Skip to content

Lower layerNorm to real StableHLO ops instead of custom_call (#480)#482

Merged
michalharakal merged 2 commits intodevelopfrom
feature/480-layernorm-real-lowering
Apr 13, 2026
Merged

Lower layerNorm to real StableHLO ops instead of custom_call (#480)#482
michalharakal merged 2 commits intodevelopfrom
feature/480-layernorm-real-lowering

Conversation

@michalharakal
Copy link
Copy Markdown
Contributor

Closes #480.

Summary

Replaces `NeuralNetOperationsConverter.convertLayerNorm`'s `stablehlo.custom_call @layer_norm` placeholder with a real elementwise decomposition that downstream tools (IREE, any StableHLO-reading pass) can actually interpret.

out = scale * (x - mean) / sqrt(var + eps) + offset

Previously every `.mlir` module emitted with a LayerNorm had a hole punched in it at the most expensive op in every transformer block. This lowering matches the style established by softmax (#467) and RMSNorm (#479): reductions go through `stablehlo.custom_call @reduce_mean` / `@reduce_variance` (already supported by `ReductionOperationsConverter`), the reduced tensors are broadcast back to the input shape via `stablehlo.broadcast_in_dim`, and scale / offset are elementwise applied only when their operands are present.

Two commits

1. Failing test — `LayerNormConverterTest`

  • `layerNorm_does_not_emit_custom_call_stub` — asserts the emitted MLIR does not contain `@layer_norm`. Red against the pre-fix converter.
  • `layerNorm_lowers_to_real_reductions_and_broadcasts` — asserts the presence of `@reduce_mean`, `@reduce_variance`, `subtract`, `sqrt`, `divide`, `broadcast_in_dim`, `multiply`, and `add`.
  • `layerNorm_without_scale_or_offset_still_lowers_correctly` — scale and offset are optional; dropping them must still leave a correctly-lowered norm.

2. The fix

  • Rewrites `convertLayerNorm` inline and deletes the now-unused `buildLayerNormOperation` helper.
  • Handles negative axis / `normalized_shape: IntArray` parameter form.
  • Default `epsilon` 1e-5 (LayerNorm family convention); overridable via `eps` / `epsilon`.
  • Tracks a running SSA value across the optional scale / offset stages so omitting either one keeps the emitted MLIR faithful to the input graph.

Test plan

  • `LayerNormConverterTest` — 3/3 green
  • Full `:skainet-compile:skainet-compile-hlo:jvmTest` — green, no regressions (existing layerNorm registration check at `NeuralNetOperationsConverterTest:29` is unaffected)
  • CI: full multiplatform build

Out of scope

  • Real `stablehlo.reduce` region bodies. Every reduction in the emitter uses `custom_call @reduce_*` today; migrating them to proper regions is a larger separate refactor.
  • Fused LayerNorm-inside-attention patterns (IREE-side optimization).
  • Quantized LayerNorm — depends on further P0-1 work.

🤖 Generated with Claude Code

michalharakal and others added 2 commits April 13, 2026 13:07
Adds LayerNormConverterTest with three cases:

1. layerNorm_does_not_emit_custom_call_stub — asserts the
   emitted MLIR does not contain `@layer_norm`. Red against
   NeuralNetOperationsConverter today, which emits
   `stablehlo.custom_call @layer_norm(...)` as a placeholder.
2. layerNorm_lowers_to_real_reductions_and_broadcasts — asserts
   the full elementwise decomposition surfaces:
     @reduce_mean, @reduce_variance, subtract (mean-centering),
     sqrt (of var+eps), divide (by std), broadcast_in_dim
     (reduced back to input shape), multiply (scale), add
     (offset).
3. layerNorm_without_scale_or_offset_still_lowers_correctly —
   scale and offset are optional; dropping them must still
   leave a correctly-lowered norm.

Test uses the same minimal in-file operation fixture style as
RmsNormConverterTest. Shape is 2×4 FP32, axis=-1, eps=1e-5
(LayerNorm family default).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replaces the @layer_norm custom_call placeholder with the
standard elementwise decomposition:

    out = scale * (x - mean) / sqrt(var + eps) + offset

No MLIR tool in the repo understands @layer_norm, so every
.mlir module emitted with a LayerNorm had a hole punched in it
at the most expensive op in every transformer block. This fix
rewrites convertLayerNorm to emit a sequence that consumers
(IREE, any StableHLO-reading pass) can actually interpret:

- mean(x) via stablehlo.custom_call @reduce_mean
- broadcast_in_dim of the reduced mean back to the input shape
- stablehlo.subtract for mean-centering
- variance(x) via stablehlo.custom_call @reduce_variance
- stablehlo.constant + broadcast_in_dim for the epsilon term
- stablehlo.add of var + eps
- stablehlo.sqrt for std
- broadcast_in_dim of std back to the input shape
- stablehlo.divide of (x - mean) by std
- optional stablehlo.multiply with scale (omitted if absent)
- optional stablehlo.add with offset   (omitted if absent)

Emission style matches softmax #467 and RMSNorm #479 — all three
lower their reductions through stablehlo.custom_call @reduce_*
to stay consistent with ReductionOperationsConverter. Migrating
every reduction to real stablehlo.reduce regions is a separate
larger refactor.

Axis normalization against rank handles negative axes and the
`normalized_shape: IntArray` parameter form. Default epsilon
1e-5 matches the LayerNorm family convention; callers can
override via `eps` or `epsilon`.

buildLayerNormOperation is deleted — the new convertLayerNorm
emits its ops inline, matching the pattern used by the softmax
and RMSNorm paths.

Tests: 3/3 in LayerNormConverterTest green, full compile-hlo
jvmTest suite still green (existing layerNorm registration
check at NeuralNetOperationsConverterTest:29 is unaffected).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@michalharakal michalharakal merged commit effc502 into develop Apr 13, 2026
2 of 4 checks passed
@michalharakal michalharakal deleted the feature/480-layernorm-real-lowering branch April 13, 2026 11:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Lower layerNorm to real StableHLO ops instead of custom_call stub (P1)

1 participant