Source code for nicetoolbox.detectors.method_detectors.whisperx.whisperx_detector

"""
WhisperX method detector class (mock/debug implementation).
"""

import json
import logging
import os

from nicetoolbox_core.audio_loaders import AudioStreamLoader

from ....configs.schemas.detectors_instances_configs import MethodDetectorRuntime, WhisperXConfig
from ....utils.hf_token import effective_hf_hub_token
from ....utils.video import frames_to_video
from ..base_method import BaseMethod


[docs]class WhisperX(BaseMethod): algorithm_type = "whisperx" components = ["audio_transcription", "audio_diarization", "speaker_aligned_transcription"] def _initialize_detector(self) -> MethodDetectorRuntime: """ Initializes the WhisperX detector. """ if not self.data.has_audio(): raise RuntimeError("WhisperX requires audio data but no audio was prepared.") resolved = effective_hf_hub_token(self.sequence_context.machine) if not resolved: raise ValueError( "WhisperX requires a Hugging Face token for pyannote diarization. " "Set hugging_face_token in machine_specific_paths.toml." ) self._resolved_hf_hub_token = resolved # Initialize global dataloader self.audio_loader = AudioStreamLoader( config=self.data.get_input_recipes(), expected_tracks=self.detector_config.track_names ) rt = super()._initialize_detector() return WhisperXConfig.RuntimeConfig(**rt.model_dump(), hf_token=self._resolved_hf_hub_token)
[docs] def post_inference(self) -> None: """ Process individual speaker aligned transcription json outputs into our final json format. Structure: { "track_name": { "total": { "text": "full concatenated transcription text for the track", "start": start_time_of_first_segment, "end": end_time_of_last_segment, }, "segments": [ { "start": segment_start_time, "end": segment_end_time, "text": "segment_transcription_text", "avg_logprob": segment_avg_log_probability, }, ... ], "word_segments": [ { "word": word_text, "start": word_start_time, "end": word_end_time, "score": word log probability score, "speaker": speaker_label provided by pyannote }, ... ], "language": detected_language }, ... } """ folder = self.result_folders["speaker_aligned_transcription"] out_dict = {} for track_name in self.audio_loader.tracks: json_path = os.path.join(self.out_folders["speaker_aligned_transcription"], f"{track_name}.json") if not os.path.exists(json_path): logging.warning(f"No JSON output found for {track_name} in post_inference, skipping.") continue with open(json_path) as f: track_data = json.load(f) segments = track_data["segments"] total_text = "" total_start = None total_end = None if segments: total_text = " ".join(seg["text"].strip() for seg in segments if seg["text"]) total_start = segments[0]["start"] total_end = segments[-1]["end"] # Remove redundant words list from each segment and speaker labels for seg in segments: seg.pop("words", None) seg.pop("speaker", None) out_dict[track_name] = { "total": {"text": total_text, "start": total_start, "end": total_end}, "segments": segments, "word_segments": track_data["word_segments"], "language": track_data["language"], } with open(os.path.join(folder, f"{self.algorithm_instance}.json"), "w") as f: json.dump(out_dict, f, indent=4) logging.info("WhisperX post-inference processing complete. Speaker aligned transcription results collected.")
[docs] def visualization(self, _) -> None: """ Generates visualizations overlaying SRTs subtitles onto video files. Uses the generated SRT files from the extra outputs of the audio transcription component. These SRT files are raw outputs from WhisperX of the final speaker-aligned transcription segments with speaker labels. We create a new video from scratch based on the video frames (if available) or a black background (if no video frames are available) and overlay the SRT subtitles onto it. """ if not self.visualize: return for track_name in self.audio_loader.tracks: # Find generated SRT in detector_output directory (extra outputs) srt_path = os.path.join(self.out_folders["speaker_aligned_transcription"], f"{track_name}.srt") if not os.path.exists(srt_path): logging.warning(f"No SRT found for {track_name}, skipping visualization.") continue if os.path.getsize(srt_path) == 0: logging.warning( f"SRT file {srt_path} is empty for {track_name}. This probably means no person was" " detected for this track. Skipping visualization." ) continue info = self.audio_loader.get_stream_info(track_name) source = info["source_path"] viz_dir = self.viz_folders["speaker_aligned_transcription"] os.makedirs(viz_dir, exist_ok=True) video_out = os.path.join(viz_dir, f"{track_name}.mp4") frame_folder = None video_recipe = self.data.get_input_recipes().video_input_recipe start_frame = self.data.video_start_frame_index frame_limit = None fps = self.data.fps if video_recipe: camera_to_use = info.get("camera") if not camera_to_use or camera_to_use not in video_recipe.camera_names: camera_to_use = self.data.camera_mapping["cam_front"] if camera_to_use and camera_to_use in video_recipe.camera_names: frame_folder = os.path.join(video_recipe.root_path, camera_to_use, "frames") start_frame = video_recipe.range_start frame_limit = video_recipe.range_end logging.info(f"Baking subtitles into {video_out}") logging.info(f"Using frame folder: {frame_folder}" if frame_folder else "Using black background fallback.") frames_to_video( input_folder=frame_folder, out_filename=video_out, fps=fps, start_frame=start_frame, audio_path=source, srt_path=srt_path, frame_limit=frame_limit, )