Skip to content

Explain an Existing Model, Adult Income Dataset#

Please refer to the Adult Income tutorial for a more detailed overview of how to learn a self-explainable model for this dataset. This tutorial will focus on learning explanations from an existing model trained for a classification problem on the Adult Income dataset and will highlight the main differences with training a self-explainable model.

Tne dataset preparation, and the explanation computation are exactly the same as for a self-explainable model. The model preparation, and the training process are however slightly different.

Prepare the Model#

Initial Training#

To simulate an already trained model whose performance we want to preserve while computing explanations, we first train the feature extractor and task learner using PyTorch.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.nn.Sequential(feature_extractor, task_learner).to(device)

opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
crit = torch.nn.CrossEntropyLoss()
batch_size = 128
epochs = 8

def preprocess(df, fit_dataset):
    outs = []
    for exp_col in fit_dataset.fitted_schema.columns:
        if exp_col.name not in df.columns:
            continue

        Xc = df[[exp_col.name]]
        Xt = exp_col.preprocessor.transform(Xc)

        if hasattr(Xt, "toarray"):
            Xt = Xt.toarray()
        # densifie si sparse
        Xt = np.asarray(Xt, dtype=np.float32)
        if Xt.ndim == 1:
            Xt = Xt.reshape(-1, 1)  # garantit 2D
        outs.append(Xt)

    X = np.concatenate(outs, axis=1)

    return torch.from_numpy(X)

X_train, y_train = train_data.drop(columns=["income"]), train_data[["income"]]
X_test, y_test = test_data.drop(columns=["income"]), test_data[["income"]]
X_val, y_val = val_data.drop(columns=["income"]), val_data[["income"]]

X_train, y_train = preprocess(X_train, fit_train_dataset), preprocess(y_train, fit_train_dataset).argmax(1)
X_test, y_test = preprocess(X_test, fit_test_dataset), preprocess(y_test, fit_test_dataset).argmax(1)
X_val, y_val = preprocess(X_val, fit_val_dataset), preprocess(y_val, fit_val_dataset).argmax(1)

train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=batch_size, shuffle=True)
val_loader = DataLoader(TensorDataset(X_val, y_val), batch_size=batch_size)

@torch.no_grad()
def val_acc():
    model.eval()
    correct = total = 0
    for xb, yb in val_loader:
        xb, yb = xb.to(device), yb.to(device)
        pred = model(xb).argmax(1)
        correct += (pred == yb).sum().item()
        total += yb.numel()
    return correct / max(total, 1)

for e in range(1, epochs + 1):
    model.train()
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        opt.zero_grad(set_to_none=True)
        loss = crit(model(xb), yb)
        loss.backward()
        opt.step()
    print(f"epoch {e:02d} | val_acc={val_acc():.4f}")

print(f"Accuracy score on train set: {accuracy_score(model(X_train.to(device)).argmax(1).cpu(), y_train)}")
print(f"Accuracy score on test set: {accuracy_score(model(X_test.to(device)).argmax(1).cpu(), y_test)}")
print(f"Accuracy score on validation set: {accuracy_score(model(X_val.to(device)).argmax(1).cpu(), y_val)}")
👀 Full file preview
"""Adult Income workflow with the model frozen, classification with tabular data."""

import os
from functools import partial

import boto3
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from botocore.config import Config
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.nn import Sequential
from torch.utils.data import DataLoader, TensorDataset
from torchmetrics.classification import MulticlassAccuracy, MulticlassConfusionMatrix, MulticlassF1Score

from xpdeep import init, set_project
from xpdeep.dataset.feature import ExplainableFeature
from xpdeep.dataset.feature.feature_types import NumericalFeature
from xpdeep.dataset.parquet_dataset import FittedParquetDataset, ParquetDataset
from xpdeep.dataset.preprocessor.preprocessor import SklearnPreprocessor
from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats, DistributionStat
from xpdeep.filtering.criteria import CategoricalCriterion, NumericalCriterion
from xpdeep.filtering.filter import Filter
from xpdeep.metric import DictMetrics, TorchGlobalMetric, TorchLeafMetric
from xpdeep.model.feature_extraction_output_type import FeatureExtractionOutputType
from xpdeep.model.model_parameters import ModelDecisionGraphParameters
from xpdeep.model.xpdeep_model import XpdeepModel
from xpdeep.project import Project, get_project
from xpdeep.trainer.trainer import FrozenModelTrainer

STORAGE_OPTIONS = {
    "key": os.getenv("S3_DATASET_ACCESS_KEY_ID"),
    "secret": os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
    "client_kwargs": {
        "endpoint_url": os.getenv("S3_DATASET_ENDPOINT_URL"),
    },
    "s3_additional_kwargs": {"addressing_style": "path"},
}


def main():
    """Process the dataset, train, and explain the model."""
    torch.random.manual_seed(5)

    # ##### Prepare the Dataset #######

    # 1. Split and Convert your Raw Data
    # Load the CSV file
    file_path = "adult_income.csv"
    data = pd.read_csv(file_path)
    data = data.drop(columns=["fnlwgt", "education"])

    # Split the data into training and testing sets
    train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)

    # Further split the training set into training and validation sets
    train_data, val_data = train_test_split(train_data, test_size=0.25, random_state=42)

    # Convert to pyarrow Table format
    train_table = pa.Table.from_pandas(train_data, preserve_index=True)
    val_table = pa.Table.from_pandas(val_data, preserve_index=True)
    test_table = pa.Table.from_pandas(test_data, preserve_index=True)

    # Save each split as ".parquet" file
    pq.write_table(train_table, "train.parquet")
    pq.write_table(val_table, "val.parquet")
    pq.write_table(test_table, "test.parquet")

    # 2. Upload your Converted Data
    client = boto3.client(
        service_name="s3",
        endpoint_url=os.getenv("S3_DATASET_ENDPOINT_URL"),
        aws_access_key_id=os.getenv("S3_DATASET_ACCESS_KEY_ID"),
        aws_secret_access_key=os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
        config=Config(signature_version="s3v4"),
    )

    client.upload_file("train.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "frozen_model_adult_income/train.parquet")
    client.upload_file("val.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "frozen_model_adult_income/val.parquet")
    client.upload_file("test.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "frozen_model_adult_income/test.parquet")

    # 3. Instantiate a Dataset

    train_dataset = ParquetDataset(
        name="adult_income_train_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/frozen_model_adult_income/train.parquet",
        storage_options=STORAGE_OPTIONS,
    )

    # 4. Find a schema
    analyzed_train_dataset = train_dataset.analyze(target_names=["income"])
    print(analyzed_train_dataset.analyzed_schema)

    preprocessor = SklearnPreprocessor(preprocess_function=StandardScaler())
    analyzed_train_dataset.analyzed_schema["educational-num"] = ExplainableFeature(
        name="educational-num", is_target=False, preprocessor=preprocessor, feature_type=NumericalFeature()
    )
    print(analyzed_train_dataset.analyzed_schema)

    # 5. Fit the schema
    fit_train_dataset = analyzed_train_dataset.fit()

    fit_test_dataset = FittedParquetDataset(
        name="adult_income_test_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/frozen_model_adult_income/test.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=fit_train_dataset.fitted_schema,
    )

    fit_val_dataset = FittedParquetDataset(
        name="adult_income_validation_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/frozen_model_adult_income/val.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=fit_train_dataset.fitted_schema,
    )

    # ##### Prepare the Model #######

    # 1. Create the required torch models
    input_size = fit_train_dataset.fitted_schema.input_size[1]
    target_size = fit_train_dataset.fitted_schema.target_size[1]

    print(f"input_size: {input_size} - target_size: {target_size}")

    feature_extractor = Sequential(
        torch.nn.Linear(input_size, 128),
        torch.nn.ReLU(),
        torch.nn.Linear(128, 64),
        torch.nn.ReLU(),
    )

    task_learner = Sequential(torch.nn.Linear(64, target_size))

    # 2. Train model locally

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = torch.nn.Sequential(feature_extractor, task_learner).to(device)

    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
    crit = torch.nn.CrossEntropyLoss()
    batch_size = 128
    epochs = 8

    def preprocess(df, fit_dataset):
        outs = []
        for exp_col in fit_dataset.fitted_schema.columns:
            if exp_col.name not in df.columns:
                continue

            Xc = df[[exp_col.name]]
            Xt = exp_col.preprocessor.transform(Xc)

            if hasattr(Xt, "toarray"):
                Xt = Xt.toarray()
            # densifie si sparse
            Xt = np.asarray(Xt, dtype=np.float32)
            if Xt.ndim == 1:
                Xt = Xt.reshape(-1, 1)  # garantit 2D
            outs.append(Xt)

        X = np.concatenate(outs, axis=1)

        return torch.from_numpy(X)

    X_train, y_train = train_data.drop(columns=["income"]), train_data[["income"]]
    X_test, y_test = test_data.drop(columns=["income"]), test_data[["income"]]
    X_val, y_val = val_data.drop(columns=["income"]), val_data[["income"]]

    X_train, y_train = preprocess(X_train, fit_train_dataset), preprocess(y_train, fit_train_dataset).argmax(1)
    X_test, y_test = preprocess(X_test, fit_test_dataset), preprocess(y_test, fit_test_dataset).argmax(1)
    X_val, y_val = preprocess(X_val, fit_val_dataset), preprocess(y_val, fit_val_dataset).argmax(1)

    train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(TensorDataset(X_val, y_val), batch_size=batch_size)

    @torch.no_grad()
    def val_acc():
        model.eval()
        correct = total = 0
        for xb, yb in val_loader:
            xb, yb = xb.to(device), yb.to(device)
            pred = model(xb).argmax(1)
            correct += (pred == yb).sum().item()
            total += yb.numel()
        return correct / max(total, 1)

    for e in range(1, epochs + 1):
        model.train()
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad(set_to_none=True)
            loss = crit(model(xb), yb)
            loss.backward()
            opt.step()
        print(f"epoch {e:02d} | val_acc={val_acc():.4f}")

    print(f"Accuracy score on train set: {accuracy_score(model(X_train.to(device)).argmax(1).cpu(), y_train)}")
    print(f"Accuracy score on test set: {accuracy_score(model(X_test.to(device)).argmax(1).cpu(), y_test)}")
    print(f"Accuracy score on validation set: {accuracy_score(model(X_val.to(device)).argmax(1).cpu(), y_val)}")

    # 3. Explainable Model Specifications

    model_specifications = ModelDecisionGraphParameters(
        graph_depth=3,
        population_pruning_threshold=0.1,
        feature_extraction_output_type=FeatureExtractionOutputType.VECTOR,
        frozen_model=True,
    )

    # 4. Create the Explainable Model
    xpdeep_model = XpdeepModel.from_torch(
        example_dataset=fit_train_dataset,
        feature_extraction=feature_extractor,
        task_learner=task_learner,
        backbone=None,
        decision_graph_parameters=model_specifications,
    )

    # ##### Train #######

    # Metrics to monitor the training.
    trainer = FrozenModelTrainer(start_epoch=0, max_epochs=10)

    trained_model = trainer.train(
        model=xpdeep_model,
        train_set=fit_train_dataset,
        validation_set=fit_val_dataset,
        batch_size=128,
    )

    # ##### Explain #######

    # 1. Build the Explainer
    statistics = DictStats(
        distribution_target=DistributionStat(on="target"), distribution_prediction=DistributionStat(on="prediction")
    )
    quality_metrics = [Sensitivity(), Infidelity()]

    metrics = DictMetrics(
        global_multi_class_accuracy=TorchGlobalMetric(
            partial(MulticlassAccuracy, num_classes=2, average="micro"), target_as_indexes=True
        ),
        leaf_multi_class_accuracy=TorchLeafMetric(
            partial(MulticlassAccuracy, num_classes=2, average="micro"), target_as_indexes=True
        ),
        global_multi_class_F1_score=TorchGlobalMetric(
            partial(MulticlassF1Score, num_classes=2, average="macro"), target_as_indexes=True
        ),
        leaf_multi_class_F1_score=TorchLeafMetric(
            partial(MulticlassF1Score, num_classes=2, average="macro"), target_as_indexes=True
        ),
        global_confusion_matrix=TorchGlobalMetric(
            partial(MulticlassConfusionMatrix, num_classes=2, normalize="all"), target_as_indexes=True
        ),
        leaf_confusion_matrix=TorchLeafMetric(
            partial(MulticlassConfusionMatrix, num_classes=2, normalize="all"), target_as_indexes=True
        ),
    )

    explainer = Explainer(quality_metrics=quality_metrics, metrics=metrics, statistics=statistics)

    # 2. Model Functioning Explanations
    model_explanations = explainer.global_explain(
        trained_model,
        train_set=fit_train_dataset,
        test_set=fit_test_dataset,
        validation_set=fit_val_dataset,
    )
    print(model_explanations.visualisation_link)

    # 3. Inference and their Causal Explanations
    my_filter = Filter("testing_filter", fit_test_dataset)
    my_filter.add_criteria(
        NumericalCriterion(fit_test_dataset.fitted_schema["age"], max_=30),
        CategoricalCriterion(fit_test_dataset.fitted_schema["workclass"], categories=["Private"]),
    )

    causal_explanations = explainer.local_explain(trained_model, fit_train_dataset, my_filter)

    print(causal_explanations.visualisation_link)


if __name__ == "__main__":
    init(api_key=os.getenv("API_KEY"), api_url=os.getenv("API_URL"))
    set_project(Project.create_or_get(name="Frozen Model Tutorial"))

    try:
        main()
    finally:
        get_project().delete()

Explainable Model Specifications#

As specified in this tutorial, we set frozen_model to True to specify that we only want to train the explanations of an existing model, that is maintained unchanged.

from xpdeep.model.model_parameters import ModelDecisionGraphParameters
from xpdeep.model.feature_extraction_output_type import FeatureExtractionOutputType

model_specifications = ModelDecisionGraphParameters(
    graph_depth=3,
    population_pruning_threshold=0.1,
    feature_extraction_output_type=FeatureExtractionOutputType.VECTOR,
    frozen_model=True,
)
👀 Full file preview
"""Adult Income workflow with the model frozen, classification with tabular data."""

import os
from functools import partial

import boto3
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from botocore.config import Config
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.nn import Sequential
from torch.utils.data import DataLoader, TensorDataset
from torchmetrics.classification import MulticlassAccuracy, MulticlassConfusionMatrix, MulticlassF1Score

from xpdeep import init, set_project
from xpdeep.dataset.feature import ExplainableFeature
from xpdeep.dataset.feature.feature_types import NumericalFeature
from xpdeep.dataset.parquet_dataset import FittedParquetDataset, ParquetDataset
from xpdeep.dataset.preprocessor.preprocessor import SklearnPreprocessor
from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats, DistributionStat
from xpdeep.filtering.criteria import CategoricalCriterion, NumericalCriterion
from xpdeep.filtering.filter import Filter
from xpdeep.metric import DictMetrics, TorchGlobalMetric, TorchLeafMetric
from xpdeep.model.feature_extraction_output_type import FeatureExtractionOutputType
from xpdeep.model.model_parameters import ModelDecisionGraphParameters
from xpdeep.model.xpdeep_model import XpdeepModel
from xpdeep.project import Project, get_project
from xpdeep.trainer.trainer import FrozenModelTrainer

STORAGE_OPTIONS = {
    "key": os.getenv("S3_DATASET_ACCESS_KEY_ID"),
    "secret": os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
    "client_kwargs": {
        "endpoint_url": os.getenv("S3_DATASET_ENDPOINT_URL"),
    },
    "s3_additional_kwargs": {"addressing_style": "path"},
}


def main():
    """Process the dataset, train, and explain the model."""
    torch.random.manual_seed(5)

    # ##### Prepare the Dataset #######

    # 1. Split and Convert your Raw Data
    # Load the CSV file
    file_path = "adult_income.csv"
    data = pd.read_csv(file_path)
    data = data.drop(columns=["fnlwgt", "education"])

    # Split the data into training and testing sets
    train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)

    # Further split the training set into training and validation sets
    train_data, val_data = train_test_split(train_data, test_size=0.25, random_state=42)

    # Convert to pyarrow Table format
    train_table = pa.Table.from_pandas(train_data, preserve_index=True)
    val_table = pa.Table.from_pandas(val_data, preserve_index=True)
    test_table = pa.Table.from_pandas(test_data, preserve_index=True)

    # Save each split as ".parquet" file
    pq.write_table(train_table, "train.parquet")
    pq.write_table(val_table, "val.parquet")
    pq.write_table(test_table, "test.parquet")

    # 2. Upload your Converted Data
    client = boto3.client(
        service_name="s3",
        endpoint_url=os.getenv("S3_DATASET_ENDPOINT_URL"),
        aws_access_key_id=os.getenv("S3_DATASET_ACCESS_KEY_ID"),
        aws_secret_access_key=os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
        config=Config(signature_version="s3v4"),
    )

    client.upload_file("train.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "frozen_model_adult_income/train.parquet")
    client.upload_file("val.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "frozen_model_adult_income/val.parquet")
    client.upload_file("test.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "frozen_model_adult_income/test.parquet")

    # 3. Instantiate a Dataset

    train_dataset = ParquetDataset(
        name="adult_income_train_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/frozen_model_adult_income/train.parquet",
        storage_options=STORAGE_OPTIONS,
    )

    # 4. Find a schema
    analyzed_train_dataset = train_dataset.analyze(target_names=["income"])
    print(analyzed_train_dataset.analyzed_schema)

    preprocessor = SklearnPreprocessor(preprocess_function=StandardScaler())
    analyzed_train_dataset.analyzed_schema["educational-num"] = ExplainableFeature(
        name="educational-num", is_target=False, preprocessor=preprocessor, feature_type=NumericalFeature()
    )
    print(analyzed_train_dataset.analyzed_schema)

    # 5. Fit the schema
    fit_train_dataset = analyzed_train_dataset.fit()

    fit_test_dataset = FittedParquetDataset(
        name="adult_income_test_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/frozen_model_adult_income/test.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=fit_train_dataset.fitted_schema,
    )

    fit_val_dataset = FittedParquetDataset(
        name="adult_income_validation_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/frozen_model_adult_income/val.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=fit_train_dataset.fitted_schema,
    )

    # ##### Prepare the Model #######

    # 1. Create the required torch models
    input_size = fit_train_dataset.fitted_schema.input_size[1]
    target_size = fit_train_dataset.fitted_schema.target_size[1]

    print(f"input_size: {input_size} - target_size: {target_size}")

    feature_extractor = Sequential(
        torch.nn.Linear(input_size, 128),
        torch.nn.ReLU(),
        torch.nn.Linear(128, 64),
        torch.nn.ReLU(),
    )

    task_learner = Sequential(torch.nn.Linear(64, target_size))

    # 2. Train model locally

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = torch.nn.Sequential(feature_extractor, task_learner).to(device)

    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
    crit = torch.nn.CrossEntropyLoss()
    batch_size = 128
    epochs = 8

    def preprocess(df, fit_dataset):
        outs = []
        for exp_col in fit_dataset.fitted_schema.columns:
            if exp_col.name not in df.columns:
                continue

            Xc = df[[exp_col.name]]
            Xt = exp_col.preprocessor.transform(Xc)

            if hasattr(Xt, "toarray"):
                Xt = Xt.toarray()
            # densifie si sparse
            Xt = np.asarray(Xt, dtype=np.float32)
            if Xt.ndim == 1:
                Xt = Xt.reshape(-1, 1)  # garantit 2D
            outs.append(Xt)

        X = np.concatenate(outs, axis=1)

        return torch.from_numpy(X)

    X_train, y_train = train_data.drop(columns=["income"]), train_data[["income"]]
    X_test, y_test = test_data.drop(columns=["income"]), test_data[["income"]]
    X_val, y_val = val_data.drop(columns=["income"]), val_data[["income"]]

    X_train, y_train = preprocess(X_train, fit_train_dataset), preprocess(y_train, fit_train_dataset).argmax(1)
    X_test, y_test = preprocess(X_test, fit_test_dataset), preprocess(y_test, fit_test_dataset).argmax(1)
    X_val, y_val = preprocess(X_val, fit_val_dataset), preprocess(y_val, fit_val_dataset).argmax(1)

    train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(TensorDataset(X_val, y_val), batch_size=batch_size)

    @torch.no_grad()
    def val_acc():
        model.eval()
        correct = total = 0
        for xb, yb in val_loader:
            xb, yb = xb.to(device), yb.to(device)
            pred = model(xb).argmax(1)
            correct += (pred == yb).sum().item()
            total += yb.numel()
        return correct / max(total, 1)

    for e in range(1, epochs + 1):
        model.train()
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad(set_to_none=True)
            loss = crit(model(xb), yb)
            loss.backward()
            opt.step()
        print(f"epoch {e:02d} | val_acc={val_acc():.4f}")

    print(f"Accuracy score on train set: {accuracy_score(model(X_train.to(device)).argmax(1).cpu(), y_train)}")
    print(f"Accuracy score on test set: {accuracy_score(model(X_test.to(device)).argmax(1).cpu(), y_test)}")
    print(f"Accuracy score on validation set: {accuracy_score(model(X_val.to(device)).argmax(1).cpu(), y_val)}")

    # 3. Explainable Model Specifications

    model_specifications = ModelDecisionGraphParameters(
        graph_depth=3,
        population_pruning_threshold=0.1,
        feature_extraction_output_type=FeatureExtractionOutputType.VECTOR,
        frozen_model=True,
    )

    # 4. Create the Explainable Model
    xpdeep_model = XpdeepModel.from_torch(
        example_dataset=fit_train_dataset,
        feature_extraction=feature_extractor,
        task_learner=task_learner,
        backbone=None,
        decision_graph_parameters=model_specifications,
    )

    # ##### Train #######

    # Metrics to monitor the training.
    trainer = FrozenModelTrainer(start_epoch=0, max_epochs=10)

    trained_model = trainer.train(
        model=xpdeep_model,
        train_set=fit_train_dataset,
        validation_set=fit_val_dataset,
        batch_size=128,
    )

    # ##### Explain #######

    # 1. Build the Explainer
    statistics = DictStats(
        distribution_target=DistributionStat(on="target"), distribution_prediction=DistributionStat(on="prediction")
    )
    quality_metrics = [Sensitivity(), Infidelity()]

    metrics = DictMetrics(
        global_multi_class_accuracy=TorchGlobalMetric(
            partial(MulticlassAccuracy, num_classes=2, average="micro"), target_as_indexes=True
        ),
        leaf_multi_class_accuracy=TorchLeafMetric(
            partial(MulticlassAccuracy, num_classes=2, average="micro"), target_as_indexes=True
        ),
        global_multi_class_F1_score=TorchGlobalMetric(
            partial(MulticlassF1Score, num_classes=2, average="macro"), target_as_indexes=True
        ),
        leaf_multi_class_F1_score=TorchLeafMetric(
            partial(MulticlassF1Score, num_classes=2, average="macro"), target_as_indexes=True
        ),
        global_confusion_matrix=TorchGlobalMetric(
            partial(MulticlassConfusionMatrix, num_classes=2, normalize="all"), target_as_indexes=True
        ),
        leaf_confusion_matrix=TorchLeafMetric(
            partial(MulticlassConfusionMatrix, num_classes=2, normalize="all"), target_as_indexes=True
        ),
    )

    explainer = Explainer(quality_metrics=quality_metrics, metrics=metrics, statistics=statistics)

    # 2. Model Functioning Explanations
    model_explanations = explainer.global_explain(
        trained_model,
        train_set=fit_train_dataset,
        test_set=fit_test_dataset,
        validation_set=fit_val_dataset,
    )
    print(model_explanations.visualisation_link)

    # 3. Inference and their Causal Explanations
    my_filter = Filter("testing_filter", fit_test_dataset)
    my_filter.add_criteria(
        NumericalCriterion(fit_test_dataset.fitted_schema["age"], max_=30),
        CategoricalCriterion(fit_test_dataset.fitted_schema["workclass"], categories=["Private"]),
    )

    causal_explanations = explainer.local_explain(trained_model, fit_train_dataset, my_filter)

    print(causal_explanations.visualisation_link)


if __name__ == "__main__":
    init(api_key=os.getenv("API_KEY"), api_url=os.getenv("API_URL"))
    set_project(Project.create_or_get(name="Frozen Model Tutorial"))

    try:
        main()
    finally:
        get_project().delete()

Train the Explanations#

The train step is straightforward: we need to specify the Trainer parameters. We will use the FrozenModelTrainer interface for convenience, as it carries most of the default correct values for the explanation training process.

Internally, Xpdeep uses its own internal algorithm to compute and train explanations while conserving the original model parameters and performances intact.

from xpdeep.trainer.trainer import FrozenModelTrainer

trainer = FrozenModelTrainer(start_epoch=0, max_epochs=10)
👀 Full file preview
"""Adult Income workflow with the model frozen, classification with tabular data."""

import os
from functools import partial

import boto3
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from botocore.config import Config
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.nn import Sequential
from torch.utils.data import DataLoader, TensorDataset
from torchmetrics.classification import MulticlassAccuracy, MulticlassConfusionMatrix, MulticlassF1Score

from xpdeep import init, set_project
from xpdeep.dataset.feature import ExplainableFeature
from xpdeep.dataset.feature.feature_types import NumericalFeature
from xpdeep.dataset.parquet_dataset import FittedParquetDataset, ParquetDataset
from xpdeep.dataset.preprocessor.preprocessor import SklearnPreprocessor
from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats, DistributionStat
from xpdeep.filtering.criteria import CategoricalCriterion, NumericalCriterion
from xpdeep.filtering.filter import Filter
from xpdeep.metric import DictMetrics, TorchGlobalMetric, TorchLeafMetric
from xpdeep.model.feature_extraction_output_type import FeatureExtractionOutputType
from xpdeep.model.model_parameters import ModelDecisionGraphParameters
from xpdeep.model.xpdeep_model import XpdeepModel
from xpdeep.project import Project, get_project
from xpdeep.trainer.trainer import FrozenModelTrainer

STORAGE_OPTIONS = {
    "key": os.getenv("S3_DATASET_ACCESS_KEY_ID"),
    "secret": os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
    "client_kwargs": {
        "endpoint_url": os.getenv("S3_DATASET_ENDPOINT_URL"),
    },
    "s3_additional_kwargs": {"addressing_style": "path"},
}


def main():
    """Process the dataset, train, and explain the model."""
    torch.random.manual_seed(5)

    # ##### Prepare the Dataset #######

    # 1. Split and Convert your Raw Data
    # Load the CSV file
    file_path = "adult_income.csv"
    data = pd.read_csv(file_path)
    data = data.drop(columns=["fnlwgt", "education"])

    # Split the data into training and testing sets
    train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)

    # Further split the training set into training and validation sets
    train_data, val_data = train_test_split(train_data, test_size=0.25, random_state=42)

    # Convert to pyarrow Table format
    train_table = pa.Table.from_pandas(train_data, preserve_index=True)
    val_table = pa.Table.from_pandas(val_data, preserve_index=True)
    test_table = pa.Table.from_pandas(test_data, preserve_index=True)

    # Save each split as ".parquet" file
    pq.write_table(train_table, "train.parquet")
    pq.write_table(val_table, "val.parquet")
    pq.write_table(test_table, "test.parquet")

    # 2. Upload your Converted Data
    client = boto3.client(
        service_name="s3",
        endpoint_url=os.getenv("S3_DATASET_ENDPOINT_URL"),
        aws_access_key_id=os.getenv("S3_DATASET_ACCESS_KEY_ID"),
        aws_secret_access_key=os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
        config=Config(signature_version="s3v4"),
    )

    client.upload_file("train.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "frozen_model_adult_income/train.parquet")
    client.upload_file("val.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "frozen_model_adult_income/val.parquet")
    client.upload_file("test.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "frozen_model_adult_income/test.parquet")

    # 3. Instantiate a Dataset

    train_dataset = ParquetDataset(
        name="adult_income_train_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/frozen_model_adult_income/train.parquet",
        storage_options=STORAGE_OPTIONS,
    )

    # 4. Find a schema
    analyzed_train_dataset = train_dataset.analyze(target_names=["income"])
    print(analyzed_train_dataset.analyzed_schema)

    preprocessor = SklearnPreprocessor(preprocess_function=StandardScaler())
    analyzed_train_dataset.analyzed_schema["educational-num"] = ExplainableFeature(
        name="educational-num", is_target=False, preprocessor=preprocessor, feature_type=NumericalFeature()
    )
    print(analyzed_train_dataset.analyzed_schema)

    # 5. Fit the schema
    fit_train_dataset = analyzed_train_dataset.fit()

    fit_test_dataset = FittedParquetDataset(
        name="adult_income_test_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/frozen_model_adult_income/test.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=fit_train_dataset.fitted_schema,
    )

    fit_val_dataset = FittedParquetDataset(
        name="adult_income_validation_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/frozen_model_adult_income/val.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=fit_train_dataset.fitted_schema,
    )

    # ##### Prepare the Model #######

    # 1. Create the required torch models
    input_size = fit_train_dataset.fitted_schema.input_size[1]
    target_size = fit_train_dataset.fitted_schema.target_size[1]

    print(f"input_size: {input_size} - target_size: {target_size}")

    feature_extractor = Sequential(
        torch.nn.Linear(input_size, 128),
        torch.nn.ReLU(),
        torch.nn.Linear(128, 64),
        torch.nn.ReLU(),
    )

    task_learner = Sequential(torch.nn.Linear(64, target_size))

    # 2. Train model locally

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = torch.nn.Sequential(feature_extractor, task_learner).to(device)

    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
    crit = torch.nn.CrossEntropyLoss()
    batch_size = 128
    epochs = 8

    def preprocess(df, fit_dataset):
        outs = []
        for exp_col in fit_dataset.fitted_schema.columns:
            if exp_col.name not in df.columns:
                continue

            Xc = df[[exp_col.name]]
            Xt = exp_col.preprocessor.transform(Xc)

            if hasattr(Xt, "toarray"):
                Xt = Xt.toarray()
            # densifie si sparse
            Xt = np.asarray(Xt, dtype=np.float32)
            if Xt.ndim == 1:
                Xt = Xt.reshape(-1, 1)  # garantit 2D
            outs.append(Xt)

        X = np.concatenate(outs, axis=1)

        return torch.from_numpy(X)

    X_train, y_train = train_data.drop(columns=["income"]), train_data[["income"]]
    X_test, y_test = test_data.drop(columns=["income"]), test_data[["income"]]
    X_val, y_val = val_data.drop(columns=["income"]), val_data[["income"]]

    X_train, y_train = preprocess(X_train, fit_train_dataset), preprocess(y_train, fit_train_dataset).argmax(1)
    X_test, y_test = preprocess(X_test, fit_test_dataset), preprocess(y_test, fit_test_dataset).argmax(1)
    X_val, y_val = preprocess(X_val, fit_val_dataset), preprocess(y_val, fit_val_dataset).argmax(1)

    train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(TensorDataset(X_val, y_val), batch_size=batch_size)

    @torch.no_grad()
    def val_acc():
        model.eval()
        correct = total = 0
        for xb, yb in val_loader:
            xb, yb = xb.to(device), yb.to(device)
            pred = model(xb).argmax(1)
            correct += (pred == yb).sum().item()
            total += yb.numel()
        return correct / max(total, 1)

    for e in range(1, epochs + 1):
        model.train()
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad(set_to_none=True)
            loss = crit(model(xb), yb)
            loss.backward()
            opt.step()
        print(f"epoch {e:02d} | val_acc={val_acc():.4f}")

    print(f"Accuracy score on train set: {accuracy_score(model(X_train.to(device)).argmax(1).cpu(), y_train)}")
    print(f"Accuracy score on test set: {accuracy_score(model(X_test.to(device)).argmax(1).cpu(), y_test)}")
    print(f"Accuracy score on validation set: {accuracy_score(model(X_val.to(device)).argmax(1).cpu(), y_val)}")

    # 3. Explainable Model Specifications

    model_specifications = ModelDecisionGraphParameters(
        graph_depth=3,
        population_pruning_threshold=0.1,
        feature_extraction_output_type=FeatureExtractionOutputType.VECTOR,
        frozen_model=True,
    )

    # 4. Create the Explainable Model
    xpdeep_model = XpdeepModel.from_torch(
        example_dataset=fit_train_dataset,
        feature_extraction=feature_extractor,
        task_learner=task_learner,
        backbone=None,
        decision_graph_parameters=model_specifications,
    )

    # ##### Train #######

    # Metrics to monitor the training.
    trainer = FrozenModelTrainer(start_epoch=0, max_epochs=10)

    trained_model = trainer.train(
        model=xpdeep_model,
        train_set=fit_train_dataset,
        validation_set=fit_val_dataset,
        batch_size=128,
    )

    # ##### Explain #######

    # 1. Build the Explainer
    statistics = DictStats(
        distribution_target=DistributionStat(on="target"), distribution_prediction=DistributionStat(on="prediction")
    )
    quality_metrics = [Sensitivity(), Infidelity()]

    metrics = DictMetrics(
        global_multi_class_accuracy=TorchGlobalMetric(
            partial(MulticlassAccuracy, num_classes=2, average="micro"), target_as_indexes=True
        ),
        leaf_multi_class_accuracy=TorchLeafMetric(
            partial(MulticlassAccuracy, num_classes=2, average="micro"), target_as_indexes=True
        ),
        global_multi_class_F1_score=TorchGlobalMetric(
            partial(MulticlassF1Score, num_classes=2, average="macro"), target_as_indexes=True
        ),
        leaf_multi_class_F1_score=TorchLeafMetric(
            partial(MulticlassF1Score, num_classes=2, average="macro"), target_as_indexes=True
        ),
        global_confusion_matrix=TorchGlobalMetric(
            partial(MulticlassConfusionMatrix, num_classes=2, normalize="all"), target_as_indexes=True
        ),
        leaf_confusion_matrix=TorchLeafMetric(
            partial(MulticlassConfusionMatrix, num_classes=2, normalize="all"), target_as_indexes=True
        ),
    )

    explainer = Explainer(quality_metrics=quality_metrics, metrics=metrics, statistics=statistics)

    # 2. Model Functioning Explanations
    model_explanations = explainer.global_explain(
        trained_model,
        train_set=fit_train_dataset,
        test_set=fit_test_dataset,
        validation_set=fit_val_dataset,
    )
    print(model_explanations.visualisation_link)

    # 3. Inference and their Causal Explanations
    my_filter = Filter("testing_filter", fit_test_dataset)
    my_filter.add_criteria(
        NumericalCriterion(fit_test_dataset.fitted_schema["age"], max_=30),
        CategoricalCriterion(fit_test_dataset.fitted_schema["workclass"], categories=["Private"]),
    )

    causal_explanations = explainer.local_explain(trained_model, fit_train_dataset, my_filter)

    print(causal_explanations.visualisation_link)


if __name__ == "__main__":
    init(api_key=os.getenv("API_KEY"), api_url=os.getenv("API_URL"))
    set_project(Project.create_or_get(name="Frozen Model Tutorial"))

    try:
        main()
    finally:
        get_project().delete()

No more differences with the self-explainable model training and explain process !

You can check that the initial model performances are the same after the explanation training on XpViz. For Adult Income, we use Accuracy.