Self-supervised Learning

By Prof. Hyunseok Oh
https://sddo.gist.ac.kr/
SDDO Lab at GIST

1. Import library¶

In [ ]:
import tensorflow as tf
import numpy as np

from tensorflow import keras
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense

import matplotlib.pyplot as plt


In [ ]:
#Load MNIST dataset
(X_train, Y_train), (X_test, Y_test) = keras.datasets.mnist.load_data()

# using 1000 data for test
X_train=X_train[:1000]
Y_train=Y_train[:1000]

# using 300 data for test
X_test=X_test[:300]
Y_test=Y_test[:300]

print(X_train.shape)

(1000, 28, 28)


In [ ]:
n_samples = X_train.shape[0]

X_rotate= np.zeros(shape=(n_samples*4,X_train.shape[1],X_train.shape[2]))
Y_rotate =np.zeros(shape=(n_samples*4,4))

for i in range(n_samples):

img =X_train[i]

X_rotate[4*i-4]=img
Y_rotate[4*i-4]=[1,0,0,0]

# 90degree rotation: Transpose + Vertical flip
X_rotate[4*i-3]=np.flip(img.T,0)
Y_rotate[4*i-3]=[0,1,0,0]

# 180degree rotation: Vertical flip + Horizontal flip
X_rotate[4*i-2]=np.flip(np.flip(img,0),1)
Y_rotate[4*i-2]=[0,0,1,0]

# 270degree rotation: Vertical flip + Transpose
X_rotate[4*i-1]=np.flip(img,0).T
Y_rotate[4*i-1]=[0,0,0,1]


3.2 Check pretext task data set¶

In [ ]:
plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=1, hspace=None)
plt.subplots(figsize=(10, 10))

plt.subplot(141)
plt.imshow(X_rotate[4], cmap = 'gray')
plt.title('label : ' + str(Y_rotate[4]))
plt.axis('off')

plt.subplot(142)
plt.imshow(X_rotate[5], cmap = 'gray')
plt.title('label : ' + str(Y_rotate[5]))
plt.axis('off')

plt.subplot(143)
plt.imshow(X_rotate[6], cmap = 'gray')
plt.title('label : ' + str(Y_rotate[6]))
plt.axis('off')

plt.subplot(144)
plt.imshow(X_rotate[7], cmap = 'gray')
plt.title('label : ' +str(Y_rotate[7]))
plt.axis('off')

Out[ ]:
(-0.5, 27.5, 27.5, -0.5)
<Figure size 432x288 with 0 Axes>
In [ ]:
# reshape MNIST image data for convolution
print(X_rotate[0].shape)
X_rotate =X_rotate.reshape(-1,28,28,1)

(28, 28)


3.3 Build model¶

In [ ]:
layer1 = Conv2D(64, kernel_size=(3, 3), strides=(2, 2), padding='same',
activation='relu',kernel_initializer='random_normal')
layer2 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))
layer3 = Conv2D(32, kernel_size=(3, 3), strides=(1, 1), padding='same',
activation='relu' ,kernel_initializer='random_normal')
layer4 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))
layer5 = Conv2D(16, kernel_size=(3, 3), strides=(2, 2), padding='same',
activation='relu',kernel_initializer='random_normal')

layer6 = Flatten()
layer7= Dense(4, activation='softmax',kernel_initializer='random_normal')

model_pre = Sequential([keras.Input(shape=(28,28,1)),
layer1, layer2, layer3,layer4,layer5,
layer6,layer7])
model_pre.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                Output Shape              Param #
=================================================================
conv2d (Conv2D)             (None, 14, 14, 64)        640

max_pooling2d (MaxPooling2D  (None, 7, 7, 64)         0
)

conv2d_1 (Conv2D)           (None, 7, 7, 32)          18464

max_pooling2d_1 (MaxPooling  (None, 3, 3, 32)         0
2D)

conv2d_2 (Conv2D)           (None, 2, 2, 16)          4624

flatten (Flatten)           (None, 64)                0

dense (Dense)               (None, 4)                 260

=================================================================
Total params: 23,988
Trainable params: 23,988
Non-trainable params: 0
_________________________________________________________________


3.4 Train pretext model¶

In [ ]:
sgd = keras.optimizers.SGD(learning_rate = 0.001,momentum = 0.9)
model_pre.compile(loss = 'categorical_crossentropy', optimizer = sgd, metrics = ['accuracy'])
hist= model_pre.fit(X_rotate, Y_rotate, batch_size = 192, epochs = 50,verbose = 2, shuffle=False)

Epoch 1/50
21/21 - 2s - loss: 1.6171 - accuracy: 0.3975 - 2s/epoch - 94ms/step
Epoch 2/50
21/21 - 1s - loss: 0.7771 - accuracy: 0.6865 - 1s/epoch - 67ms/step
Epoch 3/50
21/21 - 1s - loss: 0.5225 - accuracy: 0.8160 - 1s/epoch - 68ms/step
Epoch 4/50
21/21 - 1s - loss: 0.3773 - accuracy: 0.8755 - 1s/epoch - 67ms/step
Epoch 5/50
21/21 - 1s - loss: 0.2949 - accuracy: 0.9038 - 1s/epoch - 67ms/step
Epoch 6/50
21/21 - 1s - loss: 0.2431 - accuracy: 0.9195 - 1s/epoch - 69ms/step
Epoch 7/50
21/21 - 1s - loss: 0.2037 - accuracy: 0.9342 - 1s/epoch - 67ms/step
Epoch 8/50
21/21 - 1s - loss: 0.1764 - accuracy: 0.9415 - 1s/epoch - 66ms/step
Epoch 9/50
21/21 - 1s - loss: 0.1547 - accuracy: 0.9510 - 1s/epoch - 66ms/step
Epoch 10/50
21/21 - 1s - loss: 0.1376 - accuracy: 0.9545 - 1s/epoch - 66ms/step
Epoch 11/50
21/21 - 1s - loss: 0.1240 - accuracy: 0.9582 - 1s/epoch - 67ms/step
Epoch 12/50
21/21 - 1s - loss: 0.1131 - accuracy: 0.9635 - 1s/epoch - 66ms/step
Epoch 13/50
21/21 - 1s - loss: 0.1087 - accuracy: 0.9655 - 1s/epoch - 69ms/step
Epoch 14/50
21/21 - 1s - loss: 0.1150 - accuracy: 0.9582 - 1s/epoch - 67ms/step
Epoch 15/50
21/21 - 1s - loss: 0.1306 - accuracy: 0.9520 - 1s/epoch - 66ms/step
Epoch 16/50
21/21 - 1s - loss: 0.1857 - accuracy: 0.9295 - 1s/epoch - 67ms/step
Epoch 17/50
21/21 - 1s - loss: 0.1415 - accuracy: 0.9485 - 1s/epoch - 66ms/step
Epoch 18/50
21/21 - 1s - loss: 0.1681 - accuracy: 0.9413 - 1s/epoch - 66ms/step
Epoch 19/50
21/21 - 1s - loss: 0.1142 - accuracy: 0.9572 - 1s/epoch - 67ms/step
Epoch 20/50
21/21 - 1s - loss: 0.0865 - accuracy: 0.9693 - 1s/epoch - 67ms/step
Epoch 21/50
21/21 - 1s - loss: 0.0762 - accuracy: 0.9720 - 1s/epoch - 66ms/step
Epoch 22/50
21/21 - 1s - loss: 0.0954 - accuracy: 0.9625 - 1s/epoch - 67ms/step
Epoch 23/50
21/21 - 1s - loss: 0.1278 - accuracy: 0.9480 - 1s/epoch - 67ms/step
Epoch 24/50
21/21 - 1s - loss: 0.1229 - accuracy: 0.9535 - 1s/epoch - 66ms/step
Epoch 25/50
21/21 - 1s - loss: 0.0919 - accuracy: 0.9640 - 1s/epoch - 67ms/step
Epoch 26/50
21/21 - 1s - loss: 0.0570 - accuracy: 0.9810 - 1s/epoch - 67ms/step
Epoch 27/50
21/21 - 1s - loss: 0.0568 - accuracy: 0.9808 - 1s/epoch - 67ms/step
Epoch 28/50
21/21 - 1s - loss: 0.0608 - accuracy: 0.9795 - 1s/epoch - 68ms/step
Epoch 29/50
21/21 - 1s - loss: 0.0646 - accuracy: 0.9762 - 1s/epoch - 67ms/step
Epoch 30/50
21/21 - 1s - loss: 0.0729 - accuracy: 0.9712 - 1s/epoch - 68ms/step
Epoch 31/50
21/21 - 1s - loss: 0.0889 - accuracy: 0.9665 - 1s/epoch - 66ms/step
Epoch 32/50
21/21 - 1s - loss: 0.1087 - accuracy: 0.9620 - 1s/epoch - 66ms/step
Epoch 33/50
21/21 - 1s - loss: 0.1336 - accuracy: 0.9532 - 1s/epoch - 68ms/step
Epoch 34/50
21/21 - 1s - loss: 0.1815 - accuracy: 0.9330 - 1s/epoch - 67ms/step
Epoch 35/50
21/21 - 1s - loss: 0.2080 - accuracy: 0.9298 - 1s/epoch - 68ms/step
Epoch 36/50
21/21 - 1s - loss: 0.1263 - accuracy: 0.9510 - 1s/epoch - 66ms/step
Epoch 37/50
21/21 - 1s - loss: 0.0512 - accuracy: 0.9825 - 1s/epoch - 67ms/step
Epoch 38/50
21/21 - 1s - loss: 0.0305 - accuracy: 0.9920 - 1s/epoch - 66ms/step
Epoch 39/50
21/21 - 1s - loss: 0.0297 - accuracy: 0.9915 - 1s/epoch - 66ms/step
Epoch 40/50
21/21 - 1s - loss: 0.0276 - accuracy: 0.9935 - 1s/epoch - 66ms/step
Epoch 41/50
21/21 - 1s - loss: 0.0288 - accuracy: 0.9910 - 1s/epoch - 66ms/step
Epoch 42/50
21/21 - 1s - loss: 0.0321 - accuracy: 0.9902 - 1s/epoch - 65ms/step
Epoch 43/50
21/21 - 1s - loss: 0.0330 - accuracy: 0.9890 - 1s/epoch - 67ms/step
Epoch 44/50
21/21 - 1s - loss: 0.0341 - accuracy: 0.9877 - 1s/epoch - 67ms/step
Epoch 45/50
21/21 - 1s - loss: 0.0348 - accuracy: 0.9865 - 1s/epoch - 66ms/step
Epoch 46/50
21/21 - 1s - loss: 0.0333 - accuracy: 0.9883 - 1s/epoch - 66ms/step
Epoch 47/50
21/21 - 1s - loss: 0.0340 - accuracy: 0.9880 - 1s/epoch - 65ms/step
Epoch 48/50
21/21 - 1s - loss: 0.0562 - accuracy: 0.9780 - 1s/epoch - 66ms/step
Epoch 49/50
21/21 - 1s - loss: 0.0552 - accuracy: 0.9790 - 1s/epoch - 66ms/step
Epoch 50/50
21/21 - 1s - loss: 0.0630 - accuracy: 0.9747 - 1s/epoch - 67ms/step

In [ ]:
# Freeze the pretext model
model_pre.trainable=False


4.1 Reshape dataset¶

In [ ]:
X_train =X_train.reshape(-1,28,28,1)
X_test =X_test.reshape(-1,28,28,1)
Y_train = tf.one_hot(Y_train, 10,on_value=1.0, off_value=0.0)
Y_test = tf.one_hot(Y_test, 10,on_value=1.0, off_value=0.0)


4.2 Build model¶

In [ ]:
layer9=Flatten()
layer10 = Dense(10,activation = 'softmax',kernel_initializer='random_normal')
# new layer to classify 10 numbers

model_down = keras.Sequential([keras.Input(shape=(28,28,1)),
layer1, layer2, layer3,layer4,layer9,layer10])
model_down.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                Output Shape              Param #
=================================================================
conv2d (Conv2D)             (None, 14, 14, 64)        640

max_pooling2d (MaxPooling2D  (None, 7, 7, 64)         0
)

conv2d_1 (Conv2D)           (None, 7, 7, 32)          18464

max_pooling2d_1 (MaxPooling  (None, 3, 3, 32)         0
2D)

flatten_1 (Flatten)         (None, 288)               0

dense_1 (Dense)             (None, 10)                2890

=================================================================
Total params: 21,994
Trainable params: 2,890
Non-trainable params: 19,104
_________________________________________________________________


4.3 Train downstream model¶

In [ ]:
callback = tf.keras.callbacks.EarlyStopping(monitor='accuracy', patience=7)
#validation_split for training
split = 0.3

In [ ]:
model_down.compile(loss = 'categorical_crossentropy', optimizer = sgd, metrics = ['accuracy'])
hist_down= model_down.fit(X_train, Y_train, batch_size = 64, validation_split = split,
epochs = 50,verbose = 2, callbacks = callback)

Epoch 1/50
11/11 - 1s - loss: 145.5323 - accuracy: 0.1443 - val_loss: 95.5060 - val_accuracy: 0.1333 - 660ms/epoch - 60ms/step
Epoch 2/50
11/11 - 0s - loss: 112.1947 - accuracy: 0.3200 - val_loss: 108.7943 - val_accuracy: 0.1800 - 160ms/epoch - 15ms/step
Epoch 3/50
11/11 - 0s - loss: 55.6467 - accuracy: 0.5029 - val_loss: 63.7907 - val_accuracy: 0.4933 - 144ms/epoch - 13ms/step
Epoch 4/50
11/11 - 0s - loss: 27.0338 - accuracy: 0.6500 - val_loss: 26.3296 - val_accuracy: 0.6367 - 144ms/epoch - 13ms/step
Epoch 5/50
11/11 - 0s - loss: 17.3450 - accuracy: 0.7286 - val_loss: 11.5728 - val_accuracy: 0.7700 - 149ms/epoch - 14ms/step
Epoch 6/50
11/11 - 0s - loss: 7.8281 - accuracy: 0.8429 - val_loss: 9.8907 - val_accuracy: 0.8000 - 154ms/epoch - 14ms/step
Epoch 7/50
11/11 - 0s - loss: 6.1283 - accuracy: 0.8643 - val_loss: 9.0447 - val_accuracy: 0.8067 - 154ms/epoch - 14ms/step
Epoch 8/50
11/11 - 0s - loss: 4.1972 - accuracy: 0.8771 - val_loss: 7.6835 - val_accuracy: 0.8333 - 141ms/epoch - 13ms/step
Epoch 9/50
11/11 - 0s - loss: 4.1619 - accuracy: 0.8929 - val_loss: 8.4072 - val_accuracy: 0.7800 - 149ms/epoch - 14ms/step
Epoch 10/50
11/11 - 0s - loss: 4.2850 - accuracy: 0.8729 - val_loss: 6.6977 - val_accuracy: 0.8200 - 141ms/epoch - 13ms/step
Epoch 11/50
11/11 - 0s - loss: 4.5401 - accuracy: 0.8714 - val_loss: 8.8403 - val_accuracy: 0.8000 - 144ms/epoch - 13ms/step
Epoch 12/50
11/11 - 0s - loss: 3.5040 - accuracy: 0.9014 - val_loss: 8.3725 - val_accuracy: 0.7867 - 146ms/epoch - 13ms/step
Epoch 13/50
11/11 - 0s - loss: 2.9546 - accuracy: 0.9129 - val_loss: 5.3033 - val_accuracy: 0.8533 - 156ms/epoch - 14ms/step
Epoch 14/50
11/11 - 0s - loss: 1.6648 - accuracy: 0.9343 - val_loss: 7.1830 - val_accuracy: 0.8367 - 141ms/epoch - 13ms/step
Epoch 15/50
11/11 - 0s - loss: 1.6900 - accuracy: 0.9357 - val_loss: 5.0708 - val_accuracy: 0.8800 - 147ms/epoch - 13ms/step
Epoch 16/50
11/11 - 0s - loss: 1.2530 - accuracy: 0.9386 - val_loss: 5.0104 - val_accuracy: 0.8767 - 142ms/epoch - 13ms/step
Epoch 17/50
11/11 - 0s - loss: 1.0832 - accuracy: 0.9443 - val_loss: 5.9515 - val_accuracy: 0.8733 - 152ms/epoch - 14ms/step
Epoch 18/50
11/11 - 0s - loss: 1.2069 - accuracy: 0.9286 - val_loss: 6.4323 - val_accuracy: 0.8567 - 136ms/epoch - 12ms/step
Epoch 19/50
11/11 - 0s - loss: 1.6460 - accuracy: 0.9314 - val_loss: 5.5614 - val_accuracy: 0.8700 - 149ms/epoch - 14ms/step
Epoch 20/50
11/11 - 0s - loss: 1.3438 - accuracy: 0.9314 - val_loss: 6.0327 - val_accuracy: 0.8567 - 150ms/epoch - 14ms/step
Epoch 21/50
11/11 - 0s - loss: 1.0919 - accuracy: 0.9343 - val_loss: 4.9564 - val_accuracy: 0.8767 - 149ms/epoch - 14ms/step
Epoch 22/50
11/11 - 0s - loss: 0.8793 - accuracy: 0.9529 - val_loss: 5.2268 - val_accuracy: 0.8700 - 142ms/epoch - 13ms/step
Epoch 23/50
11/11 - 0s - loss: 0.7904 - accuracy: 0.9471 - val_loss: 5.1340 - val_accuracy: 0.8933 - 147ms/epoch - 13ms/step
Epoch 24/50
11/11 - 0s - loss: 0.4839 - accuracy: 0.9657 - val_loss: 6.1588 - val_accuracy: 0.8467 - 141ms/epoch - 13ms/step
Epoch 25/50
11/11 - 0s - loss: 0.6237 - accuracy: 0.9557 - val_loss: 5.4716 - val_accuracy: 0.8767 - 141ms/epoch - 13ms/step
Epoch 26/50
11/11 - 0s - loss: 0.3927 - accuracy: 0.9700 - val_loss: 5.0492 - val_accuracy: 0.8800 - 146ms/epoch - 13ms/step
Epoch 27/50
11/11 - 0s - loss: 0.4778 - accuracy: 0.9643 - val_loss: 5.8995 - val_accuracy: 0.8400 - 158ms/epoch - 14ms/step
Epoch 28/50
11/11 - 0s - loss: 0.4085 - accuracy: 0.9686 - val_loss: 5.0752 - val_accuracy: 0.8667 - 144ms/epoch - 13ms/step
Epoch 29/50
11/11 - 0s - loss: 0.3238 - accuracy: 0.9800 - val_loss: 5.1408 - val_accuracy: 0.8900 - 146ms/epoch - 13ms/step
Epoch 30/50
11/11 - 0s - loss: 0.1134 - accuracy: 0.9857 - val_loss: 4.5113 - val_accuracy: 0.8900 - 149ms/epoch - 14ms/step
Epoch 31/50
11/11 - 0s - loss: 0.1502 - accuracy: 0.9857 - val_loss: 4.5989 - val_accuracy: 0.8900 - 145ms/epoch - 13ms/step
Epoch 32/50
11/11 - 0s - loss: 0.1626 - accuracy: 0.9829 - val_loss: 5.0097 - val_accuracy: 0.8767 - 142ms/epoch - 13ms/step
Epoch 33/50
11/11 - 0s - loss: 0.0347 - accuracy: 0.9900 - val_loss: 4.9909 - val_accuracy: 0.8700 - 154ms/epoch - 14ms/step
Epoch 34/50
11/11 - 0s - loss: 0.0437 - accuracy: 0.9914 - val_loss: 4.6969 - val_accuracy: 0.8867 - 157ms/epoch - 14ms/step
Epoch 35/50
11/11 - 0s - loss: 0.0343 - accuracy: 0.9900 - val_loss: 4.5254 - val_accuracy: 0.8900 - 156ms/epoch - 14ms/step
Epoch 36/50
11/11 - 0s - loss: 0.0326 - accuracy: 0.9929 - val_loss: 4.8820 - val_accuracy: 0.8767 - 152ms/epoch - 14ms/step
Epoch 37/50
11/11 - 0s - loss: 0.0259 - accuracy: 0.9929 - val_loss: 5.0051 - val_accuracy: 0.8800 - 153ms/epoch - 14ms/step
Epoch 38/50
11/11 - 0s - loss: 0.0597 - accuracy: 0.9914 - val_loss: 4.7422 - val_accuracy: 0.8900 - 157ms/epoch - 14ms/step
Epoch 39/50
11/11 - 0s - loss: 0.0087 - accuracy: 0.9971 - val_loss: 4.8613 - val_accuracy: 0.8733 - 157ms/epoch - 14ms/step
Epoch 40/50
11/11 - 0s - loss: 0.0262 - accuracy: 0.9943 - val_loss: 4.7491 - val_accuracy: 0.8900 - 160ms/epoch - 15ms/step
Epoch 41/50
11/11 - 0s - loss: 0.0301 - accuracy: 0.9886 - val_loss: 5.3986 - val_accuracy: 0.8533 - 155ms/epoch - 14ms/step
Epoch 42/50
11/11 - 0s - loss: 0.0903 - accuracy: 0.9857 - val_loss: 5.0836 - val_accuracy: 0.8800 - 155ms/epoch - 14ms/step
Epoch 43/50
11/11 - 0s - loss: 0.0312 - accuracy: 0.9943 - val_loss: 4.6061 - val_accuracy: 0.8867 - 148ms/epoch - 13ms/step
Epoch 44/50
11/11 - 0s - loss: 8.9462e-04 - accuracy: 1.0000 - val_loss: 4.7458 - val_accuracy: 0.8867 - 148ms/epoch - 13ms/step
Epoch 45/50
11/11 - 0s - loss: 0.0054 - accuracy: 0.9986 - val_loss: 4.7773 - val_accuracy: 0.8900 - 150ms/epoch - 14ms/step
Epoch 46/50
11/11 - 0s - loss: 0.0202 - accuracy: 0.9943 - val_loss: 4.6549 - val_accuracy: 0.8833 - 159ms/epoch - 14ms/step
Epoch 47/50
11/11 - 0s - loss: 0.0152 - accuracy: 0.9971 - val_loss: 4.9255 - val_accuracy: 0.8867 - 142ms/epoch - 13ms/step
Epoch 48/50
11/11 - 0s - loss: 0.0131 - accuracy: 0.9971 - val_loss: 4.6451 - val_accuracy: 0.8833 - 146ms/epoch - 13ms/step
Epoch 49/50
11/11 - 0s - loss: 0.0612 - accuracy: 0.9929 - val_loss: 4.6479 - val_accuracy: 0.9000 - 145ms/epoch - 13ms/step
Epoch 50/50
11/11 - 0s - loss: 0.0673 - accuracy: 0.9957 - val_loss: 4.9052 - val_accuracy: 0.8933 - 146ms/epoch - 13ms/step


4.4 Result examples¶

In [ ]:
prediction = np.round(model_down.predict(X_train))
wrong_predictions = np.where(prediction != Y_train)[0]
correct_predictions = np.where(prediction == Y_train)[0]
img = X_train.reshape(-1,28,28)

plt.subplot(121)
plt.imshow(img[correct_predictions[0]], cmap = 'gray')
plt.title('Prediction : '+ str(np.argmax(prediction[correct_predictions[0]]))+
'\n True answer : '+ str(np.argmax(Y_train[correct_predictions[0]])))
plt.axis('off')

plt.subplot(122)
plt.imshow(img[wrong_predictions[0]], cmap = 'gray')
plt.title('Prediction : '+ str(np.argmax(prediction[wrong_predictions[0]]))+
'\n True answer : '+ str(np.argmax(Y_train[wrong_predictions[0]])))
plt.axis('off')

Out[ ]:
(-0.5, 27.5, 27.5, -0.5)