From a3abb0f299bc5a38ddbc09612eac2e6a6c7bc43c Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Thu, 16 Apr 2026 12:14:03 +0100 Subject: [PATCH 01/17] Add pytest timeouts: - mark batch tests specifically with long timeouts - add timeouts to GH action jobs as a backup - dump full stack-trace on timeout detection to debug hangs in the CI --- .github/workflows/main.yaml | 12 ++++++++++++ conftest.py | 24 ++++++++++++++++++++++++ integration/test_batch_v4.py | 4 ++++ integration/test_collection_batch.py | 3 +++ pytest.ini | 2 +- requirements-test.txt | 1 + 6 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 conftest.py diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 2f68432ba..62203108c 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -35,6 +35,7 @@ jobs: lint-and-format: name: Run Linter and Formatter runs-on: ubuntu-latest + timeout-minutes: 5 steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 @@ -60,6 +61,7 @@ jobs: type-checking: name: Run Type Checking runs-on: ubuntu-latest + timeout-minutes: 5 strategy: fail-fast: false matrix: @@ -80,6 +82,7 @@ jobs: unit-tests: name: Run Unit Tests runs-on: ubuntu-latest + timeout-minutes: 5 strategy: fail-fast: false matrix: @@ -104,6 +107,7 @@ jobs: proto-test: name: Run importing protos test runs-on: ubuntu-latest + timeout-minutes: 5 strategy: fail-fast: false matrix: @@ -124,6 +128,7 @@ jobs: integration-tests-embedded: name: Run Integration Tests Embedded runs-on: ubuntu-latest + timeout-minutes: 30 strategy: matrix: version: ["3.10", "3.11", "3.12", "3.13", "3.14"] @@ -153,6 +158,7 @@ jobs: integration-tests: name: Run Integration Tests runs-on: ubuntu-latest + timeout-minutes: 30 strategy: fail-fast: false matrix: @@ -208,6 +214,7 @@ jobs: journey-tests: name: Run Journey Tests runs-on: ubuntu-latest + timeout-minutes: 30 strategy: fail-fast: false matrix: @@ -243,6 +250,7 @@ jobs: Codecov: needs: [Unit-Tests, Integration-Tests] runs-on: ubuntu-latest + timeout-minutes: 5 if: github.ref_name != 'main' && !github.event.pull_request.head.repo.fork steps: - uses: actions/checkout@v4 @@ -273,6 +281,7 @@ jobs: build-package: name: Build package runs-on: ubuntu-latest + timeout-minutes: 10 steps: - name: Checkout uses: actions/checkout@v4 @@ -297,6 +306,7 @@ jobs: test-package: needs: [build-package] runs-on: ubuntu-latest + timeout-minutes: 30 strategy: fail-fast: false matrix: @@ -341,6 +351,7 @@ jobs: name: Build and publish Python 🐍 distributions 📦 to PyPI and TestPyPI needs: [integration-tests, unit-tests, lint-and-format, type-checking, test-package, proto-test] runs-on: ubuntu-latest + timeout-minutes: 20 steps: - name: Checkout uses: actions/checkout@v4 @@ -366,6 +377,7 @@ jobs: name: Create a GitHub Release on new tags if: startsWith(github.ref, 'refs/tags') runs-on: ubuntu-latest + timeout-minutes: 5 needs: [build-and-publish] steps: - name: Download build artifact to append to release diff --git a/conftest.py b/conftest.py new file mode 100644 index 000000000..6b40adda9 --- /dev/null +++ b/conftest.py @@ -0,0 +1,24 @@ +import faulthandler + +import pytest + + +def pytest_runtest_setup(item: pytest.Item) -> None: + """Set faulthandler alarm as a backup timeout mechanism. + + This fires even if the process is stuck in C code (e.g., gRPC core). + Set to pytest-timeout value + 30s so pytest-timeout handles it first. + """ + marker = item.get_closest_marker("timeout") + if marker and marker.args: + test_timeout = marker.args[0] + else: + test_timeout = item.config.getini("timeout") or 300 + + if test_timeout and float(test_timeout) > 0: + faulthandler.dump_traceback_later(float(test_timeout) + 30, exit=True) + + +def pytest_runtest_teardown(item: pytest.Item, nextitem: pytest.Item | None) -> None: + """Cancel the faulthandler alarm after each test.""" + faulthandler.cancel_dump_traceback_later() diff --git a/integration/test_batch_v4.py b/integration/test_batch_v4.py index ad4c70bb2..f4ce7669e 100644 --- a/integration/test_batch_v4.py +++ b/integration/test_batch_v4.py @@ -410,6 +410,7 @@ def test_add_ref_batch_with_tenant(client_factory: ClientFactory) -> None: assert ret_obj.references["test"].objects[0].uuid == obj[0] +@pytest.mark.timeout(600) @pytest.mark.parametrize( "batching_method", [ @@ -717,6 +718,7 @@ def test_non_existant_collection(client_factory: ClientFactory) -> None: # not, so we do not check for errors here +@pytest.mark.timeout(600) def test_number_of_stored_results_in_batch(client_factory: ClientFactory) -> None: client, name = client_factory() with client.batch.dynamic() as batch: @@ -816,6 +818,7 @@ def test_references_with_to_uuids(client_factory: ClientFactory) -> None: @pytest.mark.asyncio +@pytest.mark.timeout(600) async def test_add_one_hundred_thousand_objects_async_client( async_client_factory: AsyncClientFactory, ) -> None: @@ -846,6 +849,7 @@ async def test_add_one_hundred_thousand_objects_async_client( await client.collections.delete(name) +@pytest.mark.timeout(600) def test_add_one_hundred_thousand_objects_sync_client( client_factory: ClientFactory, ) -> None: diff --git a/integration/test_collection_batch.py b/integration/test_collection_batch.py index 40683a26f..e670e4883 100644 --- a/integration/test_collection_batch.py +++ b/integration/test_collection_batch.py @@ -271,6 +271,7 @@ def test_non_existant_collection(collection_factory_get: CollectionFactoryGet) - @pytest.mark.asyncio +@pytest.mark.timeout(600) async def test_batch_one_hundred_thousand_objects_async_collection( batch_collection_async: BatchCollectionAsync, ) -> None: @@ -298,6 +299,7 @@ async def test_batch_one_hundred_thousand_objects_async_collection( @pytest.mark.asyncio +@pytest.mark.timeout(600) async def test_ingest_one_hundred_thousand_data_objects_async( batch_collection_async: BatchCollectionAsync, ) -> None: @@ -319,6 +321,7 @@ async def test_ingest_one_hundred_thousand_data_objects_async( assert len(results.errors) == 0, [obj.message for obj in results.errors.values()] +@pytest.mark.timeout(600) def test_ingest_one_hundred_thousand_data_objects( batch_collection: BatchCollection, ) -> None: diff --git a/pytest.ini b/pytest.ini index 36321aa1d..0516dd966 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,5 +1,5 @@ [pytest] -addopts = -m 'not profiling' --benchmark-skip -l +addopts = -m 'not profiling' --benchmark-skip -l --timeout=300 --timeout_method=thread markers = profiling: marks tests that can be profiled asyncio_default_fixture_loop_scope = function \ No newline at end of file diff --git a/requirements-test.txt b/requirements-test.txt index c267bab24..ab1c3c036 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -3,6 +3,7 @@ pytest-cov==6.2.1 pytest-asyncio==1.3.0 pytest-benchmark==5.1.0 pytest-profiling==1.8.1 +pytest-timeout==2.3.1 coverage==7.10.7 pytest-xdist==3.7.0 werkzeug==3.1.6 From b439d5cd105ab6e46b1c0e5c8d3cc636da18c759 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Thu, 16 Apr 2026 13:19:32 +0100 Subject: [PATCH 02/17] Avoid crashing the xdist node on failure --- conftest.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/conftest.py b/conftest.py index 6b40adda9..f95927fab 100644 --- a/conftest.py +++ b/conftest.py @@ -8,6 +8,11 @@ def pytest_runtest_setup(item: pytest.Item) -> None: This fires even if the process is stuck in C code (e.g., gRPC core). Set to pytest-timeout value + 30s so pytest-timeout handles it first. + + Uses exit=False to avoid killing xdist worker processes — a killed worker + causes 'node down: Not properly terminated' and loses the stack trace output. + With exit=False, faulthandler dumps tracebacks to stderr (relayed by xdist) + without terminating the process, letting pytest-timeout handle the interruption. """ marker = item.get_closest_marker("timeout") if marker and marker.args: @@ -16,7 +21,7 @@ def pytest_runtest_setup(item: pytest.Item) -> None: test_timeout = item.config.getini("timeout") or 300 if test_timeout and float(test_timeout) > 0: - faulthandler.dump_traceback_later(float(test_timeout) + 30, exit=True) + faulthandler.dump_traceback_later(float(test_timeout) + 30, exit=False) def pytest_runtest_teardown(item: pytest.Item, nextitem: pytest.Item | None) -> None: From b1b25e081bc2953bb60a0561aa50c115a44948be Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Thu, 16 Apr 2026 14:26:25 +0100 Subject: [PATCH 03/17] Timeout with signal rather than thread to avoid crashing xdist worker node --- pytest.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytest.ini b/pytest.ini index 0516dd966..6c3287680 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,5 +1,5 @@ [pytest] -addopts = -m 'not profiling' --benchmark-skip -l --timeout=300 --timeout_method=thread +addopts = -m 'not profiling' --benchmark-skip -l --timeout=300 --timeout_method=signal markers = profiling: marks tests that can be profiled asyncio_default_fixture_loop_scope = function \ No newline at end of file From 1630b0b1c5196c1c600bcc1f09446a5ac53ae75f Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Thu, 16 Apr 2026 17:31:09 +0100 Subject: [PATCH 04/17] Add test to verify behaviour with xdist, change behaviour to match test --- conftest.py | 62 +++++++++++++++++++++-------- pytest.ini | 3 +- requirements-test.txt | 1 - test/test_timeout.py | 90 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 138 insertions(+), 18 deletions(-) create mode 100644 test/test_timeout.py diff --git a/conftest.py b/conftest.py index f95927fab..ced4f6ed2 100644 --- a/conftest.py +++ b/conftest.py @@ -1,29 +1,59 @@ import faulthandler +import os +import threading import pytest +DEFAULT_TIMEOUT = 300 # 5 minutes -def pytest_runtest_setup(item: pytest.Item) -> None: - """Set faulthandler alarm as a backup timeout mechanism. +_timeout_timer: threading.Timer | None = None - This fires even if the process is stuck in C code (e.g., gRPC core). - Set to pytest-timeout value + 30s so pytest-timeout handles it first. - Uses exit=False to avoid killing xdist worker processes — a killed worker - causes 'node down: Not properly terminated' and loses the stack trace output. - With exit=False, faulthandler dumps tracebacks to stderr (relayed by xdist) - without terminating the process, letting pytest-timeout handle the interruption. - """ +def _get_timeout(item: pytest.Item) -> float: marker = item.get_closest_marker("timeout") if marker and marker.args: - test_timeout = marker.args[0] - else: - test_timeout = item.config.getini("timeout") or 300 + return float(marker.args[0]) + return float(DEFAULT_TIMEOUT) - if test_timeout and float(test_timeout) > 0: - faulthandler.dump_traceback_later(float(test_timeout) + 30, exit=False) + +def pytest_runtest_setup(item: pytest.Item) -> None: + """Start a watchdog timer that dumps all thread stack traces on timeout. + + Unlike pytest-timeout, this does NOT raise KeyboardInterrupt (which crashes + xdist workers and corrupts asyncio event loops). Instead it: + 1. Writes the test name + all thread tracebacks directly to fd 2 (stderr). + With --capture=sys in pytest.ini, fd 2 is the real stderr (not captured), + so the output goes directly to the CI log even under xdist. + 2. Calls os._exit(1) to terminate the worker process. + + xdist will report 'node down: Not properly terminated' which is expected — + the diagnostic output will already be in the CI logs above that message. + """ + global _timeout_timer + timeout = _get_timeout(item) + if timeout <= 0: + return + + def _on_timeout() -> None: + banner = "=" * 70 + os.write(2, f"\n\n{banner}\n".encode()) + os.write(2, f"TIMEOUT: {item.nodeid} exceeded {timeout}s\n".encode()) + os.write(2, f"{banner}\n\n".encode()) + # faulthandler needs a file object — wrap a dup of fd 2 to avoid closing it + with os.fdopen(os.dup(2), "w") as f: + faulthandler.dump_traceback(file=f) + f.flush() + os.write(2, f"\n{banner}\n\n".encode()) + os._exit(1) + + _timeout_timer = threading.Timer(timeout, _on_timeout) + _timeout_timer.daemon = True + _timeout_timer.start() def pytest_runtest_teardown(item: pytest.Item, nextitem: pytest.Item | None) -> None: - """Cancel the faulthandler alarm after each test.""" - faulthandler.cancel_dump_traceback_later() + """Cancel the watchdog timer after each test completes.""" + global _timeout_timer + if _timeout_timer is not None: + _timeout_timer.cancel() + _timeout_timer = None diff --git a/pytest.ini b/pytest.ini index 6c3287680..f87230abf 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,5 +1,6 @@ [pytest] -addopts = -m 'not profiling' --benchmark-skip -l --timeout=300 --timeout_method=signal +addopts = -m 'not profiling' --benchmark-skip -l --capture=sys --max-worker-restart=3 markers = profiling: marks tests that can be profiled + timeout: marks tests with a custom timeout in seconds (default: 300) asyncio_default_fixture_loop_scope = function \ No newline at end of file diff --git a/requirements-test.txt b/requirements-test.txt index ab1c3c036..c267bab24 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -3,7 +3,6 @@ pytest-cov==6.2.1 pytest-asyncio==1.3.0 pytest-benchmark==5.1.0 pytest-profiling==1.8.1 -pytest-timeout==2.3.1 coverage==7.10.7 pytest-xdist==3.7.0 werkzeug==3.1.6 diff --git a/test/test_timeout.py b/test/test_timeout.py new file mode 100644 index 000000000..fc49f51e0 --- /dev/null +++ b/test/test_timeout.py @@ -0,0 +1,90 @@ +"""Tests for the custom per-test timeout mechanism in conftest.py. + +Uses subprocess because the timeout mechanism calls os._exit(1). +""" + +import subprocess +import sys +import textwrap +from pathlib import Path + +PROJECT_ROOT = Path(__file__).parent.parent + + +def _run_pytest(tmp_path: Path, test_code: str, *extra_args: str) -> subprocess.CompletedProcess: + """Run pytest in a subprocess with a copy of our timeout conftest.""" + (tmp_path / "conftest.py").write_text((PROJECT_ROOT / "conftest.py").read_text()) + (tmp_path / "pytest.ini").write_text( + "[pytest]\naddopts = --capture=sys --max-worker-restart=0\nmarkers =\n timeout: custom timeout\n" + ) + (tmp_path / "test_it.py").write_text(textwrap.dedent(test_code)) + return subprocess.run( + [ + sys.executable, + "-m", + "pytest", + "-v", + "-n", + "auto", + "--dist", + "loadgroup", + "test_it.py", + *extra_args, + ], + capture_output=True, + text=True, + timeout=60, + cwd=str(tmp_path), + ) + + +def test_timeout_prints_test_name_and_stacktrace(tmp_path: Path) -> None: + result = _run_pytest( + tmp_path, + """\ + import time + import pytest + + @pytest.mark.timeout(2) + def test_hangs(): + time.sleep(999) + """, + ) + assert result.returncode != 0 + assert "TIMEOUT: test_it.py::test_hangs exceeded 2.0s" in result.stderr + assert "test_hangs" in result.stderr + + +def test_fast_test_not_killed(tmp_path: Path) -> None: + result = _run_pytest( + tmp_path, + """\ + import pytest + + @pytest.mark.timeout(10) + def test_fast(): + assert True + """, + ) + assert result.returncode == 0 + assert "TIMEOUT" not in result.stderr + + +def test_timeout_with_passing_and_hanging_test(tmp_path: Path) -> None: + result = _run_pytest( + tmp_path, + """\ + import time + import pytest + + @pytest.mark.timeout(2) + def test_hangs_in_worker(): + time.sleep(999) + + def test_passes(): + assert True + """, + ) + assert result.returncode != 0 + assert "TIMEOUT: test_it.py::test_hangs_in_worker exceeded 2.0s" in result.stderr + assert "test_hangs_in_worker" in result.stderr From 95486700f924cfbdcbc8595df703ec5ce960fb71 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Fri, 17 Apr 2026 10:49:51 +0100 Subject: [PATCH 05/17] Attempt bug squashing of hanging test flake in CI through stricter inflight_{objs,refs} logic --- weaviate/collections/batch/async_.py | 34 +++++++++++++--------------- weaviate/collections/batch/base.py | 8 +++++++ weaviate/collections/batch/sync.py | 31 +++++++++++++------------ 3 files changed, 41 insertions(+), 32 deletions(-) diff --git a/weaviate/collections/batch/async_.py b/weaviate/collections/batch/async_.py index b71f8be39..a3c88926d 100644 --- a/weaviate/collections/batch/async_.py +++ b/weaviate/collections/batch/async_.py @@ -17,6 +17,7 @@ ObjectsBatchRequest, ReferencesBatchRequest, _BatchDataWrapper, + _BatchStreamRequest, _ClusterBatchAsync, ) from weaviate.collections.batch.grpc_batch import _BatchGRPC @@ -117,9 +118,7 @@ def __init__( # maxsize=1 so that __send does not run faster than generator for __recv # thereby using too much buffer in case of server-side shutdown - self.__reqs: asyncio.Queue[Optional[batch_pb2.BatchStreamRequest]] = asyncio.Queue( - maxsize=1 - ) + self.__reqs: asyncio.Queue[Optional[_BatchStreamRequest]] = asyncio.Queue(maxsize=1) self.__bg_exception: Optional[Exception] = None self.__bg_tasks: Optional[_BgTasks] = None @@ -170,6 +169,8 @@ async def recv_wrapper() -> None: await self.__batch_objects.aprepend(list(self.__objs_cache.values())) async with self.__refs_cache_lock: await self.__batch_references.aprepend(list(self.__refs_cache.values())) + self.__inflight_objs.clear() + self.__inflight_refs.clear() # start a new stream with a newly reconnected channel return await recv_wrapper() @@ -262,7 +263,7 @@ def __generate_stream_requests( self, objects: List[BatchObject], references: List[BatchReference], - ) -> Generator[batch_pb2.BatchStreamRequest, None, None]: + ) -> Generator[_BatchStreamRequest, None, None]: per_object_overhead = 4 # extra overhead bytes per object in the request def request_maker(): @@ -271,8 +272,7 @@ def request_maker(): request = request_maker() total_size = request.ByteSize() - inflight_objs = set() - inflight_refs = set() + uuids, beacons = set() for object_ in objects: obj = self.__batch_grpc.grpc_object(object_._to_internal()) obj_size = obj.ByteSize() + per_object_overhead @@ -283,35 +283,31 @@ def request_maker(): ) if total_size + obj_size >= self.__batch_grpc.grpc_max_msg_size: - self.__inflight_objs.update(inflight_objs) - self.__inflight_refs.update(inflight_refs) - yield request + yield _BatchStreamRequest(request, uuids, beacons) request = request_maker() total_size = request.ByteSize() + uuids, beacons = set(), set() request.data.objects.values.append(obj) total_size += obj_size - inflight_objs.add(obj.uuid) + uuids.add(obj.uuid) for reference in references: ref = self.__batch_grpc.grpc_reference(reference._to_internal()) ref_size = ref.ByteSize() + per_object_overhead if total_size + ref_size >= self.__batch_grpc.grpc_max_msg_size: - self.__inflight_objs.update(inflight_objs) - self.__inflight_refs.update(inflight_refs) - yield request + yield _BatchStreamRequest(request, uuids, beacons) request = request_maker() total_size = request.ByteSize() + uuids, beacons = set(), set() request.data.references.values.append(ref) total_size += ref_size - inflight_refs.add(reference._to_beacon()) + beacons.add(reference._to_beacon()) if len(request.data.objects.values) > 0 or len(request.data.references.values) > 0: - self.__inflight_objs.update(inflight_objs) - self.__inflight_refs.update(inflight_refs) - yield request + yield _BatchStreamRequest(request, uuids, beacons) async def __send( self, @@ -341,7 +337,9 @@ async def __send( ) yield batch_pb2.BatchStreamRequest(stop=batch_pb2.BatchStreamRequest.Stop()) return - yield req + yield req.proto + self.__inflight_objs.update(req.uuids) + self.__inflight_refs.update(req.beacons) continue except asyncio.TimeoutError: if self.__is_shutting_down.is_set(): diff --git a/weaviate/collections/batch/base.py b/weaviate/collections/batch/base.py index 14a0c0768..9ee04e136 100644 --- a/weaviate/collections/batch/base.py +++ b/weaviate/collections/batch/base.py @@ -43,6 +43,7 @@ WeaviateBatchValidationError, ) from weaviate.logger import logger +from weaviate.proto.v1 import batch_pb2 from weaviate.types import UUID, VECTORS from weaviate.util import _decode_json_response_dict from weaviate.warnings import _Warnings @@ -227,6 +228,13 @@ async def ahead(self) -> Optional[Obj]: return self.__head() +@dataclass +class _BatchStreamRequest: + proto: batch_pb2.BatchStreamRequest + uuids: set[str] + beacons: set[str] + + @dataclass class _BatchDataWrapper: results: BatchResult = field(default_factory=BatchResult) diff --git a/weaviate/collections/batch/sync.py b/weaviate/collections/batch/sync.py index f219de563..772c17906 100644 --- a/weaviate/collections/batch/sync.py +++ b/weaviate/collections/batch/sync.py @@ -13,6 +13,7 @@ ReferencesBatchRequest, _BatchDataWrapper, _BatchMode, + _BatchStreamRequest, _BgThreads, _ClusterBatch, ) @@ -109,7 +110,7 @@ def __init__( # maxsize=1 so that __loop does not run faster than generator for __recv # thereby using too much buffer in case of server-side shutdown - self.__reqs: Queue[Optional[batch_pb2.BatchStreamRequest]] = Queue(maxsize=1) + self.__reqs: Queue[Optional[_BatchStreamRequest]] = Queue(maxsize=1) @property def number_errors(self) -> int: @@ -214,7 +215,7 @@ def __generate_stream_requests( self, objects: List[BatchObject], references: List[BatchReference], - ) -> Generator[batch_pb2.BatchStreamRequest, None, None]: + ) -> Generator[_BatchStreamRequest, None, None]: per_object_overhead = 4 # extra overhead bytes per object in the request def request_maker(): @@ -223,8 +224,7 @@ def request_maker(): request = request_maker() total_size = request.ByteSize() - inflight_objs = set() - inflight_refs = set() + uuids, beacons = set(), set() for object_ in objects: obj = self.__batch_grpc.grpc_object(object_._to_internal()) obj_size = obj.ByteSize() + per_object_overhead @@ -235,33 +235,31 @@ def request_maker(): ) if total_size + obj_size >= self.__batch_grpc.grpc_max_msg_size: - yield request + yield _BatchStreamRequest(request, uuids, beacons) request = request_maker() total_size = request.ByteSize() + uuids, beacons = set(), set() request.data.objects.values.append(obj) total_size += obj_size - inflight_objs.add(obj.uuid) + uuids.add(obj.uuid) for reference in references: ref = self.__batch_grpc.grpc_reference(reference._to_internal()) ref_size = ref.ByteSize() + per_object_overhead if total_size + ref_size >= self.__batch_grpc.grpc_max_msg_size: - yield request + yield _BatchStreamRequest(request, uuids, beacons) request = request_maker() total_size = request.ByteSize() + uuids, beacons = set(), set() request.data.references.values.append(ref) total_size += ref_size - inflight_refs.add(reference._to_beacon()) - - with self.__acks_lock: - self.__inflight_objs.update(inflight_objs) - self.__inflight_refs.update(inflight_refs) + beacons.add(reference._to_beacon()) if len(request.data.objects.values) > 0 or len(request.data.references.values) > 0: - yield request + yield _BatchStreamRequest(request, uuids, beacons) def __send( self, @@ -291,7 +289,10 @@ def __send( ) yield batch_pb2.BatchStreamRequest(stop=batch_pb2.BatchStreamRequest.Stop()) return - yield req + yield req.proto + with self.__acks_lock: + self.__inflight_objs.update(req.uuids) + self.__inflight_refs.update(req.beacons) continue except Empty: if self.__is_shutting_down.is_set(): @@ -513,6 +514,8 @@ def recv_wrapper() -> None: self.__batch_objects.prepend(list(self.__objs_cache.values())) with self.__refs_cache_lock: self.__batch_references.prepend(list(self.__refs_cache.values())) + self.__inflight_objs.clear() + self.__inflight_refs.clear() # start a new stream with a newly reconnected channel return recv_wrapper() From 5a0e8f400b5833ce281b697b8c8a13153b7e8059 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Fri, 17 Apr 2026 10:55:21 +0100 Subject: [PATCH 06/17] Fix typo --- weaviate/collections/batch/async_.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/weaviate/collections/batch/async_.py b/weaviate/collections/batch/async_.py index a3c88926d..d687af02f 100644 --- a/weaviate/collections/batch/async_.py +++ b/weaviate/collections/batch/async_.py @@ -272,7 +272,7 @@ def request_maker(): request = request_maker() total_size = request.ByteSize() - uuids, beacons = set() + uuids, beacons = set(), set() for object_ in objects: obj = self.__batch_grpc.grpc_object(object_._to_internal()) obj_size = obj.ByteSize() + per_object_overhead From 1c67654d961a133a1c8780ff9f2fbd0d42f438e5 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Fri, 17 Apr 2026 11:42:05 +0100 Subject: [PATCH 07/17] Move inflight updates before yield to grpc to avoid racing --- weaviate/collections/batch/async_.py | 2 +- weaviate/collections/batch/sync.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/weaviate/collections/batch/async_.py b/weaviate/collections/batch/async_.py index d687af02f..e4aa685ea 100644 --- a/weaviate/collections/batch/async_.py +++ b/weaviate/collections/batch/async_.py @@ -337,9 +337,9 @@ async def __send( ) yield batch_pb2.BatchStreamRequest(stop=batch_pb2.BatchStreamRequest.Stop()) return - yield req.proto self.__inflight_objs.update(req.uuids) self.__inflight_refs.update(req.beacons) + yield req.proto continue except asyncio.TimeoutError: if self.__is_shutting_down.is_set(): diff --git a/weaviate/collections/batch/sync.py b/weaviate/collections/batch/sync.py index 772c17906..d796d18f1 100644 --- a/weaviate/collections/batch/sync.py +++ b/weaviate/collections/batch/sync.py @@ -289,10 +289,10 @@ def __send( ) yield batch_pb2.BatchStreamRequest(stop=batch_pb2.BatchStreamRequest.Stop()) return - yield req.proto with self.__acks_lock: self.__inflight_objs.update(req.uuids) self.__inflight_refs.update(req.beacons) + yield req.proto continue except Empty: if self.__is_shutting_down.is_set(): @@ -304,7 +304,9 @@ def __send( elif self.__is_hungup.is_set(): logger.info("Detected hung up stream, closing the client-side of the stream") return - logger.info("Timed out getting request from queue, but not stopping, continuing...") + logger.debug( + "Timed out getting request from queue, but not stopping, continuing..." + ) logger.info("Batch send thread exiting due to exception...") def __recv(self) -> None: @@ -575,6 +577,11 @@ def _add_object( self.__objs_count += 1 while self.__is_blocked(): + logger.warning("Batch is blocked, waiting to add more objects...") + if len(self.__inflight_objs) >= self.__batch_size: + logger.info( + f"Too many inflight_objs, waiting for acknowledgements from the server: {len(self.__inflight_objs)}, {self.__batch_size}" + ) self.__check_bg_threads_alive() time.sleep(0.01) @@ -617,6 +624,7 @@ def _add_reference( self.__refs_cache[batch_reference._to_beacon()] = batch_reference self.__refs_count += 1 while self.__is_blocked(): + logger.warning("Batch is blocked, waiting to add more references...") self.__check_bg_threads_alive() time.sleep(0.01) From 66d5c96cae4f1cb0bd628d975dc0ae22d84b1359 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Fri, 17 Apr 2026 11:49:59 +0100 Subject: [PATCH 08/17] Remove debugging logs --- weaviate/collections/batch/sync.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/weaviate/collections/batch/sync.py b/weaviate/collections/batch/sync.py index d796d18f1..b85ef7c1d 100644 --- a/weaviate/collections/batch/sync.py +++ b/weaviate/collections/batch/sync.py @@ -577,11 +577,6 @@ def _add_object( self.__objs_count += 1 while self.__is_blocked(): - logger.warning("Batch is blocked, waiting to add more objects...") - if len(self.__inflight_objs) >= self.__batch_size: - logger.info( - f"Too many inflight_objs, waiting for acknowledgements from the server: {len(self.__inflight_objs)}, {self.__batch_size}" - ) self.__check_bg_threads_alive() time.sleep(0.01) From a674ec1845948d0ce09bb148d07d43f350b8c883 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Fri, 17 Apr 2026 15:40:32 +0100 Subject: [PATCH 09/17] Add safety hatch in-case server doesn't clos its side of the stream in time (300s) --- weaviate/collections/batch/async_.py | 15 ++++++++++++++- weaviate/collections/batch/base.py | 1 + weaviate/collections/batch/sync.py | 14 ++++++++++++++ 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/weaviate/collections/batch/async_.py b/weaviate/collections/batch/async_.py index e4aa685ea..0a8503684 100644 --- a/weaviate/collections/batch/async_.py +++ b/weaviate/collections/batch/async_.py @@ -14,6 +14,7 @@ from weaviate.collections.batch.base import ( GCP_STREAM_TIMEOUT, + SHUTDOWN_TIMEOUT, ObjectsBatchRequest, ReferencesBatchRequest, _BatchDataWrapper, @@ -182,8 +183,20 @@ async def recv_wrapper() -> None: loop=loop, ) - async def _wait(self): + async def _wait(self) -> None: assert self.__bg_tasks is not None + deadline = time.time() + SHUTDOWN_TIMEOUT + while time.time() < deadline: + if not self.__bg_tasks.all_alive(): + break + await asyncio.sleep(0.1) + if self.__bg_tasks.all_alive(): + logger.warning( + f"Background batch tasks did not exit within {SHUTDOWN_TIMEOUT}s. " + f"Forcing shutdown. inflight_objs={len(self.__inflight_objs)}, " + f"inflight_refs={len(self.__inflight_refs)}" + ) + self.__shutdown_loop.set() # force __loop to exit await self.__bg_tasks.gather() # copy the results to the public results diff --git a/weaviate/collections/batch/base.py b/weaviate/collections/batch/base.py index 9ee04e136..08cc1a9df 100644 --- a/weaviate/collections/batch/base.py +++ b/weaviate/collections/batch/base.py @@ -64,6 +64,7 @@ GCP_STREAM_TIMEOUT = ( 160 # GCP connections have a max lifetime of 180s, leave 20s of buffer as safety ) +SHUTDOWN_TIMEOUT = 300 # time to wait for background threads to exit after shutdown is initiated, in seconds, in the event the server never hangs up class BatchRequest(ABC, Generic[TBatchInput, TBatchReturn]): diff --git a/weaviate/collections/batch/sync.py b/weaviate/collections/batch/sync.py index b85ef7c1d..01b082ea3 100644 --- a/weaviate/collections/batch/sync.py +++ b/weaviate/collections/batch/sync.py @@ -9,6 +9,7 @@ from weaviate.collections.batch.base import ( GCP_STREAM_TIMEOUT, + SHUTDOWN_TIMEOUT, ObjectsBatchRequest, ReferencesBatchRequest, _BatchDataWrapper, @@ -135,6 +136,19 @@ def _start(self) -> None: ) def _wait(self) -> None: + deadline = time.time() + SHUTDOWN_TIMEOUT + while time.time() < deadline: + if not self.__bg_threads.is_alive(): + break + time.sleep(0.1) + if self.__bg_threads.is_alive(): + logger.warning( + f"Background batch threads did not exit within {SHUTDOWN_TIMEOUT}s. " + f"Forcing shutdown. inflight_objs={len(self.__inflight_objs)}, " + f"inflight_refs={len(self.__inflight_refs)}" + ) + self.__shutdown_loop.set() # force __loop to exit + self.__bg_threads.join() # copy the results to the public results From f5d50989c676c74240f8b75d3e38b2faa4b281fb Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Fri, 17 Apr 2026 17:19:30 +0100 Subject: [PATCH 10/17] Improve shutdown behaviour of client by adding timeouts to joins and cancelling stream if necessary --- test/collection/test_batch_lifecycle.py | 59 +++++++++++++++++++++++++ weaviate/collections/batch/async_.py | 40 ++++++++++++++--- weaviate/collections/batch/base.py | 10 +++-- weaviate/collections/batch/sync.py | 59 ++++++++++++++++++++----- weaviate/connect/v4.py | 37 +++++++++------- 5 files changed, 168 insertions(+), 37 deletions(-) create mode 100644 test/collection/test_batch_lifecycle.py diff --git a/test/collection/test_batch_lifecycle.py b/test/collection/test_batch_lifecycle.py new file mode 100644 index 000000000..862f642c9 --- /dev/null +++ b/test/collection/test_batch_lifecycle.py @@ -0,0 +1,59 @@ +import asyncio +import threading +import time + +from weaviate.collections.batch.async_ import _BgTasks +from weaviate.collections.batch.base import _BgThreads + + +def test_bg_threads_any_alive_and_join_timeout() -> None: + def _short() -> None: + time.sleep(0.05) + + def _long() -> None: + time.sleep(0.3) + + bg_threads = _BgThreads( + loop=threading.Thread(target=_long, daemon=True), + recv=threading.Thread(target=_short, daemon=True), + ) + bg_threads.start_recv() + bg_threads.start_loop() + + time.sleep(0.12) + assert bg_threads.recv_alive() is False + assert bg_threads.loop_alive() is True + assert bg_threads.is_alive() is False + assert bg_threads.any_alive() is True + + start = time.time() + bg_threads.join(timeout=0.01) + assert time.time() - start < 0.2 + assert bg_threads.any_alive() is True + + bg_threads.join(timeout=1) + assert bg_threads.any_alive() is False + + +def test_bg_tasks_any_alive() -> None: + async def _short() -> None: + await asyncio.sleep(0.05) + + async def _long() -> None: + await asyncio.sleep(0.3) + + async def _run() -> None: + recv = asyncio.create_task(_short()) + loop = asyncio.create_task(_long()) + bg_tasks = _BgTasks(recv=recv, loop=loop) + + await asyncio.sleep(0.12) + assert bg_tasks.recv_alive() is False + assert bg_tasks.loop_alive() is True + assert bg_tasks.all_alive() is False + assert bg_tasks.any_alive() is True + + await bg_tasks.gather() + assert bg_tasks.any_alive() is False + + asyncio.run(_run()) diff --git a/weaviate/collections/batch/async_.py b/weaviate/collections/batch/async_.py index 0a8503684..efe1007c3 100644 --- a/weaviate/collections/batch/async_.py +++ b/weaviate/collections/batch/async_.py @@ -61,9 +61,18 @@ def __init__(self, recv: asyncio.Task[None], loop: asyncio.Task[None]) -> None: def all_alive(self) -> bool: return all([not self.recv.done(), not self.loop.done()]) + def any_alive(self) -> bool: + return not self.recv.done() or not self.loop.done() + + def recv_alive(self) -> bool: + return not self.recv.done() + + def loop_alive(self) -> bool: + return not self.loop.done() + async def gather(self) -> None: tasks = [self.recv, self.loop] - await asyncio.gather(*tasks) + await asyncio.gather(*tasks, return_exceptions=True) class _BatchBaseAsync: @@ -187,17 +196,32 @@ async def _wait(self) -> None: assert self.__bg_tasks is not None deadline = time.time() + SHUTDOWN_TIMEOUT while time.time() < deadline: - if not self.__bg_tasks.all_alive(): + if not self.__bg_tasks.any_alive(): break await asyncio.sleep(0.1) - if self.__bg_tasks.all_alive(): + if self.__bg_tasks.any_alive(): logger.warning( f"Background batch tasks did not exit within {SHUTDOWN_TIMEOUT}s. " f"Forcing shutdown. inflight_objs={len(self.__inflight_objs)}, " - f"inflight_refs={len(self.__inflight_refs)}" + f"inflight_refs={len(self.__inflight_refs)}, " + f"loop_alive={self.__bg_tasks.loop_alive()}, " + f"recv_alive={self.__bg_tasks.recv_alive()}" ) self.__shutdown_loop.set() # force __loop to exit - await self.__bg_tasks.gather() + self.__bg_tasks.recv.cancel() + self.__bg_tasks.loop.cancel() + try: + await asyncio.wait_for(self.__bg_tasks.gather(), timeout=5) + except asyncio.TimeoutError as e: + raise WeaviateBatchStreamError( + "Background batch tasks did not terminate after forced shutdown." + ) from e + if self.__bg_tasks.any_alive(): + raise WeaviateBatchStreamError( + "Background batch tasks did not terminate after forced shutdown. " + f"loop_alive={self.__bg_tasks.loop_alive()}, " + f"recv_alive={self.__bg_tasks.recv_alive()}" + ) # copy the results to the public results self.__results_for_wrapper_backup.results = self.__results_for_wrapper.results @@ -355,7 +379,11 @@ async def __send( yield req.proto continue except asyncio.TimeoutError: - if self.__is_shutting_down.is_set(): + if self.__shutdown_loop.is_set() or self.__is_stopped.is_set(): + logger.info("Batch shutdown requested, stopping and closing the stream") + yield batch_pb2.BatchStreamRequest(stop=batch_pb2.BatchStreamRequest.Stop()) + return + elif self.__is_shutting_down.is_set(): logger.info("Server shutting down, closing the client-side of the stream") return elif self.__is_oom.is_set(): diff --git a/weaviate/collections/batch/base.py b/weaviate/collections/batch/base.py index 08cc1a9df..ceba10a25 100644 --- a/weaviate/collections/batch/base.py +++ b/weaviate/collections/batch/base.py @@ -895,6 +895,10 @@ def is_alive(self) -> bool: """Check if the background threads are still alive.""" return self.loop_alive() and self.recv_alive() + def any_alive(self) -> bool: + """Check if any background thread is still alive.""" + return self.loop_alive() or self.recv_alive() + def loop_alive(self) -> bool: """Check if the loop background thread is still alive.""" if self.__started_loop: @@ -907,10 +911,10 @@ def recv_alive(self) -> bool: return self.recv.is_alive() return True # not started yet so considered alive - def join(self) -> None: + def join(self, timeout: Optional[float] = None) -> None: """Join the background threads.""" - self.loop.join() - self.recv.join() + self.loop.join(timeout=timeout) + self.recv.join(timeout=timeout) class _ClusterBatch: diff --git a/weaviate/collections/batch/sync.py b/weaviate/collections/batch/sync.py index 01b082ea3..42aad3a73 100644 --- a/weaviate/collections/batch/sync.py +++ b/weaviate/collections/batch/sync.py @@ -5,6 +5,7 @@ from queue import Empty, Full, Queue from typing import Generator, List, Optional, Set, Union +from grpc import Call from pydantic import ValidationError from weaviate.collections.batch.base import ( @@ -98,7 +99,6 @@ def __init__( self.__oom_wait_time = 300 self.__shutdown_loop = threading.Event() - self.__sent_sentinel = threading.Event() self.__objs_cache_lock = threading.Lock() self.__refs_cache_lock = threading.Lock() @@ -112,6 +112,8 @@ def __init__( # maxsize=1 so that __loop does not run faster than generator for __recv # thereby using too much buffer in case of server-side shutdown self.__reqs: Queue[Optional[_BatchStreamRequest]] = Queue(maxsize=1) + self.__stream_lock = threading.Lock() + self.__active_stream: Optional[Call] = None @property def number_errors(self) -> int: @@ -123,6 +125,26 @@ def number_errors(self) -> int: def __all_threads_alive(self) -> bool: return self.__bg_threads.is_alive() + def __any_threads_alive(self) -> bool: + return self.__bg_threads.any_alive() + + def __set_active_stream(self, call: Call) -> None: + with self.__stream_lock: + self.__active_stream = call + + def __clear_active_stream(self) -> None: + with self.__stream_lock: + self.__active_stream = None + + def __cancel_active_stream(self) -> bool: + with self.__stream_lock: + stream = self.__active_stream + + if stream is None: + return False + + return stream.cancel() + def _start(self) -> None: self.__start_bg_threads() logger.info("Provisioned stream to the server for batch processing") @@ -138,18 +160,28 @@ def _start(self) -> None: def _wait(self) -> None: deadline = time.time() + SHUTDOWN_TIMEOUT while time.time() < deadline: - if not self.__bg_threads.is_alive(): + if not self.__any_threads_alive(): break time.sleep(0.1) - if self.__bg_threads.is_alive(): + if self.__any_threads_alive(): logger.warning( f"Background batch threads did not exit within {SHUTDOWN_TIMEOUT}s. " f"Forcing shutdown. inflight_objs={len(self.__inflight_objs)}, " - f"inflight_refs={len(self.__inflight_refs)}" + f"inflight_refs={len(self.__inflight_refs)}, " + f"loop_alive={self.__bg_threads.loop_alive()}, " + f"recv_alive={self.__bg_threads.recv_alive()}" ) self.__shutdown_loop.set() # force __loop to exit - - self.__bg_threads.join() + self.__is_stopped.set() + self.__cancel_active_stream() # force __recv to exit by cancelling the stream + + self.__bg_threads.join(timeout=5) + if self.__any_threads_alive(): + raise WeaviateBatchStreamError( + "Background batch threads did not terminate after forced shutdown. " + f"loop_alive={self.__bg_threads.loop_alive()}, " + f"recv_alive={self.__bg_threads.recv_alive()}" + ) # copy the results to the public results self.__results_for_wrapper_backup.results = self.__results_for_wrapper.results @@ -216,13 +248,11 @@ def __loop(self) -> None: return elif ( self.__is_stopped.is_set() - and not self.__sent_sentinel.is_set() and not self.__is_hungup.is_set() and not self.__is_shutting_down.is_set() and not self.__is_oom.is_set() ): self.__reqs.put(None) - self.__sent_sentinel.set() time.sleep(refresh_time) def __generate_stream_requests( @@ -278,7 +308,6 @@ def request_maker(): def __send( self, ) -> Generator[batch_pb2.BatchStreamRequest, None, None]: - self.__sent_sentinel.clear() yield batch_pb2.BatchStreamRequest( start=batch_pb2.BatchStreamRequest.Start( consistency_level=self.__batch_grpc._consistency_level, @@ -309,7 +338,11 @@ def __send( yield req.proto continue except Empty: - if self.__is_shutting_down.is_set(): + if self.__shutdown_loop.is_set() or self.__is_stopped.is_set(): + logger.info("Batch shutdown requested, stopping and closing the stream") + yield batch_pb2.BatchStreamRequest(stop=batch_pb2.BatchStreamRequest.Stop()) + return + elif self.__is_shutting_down.is_set(): logger.info("Server shutting down, closing the client-side of the stream") return elif self.__is_oom.is_set(): @@ -324,14 +357,16 @@ def __send( logger.info("Batch send thread exiting due to exception...") def __recv(self) -> None: - stream = self.__batch_grpc.stream( + gen, call = self.__batch_grpc.stream( connection=self.__connection, requests=self.__send(), ) + self.__set_active_stream(call) + self.__is_renewing_stream.clear() self.__is_shutting_down.clear() self.__is_hungup.clear() - for message in stream: + for message in gen: if message.HasField("started"): logger.info("Batch stream started successfully") diff --git a/weaviate/connect/v4.py b/weaviate/connect/v4.py index 56ece8ca2..4d8265375 100644 --- a/weaviate/connect/v4.py +++ b/weaviate/connect/v4.py @@ -1018,22 +1018,27 @@ def grpc_batch_objects( def grpc_batch_stream( self, requests: Generator[batch_pb2.BatchStreamRequest, None, None], - ) -> Generator[batch_pb2.BatchStreamReply, None, None]: - try: - assert self.grpc_stub is not None - for msg in self.grpc_stub.BatchStream( - request_iterator=requests, - timeout=self.timeout_config.stream, - metadata=self.grpc_headers(), - ): - yield msg - except RpcError as e: - error = cast(Call, e) - if error.code() == StatusCode.PERMISSION_DENIED: - raise InsufficientPermissionsError(error) - if error.code() == StatusCode.ABORTED: - raise _BatchStreamShutdownError() - raise WeaviateBatchStreamError(f"{error.code()}({error.details()})") + ) -> tuple[Generator[batch_pb2.BatchStreamReply, None, None], Call]: + assert self.grpc_stub is not None + call = self.grpc_stub.BatchStream( + request_iterator=requests, + timeout=self.timeout_config.stream, + metadata=self.grpc_headers(), + )() + + def generator(): + try: + for msg in call: + yield msg + except RpcError as e: + error = cast(Call, e) + if error.code() == StatusCode.PERMISSION_DENIED: + raise InsufficientPermissionsError(error) + if error.code() == StatusCode.ABORTED: + raise _BatchStreamShutdownError() + raise WeaviateBatchStreamError(f"{error.code()}({error.details()})") + + return generator(), call def grpc_batch_delete( self, request: batch_delete_pb2.BatchDeleteRequest From fd5453db994eaf3f5ea9d65a5931ab9aec6e799c Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Mon, 20 Apr 2026 12:05:42 +0100 Subject: [PATCH 11/17] Improvements to batching logic: - Avoid 60s timeout putting into self.__reqs - Handle graceful stopping of all bg threads - Allow cancelling of hanging streams - Reraise bg_exceptions when they happen - Align shutdown timeout with client-defined insert timeout - Add mock tests of cancelling bidi streamsa --- mock_tests/conftest.py | 48 +++++++++++++++++++++++++++- mock_tests/test_connect.py | 25 +++++++++++++++ weaviate/collections/batch/async_.py | 9 +++--- weaviate/collections/batch/base.py | 7 ++-- weaviate/collections/batch/sync.py | 41 +++++++++++++++++------- weaviate/connect/v4.py | 2 +- 6 files changed, 110 insertions(+), 22 deletions(-) create mode 100644 mock_tests/test_connect.py diff --git a/mock_tests/conftest.py b/mock_tests/conftest.py index de22b2e51..9c0bf19ec 100644 --- a/mock_tests/conftest.py +++ b/mock_tests/conftest.py @@ -1,7 +1,7 @@ import json import time from concurrent import futures -from typing import Generator, Mapping +from typing import AsyncGenerator, Generator, Mapping import grpc import pytest @@ -141,6 +141,16 @@ def weaviate_client( client.close() +@pytest.fixture(scope="function") +async def weaviate_client_async( + weaviate_mock: HTTPServer, start_grpc_server: grpc.Server +) -> AsyncGenerator[weaviate.WeaviateAsyncClient, None]: + client = weaviate.use_async_with_local(port=MOCK_PORT, host=MOCK_IP, grpc_port=MOCK_PORT_GRPC) + await client.connect() + yield client + await client.close() + + @pytest.fixture(scope="function") def weaviate_timeouts_client( weaviate_timeouts_mock: HTTPServer, start_grpc_server: grpc.Server @@ -368,3 +378,39 @@ def forbidden( service = MockForbiddenWeaviateService() weaviate_pb2_grpc.add_WeaviateServicer_to_server(service, start_grpc_server) return weaviate_client.collections.use("ForbiddenCollection") + + +class MockWeaviateService(weaviate_pb2_grpc.WeaviateServicer): + def BatchStream( + self, + request_iterator: Generator[batch_pb2.BatchStreamRequest, None, None], + context: grpc.ServicerContext, + ) -> Generator[batch_pb2.BatchStreamReply, None, None]: + while True: + if context.is_active(): + time.sleep(0.1) + else: + raise grpc.RpcError(grpc.StatusCode.DEADLINE_EXCEEDED, "Deadline exceeded") + + +@pytest.fixture(scope="function") +def stream_cancel( + weaviate_client: weaviate.WeaviateClient, + weaviate_mock: HTTPServer, + start_grpc_server: grpc.Server, +) -> Generator[weaviate.collections.Collection, None, None]: + name = "StreamCancelCollection" + weaviate_mock.expect_request(f"/v1/schema/{name}").respond_with_response( + Response(status=404) + ) # skips __create_batch_reset vectorizer logic + weaviate_pb2_grpc.add_WeaviateServicer_to_server(MockWeaviateService(), start_grpc_server) + client = weaviate.connect_to_local( + port=MOCK_PORT, + host=MOCK_IP, + grpc_port=MOCK_PORT_GRPC, + additional_config=weaviate.classes.init.AdditionalConfig( + timeout=weaviate.classes.init.Timeout(insert=1) + ), + ) + yield client.collections.use(name) + client.close() diff --git a/mock_tests/test_connect.py b/mock_tests/test_connect.py new file mode 100644 index 000000000..bb6625e53 --- /dev/null +++ b/mock_tests/test_connect.py @@ -0,0 +1,25 @@ +import time +import pytest +import weaviate +from weaviate.proto.v1 import batch_pb2 + + +def test_bidi_stream_cancel_sync(stream_cancel: weaviate.collections.Collection): + def gen(): + time.sleep(10) + yield batch_pb2.BatchStreamRequest() + + out, call = stream_cancel._connection.grpc_batch_stream(gen()) + assert call.is_active() + call.cancel() + assert not call.is_active() + with pytest.raises(weaviate.exceptions.WeaviateBatchStreamError) as e: + next(out) + assert "StatusCode.CANCELLED(Locally cancelled by application!)" in e.value.message + + +def test_batch_stream_hanging_server(stream_cancel: weaviate.collections.Collection): + with pytest.raises(weaviate.exceptions.WeaviateBatchStreamError) as e: + with stream_cancel.batch.stream() as batch: + batch.add_object() + assert "The server did not hangup its side of the stream gracefully in time" in e.value.message diff --git a/weaviate/collections/batch/async_.py b/weaviate/collections/batch/async_.py index efe1007c3..7b6eb8e1d 100644 --- a/weaviate/collections/batch/async_.py +++ b/weaviate/collections/batch/async_.py @@ -14,7 +14,6 @@ from weaviate.collections.batch.base import ( GCP_STREAM_TIMEOUT, - SHUTDOWN_TIMEOUT, ObjectsBatchRequest, ReferencesBatchRequest, _BatchDataWrapper, @@ -194,14 +193,16 @@ async def recv_wrapper() -> None: async def _wait(self) -> None: assert self.__bg_tasks is not None - deadline = time.time() + SHUTDOWN_TIMEOUT + # this is how long an insert will take to timeout for, so we wait at most this time +5s for the batch to finish after shutdown is initiated, in case the server never hangs up + shutdown_timeout = self.__connection.timeout_config.insert + 5 + deadline = time.time() + shutdown_timeout while time.time() < deadline: if not self.__bg_tasks.any_alive(): break await asyncio.sleep(0.1) if self.__bg_tasks.any_alive(): logger.warning( - f"Background batch tasks did not exit within {SHUTDOWN_TIMEOUT}s. " + f"Background batch tasks did not exit within {shutdown_timeout}s. " f"Forcing shutdown. inflight_objs={len(self.__inflight_objs)}, " f"inflight_refs={len(self.__inflight_refs)}, " f"loop_alive={self.__bg_tasks.loop_alive()}, " @@ -211,7 +212,7 @@ async def _wait(self) -> None: self.__bg_tasks.recv.cancel() self.__bg_tasks.loop.cancel() try: - await asyncio.wait_for(self.__bg_tasks.gather(), timeout=5) + await asyncio.wait_for(self.__bg_tasks.gather(), timeout=None) except asyncio.TimeoutError as e: raise WeaviateBatchStreamError( "Background batch tasks did not terminate after forced shutdown." diff --git a/weaviate/collections/batch/base.py b/weaviate/collections/batch/base.py index ceba10a25..fd9aae261 100644 --- a/weaviate/collections/batch/base.py +++ b/weaviate/collections/batch/base.py @@ -64,7 +64,6 @@ GCP_STREAM_TIMEOUT = ( 160 # GCP connections have a max lifetime of 180s, leave 20s of buffer as safety ) -SHUTDOWN_TIMEOUT = 300 # time to wait for background threads to exit after shutdown is initiated, in seconds, in the event the server never hangs up class BatchRequest(ABC, Generic[TBatchInput, TBatchReturn]): @@ -911,10 +910,10 @@ def recv_alive(self) -> bool: return self.recv.is_alive() return True # not started yet so considered alive - def join(self, timeout: Optional[float] = None) -> None: + def join(self) -> None: """Join the background threads.""" - self.loop.join(timeout=timeout) - self.recv.join(timeout=timeout) + self.loop.join() + self.recv.join() class _ClusterBatch: diff --git a/weaviate/collections/batch/sync.py b/weaviate/collections/batch/sync.py index 42aad3a73..204e6a0ab 100644 --- a/weaviate/collections/batch/sync.py +++ b/weaviate/collections/batch/sync.py @@ -10,7 +10,6 @@ from weaviate.collections.batch.base import ( GCP_STREAM_TIMEOUT, - SHUTDOWN_TIMEOUT, ObjectsBatchRequest, ReferencesBatchRequest, _BatchDataWrapper, @@ -158,14 +157,16 @@ def _start(self) -> None: ) def _wait(self) -> None: - deadline = time.time() + SHUTDOWN_TIMEOUT + # this is how long an insert will take to timeout for, so we wait at most this time +5s for the batch to finish after shutdown is initiated, in case the server never hangs up + shutdown_timeout = self.__connection.timeout_config.insert + 5 + deadline = time.time() + shutdown_timeout while time.time() < deadline: if not self.__any_threads_alive(): break time.sleep(0.1) if self.__any_threads_alive(): logger.warning( - f"Background batch threads did not exit within {SHUTDOWN_TIMEOUT}s. " + f"Background batch threads did not exit within {shutdown_timeout}s. " f"Forcing shutdown. inflight_objs={len(self.__inflight_objs)}, " f"inflight_refs={len(self.__inflight_refs)}, " f"loop_alive={self.__bg_threads.loop_alive()}, " @@ -175,7 +176,7 @@ def _wait(self) -> None: self.__is_stopped.set() self.__cancel_active_stream() # force __recv to exit by cancelling the stream - self.__bg_threads.join(timeout=5) + self.__bg_threads.join() if self.__any_threads_alive(): raise WeaviateBatchStreamError( "Background batch threads did not terminate after forced shutdown. " @@ -193,10 +194,29 @@ def _wait(self) -> None: self.__results_for_wrapper.imported_shards ) + if self.__bg_exception is not None: + if "StatusCode.CANCELLED(Locally cancelled by application!)" in str( + self.__bg_exception + ): + raise WeaviateBatchStreamError( + "The server did not hangup its side of the stream gracefully in time" + ) + raise self.__bg_exception + def _shutdown(self) -> None: # Shutdown the current batch and wait for all requests to be finished self.__is_stopped.set() + def __put(self, req: _BatchStreamRequest | None): + while True: + try: + self.__reqs.put(req, timeout=1) + return True + except Full: + if self.__bg_exception is not None or self.__shutdown_loop.is_set(): + return False + return self.__put(req) + def __loop(self) -> None: refresh_time: float = 0.01 while self.__bg_exception is None and not self.__shutdown_loop.is_set(): @@ -238,13 +258,8 @@ def __loop(self) -> None: if paused: logger.info("Server is back up, resuming batching loop...") paused = False - try: - self.__reqs.put(req, timeout=60) - except Full as e: - logger.warning( - "Batch queue is blocked for more than 60 seconds. Exiting the loop" - ) - self.__bg_exception = e + if not self.__put(req): + logger.info("Batch loop is shutting down, stopping putting requests...") return elif ( self.__is_stopped.is_set() @@ -252,7 +267,9 @@ def __loop(self) -> None: and not self.__is_shutting_down.is_set() and not self.__is_oom.is_set() ): - self.__reqs.put(None) + if not self.__put(None): + logger.info("Batch loop is shutting down, stopping putting shutdown signal...") + return time.sleep(refresh_time) def __generate_stream_requests( diff --git a/weaviate/connect/v4.py b/weaviate/connect/v4.py index 4d8265375..7faf2ba76 100644 --- a/weaviate/connect/v4.py +++ b/weaviate/connect/v4.py @@ -1024,7 +1024,7 @@ def grpc_batch_stream( request_iterator=requests, timeout=self.timeout_config.stream, metadata=self.grpc_headers(), - )() + ) def generator(): try: From a47587f6fc24772013028baa34e21d3b005832ab Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Mon, 20 Apr 2026 12:09:16 +0100 Subject: [PATCH 12/17] Remove redundant unit tests --- test/collection/test_batch_lifecycle.py | 59 ------------------------- 1 file changed, 59 deletions(-) delete mode 100644 test/collection/test_batch_lifecycle.py diff --git a/test/collection/test_batch_lifecycle.py b/test/collection/test_batch_lifecycle.py deleted file mode 100644 index 862f642c9..000000000 --- a/test/collection/test_batch_lifecycle.py +++ /dev/null @@ -1,59 +0,0 @@ -import asyncio -import threading -import time - -from weaviate.collections.batch.async_ import _BgTasks -from weaviate.collections.batch.base import _BgThreads - - -def test_bg_threads_any_alive_and_join_timeout() -> None: - def _short() -> None: - time.sleep(0.05) - - def _long() -> None: - time.sleep(0.3) - - bg_threads = _BgThreads( - loop=threading.Thread(target=_long, daemon=True), - recv=threading.Thread(target=_short, daemon=True), - ) - bg_threads.start_recv() - bg_threads.start_loop() - - time.sleep(0.12) - assert bg_threads.recv_alive() is False - assert bg_threads.loop_alive() is True - assert bg_threads.is_alive() is False - assert bg_threads.any_alive() is True - - start = time.time() - bg_threads.join(timeout=0.01) - assert time.time() - start < 0.2 - assert bg_threads.any_alive() is True - - bg_threads.join(timeout=1) - assert bg_threads.any_alive() is False - - -def test_bg_tasks_any_alive() -> None: - async def _short() -> None: - await asyncio.sleep(0.05) - - async def _long() -> None: - await asyncio.sleep(0.3) - - async def _run() -> None: - recv = asyncio.create_task(_short()) - loop = asyncio.create_task(_long()) - bg_tasks = _BgTasks(recv=recv, loop=loop) - - await asyncio.sleep(0.12) - assert bg_tasks.recv_alive() is False - assert bg_tasks.loop_alive() is True - assert bg_tasks.all_alive() is False - assert bg_tasks.any_alive() is True - - await bg_tasks.gather() - assert bg_tasks.any_alive() is False - - asyncio.run(_run()) From b5ddfb2323832f5ebc03bbb8a387976817efd735 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Mon, 20 Apr 2026 12:27:36 +0100 Subject: [PATCH 13/17] Remove wrong if clause when req is Empty --- weaviate/collections/batch/async_.py | 6 +----- weaviate/collections/batch/sync.py | 6 +----- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/weaviate/collections/batch/async_.py b/weaviate/collections/batch/async_.py index 7b6eb8e1d..3630b6097 100644 --- a/weaviate/collections/batch/async_.py +++ b/weaviate/collections/batch/async_.py @@ -380,11 +380,7 @@ async def __send( yield req.proto continue except asyncio.TimeoutError: - if self.__shutdown_loop.is_set() or self.__is_stopped.is_set(): - logger.info("Batch shutdown requested, stopping and closing the stream") - yield batch_pb2.BatchStreamRequest(stop=batch_pb2.BatchStreamRequest.Stop()) - return - elif self.__is_shutting_down.is_set(): + if self.__is_shutting_down.is_set(): logger.info("Server shutting down, closing the client-side of the stream") return elif self.__is_oom.is_set(): diff --git a/weaviate/collections/batch/sync.py b/weaviate/collections/batch/sync.py index 204e6a0ab..2d00cdd49 100644 --- a/weaviate/collections/batch/sync.py +++ b/weaviate/collections/batch/sync.py @@ -355,11 +355,7 @@ def __send( yield req.proto continue except Empty: - if self.__shutdown_loop.is_set() or self.__is_stopped.is_set(): - logger.info("Batch shutdown requested, stopping and closing the stream") - yield batch_pb2.BatchStreamRequest(stop=batch_pb2.BatchStreamRequest.Stop()) - return - elif self.__is_shutting_down.is_set(): + if self.__is_shutting_down.is_set(): logger.info("Server shutting down, closing the client-side of the stream") return elif self.__is_oom.is_set(): From df5014e63e3c49d1d3ecdc02bdc5c4fd8863856b Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Mon, 20 Apr 2026 15:44:31 +0100 Subject: [PATCH 14/17] Remove redundant stream cancellation logic, align async with sync impl --- .github/workflows/main.yaml | 15 +++--- mock_tests/test_connect.py | 25 ---------- weaviate/collections/batch/async_.py | 63 ++++++++--------------- weaviate/collections/batch/base.py | 6 +-- weaviate/collections/batch/sync.py | 75 +++++----------------------- weaviate/connect/v4.py | 35 ++++++------- 6 files changed, 57 insertions(+), 162 deletions(-) delete mode 100644 mock_tests/test_connect.py diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 62203108c..8535e67b4 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -22,14 +22,13 @@ env: WEAVIATE_128: 1.28.16 WEAVIATE_129: 1.29.11 WEAVIATE_130: 1.30.22 - WEAVIATE_131: 1.31.20 - WEAVIATE_132: 1.32.23 - WEAVIATE_133: 1.33.10 - WEAVIATE_134: 1.34.5 - WEAVIATE_135: 1.35.16-efdedfa - WEAVIATE_136: 1.36.9-d905e6c - WEAVIATE_137: 1.37.0-rc.0-b313954.amd64 - + WEAVIATE_131: 1.31.22 + WEAVIATE_132: 1.32.27 + WEAVIATE_133: 1.33.18 + WEAVIATE_134: 1.34.19 + WEAVIATE_135: 1.35.15 + WEAVIATE_136: 1.36.6-8edcf08.amd64 + WEAVIATE_137: 1.37.0-dev-29d5c87.amd64 jobs: lint-and-format: diff --git a/mock_tests/test_connect.py b/mock_tests/test_connect.py deleted file mode 100644 index bb6625e53..000000000 --- a/mock_tests/test_connect.py +++ /dev/null @@ -1,25 +0,0 @@ -import time -import pytest -import weaviate -from weaviate.proto.v1 import batch_pb2 - - -def test_bidi_stream_cancel_sync(stream_cancel: weaviate.collections.Collection): - def gen(): - time.sleep(10) - yield batch_pb2.BatchStreamRequest() - - out, call = stream_cancel._connection.grpc_batch_stream(gen()) - assert call.is_active() - call.cancel() - assert not call.is_active() - with pytest.raises(weaviate.exceptions.WeaviateBatchStreamError) as e: - next(out) - assert "StatusCode.CANCELLED(Locally cancelled by application!)" in e.value.message - - -def test_batch_stream_hanging_server(stream_cancel: weaviate.collections.Collection): - with pytest.raises(weaviate.exceptions.WeaviateBatchStreamError) as e: - with stream_cancel.batch.stream() as batch: - batch.add_object() - assert "The server did not hangup its side of the stream gracefully in time" in e.value.message diff --git a/weaviate/collections/batch/async_.py b/weaviate/collections/batch/async_.py index 3630b6097..4a30f349b 100644 --- a/weaviate/collections/batch/async_.py +++ b/weaviate/collections/batch/async_.py @@ -115,7 +115,6 @@ def __init__( self.__oom_wait_time = 300 self.__shutdown_loop = asyncio.Event() - self.__sent_sentinel = asyncio.Event() self.__objs_cache_lock = asyncio.Lock() self.__objs_cache: dict[str, BatchObject] = {} @@ -195,34 +194,12 @@ async def _wait(self) -> None: assert self.__bg_tasks is not None # this is how long an insert will take to timeout for, so we wait at most this time +5s for the batch to finish after shutdown is initiated, in case the server never hangs up shutdown_timeout = self.__connection.timeout_config.insert + 5 - deadline = time.time() + shutdown_timeout - while time.time() < deadline: - if not self.__bg_tasks.any_alive(): - break - await asyncio.sleep(0.1) - if self.__bg_tasks.any_alive(): - logger.warning( - f"Background batch tasks did not exit within {shutdown_timeout}s. " - f"Forcing shutdown. inflight_objs={len(self.__inflight_objs)}, " - f"inflight_refs={len(self.__inflight_refs)}, " - f"loop_alive={self.__bg_tasks.loop_alive()}, " - f"recv_alive={self.__bg_tasks.recv_alive()}" - ) - self.__shutdown_loop.set() # force __loop to exit - self.__bg_tasks.recv.cancel() - self.__bg_tasks.loop.cancel() try: - await asyncio.wait_for(self.__bg_tasks.gather(), timeout=None) + await asyncio.wait_for(self.__bg_tasks.gather(), timeout=shutdown_timeout) except asyncio.TimeoutError as e: raise WeaviateBatchStreamError( "Background batch tasks did not terminate after forced shutdown." ) from e - if self.__bg_tasks.any_alive(): - raise WeaviateBatchStreamError( - "Background batch tasks did not terminate after forced shutdown. " - f"loop_alive={self.__bg_tasks.loop_alive()}, " - f"recv_alive={self.__bg_tasks.recv_alive()}" - ) # copy the results to the public results self.__results_for_wrapper_backup.results = self.__results_for_wrapper.results @@ -237,6 +214,15 @@ async def _wait(self) -> None: async def _shutdown(self) -> None: self.__is_stopped.set() + async def __put(self, req: _BatchStreamRequest | None): + try: + await asyncio.wait_for(self.__reqs.put(req), timeout=1) + return True + except asyncio.TimeoutError: + if self.__bg_exception is not None or self.__shutdown_loop.is_set(): + return False + return await self.__put(req) + async def __loop(self) -> None: refresh_time: float = 0.01 while self.__bg_exception is None and not self.__shutdown_loop.is_set(): @@ -278,23 +264,18 @@ async def __loop(self) -> None: if paused: logger.info("Server is back up, resuming batching loop...") paused = False - try: - await asyncio.wait_for(self.__reqs.put(req), timeout=60) - except asyncio.TimeoutError as e: - logger.warning( - "Batch queue is blocked for more than 60 seconds. Exiting the loop" - ) - self.__bg_exception = e + if not self.__put(req): + logger.info("Batch loop is shutting down, stopping putting new requests...") return elif ( self.__is_stopped.is_set() - and not self.__sent_sentinel.is_set() and not self.__is_hungup.is_set() and not self.__is_shutting_down.is_set() and not self.__is_oom.is_set() ): - await self.__reqs.put(None) - self.__sent_sentinel.set() + await self.__put(None) + logger.info("Sent sentinel, stopping batch loop...") + return await asyncio.sleep(refresh_time) def __generate_stream_requests( @@ -347,10 +328,7 @@ def request_maker(): if len(request.data.objects.values) > 0 or len(request.data.references.values) > 0: yield _BatchStreamRequest(request, uuids, beacons) - async def __send( - self, - ) -> AsyncGenerator[batch_pb2.BatchStreamRequest, None]: - self.__sent_sentinel.clear() + async def __send(self) -> AsyncGenerator[batch_pb2.BatchStreamRequest, None]: yield batch_pb2.BatchStreamRequest( start=batch_pb2.BatchStreamRequest.Start( consistency_level=self.__batch_grpc._consistency_level, @@ -393,14 +371,13 @@ async def __send( logger.info("Batch send thread exiting due to exception...") async def __recv(self) -> None: - stream = self.__batch_grpc.astream( - connection=self.__connection, - requests=self.__send(), - ) self.__is_renewing_stream.clear() self.__is_shutting_down.clear() self.__is_hungup.clear() - async for message in stream: + async for message in self.__batch_grpc.astream( + connection=self.__connection, + requests=self.__send(), + ): if message.HasField("started"): logger.info("Batch stream started successfully") diff --git a/weaviate/collections/batch/base.py b/weaviate/collections/batch/base.py index fd9aae261..73de73531 100644 --- a/weaviate/collections/batch/base.py +++ b/weaviate/collections/batch/base.py @@ -910,10 +910,10 @@ def recv_alive(self) -> bool: return self.recv.is_alive() return True # not started yet so considered alive - def join(self) -> None: + def join(self, timeout: float | None = None) -> None: """Join the background threads.""" - self.loop.join() - self.recv.join() + self.loop.join(timeout=timeout) + self.recv.join(timeout=timeout) class _ClusterBatch: diff --git a/weaviate/collections/batch/sync.py b/weaviate/collections/batch/sync.py index 2d00cdd49..910423b43 100644 --- a/weaviate/collections/batch/sync.py +++ b/weaviate/collections/batch/sync.py @@ -66,7 +66,6 @@ def __init__( self.__connection = connection self.__is_gcp_on_wcd = connection._connection_params.is_gcp_on_wcd() - self.__stream_start: Optional[float] = None self.__is_renewing_stream = threading.Event() self.__consistency_level: ConsistencyLevel = consistency_level or ConsistencyLevel.QUORUM self.__batch_size = 100 @@ -124,26 +123,6 @@ def number_errors(self) -> int: def __all_threads_alive(self) -> bool: return self.__bg_threads.is_alive() - def __any_threads_alive(self) -> bool: - return self.__bg_threads.any_alive() - - def __set_active_stream(self, call: Call) -> None: - with self.__stream_lock: - self.__active_stream = call - - def __clear_active_stream(self) -> None: - with self.__stream_lock: - self.__active_stream = None - - def __cancel_active_stream(self) -> bool: - with self.__stream_lock: - stream = self.__active_stream - - if stream is None: - return False - - return stream.cancel() - def _start(self) -> None: self.__start_bg_threads() logger.info("Provisioned stream to the server for batch processing") @@ -159,30 +138,12 @@ def _start(self) -> None: def _wait(self) -> None: # this is how long an insert will take to timeout for, so we wait at most this time +5s for the batch to finish after shutdown is initiated, in case the server never hangs up shutdown_timeout = self.__connection.timeout_config.insert + 5 - deadline = time.time() + shutdown_timeout - while time.time() < deadline: - if not self.__any_threads_alive(): - break - time.sleep(0.1) - if self.__any_threads_alive(): - logger.warning( - f"Background batch threads did not exit within {shutdown_timeout}s. " - f"Forcing shutdown. inflight_objs={len(self.__inflight_objs)}, " - f"inflight_refs={len(self.__inflight_refs)}, " - f"loop_alive={self.__bg_threads.loop_alive()}, " - f"recv_alive={self.__bg_threads.recv_alive()}" - ) - self.__shutdown_loop.set() # force __loop to exit - self.__is_stopped.set() - self.__cancel_active_stream() # force __recv to exit by cancelling the stream - - self.__bg_threads.join() - if self.__any_threads_alive(): + try: + self.__bg_threads.join(shutdown_timeout) + except TimeoutError as e: raise WeaviateBatchStreamError( - "Background batch threads did not terminate after forced shutdown. " - f"loop_alive={self.__bg_threads.loop_alive()}, " - f"recv_alive={self.__bg_threads.recv_alive()}" - ) + "Background batch threads did not terminate after forced shutdown." + ) from e # copy the results to the public results self.__results_for_wrapper_backup.results = self.__results_for_wrapper.results @@ -194,15 +155,6 @@ def _wait(self) -> None: self.__results_for_wrapper.imported_shards ) - if self.__bg_exception is not None: - if "StatusCode.CANCELLED(Locally cancelled by application!)" in str( - self.__bg_exception - ): - raise WeaviateBatchStreamError( - "The server did not hangup its side of the stream gracefully in time" - ) - raise self.__bg_exception - def _shutdown(self) -> None: # Shutdown the current batch and wait for all requests to be finished self.__is_stopped.set() @@ -267,9 +219,9 @@ def __loop(self) -> None: and not self.__is_shutting_down.is_set() and not self.__is_oom.is_set() ): - if not self.__put(None): - logger.info("Batch loop is shutting down, stopping putting shutdown signal...") - return + self.__put(None) + logger.info("Sent sentinel, stopping batch loop...") + return time.sleep(refresh_time) def __generate_stream_requests( @@ -370,16 +322,13 @@ def __send( logger.info("Batch send thread exiting due to exception...") def __recv(self) -> None: - gen, call = self.__batch_grpc.stream( - connection=self.__connection, - requests=self.__send(), - ) - self.__set_active_stream(call) - self.__is_renewing_stream.clear() self.__is_shutting_down.clear() self.__is_hungup.clear() - for message in gen: + for message in self.__batch_grpc.stream( + connection=self.__connection, + requests=self.__send(), + ): if message.HasField("started"): logger.info("Batch stream started successfully") diff --git a/weaviate/connect/v4.py b/weaviate/connect/v4.py index 7faf2ba76..adac4be38 100644 --- a/weaviate/connect/v4.py +++ b/weaviate/connect/v4.py @@ -1018,27 +1018,22 @@ def grpc_batch_objects( def grpc_batch_stream( self, requests: Generator[batch_pb2.BatchStreamRequest, None, None], - ) -> tuple[Generator[batch_pb2.BatchStreamReply, None, None], Call]: + ) -> Generator[batch_pb2.BatchStreamReply, None, None]: assert self.grpc_stub is not None - call = self.grpc_stub.BatchStream( - request_iterator=requests, - timeout=self.timeout_config.stream, - metadata=self.grpc_headers(), - ) - - def generator(): - try: - for msg in call: - yield msg - except RpcError as e: - error = cast(Call, e) - if error.code() == StatusCode.PERMISSION_DENIED: - raise InsufficientPermissionsError(error) - if error.code() == StatusCode.ABORTED: - raise _BatchStreamShutdownError() - raise WeaviateBatchStreamError(f"{error.code()}({error.details()})") - - return generator(), call + try: + for msg in self.grpc_stub.BatchStream( + request_iterator=requests, + timeout=self.timeout_config.stream, + metadata=self.grpc_headers(), + ): + yield msg + except RpcError as e: + error = cast(Call, e) + if error.code() == StatusCode.PERMISSION_DENIED: + raise InsufficientPermissionsError(error) + if error.code() == StatusCode.ABORTED: + raise _BatchStreamShutdownError() + raise WeaviateBatchStreamError(f"{error.code()}({error.details()})") def grpc_batch_delete( self, request: batch_delete_pb2.BatchDeleteRequest From a5ecad7e3eec0b8402f625e15d49dff982da38b2 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Mon, 20 Apr 2026 16:22:12 +0100 Subject: [PATCH 15/17] Fix missed await --- weaviate/collections/batch/async_.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/weaviate/collections/batch/async_.py b/weaviate/collections/batch/async_.py index 4a30f349b..e4417daa1 100644 --- a/weaviate/collections/batch/async_.py +++ b/weaviate/collections/batch/async_.py @@ -264,7 +264,7 @@ async def __loop(self) -> None: if paused: logger.info("Server is back up, resuming batching loop...") paused = False - if not self.__put(req): + if not await self.__put(req): logger.info("Batch loop is shutting down, stopping putting new requests...") return elif ( From ec476babeab387990c38e29486079a29350af6f8 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Mon, 20 Apr 2026 16:25:35 +0100 Subject: [PATCH 16/17] Remove redundant code --- weaviate/collections/batch/async_.py | 9 --------- weaviate/collections/batch/base.py | 4 ---- weaviate/collections/batch/sync.py | 3 --- 3 files changed, 16 deletions(-) diff --git a/weaviate/collections/batch/async_.py b/weaviate/collections/batch/async_.py index e4417daa1..ad073e513 100644 --- a/weaviate/collections/batch/async_.py +++ b/weaviate/collections/batch/async_.py @@ -60,15 +60,6 @@ def __init__(self, recv: asyncio.Task[None], loop: asyncio.Task[None]) -> None: def all_alive(self) -> bool: return all([not self.recv.done(), not self.loop.done()]) - def any_alive(self) -> bool: - return not self.recv.done() or not self.loop.done() - - def recv_alive(self) -> bool: - return not self.recv.done() - - def loop_alive(self) -> bool: - return not self.loop.done() - async def gather(self) -> None: tasks = [self.recv, self.loop] await asyncio.gather(*tasks, return_exceptions=True) diff --git a/weaviate/collections/batch/base.py b/weaviate/collections/batch/base.py index 73de73531..a700df53f 100644 --- a/weaviate/collections/batch/base.py +++ b/weaviate/collections/batch/base.py @@ -894,10 +894,6 @@ def is_alive(self) -> bool: """Check if the background threads are still alive.""" return self.loop_alive() and self.recv_alive() - def any_alive(self) -> bool: - """Check if any background thread is still alive.""" - return self.loop_alive() or self.recv_alive() - def loop_alive(self) -> bool: """Check if the loop background thread is still alive.""" if self.__started_loop: diff --git a/weaviate/collections/batch/sync.py b/weaviate/collections/batch/sync.py index 910423b43..6627f8911 100644 --- a/weaviate/collections/batch/sync.py +++ b/weaviate/collections/batch/sync.py @@ -5,7 +5,6 @@ from queue import Empty, Full, Queue from typing import Generator, List, Optional, Set, Union -from grpc import Call from pydantic import ValidationError from weaviate.collections.batch.base import ( @@ -110,8 +109,6 @@ def __init__( # maxsize=1 so that __loop does not run faster than generator for __recv # thereby using too much buffer in case of server-side shutdown self.__reqs: Queue[Optional[_BatchStreamRequest]] = Queue(maxsize=1) - self.__stream_lock = threading.Lock() - self.__active_stream: Optional[Call] = None @property def number_errors(self) -> int: From eddac53aa74f8192a551dc5ec736b9eff9131e9e Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Mon, 20 Apr 2026 16:27:29 +0100 Subject: [PATCH 17/17] Achieve parity in async/sync impls --- weaviate/collections/batch/async_.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/weaviate/collections/batch/async_.py b/weaviate/collections/batch/async_.py index ad073e513..8b997586c 100644 --- a/weaviate/collections/batch/async_.py +++ b/weaviate/collections/batch/async_.py @@ -60,9 +60,9 @@ def __init__(self, recv: asyncio.Task[None], loop: asyncio.Task[None]) -> None: def all_alive(self) -> bool: return all([not self.recv.done(), not self.loop.done()]) - async def gather(self) -> None: + async def gather(self, timeout: float | None = None) -> None: tasks = [self.recv, self.loop] - await asyncio.gather(*tasks, return_exceptions=True) + await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=timeout) class _BatchBaseAsync: @@ -186,7 +186,7 @@ async def _wait(self) -> None: # this is how long an insert will take to timeout for, so we wait at most this time +5s for the batch to finish after shutdown is initiated, in case the server never hangs up shutdown_timeout = self.__connection.timeout_config.insert + 5 try: - await asyncio.wait_for(self.__bg_tasks.gather(), timeout=shutdown_timeout) + await self.__bg_tasks.gather(timeout=shutdown_timeout) except asyncio.TimeoutError as e: raise WeaviateBatchStreamError( "Background batch tasks did not terminate after forced shutdown."