Skip to content

add MXFP8 pre-swizzling for gfx1250 GEMM (#568)#605

Open
matthiasdiener wants to merge 21 commits into
devfrom
mdiener/mxfp8-swizzle-gfx1250
Open

add MXFP8 pre-swizzling for gfx1250 GEMM (#568)#605
matthiasdiener wants to merge 21 commits into
devfrom
mdiener/mxfp8-swizzle-gfx1250

Conversation

@matthiasdiener

@matthiasdiener matthiasdiener commented Jun 1, 2026

Copy link
Copy Markdown
Contributor

Description

Cherry-picked from #568 (same code)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@matthiasdiener matthiasdiener self-assigned this Jun 1, 2026
@matthiasdiener matthiasdiener added the ci-level 1 CI test level 1 label Jun 1, 2026
@matthiasdiener matthiasdiener requested a review from alextmagro June 1, 2026 18:37
@matthiasdiener matthiasdiener marked this pull request as ready for review June 1, 2026 18:37
@matthiasdiener

Copy link
Copy Markdown
Contributor Author

Manually tested on gfx1250, should be ready to go from my perspective.

GTEST_SKIP() << "MXFP8 is not supported in current config";
}

// hipBLASLt on gfx950 produces incorrect results for certain small MXFP8

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there ticket for that?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, there isn't.

if (is_nvfp4_scaling(config.scaling_mode)) {
if (is_nvfp4_scaling(config.scaling_mode)
#ifdef USE_ROCM
|| (config.scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there should be corresponding update fort workspace size calculation: scale sizes should be added to it

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in ce60ce0

@github-actions

Copy link
Copy Markdown

Claude Walkthrough

Intent. Add MXFP8 GEMM support on gfx1250 by teaching TE to pre-swizzle MXFP8 scale tensors into the K-tiled "Tensile 3D" layout that hipBLASLt requires there, and to canonicalize operand layouts to TN (the only layout the gfx1250 MXFP8 kernels accept). Cherry-picked from #568.

Key changes.

  • New device-side swizzle implementing the K-tiled 3D layout (group of 4) for both row-wise and column-wise scales: transformer_engine/common/swizzle/swizzle.cu:385 (swizzle_scaling_mx_kernel + swizzle_scaling_factors_mx). Dispatched from swizzle_scaling_factors and the multi-tensor variant only when cuda::sm_arch() == 125 and scaling mode is MXFP8.
  • Force TN GEMM layout on gfx1250 by routing N/T requests through columnwise data, in transformer_engine/common/gemm/rocm_gemm.cu:311 and :368.
  • Tighten scale-shape validation: both scale dims must be multiples of 4 on gfx1250 — transformer_engine/common/transformer_engine.cpp:129.
  • Pad MXFP8 scale-inv allocations to multiples of 4 on gfx1250: transformer_engine/pytorch/csrc/quantizer.cpp:1700 and transformer_engine/pytorch/tensor/mxfp8_tensor.py:144 (also :175 for columnwise, and the FSDP2 split path at :525).
  • Enable the PyTorch GEMM extension swizzle call on ROCm (it was fully #ifndef USE_ROCM-guarded), and add a narrower ROCm filter inside swizzle_scales_for_gemm that returns no-op unless MXFP8 + gfx1250: transformer_engine/pytorch/csrc/extensions/gemm.cpp:287, transformer_engine/pytorch/csrc/extensions/swizzle.cpp:66 and :182.
  • JAX path mirrors the same gating: workspace budget extended to hold pre-swizzled scales (transformer_engine/jax/cpp_extensions/gemm.py:624, transformer_engine/jax/csrc/extensions/gemm.cpp:245) and xla_buffer_to_nvte_gemm_operand now invokes nvte_swizzle_scaling_factors on gfx1250 (transformer_engine/jax/csrc/extensions/gemm.cpp:91).

Walkthrough.

  • swizzle.cu is the heart of the change. A single templated kernel swizzle_scaling_mx_kernel<bool kRowwise> rewrites compact [M, K_scale] (or transposed) scales into the permuted K-tiled 3D layout via dst = (k/4)*(M*4) + m*4 + (k%4). The caller pre-fills the output with 0x7F (E8M0 identity = 2^0 = 1.0) so K_scale padding is a no-op for the GEMM. The wrapper validates that K_scale is already padded to a multiple of 4, that input/output with_gemm_swizzled_scales flags are consistent, and that only one of row/col scales is present. The same wrapper is fanned out by multi_tensor_swizzle_scaling_factors for grouped GEMM.
  • rocm_gemm.cu reshapes the operand-canonicalization logic so the previously-shared assignments are inlined per branch, then layers in a gfx1250-only override: when A is N-transposed (or B is T-transposed) the columnwise buffer is used and the trans flag is flipped to T (or N) with the corresponding lda/ldb, mirroring how tensor-scaling already worked. Net effect: gfx1250 MXFP8 GEMMs always land at hipBLASLt as TN, regardless of caller-requested layout.
  • quantizer.cpp + mxfp8_tensor.py keep allocations consistent with the new layout: scale-inv shapes are padded to multiples of 4 on both axes when the device is gfx1250, for both rowwise and columnwise. The FSDP2 split path in mxfp8_tensor.py picks 4/4 padding multiples for gfx1250 (existing branches stay at 128/4 for CUDA and 1/1 for older ROCm).
  • pytorch/csrc/extensions/gemm.cpp and swizzle.cpp previously bypassed swizzling entirely on ROCm. The new code removes the broad USE_ROCM exclusion around swizzle_scales_for_gemm so MXFP8 tensors can be pre-swizzled, but the swizzle helpers themselves bail out early for any ROCm config that is not gfx1250 MXFP8 — preserving prior behavior for gfx94x/gfx95x.
  • JAX bindings plumb the same idea through XLA. xla_buffer_to_nvte_gemm_operand now builds a second TensorWrapper pointed at the workspace-allocated swizzle_scale_ptr, calls nvte_swizzle_scaling_factors, and rebinds the input scale_inv to the swizzled pointer so downstream GEMM sees the post-swizzle layout. GemmV2FFI and the Python-side workspace sizing reserve lhs_scale_inv.size + rhs_scale_inv.size additional bytes — same accounting trick the NVFP4 path already uses.

Testing.

  • tests/cpp/operator/test_swizzle.cu gains a HIP-only MxSwizzleTestSuite with CPU reference implementations of the row- and column-wise Tensile-3D permutation, padding-zone validation, and parameterized sweeps. The CMake change removes test_swizzle.cu from the ROCm exclusion list so it actually compiles on AMD.
  • tests/cpp/operator/test_cublaslt_gemm.cu adds a swizzle_mxfp8_scales helper that calls nvte_swizzle_scaling_factors after the reference GEMM is computed and copies results back in place, broadens the basic MXFP8 size set, and introduces a ProdDqGEMMTestSuite parameterized via testing::Combine over DeepSeek3 and Qwen3 shapes x micro-batch sizes 1/2/4 x TN/NN/NT layouts. Tolerances are bumped for gfx1250 BF16 output, and a hardcoded kGfx950Skips set marks GEMM shapes known to fail on gfx950 hipBLASLt.

Notes for reviewers.

  • The gfx1250 TN-only override in rocm_gemm.cu assumes the columnwise scale/data buffer is always populated whenever an MXFP8 operand reaches GEMM in N (for A) or T (for B). The existing NVTE_CHECK(A.has_columnwise_data(), ...) guard now enforces that, but it is worth a second look for call sites that previously got away with row-only quantization on ROCm.
  • Padding scales to multiples of 4 changes the on-disk size of scale_inv tensors; the FSDP2 split path was updated for that, but anything that snapshots/loads scale_inv tensors (checkpointing, state-dict round-trips) on gfx1250 will see larger buffers than before.
  • The with_gemm_swizzled_scales flag is now load-bearing on ROCm: setting it incorrectly will either re-swizzle already-swizzled scales or skip needed swizzling. The early-return in swizzle_scales_for_gemm already checks this, but JAX callers set the flag manually after the swizzle.
  • The gfx950 skip list is shape-name-based and string-coupled to the production test parameterization — if the shape labels change, skips silently stop matching.
  • Author reports manual gfx1250 validation; no end-to-end PyTorch model test was added.

Generated by Claude. To request a code review, comment /claude review.


uint8_t val;
if constexpr (kRowwise) {
val = input[idx]; // == input[m * orig_K + k]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong input stride when padded_K_scale > orig_K_scale.

The kernel reads input[idx] = input[m * orig_K + k], treating the input as if its row stride were orig_K. But the actual scale_inv buffer is allocated with the padded scale shape [padded_M, padded_K] (see MXFP8Quantizer::get_scale_shape and mxfp8_tensor.py), so its row stride is padded_K (= the m parameter that's not passed — only padded_M is). For row m > 0 this reads from the wrong memory location.

Concrete example: data shape [3, 96]orig_M=3, orig_K=3, padded_M=4, padded_K=4. Buffer stride is 4, but the kernel reads input[1] for logical (m=1, k=0) instead of input[4].

This is masked in practice because:

  • All test_swizzle.cu cases use orig_K already a multiple of 4 (so padded_K == orig_K), and the test allocates the input buffer with size orig_M * orig_K rather than padded_M * padded_K — matching the kernel's incorrect stride assumption rather than the production layout.
  • All production LLM shapes here have K_data as a multiple of 128, so K_scale % 4 == 0 and no K-side padding occurs.

But the kernel will misbehave for K_data ∈ {32, 64, 96, 160, ...} (any K that's a multiple of 32 but not 128). Suggested fix: also pass the padded scale K (= k in the caller) as a stride parameter, e.g.:

swizzle_scaling_mx_kernel(input, output, padded_M, padded_K, orig_M, orig_K);
...
if constexpr (kRowwise) {
    val = input[m * padded_K + k];
} else {
    val = input[k * padded_M + m];  // padded_M for columnwise (currently OK because K_data%32==0 ⇒ padded_M==orig_M, but better to be explicit)
}

And update the test to allocate d_input of size padded_M * padded_K and exercise non-multiple-of-4 orig_K (e.g. {3,3}, {8,5}, {32,7}) to lock this down.


// Scale dimensions (M_scale, K_scale).
// K_scale will be padded to multiple of 4 by the test.
std::vector<std::pair<int, int>> mx_scale_dims = {

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage gap: no padded case is exercised.

Every entry in mx_scale_dims has orig_K already a multiple of 4, so K = roundup_sz(orig_K, 4) == orig_K. The test also allocates d_input of size orig_M * orig_K (line 248) rather than the padded M * K, so the test only validates the kernel against a tightly-packed (unpadded) input layout — not the padded [padded_M, padded_K] layout used by MXFP8Quantizer::get_scale_shape in production. Likewise no case has orig_M % 4 != 0.

Please add cases with orig_K not a multiple of 4 (e.g. {4,3}, {8,5}, {32,7}) and at least one with orig_M % 4 != 0 (e.g. {3,4}), allocating the input with the padded stride matching production. This would have caught the stride bug flagged in swizzle.cu:427.

// Relax for gfx1250
cudaDeviceProp prop;
(void)cudaGetDeviceProperties(&prop, 0);
if (prop.major == 12 && type == DType::kBFloat16) {

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This relaxation fires for every FP8 GEMM test on any gfx12 device (tensor-scaling FP8 included), not just MXFP8 on gfx1250. The comment ("Relax for gfx1250") and the PR scope suggest the intent is the gfx1250 MXFP8 path specifically. Consider guarding with use_mxfp8 and/or prop.major == 12 && prop.minor == 5 so non-MXFP8 FP8 tests don't silently lose precision coverage on this arch.

for (size_t i = 0; i < input.size(); i++) {
if (is_mxfp8_scaling(input[i]->scaling_mode)) {
any_mxfp8 = true;
}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If any_mxfp8 is true, every tensor in the batch is dispatched to swizzle_scaling_factors_mx, which asserts scaling_mode == NVTE_MXFP8_1D_SCALING. The single-tensor multi_tensor_swizzle_scaling_factors contract (a few lines below) only requires each tensor to be (fp8 && mxfp8) || (fp4 && nvfp4) — i.e. a batch can in principle mix MXFP8 and NVFP4 tensors. On gfx1250 such a mix would crash inside the MX helper.

Probably theoretical today (NVFP4 isn't supported on gfx1250 yet), but the safer form is either to assert all_mxfp8 here, or to dispatch per-tensor through swizzle_scaling_factors (which already routes MXFP8→MX on gfx1250 and leaves NVFP4 on the existing path).

Comment on lines +661 to +664
if (!a_colwise) swizzle_mxfp8_scales(A, true);
if (a_colwise) swizzle_mxfp8_scales(A, false);
if (!b_colwise) swizzle_mxfp8_scales(B, true);
if (b_colwise) swizzle_mxfp8_scales(B, false);

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: each pair of if (!x) ...; if (x) ...; lines unconditionally calls swizzle_mxfp8_scales with the negation of x. The same lines exist again at 776-779.

Suggested change
if (!a_colwise) swizzle_mxfp8_scales(A, true);
if (a_colwise) swizzle_mxfp8_scales(A, false);
if (!b_colwise) swizzle_mxfp8_scales(B, true);
if (b_colwise) swizzle_mxfp8_scales(B, false);
swizzle_mxfp8_scales(A, !a_colwise);
swizzle_mxfp8_scales(B, !b_colwise);

@github-actions

Copy link
Copy Markdown

Claude review

Reviewed the gfx1250 MXFP8 pre-swizzle plumbing (GEMM path + scale padding + new swizzle_scaling_factors_mx kernel + production-shape GEMM tests). Overall the structure is clean: ROCm-specific behavior is consistently guarded behind #ifdef __HIP_PLATFORM_AMD__ / USE_ROCM plus a runtime sm_arch() == 125 check, and the JAX workspace fix flagged earlier looks consistent on both Python and C++ sides.

Findings (see inline comments):

  • swizzle_scaling_mx_kernel row-wise input stride (swizzle.cu:427) — kernel uses orig_K as the input row stride, but the production scale_inv buffer is allocated with the padded shape [padded_M, padded_K]. Latent unless K_data is not a multiple of 128 (the typical LLM case is fine), but worth fixing.
  • test_swizzle.cu coverage gap — none of the parametrized cases exercise padded_K > orig_K or padded_M > orig_M, so the kernel is only validated against an unpadded layout. This is what masks the bug above.
  • getTestTolerances relaxation is too broad — fires for all FP8 GEMM tests on any gfx12 device, not just MXFP8.
  • multi_tensor_swizzle_scaling_factors gfx1250 path — dispatches all tensors to MX if any tensor is MXFP8; would crash on a mixed MXFP8/NVFP4 batch. Theoretical today but fragile.
  • Style nitswizzle_mxfp8_scales(t, !x) collapses the two-line if-pair in performTest/performDqTest.

Copyright headers: OK (all 12 changed files have correct AMD 2025-2026 or 2026 lines, NVIDIA lines preserved).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 1 CI test level 1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants