diff --git a/src/classifai/indexers/main.py b/src/classifai/indexers/main.py index 6171c0d..7ae19e8 100644 --- a/src/classifai/indexers/main.py +++ b/src/classifai/indexers/main.py @@ -66,7 +66,7 @@ class VectorStore: Attributes: file_name (str): the data file contatining the knowledgebase to build the `VectorStore` data_type (str): the data type of the data file (curently only csv supported) - vectoriser (VectoriserBase): A `Vectoriser` object from the corresponding ClassifAI Pacakge module + vectoriser (VectoriserBase): A `Vectoriser` object from the corresponding ClassifAI Package module batch_size (int): the batch size to pass to the vectoriser when embedding meta_data (dict): key-value pairs of metadata to extract from the input file and their correpsonding types output_dir (str): the path to the output directory where the `VectorStore` will be saved @@ -83,7 +83,7 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 file_name: str, data_type: str, vectoriser: VectoriserBase, - batch_size: int = 8, + batch_size: int = 250, meta_data: dict | None = None, output_dir: str | None = None, overwrite: bool = False, @@ -100,7 +100,7 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 vectoriser (object): The `Vectoriser` object used to transform text into vector embeddings. batch_size (int): [optional] The batch size for processing the input file and batching to - vectoriser. Defaults to 8. + vectoriser. Defaults to 250. meta_data (dict): [optional] key,value pair metadata column names to extract from the input file and their types. Defaults to `None`. output_dir (str): [optional] The directory where the `VectorStore` will be saved. @@ -351,6 +351,7 @@ def _save_metadata(self, path: str): "vectoriser_class": self.vectoriser_class, "vector_shape": self.vector_shape, "num_vectors": self.num_vectors, + "batch_size": self.batch_size, "created_at": time.time(), "meta_data": serializable_column_meta_data, } @@ -672,7 +673,7 @@ def reverse_search( # noqa: C901, PLR0912 return result_df - def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> VectorStoreSearchOutput: # noqa: C901, PLR0912, PLR0915 + def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=None) -> VectorStoreSearchOutput: # noqa: C901, PLR0912, PLR0915 """Searches the `VectorStore` using queries from a `VectorStoreSearchInput` object. Outputs ranked results in `VectorStoreSearchOutput` object. In batches, converts users text queries into vector embeddings, computes cosine similarity with stored document vectors, and retrieves the top results. @@ -680,7 +681,7 @@ def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> V Args: query (VectorStoreSearchInput): A `VectorStoreSearchInput` object containing the text query or list of queries to search for with ids. n_results (int): [optional] Number of top results to return for each query. Default 10. - batch_size (int): [optional] The batch size for processing queries. Default 8. + batch_size (int): [optional] The batch size for processing queries. Defaults to the `batch_size` set during initialisation. Returns: (VectorStoreSearchOutput): A `VectorStoreSearchOutput` object containing search results with columns for `query_id`, `query_text`, @@ -702,8 +703,10 @@ def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> V if not isinstance(n_results, int) or n_results < 1: raise DataValidationError("n_results must be an integer >= 1.", context={"n_results": n_results}) - if not isinstance(batch_size, int) or batch_size < 1: - raise DataValidationError("batch_size must be an integer >= 1.", context={"batch_size": batch_size}) + query_batch_size = batch_size if batch_size is not None else self.batch_size + + if not isinstance(query_batch_size, int) or query_batch_size < 1: + raise DataValidationError("batch_size must be an integer >= 1.", context={"batch_size": query_batch_size}) if self.vectors is None: raise ConfigurationError("Vector store is not initialized (vectors is None).") @@ -731,9 +734,9 @@ def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> V all_results: list[pl.DataFrame] = [] - for i in self.classifai_tqdm(range(0, len(query), batch_size), desc="Processing query batches"): - query_text_batch = query.query.to_list()[i : i + batch_size] - query_ids_batch = query.id.to_list()[i : i + batch_size] + for i in self.classifai_tqdm(range(0, len(query), query_batch_size), desc="Processing query batches"): + query_text_batch = query.query.to_list()[i : i + query_batch_size] + query_ids_batch = query.id.to_list()[i : i + query_batch_size] if len(query_text_batch) == 0: continue @@ -825,7 +828,7 @@ def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> V code="search_failed", context={ "n_queries": len(query), - "batch_size": batch_size, + "batch_size": query_batch_size, "n_results": n_results, "cause_type": type(e).__name__, "cause_message": str(e), @@ -849,7 +852,9 @@ def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> V return result_df @classmethod - def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None, quiet_mode: bool = False): # noqa: C901, PLR0912, PLR0915 + def from_filespace( # noqa: C901, PLR0912, PLR0915 + cls, folder_path, vectoriser, batch_size: int | None = None, hooks: dict | None = None, quiet_mode: bool = False + ): """Creates a `VectorStore` instance from stored metadata and Parquet files. This method reads the metadata and vectors from the specified folder, validates the contents, and initializes a `VectorStore` object with the @@ -863,6 +868,7 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None, quie Args: folder_path (str): The folder path containing the metadata and Parquet files. vectoriser (object): The `Vectoriser` object used to transform text into vector embeddings. + batch_size (int): [optional] Overrides the batch size stored in metadata. Defaults to `None`, which uses the value from metadata. hooks (dict): [optional] A dictionary of user-defined hooks for preprocessing and postprocessing. Defaults to None. quiet_mode (bool): [optional] Whether to minimise verbose output, such as progress bars. Defaults to `False`. @@ -912,6 +918,9 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None, quie context={"vectoriser_type": type(vectoriser).__name__}, ) + if batch_size is not None and (not isinstance(batch_size, int) or batch_size < 1): + raise DataValidationError("batch_size must be an integer >= 1 or None.", context={"batch_size": batch_size}) + if hooks is not None and not isinstance(hooks, dict): raise DataValidationError("hooks must be a dict or None.", context={"hooks_type": type(hooks).__name__}) @@ -939,7 +948,7 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None, quie context={"metadata_path": metadata_in_path, "metadata_type": type(metadata).__name__}, ) - required_keys = ["vectoriser_class", "vector_shape", "num_vectors", "created_at", "meta_data"] + required_keys = ["vectoriser_class", "vector_shape", "num_vectors", "batch_size", "created_at", "meta_data"] missing = [k for k in required_keys if k not in metadata] if missing: raise DataValidationError( @@ -1022,7 +1031,7 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None, quie vector_store.file_name = None vector_store.data_type = None vector_store.vectoriser = vectoriser - vector_store.batch_size = None + vector_store.batch_size = batch_size if batch_size is not None else metadata["batch_size"] vector_store.meta_data = deserialized_column_meta_data vector_store.vectors = df vector_store.vector_shape = metadata["vector_shape"]