From 96424a720469dea023381dabb520d62c9f4acc94 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Sat, 14 Mar 2026 12:53:07 -0400 Subject: [PATCH 1/3] Coalesce broadcast exchange batches before broadcasting CometBroadcastExchangeExec previously broadcast an Array[ChunkedByteBuffer] with one entry per source partition. Each consumer partition independently deserialized all entries, creating a separate compression codec and Arrow IPC reader per entry. For broadcasts with many source partitions, this produced large per-task overhead in the hash join build-side collection. Decode and concatenate all broadcast batches into a single ChunkedByteBuffer on the driver using VectorSchemaRootAppender before broadcasting. Falls back to per-batch serialization if dictionary-encoded vectors are present. --- .../apache/spark/sql/comet/util/Utils.scala | 73 ++++++++++++++++++- .../comet/CometBroadcastExchangeExec.scala | 5 +- 2 files changed, 76 insertions(+), 2 deletions(-) diff --git a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala index 6eaa9cad44..297c1e8aa6 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala +++ b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala @@ -32,7 +32,9 @@ import org.apache.arrow.vector.dictionary.DictionaryProvider import org.apache.arrow.vector.ipc.ArrowStreamWriter import org.apache.arrow.vector.types._ import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} +import org.apache.arrow.vector.util.VectorSchemaRootAppender import org.apache.spark.{SparkEnv, SparkException} +import org.apache.spark.internal.Logging import org.apache.spark.io.CompressionCodec import org.apache.spark.sql.comet.execution.arrow.ArrowReaderIterator import org.apache.spark.sql.types._ @@ -43,7 +45,7 @@ import org.apache.comet.Constants.COMET_CONF_DIR_ENV import org.apache.comet.shims.CometTypeShim import org.apache.comet.vector.CometVector -object Utils extends CometTypeShim { +object Utils extends CometTypeShim with Logging { def getConfPath(confFileName: String): String = { sys.env .get(COMET_CONF_DIR_ENV) @@ -252,6 +254,75 @@ object Utils extends CometTypeShim { new ArrowReaderIterator(Channels.newChannel(ins), source) } + /** + * Coalesces many small ChunkedByteBuffers (one per source partition) into a single + * ChunkedByteBuffer containing one Arrow IPC stream with one record batch. This avoids each + * consumer partition having to deserialize N separate streams. + */ + def coalesceBroadcastBatches(input: Iterator[ChunkedByteBuffer]): Array[ChunkedByteBuffer] = { + val decoded = input.flatMap(decodeBatches(_, "broadcast-coalesce")).toArray + if (decoded.isEmpty) { + return Array.empty + } + + try { + var hasDictionary = false + val sourceRoots = decoded.map { batch => + val (fieldVectors, providerOpt) = getBatchFieldVectors(batch) + if (providerOpt.isDefined) { + hasDictionary = true + } + new VectorSchemaRoot(fieldVectors.asJava) + } + + // Fall back to per-batch serialization if any batch has dictionary-encoded vectors, + // since merging dictionaries across batches is not supported. + if (hasDictionary) { + logInfo( + s"Broadcast coalesce falling back to per-batch serialization due to " + + s"dictionary-encoded vectors (${decoded.length} batches)") + return decoded.flatMap { batch => + serializeBatches(Iterator(batch)).map(_._2) + } + } + + val allocator = org.apache.comet.CometArrowAllocator + .newChildAllocator("broadcast-coalesce", 0, Long.MaxValue) + try { + val schema = sourceRoots.head.getSchema + val targetRoot = VectorSchemaRoot.create(schema, allocator) + try { + VectorSchemaRootAppender.append(targetRoot, sourceRoots: _*) + + val expectedRows = decoded.map(_.numRows().toLong).sum + assert( + targetRoot.getRowCount.toLong == expectedRows, + s"Row count mismatch after coalesce: ${targetRoot.getRowCount} != $expectedRows") + + logInfo( + s"Coalesced ${decoded.length} broadcast batches into 1 " + + s"($expectedRows rows)") + + val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + val cbbos = new ChunkedByteBufferOutputStream(1024 * 1024, ByteBuffer.allocate) + val out = new DataOutputStream(codec.compressedOutputStream(cbbos)) + val writer = new ArrowStreamWriter(targetRoot, null, Channels.newChannel(out)) + writer.start() + writer.writeBatch() + writer.close() + + Array(cbbos.toChunkedByteBuffer) + } finally { + targetRoot.close() + } + } finally { + allocator.close() + } + } finally { + decoded.foreach(_.close()) + } + } + def getBatchFieldVectors( batch: ColumnarBatch): (Seq[FieldVector], Option[DictionaryProvider]) = { var provider: Option[DictionaryProvider] = None diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index f40e05ea0c..ea4e8b1405 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -155,7 +155,10 @@ case class CometBroadcastExchangeExec( val beforeBuild = System.nanoTime() longMetric("collectTime") += NANOSECONDS.toMillis(beforeBuild - beforeCollect) - val batches = input.toArray + // Coalesce many small per-partition buffers into a single buffer so each + // consumer partition only deserializes one Arrow IPC stream. + // May produce multiple buffers if dictionary-encoded vectors are present. + val batches = Utils.coalesceBroadcastBatches(input) val dataSize = batches.map(_.size).sum From a51bc8b1709617b3b3ab4985851fafbcb5cb1230 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Sat, 14 Mar 2026 13:05:33 -0400 Subject: [PATCH 2/3] Fix scalastyle. --- .../src/main/scala/org/apache/spark/sql/comet/util/Utils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala index 297c1e8aa6..28943f5ad1 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala +++ b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala @@ -279,7 +279,7 @@ object Utils extends CometTypeShim with Logging { // since merging dictionaries across batches is not supported. if (hasDictionary) { logInfo( - s"Broadcast coalesce falling back to per-batch serialization due to " + + "Broadcast coalesce falling back to per-batch serialization due to " + s"dictionary-encoded vectors (${decoded.length} batches)") return decoded.flatMap { batch => serializeBatches(Iterator(batch)).map(_._2) From 43d3920c9350f96121bcb099f396c179d15f5a18 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Sun, 15 Mar 2026 16:06:31 -0400 Subject: [PATCH 3/3] Add test, change implementation. --- .../apache/spark/sql/comet/util/Utils.scala | 92 +++++++++---------- .../comet/CometBroadcastExchangeExec.scala | 1 - .../apache/comet/exec/CometJoinSuite.scala | 55 +++++++++++ 3 files changed, 101 insertions(+), 47 deletions(-) diff --git a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala index 28943f5ad1..525d7c1759 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala +++ b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala @@ -29,7 +29,7 @@ import org.apache.arrow.c.CDataDictionaryProvider import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, FieldVector, FixedSizeBinaryVector, Float4Vector, Float8Vector, IntVector, NullVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, ValueVector, VarBinaryVector, VarCharVector, VectorSchemaRoot} import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector} import org.apache.arrow.vector.dictionary.DictionaryProvider -import org.apache.arrow.vector.ipc.ArrowStreamWriter +import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter} import org.apache.arrow.vector.types._ import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} import org.apache.arrow.vector.util.VectorSchemaRootAppender @@ -260,66 +260,66 @@ object Utils extends CometTypeShim with Logging { * consumer partition having to deserialize N separate streams. */ def coalesceBroadcastBatches(input: Iterator[ChunkedByteBuffer]): Array[ChunkedByteBuffer] = { - val decoded = input.flatMap(decodeBatches(_, "broadcast-coalesce")).toArray - if (decoded.isEmpty) { + val buffers = input.filterNot(_.size == 0).toArray + if (buffers.isEmpty) { return Array.empty } + val allocator = org.apache.comet.CometArrowAllocator + .newChildAllocator("broadcast-coalesce", 0, Long.MaxValue) try { - var hasDictionary = false - val sourceRoots = decoded.map { batch => - val (fieldVectors, providerOpt) = getBatchFieldVectors(batch) - if (providerOpt.isDefined) { - hasDictionary = true + var targetRoot: VectorSchemaRoot = null + var totalRows = 0L + var batchCount = 0 + + try { + for (bytes <- buffers) { + val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + val cbbis = bytes.toInputStream() + val ins = new DataInputStream(codec.compressedInputStream(cbbis)) + val channel = Channels.newChannel(ins) + val reader = new ArrowStreamReader(channel, allocator) + try { + while (reader.loadNextBatch()) { + val sourceRoot = reader.getVectorSchemaRoot + if (targetRoot == null) { + targetRoot = VectorSchemaRoot.create(sourceRoot.getSchema, allocator) + } + VectorSchemaRootAppender.append(targetRoot, sourceRoot) + totalRows += sourceRoot.getRowCount + batchCount += 1 + } + } finally { + reader.close() + } } - new VectorSchemaRoot(fieldVectors.asJava) - } - // Fall back to per-batch serialization if any batch has dictionary-encoded vectors, - // since merging dictionaries across batches is not supported. - if (hasDictionary) { - logInfo( - "Broadcast coalesce falling back to per-batch serialization due to " + - s"dictionary-encoded vectors (${decoded.length} batches)") - return decoded.flatMap { batch => - serializeBatches(Iterator(batch)).map(_._2) + if (targetRoot == null) { + return Array.empty } - } - val allocator = org.apache.comet.CometArrowAllocator - .newChildAllocator("broadcast-coalesce", 0, Long.MaxValue) - try { - val schema = sourceRoots.head.getSchema - val targetRoot = VectorSchemaRoot.create(schema, allocator) - try { - VectorSchemaRootAppender.append(targetRoot, sourceRoots: _*) + assert( + targetRoot.getRowCount.toLong == totalRows, + s"Row count mismatch after coalesce: ${targetRoot.getRowCount} != $totalRows") - val expectedRows = decoded.map(_.numRows().toLong).sum - assert( - targetRoot.getRowCount.toLong == expectedRows, - s"Row count mismatch after coalesce: ${targetRoot.getRowCount} != $expectedRows") + logInfo(s"Coalesced $batchCount broadcast batches into 1 ($totalRows rows)") - logInfo( - s"Coalesced ${decoded.length} broadcast batches into 1 " + - s"($expectedRows rows)") + val outCodec = CompressionCodec.createCodec(SparkEnv.get.conf) + val cbbos = new ChunkedByteBufferOutputStream(1024 * 1024, ByteBuffer.allocate) + val out = new DataOutputStream(outCodec.compressedOutputStream(cbbos)) + val writer = new ArrowStreamWriter(targetRoot, null, Channels.newChannel(out)) + writer.start() + writer.writeBatch() + writer.close() - val codec = CompressionCodec.createCodec(SparkEnv.get.conf) - val cbbos = new ChunkedByteBufferOutputStream(1024 * 1024, ByteBuffer.allocate) - val out = new DataOutputStream(codec.compressedOutputStream(cbbos)) - val writer = new ArrowStreamWriter(targetRoot, null, Channels.newChannel(out)) - writer.start() - writer.writeBatch() - writer.close() - - Array(cbbos.toChunkedByteBuffer) - } finally { + Array(cbbos.toChunkedByteBuffer) + } finally { + if (targetRoot != null) { targetRoot.close() } - } finally { - allocator.close() } } finally { - decoded.foreach(_.close()) + allocator.close() } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index ea4e8b1405..291006f356 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -157,7 +157,6 @@ case class CometBroadcastExchangeExec( // Coalesce many small per-partition buffers into a single buffer so each // consumer partition only deserializes one Arrow IPC stream. - // May produce multiple buffers if dictionary-encoded vectors are present. val batches = Utils.coalesceBroadcastBatches(input) val dataSize = batches.map(_.size).sum diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala index d5a8387be7..8de66b8d2f 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala @@ -383,4 +383,59 @@ class CometJoinSuite extends CometTestBase { """.stripMargin)) } } + + test("Broadcast hash join build-side batch coalescing") { + // Use many shuffle partitions to produce many small broadcast batches, + // then verify that coalescing reduces the build-side batch count to 1 per task. + val numPartitions = 512 + withSQLConf( + CometConf.COMET_BATCH_SIZE.key -> "100", + SQLConf.PREFER_SORTMERGEJOIN.key -> "false", + "spark.sql.join.forceApplyShuffledHashJoin" -> "true", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + SQLConf.SHUFFLE_PARTITIONS.key -> numPartitions.toString) { + withParquetTable((0 until 10000).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 10000).map(i => (i % 10, i + 2)), "tbl_b") { + // Force a shuffle on tbl_a before broadcast so the broadcast source has + // numPartitions partitions, not just the number of parquet files. + val query = + s"""SELECT /*+ BROADCAST(a) */ * + |FROM (SELECT /*+ REPARTITION($numPartitions) */ * FROM tbl_a) a + |JOIN tbl_b ON a._2 = tbl_b._1""".stripMargin + + // First verify correctness + val df = sql(query) + checkSparkAnswerAndOperator( + df, + Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec])) + + // Run again and check metrics on the executed plan + val df2 = sql(query) + df2.collect() + + val joins = collect(df2.queryExecution.executedPlan) { + case j: CometBroadcastHashJoinExec => j + } + assert(joins.nonEmpty, "Expected CometBroadcastHashJoinExec in plan") + + val join = joins.head + val buildBatches = join.metrics("build_input_batches").value + val buildRows = join.metrics("build_input_rows").value + + // Without coalescing, build_input_batches would be ~numPartitions per task, + // totaling ~numPartitions * numPartitions across all tasks. + // With coalescing, each task gets 1 batch, so total ≈ numPartitions. + // scalastyle:off println + println(s"Build-side metrics: batches=$buildBatches, rows=$buildRows") + // scalastyle:on println + assert( + buildBatches <= numPartitions, + s"Expected at most $numPartitions build batches (1 per task), got $buildBatches. " + + "Broadcast batch coalescing may not be working.") + } + } + } + } }