Chapter 20: Image Classification Project with PyTorch
Abstract:
- Dataset Loading:Load your image dataset. This can involve using
torchvision.datasetsfor common datasets (e.g., CIFAR-10, Fashion MNIST) or creating a customDatasetclass for your specific data. - Data Augmentation and Preprocessing:Apply transformations to your images using
torchvision.transforms. This includes resizing, cropping, normalization (e.g.,ToTensor,Normalize), and data augmentation techniques like random rotations or flips to improve model generalization. - DataLoader Creation:Create
DataLoaderobjects to efficiently load and batch your data during training and evaluation.
- Choose/Define a CNN Architecture: Select a suitable Convolutional Neural Network (CNN) architecture. This could be a pre-trained model from
torchvision.models(e.g., ResNet, VGG) for transfer learning, or you can define your own custom CNN usingtorch.nn.Moduleto build layers likeConv2d,MaxPool2d, andLinear.
- Loss Function: Define a loss function appropriate for your classification task (e.g.,
nn.CrossEntropyLossfor multi-class classification). - Optimizer: Choose an optimizer (e.g.,
optim.Adam,optim.SGD) to update your model's weights during training. - Device Configuration: Specify whether to use a GPU (CUDA) or CPU for training.
- Training Loop:Implement a training loop that iterates through epochs. Within each epoch:
- Iterate through batches from the training
DataLoader. - Perform a forward pass through the model.
- Calculate the loss.
- Perform backpropagation to compute gradients.
- Update model weights using the optimizer
- Iterate through batches from the training
- Validation:Periodically evaluate the model's performance on a validation set to monitor for overfitting and track progress. Use
torch.no_grad()during validation to avoid gradient computation.
- Testing: Evaluate the trained model on a separate test set to assess its generalization performance using metrics like accuracy, precision, recall, or F1-score.
- Prediction: Use the trained model to make predictions on new, unseen images.
import torchimport torch.nn as nnimport torch.optim as optimfrom torchvision import datasets, transforms, modelsfrom torch.utils.data import DataLoader# 1. Data Preparationtransform = transforms.Compose([ transforms.Resize(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)# 2. Model Definition (using a pre-trained ResNet)model = models.resnet18(pretrained=True)num_ftrs = model.fc.in_featuresmodel.fc = nn.Linear(num_ftrs, 10) # 10 classes for CIFAR-10# 3. Training Setupcriterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.001)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model.to(device)# 4. Model Training (simplified loop)num_epochs = 10for epoch in range(num_epochs): for images, labels in train_loader: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")# 5. Evaluation (conceptual)# ... (load test data, evaluate model)Here’s the complete Chapter 20: Image Classification Project for PyTorch Deep Learning book series — written in a structured, academic-yet-practical format consistent with previous chapters.
Chapter 20: Image Classification Project
Chapter Outline
-
Introduction
-
Data Pipeline and Model Selection
-
Training and Evaluation
-
Deployment
-
Summary
-
Exercises
20.1 Introduction
Image classification is one of the most common and foundational tasks in computer vision. It involves assigning a label or class to an image based on its visual content. Applications range from facial recognition and medical diagnosis to autonomous driving and defect detection in manufacturing.
In this chapter, we will build a complete image classification project using PyTorch, covering every stage of the workflow — data loading, model design, training and evaluation, and finally deployment for inference. The dataset used for demonstration will be CIFAR-10, which consists of 60,000 32×32 color images in 10 classes, such as airplane, automobile, bird, cat, deer, dog, frog, horse, ship, and truck.
20.2 Data Pipeline and Model Selection
20.2.1 Dataset Overview
The CIFAR-10 dataset is a standard benchmark in computer vision. It contains:
-
Training images: 50,000
-
Test images: 10,000
-
Classes: 10 object categories
Each image is a small color image (32×32 pixels).
We will use PyTorch’s torchvision library, which provides easy access to this dataset and common transformations for preprocessing.
20.2.2 Data Preprocessing
Before training, images are preprocessed to improve model convergence and generalization. Steps include:
-
Normalization: Scaling pixel values to a standardized range.
-
Data augmentation: Random transformations to increase data diversity (flipping, cropping, rotation).
import torch
import torchvision
import torchvision.transforms as transforms
# Define data transformations
transform_train = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))
])
# Load datasets
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
20.2.3 Model Selection
We can design a custom Convolutional Neural Network (CNN) or use a pre-trained model like ResNet, VGG, or MobileNet from torchvision.models.
For this project, we will start with a custom CNN for clarity.
import torch.nn as nn
import torch.nn.functional as F
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 8 * 8, 512)
self.fc2 = nn.Linear(512, 10)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 64 * 8 * 8)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
model = SimpleCNN()
Alternatively, using a pre-trained ResNet18:
from torchvision import models
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 10) # CIFAR-10 has 10 classes
20.3 Training and Evaluation
20.3.1 Defining Loss Function and Optimizer
We will use CrossEntropyLoss as the loss function (suitable for multi-class classification) and Adam as the optimizer.
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
20.3.2 Training Loop
Training involves feeding batches of data, computing loss, and updating model weights.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
num_epochs = 10
for epoch in range(num_epochs):
running_loss = 0.0
model.train()
for i, data in enumerate(trainloader, 0):
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 100 == 99: # Print every 100 mini-batches
print(f'Epoch [{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
running_loss = 0.0
print('Training Finished')
20.3.3 Model Evaluation
After training, the model’s performance is measured using accuracy on the test dataset.
correct = 0
total = 0
model.eval()
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(device), data[1].to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy of the network on test images: {100 * correct / total:.2f}%')
You can also compute per-class accuracy:
class_correct = [0.0] * 10
class_total = [0.0] * 10
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(device), data[1].to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
c = (predicted == labels).squeeze()
for i in range(len(labels)):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1
for i in range(10):
print(f'Accuracy of {classes[i]}: {100 * class_correct[i] / class_total[i]:.2f}%')
20.4 Deployment
20.4.1 Saving the Model
To reuse or deploy the trained model, it must be saved.
PATH = './cifar_net.pth'
torch.save(model.state_dict(), PATH)
20.4.2 Loading the Model
When deploying for inference, load the saved weights.
model = SimpleCNN()
model.load_state_dict(torch.load(PATH))
model.eval()
20.4.3 Inference on New Images
To classify new unseen images, preprocess them similarly and perform a forward pass.
from PIL import Image
def predict_image(image_path, model):
image = Image.open(image_path)
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))
])
image = transform(image).unsqueeze(0).to(device)
outputs = model(image)
_, predicted = torch.max(outputs, 1)
return classes[predicted.item()]
print(predict_image('sample_image.jpg', model))
20.4.4 Web Deployment using Flask
A simple Flask API can expose the model as a web service.
from flask import Flask, request, jsonify
app = Flask(__name__)
@app.route('/predict', methods=['POST'])
def predict():
file = request.files['image']
image = Image.open(file.stream)
label = predict_image(image, model)
return jsonify({'prediction': label})
if __name__ == '__main__':
app.run(debug=True)
With this setup, you can send an image via a POST request, and the API will return the predicted class — effectively deploying your deep learning model.
20.5 Summary
In this chapter, we developed a complete image classification project using PyTorch. We:
-
Built a data pipeline for image preprocessing and loading.
-
Designed and trained a CNN model on the CIFAR-10 dataset.
-
Evaluated performance and saved the trained model.
-
Deployed it for real-world inference through a Flask API.
This project consolidates all concepts from previous chapters—data handling, model design, training workflow, and deployment—into a practical, end-to-end application.
20.6 Exercises
-
Modify the CNN architecture to include Batch Normalization and observe its effect on accuracy.
-
Implement transfer learning using a pretrained ResNet50 model for CIFAR-10.
-
Try learning rate scheduling and compare model convergence.
-
Deploy your model using FastAPI instead of Flask.
-
Visualize sample predictions using matplotlib to display images alongside predicted labels.
Comments
Post a Comment
"Thank you for seeking advice on your career journey! Our team is dedicated to providing personalized guidance on education and success. Please share your specific questions or concerns, and we'll assist you in navigating the path to a fulfilling and successful career."