From 41a53ffd2aceb9785866ecdbc449e959cf5f61fa Mon Sep 17 00:00:00 2001 From: Tyler Riccio Date: Mon, 9 Mar 2026 21:25:10 -0400 Subject: [PATCH 1/9] add tests for briefs (not working yet) --- tests/test_agg.py | 64 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/tests/test_agg.py b/tests/test_agg.py index 95513e5c8..66448d1ee 100644 --- a/tests/test_agg.py +++ b/tests/test_agg.py @@ -1139,3 +1139,67 @@ def test_agg_report_multiple_steps_formatting(): # Step 3: Value with asymmetric tolerance assert "2.0
tol=(0.1, 0.2)" in html + + +def test_brief_auto(): + """Test that auto briefs are generated correctly for aggregation methods.""" + data = pl.DataFrame({"amount": [100, 200, 300]}) + + validation = Validate(data).col_sum_gt(columns="amount", value=500, brief=True).interrogate() + + # Check that brief is set to auto template + assert validation.validation_info[0].brief == "{auto}" + + # Check that the HTML report generates and contains meaningful content + html = validation.get_tabular_report().as_raw_html() + assert html is not None + assert len(html) > 0 + # Auto briefs should mention validation details about the column + assert "amount" in html + + +def test_brief_custom(): + """Test that custom briefs are stored and displayed correctly.""" + data = pl.DataFrame({"sales": [1000, 2000, 3000]}) + + custom_brief = "Validating that total sales exceeds minimum threshold" + + validation = ( + Validate(data).col_avg_eq(columns="sales", value=2000, brief=custom_brief).interrogate() + ) + + # Check that custom brief is stored + assert validation.validation_info[0].brief == custom_brief + + # Check that custom brief appears in HTML report + html = validation.get_tabular_report().as_raw_html() + assert custom_brief in html + + +def test_brief_mixed(): + """Test mixing auto and custom briefs across multiple validation steps.""" + data = pl.DataFrame({"value_a": [10, 20, 30], "value_b": [100, 200, 300]}) + + custom_brief_1 = "First check: sum validation" + + validation = ( + Validate(data) + .col_sum_gt(columns="value_a", value=50, brief=custom_brief_1) + .col_avg_lt(columns="value_b", value=400, brief=True) # auto brief + .interrogate() + ) + + # First step should have custom brief + assert validation.validation_info[0].brief == custom_brief_1 + + # Second step should have auto brief template + assert validation.validation_info[1].brief == "{auto}" + + # Both should appear in HTML report + html = validation.get_tabular_report().as_raw_html() + assert custom_brief_1 in html + # Auto briefs in the report should mention the column being validated + assert "value_b" in html + + +# test_brief_mixed() From 61a3a816d29c488639a6a2f93f2743a2ee18e4a6 Mon Sep 17 00:00:00 2001 From: Tyler Riccio Date: Mon, 9 Mar 2026 21:48:42 -0400 Subject: [PATCH 2/9] implement auto for aggs --- pointblank/_agg.py | 17 ++++++++++-- pointblank/validate.py | 62 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 76 insertions(+), 3 deletions(-) diff --git a/pointblank/_agg.py b/pointblank/_agg.py index 870dff8f2..e3377cd12 100644 --- a/pointblank/_agg.py +++ b/pointblank/_agg.py @@ -76,6 +76,20 @@ def _generic_between(real: Any, lower: Any, upper: Any) -> bool: return bool(lower <= real <= upper) +def split_agg_name(name: str) -> tuple[str, str]: + """Split an aggregation method name into aggregator and comparator names. + + Args: + name (str): The aggregation method name (e.g., "col_sum_eq" or "sum_eq"). + + Returns: + tuple[str, str]: A tuple of (agg_name, comp_name) e.g., ("sum", "eq"). + """ + name = name.removeprefix("col_") + agg_name, comp_name = name.rsplit("_", 1) + return agg_name, comp_name + + def resolve_agg_registries(name: str) -> tuple[Aggregator, Comparator]: """Resolve the assertion name to a valid aggregator @@ -85,8 +99,7 @@ def resolve_agg_registries(name: str) -> tuple[Aggregator, Comparator]: Returns: tuple[Aggregator, Comparator]: The aggregator and comparator functions. """ - name = name.removeprefix("col_") - agg_name, comp_name = name.split("_")[-2:] + agg_name, comp_name = split_agg_name(name) aggregator = AGGREGATOR_REGISTRY.get(agg_name) comparator = COMPARATOR_REGISTRY.get(comp_name) diff --git a/pointblank/validate.py b/pointblank/validate.py index 45d7375dc..bbc533a90 100644 --- a/pointblank/validate.py +++ b/pointblank/validate.py @@ -27,7 +27,12 @@ from great_tables.vals import fmt_integer, fmt_number from importlib_resources import files -from pointblank._agg import is_valid_agg, load_validation_method_grid, resolve_agg_registries +from pointblank._agg import ( + is_valid_agg, + load_validation_method_grid, + resolve_agg_registries, + split_agg_name, +) from pointblank._constants import ( ASSERTION_TYPE_METHOD_MAP, CHECK_MARK_SPAN, @@ -18869,6 +18874,15 @@ def _create_autobrief_or_failure_text( for_failure=for_failure, ) + if is_valid_agg(assertion_type): + return _create_text_agg( + lang=lang, + assertion_type=assertion_type, + column=column, + values=values, + for_failure=for_failure, + ) + return None @@ -18903,6 +18917,52 @@ def _create_text_comparison( ) +def _create_text_agg( + lang: str, + assertion_type: str, + column: str | list[str], + values: dict[str, Any], + for_failure: bool = False, +) -> str: + """Create autobrief text for aggregation methods like col_sum_eq, col_avg_gt, etc.""" + type_ = _expect_failure_type(for_failure=for_failure) + + agg_type, comp_type = split_agg_name(assertion_type) + + # this is covered by the test `test_brief_auto_all_agg_methods` to make sure we don't + # create any weird secret agg constants. + agg_display_names: dict[str, str] = { + "sum": "sum", + "avg": "average", + "sd": "standard deviation", + } + try: + agg_display: str = agg_display_names[agg_type] + except KeyError as ke: # pragma: no cover + raise AssertionError from ke # This should never happen in prod, it's caught in CI. + + # Get the operator + comparison_assertion = f"col_vals_{comp_type}" + if lang == "ar": # pragma: no cover + operator = COMPARISON_OPERATORS_AR.get(comparison_assertion, comp_type) + else: + operator = COMPARISON_OPERATORS.get(comparison_assertion, comp_type) + + column_text = _prep_column_text(column=column) + + value = values.get("value", values) if isinstance(values, dict) else values + values_text = _prep_values_text(values=str(value), lang=lang, limit=3) + + # "Expect that the {agg} of {column} should be {operator} {value}." + agg_expectation_text = EXPECT_FAIL_TEXT[f"compare_{type_}_text"][lang] + + return agg_expectation_text.format( + column_text=f"the {agg_display} of {column_text}", + operator=operator, + values_text=values_text, + ) + + def _create_text_between( lang: str, column: str, From cc1603f10f882fcd6e1e6e90d431d7a9fd5f2143 Mon Sep 17 00:00:00 2001 From: Tyler Riccio Date: Mon, 9 Mar 2026 21:49:39 -0400 Subject: [PATCH 3/9] harden auto tests --- tests/test_agg.py | 44 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/tests/test_agg.py b/tests/test_agg.py index 66448d1ee..4e788b672 100644 --- a/tests/test_agg.py +++ b/tests/test_agg.py @@ -1150,12 +1150,15 @@ def test_brief_auto(): # Check that brief is set to auto template assert validation.validation_info[0].brief == "{auto}" - # Check that the HTML report generates and contains meaningful content + # Check that the HTML report generates auto brief text html = validation.get_tabular_report().as_raw_html() assert html is not None assert len(html) > 0 - # Auto briefs should mention validation details about the column + + # Auto brief should contain references to the aggregation type and column + # Should mention "sum" and "amount" and the comparison assert "amount" in html + assert "sum" in html.lower() def test_brief_custom(): @@ -1177,7 +1180,7 @@ def test_brief_custom(): def test_brief_mixed(): - """Test mixing auto and custom briefs across multiple validation steps.""" + """Test mixing custom and auto brief templates across multiple validation steps.""" data = pl.DataFrame({"value_a": [10, 20, 30], "value_b": [100, 200, 300]}) custom_brief_1 = "First check: sum validation" @@ -1185,7 +1188,7 @@ def test_brief_mixed(): validation = ( Validate(data) .col_sum_gt(columns="value_a", value=50, brief=custom_brief_1) - .col_avg_lt(columns="value_b", value=400, brief=True) # auto brief + .col_avg_lt(columns="value_b", value=400, brief=True) # auto brief template .interrogate() ) @@ -1198,8 +1201,37 @@ def test_brief_mixed(): # Both should appear in HTML report html = validation.get_tabular_report().as_raw_html() assert custom_brief_1 in html - # Auto briefs in the report should mention the column being validated + # Auto brief should mention the column and aggregation type assert "value_b" in html + assert "average" in html.lower() + + +@pytest.mark.parametrize("method", load_validation_method_grid()) +def test_brief_auto_all_agg_methods(method: str): + """Test that auto briefs are generated for all aggregation methods. + + This ensures that the agg_display_names mapping in _create_text_agg + has coverage for all aggregation types (sum, avg, sd). + """ + from pointblank._agg import split_agg_name + + data = pl.DataFrame({"col": [10.0, 20.0, 30.0, 40.0, 50.0]}) + + v = Validate(data) + v = getattr(v, method)(columns="col", value=100, brief=True) + v = v.interrogate() + + assert v.validation_info[0].brief == "{auto}" + + html = v.get_tabular_report().as_raw_html() + assert html is not None + assert len(html) > 0 + assert "col" in html + # Extract agg type from method name to verify it appears in report + # e.g., "col_sum_eq" -> agg_type="sum" + agg_type, _ = split_agg_name(method) + agg_display_map = {"sum": "sum", "avg": "average", "sd": "standard deviation"} + agg_display = agg_display_map.get(agg_type, agg_type) -# test_brief_mixed() + assert agg_display.lower() in html.lower() From 85056e6710212af723c31af2f0af9778a1872cdc Mon Sep 17 00:00:00 2001 From: Tyler Riccio Date: Wed, 11 Mar 2026 20:42:33 -0400 Subject: [PATCH 4/9] add validation to ci --- .github/workflows/code-checks.yaml | 7 +++ .pre-commit-config.yaml | 10 +++- scripts/generate_agg_validate_pyi.py | 76 +++++++++++++++++++++------- 3 files changed, 74 insertions(+), 19 deletions(-) diff --git a/.github/workflows/code-checks.yaml b/.github/workflows/code-checks.yaml index 016db786f..84e8a695a 100644 --- a/.github/workflows/code-checks.yaml +++ b/.github/workflows/code-checks.yaml @@ -17,4 +17,11 @@ jobs: - uses: actions/setup-python@v5 with: python-version: "3.10" + - name: Setup uv + uses: astral-sh/setup-uv@v2 + - name: Check validate.pyi is up to date + run: | + uv run make pyi + git diff --exit-code pointblank/validate.pyi || \ + (echo "validate.pyi is out of date — run 'make pyi' and commit the result" && exit 1) - uses: pre-commit/action@v3.0.1 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a96aff5b0..155c3a0dc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,13 @@ exclude: "(.*\\.svg)|(.*\\.qmd)|(.*\\.ambr)|(.*\\.csv)|(.*\\.txt)|(.*\\.json)|(.*\\.ipynb)|(.*\\.html)" repos: + - repo: local + hooks: + - id: check-pyi-sync + name: validate.pyi must be up to date + language: system + entry: bash -c 'uv run python scripts/generate_agg_validate_pyi.py && git diff --exit-code pointblank/validate.pyi' + pass_filenames: false + stages: [pre-commit] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.6.0 hooks: @@ -8,8 +16,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.9.9 hooks: - # Run the linter. - id: ruff args: [--fix] - # Run the formatter. - id: ruff-format diff --git a/scripts/generate_agg_validate_pyi.py b/scripts/generate_agg_validate_pyi.py index 9abc9b2bc..50fb1b79e 100644 --- a/scripts/generate_agg_validate_pyi.py +++ b/scripts/generate_agg_validate_pyi.py @@ -1,7 +1,9 @@ +import ast import inspect import itertools import subprocess import sys +import textwrap from pathlib import Path from pointblank._agg import AGGREGATOR_REGISTRY, COMPARATOR_REGISTRY, is_valid_agg @@ -14,6 +16,32 @@ VALIDATE_PYI_PATH = Path("pointblank/validate.pyi") + +def _extract_body(func) -> str: + """Extract method body from doctest function using AST parsing. + + Reliably finds the first non-docstring statement and returns the + remaining function body as source code. + """ + source = textwrap.dedent(inspect.getsource(func)) + tree = ast.parse(source) + func_def = tree.body[0] + stmts = func_def.body # ty: ignore + + # Skip leading docstring if present + if stmts and isinstance(stmts[0], ast.Expr) and isinstance(stmts[0].value, ast.Constant): + stmts = stmts[1:] + + if not stmts: + raise ValueError(f"No body found in {func.__name__}") + + source_lines = source.splitlines() + first_line = stmts[0].lineno - 1 # ast line numbers are 1-indexed + last_line = func_def.end_lineno # inclusive + + return "\n".join(line.strip() for line in source_lines[first_line:last_line]) + + SIGNATURE = """ self, columns: _PBUnresolvedColumn, @@ -22,7 +50,7 @@ thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, """ DOCSTRING = """ @@ -53,32 +81,47 @@ from pointblank._typing import Tolerance """ -# Write the headers to the end. Ruff will take care of sorting imports. +# ensure all methods have tests before generating +all_methods = [ + f"col_{agg}_{comp}" + for agg, comp in itertools.product(AGGREGATOR_REGISTRY.keys(), COMPARATOR_REGISTRY.keys()) +] + +missing_tests = [m for m in all_methods if m not in _TEST_FUNCTION_REGISTRY] +if missing_tests: + raise SystemExit(f"Missing doctest entries for: {missing_tests}") + +# all method names should be valid aggregator methods; sanity check +invalid = [m for m in all_methods if not is_valid_agg(m)] +if invalid: + raise SystemExit(f"Invalid agg method names: {invalid}") + +# Read the file and remove any previously generated sections with VALIDATE_PYI_PATH.open() as f: content = f.read() -with VALIDATE_PYI_PATH.open("w") as f: - f.write(IMPORT_HEADER + "\n\n" + content) + +# Remove the GENERATED section if it exists (but keep everything before it) +if "# === GENERATED START ===" in content: + content = content[: content.find("# === GENERATED START ===")].rstrip() +else: + content = content.rstrip() + +# Ensure content ends with newline before appending generated section +content += "\n" ## Create grid of aggs and comparators -with VALIDATE_PYI_PATH.open("a") as f: +with VALIDATE_PYI_PATH.open("w") as f: + f.write(content) f.write(" # === GENERATED START ===\n") for agg_name, comp_name in itertools.product( AGGREGATOR_REGISTRY.keys(), COMPARATOR_REGISTRY.keys() ): method = f"col_{agg_name}_{comp_name}" - assert is_valid_agg(method) # internal sanity check - # Extract examples from the doctest registry. + # Extract examples from the doctest registry using robust AST parsing doctest_fn = _TEST_FUNCTION_REGISTRY[method] - try: - lines_to_skip = len(doctest_fn.__doc__.split("\n")) - except AttributeError: - lines_to_skip = 0 - - lines: list[str] = inspect.getsourcelines(doctest_fn)[0] - cleaned_lines: list[str] = [line.strip() for line in lines] - body: str = "\n".join(cleaned_lines[lines_to_skip + 2 :]) + body: str = _extract_body(doctest_fn) # Add >>> to each line in the body so doctest can run it body_with_arrows: str = "\n".join(f"\t>>> {line}" for line in body.split("\n")) @@ -100,6 +143,5 @@ f.write(" # === GENERATED END ===\n") -## Run formatter and linter on the generated file: +## Run formatter on the generated file: subprocess.run(["uv", "run", "ruff", "format", str(VALIDATE_PYI_PATH)]) -subprocess.run(["uv", "run", "ty", "check", str(VALIDATE_PYI_PATH)]) From 97319ccbc799539e5cef8ff914686e783272649c Mon Sep 17 00:00:00 2001 From: Tyler Riccio Date: Fri, 13 Mar 2026 11:26:46 -0400 Subject: [PATCH 5/9] forgot to run my own type generator --- pointblank/validate.pyi | 118 ++++++++++++++++++++++------------------ 1 file changed, 66 insertions(+), 52 deletions(-) diff --git a/pointblank/validate.pyi b/pointblank/validate.pyi index 666d5fc47..125912c28 100644 --- a/pointblank/validate.pyi +++ b/pointblank/validate.pyi @@ -1,17 +1,15 @@ +import datetime from collections.abc import Collection from dataclasses import dataclass -from pathlib import Path -from typing import Any, Callable, Literal, ParamSpec, TypeVar - from great_tables import GT -from narwhals.typing import FrameT, IntoFrame - -from pointblank import Actions, Thresholds +from narwhals.typing import IntoDataFrame, IntoFrame +from pathlib import Path from pointblank._typing import SegmentSpec, Tolerance from pointblank._utils import _PBUnresolvedColumn from pointblank.column import Column, ColumnSelector, ColumnSelectorNarwhals, ReferenceColumn from pointblank.schema import Schema from pointblank.thresholds import Actions, FinalActions, Thresholds +from typing import Any, Callable, Literal, ParamSpec, TypeVar __all__ = [ "Validate", @@ -54,7 +52,7 @@ def config( def load_dataset( dataset: Literal["small_table", "game_revenue", "nycflights", "global_sales"] = "small_table", tbl_type: Literal["polars", "pandas", "duckdb"] = "polars", -) -> FrameT | Any: ... +) -> Any: ... def read_file(filepath: str | Path) -> Validate: ... def write_file( validation: Validate, @@ -69,7 +67,7 @@ def get_data_path( file_type: Literal["csv", "parquet", "duckdb"] = "csv", ) -> str: ... def preview( - data: FrameT | Any, + data: Any, columns_subset: str | list[str] | Column | None = None, n_head: int = 5, n_tail: int = 5, @@ -77,11 +75,11 @@ def preview( show_row_numbers: bool = True, max_col_width: int = 250, min_tbl_width: int = 500, - incl_header: bool = None, + incl_header: bool | None = None, ) -> GT: ... -def missing_vals_tbl(data: FrameT | Any) -> GT: ... -def get_column_count(data: FrameT | Any) -> int: ... -def get_row_count(data: FrameT | Any) -> int: ... +def missing_vals_tbl(data: Any) -> GT: ... +def get_column_count(data: Any) -> int: ... +def get_row_count(data: Any) -> int: ... @dataclass class _ValidationInfo: @classmethod @@ -102,7 +100,7 @@ class _ValidationInfo: sha1: str | None = ... assertion_type: str | None = ... column: Any | None = ... - values: Any | list[any] | tuple | None = ... + values: Any | list[Any] | tuple | None = ... inclusive: tuple[bool, bool] | None = ... na_pass: bool | None = ... pre: Callable | None = ... @@ -124,13 +122,13 @@ class _ValidationInfo: error: bool | None = ... critical: bool | None = ... failure_text: str | None = ... - tbl_checked: FrameT | None = ... - extract: FrameT | None = ... - val_info: dict[str, any] | None = ... + tbl_checked: Any = ... + extract: Any = ... + val_info: dict[str, Any] | None = ... time_processed: str | None = ... proc_duration_s: float | None = ... notes: dict[str, dict[str, str]] | None = ... - def get_val_info(self) -> dict[str, any]: ... + def get_val_info(self) -> dict[str, Any] | None: ... def _add_note(self, key: str, markdown: str, text: str | None = None) -> None: ... def _get_notes(self, format: str = "dict") -> dict[str, dict[str, str]] | list[str] | None: ... def _get_note(self, key: str, format: str = "dict") -> dict[str, str] | str | None: ... @@ -140,7 +138,7 @@ def connect_to_table(connection_string: str) -> Any: ... def print_database_tables(connection_string: str) -> list[str]: ... @dataclass class Validate: - data: FrameT | Any + data: IntoDataFrame reference: IntoFrame | None = ... tbl_name: str | None = ... label: str | None = ... @@ -150,6 +148,9 @@ class Validate: brief: str | bool | None = ... lang: str | None = ... locale: str | None = ... + owner: str | None = ... + consumers: str | list[str] | None = ... + version: str | None = ... col_names = ... col_types = ... time_start = ... @@ -166,10 +167,10 @@ class Validate: thresholds=None, brief: bool = False, actions=None, - active: bool | Callable = True, + active: bool = True, ): ... def set_tbl( - self, tbl: FrameT | Any, tbl_name: str | None = None, label: str | None = None + self, tbl: Any, tbl_name: str | None = None, label: str | None = None ) -> Validate: ... def _repr_html_(self) -> str: ... def col_vals_gt( @@ -179,7 +180,7 @@ class Validate: na_pass: bool = False, pre: Callable | None = None, segments: SegmentSpec | None = None, - thresholds: int | float | bool | tuple | dict | Thresholds = None, + thresholds: int | float | bool | tuple | dict | Thresholds | None = None, actions: Actions | None = None, brief: str | bool | None = None, active: bool | Callable = True, @@ -191,7 +192,7 @@ class Validate: na_pass: bool = False, pre: Callable | None = None, segments: SegmentSpec | None = None, - thresholds: int | float | bool | tuple | dict | Thresholds = None, + thresholds: int | float | bool | tuple | dict | Thresholds | None = None, actions: Actions | None = None, brief: str | bool | None = None, active: bool | Callable = True, @@ -203,7 +204,7 @@ class Validate: na_pass: bool = False, pre: Callable | None = None, segments: SegmentSpec | None = None, - thresholds: int | float | bool | tuple | dict | Thresholds = None, + thresholds: int | float | bool | tuple | dict | Thresholds | None = None, actions: Actions | None = None, brief: str | bool | None = None, active: bool | Callable = True, @@ -215,7 +216,7 @@ class Validate: na_pass: bool = False, pre: Callable | None = None, segments: SegmentSpec | None = None, - thresholds: int | float | bool | tuple | dict | Thresholds = None, + thresholds: int | float | bool | tuple | dict | Thresholds | None = None, actions: Actions | None = None, brief: str | bool | None = None, active: bool | Callable = True, @@ -227,7 +228,7 @@ class Validate: na_pass: bool = False, pre: Callable | None = None, segments: SegmentSpec | None = None, - thresholds: int | float | bool | tuple | dict | Thresholds = None, + thresholds: int | float | bool | tuple | dict | Thresholds | None = None, actions: Actions | None = None, brief: str | bool | None = None, active: bool | Callable = True, @@ -239,7 +240,7 @@ class Validate: na_pass: bool = False, pre: Callable | None = None, segments: SegmentSpec | None = None, - thresholds: int | float | bool | tuple | dict | Thresholds = None, + thresholds: int | float | bool | tuple | dict | Thresholds | None = None, actions: Actions | None = None, brief: str | bool | None = None, active: bool | Callable = True, @@ -253,7 +254,7 @@ class Validate: na_pass: bool = False, pre: Callable | None = None, segments: SegmentSpec | None = None, - thresholds: int | float | bool | tuple | dict | Thresholds = None, + thresholds: int | float | bool | tuple | dict | Thresholds | None = None, actions: Actions | None = None, brief: str | bool | None = None, active: bool | Callable = True, @@ -267,7 +268,7 @@ class Validate: na_pass: bool = False, pre: Callable | None = None, segments: SegmentSpec | None = None, - thresholds: int | float | bool | tuple | dict | Thresholds = None, + thresholds: int | float | bool | tuple | dict | Thresholds | None = None, actions: Actions | None = None, brief: str | bool | None = None, active: bool | Callable = True, @@ -278,7 +279,7 @@ class Validate: set: Collection[Any], pre: Callable | None = None, segments: SegmentSpec | None = None, - thresholds: int | float | bool | tuple | dict | Thresholds = None, + thresholds: int | float | bool | tuple | dict | Thresholds | None = None, actions: Actions | None = None, brief: str | bool | None = None, active: bool | Callable = True, @@ -289,7 +290,7 @@ class Validate: set: Collection[Any], pre: Callable | None = None, segments: SegmentSpec | None = None, - thresholds: int | float | bool | tuple | dict | Thresholds = None, + thresholds: int | float | bool | tuple | dict | Thresholds | None = None, actions: Actions | None = None, brief: str | bool | None = None, active: bool | Callable = True, @@ -302,7 +303,7 @@ class Validate: na_pass: bool = False, pre: Callable | None = None, segments: SegmentSpec | None = None, - thresholds: int | float | bool | tuple | dict | Thresholds = None, + thresholds: int | float | bool | tuple | dict | Thresholds | None = None, actions: Actions | None = None, brief: str | bool | None = None, active: bool | Callable = True, @@ -315,7 +316,7 @@ class Validate: na_pass: bool = False, pre: Callable | None = None, segments: SegmentSpec | None = None, - thresholds: int | float | bool | tuple | dict | Thresholds = None, + thresholds: int | float | bool | tuple | dict | Thresholds | None = None, actions: Actions | None = None, brief: str | bool | None = None, active: bool | Callable = True, @@ -325,7 +326,7 @@ class Validate: columns: str | list[str] | Column | ColumnSelector | ColumnSelectorNarwhals, pre: Callable | None = None, segments: SegmentSpec | None = None, - thresholds: int | float | bool | tuple | dict | Thresholds = None, + thresholds: int | float | bool | tuple | dict | Thresholds | None = None, actions: Actions | None = None, brief: str | bool | None = None, active: bool | Callable = True, @@ -335,7 +336,7 @@ class Validate: columns: str | list[str] | Column | ColumnSelector | ColumnSelectorNarwhals, pre: Callable | None = None, segments: SegmentSpec | None = None, - thresholds: int | float | bool | tuple | dict | Thresholds = None, + thresholds: int | float | bool | tuple | dict | Thresholds | None = None, actions: Actions | None = None, brief: str | bool | None = None, active: bool | Callable = True, @@ -348,7 +349,7 @@ class Validate: inverse: bool = False, pre: Callable | None = None, segments: SegmentSpec | None = None, - thresholds: int | float | bool | tuple | dict | Thresholds = None, + thresholds: int | float | bool | tuple | dict | Thresholds | None = None, actions: Actions | None = None, brief: str | bool | None = None, active: bool | Callable = True, @@ -360,7 +361,7 @@ class Validate: na_pass: bool = False, pre: Callable | None = None, segments: SegmentSpec | None = None, - thresholds: int | float | bool | tuple | dict | Thresholds = None, + thresholds: int | float | bool | tuple | dict | Thresholds | None = None, actions: Actions | None = None, brief: str | bool | None = None, active: bool | Callable = True, @@ -370,7 +371,7 @@ class Validate: expr: Any, pre: Callable | None = None, segments: SegmentSpec | None = None, - thresholds: int | float | bool | tuple | dict | Thresholds = None, + thresholds: int | float | bool | tuple | dict | Thresholds | None = None, actions: Actions | None = None, brief: str | bool | None = None, active: bool | Callable = True, @@ -378,7 +379,7 @@ class Validate: def col_exists( self, columns: str | list[str] | Column | ColumnSelector | ColumnSelectorNarwhals, - thresholds: int | float | bool | tuple | dict | Thresholds = None, + thresholds: int | float | bool | tuple | dict | Thresholds | None = None, actions: Actions | None = None, brief: str | bool | None = None, active: bool | Callable = True, @@ -398,7 +399,7 @@ class Validate: columns_subset: str | list[str] | None = None, pre: Callable | None = None, segments: SegmentSpec | None = None, - thresholds: int | float | bool | tuple | dict | Thresholds = None, + thresholds: int | float | bool | tuple | dict | Thresholds | None = None, actions: Actions | None = None, brief: str | bool | None = None, active: bool | Callable = True, @@ -408,7 +409,7 @@ class Validate: columns_subset: str | list[str] | None = None, pre: Callable | None = None, segments: SegmentSpec | None = None, - thresholds: int | float | bool | tuple | dict | Thresholds = None, + thresholds: int | float | bool | tuple | dict | Thresholds | None = None, actions: Actions | None = None, brief: str | bool | None = None, active: bool | Callable = True, @@ -422,7 +423,7 @@ class Validate: max_concurrent: int = 3, pre: Callable | None = None, segments: SegmentSpec | None = None, - thresholds: int | float | bool | tuple | dict | Thresholds = None, + thresholds: int | float | bool | tuple | dict | Thresholds | None = None, actions: Actions | None = None, brief: str | bool | None = None, active: bool | Callable = True, @@ -436,37 +437,50 @@ class Validate: case_sensitive_dtypes: bool = True, full_match_dtypes: bool = True, pre: Callable | None = None, - thresholds: int | float | bool | tuple | dict | Thresholds = None, + thresholds: int | float | bool | tuple | dict | Thresholds | None = None, actions: Actions | None = None, brief: str | bool | None = None, active: bool | Callable = True, ) -> Validate: ... def row_count_match( self, - count: int | FrameT | Any, + count: int | Any, tol: Tolerance = 0, inverse: bool = False, pre: Callable | None = None, - thresholds: int | float | bool | tuple | dict | Thresholds = None, + thresholds: int | float | bool | tuple | dict | Thresholds | None = None, + actions: Actions | None = None, + brief: str | bool | None = None, + active: bool | Callable = True, + ) -> Validate: ... + def data_freshness( + self, + column: str, + max_age: str | datetime.timedelta, + reference_time: datetime.datetime | str | None = None, + timezone: str | None = None, + allow_tz_mismatch: bool = False, + pre: Callable | None = None, + thresholds: int | float | bool | tuple | dict | Thresholds | None = None, actions: Actions | None = None, brief: str | bool | None = None, active: bool | Callable = True, ) -> Validate: ... def col_count_match( self, - count: int | FrameT | Any, + count: int | Any, inverse: bool = False, pre: Callable | None = None, - thresholds: int | float | bool | tuple | dict | Thresholds = None, + thresholds: int | float | bool | tuple | dict | Thresholds | None = None, actions: Actions | None = None, brief: str | bool | None = None, active: bool | Callable = True, ) -> Validate: ... def tbl_match( self, - tbl_compare: FrameT | Any, + tbl_compare: Any, pre: Callable | None = None, - thresholds: int | float | bool | tuple | dict | Thresholds = None, + thresholds: int | float | bool | tuple | dict | Thresholds | None = None, actions: Actions | None = None, brief: str | bool | None = None, active: bool | Callable = True, @@ -475,7 +489,7 @@ class Validate: self, *exprs: Callable, pre: Callable | None = None, - thresholds: int | float | bool | tuple | dict | Thresholds = None, + thresholds: int | float | bool | tuple | dict | Thresholds | None = None, actions: Actions | None = None, brief: str | bool | None = None, active: bool | Callable = True, @@ -484,7 +498,7 @@ class Validate: self, expr: Callable, pre: Callable | None = None, - thresholds: int | float | bool | tuple | dict | Thresholds = None, + thresholds: int | float | bool | tuple | dict | Thresholds | None = None, actions: Actions | None = None, brief: str | bool | None = None, active: bool | Callable = True, @@ -528,11 +542,11 @@ class Validate: ) -> dict[int, bool] | bool: ... def get_data_extracts( self, i: int | list[int] | None = None, frame: bool = False - ) -> dict[int, FrameT | None] | FrameT | None: ... + ) -> dict[int, Any] | Any: ... def get_json_report( self, use_fields: list[str] | None = None, exclude_fields: list[str] | None = None ) -> str: ... - def get_sundered_data(self, type: str = "pass") -> FrameT: ... + def get_sundered_data(self, type: str = "pass") -> Any: ... def get_notes( self, i: int, format: str = "dict" ) -> dict[str, dict[str, str]] | list[str] | None: ... From 7667731bb32067775f6bf0ad450326e3da9e6554 Mon Sep 17 00:00:00 2001 From: Tyler Riccio Date: Mon, 16 Mar 2026 19:17:49 -0400 Subject: [PATCH 6/9] run ruff in pyi command directly --- .pre-commit-config.yaml | 4 ++-- Makefile | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 155c3a0dc..68bc02695 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,9 +5,9 @@ repos: - id: check-pyi-sync name: validate.pyi must be up to date language: system - entry: bash -c 'uv run python scripts/generate_agg_validate_pyi.py && git diff --exit-code pointblank/validate.pyi' + entry: bash -c 'make pyi && git diff --exit-code pointblank/validate.pyi' pass_filenames: false - stages: [pre-commit] + stages: [commit] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.6.0 hooks: diff --git a/Makefile b/Makefile index 60f91f723..6338d8cb0 100644 --- a/Makefile +++ b/Makefile @@ -6,6 +6,8 @@ pyi: ## Generate .pyi stub files --include-private \ -o . @uv run scripts/generate_agg_validate_pyi.py + @uv run ruff check --fix pointblank/validate.pyi + @uv run ruff format pointblank/validate.pyi .PHONY: test test: From d250cb0b0f4f67e1bf677b91d0f4fb73c55e0959 Mon Sep 17 00:00:00 2001 From: Tyler Riccio Date: Mon, 16 Mar 2026 20:04:39 -0400 Subject: [PATCH 7/9] repull and regen --- pointblank/validate.pyi | 34 +++++++++++++++------------------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/pointblank/validate.pyi b/pointblank/validate.pyi index f27528cbc..2e825a603 100644 --- a/pointblank/validate.pyi +++ b/pointblank/validate.pyi @@ -1,7 +1,3 @@ -from pointblank import Actions, Thresholds -from pointblank._utils import _PBUnresolvedColumn -from pointblank.column import Column, ReferenceColumn -from pointblank._typing import Tolerance import datetime from collections.abc import Collection from dataclasses import dataclass @@ -585,7 +581,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column sum to a value eq some `value`. @@ -621,7 +617,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column sum to a value gt some `value`. @@ -657,7 +653,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column sum to a value ge some `value`. @@ -693,7 +689,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column sum to a value lt some `value`. @@ -729,7 +725,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column sum to a value le some `value`. @@ -765,7 +761,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column avg to a value eq some `value`. @@ -801,7 +797,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column avg to a value gt some `value`. @@ -837,7 +833,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column avg to a value ge some `value`. @@ -873,7 +869,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column avg to a value lt some `value`. @@ -909,7 +905,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column avg to a value le some `value`. @@ -945,7 +941,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column sd to a value eq some `value`. @@ -981,7 +977,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column sd to a value gt some `value`. @@ -1017,7 +1013,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column sd to a value ge some `value`. @@ -1053,7 +1049,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column sd to a value lt some `value`. @@ -1089,7 +1085,7 @@ class Validate: thresholds: float | bool | tuple | dict | Thresholds | None = None, brief: str | bool = False, actions: Actions | None = None, - active: bool = True, + active: bool | Callable = True, ) -> Validate: """Assert the values in a column sd to a value le some `value`. From 4e3b455e3b209ac20a12b0abb1b88e5e71625ecd Mon Sep 17 00:00:00 2001 From: Tyler Riccio Date: Fri, 20 Mar 2026 18:36:55 -0400 Subject: [PATCH 8/9] extra check for variable key error in polars --- pointblank/validate.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/pointblank/validate.py b/pointblank/validate.py index 0dd84a126..a2d4f9364 100644 --- a/pointblank/validate.py +++ b/pointblank/validate.py @@ -14097,9 +14097,12 @@ def interrogate( is_column_not_found = "column" in error_msg and "not found" in error_msg + # Older Polars versions (< ~1.33) raise KeyError instead of + # ColumnNotFoundError for missing columns in expressions, so we + # need to catch both error shapes. is_comparison_column_not_found = ( "unable to find column" in error_msg and "valid columns" in error_msg - ) + ) or isinstance(e, KeyError) if ( is_comparison_error or is_column_not_found or is_comparison_column_not_found @@ -14131,12 +14134,16 @@ def interrogate( # Add a note for comparison column not found errors elif is_comparison_column_not_found: # Extract column name from error message - # Error format: 'unable to find column "col_name"; valid columns: ...' + # ColumnNotFoundError: 'unable to find column "col_name"; valid columns: ...' + # KeyError (older Polars): "'col_name'" match = re.search(r'unable to find column "([^"]+)"', str(e)) - + missing_col_name = None if match: missing_col_name = match.group(1) + elif isinstance(e, KeyError) and e.args: + missing_col_name = e.args[0] + if missing_col_name is not None: # Determine position for between/outside validations position = None if assertion_type in ["col_vals_between", "col_vals_outside"]: From 7e6de38d80b902d6718fd2f321cfbcc0cccd123d Mon Sep 17 00:00:00 2001 From: Tyler Riccio Date: Tue, 24 Mar 2026 20:44:16 -0400 Subject: [PATCH 9/9] pin ruff to match local CI and github CI --- .pre-commit-config.yaml | 3 ++- pyproject.toml | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 68bc02695..a69606b93 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,8 +13,9 @@ repos: hooks: - id: trailing-whitespace - id: end-of-file-fixer + # NOTE: ruff version must match the pin in pyproject.toml [dependency-groups] dev - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.9.9 + rev: v0.14.10 hooks: - id: ruff args: [--fix] diff --git a/pyproject.toml b/pyproject.toml index 1bc203a61..6b1aa6246 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,7 +107,7 @@ dev = [ "pytest-xdist>=3.6.1", "pytz>=2025.2", "quartodoc>=0.8.1; python_version >= '3.9'", - "ruff>=0.9.9", + "ruff==0.14.10", # NOTE: must match rev in .pre-commit-config.yaml "shiny>=1.4.0", "openpyxl>=3.0.0", "mcp[cli]>=1.10.1",