Skip to content

From a pytorch model to a deep explainable model#

For a quick introduction to the Xpdeep APIs, this section demonstrates, on the Air Quality dataset, how to adapt a standard deep model's PyTorch code to transition to designing an explainable deep model.

We will review the key steps involved in designing a deep model, from architecture specification and training to generating explanations (for Xpdeep).

For each step in building a deep model, we provide:

  • Tabs labeled "SOTA and Xpdeep" for code that is identical for both the SOTA deep model and the Xpdeep explainable model.
  • Tabs labeled "Xpdeep" for code specific to the Xpdeep explainable model.

1. Project Setup#

Setup Api Key and URL#

from xpdeep import init

init(api_key="MY_API_KEY", api_url="MY_API_URL")

Create a Project#

from xpdeep import set_project
from xpdeep.project import Project

set_project(Project.create_or_get(name="Air Quality Tutorial"))

2. Data preparation#

Read Raw Data#

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from datetime import UTC, datetime
import torch


# 1. Split and Convert your Raw Data
# Remove the first rows (incorrect values for some columns) .
data = pd.read_csv("air_quality.csv")[24:]

# Fill NA/NaN values by propagating the last valid observation to next valid value.
data.update({"pm2.5": data["pm2.5"].ffill()})

# Convert time to python datetime.
data["time"] = data.apply(
    lambda x: datetime(year=x["year"], month=x["month"], day=x["day"], hour=x["hour"], tzinfo=UTC), axis=1
)

# Remove unnecessary columns.
data.drop(columns=["year", "month", "day", "hour", "No"], inplace=True)
data.drop(columns=["cbwd"], inplace=True)

# Set the "time" column as index
data = data.set_index("time")
data.head()

# Create the samples
lookback = 24
horizon = 5

# Calculate the number of samples based on dataset length, look_back, and horizon. Each sample overlap the
# next by 1 timestamp.
num_samples = len(data) - lookback - horizon + 1

data_input_numpy = data.to_numpy()  # Inputs contains the target channel as well
# (with its lookback we predict the horizon)
data_target_numpy = data[["pm2.5"]].to_numpy()

# Broadcast the data input and target
repeated_data_input = np.broadcast_to(data_input_numpy, (num_samples, *data_input_numpy.shape))
repeated_data_target = np.broadcast_to(data_target_numpy, (num_samples, *data_target_numpy.shape))

# Generate tensor slices with overlap
tensor_slices = torch.arange(lookback + horizon).unsqueeze(0) + torch.arange(num_samples).unsqueeze(1)

# Get the input and target slices
input_slices = tensor_slices[:, :lookback]
target_slices = tensor_slices[:, lookback:]

time_dimension = 1

# Number of dimensions apart from the temporal one (for multivariate, it's 1)
number_of_data_dims = len(data.shape) - 1

# Gather input and target data using the slices
input_indices_to_gather = input_slices.unsqueeze(*list(range(-number_of_data_dims, 0))).repeat(
    1, 1, *repeated_data_input.shape[2:]
)
target_indices_to_gather = target_slices.unsqueeze(*list(range(-number_of_data_dims, 0))).repeat(
    1, 1, *repeated_data_target.shape[2:]
)

# Reshape the input and target data
transformed_inputs = (torch.gather(torch.from_numpy(repeated_data_input.copy()), time_dimension, input_indices_to_gather).numpy().copy())
transformed_targets = (torch.gather(torch.from_numpy(repeated_data_target.copy()), time_dimension, target_indices_to_gather).numpy().copy())

data = pd.DataFrame({
    "sensor airquality": transformed_inputs.tolist(),  # Convert to a list of arrays for storage in DataFrame
    "target pm2.5": transformed_targets.tolist(),
})

Split Data#

from sklearn.model_selection import train_test_split
import numpy as np

# Split the data into training and validation sets
train_data, test_val_data = train_test_split(data, test_size=0.2, random_state=42)
test_data, val_data = train_test_split(test_val_data, test_size=0.5, random_state=42)

Conversion to Parquet Format#

import pyarrow as pa
import pyarrow.parquet as pq
from xpdeep.dataset.upload import upload

# Convert to pyarrow Table format
train_table = pa.Table.from_pandas(train_data.reset_index(names="airq_index"), preserve_index=False)
val_table = pa.Table.from_pandas(val_data.reset_index(names="airq_index"), preserve_index=False)
test_table = pa.Table.from_pandas(test_data.reset_index(names="airq_index"), preserve_index=False)

# Save each split as ".parquet" file
pq.write_table(train_table, "train.parquet")
pq.write_table(val_table, "val.parquet")
pq.write_table(test_table, "test.parquet")

Upload#

import boto3

client = boto3.client(
    service_name="s3",
    endpoint_url=S3_DATASET_ENDPOINT_URL,
    aws_access_key_id=S3_DATASET_ACCESS_KEY_ID,
    aws_secret_access_key=S3_DATASET_SECRET_ACCESS_KEY,
    config=Config(signature_version="s3v4"),
)

client.upload_file("train.parquet", S3_DATASET_BUCKET_NAME, "air_quality/train.parquet")
client.upload_file("val.parquet", S3_DATASET_BUCKET_NAME, "air_quality/val.parquet")
client.upload_file("test.parquet", S3_DATASET_BUCKET_NAME, "air_quality/test.parquet")

Preprocess Data#

from sklearn.preprocessing import StandardScaler

input_data_for_preprocessor = np.array(train_data["sensor airquality"].to_list())[:,0,:]
target_data_for_preprocessor = np.array(train_data["target pm2.5"].to_list())[:,0,:]

input_encoder = StandardScaler().fit(input_data_for_preprocessor)
target_encoder = StandardScaler().fit(target_data_for_preprocessor)

preprocessed_input = input_encoder.transform(np.array(data["sensor airquality"].to_list()).reshape(-1,7)).reshape(-1, lookback, 7)
preprocessed_target = target_encoder.transform(np.array(data["target pm2.5"].to_list()).reshape(-1,1)).reshape(-1, horizon, 1)

x_train = input_encoder.transform(np.array(train_data["sensor airquality"].to_list()).reshape(-1,7)).reshape(-1, 7, lookback)
y_train = target_encoder.transform(np.array(train_data["target pm2.5"].to_list()).reshape(-1,1)).reshape(-1, horizon, 1)

x_val = input_encoder.transform(np.array(val_data["sensor airquality"].to_list()).reshape(-1,7)).reshape(-1, 7, lookback)
y_val = target_encoder.transform(np.array(val_data["target pm2.5"].to_list()).reshape(-1,1)).reshape(-1, horizon, 1)

x_test = input_encoder.transform(np.array(test_data["sensor airquality"].to_list()).reshape(-1,7)).reshape(-1, 7, lookback)
y_test = target_encoder.transform(np.array(test_data["target pm2.5"].to_list()).reshape(-1,1)).reshape(-1, horizon, 1)
class Scaler(TorchPreprocessor):
    """Air quality preprocessor."""

    def __init__(self, input_size: tuple[int, ...], mean: torch.Tensor, scale: torch.Tensor):
        super().__init__(input_size=input_size)
        # Saved as buffer for torch.export: saved loaded with `state_dict` but not optimized with `optimizer.step()
        self.register_buffer("mean", mean)
        self.register_buffer("scale", scale)

    def transform(self, inputs: torch.Tensor) -> torch.Tensor:
        """Transform."""
        return (inputs - self.mean) / self.scale

    def inverse_transform(self, output: torch.Tensor) -> torch.Tensor:
        """Apply inverse transform."""
        return output * self.scale + self.mean


input_tensor = torch.tensor(data["sensor airquality"].to_list())
mean_input = input_tensor[:, 0, :].mean(dim=0)
scale_input = input_tensor[:, 0, :].std(dim=0)

target_tensor = torch.tensor(data["target pm2.5"].to_list())
mean_target = target_tensor[:, 0, :].mean(dim=0)
scale_target = target_tensor[:, 0, :].std(dim=0)

fitted_schema = FittedSchema(
    ExplainableFeature(
        name="sensor airquality",
        preprocessor=Scaler((24, 7), mean=mean_input, scale=scale_input),
        feature_type=MultivariateTimeSeries(
            asynchronous=True, channel_names=["pm2.5", "DEWP", "TEMP", "PRES", "Iws", "Is", "Ir"]
        ),
    ),
    ExplainableFeature(
        name="target pm2.5",
        preprocessor=Scaler((5, 1), mean=mean_target, scale=scale_target),
        is_target=True,
        feature_type=UnivariateTimeSeries(
            asynchronous=True,
            mirrored_channel=("sensor airquality", "pm2.5"),
        ),
    ),
IndexMetadata(name="airq_index"),
input_size=(1, 24, 7),
target_size=(1, 5, 1),
)

storage_options={
            "key": S3_DATASET_ACCESS_KEY_ID,
            "secret": S3_DATASET_SECRET_ACCESS_KEY,
            "client_kwargs": {
                "endpoint_url": S3_DATASET_ENDPOINT_URL,
            },
            "s3_additional_kwargs": {"addressing_style": "path"},
        }

fit_train_dataset = FittedParquetDataset(
    name="air_quality_train_set",
    path=f"s3://{S3_DATASET_BUCKET_NAME}/air_quality/train.parquet",
    storage_options=storage_options,
    fitted_schema=fitted_schema,
)
print(fitted_schema)

fit_test_dataset = FittedParquetDataset(
    name="air_quality_test_set",
    path=f"s3://{S3_DATASET_BUCKET_NAME}/air_quality/test.parquet",
    storage_options=storage_options,
    fitted_schema=fit_train_dataset.fitted_schema,
)

fit_val_dataset = FittedParquetDataset(
    name="air_quality_validation_set",
    path=f"s3://{S3_DATASET_BUCKET_NAME}/air_quality/val.parquet",
    storage_options=storage_options,
    fitted_schema=fit_train_dataset.fitted_schema,
)

input_size = fit_train_dataset.fitted_schema.input_size[1:]
target_size = fit_train_dataset.fitted_schema.target_size[1:]

3. Model Construction#

Architecture Specification#

import torch
from torch.nn import Sequential

class SotaModel(Sequential):
    def __init__(self):
        layers = [
            torch.nn.Conv1d(7, 16, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv1d(16, 32, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.Flatten(),
            torch.nn.LazyLinear(out_features=horizon)    
        ]
        super().__init__(*layers)


    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        x = super().forward(inputs)
        x = x.reshape(-1,horizon,1)
        return x
from torch.nn import Sequential

class FeatureExtractor(Sequential):
    def __init__(self):
        layers = [
            torch.nn.Conv1d(7, 16, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv1d(16, 32, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
        ]

        super().__init__(*layers)

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        x = inputs.reshape(-1, 7, lookback)
        return super().forward(x)

class TaskLearner(Sequential):
    def __init__(self):
        layers = [torch.nn.Flatten(), torch.nn.LazyLinear(out_features=horizon)]

        super().__init__(*layers)

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        x = super().forward(inputs)
        return x.reshape(-1, horizon, 1)

feature_extractor = FeatureExtractor()
task_learner = TaskLearner()

Model Instantiation#

sota_model = SotaModel()
from xpdeep.model.model_builder import ModelDecisionGraphParameters
from xpdeep.model.xpdeep_model import XpdeepModel

# Model specifications and hyperparameters.
model_specifications = ModelDecisionGraphParameters(
    feature_extraction_output_type=FeatureExtractionOutputType.TEMPORAL_MATRIX,
)

# Xpdeep Model Architecture
xpdeep_model = XpdeepModel.from_torch(
    example_dataset=fit_train_dataset,
    feature_extraction=feature_extractor,
    task_learner=task_learner,
    backbone=None,
    decision_graph_parameters=model_specifications,
)

4. Training#

Training Specification#

from torch import nn
import torch

loss_fn = nn.MSELoss()
optimizer = torch.optim.AdamW(sota_model.parameters(), lr=1e-3)
batch_size = 128
epochs = 30
from xpdeep.trainer.callbacks import EarlyStopping, ModelCheckpoint, Scheduler
from functools import partial
from xpdeep.metric import DictMetrics, TorchGlobalMetric, TorchLeafMetric
from torch.optim.lr_scheduler import ReduceLROnPlateau
from xpdeep.trainer.trainer import Trainer
from xpdeep.model.zoo.cross_entropy_loss_from_proba import CrossEntropyLossFromProbabilities
import time
from torchmetrics MeanSquaredError   

# Metrics to monitor the training.
metrics = DictMetrics(
    mse=TorchGlobalMetric(metric=partial(MeanSquaredError), on_raw_data=True),
    leaf_metric_flatten_mse=TorchLeafMetric(metric=partial(MeanSquaredError), on_raw_data=True),
)

callbacks = [
    EarlyStopping(monitoring_metric="mse", mode="minimize", patience=5),
    Scheduler(
        pre_scheduler=partial(ReduceLROnPlateau, patience=3, mode="min"),
        step_method="epoch",
        monitoring_metric="Total loss",
    ),
]

# Optimizer is a partial object as pytorch needs to give the model as optimizer parameter.
optimizer = partial(torch.optim.AdamW, lr=0.01, foreach=False, fused=False)

trainer = Trainer(
    loss=torch.nn.MSELoss(reduction="none"),
    optimizer=optimizer,
    callbacks=callbacks,
    start_epoch=0,
    max_epochs=30,
    metrics=metrics,
)

Model Training#

import torch
import time
from sklearn.metrics import mean_squared_error, root_mean_squared_error

torch.manual_seed(0)

def train(X_train, y_train, model, loss_fn, optimizer):

    size = len(X_train)
    model.train()
    total_loss = 0

    for batch in range(size//batch_size):

        X_batch, y_batch = torch.tensor(X_train[batch*batch_size:(batch+1)*batch_size,:,:], dtype=torch.float32).to(device), torch.tensor(y_train[batch*batch_size:(batch+1)*batch_size,:], dtype=torch.float32).to(device)

        # X_batch, y_batch = X_batch.reshape(-1, lookback*7), y_batch.reshape(-1,horizon)

        # Compute prediction error
        pred = model(X_batch)
        loss = loss_fn(pred, y_batch)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    average_loss = total_loss/(size//batch_size)
    return average_loss


def eval_(X_test, y_test, model, loss_fn):

    model.eval()
    with torch.no_grad():
        X_test, y_test = torch.tensor(X_test, dtype=torch.float32).to(device), torch.tensor(y_test, dtype=torch.float32).to(device)

        # X_test, y_test = X_test.reshape(-1, lookback*7), y_test.reshape(-1,horizon)

        pred = model(X_test)
        test_loss = loss_fn(pred, y_test).item()

        mse = mean_squared_error(target_encoder.inverse_transform(y_test.reshape(-1,horizon)), target_encoder.inverse_transform(pred.reshape(-1,horizon)))
        rmse = root_mean_squared_error(target_encoder.inverse_transform(y_test.reshape(-1,horizon)), target_encoder.inverse_transform(pred.reshape(-1,horizon)))

        return target_encoder.inverse_transform(pred.reshape(-1,horizon)), test_loss, mse, rmse


start_time = time.time()

for t in range(epochs):

    print(f"\nEpoch {t+1}\n-------------------------------")


    training_loss = train(
        x_train, 
        y_train, 
        sota_model, 
        loss_fn, 
        optimizer
    )

    _, val_loss, _, _ = eval_(
        x_val, 
        y_val, 
        sota_model, 
        loss_fn
    )

    print(f"Training Loss: {training_loss}\nValidation Loss: {val_loss}")

_, _, mse_on_train  , rmse_on_train = eval_(x_train, y_train, sota_model, loss_fn)
_, _, mse_on_validation, rmse_on_validation = eval_(x_val, y_val, sota_model, loss_fn)
_, _, mse_on_test, rmse_on_test = eval_(x_test, y_test, sota_model, loss_fn)

print(f"\nTraining time : --- {time.time() - start_time:.2f} seconds --- \n")
print(f"\nMSEs: "
      f"\nMSE on train set       : {mse_on_train}"
      f"\nMSE on validation set  : {mse_on_validation}"
      f"\nMSE on test set        : {mse_on_test}"
      )
trained_model = trainer.train(
    model=xpdeep_model,
    train_set=fit_train_dataset,
    validation_set=fit_val_dataset,
    batch_size=128,
)

5. Explanation Generation#

from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats, HistogramStat

statistics = DictStats(
        distribution_target=HistogramStat(on="target", num_bins=20, num_items=1000, on_raw_data=True),
        distribution_prediction=HistogramStat(on="prediction", num_bins=20, num_items=1000, on_raw_data=True),
        distribution_error=HistogramStat(on="prediction_error", num_bins=20, num_items=1000, on_raw_data=True),
    )

quality_metrics = [Sensitivity(), Infidelity()]

explainer = Explainer(
    description_representativeness=1000, quality_metrics=quality_metrics, metrics=metrics, statistics=statistics
)

model_explanations = explainer.global_explain(
    trained_model,
    train_set=fit_train_dataset,
    test_set=fit_test_dataset,
    validation_set=fit_val_dataset,
)
print(model_explanations.visualisation_link)