Skip to content

callbacks

Callbacks for training.

Classes:

Name Description
EarlyStopping

Initialize the EarlyStopping callback.

Scheduler

Initialize a scheduler object to encapsulate different torch schedulers.

ModelCheckpoint

Model checkpoint initialization.

Attributes:

Name Type Description
Callback

Callback = EarlyStopping | Scheduler | ModelCheckpoint #

EarlyStopping #

Initialize the EarlyStopping callback.

The given monitoring metric will be used on the validation set in order to evaluate the training process. If no validation set provided when calling the Trainer.train method, the monitoring metric will be used on training set and can lead to an overfitted trained model.

Parameters:

Name Type Description Default

monitoring_metric #

str

Monitoring metric required for early stopping, for instance "Total loss". The specified metric is by default computed on the validation set. If no validation set is provided, the only available metric is "Total loss", the epoch total loss on the train set.

required

mode #

Literal['maximize', 'minimize']

Whether to "maximize" or "minimize" the provided metric.

required

patience #

int

Number of epochs to wait with no improvement of the monitoring value.

3.

min_delta #

float

Minimum delta between two monitoring values to consider an improvement.

0.

Attributes:

Name Type Description
monitoring_metric str
mode Literal['maximize', 'minimize']
patience int
min_delta float

monitoring_metric: str #

mode: Literal['maximize', 'minimize'] = field(validator=lambda self, attribute, value: value in [str(mode_) for mode_ in EarlyStoppingRequestBodyMode]) #

patience: int = 3 #

min_delta: float = 0.0 #

Scheduler #

Initialize a scheduler object to encapsulate different torch schedulers.

Parameters:

Name Type Description Default

pre_scheduler #

partial[LRScheduler]

Based torch lr scheduler to be instantiated. Should not contain the optimizer as xpdeep use the trainer's optimizer for the scheduler internally.

required

step_method #

Literal['batch', 'epoch']

"epoch" or "batch".

required

monitoring_metric #

str

Monitoring metric required for the step method, for instance "Total loss". The specified metric is by default computed on the validation set. If no validation set is provided, the only available metric is "Total loss", the epoch total loss on the train set.

required

Attributes:

Name Type Description
pre_scheduler partial[LRScheduler | ReduceLROnPlateau]
monitoring_metric str
step_method Literal['batch', 'epoch']

pre_scheduler: partial[LRScheduler | ReduceLROnPlateau] #

monitoring_metric: str #

step_method: Literal['batch', 'epoch'] = field(validator=lambda self, attribute, value: value in [str(step_method_) for step_method_ in SchedulerRequestBodyStepmethod]) #

ModelCheckpoint #

Model checkpoint initialization.

Parameters:

Name Type Description Default

monitoring_metric #

str

Monitoring metric required for the step method, for instance "Total loss". The specified metric is by default computed on the validation set. If no validation set is provided, the only available metric is "Total loss", the epoch total loss on the train set.

required

save_every_epoch #

int

How often to save the model. If None, only save the best checkpoint.

1

mode #

Literal[str, str]

Whether to "maximize" or "minimize" the provided metric.

required

Attributes:

Name Type Description
monitoring_metric str
mode Literal['maximize', 'minimize']
save_every_epoch int

monitoring_metric: str #

mode: Literal['maximize', 'minimize'] = field(validator=lambda self, attribute, value: value in [str(mode_) for mode_ in ModelCheckpointRequestBodyMode]) #

save_every_epoch: int = 1 #