Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions integration/test_collection_diversity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import pytest

from integration.conftest import CollectionFactory
from weaviate.classes.query import Diversity
from weaviate.collections.classes.config import Configure, DataType, Property
from weaviate.collections.classes.data import DataObject


def _create_clustered_collection(collection_factory: CollectionFactory):
"""Create a collection with 3 tight clusters (a, b, c) of vectors in 3D."""
collection = collection_factory(
properties=[Property(name="text", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.none(),
)
if collection._connection._weaviate_version.is_lower_than(1, 37, 0):
pytest.skip("Diversity selection requires Weaviate >= 1.37.0")
collection.data.insert_many(
[
DataObject(properties={"text": "a1"}, vector=[1.0, 0.0, 0.0]),
DataObject(properties={"text": "a2"}, vector=[0.95, 0.05, 0.0]),
DataObject(properties={"text": "a3"}, vector=[0.9, 0.1, 0.0]),
DataObject(properties={"text": "b1"}, vector=[0.0, 1.0, 0.0]),
DataObject(properties={"text": "b2"}, vector=[0.05, 0.95, 0.0]),
DataObject(properties={"text": "c1"}, vector=[0.0, 0.0, 1.0]),
]
)
return collection


def test_near_vector_diversity_pure_relevance(
collection_factory: CollectionFactory,
) -> None:
"""balance=1.0 -> MMR degenerates to pure relevance (same as no diversity)."""
collection = _create_clustered_collection(collection_factory)

baseline = collection.query.near_vector(near_vector=[1.0, 0.0, 0.0], limit=3).objects
diverse = collection.query.near_vector(
near_vector=[1.0, 0.0, 0.0],
selection=Diversity.MMR(limit=3, balance=1.0),
).objects

assert [o.properties["text"] for o in baseline] == [o.properties["text"] for o in diverse]


def test_near_vector_diversity_pure_diversity(
collection_factory: CollectionFactory,
) -> None:
"""balance=0.0 -> MMR picks maximally diverse results (one per cluster)."""
collection = _create_clustered_collection(collection_factory)

result = collection.query.near_vector(
near_vector=[1.0, 0.0, 0.0],
selection=Diversity.MMR(limit=3, balance=0.0),
)
texts = {o.properties["text"] for o in result.objects}
assert len(texts) == 3
# Pure diversity should pick one from each cluster (a*, b*, c*)
clusters = {t[0] for t in texts}
assert clusters == {"a", "b", "c"}


def test_near_object_diversity(collection_factory: CollectionFactory) -> None:
"""near_object supports diversity selection."""
collection = _create_clustered_collection(collection_factory)
anchor = next(iter(collection.query.fetch_objects().objects)).uuid

result = collection.query.near_object(
near_object=anchor,
selection=Diversity.MMR(limit=3, balance=0.0),
)
assert len(result.objects) == 3
clusters = {o.properties["text"][0] for o in result.objects}
assert len(clusters) == 3


def test_diversity_cannot_be_instantiated() -> None:
"""Diversity is a factory — direct instantiation should fail."""
with pytest.raises(TypeError):
Diversity()


def test_diversity_mmr_only_limit(collection_factory: CollectionFactory) -> None:
"""MMR accepts just a limit (balance defaults to server-side value)."""
collection = _create_clustered_collection(collection_factory)
result = collection.query.near_vector(
near_vector=[1.0, 0.0, 0.0],
selection=Diversity.MMR(limit=2),
)
assert len(result.objects) == 2


def test_near_text_diversity(collection_factory: CollectionFactory) -> None:
"""near_text supports diversity selection via text2vec-contextionary."""
collection = collection_factory(
properties=[Property(name="name", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.text2vec_contextionary(
vectorize_collection_name=False
),
)
if collection._connection._weaviate_version.is_lower_than(1, 37, 0):
pytest.skip("Diversity selection requires Weaviate >= 1.37.0")
for name in ["banana", "apple", "orange", "car", "truck", "bike"]:
collection.data.insert({"name": name})

result = collection.query.near_text(
query="fruit",
selection=Diversity.MMR(limit=3, balance=0.0),
)
assert len(result.objects) == 3
4 changes: 4 additions & 0 deletions weaviate/classes/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
BM25OperatorFactory as BM25Operator,
)
from weaviate.collections.classes.grpc import (
MMR,
Diversity,
GroupBy,
HybridFusion,
HybridVector,
Expand All @@ -21,6 +23,7 @@
from weaviate.collections.classes.types import GeoCoordinate

__all__ = [
"Diversity",
"Filter",
"FilterReturn",
"GeoCoordinate",
Expand All @@ -29,6 +32,7 @@
"HybridFusion",
"HybridVector",
"BM25Operator",
"MMR",
"MetadataQuery",
"Metrics",
"Move",
Expand Down
22 changes: 22 additions & 0 deletions weaviate/collections/classes/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,28 @@ class Rerank(_WeaviateInput):
query: Optional[str] = Field(default=None)


@dataclass
class MMR:
"""Define MMR (Maximal Marginal Relevance) diversity selection.

Args:
limit: Optional number of candidates to consider for diversification.
balance: Optional MMR lambda in [0.0, 1.0] — 1.0 is pure relevance, 0.0 is pure diversity.
"""

limit: Optional[int] = None
balance: Optional[float] = None


class Diversity:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should match the parameter name

"""Use this factory class to apply diversity selection to search results via MMR."""

def __init__(self) -> None:
raise TypeError("Diversity cannot be instantiated directly. Use Diversity.MMR(...).")

MMR = MMR
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a factory here, eg something like

      @staticmethod
      def mmr(limit: Optional[int] = None, balance: Optional[float] = None) -> MMR:
          """Maximal Marginal Relevance.

          Args:
              limit: number of candidates to consider for diversification.
              balance: MMR lambda in [0.0, 1.0] — 1.0 pure relevance, 0.0 pure diversity.
          """
          return MMR(limit=limit, balance=balance)



@dataclass
class BM25OperatorOptions:
# replace with ClassVar[base_search_pb2.SearchOperatorOptions.Operator] once python 3.10 is removed
Expand Down
13 changes: 11 additions & 2 deletions weaviate/collections/grpc/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
QueryNested,
Rerank,
TargetVectorJoinType,
MMR,
_MetadataQuery,
_QueryReference,
_QueryReferenceMultiTarget,
Expand Down Expand Up @@ -262,6 +263,7 @@ def near_vector(
return_metadata: Optional[_MetadataQuery] = None,
return_properties: Union[PROPERTIES, bool, None] = None,
return_references: Optional[REFERENCES] = None,
selection: Optional[MMR] = None,
) -> search_get_pb2.SearchRequest:
return self.__create_request(
limit=limit,
Expand All @@ -275,7 +277,7 @@ def near_vector(
autocut=autocut,
group_by=group_by,
near_vector=self._parse_near_vector(
near_vector, certainty, distance, target_vector=target_vector
near_vector, certainty, distance, target_vector=target_vector, selection=selection
),
)

Expand All @@ -296,6 +298,7 @@ def near_object(
return_metadata: Optional[_MetadataQuery] = None,
return_properties: Union[PROPERTIES, bool, None] = None,
return_references: Optional[REFERENCES] = None,
selection: Optional[MMR] = None,
) -> search_get_pb2.SearchRequest:
return self.__create_request(
limit=limit,
Expand All @@ -308,7 +311,9 @@ def near_object(
rerank=rerank,
autocut=autocut,
group_by=group_by,
near_object=self._parse_near_object(near_object, certainty, distance, target_vector),
near_object=self._parse_near_object(
near_object, certainty, distance, target_vector, selection=selection
),
)

def near_text(
Expand All @@ -330,6 +335,7 @@ def near_text(
return_metadata: Optional[_MetadataQuery] = None,
return_properties: Union[PROPERTIES, bool, None] = None,
return_references: Optional[REFERENCES] = None,
selection: Optional[MMR] = None,
) -> search_get_pb2.SearchRequest:
return self.__create_request(
limit=limit,
Expand All @@ -349,6 +355,7 @@ def near_text(
move_away=move_away,
move_to=move_to,
target_vector=target_vector,
selection=selection,
),
)

Expand All @@ -370,6 +377,7 @@ def near_media(
return_metadata: Optional[_MetadataQuery] = None,
return_properties: Union[PROPERTIES, bool, None] = None,
return_references: Optional[REFERENCES] = None,
selection: Optional[MMR] = None,
) -> search_get_pb2.SearchRequest:
return self.__create_request(
limit=limit,
Expand All @@ -388,6 +396,7 @@ def near_media(
certainty,
distance,
target_vector,
selection=selection,
),
)

Expand Down
28 changes: 28 additions & 0 deletions weaviate/collections/grpc/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
PrimitiveVectorType,
TargetVectorJoinType,
TwoDimensionalVectorType,
MMR,
_HybridNearText,
_HybridNearVector,
_ListOfVectorsQuery,
Expand Down Expand Up @@ -310,12 +311,26 @@ def _parse_near_options(
float(distance) if distance is not None else None,
)

@staticmethod
def _selection_to_grpc(
selection: Optional[MMR],
) -> Optional[base_search_pb2.Selection]:
if selection is None:
return None
return base_search_pb2.Selection(
mmr=base_search_pb2.Selection.MMR(
limit=selection.limit,
balance=selection.balance,
)
)

def _parse_near_vector(
self,
near_vector: NearVectorInputType,
certainty: Optional[NUMBER],
distance: Optional[NUMBER],
target_vector: Optional[TargetVectorJoinType],
selection: Optional[MMR] = None,
) -> base_search_pb2.NearVector:
if self._validate_arguments:
_validate_input(
Expand Down Expand Up @@ -399,6 +414,7 @@ def _parse_near_vector(
vector_per_target=vector_per_target_tmp,
vector_for_targets=vector_for_targets,
vectors=vectors,
selection=self._selection_to_grpc(selection),
)

@staticmethod
Expand All @@ -423,6 +439,7 @@ def _parse_near_text(
move_to: Optional[Move],
move_away: Optional[Move],
target_vector: Optional[TargetVectorJoinType],
selection: Optional[MMR] = None,
) -> base_search_pb2.NearTextSearch:
if self._validate_arguments:
_validate_input(
Expand Down Expand Up @@ -451,6 +468,7 @@ def _parse_near_text(
move_to=self.__parse_move(move_to),
targets=targets,
target_vectors=target_vector,
selection=self._selection_to_grpc(selection),
)

def _parse_near_object(
Expand All @@ -459,6 +477,7 @@ def _parse_near_object(
certainty: Optional[NUMBER],
distance: Optional[NUMBER],
target_vector: Optional[TargetVectorJoinType],
selection: Optional[MMR] = None,
) -> base_search_pb2.NearObject:
if self._validate_arguments:
_validate_input(
Expand All @@ -482,6 +501,7 @@ def _parse_near_object(
distance=distance,
targets=targets,
target_vectors=target_vector,
selection=self._selection_to_grpc(selection),
)

def _parse_media(
Expand All @@ -491,6 +511,7 @@ def _parse_media(
certainty: Optional[NUMBER],
distance: Optional[NUMBER],
target_vector: Optional[TargetVectorJoinType],
selection: Optional[MMR] = None,
) -> dict:
if self._validate_arguments:
_validate_input(
Expand All @@ -508,13 +529,15 @@ def _parse_media(

kwargs: Dict[str, Any] = {}
targets, target_vector = self.__target_vector_to_grpc(target_vector)
selection_grpc = self._selection_to_grpc(selection)
if type_ == "audio":
kwargs["near_audio"] = base_search_pb2.NearAudioSearch(
audio=media,
distance=distance,
certainty=certainty,
target_vectors=target_vector,
targets=targets,
selection=selection_grpc,
)
elif type_ == "depth":
kwargs["near_depth"] = base_search_pb2.NearDepthSearch(
Expand All @@ -523,6 +546,7 @@ def _parse_media(
certainty=certainty,
target_vectors=target_vector,
targets=targets,
selection=selection_grpc,
)
elif type_ == "image":
kwargs["near_image"] = base_search_pb2.NearImageSearch(
Expand All @@ -531,6 +555,7 @@ def _parse_media(
certainty=certainty,
target_vectors=target_vector,
targets=targets,
selection=selection_grpc,
)
elif type_ == "imu":
kwargs["near_imu"] = base_search_pb2.NearIMUSearch(
Expand All @@ -539,6 +564,7 @@ def _parse_media(
certainty=certainty,
target_vectors=target_vector,
targets=targets,
selection=selection_grpc,
)
elif type_ == "thermal":
kwargs["near_thermal"] = base_search_pb2.NearThermalSearch(
Expand All @@ -547,6 +573,7 @@ def _parse_media(
certainty=certainty,
target_vectors=target_vector,
targets=targets,
selection=selection_grpc,
)
elif type_ == "video":
kwargs["near_video"] = base_search_pb2.NearVideoSearch(
Expand All @@ -555,6 +582,7 @@ def _parse_media(
certainty=certainty,
target_vectors=target_vector,
targets=targets,
selection=selection_grpc,
)
else:
raise ValueError(
Expand Down
Loading
Loading