XAI: eXplainable AI


By Prof. Seungchul Lee
http://iailab.kaist.ac.kr/
Industrial AI Lab at KAIST

Table of Contents


0. Lecture Video

In [ ]:
from IPython.display import YouTubeVideo
YouTubeVideo('TuBdJYH22AA', width = "560", height = "315")
Out[ ]:

1. eXplainable AI (XAI)

Feature Importance

In regression, feature importance reflects how much each predictor (input variable) contributes to the model's output. By identifying the most impactful features, you can interpret the model, simplify it by removing unimportant features, and improve its performance.

Consider the example of a process engineer working on a manufacturing production line. It is essential to understand not only how the current production settings affect product quality but also which control parameters need adjustment when production targets are not met. In this context, having an accurate predictive model is important, but equally critical is understanding which features influence the model’s decisions and to what extent.

Feature importance provides insights into how individual features impact the model's output. However, for complex systems and models, a more comprehensive approach is needed. This leads to the broader concept of Explainable AI (XAI), which extends feature importance by providing deeper, interpretable insights into model behavior, ensuring transparency and trust in the decision-making process.


Introduction to Explainable AI (XAI)

Explainable Artificial Intelligence (XAI) refers to a set of methods and techniques that make the decision-making processes of AI systems transparent, interpretable, and understandable to humans. As AI systems become more complex and influential in critical domains such as healthcare, finance, and law enforcement, it is crucial to ensure that their predictions and recommendations are comprehensible and trustworthy.


Key Goals of Explainable AI (XAI)

  • Transparency: Providing insights into how and why an AI model makes specific predictions.
  • Interpretability: Making the behavior of the AI system interpretable by stakeholders (users, regulators, developers).
  • Accountability: Enabling the identification of biases, errors, and vulnerabilities in AI systems.
  • Trust and Adoption: Building user confidence by demonstrating that the model's decisions are logical and ethical.

Importance of Explainable AI (XAI)

  • Ethical and Fair AI: Helps detect and mitigate biases in model predictions.
  • Debugging and Improving Models: Enables researchers and developers to understand model weaknesses and improve performance.
  • Regulatory Compliance: Some regulations, such as GDPR, require explanations for decisions made by AI, especially in automated decision-making systems.
  • User Confidence: XAI can enhance user trust, particularly in sensitive domains like healthcare, where understanding a diagnosis is essential.

Methods for Explainable AI

(1) Post-Hoc Explanation Methods: Techniques that explain decisions after the model has made predictions.

  • SHAP (Shapley Additive Explanations): Assigns importance values to input features based on their contribution to the prediction.
  • LIME (Local Interpretable Model-Agnostic Explanations): Builds a local interpretable model around each prediction to approximate the decision boundary.
  • Saliency Maps: Highlight areas of an input image that are most relevant for the model's prediction.

(2) Intrinsically Interpretable Models: Models that are inherently interpretable due to their design.

  • Linear Regression and Decision Trees: Provide clear insights into the influence of input features.
  • Rule-Based Systems: Make predictions using if-then-else rules, which are easy to interpret.

Having introduced the key categories of explainable AI methods, we will now focus on post-hoc explanation techniques in greater detail. These methods play a crucial role in interpreting complex, black-box models by providing insights into how predictions are made after the model has processed the input. By examining popular approaches such as SHAP, LIME, and saliency maps, we can better understand how these tools enhance transparency and support informed decision-making.


Note:

"Post-Hoc" indicates that interpretability techniques are applied after the model has been constructed.



2. SHAP (Shapley Additive Explanations)

SHAP (Shapley Additive Explanations) is an explainability framework that interprets the predictions of machine learning models by attributing importance values (SHAP values) to each input feature. These values indicate how much each feature contributes to pushing the model’s prediction higher or lower compared to a baseline. SHAP values are based on Shapley values from cooperative game theory, which ensures fairness and consistency in the distribution of contributions among features.

SHAP was introduced by Scott M. Lundberg and Su-In Lee in their 2017 paper titled "A Unified Approach to Interpreting Model Predictions." This work presents a method to interpret complex machine learning models by assigning each feature an importance value for a particular prediction, grounded in cooperative game theory's Shapley values. The paper has been highly influential in the field of explainable AI, providing a consistent framework for understanding model outputs across various machine learning algorithms.


Key Idea of SHAP

The core idea of SHAP is to treat the features of a model as "players" in a cooperative game and the model's prediction as the "payout." The SHAP value for each feature represents its "fair share" of the prediction based on its contribution.


Mathematical Formulation (optional)

The SHAP value $ \phi_i $ for a feature $ i $ is given by:


$$ \phi_i = \sum_{S \subseteq F \setminus \{i\}} \frac{|S|! (|F| - |S| - 1)!}{|F|!} \left( f(S \cup \{i\}) - f(S) \right) $$

where:

  • $ F $ is the set of all features.
  • $ S $ is a subset of features excluding $ i $.
  • $ f(S) $ is the model's prediction when using only the features in $ S $.
  • $ f(S \cup \{i\}) - f(S) $ is the marginal contribution of feature $ i $ to the prediction.

This formula sums over all possible subsets of features to compute the marginal contribution of each feature.


Advantages of SHAP

  • Provides both local explanations (for individual predictions) and global explanations (aggregated insights across the entire dataset).
  • Ensures consistency and fairness in feature importance attribution.
  • Model-agnostic (works for any model) and also has specialized optimizations for tree-based and deep learning models.

SHAP is a powerful and principled framework for interpreting machine learning model predictions. By assigning fair and consistent contributions to each feature, SHAP provides both local and global interpretability. Although computationally intensive, it has become a widely-used method for ensuring transparency and trust in AI systems across a variety of domains.


We will not provide an in-depth discussion of SHAP, as it is already comprehensively implemented in widely available Python libraries. It is generally acceptable for researchers to leverage SHAP directly in their studies to enhance model interpretability and perform feature analysis.



3. LIME (Local Interpretable Model-Agnostic Explanations)

LIME stands for Local Interpretable Model-Agnostic Explanations. It is a technique used to explain individual predictions made by any machine learning model by approximating the complex model locally with a simpler, interpretable model (e.g., a linear regression model or a decision tree). LIME was introduced by Marco Tulio Ribeiro, Sameer Singh, and Carlos Guestrin in 2016.


LIME works on the principle that while the overall behavior of a machine learning model might be complex and nonlinear, the model's predictions can often be explained locally (i.e., near the instance being explained) by a simpler, linear model.


How LIME Works

(1) Perturbation:

  • LIME creates slightly modified versions of the input (perturbed samples) by randomly changing the feature values.
  • These perturbed samples are then passed through the model to get predictions.

(2) Simple Model Fitting:

  • A simple interpretable model (like linear regression) is fit to approximate the relationship between the features and the model’s output in the local neighborhood of the instance.

(3) Feature Importance:

  • The coefficients of the simple model indicate the importance of features in making the prediction for the given instance.

Strengths of LIME

  • Model-agnostic: Works with any machine learning model (e.g., neural networks, random forests, etc.).
  • Local interpretability: Provides instance-specific explanations, which are useful for understanding individual predictions.
  • Flexibility: Can be used for different types of data (e.g., tabular, image, text).

Limitations of LIME

  • Instability: The generated explanation can vary with different perturbations.
  • Computational Cost: Perturbing inputs and fitting local models can be computationally expensive.
  • Assumption of Linearity: Assumes that the local decision boundary is linear, which may not always hold for highly non-linear models.

LIME is a powerful technique for explaining individual model predictions by approximating the model locally with a simpler, interpretable model. Its flexibility and model-agnostic nature make it a widely-used tool for ensuring transparency in machine learning, especially in domains where individual predictions need to be explained.


Again, we will not provide an in-depth discussion of LIME, as it is already comprehensively implemented in widely available Python libraries. It is generally acceptable for researchers to leverage LIME directly in their studies to enhance model interpretability and perform feature analysis.


Side Note: LIME and Sensitivity Analysis

LIME (Local Interpretable Model-Agnostic Explanations) and sensitivity analysis are both approaches used to understand the behavior of machine learning models and how changes in input features influence the output. While they share some conceptual similarities, they differ significantly in their methodology and focus.


(1) Understanding Sensitivity Analysis:

Sensitivity analysis is a technique used to measure how sensitive the output of a model is to variations in its input features. It aims to quantify the importance of different input features by systematically varying them and measuring the corresponding changes in the output. Sensitivity analysis is widely used to assess the robustness, stability, and interpretability of machine learning models. Sensitivity analysis can be:

  • Global: Assessing the influence of input features over the entire input space.
  • Local: Assessing the influence of input features at a particular point (similar to what LIME does).

(2) LIME as a Localized Sensitivity Analysis:

LIME performs a form of local sensitivity analysis by perturbing the input features around a specific instance and observing how the model's output changes. It builds a local interpretable model (e.g., linear regression) that fits the perturbed data, revealing the relative importance of each feature for the given prediction.

  • Perturbation Strategy: LIME perturbs each feature independently and generates multiple "nearby" samples.
  • Feature Importance: The importance of each feature is derived from the coefficients of the simple model, indicating how much the feature influences the output in the local neighborhood.

4. Saliency Maps in Explainable AI

Having discussed SHAP and LIME, which are commonly used for tabular and textual data, we now shift our focus to explainable AI methods designed specifically for image-based models.

When we visually identify an object in an image, we do not examine the entire image; rather, we intuitively focus on the most important parts. Similarly, CNN learning mimics this human behavior by assigning higher weights to the most relevant parts during optimization. However, CNNs typically cannot explicitly recognize these important regions, as the features extracted by the convolutional layers become more abstract after passing through fully connected layers.


In this context, saliency maps play a crucial role by visually illustrating which regions of an image are most influential in the model's decision-making process. Saliency maps are a type of visual explanation technique used to interpret deep learning models, especially in the context of image classification tasks. Saliency maps highlight the most "salient" (important) pixels or regions of an input image that the model relies on to make its prediction. The brighter the pixel in the saliency map, the more it contributed to the model's decision.


Key Concept of Saliency Maps

Saliency maps are generated by computing the gradient of the model's output with respect to the input image. The intuition is that the magnitude of the gradient at a specific pixel indicates how sensitive the model's prediction is to changes in the value of that pixel.


Mathematical Formulation

Given:

  • $ f(x) $ is the model's output for an input image $ x $.
  • The class label of interest is $ c $.

The saliency map $ S(x) $ for class $ c $ is computed as:


$$ S(x) = \left| \frac{\partial f_c(x)}{\partial x} \right| $$

where:

  • $ f_c(x) $ represents the model's score for class $ c $.
  • $ \frac{\partial f_c(x)}{\partial x} $ is the gradient of the output with respect to the input image $ x $.
  • The absolute value is taken to focus on the magnitude of sensitivity, ignoring the direction.

Steps for Generating a Saliency Map

  • Forward Pass: Pass the input image through the model to obtain the class probabilities.
  • Compute Gradients: Compute the gradient of the class score with respect to the input image.
  • Visualize the Magnitudes: Visualize the absolute value of the gradient as a heatmap over the input image.

Advantages of Saliency Maps

  • Visual Interpretability: Provides an intuitive, visual understanding of the model's reasoning.
  • Model-Agnostic: Saliency maps can be computed for any differentiable model.
  • Instance-Specific: Explains individual predictions, making them useful for debugging specific inputs.

Several methods have been proposed to generate saliency maps in deep learning-based models, especially for image-based tasks. These methods aim to highlight important regions or features in the input that influence the model's output. Below are some of the most widely used saliency map generation methods:

  • Vanilla Gradient Saliency Maps
  • SmoothGrad (Smoothed Gradient)
  • Integrated Gradients
  • Guided Backpropagation
  • Grad-CAM (Gradient-weighted Class Activation Mapping)
  • Guided Grad-CAM
  • DeepLIFT (Deep Learning Important Features)

Among the various saliency map methods, we will focus on studying Grad-CAM due to its ability to provide class-specific explanations and highlight spatial regions relevant to model predictions.


4.1. Grad-CAM: Gradient-weighted Class Activation Maps

Grad-CAM (Gradient-weighted Class Activation Mapping) is a powerful visualization tool in explainable AI used to understand the decisions of convolutional neural networks (CNNs). Grad-CAM produces heatmaps that highlight the regions in the input image that were most relevant for the model's prediction of a specific class

Grad-CAM uses the gradients of the class score with respect to the feature maps of a convolutional layer to identify the important regions of the image.


Mathematical Formulation

The class activation map for class $c$ is computed as:


$$ L_{\text{Grad-CAM}}^c(x) = \text{ReLU} \left( \sum_k \alpha_k^c A_k(x) \right) $$

where:

  • $A_k(x)$ is the $k$-th feature map of the convolutional layer for input $x$.

  • $\alpha_k^c$ is the weight for the $k$-th feature map corresponding to class $c$, computed as:


    $$ \alpha_k^c = \frac{1}{Z} \sum_i \sum_j \frac{\partial f_c(x)}{\partial A_k^{ij}(x)} $$

    where $Z$ is the number of pixels in the feature map.


The ReLU function ensures that only positive gradients (positive contributions to the class score) are visualized.




Step-by-Step Process for Grad-CAM


Step 1: Input the Image

  • Provide the image you want to analyze.

Step 2: Forward Pass through the trained CNN

  • Pass the input image through the trained CNN to get the output prediction (class scores).
  • Identify the predicted class (or a target class of interest).

Step 3: Identify the Last Convolutional Layer

  • Select the last convolutional layer of the CNN, as it retains spatial information while capturing high-level features.
  • This layer will be used to generate the saliency map.

Step 4: Compute the Gradients

  • Perform a backward pass to compute the gradients of the score for the target class with respect to the feature maps of the selected convolutional layer.


    $$ \frac{\partial y^c}{\partial A_k} $$
    where $ y^c $ is the score for class $ c $, and $ A_k $ is the $ k $-th feature map.

Step 5: Compute the Global Average of the Gradients

  • Compute the global average of the gradients for each feature map across the spatial dimensions:


    $$ \alpha_k^c = \frac{1}{Z} \sum_{i} \sum_{j} \frac{\partial y^c}{\partial A_k^{ij}} $$
    where $ Z $ is the number of pixels in the feature map.

Step 6: Weight the Feature Maps

  • Multiply each feature map by its corresponding weight $ \alpha_k^c $ (the average gradient):


    $$ L_{\text{Grad-CAM}}^c = \sum_{k} \alpha_k^c A_k $$
    This forms the weighted combination of the feature maps.

Step 7: Apply ReLU Activation

  • Apply a ReLU function to the weighted sum to retain only the positive contributions:


    $$ L_{\text{Grad-CAM}}^c = \text{ReLU} \left( \sum_{k} \alpha_k^c A_k \right) $$
    This ensures that only features positively influencing the target class are considered.

Step 8: Resize the Grad-CAM Map

  • The generated saliency map is upsampled to the same size as the original input image using interpolation.

Step 9: Overlay the Grad-CAM Map on the Input Image

  • Overlay the Grad-CAM heatmap on the input image using a colormap (e.g., "jet" colormap).
  • This visualization highlights the regions of the image that contributed most to the model's decision.

Summary:

Grad-CAM creates a visual explanation for CNN predictions by computing the gradient of the class score relative to the feature maps of the last convolutional layer. By weighting and combining these feature maps, it highlights the regions most relevant to the prediction, offering an intuitive way to interpret CNN decisions.



4.2. Lab for Grad-GAM

We will explore the implementation of Grad-CAM using the NEU case


Download NEU steel surface defects images and labels


First, build the CNN model


In [ ]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2
In [ ]:
from google.colab import drive
drive.mount('/content/drive')
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
In [ ]:
# Change file paths if necessary

train_x = np.load('/content/drive/MyDrive/DL/DL_data/NEU_train_imgs.npy')
train_y = np.load('/content/drive/MyDrive/DL/DL_data/NEU_train_labels.npy')

test_x = np.load('/content/drive/MyDrive/DL/DL_data/NEU_test_imgs.npy')
test_y = np.load('/content/drive/MyDrive/DL/DL_data/NEU_test_labels.npy')

n_train = train_x.shape[0]
n_test = test_x.shape[0]

print ("The number of training images : {}, shape : {}".format(n_train, train_x.shape))
print ("The number of testing images : {}, shape : {}".format(n_test, test_x.shape))
The number of training images : 1500, shape : (1500, 200, 200, 1)
The number of testing images : 300, shape : (300, 200, 200, 1)
In [ ]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(filters = 32,
                           kernel_size = (3,3),
                           activation = 'relu',
                           padding = 'SAME',
                           input_shape = (200, 200, 1)),

    tf.keras.layers.MaxPool2D((2,2)),

    tf.keras.layers.Conv2D(filters = 64,
                           kernel_size = (3,3),
                           activation = 'relu',
                           padding = 'SAME',
                           input_shape = (100, 100, 32)),

    tf.keras.layers.MaxPool2D((2,2)),

    tf.keras.layers.Conv2D(filters = 64,
                           kernel_size = (3,3),
                           activation = 'relu',
                           padding = 'SAME',
                           input_shape = (50, 50, 64)),

    tf.keras.layers.MaxPool2D((2,2)),

    tf.keras.layers.Conv2D(filters = 64,
                           kernel_size = (3,3),
                           activation = 'relu',
                           padding = 'SAME',
                           input_shape = (25, 25, 64)),

    tf.keras.layers.MaxPool2D((2,2)),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(units = 64, activation = 'relu'),
    tf.keras.layers.Dense(units = 6, activation = 'softmax')
])
In [ ]:
model.compile(optimizer = 'adam',
              loss = 'sparse_categorical_crossentropy',
              metrics = ['accuracy'])
In [ ]:
model.fit(train_x, train_y, epochs = 15)
Epoch 1/15
47/47 [==============================] - 4s 33ms/step - loss: 1.7092 - accuracy: 0.2200
Epoch 2/15
47/47 [==============================] - 2s 32ms/step - loss: 1.1546 - accuracy: 0.5287
Epoch 3/15
47/47 [==============================] - 2s 34ms/step - loss: 0.7600 - accuracy: 0.7233
Epoch 4/15
47/47 [==============================] - 2s 33ms/step - loss: 0.5364 - accuracy: 0.8067
Epoch 5/15
47/47 [==============================] - 2s 33ms/step - loss: 0.2964 - accuracy: 0.9007
Epoch 6/15
47/47 [==============================] - 2s 32ms/step - loss: 0.2367 - accuracy: 0.9127
Epoch 7/15
47/47 [==============================] - 2s 32ms/step - loss: 0.1665 - accuracy: 0.9420
Epoch 8/15
47/47 [==============================] - 1s 32ms/step - loss: 0.1637 - accuracy: 0.9487
Epoch 9/15
47/47 [==============================] - 2s 32ms/step - loss: 0.2170 - accuracy: 0.9140
Epoch 10/15
47/47 [==============================] - 2s 32ms/step - loss: 0.1179 - accuracy: 0.9607
Epoch 11/15
47/47 [==============================] - 2s 32ms/step - loss: 0.1644 - accuracy: 0.9387
Epoch 12/15
47/47 [==============================] - 2s 33ms/step - loss: 0.1116 - accuracy: 0.9653
Epoch 13/15
47/47 [==============================] - 2s 34ms/step - loss: 0.1251 - accuracy: 0.9593
Epoch 14/15
47/47 [==============================] - 2s 33ms/step - loss: 0.4703 - accuracy: 0.8693
Epoch 15/15
47/47 [==============================] - 1s 32ms/step - loss: 0.1523 - accuracy: 0.9513
Out[ ]:
<keras.src.callbacks.History at 0x7e4195b05510>
In [ ]:
model.summary()
Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d_4 (Conv2D)           (None, 200, 200, 32)      320       
                                                                 
 max_pooling2d_3 (MaxPoolin  (None, 100, 100, 32)      0         
 g2D)                                                            
                                                                 
 conv2d_5 (Conv2D)           (None, 100, 100, 64)      18496     
                                                                 
 max_pooling2d_4 (MaxPoolin  (None, 50, 50, 64)        0         
 g2D)                                                            
                                                                 
 conv2d_6 (Conv2D)           (None, 50, 50, 64)        36928     
                                                                 
 max_pooling2d_5 (MaxPoolin  (None, 25, 25, 64)        0         
 g2D)                                                            
                                                                 
 conv2d_7 (Conv2D)           (None, 25, 25, 64)        36928     
                                                                 
 max_pooling2d_6 (MaxPoolin  (None, 12, 12, 64)        0         
 g2D)                                                            
                                                                 
 flatten (Flatten)           (None, 9216)              0         
                                                                 
 dense_1 (Dense)             (None, 64)                589888    
                                                                 
 dense_2 (Dense)             (None, 6)                 390       
                                                                 
=================================================================
Total params: 682950 (2.61 MB)
Trainable params: 682950 (2.61 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________

Get Gradients of the Target Class and Compute the Grad-CAM Heatmap


In [ ]:
test_idx = [7]
test_image = tf.convert_to_tensor(test_x[test_idx], dtype=tf.float32)

conv_layer = model.get_layer(index = 6)
grad_model = tf.keras.models.Model(inputs = model.layers[0].input,
                                   outputs = [conv_layer.output, model.layers[-1].output])

with tf.GradientTape() as tape:
    tape.watch(test_image)
    desired_conv_layer_output, preds = grad_model(test_image)

    pred_index = tf.argmax(preds[0])
    class_channel = preds[:, pred_index]

# compute gradient via tensorflow GradientTape()
grads = tape.gradient(class_channel, desired_conv_layer_output)

pooled_grads = tf.reduce_mean(grads, axis = (0, 1, 2))

heatmap = tf.matmul(desired_conv_layer_output[0], pooled_grads[..., tf.newaxis])
heatmap = tf.squeeze(heatmap)

Overlay Heatmap on the Original Image


In [ ]:
attention_grad = np.abs(np.reshape(heatmap,(25,25)))

resized_attention_grad = cv2.resize(attention_grad,
                                    (200*5, 200*5),
                                    interpolation = cv2.INTER_CUBIC)

resized_test_x = cv2.resize(test_image.numpy().reshape(200,200),
                            (200*5, 200*5),
                            interpolation = cv2.INTER_CUBIC)

plt.figure(figsize = (6, 9))
plt.subplot(3,2,1)
plt.imshow(test_x[test_idx].reshape(200,200), 'gray')
plt.axis('off')
plt.subplot(3,2,2)
plt.imshow(attention_grad)
plt.axis('off')
plt.subplot(3,2,3)
plt.imshow(resized_test_x, 'gray')
plt.axis('off')
plt.subplot(3,2,4)
plt.imshow(resized_attention_grad, 'jet', alpha = 0.5)
plt.axis('off')
plt.subplot(3,2,6)
plt.imshow(resized_test_x, 'gray')
plt.imshow(resized_attention_grad, 'jet', alpha = 0.5)
plt.axis('off')
plt.show()

5. Class Activation Maps (CAM)

Grad-CAM is a model-agnostic method that offers the significant advantage of generating saliency maps without requiring any modifications to the architecture of a CNN. In contrast, while Class Activation Mapping (CAM) produces similar saliency maps, it requires modifications to the CNN architecture. For this reason, despite CAM being developed earlier, Grad-CAM is introduced first. We will now explore CAM to understand how its architecture is adapted to generate saliency maps.

Class Activation Maps (CAM) are a type of visualization technique used in deep learning to highlight the spatial regions in an input image that a convolutional neural network (CNN) focuses on to make a prediction for a particular class. CAM helps explain which parts of the image were most influential in determining the predicted class.

  • We can determine which parts of the image the model is focusing on, based on the learned weights
  • Highlighting the importance of the image region to the prediction



5. 1. CNN with a Fully Connected Layer

The conventional CNN can be conceptually divided into two parts. One part is feature extraction and the other is classification.

  • In the feature extraction process, convolution is used to extract the features of the input data so that the classification can be performed well.
  • The classification process classifies which group each input data belongs to by using the extracted features from the input data.



Instead of strictly adhering to the conventional CNN architecture, as in Grad-CAM, one may consider modifying the CNN structure to generate saliency maps in a more intuitive and interpretable manner.


5.2. CAM

Global Average Pooling (GAP)

Global Average Pooling (GAP) is a pooling operation commonly used in convolutional neural networks (CNNs) to reduce the spatial dimensions of feature maps by computing the average of each feature map. Unlike fully connected (dense) layers, which flatten the feature maps, GAP condenses each feature map into a single value by taking the spatial average, resulting in a vector with a length equal to the number of feature maps.


Example of GAP:

  • Input: A feature map of size $6 \times 6 \times 3$.
  • GAP Output: A vector of length 3, where each value is the average of the corresponding $6 \times 6$ feature map.

Characteristics of GAP

  • Spatial Information Summarization: It provides a global summary of feature maps, retaining the essence of features while discarding spatial details.

  • Architecture Simplicity: Simplifies the network architecture by removing the need for flattening operations and dense layers.


CNN with a Global Average Pooling (GAP) Layer

  • A modified convolutional network architecture that replaces fully connected layers with a Global Average Pooling layer.
  • The Class Activation Map (CAM) is a type of heatmap that highlights the importance of specific regions of the image for the model's prediction.
  • As a result, the key regions contributing to the prediction are visually emphasized, providing insights into the model's decision-making process.

The below figure describes the procedure for class activation mapping.



Once the training process with the modified convolutional network architecture is complete, a new input image is fed into the trained network. The following steps outline the process for generating the Class Activation Map (CAM) for that new input.


Step-by-Step Process of Class Activation Mapping (CAM)

Step 1: Input Image

  • Provide an input image to the CNN.
  • The image is passed through the initial convolutional, activation, and pooling layers to extract high-level feature maps.

Step 2: Feature Maps Generation

  • The CNN processes the image through multiple layers and produces feature maps at the final convolutional layer.
  • These feature maps capture spatial patterns such as edges, textures, and shapes relevant to the class prediction.

Step 3: Apply Global Average Pooling (GAP)

  • Instead of flattening the feature maps and passing them through fully connected layers, the Global Average Pooling (GAP) layer is applied.
  • GAP condenses each feature map into a single scalar value by taking the average of all its values.
  • This results in a vector, where each element corresponds to the average activation of a feature map.

Step 4: Linear Layer for Class Scores

  • The vector from the GAP layer is passed through a linear layer to compute the class scores.
  • The weights of this linear layer indicate the importance of each feature map for the corresponding class.

Step 5: Compute Weighted Sum of Feature Maps

  • To generate the Class Activation Map, the feature maps are weighted by the corresponding class-specific weights from the linear layer.

  • The weighted sum of feature maps is computed to create a 2D saliency map.


    $$L^c_{\text{CAM}}(x,y)= \sum_k \omega^c_k \cdot A_k (x, y)$$
    Where:
    • $A_k(x,y)$ is the activation at spatial location $(x, y)$ in the $k$-th feature map.

    • $\omega^c_k$ is the weight corresponding to class $c$ for the $k$-th feature map.


Step 6: Resize the CAM to Match the Original Image

  • The generated Class Activation Map is upsampled to the same size as the input image using interpolation.
  • This allows the highlighted regions to align with the original image dimensions.

Step 7: Overlay the CAM on the Original Image

  • The final CAM is visualized by overlaying it on the input image.
  • The resulting visualization highlights the areas of the image that contributed the most to the CNN's decision.

Key Notes:

  • Architecture Requirement: CAM requires replacing fully connected layers with a GAP layer followed by a linear layer.
  • Limitations: CAM can only be applied to architectures with this specific modification, which led to the development of Grad-CAM for broader applicability.


5.3. Lab: CAM with NEU

Again, we will explore the implementation of CAM using the NEU case


Download NEU steel surface defects images and labels

In [ ]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2
In [ ]:
from google.colab import drive
drive.mount('/content/drive')
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
In [ ]:
# Change file paths if necessary

train_x = np.load('/content/drive/MyDrive/DL/DL_data/NEU_train_imgs.npy')
train_y = np.load('/content/drive/MyDrive/DL/DL_data/NEU_train_labels.npy')

test_x = np.load('/content/drive/MyDrive/DL/DL_data/NEU_test_imgs.npy')
test_y = np.load('/content/drive/MyDrive/DL/DL_data/NEU_test_labels.npy')

n_train = train_x.shape[0]
n_test = test_x.shape[0]

print ("The number of training images : {}, shape : {}".format(n_train, train_x.shape))
print ("The number of testing images : {}, shape : {}".format(n_test, test_x.shape))
The number of training images : 1500, shape : (1500, 200, 200, 1)
The number of testing images : 300, shape : (300, 200, 200, 1)
In [ ]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(filters = 32,
                           kernel_size = (3,3),
                           activation = 'relu',
                           padding = 'SAME',
                           input_shape = (200, 200, 1)),

    tf.keras.layers.MaxPool2D((2,2)),

    tf.keras.layers.Conv2D(filters = 64,
                           kernel_size = (3,3),
                           activation = 'relu',
                           padding = 'SAME',
                           input_shape = (100, 100, 32)),

    tf.keras.layers.MaxPool2D((2,2)),

    tf.keras.layers.Conv2D(filters = 64,
                           kernel_size = (3,3),
                           activation = 'relu',
                           padding = 'SAME',
                           input_shape = (50, 50, 64)),

    tf.keras.layers.MaxPool2D((2,2)),

    tf.keras.layers.Conv2D(filters = 64,
                           kernel_size = (3,3),
                           activation = 'relu',
                           padding = 'SAME',
                           input_shape = (25, 25, 64)),

    tf.keras.layers.GlobalAveragePooling2D(),

    tf.keras.layers.Dense(6, activation = 'softmax')
])
In [ ]:
model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             (None, 200, 200, 32)      320       
                                                                 
 max_pooling2d (MaxPooling2  (None, 100, 100, 32)      0         
 D)                                                              
                                                                 
 conv2d_1 (Conv2D)           (None, 100, 100, 64)      18496     
                                                                 
 max_pooling2d_1 (MaxPoolin  (None, 50, 50, 64)        0         
 g2D)                                                            
                                                                 
 conv2d_2 (Conv2D)           (None, 50, 50, 64)        36928     
                                                                 
 max_pooling2d_2 (MaxPoolin  (None, 25, 25, 64)        0         
 g2D)                                                            
                                                                 
 conv2d_3 (Conv2D)           (None, 25, 25, 64)        36928     
                                                                 
 global_average_pooling2d (  (None, 64)                0         
 GlobalAveragePooling2D)                                         
                                                                 
 dense (Dense)               (None, 6)                 390       
                                                                 
=================================================================
Total params: 93062 (363.52 KB)
Trainable params: 93062 (363.52 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
In [ ]:
model.compile(optimizer = 'adam',
              loss = 'sparse_categorical_crossentropy',
              metrics = ['accuracy'])
In [ ]:
model.fit(train_x, train_y, epochs = 15)
Epoch 1/15
47/47 [==============================] - 10s 68ms/step - loss: 1.7321 - accuracy: 0.2100
Epoch 2/15
47/47 [==============================] - 2s 45ms/step - loss: 1.1566 - accuracy: 0.5353
Epoch 3/15
47/47 [==============================] - 2s 38ms/step - loss: 0.6805 - accuracy: 0.7527
Epoch 4/15
47/47 [==============================] - 2s 36ms/step - loss: 0.6311 - accuracy: 0.7533
Epoch 5/15
47/47 [==============================] - 2s 36ms/step - loss: 0.5192 - accuracy: 0.8047
Epoch 6/15
47/47 [==============================] - 2s 38ms/step - loss: 0.4316 - accuracy: 0.8380
Epoch 7/15
47/47 [==============================] - 2s 39ms/step - loss: 0.4138 - accuracy: 0.8447
Epoch 8/15
47/47 [==============================] - 2s 38ms/step - loss: 0.4769 - accuracy: 0.8280
Epoch 9/15
47/47 [==============================] - 2s 44ms/step - loss: 0.3368 - accuracy: 0.8813
Epoch 10/15
47/47 [==============================] - 2s 47ms/step - loss: 0.3436 - accuracy: 0.8740
Epoch 11/15
47/47 [==============================] - 2s 36ms/step - loss: 0.3365 - accuracy: 0.8787
Epoch 12/15
47/47 [==============================] - 2s 33ms/step - loss: 0.2938 - accuracy: 0.8913
Epoch 13/15
47/47 [==============================] - 2s 33ms/step - loss: 0.3180 - accuracy: 0.8847
Epoch 14/15
47/47 [==============================] - 2s 33ms/step - loss: 0.3724 - accuracy: 0.8633
Epoch 15/15
47/47 [==============================] - 2s 33ms/step - loss: 0.2569 - accuracy: 0.9080
Out[ ]:
<keras.src.callbacks.History at 0x7e41e206a500>
In [ ]:
# accuracy test
test_loss, test_acc = model.evaluate(test_x, test_y)
10/10 [==============================] - 1s 41ms/step - loss: 0.2665 - accuracy: 0.9033

Once the training process is complete, a new input image is fed into the trained network to generate the Class Activation Map (CAM). The CAM is computed as a linear combination of the learned weights $ \omega^c_k $ and the corresponding feature maps $ A_k $, highlighting the regions that contribute most to the class prediction. Since the spatial dimensions of the feature maps are smaller than the original input image, the activation map is resized to match the input image's dimensions and is overlaid on the input to provide an intuitive visualization of the model's decision-making process.


In [ ]:
# get max pooling layer and fully connected layer
conv_layer = model.get_layer(index = 6)
fc_layer = model.layers[8].get_weights()[0]

 # Define the class activation map operation using a Lambda layer
cam_layer = tf.keras.layers.Lambda(lambda x: tf.matmul(x, fc_layer_weights))(conv_layer.output)

 # Create the class activation map model
CAM = tf.keras.Model(inputs = model.inputs, outputs = cam_layer)
In [ ]:
test_idx = [7]
test_image = test_x[test_idx]

pred = np.argmax(model.predict(test_image), axis = 1)
predCAM = CAM.predict(test_image)

attention = predCAM[:,:,:,pred]
attention = np.abs(np.reshape(attention,(25,25)))

resized_attention = cv2.resize(attention,
                               (200*5, 200*5),
                               interpolation = cv2.INTER_CUBIC)

resized_test_x = cv2.resize(test_image.reshape(200,200),
                            (200*5, 200*5),
                            interpolation = cv2.INTER_CUBIC)

plt.figure(figsize = (6, 9))
plt.subplot(3,2,1)
plt.imshow(test_x[test_idx].reshape(200,200), 'gray')
plt.axis('off')
plt.subplot(3,2,2)
plt.imshow(attention)
plt.axis('off')
plt.subplot(3,2,3)
plt.imshow(resized_test_x, 'gray')
plt.axis('off')
plt.subplot(3,2,4)
plt.imshow(resized_attention, 'jet', alpha = 0.5)
plt.axis('off')
plt.subplot(3,2,6)
plt.imshow(resized_test_x, 'gray')
plt.imshow(resized_attention, 'jet', alpha = 0.5)
plt.axis('off')
plt.show()
1/1 [==============================] - 0s 190ms/step
1/1 [==============================] - 0s 63ms/step
In [ ]:
%%javascript
$.getScript('https://kmahelona.github.io/ipython_notebook_goodies/ipython_notebook_toc.js')