ANN with MNIST


By Prof. Seungchul Lee
http://iai.postech.ac.kr/
Industrial AI Lab at POSTECH

Table of Contents

1. What's an MNIST?

From Wikipedia

  • The MNIST database (Mixed National Institute of Standards and Technology database) is a large database of handwritten digits that is commonly used for training various image processing systems. The database is also widely used for training and testing in the field of machine learning. It was created by "re-mixing" the samples from NIST's original datasets. The creators felt that since NIST's training dataset was taken from American Census Bureau employees, while the testing dataset was taken from American high school students, NIST's complete dataset was too hard.
  • MNIST (Mixed National Institute of Standards and Technology database) database
    • Handwritten digit database
    • $28 \times 28$ gray scaled image
    • Flattened matrix into a vector of $28 \times 28 = 784$



More here

We will be using MNIST to create a Multinomial Classifier that can detect if the MNIST image shown is a member of class 0,1,2,3,4,5,6,7,8 or 9. Susinctly, we're teaching a computer to recognize hand written digets.

In [1]:
# Import Library
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

Let's load the dataset.

In [2]:
mnist_train_images = np.load('./data_files/mnist_train_images.npy')
mnist_train_labels = np.load('./data_files/mnist_train_labels.npy')
mnist_test_images = np.load('./data_files/mnist_test_images.npy')
mnist_test_labels = np.load('./data_files/mnist_test_labels.npy')
In [3]:
print ("The training data set is:\n")
print (mnist_train_images.shape)
print (mnist_train_labels.shape)
The training data set is:

(55000, 784)
(55000, 10)
In [4]:
print ("The test data set is:")
print (mnist_test_images.shape)
print (mnist_test_labels.shape)
The test data set is:
(10000, 784)
(10000, 10)

Display a few random samples from it:

In [5]:
mnist_train_images[5]
Out[5]:
array([0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.23137257, 0.6392157 , 0.9960785 , 0.9960785 , 0.9960785 ,
       0.7607844 , 0.43921572, 0.07058824, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.01568628, 0.5176471 , 0.93725497, 0.9921569 ,
       0.9921569 , 0.9921569 , 0.9921569 , 0.9960785 , 0.9921569 ,
       0.627451  , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.5372549 ,
       0.9921569 , 0.9960785 , 0.9921569 , 0.9921569 , 0.9921569 ,
       0.75294125, 0.9960785 , 0.9921569 , 0.8980393 , 0.0509804 ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.01568628, 0.5372549 , 0.9843138 , 0.9921569 , 0.9568628 ,
       0.50980395, 0.19215688, 0.07450981, 0.01960784, 0.6392157 ,
       0.9921569 , 0.8235295 , 0.03529412, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.37254903, 0.9921569 ,
       0.9921569 , 0.8431373 , 0.1764706 , 0.        , 0.        ,
       0.        , 0.        , 0.6117647 , 0.9921569 , 0.68235296,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.8431373 , 0.9960785 , 0.8117648 , 0.09019608,
       0.        , 0.        , 0.        , 0.03921569, 0.3803922 ,
       0.85098046, 0.9176471 , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.83921576,
       0.9921569 , 0.2784314 , 0.        , 0.        , 0.00784314,
       0.19607845, 0.8352942 , 0.9921569 , 0.9960785 , 0.7058824 ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.83921576, 0.9921569 , 0.19215688,
       0.        , 0.        , 0.19607845, 0.9921569 , 0.9921569 ,
       0.9921569 , 0.7176471 , 0.04705883, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.7803922 , 0.9921569 , 0.95294124, 0.7686275 , 0.62352943,
       0.95294124, 0.9921569 , 0.9686275 , 0.5411765 , 0.03137255,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.16470589, 0.9921569 ,
       0.9921569 , 0.9921569 , 0.9960785 , 0.9921569 , 0.9921569 ,
       0.39607847, 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.23137257, 0.58431375, 0.9960785 , 0.9960785 , 0.9960785 ,
       1.        , 0.9960785 , 0.6862745 , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.13333334, 0.75294125, 0.9960785 , 0.9921569 ,
       0.9921569 , 0.9921569 , 0.7843138 , 0.53333336, 0.89019614,
       0.9450981 , 0.27058825, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.33333334, 0.9686275 ,
       0.9921569 , 0.9960785 , 0.9921569 , 0.77647066, 0.48235297,
       0.07058824, 0.        , 0.19607845, 0.9921569 , 0.8352942 ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.2784314 , 0.9686275 , 0.9921569 , 0.9294118 , 0.75294125,
       0.2784314 , 0.02352941, 0.        , 0.        , 0.        ,
       0.00784314, 0.5019608 , 0.9803922 , 0.21176472, 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.46274513, 0.9921569 ,
       0.8705883 , 0.14117648, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.03137255, 0.7176471 ,
       0.9921569 , 0.227451  , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.46274513, 0.9960785 , 0.54509807, 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.05490196, 0.7294118 , 0.9960785 , 0.9960785 , 0.227451  ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.2784314 ,
       0.9686275 , 0.9686275 , 0.54509807, 0.0627451 , 0.        ,
       0.        , 0.07450981, 0.227451  , 0.87843144, 0.9921569 ,
       0.9921569 , 0.8313726 , 0.03529412, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.42352945, 0.9921569 ,
       0.9921569 , 0.92549026, 0.6862745 , 0.6862745 , 0.9686275 ,
       0.9921569 , 0.9960785 , 0.9921569 , 0.77647066, 0.16862746,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.2627451 , 0.8352942 , 0.8980393 , 0.9960785 ,
       0.9921569 , 0.9921569 , 0.9921569 , 0.9921569 , 0.83921576,
       0.48627454, 0.02352941, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.09019608, 0.60784316, 0.60784316, 0.8745099 ,
       0.7843138 , 0.46274513, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        ], dtype=float32)
In [6]:
# well, that's not a picture (or image), it's an array.

mnist_train_images[5].shape
Out[6]:
(784,)

You might think the training set is made up of 28 $\times$28 grayscale images of handwritten digits. No !!!

The thing is, the iamge has been flattened. These are 28x28 images that have been flattened into a 1D array. Let's reshape one.

In [7]:
img = np.reshape(mnist_train_images[5], [28,28])
In [8]:
img = mnist_train_images[5].reshape([28,28])
In [9]:
# So now we have a 28x28 matrix, where each element is an intensity level from 0 to 1.  
img.shape
Out[9]:
(28, 28)

Let's visualize what some of these images and their corresponding training labels look like.

In [10]:
plt.figure(figsize = (6,6))
plt.imshow(img, 'gray')
plt.xticks([])
plt.yticks([])
plt.show()
In [11]:
mnist_train_labels[5]
Out[11]:
array([0., 0., 0., 0., 0., 0., 0., 0., 1., 0.])
In [12]:
np.argmax(mnist_train_labels[5])
Out[12]:
8

2. ANN with Scikit Learn

  • Feed a gray image to ANN
  • Our network model



- Network training (learning) $$\omega:= \omega - \alpha \nabla_{\omega} \left( h_{\omega} \left(x^{(i)}\right),y^{(i)}\right)$$

2.1. Import Library

In [13]:
# Import Library
import numpy as np
import matplotlib.pyplot as plt

from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score

2.2. Load MNIST Data

In [14]:
train_x = np.load('./data_files/mnist_train_images.npy')
train_y = np.load('./data_files/mnist_train_labels.npy')
test_x = np.load('./data_files/mnist_test_images.npy')
test_y = np.load('./data_files/mnist_test_labels.npy')
In [15]:
img = train_x[1,:].reshape(28,28)

plt.figure(figsize=(6,6))
plt.imshow(img,'gray')
plt.title("Label : {}".format(np.argmax(train_y[1,:])))
plt.xticks([])
plt.yticks([])
plt.show()

One hot encoding

In [16]:
print ('Train labels : {}'.format(train_y[1, :]))
Train labels : [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]

2.3. Define an ANN Structure

  • Hidden layer size
  • Activation function
  • Optimizer
  • Learning rate
  • Batch size
  • Iteration
In [17]:
clf = MLPClassifier(hidden_layer_sizes=(100,), activation = 'relu', solver='adam', 
                    learning_rate_init = 0.0001, batch_size = 50, max_iter = 20, verbose = True)

2.4 Optimize

In [18]:
clf.fit(train_x, train_y)
Iteration 1, loss = 1.81264161
Iteration 2, loss = 0.86213484
Iteration 3, loss = 0.68763434
Iteration 4, loss = 0.59347783
Iteration 5, loss = 0.52895041
Iteration 6, loss = 0.48005256
Iteration 7, loss = 0.44179623
Iteration 8, loss = 0.41014217
Iteration 9, loss = 0.38342902
Iteration 10, loss = 0.36060402
Iteration 11, loss = 0.34058624
Iteration 12, loss = 0.32290041
Iteration 13, loss = 0.30728495
Iteration 14, loss = 0.29320335
Iteration 15, loss = 0.28043393
Iteration 16, loss = 0.26918163
Iteration 17, loss = 0.25822949
Iteration 18, loss = 0.24852496
Iteration 19, loss = 0.23958331
Iteration 20, loss = 0.23091648
/mnt/disk1/project/.env/lib/python3.6/site-packages/sklearn/neural_network/multilayer_perceptron.py:564: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (20) reached and the optimization hasn't converged yet.
  % self.max_iter, ConvergenceWarning)
Out[18]:
MLPClassifier(activation='relu', alpha=0.0001, batch_size=50, beta_1=0.9,
       beta_2=0.999, early_stopping=False, epsilon=1e-08,
       hidden_layer_sizes=(100,), learning_rate='constant',
       learning_rate_init=0.0001, max_iter=20, momentum=0.9,
       nesterovs_momentum=True, power_t=0.5, random_state=None,
       shuffle=True, solver='adam', tol=0.0001, validation_fraction=0.1,
       verbose=True, warm_start=False)

2.5. Test or Evaluate

In [19]:
pred = clf.predict(test_x)
print("Accuracy : {}%".format(accuracy_score(test_y, pred)*100))
Accuracy : 93.77%
In [20]:
logits = clf.predict_proba(test_x[:1])
predict = clf.predict(test_x[:1])

plt.figure(figsize = (6,6))
plt.imshow(test_x[:1].reshape(28,28), 'gray')
plt.xticks([])
plt.yticks([])
plt.show()

print('Prediction : {}'.format(np.argmax(predict)))
np.set_printoptions(precision = 2, suppress = True)
print('Probability : {}'.format(logits.ravel()))
Prediction : 7
Probability : [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
In [21]:
%%javascript
$.getScript('https://kmahelona.github.io/ipython_notebook_goodies/ipython_notebook_toc.js')