pytorch

How to Normalize An Image in PyTorch?

PyTorch is an ideal choice of framework for the development of machine learning models that are focused on image analysis with large datasets. The thousands of images in such datasets have vastly different settings such as brightness and color saturation. The “Normalization” of an Image Dataset converts the values assigned to the image pixels to fall within a narrow range for faster processing.

In this blog, the spotlight will be on normalizing an image dataset in PyTorch.

What is the Normalization of an Image in PyTorch?

The “Normalization” of images used in PyTorch is an essential step in the optimization of the runtime for the image analysis models. Normalization is usually done before starting work on a project so that the data is cleaned and ready for training. The normalized image data is easier for the model to be trained upon and learn from to make credible inferences. The concept behind normalization is to linearly transform the numerical values of each image pixel to lie within a close range of “[0, 1]”. Mathematically, the mean is subtracted from each individual pixel value and the result is divided by the standard deviation of the dataset.

What is the Importance of Image Normalization?

The importance of normalization lies in model processing. The varied data can make the model focus on outliers and make grossly invalid assumptions and inferences. Moreover, the model may give more attention to an exceptionally accentuated aspect of the data such as the brightness of the images. The color profile of the images would not be efficiently trained if the images are too bright and have larger brightness values. Normalization ensures that each detail of the image is given equal weight.

How to Normalize an Image Dataset in PyTorch?

The normalization function in PyTorch, “transforms.Normalize()”, can accommodate two arguments which are “mean” and “standard deviation” of the image dataset.

Follow the steps below to learn how to normalize an image dataset in PyTorch:

Step 1: Launch Google Colab

Open the Google “Colaboratorywebsite and click on the “New Notebook” option to start the project as shown:

Step 2: Install and Import the Required Libraries

!pip install torch

import torch
import torchvision
from torchvision import transforms

 

To start working on PyTorch project, first install and import the “Torch” library into the IDE:

Step 3: Download the CIFAR-10 Dataset

In this tutorial, an image from the “CIFAR-10” image dataset will be used to demonstrate the normalization process. Go to the Kaggle website and download the CIFAR-10 dataset to use its images as shown:

Step 4: Upload an Image to PyTorch

Import the “files” library from google colab to upload an image from the system to PyTorch:

from google.colab import files

image = files.upload()

uploaded_airplane = list(image.keys())[0]

 

Click on the “Browse” option to select an image to upload from the system:

Next, the output of the uploaded and saved image is shown below:

Step 5: Load the Image into PyTorch

Use the “Image.open()” method to load the selected image into PyTorch as shown:

path_of_image = uploaded_airplane
uploaded_airplane_image = Image.open(path_of_image)

 

Step 6: Image to Tensor

Use the “transforms.Compose()” function to convert the uploaded image to a tensor for further calculations and put the “transforms.ToTensor()” function as its argument:

image_transformation = transforms.Compose([transforms.ToTensor()])

airplane_image_tensor = image_transformation(uploaded_airplane_image)

 

Step 7: Calculate the Mean and Standard Deviation

Next, calculate the “Mean” and “Standard Deviation” for normalization. Moreover, use the “print()” method to show the output of the calculation:

image_mean = airplane_image_tensor.mean(dim=(1, 2, -3))
image_STD = airplane_image_tensor.std(dim=(1, 2, -3))

print("Image Mean: ", image_mean)
print("Image STD: ", image_STD)

 

The below output shows the “Image Mean” and “Image STD”:

Step 8: Normalization

Lastly, normalize the uploaded image using the “transforms.Normalize()” function and place the previously calculated “mean” and “standard deviation” as the arguments of this function. Then, use the “print()” method to show the normalized image tensor in the output:

normalization = transforms.Normalize(mean=image_mean, std=image_STD)

transform_normalized = transforms.Compose([transforms.ToTensor(), normalization])

normalized_image = transform_normalized(uploaded_airplane_image)

print(normalized_image)

 

The output shows the normalize image tensor:

Note: To use complete code for image normalization, access our Google Colab Notebook at this link.

Pro-Tip

You can showcase the normalized image in the output by using the “matplotlib.pyplot” library as shown:

import matplotlib.pyplot as plt

normalized_airplane_image = normalized_image.numpy().transpose(1, 2, 0)

plt.imshow(normalized_airplane_image)
plt.title("Normalized Airplane Image")
plt.show()

 

Success! We have just demonstrated how to normalize an image in PyTorch and what are the benefits associated with normalized image datasets.

Conclusion

Normalize an image in PyTorch by first converting it into a Tensor then, calculating its mean and standard deviation, followed by its normalization. Additionally, you can show the normalized image in the output. This normalization can significantly reduce the runtime for image analysis models and improve the quality of inferences drawn from the available data. In this article, we have showcased what normalization is and how to implement it.

About the author

Shehroz Azam

A Javascript Developer & Linux enthusiast with 4 years of industrial experience and proven know-how to combine creative and usability viewpoints resulting in world-class web applications. I have experience working with Vue, React & Node.js & currently working on article writing and video creation.