Skip to content

kdn251/mnist

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 

Repository files navigation

MNIST

Handwritten digit classifier using PyTorch.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import torchvision
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=False,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])),
    batch_size=32, shuffle=False)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=False,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])),
    batch_size=32, shuffle=False)
class BasicNN(nn.Module):
    def __init__(self):
        super(BasicNN, self).__init__()
        self.net = nn.Linear(28 * 28, 10)
    def forward(self, x):
        batch_size = x.size(0)
        x = x.view(batch_size, -1)
        output = self.net(x)
        return F.softmax(output)
model = BasicNN()
optimizer = optim.SGD(model.parameters(), lr=0.001)
def test():
    total_loss = 0
    correct = 0
    for image, label in test_loader:
        image, label = Variable(image), Variable(label)
        output = model(image)
        total_loss += F.cross_entropy(output, label)
        correct += (torch.max(output, 1)[1].view(label.size()).data == label.data).sum()
    total_loss = total_loss.data[0] / len(test_loader)
    accuracy = correct / len(test_loader.dataset)
    return total_loss, accuracy
def train():
    model.train()
    for image, label in train_loader:
        image, label = Variable(image), Variable(label)
        optimizer.zero_grad()
        output = model(image)
        loss = F.cross_entropy(output, label)
        loss.backward()
        optimizer.step()
best_test_loss = None
for e in range(1, 150):
    train()
    test_loss, test_accuracy = test()
    print("\n[Epoch: %d] Test Loss:%5.5f Test Accuracy:%5.5f" % (e, test_loss, test_accuracy))
    # Save the model if the test_loss is the lowest
    if not best_test_loss or test_loss < best_test_loss:
        best_test_loss = test_loss
    else:
        break
print("\nFinal Results\n-------------\n""Loss:", best_test_loss, "Test Accuracy: ", test_accuracy)
[Epoch: 1] Test Loss:2.27352 Test Accuracy:0.44360

[Epoch: 2] Test Loss:2.22371 Test Accuracy:0.45100

[Epoch: 3] Test Loss:2.16380 Test Accuracy:0.49840

[Epoch: 4] Test Loss:2.09973 Test Accuracy:0.51520

[Epoch: 5] Test Loss:2.04782 Test Accuracy:0.56200

[Epoch: 6] Test Loss:2.00434 Test Accuracy:0.60630

[Epoch: 7] Test Loss:1.96735 Test Accuracy:0.62930

[Epoch: 8] Test Loss:1.93913 Test Accuracy:0.64160

[Epoch: 9] Test Loss:1.91655 Test Accuracy:0.65620

[Epoch: 10] Test Loss:1.89545 Test Accuracy:0.68240

[Epoch: 11] Test Loss:1.87484 Test Accuracy:0.70650

[Epoch: 12] Test Loss:1.85802 Test Accuracy:0.71700

[Epoch: 13] Test Loss:1.84345 Test Accuracy:0.72550

[Epoch: 14] Test Loss:1.82930 Test Accuracy:0.74690

[Epoch: 15] Test Loss:1.81557 Test Accuracy:0.77430

[Epoch: 16] Test Loss:1.80372 Test Accuracy:0.78770

[Epoch: 17] Test Loss:1.79372 Test Accuracy:0.79150

[Epoch: 18] Test Loss:1.78501 Test Accuracy:0.79350

[Epoch: 19] Test Loss:1.77731 Test Accuracy:0.79600

[Epoch: 20] Test Loss:1.77043 Test Accuracy:0.79800

[Epoch: 21] Test Loss:1.76424 Test Accuracy:0.79990

[Epoch: 22] Test Loss:1.75864 Test Accuracy:0.80170

[Epoch: 23] Test Loss:1.75355 Test Accuracy:0.80300

[Epoch: 24] Test Loss:1.74890 Test Accuracy:0.80510

[Epoch: 25] Test Loss:1.74463 Test Accuracy:0.80620

[Epoch: 26] Test Loss:1.74069 Test Accuracy:0.80720

[Epoch: 27] Test Loss:1.73705 Test Accuracy:0.80880

[Epoch: 28] Test Loss:1.73367 Test Accuracy:0.80960

[Epoch: 29] Test Loss:1.73052 Test Accuracy:0.81040

[Epoch: 30] Test Loss:1.72757 Test Accuracy:0.81110

[Epoch: 31] Test Loss:1.72482 Test Accuracy:0.81170

[Epoch: 32] Test Loss:1.72223 Test Accuracy:0.81150

[Epoch: 33] Test Loss:1.71979 Test Accuracy:0.81260

[Epoch: 34] Test Loss:1.71750 Test Accuracy:0.81350

[Epoch: 35] Test Loss:1.71532 Test Accuracy:0.81350

[Epoch: 36] Test Loss:1.71326 Test Accuracy:0.81490

[Epoch: 37] Test Loss:1.71131 Test Accuracy:0.81560

[Epoch: 38] Test Loss:1.70945 Test Accuracy:0.81610

[Epoch: 39] Test Loss:1.70768 Test Accuracy:0.81660

[Epoch: 40] Test Loss:1.70599 Test Accuracy:0.81710

[Epoch: 41] Test Loss:1.70437 Test Accuracy:0.81810

[Epoch: 42] Test Loss:1.70282 Test Accuracy:0.81840

[Epoch: 43] Test Loss:1.70134 Test Accuracy:0.81910

[Epoch: 44] Test Loss:1.69992 Test Accuracy:0.81960

[Epoch: 45] Test Loss:1.69854 Test Accuracy:0.82030

[Epoch: 46] Test Loss:1.69722 Test Accuracy:0.82110

[Epoch: 47] Test Loss:1.69594 Test Accuracy:0.82090

[Epoch: 48] Test Loss:1.69470 Test Accuracy:0.82140

[Epoch: 49] Test Loss:1.69350 Test Accuracy:0.82170

[Epoch: 50] Test Loss:1.69233 Test Accuracy:0.82180

[Epoch: 51] Test Loss:1.69119 Test Accuracy:0.82220

[Epoch: 52] Test Loss:1.69007 Test Accuracy:0.82240

[Epoch: 53] Test Loss:1.68897 Test Accuracy:0.82280

[Epoch: 54] Test Loss:1.68787 Test Accuracy:0.82320

[Epoch: 55] Test Loss:1.68678 Test Accuracy:0.82370

[Epoch: 56] Test Loss:1.68567 Test Accuracy:0.82450

[Epoch: 57] Test Loss:1.68453 Test Accuracy:0.82490

[Epoch: 58] Test Loss:1.68333 Test Accuracy:0.82580

[Epoch: 59] Test Loss:1.68204 Test Accuracy:0.82700

[Epoch: 60] Test Loss:1.68060 Test Accuracy:0.82880

[Epoch: 61] Test Loss:1.67894 Test Accuracy:0.83060

[Epoch: 62] Test Loss:1.67696 Test Accuracy:0.83360

[Epoch: 63] Test Loss:1.67463 Test Accuracy:0.83750

[Epoch: 64] Test Loss:1.67199 Test Accuracy:0.84090

[Epoch: 65] Test Loss:1.66929 Test Accuracy:0.84380

[Epoch: 66] Test Loss:1.66677 Test Accuracy:0.84850

[Epoch: 67] Test Loss:1.66456 Test Accuracy:0.85150

[Epoch: 68] Test Loss:1.66259 Test Accuracy:0.85390

[Epoch: 69] Test Loss:1.66077 Test Accuracy:0.85540

[Epoch: 70] Test Loss:1.65905 Test Accuracy:0.85720

[Epoch: 71] Test Loss:1.65739 Test Accuracy:0.85920

[Epoch: 72] Test Loss:1.65576 Test Accuracy:0.86050

[Epoch: 73] Test Loss:1.65416 Test Accuracy:0.86220

[Epoch: 74] Test Loss:1.65260 Test Accuracy:0.86360

[Epoch: 75] Test Loss:1.65106 Test Accuracy:0.86480

[Epoch: 76] Test Loss:1.64955 Test Accuracy:0.86590

[Epoch: 77] Test Loss:1.64807 Test Accuracy:0.86770

[Epoch: 78] Test Loss:1.64661 Test Accuracy:0.86940

[Epoch: 79] Test Loss:1.64518 Test Accuracy:0.87040

[Epoch: 80] Test Loss:1.64379 Test Accuracy:0.87180

[Epoch: 81] Test Loss:1.64242 Test Accuracy:0.87240

[Epoch: 82] Test Loss:1.64109 Test Accuracy:0.87360

[Epoch: 83] Test Loss:1.63980 Test Accuracy:0.87450

[Epoch: 84] Test Loss:1.63853 Test Accuracy:0.87590

[Epoch: 85] Test Loss:1.63731 Test Accuracy:0.87790

[Epoch: 86] Test Loss:1.63612 Test Accuracy:0.87870

[Epoch: 87] Test Loss:1.63496 Test Accuracy:0.87950

[Epoch: 88] Test Loss:1.63384 Test Accuracy:0.88010

[Epoch: 89] Test Loss:1.63276 Test Accuracy:0.88130

[Epoch: 90] Test Loss:1.63171 Test Accuracy:0.88230

[Epoch: 91] Test Loss:1.63070 Test Accuracy:0.88320

[Epoch: 92] Test Loss:1.62971 Test Accuracy:0.88380

[Epoch: 93] Test Loss:1.62877 Test Accuracy:0.88490

[Epoch: 94] Test Loss:1.62785 Test Accuracy:0.88620

[Epoch: 95] Test Loss:1.62696 Test Accuracy:0.88650

[Epoch: 96] Test Loss:1.62610 Test Accuracy:0.88750

[Epoch: 97] Test Loss:1.62527 Test Accuracy:0.88740

[Epoch: 98] Test Loss:1.62447 Test Accuracy:0.88810

[Epoch: 99] Test Loss:1.62369 Test Accuracy:0.88830

[Epoch: 100] Test Loss:1.62294 Test Accuracy:0.88880

[Epoch: 101] Test Loss:1.62220 Test Accuracy:0.88930

[Epoch: 102] Test Loss:1.62149 Test Accuracy:0.88970

[Epoch: 103] Test Loss:1.62080 Test Accuracy:0.88990

[Epoch: 104] Test Loss:1.62013 Test Accuracy:0.89040

[Epoch: 105] Test Loss:1.61948 Test Accuracy:0.89060

[Epoch: 106] Test Loss:1.61885 Test Accuracy:0.89110

[Epoch: 107] Test Loss:1.61823 Test Accuracy:0.89170

[Epoch: 108] Test Loss:1.61763 Test Accuracy:0.89190

[Epoch: 109] Test Loss:1.61704 Test Accuracy:0.89230

[Epoch: 110] Test Loss:1.61647 Test Accuracy:0.89230

[Epoch: 111] Test Loss:1.61591 Test Accuracy:0.89290

[Epoch: 112] Test Loss:1.61536 Test Accuracy:0.89320

[Epoch: 113] Test Loss:1.61483 Test Accuracy:0.89340

[Epoch: 114] Test Loss:1.61430 Test Accuracy:0.89330

[Epoch: 115] Test Loss:1.61379 Test Accuracy:0.89360

[Epoch: 116] Test Loss:1.61329 Test Accuracy:0.89380

[Epoch: 117] Test Loss:1.61280 Test Accuracy:0.89400

[Epoch: 118] Test Loss:1.61232 Test Accuracy:0.89420

[Epoch: 119] Test Loss:1.61185 Test Accuracy:0.89430

[Epoch: 120] Test Loss:1.61139 Test Accuracy:0.89430

[Epoch: 121] Test Loss:1.61094 Test Accuracy:0.89430

[Epoch: 122] Test Loss:1.61049 Test Accuracy:0.89470

[Epoch: 123] Test Loss:1.61006 Test Accuracy:0.89500

[Epoch: 124] Test Loss:1.60963 Test Accuracy:0.89500

[Epoch: 125] Test Loss:1.60921 Test Accuracy:0.89510

[Epoch: 126] Test Loss:1.60880 Test Accuracy:0.89500

[Epoch: 127] Test Loss:1.60839 Test Accuracy:0.89500

[Epoch: 128] Test Loss:1.60799 Test Accuracy:0.89500

[Epoch: 129] Test Loss:1.60760 Test Accuracy:0.89500

[Epoch: 130] Test Loss:1.60721 Test Accuracy:0.89520

[Epoch: 131] Test Loss:1.60683 Test Accuracy:0.89550

[Epoch: 132] Test Loss:1.60646 Test Accuracy:0.89570

[Epoch: 133] Test Loss:1.60609 Test Accuracy:0.89580

[Epoch: 134] Test Loss:1.60573 Test Accuracy:0.89630

[Epoch: 135] Test Loss:1.60538 Test Accuracy:0.89660

[Epoch: 136] Test Loss:1.60503 Test Accuracy:0.89660

[Epoch: 137] Test Loss:1.60468 Test Accuracy:0.89680

[Epoch: 138] Test Loss:1.60434 Test Accuracy:0.89710

[Epoch: 139] Test Loss:1.60401 Test Accuracy:0.89730

[Epoch: 140] Test Loss:1.60368 Test Accuracy:0.89750

[Epoch: 141] Test Loss:1.60335 Test Accuracy:0.89780

[Epoch: 142] Test Loss:1.60303 Test Accuracy:0.89790

[Epoch: 143] Test Loss:1.60271 Test Accuracy:0.89820

[Epoch: 144] Test Loss:1.60240 Test Accuracy:0.89840

[Epoch: 145] Test Loss:1.60209 Test Accuracy:0.89850

[Epoch: 146] Test Loss:1.60179 Test Accuracy:0.89870

[Epoch: 147] Test Loss:1.60149 Test Accuracy:0.89860

[Epoch: 148] Test Loss:1.60119 Test Accuracy:0.89880

[Epoch: 149] Test Loss:1.60090 Test Accuracy:0.89870

Final Results
-------------
Loss: 1.60090330081245 Test Accuracy:  0.8987

About

Handwritten digit classifier using PyTorch.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published