Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -278,45 +278,145 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter {
)
}

/**
* Lower LayerNorm to real StableHLO elementwise ops. Replaces an
* earlier lowering that emitted `stablehlo.custom_call @layer_norm`
* as a placeholder (no MLIR tool in the repo understands that
* custom_call), using the standard decomposition:
*
* out = scale * (x - mean) / sqrt(var + eps) + offset
*
* Emission style matches the softmax fix (#467) and the rest of
* the emitter: reductions go through
* `stablehlo.custom_call @reduce_mean` / `@reduce_variance` (both
* already supported by `ReductionOperationsConverter`), the reduced
* tensors are broadcast back to the input shape via
* `stablehlo.broadcast_in_dim`, and scale / offset are elementwise
* multiplied / added only when their operands are actually present.
* Migrating every reduction to real `stablehlo.reduce` regions is
* a separate, larger refactor.
*/
private fun convertLayerNorm(
node: GraphNode,
operands: List<String>,
context: ConversionContext
): ConversionResult {
if (operands.size < 1) {
if (operands.isEmpty()) {
return ConversionResult.Failure(
"LayerNorm operation requires at least 1 operand (input), got ${operands.size}",
"Unsupported layerNorm arity for node ${node.id}"
)
}

val outputSpec = node.outputs.firstOrNull()
val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) }
val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) }
?: "tensor<?x?xf32>"

// Extract layer norm parameters
val params = node.operation.parameters
val epsilon = params["eps"] as? Double ?: 1e-5
val normalizedShape = params["normalized_shape"] as? IntArray ?: intArrayOf(-1)

val resultValue = context.nextTempValue()

// Build StableHLO layer normalization using reduce operations
val layerNormOperation = buildLayerNormOperation(
resultValue = resultValue,
input = operands[0],
scale = if (operands.size > 1) operands[1] else null,
offset = if (operands.size > 2) operands[2] else null,
outputType = outputType,
epsilon = epsilon,
normalizedShape = normalizedShape
)

context.emitOperation(layerNormOperation)

val elementType = outputSpec?.let { context.getTypeMapper().mapDType(it.dtype) }
?: "f32"

val inputShape = node.inputs.firstOrNull()?.shape ?: outputSpec?.shape ?: emptyList()
val rank = inputShape.size

// Normalize the axis parameter against rank. Default to the
// last dimension, matching every standard LayerNorm in the
// wild. Callers may supply either an `axis` integer or an
// `IntArray normalized_shape`; in the latter case we reduce
// along the leading element (simple single-axis support).
val rawAxis = node.operation.parameters["axis"] as? Int
?: (node.operation.parameters["normalized_shape"] as? IntArray)?.firstOrNull()
?: -1
val axis = when {
rank == 0 -> 0
rawAxis < 0 -> rank + rawAxis
else -> rawAxis
}.coerceIn(0, (rank - 1).coerceAtLeast(0))

val reducedShape = if (rank > 0) {
inputShape.filterIndexed { i, _ -> i != axis }
} else {
emptyList()
}
val reducedType = if (reducedShape.isEmpty()) {
"tensor<$elementType>"
} else {
"tensor<${reducedShape.joinToString("x")}x$elementType>"
}
val broadcastDims = (0 until rank).filter { it != axis }.joinToString(", ")

val epsilon = (node.operation.parameters["eps"] as? Double)
?: (node.operation.parameters["epsilon"] as? Double)
?: 1e-5 // LayerNorm family default.

val xInput = operands[0]
val scaleOperand: String? = if (operands.size > 1) operands[1] else null
val offsetOperand: String? = if (operands.size > 2) operands[2] else null

val meanValue = context.nextTempValue()
val meanBroadcast = context.nextTempValue()
val centered = context.nextTempValue()
val varValue = context.nextTempValue()
val epsConst = context.nextTempValue()
val epsBroadcast = context.nextTempValue()
val varPlusEps = context.nextTempValue()
val stdValue = context.nextTempValue()
val stdBroadcast = context.nextTempValue()
val normalized = context.nextTempValue()

val operations = mutableListOf<String>()

// mean(x) along the normalization axis.
operations += "$meanValue = stablehlo.custom_call @reduce_mean($xInput) " +
"{dimensions = [$axis], keepdim = false} : $reducedType"

// Broadcast mean back to input shape.
operations += "$meanBroadcast = stablehlo.broadcast_in_dim $meanValue, " +
"dims = [$broadcastDims] : ($reducedType) -> $outputType"

// Mean-center.
operations += "$centered = stablehlo.subtract $xInput, $meanBroadcast : $outputType"

// variance(x) along the normalization axis.
operations += "$varValue = stablehlo.custom_call @reduce_variance($xInput) " +
"{dimensions = [$axis], keepdim = false} : $reducedType"

// Epsilon constant broadcast into the reduced shape.
operations += "$epsConst = stablehlo.constant dense<$epsilon> : tensor<$elementType>"
operations += "$epsBroadcast = stablehlo.broadcast_in_dim $epsConst, " +
"dims = [] : (tensor<$elementType>) -> $reducedType"

// variance + eps
operations += "$varPlusEps = stablehlo.add $varValue, $epsBroadcast : $reducedType"

// std = sqrt(variance + eps)
operations += "$stdValue = stablehlo.sqrt $varPlusEps : $reducedType"

// Broadcast std back to the input shape.
operations += "$stdBroadcast = stablehlo.broadcast_in_dim $stdValue, " +
"dims = [$broadcastDims] : ($reducedType) -> $outputType"

// normalized = (x - mean) / std
operations += "$normalized = stablehlo.divide $centered, $stdBroadcast : $outputType"

// Apply scale and offset if present. Track the current running
// SSA value so omitting either one keeps the emitted MLIR
// faithful to the input graph.
var current = normalized
if (scaleOperand != null) {
val scaled = context.nextTempValue()
operations += "$scaled = stablehlo.multiply $current, $scaleOperand : $outputType"
current = scaled
}
if (offsetOperand != null) {
val offsetted = context.nextTempValue()
operations += "$offsetted = stablehlo.add $current, $offsetOperand : $outputType"
current = offsetted
}

operations.forEach { context.emitOperation(it) }

return ConversionResult.Success(
outputValueName = resultValue,
emittedOperations = listOf(layerNormOperation)
outputValueName = current,
emittedOperations = operations
)
}

Expand Down Expand Up @@ -528,23 +628,4 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter {
}
}

private fun buildLayerNormOperation(
resultValue: String,
input: String,
scale: String?,
offset: String?,
outputType: String,
epsilon: Double,
normalizedShape: IntArray
): String {
// Layer normalization is implemented using reduce operations
// This is a simplified version - full implementation would need proper mean/variance computation
return if (scale != null && offset != null) {
"$resultValue = stablehlo.custom_call @layer_norm($input, $scale, $offset) " +
"{epsilon = $epsilon} : $outputType"
} else {
"$resultValue = stablehlo.custom_call @layer_norm($input) " +
"{epsilon = $epsilon} : $outputType"
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
package sk.ainet.compile.hlo

import sk.ainet.lang.graph.DefaultComputeGraph
import sk.ainet.lang.graph.GraphEdge
import sk.ainet.lang.graph.GraphNode
import sk.ainet.lang.tensor.Tensor
import sk.ainet.lang.tensor.ops.Operation
import sk.ainet.lang.tensor.ops.TensorSpec
import sk.ainet.lang.tensor.ops.ValidationResult
import sk.ainet.lang.types.DType
import kotlin.test.Test
import kotlin.test.assertFalse
import kotlin.test.assertTrue

/**
* Covers the LayerNorm lowering rewrite for #480.
*
* Before this fix, `NeuralNetOperationsConverter.convertLayerNorm`
* emitted `stablehlo.custom_call @layer_norm(...)`, which no MLIR
* tool in the repo understands. This test pins the new lowering:
* a real elementwise decomposition using @reduce_mean / @reduce_variance
* / broadcast_in_dim / sqrt / divide — matching softmax #467 and the
* codebase's existing reduction-via-custom-call style.
*
* layer_norm(x) = scale * (x - mean) / sqrt(var + eps) + offset
*/
class LayerNormConverterTest {

@Test
fun layerNorm_does_not_emit_custom_call_stub() {
val graph = buildLayerNormGraph(withScale = true, withOffset = true)
val converter = StableHloConverterFactory.createExtended()
val module = converter.convert(graph, "test_layer_norm")
println("[DEBUG_LOG] LayerNorm lowering:\n${module.content}")

assertFalse(
module.content.contains("@layer_norm"),
"layerNorm must not fall back to the @layer_norm custom_call stub"
)
}

@Test
fun layerNorm_lowers_to_real_reductions_and_broadcasts() {
val graph = buildLayerNormGraph(withScale = true, withOffset = true)
val converter = StableHloConverterFactory.createExtended()
val module = converter.convert(graph, "test_layer_norm_full")

// Core elementwise decomposition.
assertTrue(
module.content.contains("@reduce_mean"),
"layerNorm must lower mean(x) to a real reduction"
)
assertTrue(
module.content.contains("@reduce_variance"),
"layerNorm must lower var(x) to a real reduction"
)
assertTrue(
module.content.contains("stablehlo.subtract"),
"layerNorm must subtract the mean (mean-centering)"
)
assertTrue(
module.content.contains("stablehlo.sqrt"),
"layerNorm must take the square root of variance + epsilon"
)
assertTrue(
module.content.contains("stablehlo.divide"),
"layerNorm must divide by the standard deviation"
)
assertTrue(
module.content.contains("stablehlo.broadcast_in_dim"),
"layerNorm must broadcast the reduced mean / std back to the input shape"
)
assertTrue(
module.content.contains("stablehlo.multiply"),
"layerNorm must apply the scale multiplier when a scale operand is present"
)
assertTrue(
module.content.contains("stablehlo.add"),
"layerNorm must apply the additive offset when an offset operand is present"
)
}

@Test
fun layerNorm_without_scale_or_offset_still_lowers_correctly() {
val graph = buildLayerNormGraph(withScale = false, withOffset = false)
val converter = StableHloConverterFactory.createExtended()
val module = converter.convert(graph, "test_layer_norm_minimal")

assertFalse(module.content.contains("@layer_norm"))
assertTrue(module.content.contains("@reduce_mean"))
assertTrue(module.content.contains("@reduce_variance"))
assertTrue(module.content.contains("stablehlo.subtract"))
assertTrue(module.content.contains("stablehlo.sqrt"))
assertTrue(module.content.contains("stablehlo.divide"))
}

private fun buildLayerNormGraph(withScale: Boolean, withOffset: Boolean): DefaultComputeGraph {
val graph = DefaultComputeGraph()
val shape = listOf(2, 4)

val input = GraphNode(
id = "x",
operation = markerInputOp(),
inputs = emptyList(),
outputs = listOf(TensorSpec("x", shape, "FP32"))
)
graph.addNode(input)

val layerNormInputs = mutableListOf(TensorSpec("x", shape, "FP32"))
val extraEdges = mutableListOf<Pair<GraphNode, Int>>()

if (withScale) {
val scaleNode = GraphNode(
id = "scale",
operation = markerInputOp(),
inputs = emptyList(),
outputs = listOf(TensorSpec("scale", listOf(4), "FP32"))
)
graph.addNode(scaleNode)
layerNormInputs.add(TensorSpec("scale", listOf(4), "FP32"))
extraEdges.add(scaleNode to (layerNormInputs.size - 1))
}
if (withOffset) {
val offsetNode = GraphNode(
id = "offset",
operation = markerInputOp(),
inputs = emptyList(),
outputs = listOf(TensorSpec("offset", listOf(4), "FP32"))
)
graph.addNode(offsetNode)
layerNormInputs.add(TensorSpec("offset", listOf(4), "FP32"))
extraEdges.add(offsetNode to (layerNormInputs.size - 1))
}

val layerNorm = GraphNode(
id = "ln1",
operation = layerNormOp(eps = 1e-5, axis = -1),
inputs = layerNormInputs.toList(),
outputs = listOf(TensorSpec("y", shape, "FP32"))
)
graph.addNode(layerNorm)
graph.addEdge(GraphEdge("e1", input, layerNorm, 0, 0, input.outputs[0]))
extraEdges.forEachIndexed { i, (src, idx) ->
graph.addEdge(GraphEdge("e${i + 2}", src, layerNorm, 0, idx, src.outputs[0]))
}

return graph
}

private fun markerInputOp(): Operation = object : Operation {
override val name: String = "input"
override val type: String = "input"
override val parameters: Map<String, Any> = emptyMap()
override fun <T : DType, V> execute(inputs: List<Tensor<T, V>>): List<Tensor<T, V>> =
throw UnsupportedOperationException("test fixture only")
override fun validateInputs(inputs: List<TensorSpec>): ValidationResult = ValidationResult.Valid
override fun inferOutputs(inputs: List<TensorSpec>): List<TensorSpec> = emptyList()
override fun clone(newParameters: Map<String, Any>): Operation = this
override fun serialize(): Map<String, Any> = mapOf("name" to name, "type" to type)
}

private fun layerNormOp(eps: Double, axis: Int): Operation = object : Operation {
override val name: String = "layerNorm"
override val type: String = "normalization"
override val parameters: Map<String, Any> = mapOf("eps" to eps, "axis" to axis)
override fun <T : DType, V> execute(inputs: List<Tensor<T, V>>): List<Tensor<T, V>> =
throw UnsupportedOperationException("test fixture only")
override fun validateInputs(inputs: List<TensorSpec>): ValidationResult = ValidationResult.Valid
override fun inferOutputs(inputs: List<TensorSpec>): List<TensorSpec> = inputs.take(1)
override fun clone(newParameters: Map<String, Any>): Operation = this
override fun serialize(): Map<String, Any> = mapOf(
"name" to name, "type" to type, "parameters" to parameters
)
}
}
Loading