Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/vectorcode/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class Config:
embedding_dims: Optional[int] = None
n_result: int = 1
force: bool = False
batch_size: int = 100
db_path: Optional[str] = "~/.local/share/vectorcode/chromadb/"
db_log_path: str = "~/.local/share/vectorcode/"
db_settings: Optional[dict] = None
Expand Down Expand Up @@ -290,6 +291,13 @@ def get_cli_parser():
default=False,
help="Force to vectorise the file(s) against the gitignore.",
)
vectorise_parser.add_argument(
"-b",
"--batch_size",
type=int,
default=__default_config.batch_size,
help="Number of files to process per batch (default: 100). Use -1 for no batching.",
)

query_parser = subparsers.add_parser(
"query",
Expand Down Expand Up @@ -439,6 +447,7 @@ async def parse_cli_args(args: Optional[Sequence[str]] = None):
configs_items["recursive"] = main_args.recursive
configs_items["include_hidden"] = main_args.include_hidden
configs_items["force"] = main_args.force
configs_items["batch_size"] = main_args.batch_size
configs_items["chunk_size"] = main_args.chunk_size
configs_items["overlap_ratio"] = main_args.overlap
configs_items["encoding"] = main_args.encoding
Expand Down
38 changes: 21 additions & 17 deletions src/vectorcode/subcommands/vectorise.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,28 +302,32 @@ async def vectorise(configs: Config) -> int:
max_batch_size = await client.get_max_batch_size()
semaphore = asyncio.Semaphore(os.cpu_count() or 1)

batch_size = configs.batch_size if configs.batch_size > 0 else len(files)

with tqdm.tqdm(
total=len(files), desc="Vectorising files...", disable=configs.pipe
) as bar:
try:
tasks = [
asyncio.create_task(
chunked_add(
str(file),
collection,
collection_lock,
stats,
stats_lock,
configs,
max_batch_size,
semaphore,
for i in range(0, len(files), batch_size):
batch = files[i : i + batch_size]
tasks = [
asyncio.create_task(
chunked_add(
str(file),
collection,
collection_lock,
stats,
stats_lock,
configs,
max_batch_size,
semaphore,
)
)
)
for file in files
]
for task in asyncio.as_completed(tasks):
await task
bar.update(1)
for file in batch
]
for task in asyncio.as_completed(tasks):
await task
bar.update(1)
except asyncio.CancelledError:
print("Abort.", file=sys.stderr)
return 1
Expand Down
124 changes: 124 additions & 0 deletions tests/subcommands/test_vectorise.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,130 @@ async def test_vectorise_exclude_file_recursive():
assert mock_chunked_add.call_count == 1


@pytest.mark.asyncio
async def test_vectorise_with_batch_size():
"""Test that batch_size correctly limits concurrent task creation."""
configs = Config(
db_url="http://test_host:1234",
db_path="test_db",
embedding_function="SentenceTransformerEmbeddingFunction",
embedding_params={},
project_root="/test_project",
files=[f"file{i}.py" for i in range(10)],
recursive=False,
force=False,
pipe=False,
batch_size=3,
)
mock_client = AsyncMock()
mock_collection = MagicMock(spec=AsyncCollection)
mock_collection.get.return_value = {"ids": []}
mock_collection.delete.return_value = None
mock_collection.metadata = {
"embedding_function": "SentenceTransformerEmbeddingFunction",
"path": "/test_project",
"hostname": socket.gethostname(),
"created-by": "VectorCode",
"username": os.environ.get("USER", os.environ.get("USERNAME", "DEFAULT_USER")),
}
mock_client.get_max_batch_size.return_value = 50
mock_embedding_function = MagicMock()

with ExitStack() as stack:
stack.enter_context(
patch("vectorcode.subcommands.vectorise.ClientManager"),
)
stack.enter_context(patch("os.path.isfile", return_value=False))
stack.enter_context(
patch(
"vectorcode.subcommands.vectorise.expand_globs",
return_value=configs.files,
)
)
mock_chunked_add = stack.enter_context(
patch("vectorcode.subcommands.vectorise.chunked_add", return_value=None)
)
stack.enter_context(
patch(
"vectorcode.common.get_embedding_function",
return_value=mock_embedding_function,
)
)
stack.enter_context(
patch(
"vectorcode.subcommands.vectorise.get_collection",
return_value=mock_collection,
)
)

result = await vectorise(configs)
assert result == 0
# All 10 files should be processed
assert mock_chunked_add.call_count == 10


@pytest.mark.asyncio
async def test_vectorise_with_batch_size_disabled():
"""Test that batch_size=-1 disables batching (processes all files at once)."""
configs = Config(
db_url="http://test_host:1234",
db_path="test_db",
embedding_function="SentenceTransformerEmbeddingFunction",
embedding_params={},
project_root="/test_project",
files=[f"file{i}.py" for i in range(5)],
recursive=False,
force=False,
pipe=False,
batch_size=-1,
)
mock_client = AsyncMock()
mock_collection = MagicMock(spec=AsyncCollection)
mock_collection.get.return_value = {"ids": []}
mock_collection.delete.return_value = None
mock_collection.metadata = {
"embedding_function": "SentenceTransformerEmbeddingFunction",
"path": "/test_project",
"hostname": socket.gethostname(),
"created-by": "VectorCode",
"username": os.environ.get("USER", os.environ.get("USERNAME", "DEFAULT_USER")),
}
mock_client.get_max_batch_size.return_value = 50
mock_embedding_function = MagicMock()

with ExitStack() as stack:
stack.enter_context(
patch("vectorcode.subcommands.vectorise.ClientManager"),
)
stack.enter_context(patch("os.path.isfile", return_value=False))
stack.enter_context(
patch(
"vectorcode.subcommands.vectorise.expand_globs",
return_value=configs.files,
)
)
mock_chunked_add = stack.enter_context(
patch("vectorcode.subcommands.vectorise.chunked_add", return_value=None)
)
stack.enter_context(
patch(
"vectorcode.common.get_embedding_function",
return_value=mock_embedding_function,
)
)
stack.enter_context(
patch(
"vectorcode.subcommands.vectorise.get_collection",
return_value=mock_collection,
)
)

result = await vectorise(configs)
assert result == 0
# All 5 files should be processed
assert mock_chunked_add.call_count == 5


@pytest.mark.asyncio
async def test_vectorise_uses_global_exclude_when_local_missing():
mock_client = AsyncMock()
Expand Down
Loading