Skip to content

Transfer Learning with PyTorch: Boosting Model Performance

Transfer Learning with PyTorch: Boosting Model Performance Cover Image

In this tutorial, you’ll learn about how to use transfer learning in PyTorch to significantly boost your deep learning projects. Transfer learning is about leveraging the knowledge gained from one task and applying it to another. This allows you to cut down your training time and improve the performance of your deep-learning models.

This tutorial is part of a broader series on using PyTorch and Python for deep learning. You can find the entire learning path here.

By the end of this tutorial, you’ll have learned the following:

  • What transfer learning is in deep learning and in PyTorch
  • What the different types of transfer learning are and when to use which
  • How to load a pre-trained model in PyTorch and how to fine-tune it for your specific task
  • How to walk through using transfer learning for your deep-learning projects using a hands-on example

What is Transfer Learning in Deep Learning

Transfer learning is about leveraging the knowledge gained from one task to improve performance in another. Specifically, it allows you to use the patterns (the weights) learned from one (or many) different deep-learning models and apply them to a new problem.

I find this concept easiest to understand when thinking of images. If you’ve read through the deep learning for image classification tutorial on convolutional neural networks, you may remember that much of the early training of our model is to find patterns in our images.

As our model develops, these learned patterns allow our models to find complex nuances in images, such as circles, hair, and more! Because these patterns are common in many images, transfer learning allows you to use these patterns to train only the tail-end of our model.

Some of the key benefits of transfer learning include:

  1. Time and Resource Savings: Training a deep learning model from scratch requires both a lot of time and resources. Transfer learning significantly reduces this by using pre-trained models, allowing you to work more efficiently.
  2. Reduced Data Requirements: Training a deep learning model from scratch requires a large amount of data. By using the learnings from similar projects, we can cut this requirement down significantly.
  3. Enhanced Performance: Pre-trained models have already captured intricate features from diverse datasets. By transferring this knowledge, your model can start with a higher level of understanding, leading to faster convergence and potentially better performance.
  4. Quick Experimentation: Transfer learning empowers you to experiment with various architectures and strategies without the need to train each model from scratch, enabling rapid iteration.

Now that you have explored some of the benefits, you might be thinking, “Why don’t I use transfer learning for everything?” Transfer learning requires similarity between the original and the target tasks. We’ll dive into this more later on, but finding a pre-trained model where the data are similar is a crucial aspect of working with transfer learning.

Types of Transfer Learning

So far, we have talked about transfer learning in a fairly abstract manner. However, there are different types of transfer learning that you’ll come across as you progress in your deep learning adventures.

In this section, we’ll explore the two main approaches to transfer learning:

  1. Fine-Tuning: Fine-tuning involves taking a pre-trained model and modifying its architecture slightly to adapt it to your target task. The idea is to retain most of the pre-trained model’s weights while only adjusting a few layers to specialize in your specific problem. This approach is particularly useful when your target task is closely related to the task the pre-trained model was originally designed for. For instance, you might use a pre-trained image classification model as a base and fine-tune it for a similar but more specific task, like identifying plant species within a subset of the original categories.
  2. Feature Extraction: Feature extraction entails using the pre-trained model as a feature extractor, where you remove the final classification layer and only use the learned features from the earlier layers. These features are then input to a new classifier that’s trained for your target task. Feature extraction is a go-to strategy when your target task has a different domain or structure compared to the original pre-trained task. For instance, you could employ a pre-trained model designed for image classification to extract features for a completely distinct task, like facial expression recognition.

Choosing the appropriate type of transfer learning depends on the nature of your task, the availability of data, and the level of similarity between the pre-trained model’s original task and your target task. In the following sections, I will delve into practical examples of these transfer learning types, providing you with the tools and insights to implement each strategy effectively.

Loading Pre-Trained Models in PyTorch for Transfer Learning

While transfer learning can be applied to many different domains, such as natural language processing, we’ll focus our attention on image classification. Personally, I find that this makes the process more intuitive (and fun since we can visualize our results better).

PyTorch comes bundled with a number of different pre-trained models. Before we dive into choosing a pre-trained model, let’s explore some of the considerations that we need to make before then:

  1. Understand Model Architectures: Before selecting a pre-trained model, we need to understand the architecture and structure of different models. Models like VGG, ResNet, Inception, and MobileNet have varying depths and complexities, which may influence their suitability for your task.
  2. Match Task Complexity: Consider the complexity of your target task. For instance, if you’re dealing with simple image classification, a lighter model like MobileNet might suffice. On the other hand, if your task involves intricate object detection, a deeper model like ResNet or Inception might be more appropriate.
  3. Datasets and Domains: Pay attention to the datasets the pre-trained models were initially trained on. Models trained on diverse and extensive datasets, such as ImageNet, tend to have more generalizable features. If your dataset is similar to the one the model was trained on, your chances of success are higher.
  4. Computational Resources: Deeper models usually require more computational power and memory for both training and inference. Consider the resources available to you when choosing a model, as it could impact your ability to fine-tune effectively.
  5. Model Flexibility: Some models are designed to be more flexible, allowing you to modify and adapt their architecture more easily. This flexibility can be advantageous when you need to tailor the model to your specific task.
  6. Community Support: Models with strong community support often have a wealth of resources, tutorials, and pre-trained weights available, which can expedite your transfer learning process.

Now that you have a strong understanding of what to look for in a model, let’s dive into how to load a pre-trained model for transfer learning in PyTorch. For this tutorial, we’ll use a dataset containing two different types of flowers to see how well (and quickly) we can build an effective image classifier.

You can find the original dataset here on Kaggle, though the dataset is also included in the repository for convenience.

Because we’re hoping to create an image classifier, it makes sense to use a model such as ResNet. PyTorch comes bundled with a number of different versions of the network, but we’ll use the resnet50 model.

# Importing Libraries
import torch
from import DataLoader, random_split
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.datasets import ImageFolder

# Setting up a Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load the pre-trained ResNet-50 model
resnet50 = models.resnet50(weights=weights).to(device)

In the code block above, we accomplished a good number of things:

  1. We imported the libraries and modules that we need to work with
  2. We then set up our code to be device-agnostic, allowing us to switch between GPU or CPU support, as available
  3. Finally, we loaded our pre-trained model.

Beginning in version 0.13, PyTorch introduced the new method of loading a pre-trained model above. In this way, we first load the weights of the model by getting the DEFAULT weights. Then, we assign these weights to the resnet50 architecture. Finally, we send that model and its weights to the assigned device.

There you have it! You have successfully loaded a pre-trained model in PyTorch. Let’s now dive into how we can work with this model and freeze some of the weights.

Freezing a Pre-Trained Model for Transfer Learning in PyTorch

The next step in transfer learning is to prepare our pre-trained model. This generally involves two steps:

  1. Adapting the model to fit out current context, and
  2. Freezing some of the layers so that we can keep the learned weights

Let’s tackle the first step now. This will largely depend on the problem you’re facing. In our case, we are working with a model that was trained on 1000 different classes. However, in our case, we’re only working with two classes. Because of this, we’ll need to modify the final layer in our model to output our expected number of classes.

How do we know which layer to modify? Well, let’s start by printing out the model, which displays the architecture of the model. Given the model’s depth, I have significantly truncated the output and only kept some of the relevant pieces for our work.

# Printing the ResNet50 Model

# Returns:
# ResNet(
# ...
#   (layer4): Sequential(
#     (0): Bottleneck(
#       (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
#       (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#     ...
#   (fc): Linear(in_features=2048, out_features=1000, bias=True)
# )

We can see that by printing out the model the architecture is massive and quite deep. In the truncated printout above, there are two elements I want you to pay attention to:

  1. The final layer, fc, we output 1000 features. This is the element we’ll need to modify in order to predict our actual number of classes.
  2. layer4, which actually contains a number of different layers is the final set of layers. This is the layer for which we will unfreeze the parameters, allowing us to train these weights as we go through fine-tuning.

Let’s now modify our final layer, accepting the number of classes we want our model to predict:

# Modify the final classification layer for our specific task
num_classes = 2
resnet50.fc = torch.nn.Linear(resnet50.fc.in_features, num_classes)

In the code block above, we modified the architecture of our final layer to output only two classes. This will allow our model to accurately understand that we are only predicting two types of images, rather than the original 1000.

Now, we’ll take a look at how we can freeze some of the layers and keep some of them unfrozen.

# Freeze earlier layers
for param in resnet50.parameters():
    param.requires_grad = False

unfrozen_layers = ['layer4', 'fc', 'avgpool']  
for name, param in resnet50.named_parameters():
    if any(layer_name in name for layer_name in unfrozen_layers):
        param.requires_grad = True

In the code block above, we first freeze all of the layers by setting the requires_grad = False. We then, we define a list of layers that we want to unfreeze. We do this by looping over the name and parameters returned by applying the .named_parameters() method to our model. We use the any() method to check if any of the layer names exist in the layer. If it does, we unfreeze the layer.

Let’s now explore how we can prepare our data for our model.

Preparing Our Data For Transfer Learning in PyTorch

In this section, we’ll explore how we can prepare our images to fit into our resnet50 model. From the documentation, the model expects images of dimensions 224x224x3. We can also apply a number of preprocessing transformations to help our model learn better. Similarly, we can apply some transformations to help validation, which will simply modify the data to be as expected.

Let’s take a look at how this works:

# Define your desired transform for the training data
train_transform = transforms.Compose([
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

# No additional transform for the test data
test_transform = transforms.Compose([
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

In the training and testing transformations, we created transformation compositions. We use two different ones, since we want our model to use augmented data during learning, but use true data while validating.

Let’s now create our dataset, using the ImageFolder class and apply our transformations:

# Creating our Training and Testing Dataset
data_dir = 'data/transfer-learning-flowers/'
dataset = ImageFolder(root=data_dir)

# Split the dataset into training and testing sets
train_dataset, test_dataset = random_split(dataset, [0.7, 0.3])

# Apply the transforms to the respective datasets
train_dataset.dataset.transform = train_transform
test_dataset.dataset.transform = test_transform

In the code block above, we first created our dataset and then split it using the random_split function. Finally, we applied our transformations to each of the datasets.

Let’s now create PyTorch DataLoaders for both training and testing. This will allow us to batch our data when we train and test the model. Let’s see how we can do this, setting a batch size of 32 (meaning that we’ll load 32 images at a time):

# Defining Training and Testing DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

Now that we have our transformations defined and our datasets and DataLoaders created, we can finally begin applying transfer learning to our model in PyTorch!

Using Transfer Learning to Train Our Model in PyTorch

The process of training and evaluating a model using transfer learning works in the same way as training and evaluating a regular model in PyTorch. We’ll follow a similar process as outlined in my guide on developing deep learning models in PyTorch:

  1. We’ll define a criterion using the torch.nn module, in particular using the cross-entropy loss to evaluate our performance
  2. We’ll then define an optimizer. in this case, we’ll use stochastic gradient descent
  3. Then, we’ll loop over a number of epochs (meaning each data point is seem once). In this, we first train the model, optimizing the unfrozen parameters. From there, we also validate our model by using the no_grad() function, allowing us to understand the accuracy.

Let’s see what this looks like in Python:

# Define loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet50.parameters(), lr=0.001, momentum=0.9)
num_epochs = 5

# Training loop
for epoch in range(num_epochs):
    for inputs, labels in train_loader:
        outputs = resnet50(inputs)
        loss = criterion(outputs, labels)

    # Validation loop 
    with torch.no_grad():
        total_correct = 0
        total_samples = 0
        for val_inputs, val_labels in test_loader:
            val_outputs = resnet50(val_inputs)
            _, predicted = torch.max(val_outputs, 1)
            total_samples += val_labels.size(0)
            total_correct += (predicted == val_labels).sum().item()

        accuracy = total_correct / total_samples
        print(f'Epoch [{epoch+1}/{num_epochs}], Validation Accuracy: {accuracy:.4f}')

# Returns:
# Epoch [1/5], Validation Accuracy: 0.9519
# Epoch [2/5], Validation Accuracy: 0.9852
# Epoch [3/5], Validation Accuracy: 0.9926
# Epoch [4/5], Validation Accuracy: 0.9852
# Epoch [5/5], Validation Accuracy: 0.9889

We can see right away that the model’s accuracy is quite high right off the bat (over 95%!). This is one of the huge benefits of transfer learning.

Because the model has already learned the generalized patterns of image data, we save the time and resources it would take our model to learn the generalized patterns.


In conclusion, this tutorial has provided you with a comprehensive understanding of transfer learning in PyTorch and its potential to significantly enhance your deep learning projects. Transfer learning empowers you to leverage the knowledge gained from one task and apply it to another, resulting in reduced training time and improved model performance.

Throughout this tutorial, you’ve gained insights into the following key aspects:

  1. Definition and Benefits of Transfer Learning: Transfer learning involves utilizing patterns learned from one deep learning model and applying them to a new problem. This approach offers benefits such as time and resource savings, reduced data requirements, enhanced performance, and quick experimentation.
  2. Types of Transfer Learning: You’ve explored two main approaches—fine-tuning and feature extraction. Fine-tuning involves adapting a pre-trained model’s architecture to your target task, while feature extraction leverages the pre-trained model’s learned features for a new classifier. Choosing the appropriate approach depends on task similarity, data availability, and model structure.
  3. Loading Pre-Trained Models: The tutorial guided you through the process of loading pre-trained models in PyTorch, considering factors like model architecture, task complexity, datasets, computational resources, model flexibility, and community support.
  4. Modifying and Freezing Model Layers: You’ve learned how to modify the final classification layer of a pre-trained model to suit your task’s specific number of classes. Additionally, you explored how to freeze and unfreeze specific layers for fine-tuning.
  5. Data Preparation: The tutorial covered data preparation steps, including defining transformations, creating datasets, and setting up DataLoaders for both training and testing.
  6. Applying Transfer Learning: The final section demonstrated how to train and evaluate your model using transfer learning. By following the typical training loop and validation process, you were able to achieve high accuracy with minimal epochs due to the benefits of transfer learning.

Incorporating transfer learning into your deep learning projects can be a game-changer, allowing you to build effective models with less effort and time. Armed with the knowledge gained from this tutorial, you’re now equipped to confidently apply transfer learning to your own projects, taking advantage of pre-trained models to achieve impressive results. Keep experimenting and exploring the world of deep learning to further enhance your skills and capabilities.

Nik Piepenbreier

Nik is the author of and has over a decade of experience working with data analytics, data science, and Python. He specializes in teaching developers how to use Python for data science using hands-on tutorials.View Author posts

Leave a Reply

Your email address will not be published. Required fields are marked *