diff --git a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala index 6eaa9cad44..525d7c1759 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala +++ b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala @@ -29,10 +29,12 @@ import org.apache.arrow.c.CDataDictionaryProvider import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, FieldVector, FixedSizeBinaryVector, Float4Vector, Float8Vector, IntVector, NullVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, ValueVector, VarBinaryVector, VarCharVector, VectorSchemaRoot} import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector} import org.apache.arrow.vector.dictionary.DictionaryProvider -import org.apache.arrow.vector.ipc.ArrowStreamWriter +import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter} import org.apache.arrow.vector.types._ import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} +import org.apache.arrow.vector.util.VectorSchemaRootAppender import org.apache.spark.{SparkEnv, SparkException} +import org.apache.spark.internal.Logging import org.apache.spark.io.CompressionCodec import org.apache.spark.sql.comet.execution.arrow.ArrowReaderIterator import org.apache.spark.sql.types._ @@ -43,7 +45,7 @@ import org.apache.comet.Constants.COMET_CONF_DIR_ENV import org.apache.comet.shims.CometTypeShim import org.apache.comet.vector.CometVector -object Utils extends CometTypeShim { +object Utils extends CometTypeShim with Logging { def getConfPath(confFileName: String): String = { sys.env .get(COMET_CONF_DIR_ENV) @@ -252,6 +254,75 @@ object Utils extends CometTypeShim { new ArrowReaderIterator(Channels.newChannel(ins), source) } + /** + * Coalesces many small ChunkedByteBuffers (one per source partition) into a single + * ChunkedByteBuffer containing one Arrow IPC stream with one record batch. This avoids each + * consumer partition having to deserialize N separate streams. + */ + def coalesceBroadcastBatches(input: Iterator[ChunkedByteBuffer]): Array[ChunkedByteBuffer] = { + val buffers = input.filterNot(_.size == 0).toArray + if (buffers.isEmpty) { + return Array.empty + } + + val allocator = org.apache.comet.CometArrowAllocator + .newChildAllocator("broadcast-coalesce", 0, Long.MaxValue) + try { + var targetRoot: VectorSchemaRoot = null + var totalRows = 0L + var batchCount = 0 + + try { + for (bytes <- buffers) { + val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + val cbbis = bytes.toInputStream() + val ins = new DataInputStream(codec.compressedInputStream(cbbis)) + val channel = Channels.newChannel(ins) + val reader = new ArrowStreamReader(channel, allocator) + try { + while (reader.loadNextBatch()) { + val sourceRoot = reader.getVectorSchemaRoot + if (targetRoot == null) { + targetRoot = VectorSchemaRoot.create(sourceRoot.getSchema, allocator) + } + VectorSchemaRootAppender.append(targetRoot, sourceRoot) + totalRows += sourceRoot.getRowCount + batchCount += 1 + } + } finally { + reader.close() + } + } + + if (targetRoot == null) { + return Array.empty + } + + assert( + targetRoot.getRowCount.toLong == totalRows, + s"Row count mismatch after coalesce: ${targetRoot.getRowCount} != $totalRows") + + logInfo(s"Coalesced $batchCount broadcast batches into 1 ($totalRows rows)") + + val outCodec = CompressionCodec.createCodec(SparkEnv.get.conf) + val cbbos = new ChunkedByteBufferOutputStream(1024 * 1024, ByteBuffer.allocate) + val out = new DataOutputStream(outCodec.compressedOutputStream(cbbos)) + val writer = new ArrowStreamWriter(targetRoot, null, Channels.newChannel(out)) + writer.start() + writer.writeBatch() + writer.close() + + Array(cbbos.toChunkedByteBuffer) + } finally { + if (targetRoot != null) { + targetRoot.close() + } + } + } finally { + allocator.close() + } + } + def getBatchFieldVectors( batch: ColumnarBatch): (Seq[FieldVector], Option[DictionaryProvider]) = { var provider: Option[DictionaryProvider] = None diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index f40e05ea0c..291006f356 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -155,7 +155,9 @@ case class CometBroadcastExchangeExec( val beforeBuild = System.nanoTime() longMetric("collectTime") += NANOSECONDS.toMillis(beforeBuild - beforeCollect) - val batches = input.toArray + // Coalesce many small per-partition buffers into a single buffer so each + // consumer partition only deserializes one Arrow IPC stream. + val batches = Utils.coalesceBroadcastBatches(input) val dataSize = batches.map(_.size).sum diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala index d5a8387be7..8de66b8d2f 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala @@ -383,4 +383,59 @@ class CometJoinSuite extends CometTestBase { """.stripMargin)) } } + + test("Broadcast hash join build-side batch coalescing") { + // Use many shuffle partitions to produce many small broadcast batches, + // then verify that coalescing reduces the build-side batch count to 1 per task. + val numPartitions = 512 + withSQLConf( + CometConf.COMET_BATCH_SIZE.key -> "100", + SQLConf.PREFER_SORTMERGEJOIN.key -> "false", + "spark.sql.join.forceApplyShuffledHashJoin" -> "true", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + SQLConf.SHUFFLE_PARTITIONS.key -> numPartitions.toString) { + withParquetTable((0 until 10000).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 10000).map(i => (i % 10, i + 2)), "tbl_b") { + // Force a shuffle on tbl_a before broadcast so the broadcast source has + // numPartitions partitions, not just the number of parquet files. + val query = + s"""SELECT /*+ BROADCAST(a) */ * + |FROM (SELECT /*+ REPARTITION($numPartitions) */ * FROM tbl_a) a + |JOIN tbl_b ON a._2 = tbl_b._1""".stripMargin + + // First verify correctness + val df = sql(query) + checkSparkAnswerAndOperator( + df, + Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec])) + + // Run again and check metrics on the executed plan + val df2 = sql(query) + df2.collect() + + val joins = collect(df2.queryExecution.executedPlan) { + case j: CometBroadcastHashJoinExec => j + } + assert(joins.nonEmpty, "Expected CometBroadcastHashJoinExec in plan") + + val join = joins.head + val buildBatches = join.metrics("build_input_batches").value + val buildRows = join.metrics("build_input_rows").value + + // Without coalescing, build_input_batches would be ~numPartitions per task, + // totaling ~numPartitions * numPartitions across all tasks. + // With coalescing, each task gets 1 batch, so total ≈ numPartitions. + // scalastyle:off println + println(s"Build-side metrics: batches=$buildBatches, rows=$buildRows") + // scalastyle:on println + assert( + buildBatches <= numPartitions, + s"Expected at most $numPartitions build batches (1 per task), got $buildBatches. " + + "Broadcast batch coalescing may not be working.") + } + } + } + } }