Super-resolution and Deblurring


By Prof. Seungchul Lee
http://iailab.kaist.ac.kr/
Industrial AI Lab at KAIST

Table of Contents

1. Image RestorationĀ¶





  • Image restoration tries to recover original image from degraded one with prior knowledge of degradation process.
  • The sources of corruption in digital images arise during image acquisition (digitization) and transmission.

    • Imaging sensors can be affected by ambient conditions.
    • Interference can be added to an image during transmission.

2. Inverse ProblemĀ¶




  • The reconstruction is the inverse of the acquisition.
  • Inverse problems involve modeling of degradation and applying the inverse process in order to recover the original image from inadequate observations.
  • The observations contain incomplete information about the target parameter or data due to physical limitations of the measurement devices.
  • Consequently, solutions to inverse problems are non-unique.

3. Image Super-resolutionĀ¶

  • Restore High Resolution (HR) image from Low Resolution (LR) image




  • There are numerous learning-based SR approaches.


InĀ [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
InĀ [2]:
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
InĀ [3]:
train_lr = np.load('/content/drive/MyDrive/DL_Colab/DL_data/SR_train_lr.npy')
train_hr = np.load('/content/drive/MyDrive/DL_Colab/DL_data/SR_train_hr.npy')
test_lr = np.load('/content/drive/MyDrive/DL_Colab/DL_data/SR_test_lr.npy')

n_train = train_lr.shape[0]
n_test = test_lr.shape[0]

print ("The number of training LR images : {}, shape : {}".format(n_train, train_lr.shape))
print ("The number of training HR images : {}, shape : {}".format(n_train, train_hr.shape))
print ("The number of testing LR images  : {}, shape : {}".format(n_test, test_lr.shape))
The number of training LR images : 79, shape : (79, 112, 112, 1)
The number of training HR images : 79, shape : (79, 224, 224, 1)
The number of testing LR images  : 20, shape : (20, 112, 112, 1)
InĀ [4]:
idx = np.random.randint(n_train)

plt.figure(figsize = (8, 6))
plt.subplot(1,2,1)
plt.imshow(train_lr[idx][:,:,0], 'gray')
plt.title('Low-resolution image')
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(train_hr[idx][:,:,0], 'gray')
plt.title('High-resolution image')
plt.axis('off')
plt.show()

3.2. Build a FCN ModelĀ¶



InĀ [5]:
inputs = tf.keras.Input(shape = (112, 112, 1))

# 3x3 convolutional layer
x = tf.keras.layers.Conv2D(filters = 16,
                           kernel_size = (3,3),
                           padding = 'SAME',
                           activation = 'relu')(inputs)

# first residual block
x_skip = x
x = tf.keras.layers.Conv2D(filters = 16,
                           kernel_size = (3,3),
                           padding = 'SAME',
                           activation = 'relu')(x)

x = tf.keras.layers.Conv2D(filters = 16,
                           kernel_size = (3,3),
                           padding = 'SAME',
                           activation = 'relu')(x)

x = tf.keras.layers.Add()([x_skip, x])

# second residual block
x_skip = x
x = tf.keras.layers.Conv2D(filters = 16,
                           kernel_size = (3,3),
                           padding = 'SAME',
                           activation = 'relu')(x)

x = tf.keras.layers.Conv2D(filters = 16,
                           kernel_size = (3,3),
                           padding = 'SAME',
                           activation = 'relu')(x)

x = tf.keras.layers.Add()([x_skip, x])

# third residual block
x_skip = x
x = tf.keras.layers.Conv2D(filters = 16,
                           kernel_size = (3,3),
                           padding = 'SAME',
                           activation = 'relu')(x)

x = tf.keras.layers.Conv2D(filters = 16,
                           kernel_size = (3,3),
                           padding = 'SAME',
                           activation = 'relu')(x)

x = tf.keras.layers.Add()([x_skip, x])

# upsampling layer
x = tf.keras.layers.Conv2DTranspose(filters = 16,
                                    kernel_size = (4,4),
                                    strides = (2,2),
                                    padding = 'SAME',
                                    activation = 'relu')(x)

# 3x3 convolutional layer
outputs = tf.keras.layers.Conv2D(filters = 1,
                                 kernel_size = (3,3),
                                 padding = 'SAME',
                                 activation = 'sigmoid')(x)

model = tf.keras.Model(inputs, outputs)

3.3. TrainingĀ¶

InĀ [6]:
model.compile(optimizer = 'adam',
              loss = 'mean_absolute_error',
              metrics = ['mean_squared_error'])
InĀ [7]:
model.fit(train_lr, train_hr, batch_size = 16, epochs = 30)
Epoch 1/30
5/5 [==============================] - 14s 143ms/step - loss: 0.1640 - mean_squared_error: 0.0454
Epoch 2/30
5/5 [==============================] - 0s 30ms/step - loss: 0.1558 - mean_squared_error: 0.0480
Epoch 3/30
5/5 [==============================] - 0s 30ms/step - loss: 0.1521 - mean_squared_error: 0.0449
Epoch 4/30
5/5 [==============================] - 0s 31ms/step - loss: 0.1479 - mean_squared_error: 0.0425
Epoch 5/30
5/5 [==============================] - 0s 30ms/step - loss: 0.1425 - mean_squared_error: 0.0403
Epoch 6/30
5/5 [==============================] - 0s 30ms/step - loss: 0.1363 - mean_squared_error: 0.0349
Epoch 7/30
5/5 [==============================] - 0s 31ms/step - loss: 0.1267 - mean_squared_error: 0.0301
Epoch 8/30
5/5 [==============================] - 0s 31ms/step - loss: 0.1195 - mean_squared_error: 0.0258
Epoch 9/30
5/5 [==============================] - 0s 29ms/step - loss: 0.1094 - mean_squared_error: 0.0211
Epoch 10/30
5/5 [==============================] - 0s 37ms/step - loss: 0.1028 - mean_squared_error: 0.0185
Epoch 11/30
5/5 [==============================] - 0s 41ms/step - loss: 0.1015 - mean_squared_error: 0.0179
Epoch 12/30
5/5 [==============================] - 0s 39ms/step - loss: 0.0997 - mean_squared_error: 0.0169
Epoch 13/30
5/5 [==============================] - 0s 36ms/step - loss: 0.0909 - mean_squared_error: 0.0143
Epoch 14/30
5/5 [==============================] - 0s 34ms/step - loss: 0.0861 - mean_squared_error: 0.0129
Epoch 15/30
5/5 [==============================] - 0s 34ms/step - loss: 0.0812 - mean_squared_error: 0.0117
Epoch 16/30
5/5 [==============================] - 0s 33ms/step - loss: 0.0767 - mean_squared_error: 0.0106
Epoch 17/30
5/5 [==============================] - 0s 36ms/step - loss: 0.0726 - mean_squared_error: 0.0097
Epoch 18/30
5/5 [==============================] - 0s 35ms/step - loss: 0.0738 - mean_squared_error: 0.0098
Epoch 19/30
5/5 [==============================] - 0s 33ms/step - loss: 0.0707 - mean_squared_error: 0.0093
Epoch 20/30
5/5 [==============================] - 0s 35ms/step - loss: 0.0738 - mean_squared_error: 0.0097
Epoch 21/30
5/5 [==============================] - 0s 34ms/step - loss: 0.0726 - mean_squared_error: 0.0095
Epoch 22/30
5/5 [==============================] - 0s 32ms/step - loss: 0.0701 - mean_squared_error: 0.0090
Epoch 23/30
5/5 [==============================] - 0s 32ms/step - loss: 0.0686 - mean_squared_error: 0.0087
Epoch 24/30
5/5 [==============================] - 0s 40ms/step - loss: 0.0671 - mean_squared_error: 0.0083
Epoch 25/30
5/5 [==============================] - 0s 34ms/step - loss: 0.0673 - mean_squared_error: 0.0083
Epoch 26/30
5/5 [==============================] - 0s 35ms/step - loss: 0.0679 - mean_squared_error: 0.0084
Epoch 27/30
5/5 [==============================] - 0s 34ms/step - loss: 0.0661 - mean_squared_error: 0.0081
Epoch 28/30
5/5 [==============================] - 0s 35ms/step - loss: 0.0654 - mean_squared_error: 0.0079
Epoch 29/30
5/5 [==============================] - 0s 36ms/step - loss: 0.0682 - mean_squared_error: 0.0083
Epoch 30/30
5/5 [==============================] - 0s 33ms/step - loss: 0.0672 - mean_squared_error: 0.0083
Out[7]:
<keras.src.callbacks.History at 0x7b0d28239c30>

3.4. TestingĀ¶

InĀ [8]:
test_x = test_lr[[3]]
test_sr = model.predict(test_x)

plt.figure(figsize = (8, 6))
plt.subplot(1,2,1)
plt.imshow(test_x[0][:,:,0], 'gray')
plt.title('Low-resolution image')
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(test_sr[0][:,:,0], 'gray')
plt.title('Super-resolved image')
plt.axis('off')
plt.show()
1/1 [==============================] - 1s 734ms/step

4. Image DeblurringĀ¶

4.1. Blurred and Deblurred ImagesĀ¶

Download data from here

InĀ [9]:
train_blur = np.load('/content/drive/MyDrive/DL_Colab/DL_data/deblurring_train_blur.npy')
train_deblur = np.load('/content/drive/MyDrive/DL_Colab/DL_data/deblurring_train_deblur.npy')
test_blur = np.load('/content/drive/MyDrive/DL_Colab/DL_data/deblurring_test_blur.npy')

n_train = train_blur.shape[0]
n_test = test_blur.shape[0]

print ("The number of training blur images   : {}, shape : {}".format(n_train, train_blur.shape))
print ("The number of training deblur images : {}, shape : {}".format(n_train, train_deblur.shape))
print ("The number of testing blur images    : {}, shape : {}".format(n_test, test_blur.shape))
The number of training blur images   : 79, shape : (79, 224, 224, 1)
The number of training deblur images : 79, shape : (79, 224, 224, 1)
The number of testing blur images    : 20, shape : (20, 224, 224, 1)
InĀ [10]:
idx = np.random.randint(n_train)

plt.figure(figsize = (8, 6))
plt.subplot(1,2,1)
plt.imshow(train_blur[idx][:,:,0], 'gray')
plt.title('Blurred image')
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(train_deblur[idx][:,:,0], 'gray')
plt.title('Deblurred image')
plt.axis('off')
plt.show()

4.2. Build a FCN ModelĀ¶



InĀ [11]:
inputs = tf.keras.Input(shape = (224, 224, 1))

# 3x3 convolutional layer
x = tf.keras.layers.Conv2D(filters = 16,
                           kernel_size = (3,3),
                           padding = 'SAME',
                           activation = 'relu')(inputs)

# first residual block
x_skip = x
x = tf.keras.layers.Conv2D(filters = 16,
                           kernel_size = (3,3),
                           padding = 'SAME',
                           activation = 'relu')(x)

x = tf.keras.layers.Conv2D(filters = 16,
                           kernel_size = (3,3),
                           padding = 'SAME',
                           activation = 'relu')(x)

x = tf.keras.layers.Add()([x_skip, x])

# second residual block
x_skip = x
x = tf.keras.layers.Conv2D(filters = 16,
                           kernel_size = (3,3),
                           padding = 'SAME',
                           activation = 'relu')(x)

x = tf.keras.layers.Conv2D(filters = 16,
                           kernel_size = (3,3),
                           padding = 'SAME',
                           activation = 'relu')(x)

x = tf.keras.layers.Add()([x_skip, x])

# third residual block
x_skip = x
x = tf.keras.layers.Conv2D(filters = 16,
                           kernel_size = (3,3),
                           padding = 'SAME',
                           activation = 'relu')(x)

x = tf.keras.layers.Conv2D(filters = 16,
                           kernel_size = (3,3),
                           padding = 'SAME',
                           activation = 'relu')(x)

x = tf.keras.layers.Add()([x_skip, x])

# 3x3 convolutional layer
outputs = tf.keras.layers.Conv2D(filters = 1,
                                 kernel_size = (3,3),
                                 padding = 'SAME',
                                 activation = 'sigmoid')(x)

model = tf.keras.Model(inputs, outputs)

4.3. TrainingĀ¶

InĀ [12]:
model.compile(optimizer = 'adam',
              loss ='mean_absolute_error',
              metrics = ['mean_squared_error'])
InĀ [13]:
model.fit(train_blur, train_deblur, batch_size = 16, epochs = 30)
Epoch 1/30
5/5 [==============================] - 4s 198ms/step - loss: 0.1618 - mean_squared_error: 0.0480
Epoch 2/30
5/5 [==============================] - 0s 71ms/step - loss: 0.1571 - mean_squared_error: 0.0486
Epoch 3/30
5/5 [==============================] - 0s 72ms/step - loss: 0.1491 - mean_squared_error: 0.0425
Epoch 4/30
5/5 [==============================] - 0s 72ms/step - loss: 0.1393 - mean_squared_error: 0.0377
Epoch 5/30
5/5 [==============================] - 0s 74ms/step - loss: 0.1262 - mean_squared_error: 0.0301
Epoch 6/30
5/5 [==============================] - 0s 81ms/step - loss: 0.1049 - mean_squared_error: 0.0215
Epoch 7/30
5/5 [==============================] - 0s 79ms/step - loss: 0.0897 - mean_squared_error: 0.0156
Epoch 8/30
5/5 [==============================] - 0s 88ms/step - loss: 0.0801 - mean_squared_error: 0.0120
Epoch 9/30
5/5 [==============================] - 0s 82ms/step - loss: 0.0749 - mean_squared_error: 0.0105
Epoch 10/30
5/5 [==============================] - 0s 80ms/step - loss: 0.0697 - mean_squared_error: 0.0095
Epoch 11/30
5/5 [==============================] - 0s 89ms/step - loss: 0.0668 - mean_squared_error: 0.0086
Epoch 12/30
5/5 [==============================] - 0s 90ms/step - loss: 0.0636 - mean_squared_error: 0.0077
Epoch 13/30
5/5 [==============================] - 0s 81ms/step - loss: 0.0606 - mean_squared_error: 0.0071
Epoch 14/30
5/5 [==============================] - 0s 77ms/step - loss: 0.0587 - mean_squared_error: 0.0066
Epoch 15/30
5/5 [==============================] - 0s 76ms/step - loss: 0.0583 - mean_squared_error: 0.0064
Epoch 16/30
5/5 [==============================] - 0s 74ms/step - loss: 0.0572 - mean_squared_error: 0.0061
Epoch 17/30
5/5 [==============================] - 0s 74ms/step - loss: 0.0560 - mean_squared_error: 0.0058
Epoch 18/30
5/5 [==============================] - 0s 74ms/step - loss: 0.0517 - mean_squared_error: 0.0052
Epoch 19/30
5/5 [==============================] - 0s 74ms/step - loss: 0.0510 - mean_squared_error: 0.0050
Epoch 20/30
5/5 [==============================] - 0s 74ms/step - loss: 0.0499 - mean_squared_error: 0.0048
Epoch 21/30
5/5 [==============================] - 0s 74ms/step - loss: 0.0489 - mean_squared_error: 0.0046
Epoch 22/30
5/5 [==============================] - 0s 73ms/step - loss: 0.0514 - mean_squared_error: 0.0048
Epoch 23/30
5/5 [==============================] - 0s 73ms/step - loss: 0.0508 - mean_squared_error: 0.0047
Epoch 24/30
5/5 [==============================] - 0s 74ms/step - loss: 0.0496 - mean_squared_error: 0.0045
Epoch 25/30
5/5 [==============================] - 0s 73ms/step - loss: 0.0483 - mean_squared_error: 0.0043
Epoch 26/30
5/5 [==============================] - 0s 74ms/step - loss: 0.0460 - mean_squared_error: 0.0040
Epoch 27/30
5/5 [==============================] - 0s 75ms/step - loss: 0.0452 - mean_squared_error: 0.0039
Epoch 28/30
5/5 [==============================] - 0s 74ms/step - loss: 0.0459 - mean_squared_error: 0.0040
Epoch 29/30
5/5 [==============================] - 0s 74ms/step - loss: 0.0442 - mean_squared_error: 0.0038
Epoch 30/30
5/5 [==============================] - 0s 74ms/step - loss: 0.0431 - mean_squared_error: 0.0036
Out[13]:
<keras.src.callbacks.History at 0x7b0cfffc50c0>

4.4. TestingĀ¶

InĀ [14]:
test_x = test_blur[[1]]
test_deblur = model.predict(test_x)

plt.figure(figsize = (8, 6))
plt.subplot(1,2,1)
plt.imshow(test_x[0][:,:,0], 'gray')
plt.title('Blurred image')
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(test_deblur[0][:,:,0], 'gray')
plt.title('Deblurred image')
plt.axis('off')
plt.show()
1/1 [==============================] - 0s 124ms/step
InĀ [15]:
%%javascript
$.getScript('https://kmahelona.github.io/ipython_notebook_goodies/ipython_notebook_toc.js')