diff --git a/src/jabs/scripts/cli/cli.py b/src/jabs/scripts/cli/cli.py index d3f8fb8e..0e89c406 100644 --- a/src/jabs/scripts/cli/cli.py +++ b/src/jabs/scripts/cli/cli.py @@ -322,9 +322,24 @@ def prune(ctx: click.Context, directory: Path, behavior: str | None): ) @click.option( "--grouping-strategy", - type=click.Choice(["video", "individual"], case_sensitive=False), + type=click.Choice(["video", "individual", "filename"], case_sensitive=False), default=None, - help=("Cross validation grouping strategy. If not provided, use the project setting."), + help=( + "Cross validation grouping strategy. If not provided, use the project setting. " + "The 'filename' strategy groups videos by a regular expression applied to their " + "filenames and requires --grouping-pattern (or a pattern saved in the project)." + ), +) +@click.option( + "--grouping-pattern", + "grouping_pattern", + type=str, + default=None, + help=( + "Regular expression used to extract a grouping key from each video filename. " + "Only used with '--grouping-strategy filename'. If not provided, the pattern saved " + "in the project settings is used." + ), ) @click.option( "--classifier", @@ -348,6 +363,7 @@ def cross_validation( behavior: str, k: int, grouping_strategy: str | None, + grouping_pattern: str | None, classifier: str, report_file: Path | None, ): @@ -357,16 +373,24 @@ def cross_validation( "Report file must have a .md (Markdown) or .json (JSON) extension." ) - if grouping_strategy and grouping_strategy.lower() == "video": - cv_grouping = CrossValidationGroupingStrategy.VIDEO - elif grouping_strategy and grouping_strategy.lower() == "individual": - cv_grouping = CrossValidationGroupingStrategy.INDIVIDUAL - else: - cv_grouping = None + cv_grouping_by_name = { + "video": CrossValidationGroupingStrategy.VIDEO, + "individual": CrossValidationGroupingStrategy.INDIVIDUAL, + "filename": CrossValidationGroupingStrategy.FILENAME_PATTERN, + } + cv_grouping = cv_grouping_by_name[grouping_strategy.lower()] if grouping_strategy else None try: classifier_type = ClassifierType[classifier.upper()] - run_cross_validation(directory, behavior, classifier_type, cv_grouping, k, report_file) + run_cross_validation( + directory, + behavior, + classifier_type, + cv_grouping, + k, + report_file, + grouping_regex=grouping_pattern, + ) except Exception as e: raise click.ClickException(str(e)) from e diff --git a/src/jabs/scripts/cli/cross_validation.py b/src/jabs/scripts/cli/cross_validation.py index 52175114..a90fd897 100644 --- a/src/jabs/scripts/cli/cross_validation.py +++ b/src/jabs/scripts/cli/cross_validation.py @@ -27,6 +27,7 @@ def run_cross_validation( grouping_strategy: CrossValidationGroupingStrategy | None, k: int, report_file: Path | None = None, + grouping_regex: str | None = None, ) -> None: """Run cross-validation for a JABS project from the command line. @@ -41,6 +42,9 @@ def run_cross_validation( k (int): Number of cross-validation splits. Use 0 for max splits. report_file (Path | None): Path to save the training report file. Format will be determined by the extension (.md for markdown or .json for JSON). + grouping_regex (str | None): Regular expression used to extract a grouping key + from each video filename. Only used when ``grouping_strategy`` is + ``FILENAME_PATTERN``. If None, uses the pattern saved in project settings. """ if k < 0: raise ValueError("The number of cross-validation splits 'k' must be non-negative.") @@ -90,6 +94,7 @@ def progress_callback(): features, group_mapping = project.get_labeled_features( behavior, grouping_strategy=grouping_strategy, + grouping_regex=grouping_regex, ) with progress: @@ -178,6 +183,19 @@ def progress_callback(): unit = "cm" if project.feature_manager.distance_unit == ProjectDistanceUnit.CM else "pixel" report_timestamp = datetime.now() behavior_settings = project.settings_manager.get_behavior(behavior) + + # resolve the grouping strategy/regex actually used so the report reflects any + # command-line overrides rather than the project's saved settings. + effective_grouping_strategy = ( + grouping_strategy + if grouping_strategy is not None + else project.settings_manager.cv_grouping_strategy + ) + effective_grouping_regex = ( + grouping_regex + if grouping_regex is not None + else project.settings_manager.cv_grouping_regex + ) training_data = TrainingReportData( behavior_name=behavior, classifier_type=classifier.classifier_name, @@ -193,8 +211,12 @@ def progress_callback(): training_time_ms=elapsed_ms, timestamp=report_timestamp, window_size=behavior_settings["window_size"], - cv_grouping_strategy=project.settings_manager.cv_grouping_strategy, - cv_grouping_regex=project.settings_manager.cv_grouping_regex, + cv_grouping_strategy=effective_grouping_strategy, + cv_grouping_regex=( + effective_grouping_regex + if effective_grouping_strategy == CrossValidationGroupingStrategy.FILENAME_PATTERN + else None + ), ) # Save markdown report diff --git a/tests/scripts/test_cross_validation_cli.py b/tests/scripts/test_cross_validation_cli.py new file mode 100644 index 00000000..14ae00a2 --- /dev/null +++ b/tests/scripts/test_cross_validation_cli.py @@ -0,0 +1,94 @@ +"""Tests for the ``jabs-cli cross-validation`` command option parsing. + +These tests exercise the Click command's translation of ``--grouping-strategy`` / +``--grouping-pattern`` into the arguments passed to +:func:`jabs.scripts.cli.cross_validation.run_cross_validation`. The heavy +``run_cross_validation`` implementation is replaced with a spy so the tests stay +fast and do not require a real JABS project on disk. +""" + +from pathlib import Path +from unittest import mock + +import pytest +from click.testing import CliRunner + +import jabs.scripts.cli.cli as cli_module +from jabs.core.enums import CrossValidationGroupingStrategy +from jabs.scripts.cli.cli import cli + + +@pytest.fixture +def run_cv_spy(monkeypatch: pytest.MonkeyPatch) -> mock.Mock: + """Replace ``run_cross_validation`` (as imported into cli.py) with a spy.""" + spy = mock.Mock() + monkeypatch.setattr(cli_module, "run_cross_validation", spy) + return spy + + +def _invoke(tmp_path: Path, *extra_args: str): + """Invoke the cross-validation command against ``tmp_path`` with extra args.""" + runner = CliRunner() + return runner.invoke( + cli, + ["cross-validation", str(tmp_path), "--behavior", "Walk", *extra_args], + ) + + +@pytest.mark.parametrize( + ("strategy_arg", "expected"), + [ + ("video", CrossValidationGroupingStrategy.VIDEO), + ("individual", CrossValidationGroupingStrategy.INDIVIDUAL), + ("filename", CrossValidationGroupingStrategy.FILENAME_PATTERN), + ("FILENAME", CrossValidationGroupingStrategy.FILENAME_PATTERN), + ], + ids=["video", "individual", "filename", "filename-uppercase"], +) +def test_grouping_strategy_maps_to_enum( + tmp_path: Path, + run_cv_spy: mock.Mock, + strategy_arg: str, + expected: CrossValidationGroupingStrategy, +) -> None: + """``--grouping-strategy`` values (case-insensitive) map to the right enum.""" + result = _invoke(tmp_path, "--grouping-strategy", strategy_arg) + + assert result.exit_code == 0, result.output + run_cv_spy.assert_called_once() + # grouping_strategy is the 4th positional argument + assert run_cv_spy.call_args.args[3] == expected + + +def test_filename_pattern_passed_as_grouping_regex(tmp_path: Path, run_cv_spy: mock.Mock) -> None: + """``--grouping-pattern`` is forwarded as the ``grouping_regex`` keyword.""" + result = _invoke( + tmp_path, + "--grouping-strategy", + "filename", + "--grouping-pattern", + r"^(\w+?)_", + ) + + assert result.exit_code == 0, result.output + run_cv_spy.assert_called_once() + assert run_cv_spy.call_args.args[3] == CrossValidationGroupingStrategy.FILENAME_PATTERN + assert run_cv_spy.call_args.kwargs["grouping_regex"] == r"^(\w+?)_" + + +def test_no_strategy_defaults_to_none(tmp_path: Path, run_cv_spy: mock.Mock) -> None: + """Omitting the strategy/pattern defers to project settings (None passed through).""" + result = _invoke(tmp_path) + + assert result.exit_code == 0, result.output + run_cv_spy.assert_called_once() + assert run_cv_spy.call_args.args[3] is None + assert run_cv_spy.call_args.kwargs["grouping_regex"] is None + + +def test_invalid_grouping_strategy_rejected(tmp_path: Path, run_cv_spy: mock.Mock) -> None: + """An unknown strategy is rejected by Click before run_cross_validation is called.""" + result = _invoke(tmp_path, "--grouping-strategy", "bogus") + + assert result.exit_code != 0 + run_cv_spy.assert_not_called()