This blog will illustrate the method of using the “load_state_dict()” method in PyTorch.
How to Use “load_state_dict()” Method in PyTorch
To use the “load_state_dict()” method in PyTorch, follow the below-provided steps:
- Import PyTorch libraries and modules
- Define specific model
- Save the model using “state_dict”
- Load model using the “load_state_dict()” method
Step 1: Import PyTorch Library and Modules
First, import the desired libraries and modules. For instance, we have imported the “PyTorch” library and “torch.nn” module for building neural network layers and architectures:
import torch.nn as nn
Step 2: Define the Model
Then, create or download the specific model. For instance, we have defined the simple neural network model here:
def __init__(self, input_feat):
super(Model, self).__init__()
self.linear = nn.Linear(input_feat, 1)
def forward(self, x):
y_pred = torch.sigmoid(self.linear(x))
return y_pred
model = Model(input_feat=6)
Step 3: Save Model Using “state_dict()” Method
To save a PyTorch model using the “state_dict()” method, call the “torch.save()” function and pass the “state_dict()” method as an attribute of the model and file path as arguments:
Here:
- “model” is our PyTorch model.
- “model.state_dict()” is a state dictionary of the model.
- “model.pth” is the file name where we want to save our model:
Subsequently, the “model.pth” file will be created in our current directory as seen below:
The model has been saved using the “state_dict()” function and now we can load it using “load_state_dict()” method.
Step 4: Load PyTorch Model Using “load_state_dict()” Method
To load a PyTorch model using “load_state_dict()” method, create an instance of the same model architecture that was used to save the “state_dict”. Then, call the “load_state_dict()” method on it:
loaded_model.load_state_dict(torch.load('model.pth'))
loaded_model.eval()
We have successfully loaded the model using the “load_state_dict()” method in PyTorch.
Conclusion
To use “load_state_dict()” method in PyTorch, it is required to first save the model using “state_dict” and then use the “load_state_dict()” function to load the saved model. This will load the saved state dictionary into a model. This blog has illustrated the method of using the “load_state_dict()” method in PyTorch.