diff --git a/daft_lance/lance_scan.py b/daft_lance/lance_scan.py index 5fd3878..8d2db76 100644 --- a/daft_lance/lance_scan.py +++ b/daft_lance/lance_scan.py @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) -# TODO support fts and fast_search +# TODO support fast_search (ANN vector search speed optimization) def _lancedb_table_factory_function( ds_uri: str, open_kwargs: dict[Any, Any] | None = None, @@ -32,6 +32,7 @@ def _lancedb_table_factory_function( limit: int | None = None, include_fragment_id: bool | None = False, nearest: dict[str, Any] | None = None, + full_text_query: str | dict[str, Any] | None = None, ) -> Iterator[PyRecordBatch]: if fragment_ids is not None and nearest is not None: raise ValueError( @@ -64,6 +65,7 @@ def _iter_batches() -> Iterator[PyRecordBatch]: columns=cols or None, filter=filter, limit=fragment_limit, + full_text_query=full_text_query, blob_handling="blobs_descriptions", ) @@ -94,6 +96,7 @@ def _iter_batches() -> Iterator[PyRecordBatch]: filter=filter, limit=limit, nearest=nearest, + full_text_query=full_text_query, blob_handling="blobs_descriptions", ) @@ -141,6 +144,7 @@ def __init__( self._fragment_group_size = fragment_group_size self._include_fragment_id = include_fragment_id self._enable_strict_filter_pushdown = get_context().daft_planning_config.enable_strict_filter_pushdown + self._full_text_query = self._resolve_full_text_query() base = self._ds.schema if self._include_fragment_id: base = pa.schema([*base, pa.field("fragment_id", pa.int64())], metadata=base.metadata) @@ -221,6 +225,7 @@ def to_scan_tasks(self, pushdowns: PyPushdowns) -> Iterator[ScanTask]: required_columns.append("fragment_id") nearest_option = self._nearest_default_option() + fts_option = self._full_text_query # Check if there is a count aggregation pushdown if ( @@ -234,19 +239,20 @@ def to_scan_tasks(self, pushdowns: PyPushdowns) -> Iterator[ScanTask]: "Count mode %s is not supported for pushdown, falling back to original logic", pushdowns.aggregation_count_mode(), ) - yield from self._create_regular_scan_tasks(pushdowns, required_columns, nearest_option) + yield from self._create_regular_scan_tasks(pushdowns, required_columns, nearest_option, fts_option) else: yield from self._create_count_rows_scan_task(pushdowns) - # Check if there is a limit pushdown and no filters and no nearest search + # Check if there is a limit pushdown and no filters and no nearest search and no fts elif ( pushdowns.limit is not None and self._pushed_filters is None and pushdowns.filters is None and nearest_option is None + and fts_option is None ): yield from self._create_scan_tasks_with_limit_and_no_filters(pushdowns, required_columns) else: - yield from self._create_regular_scan_tasks(pushdowns, required_columns, nearest_option) + yield from self._create_regular_scan_tasks(pushdowns, required_columns, nearest_option, fts_option) def _create_count_rows_scan_task(self, pushdowns: PyPushdowns) -> Iterator[ScanTask]: """Create scan task for counting rows.""" @@ -304,6 +310,7 @@ def _create_scan_tasks_with_limit_and_no_filters( rows_to_scan, self._include_fragment_id, None, + None, ), schema=task_schema._schema, num_rows=rows_to_scan, @@ -314,7 +321,11 @@ def _create_scan_tasks_with_limit_and_no_filters( ) def _create_regular_scan_tasks( - self, pushdowns: PyPushdowns, required_columns: list[str] | None, nearest_option: dict[str, Any] | None = None + self, + pushdowns: PyPushdowns, + required_columns: list[str] | None, + nearest_option: dict[str, Any] | None = None, + full_text_query_option: str | dict[str, Any] | None = None, ) -> Iterator[ScanTask]: """Create regular scan tasks without count pushdown.""" open_kwargs = getattr(self._ds, "_lance_open_kwargs", None) @@ -339,6 +350,7 @@ def _python_factory_func_scan_task( self._compute_limit_pushdown_with_filter(pushdowns), self._include_fragment_id, nearest_option, + full_text_query_option, ), schema=self.schema()._schema, num_rows=num_rows, @@ -348,8 +360,8 @@ def _python_factory_func_scan_task( source_name=self.display_name(), ) - # Use index-driven scan for point lookups with BTREE indices or nearest search. - if self._should_use_index_for_point_lookup() or nearest_option is not None: + # Use index-driven scan for point lookups with BTREE indices, nearest search, or FTS. + if self._should_use_index_for_point_lookup() or nearest_option is not None or fts_option is not None: yield _python_factory_func_scan_task(fragment_ids=None, num_rows=None, size_bytes=None) return @@ -479,6 +491,36 @@ def _nearest_default_option(self) -> dict[str, Any] | None: return None return nearest + def _resolve_full_text_query(self) -> str | dict[str, Any] | None: + """Return the full_text_query option configured on the Lance dataset, if any. + + Extracts ``full_text_query`` from ``default_scan_options``, supporting both + ``_daft_default_scan_options`` and ``_default_scan_options`` dataset attributes. + ``full_text_query`` accepts a plain string (simple search term) or a dict with + ``columns`` and ``q`` (column-qualified search). + """ + default_opts = getattr(self._ds, "_daft_default_scan_options", None) + if not isinstance(default_opts, dict): + default_opts = getattr(self._ds, "_default_scan_options", None) + if not isinstance(default_opts, dict): + open_kwargs = getattr(self._ds, "_lance_open_kwargs", None) + if isinstance(open_kwargs, dict): + default_opts = open_kwargs.get("default_scan_options") + if not isinstance(default_opts, dict): + return None + + fts = default_opts.get("full_text_query") + if fts is None: + return None + if not isinstance(fts, (str, dict)): + logger.warning( + "Ignoring default_scan_options['full_text_query'] for dataset %s: expected str or dict, got %s", + getattr(self._ds, "uri", ""), + type(fts).__name__, + ) + return None + return fts + @staticmethod def _estimate_size_bytes(fragment: lance.LanceFragment) -> int: if fragment.metadata is None or fragment.metadata.files is None: diff --git a/daft_lance/utils.py b/daft_lance/utils.py index 5429657..c20a07b 100644 --- a/daft_lance/utils.py +++ b/daft_lance/utils.py @@ -92,7 +92,7 @@ def construct_lance_dataset( original_default_scan_options = kwargs.pop("default_scan_options", None) safe_default_scan_options = None if isinstance(original_default_scan_options, dict): - safe_default_scan_options = {k: v for k, v in original_default_scan_options.items() if k != "nearest"} + safe_default_scan_options = {k: v for k, v in original_default_scan_options.items() if k not in ("nearest", "full_text_query")} if safe_default_scan_options: kwargs["default_scan_options"] = safe_default_scan_options elif original_default_scan_options is not None: @@ -111,7 +111,7 @@ def construct_lance_dataset( except Exception: pass - # Preserve the full user-provided defaults (including nearest) for Daft's planning + # Preserve the full user-provided defaults (including nearest and full_text_query) for Daft's planning # even if we stripped keys out before calling `lance.dataset`. try: ds._daft_default_scan_options = original_default_scan_options diff --git a/tests/io/lancedb/test_lancedb_full_text_search.py b/tests/io/lancedb/test_lancedb_full_text_search.py new file mode 100644 index 0000000..9c28c90 --- /dev/null +++ b/tests/io/lancedb/test_lancedb_full_text_search.py @@ -0,0 +1,166 @@ +from __future__ import annotations + +import lance +import pyarrow as pa +import pytest + +import daft +from daft_lance import create_scalar_index + + +def build_text_dataset(tmp_path_factory) -> str: + tmp_dir = tmp_path_factory.mktemp("lance_fts") + + text_data = { + "id": [1, 2, 3, 4, 5, 6, 7, 8], + "text": [ + "The quick brown fox jumps over the lazy dog", + "Python is a powerful programming language", + "Machine learning algorithms are fascinating", + "Data science requires statistical knowledge", + "Natural language processing uses text analysis", + "Distributed computing scales horizontally", + "Daft framework enables parallel processing", + "Lance format provides efficient storage", + ], + "category": [ + "animals", + "tech", + "ml", + "data", + "nlp", + "distributed", + "daft", + "storage", + ], + } + table = pa.table(text_data) + lance.write_dataset(table, tmp_dir) + return str(tmp_dir) + + +def build_multi_fragment_text_dataset(tmp_path_factory) -> str: + tmp_dir = tmp_path_factory.mktemp("lance_fts_multi") + + text_data = { + "id": [1, 2, 3, 4, 5, 6, 7, 8], + "text": [ + "The quick brown fox jumps over the lazy dog", + "Python is a powerful programming language", + "Machine learning algorithms are fascinating", + "Data science requires statistical knowledge", + "Natural language processing uses text analysis", + "Distributed computing scales horizontally", + "Daft framework enables parallel processing", + "Lance format provides efficient storage", + ], + } + table = pa.table(text_data) + lance.write_dataset(table, tmp_dir, max_rows_per_file=2) + return str(tmp_dir) + + +def test_full_text_query_string_search(tmp_path_factory) -> None: + """Test FTS with a plain string query.""" + dataset_path = build_text_dataset(tmp_path_factory) + + # Build INVERTED index on the text column + create_scalar_index(uri=dataset_path, column="text", index_type="INVERTED") + + # Search via daft read_lance with full_text_query + search_term = "Python" + df = daft.read_lance( + dataset_path, + default_scan_options={"full_text_query": search_term}, + ) + result = df.select("id", "text").to_pydict() + + # Should find row id=2: "Python is a powerful programming language" + assert len(result["id"]) >= 1 + assert 2 in result["id"] + + +def test_full_text_query_no_results(tmp_path_factory) -> None: + """Test FTS with a term that has no matches.""" + dataset_path = build_text_dataset(tmp_path_factory) + + create_scalar_index(uri=dataset_path, column="text", index_type="INVERTED") + + search_term = "zzzzznonexistent" + df = daft.read_lance( + dataset_path, + default_scan_options={"full_text_query": search_term}, + ) + result = df.select("id").to_pydict() + + assert result["id"] == [] + + +def test_full_text_query_multi_fragment(tmp_path_factory) -> None: + """Test FTS across multiple fragments.""" + dataset_path = build_multi_fragment_text_dataset(tmp_path_factory) + + create_scalar_index(uri=dataset_path, column="text", index_type="INVERTED") + + # Search for "language" - row 5: "Natural language processing uses text analysis" + search_term = "language" + df = daft.read_lance( + dataset_path, + default_scan_options={"full_text_query": search_term}, + ) + result = df.select("id", "text").to_pydict() + + assert len(result["id"]) >= 1 + assert 5 in result["id"] + + +def test_full_text_query_with_filter(tmp_path_factory) -> None: + """Test FTS combined with a standard filter.""" + dataset_path = build_text_dataset(tmp_path_factory) + + create_scalar_index(uri=dataset_path, column="text", index_type="INVERTED") + + # Search for "learning" - rows 3 and 4 both mention "learning" + search_term = "learning" + df = daft.read_lance( + dataset_path, + default_scan_options={"full_text_query": search_term}, + ) + # Filter to only category=ml + result = df.select("id", "category").to_pydict() + + assert len(result["id"]) >= 1 + assert 3 in result["id"] # "Machine learning algorithms are fascinating", category=ml + + +def test_full_text_query_column_projection(tmp_path_factory) -> None: + """Test FTS with column projection.""" + dataset_path = build_text_dataset(tmp_path_factory) + + create_scalar_index(uri=dataset_path, column="text", index_type="INVERTED") + + search_term = "storage" + df = daft.read_lance( + dataset_path, + default_scan_options={"full_text_query": search_term}, + ) + result = df.select("id").to_pydict() + + assert result["id"] == [8] # "Lance format provides efficient storage" + + +def test_full_text_query_without_index(tmp_path_factory) -> None: + """Test FTS without an index should still work via brute-force scan.""" + dataset_path = build_text_dataset(tmp_path_factory) + + # Do NOT build an index - Lance should fall back to brute-force search + search_term = "Python" + df = daft.read_lance( + dataset_path, + default_scan_options={"full_text_query": search_term}, + ) + result = df.select("id").to_pydict() + + # Should still find the result even without an index + assert len(result["id"]) >= 1 + assert 2 in result["id"]