diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 29444ba1b..a1ff94f98 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -26,9 +26,9 @@ env: WEAVIATE_132: 1.32.27 WEAVIATE_133: 1.33.18 WEAVIATE_134: 1.34.19 - WEAVIATE_135: 1.35.17 - WEAVIATE_136: 1.36.10 - WEAVIATE_137: 1.37.1 + WEAVIATE_135: 1.35.18 + WEAVIATE_136: 1.36.12 + WEAVIATE_137: 1.37.1-4e61e26.amd64 jobs: lint-and-format: diff --git a/integration/test_tokenize.py b/integration/test_tokenize.py index 97587235b..d2d46916d 100644 --- a/integration/test_tokenize.py +++ b/integration/test_tokenize.py @@ -1,11 +1,17 @@ """Integration tests for the tokenization module. These tests cover the client's responsibilities: -- Correct serialization of inputs (enums, _TextAnalyzerConfigCreate, _StopwordsCreate) -- Correct deserialization of responses into typed objects -- Client-side validation (_TextAnalyzerConfigCreate rejects invalid input) +- Correct serialization of inputs (enums, TextAnalyzerConfigCreate, StopwordsCreate) +- Correct deserialization of responses into the TokenizeResult object +- Client-side validation (TextAnalyzerConfigCreate, stopwords/stopword_presets mutex) - Version gate (>= 1.37.0) - Both sync and async client paths + +Server-side behavior this client relies on: +- Word tokenization defaults to preset "en" when no stopword config is sent. +- Both endpoints return only ``indexed`` and ``query``. +- ``stopwords`` and ``stopword_presets`` are mutually exclusive on the generic + endpoint — the server rejects requests that set both. """ from typing import AsyncGenerator, Generator @@ -14,17 +20,15 @@ import pytest_asyncio import weaviate -from weaviate.collections.classes.config import ( - StopwordsConfig, +from weaviate.classes.tokenization import ( + StopwordsCreate, StopwordsPreset, - TextAnalyzerConfig, + TextAnalyzerConfigCreate, Tokenization, - _StopwordsCreate, - _TextAnalyzerConfigCreate, + TokenizeResult, ) from weaviate.config import AdditionalConfig from weaviate.exceptions import WeaviateUnsupportedFeatureError -from weaviate.tokenization.models import TokenizeResult @pytest.fixture(scope="module") @@ -52,6 +56,29 @@ async def async_client() -> AsyncGenerator[weaviate.WeaviateAsyncClient, None]: await c.close() +@pytest.fixture +def recipe_collection(client: weaviate.WeaviateClient) -> Generator: + """Collection with a `recipe` word-tokenized property and an en + ["quick"] stopwords config.""" + name = "TestTokenizeRecipe" + client.collections.delete(name) + client.collections.create_from_dict( + { + "class": name, + "vectorizer": "none", + "invertedIndexConfig": { + "stopwords": {"preset": "en", "additions": ["quick"]}, + }, + "properties": [ + {"name": "recipe", "dataType": ["text"], "tokenization": "word"}, + ], + } + ) + try: + yield client.collections.get(name) + finally: + client.collections.delete(name) + + # --------------------------------------------------------------------------- # Serialization # --------------------------------------------------------------------------- @@ -62,13 +89,31 @@ class TestSerialization: """Verify the client correctly serializes different input forms.""" @pytest.mark.parametrize( - "tokenization,text,expected_tokens", + "tokenization,text,expected_indexed,expected_query", [ - (Tokenization.WORD, "The quick brown fox", ["the", "quick", "brown", "fox"]), - (Tokenization.LOWERCASE, "Hello World Test", ["hello", "world", "test"]), - (Tokenization.WHITESPACE, "Hello World Test", ["Hello", "World", "Test"]), - (Tokenization.FIELD, " Hello World ", ["Hello World"]), - (Tokenization.TRIGRAM, "Hello", ["hel", "ell", "llo"]), + # "the" is an English stopword — filtered from the query output + # by the server's default "en" preset for word tokenization. + ( + Tokenization.WORD, + "The quick brown fox", + ["the", "quick", "brown", "fox"], + ["quick", "brown", "fox"], + ), + # Non-word tokenizations do not apply the default "en" preset. + ( + Tokenization.LOWERCASE, + "Hello World Test", + ["hello", "world", "test"], + ["hello", "world", "test"], + ), + ( + Tokenization.WHITESPACE, + "Hello World Test", + ["Hello", "World", "Test"], + ["Hello", "World", "Test"], + ), + (Tokenization.FIELD, " Hello World ", ["Hello World"], ["Hello World"]), + (Tokenization.TRIGRAM, "Hello", ["hel", "ell", "llo"], ["hel", "ell", "llo"]), ], ) def test_tokenization_enum( @@ -76,97 +121,184 @@ def test_tokenization_enum( client: weaviate.WeaviateClient, tokenization: Tokenization, text: str, - expected_tokens: list, + expected_indexed: list, + expected_query: list, ) -> None: result = client.tokenization.text(text=text, tokenization=tokenization) assert isinstance(result, TokenizeResult) - assert result.tokenization == tokenization - assert result.indexed == expected_tokens - assert result.query == expected_tokens - - def test_no_analyzer_config(self, client: weaviate.WeaviateClient) -> None: - result = client.tokenization.text(text="hello world", tokenization=Tokenization.WORD) - assert result.tokenization == Tokenization.WORD - assert result.indexed == ["hello", "world"] - assert result.analyzer_config is None - - def test_ascii_fold(self, client: weaviate.WeaviateClient) -> None: - cfg = _TextAnalyzerConfigCreate(ascii_fold=True) - result = client.tokenization.text( - text="L'école est fermée", - tokenization=Tokenization.WORD, - analyzer_config=cfg, - ) - assert result.indexed == ["l", "ecole", "est", "fermee"] - - def test_ascii_fold_with_ignore(self, client: weaviate.WeaviateClient) -> None: - cfg = _TextAnalyzerConfigCreate(ascii_fold=True, ascii_fold_ignore=["é"]) - result = client.tokenization.text( - text="L'école est fermée", - tokenization=Tokenization.WORD, - analyzer_config=cfg, - ) - assert result.indexed == ["l", "école", "est", "fermée"] + assert result.indexed == expected_indexed + assert result.query == expected_query - def test_stopword_preset_enum(self, client: weaviate.WeaviateClient) -> None: - cfg = _TextAnalyzerConfigCreate(stopword_preset=StopwordsPreset.EN) - result = client.tokenization.text( - text="The quick brown fox", - tokenization=Tokenization.WORD, - analyzer_config=cfg, - ) - assert "the" not in result.query - assert "quick" in result.query + @pytest.mark.parametrize( + "call_kwargs,expected_indexed,expected_query", + [ + ( + {"text": "The quick brown fox"}, + ["the", "quick", "brown", "fox"], + ["quick", "brown", "fox"], + ), + ( + { + "text": "The quick brown fox", + "analyzer_config": TextAnalyzerConfigCreate( + stopword_preset=StopwordsPreset.NONE + ), + }, + ["the", "quick", "brown", "fox"], + ["the", "quick", "brown", "fox"], + ), + ( + { + "text": "L'école est fermée", + "analyzer_config": TextAnalyzerConfigCreate(ascii_fold=True), + }, + ["l", "ecole", "est", "fermee"], + ["l", "ecole", "est", "fermee"], + ), + ( + { + "text": "L'école est fermée", + "analyzer_config": TextAnalyzerConfigCreate( + ascii_fold=True, ascii_fold_ignore=["é"] + ), + }, + ["l", "école", "est", "fermée"], + ["l", "école", "est", "fermée"], + ), + ( + { + "text": "The quick brown fox", + "analyzer_config": TextAnalyzerConfigCreate(stopword_preset=StopwordsPreset.EN), + }, + ["the", "quick", "brown", "fox"], + ["quick", "brown", "fox"], + ), + ( + { + "text": "The quick brown fox", + "analyzer_config": TextAnalyzerConfigCreate(stopword_preset="en"), + }, + ["the", "quick", "brown", "fox"], + ["quick", "brown", "fox"], + ), + ( + { + "text": "The école est fermée", + "analyzer_config": TextAnalyzerConfigCreate( + ascii_fold=True, + ascii_fold_ignore=["é"], + stopword_preset=StopwordsPreset.EN, + ), + }, + ["the", "école", "est", "fermée"], + ["école", "est", "fermée"], + ), + ( + { + "text": "the quick brown fox", + "stopwords": StopwordsCreate( + preset=StopwordsPreset.EN, additions=["quick"], removals=None + ), + }, + ["the", "quick", "brown", "fox"], + ["brown", "fox"], + ), + ( + { + "text": "the quick hello world", + "stopwords": StopwordsCreate(preset=None, additions=["hello"], removals=None), + }, + ["the", "quick", "hello", "world"], + ["quick", "world"], + ), + ( + { + "text": "the quick is fast", + "stopwords": StopwordsCreate(preset=None, additions=None, removals=["the"]), + }, + ["the", "quick", "is", "fast"], + ["the", "quick", "fast"], + ), + ( + { + "text": "hello world test", + "analyzer_config": TextAnalyzerConfigCreate(stopword_preset="custom"), + "stopword_presets": {"custom": ["test"]}, + }, + ["hello", "world", "test"], + ["hello", "world"], + ), + ( + { + "text": "the quick hello world", + "stopword_presets": {"en": ["hello"]}, + }, + ["the", "quick", "hello", "world"], + ["the", "quick", "world"], + ), + ], + ids=[ + "default_en_applied_for_word", + "opt_out_of_default_en", + "ascii_fold", + "ascii_fold_with_ignore", + "stopword_preset_enum", + "stopword_preset_string", + "ascii_fold_combined_with_stopwords", + "stopwords_fallback", + "stopwords_additions_default_preset_to_en", + "stopwords_removals_default_preset_to_en", + "stopword_presets_named_reference", + "stopword_presets_override_builtin_en", + ], + ) + def test_text_tokenize( + self, + client: weaviate.WeaviateClient, + call_kwargs: dict, + expected_indexed: list, + expected_query: list, + ) -> None: + result = client.tokenization.text(tokenization=Tokenization.WORD, **call_kwargs) + assert isinstance(result, TokenizeResult) + assert result.indexed == expected_indexed + assert result.query == expected_query - def test_stopword_preset_string(self, client: weaviate.WeaviateClient) -> None: - cfg = _TextAnalyzerConfigCreate(stopword_preset="en") + def test_text_from_collection_config( + self, client: weaviate.WeaviateClient, recipe_collection + ) -> None: + """Values round-tripped through config.get() feed back into tokenization.text().""" + config = recipe_collection.config.get() + recipe = next(p for p in config.properties if p.name == "recipe") + stopwords = config.inverted_index_config.stopwords result = client.tokenization.text( - text="The quick brown fox", - tokenization=Tokenization.WORD, - analyzer_config=cfg, + text="the quick brown fox", + tokenization=recipe.tokenization, + stopwords=stopwords, ) - assert "the" not in result.query + assert result.indexed == ["the", "quick", "brown", "fox"] + assert result.query == ["brown", "fox"] - def test_ascii_fold_combined_with_stopwords(self, client: weaviate.WeaviateClient) -> None: - cfg = _TextAnalyzerConfigCreate( - ascii_fold=True, ascii_fold_ignore=["é"], stopword_preset=StopwordsPreset.EN - ) - result = client.tokenization.text( - text="The école est fermée", - tokenization=Tokenization.WORD, - analyzer_config=cfg, + def test_property_and_generic_endpoints_agree( + self, client: weaviate.WeaviateClient, recipe_collection + ) -> None: + """Property endpoint (server resolves config from schema) produces the same indexed/query as the generic endpoint fed the same config.""" + config = recipe_collection.config.get() + recipe = next(p for p in config.properties if p.name == "recipe") + stopwords = config.inverted_index_config.stopwords + + text = "the quick brown fox" + via_property = client.tokenization.for_property( + collection=recipe_collection.name, property_name="recipe", text=text ) - assert result.indexed == ["the", "école", "est", "fermée"] - assert "the" not in result.query - assert "école" in result.query - - def test_stopword_presets_custom_additions(self, client: weaviate.WeaviateClient) -> None: - cfg = _TextAnalyzerConfigCreate(stopword_preset="custom") - result = client.tokenization.text( - text="hello world test", - tokenization=Tokenization.WORD, - analyzer_config=cfg, - stopword_presets={ - "custom": _StopwordsCreate(preset=None, additions=["test"], removals=None), - }, + via_generic = client.tokenization.text( + text=text, + tokenization=recipe.tokenization, + stopwords=stopwords, ) - assert result.indexed == ["hello", "world", "test"] - assert result.query == ["hello", "world"] - def test_stopword_presets_with_base_and_removals(self, client: weaviate.WeaviateClient) -> None: - cfg = _TextAnalyzerConfigCreate(stopword_preset="en-no-the") - result = client.tokenization.text( - text="the quick", - tokenization=Tokenization.WORD, - analyzer_config=cfg, - stopword_presets={ - "en-no-the": _StopwordsCreate( - preset=StopwordsPreset.EN, additions=None, removals=["the"] - ), - }, - ) - assert result.indexed == ["the", "quick"] - assert result.query == ["the", "quick"] + assert via_property.indexed == via_generic.indexed + assert via_property.query == via_generic.query # --------------------------------------------------------------------------- @@ -176,61 +308,10 @@ def test_stopword_presets_with_base_and_removals(self, client: weaviate.Weaviate @pytest.mark.usefixtures("require_1_37") class TestDeserialization: - """Verify the client correctly deserializes response fields into typed objects.""" - - def test_result_type(self, client: weaviate.WeaviateClient) -> None: - result = client.tokenization.text(text="hello", tokenization=Tokenization.WORD) - assert isinstance(result, TokenizeResult) - assert isinstance(result.indexed, list) - assert isinstance(result.query, list) + """Verify the client correctly deserializes response fields into TokenizeResult.""" - def test_analyzer_config_deserialized(self, client: weaviate.WeaviateClient) -> None: - cfg = _TextAnalyzerConfigCreate( - ascii_fold=True, ascii_fold_ignore=["é"], stopword_preset=StopwordsPreset.EN - ) - result = client.tokenization.text( - text="L'école", - tokenization=Tokenization.WORD, - analyzer_config=cfg, - ) - assert isinstance(result.analyzer_config, TextAnalyzerConfig) - assert result.analyzer_config.ascii_fold is True - assert result.analyzer_config.ascii_fold_ignore == ["é"] - assert result.analyzer_config.stopword_preset == "en" - - def test_no_analyzer_config_returns_none(self, client: weaviate.WeaviateClient) -> None: - result = client.tokenization.text(text="hello", tokenization=Tokenization.WORD) - assert result.analyzer_config is None - - def test_stopword_config_deserialized_on_property( - self, client: weaviate.WeaviateClient - ) -> None: - client.collections.delete("TestDeserStopword") - try: - client.collections.create_from_dict( - { - "class": "TestDeserStopword", - "vectorizer": "none", - "properties": [ - { - "name": "title", - "dataType": ["text"], - "tokenization": "word", - "textAnalyzer": {"stopwordPreset": "en"}, - }, - ], - } - ) - col = client.collections.get("TestDeserStopword") - result = col.config.tokenize_property(property_name="title", text="the quick") - assert isinstance(result, TokenizeResult) - assert result.tokenization == Tokenization.WORD - if result.stopword_config is not None: - assert isinstance(result.stopword_config, StopwordsConfig) - finally: - client.collections.delete("TestDeserStopword") - - def test_property_result_types(self, client: weaviate.WeaviateClient) -> None: + def test_property_result_shape(self, client: weaviate.WeaviateClient) -> None: + """Property endpoint response deserializes into TokenizeResult — server resolves tokenization from the property's schema.""" client.collections.delete("TestDeserPropTypes") try: client.collections.create_from_dict( @@ -246,50 +327,105 @@ def test_property_result_types(self, client: weaviate.WeaviateClient) -> None: ], } ) - col = client.collections.get("TestDeserPropTypes") - result = col.config.tokenize_property(property_name="tag", text=" Hello World ") + result = client.tokenization.for_property( + collection="TestDeserPropTypes", property_name="tag", text=" Hello World " + ) assert isinstance(result, TokenizeResult) - assert result.tokenization == Tokenization.FIELD assert result.indexed == ["Hello World"] finally: client.collections.delete("TestDeserPropTypes") # --------------------------------------------------------------------------- -# Client-side validation (_TextAnalyzerConfigCreate) +# Client-side validation # --------------------------------------------------------------------------- class TestClientSideValidation: - """Verify that _TextAnalyzerConfigCreate rejects invalid input before hitting the server.""" - - def test_ascii_fold_ignore_without_fold_raises(self) -> None: - with pytest.raises(ValueError, match="asciiFoldIgnore"): - _TextAnalyzerConfigCreate(ascii_fold=False, ascii_fold_ignore=["é"]) + """Verify that client-side validation rejects invalid input before hitting the server.""" - def test_ascii_fold_ignore_without_fold_default_raises(self) -> None: + @pytest.mark.parametrize( + "kwargs", + [ + {"ascii_fold": False, "ascii_fold_ignore": ["é"]}, + {"ascii_fold_ignore": ["é"]}, + ], + ids=["explicit_false", "default"], + ) + def test_ascii_fold_ignore_without_fold_raises(self, kwargs: dict) -> None: with pytest.raises(ValueError, match="asciiFoldIgnore"): - _TextAnalyzerConfigCreate(ascii_fold_ignore=["é"]) - - def test_valid_config_does_not_raise(self) -> None: - cfg = _TextAnalyzerConfigCreate(ascii_fold=True, ascii_fold_ignore=["é", "ñ"]) - assert cfg.asciiFold is True - assert cfg.asciiFoldIgnore == ["é", "ñ"] - - def test_fold_without_ignore_is_valid(self) -> None: - cfg = _TextAnalyzerConfigCreate(ascii_fold=True) - assert cfg.asciiFold is True - assert cfg.asciiFoldIgnore is None + TextAnalyzerConfigCreate(**kwargs) - def test_stopword_preset_only_is_valid(self) -> None: - cfg = _TextAnalyzerConfigCreate(stopword_preset="en") - assert cfg.stopwordPreset == "en" + @pytest.mark.parametrize( + "kwargs,expected", + [ + ( + {"ascii_fold": True, "ascii_fold_ignore": ["é", "ñ"]}, + {"asciiFold": True, "asciiFoldIgnore": ["é", "ñ"]}, + ), + ( + {"ascii_fold": True}, + {"asciiFold": True, "asciiFoldIgnore": None}, + ), + ( + {"stopword_preset": "en"}, + {"stopwordPreset": "en"}, + ), + ( + {}, + {"asciiFold": None, "asciiFoldIgnore": None, "stopwordPreset": None}, + ), + ], + ids=["fold_with_ignore", "fold_without_ignore", "stopword_preset_only", "empty"], + ) + def test_valid_config(self, kwargs: dict, expected: dict) -> None: + cfg = TextAnalyzerConfigCreate(**kwargs) + for attr, value in expected.items(): + assert getattr(cfg, attr) == value + + def test_stopwords_and_stopword_presets_mutex(self, client: weaviate.WeaviateClient) -> None: + """Client rejects the mutex violation locally with ValueError, before sending the request (which the server would also reject with 422).""" + if client._connection._weaviate_version.is_lower_than(1, 37, 0): + pytest.skip("Tokenization requires Weaviate >= 1.37.0") + with pytest.raises(ValueError, match="mutually exclusive"): + client.tokenization.text( + text="hello", + tokenization=Tokenization.WORD, + stopwords=StopwordsCreate(preset=StopwordsPreset.EN, additions=None, removals=None), + stopword_presets={"custom": ["hello"]}, + ) - def test_empty_config_is_valid(self) -> None: - cfg = _TextAnalyzerConfigCreate() - assert cfg.asciiFold is None - assert cfg.asciiFoldIgnore is None - assert cfg.stopwordPreset is None + @pytest.mark.parametrize( + "stopword_presets,match", + [ + ({"custom": "hello"}, "must be a list of strings"), + ( + { + "custom": StopwordsCreate( + preset=StopwordsPreset.EN, additions=None, removals=None + ), + }, + "must be a list of strings", + ), + ({"custom": ["hello", 123]}, "must contain only strings"), + ], + ids=["str_value", "pydantic_model_value", "non_string_element"], + ) + def test_stopword_presets_invalid_shape_raises( + self, + client: weaviate.WeaviateClient, + stopword_presets: dict, + match: str, + ) -> None: + """Client rejects malformed stopword_presets values locally before sending — str would silently split into characters; a pydantic model would serialize to field tuples.""" + if client._connection._weaviate_version.is_lower_than(1, 37, 0): + pytest.skip("Tokenization requires Weaviate >= 1.37.0") + with pytest.raises(ValueError, match=match): + client.tokenization.text( + text="hello", + tokenization=Tokenization.WORD, + stopword_presets=stopword_presets, + ) # --------------------------------------------------------------------------- @@ -309,9 +445,8 @@ def test_text_raises_on_old_server(self, client: weaviate.WeaviateClient) -> Non def test_tokenize_property_raises_on_old_server(self, client: weaviate.WeaviateClient) -> None: if client._connection._weaviate_version.is_at_least(1, 37, 0): pytest.skip("Version gate only applies to Weaviate < 1.37.0") - col = client.collections.get("Any") with pytest.raises(WeaviateUnsupportedFeatureError): - col.config.tokenize_property(property_name="title", text="hello") + client.tokenization.for_property(collection="Any", property_name="title", text="hello") # --------------------------------------------------------------------------- @@ -321,7 +456,7 @@ def test_tokenize_property_raises_on_old_server(self, client: weaviate.WeaviateC @pytest.mark.usefixtures("require_1_37") class TestAsyncClient: - """Verify text() and tokenize_property() work through the async client.""" + """Verify tokenization.text() and tokenization.for_property() work through the async client.""" @pytest.mark.asyncio async def test_text_tokenize(self, async_client: weaviate.WeaviateAsyncClient) -> None: @@ -331,20 +466,21 @@ async def test_text_tokenize(self, async_client: weaviate.WeaviateAsyncClient) - ) assert isinstance(result, TokenizeResult) assert result.indexed == ["the", "quick", "brown", "fox"] + # default "en" applied server-side. + assert result.query == ["quick", "brown", "fox"] @pytest.mark.asyncio - async def test_text_with_analyzer_config( + async def test_text_with_stopwords_fallback( self, async_client: weaviate.WeaviateAsyncClient ) -> None: - cfg = _TextAnalyzerConfigCreate(ascii_fold=True, stopword_preset=StopwordsPreset.EN) + sw = StopwordsCreate(preset=StopwordsPreset.EN, additions=["quick"], removals=None) result = await async_client.tokenization.text( - text="L'école est fermée", + text="the quick brown fox", tokenization=Tokenization.WORD, - analyzer_config=cfg, + stopwords=sw, ) - assert result.indexed == ["l", "ecole", "est", "fermee"] - assert isinstance(result.analyzer_config, TextAnalyzerConfig) - assert result.analyzer_config.ascii_fold is True + assert result.indexed == ["the", "quick", "brown", "fox"] + assert result.query == ["brown", "fox"] @pytest.mark.asyncio async def test_property_tokenize(self, async_client: weaviate.WeaviateAsyncClient) -> None: @@ -364,15 +500,13 @@ async def test_property_tokenize(self, async_client: weaviate.WeaviateAsyncClien ], } ) - col = async_client.collections.get("TestAsyncPropTokenize") - result = await col.config.tokenize_property( + result = await async_client.tokenization.for_property( + collection="TestAsyncPropTokenize", property_name="title", text="The quick brown fox", ) assert isinstance(result, TokenizeResult) - assert result.tokenization == Tokenization.WORD assert result.indexed == ["the", "quick", "brown", "fox"] - assert "the" not in result.query - assert "quick" in result.query + assert result.query == ["quick", "brown", "fox"] finally: await async_client.collections.delete("TestAsyncPropTokenize") diff --git a/weaviate/classes/__init__.py b/weaviate/classes/__init__.py index d495744ac..69af5d920 100644 --- a/weaviate/classes/__init__.py +++ b/weaviate/classes/__init__.py @@ -13,6 +13,7 @@ rbac, replication, tenants, + tokenization, ) # noqa: F401 from .config import ConsistencyLevel @@ -29,6 +30,7 @@ "init", "query", "tenants", + "tokenization", "rbac", "replication", ] diff --git a/weaviate/classes/config.py b/weaviate/classes/config.py index 868cd1c79..c154062d3 100644 --- a/weaviate/classes/config.py +++ b/weaviate/classes/config.py @@ -11,8 +11,10 @@ ReferenceProperty, ReplicationDeletionStrategy, Rerankers, + StopwordsCreate, StopwordsPreset, TextAnalyzerConfig, + TextAnalyzerConfigCreate, Tokenization, VectorDistances, ) @@ -39,8 +41,10 @@ "PQEncoderType", "ReferenceProperty", "Rerankers", + "StopwordsCreate", "StopwordsPreset", "TextAnalyzerConfig", + "TextAnalyzerConfigCreate", "Tokenization", "Vectorizers", "VectorDistances", diff --git a/weaviate/classes/tokenization.py b/weaviate/classes/tokenization.py new file mode 100644 index 000000000..0e89fc64b --- /dev/null +++ b/weaviate/classes/tokenization.py @@ -0,0 +1,17 @@ +from weaviate.collections.classes.config import ( + StopwordsConfig, + StopwordsCreate, + StopwordsPreset, + TextAnalyzerConfigCreate, + Tokenization, +) +from weaviate.tokenization.models import TokenizeResult + +__all__ = [ + "StopwordsConfig", + "StopwordsCreate", + "StopwordsPreset", + "TextAnalyzerConfigCreate", + "Tokenization", + "TokenizeResult", +] diff --git a/weaviate/collections/classes/config.py b/weaviate/collections/classes/config.py index 6d60482a3..43d86375d 100644 --- a/weaviate/collections/classes/config.py +++ b/weaviate/collections/classes/config.py @@ -1,5 +1,6 @@ import datetime from dataclasses import dataclass +from dataclasses import fields as _dataclass_fields from typing import ( Any, ClassVar, @@ -1647,6 +1648,26 @@ class _StopwordsConfig(_ConfigBase): StopwordsConfig = _StopwordsConfig +StopwordsCreate = _StopwordsCreate + +# Invariant: the read-side dataclass (_StopwordsConfig) and the write-side +# pydantic model (_StopwordsCreate) must carry the same set of field names so +# that values round-tripped from ``collection.config.get()`` can flow back into +# ``tokenization.text()`` without silent data loss. If a field is added to one +# but not the other, importing this module fails loudly; the read→write +# conversion in ``weaviate/tokenization/executor.py::_TokenizationExecutor.text`` +# depends on this parity. +_read_fields = {f.name for f in _dataclass_fields(_StopwordsConfig)} +_write_fields = set(_StopwordsCreate.model_fields.keys()) +if _read_fields != _write_fields: + raise RuntimeError( + "_StopwordsConfig / _StopwordsCreate field drift detected — " + f"read-only={_read_fields - _write_fields}, " + f"write-only={_write_fields - _read_fields}. " + "Update both classes together, or adapt the read→write conversion in " + "weaviate/tokenization/executor.py::_TokenizationExecutor.text." + ) +del _read_fields, _write_fields @dataclass @@ -2224,6 +2245,9 @@ def _validate_ascii_fold_ignore(self) -> "_TextAnalyzerConfigCreate": return self +TextAnalyzerConfigCreate = _TextAnalyzerConfigCreate + + class Property(_ConfigCreateModel): """This class defines the structure of a data property that a collection can have within Weaviate. diff --git a/weaviate/collections/config/async_.pyi b/weaviate/collections/config/async_.pyi index a1f740ded..015b70dab 100644 --- a/weaviate/collections/config/async_.pyi +++ b/weaviate/collections/config/async_.pyi @@ -27,7 +27,6 @@ from weaviate.collections.classes.config import ( from weaviate.collections.classes.config_object_ttl import _ObjectTTLConfigUpdate from weaviate.collections.classes.config_vector_index import _VectorIndexConfigDynamicUpdate from weaviate.connect.v4 import ConnectionAsync -from weaviate.tokenization.models import TokenizeResult from .executor import _ConfigCollectionExecutor @@ -91,4 +90,3 @@ class _ConfigCollectionAsync(_ConfigCollectionExecutor[ConnectionAsync]): self, *, vector_config: Union[_VectorConfigCreate, List[_VectorConfigCreate]] ) -> None: ... async def delete_property_index(self, property_name: str, index_name: IndexName) -> bool: ... - async def tokenize_property(self, property_name: str, text: str) -> TokenizeResult: ... diff --git a/weaviate/collections/config/executor.py b/weaviate/collections/config/executor.py index fe9f5ec0d..103ab70ac 100644 --- a/weaviate/collections/config/executor.py +++ b/weaviate/collections/config/executor.py @@ -56,7 +56,6 @@ WeaviateInvalidInputError, WeaviateUnsupportedFeatureError, ) -from weaviate.tokenization.models import TokenizeResult from weaviate.util import ( _capitalize_first_letter, _decode_json_response_dict, @@ -667,42 +666,3 @@ def resp(res: Response) -> bool: error_msg="Property may not exist", status_codes=_ExpectedStatusCodes(ok_in=[200], error="property exists"), ) - - def tokenize_property( - self, - property_name: str, - text: str, - ) -> executor.Result[TokenizeResult]: - """Tokenize text using a property's configured tokenization settings. - - Args: - property_name: The property name whose tokenization config to use. - text: The text to tokenize. - - Returns: - A TokenizeResult with indexed and query token lists. - - Raises: - WeaviateUnsupportedFeatureError: If the server version is below 1.37.0. - """ - if self._connection._weaviate_version.is_lower_than(1, 37, 0): - raise WeaviateUnsupportedFeatureError( - "Tokenization", - str(self._connection._weaviate_version), - "1.37.0", - ) - - path = f"/schema/{self._name}/properties/{property_name}/tokenize" - payload: Dict[str, Any] = {"text": text} - - def resp(response: Response) -> TokenizeResult: - return TokenizeResult.model_validate(response.json()) - - return executor.execute( - response_callback=resp, - method=self._connection.post, - path=path, - weaviate_object=payload, - error_msg="Property tokenization failed", - status_codes=_ExpectedStatusCodes(ok_in=[200], error="tokenize property text"), - ) diff --git a/weaviate/collections/config/sync.pyi b/weaviate/collections/config/sync.pyi index 3664a0e1b..e54d8c8fc 100644 --- a/weaviate/collections/config/sync.pyi +++ b/weaviate/collections/config/sync.pyi @@ -27,7 +27,6 @@ from weaviate.collections.classes.config import ( from weaviate.collections.classes.config_object_ttl import _ObjectTTLConfigUpdate from weaviate.collections.classes.config_vector_index import _VectorIndexConfigDynamicUpdate from weaviate.connect.v4 import ConnectionSync -from weaviate.tokenization.models import TokenizeResult from .executor import _ConfigCollectionExecutor @@ -89,4 +88,3 @@ class _ConfigCollection(_ConfigCollectionExecutor[ConnectionSync]): self, *, vector_config: Union[_VectorConfigCreate, List[_VectorConfigCreate]] ) -> None: ... def delete_property_index(self, property_name: str, index_name: IndexName) -> bool: ... - def tokenize_property(self, property_name: str, text: str) -> TokenizeResult: ... diff --git a/weaviate/tokenization/async_.pyi b/weaviate/tokenization/async_.pyi index d5b1ab12c..ba12abc2a 100644 --- a/weaviate/tokenization/async_.pyi +++ b/weaviate/tokenization/async_.pyi @@ -1,9 +1,10 @@ -from typing import Dict, Optional +from typing import Dict, List, Optional, Union, overload from weaviate.collections.classes.config import ( + StopwordsConfig, + StopwordsCreate, + TextAnalyzerConfigCreate, Tokenization, - _StopwordsCreate, - _TextAnalyzerConfigCreate, ) from weaviate.connect.v4 import ConnectionAsync from weaviate.tokenization.models import TokenizeResult @@ -11,11 +12,24 @@ from weaviate.tokenization.models import TokenizeResult from .executor import _TokenizationExecutor class _TokenizationAsync(_TokenizationExecutor[ConnectionAsync]): + @overload async def text( self, text: str, tokenization: Tokenization, *, - analyzer_config: Optional[_TextAnalyzerConfigCreate] = None, - stopword_presets: Optional[Dict[str, _StopwordsCreate]] = None, + analyzer_config: Optional[TextAnalyzerConfigCreate] = ..., + stopwords: Optional[Union[StopwordsCreate, StopwordsConfig]] = ..., + ) -> TokenizeResult: ... + @overload + async def text( + self, + text: str, + tokenization: Tokenization, + *, + analyzer_config: Optional[TextAnalyzerConfigCreate] = ..., + stopword_presets: Optional[Dict[str, List[str]]] = ..., + ) -> TokenizeResult: ... + async def for_property( + self, collection: str, property_name: str, text: str ) -> TokenizeResult: ... diff --git a/weaviate/tokenization/executor.py b/weaviate/tokenization/executor.py index 3198c8e65..33f1c05f9 100644 --- a/weaviate/tokenization/executor.py +++ b/weaviate/tokenization/executor.py @@ -1,18 +1,20 @@ """Tokenize executor.""" -from typing import Any, Dict, Generic, Optional +from typing import Any, Dict, Generic, List, Optional, Union, overload from httpx import Response from weaviate.collections.classes.config import ( + StopwordsConfig, + StopwordsCreate, + TextAnalyzerConfigCreate, Tokenization, - _StopwordsCreate, - _TextAnalyzerConfigCreate, ) from weaviate.connect import executor from weaviate.connect.v4 import ConnectionType, _ExpectedStatusCodes from weaviate.exceptions import WeaviateUnsupportedFeatureError from weaviate.tokenization.models import TokenizeResult +from weaviate.util import _capitalize_first_letter class _TokenizationExecutor(Generic[ConnectionType]): @@ -27,32 +29,109 @@ def __check_version(self) -> None: "1.37.0", ) + # Overloads make ``stopwords`` and ``stopword_presets`` mutually exclusive + # at type-check time. Passing both is additionally rejected at runtime with + # ``ValueError`` in the implementation below. ``stopwords`` accepts either a + # ``StopwordsCreate`` (the write-side shape) or a ``StopwordsConfig`` (the + # read-side shape returned by ``collection.config.get()``), so values round- + # tripped through config reads can be passed back in directly. + @overload def text( self, text: str, tokenization: Tokenization, *, - analyzer_config: Optional[_TextAnalyzerConfigCreate] = None, - stopword_presets: Optional[Dict[str, _StopwordsCreate]] = None, + analyzer_config: Optional[TextAnalyzerConfigCreate] = ..., + stopwords: Optional[Union[StopwordsCreate, StopwordsConfig]] = ..., + ) -> executor.Result[TokenizeResult]: ... + + @overload + def text( + self, + text: str, + tokenization: Tokenization, + *, + analyzer_config: Optional[TextAnalyzerConfigCreate] = ..., + stopword_presets: Optional[Dict[str, List[str]]] = ..., + ) -> executor.Result[TokenizeResult]: ... + + def text( + self, + text: str, + tokenization: Tokenization, + *, + analyzer_config: Optional[TextAnalyzerConfigCreate] = None, + stopwords: Optional[Union[StopwordsCreate, StopwordsConfig]] = None, + stopword_presets: Optional[Dict[str, List[str]]] = None, ) -> executor.Result[TokenizeResult]: """Tokenize text using the generic /v1/tokenize endpoint. + For ``word`` tokenization the server defaults to the built-in ``en`` + stopword preset when no stopword configuration is supplied. Pass + ``analyzer_config=TextAnalyzerConfigCreate(stopword_preset="none")`` + or equivalent to opt out. + + Call patterns for stopword handling (``stopwords`` and + ``stopword_presets`` are mutually exclusive — pass at most one): + + 1. **No stopword config** — rely on the server default (``en`` for + word tokenization, none otherwise):: + + client.tokenization.text(text=..., tokenization=Tokenization.WORD) + + 2. **Apply a one-off stopwords block** via ``stopwords`` — the block + filters the query tokens directly, same shape as a collection's + ``invertedIndexConfig.stopwords``:: + + client.tokenization.text( + text=..., + tokenization=Tokenization.WORD, + stopwords=StopwordsCreate(preset=StopwordsPreset.EN, additions=["foo"]), + ) + + 3. **Register a named-preset catalog** via ``stopword_presets`` and + reference one by name from ``analyzer_config.stopword_preset``. + The catalog can also override built-in presets such as ``en``:: + + client.tokenization.text( + text=..., + tokenization=Tokenization.WORD, + analyzer_config=TextAnalyzerConfigCreate(stopword_preset="custom"), + stopword_presets={"custom": ["foo", "bar"]}, + ) + Args: text: The text to tokenize. - tokenization: The tokenization method to use (e.g. Tokenization.WORD). - analyzer_config: Text analyzer settings (ASCII folding, stopword preset). - stopword_presets: Custom stopword preset definitions, keyed by name. - Each value is a ``_StopwordsCreate`` with optional preset, additions, - and removals fields. + tokenization: The tokenization method to use (e.g. ``Tokenization.WORD``). + analyzer_config: Text analyzer settings (ASCII folding, stopword + preset name), built via ``Configure.text_analyzer(...)``. + ``stopword_preset`` may reference a built-in preset + (``en`` / ``none``) or a name defined in ``stopword_presets``. + stopwords: One-off stopwords block applied directly to this request. + Mirrors the collection-level ``invertedIndexConfig.stopwords`` + shape — hence the rich model with preset / additions / removals. + Mutually exclusive with ``stopword_presets``. + stopword_presets: Named-preset catalog (name → word list). Mirrors + the property-level preset catalog — a plain mapping, since a + property only references a preset by name (via + ``analyzer_config.stopword_preset``) rather than carrying the + full stopwords block. Entries can override built-ins like + ``en``. Mutually exclusive with ``stopwords``. Returns: - A TokenizeResult with indexed and query token lists. + A ``TokenizeResult`` with indexed and query token lists. The generic + endpoint does not echo request fields back in the response. Raises: WeaviateUnsupportedFeatureError: If the server version is below 1.37.0. + ValueError: If both ``stopwords`` and ``stopword_presets`` are passed, + or if any ``stopword_presets`` value is not a list/tuple of strings. """ self.__check_version() + if stopwords is not None and stopword_presets is not None: + raise ValueError("stopwords and stopword_presets are mutually exclusive; pass only one") + payload: Dict[str, Any] = { "text": text, "tokenization": tokenization.value, @@ -63,10 +142,42 @@ def text( if ac_dict: payload["analyzerConfig"] = ac_dict + if stopwords is not None: + if isinstance(stopwords, StopwordsConfig): + # Widen from the read-side shape returned by config.get() to the + # write-side shape the server expects. Field parity between the + # two classes is enforced at import time in + # ``weaviate/collections/classes/config.py``, so iterating + # ``StopwordsCreate.model_fields`` copies every field. + stopwords = StopwordsCreate( + **{name: getattr(stopwords, name) for name in StopwordsCreate.model_fields} + ) + sw_dict = stopwords._to_dict() + if sw_dict: + payload["stopwords"] = sw_dict + if stopword_presets is not None: - payload["stopwordPresets"] = { - name: cfg._to_dict() for name, cfg in stopword_presets.items() - } + # Plain word-list shape matching a collection's + # invertedIndexConfig.stopwordPresets. Reject str (would + # silently split into characters) and pydantic models / + # other non-sequence shapes up-front so callers get a clear + # error instead of a malformed payload. + validated: Dict[str, List[str]] = {} + for name, words in stopword_presets.items(): + if isinstance(words, (str, bytes)): + raise ValueError( + f"stopword_presets[{name!r}] must be a list of strings, " + f"got {type(words).__name__}" + ) + if not isinstance(words, (list, tuple)): + raise ValueError( + f"stopword_presets[{name!r}] must be a list of strings, " + f"got {type(words).__name__}" + ) + if not all(isinstance(w, str) for w in words): + raise ValueError(f"stopword_presets[{name!r}] must contain only strings") + validated[name] = list(words) + payload["stopwordPresets"] = validated def resp(response: Response) -> TokenizeResult: return TokenizeResult.model_validate(response.json()) @@ -79,3 +190,42 @@ def resp(response: Response) -> TokenizeResult: error_msg="Tokenization failed", status_codes=_ExpectedStatusCodes(ok_in=[200], error="tokenize text"), ) + + def for_property( + self, + collection: str, + property_name: str, + text: str, + ) -> executor.Result[TokenizeResult]: + """Tokenize text using a property's configured tokenization settings. + + The server resolves the tokenization and analyzer configuration from + the property's schema, so callers only supply the text. + + Args: + collection: The collection that owns the property. + property_name: The property name whose tokenization config to use. + text: The text to tokenize. + + Returns: + A TokenizeResult with indexed and query token lists. + + Raises: + WeaviateUnsupportedFeatureError: If the server version is below 1.37.0. + """ + self.__check_version() + + path = f"/schema/{_capitalize_first_letter(collection)}/properties/{property_name}/tokenize" + payload: Dict[str, Any] = {"text": text} + + def resp(response: Response) -> TokenizeResult: + return TokenizeResult.model_validate(response.json()) + + return executor.execute( + response_callback=resp, + method=self._connection.post, + path=path, + weaviate_object=payload, + error_msg="Property tokenization failed", + status_codes=_ExpectedStatusCodes(ok_in=[200], error="tokenize property text"), + ) diff --git a/weaviate/tokenization/models.py b/weaviate/tokenization/models.py index 8bfa508f8..baeac140c 100644 --- a/weaviate/tokenization/models.py +++ b/weaviate/tokenization/models.py @@ -1,56 +1,17 @@ """Return types for tokenization operations.""" -from typing import Any, Dict, List, Optional +from typing import List -from pydantic import BaseModel, ConfigDict, Field, field_validator - -from weaviate.collections.classes.config import ( - StopwordsConfig, - StopwordsPreset, - TextAnalyzerConfig, - Tokenization, -) +from pydantic import BaseModel class TokenizeResult(BaseModel): """Result of a tokenization operation. Attributes: - tokenization: The tokenization method that was applied. indexed: Tokens as they would be stored in the inverted index. query: Tokens as they would be used for querying (after stopword removal). - analyzer_config: The text analyzer configuration that was used, if any. - stopword_config: The stopword configuration that was used, if any. """ - model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True) - - tokenization: Tokenization indexed: List[str] query: List[str] - analyzer_config: Optional[TextAnalyzerConfig] = Field(default=None, alias="analyzerConfig") - stopword_config: Optional[StopwordsConfig] = Field(default=None, alias="stopwordConfig") - - @field_validator("analyzer_config", mode="before") - @classmethod - def _parse_analyzer_config(cls, v: Optional[Dict[str, Any]]) -> Optional[TextAnalyzerConfig]: - if v is None: - return None - if "asciiFold" not in v and "stopwordPreset" not in v: - return None - return TextAnalyzerConfig( - ascii_fold=v.get("asciiFold", False), - ascii_fold_ignore=v.get("asciiFoldIgnore"), - stopword_preset=v.get("stopwordPreset"), - ) - - @field_validator("stopword_config", mode="before") - @classmethod - def _parse_stopword_config(cls, v: Optional[Dict[str, Any]]) -> Optional[StopwordsConfig]: - if v is None: - return None - return StopwordsConfig( - preset=StopwordsPreset(v["preset"]), - additions=v.get("additions"), - removals=v.get("removals"), - ) diff --git a/weaviate/tokenization/sync.pyi b/weaviate/tokenization/sync.pyi index af30ef103..71aaaea5c 100644 --- a/weaviate/tokenization/sync.pyi +++ b/weaviate/tokenization/sync.pyi @@ -1,9 +1,10 @@ -from typing import Dict, Optional +from typing import Dict, List, Optional, Union, overload from weaviate.collections.classes.config import ( + StopwordsConfig, + StopwordsCreate, + TextAnalyzerConfigCreate, Tokenization, - _StopwordsCreate, - _TextAnalyzerConfigCreate, ) from weaviate.connect.v4 import ConnectionSync from weaviate.tokenization.models import TokenizeResult @@ -11,11 +12,22 @@ from weaviate.tokenization.models import TokenizeResult from .executor import _TokenizationExecutor class _Tokenization(_TokenizationExecutor[ConnectionSync]): + @overload def text( self, text: str, tokenization: Tokenization, *, - analyzer_config: Optional[_TextAnalyzerConfigCreate] = None, - stopword_presets: Optional[Dict[str, _StopwordsCreate]] = None, + analyzer_config: Optional[TextAnalyzerConfigCreate] = ..., + stopwords: Optional[Union[StopwordsCreate, StopwordsConfig]] = ..., ) -> TokenizeResult: ... + @overload + def text( + self, + text: str, + tokenization: Tokenization, + *, + analyzer_config: Optional[TextAnalyzerConfigCreate] = ..., + stopword_presets: Optional[Dict[str, List[str]]] = ..., + ) -> TokenizeResult: ... + def for_property(self, collection: str, property_name: str, text: str) -> TokenizeResult: ...