diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/NeuralNetOperationsConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/NeuralNetOperationsConverter.kt index fc6e0a27..c30d4606 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/NeuralNetOperationsConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/NeuralNetOperationsConverter.kt @@ -278,45 +278,145 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter { ) } + /** + * Lower LayerNorm to real StableHLO elementwise ops. Replaces an + * earlier lowering that emitted `stablehlo.custom_call @layer_norm` + * as a placeholder (no MLIR tool in the repo understands that + * custom_call), using the standard decomposition: + * + * out = scale * (x - mean) / sqrt(var + eps) + offset + * + * Emission style matches the softmax fix (#467) and the rest of + * the emitter: reductions go through + * `stablehlo.custom_call @reduce_mean` / `@reduce_variance` (both + * already supported by `ReductionOperationsConverter`), the reduced + * tensors are broadcast back to the input shape via + * `stablehlo.broadcast_in_dim`, and scale / offset are elementwise + * multiplied / added only when their operands are actually present. + * Migrating every reduction to real `stablehlo.reduce` regions is + * a separate, larger refactor. + */ private fun convertLayerNorm( node: GraphNode, operands: List, context: ConversionContext ): ConversionResult { - if (operands.size < 1) { + if (operands.isEmpty()) { return ConversionResult.Failure( "LayerNorm operation requires at least 1 operand (input), got ${operands.size}", "Unsupported layerNorm arity for node ${node.id}" ) } - + val outputSpec = node.outputs.firstOrNull() - val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) } + val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) } ?: "tensor" - - // Extract layer norm parameters - val params = node.operation.parameters - val epsilon = params["eps"] as? Double ?: 1e-5 - val normalizedShape = params["normalized_shape"] as? IntArray ?: intArrayOf(-1) - - val resultValue = context.nextTempValue() - - // Build StableHLO layer normalization using reduce operations - val layerNormOperation = buildLayerNormOperation( - resultValue = resultValue, - input = operands[0], - scale = if (operands.size > 1) operands[1] else null, - offset = if (operands.size > 2) operands[2] else null, - outputType = outputType, - epsilon = epsilon, - normalizedShape = normalizedShape - ) - - context.emitOperation(layerNormOperation) - + val elementType = outputSpec?.let { context.getTypeMapper().mapDType(it.dtype) } + ?: "f32" + + val inputShape = node.inputs.firstOrNull()?.shape ?: outputSpec?.shape ?: emptyList() + val rank = inputShape.size + + // Normalize the axis parameter against rank. Default to the + // last dimension, matching every standard LayerNorm in the + // wild. Callers may supply either an `axis` integer or an + // `IntArray normalized_shape`; in the latter case we reduce + // along the leading element (simple single-axis support). + val rawAxis = node.operation.parameters["axis"] as? Int + ?: (node.operation.parameters["normalized_shape"] as? IntArray)?.firstOrNull() + ?: -1 + val axis = when { + rank == 0 -> 0 + rawAxis < 0 -> rank + rawAxis + else -> rawAxis + }.coerceIn(0, (rank - 1).coerceAtLeast(0)) + + val reducedShape = if (rank > 0) { + inputShape.filterIndexed { i, _ -> i != axis } + } else { + emptyList() + } + val reducedType = if (reducedShape.isEmpty()) { + "tensor<$elementType>" + } else { + "tensor<${reducedShape.joinToString("x")}x$elementType>" + } + val broadcastDims = (0 until rank).filter { it != axis }.joinToString(", ") + + val epsilon = (node.operation.parameters["eps"] as? Double) + ?: (node.operation.parameters["epsilon"] as? Double) + ?: 1e-5 // LayerNorm family default. + + val xInput = operands[0] + val scaleOperand: String? = if (operands.size > 1) operands[1] else null + val offsetOperand: String? = if (operands.size > 2) operands[2] else null + + val meanValue = context.nextTempValue() + val meanBroadcast = context.nextTempValue() + val centered = context.nextTempValue() + val varValue = context.nextTempValue() + val epsConst = context.nextTempValue() + val epsBroadcast = context.nextTempValue() + val varPlusEps = context.nextTempValue() + val stdValue = context.nextTempValue() + val stdBroadcast = context.nextTempValue() + val normalized = context.nextTempValue() + + val operations = mutableListOf() + + // mean(x) along the normalization axis. + operations += "$meanValue = stablehlo.custom_call @reduce_mean($xInput) " + + "{dimensions = [$axis], keepdim = false} : $reducedType" + + // Broadcast mean back to input shape. + operations += "$meanBroadcast = stablehlo.broadcast_in_dim $meanValue, " + + "dims = [$broadcastDims] : ($reducedType) -> $outputType" + + // Mean-center. + operations += "$centered = stablehlo.subtract $xInput, $meanBroadcast : $outputType" + + // variance(x) along the normalization axis. + operations += "$varValue = stablehlo.custom_call @reduce_variance($xInput) " + + "{dimensions = [$axis], keepdim = false} : $reducedType" + + // Epsilon constant broadcast into the reduced shape. + operations += "$epsConst = stablehlo.constant dense<$epsilon> : tensor<$elementType>" + operations += "$epsBroadcast = stablehlo.broadcast_in_dim $epsConst, " + + "dims = [] : (tensor<$elementType>) -> $reducedType" + + // variance + eps + operations += "$varPlusEps = stablehlo.add $varValue, $epsBroadcast : $reducedType" + + // std = sqrt(variance + eps) + operations += "$stdValue = stablehlo.sqrt $varPlusEps : $reducedType" + + // Broadcast std back to the input shape. + operations += "$stdBroadcast = stablehlo.broadcast_in_dim $stdValue, " + + "dims = [$broadcastDims] : ($reducedType) -> $outputType" + + // normalized = (x - mean) / std + operations += "$normalized = stablehlo.divide $centered, $stdBroadcast : $outputType" + + // Apply scale and offset if present. Track the current running + // SSA value so omitting either one keeps the emitted MLIR + // faithful to the input graph. + var current = normalized + if (scaleOperand != null) { + val scaled = context.nextTempValue() + operations += "$scaled = stablehlo.multiply $current, $scaleOperand : $outputType" + current = scaled + } + if (offsetOperand != null) { + val offsetted = context.nextTempValue() + operations += "$offsetted = stablehlo.add $current, $offsetOperand : $outputType" + current = offsetted + } + + operations.forEach { context.emitOperation(it) } + return ConversionResult.Success( - outputValueName = resultValue, - emittedOperations = listOf(layerNormOperation) + outputValueName = current, + emittedOperations = operations ) } @@ -528,23 +628,4 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter { } } - private fun buildLayerNormOperation( - resultValue: String, - input: String, - scale: String?, - offset: String?, - outputType: String, - epsilon: Double, - normalizedShape: IntArray - ): String { - // Layer normalization is implemented using reduce operations - // This is a simplified version - full implementation would need proper mean/variance computation - return if (scale != null && offset != null) { - "$resultValue = stablehlo.custom_call @layer_norm($input, $scale, $offset) " + - "{epsilon = $epsilon} : $outputType" - } else { - "$resultValue = stablehlo.custom_call @layer_norm($input) " + - "{epsilon = $epsilon} : $outputType" - } - } } \ No newline at end of file diff --git a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/LayerNormConverterTest.kt b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/LayerNormConverterTest.kt new file mode 100644 index 00000000..9746861b --- /dev/null +++ b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/LayerNormConverterTest.kt @@ -0,0 +1,175 @@ +package sk.ainet.compile.hlo + +import sk.ainet.lang.graph.DefaultComputeGraph +import sk.ainet.lang.graph.GraphEdge +import sk.ainet.lang.graph.GraphNode +import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.tensor.ops.Operation +import sk.ainet.lang.tensor.ops.TensorSpec +import sk.ainet.lang.tensor.ops.ValidationResult +import sk.ainet.lang.types.DType +import kotlin.test.Test +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +/** + * Covers the LayerNorm lowering rewrite for #480. + * + * Before this fix, `NeuralNetOperationsConverter.convertLayerNorm` + * emitted `stablehlo.custom_call @layer_norm(...)`, which no MLIR + * tool in the repo understands. This test pins the new lowering: + * a real elementwise decomposition using @reduce_mean / @reduce_variance + * / broadcast_in_dim / sqrt / divide — matching softmax #467 and the + * codebase's existing reduction-via-custom-call style. + * + * layer_norm(x) = scale * (x - mean) / sqrt(var + eps) + offset + */ +class LayerNormConverterTest { + + @Test + fun layerNorm_does_not_emit_custom_call_stub() { + val graph = buildLayerNormGraph(withScale = true, withOffset = true) + val converter = StableHloConverterFactory.createExtended() + val module = converter.convert(graph, "test_layer_norm") + println("[DEBUG_LOG] LayerNorm lowering:\n${module.content}") + + assertFalse( + module.content.contains("@layer_norm"), + "layerNorm must not fall back to the @layer_norm custom_call stub" + ) + } + + @Test + fun layerNorm_lowers_to_real_reductions_and_broadcasts() { + val graph = buildLayerNormGraph(withScale = true, withOffset = true) + val converter = StableHloConverterFactory.createExtended() + val module = converter.convert(graph, "test_layer_norm_full") + + // Core elementwise decomposition. + assertTrue( + module.content.contains("@reduce_mean"), + "layerNorm must lower mean(x) to a real reduction" + ) + assertTrue( + module.content.contains("@reduce_variance"), + "layerNorm must lower var(x) to a real reduction" + ) + assertTrue( + module.content.contains("stablehlo.subtract"), + "layerNorm must subtract the mean (mean-centering)" + ) + assertTrue( + module.content.contains("stablehlo.sqrt"), + "layerNorm must take the square root of variance + epsilon" + ) + assertTrue( + module.content.contains("stablehlo.divide"), + "layerNorm must divide by the standard deviation" + ) + assertTrue( + module.content.contains("stablehlo.broadcast_in_dim"), + "layerNorm must broadcast the reduced mean / std back to the input shape" + ) + assertTrue( + module.content.contains("stablehlo.multiply"), + "layerNorm must apply the scale multiplier when a scale operand is present" + ) + assertTrue( + module.content.contains("stablehlo.add"), + "layerNorm must apply the additive offset when an offset operand is present" + ) + } + + @Test + fun layerNorm_without_scale_or_offset_still_lowers_correctly() { + val graph = buildLayerNormGraph(withScale = false, withOffset = false) + val converter = StableHloConverterFactory.createExtended() + val module = converter.convert(graph, "test_layer_norm_minimal") + + assertFalse(module.content.contains("@layer_norm")) + assertTrue(module.content.contains("@reduce_mean")) + assertTrue(module.content.contains("@reduce_variance")) + assertTrue(module.content.contains("stablehlo.subtract")) + assertTrue(module.content.contains("stablehlo.sqrt")) + assertTrue(module.content.contains("stablehlo.divide")) + } + + private fun buildLayerNormGraph(withScale: Boolean, withOffset: Boolean): DefaultComputeGraph { + val graph = DefaultComputeGraph() + val shape = listOf(2, 4) + + val input = GraphNode( + id = "x", + operation = markerInputOp(), + inputs = emptyList(), + outputs = listOf(TensorSpec("x", shape, "FP32")) + ) + graph.addNode(input) + + val layerNormInputs = mutableListOf(TensorSpec("x", shape, "FP32")) + val extraEdges = mutableListOf>() + + if (withScale) { + val scaleNode = GraphNode( + id = "scale", + operation = markerInputOp(), + inputs = emptyList(), + outputs = listOf(TensorSpec("scale", listOf(4), "FP32")) + ) + graph.addNode(scaleNode) + layerNormInputs.add(TensorSpec("scale", listOf(4), "FP32")) + extraEdges.add(scaleNode to (layerNormInputs.size - 1)) + } + if (withOffset) { + val offsetNode = GraphNode( + id = "offset", + operation = markerInputOp(), + inputs = emptyList(), + outputs = listOf(TensorSpec("offset", listOf(4), "FP32")) + ) + graph.addNode(offsetNode) + layerNormInputs.add(TensorSpec("offset", listOf(4), "FP32")) + extraEdges.add(offsetNode to (layerNormInputs.size - 1)) + } + + val layerNorm = GraphNode( + id = "ln1", + operation = layerNormOp(eps = 1e-5, axis = -1), + inputs = layerNormInputs.toList(), + outputs = listOf(TensorSpec("y", shape, "FP32")) + ) + graph.addNode(layerNorm) + graph.addEdge(GraphEdge("e1", input, layerNorm, 0, 0, input.outputs[0])) + extraEdges.forEachIndexed { i, (src, idx) -> + graph.addEdge(GraphEdge("e${i + 2}", src, layerNorm, 0, idx, src.outputs[0])) + } + + return graph + } + + private fun markerInputOp(): Operation = object : Operation { + override val name: String = "input" + override val type: String = "input" + override val parameters: Map = emptyMap() + override fun execute(inputs: List>): List> = + throw UnsupportedOperationException("test fixture only") + override fun validateInputs(inputs: List): ValidationResult = ValidationResult.Valid + override fun inferOutputs(inputs: List): List = emptyList() + override fun clone(newParameters: Map): Operation = this + override fun serialize(): Map = mapOf("name" to name, "type" to type) + } + + private fun layerNormOp(eps: Double, axis: Int): Operation = object : Operation { + override val name: String = "layerNorm" + override val type: String = "normalization" + override val parameters: Map = mapOf("eps" to eps, "axis" to axis) + override fun execute(inputs: List>): List> = + throw UnsupportedOperationException("test fixture only") + override fun validateInputs(inputs: List): ValidationResult = ValidationResult.Valid + override fun inferOutputs(inputs: List): List = inputs.take(1) + override fun clone(newParameters: Map): Operation = this + override fun serialize(): Map = mapOf( + "name" to name, "type" to type, "parameters" to parameters + ) + } +}