Skip to content

PyTorch DataLoader: A Complete Guide

PyTorch DataLoaders Cover Image

In this tutorial, you’ll learn everything you need to know about the important and powerful PyTorch DataLoader class. PyTorch provides an intuitive and incredibly versatile tool, the DataLoader class, to load data in meaningful ways. Because data preparation is a critical step to any type of data work, being able to work with, and understand, DataLoaders is an important step in your deep learning journey.

By the end of this tutorial, you’ll have learned:

  • What the PyTorch DataLoader class is and how to use it
  • How to work with built-in datasets
  • How to access data and targets in a DataLoader object

What Does a PyTorch DataLoader Do?

The PyTorch DataLoader class is an important tool to help you prepare, manage, and serve your data to your deep learning networks. Because many of the pre-processing steps you will need to do before beginning training a model, finding ways to standardize these processes is critical for the readability and maintainability of your code.

The PyTorch DataLoader allows you to:

  1. Define a dataset to work with: identifying where the data is coming from and how it should be accessed.
  2. Batch the data: define how many training or testing samples to use in a single iteration. Because data are often split across training and testing sets of large sizes, being able to work with batches of data can allow your training and testing processes to be more manageable.
  3. Shuffle the data: PyTorch can handle shuffling data for you as it loads data into batches. This can increase representativeness in your dataset and prevent accidental skewness.
  4. Support multi-processing: PyTorch is optimized to run multiple processes at once in order to make better use of modern CPUs and GPUs and to save time in training and testing your data. The DataLoader class lets you define how many workers should go at once.
  5. Merge datasets together: optionally, PyTorch also allows you to merge multiple datasets together. While this may not be a common task, having it available to you is an a great feature.
  6. Load data directly on CUDA tensors: because PyTorch can run on the GPU, you can load the data directly onto the CUDA before they’re returned.

Now that you have a strong understanding of the benefits of using the PyTorch DataLoader class, let’s take a look at how they are defined.

Understanding the PyTorch DataLoader Class

Before we dive into how to use a PyTorch DataLoader to load your data, let’s take a look at the basic syntax that makes up a DataLoader class. The code block below shows the parameters available in the PyTorch DataLoader class:

# Understanding the PyTorch DataLoader Class
from torch.utils.data import DataLoader

from torch.utils.data import DataLoader

DataLoader(
    dataset, 
    batch_size=1, 
    shuffle=False, 
    sampler=None, 
    batch_sampler=None, 
    num_workers=0, 
    collate_fn=None, 
    pin_memory=False, 
    drop_last=False, 
    timeout=0, 
    worker_init_fn=None, 
    multiprocessing_context=None, 
    generator=None, 
    *, 
    prefetch_factor=2, 
    persistent_workers=False
)

From the code block above, you can see that the DataLoader class has a lot of different parameters available. Let’s take a look at some of the most important ones that we’ll explore throughout this tutorial:

  • dataset expects a PyTorch Dataset from which to load the data
  • batch_size represents how many samples per batch to load
  • shuffle indicates whether data should be shuffled at every epoch you run
  • sampler defines how to draw samples from the dataset. Cannot work when the dataset is being shuffled.
  • num_workers represents how many subprocesses to use for loading data.

Of course, one of the most important parameters is the actual dataset. Generally, you’ll be working with at least a training and a testing dataset. Because of this, it’s a convention that you’ll have at least two DataLoaders, to be able to load data for both your training and testing data.

PyTorch lets you define many different parameters to influence how data are loaded. This can have a big impact on the speed at which your model can train, how well it can train, and ensuring that data are sampled appropriately.

In the following section, you’ll learn how to use a PyTorch DataLoader to load a dataset in meaningful ways.

Creating and Using a PyTorch DataLoader

In this section, you’ll learn how to create a PyTorch DataLoader using a built-in dataset and how to use it to load and use the data. To keep things familiar, we’ll be working with one of the most popular datasets for deep learning, the MNIST dataset.

Let’s begin by loading the dataset and exploring it a little bit:

# Loading the MNIST Dataset Using PyTorch
# Importing Libraries
from torchvision.datasets import MNIST

# Downloading and Saving MNIST 
data_train = MNIST('~/mnist_data', train=True, download=True, transform=transforms.ToTensor())

# Print Data
# Dataset MNIST
#     Number of datapoints: 60000
#     Root location: /Users/nikpi/mnist_data
#     Split: Train
#     StandardTransform
# Transform: ToTensor()

In the code above, we first import the required dataset class. Then, we load the training data by instantiating the class. Finally, we print the dataset to see what it looks like.

We can see that the dataset has 60,000 records in the training set. Similar to other iterable objects in Python, we can access an item by accessing its index. Let’s take a look at the first item, by accessing the 0th index:

# Accessing a Dataset Item
print(data_train[0])

# Returns:
# (tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
# ...
#           0.0000, 0.0000, 0.0000, 0.0000]]]), 5)

We can see above that by accessing a dataset item, we get an image back, as well as its label. Similarly, we can visualize one this sample datapoint by using the imshow() function in Matplotlib:

# Visualizing a Sample
import matplotlib.pyplot as plt
plt.imshow(data_train.data[0])
plt.show()

This returns the following image:

Our first data point in the MNIST Dataset
Our first data point in the MNIST Dataset

Now that we have loaded our dataset, we can create our DataLoader object:

# Creating a Training DataLoader Object
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

# Downloading and Saving MNIST 
data_train = MNIST('~/mnist_data', train=True, download=True, transform=transforms.ToTensor())

# Creating Data Loader
data_loader = DataLoader(data_train, batch_size=20, shuffle=True)

print(data_loader)

# Returns:
# <torch.utils.data.dataloader.DataLoader object at 0x7fc3c021b6d0>

In the code above, we created a DataLoader object, data_loader, which loaded in the training dataset, set the batch size to 20 and instructed the dataset to shuffle at each epoch.

Iterating over a PyTorch DataLoader

Conventionally, you will load both the index of a batch and the items in the batch. We can do this using the enumerate() function to do this. Let’s use the DataLoader object to load the first batch. We’ll print out the shape of it to save space:

# Loading the First Batch and Printing Information
for idx, batch in enumerate(data_loader):
    print('Batch index: ', idx)
    print('Batch size: ', batch[0].size())
    print('Batch label: ', batch[1])
    break

# Returns:
# Batch index:  0
# Batch size:  torch.Size([20, 1, 28, 28])
# Batch label:  tensor([3, 3, 7, 7, 2, 4, 7, 2, 1, 8, 3, 3, 9, 3, 2, 3, 5, 0, 6, 8])

We can see in the code above that the first batch has 20 images, each with a single color channel (as they are grayscale), and are of size 28×28. Similarly, we were able to access the labels for all of the 20 images by accessing the second item in the return value. Note: generally the script would not use the break keyword – this is done only to prevent printing everything.

Now that you have your data loaded in batches, you’re able to move ahead with training your network!

Accessing Data and Targets in a PyTorch DataLoader

As you saw above, the code above, the DataLoader will return an object that contains both the data and the target (if the dataset contains both). We can access each item and its labels by iterating over the batches.

Let’s see how this is conventionally done:

# Accessing Data and Targets in a PyTorch DataLoader
for idx, (data, target) in enumerate(data_loader):
    print(data[0])
    print(target[0])
    break

# Returns:
# tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
# ...
#           0.0000, 0.0000, 0.0000, 0.0000]]])
# tensor(1)

In the code above, we printed the first item of the first batch, by accessing the data in it and its target. The data has been truncated to save space. Note: generally the script would not use the break keyword – this is done only to prevent printing everything.

Loading Data to a GPU (CUDA) With a PyTorch DataLoader

In this section, you’ll learn how to load data to a GPU (generally, CUDA) using a PyTorch DataLoader object. We can allow our code to be dynamic, allowing the program to identify whether it’s running on a GPU or a CPU. This prevents you from accidentally hard-coding elements of your program, causing it to fail if a CPU isn’t available.

Let’s take a look at how this is done in Python:

# Loading Data to a GPU with a PyTorch DataLoader Object
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import transforms
import torch

data_train = MNIST('~/mnist_data', train=True, download=True, transform=transforms.ToTensor())
data_loader = DataLoader(data_train, batch_size=20, shuffle=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for idx, (data, target) in enumerate(data_loader):
    data = data.to(device)
    target = target.to(device)

Let’s break down what we did in the code above:

  1. We imported the previous elements, but also imported torch
  2. We created a variable, device, which checks whether or not CUDA is available. If it is, it assigns 'cuda', otherwise it uses 'cpu'
  3. In our enumeration of the DataLoader object, we move both the data and the target onto the provided device

Conclusion

In this tutorial, you learned what the PyTorch DataLoader class is and how it can be implemented in practice. You learned what the benefit of using a DataLoader is an how they can be customized to meet your training and testing needs. Then, you learned how to use the PyTorch DataLoader class with a practical example. Finally, you learned how to iterate over batches of data and how to move data to a GPU.

Additional Resources

To learn more about related topics, check out the tutorials below:

Nik Piepenbreier

Nik is the author of datagy.io and has over a decade of experience working with data analytics, data science, and Python. He specializes in teaching developers how to use Python for data science using hands-on tutorials.View Author posts

Leave a Reply

Your email address will not be published. Required fields are marked *