PyTorch is a Python-based library that enables users to create/build machine-learning models. Sometimes, users may want to count the number of parameters in a PyTorch model for several reasons, such as for estimating the model size, debugging, and optimization purposes. In such circumstances, PyTorch provides a way to get the model’s total number of parameters.
This article will exemplify the method of checking the total number of parameters in the PyTorch model.
How to Check/Find Out the Total Number of Parameters in the PyTorch Model?
In PyTorch, there is no built-in function to count the total number of model parameters. However, there is a possible way to find out the model parameters using the model class. The model class has a property called parameters() that returns an iterator over all the model’s parameters.
To get the total number of parameters use the following snippet:
print(f"Number of parameters: {total_params}")
Here:
- “total_params” is a variable that holds the total number of calculated parameters.
- “model.parameters()” gets the iterator that loops through all the model’s parameters.
- “for p in model.parameters()” iterates through each parameter tensor in the model.
- “p.numel()” calculates the total number of parameters in the tensor.
- “sum()” function takes the sequence of parameter counts and calculates their sum.
- “print()” function prints out the calculated total number of parameters
Now, let us use this snippet with different PyTorch models to calculate their total number of parameters. Check out the following examples for a better understanding:
Example 1: Check the Total Number of Parameters of resnet152 Model
In this example, we are using a “resnet152” model to calculate its total number of parameters:
from torchvision import datasets, models, transforms
model = models.resnet152(pretrained=True)
for param in model.parameters():
param.requires_grad = False
total_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {total_params}")
This has displayed the total number of parameters of this model i.e., “60192808”:
Example 2: Check the Total Number of Parameters of Simple Neural Network
In this example, we have defined the simple neural network model to get its total number of parameters:
import torch.nn as nn
class Model(nn.Module):
def __init__(self, input_feat):
super(Model, self).__init__()
self.linear = nn.Linear(input_feat, 1)
def forward(self, x):
y_pred = torch.sigmoid(self.linear(x))
return y_pred
model = Model(input_feat=6)
total_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {total_params}")
This has displayed the total number of parameters of this model i.e., “7”.
We have efficiently exemplified the method to find out the total number of parameters in PyTorch models.
Conclusion
To check the number of parameters in a PyTorch model, users can use the “parameters()” method to return an iterator over all the parameters of the model and the “numel()” method of each parameter to get its total number of parameters. This article has exemplified the method to check/get the total number of parameters in the PyTorch model.