From af98395f59414dec506f41da707300f14becdc27 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 13 Apr 2026 13:00:31 +0200 Subject: [PATCH 1/2] Pin missing RMSNorm converter with failing test (#479) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds RmsNormConverterTest with three cases: 1. rmsNorm_operation_is_supported_by_neural_net_converter — asserts NeuralNetOperationsConverter registers rmsNorm plus the rms_norm and RMSNorm aliases. Red today. 2. rmsNorm_with_scale_lowers_to_real_ops — builds a 2×4 FP32 graph with an RMSNorm node and a per-channel scale operand, runs the converter, asserts the emitted module contains @reduce_mean, sqrt, divide, broadcast_in_dim, multiply, and is not labelled as "Unsupported operation rmsNorm". Red today because no converter claims the op. 3. rmsNorm_without_scale_still_normalizes — the scale operand is optional (RMSNorm can be used without the trailing affine multiply, though most LLMs do include it). The core norm must still lower to real ops. Tests use a minimal in-file fixture op stub rather than a real RMSNorm Operation subclass since the converter only reads `operation.name` and `operation.parameters`. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../ainet/compile/hlo/RmsNormConverterTest.kt | 178 ++++++++++++++++++ 1 file changed, 178 insertions(+) create mode 100644 skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/RmsNormConverterTest.kt diff --git a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/RmsNormConverterTest.kt b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/RmsNormConverterTest.kt new file mode 100644 index 00000000..f09e4ede --- /dev/null +++ b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/RmsNormConverterTest.kt @@ -0,0 +1,178 @@ +package sk.ainet.compile.hlo + +import sk.ainet.compile.hlo.converters.NeuralNetOperationsConverter +import sk.ainet.lang.graph.DefaultComputeGraph +import sk.ainet.lang.graph.GraphEdge +import sk.ainet.lang.graph.GraphNode +import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.tensor.ops.Operation +import sk.ainet.lang.tensor.ops.TensorSpec +import sk.ainet.lang.tensor.ops.ValidationResult +import sk.ainet.lang.types.DType +import kotlin.test.Test +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +/** + * Covers the RMSNorm converter added for #479. Every Llama / Mistral / + * Qwen / Gemma family transformer normalizes activations with RMSNorm, + * not LayerNorm, so a missing converter here blocks every modern LLM + * export through the StableHLO pipeline. + * + * RMSNorm: + * rms = sqrt(mean(x^2, axis) + eps) + * out = scale * x / rms + * + * (No mean-centering, no offset — that's the distinction from LayerNorm.) + */ +class RmsNormConverterTest { + + @Test + fun rmsNorm_operation_is_supported_by_neural_net_converter() { + val registry = StableHloOperationRegistry() + registry.register(NeuralNetOperationsConverter()) + assertTrue( + registry.isSupported("rmsNorm"), + "NeuralNetOperationsConverter must register `rmsNorm`" + ) + assertTrue( + registry.isSupported("rms_norm"), + "NeuralNetOperationsConverter must register snake-case alias `rms_norm`" + ) + assertTrue( + registry.isSupported("RMSNorm"), + "NeuralNetOperationsConverter must register PascalCase alias `RMSNorm`" + ) + } + + @Test + fun rmsNorm_with_scale_lowers_to_real_ops() { + val graph = buildRmsNormGraph(withScale = true) + val converter = StableHloConverterFactory.createExtended() + val module = converter.convert(graph, "test_rms_norm") + println("[DEBUG_LOG] RMSNorm with scale:\n${module.content}") + + // Must not fall through to the registry's "unsupported" path. + assertFalse( + module.content.contains("Unsupported operation rmsNorm"), + "RMSNorm must be claimed by a converter, not dropped as unsupported" + ) + assertFalse( + module.content.contains("No converter found"), + "RMSNorm must be claimed by a converter, not left without a handler" + ) + + // Core elementwise decomposition of the norm. + assertTrue( + module.content.contains("stablehlo.multiply"), + "RMSNorm must emit at least one multiply (x*x and/or scale*x/rms)" + ) + assertTrue( + module.content.contains("@reduce_mean"), + "RMSNorm must lower mean(x^2) to a real reduction (custom_call style matches the rest of the emitter)" + ) + assertTrue( + module.content.contains("stablehlo.sqrt"), + "RMSNorm must take the sqrt of the mean-square-plus-eps term" + ) + assertTrue( + module.content.contains("stablehlo.divide"), + "RMSNorm must divide by the rms value" + ) + assertTrue( + module.content.contains("stablehlo.broadcast_in_dim"), + "RMSNorm must broadcast the reduced rms back to the input shape" + ) + + // Regression: the converter must NOT hardcode the normalization + // denominator or the mean-square term as a placeholder constant. + assertFalse( + module.content.contains("stablehlo.constant dense<1.0> : tensor<2x4xf32>"), + "RMSNorm must not emit a fake-rms constant at output shape" + ) + } + + @Test + fun rmsNorm_without_scale_still_normalizes() { + val graph = buildRmsNormGraph(withScale = false) + val converter = StableHloConverterFactory.createExtended() + val module = converter.convert(graph, "test_rms_norm_no_scale") + println("[DEBUG_LOG] RMSNorm without scale:\n${module.content}") + + assertFalse( + module.content.contains("Unsupported operation"), + "RMSNorm without a scale operand must still be claimed by the converter" + ) + // The core norm still happens — we just skip the final scale multiply. + assertTrue(module.content.contains("@reduce_mean")) + assertTrue(module.content.contains("stablehlo.sqrt")) + assertTrue(module.content.contains("stablehlo.divide")) + } + + private fun buildRmsNormGraph(withScale: Boolean): DefaultComputeGraph { + val graph = DefaultComputeGraph() + val shape = listOf(2, 4) + + val input = GraphNode( + id = "x", + operation = markerInputOp(), + inputs = emptyList(), + outputs = listOf(TensorSpec("x", shape, "FP32")) + ) + graph.addNode(input) + + val rmsNormInputs = mutableListOf(TensorSpec("x", shape, "FP32")) + val scaleNode: GraphNode? = if (withScale) { + val s = GraphNode( + id = "scale", + operation = markerInputOp(), + inputs = emptyList(), + outputs = listOf(TensorSpec("scale", listOf(4), "FP32")) + ) + graph.addNode(s) + rmsNormInputs.add(TensorSpec("scale", listOf(4), "FP32")) + s + } else null + + val rmsNorm = GraphNode( + id = "rms1", + operation = rmsNormOp(eps = 1e-6, axis = -1), + inputs = rmsNormInputs.toList(), + outputs = listOf(TensorSpec("y", shape, "FP32")) + ) + graph.addNode(rmsNorm) + + graph.addEdge(GraphEdge("e1", input, rmsNorm, 0, 0, input.outputs[0])) + if (scaleNode != null) { + graph.addEdge(GraphEdge("e2", scaleNode, rmsNorm, 0, 1, scaleNode.outputs[0])) + } + + return graph + } + + private fun markerInputOp(): Operation = object : Operation { + override val name: String = "input" + override val type: String = "input" + override val parameters: Map = emptyMap() + override fun execute(inputs: List>): List> = + throw UnsupportedOperationException("test fixture only") + override fun validateInputs(inputs: List): ValidationResult = ValidationResult.Valid + override fun inferOutputs(inputs: List): List = emptyList() + override fun clone(newParameters: Map): Operation = this + override fun serialize(): Map = mapOf("name" to name, "type" to type) + } + + private fun rmsNormOp(eps: Double, axis: Int): Operation = object : Operation { + override val name: String = "rmsNorm" + override val type: String = "normalization" + override val parameters: Map = mapOf("eps" to eps, "axis" to axis) + override fun execute(inputs: List>): List> = + throw UnsupportedOperationException("test fixture only") + override fun validateInputs(inputs: List): ValidationResult = ValidationResult.Valid + override fun inferOutputs(inputs: List): List = inputs.take(1) + override fun clone(newParameters: Map): Operation = this + override fun serialize(): Map = mapOf( + "name" to name, "type" to type, "parameters" to parameters + ) + } +} From 862a7719d5d65ee40da8baac64258ffaba5b9f3a Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 13 Apr 2026 13:01:34 +0200 Subject: [PATCH 2/2] Add RMSNorm lowering to NeuralNetOperationsConverter (#479) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extends NeuralNetOperationsConverter with a convertRmsNorm method covering the rmsNorm / rms_norm / RMSNorm / RmsNorm operation names and registers them in supportedOperations. The lowering is the standard Llama-family form: rms = sqrt(mean(x^2, axis) + eps) out = scale * x / rms (scale operand optional) Emission style matches the softmax fix (#467) and the rest of the emitter: reductions go through `stablehlo.custom_call @reduce_mean`, the reduced tensor is broadcast back to the input shape via `stablehlo.broadcast_in_dim` for the final divide, and the epsilon is materialized as a scalar constant broadcast into the reduced shape. Migrating all reductions to real `stablehlo.reduce` regions is a separate refactor, explicitly out of scope. Axis normalization against rank handles negative axes and also accepts an `IntArray` `normalized_shape` parameter for callers that prefer PyTorch-style configuration. Default epsilon is 1e-6, matching Llama / Mistral / Qwen / Gemma; callers can override via `eps` or `epsilon`. Without a scale operand the final affine multiply is skipped and the normalized value is returned directly — a few implementations use RMSNorm without a learnable scale, and dropping the multiply keeps the emitted MLIR faithful to the input graph. Tests: 3/3 in RmsNormConverterTest green, full compile-hlo jvmTest suite still green. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../NeuralNetOperationsConverter.kt | 137 +++++++++++++++++- 1 file changed, 133 insertions(+), 4 deletions(-) diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/NeuralNetOperationsConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/NeuralNetOperationsConverter.kt index fc6e0a27..afd09e06 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/NeuralNetOperationsConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/NeuralNetOperationsConverter.kt @@ -27,12 +27,13 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter { "maxPool2d", "avgPool2d", "averagePool2d", // Normalization operations "batchNorm", "batchNormalization", "BatchNormalization", - "layerNorm", "layerNormalization", "LayerNormalization" + "layerNorm", "layerNormalization", "LayerNormalization", + "rmsNorm", "rms_norm", "RMSNorm", "RmsNorm" ) - + override fun convert( - node: GraphNode, - operands: List, + node: GraphNode, + operands: List, context: ConversionContext ): ConversionResult { return when (node.operation.name.lowercase()) { @@ -42,6 +43,7 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter { "avgpool2d", "averagepool2d" -> convertAvgPool2d(node, operands, context) "batchnorm", "batchnormalization" -> convertBatchNorm(node, operands, context) "layernorm", "layernormalization" -> convertLayerNorm(node, operands, context) + "rmsnorm", "rms_norm" -> convertRmsNorm(node, operands, context) else -> ConversionResult.Unsupported( node.operation.name, "Operation not supported by NeuralNetOperationsConverter" @@ -320,6 +322,133 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter { ) } + /** + * Lower RMSNorm to real StableHLO elementwise ops. This is the + * normalization every Llama / Mistral / Qwen / Gemma family + * transformer uses — it drops the mean-centering and the additive + * offset of LayerNorm, leaving: + * + * rms = sqrt(mean(x^2, axis) + eps) + * out = scale * x / rms (scale operand is optional) + * + * The reductions are emitted as `stablehlo.custom_call @reduce_mean` + * to match the style already used by `ReductionOperationsConverter` + * and by the softmax lowering in `ActivationOperationsConverter`. + * Migrating every reduction to proper `stablehlo.reduce` regions is + * a separate, larger refactor. + */ + private fun convertRmsNorm( + node: GraphNode, + operands: List, + context: ConversionContext + ): ConversionResult { + if (operands.isEmpty()) { + return ConversionResult.Failure( + "RMSNorm operation requires at least 1 operand (input), got ${operands.size}", + "Unsupported rmsNorm arity for node ${node.id}" + ) + } + + val outputSpec = node.outputs.firstOrNull() + val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) } + ?: "tensor" + val elementType = outputSpec?.let { context.getTypeMapper().mapDType(it.dtype) } + ?: "f32" + + val inputShape = node.inputs.firstOrNull()?.shape ?: outputSpec?.shape ?: emptyList() + val rank = inputShape.size + + // Normalize the axis parameter against rank. Default to the + // last dimension, consistent with softmax and every LLM RMSNorm + // implementation in the wild. + val rawAxis = node.operation.parameters["axis"] as? Int + ?: (node.operation.parameters["normalized_shape"] as? IntArray)?.firstOrNull() + ?: -1 + val axis = when { + rank == 0 -> 0 + rawAxis < 0 -> rank + rawAxis + else -> rawAxis + }.coerceIn(0, (rank - 1).coerceAtLeast(0)) + + // Reduced tensor shape: input with `axis` removed. Matches the + // same reduced-type convention `convertSoftmax` uses. + val reducedShape = if (rank > 0) { + inputShape.filterIndexed { i, _ -> i != axis } + } else { + emptyList() + } + val reducedType = if (reducedShape.isEmpty()) { + "tensor<$elementType>" + } else { + "tensor<${reducedShape.joinToString("x")}x$elementType>" + } + + // Dimensions kept for broadcast_in_dim: every input dim except + // `axis`, mapped to its position in the reduced tensor. + val broadcastDims = (0 until rank).filter { it != axis }.joinToString(", ") + + val eps = (node.operation.parameters["eps"] as? Double) + ?: (node.operation.parameters["epsilon"] as? Double) + ?: 1e-6 // Llama family default; LayerNorm typically uses 1e-5 + + val xInput = operands[0] + val scaleOperand: String? = if (operands.size >= 2) operands[1] else null + + val xSquared = context.nextTempValue() + val meanSquared = context.nextTempValue() + val epsConst = context.nextTempValue() + val epsBroadcast = context.nextTempValue() + val meanPlusEps = context.nextTempValue() + val rms = context.nextTempValue() + val rmsBroadcast = context.nextTempValue() + val normalized = context.nextTempValue() + val resultValue = context.nextTempValue() + + val operations = mutableListOf() + + // x^2 + operations += "$xSquared = stablehlo.multiply $xInput, $xInput : $outputType" + + // reduce_mean(x^2, axis) + operations += "$meanSquared = stablehlo.custom_call @reduce_mean($xSquared) " + + "{dimensions = [$axis], keepdim = false} : $reducedType" + + // eps constant broadcast into the reduced shape + operations += "$epsConst = stablehlo.constant dense<$eps> : tensor<$elementType>" + operations += "$epsBroadcast = stablehlo.broadcast_in_dim $epsConst, " + + "dims = [] : (tensor<$elementType>) -> $reducedType" + + // mean + eps + operations += "$meanPlusEps = stablehlo.add $meanSquared, $epsBroadcast : $reducedType" + + // rms = sqrt(mean + eps) + operations += "$rms = stablehlo.sqrt $meanPlusEps : $reducedType" + + // Broadcast rms back to the input shape for the elementwise divide. + operations += "$rmsBroadcast = stablehlo.broadcast_in_dim $rms, " + + "dims = [$broadcastDims] : ($reducedType) -> $outputType" + + // x / rms + operations += "$normalized = stablehlo.divide $xInput, $rmsBroadcast : $outputType" + + // Final scale multiply is optional — when the caller did not + // pass a scale operand we return the normalized value directly. + val finalValue: String + if (scaleOperand != null) { + operations += "$resultValue = stablehlo.multiply $normalized, $scaleOperand : $outputType" + finalValue = resultValue + } else { + finalValue = normalized + } + + operations.forEach { context.emitOperation(it) } + + return ConversionResult.Success( + outputValueName = finalValue, + emittedOperations = operations + ) + } + // Helper functions for parameter extraction private fun extractStride(params: Map): Pair {