pytorch

How to Use “torch.argmax()” Method in PyTorch?

In PyTorch, the “torch.argmax()” method is a built-in function that returns indices of maximum values of a particular tensor across a given dimension. Users use this function when they work with tensors and want to find the index of the maximum value along a tensor’s given dimension. Moreover, this method can also be useful for classification where users want to know which class has the highest probability.

This blog will exemplify the method to use the “torch.argmax()” method in PyTorch.

How to Use “torch.argmax()” Method in PyTorch?

The “torch.argmax()” method takes any 1D or 2D tensor as input and returns a tensor that contains the indices/indexes of the maximum values along the given dimension.

The syntax of the “torch.argmax()” method is given below:

torch.argmax(<input_tensor>)

To use this method in PyTorch, go through the following examples for a better understanding:

Example 1: Use “torch.argmax()” Method With 1D Tensor

In the first example, we will create a 1D tensor and use the “torch.argmax()” method with it. Let’s follow the below step-by-step procedure:

Step 1: Import PyTorch Library

First, import the “torch” library to use the “torch.argmax()” method:

import torch

Step 2: Create 1D Tensor

Then, create a 1D tensor and print its elements. Here, we are creating the following “Tens1” tensor from a list using the “torch.tensor()” function:

Tens1 = torch.tensor([5, 0, -8, 1, 9, 7])

print(Tens1)

This has created a 1D tensor as seen below:

Step 3: Find Indices of Maximum Value

Now, utilize the “torch.argmax()” function to find the index/indices of the maximum value in the “Tens1” tensor:

T1_ind = torch.argmax(Tens1)

Step 4: Print Index of Maximum Value

Lastly, display the maximum value’s index in the input tensor:

print("Indices:", T1_ind)

The below output shows the index of the maximum value in the “Tens1” tensor i.e., 4. It means that the highest value of the tensor is at the 4th index which is “9”:

Example 2: Use “torch.argmax()” Method With 2D Tensor

In the second example, we will create a 2D tensor and use the “torch.argmax()” method with it. Let’s follow the provided steps:

Step 1: Import PyTorch Library

First, import the “torch” library to use the “torch.argmax()” method:

import torch

Step 2: Create 2D Tensor

Then, use the “torch.tensor()” function to create a 2D tensor and print its elements. Here, we are creating the following “Tens2” 2D tensor:

Tens2 = torch.tensor([[4, 1, -7], [15, 6, 0], [-7, 9, 2]])

print(Tens2)

This has created a 2D tensor as seen below:

Step 3: Find Indices of Maximum Value

Now, find the index of the maximum value in the “Tens2” tensor by utilizing the “torch.argmax()” function:

T2_ind = torch.argmax(Tens2)

Step 4: Print Index of Maximum Value

Finally, display the maximum value’s index in the input tensor:

print("Indices:", T2_ind)

According to the below output, the index of the maximum value in the “Tens2” tensor is “3”. It means that the highest value of the tensor is at the 3rd index which is “15”:

Step 5: Find Indices of Maximum Value Along Columns

Moreover, users can also find the indexes/indices of the maximum values along each column of a tensor. For instance, we can use the “dim=0” argument with the “torch.argmax()” function. It finds the maximum values’ indices along columns in the “Tens2” tensor and then prints those indices:

indice_col = torch.argmax(Tens2, dim=0)

print("Indices in columns:", indice_col)

The below output shows the indexes of the maximum values along each column of the tensor:

Step 6: Find Indices of Maximum Value Along Rows

Similarly, users can also find the indexes/indices of the maximum values along each row of a tensor. For instance, use the “dim=1” argument with the “torch.argmax()” function to find the maximum values’ indices along rows in the “Tens2” tensor and then print those indices:

indice_row = torch.argmax(Tens2, dim=1)

print("Indices in rows:", indice_row)

The maximum value’s indexes along each row of a “Tens2” tensor can be seen below:

We have efficiently explained the method to use the “torch.argmax()” method in PyTorch.

Note: You can access our Google Colab Notebook at this link.

Conclusion

To use the “torch.argmax()” method in PyTorch, first, import the “torch” library. Then, create the desired 1D or 2D tensor and view its elements. Next, use the “torch.argmax()” method to find/compute the indexes/indices of the maximum values in the tensor. Moreover, users can also find the maximum value’s indexes along each row or column in the tensor using the “dim” argument. Finally, display the maximum value’s index in the input tensor. This blog has exemplified the method to use the “torch.argmax()” method in PyTorch.

About the author

Laiba Younas

I have done bachelors in Computer Science. Being passionate about learning new technologies, I am interested in exploring different programming languages and sharing my experience with the world.