Fully Convolutional Networks for Segmentation
Table of Contents
%%html
<center><iframe src="https://www.youtube.com/embed/8-PA11R3e9c?start=1915&rel=0"
width="560" height="315" frameborder="0" allowfullscreen></iframe></center>
%%html
<center><iframe src="https://www.youtube.com/embed/4vyohdppEoY?rel=0"
width="560" height="315" frameborder="0" allowfullscreen></iframe></center>
To obtain a segmentation map (output), segmentation networks usually have 2 parts
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
train_imgs = np.load('./data_files/images_training.npy')
train_seg = np.load('./data_files/seg_training.npy')
test_imgs = np.load('./data_files/images_testing.npy')
n_train = train_imgs.shape[0]
n_test = test_imgs.shape[0]
print ("The number of training images : {}, shape : {}".format(n_train, train_imgs.shape))
print ("The number of segmented images : {}, shape : {}".format(n_train, train_seg.shape))
print ("The number of testing images : {}, shape : {}".format(n_test, test_imgs.shape))
idx = np.random.randint(n_train)
plt.figure(figsize = (15,10))
plt.subplot(1,3,1)
plt.imshow(train_imgs[idx])
plt.axis('off')
plt.subplot(1,3,2)
plt.imshow(train_seg[idx][:,:,0])
plt.axis('off')
plt.subplot(1,3,3)
plt.imshow(train_seg[idx][:,:,1])
plt.axis('off')
plt.show()
model_type = tf.keras.applications.vgg16
base_model = model_type.VGG16()
base_model.trainable = False
base_model.summary()
map5 = base_model.layers[-5].output
# sixth convolution layer
conv6 = tf.keras.layers.Conv2D(filters = 4096,
kernel_size = (7,7),
padding = 'SAME',
activation = 'relu')(map5)
# 1x1 convolution layers
fcn4 = tf.keras.layers.Conv2D(filters = 4096,
kernel_size = (1,1),
padding = 'SAME',
activation = 'relu')(conv6)
fcn3 = tf.keras.layers.Conv2D(filters = 2,
kernel_size = (1,1),
padding = 'SAME',
activation = 'relu')(fcn4)
# Upsampling layers
fcn2 = tf.keras.layers.Conv2DTranspose(filters = 512,
kernel_size = (4,4),
strides = (2,2),
padding = 'SAME')(fcn3)
fcn1 = tf.keras.layers.Conv2DTranspose(filters = 256,
kernel_size = (4,4),
strides = (2,2),
padding = 'SAME')(fcn2 + base_model.layers[14].output)
output = tf.keras.layers.Conv2DTranspose(filters = 2,
kernel_size = (16,16),
strides = (8,8),
padding = 'SAME',
activation = 'softmax')(fcn1 + base_model.layers[10].output)
model = tf.keras.Model(inputs = base_model.inputs, outputs = output)
model.summary()
model.compile(optimizer = 'adam',
loss = 'categorical_crossentropy',
metrics = 'accuracy')
model.fit(train_imgs, train_seg, batch_size = 5, epochs = 5)
test_x = test_imgs[[1]]
test_seg = model.predict(test_x)
seg_mask = (test_seg[:,:,:,1] > 0.5).reshape(224, 224, 1).astype(float)
plt.figure(figsize = (14,14))
plt.subplot(2,2,1)
plt.imshow(test_x[0])
plt.axis('off')
plt.subplot(2,2,2)
plt.imshow(seg_mask, cmap = 'Blues')
plt.axis('off')
plt.subplot(2,2,3)
plt.imshow(test_x[0])
plt.imshow(seg_mask, cmap = 'Blues', alpha = 0.5)
plt.axis('off')
plt.show()
%%javascript
$.getScript('https://kmahelona.github.io/ipython_notebook_goodies/ipython_notebook_toc.js')