Source code for nicetoolbox.utils.comparison

from dataclasses import dataclass
from pathlib import Path

import numpy as np
import pytest


[docs]@dataclass class DifferentKeysError: filename: str expected_keys: list[str] generated_keys: list[str]
[docs]@dataclass class DifferentShapeError: filename: str key: str expected_shape: tuple generated_shape: tuple
[docs]@dataclass class DifferentValuesError(Exception): filename: str key: str max_abs_diff: float = np.nan mean_abs_diff: float = np.nan std: float = np.nan
[docs]def compare_npz_files(generated_path: Path, golden_path: Path, rtol: float = 1e-5, atol: float = 1e-8) -> list: try: generated_data = np.load(generated_path, allow_pickle=True) golden_data = np.load(golden_path, allow_pickle=True) except Exception as e: pytest.fail(f"Could not load NPZ files for comparison: {e}") errors = [] generated_keys = sorted(generated_data.files) golden_keys = sorted(golden_data.files) if generated_keys != golden_keys: errors.append(DifferentKeysError(str(generated_path), golden_keys, generated_keys)) for key in golden_data.files: if key not in generated_data: continue gen_arr = generated_data[key] gold_arr = golden_data[key] if gen_arr.shape != gold_arr.shape: errors.append(DifferentShapeError(str(generated_path), key, gold_arr.shape, gen_arr.shape)) continue if np.issubdtype(gold_arr.dtype, np.floating): if not np.allclose(gen_arr, gold_arr, rtol=rtol, atol=atol, equal_nan=True): abs_dif = np.abs(gen_arr - gold_arr) max_abs_diff = np.nanmax(abs_dif) mean_abs_diff = np.nanmean(abs_dif) std = np.nanstd(gen_arr - gold_arr) errors.append(DifferentValuesError(str(generated_path), key, max_abs_diff, mean_abs_diff, std)) else: if not np.array_equal(gen_arr, gold_arr): errors.append(DifferentValuesError(str(generated_path), key)) return errors