Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 33 additions & 9 deletions src/jabs/scripts/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,24 @@ def prune(ctx: click.Context, directory: Path, behavior: str | None):
)
@click.option(
"--grouping-strategy",
type=click.Choice(["video", "individual"], case_sensitive=False),
type=click.Choice(["video", "individual", "filename"], case_sensitive=False),
default=None,
help=("Cross validation grouping strategy. If not provided, use the project setting."),
help=(
"Cross validation grouping strategy. If not provided, use the project setting. "
"The 'filename' strategy groups videos by a regular expression applied to their "
"filenames and requires --grouping-pattern (or a pattern saved in the project)."
),
)
@click.option(
"--grouping-pattern",
"grouping_pattern",
type=str,
default=None,
help=(
"Regular expression used to extract a grouping key from each video filename. "
"Only used with '--grouping-strategy filename'. If not provided, the pattern saved "
"in the project settings is used."
),
)
@click.option(
"--classifier",
Expand All @@ -348,6 +363,7 @@ def cross_validation(
behavior: str,
k: int,
grouping_strategy: str | None,
grouping_pattern: str | None,
classifier: str,
report_file: Path | None,
):
Expand All @@ -357,16 +373,24 @@ def cross_validation(
"Report file must have a .md (Markdown) or .json (JSON) extension."
)

if grouping_strategy and grouping_strategy.lower() == "video":
cv_grouping = CrossValidationGroupingStrategy.VIDEO
elif grouping_strategy and grouping_strategy.lower() == "individual":
cv_grouping = CrossValidationGroupingStrategy.INDIVIDUAL
else:
cv_grouping = None
cv_grouping_by_name = {
"video": CrossValidationGroupingStrategy.VIDEO,
"individual": CrossValidationGroupingStrategy.INDIVIDUAL,
"filename": CrossValidationGroupingStrategy.FILENAME_PATTERN,
}
cv_grouping = cv_grouping_by_name[grouping_strategy.lower()] if grouping_strategy else None

try:
classifier_type = ClassifierType[classifier.upper()]
run_cross_validation(directory, behavior, classifier_type, cv_grouping, k, report_file)
run_cross_validation(
directory,
behavior,
classifier_type,
cv_grouping,
k,
report_file,
grouping_regex=grouping_pattern,
)
except Exception as e:
raise click.ClickException(str(e)) from e

Expand Down
26 changes: 24 additions & 2 deletions src/jabs/scripts/cli/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def run_cross_validation(
grouping_strategy: CrossValidationGroupingStrategy | None,
k: int,
report_file: Path | None = None,
grouping_regex: str | None = None,
) -> None:
"""Run cross-validation for a JABS project from the command line.

Expand All @@ -41,6 +42,9 @@ def run_cross_validation(
k (int): Number of cross-validation splits. Use 0 for max splits.
report_file (Path | None): Path to save the training report file.
Format will be determined by the extension (.md for markdown or .json for JSON).
grouping_regex (str | None): Regular expression used to extract a grouping key
from each video filename. Only used when ``grouping_strategy`` is
``FILENAME_PATTERN``. If None, uses the pattern saved in project settings.
"""
if k < 0:
raise ValueError("The number of cross-validation splits 'k' must be non-negative.")
Expand Down Expand Up @@ -90,6 +94,7 @@ def progress_callback():
features, group_mapping = project.get_labeled_features(
behavior,
grouping_strategy=grouping_strategy,
grouping_regex=grouping_regex,
)

with progress:
Expand Down Expand Up @@ -178,6 +183,19 @@ def progress_callback():
unit = "cm" if project.feature_manager.distance_unit == ProjectDistanceUnit.CM else "pixel"
report_timestamp = datetime.now()
behavior_settings = project.settings_manager.get_behavior(behavior)

# resolve the grouping strategy/regex actually used so the report reflects any
# command-line overrides rather than the project's saved settings.
effective_grouping_strategy = (
grouping_strategy
if grouping_strategy is not None
else project.settings_manager.cv_grouping_strategy
)
effective_grouping_regex = (
grouping_regex
if grouping_regex is not None
else project.settings_manager.cv_grouping_regex
)
training_data = TrainingReportData(
behavior_name=behavior,
classifier_type=classifier.classifier_name,
Expand All @@ -193,8 +211,12 @@ def progress_callback():
training_time_ms=elapsed_ms,
timestamp=report_timestamp,
window_size=behavior_settings["window_size"],
cv_grouping_strategy=project.settings_manager.cv_grouping_strategy,
cv_grouping_regex=project.settings_manager.cv_grouping_regex,
cv_grouping_strategy=effective_grouping_strategy,
cv_grouping_regex=(
effective_grouping_regex
if effective_grouping_strategy == CrossValidationGroupingStrategy.FILENAME_PATTERN
else None
),
)

# Save markdown report
Expand Down
94 changes: 94 additions & 0 deletions tests/scripts/test_cross_validation_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Tests for the ``jabs-cli cross-validation`` command option parsing.

These tests exercise the Click command's translation of ``--grouping-strategy`` /
``--grouping-pattern`` into the arguments passed to
:func:`jabs.scripts.cli.cross_validation.run_cross_validation`. The heavy
``run_cross_validation`` implementation is replaced with a spy so the tests stay
fast and do not require a real JABS project on disk.
"""

from pathlib import Path
from unittest import mock

import pytest
from click.testing import CliRunner

import jabs.scripts.cli.cli as cli_module
from jabs.core.enums import CrossValidationGroupingStrategy
from jabs.scripts.cli.cli import cli


@pytest.fixture
def run_cv_spy(monkeypatch: pytest.MonkeyPatch) -> mock.Mock:
"""Replace ``run_cross_validation`` (as imported into cli.py) with a spy."""
spy = mock.Mock()
monkeypatch.setattr(cli_module, "run_cross_validation", spy)
return spy


def _invoke(tmp_path: Path, *extra_args: str):
"""Invoke the cross-validation command against ``tmp_path`` with extra args."""
runner = CliRunner()
return runner.invoke(
cli,
["cross-validation", str(tmp_path), "--behavior", "Walk", *extra_args],
)


@pytest.mark.parametrize(
("strategy_arg", "expected"),
[
("video", CrossValidationGroupingStrategy.VIDEO),
("individual", CrossValidationGroupingStrategy.INDIVIDUAL),
("filename", CrossValidationGroupingStrategy.FILENAME_PATTERN),
("FILENAME", CrossValidationGroupingStrategy.FILENAME_PATTERN),
],
ids=["video", "individual", "filename", "filename-uppercase"],
)
def test_grouping_strategy_maps_to_enum(
tmp_path: Path,
run_cv_spy: mock.Mock,
strategy_arg: str,
expected: CrossValidationGroupingStrategy,
) -> None:
"""``--grouping-strategy`` values (case-insensitive) map to the right enum."""
result = _invoke(tmp_path, "--grouping-strategy", strategy_arg)

assert result.exit_code == 0, result.output
run_cv_spy.assert_called_once()
# grouping_strategy is the 4th positional argument
assert run_cv_spy.call_args.args[3] == expected


def test_filename_pattern_passed_as_grouping_regex(tmp_path: Path, run_cv_spy: mock.Mock) -> None:
"""``--grouping-pattern`` is forwarded as the ``grouping_regex`` keyword."""
result = _invoke(
tmp_path,
"--grouping-strategy",
"filename",
"--grouping-pattern",
r"^(\w+?)_",
)

assert result.exit_code == 0, result.output
run_cv_spy.assert_called_once()
assert run_cv_spy.call_args.args[3] == CrossValidationGroupingStrategy.FILENAME_PATTERN
assert run_cv_spy.call_args.kwargs["grouping_regex"] == r"^(\w+?)_"


def test_no_strategy_defaults_to_none(tmp_path: Path, run_cv_spy: mock.Mock) -> None:
"""Omitting the strategy/pattern defers to project settings (None passed through)."""
result = _invoke(tmp_path)

assert result.exit_code == 0, result.output
run_cv_spy.assert_called_once()
assert run_cv_spy.call_args.args[3] is None
assert run_cv_spy.call_args.kwargs["grouping_regex"] is None


def test_invalid_grouping_strategy_rejected(tmp_path: Path, run_cv_spy: mock.Mock) -> None:
"""An unknown strategy is rejected by Click before run_cross_validation is called."""
result = _invoke(tmp_path, "--grouping-strategy", "bogus")

assert result.exit_code != 0
run_cv_spy.assert_not_called()
Loading