pytorch

How to Flatten a Tensor in PyTorch?

In PyTorch, tensors are multidimensional arrays that can store data/values of different types and shapes. Flattening a tensor means reshaping it into a one-dimensional tensor. Users can change any 2D or 3D tensor to a 1D tensor by flattening it. PyTorch provides a “flatten()” method to flatten a tensor without changing its data.

This article will illustrate the following content:

How to Flatten a Tensor in PyTorch?

To flatten a tensor in PyTorch, follow the below-listed steps:

Step 1: Import PyTorch Library

First, import the “torch” library to use the “flatten()” method for flattening the input tensor:

import torch

Step 2: Create Input Tensor

Then, define the input tensor and print its elements. For instance, we are creating a 3D tensor using the “torch.tensor()” function:

tens = torch.tensor([[[1, 3],[5, 7]],[[9, 11],[13, 15]]])

print("Tensor: \n", tens)

This has created a 3D tensor as seen below:

Step 3: View Input Tensor Size

Next, use the “size()” attribute to view the size of the above-created “tens” tensor:

print(tens.size())

The shape of the “tens” tensor is 2x2x2:

Step 4: Flatten Input Tensor

Now, flatten the above input tensor via the “flatten()” method. It changes its shape which can be verified in the below step:

flatten_tens = torch.flatten(tens)

Step 5: Display the Flattened Tensor and Its Shape

Lastly, display the elements of the flattened tensor and the tensor shape:

print(flatten_tens)

print(flatten_tens.shape)

The below output shows that the 3D input tensor has been changed to a 1D tensor:

This indicates that the tensor has been flattened successfully.

How to Flatten a Tensor With Different Parameters in PyTorch?

Users can also use different parameters with the “torch.flatten()” method to flatten tensor elements by row or columns. These parameters include “start_dim” and “end_dim”. The “start_dim” flattens the first dim and the “end_dim” flattens the last dim.

To flatten a tensor with these parameters, look at the provided examples:

Example 1: Flatten Tensor With “start_dim” Parameter

In the first example, we will flatten the input “tens” tensor with the “start_dim” parameter. Here, we have specified the dimension “1” which the tensor will start flattening from dimension 1. The output tensor will be a 2D tensor:

flatten_st1 = torch.flatten(tens, start_dim=1)

print(flatten_st1)

It can be observed that the input tensor has been flattened with the specified parameter:

Example 2: Flatten Tensor With “end_dim” Parameter

In the second example, we will flatten the input “tens” tensor with the “end_dim” parameter. Here, we have specified the dimension “1” which means that the tensor will be flattened until dimension 1. The resultant tensor will be a 2D tensor:

flatten_ed = torch.flatten(tens, end_dim=1)

print(flatten_ed)

The below output shows the flattened tensor:

Example 3: Flatten Tensor With “start_dim” and “end_dim” Parameters

In the third example, we will flatten the tensor with both “start_dim” and “end_dim” parameters:

flatten_tens = torch.flatten(tens, start_dim=1, end_dim=2)

print(flatten_tens)

According to the below output, the input tensor has been flattened with the specified parameters:

We have efficiently explained the methods to flatten a tensor with different parameters in PyTorch.

Note: You can access our Google Colab Notebook at this link.

Conclusion

To flatten a tensor in PyTorch, first, import the torch library. Then, create the input tensor and view its elements and size. After that, use the “flatten()” method to flatten the input tensor to change its shape. Finally, print the elements of the flattened tensor and the tensor shape. Moreover, users can also use different parameters, such as “start_dim” and “end_dim” with the “torch.flatten()” method to flatten tensor elements by row or columns. This article has explained the method to flatten a tensor in PyTorch.

About the author

Laiba Younas

I have done bachelors in Computer Science. Being passionate about learning new technologies, I am interested in exploring different programming languages and sharing my experience with the world.