Python

PyTorch – argmax()

β€œIn this PyTorch tutorial, we will see how to return index positions of maximum values from a tensor using argmax().

PyTorch is an open-source framework available with a Python programming language. We can process the data in PyTorch in the form of a Tensor.

A tensor is a multidimensional array that is used to store the data. So for using a Tensor, we have to import the torch module.

To create a tensor, the method used is tensor()”

Syntax:

torch.tensor(data)

Where data is a multi-dimensional array.

argmax()

argmax() in PyTorch is used to return the index of the maximum value of all elements in the input tensor.

Syntax:

torch.argmax(tensor,dim,keepdim)

Where

  1. The tensor is the input tensor
  2. dim is to reduce the dimension. dim=0 specifies column comparison, which will get the index for maximum value along a column, and dim=1 specifies row comparison, which will get the index for maximum value along the row.
  3. keepdim checks whether the output tensor has dimension(dim) retained or not

Example 1

In this example, we will create a tensor with 2 dimensions that have 3 rows and 5 columns and apply argmax() on rows and columns.

#import torch module

import torch

#create a tensor with 2 dimensions (3 * 5)

#with random elements using randn() function

data = torch.randn(3,5)

#display

print(data)

#get maximum index along columns with argmax

print(torch.argmax(data, dim=0))

#get maximum index along rows with argmax

print(torch.argmax(data, dim=1))

Output:

tensor([[ 0.6699, 1.3390, -1.0658, -1.8200, 0.6544],

[-0.3117, 0.2488, 0.2677, 0.2568, 0.5337],

[-1.0966, 1.8024, -0.7538, -0.2553, -1.0591]])

tensor([0, 2, 1, 1, 0])

tensor([1, 4, 1])

We can see that the maximum values present in the index along columns are:

  1. Max value – 0.6699. Its index is 0.
  2. Max value – 1.8024. Its index is 2.
  3. Max value – 0.2677. Its index is 1.
  4. Max value – 0.2568. Its index is 1.
  5. Max value – 0.6544. Its index is 0.

Similarly, maximum values present at index along rows are:

  1. Max value – 1.3390. Its index is 1.
  2. Max value – 0.5337. Its index is 4.
  3. Max value – 1.8024. Its index is 1.

Example 2

Create Tensor with 5 * 5 matrix and apply argmax()

#import torch module

import torch

#create a tensor with 2 dimensions (5 * 5)

#with random elements using randn() function

data = torch.randn(5,5)

#display

print(data)

#get maximum index along columns with argmax

print(torch.argmax(data, dim=0))

#get maximum index along rows with argmax

print(torch.argmax(data, dim=1))

Output:

tensor([[-0.9553, -0.2611, -2.1233, -0.5208, -0.3458],

[-0.5466, -1.6395, 0.2576, -0.3123, 0.6785],

[-0.4574, 1.5301, 0.4812, 0.3434, 0.1388],

[ 0.8364, 0.3821, 0.1529, 1.4529, 0.3747],

[-1.4991, -1.8821, -0.2861, -0.4067, 1.1323]])

tensor([3, 2, 2, 3, 4])

tensor([1, 4, 1, 3, 4])

We can see that the maximum values present in the index along columns are:

  1. Max value – 0.8364. Its index is 3.
  2. Max value – 1.5301. Its index is 2.
  3. Max value – 0.4812. Its index is 2.
  4. Max value – 1.4529. Its index is 3.
  5. Max value – 1.1323. Its index is 4.

Similarly, maximum values present at index along rows are:

  1. Max value – -0.2611. Its index is 1.
  2. Max value – 0.6785. Its index is 4.
  3. Max value – 1.5301. Its index is 1.
  4. Max value – 1.4529. Its index is 3.
  5. Max value – 1.1323. Its index is 4.

Work With CPU

If you want to run an argmax() function on the CPU, then we have to create a tensor with a cpu() function. This will run on a CPU machine.

When we are creating a tensor, at this time, we can use the cpu() function.

Syntax:

torch.tensor(data).cpu()

Example 1

Create Tensor with 5 * 5 matrix with cpu() and apply argmax()

#import torch module

import torch

#create a tensor with 2 dimensions (5 * 5)

#with random elements using randn() function with cpu()

data = torch.randn(5,5).cpu()

#display

print(data)

#get maximum index along columns with argmax

print(torch.argmax(data, dim=0))

#get maximum index along rows with argmax

print(torch.argmax(data, dim=1))

Output:

tensor([[-0.2213, 1.6140, -0.0774, 0.4135, 0.1379],

[-0.4415, -2.5789, 0.8294, -0.9309, 1.3535],

[-1.3256, -0.7233, -0.9713, 1.0742, 1.9350],

[-0.7126, -1.3336, 0.7371, -0.2253, 0.1675],

[-0.1174, -0.5773, 0.8887, -0.2563, -1.0667]])

tensor([4, 0, 4, 2, 2])

tensor([1, 4, 4, 2, 2])

We can see that the maximum values present in the index along columns are:

  1. Max value – -0.1174. Its index is 4.
  2. Max value – 1.6140. Its index is 0.
  3. Max value – 0.8887. Its index is 4.
  4. Max value – 1.0742. Its index is 2.
  5. Max value – 1.9350. Its index is 2.

Similarly, maximum values present at index along rows are:

  1. Max value – 1.6140. Its index is 1.
  2. Max value – 1.3535. Its index is 4.
  3. Max value – 1.9350. Its index is 4.
  4. Max value – 0.7371. Its index is 2.
  5. Max value – 0.8887. Its index is 2.

Example 2

In this example, we will create a tensor with 2 dimensions that have 3 rows and 5 columns using cpu() function and apply argmax() on rows and columns.

#import torch module

import torch

#create a tensor with 2 dimensions (3 * 5)

#with random elements using randn() with cpu() function

data = torch.randn(3,5).cpu()

#display

print(data)

#get maximum index along columns with argmax

print(torch.argmax(data, dim=0))

#get maximum index along rows with argmax

print(torch.argmax(data, dim=1))

Output:

tensor([[ 0.6699, 1.3390, -1.0658, -1.8200, 0.6544],

[-0.3117, 0.2488, 0.2677, 0.2568, 0.5337],

[-1.0966, 1.8024, -0.7538, -0.2553, -1.0591]])

tensor([0, 2, 1, 1, 0])

tensor([1, 4, 1])

We can see that the maximum values present in the index along columns are:

  1. Max value – 0.6699. Its index is 0.
  2. Max value – 1.8024. Its index is 2.
  3. Max value – 0.2677. Its index is 1.
  4. Max value – 0.2568. Its index is 1.
  5. Max value – 0.6544. Its index is 0.

Similarly, maximum values present at index along rows are:

  1. Max value – 1.3390. Its index is 1.
  2. Max value – 0.5337. Its index is 4.
  3. Max value – 1.8024. Its index is 1.

Conclusion

In this PyTorch lesson, we saw what argmax() and how to apply argmax() on a tensor to return indices of maximum values across columns and rows.

We also created a tensor with cpu() function and returned indices of maximum values. dim is the parameter used to return indices of maximum values across columns when it is set to 0 and return indices of maximum values across rows when it is set to 1.

About the author

Gottumukkala Sravan Kumar

B tech-hon's in Information Technology; Known programming languages - Python, R , PHP MySQL; Published 500+ articles on computer science domain