An Introduction To PyTorch Dataset and DataLoader
Why Write Good Data Loaders and Datasets?
The Basic PyTorch Dataset Structure
Implementing A Custom Dataset In PyTorch
Best Practices For Creating Custom Datasets
The Basic PyTorch DataLoader Class Structure
Example: Creating A Data Loader From A Dataset
Using Custom Samplers For More Control Over Data Loading
In this tutorial, we’ll go through the PyTorch data primitives, namely torch.utils.data.DataLoader and torch.utils.data.Dataset, and understand how the pre-loaded datasets work and how to create our own DataLoader and Datasets by subclassing these modules.
Why should we learn how to write good data loaders and datasets? Isn’t modeling the most important of a Deep Learning Pipeline.
Your training pipeline should be as modular as possible in order to aid in quick prototyping and maintain usability. Using a poorly-written data loader / not using a data loader (using a python generator or some function), can affect the parallelization ability of your code. Dataset processing is a highly important part of any training pipeline and should be kept separate from modeling.
The same technique won’t work everywhere. Some problems might require you to use image augmentations, therefore you’d prefer to have an argument (something like data = Dataset(…, fetch = True) ) to test the model’s performance. Or you might need to experiment with different sequence lengths and strides for fine-tuning an NLP model. To these ends, it’s recommended to use custom Datasets and DatatLoaders.
The following code snippet contains the original implementation of the Dataset class from PyTorch. All pre-loaded Datasets inherit from this basic structure.
class Dataset(...):
# Raises NotImplementedError
def __getitem__(self, index):
# Allows us to Add/Concat Datasets
def __add__(self, other):
# Returns the Attribute value or raises a AttributeError
def __getattr__(self, attribute_name):
# Utility methods to "Register" Functions
@classmethod
def register_function(cls, ...):
# Utility methods to "Register" Functions
@classmethod
def register_datapipe_as_function(cls, ...):
As it has such a simple structure, you don’t always need to inherit from torch.utils.data.Dataset. In most cases, we can get away by writing some key functions.
Now, for most purposes, you will need to write your own implementation of a Dataset. So let’s see how you can write a custom dataset by subclassing torch.utils.data.Dataset.
You’ll need to implement 3 functions:
class CustomDataset(torch.utils.data.Dataset):
# Basic Instantiation
def __init__(self, ..., *args, **kwargs):
...
# Length of the Dataset
def __len__(self):
...
# Fetch an item from the Dataset
def __getitem__(self, idx):
...
Let’s walk through some examples of Custom Datasets.
This code snippet is taken from my Kaggle Kernel on Neural Image Captioning. Let’s walk through the code:
class FlickrDataset(Dataset):
def __init__(self, df,
transforms):
self.df = df
self.transforms = T.Compose([
T.ToTensor(),
T.Normalize(mean = [0.5], std = [0.5]),
T.Resize((256,256)),
])
def __len__(self) -> int:
return len(self.df)
def __getitem__(self, idx: int):
image_id = self.df.image_name.values[idx]
image = Image.open(image_id).convert('RGB')
if self.transforms is not None:
image = self.transforms(image)
comments = self.df[self.df.image_name == image_id].values.tolist()[0][1:][0]
encoded_inputs = tokenizer(comments,
return_token_type_ids = False,
return_attention_mask = False,
max_length = 100,
padding = "max_length",
return_tensors = "pt")
sample = {"image":image.to(device),
"captions": encoded_inputs["input_ids"].flatten().to(device)
}
return sample
This code snippet is taken from my Custom Wrapper for the RSNA-MICCAI Brain Tumor Radiogenomic Classification Kaggle Competition. Let’s walk through the code :
class Dataset(torch_data.Dataset):
def __init__(
self,
paths,
targets=None,
mri_type=None,
label_smoothing: float = 0.01,
split: str = "train",
augment: bool = False,
):
self.paths = paths
self.targets = targets
self.mri_type = mri_type
self.label_smoothing = label_smoothing
self.split = split
self.augment = augment
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
scan_id = self.paths[index]
if self.targets is None:
data = load_dicom_images_3d(
str(scan_id).zfill(5), mri_type=self.mri_type[index], split=self.split
)
else:
data = load_dicom_images_3d(
str(scan_id).zfill(5), mri_type=self.mri_type[index], split="train"
)
if self.augment:
data = seq(images=data)
if self.targets is None:
return {"X": torch.tensor(data).float(), "id": scan_id}
else:
y = torch.tensor(
abs(self.targets[index] - self.label_smoothing), dtype=torch.float
)
return {"X": torch.tensor(data).float(), "y": y}
There are some general things you need to remember while creating custom datasets.
Sometimes when working with a big dataset it becomes quite difficult to load the entire data into the memory at once. As such the only way forward is to load data into memory in batches for processing, this means you may have to write extra code to do this. But do not worry, PyTorch has you covered with its Dataloader function.
The dataloader function is available in PyTorch torch.utils.data class and supports the following tasks –
The following code snippet contains the original implementation of the DataLoader class from PyTorch.
class DataLoader(...):
# Basic __init__ function
def __init__(self,..):
# Returns Either a Single or a Multi Process Iterator
def _get_iterator(self):
# Handle Multiprocessing
@property
def multiprocessing_context(self):
# Handle Multiprocessing
@multiprocessing_context.setter
def multiprocessing_context(self, multiprocessing_context):
# Override default __setattr__ method
def __setattr__(self, attr, val):
# Override default __iter__ method
def __iter__(self):
# Helper Function for collation
@property
def _auto_collation(self):
# The Actual Sampler Used for fetching
@property
def _index_sampler(self):
# Returns the length of the Index Sampler (in case of map-style dataset)
def __len__(self) -> int:
# Checks if the worker number is rational based on system resource
def check_worker_number_rationality(self):
Now, this does look complicated, but in most cases, we don’t need to understand most of this. But it’s nice to know how PyTorch takes care of multiprocessing and handling different types of Iterators.
The following section shows the syntax of dataloader function in PyTorch library along with the information of its parameters.
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
Parameters
This first example will showcase how the built-in MNIST dataset of PyTorch can be handled with dataloader function. (MNIST is a famous dataset that contains hand-written digits.)
import torch
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
Here in this example, we are using the transforms module of torchvision. It is generally used when we have to handle image datasets and can help in normalizing, resizing, and cropping of the images.
For this MNIST dataset, we are using the normalization technique. This way the values from -0.5 to +0.5 are converted to values from 0 to 1.
The following code that contains the transforms function is used for normalization.
# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
])
The following code snippet is used for loading the desired dataset. We are using PyTorch dataloader to load the data by giving batch_size = 64 and we have also enabled shuffling for reordering data each epoch of data load.
# Download and load the training data
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
Extracting /root/.pytorch/MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to /root/.pytorch/MNIST_data/MNIST/raw
Processing...
/usr/local/lib/python3.7/dist-packages/torchvision/datasets/mnist.py:502: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:143.)
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
Done!
For fetching all the images of the dataset, we are going to use iter function along with a dataloader.
In [5]:
dataiter = iter(trainloader)
images, labels = dataiter.next()
print(images.shape)
print(labels.shape)
plt.imshow(images[1].numpy().squeeze(), cmap='Greys_r')
Output:
torch.Size([64, 1, 28, 28])
torch.Size([64])
<matplotlib.image.AxesImage at 0x7fdc324cdb50>
This second example shows how we can use PyTorch dataloader on custom datasets. So let us first create a custom dataset.
The below code snippet helps us to create a custom dataset that contains 1000 random numbers.
from torch.utils.data import Dataset
import random
class SampleDataset(Dataset):
def __init__(self,r1,r2):
randomlist=[]
for i in range(1,1000):
n = random.randint(r1,r2)
randomlist.append(n)
self.samples=randomlist
def __len__(self):
return len(self.samples)
def __getitem__(self,idx):
return(self.samples[idx])
dataset=SampleDataset(4,445)
dataset[100:120]
Output:
[435,
117,
315,
266,
279,
441,
364,
383,
241,
299,
146,
124,
74,
128,
404,
400,
214,
237,
40,
382]
Finally, we will be to use the dataloader function on our custom dataset. Notice that we have given the batch_size as 12 and have also enabled parallel multiprocess data loading with num_workers =2.
The output shows that the loaded data is divided into 12 different batches. Some of the tensors are displayed for reference.
from torch.utils.data import DataLoader
loader = DataLoader(dataset,batch_size=12, shuffle=True, num_workers=2 )
for i, batch in enumerate(loader):
print(i, batch)
Output:
0 tensor([ 16, 179, 246, 127, 263, 418, 33, 410, 107, 281, 438, 164])
1 tensor([421, 55, 183, 19, 47, 402, 336, 290, 241, 121, 308, 140])
2 tensor([265, 149, 62, 421, 67, 427, 302, 149, 134, 269, 116, 267])
3 tensor([318, 404, 365, 324, 229, 184, 10, 391, 71, 424, 387, 256])
4 tensor([178, 138, 200, 398, 420, 98, 147, 338, 341, 434, 58, 332])
5 tensor([403, 256, 290, 238, 186, 57, 343, 361, 388, 81, 271, 111])
6 tensor([340, 59, 73, 298, 275, 102, 20, 413, 95, 83, 380, 323])
7 tensor([ 71, 15, 443, 44, 394, 252, 103, 11, 383, 292, 57, 109])
8 tensor([398, 406, 84, 369, 272, 409, 367, 205, 353, 24, 305, 21])
9 tensor([280, 200, 79, 424, 26, 58, 233, 194, 362, 379, 228, 428])
10 tensor([316, 225, 231, 272, 382, 132, 306, 295, 150, 365, 420, 17])
11 tensor([280, 432, 51, 123, 356, 29, 172, 225, 143, 147, 226, 262])
12 tensor([208, 366, 267, 389, 135, 398, 359, 365, 52, 210, 152, 214])
.
.
.
69 tensor([ 43, 351, 383, 435, 368, 26, 316, 145, 409, 140, 224, 159])
70 tensor([210, 68, 404, 30, 32, 324, 18, 416, 340, 354, 337, 436])
71 tensor([414, 114, 233, 320, 105, 318, 326, 139, 319, 205, 69, 123])
72 tensor([165, 265, 381, 33, 392, 261, 57, 23, 131, 186, 232, 186])
73 tensor([404, 105, 345, 436, 51, 392, 263, 138, 364, 439, 12, 295])
74 tensor([163, 70, 137, 435, 250, 354, 190, 335, 39, 323, 365, 96])
75 tensor([148, 383, 322, 300, 309, 125, 46, 29, 231, 432, 258, 376])
76 tensor([314, 266, 248, 236, 296, 434, 93, 138, 140, 12, 444, 302])
77 tensor([ 41, 257, 13, 64, 295, 330, 396, 251, 379, 232, 108, 364])
78 tensor([ 70, 161, 168, 41, 434, 258, 327, 270, 42, 347, 384, 282])
79 tensor([392, 13, 258, 416, 146, 308, 32, 276, 302, 177, 410, 263])
80 tensor([186, 433, 420, 11, 273, 230, 377, 416, 303, 83, 20, 240])
81 tensor([ 47, 354, 171, 207, 178, 351, 137, 138, 33, 224, 422, 280])
82 tensor([214, 193, 444, 432, 274, 268, 67, 217, 64, 84, 27, 102])
83 tensor([419, 62, 244])
Most pre-loaded datasets from Torchvision return torch.utils.data.Dataset objects thus enabling us to directly feed them into the torch.utils.data.DataLoader class and then enumerate through them in our training loop.
For example, this code snippet from the PyTorch tutorials shows how easily, we can create data loaders using pre-loaded datasets from torchvision.
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
training_data = datasets.FashionMNIST(root="data", train=True, download=True, transform=ToTensor())
test_data = datasets.FashionMNIST(root="data", train=False, download=True, transform=ToTensor())
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
The aforementioned code example returns mini-batches of data with the provided batch size.
For even more control over your data loading use custom Samplers. Every subclass must contain a __iter__ method and a __len__ method to specify enumeration. For more information refer to the PyTorch docs.
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
Pandas is not essential to create a Dataset object. However, it’s a powerful tool for managing data so I’m going to use it. torch.utils.data imports the required functions we need to create and use Dataset and DataLoader.
class CustomTextDataset(Dataset):
def __init__(self, txt, labels):
self.labels = labels
self.text = textdef __len__(self):
return len(self.labels)def __getitem__(self, idx):
label = self.labels[idx]
text = self.text[idx]
sample = {"Text": text, "Class": label}
return sample
class CustomTextDataset(Dataset): Create a class called ‘CustomTextDataset’, this can be called whatever you want. Passed in to the class is the dataset module which we imported earlier.
def __init__(self, text, labels): When you initialize the class you need to import two variables. In this case, the variables are called ‘text’ and ‘labels’ to match the data which will be added.
self.labels = labels & self.text = text: The imported variables can now be used in functions within the class by using self.text or self.labels.
def __len__(self): This function just returns the length of the labels when called. E.g., if you had a dataset with 5 labels, then the integer 5 would be returned.
def __getitem__(self, idx): This function is used by Pytorch’s Dataset module to get a sample and construct the dataset. When initialized, it will loop through this function creating a sample from each instance in the dataset.
# define data and class labels
text = ['Happy', 'Amazing', 'Sad', 'Unhapy', 'Glum']
labels = ['Positive', 'Positive', 'Negative', 'Negative', 'Negative']# create Pandas DataFrame
text_labels_df = pd.DataFrame({'Text': text, 'Labels': labels})# define data set object
TD = CustomTextDataset(text_labels_df['Text'], text_labels_df['Labels'])
First, we create two lists called ‘text’ and ‘labels’ as an example.
text_labels_df = pd.DataFrame({‘Text’: text, ‘Labels’: labels}): This is not essential, but Pandas is a useful tool for data management and pre-processing and will probably be used in your PyTorch pipeline. In this section, the lists ‘text’ and ‘labels’ containing the data are saved in a Pandas DataFrame.
TD = CustomTextDataset(text_labels_df[‘Text’], text_labels_df[‘Labels’]): This initialises the class we made earlier with the ‘Text’ and ‘Labels’ data being passed in. This data will become ‘self.text’ and ‘self.labels’ within the class. The Dataset is saved under the variable named TD.
The Dataset is now initialized and ready to be used!
This will show you how the data is stored within the Dataset.
# Display text and label.
print('\nFirst iteration of data set: ', next(iter(TD)), '\n')# Print how many items are in the data set
print('Length of data set: ', len(TD), '\n')# Print entire data set
print('Entire data set: ', list(DataLoader(TD)), '\n')
Output:
First iteration of data set: {‘Text’: ‘Happy’, ‘Class’: ‘Positive’}
Length of data set: 5
Entire data set: [{‘Text’: [‘Happy’], ‘Class’: [‘Positive’]}, {‘Text’: [‘Amazing’], ‘Class’: [‘Positive’]}, {‘Text’: [‘Sad’], ‘Class’: [‘Negative’]}, {‘Text’: [‘Unhapy’], ‘Class’: [‘Negative’]}, {‘Text’: [‘Glum’], ‘Class’: [‘Negative’]}]
In machine learning or deep learning text needs to be cleaned and turned in to vectors prior to training. DataLoader has a handy parameter called collate_fn. This parameter allows you to create separate data processing functions and will apply the processing within that function to the data before it is output.
def collate_batch(batch): word_tensor = torch.tensor([[1.], [0.], [45.]])
label_tensor = torch.tensor([[1.]])
text_list, classes = [], [] for (_text, _class) in batch:
text_list.append(word_tensor)
classes.append(label_tensor) text = torch.cat(text_list)
classes = torch.tensor(classes) return text, classesDL_DS = DataLoader(TD, batch_size=2, collate_fn=collate_batch)
As an example, two tensors are created to represent the word and class. In practice, these could be word vectors passed in through another function. The batch is then unpacked and then we add the word and label tensors to lists.
The word tensors are then concatenated and the list of class tensors, in this case, 1, are combined into a single tensor. The function will now return processed text data ready for training.
To activate this function you simply add the parameter collate_fn=Your_Function_name when initializing the DataLoader object.
We will iterate through the Dataset without using collate_fn because it’s easier to see how the words and classes are being output by DataLoader. If the above function were used with collate_fn then the output would be tensors.
DL_DS = DataLoader(TD, batch_size=2, shuffle=True)for (idx, batch) in enumerate(DL_DS): # Print the 'text' data of the batch
print(idx, 'Text data: ', batch['Text']) # Print the 'class' data of batch
print(idx, 'Class data: ', batch['Class'], '\n')
DL_DS = DataLoader(TD, batch_size=2, shuffle=True) : This initialises DataLoader with the Dataset object “TD” which we just created. In this example, the batch size is set to 2. This means that when you iterate through the Dataset, DataLoader will output 2 instances of data instead of one. For more information on batches see this article. Shuffle will reshuffle the data at each epoch, this prevents the model from learning the order of training data.
for (idx, batch) in enumerate(DL_DS): Iterate through the data in the DataLoader object we just created. enumerate(DL_DS) returns the index number of the batch and the batch consisting of two data instances.
Output:
As you can see, the 5 data instances we created are output in batches of 2. Since we have an odd number of training examples the last one is output in its own batch. Each number — 0,1 or 2 represents a batch.
Full code:
# Import libraries
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
# create custom dataset class
class CustomTextDataset(Dataset):
def __init__(self, text, labels):
self.labels = labels
self.text = text
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
label = self.labels[idx]
data = self.text[idx]
sample = {"Text": data, "Class": label}
return sample
# define data and class labels
text = ['Happy', 'Amazing', 'Sad', 'Unhapy', 'Glum']
labels = ['Positive', 'Positive', 'Negative', 'Negative', 'Negative']
# create Pandas DataFrame
text_labels_df = pd.DataFrame({'Text': text, 'Labels': labels})
# define data set object
TD = CustomTextDataset(text_labels_df['Text'], text_labels_df['Labels'])
# Display image and label.
print('\nFirst iteration of data set: ', next(iter(TD)), '\n')
# Print how many items are in the data set
print('Length of data set: ', len(TD), '\n')
# Print entire data set
print('Entire data set: ', list(DataLoader(TD)), '\n')
# collate_fn
def collate_batch(batch):
word_tensor = torch.tensor([[1.], [0.], [45.]])
label_tensor = torch.tensor([[1.]])
text_list, classes = [], []
for (_text, _class) in batch:
text_list.append(word_tensor)
classes.append(label_tensor)
text = torch.cat(text_list)
classes = torch.tensor(classes)
return text, classes
# create DataLoader object of DataSet object
bat_size = 2
DL_DS = DataLoader(TD, batch_size=bat_size, shuffle=True)
# loop through each batch in the DataLoader object
for (idx, batch) in enumerate(DL_DS):
# Print the 'text' data of the batch
print(idx, 'Text data: ', batch, '\n')
# Print the 'class' data of batch
print(idx, 'Class data: ', batch, '\n')
Resources:
https://blog.paperspace.com/dataloaders-abstractions-pytorch/
https://medium.com/analytics-vidhya/creating-a-custom-dataset-and-dataloader-in-pytorch-76f210a1df5d