Source code for nicetoolbox.detectors.method_detectors.base_detector

"""
A template class for Detectors.
"""

import json
import logging
import os
import subprocess
from abc import ABC, abstractmethod
from pathlib import Path

from nicetoolbox_core.entrypoint import SubprocessError

from ...utils.config import save_config
from ...utils.system import detect_os_type


[docs]class BaseDetector(ABC): """ Abstract class to setup and run existing computer vision research code, called method detectors. Attributes: components (list): A list of components associated with the method detector. algorithm (str): The algorithm used for detecting the component. out_folder (str): The output folder. viz_folder (str): The visualization folder. subjects_descr (str): The subjects description. config_path (str): The path to the configuration file. conda_path (str): The path to the conda installation. framework (str): The name of the framework used for the method detector. script_path (str): The path to the script used for the method detector. venv (str): The type of virtual environment used for the method detector. env_name (str): The name of the virtual environment used for the method detector. venv_path (str): The path to the virtual environment used for the method detector. """ def __init__(self, config, io, data, requires_out_folder=True) -> None: """ Sets up the input and output folders required for each method based on the provided configurations. Saves a copy of the configuration file for the method detector. Args: config (dict): Configuration parameters for the detector. io (IO): An instance of the IO class for input/output operations. data (Data): An instance of the Data class for accessing data. requires_out_folder (bool, optional): Indicates whether an output folder is required. Defaults to True. """ # (1) Store general detector/data/io information self.config = config self.io = io self.data = data self.results_folders = dict( (comp, io.get_detector_output_folder(comp, self.algorithm, "result")) for comp in self.components ) main_component = self.components[0] self.out_folder, self.viz_folder = None, None if requires_out_folder: self.out_folder = io.get_detector_output_folder(main_component, self.algorithm, "output") if self.config["visualize"]: self.viz_folder = io.get_detector_output_folder(main_component, self.algorithm, "visualization") self.subjects_descr = data.subjects_descr # (2) Extend the content of the detector config (used during venv inference) self.config["log_file"], self.config["log_level"] = io.get_log_file_level() # Get log file and level self.config["result_folders"] = self.results_folders self.config["out_folder"] = self.out_folder self.config["algorithm"] = self.algorithm self.config["calibration"] = data.calibration self.config["subjects_descr"] = data.subjects_descr self.config["cam_sees_subjects"] = data.camera_mapping["cam_sees_subjects"] self.config.update(data.get_input_recipe()) # Add data recipe to config for dataloader during inference # (3) Save the detector config (for venv inference) self.config_path = os.path.join( io.get_detector_output_folder(main_component, self.algorithm, "run_config"), "run_config.toml", ) save_config(self.config, self.config_path) # (4) Prepare OS specific venv/conda inference settings self.os_type = detect_os_type() self.conda_path = io.get_conda_path() framework = config.get("framework", self.algorithm) self.venv, self.env_name = config["env_name"].split(":") self.script_path = io.get_inference_path(main_component, framework) if self.venv == "venv": self.venv_path = io.get_venv_path(framework, self.env_name) def __str__(self): """ Returns a description of the method detector for printing. Returns: str: A string representation of the method detector, including its components, and the associated algorithm. """ return f"Instance of component {self.components} \n\t" f"algorithm = {self.algorithm} \n\t" + " \n\t".join( [f"{attr} = {value}" for (attr, value) in self.__dict__.items()] )
[docs] def run_inference(self) -> None: """ Runs the inference of the method detector in a separate terminal/cmd window using the specified virtual environment or conda environment. Captures the output and logs the success or failure of the inference. """ logging.info(f"INFERENCE: Launching {self.algorithm} subprocess...") # (1) Create the command to run the method detector command = self._create_command() # (2) Run the command in a separate terminal/cmd window if self.os_type == "windows": cmd_result = subprocess.run( command, capture_output=True, text=True, shell=True, check=False, ) else: cmd_result = subprocess.run( command, capture_output=True, text=True, shell=True, executable="/bin/bash", check=False, ) # (3) Check the return code and log the result if cmd_result.returncode == 0: logging.info(f"INFERENCE: {self.algorithm} finished successfully (Exit 0).") self.post_inference() return # (4) Handle subprocess failure logging.error(f"INFERENCE: {self.algorithm} subprocess failed (Exit 1).") config_path = Path(self.config_path) error_file = config_path.parent / "error.json" if error_file.exists(): try: with open(error_file) as file: remote_exc_json = json.load(file) remote_exc = SubprocessError(**remote_exc_json) # Raise. Now we need to figure out where to catch and # ignore this raise of the exception based on ErrorLevel. raise RuntimeError( f"Subprocess raised {remote_exc.exception_type}: {remote_exc.message}\n\n" f"--- Remote Traceback ---\n\n{remote_exc.traceback}" ) except (json.JSONDecodeError, KeyError) as err: raise RuntimeError( f"Subprocess failed and error report is corrupt: {err}\nStderr: {cmd_result.stderr}" ) from err else: # The script died before it could catch the exception and write the json. raise RuntimeError(f"Subprocess Hard Crash (No Error Report):\n{cmd_result.stderr}")
def _create_command(self) -> str: """ Creates the command to run the method detector in a separate terminal/cmd window using the specified virtual environment or conda environment. Returns: str: The command to run the method detector. """ if self.venv == "conda": if self.os_type == "windows": command = ( f"deactivate && " f'cmd "/c conda activate {self.env_name} && ' f'python {self.script_path} {self.config_path}"' ) elif self.os_type == "linux": conda_path = os.path.join(self.conda_path, "bin/activate") python_path = os.path.join(self.conda_path, "envs", self.env_name, "bin/python") command = ( f"conda init bash && source ~/.bashrc && " f"{conda_path} {self.env_name} && " f"{python_path} {self.script_path} {self.config_path}" ) elif self.venv == "venv": if self.os_type == "windows": command = f'cmd "/c {self.venv_path} && ' f'python {self.script_path} {self.config_path}"' elif self.os_type == "linux": command = f"source {self.venv_path} && " f"python {self.script_path} {self.config_path}" else: print(f"WARNING! venv '{self.venv}' is not known. " f"Detector not running.") return command
[docs] @abstractmethod def post_inference(self) -> None: # noqa: B027 """ Post-processing after inference. This method is called after the inference step and is used for any post-processing tasks that need to be performed. """ pass
@property @abstractmethod def components(self) -> list[str]: """ Abstract property that returns the components of the method detector. This property should be implemented in the derived classes to specify the components that the method detector is associated with. Returns: list: A list of strings representing the components associated with the method detector. Raises: NotImplementedError: If the property is not set in the derived classes. """ raise NotImplementedError @property @abstractmethod def algorithm(self) -> str: """ Abstract property that returns the algorithm of the method detector. This property should be implemented in the derived classes to specify the algorithm that the method detector is associated with. Returns: str: A string representing the algorithm associated with the method detector Raises: NotImplementedError: If the property is not set in the derived classes. """ raise NotImplementedError
[docs] @abstractmethod def visualization(self, data) -> None: """ Abstract method to visualize the output of the method, preferably as a video. This method is intended to generate a visual representation of the method detector's output. The visualization should be saved in the self.viz_folder. Args: data (any): The data to be visualized. The type and content of this parameter depend on the specific implementation of the method detector. Returns: None. This method does not return any value. However, it should save the visualization in the self.viz_folder. Raises: NotImplementedError: If this method is not implemented in the derived classes. """ pass