This article will exemplify the methods to change the data type of tensors in PyTorch.
How to Change/Modify the Data Type of a Tensor in PyTorch?
To convert and change the data type of a tensor to another type in PyTorch, different methods can be used:
Method 1: Convert Data Type of Tensor Using “to()” Method
The “to()” method takes a data type as an argument and returns a new tensor with the same shape and values but with the specified data type. Follow the provided examples to understand it better:
Example 1: Converting Floating Tensor to Integer Tensor
In the first example, we will create a zero tensor that has a float data type by default. We will convert its data type from float to integer. Let us explore the below instructions:
First, import the “torch” library:
Then, create a desired tensor and print its elements. For instance, we are creating a zero tensor “tens1” using the “torch.zeros()” function:
print(tens1)
This has created a 3×4 zero tensor:
After that, print the data type of the tensor using the “dtype” attribute:
The data type of “tens1” is “float32” (32-bit floating-point values):
Now, convert the data type of “tens1” tensor to “int32” using the “to()” method:
Finally, print the modified data type of the tensor:
The below output shows that the data type of the “tens1” tensor has been successfully changed from “float32” to “int32”:
Example 2: Converting Integer Tensor to Float Tensor
In the second example, we will create a simple 2D tensor from a list that has an integer data type by default. We will convert its data type from integer to float.
Utilize the “torch.tensor()” function to define a 2D tensor from the list with integer values and print its content. Here, we are creating the following tensor and storing it in a “tens2” variable:
print(tens2)
This has created a 2D integer tensor:
Then, view the data type of the tensor via the “.dtype” attribute:
According to the below output, the data type of the “tens1” tensor is “int64” (64-bit integer values):
Now, convert the data type of “tens2” tensor to “float32” using the “to()” method:
Finally, print the modified data type of the tensor:
The data type of the “tens2” tensor has been successfully changed from “int64” to “float32”:
Method 2: Convert Data Type of Tensor Using “type()” Method
The “type()” method takes a data type as an argument and returns a new tensor with the same shape and values but with the specified data type. Check out the provided examples to understand it better.
Example 1: Converting Integer Tensor to Floating Tensor
Here, we will create a simple 1D tensor from a list that has an integer data type by default. We will convert its data type from integer to floating-point.
First, import the “torch” library:
Then, define a 1D tensor from the list with integer values and print its content. Here, we are creating the following tensor and storing it in a “tensor1” variable:
print(tensor1)
This has created the 1D integer tensor:
Next, use the “.dtype” attribute to view the data type of the tensor:
The below output shows the data type of the “tensor1” tensor i.e., “int64” (64-bit integer values):
To convert the data type of “tensor1” tensor to floating-point, use the “type()” method. Here, we will convert the tensor into the “float16” data type:
Lastly, print the modified data type of the tensor:
According to the below output, the data type of the “tensor1” tensor has been successfully changed from “int64” to “float16”:
Example 2: Converting Floating Tensor to Integer Tensor
In this example, we will create a random tensor with random values that has a floating-point data type by default. We will convert its data type from float to integer.
Create/define a random tensor using the “torch.rand()” function and print its content. Here, we are creating the tensor with random values having 2 rows and 3 columns and storing it in a “tensor2” variable:
print(tensor2)
The 2×3 random tensor has been created successfully:
Next, view the data type of the tensor with the help of the “.dtype” attribute:
The below output shows that the data type of the “tensor2” tensor is “float32” (32-bit floating-point values):
Now, use the “type()” method to convert the data type of the “tensor2” tensor to an integer. Here, we will convert the tensor into the “int64” data type:
Finally, print the modified data type of the tensor:
It can be observed that the data type of the “tensor2” tensor has been successfully changed from “float32” to “int64”:
We have efficiently explained the methods of changing the data type of tensors in PyTorch with examples.
Note: You can access our Google Colab Notebook at this link.
Conclusion
To change the data type of tensors in PyTorch, first, import the “torch” library and create a desired tensor. After that, use the “dtype” attribute to print the specific tensor’s data type. The tensor’s data can be integer, float, or boolean. Next, change a particular tensor’s data type into the desired type using the “to()” or “type()” methods and print it. This article has exemplified the methods to change the data type of tensors in PyTorch.