diff --git a/.github/workflows/code-checks.yaml b/.github/workflows/code-checks.yaml index 016db786f..84e8a695a 100644 --- a/.github/workflows/code-checks.yaml +++ b/.github/workflows/code-checks.yaml @@ -17,4 +17,11 @@ jobs: - uses: actions/setup-python@v5 with: python-version: "3.10" + - name: Setup uv + uses: astral-sh/setup-uv@v2 + - name: Check validate.pyi is up to date + run: | + uv run make pyi + git diff --exit-code pointblank/validate.pyi || \ + (echo "validate.pyi is out of date — run 'make pyi' and commit the result" && exit 1) - uses: pre-commit/action@v3.0.1 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a96aff5b0..a69606b93 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,15 +1,22 @@ exclude: "(.*\\.svg)|(.*\\.qmd)|(.*\\.ambr)|(.*\\.csv)|(.*\\.txt)|(.*\\.json)|(.*\\.ipynb)|(.*\\.html)" repos: + - repo: local + hooks: + - id: check-pyi-sync + name: validate.pyi must be up to date + language: system + entry: bash -c 'make pyi && git diff --exit-code pointblank/validate.pyi' + pass_filenames: false + stages: [commit] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.6.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer + # NOTE: ruff version must match the pin in pyproject.toml [dependency-groups] dev - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.9.9 + rev: v0.14.10 hooks: - # Run the linter. - id: ruff args: [--fix] - # Run the formatter. - id: ruff-format diff --git a/Makefile b/Makefile index fc9a3a79e..7e9b75d69 100644 --- a/Makefile +++ b/Makefile @@ -6,6 +6,8 @@ pyi: ## Generate .pyi stub files --include-private \ -o . @uv run scripts/generate_agg_validate_pyi.py + @uv run ruff check --fix pointblank/validate.pyi + @uv run ruff format pointblank/validate.pyi .PHONY: test test: diff --git a/pointblank/_agg.py b/pointblank/_agg.py index 57c6677ee..d8f7524f3 100644 --- a/pointblank/_agg.py +++ b/pointblank/_agg.py @@ -86,6 +86,20 @@ def _generic_between(real: float, lower: float, upper: float) -> bool: return bool(lower <= real <= upper) +def split_agg_name(name: str) -> tuple[str, str]: + """Split an aggregation method name into aggregator and comparator names. + + Args: + name (str): The aggregation method name (e.g., "col_sum_eq" or "sum_eq"). + + Returns: + tuple[str, str]: A tuple of (agg_name, comp_name) e.g., ("sum", "eq"). + """ + name = name.removeprefix("col_") + agg_name, comp_name = name.rsplit("_", 1) + return agg_name, comp_name + + def resolve_agg_registries(name: str) -> tuple[Aggregator, Comparator]: """Resolve the assertion name to a valid aggregator @@ -95,8 +109,7 @@ def resolve_agg_registries(name: str) -> tuple[Aggregator, Comparator]: Returns: tuple[Aggregator, Comparator]: The aggregator and comparator functions. """ - name = name.removeprefix("col_") - agg_name, comp_name = name.split("_")[-2:] + agg_name, comp_name = split_agg_name(name) aggregator = AGGREGATOR_REGISTRY.get(agg_name) comparator = COMPARATOR_REGISTRY.get(comp_name) diff --git a/pointblank/validate.py b/pointblank/validate.py index 0085b29fc..2a80d38af 100644 --- a/pointblank/validate.py +++ b/pointblank/validate.py @@ -28,7 +28,12 @@ from great_tables.vals import fmt_integer, fmt_number from importlib_resources import files -from pointblank._agg import is_valid_agg, load_validation_method_grid, resolve_agg_registries +from pointblank._agg import ( + is_valid_agg, + load_validation_method_grid, + resolve_agg_registries, + split_agg_name, +) from pointblank._constants import ( ASSERTION_TYPE_METHOD_MAP, CHECK_MARK_SPAN, @@ -18878,6 +18883,15 @@ def _create_autobrief_or_failure_text( for_failure=for_failure, ) + if is_valid_agg(assertion_type): + return _create_text_agg( + lang=lang, + assertion_type=assertion_type, + column=column, + values=values, + for_failure=for_failure, + ) + return None @@ -18912,6 +18926,52 @@ def _create_text_comparison( ) +def _create_text_agg( + lang: str, + assertion_type: str, + column: str | list[str], + values: dict[str, Any], + for_failure: bool = False, +) -> str: + """Create autobrief text for aggregation methods like col_sum_eq, col_avg_gt, etc.""" + type_ = _expect_failure_type(for_failure=for_failure) + + agg_type, comp_type = split_agg_name(assertion_type) + + # this is covered by the test `test_brief_auto_all_agg_methods` to make sure we don't + # create any weird secret agg constants. + agg_display_names: dict[str, str] = { + "sum": "sum", + "avg": "average", + "sd": "standard deviation", + } + try: + agg_display: str = agg_display_names[agg_type] + except KeyError as ke: # pragma: no cover + raise AssertionError from ke # This should never happen in prod, it's caught in CI. + + # Get the operator + comparison_assertion = f"col_vals_{comp_type}" + if lang == "ar": # pragma: no cover + operator = COMPARISON_OPERATORS_AR.get(comparison_assertion, comp_type) + else: + operator = COMPARISON_OPERATORS.get(comparison_assertion, comp_type) + + column_text = _prep_column_text(column=column) + + value = values.get("value", values) if isinstance(values, dict) else values + values_text = _prep_values_text(values=str(value), lang=lang, limit=3) + + # "Expect that the {agg} of {column} should be {operator} {value}." + agg_expectation_text = EXPECT_FAIL_TEXT[f"compare_{type_}_text"][lang] + + return agg_expectation_text.format( + column_text=f"the {agg_display} of {column_text}", + operator=operator, + values_text=values_text, + ) + + def _create_text_between( lang: str, column: str, diff --git a/pointblank/validate.pyi b/pointblank/validate.pyi index 3b8cdd540..2e825a603 100644 --- a/pointblank/validate.pyi +++ b/pointblank/validate.pyi @@ -1,8 +1,3 @@ -from pointblank import Actions, Thresholds -from pointblank._utils import _PBUnresolvedColumn -from pointblank.column import Column, ReferenceColumn -from pointblank._typing import Tolerance - import datetime from collections.abc import Collection from dataclasses import dataclass @@ -586,7 +581,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column sum to a value eq some `value`. @@ -622,7 +617,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column sum to a value gt some `value`. @@ -658,7 +653,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column sum to a value ge some `value`. @@ -694,7 +689,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column sum to a value lt some `value`. @@ -730,7 +725,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column sum to a value le some `value`. @@ -766,7 +761,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column avg to a value eq some `value`. @@ -802,7 +797,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column avg to a value gt some `value`. @@ -838,7 +833,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column avg to a value ge some `value`. @@ -874,7 +869,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column avg to a value lt some `value`. @@ -910,7 +905,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column avg to a value le some `value`. @@ -946,7 +941,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column sd to a value eq some `value`. @@ -982,7 +977,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column sd to a value gt some `value`. @@ -1018,7 +1013,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column sd to a value ge some `value`. @@ -1054,7 +1049,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column sd to a value lt some `value`. @@ -1090,7 +1085,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column sd to a value le some `value`. diff --git a/pyproject.toml b/pyproject.toml index 1bc203a61..6b1aa6246 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,7 +107,7 @@ dev = [ "pytest-xdist>=3.6.1", "pytz>=2025.2", "quartodoc>=0.8.1; python_version >= '3.9'", - "ruff>=0.9.9", + "ruff==0.14.10", # NOTE: must match rev in .pre-commit-config.yaml "shiny>=1.4.0", "openpyxl>=3.0.0", "mcp[cli]>=1.10.1", diff --git a/scripts/generate_agg_validate_pyi.py b/scripts/generate_agg_validate_pyi.py index 9abc9b2bc..50fb1b79e 100644 --- a/scripts/generate_agg_validate_pyi.py +++ b/scripts/generate_agg_validate_pyi.py @@ -1,7 +1,9 @@ +import ast import inspect import itertools import subprocess import sys +import textwrap from pathlib import Path from pointblank._agg import AGGREGATOR_REGISTRY, COMPARATOR_REGISTRY, is_valid_agg @@ -14,6 +16,32 @@ VALIDATE_PYI_PATH = Path("pointblank/validate.pyi") + +def _extract_body(func) -> str: + """Extract method body from doctest function using AST parsing. + + Reliably finds the first non-docstring statement and returns the + remaining function body as source code. + """ + source = textwrap.dedent(inspect.getsource(func)) + tree = ast.parse(source) + func_def = tree.body[0] + stmts = func_def.body # ty: ignore + + # Skip leading docstring if present + if stmts and isinstance(stmts[0], ast.Expr) and isinstance(stmts[0].value, ast.Constant): + stmts = stmts[1:] + + if not stmts: + raise ValueError(f"No body found in {func.__name__}") + + source_lines = source.splitlines() + first_line = stmts[0].lineno - 1 # ast line numbers are 1-indexed + last_line = func_def.end_lineno # inclusive + + return "\n".join(line.strip() for line in source_lines[first_line:last_line]) + + SIGNATURE = """ self, columns: _PBUnresolvedColumn, @@ -22,7 +50,7 @@ thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, """ DOCSTRING = """ @@ -53,32 +81,47 @@ from pointblank._typing import Tolerance """ -# Write the headers to the end. Ruff will take care of sorting imports. +# ensure all methods have tests before generating +all_methods = [ + f"col_{agg}_{comp}" + for agg, comp in itertools.product(AGGREGATOR_REGISTRY.keys(), COMPARATOR_REGISTRY.keys()) +] + +missing_tests = [m for m in all_methods if m not in _TEST_FUNCTION_REGISTRY] +if missing_tests: + raise SystemExit(f"Missing doctest entries for: {missing_tests}") + +# all method names should be valid aggregator methods; sanity check +invalid = [m for m in all_methods if not is_valid_agg(m)] +if invalid: + raise SystemExit(f"Invalid agg method names: {invalid}") + +# Read the file and remove any previously generated sections with VALIDATE_PYI_PATH.open() as f: content = f.read() -with VALIDATE_PYI_PATH.open("w") as f: - f.write(IMPORT_HEADER + "\n\n" + content) + +# Remove the GENERATED section if it exists (but keep everything before it) +if "# === GENERATED START ===" in content: + content = content[: content.find("# === GENERATED START ===")].rstrip() +else: + content = content.rstrip() + +# Ensure content ends with newline before appending generated section +content += "\n" ## Create grid of aggs and comparators -with VALIDATE_PYI_PATH.open("a") as f: +with VALIDATE_PYI_PATH.open("w") as f: + f.write(content) f.write(" # === GENERATED START ===\n") for agg_name, comp_name in itertools.product( AGGREGATOR_REGISTRY.keys(), COMPARATOR_REGISTRY.keys() ): method = f"col_{agg_name}_{comp_name}" - assert is_valid_agg(method) # internal sanity check - # Extract examples from the doctest registry. + # Extract examples from the doctest registry using robust AST parsing doctest_fn = _TEST_FUNCTION_REGISTRY[method] - try: - lines_to_skip = len(doctest_fn.__doc__.split("\n")) - except AttributeError: - lines_to_skip = 0 - - lines: list[str] = inspect.getsourcelines(doctest_fn)[0] - cleaned_lines: list[str] = [line.strip() for line in lines] - body: str = "\n".join(cleaned_lines[lines_to_skip + 2 :]) + body: str = _extract_body(doctest_fn) # Add >>> to each line in the body so doctest can run it body_with_arrows: str = "\n".join(f"\t>>> {line}" for line in body.split("\n")) @@ -100,6 +143,5 @@ f.write(" # === GENERATED END ===\n") -## Run formatter and linter on the generated file: +## Run formatter on the generated file: subprocess.run(["uv", "run", "ruff", "format", str(VALIDATE_PYI_PATH)]) -subprocess.run(["uv", "run", "ty", "check", str(VALIDATE_PYI_PATH)]) diff --git a/tests/test_agg.py b/tests/test_agg.py index ae1316ede..d3dd015e4 100644 --- a/tests/test_agg.py +++ b/tests/test_agg.py @@ -1142,6 +1142,102 @@ def test_agg_report_multiple_steps_formatting(): assert "2.0
tol=(0.1, 0.2)" in html +def test_brief_auto(): + """Test that auto briefs are generated correctly for aggregation methods.""" + data = pl.DataFrame({"amount": [100, 200, 300]}) + + validation = Validate(data).col_sum_gt(columns="amount", value=500, brief=True).interrogate() + + # Check that brief is set to auto template + assert validation.validation_info[0].brief == "{auto}" + + # Check that the HTML report generates auto brief text + html = validation.get_tabular_report().as_raw_html() + assert html is not None + assert len(html) > 0 + + # Auto brief should contain references to the aggregation type and column + # Should mention "sum" and "amount" and the comparison + assert "amount" in html + assert "sum" in html.lower() + + +def test_brief_custom(): + """Test that custom briefs are stored and displayed correctly.""" + data = pl.DataFrame({"sales": [1000, 2000, 3000]}) + + custom_brief = "Validating that total sales exceeds minimum threshold" + + validation = ( + Validate(data).col_avg_eq(columns="sales", value=2000, brief=custom_brief).interrogate() + ) + + # Check that custom brief is stored + assert validation.validation_info[0].brief == custom_brief + + # Check that custom brief appears in HTML report + html = validation.get_tabular_report().as_raw_html() + assert custom_brief in html + + +def test_brief_mixed(): + """Test mixing custom and auto brief templates across multiple validation steps.""" + data = pl.DataFrame({"value_a": [10, 20, 30], "value_b": [100, 200, 300]}) + + custom_brief_1 = "First check: sum validation" + + validation = ( + Validate(data) + .col_sum_gt(columns="value_a", value=50, brief=custom_brief_1) + .col_avg_lt(columns="value_b", value=400, brief=True) # auto brief template + .interrogate() + ) + + # First step should have custom brief + assert validation.validation_info[0].brief == custom_brief_1 + + # Second step should have auto brief template + assert validation.validation_info[1].brief == "{auto}" + + # Both should appear in HTML report + html = validation.get_tabular_report().as_raw_html() + assert custom_brief_1 in html + # Auto brief should mention the column and aggregation type + assert "value_b" in html + assert "average" in html.lower() + + +@pytest.mark.parametrize("method", load_validation_method_grid()) +def test_brief_auto_all_agg_methods(method: str): + """Test that auto briefs are generated for all aggregation methods. + + This ensures that the agg_display_names mapping in _create_text_agg + has coverage for all aggregation types (sum, avg, sd). + """ + from pointblank._agg import split_agg_name + + data = pl.DataFrame({"col": [10.0, 20.0, 30.0, 40.0, 50.0]}) + + v = Validate(data) + v = getattr(v, method)(columns="col", value=100, brief=True) + v = v.interrogate() + + assert v.validation_info[0].brief == "{auto}" + + html = v.get_tabular_report().as_raw_html() + assert html is not None + assert len(html) > 0 + assert "col" in html + + # Extract agg type from method name to verify it appears in report + # e.g., "col_sum_eq" -> agg_type="sum" + agg_type, _ = split_agg_name(method) + agg_display_map = {"sum": "sum", "avg": "average", "sd": "standard deviation"} + agg_display = agg_display_map.get(agg_type, agg_type) + + assert agg_display.lower() in html.lower() + + @pytest.mark.parametrize( ("data_eager", "ref_eager"), list(product([True, False], repeat=2)),