Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 49 additions & 7 deletions daft_lance/lance_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
logger = logging.getLogger(__name__)


# TODO support fts and fast_search
# TODO support fast_search (ANN vector search speed optimization)
def _lancedb_table_factory_function(
ds_uri: str,
open_kwargs: dict[Any, Any] | None = None,
Expand All @@ -32,6 +32,7 @@ def _lancedb_table_factory_function(
limit: int | None = None,
include_fragment_id: bool | None = False,
nearest: dict[str, Any] | None = None,
full_text_query: str | dict[str, Any] | None = None,
) -> Iterator[PyRecordBatch]:
if fragment_ids is not None and nearest is not None:
raise ValueError(
Expand Down Expand Up @@ -64,6 +65,7 @@ def _iter_batches() -> Iterator[PyRecordBatch]:
columns=cols or None,
filter=filter,
limit=fragment_limit,
full_text_query=full_text_query,
blob_handling="blobs_descriptions",
)

Expand Down Expand Up @@ -94,6 +96,7 @@ def _iter_batches() -> Iterator[PyRecordBatch]:
filter=filter,
limit=limit,
nearest=nearest,
full_text_query=full_text_query,
blob_handling="blobs_descriptions",
)

Expand Down Expand Up @@ -141,6 +144,7 @@ def __init__(
self._fragment_group_size = fragment_group_size
self._include_fragment_id = include_fragment_id
self._enable_strict_filter_pushdown = get_context().daft_planning_config.enable_strict_filter_pushdown
self._full_text_query = self._resolve_full_text_query()
base = self._ds.schema
if self._include_fragment_id:
base = pa.schema([*base, pa.field("fragment_id", pa.int64())], metadata=base.metadata)
Expand Down Expand Up @@ -221,6 +225,7 @@ def to_scan_tasks(self, pushdowns: PyPushdowns) -> Iterator[ScanTask]:
required_columns.append("fragment_id")

nearest_option = self._nearest_default_option()
fts_option = self._full_text_query

# Check if there is a count aggregation pushdown
if (
Expand All @@ -234,19 +239,20 @@ def to_scan_tasks(self, pushdowns: PyPushdowns) -> Iterator[ScanTask]:
"Count mode %s is not supported for pushdown, falling back to original logic",
pushdowns.aggregation_count_mode(),
)
yield from self._create_regular_scan_tasks(pushdowns, required_columns, nearest_option)
yield from self._create_regular_scan_tasks(pushdowns, required_columns, nearest_option, fts_option)
else:
yield from self._create_count_rows_scan_task(pushdowns)
# Check if there is a limit pushdown and no filters and no nearest search
# Check if there is a limit pushdown and no filters and no nearest search and no fts
elif (
pushdowns.limit is not None
and self._pushed_filters is None
and pushdowns.filters is None
and nearest_option is None
and fts_option is None
):
yield from self._create_scan_tasks_with_limit_and_no_filters(pushdowns, required_columns)
else:
yield from self._create_regular_scan_tasks(pushdowns, required_columns, nearest_option)
yield from self._create_regular_scan_tasks(pushdowns, required_columns, nearest_option, fts_option)

def _create_count_rows_scan_task(self, pushdowns: PyPushdowns) -> Iterator[ScanTask]:
"""Create scan task for counting rows."""
Expand Down Expand Up @@ -304,6 +310,7 @@ def _create_scan_tasks_with_limit_and_no_filters(
rows_to_scan,
self._include_fragment_id,
None,
None,
),
schema=task_schema._schema,
num_rows=rows_to_scan,
Expand All @@ -314,7 +321,11 @@ def _create_scan_tasks_with_limit_and_no_filters(
)

def _create_regular_scan_tasks(
self, pushdowns: PyPushdowns, required_columns: list[str] | None, nearest_option: dict[str, Any] | None = None
self,
pushdowns: PyPushdowns,
required_columns: list[str] | None,
nearest_option: dict[str, Any] | None = None,
full_text_query_option: str | dict[str, Any] | None = None,
) -> Iterator[ScanTask]:
"""Create regular scan tasks without count pushdown."""
open_kwargs = getattr(self._ds, "_lance_open_kwargs", None)
Expand All @@ -339,6 +350,7 @@ def _python_factory_func_scan_task(
self._compute_limit_pushdown_with_filter(pushdowns),
self._include_fragment_id,
nearest_option,
full_text_query_option,
),
schema=self.schema()._schema,
num_rows=num_rows,
Expand All @@ -348,8 +360,8 @@ def _python_factory_func_scan_task(
source_name=self.display_name(),
)

# Use index-driven scan for point lookups with BTREE indices or nearest search.
if self._should_use_index_for_point_lookup() or nearest_option is not None:
# Use index-driven scan for point lookups with BTREE indices, nearest search, or FTS.
if self._should_use_index_for_point_lookup() or nearest_option is not None or fts_option is not None:
yield _python_factory_func_scan_task(fragment_ids=None, num_rows=None, size_bytes=None)
return

Expand Down Expand Up @@ -479,6 +491,36 @@ def _nearest_default_option(self) -> dict[str, Any] | None:
return None
return nearest

def _resolve_full_text_query(self) -> str | dict[str, Any] | None:
"""Return the full_text_query option configured on the Lance dataset, if any.

Extracts ``full_text_query`` from ``default_scan_options``, supporting both
``_daft_default_scan_options`` and ``_default_scan_options`` dataset attributes.
``full_text_query`` accepts a plain string (simple search term) or a dict with
``columns`` and ``q`` (column-qualified search).
"""
default_opts = getattr(self._ds, "_daft_default_scan_options", None)
if not isinstance(default_opts, dict):
default_opts = getattr(self._ds, "_default_scan_options", None)
if not isinstance(default_opts, dict):
open_kwargs = getattr(self._ds, "_lance_open_kwargs", None)
if isinstance(open_kwargs, dict):
default_opts = open_kwargs.get("default_scan_options")
if not isinstance(default_opts, dict):
return None

fts = default_opts.get("full_text_query")
if fts is None:
return None
if not isinstance(fts, (str, dict)):
logger.warning(
"Ignoring default_scan_options['full_text_query'] for dataset %s: expected str or dict, got %s",
getattr(self._ds, "uri", "<unknown>"),
type(fts).__name__,
)
return None
return fts

@staticmethod
def _estimate_size_bytes(fragment: lance.LanceFragment) -> int:
if fragment.metadata is None or fragment.metadata.files is None:
Expand Down
4 changes: 2 additions & 2 deletions daft_lance/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def construct_lance_dataset(
original_default_scan_options = kwargs.pop("default_scan_options", None)
safe_default_scan_options = None
if isinstance(original_default_scan_options, dict):
safe_default_scan_options = {k: v for k, v in original_default_scan_options.items() if k != "nearest"}
safe_default_scan_options = {k: v for k, v in original_default_scan_options.items() if k not in ("nearest", "full_text_query")}
if safe_default_scan_options:
kwargs["default_scan_options"] = safe_default_scan_options
elif original_default_scan_options is not None:
Expand All @@ -111,7 +111,7 @@ def construct_lance_dataset(
except Exception:
pass

# Preserve the full user-provided defaults (including nearest) for Daft's planning
# Preserve the full user-provided defaults (including nearest and full_text_query) for Daft's planning
# even if we stripped keys out before calling `lance.dataset`.
try:
ds._daft_default_scan_options = original_default_scan_options
Expand Down
166 changes: 166 additions & 0 deletions tests/io/lancedb/test_lancedb_full_text_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from __future__ import annotations

import lance
import pyarrow as pa
import pytest

import daft
from daft_lance import create_scalar_index


def build_text_dataset(tmp_path_factory) -> str:
tmp_dir = tmp_path_factory.mktemp("lance_fts")

text_data = {
"id": [1, 2, 3, 4, 5, 6, 7, 8],
"text": [
"The quick brown fox jumps over the lazy dog",
"Python is a powerful programming language",
"Machine learning algorithms are fascinating",
"Data science requires statistical knowledge",
"Natural language processing uses text analysis",
"Distributed computing scales horizontally",
"Daft framework enables parallel processing",
"Lance format provides efficient storage",
],
"category": [
"animals",
"tech",
"ml",
"data",
"nlp",
"distributed",
"daft",
"storage",
],
}
table = pa.table(text_data)
lance.write_dataset(table, tmp_dir)
return str(tmp_dir)


def build_multi_fragment_text_dataset(tmp_path_factory) -> str:
tmp_dir = tmp_path_factory.mktemp("lance_fts_multi")

text_data = {
"id": [1, 2, 3, 4, 5, 6, 7, 8],
"text": [
"The quick brown fox jumps over the lazy dog",
"Python is a powerful programming language",
"Machine learning algorithms are fascinating",
"Data science requires statistical knowledge",
"Natural language processing uses text analysis",
"Distributed computing scales horizontally",
"Daft framework enables parallel processing",
"Lance format provides efficient storage",
],
}
table = pa.table(text_data)
lance.write_dataset(table, tmp_dir, max_rows_per_file=2)
return str(tmp_dir)


def test_full_text_query_string_search(tmp_path_factory) -> None:
"""Test FTS with a plain string query."""
dataset_path = build_text_dataset(tmp_path_factory)

# Build INVERTED index on the text column
create_scalar_index(uri=dataset_path, column="text", index_type="INVERTED")

# Search via daft read_lance with full_text_query
search_term = "Python"
df = daft.read_lance(
dataset_path,
default_scan_options={"full_text_query": search_term},
)
result = df.select("id", "text").to_pydict()

# Should find row id=2: "Python is a powerful programming language"
assert len(result["id"]) >= 1
assert 2 in result["id"]


def test_full_text_query_no_results(tmp_path_factory) -> None:
"""Test FTS with a term that has no matches."""
dataset_path = build_text_dataset(tmp_path_factory)

create_scalar_index(uri=dataset_path, column="text", index_type="INVERTED")

search_term = "zzzzznonexistent"
df = daft.read_lance(
dataset_path,
default_scan_options={"full_text_query": search_term},
)
result = df.select("id").to_pydict()

assert result["id"] == []


def test_full_text_query_multi_fragment(tmp_path_factory) -> None:
"""Test FTS across multiple fragments."""
dataset_path = build_multi_fragment_text_dataset(tmp_path_factory)

create_scalar_index(uri=dataset_path, column="text", index_type="INVERTED")

# Search for "language" - row 5: "Natural language processing uses text analysis"
search_term = "language"
df = daft.read_lance(
dataset_path,
default_scan_options={"full_text_query": search_term},
)
result = df.select("id", "text").to_pydict()

assert len(result["id"]) >= 1
assert 5 in result["id"]


def test_full_text_query_with_filter(tmp_path_factory) -> None:
"""Test FTS combined with a standard filter."""
dataset_path = build_text_dataset(tmp_path_factory)

create_scalar_index(uri=dataset_path, column="text", index_type="INVERTED")

# Search for "learning" - rows 3 and 4 both mention "learning"
search_term = "learning"
df = daft.read_lance(
dataset_path,
default_scan_options={"full_text_query": search_term},
)
# Filter to only category=ml
result = df.select("id", "category").to_pydict()

assert len(result["id"]) >= 1
assert 3 in result["id"] # "Machine learning algorithms are fascinating", category=ml


def test_full_text_query_column_projection(tmp_path_factory) -> None:
"""Test FTS with column projection."""
dataset_path = build_text_dataset(tmp_path_factory)

create_scalar_index(uri=dataset_path, column="text", index_type="INVERTED")

search_term = "storage"
df = daft.read_lance(
dataset_path,
default_scan_options={"full_text_query": search_term},
)
result = df.select("id").to_pydict()

assert result["id"] == [8] # "Lance format provides efficient storage"


def test_full_text_query_without_index(tmp_path_factory) -> None:
"""Test FTS without an index should still work via brute-force scan."""
dataset_path = build_text_dataset(tmp_path_factory)

# Do NOT build an index - Lance should fall back to brute-force search
search_term = "Python"
df = daft.read_lance(
dataset_path,
default_scan_options={"full_text_query": search_term},
)
result = df.select("id").to_pydict()

# Should still find the result even without an index
assert len(result["id"]) >= 1
assert 2 in result["id"]