PyTorch is an open-source framework available with a Python programming language. We can process the data in PyTorch in the form of a Tensor.
A tensor is a multidimensional array that is used to store the data. So for using a Tensor, we have to import the torch module.
To create a tensor, the method used is tensor()”
Syntax:
Where data is a multi-dimensional array.
argmax()
argmax() in PyTorch is used to return the index of the maximum value of all elements in the input tensor.
Syntax:
Where
- The tensor is the input tensor
- dim is to reduce the dimension. dim=0 specifies column comparison, which will get the index for maximum value along a column, and dim=1 specifies row comparison, which will get the index for maximum value along the row.
- keepdim checks whether the output tensor has dimension(dim) retained or not
Example 1
In this example, we will create a tensor with 2 dimensions that have 3 rows and 5 columns and apply argmax() on rows and columns.
import torch
#create a tensor with 2 dimensions (3 * 5)
#with random elements using randn() function
data = torch.randn(3,5)
#display
print(data)
#get maximum index along columns with argmax
print(torch.argmax(data, dim=0))
#get maximum index along rows with argmax
print(torch.argmax(data, dim=1))
Output:
[-0.3117, 0.2488, 0.2677, 0.2568, 0.5337],
[-1.0966, 1.8024, -0.7538, -0.2553, -1.0591]])
tensor([0, 2, 1, 1, 0])
tensor([1, 4, 1])
We can see that the maximum values present in the index along columns are:
- Max value – 0.6699. Its index is 0.
- Max value – 1.8024. Its index is 2.
- Max value – 0.2677. Its index is 1.
- Max value – 0.2568. Its index is 1.
- Max value – 0.6544. Its index is 0.
Similarly, maximum values present at index along rows are:
- Max value – 1.3390. Its index is 1.
- Max value – 0.5337. Its index is 4.
- Max value – 1.8024. Its index is 1.
Example 2
Create Tensor with 5 * 5 matrix and apply argmax()
import torch
#create a tensor with 2 dimensions (5 * 5)
#with random elements using randn() function
data = torch.randn(5,5)
#display
print(data)
#get maximum index along columns with argmax
print(torch.argmax(data, dim=0))
#get maximum index along rows with argmax
print(torch.argmax(data, dim=1))
Output:
[-0.5466, -1.6395, 0.2576, -0.3123, 0.6785],
[-0.4574, 1.5301, 0.4812, 0.3434, 0.1388],
[ 0.8364, 0.3821, 0.1529, 1.4529, 0.3747],
[-1.4991, -1.8821, -0.2861, -0.4067, 1.1323]])
tensor([3, 2, 2, 3, 4])
tensor([1, 4, 1, 3, 4])
We can see that the maximum values present in the index along columns are:
- Max value – 0.8364. Its index is 3.
- Max value – 1.5301. Its index is 2.
- Max value – 0.4812. Its index is 2.
- Max value – 1.4529. Its index is 3.
- Max value – 1.1323. Its index is 4.
Similarly, maximum values present at index along rows are:
- Max value – -0.2611. Its index is 1.
- Max value – 0.6785. Its index is 4.
- Max value – 1.5301. Its index is 1.
- Max value – 1.4529. Its index is 3.
- Max value – 1.1323. Its index is 4.
Work With CPU
If you want to run an argmax() function on the CPU, then we have to create a tensor with a cpu() function. This will run on a CPU machine.
When we are creating a tensor, at this time, we can use the cpu() function.
Syntax:
Example 1
#import torch module
import torch
#create a tensor with 2 dimensions (5 * 5)
#with random elements using randn() function with cpu()
data = torch.randn(5,5).cpu()
#display
print(data)
#get maximum index along columns with argmax
print(torch.argmax(data, dim=0))
#get maximum index along rows with argmax
print(torch.argmax(data, dim=1))
Output:
[-0.4415, -2.5789, 0.8294, -0.9309, 1.3535],
[-1.3256, -0.7233, -0.9713, 1.0742, 1.9350],
[-0.7126, -1.3336, 0.7371, -0.2253, 0.1675],
[-0.1174, -0.5773, 0.8887, -0.2563, -1.0667]])
tensor([4, 0, 4, 2, 2])
tensor([1, 4, 4, 2, 2])
We can see that the maximum values present in the index along columns are:
- Max value – -0.1174. Its index is 4.
- Max value – 1.6140. Its index is 0.
- Max value – 0.8887. Its index is 4.
- Max value – 1.0742. Its index is 2.
- Max value – 1.9350. Its index is 2.
Similarly, maximum values present at index along rows are:
- Max value – 1.6140. Its index is 1.
- Max value – 1.3535. Its index is 4.
- Max value – 1.9350. Its index is 4.
- Max value – 0.7371. Its index is 2.
- Max value – 0.8887. Its index is 2.
Example 2
In this example, we will create a tensor with 2 dimensions that have 3 rows and 5 columns using cpu() function and apply argmax() on rows and columns.
import torch
#create a tensor with 2 dimensions (3 * 5)
#with random elements using randn() with cpu() function
data = torch.randn(3,5).cpu()
#display
print(data)
#get maximum index along columns with argmax
print(torch.argmax(data, dim=0))
#get maximum index along rows with argmax
print(torch.argmax(data, dim=1))
Output:
[-0.3117, 0.2488, 0.2677, 0.2568, 0.5337],
[-1.0966, 1.8024, -0.7538, -0.2553, -1.0591]])
tensor([0, 2, 1, 1, 0])
tensor([1, 4, 1])
We can see that the maximum values present in the index along columns are:
- Max value – 0.6699. Its index is 0.
- Max value – 1.8024. Its index is 2.
- Max value – 0.2677. Its index is 1.
- Max value – 0.2568. Its index is 1.
- Max value – 0.6544. Its index is 0.
Similarly, maximum values present at index along rows are:
- Max value – 1.3390. Its index is 1.
- Max value – 0.5337. Its index is 4.
- Max value – 1.8024. Its index is 1.
Conclusion
In this PyTorch lesson, we saw what argmax() and how to apply argmax() on a tensor to return indices of maximum values across columns and rows.
We also created a tensor with cpu() function and returned indices of maximum values. dim is the parameter used to return indices of maximum values across columns when it is set to 0 and return indices of maximum values across rows when it is set to 1.