pytorch

How to Use the load_state_dict() Method in PyTorch?

PyTorch is a well-known machine learning framework that provides a flexible way of defining model parameters that are stored in the “state_dict” dictionary. The “state_dict()” returns a dictionary containing the learned parameters of a PyTorch model. The “load_state_dict()” method enables users to load a “state_dict” into a model and replace the existing parameters with the ones from the “state_dict”.

This blog will illustrate the method of using the “load_state_dict()” method in PyTorch.

How to Use “load_state_dict()” Method in PyTorch

To use the “load_state_dict()” method in PyTorch, follow the below-provided steps:

Step 1: Import PyTorch Library and Modules

First, import the desired libraries and modules. For instance, we have imported the “PyTorch” library and “torch.nn” module for building neural network layers and architectures:

import torch

import torch.nn as nn

Step 2: Define the Model

Then, create or download the specific model. For instance, we have defined the simple neural network model here:

class Model(nn.Module):

  def __init__(self, input_feat):

    super(Model, self).__init__()

    self.linear = nn.Linear(input_feat, 1)

  def forward(self, x):

    y_pred = torch.sigmoid(self.linear(x))

    return y_pred

    model = Model(input_feat=6)

Step 3: Save Model Using “state_dict()” Method

To save a PyTorch model using the “state_dict()” method, call the “torch.save()” function and pass the “state_dict()” method as an attribute of the model and file path as arguments:

torch.save(model.state_dict(), 'model.pth')

Here:

  • model” is our PyTorch model.
  • model.state_dict()” is a state dictionary of the model.
  • model.pth” is the file name where we want to save our model:

Subsequently, the “model.pth” file will be created in our current directory as seen below:

The model has been saved using the “state_dict()” function and now we can load it using “load_state_dict()” method.

Step 4: Load PyTorch Model Using “load_state_dict()” Method

To load a PyTorch model using “load_state_dict()” method, create an instance of the same model architecture that was used to save the “state_dict”. Then, call the “load_state_dict()” method on it:

loaded_model = Model(input_feat=6)

loaded_model.load_state_dict(torch.load('model.pth'))

loaded_model.eval()

We have successfully loaded the model using the “load_state_dict()” method in PyTorch.

Conclusion

To use “load_state_dict()” method in PyTorch, it is required to first save the model using “state_dict” and then use the “load_state_dict()” function to load the saved model. This will load the saved state dictionary into a model. This blog has illustrated the method of using the “load_state_dict()” method in PyTorch.

About the author

Laiba Younas

I have done bachelors in Computer Science. Being passionate about learning new technologies, I am interested in exploring different programming languages and sharing my experience with the world.