Self-supervised Learning


By Prof. Hyunseok Oh
https://sddo.gist.ac.kr/
SDDO Lab at GIST

1. Import library

In [ ]:
import tensorflow as tf
import numpy as np 

from tensorflow import keras
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense

import matplotlib.pyplot as plt

2. Load MNIST dataset

In [ ]:
#Load MNIST dataset
(X_train, Y_train), (X_test, Y_test) = keras.datasets.mnist.load_data()

# using 1000 data for test
X_train=X_train[:1000]
Y_train=Y_train[:1000]

# using 300 data for test
X_test=X_test[:300]
Y_test=Y_test[:300]

print(X_train.shape)
(1000, 28, 28)

3. Build Pretext task

ssl2.png

3.1 Dateset for Pretext task(rotation)

In [ ]:
n_samples = X_train.shape[0]

X_rotate= np.zeros(shape=(n_samples*4,X_train.shape[1],X_train.shape[2]))
Y_rotate =np.zeros(shape=(n_samples*4,4))

for i in range(n_samples):
    
    img =X_train[i]
    
    
    X_rotate[4*i-4]=img
    Y_rotate[4*i-4]=[1,0,0,0] 
    
    # 90degree rotation: Transpose + Vertical flip
    X_rotate[4*i-3]=np.flip(img.T,0)
    Y_rotate[4*i-3]=[0,1,0,0] 
    
    # 180degree rotation: Vertical flip + Horizontal flip
    X_rotate[4*i-2]=np.flip(np.flip(img,0),1)
    Y_rotate[4*i-2]=[0,0,1,0] 
    
    # 270degree rotation: Vertical flip + Transpose
    X_rotate[4*i-1]=np.flip(img,0).T
    Y_rotate[4*i-1]=[0,0,0,1] 

3.2 Check pretext task data set

In [ ]:
plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=1, hspace=None)
plt.subplots(figsize=(10, 10))

plt.subplot(141)
plt.imshow(X_rotate[4], cmap = 'gray')
plt.title('label : ' + str(Y_rotate[4]))
plt.axis('off') 

plt.subplot(142)
plt.imshow(X_rotate[5], cmap = 'gray')
plt.title('label : ' + str(Y_rotate[5]))
plt.axis('off') 

plt.subplot(143)
plt.imshow(X_rotate[6], cmap = 'gray')
plt.title('label : ' + str(Y_rotate[6]))
plt.axis('off') 

plt.subplot(144)
plt.imshow(X_rotate[7], cmap = 'gray')
plt.title('label : ' +str(Y_rotate[7]))
plt.axis('off') 
Out[ ]:
(-0.5, 27.5, 27.5, -0.5)
<Figure size 432x288 with 0 Axes>
In [ ]:
# reshape MNIST image data for convolution
print(X_rotate[0].shape)
X_rotate =X_rotate.reshape(-1,28,28,1)
(28, 28)

3.3 Build model

In [ ]:
layer1 = Conv2D(64, kernel_size=(3, 3), strides=(2, 2), padding='same',
                activation='relu',kernel_initializer='random_normal')
layer2 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))
layer3 = Conv2D(32, kernel_size=(3, 3), strides=(1, 1), padding='same', 
                activation='relu' ,kernel_initializer='random_normal')
layer4 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))
layer5 = Conv2D(16, kernel_size=(3, 3), strides=(2, 2), padding='same',
                activation='relu',kernel_initializer='random_normal')

layer6 = Flatten()
layer7= Dense(4, activation='softmax',kernel_initializer='random_normal')

model_pre = Sequential([keras.Input(shape=(28,28,1)), 
                          layer1, layer2, layer3,layer4,layer5,
                          layer6,layer7])
model_pre.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             (None, 14, 14, 64)        640       
                                                                 
 max_pooling2d (MaxPooling2D  (None, 7, 7, 64)         0         
 )                                                               
                                                                 
 conv2d_1 (Conv2D)           (None, 7, 7, 32)          18464     
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 3, 3, 32)         0         
 2D)                                                             
                                                                 
 conv2d_2 (Conv2D)           (None, 2, 2, 16)          4624      
                                                                 
 flatten (Flatten)           (None, 64)                0         
                                                                 
 dense (Dense)               (None, 4)                 260       
                                                                 
=================================================================
Total params: 23,988
Trainable params: 23,988
Non-trainable params: 0
_________________________________________________________________

3.4 Train pretext model

In [ ]:
sgd = keras.optimizers.SGD(learning_rate = 0.001,momentum = 0.9)
model_pre.compile(loss = 'categorical_crossentropy', optimizer = sgd, metrics = ['accuracy'])
hist= model_pre.fit(X_rotate, Y_rotate, batch_size = 192, epochs = 50,verbose = 2, shuffle=False)
Epoch 1/50
21/21 - 2s - loss: 1.6171 - accuracy: 0.3975 - 2s/epoch - 94ms/step
Epoch 2/50
21/21 - 1s - loss: 0.7771 - accuracy: 0.6865 - 1s/epoch - 67ms/step
Epoch 3/50
21/21 - 1s - loss: 0.5225 - accuracy: 0.8160 - 1s/epoch - 68ms/step
Epoch 4/50
21/21 - 1s - loss: 0.3773 - accuracy: 0.8755 - 1s/epoch - 67ms/step
Epoch 5/50
21/21 - 1s - loss: 0.2949 - accuracy: 0.9038 - 1s/epoch - 67ms/step
Epoch 6/50
21/21 - 1s - loss: 0.2431 - accuracy: 0.9195 - 1s/epoch - 69ms/step
Epoch 7/50
21/21 - 1s - loss: 0.2037 - accuracy: 0.9342 - 1s/epoch - 67ms/step
Epoch 8/50
21/21 - 1s - loss: 0.1764 - accuracy: 0.9415 - 1s/epoch - 66ms/step
Epoch 9/50
21/21 - 1s - loss: 0.1547 - accuracy: 0.9510 - 1s/epoch - 66ms/step
Epoch 10/50
21/21 - 1s - loss: 0.1376 - accuracy: 0.9545 - 1s/epoch - 66ms/step
Epoch 11/50
21/21 - 1s - loss: 0.1240 - accuracy: 0.9582 - 1s/epoch - 67ms/step
Epoch 12/50
21/21 - 1s - loss: 0.1131 - accuracy: 0.9635 - 1s/epoch - 66ms/step
Epoch 13/50
21/21 - 1s - loss: 0.1087 - accuracy: 0.9655 - 1s/epoch - 69ms/step
Epoch 14/50
21/21 - 1s - loss: 0.1150 - accuracy: 0.9582 - 1s/epoch - 67ms/step
Epoch 15/50
21/21 - 1s - loss: 0.1306 - accuracy: 0.9520 - 1s/epoch - 66ms/step
Epoch 16/50
21/21 - 1s - loss: 0.1857 - accuracy: 0.9295 - 1s/epoch - 67ms/step
Epoch 17/50
21/21 - 1s - loss: 0.1415 - accuracy: 0.9485 - 1s/epoch - 66ms/step
Epoch 18/50
21/21 - 1s - loss: 0.1681 - accuracy: 0.9413 - 1s/epoch - 66ms/step
Epoch 19/50
21/21 - 1s - loss: 0.1142 - accuracy: 0.9572 - 1s/epoch - 67ms/step
Epoch 20/50
21/21 - 1s - loss: 0.0865 - accuracy: 0.9693 - 1s/epoch - 67ms/step
Epoch 21/50
21/21 - 1s - loss: 0.0762 - accuracy: 0.9720 - 1s/epoch - 66ms/step
Epoch 22/50
21/21 - 1s - loss: 0.0954 - accuracy: 0.9625 - 1s/epoch - 67ms/step
Epoch 23/50
21/21 - 1s - loss: 0.1278 - accuracy: 0.9480 - 1s/epoch - 67ms/step
Epoch 24/50
21/21 - 1s - loss: 0.1229 - accuracy: 0.9535 - 1s/epoch - 66ms/step
Epoch 25/50
21/21 - 1s - loss: 0.0919 - accuracy: 0.9640 - 1s/epoch - 67ms/step
Epoch 26/50
21/21 - 1s - loss: 0.0570 - accuracy: 0.9810 - 1s/epoch - 67ms/step
Epoch 27/50
21/21 - 1s - loss: 0.0568 - accuracy: 0.9808 - 1s/epoch - 67ms/step
Epoch 28/50
21/21 - 1s - loss: 0.0608 - accuracy: 0.9795 - 1s/epoch - 68ms/step
Epoch 29/50
21/21 - 1s - loss: 0.0646 - accuracy: 0.9762 - 1s/epoch - 67ms/step
Epoch 30/50
21/21 - 1s - loss: 0.0729 - accuracy: 0.9712 - 1s/epoch - 68ms/step
Epoch 31/50
21/21 - 1s - loss: 0.0889 - accuracy: 0.9665 - 1s/epoch - 66ms/step
Epoch 32/50
21/21 - 1s - loss: 0.1087 - accuracy: 0.9620 - 1s/epoch - 66ms/step
Epoch 33/50
21/21 - 1s - loss: 0.1336 - accuracy: 0.9532 - 1s/epoch - 68ms/step
Epoch 34/50
21/21 - 1s - loss: 0.1815 - accuracy: 0.9330 - 1s/epoch - 67ms/step
Epoch 35/50
21/21 - 1s - loss: 0.2080 - accuracy: 0.9298 - 1s/epoch - 68ms/step
Epoch 36/50
21/21 - 1s - loss: 0.1263 - accuracy: 0.9510 - 1s/epoch - 66ms/step
Epoch 37/50
21/21 - 1s - loss: 0.0512 - accuracy: 0.9825 - 1s/epoch - 67ms/step
Epoch 38/50
21/21 - 1s - loss: 0.0305 - accuracy: 0.9920 - 1s/epoch - 66ms/step
Epoch 39/50
21/21 - 1s - loss: 0.0297 - accuracy: 0.9915 - 1s/epoch - 66ms/step
Epoch 40/50
21/21 - 1s - loss: 0.0276 - accuracy: 0.9935 - 1s/epoch - 66ms/step
Epoch 41/50
21/21 - 1s - loss: 0.0288 - accuracy: 0.9910 - 1s/epoch - 66ms/step
Epoch 42/50
21/21 - 1s - loss: 0.0321 - accuracy: 0.9902 - 1s/epoch - 65ms/step
Epoch 43/50
21/21 - 1s - loss: 0.0330 - accuracy: 0.9890 - 1s/epoch - 67ms/step
Epoch 44/50
21/21 - 1s - loss: 0.0341 - accuracy: 0.9877 - 1s/epoch - 67ms/step
Epoch 45/50
21/21 - 1s - loss: 0.0348 - accuracy: 0.9865 - 1s/epoch - 66ms/step
Epoch 46/50
21/21 - 1s - loss: 0.0333 - accuracy: 0.9883 - 1s/epoch - 66ms/step
Epoch 47/50
21/21 - 1s - loss: 0.0340 - accuracy: 0.9880 - 1s/epoch - 65ms/step
Epoch 48/50
21/21 - 1s - loss: 0.0562 - accuracy: 0.9780 - 1s/epoch - 66ms/step
Epoch 49/50
21/21 - 1s - loss: 0.0552 - accuracy: 0.9790 - 1s/epoch - 66ms/step
Epoch 50/50
21/21 - 1s - loss: 0.0630 - accuracy: 0.9747 - 1s/epoch - 67ms/step
In [ ]:
# Freeze the pretext model
model_pre.trainable=False

4 Build downstream task

4.1 Reshape dataset

In [ ]:
X_train =X_train.reshape(-1,28,28,1)
X_test =X_test.reshape(-1,28,28,1)
Y_train = tf.one_hot(Y_train, 10,on_value=1.0, off_value=0.0)
Y_test = tf.one_hot(Y_test, 10,on_value=1.0, off_value=0.0)

4.2 Build model

In [ ]:
layer9=Flatten()
layer10 = Dense(10,activation = 'softmax',kernel_initializer='random_normal') 
# new layer to classify 10 numbers

model_down = keras.Sequential([keras.Input(shape=(28,28,1)), 
                           layer1, layer2, layer3,layer4,layer9,layer10])
model_down.summary()
Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             (None, 14, 14, 64)        640       
                                                                 
 max_pooling2d (MaxPooling2D  (None, 7, 7, 64)         0         
 )                                                               
                                                                 
 conv2d_1 (Conv2D)           (None, 7, 7, 32)          18464     
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 3, 3, 32)         0         
 2D)                                                             
                                                                 
 flatten_1 (Flatten)         (None, 288)               0         
                                                                 
 dense_1 (Dense)             (None, 10)                2890      
                                                                 
=================================================================
Total params: 21,994
Trainable params: 2,890
Non-trainable params: 19,104
_________________________________________________________________

4.3 Train downstream model

In [ ]:
callback = tf.keras.callbacks.EarlyStopping(monitor='accuracy', patience=7)
#validation_split for training
split = 0.3
In [ ]:
model_down.compile(loss = 'categorical_crossentropy', optimizer = sgd, metrics = ['accuracy'])
hist_down= model_down.fit(X_train, Y_train, batch_size = 64, validation_split = split,
                  epochs = 50,verbose = 2, callbacks = callback)
Epoch 1/50
11/11 - 1s - loss: 145.5323 - accuracy: 0.1443 - val_loss: 95.5060 - val_accuracy: 0.1333 - 660ms/epoch - 60ms/step
Epoch 2/50
11/11 - 0s - loss: 112.1947 - accuracy: 0.3200 - val_loss: 108.7943 - val_accuracy: 0.1800 - 160ms/epoch - 15ms/step
Epoch 3/50
11/11 - 0s - loss: 55.6467 - accuracy: 0.5029 - val_loss: 63.7907 - val_accuracy: 0.4933 - 144ms/epoch - 13ms/step
Epoch 4/50
11/11 - 0s - loss: 27.0338 - accuracy: 0.6500 - val_loss: 26.3296 - val_accuracy: 0.6367 - 144ms/epoch - 13ms/step
Epoch 5/50
11/11 - 0s - loss: 17.3450 - accuracy: 0.7286 - val_loss: 11.5728 - val_accuracy: 0.7700 - 149ms/epoch - 14ms/step
Epoch 6/50
11/11 - 0s - loss: 7.8281 - accuracy: 0.8429 - val_loss: 9.8907 - val_accuracy: 0.8000 - 154ms/epoch - 14ms/step
Epoch 7/50
11/11 - 0s - loss: 6.1283 - accuracy: 0.8643 - val_loss: 9.0447 - val_accuracy: 0.8067 - 154ms/epoch - 14ms/step
Epoch 8/50
11/11 - 0s - loss: 4.1972 - accuracy: 0.8771 - val_loss: 7.6835 - val_accuracy: 0.8333 - 141ms/epoch - 13ms/step
Epoch 9/50
11/11 - 0s - loss: 4.1619 - accuracy: 0.8929 - val_loss: 8.4072 - val_accuracy: 0.7800 - 149ms/epoch - 14ms/step
Epoch 10/50
11/11 - 0s - loss: 4.2850 - accuracy: 0.8729 - val_loss: 6.6977 - val_accuracy: 0.8200 - 141ms/epoch - 13ms/step
Epoch 11/50
11/11 - 0s - loss: 4.5401 - accuracy: 0.8714 - val_loss: 8.8403 - val_accuracy: 0.8000 - 144ms/epoch - 13ms/step
Epoch 12/50
11/11 - 0s - loss: 3.5040 - accuracy: 0.9014 - val_loss: 8.3725 - val_accuracy: 0.7867 - 146ms/epoch - 13ms/step
Epoch 13/50
11/11 - 0s - loss: 2.9546 - accuracy: 0.9129 - val_loss: 5.3033 - val_accuracy: 0.8533 - 156ms/epoch - 14ms/step
Epoch 14/50
11/11 - 0s - loss: 1.6648 - accuracy: 0.9343 - val_loss: 7.1830 - val_accuracy: 0.8367 - 141ms/epoch - 13ms/step
Epoch 15/50
11/11 - 0s - loss: 1.6900 - accuracy: 0.9357 - val_loss: 5.0708 - val_accuracy: 0.8800 - 147ms/epoch - 13ms/step
Epoch 16/50
11/11 - 0s - loss: 1.2530 - accuracy: 0.9386 - val_loss: 5.0104 - val_accuracy: 0.8767 - 142ms/epoch - 13ms/step
Epoch 17/50
11/11 - 0s - loss: 1.0832 - accuracy: 0.9443 - val_loss: 5.9515 - val_accuracy: 0.8733 - 152ms/epoch - 14ms/step
Epoch 18/50
11/11 - 0s - loss: 1.2069 - accuracy: 0.9286 - val_loss: 6.4323 - val_accuracy: 0.8567 - 136ms/epoch - 12ms/step
Epoch 19/50
11/11 - 0s - loss: 1.6460 - accuracy: 0.9314 - val_loss: 5.5614 - val_accuracy: 0.8700 - 149ms/epoch - 14ms/step
Epoch 20/50
11/11 - 0s - loss: 1.3438 - accuracy: 0.9314 - val_loss: 6.0327 - val_accuracy: 0.8567 - 150ms/epoch - 14ms/step
Epoch 21/50
11/11 - 0s - loss: 1.0919 - accuracy: 0.9343 - val_loss: 4.9564 - val_accuracy: 0.8767 - 149ms/epoch - 14ms/step
Epoch 22/50
11/11 - 0s - loss: 0.8793 - accuracy: 0.9529 - val_loss: 5.2268 - val_accuracy: 0.8700 - 142ms/epoch - 13ms/step
Epoch 23/50
11/11 - 0s - loss: 0.7904 - accuracy: 0.9471 - val_loss: 5.1340 - val_accuracy: 0.8933 - 147ms/epoch - 13ms/step
Epoch 24/50
11/11 - 0s - loss: 0.4839 - accuracy: 0.9657 - val_loss: 6.1588 - val_accuracy: 0.8467 - 141ms/epoch - 13ms/step
Epoch 25/50
11/11 - 0s - loss: 0.6237 - accuracy: 0.9557 - val_loss: 5.4716 - val_accuracy: 0.8767 - 141ms/epoch - 13ms/step
Epoch 26/50
11/11 - 0s - loss: 0.3927 - accuracy: 0.9700 - val_loss: 5.0492 - val_accuracy: 0.8800 - 146ms/epoch - 13ms/step
Epoch 27/50
11/11 - 0s - loss: 0.4778 - accuracy: 0.9643 - val_loss: 5.8995 - val_accuracy: 0.8400 - 158ms/epoch - 14ms/step
Epoch 28/50
11/11 - 0s - loss: 0.4085 - accuracy: 0.9686 - val_loss: 5.0752 - val_accuracy: 0.8667 - 144ms/epoch - 13ms/step
Epoch 29/50
11/11 - 0s - loss: 0.3238 - accuracy: 0.9800 - val_loss: 5.1408 - val_accuracy: 0.8900 - 146ms/epoch - 13ms/step
Epoch 30/50
11/11 - 0s - loss: 0.1134 - accuracy: 0.9857 - val_loss: 4.5113 - val_accuracy: 0.8900 - 149ms/epoch - 14ms/step
Epoch 31/50
11/11 - 0s - loss: 0.1502 - accuracy: 0.9857 - val_loss: 4.5989 - val_accuracy: 0.8900 - 145ms/epoch - 13ms/step
Epoch 32/50
11/11 - 0s - loss: 0.1626 - accuracy: 0.9829 - val_loss: 5.0097 - val_accuracy: 0.8767 - 142ms/epoch - 13ms/step
Epoch 33/50
11/11 - 0s - loss: 0.0347 - accuracy: 0.9900 - val_loss: 4.9909 - val_accuracy: 0.8700 - 154ms/epoch - 14ms/step
Epoch 34/50
11/11 - 0s - loss: 0.0437 - accuracy: 0.9914 - val_loss: 4.6969 - val_accuracy: 0.8867 - 157ms/epoch - 14ms/step
Epoch 35/50
11/11 - 0s - loss: 0.0343 - accuracy: 0.9900 - val_loss: 4.5254 - val_accuracy: 0.8900 - 156ms/epoch - 14ms/step
Epoch 36/50
11/11 - 0s - loss: 0.0326 - accuracy: 0.9929 - val_loss: 4.8820 - val_accuracy: 0.8767 - 152ms/epoch - 14ms/step
Epoch 37/50
11/11 - 0s - loss: 0.0259 - accuracy: 0.9929 - val_loss: 5.0051 - val_accuracy: 0.8800 - 153ms/epoch - 14ms/step
Epoch 38/50
11/11 - 0s - loss: 0.0597 - accuracy: 0.9914 - val_loss: 4.7422 - val_accuracy: 0.8900 - 157ms/epoch - 14ms/step
Epoch 39/50
11/11 - 0s - loss: 0.0087 - accuracy: 0.9971 - val_loss: 4.8613 - val_accuracy: 0.8733 - 157ms/epoch - 14ms/step
Epoch 40/50
11/11 - 0s - loss: 0.0262 - accuracy: 0.9943 - val_loss: 4.7491 - val_accuracy: 0.8900 - 160ms/epoch - 15ms/step
Epoch 41/50
11/11 - 0s - loss: 0.0301 - accuracy: 0.9886 - val_loss: 5.3986 - val_accuracy: 0.8533 - 155ms/epoch - 14ms/step
Epoch 42/50
11/11 - 0s - loss: 0.0903 - accuracy: 0.9857 - val_loss: 5.0836 - val_accuracy: 0.8800 - 155ms/epoch - 14ms/step
Epoch 43/50
11/11 - 0s - loss: 0.0312 - accuracy: 0.9943 - val_loss: 4.6061 - val_accuracy: 0.8867 - 148ms/epoch - 13ms/step
Epoch 44/50
11/11 - 0s - loss: 8.9462e-04 - accuracy: 1.0000 - val_loss: 4.7458 - val_accuracy: 0.8867 - 148ms/epoch - 13ms/step
Epoch 45/50
11/11 - 0s - loss: 0.0054 - accuracy: 0.9986 - val_loss: 4.7773 - val_accuracy: 0.8900 - 150ms/epoch - 14ms/step
Epoch 46/50
11/11 - 0s - loss: 0.0202 - accuracy: 0.9943 - val_loss: 4.6549 - val_accuracy: 0.8833 - 159ms/epoch - 14ms/step
Epoch 47/50
11/11 - 0s - loss: 0.0152 - accuracy: 0.9971 - val_loss: 4.9255 - val_accuracy: 0.8867 - 142ms/epoch - 13ms/step
Epoch 48/50
11/11 - 0s - loss: 0.0131 - accuracy: 0.9971 - val_loss: 4.6451 - val_accuracy: 0.8833 - 146ms/epoch - 13ms/step
Epoch 49/50
11/11 - 0s - loss: 0.0612 - accuracy: 0.9929 - val_loss: 4.6479 - val_accuracy: 0.9000 - 145ms/epoch - 13ms/step
Epoch 50/50
11/11 - 0s - loss: 0.0673 - accuracy: 0.9957 - val_loss: 4.9052 - val_accuracy: 0.8933 - 146ms/epoch - 13ms/step

4.4 Result examples

In [ ]:
prediction = np.round(model_down.predict(X_train))
wrong_predictions = np.where(prediction != Y_train)[0]
correct_predictions = np.where(prediction == Y_train)[0]
img = X_train.reshape(-1,28,28)

plt.subplot(121)
plt.imshow(img[correct_predictions[0]], cmap = 'gray')
plt.title('Prediction : '+ str(np.argmax(prediction[correct_predictions[0]]))+
          '\n True answer : '+ str(np.argmax(Y_train[correct_predictions[0]])))
plt.axis('off') 

plt.subplot(122)
plt.imshow(img[wrong_predictions[0]], cmap = 'gray')
plt.title('Prediction : '+ str(np.argmax(prediction[wrong_predictions[0]]))+
          '\n True answer : '+ str(np.argmax(Y_train[wrong_predictions[0]])))
plt.axis('off') 
Out[ ]:
(-0.5, 27.5, 27.5, -0.5)

5 Build supervised Model

In [ ]:
#supervised model
model_super = Sequential()

model_super.add(Conv2D(64, kernel_size=(3, 3), strides=(2, 2),activation='relu', padding='same',
                  kernel_initializer='random_normal',input_shape=(28,28,1)))
model_super.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
model_super.add(Conv2D(32, kernel_size=(3, 3),strides=(1, 1), activation='relu', padding='same',
                  kernel_initializer='random_normal'))
model_super.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))

model_super.add(Flatten())
model_super.add(Dense(10, activation='softmax',kernel_initializer='random_normal'))

model_super.summary()
Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d_3 (Conv2D)           (None, 14, 14, 64)        640       
                                                                 
 max_pooling2d_2 (MaxPooling  (None, 7, 7, 64)         0         
 2D)                                                             
                                                                 
 conv2d_4 (Conv2D)           (None, 7, 7, 32)          18464     
                                                                 
 max_pooling2d_3 (MaxPooling  (None, 3, 3, 32)         0         
 2D)                                                             
                                                                 
 flatten_2 (Flatten)         (None, 288)               0         
                                                                 
 dense_2 (Dense)             (None, 10)                2890      
                                                                 
=================================================================
Total params: 21,994
Trainable params: 21,994
Non-trainable params: 0
_________________________________________________________________
In [ ]:
model_super.compile(loss = 'categorical_crossentropy', optimizer = sgd, metrics = ['accuracy'])
hist_super= model_super.fit(X_train, Y_train, batch_size = 64, validation_split = split, 
                  epochs = 50,verbose = 0, callbacks = callback)

6 Compare self-supervised learning and supervised learning

In [ ]:
plt.plot(range(len(hist_down.history['accuracy'])), hist_down.history['accuracy'], label='Self-supervised')
plt.plot(range(len(hist_super.history['accuracy'])), hist_super.history['accuracy'], label='Supervised')

plt.xlim([0, 50])
plt.ylim([0, 1])
plt.title('Training accuracy (' +str(round((1-split)*100,1))+ '% label)')
plt.legend()
plt.show()
In [ ]:
if split == 0:
    print(" 100% labeled data is using for training")
    
else:
    plt.plot(range(len(hist_down.history['accuracy'])), hist_down.history['val_accuracy'], label='Self-Supervised')
    plt.plot(range(len(hist_super.history['accuracy'])), hist_super.history['val_accuracy'], label='Supervised')
    plt.xlim([0, 50])
    plt.ylim([0, 1])
    plt.title('Validation accuracy (' +str(round((1-split)*100,1))+ '% label)')
    plt.legend()
    plt.show()
    
    
In [ ]:
eval_self = model_down.evaluate(X_test,Y_test,batch_size = 64,steps =10,verbose = 2)
WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 10 batches). You may need to use the repeat() function when building your dataset.
10/10 - 0s - loss: 2.7465 - accuracy: 0.9067 - 373ms/epoch - 37ms/step
In [ ]:
eval_super = model_super.evaluate(X_test,Y_test,batch_size = 64,steps =10, verbose = 2)
WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 10 batches). You may need to use the repeat() function when building your dataset.
10/10 - 0s - loss: 0.6078 - accuracy: 0.8500 - 84ms/epoch - 8ms/step
In [22]:
img = X_test.reshape(-1,28,28)
img.shape

prediction_self = np.argmax(model_down.predict(X_test),axis=1)
prediction_super = np.argmax(model_super.predict(X_test),axis=1)

plt.figure(figsize=(20,25))   
for n in range(10):
    
   
  plt.subplot(1,10,n+1)
  plt.imshow(img[n], cmap = 'gray')
  plt.axis('off') 

print('Prediction by self-supervise learning : ', prediction_self[0:10])
print('\nPrediction by supervised learning     : ', prediction_super[0:10])
print('\nTrue answer                           : ',np.argmax(Y_test[0:10],axis=1))
Prediction by self-supervise learning :  [7 2 1 0 4 1 4 9 5 9]

Prediction by supervised learning     :  [7 2 1 0 4 1 4 9 5 7]

True answer                           :  [7 2 1 0 4 1 4 9 5 9]