Source code for nicetoolbox.evaluation.data.plots

import logging

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

from ...configs.schemas.evaluation_group_by import GroupBySpec
from .input_loader import LoadedArray

_NO_SPLIT = GroupBySpec(dims=[])
_TAB10 = plt.cm.tab10.colors


class _LabelColorMap:
    """Assigns a unique tab10 color to each label on first encounter."""

    def __init__(self):
        self._map: dict[str, tuple] = {}
        self._counter = 0

    def __call__(self, label: str) -> tuple:
        if label not in self._map:
            self._map[label] = _TAB10[self._counter % len(_TAB10)]
            self._counter += 1
        return self._map[label]


_label_color = _LabelColorMap()


def _plot_candle_single(
    df: pd.DataFrame,
    x_col: str,
    title: str,
    y_label: str,
    series_col: str | None = None,
) -> plt.Figure:
    """Render one candle chart for a single (subject, sequence, ...) slice.

    When series_col is None, renders all data as a single color with no legend.
    """
    x_values = df[x_col].unique()
    n_x = len(x_values)
    fig, ax = plt.subplots(figsize=(max(8, n_x * 1.2), 6))
    x_positions = np.arange(n_x)

    if series_col is None:
        series_list = [(None, df, _TAB10[0])]
        width = 0.8
    else:
        series = df[series_col].unique()
        n_series = len(series)
        width = 0.8 / n_series
        series_list = [(s, df[df[series_col] == s], _label_color(s)) for s in series]

    for i, (label, series_df, color) in enumerate(series_list):
        series_df = series_df.set_index(x_col)
        if series_col is not None:
            n_series = len(series_list)
            offsets = x_positions + (i - n_series / 2 + 0.5) * width
        else:
            offsets = x_positions

        for j, x_val in enumerate(x_values):
            if x_val not in series_df.index:
                continue
            row = series_df.loc[x_val]
            low, high = row["min"], row["max"]
            q25, q75, median = row["q25"], row["q75"], row["median"]

            cap = width * 0.3
            ax.plot([offsets[j], offsets[j]], [low, high], color=color, linewidth=1)
            ax.plot([offsets[j] - cap, offsets[j] + cap], [low, low], color=color, linewidth=1)
            ax.plot([offsets[j] - cap, offsets[j] + cap], [high, high], color=color, linewidth=1)
            ax.bar(offsets[j], q75 - q25, bottom=q25, width=width * 0.9, color=color, alpha=0.6)
            ax.plot([offsets[j] - width * 0.4, offsets[j] + width * 0.4], [median, median], color=color, linewidth=2.5)

        if label is not None:
            ax.plot([], [], color=color, label=label)

    ax.set_xticks(x_positions)
    ax.set_xticklabels(x_values, rotation=45, ha="right")
    ax.set_ylabel(y_label)
    ax.set_title(title, wrap=True)
    if series_col is not None:
        ax.legend()
    fig.tight_layout()
    return fig


[docs]def plot_score( df: pd.DataFrame, x_col: str, y_col: str, title: str, x_label: str, y_label: str, ) -> plt.Figure: """Bar chart comparing a scalar score across groups (e.g. algorithm vs CV). If multiple rows share the same x_col value, their y_col values are averaged. """ score = df.groupby(x_col)[y_col].mean().sort_values() fig, ax = plt.subplots(figsize=(max(6, len(score) * 1.2), 5)) bars = ax.bar(range(len(score)), score.values, color=[_label_color(label) for label in score.index]) ax.set_xticks(range(len(score))) ax.set_xticklabels(score.index, rotation=45, ha="right") ax.set_xlabel(x_label) ax.set_ylabel(y_label) ax.set_title(title) ax.bar_label(bars, fmt="%.3f", padding=3) fig.tight_layout() return fig
[docs]def plot_score_grouped( df: pd.DataFrame, x_col: str, series_col: str, y_col: str, title: str, x_label: str, y_label: str, ) -> plt.Figure: """Grouped bar chart comparing a scalar score across x groups, with one color per series. If multiple rows share the same (x_col, series_col) pair, their y_col values are averaged. """ df = df.groupby([x_col, series_col])[y_col].mean().reset_index() x_values = df[x_col].unique() series = df[series_col].unique() n_x = len(x_values) n_series = len(series) x_positions = np.arange(n_x) width = 0.8 / n_series fig, ax = plt.subplots(figsize=(max(8, n_x * 1.2), 5)) for i, s in enumerate(series): series_df = df[df[series_col] == s].set_index(x_col) offsets = x_positions + (i - n_series / 2 + 0.5) * width values = [series_df.loc[x, y_col] if x in series_df.index else float("nan") for x in x_values] bars = ax.bar(offsets, values, width=width * 0.9, color=_label_color(s), label=s) ax.bar_label(bars, fmt="%.2f", padding=2, fontsize=7) ax.set_xticks(x_positions) ax.set_xticklabels(x_values, rotation=45, ha="right") ax.set_xlabel(x_label) ax.set_ylabel(y_label) ax.set_title(title) ax.legend() fig.tight_layout() return fig
[docs]def plot_score_heatmap( df: pd.DataFrame, x_col: str, y_col: str, value_col: str, title: str, x_label: str, y_label: str, ) -> plt.Figure: """Heatmap of scalar scores with x_col on X axis and y_col on Y axis. If multiple rows share the same (x_col, y_col) pair, their values are averaged. Each cell is annotated with its numeric value. """ pivot = df.groupby([y_col, x_col])[value_col].mean().unstack(x_col) n_x = pivot.shape[1] n_y = pivot.shape[0] fig, ax = plt.subplots(figsize=(max(6, n_x * 0.7), max(3, n_y * 0.6))) im = ax.imshow(pivot.values, aspect="auto", cmap="RdYlGn", vmin=0, vmax=1) plt.colorbar(im, ax=ax) for i in range(n_y): for j in range(n_x): val = pivot.values[i, j] if not np.isnan(val): ax.text(j, i, f"{val:.2f}", ha="center", va="center", fontsize=8) ax.set_xticks(range(n_x)) ax.set_xticklabels(pivot.columns, rotation=45, ha="right") ax.set_yticks(range(n_y)) ax.set_yticklabels(pivot.index) ax.set_xlabel(x_label) ax.set_ylabel(y_label) ax.set_title(title) fig.tight_layout() return fig
[docs]def plot_frame_line( arrays: list[LoadedArray], base_title: str, y_label: str, series_col: str | None = None, split_by: GroupBySpec = _NO_SPLIT, ) -> dict[str, plt.Figure]: """Generate per-frame line charts from raw arrays. For each split_by group, produces a line chart with frames on X axis and mean metric value (across subjects, cameras, fields) on Y axis. Different series (e.g. algorithms) are shown as separate colored lines. Returns a dict mapping filename -> figure. """ from .summary import arrays_to_dataframe df = arrays_to_dataframe(arrays) # Determine which columns to group by for each line extra = [series_col] if series_col else [] present_split = split_by.intersect(df.columns) if split_by.dims else [] # Group by split + series + frame, average over the rest group_cols = present_split + extra + ["frame"] agg_df = df.groupby(group_cols)["value"].mean().reset_index() if not present_split: fig = _plot_frame_line_single(agg_df, base_title, y_label, series_col) return {f"{base_title}_per_frame": fig} figures: dict[str, plt.Figure] = {} for keys, group_df in agg_df.groupby(present_split): if not isinstance(keys, tuple): keys = (keys,) label = " | ".join(f"{c}={v}" for c, v in zip(present_split, keys)) title = f"{base_title}{label}" filename = f"{base_title}_per_frame_{'_'.join(str(v) for v in keys)}" figures[filename] = _plot_frame_line_single(group_df, title, y_label, series_col) return figures
def _plot_frame_line_single( df: pd.DataFrame, title: str, y_label: str, series_col: str | None = None, ) -> plt.Figure: """Render one per-frame line chart.""" fig, ax = plt.subplots(figsize=(12, 5)) df = df.copy() df["frame"] = pd.to_numeric(df["frame"], errors="coerce") if series_col is None or series_col not in df.columns: df = df.sort_values("frame") ax.plot(df["frame"], df["value"], color=_TAB10[0], linewidth=0.8, alpha=0.8) else: for name, group in df.groupby(series_col): group = group.sort_values("frame") ax.plot(group["frame"], group["value"], color=_label_color(name), label=name, linewidth=0.8, alpha=0.8) ax.legend() ax.set_xlabel("Frame") ax.set_ylabel(y_label) ax.set_title(title) fig.tight_layout() return fig
[docs]def plot_candle_per_group( arrays: list[LoadedArray], x_col: str, base_title: str, y_label: str, series_col: str | None = None, split_by: GroupBySpec = _NO_SPLIT, ) -> dict[str, plt.Figure]: """Generate one candle chart per unique combination of split_by columns. Computes candle statistics (min, q25, median, q75, max) from raw arrays. group_by is derived automatically as split_by + x_col + series_col (if provided). Returns a dict mapping filename -> figure. """ from ...configs.schemas.evaluation_aggr import AggSpec from .summary import summarize_with_group_by extra = [series_col] if series_col else [] group_by = GroupBySpec(dims=(split_by.dims or []) + [x_col] + extra) df = summarize_with_group_by(arrays, group_by=group_by, agg=AggSpec.of_type("min", "q25", "median", "q75", "max")) if series_col and series_col not in df.columns: logging.warning( f"Candle chart skipped for '{base_title}': " f"series column '{series_col}' not found in data. Available columns: {list(df.columns)}" ) return {} present = split_by.intersect(df.columns) if not present: fig = _plot_candle_single(df, x_col, base_title, y_label, series_col) return {f"{base_title}": fig} figures: dict[str, plt.Figure] = {} for keys, group_df in df.groupby(present): if not isinstance(keys, tuple): keys = (keys,) label = " | ".join(f"{c}={v}" for c, v in zip(present, keys)) title = f"{base_title}{label}" filename = f"{base_title}_{'_'.join(str(v) for v in keys)}" figures[filename] = _plot_candle_single(group_df, x_col, title, y_label, series_col) return figures
[docs]def plot_roc_curves( curves: list[tuple[str, np.ndarray, np.ndarray, float]], base_title: str, ) -> plt.Figure: """ROC curve figure with one line per group. Args: curves: list of (label, fpr, tpr, auc) tuples, one per group row. base_title: figure suptitle. """ fig, ax = plt.subplots(figsize=(6, 6)) for label, fpr, tpr, auc in curves: ax.plot(fpr, tpr, color=_label_color(label), linewidth=1.5, label=f"{label} (AUC={auc:.3f})") ax.plot([0, 1], [0, 1], color="gray", linewidth=0.8, linestyle="--") ax.set_xlim(0, 1) ax.set_ylim(0, 1.02) ax.set_xlabel("False Positive Rate") ax.set_ylabel("True Positive Rate") ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.12), fontsize=8) ax.set_title(base_title) fig.tight_layout() return fig
[docs]def plot_pr_curves( curves: list[tuple[str, np.ndarray, np.ndarray, float, float]], base_title: str, ) -> plt.Figure: """Precision-recall curve figure with one line per group. Each group also gets a dashed horizontal baseline at its positive class prevalence, which is what a random classifier achieves at every threshold. Args: curves: list of (label, recall, precision, ap, prevalence) tuples, one per group row. base_title: figure suptitle. """ fig, ax = plt.subplots(figsize=(6, 6)) for label, recall, precision, ap, prevalence in curves: color = _label_color(label) ax.plot(recall, precision, color=color, linewidth=1.5, label=f"{label} (AP={ap:.3f})") ax.axhline(prevalence, color=color, linewidth=0.8, linestyle="--", alpha=0.6) ax.set_xlim(0, 1) ax.set_ylim(0, 1.02) ax.set_xlabel("Recall") ax.set_ylabel("Precision") ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.12), fontsize=8) ax.set_title(base_title) fig.tight_layout() return fig
[docs]def plot_confusion_matrix_grid( summary: pd.DataFrame, group_levels: list[str], base_title: str, ) -> dict[str, plt.Figure]: """One figure per row in *summary*: CM heatmap + accuracy/precision/recall/F1 bar chart. Returns a dict mapping filename -> figure. """ figures: dict[str, plt.Figure] = {} for _, row in summary.iterrows(): group_parts = [f"{lvl}={row[lvl]}" for lvl in group_levels if lvl in row.index] group_label = "_".join(str(row[lvl]) for lvl in group_levels if lvl in row.index) subtitle = "\n".join(group_parts) if group_parts else "" fig_key = f"cm_{group_label}" if group_label else "cm" cm_vals = np.array([[int(row["tn"]), int(row["fp"])], [int(row["fn"]), int(row["tp"])]]) fig, (ax, ax2) = plt.subplots(1, 2, figsize=(10, 4 + 0.2 * len(group_parts))) im = ax.imshow(cm_vals, interpolation="nearest", cmap="Blues") plt.colorbar(im, ax=ax) ax.set_xticks([0, 1]) ax.set_yticks([0, 1]) ax.set_xticklabels(["Pred Neg", "Pred Pos"]) ax.set_yticklabels(["True Neg", "True Pos"]) for i in range(2): for j in range(2): ax.text(j, i, str(cm_vals[i, j]), ha="center", va="center", fontsize=14) ax.set_title(f"Confusion Matrix\n{subtitle}", fontsize=9) metric_names = ["accuracy", "precision", "recall", "f1"] values = [float(row[m]) for m in metric_names] bars = ax2.bar(metric_names, values, color=_TAB10) ax2.set_ylim(0, 1.1) ax2.set_ylabel("Score") ax2.set_title(f"Metrics\n{subtitle}", fontsize=9) ax2.bar_label(bars, fmt="%.3f", padding=3) fig.suptitle(base_title) fig.tight_layout() figures[fig_key] = fig return figures