pytorch

Transfer Learning with Pretrained Models in PyTorch

Transfer learning is a special technique that enables you to use a pretrained model on some data for inference on another task with somewhat similar data. It has revolutionized the field by enabling the development of highly accurate models even with limited data.

One of the key components of transfer learning is the use of pretrained models which are neural networks that are trained on massive datasets, typically for image recognition tasks. Instead of training a neural network from scratch for the new task, transfer learning allows us to leverage the learned representations and features from the pre-trained model and adapt them to the new problem.

PyTorch, a popular deep learning framework, provides an easy-to-use platform to implement transfer learning with pretrained models. In this article, we will explore the fundamentals of transfer learning and delve into how to utilize the pretrained models effectively in PyTorch.

Advantages of Transfer Learning

Feature Extraction

The Convolutional Neural Networks (CNNs) that are used in many pretrained models act as excellent feature extractors. The early layers of these models learn the basic features like edges, textures, and simple shapes, while the later layers learn more about complex features that are relevant to the specific dataset. These complex features could represent how different objects in an image relate to one another or how the absence of some specific representation might influence the overall final prediction of the model. By reusing these learned features, we can significantly reduce the training time and data requirements for a new task.

Generalization

Pretrained models are trained on diverse and extensive datasets which improves their ability to generalize across different tasks. Being trained on datasets with diverse class labels allows these models to be easily trainable on our own specific tasks. The knowledge they acquire from large-scale data allows them to identify the patterns and features that are common in many tasks.

Speed and Efficiency

Training a deep neural network from scratch can be computationally intensive and time-consuming. By starting with a pretrained model, we save a substantial time and computational resources. Essentially, what we are doing is that we are using a model that is already trained on the data that is somewhat similar to ours, and then changing the final few layers of the model to repurpose it to our own specific needs. This way, we get done with our complex training task in a very efficient manner with less time and resources being spent on training.

Working with Pretrained Models in PyTorch

PyTorch provides a rich ecosystem to work with pretrained models. The “torchvision” library, an extension of PyTorch, offers a collection of popular image datasets, model architectures, and pretrained models. To use these pretrained models, ensure that you have PyTorch and torchvision installed. You can install them using pip.

pip install torch torchvision

Load the Pretrained Model

In this step, we choose a suitable pretrained model from the “torchvision.models” module and load it with its pre-trained weights. Some popular pretrained models include ResNet, VGG, AlexNet, and DenseNet, among others. The choice of the model depends on the complexity of the task and the size of the dataset.

import torch
import torchvision.models as models

# Choose a pretrained model (e.g., ResNet-18)
model = models.resnet18(pretrained=True)

The pretrained=True argument ensures that the model is initialized with pretrained weights on the “ImageNet” dataset.

Freeze the Model Layers

To retain the features that are learned by the pretrained model and prevent them from being modified during the fine-tuning process, we freeze the layers up to a certain point. The lower layers of the model capture the general features like edges and textures, while the higher layers encode more task-specific information. By freezing the lower layers, we allow the model to focus on learning the task-specific patterns using the new dataset.

# Freeze model layers
for param in model.parameters():
  param.requires_grad = False

Modify the Classifier

The final layers of the pretrained model need to be replaced with a new classifier that matches the number of classes in the target dataset. For instance, if you are working on a binary classification task, the classifier will have two output nodes; for a multi-class problem with N classes, the classifier will have N output nodes.

import torch.nn as nn

# Modify the classifier (for binary classification)
num_classes = 2
model.fc = nn.Linear(model.fc.in_features, num_classes)

In this example, we replace the last fully connected “model.fc” layer with a new linear layer.

Train the Modified Model

With the pretrained model modified for our specific task, we can now train it using the new dataset. Since we only updated the classifier layers and froze the rest of the model, the number of trainable parameters is significantly reduced, making the training faster and more efficient.

# Train the modified model on the new dataset
# (Note: code for loading and preparing the dataset is not shown in this article)
# ...

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for inputs, labels in dataloader: # Assuming dataloader contains batches of data
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and update weights
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(dataloader)}")

Fine-Tuning

After training the modified model, you can choose to fine-tune the entire model or unfreeze some of the previously frozen layers. Fine-tuning allows the model to adapt further to the new task and dataset. Fine-tuning is often helpful when the new dataset is larger and more similar to the original dataset on which the model was pretrained.

Conclusion

Transfer learning with pretrained models is a powerful technique that enables the efficient development of highly accurate deep learning models, even with very limited data. PyTorch provides a user-friendly interface through the “torchvision” library to work with pretrained models. By following the steps that are outlined in this article, you can easily adapt the pretrained models to your specific tasks and achieve impressive results.

Remember to choose a suitable pretrained model, freeze the lower layers to preserve the learned features, modify the classifier for your task, and optionally fine-tune the entire model or specific layers. Experiment with different architectures, hyperparameters, and datasets to find the best configuration for your specific application.

About the author

Zeeman Memon

Hi there! I'm a Software Engineer who loves to write about tech. You can reach out to me on LinkedIn.