Skip to content

NVFP4 default recipe (RHT + stochastic rounding) crashes on sm_120: fused kernel exceeds 101376-byte shared-memory opt-in cap #3062

@Infatoshi

Description

@Infatoshi

Description

On sm_120 (RTX PRO 6000 Blackwell Workstation / GeForce Blackwell), any NVFP4 GEMM that enables the Random Hadamard Transform or stochastic rounding crashes inside the fused RHT kernel:

RuntimeError: .../common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu:1200
  in function row_col_rht_gemm_ntt_w_sfc: CUDA Error: invalid argument

Since these are on by default in NVFP4BlockScaling() (backward-gradient quant uses random_hadamard_transform=True, stochastic_rounding=True), the default NVFP4 recipe is unusable for training on sm_120. The plain NVFP4 GEMM (2D weight scaling + round-to-nearest) works fine.

Minimal repro

import torch, transformer_engine.pytorch as tep
from transformer_engine.common.recipe import NVFP4BlockScaling

lin = tep.Linear(512, 512, bias=False, params_dtype=torch.float32).cuda()
x = torch.randn(512, 512, device="cuda", dtype=torch.bfloat16, requires_grad=True)
with tep.fp8_autocast(enabled=True, fp8_recipe=NVFP4BlockScaling()):  # RHT+SR default
    y = lin(x)
y.float().pow(2).mean().backward()      # -> CUDA Error: invalid argument

# Works if RHT and SR are disabled:
# NVFP4BlockScaling(disable_rht=True, disable_stochastic_rounding=True)

Root cause

The failing line is the dynamic-shared-memory opt-in:

// row_cast_col_hadamard_transform_cast_fusion.cu:1200
NVTE_CHECK_CUDA(cudaFuncSetAttribute(*kernel_ptr,
    cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));   // smem_size = sizeof(SharedStorage<...>)

smem_size is sized for the sm_100 datacenter shared-memory budget (~232 KB). On sm_120 the opt-in cap is 101376 bytes, so cudaFuncSetAttribute(MaxDynamicSharedMemorySize, smem_size) returns cudaErrorInvalidValue ("invalid argument"). Confirmed with a standalone probe on the device:

device: NVIDIA RTX PRO 6000 Blackwell Workstation Edition  sm_120
sharedMemPerBlockOptin = 101376 bytes
cudaFuncSetAttribute(MaxDynShmem= 101376) -> no error
cudaFuncSetAttribute(MaxDynShmem= 101377) -> invalid argument
cudaFuncSetAttribute(MaxDynShmem= 232448) -> invalid argument

The kernel cubins exist for the device family (cuobjdump -elf shows sm_120f instantiations of row_col_rht_gemm/col_hadamard alongside sm_100a/sm_103a), so this is a launch-configuration / shared-memory-budget problem, not a missing image. The kernel's SharedStorage (multiple mainloop/accumulator/scheduler pipeline stages) simply doesn't fit in sm_120's 99 KB.

Two secondary points:

  1. check_recipe_support() has no NVFP4BlockScaling branch, so there is no early, descriptive error — users hit a raw CUDA error deep in the kernel.
  2. MXFP8BlockScaling is already reported unsupported on 12.0; NVFP4+RHT/SR is the analogous gap.

Environment

  • GPU: NVIDIA RTX PRO 6000 Blackwell Workstation Edition (sm_120, compute_cap 12.0, 96 GB)
  • Transformer Engine 2.15.0 (prebuilt transformer_engine_cu13 wheel + transformer_engine_torch built from sdist)
  • PyTorch 2.11.0+cu130, CUDA 13.0 runtime / 13.2 toolkit, driver 580.159.03, Ubuntu 24.04

Suggested fix

  • Provide an sm_120 tile/pipeline-stage config for the fused RHT/SR GEMM whose SharedStorage fits within 101376 bytes (fewer mainloop/accumulator stages), or a non-fused fallback (separate RHT + cast + GEMM) on sm_120.
  • Add an NVFP4BlockScaling branch to check_recipe_support() that, on sm_120 with RHT/SR enabled, raises a clear error (or auto-disables RHT/SR with a warning) instead of letting cudaFuncSetAttribute fail cryptically.

Workaround

Disable RHT and stochastic rounding on sm_120: NVFP4BlockScaling(disable_rht=True, disable_stochastic_rounding=True). The remaining NVFP4 path (2D weight scaling + RNE) trains stably; keeping numerically sensitive layers (first/last blocks) in higher precision recovers convergence in our tests.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions