Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 46 additions & 15 deletions daft_lance/lance_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def _create_count_rows_scan_task(self, pushdowns: PyPushdowns) -> Iterator[ScanT
yield ScanTask.python_factory_func_scan_task(
module=_lancedb_count_result_function.__module__,
func_name=_lancedb_count_result_function.__name__,
func_args=(self._ds.uri, open_kwargs, fields[0], self._combine_filters_to_arrow()),
func_args=(self._ds.uri, open_kwargs, fields[0], self._combine_filters_to_arrow(pushdowns)),
schema=new_schema._schema,
num_rows=1,
size_bytes=None,
Expand Down Expand Up @@ -319,7 +319,7 @@ def _create_regular_scan_tasks(
"""Create regular scan tasks without count pushdown."""
open_kwargs = getattr(self._ds, "_lance_open_kwargs", None)
fragments = self._ds.get_fragments()
pushed_expr = self._combine_filters_to_arrow()
pushed_expr = self._combine_filters_to_arrow(pushdowns)

def _python_factory_func_scan_task(
fragment_ids: list[int] | None = None,
Expand Down Expand Up @@ -349,7 +349,7 @@ def _python_factory_func_scan_task(
)

# Use index-driven scan for point lookups with BTREE indices or nearest search.
if self._should_use_index_for_point_lookup() or nearest_option is not None:
if self._should_use_index_for_point_lookup(pushdowns) or nearest_option is not None:
yield _python_factory_func_scan_task(fragment_ids=None, num_rows=None, size_bytes=None)
return

Expand Down Expand Up @@ -395,8 +395,20 @@ def _python_factory_func_scan_task(
fragment_ids = [fragment.fragment_id for fragment in fragment_group]
yield _python_factory_func_scan_task(fragment_ids, num_rows=num_rows, size_bytes=size_bytes)

def _combine_filters_to_arrow(self) -> pa.compute.Expression | None:
return combine_filters_to_arrow(self._pushed_filters)
def _effective_filters(self, pushdowns: PyPushdowns | None = None) -> list[PyExpr] | None:
if self._pushed_filters:
return self._pushed_filters

if self._requires_fragment_scans():
return None

filters = getattr(pushdowns, "filters", None) if pushdowns is not None else None
if filters is None:
return None
return [filters]

def _combine_filters_to_arrow(self, pushdowns: PyPushdowns | None = None) -> pa.compute.Expression | None:
return combine_filters_to_arrow(self._effective_filters(pushdowns))

def _compute_limit_pushdown_with_filter(self, pushdowns: PyPushdowns) -> int | None:
"""Decide whether to push down `limit` when filters are present."""
Expand All @@ -408,19 +420,21 @@ def _compute_limit_pushdown_with_filter(self, pushdowns: PyPushdowns) -> int | N

return pushdowns.limit

def _should_use_index_for_point_lookup(self) -> bool:
def _should_use_index_for_point_lookup(self, pushdowns: PyPushdowns | None = None) -> bool:
"""Use index-driven scan only when all point-lookup columns have BTREE.

Otherwise fall back to fragment enumeration. Passing fragment_ids=None signals
index-driven scan; factory omits fragments so Lance selects them using indices.
"""
if not self._pushed_filters:
if self._requires_fragment_scans():
return False

filters = self._effective_filters(pushdowns)
if not filters:
return False

try:
point_columns = detect_point_lookup_columns(
[Expression._from_pyexpr(expr) for expr in self._pushed_filters]
)
point_columns = detect_point_lookup_columns([Expression._from_pyexpr(expr) for expr in filters])
except (ValueError, TypeError, AttributeError) as e:
logger.warning("Failed to analyze filters for point lookup: %s", e, exc_info=True)
return False
Expand Down Expand Up @@ -452,11 +466,7 @@ def _should_use_index_for_point_lookup(self) -> bool:
return True
return False

def _nearest_default_option(self) -> dict[str, Any] | None:
"""Return the default nearest option configured on the Lance dataset, if any.

Prefer Daft-specific `_daft_default_scan_options` to preserve options stripped before `lance.dataset` (e.g., `nearest`).
"""
def _default_scan_options(self) -> dict[str, Any] | None:
default_opts = getattr(self._ds, "_daft_default_scan_options", None)
if not isinstance(default_opts, dict):
default_opts = getattr(self._ds, "_default_scan_options", None)
Expand All @@ -466,6 +476,27 @@ def _nearest_default_option(self) -> dict[str, Any] | None:
default_opts = open_kwargs.get("default_scan_options")
if not isinstance(default_opts, dict):
return None
return default_opts

def _requires_fragment_scans(self) -> bool:
if self._include_fragment_id:
return True

default_opts = self._default_scan_options()
if not default_opts:
return False

row_identity_options = ("with_row_id", "with_row_address", "with_rowaddr")
return any(bool(default_opts.get(option)) for option in row_identity_options)

def _nearest_default_option(self) -> dict[str, Any] | None:
"""Return the default nearest option configured on the Lance dataset, if any.

Prefer Daft-specific `_daft_default_scan_options` to preserve options stripped before `lance.dataset` (e.g., `nearest`).
"""
default_opts = self._default_scan_options()
if default_opts is None:
return None

nearest = default_opts.get("nearest")
if nearest is None:
Expand Down
46 changes: 44 additions & 2 deletions tests/io/lancedb/test_lancedb_point_lookup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

from pathlib import Path
from typing import Any, cast

import pyarrow as pa
import pytest

Expand All @@ -26,7 +29,7 @@ def _scan(ds):
@pytest.mark.parametrize(
"idx_type",
[
pytest.param("BTREE", marks=pytest.mark.xfail(reason="Lance BTREE point lookup detection issue")),
"BTREE",
"BITMAP",
"BLOOMFILTER",
],
Expand All @@ -46,7 +49,7 @@ def test_point_lookup_equal_hits_scalar_index(lance_dataset, idx_type):
@pytest.mark.parametrize(
"idx_type",
[
pytest.param("BTREE", marks=pytest.mark.xfail(reason="Lance BTREE point lookup detection issue")),
"BTREE",
"BITMAP",
"BLOOMFILTER",
],
Expand Down Expand Up @@ -129,3 +132,42 @@ def test_to_scan_tasks_runs(lance_dataset, idx_type):
# Ensure to_scan_tasks yields at least one task and the factory runs without error
tasks = list(scan.to_scan_tasks(py_pushdowns))
assert len(tasks) >= 1


def test_point_lookup_uses_pushdown_filters(tmp_path: Path) -> None:
table = pa.table({"id": list(range(10)), "value": [f"value-{i}" for i in range(10)]})
ds: Any = lance.write_dataset(table, tmp_path, max_rows_per_file=1)
ds.create_scalar_index("id", "BTREE")

from daft.daft import PyPushdowns

scan = lance_scan.LanceDBScanOperator(ds)
py_pushdowns = PyPushdowns(columns=["id", "value"], filters=cast(Any, col("id") == 2)._expr, limit=5)

assert scan._should_use_index_for_point_lookup(py_pushdowns) is True
tasks = list(scan.to_scan_tasks(py_pushdowns))
assert len(tasks) == 1


def test_point_lookup_skips_index_when_row_identity_requested(tmp_path: Path) -> None:
table = pa.table({"id": list(range(10)), "value": [f"value-{i}" for i in range(10)]})
ds: Any = lance.write_dataset(table, tmp_path, max_rows_per_file=1)
ds.create_scalar_index("id", "BTREE")
setattr(ds, "_daft_default_scan_options", {"with_row_id": True})

from daft.daft import PyPushdowns

scan = lance_scan.LanceDBScanOperator(ds)
py_pushdowns = PyPushdowns(columns=["id", "value"], filters=cast(Any, col("id") == 2)._expr, limit=5)

assert scan._should_use_index_for_point_lookup(py_pushdowns) is False


def test_point_lookup_skips_index_when_fragment_id_requested(lance_dataset: Any) -> None:
ds = lance_dataset
ds.create_scalar_index("id", "BTREE")

scan = lance_scan.LanceDBScanOperator(ds, include_fragment_id=True)
scan.push_filters([cast(Any, col("id") == 2)._expr])

assert scan._should_use_index_for_point_lookup() is False