Chapter 12: Transfer Learning and Fine-Tuning in PyTorch

Abstract:

Transfer learning and fine-tuning are powerful techniques in PyTorch for leveraging pre-trained models on new, related tasks, especially when limited data or computational resources are available.
Transfer Learning:
Transfer learning involves using a model pre-trained on a large dataset for a general task (e.g., image classification on ImageNet) as a starting point for a different, but related, task. The idea is that the pre-trained model has already learned rich feature representations that are transferable to the new task.
In PyTorch, a common approach is to load a pre-trained model from torchvision.models or other sources. You can then modify the final classification layer to match the number of classes in your new task.
Fine-Tuning:
Fine-tuning is a specific type of transfer learning where, after replacing the final layer, you continue training the entire model (or parts of it) on your new dataset. This allows the pre-trained weights to adapt to the specifics of your new data and task.
Implementing Transfer Learning and Fine-Tuning in PyTorch:
Load a Pre-trained Model.
Python
    import torchvision.models as models    # Load a pre-trained ResNet-18 model    model = models.resnet18(pretrained=True)
Modify the Output Layer.
Replace the final fully connected layer to match the number of classes in your new task.
Python
    num_ftrs = model.fc.in_features    model.fc = torch.nn.Linear(num_ftrs, num_classes_for_your_task)
  • Freeze Layers (Optional but Recommended for Feature Extraction):
For feature extraction, you can freeze the weights of the pre-trained layers to prevent them from being updated during training, ensuring the learned features are preserved.
Python
    for param in model.parameters():        param.requires_grad = False    # Only the newly added final layer will have requires_grad=True by default
Define Loss Function and Optimizer.
Python
    criterion = torch.nn.CrossEntropyLoss()    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
If you froze layers, ensure the optimizer only updates the parameters of the un-frozen layers (e.g., optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001)). Train the Model.
Train the model on your new dataset, allowing the weights (either only the new layer or the entire network during fine-tuning) to adjust.
Python
    # Training loop (simplified)    for epoch in range(num_epochs):        for inputs, labels in dataloader:            outputs = model(inputs)            loss = criterion(outputs, labels)            optimizer.zero_grad()            loss.backward()            optimizer.step()
Key Considerations:
  • Dataset Size: 
    For small datasets, feature extraction (freezing pre-trained layers) is often preferred. For larger datasets, fine-tuning the entire model can yield better performance.
  • Learning Rate: 
    When fine-tuning, it is common to use a smaller learning rate than when training from scratch to avoid rapidly altering the pre-trained weights.
  • Unfreezing Layers Gradually: 
    You can start by training only the new layers and then gradually unfreeze more layers of the pre-trained model for further fine-tuning

Here’s the complete Chapter 12 written in a clear, academic, and textbook-oriented style with Learning Objectives, Concept Explanations, Examples, and Exercises


Chapter 12: Transfer Learning and Fine-Tuning


Learning Objectives

After completing this chapter, you will be able to:

  1. Understand the concept and importance of transfer learning in deep learning.

  2. Differentiate between feature extraction and fine-tuning strategies.

  3. Utilize pre-trained models available in torchvision.models.

  4. Implement practical applications of transfer learning using PyTorch.

  5. Apply transfer learning to real-world tasks such as image classification, object detection, and medical imaging.


12.1 Concept of Transfer Learning

Transfer learning is a machine learning technique where a model developed for one task is reused or adapted for another related task. Instead of training a model from scratch — which requires large datasets and extensive computation — transfer learning allows us to leverage pre-trained models that have already learned useful features from large benchmark datasets like ImageNet.

12.1.1 Motivation

  • Faster Training: Reusing learned features accelerates training time.

  • Better Performance with Less Data: Transfer learning works even when labeled data is scarce.

  • Generalization: Pre-trained models often generalize better to new tasks.

12.1.2 The Idea Behind Transfer Learning

A deep neural network learns hierarchical representations:

  • Early layers learn generic features (edges, textures, colors).

  • Middle layers learn abstract features (shapes, patterns).

  • Later layers learn task-specific features (object categories).

By transferring the early and middle layers to a new model, we can reuse these general features and train only the final layers for a new task.


12.2 Feature Extraction and Fine-Tuning Strategies

Transfer learning typically follows one of two strategies: feature extraction and fine-tuning.

12.2.1 Feature Extraction

In feature extraction, we freeze all the convolutional layers of the pre-trained model so their weights remain unchanged. We then replace the final fully connected (classification) layer to match our new task’s output classes and train only that layer.

Advantages:

  • Fast training.

  • Less prone to overfitting.

  • Ideal when the new dataset is small or similar to the original dataset.

Example Workflow:

  1. Load a pre-trained model (e.g., ResNet18).

  2. Freeze all layers except the final classifier.

  3. Replace the classifier to match the new task.

  4. Train the new classifier on your dataset.

12.2.2 Fine-Tuning

In fine-tuning, we unfreeze some or all layers of the pre-trained model and continue training them with a low learning rate. This allows the model to adjust pre-trained weights to better fit the new dataset.

Advantages:

  • Improved accuracy on target task.

  • Allows the model to adapt to domain-specific features.

Guidelines:

  • Fine-tune only a few top layers if the dataset is small.

  • Use a smaller learning rate for pre-trained layers.

  • Avoid overfitting by applying regularization or dropout.


12.3 Pre-trained Models from torchvision.models

The torchvision.models module provides many state-of-the-art pre-trained models trained on the ImageNet dataset. These models can be used directly for feature extraction or fine-tuning.

12.3.1 Common Pre-trained Models

Model Architecture Type Description
ResNet Residual Network Efficient deep network using skip connections.
VGG Deep CNN Simple architecture with stacked convolutional layers.
DenseNet Densely Connected CNN Improves information flow between layers.
MobileNet Lightweight CNN Optimized for mobile and embedded devices.
EfficientNet Scaled CNN Balances accuracy and efficiency.

12.3.2 Loading Pre-trained Models

Example: Loading and modifying a pre-trained ResNet18.

import torch
import torch.nn as nn
from torchvision import models

# Load a pre-trained ResNet18 model
model = models.resnet18(pretrained=True)

# Freeze all parameters
for param in model.parameters():
    param.requires_grad = False

# Modify the final layer to match your dataset (e.g., 5 classes)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 5)

print(model)

12.3.3 Fine-Tuning Example

To fine-tune the top layers, unfreeze selected parameters:

# Unfreeze the last few layers for fine-tuning
for name, param in model.named_parameters():
    if "layer4" in name or "fc" in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

Then train with a smaller learning rate:

optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
criterion = nn.CrossEntropyLoss()

12.4 Practical Applications

Transfer learning has become indispensable across a wide range of applications due to its efficiency and performance advantages.

12.4.1 Image Classification

Pre-trained models such as ResNet, VGG, or EfficientNet can classify new image datasets (e.g., medical scans, plant diseases) with limited data.

Example:
Classifying different flower species using ResNet18 pre-trained on ImageNet.

12.4.2 Object Detection

Models like Faster R-CNN or SSD can be fine-tuned for specific object detection tasks such as:

  • Detecting vehicles in traffic images.

  • Identifying tumors in medical scans.

12.4.3 Medical Image Analysis

Transfer learning helps when datasets are small or annotated data is limited:

  • Detecting pneumonia in chest X-rays.

  • Identifying diabetic retinopathy in retina images.

12.4.4 Natural Language Processing (NLP)

Although this chapter focuses on computer vision, transfer learning is equally powerful in NLP:

  • Models like BERT, GPT, and RoBERTa are pre-trained on large text corpora and fine-tuned for sentiment analysis, translation, and question answering.


12.5 Practical Example: Transfer Learning with ResNet18

Below is a simplified example of transfer learning for an image classification task.

import torch
import torch.nn as nn
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader

# 1. Data Preprocessing
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor()
])

train_data = datasets.ImageFolder('data/train', transform=transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)

# 2. Load Pre-trained Model
model = models.resnet18(pretrained=True)

# 3. Freeze Parameters
for param in model.parameters():
    param.requires_grad = False

# 4. Replace Final Layer
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, len(train_data.classes))

# 5. Define Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001)

# 6. Training Loop
for epoch in range(5):
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}: Loss = {loss.item():.4f}")

This example demonstrates feature extraction using a pre-trained ResNet18 model. To fine-tune, simply unfreeze additional layers and use a smaller learning rate.


12.6 Key Takeaways

  • Transfer Learning saves computation and improves performance on small datasets.

  • Feature Extraction uses fixed pre-trained features, while Fine-Tuning adjusts the entire model or selected layers.

  • PyTorch provides many pre-trained models via torchvision.models.

  • Real-world applications include image recognition, object detection, medical diagnosis, and NLP tasks.


Exercises

  1. Conceptual Questions

    1. Define transfer learning and explain its advantages over training from scratch.

    2. Differentiate between feature extraction and fine-tuning.

    3. Why should a smaller learning rate be used when fine-tuning pre-trained models?

  2. Programming Tasks

    1. Load a pre-trained VGG16 model and use it for feature extraction on a custom dataset.

    2. Fine-tune the last two convolutional blocks of a ResNet34 model.

    3. Compare the training time and accuracy of a model trained from scratch versus one using transfer learning.

  3. Advanced Exercise

    • Implement transfer learning for a medical image dataset (e.g., classifying X-ray images into “Normal” or “Pneumonia”). Report your accuracy and observations.


Conclusion

Transfer learning and fine-tuning represent a major leap in deep learning practice, allowing models to build upon existing knowledge instead of starting from zero. With PyTorch’s simple access to powerful pre-trained models, developers can achieve high performance with minimal data and resources. Whether used in computer vision, NLP, or speech recognition, transfer learning continues to accelerate innovation and accessibility in artificial intelligence.

Comments