diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/NeuralNetOperationsConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/NeuralNetOperationsConverter.kt index fc6e0a27..afd09e06 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/NeuralNetOperationsConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/NeuralNetOperationsConverter.kt @@ -27,12 +27,13 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter { "maxPool2d", "avgPool2d", "averagePool2d", // Normalization operations "batchNorm", "batchNormalization", "BatchNormalization", - "layerNorm", "layerNormalization", "LayerNormalization" + "layerNorm", "layerNormalization", "LayerNormalization", + "rmsNorm", "rms_norm", "RMSNorm", "RmsNorm" ) - + override fun convert( - node: GraphNode, - operands: List, + node: GraphNode, + operands: List, context: ConversionContext ): ConversionResult { return when (node.operation.name.lowercase()) { @@ -42,6 +43,7 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter { "avgpool2d", "averagepool2d" -> convertAvgPool2d(node, operands, context) "batchnorm", "batchnormalization" -> convertBatchNorm(node, operands, context) "layernorm", "layernormalization" -> convertLayerNorm(node, operands, context) + "rmsnorm", "rms_norm" -> convertRmsNorm(node, operands, context) else -> ConversionResult.Unsupported( node.operation.name, "Operation not supported by NeuralNetOperationsConverter" @@ -320,6 +322,133 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter { ) } + /** + * Lower RMSNorm to real StableHLO elementwise ops. This is the + * normalization every Llama / Mistral / Qwen / Gemma family + * transformer uses — it drops the mean-centering and the additive + * offset of LayerNorm, leaving: + * + * rms = sqrt(mean(x^2, axis) + eps) + * out = scale * x / rms (scale operand is optional) + * + * The reductions are emitted as `stablehlo.custom_call @reduce_mean` + * to match the style already used by `ReductionOperationsConverter` + * and by the softmax lowering in `ActivationOperationsConverter`. + * Migrating every reduction to proper `stablehlo.reduce` regions is + * a separate, larger refactor. + */ + private fun convertRmsNorm( + node: GraphNode, + operands: List, + context: ConversionContext + ): ConversionResult { + if (operands.isEmpty()) { + return ConversionResult.Failure( + "RMSNorm operation requires at least 1 operand (input), got ${operands.size}", + "Unsupported rmsNorm arity for node ${node.id}" + ) + } + + val outputSpec = node.outputs.firstOrNull() + val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) } + ?: "tensor" + val elementType = outputSpec?.let { context.getTypeMapper().mapDType(it.dtype) } + ?: "f32" + + val inputShape = node.inputs.firstOrNull()?.shape ?: outputSpec?.shape ?: emptyList() + val rank = inputShape.size + + // Normalize the axis parameter against rank. Default to the + // last dimension, consistent with softmax and every LLM RMSNorm + // implementation in the wild. + val rawAxis = node.operation.parameters["axis"] as? Int + ?: (node.operation.parameters["normalized_shape"] as? IntArray)?.firstOrNull() + ?: -1 + val axis = when { + rank == 0 -> 0 + rawAxis < 0 -> rank + rawAxis + else -> rawAxis + }.coerceIn(0, (rank - 1).coerceAtLeast(0)) + + // Reduced tensor shape: input with `axis` removed. Matches the + // same reduced-type convention `convertSoftmax` uses. + val reducedShape = if (rank > 0) { + inputShape.filterIndexed { i, _ -> i != axis } + } else { + emptyList() + } + val reducedType = if (reducedShape.isEmpty()) { + "tensor<$elementType>" + } else { + "tensor<${reducedShape.joinToString("x")}x$elementType>" + } + + // Dimensions kept for broadcast_in_dim: every input dim except + // `axis`, mapped to its position in the reduced tensor. + val broadcastDims = (0 until rank).filter { it != axis }.joinToString(", ") + + val eps = (node.operation.parameters["eps"] as? Double) + ?: (node.operation.parameters["epsilon"] as? Double) + ?: 1e-6 // Llama family default; LayerNorm typically uses 1e-5 + + val xInput = operands[0] + val scaleOperand: String? = if (operands.size >= 2) operands[1] else null + + val xSquared = context.nextTempValue() + val meanSquared = context.nextTempValue() + val epsConst = context.nextTempValue() + val epsBroadcast = context.nextTempValue() + val meanPlusEps = context.nextTempValue() + val rms = context.nextTempValue() + val rmsBroadcast = context.nextTempValue() + val normalized = context.nextTempValue() + val resultValue = context.nextTempValue() + + val operations = mutableListOf() + + // x^2 + operations += "$xSquared = stablehlo.multiply $xInput, $xInput : $outputType" + + // reduce_mean(x^2, axis) + operations += "$meanSquared = stablehlo.custom_call @reduce_mean($xSquared) " + + "{dimensions = [$axis], keepdim = false} : $reducedType" + + // eps constant broadcast into the reduced shape + operations += "$epsConst = stablehlo.constant dense<$eps> : tensor<$elementType>" + operations += "$epsBroadcast = stablehlo.broadcast_in_dim $epsConst, " + + "dims = [] : (tensor<$elementType>) -> $reducedType" + + // mean + eps + operations += "$meanPlusEps = stablehlo.add $meanSquared, $epsBroadcast : $reducedType" + + // rms = sqrt(mean + eps) + operations += "$rms = stablehlo.sqrt $meanPlusEps : $reducedType" + + // Broadcast rms back to the input shape for the elementwise divide. + operations += "$rmsBroadcast = stablehlo.broadcast_in_dim $rms, " + + "dims = [$broadcastDims] : ($reducedType) -> $outputType" + + // x / rms + operations += "$normalized = stablehlo.divide $xInput, $rmsBroadcast : $outputType" + + // Final scale multiply is optional — when the caller did not + // pass a scale operand we return the normalized value directly. + val finalValue: String + if (scaleOperand != null) { + operations += "$resultValue = stablehlo.multiply $normalized, $scaleOperand : $outputType" + finalValue = resultValue + } else { + finalValue = normalized + } + + operations.forEach { context.emitOperation(it) } + + return ConversionResult.Success( + outputValueName = finalValue, + emittedOperations = operations + ) + } + // Helper functions for parameter extraction private fun extractStride(params: Map): Pair { diff --git a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/RmsNormConverterTest.kt b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/RmsNormConverterTest.kt new file mode 100644 index 00000000..f09e4ede --- /dev/null +++ b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/RmsNormConverterTest.kt @@ -0,0 +1,178 @@ +package sk.ainet.compile.hlo + +import sk.ainet.compile.hlo.converters.NeuralNetOperationsConverter +import sk.ainet.lang.graph.DefaultComputeGraph +import sk.ainet.lang.graph.GraphEdge +import sk.ainet.lang.graph.GraphNode +import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.tensor.ops.Operation +import sk.ainet.lang.tensor.ops.TensorSpec +import sk.ainet.lang.tensor.ops.ValidationResult +import sk.ainet.lang.types.DType +import kotlin.test.Test +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +/** + * Covers the RMSNorm converter added for #479. Every Llama / Mistral / + * Qwen / Gemma family transformer normalizes activations with RMSNorm, + * not LayerNorm, so a missing converter here blocks every modern LLM + * export through the StableHLO pipeline. + * + * RMSNorm: + * rms = sqrt(mean(x^2, axis) + eps) + * out = scale * x / rms + * + * (No mean-centering, no offset — that's the distinction from LayerNorm.) + */ +class RmsNormConverterTest { + + @Test + fun rmsNorm_operation_is_supported_by_neural_net_converter() { + val registry = StableHloOperationRegistry() + registry.register(NeuralNetOperationsConverter()) + assertTrue( + registry.isSupported("rmsNorm"), + "NeuralNetOperationsConverter must register `rmsNorm`" + ) + assertTrue( + registry.isSupported("rms_norm"), + "NeuralNetOperationsConverter must register snake-case alias `rms_norm`" + ) + assertTrue( + registry.isSupported("RMSNorm"), + "NeuralNetOperationsConverter must register PascalCase alias `RMSNorm`" + ) + } + + @Test + fun rmsNorm_with_scale_lowers_to_real_ops() { + val graph = buildRmsNormGraph(withScale = true) + val converter = StableHloConverterFactory.createExtended() + val module = converter.convert(graph, "test_rms_norm") + println("[DEBUG_LOG] RMSNorm with scale:\n${module.content}") + + // Must not fall through to the registry's "unsupported" path. + assertFalse( + module.content.contains("Unsupported operation rmsNorm"), + "RMSNorm must be claimed by a converter, not dropped as unsupported" + ) + assertFalse( + module.content.contains("No converter found"), + "RMSNorm must be claimed by a converter, not left without a handler" + ) + + // Core elementwise decomposition of the norm. + assertTrue( + module.content.contains("stablehlo.multiply"), + "RMSNorm must emit at least one multiply (x*x and/or scale*x/rms)" + ) + assertTrue( + module.content.contains("@reduce_mean"), + "RMSNorm must lower mean(x^2) to a real reduction (custom_call style matches the rest of the emitter)" + ) + assertTrue( + module.content.contains("stablehlo.sqrt"), + "RMSNorm must take the sqrt of the mean-square-plus-eps term" + ) + assertTrue( + module.content.contains("stablehlo.divide"), + "RMSNorm must divide by the rms value" + ) + assertTrue( + module.content.contains("stablehlo.broadcast_in_dim"), + "RMSNorm must broadcast the reduced rms back to the input shape" + ) + + // Regression: the converter must NOT hardcode the normalization + // denominator or the mean-square term as a placeholder constant. + assertFalse( + module.content.contains("stablehlo.constant dense<1.0> : tensor<2x4xf32>"), + "RMSNorm must not emit a fake-rms constant at output shape" + ) + } + + @Test + fun rmsNorm_without_scale_still_normalizes() { + val graph = buildRmsNormGraph(withScale = false) + val converter = StableHloConverterFactory.createExtended() + val module = converter.convert(graph, "test_rms_norm_no_scale") + println("[DEBUG_LOG] RMSNorm without scale:\n${module.content}") + + assertFalse( + module.content.contains("Unsupported operation"), + "RMSNorm without a scale operand must still be claimed by the converter" + ) + // The core norm still happens — we just skip the final scale multiply. + assertTrue(module.content.contains("@reduce_mean")) + assertTrue(module.content.contains("stablehlo.sqrt")) + assertTrue(module.content.contains("stablehlo.divide")) + } + + private fun buildRmsNormGraph(withScale: Boolean): DefaultComputeGraph { + val graph = DefaultComputeGraph() + val shape = listOf(2, 4) + + val input = GraphNode( + id = "x", + operation = markerInputOp(), + inputs = emptyList(), + outputs = listOf(TensorSpec("x", shape, "FP32")) + ) + graph.addNode(input) + + val rmsNormInputs = mutableListOf(TensorSpec("x", shape, "FP32")) + val scaleNode: GraphNode? = if (withScale) { + val s = GraphNode( + id = "scale", + operation = markerInputOp(), + inputs = emptyList(), + outputs = listOf(TensorSpec("scale", listOf(4), "FP32")) + ) + graph.addNode(s) + rmsNormInputs.add(TensorSpec("scale", listOf(4), "FP32")) + s + } else null + + val rmsNorm = GraphNode( + id = "rms1", + operation = rmsNormOp(eps = 1e-6, axis = -1), + inputs = rmsNormInputs.toList(), + outputs = listOf(TensorSpec("y", shape, "FP32")) + ) + graph.addNode(rmsNorm) + + graph.addEdge(GraphEdge("e1", input, rmsNorm, 0, 0, input.outputs[0])) + if (scaleNode != null) { + graph.addEdge(GraphEdge("e2", scaleNode, rmsNorm, 0, 1, scaleNode.outputs[0])) + } + + return graph + } + + private fun markerInputOp(): Operation = object : Operation { + override val name: String = "input" + override val type: String = "input" + override val parameters: Map = emptyMap() + override fun execute(inputs: List>): List> = + throw UnsupportedOperationException("test fixture only") + override fun validateInputs(inputs: List): ValidationResult = ValidationResult.Valid + override fun inferOutputs(inputs: List): List = emptyList() + override fun clone(newParameters: Map): Operation = this + override fun serialize(): Map = mapOf("name" to name, "type" to type) + } + + private fun rmsNormOp(eps: Double, axis: Int): Operation = object : Operation { + override val name: String = "rmsNorm" + override val type: String = "normalization" + override val parameters: Map = mapOf("eps" to eps, "axis" to axis) + override fun execute(inputs: List>): List> = + throw UnsupportedOperationException("test fixture only") + override fun validateInputs(inputs: List): ValidationResult = ValidationResult.Valid + override fun inferOutputs(inputs: List): List = inputs.take(1) + override fun clone(newParameters: Map): Operation = this + override fun serialize(): Map = mapOf( + "name" to name, "type" to type, "parameters" to parameters + ) + } +}