Domain Adaptation


By Prof. Seungchul Lee
http://iailab.kaist.ac.kr/
Industrial AI Lab at KAIST

Table of Contents



1. Domain Adaptation

In many real-world machine learning applications, the data we have available for training and the data we encounter at test time do not come from the same distribution. A model trained on one dataset may perform well on that dataset but fail to generalize when deployed in a different environment. This discrepancy between training and test distributions is known as domain shift, and addressing it is the central challenge of domain adaptation.

Consider a practical example from manufacturing: a defect detection model trained on images collected from one production line may perform poorly when deployed on a different production line — even within the same factory. The two lines may produce the same product, but differences in lighting conditions, camera angles, machine wear, or material batches cause the statistical properties of the images to differ. The underlying task is identical, but the model trained on data from the source line does not automatically generalize to the target line. Collecting and annotating a new labeled dataset for every production line or every factory is costly and time-consuming, making domain adaptation a practically important problem in industrial settings.

Domain adaptation aims to bridge this gap. The goal is to learn a model on a labeled source domain that generalizes well to a target domain, even when the target domain has little or no labeled data.

Domain adaptation sits at the intersection of transfer learning and distribution matching. It draws on ideas from both fields — leveraging the structure learned from the source domain while aligning the source and target distributions so that the learned representations transfer effectively.

In the following sections, we will examine the problem more formally, discuss the key approaches that have been developed to address it, and explore how adversarial training — the same principle underlying GAN — can be applied to learn domain-invariant representations.


1.1. Basic Assumption of AI Model

A fundamental assumption underlying most machine learning models is that the training data and the test data are drawn from the same distribution:


$$P_{\text{train}}(X, Y) \approx P_{\text{test}}(X, Y)$$


This assumption is what allows a model trained on a fixed dataset to be expected to perform well on unseen data. The model learns the statistical patterns present in the training distribution, and if the test data follows the same distribution, those patterns remain valid at inference time.

However, this assumption is frequently violated in practice. When a model is deployed in a real-world environment, the data it encounters may differ from the training data in subtle or significant ways — due to differences in sensor characteristics, environmental conditions, operational settings, or data collection procedures. The natural question then arises: will the model still perform well in a deployment environment that is different from the training environment?

This is precisely the question that domain adaptation seeks to answer — and address. Rather than assuming that the training and deployment distributions are identical, domain adaptation explicitly acknowledges the gap between them and develops methods to bridge it.


Domain Shift

No description has been provided for this image


The above figure illustrates the domain shift problem in a manufacturing context. An AI model is trained on data collected from Line A. When the same model is tested on Line A, the output distribution is confident and well-concentrated on class C1, indicating correct and reliable predictions. On the other hand, when the same model is applied to Line B — a different production line with different equipment and camera setup — the output distribution becomes uncertain and spread across multiple classes, reflecting a significant drop in model confidence and reliability.

This phenomenon is known as domain shift: the data distribution differs between the training environment and the deployment environment. In practice, such shifts can arise from a wide range of factors — differences in equipment, sensor position, camera angle, process conditions, temperature, or lighting. Even subtle changes in any of these factors can cause the input data distribution to shift in ways that invalidate the decision boundaries learned during training. The model's learned criteria, which were well-calibrated for Line A, are no longer valid when applied to Line B.


1.2. Source Domain and Target Domain

The figure illustrates one of the most common manifestations of domain shift in manufacturing: the gap between the environment in which a model is developed and the environment in which it is deployed.

  • Source domain — the controlled environment where data is collected and the model is trained. This may be a simulation, a laboratory setup, or a dedicated test bench where conditions can be carefully managed and labeled data can be collected at relatively low cost.

  • Target domain — the real production process where the model is ultimately deployed. Here, conditions are less controlled, the environment is more complex, and labeled data is scarce or unavailable.


No description has been provided for this image


Formally, the source domain $\mathcal{D}_S$ and target domain $\mathcal{D}_T$ are defined by their respective data distributions. Depending on the availability of labeled data in the target domain, domain adaptation is typically categorized into two settings:

When no labeled data is available in the target domain — the unsupervised setting:


$$\mathcal{D}_S = \{(x_i^S, y_i^S)\}_{i=1}^{n_S} \sim P_S(X, Y) \qquad \mathcal{D}_T = \{x_j^T\}_{j=1}^{n_T} \sim P_T(X)$$


When a small number of labeled samples is available in the target domain — the semi-supervised setting:


$$\mathcal{D}_S = \{(x_i^S, y_i^S)\}_{i=1}^{n_S} \sim P_S(X, Y) \qquad \mathcal{D}_T = \{(x_j^T, y_j^T)\}_{j=1}^{n_T} \sim P_T(X, Y)$$


where $n_T \ll n_S$ — the number of labeled target samples is much smaller than the number of labeled source samples.


In both settings, the domain shift problem arises because $P_S(X, Y) \neq P_T(X, Y)$ — the two distributions differ, and a model trained on $\mathcal{D}_S$ cannot be expected to perform well on $\mathcal{D}_T$ without adaptation. The goal of domain adaptation is to leverage the labeled source data together with the available target data — labeled or unlabeled — to learn a model that generalizes to the target domain.


1.3. Domain Shift

Domain shift occurs when the source and target domains differ in their data distribution or input-output relationship. Understanding the nature of this shift is important because different types of shift require different adaptation strategies. The most common and well-studied type is covariate shift.


Covariate Shift

Covariate shift refers to the case where the input distribution $P(X)$ changes between the source and target domains, while the underlying relationship between input and output remains unchanged. In other words, the task itself has not changed — a defective product is still defective, and a normal product is still normal — but the way the data looks has changed due to differences in the environment or measurement conditions.

In manufacturing, covariate shift commonly arises from factors such as sensor position changes, equipment differences, or variations in lighting, temperature, and vibration conditions. Each production line may have its own measurement conditions that cause the input data distribution to shift, even when the underlying product and defect types remain the same.


No description has been provided for this image


The figure above illustrates this clearly. The source domain (blue) and target domain (red) contain the same two classes — circles and crosses — but their distributions occupy different regions of the feature space. The decision boundary learned from the source domain, shown as a dashed line, was well-suited for the source distribution but no longer correctly separates the target classes. The shift in input distribution causes the source decision boundary to misclassify target samples, even though the nature of the classification task has not changed at all.


Covariate Shift Example: Time-Varying Operating Conditions

The figure below shows a practical example of covariate shift arising from time-varying operating conditions. The blue region represents the source domain — the period during which training data was collected. The red region represents the target domain — a later period during which the model is deployed. The signal characteristics change noticeably between the two periods, reflecting a shift in the underlying operating conditions.

This type of shift is particularly challenging in industrial settings for two reasons. First, acquiring diverse and high-quality labeled data that covers all possible operating conditions is difficult and costly. Second, the operating environment changes continuously over time — due to machine aging, process drift, seasonal variation, or changes in raw materials — meaning that a model trained at one point in time may gradually become less reliable as conditions evolve.


No description has been provided for this image


1.4. Data for Domain Adaptation: MNIST to USPS

Before introducing the various domain adaptation methods, let us first examine the data that will be used throughout this chapter to demonstrate them.

The figure shows representative samples from the two datasets. The top row shows digits from MNIST — the source domain — which consists of handwritten digits with relatively thin strokes on a clean black background. The bottom row shows digits from USPS — the target domain — which also contains handwritten digits from 0 to 9, but with noticeably different visual characteristics: the strokes are thicker and smoother, the digits appear slightly larger relative to the image, and the overall appearance has a softer, more blurred quality.

Despite sharing the same task — digit classification from 0 to 9 — the two datasets look quite different. This visual difference reflects an underlying shift in the input data distribution, making MNIST and USPS a standard and well-established benchmark for evaluating domain adaptation methods.

The experimental setup is as follows:

  • Source domain: MNIST — fully labeled, used for training
  • Target domain: USPS — unlabeled, used for evaluation
  • Task: digit classification (0 to 9)
  • Model: CNN

A model trained solely on MNIST and tested directly on USPS will experience a noticeable drop in performance due to the domain shift. The goal of domain adaptation is to close this performance gap without requiring labeled data from USPS.


No description has been provided for this image

In [ ]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from matplotlib.lines import Line2D
import warnings
warnings.filterwarnings("ignore")

from google.colab import drive
drive.mount("/content/drive")

print("GPU devices:", tf.config.list_physical_devices("GPU"))
In [ ]:
data_dir = "/content/drive/MyDrive/DL_Colab/DL_data"

mnist_train_images = np.load(data_dir + "/MNIST_train_images.npy").astype("float32")
mnist_train_labels = np.load(data_dir + "/MNIST_train_labels.npy").astype("int64")
mnist_test_images = np.load(data_dir + "/MNIST_test_images.npy").astype("float32")
mnist_test_labels = np.load(data_dir + "/MNIST_test_labels.npy").astype("int64")

usps_train_images = np.load(data_dir + "/USPS_train_images.npy").astype("float32")
usps_train_labels = np.load(data_dir + "/USPS_train_labels.npy").astype("int64")
usps_test_images = np.load(data_dir + "/USPS_test_images.npy").astype("float32")
usps_test_labels = np.load(data_dir + "/USPS_test_labels.npy").astype("int64")

print("MNIST train:", mnist_train_images.shape, mnist_train_labels.shape)
print("MNIST test :", mnist_test_images.shape, mnist_test_labels.shape)
print("USPS train :", usps_train_images.shape, usps_train_labels.shape)
print("USPS test  :", usps_test_images.shape, usps_test_labels.shape)
In [ ]:
dataset_list = [
    ("MNIST", mnist_train_images, mnist_train_labels),
    ("USPS", usps_train_images, usps_train_labels)
]

fig, axes = plt.subplots(2, 10, figsize=(15, 3.3))

for row_index, (dataset_name, images, labels) in enumerate(dataset_list):
    for digit in range(10):
        sample_index = np.where(labels == digit)[0][0]
        axes[row_index, digit].imshow(images[sample_index, :, :, 0], cmap="gray", vmin=0.0, vmax=1.0)
        axes[row_index, digit].axis("off")

    axes[row_index, 0].text(-0.55, 0.5, dataset_name, transform=axes[row_index, 0].transAxes, ha="center", va="center", fontsize=18)

plt.tight_layout(rect=[0.08, 0, 1, 1])
plt.show()

1.5. Training with MNIST (Source)


In [ ]:
source_batch_size = 128
source_epochs = 3
source_learning_rate = 1e-3

source_classifier = tf.keras.models.Sequential([
    tf.keras.layers.Input(shape=(28, 28, 1)),
    tf.keras.layers.Conv2D(32, kernel_size=(7, 7), activation="relu", padding="same"),
    tf.keras.layers.MaxPool2D(pool_size=(2, 2)),
    tf.keras.layers.Conv2D(64, kernel_size=(7, 7), activation="relu", padding="same"),
    tf.keras.layers.MaxPool2D(pool_size=(2, 2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation="relu"),
    tf.keras.layers.Dense(10, activation="softmax"),
], name="source_classifier")

source_classifier.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=source_learning_rate),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)
In [ ]:
source_classifier.summary()
In [ ]:
source_classifier.fit(
    mnist_train_images,
    mnist_train_labels,
    batch_size=source_batch_size,
    epochs=source_epochs,
    validation_data=(mnist_test_images, mnist_test_labels),
    verbose=1,
)

source_mnist_loss, source_mnist_acc = source_classifier.evaluate(
    mnist_test_images,
    mnist_test_labels,
    batch_size=source_batch_size,
    verbose=0,
)

source_usps_loss, source_usps_acc = source_classifier.evaluate(
    usps_test_images,
    usps_test_labels,
    batch_size=source_batch_size,
    verbose=0,
)

source_classifier_weights = source_classifier.get_weights()

print("[Source-only]")
print(f"MNIST accuracy: {source_mnist_acc:.4f}")
print(f"USPS accuracy : {source_usps_acc:.4f}")

2. Generative Domain Adaptation

Among the various approaches to domain adaptation, the most intuitive is generative domain adaptation. The core idea is straightforward: rather than adapting the model to the target domain, we adapt the data itself.

Specifically, the source images are transformed at the pixel level to look like target domain images, while their original labels are preserved. The transformed images are then used to train a classifier, and the resulting decision boundary is applied directly to the target domain.

The appeal of this approach lies in its simplicity. If we can make source images visually indistinguishable from target images without changing their semantic content — that is, a digit 2 in the source domain becomes a digit 2 that looks like it came from the target domain — then the classifier trained on the transformed data should generalize well to the target domain without requiring any target labels.


No description has been provided for this image

No description has been provided for this image


2.1. Generative Domain Adaptation: CycleGAN

This pixel-level transformation is precisely the kind of task that CycleGAN was designed for. By learning an unpaired image-to-image mapping between the source and target domains, CycleGAN can transform source images into the visual style of the target domain while preserving their underlying content and labels. The transformed source dataset effectively serves as a bridge between the two domains, allowing a standard supervised classifier to be trained on source labels while being exposed to target-like visual characteristics.

The cycle consistency constraint plays a particularly important role here. Recall that CycleGAN enforces:


$$G_{YX}(G_{XY}(x)) \approx x$$


where $x$ is a source image, $G_{XY}$ transforms it into the target style, and $G_{YX}$ reconstructs the original source image from the transformed output. This round-trip reconstruction constraint ensures that the transformation preserves the semantic content of the source image — the digit identity, shape, and structure — while only changing the visual style. Without this constraint, the generator could produce realistic-looking target images that no longer correspond to the correct source label, making the transformed data useless for training a classifier.

The overall pipeline for generative domain adaptation with CycleGAN is therefore:

  • Train CycleGAN on unpaired source and target images to learn the bidirectional mapping
  • Transform the labeled source images into the target style using $G_{XY}$
  • Train a classifier on the transformed source images using the original source labels
  • Deploy the trained classifier directly on the target domain

No description has been provided for this image


In [ ]:
mnist_train_images_cyclegan = mnist_train_images * 2.0 - 1.0
mnist_test_images_cyclegan = mnist_test_images * 2.0 - 1.0

usps_train_images_cyclegan = usps_train_images * 2.0 - 1.0
usps_test_images_cyclegan = usps_test_images * 2.0 - 1.0

print("Image range")
print("MNIST:", mnist_train_images.min(), mnist_train_images.max())
print("USPS :", usps_train_images.min(), usps_train_images.max())

print("\nCycleGAN image range")
print("MNIST:", mnist_train_images_cyclegan.min(), mnist_train_images_cyclegan.max())
print("USPS :", usps_train_images_cyclegan.min(), usps_train_images_cyclegan.max())
Image range
MNIST: 0.0 1.0
USPS : 0.0 0.9999705

CycleGAN image range
MNIST: -1.0 1.0
USPS : -1.0 0.999941
In [ ]:
cycle_batch_size = 128

cycle_source_dataset = tf.data.Dataset.from_tensor_slices(mnist_train_images_cyclegan)
cycle_source_dataset = cycle_source_dataset.shuffle(len(mnist_train_images_cyclegan))
cycle_source_dataset = cycle_source_dataset.repeat()
cycle_source_dataset = cycle_source_dataset.batch(cycle_batch_size)

cycle_target_dataset = tf.data.Dataset.from_tensor_slices(usps_train_images_cyclegan)
cycle_target_dataset = cycle_target_dataset.shuffle(len(usps_train_images_cyclegan))
cycle_target_dataset = cycle_target_dataset.repeat()
cycle_target_dataset = cycle_target_dataset.batch(cycle_batch_size)

cycle_source_iter = iter(cycle_source_dataset)
cycle_target_iter = iter(cycle_target_dataset)

cycle_steps_per_epoch = min(
    len(mnist_train_images_cyclegan),
    len(usps_train_images_cyclegan)
) // cycle_batch_size
In [ ]:
mnist_to_usps_generator = tf.keras.models.Sequential([
    tf.keras.layers.Input(shape=(28, 28, 1)),
    tf.keras.layers.Conv2D(64, kernel_size=7, padding="same", use_bias=False),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.ReLU(),
    tf.keras.layers.Conv2D(128, kernel_size=3, strides=2, padding="same", use_bias=False),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.ReLU(),
    tf.keras.layers.Conv2D(256, kernel_size=3, strides=2, padding="same", use_bias=False),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.ReLU(),
    tf.keras.layers.Conv2DTranspose(128, kernel_size=3, strides=2, padding="same", use_bias=False),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.ReLU(),
    tf.keras.layers.Conv2DTranspose(64, kernel_size=3, strides=2, padding="same", use_bias=False),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.ReLU(),
    tf.keras.layers.Conv2D(1, kernel_size=7, padding="same", activation="tanh")], name="mnist_to_usps_generator")
In [ ]:
mnist_to_usps_generator.summary()
Model: "mnist_to_usps_generator"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ conv2d_36 (Conv2D)              │ (None, 28, 28, 64)     │         3,136 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ batch_normalization_29          │ (None, 28, 28, 64)     │           256 │
│ (BatchNormalization)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ re_lu_25 (ReLU)                 │ (None, 28, 28, 64)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_37 (Conv2D)              │ (None, 14, 14, 128)    │        73,728 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ batch_normalization_30          │ (None, 14, 14, 128)    │           512 │
│ (BatchNormalization)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ re_lu_26 (ReLU)                 │ (None, 14, 14, 128)    │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_38 (Conv2D)              │ (None, 7, 7, 256)      │       294,912 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ batch_normalization_31          │ (None, 7, 7, 256)      │         1,024 │
│ (BatchNormalization)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ re_lu_27 (ReLU)                 │ (None, 7, 7, 256)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_transpose_10             │ (None, 14, 14, 128)    │       294,912 │
│ (Conv2DTranspose)               │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ batch_normalization_32          │ (None, 14, 14, 128)    │           512 │
│ (BatchNormalization)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ re_lu_28 (ReLU)                 │ (None, 14, 14, 128)    │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_transpose_11             │ (None, 28, 28, 64)     │        73,728 │
│ (Conv2DTranspose)               │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ batch_normalization_33          │ (None, 28, 28, 64)     │           256 │
│ (BatchNormalization)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ re_lu_29 (ReLU)                 │ (None, 28, 28, 64)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_39 (Conv2D)              │ (None, 28, 28, 1)      │         3,137 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 746,113 (2.85 MB)
 Trainable params: 744,833 (2.84 MB)
 Non-trainable params: 1,280 (5.00 KB)
In [ ]:
usps_to_mnist_generator = tf.keras.models.Sequential([
    tf.keras.layers.Input(shape=(28, 28, 1)),
    tf.keras.layers.Conv2D(64, kernel_size=7, padding="same", use_bias=False),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.ReLU(),
    tf.keras.layers.Conv2D(128, kernel_size=3, strides=2, padding="same", use_bias=False),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.ReLU(),
    tf.keras.layers.Conv2D(256, kernel_size=3, strides=2, padding="same", use_bias=False),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.ReLU(),
    tf.keras.layers.Conv2DTranspose(128, kernel_size=3, strides=2, padding="same", use_bias=False),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.ReLU(),
    tf.keras.layers.Conv2DTranspose(64, kernel_size=3, strides=2, padding="same", use_bias=False),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.ReLU(),
    tf.keras.layers.Conv2D(1, kernel_size=7, padding="same", activation="tanh")], name="usps_to_mnist_generator")
In [ ]:
usps_to_mnist_generator.summary()
Model: "usps_to_mnist_generator"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ conv2d_40 (Conv2D)              │ (None, 28, 28, 64)     │         3,136 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ batch_normalization_34          │ (None, 28, 28, 64)     │           256 │
│ (BatchNormalization)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ re_lu_30 (ReLU)                 │ (None, 28, 28, 64)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_41 (Conv2D)              │ (None, 14, 14, 128)    │        73,728 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ batch_normalization_35          │ (None, 14, 14, 128)    │           512 │
│ (BatchNormalization)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ re_lu_31 (ReLU)                 │ (None, 14, 14, 128)    │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_42 (Conv2D)              │ (None, 7, 7, 256)      │       294,912 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ batch_normalization_36          │ (None, 7, 7, 256)      │         1,024 │
│ (BatchNormalization)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ re_lu_32 (ReLU)                 │ (None, 7, 7, 256)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_transpose_12             │ (None, 14, 14, 128)    │       294,912 │
│ (Conv2DTranspose)               │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ batch_normalization_37          │ (None, 14, 14, 128)    │           512 │
│ (BatchNormalization)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ re_lu_33 (ReLU)                 │ (None, 14, 14, 128)    │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_transpose_13             │ (None, 28, 28, 64)     │        73,728 │
│ (Conv2DTranspose)               │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ batch_normalization_38          │ (None, 28, 28, 64)     │           256 │
│ (BatchNormalization)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ re_lu_34 (ReLU)                 │ (None, 28, 28, 64)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_43 (Conv2D)              │ (None, 28, 28, 1)      │         3,137 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 746,113 (2.85 MB)
 Trainable params: 744,833 (2.84 MB)
 Non-trainable params: 1,280 (5.00 KB)
In [ ]:
mnist_discriminator = tf.keras.models.Sequential([
    tf.keras.layers.Input(shape=(28, 28, 1)),
    tf.keras.layers.Conv2D(64, kernel_size=4, strides=2, padding="same"),
    tf.keras.layers.LeakyReLU(0.2),
    tf.keras.layers.Conv2D(128, kernel_size=4, strides=2, padding="same", use_bias=False),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.LeakyReLU(0.2),
    tf.keras.layers.Conv2D(256, kernel_size=4, strides=1, padding="same", use_bias=False),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.LeakyReLU(0.2),
    tf.keras.layers.Conv2D(1, kernel_size=4, strides=1, padding="same")], name="mnist_discriminator")
In [ ]:
mnist_discriminator.summary()
Model: "mnist_discriminator"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ conv2d_44 (Conv2D)              │ (None, 14, 14, 64)     │         1,088 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ leaky_re_lu_6 (LeakyReLU)       │ (None, 14, 14, 64)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_45 (Conv2D)              │ (None, 7, 7, 128)      │       131,072 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ batch_normalization_39          │ (None, 7, 7, 128)      │           512 │
│ (BatchNormalization)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ leaky_re_lu_7 (LeakyReLU)       │ (None, 7, 7, 128)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_46 (Conv2D)              │ (None, 7, 7, 256)      │       524,288 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ batch_normalization_40          │ (None, 7, 7, 256)      │         1,024 │
│ (BatchNormalization)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ leaky_re_lu_8 (LeakyReLU)       │ (None, 7, 7, 256)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_47 (Conv2D)              │ (None, 7, 7, 1)        │         4,097 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 662,081 (2.53 MB)
 Trainable params: 661,313 (2.52 MB)
 Non-trainable params: 768 (3.00 KB)
In [ ]:
usps_discriminator = tf.keras.models.Sequential([
    tf.keras.layers.Input(shape=(28, 28, 1)),
    tf.keras.layers.Conv2D(64, kernel_size=4, strides=2, padding="same"),
    tf.keras.layers.LeakyReLU(0.2),
    tf.keras.layers.Conv2D(128, kernel_size=4, strides=2, padding="same", use_bias=False),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.LeakyReLU(0.2),
    tf.keras.layers.Conv2D(256, kernel_size=4, strides=1, padding="same", use_bias=False),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.LeakyReLU(0.2),
    tf.keras.layers.Conv2D(1, kernel_size=4, strides=1, padding="same")], name="usps_discriminator")
In [ ]:
usps_discriminator.summary()
Model: "usps_discriminator"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ conv2d_48 (Conv2D)              │ (None, 14, 14, 64)     │         1,088 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ leaky_re_lu_9 (LeakyReLU)       │ (None, 14, 14, 64)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_49 (Conv2D)              │ (None, 7, 7, 128)      │       131,072 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ batch_normalization_41          │ (None, 7, 7, 128)      │           512 │
│ (BatchNormalization)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ leaky_re_lu_10 (LeakyReLU)      │ (None, 7, 7, 128)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_50 (Conv2D)              │ (None, 7, 7, 256)      │       524,288 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ batch_normalization_42          │ (None, 7, 7, 256)      │         1,024 │
│ (BatchNormalization)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ leaky_re_lu_11 (LeakyReLU)      │ (None, 7, 7, 256)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_51 (Conv2D)              │ (None, 7, 7, 1)        │         4,097 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 662,081 (2.53 MB)
 Trainable params: 661,313 (2.52 MB)
 Non-trainable params: 768 (3.00 KB)
In [ ]:
mnist_to_usps_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5, beta_2=0.999)
usps_to_mnist_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5, beta_2=0.999)

mnist_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5, beta_2=0.999)
usps_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5, beta_2=0.999)

mse_loss = tf.keras.losses.MeanSquaredError()
In [ ]:
cycle_epochs = 30
cycle_display_epoch_interval = 10

for epoch in range(1, cycle_epochs + 1):
    g_mnist_to_usps_loss_sum = 0.0
    g_usps_to_mnist_loss_sum = 0.0
    d_mnist_loss_sum = 0.0
    d_usps_loss_sum = 0.0
    cycle_loss_sum = 0.0

    for _ in range(cycle_steps_per_epoch):
        real_mnist = next(cycle_source_iter)
        real_usps = next(cycle_target_iter)

        with tf.GradientTape(persistent=True) as tape:
            fake_usps = mnist_to_usps_generator(real_mnist, training=True)
            fake_mnist = usps_to_mnist_generator(real_usps, training=True)

            reconstructed_mnist = usps_to_mnist_generator(fake_usps, training=True)
            reconstructed_usps = mnist_to_usps_generator(fake_mnist, training=True)

            identity_mnist = usps_to_mnist_generator(real_mnist, training=True)
            identity_usps = mnist_to_usps_generator(real_usps, training=True)

            real_mnist_logits = mnist_discriminator(real_mnist, training=True)
            fake_mnist_logits = mnist_discriminator(fake_mnist, training=True)

            real_usps_logits = usps_discriminator(real_usps, training=True)
            fake_usps_logits = usps_discriminator(fake_usps, training=True)

            mnist_to_usps_gan_loss = mse_loss(tf.ones_like(fake_usps_logits), fake_usps_logits)
            usps_to_mnist_gan_loss = mse_loss(tf.ones_like(fake_mnist_logits), fake_mnist_logits )

            cycle_loss = 10.0 * (tf.reduce_mean(tf.abs(real_mnist - reconstructed_mnist))
                                 + tf.reduce_mean(tf.abs(real_usps - reconstructed_usps)))

            identity_mnist_loss = 3.0 * tf.reduce_mean(tf.abs(real_mnist - identity_mnist))
            identity_usps_loss = 3.0 * tf.reduce_mean(tf.abs(real_usps - identity_usps))

            mnist_to_usps_generator_loss = (mnist_to_usps_gan_loss + cycle_loss + identity_usps_loss)
            usps_to_mnist_generator_loss = (usps_to_mnist_gan_loss + cycle_loss + identity_mnist_loss)

            mnist_discriminator_loss = 0.5 * (mse_loss(tf.ones_like(real_mnist_logits), real_mnist_logits)
                                              + mse_loss(tf.zeros_like(fake_mnist_logits), fake_mnist_logits))
            usps_discriminator_loss = 0.5 * (mse_loss(tf.ones_like(real_usps_logits), real_usps_logits)
                                             + mse_loss(tf.zeros_like(fake_usps_logits), fake_usps_logits))

        mnist_to_usps_generator_optimizer.apply_gradients(
            zip(tape.gradient(mnist_to_usps_generator_loss,
                              mnist_to_usps_generator.trainable_variables),
                mnist_to_usps_generator.trainable_variables))

        usps_to_mnist_generator_optimizer.apply_gradients(
            zip(tape.gradient(usps_to_mnist_generator_loss,
                              usps_to_mnist_generator.trainable_variables),
                usps_to_mnist_generator.trainable_variables))

        mnist_discriminator_optimizer.apply_gradients(
            zip(tape.gradient(mnist_discriminator_loss,
                              mnist_discriminator.trainable_variables),
                mnist_discriminator.trainable_variables))

        usps_discriminator_optimizer.apply_gradients(
            zip(tape.gradient(usps_discriminator_loss,
                              usps_discriminator.trainable_variables),
                usps_discriminator.trainable_variables))

        del tape

        g_mnist_to_usps_loss_sum += float(mnist_to_usps_generator_loss)
        g_usps_to_mnist_loss_sum += float(usps_to_mnist_generator_loss)
        d_mnist_loss_sum += float(mnist_discriminator_loss)
        d_usps_loss_sum += float(usps_discriminator_loss)
        cycle_loss_sum += float(cycle_loss)

    g_mnist_to_usps_loss_mean = g_mnist_to_usps_loss_sum / cycle_steps_per_epoch
    g_usps_to_mnist_loss_mean = g_usps_to_mnist_loss_sum / cycle_steps_per_epoch
    d_mnist_loss_mean = d_mnist_loss_sum / cycle_steps_per_epoch
    d_usps_loss_mean = d_usps_loss_sum / cycle_steps_per_epoch
    cycle_loss_mean = cycle_loss_sum / cycle_steps_per_epoch

    if epoch == 1 or epoch % cycle_display_epoch_interval == 0 or epoch == cycle_epochs:
        print(
            f"Epoch {epoch:03d}/{cycle_epochs} | "
            f"g_mnist_to_usps={g_mnist_to_usps_loss_mean:.4f} | "
            f"g_usps_to_mnist={g_usps_to_mnist_loss_mean:.4f} | "
            f"d_mnist={d_mnist_loss_mean:.4f} | "
            f"d_usps={d_usps_loss_mean:.4f} | "
            f"cycle={cycle_loss_mean:.4f}"
        )
Epoch 001/30 | g_mnist_to_usps=4.5764 | g_usps_to_mnist=4.5477 | d_mnist=0.3153 | d_usps=0.3610 | cycle=3.4605
Epoch 010/30 | g_mnist_to_usps=1.8663 | g_usps_to_mnist=1.8185 | d_mnist=0.1826 | d_usps=0.1649 | cycle=1.2461
Epoch 020/30 | g_mnist_to_usps=1.8382 | g_usps_to_mnist=1.7249 | d_mnist=0.1669 | d_usps=0.1362 | cycle=1.1578
Epoch 030/30 | g_mnist_to_usps=2.4298 | g_usps_to_mnist=1.9249 | d_mnist=0.1462 | d_usps=0.7321 | cycle=1.3113
In [ ]:
weight_dir = "/content/drive/MyDrive/DL_Colab/weights"

mnist_to_usps_generator.load_weights(weight_dir + "/mnist_to_usps_generator.weights.h5")
usps_to_mnist_generator.load_weights(weight_dir + "/usps_to_mnist_generator.weights.h5")

mnist_discriminator.load_weights(weight_dir + "/mnist_discriminator.weights.h5")
usps_discriminator.load_weights(weight_dir + "/usps_discriminator.weights.h5")

print("Loaded CycleGAN weights")
Loaded CycleGAN weights
In [ ]:
cycle_mnist_sample_indices = np.array([np.where(mnist_test_labels == digit)[0][0] for digit in range(10)])
cycle_usps_sample_indices = np.array([np.where(usps_test_labels == digit)[0][0] for digit in range(10)])

cycle_mnist_sample_images = mnist_test_images_cyclegan[cycle_mnist_sample_indices]
cycle_usps_sample_images = usps_test_images_cyclegan[cycle_usps_sample_indices]

generated_usps_samples = mnist_to_usps_generator.predict(cycle_mnist_sample_images, batch_size=10, verbose=0)
reconstructed_mnist_samples = usps_to_mnist_generator.predict(generated_usps_samples, batch_size=10, verbose=0)

image_rows = [
    (cycle_mnist_sample_images + 1.0) / 2.0,
    (cycle_usps_sample_images + 1.0) / 2.0,
    (generated_usps_samples + 1.0) / 2.0,
    (reconstructed_mnist_samples + 1.0) / 2.0,
]

row_labels = [
    "Original MNIST",
    "Original USPS",
    "Generated USPS",
    "Reconstructed MNIST",
]

fig, axes = plt.subplots(4, 10, figsize=(15, 6))

for row_index in range(4):
    for digit in range(10):
        axes[row_index, digit].imshow(image_rows[row_index][digit, :, :, 0], cmap="gray", vmin=0.0, vmax=1.0)
        axes[row_index, digit].axis("off")

    axes[row_index, 0].text(-0.35, 0.5, row_labels[row_index], transform=axes[row_index, 0].transAxes, ha="right", va="center", fontsize=13)

plt.tight_layout()
plt.show()
No description has been provided for this image

  • Original MNIST — the labeled source images used as input
  • Original USPS — the unlabeled target images used for style reference
  • Generated USPS — MNIST images transformed into the visual style of USPS by $G_{XY}$
  • Reconstructed MNIST — the generated USPS images transformed back to the source style by $G_{YX}$

The generated USPS shows that CycleGAN successfully transfers the visual style of the target domain onto the source images — the digits take on the thicker, smoother appearance characteristic of USPS while retaining the correct digit identity. The reconstructed MNIST confirms that the cycle consistency constraint is working as intended — the round-trip transformation recovers images that closely resemble the original MNIST inputs, indicating that the semantic content has been preserved throughout the transformation.


In [ ]:
cycle_finetune_batch_size = 128
cycle_finetune_epochs = 20

generated_usps_train_images = mnist_to_usps_generator.predict(
    mnist_train_images_cyclegan,
    batch_size=cycle_finetune_batch_size,
    verbose=1)

generated_usps_train_images = (generated_usps_train_images + 1.0) / 2.0

cycle_finetune_classifier = tf.keras.models.clone_model(source_classifier)
cycle_finetune_classifier(np.zeros((1, 28, 28, 1), dtype="float32"), training=False)
cycle_finetune_classifier.set_weights(source_classifier.get_weights())

cycle_finetune_classifier.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"])

cycle_finetune_classifier.fit(
    generated_usps_train_images,
    mnist_train_labels,
    batch_size=cycle_finetune_batch_size,
    epochs=cycle_finetune_epochs,
    verbose=1)

cycle_mnist_loss, cycle_mnist_acc = cycle_finetune_classifier.evaluate(
    mnist_test_images,
    mnist_test_labels,
    batch_size=cycle_finetune_batch_size,
    verbose=0)

cycle_usps_loss, cycle_usps_acc = cycle_finetune_classifier.evaluate(
    usps_test_images,
    usps_test_labels,
    batch_size=cycle_finetune_batch_size,
    verbose=0)

print("[Source-only]")
print(f"MNIST accuracy: {source_mnist_acc:.4f}")
print(f"USPS accuracy : {source_usps_acc:.4f}")

print("\n[CycleGAN]")
print(f"MNIST accuracy: {cycle_mnist_acc:.4f}")
print(f"USPS accuracy : {cycle_usps_acc:.4f}")

print(f"\nUSPS improvement over source-only: {cycle_usps_acc - source_usps_acc:+.4f}")
469/469 ━━━━━━━━━━━━━━━━━━━━ 6s 11ms/step
Epoch 1/20
469/469 ━━━━━━━━━━━━━━━━━━━━ 7s 10ms/step - accuracy: 0.9615 - loss: 0.1307
Epoch 2/20
469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - accuracy: 0.9765 - loss: 0.0769
Epoch 3/20
469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9801 - loss: 0.0643
Epoch 4/20
469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9822 - loss: 0.0577
Epoch 5/20
469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9833 - loss: 0.0534
Epoch 6/20
469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 8ms/step - accuracy: 0.9843 - loss: 0.0502
Epoch 7/20
469/469 ━━━━━━━━━━━━━━━━━━━━ 5s 12ms/step - accuracy: 0.9851 - loss: 0.0477
Epoch 8/20
469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9858 - loss: 0.0456
Epoch 9/20
469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9864 - loss: 0.0438
Epoch 10/20
469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9868 - loss: 0.0423
Epoch 11/20
469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9873 - loss: 0.0409
Epoch 12/20
469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9877 - loss: 0.0396
Epoch 13/20
469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9880 - loss: 0.0385
Epoch 14/20
469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9883 - loss: 0.0375
Epoch 15/20
469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9887 - loss: 0.0365
Epoch 16/20
469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9891 - loss: 0.0356
Epoch 17/20
469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 6ms/step - accuracy: 0.9893 - loss: 0.0347
Epoch 18/20
469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9896 - loss: 0.0339
Epoch 19/20
469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 7ms/step - accuracy: 0.9899 - loss: 0.0331
Epoch 20/20
469/469 ━━━━━━━━━━━━━━━━━━━━ 3s 7ms/step - accuracy: 0.9900 - loss: 0.0324
[Source-only]
MNIST accuracy: 0.9856
USPS accuracy : 0.7414

[CycleGAN]
MNIST accuracy: 0.9880
USPS accuracy : 0.8979

USPS improvement over source-only: +0.1565

The Generated USPS images are then used as training data for the classifier, with the original MNIST labels retained. The classification results demonstrate the effectiveness of this approach:

Training on the target-like generated images improves USPS accuracy from 74.1% to 89.8% — a substantial gain of over 15 percentage points — while maintaining comparable performance on the source domain. This confirms that pixel-level domain adaptation via CycleGAN is an effective strategy for closing the domain gap without requiring any labeled target data.



3. Adversarial Domain Adaptation

The generative approach adapts the data to match the target domain at the pixel level. Adversarial domain adaptation takes a fundamentally different approach — rather than transforming the input images, it transforms the feature representations learned by the network.

The core idea is to learn features that are discriminative for the task — capturing information useful for classification — but invariant to the domain — containing no information about whether the input came from the source or the target. If such features can be learned, a classifier trained on labeled source features will generalize directly to target features, since the two are indistinguishable in the learned feature space.


No description has been provided for this image


The figure illustrates the framework. Both source and target data are passed through a shared feature extractor, producing source features and target features respectively. A domain classifier is then trained to distinguish between the two — playing a role analogous to the discriminator in a GAN. At the same time, the feature extractor is trained to fool the domain classifier — producing features that the domain classifier cannot tell apart, regardless of whether they came from the source or the target domain.

This adversarial objective drives the feature extractor toward domain-invariant representations. The training process alternates between two competing objectives:

  • The domain classifier is updated to correctly identify whether a feature came from the source or the target domain
  • The feature extractor is updated to confuse the domain classifier — making source and target features as similar as possible in the learned feature space

The underlying intuition mirrors the GAN framework closely: just as the GAN generator learns to produce samples indistinguishable from real data, the feature extractor here learns to produce representations indistinguishable across domains. The domain classifier serves as the learned, adaptive signal that drives this alignment.


3.1. DANN: Domain-Adversarial Neural Network

DANN is one of the most representative and widely studied methods in adversarial domain adaptation. The figure illustrates its architecture, which consists of three components trained jointly.

The feature extractor takes both source and target data as input and produces a shared feature representation. The extracted features are then passed to two separate branches:

  • The label classifier takes source features and is trained to correctly predict the class label using the supervised source labels. This branch ensures that the learned features remain discriminative for the task — preserving class-relevant information.
  • The domain classifier takes features from both domains and is trained to distinguish whether a feature came from the source or the target domain. This branch drives the adversarial alignment between the two domains.

The key mechanism in DANN is the gradient reversal layer, represented by the $-\lambda \nabla L_{\text{domain}}$ arrow in the figure. During backpropagation, the gradient from the domain classifier is reversed in sign before being passed to the feature extractor. This means that while the domain classifier is updated to better distinguish the two domains, the feature extractor is simultaneously updated in the opposite direction — to make the two domains harder to distinguish.

The total training objective combines the two losses:


$$\mathcal{L} = \mathcal{L}_{\text{class}}^{\text{source}} - \lambda \mathcal{L}_{\text{domain}}$$


The feature extractor is trained to minimize the classification loss while maximizing the domain confusion loss. The hyperparameter $\lambda$ controls the trade-off between preserving class-discriminative information and suppressing domain-specific information. The result is a feature space in which source and target representations are aligned — making the label classifier trained on source features directly applicable to target features.

No description has been provided for this image

No description has been provided for this image


In [ ]:
dann_feature_extractor = tf.keras.models.Sequential([
    tf.keras.layers.Input(shape=(28, 28, 1)),
    tf.keras.layers.Conv2D(32, kernel_size=7, activation="relu", padding="same"),
    tf.keras.layers.MaxPool2D(pool_size=2),
    tf.keras.layers.Conv2D(64, kernel_size=7, activation="relu", padding="same"),
    tf.keras.layers.MaxPool2D(pool_size=2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation="relu")], name="dann_feature_extractor")

dann_feature_extractor.layers[0].set_weights(source_classifier.layers[0].get_weights())
dann_feature_extractor.layers[2].set_weights(source_classifier.layers[2].get_weights())
dann_feature_extractor.layers[5].set_weights(source_classifier.layers[5].get_weights())
In [ ]:
dann_feature_extractor.summary()
Model: "dann_feature_extractor"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ conv2d_32 (Conv2D)              │ (None, 28, 28, 32)     │         1,600 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_4 (MaxPooling2D)  │ (None, 14, 14, 32)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_33 (Conv2D)              │ (None, 14, 14, 64)     │       100,416 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_5 (MaxPooling2D)  │ (None, 7, 7, 64)       │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ flatten_2 (Flatten)             │ (None, 3136)           │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_4 (Dense)                 │ (None, 128)            │       401,536 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 503,552 (1.92 MB)
 Trainable params: 503,552 (1.92 MB)
 Non-trainable params: 0 (0.00 B)
In [ ]:
dann_label_classifier = tf.keras.models.Sequential([
    tf.keras.layers.Input(shape=(128,)),
    tf.keras.layers.Dense(10, activation="softmax")], name="dann_label_classifier")

dann_label_classifier.layers[0].set_weights(source_classifier.layers[6].get_weights())
In [ ]:
dann_label_classifier.summary()
Model: "dann_label_classifier"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ dense_5 (Dense)                 │ (None, 10)             │         1,290 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 1,290 (5.04 KB)
 Trainable params: 1,290 (5.04 KB)
 Non-trainable params: 0 (0.00 B)
In [ ]:
dann_domain_classifier = tf.keras.models.Sequential([
    tf.keras.layers.Input(shape=(128,)),
    tf.keras.layers.Dense(64, activation="relu"),
    tf.keras.layers.Dense(2, activation="softmax")], name="dann_domain_classifier")
In [ ]:
dann_domain_classifier.summary()
Model: "dann_domain_classifier"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ dense_6 (Dense)                 │ (None, 64)             │         8,256 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_7 (Dense)                 │ (None, 2)              │           130 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 8,386 (32.76 KB)
 Trainable params: 8,386 (32.76 KB)
 Non-trainable params: 0 (0.00 B)
In [ ]:
dann_classifier_model = tf.keras.models.Sequential([
    dann_feature_extractor,
    dann_label_classifier], name="dann_classifier_model")

dann_feature_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
dann_label_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4)
dann_domain_optimizer = tf.keras.optimizers.Adam(learning_rate=3e-3)

class_loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
domain_loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
In [ ]:
dann_classifier_model.summary()
Model: "dann_classifier_model"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ dann_feature_extractor          │ (None, 128)            │       503,552 │
│ (Sequential)                    │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dann_label_classifier           │ (None, 10)             │         1,290 │
│ (Sequential)                    │                        │               │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 504,842 (1.93 MB)
 Trainable params: 504,842 (1.93 MB)
 Non-trainable params: 0 (0.00 B)
In [ ]:
dann_batch_size = 128

dann_source_dataset = tf.data.Dataset.from_tensor_slices((mnist_train_images, mnist_train_labels))
dann_source_dataset = dann_source_dataset.shuffle(len(mnist_train_images))
dann_source_dataset = dann_source_dataset.repeat()
dann_source_dataset = dann_source_dataset.batch(dann_batch_size)

dann_target_dataset = tf.data.Dataset.from_tensor_slices(usps_train_images)
dann_target_dataset = dann_target_dataset.shuffle(len(usps_train_images))
dann_target_dataset = dann_target_dataset.repeat()
dann_target_dataset = dann_target_dataset.batch(dann_batch_size)

dann_source_iter = iter(dann_source_dataset)
dann_target_iter = iter(dann_target_dataset)

dann_steps_per_epoch = len(usps_train_images) // dann_batch_size

print("DANN steps per epoch:", dann_steps_per_epoch)
DANN steps per epoch: 56
In [ ]:
dann_epochs = 50
dann_display_epoch_interval = 10

domain_weight = 0.1
target_entropy_weight = 0.1

dann_history = []

for epoch in range(1, dann_epochs + 1):
    total_loss_sum = 0.0
    class_loss_sum = 0.0
    domain_loss_sum = 0.0
    entropy_loss_sum = 0.0

    domain_correct = 0.0
    domain_count = 0.0

    for _ in range(dann_steps_per_epoch):
        source_x, source_y = next(dann_source_iter)
        target_x = next(dann_target_iter)

        domain_x = tf.concat([source_x, target_x], axis=0)
        domain_y = tf.concat([tf.zeros((dann_batch_size,), dtype=tf.int64), tf.ones((dann_batch_size,), dtype=tf.int64),], axis=0)

        with tf.GradientTape() as domain_tape:
            domain_feature = dann_feature_extractor(domain_x, training=True)
            domain_feature = tf.stop_gradient(domain_feature)

            domain_prob = dann_domain_classifier(domain_feature, training=True)
            domain_classifier_loss = domain_loss_fn(domain_y, domain_prob)

        dann_domain_optimizer.apply_gradients(
            zip(domain_tape.gradient(domain_classifier_loss,
                                     dann_domain_classifier.trainable_variables),
                dann_domain_classifier.trainable_variables))

        with tf.GradientTape(persistent=True) as feature_tape:
            source_feature = dann_feature_extractor(source_x, training=True)
            target_feature = dann_feature_extractor(target_x, training=True)

            source_class_prob = dann_label_classifier(source_feature, training=True)
            target_class_prob = dann_label_classifier(target_feature, training=True)

            source_domain_prob = dann_domain_classifier(source_feature, training=False)
            target_domain_prob = dann_domain_classifier(target_feature, training=False)

            domain_prob_for_feature = tf.concat([source_domain_prob, target_domain_prob], axis=0)

            class_loss = class_loss_fn(source_y, source_class_prob)

            domain_confusion_loss = domain_loss_fn(domain_y, domain_prob_for_feature)

            entropy_loss = -tf.reduce_mean(tf.reduce_sum(target_class_prob * tf.math.log(target_class_prob + 1e-8), axis=1))

            total_loss = (class_loss - domain_weight * domain_confusion_loss + target_entropy_weight * entropy_loss)

        dann_feature_optimizer.apply_gradients(
            zip(feature_tape.gradient(total_loss,
                                      dann_feature_extractor.trainable_variables),
                dann_feature_extractor.trainable_variables))

        dann_label_optimizer.apply_gradients(
            zip(feature_tape.gradient(total_loss,
                                      dann_label_classifier.trainable_variables),
                dann_label_classifier.trainable_variables))

        del feature_tape

        domain_pred = tf.argmax(domain_prob, axis=1, output_type=tf.int64)
        domain_correct += float(tf.reduce_sum(tf.cast(domain_pred == domain_y, tf.float32)))
        domain_count += float(tf.shape(domain_y)[0])

        total_loss_sum += float(total_loss)
        class_loss_sum += float(class_loss)
        domain_loss_sum += float(domain_classifier_loss)
        entropy_loss_sum += float(entropy_loss)

    total_loss_mean = total_loss_sum / dann_steps_per_epoch
    class_loss_mean = class_loss_sum / dann_steps_per_epoch
    domain_loss_mean = domain_loss_sum / dann_steps_per_epoch
    entropy_loss_mean = entropy_loss_sum / dann_steps_per_epoch
    domain_acc = domain_correct / domain_count

    dann_mnist_prob = dann_classifier_model.predict(mnist_test_images, batch_size=dann_batch_size, verbose=0)
    dann_usps_prob = dann_classifier_model.predict(usps_test_images, batch_size=dann_batch_size, verbose=0)

    dann_mnist_pred = np.argmax(dann_mnist_prob, axis=1)
    dann_usps_pred = np.argmax(dann_usps_prob, axis=1)

    dann_mnist_acc = np.mean(dann_mnist_pred == mnist_test_labels.reshape(-1))
    dann_usps_acc = np.mean(dann_usps_pred == usps_test_labels.reshape(-1))

    dann_history.append({
        "epoch": epoch,
        "total_loss": total_loss_mean,
        "class_loss": class_loss_mean,
        "domain_loss": domain_loss_mean,
        "entropy_loss": entropy_loss_mean,
        "domain_acc": domain_acc,
        "mnist_acc": dann_mnist_acc
    })

    if epoch == 1 or epoch % dann_display_epoch_interval == 0 or epoch == dann_epochs:
        print(
            f"Epoch {epoch:03d}/{dann_epochs} | "
            f"total={total_loss_mean:.4f} | "
            f"class={class_loss_mean:.4f} | "
            f"domain={domain_loss_mean:.4f} | "
            f"entropy={entropy_loss_mean:.4f} | "
            f"domain_acc={domain_acc:.4f} | "
            f"MNIST_acc={dann_mnist_acc:.4f} | "
        )
Epoch 001/50 | total=0.0118 | class=0.0201 | domain=0.4231 | entropy=0.3069 | domain_acc=0.8012 | MNIST_acc=0.9924 | 
Epoch 010/50 | total=-0.0038 | class=0.0127 | domain=0.2393 | entropy=0.0571 | domain_acc=0.9024 | MNIST_acc=0.9917 | 
Epoch 020/50 | total=-0.0242 | class=0.0084 | domain=0.3823 | entropy=0.0372 | domain_acc=0.8212 | MNIST_acc=0.9920 | 
Epoch 030/50 | total=-0.0379 | class=0.0107 | domain=0.5744 | entropy=0.0347 | domain_acc=0.7525 | MNIST_acc=0.9918 | 
Epoch 040/50 | total=-0.0390 | class=0.0114 | domain=0.5445 | entropy=0.0264 | domain_acc=0.7003 | MNIST_acc=0.9924 | 
Epoch 050/50 | total=-0.0456 | class=0.0094 | domain=0.5986 | entropy=0.0272 | domain_acc=0.6713 | MNIST_acc=0.9917 | 
In [ ]:
dann_mnist_prob = dann_classifier_model.predict(
    mnist_test_images,
    batch_size=dann_batch_size,
    verbose=0)

dann_usps_prob = dann_classifier_model.predict(
    usps_test_images,
    batch_size=dann_batch_size,
    verbose=0)

dann_mnist_pred = np.argmax(dann_mnist_prob, axis=1)
dann_usps_pred = np.argmax(dann_usps_prob, axis=1)

dann_mnist_acc = np.mean(dann_mnist_pred == mnist_test_labels)
dann_usps_acc = np.mean(dann_usps_pred == usps_test_labels)

print("[Source-only]")
print(f"MNIST accuracy: {source_mnist_acc:.4f}")
print(f"USPS accuracy : {source_usps_acc:.4f}")

print("\n[DANN]")
print(f"MNIST accuracy: {dann_mnist_acc:.4f}")
print(f"USPS accuracy : {dann_usps_acc:.4f}")

print(f"\nUSPS improvement over source-only: {dann_usps_acc - source_usps_acc:+.4f}")
[Source-only]
MNIST accuracy: 0.9856
USPS accuracy : 0.7414

[DANN]
MNIST accuracy: 0.9917
USPS accuracy : 0.9427

USPS improvement over source-only: +0.2013

The classification results demonstrate the effectiveness of DANN compared to the source-only baseline and the generative approach. DANN achieves a USPS accuracy of 94.3% — an improvement of approximately 20 percentage points over the source-only baseline.

The superior performance of DANN over the generative approach can be attributed to the difference in where adaptation takes place. CycleGAN operates at the pixel level — it adapts the raw input images to look like the target domain before training a classifier. DANN operates at the feature level — it directly aligns the learned representations of the two domains within the network. Feature-level alignment is generally more flexible and more directly connected to the classification objective, as the adversarial training simultaneously encourages domain invariance and class discriminability in the same feature space.


In [ ]:
selected_digits = [1, 2, 7]
plot_points_per_digit = 100

source_classifier(np.zeros((1, 28, 28, 1), dtype="float32"), training=False)

source_feature_extractor_for_plot = tf.keras.Model(
    inputs=source_classifier.inputs,
    outputs=source_classifier.layers[-2].output)

pretrain_mnist_feature = source_feature_extractor_for_plot.predict(
    mnist_test_images,
    batch_size=dann_batch_size,
    verbose=0)

pretrain_usps_feature = source_feature_extractor_for_plot.predict(
    usps_test_images,
    batch_size=dann_batch_size,
    verbose=0)

dann_mnist_feature = dann_feature_extractor.predict(
    mnist_test_images,
    batch_size=dann_batch_size,
    verbose=0)

dann_usps_feature = dann_feature_extractor.predict(
    usps_test_images,
    batch_size=dann_batch_size,
    verbose=0)

pretrain_feature_2d = PCA(n_components=2).fit_transform(np.vstack([pretrain_mnist_feature, pretrain_usps_feature]))
dann_feature_2d = PCA(n_components=2).fit_transform(np.vstack([dann_mnist_feature, dann_usps_feature]))

pretrain_mnist_2d = pretrain_feature_2d[:len(pretrain_mnist_feature)]
pretrain_usps_2d = pretrain_feature_2d[len(pretrain_mnist_feature):]

dann_mnist_2d = dann_feature_2d[:len(dann_mnist_feature)]
dann_usps_2d = dann_feature_2d[len(dann_mnist_feature):]

mnist_plot_indices = []
usps_plot_indices = []

for digit in selected_digits:
    mnist_idx = np.where(mnist_test_labels == digit)[0]
    usps_idx = np.where(usps_test_labels == digit)[0]

    mnist_plot_indices.extend(mnist_idx[:plot_points_per_digit])
    usps_plot_indices.extend(usps_idx[:plot_points_per_digit])

mnist_plot_indices = np.array(mnist_plot_indices)
usps_plot_indices = np.array(usps_plot_indices)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for color_index, digit in enumerate(selected_digits):
    mnist_idx = mnist_plot_indices[mnist_test_labels[mnist_plot_indices] == digit]
    usps_idx = usps_plot_indices[usps_test_labels[usps_plot_indices] == digit]
    color = f"C{color_index}"

    axes[0].scatter(pretrain_mnist_2d[mnist_idx, 0], pretrain_mnist_2d[mnist_idx, 1], s=18, alpha=0.35, c=color, marker="o")
    axes[0].scatter(pretrain_usps_2d[usps_idx, 0], pretrain_usps_2d[usps_idx, 1], s=28, alpha=0.80, c=color, marker="x")
    axes[1].scatter(dann_mnist_2d[mnist_idx, 0], dann_mnist_2d[mnist_idx, 1], s=18, alpha=0.35, c=color, marker="o")
    axes[1].scatter(dann_usps_2d[usps_idx, 0], dann_usps_2d[usps_idx, 1], s=28, alpha=0.80, c=color, marker="x")

axes[0].set_title("Source-only Feature Space", fontsize=20)
axes[1].set_title("DANN Feature Space", fontsize=20)

for ax in axes:
    ax.set_xlabel("PC1", fontsize=17)
    ax.set_ylabel("PC2", fontsize=17)
    ax.tick_params(axis="both", labelsize=15)
    ax.grid(alpha=0.3)

legend_handles = []
for color_index, digit in enumerate(selected_digits):
    color = f"C{color_index}"

    legend_handles.append(Line2D([0], [0], marker="o", color="w", markerfacecolor=color, markeredgecolor=color, markersize=8, label=f"MNIST {digit}"))
    legend_handles.append(Line2D([0], [0], marker="x", color=color, markersize=8, linewidth=0, label=f"USPS {digit}"))

fig.legend(handles=legend_handles, loc="center left", bbox_to_anchor=(0.98, 0.5),fontsize=12)

plt.tight_layout(rect=[0, 0, 0.88, 1])
plt.show()
No description has been provided for this image

The result provides a visual explanation of why DANN outperforms the source-only baseline, by showing the learned feature spaces of the two models projected onto two principal components.

In the source-only feature space (left), the MNIST samples (circles) and USPS samples (crosses) of the same digit class are clearly separated from each other. The decision boundary learned from the source domain does not transfer well to the target domain.

In the DANN feature space (right), the picture is strikingly different. The MNIST and USPS samples of the same digit class are now interleaved and occupy overlapping regions — MNIST 1 and USPS 1 cluster together, as do MNIST 2 and USPS 2, and MNIST 7 and USPS 7. The features are now organized by class rather than by domain, confirming that the adversarial training has successfully suppressed domain-specific information while preserving class-discriminative structure.

This visualization directly illustrates the goal of adversarial domain adaptation: to learn a feature space in which the source and target distributions are aligned, so that a classifier trained on labeled source features generalizes naturally to unlabeled target features.



4. Discrepancy-based Domain Adaptation

4.1. CORAL: CORrelation ALignment

While adversarial domain adaptation aligns source and target distributions through a learned domain classifier, discrepancy-based methods take a more direct approach — they explicitly measure and minimize a statistical distance between the source and target feature distributions.

CORAL is one of the simplest and most interpretable methods in this category. If we assume that the source and target feature distributions can be approximated as Gaussian, then the distribution is fully characterized by its first-order statistics (mean) and second-order statistics (covariance). CORAL focuses on aligning the covariance structure of the source and target features — under this assumption, matching the covariance is sufficient to align the shape and spread of the two distributions.


No description has been provided for this image


The figure illustrates the intuition. On the left, the source and target feature distributions are misaligned — they occupy different regions and have different shapes in the feature space. On the right, the source distribution has been transformed to overlap with the target distribution by matching its covariance structure.

The goal of CORAL is to find a linear transformation $A$ that minimizes the difference between the transformed source covariance and the target covariance, measured by the norm:


$$\min_A \left\| A^T C_S A - C_T \right\|^2$$


This objective has a closed-form solution — $A$ can be computed analytically from the sample covariances of the source and target features, without requiring any iterative optimization. Once $A$ is obtained, the transformed source features $X_S^{\text{CORAL}} = X_S A$ are used to train a classifier that generalizes to the target domain.


4.2. DeepCORAL

CORAL as described above is a fixed linear transformation applied to pre-extracted features. DeepCORAL extends this idea by integrating the CORAL objective directly into neural network training, allowing the feature representations themselves to be learned with reduced covariance discrepancy.


No description has been provided for this image

The figure illustrates the framework. Both source and target data are passed through a shared feature extractor. The extracted source and target features are used to compute their respective covariance matrices, and the CORAL loss penalizes the difference between the two. At the same time, the source features are passed to a classifier trained with the standard cross-entropy loss on the source labels.

The total training objective combines the two losses:


$$\mathcal{L} = \mathcal{L}_{\text{class}}^{\text{source}} + \lambda \mathcal{L}_{\text{CORAL}}$$


where the CORAL loss is defined as:


$$\mathcal{L}_{\text{CORAL}} = \frac{1}{4d^2} \left\| C_S - C_T \right\|_F^2$$


  • $\mathcal{L}_{\text{class}}^{\text{source}}$: classification loss computed on labeled source data — ensures that the learned features remain discriminative for the task
  • $\mathcal{L}_{\text{CORAL}}$: covariance discrepancy between source and target features — encourages the feature extractor to produce representations with aligned second-order statistics across the two domains
  • $d$: feature dimension, used for normalization

The hyperparameter $\lambda$ controls the trade-off between the two objectives. By jointly optimizing classification accuracy and covariance alignment, DeepCORAL learns feature representations that are both task-discriminative and domain-invariant — without requiring any adversarial training or labeled target data.


In [ ]:
coral_feature_extractor = tf.keras.models.Sequential([
    tf.keras.layers.Input(shape=(28, 28, 1)),
    tf.keras.layers.Conv2D(32, kernel_size=7, activation="relu", padding="same"),
    tf.keras.layers.MaxPool2D(pool_size=2),
    tf.keras.layers.Conv2D(64, kernel_size=7, activation="relu", padding="same"),
    tf.keras.layers.MaxPool2D(pool_size=2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation="relu"),], name="coral_feature_extractor")

coral_feature_extractor.layers[0].set_weights(source_classifier.layers[0].get_weights())
coral_feature_extractor.layers[2].set_weights(source_classifier.layers[2].get_weights())
coral_feature_extractor.layers[5].set_weights(source_classifier.layers[5].get_weights())
In [ ]:
coral_feature_extractor.summary()
Model: "coral_feature_extractor"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ conv2d_34 (Conv2D)              │ (None, 28, 28, 32)     │         1,600 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_6 (MaxPooling2D)  │ (None, 14, 14, 32)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_35 (Conv2D)              │ (None, 14, 14, 64)     │       100,416 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_7 (MaxPooling2D)  │ (None, 7, 7, 64)       │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ flatten_3 (Flatten)             │ (None, 3136)           │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_8 (Dense)                 │ (None, 128)            │       401,536 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 503,552 (1.92 MB)
 Trainable params: 503,552 (1.92 MB)
 Non-trainable params: 0 (0.00 B)
In [ ]:
coral_label_classifier = tf.keras.models.Sequential([
    tf.keras.layers.Input(shape=(128,)),
    tf.keras.layers.Dense(10, activation="softmax")], name="coral_label_classifier")

coral_label_classifier.layers[0].set_weights(source_classifier.layers[6].get_weights())
In [ ]:
coral_label_classifier.summary()
Model: "coral_label_classifier"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ dense_9 (Dense)                 │ (None, 10)             │         1,290 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 1,290 (5.04 KB)
 Trainable params: 1,290 (5.04 KB)
 Non-trainable params: 0 (0.00 B)
In [ ]:
coral_classifier_model = tf.keras.models.Sequential([
    coral_feature_extractor,
    coral_label_classifier], name="coral_classifier_model")

coral_feature_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
coral_label_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4)

coral_class_loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
In [ ]:
coral_classifier_model.summary()
Model: "coral_classifier_model"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ coral_feature_extractor         │ (None, 128)            │       503,552 │
│ (Sequential)                    │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ coral_label_classifier          │ (None, 10)             │         1,290 │
│ (Sequential)                    │                        │               │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 504,842 (1.93 MB)
 Trainable params: 504,842 (1.93 MB)
 Non-trainable params: 0 (0.00 B)
In [ ]:
coral_batch_size = 128

coral_source_dataset = tf.data.Dataset.from_tensor_slices((mnist_train_images, mnist_train_labels))
coral_source_dataset = coral_source_dataset.shuffle(len(mnist_train_images))
coral_source_dataset = coral_source_dataset.repeat()
coral_source_dataset = coral_source_dataset.batch(coral_batch_size)

coral_target_dataset = tf.data.Dataset.from_tensor_slices(usps_train_images)
coral_target_dataset = coral_target_dataset.shuffle(len(usps_train_images))
coral_target_dataset = coral_target_dataset.repeat()
coral_target_dataset = coral_target_dataset.batch(coral_batch_size)

coral_source_iter = iter(coral_source_dataset)
coral_target_iter = iter(coral_target_dataset)

coral_steps_per_epoch = len(usps_train_images) // coral_batch_size
In [ ]:
coral_epochs = 100
coral_display_epoch_interval = 10
coral_weight = 500.0

for epoch in range(1, coral_epochs + 1):
    total_loss_sum = 0.0
    class_loss_sum = 0.0
    coral_loss_sum = 0.0

    for _ in range(coral_steps_per_epoch):
        source_x, source_y = next(coral_source_iter)
        target_x = next(coral_target_iter)

        with tf.GradientTape(persistent=True) as tape:
            source_feature = coral_feature_extractor(source_x, training=True)
            target_feature = coral_feature_extractor(target_x, training=True)

            source_class_prob = coral_label_classifier(source_feature, training=True)
            class_loss = coral_class_loss_fn(source_y, source_class_prob)

            source_centered = source_feature - tf.reduce_mean(source_feature, axis=0, keepdims=True)
            target_centered = target_feature - tf.reduce_mean(target_feature, axis=0, keepdims=True)

            batch_size_float = tf.cast(tf.shape(source_feature)[0], tf.float32)
            feature_dim_float = tf.cast(tf.shape(source_feature)[1], tf.float32)

            source_covariance = tf.matmul(source_centered, source_centered, transpose_a=True) / (batch_size_float - 1.0)
            target_covariance = tf.matmul(target_centered, target_centered, transpose_a=True) / (batch_size_float - 1.0)

            coral_loss = tf.reduce_sum(
                tf.square(source_covariance - target_covariance)) / (4.0 * feature_dim_float * feature_dim_float)

            total_loss = class_loss + coral_weight * coral_loss

        coral_feature_optimizer.apply_gradients(
            zip(tape.gradient(total_loss,
                              coral_feature_extractor.trainable_variables),
                coral_feature_extractor.trainable_variables))

        coral_label_optimizer.apply_gradients(
            zip(tape.gradient(total_loss,
                              coral_label_classifier.trainable_variables),
                coral_label_classifier.trainable_variables))
        del tape

        total_loss_sum += float(total_loss)
        class_loss_sum += float(class_loss)
        coral_loss_sum += float(coral_loss)

    total_loss_mean = total_loss_sum / coral_steps_per_epoch
    class_loss_mean = class_loss_sum / coral_steps_per_epoch
    coral_loss_mean = coral_loss_sum / coral_steps_per_epoch
    scaled_coral_loss_mean = coral_weight * coral_loss_mean

    coral_mnist_prob = coral_classifier_model.predict(
        mnist_test_images,
        batch_size=coral_batch_size,
        verbose=0)

    coral_usps_prob = coral_classifier_model.predict(
        usps_test_images,
        batch_size=coral_batch_size,
        verbose=0)

    coral_mnist_pred = np.argmax(coral_mnist_prob, axis=1)
    coral_usps_pred = np.argmax(coral_usps_prob, axis=1)

    coral_mnist_acc = np.mean(coral_mnist_pred == mnist_test_labels.reshape(-1))
    coral_usps_acc = np.mean(coral_usps_pred == usps_test_labels.reshape(-1))

    if epoch == 1 or epoch % coral_display_epoch_interval == 0 or epoch == coral_epochs:
        print(
            f"Epoch {epoch:03d}/{coral_epochs} | "
            f"total={total_loss_mean:.4f} | "
            f"class={class_loss_mean:.4f} | "
            f"coral={coral_loss_mean:.6f} | "
            f"scaled_coral={scaled_coral_loss_mean:.4f} | "
            f"MNIST_acc={coral_mnist_acc:.4f} | "
        )
Epoch 001/100 | total=16.3539 | class=0.0527 | coral=0.032602 | scaled_coral=16.3012 | MNIST_acc=0.9900 | 
Epoch 010/100 | total=0.2518 | class=0.1256 | coral=0.000252 | scaled_coral=0.1261 | MNIST_acc=0.9906 | 
Epoch 020/100 | total=0.1329 | class=0.0760 | coral=0.000114 | scaled_coral=0.0570 | MNIST_acc=0.9917 | 
Epoch 030/100 | total=0.0826 | class=0.0499 | coral=0.000065 | scaled_coral=0.0327 | MNIST_acc=0.9922 | 
Epoch 040/100 | total=0.0610 | class=0.0373 | coral=0.000047 | scaled_coral=0.0237 | MNIST_acc=0.9930 | 
Epoch 050/100 | total=0.0501 | class=0.0305 | coral=0.000039 | scaled_coral=0.0196 | MNIST_acc=0.9934 | 
Epoch 060/100 | total=0.0426 | class=0.0267 | coral=0.000032 | scaled_coral=0.0159 | MNIST_acc=0.9934 | 
Epoch 070/100 | total=0.0337 | class=0.0200 | coral=0.000028 | scaled_coral=0.0138 | MNIST_acc=0.9933 | 
Epoch 080/100 | total=0.0290 | class=0.0169 | coral=0.000024 | scaled_coral=0.0122 | MNIST_acc=0.9932 | 
Epoch 090/100 | total=0.0233 | class=0.0128 | coral=0.000021 | scaled_coral=0.0105 | MNIST_acc=0.9934 | 
Epoch 100/100 | total=0.0234 | class=0.0145 | coral=0.000018 | scaled_coral=0.0089 | MNIST_acc=0.9930 | 
In [ ]:
coral_mnist_prob = coral_classifier_model.predict(mnist_test_images, batch_size=coral_batch_size, verbose=0)
coral_usps_prob = coral_classifier_model.predict(usps_test_images, batch_size=coral_batch_size, verbose=0)

coral_mnist_pred = np.argmax(coral_mnist_prob, axis=1)
coral_usps_pred = np.argmax(coral_usps_prob, axis=1)

coral_mnist_acc = np.mean(coral_mnist_pred == mnist_test_labels.reshape(-1))
coral_usps_acc = np.mean(coral_usps_pred == usps_test_labels.reshape(-1))

print("[Source-only]")
print(f"MNIST accuracy: {source_mnist_acc:.4f}")
print(f"USPS accuracy : {source_usps_acc:.4f}")

print("\n[Deep CORAL]")
print(f"MNIST accuracy: {coral_mnist_acc:.4f}")
print(f"USPS accuracy : {coral_usps_acc:.4f}")
print(f"USPS improvement over source-only: {coral_usps_acc - source_usps_acc:+.4f}")
[Source-only]
MNIST accuracy: 0.9856
USPS accuracy : 0.7414

[Deep CORAL]
MNIST accuracy: 0.9930
USPS accuracy : 0.9038
USPS improvement over source-only: +0.1624
In [ ]:
selected_digits = [1, 2, 7]
plot_points_per_digit = 100

mnist_test_labels_1d = mnist_test_labels.reshape(-1)
usps_test_labels_1d = usps_test_labels.reshape(-1)

pretrain_mnist_feature = source_feature_extractor_for_plot.predict(
    mnist_test_images,
    batch_size=coral_batch_size,
    verbose=0)
pretrain_usps_feature = source_feature_extractor_for_plot.predict(
    usps_test_images,
    batch_size=coral_batch_size,
    verbose=0)

coral_mnist_feature = coral_feature_extractor.predict(
    mnist_test_images,
    batch_size=coral_batch_size,
    verbose=0)
coral_usps_feature = coral_feature_extractor.predict(
    usps_test_images,
    batch_size=coral_batch_size,
    verbose=0)

pretrain_feature_2d = PCA(n_components=2).fit_transform(np.vstack([pretrain_mnist_feature, pretrain_usps_feature]))
coral_feature_2d = PCA(n_components=2).fit_transform(np.vstack([coral_mnist_feature, coral_usps_feature]))

pretrain_mnist_2d = pretrain_feature_2d[:len(pretrain_mnist_feature)]
pretrain_usps_2d = pretrain_feature_2d[len(pretrain_mnist_feature):]

coral_mnist_2d = coral_feature_2d[:len(coral_mnist_feature)]
coral_usps_2d = coral_feature_2d[len(coral_mnist_feature):]

mnist_plot_indices = []
usps_plot_indices = []

for digit in selected_digits:
    mnist_idx = np.where(mnist_test_labels_1d == digit)[0]
    usps_idx = np.where(usps_test_labels_1d == digit)[0]

    mnist_plot_indices.extend(mnist_idx[:plot_points_per_digit])
    usps_plot_indices.extend(usps_idx[:plot_points_per_digit])

mnist_plot_indices = np.array(mnist_plot_indices)
usps_plot_indices = np.array(usps_plot_indices)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for color_index, digit in enumerate(selected_digits):
    mnist_idx = mnist_plot_indices[mnist_test_labels_1d[mnist_plot_indices] == digit]
    usps_idx = usps_plot_indices[usps_test_labels_1d[usps_plot_indices] == digit]
    color = f"C{color_index}"

    axes[0].scatter(pretrain_mnist_2d[mnist_idx, 0], pretrain_mnist_2d[mnist_idx, 1], s=18, alpha=0.35, c=color, marker="o")
    axes[0].scatter(pretrain_usps_2d[usps_idx, 0], pretrain_usps_2d[usps_idx, 1], s=28, alpha=0.80, c=color, marker="x")
    axes[1].scatter(coral_mnist_2d[mnist_idx, 0], coral_mnist_2d[mnist_idx, 1], s=18, alpha=0.35, c=color, marker="o")
    axes[1].scatter(coral_usps_2d[usps_idx, 0], coral_usps_2d[usps_idx, 1], s=28, alpha=0.80, c=color, marker="x")

axes[0].set_title("Source-only Feature Space", fontsize=20)
axes[1].set_title("Deep CORAL Feature Space", fontsize=20)

for ax in axes:
    ax.set_xlabel("PC1", fontsize=17)
    ax.set_ylabel("PC2", fontsize=17)
    ax.tick_params(axis="both", labelsize=15)
    ax.grid(alpha=0.3)

legend_handles = []
for color_index, digit in enumerate(selected_digits):
    color = f"C{color_index}"

    legend_handles.append(Line2D([0], [0], marker="o", color="w", markerfacecolor=color, markeredgecolor=color, markersize=8, label=f"MNIST {digit}"))
    legend_handles.append(Line2D([0], [0], marker="x", color=color, markersize=8, linewidth=0, label=f"USPS {digit}"))

fig.legend(handles=legend_handles, loc="center left", bbox_to_anchor=(0.98, 0.5), fontsize=12)

plt.tight_layout(rect=[0, 0, 0.88, 1])
plt.show()
No description has been provided for this image
In [ ]:
summary_rows = [{"method": "Source-only", "mnist_acc": source_mnist_acc, "usps_acc": source_usps_acc, "usps_improvement": 0.0},
                {"method": "CycleGAN", "mnist_acc": cycle_mnist_acc, "usps_acc": cycle_usps_acc, "usps_improvement": cycle_usps_acc - source_usps_acc,},
                {"method": "DANN", "mnist_acc": dann_mnist_acc, "usps_acc": dann_usps_acc, "usps_improvement": dann_usps_acc - source_usps_acc,},
                {"method": "Deep CORAL", "mnist_acc": coral_mnist_acc, "usps_acc": coral_usps_acc, "usps_improvement": coral_usps_acc - source_usps_acc,}]

print("Domain Adaptation Summary")
print("-" * 78)
print(
    f"{'Method':<28} "
    f"{'MNIST Acc':>12} "
    f"{'USPS Acc':>12} "
    f"{'USPS Improvement':>18}")
print("-" * 78)

for row in summary_rows:
    print(
        f"{row['method']:<28} "
        f"{row['mnist_acc']:>12.4f} "
        f"{row['usps_acc']:>12.4f} "
        f"{row['usps_improvement']:>+18.4f}")
print("-" * 78)
Domain Adaptation Summary
------------------------------------------------------------------------------
Method                          MNIST Acc     USPS Acc   USPS Improvement
------------------------------------------------------------------------------
Source-only                        0.9856       0.7414            +0.0000
CycleGAN                           0.9880       0.8979            +0.1565
DANN                               0.9917       0.9427            +0.2013
Deep CORAL                         0.9930       0.9038            +0.1624
------------------------------------------------------------------------------
In [1]:
%%javascript
$.getScript('https://kmahelona.github.io/ipython_notebook_goodies/ipython_notebook_toc.js')