**torch.argmax()**” method is a built-in function that returns indices of maximum values of a particular tensor across a given dimension. Users use this function when they work with tensors and want to find the index of the maximum value along a tensor’s given dimension. Moreover, this method can also be useful for classification where users want to know which class has the highest probability.

This blog will exemplify the method to use the “torch.argmax()” method in PyTorch.

**How to Use “torch.argmax()” Method in PyTorch?**

The “torch.argmax()” method takes any 1D or 2D tensor as input and returns a tensor that contains the indices/indexes of the maximum values along the given dimension.

The syntax of the “torch.argmax()” method is given below:

To use this method in PyTorch, go through the following examples for a better understanding:

**Example 1: Use “torch.argmax()” Method With 1D Tensor**

In the first example, we will create a 1D tensor and use the “torch.argmax()” method with it. Let’s follow the below step-by-step procedure:

**Step 1: Import PyTorch Library**

First, import the “**torch**” library to use the “torch.argmax()” method:

**Step 2: Create 1D Tensor **

Then, create a 1D tensor and print its elements. Here, we are creating the following “**Tens1**” tensor from a list using the “**torch.tensor()**” function:

print(Tens1)

This has created a 1D tensor as seen below:

**Step 3: Find Indices of Maximum Value**

Now, utilize the “**torch.argmax()**” function to find the index/indices of the maximum value in the “**Tens1**” tensor:

**Step 4: Print Index of Maximum Value**

Lastly, display the maximum value’s index in the input tensor:

The below output shows the index of the maximum value in the “**Tens1**” tensor i.e., 4. It means that the highest value of the tensor is at the 4th index which is “**9**”:

**Example 2: Use “torch.argmax()” Method With 2D Tensor**

In the second example, we will create a 2D tensor and use the “torch.argmax()” method with it. Let’s follow the provided steps:

**Step 1: Import PyTorch Library**

First, import the “**torch**” library to use the “torch.argmax()” method:

**Step 2: Create 2D Tensor **

Then, use the “**torch.tensor()**” function to create a 2D tensor and print its elements. Here, we are creating the following “**Tens2**” 2D tensor:

print(Tens2)

This has created a 2D tensor as seen below:

**Step 3: Find Indices of Maximum Value**

Now, find the index of the maximum value in the “**Tens2**” tensor by utilizing the “**torch.argmax()**” function:

**Step 4: Print Index of Maximum Value**

Finally, display the maximum value’s index in the input tensor:

According to the below output, the index of the maximum value in the “**Tens2**” tensor is “3”. It means that the highest value of the tensor is at the 3rd index which is “**15**”:

**Step 5: Find Indices of Maximum Value Along Columns**

Moreover, users can also find the indexes/indices of the maximum values along each column of a tensor. For instance, we can use the “**dim=0**” argument with the “torch.argmax()” function. It finds the maximum values’ indices along columns in the “**Tens2**” tensor and then prints those indices:

print("Indices in columns:", indice_col)

The below output shows the indexes of the maximum values along each column of the tensor:

**Step 6: Find Indices of Maximum Value Along Rows**

Similarly, users can also find the indexes/indices of the maximum values along each row of a tensor. For instance, use the “**dim=1**” argument with the “torch.argmax()” function to find the maximum values’ indices along rows in the “Tens2” tensor and then print those indices:

print("Indices in rows:", indice_row)

The maximum value’s indexes along each row of a “Tens2” tensor can be seen below:

We have efficiently explained the method to use the “torch.argmax()” method in PyTorch.

**Note**: You can access our Google Colab Notebook at this link.

**Conclusion**

To use the “torch.argmax()” method in PyTorch, first, import the “**torch**” library. Then, create the desired 1D or 2D tensor and view its elements. Next, use the “**torch.argmax()**” method to find/compute the indexes/indices of the maximum values in the tensor. Moreover, users can also find the maximum value’s indexes along each row or column in the tensor using the “**dim**” argument. Finally, display the maximum value’s index in the input tensor. This blog has exemplified the method to use the “torch.argmax()” method in PyTorch.