diff --git a/src/vectorcode/cli_utils.py b/src/vectorcode/cli_utils.py index 0131a5e2..1907b581 100644 --- a/src/vectorcode/cli_utils.py +++ b/src/vectorcode/cli_utils.py @@ -93,6 +93,7 @@ class Config: embedding_dims: Optional[int] = None n_result: int = 1 force: bool = False + batch_size: int = 100 db_path: Optional[str] = "~/.local/share/vectorcode/chromadb/" db_log_path: str = "~/.local/share/vectorcode/" db_settings: Optional[dict] = None @@ -290,6 +291,13 @@ def get_cli_parser(): default=False, help="Force to vectorise the file(s) against the gitignore.", ) + vectorise_parser.add_argument( + "-b", + "--batch_size", + type=int, + default=__default_config.batch_size, + help="Number of files to process per batch (default: 100). Use -1 for no batching.", + ) query_parser = subparsers.add_parser( "query", @@ -439,6 +447,7 @@ async def parse_cli_args(args: Optional[Sequence[str]] = None): configs_items["recursive"] = main_args.recursive configs_items["include_hidden"] = main_args.include_hidden configs_items["force"] = main_args.force + configs_items["batch_size"] = main_args.batch_size configs_items["chunk_size"] = main_args.chunk_size configs_items["overlap_ratio"] = main_args.overlap configs_items["encoding"] = main_args.encoding diff --git a/src/vectorcode/subcommands/vectorise.py b/src/vectorcode/subcommands/vectorise.py index 2ce0b249..3456abca 100644 --- a/src/vectorcode/subcommands/vectorise.py +++ b/src/vectorcode/subcommands/vectorise.py @@ -302,28 +302,32 @@ async def vectorise(configs: Config) -> int: max_batch_size = await client.get_max_batch_size() semaphore = asyncio.Semaphore(os.cpu_count() or 1) + batch_size = configs.batch_size if configs.batch_size > 0 else len(files) + with tqdm.tqdm( total=len(files), desc="Vectorising files...", disable=configs.pipe ) as bar: try: - tasks = [ - asyncio.create_task( - chunked_add( - str(file), - collection, - collection_lock, - stats, - stats_lock, - configs, - max_batch_size, - semaphore, + for i in range(0, len(files), batch_size): + batch = files[i : i + batch_size] + tasks = [ + asyncio.create_task( + chunked_add( + str(file), + collection, + collection_lock, + stats, + stats_lock, + configs, + max_batch_size, + semaphore, + ) ) - ) - for file in files - ] - for task in asyncio.as_completed(tasks): - await task - bar.update(1) + for file in batch + ] + for task in asyncio.as_completed(tasks): + await task + bar.update(1) except asyncio.CancelledError: print("Abort.", file=sys.stderr) return 1 diff --git a/tests/subcommands/test_vectorise.py b/tests/subcommands/test_vectorise.py index 3ce5683b..16fb5db9 100644 --- a/tests/subcommands/test_vectorise.py +++ b/tests/subcommands/test_vectorise.py @@ -806,6 +806,130 @@ async def test_vectorise_exclude_file_recursive(): assert mock_chunked_add.call_count == 1 +@pytest.mark.asyncio +async def test_vectorise_with_batch_size(): + """Test that batch_size correctly limits concurrent task creation.""" + configs = Config( + db_url="http://test_host:1234", + db_path="test_db", + embedding_function="SentenceTransformerEmbeddingFunction", + embedding_params={}, + project_root="/test_project", + files=[f"file{i}.py" for i in range(10)], + recursive=False, + force=False, + pipe=False, + batch_size=3, + ) + mock_client = AsyncMock() + mock_collection = MagicMock(spec=AsyncCollection) + mock_collection.get.return_value = {"ids": []} + mock_collection.delete.return_value = None + mock_collection.metadata = { + "embedding_function": "SentenceTransformerEmbeddingFunction", + "path": "/test_project", + "hostname": socket.gethostname(), + "created-by": "VectorCode", + "username": os.environ.get("USER", os.environ.get("USERNAME", "DEFAULT_USER")), + } + mock_client.get_max_batch_size.return_value = 50 + mock_embedding_function = MagicMock() + + with ExitStack() as stack: + stack.enter_context( + patch("vectorcode.subcommands.vectorise.ClientManager"), + ) + stack.enter_context(patch("os.path.isfile", return_value=False)) + stack.enter_context( + patch( + "vectorcode.subcommands.vectorise.expand_globs", + return_value=configs.files, + ) + ) + mock_chunked_add = stack.enter_context( + patch("vectorcode.subcommands.vectorise.chunked_add", return_value=None) + ) + stack.enter_context( + patch( + "vectorcode.common.get_embedding_function", + return_value=mock_embedding_function, + ) + ) + stack.enter_context( + patch( + "vectorcode.subcommands.vectorise.get_collection", + return_value=mock_collection, + ) + ) + + result = await vectorise(configs) + assert result == 0 + # All 10 files should be processed + assert mock_chunked_add.call_count == 10 + + +@pytest.mark.asyncio +async def test_vectorise_with_batch_size_disabled(): + """Test that batch_size=-1 disables batching (processes all files at once).""" + configs = Config( + db_url="http://test_host:1234", + db_path="test_db", + embedding_function="SentenceTransformerEmbeddingFunction", + embedding_params={}, + project_root="/test_project", + files=[f"file{i}.py" for i in range(5)], + recursive=False, + force=False, + pipe=False, + batch_size=-1, + ) + mock_client = AsyncMock() + mock_collection = MagicMock(spec=AsyncCollection) + mock_collection.get.return_value = {"ids": []} + mock_collection.delete.return_value = None + mock_collection.metadata = { + "embedding_function": "SentenceTransformerEmbeddingFunction", + "path": "/test_project", + "hostname": socket.gethostname(), + "created-by": "VectorCode", + "username": os.environ.get("USER", os.environ.get("USERNAME", "DEFAULT_USER")), + } + mock_client.get_max_batch_size.return_value = 50 + mock_embedding_function = MagicMock() + + with ExitStack() as stack: + stack.enter_context( + patch("vectorcode.subcommands.vectorise.ClientManager"), + ) + stack.enter_context(patch("os.path.isfile", return_value=False)) + stack.enter_context( + patch( + "vectorcode.subcommands.vectorise.expand_globs", + return_value=configs.files, + ) + ) + mock_chunked_add = stack.enter_context( + patch("vectorcode.subcommands.vectorise.chunked_add", return_value=None) + ) + stack.enter_context( + patch( + "vectorcode.common.get_embedding_function", + return_value=mock_embedding_function, + ) + ) + stack.enter_context( + patch( + "vectorcode.subcommands.vectorise.get_collection", + return_value=mock_collection, + ) + ) + + result = await vectorise(configs) + assert result == 0 + # All 5 files should be processed + assert mock_chunked_add.call_count == 5 + + @pytest.mark.asyncio async def test_vectorise_uses_global_exclude_when_local_missing(): mock_client = AsyncMock()