From 1df7f9f37c053330383cea4cfd2462a1d7acde48 Mon Sep 17 00:00:00 2001 From: riso Date: Mon, 1 Jun 2026 17:54:20 +0800 Subject: [PATCH] Fix Lance point lookup pushdown filters --- daft_lance/lance_scan.py | 61 ++++++++++++++----- tests/io/lancedb/test_lancedb_point_lookup.py | 46 +++++++++++++- 2 files changed, 90 insertions(+), 17 deletions(-) diff --git a/daft_lance/lance_scan.py b/daft_lance/lance_scan.py index 5fd3878..4341954 100644 --- a/daft_lance/lance_scan.py +++ b/daft_lance/lance_scan.py @@ -256,7 +256,7 @@ def _create_count_rows_scan_task(self, pushdowns: PyPushdowns) -> Iterator[ScanT yield ScanTask.python_factory_func_scan_task( module=_lancedb_count_result_function.__module__, func_name=_lancedb_count_result_function.__name__, - func_args=(self._ds.uri, open_kwargs, fields[0], self._combine_filters_to_arrow()), + func_args=(self._ds.uri, open_kwargs, fields[0], self._combine_filters_to_arrow(pushdowns)), schema=new_schema._schema, num_rows=1, size_bytes=None, @@ -319,7 +319,7 @@ def _create_regular_scan_tasks( """Create regular scan tasks without count pushdown.""" open_kwargs = getattr(self._ds, "_lance_open_kwargs", None) fragments = self._ds.get_fragments() - pushed_expr = self._combine_filters_to_arrow() + pushed_expr = self._combine_filters_to_arrow(pushdowns) def _python_factory_func_scan_task( fragment_ids: list[int] | None = None, @@ -349,7 +349,7 @@ def _python_factory_func_scan_task( ) # Use index-driven scan for point lookups with BTREE indices or nearest search. - if self._should_use_index_for_point_lookup() or nearest_option is not None: + if self._should_use_index_for_point_lookup(pushdowns) or nearest_option is not None: yield _python_factory_func_scan_task(fragment_ids=None, num_rows=None, size_bytes=None) return @@ -395,8 +395,20 @@ def _python_factory_func_scan_task( fragment_ids = [fragment.fragment_id for fragment in fragment_group] yield _python_factory_func_scan_task(fragment_ids, num_rows=num_rows, size_bytes=size_bytes) - def _combine_filters_to_arrow(self) -> pa.compute.Expression | None: - return combine_filters_to_arrow(self._pushed_filters) + def _effective_filters(self, pushdowns: PyPushdowns | None = None) -> list[PyExpr] | None: + if self._pushed_filters: + return self._pushed_filters + + if self._requires_fragment_scans(): + return None + + filters = getattr(pushdowns, "filters", None) if pushdowns is not None else None + if filters is None: + return None + return [filters] + + def _combine_filters_to_arrow(self, pushdowns: PyPushdowns | None = None) -> pa.compute.Expression | None: + return combine_filters_to_arrow(self._effective_filters(pushdowns)) def _compute_limit_pushdown_with_filter(self, pushdowns: PyPushdowns) -> int | None: """Decide whether to push down `limit` when filters are present.""" @@ -408,19 +420,21 @@ def _compute_limit_pushdown_with_filter(self, pushdowns: PyPushdowns) -> int | N return pushdowns.limit - def _should_use_index_for_point_lookup(self) -> bool: + def _should_use_index_for_point_lookup(self, pushdowns: PyPushdowns | None = None) -> bool: """Use index-driven scan only when all point-lookup columns have BTREE. Otherwise fall back to fragment enumeration. Passing fragment_ids=None signals index-driven scan; factory omits fragments so Lance selects them using indices. """ - if not self._pushed_filters: + if self._requires_fragment_scans(): + return False + + filters = self._effective_filters(pushdowns) + if not filters: return False try: - point_columns = detect_point_lookup_columns( - [Expression._from_pyexpr(expr) for expr in self._pushed_filters] - ) + point_columns = detect_point_lookup_columns([Expression._from_pyexpr(expr) for expr in filters]) except (ValueError, TypeError, AttributeError) as e: logger.warning("Failed to analyze filters for point lookup: %s", e, exc_info=True) return False @@ -452,11 +466,7 @@ def _should_use_index_for_point_lookup(self) -> bool: return True return False - def _nearest_default_option(self) -> dict[str, Any] | None: - """Return the default nearest option configured on the Lance dataset, if any. - - Prefer Daft-specific `_daft_default_scan_options` to preserve options stripped before `lance.dataset` (e.g., `nearest`). - """ + def _default_scan_options(self) -> dict[str, Any] | None: default_opts = getattr(self._ds, "_daft_default_scan_options", None) if not isinstance(default_opts, dict): default_opts = getattr(self._ds, "_default_scan_options", None) @@ -466,6 +476,27 @@ def _nearest_default_option(self) -> dict[str, Any] | None: default_opts = open_kwargs.get("default_scan_options") if not isinstance(default_opts, dict): return None + return default_opts + + def _requires_fragment_scans(self) -> bool: + if self._include_fragment_id: + return True + + default_opts = self._default_scan_options() + if not default_opts: + return False + + row_identity_options = ("with_row_id", "with_row_address", "with_rowaddr") + return any(bool(default_opts.get(option)) for option in row_identity_options) + + def _nearest_default_option(self) -> dict[str, Any] | None: + """Return the default nearest option configured on the Lance dataset, if any. + + Prefer Daft-specific `_daft_default_scan_options` to preserve options stripped before `lance.dataset` (e.g., `nearest`). + """ + default_opts = self._default_scan_options() + if default_opts is None: + return None nearest = default_opts.get("nearest") if nearest is None: diff --git a/tests/io/lancedb/test_lancedb_point_lookup.py b/tests/io/lancedb/test_lancedb_point_lookup.py index 92b07ff..553bbaa 100644 --- a/tests/io/lancedb/test_lancedb_point_lookup.py +++ b/tests/io/lancedb/test_lancedb_point_lookup.py @@ -1,5 +1,8 @@ from __future__ import annotations +from pathlib import Path +from typing import Any, cast + import pyarrow as pa import pytest @@ -26,7 +29,7 @@ def _scan(ds): @pytest.mark.parametrize( "idx_type", [ - pytest.param("BTREE", marks=pytest.mark.xfail(reason="Lance BTREE point lookup detection issue")), + "BTREE", "BITMAP", "BLOOMFILTER", ], @@ -46,7 +49,7 @@ def test_point_lookup_equal_hits_scalar_index(lance_dataset, idx_type): @pytest.mark.parametrize( "idx_type", [ - pytest.param("BTREE", marks=pytest.mark.xfail(reason="Lance BTREE point lookup detection issue")), + "BTREE", "BITMAP", "BLOOMFILTER", ], @@ -129,3 +132,42 @@ def test_to_scan_tasks_runs(lance_dataset, idx_type): # Ensure to_scan_tasks yields at least one task and the factory runs without error tasks = list(scan.to_scan_tasks(py_pushdowns)) assert len(tasks) >= 1 + + +def test_point_lookup_uses_pushdown_filters(tmp_path: Path) -> None: + table = pa.table({"id": list(range(10)), "value": [f"value-{i}" for i in range(10)]}) + ds: Any = lance.write_dataset(table, tmp_path, max_rows_per_file=1) + ds.create_scalar_index("id", "BTREE") + + from daft.daft import PyPushdowns + + scan = lance_scan.LanceDBScanOperator(ds) + py_pushdowns = PyPushdowns(columns=["id", "value"], filters=cast(Any, col("id") == 2)._expr, limit=5) + + assert scan._should_use_index_for_point_lookup(py_pushdowns) is True + tasks = list(scan.to_scan_tasks(py_pushdowns)) + assert len(tasks) == 1 + + +def test_point_lookup_skips_index_when_row_identity_requested(tmp_path: Path) -> None: + table = pa.table({"id": list(range(10)), "value": [f"value-{i}" for i in range(10)]}) + ds: Any = lance.write_dataset(table, tmp_path, max_rows_per_file=1) + ds.create_scalar_index("id", "BTREE") + setattr(ds, "_daft_default_scan_options", {"with_row_id": True}) + + from daft.daft import PyPushdowns + + scan = lance_scan.LanceDBScanOperator(ds) + py_pushdowns = PyPushdowns(columns=["id", "value"], filters=cast(Any, col("id") == 2)._expr, limit=5) + + assert scan._should_use_index_for_point_lookup(py_pushdowns) is False + + +def test_point_lookup_skips_index_when_fragment_id_requested(lance_dataset: Any) -> None: + ds = lance_dataset + ds.create_scalar_index("id", "BTREE") + + scan = lance_scan.LanceDBScanOperator(ds, include_fragment_id=True) + scan.push_filters([cast(Any, col("id") == 2)._expr]) + + assert scan._should_use_index_for_point_lookup() is False