From 5399d5965313a19ec40a9d439093a1981ac54f26 Mon Sep 17 00:00:00 2001 From: Jamie Milsom Date: Mon, 8 Jun 2026 17:36:51 +0000 Subject: [PATCH 1/5] updated VectorStore methods to use self.batch_size so there is one single source of truth --- src/classifai/indexers/main.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/classifai/indexers/main.py b/src/classifai/indexers/main.py index 250bdba..4335fb2 100644 --- a/src/classifai/indexers/main.py +++ b/src/classifai/indexers/main.py @@ -72,7 +72,7 @@ class VectorStore: Attributes: file_name (str): the data file contatining the knowledgebase to build the `VectorStore` data_type (str): the data type of the data file (curently only csv supported) - vectoriser (VectoriserBase): A `Vectoriser` object from the corresponding ClassifAI Pacakge module + vectoriser (VectoriserBase): A `Vectoriser` object from the corresponding ClassifAI Package module batch_size (int): the batch size to pass to the vectoriser when embedding meta_data (dict): key-value pairs of metadata to extract from the input file and their correpsonding types output_dir (str): the path to the output directory where the `VectorStore` will be saved @@ -661,7 +661,7 @@ def reverse_search( # noqa: C901, PLR0912 return result_df - def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> VectorStoreSearchOutput: # noqa: C901, PLR0912, PLR0915 + def search(self, query: VectorStoreSearchInput, n_results=10) -> VectorStoreSearchOutput: # noqa: C901, PLR0912, PLR0915 """Searches the `VectorStore` using queries from a `VectorStoreSearchInput` object. Outputs ranked results in `VectorStoreSearchOutput` object. In batches, converts users text queries into vector embeddings, computes cosine similarity with stored document vectors, and retrieves the top results. @@ -669,7 +669,6 @@ def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> V Args: query (VectorStoreSearchInput): A `VectorStoreSearchInput` object containing the text query or list of queries to search for with ids. n_results (int): [optional] Number of top results to return for each query. Default 10. - batch_size (int): [optional] The batch size for processing queries. Default 8. Returns: (VectorStoreSearchOutput): A `VectorStoreSearchOutput` object containing search results with columns for `query_id`, `query_text`, @@ -691,9 +690,6 @@ def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> V if not isinstance(n_results, int) or n_results < 1: raise DataValidationError("n_results must be an integer >= 1.", context={"n_results": n_results}) - if not isinstance(batch_size, int) or batch_size < 1: - raise DataValidationError("batch_size must be an integer >= 1.", context={"batch_size": batch_size}) - if self.vectors is None: raise ConfigurationError("Vector store is not initialized (vectors is None).") @@ -720,9 +716,9 @@ def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> V all_results: list[pl.DataFrame] = [] - for i in tqdm(range(0, len(query), batch_size), desc="Processing query batches"): - query_text_batch = query.query.to_list()[i : i + batch_size] - query_ids_batch = query.id.to_list()[i : i + batch_size] + for i in tqdm(range(0, len(query), self.batch_size), desc="Processing query batches"): + query_text_batch = query.query.to_list()[i : i + self.batch_size] + query_ids_batch = query.id.to_list()[i : i + self.batch_size] if len(query_text_batch) == 0: continue @@ -814,7 +810,7 @@ def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> V code="search_failed", context={ "n_queries": len(query), - "batch_size": batch_size, + "batch_size": self.batch_size, "n_results": n_results, "cause_type": type(e).__name__, "cause_message": str(e), From 8769e9ed6771be3a7a4a263dd5ed534c68f77ae9 Mon Sep 17 00:00:00 2001 From: Jamie Milsom Date: Fri, 12 Jun 2026 15:29:44 +0000 Subject: [PATCH 2/5] updated VectorStore search method to allow a query batch_size argument --- src/classifai/indexers/main.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/classifai/indexers/main.py b/src/classifai/indexers/main.py index 4335fb2..d258a4d 100644 --- a/src/classifai/indexers/main.py +++ b/src/classifai/indexers/main.py @@ -661,7 +661,7 @@ def reverse_search( # noqa: C901, PLR0912 return result_df - def search(self, query: VectorStoreSearchInput, n_results=10) -> VectorStoreSearchOutput: # noqa: C901, PLR0912, PLR0915 + def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=None) -> VectorStoreSearchOutput: # noqa: C901, PLR0912, PLR0915 """Searches the `VectorStore` using queries from a `VectorStoreSearchInput` object. Outputs ranked results in `VectorStoreSearchOutput` object. In batches, converts users text queries into vector embeddings, computes cosine similarity with stored document vectors, and retrieves the top results. @@ -669,6 +669,7 @@ def search(self, query: VectorStoreSearchInput, n_results=10) -> VectorStoreSear Args: query (VectorStoreSearchInput): A `VectorStoreSearchInput` object containing the text query or list of queries to search for with ids. n_results (int): [optional] Number of top results to return for each query. Default 10. + batch_size (int): [optional] The batch size for processing queries. Defaults to the `batch_size` set during initialisation. Returns: (VectorStoreSearchOutput): A `VectorStoreSearchOutput` object containing search results with columns for `query_id`, `query_text`, @@ -690,6 +691,11 @@ def search(self, query: VectorStoreSearchInput, n_results=10) -> VectorStoreSear if not isinstance(n_results, int) or n_results < 1: raise DataValidationError("n_results must be an integer >= 1.", context={"n_results": n_results}) + query_batch_size = batch_size if batch_size is not None else self.batch_size + + if not isinstance(query_batch_size, int) or query_batch_size < 1: + raise DataValidationError("batch_size must be an integer >= 1.", context={"batch_size": query_batch_size}) + if self.vectors is None: raise ConfigurationError("Vector store is not initialized (vectors is None).") @@ -716,9 +722,9 @@ def search(self, query: VectorStoreSearchInput, n_results=10) -> VectorStoreSear all_results: list[pl.DataFrame] = [] - for i in tqdm(range(0, len(query), self.batch_size), desc="Processing query batches"): - query_text_batch = query.query.to_list()[i : i + self.batch_size] - query_ids_batch = query.id.to_list()[i : i + self.batch_size] + for i in tqdm(range(0, len(query), query_batch_size), desc="Processing query batches"): + query_text_batch = query.query.to_list()[i : i + query_batch_size] + query_ids_batch = query.id.to_list()[i : i + query_batch_size] if len(query_text_batch) == 0: continue @@ -810,7 +816,7 @@ def search(self, query: VectorStoreSearchInput, n_results=10) -> VectorStoreSear code="search_failed", context={ "n_queries": len(query), - "batch_size": self.batch_size, + "batch_size": query_batch_size, "n_results": n_results, "cause_type": type(e).__name__, "cause_message": str(e), From fa949a79db9ba910125f851732069094774c9afc Mon Sep 17 00:00:00 2001 From: Jamie Milsom Date: Sat, 13 Jun 2026 17:14:46 +0000 Subject: [PATCH 3/5] persist batch_size in metadata and expose via from_filespace --- src/classifai/indexers/main.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/classifai/indexers/main.py b/src/classifai/indexers/main.py index d258a4d..f8d8517 100644 --- a/src/classifai/indexers/main.py +++ b/src/classifai/indexers/main.py @@ -340,6 +340,7 @@ def _save_metadata(self, path: str): "vectoriser_class": self.vectoriser_class, "vector_shape": self.vector_shape, "num_vectors": self.num_vectors, + "batch_size": self.batch_size, "created_at": time.time(), "meta_data": serializable_column_meta_data, } @@ -840,7 +841,7 @@ def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=None) - return result_df @classmethod - def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None): # noqa: C901, PLR0912, PLR0915 + def from_filespace(cls, folder_path, vectoriser, batch_size: int | None = None, hooks: dict | None = None): # noqa: C901, PLR0912, PLR0915 """Creates a `VectorStore` instance from stored metadata and Parquet files. This method reads the metadata and vectors from the specified folder, validates the contents, and initializes a `VectorStore` object with the @@ -854,6 +855,7 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None): # Args: folder_path (str): The folder path containing the metadata and Parquet files. vectoriser (object): The `Vectoriser` object used to transform text into vector embeddings. + batch_size (int): [optional] Overrides the batch size stored in metadata. Defaults to `None`, which uses the value from metadata. hooks (dict): [optional] A dictionary of user-defined hooks for preprocessing and postprocessing. Defaults to None. Returns: @@ -902,6 +904,9 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None): # context={"vectoriser_type": type(vectoriser).__name__}, ) + if batch_size is not None and (not isinstance(batch_size, int) or batch_size < 1): + raise DataValidationError("batch_size must be an integer >= 1 or None.", context={"batch_size": batch_size}) + if hooks is not None and not isinstance(hooks, dict): raise DataValidationError("hooks must be a dict or None.", context={"hooks_type": type(hooks).__name__}) @@ -929,7 +934,7 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None): # context={"metadata_path": metadata_in_path, "metadata_type": type(metadata).__name__}, ) - required_keys = ["vectoriser_class", "vector_shape", "num_vectors", "created_at", "meta_data"] + required_keys = ["vectoriser_class", "vector_shape", "num_vectors", "batch_size", "created_at", "meta_data"] missing = [k for k in required_keys if k not in metadata] if missing: raise DataValidationError( @@ -1012,7 +1017,7 @@ def from_filespace(cls, folder_path, vectoriser, hooks: dict | None = None): # vector_store.file_name = None vector_store.data_type = None vector_store.vectoriser = vectoriser - vector_store.batch_size = None + vector_store.batch_size = batch_size if batch_size is not None else metadata["batch_size"] vector_store.meta_data = deserialized_column_meta_data vector_store.vectors = df vector_store.vector_shape = metadata["vector_shape"] From 2d6f4e50dffabf14906bcf40229efe09accccdf5 Mon Sep 17 00:00:00 2001 From: Jamie Milsom Date: Sat, 13 Jun 2026 17:31:12 +0000 Subject: [PATCH 4/5] set default batch_size=250 as it is max batch size for gcp models --- src/classifai/indexers/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/classifai/indexers/main.py b/src/classifai/indexers/main.py index f8d8517..55687cb 100644 --- a/src/classifai/indexers/main.py +++ b/src/classifai/indexers/main.py @@ -88,7 +88,7 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 file_name: str, data_type: str, vectoriser: VectoriserBase, - batch_size: int = 8, + batch_size: int = 250, meta_data: dict | None = None, output_dir: str | None = None, overwrite: bool = False, @@ -104,7 +104,7 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 vectoriser (object): The `Vectoriser` object used to transform text into vector embeddings. batch_size (int): [optional] The batch size for processing the input file and batching to - vectoriser. Defaults to 8. + vectoriser. Defaults to 250. meta_data (dict): [optional] key,value pair metadata column names to extract from the input file and their types. Defaults to `None`. output_dir (str): [optional] The directory where the `VectorStore` will be saved. From e8fac530c6801723e4ff3af0fc09d723666598d7 Mon Sep 17 00:00:00 2001 From: Jamie Milsom Date: Tue, 23 Jun 2026 10:26:36 +0100 Subject: [PATCH 5/5] fix query_batch_size typo from merge --- src/classifai/indexers/main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/classifai/indexers/main.py b/src/classifai/indexers/main.py index 9df5eb0..7ae19e8 100644 --- a/src/classifai/indexers/main.py +++ b/src/classifai/indexers/main.py @@ -734,9 +734,9 @@ def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=None) - all_results: list[pl.DataFrame] = [] - for i in self.classifai_tqdm(range(0, len(query), batch_size), desc="Processing query batches"): - query_text_batch = query.query.to_list()[i : i + batch_size] - query_ids_batch = query.id.to_list()[i : i + batch_size] + for i in self.classifai_tqdm(range(0, len(query), query_batch_size), desc="Processing query batches"): + query_text_batch = query.query.to_list()[i : i + query_batch_size] + query_ids_batch = query.id.to_list()[i : i + query_batch_size] if len(query_text_batch) == 0: continue