Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions docs/whisper-iree-issues/ISSUE-SDPA-recording-and-hlo.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# scaledDotProductAttention: not recorded by tape, no StableHLO converter

## Summary

`ctx.ops.scaledDotProductAttention()` exists in TensorOps interface,
VoidTensorOps, and DefaultCpuOps — but it is not tape-recorded and has
no StableHLO converter. This blocks multi-head attention in the
SKaiNET → IREE compilation path.

## Impact

Without SDPA, Whisper's multi-head attention must be decomposed into
individual ops (reshape, transpose, matmul, softmax, matmul). The
per-batch K transpose requires raw FloatArray manipulation which
creates zero constants in the VMFB (proven in TapeAttentionPermuteBugTest).

Result: GPU Whisper encoder produces wrong hidden states → decoder
outputs "," instead of real transcription.

## Three fixes needed

### 1. RecordingExecution: record SDPA

**File:** `skainet-compile-core/.../tape/RecordingExecution.kt` line 436

Current (just delegates, no recording):
```kotlin
override fun <T : DType, V> scaledDotProductAttention(...) =
base.scaledDotProductAttention(query, key, value, mask, scale, causal)
```

Fix (same pattern as conv1d in PR #532):
```kotlin
override fun <T : DType, V> scaledDotProductAttention(
query, key, value, mask, scale, causal
): Tensor<T, V> {
val out = base.scaledDotProductAttention(query, key, value, mask, scale, causal)
val params = mapOf("scale" to scale, "causal" to causal)
record(ScaledDotProductAttentionOperation(params),
listOfNotNull(query, key, value, mask), listOf(out))
return out
}
```

### 2. TensorOperations: add ScaledDotProductAttentionOperation

**File:** `skainet-lang-core/.../tensor/ops/TensorOperations.kt`

```kotlin
class ScaledDotProductAttentionOperation<T : DType, V>(
parameters: Map<String, Any> = emptyMap()
) : BaseOperation<T, V>("scaledDotProductAttention", "nn", parameters) {
override fun inferOutputs(inputs: List<TensorSpec>): List<TensorSpec> {
// Output shape = query shape: [batch, nHeads, seqLen, headDim]
return listOf(TensorSpec("sdpa_output", inputs[0].shape, inputs[0].dtype))
}
}
```

### 3. StableHLO converter: decompose SDPA

**File:** `skainet-compile-hlo/.../converters/NeuralNetOperationsConverter.kt`

Register "scaledDotProductAttention" and decompose into:
```mlir
// scores = Q @ K.T (batched matmul with K transposed)
%scores = stablehlo.dot_general %query, %key,
batching_dims = [0, 1] x [0, 1],
contracting_dims = [3] x [3]
: (tensor<BxHxSxDxf32>, tensor<BxHxTxDxf32>) -> tensor<BxHxSxTxf32>

// scale
%scaled = stablehlo.multiply %scores, %scale_splat

// optional mask (additive)
%masked = stablehlo.add %scaled, %mask // if mask != null

// softmax over last dim
%weights = stablehlo softmax ...

// output = weights @ V (batched matmul)
%output = stablehlo.dot_general %weights, %value,
batching_dims = [0, 1] x [0, 1],
contracting_dims = [3] x [2]
```

Note: `contracting_dims = [3] x [3]` for Q@K.T because we contract
headDim of Q (last dim) with headDim of K (also last dim). This is
different from standard matmul where you contract last of A with
second-to-last of B — here K is NOT pre-transposed.

## Test

```kotlin
val ctx = DefaultGraphExecutionContext.tape(baseOps = VoidTensorOps())
val q = ctx.fromFloatArray(Shape(1, 6, 4, 64), ...) // [batch, heads, seq, headDim]
val k = ctx.fromFloatArray(Shape(1, 6, 4, 64), ...)
val v = ctx.fromFloatArray(Shape(1, 6, 4, 64), ...)

val (tape, out) = ctx.record {
ctx.ops.scaledDotProductAttention(q, k, v)
}

val graph = tape!!.toComputeGraph(synthesizeExternalInputs = true)
val module = StableHloConverterFactory.createExtended().convert(graph, "test_sdpa")

// Should contain dot_general for Q@K.T and weights@V
assertTrue(module.content.contains("stablehlo.dot_general"))
assertFalse(module.content.contains("dense<0.0>")) // no zero constants
```
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,20 @@ internal class RecordingTensorOpsDecorator(private val base: TensorOps) : Tensor
override fun <T : DType, V> indexSelect(input: Tensor<T, V>, indices: Tensor<DType, *>, dim: Int): Tensor<T, V> = base.indexSelect(input, indices, dim)
override fun <T : DType, V> exp(tensor: Tensor<T, V>): Tensor<T, V> = base.exp(tensor)
override fun <T : DType, V> expm1(tensor: Tensor<T, V>): Tensor<T, V> = base.expm1(tensor)
override fun <T : DType, V> scaledDotProductAttention(query: Tensor<T, V>, key: Tensor<T, V>, value: Tensor<T, V>, mask: Tensor<T, V>?, scale: Float, causal: Boolean): Tensor<T, V> = base.scaledDotProductAttention(query, key, value, mask, scale, causal)
override fun <T : DType, V> scaledDotProductAttention(
query: Tensor<T, V>, key: Tensor<T, V>, value: Tensor<T, V>,
mask: Tensor<T, V>?, scale: Float, causal: Boolean
): Tensor<T, V> {
val out = base.scaledDotProductAttention(query, key, value, mask, scale, causal)
val params = mutableMapOf<String, Any>(
"scale" to scale,
"causal" to causal
)
@Suppress("UNCHECKED_CAST")
val inputs = listOfNotNull(query, key, value, mask) as List<Tensor<T, V>>
record(ScaledDotProductAttentionOperation(params), inputs, listOf(out))
return out
}
}

private class ConcatRecordingOperation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter {
// Normalization operations
"batchNorm", "batchNormalization", "BatchNormalization",
"layerNorm", "layerNormalization", "LayerNormalization",
"rmsNorm", "rms_norm", "RMSNorm", "RmsNorm"
"rmsNorm", "rms_norm", "RMSNorm", "RmsNorm",
// Attention
"scaledDotProductAttention"
)

override fun convert(
Expand All @@ -44,6 +46,7 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter {
"batchnorm", "batchnormalization" -> convertBatchNorm(node, operands, context)
"layernorm", "layernormalization" -> convertLayerNorm(node, operands, context)
"rmsnorm", "rms_norm" -> convertRmsNorm(node, operands, context)
"scaleddotproductattention" -> convertSdpa(node, operands, context)
else -> ConversionResult.Unsupported(
node.operation.name,
"Operation not supported by NeuralNetOperationsConverter"
Expand Down Expand Up @@ -770,5 +773,149 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter {
"epsilon = $epsilon, feature_index = $featureIndex : $outputType"
}
}


/**
* Convert scaledDotProductAttention to StableHLO.
* Decomposes into: Q @ K.T (batched) → scale → optional mask → softmax → @ V (batched)
*
* Input shapes: Q[B,H,S,D], K[B,H,T,D], V[B,H,T,D], optional mask[B,H,S,T] or broadcastable
* Output: [B,H,S,D]
*/
private fun convertSdpa(
node: GraphNode,
operands: List<String>,
context: ConversionContext
): ConversionResult {
if (operands.size < 3) {
return ConversionResult.Failure("SDPA requires at least 3 operands (Q, K, V), got ${operands.size}")
}

val query = operands[0] // [B, H, S, D]
val key = operands[1] // [B, H, T, D]
val value = operands[2] // [B, H, T, D]
val mask = if (operands.size >= 4) operands[3] else null

val querySpec = node.inputs.getOrNull(0)
val keySpec = node.inputs.getOrNull(1)
val valueSpec = node.inputs.getOrNull(2)

val outputSpec = node.outputs.firstOrNull()
val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) }
?: "tensor<?xf32>"

// Infer shapes for intermediate types
val qShape = querySpec?.shape ?: return ConversionResult.Failure("Unknown Q shape")
val kShape = keySpec?.shape ?: return ConversionResult.Failure("Unknown K shape")
val vShape = valueSpec?.shape ?: return ConversionResult.Failure("Unknown V shape")

val rank = qShape.size
if (rank != 4) {
return ConversionResult.Failure("SDPA expects 4D tensors [B,H,S,D], got rank $rank")
}

val batch = qShape[0]
val heads = qShape[1]
val seqQ = qShape[2]
val headDim = qShape[3]
val seqK = kShape[2]

val queryType = context.getValueType(query) ?: "tensor<${qShape.joinToString("x")}xf32>"
val keyType = context.getValueType(key) ?: "tensor<${kShape.joinToString("x")}xf32>"
val valueType = context.getValueType(value) ?: "tensor<${vShape.joinToString("x")}xf32>"

// scores = Q @ K.T: [B,H,S,D] @ [B,H,T,D] → [B,H,S,T]
// dot_general with batching_dims=[0,1], contracting_dims=[3]x[3]
val scoresType = "tensor<${batch}x${heads}x${seqQ}x${seqK}xf32>"
val scoresVal = context.nextTempValue()
context.emitOperation(
"$scoresVal = stablehlo.dot_general $query, $key, " +
"batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [3] " +
": ($queryType, $keyType) -> $scoresType"
)
context.setValueType(scoresVal, scoresType)

// Scale
val scale = node.operation.parameters["scale"] as? Float
?: (1.0f / kotlin.math.sqrt(headDim.toFloat()))
val scaledVal = context.nextTempValue()
val scaleConst = context.nextTempValue()
context.emitOperation("$scaleConst = stablehlo.constant dense<$scale> : tensor<f32>")
context.emitOperation(
"$scaledVal = stablehlo.broadcast_in_dim $scaleConst, dims = [] " +
": (tensor<f32>) -> $scoresType"
)
val scaledScores = context.nextTempValue()
context.emitOperation(
"$scaledScores = stablehlo.multiply $scoresVal, $scaledVal : $scoresType"
)
context.setValueType(scaledScores, scoresType)

// Optional mask
var presoft = scaledScores
if (mask != null) {
val maskedVal = context.nextTempValue()
val maskType = context.getValueType(mask) ?: scoresType
context.emitOperation(
"$maskedVal = stablehlo.add $presoft, $mask : $scoresType"
)
context.setValueType(maskedVal, scoresType)
presoft = maskedVal
}

// Softmax over last dim (seqK)
// Decompose: exp(x - max(x)) / sum(exp(x - max(x)))
val maxVal = context.nextTempValue()
val maxInitVal = context.nextTempValue()
context.emitOperation("$maxInitVal = stablehlo.constant dense<0xFF800000> : tensor<f32>") // -inf
context.emitOperation(
"$maxVal = stablehlo.reduce($presoft init: $maxInitVal) applies stablehlo.maximum " +
"across dimensions = [${rank - 1}] : ($scoresType, tensor<f32>) -> " +
"tensor<${batch}x${heads}x${seqQ}xf32>"
)

val maxBcast = context.nextTempValue()
val reducedType = "tensor<${batch}x${heads}x${seqQ}xf32>"
context.emitOperation(
"$maxBcast = stablehlo.broadcast_in_dim $maxVal, dims = [0, 1, 2] " +
": ($reducedType) -> $scoresType"
)

val shifted = context.nextTempValue()
context.emitOperation("$shifted = stablehlo.subtract $presoft, $maxBcast : $scoresType")

val expVal = context.nextTempValue()
context.emitOperation("$expVal = stablehlo.exponential $shifted : $scoresType")

val sumInit = context.nextTempValue()
context.emitOperation("$sumInit = stablehlo.constant dense<0.0> : tensor<f32>")
val sumVal = context.nextTempValue()
context.emitOperation(
"$sumVal = stablehlo.reduce($expVal init: $sumInit) applies stablehlo.add " +
"across dimensions = [${rank - 1}] : ($scoresType, tensor<f32>) -> $reducedType"
)

val sumBcast = context.nextTempValue()
context.emitOperation(
"$sumBcast = stablehlo.broadcast_in_dim $sumVal, dims = [0, 1, 2] " +
": ($reducedType) -> $scoresType"
)

val weightsVal = context.nextTempValue()
context.emitOperation("$weightsVal = stablehlo.divide $expVal, $sumBcast : $scoresType")
context.setValueType(weightsVal, scoresType)

// output = weights @ V: [B,H,S,T] @ [B,H,T,D] → [B,H,S,D]
val resultValue = context.nextTempValue()
context.emitOperation(
"$resultValue = stablehlo.dot_general $weightsVal, $value, " +
"batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [2] " +
": ($scoresType, $valueType) -> $outputType"
)

return ConversionResult.Success(
outputValueName = resultValue,
emittedOperations = emptyList()
)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package sk.ainet.compile.hlo

import sk.ainet.lang.graph.DefaultGraphExecutionContext
import sk.ainet.lang.tensor.Shape
import sk.ainet.lang.tensor.Tensor
import sk.ainet.lang.tensor.ops.VoidTensorOps
import sk.ainet.lang.tape.toComputeGraph
import sk.ainet.lang.types.FP32
import kotlin.test.Test
import kotlin.test.assertFalse
import kotlin.test.assertTrue

class SdpaHloExportTest {

@Test
fun sdpa_produces_dot_general_ops() {
val ctx = DefaultGraphExecutionContext.tape(baseOps = VoidTensorOps())

val q = ctx.fromFloatArray<FP32, Float>(Shape(1, 2, 4, 8), FP32::class, FloatArray(64))
val k = ctx.fromFloatArray<FP32, Float>(Shape(1, 2, 4, 8), FP32::class, FloatArray(64))
val v = ctx.fromFloatArray<FP32, Float>(Shape(1, 2, 4, 8), FP32::class, FloatArray(64))

@Suppress("UNCHECKED_CAST")
val inputIds = setOf(
ctx.session.refOf(q as Tensor<*, *>).id,
ctx.session.refOf(k as Tensor<*, *>).id,
ctx.session.refOf(v as Tensor<*, *>).id
)

val (tape, out) = ctx.record {
ctx.ops.scaledDotProductAttention(q, k, v)
}

println("Output shape: ${out.shape}")

val graph = tape!!.toComputeGraph(
synthesizeExternalInputs = true,
inputTensorIds = inputIds
)
val nodes = graph.getTopologicalOrder()
println("Graph: ${nodes.size} nodes")
println("Ops: ${nodes.map { it.operation.name }}")

val module = StableHloConverterFactory.createExtended().convert(graph, "sdpa_test")
println("MLIR:\n${module.content}")

// Should contain dot_general for Q@K.T and weights@V
assertTrue(module.content.contains("stablehlo.dot_general"), "Should have dot_general ops")

// Should NOT contain large zero constant tensors (from raw permutation)
// Scalar zeros (tensor<f32>) for softmax init are fine
assertFalse(
module.content.contains(Regex("dense<0\\.0> : tensor<\\d+x")),
"Should not have large zero constant tensors"
)

// Should contain exponential (softmax decomposition)
assertTrue(module.content.contains("stablehlo.exponential"), "Should have softmax (exponential)")
}
}
Loading
Loading