[fix] Fix CUTLASS grouped GEMM segfault for empty groups#3037
Conversation
Signed-off-by: yangfan.bai <yangfan.bai@shopee.com>
Greptile SummaryThis PR fixes a native segfault in the CUTLASS grouped GEMM path when MoE routing produces a microbatch where every expert receives zero tokens. The PyTorch wrapper already filtered zero-token GEMMs individually, but did not handle the case where all groups were filtered, allowing
Confidence Score: 4/5The fix is narrowly scoped, targets a real crash path, and does not change behaviour for non-empty inputs. Safe to merge. Both changes are small, isolated guards with no effect on the non-empty GEMM path. The early return in te_general_grouped_gemm correctly mirrors the existing return value semantics. The only notable weakness is that the TN/NN test assertions are trivially true (empty-vs-empty compare), so those sub-cases are purely a no-crash check rather than a semantic postcondition. No files require special attention beyond the test assertion coverage noted in the inline comment on test_numerics.py. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["te_general_grouped_gemm(A, B, D, ...)"] --> B["Loop over all GEMM groups"]
B --> C{"te_A.numel()==0 or\nte_B.numel()==0?"}
C -- Yes --> D["zero_() output if non-empty\nzero_() bias/gelu if grad\ncontinue"]
D --> B
C -- No --> E["Build te_A/B/D wrappers\nappend to vectors"]
E --> B
B -- loop done --> F{"te_A_wrappers.empty()?\n(ALL groups filtered)"}
F -- Yes --> G["return bias ← NEW early return\n(prevents null deref in CUTLASS path)"]
F -- No --> H["swizzle scales\nbuild NVTETensor vectors"]
H --> I["nvte_multi_tensor_gemm(...)"]
I --> J{"num_gemms <= 0?\n← NEW guard"}
J -- Yes --> K["return (no-op)"]
J -- No --> L{"is_hopper &&\nuse_cutlass?"}
L -- No --> M["multi_stream_cublas_gemm"]
L -- Yes --> N["CUTLASS grouped GEMM\n(accesses A[0]/B[0]/D[0])"]
Reviews (1): Last reviewed commit: "[fix] fix grouped GEMM zero-work bug." | Re-trigger Greptile |
| for tensor in out: | ||
| torch.testing.assert_close(tensor, torch.zeros_like(tensor), rtol=0, atol=0) | ||
|
|
||
|
|
||
| def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None: | ||
| data = grouped_tensor.rowwise_data |
There was a problem hiding this comment.
Zero-assertion is trivially true for TN and NN layouts
For TN and NN, out is constructed as a list containing a single 0-element tensor (torch.empty(0, n/k, ...)). torch.testing.assert_close on two empty tensors passes unconditionally regardless of any computation, so those two sub-cases only serve as crash/segfault guards. The meaningful assertion only fires for NT, where out[0] is a full (n, k) buffer that the C++ code zeros in-place. Consider either documenting this in a comment or, for TN/NN, adding a small non-empty output tensor and asserting it is zero to provide the same level of postcondition coverage as NT.
|
Hi @Baibaifan, could you resolve the conflicts? |
Description
Handle grouped GEMM calls where all groups are empty.
MoE routing can legally produce a microbatch where no local expert receives
tokens. The PyTorch grouped GEMM wrapper filters those zero-token GEMMs, but
the CUTLASS grouped GEMM path could still be reached with num_gemms == 0 and
then dereference A[0]/B[0]/D[0], causing a native segfault.
Return early after filtering all GEMMs in
te_general_grouped_gemm, and add adefensive
num_gemms <= 0guard in nvte_multi_tensor_gemm.Add a Hopper/CUTLASS regression test covering all-empty grouped GEMM inputs for
TN, NN, and NT layouts.
Type of change
Changes
Please list the changes introduced in this PR:
te_general_grouped_gemmwhen all GEMMs were filtered.num_gemms <= 0guard innvte_multi_tensor_gemm.NVTE_USE_CUTLASS_GROUPED_GEMM=1.TN,NN, andNTlayouts.Testing
pytest -q tests/pytorch/test_numerics.py::test_grouped_gemm_cutlass_empty_groups -sChecklist: