Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
c873e46
add production GEMM tests
matthiasdiener May 19, 2026
c4c2ea5
rename
matthiasdiener May 26, 2026
00da5e6
Merge remote-tracking branch 'origin/dev' into mdiener/prodgemm-test
matthiasdiener May 28, 2026
8eaf06d
restructure based on review comments
matthiasdiener May 28, 2026
77f1c45
clarify switch
matthiasdiener May 28, 2026
76c8d98
skip known-bad tests
matthiasdiener May 28, 2026
db3123f
loosen tolerances a bit
matthiasdiener May 28, 2026
c6cc59f
restructure tests into test_cublaslt_gemm
matthiasdiener Jun 1, 2026
b6440e0
Merge remote-tracking branch 'origin/dev' into mdiener/prodgemm-test
matthiasdiener Jun 1, 2026
e46d6da
add MXFP8 pre-swizzling for gfx1250 GEMM (#568)
matthiasdiener May 21, 2026
8b37f0f
Merge remote-tracking branch 'origin/dev' into mdiener/mxfp8-swizzle-…
matthiasdiener Jun 2, 2026
d5a16c9
skip known-failing mxfp8 tests due to hipblaslt limitation
matthiasdiener Jun 2, 2026
1cf0dad
Merge branch 'dev' into mdiener/prodgemm-test
matthiasdiener Jun 2, 2026
a24f739
Merge branch 'dev' into mdiener/mxfp8-swizzle-gfx1250
matthiasdiener Jun 2, 2026
5f510c1
Fix gemm.cpp
matthiasdiener Jun 3, 2026
b540068
Merge branch 'dev' into mdiener/mxfp8-swizzle-gfx1250
matthiasdiener Jun 3, 2026
1d2f222
Merge branch 'mdiener/prodgemm-test' into mdiener/mxfp8-swizzle-gfx1250
matthiasdiener Jun 4, 2026
bfedb4a
more padding fixes
matthiasdiener Jun 5, 2026
98fb2ff
Merge remote-tracking branch 'origin/dev' into mdiener/mxfp8-swizzle-…
matthiasdiener Jun 10, 2026
ce60ce0
increase jax WS size
matthiasdiener Jun 10, 2026
b668a2c
more support for non-TN
matthiasdiener Jun 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions tests/cpp/operator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ add_executable(test_operator
test_multi_padding.cu
test_multi_unpadding.cu
test_causal_softmax.cu
test_swizzle.cu #CUDA-only test
test_swizzle.cu
test_swap_first_dims.cu
test_grouped_gemm.cu #CUDA-only test
../test_common.cu)
Expand All @@ -42,7 +42,6 @@ if(USE_ROCM)
# Remove CUDA-only tests and add ROCm specific ones
list(REMOVE_ITEM test_cuda_sources
test_cast_float8blockwise.cu
test_swizzle.cu
test_grouped_gemm.cu)
list(APPEND test_cuda_sources
test_dequantize_nvfp4.cu
Expand Down
297 changes: 279 additions & 18 deletions tests/cpp/operator/test_cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
************************************************************************/
#include <cmath>
#include <iostream>
#include <optional>
#include <set>
#include <string>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/swizzle.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"

Expand All @@ -30,7 +33,107 @@ std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes = {

std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes_mxfp8 = {
{32, 128, 16},
{64, 128, 32},
{128, 128, 64},
{64, 256, 32},
{128, 384, 64},
{256, 512, 128},
{512, 1024, 256},
{768, 3072, 4096},
{1024, 2048, 128},
{4096, 8192, 64},
};

// ============================================================================
// Production LLM shapes for MXFP8 GEMM testing.
//
// Each shape is tested with 3 micro-batch sizes (MBS = 1, 2, 4)
// yielding tokens = 4096, 8192, 16384, and 3 layouts (TN, NN, NT)
// via ::testing::Combine.
//
// GemmPass selects the FP8 type combination:
// FWD: E4M3 x E4M3 -> BF16
// DGRAD: E5M2 x E4M3 -> BF16
// WGRAD: E4M3 x E5M2 -> BF16
// ============================================================================

enum class GemmPass { FWD, DGRAD, WGRAD };

struct ShapeDef {
const char* label;
size_t dim1; // FWD/DGRAD: N, WGRAD: M
size_t dim2; // FWD/DGRAD: K, WGRAD: N
GemmPass pass;
};

std::ostream& operator<<(std::ostream& os, const ShapeDef& s) {
return os << s.label;
}

static void resolve_mkn(const ShapeDef& s, size_t mbs,
size_t& m, size_t& k, size_t& n) {
size_t tokens = mbs * 4096;
switch (s.pass) {
case GemmPass::FWD:
case GemmPass::DGRAD:
m = tokens; n = s.dim1; k = s.dim2;
break;
case GemmPass::WGRAD:
m = s.dim1; n = s.dim2; k = tokens;
break;
}
}

// DeepSeek3 (hidden=7168, MLA, seq=4096, incl. LM Head)
static const ShapeDef deepseek3_shapes[] = {
// Forward (M=tokens, N, K)
{"DeepSeek3_Linear0_fwd", 1536, 7168, GemmPass::FWD},
{"DeepSeek3_Linear1_fwd", 576, 7168, GemmPass::FWD},
{"DeepSeek3_LNLinear0_fwd", 24576, 1536, GemmPass::FWD},
{"DeepSeek3_LNLinear1_fwd", 32768, 512, GemmPass::FWD},
{"DeepSeek3_Linear_attn_fwd", 7168, 16384, GemmPass::FWD},
{"DeepSeek3_LNMLP_gateup_fwd", 36864, 7168, GemmPass::FWD},
{"DeepSeek3_LNMLP_down_fwd", 7168, 18432, GemmPass::FWD},
{"DeepSeek3_SharedExp_gu_fwd", 4096, 7168, GemmPass::FWD},
{"DeepSeek3_SharedExp_dn_fwd", 7168, 2048, GemmPass::FWD},
{"DeepSeek3_TopKRouter_fwd", 256, 7168, GemmPass::FWD},
{"DeepSeek3_LMHead_fwd", 129280, 7168, GemmPass::FWD},
// Dgrad (M=tokens, N, K)
{"DeepSeek3_attn_dgrad", 16384, 7168, GemmPass::DGRAD},
{"DeepSeek3_LNLinear1_dgrad", 512, 32768, GemmPass::DGRAD},
{"DeepSeek3_LNLinear0_dgrad", 1536, 24576, GemmPass::DGRAD},
{"DeepSeek3_SharedExp_dn_dgrad", 2048, 7168, GemmPass::DGRAD},
{"DeepSeek3_SharedExp_gu_dgrad", 7168, 4096, GemmPass::DGRAD},
{"DeepSeek3_TopKRouter_dgrad", 7168, 256, GemmPass::DGRAD},
{"DeepSeek3_MLP_post_dgrad", 7168, 14336, GemmPass::DGRAD},
{"DeepSeek3_LMHead_dgrad", 7168, 129280, GemmPass::DGRAD},
// Wgrad (M, N, K=tokens)
{"DeepSeek3_attn_wgrad", 16384, 7168, GemmPass::WGRAD},
{"DeepSeek3_LNLinear1_wgrad", 512, 32768, GemmPass::WGRAD},
{"DeepSeek3_LNLinear0_wgrad", 1536, 24576, GemmPass::WGRAD},
{"DeepSeek3_SharedExp_dn_wgrad", 2048, 7168, GemmPass::WGRAD},
{"DeepSeek3_SharedExp_gu_wgrad", 7168, 4096, GemmPass::WGRAD},
{"DeepSeek3_TopKRouter_wgrad", 7168, 256, GemmPass::WGRAD},
{"DeepSeek3_LMHead_wgrad", 7168, 129280, GemmPass::WGRAD},
};

// Qwen3 (hidden=4096, GQA, seq=4096, incl. LM Head)
static const ShapeDef qwen3_shapes[] = {
// Forward (M=tokens, N, K)
{"Qwen3_LNLinear_QKV_fwd", 9216, 4096, GemmPass::FWD},
{"Qwen3_Linear_attn_fwd", 4096, 8192, GemmPass::FWD},
{"Qwen3_Router_fwd", 128, 4096, GemmPass::FWD},
{"Qwen3_LMHead_fwd", 151936, 4096, GemmPass::FWD},
// Dgrad (M=tokens, N, K)
{"Qwen3_Router_dgrad", 4096, 128, GemmPass::DGRAD},
{"Qwen3_Linear_attn_dgrad", 8192, 4096, GemmPass::DGRAD},
{"Qwen3_LNLinear_dgrad", 4096, 9216, GemmPass::DGRAD},
{"Qwen3_LMHead_dgrad", 4096, 151936, GemmPass::DGRAD},
// Wgrad (M, N, K=tokens)
{"Qwen3_Router_wgrad", 4096, 128, GemmPass::WGRAD},
{"Qwen3_Linear_attn_wgrad", 8192, 4096, GemmPass::WGRAD},
{"Qwen3_LNLinear_wgrad", 4096, 9216, GemmPass::WGRAD},
{"Qwen3_LMHead_wgrad", 4096, 151936, GemmPass::WGRAD},
};

// A, B, Bias, Gelu, D
Expand Down Expand Up @@ -303,6 +406,40 @@ void cpu_rowwise_to_columnwise(
}
}

// Swizzle MXFP8 scale_inv of a test::Tensor in-place for gfx1250.
static void swizzle_mxfp8_scales(test::Tensor &t, bool rowwise) {
using namespace transformer_engine;
void *scale_ptr = rowwise ? t.rowwise_scale_inv_dptr()
: t.columnwise_scale_inv_dptr();
if (!scale_ptr) return;
const NVTEShape scale_shape = rowwise ? t.rowwise_scale_inv_shape()
: t.columnwise_scale_inv_shape();
const NVTEShape data_shape = rowwise ? t.rowwise_shape()
: t.columnwise_shape();
size_t num_scales = 1;
for (size_t d = 0; d < scale_shape.ndim; d++) num_scales *= scale_shape.data[d];
uint8_t *d_tmp = nullptr;
NVTE_CHECK_CUDA(cudaMalloc(&d_tmp, num_scales));
TensorWrapper input_tw(NVTE_MXFP8_1D_SCALING);
TensorWrapper output_tw(NVTE_MXFP8_1D_SCALING);
output_tw.set_with_gemm_swizzled_scales(true);
if (rowwise) {
input_tw.set_rowwise_data(nullptr, t.dtype(), data_shape);
input_tw.set_rowwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape);
output_tw.set_rowwise_data(nullptr, t.dtype(), data_shape);
output_tw.set_rowwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape);
} else {
input_tw.set_columnwise_data(nullptr, t.dtype(), data_shape);
input_tw.set_columnwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape);
output_tw.set_columnwise_data(nullptr, t.dtype(), data_shape);
output_tw.set_columnwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape);
}
nvte_swizzle_scaling_factors(input_tw.data(), output_tw.data(), 0);
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
NVTE_CHECK_CUDA(cudaMemcpy(scale_ptr, d_tmp, num_scales, cudaMemcpyDeviceToDevice));
NVTE_CHECK_CUDA(cudaFree(d_tmp));
}

std::pair<double, double> getTestTolerances(const DType type, bool use_fp8, bool use_mxfp8) {
auto [atol, rtol] = getTolerances(type);

Expand All @@ -318,6 +455,12 @@ std::pair<double, double> getTestTolerances(const DType type, bool use_fp8, bool
else if (use_fp8) {
atol = 1e-3;
rtol = std::max(rtol, 1e-2);
// Relax for gfx1250
cudaDeviceProp prop;
(void)cudaGetDeviceProperties(&prop, 0);
if (prop.major == 12 && type == DType::kBFloat16) {

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This relaxation fires for every FP8 GEMM test on any gfx12 device (tensor-scaling FP8 included), not just MXFP8 on gfx1250. The comment ("Relax for gfx1250") and the PR scope suggest the intent is the gfx1250 MXFP8 path specifically. Consider guarding with use_mxfp8 and/or prop.major == 12 && prop.minor == 5 so non-MXFP8 FP8 tests don't silently lose precision coverage on this arch.

rtol = std::max(rtol, 5e-2);
}
}
else if (type == DType::kBFloat16) {
//relax for certain prime number TN gemm
Expand Down Expand Up @@ -496,6 +639,31 @@ void performTest(const TestParams& params) {
#endif
Tensor Workspace("Workspace", TShape{ workspace_size }, DType::kByte);

//perform the reference gemm on GPU (before swizzle, which modifies scales in-place)
Tensor RefD("RefD", TShape{ params.n, params.m }, dtype);
Tensor RefPreGeluOut;

if (params.use_gelu) {
RefPreGeluOut = Tensor("RefPreGeluOut", TShape{ params.n, params.m }, gelu_type);
}

run_reference<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(
params,
A,
B,
params.use_bias ? &bias : nullptr,
D,
RefD,
params.use_gelu ? &RefPreGeluOut : nullptr);

// On gfx1250, hipBLASLt MXFP8 kernels expect pre-swizzled scales.
if (use_mxfp8 && prop.major == 12) {
if (!a_colwise) swizzle_mxfp8_scales(A, true);
if (a_colwise) swizzle_mxfp8_scales(A, false);
if (!b_colwise) swizzle_mxfp8_scales(B, true);
if (b_colwise) swizzle_mxfp8_scales(B, false);
Comment on lines +661 to +664

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: each pair of if (!x) ...; if (x) ...; lines unconditionally calls swizzle_mxfp8_scales with the negation of x. The same lines exist again at 776-779.

Suggested change
if (!a_colwise) swizzle_mxfp8_scales(A, true);
if (a_colwise) swizzle_mxfp8_scales(A, false);
if (!b_colwise) swizzle_mxfp8_scales(B, true);
if (b_colwise) swizzle_mxfp8_scales(B, false);
swizzle_mxfp8_scales(A, !a_colwise);
swizzle_mxfp8_scales(B, !b_colwise);

}

//perform the gemm in GPU
nvte_cublas_gemm(A.data(),
B.data(),
Expand All @@ -517,23 +685,6 @@ void performTest(const TestParams& params) {
pre_gelu_out.to_cpu();
}

//perform the reference gemm on GPU
Tensor RefD("RefD", TShape{ params.n, params.m }, dtype);
Tensor RefPreGeluOut;

if (params.use_gelu) {
RefPreGeluOut = Tensor("RefPreGeluOut", TShape{ params.n, params.m }, gelu_type);
}

run_reference<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(
params,
A,
B,
params.use_bias ? &bias : nullptr,
D,
RefD,
params.use_gelu ? &RefPreGeluOut : nullptr);

// check if error message happens in running
(void)cudaDeviceSynchronize();
auto err = cudaGetLastError();
Expand All @@ -559,7 +710,9 @@ void performTest(const TestParams& params) {

#ifdef __HIP_PLATFORM_AMD__
template <typename A_Type, typename B_Type, typename D_Type>
void performDqTest(const TestParams &params) {
void performDqTest(const TestParams &params,
std::optional<double> atol_override = std::nullopt,
std::optional<double> rtol_override = std::nullopt) {
DType atype = TypeInfo<A_Type>::dtype;
DType btype = TypeInfo<B_Type>::dtype;
DType dtype = TypeInfo<D_Type>::dtype;
Expand All @@ -582,6 +735,17 @@ void performDqTest(const TestParams &params) {
GTEST_SKIP() << "MXFP8 is not supported in current config";
}

// hipBLASLt on gfx950 produces incorrect results for certain small MXFP8

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there ticket for that?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, there isn't.

// GEMMs with non-TN layouts.
if (prop.major == 9 && prop.minor == 5) {
const bool is_NN = !params.transa && !params.transb;
const bool is_NT = !params.transa && params.transb;
if ((is_NN && params.m == 64) ||
(is_NT && params.m > 32 && params.m <= 128 && params.n <= 64)) {
GTEST_SKIP() << "hipBLASLt MXFP8 non-TN GEMM with small M/N is not supported on gfx950";
}
}

DType ref_type = dtype;
TShape a_shape = params.transa ? TShape{params.m, params.k} : TShape{params.k, params.m};
TShape b_shape = params.transb ? TShape{params.k, params.n} : TShape{params.n, params.k};
Expand All @@ -605,6 +769,16 @@ void performDqTest(const TestParams &params) {
nvte_dequantize(A_fp8.data(), A_ref.data(), 0);
nvte_dequantize(B_fp8.data(), B_ref.data(), 0);

// On gfx1250, hipBLASLt MXFP8 kernels expect pre-swizzled scales.
if (prop.major == 12) {
const bool a_colwise = !params.transa;
const bool b_colwise = params.transb;
if (!a_colwise) swizzle_mxfp8_scales(A_fp8, true);
if (a_colwise) swizzle_mxfp8_scales(A_fp8, false);
if (!b_colwise) swizzle_mxfp8_scales(B_fp8, true);
if (b_colwise) swizzle_mxfp8_scales(B_fp8, false);
}

Tensor bias;
Tensor pre_gelu_out;

Expand Down Expand Up @@ -633,6 +807,10 @@ void performDqTest(const TestParams &params) {

//compare results
auto [atol, rtol] = getTestTolerances(dtype, true, true);
if (atol_override)
atol = *atol_override;
if (rtol_override)
rtol = *rtol_override;
compareResults("D", D, D_ref.rowwise_cpu_dptr<D_Type>(), true, atol, rtol);
}
#endif // __HIP_PLATFORM_AMD__
Expand Down Expand Up @@ -751,6 +929,89 @@ INSTANTIATE_TEST_SUITE_P(OperatorTest, DqGEMMTestSuite,
return MKN(std::get<0>(info.param)) + "x" + TN(std::get<3>(info.param));
});

// ============================================================================
// Production GEMM shape instantiations (run with --gtest_filter='ProdGemm*')
// ============================================================================

// Known-failing GEMM shapes on gfx950
static const std::set<std::string> kGfx950Skips = {
"DeepSeek3_Linear1_fwd_mbs1_NT",
"DeepSeek3_Linear1_fwd_mbs2_NT",
"DeepSeek3_Linear1_fwd_mbs4_NT",
"DeepSeek3_LNLinear0_fwd_mbs4_NN",
"DeepSeek3_LNLinear0_fwd_mbs4_NT",
"DeepSeek3_attn_wgrad_mbs1_NN",
"Qwen3_LMHead_fwd_mbs2_NN",
"Qwen3_Router_fwd_mbs2_NT",
"Qwen3_LMHead_fwd_mbs4_TN",
"Qwen3_LMHead_fwd_mbs4_NN",
"Qwen3_LMHead_fwd_mbs4_NT",
};

// Production GEMM test suite using ShapeDef x MBS x Layout via testing::Combine.
using ProdGemmParam = std::tuple<ShapeDef, size_t, Layout>;

class ProdDqGEMMTestSuite : public ::testing::TestWithParam<ProdGemmParam> {};

TEST_P(ProdDqGEMMTestSuite, TestMxfp8Dq) {
const auto& shape = std::get<0>(GetParam());
size_t mbs = std::get<1>(GetParam());
const auto& layout = std::get<2>(GetParam());

std::string name = std::string(shape.label) + "_mbs" + std::to_string(mbs)
+ "_" + TN(layout);
if (kGfx950Skips.count(name)) {
GTEST_SKIP() << "Known gfx950 hipBLASLt failure: " << name;
}

size_t m, k, n;
resolve_mkn(shape, mbs, m, k, n);

TestParams params = {.m = m, .k = k, .n = n,
.use_bias = false, .use_gelu = false,
.transa = layout.first, .transb = layout.second,
.scaling_mode = NVTEScalingMode::NVTE_MXFP8_1D_SCALING};

// Production shapes use looser tolerances: the MXFP8 and bf16 reference
// GEMM use different internal accumulation paths, so results can differ
// by up to 1 ULP in bf16 (~1.5-2% relative).
const double prod_atol = 1e-3;
const double prod_rtol = 2e-2;

switch (shape.pass) {
case GemmPass::FWD:
performDqTest<fp8, fp8, bf16>(params, prod_atol, prod_rtol);
break;
case GemmPass::DGRAD:
performDqTest<bf8, fp8, bf16>(params, prod_atol, prod_rtol);
break;
case GemmPass::WGRAD:
performDqTest<fp8, bf8, bf16>(params, prod_atol, prod_rtol);
break;
}
}

static auto prodTestName = [](const testing::TestParamInfo<ProdGemmParam>& info) {
const auto& shape = std::get<0>(info.param);
size_t mbs = std::get<1>(info.param);
const auto& layout = std::get<2>(info.param);
return std::string(shape.label) + "_mbs" + std::to_string(mbs) + "_" + TN(layout);
};

INSTANTIATE_TEST_SUITE_P(ProdGemmDeepSeek3, ProdDqGEMMTestSuite,
::testing::Combine(
::testing::ValuesIn(deepseek3_shapes),
::testing::Values(size_t{1}, size_t{2}, size_t{4}),
::testing::ValuesIn(kLayouts)),
prodTestName);

INSTANTIATE_TEST_SUITE_P(ProdGemmQwen3, ProdDqGEMMTestSuite,
::testing::Combine(
::testing::ValuesIn(qwen3_shapes),
::testing::Values(size_t{1}, size_t{2}, size_t{4}),
::testing::ValuesIn(kLayouts)),
prodTestName);

TEST(InputGenTest, FillUniform_DoesNotGetOverwrittenByFromCpu) {
const size_t rows = 128;
const size_t cols = 256;
Expand Down
Loading
Loading