Skip to content

Explain an Existing Model, Kitti Dataset#

This tutorial will focus on learning explanations from an existing model, based on Dfine architecture, trained for a classification problem on the Kitti dataset and will highlight the main differences with training a self-explainable model.

Tne dataset preparation, and the explanation computation are exactly the same as for a self-explainable model. The model preparation, and the training process are however slightly different.

Prepare the Model#

We need now to create an explainable model XpdeepModel.

As the models to detect objects for Kitti Dataset are not exportable with pytorch, we need to use one of the Xpdeep pre-existing models. ObjectDetectionFeatureExtractor and ObjectDetectionTaskLearner models will be used as feature extraction and task learner models to build the XpdeepModel.

Prepare the Model#

1. Create the required torch models#

In the frozen model context, we use an existing pre-trained pytorch model. For this tutorial, Xpdeep provides a model pre-trained on the Kitti dataset. It is available on HuggingFace model hub, but you can provide your own model link in the hub. Use the pretrained_model_path parameter and the correct credentials in the head_config parameter if required.

Warning

For the metrics to be correctly computed, you must set max_detections to be the maximum number of bounding boxes in your ground truth images.

pretrained_model_path="Xpdeep/dfine-small-kitti"
feature_extractor = ObjectDetectionFeatureExtractor(pretrained_model_path=pretrained_model_path)
feature_extractor.load_pretrained_weights()  # Load pretrained weight

task_learner = ObjectDetectionTaskLearner(pretrained_model_path=pretrained_model_path, max_detections=22)
task_learner.load_pretrained_weights()  # Load pretrained weight
👀 Full file preview
from __future__ import annotations

import os
from copy import deepcopy
from functools import partial
from pathlib import Path
from typing import Any

import albumentations as alb
import boto3
import cv2
import numpy as np
import torch
from botocore.config import Config
from datasets import Dataset, DatasetDict, load_dataset
from PIL import Image as PILImage
from sklearn.model_selection import train_test_split
from xpdeep_metrics.object_detection import Map50, Map50To95, MapPerClass
from xpdeep_modules.object_detection.dfine_models import ObjectDetectionFeatureExtractor, ObjectDetectionTaskLearner

from xpdeep import Project, get_project, init, set_project
from xpdeep.dataset.feature import ExplainableFeature, IndexMetadata
from xpdeep.dataset.feature.feature_types import BoundingBoxesFeature, ImageFeature
from xpdeep.dataset.parquet_dataset import AnalyzedParquetDataset, FittedParquetDataset
from xpdeep.dataset.preprocessor.preprocessor import BoundingBoxesPreprocessor, TorchPreprocessor
from xpdeep.dataset.schema import AnalyzedSchema
from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats
from xpdeep.filtering.filter import Filter
from xpdeep.metric import DictMetrics, TorchGlobalMetric, TorchLeafMetric
from xpdeep.model.feature_extraction_output_type import FeatureExtractionOutputType
from xpdeep.model.model_parameters import ModelDecisionGraphParameters
from xpdeep.model.xpdeep_model import XpdeepModel
from xpdeep.trainer.trainer import FrozenModelTrainer

STORAGE_OPTIONS = {
    "key": os.getenv("S3_DATASET_ACCESS_KEY_ID"),
    "secret": os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
    "client_kwargs": {
        "endpoint_url": os.getenv("S3_DATASET_ENDPOINT_URL"),
    },
    "s3_additional_kwargs": {"addressing_style": "path"},
}

# Define env variable to ensure multiprocessing works.
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
cv2.setNumThreads(0)


def add_xpdeep_index(dataset_dict: DatasetDict):
    """Add a unique 'index' column to each split in the dataset dictionary."""

    def add_index_with_offset(batch: dict[str, Any], idx: list[int], offset: int = 0) -> dict[str, Any]:
        batch["kitti_index"] = np.array(idx) + offset
        return batch

    offset = 0
    for split in dataset_dict:
        dataset_dict[split] = dataset_dict[split].map(
            add_index_with_offset, batched=True, with_indices=True, fn_kwargs={"offset": offset}
        )
        offset += dataset_dict[split].num_rows


def read_kitti_label_file(
    txt_path: Path, image_width: int, image_height: int, class_to_id: dict[str, int]
) -> list[list[float | Any]]:
    """Parse a KITTI label file into YOLO-normalized boxes.

    Parameters
    ----------
    txt_path : Path
        Path to the KITTI label text file to read.
    image_width : int
        Width of the corresponding image in pixels.
    image_height : int
        Height of the corresponding image in pixels.
    class_to_id : dict[str, int]
        Mapping from KITTI class names (e.g., ``"Car"``) to integer class IDs.
        Any class name not present in this mapping is ignored.

    Returns
    -------
    list[list[float | Any]]
        A list of detections; each detection is
        ``[cx, cy, w, h, score, class_id]`` where:
        ``cx, cy, w, h`` are YOLO-normalized to ``[0, 1]`` relative to
        ``(image_width, image_height)``, ``score`` is ``1.0`` for parsed boxes,
        and ``class_id`` is the integer from ``class_to_id``.

    Notes
    -----
    - Boxes are clipped to the image frame; invalid (degenerate) boxes are skipped.
    - Lines whose class is not in ``class_to_id`` (e.g., ``DontCare``) are ignored.
    """
    boxes: list[list[float | Any]] = []
    if not txt_path.exists():
        return boxes

    for line in txt_path.read_text(encoding="utf8").strip().splitlines():
        parts = line.split()
        cls_name = parts[0]
        if cls_name not in class_to_id:
            continue

        xmin, ymin, xmax, ymax = map(float, parts[4:8])
        xmin = max(0.0, xmin)
        ymin = max(0.0, ymin)
        xmax = min(image_width - 1.0, xmax)
        ymax = min(image_height - 1.0, ymax)
        if xmax <= xmin or ymax <= ymin:
            continue

        cx = ((xmin + xmax) * 0.5) / image_width
        cy = ((ymin + ymax) * 0.5) / image_height
        w = (xmax - xmin) / image_width
        h = (ymax - ymin) / image_height
        class_id = int(class_to_id[cls_name])

        if w > 0.0 and h > 0.0:
            boxes.append([cx, cy, w, h, 1.0, class_id])
    return boxes


def compute_majority_class(labels: np.ndarray, areas: np.ndarray) -> int:
    """Compute the per-image majority class (area-weighted tie-break).

    Parameters
    ----------
    labels : np.ndarray
        Class ids for all objects in an image.
    areas : np.ndarray
        Per-object area proxy (e.g., w*h) for tiebreak.

    Returns
    -------
    int
        Majority class id, or raises if `labels` is empty (call-site guards this).
    """
    counts = np.bincount(labels)
    max_count = counts.max()
    tied = np.flatnonzero(counts == max_count)
    if tied.size == 1:
        return int(tied[0])

    area_per_class = np.bincount(labels, weights=areas, minlength=counts.shape[0])
    return int(tied[np.argmax(area_per_class[tied])])


def attach_objects_to_sample(sample: dict, label_dir: Path, class_to_id: dict[str, int]) -> dict[str, Any]:
    """Map function for HF `Dataset.map`.

    - reads KITTI txt for this image,
    - produces YOLO boxes and majority class metadata.
    """
    img_w, img_h = sample["image"].size
    stem = Path(sample["image"].filename).stem
    objects = read_kitti_label_file(label_dir / f"{stem}.txt", img_w, img_h, class_to_id)

    majority = -1
    num_objects = 0
    if objects:
        arr = np.asarray(objects, dtype=np.float32)
        labels_np = arr[:, -1].astype(np.int64)
        num_objects = int(arr.shape[0])
        majority = compute_majority_class(labels_np, arr[:, 2] * arr[:, 3])

    return {"objects": objects, "majority_class": majority, "n_objects": num_objects}


def resize_and_map_fn(  # noqa: PLR0913
    batch: dict,
    *,
    do_pad: bool,
    image_size_wh: tuple[int, int],
    max_size_hw: tuple[int, int] | None,
    pad_size_hw: tuple[int, int] | None,
    padding_row: list[float],
    target_len: int,
) -> dict:
    """Picklable map function for multiprocessing: resize images and remap YOLO boxes."""
    out_w, out_h = image_size_wh  # (W, H)
    # Build bbox params *inside* each process to avoid cross-proc sharing
    albu_bbox_params = alb.BboxParams(format="yolo", label_fields=["labels", "scores"])

    processed_images: list[Any] = []
    processed_objects: list[list[list[float]]] = []

    for pil_img, obj_list in zip(batch["image"], batch["objects"], strict=False):
        img_np = np.array(pil_img)  # HxWxC
        h, w = img_np.shape[:2]

        if obj_list:
            bboxes = [o[:4] for o in obj_list]
            scores = [float(o[4]) for o in obj_list]
            labels = [int(o[5]) for o in obj_list]
        else:
            bboxes, labels, scores = [], [], []

        if not do_pad:
            transform = alb.Compose(
                [alb.Resize(height=out_h, width=out_w, interpolation=cv2.INTER_LINEAR)],
                bbox_params=albu_bbox_params,
            )
        else:
            if max_size_hw is None:
                msg = "When do_pad=True, set max_size_hw=(MAX_H, MAX_W)."
                raise ValueError(msg)
            max_h, max_w = max_size_hw
            pad_h, pad_w = pad_size_hw or max_size_hw

            scale = min(max_h / float(h), max_w / float(w))
            new_h = round(h * scale)
            new_w = round(w * scale)

            transform = alb.Compose(
                [
                    alb.Resize(height=new_h, width=new_w, interpolation=cv2.INTER_LINEAR),
                    alb.PadIfNeeded(
                        min_height=pad_h,
                        min_width=pad_w,
                        position="top_left",
                        border_mode=cv2.BORDER_CONSTANT,
                        value=0,
                    ),
                ],
                bbox_params=albu_bbox_params,
            )

        out = transform(image=img_np, bboxes=bboxes, labels=labels, scores=scores)
        img_resized = PILImage.fromarray(out["image"])
        bboxes_out, labels_out, scores_out = out["bboxes"], out["labels"], out["scores"]

        rows = [[*bb, sc, lb] for bb, lb, sc in zip(bboxes_out, labels_out, scores_out, strict=False)]

        # Pad/truncate to fixed length
        if len(rows) < target_len:
            rows += [padding_row] * (target_len - len(rows))
        else:
            rows = rows[:target_len]

        processed_images.append(img_resized)
        processed_objects.append(rows)

    return {"image": processed_images, "objects": processed_objects}


class KittiPreparationScript:
    """
    KITTI preparation script (object detection).

    Steps:
      1) Load images via HF `imagefolder` and parse KITTI label files.
      2) Create stratified train/val/test splits by per-image majority class.
      3) Resize images using either:
           - direct resize to (W, H) with no padding (do_pad=False), or
           - letterbox (rectangular general case; square if H==W) with top-left placement (do_pad=True).
         Bounding boxes (YOLO format [cx, cy, w, h] normalized) are transformed automatically.
      4) (Optionally) write raw and/or fully preprocessed parquet splits and fit schema.
    """

    def __init__(  # noqa: PLR0917,PLR0913
        self,
        dataset_root_path: str,
        image_size: tuple[int, int] = (1024, 320),
        max_size_hw: tuple[int, int] | None = None,
        do_pad: bool = False,
        pad_size_hw: tuple[int, int] | None = None,
        dataset_name: str = "KITTI",
        random_seed: int = 1225,
    ) -> None:
        """
        Initialize KITTI preparation script.

        Parameters
        ----------
        dataset_root_path : str
            Root path containing `images/` and `labels/` folders.
        image_size : tuple[int, int]
            Target (W, H) when do_pad=False (direct resize); default (1024, 320)
        max_size_hw : tuple[int, int] | None
            (MAX_H, MAX_W) for letterbox. If do_pad=True, this is required.
            Set MAX_H==MAX_W for square letterbox; default None.
        do_pad : bool
            If True, use letterbox with top-left padding; else anisotropic resize; default False.
        pad_size_hw : tuple[int, int] | None
            (PAD_H, PAD_W) final canvas for letterbox. Defaults to max_size_hw; default None
        dataset_name : str
            Name used when writing parquet splits; default "KITTI".
        random_seed : int
            Seed for deterministic splits; default 1225.
        """
        self.dataset_root_path = Path(dataset_root_path)
        self.image_size = image_size
        self.do_pad = do_pad
        self.max_size_hw = max_size_hw
        self.pad_size_hw = pad_size_hw
        self.dataset_name = dataset_name
        self.random_seed = random_seed

        # KITTI classes, we do not use DontCare class like other works
        self.classes = {
            "Car": 0,
            "Pedestrian": 1,
            "Van": 2,
            "Cyclist": 3,
            "Truck": 4,
            "Misc": 5,
            "Tram": 6,
            "Person_sitting": 7,
        }

        # Pad bboxes in each image to have same number of bboxes per image in the dataset.
        # padding row format [cx, cy, w, h, confidence, class_id], where confidence is 0
        # and class_id = num_classes (here 8)
        self.padding_row = [-1, -1, 0, 0, 0, len(self.classes)]

        # Will be filled after scanning the dataset
        self.maximum_objects_per_image: int | None = None

        # Albumentations bbox settings (YOLO), tied to labels and scores
        self.albu_bbox_params = alb.BboxParams(format="yolo", label_fields=["labels", "scores"])

        # Parallelism for HF map()
        self.map_num_proc = min(4, os.cpu_count() or 1)

    def load_full_dataset_with_objects(self) -> Dataset:
        """Load images and labels from `images/` and `labels/` folder.

        Attach to each sample:
          - objects: list[[cx, cy, w, h, score, class_id]] in YOLO normalized coords
          - majority_class: per-image majority class id
        """
        dataset = load_dataset(
            "imagefolder", data_dir=str(self.dataset_root_path / "images"), split="train", drop_labels=True
        )

        label_dir = self.dataset_root_path / "labels"
        map_fn = partial(attach_objects_to_sample, label_dir=label_dir, class_to_id=self.classes)

        dataset = dataset.map(map_fn, desc="Attach objects + majority class")
        self.maximum_objects_per_image = int(np.max(dataset["n_objects"])) if len(dataset) else 0
        return dataset.remove_columns(["n_objects"])

    def make_splits(self, full_dataset: Dataset) -> DatasetDict:
        """Create stratified splits by majority class, resize images and transform boxes according to config."""
        majority_class = np.array(full_dataset["majority_class"])
        full_dataset = full_dataset.remove_columns("majority_class")

        all_indices = np.arange(len(full_dataset))
        training_indices, temporary_indices = train_test_split(
            all_indices, test_size=0.30, stratify=majority_class, random_state=self.random_seed
        )

        validation_proportion_relative = 0.15 / 0.30
        validation_indices, test_indices = train_test_split(
            temporary_indices,
            test_size=1 - validation_proportion_relative,
            stratify=majority_class[temporary_indices],
            random_state=self.random_seed,
        )

        splits = {
            "train": full_dataset.select(training_indices),
            "validation": full_dataset.select(validation_indices),
            "test": full_dataset.select(test_indices),
        }

        resize_and_map = partial(
            resize_and_map_fn,
            do_pad=self.do_pad,
            image_size_wh=self.image_size,
            max_size_hw=self.max_size_hw,
            pad_size_hw=self.pad_size_hw,
            padding_row=self.padding_row,
            target_len=self.maximum_objects_per_image,
        )

        # Apply preprocessing per split (batched & parallel).
        for split_name, split in splits.items():
            splits[split_name] = split.map(
                resize_and_map,
                desc=f"resize+map {split_name}",
                batched=True,
                batch_size=256,
                num_proc=self.map_num_proc,
                load_from_cache_file=True,
            )

        return DatasetDict(splits)


def main():
    """Process the dataset, train, and explain the model."""
    # ##### Prepare the Dataset #######
    preparation_class = KittiPreparationScript(dataset_root_path="object_detection_kitti")
    dataset = preparation_class.load_full_dataset_with_objects()
    splits = preparation_class.make_splits(dataset)

    # Add index xpdeep, "index_xp_deep" will be recognized as IndexMetadata by the analyze method.
    add_xpdeep_index(splits)
    splits.set_format("numpy")  # convert to numpy the pil images

    # Save each split as ".parquet" file
    for split_name, split_data in splits.items():
        # HuggingFace: set chunk to be groups of 100mb by default, here we use 64 rows to be memory efficient.
        # For kitti, we set chunk_size to 64 to get ~50mb per group to avoid memory errors on XpViz.
        split_data.to_parquet(f"{split_name}.parquet", batch_size=64)

    # 2. Upload your Converted Data
    client = boto3.client(
        service_name="s3",
        endpoint_url=os.getenv("S3_DATASET_ENDPOINT_URL"),
        aws_access_key_id=os.getenv("S3_DATASET_ACCESS_KEY_ID"),
        aws_secret_access_key=os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
        config=Config(signature_version="s3v4"),
    )

    client.upload_file("train.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "frozen_model_kitti/train.parquet")
    client.upload_file("validation.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "frozen_model_kitti/val.parquet")
    client.upload_file("test.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "frozen_model_kitti/test.parquet")

    # 3. Find a schema.

    # Set a custom image preprocessor, different from the one provided by the AutoAnalyzer.
    class ScaleKitti(TorchPreprocessor):
        """Kitti preprocessor, given an image in range [0, 256], scale the pixel values to [0 ,1]."""

        def transform(self, inputs: torch.Tensor) -> torch.Tensor:  # noqa: PLR6301
            """Transform."""
            return inputs / 255.0

        def inverse_transform(self, output: torch.Tensor) -> torch.Tensor:  # noqa: PLR6301
            """Apply inverse transform."""
            return output * 255.0

    image = ExplainableFeature(
        name="image", feature_type=ImageFeature(), preprocessor=ScaleKitti(input_size=(320, 1024, 3)), is_target=False
    )

    # BBOX not supported in AutoAnalyzer, needs to manually define the feature
    target = ExplainableFeature(
        name="objects",
        feature_type=BoundingBoxesFeature(
            categories=list(preparation_class.classes.keys()),
        ),
        preprocessor=BoundingBoxesPreprocessor(preprocessed_size=None),
        is_target=True,
    )

    index = IndexMetadata(name="kitti_index")

    # Add the index column in the AnalyzedSchema as we don't use the dataset `analyze` method
    analyzed_schema = AnalyzedSchema(image, target, index)
    analyzed_train_dataset = AnalyzedParquetDataset(
        analyzed_schema=analyzed_schema,
        name="kitti_train_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/frozen_model_kitti/train.parquet",
        storage_options=STORAGE_OPTIONS,
    )

    print(analyzed_schema)

    # 4. Fit the schema

    fit_train_dataset = analyzed_train_dataset.fit()

    fit_test_dataset = FittedParquetDataset(
        name="kitti_test_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/frozen_model_kitti/test.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    fit_val_dataset = FittedParquetDataset(
        name="kitti_validation_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/frozen_model_kitti/val.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    pretrained_model_path = "Xpdeep/dfine-small-kitti"  # Checkpoint path on huggingface hub
    feature_extractor = ObjectDetectionFeatureExtractor(pretrained_model_path=pretrained_model_path)
    feature_extractor.load_pretrained_weights()  # Load pretrained weight

    task_learner = ObjectDetectionTaskLearner(pretrained_model_path=pretrained_model_path, max_detections=22)
    task_learner.load_pretrained_weights()

    # 2. Explainable Model Specifications
    model_specifications = ModelDecisionGraphParameters(
        graph_depth=3,
        population_pruning_threshold=0.1,
        feature_extraction_output_type=FeatureExtractionOutputType.DFINE_MATRIX,
        frozen_model=True,
    )

    # 3. Create the Explainable Model
    xpdeep_model = XpdeepModel(
        example_dataset=fit_train_dataset,
        feature_extraction=feature_extractor,
        task_learner=task_learner,
        backbone=None,
        decision_graph_parameters=model_specifications,
    )

    # ##### Train #######

    # Metrics to monitor the training.
    trainer = FrozenModelTrainer(start_epoch=0, max_epochs=13)
    trained_model = trainer.train(
        model=xpdeep_model,
        train_set=fit_train_dataset,
        validation_set=fit_val_dataset,
        batch_size=32,
    )

    # ##### Explain #######

    # 1. Build the Explainer
    statistics = DictStats()
    quality_metrics = [Sensitivity(), Infidelity()]

    metrics = DictMetrics(
        map50=TorchGlobalMetric(metric=partial(Map50, box_format="cxcywh"), on_raw_data=False),
        leaf_metric_map50=TorchLeafMetric(metric=partial(Map50, box_format="cxcywh"), on_raw_data=False),
        map50_95=TorchGlobalMetric(metric=partial(Map50To95, box_format="cxcywh"), on_raw_data=False),
        leaf_metric_map50_95=TorchLeafMetric(metric=partial(Map50To95, box_format="cxcywh"), on_raw_data=False),
        map_per_class=TorchGlobalMetric(metric=partial(MapPerClass, box_format="cxcywh"), on_raw_data=False),
        leaf_metric_map_per_class=TorchLeafMetric(metric=partial(MapPerClass, box_format="cxcywh"), on_raw_data=False),
    )
    explainer = Explainer(
        description_representativeness=10, quality_metrics=quality_metrics, metrics=metrics, statistics=statistics
    )

    # 2. Model Functioning Explanations
    model_explanations = explainer.global_explain(
        trained_model,
        train_set=fit_train_dataset,
        test_set=fit_test_dataset,
        validation_set=fit_val_dataset,
    )
    print(model_explanations.visualisation_link)

    # 3. Inference and their Causal Explanations
    # No criterion exist to filter by images, we should filter by indexes only.
    my_filter = Filter("testing_filter", fit_test_dataset, min_index=10, max_index=20)
    causal_explanations = explainer.local_explain(trained_model, fit_test_dataset, my_filter)

    print(causal_explanations.visualisation_link)


if __name__ == "__main__":
    init(api_key=os.getenv("API_KEY"), api_url=os.getenv("API_URL"))
    set_project(Project.create_or_get(name="Kitti Tutorial"))

    try:
        main()
    finally:
        get_project().delete()

Note

load_pretrained_weights method load pretrained weights on the model. You need to call it explicitly as init does not load the pretrained weights, for optimization issues.

2. Explainable Model Specifications#

As specified in this tutorial, we set frozen_model to True to specify that we only want to train the explanations of an existing model, that is maintained unchanged.

from xpdeep.model.model_parameters import ModelDecisionGraphParameters
from xpdeep.model.feature_extraction_output_type import FeatureExtractionOutputType

model_specifications = ModelDecisionGraphParameters(
    graph_depth=3,
    population_pruning_threshold=0.1,
    feature_extraction_output_type=FeatureExtractionOutputType.DFINE_MATRIX,
    frozen_model=True,
)
👀 Full file preview
from __future__ import annotations

import os
from copy import deepcopy
from functools import partial
from pathlib import Path
from typing import Any

import albumentations as alb
import boto3
import cv2
import numpy as np
import torch
from botocore.config import Config
from datasets import Dataset, DatasetDict, load_dataset
from PIL import Image as PILImage
from sklearn.model_selection import train_test_split
from xpdeep_metrics.object_detection import Map50, Map50To95, MapPerClass
from xpdeep_modules.object_detection.dfine_models import ObjectDetectionFeatureExtractor, ObjectDetectionTaskLearner

from xpdeep import Project, get_project, init, set_project
from xpdeep.dataset.feature import ExplainableFeature, IndexMetadata
from xpdeep.dataset.feature.feature_types import BoundingBoxesFeature, ImageFeature
from xpdeep.dataset.parquet_dataset import AnalyzedParquetDataset, FittedParquetDataset
from xpdeep.dataset.preprocessor.preprocessor import BoundingBoxesPreprocessor, TorchPreprocessor
from xpdeep.dataset.schema import AnalyzedSchema
from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats
from xpdeep.filtering.filter import Filter
from xpdeep.metric import DictMetrics, TorchGlobalMetric, TorchLeafMetric
from xpdeep.model.feature_extraction_output_type import FeatureExtractionOutputType
from xpdeep.model.model_parameters import ModelDecisionGraphParameters
from xpdeep.model.xpdeep_model import XpdeepModel
from xpdeep.trainer.trainer import FrozenModelTrainer

STORAGE_OPTIONS = {
    "key": os.getenv("S3_DATASET_ACCESS_KEY_ID"),
    "secret": os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
    "client_kwargs": {
        "endpoint_url": os.getenv("S3_DATASET_ENDPOINT_URL"),
    },
    "s3_additional_kwargs": {"addressing_style": "path"},
}

# Define env variable to ensure multiprocessing works.
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
cv2.setNumThreads(0)


def add_xpdeep_index(dataset_dict: DatasetDict):
    """Add a unique 'index' column to each split in the dataset dictionary."""

    def add_index_with_offset(batch: dict[str, Any], idx: list[int], offset: int = 0) -> dict[str, Any]:
        batch["kitti_index"] = np.array(idx) + offset
        return batch

    offset = 0
    for split in dataset_dict:
        dataset_dict[split] = dataset_dict[split].map(
            add_index_with_offset, batched=True, with_indices=True, fn_kwargs={"offset": offset}
        )
        offset += dataset_dict[split].num_rows


def read_kitti_label_file(
    txt_path: Path, image_width: int, image_height: int, class_to_id: dict[str, int]
) -> list[list[float | Any]]:
    """Parse a KITTI label file into YOLO-normalized boxes.

    Parameters
    ----------
    txt_path : Path
        Path to the KITTI label text file to read.
    image_width : int
        Width of the corresponding image in pixels.
    image_height : int
        Height of the corresponding image in pixels.
    class_to_id : dict[str, int]
        Mapping from KITTI class names (e.g., ``"Car"``) to integer class IDs.
        Any class name not present in this mapping is ignored.

    Returns
    -------
    list[list[float | Any]]
        A list of detections; each detection is
        ``[cx, cy, w, h, score, class_id]`` where:
        ``cx, cy, w, h`` are YOLO-normalized to ``[0, 1]`` relative to
        ``(image_width, image_height)``, ``score`` is ``1.0`` for parsed boxes,
        and ``class_id`` is the integer from ``class_to_id``.

    Notes
    -----
    - Boxes are clipped to the image frame; invalid (degenerate) boxes are skipped.
    - Lines whose class is not in ``class_to_id`` (e.g., ``DontCare``) are ignored.
    """
    boxes: list[list[float | Any]] = []
    if not txt_path.exists():
        return boxes

    for line in txt_path.read_text(encoding="utf8").strip().splitlines():
        parts = line.split()
        cls_name = parts[0]
        if cls_name not in class_to_id:
            continue

        xmin, ymin, xmax, ymax = map(float, parts[4:8])
        xmin = max(0.0, xmin)
        ymin = max(0.0, ymin)
        xmax = min(image_width - 1.0, xmax)
        ymax = min(image_height - 1.0, ymax)
        if xmax <= xmin or ymax <= ymin:
            continue

        cx = ((xmin + xmax) * 0.5) / image_width
        cy = ((ymin + ymax) * 0.5) / image_height
        w = (xmax - xmin) / image_width
        h = (ymax - ymin) / image_height
        class_id = int(class_to_id[cls_name])

        if w > 0.0 and h > 0.0:
            boxes.append([cx, cy, w, h, 1.0, class_id])
    return boxes


def compute_majority_class(labels: np.ndarray, areas: np.ndarray) -> int:
    """Compute the per-image majority class (area-weighted tie-break).

    Parameters
    ----------
    labels : np.ndarray
        Class ids for all objects in an image.
    areas : np.ndarray
        Per-object area proxy (e.g., w*h) for tiebreak.

    Returns
    -------
    int
        Majority class id, or raises if `labels` is empty (call-site guards this).
    """
    counts = np.bincount(labels)
    max_count = counts.max()
    tied = np.flatnonzero(counts == max_count)
    if tied.size == 1:
        return int(tied[0])

    area_per_class = np.bincount(labels, weights=areas, minlength=counts.shape[0])
    return int(tied[np.argmax(area_per_class[tied])])


def attach_objects_to_sample(sample: dict, label_dir: Path, class_to_id: dict[str, int]) -> dict[str, Any]:
    """Map function for HF `Dataset.map`.

    - reads KITTI txt for this image,
    - produces YOLO boxes and majority class metadata.
    """
    img_w, img_h = sample["image"].size
    stem = Path(sample["image"].filename).stem
    objects = read_kitti_label_file(label_dir / f"{stem}.txt", img_w, img_h, class_to_id)

    majority = -1
    num_objects = 0
    if objects:
        arr = np.asarray(objects, dtype=np.float32)
        labels_np = arr[:, -1].astype(np.int64)
        num_objects = int(arr.shape[0])
        majority = compute_majority_class(labels_np, arr[:, 2] * arr[:, 3])

    return {"objects": objects, "majority_class": majority, "n_objects": num_objects}


def resize_and_map_fn(  # noqa: PLR0913
    batch: dict,
    *,
    do_pad: bool,
    image_size_wh: tuple[int, int],
    max_size_hw: tuple[int, int] | None,
    pad_size_hw: tuple[int, int] | None,
    padding_row: list[float],
    target_len: int,
) -> dict:
    """Picklable map function for multiprocessing: resize images and remap YOLO boxes."""
    out_w, out_h = image_size_wh  # (W, H)
    # Build bbox params *inside* each process to avoid cross-proc sharing
    albu_bbox_params = alb.BboxParams(format="yolo", label_fields=["labels", "scores"])

    processed_images: list[Any] = []
    processed_objects: list[list[list[float]]] = []

    for pil_img, obj_list in zip(batch["image"], batch["objects"], strict=False):
        img_np = np.array(pil_img)  # HxWxC
        h, w = img_np.shape[:2]

        if obj_list:
            bboxes = [o[:4] for o in obj_list]
            scores = [float(o[4]) for o in obj_list]
            labels = [int(o[5]) for o in obj_list]
        else:
            bboxes, labels, scores = [], [], []

        if not do_pad:
            transform = alb.Compose(
                [alb.Resize(height=out_h, width=out_w, interpolation=cv2.INTER_LINEAR)],
                bbox_params=albu_bbox_params,
            )
        else:
            if max_size_hw is None:
                msg = "When do_pad=True, set max_size_hw=(MAX_H, MAX_W)."
                raise ValueError(msg)
            max_h, max_w = max_size_hw
            pad_h, pad_w = pad_size_hw or max_size_hw

            scale = min(max_h / float(h), max_w / float(w))
            new_h = round(h * scale)
            new_w = round(w * scale)

            transform = alb.Compose(
                [
                    alb.Resize(height=new_h, width=new_w, interpolation=cv2.INTER_LINEAR),
                    alb.PadIfNeeded(
                        min_height=pad_h,
                        min_width=pad_w,
                        position="top_left",
                        border_mode=cv2.BORDER_CONSTANT,
                        value=0,
                    ),
                ],
                bbox_params=albu_bbox_params,
            )

        out = transform(image=img_np, bboxes=bboxes, labels=labels, scores=scores)
        img_resized = PILImage.fromarray(out["image"])
        bboxes_out, labels_out, scores_out = out["bboxes"], out["labels"], out["scores"]

        rows = [[*bb, sc, lb] for bb, lb, sc in zip(bboxes_out, labels_out, scores_out, strict=False)]

        # Pad/truncate to fixed length
        if len(rows) < target_len:
            rows += [padding_row] * (target_len - len(rows))
        else:
            rows = rows[:target_len]

        processed_images.append(img_resized)
        processed_objects.append(rows)

    return {"image": processed_images, "objects": processed_objects}


class KittiPreparationScript:
    """
    KITTI preparation script (object detection).

    Steps:
      1) Load images via HF `imagefolder` and parse KITTI label files.
      2) Create stratified train/val/test splits by per-image majority class.
      3) Resize images using either:
           - direct resize to (W, H) with no padding (do_pad=False), or
           - letterbox (rectangular general case; square if H==W) with top-left placement (do_pad=True).
         Bounding boxes (YOLO format [cx, cy, w, h] normalized) are transformed automatically.
      4) (Optionally) write raw and/or fully preprocessed parquet splits and fit schema.
    """

    def __init__(  # noqa: PLR0917,PLR0913
        self,
        dataset_root_path: str,
        image_size: tuple[int, int] = (1024, 320),
        max_size_hw: tuple[int, int] | None = None,
        do_pad: bool = False,
        pad_size_hw: tuple[int, int] | None = None,
        dataset_name: str = "KITTI",
        random_seed: int = 1225,
    ) -> None:
        """
        Initialize KITTI preparation script.

        Parameters
        ----------
        dataset_root_path : str
            Root path containing `images/` and `labels/` folders.
        image_size : tuple[int, int]
            Target (W, H) when do_pad=False (direct resize); default (1024, 320)
        max_size_hw : tuple[int, int] | None
            (MAX_H, MAX_W) for letterbox. If do_pad=True, this is required.
            Set MAX_H==MAX_W for square letterbox; default None.
        do_pad : bool
            If True, use letterbox with top-left padding; else anisotropic resize; default False.
        pad_size_hw : tuple[int, int] | None
            (PAD_H, PAD_W) final canvas for letterbox. Defaults to max_size_hw; default None
        dataset_name : str
            Name used when writing parquet splits; default "KITTI".
        random_seed : int
            Seed for deterministic splits; default 1225.
        """
        self.dataset_root_path = Path(dataset_root_path)
        self.image_size = image_size
        self.do_pad = do_pad
        self.max_size_hw = max_size_hw
        self.pad_size_hw = pad_size_hw
        self.dataset_name = dataset_name
        self.random_seed = random_seed

        # KITTI classes, we do not use DontCare class like other works
        self.classes = {
            "Car": 0,
            "Pedestrian": 1,
            "Van": 2,
            "Cyclist": 3,
            "Truck": 4,
            "Misc": 5,
            "Tram": 6,
            "Person_sitting": 7,
        }

        # Pad bboxes in each image to have same number of bboxes per image in the dataset.
        # padding row format [cx, cy, w, h, confidence, class_id], where confidence is 0
        # and class_id = num_classes (here 8)
        self.padding_row = [-1, -1, 0, 0, 0, len(self.classes)]

        # Will be filled after scanning the dataset
        self.maximum_objects_per_image: int | None = None

        # Albumentations bbox settings (YOLO), tied to labels and scores
        self.albu_bbox_params = alb.BboxParams(format="yolo", label_fields=["labels", "scores"])

        # Parallelism for HF map()
        self.map_num_proc = min(4, os.cpu_count() or 1)

    def load_full_dataset_with_objects(self) -> Dataset:
        """Load images and labels from `images/` and `labels/` folder.

        Attach to each sample:
          - objects: list[[cx, cy, w, h, score, class_id]] in YOLO normalized coords
          - majority_class: per-image majority class id
        """
        dataset = load_dataset(
            "imagefolder", data_dir=str(self.dataset_root_path / "images"), split="train", drop_labels=True
        )

        label_dir = self.dataset_root_path / "labels"
        map_fn = partial(attach_objects_to_sample, label_dir=label_dir, class_to_id=self.classes)

        dataset = dataset.map(map_fn, desc="Attach objects + majority class")
        self.maximum_objects_per_image = int(np.max(dataset["n_objects"])) if len(dataset) else 0
        return dataset.remove_columns(["n_objects"])

    def make_splits(self, full_dataset: Dataset) -> DatasetDict:
        """Create stratified splits by majority class, resize images and transform boxes according to config."""
        majority_class = np.array(full_dataset["majority_class"])
        full_dataset = full_dataset.remove_columns("majority_class")

        all_indices = np.arange(len(full_dataset))
        training_indices, temporary_indices = train_test_split(
            all_indices, test_size=0.30, stratify=majority_class, random_state=self.random_seed
        )

        validation_proportion_relative = 0.15 / 0.30
        validation_indices, test_indices = train_test_split(
            temporary_indices,
            test_size=1 - validation_proportion_relative,
            stratify=majority_class[temporary_indices],
            random_state=self.random_seed,
        )

        splits = {
            "train": full_dataset.select(training_indices),
            "validation": full_dataset.select(validation_indices),
            "test": full_dataset.select(test_indices),
        }

        resize_and_map = partial(
            resize_and_map_fn,
            do_pad=self.do_pad,
            image_size_wh=self.image_size,
            max_size_hw=self.max_size_hw,
            pad_size_hw=self.pad_size_hw,
            padding_row=self.padding_row,
            target_len=self.maximum_objects_per_image,
        )

        # Apply preprocessing per split (batched & parallel).
        for split_name, split in splits.items():
            splits[split_name] = split.map(
                resize_and_map,
                desc=f"resize+map {split_name}",
                batched=True,
                batch_size=256,
                num_proc=self.map_num_proc,
                load_from_cache_file=True,
            )

        return DatasetDict(splits)


def main():
    """Process the dataset, train, and explain the model."""
    # ##### Prepare the Dataset #######
    preparation_class = KittiPreparationScript(dataset_root_path="object_detection_kitti")
    dataset = preparation_class.load_full_dataset_with_objects()
    splits = preparation_class.make_splits(dataset)

    # Add index xpdeep, "index_xp_deep" will be recognized as IndexMetadata by the analyze method.
    add_xpdeep_index(splits)
    splits.set_format("numpy")  # convert to numpy the pil images

    # Save each split as ".parquet" file
    for split_name, split_data in splits.items():
        # HuggingFace: set chunk to be groups of 100mb by default, here we use 64 rows to be memory efficient.
        # For kitti, we set chunk_size to 64 to get ~50mb per group to avoid memory errors on XpViz.
        split_data.to_parquet(f"{split_name}.parquet", batch_size=64)

    # 2. Upload your Converted Data
    client = boto3.client(
        service_name="s3",
        endpoint_url=os.getenv("S3_DATASET_ENDPOINT_URL"),
        aws_access_key_id=os.getenv("S3_DATASET_ACCESS_KEY_ID"),
        aws_secret_access_key=os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
        config=Config(signature_version="s3v4"),
    )

    client.upload_file("train.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "frozen_model_kitti/train.parquet")
    client.upload_file("validation.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "frozen_model_kitti/val.parquet")
    client.upload_file("test.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "frozen_model_kitti/test.parquet")

    # 3. Find a schema.

    # Set a custom image preprocessor, different from the one provided by the AutoAnalyzer.
    class ScaleKitti(TorchPreprocessor):
        """Kitti preprocessor, given an image in range [0, 256], scale the pixel values to [0 ,1]."""

        def transform(self, inputs: torch.Tensor) -> torch.Tensor:  # noqa: PLR6301
            """Transform."""
            return inputs / 255.0

        def inverse_transform(self, output: torch.Tensor) -> torch.Tensor:  # noqa: PLR6301
            """Apply inverse transform."""
            return output * 255.0

    image = ExplainableFeature(
        name="image", feature_type=ImageFeature(), preprocessor=ScaleKitti(input_size=(320, 1024, 3)), is_target=False
    )

    # BBOX not supported in AutoAnalyzer, needs to manually define the feature
    target = ExplainableFeature(
        name="objects",
        feature_type=BoundingBoxesFeature(
            categories=list(preparation_class.classes.keys()),
        ),
        preprocessor=BoundingBoxesPreprocessor(preprocessed_size=None),
        is_target=True,
    )

    index = IndexMetadata(name="kitti_index")

    # Add the index column in the AnalyzedSchema as we don't use the dataset `analyze` method
    analyzed_schema = AnalyzedSchema(image, target, index)
    analyzed_train_dataset = AnalyzedParquetDataset(
        analyzed_schema=analyzed_schema,
        name="kitti_train_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/frozen_model_kitti/train.parquet",
        storage_options=STORAGE_OPTIONS,
    )

    print(analyzed_schema)

    # 4. Fit the schema

    fit_train_dataset = analyzed_train_dataset.fit()

    fit_test_dataset = FittedParquetDataset(
        name="kitti_test_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/frozen_model_kitti/test.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    fit_val_dataset = FittedParquetDataset(
        name="kitti_validation_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/frozen_model_kitti/val.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    pretrained_model_path = "Xpdeep/dfine-small-kitti"  # Checkpoint path on huggingface hub
    feature_extractor = ObjectDetectionFeatureExtractor(pretrained_model_path=pretrained_model_path)
    feature_extractor.load_pretrained_weights()  # Load pretrained weight

    task_learner = ObjectDetectionTaskLearner(pretrained_model_path=pretrained_model_path, max_detections=22)
    task_learner.load_pretrained_weights()

    # 2. Explainable Model Specifications
    model_specifications = ModelDecisionGraphParameters(
        graph_depth=3,
        population_pruning_threshold=0.1,
        feature_extraction_output_type=FeatureExtractionOutputType.DFINE_MATRIX,
        frozen_model=True,
    )

    # 3. Create the Explainable Model
    xpdeep_model = XpdeepModel(
        example_dataset=fit_train_dataset,
        feature_extraction=feature_extractor,
        task_learner=task_learner,
        backbone=None,
        decision_graph_parameters=model_specifications,
    )

    # ##### Train #######

    # Metrics to monitor the training.
    trainer = FrozenModelTrainer(start_epoch=0, max_epochs=13)
    trained_model = trainer.train(
        model=xpdeep_model,
        train_set=fit_train_dataset,
        validation_set=fit_val_dataset,
        batch_size=32,
    )

    # ##### Explain #######

    # 1. Build the Explainer
    statistics = DictStats()
    quality_metrics = [Sensitivity(), Infidelity()]

    metrics = DictMetrics(
        map50=TorchGlobalMetric(metric=partial(Map50, box_format="cxcywh"), on_raw_data=False),
        leaf_metric_map50=TorchLeafMetric(metric=partial(Map50, box_format="cxcywh"), on_raw_data=False),
        map50_95=TorchGlobalMetric(metric=partial(Map50To95, box_format="cxcywh"), on_raw_data=False),
        leaf_metric_map50_95=TorchLeafMetric(metric=partial(Map50To95, box_format="cxcywh"), on_raw_data=False),
        map_per_class=TorchGlobalMetric(metric=partial(MapPerClass, box_format="cxcywh"), on_raw_data=False),
        leaf_metric_map_per_class=TorchLeafMetric(metric=partial(MapPerClass, box_format="cxcywh"), on_raw_data=False),
    )
    explainer = Explainer(
        description_representativeness=10, quality_metrics=quality_metrics, metrics=metrics, statistics=statistics
    )

    # 2. Model Functioning Explanations
    model_explanations = explainer.global_explain(
        trained_model,
        train_set=fit_train_dataset,
        test_set=fit_test_dataset,
        validation_set=fit_val_dataset,
    )
    print(model_explanations.visualisation_link)

    # 3. Inference and their Causal Explanations
    # No criterion exist to filter by images, we should filter by indexes only.
    my_filter = Filter("testing_filter", fit_test_dataset, min_index=10, max_index=20)
    causal_explanations = explainer.local_explain(trained_model, fit_test_dataset, my_filter)

    print(causal_explanations.visualisation_link)


if __name__ == "__main__":
    init(api_key=os.getenv("API_KEY"), api_url=os.getenv("API_URL"))
    set_project(Project.create_or_get(name="Kitti Tutorial"))

    try:
        main()
    finally:
        get_project().delete()

Train the Explanations#

The train step is straightforward: we need to specify the Trainer parameters. We will use the FrozenModelTrainer interface for convenience, as it carries most of the default correct values for the explanation training process.

Internally, Xpdeep uses its own internal algorithm to compute and train explanations while conserving the original model parameters and performances intact.

from xpdeep.trainer.trainer import FrozenModelTrainer

trainer = FrozenModelTrainer(start_epoch=0, max_epochs=13)
👀 Full file preview
from __future__ import annotations

import os
from copy import deepcopy
from functools import partial
from pathlib import Path
from typing import Any

import albumentations as alb
import boto3
import cv2
import numpy as np
import torch
from botocore.config import Config
from datasets import Dataset, DatasetDict, load_dataset
from PIL import Image as PILImage
from sklearn.model_selection import train_test_split
from xpdeep_metrics.object_detection import Map50, Map50To95, MapPerClass
from xpdeep_modules.object_detection.dfine_models import ObjectDetectionFeatureExtractor, ObjectDetectionTaskLearner

from xpdeep import Project, get_project, init, set_project
from xpdeep.dataset.feature import ExplainableFeature, IndexMetadata
from xpdeep.dataset.feature.feature_types import BoundingBoxesFeature, ImageFeature
from xpdeep.dataset.parquet_dataset import AnalyzedParquetDataset, FittedParquetDataset
from xpdeep.dataset.preprocessor.preprocessor import BoundingBoxesPreprocessor, TorchPreprocessor
from xpdeep.dataset.schema import AnalyzedSchema
from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats
from xpdeep.filtering.filter import Filter
from xpdeep.metric import DictMetrics, TorchGlobalMetric, TorchLeafMetric
from xpdeep.model.feature_extraction_output_type import FeatureExtractionOutputType
from xpdeep.model.model_parameters import ModelDecisionGraphParameters
from xpdeep.model.xpdeep_model import XpdeepModel
from xpdeep.trainer.trainer import FrozenModelTrainer

STORAGE_OPTIONS = {
    "key": os.getenv("S3_DATASET_ACCESS_KEY_ID"),
    "secret": os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
    "client_kwargs": {
        "endpoint_url": os.getenv("S3_DATASET_ENDPOINT_URL"),
    },
    "s3_additional_kwargs": {"addressing_style": "path"},
}

# Define env variable to ensure multiprocessing works.
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
cv2.setNumThreads(0)


def add_xpdeep_index(dataset_dict: DatasetDict):
    """Add a unique 'index' column to each split in the dataset dictionary."""

    def add_index_with_offset(batch: dict[str, Any], idx: list[int], offset: int = 0) -> dict[str, Any]:
        batch["kitti_index"] = np.array(idx) + offset
        return batch

    offset = 0
    for split in dataset_dict:
        dataset_dict[split] = dataset_dict[split].map(
            add_index_with_offset, batched=True, with_indices=True, fn_kwargs={"offset": offset}
        )
        offset += dataset_dict[split].num_rows


def read_kitti_label_file(
    txt_path: Path, image_width: int, image_height: int, class_to_id: dict[str, int]
) -> list[list[float | Any]]:
    """Parse a KITTI label file into YOLO-normalized boxes.

    Parameters
    ----------
    txt_path : Path
        Path to the KITTI label text file to read.
    image_width : int
        Width of the corresponding image in pixels.
    image_height : int
        Height of the corresponding image in pixels.
    class_to_id : dict[str, int]
        Mapping from KITTI class names (e.g., ``"Car"``) to integer class IDs.
        Any class name not present in this mapping is ignored.

    Returns
    -------
    list[list[float | Any]]
        A list of detections; each detection is
        ``[cx, cy, w, h, score, class_id]`` where:
        ``cx, cy, w, h`` are YOLO-normalized to ``[0, 1]`` relative to
        ``(image_width, image_height)``, ``score`` is ``1.0`` for parsed boxes,
        and ``class_id`` is the integer from ``class_to_id``.

    Notes
    -----
    - Boxes are clipped to the image frame; invalid (degenerate) boxes are skipped.
    - Lines whose class is not in ``class_to_id`` (e.g., ``DontCare``) are ignored.
    """
    boxes: list[list[float | Any]] = []
    if not txt_path.exists():
        return boxes

    for line in txt_path.read_text(encoding="utf8").strip().splitlines():
        parts = line.split()
        cls_name = parts[0]
        if cls_name not in class_to_id:
            continue

        xmin, ymin, xmax, ymax = map(float, parts[4:8])
        xmin = max(0.0, xmin)
        ymin = max(0.0, ymin)
        xmax = min(image_width - 1.0, xmax)
        ymax = min(image_height - 1.0, ymax)
        if xmax <= xmin or ymax <= ymin:
            continue

        cx = ((xmin + xmax) * 0.5) / image_width
        cy = ((ymin + ymax) * 0.5) / image_height
        w = (xmax - xmin) / image_width
        h = (ymax - ymin) / image_height
        class_id = int(class_to_id[cls_name])

        if w > 0.0 and h > 0.0:
            boxes.append([cx, cy, w, h, 1.0, class_id])
    return boxes


def compute_majority_class(labels: np.ndarray, areas: np.ndarray) -> int:
    """Compute the per-image majority class (area-weighted tie-break).

    Parameters
    ----------
    labels : np.ndarray
        Class ids for all objects in an image.
    areas : np.ndarray
        Per-object area proxy (e.g., w*h) for tiebreak.

    Returns
    -------
    int
        Majority class id, or raises if `labels` is empty (call-site guards this).
    """
    counts = np.bincount(labels)
    max_count = counts.max()
    tied = np.flatnonzero(counts == max_count)
    if tied.size == 1:
        return int(tied[0])

    area_per_class = np.bincount(labels, weights=areas, minlength=counts.shape[0])
    return int(tied[np.argmax(area_per_class[tied])])


def attach_objects_to_sample(sample: dict, label_dir: Path, class_to_id: dict[str, int]) -> dict[str, Any]:
    """Map function for HF `Dataset.map`.

    - reads KITTI txt for this image,
    - produces YOLO boxes and majority class metadata.
    """
    img_w, img_h = sample["image"].size
    stem = Path(sample["image"].filename).stem
    objects = read_kitti_label_file(label_dir / f"{stem}.txt", img_w, img_h, class_to_id)

    majority = -1
    num_objects = 0
    if objects:
        arr = np.asarray(objects, dtype=np.float32)
        labels_np = arr[:, -1].astype(np.int64)
        num_objects = int(arr.shape[0])
        majority = compute_majority_class(labels_np, arr[:, 2] * arr[:, 3])

    return {"objects": objects, "majority_class": majority, "n_objects": num_objects}


def resize_and_map_fn(  # noqa: PLR0913
    batch: dict,
    *,
    do_pad: bool,
    image_size_wh: tuple[int, int],
    max_size_hw: tuple[int, int] | None,
    pad_size_hw: tuple[int, int] | None,
    padding_row: list[float],
    target_len: int,
) -> dict:
    """Picklable map function for multiprocessing: resize images and remap YOLO boxes."""
    out_w, out_h = image_size_wh  # (W, H)
    # Build bbox params *inside* each process to avoid cross-proc sharing
    albu_bbox_params = alb.BboxParams(format="yolo", label_fields=["labels", "scores"])

    processed_images: list[Any] = []
    processed_objects: list[list[list[float]]] = []

    for pil_img, obj_list in zip(batch["image"], batch["objects"], strict=False):
        img_np = np.array(pil_img)  # HxWxC
        h, w = img_np.shape[:2]

        if obj_list:
            bboxes = [o[:4] for o in obj_list]
            scores = [float(o[4]) for o in obj_list]
            labels = [int(o[5]) for o in obj_list]
        else:
            bboxes, labels, scores = [], [], []

        if not do_pad:
            transform = alb.Compose(
                [alb.Resize(height=out_h, width=out_w, interpolation=cv2.INTER_LINEAR)],
                bbox_params=albu_bbox_params,
            )
        else:
            if max_size_hw is None:
                msg = "When do_pad=True, set max_size_hw=(MAX_H, MAX_W)."
                raise ValueError(msg)
            max_h, max_w = max_size_hw
            pad_h, pad_w = pad_size_hw or max_size_hw

            scale = min(max_h / float(h), max_w / float(w))
            new_h = round(h * scale)
            new_w = round(w * scale)

            transform = alb.Compose(
                [
                    alb.Resize(height=new_h, width=new_w, interpolation=cv2.INTER_LINEAR),
                    alb.PadIfNeeded(
                        min_height=pad_h,
                        min_width=pad_w,
                        position="top_left",
                        border_mode=cv2.BORDER_CONSTANT,
                        value=0,
                    ),
                ],
                bbox_params=albu_bbox_params,
            )

        out = transform(image=img_np, bboxes=bboxes, labels=labels, scores=scores)
        img_resized = PILImage.fromarray(out["image"])
        bboxes_out, labels_out, scores_out = out["bboxes"], out["labels"], out["scores"]

        rows = [[*bb, sc, lb] for bb, lb, sc in zip(bboxes_out, labels_out, scores_out, strict=False)]

        # Pad/truncate to fixed length
        if len(rows) < target_len:
            rows += [padding_row] * (target_len - len(rows))
        else:
            rows = rows[:target_len]

        processed_images.append(img_resized)
        processed_objects.append(rows)

    return {"image": processed_images, "objects": processed_objects}


class KittiPreparationScript:
    """
    KITTI preparation script (object detection).

    Steps:
      1) Load images via HF `imagefolder` and parse KITTI label files.
      2) Create stratified train/val/test splits by per-image majority class.
      3) Resize images using either:
           - direct resize to (W, H) with no padding (do_pad=False), or
           - letterbox (rectangular general case; square if H==W) with top-left placement (do_pad=True).
         Bounding boxes (YOLO format [cx, cy, w, h] normalized) are transformed automatically.
      4) (Optionally) write raw and/or fully preprocessed parquet splits and fit schema.
    """

    def __init__(  # noqa: PLR0917,PLR0913
        self,
        dataset_root_path: str,
        image_size: tuple[int, int] = (1024, 320),
        max_size_hw: tuple[int, int] | None = None,
        do_pad: bool = False,
        pad_size_hw: tuple[int, int] | None = None,
        dataset_name: str = "KITTI",
        random_seed: int = 1225,
    ) -> None:
        """
        Initialize KITTI preparation script.

        Parameters
        ----------
        dataset_root_path : str
            Root path containing `images/` and `labels/` folders.
        image_size : tuple[int, int]
            Target (W, H) when do_pad=False (direct resize); default (1024, 320)
        max_size_hw : tuple[int, int] | None
            (MAX_H, MAX_W) for letterbox. If do_pad=True, this is required.
            Set MAX_H==MAX_W for square letterbox; default None.
        do_pad : bool
            If True, use letterbox with top-left padding; else anisotropic resize; default False.
        pad_size_hw : tuple[int, int] | None
            (PAD_H, PAD_W) final canvas for letterbox. Defaults to max_size_hw; default None
        dataset_name : str
            Name used when writing parquet splits; default "KITTI".
        random_seed : int
            Seed for deterministic splits; default 1225.
        """
        self.dataset_root_path = Path(dataset_root_path)
        self.image_size = image_size
        self.do_pad = do_pad
        self.max_size_hw = max_size_hw
        self.pad_size_hw = pad_size_hw
        self.dataset_name = dataset_name
        self.random_seed = random_seed

        # KITTI classes, we do not use DontCare class like other works
        self.classes = {
            "Car": 0,
            "Pedestrian": 1,
            "Van": 2,
            "Cyclist": 3,
            "Truck": 4,
            "Misc": 5,
            "Tram": 6,
            "Person_sitting": 7,
        }

        # Pad bboxes in each image to have same number of bboxes per image in the dataset.
        # padding row format [cx, cy, w, h, confidence, class_id], where confidence is 0
        # and class_id = num_classes (here 8)
        self.padding_row = [-1, -1, 0, 0, 0, len(self.classes)]

        # Will be filled after scanning the dataset
        self.maximum_objects_per_image: int | None = None

        # Albumentations bbox settings (YOLO), tied to labels and scores
        self.albu_bbox_params = alb.BboxParams(format="yolo", label_fields=["labels", "scores"])

        # Parallelism for HF map()
        self.map_num_proc = min(4, os.cpu_count() or 1)

    def load_full_dataset_with_objects(self) -> Dataset:
        """Load images and labels from `images/` and `labels/` folder.

        Attach to each sample:
          - objects: list[[cx, cy, w, h, score, class_id]] in YOLO normalized coords
          - majority_class: per-image majority class id
        """
        dataset = load_dataset(
            "imagefolder", data_dir=str(self.dataset_root_path / "images"), split="train", drop_labels=True
        )

        label_dir = self.dataset_root_path / "labels"
        map_fn = partial(attach_objects_to_sample, label_dir=label_dir, class_to_id=self.classes)

        dataset = dataset.map(map_fn, desc="Attach objects + majority class")
        self.maximum_objects_per_image = int(np.max(dataset["n_objects"])) if len(dataset) else 0
        return dataset.remove_columns(["n_objects"])

    def make_splits(self, full_dataset: Dataset) -> DatasetDict:
        """Create stratified splits by majority class, resize images and transform boxes according to config."""
        majority_class = np.array(full_dataset["majority_class"])
        full_dataset = full_dataset.remove_columns("majority_class")

        all_indices = np.arange(len(full_dataset))
        training_indices, temporary_indices = train_test_split(
            all_indices, test_size=0.30, stratify=majority_class, random_state=self.random_seed
        )

        validation_proportion_relative = 0.15 / 0.30
        validation_indices, test_indices = train_test_split(
            temporary_indices,
            test_size=1 - validation_proportion_relative,
            stratify=majority_class[temporary_indices],
            random_state=self.random_seed,
        )

        splits = {
            "train": full_dataset.select(training_indices),
            "validation": full_dataset.select(validation_indices),
            "test": full_dataset.select(test_indices),
        }

        resize_and_map = partial(
            resize_and_map_fn,
            do_pad=self.do_pad,
            image_size_wh=self.image_size,
            max_size_hw=self.max_size_hw,
            pad_size_hw=self.pad_size_hw,
            padding_row=self.padding_row,
            target_len=self.maximum_objects_per_image,
        )

        # Apply preprocessing per split (batched & parallel).
        for split_name, split in splits.items():
            splits[split_name] = split.map(
                resize_and_map,
                desc=f"resize+map {split_name}",
                batched=True,
                batch_size=256,
                num_proc=self.map_num_proc,
                load_from_cache_file=True,
            )

        return DatasetDict(splits)


def main():
    """Process the dataset, train, and explain the model."""
    # ##### Prepare the Dataset #######
    preparation_class = KittiPreparationScript(dataset_root_path="object_detection_kitti")
    dataset = preparation_class.load_full_dataset_with_objects()
    splits = preparation_class.make_splits(dataset)

    # Add index xpdeep, "index_xp_deep" will be recognized as IndexMetadata by the analyze method.
    add_xpdeep_index(splits)
    splits.set_format("numpy")  # convert to numpy the pil images

    # Save each split as ".parquet" file
    for split_name, split_data in splits.items():
        # HuggingFace: set chunk to be groups of 100mb by default, here we use 64 rows to be memory efficient.
        # For kitti, we set chunk_size to 64 to get ~50mb per group to avoid memory errors on XpViz.
        split_data.to_parquet(f"{split_name}.parquet", batch_size=64)

    # 2. Upload your Converted Data
    client = boto3.client(
        service_name="s3",
        endpoint_url=os.getenv("S3_DATASET_ENDPOINT_URL"),
        aws_access_key_id=os.getenv("S3_DATASET_ACCESS_KEY_ID"),
        aws_secret_access_key=os.getenv("S3_DATASET_SECRET_ACCESS_KEY"),
        config=Config(signature_version="s3v4"),
    )

    client.upload_file("train.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "frozen_model_kitti/train.parquet")
    client.upload_file("validation.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "frozen_model_kitti/val.parquet")
    client.upload_file("test.parquet", os.getenv("S3_DATASET_BUCKET_NAME"), "frozen_model_kitti/test.parquet")

    # 3. Find a schema.

    # Set a custom image preprocessor, different from the one provided by the AutoAnalyzer.
    class ScaleKitti(TorchPreprocessor):
        """Kitti preprocessor, given an image in range [0, 256], scale the pixel values to [0 ,1]."""

        def transform(self, inputs: torch.Tensor) -> torch.Tensor:  # noqa: PLR6301
            """Transform."""
            return inputs / 255.0

        def inverse_transform(self, output: torch.Tensor) -> torch.Tensor:  # noqa: PLR6301
            """Apply inverse transform."""
            return output * 255.0

    image = ExplainableFeature(
        name="image", feature_type=ImageFeature(), preprocessor=ScaleKitti(input_size=(320, 1024, 3)), is_target=False
    )

    # BBOX not supported in AutoAnalyzer, needs to manually define the feature
    target = ExplainableFeature(
        name="objects",
        feature_type=BoundingBoxesFeature(
            categories=list(preparation_class.classes.keys()),
        ),
        preprocessor=BoundingBoxesPreprocessor(preprocessed_size=None),
        is_target=True,
    )

    index = IndexMetadata(name="kitti_index")

    # Add the index column in the AnalyzedSchema as we don't use the dataset `analyze` method
    analyzed_schema = AnalyzedSchema(image, target, index)
    analyzed_train_dataset = AnalyzedParquetDataset(
        analyzed_schema=analyzed_schema,
        name="kitti_train_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/frozen_model_kitti/train.parquet",
        storage_options=STORAGE_OPTIONS,
    )

    print(analyzed_schema)

    # 4. Fit the schema

    fit_train_dataset = analyzed_train_dataset.fit()

    fit_test_dataset = FittedParquetDataset(
        name="kitti_test_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/frozen_model_kitti/test.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    fit_val_dataset = FittedParquetDataset(
        name="kitti_validation_set",
        path=f"s3://{os.getenv('S3_DATASET_BUCKET_NAME')}/frozen_model_kitti/val.parquet",
        storage_options=STORAGE_OPTIONS,
        fitted_schema=deepcopy(fit_train_dataset.fitted_schema),
    )

    pretrained_model_path = "Xpdeep/dfine-small-kitti"  # Checkpoint path on huggingface hub
    feature_extractor = ObjectDetectionFeatureExtractor(pretrained_model_path=pretrained_model_path)
    feature_extractor.load_pretrained_weights()  # Load pretrained weight

    task_learner = ObjectDetectionTaskLearner(pretrained_model_path=pretrained_model_path, max_detections=22)
    task_learner.load_pretrained_weights()

    # 2. Explainable Model Specifications
    model_specifications = ModelDecisionGraphParameters(
        graph_depth=3,
        population_pruning_threshold=0.1,
        feature_extraction_output_type=FeatureExtractionOutputType.DFINE_MATRIX,
        frozen_model=True,
    )

    # 3. Create the Explainable Model
    xpdeep_model = XpdeepModel(
        example_dataset=fit_train_dataset,
        feature_extraction=feature_extractor,
        task_learner=task_learner,
        backbone=None,
        decision_graph_parameters=model_specifications,
    )

    # ##### Train #######

    # Metrics to monitor the training.
    trainer = FrozenModelTrainer(start_epoch=0, max_epochs=13)
    trained_model = trainer.train(
        model=xpdeep_model,
        train_set=fit_train_dataset,
        validation_set=fit_val_dataset,
        batch_size=32,
    )

    # ##### Explain #######

    # 1. Build the Explainer
    statistics = DictStats()
    quality_metrics = [Sensitivity(), Infidelity()]

    metrics = DictMetrics(
        map50=TorchGlobalMetric(metric=partial(Map50, box_format="cxcywh"), on_raw_data=False),
        leaf_metric_map50=TorchLeafMetric(metric=partial(Map50, box_format="cxcywh"), on_raw_data=False),
        map50_95=TorchGlobalMetric(metric=partial(Map50To95, box_format="cxcywh"), on_raw_data=False),
        leaf_metric_map50_95=TorchLeafMetric(metric=partial(Map50To95, box_format="cxcywh"), on_raw_data=False),
        map_per_class=TorchGlobalMetric(metric=partial(MapPerClass, box_format="cxcywh"), on_raw_data=False),
        leaf_metric_map_per_class=TorchLeafMetric(metric=partial(MapPerClass, box_format="cxcywh"), on_raw_data=False),
    )
    explainer = Explainer(
        description_representativeness=10, quality_metrics=quality_metrics, metrics=metrics, statistics=statistics
    )

    # 2. Model Functioning Explanations
    model_explanations = explainer.global_explain(
        trained_model,
        train_set=fit_train_dataset,
        test_set=fit_test_dataset,
        validation_set=fit_val_dataset,
    )
    print(model_explanations.visualisation_link)

    # 3. Inference and their Causal Explanations
    # No criterion exist to filter by images, we should filter by indexes only.
    my_filter = Filter("testing_filter", fit_test_dataset, min_index=10, max_index=20)
    causal_explanations = explainer.local_explain(trained_model, fit_test_dataset, my_filter)

    print(causal_explanations.visualisation_link)


if __name__ == "__main__":
    init(api_key=os.getenv("API_KEY"), api_url=os.getenv("API_URL"))
    set_project(Project.create_or_get(name="Kitti Tutorial"))

    try:
        main()
    finally:
        get_project().delete()

No more differences with the self-explainable model training and explain process !

You can check that the initial model performances are the same after the explanation training on XpViz. For Kitti, we use MAP metrics (Map50/Map50To95/MapPerClass).