Source code for nicetoolbox.detectors.method_detectors.base_method

"""
Base class for Method Detectors.
Method detectors run external inference scripts in separate virtual environments.
"""

import inspect
import json
import logging
import os
import subprocess
from abc import abstractmethod
from pathlib import Path
from typing import final

from nicetoolbox_core.entrypoint import SubprocessError, get_subprocess_error_path

from ...configs.schemas.detectors_instances_configs import MethodDetectorRuntime
from ...configs.video_runtime_config import SequenceRuntimeConfig
from ...utils.base_detectors import flatten_inference_config
from ...utils.config import save_config
from ...utils.hf_token import effective_hf_hub_token
from ...utils.system import detect_os_type
from ..base_detector import BaseDetector
from ..data import SequenceData
from ..in_out import SequenceIO


[docs]class BaseMethod(BaseDetector): """ Abstract base class for method detectors. Method detectors run inference in external virtual environments via subprocess. """ # If set, used for get_inference_path instead of components[0]. inference_package_name: str | None = None runtime: MethodDetectorRuntime # subprocess settings os_type: str conda_path: Path venv: str env_name: str script_path: str # method specific settings visualize: bool requires_out_folder: bool out_folders: dict[str, str] result_folders: dict[str, str] viz_folders: dict[str, str] config_paths: list[Path] @final def __init__( self, io: SequenceIO, data: SequenceData, sequence_context: SequenceRuntimeConfig, algorithm_instance: str, ) -> None: """ Initialize base method detector with references. """ # (1) Call BaseDetector __init__() super().__init__(io, data, sequence_context, algorithm_instance) logging.info( f"Initializing method detector {self.__class__.__name__} with instance '{self.algorithm_instance}' " f"and component(s) {self.components}." ) # (2) Setup subprocess execution settings self._setup_subprocess_settings() # (3) Setup method detector (builds runtime, validates, saves config) self.requires_out_folder = getattr(self.detector_config, "visualize", False) self.out_folders = self.compute_output_folders(self.requires_out_folder) self.result_folders = self.compute_result_folders() self.viz_folders = self.compute_viz_folders(self.visualize) self.runtime = self._initialize_detector() # This will be overloaded by child detectors! # (4) Flatten config for subprocess inference_config = flatten_inference_config(self.detector_config, self.runtime) # Pre-map single component out_folder and viz_folder for backward compatibility... TODO if len(self.components) == 1: comp = self.components[0] self.out_folder = self.out_folders[comp] self.viz_folder = self.viz_folders[comp] inference_config["out_folder"] = self.out_folder inference_config["viz_folder"] = self.viz_folder # (5) Save config for subprocess for each component folder = self.io.get_detector_output_folder(self.components[0], self.algorithm_instance, "run_config") self.config_path = folder / "run_config.toml" save_config(inference_config, self.config_path) for comp in self.components[1:]: dup = self.io.get_detector_output_folder(comp, self.algorithm_instance, "run_config") / "run_config.toml" save_config(inference_config, dup) logging.info("Inference preparation completed.\n") def _initialize_detector(self) -> MethodDetectorRuntime: """ Create runtime configuration. Override in subclasses that need extended runtime fields. The override should: 1. Call super()._create_runtime() to get base runtime 2. Return an instance of the detector's RuntimeConfig class with additional fields Returns: MethodDetectorRuntime or subclass with detector-specific fields """ return MethodDetectorRuntime( nicetoolbox_root=str(self.io.code_folder), result_folders=self.result_folders, out_folders=self.out_folders, viz_folders=self.viz_folders, algorithm=self.algorithm_instance, visualize=self.visualize, subjects_descr=self.data.subjects_descr, log_file=str(self.sequence_context.log_file), log_level=self.sequence_context.log_level, calibration=self.data.calibration, cam_sees_subjects=self.data.cam_sees_subjects, input_recipes=self.data.get_input_recipes(), ) def _resolve_inference_script(self) -> Path: """ Resolve the inference script path relative to the concrete subclass's file. Convention: <same_folder_as_class>/<algorithm_type>_inference.py """ class_file = Path(inspect.getfile(type(self))) script = class_file.parent / f"{self.detector_config.algorithm_type}_inference.py" if not script.exists(): logging.error(f"Detector inference file {script} does not exist!") raise FileNotFoundError(script) return script def _setup_subprocess_settings(self) -> None: """Setup OS-specific subprocess execution settings.""" self.os_type = detect_os_type() self.conda_path = self.io.get_conda_path() env_name = getattr(self.detector_config, "env_name", "venv:nicetoolbox") # Default to nicetoolbox env self.venv, self.env_name = env_name.split(":") self.script_path = self._resolve_inference_script() if self.venv == "venv": self.venv_path = self.io.get_venv_path(self.detector_config.algorithm_type, self.env_name) def _subprocess_env(self) -> dict: """Copy os.environ and set HF_TOKEN from machine_specific_paths.toml when configured.""" env = os.environ.copy() tok = effective_hf_hub_token(self.sequence_context.machine) if tok: env["HF_TOKEN"] = tok return env # ------------------------------------------------------------------------- # BaseDetector Interface Implementation # -------------------------------------------------------------------------
[docs] def run(self) -> None: """ Execute method detector: run subprocess inference + post_inference. Returns None - visualization uses external data. """ self._run_inference()
def _run_inference(self) -> None: """Run the inference subprocess.""" logging.info(f"INFERENCE: Launching {self.algorithm_instance} subprocess...") command = self._create_command() # delete error from previois runs, before starting processing error_file = get_subprocess_error_path(self.config_path) error_file.unlink(missing_ok=True) sub_env = self._subprocess_env() if self.os_type == "windows": cmd_result = subprocess.run(command, capture_output=True, text=True, shell=True, check=False, env=sub_env) else: cmd_result = subprocess.run( command, capture_output=True, text=True, shell=True, executable="/bin/bash", check=False, env=sub_env, ) if cmd_result.returncode == 0: logging.info(f"INFERENCE: {self.algorithm_instance} finished successfully (Exit 0).") self.post_inference() return logging.error(f"INFERENCE: {self.algorithm_instance} subprocess failed (Exit 1).") self._handle_subprocess_error(cmd_result) def _create_command(self) -> str: """Create the shell command to run inference.""" script = f'"{self.script_path}"' config = f'"{self.config_path}"' if self.venv == "conda": if self.os_type == "windows": conda_env_path = os.path.join(self.io.code_folder, "envs", self.env_name) # fmt: off command = ( f"deactivate && " f'cmd /s /c "conda activate {conda_env_path} && ' f'python {script} {config}"' ) # fmt: on else: python_path = os.path.join(self.io.code_folder, "envs", self.env_name, "bin/python") command = f"'{python_path}' {script} {config}" elif self.venv == "venv": if self.os_type == "windows": command = f'cmd /s /c ""{self.venv_path}" && python {script} {config}"' else: command = f"source '{self.venv_path}' && " f"python {script} {config}" else: raise ValueError(f"Unknown venv type '{self.venv}'. Expected 'conda' or 'venv'.") return command def _handle_subprocess_error(self, cmd_result) -> None: """Handle subprocess failure by checking for error.json.""" error_file = get_subprocess_error_path(self.config_path) if error_file.exists(): try: with open(error_file) as file: remote_exc_json = json.load(file) remote_exc = SubprocessError(**remote_exc_json) raise RuntimeError( f"Subprocess raised {remote_exc.exception_type}: {remote_exc.message}\n\n" f"--- Remote Traceback ---\n\n{remote_exc.traceback}" ) except (json.JSONDecodeError, KeyError) as err: raise RuntimeError( f"Subprocess failed and error report is corrupt: {err}\n" f"Stderr: {cmd_result.stderr}" ) from err else: raise RuntimeError(f"Subprocess Hard Crash (No Error Report):\n{cmd_result.stderr}")
[docs] @abstractmethod def post_inference(self) -> None: """Post-processing after inference completes.""" pass