diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 2f68432ba..8535e67b4 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -22,19 +22,19 @@ env: WEAVIATE_128: 1.28.16 WEAVIATE_129: 1.29.11 WEAVIATE_130: 1.30.22 - WEAVIATE_131: 1.31.20 - WEAVIATE_132: 1.32.23 - WEAVIATE_133: 1.33.10 - WEAVIATE_134: 1.34.5 - WEAVIATE_135: 1.35.16-efdedfa - WEAVIATE_136: 1.36.9-d905e6c - WEAVIATE_137: 1.37.0-rc.0-b313954.amd64 - + WEAVIATE_131: 1.31.22 + WEAVIATE_132: 1.32.27 + WEAVIATE_133: 1.33.18 + WEAVIATE_134: 1.34.19 + WEAVIATE_135: 1.35.15 + WEAVIATE_136: 1.36.6-8edcf08.amd64 + WEAVIATE_137: 1.37.0-dev-29d5c87.amd64 jobs: lint-and-format: name: Run Linter and Formatter runs-on: ubuntu-latest + timeout-minutes: 5 steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 @@ -60,6 +60,7 @@ jobs: type-checking: name: Run Type Checking runs-on: ubuntu-latest + timeout-minutes: 5 strategy: fail-fast: false matrix: @@ -80,6 +81,7 @@ jobs: unit-tests: name: Run Unit Tests runs-on: ubuntu-latest + timeout-minutes: 5 strategy: fail-fast: false matrix: @@ -104,6 +106,7 @@ jobs: proto-test: name: Run importing protos test runs-on: ubuntu-latest + timeout-minutes: 5 strategy: fail-fast: false matrix: @@ -124,6 +127,7 @@ jobs: integration-tests-embedded: name: Run Integration Tests Embedded runs-on: ubuntu-latest + timeout-minutes: 30 strategy: matrix: version: ["3.10", "3.11", "3.12", "3.13", "3.14"] @@ -153,6 +157,7 @@ jobs: integration-tests: name: Run Integration Tests runs-on: ubuntu-latest + timeout-minutes: 30 strategy: fail-fast: false matrix: @@ -208,6 +213,7 @@ jobs: journey-tests: name: Run Journey Tests runs-on: ubuntu-latest + timeout-minutes: 30 strategy: fail-fast: false matrix: @@ -243,6 +249,7 @@ jobs: Codecov: needs: [Unit-Tests, Integration-Tests] runs-on: ubuntu-latest + timeout-minutes: 5 if: github.ref_name != 'main' && !github.event.pull_request.head.repo.fork steps: - uses: actions/checkout@v4 @@ -273,6 +280,7 @@ jobs: build-package: name: Build package runs-on: ubuntu-latest + timeout-minutes: 10 steps: - name: Checkout uses: actions/checkout@v4 @@ -297,6 +305,7 @@ jobs: test-package: needs: [build-package] runs-on: ubuntu-latest + timeout-minutes: 30 strategy: fail-fast: false matrix: @@ -341,6 +350,7 @@ jobs: name: Build and publish Python 🐍 distributions 📦 to PyPI and TestPyPI needs: [integration-tests, unit-tests, lint-and-format, type-checking, test-package, proto-test] runs-on: ubuntu-latest + timeout-minutes: 20 steps: - name: Checkout uses: actions/checkout@v4 @@ -366,6 +376,7 @@ jobs: name: Create a GitHub Release on new tags if: startsWith(github.ref, 'refs/tags') runs-on: ubuntu-latest + timeout-minutes: 5 needs: [build-and-publish] steps: - name: Download build artifact to append to release diff --git a/conftest.py b/conftest.py new file mode 100644 index 000000000..ced4f6ed2 --- /dev/null +++ b/conftest.py @@ -0,0 +1,59 @@ +import faulthandler +import os +import threading + +import pytest + +DEFAULT_TIMEOUT = 300 # 5 minutes + +_timeout_timer: threading.Timer | None = None + + +def _get_timeout(item: pytest.Item) -> float: + marker = item.get_closest_marker("timeout") + if marker and marker.args: + return float(marker.args[0]) + return float(DEFAULT_TIMEOUT) + + +def pytest_runtest_setup(item: pytest.Item) -> None: + """Start a watchdog timer that dumps all thread stack traces on timeout. + + Unlike pytest-timeout, this does NOT raise KeyboardInterrupt (which crashes + xdist workers and corrupts asyncio event loops). Instead it: + 1. Writes the test name + all thread tracebacks directly to fd 2 (stderr). + With --capture=sys in pytest.ini, fd 2 is the real stderr (not captured), + so the output goes directly to the CI log even under xdist. + 2. Calls os._exit(1) to terminate the worker process. + + xdist will report 'node down: Not properly terminated' which is expected — + the diagnostic output will already be in the CI logs above that message. + """ + global _timeout_timer + timeout = _get_timeout(item) + if timeout <= 0: + return + + def _on_timeout() -> None: + banner = "=" * 70 + os.write(2, f"\n\n{banner}\n".encode()) + os.write(2, f"TIMEOUT: {item.nodeid} exceeded {timeout}s\n".encode()) + os.write(2, f"{banner}\n\n".encode()) + # faulthandler needs a file object — wrap a dup of fd 2 to avoid closing it + with os.fdopen(os.dup(2), "w") as f: + faulthandler.dump_traceback(file=f) + f.flush() + os.write(2, f"\n{banner}\n\n".encode()) + os._exit(1) + + _timeout_timer = threading.Timer(timeout, _on_timeout) + _timeout_timer.daemon = True + _timeout_timer.start() + + +def pytest_runtest_teardown(item: pytest.Item, nextitem: pytest.Item | None) -> None: + """Cancel the watchdog timer after each test completes.""" + global _timeout_timer + if _timeout_timer is not None: + _timeout_timer.cancel() + _timeout_timer = None diff --git a/integration/test_batch_v4.py b/integration/test_batch_v4.py index ad4c70bb2..f4ce7669e 100644 --- a/integration/test_batch_v4.py +++ b/integration/test_batch_v4.py @@ -410,6 +410,7 @@ def test_add_ref_batch_with_tenant(client_factory: ClientFactory) -> None: assert ret_obj.references["test"].objects[0].uuid == obj[0] +@pytest.mark.timeout(600) @pytest.mark.parametrize( "batching_method", [ @@ -717,6 +718,7 @@ def test_non_existant_collection(client_factory: ClientFactory) -> None: # not, so we do not check for errors here +@pytest.mark.timeout(600) def test_number_of_stored_results_in_batch(client_factory: ClientFactory) -> None: client, name = client_factory() with client.batch.dynamic() as batch: @@ -816,6 +818,7 @@ def test_references_with_to_uuids(client_factory: ClientFactory) -> None: @pytest.mark.asyncio +@pytest.mark.timeout(600) async def test_add_one_hundred_thousand_objects_async_client( async_client_factory: AsyncClientFactory, ) -> None: @@ -846,6 +849,7 @@ async def test_add_one_hundred_thousand_objects_async_client( await client.collections.delete(name) +@pytest.mark.timeout(600) def test_add_one_hundred_thousand_objects_sync_client( client_factory: ClientFactory, ) -> None: diff --git a/integration/test_collection_batch.py b/integration/test_collection_batch.py index 40683a26f..e670e4883 100644 --- a/integration/test_collection_batch.py +++ b/integration/test_collection_batch.py @@ -271,6 +271,7 @@ def test_non_existant_collection(collection_factory_get: CollectionFactoryGet) - @pytest.mark.asyncio +@pytest.mark.timeout(600) async def test_batch_one_hundred_thousand_objects_async_collection( batch_collection_async: BatchCollectionAsync, ) -> None: @@ -298,6 +299,7 @@ async def test_batch_one_hundred_thousand_objects_async_collection( @pytest.mark.asyncio +@pytest.mark.timeout(600) async def test_ingest_one_hundred_thousand_data_objects_async( batch_collection_async: BatchCollectionAsync, ) -> None: @@ -319,6 +321,7 @@ async def test_ingest_one_hundred_thousand_data_objects_async( assert len(results.errors) == 0, [obj.message for obj in results.errors.values()] +@pytest.mark.timeout(600) def test_ingest_one_hundred_thousand_data_objects( batch_collection: BatchCollection, ) -> None: diff --git a/mock_tests/conftest.py b/mock_tests/conftest.py index de22b2e51..9c0bf19ec 100644 --- a/mock_tests/conftest.py +++ b/mock_tests/conftest.py @@ -1,7 +1,7 @@ import json import time from concurrent import futures -from typing import Generator, Mapping +from typing import AsyncGenerator, Generator, Mapping import grpc import pytest @@ -141,6 +141,16 @@ def weaviate_client( client.close() +@pytest.fixture(scope="function") +async def weaviate_client_async( + weaviate_mock: HTTPServer, start_grpc_server: grpc.Server +) -> AsyncGenerator[weaviate.WeaviateAsyncClient, None]: + client = weaviate.use_async_with_local(port=MOCK_PORT, host=MOCK_IP, grpc_port=MOCK_PORT_GRPC) + await client.connect() + yield client + await client.close() + + @pytest.fixture(scope="function") def weaviate_timeouts_client( weaviate_timeouts_mock: HTTPServer, start_grpc_server: grpc.Server @@ -368,3 +378,39 @@ def forbidden( service = MockForbiddenWeaviateService() weaviate_pb2_grpc.add_WeaviateServicer_to_server(service, start_grpc_server) return weaviate_client.collections.use("ForbiddenCollection") + + +class MockWeaviateService(weaviate_pb2_grpc.WeaviateServicer): + def BatchStream( + self, + request_iterator: Generator[batch_pb2.BatchStreamRequest, None, None], + context: grpc.ServicerContext, + ) -> Generator[batch_pb2.BatchStreamReply, None, None]: + while True: + if context.is_active(): + time.sleep(0.1) + else: + raise grpc.RpcError(grpc.StatusCode.DEADLINE_EXCEEDED, "Deadline exceeded") + + +@pytest.fixture(scope="function") +def stream_cancel( + weaviate_client: weaviate.WeaviateClient, + weaviate_mock: HTTPServer, + start_grpc_server: grpc.Server, +) -> Generator[weaviate.collections.Collection, None, None]: + name = "StreamCancelCollection" + weaviate_mock.expect_request(f"/v1/schema/{name}").respond_with_response( + Response(status=404) + ) # skips __create_batch_reset vectorizer logic + weaviate_pb2_grpc.add_WeaviateServicer_to_server(MockWeaviateService(), start_grpc_server) + client = weaviate.connect_to_local( + port=MOCK_PORT, + host=MOCK_IP, + grpc_port=MOCK_PORT_GRPC, + additional_config=weaviate.classes.init.AdditionalConfig( + timeout=weaviate.classes.init.Timeout(insert=1) + ), + ) + yield client.collections.use(name) + client.close() diff --git a/pytest.ini b/pytest.ini index 36321aa1d..f87230abf 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,5 +1,6 @@ [pytest] -addopts = -m 'not profiling' --benchmark-skip -l +addopts = -m 'not profiling' --benchmark-skip -l --capture=sys --max-worker-restart=3 markers = profiling: marks tests that can be profiled + timeout: marks tests with a custom timeout in seconds (default: 300) asyncio_default_fixture_loop_scope = function \ No newline at end of file diff --git a/test/test_timeout.py b/test/test_timeout.py new file mode 100644 index 000000000..fc49f51e0 --- /dev/null +++ b/test/test_timeout.py @@ -0,0 +1,90 @@ +"""Tests for the custom per-test timeout mechanism in conftest.py. + +Uses subprocess because the timeout mechanism calls os._exit(1). +""" + +import subprocess +import sys +import textwrap +from pathlib import Path + +PROJECT_ROOT = Path(__file__).parent.parent + + +def _run_pytest(tmp_path: Path, test_code: str, *extra_args: str) -> subprocess.CompletedProcess: + """Run pytest in a subprocess with a copy of our timeout conftest.""" + (tmp_path / "conftest.py").write_text((PROJECT_ROOT / "conftest.py").read_text()) + (tmp_path / "pytest.ini").write_text( + "[pytest]\naddopts = --capture=sys --max-worker-restart=0\nmarkers =\n timeout: custom timeout\n" + ) + (tmp_path / "test_it.py").write_text(textwrap.dedent(test_code)) + return subprocess.run( + [ + sys.executable, + "-m", + "pytest", + "-v", + "-n", + "auto", + "--dist", + "loadgroup", + "test_it.py", + *extra_args, + ], + capture_output=True, + text=True, + timeout=60, + cwd=str(tmp_path), + ) + + +def test_timeout_prints_test_name_and_stacktrace(tmp_path: Path) -> None: + result = _run_pytest( + tmp_path, + """\ + import time + import pytest + + @pytest.mark.timeout(2) + def test_hangs(): + time.sleep(999) + """, + ) + assert result.returncode != 0 + assert "TIMEOUT: test_it.py::test_hangs exceeded 2.0s" in result.stderr + assert "test_hangs" in result.stderr + + +def test_fast_test_not_killed(tmp_path: Path) -> None: + result = _run_pytest( + tmp_path, + """\ + import pytest + + @pytest.mark.timeout(10) + def test_fast(): + assert True + """, + ) + assert result.returncode == 0 + assert "TIMEOUT" not in result.stderr + + +def test_timeout_with_passing_and_hanging_test(tmp_path: Path) -> None: + result = _run_pytest( + tmp_path, + """\ + import time + import pytest + + @pytest.mark.timeout(2) + def test_hangs_in_worker(): + time.sleep(999) + + def test_passes(): + assert True + """, + ) + assert result.returncode != 0 + assert "TIMEOUT: test_it.py::test_hangs_in_worker exceeded 2.0s" in result.stderr + assert "test_hangs_in_worker" in result.stderr diff --git a/weaviate/collections/batch/async_.py b/weaviate/collections/batch/async_.py index b71f8be39..8b997586c 100644 --- a/weaviate/collections/batch/async_.py +++ b/weaviate/collections/batch/async_.py @@ -17,6 +17,7 @@ ObjectsBatchRequest, ReferencesBatchRequest, _BatchDataWrapper, + _BatchStreamRequest, _ClusterBatchAsync, ) from weaviate.collections.batch.grpc_batch import _BatchGRPC @@ -59,9 +60,9 @@ def __init__(self, recv: asyncio.Task[None], loop: asyncio.Task[None]) -> None: def all_alive(self) -> bool: return all([not self.recv.done(), not self.loop.done()]) - async def gather(self) -> None: + async def gather(self, timeout: float | None = None) -> None: tasks = [self.recv, self.loop] - await asyncio.gather(*tasks) + await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=timeout) class _BatchBaseAsync: @@ -105,7 +106,6 @@ def __init__( self.__oom_wait_time = 300 self.__shutdown_loop = asyncio.Event() - self.__sent_sentinel = asyncio.Event() self.__objs_cache_lock = asyncio.Lock() self.__objs_cache: dict[str, BatchObject] = {} @@ -117,9 +117,7 @@ def __init__( # maxsize=1 so that __send does not run faster than generator for __recv # thereby using too much buffer in case of server-side shutdown - self.__reqs: asyncio.Queue[Optional[batch_pb2.BatchStreamRequest]] = asyncio.Queue( - maxsize=1 - ) + self.__reqs: asyncio.Queue[Optional[_BatchStreamRequest]] = asyncio.Queue(maxsize=1) self.__bg_exception: Optional[Exception] = None self.__bg_tasks: Optional[_BgTasks] = None @@ -170,6 +168,8 @@ async def recv_wrapper() -> None: await self.__batch_objects.aprepend(list(self.__objs_cache.values())) async with self.__refs_cache_lock: await self.__batch_references.aprepend(list(self.__refs_cache.values())) + self.__inflight_objs.clear() + self.__inflight_refs.clear() # start a new stream with a newly reconnected channel return await recv_wrapper() @@ -181,9 +181,16 @@ async def recv_wrapper() -> None: loop=loop, ) - async def _wait(self): + async def _wait(self) -> None: assert self.__bg_tasks is not None - await self.__bg_tasks.gather() + # this is how long an insert will take to timeout for, so we wait at most this time +5s for the batch to finish after shutdown is initiated, in case the server never hangs up + shutdown_timeout = self.__connection.timeout_config.insert + 5 + try: + await self.__bg_tasks.gather(timeout=shutdown_timeout) + except asyncio.TimeoutError as e: + raise WeaviateBatchStreamError( + "Background batch tasks did not terminate after forced shutdown." + ) from e # copy the results to the public results self.__results_for_wrapper_backup.results = self.__results_for_wrapper.results @@ -198,6 +205,15 @@ async def _wait(self): async def _shutdown(self) -> None: self.__is_stopped.set() + async def __put(self, req: _BatchStreamRequest | None): + try: + await asyncio.wait_for(self.__reqs.put(req), timeout=1) + return True + except asyncio.TimeoutError: + if self.__bg_exception is not None or self.__shutdown_loop.is_set(): + return False + return await self.__put(req) + async def __loop(self) -> None: refresh_time: float = 0.01 while self.__bg_exception is None and not self.__shutdown_loop.is_set(): @@ -239,30 +255,25 @@ async def __loop(self) -> None: if paused: logger.info("Server is back up, resuming batching loop...") paused = False - try: - await asyncio.wait_for(self.__reqs.put(req), timeout=60) - except asyncio.TimeoutError as e: - logger.warning( - "Batch queue is blocked for more than 60 seconds. Exiting the loop" - ) - self.__bg_exception = e + if not await self.__put(req): + logger.info("Batch loop is shutting down, stopping putting new requests...") return elif ( self.__is_stopped.is_set() - and not self.__sent_sentinel.is_set() and not self.__is_hungup.is_set() and not self.__is_shutting_down.is_set() and not self.__is_oom.is_set() ): - await self.__reqs.put(None) - self.__sent_sentinel.set() + await self.__put(None) + logger.info("Sent sentinel, stopping batch loop...") + return await asyncio.sleep(refresh_time) def __generate_stream_requests( self, objects: List[BatchObject], references: List[BatchReference], - ) -> Generator[batch_pb2.BatchStreamRequest, None, None]: + ) -> Generator[_BatchStreamRequest, None, None]: per_object_overhead = 4 # extra overhead bytes per object in the request def request_maker(): @@ -271,8 +282,7 @@ def request_maker(): request = request_maker() total_size = request.ByteSize() - inflight_objs = set() - inflight_refs = set() + uuids, beacons = set(), set() for object_ in objects: obj = self.__batch_grpc.grpc_object(object_._to_internal()) obj_size = obj.ByteSize() + per_object_overhead @@ -283,40 +293,33 @@ def request_maker(): ) if total_size + obj_size >= self.__batch_grpc.grpc_max_msg_size: - self.__inflight_objs.update(inflight_objs) - self.__inflight_refs.update(inflight_refs) - yield request + yield _BatchStreamRequest(request, uuids, beacons) request = request_maker() total_size = request.ByteSize() + uuids, beacons = set(), set() request.data.objects.values.append(obj) total_size += obj_size - inflight_objs.add(obj.uuid) + uuids.add(obj.uuid) for reference in references: ref = self.__batch_grpc.grpc_reference(reference._to_internal()) ref_size = ref.ByteSize() + per_object_overhead if total_size + ref_size >= self.__batch_grpc.grpc_max_msg_size: - self.__inflight_objs.update(inflight_objs) - self.__inflight_refs.update(inflight_refs) - yield request + yield _BatchStreamRequest(request, uuids, beacons) request = request_maker() total_size = request.ByteSize() + uuids, beacons = set(), set() request.data.references.values.append(ref) total_size += ref_size - inflight_refs.add(reference._to_beacon()) + beacons.add(reference._to_beacon()) if len(request.data.objects.values) > 0 or len(request.data.references.values) > 0: - self.__inflight_objs.update(inflight_objs) - self.__inflight_refs.update(inflight_refs) - yield request + yield _BatchStreamRequest(request, uuids, beacons) - async def __send( - self, - ) -> AsyncGenerator[batch_pb2.BatchStreamRequest, None]: - self.__sent_sentinel.clear() + async def __send(self) -> AsyncGenerator[batch_pb2.BatchStreamRequest, None]: yield batch_pb2.BatchStreamRequest( start=batch_pb2.BatchStreamRequest.Start( consistency_level=self.__batch_grpc._consistency_level, @@ -341,7 +344,9 @@ async def __send( ) yield batch_pb2.BatchStreamRequest(stop=batch_pb2.BatchStreamRequest.Stop()) return - yield req + self.__inflight_objs.update(req.uuids) + self.__inflight_refs.update(req.beacons) + yield req.proto continue except asyncio.TimeoutError: if self.__is_shutting_down.is_set(): @@ -357,14 +362,13 @@ async def __send( logger.info("Batch send thread exiting due to exception...") async def __recv(self) -> None: - stream = self.__batch_grpc.astream( - connection=self.__connection, - requests=self.__send(), - ) self.__is_renewing_stream.clear() self.__is_shutting_down.clear() self.__is_hungup.clear() - async for message in stream: + async for message in self.__batch_grpc.astream( + connection=self.__connection, + requests=self.__send(), + ): if message.HasField("started"): logger.info("Batch stream started successfully") diff --git a/weaviate/collections/batch/base.py b/weaviate/collections/batch/base.py index 14a0c0768..a700df53f 100644 --- a/weaviate/collections/batch/base.py +++ b/weaviate/collections/batch/base.py @@ -43,6 +43,7 @@ WeaviateBatchValidationError, ) from weaviate.logger import logger +from weaviate.proto.v1 import batch_pb2 from weaviate.types import UUID, VECTORS from weaviate.util import _decode_json_response_dict from weaviate.warnings import _Warnings @@ -227,6 +228,13 @@ async def ahead(self) -> Optional[Obj]: return self.__head() +@dataclass +class _BatchStreamRequest: + proto: batch_pb2.BatchStreamRequest + uuids: set[str] + beacons: set[str] + + @dataclass class _BatchDataWrapper: results: BatchResult = field(default_factory=BatchResult) @@ -898,10 +906,10 @@ def recv_alive(self) -> bool: return self.recv.is_alive() return True # not started yet so considered alive - def join(self) -> None: + def join(self, timeout: float | None = None) -> None: """Join the background threads.""" - self.loop.join() - self.recv.join() + self.loop.join(timeout=timeout) + self.recv.join(timeout=timeout) class _ClusterBatch: diff --git a/weaviate/collections/batch/sync.py b/weaviate/collections/batch/sync.py index f219de563..6627f8911 100644 --- a/weaviate/collections/batch/sync.py +++ b/weaviate/collections/batch/sync.py @@ -13,6 +13,7 @@ ReferencesBatchRequest, _BatchDataWrapper, _BatchMode, + _BatchStreamRequest, _BgThreads, _ClusterBatch, ) @@ -64,7 +65,6 @@ def __init__( self.__connection = connection self.__is_gcp_on_wcd = connection._connection_params.is_gcp_on_wcd() - self.__stream_start: Optional[float] = None self.__is_renewing_stream = threading.Event() self.__consistency_level: ConsistencyLevel = consistency_level or ConsistencyLevel.QUORUM self.__batch_size = 100 @@ -96,7 +96,6 @@ def __init__( self.__oom_wait_time = 300 self.__shutdown_loop = threading.Event() - self.__sent_sentinel = threading.Event() self.__objs_cache_lock = threading.Lock() self.__refs_cache_lock = threading.Lock() @@ -109,7 +108,7 @@ def __init__( # maxsize=1 so that __loop does not run faster than generator for __recv # thereby using too much buffer in case of server-side shutdown - self.__reqs: Queue[Optional[batch_pb2.BatchStreamRequest]] = Queue(maxsize=1) + self.__reqs: Queue[Optional[_BatchStreamRequest]] = Queue(maxsize=1) @property def number_errors(self) -> int: @@ -134,7 +133,14 @@ def _start(self) -> None: ) def _wait(self) -> None: - self.__bg_threads.join() + # this is how long an insert will take to timeout for, so we wait at most this time +5s for the batch to finish after shutdown is initiated, in case the server never hangs up + shutdown_timeout = self.__connection.timeout_config.insert + 5 + try: + self.__bg_threads.join(shutdown_timeout) + except TimeoutError as e: + raise WeaviateBatchStreamError( + "Background batch threads did not terminate after forced shutdown." + ) from e # copy the results to the public results self.__results_for_wrapper_backup.results = self.__results_for_wrapper.results @@ -150,6 +156,16 @@ def _shutdown(self) -> None: # Shutdown the current batch and wait for all requests to be finished self.__is_stopped.set() + def __put(self, req: _BatchStreamRequest | None): + while True: + try: + self.__reqs.put(req, timeout=1) + return True + except Full: + if self.__bg_exception is not None or self.__shutdown_loop.is_set(): + return False + return self.__put(req) + def __loop(self) -> None: refresh_time: float = 0.01 while self.__bg_exception is None and not self.__shutdown_loop.is_set(): @@ -191,30 +207,25 @@ def __loop(self) -> None: if paused: logger.info("Server is back up, resuming batching loop...") paused = False - try: - self.__reqs.put(req, timeout=60) - except Full as e: - logger.warning( - "Batch queue is blocked for more than 60 seconds. Exiting the loop" - ) - self.__bg_exception = e + if not self.__put(req): + logger.info("Batch loop is shutting down, stopping putting requests...") return elif ( self.__is_stopped.is_set() - and not self.__sent_sentinel.is_set() and not self.__is_hungup.is_set() and not self.__is_shutting_down.is_set() and not self.__is_oom.is_set() ): - self.__reqs.put(None) - self.__sent_sentinel.set() + self.__put(None) + logger.info("Sent sentinel, stopping batch loop...") + return time.sleep(refresh_time) def __generate_stream_requests( self, objects: List[BatchObject], references: List[BatchReference], - ) -> Generator[batch_pb2.BatchStreamRequest, None, None]: + ) -> Generator[_BatchStreamRequest, None, None]: per_object_overhead = 4 # extra overhead bytes per object in the request def request_maker(): @@ -223,8 +234,7 @@ def request_maker(): request = request_maker() total_size = request.ByteSize() - inflight_objs = set() - inflight_refs = set() + uuids, beacons = set(), set() for object_ in objects: obj = self.__batch_grpc.grpc_object(object_._to_internal()) obj_size = obj.ByteSize() + per_object_overhead @@ -235,38 +245,35 @@ def request_maker(): ) if total_size + obj_size >= self.__batch_grpc.grpc_max_msg_size: - yield request + yield _BatchStreamRequest(request, uuids, beacons) request = request_maker() total_size = request.ByteSize() + uuids, beacons = set(), set() request.data.objects.values.append(obj) total_size += obj_size - inflight_objs.add(obj.uuid) + uuids.add(obj.uuid) for reference in references: ref = self.__batch_grpc.grpc_reference(reference._to_internal()) ref_size = ref.ByteSize() + per_object_overhead if total_size + ref_size >= self.__batch_grpc.grpc_max_msg_size: - yield request + yield _BatchStreamRequest(request, uuids, beacons) request = request_maker() total_size = request.ByteSize() + uuids, beacons = set(), set() request.data.references.values.append(ref) total_size += ref_size - inflight_refs.add(reference._to_beacon()) - - with self.__acks_lock: - self.__inflight_objs.update(inflight_objs) - self.__inflight_refs.update(inflight_refs) + beacons.add(reference._to_beacon()) if len(request.data.objects.values) > 0 or len(request.data.references.values) > 0: - yield request + yield _BatchStreamRequest(request, uuids, beacons) def __send( self, ) -> Generator[batch_pb2.BatchStreamRequest, None, None]: - self.__sent_sentinel.clear() yield batch_pb2.BatchStreamRequest( start=batch_pb2.BatchStreamRequest.Start( consistency_level=self.__batch_grpc._consistency_level, @@ -291,7 +298,10 @@ def __send( ) yield batch_pb2.BatchStreamRequest(stop=batch_pb2.BatchStreamRequest.Stop()) return - yield req + with self.__acks_lock: + self.__inflight_objs.update(req.uuids) + self.__inflight_refs.update(req.beacons) + yield req.proto continue except Empty: if self.__is_shutting_down.is_set(): @@ -303,18 +313,19 @@ def __send( elif self.__is_hungup.is_set(): logger.info("Detected hung up stream, closing the client-side of the stream") return - logger.info("Timed out getting request from queue, but not stopping, continuing...") + logger.debug( + "Timed out getting request from queue, but not stopping, continuing..." + ) logger.info("Batch send thread exiting due to exception...") def __recv(self) -> None: - stream = self.__batch_grpc.stream( - connection=self.__connection, - requests=self.__send(), - ) self.__is_renewing_stream.clear() self.__is_shutting_down.clear() self.__is_hungup.clear() - for message in stream: + for message in self.__batch_grpc.stream( + connection=self.__connection, + requests=self.__send(), + ): if message.HasField("started"): logger.info("Batch stream started successfully") @@ -513,6 +524,8 @@ def recv_wrapper() -> None: self.__batch_objects.prepend(list(self.__objs_cache.values())) with self.__refs_cache_lock: self.__batch_references.prepend(list(self.__refs_cache.values())) + self.__inflight_objs.clear() + self.__inflight_refs.clear() # start a new stream with a newly reconnected channel return recv_wrapper() @@ -614,6 +627,7 @@ def _add_reference( self.__refs_cache[batch_reference._to_beacon()] = batch_reference self.__refs_count += 1 while self.__is_blocked(): + logger.warning("Batch is blocked, waiting to add more references...") self.__check_bg_threads_alive() time.sleep(0.01) diff --git a/weaviate/connect/v4.py b/weaviate/connect/v4.py index 56ece8ca2..adac4be38 100644 --- a/weaviate/connect/v4.py +++ b/weaviate/connect/v4.py @@ -1019,8 +1019,8 @@ def grpc_batch_stream( self, requests: Generator[batch_pb2.BatchStreamRequest, None, None], ) -> Generator[batch_pb2.BatchStreamReply, None, None]: + assert self.grpc_stub is not None try: - assert self.grpc_stub is not None for msg in self.grpc_stub.BatchStream( request_iterator=requests, timeout=self.timeout_config.stream,