Skip to content

PyTorch Transforms: Understanding PyTorch Transformations

PyTorch Transforms Understanding PyTorch Transformations Cover Image

In this tutorial, you’ll learn about how to use PyTorch transforms to perform transformations used to increase the robustness of your deep-learning models. In deep learning, the quality of data plays an important role in determining the performance and generalization of the models you build. PyTorch transforms are a collection of operations that can be applied to data, enabling us to manipulate, augment, and preprocess it effectively before feeding it into the model.

By the end of this tutorial, you’ll have a strong understanding of:

  • What PyTorch transforms are and why we use them
  • Examples of common PyTorch transformations that you’ll often apply
  • How to pass multiple transformations into a deep-learning model using Compose
  • How to integrate PyTorch transforms into torchvision Datasets

Understanding PyTorch Transforms

Your deep learning model’s success largely depends on the quality of the data that you feed into it, since it determines the performance and generalizations of a model. In most cases, however, raw data is rarely in the form that you need. This is where PyTorch transformations come into play.

PyTorch transforms provide the opportunity for two helpful functions:

  1. Data preprocessing: allows you to transform data into a suitable format for training
  2. Data augmentation: allows you to generate new training examples by applying various transformations on existing data

Both data preprocessing and data augmentation are essential for improving the robustness and effectiveness of machine learning models.

In this tutorial, we’ll dive into the torchvision transforms, which allow you to apply powerful transformations to images and other data. Let’s start off by importing the torchvision library and the transforms module.

# Importing the torchvision library
import torchvision
from torchvision import transforms

from PIL import Image
from IPython.display import display

import numpy as np

In the code block above, we imported torchvision, the transforms module, Image from PIL (to load our images) and numpy to identify some of our transformations.

I have provided a sample image here, which we’ll use throughout the tutorial to see the effect on images. Let’s begin by loading it into a variable. To do this, save the image and add the file path to the code block below:

# Loading a Sample Image
img = Image.open('resources/nik-from-datagy.jpeg')
display(img)

We can see that this is the same image from the sidebar! If you don’t want to look at my face throughout this tutorial, feel free to use a different image!

This returns the original image, shown below:

Original Image for PyTorch transforms Tutorial

Let’s now dive into some common PyTorch transforms to see what effect they’ll have on the image above.

Resizing with PyTorch Transforms

To start looking at some simple transformations, we can begin by resizing our image using PyTorch transforms. PyTorch provides an aptly-named transformation to resize images: transforms.Resize(). This allows you to pass in a tuple containing the size to which you want to resize.

Let’s take a look at how we can resize an image with PyTorch transformations:

# Resizing Images with PyTorch Transforms
img = Image.open('resources/nik-from-datagy.jpeg')
resize_transform = transforms.Resize((150, 150))
resized_image = resize_transform(img)

display(resized_image)

In the code block above, we used the transforms.Resize() class to resize our image. We then displayed the smaller image to see the effect that the transformation had on the image. This returned the image shown below:

Resizing Images with PyTorch transforms

We can see that the image is quite a bit smaller! Let’s now take a look at how to convert images to tensors.

Converting Images to Tensors with PyTorch Transforms

In many cases, you’ll need to convert your images or other data to PyTorch tensors. By using the transforms.ToTensor() transformation, you’re able to easily convert data (such as images) to tensors. Tensors provide many different functions – let’s take a quick look at a few benefits:

  1. Seamless Integration: Deep learning models, especially those built using PyTorch, expect input data in tensor format. Converting data to tensors enables smooth integration into the model.
  2. Efficient Computations: Tensors allow for efficient mathematical operations and computations required during model training and inference.
  3. Automatic Differentiation: Tensors in PyTorch enable automatic differentiation, a fundamental concept for training neural networks using gradient-based optimization algorithms.

Let’s now take a look at how we can convert our image to PyTorch tensors using our transforms module:

# Converting Images to Tensors with PyTorch Transforms
img = Image.open('resources/nik-from-datagy.jpeg')
tensor_transform = transforms.ToTensor()
tensor_image = tensor_transform(img)

print(tensor_image)

# Returns:
# tensor([[[0.7961, 0.8000, 0.8039,  ..., 0.7059, 0.7059, 0.7020],
#          [0.7882, 0.7922, 0.7961,  ..., 0.7098, 0.7059, 0.7020],
#          [0.7922, 0.7922, 0.7961,  ..., 0.7137, 0.7098, 0.7059],
#          ...,
#          [0.4667, 0.4627, 0.4745,  ..., 0.2667, 0.2745, 0.2824],
#          [0.4784, 0.4824, 0.4980,  ..., 0.2941, 0.3020, 0.3020],
#          [0.4941, 0.4824, 0.5020,  ..., 0.3059, 0.3137, 0.3216]]])

In the code block above, we first created our transformation object. This is optional (you could simply pass your data into the transformation). Then, we passed our data into this object to convert it to a tensor. When we printed it out, we saw that our image was now transformed into a PyTorch tensor!

Let’s now take a look at how to normalize data with PyTorch transformations.

Normalize Tensors with PyTorch Transforms

Normalization is one of the cornerstones of effective data preprocessing. It allows you to ensure that your input features are scaled and centered consistently, which often leads to better convergence during training.

The PyTorch Normalize transformation offers a convenient way to standardize input data, such as images, while ensuring that the underlying data distribution remains intact.

Normalizing data provides a number of benefits, including:

  1. Stable Training: Normalized data can lead to more stable and faster convergence during training, as it mitigates the issue of varying scales among input features.
  2. Reduced Sensitivity: Neural networks are less sensitive to input features that vary within different scales, leading to improved generalization.
  3. Numerical Stability: Normalized data can prevent numerical instability issues that might arise during computations involving large or small numbers.

Let’s now see how we can normalize our data PyTorch:

# Normalizing an Image with PyTorch transforms
normalize_transform = transforms.Normalize(
    mean=[0.5, 0.5, 0.5],
    std=[0.3, 0.3, 0.3]
)

normalized_image = normalize_transform(tensor_image)

# Transform the image to a PIL Image
normalized_image_show = transforms.ToPILImage()(normalized_image)
display(normalized_image_show)

In the code block above, we used our image that had been turned into a tensor. We then passed in some parameters into the transforms.Normalize() class. In particular, we set the mean and standard deviation of our normalization. This will depend entirely on the data itself, so make sure you find out these values before passing them in.

This returns the following image:

Normalize Images with PyTorch transforms

In the following section, you’ll learn how to flip images with PyTorch.

Flip Images with PyTorch Transforms

In this section, we’ll take a look at how to flip an image randomly using PyTorch. This can be done using the transforms.RandomHorizontalFlip() class, which by default has a probability of 50% of flipping an image. This can be helpful to augment some of your data, making it more generalizable.

Let’s see how we can flip an image using PyTorch:

# Randomly Flipping an Image with PyTorch transforms
img = Image.open('resources/nik-from-datagy.jpeg')

flip_transform = transforms.RandomHorizontalFlip()
flipped_image = flip_transform(img)

display(flipped_image)

In the code block above, we created a transform object, into which we later passed our image. When we used the display() function, we were able to return the image below:

Randomly Flip Images with PyTorch transforms

Let’s now take a look at a similar transformation: rotating images.

Rotate Images with PyTorch Transforms

Similar to flipping an image, we can also rotate an image. PyTorch provides a helpful transformation, transforms.RandomRotation(), which allows you to set a range of degrees by which to rotate images. You can pass in a tuple of degrees that represent the minimum and maximum rotation.

Similarly, you can pass in a single integer, where the rotation will be chosen from (-degrees, degrees). Let’s see how we can set the random rotation to be between (-30, 30):

# Randomly Rotate an Image with PyTorch transforms
img = Image.open('resources/nik-from-datagy.jpeg')

flip_transform = transforms.RandomRotation(degrees=30)
flipped_image = flip_transform(img)

display(flipped_image)

In the code block above, we instructed PyTorch to pick a degree between -30 and +30 degrees to rotate an image by. This adds significant augmentation to our data, allowing our models to generalize better.

Running this code block returns the image shown below:

Randomly Rotate Images with PyTorch transforms

Now, let’s take a look at one final transformation: color jitter,

ColorJitter Images with PyTorch Transforms

The final sample transformation we’ll take a look at in this tutorial is the PyTorch color jitter transformation. This allows you to define random alterations based on the following color transformations:

  1. Brightness,
  2. Contrast,
  3. Saturation, and
  4. Hue

Let’s see how we can pass in some values to adjust our image.

# Color Jitter an Image with PyTorch transforms
img = Image.open('resources/nik-from-datagy.jpeg')

color_jitter_transform = transforms.ColorJitter(
    brightness=0.7, contrast=0.2, saturation=0.5, hue=0.4
)
jittered = color_jitter_transform(img)

display(jittered)

Doing this allows you to set thresholds that may help mimic different cameras, lighting situations, etc. Because of this, you can more easily add variability to your data.

The code block above returns the image below:

Color Jitter Images with PyTorch transforms

Now that you have learned about some of the different transformations that are available, let’s see how we can combine multiple transformations into a single workflow.

Combining Multiple Transformations with PyTorch Compose

So far, you have learned about how to transform an image using a single transformation. In most cases, you’ll want to combine multiple transformations in your models. PyTorch makes this easy by using a Compose object, which allows you to pass in a list of transformations that your model can take on.

Let’s see how we can use a PyTorch Compose object to combine multiple transformations into a single workflow:

# Composing Multiple Transformations
img = Image.open('resources/nik-from-datagy.jpeg')

transformations = transforms.Compose([
    transforms.ColorJitter(brightness=0.7, contrast=0.2, saturation=0.5, hue=0.4),
    transforms.RandomRotation(degrees=30),
])

transformed = transformations(img)
display(transformed)

In the code block above, we combined the color jitter and the random rotation transformations. PyTorch provides flexibility to add more than these two transformations, but we’re keeping it simple for now!

The code block above returns the image below:

Combining Multiple Transformations with PyTorch Compose

Now that you know how to compose multiple transformations into one workflow, let’s see how transformations can feed into PyTorch datasets.

Integrating Transforms into PyTorch Datasets

Integrating data transformations seamlessly into your deep learning pipeline is essential for effective data preprocessing. PyTorch makes this integration straightforward through its support for applying transformations directly within dataset objects. Leveraging transformations within datasets not only streamlines the preprocessing process but also maintains data integrity and consistency throughout the training process.

PyTorch uses custom classes (such as DataLoaders and neural networks) to structure deep learning projects. PyTorch Datasets are an essential component of these projects. PyTorch Datasets provide a helpful way to organize your data, both for training and inference tasks.

Similarly, PyTorch Datasets allow you to easily integrate with other PyTorch components, such as DataLoaders which allow you to effortlessly batch your data during training. They also allow you to easily load data in efficient and parallel ways, making it easier to load large amounts of data.

Applying Transformations in PyTorch Datasets:

To incorporate transformations within PyTorch datasets, you can use the transforms argument of dataset classes such as torchvision.datasets.ImageFolder or your custom dataset.

Let’s take a look at an example:

import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder

# Define transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# Create a dataset with transformations
dataset = ImageFolder(root='path', transform=transform)

In this example, the transform sequence is applied to each image as it’s loaded from the dataset. This ensures that all images are resized, converted to tensors, and normalized before being used in training.

By using transformations directly within datasets, you simplify your code, maintain data consistency, and make your deep learning pipeline more efficient and effective. This approach empowers you to focus on the model itself, confident that your data is being properly prepared for training without the need for manual preprocessing steps.

Conclusion

In this comprehensive tutorial, you’ve embarked on a journey through the realm of PyTorch transforms – a powerful toolset that empowers you to preprocess and augment data effectively for deep learning tasks. By mastering the art of transformation, you’ve gained the knowledge to enhance your model’s performance, improve its generalization ability, and streamline the preprocessing pipeline.

Throughout this tutorial, you’ve covered the following key aspects:

  1. Introduction to PyTorch Transforms: You started by understanding the significance of data preprocessing and augmentation in deep learning. PyTorch transforms emerged as a versatile solution to manipulate, augment, and preprocess data, ultimately enhancing model performance.
  2. Common PyTorch Transformations: You explored a variety of common transformations, ranging from resizing, converting to tensors, and normalization to random horizontal flips, rotations, and color jitter. Each transformation serves a unique purpose, contributing to data variety and model robustness.
  3. Combining Transformations: By combining multiple transformations using PyTorch’s transforms.Compose(), you learned how to create complex preprocessing pipelines that manipulate data according to your needs, resulting in richer and more diverse datasets.
  4. Integration with PyTorch Datasets: You discovered the seamless integration of transformations within PyTorch datasets. Leveraging this capability, you can apply transformations directly to your data loading process, maintaining data consistency and integrity while effortlessly feeding preprocessed data into your models.

As you embark on your journey of building deep learning models, armed with the knowledge of PyTorch transforms, you can create pipelines that not only preprocess data efficiently but also enhance the performance and generalization of your models. By embracing the art of transformation, you’re well-equipped to tackle a wide array of data challenges and unlock the true potential of your deep learning endeavors.

Nik Piepenbreier

Nik is the author of datagy.io and has over a decade of experience working with data analytics, data science, and Python. He specializes in teaching developers how to use Python for data science using hands-on tutorials.View Author posts

Leave a Reply

Your email address will not be published. Required fields are marked *