pytorch

How to Save and Load Model Weights in PyTorch?

In PyTorch, model weights are the learnable parameters of a neural network model that are updated during the training process. They enable the model to create/make accurate predictions. These weights are stored in the “state_dict” of the model. PyTorch provides various functions to save and load model weights.

This blog will illustrate the method of saving and loading particular model weights in PyTorch.

How to Save and Load Specific Model Weights in PyTorch?

To save and load particular model weights in PyTorch, follow the provided steps:

Step 1: Import PyTorch Library and Modules

First, import the desired libraries and modules. For instance, we have imported the main “PyTorch” library, and pre-trained model architectures from the “torchvision” and “torch.nn” modules for building neural network layers and architectures:

import torch

import torchvision.models as models

import torch.nn as nn

Step 2: Build the Model and Define Weights

Then, create a specific model and define weights. For instance, we have used the “ResNet-50” model and initialized it with weights pre-trained on the “ImageNet” dataset using the “IMAGENET1K_V2”:

def build_model(

  mod_weights='IMAGENET1K_V2',

  fine_tune=False,

  num_classes=10

):

  model = models.resnet50(weights=mod_weights)

  model.fc = nn.Linear(in_feat=2048, out_feat=num_classes)

  return model

Step 3: Save Model Weights Using the “torch.save()” Method

To save model weights, load pre-trained weights from the defined model. Then, use the “torch.save()” function to save the “state_dict” (which includes the weights) of the model to the specific file. For instance, we are saving weights to a file called “savedModel_weights.pth”:

model = models.resnet50(weights='IMAGENET1K_V2')

torch.save(model.state_dict(), 'savedModel_weights.pth')

Upon doing so, the “savedModel_weights.pth” file is created in our current directory as seen below:

The model weights have been saved successfully.

Step 4: Load Model Weights

To load model weights, create/make an instance of the exact model, and then load the parameters by utilizing the “load_state_dict()” method:

model = models.resnet50()

model.load_state_dict(torch.load('savedModel_weights.pth'))

model.eval()

We have efficiently explained the method to save and load model weights in PyTorch.

Conclusion

To save and load model weights in PyTorch, first, import the desired PyTorch library and modules. Then, build a model and define weights. Next, save model weights to a specific file using the “torch.save()” method and finally load it using the “load_state_dict()” method. This blog has illustrated the method of saving and loading model weights in PyTorch.

About the author

Laiba Younas

I have done bachelors in Computer Science. Being passionate about learning new technologies, I am interested in exploring different programming languages and sharing my experience with the world.