diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 1c98f56ab..2d5b28841 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -480,12 +480,11 @@ async def _manage_channel( old_channel = super_channel.swap_channel(new_channel) self._invalidate_channel_stubs() # give old_channel a chance to complete existing rpcs - if CrossSync.is_async: - await old_channel.close(grace_period) - else: - if grace_period: - self._is_closed.wait(grace_period) # type: ignore - old_channel.close() # type: ignore + if grace_period: + await CrossSync.event_wait( + self._is_closed, grace_period, async_break_early=False + ) + await old_channel.close() # subtract the time spent waiting for the channel to be replaced next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) next_sleep = max(next_refresh - (time.monotonic() - start_timestamp), 0) diff --git a/google/cloud/bigtable/data/_sync_autogen/client.py b/google/cloud/bigtable/data/_sync_autogen/client.py index a403643f5..8e99ef05c 100644 --- a/google/cloud/bigtable/data/_sync_autogen/client.py +++ b/google/cloud/bigtable/data/_sync_autogen/client.py @@ -364,7 +364,9 @@ def _manage_channel( old_channel = super_channel.swap_channel(new_channel) self._invalidate_channel_stubs() if grace_period: - self._is_closed.wait(grace_period) + CrossSync._Sync_Impl.event_wait( + self._is_closed, grace_period, async_break_early=False + ) old_channel.close() next_refresh = random.uniform(refresh_interval_min, refresh_interval_max) next_sleep = max(next_refresh - (time.monotonic() - start_timestamp), 0) diff --git a/tests/system/data/test_system_async.py b/tests/system/data/test_system_async.py index 39c454996..ac8a358a3 100644 --- a/tests/system/data/test_system_async.py +++ b/tests/system/data/test_system_async.py @@ -266,49 +266,49 @@ async def test_ping_and_warm(self, client, target): @CrossSync.pytest async def test_channel_refresh(self, table_id, instance_id, temp_rows): """ - change grpc channel to refresh after 1 second. Schedule a read_rows call after refresh, - to ensure new channel works + perform requests while swapping out the grpc channel. Requests should continue without error """ - await temp_rows.add_row(b"row_key_1") - await temp_rows.add_row(b"row_key_2") - client = self._make_client() - # start custom refresh task - try: + import time + + await temp_rows.add_row(b"test_row") + async with self._make_client() as client: + client._channel_refresh_task.cancel() + channel_wrapper = client.transport.grpc_channel + first_channel = channel_wrapper._channel + # swap channels frequently, with large grace windows client._channel_refresh_task = CrossSync.create_task( client._manage_channel, - refresh_interval_min=1, - refresh_interval_max=1, + refresh_interval_min=0.1, + refresh_interval_max=0.1, + grace_period=1, sync_executor=client._executor, ) - # let task run - await CrossSync.yield_to_event_loop() + + # hit channels with frequent requests + end_time = time.monotonic() + 3 async with client.get_table(instance_id, table_id) as table: - rows = await table.read_rows({}) - channel_wrapper = client.transport.grpc_channel - first_channel = channel_wrapper._channel - assert len(rows) == 2 - await CrossSync.sleep(2) - rows_after_refresh = await table.read_rows({}) - assert len(rows_after_refresh) == 2 - assert client.transport.grpc_channel is channel_wrapper - updated_channel = channel_wrapper._channel - assert updated_channel is not first_channel - # ensure interceptors are kept (gapic's logging interceptor, and metric interceptor) - if CrossSync.is_async: - unary_interceptors = updated_channel._unary_unary_interceptors - assert len(unary_interceptors) == 2 - assert GapicInterceptor in [type(i) for i in unary_interceptors] - assert client._metrics_interceptor in unary_interceptors - stream_interceptors = updated_channel._unary_stream_interceptors - assert len(stream_interceptors) == 1 - assert client._metrics_interceptor in stream_interceptors - else: - assert isinstance( - client.transport._logged_channel._interceptor, GapicInterceptor - ) - assert updated_channel._interceptor == client._metrics_interceptor - finally: - await client.close() + while time.monotonic() < end_time: + # we expect a CancelledError if a channel is closed before completion + rows = await table.read_rows({}) + assert len(rows) == 1 + await CrossSync.yield_to_event_loop() + # ensure channel was updated + updated_channel = channel_wrapper._channel + assert updated_channel is not first_channel + # ensure interceptors are kept (gapic's logging interceptor, and metric interceptor) + if CrossSync.is_async: + unary_interceptors = updated_channel._unary_unary_interceptors + assert len(unary_interceptors) == 2 + assert GapicInterceptor in [type(i) for i in unary_interceptors] + assert client._metrics_interceptor in unary_interceptors + stream_interceptors = updated_channel._unary_stream_interceptors + assert len(stream_interceptors) == 1 + assert client._metrics_interceptor in stream_interceptors + else: + assert isinstance( + client.transport._logged_channel._interceptor, GapicInterceptor + ) + assert updated_channel._interceptor == client._metrics_interceptor @CrossSync.pytest @pytest.mark.usefixtures("target") diff --git a/tests/system/data/test_system_autogen.py b/tests/system/data/test_system_autogen.py index 37c00f2ae..463235087 100644 --- a/tests/system/data/test_system_autogen.py +++ b/tests/system/data/test_system_autogen.py @@ -221,36 +221,33 @@ def test_ping_and_warm(self, client, target): reason="emulator mode doesn't refresh channel", ) def test_channel_refresh(self, table_id, instance_id, temp_rows): - """change grpc channel to refresh after 1 second. Schedule a read_rows call after refresh, - to ensure new channel works""" - temp_rows.add_row(b"row_key_1") - temp_rows.add_row(b"row_key_2") - client = self._make_client() - try: + """perform requests while swapping out the grpc channel. Requests should continue without error""" + import time + + temp_rows.add_row(b"test_row") + with self._make_client() as client: + client._channel_refresh_task.cancel() + channel_wrapper = client.transport.grpc_channel + first_channel = channel_wrapper._channel client._channel_refresh_task = CrossSync._Sync_Impl.create_task( client._manage_channel, - refresh_interval_min=1, - refresh_interval_max=1, + refresh_interval_min=0.1, + refresh_interval_max=0.1, + grace_period=1, sync_executor=client._executor, ) - CrossSync._Sync_Impl.yield_to_event_loop() + end_time = time.monotonic() + 3 with client.get_table(instance_id, table_id) as table: - rows = table.read_rows({}) - channel_wrapper = client.transport.grpc_channel - first_channel = channel_wrapper._channel - assert len(rows) == 2 - CrossSync._Sync_Impl.sleep(2) - rows_after_refresh = table.read_rows({}) - assert len(rows_after_refresh) == 2 - assert client.transport.grpc_channel is channel_wrapper - updated_channel = channel_wrapper._channel - assert updated_channel is not first_channel - assert isinstance( - client.transport._logged_channel._interceptor, GapicInterceptor - ) - assert updated_channel._interceptor == client._metrics_interceptor - finally: - client.close() + while time.monotonic() < end_time: + rows = table.read_rows({}) + assert len(rows) == 1 + CrossSync._Sync_Impl.yield_to_event_loop() + updated_channel = channel_wrapper._channel + assert updated_channel is not first_channel + assert isinstance( + client.transport._logged_channel._interceptor, GapicInterceptor + ) + assert updated_channel._interceptor == client._metrics_interceptor @pytest.mark.usefixtures("target") @CrossSync._Sync_Impl.Retry(