한국소성가공학회 실습 3

XAI: exPlainable Artificial Intelligence


By Prof. Seungchul Lee
http://iai.postech.ac.kr/
Industrial AI Lab at POSTECH

Table of Contents

1. Black Box AI

  • Black box AI: AI produces insights based on a data set, but the end-user doesn’t know how
    • Many machine learning and deep learning models share ‘black box’ problem
    • AI does not provide reasons behind the decision or prediction it makes
    • The reliability of AI models may be questioned





  • XAI which humans can understand the decisions or predictions made by the AI





Why XAI?

  • XAI can be used to increase the interpretability of AI by enabling description of the expected outcome and potential bias of the model
  • Depending on the AI performance, XAI results can be used in various ways:
    • AI performance < Human performance
      • XAI suggests improvement directions for AI models
    • AI performance ≈ Human performance
      • XAI identifies the principles behind AI model learning
    • AI performance > Human performance
      • XAI enables acquiring new knowledge from AI





Model-Specific XAI

  • Model-Specific XAI: only applicable to specific algorithms that provides explanations by using the intrinsic structure of the model
    • Examples: Class Activation Mapping (CAM) & Gradient-CAM for Convolution Neural Network (CNN) models





Model-Agnostic XAI

  • Model Agnostic XAI: applicable to any machine learning algorithms and work on the black box model
    • Obtain explanations by perturbing and mutating the input data and obtaining sensitivity of the performance of theses mutations with respect to the original data performance
    • Examples:
      • SHapley Additive exPlanations (SHAP)
      • Local Interpretable Model-agnostic Explanations (LIME)





2. SHapley Additive exPlanations (SHAP)

  • SHAP: a game theoretic approach to explain the output of any machine learning model
    • Compute feature importance on each predicted value using Shapley value
    • According to the SHAP value, the contribution of each feature can be expressed as the degree of change in the overall performance when the contribution of that feature is excluded
    • Unlike general permutation method, SHAP calculates the model influence by considering the dependencies between features





3. Local Interpretable Model-agnostic Explanations (LIME)

  • LIME: which feature of the current data the model focused on and which features were used as the basis for prediction
    • LIME partially masks the input data
    • Then, the transformed input data is given to the model and the predicted value is obtained
    • If the prediction result changed a lot, we can know that the masked part was important
    • On the other hand, if the prediction result did not change much, we can know that the masked part was not very important









4. XAI with Machine Learning

4.1. RandomForest Model

In [1]:
!pip install shap lime
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting shap
  Downloading shap-0.41.0-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (575 kB)
     |████████████████████████████████| 575 kB 5.2 MB/s 
Collecting lime
  Downloading lime-0.2.0.1.tar.gz (275 kB)
     |████████████████████████████████| 275 kB 42.0 MB/s 
Collecting slicer==0.0.7
  Downloading slicer-0.0.7-py3-none-any.whl (14 kB)
Requirement already satisfied: numpy in /usr/local/lib/python3.8/dist-packages (from shap) (1.21.6)
Requirement already satisfied: cloudpickle in /usr/local/lib/python3.8/dist-packages (from shap) (1.5.0)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.8/dist-packages (from shap) (1.0.2)
Requirement already satisfied: packaging>20.9 in /usr/local/lib/python3.8/dist-packages (from shap) (21.3)
Requirement already satisfied: numba in /usr/local/lib/python3.8/dist-packages (from shap) (0.56.4)
Requirement already satisfied: pandas in /usr/local/lib/python3.8/dist-packages (from shap) (1.3.5)
Requirement already satisfied: tqdm>4.25.0 in /usr/local/lib/python3.8/dist-packages (from shap) (4.64.1)
Requirement already satisfied: scipy in /usr/local/lib/python3.8/dist-packages (from shap) (1.7.3)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.8/dist-packages (from packaging>20.9->shap) (3.0.9)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.8/dist-packages (from lime) (3.2.2)
Requirement already satisfied: scikit-image>=0.12 in /usr/local/lib/python3.8/dist-packages (from lime) (0.18.3)
Requirement already satisfied: pillow!=7.1.0,!=7.1.1,>=4.3.0 in /usr/local/lib/python3.8/dist-packages (from scikit-image>=0.12->lime) (7.1.2)
Requirement already satisfied: PyWavelets>=1.1.1 in /usr/local/lib/python3.8/dist-packages (from scikit-image>=0.12->lime) (1.3.0)
Requirement already satisfied: networkx>=2.0 in /usr/local/lib/python3.8/dist-packages (from scikit-image>=0.12->lime) (2.6.3)
Requirement already satisfied: tifffile>=2019.7.26 in /usr/local/lib/python3.8/dist-packages (from scikit-image>=0.12->lime) (2021.11.2)
Requirement already satisfied: imageio>=2.3.0 in /usr/local/lib/python3.8/dist-packages (from scikit-image>=0.12->lime) (2.9.0)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib->lime) (0.11.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->lime) (1.4.4)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->lime) (2.8.2)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.1->matplotlib->lime) (1.15.0)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.8/dist-packages (from scikit-learn->shap) (1.2.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from scikit-learn->shap) (3.1.0)
Requirement already satisfied: setuptools in /usr/local/lib/python3.8/dist-packages (from numba->shap) (57.4.0)
Requirement already satisfied: llvmlite<0.40,>=0.39.0dev0 in /usr/local/lib/python3.8/dist-packages (from numba->shap) (0.39.1)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.8/dist-packages (from numba->shap) (4.13.0)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.8/dist-packages (from importlib-metadata->numba->shap) (3.10.0)
Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.8/dist-packages (from pandas->shap) (2022.6)
Building wheels for collected packages: lime
  Building wheel for lime (setup.py) ... done
  Created wheel for lime: filename=lime-0.2.0.1-py3-none-any.whl size=283857 sha256=096d1d9635dfe7184ee42fd70bd28eafb45844237135f823c0d2318eff9ee09c
  Stored in directory: /root/.cache/pip/wheels/e6/a6/20/cc1e293fcdb67ede666fed293cb895395e7ecceb4467779546
Successfully built lime
Installing collected packages: slicer, shap, lime
Successfully installed lime-0.2.0.1 shap-0.41.0 slicer-0.0.7
In [2]:
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive

Import Libraries

In [ ]:
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
from sklearn.metrics import r2_score

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

Define Input Data

In [ ]:
df = pd.read_csv("/content/drive/MyDrive/kstp/data_files/datafile.csv")

normalized_df = (df-df.min())/(df.max()-df.min())
df = normalized_df
In [ ]:
df
In [ ]:
y = df['Weight'].values
In [ ]:
X = df.drop(["Weight"], axis = 1)
In [ ]:
print("There are {} possible descriptors:".format(len(X.columns)))
print(X.columns)

Split Train-Test Set and Train Random Forest Regressor

In [ ]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.30, random_state = 1)

rf = RandomForestRegressor(random_state = 1)
rf.fit(X_train, y_train)

Evaluate the Random Forest Regressor with Test Set

In [ ]:
y_pred = rf.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
print('RMSE = {:.3f} '.format(np.sqrt(mse)))
In [ ]:
r2 = r2_score(y_test, y_pred)
print('R2 = {:.3f} '.format(r2))
In [ ]:
fig, ax = plt.subplots(figsize=(10, 10))
ax.scatter(y_test, y_pred, edgecolors=(0, 0, 0))
ax.plot([y.min(), y.max()], [y.min(), y.max()], "r--", lw=4)
ax.set_xlabel("Measured")
ax.set_ylabel("Predicted")
plt.show()

4.2. SHAP

SHAP Implementation

In [ ]:
import shap

explainer = shap.TreeExplainer(rf)
shap_values = explainer.shap_values(X_test)

SHAP Value of a Test Data (Local Feature Importance)

In [ ]:
shap.bar_plot(shap_values[0],features = X_test.iloc[0,:] ,feature_names = X.columns )

Average of Absolute SHAP Values of Entire Test Data (Global Feature Importance)

In [ ]:
shap.summary_plot(shap_values, X_test, plot_type = "bar")

4.3. LIME

LIME Implementation

In [ ]:
from lime import lime_tabular

X_train = X_train.to_numpy()
X_test = X_test.to_numpy()

explainer = lime_tabular.LimeTabularExplainer(X_train, 
                                              mode = "regression",
                                              feature_names = X.columns)
explainer

LIME Result

In [ ]:
explanation = explainer.explain_instance(X_test[0], rf.predict, num_features = len(X.columns))
explanation
explanation.show_in_notebook()
In [ ]:
with plt.style.context("ggplot"):
    explanation.as_pyplot_figure()

5. XAI with Deep Learning

5.1. ANN Model

Import Libraries

In [3]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.metrics import r2_score

import matplotlib.pyplot as plt
import tensorflow as tf
import pandas as pd
import numpy as np

Define Input Data

In [4]:
df = pd.read_csv("/content/drive/MyDrive/kstp/data_files/datafile.csv") 

normalized_df = (df - df.min())/(df.max() - df.min())
df = normalized_df
In [5]:
df
Out[5]:
Weight Filltime Melt temp Mold temp Pack press Packtime
0 0.035330 0.0 0.25000 0.133333 0.0 0.2
1 0.049863 0.0 0.25000 0.133333 0.0 0.6
2 0.069474 0.0 0.25000 0.133333 0.0 1.0
3 0.022718 0.0 0.34375 0.233333 0.5 0.2
4 0.035128 0.0 0.34375 0.233333 0.5 0.6
... ... ... ... ... ... ...
1895 0.987906 1.0 0.84375 0.800000 1.0 0.6
1896 1.000000 1.0 0.84375 0.800000 1.0 1.0
1897 0.924032 1.0 1.00000 0.866667 0.0 0.2
1898 0.940060 1.0 1.00000 0.866667 0.0 0.6
1899 0.956140 1.0 1.00000 0.866667 0.0 1.0

1900 rows × 6 columns

In [6]:
y = df['Weight'].values
In [7]:
X = df.drop(["Weight"], axis = 1)
In [8]:
print("There are {} possible descriptors:".format(len(X.columns)))
print(X.columns)
There are 5 possible descriptors:
Index(['Filltime', 'Melt temp', 'Mold temp', 'Pack press', 'Packtime'], dtype='object')

Split Train-Test Set

In [9]:
X_train, X_test, Y_train, Y_test = train_test_split(X, y, test_size = 0.30, random_state = 10)
In [10]:
X_train.shape
Out[10]:
(1330, 5)

Define ANN Model

In [11]:
model = tf.keras.models.Sequential([
    tf.keras.Input(shape = (X_train.shape[1],)),
    tf.keras.layers.Dense(1024, activation = 'relu'),
    tf.keras.layers.Dense(128, activation = 'relu'),
    tf.keras.layers.Dense(1)
])
In [12]:
model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               (None, 1024)              6144      
                                                                 
 dense_1 (Dense)             (None, 128)               131200    
                                                                 
 dense_2 (Dense)             (None, 1)                 129       
                                                                 
=================================================================
Total params: 137,473
Trainable params: 137,473
Non-trainable params: 0
_________________________________________________________________
In [13]:
model.compile(loss = tf.keras.losses.MeanSquaredError(),
              optimizer = tf.keras.optimizers.Adam())

Train DNN Model

In [14]:
model.fit(X_train, 
          Y_train, 
          epochs = 150, 
          batch_size = 64, 
          validation_split = 0.1)
Epoch 1/150
19/19 [==============================] - 2s 38ms/step - loss: 0.0640 - val_loss: 0.0461
Epoch 2/150
19/19 [==============================] - 0s 14ms/step - loss: 0.0431 - val_loss: 0.0379
Epoch 3/150
19/19 [==============================] - 0s 17ms/step - loss: 0.0392 - val_loss: 0.0294
Epoch 4/150
19/19 [==============================] - 0s 16ms/step - loss: 0.0378 - val_loss: 0.0363
Epoch 5/150
19/19 [==============================] - 0s 17ms/step - loss: 0.0377 - val_loss: 0.0273
Epoch 6/150
19/19 [==============================] - 0s 16ms/step - loss: 0.0361 - val_loss: 0.0263
Epoch 7/150
19/19 [==============================] - 0s 12ms/step - loss: 0.0345 - val_loss: 0.0251
Epoch 8/150
19/19 [==============================] - 0s 18ms/step - loss: 0.0339 - val_loss: 0.0241
Epoch 9/150
19/19 [==============================] - 0s 16ms/step - loss: 0.0329 - val_loss: 0.0230
Epoch 10/150
19/19 [==============================] - 0s 18ms/step - loss: 0.0319 - val_loss: 0.0216
Epoch 11/150
19/19 [==============================] - 0s 15ms/step - loss: 0.0316 - val_loss: 0.0222
Epoch 12/150
19/19 [==============================] - 0s 16ms/step - loss: 0.0314 - val_loss: 0.0262
Epoch 13/150
19/19 [==============================] - 0s 16ms/step - loss: 0.0305 - val_loss: 0.0206
Epoch 14/150
19/19 [==============================] - 0s 17ms/step - loss: 0.0308 - val_loss: 0.0207
Epoch 15/150
19/19 [==============================] - 0s 13ms/step - loss: 0.0328 - val_loss: 0.0226
Epoch 16/150
19/19 [==============================] - 0s 15ms/step - loss: 0.0328 - val_loss: 0.0264
Epoch 17/150
19/19 [==============================] - 0s 13ms/step - loss: 0.0302 - val_loss: 0.0203
Epoch 18/150
19/19 [==============================] - 0s 14ms/step - loss: 0.0283 - val_loss: 0.0230
Epoch 19/150
19/19 [==============================] - 0s 14ms/step - loss: 0.0275 - val_loss: 0.0195
Epoch 20/150
19/19 [==============================] - 0s 15ms/step - loss: 0.0269 - val_loss: 0.0201
Epoch 21/150
19/19 [==============================] - 0s 14ms/step - loss: 0.0262 - val_loss: 0.0188
Epoch 22/150
19/19 [==============================] - 0s 21ms/step - loss: 0.0269 - val_loss: 0.0189
Epoch 23/150
19/19 [==============================] - 0s 15ms/step - loss: 0.0258 - val_loss: 0.0196
Epoch 24/150
19/19 [==============================] - 0s 20ms/step - loss: 0.0249 - val_loss: 0.0217
Epoch 25/150
19/19 [==============================] - 0s 17ms/step - loss: 0.0256 - val_loss: 0.0200
Epoch 26/150
19/19 [==============================] - 0s 22ms/step - loss: 0.0236 - val_loss: 0.0184
Epoch 27/150
19/19 [==============================] - 0s 19ms/step - loss: 0.0230 - val_loss: 0.0191
Epoch 28/150
19/19 [==============================] - 0s 16ms/step - loss: 0.0228 - val_loss: 0.0163
Epoch 29/150
19/19 [==============================] - 0s 15ms/step - loss: 0.0228 - val_loss: 0.0177
Epoch 30/150
19/19 [==============================] - 0s 15ms/step - loss: 0.0214 - val_loss: 0.0167
Epoch 31/150
19/19 [==============================] - 0s 14ms/step - loss: 0.0209 - val_loss: 0.0166
Epoch 32/150
19/19 [==============================] - 0s 21ms/step - loss: 0.0212 - val_loss: 0.0174
Epoch 33/150
19/19 [==============================] - 0s 14ms/step - loss: 0.0203 - val_loss: 0.0171
Epoch 34/150
19/19 [==============================] - 0s 19ms/step - loss: 0.0201 - val_loss: 0.0190
Epoch 35/150
19/19 [==============================] - 0s 13ms/step - loss: 0.0200 - val_loss: 0.0172
Epoch 36/150
19/19 [==============================] - 0s 15ms/step - loss: 0.0200 - val_loss: 0.0192
Epoch 37/150
19/19 [==============================] - 0s 16ms/step - loss: 0.0192 - val_loss: 0.0186
Epoch 38/150
19/19 [==============================] - 0s 14ms/step - loss: 0.0194 - val_loss: 0.0190
Epoch 39/150
19/19 [==============================] - 0s 13ms/step - loss: 0.0219 - val_loss: 0.0203
Epoch 40/150
19/19 [==============================] - 0s 20ms/step - loss: 0.0209 - val_loss: 0.0167
Epoch 41/150
19/19 [==============================] - 0s 15ms/step - loss: 0.0199 - val_loss: 0.0183
Epoch 42/150
19/19 [==============================] - 0s 12ms/step - loss: 0.0193 - val_loss: 0.0181
Epoch 43/150
19/19 [==============================] - 0s 16ms/step - loss: 0.0192 - val_loss: 0.0176
Epoch 44/150
19/19 [==============================] - 0s 16ms/step - loss: 0.0181 - val_loss: 0.0182
Epoch 45/150
19/19 [==============================] - 0s 13ms/step - loss: 0.0187 - val_loss: 0.0194
Epoch 46/150
19/19 [==============================] - 0s 18ms/step - loss: 0.0193 - val_loss: 0.0175
Epoch 47/150
19/19 [==============================] - 0s 18ms/step - loss: 0.0179 - val_loss: 0.0168
Epoch 48/150
19/19 [==============================] - 0s 15ms/step - loss: 0.0175 - val_loss: 0.0175
Epoch 49/150
19/19 [==============================] - 0s 13ms/step - loss: 0.0169 - val_loss: 0.0183
Epoch 50/150
19/19 [==============================] - 0s 21ms/step - loss: 0.0173 - val_loss: 0.0165
Epoch 51/150
19/19 [==============================] - 0s 16ms/step - loss: 0.0163 - val_loss: 0.0184
Epoch 52/150
19/19 [==============================] - 0s 13ms/step - loss: 0.0156 - val_loss: 0.0170
Epoch 53/150
19/19 [==============================] - 0s 16ms/step - loss: 0.0156 - val_loss: 0.0167
Epoch 54/150
19/19 [==============================] - 0s 20ms/step - loss: 0.0156 - val_loss: 0.0164
Epoch 55/150
19/19 [==============================] - 0s 15ms/step - loss: 0.0157 - val_loss: 0.0169
Epoch 56/150
19/19 [==============================] - 0s 18ms/step - loss: 0.0155 - val_loss: 0.0169
Epoch 57/150
19/19 [==============================] - 0s 17ms/step - loss: 0.0145 - val_loss: 0.0176
Epoch 58/150
19/19 [==============================] - 0s 17ms/step - loss: 0.0154 - val_loss: 0.0176
Epoch 59/150
19/19 [==============================] - 0s 13ms/step - loss: 0.0145 - val_loss: 0.0177
Epoch 60/150
19/19 [==============================] - 0s 17ms/step - loss: 0.0140 - val_loss: 0.0155
Epoch 61/150
19/19 [==============================] - 0s 15ms/step - loss: 0.0151 - val_loss: 0.0200
Epoch 62/150
19/19 [==============================] - 0s 14ms/step - loss: 0.0191 - val_loss: 0.0202
Epoch 63/150
19/19 [==============================] - 0s 14ms/step - loss: 0.0157 - val_loss: 0.0158
Epoch 64/150
19/19 [==============================] - 0s 14ms/step - loss: 0.0138 - val_loss: 0.0165
Epoch 65/150
19/19 [==============================] - 0s 10ms/step - loss: 0.0138 - val_loss: 0.0176
Epoch 66/150
19/19 [==============================] - 0s 13ms/step - loss: 0.0142 - val_loss: 0.0151
Epoch 67/150
19/19 [==============================] - 0s 17ms/step - loss: 0.0139 - val_loss: 0.0146
Epoch 68/150
19/19 [==============================] - 0s 17ms/step - loss: 0.0129 - val_loss: 0.0155
Epoch 69/150
19/19 [==============================] - 0s 13ms/step - loss: 0.0125 - val_loss: 0.0142
Epoch 70/150
19/19 [==============================] - 0s 16ms/step - loss: 0.0131 - val_loss: 0.0150
Epoch 71/150
19/19 [==============================] - 0s 15ms/step - loss: 0.0129 - val_loss: 0.0196
Epoch 72/150
19/19 [==============================] - 0s 13ms/step - loss: 0.0143 - val_loss: 0.0172
Epoch 73/150
19/19 [==============================] - 0s 14ms/step - loss: 0.0131 - val_loss: 0.0174
Epoch 74/150
19/19 [==============================] - 0s 17ms/step - loss: 0.0124 - val_loss: 0.0137
Epoch 75/150
19/19 [==============================] - 0s 22ms/step - loss: 0.0124 - val_loss: 0.0161
Epoch 76/150
19/19 [==============================] - 0s 22ms/step - loss: 0.0126 - val_loss: 0.0155
Epoch 77/150
19/19 [==============================] - 0s 23ms/step - loss: 0.0122 - val_loss: 0.0155
Epoch 78/150
19/19 [==============================] - 0s 22ms/step - loss: 0.0138 - val_loss: 0.0152
Epoch 79/150
19/19 [==============================] - 0s 17ms/step - loss: 0.0115 - val_loss: 0.0147
Epoch 80/150
19/19 [==============================] - 0s 14ms/step - loss: 0.0126 - val_loss: 0.0158
Epoch 81/150
19/19 [==============================] - 0s 16ms/step - loss: 0.0114 - val_loss: 0.0151
Epoch 82/150
19/19 [==============================] - 0s 17ms/step - loss: 0.0113 - val_loss: 0.0133
Epoch 83/150
19/19 [==============================] - 0s 17ms/step - loss: 0.0113 - val_loss: 0.0141
Epoch 84/150
19/19 [==============================] - 0s 20ms/step - loss: 0.0116 - val_loss: 0.0138
Epoch 85/150
19/19 [==============================] - 0s 13ms/step - loss: 0.0105 - val_loss: 0.0143
Epoch 86/150
19/19 [==============================] - 0s 8ms/step - loss: 0.0106 - val_loss: 0.0134
Epoch 87/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0105 - val_loss: 0.0137
Epoch 88/150
19/19 [==============================] - 0s 8ms/step - loss: 0.0103 - val_loss: 0.0145
Epoch 89/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0103 - val_loss: 0.0142
Epoch 90/150
19/19 [==============================] - 0s 8ms/step - loss: 0.0098 - val_loss: 0.0157
Epoch 91/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0105 - val_loss: 0.0143
Epoch 92/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0100 - val_loss: 0.0137
Epoch 93/150
19/19 [==============================] - 0s 10ms/step - loss: 0.0096 - val_loss: 0.0137
Epoch 94/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0102 - val_loss: 0.0141
Epoch 95/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0103 - val_loss: 0.0173
Epoch 96/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0108 - val_loss: 0.0134
Epoch 97/150
19/19 [==============================] - 0s 8ms/step - loss: 0.0099 - val_loss: 0.0139
Epoch 98/150
19/19 [==============================] - 0s 8ms/step - loss: 0.0096 - val_loss: 0.0139
Epoch 99/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0096 - val_loss: 0.0171
Epoch 100/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0096 - val_loss: 0.0141
Epoch 101/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0096 - val_loss: 0.0145
Epoch 102/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0097 - val_loss: 0.0132
Epoch 103/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0105 - val_loss: 0.0160
Epoch 104/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0101 - val_loss: 0.0143
Epoch 105/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0097 - val_loss: 0.0138
Epoch 106/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0093 - val_loss: 0.0149
Epoch 107/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0088 - val_loss: 0.0148
Epoch 108/150
19/19 [==============================] - 0s 8ms/step - loss: 0.0094 - val_loss: 0.0144
Epoch 109/150
19/19 [==============================] - 0s 8ms/step - loss: 0.0093 - val_loss: 0.0127
Epoch 110/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0095 - val_loss: 0.0126
Epoch 111/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0089 - val_loss: 0.0143
Epoch 112/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0087 - val_loss: 0.0146
Epoch 113/150
19/19 [==============================] - 0s 10ms/step - loss: 0.0091 - val_loss: 0.0142
Epoch 114/150
19/19 [==============================] - 0s 8ms/step - loss: 0.0088 - val_loss: 0.0130
Epoch 115/150
19/19 [==============================] - 0s 8ms/step - loss: 0.0086 - val_loss: 0.0117
Epoch 116/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0090 - val_loss: 0.0135
Epoch 117/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0092 - val_loss: 0.0144
Epoch 118/150
19/19 [==============================] - 0s 8ms/step - loss: 0.0087 - val_loss: 0.0127
Epoch 119/150
19/19 [==============================] - 0s 8ms/step - loss: 0.0083 - val_loss: 0.0142
Epoch 120/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0085 - val_loss: 0.0127
Epoch 121/150
19/19 [==============================] - 0s 8ms/step - loss: 0.0083 - val_loss: 0.0150
Epoch 122/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0085 - val_loss: 0.0140
Epoch 123/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0082 - val_loss: 0.0136
Epoch 124/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0081 - val_loss: 0.0129
Epoch 125/150
19/19 [==============================] - 0s 8ms/step - loss: 0.0092 - val_loss: 0.0128
Epoch 126/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0081 - val_loss: 0.0129
Epoch 127/150
19/19 [==============================] - 0s 10ms/step - loss: 0.0086 - val_loss: 0.0160
Epoch 128/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0092 - val_loss: 0.0168
Epoch 129/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0090 - val_loss: 0.0174
Epoch 130/150
19/19 [==============================] - 0s 8ms/step - loss: 0.0092 - val_loss: 0.0140
Epoch 131/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0080 - val_loss: 0.0153
Epoch 132/150
19/19 [==============================] - 0s 11ms/step - loss: 0.0086 - val_loss: 0.0137
Epoch 133/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0080 - val_loss: 0.0139
Epoch 134/150
19/19 [==============================] - 0s 8ms/step - loss: 0.0076 - val_loss: 0.0148
Epoch 135/150
19/19 [==============================] - 0s 10ms/step - loss: 0.0077 - val_loss: 0.0143
Epoch 136/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0078 - val_loss: 0.0146
Epoch 137/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0094 - val_loss: 0.0154
Epoch 138/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0085 - val_loss: 0.0144
Epoch 139/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0076 - val_loss: 0.0137
Epoch 140/150
19/19 [==============================] - 0s 8ms/step - loss: 0.0083 - val_loss: 0.0130
Epoch 141/150
19/19 [==============================] - 0s 8ms/step - loss: 0.0077 - val_loss: 0.0131
Epoch 142/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0076 - val_loss: 0.0167
Epoch 143/150
19/19 [==============================] - 0s 8ms/step - loss: 0.0079 - val_loss: 0.0141
Epoch 144/150
19/19 [==============================] - 0s 8ms/step - loss: 0.0082 - val_loss: 0.0133
Epoch 145/150
19/19 [==============================] - 0s 10ms/step - loss: 0.0079 - val_loss: 0.0130
Epoch 146/150
19/19 [==============================] - 0s 8ms/step - loss: 0.0073 - val_loss: 0.0136
Epoch 147/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0079 - val_loss: 0.0140
Epoch 148/150
19/19 [==============================] - 0s 9ms/step - loss: 0.0075 - val_loss: 0.0142
Epoch 149/150
19/19 [==============================] - 0s 8ms/step - loss: 0.0073 - val_loss: 0.0138
Epoch 150/150
19/19 [==============================] - 0s 10ms/step - loss: 0.0076 - val_loss: 0.0134
Out[14]:
<keras.callbacks.History at 0x7fe945497c40>

Evaluate ANN Model

In [15]:
y_pred = model.predict(X_test)
mse = mean_squared_error(Y_test, y_pred)
print('RMSE = {:.3f} '.format(np.sqrt(mse)))
18/18 [==============================] - 0s 3ms/step
RMSE = 0.115 
In [16]:
r2 = r2_score(Y_test, y_pred)
print('R2 = {:.3f} '.format(r2))
R2 = 0.780 
In [17]:
fig, ax = plt.subplots(figsize = (10, 10))
ax.scatter(Y_test, y_pred, edgecolors = (0, 0, 0))
ax.plot([y.min(), y.max()], [y.min(), y.max()], "r--", lw = 4)
ax.set_xlabel("Measured")
ax.set_ylabel("Predicted")
plt.show()

5.2. SHAP

SHAP Implementation

In [18]:
import shap

X_train = X_train.to_numpy()
X_test = X_test.to_numpy()

explainer_shap = shap.DeepExplainer(model = model, data = X_train)
shap_values = explainer_shap.shap_values(X_test)
keras is no longer supported, please use tf.keras instead.
Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode and some static graphs may not be supported. See PR #1483 for discussion.
`tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.

SHAP Value of a Test Data (Local Feature Importance)

In [19]:
shap.bar_plot(shap_values[0][1], features = X_test[1], feature_names = X.columns )

Average of Absolute SHAP Values of Entire Test Data (Global Feature Importance)

In [20]:
shap.summary_plot(shap_values, X_test, feature_names = X.columns)

5.3. LIME

LIME Implementation

In [21]:
from lime import lime_tabular

explainer = lime_tabular.LimeTabularExplainer(X_train, 
                                              mode = "regression",
                                              feature_names = X.columns)
explainer
Out[21]:
<lime.lime_tabular.LimeTabularExplainer at 0x7fe8bf61a970>
In [22]:
print("Prediction : ", model.predict(X_test[1].reshape(1,-1)))
print("Actual :     ", [[Y_test[1]]])
explanation = explainer.explain_instance(X_test[1], model.predict, num_features = len(X.columns))
explanation
1/1 [==============================] - 0s 62ms/step
Prediction :  [[0.7690255]]
Actual :      [[0.7912527303672922]]
157/157 [==============================] - 0s 3ms/step
Out[22]:
<lime.explanation.Explanation at 0x7fe8bf55d400>

LIME Result

In [23]:
explanation.show_in_notebook()
In [24]:
with plt.style.context("ggplot"):
    explanation.as_pyplot_figure()
In [1]:
%%javascript
$.getScript('https://kmahelona.github.io/ipython_notebook_goodies/ipython_notebook_toc.js')