PyTorch is an open-source framework available with a Python programming language. We can process the data in PyTorch in the form of a Tensor.

A tensor is a multidimensional array that is used to store the data. So for using a Tensor, we have to import the torch module.

To create a tensor, the method used is tensor()”

**Syntax**:

Where data is a multi-dimensional array.

**argmax()**

argmax() in PyTorch is used to return the index of the maximum value of all elements in the input tensor.

**Syntax**:

Where

- The tensor is the input tensor
- dim is to reduce the dimension. dim=0 specifies column comparison, which will get the index for maximum value along a column, and dim=1 specifies row comparison, which will get the index for maximum value along the row.
- keepdim checks whether the output tensor has dimension(dim) retained or not

**Example 1**

In this example, we will create a tensor with 2 dimensions that have 3 rows and 5 columns and apply argmax() on rows and columns.

import torch

#create a tensor with 2 dimensions (3 * 5)

#with random elements using randn() function

data = torch.randn(3,5)

#display

print(data)

#get maximum index along columns with argmax

print(torch.argmax(data, dim=0))

#get maximum index along rows with argmax

print(torch.argmax(data, dim=1))

Output:

[-0.3117, 0.2488, 0.2677, 0.2568, 0.5337],

[-1.0966, 1.8024, -0.7538, -0.2553, -1.0591]])

tensor([0, 2, 1, 1, 0])

tensor([1, 4, 1])

We can see that the maximum values present in the index along columns are:

- Max value – 0.6699. Its index is 0.
- Max value – 1.8024. Its index is 2.
- Max value – 0.2677. Its index is 1.
- Max value – 0.2568. Its index is 1.
- Max value – 0.6544. Its index is 0.

Similarly, maximum values present at index along rows are:

- Max value – 1.3390. Its index is 1.
- Max value – 0.5337. Its index is 4.
- Max value – 1.8024. Its index is 1.

**Example 2**

Create Tensor with 5 * 5 matrix and apply argmax()

import torch

#create a tensor with 2 dimensions (5 * 5)

#with random elements using randn() function

data = torch.randn(5,5)

#display

print(data)

#get maximum index along columns with argmax

print(torch.argmax(data, dim=0))

#get maximum index along rows with argmax

print(torch.argmax(data, dim=1))

Output:

[-0.5466, -1.6395, 0.2576, -0.3123, 0.6785],

[-0.4574, 1.5301, 0.4812, 0.3434, 0.1388],

[ 0.8364, 0.3821, 0.1529, 1.4529, 0.3747],

[-1.4991, -1.8821, -0.2861, -0.4067, 1.1323]])

tensor([3, 2, 2, 3, 4])

tensor([1, 4, 1, 3, 4])

We can see that the maximum values present in the index along columns are:

- Max value – 0.8364. Its index is 3.
- Max value – 1.5301. Its index is 2.
- Max value – 0.4812. Its index is 2.
- Max value – 1.4529. Its index is 3.
- Max value – 1.1323. Its index is 4.

Similarly, maximum values present at index along rows are:

- Max value – -0.2611. Its index is 1.
- Max value – 0.6785. Its index is 4.
- Max value – 1.5301. Its index is 1.
- Max value – 1.4529. Its index is 3.
- Max value – 1.1323. Its index is 4.

**Work With CPU**

If you want to run an argmax() function on the CPU, then we have to create a tensor with a cpu() function. This will run on a CPU machine.

When we are creating a tensor, at this time, we can use the cpu() function.

**Syntax**:

**Example 1**

#import torch module

import torch

#create a tensor with 2 dimensions (5 * 5)

#with random elements using randn() function with cpu()

data = torch.randn(5,5).cpu()

#display

print(data)

#get maximum index along columns with argmax

print(torch.argmax(data, dim=0))

#get maximum index along rows with argmax

print(torch.argmax(data, dim=1))

Output:

[-0.4415, -2.5789, 0.8294, -0.9309, 1.3535],

[-1.3256, -0.7233, -0.9713, 1.0742, 1.9350],

[-0.7126, -1.3336, 0.7371, -0.2253, 0.1675],

[-0.1174, -0.5773, 0.8887, -0.2563, -1.0667]])

tensor([4, 0, 4, 2, 2])

tensor([1, 4, 4, 2, 2])

We can see that the maximum values present in the index along columns are:

- Max value – -0.1174. Its index is 4.
- Max value – 1.6140. Its index is 0.
- Max value – 0.8887. Its index is 4.
- Max value – 1.0742. Its index is 2.
- Max value – 1.9350. Its index is 2.

Similarly, maximum values present at index along rows are:

- Max value – 1.6140. Its index is 1.
- Max value – 1.3535. Its index is 4.
- Max value – 1.9350. Its index is 4.
- Max value – 0.7371. Its index is 2.
- Max value – 0.8887. Its index is 2.

**Example 2**

In this example, we will create a tensor with 2 dimensions that have 3 rows and 5 columns using cpu() function and apply argmax() on rows and columns.

import torch

#create a tensor with 2 dimensions (3 * 5)

#with random elements using randn() with cpu() function

data = torch.randn(3,5).cpu()

#display

print(data)

#get maximum index along columns with argmax

print(torch.argmax(data, dim=0))

#get maximum index along rows with argmax

print(torch.argmax(data, dim=1))

Output:

[-0.3117, 0.2488, 0.2677, 0.2568, 0.5337],

[-1.0966, 1.8024, -0.7538, -0.2553, -1.0591]])

tensor([0, 2, 1, 1, 0])

tensor([1, 4, 1])

We can see that the maximum values present in the index along columns are:

- Max value – 0.6699. Its index is 0.
- Max value – 1.8024. Its index is 2.
- Max value – 0.2677. Its index is 1.
- Max value – 0.2568. Its index is 1.
- Max value – 0.6544. Its index is 0.

Similarly, maximum values present at index along rows are:

- Max value – 1.3390. Its index is 1.
- Max value – 0.5337. Its index is 4.
- Max value – 1.8024. Its index is 1.

**Conclusion**

In this PyTorch lesson, we saw what argmax() and how to apply argmax() on a tensor to return indices of maximum values across columns and rows.

We also created a tensor with cpu() function and returned indices of maximum values. dim is the parameter used to return indices of maximum values across columns when it is set to 0 and return indices of maximum values across rows when it is set to 1.