Skip to content

[fix] Fix CUTLASS grouped GEMM segfault for empty groups#3037

Open
Baibaifan wants to merge 1 commit into
NVIDIA:mainfrom
Baibaifan:zero_groupgemm
Open

[fix] Fix CUTLASS grouped GEMM segfault for empty groups#3037
Baibaifan wants to merge 1 commit into
NVIDIA:mainfrom
Baibaifan:zero_groupgemm

Conversation

@Baibaifan
Copy link
Copy Markdown

Description

Handle grouped GEMM calls where all groups are empty.

MoE routing can legally produce a microbatch where no local expert receives
tokens. The PyTorch grouped GEMM wrapper filters those zero-token GEMMs, but
the CUTLASS grouped GEMM path could still be reached with num_gemms == 0 and
then dereference A[0]/B[0]/D[0], causing a native segfault.

Return early after filtering all GEMMs in te_general_grouped_gemm, and add a
defensive num_gemms <= 0 guard in nvte_multi_tensor_gemm.

Add a Hopper/CUTLASS regression test covering all-empty grouped GEMM inputs for
TN, NN, and NT layouts.

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Changes

Please list the changes introduced in this PR:

  • Return early from te_general_grouped_gemm when all GEMMs were filtered.
  • Add a defensive num_gemms <= 0 guard in nvte_multi_tensor_gemm.
  • Add a Hopper-only regression test for all-empty grouped GEMM inputs under
    NVTE_USE_CUTLASS_GROUPED_GEMM=1.
  • Cover TN, NN, and NT layouts.

Testing

pytest -q tests/pytorch/test_numerics.py::test_grouped_gemm_cutlass_empty_groups -s

Checklist:

  • I have read and followed the [contributing guidelines]
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@Baibaifan Baibaifan requested a review from ksivaman as a code owner May 22, 2026 04:06
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 22, 2026
Signed-off-by: yangfan.bai <yangfan.bai@shopee.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 22, 2026

Greptile Summary

This PR fixes a native segfault in the CUTLASS grouped GEMM path when MoE routing produces a microbatch where every expert receives zero tokens. The PyTorch wrapper already filtered zero-token GEMMs individually, but did not handle the case where all groups were filtered, allowing nvte_multi_tensor_gemm to be invoked with an empty tensor array and dereference A[0]/B[0]/D[0].

  • gemm.cpp: After the per-group filtering loop, an early return is added when te_A_wrappers is empty, returning bias to match the function's existing normal return value.
  • cublaslt_gemm.cu: A num_gemms <= 0 guard is added at the top of nvte_multi_tensor_gemm as a secondary defence for direct C-API callers.
  • test_numerics.py: A Hopper-only regression test exercises TN, NN, and NT layouts with all-zero m_splits; the NT case meaningfully asserts that the wgrad buffer is zeroed in-place, while TN/NN serve primarily as no-crash guards.

Confidence Score: 4/5

The fix is narrowly scoped, targets a real crash path, and does not change behaviour for non-empty inputs. Safe to merge.

Both changes are small, isolated guards with no effect on the non-empty GEMM path. The early return in te_general_grouped_gemm correctly mirrors the existing return value semantics. The only notable weakness is that the TN/NN test assertions are trivially true (empty-vs-empty compare), so those sub-cases are purely a no-crash check rather than a semantic postcondition.

No files require special attention beyond the test assertion coverage noted in the inline comment on test_numerics.py.

Important Files Changed

Filename Overview
transformer_engine/pytorch/csrc/extensions/gemm.cpp Adds an early return in te_general_grouped_gemm when all GEMM groups are filtered (te_A_wrappers is empty); returns the bias vector, consistent with the existing return at the bottom of the function. The fix is logically correct: empty groups are already handled by the in-loop continue + zero_(), and the early exit prevents a subsequent call to nvte_multi_tensor_gemm with a zero-length vector.
transformer_engine/common/gemm/cublaslt_gemm.cu Adds a defensive num_gemms <= 0 guard at the entry of nvte_multi_tensor_gemm. This is a secondary safety net for callers that bypass te_general_grouped_gemm and invoke the C API directly with an empty array.
tests/pytorch/test_numerics.py Adds a Hopper-only regression test for all-empty grouped GEMM. TN/NN out-tensors are 0-element so the zero-check is trivial (useful only as a no-crash guard); NT out-tensor is a real (n,k) buffer whose in-place zero_() is meaningfully exercised. Test env-var setup/teardown follows the existing pattern.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["te_general_grouped_gemm(A, B, D, ...)"] --> B["Loop over all GEMM groups"]
    B --> C{"te_A.numel()==0 or\nte_B.numel()==0?"}
    C -- Yes --> D["zero_() output if non-empty\nzero_() bias/gelu if grad\ncontinue"]
    D --> B
    C -- No --> E["Build te_A/B/D wrappers\nappend to vectors"]
    E --> B
    B -- loop done --> F{"te_A_wrappers.empty()?\n(ALL groups filtered)"}
    F -- Yes --> G["return bias  ← NEW early return\n(prevents null deref in CUTLASS path)"]
    F -- No --> H["swizzle scales\nbuild NVTETensor vectors"]
    H --> I["nvte_multi_tensor_gemm(...)"]
    I --> J{"num_gemms <= 0?\n← NEW guard"}
    J -- Yes --> K["return (no-op)"]
    J -- No --> L{"is_hopper &&\nuse_cutlass?"}
    L -- No --> M["multi_stream_cublas_gemm"]
    L -- Yes --> N["CUTLASS grouped GEMM\n(accesses A[0]/B[0]/D[0])"]
Loading

Reviews (1): Last reviewed commit: "[fix] fix grouped GEMM zero-work bug." | Re-trigger Greptile

Comment on lines +2850 to 2855
for tensor in out:
torch.testing.assert_close(tensor, torch.zeros_like(tensor), rtol=0, atol=0)


def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None:
data = grouped_tensor.rowwise_data
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Zero-assertion is trivially true for TN and NN layouts

For TN and NN, out is constructed as a list containing a single 0-element tensor (torch.empty(0, n/k, ...)). torch.testing.assert_close on two empty tensors passes unconditionally regardless of any computation, so those two sub-cases only serve as crash/segfault guards. The meaningful assertion only fires for NT, where out[0] is a full (n, k) buffer that the C++ code zeros in-place. Consider either documenting this in a comment or, for TN/NN, adding a small non-empty output tensor and asserting it is zero to provide the same level of postcondition coverage as NT.

@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented May 29, 2026

Hi @Baibaifan, could you resolve the conflicts?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants