diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py index a5589ac33a1b..cfefa37d76b6 100644 --- a/sdks/python/apache_beam/runners/worker/data_plane.py +++ b/sdks/python/apache_beam/runners/worker/data_plane.py @@ -68,6 +68,9 @@ # Keep a set of completed instructions to discard late received data. The set # can have up to _MAX_CLEANED_INSTRUCTIONS items. See _GrpcDataChannel. _MAX_CLEANED_INSTRUCTIONS = 10000 +_DEFAULT_SEND_QUEUE_MAX_ELEMENTS = 10000 +_DEFAULT_SEND_QUEUE_MAX_BYTES = 100 << 20 # 100MB +_DEFAULT_RECEIVE_QUEUE_MAX_ELEMENTS = 5 # retry on transient UNAVAILABLE grpc error from data channels. _GRPC_SERVICE_CONFIG = json.dumps({ @@ -459,10 +462,20 @@ def __init__(self, data_buffer_time_limit_ms=0): self._data_buffer_time_limit_ms = data_buffer_time_limit_ms self._to_send = ByteLimitedQueue( - maxsize=10000, - maxbytes=100 << 20) # type: ByteLimitedQueue[DataOrTimers] + maxsize=_DEFAULT_SEND_QUEUE_MAX_ELEMENTS, + maxbytes=_DEFAULT_SEND_QUEUE_MAX_BYTES + ) # type: ByteLimitedQueue[DataOrTimers] + # Staging queue so a full send buffer does not block reading inputs. + self._pending_send = ByteLimitedQueue( + maxsize=_DEFAULT_SEND_QUEUE_MAX_ELEMENTS, + maxbytes=_DEFAULT_SEND_QUEUE_MAX_BYTES + ) # type: ByteLimitedQueue[DataOrTimers] + self._send_forwarder = None # type: Optional[threading.Thread] + self._start_send_forwarder() self._received = collections.defaultdict( - lambda: ByteLimitedQueue(maxsize=5, maxbytes=100 << 20) + lambda: ByteLimitedQueue( + maxsize=_DEFAULT_RECEIVE_QUEUE_MAX_ELEMENTS, maxbytes= + _DEFAULT_SEND_QUEUE_MAX_BYTES) ) # type: DefaultDict[str, ByteLimitedQueue[DataOrTimers]] # Keep a cache of completed instructions. Data for completed instructions @@ -478,9 +491,40 @@ def __init__(self, data_buffer_time_limit_ms=0): def close(self): # type: () -> None - self._to_send.put(self._WRITES_FINISHED, 0) + self._enqueue_to_send(self._WRITES_FINISHED) + if self._send_forwarder is not None: + self._send_forwarder.join() + if self._exception: + raise self._exception self._closed = True + def _start_send_forwarder(self): + # type: () -> None + forwarder = threading.Thread( + target=self._forward_pending_to_send, name='forward_grpc_outputs') + forwarder.daemon = True + forwarder.start() + self._send_forwarder = forwarder + + def _enqueue_to_send(self, elem): + # type: (DataOrTimers) -> None + size = self._get_element_size_bytes(elem) + self._pending_send.put((elem, size), size) + + def _forward_pending_to_send(self): + # type: () -> None + try: + while True: + elem, size = self._pending_send.get() + self._to_send.put(elem, size) + if elem is self._WRITES_FINISHED: + return + except Exception as e: + if not self._closed: + _LOGGER.exception('Failed to forward outputs in the data plane.') + self._exception = e + raise + def wait(self, timeout=None): # type: (Optional[int]) -> None self._reads_finished.wait(timeout) @@ -591,7 +635,7 @@ def add_to_send_queue(data): if data: elem = beam_fn_api_pb2.Elements.Data( instruction_id=instruction_id, transform_id=transform_id, data=data) - self._to_send.put(elem, self._get_element_size_bytes(elem)) + self._enqueue_to_send(elem) def close_callback(data): # type: (bytes) -> None @@ -601,7 +645,7 @@ def close_callback(data): instruction_id=instruction_id, transform_id=transform_id, is_last=True) - self._to_send.put(elem, self._get_element_size_bytes(elem)) + self._enqueue_to_send(elem) return ClosableOutputStream.create( close_callback, add_to_send_queue, self._data_buffer_time_limit_ms) @@ -622,7 +666,7 @@ def add_to_send_queue(timer): timer_family_id=timer_family_id, timers=timer, is_last=False) - self._to_send.put(elem, self._get_element_size_bytes(elem)) + self._enqueue_to_send(elem) def close_callback(timer): # type: (bytes) -> None @@ -632,7 +676,7 @@ def close_callback(timer): transform_id=transform_id, timer_family_id=timer_family_id, is_last=True) - self._to_send.put(elem, self._get_element_size_bytes(elem)) + self._enqueue_to_send(elem) return ClosableOutputStream.create( close_callback, add_to_send_queue, self._data_buffer_time_limit_ms)