Source code for nicetoolbox.connectors.elan.labeling_data

"""Multiperson annotations grouped by subject, category, and labeled intervals."""

import copy
import logging
from collections import defaultdict
from dataclasses import dataclass


[docs]@dataclass(frozen=True) class LabeledInterval: start_sec: float end_sec: float labels: frozenset[str]
[docs]@dataclass class MultipersonHierarchicalData: data: dict[str, dict[str, list[LabeledInterval]]] @property def subjects(self) -> set[str]: return set(self.data.keys()) @property def categories(self) -> set[str]: if not self.data: return set() return set(next(iter(self.data.values())).keys()) @property def labels_per_category(self) -> dict[str, set[str]]: result: defaultdict[str, set[str]] = defaultdict(set) for subject_data in self.data.values(): for cat, intervals in subject_data.items(): for iv in intervals: result[cat].update(iv.labels) return dict(result) def __str__(self) -> str: return ( f"MultipersonHierarchicalData: " f"subjects=[{self.subjects}], " f"label_per_category={self.labels_per_category}" )
[docs]def apply_category_defaults( data: MultipersonHierarchicalData, category_defaults: dict[str, str] ) -> MultipersonHierarchicalData: unknown = set(category_defaults) - data.categories if unknown: raise ValueError(f"Category defaults reference unknown categories: {unknown} (known: {data.categories})") new_data = copy.deepcopy(data.data) counts: dict[str, int] = {} for subject in new_data: for category, intervals in new_data[subject].items(): if category not in category_defaults: continue default = frozenset({category_defaults[category]}) for idx, iv in enumerate(intervals): if not iv.labels: intervals[idx] = LabeledInterval(iv.start_sec, iv.end_sec, default) counts[category] = counts.get(category, 0) + 1 for category, n in sorted(counts.items()): logging.info(f"Replaced {n} empty annotations in category '{category}' with '{category_defaults[category]}'") return MultipersonHierarchicalData(data=new_data)
[docs]def rename_subjects(data: MultipersonHierarchicalData, subject_mapping: dict[str, str]) -> MultipersonHierarchicalData: unknown = set(subject_mapping) - data.subjects if unknown: raise ValueError(f"Subject mapping references unknown subjects: {unknown} (known: {data.subjects})") new_data = {subject_mapping.get(subj, subj): intervals for subj, intervals in data.data.items()} logging.info(f"Renamed subjects: {subject_mapping}") return MultipersonHierarchicalData(data=new_data)