Model training using trainers#

This use-case is model training - the same, but now the usage of Trainer will be shown.

[3]:
# !pip3 install torchvision
# !pip3 install scikit-learn
[4]:
import cascade.data as cdd
import cascade.models as cdm
from cascade.utils.torch import TorchModel
from cascade.utils.sklearn import SkMetric

from tqdm import tqdm
import torch
import torchvision
from torchvision.transforms import functional as F
from torch import nn
[5]:
import cascade
cascade.__version__
[5]:
'0.14.0-alpha'

Defining data pipeline#

[6]:
MNIST_ROOT = 'data'
INPUT_SIZE = 784
BATCH_SIZE = 10
[7]:
class NoiseModifier(cdd.Modifier):
    def __getitem__(self, index):
        img, label = self._dataset[index]
        img += torch.rand_like(img) * 0.1
        img = torch.clip(img, 0, 255)
        return img, label


train_ds = torchvision.datasets.MNIST(root=MNIST_ROOT,
                                     train=True,
                                     transform=F.to_tensor,
                                     download=True)
test_ds = torchvision.datasets.MNIST(root=MNIST_ROOT,
                                    train=False,
                                    transform=F.to_tensor)

train_ds = cdd.Wrapper(train_ds)
train_ds.describe("This is MNIST dataset of handwritten images, TRAIN PART")
test_ds = cdd.Wrapper(test_ds)

train_ds = NoiseModifier(train_ds)
test_ds = NoiseModifier(test_ds)

# We will constraint the number of samples to speed up learning in example
train_ds = cdd.CyclicSampler(train_ds, 10000)
test_ds = cdd.CyclicSampler(test_ds, 5000)

train_dl = torch.utils.data.DataLoader(dataset=train_ds,
                                       batch_size=BATCH_SIZE,
                                       shuffle=True)
test_dl = torch.utils.data.DataLoader(dataset=test_ds,
                                      batch_size=BATCH_SIZE,
                                      shuffle=False)
[8]:
train_ds.get_meta()
[8]:
[{'name': 'cascade.data.cyclic_sampler.CyclicSampler', 'description': None, 'tags': [], 'comments': [], 'links': [], 'type': 'dataset', 'len': 10000}, {'name': '__main__.NoiseModifier', 'description': None, 'tags': [], 'comments': [], 'links': [], 'type': 'dataset', 'len': 60000}, {'name': 'cascade.data.dataset.Wrapper', 'description': 'This is MNIST dataset of handwritten images, TRAIN PART', 'tags': [], 'comments': [], 'links': [], 'type': 'dataset', 'len': 60000, 'obj_type': "<class 'torchvision.datasets.mnist.MNIST'>"}]

Model definition#

[9]:
class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes, *args, **kwargs):
        super().__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.l1 = nn.Linear(input_size, hidden_size)
        self.l2 = nn.Linear(hidden_size, num_classes)
        self.relu = nn.ReLU()

    def forward(self, y):
         out = self.l1(y)
         out = self.relu(out)
         out = self.l2(out)

         return out

Next Cascade’s wrapper is defined. The most of the interactions with pytorch modules is already implemented in cascade.utils.TorchModel so we need to only define how to train and evaluate this model.

The difference between previous example and this one is in the fit function - now it only fits one epoch per call and doesn’t need additional logging - Trainer will cover this functionality.

[10]:
class Classifier(TorchModel):
    # In train we copy-paste regular pytorch trainloop,
    # but use self._model, where our SimpleNN is placed
    def fit(self, train_dl, lr, *args, **kwargs):
        criterion = nn.CrossEntropyLoss()
        optim = torch.optim.Adam(self._model.parameters(), lr=lr)

        ds_size = len(train_dl)
        for x, (imgs, labels) in enumerate(train_dl):
            imgs = imgs.reshape(-1, self._model.input_size)

            out = self._model(imgs)
            loss = criterion(out, labels)

            optim.zero_grad()
            loss.backward()
            optim.step()


    # Evaluate function takes the metrics from arguments
    # and populates self.metrics without returning anything
    def evaluate(self, test_dl, metrics, *args, **kwargs):
        pred = []
        gt = []
        for imgs, labels in tqdm(test_dl):
            imgs = imgs.reshape(-1, self._model.input_size)
            out = torch.argmax(self._model(imgs, *args, **kwargs), -1)

            pred.append(out)
            gt.append(labels)

        pred = torch.concat(pred).detach().numpy()
        gt = torch.concat(gt).detach().numpy()

        for metric in metrics:
            metric.compute(gt, pred)
            self.add_metric(metric)

Model initialization#

[12]:
NUM_EPOCHS = 5
LR = 1e-3

# Classifier will initialize SimpleNN with all the parameters passed
# but some of them are not for the SimpleNN, but to be recorded in metadata
model = Classifier(SimpleNN,
    # These arguments are needed by SimpleNN,
    # but passed as keywords to be recorded in meta
    input_size=INPUT_SIZE,
    hidden_size=100,
    num_classes=10,
    # These arguments will be skipped by SimpleNN,
    # but will be added to meta
    num_epochs=NUM_EPOCHS,
    lr=LR,
    bs=BATCH_SIZE)

Set up trainer#

Let’s set up logging first to catch trainer’s logs

[1]:
import sys
import logging
logging.basicConfig(
    handlers=[logging.StreamHandler(sys.stdout)],
    level='INFO'
)
[2]:
from cascade.trainers import BasicTrainer

trainer = BasicTrainer('trainer_repo')

The main method of course is train It will do all the stuff needed for us including training, evaluating, saving and logging

[12]:
trainer.train(
    model,
    train_data=train_dl,
    test_data=test_dl,
    train_kwargs={'lr': LR, 'bs': BATCH_SIZE}, # will be passed into model.fit()
    test_kwargs={"metrics": [SkMetric("accuracy_score")]}, # will be passed into model.evaluate()
    epochs=NUM_EPOCHS,
    start_from=None, # can start from checkpoint if line name is specified,
    save_strategy=2,
    eval_strategy=1
)
INFO:cascade.trainers.trainer:Training started with parameters:
{'lr': 0.001, 'bs': 10}
INFO:cascade.trainers.trainer:repo is Repo in trainer_repo of 1 lines
INFO:cascade.trainers.trainer:line is 00000
INFO:cascade.trainers.trainer:training will last 5 epochs
100%|██████████| 500/500 [00:04<00:00, 101.35it/s]
INFO:cascade.trainers.trainer:Epoch: 0
INFO:cascade.trainers.trainer:SkMetric(name=accuracy_score, value=0.874, created_at=2024-08-24 11:08:21.482021+00:00)

100%|██████████| 500/500 [00:04<00:00, 122.69it/s]
INFO:cascade.trainers.trainer:Epoch: 1
INFO:cascade.trainers.trainer:SkMetric(name=accuracy_score, value=0.8956, created_at=2024-08-24 11:08:21.482021+00:00)

100%|██████████| 500/500 [00:04<00:00, 118.09it/s]
INFO:cascade.trainers.trainer:Epoch: 2
INFO:cascade.trainers.trainer:SkMetric(name=accuracy_score, value=0.9092, created_at=2024-08-24 11:08:21.482021+00:00)
100%|██████████| 500/500 [00:04<00:00, 115.79it/s]
INFO:cascade.trainers.trainer:Epoch: 3
INFO:cascade.trainers.trainer:SkMetric(name=accuracy_score, value=0.9202, created_at=2024-08-24 11:08:21.482021+00:00)

100%|██████████| 500/500 [00:04<00:00, 121.44it/s]
INFO:cascade.trainers.trainer:Epoch: 4
INFO:cascade.trainers.trainer:SkMetric(name=accuracy_score, value=0.919, created_at=2024-08-24 11:08:21.482021+00:00)
INFO:cascade.trainers.trainer:Training finished in 1 minute
INFO:cascade.trainers.trainer:repo was Repo in trainer_repo of 1 lines
INFO:cascade.trainers.trainer:line was 00000
INFO:cascade.trainers.trainer:training ended on 4 epoch
INFO:cascade.trainers.trainer:Parameters:
{'lr': 0.001, 'bs': 10}
INFO:cascade.trainers.trainer:Metrics:
INFO:cascade.trainers.trainer:SkMetric(name=accuracy_score, value=0.919, created_at=2024-08-24 11:08:21.482021+00:00)

Results#

We can obtain the results of training from trainer’s meta data.

[13]:
trainer.get_meta()
[13]:
[{'name': 'cascade.trainers.trainer.BasicTrainer', 'epochs': 5, 'eval_strategy': 1, 'save_strategy': 2, 'description': None, 'tags': [], 'comments': [], 'links': [], 'type': 'trainer', 'training_started_at': DateTime(2024, 8, 24, 14, 8, 21, 498283, tzinfo=Timezone('Europe/Moscow')), 'training_ended_at': DateTime(2024, 8, 24, 14, 9, 45, 93067, tzinfo=Timezone('Europe/Moscow'))}]

Start from checkpoint#

Let’s try continue learning where we finished using the same line as before.

[13]:
trainer.train(
    model,
    train_data=train_dl,
    test_data=test_dl,
    train_kwargs={'lr': LR, 'bs': BATCH_SIZE},
    test_kwargs={'metrics': [SkMetric("accuracy_score")]},
    epochs=5,
    start_from='00000',
    save_strategy=4,
    eval_strategy=1
)
INFO:cascade.trainers.trainer:Training started with parameters:
{'lr': 0.001, 'bs': 10}
INFO:cascade.trainers.trainer:repo is Repo in trainer_repo of 1 lines
INFO:cascade.trainers.trainer:line is 00000
INFO:cascade.trainers.trainer:started from model 9
INFO:cascade.trainers.trainer:training will last 5 epochs
100%|██████████| 500/500 [00:04<00:00, 108.53it/s]
INFO:cascade.trainers.trainer:Epoch: 0
INFO:cascade.trainers.trainer:SkMetric(name=accuracy_score, value=0.9408, created_at=2024-08-24 11:13:04.420271+00:00)
100%|██████████| 500/500 [00:05<00:00, 89.71it/s]
INFO:cascade.trainers.trainer:Epoch: 1
INFO:cascade.trainers.trainer:SkMetric(name=accuracy_score, value=0.9378, created_at=2024-08-24 11:13:04.420271+00:00)

100%|██████████| 500/500 [00:04<00:00, 121.56it/s]
INFO:cascade.trainers.trainer:Epoch: 2
INFO:cascade.trainers.trainer:SkMetric(name=accuracy_score, value=0.934, created_at=2024-08-24 11:13:04.420271+00:00)

100%|██████████| 500/500 [00:04<00:00, 118.74it/s]
INFO:cascade.trainers.trainer:Epoch: 3
INFO:cascade.trainers.trainer:SkMetric(name=accuracy_score, value=0.9398, created_at=2024-08-24 11:13:04.420271+00:00)
100%|██████████| 500/500 [00:04<00:00, 124.65it/s]
INFO:cascade.trainers.trainer:Epoch: 4
INFO:cascade.trainers.trainer:SkMetric(name=accuracy_score, value=0.939, created_at=2024-08-24 11:13:04.420271+00:00)
INFO:cascade.trainers.trainer:Training finished in 1 minute
INFO:cascade.trainers.trainer:repo was Repo in trainer_repo of 1 lines
INFO:cascade.trainers.trainer:line was 00000
INFO:cascade.trainers.trainer:started from model 9
INFO:cascade.trainers.trainer:training ended on 4 epoch
INFO:cascade.trainers.trainer:Parameters:
{'lr': 0.001, 'bs': 10}
INFO:cascade.trainers.trainer:Metrics:
INFO:cascade.trainers.trainer:SkMetric(name=accuracy_score, value=0.939, created_at=2024-08-24 11:13:04.420271+00:00)
[14]:
trainer.metrics
[14]:
[[SkMetric(name=accuracy_score, value=0.939, created_at=2024-08-24 11:13:04.420271+00:00)], [SkMetric(name=accuracy_score, value=0.939, created_at=2024-08-24 11:13:04.420271+00:00)], [SkMetric(name=accuracy_score, value=0.939, created_at=2024-08-24 11:13:04.420271+00:00)], [SkMetric(name=accuracy_score, value=0.939, created_at=2024-08-24 11:13:04.420271+00:00)], [SkMetric(name=accuracy_score, value=0.939, created_at=2024-08-24 11:13:04.420271+00:00)]]