Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions src/classifai/indexers/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class VectorStore:
Attributes:
file_name (str): the data file contatining the knowledgebase to build the `VectorStore`
data_type (str): the data type of the data file (curently only csv supported)
vectoriser (VectoriserBase): A `Vectoriser` object from the corresponding ClassifAI Pacakge module
vectoriser (VectoriserBase): A `Vectoriser` object from the corresponding ClassifAI Package module
batch_size (int): the batch size to pass to the vectoriser when embedding
meta_data (dict): key-value pairs of metadata to extract from the input file and their correpsonding types
output_dir (str): the path to the output directory where the `VectorStore` will be saved
Expand All @@ -83,7 +83,7 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
file_name: str,
data_type: str,
vectoriser: VectoriserBase,
batch_size: int = 8,
batch_size: int = 250,
meta_data: dict | None = None,
output_dir: str | None = None,
overwrite: bool = False,
Expand All @@ -100,7 +100,7 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
vectoriser (object): The `Vectoriser` object used to transform text into
vector embeddings.
batch_size (int): [optional] The batch size for processing the input file and batching to
vectoriser. Defaults to 8.
vectoriser. Defaults to 250.
meta_data (dict): [optional] key,value pair metadata column names to extract from the input file and their types.
Defaults to `None`.
output_dir (str): [optional] The directory where the `VectorStore` will be saved.
Expand Down Expand Up @@ -351,6 +351,7 @@ def _save_metadata(self, path: str):
"vectoriser_class": self.vectoriser_class,
"vector_shape": self.vector_shape,
"num_vectors": self.num_vectors,
"batch_size": self.batch_size,
"created_at": time.time(),
"meta_data": serializable_column_meta_data,
}
Expand Down Expand Up @@ -672,15 +673,15 @@ def reverse_search( # noqa: C901, PLR0912

return result_df

def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> VectorStoreSearchOutput: # noqa: C901, PLR0912, PLR0915

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'd like to retain the option for users to specify a different batch size at this point, but we'd want the default behaviour to follow the single source of truth.

def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=None) -> VectorStoreSearchOutput: # noqa: C901, PLR0912, PLR0915
"""Searches the `VectorStore` using queries from a `VectorStoreSearchInput` object.
Outputs ranked results in `VectorStoreSearchOutput` object. In batches, converts users text queries into vector embeddings,
computes cosine similarity with stored document vectors, and retrieves the top results.

Args:
query (VectorStoreSearchInput): A `VectorStoreSearchInput` object containing the text query or list of queries to search for with ids.
n_results (int): [optional] Number of top results to return for each query. Default 10.
batch_size (int): [optional] The batch size for processing queries. Default 8.
batch_size (int): [optional] The batch size for processing queries. Defaults to the `batch_size` set during initialisation.

Returns:
(VectorStoreSearchOutput): A `VectorStoreSearchOutput` object containing search results with columns for `query_id`, `query_text`,
Expand All @@ -702,8 +703,10 @@ def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> V
if not isinstance(n_results, int) or n_results < 1:
raise DataValidationError("n_results must be an integer >= 1.", context={"n_results": n_results})

if not isinstance(batch_size, int) or batch_size < 1:
raise DataValidationError("batch_size must be an integer >= 1.", context={"batch_size": batch_size})
query_batch_size = batch_size if batch_size is not None else self.batch_size

if not isinstance(query_batch_size, int) or query_batch_size < 1:
raise DataValidationError("batch_size must be an integer >= 1.", context={"batch_size": query_batch_size})

if self.vectors is None:
raise ConfigurationError("Vector store is not initialized (vectors is None).")
Expand Down Expand Up @@ -731,9 +734,9 @@ def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> V

all_results: list[pl.DataFrame] = []

for i in self.classifai_tqdm(range(0, len(query), batch_size), desc="Processing query batches"):
query_text_batch = query.query.to_list()[i : i + batch_size]
query_ids_batch = query.id.to_list()[i : i + batch_size]
for i in self.classifai_tqdm(range(0, len(query), query_batch_size), desc="Processing query batches"):
query_text_batch = query.query.to_list()[i : i + query_batch_size]
query_ids_batch = query.id.to_list()[i : i + query_batch_size]

if len(query_text_batch) == 0:
continue
Expand Down Expand Up @@ -825,7 +828,7 @@ def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> V
code="search_failed",
context={
"n_queries": len(query),
"batch_size": batch_size,
"batch_size": query_batch_size,
"n_results": n_results,
"cause_type": type(e).__name__,
"cause_message": str(e),
Expand All @@ -849,7 +852,9 @@ def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> V
return result_df

@classmethod
def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None, quiet_mode: bool = False): # noqa: C901, PLR0912, PLR0915
def from_filespace( # noqa: C901, PLR0912, PLR0915
cls, folder_path, vectoriser, batch_size: int | None = None, hooks: dict | None = None, quiet_mode: bool = False
):
"""Creates a `VectorStore` instance from stored metadata and Parquet files.
This method reads the metadata and vectors from the specified folder,
validates the contents, and initializes a `VectorStore` object with the
Expand All @@ -863,6 +868,7 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None, quie
Args:
folder_path (str): The folder path containing the metadata and Parquet files.
vectoriser (object): The `Vectoriser` object used to transform text into vector embeddings.
batch_size (int): [optional] Overrides the batch size stored in metadata. Defaults to `None`, which uses the value from metadata.
hooks (dict): [optional] A dictionary of user-defined hooks for preprocessing and postprocessing. Defaults to None.
quiet_mode (bool): [optional] Whether to minimise verbose output, such as progress bars. Defaults to `False`.

Expand Down Expand Up @@ -912,6 +918,9 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None, quie
context={"vectoriser_type": type(vectoriser).__name__},
)

if batch_size is not None and (not isinstance(batch_size, int) or batch_size < 1):
raise DataValidationError("batch_size must be an integer >= 1 or None.", context={"batch_size": batch_size})

if hooks is not None and not isinstance(hooks, dict):
raise DataValidationError("hooks must be a dict or None.", context={"hooks_type": type(hooks).__name__})

Expand Down Expand Up @@ -939,7 +948,7 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None, quie
context={"metadata_path": metadata_in_path, "metadata_type": type(metadata).__name__},
)

required_keys = ["vectoriser_class", "vector_shape", "num_vectors", "created_at", "meta_data"]
required_keys = ["vectoriser_class", "vector_shape", "num_vectors", "batch_size", "created_at", "meta_data"]
missing = [k for k in required_keys if k not in metadata]
if missing:
raise DataValidationError(
Expand Down Expand Up @@ -1022,7 +1031,7 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None, quie
vector_store.file_name = None
vector_store.data_type = None
vector_store.vectoriser = vectoriser
vector_store.batch_size = None
vector_store.batch_size = batch_size if batch_size is not None else metadata["batch_size"]
vector_store.meta_data = deserialized_column_meta_data
vector_store.vectors = df
vector_store.vector_shape = metadata["vector_shape"]
Expand Down
Loading