Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .python-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.10
3.11
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "openbench"
version = "0.1.0"
description = "Benchmark suite for speaker diarization"
authors = [{ name = "Argmax Inc.", email = "info@takeargmax.com" }]
requires-python = ">=3.10,<3.13"
requires-python = ">=3.11,<3.13"
readme = "README.md"
dependencies = [
"black>=24.10.0,<25",
Expand Down
43 changes: 43 additions & 0 deletions src/openbench/cli/command_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,49 @@ def get_available_pipelines() -> list[str]:
return PipelineRegistry.list_pipelines()


def parse_pipeline_config_overrides(items: list[str] | None) -> dict[str, str]:
"""Parse a list of `key=value` strings into a config-override dict.

Values are kept as raw strings. Pydantic performs the actual type
coercion when the pipeline config is built — so e.g. `seed=10`,
`temperature=0.9`, and `force_language=true` all work as expected
against the pipeline's Pydantic config, because Pydantic's lax mode
parses ints, floats, and the standard boolean spellings
("true"/"false"/"yes"/"no"/"1"/"0", case-insensitive) from strings.
Quote arguments that contain spaces.

Examples:
--pipeline-config speaker=serena
--pipeline-config seed=42 -pc temperature=0.7
--pipeline-config force_language=true

Args:
items: The raw values typer collected for `--pipeline-config`,
or `None` if the flag wasn't passed.

Returns:
Dict mapping config keys to (string) values. Empty if `items`
is `None` or empty.

Raises:
typer.BadParameter: If any item is missing the `=` separator,
or has an empty key.
"""
overrides: dict[str, str] = {}
if not items:
return overrides

for item in items:
if "=" not in item:
raise typer.BadParameter(f"Invalid --pipeline-config format: '{item}'. Expected key=value (e.g. seed=42)")
key, value = item.split("=", 1)
key = key.strip()
if not key:
raise typer.BadParameter(f"Invalid --pipeline-config format: '{item}'. Empty key before '='.")
overrides[key] = value
return overrides


def get_available_datasets() -> list[str]:
"""Get list of available dataset aliases."""
return DatasetRegistry.list_aliases()
Expand Down
24 changes: 24 additions & 0 deletions src/openbench/cli/commands/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
get_datasets_help_text,
get_metrics_help_text,
get_pipelines_help_text,
parse_pipeline_config_overrides,
validate_dataset_name,
validate_pipeline_dataset_compatibility,
validate_pipeline_metrics_compatibility,
Expand Down Expand Up @@ -176,6 +177,7 @@ def run_alias_mode(
wandb_tags: list[str] | None,
use_keywords: bool | None,
force_language: bool,
pipeline_config: list[str] | None,
verbose: bool,
) -> BenchmarkResult:
"""Run evaluation using pipeline and dataset aliases."""
Expand Down Expand Up @@ -208,6 +210,14 @@ def run_alias_mode(
if verbose:
typer.echo("✅ Force language: enabled")

# Handle generic pipeline config overrides (key=value pairs).
# Values are kept as strings; the pipeline's Pydantic config
# coerces them to int/float/bool/etc when instantiated.
for key, value in parse_pipeline_config_overrides(pipeline_config).items():
pipeline_config_override[key] = value
if verbose:
typer.echo(f"Config override: {key}={value}")

pipeline = PipelineRegistry.create_pipeline(pipeline_name, config=pipeline_config_override)

######### Build Benchmark Config #########
Expand Down Expand Up @@ -345,6 +355,19 @@ def evaluate(
"--force-language",
help="Force language hinting for compatible pipelines",
),
pipeline_config: list[str] | None = typer.Option(
None,
"--pipeline-config",
"-pc",
help=(
"Override one or more pipeline config fields as key=value pairs. "
"The value is parsed as a string; Pydantic coerces it to the field's "
"declared type when the pipeline is instantiated, so ints/floats/bools "
"all just work. Repeat the flag for multiple overrides. Examples: "
"`-pc speaker=serena`, `-pc seed=42 -pc temperature=0.7`, "
"`-pc force_language=true`."
),
),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose output"),
) -> None:
"""Run evaluation benchmarks.
Expand Down Expand Up @@ -406,6 +429,7 @@ def evaluate(
wandb_tags=wandb_tags,
use_keywords=use_keywords,
force_language=force_language,
pipeline_config=pipeline_config,
verbose=verbose,
)
display_result(result)
Expand Down
3 changes: 3 additions & 0 deletions src/openbench/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .dataset_diarization import DiarizationDataset, DiarizationSample
from .dataset_orchestration import OrchestrationDataset, OrchestrationSample
from .dataset_registry import DatasetRegistry
from .dataset_speech_generation import SpeechGenerationDataset, SpeechGenerationSample
from .dataset_streaming_transcription import StreamingDataset, StreamingSample
from .dataset_transcription import TranscriptionDataset, TranscriptionSample

Expand All @@ -24,11 +25,13 @@
"TranscriptionDataset",
"StreamingDataset",
"OrchestrationDataset",
"SpeechGenerationDataset",
# Sample types
"DiarizationSample",
"TranscriptionSample",
"StreamingSample",
"OrchestrationSample",
"SpeechGenerationSample",
# Registry
"DatasetRegistry",
]
14 changes: 14 additions & 0 deletions src/openbench/dataset/dataset_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,20 @@ def register_dataset_aliases() -> None:
description="Common Voice dataset for transcription evaluation with up to 400 samples per language this subset contains only russian",
)

########## SPEECH GENERATION ##########

DatasetRegistry.register_alias(
"customer-service-tts-prompts-vocalized",
DatasetConfig(
dataset_id="argmaxinc/customer-service-tts-prompts-vocalized",
split="validation",
),
supported_pipeline_types={
PipelineType.SPEECH_GENERATION,
},
description="Customer service TTS prompts with vocalized audio for speech generation evaluation.",
)

########## STREAMING TRANSCRIPTION ##########

DatasetRegistry.register_alias(
Expand Down
2 changes: 2 additions & 0 deletions src/openbench/dataset/dataset_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .dataset_base import BaseDataset, DatasetConfig
from .dataset_diarization import DiarizationDataset
from .dataset_orchestration import OrchestrationDataset
from .dataset_speech_generation import SpeechGenerationDataset
from .dataset_streaming_transcription import StreamingDataset
from .dataset_transcription import TranscriptionDataset

Expand Down Expand Up @@ -139,3 +140,4 @@ def has_alias(cls, alias: str) -> bool:
DatasetRegistry.register(PipelineType.ORCHESTRATION, OrchestrationDataset)
DatasetRegistry.register(PipelineType.STREAMING_TRANSCRIPTION, StreamingDataset)
DatasetRegistry.register(PipelineType.TRANSCRIPTION, TranscriptionDataset)
DatasetRegistry.register(PipelineType.SPEECH_GENERATION, SpeechGenerationDataset)
76 changes: 76 additions & 0 deletions src/openbench/dataset/dataset_speech_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# For licensing see accompanying LICENSE.md file.
# Copyright (C) 2026 Argmax, Inc. All Rights Reserved.

import numpy as np
from typing_extensions import TypedDict

from ..pipeline_prediction import Transcript
from .dataset_base import BaseDataset, BaseSample


class SpeechGenerationExtraInfo(TypedDict, total=False):
"""Extra info for speech generation samples."""

language: str


class SpeechGenerationRow(TypedDict):
"""Expected row structure for speech generation.

Requires 'text' (the prompt string). No audio needed.
"""

text: str


class SpeechGenerationSample(BaseSample[Transcript, SpeechGenerationExtraInfo]):
"""Sample for speech-generation tasks.

The reference `Transcript` is constructed from the prompt text. The
pipeline synthesizes audio from this prompt and returns a
`GeneratedAudio` prediction; the WER metric transcribes that audio
and compares against this reference.
"""

@property
def text(self) -> str:
"""The original text prompt."""
return self.reference.get_transcript_string()


class SpeechGenerationDataset(BaseDataset[SpeechGenerationSample]):
"""Dataset for speech-generation pipelines.

Expects column: 'text' (the prompt string). No audio column is
required — audio is produced by the pipeline. A dummy waveform is
supplied to satisfy the base sample structure but is ignored
everywhere downstream; the runner reads the real generated-audio
duration off the pipeline output.
"""

_expected_columns = ["text"]
_sample_class = SpeechGenerationSample

def _extract_audio_info(self, row: dict) -> tuple[str, np.ndarray, int]:
"""Provide a placeholder waveform; speech-generation has no input audio."""
audio_name = f"sample_{row['idx']}"
# Use audio_name from the row if available
if "audio_name" in row and row["audio_name"]:
audio_name = str(row["audio_name"])
dummy_waveform = np.zeros(1, dtype=np.float32)
dummy_sample_rate = 16000
return audio_name, dummy_waveform, dummy_sample_rate

def prepare_sample(self, row: SpeechGenerationRow) -> tuple[Transcript, SpeechGenerationExtraInfo]:
"""Build the reference transcript from the prompt text."""
text = row["text"]
words = text.split()
reference = Transcript.from_words_info(
words=words,
)

extra_info: SpeechGenerationExtraInfo = {}
if "language" in row:
extra_info["language"] = row["language"]

return reference, extra_info
4 changes: 4 additions & 0 deletions src/openbench/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
DiarizeCliOutput,
TranscriptionCliInput,
TranscriptionCliOutput,
TtsCliInput,
TtsCliOutput,
resolve_argmax_oss_cache_dir,
)
from .deepgram_engine import DeepgramApi, DeepgramApiResponse
Expand Down Expand Up @@ -33,6 +35,8 @@
"DiarizeCliOutput",
"TranscriptionCliInput",
"TranscriptionCliOutput",
"TtsCliInput",
"TtsCliOutput",
"resolve_argmax_oss_cache_dir",
"DeepgramApi",
"DeepgramApiResponse",
Expand Down
66 changes: 56 additions & 10 deletions src/openbench/engine/argmax_oss_engine.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# For licensing see accompanying LICENSE.md file.
# Copyright (C) 2025 Argmax, Inc. All Rights Reserved.
# Copyright (C) 2026 Argmax, Inc. All Rights Reserved.

"""Argmax SDK open-source CLI (`argmax-cli`) — clone/build, transcribe, and diarize."""

from __future__ import annotations
"""Argmax SDK open-source CLI (`argmax-cli`) — clone/build, transcribe, diarize, and tts."""

import os
import subprocess
Expand All @@ -19,6 +17,12 @@
ARGMAX_OSS_PRODUCT = "argmax-cli"
DEFAULT_CACHE_SUBDIR = Path(".cache") / "openbench" / "argmax-oss"

# Process-wide cache of resolved CLI binary paths, keyed by (cache_root, commit_hash).
# Avoids re-cloning and re-invoking `swift build` when multiple pipelines (e.g. a
# TTS pipeline + the WER metric's transcription pipeline) build their own engine
# instances in the same run.
_CLI_PATH_CACHE: dict[tuple[str, str | None], str] = {}


def resolve_argmax_oss_cache_dir(explicit: str | Path | None = None) -> Path:
"""Absolute cache root for WhisperKit clone + `argmax-cli` build."""
Expand Down Expand Up @@ -79,16 +83,39 @@ class DiarizeCliOutput(BaseModel):
rttm_path: Path = Field(..., description="Written RTTM path")


class TtsCliInput(BaseModel):
"""Input for `argmax-cli tts`."""

text: str = Field(..., description="Text to synthesize.")
output_path: Path = Field(..., description="Destination audio file path (extension picks format).")


class TtsCliOutput(BaseModel):
"""Output from `argmax-cli tts`."""

audio_path: Path = Field(..., description="Generated audio path.")


class ArgmaxOpenSourceEngine:
"""Resolve `argmax-cli`, then run `transcribe` / `diarize` subcommands."""
"""Resolve `argmax-cli`, then run `transcribe` / `diarize` / `tts` subcommands."""

def __init__(self, config: ArgmaxOpenSourceEngineConfig) -> None:
self.config = config
if config.cli_path:
self.cli_path = str(Path(config.cli_path).expanduser().resolve())
logger.info(f"Using Argmax OSS CLI at {self.cli_path}")
else:
self.cli_path = self._clone_and_build_cli()
logger.info("Using Argmax OSS CLI at %s", self.cli_path)
return

cache_root = resolve_argmax_oss_cache_dir(config.cache_dir)
cache_key = (str(cache_root), config.commit_hash)
cached_cli_path = _CLI_PATH_CACHE.get(cache_key)
if cached_cli_path is not None:
logger.info("Reusing cached Argmax OSS CLI at %s", cached_cli_path)
self.cli_path = cached_cli_path
return

self.cli_path = self._clone_and_build_cli(cache_root)
_CLI_PATH_CACHE[cache_key] = self.cli_path

def _build_cli(self, repo_dir: str) -> str:
"""Run release build (swift build -c release, not debug) and return the dir containing the binary."""
Expand Down Expand Up @@ -119,8 +146,7 @@ def _build_cli(self, repo_dir: str) -> str:
logger.info("Built Argmax OSS CLI at %s", cli)
return bin_dir

def _clone_and_build_cli(self) -> str:
cache_root = resolve_argmax_oss_cache_dir(self.config.cache_dir)
def _clone_and_build_cli(self, cache_root: Path) -> str:
cache_root.mkdir(parents=True, exist_ok=True)
repo_url_parts = ARGMAX_OSS_REPO_URL.rstrip("/").split("/")
repo_name = repo_url_parts[-1]
Expand Down Expand Up @@ -201,3 +227,23 @@ def diarize(self, input: DiarizeCliInput, diarize_args: list[str]) -> DiarizeCli
input.audio_path.unlink(missing_ok=True)

return DiarizeCliOutput(rttm_path=input.rttm_path)

def tts(self, input: TtsCliInput, tts_args: list[str]) -> TtsCliOutput:
"""Run `argmax-cli tts` with pre-built flag list (see TTS pipeline config)."""
input.output_path.parent.mkdir(parents=True, exist_ok=True)
cmd = [
self.cli_path,
"tts",
"--text",
input.text,
"--output-path",
str(input.output_path),
*tts_args,
]
logger.debug("Argmax OSS tts: %s", cmd)
try:
subprocess.run(cmd, check=True, capture_output=True, text=True)
except subprocess.CalledProcessError as e:
raise RuntimeError(f"argmax-cli tts failed: {e.stderr}") from e

return TtsCliOutput(audio_path=input.output_path)
1 change: 1 addition & 0 deletions src/openbench/metric/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from .word_error_metrics import (
ConcatenatedMinimumPermutationWER,
SpeechGenerationWordErrorRate,
WordDiarizationErrorRate,
WordErrorRate,
)
Loading
Loading