Convolutional Neural Networks (CNNs) with PyTorch

Authors: Jeffrey Huang and Alex Michels

In this notebook, we will use PyTorch CNNs to recognize text from images. We use CNNs in this use case because the individual values of pixels don't tell us very much, but convolutions can help us extract features.

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pathlib import Path
import time
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torch import optim
from torch.autograd import Variable
import torch.nn as nn

Data Wrangling

First, we need to download the built in PyTorch MNIST dataset:

In [2]:
train_data = datasets.MNIST(
    root = 'data',
    train = True,
    transform = ToTensor(),
    download = True,
)
test_data = datasets.MNIST(
    root = 'data', 
    train = False,
    transform = ToTensor()
)

Next, we will examine the data. The data is of arrays based on the color of the pixel. We can also plot our data to see the images they make:

In [3]:
test_data[0]
Out[3]:
(tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3294, 0.7255,
           0.6235, 0.5922, 0.2353, 0.1412, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8706, 0.9961,
           0.9961, 0.9961, 0.9961, 0.9451, 0.7765, 0.7765, 0.7765, 0.7765,
           0.7765, 0.7765, 0.7765, 0.7765, 0.6667, 0.2039, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2627, 0.4471,
           0.2824, 0.4471, 0.6392, 0.8902, 0.9961, 0.8824, 0.9961, 0.9961,
           0.9961, 0.9804, 0.8980, 0.9961, 0.9961, 0.5490, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0667, 0.2588, 0.0549, 0.2627, 0.2627,
           0.2627, 0.2314, 0.0824, 0.9255, 0.9961, 0.4157, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.3255, 0.9922, 0.8196, 0.0706, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0863, 0.9137, 1.0000, 0.3255, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.5059, 0.9961, 0.9333, 0.1725, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.2314, 0.9765, 0.9961, 0.2431, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.5216, 0.9961, 0.7333, 0.0196, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0353,
           0.8039, 0.9725, 0.2275, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4941,
           0.9961, 0.7137, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2941, 0.9843,
           0.9412, 0.2235, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0745, 0.8667, 0.9961,
           0.6510, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.7961, 0.9961, 0.8588,
           0.1373, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.1490, 0.9961, 0.9961, 0.3020,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.1216, 0.8784, 0.9961, 0.4510, 0.0039,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.5216, 0.9961, 0.9961, 0.2039, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.2392, 0.9490, 0.9961, 0.9961, 0.2039, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.4745, 0.9961, 0.9961, 0.8588, 0.1569, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.4745, 0.9961, 0.8118, 0.0706, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000]]]),
 7)
In [4]:
figure = plt.figure(figsize=(10, 8))
cols, rows = 5, 5
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(train_data), size=(1,)).item()
    img, label = train_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(label)
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

DataLoaders are necessary for letting the model process large datasets.

In [5]:
loaders = {
    'train' : torch.utils.data.DataLoader(train_data,
                                          batch_size=100,
                                          shuffle=True,
                                          num_workers=1),
    'test' : torch.utils.data.DataLoader(test_data,
                                         batch_size=100,
                                         shuffle=True,
                                         num_workers=1)
}

Creating the Model

Next, we define the model as well as its forward pass function. This is a fairly simple CNN with 2 convolutional layers.

In [6]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1,  # (int) -> number of channels in input image    =1 because output is grayscale image
                out_channels=16, # (int) -> number of channels produced by the convolution
                kernel_size=5, # (int, tuple) -> size of convolving kernel
                stride=1, # (int, tuple, optional) -> stride of convolution, default is 1
                padding=2, # (int, tuple, optional) -> zero-padding added to both sides of input, default is 0
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.out = nn.Linear(32 * 7 * 7, 10)
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        output = self.out(x)
        return output, x       
In [7]:
cnn = CNN()

Next we define the loss function and optimizer. Cross Entropy Loss is a loss function for classification problems, while Adam is a popular and powerful optimizer that extends the already effective stochastic gradient descent algorithm.

In [8]:
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(cnn.parameters(), lr=0.01)  

It is worthwhile to check and make sure what device we are running on.

In [9]:
# train on the GPU or on the CPU, if a GPU is not available
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')
Using device: cpu

Training the Model

Next is the training step.

As this is a convolutional neural network being trained on a large image dataset, expect training time to take much longer than for a simple linear regression model. It is also worthwhile to time the training loop.

In [10]:
start_time = time.time()

num_epochs = 5
def train(num_epochs, cnn, loaders):
    cnn.train()
    total_step = len(loaders['train'])
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(loaders['train']):
            # gives batch data, normalizes x when iterating train_loader
            b_x = Variable(images)
            b_y = Variable(labels)
            
            output = cnn(b_x)[0]
            loss = loss_func(output, b_y)

            # clear gradients for this trainign step
            optimizer.zero_grad()
            # backpropogation, computing gradients
            loss.backward()
            # apply gradients
            optimizer.step()
            if (i+1) & 100 == 0:
                print('Epoch [{}/{}], Step[{}/{}], Loss: {:.4f}'
                      .format(epoch + 1, num_epochs, i+1, total_step, loss.item()))
                pass
            pass
        pass
train(num_epochs, cnn, loaders)

curr_time = time.time()
print(f"Training took {curr_time - start_time} seconds")
Epoch [1/5], Step[1/600], Loss: 2.3010
Epoch [1/5], Step[2/600], Loss: 2.7873
Epoch [1/5], Step[3/600], Loss: 2.2077
Epoch [1/5], Step[8/600], Loss: 1.4460
Epoch [1/5], Step[9/600], Loss: 1.1443
Epoch [1/5], Step[10/600], Loss: 0.8667
Epoch [1/5], Step[11/600], Loss: 0.9214
Epoch [1/5], Step[16/600], Loss: 1.0815
Epoch [1/5], Step[17/600], Loss: 0.5077
Epoch [1/5], Step[18/600], Loss: 1.0005
Epoch [1/5], Step[19/600], Loss: 0.6165
Epoch [1/5], Step[24/600], Loss: 0.4259
Epoch [1/5], Step[25/600], Loss: 0.3769
Epoch [1/5], Step[26/600], Loss: 0.3689
Epoch [1/5], Step[27/600], Loss: 0.3474
Epoch [1/5], Step[128/600], Loss: 0.0309
Epoch [1/5], Step[129/600], Loss: 0.0608
Epoch [1/5], Step[130/600], Loss: 0.0554
Epoch [1/5], Step[131/600], Loss: 0.2005
Epoch [1/5], Step[136/600], Loss: 0.1554
Epoch [1/5], Step[137/600], Loss: 0.2146
Epoch [1/5], Step[138/600], Loss: 0.0350
Epoch [1/5], Step[139/600], Loss: 0.0702
Epoch [1/5], Step[144/600], Loss: 0.1806
Epoch [1/5], Step[145/600], Loss: 0.0505
Epoch [1/5], Step[146/600], Loss: 0.0413
Epoch [1/5], Step[147/600], Loss: 0.0866
Epoch [1/5], Step[152/600], Loss: 0.1066
Epoch [1/5], Step[153/600], Loss: 0.0252
Epoch [1/5], Step[154/600], Loss: 0.1037
Epoch [1/5], Step[155/600], Loss: 0.0856
Epoch [1/5], Step[256/600], Loss: 0.0678
Epoch [1/5], Step[257/600], Loss: 0.0756
Epoch [1/5], Step[258/600], Loss: 0.1582
Epoch [1/5], Step[259/600], Loss: 0.1289
Epoch [1/5], Step[264/600], Loss: 0.0499
Epoch [1/5], Step[265/600], Loss: 0.0876
Epoch [1/5], Step[266/600], Loss: 0.0833
Epoch [1/5], Step[267/600], Loss: 0.0905
Epoch [1/5], Step[272/600], Loss: 0.1186
Epoch [1/5], Step[273/600], Loss: 0.0602
Epoch [1/5], Step[274/600], Loss: 0.0861
Epoch [1/5], Step[275/600], Loss: 0.1177
Epoch [1/5], Step[280/600], Loss: 0.0497
Epoch [1/5], Step[281/600], Loss: 0.0421
Epoch [1/5], Step[282/600], Loss: 0.1780
Epoch [1/5], Step[283/600], Loss: 0.0327
Epoch [1/5], Step[384/600], Loss: 0.0285
Epoch [1/5], Step[385/600], Loss: 0.0828
Epoch [1/5], Step[386/600], Loss: 0.0458
Epoch [1/5], Step[387/600], Loss: 0.0500
Epoch [1/5], Step[392/600], Loss: 0.2066
Epoch [1/5], Step[393/600], Loss: 0.1402
Epoch [1/5], Step[394/600], Loss: 0.0493
Epoch [1/5], Step[395/600], Loss: 0.0344
Epoch [1/5], Step[400/600], Loss: 0.0806
Epoch [1/5], Step[401/600], Loss: 0.0828
Epoch [1/5], Step[402/600], Loss: 0.1046
Epoch [1/5], Step[403/600], Loss: 0.0535
Epoch [1/5], Step[408/600], Loss: 0.0878
Epoch [1/5], Step[409/600], Loss: 0.0360
Epoch [1/5], Step[410/600], Loss: 0.0552
Epoch [1/5], Step[411/600], Loss: 0.0530
Epoch [1/5], Step[512/600], Loss: 0.0371
Epoch [1/5], Step[513/600], Loss: 0.0273
Epoch [1/5], Step[514/600], Loss: 0.0427
Epoch [1/5], Step[515/600], Loss: 0.0510
Epoch [1/5], Step[520/600], Loss: 0.0423
Epoch [1/5], Step[521/600], Loss: 0.0098
Epoch [1/5], Step[522/600], Loss: 0.0454
Epoch [1/5], Step[523/600], Loss: 0.0378
Epoch [1/5], Step[528/600], Loss: 0.0534
Epoch [1/5], Step[529/600], Loss: 0.1308
Epoch [1/5], Step[530/600], Loss: 0.0389
Epoch [1/5], Step[531/600], Loss: 0.0375
Epoch [1/5], Step[536/600], Loss: 0.0539
Epoch [1/5], Step[537/600], Loss: 0.0870
Epoch [1/5], Step[538/600], Loss: 0.0444
Epoch [1/5], Step[539/600], Loss: 0.0144
Epoch [2/5], Step[1/600], Loss: 0.1345
Epoch [2/5], Step[2/600], Loss: 0.1259
Epoch [2/5], Step[3/600], Loss: 0.0329
Epoch [2/5], Step[8/600], Loss: 0.0653
Epoch [2/5], Step[9/600], Loss: 0.0308
Epoch [2/5], Step[10/600], Loss: 0.0378
Epoch [2/5], Step[11/600], Loss: 0.0493
Epoch [2/5], Step[16/600], Loss: 0.1019
Epoch [2/5], Step[17/600], Loss: 0.0834
Epoch [2/5], Step[18/600], Loss: 0.0103
Epoch [2/5], Step[19/600], Loss: 0.0963
Epoch [2/5], Step[24/600], Loss: 0.0314
Epoch [2/5], Step[25/600], Loss: 0.0304
Epoch [2/5], Step[26/600], Loss: 0.0737
Epoch [2/5], Step[27/600], Loss: 0.0337
Epoch [2/5], Step[128/600], Loss: 0.0157
Epoch [2/5], Step[129/600], Loss: 0.0345
Epoch [2/5], Step[130/600], Loss: 0.0390
Epoch [2/5], Step[131/600], Loss: 0.0328
Epoch [2/5], Step[136/600], Loss: 0.0401
Epoch [2/5], Step[137/600], Loss: 0.0749
Epoch [2/5], Step[138/600], Loss: 0.0254
Epoch [2/5], Step[139/600], Loss: 0.0194
Epoch [2/5], Step[144/600], Loss: 0.0575
Epoch [2/5], Step[145/600], Loss: 0.1244
Epoch [2/5], Step[146/600], Loss: 0.0719
Epoch [2/5], Step[147/600], Loss: 0.2173
Epoch [2/5], Step[152/600], Loss: 0.0419
Epoch [2/5], Step[153/600], Loss: 0.0076
Epoch [2/5], Step[154/600], Loss: 0.0256
Epoch [2/5], Step[155/600], Loss: 0.0540
Epoch [2/5], Step[256/600], Loss: 0.0578
Epoch [2/5], Step[257/600], Loss: 0.0705
Epoch [2/5], Step[258/600], Loss: 0.0657
Epoch [2/5], Step[259/600], Loss: 0.0462
Epoch [2/5], Step[264/600], Loss: 0.0519
Epoch [2/5], Step[265/600], Loss: 0.0296
Epoch [2/5], Step[266/600], Loss: 0.0405
Epoch [2/5], Step[267/600], Loss: 0.0738
Epoch [2/5], Step[272/600], Loss: 0.0106
Epoch [2/5], Step[273/600], Loss: 0.0147
Epoch [2/5], Step[274/600], Loss: 0.0693
Epoch [2/5], Step[275/600], Loss: 0.1410
Epoch [2/5], Step[280/600], Loss: 0.0135
Epoch [2/5], Step[281/600], Loss: 0.0027
Epoch [2/5], Step[282/600], Loss: 0.0105
Epoch [2/5], Step[283/600], Loss: 0.0218
Epoch [2/5], Step[384/600], Loss: 0.1432
Epoch [2/5], Step[385/600], Loss: 0.0150
Epoch [2/5], Step[386/600], Loss: 0.1172
Epoch [2/5], Step[387/600], Loss: 0.0214
Epoch [2/5], Step[392/600], Loss: 0.0736
Epoch [2/5], Step[393/600], Loss: 0.0790
Epoch [2/5], Step[394/600], Loss: 0.0813
Epoch [2/5], Step[395/600], Loss: 0.0265
Epoch [2/5], Step[400/600], Loss: 0.0591
Epoch [2/5], Step[401/600], Loss: 0.0764
Epoch [2/5], Step[402/600], Loss: 0.1025
Epoch [2/5], Step[403/600], Loss: 0.0495
Epoch [2/5], Step[408/600], Loss: 0.0133
Epoch [2/5], Step[409/600], Loss: 0.0040
Epoch [2/5], Step[410/600], Loss: 0.0600
Epoch [2/5], Step[411/600], Loss: 0.0564
Epoch [2/5], Step[512/600], Loss: 0.0248
Epoch [2/5], Step[513/600], Loss: 0.1097
Epoch [2/5], Step[514/600], Loss: 0.0120
Epoch [2/5], Step[515/600], Loss: 0.0348
Epoch [2/5], Step[520/600], Loss: 0.0252
Epoch [2/5], Step[521/600], Loss: 0.0141
Epoch [2/5], Step[522/600], Loss: 0.0725
Epoch [2/5], Step[523/600], Loss: 0.0303
Epoch [2/5], Step[528/600], Loss: 0.0250
Epoch [2/5], Step[529/600], Loss: 0.0133
Epoch [2/5], Step[530/600], Loss: 0.0673
Epoch [2/5], Step[531/600], Loss: 0.1008
Epoch [2/5], Step[536/600], Loss: 0.0387
Epoch [2/5], Step[537/600], Loss: 0.0968
Epoch [2/5], Step[538/600], Loss: 0.0580
Epoch [2/5], Step[539/600], Loss: 0.0297
Epoch [3/5], Step[1/600], Loss: 0.0949
Epoch [3/5], Step[2/600], Loss: 0.0547
Epoch [3/5], Step[3/600], Loss: 0.0498
Epoch [3/5], Step[8/600], Loss: 0.0582
Epoch [3/5], Step[9/600], Loss: 0.0140
Epoch [3/5], Step[10/600], Loss: 0.0189
Epoch [3/5], Step[11/600], Loss: 0.0566
Epoch [3/5], Step[16/600], Loss: 0.0660
Epoch [3/5], Step[17/600], Loss: 0.0192
Epoch [3/5], Step[18/600], Loss: 0.0157
Epoch [3/5], Step[19/600], Loss: 0.0537
Epoch [3/5], Step[24/600], Loss: 0.0737
Epoch [3/5], Step[25/600], Loss: 0.0106
Epoch [3/5], Step[26/600], Loss: 0.0207
Epoch [3/5], Step[27/600], Loss: 0.0169
Epoch [3/5], Step[128/600], Loss: 0.0203
Epoch [3/5], Step[129/600], Loss: 0.0098
Epoch [3/5], Step[130/600], Loss: 0.0613
Epoch [3/5], Step[131/600], Loss: 0.0507
Epoch [3/5], Step[136/600], Loss: 0.1746
Epoch [3/5], Step[137/600], Loss: 0.0173
Epoch [3/5], Step[138/600], Loss: 0.0164
Epoch [3/5], Step[139/600], Loss: 0.0727
Epoch [3/5], Step[144/600], Loss: 0.0210
Epoch [3/5], Step[145/600], Loss: 0.0289
Epoch [3/5], Step[146/600], Loss: 0.0423
Epoch [3/5], Step[147/600], Loss: 0.1073
Epoch [3/5], Step[152/600], Loss: 0.0219
Epoch [3/5], Step[153/600], Loss: 0.0165
Epoch [3/5], Step[154/600], Loss: 0.0172
Epoch [3/5], Step[155/600], Loss: 0.0520
Epoch [3/5], Step[256/600], Loss: 0.0422
Epoch [3/5], Step[257/600], Loss: 0.0447
Epoch [3/5], Step[258/600], Loss: 0.0892
Epoch [3/5], Step[259/600], Loss: 0.0150
Epoch [3/5], Step[264/600], Loss: 0.0835
Epoch [3/5], Step[265/600], Loss: 0.0137
Epoch [3/5], Step[266/600], Loss: 0.1238
Epoch [3/5], Step[267/600], Loss: 0.0161
Epoch [3/5], Step[272/600], Loss: 0.0355
Epoch [3/5], Step[273/600], Loss: 0.1550
Epoch [3/5], Step[274/600], Loss: 0.0348
Epoch [3/5], Step[275/600], Loss: 0.0518
Epoch [3/5], Step[280/600], Loss: 0.0117
Epoch [3/5], Step[281/600], Loss: 0.0162
Epoch [3/5], Step[282/600], Loss: 0.0951
Epoch [3/5], Step[283/600], Loss: 0.0658
Epoch [3/5], Step[384/600], Loss: 0.1450
Epoch [3/5], Step[385/600], Loss: 0.1073
Epoch [3/5], Step[386/600], Loss: 0.0034
Epoch [3/5], Step[387/600], Loss: 0.1261
Epoch [3/5], Step[392/600], Loss: 0.0082
Epoch [3/5], Step[393/600], Loss: 0.0144
Epoch [3/5], Step[394/600], Loss: 0.0348
Epoch [3/5], Step[395/600], Loss: 0.0034
Epoch [3/5], Step[400/600], Loss: 0.0244
Epoch [3/5], Step[401/600], Loss: 0.0274
Epoch [3/5], Step[402/600], Loss: 0.0612
Epoch [3/5], Step[403/600], Loss: 0.0084
Epoch [3/5], Step[408/600], Loss: 0.0263
Epoch [3/5], Step[409/600], Loss: 0.1690
Epoch [3/5], Step[410/600], Loss: 0.0497
Epoch [3/5], Step[411/600], Loss: 0.0342
Epoch [3/5], Step[512/600], Loss: 0.0184
Epoch [3/5], Step[513/600], Loss: 0.0869
Epoch [3/5], Step[514/600], Loss: 0.0501
Epoch [3/5], Step[515/600], Loss: 0.0334
Epoch [3/5], Step[520/600], Loss: 0.0065
Epoch [3/5], Step[521/600], Loss: 0.0680
Epoch [3/5], Step[522/600], Loss: 0.0803
Epoch [3/5], Step[523/600], Loss: 0.1726
Epoch [3/5], Step[528/600], Loss: 0.0060
Epoch [3/5], Step[529/600], Loss: 0.0189
Epoch [3/5], Step[530/600], Loss: 0.1532
Epoch [3/5], Step[531/600], Loss: 0.0268
Epoch [3/5], Step[536/600], Loss: 0.0517
Epoch [3/5], Step[537/600], Loss: 0.0222
Epoch [3/5], Step[538/600], Loss: 0.0674
Epoch [3/5], Step[539/600], Loss: 0.0449
Epoch [4/5], Step[1/600], Loss: 0.0334
Epoch [4/5], Step[2/600], Loss: 0.0218
Epoch [4/5], Step[3/600], Loss: 0.0115
Epoch [4/5], Step[8/600], Loss: 0.0506
Epoch [4/5], Step[9/600], Loss: 0.0566
Epoch [4/5], Step[10/600], Loss: 0.0143
Epoch [4/5], Step[11/600], Loss: 0.0344
Epoch [4/5], Step[16/600], Loss: 0.0552
Epoch [4/5], Step[17/600], Loss: 0.0691
Epoch [4/5], Step[18/600], Loss: 0.0226
Epoch [4/5], Step[19/600], Loss: 0.0551
Epoch [4/5], Step[24/600], Loss: 0.0192
Epoch [4/5], Step[25/600], Loss: 0.0141
Epoch [4/5], Step[26/600], Loss: 0.0177
Epoch [4/5], Step[27/600], Loss: 0.0807
Epoch [4/5], Step[128/600], Loss: 0.0812
Epoch [4/5], Step[129/600], Loss: 0.0597
Epoch [4/5], Step[130/600], Loss: 0.0027
Epoch [4/5], Step[131/600], Loss: 0.0721
Epoch [4/5], Step[136/600], Loss: 0.0038
Epoch [4/5], Step[137/600], Loss: 0.0641
Epoch [4/5], Step[138/600], Loss: 0.0209
Epoch [4/5], Step[139/600], Loss: 0.0039
Epoch [4/5], Step[144/600], Loss: 0.0399
Epoch [4/5], Step[145/600], Loss: 0.0050
Epoch [4/5], Step[146/600], Loss: 0.0550
Epoch [4/5], Step[147/600], Loss: 0.0069
Epoch [4/5], Step[152/600], Loss: 0.0056
Epoch [4/5], Step[153/600], Loss: 0.0557
Epoch [4/5], Step[154/600], Loss: 0.1290
Epoch [4/5], Step[155/600], Loss: 0.0248
Epoch [4/5], Step[256/600], Loss: 0.0065
Epoch [4/5], Step[257/600], Loss: 0.0109
Epoch [4/5], Step[258/600], Loss: 0.0144
Epoch [4/5], Step[259/600], Loss: 0.0839
Epoch [4/5], Step[264/600], Loss: 0.0247
Epoch [4/5], Step[265/600], Loss: 0.0234
Epoch [4/5], Step[266/600], Loss: 0.0582
Epoch [4/5], Step[267/600], Loss: 0.0041
Epoch [4/5], Step[272/600], Loss: 0.1262
Epoch [4/5], Step[273/600], Loss: 0.0061
Epoch [4/5], Step[274/600], Loss: 0.0117
Epoch [4/5], Step[275/600], Loss: 0.0508
Epoch [4/5], Step[280/600], Loss: 0.0058
Epoch [4/5], Step[281/600], Loss: 0.1394
Epoch [4/5], Step[282/600], Loss: 0.1032
Epoch [4/5], Step[283/600], Loss: 0.0093
Epoch [4/5], Step[384/600], Loss: 0.0099
Epoch [4/5], Step[385/600], Loss: 0.0244
Epoch [4/5], Step[386/600], Loss: 0.0634
Epoch [4/5], Step[387/600], Loss: 0.0861
Epoch [4/5], Step[392/600], Loss: 0.0415
Epoch [4/5], Step[393/600], Loss: 0.0244
Epoch [4/5], Step[394/600], Loss: 0.0073
Epoch [4/5], Step[395/600], Loss: 0.0727
Epoch [4/5], Step[400/600], Loss: 0.0344
Epoch [4/5], Step[401/600], Loss: 0.0526
Epoch [4/5], Step[402/600], Loss: 0.0469
Epoch [4/5], Step[403/600], Loss: 0.0225
Epoch [4/5], Step[408/600], Loss: 0.0051
Epoch [4/5], Step[409/600], Loss: 0.0374
Epoch [4/5], Step[410/600], Loss: 0.1101
Epoch [4/5], Step[411/600], Loss: 0.0204
Epoch [4/5], Step[512/600], Loss: 0.0102
Epoch [4/5], Step[513/600], Loss: 0.0353
Epoch [4/5], Step[514/600], Loss: 0.0593
Epoch [4/5], Step[515/600], Loss: 0.0230
Epoch [4/5], Step[520/600], Loss: 0.1317
Epoch [4/5], Step[521/600], Loss: 0.0152
Epoch [4/5], Step[522/600], Loss: 0.1528
Epoch [4/5], Step[523/600], Loss: 0.0796
Epoch [4/5], Step[528/600], Loss: 0.0619
Epoch [4/5], Step[529/600], Loss: 0.0280
Epoch [4/5], Step[530/600], Loss: 0.0453
Epoch [4/5], Step[531/600], Loss: 0.0276
Epoch [4/5], Step[536/600], Loss: 0.0162
Epoch [4/5], Step[537/600], Loss: 0.0576
Epoch [4/5], Step[538/600], Loss: 0.0131
Epoch [4/5], Step[539/600], Loss: 0.0536
Epoch [5/5], Step[1/600], Loss: 0.0300
Epoch [5/5], Step[2/600], Loss: 0.0062
Epoch [5/5], Step[3/600], Loss: 0.0746
Epoch [5/5], Step[8/600], Loss: 0.0050
Epoch [5/5], Step[9/600], Loss: 0.0073
Epoch [5/5], Step[10/600], Loss: 0.0374
Epoch [5/5], Step[11/600], Loss: 0.0194
Epoch [5/5], Step[16/600], Loss: 0.0032
Epoch [5/5], Step[17/600], Loss: 0.0114
Epoch [5/5], Step[18/600], Loss: 0.0017
Epoch [5/5], Step[19/600], Loss: 0.0296
Epoch [5/5], Step[24/600], Loss: 0.0107
Epoch [5/5], Step[25/600], Loss: 0.1005
Epoch [5/5], Step[26/600], Loss: 0.0535
Epoch [5/5], Step[27/600], Loss: 0.0150
Epoch [5/5], Step[128/600], Loss: 0.0950
Epoch [5/5], Step[129/600], Loss: 0.0121
Epoch [5/5], Step[130/600], Loss: 0.0068
Epoch [5/5], Step[131/600], Loss: 0.0046
Epoch [5/5], Step[136/600], Loss: 0.0561
Epoch [5/5], Step[137/600], Loss: 0.0932
Epoch [5/5], Step[138/600], Loss: 0.0313
Epoch [5/5], Step[139/600], Loss: 0.0569
Epoch [5/5], Step[144/600], Loss: 0.0050
Epoch [5/5], Step[145/600], Loss: 0.0543
Epoch [5/5], Step[146/600], Loss: 0.1123
Epoch [5/5], Step[147/600], Loss: 0.0038
Epoch [5/5], Step[152/600], Loss: 0.0188
Epoch [5/5], Step[153/600], Loss: 0.0004
Epoch [5/5], Step[154/600], Loss: 0.0089
Epoch [5/5], Step[155/600], Loss: 0.0101
Epoch [5/5], Step[256/600], Loss: 0.0143
Epoch [5/5], Step[257/600], Loss: 0.0368
Epoch [5/5], Step[258/600], Loss: 0.1298
Epoch [5/5], Step[259/600], Loss: 0.0296
Epoch [5/5], Step[264/600], Loss: 0.0167
Epoch [5/5], Step[265/600], Loss: 0.0628
Epoch [5/5], Step[266/600], Loss: 0.0583
Epoch [5/5], Step[267/600], Loss: 0.0081
Epoch [5/5], Step[272/600], Loss: 0.1128
Epoch [5/5], Step[273/600], Loss: 0.0334
Epoch [5/5], Step[274/600], Loss: 0.0126
Epoch [5/5], Step[275/600], Loss: 0.0020
Epoch [5/5], Step[280/600], Loss: 0.0065
Epoch [5/5], Step[281/600], Loss: 0.0387
Epoch [5/5], Step[282/600], Loss: 0.0177
Epoch [5/5], Step[283/600], Loss: 0.0286
Epoch [5/5], Step[384/600], Loss: 0.0409
Epoch [5/5], Step[385/600], Loss: 0.0167
Epoch [5/5], Step[386/600], Loss: 0.0175
Epoch [5/5], Step[387/600], Loss: 0.0345
Epoch [5/5], Step[392/600], Loss: 0.0106
Epoch [5/5], Step[393/600], Loss: 0.0108
Epoch [5/5], Step[394/600], Loss: 0.0212
Epoch [5/5], Step[395/600], Loss: 0.0150
Epoch [5/5], Step[400/600], Loss: 0.0364
Epoch [5/5], Step[401/600], Loss: 0.0001
Epoch [5/5], Step[402/600], Loss: 0.0165
Epoch [5/5], Step[403/600], Loss: 0.0735
Epoch [5/5], Step[408/600], Loss: 0.0078
Epoch [5/5], Step[409/600], Loss: 0.0587
Epoch [5/5], Step[410/600], Loss: 0.0046
Epoch [5/5], Step[411/600], Loss: 0.0513
Epoch [5/5], Step[512/600], Loss: 0.0328
Epoch [5/5], Step[513/600], Loss: 0.0730
Epoch [5/5], Step[514/600], Loss: 0.0896
Epoch [5/5], Step[515/600], Loss: 0.0244
Epoch [5/5], Step[520/600], Loss: 0.1046
Epoch [5/5], Step[521/600], Loss: 0.0129
Epoch [5/5], Step[522/600], Loss: 0.0958
Epoch [5/5], Step[523/600], Loss: 0.1078
Epoch [5/5], Step[528/600], Loss: 0.0319
Epoch [5/5], Step[529/600], Loss: 0.0606
Epoch [5/5], Step[530/600], Loss: 0.0835
Epoch [5/5], Step[531/600], Loss: 0.0390
Epoch [5/5], Step[536/600], Loss: 0.0489
Epoch [5/5], Step[537/600], Loss: 0.0168
Epoch [5/5], Step[538/600], Loss: 0.0122
Epoch [5/5], Step[539/600], Loss: 0.0445
Training took 690.5935745239258 seconds

Testing the Model

Let's define a testing function for model evaluation:

In [11]:
def test():
    cnn.eval()  # set to evaluation mode
    with torch.no_grad():  # don't update gradients during testing
        correct = 0
        total = 0
        for images, labels in loaders['test']:  # for each item in the test set, test it
            test_output, last_layer = cnn(images)
            pred_y = torch.max(test_output, 1)[1].data.squeeze()
            accuracy = (pred_y == labels).sum().item() / float(labels.size(0))
    print('Test accuracy of model on 10000 test images: %.2f' % accuracy)
In [12]:
test()  # this takes a minute
Test accuracy of model on 10000 test images: 1.00

Let's grab a few samples from our test set, plot them, and then manually check our accuracy:

In [13]:
sample = next(iter(loaders['test']))
imgs, lbls = sample
actual_number = lbls[:10].numpy()

Plotting the samples:

In [14]:
figure = plt.figure(figsize=(10, 4))
cols, rows = 5, 2
for i in range(1, cols * rows + 1):
    img, label = imgs[i], lbls[i]
    figure.add_subplot(rows, cols, i)
    plt.title(label)
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

Checking our prediction abilities on our 10 samples:

In [15]:
test_output, last_layer = cnn(imgs[:10])
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
print(f'Prediction number: {pred_y}')
print(f'Actual number: {actual_number}')
Prediction number: [0 9 2 9 8 8 7 3 8 1]
Actual number: [0 9 2 4 8 8 7 3 8 1]

Saving The Trained Model

Next, we can save the trained model's state using the lines below:

In [16]:
# 1. Create models directory 
MODEL_PATH = Path("models")
MODEL_PATH.mkdir(parents=True, exist_ok=True)

# 2. Create model save path 
MODEL_NAME = "01_pytorch_mnist_cnn.pth"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME

# 3. Save the model state dict 
print(f"Saving model to: {MODEL_SAVE_PATH}")
torch.save(obj=cnn.state_dict(), # only saving the state_dict() only saves the models learned parameters
           f=MODEL_SAVE_PATH)
Saving model to: models/01_pytorch_mnist_cnn.pth
In [17]:
# Instantiate a new instance of our model (this will be instantiated with random weights)
loaded_model_0 = CNN()

# Load the state_dict of our saved model (this will update the new instance of our model with trained weights)
loaded_model_0.load_state_dict(torch.load(f=MODEL_SAVE_PATH))
Out[17]:
<All keys matched successfully>
In [18]:
loaded_model_0.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in loaders['test']:
        test_output, last_layer = loaded_model_0(images)
        pred_y = torch.max(test_output, 1)[1].data.squeeze()
        accuracy = (pred_y == labels).sum().item() / float(labels.size(0))
        pass
print('Test accuracy of model on 10000 test images: %.2f' % accuracy)
Test accuracy of model on 10000 test images: 0.99