diff --git a/src/openbench/engine/whisperkitpro_engine.py b/src/openbench/engine/whisperkitpro_engine.py index 0103320..fd17e00 100644 --- a/src/openbench/engine/whisperkitpro_engine.py +++ b/src/openbench/engine/whisperkitpro_engine.py @@ -13,6 +13,11 @@ logger = get_logger(__name__) + +def _config_str_provided(value: str | None) -> bool: + return value is not None and value.strip() != "" + + COMPUTE_UNITS_MAPPER = { ct.ComputeUnit.CPU_ONLY: "cpuOnly", ct.ComputeUnit.CPU_AND_NE: "cpuAndNeuralEngine", @@ -26,9 +31,10 @@ class WhisperKitProConfig(BaseModel): """Configuration for transcription operations. - Supports two modes: - 1. Legacy: model_version, model_prefix, model_repo_name - 2. New: repo_id, model_variant (downloads models locally) + Supports three modes: + 1. Local: model_dir only (existing directory on disk; no Hugging Face download) + 2. Hugging Face: repo_id + model_variant (downloads unless model_dir already exists) + 3. Legacy: model_version, model_prefix, model_repo_name """ # Legacy fields @@ -56,7 +62,10 @@ class WhisperKitProConfig(BaseModel): ) model_dir: str | None = Field( None, - description="Local path to model directory, if this is provided it will ignore the repo_id and model_variant and use the provided path directly", + description=( + "Local directory passed as --model-path. If set, must exist; repo_id/model_variant are not used for " + "download (no Hugging Face fetch when only model_dir is configured)." + ), ) word_timestamps: bool = Field( True, @@ -128,7 +137,7 @@ def generate_cli_args(self, model_path: Path | None = None) -> list[str]: # Use either --model-path (new) or legacy model args if self.use_model_path: if model_path is None: - raise ValueError("model_path required when using repo_id/model_variant") + raise ValueError("model_path required when using --model-path mode") args = [ "--model-path", str(model_path), @@ -185,26 +194,34 @@ def generate_cli_args(self, model_path: Path | None = None) -> list[str]: @property def use_model_path(self) -> bool: - """Check if we should use --model-path vs legacy args.""" - return self.repo_id is not None and self.model_variant is not None + """Use --model-path when model_dir is set, or when HF repo_id + model_variant are set.""" + if _config_str_provided(self.model_dir): + return True + return _config_str_provided(self.repo_id) and _config_str_provided(self.model_variant) def download_and_prepare_model(self) -> Path: - """Download model from HuggingFace and prepare folder. + """Resolve local model directory or download from Hugging Face. Returns: Path to model directory for --model-path """ if not self.use_model_path: - raise ValueError("download_and_prepare_model requires repo_id and model_variant") + raise ValueError("download_and_prepare_model requires model_dir or repo_id/model_variant") + + if _config_str_provided(self.model_dir): + p = Path(self.model_dir).expanduser().resolve() + if not p.is_dir(): + raise FileNotFoundError( + f"model_dir must be an existing directory (no Hugging Face download when model_dir is set): {self.model_dir}" + ) + logger.info(f"Using local model at: {p}") + return p - # Check if model already exists - if self.model_dir is not None and os.path.exists(self.model_dir): - logger.info(f"Model already exists at: {self.model_dir}") - return Path(self.model_dir) + if not (_config_str_provided(self.repo_id) and _config_str_provided(self.model_variant)): + raise ValueError("repo_id and model_variant are required when model_dir is not set") logger.info(f"Downloading model from {self.repo_id}, variant: {self.model_variant}") - # Download specific model variant folder from HuggingFace try: downloaded_path = snapshot_download(repo_id=self.repo_id, allow_patterns=f"{self.model_variant}/*") return Path(f"{downloaded_path}/{self.model_variant}") @@ -252,10 +269,20 @@ def __init__( # Download and prepare model if using new model management self.model_path = None if self.transcription_config.use_model_path: - logger.debug("Using model path management with repo_id/model_variant") + logger.debug("Using --model-path (local model_dir and/or Hugging Face ids)") self.model_path = self.transcription_config.download_and_prepare_model() else: logger.debug("Using legacy model management") + if not ( + _config_str_provided(self.transcription_config.model_version) + and _config_str_provided(self.transcription_config.model_prefix) + and _config_str_provided(self.transcription_config.model_repo_name) + ): + raise ValueError( + "WhisperKitPro requires one of: model_dir (existing directory), " + "(repo_id and model_variant for Hugging Face), or " + "(model_version, model_prefix, model_repo_name) for legacy CLI args." + ) # Generate CLI args (with model_path if available) self.transcription_args = self.transcription_config.generate_cli_args(model_path=self.model_path) diff --git a/src/openbench/pipeline/pipeline_aliases.py b/src/openbench/pipeline/pipeline_aliases.py index 97287b0..b96ba80 100644 --- a/src/openbench/pipeline/pipeline_aliases.py +++ b/src/openbench/pipeline/pipeline_aliases.py @@ -487,6 +487,20 @@ def register_pipeline_aliases() -> None: description="WhisperKitPro transcription pipeline using the parakeet-v3 version of the model compressed to 494MB. Requires `WHISPERKITPRO_CLI_PATH` env var and depending on your permissions also `WHISPERKITPRO_API_KEY` env var.", ) + PipelineRegistry.register_alias( + "whisperkitpro-local-model", + WhisperKitProTranscriptionPipeline, + default_config={ + "model_dir": os.getenv("WHISPERKITPRO_LOCAL_MODEL_PATH"), + "cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"), + }, + description=( + "WhisperKitPro transcription using only a local model directory (no default Hugging Face repo). " + "Set `WHISPERKITPRO_LOCAL_MODEL_PATH` to the folder passed as `--model-path` on the CLI; it must exist. " + "Requires `WHISPERKITPRO_CLI_PATH` and may require `WHISPERKITPRO_API_KEY`." + ), + ) + PipelineRegistry.register_alias( "groq-whisper-large-v3-turbo", GroqTranscriptionPipeline, diff --git a/src/openbench/pipeline/transcription/transcription_whisperkitpro.py b/src/openbench/pipeline/transcription/transcription_whisperkitpro.py index a88a517..bcb60fa 100644 --- a/src/openbench/pipeline/transcription/transcription_whisperkitpro.py +++ b/src/openbench/pipeline/transcription/transcription_whisperkitpro.py @@ -24,9 +24,7 @@ class WhisperKitProTranscriptionConfig(TranscriptionConfig): """Configuration for WhisperKitPro transcription pipeline. - Supports two modes: - 1. Legacy: model_version, model_prefix, model_repo_name - 2. New: repo_id, model_variant (downloads and manages models) + Supports local model_dir only, Hugging Face repo_id + model_variant, or legacy model fields. """ cli_path: str = Field( @@ -59,7 +57,7 @@ class WhisperKitProTranscriptionConfig(TranscriptionConfig): ) model_dir: str | None = Field( None, - description="Directory to cache downloaded models", + description="Existing local model directory for --model-path (no Hugging Face download when this is the only model source).", ) audio_encoder_compute_units: ComputeUnit = Field(