diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index b2e30a8325db0..98dd3ef4380f1 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -1667,13 +1667,16 @@ def handle_response( for batch in reader: assert isinstance(batch, pa.RecordBatch) num_records_in_batch += batch.num_rows - if num_records_in_batch != b.arrow_batch.row_count: - raise SparkConnectException( - f"Expected {b.arrow_batch.row_count} rows in arrow batch but " - + f"got {num_records_in_batch}." - ) - num_records += num_records_in_batch + num_records += batch.num_rows yield batch + # An Arrow IPC stream ([Schema][RecordBatch]*[EOS]) may carry + # multiple RecordBatches, so validate row_count only once the + # reader is fully consumed. + if num_records_in_batch != b.arrow_batch.row_count: + raise SparkConnectException( + f"Expected {b.arrow_batch.row_count} rows in arrow batch but " + + f"got {num_records_in_batch}." + ) if b.HasField("create_resource_profile_command_result"): profile_id = b.create_resource_profile_command_result.profile_id yield {"create_resource_profile_command_result": profile_id} diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index 9b0f59522e257..7ffac7752b797 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -277,6 +277,44 @@ def test_user_agent_default(self): mock.req.client_type, r"^_SPARK_CONNECT_PYTHON spark/[^ ]+ os/[^ ]+ python/[^ ]+$" ) + def test_multiple_record_batches_in_single_arrow_batch(self): + # An Arrow IPC stream may carry multiple RecordBatches; row_count is the total + # across them and must be validated only after the stream is fully consumed. + client = SparkConnectClient("sc://foo/", use_reattachable_execute=False) + + class MultiBatchMockService: + def __init__(self, session_id: str): + self._session_id = session_id + self.req: Optional[proto.ExecutePlanRequest] = None + + def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata): + self.req = req + resp = proto.ExecutePlanResponse() + resp.session_id = self._session_id + resp.operation_id = req.operation_id + + pdf = pd.DataFrame(data={"col1": [1, 2, 3, 4]}) + schema = pa.Schema.from_pandas(pdf) + table = pa.Table.from_pandas(pdf) + sink = pa.BufferOutputStream() + writer = pa.ipc.new_stream(sink, schema=schema) + # Two RecordBatches in one IPC stream. + for batch in table.to_batches(max_chunksize=2): + writer.write_batch(batch) + writer.close() + + resp.arrow_batch.data = sink.getvalue().to_pybytes() + # row_count is the total across all RecordBatches in the message. + resp.arrow_batch.row_count = 4 + return [resp] + + mock = MultiBatchMockService(client._session_id) + client._stub = mock + + plan = proto.Plan() + table, _, _ = client.to_table(plan, {}) + self.assertEqual(table.num_rows, 4) + def test_properties(self): client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) self.assertEqual(client.token, "bar")