This blog will illustrate the method of saving and loading particular model weights in PyTorch.
How to Save and Load Specific Model Weights in PyTorch?
To save and load particular model weights in PyTorch, follow the provided steps:
- Import PyTorch library and modules
- Build the model and define weights
- Save model weights using the “torch.save()” method
- Load model weights using the “load_state_dict()” method
Step 1: Import PyTorch Library and Modules
First, import the desired libraries and modules. For instance, we have imported the main “PyTorch” library, and pre-trained model architectures from the “torchvision” and “torch.nn” modules for building neural network layers and architectures:
import torchvision.models as models
import torch.nn as nn
Step 2: Build the Model and Define Weights
Then, create a specific model and define weights. For instance, we have used the “ResNet-50” model and initialized it with weights pre-trained on the “ImageNet” dataset using the “IMAGENET1K_V2”:
mod_weights='IMAGENET1K_V2',
fine_tune=False,
num_classes=10
):
model = models.resnet50(weights=mod_weights)
model.fc = nn.Linear(in_feat=2048, out_feat=num_classes)
return model
Step 3: Save Model Weights Using the “torch.save()” Method
To save model weights, load pre-trained weights from the defined model. Then, use the “torch.save()” function to save the “state_dict” (which includes the weights) of the model to the specific file. For instance, we are saving weights to a file called “savedModel_weights.pth”:
torch.save(model.state_dict(), 'savedModel_weights.pth')
Upon doing so, the “savedModel_weights.pth” file is created in our current directory as seen below:
The model weights have been saved successfully.
Step 4: Load Model Weights
To load model weights, create/make an instance of the exact model, and then load the parameters by utilizing the “load_state_dict()” method:
We have efficiently explained the method to save and load model weights in PyTorch.
Conclusion
To save and load model weights in PyTorch, first, import the desired PyTorch library and modules. Then, build a model and define weights. Next, save model weights to a specific file using the “torch.save()” method and finally load it using the “load_state_dict()” method. This blog has illustrated the method of saving and loading model weights in PyTorch.