cascade.utils.torch#

class cascade.utils.torch.TorchModel(model_class: Type | None = None, model: Module | None = None, **kwargs: Any)[source]#

The wrapper around nn.Module

__init__(model_class: Type | None = None, model: Module | None = None, **kwargs: Any) None[source]#
Parameters:
  • model_class (type, optional) – The class created when new nn.Module was defined. Will be used to construct model. If any arguments needed, please pass them into kwargs.

  • model (torch.nn.Module, optional) – The module that should be used as a model. Have higher priority if provided. model_class and model cannot both be None.

evaluate(x: Any, y: Any, *args: Any, **kwargs: Any) None[source]#

Receives x and y validation sequences. Passes x to the model’s predict method along with any args or kwargs needed. Then updates self.metrics with what objects in metrics return. metrics should contain Metric with compute() method or callables with the interface: f(true, predicted) -> metric_value, where metric_value is a scalar

Parameters:
  • x (Any) – Input of the model.

  • y (Any) – Desired output to compare with the values predicted.

  • metrics (List[Union[Metric, Callable[[Any, Any], MetricType]]]) – List of metrics or callables to compute metric values

get_meta() List[Dict[Any, Any]][source]#
Returns:

meta – A list where first element is this object’s metadata. All other elements represent the other stages of pipeline if present.

Meta can be anything that is worth to document about the object and its properties.

Meta is a list (see Meta type alias) to allow the formation of pipelines.

Return type:

Meta

load_artifact(path: str, *args: Any, **kwargs: Any) None[source]#

Loads torch module. Additional args and kwargs are passed to torch.load

Parameters:

path (str) – the folder from which to load pipeline.pkl

Raises:

ValueError – if the path is not a valid directory

predict(*args: Any, **kwargs: Any) Any[source]#

Calls internal module with arguments provided.

save(path: str, *args: Any, **kwargs: Any) None[source]#

Saves model to the path provided. Path should be a folder. Creates it if not exists and saves there as model.pkl

When saving using this method only wrapper is saved if you want to save torch module use save_artifact

save_artifact(path: str, *args: Any, **kwargs: Any) None[source]#

Saves torch module. Additional args and kwargs are passed to torch.save

Parameters:

path (str) – the folder in which to save checkpoint.pt

Raises:

ValueError – if the path is not a valid directory