Source code for nicetoolbox.evaluation.metrics.categorical.confusion_matrix

import logging

import pandas as pd
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support

from ....configs.schemas.evaluation_metrics_config import ConfusionMatrixConfig
from ...data.input_loader import align_arrays, get_meta_type, load_input
from ...data.plots import plot_confusion_matrix_grid
from ...data.summary import pair_arrays_to_df, resolve_group_levels, split_aligned_arrays
from ..base_metric import BaseMetric
from ..metric_result import FrameResult, MetricResult, PlotResult, SummaryResult


[docs]class ConfusionMatrixMetric(BaseMetric): """Binary confusion matrix with precision, recall, F1 per pre-compute group. Pools all pred/gt boolean values within each group, then runs sklearn's confusion_matrix once per group. This avoids Simpson's paradox from averaging per-pair F1s across groups with different support. Input arrays must have bool dtype (or castable to bool). Confidence floats should use a separate metric — routing floats here silently corrupts results. """ metric_config: ConfusionMatrixConfig
[docs] def compute(self) -> MetricResult: # load gt and labels and align them frame by frame preds = load_input(self.metric_config.predictions) gt = load_input(self.metric_config.ground_truth) pairs = align_arrays(preds, gt, self.metric_config.broadcast_single) if not pairs: raise ValueError( f"Failed to compute Confusion Matrix '{self.metric_name}': no aligned pred/GT pairs found." ) # convert pred and gt into one dataframe df = pair_arrays_to_df(pairs) # drop nan n_before = len(df) df = df.dropna(subset=["pred", "gt"]) n_dropped = n_before - len(df) if n_dropped: logging.warning( f"[{self.metric_name}] Dropped {n_dropped}/{n_before} rows with NaN pred or gt " f"({100 * n_dropped / n_before:.1f}%). Missing predictions are excluded, not counted as wrong." ) # figure out group by levels (always iterate + what user selected) meta_type = get_meta_type([p for p, _ in pairs]) group_levels = resolve_group_levels(df, meta_type, self.metric_config.compute_group_by) # generate summary summary = df.groupby(level=group_levels, observed=True).apply(self._compute_cm_row).reset_index() # optional visualization figures = {} if self.metric_config.visualize: figures = plot_confusion_matrix_grid(summary, group_levels, self.metric_name) pred_aligned, gt_aligned = split_aligned_arrays(*pairs) return MetricResult( self.metric_name, frames=FrameResult({"predictions": pred_aligned, "ground_truth": gt_aligned}), summary=SummaryResult({"summary": summary}), plots=PlotResult(figures), )
@staticmethod def _compute_cm_row(group: pd.DataFrame) -> pd.Series: y_true = group["gt"].to_numpy(dtype=bool) y_pred = group["pred"].to_numpy(dtype=bool) cm = confusion_matrix(y_true, y_pred, labels=[False, True]) tn, fp, fn, tp = cm.ravel() precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0) support = int(len(y_true)) accuracy = float(tp + tn) / support if support else float("nan") if precision == 0.0 and recall == 0.0: logging.warning( "Degenerate group (all-negative or all-positive): precision and recall are both 0. " "Support=%d, positives=%d", support, int(y_true.sum()), ) return pd.Series( { "tp": int(tp), "fp": int(fp), "fn": int(fn), "tn": int(tn), "accuracy": accuracy, "precision": float(precision), "recall": float(recall), "f1": float(f1), "support": support, } )