PyTorch Custom Datasets
PyTorch Custom Datasets
Open in Colab
In the last notebook, notebook 03, we looked at how to build computer vision models on an in-built dataset
in PyTorch (FashionMNIST).
The steps we took are similar across many different problems in machine learning.
Find a dataset, turn the dataset into numbers, build a model (or Jnd an existing model) to Jnd patterns in
those numbers that can be used for prediction.
PyTorch has many built-in datasets used for a wide number of machine learning benchmarks, however,
you'll often want to use your own custom dataset.
For example, if we were building a food image classiJcation app like Nutrify, our custom dataset might be
images of food.
Or if we were trying to build a model to classify whether or not a text-based review on a website was
positive or negative, our custom dataset might be examples of existing customer reviews and their ratings.
Or if we were trying to build a sound classiJcation app, our custom dataset might be sound samples
alongside their sample labels.
Or if we were trying to build a recommendation system for customers purchasing things on our website,
our custom dataset might be examples of products other people have bought.
PyTorch includes many existing functions to load in various custom datasets in the TorchVision ,
TorchText , TorchAudio and TorchRec domain libraries.
In that case, we can always subclass torch.utils.data.Dataset and customize it to our liking.
But instead of using an in-built PyTorch dataset, we're going to be using our own dataset of pizza, steak
and sushi images.
The goal will be to load these images and then build a model to train and predict on them.
What we're going to build. We'll use torchvision.datasets as well as our own custom Dataset class to
load in images of food and then we'll build a PyTorch computer vision model to hopefully be able to classify
them.
Topic Contents
0. Importing PyTorch and Let's get PyTorch loaded and then follow best practice to setup our code to be device-
setting up device- agnostic.
agnostic code
1. Get data We're going to be using our own custom dataset of pizza, steak and sushi images.
2. Become one with the At the beginning of any new machine learning problem, it's paramount to understand the
data (data preparation) data you're working with. Here we'll take some steps to Jgure out what data we have.
3. Transforming data Often, the data you get won't be 100% ready to use with a machine learning model, here
we'll look at some steps we can take to transform our images so they're ready to be used
with a model.
4. Loading data with PyTorch has many in-built data loading functions for common types of data.
ImageFolder (option 1) ImageFolder is helpful if our images are in standard image classiJcation format.
5. Loading image data What if PyTorch didn't have an in-built function to load data with? This is where we can
with a custom Dataset build our own custom subclass of torch.utils.data.Dataset .
6. Other forms of Data augmentation is a common technique for expanding the diversity of your training
transforms (data data. Here we'll explore some of torchvision 's in-built data augmentation functions.
augmentation)
7. Model 0: TinyVGG By this stage, we'll have our data ready, let's build a model capable of Jtting it. We'll also
without data create some training and testing functions for training and evaluating our model.
augmentation
8. Exploring loss curves Loss curves are a great way to see how your model is training/improving over time.
They're also a good way to see if your model is underPtting or overPtting.
9. Model 1: TinyVGG By now, we've tried a model without, how about we try one with data augmentation?
with data augmentation
10. Compare model Let's compare our different models' loss curves and see which performed better and
results discuss some options for improving performance.
11. Making a prediction Our model is trained to on a dataset of pizza, steak and sushi images. In this section we'll
on a custom image cover how to use our trained model to predict on an image outside of our existing
dataset.
If you run into trouble, you can ask a question on the course GitHub Discussions page there too.
And of course, there's the PyTorch documentation and PyTorch developer forums, a very helpful place for
all things PyTorch.
Out[1]: '1.12.1+cu113'
And now let's follow best practice and setup device-agnostic code.
Note: If you're using Google Colab, and you don't have a GPU turned on yet, it's now time to turn one on
via Runtime -> Change runtime type -> Hardware accelerator -> GPU . If you do this, your runtime
will likely reset and you'll have to run all of the cells above by going Runtime -> Run before .
Out[2]: 'cuda'
1. Get data
First thing's Jrst we need some data.
And like any good cooking show, some data has already been prepared for us.
Because we're not looking to train the biggest model or use the biggest dataset yet.
Machine learning is an iterative process, start small, get something working and increase when necessary.
Food101 is popular computer vision benchmark as it contains 1000 images of 101 different kinds of foods,
totaling 101,000 images (75,750 train and 25,250 test).
I can.
Instead of 101 food classes though, we're going to start with 3: pizza, steak and sushi.
And instead of 1,000 images per class, we're going to start with a random 10% (start small, increase when
necessary).
If you'd like to see where the data came from you see the following resources:
data/pizza_steak_sushi.zip - the zip archive of pizza, steak and sushi images from Food101,
created with the notebook linked above.
Let's write some code to download the formatted data from GitHub.
Note: The dataset we're about to use has been pre-formatted for what we'd like to use it for. However,
you'll often have to format your own datasets for whatever problem you're working on. This is a regular
practice in the machine learning world.
Data preparation is paramount. Before building a model, become one with the data. Ask: What am I trying to
do here? Source: @mrdbourke Twitter.
Before starting a project or building any kind of model, it's important to know what data you're working with.
In our case, we have images of pizza, steak and sushi in standard image classiJcation format.
Image classiJcation format contains separate classes of images in seperate directories titled with a
particular class name.
For example, all images of pizza are contained in the pizza/ directory.
This format is popular across many different image classiJcation benchmarks, including ImageNet (of the
most popular computer vision benchmark datasets).
You can see an example of the storage format below, the images numbers are arbitrary.
The goal will be to take this data storage structure and turn it into a dataset usable with PyTorch.
Note: The structure of the data you work with will vary depending on the problem you're working on. But
the premise still remains: become one with the data, then Jnd a way to best turn it into a dataset
compatible with PyTorch.
We can inspect what's in our data directory by writing a small helper function to walk through each of the
subdirectories and count the Jles present.
In [4]: import os
def walk_through_dir(dir_path):
"""
Walks through dir_path returning its contents.
Args:
dir_path (str or pathlib.Path): target directory
Returns:
A print out of:
number of subdiretories in dir_path
number of images (files) in each subdirectory
name of each subdirectory
"""
for dirpath, dirnames, filenames in os.walk(dir_path):
print(f"There are {len(dirnames)} directories and {len(filenames)} images in '{dirpath
In [5]: walk_through_dir(image_path)
Excellent!
It looks like we've got about 75 images per training class and 25 images per testing class.
You can see how they were created in the data creation notebook.
While we're at it, let's setup our training and testing paths.
train_dir, test_dir
Out[6]: (PosixPath('data/pizza_steak_sushi/train'),
PosixPath('data/pizza_steak_sushi/test'))
Now in the spirit of the data explorer, it's time to visualize, visualize, visualize!
1. Get all of the image paths using pathlib.Path.glob() to Jnd all of the Jles ending in .jpg .
4. And since we're working with images, we'll open the random image path using PIL.Image.open() (PIL
stands for Python Image Library).
# Set seed
random.seed(42) # <- try changing this and see what happens
# 3. Get image class from path name (the image class is the name of the directory where the
image_class = random_image_path.parent.stem
# 4. Open image
img = Image.open(random_image_path)
# 5. Print metadata
print(f"Random image path: {random_image_path}")
print(f"Image class: {image_class}")
print(f"Image height: {img.height}")
print(f"Image width: {img.width}")
img
We can do the same with matplotlib.pyplot.imshow() , except we have to convert the image to a
NumPy array Jrst.
3. Transforming data
Now what if we wanted to load our image data into PyTorch?
Before we can use our image data with PyTorch we need to:
There are several different kinds of pre-built datasets and dataset loaders for PyTorch, depending on the
problem you're working on.
Vision torchvision.datasets
Audio torchaudio.datasets
Text torchtext.datasets
Since we're working with a vision problem, we'll be looking at torchvision.datasets for our data loading
functions as well as torchvision.transforms for preparing our data.
We've got folders of images but before we can use them with PyTorch, we need to convert them into
tensors.
torchvision.transforms contains many pre-built methods for formatting images, turning them into
tensors and even manipulating them for data augmentation (the practice of altering data to make it harder
for a model to learn, we'll see this later on) purposes .
To get experience with torchvision.transforms , let's write a series of transform steps that:
1. Resize the images using transforms.Resize() (from about 512x512 to 64x64, the same shape as
the images on the CNN Explainer website).
2. Flip our images randomly on the horizontal using transforms.RandomHorizontalFlip() (this could
be considered a form of data augmentation because it will artiJcially change our image data).
3. Turn our images from a PIL image to a PyTorch tensor using transforms.ToTensor() .
Now we've got a composition of transforms, let's write a function to try them out on various images.
Args:
image_paths (list): List of target image paths.
transform (PyTorch Transforms): Transforms to apply to images.
n (int, optional): Number of images to plot. Defaults to 3.
seed (int, optional): Random seed for the random generator. Defaults to 42.
"""
random.seed(seed)
random_image_paths = random.sample(image_paths, k=n)
for image_path in random_image_paths:
with Image.open(image_path) as f:
fig, ax = plt.subplots(1, 2)
ax[0].imshow(f)
ax[0].set_title(f"Original \nSize: {f.size}")
ax[0].axis("off")
plot_transformed_images(image_path_list,
transform=data_transform,
n=3)
Nice!
We've now got a way to convert our images to tensors using torchvision.transforms .
We also manipulate their size and orientation if needed (some models prefer images of different sizes and
shapes).
Generally, the larger the shape of the image, the more information a model can recover.
For example, an image of size [256, 256, 3] will have 16x more pixels than an image of size [64, 64,
3] ( (256*256*3)/(64*64*3)=16 ).
Exercise: Try commenting out one of the transforms in data_transform and running the plotting
function plot_transformed_images() again, what happens?
Since our data is in standard image classiJcation format, we can use the class
torchvision.datasets.ImageFolder .
Where we can pass it the Jle path of a target image directory as well as a series of transforms we'd like to
perform on our images.
Let's test it out on our data folders train_dir and test_dir passing in transform=data_transform to
turn our images into tensors.