Pre-trained Models


By Prof. Hyunseok Oh
https://sddo.gist.ac.kr/
SDDO Lab at GIST

1. Import library

In [9]:
import tensorflow as tf
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras import models
from tensorflow.keras import layers
from tensorflow.keras import optimizers
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras.preprocessing.image import img_to_array, array_to_img
from tensorflow.keras.applications.vgg16 import decode_predictions

2. Load data

In [10]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train = x_train[:200].reshape(-1,28,28)
x_train = np.stack([x_train] * 3, axis=3)
train_X = np.asarray([img_to_array(array_to_img(im, scale=False).resize((32,32))) for im in x_train])
y_train = y_train[:200]
train_Y = tf.one_hot(y_train, 10, on_value=1.0, off_value=0.0)

x_test = x_test[:50].reshape(-1,28,28)
x_test = np.stack([x_test] * 3, axis=3)
test_X = np.asarray([img_to_array(array_to_img(im, scale=False).resize((32,32))) for im in x_test])
y_test = y_test[:50]
test_Y = tf.one_hot(y_test, 10, on_value=1.0, off_value=0.0)

x_show = x_test[:3]
x_show = np.asarray([img_to_array(array_to_img(im, scale=False).resize((224,224))) for im in x_show]).astype(int)
y_show = tf.one_hot(y_test[:3], 1000, on_value=1.0, off_value=0.0)

3. Immediate use of pre-trained model

In [11]:
model = VGG16()
pred = decode_predictions(model.predict(x_show), top=1)
for i in range(3):
    print("예측 : ", pred[i][0][1])
    print("정답 : ", np.argmax(y_show[i]))
    plt.imshow(x_show[i])
    plt.show()
예측 :  nematode
정답 :  7
예측 :  nematode
정답 :  2
예측 :  nail
정답 :  1
  • Nematode : 선형동물
  • Nail : 못

4. Fine-tuning procedure

4.1 Feature extraction

(1) 사전 학습된 모델에 새로운 분류기를 추가 및 변경
(2) 사전 학습된 네트워크를 고정 (Freeze)
(3) 추가된 새 분류기 학습

In [12]:
x_train = np.asarray([img_to_array(array_to_img(im, scale=False).resize((32,32))) for im in x_train])
x_test = np.asarray([img_to_array(array_to_img(im, scale=False).resize((32,32))) for im in x_test])
y_train = tf.one_hot(y_train, 10, on_value=1.0, off_value=0.0)
y_test = tf.one_hot(y_test, 10, on_value=1.0, off_value=0.0)

conv_base = VGG16(weights='imagenet',
                  include_top=False,
                  input_shape=(32, 32, 3))

model = models.Sequential()
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dense(10,activation = 'softmax'))

conv_base.trainable = False

model.compile(loss='categorical_crossentropy',
              optimizer=optimizers.Adam(learning_rate=2e-5),
              metrics=['acc'])

history = model.fit(
      x_train,
      y_train,
      epochs=50,
      validation_data=(x_test, y_test),
      verbose = 2)

plt.plot(range(len(history.history['acc'])), history.history['acc'], label='Training acc')
plt.plot(range(len(history.history['acc'])), history.history['val_acc'], label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()

plt.figure()

plt.plot(range(len(history.history['acc'])), history.history['loss'], label='Training loss')
plt.plot(range(len(history.history['acc'])), history.history['val_loss'], label='Validation loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()
Epoch 1/50
7/7 - 4s - loss: 30.7446 - acc: 0.0650 - val_loss: 28.1439 - val_acc: 0.0800 - 4s/epoch - 539ms/step
Epoch 2/50
7/7 - 3s - loss: 29.4302 - acc: 0.0650 - val_loss: 26.8792 - val_acc: 0.0800 - 3s/epoch - 401ms/step
Epoch 3/50
7/7 - 3s - loss: 28.2108 - acc: 0.0650 - val_loss: 25.6510 - val_acc: 0.0800 - 3s/epoch - 397ms/step
Epoch 4/50
7/7 - 3s - loss: 26.9950 - acc: 0.0450 - val_loss: 24.5018 - val_acc: 0.0800 - 3s/epoch - 395ms/step
Epoch 5/50
7/7 - 3s - loss: 25.8758 - acc: 0.0450 - val_loss: 23.4195 - val_acc: 0.0600 - 3s/epoch - 393ms/step
Epoch 6/50
7/7 - 3s - loss: 24.8180 - acc: 0.0400 - val_loss: 22.4281 - val_acc: 0.0600 - 3s/epoch - 391ms/step
Epoch 7/50
7/7 - 3s - loss: 23.8128 - acc: 0.0450 - val_loss: 21.4982 - val_acc: 0.0800 - 3s/epoch - 394ms/step
Epoch 8/50
7/7 - 3s - loss: 22.8543 - acc: 0.0400 - val_loss: 20.6171 - val_acc: 0.0600 - 3s/epoch - 391ms/step
Epoch 9/50
7/7 - 3s - loss: 21.9202 - acc: 0.0400 - val_loss: 19.8143 - val_acc: 0.0600 - 3s/epoch - 392ms/step
Epoch 10/50
7/7 - 3s - loss: 21.0491 - acc: 0.0400 - val_loss: 19.0579 - val_acc: 0.0600 - 3s/epoch - 393ms/step
Epoch 11/50
7/7 - 3s - loss: 20.1974 - acc: 0.0450 - val_loss: 18.3401 - val_acc: 0.0400 - 3s/epoch - 391ms/step
Epoch 12/50
7/7 - 3s - loss: 19.3966 - acc: 0.0400 - val_loss: 17.6502 - val_acc: 0.0400 - 3s/epoch - 392ms/step
Epoch 13/50
7/7 - 3s - loss: 18.6019 - acc: 0.0400 - val_loss: 16.9971 - val_acc: 0.0600 - 3s/epoch - 396ms/step
Epoch 14/50
7/7 - 3s - loss: 17.8525 - acc: 0.0400 - val_loss: 16.3682 - val_acc: 0.0600 - 3s/epoch - 393ms/step
Epoch 15/50
7/7 - 3s - loss: 17.1464 - acc: 0.0400 - val_loss: 15.7454 - val_acc: 0.0600 - 3s/epoch - 394ms/step
Epoch 16/50
7/7 - 3s - loss: 16.4378 - acc: 0.0500 - val_loss: 15.1549 - val_acc: 0.0600 - 3s/epoch - 394ms/step
Epoch 17/50
7/7 - 3s - loss: 15.7721 - acc: 0.0700 - val_loss: 14.5868 - val_acc: 0.0800 - 3s/epoch - 397ms/step
Epoch 18/50
7/7 - 3s - loss: 15.1361 - acc: 0.0750 - val_loss: 14.0579 - val_acc: 0.1200 - 3s/epoch - 395ms/step
Epoch 19/50
7/7 - 3s - loss: 14.5406 - acc: 0.0700 - val_loss: 13.5484 - val_acc: 0.1200 - 3s/epoch - 391ms/step
Epoch 20/50
7/7 - 3s - loss: 13.9568 - acc: 0.0750 - val_loss: 13.0671 - val_acc: 0.1200 - 3s/epoch - 397ms/step
Epoch 21/50
7/7 - 3s - loss: 13.3881 - acc: 0.0850 - val_loss: 12.6048 - val_acc: 0.1800 - 3s/epoch - 391ms/step
Epoch 22/50
7/7 - 3s - loss: 12.8533 - acc: 0.0950 - val_loss: 12.1658 - val_acc: 0.1600 - 3s/epoch - 390ms/step
Epoch 23/50
7/7 - 3s - loss: 12.3275 - acc: 0.1050 - val_loss: 11.7516 - val_acc: 0.1600 - 3s/epoch - 395ms/step
Epoch 24/50
7/7 - 3s - loss: 11.8262 - acc: 0.1100 - val_loss: 11.3508 - val_acc: 0.1600 - 3s/epoch - 402ms/step
Epoch 25/50
7/7 - 3s - loss: 11.3416 - acc: 0.1150 - val_loss: 10.9698 - val_acc: 0.1800 - 3s/epoch - 391ms/step
Epoch 26/50
7/7 - 3s - loss: 10.8874 - acc: 0.1250 - val_loss: 10.6027 - val_acc: 0.1800 - 3s/epoch - 395ms/step
Epoch 27/50
7/7 - 3s - loss: 10.4444 - acc: 0.1300 - val_loss: 10.2673 - val_acc: 0.1600 - 3s/epoch - 394ms/step
Epoch 28/50
7/7 - 3s - loss: 10.0124 - acc: 0.1450 - val_loss: 9.9528 - val_acc: 0.2200 - 3s/epoch - 394ms/step
Epoch 29/50
7/7 - 3s - loss: 9.6067 - acc: 0.1700 - val_loss: 9.6695 - val_acc: 0.2200 - 3s/epoch - 392ms/step
Epoch 30/50
7/7 - 3s - loss: 9.2175 - acc: 0.1800 - val_loss: 9.4098 - val_acc: 0.2600 - 3s/epoch - 395ms/step
Epoch 31/50
7/7 - 3s - loss: 8.8610 - acc: 0.2050 - val_loss: 9.1738 - val_acc: 0.2800 - 3s/epoch - 395ms/step
Epoch 32/50
7/7 - 3s - loss: 8.5306 - acc: 0.2500 - val_loss: 8.9354 - val_acc: 0.2800 - 3s/epoch - 389ms/step
Epoch 33/50
7/7 - 3s - loss: 8.2059 - acc: 0.2650 - val_loss: 8.7266 - val_acc: 0.2800 - 3s/epoch - 393ms/step
Epoch 34/50
7/7 - 3s - loss: 7.9215 - acc: 0.2900 - val_loss: 8.5190 - val_acc: 0.3000 - 3s/epoch - 395ms/step
Epoch 35/50
7/7 - 3s - loss: 7.6488 - acc: 0.3000 - val_loss: 8.3260 - val_acc: 0.3000 - 3s/epoch - 392ms/step
Epoch 36/50
7/7 - 3s - loss: 7.3901 - acc: 0.3050 - val_loss: 8.1297 - val_acc: 0.3000 - 3s/epoch - 395ms/step
Epoch 37/50
7/7 - 3s - loss: 7.1243 - acc: 0.3350 - val_loss: 7.9325 - val_acc: 0.3000 - 3s/epoch - 397ms/step
Epoch 38/50
7/7 - 3s - loss: 6.8747 - acc: 0.3450 - val_loss: 7.7373 - val_acc: 0.3000 - 3s/epoch - 398ms/step
Epoch 39/50
7/7 - 3s - loss: 6.6244 - acc: 0.3550 - val_loss: 7.5583 - val_acc: 0.3000 - 3s/epoch - 398ms/step
Epoch 40/50
7/7 - 3s - loss: 6.3852 - acc: 0.3550 - val_loss: 7.3841 - val_acc: 0.3000 - 3s/epoch - 395ms/step
Epoch 41/50
7/7 - 3s - loss: 6.1554 - acc: 0.3700 - val_loss: 7.2187 - val_acc: 0.3000 - 3s/epoch - 395ms/step
Epoch 42/50
7/7 - 3s - loss: 5.9498 - acc: 0.3750 - val_loss: 7.0398 - val_acc: 0.3000 - 3s/epoch - 396ms/step
Epoch 43/50
7/7 - 3s - loss: 5.7261 - acc: 0.3950 - val_loss: 6.8711 - val_acc: 0.3200 - 3s/epoch - 391ms/step
Epoch 44/50
7/7 - 3s - loss: 5.5119 - acc: 0.4150 - val_loss: 6.7076 - val_acc: 0.3200 - 3s/epoch - 394ms/step
Epoch 45/50
7/7 - 3s - loss: 5.3113 - acc: 0.4150 - val_loss: 6.5472 - val_acc: 0.3200 - 3s/epoch - 396ms/step
Epoch 46/50
7/7 - 3s - loss: 5.1181 - acc: 0.4150 - val_loss: 6.4037 - val_acc: 0.3200 - 3s/epoch - 396ms/step
Epoch 47/50
7/7 - 3s - loss: 4.9323 - acc: 0.4300 - val_loss: 6.2565 - val_acc: 0.3200 - 3s/epoch - 397ms/step
Epoch 48/50
7/7 - 3s - loss: 4.7641 - acc: 0.4300 - val_loss: 6.1032 - val_acc: 0.3200 - 3s/epoch - 402ms/step
Epoch 49/50
7/7 - 3s - loss: 4.5883 - acc: 0.4450 - val_loss: 5.9615 - val_acc: 0.3200 - 3s/epoch - 399ms/step
Epoch 50/50
7/7 - 3s - loss: 4.4095 - acc: 0.4550 - val_loss: 5.8371 - val_acc: 0.3200 - 3s/epoch - 396ms/step
In [13]:
pred = model.predict(x_test[:3])
for i in range(3):
    print("예측 : ", np.argmax(pred[i]))
    print("정답 : ", np.argmax(y_test[i]))
    plt.imshow(x_test[i].astype(int))
    plt.show()
WARNING:tensorflow:5 out of the last 5 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7fd01ee6def0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
예측 :  7
정답 :  7
예측 :  3
정답 :  2
예측 :  1
정답 :  1

4.2 Fine-tuning

(4) 사전 학습된 네트워크의 일부 레이어를 고정 해제 (Unfreeze)
(5) 추가된 분류기와 고정 해제된 일부 레이어를 학습

In [14]:
model.trainable = True

set_trainable = False
for layer in conv_base.layers:
    if layer.name == 'block5_conv1':
        set_trainable = True
    if set_trainable:
        layer.trainable = True
    else:
        layer.trainable = False
In [15]:
model.compile(loss='categorical_crossentropy',
              optimizer=optimizers.Adam(learning_rate=2e-5),
              metrics=['acc'])

history = model.fit(
      x_train,
      y_train,
      epochs=10,
      validation_data=(x_test, y_test),
      verbose = 2)

plt.plot(range(len(history.history['acc'])), history.history['acc'], label='Training acc')
plt.plot(range(len(history.history['acc'])), history.history['val_acc'], label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()

plt.figure()
plt.plot(range(len(history.history['acc'])), history.history['loss'], label='Training loss')
plt.plot(range(len(history.history['acc'])), history.history['val_loss'], label='Validation loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()
Epoch 1/10
7/7 - 12s - loss: 3.4961 - acc: 0.4800 - val_loss: 3.4754 - val_acc: 0.4600 - 12s/epoch - 2s/step
Epoch 2/10
7/7 - 11s - loss: 1.0802 - acc: 0.7700 - val_loss: 2.1089 - val_acc: 0.5400 - 11s/epoch - 2s/step
Epoch 3/10
7/7 - 11s - loss: 0.4703 - acc: 0.8850 - val_loss: 1.8127 - val_acc: 0.6200 - 11s/epoch - 2s/step
Epoch 4/10
7/7 - 11s - loss: 0.1841 - acc: 0.9500 - val_loss: 1.5410 - val_acc: 0.6800 - 11s/epoch - 2s/step
Epoch 5/10
7/7 - 11s - loss: 0.0433 - acc: 0.9950 - val_loss: 1.5427 - val_acc: 0.7400 - 11s/epoch - 2s/step
Epoch 6/10
7/7 - 11s - loss: 0.0214 - acc: 1.0000 - val_loss: 1.5450 - val_acc: 0.7200 - 11s/epoch - 2s/step
Epoch 7/10
7/7 - 11s - loss: 0.0091 - acc: 1.0000 - val_loss: 1.5660 - val_acc: 0.7000 - 11s/epoch - 2s/step
Epoch 8/10
7/7 - 11s - loss: 0.0050 - acc: 1.0000 - val_loss: 1.5974 - val_acc: 0.7000 - 11s/epoch - 2s/step
Epoch 9/10
7/7 - 11s - loss: 0.0034 - acc: 1.0000 - val_loss: 1.6065 - val_acc: 0.7000 - 11s/epoch - 2s/step
Epoch 10/10
7/7 - 11s - loss: 0.0025 - acc: 1.0000 - val_loss: 1.5962 - val_acc: 0.7000 - 11s/epoch - 2s/step
In [16]:
pred = model.predict(x_test[:3])
for i in range(3):
    print("예측 : ", np.argmax(pred[i]))
    print("정답 : ", np.argmax(y_test[i]))
    plt.imshow(x_test[i].astype(int))
    plt.show()
WARNING:tensorflow:6 out of the last 6 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7fd01e949560> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
예측 :  7
정답 :  7
예측 :  2
정답 :  2
예측 :  1
정답 :  1