Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 73 additions & 2 deletions common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ import org.apache.arrow.c.CDataDictionaryProvider
import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, FieldVector, FixedSizeBinaryVector, Float4Vector, Float8Vector, IntVector, NullVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, ValueVector, VarBinaryVector, VarCharVector, VectorSchemaRoot}
import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector}
import org.apache.arrow.vector.dictionary.DictionaryProvider
import org.apache.arrow.vector.ipc.ArrowStreamWriter
import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter}
import org.apache.arrow.vector.types._
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
import org.apache.arrow.vector.util.VectorSchemaRootAppender
import org.apache.spark.{SparkEnv, SparkException}
import org.apache.spark.internal.Logging
import org.apache.spark.io.CompressionCodec
import org.apache.spark.sql.comet.execution.arrow.ArrowReaderIterator
import org.apache.spark.sql.types._
Expand All @@ -43,7 +45,7 @@ import org.apache.comet.Constants.COMET_CONF_DIR_ENV
import org.apache.comet.shims.CometTypeShim
import org.apache.comet.vector.CometVector

object Utils extends CometTypeShim {
object Utils extends CometTypeShim with Logging {
def getConfPath(confFileName: String): String = {
sys.env
.get(COMET_CONF_DIR_ENV)
Expand Down Expand Up @@ -252,6 +254,75 @@ object Utils extends CometTypeShim {
new ArrowReaderIterator(Channels.newChannel(ins), source)
}

/**
* Coalesces many small ChunkedByteBuffers (one per source partition) into a single
* ChunkedByteBuffer containing one Arrow IPC stream with one record batch. This avoids each
* consumer partition having to deserialize N separate streams.
*/
def coalesceBroadcastBatches(input: Iterator[ChunkedByteBuffer]): Array[ChunkedByteBuffer] = {
val buffers = input.filterNot(_.size == 0).toArray
if (buffers.isEmpty) {
return Array.empty
}

val allocator = org.apache.comet.CometArrowAllocator
.newChildAllocator("broadcast-coalesce", 0, Long.MaxValue)
try {
var targetRoot: VectorSchemaRoot = null
var totalRows = 0L
var batchCount = 0

try {
for (bytes <- buffers) {
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
val cbbis = bytes.toInputStream()
val ins = new DataInputStream(codec.compressedInputStream(cbbis))
val channel = Channels.newChannel(ins)
val reader = new ArrowStreamReader(channel, allocator)
try {
while (reader.loadNextBatch()) {
val sourceRoot = reader.getVectorSchemaRoot
if (targetRoot == null) {
targetRoot = VectorSchemaRoot.create(sourceRoot.getSchema, allocator)
}
VectorSchemaRootAppender.append(targetRoot, sourceRoot)
totalRows += sourceRoot.getRowCount
batchCount += 1
}
} finally {
reader.close()
}
}

if (targetRoot == null) {
return Array.empty
}

assert(
targetRoot.getRowCount.toLong == totalRows,
s"Row count mismatch after coalesce: ${targetRoot.getRowCount} != $totalRows")

logInfo(s"Coalesced $batchCount broadcast batches into 1 ($totalRows rows)")

val outCodec = CompressionCodec.createCodec(SparkEnv.get.conf)
val cbbos = new ChunkedByteBufferOutputStream(1024 * 1024, ByteBuffer.allocate)
val out = new DataOutputStream(outCodec.compressedOutputStream(cbbos))
val writer = new ArrowStreamWriter(targetRoot, null, Channels.newChannel(out))
writer.start()
writer.writeBatch()
writer.close()

Array(cbbos.toChunkedByteBuffer)
} finally {
if (targetRoot != null) {
targetRoot.close()
}
}
} finally {
allocator.close()
}
}

def getBatchFieldVectors(
batch: ColumnarBatch): (Seq[FieldVector], Option[DictionaryProvider]) = {
var provider: Option[DictionaryProvider] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ case class CometBroadcastExchangeExec(
val beforeBuild = System.nanoTime()
longMetric("collectTime") += NANOSECONDS.toMillis(beforeBuild - beforeCollect)

val batches = input.toArray
// Coalesce many small per-partition buffers into a single buffer so each
// consumer partition only deserializes one Arrow IPC stream.
val batches = Utils.coalesceBroadcastBatches(input)

val dataSize = batches.map(_.size).sum

Expand Down
55 changes: 55 additions & 0 deletions spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -383,4 +383,59 @@ class CometJoinSuite extends CometTestBase {
""".stripMargin))
}
}

test("Broadcast hash join build-side batch coalescing") {
// Use many shuffle partitions to produce many small broadcast batches,
// then verify that coalescing reduces the build-side batch count to 1 per task.
val numPartitions = 512
withSQLConf(
CometConf.COMET_BATCH_SIZE.key -> "100",
SQLConf.PREFER_SORTMERGEJOIN.key -> "false",
"spark.sql.join.forceApplyShuffledHashJoin" -> "true",
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
SQLConf.SHUFFLE_PARTITIONS.key -> numPartitions.toString) {
withParquetTable((0 until 10000).map(i => (i, i % 5)), "tbl_a") {
withParquetTable((0 until 10000).map(i => (i % 10, i + 2)), "tbl_b") {
// Force a shuffle on tbl_a before broadcast so the broadcast source has
// numPartitions partitions, not just the number of parquet files.
val query =
s"""SELECT /*+ BROADCAST(a) */ *
|FROM (SELECT /*+ REPARTITION($numPartitions) */ * FROM tbl_a) a
|JOIN tbl_b ON a._2 = tbl_b._1""".stripMargin

// First verify correctness
val df = sql(query)
checkSparkAnswerAndOperator(
df,
Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec]))

// Run again and check metrics on the executed plan
val df2 = sql(query)
df2.collect()

val joins = collect(df2.queryExecution.executedPlan) {
case j: CometBroadcastHashJoinExec => j
}
assert(joins.nonEmpty, "Expected CometBroadcastHashJoinExec in plan")

val join = joins.head
val buildBatches = join.metrics("build_input_batches").value
val buildRows = join.metrics("build_input_rows").value

// Without coalescing, build_input_batches would be ~numPartitions per task,
// totaling ~numPartitions * numPartitions across all tasks.
// With coalescing, each task gets 1 batch, so total ≈ numPartitions.
// scalastyle:off println
println(s"Build-side metrics: batches=$buildBatches, rows=$buildRows")
// scalastyle:on println
assert(
buildBatches <= numPartitions,
s"Expected at most $numPartitions build batches (1 per task), got $buildBatches. " +
"Broadcast batch coalescing may not be working.")
}
}
}
}
}
Loading