Skip to content

[draft] NVFP4 block-16 scale support for SM90 mixed-input grouped GEMM (CUTLASS 4.4.2)#2

Draft
changjonathanc wants to merge 1 commit into
poolside-cutlass-v4.4.2from
nvfp4-w4a8-v4.4.2
Draft

[draft] NVFP4 block-16 scale support for SM90 mixed-input grouped GEMM (CUTLASS 4.4.2)#2
changjonathanc wants to merge 1 commit into
poolside-cutlass-v4.4.2from
nvfp4-w4a8-v4.4.2

Conversation

@changjonathanc

Copy link
Copy Markdown

Redo of #1 on a CUTLASS 4.4.2 base (the original #1 targeted 4.3.5).

Why 4.4.2 (not 4.3.5)

Measured on H200 with the Colonels W4A8 MoE kernel (capture-replay, isolated, locked-clock):
4.4.2's mixed-input grouped GEMM is ~5.3x faster than 4.3.5's. On 4.4.2 the W4A8 path
is ~2x faster than Marlin W4A16 at large batch (matching the kernel-benchmark premise); on
4.3.5 it is ~2.6x slower — i.e. building W4A8 against 4.3.5 is a net regression. So the
fork base must be 4.4.2.

Contents

  • Base branch poolside-cutlass-v4.4.2: NVIDIA CUTLASS v4.4.2 + the existing poolside
    TensorMapStorage shared-memory workaround (cherry-picked from the 4.3.5 fork branch).
  • This PR (+71/-14, two headers): the NVFP4 block-16 mixed-input patch. All new behaviour
    is gated by UseNvfp4Block16Scales/Broadcast, false for every existing instantiation, so
    existing kernels are byte-for-byte unaffected.

sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp: UseNvfp4Block16Scales /
ScaleAtomM / ScaleAtomK; multi-column SmemLayoutAtomScale; relaxed static_assert;
grouped-GEMM init clamp; explicit StrideScale; relaxed can_implement chunk-size check.
mixed_input_utils.hpp: UseNvfp4Block16ScaleBroadcast + get_mma_smem_layout_scale()
broadcast view; refresh scales every k-block when multi-column.

Supersedes #1. The Colonels/Forge W4A8 extension builds against this branch.

🤖 Generated with Claude Code

Adds optional NVFP4 (e2m1 + e4m3-scale) block-16 scaling to the SM90 mixed-input
collectives for weight-scaled W4A8 grouped GEMM (scale block 16 < GMMA K tile).
Gated by UseNvfp4Block16Scales/Broadcast (false for all existing instantiations,
so existing kernels are byte-for-byte unaffected). Base: CUTLASS v4.4.2 +
TensorMapStorage workaround. 4.4.2 is required: its mixed-input GEMM is ~5x faster
than 4.3.5, making W4A8 ~2x faster than Marlin at large batch.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant