Handwritten digit classifier using PyTorch.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import torchvision
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data', train=True, download=False,
transform=transforms.Compose([
transforms.ToTensor()
])),
batch_size=32, shuffle=False)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data', train=False,
transform=transforms.Compose([
transforms.ToTensor()
])),
batch_size=32, shuffle=False)
class BasicNN(nn.Module):
def __init__(self):
super(BasicNN, self).__init__()
self.net = nn.Linear(28 * 28, 10)
def forward(self, x):
batch_size = x.size(0)
x = x.view(batch_size, -1)
output = self.net(x)
return F.softmax(output)
model = BasicNN()
optimizer = optim.SGD(model.parameters(), lr=0.001)
def test():
total_loss = 0
correct = 0
for image, label in test_loader:
image, label = Variable(image), Variable(label)
output = model(image)
total_loss += F.cross_entropy(output, label)
correct += (torch.max(output, 1)[1].view(label.size()).data == label.data).sum()
total_loss = total_loss.data[0] / len(test_loader)
accuracy = correct / len(test_loader.dataset)
return total_loss, accuracy
def train():
model.train()
for image, label in train_loader:
image, label = Variable(image), Variable(label)
optimizer.zero_grad()
output = model(image)
loss = F.cross_entropy(output, label)
loss.backward()
optimizer.step()
best_test_loss = None
for e in range(1, 150):
train()
test_loss, test_accuracy = test()
print("\n[Epoch: %d] Test Loss:%5.5f Test Accuracy:%5.5f" % (e, test_loss, test_accuracy))
# Save the model if the test_loss is the lowest
if not best_test_loss or test_loss < best_test_loss:
best_test_loss = test_loss
else:
break
print("\nFinal Results\n-------------\n""Loss:", best_test_loss, "Test Accuracy: ", test_accuracy)
[Epoch: 1] Test Loss:2.27352 Test Accuracy:0.44360
[Epoch: 2] Test Loss:2.22371 Test Accuracy:0.45100
[Epoch: 3] Test Loss:2.16380 Test Accuracy:0.49840
[Epoch: 4] Test Loss:2.09973 Test Accuracy:0.51520
[Epoch: 5] Test Loss:2.04782 Test Accuracy:0.56200
[Epoch: 6] Test Loss:2.00434 Test Accuracy:0.60630
[Epoch: 7] Test Loss:1.96735 Test Accuracy:0.62930
[Epoch: 8] Test Loss:1.93913 Test Accuracy:0.64160
[Epoch: 9] Test Loss:1.91655 Test Accuracy:0.65620
[Epoch: 10] Test Loss:1.89545 Test Accuracy:0.68240
[Epoch: 11] Test Loss:1.87484 Test Accuracy:0.70650
[Epoch: 12] Test Loss:1.85802 Test Accuracy:0.71700
[Epoch: 13] Test Loss:1.84345 Test Accuracy:0.72550
[Epoch: 14] Test Loss:1.82930 Test Accuracy:0.74690
[Epoch: 15] Test Loss:1.81557 Test Accuracy:0.77430
[Epoch: 16] Test Loss:1.80372 Test Accuracy:0.78770
[Epoch: 17] Test Loss:1.79372 Test Accuracy:0.79150
[Epoch: 18] Test Loss:1.78501 Test Accuracy:0.79350
[Epoch: 19] Test Loss:1.77731 Test Accuracy:0.79600
[Epoch: 20] Test Loss:1.77043 Test Accuracy:0.79800
[Epoch: 21] Test Loss:1.76424 Test Accuracy:0.79990
[Epoch: 22] Test Loss:1.75864 Test Accuracy:0.80170
[Epoch: 23] Test Loss:1.75355 Test Accuracy:0.80300
[Epoch: 24] Test Loss:1.74890 Test Accuracy:0.80510
[Epoch: 25] Test Loss:1.74463 Test Accuracy:0.80620
[Epoch: 26] Test Loss:1.74069 Test Accuracy:0.80720
[Epoch: 27] Test Loss:1.73705 Test Accuracy:0.80880
[Epoch: 28] Test Loss:1.73367 Test Accuracy:0.80960
[Epoch: 29] Test Loss:1.73052 Test Accuracy:0.81040
[Epoch: 30] Test Loss:1.72757 Test Accuracy:0.81110
[Epoch: 31] Test Loss:1.72482 Test Accuracy:0.81170
[Epoch: 32] Test Loss:1.72223 Test Accuracy:0.81150
[Epoch: 33] Test Loss:1.71979 Test Accuracy:0.81260
[Epoch: 34] Test Loss:1.71750 Test Accuracy:0.81350
[Epoch: 35] Test Loss:1.71532 Test Accuracy:0.81350
[Epoch: 36] Test Loss:1.71326 Test Accuracy:0.81490
[Epoch: 37] Test Loss:1.71131 Test Accuracy:0.81560
[Epoch: 38] Test Loss:1.70945 Test Accuracy:0.81610
[Epoch: 39] Test Loss:1.70768 Test Accuracy:0.81660
[Epoch: 40] Test Loss:1.70599 Test Accuracy:0.81710
[Epoch: 41] Test Loss:1.70437 Test Accuracy:0.81810
[Epoch: 42] Test Loss:1.70282 Test Accuracy:0.81840
[Epoch: 43] Test Loss:1.70134 Test Accuracy:0.81910
[Epoch: 44] Test Loss:1.69992 Test Accuracy:0.81960
[Epoch: 45] Test Loss:1.69854 Test Accuracy:0.82030
[Epoch: 46] Test Loss:1.69722 Test Accuracy:0.82110
[Epoch: 47] Test Loss:1.69594 Test Accuracy:0.82090
[Epoch: 48] Test Loss:1.69470 Test Accuracy:0.82140
[Epoch: 49] Test Loss:1.69350 Test Accuracy:0.82170
[Epoch: 50] Test Loss:1.69233 Test Accuracy:0.82180
[Epoch: 51] Test Loss:1.69119 Test Accuracy:0.82220
[Epoch: 52] Test Loss:1.69007 Test Accuracy:0.82240
[Epoch: 53] Test Loss:1.68897 Test Accuracy:0.82280
[Epoch: 54] Test Loss:1.68787 Test Accuracy:0.82320
[Epoch: 55] Test Loss:1.68678 Test Accuracy:0.82370
[Epoch: 56] Test Loss:1.68567 Test Accuracy:0.82450
[Epoch: 57] Test Loss:1.68453 Test Accuracy:0.82490
[Epoch: 58] Test Loss:1.68333 Test Accuracy:0.82580
[Epoch: 59] Test Loss:1.68204 Test Accuracy:0.82700
[Epoch: 60] Test Loss:1.68060 Test Accuracy:0.82880
[Epoch: 61] Test Loss:1.67894 Test Accuracy:0.83060
[Epoch: 62] Test Loss:1.67696 Test Accuracy:0.83360
[Epoch: 63] Test Loss:1.67463 Test Accuracy:0.83750
[Epoch: 64] Test Loss:1.67199 Test Accuracy:0.84090
[Epoch: 65] Test Loss:1.66929 Test Accuracy:0.84380
[Epoch: 66] Test Loss:1.66677 Test Accuracy:0.84850
[Epoch: 67] Test Loss:1.66456 Test Accuracy:0.85150
[Epoch: 68] Test Loss:1.66259 Test Accuracy:0.85390
[Epoch: 69] Test Loss:1.66077 Test Accuracy:0.85540
[Epoch: 70] Test Loss:1.65905 Test Accuracy:0.85720
[Epoch: 71] Test Loss:1.65739 Test Accuracy:0.85920
[Epoch: 72] Test Loss:1.65576 Test Accuracy:0.86050
[Epoch: 73] Test Loss:1.65416 Test Accuracy:0.86220
[Epoch: 74] Test Loss:1.65260 Test Accuracy:0.86360
[Epoch: 75] Test Loss:1.65106 Test Accuracy:0.86480
[Epoch: 76] Test Loss:1.64955 Test Accuracy:0.86590
[Epoch: 77] Test Loss:1.64807 Test Accuracy:0.86770
[Epoch: 78] Test Loss:1.64661 Test Accuracy:0.86940
[Epoch: 79] Test Loss:1.64518 Test Accuracy:0.87040
[Epoch: 80] Test Loss:1.64379 Test Accuracy:0.87180
[Epoch: 81] Test Loss:1.64242 Test Accuracy:0.87240
[Epoch: 82] Test Loss:1.64109 Test Accuracy:0.87360
[Epoch: 83] Test Loss:1.63980 Test Accuracy:0.87450
[Epoch: 84] Test Loss:1.63853 Test Accuracy:0.87590
[Epoch: 85] Test Loss:1.63731 Test Accuracy:0.87790
[Epoch: 86] Test Loss:1.63612 Test Accuracy:0.87870
[Epoch: 87] Test Loss:1.63496 Test Accuracy:0.87950
[Epoch: 88] Test Loss:1.63384 Test Accuracy:0.88010
[Epoch: 89] Test Loss:1.63276 Test Accuracy:0.88130
[Epoch: 90] Test Loss:1.63171 Test Accuracy:0.88230
[Epoch: 91] Test Loss:1.63070 Test Accuracy:0.88320
[Epoch: 92] Test Loss:1.62971 Test Accuracy:0.88380
[Epoch: 93] Test Loss:1.62877 Test Accuracy:0.88490
[Epoch: 94] Test Loss:1.62785 Test Accuracy:0.88620
[Epoch: 95] Test Loss:1.62696 Test Accuracy:0.88650
[Epoch: 96] Test Loss:1.62610 Test Accuracy:0.88750
[Epoch: 97] Test Loss:1.62527 Test Accuracy:0.88740
[Epoch: 98] Test Loss:1.62447 Test Accuracy:0.88810
[Epoch: 99] Test Loss:1.62369 Test Accuracy:0.88830
[Epoch: 100] Test Loss:1.62294 Test Accuracy:0.88880
[Epoch: 101] Test Loss:1.62220 Test Accuracy:0.88930
[Epoch: 102] Test Loss:1.62149 Test Accuracy:0.88970
[Epoch: 103] Test Loss:1.62080 Test Accuracy:0.88990
[Epoch: 104] Test Loss:1.62013 Test Accuracy:0.89040
[Epoch: 105] Test Loss:1.61948 Test Accuracy:0.89060
[Epoch: 106] Test Loss:1.61885 Test Accuracy:0.89110
[Epoch: 107] Test Loss:1.61823 Test Accuracy:0.89170
[Epoch: 108] Test Loss:1.61763 Test Accuracy:0.89190
[Epoch: 109] Test Loss:1.61704 Test Accuracy:0.89230
[Epoch: 110] Test Loss:1.61647 Test Accuracy:0.89230
[Epoch: 111] Test Loss:1.61591 Test Accuracy:0.89290
[Epoch: 112] Test Loss:1.61536 Test Accuracy:0.89320
[Epoch: 113] Test Loss:1.61483 Test Accuracy:0.89340
[Epoch: 114] Test Loss:1.61430 Test Accuracy:0.89330
[Epoch: 115] Test Loss:1.61379 Test Accuracy:0.89360
[Epoch: 116] Test Loss:1.61329 Test Accuracy:0.89380
[Epoch: 117] Test Loss:1.61280 Test Accuracy:0.89400
[Epoch: 118] Test Loss:1.61232 Test Accuracy:0.89420
[Epoch: 119] Test Loss:1.61185 Test Accuracy:0.89430
[Epoch: 120] Test Loss:1.61139 Test Accuracy:0.89430
[Epoch: 121] Test Loss:1.61094 Test Accuracy:0.89430
[Epoch: 122] Test Loss:1.61049 Test Accuracy:0.89470
[Epoch: 123] Test Loss:1.61006 Test Accuracy:0.89500
[Epoch: 124] Test Loss:1.60963 Test Accuracy:0.89500
[Epoch: 125] Test Loss:1.60921 Test Accuracy:0.89510
[Epoch: 126] Test Loss:1.60880 Test Accuracy:0.89500
[Epoch: 127] Test Loss:1.60839 Test Accuracy:0.89500
[Epoch: 128] Test Loss:1.60799 Test Accuracy:0.89500
[Epoch: 129] Test Loss:1.60760 Test Accuracy:0.89500
[Epoch: 130] Test Loss:1.60721 Test Accuracy:0.89520
[Epoch: 131] Test Loss:1.60683 Test Accuracy:0.89550
[Epoch: 132] Test Loss:1.60646 Test Accuracy:0.89570
[Epoch: 133] Test Loss:1.60609 Test Accuracy:0.89580
[Epoch: 134] Test Loss:1.60573 Test Accuracy:0.89630
[Epoch: 135] Test Loss:1.60538 Test Accuracy:0.89660
[Epoch: 136] Test Loss:1.60503 Test Accuracy:0.89660
[Epoch: 137] Test Loss:1.60468 Test Accuracy:0.89680
[Epoch: 138] Test Loss:1.60434 Test Accuracy:0.89710
[Epoch: 139] Test Loss:1.60401 Test Accuracy:0.89730
[Epoch: 140] Test Loss:1.60368 Test Accuracy:0.89750
[Epoch: 141] Test Loss:1.60335 Test Accuracy:0.89780
[Epoch: 142] Test Loss:1.60303 Test Accuracy:0.89790
[Epoch: 143] Test Loss:1.60271 Test Accuracy:0.89820
[Epoch: 144] Test Loss:1.60240 Test Accuracy:0.89840
[Epoch: 145] Test Loss:1.60209 Test Accuracy:0.89850
[Epoch: 146] Test Loss:1.60179 Test Accuracy:0.89870
[Epoch: 147] Test Loss:1.60149 Test Accuracy:0.89860
[Epoch: 148] Test Loss:1.60119 Test Accuracy:0.89880
[Epoch: 149] Test Loss:1.60090 Test Accuracy:0.89870
Final Results
-------------
Loss: 1.60090330081245 Test Accuracy: 0.8987