Skip to content

MNIST Dataset#

In this section, we detail the pytorch code for designing an explainable deep model for processing the MNIST dataset.

MNIST is a dataset for classification of image inputs.

We will use HuggingFace dataset hub for convenience, but you can also download the data from another source on your side and update the tutorial accordingly.

The goal of this task is to classify a given image of a handwritten digit into one of 10 classes representing integer values from 0 to 9, inclusively.

Please follow this end-to-end tutorial to prepare the dataset, create and train the model, and finally compute explanations.

Prepare the Dataset#

1. Split and Convert your Raw Data#

The first step consists in creating your train, test and validation splits as StandardDataset.

We load the dataset from HuggingFace datasets hub for convenience.

from datasets import load_dataset

dataset = load_dataset('mnist', trust_remote_code=True)

The dataset is then split into train, test and validation set.

from datasets import DatasetDict

test_eval = dataset["test"].train_test_split(test_size=0.5, stratify_by_column="label", seed=1225)

splits = DatasetDict(
    {
        "train": dataset["train"],
        "val": test_eval["train"],  # HuggingFace requires the "train" keyword.
        "test": test_eval["test"],
    }
)
👀 Full file preview
"""MNIST workflow, classification with image data."""

import os
from copy import deepcopy
from functools import partial

import boto3
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from botocore.client import Config
from datasets import DatasetDict, load_dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.classification import MulticlassAccuracy, MulticlassConfusionMatrix, MulticlassF1Score
from torchvision.transforms import RandomRotation

from xpdeep import init, set_project
from xpdeep.dataset.feature.augmentation.augmentation import FeatureAugmentation
from xpdeep.dataset.parquet_dataset import FittedParquetDataset, ParquetDataset
from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats, DistributionStat
from xpdeep.filtering.criteria import CategoricalCriterion
from xpdeep.filtering.filter import Filter
from xpdeep.metric import DictMetrics, TorchGlobalMetric, TorchLeafMetric
from xpdeep.model.feature_extraction_output_type import FeatureExtractionOutputType
from xpdeep.model.model_parameters import ModelDecisionGraphParameters
from xpdeep.model.xpdeep_model import XpdeepModel
from xpdeep.model.zoo.cross_entropy_loss_from_proba import CrossEntropyLossFromProbabilities
from xpdeep.model.zoo.doc import MnistCNN
from xpdeep.model.zoo.mlp import MLP
from xpdeep.project import Project, get_project
from xpdeep.trainer.callbacks import EarlyStopping, ModelCheckpoint, Scheduler
from xpdeep.trainer.loss import XpdeepLoss
from xpdeep.trainer.trainer import Trainer

STORAGE_OPTIONS = {
    "key": os.getenv("S3_DATASET_ACCESS_KEY_ID"),
    "secret": os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
    "client_kwargs": {
        "endpoint_url": os.getenv("S3_DATASET_ENDPOINT_URL"),
    },
    "s3_additional_kwargs": {"addressing_style": "path"},
}


def main():
    """Process the dataset, train, and explain the model."""
    torch.random.manual_seed(5)

    # ##### Prepare the Dataset #######

    # 1. Split and Convert your Raw Data
    # Load the dataset from HuggingFace datasets hub for convenience.
    dataset = load_dataset("mnist", trust_remote_code=True)

    test_eval = dataset["test"].train_test_split(test_size=0.5, stratify_by_column="label", seed=1225)

    splits = DatasetDict({
        "train": dataset["train"],
        "val": test_eval["train"],  # HuggingFace requires the "train" keyword.
        "test": test_eval["test"],
    })

    # Convert to pyarrow Table format
    train_table = pa.Table.from_pandas(splits["train"].to_pandas(), preserve_index=True)
    val_table = pa.Table.from_pandas(splits["val"].to_pandas(), preserve_index=True)
    test_table = pa.Table.from_pandas(splits["test"].to_pandas(), preserve_index=True)

    # Save each split as ".parquet" file
    pq.write_table(train_table, "train.parquet")
    pq.write_table(val_table, "val.parquet")
    pq.write_table(test_table, "test.parquet")

    # 2. Upload your Converted Data
    client = boto3.client(
        service_name="s3",
        endpoint_url=os.getenv("S3_DATASET_ENDPOINT_URL"),
        aws_access_key_id=os.getenv("S3_DATASET_ACCESS_KEY_ID"),
        aws_secret_access_key=os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
        config=Config(signature_version="s3v4"),
    )

    client.upload_file("train.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/train.parquet")
    client.upload_file("val.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/val.parquet")
    client.upload_file("test.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/test.parquet")

    # 3. Instantiate a Dataset
    train_dataset = ParquetDataset(
        name="mnist_train_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/train.parquet",
        storage_options=STORAGE_OPTIONS,
    )

    # 4. Find a schema
    analyzed_train_dataset = train_dataset.analyze(target_names=["label"])

    # 5. Fit the schema
    fit_train_dataset = analyzed_train_dataset.fit()

    fit_test_dataset = FittedParquetDataset(
        name="mnist_test_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/test.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    fit_val_dataset = FittedParquetDataset(
        name="mnist_validation_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/val.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    # 6. Add augmentation on the train set fitted schema (optional) , to ensure augmentation won't be applied on test
    # and validation sets
    # Here as images are grayscale, we don't need to permute the channel dimension to get the
    # channel dimension first.
    # For an RGB image, we could have used Compose([Permute([0, 3, 1, 2]), RandomRotation(90), Permute([0, 2, 3, 1])])
    augmentation = RandomRotation(90)
    image_rotation_augmentation = FeatureAugmentation(augment_raw=augmentation, augment_preprocessed=augmentation)
    fit_train_dataset.fitted_schema.add_augmentation("image", image_rotation_augmentation)

    # ##### Prepare the Model #######

    # 1. Create the required torch models
    input_size = fit_train_dataset.fitted_schema.input_size[1:]  # 28 x 28
    target_size = fit_train_dataset.fitted_schema.target_size[1]  # 10

    print(f"input_size: {input_size} - target_size: {target_size}")

    feature_extractor = MnistCNN(output_size=128)
    task_learner = MLP(
        input_size=128,
        activation_layer=partial(torch.nn.ReLU),
        flatten_input=True,
        hidden_channels=[target_size],
        last_activation=partial(torch.nn.Softmax, dim=1),
    )

    # 2. Explainable Model Specifications
    model_specifications = ModelDecisionGraphParameters(
        feature_extraction_output_type=FeatureExtractionOutputType.VECTOR,
        balancing_weight=0.1,  # This incentivizes the explainable model to provide more decisions.
    )

    # 3. Create the Explainable Model
    xpdeep_model = XpdeepModel.from_torch(
        example_dataset=fit_train_dataset,
        feature_extraction=feature_extractor,
        task_learner=task_learner,
        backbone=None,
        decision_graph_parameters=model_specifications,
    )

    # ##### Train #######

    # Metrics to monitor the training.
    metrics = DictMetrics(
        global_multi_class_accuracy=TorchGlobalMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        leaf_multi_class_accuracy=TorchLeafMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        global_multi_class_F1_score=TorchGlobalMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        leaf_multi_class_F1_score=TorchLeafMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        global_confusion_matrix=TorchGlobalMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
        leaf_confusion_matrix=TorchLeafMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
    )

    callbacks = [
        EarlyStopping(monitoring_metric="global_multi_class_accuracy", mode="maximize", patience=10),
        Scheduler(
            pre_scheduler=partial(ReduceLROnPlateau, mode="min", patience=1),
            step_method="epoch",
            monitoring_metric="Total loss",
        ),
        ModelCheckpoint(monitoring_metric="global_multi_class_accuracy", mode="maximize"),
    ]

    # Optimizer is a partial object as pytorch needs to give the model as optimizer parameter.
    optimizer = partial(torch.optim.AdamW, lr=0.001, foreach=False, fused=False)

    loss = XpdeepLoss.from_torch(
        loss=CrossEntropyLossFromProbabilities(reduction="none"),
        model=xpdeep_model,
        fitted_schema=fit_train_dataset.fitted_schema,
    )

    trainer = Trainer(
        loss=loss,
        optimizer=optimizer,
        callbacks=callbacks,
        start_epoch=0,
        max_epochs=6,
        metrics=metrics,
    )

    trained_model = trainer.train(
        model=xpdeep_model,
        train_set=fit_train_dataset,
        validation_set=fit_val_dataset,
        batch_size=128,
    )

    # ##### Explain #######

    # 1. Build the Explainer
    statistics = DictStats(
        distribution_target=DistributionStat(on="target"), distribution_prediction=DistributionStat(on="prediction")
    )
    quality_metrics = [Sensitivity(), Infidelity()]

    explainer = Explainer(quality_metrics=quality_metrics, metrics=metrics, statistics=statistics)

    # 2. Model Functioning Explanations
    model_explanations = explainer.global_explain(
        trained_model,
        train_set=fit_train_dataset,
        test_set=fit_test_dataset,
        validation_set=fit_val_dataset,
    )

    print(model_explanations.visualisation_link)

    # 3. Inference and their Causal Explanations
    my_filter = Filter("testing_filter", fit_test_dataset)
    my_filter.add_criteria(
        CategoricalCriterion(fit_test_dataset.fitted_schema["label"], categories=[1, 2]),
    )

    causal_explanations = explainer.local_explain(trained_model, fit_test_dataset, my_filter)
    print(causal_explanations.visualisation_link)


if __name__ == "__main__":
    init(api_key=os.getenv("API_KEY"), api_url=os.getenv("API_URL"))
    set_project(Project.create_or_get(name="MNIST Tutorial Review"))

    try:
        main()
    finally:
        get_project().delete()

As stated in the doc, Xpdeep requires a ".parquet" file to create the dataset. The original data is stored under a DatasetDict object, therefore each split must be converted to a ".parquet" file.

Tip

To get your ".parquet" files, you can easily convert each split from pandas.DataFrame to pyarrow.Table first.

Tip

We set preserve_index to True to keep the dataframe index as a dedicated column that is later captured as IndexMetadata in the schema.

import pyarrow as pa
import pyarrow.parquet as pq

# Convert to pyarrow Table format
train_table = pa.Table.from_pandas(splits["train"].to_pandas(), preserve_index=True)
val_table = pa.Table.from_pandas(splits["val"].to_pandas(), preserve_index=True)
test_table = pa.Table.from_pandas(splits["test"].to_pandas(), preserve_index=True)

# Save each split as ".parquet" file
pq.write_table(train_table, "train.parquet")
pq.write_table(val_table, "val.parquet")
pq.write_table(test_table, "test.parquet")
👀 Full file preview
"""MNIST workflow, classification with image data."""

import os
from copy import deepcopy
from functools import partial

import boto3
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from botocore.client import Config
from datasets import DatasetDict, load_dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.classification import MulticlassAccuracy, MulticlassConfusionMatrix, MulticlassF1Score
from torchvision.transforms import RandomRotation

from xpdeep import init, set_project
from xpdeep.dataset.feature.augmentation.augmentation import FeatureAugmentation
from xpdeep.dataset.parquet_dataset import FittedParquetDataset, ParquetDataset
from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats, DistributionStat
from xpdeep.filtering.criteria import CategoricalCriterion
from xpdeep.filtering.filter import Filter
from xpdeep.metric import DictMetrics, TorchGlobalMetric, TorchLeafMetric
from xpdeep.model.feature_extraction_output_type import FeatureExtractionOutputType
from xpdeep.model.model_parameters import ModelDecisionGraphParameters
from xpdeep.model.xpdeep_model import XpdeepModel
from xpdeep.model.zoo.cross_entropy_loss_from_proba import CrossEntropyLossFromProbabilities
from xpdeep.model.zoo.doc import MnistCNN
from xpdeep.model.zoo.mlp import MLP
from xpdeep.project import Project, get_project
from xpdeep.trainer.callbacks import EarlyStopping, ModelCheckpoint, Scheduler
from xpdeep.trainer.loss import XpdeepLoss
from xpdeep.trainer.trainer import Trainer

STORAGE_OPTIONS = {
    "key": os.getenv("S3_DATASET_ACCESS_KEY_ID"),
    "secret": os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
    "client_kwargs": {
        "endpoint_url": os.getenv("S3_DATASET_ENDPOINT_URL"),
    },
    "s3_additional_kwargs": {"addressing_style": "path"},
}


def main():
    """Process the dataset, train, and explain the model."""
    torch.random.manual_seed(5)

    # ##### Prepare the Dataset #######

    # 1. Split and Convert your Raw Data
    # Load the dataset from HuggingFace datasets hub for convenience.
    dataset = load_dataset("mnist", trust_remote_code=True)

    test_eval = dataset["test"].train_test_split(test_size=0.5, stratify_by_column="label", seed=1225)

    splits = DatasetDict({
        "train": dataset["train"],
        "val": test_eval["train"],  # HuggingFace requires the "train" keyword.
        "test": test_eval["test"],
    })

    # Convert to pyarrow Table format
    train_table = pa.Table.from_pandas(splits["train"].to_pandas(), preserve_index=True)
    val_table = pa.Table.from_pandas(splits["val"].to_pandas(), preserve_index=True)
    test_table = pa.Table.from_pandas(splits["test"].to_pandas(), preserve_index=True)

    # Save each split as ".parquet" file
    pq.write_table(train_table, "train.parquet")
    pq.write_table(val_table, "val.parquet")
    pq.write_table(test_table, "test.parquet")

    # 2. Upload your Converted Data
    client = boto3.client(
        service_name="s3",
        endpoint_url=os.getenv("S3_DATASET_ENDPOINT_URL"),
        aws_access_key_id=os.getenv("S3_DATASET_ACCESS_KEY_ID"),
        aws_secret_access_key=os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
        config=Config(signature_version="s3v4"),
    )

    client.upload_file("train.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/train.parquet")
    client.upload_file("val.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/val.parquet")
    client.upload_file("test.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/test.parquet")

    # 3. Instantiate a Dataset
    train_dataset = ParquetDataset(
        name="mnist_train_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/train.parquet",
        storage_options=STORAGE_OPTIONS,
    )

    # 4. Find a schema
    analyzed_train_dataset = train_dataset.analyze(target_names=["label"])

    # 5. Fit the schema
    fit_train_dataset = analyzed_train_dataset.fit()

    fit_test_dataset = FittedParquetDataset(
        name="mnist_test_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/test.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    fit_val_dataset = FittedParquetDataset(
        name="mnist_validation_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/val.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    # 6. Add augmentation on the train set fitted schema (optional) , to ensure augmentation won't be applied on test
    # and validation sets
    # Here as images are grayscale, we don't need to permute the channel dimension to get the
    # channel dimension first.
    # For an RGB image, we could have used Compose([Permute([0, 3, 1, 2]), RandomRotation(90), Permute([0, 2, 3, 1])])
    augmentation = RandomRotation(90)
    image_rotation_augmentation = FeatureAugmentation(augment_raw=augmentation, augment_preprocessed=augmentation)
    fit_train_dataset.fitted_schema.add_augmentation("image", image_rotation_augmentation)

    # ##### Prepare the Model #######

    # 1. Create the required torch models
    input_size = fit_train_dataset.fitted_schema.input_size[1:]  # 28 x 28
    target_size = fit_train_dataset.fitted_schema.target_size[1]  # 10

    print(f"input_size: {input_size} - target_size: {target_size}")

    feature_extractor = MnistCNN(output_size=128)
    task_learner = MLP(
        input_size=128,
        activation_layer=partial(torch.nn.ReLU),
        flatten_input=True,
        hidden_channels=[target_size],
        last_activation=partial(torch.nn.Softmax, dim=1),
    )

    # 2. Explainable Model Specifications
    model_specifications = ModelDecisionGraphParameters(
        feature_extraction_output_type=FeatureExtractionOutputType.VECTOR,
        balancing_weight=0.1,  # This incentivizes the explainable model to provide more decisions.
    )

    # 3. Create the Explainable Model
    xpdeep_model = XpdeepModel.from_torch(
        example_dataset=fit_train_dataset,
        feature_extraction=feature_extractor,
        task_learner=task_learner,
        backbone=None,
        decision_graph_parameters=model_specifications,
    )

    # ##### Train #######

    # Metrics to monitor the training.
    metrics = DictMetrics(
        global_multi_class_accuracy=TorchGlobalMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        leaf_multi_class_accuracy=TorchLeafMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        global_multi_class_F1_score=TorchGlobalMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        leaf_multi_class_F1_score=TorchLeafMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        global_confusion_matrix=TorchGlobalMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
        leaf_confusion_matrix=TorchLeafMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
    )

    callbacks = [
        EarlyStopping(monitoring_metric="global_multi_class_accuracy", mode="maximize", patience=10),
        Scheduler(
            pre_scheduler=partial(ReduceLROnPlateau, mode="min", patience=1),
            step_method="epoch",
            monitoring_metric="Total loss",
        ),
        ModelCheckpoint(monitoring_metric="global_multi_class_accuracy", mode="maximize"),
    ]

    # Optimizer is a partial object as pytorch needs to give the model as optimizer parameter.
    optimizer = partial(torch.optim.AdamW, lr=0.001, foreach=False, fused=False)

    loss = XpdeepLoss.from_torch(
        loss=CrossEntropyLossFromProbabilities(reduction="none"),
        model=xpdeep_model,
        fitted_schema=fit_train_dataset.fitted_schema,
    )

    trainer = Trainer(
        loss=loss,
        optimizer=optimizer,
        callbacks=callbacks,
        start_epoch=0,
        max_epochs=6,
        metrics=metrics,
    )

    trained_model = trainer.train(
        model=xpdeep_model,
        train_set=fit_train_dataset,
        validation_set=fit_val_dataset,
        batch_size=128,
    )

    # ##### Explain #######

    # 1. Build the Explainer
    statistics = DictStats(
        distribution_target=DistributionStat(on="target"), distribution_prediction=DistributionStat(on="prediction")
    )
    quality_metrics = [Sensitivity(), Infidelity()]

    explainer = Explainer(quality_metrics=quality_metrics, metrics=metrics, statistics=statistics)

    # 2. Model Functioning Explanations
    model_explanations = explainer.global_explain(
        trained_model,
        train_set=fit_train_dataset,
        test_set=fit_test_dataset,
        validation_set=fit_val_dataset,
    )

    print(model_explanations.visualisation_link)

    # 3. Inference and their Causal Explanations
    my_filter = Filter("testing_filter", fit_test_dataset)
    my_filter.add_criteria(
        CategoricalCriterion(fit_test_dataset.fitted_schema["label"], categories=[1, 2]),
    )

    causal_explanations = explainer.local_explain(trained_model, fit_test_dataset, my_filter)
    print(causal_explanations.visualisation_link)


if __name__ == "__main__":
    init(api_key=os.getenv("API_KEY"), api_url=os.getenv("API_URL"))
    set_project(Project.create_or_get(name="MNIST Tutorial Review"))

    try:
        main()
    finally:
        get_project().delete()

2. Upload your Converted Data#

Warning

Don't forget to set up a Project and initialize the API with your credentials !

from xpdeep import init, set_project
from xpdeep.project import Project

init(api_key="api_key", api_url="api_url")

set_project(Project.create_or_get(name="MNIST Tutorial"))

With your Project set up, you can upload the converted parquet files into your fsspec compatible storage, here an S3 bucket.

import boto3
from botocore.client import Config

client = boto3.client(
    service_name="s3",
    endpoint_url=S3_DATASET_ENDPOINT_URL,
    aws_access_key_id=S3_DATASET_ACCESS_KEY_ID,
    aws_secret_access_key=S3_DATASET_SECRET_ACCESS_KEY,
    config=Config(signature_version="s3v4"),
)

client.upload_file("train.parquet", S3_DATASET_BUCKET_NAME, "mnist/train.parquet")
client.upload_file("val.parquet", S3_DATASET_BUCKET_NAME, "mnist/val.parquet")
client.upload_file("test.parquet", S3_DATASET_BUCKET_NAME, "mnist/test.parquet")
👀 Full file preview
"""MNIST workflow, classification with image data."""

import os
from copy import deepcopy
from functools import partial

import boto3
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from botocore.client import Config
from datasets import DatasetDict, load_dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.classification import MulticlassAccuracy, MulticlassConfusionMatrix, MulticlassF1Score
from torchvision.transforms import RandomRotation

from xpdeep import init, set_project
from xpdeep.dataset.feature.augmentation.augmentation import FeatureAugmentation
from xpdeep.dataset.parquet_dataset import FittedParquetDataset, ParquetDataset
from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats, DistributionStat
from xpdeep.filtering.criteria import CategoricalCriterion
from xpdeep.filtering.filter import Filter
from xpdeep.metric import DictMetrics, TorchGlobalMetric, TorchLeafMetric
from xpdeep.model.feature_extraction_output_type import FeatureExtractionOutputType
from xpdeep.model.model_parameters import ModelDecisionGraphParameters
from xpdeep.model.xpdeep_model import XpdeepModel
from xpdeep.model.zoo.cross_entropy_loss_from_proba import CrossEntropyLossFromProbabilities
from xpdeep.model.zoo.doc import MnistCNN
from xpdeep.model.zoo.mlp import MLP
from xpdeep.project import Project, get_project
from xpdeep.trainer.callbacks import EarlyStopping, ModelCheckpoint, Scheduler
from xpdeep.trainer.loss import XpdeepLoss
from xpdeep.trainer.trainer import Trainer

STORAGE_OPTIONS = {
    "key": os.getenv("S3_DATASET_ACCESS_KEY_ID"),
    "secret": os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
    "client_kwargs": {
        "endpoint_url": os.getenv("S3_DATASET_ENDPOINT_URL"),
    },
    "s3_additional_kwargs": {"addressing_style": "path"},
}


def main():
    """Process the dataset, train, and explain the model."""
    torch.random.manual_seed(5)

    # ##### Prepare the Dataset #######

    # 1. Split and Convert your Raw Data
    # Load the dataset from HuggingFace datasets hub for convenience.
    dataset = load_dataset("mnist", trust_remote_code=True)

    test_eval = dataset["test"].train_test_split(test_size=0.5, stratify_by_column="label", seed=1225)

    splits = DatasetDict({
        "train": dataset["train"],
        "val": test_eval["train"],  # HuggingFace requires the "train" keyword.
        "test": test_eval["test"],
    })

    # Convert to pyarrow Table format
    train_table = pa.Table.from_pandas(splits["train"].to_pandas(), preserve_index=True)
    val_table = pa.Table.from_pandas(splits["val"].to_pandas(), preserve_index=True)
    test_table = pa.Table.from_pandas(splits["test"].to_pandas(), preserve_index=True)

    # Save each split as ".parquet" file
    pq.write_table(train_table, "train.parquet")
    pq.write_table(val_table, "val.parquet")
    pq.write_table(test_table, "test.parquet")

    # 2. Upload your Converted Data
    client = boto3.client(
        service_name="s3",
        endpoint_url=os.getenv("S3_DATASET_ENDPOINT_URL"),
        aws_access_key_id=os.getenv("S3_DATASET_ACCESS_KEY_ID"),
        aws_secret_access_key=os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
        config=Config(signature_version="s3v4"),
    )

    client.upload_file("train.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/train.parquet")
    client.upload_file("val.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/val.parquet")
    client.upload_file("test.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/test.parquet")

    # 3. Instantiate a Dataset
    train_dataset = ParquetDataset(
        name="mnist_train_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/train.parquet",
        storage_options=STORAGE_OPTIONS,
    )

    # 4. Find a schema
    analyzed_train_dataset = train_dataset.analyze(target_names=["label"])

    # 5. Fit the schema
    fit_train_dataset = analyzed_train_dataset.fit()

    fit_test_dataset = FittedParquetDataset(
        name="mnist_test_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/test.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    fit_val_dataset = FittedParquetDataset(
        name="mnist_validation_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/val.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    # 6. Add augmentation on the train set fitted schema (optional) , to ensure augmentation won't be applied on test
    # and validation sets
    # Here as images are grayscale, we don't need to permute the channel dimension to get the
    # channel dimension first.
    # For an RGB image, we could have used Compose([Permute([0, 3, 1, 2]), RandomRotation(90), Permute([0, 2, 3, 1])])
    augmentation = RandomRotation(90)
    image_rotation_augmentation = FeatureAugmentation(augment_raw=augmentation, augment_preprocessed=augmentation)
    fit_train_dataset.fitted_schema.add_augmentation("image", image_rotation_augmentation)

    # ##### Prepare the Model #######

    # 1. Create the required torch models
    input_size = fit_train_dataset.fitted_schema.input_size[1:]  # 28 x 28
    target_size = fit_train_dataset.fitted_schema.target_size[1]  # 10

    print(f"input_size: {input_size} - target_size: {target_size}")

    feature_extractor = MnistCNN(output_size=128)
    task_learner = MLP(
        input_size=128,
        activation_layer=partial(torch.nn.ReLU),
        flatten_input=True,
        hidden_channels=[target_size],
        last_activation=partial(torch.nn.Softmax, dim=1),
    )

    # 2. Explainable Model Specifications
    model_specifications = ModelDecisionGraphParameters(
        feature_extraction_output_type=FeatureExtractionOutputType.VECTOR,
        balancing_weight=0.1,  # This incentivizes the explainable model to provide more decisions.
    )

    # 3. Create the Explainable Model
    xpdeep_model = XpdeepModel.from_torch(
        example_dataset=fit_train_dataset,
        feature_extraction=feature_extractor,
        task_learner=task_learner,
        backbone=None,
        decision_graph_parameters=model_specifications,
    )

    # ##### Train #######

    # Metrics to monitor the training.
    metrics = DictMetrics(
        global_multi_class_accuracy=TorchGlobalMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        leaf_multi_class_accuracy=TorchLeafMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        global_multi_class_F1_score=TorchGlobalMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        leaf_multi_class_F1_score=TorchLeafMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        global_confusion_matrix=TorchGlobalMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
        leaf_confusion_matrix=TorchLeafMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
    )

    callbacks = [
        EarlyStopping(monitoring_metric="global_multi_class_accuracy", mode="maximize", patience=10),
        Scheduler(
            pre_scheduler=partial(ReduceLROnPlateau, mode="min", patience=1),
            step_method="epoch",
            monitoring_metric="Total loss",
        ),
        ModelCheckpoint(monitoring_metric="global_multi_class_accuracy", mode="maximize"),
    ]

    # Optimizer is a partial object as pytorch needs to give the model as optimizer parameter.
    optimizer = partial(torch.optim.AdamW, lr=0.001, foreach=False, fused=False)

    loss = XpdeepLoss.from_torch(
        loss=CrossEntropyLossFromProbabilities(reduction="none"),
        model=xpdeep_model,
        fitted_schema=fit_train_dataset.fitted_schema,
    )

    trainer = Trainer(
        loss=loss,
        optimizer=optimizer,
        callbacks=callbacks,
        start_epoch=0,
        max_epochs=6,
        metrics=metrics,
    )

    trained_model = trainer.train(
        model=xpdeep_model,
        train_set=fit_train_dataset,
        validation_set=fit_val_dataset,
        batch_size=128,
    )

    # ##### Explain #######

    # 1. Build the Explainer
    statistics = DictStats(
        distribution_target=DistributionStat(on="target"), distribution_prediction=DistributionStat(on="prediction")
    )
    quality_metrics = [Sensitivity(), Infidelity()]

    explainer = Explainer(quality_metrics=quality_metrics, metrics=metrics, statistics=statistics)

    # 2. Model Functioning Explanations
    model_explanations = explainer.global_explain(
        trained_model,
        train_set=fit_train_dataset,
        test_set=fit_test_dataset,
        validation_set=fit_val_dataset,
    )

    print(model_explanations.visualisation_link)

    # 3. Inference and their Causal Explanations
    my_filter = Filter("testing_filter", fit_test_dataset)
    my_filter.add_criteria(
        CategoricalCriterion(fit_test_dataset.fitted_schema["label"], categories=[1, 2]),
    )

    causal_explanations = explainer.local_explain(trained_model, fit_test_dataset, my_filter)
    print(causal_explanations.visualisation_link)


if __name__ == "__main__":
    init(api_key=os.getenv("API_KEY"), api_url=os.getenv("API_URL"))
    set_project(Project.create_or_get(name="MNIST Tutorial Review"))

    try:
        main()
    finally:
        get_project().delete()

3. Instantiate a Dataset#

Here we instantiate a ParquetDataset for the train set only. We will create the validation and test dataset later.

from xpdeep.dataset.parquet_dataset import ParquetDataset

train_dataset = ParquetDataset(
    name="mnist_train_dataset",
    path=f"s3://{S3_DATASET_BUCKET_NAME}/mnist/train.parquet",
    storage_options=STORAGE_OPTIONS,
)
👀 Full file preview
"""MNIST workflow, classification with image data."""

import os
from copy import deepcopy
from functools import partial

import boto3
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from botocore.client import Config
from datasets import DatasetDict, load_dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.classification import MulticlassAccuracy, MulticlassConfusionMatrix, MulticlassF1Score
from torchvision.transforms import RandomRotation

from xpdeep import init, set_project
from xpdeep.dataset.feature.augmentation.augmentation import FeatureAugmentation
from xpdeep.dataset.parquet_dataset import FittedParquetDataset, ParquetDataset
from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats, DistributionStat
from xpdeep.filtering.criteria import CategoricalCriterion
from xpdeep.filtering.filter import Filter
from xpdeep.metric import DictMetrics, TorchGlobalMetric, TorchLeafMetric
from xpdeep.model.feature_extraction_output_type import FeatureExtractionOutputType
from xpdeep.model.model_parameters import ModelDecisionGraphParameters
from xpdeep.model.xpdeep_model import XpdeepModel
from xpdeep.model.zoo.cross_entropy_loss_from_proba import CrossEntropyLossFromProbabilities
from xpdeep.model.zoo.doc import MnistCNN
from xpdeep.model.zoo.mlp import MLP
from xpdeep.project import Project, get_project
from xpdeep.trainer.callbacks import EarlyStopping, ModelCheckpoint, Scheduler
from xpdeep.trainer.loss import XpdeepLoss
from xpdeep.trainer.trainer import Trainer

STORAGE_OPTIONS = {
    "key": os.getenv("S3_DATASET_ACCESS_KEY_ID"),
    "secret": os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
    "client_kwargs": {
        "endpoint_url": os.getenv("S3_DATASET_ENDPOINT_URL"),
    },
    "s3_additional_kwargs": {"addressing_style": "path"},
}


def main():
    """Process the dataset, train, and explain the model."""
    torch.random.manual_seed(5)

    # ##### Prepare the Dataset #######

    # 1. Split and Convert your Raw Data
    # Load the dataset from HuggingFace datasets hub for convenience.
    dataset = load_dataset("mnist", trust_remote_code=True)

    test_eval = dataset["test"].train_test_split(test_size=0.5, stratify_by_column="label", seed=1225)

    splits = DatasetDict({
        "train": dataset["train"],
        "val": test_eval["train"],  # HuggingFace requires the "train" keyword.
        "test": test_eval["test"],
    })

    # Convert to pyarrow Table format
    train_table = pa.Table.from_pandas(splits["train"].to_pandas(), preserve_index=True)
    val_table = pa.Table.from_pandas(splits["val"].to_pandas(), preserve_index=True)
    test_table = pa.Table.from_pandas(splits["test"].to_pandas(), preserve_index=True)

    # Save each split as ".parquet" file
    pq.write_table(train_table, "train.parquet")
    pq.write_table(val_table, "val.parquet")
    pq.write_table(test_table, "test.parquet")

    # 2. Upload your Converted Data
    client = boto3.client(
        service_name="s3",
        endpoint_url=os.getenv("S3_DATASET_ENDPOINT_URL"),
        aws_access_key_id=os.getenv("S3_DATASET_ACCESS_KEY_ID"),
        aws_secret_access_key=os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
        config=Config(signature_version="s3v4"),
    )

    client.upload_file("train.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/train.parquet")
    client.upload_file("val.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/val.parquet")
    client.upload_file("test.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/test.parquet")

    # 3. Instantiate a Dataset
    train_dataset = ParquetDataset(
        name="mnist_train_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/train.parquet",
        storage_options=STORAGE_OPTIONS,
    )

    # 4. Find a schema
    analyzed_train_dataset = train_dataset.analyze(target_names=["label"])

    # 5. Fit the schema
    fit_train_dataset = analyzed_train_dataset.fit()

    fit_test_dataset = FittedParquetDataset(
        name="mnist_test_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/test.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    fit_val_dataset = FittedParquetDataset(
        name="mnist_validation_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/val.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    # 6. Add augmentation on the train set fitted schema (optional) , to ensure augmentation won't be applied on test
    # and validation sets
    # Here as images are grayscale, we don't need to permute the channel dimension to get the
    # channel dimension first.
    # For an RGB image, we could have used Compose([Permute([0, 3, 1, 2]), RandomRotation(90), Permute([0, 2, 3, 1])])
    augmentation = RandomRotation(90)
    image_rotation_augmentation = FeatureAugmentation(augment_raw=augmentation, augment_preprocessed=augmentation)
    fit_train_dataset.fitted_schema.add_augmentation("image", image_rotation_augmentation)

    # ##### Prepare the Model #######

    # 1. Create the required torch models
    input_size = fit_train_dataset.fitted_schema.input_size[1:]  # 28 x 28
    target_size = fit_train_dataset.fitted_schema.target_size[1]  # 10

    print(f"input_size: {input_size} - target_size: {target_size}")

    feature_extractor = MnistCNN(output_size=128)
    task_learner = MLP(
        input_size=128,
        activation_layer=partial(torch.nn.ReLU),
        flatten_input=True,
        hidden_channels=[target_size],
        last_activation=partial(torch.nn.Softmax, dim=1),
    )

    # 2. Explainable Model Specifications
    model_specifications = ModelDecisionGraphParameters(
        feature_extraction_output_type=FeatureExtractionOutputType.VECTOR,
        balancing_weight=0.1,  # This incentivizes the explainable model to provide more decisions.
    )

    # 3. Create the Explainable Model
    xpdeep_model = XpdeepModel.from_torch(
        example_dataset=fit_train_dataset,
        feature_extraction=feature_extractor,
        task_learner=task_learner,
        backbone=None,
        decision_graph_parameters=model_specifications,
    )

    # ##### Train #######

    # Metrics to monitor the training.
    metrics = DictMetrics(
        global_multi_class_accuracy=TorchGlobalMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        leaf_multi_class_accuracy=TorchLeafMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        global_multi_class_F1_score=TorchGlobalMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        leaf_multi_class_F1_score=TorchLeafMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        global_confusion_matrix=TorchGlobalMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
        leaf_confusion_matrix=TorchLeafMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
    )

    callbacks = [
        EarlyStopping(monitoring_metric="global_multi_class_accuracy", mode="maximize", patience=10),
        Scheduler(
            pre_scheduler=partial(ReduceLROnPlateau, mode="min", patience=1),
            step_method="epoch",
            monitoring_metric="Total loss",
        ),
        ModelCheckpoint(monitoring_metric="global_multi_class_accuracy", mode="maximize"),
    ]

    # Optimizer is a partial object as pytorch needs to give the model as optimizer parameter.
    optimizer = partial(torch.optim.AdamW, lr=0.001, foreach=False, fused=False)

    loss = XpdeepLoss.from_torch(
        loss=CrossEntropyLossFromProbabilities(reduction="none"),
        model=xpdeep_model,
        fitted_schema=fit_train_dataset.fitted_schema,
    )

    trainer = Trainer(
        loss=loss,
        optimizer=optimizer,
        callbacks=callbacks,
        start_epoch=0,
        max_epochs=6,
        metrics=metrics,
    )

    trained_model = trainer.train(
        model=xpdeep_model,
        train_set=fit_train_dataset,
        validation_set=fit_val_dataset,
        batch_size=128,
    )

    # ##### Explain #######

    # 1. Build the Explainer
    statistics = DictStats(
        distribution_target=DistributionStat(on="target"), distribution_prediction=DistributionStat(on="prediction")
    )
    quality_metrics = [Sensitivity(), Infidelity()]

    explainer = Explainer(quality_metrics=quality_metrics, metrics=metrics, statistics=statistics)

    # 2. Model Functioning Explanations
    model_explanations = explainer.global_explain(
        trained_model,
        train_set=fit_train_dataset,
        test_set=fit_test_dataset,
        validation_set=fit_val_dataset,
    )

    print(model_explanations.visualisation_link)

    # 3. Inference and their Causal Explanations
    my_filter = Filter("testing_filter", fit_test_dataset)
    my_filter.add_criteria(
        CategoricalCriterion(fit_test_dataset.fitted_schema["label"], categories=[1, 2]),
    )

    causal_explanations = explainer.local_explain(trained_model, fit_test_dataset, my_filter)
    print(causal_explanations.visualisation_link)


if __name__ == "__main__":
    init(api_key=os.getenv("API_KEY"), api_url=os.getenv("API_URL"))
    set_project(Project.create_or_get(name="MNIST Tutorial Review"))

    try:
        main()
    finally:
        get_project().delete()

4. Find a schema#

We use the AutoAnalyzer to get a schema proposal on the train set. The only requirement is to specify the target name, here the "label" feature with its 10 digit classes.

analyzed_train_dataset = train_dataset.analyze(target_names=["label"])
print(analyzed_train_dataset.analyzed_schema)
+----------------------------------------------------+
|                  Schema Contents                   |
+--------------------+-------------------+-----------+
| Type               | Name              | Is Target |
+--------------------+-------------------+-----------+
| ImageFeature       | image             | ❌        |
| CategoricalFeature | label             | ✅        |
| IndexMetadata      | __index_level_0__ |           |
+--------------------+-------------------+-----------+

By default, the feature "image" inferred from the analyzis contains a scaler preprocessing, to scale the pixel values from [0, 255] to [-1, 1].

Note

Please note that the __index_level_0__ column is automatically set as a IndexMetadata in the Schema.

👀 Full file preview
"""MNIST workflow, classification with image data."""

import os
from copy import deepcopy
from functools import partial

import boto3
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from botocore.client import Config
from datasets import DatasetDict, load_dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.classification import MulticlassAccuracy, MulticlassConfusionMatrix, MulticlassF1Score
from torchvision.transforms import RandomRotation

from xpdeep import init, set_project
from xpdeep.dataset.feature.augmentation.augmentation import FeatureAugmentation
from xpdeep.dataset.parquet_dataset import FittedParquetDataset, ParquetDataset
from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats, DistributionStat
from xpdeep.filtering.criteria import CategoricalCriterion
from xpdeep.filtering.filter import Filter
from xpdeep.metric import DictMetrics, TorchGlobalMetric, TorchLeafMetric
from xpdeep.model.feature_extraction_output_type import FeatureExtractionOutputType
from xpdeep.model.model_parameters import ModelDecisionGraphParameters
from xpdeep.model.xpdeep_model import XpdeepModel
from xpdeep.model.zoo.cross_entropy_loss_from_proba import CrossEntropyLossFromProbabilities
from xpdeep.model.zoo.doc import MnistCNN
from xpdeep.model.zoo.mlp import MLP
from xpdeep.project import Project, get_project
from xpdeep.trainer.callbacks import EarlyStopping, ModelCheckpoint, Scheduler
from xpdeep.trainer.loss import XpdeepLoss
from xpdeep.trainer.trainer import Trainer

STORAGE_OPTIONS = {
    "key": os.getenv("S3_DATASET_ACCESS_KEY_ID"),
    "secret": os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
    "client_kwargs": {
        "endpoint_url": os.getenv("S3_DATASET_ENDPOINT_URL"),
    },
    "s3_additional_kwargs": {"addressing_style": "path"},
}


def main():
    """Process the dataset, train, and explain the model."""
    torch.random.manual_seed(5)

    # ##### Prepare the Dataset #######

    # 1. Split and Convert your Raw Data
    # Load the dataset from HuggingFace datasets hub for convenience.
    dataset = load_dataset("mnist", trust_remote_code=True)

    test_eval = dataset["test"].train_test_split(test_size=0.5, stratify_by_column="label", seed=1225)

    splits = DatasetDict({
        "train": dataset["train"],
        "val": test_eval["train"],  # HuggingFace requires the "train" keyword.
        "test": test_eval["test"],
    })

    # Convert to pyarrow Table format
    train_table = pa.Table.from_pandas(splits["train"].to_pandas(), preserve_index=True)
    val_table = pa.Table.from_pandas(splits["val"].to_pandas(), preserve_index=True)
    test_table = pa.Table.from_pandas(splits["test"].to_pandas(), preserve_index=True)

    # Save each split as ".parquet" file
    pq.write_table(train_table, "train.parquet")
    pq.write_table(val_table, "val.parquet")
    pq.write_table(test_table, "test.parquet")

    # 2. Upload your Converted Data
    client = boto3.client(
        service_name="s3",
        endpoint_url=os.getenv("S3_DATASET_ENDPOINT_URL"),
        aws_access_key_id=os.getenv("S3_DATASET_ACCESS_KEY_ID"),
        aws_secret_access_key=os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
        config=Config(signature_version="s3v4"),
    )

    client.upload_file("train.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/train.parquet")
    client.upload_file("val.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/val.parquet")
    client.upload_file("test.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/test.parquet")

    # 3. Instantiate a Dataset
    train_dataset = ParquetDataset(
        name="mnist_train_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/train.parquet",
        storage_options=STORAGE_OPTIONS,
    )

    # 4. Find a schema
    analyzed_train_dataset = train_dataset.analyze(target_names=["label"])

    # 5. Fit the schema
    fit_train_dataset = analyzed_train_dataset.fit()

    fit_test_dataset = FittedParquetDataset(
        name="mnist_test_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/test.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    fit_val_dataset = FittedParquetDataset(
        name="mnist_validation_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/val.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    # 6. Add augmentation on the train set fitted schema (optional) , to ensure augmentation won't be applied on test
    # and validation sets
    # Here as images are grayscale, we don't need to permute the channel dimension to get the
    # channel dimension first.
    # For an RGB image, we could have used Compose([Permute([0, 3, 1, 2]), RandomRotation(90), Permute([0, 2, 3, 1])])
    augmentation = RandomRotation(90)
    image_rotation_augmentation = FeatureAugmentation(augment_raw=augmentation, augment_preprocessed=augmentation)
    fit_train_dataset.fitted_schema.add_augmentation("image", image_rotation_augmentation)

    # ##### Prepare the Model #######

    # 1. Create the required torch models
    input_size = fit_train_dataset.fitted_schema.input_size[1:]  # 28 x 28
    target_size = fit_train_dataset.fitted_schema.target_size[1]  # 10

    print(f"input_size: {input_size} - target_size: {target_size}")

    feature_extractor = MnistCNN(output_size=128)
    task_learner = MLP(
        input_size=128,
        activation_layer=partial(torch.nn.ReLU),
        flatten_input=True,
        hidden_channels=[target_size],
        last_activation=partial(torch.nn.Softmax, dim=1),
    )

    # 2. Explainable Model Specifications
    model_specifications = ModelDecisionGraphParameters(
        feature_extraction_output_type=FeatureExtractionOutputType.VECTOR,
        balancing_weight=0.1,  # This incentivizes the explainable model to provide more decisions.
    )

    # 3. Create the Explainable Model
    xpdeep_model = XpdeepModel.from_torch(
        example_dataset=fit_train_dataset,
        feature_extraction=feature_extractor,
        task_learner=task_learner,
        backbone=None,
        decision_graph_parameters=model_specifications,
    )

    # ##### Train #######

    # Metrics to monitor the training.
    metrics = DictMetrics(
        global_multi_class_accuracy=TorchGlobalMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        leaf_multi_class_accuracy=TorchLeafMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        global_multi_class_F1_score=TorchGlobalMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        leaf_multi_class_F1_score=TorchLeafMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        global_confusion_matrix=TorchGlobalMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
        leaf_confusion_matrix=TorchLeafMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
    )

    callbacks = [
        EarlyStopping(monitoring_metric="global_multi_class_accuracy", mode="maximize", patience=10),
        Scheduler(
            pre_scheduler=partial(ReduceLROnPlateau, mode="min", patience=1),
            step_method="epoch",
            monitoring_metric="Total loss",
        ),
        ModelCheckpoint(monitoring_metric="global_multi_class_accuracy", mode="maximize"),
    ]

    # Optimizer is a partial object as pytorch needs to give the model as optimizer parameter.
    optimizer = partial(torch.optim.AdamW, lr=0.001, foreach=False, fused=False)

    loss = XpdeepLoss.from_torch(
        loss=CrossEntropyLossFromProbabilities(reduction="none"),
        model=xpdeep_model,
        fitted_schema=fit_train_dataset.fitted_schema,
    )

    trainer = Trainer(
        loss=loss,
        optimizer=optimizer,
        callbacks=callbacks,
        start_epoch=0,
        max_epochs=6,
        metrics=metrics,
    )

    trained_model = trainer.train(
        model=xpdeep_model,
        train_set=fit_train_dataset,
        validation_set=fit_val_dataset,
        batch_size=128,
    )

    # ##### Explain #######

    # 1. Build the Explainer
    statistics = DictStats(
        distribution_target=DistributionStat(on="target"), distribution_prediction=DistributionStat(on="prediction")
    )
    quality_metrics = [Sensitivity(), Infidelity()]

    explainer = Explainer(quality_metrics=quality_metrics, metrics=metrics, statistics=statistics)

    # 2. Model Functioning Explanations
    model_explanations = explainer.global_explain(
        trained_model,
        train_set=fit_train_dataset,
        test_set=fit_test_dataset,
        validation_set=fit_val_dataset,
    )

    print(model_explanations.visualisation_link)

    # 3. Inference and their Causal Explanations
    my_filter = Filter("testing_filter", fit_test_dataset)
    my_filter.add_criteria(
        CategoricalCriterion(fit_test_dataset.fitted_schema["label"], categories=[1, 2]),
    )

    causal_explanations = explainer.local_explain(trained_model, fit_test_dataset, my_filter)
    print(causal_explanations.visualisation_link)


if __name__ == "__main__":
    init(api_key=os.getenv("API_KEY"), api_url=os.getenv("API_URL"))
    set_project(Project.create_or_get(name="MNIST Tutorial Review"))

    try:
        main()
    finally:
        get_project().delete()

5. Fit the schema#

With your Schema analyzed on the train set, you can now fit the schema to fit each feature preprocessor on the train set.

fit_train_dataset = analyzed_train_dataset.fit()

We use a copy of the same FittedSchema to create a FittedParquetDataset for the validation and test sets. This copy ensures that data augmentation applied later to the training set schema does not unintentionally modify the fitted schemas for the validation and test sets.

from xpdeep.dataset.parquet_dataset import FittedParquetDataset
from copy import deepcopy

fit_test_dataset = FittedParquetDataset(
    name="mnist_test_dataset",
    path=f"s3://{S3_DATASET_BUCKET_NAME}/mnist/test.parquet",
    storage_options=STORAGE_OPTIONS,
    fitted_schema=deepcopy(fit_train_dataset.fitted_schema)
)

fit_val_dataset = FittedParquetDataset(
    name="mnist_validation_dataset",
    path=f"s3://{S3_DATASET_BUCKET_NAME}/mnist/val.parquet",
    storage_options=STORAGE_OPTIONS,
    fitted_schema=deepcopy(fit_train_dataset.fitted_schema)
)
👀 Full file preview
"""MNIST workflow, classification with image data."""

import os
from copy import deepcopy
from functools import partial

import boto3
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from botocore.client import Config
from datasets import DatasetDict, load_dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.classification import MulticlassAccuracy, MulticlassConfusionMatrix, MulticlassF1Score
from torchvision.transforms import RandomRotation

from xpdeep import init, set_project
from xpdeep.dataset.feature.augmentation.augmentation import FeatureAugmentation
from xpdeep.dataset.parquet_dataset import FittedParquetDataset, ParquetDataset
from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats, DistributionStat
from xpdeep.filtering.criteria import CategoricalCriterion
from xpdeep.filtering.filter import Filter
from xpdeep.metric import DictMetrics, TorchGlobalMetric, TorchLeafMetric
from xpdeep.model.feature_extraction_output_type import FeatureExtractionOutputType
from xpdeep.model.model_parameters import ModelDecisionGraphParameters
from xpdeep.model.xpdeep_model import XpdeepModel
from xpdeep.model.zoo.cross_entropy_loss_from_proba import CrossEntropyLossFromProbabilities
from xpdeep.model.zoo.doc import MnistCNN
from xpdeep.model.zoo.mlp import MLP
from xpdeep.project import Project, get_project
from xpdeep.trainer.callbacks import EarlyStopping, ModelCheckpoint, Scheduler
from xpdeep.trainer.loss import XpdeepLoss
from xpdeep.trainer.trainer import Trainer

STORAGE_OPTIONS = {
    "key": os.getenv("S3_DATASET_ACCESS_KEY_ID"),
    "secret": os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
    "client_kwargs": {
        "endpoint_url": os.getenv("S3_DATASET_ENDPOINT_URL"),
    },
    "s3_additional_kwargs": {"addressing_style": "path"},
}


def main():
    """Process the dataset, train, and explain the model."""
    torch.random.manual_seed(5)

    # ##### Prepare the Dataset #######

    # 1. Split and Convert your Raw Data
    # Load the dataset from HuggingFace datasets hub for convenience.
    dataset = load_dataset("mnist", trust_remote_code=True)

    test_eval = dataset["test"].train_test_split(test_size=0.5, stratify_by_column="label", seed=1225)

    splits = DatasetDict({
        "train": dataset["train"],
        "val": test_eval["train"],  # HuggingFace requires the "train" keyword.
        "test": test_eval["test"],
    })

    # Convert to pyarrow Table format
    train_table = pa.Table.from_pandas(splits["train"].to_pandas(), preserve_index=True)
    val_table = pa.Table.from_pandas(splits["val"].to_pandas(), preserve_index=True)
    test_table = pa.Table.from_pandas(splits["test"].to_pandas(), preserve_index=True)

    # Save each split as ".parquet" file
    pq.write_table(train_table, "train.parquet")
    pq.write_table(val_table, "val.parquet")
    pq.write_table(test_table, "test.parquet")

    # 2. Upload your Converted Data
    client = boto3.client(
        service_name="s3",
        endpoint_url=os.getenv("S3_DATASET_ENDPOINT_URL"),
        aws_access_key_id=os.getenv("S3_DATASET_ACCESS_KEY_ID"),
        aws_secret_access_key=os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
        config=Config(signature_version="s3v4"),
    )

    client.upload_file("train.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/train.parquet")
    client.upload_file("val.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/val.parquet")
    client.upload_file("test.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/test.parquet")

    # 3. Instantiate a Dataset
    train_dataset = ParquetDataset(
        name="mnist_train_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/train.parquet",
        storage_options=STORAGE_OPTIONS,
    )

    # 4. Find a schema
    analyzed_train_dataset = train_dataset.analyze(target_names=["label"])

    # 5. Fit the schema
    fit_train_dataset = analyzed_train_dataset.fit()

    fit_test_dataset = FittedParquetDataset(
        name="mnist_test_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/test.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    fit_val_dataset = FittedParquetDataset(
        name="mnist_validation_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/val.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    # 6. Add augmentation on the train set fitted schema (optional) , to ensure augmentation won't be applied on test
    # and validation sets
    # Here as images are grayscale, we don't need to permute the channel dimension to get the
    # channel dimension first.
    # For an RGB image, we could have used Compose([Permute([0, 3, 1, 2]), RandomRotation(90), Permute([0, 2, 3, 1])])
    augmentation = RandomRotation(90)
    image_rotation_augmentation = FeatureAugmentation(augment_raw=augmentation, augment_preprocessed=augmentation)
    fit_train_dataset.fitted_schema.add_augmentation("image", image_rotation_augmentation)

    # ##### Prepare the Model #######

    # 1. Create the required torch models
    input_size = fit_train_dataset.fitted_schema.input_size[1:]  # 28 x 28
    target_size = fit_train_dataset.fitted_schema.target_size[1]  # 10

    print(f"input_size: {input_size} - target_size: {target_size}")

    feature_extractor = MnistCNN(output_size=128)
    task_learner = MLP(
        input_size=128,
        activation_layer=partial(torch.nn.ReLU),
        flatten_input=True,
        hidden_channels=[target_size],
        last_activation=partial(torch.nn.Softmax, dim=1),
    )

    # 2. Explainable Model Specifications
    model_specifications = ModelDecisionGraphParameters(
        feature_extraction_output_type=FeatureExtractionOutputType.VECTOR,
        balancing_weight=0.1,  # This incentivizes the explainable model to provide more decisions.
    )

    # 3. Create the Explainable Model
    xpdeep_model = XpdeepModel.from_torch(
        example_dataset=fit_train_dataset,
        feature_extraction=feature_extractor,
        task_learner=task_learner,
        backbone=None,
        decision_graph_parameters=model_specifications,
    )

    # ##### Train #######

    # Metrics to monitor the training.
    metrics = DictMetrics(
        global_multi_class_accuracy=TorchGlobalMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        leaf_multi_class_accuracy=TorchLeafMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        global_multi_class_F1_score=TorchGlobalMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        leaf_multi_class_F1_score=TorchLeafMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        global_confusion_matrix=TorchGlobalMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
        leaf_confusion_matrix=TorchLeafMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
    )

    callbacks = [
        EarlyStopping(monitoring_metric="global_multi_class_accuracy", mode="maximize", patience=10),
        Scheduler(
            pre_scheduler=partial(ReduceLROnPlateau, mode="min", patience=1),
            step_method="epoch",
            monitoring_metric="Total loss",
        ),
        ModelCheckpoint(monitoring_metric="global_multi_class_accuracy", mode="maximize"),
    ]

    # Optimizer is a partial object as pytorch needs to give the model as optimizer parameter.
    optimizer = partial(torch.optim.AdamW, lr=0.001, foreach=False, fused=False)

    loss = XpdeepLoss.from_torch(
        loss=CrossEntropyLossFromProbabilities(reduction="none"),
        model=xpdeep_model,
        fitted_schema=fit_train_dataset.fitted_schema,
    )

    trainer = Trainer(
        loss=loss,
        optimizer=optimizer,
        callbacks=callbacks,
        start_epoch=0,
        max_epochs=6,
        metrics=metrics,
    )

    trained_model = trainer.train(
        model=xpdeep_model,
        train_set=fit_train_dataset,
        validation_set=fit_val_dataset,
        batch_size=128,
    )

    # ##### Explain #######

    # 1. Build the Explainer
    statistics = DictStats(
        distribution_target=DistributionStat(on="target"), distribution_prediction=DistributionStat(on="prediction")
    )
    quality_metrics = [Sensitivity(), Infidelity()]

    explainer = Explainer(quality_metrics=quality_metrics, metrics=metrics, statistics=statistics)

    # 2. Model Functioning Explanations
    model_explanations = explainer.global_explain(
        trained_model,
        train_set=fit_train_dataset,
        test_set=fit_test_dataset,
        validation_set=fit_val_dataset,
    )

    print(model_explanations.visualisation_link)

    # 3. Inference and their Causal Explanations
    my_filter = Filter("testing_filter", fit_test_dataset)
    my_filter.add_criteria(
        CategoricalCriterion(fit_test_dataset.fitted_schema["label"], categories=[1, 2]),
    )

    causal_explanations = explainer.local_explain(trained_model, fit_test_dataset, my_filter)
    print(causal_explanations.visualisation_link)


if __name__ == "__main__":
    init(api_key=os.getenv("API_KEY"), api_url=os.getenv("API_URL"))
    set_project(Project.create_or_get(name="MNIST Tutorial Review"))

    try:
        main()
    finally:
        get_project().delete()

6. Define data augmentation methods#

We add data augmentation on preprocessed images (train set) by applying a 90° random rotation.

augmentation = RandomRotation(90)
image_rotation_augmentation = FeatureAugmentation(augment_raw=None, augment_preprocessed=augmentation)

fit_train_dataset.fitted_schema.add_augmentation("image", image_rotation_augmentation)
👀 Full file preview
"""MNIST workflow, classification with image data."""

import os
from copy import deepcopy
from functools import partial

import boto3
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from botocore.client import Config
from datasets import DatasetDict, load_dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.classification import MulticlassAccuracy, MulticlassConfusionMatrix, MulticlassF1Score
from torchvision.transforms import RandomRotation

from xpdeep import init, set_project
from xpdeep.dataset.feature.augmentation.augmentation import FeatureAugmentation
from xpdeep.dataset.parquet_dataset import FittedParquetDataset, ParquetDataset
from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats, DistributionStat
from xpdeep.filtering.criteria import CategoricalCriterion
from xpdeep.filtering.filter import Filter
from xpdeep.metric import DictMetrics, TorchGlobalMetric, TorchLeafMetric
from xpdeep.model.feature_extraction_output_type import FeatureExtractionOutputType
from xpdeep.model.model_parameters import ModelDecisionGraphParameters
from xpdeep.model.xpdeep_model import XpdeepModel
from xpdeep.model.zoo.cross_entropy_loss_from_proba import CrossEntropyLossFromProbabilities
from xpdeep.model.zoo.doc import MnistCNN
from xpdeep.model.zoo.mlp import MLP
from xpdeep.project import Project, get_project
from xpdeep.trainer.callbacks import EarlyStopping, ModelCheckpoint, Scheduler
from xpdeep.trainer.loss import XpdeepLoss
from xpdeep.trainer.trainer import Trainer

STORAGE_OPTIONS = {
    "key": os.getenv("S3_DATASET_ACCESS_KEY_ID"),
    "secret": os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
    "client_kwargs": {
        "endpoint_url": os.getenv("S3_DATASET_ENDPOINT_URL"),
    },
    "s3_additional_kwargs": {"addressing_style": "path"},
}


def main():
    """Process the dataset, train, and explain the model."""
    torch.random.manual_seed(5)

    # ##### Prepare the Dataset #######

    # 1. Split and Convert your Raw Data
    # Load the dataset from HuggingFace datasets hub for convenience.
    dataset = load_dataset("mnist", trust_remote_code=True)

    test_eval = dataset["test"].train_test_split(test_size=0.5, stratify_by_column="label", seed=1225)

    splits = DatasetDict({
        "train": dataset["train"],
        "val": test_eval["train"],  # HuggingFace requires the "train" keyword.
        "test": test_eval["test"],
    })

    # Convert to pyarrow Table format
    train_table = pa.Table.from_pandas(splits["train"].to_pandas(), preserve_index=True)
    val_table = pa.Table.from_pandas(splits["val"].to_pandas(), preserve_index=True)
    test_table = pa.Table.from_pandas(splits["test"].to_pandas(), preserve_index=True)

    # Save each split as ".parquet" file
    pq.write_table(train_table, "train.parquet")
    pq.write_table(val_table, "val.parquet")
    pq.write_table(test_table, "test.parquet")

    # 2. Upload your Converted Data
    client = boto3.client(
        service_name="s3",
        endpoint_url=os.getenv("S3_DATASET_ENDPOINT_URL"),
        aws_access_key_id=os.getenv("S3_DATASET_ACCESS_KEY_ID"),
        aws_secret_access_key=os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
        config=Config(signature_version="s3v4"),
    )

    client.upload_file("train.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/train.parquet")
    client.upload_file("val.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/val.parquet")
    client.upload_file("test.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/test.parquet")

    # 3. Instantiate a Dataset
    train_dataset = ParquetDataset(
        name="mnist_train_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/train.parquet",
        storage_options=STORAGE_OPTIONS,
    )

    # 4. Find a schema
    analyzed_train_dataset = train_dataset.analyze(target_names=["label"])

    # 5. Fit the schema
    fit_train_dataset = analyzed_train_dataset.fit()

    fit_test_dataset = FittedParquetDataset(
        name="mnist_test_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/test.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    fit_val_dataset = FittedParquetDataset(
        name="mnist_validation_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/val.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    # 6. Add augmentation on the train set fitted schema (optional) , to ensure augmentation won't be applied on test
    # and validation sets
    # Here as images are grayscale, we don't need to permute the channel dimension to get the
    # channel dimension first.
    # For an RGB image, we could have used Compose([Permute([0, 3, 1, 2]), RandomRotation(90), Permute([0, 2, 3, 1])])
    augmentation = RandomRotation(90)
    image_rotation_augmentation = FeatureAugmentation(augment_raw=augmentation, augment_preprocessed=augmentation)
    fit_train_dataset.fitted_schema.add_augmentation("image", image_rotation_augmentation)

    # ##### Prepare the Model #######

    # 1. Create the required torch models
    input_size = fit_train_dataset.fitted_schema.input_size[1:]  # 28 x 28
    target_size = fit_train_dataset.fitted_schema.target_size[1]  # 10

    print(f"input_size: {input_size} - target_size: {target_size}")

    feature_extractor = MnistCNN(output_size=128)
    task_learner = MLP(
        input_size=128,
        activation_layer=partial(torch.nn.ReLU),
        flatten_input=True,
        hidden_channels=[target_size],
        last_activation=partial(torch.nn.Softmax, dim=1),
    )

    # 2. Explainable Model Specifications
    model_specifications = ModelDecisionGraphParameters(
        feature_extraction_output_type=FeatureExtractionOutputType.VECTOR,
        balancing_weight=0.1,  # This incentivizes the explainable model to provide more decisions.
    )

    # 3. Create the Explainable Model
    xpdeep_model = XpdeepModel.from_torch(
        example_dataset=fit_train_dataset,
        feature_extraction=feature_extractor,
        task_learner=task_learner,
        backbone=None,
        decision_graph_parameters=model_specifications,
    )

    # ##### Train #######

    # Metrics to monitor the training.
    metrics = DictMetrics(
        global_multi_class_accuracy=TorchGlobalMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        leaf_multi_class_accuracy=TorchLeafMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        global_multi_class_F1_score=TorchGlobalMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        leaf_multi_class_F1_score=TorchLeafMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        global_confusion_matrix=TorchGlobalMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
        leaf_confusion_matrix=TorchLeafMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
    )

    callbacks = [
        EarlyStopping(monitoring_metric="global_multi_class_accuracy", mode="maximize", patience=10),
        Scheduler(
            pre_scheduler=partial(ReduceLROnPlateau, mode="min", patience=1),
            step_method="epoch",
            monitoring_metric="Total loss",
        ),
        ModelCheckpoint(monitoring_metric="global_multi_class_accuracy", mode="maximize"),
    ]

    # Optimizer is a partial object as pytorch needs to give the model as optimizer parameter.
    optimizer = partial(torch.optim.AdamW, lr=0.001, foreach=False, fused=False)

    loss = XpdeepLoss.from_torch(
        loss=CrossEntropyLossFromProbabilities(reduction="none"),
        model=xpdeep_model,
        fitted_schema=fit_train_dataset.fitted_schema,
    )

    trainer = Trainer(
        loss=loss,
        optimizer=optimizer,
        callbacks=callbacks,
        start_epoch=0,
        max_epochs=6,
        metrics=metrics,
    )

    trained_model = trainer.train(
        model=xpdeep_model,
        train_set=fit_train_dataset,
        validation_set=fit_val_dataset,
        batch_size=128,
    )

    # ##### Explain #######

    # 1. Build the Explainer
    statistics = DictStats(
        distribution_target=DistributionStat(on="target"), distribution_prediction=DistributionStat(on="prediction")
    )
    quality_metrics = [Sensitivity(), Infidelity()]

    explainer = Explainer(quality_metrics=quality_metrics, metrics=metrics, statistics=statistics)

    # 2. Model Functioning Explanations
    model_explanations = explainer.global_explain(
        trained_model,
        train_set=fit_train_dataset,
        test_set=fit_test_dataset,
        validation_set=fit_val_dataset,
    )

    print(model_explanations.visualisation_link)

    # 3. Inference and their Causal Explanations
    my_filter = Filter("testing_filter", fit_test_dataset)
    my_filter.add_criteria(
        CategoricalCriterion(fit_test_dataset.fitted_schema["label"], categories=[1, 2]),
    )

    causal_explanations = explainer.local_explain(trained_model, fit_test_dataset, my_filter)
    print(causal_explanations.visualisation_link)


if __name__ == "__main__":
    init(api_key=os.getenv("API_KEY"), api_url=os.getenv("API_URL"))
    set_project(Project.create_or_get(name="MNIST Tutorial Review"))

    try:
        main()
    finally:
        get_project().delete()

And that's all for the dataset preparation. We now have three FittedParquetDataset, each with its FittedSchema, ready to be used.

Prepare the Model#

We need now to create an explainable model XpdeepModel.

1. Create the required torch models#

We have a multi-class classification task with image input data. We will use a Multi Layer Perceptron (MLP) for this task, in combination with a CNN backbone model.

Tip

Model input and output sizes (including the batch dimension) can be easily retrieved from the fitted schema.

input_size = fit_train_dataset.fitted_schema.input_size[1]
target_size = fit_train_dataset.fitted_schema.target_size[1]

print(f"input_size: {input_size} - target_size: {target_size}")
input_size: (28, 28) - target_size: 10

Therefore:

  • The FeatureExtractionModel will embed input data into a 128 dimension space using a CNN with a serie of residual blocks.
  • The TaskLearnerModel will use a Softmax output layer for the 10 output classes.
from functools import partial

import torch

from xpdeep.model.zoo.doc import MnistCNN
from xpdeep.model.zoo.mlp import MLP

feature_extractor = MnistCNN(output_size=128)
task_learner = MLP(
    input_size=128,
    activation_layer=partial(torch.nn.ReLU),
    flatten_input=True,
    hidden_channels=[target_size],
    last_activation=partial(torch.nn.Softmax, dim=1),
)
👀 Full file preview
"""MNIST workflow, classification with image data."""

import os
from copy import deepcopy
from functools import partial

import boto3
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from botocore.client import Config
from datasets import DatasetDict, load_dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.classification import MulticlassAccuracy, MulticlassConfusionMatrix, MulticlassF1Score
from torchvision.transforms import RandomRotation

from xpdeep import init, set_project
from xpdeep.dataset.feature.augmentation.augmentation import FeatureAugmentation
from xpdeep.dataset.parquet_dataset import FittedParquetDataset, ParquetDataset
from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats, DistributionStat
from xpdeep.filtering.criteria import CategoricalCriterion
from xpdeep.filtering.filter import Filter
from xpdeep.metric import DictMetrics, TorchGlobalMetric, TorchLeafMetric
from xpdeep.model.feature_extraction_output_type import FeatureExtractionOutputType
from xpdeep.model.model_parameters import ModelDecisionGraphParameters
from xpdeep.model.xpdeep_model import XpdeepModel
from xpdeep.model.zoo.cross_entropy_loss_from_proba import CrossEntropyLossFromProbabilities
from xpdeep.model.zoo.doc import MnistCNN
from xpdeep.model.zoo.mlp import MLP
from xpdeep.project import Project, get_project
from xpdeep.trainer.callbacks import EarlyStopping, ModelCheckpoint, Scheduler
from xpdeep.trainer.loss import XpdeepLoss
from xpdeep.trainer.trainer import Trainer

STORAGE_OPTIONS = {
    "key": os.getenv("S3_DATASET_ACCESS_KEY_ID"),
    "secret": os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
    "client_kwargs": {
        "endpoint_url": os.getenv("S3_DATASET_ENDPOINT_URL"),
    },
    "s3_additional_kwargs": {"addressing_style": "path"},
}


def main():
    """Process the dataset, train, and explain the model."""
    torch.random.manual_seed(5)

    # ##### Prepare the Dataset #######

    # 1. Split and Convert your Raw Data
    # Load the dataset from HuggingFace datasets hub for convenience.
    dataset = load_dataset("mnist", trust_remote_code=True)

    test_eval = dataset["test"].train_test_split(test_size=0.5, stratify_by_column="label", seed=1225)

    splits = DatasetDict({
        "train": dataset["train"],
        "val": test_eval["train"],  # HuggingFace requires the "train" keyword.
        "test": test_eval["test"],
    })

    # Convert to pyarrow Table format
    train_table = pa.Table.from_pandas(splits["train"].to_pandas(), preserve_index=True)
    val_table = pa.Table.from_pandas(splits["val"].to_pandas(), preserve_index=True)
    test_table = pa.Table.from_pandas(splits["test"].to_pandas(), preserve_index=True)

    # Save each split as ".parquet" file
    pq.write_table(train_table, "train.parquet")
    pq.write_table(val_table, "val.parquet")
    pq.write_table(test_table, "test.parquet")

    # 2. Upload your Converted Data
    client = boto3.client(
        service_name="s3",
        endpoint_url=os.getenv("S3_DATASET_ENDPOINT_URL"),
        aws_access_key_id=os.getenv("S3_DATASET_ACCESS_KEY_ID"),
        aws_secret_access_key=os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
        config=Config(signature_version="s3v4"),
    )

    client.upload_file("train.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/train.parquet")
    client.upload_file("val.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/val.parquet")
    client.upload_file("test.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/test.parquet")

    # 3. Instantiate a Dataset
    train_dataset = ParquetDataset(
        name="mnist_train_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/train.parquet",
        storage_options=STORAGE_OPTIONS,
    )

    # 4. Find a schema
    analyzed_train_dataset = train_dataset.analyze(target_names=["label"])

    # 5. Fit the schema
    fit_train_dataset = analyzed_train_dataset.fit()

    fit_test_dataset = FittedParquetDataset(
        name="mnist_test_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/test.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    fit_val_dataset = FittedParquetDataset(
        name="mnist_validation_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/val.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    # 6. Add augmentation on the train set fitted schema (optional) , to ensure augmentation won't be applied on test
    # and validation sets
    # Here as images are grayscale, we don't need to permute the channel dimension to get the
    # channel dimension first.
    # For an RGB image, we could have used Compose([Permute([0, 3, 1, 2]), RandomRotation(90), Permute([0, 2, 3, 1])])
    augmentation = RandomRotation(90)
    image_rotation_augmentation = FeatureAugmentation(augment_raw=augmentation, augment_preprocessed=augmentation)
    fit_train_dataset.fitted_schema.add_augmentation("image", image_rotation_augmentation)

    # ##### Prepare the Model #######

    # 1. Create the required torch models
    input_size = fit_train_dataset.fitted_schema.input_size[1:]  # 28 x 28
    target_size = fit_train_dataset.fitted_schema.target_size[1]  # 10

    print(f"input_size: {input_size} - target_size: {target_size}")

    feature_extractor = MnistCNN(output_size=128)
    task_learner = MLP(
        input_size=128,
        activation_layer=partial(torch.nn.ReLU),
        flatten_input=True,
        hidden_channels=[target_size],
        last_activation=partial(torch.nn.Softmax, dim=1),
    )

    # 2. Explainable Model Specifications
    model_specifications = ModelDecisionGraphParameters(
        feature_extraction_output_type=FeatureExtractionOutputType.VECTOR,
        balancing_weight=0.1,  # This incentivizes the explainable model to provide more decisions.
    )

    # 3. Create the Explainable Model
    xpdeep_model = XpdeepModel.from_torch(
        example_dataset=fit_train_dataset,
        feature_extraction=feature_extractor,
        task_learner=task_learner,
        backbone=None,
        decision_graph_parameters=model_specifications,
    )

    # ##### Train #######

    # Metrics to monitor the training.
    metrics = DictMetrics(
        global_multi_class_accuracy=TorchGlobalMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        leaf_multi_class_accuracy=TorchLeafMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        global_multi_class_F1_score=TorchGlobalMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        leaf_multi_class_F1_score=TorchLeafMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        global_confusion_matrix=TorchGlobalMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
        leaf_confusion_matrix=TorchLeafMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
    )

    callbacks = [
        EarlyStopping(monitoring_metric="global_multi_class_accuracy", mode="maximize", patience=10),
        Scheduler(
            pre_scheduler=partial(ReduceLROnPlateau, mode="min", patience=1),
            step_method="epoch",
            monitoring_metric="Total loss",
        ),
        ModelCheckpoint(monitoring_metric="global_multi_class_accuracy", mode="maximize"),
    ]

    # Optimizer is a partial object as pytorch needs to give the model as optimizer parameter.
    optimizer = partial(torch.optim.AdamW, lr=0.001, foreach=False, fused=False)

    loss = XpdeepLoss.from_torch(
        loss=CrossEntropyLossFromProbabilities(reduction="none"),
        model=xpdeep_model,
        fitted_schema=fit_train_dataset.fitted_schema,
    )

    trainer = Trainer(
        loss=loss,
        optimizer=optimizer,
        callbacks=callbacks,
        start_epoch=0,
        max_epochs=6,
        metrics=metrics,
    )

    trained_model = trainer.train(
        model=xpdeep_model,
        train_set=fit_train_dataset,
        validation_set=fit_val_dataset,
        batch_size=128,
    )

    # ##### Explain #######

    # 1. Build the Explainer
    statistics = DictStats(
        distribution_target=DistributionStat(on="target"), distribution_prediction=DistributionStat(on="prediction")
    )
    quality_metrics = [Sensitivity(), Infidelity()]

    explainer = Explainer(quality_metrics=quality_metrics, metrics=metrics, statistics=statistics)

    # 2. Model Functioning Explanations
    model_explanations = explainer.global_explain(
        trained_model,
        train_set=fit_train_dataset,
        test_set=fit_test_dataset,
        validation_set=fit_val_dataset,
    )

    print(model_explanations.visualisation_link)

    # 3. Inference and their Causal Explanations
    my_filter = Filter("testing_filter", fit_test_dataset)
    my_filter.add_criteria(
        CategoricalCriterion(fit_test_dataset.fitted_schema["label"], categories=[1, 2]),
    )

    causal_explanations = explainer.local_explain(trained_model, fit_test_dataset, my_filter)
    print(causal_explanations.visualisation_link)


if __name__ == "__main__":
    init(api_key=os.getenv("API_KEY"), api_url=os.getenv("API_URL"))
    set_project(Project.create_or_get(name="MNIST Tutorial Review"))

    try:
        main()
    finally:
        get_project().delete()

2. Explainable Model Specifications#

Here comes the crucial part: we need to specify model specifications under ModelDecisionGraphParameters to get the best explanations (Model Decision Graph and Inference Graph).

from xpdeep.model.model_parameters import ModelDecisionGraphParameters
from xpdeep.model.feature_extraction_output_type import FeatureExtractionOutputType

model_specifications = ModelDecisionGraphParameters(
    feature_extraction_output_type=FeatureExtractionOutputType.VECTOR,
    balancing_weight=0.1,
)
👀 Full file preview
"""MNIST workflow, classification with image data."""

import os
from copy import deepcopy
from functools import partial

import boto3
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from botocore.client import Config
from datasets import DatasetDict, load_dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.classification import MulticlassAccuracy, MulticlassConfusionMatrix, MulticlassF1Score
from torchvision.transforms import RandomRotation

from xpdeep import init, set_project
from xpdeep.dataset.feature.augmentation.augmentation import FeatureAugmentation
from xpdeep.dataset.parquet_dataset import FittedParquetDataset, ParquetDataset
from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats, DistributionStat
from xpdeep.filtering.criteria import CategoricalCriterion
from xpdeep.filtering.filter import Filter
from xpdeep.metric import DictMetrics, TorchGlobalMetric, TorchLeafMetric
from xpdeep.model.feature_extraction_output_type import FeatureExtractionOutputType
from xpdeep.model.model_parameters import ModelDecisionGraphParameters
from xpdeep.model.xpdeep_model import XpdeepModel
from xpdeep.model.zoo.cross_entropy_loss_from_proba import CrossEntropyLossFromProbabilities
from xpdeep.model.zoo.doc import MnistCNN
from xpdeep.model.zoo.mlp import MLP
from xpdeep.project import Project, get_project
from xpdeep.trainer.callbacks import EarlyStopping, ModelCheckpoint, Scheduler
from xpdeep.trainer.loss import XpdeepLoss
from xpdeep.trainer.trainer import Trainer

STORAGE_OPTIONS = {
    "key": os.getenv("S3_DATASET_ACCESS_KEY_ID"),
    "secret": os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
    "client_kwargs": {
        "endpoint_url": os.getenv("S3_DATASET_ENDPOINT_URL"),
    },
    "s3_additional_kwargs": {"addressing_style": "path"},
}


def main():
    """Process the dataset, train, and explain the model."""
    torch.random.manual_seed(5)

    # ##### Prepare the Dataset #######

    # 1. Split and Convert your Raw Data
    # Load the dataset from HuggingFace datasets hub for convenience.
    dataset = load_dataset("mnist", trust_remote_code=True)

    test_eval = dataset["test"].train_test_split(test_size=0.5, stratify_by_column="label", seed=1225)

    splits = DatasetDict({
        "train": dataset["train"],
        "val": test_eval["train"],  # HuggingFace requires the "train" keyword.
        "test": test_eval["test"],
    })

    # Convert to pyarrow Table format
    train_table = pa.Table.from_pandas(splits["train"].to_pandas(), preserve_index=True)
    val_table = pa.Table.from_pandas(splits["val"].to_pandas(), preserve_index=True)
    test_table = pa.Table.from_pandas(splits["test"].to_pandas(), preserve_index=True)

    # Save each split as ".parquet" file
    pq.write_table(train_table, "train.parquet")
    pq.write_table(val_table, "val.parquet")
    pq.write_table(test_table, "test.parquet")

    # 2. Upload your Converted Data
    client = boto3.client(
        service_name="s3",
        endpoint_url=os.getenv("S3_DATASET_ENDPOINT_URL"),
        aws_access_key_id=os.getenv("S3_DATASET_ACCESS_KEY_ID"),
        aws_secret_access_key=os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
        config=Config(signature_version="s3v4"),
    )

    client.upload_file("train.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/train.parquet")
    client.upload_file("val.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/val.parquet")
    client.upload_file("test.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/test.parquet")

    # 3. Instantiate a Dataset
    train_dataset = ParquetDataset(
        name="mnist_train_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/train.parquet",
        storage_options=STORAGE_OPTIONS,
    )

    # 4. Find a schema
    analyzed_train_dataset = train_dataset.analyze(target_names=["label"])

    # 5. Fit the schema
    fit_train_dataset = analyzed_train_dataset.fit()

    fit_test_dataset = FittedParquetDataset(
        name="mnist_test_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/test.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    fit_val_dataset = FittedParquetDataset(
        name="mnist_validation_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/val.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    # 6. Add augmentation on the train set fitted schema (optional) , to ensure augmentation won't be applied on test
    # and validation sets
    # Here as images are grayscale, we don't need to permute the channel dimension to get the
    # channel dimension first.
    # For an RGB image, we could have used Compose([Permute([0, 3, 1, 2]), RandomRotation(90), Permute([0, 2, 3, 1])])
    augmentation = RandomRotation(90)
    image_rotation_augmentation = FeatureAugmentation(augment_raw=augmentation, augment_preprocessed=augmentation)
    fit_train_dataset.fitted_schema.add_augmentation("image", image_rotation_augmentation)

    # ##### Prepare the Model #######

    # 1. Create the required torch models
    input_size = fit_train_dataset.fitted_schema.input_size[1:]  # 28 x 28
    target_size = fit_train_dataset.fitted_schema.target_size[1]  # 10

    print(f"input_size: {input_size} - target_size: {target_size}")

    feature_extractor = MnistCNN(output_size=128)
    task_learner = MLP(
        input_size=128,
        activation_layer=partial(torch.nn.ReLU),
        flatten_input=True,
        hidden_channels=[target_size],
        last_activation=partial(torch.nn.Softmax, dim=1),
    )

    # 2. Explainable Model Specifications
    model_specifications = ModelDecisionGraphParameters(
        feature_extraction_output_type=FeatureExtractionOutputType.VECTOR,
        balancing_weight=0.1,  # This incentivizes the explainable model to provide more decisions.
    )

    # 3. Create the Explainable Model
    xpdeep_model = XpdeepModel.from_torch(
        example_dataset=fit_train_dataset,
        feature_extraction=feature_extractor,
        task_learner=task_learner,
        backbone=None,
        decision_graph_parameters=model_specifications,
    )

    # ##### Train #######

    # Metrics to monitor the training.
    metrics = DictMetrics(
        global_multi_class_accuracy=TorchGlobalMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        leaf_multi_class_accuracy=TorchLeafMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        global_multi_class_F1_score=TorchGlobalMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        leaf_multi_class_F1_score=TorchLeafMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        global_confusion_matrix=TorchGlobalMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
        leaf_confusion_matrix=TorchLeafMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
    )

    callbacks = [
        EarlyStopping(monitoring_metric="global_multi_class_accuracy", mode="maximize", patience=10),
        Scheduler(
            pre_scheduler=partial(ReduceLROnPlateau, mode="min", patience=1),
            step_method="epoch",
            monitoring_metric="Total loss",
        ),
        ModelCheckpoint(monitoring_metric="global_multi_class_accuracy", mode="maximize"),
    ]

    # Optimizer is a partial object as pytorch needs to give the model as optimizer parameter.
    optimizer = partial(torch.optim.AdamW, lr=0.001, foreach=False, fused=False)

    loss = XpdeepLoss.from_torch(
        loss=CrossEntropyLossFromProbabilities(reduction="none"),
        model=xpdeep_model,
        fitted_schema=fit_train_dataset.fitted_schema,
    )

    trainer = Trainer(
        loss=loss,
        optimizer=optimizer,
        callbacks=callbacks,
        start_epoch=0,
        max_epochs=6,
        metrics=metrics,
    )

    trained_model = trainer.train(
        model=xpdeep_model,
        train_set=fit_train_dataset,
        validation_set=fit_val_dataset,
        batch_size=128,
    )

    # ##### Explain #######

    # 1. Build the Explainer
    statistics = DictStats(
        distribution_target=DistributionStat(on="target"), distribution_prediction=DistributionStat(on="prediction")
    )
    quality_metrics = [Sensitivity(), Infidelity()]

    explainer = Explainer(quality_metrics=quality_metrics, metrics=metrics, statistics=statistics)

    # 2. Model Functioning Explanations
    model_explanations = explainer.global_explain(
        trained_model,
        train_set=fit_train_dataset,
        test_set=fit_test_dataset,
        validation_set=fit_val_dataset,
    )

    print(model_explanations.visualisation_link)

    # 3. Inference and their Causal Explanations
    my_filter = Filter("testing_filter", fit_test_dataset)
    my_filter.add_criteria(
        CategoricalCriterion(fit_test_dataset.fitted_schema["label"], categories=[1, 2]),
    )

    causal_explanations = explainer.local_explain(trained_model, fit_test_dataset, my_filter)
    print(causal_explanations.visualisation_link)


if __name__ == "__main__":
    init(api_key=os.getenv("API_KEY"), api_url=os.getenv("API_URL"))
    set_project(Project.create_or_get(name="MNIST Tutorial Review"))

    try:
        main()
    finally:
        get_project().delete()

For further details, see docs

Note

All parameters have a default value, you can start by using those default value, then iterate and update the configuration to find suitable explanations.

3. Create the Explainable Model#

Given the model architecture and configuration, we can finally instantiate the explainable model XpdeepModel.

from xpdeep.model.xpdeep_model import XpdeepModel

xpdeep_model = XpdeepModel.from_torch(
    example_dataset=fit_train_dataset,
    feature_extraction=feature_extractor,
    task_learner=task_learner,
    backbone=None,
    decision_graph_parameters=model_specifications,
)
👀 Full file preview
"""MNIST workflow, classification with image data."""

import os
from copy import deepcopy
from functools import partial

import boto3
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from botocore.client import Config
from datasets import DatasetDict, load_dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.classification import MulticlassAccuracy, MulticlassConfusionMatrix, MulticlassF1Score
from torchvision.transforms import RandomRotation

from xpdeep import init, set_project
from xpdeep.dataset.feature.augmentation.augmentation import FeatureAugmentation
from xpdeep.dataset.parquet_dataset import FittedParquetDataset, ParquetDataset
from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats, DistributionStat
from xpdeep.filtering.criteria import CategoricalCriterion
from xpdeep.filtering.filter import Filter
from xpdeep.metric import DictMetrics, TorchGlobalMetric, TorchLeafMetric
from xpdeep.model.feature_extraction_output_type import FeatureExtractionOutputType
from xpdeep.model.model_parameters import ModelDecisionGraphParameters
from xpdeep.model.xpdeep_model import XpdeepModel
from xpdeep.model.zoo.cross_entropy_loss_from_proba import CrossEntropyLossFromProbabilities
from xpdeep.model.zoo.doc import MnistCNN
from xpdeep.model.zoo.mlp import MLP
from xpdeep.project import Project, get_project
from xpdeep.trainer.callbacks import EarlyStopping, ModelCheckpoint, Scheduler
from xpdeep.trainer.loss import XpdeepLoss
from xpdeep.trainer.trainer import Trainer

STORAGE_OPTIONS = {
    "key": os.getenv("S3_DATASET_ACCESS_KEY_ID"),
    "secret": os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
    "client_kwargs": {
        "endpoint_url": os.getenv("S3_DATASET_ENDPOINT_URL"),
    },
    "s3_additional_kwargs": {"addressing_style": "path"},
}


def main():
    """Process the dataset, train, and explain the model."""
    torch.random.manual_seed(5)

    # ##### Prepare the Dataset #######

    # 1. Split and Convert your Raw Data
    # Load the dataset from HuggingFace datasets hub for convenience.
    dataset = load_dataset("mnist", trust_remote_code=True)

    test_eval = dataset["test"].train_test_split(test_size=0.5, stratify_by_column="label", seed=1225)

    splits = DatasetDict({
        "train": dataset["train"],
        "val": test_eval["train"],  # HuggingFace requires the "train" keyword.
        "test": test_eval["test"],
    })

    # Convert to pyarrow Table format
    train_table = pa.Table.from_pandas(splits["train"].to_pandas(), preserve_index=True)
    val_table = pa.Table.from_pandas(splits["val"].to_pandas(), preserve_index=True)
    test_table = pa.Table.from_pandas(splits["test"].to_pandas(), preserve_index=True)

    # Save each split as ".parquet" file
    pq.write_table(train_table, "train.parquet")
    pq.write_table(val_table, "val.parquet")
    pq.write_table(test_table, "test.parquet")

    # 2. Upload your Converted Data
    client = boto3.client(
        service_name="s3",
        endpoint_url=os.getenv("S3_DATASET_ENDPOINT_URL"),
        aws_access_key_id=os.getenv("S3_DATASET_ACCESS_KEY_ID"),
        aws_secret_access_key=os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
        config=Config(signature_version="s3v4"),
    )

    client.upload_file("train.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/train.parquet")
    client.upload_file("val.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/val.parquet")
    client.upload_file("test.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/test.parquet")

    # 3. Instantiate a Dataset
    train_dataset = ParquetDataset(
        name="mnist_train_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/train.parquet",
        storage_options=STORAGE_OPTIONS,
    )

    # 4. Find a schema
    analyzed_train_dataset = train_dataset.analyze(target_names=["label"])

    # 5. Fit the schema
    fit_train_dataset = analyzed_train_dataset.fit()

    fit_test_dataset = FittedParquetDataset(
        name="mnist_test_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/test.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    fit_val_dataset = FittedParquetDataset(
        name="mnist_validation_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/val.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    # 6. Add augmentation on the train set fitted schema (optional) , to ensure augmentation won't be applied on test
    # and validation sets
    # Here as images are grayscale, we don't need to permute the channel dimension to get the
    # channel dimension first.
    # For an RGB image, we could have used Compose([Permute([0, 3, 1, 2]), RandomRotation(90), Permute([0, 2, 3, 1])])
    augmentation = RandomRotation(90)
    image_rotation_augmentation = FeatureAugmentation(augment_raw=augmentation, augment_preprocessed=augmentation)
    fit_train_dataset.fitted_schema.add_augmentation("image", image_rotation_augmentation)

    # ##### Prepare the Model #######

    # 1. Create the required torch models
    input_size = fit_train_dataset.fitted_schema.input_size[1:]  # 28 x 28
    target_size = fit_train_dataset.fitted_schema.target_size[1]  # 10

    print(f"input_size: {input_size} - target_size: {target_size}")

    feature_extractor = MnistCNN(output_size=128)
    task_learner = MLP(
        input_size=128,
        activation_layer=partial(torch.nn.ReLU),
        flatten_input=True,
        hidden_channels=[target_size],
        last_activation=partial(torch.nn.Softmax, dim=1),
    )

    # 2. Explainable Model Specifications
    model_specifications = ModelDecisionGraphParameters(
        feature_extraction_output_type=FeatureExtractionOutputType.VECTOR,
        balancing_weight=0.1,  # This incentivizes the explainable model to provide more decisions.
    )

    # 3. Create the Explainable Model
    xpdeep_model = XpdeepModel.from_torch(
        example_dataset=fit_train_dataset,
        feature_extraction=feature_extractor,
        task_learner=task_learner,
        backbone=None,
        decision_graph_parameters=model_specifications,
    )

    # ##### Train #######

    # Metrics to monitor the training.
    metrics = DictMetrics(
        global_multi_class_accuracy=TorchGlobalMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        leaf_multi_class_accuracy=TorchLeafMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        global_multi_class_F1_score=TorchGlobalMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        leaf_multi_class_F1_score=TorchLeafMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        global_confusion_matrix=TorchGlobalMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
        leaf_confusion_matrix=TorchLeafMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
    )

    callbacks = [
        EarlyStopping(monitoring_metric="global_multi_class_accuracy", mode="maximize", patience=10),
        Scheduler(
            pre_scheduler=partial(ReduceLROnPlateau, mode="min", patience=1),
            step_method="epoch",
            monitoring_metric="Total loss",
        ),
        ModelCheckpoint(monitoring_metric="global_multi_class_accuracy", mode="maximize"),
    ]

    # Optimizer is a partial object as pytorch needs to give the model as optimizer parameter.
    optimizer = partial(torch.optim.AdamW, lr=0.001, foreach=False, fused=False)

    loss = XpdeepLoss.from_torch(
        loss=CrossEntropyLossFromProbabilities(reduction="none"),
        model=xpdeep_model,
        fitted_schema=fit_train_dataset.fitted_schema,
    )

    trainer = Trainer(
        loss=loss,
        optimizer=optimizer,
        callbacks=callbacks,
        start_epoch=0,
        max_epochs=6,
        metrics=metrics,
    )

    trained_model = trainer.train(
        model=xpdeep_model,
        train_set=fit_train_dataset,
        validation_set=fit_val_dataset,
        batch_size=128,
    )

    # ##### Explain #######

    # 1. Build the Explainer
    statistics = DictStats(
        distribution_target=DistributionStat(on="target"), distribution_prediction=DistributionStat(on="prediction")
    )
    quality_metrics = [Sensitivity(), Infidelity()]

    explainer = Explainer(quality_metrics=quality_metrics, metrics=metrics, statistics=statistics)

    # 2. Model Functioning Explanations
    model_explanations = explainer.global_explain(
        trained_model,
        train_set=fit_train_dataset,
        test_set=fit_test_dataset,
        validation_set=fit_val_dataset,
    )

    print(model_explanations.visualisation_link)

    # 3. Inference and their Causal Explanations
    my_filter = Filter("testing_filter", fit_test_dataset)
    my_filter.add_criteria(
        CategoricalCriterion(fit_test_dataset.fitted_schema["label"], categories=[1, 2]),
    )

    causal_explanations = explainer.local_explain(trained_model, fit_test_dataset, my_filter)
    print(causal_explanations.visualisation_link)


if __name__ == "__main__":
    init(api_key=os.getenv("API_KEY"), api_url=os.getenv("API_URL"))
    set_project(Project.create_or_get(name="MNIST Tutorial Review"))

    try:
        main()
    finally:
        get_project().delete()

Train#

The train step is straightforward: we need to specify the Trainer parameters.

from xpdeep.trainer.callbacks import EarlyStopping, Scheduler, ModelCheckpoint
from xpdeep.model.zoo.cross_entropy_loss_from_proba import CrossEntropyLossFromProbabilities
from xpdeep.trainer.trainer import Trainer
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score, MulticlassConfusionMatrix
from xpdeep.metric import DictMetrics, TorchGlobalMetric, TorchLeafMetric
from functools import partial

# Metrics to monitor the training.
metrics = DictMetrics(
    global_multi_class_accuracy=TorchGlobalMetric(
        partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
    ),
    leaf_multi_class_accuracy=TorchLeafMetric(
        partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
    ),
    global_multi_class_F1_score=TorchGlobalMetric(
        partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
    ),
    leaf_multi_class_F1_score=TorchLeafMetric(
        partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
    ),
    global_confusion_matrix=TorchGlobalMetric(
        partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
    ),
    leaf_confusion_matrix=TorchLeafMetric(
        partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
    ),
)

callbacks = [
    EarlyStopping(monitoring_metric="global_multi_class_accuracy", mode="maximize", patience=10),
    Scheduler(
        pre_scheduler=partial(ReduceLROnPlateau, mode="min", patience=1),
        step_method="epoch",
        monitoring_metric="Total loss",
    ),
    ModelCheckpoint(monitoring_metric="global_multi_class_accuracy", mode="maximize"),
]

# Optimizer is a partial object as pytorch needs to give the model as optimizer parameter.
optimizer = partial(torch.optim.AdamW, lr=0.001, foreach=False, fused=False)

trainer = Trainer(
    loss=CrossEntropyLossFromProbabilities(reduction="none"),
    optimizer=optimizer,
    callbacks=callbacks,
    start_epoch=0,
    max_epochs=6,
    metrics=metrics,
)
👀 Full file preview
"""MNIST workflow, classification with image data."""

import os
from copy import deepcopy
from functools import partial

import boto3
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from botocore.client import Config
from datasets import DatasetDict, load_dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.classification import MulticlassAccuracy, MulticlassConfusionMatrix, MulticlassF1Score
from torchvision.transforms import RandomRotation

from xpdeep import init, set_project
from xpdeep.dataset.feature.augmentation.augmentation import FeatureAugmentation
from xpdeep.dataset.parquet_dataset import FittedParquetDataset, ParquetDataset
from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats, DistributionStat
from xpdeep.filtering.criteria import CategoricalCriterion
from xpdeep.filtering.filter import Filter
from xpdeep.metric import DictMetrics, TorchGlobalMetric, TorchLeafMetric
from xpdeep.model.feature_extraction_output_type import FeatureExtractionOutputType
from xpdeep.model.model_parameters import ModelDecisionGraphParameters
from xpdeep.model.xpdeep_model import XpdeepModel
from xpdeep.model.zoo.cross_entropy_loss_from_proba import CrossEntropyLossFromProbabilities
from xpdeep.model.zoo.doc import MnistCNN
from xpdeep.model.zoo.mlp import MLP
from xpdeep.project import Project, get_project
from xpdeep.trainer.callbacks import EarlyStopping, ModelCheckpoint, Scheduler
from xpdeep.trainer.loss import XpdeepLoss
from xpdeep.trainer.trainer import Trainer

STORAGE_OPTIONS = {
    "key": os.getenv("S3_DATASET_ACCESS_KEY_ID"),
    "secret": os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
    "client_kwargs": {
        "endpoint_url": os.getenv("S3_DATASET_ENDPOINT_URL"),
    },
    "s3_additional_kwargs": {"addressing_style": "path"},
}


def main():
    """Process the dataset, train, and explain the model."""
    torch.random.manual_seed(5)

    # ##### Prepare the Dataset #######

    # 1. Split and Convert your Raw Data
    # Load the dataset from HuggingFace datasets hub for convenience.
    dataset = load_dataset("mnist", trust_remote_code=True)

    test_eval = dataset["test"].train_test_split(test_size=0.5, stratify_by_column="label", seed=1225)

    splits = DatasetDict({
        "train": dataset["train"],
        "val": test_eval["train"],  # HuggingFace requires the "train" keyword.
        "test": test_eval["test"],
    })

    # Convert to pyarrow Table format
    train_table = pa.Table.from_pandas(splits["train"].to_pandas(), preserve_index=True)
    val_table = pa.Table.from_pandas(splits["val"].to_pandas(), preserve_index=True)
    test_table = pa.Table.from_pandas(splits["test"].to_pandas(), preserve_index=True)

    # Save each split as ".parquet" file
    pq.write_table(train_table, "train.parquet")
    pq.write_table(val_table, "val.parquet")
    pq.write_table(test_table, "test.parquet")

    # 2. Upload your Converted Data
    client = boto3.client(
        service_name="s3",
        endpoint_url=os.getenv("S3_DATASET_ENDPOINT_URL"),
        aws_access_key_id=os.getenv("S3_DATASET_ACCESS_KEY_ID"),
        aws_secret_access_key=os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
        config=Config(signature_version="s3v4"),
    )

    client.upload_file("train.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/train.parquet")
    client.upload_file("val.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/val.parquet")
    client.upload_file("test.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/test.parquet")

    # 3. Instantiate a Dataset
    train_dataset = ParquetDataset(
        name="mnist_train_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/train.parquet",
        storage_options=STORAGE_OPTIONS,
    )

    # 4. Find a schema
    analyzed_train_dataset = train_dataset.analyze(target_names=["label"])

    # 5. Fit the schema
    fit_train_dataset = analyzed_train_dataset.fit()

    fit_test_dataset = FittedParquetDataset(
        name="mnist_test_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/test.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    fit_val_dataset = FittedParquetDataset(
        name="mnist_validation_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/val.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    # 6. Add augmentation on the train set fitted schema (optional) , to ensure augmentation won't be applied on test
    # and validation sets
    # Here as images are grayscale, we don't need to permute the channel dimension to get the
    # channel dimension first.
    # For an RGB image, we could have used Compose([Permute([0, 3, 1, 2]), RandomRotation(90), Permute([0, 2, 3, 1])])
    augmentation = RandomRotation(90)
    image_rotation_augmentation = FeatureAugmentation(augment_raw=augmentation, augment_preprocessed=augmentation)
    fit_train_dataset.fitted_schema.add_augmentation("image", image_rotation_augmentation)

    # ##### Prepare the Model #######

    # 1. Create the required torch models
    input_size = fit_train_dataset.fitted_schema.input_size[1:]  # 28 x 28
    target_size = fit_train_dataset.fitted_schema.target_size[1]  # 10

    print(f"input_size: {input_size} - target_size: {target_size}")

    feature_extractor = MnistCNN(output_size=128)
    task_learner = MLP(
        input_size=128,
        activation_layer=partial(torch.nn.ReLU),
        flatten_input=True,
        hidden_channels=[target_size],
        last_activation=partial(torch.nn.Softmax, dim=1),
    )

    # 2. Explainable Model Specifications
    model_specifications = ModelDecisionGraphParameters(
        feature_extraction_output_type=FeatureExtractionOutputType.VECTOR,
        balancing_weight=0.1,  # This incentivizes the explainable model to provide more decisions.
    )

    # 3. Create the Explainable Model
    xpdeep_model = XpdeepModel.from_torch(
        example_dataset=fit_train_dataset,
        feature_extraction=feature_extractor,
        task_learner=task_learner,
        backbone=None,
        decision_graph_parameters=model_specifications,
    )

    # ##### Train #######

    # Metrics to monitor the training.
    metrics = DictMetrics(
        global_multi_class_accuracy=TorchGlobalMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        leaf_multi_class_accuracy=TorchLeafMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        global_multi_class_F1_score=TorchGlobalMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        leaf_multi_class_F1_score=TorchLeafMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        global_confusion_matrix=TorchGlobalMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
        leaf_confusion_matrix=TorchLeafMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
    )

    callbacks = [
        EarlyStopping(monitoring_metric="global_multi_class_accuracy", mode="maximize", patience=10),
        Scheduler(
            pre_scheduler=partial(ReduceLROnPlateau, mode="min", patience=1),
            step_method="epoch",
            monitoring_metric="Total loss",
        ),
        ModelCheckpoint(monitoring_metric="global_multi_class_accuracy", mode="maximize"),
    ]

    # Optimizer is a partial object as pytorch needs to give the model as optimizer parameter.
    optimizer = partial(torch.optim.AdamW, lr=0.001, foreach=False, fused=False)

    loss = XpdeepLoss.from_torch(
        loss=CrossEntropyLossFromProbabilities(reduction="none"),
        model=xpdeep_model,
        fitted_schema=fit_train_dataset.fitted_schema,
    )

    trainer = Trainer(
        loss=loss,
        optimizer=optimizer,
        callbacks=callbacks,
        start_epoch=0,
        max_epochs=6,
        metrics=metrics,
    )

    trained_model = trainer.train(
        model=xpdeep_model,
        train_set=fit_train_dataset,
        validation_set=fit_val_dataset,
        batch_size=128,
    )

    # ##### Explain #######

    # 1. Build the Explainer
    statistics = DictStats(
        distribution_target=DistributionStat(on="target"), distribution_prediction=DistributionStat(on="prediction")
    )
    quality_metrics = [Sensitivity(), Infidelity()]

    explainer = Explainer(quality_metrics=quality_metrics, metrics=metrics, statistics=statistics)

    # 2. Model Functioning Explanations
    model_explanations = explainer.global_explain(
        trained_model,
        train_set=fit_train_dataset,
        test_set=fit_test_dataset,
        validation_set=fit_val_dataset,
    )

    print(model_explanations.visualisation_link)

    # 3. Inference and their Causal Explanations
    my_filter = Filter("testing_filter", fit_test_dataset)
    my_filter.add_criteria(
        CategoricalCriterion(fit_test_dataset.fitted_schema["label"], categories=[1, 2]),
    )

    causal_explanations = explainer.local_explain(trained_model, fit_test_dataset, my_filter)
    print(causal_explanations.visualisation_link)


if __name__ == "__main__":
    init(api_key=os.getenv("API_KEY"), api_url=os.getenv("API_URL"))
    set_project(Project.create_or_get(name="MNIST Tutorial Review"))

    try:
        main()
    finally:
        get_project().delete()

Note

Here, the loss is a custom loss compatible with our output format, based on the default torch loss.

Note

For multiclass metrics, torchmetrics expects the target to be an index vector and not onehot vector. As our targets are onehot vectors and not indexes, we add the target_as_indexes parameter which convert targets from onehot to indexes prior to the metric computation.

Warning

Here, we set foreach and fused to False as currently it may lead to unstable behaviour in the training process.

We can now train the model:

trained_model = trainer.train(
    model=xpdeep_model,
    train_set=fit_train_dataset,
    validation_set=fit_val_dataset,
    batch_size=128,
)

👀 Full file preview
"""MNIST workflow, classification with image data."""

import os
from copy import deepcopy
from functools import partial

import boto3
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from botocore.client import Config
from datasets import DatasetDict, load_dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.classification import MulticlassAccuracy, MulticlassConfusionMatrix, MulticlassF1Score
from torchvision.transforms import RandomRotation

from xpdeep import init, set_project
from xpdeep.dataset.feature.augmentation.augmentation import FeatureAugmentation
from xpdeep.dataset.parquet_dataset import FittedParquetDataset, ParquetDataset
from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats, DistributionStat
from xpdeep.filtering.criteria import CategoricalCriterion
from xpdeep.filtering.filter import Filter
from xpdeep.metric import DictMetrics, TorchGlobalMetric, TorchLeafMetric
from xpdeep.model.feature_extraction_output_type import FeatureExtractionOutputType
from xpdeep.model.model_parameters import ModelDecisionGraphParameters
from xpdeep.model.xpdeep_model import XpdeepModel
from xpdeep.model.zoo.cross_entropy_loss_from_proba import CrossEntropyLossFromProbabilities
from xpdeep.model.zoo.doc import MnistCNN
from xpdeep.model.zoo.mlp import MLP
from xpdeep.project import Project, get_project
from xpdeep.trainer.callbacks import EarlyStopping, ModelCheckpoint, Scheduler
from xpdeep.trainer.loss import XpdeepLoss
from xpdeep.trainer.trainer import Trainer

STORAGE_OPTIONS = {
    "key": os.getenv("S3_DATASET_ACCESS_KEY_ID"),
    "secret": os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
    "client_kwargs": {
        "endpoint_url": os.getenv("S3_DATASET_ENDPOINT_URL"),
    },
    "s3_additional_kwargs": {"addressing_style": "path"},
}


def main():
    """Process the dataset, train, and explain the model."""
    torch.random.manual_seed(5)

    # ##### Prepare the Dataset #######

    # 1. Split and Convert your Raw Data
    # Load the dataset from HuggingFace datasets hub for convenience.
    dataset = load_dataset("mnist", trust_remote_code=True)

    test_eval = dataset["test"].train_test_split(test_size=0.5, stratify_by_column="label", seed=1225)

    splits = DatasetDict({
        "train": dataset["train"],
        "val": test_eval["train"],  # HuggingFace requires the "train" keyword.
        "test": test_eval["test"],
    })

    # Convert to pyarrow Table format
    train_table = pa.Table.from_pandas(splits["train"].to_pandas(), preserve_index=True)
    val_table = pa.Table.from_pandas(splits["val"].to_pandas(), preserve_index=True)
    test_table = pa.Table.from_pandas(splits["test"].to_pandas(), preserve_index=True)

    # Save each split as ".parquet" file
    pq.write_table(train_table, "train.parquet")
    pq.write_table(val_table, "val.parquet")
    pq.write_table(test_table, "test.parquet")

    # 2. Upload your Converted Data
    client = boto3.client(
        service_name="s3",
        endpoint_url=os.getenv("S3_DATASET_ENDPOINT_URL"),
        aws_access_key_id=os.getenv("S3_DATASET_ACCESS_KEY_ID"),
        aws_secret_access_key=os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
        config=Config(signature_version="s3v4"),
    )

    client.upload_file("train.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/train.parquet")
    client.upload_file("val.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/val.parquet")
    client.upload_file("test.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/test.parquet")

    # 3. Instantiate a Dataset
    train_dataset = ParquetDataset(
        name="mnist_train_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/train.parquet",
        storage_options=STORAGE_OPTIONS,
    )

    # 4. Find a schema
    analyzed_train_dataset = train_dataset.analyze(target_names=["label"])

    # 5. Fit the schema
    fit_train_dataset = analyzed_train_dataset.fit()

    fit_test_dataset = FittedParquetDataset(
        name="mnist_test_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/test.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    fit_val_dataset = FittedParquetDataset(
        name="mnist_validation_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/val.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    # 6. Add augmentation on the train set fitted schema (optional) , to ensure augmentation won't be applied on test
    # and validation sets
    # Here as images are grayscale, we don't need to permute the channel dimension to get the
    # channel dimension first.
    # For an RGB image, we could have used Compose([Permute([0, 3, 1, 2]), RandomRotation(90), Permute([0, 2, 3, 1])])
    augmentation = RandomRotation(90)
    image_rotation_augmentation = FeatureAugmentation(augment_raw=augmentation, augment_preprocessed=augmentation)
    fit_train_dataset.fitted_schema.add_augmentation("image", image_rotation_augmentation)

    # ##### Prepare the Model #######

    # 1. Create the required torch models
    input_size = fit_train_dataset.fitted_schema.input_size[1:]  # 28 x 28
    target_size = fit_train_dataset.fitted_schema.target_size[1]  # 10

    print(f"input_size: {input_size} - target_size: {target_size}")

    feature_extractor = MnistCNN(output_size=128)
    task_learner = MLP(
        input_size=128,
        activation_layer=partial(torch.nn.ReLU),
        flatten_input=True,
        hidden_channels=[target_size],
        last_activation=partial(torch.nn.Softmax, dim=1),
    )

    # 2. Explainable Model Specifications
    model_specifications = ModelDecisionGraphParameters(
        feature_extraction_output_type=FeatureExtractionOutputType.VECTOR,
        balancing_weight=0.1,  # This incentivizes the explainable model to provide more decisions.
    )

    # 3. Create the Explainable Model
    xpdeep_model = XpdeepModel.from_torch(
        example_dataset=fit_train_dataset,
        feature_extraction=feature_extractor,
        task_learner=task_learner,
        backbone=None,
        decision_graph_parameters=model_specifications,
    )

    # ##### Train #######

    # Metrics to monitor the training.
    metrics = DictMetrics(
        global_multi_class_accuracy=TorchGlobalMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        leaf_multi_class_accuracy=TorchLeafMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        global_multi_class_F1_score=TorchGlobalMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        leaf_multi_class_F1_score=TorchLeafMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        global_confusion_matrix=TorchGlobalMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
        leaf_confusion_matrix=TorchLeafMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
    )

    callbacks = [
        EarlyStopping(monitoring_metric="global_multi_class_accuracy", mode="maximize", patience=10),
        Scheduler(
            pre_scheduler=partial(ReduceLROnPlateau, mode="min", patience=1),
            step_method="epoch",
            monitoring_metric="Total loss",
        ),
        ModelCheckpoint(monitoring_metric="global_multi_class_accuracy", mode="maximize"),
    ]

    # Optimizer is a partial object as pytorch needs to give the model as optimizer parameter.
    optimizer = partial(torch.optim.AdamW, lr=0.001, foreach=False, fused=False)

    loss = XpdeepLoss.from_torch(
        loss=CrossEntropyLossFromProbabilities(reduction="none"),
        model=xpdeep_model,
        fitted_schema=fit_train_dataset.fitted_schema,
    )

    trainer = Trainer(
        loss=loss,
        optimizer=optimizer,
        callbacks=callbacks,
        start_epoch=0,
        max_epochs=6,
        metrics=metrics,
    )

    trained_model = trainer.train(
        model=xpdeep_model,
        train_set=fit_train_dataset,
        validation_set=fit_val_dataset,
        batch_size=128,
    )

    # ##### Explain #######

    # 1. Build the Explainer
    statistics = DictStats(
        distribution_target=DistributionStat(on="target"), distribution_prediction=DistributionStat(on="prediction")
    )
    quality_metrics = [Sensitivity(), Infidelity()]

    explainer = Explainer(quality_metrics=quality_metrics, metrics=metrics, statistics=statistics)

    # 2. Model Functioning Explanations
    model_explanations = explainer.global_explain(
        trained_model,
        train_set=fit_train_dataset,
        test_set=fit_test_dataset,
        validation_set=fit_val_dataset,
    )

    print(model_explanations.visualisation_link)

    # 3. Inference and their Causal Explanations
    my_filter = Filter("testing_filter", fit_test_dataset)
    my_filter.add_criteria(
        CategoricalCriterion(fit_test_dataset.fitted_schema["label"], categories=[1, 2]),
    )

    causal_explanations = explainer.local_explain(trained_model, fit_test_dataset, my_filter)
    print(causal_explanations.visualisation_link)


if __name__ == "__main__":
    init(api_key=os.getenv("API_KEY"), api_url=os.getenv("API_URL"))
    set_project(Project.create_or_get(name="MNIST Tutorial Review"))

    try:
        main()
    finally:
        get_project().delete()

The training logs are displayed in the console:

Train  - Epoch 1/15 -  Loss: 0.190 | █████████                                |  24.3 % Complete

Once the model trained, it can be used to get explanations.

Explain#

Similarly to the Trainer, explanations are computed with an Explainer interface.

1. Build the Explainer#

We provide the Explainer quality metrics to get insights on the explanation quality. In addition, we compute along with the explanations the distribution on targets and predictions.

from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats, DistributionStat

statistics = DictStats(
    distribution_target=DistributionStat(on="target"), distribution_prediction=DistributionStat(on="prediction")
)
quality_metrics = [Sensitivity(), Infidelity()]

explainer = Explainer(quality_metrics=quality_metrics, metrics=metrics, statistics=statistics)
👀 Full file preview
"""MNIST workflow, classification with image data."""

import os
from copy import deepcopy
from functools import partial

import boto3
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from botocore.client import Config
from datasets import DatasetDict, load_dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.classification import MulticlassAccuracy, MulticlassConfusionMatrix, MulticlassF1Score
from torchvision.transforms import RandomRotation

from xpdeep import init, set_project
from xpdeep.dataset.feature.augmentation.augmentation import FeatureAugmentation
from xpdeep.dataset.parquet_dataset import FittedParquetDataset, ParquetDataset
from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats, DistributionStat
from xpdeep.filtering.criteria import CategoricalCriterion
from xpdeep.filtering.filter import Filter
from xpdeep.metric import DictMetrics, TorchGlobalMetric, TorchLeafMetric
from xpdeep.model.feature_extraction_output_type import FeatureExtractionOutputType
from xpdeep.model.model_parameters import ModelDecisionGraphParameters
from xpdeep.model.xpdeep_model import XpdeepModel
from xpdeep.model.zoo.cross_entropy_loss_from_proba import CrossEntropyLossFromProbabilities
from xpdeep.model.zoo.doc import MnistCNN
from xpdeep.model.zoo.mlp import MLP
from xpdeep.project import Project, get_project
from xpdeep.trainer.callbacks import EarlyStopping, ModelCheckpoint, Scheduler
from xpdeep.trainer.loss import XpdeepLoss
from xpdeep.trainer.trainer import Trainer

STORAGE_OPTIONS = {
    "key": os.getenv("S3_DATASET_ACCESS_KEY_ID"),
    "secret": os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
    "client_kwargs": {
        "endpoint_url": os.getenv("S3_DATASET_ENDPOINT_URL"),
    },
    "s3_additional_kwargs": {"addressing_style": "path"},
}


def main():
    """Process the dataset, train, and explain the model."""
    torch.random.manual_seed(5)

    # ##### Prepare the Dataset #######

    # 1. Split and Convert your Raw Data
    # Load the dataset from HuggingFace datasets hub for convenience.
    dataset = load_dataset("mnist", trust_remote_code=True)

    test_eval = dataset["test"].train_test_split(test_size=0.5, stratify_by_column="label", seed=1225)

    splits = DatasetDict({
        "train": dataset["train"],
        "val": test_eval["train"],  # HuggingFace requires the "train" keyword.
        "test": test_eval["test"],
    })

    # Convert to pyarrow Table format
    train_table = pa.Table.from_pandas(splits["train"].to_pandas(), preserve_index=True)
    val_table = pa.Table.from_pandas(splits["val"].to_pandas(), preserve_index=True)
    test_table = pa.Table.from_pandas(splits["test"].to_pandas(), preserve_index=True)

    # Save each split as ".parquet" file
    pq.write_table(train_table, "train.parquet")
    pq.write_table(val_table, "val.parquet")
    pq.write_table(test_table, "test.parquet")

    # 2. Upload your Converted Data
    client = boto3.client(
        service_name="s3",
        endpoint_url=os.getenv("S3_DATASET_ENDPOINT_URL"),
        aws_access_key_id=os.getenv("S3_DATASET_ACCESS_KEY_ID"),
        aws_secret_access_key=os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
        config=Config(signature_version="s3v4"),
    )

    client.upload_file("train.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/train.parquet")
    client.upload_file("val.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/val.parquet")
    client.upload_file("test.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/test.parquet")

    # 3. Instantiate a Dataset
    train_dataset = ParquetDataset(
        name="mnist_train_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/train.parquet",
        storage_options=STORAGE_OPTIONS,
    )

    # 4. Find a schema
    analyzed_train_dataset = train_dataset.analyze(target_names=["label"])

    # 5. Fit the schema
    fit_train_dataset = analyzed_train_dataset.fit()

    fit_test_dataset = FittedParquetDataset(
        name="mnist_test_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/test.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    fit_val_dataset = FittedParquetDataset(
        name="mnist_validation_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/val.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    # 6. Add augmentation on the train set fitted schema (optional) , to ensure augmentation won't be applied on test
    # and validation sets
    # Here as images are grayscale, we don't need to permute the channel dimension to get the
    # channel dimension first.
    # For an RGB image, we could have used Compose([Permute([0, 3, 1, 2]), RandomRotation(90), Permute([0, 2, 3, 1])])
    augmentation = RandomRotation(90)
    image_rotation_augmentation = FeatureAugmentation(augment_raw=augmentation, augment_preprocessed=augmentation)
    fit_train_dataset.fitted_schema.add_augmentation("image", image_rotation_augmentation)

    # ##### Prepare the Model #######

    # 1. Create the required torch models
    input_size = fit_train_dataset.fitted_schema.input_size[1:]  # 28 x 28
    target_size = fit_train_dataset.fitted_schema.target_size[1]  # 10

    print(f"input_size: {input_size} - target_size: {target_size}")

    feature_extractor = MnistCNN(output_size=128)
    task_learner = MLP(
        input_size=128,
        activation_layer=partial(torch.nn.ReLU),
        flatten_input=True,
        hidden_channels=[target_size],
        last_activation=partial(torch.nn.Softmax, dim=1),
    )

    # 2. Explainable Model Specifications
    model_specifications = ModelDecisionGraphParameters(
        feature_extraction_output_type=FeatureExtractionOutputType.VECTOR,
        balancing_weight=0.1,  # This incentivizes the explainable model to provide more decisions.
    )

    # 3. Create the Explainable Model
    xpdeep_model = XpdeepModel.from_torch(
        example_dataset=fit_train_dataset,
        feature_extraction=feature_extractor,
        task_learner=task_learner,
        backbone=None,
        decision_graph_parameters=model_specifications,
    )

    # ##### Train #######

    # Metrics to monitor the training.
    metrics = DictMetrics(
        global_multi_class_accuracy=TorchGlobalMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        leaf_multi_class_accuracy=TorchLeafMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        global_multi_class_F1_score=TorchGlobalMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        leaf_multi_class_F1_score=TorchLeafMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        global_confusion_matrix=TorchGlobalMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
        leaf_confusion_matrix=TorchLeafMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
    )

    callbacks = [
        EarlyStopping(monitoring_metric="global_multi_class_accuracy", mode="maximize", patience=10),
        Scheduler(
            pre_scheduler=partial(ReduceLROnPlateau, mode="min", patience=1),
            step_method="epoch",
            monitoring_metric="Total loss",
        ),
        ModelCheckpoint(monitoring_metric="global_multi_class_accuracy", mode="maximize"),
    ]

    # Optimizer is a partial object as pytorch needs to give the model as optimizer parameter.
    optimizer = partial(torch.optim.AdamW, lr=0.001, foreach=False, fused=False)

    loss = XpdeepLoss.from_torch(
        loss=CrossEntropyLossFromProbabilities(reduction="none"),
        model=xpdeep_model,
        fitted_schema=fit_train_dataset.fitted_schema,
    )

    trainer = Trainer(
        loss=loss,
        optimizer=optimizer,
        callbacks=callbacks,
        start_epoch=0,
        max_epochs=6,
        metrics=metrics,
    )

    trained_model = trainer.train(
        model=xpdeep_model,
        train_set=fit_train_dataset,
        validation_set=fit_val_dataset,
        batch_size=128,
    )

    # ##### Explain #######

    # 1. Build the Explainer
    statistics = DictStats(
        distribution_target=DistributionStat(on="target"), distribution_prediction=DistributionStat(on="prediction")
    )
    quality_metrics = [Sensitivity(), Infidelity()]

    explainer = Explainer(quality_metrics=quality_metrics, metrics=metrics, statistics=statistics)

    # 2. Model Functioning Explanations
    model_explanations = explainer.global_explain(
        trained_model,
        train_set=fit_train_dataset,
        test_set=fit_test_dataset,
        validation_set=fit_val_dataset,
    )

    print(model_explanations.visualisation_link)

    # 3. Inference and their Causal Explanations
    my_filter = Filter("testing_filter", fit_test_dataset)
    my_filter.add_criteria(
        CategoricalCriterion(fit_test_dataset.fitted_schema["label"], categories=[1, 2]),
    )

    causal_explanations = explainer.local_explain(trained_model, fit_test_dataset, my_filter)
    print(causal_explanations.visualisation_link)


if __name__ == "__main__":
    init(api_key=os.getenv("API_KEY"), api_url=os.getenv("API_URL"))
    set_project(Project.create_or_get(name="MNIST Tutorial Review"))

    try:
        main()
    finally:
        get_project().delete()

Tip

Here we reuse metrics from the train stage for convenience, but they can be adapted to your needs !

2. Model Functioning Explanations#

Model Functioning Explanations are computed with the global_explain method.

model_explanations = explainer.global_explain(
    trained_model,
    train_set=fit_train_dataset,
    test_set=fit_test_dataset,
    validation_set=fit_val_dataset,
)
print(model_explanations.visualisation_link)
👀 Full file preview
"""MNIST workflow, classification with image data."""

import os
from copy import deepcopy
from functools import partial

import boto3
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from botocore.client import Config
from datasets import DatasetDict, load_dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.classification import MulticlassAccuracy, MulticlassConfusionMatrix, MulticlassF1Score
from torchvision.transforms import RandomRotation

from xpdeep import init, set_project
from xpdeep.dataset.feature.augmentation.augmentation import FeatureAugmentation
from xpdeep.dataset.parquet_dataset import FittedParquetDataset, ParquetDataset
from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats, DistributionStat
from xpdeep.filtering.criteria import CategoricalCriterion
from xpdeep.filtering.filter import Filter
from xpdeep.metric import DictMetrics, TorchGlobalMetric, TorchLeafMetric
from xpdeep.model.feature_extraction_output_type import FeatureExtractionOutputType
from xpdeep.model.model_parameters import ModelDecisionGraphParameters
from xpdeep.model.xpdeep_model import XpdeepModel
from xpdeep.model.zoo.cross_entropy_loss_from_proba import CrossEntropyLossFromProbabilities
from xpdeep.model.zoo.doc import MnistCNN
from xpdeep.model.zoo.mlp import MLP
from xpdeep.project import Project, get_project
from xpdeep.trainer.callbacks import EarlyStopping, ModelCheckpoint, Scheduler
from xpdeep.trainer.loss import XpdeepLoss
from xpdeep.trainer.trainer import Trainer

STORAGE_OPTIONS = {
    "key": os.getenv("S3_DATASET_ACCESS_KEY_ID"),
    "secret": os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
    "client_kwargs": {
        "endpoint_url": os.getenv("S3_DATASET_ENDPOINT_URL"),
    },
    "s3_additional_kwargs": {"addressing_style": "path"},
}


def main():
    """Process the dataset, train, and explain the model."""
    torch.random.manual_seed(5)

    # ##### Prepare the Dataset #######

    # 1. Split and Convert your Raw Data
    # Load the dataset from HuggingFace datasets hub for convenience.
    dataset = load_dataset("mnist", trust_remote_code=True)

    test_eval = dataset["test"].train_test_split(test_size=0.5, stratify_by_column="label", seed=1225)

    splits = DatasetDict({
        "train": dataset["train"],
        "val": test_eval["train"],  # HuggingFace requires the "train" keyword.
        "test": test_eval["test"],
    })

    # Convert to pyarrow Table format
    train_table = pa.Table.from_pandas(splits["train"].to_pandas(), preserve_index=True)
    val_table = pa.Table.from_pandas(splits["val"].to_pandas(), preserve_index=True)
    test_table = pa.Table.from_pandas(splits["test"].to_pandas(), preserve_index=True)

    # Save each split as ".parquet" file
    pq.write_table(train_table, "train.parquet")
    pq.write_table(val_table, "val.parquet")
    pq.write_table(test_table, "test.parquet")

    # 2. Upload your Converted Data
    client = boto3.client(
        service_name="s3",
        endpoint_url=os.getenv("S3_DATASET_ENDPOINT_URL"),
        aws_access_key_id=os.getenv("S3_DATASET_ACCESS_KEY_ID"),
        aws_secret_access_key=os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
        config=Config(signature_version="s3v4"),
    )

    client.upload_file("train.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/train.parquet")
    client.upload_file("val.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/val.parquet")
    client.upload_file("test.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/test.parquet")

    # 3. Instantiate a Dataset
    train_dataset = ParquetDataset(
        name="mnist_train_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/train.parquet",
        storage_options=STORAGE_OPTIONS,
    )

    # 4. Find a schema
    analyzed_train_dataset = train_dataset.analyze(target_names=["label"])

    # 5. Fit the schema
    fit_train_dataset = analyzed_train_dataset.fit()

    fit_test_dataset = FittedParquetDataset(
        name="mnist_test_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/test.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    fit_val_dataset = FittedParquetDataset(
        name="mnist_validation_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/val.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    # 6. Add augmentation on the train set fitted schema (optional) , to ensure augmentation won't be applied on test
    # and validation sets
    # Here as images are grayscale, we don't need to permute the channel dimension to get the
    # channel dimension first.
    # For an RGB image, we could have used Compose([Permute([0, 3, 1, 2]), RandomRotation(90), Permute([0, 2, 3, 1])])
    augmentation = RandomRotation(90)
    image_rotation_augmentation = FeatureAugmentation(augment_raw=augmentation, augment_preprocessed=augmentation)
    fit_train_dataset.fitted_schema.add_augmentation("image", image_rotation_augmentation)

    # ##### Prepare the Model #######

    # 1. Create the required torch models
    input_size = fit_train_dataset.fitted_schema.input_size[1:]  # 28 x 28
    target_size = fit_train_dataset.fitted_schema.target_size[1]  # 10

    print(f"input_size: {input_size} - target_size: {target_size}")

    feature_extractor = MnistCNN(output_size=128)
    task_learner = MLP(
        input_size=128,
        activation_layer=partial(torch.nn.ReLU),
        flatten_input=True,
        hidden_channels=[target_size],
        last_activation=partial(torch.nn.Softmax, dim=1),
    )

    # 2. Explainable Model Specifications
    model_specifications = ModelDecisionGraphParameters(
        feature_extraction_output_type=FeatureExtractionOutputType.VECTOR,
        balancing_weight=0.1,  # This incentivizes the explainable model to provide more decisions.
    )

    # 3. Create the Explainable Model
    xpdeep_model = XpdeepModel.from_torch(
        example_dataset=fit_train_dataset,
        feature_extraction=feature_extractor,
        task_learner=task_learner,
        backbone=None,
        decision_graph_parameters=model_specifications,
    )

    # ##### Train #######

    # Metrics to monitor the training.
    metrics = DictMetrics(
        global_multi_class_accuracy=TorchGlobalMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        leaf_multi_class_accuracy=TorchLeafMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        global_multi_class_F1_score=TorchGlobalMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        leaf_multi_class_F1_score=TorchLeafMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        global_confusion_matrix=TorchGlobalMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
        leaf_confusion_matrix=TorchLeafMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
    )

    callbacks = [
        EarlyStopping(monitoring_metric="global_multi_class_accuracy", mode="maximize", patience=10),
        Scheduler(
            pre_scheduler=partial(ReduceLROnPlateau, mode="min", patience=1),
            step_method="epoch",
            monitoring_metric="Total loss",
        ),
        ModelCheckpoint(monitoring_metric="global_multi_class_accuracy", mode="maximize"),
    ]

    # Optimizer is a partial object as pytorch needs to give the model as optimizer parameter.
    optimizer = partial(torch.optim.AdamW, lr=0.001, foreach=False, fused=False)

    loss = XpdeepLoss.from_torch(
        loss=CrossEntropyLossFromProbabilities(reduction="none"),
        model=xpdeep_model,
        fitted_schema=fit_train_dataset.fitted_schema,
    )

    trainer = Trainer(
        loss=loss,
        optimizer=optimizer,
        callbacks=callbacks,
        start_epoch=0,
        max_epochs=6,
        metrics=metrics,
    )

    trained_model = trainer.train(
        model=xpdeep_model,
        train_set=fit_train_dataset,
        validation_set=fit_val_dataset,
        batch_size=128,
    )

    # ##### Explain #######

    # 1. Build the Explainer
    statistics = DictStats(
        distribution_target=DistributionStat(on="target"), distribution_prediction=DistributionStat(on="prediction")
    )
    quality_metrics = [Sensitivity(), Infidelity()]

    explainer = Explainer(quality_metrics=quality_metrics, metrics=metrics, statistics=statistics)

    # 2. Model Functioning Explanations
    model_explanations = explainer.global_explain(
        trained_model,
        train_set=fit_train_dataset,
        test_set=fit_test_dataset,
        validation_set=fit_val_dataset,
    )

    print(model_explanations.visualisation_link)

    # 3. Inference and their Causal Explanations
    my_filter = Filter("testing_filter", fit_test_dataset)
    my_filter.add_criteria(
        CategoricalCriterion(fit_test_dataset.fitted_schema["label"], categories=[1, 2]),
    )

    causal_explanations = explainer.local_explain(trained_model, fit_test_dataset, my_filter)
    print(causal_explanations.visualisation_link)


if __name__ == "__main__":
    init(api_key=os.getenv("API_KEY"), api_url=os.getenv("API_URL"))
    set_project(Project.create_or_get(name="MNIST Tutorial Review"))

    try:
        main()
    finally:
        get_project().delete()

We can visualize explanations with XpViz, using the link in model_explanations.visualisation_link, if you already have requested the correct credentials.

3. Inference and their Causal Explanations#

We need a subset of samples to compute Causal Explanations on. Here we filter the test set on the image label, selecting samples with label 1 and 2. It represents 1083 samples.

from xpdeep.filtering.filter import Filter
from xpdeep.filtering.criteria import CategoricalCriterion

my_filter = Filter("testing_filter", fit_test_dataset)
my_filter.add_criteria(
    CategoricalCriterion(fit_test_dataset.fitted_schema["label"], categories=[1, 2]),
)

Explanation can then be computed using the local_explain method from the Explainer.

causal_explanations = explainer.local_explain(trained_model, fit_train_dataset, my_filter)
print(causal_explanations.visualisation_link)
👀 Full file preview
"""MNIST workflow, classification with image data."""

import os
from copy import deepcopy
from functools import partial

import boto3
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from botocore.client import Config
from datasets import DatasetDict, load_dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.classification import MulticlassAccuracy, MulticlassConfusionMatrix, MulticlassF1Score
from torchvision.transforms import RandomRotation

from xpdeep import init, set_project
from xpdeep.dataset.feature.augmentation.augmentation import FeatureAugmentation
from xpdeep.dataset.parquet_dataset import FittedParquetDataset, ParquetDataset
from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats, DistributionStat
from xpdeep.filtering.criteria import CategoricalCriterion
from xpdeep.filtering.filter import Filter
from xpdeep.metric import DictMetrics, TorchGlobalMetric, TorchLeafMetric
from xpdeep.model.feature_extraction_output_type import FeatureExtractionOutputType
from xpdeep.model.model_parameters import ModelDecisionGraphParameters
from xpdeep.model.xpdeep_model import XpdeepModel
from xpdeep.model.zoo.cross_entropy_loss_from_proba import CrossEntropyLossFromProbabilities
from xpdeep.model.zoo.doc import MnistCNN
from xpdeep.model.zoo.mlp import MLP
from xpdeep.project import Project, get_project
from xpdeep.trainer.callbacks import EarlyStopping, ModelCheckpoint, Scheduler
from xpdeep.trainer.loss import XpdeepLoss
from xpdeep.trainer.trainer import Trainer

STORAGE_OPTIONS = {
    "key": os.getenv("S3_DATASET_ACCESS_KEY_ID"),
    "secret": os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
    "client_kwargs": {
        "endpoint_url": os.getenv("S3_DATASET_ENDPOINT_URL"),
    },
    "s3_additional_kwargs": {"addressing_style": "path"},
}


def main():
    """Process the dataset, train, and explain the model."""
    torch.random.manual_seed(5)

    # ##### Prepare the Dataset #######

    # 1. Split and Convert your Raw Data
    # Load the dataset from HuggingFace datasets hub for convenience.
    dataset = load_dataset("mnist", trust_remote_code=True)

    test_eval = dataset["test"].train_test_split(test_size=0.5, stratify_by_column="label", seed=1225)

    splits = DatasetDict({
        "train": dataset["train"],
        "val": test_eval["train"],  # HuggingFace requires the "train" keyword.
        "test": test_eval["test"],
    })

    # Convert to pyarrow Table format
    train_table = pa.Table.from_pandas(splits["train"].to_pandas(), preserve_index=True)
    val_table = pa.Table.from_pandas(splits["val"].to_pandas(), preserve_index=True)
    test_table = pa.Table.from_pandas(splits["test"].to_pandas(), preserve_index=True)

    # Save each split as ".parquet" file
    pq.write_table(train_table, "train.parquet")
    pq.write_table(val_table, "val.parquet")
    pq.write_table(test_table, "test.parquet")

    # 2. Upload your Converted Data
    client = boto3.client(
        service_name="s3",
        endpoint_url=os.getenv("S3_DATASET_ENDPOINT_URL"),
        aws_access_key_id=os.getenv("S3_DATASET_ACCESS_KEY_ID"),
        aws_secret_access_key=os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
        config=Config(signature_version="s3v4"),
    )

    client.upload_file("train.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/train.parquet")
    client.upload_file("val.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/val.parquet")
    client.upload_file("test.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "mnist/test.parquet")

    # 3. Instantiate a Dataset
    train_dataset = ParquetDataset(
        name="mnist_train_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/train.parquet",
        storage_options=STORAGE_OPTIONS,
    )

    # 4. Find a schema
    analyzed_train_dataset = train_dataset.analyze(target_names=["label"])

    # 5. Fit the schema
    fit_train_dataset = analyzed_train_dataset.fit()

    fit_test_dataset = FittedParquetDataset(
        name="mnist_test_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/test.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    fit_val_dataset = FittedParquetDataset(
        name="mnist_validation_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/mnist/val.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    # 6. Add augmentation on the train set fitted schema (optional) , to ensure augmentation won't be applied on test
    # and validation sets
    # Here as images are grayscale, we don't need to permute the channel dimension to get the
    # channel dimension first.
    # For an RGB image, we could have used Compose([Permute([0, 3, 1, 2]), RandomRotation(90), Permute([0, 2, 3, 1])])
    augmentation = RandomRotation(90)
    image_rotation_augmentation = FeatureAugmentation(augment_raw=augmentation, augment_preprocessed=augmentation)
    fit_train_dataset.fitted_schema.add_augmentation("image", image_rotation_augmentation)

    # ##### Prepare the Model #######

    # 1. Create the required torch models
    input_size = fit_train_dataset.fitted_schema.input_size[1:]  # 28 x 28
    target_size = fit_train_dataset.fitted_schema.target_size[1]  # 10

    print(f"input_size: {input_size} - target_size: {target_size}")

    feature_extractor = MnistCNN(output_size=128)
    task_learner = MLP(
        input_size=128,
        activation_layer=partial(torch.nn.ReLU),
        flatten_input=True,
        hidden_channels=[target_size],
        last_activation=partial(torch.nn.Softmax, dim=1),
    )

    # 2. Explainable Model Specifications
    model_specifications = ModelDecisionGraphParameters(
        feature_extraction_output_type=FeatureExtractionOutputType.VECTOR,
        balancing_weight=0.1,  # This incentivizes the explainable model to provide more decisions.
    )

    # 3. Create the Explainable Model
    xpdeep_model = XpdeepModel.from_torch(
        example_dataset=fit_train_dataset,
        feature_extraction=feature_extractor,
        task_learner=task_learner,
        backbone=None,
        decision_graph_parameters=model_specifications,
    )

    # ##### Train #######

    # Metrics to monitor the training.
    metrics = DictMetrics(
        global_multi_class_accuracy=TorchGlobalMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        leaf_multi_class_accuracy=TorchLeafMetric(
            partial(MulticlassAccuracy, num_classes=target_size, average="micro"), target_as_indexes=True
        ),
        global_multi_class_F1_score=TorchGlobalMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        leaf_multi_class_F1_score=TorchLeafMetric(
            partial(MulticlassF1Score, num_classes=target_size, average="macro"), target_as_indexes=True
        ),
        global_confusion_matrix=TorchGlobalMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
        leaf_confusion_matrix=TorchLeafMetric(
            partial(MulticlassConfusionMatrix, num_classes=target_size, normalize="all"), target_as_indexes=True
        ),
    )

    callbacks = [
        EarlyStopping(monitoring_metric="global_multi_class_accuracy", mode="maximize", patience=10),
        Scheduler(
            pre_scheduler=partial(ReduceLROnPlateau, mode="min", patience=1),
            step_method="epoch",
            monitoring_metric="Total loss",
        ),
        ModelCheckpoint(monitoring_metric="global_multi_class_accuracy", mode="maximize"),
    ]

    # Optimizer is a partial object as pytorch needs to give the model as optimizer parameter.
    optimizer = partial(torch.optim.AdamW, lr=0.001, foreach=False, fused=False)

    loss = XpdeepLoss.from_torch(
        loss=CrossEntropyLossFromProbabilities(reduction="none"),
        model=xpdeep_model,
        fitted_schema=fit_train_dataset.fitted_schema,
    )

    trainer = Trainer(
        loss=loss,
        optimizer=optimizer,
        callbacks=callbacks,
        start_epoch=0,
        max_epochs=6,
        metrics=metrics,
    )

    trained_model = trainer.train(
        model=xpdeep_model,
        train_set=fit_train_dataset,
        validation_set=fit_val_dataset,
        batch_size=128,
    )

    # ##### Explain #######

    # 1. Build the Explainer
    statistics = DictStats(
        distribution_target=DistributionStat(on="target"), distribution_prediction=DistributionStat(on="prediction")
    )
    quality_metrics = [Sensitivity(), Infidelity()]

    explainer = Explainer(quality_metrics=quality_metrics, metrics=metrics, statistics=statistics)

    # 2. Model Functioning Explanations
    model_explanations = explainer.global_explain(
        trained_model,
        train_set=fit_train_dataset,
        test_set=fit_test_dataset,
        validation_set=fit_val_dataset,
    )

    print(model_explanations.visualisation_link)

    # 3. Inference and their Causal Explanations
    my_filter = Filter("testing_filter", fit_test_dataset)
    my_filter.add_criteria(
        CategoricalCriterion(fit_test_dataset.fitted_schema["label"], categories=[1, 2]),
    )

    causal_explanations = explainer.local_explain(trained_model, fit_test_dataset, my_filter)
    print(causal_explanations.visualisation_link)


if __name__ == "__main__":
    init(api_key=os.getenv("API_KEY"), api_url=os.getenv("API_URL"))
    set_project(Project.create_or_get(name="MNIST Tutorial Review"))

    try:
        main()
    finally:
        get_project().delete()

We can again visualize causal explanations using the visualisation_link.