How to Train a Self-Explainable Model#
You can train your Xpdeep model with the XpdeepModel object, using the Trainer interface.
The Trainer class is dedicated to learn a self-explainable model (both the explanation and the model are trained).
The Trainer is fully customizable, and encapsulate your hyperparameters, callbacks or other training configurations.
Trainer related methods are asynchronous, meaning you don't have to wait for its completion before starting a new API call.
Future Release
The training process will be interruptible with ctrl+C
Requirements:
- an
XpdeepModel, see build model. - a train / validation sets as
FittedParquetDataset, see create dataset.
Warning
This section presents the training process used to train both the model and its explanations. Please refer to the Learn Explanations of an Existing Model section to learn how to learn explanations of an already existing model.
1. Build the Trainer#
Let's dive into the Trainer configuration.
loss: a loss function, as an XpdeepLoss object. The expected output shape of the loss function is a 1D tensor of shape (batch_size,).
For ease of use, if the loss function has more than 1 output dimension, it will automatically be averaged as a 1-D tensor.
The loss takes the predictions (task learner model output) as first argument
and the preprocessed target as second argument.
Similarly as the model creation, you can either
build your own loss as an ApiLoss with XpdeepLoss.from_torch() or use a pre-existing Xpdeep loss.
Tip
Most of pytorch built-in loss function respect this convention when instantiated with reduction="none"
(e.g. torch.nn.MSELoss(reduction="none")).
Future Release
It will be possible to specify custom "loss outputs" when designing models to leverage loss function that require other inputs than the predictions and targets.
optimizer: please provide any pytorch optimizer. The optimizer should be a partial optimizer that do not specify
the "params" argument as the association must be done server side.
Warning
Here, you should set torch optimizer foreach and fused parameters to False as currently it may lead to unstable behaviour in the training process.
metrics: you can set your own metrics in a DictMetrics structure. Similarly to the loss function, metrics are computed
between the model predictions and the targets. Any metric from
torchmetrics that respect this convention is supported (as partial), and will be serialized with its name, again for security issues.
If you add a metric from torchmetrics, two types of metrics will be automatically computed:
- Global model metrics
TorchGlobalMetric, which are related to the model overall performance. - Leaf metrics
TorchLeafMetric, which provide detailed metrics specific to each leaf (or predictive region) of the model, offering more granular insights.
Please note that DictMetrics keys will be used as metric names in XpViz.
callbacks: here you can define specifics callbacks.
Finally, you can provide a set of self explicable parameters, like max_epochs.
Tip
You can directly specify a TorchGlobalMetric or a TorchLeafMetric to compute global or per-leaf metrics only.
In addition, both provide additional parameters like on_raw_data that allow you to compute
the metrics on raw or preprocessed data.
For instance, if you scaled your data on the preprocessing stage, and use a MSE with TorchGlobalMetric with
on_raw_data True, the MSE will be computed on the raw data and not the scaled data.
Warning
For multiclass metrics, torchmetrics expects the target to be an index vector and not onehot vector.
If your targets are onehot vectors and not indexes, please use the target_as_indexes parameter which convert targets
from onehot to indexes prior to the metric computation.
Future Release
For more flexibility, it will be possible to specify transform functions which will be applied to the model prediction and to the targets prior to metric computation.
import torch.nn as nn
from torch.optim import AdamW
from functools import partial
from xpdeep.trainer.callbacks import EarlyStopping
from xpdeep.trainer.loss import XpdeepLoss
from xpdeep.trainer.trainer import Trainer
from torchmetrics.classification import MulticlassF1Score, MulticlassAccuracy
from xpdeep.metric import DictMetrics, TorchGlobalMetric, TorchLeafMetric
metrics = DictMetrics(
global_multi_class_accuracy=TorchGlobalMetric(partial(MulticlassAccuracy, num_classes=3, average="micro"), target_as_indexes=True),
leaf_multi_class_accuracy=TorchLeafMetric(partial(MulticlassAccuracy, num_classes=3, average="micro"), target_as_indexes=True),
global_multi_class_F1_score=TorchGlobalMetric(partial(MulticlassF1Score, num_classes=3, average="macro"), target_as_indexes=True),
leaf_multi_class_F1_score=TorchLeafMetric(partial(MulticlassF1Score, num_classes=3, average="macro"), target_as_indexes=True),)
loss = XpdeepLoss.from_torch(loss=nn.MSELoss(reduction="none"), model=xpdeep_model, fitted_schema=fitted_train_dataset.fitted_schema)
trainer = Trainer(
loss=loss,
optimizer=partial(AdamW, lr=0.01),
callbacks=[EarlyStopping(monitoring_metric="Total loss", mode="minimize")],
metrics=metrics,
max_epochs=5
)
👀 Full file preview
"""Tutorial to train a model."""
from functools import partial
from build_model import xpdeep_model
from create_dataset import fitted_train_dataset, fitted_validation_dataset
from torch import nn
from torch.optim import AdamW
from torchmetrics import Accuracy, F1Score
from xpdeep.metric import DictMetrics, TorchGlobalMetric, TorchLeafMetric
from xpdeep.trainer.callbacks import EarlyStopping
from xpdeep.trainer.loss import XpdeepLoss
from xpdeep.trainer.trainer import Trainer
metrics = DictMetrics(
global_f1_score=TorchGlobalMetric(partial(F1Score, task="multiclass", num_classes=3), on_raw_data=True),
leaf_f1_score=TorchLeafMetric(partial(F1Score, task="multiclass", num_classes=3), on_raw_data=True),
global_accuracy=TorchGlobalMetric(partial(Accuracy, task="multiclass", num_classes=3), on_raw_data=True),
leaf_accuracy=TorchLeafMetric(partial(Accuracy, task="multiclass", num_classes=3), on_raw_data=True),
)
loss = XpdeepLoss.from_torch(
loss=nn.MSELoss(reduction="none"), model=xpdeep_model, fitted_schema=fitted_train_dataset.fitted_schema
)
trainer = Trainer(
loss=loss,
optimizer=partial(AdamW, lr=0.01),
callbacks=[EarlyStopping(monitoring_metric="Total loss", mode="minimize")],
metrics=metrics,
max_epochs=5,
)
trained_model = trainer.train(
xpdeep_model, train_set=fitted_train_dataset, validation_set=fitted_validation_dataset, batch_size=32
)
2. Train the Model#
With your trainer, you can finally train the explainable model and get a trained model as a TrainedModelArtifact.
trained_model = trainer.train(xpdeep_model, train_set=fitted_train_dataset, validation_set=fitted_validation_dataset, batch_size=32)
👀 Full file preview
"""Tutorial to train a model."""
from functools import partial
from build_model import xpdeep_model
from create_dataset import fitted_train_dataset, fitted_validation_dataset
from torch import nn
from torch.optim import AdamW
from torchmetrics import Accuracy, F1Score
from xpdeep.metric import DictMetrics, TorchGlobalMetric, TorchLeafMetric
from xpdeep.trainer.callbacks import EarlyStopping
from xpdeep.trainer.loss import XpdeepLoss
from xpdeep.trainer.trainer import Trainer
metrics = DictMetrics(
global_f1_score=TorchGlobalMetric(partial(F1Score, task="multiclass", num_classes=3), on_raw_data=True),
leaf_f1_score=TorchLeafMetric(partial(F1Score, task="multiclass", num_classes=3), on_raw_data=True),
global_accuracy=TorchGlobalMetric(partial(Accuracy, task="multiclass", num_classes=3), on_raw_data=True),
leaf_accuracy=TorchLeafMetric(partial(Accuracy, task="multiclass", num_classes=3), on_raw_data=True),
)
loss = XpdeepLoss.from_torch(
loss=nn.MSELoss(reduction="none"), model=xpdeep_model, fitted_schema=fitted_train_dataset.fitted_schema
)
trainer = Trainer(
loss=loss,
optimizer=partial(AdamW, lr=0.01),
callbacks=[EarlyStopping(monitoring_metric="Total loss", mode="minimize")],
metrics=metrics,
max_epochs=5,
)
trained_model = trainer.train(
xpdeep_model, train_set=fitted_train_dataset, validation_set=fitted_validation_dataset, batch_size=32
)
Under the hood, the trained model will be saved as a trained model artifact within your Project.
You can get training logs in your terminal.
Future Release
Insights and logs will be provided as artifacts.
3. Evaluate the Model#
Model evaluation to assess performance is carried out through the Explainer interface during the global_explain
and local_explain methods.
A wide range of metrics is supported, see Get Learned Explanations.