Self-supervised Learning
Table of Contents
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
import matplotlib.pyplot as plt
#Load MNIST dataset
(X_train, Y_train), (X_test, Y_test) = keras.datasets.mnist.load_data()
# using 1000 data for test
X_train=X_train[:1000]
Y_train=Y_train[:1000]
# using 300 data for test
X_test=X_test[:300]
Y_test=Y_test[:300]
print(X_train.shape)
3.1 Dateset for Pretext task(rotation)¶
n_samples = X_train.shape[0]
X_rotate= np.zeros(shape=(n_samples*4,X_train.shape[1],X_train.shape[2]))
Y_rotate =np.zeros(shape=(n_samples*4,4))
for i in range(n_samples):
img =X_train[i]
X_rotate[4*i-4]=img
Y_rotate[4*i-4]=[1,0,0,0]
# 90degree rotation: Transpose + Vertical flip
X_rotate[4*i-3]=np.flip(img.T,0)
Y_rotate[4*i-3]=[0,1,0,0]
# 180degree rotation: Vertical flip + Horizontal flip
X_rotate[4*i-2]=np.flip(np.flip(img,0),1)
Y_rotate[4*i-2]=[0,0,1,0]
# 270degree rotation: Vertical flip + Transpose
X_rotate[4*i-1]=np.flip(img,0).T
Y_rotate[4*i-1]=[0,0,0,1]
3.2 Check pretext task data set¶
plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=1, hspace=None)
plt.subplots(figsize=(10, 10))
plt.subplot(141)
plt.imshow(X_rotate[4], cmap = 'gray')
plt.title('label : ' + str(Y_rotate[4]))
plt.axis('off')
plt.subplot(142)
plt.imshow(X_rotate[5], cmap = 'gray')
plt.title('label : ' + str(Y_rotate[5]))
plt.axis('off')
plt.subplot(143)
plt.imshow(X_rotate[6], cmap = 'gray')
plt.title('label : ' + str(Y_rotate[6]))
plt.axis('off')
plt.subplot(144)
plt.imshow(X_rotate[7], cmap = 'gray')
plt.title('label : ' +str(Y_rotate[7]))
plt.axis('off')
# reshape MNIST image data for convolution
print(X_rotate[0].shape)
X_rotate =X_rotate.reshape(-1,28,28,1)
3.3 Build model¶
layer1 = Conv2D(64, kernel_size=(3, 3), strides=(2, 2), padding='same',
activation='relu',kernel_initializer='random_normal')
layer2 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))
layer3 = Conv2D(32, kernel_size=(3, 3), strides=(1, 1), padding='same',
activation='relu' ,kernel_initializer='random_normal')
layer4 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))
layer5 = Conv2D(16, kernel_size=(3, 3), strides=(2, 2), padding='same',
activation='relu',kernel_initializer='random_normal')
layer6 = Flatten()
layer7= Dense(4, activation='softmax',kernel_initializer='random_normal')
model_pre = Sequential([keras.Input(shape=(28,28,1)),
layer1, layer2, layer3,layer4,layer5,
layer6,layer7])
model_pre.summary()
3.4 Train pretext model¶
sgd = keras.optimizers.SGD(learning_rate = 0.001,momentum = 0.9)
model_pre.compile(loss = 'categorical_crossentropy', optimizer = sgd, metrics = ['accuracy'])
hist= model_pre.fit(X_rotate, Y_rotate, batch_size = 192, epochs = 50,verbose = 2, shuffle=False)
# Freeze the pretext model
model_pre.trainable=False
4.1 Reshape dataset¶
X_train =X_train.reshape(-1,28,28,1)
X_test =X_test.reshape(-1,28,28,1)
Y_train = tf.one_hot(Y_train, 10,on_value=1.0, off_value=0.0)
Y_test = tf.one_hot(Y_test, 10,on_value=1.0, off_value=0.0)
4.2 Build model¶
layer9=Flatten()
layer10 = Dense(10,activation = 'softmax',kernel_initializer='random_normal')
# new layer to classify 10 numbers
model_down = keras.Sequential([keras.Input(shape=(28,28,1)),
layer1, layer2, layer3,layer4,layer9,layer10])
model_down.summary()
4.3 Train downstream model¶
callback = tf.keras.callbacks.EarlyStopping(monitor='accuracy', patience=7)
#validation_split for training
split = 0.3
model_down.compile(loss = 'categorical_crossentropy', optimizer = sgd, metrics = ['accuracy'])
hist_down= model_down.fit(X_train, Y_train, batch_size = 64, validation_split = split,
epochs = 50,verbose = 2, callbacks = callback)
4.4 Result examples¶
prediction = np.round(model_down.predict(X_train))
wrong_predictions = np.where(prediction != Y_train)[0]
correct_predictions = np.where(prediction == Y_train)[0]
img = X_train.reshape(-1,28,28)
plt.subplot(121)
plt.imshow(img[correct_predictions[0]], cmap = 'gray')
plt.title('Prediction : '+ str(np.argmax(prediction[correct_predictions[0]]))+
'\n True answer : '+ str(np.argmax(Y_train[correct_predictions[0]])))
plt.axis('off')
plt.subplot(122)
plt.imshow(img[wrong_predictions[0]], cmap = 'gray')
plt.title('Prediction : '+ str(np.argmax(prediction[wrong_predictions[0]]))+
'\n True answer : '+ str(np.argmax(Y_train[wrong_predictions[0]])))
plt.axis('off')
#supervised model
model_super = Sequential()
model_super.add(Conv2D(64, kernel_size=(3, 3), strides=(2, 2),activation='relu', padding='same',
kernel_initializer='random_normal',input_shape=(28,28,1)))
model_super.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
model_super.add(Conv2D(32, kernel_size=(3, 3),strides=(1, 1), activation='relu', padding='same',
kernel_initializer='random_normal'))
model_super.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
model_super.add(Flatten())
model_super.add(Dense(10, activation='softmax',kernel_initializer='random_normal'))
model_super.summary()
model_super.compile(loss = 'categorical_crossentropy', optimizer = sgd, metrics = ['accuracy'])
hist_super= model_super.fit(X_train, Y_train, batch_size = 64, validation_split = split,
epochs = 50,verbose = 0, callbacks = callback)
plt.plot(range(len(hist_down.history['accuracy'])), hist_down.history['accuracy'], label='Self-supervised')
plt.plot(range(len(hist_super.history['accuracy'])), hist_super.history['accuracy'], label='Supervised')
plt.xlim([0, 50])
plt.ylim([0, 1])
plt.title('Training accuracy (' +str(round((1-split)*100,1))+ '% label)')
plt.legend()
plt.show()
if split == 0:
print(" 100% labeled data is using for training")
else:
plt.plot(range(len(hist_down.history['accuracy'])), hist_down.history['val_accuracy'], label='Self-Supervised')
plt.plot(range(len(hist_super.history['accuracy'])), hist_super.history['val_accuracy'], label='Supervised')
plt.xlim([0, 50])
plt.ylim([0, 1])
plt.title('Validation accuracy (' +str(round((1-split)*100,1))+ '% label)')
plt.legend()
plt.show()
eval_self = model_down.evaluate(X_test,Y_test,batch_size = 64,steps =10,verbose = 2)
eval_super = model_super.evaluate(X_test,Y_test,batch_size = 64,steps =10, verbose = 2)
img = X_test.reshape(-1,28,28)
img.shape
prediction_self = np.argmax(model_down.predict(X_test),axis=1)
prediction_super = np.argmax(model_super.predict(X_test),axis=1)
plt.figure(figsize=(20,25))
for n in range(10):
plt.subplot(1,10,n+1)
plt.imshow(img[n], cmap = 'gray')
plt.axis('off')
print('Prediction by self-supervise learning : ', prediction_self[0:10])
print('\nPrediction by supervised learning : ', prediction_super[0:10])
print('\nTrue answer : ',np.argmax(Y_test[0:10],axis=1))