Self-supervised Learning
http://iailab.kaist.ac.kr/
Industrial AI Lab at KAIST
Table of Contents
Many images from https://amitness.com/2020/02/illustrated-self-supervised-learning/
1. Supervised Learning and Transfer Learning¶
Supervised pretraining on large labeled datasets has led to successful transfer learning
ImageNet
Pretrain for fine-grained image classification of 1000 classes
- Use feature representations for downstream tasks, e.g., object detection, image segmentation, and action recognition
But supervised pretraining comes at a cost …
Time-consuming and expensive to label datasets for new tasks
Domain expertise needed for specialized tasks
- Radiologists to label medical images
- Native speakers or language specialists for labeling text in different languages
To relieve the burden of labelling,
- Semi-supervised learning
- Weakly-supervised learning
- Unsupervised learning
Self-supervised learning
Self-supervised learning (SSL): supervise using labels generated from the data without any manual or weak label sources
- Sub-class of unsupervised learning
Idea: Hide or modify part of the input. Ask model to recover input or classify what changed
- Self-supervised task referred to as the pretext task can be formulated using only unlabeled data
- The features obtained from pretext tasks are transferred to downstream tasks like classification, object detection, and segmentation
Pretext Tasks
Solving the pretext tasks allow the model to learn good features.
We can automatically generate labels for the pretext tasks.
2. Pretext Tasks¶
2.1. Pretext Task - Context Prediction¶
After creating 9 patches from one input image, the classifier is trained on the location information between the middle and other patches
A pair of middle patch and other patch is given as the input for the network
Method to avoid trivial solutions
- uneven spacing between patches
Carl Doersch, Abhinav Gupta, Alexei A. Efros, 2015, "Unsupervised Visual Representation Learning by Context Prediction," Proceedings of the IEEE International Conference on Computer Vision (ICCV), pp. 1422-1430.
2.2. Pretext Task - Jigsaw Puzzle¶
Generate 9 patches from the input image
After shuffling the patches, learn a classifier that predicts permutations to return to the original position
Noroozi, M., and Favaro, P., 2016, "Unsupervised Learning of Visual Representations by Solving Jigsaw Puzzles," Computer Vision – ECCV 2016, 69–84.
2.3. Pretext Task - Image Colorization¶
Given a grayscale photograph as input, image colorization attacks the problem of hallucinating a plausible color version of the photograph
Transfer the trained encoder to the downstream task
Zhang, R., Isola, P., and Efros, A. A., 2016, "Colorful Image Colorization," Computer Vision – ECCV 2016, 649–666.
- Training data generation for self-supervised learning
- Network architecture
2.4. Pretext Task - Image Super-resolution¶
What if we prepared training pairs of (small, upscaled) images by downsampling millions of images we have freely available?
Training data generation for self-supervised learning
- Network architecture
2.5. Pretext Task - Image Inpainting¶
What if we prepared training pairs of (corrupted, fixed) images by randomly removing part of images?
Training data generation for self-supervised learning
- Network architecture
3. Self-supervised Learning¶
Benefits of Self-supervised Learning
Like supervised pretraining, can learn general-purpose feature representations for downstream tasks
Reduce expense of hand-labeling large datasets
Can leverage nearly unlimited unlabeled data available on the web
Pipeline of Self-supervised Learning
Within pretext tasks, deep neural network learns visual features of input unlabeled data
The learned parameters of the network remain fixed and the trained network serves as a pre-trained model for downstream tasks
The pre-trained model is transferred to downstream tasks and is fine-tuned
The performance of downstream tasks is used to evaluate the methodology used in pretext tasks to learn features from unlabeled data
Jing, L., & Tian, Y., 2021, "Self-supervised visual feature learning with Deep Neural Networks: A survey," IEEE Transactions on Pattern Analysis and Machine Intelligence, 43(11), 4037–4058.
Downstream Tasks
After transferring the neural network pre-trained by the pretext task, freeze the weights and build additional layers for the downstream tasks
Wide variety of downstream tasks
- Classification
- Regression
- Object detection
- Segmentation
4. Self-supervised Learning with TensorFlow¶
Pretext Task - Rotation
RotNet
Hypothesis: a model could recognize the correct rotation of an object only if it has the “visual commonsense” of what the object should look like
- Self-supervised learning by rotating the entire input images
- The model learns to predict which rotation is applied (4-way classification)
- RotNet: Supervised vs Self-supervised
- The accuracy gap between the RotNet based model and the fully supervised Network-In-Network (NIN) model is very small, only 1.64% points
- We do not need data labels to train the RotNet based model but achieved similar accuracy with that of the model which used data labels for training
Import Library
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
Load MNIST Data
(X_train, Y_train), (X_test, Y_test) = tf.keras.datasets.mnist.load_data()
XX_train = X_train[10000:11000]
YY_train = Y_train[10000:11000]
X_train = X_train[:10000]
Y_train = Y_train[:10000]
XX_test = X_test[300:600]
YY_test = Y_test[300:600]
X_test = X_test[:300]
Y_test = Y_test[:300]
print('shape of x_train:', X_train.shape)
print('shape of y_train:', Y_train.shape)
print('shape of xx_train:', XX_train.shape)
print('shape of yy_train:', YY_train.shape)
print('shape of x_test:', X_test.shape)
print('shape of y_test:', Y_test.shape)
print('shape of xx_test:', XX_test.shape)
print('shape of yy_test:', YY_test.shape)
4.1. Build RotNet for Pretext Task¶
Dataset for Pretext Task (Rotation)
Need to generate rotated images and their labels to train the model for pretext task
- [1, 0, 0, 0]: 0$^\circ $ rotation
- [0, 1, 0, 0]: 90$^\circ $ rotation
- [0, 0, 1, 0]: 180$^\circ $ rotation
- [0, 0, 0, 1]: 270$^\circ $ rotation
n_samples = X_train.shape[0]
X_rotate = np.zeros(shape = (n_samples*4,
X_train.shape[1],
X_train.shape[2]))
Y_rotate = np.zeros(shape = (n_samples*4, 4))
for i in range(n_samples):
img = X_train[i]
X_rotate[4*i-4] = img
Y_rotate[4*i-4] = tf.one_hot([0], depth = 4)
# 90 degrees rotation
X_rotate[4*i-3] = np.rot90(img, k = 1)
Y_rotate[4*i-3] = tf.one_hot([1], depth = 4)
# 180 degrees rotation
X_rotate[4*i-2] = np.rot90(img, k = 2)
Y_rotate[4*i-2] = tf.one_hot([2], depth = 4)
# 270 degrees rotation
X_rotate[4*i-1] = np.rot90(img, k = 3)
Y_rotate[4*i-1] = tf.one_hot([3], depth = 4)
Plot Dataset for Pretext Task (Rotation)
plt.figure(figsize = (10, 10))
plt.subplot(141)
plt.imshow(X_rotate[12], cmap = 'gray')
plt.axis('off')
plt.subplot(142)
plt.imshow(X_rotate[13], cmap = 'gray')
plt.axis('off')
plt.subplot(143)
plt.imshow(X_rotate[14], cmap = 'gray')
plt.axis('off')
plt.subplot(144)
plt.imshow(X_rotate[15], cmap = 'gray')
plt.axis('off')
X_rotate = X_rotate.reshape(-1,28,28,1)
Build Model for Pretext Task (Rotation)
model_pretext = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(filters = 64,
kernel_size = (3,3),
strides = (2,2),
activation = 'relu',
padding = 'SAME',
input_shape = (28, 28, 1)),
tf.keras.layers.MaxPool2D(pool_size = (2, 2),
strides = (2, 2)),
tf.keras.layers.Conv2D(filters = 32,
kernel_size = (3,3),
strides = (1,1),
activation = 'relu',
padding = 'SAME',
input_shape = (7, 7, 64)),
tf.keras.layers.MaxPool2D(pool_size = (2, 2),
strides = (2, 2)),
tf.keras.layers.Conv2D(filters = 16,
kernel_size = (3,3),
strides = (2,2),
activation = 'relu',
padding = 'SAME',
input_shape = (3, 3, 32)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(units = 4, activation = 'softmax')
])
model_pretext.summary()
- Training the model for the pretext task
model_pretext.compile(optimizer = 'adam',
loss = 'categorical_crossentropy',
metrics = 'accuracy')
model_pretext.fit(X_rotate,
Y_rotate,
batch_size = 192,
epochs = 50,
verbose = 0,
shuffle = False)
4.2. Build Downstream Task (MNIST Image Classification)¶
- Freezing trained parameters to transfer them for the downstream task
model_pretext.trainable = False
Reshape Dataset
XX_train = XX_train.reshape(-1,28,28,1)
XX_test = XX_test.reshape(-1,28,28,1)
YY_train = tf.one_hot(YY_train, 10,on_value = 1.0, off_value = 0.0)
YY_test = tf.one_hot(YY_test, 10,on_value = 1.0, off_value = 0.0)
Build Model
- Model: two convolution layers and one fully connected layer
- Two convolution layers are transferred from the model for the pretext task
- Single fully connected layer is trained only
model_downstream = tf.keras.models.Sequential([
model_pretext.get_layer(index = 0),
model_pretext.get_layer(index = 1),
model_pretext.get_layer(index = 2),
model_pretext.get_layer(index = 3),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(units = 10, activation = 'softmax')
])
model_downstream.summary()
model_downstream.compile(optimizer = tf.keras.optimizers.SGD(learning_rate = 0.001,momentum = 0.9),
loss = 'categorical_crossentropy',
metrics = 'accuracy')
model_downstream.fit(XX_train,
YY_train,
batch_size = 64,
validation_split = 0.2,
epochs = 50,
verbose = 0,
callbacks = tf.keras.callbacks.EarlyStopping(monitor = 'accuracy', patience = 7))
Downstream Task Trained Result (Image Classification Result)
name = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
idx = 9
img = XX_train[idx].reshape(-1,28,28,1)
label = YY_train[idx]
predict = model_downstream.predict(img)
mypred = np.argmax(predict, axis = 1)
plt.figure(figsize = (8, 4))
plt.subplot(1,2,1)
plt.imshow(img.reshape(28, 28), 'gray')
plt.axis('off')
plt.subplot(1,2,2)
plt.stem(predict[0])
plt.show()
print('Prediction : {}'.format(name[mypred[0]]))
4.3. Build Supervised Model for Comparison¶
- Convolution Neural Networks for MNIST image classification
- Model: Same model architecture with the model for the downstream task
- The number of total parameter is the same with the model for the downstream task, but is has zero non-trainable parameters
model_sup = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(filters = 64,
kernel_size = (3,3),
strides = (2,2),
activation = 'relu',
padding = 'SAME',
input_shape = (28, 28, 1)),
tf.keras.layers.MaxPool2D(pool_size = (2, 2),
strides = (2, 2)),
tf.keras.layers.Conv2D(filters = 32,
kernel_size = (3,3),
strides = (1,1),
activation = 'relu',
padding = 'SAME',
input_shape = (7, 7, 64)),
tf.keras.layers.MaxPool2D(pool_size = (2, 2),
strides = (2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(units = 10, activation = 'softmax')
])
model_sup.summary()
model_sup.compile(optimizer = tf.keras.optimizers.SGD(learning_rate = 0.001, momentum = 0.9),
loss = 'categorical_crossentropy',
metrics = 'accuracy')
model_sup.fit(XX_train,
YY_train,
batch_size = 32,
validation_split = 0.2,
epochs = 50,
verbose = 0)
Compare Self-supervised Learning and Supervised Learning
Pretext Task
- Input data: 10,000 MNIST images without labels
Downstream Task and Supervised Learning (for performance comparison)
- Training data: 1,000 MNIST images with labels
- Test data: 300 MNIST images with labels
Key concepts
- For transfer learning, we used to train networks like VGG 16 with large image dataset with labels such as ImageNet
- With self-supervised learning, we train such networks with unlabeled image datasets which have larger number of data than labeled image datasets have and perform transfer learning
- Comparing downstream task performance with that of supervised learning is equal to comparing the performance of (self-supervised) transfer learning and supervised learning performance
test_self = model_downstream.evaluate(XX_test, YY_test, batch_size = 64, verbose = 2)
print("")
print('Self-supervised Learning Accuracy on Test Data: {:.2f}%'.format(test_self[1]*100))
test_sup = model_sup.evaluate(XX_test, YY_test, batch_size = 64, verbose = 2)
print("")
print('Supervised Learning Accuracy on Test Data: {:.2f}%'.format(test_sup[1]*100))
%%javascript
$.getScript('https://kmahelona.github.io/ipython_notebook_goodies/ipython_notebook_toc.js')