From 009a48eaf535186ad535ee7dc9ed843e96a056a3 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 13 Apr 2026 09:37:45 +0200 Subject: [PATCH 1/2] Pin broken softmax StableHLO lowering with failing test (#467) Extends testSoftmaxOperation to assert that the converter must not emit hardcoded dense<0.0> / dense<1.0> placeholder constants at the output shape in place of the max(x) and sum(exp(...)) terms, and must invoke real reductions plus a broadcast_in_dim back to the input shape. Red against current ActivationOperationsConverter. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../hlo/ActivationOperationsConverterTest.kt | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ActivationOperationsConverterTest.kt b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ActivationOperationsConverterTest.kt index 06cf3a83..7ecaa25f 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ActivationOperationsConverterTest.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ActivationOperationsConverterTest.kt @@ -2,6 +2,7 @@ package sk.ainet.compile.hlo import kotlin.test.Test import kotlin.test.assertTrue +import kotlin.test.assertFalse import kotlin.test.assertNotNull import kotlin.test.assertEquals import sk.ainet.lang.graph.DefaultComputeGraph @@ -65,6 +66,34 @@ class ActivationOperationsConverterTest { assertTrue(module.content.contains("stablehlo.exponential")) assertTrue(module.content.contains("stablehlo.divide")) assertTrue(module.content.contains("tensor<2x3xf32>")) + + // Regression: softmax must NOT hardcode its max/sum terms as constants + // (issue #467 — the fake-max `dense<0.0>` and fake-sum `dense<1.0>` at + // full output shape produced numerically wrong MLIR). + assertFalse( + module.content.contains("stablehlo.constant dense<0.0> : tensor<2x3xf32>"), + "softmax must not emit a fake-max constant at output shape" + ) + assertFalse( + module.content.contains("stablehlo.constant dense<1.0> : tensor<2x3xf32>"), + "softmax must not emit a fake-sum constant at output shape" + ) + + // The corrected lowering invokes real reductions (via custom_call for + // now — matching ReductionOperationsConverter style) plus a broadcast + // back to the input shape before the subtract/divide. + assertTrue( + module.content.contains("@reduce_max"), + "softmax must lower max(x) to a real reduction" + ) + assertTrue( + module.content.contains("@reduce_sum"), + "softmax must lower sum(exp(...)) to a real reduction" + ) + assertTrue( + module.content.contains("stablehlo.broadcast_in_dim"), + "softmax must broadcast reduced values back to the input shape" + ) } @Test From 06808f1fa5a345dc3a80913709749e1021f7fd4f Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 13 Apr 2026 09:40:38 +0200 Subject: [PATCH 2/2] Lower softmax to real reductions + broadcast_in_dim (#467) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the dense<0.0>/dense<1.0> placeholder constants with custom_call @reduce_max and @reduce_sum (matching the codebase's existing reduction-converter style) and broadcasts the reduced values back to the input shape via stablehlo.broadcast_in_dim before subtract / divide. Handles negative axis correctly. Branch is parked pending P0 roadmap work (quant-in-IR + backend-api extraction) — see issue #467 for context. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../ActivationOperationsConverter.kt | 94 +++++++++++++------ 1 file changed, 64 insertions(+), 30 deletions(-) diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ActivationOperationsConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ActivationOperationsConverter.kt index a0eaaf81..f04d1e91 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ActivationOperationsConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/ActivationOperationsConverter.kt @@ -86,8 +86,14 @@ public class ActivationOperationsConverter : StableHloOperationConverter { } /** - * Convert softmax activation using stablehlo.reduce and stablehlo.broadcast_in_dim. + * Convert softmax activation using real reductions and broadcast_in_dim. * softmax(x) = exp(x - max(x)) / sum(exp(x - max(x))) + * + * The max and sum terms are lowered to stablehlo.custom_call @reduce_max + * and @reduce_sum — matching ReductionOperationsConverter's style — then + * broadcast back to the input shape before subtract / divide. This replaces + * an earlier lowering that used `dense<0.0>` / `dense<1.0>` placeholder + * constants (see #467) and produced numerically wrong MLIR. */ private fun convertSoftmax( node: GraphNode, @@ -100,49 +106,77 @@ public class ActivationOperationsConverter : StableHloOperationConverter { "Unsupported softmax arity for node ${node.id}" ) } - + val outputSpec = node.outputs.firstOrNull() - val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) } + val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) } ?: "tensor" - val elementType = outputSpec?.let { context.getTypeMapper().mapDType(it.dtype) } + val elementType = outputSpec?.let { context.getTypeMapper().mapDType(it.dtype) } ?: "f32" - - // Get axis parameter (default to last dimension) - val axis = node.operation.parameters["axis"] as? Int ?: -1 - val actualAxis = if (axis < 0) { - // For simplicity, assume last dimension (1 for 2D tensor) - 1 + + val inputShape = node.inputs.firstOrNull()?.shape ?: outputSpec?.shape ?: emptyList() + val rank = inputShape.size + + // Normalize axis against rank. Default to the last dimension. + val rawAxis = node.operation.parameters["axis"] as? Int ?: -1 + val axis = when { + rank == 0 -> 0 + rawAxis < 0 -> rank + rawAxis + else -> rawAxis + }.coerceIn(0, (rank - 1).coerceAtLeast(0)) + + // Reduced tensor type: input shape with `axis` dimension removed. + val reducedShape = if (rank > 0) { + inputShape.filterIndexed { i, _ -> i != axis } } else { - axis + emptyList() } - - // For a simplified softmax implementation, we'll use element-wise operations - // This is a basic implementation that works along the last dimension + val reducedType = if (reducedShape.isEmpty()) { + "tensor<$elementType>" + } else { + "tensor<${reducedShape.joinToString("x")}x$elementType>" + } + + // Dimensions kept for broadcast_in_dim: every input dim except `axis`, + // mapped to its position in the reduced tensor. + val broadcastDims = (0 until rank).filter { it != axis }.joinToString(", ") + val maxValue = context.nextTempValue() + val maxBroadcast = context.nextTempValue() val shiftedValue = context.nextTempValue() val expValue = context.nextTempValue() val sumValue = context.nextTempValue() + val sumBroadcast = context.nextTempValue() val resultValue = context.nextTempValue() - + val operations = listOf( - // Find maximum (simplified - using a constant for now) - "$maxValue = stablehlo.constant dense<0.0> : $outputType", - - // Subtract max for numerical stability (simplified) - "$shiftedValue = stablehlo.subtract ${operands[0]}, $maxValue : $outputType", - - // Apply exponential + // Reduce-max along the softmax axis (for numerical stability). + "$maxValue = stablehlo.custom_call @reduce_max(${operands[0]}) " + + "{dimensions = [$axis], keepdim = false} : $reducedType", + + // Broadcast reduced max back to the input shape. + "$maxBroadcast = stablehlo.broadcast_in_dim $maxValue, " + + "dims = [$broadcastDims] : ($reducedType) -> $outputType", + + // Subtract the max for numerical stability. + "$shiftedValue = stablehlo.subtract ${operands[0]}, $maxBroadcast : $outputType", + + // Elementwise exponential. "$expValue = stablehlo.exponential $shiftedValue : $outputType", - - // Sum (simplified - using a constant sum for now) - "$sumValue = stablehlo.constant dense<1.0> : $outputType", - - // Divide to get final softmax - "$resultValue = stablehlo.divide $expValue, $sumValue : $outputType" + + // Reduce-sum along the softmax axis. + "$sumValue = stablehlo.custom_call @reduce_sum($expValue) " + + "{dimensions = [$axis], keepdim = false} : $reducedType", + + // Broadcast the sum back to the input shape. + "$sumBroadcast = stablehlo.broadcast_in_dim $sumValue, " + + "dims = [$broadcastDims] : ($reducedType) -> $outputType", + + // Normalize. + "$resultValue = stablehlo.divide $expValue, $sumBroadcast : $outputType" ) - + operations.forEach { context.emitOperation(it) } - + return ConversionResult.Success( outputValueName = resultValue, emittedOperations = operations