Fully Convolutional Networks for Segmentation


By Prof. Seungchul Lee
http://iailab.kaist.ac.kr/
Industrial AI Lab at KAIST

Table of Contents

1. SegmentationĀ¶


  • Segmentation task is different from classification task because it requires predicting a class for each pixel of the input image, instead of only 1 class for the whole input.
  • Classification needs to understand what is in the input (namely, the context).
  • However, in order to predict what is in the input for each pixel, segmentation needs to recover not only what is in the input, but also where.
  • Segment images into regions with different semantic categories. These semantic regions label and predict objects at the pixel level

2. Fully Convolutional Networks (FCN)Ā¶


  • FCN is built only from locally connected layers, such as convolution, pooling and upsampling.
  • Note that no dense layer is used in this kind of architecture.
  • Network can work regardless of the original image size, without requiring any fixed number of units at any stage.
  • To obtain a segmentation map (output), segmentation networks usually have 2 parts

    • Downsampling path: capture semantic/contextual information
    • Upsampling path: recover spatial information
  • The downsampling path is used to extract and interpret the context (what), while the upsampling path is used to enable precise localization (where).
  • Furthermore, to fully recover the fine-grained spatial information lost in the pooling or downsampling layers, we often use skip connections.
  • Given a position on the spatial dimension, the output of the channel dimension will be a category prediction of the pixel corresponding to the location.

3. Supervised Learning for SegmentationĀ¶

3.1. Segmented (Labeled) ImagesĀ¶

Download data

InĀ [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
InĀ [2]:
from google.colab import drive
drive.mount('/content/drive')
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
InĀ [3]:
seg_train_imgs = np.load('/content/drive/MyDrive/DL_Colab/DL_data/seg_train_imgs.npy')
seg_train_labels = np.load('/content/drive/MyDrive/DL_Colab/DL_data/seg_train_labels.npy')
seg_test_imgs = np.load('/content/drive/MyDrive/DL_Colab/DL_data/seg_test_imgs.npy')

n_train = seg_train_imgs.shape[0]
n_test = seg_train_imgs.shape[0]

print ("The number of training images  : {}, shape : {}".format(n_train, seg_train_imgs.shape))
print ("The number of segmented images : {}, shape : {}".format(n_train, seg_train_labels.shape))
print ("The number of testing images   : {}, shape : {}".format(n_test, seg_test_imgs.shape))
The number of training images  : 180, shape : (180, 224, 224, 3)
The number of segmented images : 180, shape : (180, 224, 224, 2)
The number of testing images   : 180, shape : (27, 224, 224, 3)
InĀ [4]:
## binary segmentation and one-hot encoding in this case

idx = np.random.randint(n_train)

plt.figure(figsize = (10, 4))
plt.subplot(1,3,1)
plt.imshow(seg_train_imgs[idx])
plt.axis('off')
plt.subplot(1,3,2)
plt.imshow(seg_train_labels[idx][:,:,0])
plt.axis('off')
plt.subplot(1,3,3)
plt.imshow(seg_train_labels[idx][:,:,1])
plt.axis('off')
plt.show()

3.2. From CAE to FCNĀ¶


  • CAE






  • FCN
    • VGG16
    • Skip connections to fully recover the fine-grained spatial information lost in the pooling or downsampling layers





4. FCN ImplementationĀ¶





4.1. Utilize VGG16 Model for EncoderĀ¶

InĀ [5]:
model_type = tf.keras.applications.vgg16
base_model = model_type.VGG16()
base_model.trainable = False
base_model.summary()
Model: "vgg16"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 block1_conv1 (Conv2D)       (None, 224, 224, 64)      1792      
                                                                 
 block1_conv2 (Conv2D)       (None, 224, 224, 64)      36928     
                                                                 
 block1_pool (MaxPooling2D)  (None, 112, 112, 64)      0         
                                                                 
 block2_conv1 (Conv2D)       (None, 112, 112, 128)     73856     
                                                                 
 block2_conv2 (Conv2D)       (None, 112, 112, 128)     147584    
                                                                 
 block2_pool (MaxPooling2D)  (None, 56, 56, 128)       0         
                                                                 
 block3_conv1 (Conv2D)       (None, 56, 56, 256)       295168    
                                                                 
 block3_conv2 (Conv2D)       (None, 56, 56, 256)       590080    
                                                                 
 block3_conv3 (Conv2D)       (None, 56, 56, 256)       590080    
                                                                 
 block3_pool (MaxPooling2D)  (None, 28, 28, 256)       0         
                                                                 
 block4_conv1 (Conv2D)       (None, 28, 28, 512)       1180160   
                                                                 
 block4_conv2 (Conv2D)       (None, 28, 28, 512)       2359808   
                                                                 
 block4_conv3 (Conv2D)       (None, 28, 28, 512)       2359808   
                                                                 
 block4_pool (MaxPooling2D)  (None, 14, 14, 512)       0         
                                                                 
 block5_conv1 (Conv2D)       (None, 14, 14, 512)       2359808   
                                                                 
 block5_conv2 (Conv2D)       (None, 14, 14, 512)       2359808   
                                                                 
 block5_conv3 (Conv2D)       (None, 14, 14, 512)       2359808   
                                                                 
 block5_pool (MaxPooling2D)  (None, 7, 7, 512)         0         
                                                                 
 flatten (Flatten)           (None, 25088)             0         
                                                                 
 fc1 (Dense)                 (None, 4096)              102764544 
                                                                 
 fc2 (Dense)                 (None, 4096)              16781312  
                                                                 
 predictions (Dense)         (None, 1000)              4097000   
                                                                 
=================================================================
Total params: 138357544 (527.79 MB)
Trainable params: 0 (0.00 Byte)
Non-trainable params: 138357544 (527.79 MB)
_________________________________________________________________

4.2. Build a FCN ModelĀ¶

  • tf.layers are used to define upsampling parts





InĀ [6]:
map5 = base_model.layers[-5].output

# sixth convolution layer
conv6 = tf.keras.layers.Conv2D(filters = 4096,
                               kernel_size = (7,7),
                               padding = 'SAME',
                               activation = 'relu')(map5)

# 1x1 convolution layers
fcn4 = tf.keras.layers.Conv2D(filters = 4096,
                              kernel_size = (1,1),
                              padding = 'SAME',
                              activation = 'relu')(conv6)

fcn3 = tf.keras.layers.Conv2D(filters = 2,
                              kernel_size = (1,1),
                              padding = 'SAME',
                              activation = 'relu')(fcn4)

# Upsampling layers
fcn2 =  tf.keras.layers.Conv2DTranspose(filters = 512,
                                        kernel_size = (4,4),
                                        strides = (2,2),
                                        padding = 'SAME')(fcn3)

fcn1 =  tf.keras.layers.Conv2DTranspose(filters = 256,
                                        kernel_size = (4,4),
                                        strides = (2,2),
                                        padding = 'SAME')(fcn2 + base_model.layers[14].output)

output =  tf.keras.layers.Conv2DTranspose(filters = 2,
                                          kernel_size = (16,16),
                                          strides = (8,8),
                                          padding = 'SAME',
                                          activation = 'softmax')(fcn1 + base_model.layers[10].output)

model = tf.keras.Model(inputs = base_model.inputs, outputs = output)
InĀ [7]:
model.summary()
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
==================================================================================================
 input_1 (InputLayer)        [(None, 224, 224, 3)]        0         []                            
                                                                                                  
 block1_conv1 (Conv2D)       (None, 224, 224, 64)         1792      ['input_1[0][0]']             
                                                                                                  
 block1_conv2 (Conv2D)       (None, 224, 224, 64)         36928     ['block1_conv1[0][0]']        
                                                                                                  
 block1_pool (MaxPooling2D)  (None, 112, 112, 64)         0         ['block1_conv2[0][0]']        
                                                                                                  
 block2_conv1 (Conv2D)       (None, 112, 112, 128)        73856     ['block1_pool[0][0]']         
                                                                                                  
 block2_conv2 (Conv2D)       (None, 112, 112, 128)        147584    ['block2_conv1[0][0]']        
                                                                                                  
 block2_pool (MaxPooling2D)  (None, 56, 56, 128)          0         ['block2_conv2[0][0]']        
                                                                                                  
 block3_conv1 (Conv2D)       (None, 56, 56, 256)          295168    ['block2_pool[0][0]']         
                                                                                                  
 block3_conv2 (Conv2D)       (None, 56, 56, 256)          590080    ['block3_conv1[0][0]']        
                                                                                                  
 block3_conv3 (Conv2D)       (None, 56, 56, 256)          590080    ['block3_conv2[0][0]']        
                                                                                                  
 block3_pool (MaxPooling2D)  (None, 28, 28, 256)          0         ['block3_conv3[0][0]']        
                                                                                                  
 block4_conv1 (Conv2D)       (None, 28, 28, 512)          1180160   ['block3_pool[0][0]']         
                                                                                                  
 block4_conv2 (Conv2D)       (None, 28, 28, 512)          2359808   ['block4_conv1[0][0]']        
                                                                                                  
 block4_conv3 (Conv2D)       (None, 28, 28, 512)          2359808   ['block4_conv2[0][0]']        
                                                                                                  
 block4_pool (MaxPooling2D)  (None, 14, 14, 512)          0         ['block4_conv3[0][0]']        
                                                                                                  
 block5_conv1 (Conv2D)       (None, 14, 14, 512)          2359808   ['block4_pool[0][0]']         
                                                                                                  
 block5_conv2 (Conv2D)       (None, 14, 14, 512)          2359808   ['block5_conv1[0][0]']        
                                                                                                  
 block5_conv3 (Conv2D)       (None, 14, 14, 512)          2359808   ['block5_conv2[0][0]']        
                                                                                                  
 block5_pool (MaxPooling2D)  (None, 7, 7, 512)            0         ['block5_conv3[0][0]']        
                                                                                                  
 conv2d (Conv2D)             (None, 7, 7, 4096)           1027645   ['block5_pool[0][0]']         
                                                          44                                      
                                                                                                  
 conv2d_1 (Conv2D)           (None, 7, 7, 4096)           1678131   ['conv2d[0][0]']              
                                                          2                                       
                                                                                                  
 conv2d_2 (Conv2D)           (None, 7, 7, 2)              8194      ['conv2d_1[0][0]']            
                                                                                                  
 conv2d_transpose (Conv2DTr  (None, 14, 14, 512)          16896     ['conv2d_2[0][0]']            
 anspose)                                                                                         
                                                                                                  
 tf.__operators__.add (TFOp  (None, 14, 14, 512)          0         ['conv2d_transpose[0][0]',    
 Lambda)                                                             'block4_pool[0][0]']         
                                                                                                  
 conv2d_transpose_1 (Conv2D  (None, 28, 28, 256)          2097408   ['tf.__operators__.add[0][0]']
 Transpose)                                                                                       
                                                                                                  
 tf.__operators__.add_1 (TF  (None, 28, 28, 256)          0         ['conv2d_transpose_1[0][0]',  
 OpLambda)                                                           'block3_pool[0][0]']         
                                                                                                  
 conv2d_transpose_2 (Conv2D  (None, 224, 224, 2)          131074    ['tf.__operators__.add_1[0][0]
 Transpose)                                                         ']                            
                                                                                                  
==================================================================================================
Total params: 136514116 (520.76 MB)
Trainable params: 121799428 (464.63 MB)
Non-trainable params: 14714688 (56.13 MB)
__________________________________________________________________________________________________

4.3. TrainingĀ¶

InĀ [8]:
model.compile(optimizer = 'adam',
              loss = 'categorical_crossentropy',
              metrics = 'accuracy')
InĀ [9]:
model.fit(seg_train_imgs, seg_train_labels, batch_size = 5, epochs = 5)
Epoch 1/5
36/36 [==============================] - 21s 206ms/step - loss: 0.4534 - accuracy: 0.8714
Epoch 2/5
36/36 [==============================] - 7s 200ms/step - loss: 0.2277 - accuracy: 0.9138
Epoch 3/5
36/36 [==============================] - 7s 200ms/step - loss: 0.2077 - accuracy: 0.9197
Epoch 4/5
36/36 [==============================] - 7s 200ms/step - loss: 0.2020 - accuracy: 0.9220
Epoch 5/5
36/36 [==============================] - 7s 201ms/step - loss: 0.1909 - accuracy: 0.9260
Out[9]:
<keras.src.callbacks.History at 0x7fcf20508fa0>

4.4. TestingĀ¶

InĀ [10]:
test_img = seg_test_imgs[[1]]
test_segmented = model.predict(test_img)

seg_mask = (test_segmented[:,:,:,1] > 0.5).reshape(224, 224, 1).astype(float)

plt.figure(figsize = (8,8))
plt.subplot(2,2,1)
plt.imshow(test_img[0])
plt.axis('off')
plt.subplot(2,2,2)
plt.imshow(seg_mask, cmap = 'Blues')
plt.axis('off')
plt.subplot(2,2,3)
plt.imshow(test_img[0])
plt.imshow(seg_mask, cmap = 'Blues', alpha = 0.5)
plt.axis('off')
plt.show()
1/1 [==============================] - 1s 1s/step
InĀ [11]:
%%javascript
$.getScript('https://kmahelona.github.io/ipython_notebook_goodies/ipython_notebook_toc.js')