In this tutorial, you’ll learn everything you need to know about the important and powerful PyTorch DataLoader class. PyTorch provides an intuitive and incredibly versatile tool, the DataLoader class, to load data in meaningful ways. Because data preparation is a critical step to any type of data work, being able to work with, and understand, DataLoaders is an important step in your deep learning journey.
By the end of this tutorial, you’ll have learned:
- What the PyTorch DataLoader class is and how to use it
- How to work with built-in datasets
- How to access data and targets in a DataLoader object
Table of Contents
What Does a PyTorch DataLoader Do?
The PyTorch DataLoader class is an important tool to help you prepare, manage, and serve your data to your deep learning networks. Because many of the pre-processing steps you will need to do before beginning training a model, finding ways to standardize these processes is critical for the readability and maintainability of your code.
The PyTorch DataLoader
allows you to:
- Define a dataset to work with: identifying where the data is coming from and how it should be accessed.
- Batch the data: define how many training or testing samples to use in a single iteration. Because data are often split across training and testing sets of large sizes, being able to work with batches of data can allow your training and testing processes to be more manageable.
- Shuffle the data: PyTorch can handle shuffling data for you as it loads data into batches. This can increase representativeness in your dataset and prevent accidental skewness.
- Support multi-processing: PyTorch is optimized to run multiple processes at once in order to make better use of modern CPUs and GPUs and to save time in training and testing your data. The
DataLoader
class lets you define how many workers should go at once. - Merge datasets together: optionally, PyTorch also allows you to merge multiple datasets together. While this may not be a common task, having it available to you is an a great feature.
- Load data directly on CUDA tensors: because PyTorch can run on the GPU, you can load the data directly onto the CUDA before they’re returned.
Now that you have a strong understanding of the benefits of using the PyTorch DataLoader class, let’s take a look at how they are defined.
Understanding the PyTorch DataLoader Class
Before we dive into how to use a PyTorch DataLoader to load your data, let’s take a look at the basic syntax that makes up a DataLoader class. The code block below shows the parameters available in the PyTorch DataLoader
class:
# Understanding the PyTorch DataLoader Class
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader
DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
multiprocessing_context=None,
generator=None,
*,
prefetch_factor=2,
persistent_workers=False
)
From the code block above, you can see that the DataLoader
class has a lot of different parameters available. Let’s take a look at some of the most important ones that we’ll explore throughout this tutorial:
dataset
expects a PyTorchDataset
from which to load the databatch_size
represents how many samples per batch to loadshuffle
indicates whether data should be shuffled at every epoch you runsampler
defines how to draw samples from the dataset. Cannot work when the dataset is being shuffled.num_workers
represents how many subprocesses to use for loading data.
Of course, one of the most important parameters is the actual dataset. Generally, you’ll be working with at least a training and a testing dataset. Because of this, it’s a convention that you’ll have at least two DataLoaders, to be able to load data for both your training and testing data.
PyTorch lets you define many different parameters to influence how data are loaded. This can have a big impact on the speed at which your model can train, how well it can train, and ensuring that data are sampled appropriately.
In the following section, you’ll learn how to use a PyTorch DataLoader to load a dataset in meaningful ways.
Creating and Using a PyTorch DataLoader
In this section, you’ll learn how to create a PyTorch DataLoader using a built-in dataset and how to use it to load and use the data. To keep things familiar, we’ll be working with one of the most popular datasets for deep learning, the MNIST dataset.
Let’s begin by loading the dataset and exploring it a little bit:
# Loading the MNIST Dataset Using PyTorch
# Importing Libraries
from torchvision.datasets import MNIST
# Downloading and Saving MNIST
data_train = MNIST('~/mnist_data', train=True, download=True, transform=transforms.ToTensor())
# Print Data
# Dataset MNIST
# Number of datapoints: 60000
# Root location: /Users/nikpi/mnist_data
# Split: Train
# StandardTransform
# Transform: ToTensor()
In the code above, we first import the required dataset class. Then, we load the training data by instantiating the class. Finally, we print the dataset to see what it looks like.
We can see that the dataset has 60,000 records in the training set. Similar to other iterable objects in Python, we can access an item by accessing its index. Let’s take a look at the first item, by accessing the 0th index:
# Accessing a Dataset Item
print(data_train[0])
# Returns:
# (tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
# ...
# 0.0000, 0.0000, 0.0000, 0.0000]]]), 5)
We can see above that by accessing a dataset item, we get an image back, as well as its label. Similarly, we can visualize one this sample datapoint by using the imshow()
function in Matplotlib:
# Visualizing a Sample
import matplotlib.pyplot as plt
plt.imshow(data_train.data[0])
plt.show()
This returns the following image:
Now that we have loaded our dataset, we can create our DataLoader
object:
# Creating a Training DataLoader Object
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
# Downloading and Saving MNIST
data_train = MNIST('~/mnist_data', train=True, download=True, transform=transforms.ToTensor())
# Creating Data Loader
data_loader = DataLoader(data_train, batch_size=20, shuffle=True)
print(data_loader)
# Returns:
# <torch.utils.data.dataloader.DataLoader object at 0x7fc3c021b6d0>
In the code above, we created a DataLoader object, data_loader
, which loaded in the training dataset, set the batch size to 20 and instructed the dataset to shuffle at each epoch.
Iterating over a PyTorch DataLoader
Conventionally, you will load both the index of a batch and the items in the batch. We can do this using the enumerate()
function to do this. Let’s use the DataLoader object to load the first batch. We’ll print out the shape of it to save space:
# Loading the First Batch and Printing Information
for idx, batch in enumerate(data_loader):
print('Batch index: ', idx)
print('Batch size: ', batch[0].size())
print('Batch label: ', batch[1])
break
# Returns:
# Batch index: 0
# Batch size: torch.Size([20, 1, 28, 28])
# Batch label: tensor([3, 3, 7, 7, 2, 4, 7, 2, 1, 8, 3, 3, 9, 3, 2, 3, 5, 0, 6, 8])
We can see in the code above that the first batch has 20 images, each with a single color channel (as they are grayscale), and are of size 28×28. Similarly, we were able to access the labels for all of the 20 images by accessing the second item in the return value. Note: generally the script would not use the break
keyword – this is done only to prevent printing everything.
Now that you have your data loaded in batches, you’re able to move ahead with training your network!
Accessing Data and Targets in a PyTorch DataLoader
As you saw above, the code above, the DataLoader will return an object that contains both the data and the target (if the dataset contains both). We can access each item and its labels by iterating over the batches.
Let’s see how this is conventionally done:
# Accessing Data and Targets in a PyTorch DataLoader
for idx, (data, target) in enumerate(data_loader):
print(data[0])
print(target[0])
break
# Returns:
# tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
# ...
# 0.0000, 0.0000, 0.0000, 0.0000]]])
# tensor(1)
In the code above, we printed the first item of the first batch, by accessing the data in it and its target. The data has been truncated to save space. Note: generally the script would not use the break
keyword – this is done only to prevent printing everything.
Loading Data to a GPU (CUDA) With a PyTorch DataLoader
In this section, you’ll learn how to load data to a GPU (generally, CUDA) using a PyTorch DataLoader object. We can allow our code to be dynamic, allowing the program to identify whether it’s running on a GPU or a CPU. This prevents you from accidentally hard-coding elements of your program, causing it to fail if a CPU isn’t available.
Let’s take a look at how this is done in Python:
# Loading Data to a GPU with a PyTorch DataLoader Object
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import transforms
import torch
data_train = MNIST('~/mnist_data', train=True, download=True, transform=transforms.ToTensor())
data_loader = DataLoader(data_train, batch_size=20, shuffle=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
for idx, (data, target) in enumerate(data_loader):
data = data.to(device)
target = target.to(device)
Let’s break down what we did in the code above:
- We imported the previous elements, but also imported
torch
- We created a variable,
device
, which checks whether or not CUDA is available. If it is, it assigns'cuda'
, otherwise it uses'cpu'
- In our enumeration of the DataLoader object, we move both the data and the target onto the provided device
Conclusion
In this tutorial, you learned what the PyTorch DataLoader class is and how it can be implemented in practice. You learned what the benefit of using a DataLoader is an how they can be customized to meet your training and testing needs. Then, you learned how to use the PyTorch DataLoader class with a practical example. Finally, you learned how to iterate over batches of data and how to move data to a GPU.
Additional Resources
To learn more about related topics, check out the tutorials below: