top of page

Simple CNN architectures for image classification. CIFAR-10. PyTorch.

  • Writer: Tung San
    Tung San
  • Jun 28, 2022
  • 2 min read

Updated: Jun 29, 2022

CIFAR-10 Dataset. 10 classes. RGB images, 32x32. PyTorch.

Data Preparation

from torchvision import datasets, transforms

from torchvision.transforms import ToTensor, Lambda

mean = [0.4914, 0.4822, 0.4465]

std = [0.2023, 0.19994, 0.2010]

transform_train = transforms.Compose([

transforms.RandomCrop([32, 32], padding=4),

transforms.RandomHorizontalFlip(p=0.5),

transforms.ToTensor(),

transforms.Normalize(mean, std)

])

transform_test = transforms.Compose([

transforms.ToTensor(),

transforms.Normalize(mean, std)

])

training_data = datasets.CIFAR10(

root="data",

train=True,

download=True,

transform=transform_train

)

test_data = datasets.CIFAR10(

root="data",

train=False,

download=True,

transform=transform_test

)


from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=32, shuffle=True)

test_dataloader = DataLoader(test_data, batch_size=32, shuffle=True)



Display some train images (transformed)

test_flag=1

def prints(X, s=""):

if test_flag==1:

print(f"Check Output: {s}\n {X}\n")


import matplotlib.pyplot as plt

X, y = next(iter(train_dataloader))

prints(X[0].shape, "X.shape")

plt.matshow(X[0].permute(1, 2, 0))

plt.show()































prints(X.shape, "BatchInfo.shape")




CNN architecture

A batch of 32 entries, 3 channels, 32 x 32

  1. 2D convolution

  2. 2D Batch norm

  3. RELU

  4. 2D Max-pool

  5. Repeat 1 to 4 x 1 time

  6. Flatten

  7. Dense Layer

  8. Log softmax



Building using torchsummary.summary()

import torch

import torch.nn as nn

import torch.nn.functional as F

from torchsummary import summary


class CNNModel(nn.Module):

def __init__(self):

super(CNNModel, self).__init__()

self.conv1 = nn.Conv2d(3, 16, kernel_size=5)

self.BN1=nn.BatchNorm2d(16)

#group 2

self.conv2 = nn.Conv2d(16, 32, kernel_size=3)

self.BN2 = nn.BatchNorm2d(32)

#dense layer

self.flatting = nn.Flatten()

self.fc = nn.Linear(32*6*6, 10)

def forward(self, X):

out=X

prints(out.shape, "input shape")

out=self.BN1(self.conv1(out))

out=F.max_pool2d(F.relu(out), kernel_size=2, stride=2)

prints(out.shape, "out shape")

out=self.BN2(self.conv2(out))

out=F.max_pool2d(F.relu(out), kernel_size=2, stride=2)

prints(out.shape, "out shape")

out=self.fc(self.flatting(out))

return out


model = CNNModel()

model.cuda()

summary(model, input_size=(3,32,32))
















Define train and test loops

def train_loop(dataloader, model, loss_fn, optimizer):

size = len(dataloader.dataset)

for batch, (X, y) in enumerate(dataloader):

model.train()

X = X.cuda()

y = y.cuda()

# Compute prediction and loss

pred = model(X)

loss = loss_fn(pred, y)


# Backpropagation

optimizer.zero_grad()

loss.backward()

optimizer.step()


if batch % 100 == 0:

loss, current = loss.item(), batch * len(X)

print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):

model.eval()

size = len(dataloader.dataset)

num_batches = len(dataloader)

test_loss, correct = 0, 0


with torch.no_grad():

for X, y in dataloader:

X = X.cuda()

y = y.cuda()

pred = model(X)

test_loss += loss_fn(pred, y).item()

correct += (pred.argmax(1) == y).type(torch.float).sum().item()


test_loss /= num_batches

correct /= size

print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")



Begin training a model

test_flag=0

model = CNNModel()

model = model.cuda()

loss_fn = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999))

scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5,10,15], gamma=0.1)


epochs = 20

for t in range(epochs):

print(f"Epoch {t+1}\n-------------------------------")

train_loop(train_dataloader, model, loss_fn, optimizer)

test_loop(test_dataloader, model, loss_fn)

scheduler.step()

print("Done!")

































Comments


Post: Blog2 Post
bottom of page