Diffusion Models
Table of Contents
Source
CS480/680 Intro to Machine Learning
CS 198-126: Modern Computer Vision Fall 2022
Diffusion Model 수학이 포함된 tutorial
Diffusion models from scratch in Pytorch
[1] Denoising Diffusion Probabilistic Models
Fundamentally there are two processes:
After training, we can generate new samples by denoising noise (latent input) with the trained neural network.
The forward process iteratively adds noise to the image. Each noise added is considered a Gaussian, and thus can be written as
where $\beta_t$ is a noise scheduler.
For a given original image $x_{0}$, this process is repeated $T$ steps, until the image becomes an isotropic Gaussian $N(0, 1)$. Due to the Markovian property, we can express these transitions as a chain of individual transitions from time $0$ to $t$ as
Because the joint distribution of Gaussians is also Gaussian, our forward process distribution for a given input image $x_0$ at time $t$ is
where $\alpha_t = 1- \beta_t$, $\bar{\alpha}_t = \prod_{s=0}^{t} \alpha_s$, and $\epsilon \sim N (0,1).$
By using the reparametrization trick we can write $x_{t}$ as
enabling tractable closed-form sampling at any timestep.
The goal of diffusion models is to learn the reverse process. Because the reverse process (forward process posterior) $q(x_{t-1}|x_t)$ is inherently intractable, we start by conditioning this distribution on the input image $x_0$, i.e., $q(x_{t-1}|x_t, x_0)$. We can write out this distribution as
Using the equation derived for the forward process. ($x_t = \sqrt{\bar{\alpha}_t}*x_0 + \sqrt{1-\bar{\alpha}_t}*\epsilon$), we can express $\tilde{\mu}(x_t,x_0)$ and $\tilde{\beta}_t$ as
Here, we attempt to learn this distribution (learn $\tilde{\mu}(x_t,x_0)$ and $\tilde{\beta}_t$) with a neural network $q_{\theta}(x_{t-1}|x_t)$ by minimizing the distance$^*$ between $q(x_{t-1}|x_t, x_0)$ and $q_{\theta}(x_{t-1}|x_t)$. In other words, we are learning $\mu_{\theta}(x_t,x_0)$ and $\Sigma_{\theta}$ (the mean and variance of the forward process posterior). Hence our loss can be formulated as the following:
where $D_{KL}$ refers to the KL-divergence (a distance metric for distributions).
Writing out $L_t$ (by definition of KL-divergence), we have
The authors of [1] show that the training process works better with a simplified objective (by ignoring the weighting term)
or
$$ L_t = \mathbb{E} \left[ \lVert \epsilon_t - \epsilon_{\theta} (x_t, t) \rVert ^2\right] . $$Minimizing this training objective allows the neural network to learn the reverse process.
Once $q_{\theta}(x_{t-1}|x_t)$ is learned ($\mu_{\theta}(x_t,x_0)$ and $\Sigma_{\theta}$ are learned), we can sample from this distribution for $T$ steps to generate samples.
The authors of [1] show that learning a diagonal variance $\Sigma_{\theta}$ leads to unstable training and poor sampling quality, and thus fixed $\Sigma_{\theta} = \beta_t$ instead of making them learnable. Thus using the reparametrization trick, we can express $x_{t-1}$ given $x_t$ as
where $\textbf{z} \sim N(0,1)$. We repeat this process for $t=T,T-1, ...,1, 0.$
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
(train_x, train_y), _ = tf.keras.datasets.fashion_mnist.load_data()
X_train = train_x[train_y.squeeze() == 7]
X_train = (X_train / 127.5) - 1.0
print('train_images :', train_x.shape)
T = 50
beta_schedule = np.linspace(0, 1.0, T+1)
def forward_diffusion (x0, t):
alphas = 1. - beta_schedule
alpha_bars = tf.math.cumprod(alphas, axis=0)
epsilon = np.random.normal(size=x0.shape)
alpha_bar_t = tf.gather(alpha_bars, t)
alpha_bar_t = np.reshape(alpha_bar_t,(-1,1,1))
noisy_image = np.sqrt(alpha_bar_t)*x0 + np.sqrt(1 - alpha_bar_t)*epsilon
return noisy_image, epsilon
def time_embedding(x_img, x_ts, n_channels):
x_parameter = tf.keras.layers.Conv2D(n_channels, kernel_size=3, padding='same')(x_img)
x_parameter = tf.keras.layers.Activation('relu')(x_parameter)
time_parameter = tf.keras.layers.Dense(n_channels)(x_ts)
time_parameter = tf.keras.layers.Activation('relu')(time_parameter)
time_parameter = tf.keras.layers.Reshape((1, 1, n_channels))(time_parameter)
x_parameter = x_parameter * time_parameter
x_out = tf.keras.layers.Conv2D(n_channels, kernel_size=3, padding='same')(x_img)
x_out = x_out + x_parameter
x_out = tf.keras.layers.BatchNormalization()(x_out)
x_out = tf.keras.layers.Activation('relu')(x_out)
return x_out
def unet():
x_input = tf.keras.layers.Input(shape=(28, 28, 1), name='x_input')
x_ts_input = tf.keras.layers.Input(shape=(1,), name='x_ts_input')
x_ts = tf.keras.layers.Dense(192)(x_ts_input)
x_ts_ln = tf.keras.layers.BatchNormalization()(x_ts)
x_ts_out = tf.keras.layers.Activation('relu')(x_ts_ln)
x28_down = time_embedding(x_input, x_ts_out, 16)
block1_down = tf.keras.layers.MaxPool2D(2)(x28_down)
x14_down = time_embedding(block1_down, x_ts_out, 32)
block2_down = tf.keras.layers.MaxPool2D(2)(x14_down)
x7_down = time_embedding(block2_down, x_ts_out, 64)
mlp_input = tf.keras.layers.Flatten()(x7_down)
mlp_concat = tf.keras.layers.Concatenate()([mlp_input, x_ts_out])
mlp = tf.keras.layers.Dense(7*7*64)(mlp_concat)
mlp_ln = tf.keras.layers.BatchNormalization()(mlp)
mlp_out = tf.keras.layers.Activation('relu')(mlp_ln)
mlp_out = tf.keras.layers.Reshape((7, 7, 64))(mlp_out)
x7_up_input = tf.keras.layers.Concatenate()([mlp_out, x7_down])
x7_up = time_embedding(x7_up_input, x_ts_out, 64)
block1_up = tf.keras.layers.Conv2DTranspose(64, (4,4),strides=(2,2),
padding='same')(x7_up)
x14_up_input = tf.keras.layers.Concatenate()([block1_up, x14_down])
x14_up = time_embedding(x14_up_input, x_ts_out, 32)
block2_up = tf.keras.layers.Conv2DTranspose(32, (4,4),strides=(2,2),
padding='same')(x14_up)
x28_up_input = tf.keras.layers.Concatenate()([block2_up, x28_down])
x28_up = time_embedding(x28_up_input, x_ts_out, 16)
x_out = tf.keras.layers.Conv2D(1, kernel_size=1, padding='same')(x28_up)
model = tf.keras.models.Model([x_input, x_ts_input], x_out)
return model
model = unet()
# model.summary()
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0008)
loss_func = tf.keras.losses.MeanSquaredError()
model.compile(loss=loss_func, optimizer=optimizer)
Step 1: Sample $x_0$ from dataset ($x_0 \sim p(x)$) and add noise from uniformly sampled $t$ ($ t \sim$ Uniform$( \{ 1, ..., T \}))$.
Step 2: Feed through neural network and minimize loss
$$ L_t = \mathbb{E} \left[ \lVert \epsilon_t - \epsilon_{\theta} (x_t, t) \rVert ^2\right] . $$# Training loop
n_batch = 64
n_iter = 1000
n_print = 200
for e in range(n_iter):
x_img = X_train[np.random.randint(len(X_train), size=n_batch)]
ts = np.random.randint(0, T, size=len(x_img))
x_t, epsilon = forward_diffusion(x_img, ts)
loss = model.train_on_batch([x_t, ts], epsilon)
if e % n_print == 0:
print(f"Iteration {e}/{n_iter}. Loss: {loss:.5f}")
Step 1: Start from isotropic noise ($N(0,1)$), and predict $\epsilon_{\theta}(x_t, t)$ with neural network.
Step 2: Sample from distribution
to get
using the reparametrization trick.
Step 3: Repeat Steps 1 and 2 until $t=0$.
def reverse_diffusion(n_sample):
x_t = np.random.normal(size=(n_sample, 28, 28, 1))
alphas = 1. - beta_schedule
alpha_bars = tf.math.cumprod(alphas, axis=0)
for t in range(T-1):
t = T - t-1
e_t = model.predict([x_t, np.full((n_sample), t)], verbose=0)
x_t_1 = (1/np.sqrt(alphas[t]))*(x_t - ((1-alphas[t])/np.sqrt(1-alpha_bars[t])*e_t))
sigma_t = np.sqrt(beta_schedule[t])
z = np.random.normal(size=(n_sample, 28, 28, 1))
if t>1:
x_t = x_t_1 + sigma_t * z
else:
x_t = x_t_1
x0 = x_t
return x0
x_generate = reverse_diffusion(10)
plt.figure(figsize = (12, 3))
for i in range(x_generate.shape[0]):
plt.subplot(1, 10, i + 1)
plt.imshow(x_generate[i].reshape((28, 28)), 'gray', interpolation = 'nearest')
plt.axis('off')
plt.show()
%%javascript
$.getScript('https://kmahelona.github.io/ipython_notebook_goodies/ipython_notebook_toc.js')