Contrastive Learning

Table of Contents




1. Introduction

The success of deep learning has historically been driven by large-scale labeled datasets. Models trained on ImageNet, for instance, learn rich visual representations that transfer broadly across tasks. However, the process of collecting and annotating such datasets is expensive, time-consuming, and often infeasible in specialized domains such as medical imaging, satellite remote sensing, or scientific data analysis.

This raises a fundamental question: can a model learn useful representations from data alone, without any human-provided labels?


Contrastive learning is one of the most effective answers to this question. Rather than relying on annotated labels, contrastive learning defines a self-supervised pretext task directly from the structure of the data itself. The core principle is intuitive: two views of the same input should be represented similarly, while representations of different inputs should be kept apart. By optimizing this objective over large collections of unlabeled data, an encoder learns to organize the representation space according to semantic similarity — without ever being told what the semantic categories are.


This approach has proven remarkably effective. Modern contrastive learning methods such as SimCLR, MoCo, and CLIP have demonstrated that self-supervised representations can match or even surpass fully supervised baselines on a wide range of downstream tasks, including image classification, object detection, and cross-modal retrieval. In the case of CLIP, contrastive pretraining on web-scale image-text pairs yields representations that generalize to entirely new tasks in a zero-shot manner — without any task-specific fine-tuning.


The appeal of contrastive learning extends beyond its empirical performance. It offers a principled framework for thinking about representation learning: what invariances should a good representation encode? What information should be preserved, and what discarded? These questions connect contrastive learning to information theory, geometry, and the broader study of inductive biases in neural networks.


This course covers the theoretical foundations and practical implementation of contrastive learning. We begin with the core concepts and loss functions that underpin the framework, proceed through the landmark methods that have shaped the field, and conclude with advanced topics including hard negative mining, negative-free methods, and vision-language pretraining. Throughout, emphasis is placed on both rigorous understanding and hands-on implementation, with all experiments conducted in Google Colab.


By the end of this course, students will have a thorough understanding of why contrastive learning works, how its key components interact, and how to apply and extend these methods in their own research.



Objective Function

The core idea is to train an encoder that maps inputs into a representation space where:

  • Positive pairs (e.g., different views or augmentations of the same input) are mapped close together, and
  • Negative pairs (e.g., representations of distinct inputs) are mapped far apart.

Through this process, the encoder learns to capture semantic similarities and differences without relying on explicit labels. Contrastive learning serves as a foundational principle in many modern representation learning frameworks, including SimCLR, MoCo, and CLIP.


The figure below illustrates the contrastive learning objective in a feature space.

  • The anchor image (a raccoon, top right) is passed through the embedding network $\theta$ to produce a gray embedding point at the center of the feature space. Two other inputs are simultaneously embedded: a positive example (a raccoon of a different appearance, bottom right) and a negative example (an echidna, left), both processed by the same network $\theta$.

  • The positive embedding $\theta(I^+)$, shown in green, is placed close to the anchor, with a small distance $d^+$.

  • The negative embedding $\theta(I^-)$, shown in red, is placed far from the anchor, with a large distance $d^-$.

  • The training objective drives the model to simultaneously minimize $d^+$ — pulling semantically similar inputs together — and maximize $d^-$ — pushing semantically dissimilar inputs apart.


Note that all three inputs share the same embedding network $\theta$; there are no separate networks for anchor, positive, or negative. The network is trained end-to-end so that the geometry of the resulting feature space reflects semantic relationships among inputs.



A common loss used in contrastive learning is the contrastive loss:


$$ \mathcal{L} = \mathbb{I}[y = 1] \cdot d^2 + \mathbb{I}[y = 0] \cdot \max(0, m - d)^2 $$


where:

  • $d = \lVert \theta(x_i) - \theta(x_j) \rVert_2$ is the Euclidean distance between two embeddings,
  • $y = 1$ for a positive pair, and $y = 0$ for a negative pair,
  • $m$ is a margin that negative pairs must exceed.

Of course, this contrastive loss is just one of many possible loss functions, which we will study later.


A common loss function used in contrastive learning is the contrastive loss:


$$ \mathcal{L} = \mathbb{I}[y = 1] \cdot d^2 + \mathbb{I}[y = 0] \cdot \max(0, m - d)^2 $$


where:

  • $y = 1$ for a positive pair, and $y = 0$ for a negative pair,

  • $d = \lVert \theta(x_i) - \theta(x_j) \rVert_2$ is the Euclidean distance between two embeddings,

  • $m$ is a margin hyperparameter that negative pairs are required to exceed.


The first term penalizes large distances between positive pairs, encouraging the encoder to map semantically similar inputs close together. The second term penalizes negative pairs whose distance falls below the margin $m$, pushing dissimilar inputs apart by at least this threshold. When the distance between a negative pair already exceeds $m$, the loss contribution is zero and no gradient is applied.

This contrastive loss is one of several objective functions used in the field. More expressive alternatives, including the triplet loss and the InfoNCE loss, will be examined in subsequent sections.


1.1. Positive and Negative Samples in Contrastive Learning

In contrastive learning, the training process depends heavily on how positive and negative sample pairs are constructed.


1.1.1. Instance Discrimination Method



Unsupervised Contrastive Learning

(1) Positive Samples

A positive sample is another view or augmentation of the same underlying instance (anchor). It is assumed to have the same semantic content.


For example if the anchor is an image of a dog, then:

  • A mirrored version
  • A grayscale version
  • Or any strongly augmented version (e.g., cropped or jittered) can serve as a positive sample.

Common augmentations used to create positive samples include:

  • Color jitter
  • Rotation
  • Horizontal/vertical flipping
  • Gaussian noise
  • Random affine transformations

These augmentations preserve the class identity while forcing the network to learn invariance to superficial changes.



(2) Negative Samples

A negative sample is any image that is semantically different from the anchor.

  • In unsupervised settings, negative samples are typically drawn randomly from the rest of the dataset.
  • There is no guarantee that they belong to a different class, but statistically, this assumption holds if the dataset is large and diverse.



Supervised Contrastive Learning

When class labels are available, we can improve contrastive learning by using label information to define more meaningful positives.

  • In this setting, positive samples include all examples from the same class as the anchor, not just augmentations.
  • Negative samples are drawn from other classes.

This approach encourages embeddings of all instances with the same label to cluster tightly in the representation space, improving performance especially on downstream classification tasks.



1.1.2. Image Subsampling/Patching Method

Another approach to constructing positive and negative pairs is to use image patches instead of entire images.

  • Positive pairs are formed by extracting different patches from the same image. These patches may capture different regions but still share the same semantic identity.
  • Negative pairs are formed by pairing patches from different images.



1.2. Objectives in Contrastive Learning


The objective function is the central component of any contrastive learning framework. It determines how the encoder is trained to organize the representation space, and different objectives encode different assumptions about what makes a good representation. In this section, we examine the three most influential loss functions in contrastive learning: the contrastive loss, the triplet loss, and the InfoNCE loss.


1.2.1. Contrastive Loss

The contrastive loss (Hadsell et al., 2006) is the earliest and most direct formulation of the contrastive learning objective. It operates on pairs of inputs and applies one of two penalties depending on whether the pair is positive or negative:


$$ \mathcal{L}_{\text{contrastive}} = \mathbb{I}[y=1] \cdot d^2 + \mathbb{I}[y=0] \cdot \max(0,\, m - d)^2 $$


where $d = \lVert \theta(x_i) - \theta(x_j) \rVert_2$ is the Euclidean distance between two embeddings, $y = 1$ for a positive pair and $y = 0$ for a negative pair, and $m > 0$ is a margin hyperparameter.

Intuition. Think of the embedding space as a physical room where each input occupies a position. The contrastive loss acts like a set of springs and repulsors. For a positive pair, a spring pulls the two points toward each other — the stronger the pull, the farther apart they currently are. For a negative pair, a repulsor pushes the two points apart, but only if they are closer than the margin $m$. Once a negative pair is already separated by at least $m$, the repulsor deactivates and contributes no gradient. The margin therefore defines a "safe zone" — once a negative pair is sufficiently separated, the model is no longer penalized for it.

The margin $m$ is a critical hyperparameter. If $m$ is too small, the model learns only weak separation between negatives. If $m$ is too large, the loss becomes dominated by easy negatives that are far apart but still within the margin, providing uninformative gradients.

Limitation. The contrastive loss considers only one pair at a time. This means the model has no sense of the global structure of the embedding space — it only knows whether two specific points should be close or far, without any reference to where all other points are. This makes it difficult to learn representations that are uniformly distributed and well-organized across the entire space.


1.2.2. Triplet Loss

The triplet loss (Schroff et al., 2015) addresses the pairwise limitation by considering triplets of inputs: an anchor $x$, a positive $x^+$, and a negative $x^-$. Rather than enforcing absolute distances, it enforces a relative constraint: the positive must be closer to the anchor than the negative, by at least a margin $m$.


$$ \mathcal{L}_{\text{triplet}} = \max\!\left(0,\; \lVert \theta(x) - \theta(x^+) \rVert_2^2 - \lVert \theta(x) - \theta(x^-) \rVert_2^2 + m\right) $$


Intuition. The triplet loss introduces a sense of competition between the positive and negative. Rather than asking "is this positive pair close enough?" and "is this negative pair far enough?" as two separate questions, the triplet loss asks a single comparative question: "is the positive closer to the anchor than the negative?" This is a more natural formulation of the learning objective, since what matters in practice is not the absolute position of any embedding, but its position relative to others.

Geometrically, for each anchor, imagine drawing a sphere whose radius equals the anchor-to-positive distance. The triplet loss is satisfied — and contributes no gradient — only when the negative lies outside this sphere with an additional buffer of margin $m$. If the negative lies inside, the loss simultaneously pushes the positive closer and the negative farther until the constraint is satisfied.

Limitation. Like the contrastive loss, the triplet loss still considers only a small local neighborhood — one positive and one negative per anchor. It provides no information about the broader structure of the embedding space. Furthermore, hard negative mining introduces additional complexity and can lead to training instability if the mined negatives are too hard, that is, if they are actual false negatives.


1.2.3. InfoNCE Loss

The InfoNCE loss (van den Oord et al., 2018) represents a fundamental shift in how the contrastive objective is formulated. Rather than considering pairs or triplets, it frames the problem as an $(N-1)$-way classification task: given an anchor, identify its positive among a set of $N-1$ negatives. This allows the model to consider the global structure of the entire batch simultaneously.


$$ \mathcal{L}_{\text{InfoNCE}} = -\frac{1}{N} \sum_{i=1}^{N} \log \frac{\exp(\hat{z}_i \cdot \hat{z}_j / \tau)}{\sum_{k=1}^{N} \mathbb{I}_{[k \neq i]}\, \exp(\hat{z}_i \cdot \hat{z}_k / \tau)} $$


where $\hat{z} = z / \lVert z \rVert_2$ denotes the $\ell_2$-normalized embedding, $\tau > 0$ is a temperature hyperparameter, and $j$ is the index of the positive pair of anchor $i$.

Intuition. The InfoNCE loss can be understood as a softmax classifier over similarities. For each anchor, the model computes a similarity score between the anchor and every other embedding in the batch. These scores are passed through a softmax to produce a probability distribution over all candidates. The loss then maximizes the probability assigned to the true positive — exactly as in standard cross-entropy classification, where the "classes" are the candidate embeddings and the "label" is the index of the positive.

This formulation has a natural probabilistic interpretation: the model is learning to answer the question "which of these embeddings came from the same input as the anchor?" With a good representation, the positive should receive a much higher similarity score than any negative, making the softmax distribution sharply peaked at the correct answer.

A key advantage over pairwise and triplet losses is that every negative in the batch contributes to every anchor's loss simultaneously. With a batch size of $N$, each anchor is contrasted against $N-1$ negatives in a single forward pass. This provides dense, informative gradient signal without requiring any explicit mining strategy.

Temperature $\tau$. The temperature controls how sharply the softmax distribution is peaked. A low temperature amplifies differences in similarity scores, concentrating the gradient on the hardest negatives. A high temperature smooths the distribution, treating all negatives more equally. In practice, $\tau$ is typically set in the range $[0.07, 0.5]$ and treated as a tunable hyperparameter.

Connection to mutual information. The name InfoNCE reflects a deeper theoretical connection: the loss is a lower bound on the mutual information $I(z; z^+)$ between the representations of the two views:


$$ I(z;\, z^+) \geq \log N - \mathcal{L}_{\text{InfoNCE}} $$


This bound tightens as the batch size $N$ increases, providing a principled justification for using large batches in contrastive training. Maximizing $N$ is equivalent to tightening the lower bound and extracting more information from the pretext task.



1.3. General Scheme for Contrastive Learning

SimCLR (A Simple Framework for Contrastive Learning of Visual Representations) is a foundational contrastive learning method introduced by Chen et al. (2020). It demonstrated that carefully designed data augmentations, a suitable contrastive loss, and large batch sizes are sufficient to learn high-quality image representations in a self-supervised setting.



(1) Data Augmentation Strategy

The process begins by applying two independent random augmentations to the same input image. These transformations, denoted by $T$, typically include:

  • Random cropping and resizing
  • Color jitter
  • Gaussian blur
  • Horizontal flip

This results in two different views of the same image, $x_i$ and $x_j$, which are treated as a positive pair.


(2) Encoder Network



Each of the transformed images is passed through a shared base encoder network $f(\cdot)$ (typically a ResNet-50). The encoder maps each input view into a representation vector:


$$ h_i = f(x_i), \quad h_j = f(x_j) $$


These embeddings $h_i$ and $h_j$ are high-dimensional and are intended to capture rich semantic features. However, they are not used directly for contrastive loss computation.


(3) Projection Head



The output representations $h_i$ and $h_j$ are further processed by a projection head $g(\cdot)$, typically a 2-layer MLP with ReLU activation in between:


$$ z_i = g(h_i), \quad z_j = g(h_j) $$


The projection head maps the representations into a space where the contrastive loss is applied. Empirically, the authors showed that contrastive learning benefits from applying the loss in this transformed space, and discarding the projection head after pretraining improves downstream performance.

The model is trained to maximize similarity between $z_i$ and $z_j$, while minimizing similarity between $z_i$ and all other negatives in the batch.


Why Does the Projection Head Help?

The projection head improves downstream performance by a large margin (up to 10% top-1 accuracy on ImageNet). The explanation offered by Chen et al. is as follows.

The contrastive loss enforces invariance to the augmentations applied to create positive pairs. However, some of the information discarded to achieve this invariance — such as exact color, orientation, and crop position — may actually be useful for certain downstream tasks.

By applying the contrastive loss to $z = g(h)$ rather than directly to $h$, the network can concentrate the augmentation-invariance pressure on the projection head. The representation $h$ is then free to retain more information about the input, some of which will be useful downstream.

This is sometimes described as: the projection head acts as a bottleneck that absorbs the invariances, protecting the encoder representation from unnecessary information loss.


(4) Contrastive Objective: NT-Xent Loss

The Normalized Temperature-scaled Cross Entropy Loss (NT-Xent) is applied to each positive pair and all other views in the batch are treated as negatives.

Given a batch of $N$ images, resulting in $2N$ augmented views, the loss for a positive pair $(z_i, z_j)$ is:


$$ \mathcal{L}_i = -\log \frac{\exp(\text{sim}(z_i, z_j)/\tau)}{\sum_{k=1}^{2N} \mathbb{1}_{[k \neq i]} \exp(\text{sim}(z_i, z_k)/\tau)} $$


where $\text{sim}(\cdot, \cdot)$ is cosine similarity and $\tau$ is a temperature parameter.


(5) Downstream Usage



Once pretraining is complete, the projection head $g(\cdot)$ is discarded. The base encoder $f(\cdot)$ is retained and used to generate representations for downstream tasks such as image classification, object detection, or segmentation.

In the downstream phase, either:

  • The encoder is frozen, and a linear classifier is trained on top (linear probing), or
  • The encoder is fine-tuned along with the task-specific head.

SimCLR illustrates that even without labels, strong representation learning is possible through contrastive objectives, heavy data augmentation, and careful architectural design.



1.4. Summary

Contrastive learning demonstrates that strong representation learning is achievable even in the absence of labels. This is made possible by the combination of:

  • Carefully designed contrastive objectives (e.g., NT-Xent loss),

  • Heavy data augmentation to generate diverse yet semantically consistent views,

  • And a thoughtfully constructed architecture, including a projection head and encoder backbone.


By optimizing contrastive loss functions, the model learns to structure the embedding space in a meaningful way:

  • Samples from the same class (positive pairs) are pulled closer together,

  • Samples from different classes (negative pairs) are pushed farther apart.


This tight intra-class clustering and inter-class separation is what defines strong representation learning - it results in embeddings that are highly discriminative and well-suited for downstream tasks such as classification, detection, or segmentation.




2. Contrastive Learning with TensorFlow



Import Library

In [ ]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

Load MNIST Data

In [ ]:
(X_train_all, Y_train_all), (X_test_all, Y_test_all) = tf.keras.datasets.mnist.load_data()
X_train = X_train_all[:10000]
XX_train = X_train_all[10000:11000]
YY_train = Y_train_all[10000:11000]
X_test = X_test_all[:300]
Y_test = Y_test_all[:300]
XX_test = X_test_all[300:600]
YY_test = Y_test_all[300:600]
In [ ]:
print('shape of X_train   :', X_train.shape)
print('shape of XX_train  :', XX_train.shape)
print('shape of YY_train  :', YY_train.shape)
print('shape of X_test    :', X_test.shape)
print('shape of Y_test    :', Y_test.shape)
print('shape of XX_test   :', XX_test.shape)
print('shape of YY_test   :', YY_test.shape)
shape of X_train   : (10000, 28, 28)
shape of XX_train  : (1000, 28, 28)
shape of YY_train  : (1000,)
shape of X_test    : (300, 28, 28)
shape of Y_test    : (300,)
shape of XX_test   : (300, 28, 28)
shape of YY_test   : (300,)

Preprocess Data

In [ ]:
X_train = X_train.astype('float32') / 255.0
XX_train = XX_train.astype('float32') / 255.0
X_test = X_test.astype('float32') / 255.0
XX_test = XX_test.astype('float32') / 255.0

X_train = X_train.reshape(-1, 28, 28, 1)
XX_train = XX_train.reshape(-1, 28, 28, 1)
X_test = X_test.reshape(-1, 28, 28, 1)
XX_test = XX_test.reshape(-1, 28, 28, 1)

YY_train = tf.keras.utils.to_categorical(YY_train, 10)
YY_test = tf.keras.utils.to_categorical(YY_test, 10)

2.1. Build SimCLR for Pretext Task

Dataset for Pretext Task (SimCLR)

In SimCLR, two different augmented views are generated from a single image.

(1) random_cutout

  • Masks part of the image to encourage robust feature learning.

(2) random_augment

  • Applies crop, cutout, and noise to create an augmented view.

(3) augment_views

  • Generates two different augmented views from one image.

In [ ]:
def random_cutout(image, mask_size=8, fill_value=0.0):
    height = tf.shape(image)[0]
    width = tf.shape(image)[1]

    top = tf.random.uniform([], 0, height - mask_size + 1, dtype=tf.int32)
    left = tf.random.uniform([], 0, width - mask_size + 1, dtype=tf.int32)

    upper = image[:top, :, :]
    middle_left = image[top:top + mask_size, :left, :]
    middle_mask = tf.ones((mask_size, mask_size, 1), dtype=image.dtype) * fill_value
    middle_right = image[top:top + mask_size, left + mask_size:, :]
    lower = image[top + mask_size:, :, :]

    middle = tf.concat([middle_left, middle_mask, middle_right], axis=1)
    image = tf.concat([upper, middle, lower], axis=0)

    return image


def random_augment(image):
    image = tf.convert_to_tensor(image, dtype=tf.float32)

    image = tf.image.resize_with_crop_or_pad(image, 32, 32)
    image = tf.image.random_crop(image, size=[28, 28, 1])

    image = tf.cond(
        tf.random.uniform([]) > 0.5,
        lambda: random_cutout(image, mask_size=8, fill_value=0.0),
        lambda: image
    )

    noise = tf.random.normal(shape=tf.shape(image), mean=0.0, stddev=0.05)
    image = tf.clip_by_value(image + noise, 0.0, 1.0)

    return image
In [ ]:
n_samples = X_train.shape[0]

X_view_1 = np.zeros((n_samples, X_train.shape[1], X_train.shape[2], X_train.shape[3]),dtype = np.float32)
X_view_2 = np.zeros((n_samples, X_train.shape[1], X_train.shape[2], X_train.shape[3]),dtype = np.float32)

for i in range(n_samples):
    img = X_train[i]
    X_view_1[i] = random_augment(img).numpy()
    X_view_2[i] = random_augment(img).numpy()
In [ ]:
plt.figure(figsize=(9, 3))

plt.subplot(1, 3, 1)
plt.imshow(X_train[0, :, :, 0], cmap='gray')
plt.axis('off')
plt.title('Original')

plt.subplot(1, 3, 2)
plt.imshow(X_view_1[0, :, :, 0], cmap='gray')
plt.axis('off')
plt.title('View 1')

plt.subplot(1, 3, 3)
plt.imshow(X_view_2[0, :, :, 0], cmap='gray')
plt.axis('off')
plt.title('View 2')

plt.show()
No description has been provided for this image

Build Model for Pretext Task (SimCLR)

(1) encoder

  • Extracts a representation vector from the input image.

(2) projection_head

  • Transforms the representation vector into a latent vector for contrastive loss computation.


In [ ]:
encoder = tf.keras.models.Sequential([
    tf.keras.layers.Input(shape = (28, 28, 1)),

    tf.keras.layers.Conv2D(filters = 32,
                           kernel_size = (3, 3),
                           strides = (2, 2),
                           padding = 'same',
                           activation = 'relu'),
    tf.keras.layers.BatchNormalization(),

    tf.keras.layers.Conv2D(filters = 64,
                           kernel_size = (3, 3),
                           strides = (2, 2),
                           padding = 'same',
                           activation = 'relu'),
    tf.keras.layers.BatchNormalization(),

    tf.keras.layers.Conv2D(filters = 128,
                           kernel_size = (3, 3),
                           strides = (2, 2),
                           padding = 'same',
                           activation = 'relu'),
    tf.keras.layers.BatchNormalization(),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(units = 128, activation = 'relu')
], name = 'encoder')

encoder.summary()
Model: "encoder"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ conv2d_9 (Conv2D)               │ (None, 14, 14, 32)     │           320 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ batch_normalization_9           │ (None, 14, 14, 32)     │           128 │
│ (BatchNormalization)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_10 (Conv2D)              │ (None, 7, 7, 64)       │        18,496 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ batch_normalization_10          │ (None, 7, 7, 64)       │           256 │
│ (BatchNormalization)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_11 (Conv2D)              │ (None, 4, 4, 128)      │        73,856 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ batch_normalization_11          │ (None, 4, 4, 128)      │           512 │
│ (BatchNormalization)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ flatten_3 (Flatten)             │ (None, 2048)           │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_11 (Dense)                │ (None, 128)            │       262,272 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 355,840 (1.36 MB)
 Trainable params: 355,392 (1.36 MB)
 Non-trainable params: 448 (1.75 KB)


In [ ]:
projection_head = tf.keras.models.Sequential([
    tf.keras.layers.Input(shape=(128,)),
    tf.keras.layers.Dense(units=128, activation='relu'),
    tf.keras.layers.Dense(units=64)
], name='projection_head')

projection_head.summary()
Model: "projection_head"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ dense_12 (Dense)                │ (None, 128)            │        16,512 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_13 (Dense)                │ (None, 64)             │         8,256 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 24,768 (96.75 KB)
 Trainable params: 24,768 (96.75 KB)
 Non-trainable params: 0 (0.00 B)
In [ ]:
def simclr_loss(z_i, z_j, temperature=0.1):
    z_i = tf.math.l2_normalize(z_i, axis=1)
    z_j = tf.math.l2_normalize(z_j, axis=1)

    batch_size = tf.shape(z_i)[0]

    z = tf.concat([z_i, z_j], axis=0)
    similarity_matrix = tf.matmul(z, z, transpose_b=True) / temperature

    large_negative = 1e9
    mask = tf.eye(2 * batch_size)
    logits = similarity_matrix - mask * large_negative

    labels = tf.concat([tf.range(batch_size, 2 * batch_size), tf.range(batch_size)], axis=0)

    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits))

    return loss
In [ ]:
class SimCLRModel(tf.keras.Model):
    def __init__(self, encoder, projection_head, temperature=0.1):
        super().__init__()
        self.encoder = encoder
        self.projection_head = projection_head
        self.temperature = temperature

    def train_step(self, data):
        x, _, _ = tf.keras.utils.unpack_x_y_sample_weight(data)
        view_1, view_2 = x

        with tf.GradientTape() as tape:
            h_i = self.encoder(view_1, training=True)
            h_j = self.encoder(view_2, training=True)

            z_i = self.projection_head(h_i, training=True)
            z_j = self.projection_head(h_j, training=True)

            contrastive_loss = simclr_loss(z_i, z_j, temperature=self.temperature)

        trainable_variables = self.encoder.trainable_variables + self.projection_head.trainable_variables
        gradients = tape.gradient(contrastive_loss, trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, trainable_variables))

        return {'contrastive_loss': contrastive_loss}

Training the model for the pretext task

In [ ]:
simclr_model = SimCLRModel(encoder, projection_head, temperature=0.1)
simclr_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3))
simclr_model.fit( x=(X_view_1, X_view_2), epochs=30, batch_size=128, verbose=1, shuffle=True)
Epoch 1/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 12s 71ms/step - contrastive_loss: 0.4648
Epoch 2/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - contrastive_loss: 0.1532
Epoch 3/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - contrastive_loss: 0.0688
Epoch 4/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - contrastive_loss: 0.1336
Epoch 5/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - contrastive_loss: 0.0709
Epoch 6/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - contrastive_loss: 0.0577
Epoch 7/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - contrastive_loss: 0.0502
Epoch 8/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - contrastive_loss: 0.0442
Epoch 9/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - contrastive_loss: 0.3689
Epoch 10/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - contrastive_loss: 0.0170
Epoch 11/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - contrastive_loss: 0.0141
Epoch 12/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - contrastive_loss: 0.0443
Epoch 13/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - contrastive_loss: 0.0139
Epoch 14/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - contrastive_loss: 0.0186
Epoch 15/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - contrastive_loss: 0.0148
Epoch 16/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - contrastive_loss: 0.1474
Epoch 17/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - contrastive_loss: 0.0259
Epoch 18/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - contrastive_loss: 0.0146
Epoch 19/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - contrastive_loss: 0.0176
Epoch 20/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - contrastive_loss: 0.0091
Epoch 21/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - contrastive_loss: 0.0149
Epoch 22/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - contrastive_loss: 0.0110
Epoch 23/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - contrastive_loss: 0.0089
Epoch 24/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - contrastive_loss: 0.0518
Epoch 25/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - contrastive_loss: 0.0088
Epoch 26/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - contrastive_loss: 0.0099
Epoch 27/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - contrastive_loss: 0.0205
Epoch 28/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - contrastive_loss: 0.0092
Epoch 29/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - contrastive_loss: 0.0122
Epoch 30/30
79/79 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - contrastive_loss: 0.0093
Out[ ]:
<keras.src.callbacks.history.History at 0x7c23b01a9ee0>

2.2. Build Downstream Task (MNIST Image Classification)

Build Model for Downstream Task

Freezes the pretrained encoder and trains only the final classifier.


In [ ]:
encoder.trainable = False

model_downstream = tf.keras.models.Sequential([
    tf.keras.layers.Input(shape=(28, 28, 1)),
    encoder,
    tf.keras.layers.Dense(units=10, activation='softmax')
], name = 'downstream')

model_downstream.summary()
Model: "downstream"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ encoder (Sequential)            │ (None, 128)            │       355,840 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_14 (Dense)                │ (None, 10)             │         1,290 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 357,130 (1.36 MB)
 Trainable params: 1,290 (5.04 KB)
 Non-trainable params: 355,840 (1.36 MB)
In [ ]:
model_downstream.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-2),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

model_downstream.fit(XX_train,
    YY_train,
    batch_size=64,
    epochs=20,
    validation_split=0.2,
    verbose=1
)
Epoch 1/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 6s 242ms/step - accuracy: 0.4062 - loss: 3.1627 - val_accuracy: 0.6400 - val_loss: 1.2156
Epoch 2/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.7688 - loss: 0.6980 - val_accuracy: 0.8500 - val_loss: 0.6489
Epoch 3/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.8913 - loss: 0.3301 - val_accuracy: 0.8450 - val_loss: 0.5882
Epoch 4/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.9337 - loss: 0.2090 - val_accuracy: 0.8900 - val_loss: 0.5542
Epoch 5/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.9362 - loss: 0.1744 - val_accuracy: 0.8900 - val_loss: 0.5125
Epoch 6/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - accuracy: 0.9625 - loss: 0.1221 - val_accuracy: 0.9100 - val_loss: 0.4903
Epoch 7/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - accuracy: 0.9725 - loss: 0.1038 - val_accuracy: 0.9200 - val_loss: 0.4823
Epoch 8/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.9825 - loss: 0.0861 - val_accuracy: 0.9100 - val_loss: 0.4761
Epoch 9/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.9862 - loss: 0.0752 - val_accuracy: 0.9200 - val_loss: 0.4700
Epoch 10/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.9925 - loss: 0.0625 - val_accuracy: 0.9100 - val_loss: 0.4711
Epoch 11/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.9937 - loss: 0.0553 - val_accuracy: 0.9150 - val_loss: 0.4864
Epoch 12/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.9900 - loss: 0.0531 - val_accuracy: 0.9100 - val_loss: 0.4685
Epoch 13/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - accuracy: 0.9912 - loss: 0.0515 - val_accuracy: 0.9200 - val_loss: 0.4985
Epoch 14/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 27ms/step - accuracy: 0.9912 - loss: 0.0436 - val_accuracy: 0.9150 - val_loss: 0.4710
Epoch 15/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - accuracy: 0.9950 - loss: 0.0366 - val_accuracy: 0.9150 - val_loss: 0.4833
Epoch 16/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 26ms/step - accuracy: 0.9975 - loss: 0.0360 - val_accuracy: 0.9150 - val_loss: 0.4768
Epoch 17/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9987 - loss: 0.0317 - val_accuracy: 0.9150 - val_loss: 0.4795
Epoch 18/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step - accuracy: 0.9987 - loss: 0.0284 - val_accuracy: 0.9200 - val_loss: 0.4742
Epoch 19/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - accuracy: 1.0000 - loss: 0.0276 - val_accuracy: 0.9200 - val_loss: 0.4912
Epoch 20/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - accuracy: 1.0000 - loss: 0.0257 - val_accuracy: 0.9200 - val_loss: 0.4838
Out[ ]:
<keras.src.callbacks.history.History at 0x7c246c3607a0>
In [ ]:
test_loss_ssl, test_acc_ssl = model_downstream.evaluate(XX_test, YY_test, verbose=0)

print('Contrastive Learning Test Accuracy : {:.2f}%'.format(test_acc_ssl * 100))
Contrastive Learning Test Accuracy : 86.00%

Downstream Task Trained Result (Image Classification Result)

In [ ]:
name = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
idx = 9
img = XX_train[idx].reshape(-1,28,28,1)
label = YY_train[idx]
predict = model_downstream.predict(img)
mypred = np.argmax(predict, axis = 1)

plt.figure(figsize = (8, 4))
plt.subplot(1,2,1)
plt.imshow(img.reshape(28, 28), 'gray')
plt.axis('off')
plt.subplot(1,2,2)
plt.stem(predict[0])
plt.show()

print('Prediction : {}'.format(name[mypred[0]]))
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 891ms/step
No description has been provided for this image
Prediction : 2

2.3. Build Supervised Model for Comparison

Convolution Neural Networks for MNIST image classification

  • The number of total parameter is the similar with the model for the downstream task, but is has zero non-trainable parameters
In [ ]:
supervised_encoder = tf.keras.models.Sequential([
    tf.keras.layers.Input(shape=(28, 28, 1)),

    tf.keras.layers.Conv2D(filters=32,
                           kernel_size=(3, 3),
                           strides=(2, 2),
                           padding='same',
                           activation='relu'),
    tf.keras.layers.BatchNormalization(),

    tf.keras.layers.Conv2D(filters=64,
                           kernel_size=(3, 3),
                           strides=(2, 2),
                           padding='same',
                           activation='relu'),
    tf.keras.layers.BatchNormalization(),

    tf.keras.layers.Conv2D(filters=128,
                           kernel_size=(3, 3),
                           strides=(2, 2),
                           padding='same',
                           activation='relu'),
    tf.keras.layers.BatchNormalization(),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(units=128, activation='relu')
])

model_supervised = tf.keras.models.Sequential([
    tf.keras.layers.Input(shape=(28, 28, 1)),
    supervised_encoder,
    tf.keras.layers.Dense(units=10, activation='softmax')
], name = 'supervised')

model_supervised.summary()
Model: "supervised"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ sequential_3 (Sequential)       │ (None, 128)            │       355,840 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_16 (Dense)                │ (None, 10)             │         1,290 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 357,130 (1.36 MB)
 Trainable params: 356,682 (1.36 MB)
 Non-trainable params: 448 (1.75 KB)
In [ ]:
model_supervised.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

model_supervised.fit(
    XX_train,
    YY_train,
    batch_size=64,
    epochs=20,
    validation_split=0.2,
    verbose=1
)
Epoch 1/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 10s 291ms/step - accuracy: 0.6488 - loss: 1.0865 - val_accuracy: 0.5950 - val_loss: 2.2087
Epoch 2/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - accuracy: 0.9488 - loss: 0.1499 - val_accuracy: 0.4100 - val_loss: 2.1518
Epoch 3/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - accuracy: 0.9950 - loss: 0.0339 - val_accuracy: 0.4050 - val_loss: 2.1269
Epoch 4/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - accuracy: 1.0000 - loss: 0.0100 - val_accuracy: 0.3450 - val_loss: 2.1053
Epoch 5/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - accuracy: 1.0000 - loss: 0.0039 - val_accuracy: 0.3550 - val_loss: 2.0889
Epoch 6/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - accuracy: 1.0000 - loss: 0.0023 - val_accuracy: 0.3500 - val_loss: 2.0789
Epoch 7/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - accuracy: 1.0000 - loss: 0.0018 - val_accuracy: 0.3450 - val_loss: 2.0599
Epoch 8/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - accuracy: 1.0000 - loss: 0.0015 - val_accuracy: 0.3400 - val_loss: 2.0391
Epoch 9/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - accuracy: 1.0000 - loss: 0.0013 - val_accuracy: 0.3450 - val_loss: 2.0180
Epoch 10/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 30ms/step - accuracy: 1.0000 - loss: 0.0011 - val_accuracy: 0.3550 - val_loss: 1.9932
Epoch 11/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 27ms/step - accuracy: 1.0000 - loss: 9.8316e-04 - val_accuracy: 0.3750 - val_loss: 1.9603
Epoch 12/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 26ms/step - accuracy: 1.0000 - loss: 8.9548e-04 - val_accuracy: 0.3800 - val_loss: 1.9227
Epoch 13/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 30ms/step - accuracy: 1.0000 - loss: 8.2387e-04 - val_accuracy: 0.4100 - val_loss: 1.8750
Epoch 14/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 22ms/step - accuracy: 1.0000 - loss: 7.5596e-04 - val_accuracy: 0.4200 - val_loss: 1.8253
Epoch 15/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 1s 43ms/step - accuracy: 1.0000 - loss: 7.0116e-04 - val_accuracy: 0.4450 - val_loss: 1.7708
Epoch 16/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - accuracy: 1.0000 - loss: 6.3810e-04 - val_accuracy: 0.4700 - val_loss: 1.7036
Epoch 17/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - accuracy: 1.0000 - loss: 5.9262e-04 - val_accuracy: 0.4800 - val_loss: 1.6349
Epoch 18/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - accuracy: 1.0000 - loss: 5.6187e-04 - val_accuracy: 0.4950 - val_loss: 1.5589
Epoch 19/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - accuracy: 1.0000 - loss: 5.1555e-04 - val_accuracy: 0.5100 - val_loss: 1.4798
Epoch 20/20
13/13 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - accuracy: 1.0000 - loss: 4.9402e-04 - val_accuracy: 0.5400 - val_loss: 1.4023
Out[ ]:
<keras.src.callbacks.history.History at 0x7c246ae4f2c0>
In [ ]:
test_loss_sup, test_acc_sup = model_supervised.evaluate(
    XX_test,
    YY_test,
    verbose=0
)

print('Supervised Test Accuracy : {:.2f}%'.format(test_acc_sup * 100))
Supervised Test Accuracy : 45.33%

Compare Constractive Learning Learning and Supervised Learning

(1) Pretext Task

  • Input data: 10,000 MNIST images without labels

(2) Downstream Task and Supervised Learning (for performance comparison)

  • Training data: 1,000 MNIST images with labels
  • Test data: 300 MNIST images with labels
In [ ]:
print('Constractive learning Accuracy on Test Data: {:.2f}%'.format(test_acc_ssl * 100))
print('Supervised Learning Accuracy on Test Data  : {:.2f}%'.format(test_acc_sup * 100))
Constractive learning Accuracy on Test Data: 86.00%
Supervised Learning Accuracy on Test Data  : 45.33%
In [ ]:
%%javascript
$.getScript('https://kmahelona.github.io/ipython_notebook_goodies/ipython_notebook_toc.js')