Skip to content

Add RMSNorm StableHLO converter (#479)#481

Merged
michalharakal merged 2 commits intodevelopfrom
feature/479-rmsnorm-converter
Apr 13, 2026
Merged

Add RMSNorm StableHLO converter (#479)#481
michalharakal merged 2 commits intodevelopfrom
feature/479-rmsnorm-converter

Conversation

@michalharakal
Copy link
Copy Markdown
Contributor

Closes #479.

Summary

Adds a real RMSNorm lowering so modern transformer exports (Llama / Mistral / Qwen / Gemma — every open-weight LLM family uses RMSNorm rather than LayerNorm) stop dropping through the converter registry's "no converter found" path.

rms  = sqrt(mean(x^2, axis) + eps)
out  = scale * x / rms      (scale operand optional)

Emission style matches the softmax fix (#467) and the rest of the emitter: reductions go through `stablehlo.custom_call @reduce_mean`, the reduced tensor is broadcast back to the input shape via `stablehlo.broadcast_in_dim` for the final divide, and the epsilon is materialized as a scalar constant broadcast into the reduced shape.

Two commits

1. Failing test — `RmsNormConverterTest`

  • `rmsNorm_operation_is_supported_by_neural_net_converter` — asserts the registry recognizes `rmsNorm`, `rms_norm`, and `RMSNorm`.
  • `rmsNorm_with_scale_lowers_to_real_ops` — builds a 2×4 FP32 graph with an RMSNorm node and a per-channel scale operand, asserts the emitted MLIR contains `@reduce_mean`, `sqrt`, `divide`, `broadcast_in_dim`, and `multiply`, and is not labelled `Unsupported operation rmsNorm`.
  • `rmsNorm_without_scale_still_normalizes` — the scale operand is optional; dropping it must still produce the core norm.

Red against the pre-fix converter: no handler claims `rmsNorm` at all.

2. The fix

  • `NeuralNetOperationsConverter.supportedOperations` adds `rmsNorm`, `rms_norm`, `RMSNorm`, `RmsNorm`.
  • New `convertRmsNorm` method handles axis normalization against rank (negative axes; `IntArray normalized_shape` parameter too), uses 1e-6 as the default epsilon (Llama family convention; overridable via `eps` / `epsilon`), and omits the final affine multiply when no scale operand is present so the emitted MLIR stays faithful to the input graph.

Test plan

  • `RmsNormConverterTest` — 3/3 green
  • Full `:skainet-compile:skainet-compile-hlo:jvmTest` — green, no regressions
  • CI: full multiplatform build

Out of scope

  • Real `stablehlo.reduce` region bodies. Every reduction in the emitter uses `custom_call @reduce_*` today; migrating all reductions to proper regions is a separate, larger refactor.
  • Fusing RMSNorm with the following matmul / attention. That's an IREE-side optimization.
  • Quantized RMSNorm — depends on further P0-1 track work.

🤖 Generated with Claude Code

michalharakal and others added 2 commits April 13, 2026 13:00
Adds RmsNormConverterTest with three cases:

1. rmsNorm_operation_is_supported_by_neural_net_converter —
   asserts NeuralNetOperationsConverter registers rmsNorm plus
   the rms_norm and RMSNorm aliases. Red today.
2. rmsNorm_with_scale_lowers_to_real_ops — builds a 2×4 FP32
   graph with an RMSNorm node and a per-channel scale operand,
   runs the converter, asserts the emitted module contains
   @reduce_mean, sqrt, divide, broadcast_in_dim, multiply, and
   is not labelled as "Unsupported operation rmsNorm". Red today
   because no converter claims the op.
3. rmsNorm_without_scale_still_normalizes — the scale operand
   is optional (RMSNorm can be used without the trailing
   affine multiply, though most LLMs do include it). The core
   norm must still lower to real ops.

Tests use a minimal in-file fixture op stub rather than a real
RMSNorm Operation subclass since the converter only reads
`operation.name` and `operation.parameters`.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Extends NeuralNetOperationsConverter with a convertRmsNorm method
covering the rmsNorm / rms_norm / RMSNorm / RmsNorm operation
names and registers them in supportedOperations. The lowering is
the standard Llama-family form:

    rms  = sqrt(mean(x^2, axis) + eps)
    out  = scale * x / rms   (scale operand optional)

Emission style matches the softmax fix (#467) and the rest of
the emitter: reductions go through `stablehlo.custom_call
@reduce_mean`, the reduced tensor is broadcast back to the input
shape via `stablehlo.broadcast_in_dim` for the final divide, and
the epsilon is materialized as a scalar constant broadcast into
the reduced shape. Migrating all reductions to real
`stablehlo.reduce` regions is a separate refactor, explicitly
out of scope.

Axis normalization against rank handles negative axes and also
accepts an `IntArray` `normalized_shape` parameter for callers
that prefer PyTorch-style configuration. Default epsilon is
1e-6, matching Llama / Mistral / Qwen / Gemma; callers can
override via `eps` or `epsilon`.

Without a scale operand the final affine multiply is skipped and
the normalized value is returned directly — a few implementations
use RMSNorm without a learnable scale, and dropping the multiply
keeps the emitted MLIR faithful to the input graph.

Tests: 3/3 in RmsNormConverterTest green, full compile-hlo
jvmTest suite still green.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@michalharakal michalharakal merged commit ce6e7aa into develop Apr 13, 2026
2 of 4 checks passed
@michalharakal michalharakal deleted the feature/479-rmsnorm-converter branch April 13, 2026 11:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add RMSNorm StableHLO converter (P1)

1 participant