Global Average Pooling (GAP)


By Prof. Seungchul Lee
http://iai.postech.ac.kr/
Industrial AI Lab at POSTECH

Table of Contents

  • Attention

  • Visualizing and Understanding Convolutional Networks

1. CNN with a Fully Connected LayerĀ¶

The conventional CNN can be conceptually divided into two parts. One part is feature extraction and the other is classification. In the feature extraction process, convolution is used to extract the features of the input data so that the classification can be performed well. The classification process classifies which group each input data belongs to by using the extracted features from the input data.

When we visually identify images, we do not look at the whole image; instead, we intuitively focus on the most important parts of the image. CNN learning is similar to the way humans focus. When its weights are optimized, the more important parts are given higher weights. But generally, we are not able to recognize this because the generic CNN goes through a fully connected layer and makes the features extracted by the convolution layer more abstract.



1.1. Issues on CNN (or Deep Learning)Ā¶

  • Deep learning performs well comparing with any other existing algorithms
  • But works as a black box

    • A classification result is simply returned without knowing how the classification results are derived ā†’ little interpretability
  • When we visually identify images, we do not look at the whole image

  • Instead, we intuitively focus on the most important parts of the image
  • When CNN weights are optimized, the more important parts are given higher weights

  • Class activation map (CAM)

    • We can determine which parts of the image the model is focusing on, based on the learned weights
    • Highlighting the importance of the image region to the prediction



2. CAM: CNN with a Global Average PoolingĀ¶

  • shed light on how it explicitly enables the convolutional neural network to have remarkable localization ability
  • the heatmap is the class activation map, highlighting the importance of the image region to the prediction

The deep learning model is a black box model. When input data is received, a classification result of 1 or 0 is simply returned for the binary classification problem, without knowing how the classification results are derived. Meanwhile, The class activation map (CAM) is capable of interpreting the results of the classification. We can determine which parts of the image the model is focusing on. Through an analysis of which part of the image the model is focusing on, we are able to interpret which part of the image is considered important.

The class activation map (CAM) is a modified convolution layer. It directly highlights the important parts of the spatial grid of an image. As a result, we can see the emphasized parts of the model. The below figure describes the procedure for class activation mapping.



The feature maps of the last convolution layer can be interpreted as a collection of visual spatial locations focused on by the model. The CAM can be obtained by taking a linear sum of the features. They all have different weights and thus can obtain spatial locations according to various input images through a linear combination. For a given image, $f_k(x,y)$ represents the feature map of unit $k$ in the last convolution layer at spatial location $(x,y)$. For a given class $c$, the class score, $S_c$, is expressed as the following equation.


$$S_c = \sum_k \omega_k^c \sum_{x,y} f_k(x,y)= \sum_{x,y} \sum_k \omega_k^c \; f_k(x,y)$$

where $\omega_k^c$ the weight corresponding to class $c$ for unit $k$. The class activation map for class $c$ is denoted as $M_c$.


$$M_c(x,y) = \sum_k \omega_k^c \; f_k(x,y)$$

$M_c$ directly indicates the importance of the feature map at a spatial grid $(x,y)$ of the class $c$. Finally the output of the softmax for class $c$ is,


$$P_c = \frac{\exp\left(S_c\right)}{\sum_c \exp\left(S_c\right)}$$

In case of the CNN, the size of the feature map is reduced by the pooling layer. By simple up-sampling, it is possible to identify attention image regions for each label.

3. CAM with MNISTĀ¶

InĀ [1]:
!pip install opencv-python 
Requirement already satisfied: opencv-python in c:\users\seungchul\appdata\local\programs\python\python35\lib\site-packages (4.1.0.25)
Requirement already satisfied: numpy>=1.11.1 in c:\users\seungchul\appdata\local\programs\python\python35\lib\site-packages (from opencv-python) (1.16.4)
WARNING: You are using pip version 19.1.1, however version 19.2.2 is available.
You should consider upgrading via the 'python -m pip install --upgrade pip' command.
InĀ [Ā ]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"
InĀ [2]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import cv2

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot = True)
WARNING: Logging before flag parsing goes to stderr.
W0816 23:03:07.942350 10560 deprecation.py:323] From <ipython-input-2-847f63c164ee>:7: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
W0816 23:03:07.942880 10560 deprecation.py:323] From c:\users\seungchul\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
W0816 23:03:07.943342 10560 deprecation.py:323] From c:\users\seungchul\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
W0816 23:03:08.127853 10560 deprecation.py:323] From c:\users\seungchul\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
W0816 23:03:08.129833 10560 deprecation.py:323] From c:\users\seungchul\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
W0816 23:03:08.164820 10560 deprecation.py:323] From c:\users\seungchul\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
InĀ [3]:
train_idx = ((np.argmax(mnist.train.labels, 1) == 7) | (np.argmax(mnist.train.labels, 1) == 9))
test_idx = ((np.argmax(mnist.test.labels, 1) == 7) | (np.argmax(mnist.test.labels, 1) == 9))

train_imgs   = mnist.train.images[train_idx]
train_labels = mnist.train.labels[train_idx]
test_imgs    = mnist.test.images[test_idx]
test_labels  = mnist.test.labels[test_idx]

n_train      = train_imgs.shape[0]
n_test       = test_imgs.shape[0]

print ("Packages loaded")
print ("The number of train images : {}, shape : {}".format(n_train, train_imgs.shape))
print ("The number of test images : {}, shape : {}".format(n_test, test_imgs.shape))
Packages loaded
The number of train images : 11169, shape : (11169, 784)
The number of test images : 2037, shape : (2037, 784)
InĀ [4]:
def train_batch_maker(batch_size):
    idx = np.random.randint(n_train, size = batch_size)
    
    labels = []
    for i in idx:
        if np.argmax(train_labels[i,:]) == 7:
            labels.append([1, 0])
        else:
            labels.append([0, 1])    
    
    labels = np.array(labels)
    return train_imgs[idx], labels
InĀ [5]:
def test_batch_maker(batch_size):
    idx = np.random.randint(n_test, size = batch_size)
    
    labels = []
    for i in idx:
        if np.argmax(test_labels[i,:]) == 7:
            labels.append([1, 0])
        else:
            labels.append([0, 1])    
    
    labels = np.array(labels)
    return test_imgs[idx], labels
InĀ [6]:
train_x, train_y = train_batch_maker(1)

plt.figure(figsize = (5,5))
plt.imshow(train_x[0].reshape(28,28), 'gray')
plt.title("Label : {}".format(np.argmax(train_y[0,:])))
plt.axis('off')
plt.show()
InĀ [7]:
input_h = 28 
input_w = 28 
input_ch = 1 
# (None, 28, 28, 1)

k1_h = 3
k1_w = 3
k1_ch = 32

p1_h = 2
p1_w = 2
# (None, 14, 14 ,32)

k2_h = 3
k2_w = 3
k2_ch = 64

p2_h = 2
p2_w = 2
# (None, 7, 7 ,64)

n_output = 2



InĀ [8]:
weights = {
    'conv1' : tf.Variable(tf.random_normal([k1_h, k1_w, input_ch, k1_ch], stddev = 0.1)),
    'conv2' : tf.Variable(tf.random_normal([k2_h, k2_w, k1_ch, k2_ch], stddev = 0.1)),
    'output' : tf.Variable(tf.random_normal([k2_ch, n_output], stddev = 0.1))
}

biases = {
    'conv1' : tf.Variable(tf.random_normal([k1_ch], stddev = 0.1)),
    'conv2' : tf.Variable(tf.random_normal([k2_ch], stddev = 0.1))
}

x = tf.placeholder(tf.float32, [None, input_h, input_w, input_ch])
y = tf.placeholder(tf.float32, [None, n_output])
InĀ [9]:
def net(x, weights, biases):
    ## First convolution layer
    conv1 = tf.nn.conv2d(x, weights['conv1'], 
                         strides= [1, 1, 1, 1], 
                         padding = 'SAME')
    conv1 = tf.nn.relu(tf.add(conv1, biases['conv1']))
    maxp1 = tf.nn.max_pool(conv1, 
                           ksize = [1, p1_h, p1_w, 1], 
                           strides = [1, p1_h, p1_w, 1], 
                           padding = 'VALID')
    
    ## Second convolution layer
    conv2 = tf.nn.conv2d(maxp1, weights['conv2'], 
                         strides= [1, 1, 1, 1], 
                         padding = 'SAME')
    conv2 = tf.nn.relu(tf.add(conv2, biases['conv2']))
    maxp2 = tf.nn.max_pool(conv2, 
                           ksize = [1, p2_h, p2_w, 1], 
                           strides = [1, p2_h, p2_w, 1], 
                           padding = 'VALID')

    ## global average pooling
    avg = tf.reduce_mean(maxp2, axis = (1,2))
    output = tf.matmul(avg, weights['output'])

    return maxp2, output
InĀ [10]:
maps, pred = net(x, weights, biases)
loss = tf.nn.softmax_cross_entropy_with_logits(labels = y, logits = pred)
loss = tf.reduce_mean(loss)

LR = 1e-4
optm = tf.train.AdamOptimizer(LR).minimize(loss)
W0816 23:03:08.541802 10560 deprecation.py:323] From <ipython-input-10-a21d270c49e7>:2: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.
Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See `tf.nn.softmax_cross_entropy_with_logits_v2`.

InĀ [11]:
n_batch = 50
n_iter = 5000
n_prt = 250

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

loss_record_train = []
for epoch in range(n_iter):
    train_x, train_y = train_batch_maker(n_batch)
    train_x = np.reshape(train_x, [-1, input_h, input_w, input_ch])    
    sess.run(optm, feed_dict = {x: train_x, y: train_y})
    
    if epoch % n_prt == 0:
        c = sess.run(loss, feed_dict = {x:train_x, y: train_y})
        loss_record_train.append(c)
        print("Iter : {}".format(epoch))
        print("Cost : {}".format(c))

plt.figure(figsize = (10,8))
plt.plot(np.arange(len(loss_record_train))*n_prt, loss_record_train, label = 'training')
plt.xlabel('iteration', fontsize = 15)
plt.ylabel('loss', fontsize = 15)
plt.legend(fontsize = 12)
plt.show()        
Iter : 0
Cost : 0.6871762871742249
Iter : 250
Cost : 0.6756047010421753
Iter : 500
Cost : 0.6492440104484558
Iter : 750
Cost : 0.6276745796203613
Iter : 1000
Cost : 0.563922643661499
Iter : 1250
Cost : 0.5437975525856018
Iter : 1500
Cost : 0.42801111936569214
Iter : 1750
Cost : 0.39010998606681824
Iter : 2000
Cost : 0.3676518201828003
Iter : 2250
Cost : 0.25770190358161926
Iter : 2500
Cost : 0.32501572370529175
Iter : 2750
Cost : 0.30564624071121216
Iter : 3000
Cost : 0.30439239740371704
Iter : 3250
Cost : 0.2174631506204605
Iter : 3500
Cost : 0.17130641639232635
Iter : 3750
Cost : 0.2645065188407898
Iter : 4000
Cost : 0.2141740769147873
Iter : 4250
Cost : 0.14054051041603088
Iter : 4500
Cost : 0.1698538064956665
Iter : 4750
Cost : 0.18479442596435547
InĀ [12]:
test_x, test_y = test_batch_maker(100)

test_x = np.reshape(test_x, [-1, input_h, input_w, input_ch])
test_y = np.array(test_y)

my_pred = sess.run(pred, feed_dict = {x: test_x})
my_pred = np.argmax(my_pred, axis = 1)

labels = np.argmax(test_y, axis = 1)

accr = np.mean(np.equal(my_pred, labels))
print("Accuracy : {}%".format(accr*100))
Accuracy : 92.0%
InĀ [13]:
test_x, test_y = test_batch_maker(1)

test_x = np.reshape(test_x, [-1, input_h, input_w, input_ch])
my_maps, my_pred, my_w = sess.run((maps, pred, weights['output']), feed_dict = {x: test_x})
attention = np.matmul(my_maps, my_w)
attention = attention[:, :, :, np.argmax(my_pred, axis = 1)].reshape(7,7)

large_test_x = cv2.resize(test_x.reshape(28,28), (28*5, 28*5))
large_attention = cv2.resize(attention, (28*5, 28*5), interpolation = cv2.INTER_CUBIC)

plt.figure(figsize = (10,15))
plt.subplot(3,2,1)
plt.imshow(test_x.reshape(28,28), 'gray')
plt.axis('off')

plt.subplot(3,2,2)
plt.imshow(attention)
plt.axis('off')

plt.subplot(3,2,3)
plt.imshow(large_test_x, 'gray')
plt.axis('off')

plt.subplot(3,2,4)
plt.imshow(large_attention, 'jet', alpha = 0.5)
plt.axis('off')

plt.subplot(3,2,6)
plt.imshow(large_test_x, 'gray')
plt.imshow(large_attention, 'jet', alpha = 0.5)
plt.axis('off')
plt.show()
InĀ [15]:
%%javascript
$.getScript('https://kmahelona.github.io/ipython_notebook_goodies/ipython_notebook_toc.js')