pytorch

Sorting the Elements of a Tensor in PyTorch

We will see how to sort all the elements in a PyTorch tensor in this PyTorch tutorial.

PyTorch is an open-source framework available with a Python programming language. We can process the data in PyTorch in the form of a Tensor. Tensor is a multidimensional array that is used to store the data. To use a tensor, we have to import the torch module. To create a tensor, the method used is tensor().

Syntax:

torch.tensor(data)

Where the data is a multi-dimensional array.

Torch.sort()

Torch.sort() in PyTorch is used to sort the elements in a tensor in the ascending order. If the tensor is two-dimensional, it sorts row-wise when we specify 1. And it sorts in column-wise when we specify 0.

Syntax:
Row wise: torch.sort(two_demensional_tensor_object,1)

Column wise: torch.sort(two_demensional_tensor_object,0)

Parameter:

  1. Two_demensional_tensor_object is the tensor that has 2 dimensions.
  2. One (1) refers to row-wise sorting and 0 refers to column-wise sorting.

It sorts row-wise by default.

Return:
It returns the sorted tensor along with the index positions in the actual tensor.

Example 1:

Let’s create a 2D tensor that has 5 rows and 5 columns. Then, we sort it row-wise without specifying a second parameter.

#import torch module
import torch
 
#create a 2D tensor - data1 with 5 numeric values in 4 rows
data1 = torch.tensor([[23,45,67,0,0],[12,21,10,34,78],[3,4,5,2,3],[45,67,54,32,22]])
 
#display
print("Tensor: ",data1)
 
#sort the above tensor
print("After sorting row-wise: ")
print(torch.sort(data1))

Output:

Tensor:  tensor([[23, 45, 67,  0,  0],
        [12, 21, 10, 34, 78],
        [ 3,  4,  5,  2,  3],
        [45, 67, 54, 32, 22]])
After sorting row-wise:
torch.return_types.sort(
values=tensor([[ 0,  0, 23, 45, 67],
        [10, 12, 21, 34, 78],
        [ 2,  3,  3,  4,  5],
        [22, 32, 45, 54, 67]]),
indices=tensor([[3, 4, 0, 1, 2],
        [2, 0, 1, 3, 4],
        [3, 0, 4, 1, 2],
        [4, 3, 0, 2, 1]]))

We can observe that elements are sorted row-wise in a tensor in ascending order and returned the indices of their positions in the actual tensor.

Example 2:

Let’s create a 2D tensor that has 5 rows and 5 columns. Then, we sort it row-wise by specifying a second parameter as 1.

#import torch module
import torch
 
#create a 2D tensor - data1 with 5 numeric values in 4 rows
data1 = torch.tensor([[23,45,67,0,0],[12,21,10,34,78],[3,4,5,2,3],[45,67,54,32,22]])
 
#display
print("Tensor: ",data1)
 
#sort the above tensor
print("After sorting row-wise: ")
print(torch.sort(data1,1))

Output:

Tensor:  tensor([[23, 45, 67,  0,  0],
        [12, 21, 10, 34, 78],
        [ 3,  4,  5,  2,  3],
        [45, 67, 54, 32, 22]])
After sorting row-wise:
torch.return_types.sort(
values=tensor([[ 0,  0, 23, 45, 67],
        [10, 12, 21, 34, 78],
        [ 2,  3,  3,  4,  5],
        [22, 32, 45, 54, 67]]),
indices=tensor([[3, 4, 0, 1, 2],
        [2, 0, 1, 3, 4],
        [3, 0, 4, 1, 2],
        [4, 3, 0, 2, 1]]))

We can observe that the elements are sorted row-wise in a tensor in ascending order and returned the indices of their positions in the actual tensor.

Example 3:

Let’s create a 2D tensor that has 5 rows and 5 columns. Then, we sort it column-wise by specifying a second parameter as 0.

#import torch module
import torch
 
#create a 2D tensor - data1 with 5 numeric values in 4 rows
data1 = torch.tensor([[23,45,67,0,0],[12,21,10,34,78],[3,4,5,2,3],[45,67,54,32,22]])
 
#display
print("Tensor: ",data1)
 
#sort the above tensor
print("After sorting column-wise: ")
print(torch.sort(data1,0))

Output:

Tensor:  tensor([[23, 45, 67,  0,  0],
        [12, 21, 10, 34, 78],
        [ 3,  4,  5,  2,  3],
        [45, 67, 54, 32, 22]])
After sorting column-wise:
torch.return_types.sort(
values=tensor([[ 3,  4,  5,  0,  0],
        [12, 21, 10,  2,  3],
        [23, 45, 54, 32, 22],
        [45, 67, 67, 34, 78]]),
indices=tensor([[2, 2, 2, 0, 0],
        [1, 1, 1, 2, 2],
        [0, 0, 3, 3, 3],
        [3, 3, 0, 1, 1]]))

We can observe that the elements are sorted column-wise in a tensor in ascending order and returned the indices of their positions in the actual tensor.

Example 4:

Let’s create a 1D tensor that has 5 values. Then, we sort it by using the sort() function.

#import torch module
import torch
 
#create a 1D tensor - data1 with 5 numeric values
data1 = torch.tensor([23,45,67,0,0])
 
#display
print("Tensor: ",data1)
 
#sort the above tensor
print("After sorting:: ")
print(torch.sort(data1))

Output:

Tensor:  tensor([23, 45, 67,  0,  0])
After sorting::
torch.return_types.sort(
values=tensor([ 0,  0, 23, 45, 67]),
indices=tensor([3, 4, 0, 1, 2]))

We can observe that the elements are sorted in ascending order and returned the indices of their positions in the actual tensor.

Work with CPU

If you want to run a sort() function on the CPU, we have to create a tensor with a cpu() function. This will run on a CPU machine.

When we create a tensor, this time, we can use the cpu() function.

Syntax:

torch.tensor(data).cpu()

Example:

Let’s create a 2D tensor that has 5 rows and 5 columns. Then, we sort it row-wise by specifying a second parameter as 1 and sort it column-wise by specifying a second parameter as 0.

#import torch module
import torch
 
#create a 2D tensor - data1 with 5 numeric values in 4 rows
data1 = torch.tensor([[23,45,67,0,0],[12,21,10,34,78],[3,4,5,2,3],[45,67,54,32,22]]).cpu()
 
#display
print("Tensor: ",data1)
 
print()
 #sort the above tensor
print("After sorting row-wise: ")
print(torch.sort(data1,1))

print()

#sort the above tensor
print("After sorting column-wise: ")
print(torch.sort(data1,0))

Output:

Tensor:  tensor([[23, 45, 67,  0,  0],
        [12, 21, 10, 34, 78],
        [ 3,  4,  5,  2,  3],
        [45, 67, 54, 32, 22]])

After sorting row-wise:
torch.return_types.sort(
values=tensor([[ 0,  0, 23, 45, 67],
        [10, 12, 21, 34, 78],
        [ 2,  3,  3,  4,  5],
        [22, 32, 45, 54, 67]]),
indices=tensor([[3, 4, 0, 1, 2],
        [2, 0, 1, 3, 4],
        [3, 0, 4, 1, 2],
        [4, 3, 0, 2, 1]]))

After sorting column-wise:
torch.return_types.sort(
values=tensor([[ 3,  4,  5,  0,  0],
        [12, 21, 10,  2,  3],
        [23, 45, 54, 32, 22],
        [45, 67, 67, 34, 78]]),
indices=tensor([[2, 2, 2, 0, 0],
        [1, 1, 1, 2, 2],
        [0, 0, 3, 3, 3],
        [3, 3, 0, 1, 1]]))

We can observe that the elements are sorted row-wise & column-wise in a tensor in ascending order and returned the indices of their positions in the actual tensor.

Conclusion

In this PyTorch tutorial, we learned how to sort the elements in a tensor in ascending order using the torch.sort() function. If the tensor is two-dimensional, it sorts row-wise when we specify 1 and sorts column-wise when we specify 0. It returns the sorted tensor along with the index positions in the actual tensor.

We learned the different examples along with the cpu() function. The torch.sort() function don’t take any parameter while applying it on the 1D tensor.

About the author

Gottumukkala Sravan Kumar

B tech-hon's in Information Technology; Known programming languages - Python, R , PHP MySQL; Published 500+ articles on computer science domain