Skip to content

explainer

How to explain a trained model.

CounterfactualParameters #

Parameters to configure counterfactual computing.

Explainer #

Explain a XpdeepModel.

Parameters:

Name Type Description Default
description_representativeness

A parameter governing the explanation quality, the greater, the better, but it will be slower to compute.

required
quality_metrics

A list of quality metrics to compute, like Sensitivity or Infidelity.

required
window_size

DTW parameter windows (proportion %)

required
metrics

A list of metrics to compute along with the explanation (F1 score etc.)

required
statistics

A list of statistics to compute along with the explanation (Variance on targets etc.)

required

description_representativeness: int #

quality_metrics: list[QualityMetrics] #

window_size: int | None = None #

metrics: DictMetrics | None = None #

statistics: DictStats | None = None #

local_explain(trained_model: TrainedModelArtifact, train_set: FittedParquetDataset, dataset_filter: Filter, *, explanation_name: str | None = None, explanation_description: str | None = None, progress_bar_update_rate: float = 1, counterfactual_parameters: CounterfactualParameters | None = None) -> ExplanationArtifact #

Create a causal explanation from trained model.

Parameters:

Name Type Description Default
trained_model TrainedModelArtifact

A model trained via the trainer interface.

required
train_set FittedParquetDataset

A dataset representing a train split.

required
dataset_filter Filter

A filter used to filter the dataset and get samples to explain.

required
explanation_name str | None

The explanation name, default None.

None
explanation_description str | None

The explanation description, default None.

None
progress_bar_update_rate float

The progress bar update rate, default 1.

1
counterfactual_parameters CounterfactualParameters | None

Parameters used to compute counterfactual if given, default None.

None

Returns:

Type Description
ExplanationResultsModel

The causal explanation results, containing the result as json.

Source code in src/xpdeep/explain/explainer.py
def local_explain(  # noqa:PLR0913
    self,
    trained_model: TrainedModelArtifact,
    train_set: FittedParquetDataset,
    dataset_filter: Filter,
    *,
    explanation_name: str | None = None,
    explanation_description: str | None = None,
    progress_bar_update_rate: float = 1,
    counterfactual_parameters: CounterfactualParameters | None = None,
) -> ExplanationArtifact:
    """Create a causal explanation from trained model.

    Parameters
    ----------
    trained_model: TrainedModelArtifact
        A model trained via the trainer interface.
    train_set: FittedParquetDataset
        A dataset representing a train split.
    dataset_filter: Filter
        A filter used to filter the dataset and get samples to explain.
    explanation_name: str | None
        The explanation name, default None.
    explanation_description: str | None
        The explanation description, default None.
    progress_bar_update_rate: float
        The progress bar update rate, default 1.
    counterfactual_parameters: CounterfactualParameters
        Parameters used to compute counterfactual if given, default None.

    Returns
    -------
    ExplanationResultsModel
        The causal explanation results, containing the result as json.
    """
    metrics = self.metrics.as_request_body if self.metrics is not None else None
    statistics = self.statistics.as_request_body if self.statistics is not None else None

    dataset_filter.apply()
    explanation_create_request_body = ExplanationCreateRequestBody(
        filter_id=cast(str, dataset_filter.id),
        trained_model_id=trained_model.id,
        train_dataset_id=train_set.artifact_id,
        name=explanation_name,
        description=explanation_description,
        config=ExplanationCreateConfigBody(
            description_representativeness=self.description_representativeness,
            quality_metrics=[
                ExplanationCreateConfigQualityMetricBody.from_dict(quality_metric.to_dict())
                for quality_metric in self.quality_metrics
            ],
            metrics=metrics,
            statistics=statistics,
        ),
    )
    if counterfactual_parameters:
        explanation_with_counterfactual_create_request_body = CounterfactualCreateRequestBody(
            explanation_params=explanation_create_request_body,
            counterfactual_params=counterfactual_parameters.as_request_body,
        )
        compute_local_explanation_job = cast(
            JobModel,
            create_counterfactual.sync(
                project_id=Project.CURRENT.get().model.id,
                client=ClientFactory.CURRENT.get()(),
                body=explanation_with_counterfactual_create_request_body,
            ),
        )
    else:
        compute_local_explanation_job = cast(
            JobModel,
            create_explanation.sync(
                project_id=Project.CURRENT.get().model.id,
                client=ClientFactory.CURRENT.get()(),
                body=explanation_create_request_body,
            ),
        )

    explanation_results_model = self._get_one_explanation(compute_local_explanation_job, progress_bar_update_rate)

    return ExplanationArtifact(explanation_results_model)

global_explain(trained_model: TrainedModelArtifact, train_set: FittedParquetDataset, test_set: FittedParquetDataset | None = None, validation_set: FittedParquetDataset | None = None, progress_bar_update_rate: float = 1) -> ExplanationArtifact #

Compute model decision on a trained model.

Parameters:

Name Type Description Default
trained_model TrainedModelArtifact

A model trained via the trainer interface.

required
train_set FittedParquetDataset

A dataset representing a train split.

required
test_set FittedParquetDataset | None

A dataset representing a test split, used to optionally compute split statistics.

None
validation_set FittedParquetDataset | None

A dataset representing a validation split, used to optionally compute split statistics.

None
progress_bar_update_rate float
1

Returns:

Type Description
ExplanationResultsModel

The model decision results, containing the result as json.

Source code in src/xpdeep/explain/explainer.py
def global_explain(
    self,
    trained_model: TrainedModelArtifact,
    train_set: FittedParquetDataset,
    test_set: FittedParquetDataset | None = None,
    validation_set: FittedParquetDataset | None = None,
    progress_bar_update_rate: float = 1,
) -> ExplanationArtifact:
    """Compute model decision on a trained model.

    Parameters
    ----------
    trained_model: TrainedModelArtifact
        A model trained via the trainer interface.
    train_set: FittedParquetDataset
        A dataset representing a train split.
    test_set: FittedParquetDataset | None
        A dataset representing a test split, used to optionally compute split statistics.
    validation_set: FittedParquetDataset | None
        A dataset representing a validation split, used to optionally compute split statistics.
    progress_bar_update_rate

    Returns
    -------
    ExplanationResultsModel
        The model decision results, containing the result as json.
    """
    test_set_id = test_set.artifact_id if test_set is not None else None
    validation_set_id = validation_set.artifact_id if validation_set is not None else None

    metrics = self.metrics.as_request_body if self.metrics is not None else None
    statistics = self.statistics.as_request_body if self.statistics is not None else None

    description = f"Global explanation on trained model : {trained_model.name} - with id : {trained_model.id}"
    body = GlobalExplanationCreateRequestBody(
        trained_model_id=trained_model.id,
        train_dataset_id=train_set.artifact_id,
        config=ExplanationCreateConfigBody(
            description_representativeness=self.description_representativeness,
            quality_metrics=[
                ExplanationCreateConfigQualityMetricBody.from_dict(quality_metric.to_dict())
                for quality_metric in self.quality_metrics
            ],
            metrics=metrics,
            statistics=statistics,
        ),
        name=f"Global_explanation_{trained_model.name}_{trained_model.id}",
        description=description,
        test_dataset_id=test_set_id,
        validation_dataset_id=validation_set_id,
    )

    compute_global_explanation_job = cast(
        JobModel,
        create_global_explanation.sync(
            project_id=Project.CURRENT.get().model.id, client=ClientFactory.CURRENT.get()(), body=body
        ),
    )

    explanation_results_model = self._get_one_explanation(compute_global_explanation_job, progress_bar_update_rate)
    return ExplanationArtifact(explanation_results_model)