Source code for nicetoolbox.evaluation.metrics.categorical

"""
Categorical metrics for classification tasks, including accuracy, precision,
recall, and F1 score.
"""

from collections import defaultdict
from inspect import FrameInfo
from typing import Dict, List, Tuple

import torch

from ..data.discovery import ChunkWorkItem
from .base_metric import Metric, MetricHandler, MetricReturnType
from .results_schema import AggregatedResult


[docs]class CategoricalMetric(MetricHandler): """ Handler for all categorical metrics. """ @property def name(self) -> str: return "categorical_metrics" def _create_metrics(self) -> Dict[str, Metric]: """Instantiates the individual metric objects this handler manages.""" metric_map = { "accuracy": Accuracy, "precision": Precision, "recall": Recall, "f1_score": F1Score, } return {name: metric_map[name]() for name in self.cfg.metric_names if name in metric_map}
class _BinaryClassificationMetric(Metric): """Base class for binary classification metrics like precision, recall, accuracy.""" def __init__(self): self.reset() def reset(self) -> None: """Reset the metric's internal state: counts of TP, TN, FP, FN.""" # Key: (component, algorithm, person, camera) self.counts: Dict[Tuple[str, str, str, str], Dict[str, int]] = defaultdict(lambda: defaultdict(int)) def update( self, preds: torch.Tensor, gts: torch.Tensor, meta_chunk: ChunkWorkItem, meta_frames: List[FrameInfo], ) -> None: """ Update metric state with a batch of data and its corresponding metadata, accumulating counts of true positives, true negatives, false positives, and false negatives. Assumes binary classification with predictions and ground truths as boolean tensors. Args: preds (torch.Tensor): Predictions tensor (binary). gts (torch.Tensor): Ground truth tensor (binary). meta_chunk (ChunkWorkItem): Metadata for the chunk. meta_frames (List[FrameInfo]): List of frame metadata. """ if gts is None: return preds_bool, gts_bool = preds.bool(), gts.bool() assert preds_bool.shape == gts_bool.shape, "Preds and gts shape does not match." context_keys = [ (meta_chunk.component, meta_chunk.algorithm, frame.person, frame.camera) for frame in meta_frames ] unique_keys = set(context_keys) for key in unique_keys: select_frames = [k == key for k in context_keys] if not select_frames: continue preds = preds_bool[select_frames] gts = gts_bool[select_frames] assert preds.shape == gts.shape, "Preds and gts must have the same shape." self.counts[key]["tp"] += torch.sum(preds & gts).item() self.counts[key]["tn"] += torch.sum(~preds & ~gts).item() self.counts[key]["fp"] += torch.sum(preds & ~gts).item() self.counts[key]["fn"] += torch.sum(~preds & gts).item() def compute(self) -> Dict[Tuple[str, str, str], MetricReturnType]: raise NotImplementedError def get_axis3(self): raise NotImplementedError # Not needed currently for summary metrics
[docs]class Accuracy(_BinaryClassificationMetric): """Calculate the accuracy for binary classification. Returns: Dict[Tuple[str, str, str, str, str], MetricReturnType]: Dictionary with key (component, algorithm, person, camera, "accuracy") and the corresponding aggregated accuracy value. """
[docs] def compute(self) -> Dict[Tuple[str, str, str, str, str], MetricReturnType]: results = {} for (component, algorithm, person, camera), counts in self.counts.items(): denominator = counts["tp"] + counts["tn"] + counts["fp"] + counts["fn"] value = (counts["tp"] + counts["tn"]) / denominator if denominator > 0 else 0.0 key = (component, algorithm, person, camera, "accuracy") results[key] = AggregatedResult( value=value, metric_type="categorical_metrics", metric_name="accuracy", component=component, algorithm=algorithm, person=person, camera=camera, ) return results
[docs]class Precision(_BinaryClassificationMetric): """Calculate the precision for binary classification. Returns: Dict[Tuple[str, str, str, str, str], MetricReturnType]: Dictionary with key (component, algorithm, person, camera, "precision") and the corresponding aggregated precision value. """
[docs] def compute(self) -> Dict[Tuple[str, str, str, str, str], MetricReturnType]: results = {} for (component, algorithm, person, camera), counts in self.counts.items(): denominator = counts["tp"] + counts["fp"] value = counts["tp"] / denominator if denominator > 0 else 0.0 key = (component, algorithm, person, camera, "precision") results[key] = AggregatedResult( value=value, metric_type="categorical_metrics", metric_name="precision", component=component, algorithm=algorithm, person=person, camera=camera, ) return results
[docs]class Recall(_BinaryClassificationMetric): """Calculate the recall for binary classification. Returns: Dict[Tuple[str, str, str, str, str], MetricReturnType]: Dictionary with key (component, algorithm, person, camera, "recall") and the corresponding aggregated recall value. """
[docs] def compute(self) -> Dict[Tuple[str, str, str, str, str], MetricReturnType]: results = {} for (component, algorithm, person, camera), counts in self.counts.items(): denominator = counts["tp"] + counts["fn"] value = counts["tp"] / denominator if denominator > 0 else 0.0 key = (component, algorithm, person, camera, "recall") results[key] = AggregatedResult( value=value, metric_type="categorical_metrics", metric_name="recall", component=component, algorithm=algorithm, person=person, camera=camera, ) return results
[docs]class F1Score(_BinaryClassificationMetric): """Calculate the F1 score for binary classification. Returns: Dict[Tuple[str, str, str, str, str], MetricReturnType]: Dictionary with key (component, algorithm, person, camera, "f1_score") and the corresponding aggregated F1 score value. """
[docs] def compute(self) -> Dict[Tuple[str, str, str, str, str], MetricReturnType]: results = {} for (component, algorithm, person, camera), counts in self.counts.items(): precision = counts["tp"] / (counts["tp"] + counts["fp"]) if (counts["tp"] + counts["fp"]) > 0 else 0.0 recall = counts["tp"] / (counts["tp"] + counts["fn"]) if (counts["tp"] + counts["fn"]) > 0 else 0.0 denominator = precision + recall f1_score = (2 * precision * recall / denominator) if denominator > 0 else 0.0 key = (component, algorithm, person, camera, "f1_score") results[key] = AggregatedResult( value=f1_score, metric_type="categorical_metrics", metric_name="f1_score", component=component, algorithm=algorithm, person=person, camera=camera, ) return results