Introduction

Over the past few weeks, I have been experimenting with and learning more about the popular PyTorch (torch) package that like Keras and Tensorflow, provides an easy way to implement deep learning algorithms. In this post, I show why I think torch’s Dataset class is an elegant and beginner-friendly way to handle the process of data processing. As I have written previously, it’s more important to ensure you have the correct data than actually implementing a target model. This, of course, is not to say that implementing a model is always straightforward. However, it is easier and faster to fix an erroneous model than it is to correct the data processing pipeline.

With the importance of data processing re-emphasised, let us dive into the beauty of the torch Dataset class.

But First, a word on Li Thresholding

As part of my experiments over the past few weeks, I came across Li Thresholding, which was until this year an unknown method to me. Since I liked this method, I thought it would be great to write a word or two about it.

In this blog post, we shall be using a simple two image dataset organised as follows:

* sample
  * train
    - images
    - masks

Li Thresholding with scikit-image

Li thresholding is a thresholding method introduced by Li & Lee, 1993 that finds an optimum threshold by minimising the cross entropy between an image and its segmentation.

In the scikit-image implementation, you can either use the method with defaults or supply an initial guess to use to find the optimal threshold. In this example, we will calculate the 95% quantile based on the grayscale image and use this as the initial_guess argument of the threshold_li method.

In summary, we read a colored image,convert it to grayscale, and finally threshold with the Li method. For convenience, I did not perform Gaussian denoising. I have also recently read that Gaussian pre-processing may increase the likelihood of overfitting in convolutional neural networks.

The full process in code:

import cv2
from skimage.filters import threshold_li
from skimage.io import imread 
import numpy as np

img = imread("https://github.com/Nelson-Gon/nelson-gon.github.io/blob/master/images/dog-test.png?raw=true")
gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
li_thresh = threshold_li(gray_img, initial_guess=np.quantile(gray_img,0.95))
thresholded_image = gray_img > li_thresh
thresholded_numpy = np.array(thresholded_image).astype("uint8")
plt.imshow(thresholded_numpy, cmap="gray")
plt.axis("off")

The result:

Li-Result

I will leave it at that for Li thresholding and leave further exploration to those who may be interested in the same. We now proceed with the main aim of this blogpost: appreciating the elegance of Dataset.

Writing a simple custom data loader with torch

If you are coming from a Keras/Tensorflow background (like myself), you will probably (dis)agree that the data processing process in Keras/Tensorflow is less obvious and may sometimes feel less programmer-like. To be clear, this is not meant to say that the Keras/Tensorflow combination is bad. On the contrary, I love the model building process but think torch provides an easier data processing implementation to work with.

To write a simple custom data loader, we will need to import the following modules/packages:

from torch.utils.data import Dataset
import glob
import os
import torch
from skimage.io import imread
import torchvision.transforms as tf
from PIL import Image
import matplotlib.pyplot as plt

For purposes of this blog post, I wanted the CustomDataLoader to do only two things:

  • Read image, mask pairs in the provided directories.
  • Perform simple transformations on the images/masks. In this case, we will simply resize the images and convert them to PIL images.

Creating our class

To create the class, one simply inherits from the Dataset class which I think is a great way to work with, and also because I like the idea of inheritance. For this simple example, the class __init__ function will only take an image directory argument, a mask directory argument, an image suffix, and target_size specifying the target size for our resizing. We also add some sanity checks to ensure that we get exactly what we expect. The class then becomes:



class CustomDataLoader(Dataset):
    def __init__(self, train_images, train_labels, image_suffix="jpg", target_size=(30, 30)):
        """

        :param train_images: Directory containing train_images
        :param train_labels: Directory containing train_labels
        :param image_suffix: Image suffix, for convenience
        :param target_size

        """

        self.train_images = train_images
        self.train_labels = train_labels
        self.image_suffix = image_suffix
        self.target_size = target_size

        # Include sanity checks

        # In this case, we only check that both train_images and train_labels directories exist
        # We also ensure that we get the correct image suffix, jpg and png here

        if not (all(os.path.isdir(directory) for directory in [train_images, train_labels])):
            raise NotADirectoryError("Please ensure that both train_images and train_labels are valid directories.")

        if not self.image_suffix in ["png", "jpg","jpeg"]:
            raise ValueError("Only supporting PNG and JPG files for now.")

        # glob for our file names
        # sorted because I prefer 1, 10, 11, ...
        # ideally should be able to glob both png and jpg, for simplicity glob only one type
        self.image_list = sorted(glob.glob(self.train_images + "/*" + self.image_suffix))
        self.labels_list = sorted(glob.glob(self.train_labels + "/*" + self.image_suffix))


        if len(self.image_list) != len(self.labels_list):
            raise ValueError("Images list and labels list should be the same length.")

Next, by convention, and also for sanity checking, we define a __len__ method that will return the length of all images in the directory ie how many images we will work with.

    def __len__(self):
        """

        :return: length of the dataset

        """
        return len(self.image_list)

As stated above, we also define a transform method that will allow us to transform images. Here we will simply transform them by changing their size to the target size defined in __init__ and also converting to Image since this is what torch is happy to work with.

 def transform(self, image):
        # Basic, convert to PIL since torch tensors only work with PIL
        # Basic resizing
        to_pil = Image.fromarray(image)
        # resize_image
        resizer = tf.Resize(self.target_size)
        image = resizer(to_pil)
        return image

Oh the elegance of get_item!

Finally, here comes what I think is the most elegant part of the Dataset class.

We can define a “getter” method __get__item that will allow us to work with one image at a time. I think it could be thought of as a generator that calls next(iter) since what we are really doing is “looping” through the data one index at a time and doing the necessary transformations.

Programmatically:

    def __getitem__(self, img_index):
        if torch.is_tensor(img_index):
            img_index = img_index.tolist()

        train_image = imread(self.image_list[img_index])
        train_label = imread(self.labels_list[img_index])

        return {"image":self.transform(train_image), "mask":self.transform(train_label)}

Putting it all together


class CustomDataLoader(Dataset):
    def __init__(self, train_images, train_labels, image_suffix="jpg", target_size=(30, 30)):
        """

        :param train_images: Directory containing train_images
        :param train_labels: Directory containing train_labels
        :param image_suffix: Image suffix, for convenience
        :param target_size

        """

        self.train_images = train_images
        self.train_labels = train_labels
        self.image_suffix = image_suffix
        self.target_size = target_size

        # Include sanity checks

        # In this case, we only check that both train_images and train_labels directories exist
        # We also ensure that we get the correct image suffix, jpg and png here

        if not (all(os.path.isdir(directory) for directory in [train_images, train_labels])):
            raise NotADirectoryError("Please ensure that both train_images and train_labels are valid directories.")

        if not self.image_suffix in ["png", "jpg","jpeg"]:
            raise ValueError("Only supporting PNG and JPG files for now.")

        # glob for our file names
        # sorted because I prefer 1, 10, 11, ...
        # ideally should be able to glob both png and jpg, for simplicity glob only one type
        self.image_list = sorted(glob.glob(self.train_images + "/*" + self.image_suffix))
        self.labels_list = sorted(glob.glob(self.train_labels + "/*" + self.image_suffix))


        if len(self.image_list) != len(self.labels_list):
            raise ValueError("Images list and labels list should be the same length.")

    def __len__(self):
        """

        :return: length of the dataset

        """
        return len(self.image_list)

    def transform(self, image):
        # Basic, convert to PIL since torch tensors only work with PIL
        # Basic resizing
        to_pil = Image.fromarray(image)
        # resize_image
        resizer = tf.Resize(self.target_size)
        image = resizer(to_pil)
        return image

    def __getitem__(self, img_index):
        if torch.is_tensor(img_index):
            img_index = img_index.tolist()

        train_image = imread(self.image_list[img_index])
        train_label = imread(self.labels_list[img_index])

        return {"image":self.transform(train_image),
                                "mask":self.transform(train_label)}

Usage

To use our newly created dataset, we can simply create an object as usual:

images_loader = CustomDataLoader("path_to_images","path_to_masks",
                                 target_size= (512, 512), image_suffix="png")

Finally, to test and for purposes of this post, we will simply use next(iter) or simply “manual” indexing:

next(iter(images_loader))

I will demonstrate what happens when we try to plot the image at index 0 by using manual indexing:

plt.imshow(images_loader[0]["image"])

This gives us:

Image Res

Next Steps

In this blog post, we have really looked at a very basic example of the elegance of the torch Dataset class. As a next step, one could implement the following in our CustomDataLoader class:

  • Perform more transformations in transform for example random flipping of images.

  • Extend the data loader to handle multiple image types.

  • Finalize the transform by converting to a torch tensors. This is important because when running models with the torch.nn.Module class, you will likely encounter errors that request that you use tensors.

Conclusion

The torch Dataset class is powerful and can be used to do more than just what I state here. If you, are interested in learning more about the torch package, I highly recommend that you take a look at the official documentation, and spend some time exploring libraries that implement deep learning methods using torch. The full code for this implementation is available at https://github.com/Nelson-Gon/nelson-gon.github.io/blob/master/code/elegant_torch.py

Thank you very much for reading. As always, do comment below or contact me if you would like to discuss any of the posts/projects on this site.

Keep Building

Modules used in this blog post

torch==1.7.1
torchvision==0.2.2.post3
opencv-python>=4.2 #actual 4.4.046
scikit-image>=0.16 #actual 0.18.1

References

https://www.sciencedirect.com/science/article/abs/pii/003132039390115D

https://en.wikipedia.org/wiki/Cross_entropy

https://github.com/Nelson-Gon/nelson-gon.github.io/blob/master/code/elegant_torch.py

https://www.nature.com/articles/s41524-020-00363-x#Sec1

https://pytorch.org/docs/stable/index.html