Transfer Learning


By Prof. Seungchul Lee
http://iai.postech.ac.kr/
Industrial AI Lab at POSTECH

Table of Contents

1. Pre-trained Model (VGG16)Ā¶

  • Training a model on ImageNet from scratch takes days or weeks.
  • Many models trained on ImageNet and their weights are publicly available!
  • Transfer learning
    • Use pre-trained weights, remove last layers to compute representations of images
    • The network is used as a generic feature extractor
    • Train a classification model from these features on a new classification task
    • Pre- trained models can extract more general image features that can help identify edges, textures, shapes, and object composition
    • Better than handcrafted feature extraction on natural images






1.1. Import LibraryĀ¶

InĀ [1]:
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from keras.applications.vgg16 import VGG16
from keras.applications.vgg16 import decode_predictions
Using TensorFlow backend.

1.2. Load DataĀ¶

Download data files from here

InĀ [2]:
train_imgs = np.load('./data_files/target_images.npy')
train_labels = np.load('./data_files/target_labels.npy')
test_imgs = np.load('./data_files/test_images.npy')
test_labels = np.load('./data_files/test_labels.npy')

print(train_imgs.shape)
print(train_labels[0]) # one-hot-encoded 5 classes 
(65, 224, 224, 3)
[1. 0. 0. 0. 0.]
InĀ [3]:
Dict = ['Hat','Cube','Card','Torch','screw']
InĀ [4]:
n_train = train_imgs.shape[0]
n_test = test_imgs.shape[0]
idx = np.random.randint(n_train)

plt.figure(figsize = (6,6))
plt.imshow(train_imgs[idx])
plt.title("Label : {}".format(Dict[np.argmax(train_labels[idx])]))
plt.axis('off')
plt.show()

1.3. Load VGG16 ModelĀ¶





InĀ [5]:
model = VGG16(weights = 'imagenet')

model.summary()
WARNING: Logging before flag parsing goes to stderr.
W0816 23:06:17.055058 11880 deprecation_wrapper.py:119] From c:\users\seungchul\appdata\local\programs\python\python35\lib\site-packages\keras\backend\tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

W0816 23:06:17.067457 11880 deprecation_wrapper.py:119] From c:\users\seungchul\appdata\local\programs\python\python35\lib\site-packages\keras\backend\tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0816 23:06:17.069937 11880 deprecation_wrapper.py:119] From c:\users\seungchul\appdata\local\programs\python\python35\lib\site-packages\keras\backend\tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.

W0816 23:06:17.144337 11880 deprecation_wrapper.py:119] From c:\users\seungchul\appdata\local\programs\python\python35\lib\site-packages\keras\backend\tensorflow_backend.py:3976: The name tf.nn.max_pool is deprecated. Please use tf.nn.max_pool2d instead.

Downloading data from https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels.h5
553467904/553467096 [==============================] - 688s 1us/step
W0816 23:17:49.081454 11880 deprecation_wrapper.py:119] From c:\users\seungchul\appdata\local\programs\python\python35\lib\site-packages\keras\backend\tensorflow_backend.py:174: The name tf.get_default_session is deprecated. Please use tf.compat.v1.get_default_session instead.

W0816 23:17:49.082942 11880 deprecation_wrapper.py:119] From c:\users\seungchul\appdata\local\programs\python\python35\lib\site-packages\keras\backend\tensorflow_backend.py:181: The name tf.ConfigProto is deprecated. Please use tf.compat.v1.ConfigProto instead.

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 224, 224, 3)       0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 224, 224, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 224, 224, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 112, 112, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 112, 112, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 112, 112, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 56, 56, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 56, 56, 256)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 28, 28, 256)       0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 28, 28, 512)       1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 14, 14, 512)       0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 7, 7, 512)         0         
_________________________________________________________________
flatten (Flatten)            (None, 25088)             0         
_________________________________________________________________
fc1 (Dense)                  (None, 4096)              102764544 
_________________________________________________________________
fc2 (Dense)                  (None, 4096)              16781312  
_________________________________________________________________
predictions (Dense)          (None, 1000)              4097000   
=================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
_________________________________________________________________

1.4. Testing for Target DataĀ¶

InĀ [6]:
idx = np.random.randint(n_test)
pred = model.predict(test_imgs[idx].reshape(-1, 224, 224, 3))

for (i, (_, label, prob)) in enumerate(decode_predictions(pred)[0]):
    print("{}. {}: {:.2f}%".format(i + 1, label, prob*100))
    
plt.figure(figsize = (6,6))
plt.imshow(train_imgs[idx])
plt.title("Label : {}".format(Dict[np.argmax(train_labels[idx])]))
plt.axis('off')
plt.show()    
1. mosquito_net: 2.74%
2. toilet_tissue: 2.53%
3. envelope: 1.95%
4. Band_Aid: 1.40%
5. shower_curtain: 1.24%

2. Transfer LearningĀ¶

  • We assume that these model parameters contain the knowledge learned from the source data set and that this knowledge will be equally applicable to the target data set.
  • We will train the output layer from scratch, while the parameters of all remaining layers are fine tuned based on the parameters of the source model.
  • Or initialize all weights from pre-trained model, then train them with target data









2.1. Pre-trained Weights, BiasesĀ¶

InĀ [7]:
vgg16_weights = model.get_weights()

weights = {
    'conv1_1' : tf.constant(vgg16_weights[0]),
    'conv1_2' : tf.constant(vgg16_weights[2]),
    
    'conv2_1' : tf.constant(vgg16_weights[4]),
    'conv2_2' : tf.constant(vgg16_weights[6]),
    
    'conv3_1' : tf.constant(vgg16_weights[8]),
    'conv3_2' : tf.constant(vgg16_weights[10]),
    'conv3_3' : tf.constant(vgg16_weights[12]),
    
    'conv4_1' : tf.constant(vgg16_weights[14]),
    'conv4_2' : tf.constant(vgg16_weights[16]),
    'conv4_3' : tf.constant(vgg16_weights[18]),
    
    'conv5_1' : tf.constant(vgg16_weights[20]),
    'conv5_2' : tf.constant(vgg16_weights[22]),
    'conv5_3' : tf.constant(vgg16_weights[24]),
    
    'fc1' : tf.constant(vgg16_weights[26]),
    'fc2' : tf.constant(vgg16_weights[28]),
    
    # train from scratch
    'output' : tf.Variable(tf.random_normal([4096, 5], stddev = 0.1))
}

biases = {
    'conv1_1' : tf.constant(vgg16_weights[1]),
    'conv1_2' : tf.constant(vgg16_weights[3]),
    
    'conv2_1' : tf.constant(vgg16_weights[5]),
    'conv2_2' : tf.constant(vgg16_weights[7]),
    
    'conv3_1' : tf.constant(vgg16_weights[9]),
    'conv3_2' : tf.constant(vgg16_weights[11]),
    'conv3_3' : tf.constant(vgg16_weights[13]),
    
    'conv4_1' : tf.constant(vgg16_weights[15]),
    'conv4_2' : tf.constant(vgg16_weights[17]),
    'conv4_3' : tf.constant(vgg16_weights[19]),
    
    'conv5_1' : tf.constant(vgg16_weights[21]),
    'conv5_2' : tf.constant(vgg16_weights[23]),
    'conv5_3' : tf.constant(vgg16_weights[25]),
    
    'fc1' : tf.constant(vgg16_weights[27]),
    'fc2' : tf.constant(vgg16_weights[29]),
    
    # train from scratch
    'output' : tf.Variable(tf.random_normal([5], stddev = 0.1))
}
InĀ [8]:
x = tf.placeholder(tf.float32, [None, 224, 224, 3])
y = tf.placeholder(tf.float32, [None, 5])

2.2. Build a Transfer Learning ModelĀ¶

InĀ [9]:
def transfer(x, weights, biases):
    # First convolution layer
    conv1_1 = tf.nn.conv2d(x, 
                         weights['conv1_1'], 
                         strides = [1, 1, 1, 1], 
                         padding = 'SAME')
    conv1_1 = tf.nn.relu(tf.add(conv1_1, biases['conv1_1']))
    conv1_2 = tf.nn.conv2d(conv1_1, 
                         weights['conv1_2'], 
                         strides = [1, 1, 1, 1], 
                         padding = 'SAME')
    conv1_2 = tf.nn.relu(tf.add(conv1_2, biases['conv1_2']))
    maxp1 = tf.nn.max_pool(conv1_2, 
                           ksize = [1, 2, 2, 1], 
                           strides = [1, 2, 2, 1], 
                           padding = 'VALID')

    # Second convolution layer
    conv2_1 = tf.nn.conv2d(maxp1, 
                         weights['conv2_1'], 
                         strides = [1, 1, 1, 1], 
                         padding = 'SAME')
    conv2_1 = tf.nn.relu(tf.add(conv2_1, biases['conv2_1']))
    conv2_2 = tf.nn.conv2d(conv2_1, 
                         weights['conv2_2'], 
                         strides = [1, 1, 1, 1], 
                         padding = 'SAME')
    conv2_2= tf.nn.relu(tf.add(conv2_2, biases['conv2_2']))
    maxp2 = tf.nn.max_pool(conv2_2, 
                           ksize = [1, 2, 2, 1], 
                           strides = [1, 2, 2, 1], 
                           padding = 'VALID')

    # third convolution layer
    conv3_1 = tf.nn.conv2d(maxp2, 
                         weights['conv3_1'], 
                         strides = [1, 1, 1, 1], 
                         padding = 'SAME')
    conv3_1 = tf.nn.relu(tf.add(conv3_1, biases['conv3_1']))
    conv3_2 = tf.nn.conv2d(conv3_1, 
                         weights['conv3_2'], 
                         strides = [1, 1, 1, 1], 
                         padding = 'SAME')
    conv3_2= tf.nn.relu(tf.add(conv3_2, biases['conv3_2']))
    conv3_3 = tf.nn.conv2d(conv3_2, 
                         weights['conv3_3'], 
                         strides = [1, 1, 1, 1], 
                         padding = 'SAME')
    conv3_3= tf.nn.relu(tf.add(conv3_3, biases['conv3_3']))
    maxp3 = tf.nn.max_pool(conv3_3, 
                           ksize = [1, 2, 2, 1], 
                           strides = [1, 2, 2, 1], 
                           padding = 'VALID')

    # fourth convolution layer
    conv4_1 = tf.nn.conv2d(maxp3, 
                         weights['conv4_1'], 
                         strides = [1, 1, 1, 1], 
                         padding = 'SAME')
    conv4_1 = tf.nn.relu(tf.add(conv4_1, biases['conv4_1']))
    conv4_2 = tf.nn.conv2d(conv4_1, 
                         weights['conv4_2'], 
                         strides = [1, 1, 1, 1], 
                         padding = 'SAME')
    conv4_2= tf.nn.relu(tf.add(conv4_2, biases['conv4_2']))
    conv4_3 = tf.nn.conv2d(conv4_2, 
                         weights['conv4_3'], 
                         strides = [1, 1, 1, 1], 
                         padding = 'SAME')
    conv4_3= tf.nn.relu(tf.add(conv4_3, biases['conv4_3']))
    maxp4 = tf.nn.max_pool(conv4_3, 
                           ksize = [1, 2, 2, 1], 
                           strides = [1, 2, 2, 1], 
                           padding = 'VALID')

    # fifth convolution layer
    conv5_1 = tf.nn.conv2d(maxp4, 
                         weights['conv5_1'], 
                         strides = [1, 1, 1, 1], 
                         padding = 'SAME')
    conv5_1 = tf.nn.relu(tf.add(conv5_1, biases['conv5_1']))
    conv5_2 = tf.nn.conv2d(conv5_1, 
                         weights['conv5_2'], 
                         strides = [1, 1, 1, 1], 
                         padding = 'SAME')
    conv5_2= tf.nn.relu(tf.add(conv5_2, biases['conv5_2']))
    conv5_3 = tf.nn.conv2d(conv5_2, 
                         weights['conv5_3'], 
                         strides = [1, 1, 1, 1], 
                         padding = 'SAME')
    conv5_3= tf.nn.relu(tf.add(conv5_3, biases['conv5_3']))
    maxp5 = tf.nn.max_pool(conv5_3, 
                           ksize = [1, 2, 2, 1], 
                           strides = [1, 2, 2, 1], 
                           padding = 'VALID')
    
    maxp5 = tf.reshape(maxp5, [-1, 7*7*512])

    # fully connected layer
    fc1 = tf.add(tf.matmul(maxp5, weights['fc1']), biases['fc1'])
    fc1 = tf.nn.relu(fc1)
    
    fc2 = tf.add(tf.matmul(fc1, weights['fc2']), biases['fc2'])
    fc2 = tf.nn.relu(fc2)
    
    # our output layer for a new classification task
    output = tf.add(tf.matmul(fc2, weights['output']), biases['output'])
    
    return output

2.3. Define Loss and OptimizerĀ¶

InĀ [10]:
LR  = 0.001

pred = transfer(x, weights, biases)
loss = tf.nn.softmax_cross_entropy_with_logits(logits = pred, labels = y)
loss = tf.reduce_mean(loss)

optm  = tf.train.AdamOptimizer(LR).minimize(loss)
W0816 23:17:56.646700 11880 deprecation.py:323] From <ipython-input-10-b21674db3881>:4: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.
Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See `tf.nn.softmax_cross_entropy_with_logits_v2`.

2.4. OptimizeĀ¶

InĀ [11]:
def train_batch_maker(batch_size):
    random_idx = np.random.randint(n_train, size = batch_size)
    return train_imgs[random_idx], train_labels[random_idx]
InĀ [12]:
def test_batch_maker(batch_size):
    random_idx = np.random.randint(n_test, size = batch_size)
    return test_imgs[random_idx], test_labels[random_idx]
InĀ [13]:
n_batch = 20
n_epoch = 300
n_prt = 30

sess = tf.Session()
sess.run(tf.global_variables_initializer())

loss_record_train = []
for epoch in range(n_epoch):
    train_x, train_y = train_batch_maker(n_batch)
    sess.run(optm, feed_dict = {x: train_x, y: train_y})

    if epoch % n_prt == 0:        
        c = sess.run(loss, feed_dict = {x: train_x, y: train_y})
        loss_record_train.append(c)
        print ("Epoch : {}".format(epoch))
        print ("Cost : {}".format(c))
        
plt.figure(figsize = (10,8))
plt.plot(np.arange(len(loss_record_train))*n_prt, loss_record_train, label = 'train')
plt.xlabel('epoch', fontsize = 15)
plt.ylabel('loss', fontsize = 15)
plt.legend(fontsize = 12)
plt.ylim([0, np.max(loss_record_train)])
plt.show()
Epoch : 0
Cost : 3.2474300861358643
Epoch : 30
Cost : 0.8468533754348755
Epoch : 60
Cost : 0.4728967249393463
Epoch : 90
Cost : 0.3230406641960144
Epoch : 120
Cost : 0.20700104534626007
Epoch : 150
Cost : 0.1553143858909607
Epoch : 180
Cost : 0.10301370918750763
Epoch : 210
Cost : 0.07996686547994614
Epoch : 240
Cost : 0.08920763432979584
Epoch : 270
Cost : 0.05560795217752457

2.5. Test and EvaluateĀ¶

InĀ [14]:
my_pred = sess.run(pred, feed_dict = {x: test_imgs})
my_pred = np.argmax(my_pred, axis = 1)

labels = np.argmax(test_labels, axis = 1)

accr = np.mean(np.equal(my_pred, labels))
print("Accuracy : {}%".format(accr*100))
Accuracy : 100.0%
InĀ [15]:
test_x, test_y = test_batch_maker(1)
logits = sess.run(tf.nn.softmax(pred), feed_dict = {x: test_x.reshape(-1, 224, 224, 3)})
predict = np.argmax(logits)

plt.figure(figsize = (6,6))
plt.imshow(test_x.reshape(224, 224, 3))
plt.axis('off')
plt.show()

np.set_printoptions(precision = 2, suppress = True)
print('Prediction : {}'.format(Dict[predict]))
print('Probability : {}'.format(logits.ravel()))
Prediction : Hat
Probability : [0.97 0.01 0.   0.01 0.  ]
InĀ [16]:
%%javascript
$.getScript('https://kmahelona.github.io/ipython_notebook_goodies/ipython_notebook_toc.js')