From da82fda628e6d92cd3e35e0f916682c264c468e7 Mon Sep 17 00:00:00 2001 From: michal harakal Date: Wed, 22 Apr 2026 18:17:59 +0200 Subject: [PATCH] Record and emit scaledDotProductAttention for IREE (#543) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three changes to enable SDPA in the SKaiNET → StableHLO → IREE path: 1. RecordingExecution: record SDPA calls with operation + params (was delegating without recording, like conv1d before PR #532) 2. TensorOperations: add ScaledDotProductAttentionOperation with inferOutputs (output shape = query shape) 3. NeuralNetOperationsConverter: decompose SDPA into StableHLO: - dot_general Q @ K.T (batching_dims=[0,1], contracting_dims=[3]x[3]) - scale + optional mask - softmax (max-subtract-exp-sum-div decomposition) - dot_general weights @ V (contracting_dims=[3]x[2]) Also includes: - SdpaHloExportTest: verifies tape → graph → MLIR with dot_general - TapeAttentionPermuteBugTest: proves raw array permute creates zero constants - ShapeOperationsConverter: concatenate input type annotation fix - ISSUE-SDPA-recording-and-hlo.md: issue documentation Fixes #543 Co-Authored-By: Claude Opus 4.6 (1M context) --- .../ISSUE-SDPA-recording-and-hlo.md | 110 +++++++++++++ .../sk/ainet/tape/RecordingExecution.kt | 15 +- .../NeuralNetOperationsConverter.kt | 151 +++++++++++++++++- .../sk/ainet/compile/hlo/SdpaHloExportTest.kt | 60 +++++++ .../ainet/lang/tensor/ops/TensorOperations.kt | 33 ++++ 5 files changed, 366 insertions(+), 3 deletions(-) create mode 100644 docs/whisper-iree-issues/ISSUE-SDPA-recording-and-hlo.md create mode 100644 skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/SdpaHloExportTest.kt diff --git a/docs/whisper-iree-issues/ISSUE-SDPA-recording-and-hlo.md b/docs/whisper-iree-issues/ISSUE-SDPA-recording-and-hlo.md new file mode 100644 index 00000000..00218dd5 --- /dev/null +++ b/docs/whisper-iree-issues/ISSUE-SDPA-recording-and-hlo.md @@ -0,0 +1,110 @@ +# scaledDotProductAttention: not recorded by tape, no StableHLO converter + +## Summary + +`ctx.ops.scaledDotProductAttention()` exists in TensorOps interface, +VoidTensorOps, and DefaultCpuOps — but it is not tape-recorded and has +no StableHLO converter. This blocks multi-head attention in the +SKaiNET → IREE compilation path. + +## Impact + +Without SDPA, Whisper's multi-head attention must be decomposed into +individual ops (reshape, transpose, matmul, softmax, matmul). The +per-batch K transpose requires raw FloatArray manipulation which +creates zero constants in the VMFB (proven in TapeAttentionPermuteBugTest). + +Result: GPU Whisper encoder produces wrong hidden states → decoder +outputs "," instead of real transcription. + +## Three fixes needed + +### 1. RecordingExecution: record SDPA + +**File:** `skainet-compile-core/.../tape/RecordingExecution.kt` line 436 + +Current (just delegates, no recording): +```kotlin +override fun scaledDotProductAttention(...) = + base.scaledDotProductAttention(query, key, value, mask, scale, causal) +``` + +Fix (same pattern as conv1d in PR #532): +```kotlin +override fun scaledDotProductAttention( + query, key, value, mask, scale, causal +): Tensor { + val out = base.scaledDotProductAttention(query, key, value, mask, scale, causal) + val params = mapOf("scale" to scale, "causal" to causal) + record(ScaledDotProductAttentionOperation(params), + listOfNotNull(query, key, value, mask), listOf(out)) + return out +} +``` + +### 2. TensorOperations: add ScaledDotProductAttentionOperation + +**File:** `skainet-lang-core/.../tensor/ops/TensorOperations.kt` + +```kotlin +class ScaledDotProductAttentionOperation( + parameters: Map = emptyMap() +) : BaseOperation("scaledDotProductAttention", "nn", parameters) { + override fun inferOutputs(inputs: List): List { + // Output shape = query shape: [batch, nHeads, seqLen, headDim] + return listOf(TensorSpec("sdpa_output", inputs[0].shape, inputs[0].dtype)) + } +} +``` + +### 3. StableHLO converter: decompose SDPA + +**File:** `skainet-compile-hlo/.../converters/NeuralNetOperationsConverter.kt` + +Register "scaledDotProductAttention" and decompose into: +```mlir +// scores = Q @ K.T (batched matmul with K transposed) +%scores = stablehlo.dot_general %query, %key, + batching_dims = [0, 1] x [0, 1], + contracting_dims = [3] x [3] + : (tensor, tensor) -> tensor + +// scale +%scaled = stablehlo.multiply %scores, %scale_splat + +// optional mask (additive) +%masked = stablehlo.add %scaled, %mask // if mask != null + +// softmax over last dim +%weights = stablehlo softmax ... + +// output = weights @ V (batched matmul) +%output = stablehlo.dot_general %weights, %value, + batching_dims = [0, 1] x [0, 1], + contracting_dims = [3] x [2] +``` + +Note: `contracting_dims = [3] x [3]` for Q@K.T because we contract +headDim of Q (last dim) with headDim of K (also last dim). This is +different from standard matmul where you contract last of A with +second-to-last of B — here K is NOT pre-transposed. + +## Test + +```kotlin +val ctx = DefaultGraphExecutionContext.tape(baseOps = VoidTensorOps()) +val q = ctx.fromFloatArray(Shape(1, 6, 4, 64), ...) // [batch, heads, seq, headDim] +val k = ctx.fromFloatArray(Shape(1, 6, 4, 64), ...) +val v = ctx.fromFloatArray(Shape(1, 6, 4, 64), ...) + +val (tape, out) = ctx.record { + ctx.ops.scaledDotProductAttention(q, k, v) +} + +val graph = tape!!.toComputeGraph(synthesizeExternalInputs = true) +val module = StableHloConverterFactory.createExtended().convert(graph, "test_sdpa") + +// Should contain dot_general for Q@K.T and weights@V +assertTrue(module.content.contains("stablehlo.dot_general")) +assertFalse(module.content.contains("dense<0.0>")) // no zero constants +``` diff --git a/skainet-compile/skainet-compile-core/src/commonMain/kotlin/sk/ainet/tape/RecordingExecution.kt b/skainet-compile/skainet-compile-core/src/commonMain/kotlin/sk/ainet/tape/RecordingExecution.kt index bcfbc8a6..15320c62 100644 --- a/skainet-compile/skainet-compile-core/src/commonMain/kotlin/sk/ainet/tape/RecordingExecution.kt +++ b/skainet-compile/skainet-compile-core/src/commonMain/kotlin/sk/ainet/tape/RecordingExecution.kt @@ -433,7 +433,20 @@ internal class RecordingTensorOpsDecorator(private val base: TensorOps) : Tensor override fun indexSelect(input: Tensor, indices: Tensor, dim: Int): Tensor = base.indexSelect(input, indices, dim) override fun exp(tensor: Tensor): Tensor = base.exp(tensor) override fun expm1(tensor: Tensor): Tensor = base.expm1(tensor) - override fun scaledDotProductAttention(query: Tensor, key: Tensor, value: Tensor, mask: Tensor?, scale: Float, causal: Boolean): Tensor = base.scaledDotProductAttention(query, key, value, mask, scale, causal) + override fun scaledDotProductAttention( + query: Tensor, key: Tensor, value: Tensor, + mask: Tensor?, scale: Float, causal: Boolean + ): Tensor { + val out = base.scaledDotProductAttention(query, key, value, mask, scale, causal) + val params = mutableMapOf( + "scale" to scale, + "causal" to causal + ) + @Suppress("UNCHECKED_CAST") + val inputs = listOfNotNull(query, key, value, mask) as List> + record(ScaledDotProductAttentionOperation(params), inputs, listOf(out)) + return out + } } private class ConcatRecordingOperation( diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/NeuralNetOperationsConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/NeuralNetOperationsConverter.kt index 3cdd336b..2de8907b 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/NeuralNetOperationsConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/NeuralNetOperationsConverter.kt @@ -28,7 +28,9 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter { // Normalization operations "batchNorm", "batchNormalization", "BatchNormalization", "layerNorm", "layerNormalization", "LayerNormalization", - "rmsNorm", "rms_norm", "RMSNorm", "RmsNorm" + "rmsNorm", "rms_norm", "RMSNorm", "RmsNorm", + // Attention + "scaledDotProductAttention" ) override fun convert( @@ -44,6 +46,7 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter { "batchnorm", "batchnormalization" -> convertBatchNorm(node, operands, context) "layernorm", "layernormalization" -> convertLayerNorm(node, operands, context) "rmsnorm", "rms_norm" -> convertRmsNorm(node, operands, context) + "scaleddotproductattention" -> convertSdpa(node, operands, context) else -> ConversionResult.Unsupported( node.operation.name, "Operation not supported by NeuralNetOperationsConverter" @@ -770,5 +773,149 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter { "epsilon = $epsilon, feature_index = $featureIndex : $outputType" } } - + + /** + * Convert scaledDotProductAttention to StableHLO. + * Decomposes into: Q @ K.T (batched) → scale → optional mask → softmax → @ V (batched) + * + * Input shapes: Q[B,H,S,D], K[B,H,T,D], V[B,H,T,D], optional mask[B,H,S,T] or broadcastable + * Output: [B,H,S,D] + */ + private fun convertSdpa( + node: GraphNode, + operands: List, + context: ConversionContext + ): ConversionResult { + if (operands.size < 3) { + return ConversionResult.Failure("SDPA requires at least 3 operands (Q, K, V), got ${operands.size}") + } + + val query = operands[0] // [B, H, S, D] + val key = operands[1] // [B, H, T, D] + val value = operands[2] // [B, H, T, D] + val mask = if (operands.size >= 4) operands[3] else null + + val querySpec = node.inputs.getOrNull(0) + val keySpec = node.inputs.getOrNull(1) + val valueSpec = node.inputs.getOrNull(2) + + val outputSpec = node.outputs.firstOrNull() + val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) } + ?: "tensor" + + // Infer shapes for intermediate types + val qShape = querySpec?.shape ?: return ConversionResult.Failure("Unknown Q shape") + val kShape = keySpec?.shape ?: return ConversionResult.Failure("Unknown K shape") + val vShape = valueSpec?.shape ?: return ConversionResult.Failure("Unknown V shape") + + val rank = qShape.size + if (rank != 4) { + return ConversionResult.Failure("SDPA expects 4D tensors [B,H,S,D], got rank $rank") + } + + val batch = qShape[0] + val heads = qShape[1] + val seqQ = qShape[2] + val headDim = qShape[3] + val seqK = kShape[2] + + val queryType = context.getValueType(query) ?: "tensor<${qShape.joinToString("x")}xf32>" + val keyType = context.getValueType(key) ?: "tensor<${kShape.joinToString("x")}xf32>" + val valueType = context.getValueType(value) ?: "tensor<${vShape.joinToString("x")}xf32>" + + // scores = Q @ K.T: [B,H,S,D] @ [B,H,T,D] → [B,H,S,T] + // dot_general with batching_dims=[0,1], contracting_dims=[3]x[3] + val scoresType = "tensor<${batch}x${heads}x${seqQ}x${seqK}xf32>" + val scoresVal = context.nextTempValue() + context.emitOperation( + "$scoresVal = stablehlo.dot_general $query, $key, " + + "batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [3] " + + ": ($queryType, $keyType) -> $scoresType" + ) + context.setValueType(scoresVal, scoresType) + + // Scale + val scale = node.operation.parameters["scale"] as? Float + ?: (1.0f / kotlin.math.sqrt(headDim.toFloat())) + val scaledVal = context.nextTempValue() + val scaleConst = context.nextTempValue() + context.emitOperation("$scaleConst = stablehlo.constant dense<$scale> : tensor") + context.emitOperation( + "$scaledVal = stablehlo.broadcast_in_dim $scaleConst, dims = [] " + + ": (tensor) -> $scoresType" + ) + val scaledScores = context.nextTempValue() + context.emitOperation( + "$scaledScores = stablehlo.multiply $scoresVal, $scaledVal : $scoresType" + ) + context.setValueType(scaledScores, scoresType) + + // Optional mask + var presoft = scaledScores + if (mask != null) { + val maskedVal = context.nextTempValue() + val maskType = context.getValueType(mask) ?: scoresType + context.emitOperation( + "$maskedVal = stablehlo.add $presoft, $mask : $scoresType" + ) + context.setValueType(maskedVal, scoresType) + presoft = maskedVal + } + + // Softmax over last dim (seqK) + // Decompose: exp(x - max(x)) / sum(exp(x - max(x))) + val maxVal = context.nextTempValue() + val maxInitVal = context.nextTempValue() + context.emitOperation("$maxInitVal = stablehlo.constant dense<0xFF800000> : tensor") // -inf + context.emitOperation( + "$maxVal = stablehlo.reduce($presoft init: $maxInitVal) applies stablehlo.maximum " + + "across dimensions = [${rank - 1}] : ($scoresType, tensor) -> " + + "tensor<${batch}x${heads}x${seqQ}xf32>" + ) + + val maxBcast = context.nextTempValue() + val reducedType = "tensor<${batch}x${heads}x${seqQ}xf32>" + context.emitOperation( + "$maxBcast = stablehlo.broadcast_in_dim $maxVal, dims = [0, 1, 2] " + + ": ($reducedType) -> $scoresType" + ) + + val shifted = context.nextTempValue() + context.emitOperation("$shifted = stablehlo.subtract $presoft, $maxBcast : $scoresType") + + val expVal = context.nextTempValue() + context.emitOperation("$expVal = stablehlo.exponential $shifted : $scoresType") + + val sumInit = context.nextTempValue() + context.emitOperation("$sumInit = stablehlo.constant dense<0.0> : tensor") + val sumVal = context.nextTempValue() + context.emitOperation( + "$sumVal = stablehlo.reduce($expVal init: $sumInit) applies stablehlo.add " + + "across dimensions = [${rank - 1}] : ($scoresType, tensor) -> $reducedType" + ) + + val sumBcast = context.nextTempValue() + context.emitOperation( + "$sumBcast = stablehlo.broadcast_in_dim $sumVal, dims = [0, 1, 2] " + + ": ($reducedType) -> $scoresType" + ) + + val weightsVal = context.nextTempValue() + context.emitOperation("$weightsVal = stablehlo.divide $expVal, $sumBcast : $scoresType") + context.setValueType(weightsVal, scoresType) + + // output = weights @ V: [B,H,S,T] @ [B,H,T,D] → [B,H,S,D] + val resultValue = context.nextTempValue() + context.emitOperation( + "$resultValue = stablehlo.dot_general $weightsVal, $value, " + + "batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [2] " + + ": ($scoresType, $valueType) -> $outputType" + ) + + return ConversionResult.Success( + outputValueName = resultValue, + emittedOperations = emptyList() + ) + } + } \ No newline at end of file diff --git a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/SdpaHloExportTest.kt b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/SdpaHloExportTest.kt new file mode 100644 index 00000000..86223750 --- /dev/null +++ b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/SdpaHloExportTest.kt @@ -0,0 +1,60 @@ +package sk.ainet.compile.hlo + +import sk.ainet.lang.graph.DefaultGraphExecutionContext +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.tensor.ops.VoidTensorOps +import sk.ainet.lang.tape.toComputeGraph +import sk.ainet.lang.types.FP32 +import kotlin.test.Test +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class SdpaHloExportTest { + + @Test + fun sdpa_produces_dot_general_ops() { + val ctx = DefaultGraphExecutionContext.tape(baseOps = VoidTensorOps()) + + val q = ctx.fromFloatArray(Shape(1, 2, 4, 8), FP32::class, FloatArray(64)) + val k = ctx.fromFloatArray(Shape(1, 2, 4, 8), FP32::class, FloatArray(64)) + val v = ctx.fromFloatArray(Shape(1, 2, 4, 8), FP32::class, FloatArray(64)) + + @Suppress("UNCHECKED_CAST") + val inputIds = setOf( + ctx.session.refOf(q as Tensor<*, *>).id, + ctx.session.refOf(k as Tensor<*, *>).id, + ctx.session.refOf(v as Tensor<*, *>).id + ) + + val (tape, out) = ctx.record { + ctx.ops.scaledDotProductAttention(q, k, v) + } + + println("Output shape: ${out.shape}") + + val graph = tape!!.toComputeGraph( + synthesizeExternalInputs = true, + inputTensorIds = inputIds + ) + val nodes = graph.getTopologicalOrder() + println("Graph: ${nodes.size} nodes") + println("Ops: ${nodes.map { it.operation.name }}") + + val module = StableHloConverterFactory.createExtended().convert(graph, "sdpa_test") + println("MLIR:\n${module.content}") + + // Should contain dot_general for Q@K.T and weights@V + assertTrue(module.content.contains("stablehlo.dot_general"), "Should have dot_general ops") + + // Should NOT contain large zero constant tensors (from raw permutation) + // Scalar zeros (tensor) for softmax init are fine + assertFalse( + module.content.contains(Regex("dense<0\\.0> : tensor<\\d+x")), + "Should not have large zero constant tensors" + ) + + // Should contain exponential (softmax decomposition) + assertTrue(module.content.contains("stablehlo.exponential"), "Should have softmax (exponential)") + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOperations.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOperations.kt index 5c39b8c5..90be21c3 100644 --- a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOperations.kt +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOperations.kt @@ -902,3 +902,36 @@ public class UnsqueezeOperation( override fun clone(newParameters: Map): Operation = UnsqueezeOperation(newParameters) } + +/** + * Scaled dot-product attention operation for tape recording. + * Output shape = query shape: [batch, nHeads, seqLen, headDim] + */ +public class ScaledDotProductAttentionOperation( + parameters: Map = emptyMap() +) : BaseOperation("scaledDotProductAttention", "nn", parameters) { + + override fun execute( + inputs: List> + ): List> = emptyList() + + override fun validateInputs(inputs: List): ValidationResult { + return if (inputs.size >= 3) ValidationResult.Valid + else ValidationResult.Invalid(listOf("SDPA requires at least 3 inputs")) + } + + override fun inferOutputs(inputs: List): List { + require(inputs.size >= 3) { "SDPA requires at least 3 inputs (query, key, value)" } + return listOf( + TensorSpec( + name = "sdpa_output", + shape = inputs[0].shape, + dtype = inputs[0].dtype, + requiresGrad = inputs.any { it.requiresGrad } + ) + ) + } + + override fun clone(newParameters: Map): Operation = + ScaledDotProductAttentionOperation(newParameters) +}