Skip to content

multiclass_confusion_matrix

MulticlassConfusionMatrix that support logit.

MulticlassConfusionMatrix(num_classes: int, *, ignore_index: int | None = None, normalize: Literal['none', 'true', 'pred', 'all'] | None = None, validate_args: bool = True, **kwargs: object) #

MulticlassConfusionMatrix with one hot target support.

Source code in src/xpdeep/metrics/zoo/multiclass_confusion_matrix.py
def __init__(
    self,
    num_classes: int,
    *,
    ignore_index: int | None = None,
    normalize: Literal["none", "true", "pred", "all"] | None = None,
    validate_args: bool = True,
    **kwargs: object,
):
    super().__init__(
        average="micro",
        num_classes=num_classes,
        ignore_index=ignore_index,
        normalize=normalize,
        validate_args=validate_args,
        **kwargs,
    )

__module__ = 'xpdeep_utils.metrics.classification.confusion_matrix' #

update(preds: Tensor, target: Tensor) -> None #

Update the metric.

Source code in src/xpdeep/metrics/zoo/multiclass_confusion_matrix.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update the metric."""
    target = torch.argmax(target, dim=1)
    super().update(preds, target)