Source code for nicetoolbox.evaluation.data.dataset

"""
Dataset class for evaluation, implemented as an IterableDataset.
"""

import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import torch
from torch.utils.data import IterableDataset

from .discovery import ChunkWorkItem
from .loaders import AnnotationLoader, PredictionLoader


[docs]class EvaluationDataset(IterableDataset): def __init__( self, work_items: List[ChunkWorkItem], prediction_loader: PredictionLoader, annotation_loader: Optional[AnnotationLoader], ): """Initialize the EvaluationDataset with found work items, the prediction loader, and optionally the annotation loader. Args: work_items (List[ChunkWorkItem]): List of work items (chunks) to process. prediction_loader (PredictionLoader): Loader for prediction data. annotation_loader (Optional[AnnotationLoader]): Loader for annotation data. """ self.work_items = work_items self.pred_loader = prediction_loader self.annot_loader = annotation_loader self.total = sum(len(chunk.frames) for chunk in self.work_items) if not self.work_items: logging.warning("EvaluationDataset initialized with zero work items.") def __len__(self): """Returns the total number of frames to be processed.""" return self.total @staticmethod def _apply_reconciliation(data: np.ndarray, rec_map: Dict[str, Tuple[int, ...]]) -> np.ndarray: """ Applies a reconciliation map to a loaded data array. Slices the array according to the provided map. We use this to handle cases where the predictions and ground truth data have different dimensions or need to be reconciled in some way (e.g., different keypoint sets). Args: data (np.ndarray): The data array to reconcile. rec_map (Dict[str, Tuple[int, ...]]): The reconciliation map. Returns: np.ndarray: The reconciled data array. """ if data is None or not rec_map: return data axis3 = rec_map.get("axis3") axis4 = rec_map.get("axis4") if axis3 is not None and axis4 is not None: sliced = data[np.ix_(axis3, axis4)] elif axis3 is not None: sliced = data[axis3, ...] elif axis4 is not None: sliced = data[..., axis4] return sliced def __iter__(self): """ Iterates over all work items, loading the prediction and ground truth data, applying reconciliation maps if available, and yielding the processed data. Yields: Tuple[np.ndarray, Optional[np.ndarray], WorkItem]: A tuple containing: - The processed prediction data as a numpy array. - The processed ground truth data as a numpy array (or None). - The corresponding WorkItem instance for metadata. """ # Outer loop: iterate through each chunk (expensive I/O operation) for chunk in self.work_items: try: # === (1) Load full prediction and GT arrays === raw_pred_array = self.pred_loader.load_full_array(path=chunk.pred_path, data_key=chunk.pred_data_key) raw_pred_array = raw_pred_array.astype(np.float32) raw_gt_array: Optional[np.ndarray] = None if self.annot_loader and chunk.annot_path and chunk.annot_data_key: raw_gt_array = self.annot_loader.load_full_array(data_key=chunk.annot_data_key) raw_gt_array = raw_gt_array.astype(np.float32) except Exception as e: logging.error(f"Failed to load data for {chunk}: {e}", exc_info=True) continue # Inner loop: iterate through each frame in the chunk for frame_info in chunk.frames: try: # === (2) Extract the relevant data for this frame === p_idx, c_idx, f_idx = frame_info.pred_slicing_indices raw_pred_data = raw_pred_array[p_idx, c_idx, f_idx, ...] raw_gt_data = None if raw_gt_array is not None and frame_info.annot_slicing_indices: p_idx_gt, c_idx_gt, f_idx_gt = frame_info.annot_slicing_indices raw_gt_data = raw_gt_array[p_idx_gt, c_idx_gt, f_idx_gt, ...] # === (3) Apply reconciliation maps if available === pred_data = EvaluationDataset._apply_reconciliation(raw_pred_data, chunk.pred_reconciliation_map) gt_data = EvaluationDataset._apply_reconciliation(raw_gt_data, chunk.gt_reconciliation_map) yield pred_data, gt_data, chunk, frame_info except Exception as e: logging.error( f"Failed to process frame {frame_info} in chunk " f"{chunk.pred_path}: {e}", exc_info=True, ) continue
[docs] @staticmethod def collate_fn(batch: List[tuple]) -> Dict[str, Any]: """ Custom collate function given to the torch DataLoader to batch the data from the iterable EvaluationDataset. This function groups the samples by their chunk's metric type, prediction shape, component, algorithm, and prediction data key. This ensures that each batch contains homogeneous data, allowing for easier handling at metric processing time. Args: batch (List[tuple]): A list of tuples (pred_data, gt_data, metadata). Returns: Dict[str, Dict[str, Any]]: A dictionary containing for each group: - 'pred': Batched prediction data as a torch tensor. - 'gt': Batched ground truth data as a torch tensor or None. - 'metadata': A list of metadata dictionaries from each work item. """ grouped_samples = defaultdict(list) for pred, gt, chunk, frame in batch: compound_key = ( chunk.metric_type, pred.shape, chunk.session, chunk.component, chunk.algorithm, chunk.pred_data_key, ) grouped_samples[compound_key].append((pred, gt, chunk, frame)) final_batch = {} for key, samples in grouped_samples.items(): preds, gts, chunks, frames = zip(*samples) # Ensure that all chunks are equal objects since we grouped them all_equal = all(chunk == chunks[0] for chunk in chunks) if not all_equal: logging.error("In collate_fn, not all chunks are equal within a group. " "This should not happen.") raise ValueError("Inconsistent chunks in collate_fn grouping.") stacked_preds = torch.from_numpy(np.stack(preds)) stacked_gts = None if all(g is not None for g in gts): stacked_gts = torch.from_numpy(np.stack(gts)) final_batch[key] = { "pred": stacked_preds, "gt": stacked_gts, "chunk": chunks[0], "frames": list(frames), } return final_batch