Explainable AI using Grad-CAM


By Prof. Hyunseok Oh
https://sddo.gist.ac.kr/
SDDO Lab at GIST

1. Building model for Grad-CAM

1.1 Import Library

In [ ]:
import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import models
from tensorflow.keras import backend as K
from tensorflow.keras.preprocessing import image
from PIL import Image
import matplotlib.cm as cm
In [ ]:
# os.environ["CUDA_VISIBLE_DEVICES"]="1"
# gpus = tf.config.experimental.list_physical_devices('GPU')
# if gpus:
#     try:
#         tf.config.experimental.set_memory_growth(gpus[0], True)
#     except RuntimeError as e:
#         print(e)

1.2 Load MNIST data set

In [ ]:
mnist = tf.keras.datasets.mnist

(train_x, train_y), (test_x, test_y) = mnist.load_data()

# using 3000 data for test
train_x=train_x[:3000]
train_y=train_y[:3000]

# using 1000 data for test
test_x=test_x[:1000]
test_y=test_y[:1000]

train_x, test_x = train_x/255.0, test_x/255.0

train_x = train_x.reshape((train_x.shape[0], 28, 28, 1))
test_x = test_x.reshape((test_x.shape[0], 28, 28, 1))

n_train = train_x.shape[0]
n_test = test_x.shape[0]

print ("The number of training images : {}, shape : {}".format(n_train, train_x.shape))
print ("The number of testing images : {}, shape : {}".format(n_test, test_x.shape))
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
11501568/11490434 [==============================] - 0s 0us/step
The number of training images : 3000, shape : (3000, 28, 28, 1)
The number of testing images : 1000, shape : (1000, 28, 28, 1)
In [ ]:
idx = np.random.randint(n_train)
plt.figure(figsize = (8,6))
plt.imshow(train_x[idx][:,:,0], 'gray')
plt.title('image exsample', fontsize = 20)
plt.axis('off')
plt.show()

1.3 Build a CNN Model

image.png

In [ ]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(filters = 32,
                          kernel_size = (3, 3),
                          activation = 'relu',
                          padding = 'SAME',
                          input_shape = (28, 28, 1)),
    tf.keras.layers.MaxPool2D((2, 2)),
    
    tf.keras.layers.Conv2D(filters = 64,
                          kernel_size = (3, 3),
                          activation = 'relu',
                          padding = 'SAME',
                          input_shape = (14, 14, 32)),
    
    tf.keras.layers.MaxPool2D((2, 2)),
    
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(units = 32, activation = 'relu'),
    tf.keras.layers.Dense(units = 10, activation = 'softmax')
])

1.4 Define Loss and Optimizer

In [ ]:
model.compile(optimizer = 'adam',
             loss = 'sparse_categorical_crossentropy',
             metrics = 'accuracy')

1.5 Training

In [ ]:
history=model.fit(train_x, train_y, epochs = 10, batch_size = 64)
Epoch 1/10
47/47 [==============================] - 4s 71ms/step - loss: 1.2636 - accuracy: 0.6410
Epoch 2/10
47/47 [==============================] - 3s 71ms/step - loss: 0.3771 - accuracy: 0.8907
Epoch 3/10
47/47 [==============================] - 3s 71ms/step - loss: 0.2509 - accuracy: 0.9297
Epoch 4/10
47/47 [==============================] - 3s 70ms/step - loss: 0.1725 - accuracy: 0.9510
Epoch 5/10
47/47 [==============================] - 3s 70ms/step - loss: 0.1379 - accuracy: 0.9600
Epoch 6/10
47/47 [==============================] - 3s 69ms/step - loss: 0.1124 - accuracy: 0.9657
Epoch 7/10
47/47 [==============================] - 3s 70ms/step - loss: 0.0876 - accuracy: 0.9753
Epoch 8/10
47/47 [==============================] - 3s 70ms/step - loss: 0.0734 - accuracy: 0.9783
Epoch 9/10
47/47 [==============================] - 3s 70ms/step - loss: 0.0547 - accuracy: 0.9840
Epoch 10/10
47/47 [==============================] - 3s 70ms/step - loss: 0.0452 - accuracy: 0.9860

1.6 Test

In [ ]:
model.evaluate(test_x, test_y)
32/32 [==============================] - 1s 11ms/step - loss: 0.1551 - accuracy: 0.9480
Out[ ]:
[0.1551273614168167, 0.9480000138282776]
In [ ]:
import matplotlib.pyplot as plt
loss = history.history["loss"]

epochs = range(1, len(loss) + 1)
plt.plot(epochs, loss, label="Training loss")
plt.title("Training loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()
In [ ]:
plt.clf()
acc = history.history["accuracy"]

plt.plot(epochs, acc, label="Training accuracy")

plt.title("Training accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.show()
In [ ]:
test_idx = [7]
test_img = test_x[test_idx]

predict = model.predict(test_img)
mypred = np.argmax(predict, axis = 1)

plt.figure(figsize = (12, 5))

plt.subplot(1, 2, 1)
plt.imshow(test_img.reshape(28, 28), 'gray')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.stem(predict[0])
plt.show()

print('Prediction : {}'.format(mypred[0]))
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:13: UserWarning: In Matplotlib 3.3 individual lines on a stem plot will be added as a LineCollection instead of individual lines. This significantly improves the performance of a stem plot. To remove this warning and switch to the new behaviour, set the "use_line_collection" keyword argument to True.
  del sys.path[0]
Prediction : 9

2. Visualization Using Grad-CAM

2.1 Model summary

In [ ]:
model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             (None, 28, 28, 32)        320       
                                                                 
 max_pooling2d (MaxPooling2D  (None, 14, 14, 32)       0         
 )                                                               
                                                                 
 conv2d_1 (Conv2D)           (None, 14, 14, 64)        18496     
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 7, 7, 64)         0         
 2D)                                                             
                                                                 
 flatten (Flatten)           (None, 3136)              0         
                                                                 
 dense (Dense)               (None, 32)                100384    
                                                                 
 dense_1 (Dense)             (None, 10)                330       
                                                                 
=================================================================
Total params: 119,530
Trainable params: 119,530
Non-trainable params: 0
_________________________________________________________________

2.2 Implementing Grad-CAM

image.png

In [ ]:
def make_gradcam_heatmap(img_array, model, conv_layer_name, pred_index=None):
    # 입력에 따른 피처 맵 모델 생성
    grad_model = tf.keras.models.Model(
        [model.inputs], [model.get_layer(conv_layer_name).output, model.output]
    )
    #출력에 따른 Gradient 계산 함수
    with tf.GradientTape() as tape:
        conv_layer_output, preds = grad_model(img_array)
        if pred_index is None:
            pred_index = tf.argmax(preds[0])
        class_channel = preds[:, pred_index]

    # 예측된 값에 대한 Gradient 호출
    grads = tape.gradient(class_channel, conv_layer_output)

    # 각 피쳐 맵에 대한 가중치 산출
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))

    # 계산된 가중치와 해당 레이어를 곱하여 얼마나 중요도를 가지는지 히트맵 계산
    conv_layer_output = conv_layer_output[0]
    heatmap = conv_layer_output @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)

    # 시각화를 위해 0-1 범위로 정규화 및 ReLU 수행
    heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
    return heatmap.numpy()

2.3 Test

In [ ]:
test_img = test_x[[123]]
layer_name = 'conv2d'
heatmap = make_gradcam_heatmap(test_x[[12]], model, layer_name)
plt.matshow(heatmap)
plt.show()

2.4 Joint image and heatmap

In [ ]:
def save_and_display_gradcam(img_input, heatmap, cam_path = "cam.jpg",  alpha=0.01):
    # Load the original image
    img =img_input.reshape(28,28)
    img = keras.preprocessing.image.img_to_array(img)

    # 정규화된 Heatmap 이미지를 0-255 범위로 변환
    heatmap = np.uint8(255 * heatmap)
    jet = cm.get_cmap("jet")

    # 계산된 값을 RGB 값으로 변경 및 이미지 변환
    jet_colors = jet(np.arange(256))[:, :3]
    jet_heatmap = jet_colors[heatmap]
    jet_heatmap = keras.preprocessing.image.array_to_img(jet_heatmap)
    jet_heatmap = jet_heatmap.resize((img.shape[1], img.shape[0]))
    jet_heatmap = keras.preprocessing.image.img_to_array(jet_heatmap)

    # 계산한 Heatmap과 입력 이미지 결합
    # alpha 값을 통해 입력 이미지의 투명도 계산
    superimposed_img = jet_heatmap * alpha + img
    superimposed_img = keras.preprocessing.image.array_to_img(superimposed_img)

    # 생성된 이미지 저장
    superimposed_img.save(cam_path)

    # 가시화
    plt.imshow(superimposed_img)

2.5 Result of Jointing image and heatmap

In [ ]:
layer_name = 'conv2d'
layer_output = model.get_layer(layer_name).output
experiment = test_x[[1]]
In [ ]:
plt.imshow(experiment.reshape(28,28),'gray')
plt.show()
In [ ]:
make_heatmap = make_gradcam_heatmap(experiment, model, layer_name)
save_and_display_gradcam(experiment, make_heatmap)

2.6 Compare similar numbers

image.png

In [ ]:
## 1, 7, 9를 비교하여 Grad CAM이 어느 부분을 통해 판별하는지 점검 수행

list_1 = []
list_7 = []
list_9 = []

for i in range(1000):
    if test_y[i] == 1:
        list_1.append(i)
    if test_y[i] == 7:
        list_7.append(i)
    if test_y[i] == 9:
        list_9.append(i)
In [ ]:
test_idx_1 = [list_1[np.random.randint(0, 90)]]
test_image_1 = test_x[test_idx_1]

test_idx_2 = [list_7[np.random.randint(0, 90)]]
test_image_2 = test_x[test_idx_2]

test_idx_3 = [list_9[np.random.randint(0, 90)]]
test_image_3 = test_x[test_idx_3]
In [ ]:
make_heatmap = make_gradcam_heatmap(test_image_1, model, layer_name)
save_and_display_gradcam(test_image_1, make_heatmap)
In [ ]:
make_heatmap = make_gradcam_heatmap(test_image_2, model, layer_name)
save_and_display_gradcam(test_image_2, make_heatmap)
In [ ]:
make_heatmap = make_gradcam_heatmap(test_image_3, model, layer_name)
save_and_display_gradcam(test_image_3, make_heatmap)

3. Grad-CAM Using trained model(VGG16)

3.1 Import Library

In [ ]:
import numpy as np
import tensorflow as tf
from tensorflow import keras

# Display
from IPython.display import Image, display
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from tensorflow.keras.preprocessing.image    import img_to_array, load_img
from tensorflow.keras.applications.vgg16     import (preprocess_input, decode_predictions)

3.2 Import VGG 16 model and image data set

In [ ]:
!wget https://url.kr/argi5c -O 'images.zip'
!unzip images.zip -d './images'
--2022-01-18 04:58:33--  https://url.kr/argi5c
Resolving url.kr (url.kr)... 183.111.169.122
Connecting to url.kr (url.kr)|183.111.169.122|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://url.kr/spam_filtering_system.php?short=argi5c [following]
--2022-01-18 04:58:34--  https://url.kr/spam_filtering_system.php?short=argi5c
Reusing existing connection to url.kr:443.
HTTP request sent, awaiting response... 302 Found
Location: https://github.com/MinseokOff/KSME_advanced/raw/main/LRP/images.zip [following]
--2022-01-18 04:58:34--  https://github.com/MinseokOff/KSME_advanced/raw/main/LRP/images.zip
Resolving github.com (github.com)... 192.30.255.112
Connecting to github.com (github.com)|192.30.255.112|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/MinseokOff/KSME_advanced/main/LRP/images.zip [following]
--2022-01-18 04:58:34--  https://raw.githubusercontent.com/MinseokOff/KSME_advanced/main/LRP/images.zip
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 327023 (319K) [application/zip]
Saving to: ‘images.zip’

images.zip          100%[===================>] 319.36K  --.-KB/s    in 0.02s   

2022-01-18 04:58:34 (13.4 MB/s) - ‘images.zip’ saved [327023/327023]

Archive:  images.zip
replace ./images/cat.1.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: ./images/cat.1.jpg      
replace ./images/cat.2.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: ./images/cat.2.jpg      
replace ./images/cat.3.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: ./images/cat.3.jpg      
replace ./images/cat.4.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: ./images/cat.4.jpg      
replace ./images/cat.5.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: ./images/cat.5.jpg      
replace ./images/cat.6.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: ./images/cat.6.jpg      
replace ./images/cat.7.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: ./images/cat.7.jpg      
replace ./images/dog.0.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: ./images/dog.0.jpg      
replace ./images/dog.1.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: ./images/dog.1.jpg      
replace ./images/dog.2.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: ./images/dog.2.jpg      
replace ./images/dog.3.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: ./images/dog.3.jpg      
replace ./images/dog.4.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: ./images/dog.4.jpg      
replace ./images/dog.5.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: ./images/dog.5.jpg      
replace ./images/dog.6.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: ./images/dog.6.jpg      
replace ./images/dog.7.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: ./images/dog.7.jpg      
In [ ]:
model_builder = tf.keras.applications.vgg16.VGG16

preprocess_input = keras.applications.vgg16.preprocess_input
decode_predictions = keras.applications.vgg16.decode_predictions

# 이미지 경로
# image_ = load_img('data/cat.1.jpg', target_size=(224, 224))
image_= load_img('./images/cat.1.jpg', target_size=(224, 224))
plt.figure(figsize=(10,10))
plt.imshow(image_)
Out[ ]:
<matplotlib.image.AxesImage at 0x7f43bb13e210>