Super-resolution and Deblurring
Table of Contents
The sources of corruption in digital images arise during image acquisition (digitization) and transmission.
Download data from here
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
train_lr = np.load('./data_files/lr_training.npy')
train_hr = np.load('./data_files/hr_training.npy')
test_lr = np.load('./data_files/lr_testing.npy')
n_train = train_lr.shape[0]
n_test = test_lr.shape[0]
print ("The number of training LR images : {}, shape : {}".format(n_train, train_lr.shape))
print ("The number of training HR images : {}, shape : {}".format(n_train, train_hr.shape))
print ("The number of testing LR images : {}, shape : {}".format(n_test, test_lr.shape))
idx = np.random.randint(n_train)
plt.figure(figsize = (20,16))
plt.subplot(1,2,1)
plt.imshow(train_lr[idx][:,:,0], 'gray')
plt.title('Low-resolution image', fontsize = 20)
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(train_hr[idx][:,:,0], 'gray')
plt.title('High-resolution image', fontsize = 20)
plt.axis('off')
plt.show()
inputs = tf.keras.Input(shape = (112, 112, 1))
# 3x3 convolutional layer
x = tf.keras.layers.Conv2D(filters = 16,
kernel_size = (3,3),
padding = 'SAME',
activation = 'relu')(inputs)
# first residual block
x_skip = x
x = tf.keras.layers.Conv2D(filters = 16,
kernel_size = (3,3),
padding = 'SAME',
activation = 'relu')(x)
x = tf.keras.layers.Conv2D(filters = 16,
kernel_size = (3,3),
padding = 'SAME',
activation = 'relu')(x)
x = tf.keras.layers.Add()([x_skip, x])
# second residual block
x_skip = x
x = tf.keras.layers.Conv2D(filters = 16,
kernel_size = (3,3),
padding = 'SAME',
activation = 'relu')(x)
x = tf.keras.layers.Conv2D(filters = 16,
kernel_size = (3,3),
padding = 'SAME',
activation = 'relu')(x)
x = tf.keras.layers.Add()([x_skip, x])
# third residual block
x_skip = x
x = tf.keras.layers.Conv2D(filters = 16,
kernel_size = (3,3),
padding = 'SAME',
activation = 'relu')(x)
x = tf.keras.layers.Conv2D(filters = 16,
kernel_size = (3,3),
padding = 'SAME',
activation = 'relu')(x)
x = tf.keras.layers.Add()([x_skip, x])
# upsampling layer
x = tf.keras.layers.Conv2DTranspose(filters = 16,
kernel_size = (4,4),
strides = (2,2),
padding = 'SAME',
activation = 'relu')(x)
# 3x3 convolutional layer
outputs = tf.keras.layers.Conv2D(filters = 1,
kernel_size = (3,3),
padding = 'SAME',
activation = 'sigmoid')(x)
model = tf.keras.Model(inputs, outputs)
model.compile(optimizer = 'adam',
loss = 'mean_absolute_error',
metrics = ['mean_squared_error'])
model.fit(train_lr, train_hr, batch_size = 16, epochs = 30)
test_x = test_lr[[3]]
test_sr = model.predict(test_x)
plt.figure(figsize = (20,16))
plt.subplot(1,2,1)
plt.imshow(test_x[0][:,:,0], 'gray')
plt.title('Low-resolution image', fontsize = 20)
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(test_sr[0][:,:,0], 'gray')
plt.title('Super-resolved image', fontsize = 20)
plt.axis('off')
plt.show()
train_blur = np.load('./data_files/blur_training.npy')
train_deblur = np.load('./data_files/deblur_training.npy')
test_blur = np.load('./data_files/blur_testing.npy')
n_train = train_blur.shape[0]
n_test = test_blur.shape[0]
print ("The number of training blur images : {}, shape : {}".format(n_train, train_blur.shape))
print ("The number of training deblur images : {}, shape : {}".format(n_train, train_deblur.shape))
print ("The number of testing blur images : {}, shape : {}".format(n_test, test_blur.shape))
idx = np.random.randint(n_train)
plt.figure(figsize = (20,16))
plt.subplot(1,2,1)
plt.imshow(train_blur[idx][:,:,0], 'gray')
plt.title('Blurred image', fontsize = 20)
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(train_deblur[idx][:,:,0], 'gray')
plt.title('Deblurred image', fontsize = 20)
plt.axis('off')
plt.show()
inputs = tf.keras.Input(shape = (224, 224, 1))
# 3x3 convolutional layer
x = tf.keras.layers.Conv2D(filters = 16,
kernel_size = (3,3),
padding = 'SAME',
activation = 'relu')(inputs)
# first residual block
x_skip = x
x = tf.keras.layers.Conv2D(filters = 16,
kernel_size = (3,3),
padding = 'SAME',
activation = 'relu')(x)
x = tf.keras.layers.Conv2D(filters = 16,
kernel_size = (3,3),
padding = 'SAME',
activation = 'relu')(x)
x = tf.keras.layers.Add()([x_skip, x])
# second residual block
x_skip = x
x = tf.keras.layers.Conv2D(filters = 16,
kernel_size = (3,3),
padding = 'SAME',
activation = 'relu')(x)
x = tf.keras.layers.Conv2D(filters = 16,
kernel_size = (3,3),
padding = 'SAME',
activation = 'relu')(x)
x = tf.keras.layers.Add()([x_skip, x])
# third residual block
x_skip = x
x = tf.keras.layers.Conv2D(filters = 16,
kernel_size = (3,3),
padding = 'SAME',
activation = 'relu')(x)
x = tf.keras.layers.Conv2D(filters = 16,
kernel_size = (3,3),
padding = 'SAME',
activation = 'relu')(x)
x = tf.keras.layers.Add()([x_skip, x])
# 3x3 convolutional layer
outputs = tf.keras.layers.Conv2D(filters = 1,
kernel_size = (3,3),
padding = 'SAME',
activation = 'sigmoid')(x)
model = tf.keras.Model(inputs, outputs)
model.compile(optimizer = 'adam',
loss ='mean_absolute_error',
metrics = ['mean_squared_error'])
model.fit(train_blur, train_deblur, batch_size = 16, epochs = 30)
test_x = test_blur[[1]]
test_deblur = model.predict(test_x)
plt.figure(figsize = (20,16))
plt.subplot(1,2,1)
plt.imshow(test_x[0][:,:,0], 'gray')
plt.title('Blurred image', fontsize = 20)
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(test_deblur[0][:,:,0], 'gray')
plt.title('Deblurred image', fontsize = 20)
plt.axis('off')
plt.show()
%%javascript
$.getScript('https://kmahelona.github.io/ipython_notebook_goodies/ipython_notebook_toc.js')