Pipeline building#

This use-case is pipeline building. We will prepare a MNIST digits dataset for training a classifier.

[1]:
#!pip3 install torchvision
[2]:
import cascade.data as cdd
import cascade.meta as cme

import torch
import torchvision
import torchvision.transforms.functional as F
[3]:
import cascade
cascade.__version__
[3]:
'0.14.0-alpha'

Let’s load torch dataset

[4]:
MNIST_ROOT = 'data'

train_ds = torchvision.datasets.MNIST(root=MNIST_ROOT,
                                     train=True,
                                     transform=F.to_tensor,
                                     download=True)
test_ds = torchvision.datasets.MNIST(root=MNIST_ROOT,
                                    train=False,
                                    transform=F.to_tensor)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz
100%|██████████| 9912422/9912422 [00:52<00:00, 190022.28it/s]
Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz
100%|██████████| 28881/28881 [00:00<00:00, 72451.92it/s]
Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz
100%|██████████| 1648877/1648877 [00:03<00:00, 528276.00it/s]
Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz
100%|██████████| 4542/4542 [00:00<00:00, 2330054.89it/s]
Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw

Creating Cascade Dataset#

[5]:
train_ds = cdd.Wrapper(train_ds)
train_ds.describe("This is MNIST dataset of handwritten images")
test_ds = cdd.Wrapper(test_ds)

Applying noise#

Let’s say we want to apply noise to an image.
We will use hardcoded magnitude to simplify an example.
[6]:
class NoiseModifier(cdd.Modifier):
    def __getitem__(self, index):
        img, label = self._dataset[index]
        img += torch.rand_like(img) * 0.1
        img = torch.clip(img, 0, 255)
        return img, label
[7]:
train_ds = NoiseModifier(train_ds)

Viewing metadata#

[8]:
train_ds.get_meta()
[8]:
[{'name': '__main__.NoiseModifier', 'description': None, 'tags': [], 'comments': [], 'links': [], 'type': 'dataset', 'len': 60000}, {'name': 'cascade.data.dataset.Wrapper', 'description': 'This is MNIST dataset of handwritten images', 'tags': [], 'comments': [], 'links': [], 'type': 'dataset', 'len': 60000, 'obj_type': "<class 'torchvision.datasets.mnist.MNIST'>"}]

Ready to train model#

Now we can set the batch size and pass our pipeline to the DataLoaders.

[9]:
BATCH_SIZE = 10
[10]:
trainldr = torch.utils.data.DataLoader(dataset=train_ds,
                                       batch_size=BATCH_SIZE,
                                       shuffle=True)
testldr = torch.utils.data.DataLoader(dataset=test_ds,
                                      batch_size=BATCH_SIZE,
                                      shuffle=False)