Domain Adaptation


By Prof. Hyunseok Oh
https://sddo.gist.ac.kr/
SDDO Lab at GIST

  • 소스 도메인의 정보를 타겟 도메인에 적응시켜 예측 성능 향상
  • Feature space 상에서 도메인 간 데이터의 관계를 사용

Domain adversarial neural net (DANN)

  • 타겟 도메인 label 없이 학습
  • End to end 학습(태스크 분류기, 도메인 분류기)
  • Backpropagation 변경 레이어 추가: GRL (Gradient reversal layer)

1. Import library

In [31]:
import tensorflow as tf
from tensorflow.python.keras.layers import Dense, Flatten, Conv2D, MaxPool2D
from tensorflow.python.keras import Model
import numpy as np
import matplotlib.pyplot as plt

2. Data preparation

2.1 MNIST and SVHN dataset

In [5]:
!wget https://url.kr/6ri2mk -O 'data.zip'
!unzip data.zip -d './data'
--2022-01-18 05:05:51--  https://url.kr/6ri2mk
Resolving url.kr (url.kr)... 183.111.169.122
Connecting to url.kr (url.kr)|183.111.169.122|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://url.kr/spam_filtering_system.php?short=6ri2mk [following]
--2022-01-18 05:05:53--  https://url.kr/spam_filtering_system.php?short=6ri2mk
Reusing existing connection to url.kr:443.
HTTP request sent, awaiting response... 302 Found
Location: https://github.com/MinseokOff/KSME_advanced/raw/main/Domain_adaptation/data.zip [following]
--2022-01-18 05:05:53--  https://github.com/MinseokOff/KSME_advanced/raw/main/Domain_adaptation/data.zip
Resolving github.com (github.com)... 140.82.112.4
Connecting to github.com (github.com)|140.82.112.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/MinseokOff/KSME_advanced/main/Domain_adaptation/data.zip [following]
--2022-01-18 05:05:53--  https://raw.githubusercontent.com/MinseokOff/KSME_advanced/main/Domain_adaptation/data.zip
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 569122 (556K) [application/zip]
Saving to: ‘data.zip’

data.zip            100%[===================>] 555.78K  --.-KB/s    in 0.007s  

2022-01-18 05:05:53 (83.1 MB/s) - ‘data.zip’ saved [569122/569122]

Archive:  data.zip
  inflating: ./data/test_mnist.npz   
  inflating: ./data/test_svhn.npz    
  inflating: ./data/train_mnist.npz  
  inflating: ./data/train_svhn.npz   
In [6]:
mnist_train = np.load('./data/train_mnist.npz')
x_train, y_train  = mnist_train['x'], mnist_train['y']
mnist_test = np.load('./data/test_mnist.npz')
x_test, y_test  = mnist_test['x'], mnist_test['y']

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(100).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
In [7]:
svhn_train = np.load('./data/train_svhn.npz')
svhn_train_ls, svhn_train_y  = svhn_train['x'], svhn_train['y']
svhn_test = np.load('./data/test_svhn.npz')
svhn_test_ls, svhn_test_y  = svhn_test['x'], svhn_test['y']

svhn_train_ds = tf.data.Dataset.from_tensor_slices((svhn_train_ls,svhn_train_y)).batch(32)
svhn_test_ds = tf.data.Dataset.from_tensor_slices((svhn_test_ls,svhn_test_y)).batch(32)

2.2 Normalization Information

In [8]:
all_train_domain_images = np.vstack((x_train, svhn_train_ls))
channel_mean = all_train_domain_images.mean((0,1,2))
channel_mean
Out[8]:
array([73.41482781, 73.1164477 , 75.94361926])

3. Train baseline CNN

3.1 Baseline model, loss, optimizer, metrics

In [9]:
class BaselineModel(Model):
    def __init__(self):
        super(BaselineModel, self).__init__()
    
        self.normalise = lambda x: (tf.cast(x, tf.float64) - channel_mean) / 255.0
        self.conv1 = Conv2D(64, 5, activation='relu')
        self.conv2 = Conv2D(128, 5, activation='relu')
        self.maxpool = MaxPool2D(2)
        self.flatten = Flatten()
    
        self.d1 = Dense(128, activation='relu')
        self.d2 = Dense(10, activation='softmax')
    

    def call(self, x):
        x = self.normalise(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.maxpool(x)
        x = self.flatten(x)
        x = self.d1(x)

        return self.d2(x)

model = BaselineModel()
In [10]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()
In [11]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

svhn_test_loss = tf.keras.metrics.Mean(name='test_loss')
svhn_test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

3.2 Train and Test function

In [12]:
@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        predictions = model(images)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_loss(loss)
    train_accuracy(labels, predictions)
In [13]:
@tf.function
def test_step(mnist_images, labels, svhn_images, labels2):
    predictions = model(mnist_images)
    t_loss = loss_object(labels, predictions)

    test_loss(t_loss)
    test_accuracy(labels, predictions)

    predictions = model(svhn_images)
    t_loss = loss_object(labels2, predictions)

    svhn_test_loss(t_loss)
    svhn_test_accuracy(labels2, predictions)
In [14]:
def reset_metrics():
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()
    svhn_test_loss.reset_states()
    svhn_test_accuracy.reset_states()

3.3 Train on MNIST

In [15]:
EPOCHS = 20

train_acc = []
test_acc = []
svhn_test_acc = []
for epoch in range(EPOCHS):
    reset_metrics()
    for images, labels in train_ds:
        train_step(images, labels)

    for test_data, svhn_test_data in zip(test_ds,svhn_test_ds):
        test_step(test_data[0], test_data[1], svhn_test_data[0], svhn_test_data[1])

    template = 'Epoch {}, Train Accuracy: {}, Source Test Accuracy: {}, Target Test Accuracy: {}'
    print (template.format(epoch+1,
                           train_accuracy.result()*100,
                           test_accuracy.result()*100,
                           svhn_test_accuracy.result()*100,))
    
    train_acc.append(train_accuracy.result()*100)
    test_acc.append(test_accuracy.result()*100)
    svhn_test_acc.append(svhn_test_accuracy.result()*100)
    
plt.plot(list(range(EPOCHS)), train_acc, label = "Train MNIST")
plt.plot(list(range(EPOCHS)), test_acc, label = "Test MNIST")
plt.plot(list(range(EPOCHS)), svhn_test_acc, label = "Test SVHN")
plt.legend(loc='upper left', bbox_to_anchor=(1.0, 1.0),fontsize = 10,frameon=False)
plt.show()    
Epoch 1, Train Accuracy: 40.0, Source Test Accuracy: 64.0, Target Test Accuracy: 14.0
Epoch 2, Train Accuracy: 82.0, Source Test Accuracy: 68.0, Target Test Accuracy: 16.0
Epoch 3, Train Accuracy: 91.0, Source Test Accuracy: 84.0, Target Test Accuracy: 14.0
Epoch 4, Train Accuracy: 97.5, Source Test Accuracy: 86.0, Target Test Accuracy: 10.0
Epoch 5, Train Accuracy: 99.0, Source Test Accuracy: 92.0, Target Test Accuracy: 16.0
Epoch 6, Train Accuracy: 100.0, Source Test Accuracy: 94.0, Target Test Accuracy: 12.0
Epoch 7, Train Accuracy: 100.0, Source Test Accuracy: 92.0, Target Test Accuracy: 14.0
Epoch 8, Train Accuracy: 100.0, Source Test Accuracy: 94.0, Target Test Accuracy: 12.0
Epoch 9, Train Accuracy: 100.0, Source Test Accuracy: 94.0, Target Test Accuracy: 14.0
Epoch 10, Train Accuracy: 100.0, Source Test Accuracy: 94.0, Target Test Accuracy: 14.0
Epoch 11, Train Accuracy: 100.0, Source Test Accuracy: 94.0, Target Test Accuracy: 14.0
Epoch 12, Train Accuracy: 100.0, Source Test Accuracy: 94.0, Target Test Accuracy: 14.0
Epoch 13, Train Accuracy: 100.0, Source Test Accuracy: 94.0, Target Test Accuracy: 14.0
Epoch 14, Train Accuracy: 100.0, Source Test Accuracy: 94.0, Target Test Accuracy: 14.0
Epoch 15, Train Accuracy: 100.0, Source Test Accuracy: 94.0, Target Test Accuracy: 12.0
Epoch 16, Train Accuracy: 100.0, Source Test Accuracy: 94.0, Target Test Accuracy: 12.0
Epoch 17, Train Accuracy: 100.0, Source Test Accuracy: 94.0, Target Test Accuracy: 12.0
Epoch 18, Train Accuracy: 100.0, Source Test Accuracy: 94.0, Target Test Accuracy: 12.0
Epoch 19, Train Accuracy: 100.0, Source Test Accuracy: 94.0, Target Test Accuracy: 12.0
Epoch 20, Train Accuracy: 100.0, Source Test Accuracy: 94.0, Target Test Accuracy: 12.0
In [16]:
for images, labels in test_ds:
    predictions = model(images)
    print("예측 : ", np.argmax(predictions[3]))
    print("정답 : ", labels[3].numpy())
    plt.imshow(images[3])
    plt.show()
예측 :  0
정답 :  0
예측 :  2
정답 :  2
In [17]:
for images, labels in svhn_test_ds:
    predictions = model(images)
    print("예측 : ", np.argmax(predictions[3]))
    print("정답 : ", labels[3].numpy())
    plt.imshow(images[3])
    plt.show()
예측 :  0
정답 :  0
예측 :  0
정답 :  1

4. Domain adversarial neural network

4.1 Domain adversarial neural network model, loss, optimizer, metrics

  • feature generator
  • label predictor
  • domain predictor
In [18]:
class FeatureGenerator(Model):
    def __init__(self):
        super(FeatureGenerator, self).__init__() 
        self.normalise = lambda x: (tf.cast(x, tf.float64) - channel_mean) / 255.0
        self.conv1 = Conv2D(64, 5, activation='relu')
        self.conv2 = Conv2D(128, 5, activation='relu')
        self.maxpool = MaxPool2D(2)
        self.flatten = Flatten()
    
    def call(self, x):
        x = self.normalise(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.maxpool(x)

        return self.flatten(x)

feature_generator = FeatureGenerator()
In [19]:
class LabelPredictor(Model):
    def __init__(self):
        super(LabelPredictor, self).__init__() 
        self.d1 = Dense(128, activation='relu')
        self.d2 = Dense(10, activation='softmax')

    def call(self, feats):  
        feats = self.d1(feats)
        return self.d2(feats)

label_predictor = LabelPredictor()
In [20]:
class DomainPredictor(Model):
    def __init__(self):
        super(DomainPredictor, self).__init__()   
        self.d3 = Dense(64, activation='relu')
        self.d4 = Dense(2, activation='softmax')

    def call(self, feats):
        feats = self.d3(feats)
        return self.d4(feats)

domain_predictor = DomainPredictor()
In [21]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
d_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
f_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
In [22]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

svhn_test_loss = tf.keras.metrics.Mean(name='m_test_loss')
svhn_test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='m_test_accuracy')
In [23]:
conf_train_loss = tf.keras.metrics.Mean(name='c_train_loss')
conf_train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='c_train_accuracy')

4.2 Domain label generation

In [24]:
x_train_domain_labels = np.ones([len(x_train)])
svhn_train_domain_labels = np.zeros([len(svhn_train_ls)])
all_train_domain_labels = np.hstack((x_train_domain_labels, svhn_train_domain_labels))

domain_train_ds = tf.data.Dataset.from_tensor_slices(
    (all_train_domain_images, tf.cast(all_train_domain_labels, tf.int8))).shuffle(60000).batch(32)

4.3 Train and Test function

In [25]:
@tf.function
def train_step(images, labels, images2, domains, alpha):
    
    ## Update the generator and the classifier
    with tf.GradientTape(persistent=True) as tape:
        features = feature_generator(images)
        l_predictions = label_predictor(features)
        features = feature_generator(images2)
        d_predictions = domain_predictor(features)
        label_loss = loss_object(labels, l_predictions)
        domain_loss = loss_object(domains, d_predictions)
    
    f_gradients_on_label_loss = tape.gradient(label_loss, feature_generator.trainable_variables)
    f_gradients_on_domain_loss = tape.gradient(domain_loss, feature_generator.trainable_variables)    
    f_gradients = [f_gradients_on_label_loss[i] - alpha*f_gradients_on_domain_loss[i] for i in range(len(f_gradients_on_domain_loss))]

    l_gradients = tape.gradient(label_loss, label_predictor.trainable_variables)
    f_optimizer.apply_gradients(zip(f_gradients+l_gradients, 
                                  feature_generator.trainable_variables+label_predictor.trainable_variables)) 
    
    ## Update the discriminator: Comment this bit to complete all updates in one step. Asynchronous updating 
    with tf.GradientTape() as tape:
        features = feature_generator(images2)
        d_predictions = domain_predictor(features)
        domain_loss = loss_object(domains, d_predictions)
   
    d_gradients = tape.gradient(domain_loss, domain_predictor.trainable_variables)  
    d_gradients = [alpha*i for i in d_gradients]
    d_optimizer.apply_gradients(zip(d_gradients, domain_predictor.trainable_variables))
  
    train_loss(label_loss)
    train_accuracy(labels, l_predictions)
    conf_train_loss(domain_loss)
    conf_train_accuracy(domains, d_predictions)
In [26]:
@tf.function
def test_step(mnist_images, labels, svhn_images, labels2):
    features = feature_generator(mnist_images)
    predictions = label_predictor(features)
    t_loss = loss_object(labels, predictions)

    test_loss(t_loss)
    test_accuracy(labels, predictions)

    features = feature_generator(svhn_images)
    predictions = label_predictor(features)
    t_loss = loss_object(labels2, predictions)

    svhn_test_loss(t_loss)
    svhn_test_accuracy(labels2, predictions)
In [27]:
def reset_metrics():
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()
    svhn_test_loss.reset_states()
    svhn_test_accuracy.reset_states()

4.4 Train with MNIST and SVHN

In [28]:
EPOCHS = 50
alpha = 10

train_acc = []
test_acc = []
svhn_test_acc = []
for epoch in range(EPOCHS):
    reset_metrics()

    for domain_data, label_data in zip(domain_train_ds, train_ds):
    
        try:
            train_step(label_data[0], label_data[1], domain_data[0], domain_data[1], alpha=alpha)

        except ValueError: 
            pass
    
    for test_data, svhn_test_data in zip(test_ds, svhn_test_ds):
        test_step(test_data[0], test_data[1], svhn_test_data[0], svhn_test_data[1])
  
    template = 'Epoch {}, Train Accuracy: {}, Domain Accuracy: {}, Source Test Accuracy: {}, Target Test Accuracy: {}'
    print (template.format(epoch+1,
                           train_accuracy.result()*100,
                           conf_train_accuracy.result()*100,
                           test_accuracy.result()*100,
                           svhn_test_accuracy.result()*100,))
    
    train_acc.append(train_accuracy.result()*100)
    test_acc.append(test_accuracy.result()*100)
    svhn_test_acc.append(svhn_test_accuracy.result()*100)
    
plt.plot(list(range(EPOCHS)), train_acc, label = "Train MNIST")
plt.plot(list(range(EPOCHS)), test_acc, label = "Test MNIST")
plt.plot(list(range(EPOCHS)), svhn_test_acc, label = "Test SVHN")
plt.legend(loc='upper left', bbox_to_anchor=(1.0, 1.0),fontsize = 10,frameon=False)
plt.show()    
Epoch 1, Train Accuracy: 22.0, Domain Accuracy: 50.44642639160156, Source Test Accuracy: 34.0, Target Test Accuracy: 6.0
Epoch 2, Train Accuracy: 40.0, Domain Accuracy: 52.67857360839844, Source Test Accuracy: 44.0, Target Test Accuracy: 10.0
Epoch 3, Train Accuracy: 27.000001907348633, Domain Accuracy: 50.2976188659668, Source Test Accuracy: 12.0, Target Test Accuracy: 12.0
Epoch 4, Train Accuracy: 26.499998092651367, Domain Accuracy: 50.44642639160156, Source Test Accuracy: 42.0, Target Test Accuracy: 14.0
Epoch 5, Train Accuracy: 58.0, Domain Accuracy: 54.107139587402344, Source Test Accuracy: 72.0, Target Test Accuracy: 12.0
Epoch 6, Train Accuracy: 68.0, Domain Accuracy: 55.20833206176758, Source Test Accuracy: 66.0, Target Test Accuracy: 16.0
Epoch 7, Train Accuracy: 66.5, Domain Accuracy: 55.61224365234375, Source Test Accuracy: 76.0, Target Test Accuracy: 10.0
Epoch 8, Train Accuracy: 82.0, Domain Accuracy: 58.705360412597656, Source Test Accuracy: 70.0, Target Test Accuracy: 16.0
Epoch 9, Train Accuracy: 82.5, Domain Accuracy: 61.70634460449219, Source Test Accuracy: 72.0, Target Test Accuracy: 12.0
Epoch 10, Train Accuracy: 81.0, Domain Accuracy: 63.83928680419922, Source Test Accuracy: 66.0, Target Test Accuracy: 12.0
Epoch 11, Train Accuracy: 78.0, Domain Accuracy: 66.2337646484375, Source Test Accuracy: 52.0, Target Test Accuracy: 8.0
Epoch 12, Train Accuracy: 70.5, Domain Accuracy: 68.34077453613281, Source Test Accuracy: 62.0, Target Test Accuracy: 10.0
Epoch 13, Train Accuracy: 77.5, Domain Accuracy: 69.60851287841797, Source Test Accuracy: 66.0, Target Test Accuracy: 8.0
Epoch 14, Train Accuracy: 82.5, Domain Accuracy: 70.82270050048828, Source Test Accuracy: 52.0, Target Test Accuracy: 10.0
Epoch 15, Train Accuracy: 79.0, Domain Accuracy: 71.60713958740234, Source Test Accuracy: 58.0, Target Test Accuracy: 6.0
Epoch 16, Train Accuracy: 68.0, Domain Accuracy: 71.90290069580078, Source Test Accuracy: 48.0, Target Test Accuracy: 8.0
Epoch 17, Train Accuracy: 80.0, Domain Accuracy: 72.13760375976562, Source Test Accuracy: 52.0, Target Test Accuracy: 8.0
Epoch 18, Train Accuracy: 74.0, Domain Accuracy: 72.09821319580078, Source Test Accuracy: 40.0, Target Test Accuracy: 10.0
Epoch 19, Train Accuracy: 63.0, Domain Accuracy: 72.55638885498047, Source Test Accuracy: 28.0, Target Test Accuracy: 10.0
Epoch 20, Train Accuracy: 65.0, Domain Accuracy: 72.58928680419922, Source Test Accuracy: 52.0, Target Test Accuracy: 4.0
Epoch 21, Train Accuracy: 76.0, Domain Accuracy: 73.12925720214844, Source Test Accuracy: 66.0, Target Test Accuracy: 14.0
Epoch 22, Train Accuracy: 80.5, Domain Accuracy: 73.11283111572266, Source Test Accuracy: 48.0, Target Test Accuracy: 14.0
Epoch 23, Train Accuracy: 73.5, Domain Accuracy: 72.12732696533203, Source Test Accuracy: 40.0, Target Test Accuracy: 4.0
Epoch 24, Train Accuracy: 79.0, Domain Accuracy: 71.14955139160156, Source Test Accuracy: 56.0, Target Test Accuracy: 6.0
Epoch 25, Train Accuracy: 83.5, Domain Accuracy: 70.35714721679688, Source Test Accuracy: 72.0, Target Test Accuracy: 18.0
Epoch 26, Train Accuracy: 90.5, Domain Accuracy: 69.72870635986328, Source Test Accuracy: 80.0, Target Test Accuracy: 12.0
Epoch 27, Train Accuracy: 96.0, Domain Accuracy: 68.88227844238281, Source Test Accuracy: 82.0, Target Test Accuracy: 10.0
Epoch 28, Train Accuracy: 98.5, Domain Accuracy: 68.2238540649414, Source Test Accuracy: 84.0, Target Test Accuracy: 18.0
Epoch 29, Train Accuracy: 97.5, Domain Accuracy: 67.61083984375, Source Test Accuracy: 88.0, Target Test Accuracy: 10.0
Epoch 30, Train Accuracy: 99.0, Domain Accuracy: 66.97916412353516, Source Test Accuracy: 82.0, Target Test Accuracy: 8.0
Epoch 31, Train Accuracy: 99.0, Domain Accuracy: 66.44585418701172, Source Test Accuracy: 82.0, Target Test Accuracy: 10.0
Epoch 32, Train Accuracy: 99.5, Domain Accuracy: 65.90401458740234, Source Test Accuracy: 84.0, Target Test Accuracy: 16.0
Epoch 33, Train Accuracy: 99.0, Domain Accuracy: 65.39501953125, Source Test Accuracy: 88.0, Target Test Accuracy: 16.0
Epoch 34, Train Accuracy: 100.0, Domain Accuracy: 64.94223022460938, Source Test Accuracy: 86.0, Target Test Accuracy: 10.0
Epoch 35, Train Accuracy: 100.0, Domain Accuracy: 64.41326904296875, Source Test Accuracy: 86.0, Target Test Accuracy: 10.0
Epoch 36, Train Accuracy: 100.0, Domain Accuracy: 64.0625, Source Test Accuracy: 84.0, Target Test Accuracy: 12.0
Epoch 37, Train Accuracy: 100.0, Domain Accuracy: 63.706565856933594, Source Test Accuracy: 84.0, Target Test Accuracy: 12.0
Epoch 38, Train Accuracy: 100.0, Domain Accuracy: 63.392860412597656, Source Test Accuracy: 84.0, Target Test Accuracy: 16.0
Epoch 39, Train Accuracy: 100.0, Domain Accuracy: 62.95787811279297, Source Test Accuracy: 86.0, Target Test Accuracy: 18.0
Epoch 40, Train Accuracy: 100.0, Domain Accuracy: 62.64508819580078, Source Test Accuracy: 84.0, Target Test Accuracy: 16.0
Epoch 41, Train Accuracy: 100.0, Domain Accuracy: 62.39111328125, Source Test Accuracy: 84.0, Target Test Accuracy: 16.0
Epoch 42, Train Accuracy: 100.0, Domain Accuracy: 62.06420135498047, Source Test Accuracy: 86.0, Target Test Accuracy: 18.0
Epoch 43, Train Accuracy: 100.0, Domain Accuracy: 61.8563117980957, Source Test Accuracy: 86.0, Target Test Accuracy: 18.0
Epoch 44, Train Accuracy: 100.0, Domain Accuracy: 61.525978088378906, Source Test Accuracy: 86.0, Target Test Accuracy: 16.0
Epoch 45, Train Accuracy: 100.0, Domain Accuracy: 61.34920883178711, Source Test Accuracy: 86.0, Target Test Accuracy: 14.0
Epoch 46, Train Accuracy: 100.0, Domain Accuracy: 61.10248565673828, Source Test Accuracy: 86.0, Target Test Accuracy: 16.0
Epoch 47, Train Accuracy: 100.0, Domain Accuracy: 60.77127456665039, Source Test Accuracy: 84.0, Target Test Accuracy: 16.0
Epoch 48, Train Accuracy: 100.0, Domain Accuracy: 60.556175231933594, Source Test Accuracy: 84.0, Target Test Accuracy: 16.0
Epoch 49, Train Accuracy: 100.0, Domain Accuracy: 60.349853515625, Source Test Accuracy: 82.0, Target Test Accuracy: 16.0
Epoch 50, Train Accuracy: 100.0, Domain Accuracy: 60.14285659790039, Source Test Accuracy: 82.0, Target Test Accuracy: 18.0
In [29]:
for images, labels in test_ds:
    features = feature_generator(images)
    predictions = label_predictor(features)
    print("예측 : ", np.argmax(predictions[3]))
    print("정답 : ", labels[3].numpy())
    plt.imshow(images[3])
    plt.show()
예측 :  0
정답 :  0
예측 :  2
정답 :  2
In [30]:
for images, labels in svhn_test_ds:
    features = feature_generator(images)
    predictions = label_predictor(features)
    print("예측 : ", np.argmax(predictions[3]))
    print("정답 : ", labels[3].numpy())
    plt.imshow(images[3])
    plt.show()
예측 :  1
정답 :  0
예측 :  0
정답 :  1