Diffusion Models


By Prof. Seungchul Lee
http://iailab.kaist.ac.kr/
Industrial AI Lab at KAIST

Table of Contents


Source



1. Diffusion Models

  • Diffusion models are generative models that learn to generate samples by gradually transforming noise (latent input) to a data sample.
  • Fundamentally there are two processes:

    • Forward process: Given input image $x_0$, gradually add Gaussian noise to it for $T$ steps.
    • Reverse process: A neural network is trained to recover this original data by reversing the noising process.
  • After training, we can generate new samples by denoising noise (latent input) with the trained neural network.


  • Assumption : The transition between timesteps is Markovian, i.e., the current timestep only depends on the previous timestep.



2. Diffusion Process

2.1. Forward Diffusion Process

The forward process iteratively adds noise to the image. Each noise added is considered a Gaussian, and thus can be written as


$$q(x_t | x_{t-1}) = N (x_t;\sqrt{1-\beta_t}x_{t-1}, \beta_t \mathbb{I}), $$

where $\beta_t$ is a noise scheduler.


For a given original image $x_{0}$, this process is repeated $T$ steps, until the image becomes an isotropic Gaussian $N(0, 1)$. Due to the Markovian property, we can express these transitions as a chain of individual transitions from time $0$ to $t$ as


$$q(x_t | x_{0}) = q(x_t | x_{t-1}) q(x_{t-1} | x_{t-2}) \cdots q(x_2 | x_{1}) q(x_1 | x_{0}). $$

Because the joint distribution of Gaussians is also Gaussian, our forward process distribution for a given input image $x_0$ at time $t$ is


$$ q(x_t | x_{0}) = N (x_t;\sqrt{\bar{\alpha}_t}x_{t-1}, 1-\bar{\alpha}_t \mathbb{I}), $$

where $\alpha_t = 1- \beta_t$, $\bar{\alpha}_t = \prod_{s=0}^{t} \alpha_s$, and $\epsilon \sim N (0,1).$


By using the reparametrization trick we can write $x_{t}$ as


$$x_t = \sqrt{\bar{\alpha}_t}*x_0 + \sqrt{1-\bar{\alpha}_t}*\epsilon ,$$

enabling tractable closed-form sampling at any timestep.


2.2. Reverse Diffusion Process

  • Disclaimer: Intermediary steps for derivation will not be provided in this tutorial for simplicity. Please refer to the original paper for full details.

The goal of diffusion models is to learn the reverse process. Because the reverse process (forward process posterior) $q(x_{t-1}|x_t)$ is inherently intractable, we start by conditioning this distribution on the input image $x_0$, i.e., $q(x_{t-1}|x_t, x_0)$. We can write out this distribution as


$$q(x_{t-1}|x_t,x_0) = (N (x_{t-1}; \tilde{\mu}(x_t,x_0) , \tilde{\beta}_t ).$$

Using the equation derived for the forward process. ($x_t = \sqrt{\bar{\alpha}_t}*x_0 + \sqrt{1-\bar{\alpha}_t}*\epsilon$), we can express $\tilde{\mu}(x_t,x_0)$ and $\tilde{\beta}_t$ as


$$ \tilde{\mu}(x_t,x_0) = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1-\alpha_t}{ \sqrt{1-\bar{\alpha}_t}} \epsilon_t \right) $$ and $$ \tilde{\beta}_t = \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t} \beta_t .$$

Here, we attempt to learn this distribution (learn $\tilde{\mu}(x_t,x_0)$ and $\tilde{\beta}_t$) with a neural network $q_{\theta}(x_{t-1}|x_t)$ by minimizing the distance$^*$ between $q(x_{t-1}|x_t, x_0)$ and $q_{\theta}(x_{t-1}|x_t)$. In other words, we are learning $\mu_{\theta}(x_t,x_0)$ and $\Sigma_{\theta}$ (the mean and variance of the forward process posterior). Hence our loss can be formulated as the following:

  • $^*$Note: Additional terms of the original loss function are ommited for simplicity in this tutorial.

$$ L_t = \mathbb{E} \left[ D_{KL}( q(x_{t-1}|x_t, x_0) || q_{\theta}(x_{t-1}|x_t)) \right] $$

where $D_{KL}$ refers to the KL-divergence (a distance metric for distributions).


Writing out $L_t$ (by definition of KL-divergence), we have


$$ L_t = \mathbb{E} \left[ \frac{{(1-\alpha_t)}^2}{2 \alpha_t (1-\bar{\alpha}_t) \lVert \Sigma_{\theta} \rVert _2} \lVert \epsilon_t - \epsilon_{\theta} (\sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t} \epsilon_t, t) \rVert ^2\right]. $$

The authors of [1] show that the training process works better with a simplified objective (by ignoring the weighting term)


$$ L_t = \mathbb{E} \left[ \lVert \epsilon_t - \epsilon_{\theta} (\sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t} \epsilon_t, t) \rVert ^2\right] $$

or

$$ L_t = \mathbb{E} \left[ \lVert \epsilon_t - \epsilon_{\theta} (x_t, t) \rVert ^2\right] . $$

Minimizing this training objective allows the neural network to learn the reverse process.


2.3. Sampling from the Trained Model

Once $q_{\theta}(x_{t-1}|x_t)$ is learned ($\mu_{\theta}(x_t,x_0)$ and $\Sigma_{\theta}$ are learned), we can sample from this distribution for $T$ steps to generate samples.


$$q_{\theta}(x_{t-1}|x_t) = N \left( x_{t-1}; \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1-\alpha_t}{ \sqrt{1-\bar{\alpha}_t}} \epsilon_{\theta} \right) , \Sigma_{\theta} \right).$$

The authors of [1] show that learning a diagonal variance $\Sigma_{\theta}$ leads to unstable training and poor sampling quality, and thus fixed $\Sigma_{\theta} = \beta_t$ instead of making them learnable. Thus using the reparametrization trick, we can express $x_{t-1}$ given $x_t$ as


$$x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_{\theta}(x_t,t) \right) + \sigma_t \textbf{z} ,$$

where $\textbf{z} \sim N(0,1)$. We repeat this process for $t=T,T-1, ...,1, 0.$



3. Diffusion with Fashion MNIST

In [ ]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
In [ ]:
(train_x, train_y), _ = tf.keras.datasets.fashion_mnist.load_data()

X_train = train_x[train_y.squeeze() == 7]
X_train = (X_train / 127.5) - 1.0

print('train_images :', train_x.shape)
train_images : (60000, 28, 28)

3.1. Forward Process Implementation

  • First define the total timesteps $T$ and the noise scheduler $\beta_t$, as a simple linear scheduler.
In [ ]:
T = 50
beta_schedule = np.linspace(0, 1.0, T+1)

  • Use the tractable closed form for $x_t$ derived above to sample noisy images $x_t$.

$$x_t = \sqrt{\bar{\alpha}_t}*x_0 + \sqrt{1-\bar{\alpha}_t}*\epsilon .$$
In [ ]:
def forward_diffusion (x0, t):
    alphas = 1. - beta_schedule
    alpha_bars = tf.math.cumprod(alphas, axis=0)

    epsilon = np.random.normal(size=x0.shape)

    alpha_bar_t = tf.gather(alpha_bars, t)
    alpha_bar_t = np.reshape(alpha_bar_t,(-1,1,1))
    noisy_image = np.sqrt(alpha_bar_t)*x0 + np.sqrt(1 - alpha_bar_t)*epsilon

    return noisy_image, epsilon

3.2. Model architecture

  • Simple U-Net structure, with 3 downsampling blocks and 3 upsampling blocks used as neural network model.

  • Note that the neural network $\epsilon_{\theta}(x_t,t)$ takes in two inputs ($x_t$ and $t$), thus we modify the U-Net to embed the timestep information $t$ (time embedding).
In [ ]:
def time_embedding(x_img, x_ts, n_channels):
    x_parameter = tf.keras.layers.Conv2D(n_channels, kernel_size=3, padding='same')(x_img)
    x_parameter = tf.keras.layers.Activation('relu')(x_parameter)

    time_parameter = tf.keras.layers.Dense(n_channels)(x_ts)
    time_parameter = tf.keras.layers.Activation('relu')(time_parameter)
    time_parameter = tf.keras.layers.Reshape((1, 1, n_channels))(time_parameter)
    x_parameter = x_parameter * time_parameter

    x_out = tf.keras.layers.Conv2D(n_channels, kernel_size=3, padding='same')(x_img)
    x_out = x_out + x_parameter
    x_out = tf.keras.layers.BatchNormalization()(x_out)
    x_out = tf.keras.layers.Activation('relu')(x_out)

    return x_out

def unet():
    x_input = tf.keras.layers.Input(shape=(28, 28, 1), name='x_input')

    x_ts_input = tf.keras.layers.Input(shape=(1,), name='x_ts_input')
    x_ts = tf.keras.layers.Dense(192)(x_ts_input)
    x_ts_ln = tf.keras.layers.BatchNormalization()(x_ts)
    x_ts_out = tf.keras.layers.Activation('relu')(x_ts_ln)

    x28_down = time_embedding(x_input, x_ts_out, 16)
    block1_down = tf.keras.layers.MaxPool2D(2)(x28_down)

    x14_down = time_embedding(block1_down, x_ts_out, 32)
    block2_down = tf.keras.layers.MaxPool2D(2)(x14_down)

    x7_down = time_embedding(block2_down, x_ts_out, 64)

    mlp_input = tf.keras.layers.Flatten()(x7_down)
    mlp_concat = tf.keras.layers.Concatenate()([mlp_input, x_ts_out])
    mlp = tf.keras.layers.Dense(7*7*64)(mlp_concat)
    mlp_ln = tf.keras.layers.BatchNormalization()(mlp)
    mlp_out = tf.keras.layers.Activation('relu')(mlp_ln)
    mlp_out = tf.keras.layers.Reshape((7, 7, 64))(mlp_out)

    x7_up_input = tf.keras.layers.Concatenate()([mlp_out, x7_down])
    x7_up = time_embedding(x7_up_input, x_ts_out, 64)
    block1_up = tf.keras.layers.Conv2DTranspose(64, (4,4),strides=(2,2),
                                                padding='same')(x7_up)

    x14_up_input = tf.keras.layers.Concatenate()([block1_up, x14_down])
    x14_up = time_embedding(x14_up_input, x_ts_out, 32)
    block2_up = tf.keras.layers.Conv2DTranspose(32, (4,4),strides=(2,2),
                                                padding='same')(x14_up)

    x28_up_input = tf.keras.layers.Concatenate()([block2_up, x28_down])
    x28_up = time_embedding(x28_up_input, x_ts_out, 16)

    x_out = tf.keras.layers.Conv2D(1, kernel_size=1, padding='same')(x28_up)
    model = tf.keras.models.Model([x_input, x_ts_input], x_out)
    return model

model = unet()

# model.summary()
In [ ]:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0008)
loss_func = tf.keras.losses.MeanSquaredError()
model.compile(loss=loss_func, optimizer=optimizer)

3.3. Training (learning the reverse process)

Step 1: Sample $x_0$ from dataset ($x_0 \sim p(x)$) and add noise from uniformly sampled $t$ ($ t \sim$ Uniform$( \{ 1, ..., T \}))$.


Step 2: Feed through neural network and minimize loss

$$ L_t = \mathbb{E} \left[ \lVert \epsilon_t - \epsilon_{\theta} (x_t, t) \rVert ^2\right] . $$
In [ ]:
# Training loop

n_batch = 64
n_iter = 1000
n_print = 200

for e in range(n_iter):
    x_img = X_train[np.random.randint(len(X_train), size=n_batch)]
    ts = np.random.randint(0, T, size=len(x_img))
    x_t, epsilon = forward_diffusion(x_img, ts)

    loss = model.train_on_batch([x_t, ts], epsilon)

    if e % n_print == 0:
        print(f"Iteration {e}/{n_iter}. Loss: {loss:.5f}")
Iteration 0/1000. Loss: 1.80499
Iteration 200/1000. Loss: 0.16916
Iteration 400/1000. Loss: 0.11770
Iteration 600/1000. Loss: 0.09738
Iteration 800/1000. Loss: 0.08678

3.4. Reverse Diffusion (sampling process)

Step 1: Start from isotropic noise ($N(0,1)$), and predict $\epsilon_{\theta}(x_t, t)$ with neural network.

Step 2: Sample from distribution


$$q_{\theta}(x_{t-1}|x_t) = N \left( x_{t-1}; \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1-\alpha_t}{ \sqrt{1-\bar{\alpha}_t}} \epsilon_{\theta} \right) , \tilde{\beta}_{\theta} \right)$$

to get


$$x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_{\theta}(x_t,t) \right) + \sigma_t \textbf{z} ,$$

using the reparametrization trick.

Step 3: Repeat Steps 1 and 2 until $t=0$.

In [ ]:
def reverse_diffusion(n_sample):
    x_t = np.random.normal(size=(n_sample, 28, 28, 1))

    alphas = 1. - beta_schedule
    alpha_bars = tf.math.cumprod(alphas, axis=0)

    for t in range(T-1):

        t = T - t-1

        e_t = model.predict([x_t, np.full((n_sample), t)], verbose=0)
        x_t_1 = (1/np.sqrt(alphas[t]))*(x_t - ((1-alphas[t])/np.sqrt(1-alpha_bars[t])*e_t))

        sigma_t = np.sqrt(beta_schedule[t])
        z = np.random.normal(size=(n_sample, 28, 28, 1))

        if t>1:
          x_t = x_t_1 + sigma_t * z
        else:
          x_t = x_t_1

    x0 = x_t

    return x0

3.5. Visualize the Results

In [ ]:
x_generate = reverse_diffusion(10)
In [ ]:
plt.figure(figsize = (12, 3))
for i in range(x_generate.shape[0]):
    plt.subplot(1, 10, i + 1)
    plt.imshow(x_generate[i].reshape((28, 28)), 'gray', interpolation = 'nearest')
    plt.axis('off')

plt.show()
In [1]:
%%javascript
$.getScript('https://kmahelona.github.io/ipython_notebook_goodies/ipython_notebook_toc.js')