Source code for nicetoolbox.evaluation.metrics.categorical.roc_auc

import logging

import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score, roc_curve

from ....configs.schemas.evaluation_metrics_config import RocAucConfig
from ...data.input_loader import align_arrays, get_meta_type, load_input
from ...data.plots import plot_roc_curves
from ...data.summary import pair_arrays_to_df, resolve_group_levels, split_aligned_arrays
from ..base_metric import BaseMetric
from ..metric_result import FrameResult, MetricResult, PlotResult, SummaryResult


[docs]class RocAucMetric(BaseMetric): """ROC curve and AUC score per pre-compute group. Predictions must be float confidence scores; ground truth must be bool. AUC is computed once per group by pooling all frames — same rationale as ConfusionMatrixMetric (avoids averaging per-pair scores across groups with different support and class balance). """ metric_config: RocAucConfig
[docs] def compute(self) -> MetricResult: preds = load_input(self.metric_config.predictions) gt = load_input(self.metric_config.ground_truth) pairs = align_arrays(preds, gt, self.metric_config.broadcast_single) if not pairs: raise ValueError(f"Failed to compute ROC AUC '{self.metric_name}': no aligned pred/GT pairs found.") df = pair_arrays_to_df(pairs) n_before = len(df) df = df.dropna(subset=["pred", "gt"]) n_dropped = n_before - len(df) if n_dropped: logging.warning( f"[{self.metric_name}] Dropped {n_dropped}/{n_before} rows with NaN pred or gt " f"({100 * n_dropped / n_before:.1f}%). Missing predictions are excluded." ) _validate_roc_inputs(df, self.metric_name) meta_type = get_meta_type([p for p, _ in pairs]) group_levels = resolve_group_levels(df, meta_type, self.metric_config.compute_group_by) # user dims only (exclude always-iterate) — used for short curve labels always = meta_type.always_iterate() label_levels = [lvl for lvl in group_levels if lvl not in always] or group_levels # drop dimensions that are constant across all groups — they add length without distinguishing curves label_levels = [lvl for lvl in label_levels if df.index.get_level_values(lvl).nunique() > 1] or label_levels # compute per-group: AUC scalar + curve arrays records: list[dict] = [] curves: list[tuple[str, np.ndarray, np.ndarray, float]] = [] for keys, group in df.groupby(level=group_levels, observed=True): if not isinstance(keys, tuple): keys = (keys,) key_map = dict(zip(group_levels, keys)) y_true = group["gt"].to_numpy(dtype=bool) y_score = group["pred"].to_numpy(dtype=float) if self.metric_config.negate_scores: y_score = -y_score if y_true.sum() == 0 or y_true.sum() == len(y_true): logging.warning( "[%s] Skipping degenerate group %s — only one class present (positives=%d/%d).", self.metric_name, keys, int(y_true.sum()), len(y_true), ) auc = float("nan") optimal_threshold = float("nan") fpr = tpr = np.array([0.0, 1.0]) else: fpr, tpr, thresholds = roc_curve(y_true, y_score) auc = float(roc_auc_score(y_true, y_score)) # Youden's J: point on curve closest to top-left corner optimal_threshold = float(thresholds[np.argmax(tpr - fpr)]) if self.metric_config.negate_scores: optimal_threshold = -optimal_threshold label = " | ".join(f"{lvl}={key_map[lvl]}" for lvl in label_levels) curves.append((label, fpr, tpr, auc)) records.append({**key_map, "auc": auc, "optimal_threshold": optimal_threshold, "support": len(y_true)}) summary = pd.DataFrame(records) figures = {} if self.metric_config.visualize: figures["roc_curves"] = plot_roc_curves(curves, self.metric_name) pred_aligned, gt_aligned = split_aligned_arrays(*pairs) return MetricResult( self.metric_name, frames=FrameResult({"predictions": pred_aligned, "ground_truth": gt_aligned}), summary=SummaryResult({"summary": summary}), plots=PlotResult(figures), )
def _validate_roc_inputs(df: pd.DataFrame, metric_name: str) -> None: if not pd.api.types.is_bool_dtype(df["gt"]) and not pd.api.types.is_float_dtype(df["gt"]): raise TypeError(f"[{metric_name}] Column 'gt' must be bool or float, got {df['gt'].dtype}.") unique = df["gt"].unique() if not set(unique).issubset({0.0, 1.0}): raise TypeError(f"[{metric_name}] Column 'gt' must contain only 0/1 or True/False, got {unique}.") if not pd.api.types.is_float_dtype(df["pred"]): raise TypeError( f"[{metric_name}] Column 'pred' must be float (confidence scores), got {df['pred'].dtype}. " "Bool predictions belong in confusion_matrix, not roc_auc." )