Source code for nicetoolbox.evaluation.metrics.results_schema

"""
Result containers and schemas for evaluation metrics.
"""

from dataclasses import dataclass, field
from typing import Any, Dict, List, Union

import numpy as np
import torch

from ..data.discovery import ChunkWorkItem, FrameInfo


# (1) Container during metric processing (update calls)
[docs]@dataclass(frozen=True) class BatchResult: """Results container during batch evaluations. Further processing needed.""" results_tensor: torch.Tensor results_description: List[str] meta_chunk: ChunkWorkItem meta_frames: List[FrameInfo]
# (2) Container for summary metrics after processing batches (compute calls)
[docs]@dataclass(frozen=True) class AggregatedResult: """Result container for final summary metrics. No further processing needed.""" value: float metric_type: str metric_name: str component: str algorithm: str person: str camera: str
# Metrics can be frame based or aggregated summaries (jpe vs accuracy) MetricReturnType = Union[List[BatchResult], AggregatedResult] # (1.1) Container for unpacking BatchResults into single frame results -> FrameResult
[docs]@dataclass(frozen=True) class FrameResult: """Result container for final frame based metrics""" value: np.ndarray # (1) Used for generating file path metric_name: str metric_type: str session: str sequence: str component: str algorithm: str # (2) Used to re-grid to ndarray person: str camera: str frame: int # (3) Used for building data description axis3_description: List[str]
# (1.2) FrameResults are transformed to a ResultGrid, same structure as detector output
[docs]@dataclass class ResultGrid: metric_name: str # E.g. jpe, bone_length, ... values: np.ndarray # Shape: [#person x #camera x #frames x #data] description: Dict[str, Any] # {"axis0": ["person"], ..., "axis3": ["data"]}
# (1.3) One ResultFileGroup contains all ResultGrids (metrics) of a metric type
[docs]@dataclass class ResultFileGroup: # (1) Meta data for file path generation session: str sequence: str component: str algorithm: str metric_type: str # E.g. categorical metrics # (2) Entries of a single NPZ file grids: List[ResultGrid] = field(default_factory=list) # E.g. acc, precision, ...