Model training#

This use-case is model training.
By going through this you will know how to use Cascade for metadata tracking, hyperparameter tuning and model selection.
Previous part is the pipeline building and is taken without comments.
For more detailed description of it see Pipeline building example.
#!pip3 install torchvision
import as cdd
import cascade.models as cdm
import cascade.meta as cde
from cascade.utils.torch import TorchModel
from cascade.utils.sklearn import SkMetric

from tqdm import tqdm
import torch
import torchvision
from torchvision.transforms import functional as F
from torch import nn
import cascade

Data Pipeline#

MNIST_ROOT = 'data'
class NoiseModifier(cdd.Modifier):
    def __getitem__(self, index):
        img, label = self._dataset[index]
        img += torch.rand_like(img) * 0.1
        img = torch.clip(img, 0, 255)
        return img, label

train_ds = torchvision.datasets.MNIST(root=MNIST_ROOT,
test_ds = torchvision.datasets.MNIST(root=MNIST_ROOT,

train_ds = cdd.Wrapper(train_ds)
train_ds.describe("This is MNIST dataset of handwritten images, TRAIN PART")
test_ds = cdd.Wrapper(test_ds)

train_ds = NoiseModifier(train_ds)
test_ds = NoiseModifier(test_ds)

train_dl =,
test_dl =,

Module definition#

class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes, *args, **kwargs):

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.l1 = nn.Linear(input_size, hidden_size)
        self.l2 = nn.Linear(hidden_size, num_classes)
        self.relu = nn.ReLU()

    def forward(self, y):
         out = self.l1(y)
         out = self.relu(out)
         out = self.l2(out)

         return out

Cascade wrapper#

class Classifier(TorchModel):
    # In train we copy-paste regular pytorch trainloop,
    # but use self._model, where our SimpleNN is placed
    def fit(self, train_dl, num_epochs, lr, *args, **kwargs):
        criterion = nn.CrossEntropyLoss()
        optim = torch.optim.Adam(self._model.parameters(), lr=lr)

        ds_size = len(train_dl)
        for epoch in range(num_epochs):
            for x, (imgs, labels) in enumerate(train_dl):
                imgs = imgs.reshape(-1, self._model.input_size)

                out = self._model(imgs)
                loss = criterion(out, labels)


                if x % 500 == 0:
                    print (f'Epochs [{epoch}/{num_epochs}], Step[{x}/{ds_size}], Loss: {loss.item():.4f}')

    # Evaluate function takes the metrics from arguments
    # and populates self.metrics without returning anything
    def evaluate(self, test_dl, metrics, *args, **kwargs):
        pred = []
        gt = []
        for imgs, labels in tqdm(test_dl):
            imgs = imgs.reshape(-1, self._model.input_size)
            out = torch.argmax(self._model(imgs, *args, **kwargs), -1)


        pred = torch.concat(pred).detach().numpy()
        gt = torch.concat(gt).detach().numpy()

        for metric in metrics:
            metric.compute(gt, pred)

Model training#

Now we are ready to train our model. We define hyperparameters and pass them to our wrapper. Wrapper accepts pytorch module’s class and all the parameters that are needed to initialize it.
Additionally we pass keyword arguments that are connected to training. It is done to add them to the model’s metadata.
LR = 1e-3

# Classifier will initialize SimpleNN with all the parameters passed
# but some of them are not for the SimpleNN, but to be recorded in metadata
model = Classifier(SimpleNN,
    # These arguments are needed by SimpleNN,
    # but passed as keywords to be recorded in meta
    # These arguments will be skipped by SimpleNN,
    # but will be added to meta
Epochs [0/2], Step[0/6000], Loss: 2.2891
Epochs [0/2], Step[500/6000], Loss: 0.4226
Epochs [0/2], Step[1000/6000], Loss: 0.2755
Epochs [0/2], Step[1500/6000], Loss: 0.1671
Epochs [0/2], Step[2000/6000], Loss: 0.1510
Epochs [0/2], Step[2500/6000], Loss: 0.2112
Epochs [0/2], Step[3000/6000], Loss: 0.1839
Epochs [0/2], Step[3500/6000], Loss: 0.0139
Epochs [0/2], Step[4000/6000], Loss: 0.0661
Epochs [0/2], Step[4500/6000], Loss: 0.0417
Epochs [0/2], Step[5000/6000], Loss: 0.2169
Epochs [0/2], Step[5500/6000], Loss: 0.2178
Epochs [1/2], Step[0/6000], Loss: 0.0850
Epochs [1/2], Step[500/6000], Loss: 0.0512
Epochs [1/2], Step[1000/6000], Loss: 0.0338
Epochs [1/2], Step[1500/6000], Loss: 0.0106
Epochs [1/2], Step[2000/6000], Loss: 0.4859
Epochs [1/2], Step[2500/6000], Loss: 0.0068
Epochs [1/2], Step[3000/6000], Loss: 0.1055
Epochs [1/2], Step[3500/6000], Loss: 0.1380
Epochs [1/2], Step[4000/6000], Loss: 0.0124
Epochs [1/2], Step[4500/6000], Loss: 1.1131
Epochs [1/2], Step[5000/6000], Loss: 0.1694
Epochs [1/2], Step[5500/6000], Loss: 0.0235

Evaluate the model#

Now we can evaluate model performance on test dataset. We pass the data and one metric that is a wrapper around sklearn’s metric.

model.evaluate(test_dl, [SkMetric("accuracy_score")])
100%|██████████| 1000/1000 [00:07<00:00, 125.10it/s]

Check the metadata#

Let’s examine metadata obtained from the model after training.

[{'name': '__main__.Classifier', 'description': None, 'tags': [], 'comments': [], 'links': [], 'type': 'model', 'created_at': DateTime(2024, 8, 24, 10, 52, 54, 681116, tzinfo=Timezone('UTC')), 'metrics': [SkMetric(name=accuracy_score, value=0.9652, created_at=2024-08-24 10:55:07.930148+00:00)], 'params': {'input_size': 784, 'hidden_size': 100, 'num_classes': 10, 'num_epochs': 2, 'lr': 0.001, 'bs': 10}, 'module': 'SimpleNN(\n  (l1): Linear(in_features=784, out_features=100, bias=True)\n  (l2): Linear(in_features=100, out_features=10, bias=True)\n  (relu): ReLU()\n)'}]
We can notice several things. The model is tracking the time of creation. It’s metrics in place as expected after evaluation.
Let’s look at the params dict. We can see all the parameters that we passed using keywords in the wrapper. The wrapper recorded them in the metadata for us automatically.

Saving the model#

It’s time to save the trained model. We can just use method, but let’s look at another Cascade’s tool for model management.
It is called Repo.
from cascade.repos import Repo

repo = Repo('repo')

This is the repository of models. It manages a series of experiments over a sets of models of different architectures called model lines.

repo.add_line('linear_nn', type="model", obj_cls=Classifier)
<class 'cascade.lines.model_line.ModelLine'>(0) items of <class 'cascade.models.model.Model'>

Model line is the manager of models with similar architecture, but different parameters or different epochs. It manages saving of model and its meta and also loading of model.

Aside from model’s metadata we would like to know on what data model was trained.

[13]:, name='train_data')
[{'name': '__main__.Classifier', 'description': None, 'tags': [], 'comments': [], 'links': [{'id': '1', 'name': 'train_data', 'uri': None, 'meta': [{'name': '__main__.NoiseModifier', 'description': None, 'tags': [], 'comments': [], 'links': [], 'type': 'dataset', 'len': 60000}, {'name': '', 'description': 'This is MNIST dataset of handwritten images, TRAIN PART', 'tags': [], 'comments': [], 'links': [], 'type': 'dataset', 'len': 60000, 'obj_type': "<class 'torchvision.datasets.mnist.MNIST'>"}], 'created_at': DateTime(2024, 8, 24, 10, 55, 16, 169857, tzinfo=Timezone('UTC'))}], 'type': 'model', 'created_at': DateTime(2024, 8, 24, 10, 52, 54, 681116, tzinfo=Timezone('UTC')), 'metrics': [SkMetric(name=accuracy_score, value=0.9652, created_at=2024-08-24 10:55:07.930148+00:00)], 'params': {'input_size': 784, 'hidden_size': 100, 'num_classes': 10, 'num_epochs': 2, 'lr': 0.001, 'bs': 10}, 'module': 'SimpleNN(\n  (l1): Linear(in_features=784, out_features=100, bias=True)\n  (l2): Linear(in_features=100, out_features=10, bias=True)\n  (relu): ReLU()\n)'}]

We are ready to save the model

This will save the model to the path:
And metadata:

Peeking inside repo#

To see model’s metrics and parameters we don’t need to manually go to the folders mentioned or print large metadata in console. Cascade has tools for conveniently show metrics. One of them is MetricViewer.

mv = cde.MetricViewer(repo)
# We can show the table like this
# mv.plot_table()
# Or we can open web-application like this
# mv.serve()
# but it will not be rendered in the documentation, so...
        line  num                       created_at            saved  \
0  linear_nn    0 2024-08-24 10:52:54.681116+00:00  2 minutes after

   input_size  hidden_size  num_classes  num_epochs     lr  bs tags  \
0         784          100           10           2  0.001  10   []

   comment_count  link_count            name   value
0              0           1  accuracy_score  0.9652

It accepts the repo object and can show tables of metrics and metadata. However, when the table is too big and we need more powerful tool, mv also has method serve that opens fully interactive table of metrics with the ability to sort and filter results.

More experiments#

What if we want to automatically run a number of experiments and then choose the best model?
The workflow is pretty similar. In the example below we try to find the best option for hidden size of the model.
We define the set of parameters for our experiments and run them in loop every time saving the results.
params = [
    {'hidden_size': 10,  'num_epochs': 2, 'lr': 0.001, 'bs': 10},
    {'hidden_size': 50,  'num_epochs': 2, 'lr': 0.001, 'bs': 10},
    {'hidden_size': 100, 'num_epochs': 2, 'lr': 0.001, 'bs': 10}
for p in params:
    model = Classifier(SimpleNN,
        num_classes=10), **p)
    model.evaluate(test_dl, [SkMetric("accuracy_score")])
Epochs [0/2], Step[0/6000], Loss: 2.4026
Epochs [0/2], Step[500/6000], Loss: 1.1337
Epochs [0/2], Step[1000/6000], Loss: 0.6567
Epochs [0/2], Step[1500/6000], Loss: 0.0914
Epochs [0/2], Step[2000/6000], Loss: 0.2892
Epochs [0/2], Step[2500/6000], Loss: 0.0613
Epochs [0/2], Step[3000/6000], Loss: 0.2035
Epochs [0/2], Step[3500/6000], Loss: 0.4300
Epochs [0/2], Step[4000/6000], Loss: 0.8379
Epochs [0/2], Step[4500/6000], Loss: 0.1027
Epochs [0/2], Step[5000/6000], Loss: 0.5138
Epochs [0/2], Step[5500/6000], Loss: 0.0586
Epochs [1/2], Step[0/6000], Loss: 0.1320
Epochs [1/2], Step[500/6000], Loss: 0.2849
Epochs [1/2], Step[1000/6000], Loss: 0.0615
Epochs [1/2], Step[1500/6000], Loss: 0.2261
Epochs [1/2], Step[2000/6000], Loss: 0.3681
Epochs [1/2], Step[2500/6000], Loss: 0.7509
Epochs [1/2], Step[3000/6000], Loss: 0.7053
Epochs [1/2], Step[3500/6000], Loss: 0.1424
Epochs [1/2], Step[4000/6000], Loss: 0.6824
Epochs [1/2], Step[4500/6000], Loss: 0.2610
Epochs [1/2], Step[5000/6000], Loss: 0.2609
Epochs [1/2], Step[5500/6000], Loss: 0.4192
100%|██████████| 1000/1000 [00:06<00:00, 148.56it/s]
Epochs [0/2], Step[0/6000], Loss: 2.2791
Epochs [0/2], Step[500/6000], Loss: 0.3728
Epochs [0/2], Step[1000/6000], Loss: 0.4797
Epochs [0/2], Step[1500/6000], Loss: 0.3007
Epochs [0/2], Step[2000/6000], Loss: 0.5284
Epochs [0/2], Step[2500/6000], Loss: 0.1441
Epochs [0/2], Step[3000/6000], Loss: 0.0626
Epochs [0/2], Step[3500/6000], Loss: 0.1782
Epochs [0/2], Step[4000/6000], Loss: 0.2281
Epochs [0/2], Step[4500/6000], Loss: 0.1399
Epochs [0/2], Step[5000/6000], Loss: 0.0370
Epochs [0/2], Step[5500/6000], Loss: 0.1297
Epochs [1/2], Step[0/6000], Loss: 0.2715
Epochs [1/2], Step[500/6000], Loss: 0.4796
Epochs [1/2], Step[1000/6000], Loss: 0.0554
Epochs [1/2], Step[1500/6000], Loss: 0.0662
Epochs [1/2], Step[2000/6000], Loss: 0.0662
Epochs [1/2], Step[2500/6000], Loss: 0.1186
Epochs [1/2], Step[3000/6000], Loss: 0.0965
Epochs [1/2], Step[3500/6000], Loss: 1.1392
Epochs [1/2], Step[4000/6000], Loss: 0.6301
Epochs [1/2], Step[4500/6000], Loss: 0.0048
Epochs [1/2], Step[5000/6000], Loss: 0.0046
Epochs [1/2], Step[5500/6000], Loss: 0.0098
100%|██████████| 1000/1000 [00:07<00:00, 131.16it/s]
Epochs [0/2], Step[0/6000], Loss: 2.3463
Epochs [0/2], Step[500/6000], Loss: 0.2545
Epochs [0/2], Step[1000/6000], Loss: 0.1970
Epochs [0/2], Step[1500/6000], Loss: 0.0619
Epochs [0/2], Step[2000/6000], Loss: 0.0328
Epochs [0/2], Step[2500/6000], Loss: 0.0237
Epochs [0/2], Step[3000/6000], Loss: 0.7900
Epochs [0/2], Step[3500/6000], Loss: 0.0399
Epochs [0/2], Step[4000/6000], Loss: 0.0198
Epochs [0/2], Step[4500/6000], Loss: 0.0266
Epochs [0/2], Step[5000/6000], Loss: 0.1952
Epochs [0/2], Step[5500/6000], Loss: 0.2487
Epochs [1/2], Step[0/6000], Loss: 0.5751
Epochs [1/2], Step[500/6000], Loss: 0.0471
Epochs [1/2], Step[1000/6000], Loss: 0.0931
Epochs [1/2], Step[1500/6000], Loss: 0.0056
Epochs [1/2], Step[2000/6000], Loss: 0.0699
Epochs [1/2], Step[2500/6000], Loss: 0.1815
Epochs [1/2], Step[3000/6000], Loss: 1.2539
Epochs [1/2], Step[3500/6000], Loss: 0.4243
Epochs [1/2], Step[4000/6000], Loss: 0.3889
Epochs [1/2], Step[4500/6000], Loss: 0.0390
Epochs [1/2], Step[5000/6000], Loss: 0.0132
Epochs [1/2], Step[5500/6000], Loss: 0.0530
100%|██████████| 1000/1000 [00:07<00:00, 130.79it/s]


We can see the results of our experiments - all of them are present in the table and we can choose the best option.

mv = cde.MetricViewer(repo)
# mv.plot_table()
        line  num                       created_at            saved  \
0  linear_nn    0 2024-08-24 10:52:54.681116+00:00  2 minutes after
1  linear_nn    1 2024-08-24 10:55:16.503485+00:00  2 minutes after
2  linear_nn    2 2024-08-24 10:57:21.684847+00:00  2 minutes after
3  linear_nn    3 2024-08-24 10:59:35.747805+00:00  2 minutes after

   input_size  hidden_size  num_classes  num_epochs     lr  bs tags  \
0         784          100           10           2  0.001  10   []
1         784           10           10           2  0.001  10   []
2         784           50           10           2  0.001  10   []
3         784          100           10           2  0.001  10   []

   comment_count  link_count            name   value
0              0           1  accuracy_score  0.9652
1              0           0  accuracy_score  0.9186
2              0           0  accuracy_score  0.9590
3              0           0  accuracy_score  0.9674