This blog will exemplify the method to use the “torch.argmax()” method in PyTorch.
How to Use “torch.argmax()” Method in PyTorch?
The “torch.argmax()” method takes any 1D or 2D tensor as input and returns a tensor that contains the indices/indexes of the maximum values along the given dimension.
The syntax of the “torch.argmax()” method is given below:
To use this method in PyTorch, go through the following examples for a better understanding:
Example 1: Use “torch.argmax()” Method With 1D Tensor
In the first example, we will create a 1D tensor and use the “torch.argmax()” method with it. Let’s follow the below step-by-step procedure:
Step 1: Import PyTorch Library
First, import the “torch” library to use the “torch.argmax()” method:
Step 2: Create 1D Tensor
Then, create a 1D tensor and print its elements. Here, we are creating the following “Tens1” tensor from a list using the “torch.tensor()” function:
print(Tens1)
This has created a 1D tensor as seen below:
Step 3: Find Indices of Maximum Value
Now, utilize the “torch.argmax()” function to find the index/indices of the maximum value in the “Tens1” tensor:
Step 4: Print Index of Maximum Value
Lastly, display the maximum value’s index in the input tensor:
The below output shows the index of the maximum value in the “Tens1” tensor i.e., 4. It means that the highest value of the tensor is at the 4th index which is “9”:
Example 2: Use “torch.argmax()” Method With 2D Tensor
In the second example, we will create a 2D tensor and use the “torch.argmax()” method with it. Let’s follow the provided steps:
Step 1: Import PyTorch Library
First, import the “torch” library to use the “torch.argmax()” method:
Step 2: Create 2D Tensor
Then, use the “torch.tensor()” function to create a 2D tensor and print its elements. Here, we are creating the following “Tens2” 2D tensor:
print(Tens2)
This has created a 2D tensor as seen below:
Step 3: Find Indices of Maximum Value
Now, find the index of the maximum value in the “Tens2” tensor by utilizing the “torch.argmax()” function:
Step 4: Print Index of Maximum Value
Finally, display the maximum value’s index in the input tensor:
According to the below output, the index of the maximum value in the “Tens2” tensor is “3”. It means that the highest value of the tensor is at the 3rd index which is “15”:
Step 5: Find Indices of Maximum Value Along Columns
Moreover, users can also find the indexes/indices of the maximum values along each column of a tensor. For instance, we can use the “dim=0” argument with the “torch.argmax()” function. It finds the maximum values’ indices along columns in the “Tens2” tensor and then prints those indices:
print("Indices in columns:", indice_col)
The below output shows the indexes of the maximum values along each column of the tensor:
Step 6: Find Indices of Maximum Value Along Rows
Similarly, users can also find the indexes/indices of the maximum values along each row of a tensor. For instance, use the “dim=1” argument with the “torch.argmax()” function to find the maximum values’ indices along rows in the “Tens2” tensor and then print those indices:
print("Indices in rows:", indice_row)
The maximum value’s indexes along each row of a “Tens2” tensor can be seen below:
We have efficiently explained the method to use the “torch.argmax()” method in PyTorch.
Note: You can access our Google Colab Notebook at this link.
Conclusion
To use the “torch.argmax()” method in PyTorch, first, import the “torch” library. Then, create the desired 1D or 2D tensor and view its elements. Next, use the “torch.argmax()” method to find/compute the indexes/indices of the maximum values in the tensor. Moreover, users can also find the maximum value’s indexes along each row or column in the tensor using the “dim” argument. Finally, display the maximum value’s index in the input tensor. This blog has exemplified the method to use the “torch.argmax()” method in PyTorch.