Style Transfer
Table of Contents
1. Neural Style Transfer¶
Neural style transfer (NST) is a technique in computer vision and image synthesis that generates a new image by combining the content of one image with the style of another. The content image defines the spatial layout and objects, while the style image contributes color, texture, and brushstroke patterns.
The technique was first introduced by Gatys et al. in 2015 and has since become a foundational approach in image generation and visual creativity using deep neural networks.
The figure below illustrates an example of artistic style transfer. On the left is a content image of a European town, and on the right is the generated image that adopts the painting style of Vincent van Gogh’s Starry Night (shown in the inset). The result preserves the structure of the content image while expressing the texture and color palette of the style image.
1.1. Problem Definition¶
Given two input images:
- A content image $I_c$, which provides the spatial structure, layout, and objects that should be preserved
- A style image $I_s$, which contributes the textures, colors, and artistic patterns to be transferred
The objective is to synthesize a third image $I_g$ (the generated image) such that:
- The structural and semantic content in $I_g$ resembles that of $I_c$
- The visual appearance (in terms of style) of $I_g$ matches that of $I_s$
This is accomplished by optimizing a loss function that balances content and style terms. The resulting image $I_g$ is computed iteratively so that it matches the high-level features of the content image while reproducing the statistical patterns (style) extracted from the style image.
1.2. Feature Extraction using CNNs¶
Neural style transfer relies on the hierarchical representation capabilities of convolutional neural networks (CNNs), such as VGG-19 pretrained on the ImageNet dataset. These networks extract multi-scale features from images, which are useful for encoding both spatial content and visual style.
Lower convolutional layers tend to capture local structures, such as edges, textures, and basic color contrasts
Higher convolutional layers encode more abstract representations, including object shapes, part relationships, and spatial configurations
Because of this hierarchy, intermediate feature maps extracted from different layers of a CNN can be used to characterize the content and style of an image:
Content features are drawn from deeper layers, where spatial semantics are preserved
Style features are obtained by computing feature correlations (e.g., via Gram matrices) across multiple layers, capturing texture and appearance
While this interpretation is not mathematically rigorous in a strict sense, it is strongly supported by empirical results across many experiments. In practice, this layer-wise disentanglement of content and style forms the foundation for neural style transfer.
1.3. Overview of Structure of Style Transfer¶
This section outlines the overall structure of the neural style transfer procedure. The method follows a well-defined sequence of steps that combine feature extraction, loss computation, and image optimization.
Step 1: Select a pretrained CNN
- Use a fixed, pretrained convolutional neural network such as VGG‑19, trained on ImageNet.
- The network serves as a feature extractor; its weights remain unchanged during the optimization process.
Step 2: Extract features from the content and style images
- Feed the content image $I_c$ and the style image $I_s$ independently through identical copies of the pretrained network.
- Content features are extracted from deeper layers (e.g.,
conv4_2
) that preserve spatial structure and object layout. - Style features are extracted from multiple shallower and intermediate layers (e.g.,
conv1_1
,conv2_1
,conv3_1
,conv4_1
,conv5_1
), where texture and low-level statistics are more prominent.
Step 3: Initialize the generated image
- Create the generated image $I_g$ as a randomly initialized image (e.g., white noise or a copy of $I_c$).
- Unlike $I_c$ and $I_s$, the generated image $I_g$ is the only input that will be updated through gradient-based optimization.
Step 4: Optimization loop
- Forward $I_g$ through the same pretrained network to compute its content and style features.
- Compute the content loss of $I_g$ with those of $I_c$.
- Compute the style loss of $I_g$ and $I_s$ across selected layers.
- Backpropagate the total loss with respect to the pixels of $I_g$.
- Update $I_g$ using an optimizer such as L-BFGS or Adam.
- Repeat this process for a fixed number of iterations or until convergence.
Step 5: Result
After optimization, the resulting image $I_g$ will combine:
- The structural content of $I_c$, and
- The visual style of $I_s$
This synthesis is achieved by aligning intermediate representations of $I_g$ with those of both $I_c$ and $I_s$ through loss minimization.
1.4. Loss in Neural Style Transfer¶
In neural style transfer, the optimization of the generated image relies on carefully defined loss functions. These guide the synthesis toward blending content from one image with the artistic style of another. This section outlines each component of the loss and explains how they are combined and optimized.
1.4.1. Content Representation¶
Let $F^l(I)$ be the feature map of image $I$ extracted at layer $l$ of a pretrained CNN. The content loss measures the difference between high-level feature representations of the generated and content images:
$$ \mathcal{L}_{\text{content}}(I_g, I_c) = \frac{1}{2} \sum_{i,j} \left( F^l_{ij}(I_g) - F^l_{ij}(I_c) \right)^2 $$
This loss encourages the generated image $I_g$ to preserve the semantic structure of the content image $I_c$. Importantly, it is computed in feature space, not pixel space.
1.4.2. Style Representation via Gram Matrices¶
To capture style, the correlations between feature channels are considered. This is achieved through the Gram matrix:
(1) Given feature map $F^l(I)$ of shape $C \times H \times W$, flatten spatial dimensions into $C \times (H W)$
(2) Compute the Gram matrix $G^l(I) \in \mathbb{R}^{C \times C}$:
$$ G^l_{ij}(I) = \sum_{k} F^l_{ik}(I) \cdot F^l_{jk}(I) $$
or
$$G^{l}(I) = F^{l}(I) \; \left( F^{l}(I) \right)^T$$
This matrix encodes how often feature channels $i$ and $j$ are activated together.
To create a new image, we synthesize an image that has similar correlation (or Gram matrix) as the one we want to prodcue. The style loss is then:
$$ \mathcal{L}_{\text{style}}(I_g, I_s) = \sum_{l \in \mathcal{L}_s} w_l \cdot \left\| G^l(I_g) - G^l(I_s) \right\|^2 $$
where $\mathcal{L}_s$ is the set of layers used to represent style, and $w_l$ are their weights.
1.4.3. Total Variation Loss¶
Optimization may introduce high-frequency noise in $I_g$. This is mitigated by total variation loss, which encourages spatial smoothness:
$$ \mathcal{L}_{\text{TV}}(I_g) = \sum_{i,j} \; \lvert I_g[i,j+1] - I_g[i,j] \rvert + \lvert I_g[i+1,j] - I_g[i,j]\rvert $$
or
$$\sum_{i,j} \; \left| x_{i,j+1}- x_{i,j}\right| + \left| x_{i+1,j} - x_{i,j}\right|$$
This loss penalizes large intensity differences between neighboring pixels.
1.4.4. Combined Loss and Optimization¶
The total loss combines content, style, and total variation components:
$$ \mathcal{L}_{\text{total}} = \alpha \cdot \mathcal{L}_{\text{content}} + \beta \cdot \mathcal{L}_{\text{style}} + \gamma \cdot \mathcal{L}_{\text{TV}} $$
where:
- $\alpha$: weight for preserving content
- $\beta$: weight for applying style
- $\gamma$: weight for smoothing artifacts
During training:
(1) Initialize $I_g$ (e.g., with white noise or a copy of $I_c$)
(2) Keep all CNN weights fixed
(3) Compute each loss via forward passes
(4) Backpropagate gradients with respect to $I_g$ only
(5) Update $I_g$ using an optimizer
(6) Repeat until synthesis achieves a satisfactory blend of content, style, and smoothness
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from keras.applications.vgg16 import VGG16
import cv2
h_image, w_image = 600, 1000
img_content = cv2.imread('./image_files/postech_flag.jpg')
img_content = cv2.cvtColor(img_content, cv2.COLOR_BGR2RGB)
img_content = cv2.resize(img_content, (w_image, h_image))
plt.figure(figsize = (10,8))
plt.imshow(img_content)
plt.axis('off')
plt.show()
img_style = cv2.imread('./image_files/la_muse.jpg')
img_style = cv2.cvtColor(img_style, cv2.COLOR_BGR2RGB)
img_style = cv2.resize(img_style, (w_image, h_image))
plt.figure(figsize = (10,8))
plt.imshow(img_style)
plt.axis('off')
plt.show()
Pre-trained Model (VGG16)
model = VGG16(weights = 'imagenet')
model.summary()
vgg16_weights = model.get_weights()
# kernel size: [kernel_height, kernel_width, input_ch, output_ch]
weights = {
'conv1_1' : tf.constant(vgg16_weights[0]),
'conv1_2' : tf.constant(vgg16_weights[2]),
'conv2_1' : tf.constant(vgg16_weights[4]),
'conv2_2' : tf.constant(vgg16_weights[6]),
'conv3_1' : tf.constant(vgg16_weights[8]),
'conv3_2' : tf.constant(vgg16_weights[10]),
'conv3_3' : tf.constant(vgg16_weights[12]),
'conv4_1' : tf.constant(vgg16_weights[14]),
'conv4_2' : tf.constant(vgg16_weights[16]),
'conv4_3' : tf.constant(vgg16_weights[18]),
'conv5_1' : tf.constant(vgg16_weights[20]),
'conv5_2' : tf.constant(vgg16_weights[22]),
'conv5_3' : tf.constant(vgg16_weights[24]),
}
# bias size: [output_ch] or [neuron_size]
biases = {
'conv1_1' : tf.constant(vgg16_weights[1]),
'conv1_2' : tf.constant(vgg16_weights[3]),
'conv2_1' : tf.constant(vgg16_weights[5]),
'conv2_2' : tf.constant(vgg16_weights[7]),
'conv3_1' : tf.constant(vgg16_weights[9]),
'conv3_2' : tf.constant(vgg16_weights[11]),
'conv3_3' : tf.constant(vgg16_weights[13]),
'conv4_1' : tf.constant(vgg16_weights[15]),
'conv4_2' : tf.constant(vgg16_weights[17]),
'conv4_3' : tf.constant(vgg16_weights[19]),
'conv5_1' : tf.constant(vgg16_weights[21]),
'conv5_2' : tf.constant(vgg16_weights[23]),
'conv5_3' : tf.constant(vgg16_weights[25]),
}
# input layer: [1, image_height, image_width, channels]
input_content = tf.placeholder(tf.float32, [1, h_image, w_image, 3])
input_style = tf.placeholder(tf.float32, [1, h_image, w_image, 3])
def net(x, weights, biases):
# First convolution layer
conv1_1 = tf.nn.conv2d(x,
weights['conv1_1'],
strides = [1, 1, 1, 1],
padding = 'SAME')
conv1_1 = tf.nn.relu(tf.add(conv1_1, biases['conv1_1']))
conv1_2 = tf.nn.conv2d(conv1_1,
weights['conv1_2'],
strides = [1, 1, 1, 1],
padding = 'SAME')
conv1_2 = tf.nn.relu(tf.add(conv1_2, biases['conv1_2']))
maxp1 = tf.nn.max_pool(conv1_2,
ksize = [1, 2, 2, 1],
strides = [1, 2, 2, 1],
padding = 'VALID')
# Second convolution layer
conv2_1 = tf.nn.conv2d(maxp1,
weights['conv2_1'],
strides = [1, 1, 1, 1],
padding = 'SAME')
conv2_1 = tf.nn.relu(tf.add(conv2_1, biases['conv2_1']))
conv2_2 = tf.nn.conv2d(conv2_1,
weights['conv2_2'],
strides = [1, 1, 1, 1],
padding = 'SAME')
conv2_2 = tf.nn.relu(tf.add(conv2_2, biases['conv2_2']))
maxp2 = tf.nn.max_pool(conv2_2,
ksize = [1, 2, 2, 1],
strides = [1, 2, 2, 1],
padding = 'VALID')
# third convolution layer
conv3_1 = tf.nn.conv2d(maxp2,
weights['conv3_1'],
strides = [1, 1, 1, 1],
padding = 'SAME')
conv3_1 = tf.nn.relu(tf.add(conv3_1, biases['conv3_1']))
conv3_2 = tf.nn.conv2d(conv3_1,
weights['conv3_2'],
strides = [1, 1, 1, 1],
padding = 'SAME')
conv3_2 = tf.nn.relu(tf.add(conv3_2, biases['conv3_2']))
conv3_3 = tf.nn.conv2d(conv3_2,
weights['conv3_3'],
strides = [1, 1, 1, 1],
padding = 'SAME')
conv3_3 = tf.nn.relu(tf.add(conv3_3, biases['conv3_3']))
maxp3 = tf.nn.max_pool(conv3_3,
ksize = [1, 2, 2, 1],
strides = [1, 2, 2, 1],
padding = 'VALID')
# fourth convolution layer
conv4_1 = tf.nn.conv2d(maxp3,
weights['conv4_1'],
strides = [1, 1, 1, 1],
padding = 'SAME')
conv4_1 = tf.nn.relu(tf.add(conv4_1, biases['conv4_1']))
conv4_2 = tf.nn.conv2d(conv4_1,
weights['conv4_2'],
strides = [1, 1, 1, 1],
padding = 'SAME')
conv4_2 = tf.nn.relu(tf.add(conv4_2, biases['conv4_2']))
conv4_3 = tf.nn.conv2d(conv4_2,
weights['conv4_3'],
strides = [1, 1, 1, 1],
padding = 'SAME')
conv4_3 = tf.nn.relu(tf.add(conv4_3, biases['conv4_3']))
maxp4 = tf.nn.max_pool(conv4_3,
ksize = [1, 2, 2, 1],
strides = [1, 2, 2, 1],
padding = 'VALID')
# fifth convolution layer
conv5_1 = tf.nn.conv2d(maxp4,
weights['conv5_1'],
strides = [1, 1, 1, 1],
padding = 'SAME')
conv5_1 = tf.nn.relu(tf.add(conv5_1, biases['conv5_1']))
conv5_2 = tf.nn.conv2d(conv5_1,
weights['conv5_2'],
strides = [1, 1, 1, 1],
padding = 'SAME')
conv5_2 = tf.nn.relu(tf.add(conv5_2, biases['conv5_2']))
conv5_3 = tf.nn.conv2d(conv5_2,
weights['conv5_3'],
strides = [1, 1, 1, 1],
padding = 'SAME')
conv5_3 = tf.nn.relu(tf.add(conv5_3, biases['conv5_3']))
maxp5 = tf.nn.max_pool(conv5_3,
ksize = [1, 2, 2, 1],
strides = [1, 2, 2, 1],
padding = 'VALID')
return {
'conv1_1' : conv1_1,
'conv1_2' : conv1_2,
'conv2_1' : conv2_1,
'conv2_2' : conv2_2,
'conv3_1' : conv3_1,
'conv3_2' : conv3_2,
'conv3_3' : conv3_3,
'conv4_1' : conv4_1,
'conv4_2' : conv4_2,
'conv4_3' : conv4_3,
'conv5_1' : conv5_1,
'conv5_2' : conv5_2,
'conv5_3' : conv5_3
}
layers_style = ['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1']
layers_content = ['conv4_2']
LR = 30
Image Composition (Generation) as tf.Variable
# composite image is the only variable that needs to be updated
input_gen = tf.Variable(tf.random_uniform([1, h_image, w_image, 3], maxval = 255))
Style Loss and Content Loss
(1) Style loss
$$ G^l_{ij}(I) = \sum_{k} F^l_{ik}(I) \cdot F^l_{jk}(I) $$
or
$$G^{l}(I) = F^{l}(I) \; \left( F^{l}(I) \right)^T$$
(2) Content loss
$$ \mathcal{L}_{\text{content}}(I_g, I_c) = \frac{1}{2} \sum_{i,j} \left( F^l_{ij}(I_g) - F^l_{ij}(I_c) \right)^2 $$
def get_gram_matrix(conv_layer):
channels = conv_layer.get_shape().as_list()[3]
conv_layer = tf.reshape(conv_layer, (-1, channels))
gram_matrix = tf.matmul(tf.transpose(conv_layer), conv_layer)
return gram_matrix/((conv_layer.get_shape().as_list()[0])*channels)
def get_loss_style(gram_matrix_gen, gram_matrix_ref):
loss = tf.reduce_mean(tf.square(gram_matrix_gen - gram_matrix_ref))
return loss
def get_loss_content(gen_layer, ref_layer):
loss = tf.reduce_mean(tf.square(gen_layer - ref_layer))
return loss
features_style = net(input_style, weights, biases)
features_content = net(input_content, weights, biases)
features_gen = net(input_gen, weights, biases)
loss_style = 0
for key in layers_style:
loss_style += get_loss_style(get_gram_matrix(features_gen[key]), get_gram_matrix(features_style[key]))
loss_content = 0
for key in layers_content:
loss_content += get_loss_content(features_gen[key], features_content[key])
g = 1/(1e1)
loss_total = loss_content + g*loss_style
optm = tf.train.AdamOptimizer(LR).minimize(loss_total)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
Composite Image
n_iter = 1000
n_prt = 100
for itr in range(n_iter + 1):
sess.run(optm, feed_dict = {input_style: img_style[np.newaxis,:,:,:],
input_content: img_content[np.newaxis,:,:,:]})
if itr%n_prt == 0:
ls = sess.run(loss_style, feed_dict = {input_style: img_style[np.newaxis,:,:,:]})
lc = sess.run(loss_content, feed_dict = {input_content: img_content[np.newaxis,:,:,:]})
print('Iteration: {}'.format(itr))
print('Style loss: {}'.format(g*ls))
print('Content loss: {}\n'.format(lc))
image = sess.run(input_gen)
image = np.uint8(np.clip(np.round(image), 0, 255)).squeeze()
plt.figure(figsize = (10,8))
plt.imshow(image)
plt.axis('off')
plt.show()
Style Transfer with Total Variance Loss
Sometimes, the composite images we learn have a lot of high-frequency noise, particularly bright or dark pixels.
One common noise reduction method is total variation denoising.
$$\sum_{i,j} \; \left| x_{i,j+1}- x_{i,j}\right| + \left| x_{i+1,j} - x_{i,j}\right|$$
def get_loss_TV(conv_layer):
loss = tf.reduce_mean(tf.abs(conv_layer[:,:,1:,:] - conv_layer[:,:,:-1,:])) \
+ tf.reduce_mean(tf.abs(conv_layer[:,1:,:,:] - conv_layer[:,:-1,:,:]))
return loss
loss_TV = get_loss_TV(input_gen)
loss_total = loss_content + loss_style + 100*loss_TV
optm = tf.train.AdamOptimizer(LR).minimize(loss_total)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
n_iter = 500
n_prt = 100
for itr in range(n_iter + 1):
sess.run(optm, feed_dict = {input_style : img_style[np.newaxis,:,:,:],
input_content : img_content[np.newaxis,:,:,:]})
if itr%n_prt == 0:
ls = sess.run(loss_style, feed_dict = {input_style : img_style[np.newaxis,:,:,:]})
lc = sess.run(loss_content, feed_dict = {input_content : img_content[np.newaxis,:,:,:]})
ltv = sess.run(loss_TV)
print('Iteration: {}'.format(itr))
print('Style loss: {}'.format(g*ls))
print('Content loss: {}'.format(lc))
print('TV loss: {}\n'.format(ltv))
image = sess.run(input_gen)
image = np.uint8(np.clip(np.round(image), 0, 255)).squeeze()
plt.figure(figsize = (10,8))
plt.imshow(image)
plt.axis('off')
plt.show()
%%javascript
$.getScript('https://kmahelona.github.io/ipython_notebook_goodies/ipython_notebook_toc.js')