ANN with MNIST


By Prof. Seungchul Lee
http://iai.postech.ac.kr/
Industrial AI Lab at POSTECH

Table of Contents

1. What's an MNIST?

From Wikipedia

  • The MNIST database (Mixed National Institute of Standards and Technology database) is a large database of handwritten digits that is commonly used for training various image processing systems. The database is also widely used for training and testing in the field of machine learning. It was created by "re-mixing" the samples from NIST's original datasets. The creators felt that since NIST's training dataset was taken from American Census Bureau employees, while the testing dataset was taken from American high school students, NIST's complete dataset was too hard.
  • MNIST (Mixed National Institute of Standards and Technology database) database
    • Handwritten digit database
    • $28 \times 28$ gray scaled image
    • Flattened matrix into a vector of $28 \times 28 = 784$



More here

We will be using MNIST to create a Multinomial Classifier that can detect if the MNIST image shown is a member of class 0,1,2,3,4,5,6,7,8 or 9. Susinctly, we're teaching a computer to recognize hand written digets.

In [ ]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"
In [1]:
# Import Library
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
%matplotlib inline

Let's download and load the dataset.

In [2]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
WARNING: Logging before flag parsing goes to stderr.
W0816 22:40:32.788632 13688 deprecation.py:323] From <ipython-input-2-8bf8ae5a5303>:2: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
W0816 22:40:32.797559 13688 deprecation.py:323] From c:\users\seungchul\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
W0816 22:40:32.799048 13688 deprecation.py:323] From c:\users\seungchul\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
W0816 22:40:33.539080 13688 deprecation.py:323] From c:\users\seungchul\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
W0816 22:40:33.544040 13688 deprecation.py:323] From c:\users\seungchul\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
W0816 22:40:33.657129 13688 deprecation.py:323] From c:\users\seungchul\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
In [3]:
print ("The training data set is:\n")
print (mnist.train.images.shape)
print (mnist.train.labels.shape)
The training data set is:

(55000, 784)
(55000, 10)
In [4]:
print ("The test data set is:")
print (mnist.test.images.shape)
print (mnist.test.labels.shape)
The test data set is:
(10000, 784)
(10000, 10)

Display a few random samples from it:

In [5]:
mnist.train.images[5]
Out[5]:
array([0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.23137257, 0.6392157 , 0.9960785 , 0.9960785 , 0.9960785 ,
       0.7607844 , 0.43921572, 0.07058824, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.01568628, 0.5176471 , 0.93725497, 0.9921569 ,
       0.9921569 , 0.9921569 , 0.9921569 , 0.9960785 , 0.9921569 ,
       0.627451  , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.5372549 ,
       0.9921569 , 0.9960785 , 0.9921569 , 0.9921569 , 0.9921569 ,
       0.75294125, 0.9960785 , 0.9921569 , 0.8980393 , 0.0509804 ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.01568628, 0.5372549 , 0.9843138 , 0.9921569 , 0.9568628 ,
       0.50980395, 0.19215688, 0.07450981, 0.01960784, 0.6392157 ,
       0.9921569 , 0.8235295 , 0.03529412, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.37254903, 0.9921569 ,
       0.9921569 , 0.8431373 , 0.1764706 , 0.        , 0.        ,
       0.        , 0.        , 0.6117647 , 0.9921569 , 0.68235296,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.8431373 , 0.9960785 , 0.8117648 , 0.09019608,
       0.        , 0.        , 0.        , 0.03921569, 0.3803922 ,
       0.85098046, 0.9176471 , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.83921576,
       0.9921569 , 0.2784314 , 0.        , 0.        , 0.00784314,
       0.19607845, 0.8352942 , 0.9921569 , 0.9960785 , 0.7058824 ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.83921576, 0.9921569 , 0.19215688,
       0.        , 0.        , 0.19607845, 0.9921569 , 0.9921569 ,
       0.9921569 , 0.7176471 , 0.04705883, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.7803922 , 0.9921569 , 0.95294124, 0.7686275 , 0.62352943,
       0.95294124, 0.9921569 , 0.9686275 , 0.5411765 , 0.03137255,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.16470589, 0.9921569 ,
       0.9921569 , 0.9921569 , 0.9960785 , 0.9921569 , 0.9921569 ,
       0.39607847, 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.23137257, 0.58431375, 0.9960785 , 0.9960785 , 0.9960785 ,
       1.        , 0.9960785 , 0.6862745 , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.13333334, 0.75294125, 0.9960785 , 0.9921569 ,
       0.9921569 , 0.9921569 , 0.7843138 , 0.53333336, 0.89019614,
       0.9450981 , 0.27058825, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.33333334, 0.9686275 ,
       0.9921569 , 0.9960785 , 0.9921569 , 0.77647066, 0.48235297,
       0.07058824, 0.        , 0.19607845, 0.9921569 , 0.8352942 ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.2784314 , 0.9686275 , 0.9921569 , 0.9294118 , 0.75294125,
       0.2784314 , 0.02352941, 0.        , 0.        , 0.        ,
       0.00784314, 0.5019608 , 0.9803922 , 0.21176472, 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.46274513, 0.9921569 ,
       0.8705883 , 0.14117648, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.03137255, 0.7176471 ,
       0.9921569 , 0.227451  , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.46274513, 0.9960785 , 0.54509807, 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.05490196, 0.7294118 , 0.9960785 , 0.9960785 , 0.227451  ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.2784314 ,
       0.9686275 , 0.9686275 , 0.54509807, 0.0627451 , 0.        ,
       0.        , 0.07450981, 0.227451  , 0.87843144, 0.9921569 ,
       0.9921569 , 0.8313726 , 0.03529412, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.42352945, 0.9921569 ,
       0.9921569 , 0.92549026, 0.6862745 , 0.6862745 , 0.9686275 ,
       0.9921569 , 0.9960785 , 0.9921569 , 0.77647066, 0.16862746,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.2627451 , 0.8352942 , 0.8980393 , 0.9960785 ,
       0.9921569 , 0.9921569 , 0.9921569 , 0.9921569 , 0.83921576,
       0.48627454, 0.02352941, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.09019608, 0.60784316, 0.60784316, 0.8745099 ,
       0.7843138 , 0.46274513, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        ], dtype=float32)
In [6]:
# well, that's not a picture (or image), it's an array.

mnist.train.images[5].shape
Out[6]:
(784,)

You might think the training set is made up of 28 $\times$28 grayscale images of handwritten digits. No !!!

The thing is, the image has been flattened. These are 28x28 images that have been flattened into a 1D array. Let's reshape one.

In [7]:
img = np.reshape(mnist.train.images[5], [28,28])
In [8]:
img = mnist.train.images[5].reshape([28,28])
In [9]:
# So now we have a 28x28 matrix, where each element is an intensity level from 0 to 1.  
img.shape
Out[9]:
(28, 28)

Let's visualize what some of these images and their corresponding training labels look like.

In [10]:
plt.figure(figsize = (6,6))
plt.imshow(img, 'gray')
plt.xticks([])
plt.yticks([])
plt.show()
In [11]:
mnist.train.labels[5]
Out[11]:
array([0., 0., 0., 0., 0., 0., 0., 0., 1., 0.])
In [12]:
np.argmax(mnist.train.labels[5])
Out[12]:
8

Batch maker embedded

In [13]:
x, y = mnist.train.next_batch(3)

print(x.shape)
print(y.shape)
(3, 784)
(3, 10)

2. ANN with TensorFlow

  • Feed a gray image to ANN


  • Our network model



- Network training (learning) $$\omega:= \omega - \alpha \nabla_{\omega} \left( h_{\omega} \left(x^{(i)}\right),y^{(i)}\right)$$

2.1. Import Library

In [14]:
# Import Library
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

2.2. Load MNIST Data

  • Download MNIST data from tensorflow tutorial example
In [15]:
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
In [16]:
train_x, train_y = mnist.train.next_batch(1)
img = train_x[0,:].reshape(28,28)

plt.figure(figsize=(6,6))
plt.imshow(img,'gray')
plt.title("Label : {}".format(np.argmax(train_y[0,:])))
plt.xticks([])
plt.yticks([])
plt.show()

One hot encoding

In [17]:
print ('Train labels : {}'.format(train_y[0, :]))
Train labels : [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]

2.3. Define an ANN Structure

  • Input size
  • Hidden layer size
  • The number of classes


In [18]:
n_input = 28*28
n_hidden = 100
n_output = 10

2.4. Define Weights, Biases, and Placeholder

  • Define parameters based on predefined layer size
  • Initialize with normal distribution with $\mu = 0$ and $\sigma = 0.1$
In [19]:
weights = {
    'hidden' : tf.Variable(tf.random_normal([n_input, n_hidden], stddev = 0.1)),
    'output' : tf.Variable(tf.random_normal([n_hidden, n_output], stddev = 0.1))
}

biases = {
    'hidden' : tf.Variable(tf.random_normal([n_hidden], stddev = 0.1)),
    'output' : tf.Variable(tf.random_normal([n_output], stddev = 0.1))
}
In [20]:
x = tf.placeholder(tf.float32, [None, n_input])
y = tf.placeholder(tf.float32, [None, n_output])

2.5. Build a Model

First, the layer performs several matrix multiplication to produce a set of linear activations



$$y_j = \left(\sum\limits_i \omega_{ij}x_i\right) + b_j$$$$\mathcal{y} = \omega^T \mathcal{x} + \mathcal{b}$$


Second, each linear activation is running through a nonlinear activation function




Third, predict values with an affine transformation



In [21]:
# Define Network
def build_model(x, weights, biases):
    
    # first hidden layer
    hidden = tf.add(tf.matmul(x, weights['hidden']), biases['hidden'])
    # non-linear activate function
    hidden = tf.nn.relu(hidden)
    
    # Output layer 
    output = tf.add(tf.matmul(hidden, weights['output']), biases['output'])
    
    return output

2.6. Define Loss and Optimizer

Loss

  • This defines how we measure how accurate the model is during training. As was covered in lecture, during training we want to minimize this function, which will "steer" the model in the right direction.
  • Classification: Cross entropy
    • Equivalent to apply logistic regression
$$ -\frac{1}{m}\sum_{i=1}^{m}y^{(i)}\log(h_{\theta}\left(x^{(i)}\right)) + (1-y^{(i)})\log(1-h_{\theta}\left(x^{(i)}\right)) $$

Optimizer

  • This defines how the model is updated based on the data it sees and its loss function.
  • AdamOptimizer: the most popular optimizer
In [22]:
# Define Loss
pred = build_model(x, weights, biases)
loss = tf.nn.softmax_cross_entropy_with_logits(logits = pred, labels = y)
loss = tf.reduce_mean(loss)

LR = 0.0001
optm = tf.train.AdamOptimizer(LR).minimize(loss)
W0816 22:40:38.058635 13688 deprecation.py:323] From <ipython-input-22-37eee6b079c7>:3: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.
Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See `tf.nn.softmax_cross_entropy_with_logits_v2`.

2.7. Define Optimization Configuration and Then Optimize




  • Define parameters for training ANN
    • n_batch: batch size for mini-batch gradient descent
    • n_iter: the number of iteration steps
    • n_prt: check loss for every n_prt iteration
  • Metrics
    • Here we can define metrics used to monitor the training and testing steps. In this example, we'll look at the accuracy, the fraction of the images that are correctly classified.

Initializer

  • Initialize all the variables
In [23]:
n_batch = 50     # Batch Size
n_iter = 5000    # Learning Iteration
n_prt = 250      # Print Cycle
In [24]:
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

loss_record_train = []
loss_record_test = []
for epoch in range(n_iter):
    train_x, train_y = mnist.train.next_batch(n_batch)
    sess.run(optm, feed_dict = {x: train_x, y: train_y}) 
    
    if epoch % n_prt == 0:        
        test_x, test_y = mnist.test.next_batch(n_batch)
        c1 = sess.run(loss, feed_dict = {x: train_x, y: train_y})
        c2 = sess.run(loss, feed_dict = {x: test_x, y: test_y})
        loss_record_train.append(c1)
        loss_record_test.append(c2)
        print ("Iter : {}".format(epoch))
        print ("Cost : {}".format(c1))
        
plt.figure(figsize=(10,8))
plt.plot(np.arange(len(loss_record_train))*n_prt, 
         loss_record_train, label = 'training')
plt.plot(np.arange(len(loss_record_test))*n_prt, 
         loss_record_test, label = 'testing')
plt.xlabel('iteration', fontsize = 15)
plt.ylabel('loss', fontsize = 15)
plt.legend(fontsize = 12)
plt.ylim([0, np.max(loss_record_train)])
plt.show()
Iter : 0
Cost : 2.3691201210021973
Iter : 250
Cost : 1.1303077936172485
Iter : 500
Cost : 0.6283268332481384
Iter : 750
Cost : 0.557327389717102
Iter : 1000
Cost : 0.32952675223350525
Iter : 1250
Cost : 0.3076620399951935
Iter : 1500
Cost : 0.26567256450653076
Iter : 1750
Cost : 0.2599416971206665
Iter : 2000
Cost : 0.30930188298225403
Iter : 2250
Cost : 0.3725440502166748
Iter : 2500
Cost : 0.3152557611465454
Iter : 2750
Cost : 0.42607754468917847
Iter : 3000
Cost : 0.2844126522541046
Iter : 3250
Cost : 0.2923315465450287
Iter : 3500
Cost : 0.21499931812286377
Iter : 3750
Cost : 0.2560688555240631
Iter : 4000
Cost : 0.24431072175502777
Iter : 4250
Cost : 0.30018243193626404
Iter : 4500
Cost : 0.2014298290014267
Iter : 4750
Cost : 0.18581214547157288

2.8. Test or Evaluate

In [25]:
test_x, test_y = mnist.test.next_batch(100)

my_pred = sess.run(pred, feed_dict = {x : test_x})
my_pred = np.argmax(my_pred, axis = 1)

labels = np.argmax(test_y, axis = 1)

accr = np.mean(np.equal(my_pred, labels))
print("Accuracy : {}%".format(accr*100))
Accuracy : 93.0%
In [26]:
test_x, test_y = mnist.test.next_batch(1)
logits = sess.run(tf.nn.softmax(pred), feed_dict = {x : test_x})
predict = np.argmax(logits)

plt.figure(figsize = (6,6))
plt.imshow(test_x.reshape(28,28), 'gray')
plt.xticks([])
plt.yticks([])
plt.show()

print('Prediction : {}'.format(predict))
np.set_printoptions(precision = 2, suppress = True)
print('Probability : {}'.format(logits.ravel()))
Prediction : 1
Probability : [0.   0.98 0.01 0.01 0.   0.   0.   0.   0.   0.  ]

You may observe that the accuracy on the test dataset is a little lower than the accuracy on the training dataset. This gap between training accuracy and test accuracy is an example of overfitting, when a machine learning model performs worse on new data than on its training data.

What is the highest accuracy you can achieve with this first fully connected model? Since the handwritten digit classification task is pretty straightforward, you may be wondering how we can do better...

$\Rightarrow$ As we saw in lecture, convolutional neural networks (CNNs) are particularly well-suited for a variety of tasks in computer vision, and have achieved near-perfect accuracies on the MNIST dataset. We will build a CNN and ultimately output a probability distribution over the 10 digit classes (0-9) in the next lectures.

In [27]:
%%javascript
$.getScript('https://kmahelona.github.io/ipython_notebook_goodies/ipython_notebook_toc.js')