Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 148 additions & 7 deletions tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,13 @@ def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm(
torch.cuda.manual_seed(23)

num_gemms = len(m_splits)
uses_cudnn_grouped_path = (
out_dtype in (torch.bfloat16, torch.float16)
and not use_4over6
and all(m % 256 == 0 for m in m_splits)
and k % 128 == 0
and n % 128 == 0
)

x_quantizer = NVFP4Quantizer(
fp4_dtype=te_dtype,
Expand Down Expand Up @@ -301,40 +308,174 @@ def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm(
)
)
bias.append(torch.randn(n, dtype=torch.bfloat16, device=device) if use_bias else None)
if uses_cudnn_grouped_path:
expected.append(
general_gemm(
w_nvfp4[-1],
x_nvfp4[-1],
out_dtype=out_dtype,
layout="TN",
bias=bias[-1],
)[0]
)

if single_output:
out = [torch.empty((sum(m_splits), n), dtype=out_dtype, device=device)]
else:
out = [torch.empty((m, n), dtype=out_dtype, device=device) for m in m_splits]

grouped_gemm_args = (
w_nvfp4,
x_nvfp4,
out,
)
grouped_gemm_kwargs = {
"quantization_params": [None] * num_gemms,
"out_dtype": out_dtype,
"layout": "TN",
"m_splits": m_splits,
"bias": bias,
"use_bias": use_bias,
"single_output": single_output,
}
if not uses_cudnn_grouped_path:
with pytest.raises((NotImplementedError, ValueError)):
general_grouped_gemm(*grouped_gemm_args, **grouped_gemm_kwargs)
return

try:
import cudnn
except ImportError as exc:
pytest.skip(f"cudnn frontend unavailable: {exc}")
if not hasattr(cudnn, "grouped_gemm_quant_wrapper_sm100"):
pytest.skip("grouped_gemm_quant_wrapper_sm100 unavailable")

grouped_out, _, _ = general_grouped_gemm(*grouped_gemm_args, **grouped_gemm_kwargs)

if single_output:
grouped_slices = torch.split(grouped_out, m_splits, dim=0)
else:
grouped_slices = grouped_out
for grouped, ref in zip(grouped_slices, expected):
torch.testing.assert_close(grouped, ref, atol=0.5, rtol=0.25)


@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"use_bias, single_output",
[(False, False), (True, True)],
ids=["no_bias_list_output", "bias_single_output"],
)
def test_nvfp4_row_scaled_grouped_gemm_uses_cudnn_quant_wrapper(
use_bias: bool,
single_output: bool,
monkeypatch,
):
if torch.cuda.get_device_capability() < (10, 0):
pytest.skip("Requires SM100+ for cuDNN grouped GEMM quant kernel.")

try:
import cudnn
except ImportError as exc:
pytest.skip(f"cudnn frontend unavailable: {exc}")
if not hasattr(cudnn, "grouped_gemm_quant_wrapper_sm100"):
pytest.skip("grouped_gemm_quant_wrapper_sm100 unavailable")

te_dtype = tex.DType.kFloat4E2M1
device = "cuda"
dtype = torch.bfloat16
m_splits = [256, 512]
k = 128
n = 128
torch.manual_seed(29)
torch.cuda.manual_seed(29)

x_quantizer = NVFP4Quantizer(
fp4_dtype=te_dtype,
rowwise=True,
columnwise=False,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=False,
with_post_rht_amax=False,
row_scaled_nvfp4=True,
)
w_quantizer = NVFP4Quantizer(
fp4_dtype=te_dtype,
rowwise=True,
columnwise=True,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=False,
with_post_rht_amax=False,
)

x_nvfp4 = []
w_nvfp4 = []
bias = []
expected = []
for m in m_splits:
x = torch.randn((m, k), dtype=dtype, device=device)
w = torch.randn((n, k), dtype=dtype, device=device)
x_nvfp4.append(
x_quantizer.update_quantized(
x,
x_quantizer.make_empty(x.shape, dtype=dtype, device=device),
)
)
w_nvfp4.append(
w_quantizer.update_quantized(
w,
w_quantizer.make_empty(w.shape, dtype=dtype, device=device),
)
)
bias.append(torch.randn(n, dtype=torch.bfloat16, device=device) if use_bias else None)
expected.append(
general_gemm(
w_nvfp4[-1],
x_nvfp4[-1],
out_dtype=out_dtype,
out_dtype=dtype,
layout="TN",
bias=bias[-1],
)[0]
)

calls = []
original_wrapper = cudnn.grouped_gemm_quant_wrapper_sm100

def traced_wrapper(*args, **kwargs):
calls.append(kwargs)
return original_wrapper(*args, **kwargs)

monkeypatch.setattr(cudnn, "grouped_gemm_quant_wrapper_sm100", traced_wrapper)
if single_output:
out = [torch.empty((sum(m_splits), n), dtype=out_dtype, device=device)]
out = [torch.empty((sum(m_splits), n), dtype=dtype, device=device)]
else:
out = [torch.empty((m, n), dtype=out_dtype, device=device) for m in m_splits]

out = [torch.empty((m, n), dtype=dtype, device=device) for m in m_splits]
grouped_out, _, _ = general_grouped_gemm(
w_nvfp4,
x_nvfp4,
out,
quantization_params=[None] * num_gemms,
out_dtype=out_dtype,
quantization_params=[None] * len(m_splits),
out_dtype=dtype,
layout="TN",
m_splits=m_splits,
bias=bias,
use_bias=use_bias,
single_output=single_output,
)

assert len(calls) == 1
assert calls[0]["sf_vec_size"] == 16
assert calls[0]["row_scale_tensor"].shape == (sum(m_splits),)
assert calls[0]["b_major"] == "k"
assert (calls[0]["bias_tensor"] is not None) == use_bias
if single_output:
grouped_slices = torch.split(grouped_out, m_splits, dim=0)
else:
grouped_slices = grouped_out
for grouped, ref in zip(grouped_slices, expected):
torch.testing.assert_close(grouped, ref, atol=0.0, rtol=0.0)
torch.testing.assert_close(grouped, ref, atol=0.5, rtol=0.25)


def check_nvfp4_row_scaled_gemm_matches_emulated(
Expand Down
Loading