Skip to content

PyTorch Dataset: How to Use Datasets in Deep Learning

PyTorch Dataset How to Use Datasets in Deep Learning

In this tutorial, you’ll learn about the PyTorch Dataset class and how they’re used in deep learning projects. PyTorch encapsulates much of its workflow in custom classes, such as DataLoaders and neural networks. Datasets are one of these classes and help us organize and load data for training and inference tasks.

By the end of this tutorial, you’ll have learned the following:

  • What PyTorch Datasets are and why they’re important
  • How to use built-in PyTorch Datasets
  • How to create your own Datasets in PyTorch
  • How to augment data using PyTorch Datasets

Understanding PyTorch Datasets in a Deep Learning Workflow

PyTorch uses custom classes (such as DataLoaders and neural networks) to structure deep learning projects. PyTorch Datasets are an essential component of these projects. PyTorch Datasets provide a helpful way to organize your data, both for training and inference tasks.

Similarly, PyTorch Datasets allow you to easily integrate with other PyTorch components, such as DataLoaders which allow you to effortlessly batch your data during training. They also allow you to easily load data in efficient and parallel ways, making it easier to load large amounts of data.

In short, PyTorch Datasets provide a bridge between raw data and the deep learning models you’re hoping to build. By using them, you can streamline your workflow, ensuring code reusability while allowing you to focus more on model development and experimentation.

Using Built-in PyTorch Datasets

PyTorch provides a number of different built-in datasets for a variety of different projects including:

  • Image Datasets available via the torchvision.datasets module
  • Text Datasets available via the torchtext.datasets module
  • Audio Datasets available via the torchaudio.datasets module

Built-in datasets serve as standardized benchmarks for testing and comparing models in different domains. This allows you to use these datasets to quickly prototype and experiment with different models and algorithms.

Let’s see how we can load the CIFAR-10 dataset, which contains ten different sets of image classes.

# Loading a Built-in Dataset
import torch
import torchvision
import matplotlib.pyplot as plt

# Load CIFAR-10 training dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)

# Returns:
# Downloading to ./data/cifar-10-python.tar.gz
# 100%|██████████| 170498071/170498071 [00:24<00:00, 6892672.57it/s]
# Extracting ./data/cifar-10-python.tar.gz to ./data

When you run the code above, we load training dataset available in PyTorch. Notice here that we’re specifying a few things:

  1. root= represents the directory of where to download the dataset to
  2. train= represents whether to download the training or testing dataset
  3. download= represents whether to download the dataset or not, which can be helpful if you’re re-running code later on

When we run this code, the dataset is downloaded. If we wanted to download the testing dataset, too, we could create a new variable and set train=False.

Because PyTorch Datasets are indexable, we can access one of the items using the code below:

# Define class labels for CIFAR-10
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In the code block above, we first defined a tuple containing our classes. Then, we used the Matplotlib imshow function to show the first item. You’ll notice that we access the 0th index of the first item: this represents the features of our data (which in this case is an array containing the image). Then, we assign the 1st index to the title, which represents our class label.

By running the code above, we return the following image:

Frog Image from PyTorch Dataset

In many cases, you’ll actually be working with your own data. In those cases, you’ll want to define your own datasets. Let’s dive into that in the following section.

Creating Custom Datasets in PyTorch

PyTorch provides significant flexibility in creating custom datasets, which allow you to handle diverse data types and tailor data processing to your specific needs. In this section, we’ll explore how to create custom datasets in PyTorch and how to use them for efficient data handling.

To create a PyTorch dataset, you define a class that inherits from the class. The class serves as the base for all PyTorch datasets and provides essential functionality, both for data loading and manipulation.

In fact, the PyTorch Dataset class is a Python abstract base class, meaning that certain methods are required to be implemented. In order to create a PyTorch Dataset, you are required to implement the following methods:

  • __len__, which should return the size of the dataset, meaning the total number of samples
  • __getitem__, which allows the indexing of the dataset, to retrieve a specific sample and its corresponding label

Let’s take a look at what the shell of a PyTorch Dataset looks like:

# Creating a PyTorch Dataset
from import Dataset

class CustomDataset(Dataset):
    def __init__(self):

    def __len__(self):

    def __getitem__(self):

In the code block above, we implemented the shell of a Dataset called CustomDataset. In it, we used the super() function to initialize our base class. Then, we defined our custom methods to get the length and access items within our dataset.

Let’s see what this looks like in practice. We’ll first define both feature and target vectors for a dataset and then create our class.

# Loading a Sample Dataset
from sklearn.datasets import make_blobs
X, y = make_blobs(n_samples=10000, n_features=2, centers=2, random_state=123)

# Plotting the Dataset
plt.scatter(X[:, 0], X[:, 1], c=y)
plt.title('Sample Dataset')

In the code block above, we used Scikit-Learn to create a sample dataset. The function returns two clusters with two dimensions. We then used Matplotlib to visualize the data, as shown below:

Overview of Sample Data for PyTorch dataset

Similarly, we can print out some of our records to help us better understand the underlying data:

# Exploring our sample data
for i, j in list(zip(X, y))[:5]:
    print(f'{i} --> {j}')

# Returns:
# [ 3.67376434 -7.07580241] --> 0
# [-5.71886901  2.84089272] --> 1
# [-5.34035256  0.77557054] --> 1
# [-6.49803939 -0.37869243] --> 1
# [ 4.82890527 -3.24697975] --> 0

In the code block above, we used the Python zip function iterate over both our datasets consecutively. We can see that each record in our feature matrix X contains two dimensions, while each record in our target vector contains a single scalar.

Let’s now see how we can create a PyTorch Dataset with our new data:

# Creating a PyTorch Dataset
from import Dataset

class BlobDataset(Dataset):
    def __init__(self, features, target):
        self.features = features = target

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        X = self.features[idx]
        y =[idx]

        return X, y

There’s a lot going on in the code block above, so let’s break it down section by section:

  1. In the __init__ method, we first run the super() function to allow us to access methods of the base class. We then create class attributes using the features and target being passed in.
  2. In the __len__ method, we return the length of the features vector.
  3. Finally, in the __getitem__ method, we accept an index, idx. We then create two variables, X and y, which represent the indexed values from those attributes. Finally, we return a tuple containing the values from X and y.

In the code block, we have only defined the Dataset class – we haven’t actually created our dataset. Let’s see how we can do this now:

# Instantiating Our Dataset
dataset = BlobDataset(X, y)

# Returns: <__main__.BlobDataset object at 0x7fa799466cd0>

In the code block above, we instantiated our dataset by passing our feature and target vectors. We can see that by printing it that we return our object.

Because we created both the __getitem__ and __len__ methods, we can now index (and slice) our data as well as get the dataset’s length. Let’s start with getting an item using its index:

# Accessing the first items in the dataset

# Returns: (array([ 3.67376434, -7.07580241]), 0)

In the code block above, we accessed the first item by indexing the 0th item. This returns a tuple containing both our features and our target.

Let’s now see how we can find the length of our by passing it into the len() function:

# Accessing the length of our dataset

# Returns: 10000

As expected, we can see that our dataset is 10,000 records long. Let’s now dive into how we can use PyTorch datasets to facilitate some augmentations.

Augmenting Data Using PyTorch Datasets

PyTorch datasets provide a valuable framework for augmenting data, offering a range of benefits and flexibility for enhancing the diversity and size of training datasets. Augmentation techniques play a crucial role in machine learning and deep learning tasks, allowing models to generalize better by exposing them to a wider range of variations in the data.

By leveraging PyTorch datasets, developers and researchers can easily integrate augmentation strategies into their workflows, ultimately leading to improved model performance and robustness.

In short, some of the benefits of using PyTorch Datasets for augmentation include:

  1. Seamless integration with the broader PyTorch ecosystem. Since augmentations can be applied during the data loading process, we can ensure that each batch is augmented increasing the diversity and variety of training examples.
  2. The flexibility of implementation into broader deep learning workflows. Users can implement a wide range of augmentation techniques.
  3. Augmenting data through PyTorch datasets helps models to generalize better.
  4. Data augmentation acts as a form of regularization, helping to prevent overfitting by introducing controlled variations into the training data.
  5. Augmentations enable the expansion of the training dataset without requiring additional labeled samples. By applying random transformations to existing data, the effective dataset size can be significantly increased, reducing the risk of overfitting and providing the model with more diverse examples to learn from.

Let’s see how we can apply some random noise to a our dataset throughout accessing of data records.

# Adding Augmentation to the Dataset
from import Dataset
import numpy as np

class BlobDataset(Dataset):
    def __init__(self, features, target):
        self.features = features = target

    def __len__(self):
        return len(self.features)

    # Add method to augment data
    def random_noise(self, x):
        noise = np.random.normal(loc=0, scale=0.1, size=x.shape)
        x = x + noise
        return x

    def __getitem__(self, idx):
        X = self.features[idx]
        y =[idx]

        # Apply random noise
        X = self.random_noise(X)

        return X, y

# Create the Dataset and Print the Same Record Twice
dataset = BlobDataset(X, y)

# Returns:
# (array([ 3.61370675, -7.00042829]), 0)
# (array([ 3.60227616, -7.12122985]), 0)

In the example above, we added an additional method that is used to introduce some amount of noise to the data. That way, when we access the same record (by printing it twice), we return features that are similar but also augmented slightly.

In practice, data augmentation can allow you to also transform the data in meaningful ways. For example, you can use methods to normalize data, add image transformations, and much more!

Now, let’s take a look at how the PyTorch Dataset class integrates with PyTorch DataLoaders.

Using Datasets with PyTorch DataLoaders

The PyTorch DataLoader class is an important tool to help you prepare, manage, and serve your data to your deep learning networks. Because many of the pre-processing steps you will need to do before beginning training a model, finding ways to standardize these processes is critical for the readability and maintainability of your code.

The PyTorch DataLoader allows you to:

  1. Define a dataset to work with: identifying where the data is coming from and how it should be accessed.
  2. Batch the data: define how many training or testing samples to use in a single iteration. Because data are often split across training and testing sets of large sizes, being able to work with batches of data can allow your training and testing processes to be more manageable.
  3. Shuffle the data: PyTorch can handle shuffling data for you as it loads data into batches. This can increase representativeness in your dataset and prevent accidental skewness.
  4. Support multi-processing: PyTorch is optimized to run multiple processes at once in order to make better use of modern CPUs and GPUs and to save time in training and testing your data. The DataLoader class lets you define how many workers should go at once.
  5. Merge datasets together: optionally, PyTorch also allows you to merge multiple datasets together. While this may not be a common task, having it available to you is an a great feature.
  6. Load data directly on CUDA tensors: because PyTorch can run on the GPU, you can load the data directly onto the CUDA before they’re returned.

Let’s see how we can create a training and testing DataLoader:

# Creating Training and Testing Data Loaders
from import DataLoader, random_split
train_data, test_data = random_split(dataset, [0.8, 0.2])

train_loader = DataLoader(train_data, batch_size=32)
test_loader = DataLoader(test_data, batch_size=32)

In the code block above, we split our dataset into training and testing data using the random_split function. Then, we passed these datasets into the DataLoader class with a batch size of 32.

Even with custom datasets, PyTorch handles DataLoaders effectively. This allows us to create custom workflows while still relying on built-in functionality.

Frequently Asked Questions (FAQ)

What is a PyTorch Dataset?

A PyTorch Dataset is a class in the PyTorch library that represents a collection of data samples and their corresponding labels, designed for easy integration with deep learning models. It allows you to organize and preprocess your data, making it ready for training and evaluation. PyTorch Datasets provide an interface to access and manipulate data efficiently, enabling seamless integration into the PyTorch workflow.

How do I load my own dataset in PyTorch?

To load your own dataset in PyTorch, you can create a custom dataset by subclassing the class. In this custom dataset class, you need to implement the __len__ method to return the total number of samples and the __getitem__ method to return a specific sample and its corresponding label. Once you have created your custom dataset, you can use PyTorch’s DataLoader to efficiently load and iterate over the data in batches, enabling smooth integration with your deep learning models.

What is the difference between PyTorch Dataset and TensorDataset?

PyTorch Dataset is a base class that allows you to define and organize your custom datasets. It provides an interface for accessing individual samples and labels, and you can implement custom transformations and preprocessing logic within this class.
On the other hand, TensorDataset is a specific implementation of PyTorch Dataset that is designed to handle datasets composed of tensors. It takes one or more tensors as input and treats them as individual columns of the dataset. TensorDataset simplifies the process of working with tensor-based datasets, allowing you to directly access and manipulate the tensors during training or evaluation.


In conclusion, understanding PyTorch and its fundamental concepts is crucial for anyone interested in deep learning and neural networks. In this article, we explored the key components of PyTorch, including tensors, neural networks, and optimization techniques. We discussed the benefits of using PyTorch, such as its dynamic computational graph, efficient GPU utilization, and extensive collection of prebuilt models and modules.

We also delved into the process of building and training neural networks using PyTorch, highlighting essential steps like defining the network architecture, handling data with datasets and dataloaders, and optimizing the model through backpropagation and gradient descent. Additionally, we touched on advanced topics such as transfer learning, saving and loading models, and deploying them to production.

By following the guidelines and examples provided in this article, you now have a solid foundation to start your journey in applying PyTorch for various deep-learning tasks. Remember to practice and experiment with different models, datasets, and techniques to expand your knowledge and improve your skills.

To learn more about PyTorch Datasets, check out the official documentation.

Nik Piepenbreier

Nik is the author of and has over a decade of experience working with data analytics, data science, and Python. He specializes in teaching developers how to use Python for data science using hands-on tutorials.View Author posts

Leave a Reply

Your email address will not be published. Required fields are marked *