From 1f707d71c48a60706c1f5f6f3110fe81ff6054e3 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Sat, 25 Apr 2026 17:16:42 +0000 Subject: [PATCH 01/47] initial commit for CK Tile MXFP8 integration for gfx1250 --- .../ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp | 492 ++++++++++++++++++ 1 file changed, 492 insertions(+) create mode 100644 transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp diff --git a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp new file mode 100644 index 000000000..c7746b563 --- /dev/null +++ b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp @@ -0,0 +1,492 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include +#include "../../common.h" +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/kernel/mx_grouped_gemm_kernel.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" + +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using mx_grouped_gemm_kargs = ck_tile::MxGroupedGemmHostArgs<>; + +template struct TETypeToCKType; +template <> struct TETypeToCKType { using type = ck_tile::fp8_t; }; +template <> struct TETypeToCKType { using type = ck_tile::bf8_t; }; +template <> struct TETypeToCKType { using type = ck_tile::e8m0_t; }; +template <> struct TETypeToCKType { using type = ck_tile::half_t; }; +template <> struct TETypeToCKType { using type = ck_tile::bfloat16_t; }; +template <> struct TETypeToCKType { using type = float; }; + +struct GroupedGemmRunContext { + const NVTETensor* A = nullptr; + const NVTETensor* B = nullptr; + NVTETensor* D = nullptr; + + int group_num = 0; + bool transA = false; + bool transB = false; + + void* workspace = nullptr; + size_t workspace_bytes = 0; + hipStream_t stream = nullptr; + +}; + +static constexpr ck_tile::index_t ScaleBlockSize = 32; + +enum struct MxGemmPipelineType +{ + CompTDMV1, + CompTDMV2 +}; + +template +struct MxGemmPipelineTypeSelector; +template +struct MxGemmPipelineTypeSelector +{ + using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompTDM; + using pipeline = ck_tile::GemmPipelineAgBgCrCompTDMV1; + static constexpr auto GetName() { return "GemmPipelineAgBgCrCompTDMV1"; } +}; + +template +struct MxGemmPipelineTypeSelector +{ + using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompTDM; + using pipeline = ck_tile::GemmPipelineAgBgCrCompTDMV2; + static constexpr auto GetName() { return "GemmPipelineAgBgCrCompTDMV2"; } +}; + +static inline const transformer_engine::SimpleTensor& data_view(const transformer_engine::Tensor& t) { + return t.data; +} + +static inline const transformer_engine::SimpleTensor& scale_inv_view(const transformer_engine::Tensor& t) { + return t.scale_inv; +} + +template +static inline bool has_sufficient_workspace(const GroupedGemmRunContext& ctx) { + const size_t needed = Kernel::GetWorkSpaceSize(ctx.group_num); + if (!ctx.workspace || ctx.workspace_bytes < needed) { + NVTE_WARN("ck_tile_grouped_gemm: insufficient workspace for CK path. Needed bytes=", needed, + ", available bytes=", ctx.workspace_bytes, ". Falling back."); + return false; + } + return true; +} + +struct GroupedGemKernelParam_Wmma +{ + static const bool kPadM = false; + static const bool kPadN = false; + static const bool kPadK = false; + static const int kBlockPerCu = 1; + static const ck_tile::index_t M_Tile = 64; + static const ck_tile::index_t N_Tile = 64; + static const ck_tile::index_t K_Tile = 128; + static const ck_tile::index_t M_Warp = 2; + static const ck_tile::index_t N_Warp = 2; + static const ck_tile::index_t K_Warp = 1; + static const ck_tile::index_t M_Warp_Tile = 32; + static const ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 128; +}; + +template +__global__ void preshuffle_scale_gfx1250_kernel(const ScaleType* __restrict__ src, + ScaleType* __restrict__ dst, + int actual_rows, + int output_rows, + int KScale) +{ + static_assert(ScaleBlockSize == 32 && sizeof(ScaleType) == 1, + "gfx1250 scale preshuffle only supports 8-bit scale with ScaleBlockSize=32"); + constexpr int MPerXdlops = 16; + constexpr int KPerXdlops = 128; + constexpr int MNPack = 2; + constexpr int KPack = 1; + constexpr int MNStep = MPerXdlops; // 16 + constexpr int KStep = KPerXdlops / ScaleBlockSize; // 4 + const int K0 = KScale / (KPack * KStep); + const int linear = blockIdx.x * blockDim.x + threadIdx.x; + const int total = output_rows * KScale; + if(linear >= total) + return; + const int mn = linear / KScale; + const int k = linear % KScale; + const int iMNRepeat = mn / (MNStep * MNPack); + const int tempmn = mn % (MNStep * MNPack); + const int iKRepeat = k / (KStep * KPack); + const int tempk = k % (KStep * KPack); + const int outputIndex = + (iMNRepeat * MNPack * MNStep) * (KStep * KPack * K0) + + (iKRepeat * KStep * KPack) * (MNStep * MNPack) + + tempmn * (KStep * KPack) + + tempk; + ScaleType value{}; + if(mn < actual_rows) + { + if constexpr(KStride) + value = src[mn * KScale + k]; + else + value = src[k * actual_rows + mn]; + } + dst[outputIndex] = value; +} + +template +void preShuffleScaleBuffer_gfx1250(const ScaleType* src, + ScaleType* dst, + int actual_rows, + int output_rows, + int KScale, + hipStream_t stream) +{ + constexpr int KPerXdlops = 128; + constexpr int KStep = KPerXdlops / ScaleBlockSize; // 4 + if(KScale % KStep != 0) + { + NVTE_ERROR("preshuffle_scale_gfx1250: KScale must be a multiple of 4, " + "i.e. original K must be a multiple of 128 for ScaleBlockSize=32."); + } + const int total = output_rows * KScale; + constexpr int block_size = 256; + const int grid_size = (total + block_size - 1) / block_size; + hipLaunchKernelGGL((preshuffle_scale_gfx1250_kernel), + dim3(grid_size), + dim3(block_size), + 0, + stream, + src, + dst, + actual_rows, + output_rows, + KScale); + NVTE_CHECK_CUDA(hipGetLastError()); +} + +template +bool invoke_mx_grouped_gemm(const std::vector& descs, const GroupedGemmRunContext& ctx, const ck_tile::stream_config& stream_cfg) +{ + // check hardware WMMA support for the warp tile + static constexpr bool has_wmma_support = + ck_tile::has_wmma_traits_v; + + NVTE_CHECK(has_wmma_support, + "ck_tile_mx_grouped_gemm: unsupported gfx125 WMMA traits for " + "AType/BType/AccType with warp tile shape ", + MXFP8GemmConfig::M_Warp_Tile, "x", + MXFP8GemmConfig::N_Warp_Tile, "x", + MXFP8GemmConfig::K_Warp_Tile); + + using CLayout = RowMajor; + constexpr bool preshuffle = false; + constexpr bool DoubleSmemBuffer = true; // TDM pipeline requires double smem buffer + constexpr bool TransposeC = + std::is_same_v && + MXFP8GemmConfig::M_Warp_Tile == MXFP8GemmConfig::N_Warp_Tile; + static constexpr bool StructuredSparsity = false; + static constexpr bool NumWaveGroup = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transA, kTransA, { + using ALayout = std::conditional_t; + TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transB, kTransB, { + using BLayout = std::conditional_t; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using UniversalGemmProblem = + ck_tile::MxGemmPipelineProblem; + using PipelineType = MxGemmPipelineType::CompTDMV1; + /* make pipeline selective */ + using GemmPipeline = + typename MxGemmPipelineTypeSelector::pipeline; + using GemmEpilogue = ck_tile::TdmEpilogue< + ck_tile::CShuffleEpilogueProblem,//DsDataType + float, + CType, + ck_tile::tuple<>,//DsLayout + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + MXFP8GemmConfig::M_Warp, + MXFP8GemmConfig::N_Warp, + MXFP8GemmConfig::M_Warp_Tile, + MXFP8GemmConfig::N_Warp_Tile, + MXFP8GemmConfig::K_Warp_Tile, + UniversalGemmProblem::TransposeC, + 1, /*kNumWaveGroups_*/ + false, /*FixedVectorSize_*/ + 1, /*VectorSizeC_*/ + false, /*TiledMMAPermuteN_*/ + 1, /*BlockedXDLN_PerWarp_*/ + DoubleSmemBuffer, /*DoubleSmemBuffer*/ + AType, /*AType_*/ + BType /*BType_*/>>; + using Kernel = ck_tile::MxGroupedGemmKernel; + + if (!has_sufficient_workspace(ctx)) { + return false; + } + + auto kargs = Kernel::MakeKargs(descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + NVTE_WARN("ck_tile_mx_grouped_gemm: CK_Tile kernel arguments not supported for this config. " + "Falling back."); + return false; + } + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(kargs); + NVTE_CHECK_CUDA(hipMemcpyAsync(ctx.workspace, + kargs.data(), + kargs.size() * sizeof(typename decltype(kargs)::value_type), + hipMemcpyHostToDevice, + ctx.stream)); + ck_tile::ignore = ck_tile::launch_kernel( + stream_cfg, ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, + ck_tile::cast_pointer_to_constant_address_space(ctx.workspace), + kargs.size())); + return true; + }); + }); + return false; +} + +bool ck_tile_mx_grouped_gemm(const NVTETensor* A, + const NVTETensor* B, + NVTETensor* D, + int group_num, + bool transA, + bool transB, + NVTETensor* workspace, + bool accumulate,//ignored for now + hipStream_t stream) { + if (group_num <= 0) { + return true; + } + + // Normalize input mats + // I.e., swap A and B, as well as transa and transb. + const NVTETensor* A_use = B; + const NVTETensor* B_use = A; + bool transA_use = transB; + bool transB_use = transA; + + // Validate scale type / data type combination + // Expected input data format: fp8/bf8 (e4m3/e5m2) + // Expected scale data format: e8m0 + const auto* A0 = convertNVTETensorCheck(A_use[0]); + const auto* B0 = convertNVTETensorCheck(B_use[0]); + const auto* D0 = convertNVTETensorCheck(D[0]); + NVTE_CHECK(A0->scale_inv.has_data(), "ck_tile_mx_grouped_gemm: A[0] scale_inv is not initialized"); + NVTE_CHECK(B0->scale_inv.has_data(), "ck_tile_mx_grouped_gemm: B[0] scale_inv is not initialized"); + + const auto a_scale_dtype = A0->scale_inv.dtype; + const auto b_scale_dtype = B0->scale_inv.dtype; + NVTE_CHECK(a_scale_dtype == DType::kFloat8E8M0, + "ck_tile_mx_grouped_gemm: A scale_inv dtype must be Float8E8M0, got ", + static_cast(a_scale_dtype)); + + NVTE_CHECK(b_scale_dtype == DType::kFloat8E8M0, + "ck_tile_mx_grouped_gemm: B scale_inv dtype must be Float8E8M0, got ", + static_cast(b_scale_dtype)); + + const auto a_dtype = A0->dtype(); + const auto b_dtype = B0->dtype(); + const auto d_dtype = D0->dtype(); + NVTE_CHECK(is_fp8_dtype(a_dtype), "ck_tile_mx_grouped_gemm: A dtype must be FP8"); + NVTE_CHECK(is_fp8_dtype(b_dtype), "ck_tile_mx_grouped_gemm: B dtype must be FP8"); + + using AScaleType = typename TETypeToCKType::type; + using BScaleType = typename TETypeToCKType::type; + + void* ws_ptr = nullptr; + size_t ws_bytes = 0; + if (workspace) { + auto* ws_te = convertNVTETensorCheck(*workspace); + ws_ptr = ws_te->data.dptr; + ws_bytes = ws_te->data.numel() * typeToSize(ws_te->data.dtype); + } + + GroupedGemmRunContext ctx = { + A_use, + B_use, + D, + group_num, + transA_use, + transB_use, + ws_ptr, + ws_bytes, + stream}; + + const ck_tile::stream_config s{ctx.stream}; + + std::vector descs; + descs.reserve(group_num); + + std::vector> a_scale_shuffled_bufs; + std::vector> b_scale_shuffled_bufs; + a_scale_shuffled_bufs.reserve(group_num); + b_scale_shuffled_bufs.reserve(group_num); + + for (int i = 0; i < group_num; i++) { + const transformer_engine::Tensor* const A_te = + transformer_engine::convertNVTETensorCheck(ctx.A[i]); + const transformer_engine::Tensor* const B_te = + transformer_engine::convertNVTETensorCheck(ctx.B[i]); + transformer_engine::Tensor* D_te = + transformer_engine::convertNVTETensorCheck(ctx.D[i]); + const auto& a = data_view(*A_te); + const auto& b = data_view(*B_te); + const auto& d = data_view(*D_te); + int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; + if (!get_flat_2d_dims(*A_te, Ad0, Ad1) || + !get_flat_2d_dims(*B_te, Bd0, Bd1) || + !get_flat_2d_dims(*D_te, Dd0, Dd1)) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: expected all groups to be rank>=2."); + } + const auto& a_scales = scale_inv_view(*A_te); + const auto& b_scales = scale_inv_view(*B_te); + if (a_scales.shape.size() != 2 || b_scales.shape.size() != 2) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: expected A/B scale_inv tensors to be rank-2."); + } + const int64_t M = ctx.transA ? Ad1 : Ad0; + const int64_t K = ctx.transA ? Ad0 : Ad1; + const int64_t N = ctx.transB ? Bd0 : Bd1; + const int64_t Kb = ctx.transB ? Bd1 : Bd0; + if (K % ScaleBlockSize != 0) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: K must be a multiple of ScaleBlockSize for MX GEMM", i); + } + const int KScale = static_cast(K / ScaleBlockSize); + if (Kb != K) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: K mismatch between A and B in group ", i); + } + if (Dd0 != M || Dd1 != N) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: D shape mismatch in group ", i); + } + const ck_tile::index_t stride_A = static_cast(Ad1); + const ck_tile::index_t stride_B = static_cast(Bd1); + const ck_tile::index_t stride_E = static_cast(Dd1); + // Pre-shuffle scale buffers for the hardware + const int a_scale_actual_rows = static_cast(M); + const int a_scale_output_rows = + ck_tile::integer_least_multiple( + static_cast(M), + static_cast(GroupedGemKernelParam_Wmma::M_Warp_Tile)); + const int b_scale_actual_rows = static_cast(N); + const int b_scale_output_rows = static_cast(N); + const size_t a_scale_shuffled_bytes = + static_cast(a_scale_output_rows) * + static_cast(KScale) * + sizeof(AScaleType); + const size_t b_scale_shuffled_bytes = + static_cast(b_scale_output_rows) * + static_cast(KScale) * + sizeof(BScaleType); + a_scale_shuffled_bufs.push_back( + std::make_unique(a_scale_shuffled_bytes)); + b_scale_shuffled_bufs.push_back( + std::make_unique(b_scale_shuffled_bytes)); + void* a_scale_shuffled_ptr = a_scale_shuffled_bufs.back()->GetDeviceBuffer(); + void* b_scale_shuffled_ptr = b_scale_shuffled_bufs.back()->GetDeviceBuffer(); + preShuffleScaleBuffer_gfx1250( + reinterpret_cast(a_scales.dptr), + reinterpret_cast(a_scale_shuffled_ptr), + a_scale_actual_rows, + a_scale_output_rows, + KScale, + stream); + preShuffleScaleBuffer_gfx1250( + reinterpret_cast(b_scales.dptr), + reinterpret_cast(b_scale_shuffled_ptr), + b_scale_actual_rows, + b_scale_output_rows, + KScale, + stream); + descs.emplace_back(mx_grouped_gemm_kargs( + a.dptr, + a_scale_shuffled_ptr, + b.dptr, + b_scale_shuffled_ptr, + {/*ds_ptr*/}, + d.dptr, + 1,//kbatch + M, + N, + K, + stride_A, + stride_B, + {/*stride_Ds*/}, + stride_E)); + } + // invoke gemm + bool ok = false; + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(a_dtype, a_te_type, { + using AType = typename TETypeToCKType::type; + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(b_dtype, b_te_type, { + using BType = typename TETypeToCKType::type; + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { + using CType = typename TETypeToCKType::type; + ok = invoke_mx_grouped_gemm(descs,ctx,s); + }); + }); + }); + return ok; +} From e102f00a1e3c73386a9c79a0b990c52203a7274f Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Sat, 25 Apr 2026 17:38:06 +0000 Subject: [PATCH 02/47] ck mxfp8 gfx1250 integration builds successfully --- .../ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp | 33 +++++++++++++++---- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp index c7746b563..1ced0d61e 100644 --- a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp @@ -14,6 +14,9 @@ #include "ck_tile/ops/gemm/kernel/mx_grouped_gemm_kernel.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" +namespace transformer_engine { +namespace mx_grouped_gemm { + using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; using mx_grouped_gemm_kargs = ck_tile::MxGroupedGemmHostArgs<>; @@ -21,7 +24,6 @@ using mx_grouped_gemm_kargs = ck_tile::MxGroupedGemmHostArgs<>; template struct TETypeToCKType; template <> struct TETypeToCKType { using type = ck_tile::fp8_t; }; template <> struct TETypeToCKType { using type = ck_tile::bf8_t; }; -template <> struct TETypeToCKType { using type = ck_tile::e8m0_t; }; template <> struct TETypeToCKType { using type = ck_tile::half_t; }; template <> struct TETypeToCKType { using type = ck_tile::bfloat16_t; }; template <> struct TETypeToCKType { using type = float; }; @@ -41,6 +43,18 @@ struct GroupedGemmRunContext { }; +// Treat TE tensors as generalized 2D matrices by flattening: +// (D1, D2, ..., Dn) -> (D1*...*D(n-1), Dn), consistent with TE Tensor::flat_*_dim. +static inline bool get_flat_2d_dims(const transformer_engine::Tensor& t, + int64_t& d0, int64_t& d1) { + if (t.shape().size() < 2) { + return false; + } + d0 = static_cast(t.flat_first_dim()); + d1 = static_cast(t.flat_last_dim()); + return true; +} + static constexpr ck_tile::index_t ScaleBlockSize = 32; enum struct MxGemmPipelineType @@ -247,10 +261,10 @@ bool invoke_mx_grouped_gemm(const std::vector& descs, con BType, AScaleType, BScaleType>; - using PipelineType = MxGemmPipelineType::CompTDMV1; /* make pipeline selective */ using GemmPipeline = - typename MxGemmPipelineTypeSelector::pipeline; + typename MxGemmPipelineTypeSelector::pipeline; using GemmEpilogue = ck_tile::TdmEpilogue< ck_tile::CShuffleEpilogueProblemscale_inv.has_data(), "ck_tile_mx_grouped_gemm: A[0] scale_inv is not initialized"); - NVTE_CHECK(B0->scale_inv.has_data(), "ck_tile_mx_grouped_gemm: B[0] scale_inv is not initialized"); + NVTE_CHECK(A0->scale_inv.dptr != nullptr, + "ck_tile_mx_grouped_gemm: A[0] scale_inv is not initialized"); + NVTE_CHECK(B0->scale_inv.dptr != nullptr, + "ck_tile_mx_grouped_gemm: B[0] scale_inv is not initialized"); const auto a_scale_dtype = A0->scale_inv.dtype; const auto b_scale_dtype = B0->scale_inv.dtype; @@ -352,8 +368,8 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, NVTE_CHECK(is_fp8_dtype(a_dtype), "ck_tile_mx_grouped_gemm: A dtype must be FP8"); NVTE_CHECK(is_fp8_dtype(b_dtype), "ck_tile_mx_grouped_gemm: B dtype must be FP8"); - using AScaleType = typename TETypeToCKType::type; - using BScaleType = typename TETypeToCKType::type; + using AScaleType = ck_tile::e8m0_t; + using BScaleType = ck_tile::e8m0_t; void* ws_ptr = nullptr; size_t ws_bytes = 0; @@ -490,3 +506,6 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, }); return ok; } + +} // namespace mx_grouped_gemm +} // namespace transformer_engine \ No newline at end of file From 52a28875302c38dcfbc6ae128f4024be4d56a2e9 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Sat, 25 Apr 2026 18:11:08 +0000 Subject: [PATCH 03/47] add entrypoint to ck mx group gemm in caller --- .../ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp | 15 ++++++++- .../ck_mx_grouped_gemm/ck_mx_grouped_gemm.hpp | 16 ++++++++++ .../common/gemm/cublaslt_gemm.cu | 32 +++++++++++-------- 3 files changed, 49 insertions(+), 14 deletions(-) create mode 100644 transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.hpp diff --git a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp index 1ced0d61e..b05b844f8 100644 --- a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp @@ -508,4 +508,17 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, } } // namespace mx_grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine + +bool ck_tile_mx_grouped_gemm(const NVTETensor* A, + const NVTETensor* B, + NVTETensor* D, + int group_num, + bool transA, + bool transB, + NVTETensor* workspace, + bool accumulate, + hipStream_t stream) { + return transformer_engine::mx_grouped_gemm::ck_tile_mx_grouped_gemm( + A, B, D, group_num, transA, transB, workspace, accumulate, stream); +} diff --git a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.hpp b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.hpp new file mode 100644 index 000000000..96d3cd11b --- /dev/null +++ b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.hpp @@ -0,0 +1,16 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +bool ck_tile_mx_grouped_gemm(const NVTETensor* A, + const NVTETensor* B, + NVTETensor* D, + int group_num, + bool transA, + bool transB, + NVTETensor* workspace, + bool accumulate, + hipStream_t stream); + diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 7326f330f..1d030686b 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -33,7 +33,8 @@ #include "./cutlass_grouped_gemm.cuh" #else #include "ck_grouped_gemm/ck_grouped_gemm.h" -#endif +#include "ck_mx_grouped_gemm/ck_mx_grouped_gemm.hpp" + #ifndef __HIP_PLATFORM_AMD__ namespace { @@ -1163,13 +1164,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor auto A_dt = inputA->data.dtype; auto B_dt = inputB->data.dtype; auto D_dt = OutputD->data.dtype; - return ( - (is_fp8_dtype(A_dt) && is_fp8_dtype(B_dt)) - ) || - ( - (A_dt == B_dt) && (A_dt == D_dt) && - (is_fp16_dtype(A_dt)) - ); + return (is_fp8_dtype(A_dt) && is_fp8_dtype(B_dt)); #else auto A_type = get_cuda_dtype(inputA->data.dtype); auto B_type = get_cuda_dtype(inputB->data.dtype); @@ -1192,11 +1187,22 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && is_supported_dtype() && #ifdef __HIP_PLATFORM_AMD__ true) { - if (!ck_tile_grouped_gemm(A, B, D, num_gemms, transa, transb, workspace, accumulate, stream)) { - if (warn_fallback) { - NVTE_WARN("Fallback to cuBLAS grouped GEMM."); - } - cublas_path(); + const bool mxfp8_gemm = !use_fp4 && is_mxfp8_scaling(inputA->scaling_mode); + + bool handled_by_ck = false; + if (mxfp8_gemm) { + handled_by_ck = ck_tile_mx_grouped_gemm( + A, B, D, num_gemms, transa, transb, workspace, accumulate, stream); + } else { + handled_by_ck = ck_tile_grouped_gemm( + A, B, D, num_gemms, transa, transb, workspace, accumulate, stream); + } + + if (!handled_by_ck) { + if (warn_fallback) { + NVTE_WARN("Fallback to cuBLAS grouped GEMM."); + } + cublas_path(); } #else all_groups_uniform_k128(B, transb)) { From 80227775d296c6d140957110582b2a46d07de1dd Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Sat, 25 Apr 2026 18:25:46 +0000 Subject: [PATCH 04/47] temporary hacky change to test_numerics for bringup testing --- tests/pytorch/test_numerics.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index d9c7d1fb0..642999ef7 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2253,6 +2253,7 @@ def test_grouped_linear_accuracy_cutlass( delay_wgrad_compute, ): os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" + os.environ["NVTE_ROCM_ENABLE_MXFP8"] = "1" test_grouped_linear_accuracy( dtype, num_gemms, @@ -2268,6 +2269,7 @@ def test_grouped_linear_accuracy_cutlass( use_cutlass=True, ) os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) + os.environ.pop("NVTE_ROCM_ENABLE_MXFP8", None) @pytest.mark.parametrize("dtype", param_types, ids=str) From bc6253d013e0dbc19ff22539cb426def9531521f Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Sat, 25 Apr 2026 19:34:12 +0000 Subject: [PATCH 05/47] add warning print to confirm we are in fallback --- transformer_engine/common/gemm/cublaslt_gemm.cu | 3 +++ 1 file changed, 3 insertions(+) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 1d030686b..9a45b6e10 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1125,6 +1125,9 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor // Currently only support cutlass group gemm on Hopper Arch if (!(is_hopper && use_cutlass)) { #endif + if (warn_fallback) { + NVTE_WARN("Fallback to cuBLAS grouped GEMM."); + } cublas_path(); return; } From d26f52e9715d81bffa5cb6769e418ebb568445f0 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Sat, 2 May 2026 16:34:30 +0000 Subject: [PATCH 06/47] MXFP8 grouped fwd/bwd now reaches CK path and runs without fallback/crash; remaining issue is numerical validation vs BF16 sequential reference. --- .../ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp | 168 +++++++++++++----- .../common/gemm/cublaslt_gemm.cu | 20 ++- 2 files changed, 140 insertions(+), 48 deletions(-) diff --git a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp index b05b844f8..59ae7c0ff 100644 --- a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp @@ -41,6 +41,8 @@ struct GroupedGemmRunContext { size_t workspace_bytes = 0; hipStream_t stream = nullptr; + bool use_a_colwise_data = false; + bool use_b_colwise_data = false; }; // Treat TE tensors as generalized 2D matrices by flattening: @@ -55,6 +57,19 @@ static inline bool get_flat_2d_dims(const transformer_engine::Tensor& t, return true; } +// Columnwise storage is the physical transposed view used to rewrite a +// normalized GEMM into CK's preferred NT presentation. Interpret its +// 2D shape consistently with the FP8 grouped GEMM path. +static inline bool get_columnwise_storage_2d_dims(const transformer_engine::SimpleTensor& t, + int64_t& d0, int64_t& d1) { + if (t.shape.size() != 2) { + return false; + } + d0 = static_cast(t.shape[1]); + d1 = static_cast(t.shape[0]); + return true; +} + static constexpr ck_tile::index_t ScaleBlockSize = 32; enum struct MxGemmPipelineType @@ -81,19 +96,11 @@ struct MxGemmPipelineTypeSelector static constexpr auto GetName() { return "GemmPipelineAgBgCrCompTDMV2"; } }; -static inline const transformer_engine::SimpleTensor& data_view(const transformer_engine::Tensor& t) { - return t.data; -} - -static inline const transformer_engine::SimpleTensor& scale_inv_view(const transformer_engine::Tensor& t) { - return t.scale_inv; -} - template static inline bool has_sufficient_workspace(const GroupedGemmRunContext& ctx) { const size_t needed = Kernel::GetWorkSpaceSize(ctx.group_num); if (!ctx.workspace || ctx.workspace_bytes < needed) { - NVTE_WARN("ck_tile_grouped_gemm: insufficient workspace for CK path. Needed bytes=", needed, + NVTE_WARN("ck_tile_mx_grouped_gemm: insufficient workspace for CK path. Needed bytes=", needed, ", available bytes=", ctx.workspace_bytes, ". Falling back."); return false; } @@ -161,11 +168,11 @@ __global__ void preshuffle_scale_gfx1250_kernel(const ScaleType* __restrict__ sr template void preShuffleScaleBuffer_gfx1250(const ScaleType* src, - ScaleType* dst, - int actual_rows, - int output_rows, - int KScale, - hipStream_t stream) + ScaleType* dst, + int actual_rows, + int output_rows, + int KScale, + hipStream_t stream) { constexpr int KPerXdlops = 128; constexpr int KStep = KPerXdlops / ScaleBlockSize; // 4 @@ -334,39 +341,75 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, return true; } - // Normalize input mats + // Normalize input mats similar to the FP8 grouped path. // I.e., swap A and B, as well as transa and transb. const NVTETensor* A_use = B; const NVTETensor* B_use = A; bool transA_use = transB; bool transB_use = transA; - // Validate scale type / data type combination + bool use_a_colwise_data = false; + bool use_b_colwise_data = false; + + Tensor* A0_te = convertNVTETensorCheck(A_use[0]); + Tensor* B0_te = convertNVTETensorCheck(B_use[0]); + + // CK MX grouped GEMM is presented as normalized NT, matching the FP8 grouped path. + // Selecting columnwise_data rewrites the physical storage and effective dims used by CK + // while preserving the original math. + if (transA_use) { + if (!A0_te->has_columnwise_data() || A0_te->columnwise_scale_inv.dptr == nullptr) { + NVTE_WARN("ck_tile_mx_grouped_gemm: missing A columnwise MXFP8 view for NT rewrite; falling back."); + return false; + } + use_a_colwise_data = true; + transA_use = false; + } + + if (!transB_use) { + if (!B0_te->has_columnwise_data() || B0_te->columnwise_scale_inv.dptr == nullptr) { + NVTE_WARN("ck_tile_mx_grouped_gemm: missing B columnwise MXFP8 view for NT rewrite; falling back."); + return false; + } + use_b_colwise_data = true; + transB_use = true; + } + + // Validate scale type / data type combination using the effective storage + // selected by the NT canonicalization above. // Expected input data format: fp8/bf8 (e4m3/e5m2) // Expected scale data format: e8m0 - const auto* A0 = convertNVTETensorCheck(A_use[0]); - const auto* B0 = convertNVTETensorCheck(B_use[0]); const auto* D0 = convertNVTETensorCheck(D[0]); - NVTE_CHECK(A0->scale_inv.dptr != nullptr, - "ck_tile_mx_grouped_gemm: A[0] scale_inv is not initialized"); - NVTE_CHECK(B0->scale_inv.dptr != nullptr, - "ck_tile_mx_grouped_gemm: B[0] scale_inv is not initialized"); - const auto a_scale_dtype = A0->scale_inv.dtype; - const auto b_scale_dtype = B0->scale_inv.dtype; + const auto& A0_data = use_a_colwise_data ? A0_te->columnwise_data : A0_te->data; + const auto& B0_data = use_b_colwise_data ? B0_te->columnwise_data : B0_te->data; + const auto& A0_scale = use_a_colwise_data ? A0_te->columnwise_scale_inv : A0_te->scale_inv; + const auto& B0_scale = use_b_colwise_data ? B0_te->columnwise_scale_inv : B0_te->scale_inv; + + NVTE_CHECK(A0_data.dptr != nullptr, + "ck_tile_mx_grouped_gemm: effective A[0] data is not initialized"); + NVTE_CHECK(B0_data.dptr != nullptr, + "ck_tile_mx_grouped_gemm: effective B[0] data is not initialized"); + NVTE_CHECK(A0_scale.dptr != nullptr, + "ck_tile_mx_grouped_gemm: effective A[0] scale_inv is not initialized"); + NVTE_CHECK(B0_scale.dptr != nullptr, + "ck_tile_mx_grouped_gemm: effective B[0] scale_inv is not initialized"); + + const auto a_scale_dtype = A0_scale.dtype; + const auto b_scale_dtype = B0_scale.dtype; NVTE_CHECK(a_scale_dtype == DType::kFloat8E8M0, "ck_tile_mx_grouped_gemm: A scale_inv dtype must be Float8E8M0, got ", static_cast(a_scale_dtype)); - + NVTE_CHECK(b_scale_dtype == DType::kFloat8E8M0, "ck_tile_mx_grouped_gemm: B scale_inv dtype must be Float8E8M0, got ", static_cast(b_scale_dtype)); - - const auto a_dtype = A0->dtype(); - const auto b_dtype = B0->dtype(); + + const auto a_dtype = A0_data.dtype; + const auto b_dtype = B0_data.dtype; const auto d_dtype = D0->dtype(); - NVTE_CHECK(is_fp8_dtype(a_dtype), "ck_tile_mx_grouped_gemm: A dtype must be FP8"); - NVTE_CHECK(is_fp8_dtype(b_dtype), "ck_tile_mx_grouped_gemm: B dtype must be FP8"); + NVTE_CHECK(is_fp8_dtype(a_dtype), "ck_tile_mx_grouped_gemm: effective A dtype must be FP8"); + NVTE_CHECK(is_fp8_dtype(b_dtype), "ck_tile_mx_grouped_gemm: effective B dtype must be FP8"); using AScaleType = ck_tile::e8m0_t; using BScaleType = ck_tile::e8m0_t; @@ -378,7 +421,7 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, ws_ptr = ws_te->data.dptr; ws_bytes = ws_te->data.numel() * typeToSize(ws_te->data.dtype); } - + GroupedGemmRunContext ctx = { A_use, B_use, @@ -388,7 +431,9 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, transB_use, ws_ptr, ws_bytes, - stream}; + stream, + use_a_colwise_data, + use_b_colwise_data}; const ck_tile::stream_config s{ctx.stream}; @@ -407,20 +452,48 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, transformer_engine::convertNVTETensorCheck(ctx.B[i]); transformer_engine::Tensor* D_te = transformer_engine::convertNVTETensorCheck(ctx.D[i]); - const auto& a = data_view(*A_te); - const auto& b = data_view(*B_te); - const auto& d = data_view(*D_te); + + const auto& a = ctx.use_a_colwise_data ? A_te->columnwise_data : A_te->data; + const auto& b = ctx.use_b_colwise_data ? B_te->columnwise_data : B_te->data; + const auto& d = D_te->data; + const auto& a_scales = + ctx.use_a_colwise_data ? A_te->columnwise_scale_inv : A_te->scale_inv; + const auto& b_scales = + ctx.use_b_colwise_data ? B_te->columnwise_scale_inv : B_te->scale_inv; + int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; - if (!get_flat_2d_dims(*A_te, Ad0, Ad1) || - !get_flat_2d_dims(*B_te, Bd0, Bd1) || - !get_flat_2d_dims(*D_te, Dd0, Dd1)) { - NVTE_ERROR("ck_tile_mx_grouped_gemm: expected all groups to be rank>=2."); + + if (ctx.use_a_colwise_data) { + if (!get_columnwise_storage_2d_dims(A_te->columnwise_data, Ad0, Ad1)) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: expected 2D columnwise_data for A in group ", i); + } + } else { + if (!get_flat_2d_dims(*A_te, Ad0, Ad1)) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: expected rank>=2 for normalized A in group ", i); + } + } + + if (ctx.use_b_colwise_data) { + if (!get_columnwise_storage_2d_dims(B_te->columnwise_data, Bd0, Bd1)) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: expected 2D columnwise_data for B in group ", i); + } + } else { + if (!get_flat_2d_dims(*B_te, Bd0, Bd1)) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: expected rank>=2 for normalized B in group ", i); + } + } + + if (!get_flat_2d_dims(*D_te, Dd0, Dd1)) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: expected rank>=2 for normalized D in group ", i); + } + if (a.dptr == nullptr || b.dptr == nullptr || a_scales.dptr == nullptr || + b_scales.dptr == nullptr) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: effective A/B data or scale_inv is missing."); } - const auto& a_scales = scale_inv_view(*A_te); - const auto& b_scales = scale_inv_view(*B_te); if (a_scales.shape.size() != 2 || b_scales.shape.size() != 2) { - NVTE_ERROR("ck_tile_mx_grouped_gemm: expected A/B scale_inv tensors to be rank-2."); + NVTE_ERROR("ck_tile_mx_grouped_gemm: expected effective A/B scale_inv tensors to be rank-2."); } + const int64_t M = ctx.transA ? Ad1 : Ad0; const int64_t K = ctx.transA ? Ad0 : Ad1; const int64_t N = ctx.transB ? Bd0 : Bd1; @@ -430,15 +503,20 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, } const int KScale = static_cast(K / ScaleBlockSize); if (Kb != K) { - NVTE_ERROR("ck_tile_mx_grouped_gemm: K mismatch between A and B in group ", i); + NVTE_ERROR("ck_tile_mx_grouped_gemm: K mismatch between A and B in group ", i, + ". op(A)=", M, "x", K, ", op(B)=", Kb, "x", N); } if (Dd0 != M || Dd1 != N) { - NVTE_ERROR("ck_tile_mx_grouped_gemm: D shape mismatch in group ", i); + NVTE_ERROR("ck_tile_mx_grouped_gemm: D shape mismatch in group ", i, + ". D=", Dd0, "x", Dd1, ", expected=", M, "x", N); } + const ck_tile::index_t stride_A = static_cast(Ad1); const ck_tile::index_t stride_B = static_cast(Bd1); const ck_tile::index_t stride_E = static_cast(Dd1); - // Pre-shuffle scale buffers for the hardware + + // Pre-shuffle scale buffers for the hardware. + // For the NT-normalized presentation, A scales are MxKScale and B scales are NxKScale. const int a_scale_actual_rows = static_cast(M); const int a_scale_output_rows = ck_tile::integer_least_multiple( diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 9a45b6e10..b3863350e 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1159,15 +1159,29 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor }; #endif +#ifdef __HIP_PLATFORM_AMD__ + auto effective_dtype = [](const transformer_engine::Tensor *t) { + if (t->has_data()) { + return t->data.dtype; + } + if (t->has_columnwise_data()) { + return t->columnwise_data.dtype; + } + return t->data.dtype; + }; +#endif + auto is_supported_dtype = [&]() -> bool { auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]); auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]); auto *OutputD = transformer_engine::convertNVTETensorCheck(D[0]); #ifdef __HIP_PLATFORM_AMD__ - auto A_dt = inputA->data.dtype; - auto B_dt = inputB->data.dtype; + auto A_dt = effective_dtype(inputA); + auto B_dt = effective_dtype(inputB); auto D_dt = OutputD->data.dtype; - return (is_fp8_dtype(A_dt) && is_fp8_dtype(B_dt)); + + return ((is_fp8_dtype(A_dt) && is_fp8_dtype(B_dt)) || + ((A_dt == B_dt) && (A_dt == D_dt) && is_fp16_dtype(A_dt))); #else auto A_type = get_cuda_dtype(inputA->data.dtype); auto B_type = get_cuda_dtype(inputB->data.dtype); From e295e745518927a44b784816f75f25ae9c21a1bc Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Sat, 2 May 2026 20:09:45 +0000 Subject: [PATCH 07/47] add cpp test for ck tile group mxfp8 gemm forward --- tests/cpp/operator/CMakeLists.txt | 6 +- .../test_te_ck_grouped_mxfp8_forward_refs.cu | 554 ++++++++++++++++++ .../ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp | 20 +- 3 files changed, 578 insertions(+), 2 deletions(-) create mode 100644 tests/cpp/operator/test_te_ck_grouped_mxfp8_forward_refs.cu diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 0ebd7fdfe..fa9f9a542 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -16,6 +16,7 @@ list(APPEND test_cuda_sources test_dequantize_mxfp8.cu test_dequantize_nvfp4.cu test_cast_nvfp4_transpose.cu + test_te_ck_grouped_mxfp8_forward_refs.cu test_transpose.cu test_cast_transpose.cu test_cast_transpose_current_scaling.cu @@ -31,7 +32,8 @@ list(APPEND test_cuda_sources test_multi_unpadding.cu test_causal_softmax.cu test_swap_first_dims.cu - ../test_common.cu) + ../test_common.cu) + if(USE_CUDA) list(APPEND test_cuda_sources test_cast_float8blockwise.cu @@ -54,12 +56,14 @@ endif() # Find required packages find_package(OpenMP REQUIRED) + if(USE_CUDA) list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn) target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS} OpenMP::OpenMP_CXX) else() target_link_libraries(test_operator PUBLIC hip::host hip::device GTest::gtest_main ${TE_LIB} OpenMP::OpenMP_CXX hiprand) endif() + target_compile_options(test_operator PRIVATE -O2 -fopenmp) include(GoogleTest) diff --git a/tests/cpp/operator/test_te_ck_grouped_mxfp8_forward_refs.cu b/tests/cpp/operator/test_te_ck_grouped_mxfp8_forward_refs.cu new file mode 100644 index 000000000..0872b1640 --- /dev/null +++ b/tests/cpp/operator/test_te_ck_grouped_mxfp8_forward_refs.cu @@ -0,0 +1,554 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +// Forward-only TE CK grouped MXFP8 validation. +// +// Compares three paths for grouped MXFP8 forward GEMM: +// 1. TE nvte_multi_tensor_gemm grouped forward path (CK backend selected by env) +// 2. ck_tile::reference_mx_gemm host reference, using exact quantized operands/scales +// 3. TE HIP reference kernel adapted from test_cublaslt_gemm.cu compute_ref_kernel +// +// Intended drop-in location: +// TransformerEngine/tests/cpp/operator/test_te_ck_grouped_mxfp8_forward_refs.cu + +#ifndef CK_TILE_USE_OCP_FP8 +#define CK_TILE_USE_OCP_FP8 1 +#endif + +#include +#include +#include + +#include +#include +#include + +#include "../test_common.h" + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace transformer_engine; +using namespace test; + +using fp8 = fp8e4m3; +using bf16_t = bf16; +using e8m0_t_te = fp8e8m0; + +namespace { + +struct CaseConfig { + size_t m_total; + size_t n; + size_t k; + int experts; + float scale; + int seed; + int ck_ref_groups; +}; + +static std::string case_name(const testing::TestParamInfo& info) { + const auto& c = info.param; + std::ostringstream os; + os << "M" << c.m_total << "_N" << c.n << "_K" << c.k + << "_E" << c.experts; + return os.str(); +} + +static void set_env_defaults() { + setenv("NVTE_USE_CUTLASS_GROUPED_GEMM", "1", 1); + setenv("NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK", "1", 0); + setenv("NVTE_ROCM_ENABLE_MXFP8", "1", 0); +} + +static float to_float(float x) { return x; } +static float to_float(const bf16_t& x) { return static_cast(x); } +static float to_float(const ck_tile::bfloat16_t& x) { return static_cast(x); } + +__device__ __host__ __forceinline__ float ref_gelu_unused(float x) { + float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); + return x * cdf; +} + +template +__global__ void compute_ref_kernel( + const A_Type* __restrict__ a_data, + const B_Type* __restrict__ b_data, + float a_scale_inv_scalar, + float b_scale_inv_scalar, + const e8m0_t_te* __restrict__ a_scale_inv_mxfp8, + const e8m0_t_te* __restrict__ b_scale_inv_mxfp8, + size_t a_scale_ld, + size_t b_scale_ld, + bool a_scale_is_colwise, + bool b_scale_is_colwise, + const Bias_Type* __restrict__ bias_data, + float d_scale, + size_t m, size_t k, size_t n, + D_Type* __restrict__ d_data, + float* __restrict__ d_amax, + Gelu_Type* __restrict__ gelu_data, + bool transa, + bool transb, + bool is_fp8_output, + bool a_is_colwise, + bool b_is_colwise, + bool use_mxfp8) { + const size_t jj = blockIdx.x * blockDim.x + threadIdx.x; + const size_t ii = blockIdx.y * blockDim.y + threadIdx.y; + const bool in_range = (ii < m) && (jj < n); + + float val = 0.0f; + if (in_range) { + for (size_t kk = 0; kk < k; ++kk) { + size_t a_idx = 0; + size_t b_idx = 0; + + if (use_mxfp8) { + a_idx = transa ? (ii * k + kk) : (kk * m + ii); + b_idx = transb ? (kk * n + jj) : (jj * k + kk); + } else { + a_idx = a_is_colwise ? (ii * k + kk) + : (transa ? (ii * k + kk) : (kk * m + ii)); + b_idx = b_is_colwise ? (jj * k + kk) + : (transb ? (kk * n + jj) : (jj * k + kk)); + } + + float a_scale_inv_val = a_scale_inv_scalar; + float b_scale_inv_val = b_scale_inv_scalar; + + if (a_scale_inv_mxfp8) { + const size_t kc = kk / 32; + const size_t a_scale_idx = + a_scale_is_colwise ? (kc * a_scale_ld + ii) : (ii * a_scale_ld + kc); + const size_t b_scale_idx = + b_scale_is_colwise ? (kc * b_scale_ld + jj) : (jj * b_scale_ld + kc); + a_scale_inv_val = exp2f(a_scale_inv_mxfp8[a_scale_idx] - 127.0f); + b_scale_inv_val = exp2f(b_scale_inv_mxfp8[b_scale_idx] - 127.0f); + } + + const float a_val = static_cast(a_data[a_idx]); + const float b_val = static_cast(b_data[b_idx]); + val += a_scale_inv_val * a_val * b_scale_inv_val * b_val; + } + + if (bias_data) val += static_cast(bias_data[ii]); + if (gelu_data) { + gelu_data[ii + jj * m] = static_cast(val); + val = ref_gelu_unused(val); + } + + const float scaled = val * d_scale; + d_data[ii + jj * m] = static_cast(scaled); + } + + if (is_fp8_output && d_amax) { + const int tid = threadIdx.y * blockDim.x + threadIdx.x; + const int nthreads = blockDim.x * blockDim.y; + extern __shared__ float s_amax[]; + s_amax[tid] = in_range ? fabsf(val) : 0.0f; + __syncthreads(); + for (int offset = nthreads / 2; offset > 0; offset /= 2) { + if (tid < offset) s_amax[tid] = fmaxf(s_amax[tid], s_amax[tid + offset]); + __syncthreads(); + } + if (tid == 0) atomicMax(d_amax, s_amax[0]); + } +} + +template +static void fill_randn_cpu(Tensor* t, float scale, int seed) { + std::mt19937 gen(seed); + std::normal_distribution dist(0.0f, scale); + const size_t n = product(t->rowwise_shape()); + T* ptr = t->rowwise_cpu_dptr(); + for (size_t i = 0; i < n; ++i) ptr[i] = static_cast(dist(gen)); + t->from_cpu(); +} + +static std::vector split_even(size_t m_total, int experts) { + NVTE_CHECK(experts > 0, "experts must be > 0"); + NVTE_CHECK(m_total % static_cast(experts) == 0, + "m_total must be divisible by experts"); + return std::vector(experts, m_total / static_cast(experts)); +} + +struct ErrorStats { + size_t count = 0; + double sum_abs = 0.0; + double sum_rel = 0.0; + double sum_ref_abs = 0.0; + double sum_got_abs = 0.0; + float max_abs = 0.0f; + float max_rel = 0.0f; + std::vector abs_errs; +}; + +static void add_err(ErrorStats& s, float got, float ref) { + const float abs_err = std::abs(got - ref); + const float rel_err = abs_err / std::max(std::abs(ref), 1.0e-12f); + s.count++; + s.sum_abs += abs_err; + s.sum_rel += rel_err; + s.sum_ref_abs += std::abs(ref); + s.sum_got_abs += std::abs(got); + s.max_abs = std::max(s.max_abs, abs_err); + s.max_rel = std::max(s.max_rel, rel_err); + s.abs_errs.push_back(abs_err); +} + +static float quantile(std::vector& values, double q) { + if (values.empty()) return 0.0f; + const size_t pos = std::min(static_cast(q * (values.size() - 1)), values.size() - 1); + std::nth_element(values.begin(), values.begin() + pos, values.end()); + return values[pos]; +} + +static void print_stats(const std::string& label, ErrorStats s) { + std::vector v50 = s.abs_errs; + std::vector v90 = s.abs_errs; + std::vector v99 = s.abs_errs; + const double denom = static_cast(std::max(s.count, 1)); + std::cout << std::fixed << std::setprecision(6) + << label + << " count=" << s.count + << " max_abs=" << s.max_abs + << " mean_abs=" << (s.sum_abs / denom) + << " p50_abs=" << quantile(v50, 0.50) + << " p90_abs=" << quantile(v90, 0.90) + << " p99_abs=" << quantile(v99, 0.99) + << " max_rel=" << s.max_rel + << " mean_rel=" << (s.sum_rel / denom) + << " ref_abs_mean=" << (s.sum_ref_abs / denom) + << " got_abs_mean=" << (s.sum_got_abs / denom) + << std::endl; +} + +static void expect_reference_match(const std::string& label, + const ErrorStats& stats, + float max_abs_limit, + float mean_abs_limit) { + print_stats(label, stats); + EXPECT_LE(stats.max_abs, max_abs_limit) << label; + EXPECT_LE(stats.sum_abs / static_cast(std::max(stats.count, 1)), + static_cast(mean_abs_limit)) << label; +} + +static void run_te_grouped_mxfp8_forward(const std::vector& weights_mx, + const std::vector& inputs_mx, + std::vector* outputs, + Tensor* workspace, + int math_sm_count) { + const size_t groups = weights_mx.size(); + std::vector A(groups), B(groups), D(groups), Bias(groups), PreGelu(groups); + std::vector empty_bias(groups), empty_pregelu(groups); + + // Match GroupedLinear forward / te_general_grouped_gemm: + // A = weight [N,K], transa=true + // B = input [M,K], transb=false + // D = output [M,N] + for (size_t i = 0; i < groups; ++i) { + A[i] = const_cast(weights_mx[i]).data(); + B[i] = const_cast(inputs_mx[i]).data(); + D[i] = (*outputs)[i].data(); + Bias[i] = empty_bias[i].data(); + PreGelu[i] = empty_pregelu[i].data(); + } + + std::vector Workspaces(1); + Workspaces[0] = workspace->data(); + + nvte_multi_tensor_gemm(A.data(), + B.data(), + D.data(), + Bias.data(), + PreGelu.data(), + groups, + true, // transa: weight [N,K] -> op(A) [K,N] + false, // transb: input [M,K] -> op(B) [M,K] + false, // grad + Workspaces.data(), + false, // accumulate + false, // use_split_accumulator + math_sm_count, + 0); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); +} + +template +static void run_hip_ref_for_group(const Tensor& input_mx, + const Tensor& weight_mx, + Tensor* ref_d_colmajor, + size_t m, + size_t k, + size_t n) { + // compute_ref_kernel expects A=input [M,K], B=weight [N,K], transa=true, transb=false, + // and writes D as column-major MxN into rowwise storage shaped [N,M]. + const auto a_s = input_mx.rowwise_scale_inv_shape(); + const auto b_s = weight_mx.rowwise_scale_inv_shape(); + NVTE_CHECK(a_s.ndim == 2 && b_s.ndim == 2, "Expected 2D MXFP8 scale_inv tensors"); + const size_t a_scale_ld = a_s.data[1]; + const size_t b_scale_ld = b_s.data[1]; + + dim3 block(16, 16); + dim3 grid(static_cast((n + block.x - 1) / block.x), + static_cast((m + block.y - 1) / block.y)); + const size_t shmem_bytes = size_t(block.x) * size_t(block.y) * sizeof(float); + + compute_ref_kernel + <<>>( + static_cast(input_mx.rowwise_dptr()), + static_cast(weight_mx.rowwise_dptr()), + 1.0f, + 1.0f, + static_cast(input_mx.rowwise_scale_inv_dptr()), + static_cast(weight_mx.rowwise_scale_inv_dptr()), + a_scale_ld, + b_scale_ld, + false, // input scale rowwise [M,K/32] + false, // weight scale rowwise [N,K/32] + nullptr, + 1.0f, + m, k, n, + static_cast(ref_d_colmajor->rowwise_dptr()), + nullptr, + nullptr, + true, // transa for A=input in this reference-kernel convention + false, // transb for B=weight + false, + false, + false, + true); + NVTE_CHECK_CUDA(cudaGetLastError()); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); +} + +static ck_tile::HostTensor run_ck_tile_reference_for_group( + const Tensor& input_mx, + const Tensor& weight_mx, + size_t m, + size_t k, + size_t n) { + using namespace ck_tile::literals; + using AType = ck_tile::fp8_t; + using BType = ck_tile::fp8_t; + using CType = ck_tile::bfloat16_t; + using ScaleType = ck_tile::e8m0_t; + + const size_t kscale = k / 32; + + ck_tile::HostTensor a_host( + ck_tile::HostTensorDescriptor({m, k}, {k, 1_uz})); + ck_tile::HostTensor b_host( + ck_tile::HostTensorDescriptor({k, n}, {1_uz, k})); + ck_tile::HostTensor c_ref( + ck_tile::HostTensorDescriptor({m, n}, {n, 1_uz})); + ck_tile::HostTensor a_scale_ref( + ck_tile::HostTensorDescriptor({m, kscale}, {kscale, 1_uz})); + ck_tile::HostTensor b_scale_ref( + ck_tile::HostTensorDescriptor({kscale, n}, {1_uz, kscale})); + + c_ref.SetZero(); + + NVTE_CHECK_CUDA(cudaMemcpy(a_host.data(), + input_mx.rowwise_dptr(), + a_host.get_element_space_size_in_bytes(), + cudaMemcpyDeviceToHost)); + NVTE_CHECK_CUDA(cudaMemcpy(b_host.data(), + weight_mx.rowwise_dptr(), + b_host.get_element_space_size_in_bytes(), + cudaMemcpyDeviceToHost)); + NVTE_CHECK_CUDA(cudaMemcpy(a_scale_ref.data(), + input_mx.rowwise_scale_inv_dptr(), + a_scale_ref.get_element_space_size_in_bytes(), + cudaMemcpyDeviceToHost)); + NVTE_CHECK_CUDA(cudaMemcpy(b_scale_ref.data(), + weight_mx.rowwise_scale_inv_dptr(), + b_scale_ref.get_element_space_size_in_bytes(), + cudaMemcpyDeviceToHost)); + + ck_tile::reference_mx_gemm( + a_host, b_host, c_ref, a_scale_ref, b_scale_ref); + return c_ref; +} + +static ErrorStats compare_te_vs_hip(const Tensor& te_out_rowmajor, + const Tensor& hip_ref_colmajor, + size_t m, + size_t n) { + ErrorStats stats; + const bf16_t* te = te_out_rowmajor.rowwise_cpu_dptr(); + const bf16_t* hip = hip_ref_colmajor.rowwise_cpu_dptr(); + for (size_t i = 0; i < m; ++i) { + for (size_t j = 0; j < n; ++j) { + add_err(stats, to_float(te[i * n + j]), to_float(hip[j * m + i])); + } + } + return stats; +} + +static ErrorStats compare_te_vs_ck(const Tensor& te_out_rowmajor, + const ck_tile::HostTensor& ck_ref, + size_t m, + size_t n) { + ErrorStats stats; + const bf16_t* te = te_out_rowmajor.rowwise_cpu_dptr(); + for (size_t i = 0; i < m; ++i) { + for (size_t j = 0; j < n; ++j) { + add_err(stats, to_float(te[i * n + j]), to_float(ck_ref(i, j))); + } + } + return stats; +} + +static ErrorStats compare_ck_vs_hip(const ck_tile::HostTensor& ck_ref, + const Tensor& hip_ref_colmajor, + size_t m, + size_t n) { + ErrorStats stats; + const bf16_t* hip = hip_ref_colmajor.rowwise_cpu_dptr(); + for (size_t i = 0; i < m; ++i) { + for (size_t j = 0; j < n; ++j) { + add_err(stats, to_float(ck_ref(i, j)), to_float(hip[j * m + i])); + } + } + return stats; +} + +static void run_case(const CaseConfig& cfg) { + set_env_defaults(); + + ASSERT_EQ(cfg.k % 128, 0UL) << "K must be a multiple of 128 for MXFP8"; + ASSERT_EQ(cfg.m_total % static_cast(cfg.experts), 0UL); + + cudaDeviceProp prop; + NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); +#ifdef __HIP_PLATFORM_AMD__ + const bool is_gfx950_or_newer_cdna = (prop.major == 9 && prop.minor >= 5); + const bool is_gfx1250 = (prop.major == 12 && prop.minor == 5); + + if (!is_gfx950_or_newer_cdna && !is_gfx1250) { + GTEST_SKIP() << "MXFP8 requires gfx950+ or gfx1250 in this test. GPU=" << prop.name + << " major=" << prop.major << " minor=" << prop.minor; + } +#endif + + const auto m_splits = split_even(cfg.m_total, cfg.experts); + const size_t per_m = m_splits[0]; + const int groups_to_ck = std::min(cfg.ck_ref_groups, cfg.experts); + + std::cout << "\n=== TE CK grouped MXFP8 forward reference comparison ===\n" + << "M_total=" << cfg.m_total << " N=" << cfg.n << " K=" << cfg.k + << " experts=" << cfg.experts << " per_expert_M=" << per_m + << " scale=" << cfg.scale << " seed=" << cfg.seed << "\n" + << "NVTE_USE_CUTLASS_GROUPED_GEMM=" << std::getenv("NVTE_USE_CUTLASS_GROUPED_GEMM") << "\n" + << "NVTE_ROCM_ENABLE_MXFP8=" << std::getenv("NVTE_ROCM_ENABLE_MXFP8") << "\n" + << "CK_TILE_USE_OCP_FP8=" << CK_TILE_USE_OCP_FP8 << "\n" + << "GPU=" << prop.name << " SM/CU count=" << prop.multiProcessorCount << "\n"; + + std::vector input_src; + std::vector weight_src; + std::vector input_mx; + std::vector weight_mx; + std::vector output_te; + std::vector output_hip_colmajor; + input_src.reserve(cfg.experts); + weight_src.reserve(cfg.experts); + input_mx.reserve(cfg.experts); + weight_mx.reserve(cfg.experts); + output_te.reserve(cfg.experts); + output_hip_colmajor.reserve(cfg.experts); + + for (int g = 0; g < cfg.experts; ++g) { + const size_t m = m_splits[g]; + input_src.emplace_back("input_src", std::vector{m, cfg.k}, DType::kBFloat16); + weight_src.emplace_back("weight_src", std::vector{cfg.n, cfg.k}, DType::kBFloat16); + + fill_randn_cpu(&input_src.back(), cfg.scale, cfg.seed + 1009 * g + 17); + fill_randn_cpu(&weight_src.back(), cfg.scale, cfg.seed + 1009 * g + 29); + + input_mx.emplace_back("input_mx", std::vector{m, cfg.k}, DType::kFloat8E4M3, + true, false, NVTEScalingMode::NVTE_MXFP8_1D_SCALING); + weight_mx.emplace_back("weight_mx", std::vector{cfg.n, cfg.k}, DType::kFloat8E4M3, + true, false, NVTEScalingMode::NVTE_MXFP8_1D_SCALING); + + nvte_quantize(input_src.back().data(), input_mx.back().data(), 0); + nvte_quantize(weight_src.back().data(), weight_mx.back().data(), 0); + + output_te.emplace_back("output_te", std::vector{m, cfg.n}, DType::kBFloat16); + output_hip_colmajor.emplace_back("output_hip_colmajor", std::vector{cfg.n, m}, DType::kBFloat16); + } + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + + Tensor workspace("workspace", std::vector{67108864}, DType::kByte); + + run_te_grouped_mxfp8_forward(weight_mx, input_mx, &output_te, &workspace, + prop.multiProcessorCount); + for (auto& out : output_te) out.to_cpu(); + + for (int g = 0; g < cfg.experts; ++g) { + run_hip_ref_for_group(input_mx[g], weight_mx[g], &output_hip_colmajor[g], + m_splits[g], cfg.k, cfg.n); + output_hip_colmajor[g].to_cpu(); + expect_reference_match("group " + std::to_string(g) + " TE_vs_HIP_REF", + compare_te_vs_hip(output_te[g], output_hip_colmajor[g], + m_splits[g], cfg.n), + 0.25f, + 0.03f); + } + + for (int g = 0; g < groups_to_ck; ++g) { + auto ck_ref = run_ck_tile_reference_for_group(input_mx[g], weight_mx[g], + m_splits[g], cfg.k, cfg.n); + expect_reference_match("group " + std::to_string(g) + " TE_vs_CK_REF ", + compare_te_vs_ck(output_te[g], ck_ref, m_splits[g], cfg.n), + 0.25f, + 0.03f); + expect_reference_match("group " + std::to_string(g) + " CK_vs_HIP_REF", + compare_ck_vs_hip(ck_ref, output_hip_colmajor[g], + m_splits[g], cfg.n), + 0.25f, + 0.03f); + } +} + +} // namespace + +class GroupedMXFP8ForwardRefsTestSuite : public ::testing::TestWithParam {}; + +TEST_P(GroupedMXFP8ForwardRefsTestSuite, MatchesCKTileAndHIPReferences) { + run_case(GetParam()); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + GroupedMXFP8ForwardRefsTestSuite, + ::testing::Values( + // Small enough for quick CI-style sanity. + CaseConfig{1024, 1024, 1024, 2, 0.25f, 1234, 1}, + // Reproduces the earlier forward-only "failure" scale/shape regime, but + // validates against true MXFP8 references instead of BF16. + CaseConfig{1536, 4096, 4096, 3, 0.25f, 1234, 1}, + // Llama-ish suspicious path. CK reference only group 0 to keep runtime sane; + // HIP reference checks all groups. + CaseConfig{4096, 12288, 4096, 4, 0.25f, 1234, 1}), + case_name); diff --git a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp index 59ae7c0ff..932f22ecc 100644 --- a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp @@ -6,14 +6,17 @@ #include #include "../../common.h" + #include "ck_tile/core.hpp" -#include "ck_tile/host.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm/kernel/mx_grouped_gemm_kernel.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" +#include +#include + namespace transformer_engine { namespace mx_grouped_gemm { @@ -511,6 +514,21 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, ". D=", Dd0, "x", Dd1, ", expected=", M, "x", N); } + if (i == 0) { + printf("[MX CK] transA=%d transB=%d use_a_col=%d use_b_col=%d " + "M=%ld N=%ld K=%ld Ad=[%ld,%ld] Bd=[%ld,%ld] " + "a_scale_shape=[%zu,%zu] b_scale_shape=[%zu,%zu]\n", + static_cast(ctx.transA), + static_cast(ctx.transB), + static_cast(ctx.use_a_colwise_data), + static_cast(ctx.use_b_colwise_data), + M, N, K, Ad0, Ad1, Bd0, Bd1, + a_scales.shape.size() > 0 ? a_scales.shape[0] : 0, + a_scales.shape.size() > 1 ? a_scales.shape[1] : 0, + b_scales.shape.size() > 0 ? b_scales.shape[0] : 0, + b_scales.shape.size() > 1 ? b_scales.shape[1] : 0); + } + const ck_tile::index_t stride_A = static_cast(Ad1); const ck_tile::index_t stride_B = static_cast(Bd1); const ck_tile::index_t stride_E = static_cast(Dd1); From 1784045d88d240a4cc3eff6aceec1a9fc88c51be Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Sun, 3 May 2026 15:33:02 +0000 Subject: [PATCH 08/47] Fix MXFP8 grouped GEMM scale handling for NN/TN/NT --- .../ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp | 121 ++++++++---------- 1 file changed, 53 insertions(+), 68 deletions(-) diff --git a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp index 932f22ecc..4a5a4aaa9 100644 --- a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp @@ -15,6 +15,7 @@ #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include +#include #include namespace transformer_engine { @@ -60,19 +61,6 @@ static inline bool get_flat_2d_dims(const transformer_engine::Tensor& t, return true; } -// Columnwise storage is the physical transposed view used to rewrite a -// normalized GEMM into CK's preferred NT presentation. Interpret its -// 2D shape consistently with the FP8 grouped GEMM path. -static inline bool get_columnwise_storage_2d_dims(const transformer_engine::SimpleTensor& t, - int64_t& d0, int64_t& d1) { - if (t.shape.size() != 2) { - return false; - } - d0 = static_cast(t.shape[1]); - d1 = static_cast(t.shape[0]); - return true; -} - static constexpr ck_tile::index_t ScaleBlockSize = 32; enum struct MxGemmPipelineType @@ -351,35 +339,13 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, bool transA_use = transB; bool transB_use = transA; - bool use_a_colwise_data = false; - bool use_b_colwise_data = false; + const bool use_a_colwise_data = transA_use; + const bool use_b_colwise_data = !transB_use; Tensor* A0_te = convertNVTETensorCheck(A_use[0]); Tensor* B0_te = convertNVTETensorCheck(B_use[0]); - // CK MX grouped GEMM is presented as normalized NT, matching the FP8 grouped path. - // Selecting columnwise_data rewrites the physical storage and effective dims used by CK - // while preserving the original math. - if (transA_use) { - if (!A0_te->has_columnwise_data() || A0_te->columnwise_scale_inv.dptr == nullptr) { - NVTE_WARN("ck_tile_mx_grouped_gemm: missing A columnwise MXFP8 view for NT rewrite; falling back."); - return false; - } - use_a_colwise_data = true; - transA_use = false; - } - - if (!transB_use) { - if (!B0_te->has_columnwise_data() || B0_te->columnwise_scale_inv.dptr == nullptr) { - NVTE_WARN("ck_tile_mx_grouped_gemm: missing B columnwise MXFP8 view for NT rewrite; falling back."); - return false; - } - use_b_colwise_data = true; - transB_use = true; - } - - // Validate scale type / data type combination using the effective storage - // selected by the NT canonicalization above. + // Validate scale type / data type combination. // Expected input data format: fp8/bf8 (e4m3/e5m2) // Expected scale data format: e8m0 const auto* D0 = convertNVTETensorCheck(D[0]); @@ -466,24 +432,17 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; - if (ctx.use_a_colwise_data) { - if (!get_columnwise_storage_2d_dims(A_te->columnwise_data, Ad0, Ad1)) { - NVTE_ERROR("ck_tile_mx_grouped_gemm: expected 2D columnwise_data for A in group ", i); - } - } else { - if (!get_flat_2d_dims(*A_te, Ad0, Ad1)) { - NVTE_ERROR("ck_tile_mx_grouped_gemm: expected rank>=2 for normalized A in group ", i); - } + // MXFP8 columnwise_data is not a physical transpose. It has the same + // logical tensor shape as rowwise data, but is quantized with scales + // along the other dimension. Therefore dims/strides must always be + // derived from the TE tensor shape, not from columnwise_data.shape + // interpreted as a transposed storage view. + if (!get_flat_2d_dims(*A_te, Ad0, Ad1)) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: expected rank>=2 for normalized A in group ", i); } - if (ctx.use_b_colwise_data) { - if (!get_columnwise_storage_2d_dims(B_te->columnwise_data, Bd0, Bd1)) { - NVTE_ERROR("ck_tile_mx_grouped_gemm: expected 2D columnwise_data for B in group ", i); - } - } else { - if (!get_flat_2d_dims(*B_te, Bd0, Bd1)) { - NVTE_ERROR("ck_tile_mx_grouped_gemm: expected rank>=2 for normalized B in group ", i); - } + if (!get_flat_2d_dims(*B_te, Bd0, Bd1)) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: expected rank>=2 for normalized B in group ", i); } if (!get_flat_2d_dims(*D_te, Dd0, Dd1)) { @@ -556,20 +515,46 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, std::make_unique(b_scale_shuffled_bytes)); void* a_scale_shuffled_ptr = a_scale_shuffled_bufs.back()->GetDeviceBuffer(); void* b_scale_shuffled_ptr = b_scale_shuffled_bufs.back()->GetDeviceBuffer(); - preShuffleScaleBuffer_gfx1250( - reinterpret_cast(a_scales.dptr), - reinterpret_cast(a_scale_shuffled_ptr), - a_scale_actual_rows, - a_scale_output_rows, - KScale, - stream); - preShuffleScaleBuffer_gfx1250( - reinterpret_cast(b_scales.dptr), - reinterpret_cast(b_scale_shuffled_ptr), - b_scale_actual_rows, - b_scale_output_rows, - KScale, - stream); + // CK expects canonical pre-shuffled scale buffers laid out as + // A: [M, KScale] and B: [N, KScale], independent of A/B data layouts. + // TE rowwise MXFP8 scale_inv is [rows, KScale] and can be read with + // KStride=true. TE columnwise_scale_inv is [KScale, rows] and must be + // read with KStride=false before writing CK's canonical shuffled layout. + if (ctx.use_a_colwise_data) { + preShuffleScaleBuffer_gfx1250( + reinterpret_cast(a_scales.dptr), + reinterpret_cast(a_scale_shuffled_ptr), + a_scale_actual_rows, + a_scale_output_rows, + KScale, + stream); + } else { + preShuffleScaleBuffer_gfx1250( + reinterpret_cast(a_scales.dptr), + reinterpret_cast(a_scale_shuffled_ptr), + a_scale_actual_rows, + a_scale_output_rows, + KScale, + stream); + } + + if (ctx.use_b_colwise_data) { + preShuffleScaleBuffer_gfx1250( + reinterpret_cast(b_scales.dptr), + reinterpret_cast(b_scale_shuffled_ptr), + b_scale_actual_rows, + b_scale_output_rows, + KScale, + stream); + } else { + preShuffleScaleBuffer_gfx1250( + reinterpret_cast(b_scales.dptr), + reinterpret_cast(b_scale_shuffled_ptr), + b_scale_actual_rows, + b_scale_output_rows, + KScale, + stream); + } descs.emplace_back(mx_grouped_gemm_kargs( a.dptr, a_scale_shuffled_ptr, From fe99bf30d1b99c701217ef37b223f09da14e2384 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Sun, 3 May 2026 15:59:03 +0000 Subject: [PATCH 09/47] update ck mxfp8 group gemm gtest to exercise mixed dtypes --- tests/cpp/operator/CMakeLists.txt | 2 +- .../test_te_ck_grouped_mxfp8_forward_refs.cu | 554 ------------------ 2 files changed, 1 insertion(+), 555 deletions(-) delete mode 100644 tests/cpp/operator/test_te_ck_grouped_mxfp8_forward_refs.cu diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index fa9f9a542..c81ab1e62 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -16,7 +16,7 @@ list(APPEND test_cuda_sources test_dequantize_mxfp8.cu test_dequantize_nvfp4.cu test_cast_nvfp4_transpose.cu - test_te_ck_grouped_mxfp8_forward_refs.cu + test_te_ck_grouped_mxfp8.cu test_transpose.cu test_cast_transpose.cu test_cast_transpose_current_scaling.cu diff --git a/tests/cpp/operator/test_te_ck_grouped_mxfp8_forward_refs.cu b/tests/cpp/operator/test_te_ck_grouped_mxfp8_forward_refs.cu deleted file mode 100644 index 0872b1640..000000000 --- a/tests/cpp/operator/test_te_ck_grouped_mxfp8_forward_refs.cu +++ /dev/null @@ -1,554 +0,0 @@ -/************************************************************************* - * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. - * - * License for AMD contributions = MIT. See LICENSE for more information - ************************************************************************/ - -// Forward-only TE CK grouped MXFP8 validation. -// -// Compares three paths for grouped MXFP8 forward GEMM: -// 1. TE nvte_multi_tensor_gemm grouped forward path (CK backend selected by env) -// 2. ck_tile::reference_mx_gemm host reference, using exact quantized operands/scales -// 3. TE HIP reference kernel adapted from test_cublaslt_gemm.cu compute_ref_kernel -// -// Intended drop-in location: -// TransformerEngine/tests/cpp/operator/test_te_ck_grouped_mxfp8_forward_refs.cu - -#ifndef CK_TILE_USE_OCP_FP8 -#define CK_TILE_USE_OCP_FP8 1 -#endif - -#include -#include -#include - -#include -#include -#include - -#include "../test_common.h" - -#include "ck_tile/core.hpp" -#include "ck_tile/host.hpp" -#include "ck_tile/ops/gemm.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace transformer_engine; -using namespace test; - -using fp8 = fp8e4m3; -using bf16_t = bf16; -using e8m0_t_te = fp8e8m0; - -namespace { - -struct CaseConfig { - size_t m_total; - size_t n; - size_t k; - int experts; - float scale; - int seed; - int ck_ref_groups; -}; - -static std::string case_name(const testing::TestParamInfo& info) { - const auto& c = info.param; - std::ostringstream os; - os << "M" << c.m_total << "_N" << c.n << "_K" << c.k - << "_E" << c.experts; - return os.str(); -} - -static void set_env_defaults() { - setenv("NVTE_USE_CUTLASS_GROUPED_GEMM", "1", 1); - setenv("NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK", "1", 0); - setenv("NVTE_ROCM_ENABLE_MXFP8", "1", 0); -} - -static float to_float(float x) { return x; } -static float to_float(const bf16_t& x) { return static_cast(x); } -static float to_float(const ck_tile::bfloat16_t& x) { return static_cast(x); } - -__device__ __host__ __forceinline__ float ref_gelu_unused(float x) { - float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); - return x * cdf; -} - -template -__global__ void compute_ref_kernel( - const A_Type* __restrict__ a_data, - const B_Type* __restrict__ b_data, - float a_scale_inv_scalar, - float b_scale_inv_scalar, - const e8m0_t_te* __restrict__ a_scale_inv_mxfp8, - const e8m0_t_te* __restrict__ b_scale_inv_mxfp8, - size_t a_scale_ld, - size_t b_scale_ld, - bool a_scale_is_colwise, - bool b_scale_is_colwise, - const Bias_Type* __restrict__ bias_data, - float d_scale, - size_t m, size_t k, size_t n, - D_Type* __restrict__ d_data, - float* __restrict__ d_amax, - Gelu_Type* __restrict__ gelu_data, - bool transa, - bool transb, - bool is_fp8_output, - bool a_is_colwise, - bool b_is_colwise, - bool use_mxfp8) { - const size_t jj = blockIdx.x * blockDim.x + threadIdx.x; - const size_t ii = blockIdx.y * blockDim.y + threadIdx.y; - const bool in_range = (ii < m) && (jj < n); - - float val = 0.0f; - if (in_range) { - for (size_t kk = 0; kk < k; ++kk) { - size_t a_idx = 0; - size_t b_idx = 0; - - if (use_mxfp8) { - a_idx = transa ? (ii * k + kk) : (kk * m + ii); - b_idx = transb ? (kk * n + jj) : (jj * k + kk); - } else { - a_idx = a_is_colwise ? (ii * k + kk) - : (transa ? (ii * k + kk) : (kk * m + ii)); - b_idx = b_is_colwise ? (jj * k + kk) - : (transb ? (kk * n + jj) : (jj * k + kk)); - } - - float a_scale_inv_val = a_scale_inv_scalar; - float b_scale_inv_val = b_scale_inv_scalar; - - if (a_scale_inv_mxfp8) { - const size_t kc = kk / 32; - const size_t a_scale_idx = - a_scale_is_colwise ? (kc * a_scale_ld + ii) : (ii * a_scale_ld + kc); - const size_t b_scale_idx = - b_scale_is_colwise ? (kc * b_scale_ld + jj) : (jj * b_scale_ld + kc); - a_scale_inv_val = exp2f(a_scale_inv_mxfp8[a_scale_idx] - 127.0f); - b_scale_inv_val = exp2f(b_scale_inv_mxfp8[b_scale_idx] - 127.0f); - } - - const float a_val = static_cast(a_data[a_idx]); - const float b_val = static_cast(b_data[b_idx]); - val += a_scale_inv_val * a_val * b_scale_inv_val * b_val; - } - - if (bias_data) val += static_cast(bias_data[ii]); - if (gelu_data) { - gelu_data[ii + jj * m] = static_cast(val); - val = ref_gelu_unused(val); - } - - const float scaled = val * d_scale; - d_data[ii + jj * m] = static_cast(scaled); - } - - if (is_fp8_output && d_amax) { - const int tid = threadIdx.y * blockDim.x + threadIdx.x; - const int nthreads = blockDim.x * blockDim.y; - extern __shared__ float s_amax[]; - s_amax[tid] = in_range ? fabsf(val) : 0.0f; - __syncthreads(); - for (int offset = nthreads / 2; offset > 0; offset /= 2) { - if (tid < offset) s_amax[tid] = fmaxf(s_amax[tid], s_amax[tid + offset]); - __syncthreads(); - } - if (tid == 0) atomicMax(d_amax, s_amax[0]); - } -} - -template -static void fill_randn_cpu(Tensor* t, float scale, int seed) { - std::mt19937 gen(seed); - std::normal_distribution dist(0.0f, scale); - const size_t n = product(t->rowwise_shape()); - T* ptr = t->rowwise_cpu_dptr(); - for (size_t i = 0; i < n; ++i) ptr[i] = static_cast(dist(gen)); - t->from_cpu(); -} - -static std::vector split_even(size_t m_total, int experts) { - NVTE_CHECK(experts > 0, "experts must be > 0"); - NVTE_CHECK(m_total % static_cast(experts) == 0, - "m_total must be divisible by experts"); - return std::vector(experts, m_total / static_cast(experts)); -} - -struct ErrorStats { - size_t count = 0; - double sum_abs = 0.0; - double sum_rel = 0.0; - double sum_ref_abs = 0.0; - double sum_got_abs = 0.0; - float max_abs = 0.0f; - float max_rel = 0.0f; - std::vector abs_errs; -}; - -static void add_err(ErrorStats& s, float got, float ref) { - const float abs_err = std::abs(got - ref); - const float rel_err = abs_err / std::max(std::abs(ref), 1.0e-12f); - s.count++; - s.sum_abs += abs_err; - s.sum_rel += rel_err; - s.sum_ref_abs += std::abs(ref); - s.sum_got_abs += std::abs(got); - s.max_abs = std::max(s.max_abs, abs_err); - s.max_rel = std::max(s.max_rel, rel_err); - s.abs_errs.push_back(abs_err); -} - -static float quantile(std::vector& values, double q) { - if (values.empty()) return 0.0f; - const size_t pos = std::min(static_cast(q * (values.size() - 1)), values.size() - 1); - std::nth_element(values.begin(), values.begin() + pos, values.end()); - return values[pos]; -} - -static void print_stats(const std::string& label, ErrorStats s) { - std::vector v50 = s.abs_errs; - std::vector v90 = s.abs_errs; - std::vector v99 = s.abs_errs; - const double denom = static_cast(std::max(s.count, 1)); - std::cout << std::fixed << std::setprecision(6) - << label - << " count=" << s.count - << " max_abs=" << s.max_abs - << " mean_abs=" << (s.sum_abs / denom) - << " p50_abs=" << quantile(v50, 0.50) - << " p90_abs=" << quantile(v90, 0.90) - << " p99_abs=" << quantile(v99, 0.99) - << " max_rel=" << s.max_rel - << " mean_rel=" << (s.sum_rel / denom) - << " ref_abs_mean=" << (s.sum_ref_abs / denom) - << " got_abs_mean=" << (s.sum_got_abs / denom) - << std::endl; -} - -static void expect_reference_match(const std::string& label, - const ErrorStats& stats, - float max_abs_limit, - float mean_abs_limit) { - print_stats(label, stats); - EXPECT_LE(stats.max_abs, max_abs_limit) << label; - EXPECT_LE(stats.sum_abs / static_cast(std::max(stats.count, 1)), - static_cast(mean_abs_limit)) << label; -} - -static void run_te_grouped_mxfp8_forward(const std::vector& weights_mx, - const std::vector& inputs_mx, - std::vector* outputs, - Tensor* workspace, - int math_sm_count) { - const size_t groups = weights_mx.size(); - std::vector A(groups), B(groups), D(groups), Bias(groups), PreGelu(groups); - std::vector empty_bias(groups), empty_pregelu(groups); - - // Match GroupedLinear forward / te_general_grouped_gemm: - // A = weight [N,K], transa=true - // B = input [M,K], transb=false - // D = output [M,N] - for (size_t i = 0; i < groups; ++i) { - A[i] = const_cast(weights_mx[i]).data(); - B[i] = const_cast(inputs_mx[i]).data(); - D[i] = (*outputs)[i].data(); - Bias[i] = empty_bias[i].data(); - PreGelu[i] = empty_pregelu[i].data(); - } - - std::vector Workspaces(1); - Workspaces[0] = workspace->data(); - - nvte_multi_tensor_gemm(A.data(), - B.data(), - D.data(), - Bias.data(), - PreGelu.data(), - groups, - true, // transa: weight [N,K] -> op(A) [K,N] - false, // transb: input [M,K] -> op(B) [M,K] - false, // grad - Workspaces.data(), - false, // accumulate - false, // use_split_accumulator - math_sm_count, - 0); - NVTE_CHECK_CUDA(cudaDeviceSynchronize()); -} - -template -static void run_hip_ref_for_group(const Tensor& input_mx, - const Tensor& weight_mx, - Tensor* ref_d_colmajor, - size_t m, - size_t k, - size_t n) { - // compute_ref_kernel expects A=input [M,K], B=weight [N,K], transa=true, transb=false, - // and writes D as column-major MxN into rowwise storage shaped [N,M]. - const auto a_s = input_mx.rowwise_scale_inv_shape(); - const auto b_s = weight_mx.rowwise_scale_inv_shape(); - NVTE_CHECK(a_s.ndim == 2 && b_s.ndim == 2, "Expected 2D MXFP8 scale_inv tensors"); - const size_t a_scale_ld = a_s.data[1]; - const size_t b_scale_ld = b_s.data[1]; - - dim3 block(16, 16); - dim3 grid(static_cast((n + block.x - 1) / block.x), - static_cast((m + block.y - 1) / block.y)); - const size_t shmem_bytes = size_t(block.x) * size_t(block.y) * sizeof(float); - - compute_ref_kernel - <<>>( - static_cast(input_mx.rowwise_dptr()), - static_cast(weight_mx.rowwise_dptr()), - 1.0f, - 1.0f, - static_cast(input_mx.rowwise_scale_inv_dptr()), - static_cast(weight_mx.rowwise_scale_inv_dptr()), - a_scale_ld, - b_scale_ld, - false, // input scale rowwise [M,K/32] - false, // weight scale rowwise [N,K/32] - nullptr, - 1.0f, - m, k, n, - static_cast(ref_d_colmajor->rowwise_dptr()), - nullptr, - nullptr, - true, // transa for A=input in this reference-kernel convention - false, // transb for B=weight - false, - false, - false, - true); - NVTE_CHECK_CUDA(cudaGetLastError()); - NVTE_CHECK_CUDA(cudaDeviceSynchronize()); -} - -static ck_tile::HostTensor run_ck_tile_reference_for_group( - const Tensor& input_mx, - const Tensor& weight_mx, - size_t m, - size_t k, - size_t n) { - using namespace ck_tile::literals; - using AType = ck_tile::fp8_t; - using BType = ck_tile::fp8_t; - using CType = ck_tile::bfloat16_t; - using ScaleType = ck_tile::e8m0_t; - - const size_t kscale = k / 32; - - ck_tile::HostTensor a_host( - ck_tile::HostTensorDescriptor({m, k}, {k, 1_uz})); - ck_tile::HostTensor b_host( - ck_tile::HostTensorDescriptor({k, n}, {1_uz, k})); - ck_tile::HostTensor c_ref( - ck_tile::HostTensorDescriptor({m, n}, {n, 1_uz})); - ck_tile::HostTensor a_scale_ref( - ck_tile::HostTensorDescriptor({m, kscale}, {kscale, 1_uz})); - ck_tile::HostTensor b_scale_ref( - ck_tile::HostTensorDescriptor({kscale, n}, {1_uz, kscale})); - - c_ref.SetZero(); - - NVTE_CHECK_CUDA(cudaMemcpy(a_host.data(), - input_mx.rowwise_dptr(), - a_host.get_element_space_size_in_bytes(), - cudaMemcpyDeviceToHost)); - NVTE_CHECK_CUDA(cudaMemcpy(b_host.data(), - weight_mx.rowwise_dptr(), - b_host.get_element_space_size_in_bytes(), - cudaMemcpyDeviceToHost)); - NVTE_CHECK_CUDA(cudaMemcpy(a_scale_ref.data(), - input_mx.rowwise_scale_inv_dptr(), - a_scale_ref.get_element_space_size_in_bytes(), - cudaMemcpyDeviceToHost)); - NVTE_CHECK_CUDA(cudaMemcpy(b_scale_ref.data(), - weight_mx.rowwise_scale_inv_dptr(), - b_scale_ref.get_element_space_size_in_bytes(), - cudaMemcpyDeviceToHost)); - - ck_tile::reference_mx_gemm( - a_host, b_host, c_ref, a_scale_ref, b_scale_ref); - return c_ref; -} - -static ErrorStats compare_te_vs_hip(const Tensor& te_out_rowmajor, - const Tensor& hip_ref_colmajor, - size_t m, - size_t n) { - ErrorStats stats; - const bf16_t* te = te_out_rowmajor.rowwise_cpu_dptr(); - const bf16_t* hip = hip_ref_colmajor.rowwise_cpu_dptr(); - for (size_t i = 0; i < m; ++i) { - for (size_t j = 0; j < n; ++j) { - add_err(stats, to_float(te[i * n + j]), to_float(hip[j * m + i])); - } - } - return stats; -} - -static ErrorStats compare_te_vs_ck(const Tensor& te_out_rowmajor, - const ck_tile::HostTensor& ck_ref, - size_t m, - size_t n) { - ErrorStats stats; - const bf16_t* te = te_out_rowmajor.rowwise_cpu_dptr(); - for (size_t i = 0; i < m; ++i) { - for (size_t j = 0; j < n; ++j) { - add_err(stats, to_float(te[i * n + j]), to_float(ck_ref(i, j))); - } - } - return stats; -} - -static ErrorStats compare_ck_vs_hip(const ck_tile::HostTensor& ck_ref, - const Tensor& hip_ref_colmajor, - size_t m, - size_t n) { - ErrorStats stats; - const bf16_t* hip = hip_ref_colmajor.rowwise_cpu_dptr(); - for (size_t i = 0; i < m; ++i) { - for (size_t j = 0; j < n; ++j) { - add_err(stats, to_float(ck_ref(i, j)), to_float(hip[j * m + i])); - } - } - return stats; -} - -static void run_case(const CaseConfig& cfg) { - set_env_defaults(); - - ASSERT_EQ(cfg.k % 128, 0UL) << "K must be a multiple of 128 for MXFP8"; - ASSERT_EQ(cfg.m_total % static_cast(cfg.experts), 0UL); - - cudaDeviceProp prop; - NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); -#ifdef __HIP_PLATFORM_AMD__ - const bool is_gfx950_or_newer_cdna = (prop.major == 9 && prop.minor >= 5); - const bool is_gfx1250 = (prop.major == 12 && prop.minor == 5); - - if (!is_gfx950_or_newer_cdna && !is_gfx1250) { - GTEST_SKIP() << "MXFP8 requires gfx950+ or gfx1250 in this test. GPU=" << prop.name - << " major=" << prop.major << " minor=" << prop.minor; - } -#endif - - const auto m_splits = split_even(cfg.m_total, cfg.experts); - const size_t per_m = m_splits[0]; - const int groups_to_ck = std::min(cfg.ck_ref_groups, cfg.experts); - - std::cout << "\n=== TE CK grouped MXFP8 forward reference comparison ===\n" - << "M_total=" << cfg.m_total << " N=" << cfg.n << " K=" << cfg.k - << " experts=" << cfg.experts << " per_expert_M=" << per_m - << " scale=" << cfg.scale << " seed=" << cfg.seed << "\n" - << "NVTE_USE_CUTLASS_GROUPED_GEMM=" << std::getenv("NVTE_USE_CUTLASS_GROUPED_GEMM") << "\n" - << "NVTE_ROCM_ENABLE_MXFP8=" << std::getenv("NVTE_ROCM_ENABLE_MXFP8") << "\n" - << "CK_TILE_USE_OCP_FP8=" << CK_TILE_USE_OCP_FP8 << "\n" - << "GPU=" << prop.name << " SM/CU count=" << prop.multiProcessorCount << "\n"; - - std::vector input_src; - std::vector weight_src; - std::vector input_mx; - std::vector weight_mx; - std::vector output_te; - std::vector output_hip_colmajor; - input_src.reserve(cfg.experts); - weight_src.reserve(cfg.experts); - input_mx.reserve(cfg.experts); - weight_mx.reserve(cfg.experts); - output_te.reserve(cfg.experts); - output_hip_colmajor.reserve(cfg.experts); - - for (int g = 0; g < cfg.experts; ++g) { - const size_t m = m_splits[g]; - input_src.emplace_back("input_src", std::vector{m, cfg.k}, DType::kBFloat16); - weight_src.emplace_back("weight_src", std::vector{cfg.n, cfg.k}, DType::kBFloat16); - - fill_randn_cpu(&input_src.back(), cfg.scale, cfg.seed + 1009 * g + 17); - fill_randn_cpu(&weight_src.back(), cfg.scale, cfg.seed + 1009 * g + 29); - - input_mx.emplace_back("input_mx", std::vector{m, cfg.k}, DType::kFloat8E4M3, - true, false, NVTEScalingMode::NVTE_MXFP8_1D_SCALING); - weight_mx.emplace_back("weight_mx", std::vector{cfg.n, cfg.k}, DType::kFloat8E4M3, - true, false, NVTEScalingMode::NVTE_MXFP8_1D_SCALING); - - nvte_quantize(input_src.back().data(), input_mx.back().data(), 0); - nvte_quantize(weight_src.back().data(), weight_mx.back().data(), 0); - - output_te.emplace_back("output_te", std::vector{m, cfg.n}, DType::kBFloat16); - output_hip_colmajor.emplace_back("output_hip_colmajor", std::vector{cfg.n, m}, DType::kBFloat16); - } - NVTE_CHECK_CUDA(cudaDeviceSynchronize()); - - Tensor workspace("workspace", std::vector{67108864}, DType::kByte); - - run_te_grouped_mxfp8_forward(weight_mx, input_mx, &output_te, &workspace, - prop.multiProcessorCount); - for (auto& out : output_te) out.to_cpu(); - - for (int g = 0; g < cfg.experts; ++g) { - run_hip_ref_for_group(input_mx[g], weight_mx[g], &output_hip_colmajor[g], - m_splits[g], cfg.k, cfg.n); - output_hip_colmajor[g].to_cpu(); - expect_reference_match("group " + std::to_string(g) + " TE_vs_HIP_REF", - compare_te_vs_hip(output_te[g], output_hip_colmajor[g], - m_splits[g], cfg.n), - 0.25f, - 0.03f); - } - - for (int g = 0; g < groups_to_ck; ++g) { - auto ck_ref = run_ck_tile_reference_for_group(input_mx[g], weight_mx[g], - m_splits[g], cfg.k, cfg.n); - expect_reference_match("group " + std::to_string(g) + " TE_vs_CK_REF ", - compare_te_vs_ck(output_te[g], ck_ref, m_splits[g], cfg.n), - 0.25f, - 0.03f); - expect_reference_match("group " + std::to_string(g) + " CK_vs_HIP_REF", - compare_ck_vs_hip(ck_ref, output_hip_colmajor[g], - m_splits[g], cfg.n), - 0.25f, - 0.03f); - } -} - -} // namespace - -class GroupedMXFP8ForwardRefsTestSuite : public ::testing::TestWithParam {}; - -TEST_P(GroupedMXFP8ForwardRefsTestSuite, MatchesCKTileAndHIPReferences) { - run_case(GetParam()); -} - -INSTANTIATE_TEST_SUITE_P( - OperatorTest, - GroupedMXFP8ForwardRefsTestSuite, - ::testing::Values( - // Small enough for quick CI-style sanity. - CaseConfig{1024, 1024, 1024, 2, 0.25f, 1234, 1}, - // Reproduces the earlier forward-only "failure" scale/shape regime, but - // validates against true MXFP8 references instead of BF16. - CaseConfig{1536, 4096, 4096, 3, 0.25f, 1234, 1}, - // Llama-ish suspicious path. CK reference only group 0 to keep runtime sane; - // HIP reference checks all groups. - CaseConfig{4096, 12288, 4096, 4, 0.25f, 1234, 1}), - case_name); From e7159c495b53b733a5a42a0892dd739ec5bfc89b Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Sun, 3 May 2026 16:00:35 +0000 Subject: [PATCH 10/47] include renamed test file --- .../cpp/operator/test_te_ck_grouped_mxfp8.cu | 629 ++++++++++++++++++ 1 file changed, 629 insertions(+) create mode 100644 tests/cpp/operator/test_te_ck_grouped_mxfp8.cu diff --git a/tests/cpp/operator/test_te_ck_grouped_mxfp8.cu b/tests/cpp/operator/test_te_ck_grouped_mxfp8.cu new file mode 100644 index 000000000..1ad32557c --- /dev/null +++ b/tests/cpp/operator/test_te_ck_grouped_mxfp8.cu @@ -0,0 +1,629 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +// TE CK grouped MXFP8 validation. +// +// Compares three paths for grouped MXFP8 GEMM across NN/NT/TN transpose layouts: +// 1. TE nvte_multi_tensor_gemm grouped path (CK backend selected by env) +// 2. ck_tile::reference_mx_gemm host reference, using exact quantized operands/scales +// 3. TE HIP reference kernel adapted from test_cublaslt_gemm.cu compute_ref_kernel +// +// Intended drop-in location: +// TransformerEngine/tests/cpp/operator/test_te_ck_grouped_mxfp8.cu + +#ifndef CK_TILE_USE_OCP_FP8 +#define CK_TILE_USE_OCP_FP8 1 +#endif + +#include +#include +#include + +#include +#include +#include + +#include "../test_common.h" + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace transformer_engine; +using namespace test; + +using fp8 = fp8e4m3; +using bf8 = fp8e5m2; +using bf16_t = bf16; +using e8m0_t_te = fp8e8m0; + +namespace { + +enum class MXOperandDType { + FP8, + BF8, +}; + +struct DTypeConfig { + const char* name; + MXOperandDType a; + MXOperandDType b; +}; + +static DType te_dtype(MXOperandDType t) { + return t == MXOperandDType::FP8 ? DType::kFloat8E4M3 : DType::kFloat8E5M2; +} + +struct LayoutConfig { + const char* name; + bool transa; + bool transb; +}; + +struct CaseConfig { + size_t m_total; + size_t n; + size_t k; + int experts; + float scale; + int seed; + LayoutConfig layout; + DTypeConfig dtype; +}; + +static std::string case_name(const testing::TestParamInfo& info) { + const auto& c = info.param; + std::ostringstream os; + os << "M" << c.m_total << "_N" << c.n << "_K" << c.k + << "_E" << c.experts << "_" << c.layout.name << "_" << c.dtype.name; + return os.str(); +} + +static void set_env_defaults() { + setenv("NVTE_USE_CUTLASS_GROUPED_GEMM", "1", 1); + setenv("NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK", "1", 0); + setenv("NVTE_ROCM_ENABLE_MXFP8", "1", 0); +} + +static float to_float(float x) { return x; } +static float to_float(const bf16_t& x) { return static_cast(x); } +static float to_float(const ck_tile::bfloat16_t& x) { return static_cast(x); } + +__device__ __host__ __forceinline__ float ref_gelu_unused(float x) { + float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); + return x * cdf; +} + +template +__global__ void compute_ref_kernel( + const A_Type* __restrict__ a_data, + const B_Type* __restrict__ b_data, + float a_scale_inv_scalar, + float b_scale_inv_scalar, + const e8m0_t_te* __restrict__ a_scale_inv_mxfp8, + const e8m0_t_te* __restrict__ b_scale_inv_mxfp8, + size_t a_scale_ld, + size_t b_scale_ld, + bool a_scale_is_colwise, + bool b_scale_is_colwise, + const Bias_Type* __restrict__ bias_data, + float d_scale, + size_t m, size_t k, size_t n, + D_Type* __restrict__ d_data, + float* __restrict__ d_amax, + Gelu_Type* __restrict__ gelu_data, + bool transa, + bool transb, + bool is_fp8_output, + bool a_is_colwise, + bool b_is_colwise, + bool use_mxfp8) { + const size_t jj = blockIdx.x * blockDim.x + threadIdx.x; + const size_t ii = blockIdx.y * blockDim.y + threadIdx.y; + const bool in_range = (ii < m) && (jj < n); + + float val = 0.0f; + if (in_range) { + for (size_t kk = 0; kk < k; ++kk) { + size_t a_idx = 0; + size_t b_idx = 0; + + if (use_mxfp8) { + a_idx = transa ? (ii * k + kk) : (kk * m + ii); + b_idx = transb ? (kk * n + jj) : (jj * k + kk); + } else { + a_idx = a_is_colwise ? (ii * k + kk) + : (transa ? (ii * k + kk) : (kk * m + ii)); + b_idx = b_is_colwise ? (jj * k + kk) + : (transb ? (kk * n + jj) : (jj * k + kk)); + } + + float a_scale_inv_val = a_scale_inv_scalar; + float b_scale_inv_val = b_scale_inv_scalar; + + if (a_scale_inv_mxfp8) { + const size_t kc = kk / 32; + const size_t a_scale_idx = + a_scale_is_colwise ? (kc * a_scale_ld + ii) : (ii * a_scale_ld + kc); + const size_t b_scale_idx = + b_scale_is_colwise ? (kc * b_scale_ld + jj) : (jj * b_scale_ld + kc); + a_scale_inv_val = exp2f(a_scale_inv_mxfp8[a_scale_idx] - 127.0f); + b_scale_inv_val = exp2f(b_scale_inv_mxfp8[b_scale_idx] - 127.0f); + } + + const float a_val = static_cast(a_data[a_idx]); + const float b_val = static_cast(b_data[b_idx]); + val += a_scale_inv_val * a_val * b_scale_inv_val * b_val; + } + + if (bias_data) val += static_cast(bias_data[ii]); + if (gelu_data) { + gelu_data[ii + jj * m] = static_cast(val); + val = ref_gelu_unused(val); + } + + const float scaled = val * d_scale; + d_data[ii + jj * m] = static_cast(scaled); + } + + if (is_fp8_output && d_amax) { + const int tid = threadIdx.y * blockDim.x + threadIdx.x; + const int nthreads = blockDim.x * blockDim.y; + extern __shared__ float s_amax[]; + s_amax[tid] = in_range ? fabsf(val) : 0.0f; + __syncthreads(); + for (int offset = nthreads / 2; offset > 0; offset /= 2) { + if (tid < offset) s_amax[tid] = fmaxf(s_amax[tid], s_amax[tid + offset]); + __syncthreads(); + } + if (tid == 0) atomicMax(d_amax, s_amax[0]); + } +} + +template +static void fill_randn_cpu(Tensor* t, float scale, int seed) { + std::mt19937 gen(seed); + std::normal_distribution dist(0.0f, scale); + const size_t n = product(t->rowwise_shape()); + T* ptr = t->rowwise_cpu_dptr(); + for (size_t i = 0; i < n; ++i) ptr[i] = static_cast(dist(gen)); + t->from_cpu(); +} + +static std::vector split_even(size_t m_total, int experts) { + NVTE_CHECK(experts > 0, "experts must be > 0"); + NVTE_CHECK(m_total % static_cast(experts) == 0, + "m_total must be divisible by experts"); + return std::vector(experts, m_total / static_cast(experts)); +} + +static std::vector a_shape_for_te(size_t n, size_t k, bool transa) { + // TE grouped GEMM computes output shape [M,N]. A contributes the N dimension. + // transa=true means physical A is [N,K]; transa=false means physical A is [K,N]. + return transa ? std::vector{n, k} : std::vector{k, n}; +} + +static std::vector b_shape_for_te(size_t m, size_t k, bool transb) { + // B contributes the M dimension. + // transb=false means physical B is [M,K]; transb=true means physical B is [K,M]. + return transb ? std::vector{k, m} : std::vector{m, k}; +} + +struct ErrorStats { + size_t count = 0; + double sum_abs = 0.0; + double sum_rel = 0.0; + double sum_ref_abs = 0.0; + double sum_got_abs = 0.0; + float max_abs = 0.0f; + float max_rel = 0.0f; + std::vector abs_errs; +}; + +static void add_err(ErrorStats& s, float got, float ref) { + const float abs_err = std::abs(got - ref); + const float rel_err = abs_err / std::max(std::abs(ref), 1.0e-12f); + s.count++; + s.sum_abs += abs_err; + s.sum_rel += rel_err; + s.sum_ref_abs += std::abs(ref); + s.sum_got_abs += std::abs(got); + s.max_abs = std::max(s.max_abs, abs_err); + s.max_rel = std::max(s.max_rel, rel_err); + s.abs_errs.push_back(abs_err); +} + + +static void expect_reference_match(const std::string& label, + const ErrorStats& stats, + float max_abs_limit, + float mean_abs_limit) { + EXPECT_LE(stats.max_abs, max_abs_limit) << label; + EXPECT_LE(stats.sum_abs / static_cast(std::max(stats.count, 1)), + static_cast(mean_abs_limit)) << label; +} + +static void run_te_grouped_mxfp8(const std::vector& a_mx, + const std::vector& b_mx, + std::vector* outputs, + Tensor* workspace, + bool transa, + bool transb, + int math_sm_count) { + const size_t groups = a_mx.size(); + std::vector A(groups), B(groups), D(groups), Bias(groups), PreGelu(groups); + std::vector empty_bias(groups), empty_pregelu(groups); + + for (size_t i = 0; i < groups; ++i) { + A[i] = const_cast(a_mx[i]).data(); + B[i] = const_cast(b_mx[i]).data(); + D[i] = (*outputs)[i].data(); + Bias[i] = empty_bias[i].data(); + PreGelu[i] = empty_pregelu[i].data(); + } + + std::vector Workspaces(1); + Workspaces[0] = workspace->data(); + + nvte_multi_tensor_gemm(A.data(), + B.data(), + D.data(), + Bias.data(), + PreGelu.data(), + groups, + transa, + transb, + false, // grad + Workspaces.data(), + false, // accumulate + false, // use_split_accumulator + math_sm_count, + 0); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); +} + +template +static void run_hip_ref_for_group(const Tensor& a_mx, + const Tensor& b_mx, + Tensor* ref_d_colmajor, + size_t m, + size_t k, + size_t n, + bool transa, + bool transb) { + // TE grouped GEMM output is op(B) [M,K] * op(A) [K,N] -> [M,N]. + // compute_ref_kernel convention is A_left [M,K] * B_right [K,N]. + // Therefore left operand is TE B and right operand is TE A. + const bool left_transa = !transb; + const bool right_transb = !transa; + + const bool left_use_colwise = !left_transa; // Same rule as test_cublaslt_gemm run_reference. + const bool right_use_colwise = right_transb; // Same rule as test_cublaslt_gemm run_reference. + + const auto left_s = left_use_colwise ? b_mx.columnwise_scale_inv_shape() + : b_mx.rowwise_scale_inv_shape(); + const auto right_s = right_use_colwise ? a_mx.columnwise_scale_inv_shape() + : a_mx.rowwise_scale_inv_shape(); + NVTE_CHECK(left_s.ndim == 2 && right_s.ndim == 2, "Expected 2D MXFP8 scale_inv tensors"); + const size_t left_scale_ld = left_s.data[1]; + const size_t right_scale_ld = right_s.data[1]; + + dim3 block(16, 16); + dim3 grid(static_cast((n + block.x - 1) / block.x), + static_cast((m + block.y - 1) / block.y)); + const size_t shmem_bytes = size_t(block.x) * size_t(block.y) * sizeof(float); + + compute_ref_kernel + <<>>( + static_cast(left_use_colwise ? b_mx.columnwise_dptr() : b_mx.rowwise_dptr()), + static_cast(right_use_colwise ? a_mx.columnwise_dptr() : a_mx.rowwise_dptr()), + 1.0f, + 1.0f, + static_cast(left_use_colwise ? b_mx.columnwise_scale_inv_dptr() + : b_mx.rowwise_scale_inv_dptr()), + static_cast(right_use_colwise ? a_mx.columnwise_scale_inv_dptr() + : a_mx.rowwise_scale_inv_dptr()), + left_scale_ld, + right_scale_ld, + left_use_colwise, + right_use_colwise, + nullptr, + 1.0f, + m, k, n, + static_cast(ref_d_colmajor->rowwise_dptr()), + nullptr, + nullptr, + left_transa, + right_transb, + false, + left_use_colwise, + right_use_colwise, + true); + NVTE_CHECK_CUDA(cudaGetLastError()); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); +} + +template +static ck_tile::HostTensor run_ck_tile_reference_for_group( + const Tensor& a_mx, + const Tensor& b_mx, + size_t m, + size_t k, + size_t n, + bool transa, + bool transb) { + using namespace ck_tile::literals; + using AType = CkAType; + using BType = CkBType; + using CType = ck_tile::bfloat16_t; + using ScaleType = ck_tile::e8m0_t; + + const size_t kscale = k / 32; + + const bool left_transa = !transb; + const bool right_transb = !transa; + const bool left_use_colwise = !left_transa; + const bool right_use_colwise = right_transb; + + ck_tile::HostTensor a_left( + left_transa ? ck_tile::HostTensorDescriptor({m, k}, {k, 1_uz}) + : ck_tile::HostTensorDescriptor({m, k}, {1_uz, m})); + ck_tile::HostTensor b_right( + right_transb ? ck_tile::HostTensorDescriptor({k, n}, {n, 1_uz}) + : ck_tile::HostTensorDescriptor({k, n}, {1_uz, k})); + ck_tile::HostTensor c_ref( + ck_tile::HostTensorDescriptor({m, n}, {n, 1_uz})); + + ck_tile::HostTensor a_scale_ref( + left_use_colwise ? ck_tile::HostTensorDescriptor({m, kscale}, {1_uz, m}) + : ck_tile::HostTensorDescriptor({m, kscale}, {kscale, 1_uz})); + ck_tile::HostTensor b_scale_ref( + right_use_colwise ? ck_tile::HostTensorDescriptor({kscale, n}, {n, 1_uz}) + : ck_tile::HostTensorDescriptor({kscale, n}, {1_uz, kscale})); + + c_ref.SetZero(); + + NVTE_CHECK_CUDA(cudaMemcpy(a_left.data(), + left_use_colwise ? b_mx.columnwise_dptr() : b_mx.rowwise_dptr(), + a_left.get_element_space_size_in_bytes(), + cudaMemcpyDeviceToHost)); + NVTE_CHECK_CUDA(cudaMemcpy(b_right.data(), + right_use_colwise ? a_mx.columnwise_dptr() : a_mx.rowwise_dptr(), + b_right.get_element_space_size_in_bytes(), + cudaMemcpyDeviceToHost)); + NVTE_CHECK_CUDA(cudaMemcpy(a_scale_ref.data(), + left_use_colwise ? b_mx.columnwise_scale_inv_dptr() + : b_mx.rowwise_scale_inv_dptr(), + a_scale_ref.get_element_space_size_in_bytes(), + cudaMemcpyDeviceToHost)); + NVTE_CHECK_CUDA(cudaMemcpy(b_scale_ref.data(), + right_use_colwise ? a_mx.columnwise_scale_inv_dptr() + : a_mx.rowwise_scale_inv_dptr(), + b_scale_ref.get_element_space_size_in_bytes(), + cudaMemcpyDeviceToHost)); + + ck_tile::reference_mx_gemm( + a_left, b_right, c_ref, a_scale_ref, b_scale_ref); + return c_ref; +} + +static ErrorStats compare_te_vs_hip(const Tensor& te_out_rowmajor, + const Tensor& hip_ref_colmajor, + size_t m, + size_t n) { + ErrorStats stats; + const bf16_t* te = te_out_rowmajor.rowwise_cpu_dptr(); + const bf16_t* hip = hip_ref_colmajor.rowwise_cpu_dptr(); + for (size_t i = 0; i < m; ++i) { + for (size_t j = 0; j < n; ++j) { + add_err(stats, to_float(te[i * n + j]), to_float(hip[j * m + i])); + } + } + return stats; +} + +static ErrorStats compare_te_vs_ck(const Tensor& te_out_rowmajor, + const ck_tile::HostTensor& ck_ref, + size_t m, + size_t n) { + ErrorStats stats; + const bf16_t* te = te_out_rowmajor.rowwise_cpu_dptr(); + for (size_t i = 0; i < m; ++i) { + for (size_t j = 0; j < n; ++j) { + add_err(stats, to_float(te[i * n + j]), to_float(ck_ref(i, j))); + } + } + return stats; +} + +static ErrorStats compare_ck_vs_hip(const ck_tile::HostTensor& ck_ref, + const Tensor& hip_ref_colmajor, + size_t m, + size_t n) { + ErrorStats stats; + const bf16_t* hip = hip_ref_colmajor.rowwise_cpu_dptr(); + for (size_t i = 0; i < m; ++i) { + for (size_t j = 0; j < n; ++j) { + add_err(stats, to_float(ck_ref(i, j)), to_float(hip[j * m + i])); + } + } + return stats; +} + +template +static void run_case_typed(const CaseConfig& cfg) { + set_env_defaults(); + + ASSERT_EQ(cfg.k % 128, 0UL) << "K must be a multiple of 128 for MXFP8"; + ASSERT_EQ(cfg.m_total % static_cast(cfg.experts), 0UL); + + cudaDeviceProp prop; + NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); +#ifdef __HIP_PLATFORM_AMD__ + const bool is_gfx950_or_newer_cdna = (prop.major == 9 && prop.minor >= 5); + const bool is_gfx1250 = (prop.major == 12 && prop.minor == 5); + + if (!is_gfx950_or_newer_cdna && !is_gfx1250) { + GTEST_SKIP() << "MXFP8 requires gfx950+ or gfx1250 in this test. GPU=" << prop.name + << " major=" << prop.major << " minor=" << prop.minor; + } +#endif + + const auto m_splits = split_even(cfg.m_total, cfg.experts); + + std::vector a_src; + std::vector b_src; + std::vector a_mx; + std::vector b_mx; + std::vector output_te; + std::vector output_hip_colmajor; + a_src.reserve(cfg.experts); + b_src.reserve(cfg.experts); + a_mx.reserve(cfg.experts); + b_mx.reserve(cfg.experts); + output_te.reserve(cfg.experts); + output_hip_colmajor.reserve(cfg.experts); + + for (int g = 0; g < cfg.experts; ++g) { + const size_t m = m_splits[g]; + const auto a_shape = a_shape_for_te(cfg.n, cfg.k, cfg.layout.transa); + const auto b_shape = b_shape_for_te(m, cfg.k, cfg.layout.transb); + + a_src.emplace_back("a_src", a_shape, DType::kBFloat16); + b_src.emplace_back("b_src", b_shape, DType::kBFloat16); + + fill_randn_cpu(&a_src.back(), cfg.scale, cfg.seed + 1009 * g + 17); + fill_randn_cpu(&b_src.back(), cfg.scale, cfg.seed + 1009 * g + 29); + + // Allocate both rowwise and columnwise MX views so the backend can canonicalize NN/NT/TN. + a_mx.emplace_back("a_mx", a_shape, te_dtype(cfg.dtype.a), + true, true, NVTEScalingMode::NVTE_MXFP8_1D_SCALING); + b_mx.emplace_back("b_mx", b_shape, te_dtype(cfg.dtype.b), + true, true, NVTEScalingMode::NVTE_MXFP8_1D_SCALING); + + nvte_quantize(a_src.back().data(), a_mx.back().data(), 0); + nvte_quantize(b_src.back().data(), b_mx.back().data(), 0); + + output_te.emplace_back("output_te", std::vector{m, cfg.n}, DType::kBFloat16); + output_hip_colmajor.emplace_back("output_hip_colmajor", std::vector{cfg.n, m}, DType::kBFloat16); + } + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + + Tensor workspace("workspace", std::vector{67108864}, DType::kByte); + + run_te_grouped_mxfp8(a_mx, b_mx, &output_te, &workspace, + cfg.layout.transa, cfg.layout.transb, + prop.multiProcessorCount); + for (auto& out : output_te) out.to_cpu(); + + for (int g = 0; g < cfg.experts; ++g) { + run_hip_ref_for_group(a_mx[g], b_mx[g], &output_hip_colmajor[g], + m_splits[g], cfg.k, cfg.n, + cfg.layout.transa, cfg.layout.transb); + output_hip_colmajor[g].to_cpu(); + expect_reference_match("group " + std::to_string(g) + " TE_vs_HIP_REF", + compare_te_vs_hip(output_te[g], output_hip_colmajor[g], + m_splits[g], cfg.n), + 0.25f, + 0.03f); + } + + for (int g = 0; g < cfg.experts; ++g) { + auto ck_ref = run_ck_tile_reference_for_group(a_mx[g], b_mx[g], + m_splits[g], cfg.k, cfg.n, + cfg.layout.transa, cfg.layout.transb); + expect_reference_match("group " + std::to_string(g) + " TE_vs_CK_REF ", + compare_te_vs_ck(output_te[g], ck_ref, m_splits[g], cfg.n), + 0.25f, + 0.03f); + expect_reference_match("group " + std::to_string(g) + " CK_vs_HIP_REF", + compare_ck_vs_hip(ck_ref, output_hip_colmajor[g], + m_splits[g], cfg.n), + 0.25f, + 0.03f); + } +} + +static void run_case(const CaseConfig& cfg) { + if (cfg.dtype.a == MXOperandDType::FP8 && cfg.dtype.b == MXOperandDType::FP8) { + run_case_typed(cfg); + } else if (cfg.dtype.a == MXOperandDType::FP8 && cfg.dtype.b == MXOperandDType::BF8) { + run_case_typed(cfg); + } else if (cfg.dtype.a == MXOperandDType::BF8 && cfg.dtype.b == MXOperandDType::FP8) { + run_case_typed(cfg); + } else { + run_case_typed(cfg); + } +} + +} // namespace + +class GroupedMXFP8TestSuite : public ::testing::TestWithParam {}; + +TEST_P(GroupedMXFP8TestSuite, MatchesCKTileAndHIPReferences) { + run_case(GetParam()); +} + +static constexpr LayoutConfig kNN{"NN", false, false}; +static constexpr LayoutConfig kNT{"NT", false, true}; +static constexpr LayoutConfig kTN{"TN", true, false}; + +static constexpr DTypeConfig kFP8FP8{"FP8xFP8", MXOperandDType::FP8, MXOperandDType::FP8}; +static constexpr DTypeConfig kFP8BF8{"FP8xBF8", MXOperandDType::FP8, MXOperandDType::BF8}; +static constexpr DTypeConfig kBF8FP8{"BF8xFP8", MXOperandDType::BF8, MXOperandDType::FP8}; +static constexpr DTypeConfig kBF8BF8{"BF8xBF8", MXOperandDType::BF8, MXOperandDType::BF8}; + +static std::vector make_cases() { + const std::vector dtypes = {kFP8FP8, kFP8BF8, kBF8FP8, kBF8BF8}; + const std::vector base_cases = { + // Small sanity across NN/NT/TN. + CaseConfig{1024, 1024, 1024, 2, 0.25f, 1234, kNN, kFP8FP8}, + CaseConfig{1024, 1024, 1024, 2, 0.25f, 1234, kNT, kFP8FP8}, + CaseConfig{1024, 1024, 1024, 2, 0.25f, 1234, kTN, kFP8FP8}, + // Earlier failure regime across NN/NT/TN. + CaseConfig{1536, 4096, 4096, 3, 0.25f, 1234, kNN, kFP8FP8}, + CaseConfig{1536, 4096, 4096, 3, 0.25f, 1234, kNT, kFP8FP8}, + CaseConfig{1536, 4096, 4096, 3, 0.25f, 1234, kTN, kFP8FP8}, + // Llama-ish suspicious path across NN/NT/TN. + CaseConfig{4096, 12288, 4096, 4, 0.25f, 1234, kNN, kFP8FP8}, + CaseConfig{4096, 12288, 4096, 4, 0.25f, 1234, kNT, kFP8FP8}, + CaseConfig{4096, 12288, 4096, 4, 0.25f, 1234, kTN, kFP8FP8}, + }; + + std::vector cases; + cases.reserve(base_cases.size() * dtypes.size()); + for (const auto& base : base_cases) { + for (const auto& dtype : dtypes) { + CaseConfig c = base; + c.dtype = dtype; + cases.push_back(c); + } + } + return cases; +} + +static const std::vector kCases = make_cases(); + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + GroupedMXFP8TestSuite, + ::testing::ValuesIn(kCases), + case_name); From 972cea3035c8c45d68a8e86cd512a8bd8f7badc1 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Sun, 3 May 2026 17:22:45 +0000 Subject: [PATCH 11/47] clean up code --- tests/cpp/operator/CMakeLists.txt | 2 +- ...uped_mxfp8.cu => test_ck_grouped_mxfp8.cu} | 10 ++--- .../ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp | 43 ++++++------------- 3 files changed, 17 insertions(+), 38 deletions(-) rename tests/cpp/operator/{test_te_ck_grouped_mxfp8.cu => test_ck_grouped_mxfp8.cu} (98%) diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index c81ab1e62..4f87a9091 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -16,7 +16,7 @@ list(APPEND test_cuda_sources test_dequantize_mxfp8.cu test_dequantize_nvfp4.cu test_cast_nvfp4_transpose.cu - test_te_ck_grouped_mxfp8.cu + test_ck_grouped_mxfp8.cu test_transpose.cu test_cast_transpose.cu test_cast_transpose_current_scaling.cu diff --git a/tests/cpp/operator/test_te_ck_grouped_mxfp8.cu b/tests/cpp/operator/test_ck_grouped_mxfp8.cu similarity index 98% rename from tests/cpp/operator/test_te_ck_grouped_mxfp8.cu rename to tests/cpp/operator/test_ck_grouped_mxfp8.cu index 1ad32557c..7ea939320 100644 --- a/tests/cpp/operator/test_te_ck_grouped_mxfp8.cu +++ b/tests/cpp/operator/test_ck_grouped_mxfp8.cu @@ -10,9 +10,6 @@ // 1. TE nvte_multi_tensor_gemm grouped path (CK backend selected by env) // 2. ck_tile::reference_mx_gemm host reference, using exact quantized operands/scales // 3. TE HIP reference kernel adapted from test_cublaslt_gemm.cu compute_ref_kernel -// -// Intended drop-in location: -// TransformerEngine/tests/cpp/operator/test_te_ck_grouped_mxfp8.cu #ifndef CK_TILE_USE_OCP_FP8 #define CK_TILE_USE_OCP_FP8 1 @@ -478,12 +475,11 @@ static void run_case_typed(const CaseConfig& cfg) { cudaDeviceProp prop; NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); #ifdef __HIP_PLATFORM_AMD__ - const bool is_gfx950_or_newer_cdna = (prop.major == 9 && prop.minor >= 5); const bool is_gfx1250 = (prop.major == 12 && prop.minor == 5); - if (!is_gfx950_or_newer_cdna && !is_gfx1250) { - GTEST_SKIP() << "MXFP8 requires gfx950+ or gfx1250 in this test. GPU=" << prop.name - << " major=" << prop.major << " minor=" << prop.minor; + if (!is_gfx1250) { + GTEST_SKIP() << "This MXFP8 grouped GEMM test currently exercises the gfx1250-compatible CK pipeline only. GPU=" + << prop.name << " major=" << prop.major << " minor=" << prop.minor; } #endif diff --git a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp index 4a5a4aaa9..4d7323be3 100644 --- a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp @@ -188,7 +188,7 @@ void preShuffleScaleBuffer_gfx1250(const ScaleType* src, NVTE_CHECK_CUDA(hipGetLastError()); } -template +template bool invoke_mx_grouped_gemm(const std::vector& descs, const GroupedGemmRunContext& ctx, const ck_tile::stream_config& stream_cfg) { // check hardware WMMA support for the warp tile @@ -261,7 +261,7 @@ bool invoke_mx_grouped_gemm(const std::vector& descs, con BScaleType>; /* make pipeline selective */ using GemmPipeline = - typename MxGemmPipelineTypeSelector::pipeline; using GemmEpilogue = ck_tile::TdmEpilogue< ck_tile::CShuffleEpilogueProblemcolumnwise_scale_inv : B0_te->scale_inv; NVTE_CHECK(A0_data.dptr != nullptr, - "ck_tile_mx_grouped_gemm: effective A[0] data is not initialized"); + "ck_tile_mx_grouped_gemm: A[0] data is not initialized"); NVTE_CHECK(B0_data.dptr != nullptr, - "ck_tile_mx_grouped_gemm: effective B[0] data is not initialized"); + "ck_tile_mx_grouped_gemm: B[0] data is not initialized"); NVTE_CHECK(A0_scale.dptr != nullptr, - "ck_tile_mx_grouped_gemm: effective A[0] scale_inv is not initialized"); + "ck_tile_mx_grouped_gemm: A[0] scale_inv is not initialized"); NVTE_CHECK(B0_scale.dptr != nullptr, - "ck_tile_mx_grouped_gemm: effective B[0] scale_inv is not initialized"); + "ck_tile_mx_grouped_gemm: B[0] scale_inv is not initialized"); const auto a_scale_dtype = A0_scale.dtype; const auto b_scale_dtype = B0_scale.dtype; @@ -377,8 +381,8 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, const auto a_dtype = A0_data.dtype; const auto b_dtype = B0_data.dtype; const auto d_dtype = D0->dtype(); - NVTE_CHECK(is_fp8_dtype(a_dtype), "ck_tile_mx_grouped_gemm: effective A dtype must be FP8"); - NVTE_CHECK(is_fp8_dtype(b_dtype), "ck_tile_mx_grouped_gemm: effective B dtype must be FP8"); + NVTE_CHECK(is_fp8_dtype(a_dtype), "ck_tile_mx_grouped_gemm: A dtype must be FP8"); + NVTE_CHECK(is_fp8_dtype(b_dtype), "ck_tile_mx_grouped_gemm: B dtype must be FP8"); using AScaleType = ck_tile::e8m0_t; using BScaleType = ck_tile::e8m0_t; @@ -432,11 +436,6 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; - // MXFP8 columnwise_data is not a physical transpose. It has the same - // logical tensor shape as rowwise data, but is quantized with scales - // along the other dimension. Therefore dims/strides must always be - // derived from the TE tensor shape, not from columnwise_data.shape - // interpreted as a transposed storage view. if (!get_flat_2d_dims(*A_te, Ad0, Ad1)) { NVTE_ERROR("ck_tile_mx_grouped_gemm: expected rank>=2 for normalized A in group ", i); } @@ -473,27 +472,11 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, ". D=", Dd0, "x", Dd1, ", expected=", M, "x", N); } - if (i == 0) { - printf("[MX CK] transA=%d transB=%d use_a_col=%d use_b_col=%d " - "M=%ld N=%ld K=%ld Ad=[%ld,%ld] Bd=[%ld,%ld] " - "a_scale_shape=[%zu,%zu] b_scale_shape=[%zu,%zu]\n", - static_cast(ctx.transA), - static_cast(ctx.transB), - static_cast(ctx.use_a_colwise_data), - static_cast(ctx.use_b_colwise_data), - M, N, K, Ad0, Ad1, Bd0, Bd1, - a_scales.shape.size() > 0 ? a_scales.shape[0] : 0, - a_scales.shape.size() > 1 ? a_scales.shape[1] : 0, - b_scales.shape.size() > 0 ? b_scales.shape[0] : 0, - b_scales.shape.size() > 1 ? b_scales.shape[1] : 0); - } - const ck_tile::index_t stride_A = static_cast(Ad1); const ck_tile::index_t stride_B = static_cast(Bd1); const ck_tile::index_t stride_E = static_cast(Dd1); // Pre-shuffle scale buffers for the hardware. - // For the NT-normalized presentation, A scales are MxKScale and B scales are NxKScale. const int a_scale_actual_rows = static_cast(M); const int a_scale_output_rows = ck_tile::integer_least_multiple( From c0fabff0c93569b39a9006f49dea35245f52b8df Mon Sep 17 00:00:00 2001 From: Aristotle <89488299+aris134@users.noreply.github.com> Date: Wed, 6 May 2026 16:40:13 -0400 Subject: [PATCH 12/47] Update cublaslt_gemm.cu --- transformer_engine/common/gemm/cublaslt_gemm.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index b3863350e..1aef1b0de 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1123,7 +1123,8 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor if (!use_cutlass || num_gemms == 1) { #else // Currently only support cutlass group gemm on Hopper Arch - if (!(is_hopper && use_cutlass)) { + // if (!(is_hopper && use_cutlass)) { + if (!use_cutlass) { #endif if (warn_fallback) { NVTE_WARN("Fallback to cuBLAS grouped GEMM."); From 3db2e5a40d72f9c17efc0e190e37392aeed94ee7 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 11 May 2026 09:10:14 -0400 Subject: [PATCH 13/47] address pr comments --- tests/cpp/operator/CMakeLists.txt | 4 ++-- transformer_engine/common/gemm/cublaslt_gemm.cu | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 4f87a9091..c1bc43faa 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -16,7 +16,6 @@ list(APPEND test_cuda_sources test_dequantize_mxfp8.cu test_dequantize_nvfp4.cu test_cast_nvfp4_transpose.cu - test_ck_grouped_mxfp8.cu test_transpose.cu test_cast_transpose.cu test_cast_transpose_current_scaling.cu @@ -41,7 +40,8 @@ if(USE_CUDA) else() list(APPEND test_cuda_sources test_cublaslt_gemm.cu - test_cast_mxfp4_transpose.cu) + test_cast_mxfp4_transpose.cu + test_ck_grouped_mxfp8.cu) endif() if(USE_CUDA) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 1aef1b0de..445a5ce0e 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1123,7 +1123,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor if (!use_cutlass || num_gemms == 1) { #else // Currently only support cutlass group gemm on Hopper Arch - // if (!(is_hopper && use_cutlass)) { + if (!(is_hopper && use_cutlass)) { if (!use_cutlass) { #endif if (warn_fallback) { From 910d30fcfba0f481f09c6da13b3c5fda985e9021 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Sun, 17 May 2026 17:27:47 -0400 Subject: [PATCH 14/47] fix ck group mxfp8 dispatch --- transformer_engine/common/gemm/cublaslt_gemm.cu | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 445a5ce0e..5aa03d4f4 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -34,7 +34,7 @@ #else #include "ck_grouped_gemm/ck_grouped_gemm.h" #include "ck_mx_grouped_gemm/ck_mx_grouped_gemm.hpp" - +#endif #ifndef __HIP_PLATFORM_AMD__ namespace { @@ -1205,7 +1205,8 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && is_supported_dtype() && #ifdef __HIP_PLATFORM_AMD__ true) { - const bool mxfp8_gemm = !use_fp4 && is_mxfp8_scaling(inputA->scaling_mode); + auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]); + const bool mxfp8_gemm = transformer_engine::is_mxfp8_scaling(inputA->scaling_mode); bool handled_by_ck = false; if (mxfp8_gemm) { From 1b66d2990b49c75ee2167b41435dee66dc37fb4b Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Sun, 17 May 2026 17:46:54 -0400 Subject: [PATCH 15/47] update CMakeLists.txt --- transformer_engine/common/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 774065fca..32c8b95dc 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -259,6 +259,7 @@ else() gemm/ck_grouped_gemm/ck_grouped_gemm.cpp gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp + gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp amd_detail/system.cpp) list(APPEND transformer_engine_cuda_sources fused_attn_rocm/fused_attn_aotriton.cpp From 23b505f486ff05d78ebf79882f7794cd59da050c Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Tue, 19 May 2026 01:20:23 +0000 Subject: [PATCH 16/47] Add direct ROCm libraries dependency for CK grouped GEMM --- .gitmodules | 4 ++++ 3rdparty/rocm_libraries | 1 + transformer_engine/common/CMakeLists.txt | 2 +- 3 files changed, 6 insertions(+), 1 deletion(-) create mode 160000 3rdparty/rocm_libraries diff --git a/.gitmodules b/.gitmodules index c81bdb590..8c466697e 100644 --- a/.gitmodules +++ b/.gitmodules @@ -26,3 +26,7 @@ [submodule "3rdparty/QoLA"] path = 3rdparty/QoLA url = https://github.com/Micky774/QoLA.git +[submodule "rocm_libraries"] + path = 3rdparty/rocm_libraries + url = https://github.com/ROCm/rocm-libraries.git + branch = users/jia/ck/fix_grouped_gemm_quant_mxtype diff --git a/3rdparty/rocm_libraries b/3rdparty/rocm_libraries new file mode 160000 index 000000000..66b1d1467 --- /dev/null +++ b/3rdparty/rocm_libraries @@ -0,0 +1 @@ +Subproject commit 66b1d146722c42d86a794ebc9c6097c2e1c9f7a4 diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 32c8b95dc..94bb076fc 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -355,7 +355,7 @@ set_property( PROPERTY COMPILE_OPTIONS "-g0;-dopt=on") else() - set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/aiter/3rdparty/composable_kernel) + set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/rocm_libraries/projects/composablekernel) target_include_directories(transformer_engine PRIVATE ${CK_ROOT}/include) endif() #USE_CUDA From 746afea76bee6ab54b491a589d1a2586eb8787d6 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Tue, 19 May 2026 22:03:53 +0000 Subject: [PATCH 17/47] Remove redundant MXFP8 env override from grouped linear test --- tests/pytorch/test_numerics.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 642999ef7..d9c7d1fb0 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2253,7 +2253,6 @@ def test_grouped_linear_accuracy_cutlass( delay_wgrad_compute, ): os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" - os.environ["NVTE_ROCM_ENABLE_MXFP8"] = "1" test_grouped_linear_accuracy( dtype, num_gemms, @@ -2269,7 +2268,6 @@ def test_grouped_linear_accuracy_cutlass( use_cutlass=True, ) os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) - os.environ.pop("NVTE_ROCM_ENABLE_MXFP8", None) @pytest.mark.parametrize("dtype", param_types, ids=str) From 175855d3395e05110529909fc77faff82d986378 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Wed, 20 May 2026 00:47:37 +0000 Subject: [PATCH 18/47] factor out common definitions from mxfp8 ck ggemm --- transformer_engine/common/CMakeLists.txt | 2 +- .../ck_grouped_gemm/ck_grouped_gemm_common.h | 4 + .../ck_mx_grouped_gemm.cpp | 109 ++++-------------- .../ck_mx_grouped_gemm.hpp | 0 .../common/gemm/cublaslt_gemm.cu | 2 +- 5 files changed, 31 insertions(+), 86 deletions(-) rename transformer_engine/common/gemm/{ck_mx_grouped_gemm => ck_grouped_gemm}/ck_mx_grouped_gemm.cpp (87%) rename transformer_engine/common/gemm/{ck_mx_grouped_gemm => ck_grouped_gemm}/ck_mx_grouped_gemm.hpp (100%) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 94bb076fc..1c60efae0 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -259,7 +259,7 @@ else() gemm/ck_grouped_gemm/ck_grouped_gemm.cpp gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp - gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp + gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp amd_detail/system.cpp) list(APPEND transformer_engine_cuda_sources fused_attn_rocm/fused_attn_aotriton.cpp diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h index c89f10232..e8176be26 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h @@ -8,6 +8,7 @@ #include +#include #include #include #include @@ -19,6 +20,9 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm/kernel/mx_grouped_gemm_kernel.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" namespace transformer_engine { namespace grouped_gemm { diff --git a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp similarity index 87% rename from transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp rename to transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp index 4d7323be3..1228c85a7 100644 --- a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp @@ -4,63 +4,13 @@ * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ -#include -#include "../../common.h" - -#include "ck_tile/core.hpp" -#include "ck_tile/host/kernel_launch.hpp" -#include "ck_tile/ops/epilogue.hpp" -#include "ck_tile/ops/gemm.hpp" -#include "ck_tile/ops/gemm/kernel/mx_grouped_gemm_kernel.hpp" -#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" - -#include -#include -#include +#include "ck_grouped_gemm_common.h" namespace transformer_engine { -namespace mx_grouped_gemm { +namespace grouped_gemm { -using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; -using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; using mx_grouped_gemm_kargs = ck_tile::MxGroupedGemmHostArgs<>; -template struct TETypeToCKType; -template <> struct TETypeToCKType { using type = ck_tile::fp8_t; }; -template <> struct TETypeToCKType { using type = ck_tile::bf8_t; }; -template <> struct TETypeToCKType { using type = ck_tile::half_t; }; -template <> struct TETypeToCKType { using type = ck_tile::bfloat16_t; }; -template <> struct TETypeToCKType { using type = float; }; - -struct GroupedGemmRunContext { - const NVTETensor* A = nullptr; - const NVTETensor* B = nullptr; - NVTETensor* D = nullptr; - - int group_num = 0; - bool transA = false; - bool transB = false; - - void* workspace = nullptr; - size_t workspace_bytes = 0; - hipStream_t stream = nullptr; - - bool use_a_colwise_data = false; - bool use_b_colwise_data = false; -}; - -// Treat TE tensors as generalized 2D matrices by flattening: -// (D1, D2, ..., Dn) -> (D1*...*D(n-1), Dn), consistent with TE Tensor::flat_*_dim. -static inline bool get_flat_2d_dims(const transformer_engine::Tensor& t, - int64_t& d0, int64_t& d1) { - if (t.shape().size() < 2) { - return false; - } - d0 = static_cast(t.flat_first_dim()); - d1 = static_cast(t.flat_last_dim()); - return true; -} - static constexpr ck_tile::index_t ScaleBlockSize = 32; enum struct MxGemmPipelineType @@ -87,17 +37,6 @@ struct MxGemmPipelineTypeSelector static constexpr auto GetName() { return "GemmPipelineAgBgCrCompTDMV2"; } }; -template -static inline bool has_sufficient_workspace(const GroupedGemmRunContext& ctx) { - const size_t needed = Kernel::GetWorkSpaceSize(ctx.group_num); - if (!ctx.workspace || ctx.workspace_bytes < needed) { - NVTE_WARN("ck_tile_mx_grouped_gemm: insufficient workspace for CK path. Needed bytes=", needed, - ", available bytes=", ctx.workspace_bytes, ". Falling back."); - return false; - } - return true; -} - struct GroupedGemKernelParam_Wmma { static const bool kPadM = false; @@ -283,7 +222,6 @@ bool invoke_mx_grouped_gemm(const std::vector& descs, con 1, /*kNumWaveGroups_*/ false, /*FixedVectorSize_*/ 1, /*VectorSizeC_*/ - false, /*TiledMMAPermuteN_*/ 1, /*BlockedXDLN_PerWarp_*/ DoubleSmemBuffer, /*DoubleSmemBuffer*/ AType, /*AType_*/ @@ -395,18 +333,21 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, ws_bytes = ws_te->data.numel() * typeToSize(ws_te->data.dtype); } - GroupedGemmRunContext ctx = { - A_use, - B_use, - D, - group_num, - transA_use, - transB_use, - ws_ptr, - ws_bytes, - stream, - use_a_colwise_data, - use_b_colwise_data}; + GroupedGemmRunContext ctx{ + .A = A_use, + .B = B_use, + .D = D, + .N = 0, + .group_num = group_num, + .transA = transA_use, + .transB = transB_use, + .workspace = ws_ptr, + .workspace_bytes = ws_bytes, + .stream = stream, + .use_a_columnwise_data = use_a_colwise_data, + .use_b_columnwise_data = use_b_colwise_data, + .accumulate = false, + }; const ck_tile::stream_config s{ctx.stream}; @@ -426,13 +367,13 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, transformer_engine::Tensor* D_te = transformer_engine::convertNVTETensorCheck(ctx.D[i]); - const auto& a = ctx.use_a_colwise_data ? A_te->columnwise_data : A_te->data; - const auto& b = ctx.use_b_colwise_data ? B_te->columnwise_data : B_te->data; + const auto& a = ctx.use_a_columnwise_data ? A_te->columnwise_data : A_te->data; + const auto& b = ctx.use_b_columnwise_data ? B_te->columnwise_data : B_te->data; const auto& d = D_te->data; const auto& a_scales = - ctx.use_a_colwise_data ? A_te->columnwise_scale_inv : A_te->scale_inv; + ctx.use_a_columnwise_data ? A_te->columnwise_scale_inv : A_te->scale_inv; const auto& b_scales = - ctx.use_b_colwise_data ? B_te->columnwise_scale_inv : B_te->scale_inv; + ctx.use_b_columnwise_data ? B_te->columnwise_scale_inv : B_te->scale_inv; int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; @@ -503,7 +444,7 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, // TE rowwise MXFP8 scale_inv is [rows, KScale] and can be read with // KStride=true. TE columnwise_scale_inv is [KScale, rows] and must be // read with KStride=false before writing CK's canonical shuffled layout. - if (ctx.use_a_colwise_data) { + if (ctx.use_a_columnwise_data) { preShuffleScaleBuffer_gfx1250( reinterpret_cast(a_scales.dptr), reinterpret_cast(a_scale_shuffled_ptr), @@ -521,7 +462,7 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, stream); } - if (ctx.use_b_colwise_data) { + if (ctx.use_b_columnwise_data) { preShuffleScaleBuffer_gfx1250( reinterpret_cast(b_scales.dptr), reinterpret_cast(b_scale_shuffled_ptr), @@ -571,7 +512,7 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, return ok; } -} // namespace mx_grouped_gemm +} // namespace grouped_gemm } // namespace transformer_engine bool ck_tile_mx_grouped_gemm(const NVTETensor* A, @@ -583,6 +524,6 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, NVTETensor* workspace, bool accumulate, hipStream_t stream) { - return transformer_engine::mx_grouped_gemm::ck_tile_mx_grouped_gemm( + return transformer_engine::grouped_gemm::ck_tile_mx_grouped_gemm( A, B, D, group_num, transA, transB, workspace, accumulate, stream); } diff --git a/transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.hpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.hpp similarity index 100% rename from transformer_engine/common/gemm/ck_mx_grouped_gemm/ck_mx_grouped_gemm.hpp rename to transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.hpp diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 5aa03d4f4..792546f64 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -33,7 +33,7 @@ #include "./cutlass_grouped_gemm.cuh" #else #include "ck_grouped_gemm/ck_grouped_gemm.h" -#include "ck_mx_grouped_gemm/ck_mx_grouped_gemm.hpp" +#include "ck_grouped_gemm/ck_mx_grouped_gemm.hpp" #endif #ifndef __HIP_PLATFORM_AMD__ From f00fb7fc7b600785d9145b217e03c7d7473b417e Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Thu, 21 May 2026 15:57:35 +0000 Subject: [PATCH 19/47] add pr comments --- tests/cpp/operator/test_ck_grouped_mxfp8.cu | 75 +++------------------ 1 file changed, 9 insertions(+), 66 deletions(-) diff --git a/tests/cpp/operator/test_ck_grouped_mxfp8.cu b/tests/cpp/operator/test_ck_grouped_mxfp8.cu index 7ea939320..7456be812 100644 --- a/tests/cpp/operator/test_ck_grouped_mxfp8.cu +++ b/tests/cpp/operator/test_ck_grouped_mxfp8.cu @@ -9,11 +9,7 @@ // Compares three paths for grouped MXFP8 GEMM across NN/NT/TN transpose layouts: // 1. TE nvte_multi_tensor_gemm grouped path (CK backend selected by env) // 2. ck_tile::reference_mx_gemm host reference, using exact quantized operands/scales -// 3. TE HIP reference kernel adapted from test_cublaslt_gemm.cu compute_ref_kernel - -#ifndef CK_TILE_USE_OCP_FP8 -#define CK_TILE_USE_OCP_FP8 1 -#endif +// 3. TE HIP reference kernel simplified from test_cublaslt_gemm.cu compute_ref_kernel #include #include @@ -99,17 +95,10 @@ static void set_env_defaults() { setenv("NVTE_ROCM_ENABLE_MXFP8", "1", 0); } -static float to_float(float x) { return x; } static float to_float(const bf16_t& x) { return static_cast(x); } static float to_float(const ck_tile::bfloat16_t& x) { return static_cast(x); } -__device__ __host__ __forceinline__ float ref_gelu_unused(float x) { - float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); - return x * cdf; -} - -template +template __global__ void compute_ref_kernel( const A_Type* __restrict__ a_data, const B_Type* __restrict__ b_data, @@ -121,18 +110,10 @@ __global__ void compute_ref_kernel( size_t b_scale_ld, bool a_scale_is_colwise, bool b_scale_is_colwise, - const Bias_Type* __restrict__ bias_data, - float d_scale, size_t m, size_t k, size_t n, D_Type* __restrict__ d_data, - float* __restrict__ d_amax, - Gelu_Type* __restrict__ gelu_data, bool transa, - bool transb, - bool is_fp8_output, - bool a_is_colwise, - bool b_is_colwise, - bool use_mxfp8) { + bool transb) { const size_t jj = blockIdx.x * blockDim.x + threadIdx.x; const size_t ii = blockIdx.y * blockDim.y + threadIdx.y; const bool in_range = (ii < m) && (jj < n); @@ -143,15 +124,8 @@ __global__ void compute_ref_kernel( size_t a_idx = 0; size_t b_idx = 0; - if (use_mxfp8) { - a_idx = transa ? (ii * k + kk) : (kk * m + ii); - b_idx = transb ? (kk * n + jj) : (jj * k + kk); - } else { - a_idx = a_is_colwise ? (ii * k + kk) - : (transa ? (ii * k + kk) : (kk * m + ii)); - b_idx = b_is_colwise ? (jj * k + kk) - : (transb ? (kk * n + jj) : (jj * k + kk)); - } + a_idx = transa ? (ii * k + kk) : (kk * m + ii); + b_idx = transb ? (kk * n + jj) : (jj * k + kk); float a_scale_inv_val = a_scale_inv_scalar; float b_scale_inv_val = b_scale_inv_scalar; @@ -171,27 +145,7 @@ __global__ void compute_ref_kernel( val += a_scale_inv_val * a_val * b_scale_inv_val * b_val; } - if (bias_data) val += static_cast(bias_data[ii]); - if (gelu_data) { - gelu_data[ii + jj * m] = static_cast(val); - val = ref_gelu_unused(val); - } - - const float scaled = val * d_scale; - d_data[ii + jj * m] = static_cast(scaled); - } - - if (is_fp8_output && d_amax) { - const int tid = threadIdx.y * blockDim.x + threadIdx.x; - const int nthreads = blockDim.x * blockDim.y; - extern __shared__ float s_amax[]; - s_amax[tid] = in_range ? fabsf(val) : 0.0f; - __syncthreads(); - for (int offset = nthreads / 2; offset > 0; offset /= 2) { - if (tid < offset) s_amax[tid] = fmaxf(s_amax[tid], s_amax[tid + offset]); - __syncthreads(); - } - if (tid == 0) atomicMax(d_amax, s_amax[0]); + d_data[ii + jj * m] = static_cast(val); } } @@ -326,10 +280,9 @@ static void run_hip_ref_for_group(const Tensor& a_mx, dim3 block(16, 16); dim3 grid(static_cast((n + block.x - 1) / block.x), static_cast((m + block.y - 1) / block.y)); - const size_t shmem_bytes = size_t(block.x) * size_t(block.y) * sizeof(float); - compute_ref_kernel - <<>>( + compute_ref_kernel + <<>>( static_cast(left_use_colwise ? b_mx.columnwise_dptr() : b_mx.rowwise_dptr()), static_cast(right_use_colwise ? a_mx.columnwise_dptr() : a_mx.rowwise_dptr()), 1.0f, @@ -342,18 +295,10 @@ static void run_hip_ref_for_group(const Tensor& a_mx, right_scale_ld, left_use_colwise, right_use_colwise, - nullptr, - 1.0f, m, k, n, static_cast(ref_d_colmajor->rowwise_dptr()), - nullptr, - nullptr, left_transa, - right_transb, - false, - left_use_colwise, - right_use_colwise, - true); + right_transb); NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaDeviceSynchronize()); } @@ -474,14 +419,12 @@ static void run_case_typed(const CaseConfig& cfg) { cudaDeviceProp prop; NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); -#ifdef __HIP_PLATFORM_AMD__ const bool is_gfx1250 = (prop.major == 12 && prop.minor == 5); if (!is_gfx1250) { GTEST_SKIP() << "This MXFP8 grouped GEMM test currently exercises the gfx1250-compatible CK pipeline only. GPU=" << prop.name << " major=" << prop.major << " minor=" << prop.minor; } -#endif const auto m_splits = split_even(cfg.m_total, cfg.experts); From 45343f19172ab6ba90684772a3665c5c1ec56776 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 21 May 2026 13:10:22 -0500 Subject: [PATCH 20/47] add MXFP8 pre-swizzling for gfx1250 GEMM (#568) --- tests/cpp/operator/CMakeLists.txt | 4 +- tests/cpp/operator/test_cublaslt_gemm.cu | 101 ++++++++-- tests/cpp/operator/test_swizzle.cu | 180 ++++++++++++++++++ transformer_engine/common/swizzle/swizzle.cu | 178 +++++++++++++++++ .../jax/csrc/extensions/gemm.cpp | 28 ++- .../pytorch/csrc/extensions/gemm.cpp | 6 +- .../pytorch/csrc/extensions/swizzle.cpp | 15 ++ transformer_engine/pytorch/csrc/quantizer.cpp | 14 ++ transformer_engine/pytorch/csrc/util.h | 6 +- .../pytorch/tensor/mxfp8_tensor.py | 28 ++- 10 files changed, 524 insertions(+), 36 deletions(-) diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 0ebd7fdfe..20d6919cc 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -31,11 +31,11 @@ list(APPEND test_cuda_sources test_multi_unpadding.cu test_causal_softmax.cu test_swap_first_dims.cu + test_swizzle.cu ../test_common.cu) if(USE_CUDA) list(APPEND test_cuda_sources - test_cast_float8blockwise.cu - test_swizzle.cu) + test_cast_float8blockwise.cu) else() list(APPEND test_cuda_sources test_cublaslt_gemm.cu diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 85f183bf7..b8312de00 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -11,6 +11,7 @@ #include #include #include +#include #include #include "../test_common.h" @@ -30,7 +31,15 @@ std::vector> test_case_sizes = { std::vector> test_case_sizes_mxfp8 = { {32, 128, 16}, + {64, 128, 32}, + {128, 128, 64}, + {64, 256, 32}, + {128, 384, 64}, + {256, 512, 128}, + {512, 1024, 256}, {768, 3072, 4096}, + {1024, 2048, 128}, + {4096, 8192, 64}, }; // A, B, Bias, Gelu, D @@ -303,6 +312,40 @@ void cpu_rowwise_to_columnwise( } } +// Swizzle MXFP8 scale_inv of a test::Tensor in-place for gfx1250. +static void swizzle_mxfp8_scales(test::Tensor &t, bool rowwise) { + using namespace transformer_engine; + void *scale_ptr = rowwise ? t.rowwise_scale_inv_dptr() + : t.columnwise_scale_inv_dptr(); + if (!scale_ptr) return; + const NVTEShape scale_shape = rowwise ? t.rowwise_scale_inv_shape() + : t.columnwise_scale_inv_shape(); + const NVTEShape data_shape = rowwise ? t.rowwise_shape() + : t.columnwise_shape(); + size_t num_scales = 1; + for (size_t d = 0; d < scale_shape.ndim; d++) num_scales *= scale_shape.data[d]; + uint8_t *d_tmp = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&d_tmp, num_scales)); + TensorWrapper input_tw(NVTE_MXFP8_1D_SCALING); + TensorWrapper output_tw(NVTE_MXFP8_1D_SCALING); + output_tw.set_with_gemm_swizzled_scales(true); + if (rowwise) { + input_tw.set_rowwise_data(nullptr, t.dtype(), data_shape); + input_tw.set_rowwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape); + output_tw.set_rowwise_data(nullptr, t.dtype(), data_shape); + output_tw.set_rowwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape); + } else { + input_tw.set_columnwise_data(nullptr, t.dtype(), data_shape); + input_tw.set_columnwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape); + output_tw.set_columnwise_data(nullptr, t.dtype(), data_shape); + output_tw.set_columnwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape); + } + nvte_swizzle_scaling_factors(input_tw.data(), output_tw.data(), 0); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + NVTE_CHECK_CUDA(cudaMemcpy(scale_ptr, d_tmp, num_scales, cudaMemcpyDeviceToDevice)); + NVTE_CHECK_CUDA(cudaFree(d_tmp)); +} + std::pair getTestTolerances(const DType type, bool use_fp8, bool use_mxfp8) { auto [atol, rtol] = getTolerances(type); @@ -318,6 +361,12 @@ std::pair getTestTolerances(const DType type, bool use_fp8, bool else if (use_fp8) { atol = 1e-3; rtol = std::max(rtol, 1e-2); + // Relax for gfx1250 + cudaDeviceProp prop; + (void)cudaGetDeviceProperties(&prop, 0); + if (prop.major == 12 && type == DType::kBFloat16) { + rtol = std::max(rtol, 5e-2); + } } else if (type == DType::kBFloat16) { //relax for certain prime number TN gemm @@ -496,6 +545,31 @@ void performTest(const TestParams& params) { #endif Tensor Workspace("Workspace", TShape{ workspace_size }, DType::kByte); + //perform the reference gemm on GPU (before swizzle, which modifies scales in-place) + Tensor RefD("RefD", TShape{ params.n, params.m }, dtype); + Tensor RefPreGeluOut; + + if (params.use_gelu) { + RefPreGeluOut = Tensor("RefPreGeluOut", TShape{ params.n, params.m }, gelu_type); + } + + run_reference( + params, + A, + B, + params.use_bias ? &bias : nullptr, + D, + RefD, + params.use_gelu ? &RefPreGeluOut : nullptr); + + // On gfx1250, hipBLASLt MXFP8 kernels expect pre-swizzled scales. + if (use_mxfp8 && prop.major == 12) { + if (!a_colwise) swizzle_mxfp8_scales(A, true); + if (a_colwise) swizzle_mxfp8_scales(A, false); + if (!b_colwise) swizzle_mxfp8_scales(B, true); + if (b_colwise) swizzle_mxfp8_scales(B, false); + } + //perform the gemm in GPU nvte_cublas_gemm(A.data(), B.data(), @@ -517,23 +591,6 @@ void performTest(const TestParams& params) { pre_gelu_out.to_cpu(); } - //perform the reference gemm on GPU - Tensor RefD("RefD", TShape{ params.n, params.m }, dtype); - Tensor RefPreGeluOut; - - if (params.use_gelu) { - RefPreGeluOut = Tensor("RefPreGeluOut", TShape{ params.n, params.m }, gelu_type); - } - - run_reference( - params, - A, - B, - params.use_bias ? &bias : nullptr, - D, - RefD, - params.use_gelu ? &RefPreGeluOut : nullptr); - // check if error message happens in running (void)cudaDeviceSynchronize(); auto err = cudaGetLastError(); @@ -605,6 +662,16 @@ void performDqTest(const TestParams ¶ms) { nvte_dequantize(A_fp8.data(), A_ref.data(), 0); nvte_dequantize(B_fp8.data(), B_ref.data(), 0); + // On gfx1250, hipBLASLt MXFP8 kernels expect pre-swizzled scales. + if (prop.major == 12) { + const bool a_colwise = !params.transa; + const bool b_colwise = params.transb; + if (!a_colwise) swizzle_mxfp8_scales(A_fp8, true); + if (a_colwise) swizzle_mxfp8_scales(A_fp8, false); + if (!b_colwise) swizzle_mxfp8_scales(B_fp8, true); + if (b_colwise) swizzle_mxfp8_scales(B_fp8, false); + } + Tensor bias; Tensor pre_gelu_out; diff --git a/tests/cpp/operator/test_swizzle.cu b/tests/cpp/operator/test_swizzle.cu index 3209d2335..0092a0c62 100644 --- a/tests/cpp/operator/test_swizzle.cu +++ b/tests/cpp/operator/test_swizzle.cu @@ -166,3 +166,183 @@ INSTANTIATE_TEST_SUITE_P( std::to_string(std::get<2>(info.param)); return name; }); + +#ifdef __HIP_PLATFORM_AMD__ + +// MX pre-swizzle test (gfx1250 Tensile 3D layout) +// +// Tensile 3D: {K_scale, M}.reshape({K_scale, padM/4, 4}).permute({1, 0, 2}) +// For source (m, k): dst = (m/4) * (K*4) + k*4 + (m%4) + +// CPU reference for Tensile 3D MX scale pre-swizzle. +// Row-major input [M, K], output is a flat permuted array. +void compute_ref_mx_swizzle_row(const uint8_t *h_input, uint8_t *h_output, + const int M, const int K, + const int orig_M, const int orig_K) { + constexpr int GROUP = 4; + for (int m = 0; m < M; m++) { + for (int k = 0; k < K; k++) { + uint8_t val = 127; // E8M0 identity: 2^0 = 1.0 + if (m < orig_M && k < orig_K) { + val = h_input[m * orig_K + k]; + } + int group = k / GROUP; + int within = k % GROUP; + int dst = group * (M * GROUP) + m * GROUP + within; + h_output[dst] = val; + } + } +} + +void compute_ref_mx_swizzle_col(const uint8_t *h_input, uint8_t *h_output, + const int M, const int K, + const int orig_M, const int orig_K) { + constexpr int GROUP = 4; + for (int m = 0; m < M; m++) { + for (int k = 0; k < K; k++) { + uint8_t val = 127; + if (m < orig_M && k < orig_K) { + val = h_input[k * orig_M + m]; + } + int group = k / GROUP; + int within = k % GROUP; + int dst = group * (M * GROUP) + m * GROUP + within; + h_output[dst] = val; + } + } +} + +static size_t roundup_sz(size_t val, size_t mult) { + return ((val + mult - 1) / mult) * mult; +} + +class MxSwizzleTestSuite + : public ::testing::TestWithParam< + std::tuple, bool>> {}; + +TEST_P(MxSwizzleTestSuite, TestMxSwizzle) { + using namespace transformer_engine; + using namespace test; + + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, 0); + if (prop.major < 12) { + GTEST_SKIP() << "MXFP8 pre-swizzle is only supported on gfx1250"; + } + + const auto dims = std::get<0>(GetParam()); + const bool rowwise = std::get<1>(GetParam()); + + // Original (unpadded) scale dimensions + const size_t orig_M = dims.first; + const size_t orig_K = dims.second; + + // Padded dimensions: K-tiled layout requires K_scale padded to multiple of 4 + const size_t M = orig_M; + const size_t K = roundup_sz(orig_K, 4); + + // Allocate host input (unpadded) and fill with random data + const size_t input_size = orig_M * orig_K; + std::unique_ptr h_input(new uint8_t[input_size]); + std::mt19937 rng(42); + for (size_t i = 0; i < input_size; i++) { + h_input[i] = static_cast(rng() % 256); + } + + // Allocate device input + uint8_t *d_input = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&d_input, input_size)); + NVTE_CHECK_CUDA(cudaMemcpy(d_input, h_input.get(), input_size, cudaMemcpyHostToDevice)); + + // Allocate device output (padded size) + const size_t output_size = M * K; + uint8_t *d_output = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&d_output, output_size)); + NVTE_CHECK_CUDA(cudaMemset(d_output, 0, output_size)); + + // Build TensorWrapper for input and output + TensorWrapper input_tw(NVTE_MXFP8_1D_SCALING); + TensorWrapper output_tw(NVTE_MXFP8_1D_SCALING); + output_tw.set_with_gemm_swizzled_scales(true); + + // Data shape must be consistent with scale shape for validation. + // Scale shapes use padded K; data shapes use unpadded dims + // (kernel derives original_M/K from them). + if (rowwise) { + std::vector data_shape_in = {orig_M, orig_K * 32}; + std::vector data_shape_out = {M, K * 32}; + std::vector scale_shape_in = {M, K}; + std::vector scale_shape_out = {M, K}; + input_tw.set_rowwise_data(nullptr, DType::kFloat8E4M3, data_shape_in); + input_tw.set_rowwise_scale_inv(d_input, DType::kFloat8E8M0, scale_shape_in); + output_tw.set_rowwise_data(nullptr, DType::kFloat8E4M3, data_shape_out); + output_tw.set_rowwise_scale_inv(d_output, DType::kFloat8E8M0, scale_shape_out); + } else { + std::vector data_shape_in = {orig_K * 32, orig_M}; + std::vector data_shape_out = {K * 32, M}; + std::vector scale_shape_in = {K, M}; + std::vector scale_shape_out = {K, M}; + input_tw.set_columnwise_data(nullptr, DType::kFloat8E4M3, data_shape_in); + input_tw.set_columnwise_scale_inv(d_input, DType::kFloat8E8M0, scale_shape_in); + output_tw.set_columnwise_data(nullptr, DType::kFloat8E4M3, data_shape_out); + output_tw.set_columnwise_scale_inv(d_output, DType::kFloat8E8M0, scale_shape_out); + } + + nvte_swizzle_scaling_factors(input_tw.data(), output_tw.data(), 0); + + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + + // Copy output back to host + std::unique_ptr h_output(new uint8_t[output_size]); + NVTE_CHECK_CUDA(cudaMemcpy(h_output.get(), d_output, output_size, cudaMemcpyDeviceToHost)); + + // Compute reference + std::unique_ptr h_ref(new uint8_t[output_size]); + memset(h_ref.get(), 0, output_size); + if (rowwise) { + compute_ref_mx_swizzle_row(h_input.get(), h_ref.get(), M, K, orig_M, orig_K); + } else { + compute_ref_mx_swizzle_col(h_input.get(), h_ref.get(), M, K, orig_M, orig_K); + } + + // Compare + compareResults("mx_swizzle", h_output.get(), h_ref.get(), output_size); + + cudaFree(d_input); + cudaFree(d_output); +} + +namespace { + +// Scale dimensions (M_scale, K_scale). +// K_scale will be padded to multiple of 4 by the test. +std::vector> mx_scale_dims = { + {4, 4}, // minimal + {8, 4}, // small + {32, 8}, // medium + {64, 16}, // larger + {96, 8}, // non-power-of-2 M + {128, 32}, // big + {256, 64}, // bigger + {512, 128}, // stress inter-tile + {1024, 256}, // large + {4096, 256}, // max stress +}; + +} // namespace + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + MxSwizzleTestSuite, + ::testing::Combine( + ::testing::ValuesIn(mx_scale_dims), + ::testing::Values(true, false) + ), + [](const testing::TestParamInfo& info) { + std::string name = "M" + std::to_string(std::get<0>(info.param).first) + + "_K" + std::to_string(std::get<0>(info.param).second) + + (std::get<1>(info.param) ? "_row" : "_col"); + return name; + }); + +#endif // __HIP_PLATFORM_AMD__ diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index c634c73fb..1324debfb 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -14,6 +14,7 @@ #include #include "../common.h" +#include "../util/cuda_runtime.h" #include "../util/logging.h" #include "transformer_engine/transformer_engine.h" @@ -347,9 +348,168 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_ input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); } +#ifdef __HIP_PLATFORM_AMD__ +// ============================================================================ +// MX scale pre-swizzle kernel for gfx1250 — K-tiled 3D layout +// +// hipBLASlt Tensile kernels expect scales in a permuted 3D layout that +// groups K_scale into tiles of 4 (= 128 / MXBlock32): +// Tensor({M, K_scale}).pad(K_scale to mult of 4).reshape({M, K_scale/4, 4}).permute({1, 0, 2}) +// +// For source position (m, k) in the [M, K_scale] scale matrix: +// group = k / 4 +// within = k % 4 +// dst = group * (M * 4) + m * 4 + within +// +// Padding: K_scale to multiple of 4. No M padding required. +// Identity padding value: E8M0 127 = 2^0 = 1.0 +// +// Reference: swizzle_mx_scale() in hipblaslt/clients/common/include/testing_matmul.hpp +// ============================================================================ + +constexpr int MX_PRESWIZZLE_GROUP_SIZE = 4; + +// Unified MX scale pre-swizzle kernel for both row-wise and column-wise. +// Iterates only over valid (non-padded) elements; the caller must pre-fill +// the output buffer with identity (127) to handle padding. +// +// kRowwise=true: input is [orig_M, orig_K] row-major +// kRowwise=false: input is [orig_K, orig_M] row-major (column-wise scales) +template +__global__ void __launch_bounds__(256) + swizzle_scaling_mx_kernel(const uint8_t* __restrict__ input, + uint8_t* __restrict__ output, + const int padded_M, + const int orig_M, const int orig_K) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int total = orig_M * orig_K; + if (idx >= total) return; + + const int m = idx / orig_K; + const int k = idx % orig_K; + + uint8_t val; + if constexpr (kRowwise) { + val = input[idx]; // == input[m * orig_K + k] + } else { + val = input[k * orig_M + m]; + } + + const int group = k / MX_PRESWIZZLE_GROUP_SIZE; + const int within = k % MX_PRESWIZZLE_GROUP_SIZE; + const int dst = group * (padded_M * MX_PRESWIZZLE_GROUP_SIZE) + + m * MX_PRESWIZZLE_GROUP_SIZE + within; + + output[dst] = val; +} + } // namespace +void swizzle_scaling_factors_mx(const Tensor* input, Tensor* output, cudaStream_t stream) { + // Check scaling mode + const auto& scaling_mode = input->scaling_mode; + NVTE_CHECK(scaling_mode == NVTE_MXFP8_1D_SCALING, + "MX pre-swizzle only supports MXFP8 scaling mode (got ", + to_string(input->scaling_mode), ")."); + + // Check tensors + CheckInputTensor(*input, "scaling_factor_input"); + CheckInputTensor(*output, "scaling_factor_output"); + NVTE_CHECK(!input->with_gemm_swizzled_scales, + "Expected input tensor with scales in compact format."); + NVTE_CHECK(output->with_gemm_swizzled_scales, + "Expected output tensor with scales in GEMM swizzled format."); + NVTE_CHECK(is_fp8_dtype(input->dtype()), "Input tensor has invalid dtype (expected FP8, got ", + to_string(input->dtype()), ")."); + + // Check if scaling factors are non-trivial + const bool has_rowwise_scale_inv = input->scale_inv.has_data(); + const bool has_columnwise_scale_inv = input->columnwise_scale_inv.has_data(); + NVTE_CHECK(!has_rowwise_scale_inv || !has_columnwise_scale_inv, + "Input tensor has both row-wise and column-wise scaling factors"); + if (!has_rowwise_scale_inv && !has_columnwise_scale_inv) { + return; + } + + // Deduce tensor dims + int m{0}, k{0}; + if (has_rowwise_scale_inv) { + NVTE_CHECK(input->scale_inv.shape.size() == 2, + "Expected 2D scaling factors, got shape=", input->scale_inv.shape, "."); + m = input->scale_inv.shape[0]; + k = input->scale_inv.shape[1]; + } else if (has_columnwise_scale_inv) { + NVTE_CHECK(input->columnwise_scale_inv.shape.size() == 2, + "Expected 2D scaling factors, got shape=", input->columnwise_scale_inv.shape, "."); + m = input->columnwise_scale_inv.shape[1]; + k = input->columnwise_scale_inv.shape[0]; + } + + // Check dims -- K-tiled layout requires K_scale padded to multiple of 4 + NVTE_CHECK(k % MX_PRESWIZZLE_GROUP_SIZE == 0, + "Scale K dimension must be padded to multiple of ", MX_PRESWIZZLE_GROUP_SIZE, + ", got ", k, "."); + + // Validate output dimensions match + if (has_rowwise_scale_inv) { + NVTE_CHECK(output->scale_inv.has_data(), + "Output tensor does not have row-wise scaling factors."); + NVTE_CHECK(m * k == output->scale_inv.numel(), "Expected output tensor to have ", m * k, + " row-wise scaling factors, but got shape=", output->scale_inv.shape, "."); + } + if (has_columnwise_scale_inv) { + NVTE_CHECK(output->columnwise_scale_inv.has_data(), + "Output tensor does not have column-wise scaling factors."); + NVTE_CHECK(m * k == output->columnwise_scale_inv.numel(), + "Expected output tensor to have ", m * k, + " column-wise scaling factors, but got shape=", + output->columnwise_scale_inv.shape, "."); + } + + const int total = m * k; + constexpr int block = 256; + + // Row-wise swizzle + if (has_rowwise_scale_inv) { + const int original_M = input->flat_first_dim(); + const int original_K = input->flat_last_dim() / MXFP8_BLOCK_SIZE; + // Pre-fill output with E8M0 identity (127 = 2^0) to handle padding + NVTE_CHECK_CUDA(cudaMemsetAsync(output->scale_inv.dptr, 127, total, stream)); + const int orig_total = original_M * original_K; + const int grid = (orig_total + block - 1) / block; + swizzle_scaling_mx_kernel<<>>( + reinterpret_cast(input->scale_inv.dptr), + reinterpret_cast(output->scale_inv.dptr), + m, original_M, original_K); + NVTE_CHECK_CUDA(cudaGetLastError()); + } + + // Column-wise swizzle + if (has_columnwise_scale_inv) { + const int original_M = input->flat_last_dim(); + const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE; + // Pre-fill output with E8M0 identity (127 = 2^0) to handle padding + NVTE_CHECK_CUDA(cudaMemsetAsync(output->columnwise_scale_inv.dptr, 127, total, stream)); + const int orig_total = original_M * original_K; + const int grid = (orig_total + block - 1) / block; + swizzle_scaling_mx_kernel<<>>( + reinterpret_cast(input->columnwise_scale_inv.dptr), + reinterpret_cast(output->columnwise_scale_inv.dptr), + m, original_M, original_K); + NVTE_CHECK_CUDA(cudaGetLastError()); + } +} +#endif // __HIP_PLATFORM_AMD__ + void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) { +#ifdef __HIP_PLATFORM_AMD__ + // On gfx1250, MXFP8 uses the MX pre-swizzle layout (K-tiled, grouped by 4). + if (input->scaling_mode == NVTE_MXFP8_1D_SCALING && cuda::sm_arch() == 125) { + swizzle_scaling_factors_mx(input, output, stream); + return; + } +#endif // __HIP_PLATFORM_AMD__ + // Check scaling mode const auto& scaling_mode = input->scaling_mode; NVTE_CHECK(scaling_mode == NVTE_MXFP8_1D_SCALING || scaling_mode == NVTE_NVFP4_1D_SCALING, @@ -667,6 +827,24 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, void multi_tensor_swizzle_scaling_factors(const std::vector& input, std::vector& output, cudaStream_t stream) { +#ifdef __HIP_PLATFORM_AMD__ + // On gfx1250, MXFP8 uses the MX pre-swizzle layout. + if (cuda::sm_arch() == 125) { + bool any_mxfp8 = false; + for (size_t i = 0; i < input.size(); i++) { + if (is_mxfp8_scaling(input[i]->scaling_mode)) { + any_mxfp8 = true; + } + } + if (any_mxfp8) { + for (size_t i = 0; i < input.size(); i++) { + swizzle_scaling_factors_mx(input[i], output[i], stream); + } + return; + } + } +#endif // __HIP_PLATFORM_AMD__ + auto num_tensors = input.size(); bool all_has_data = true; bool all_has_columnwise_data = true; diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 40121049a..e32a42b1d 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -80,12 +80,31 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( "Inverse scale factors need to have an 8-bit data type."); } if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { - // Assume MXFP8 scales are already swizzled if (rowwise) { input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); } else { input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); } +#ifdef USE_ROCM + // On gfx1250, pre-swizzle MXFP8 scales for hipBLASLt + if (transformer_engine::cuda::sm_arch() == 125 && swizzle_scale_ptr) { + TensorWrapper output(get_nvte_scaling_mode(scaling_mode)); + if (rowwise) { + output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape); + output.set_rowwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape); + } else { + output.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape); + output.set_columnwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape); + } + output.set_with_gemm_swizzled_scales(true); + nvte_swizzle_scaling_factors(input.data(), output.data(), stream); + if (rowwise) { + input.set_rowwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape); + } else { + input.set_columnwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape); + } + } +#endif input.set_with_gemm_swizzled_scales(true); } else if (is_nvfp4) { // Swizzle for NVFP4 #ifdef USE_ROCM @@ -195,7 +214,12 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i auto workspace_ptr = reinterpret_cast(workspace->untyped_data()); workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr); size_t workspace_size = static_cast(workspace->element_count()) - 256; - if (is_nvfp4_scaling(scaling_mode)) { + if (is_nvfp4_scaling(scaling_mode) +#ifdef USE_ROCM + || (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING + && transformer_engine::cuda::sm_arch() == 125) +#endif + ) { auto lhs_scale_size = product(lhs_scale_inv.dimensions()); auto rhs_scale_size = product(rhs_scale_inv.dimensions()); workspace_size = workspace_size - lhs_scale_size - rhs_scale_size; diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 6898ce387..7a54728c2 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -244,13 +244,10 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans config.set_use_split_accumulator(use_split_accumulator); config.set_sm_count(num_math_sms); -#ifndef USE_ROCM // Keep the swizzled scaling factor tensors alive during the GEMM. std::vector> swizzled_scale_inverses_list; -#endif auto main_stream = at::cuda::getCurrentCUDAStream(); if (A_tensor.numel() != 0 && B_tensor.numel() != 0) { -#ifndef USE_ROCM // Optionally swizzle the scaling factors auto [A_row_scales, A_col_scales] = swizzle_scales_for_gemm(A_tensor, transa, !transa); auto [B_row_scales, B_col_scales] = swizzle_scales_for_gemm(B_tensor, !transb, transb); @@ -259,6 +256,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans swizzled_scale_inverses_list.emplace_back(std::move(B_row_scales)); swizzled_scale_inverses_list.emplace_back(std::move(B_col_scales)); +#ifndef USE_ROCM // Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer // as it is not natively supported by cublasLt if (fp8_block_scaling && transformer_engine::cuda::sm_arch() >= 100) { @@ -532,7 +530,6 @@ std::optional> te_general_grouped_gemm( te_pre_gelu_out_wrappers.emplace_back(std::move(te_pre_gelu_out)); } -#ifndef USE_ROCM // Keep the swizzled scaling factor tensors alive during the GEMM. std::vector> swizzled_scale_inverses_list; @@ -542,6 +539,7 @@ std::optional> te_general_grouped_gemm( swizzled_scale_inverses_list.emplace_back( multi_tensor_swizzle_scales_for_gemm(te_B_wrappers, !transb, transb)); +#ifndef USE_ROCM // Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer // as it is not natively supported by cublasLt if (transformer_engine::cuda::sm_arch() >= 100) { diff --git a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp index 4ad57bbf1..d9929c93e 100644 --- a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp +++ b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp @@ -13,6 +13,7 @@ #include "common.h" #include "common/common.h" +#include "common/util/cuda_runtime.h" #include "extensions.h" #include "pybind.h" #include "util.h" @@ -55,6 +56,13 @@ std::tuple, std::optional> swizzle_scales_ return {std::nullopt, std::nullopt}; } +#ifdef USE_ROCM + // On ROCm, only MXFP8 on gfx1250 needs scale pre-swizzling + if (scaling_mode != NVTE_MXFP8_1D_SCALING || transformer_engine::cuda::sm_arch() != 125) { + return {std::nullopt, std::nullopt}; + } +#endif + // Return early if scales are already swizzled if (tensor.get_with_gemm_swizzled_scales()) { return {std::nullopt, std::nullopt}; @@ -164,6 +172,13 @@ std::optional multi_tensor_swizzle_scales_for_gemm( return std::nullopt; } +#ifdef USE_ROCM + // On ROCm, only MXFP8 on gfx1250 needs scale pre-swizzling + if (scaling_mode != NVTE_MXFP8_1D_SCALING || transformer_engine::cuda::sm_arch() != 125) { + return std::nullopt; + } +#endif + // Filter out tensors that already have swizzled scales std::vector tensors_needing_swizzle; for (auto &tensor : tensors) { diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index bb960406d..f1f6d690a 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -9,6 +9,9 @@ #include #include "common.h" +#ifdef USE_ROCM +#include "common/util/cuda_runtime.h" +#endif #include "pybind.h" #include "torch/torch.h" @@ -1104,6 +1107,17 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE, " (got shape=", shape, ")"); #ifdef USE_ROCM + // gfx1250 MX pre-swizzle (Tensile 3D) layout requires M padded to multiple of 4. + if (transformer_engine::cuda::sm_arch() == 125) { + size_t m_dim = numel / last_dim; + size_t k_scale = last_dim / MXFP8_BLOCK_SIZE; + if (!columnwise) { + return {roundup(m_dim, 4), k_scale}; + } else { + return {k_scale, roundup(m_dim, 4)}; + } + } + return !columnwise ? std::vector{numel / last_dim, last_dim / MXFP8_BLOCK_SIZE} : std::vector{numel / last_dim / MXFP8_BLOCK_SIZE, last_dim}; diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h index 6588aa6c5..f2310b61f 100644 --- a/transformer_engine/pytorch/csrc/util.h +++ b/transformer_engine/pytorch/csrc/util.h @@ -9,8 +9,6 @@ #ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ #define TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ -#ifndef USE_ROCM - #include #include @@ -37,6 +35,7 @@ std::optional multi_tensor_swizzle_scales_for_gemm(std::vector multi_tensor_swizzle_scales_for_gemm(std::vector Date: Thu, 21 May 2026 19:51:08 -0400 Subject: [PATCH 21/47] CK Tile Group GEMM gfx1250 (#576) --- .gitmodules | 4 + 3rdparty/rocm_libraries | 1 + transformer_engine/common/CMakeLists.txt | 2 +- .../gemm/ck_grouped_gemm/ck_grouped_gemm.cpp | 45 ++++--- .../ck_grouped_gemm/ck_grouped_gemm_common.h | 23 ++++ .../ck_grouped_gemm/ck_grouped_gemm_fp16.cpp | 78 ++++++++++-- .../ck_grouped_gemm/ck_grouped_gemm_fp8.cpp | 118 +++++++++++------- .../common/gemm/cublaslt_gemm.cu | 1 + 8 files changed, 192 insertions(+), 80 deletions(-) create mode 160000 3rdparty/rocm_libraries diff --git a/.gitmodules b/.gitmodules index c81bdb590..8c466697e 100644 --- a/.gitmodules +++ b/.gitmodules @@ -26,3 +26,7 @@ [submodule "3rdparty/QoLA"] path = 3rdparty/QoLA url = https://github.com/Micky774/QoLA.git +[submodule "rocm_libraries"] + path = 3rdparty/rocm_libraries + url = https://github.com/ROCm/rocm-libraries.git + branch = users/jia/ck/fix_grouped_gemm_quant_mxtype diff --git a/3rdparty/rocm_libraries b/3rdparty/rocm_libraries new file mode 160000 index 000000000..66b1d1467 --- /dev/null +++ b/3rdparty/rocm_libraries @@ -0,0 +1 @@ +Subproject commit 66b1d146722c42d86a794ebc9c6097c2e1c9f7a4 diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 774065fca..1f75f725d 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -354,7 +354,7 @@ set_property( PROPERTY COMPILE_OPTIONS "-g0;-dopt=on") else() - set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/aiter/3rdparty/composable_kernel) + set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/rocm_libraries/projects/composablekernel) target_include_directories(transformer_engine PRIVATE ${CK_ROOT}/include) endif() #USE_CUDA diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp index 5684be1cd..c5f8f4086 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp @@ -57,19 +57,19 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, // FP8 special handling. // // A_use/B_use and transA_use/transB_use have already gone through the - // upstream-style grouped GEMM normalization above. This block only rewrites - // that normalized presentation into the CK FP8 preferred NT presentation by selecting - // `columnwise_data` when needed. + // upstream-style grouped GEMM normalization above. CK FP8 grouped GEMM is + // compiled only for the preferred NT presentation: // - // CK FP8 target presentation: - // A_use: N - // B_use: T + // transA_use = false + // transB_use = true // - // The outer condition checks whether this NT presentation is possible: - // - A_use is already N, or can be made N using columnwise_data - // - B_use is already T, or can be made T using columnwise_data + // This block rewrites the normalized presentation into that NT form by + // selecting columnwise_data when needed. If the required columnwise_data view + // is unavailable, this CK FP8 backend cannot represent the GEMM in its + // supported layout form, so we fall back instead of compiling/running an + // unsupported layout variant. // - // Then each operand is rewritten independently only if needed: + // Rewrite cases: // NN -> rewrite B only // TN -> rewrite A and B // NT -> already in target form @@ -81,16 +81,23 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, const bool has_a_col = A0_te->has_columnwise_data(); const bool has_b_col = B0_te->has_columnwise_data(); - if ((!transA_use || has_a_col) && (transB_use || has_b_col)) { - if (transA_use) { - use_a_colwise_data = true; - transA_use = false; - } + const bool can_make_a_nt = !transA_use || has_a_col; + const bool can_make_b_nt = transB_use || has_b_col; - if (!transB_use) { - use_b_colwise_data = true; - transB_use = true; - } + if (!can_make_a_nt || !can_make_b_nt) { + NVTE_WARN("ck_tile_grouped_gemm: FP8 grouped GEMM requires NT presentation. " + "Missing required columnwise_data for layout rewrite; falling back."); + return false; + } + + if (transA_use) { + use_a_colwise_data = true; + transA_use = false; + } + + if (!transB_use) { + use_b_colwise_data = true; + transB_use = true; } } diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h index c89f10232..33cfce07f 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h @@ -7,6 +7,7 @@ #pragma once #include +#include "common/util/cuda_runtime.h" #include #include @@ -70,6 +71,28 @@ static inline const transformer_engine::SimpleTensor& scale_inv_view(const trans return t.scale_inv; } +enum class GPUArch { + GFX942, + GFX950, + GFX1250, + UNKNOWN +}; + +static inline GPUArch detect_gpu_arch() { + int arch = cuda::sm_arch(0); + + if (arch == 94) { + return GPUArch::GFX942; + } + if (arch == 95) { + return GPUArch::GFX950; + } + if (arch == 1250) { + return GPUArch::GFX1250; + } + return GPUArch::UNKNOWN; +} + struct GroupedGemmRunContext { const NVTETensor* A = nullptr; const NVTETensor* B = nullptr; diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp index 660dbefb8..df47261c7 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp @@ -14,7 +14,7 @@ namespace grouped_gemm { // Tile configs: FP16/BF16 // ------------------------- -struct TileCfg_256x256x64 { +struct TileCfg_256x256x64_MFMA { static constexpr ck_tile::index_t M_Tile = 256; static constexpr ck_tile::index_t N_Tile = 256; static constexpr ck_tile::index_t K_Tile = 64; @@ -37,14 +37,37 @@ struct TileCfg_256x256x64 { static constexpr ck_tile::index_t TilePartitionerM01 = 4; }; -struct TileCfg_256x128x64 : TileCfg_256x256x64 { +struct TileCfg_256x128x64_MFMA : TileCfg_256x256x64_MFMA { static constexpr ck_tile::index_t N_Tile = 128; }; -struct TileCfg_256x128x64_padding : TileCfg_256x128x64 { +struct TileCfg_256x128x64_MFMA_padding : TileCfg_256x128x64_MFMA { static constexpr bool kPadN = true; }; +struct TileCfg_256x256x64_WMMA { + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 32; + + static constexpr bool kPadM = true; + static constexpr bool kPadN = true; + static constexpr bool kPadK = true; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 4; +}; + template (); \ }) -bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, - DType b_dtype, - DType d_dtype, - const GroupedGemmRunContext& ctx) { +template +bool ck_tile_grouped_gemm_fp16_dispatch_arch(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { const ck_tile::stream_config s{ctx.stream}; std::unique_ptr runner = nullptr; @@ -229,13 +253,17 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { using CType = typename TETypeToCKType::type; - - if (ctx.N % 256 == 0) { - MAKE_RUNNER(TileCfg_256x256x64); - } else if (ctx.N % 128 == 0) { - MAKE_RUNNER(TileCfg_256x128x64); + + if constexpr (Arch == GPUArch::GFX1250) { + MAKE_RUNNER(TileCfg_256x256x64_WMMA); } else { - MAKE_RUNNER(TileCfg_256x128x64_padding); + if (ctx.N % 256 == 0) { + MAKE_RUNNER(TileCfg_256x256x64_MFMA); + } else if (ctx.N % 128 == 0) { + MAKE_RUNNER(TileCfg_256x128x64_MFMA); + } else { + MAKE_RUNNER(TileCfg_256x128x64_MFMA_padding); + } } }); }); @@ -249,6 +277,30 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, return runner->run(s, ctx); } +bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { + switch (detect_gpu_arch()) { +#if defined(__gfx942__) + case GPUArch::GFX942: + return ck_tile_grouped_gemm_fp16_dispatch_arch(a_dtype, b_dtype, d_dtype, ctx); +#endif +#if defined(__gfx950__) + case GPUArch::GFX950: + return ck_tile_grouped_gemm_fp16_dispatch_arch(a_dtype, b_dtype, d_dtype, ctx); +#endif +#if defined(__gfx1250__) + case GPUArch::GFX1250: + return ck_tile_grouped_gemm_fp16_dispatch_arch(a_dtype, b_dtype, d_dtype, ctx); +#endif + + default: + NVTE_ERROR("ck_tile_grouped_gemm: available architectures = {gfx942, gfx950, gfx1250}"); + return false; + } +} + #undef MAKE_RUNNER } // namespace grouped_gemm diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp index 50b701c05..f31d73100 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp @@ -6,7 +6,6 @@ #include "ck_grouped_gemm_common.h" #include "ck_grouped_gemm_fp8.h" -#include "common/util/cuda_runtime.h" #include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp" @@ -16,12 +15,6 @@ namespace transformer_engine { namespace grouped_gemm { -enum class GPUArch { - GFX942, - GFX950, - UNKNOWN -}; - struct TileCfg_128x128x128_16x16x128_2x2x1 { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -45,6 +38,29 @@ struct TileCfg_128x128x128_16x16x128_2x2x1 { static constexpr ck_tile::index_t TilePartitionerM01 = 8; }; +struct TileCfg_128x128x128_16x16x64_2x2x1 { + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 64; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 16; + static constexpr ck_tile::index_t TilePartitionerM01 = 8; +}; + // gfx950 device compilation cannot instantiate the literal 32x32x16 FP8 tile // configuration due to an unsupported warp GEMM dispatcher configuration. // See: ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp for supported variants. @@ -115,8 +131,7 @@ class QuantGroupedGemmRunner : public RunnerInterface { AccType, GemmShape, UniversalTraits, - false, - AccType>; + false>; using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -265,18 +280,6 @@ class QuantGroupedGemmRunner : public RunnerInterface { } }; -static inline GPUArch detect_gpu_arch() { - int arch = cuda::sm_arch(0); - - if (arch == 94) { - return GPUArch::GFX942; - } - if (arch == 95) { - return GPUArch::GFX950; - } - return GPUArch::UNKNOWN; -} - template struct FP8TileCfg; @@ -290,6 +293,11 @@ struct FP8TileCfg { using type = TileCfg_128x128x128_16x16x128_2x2x1; }; +template <> +struct FP8TileCfg { + using type = TileCfg_128x128x128_16x16x64_2x2x1; +}; + template static bool ck_tile_grouped_gemm_fp8_dispatch_arch(DType a_dtype, DType b_dtype, @@ -301,31 +309,38 @@ static bool ck_tile_grouped_gemm_fp8_dispatch_arch(DType a_dtype, using CTypeLayout = RowMajor; using TileCfg = typename FP8TileCfg::type; - TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transA, kTransA, { - using ALayout = std::conditional_t; - - TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transB, kTransB, { - using BLayout = std::conditional_t; - - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(a_dtype, a_te_type, { - using AType = typename TETypeToCKType::type; - - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(b_dtype, b_te_type, { - using BType = typename TETypeToCKType::type; - - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { - using CType = typename TETypeToCKType::type; - using Runner = QuantGroupedGemmRunner; - runner = std::make_unique(); - }); - }); + // FP8 grouped GEMM is only compiled for CK's preferred NT presentation: + // transA=false, transB=true + // which maps to: + // ALayout=RowMajor, BLayout=ColMajor. + // + // The caller is responsible for rewriting other FP8 layouts into this form + // using columnwise_data when needed. Reject anything that did not normalize + // successfully so we do not instantiate unreachable/unsupported layout variants. + if (ctx.transA || !ctx.transB) { + return false; + } + + using ALayout = RowMajor; + using BLayout = ColMajor; + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(a_dtype, a_te_type, { + using AType = typename TETypeToCKType::type; + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(b_dtype, b_te_type, { + using BType = typename TETypeToCKType::type; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { + using CType = typename TETypeToCKType::type; + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); }); }); }); @@ -342,12 +357,21 @@ bool ck_tile_grouped_gemm_fp8_dispatch(DType a_dtype, DType d_dtype, const GroupedGemmRunContext& ctx) { switch (detect_gpu_arch()) { +#if defined(__gfx942__) case GPUArch::GFX942: return ck_tile_grouped_gemm_fp8_dispatch_arch(a_dtype, b_dtype, d_dtype, ctx); +#endif +#if defined(__gfx950__) case GPUArch::GFX950: return ck_tile_grouped_gemm_fp8_dispatch_arch(a_dtype, b_dtype, d_dtype, ctx); +#endif +#if defined(__gfx1250__) + case GPUArch::GFX1250: + return ck_tile_grouped_gemm_fp8_dispatch_arch(a_dtype, b_dtype, d_dtype, ctx); +#endif + default: - NVTE_ERROR("ck_tile_grouped_gemm: available architectures = {gfx942, gfx950}"); + NVTE_ERROR("ck_tile_grouped_gemm: available architectures = {gfx942, gfx950, gfx1250}"); return false; } } diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 7326f330f..efdff259d 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1123,6 +1123,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor #else // Currently only support cutlass group gemm on Hopper Arch if (!(is_hopper && use_cutlass)) { + if (!use_cutlass) { #endif cublas_path(); return; From 7c3f49903cd9462ffde23f61594a269a1919a794 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 8 Jun 2026 13:39:13 +0000 Subject: [PATCH 22/47] Add sparse rocm-libraries submodule for Composable Kernel --- .gitmodules | 1 - 3rdparty/rocm_libraries | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.gitmodules b/.gitmodules index 8c466697e..4aac0ed83 100644 --- a/.gitmodules +++ b/.gitmodules @@ -29,4 +29,3 @@ [submodule "rocm_libraries"] path = 3rdparty/rocm_libraries url = https://github.com/ROCm/rocm-libraries.git - branch = users/jia/ck/fix_grouped_gemm_quant_mxtype diff --git a/3rdparty/rocm_libraries b/3rdparty/rocm_libraries index 66b1d1467..287088769 160000 --- a/3rdparty/rocm_libraries +++ b/3rdparty/rocm_libraries @@ -1 +1 @@ -Subproject commit 66b1d146722c42d86a794ebc9c6097c2e1c9f7a4 +Subproject commit 2870887694e0de120aca16c302779a2326386a85 From 7b5ba68b56644e9af28bd71d94b5a6a843830259 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 8 Jun 2026 13:44:36 +0000 Subject: [PATCH 23/47] update submodule name --- .gitmodules | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitmodules b/.gitmodules index 4aac0ed83..b07958a7f 100644 --- a/.gitmodules +++ b/.gitmodules @@ -26,6 +26,6 @@ [submodule "3rdparty/QoLA"] path = 3rdparty/QoLA url = https://github.com/Micky774/QoLA.git -[submodule "rocm_libraries"] +[submodule "3rdparty/rocm_libraries"] path = 3rdparty/rocm_libraries url = https://github.com/ROCm/rocm-libraries.git From 508613c078714ca830fd06d3b470c55ff9f421bc Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 8 Jun 2026 14:00:33 +0000 Subject: [PATCH 24/47] override CK_ROOT --- transformer_engine/common/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index e437e9787..28d24c5bd 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -384,7 +384,7 @@ set_property( PROPERTY COMPILE_OPTIONS "-g0;-dopt=on") else() - set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/aiter/3rdparty/composable_kernel) + set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/rocm_libraries/projects/composablekernel) target_include_directories(transformer_engine PRIVATE ${CK_ROOT}/include) endif() #USE_CUDA From 12461eed8bbb111d016baa26d42001dd241c58b5 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 8 Jun 2026 14:10:08 +0000 Subject: [PATCH 25/47] fix util --- transformer_engine/pytorch/csrc/util.h | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h index b8d7a89c9..587ec289a 100644 --- a/transformer_engine/pytorch/csrc/util.h +++ b/transformer_engine/pytorch/csrc/util.h @@ -56,7 +56,6 @@ std::optional maybe_swizzle_grouped_tensor_for_gemm( * during the GEMM. */ at::Tensor convert_block_scaling_to_mxfp8_tensor(TensorWrapper& input, bool rowwise); -#endif //!USE_ROCM } // namespace pytorch } // namespace transformer_engine From e65634110e1a812da25efefaf13c41e2318da098 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 8 Jun 2026 14:32:17 +0000 Subject: [PATCH 26/47] add runtime guard for arch --- .../common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp index 1228c85a7..ee52c2bed 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp @@ -266,6 +266,12 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, NVTETensor* workspace, bool accumulate,//ignored for now hipStream_t stream) { + + if (detect_gpu_arch() != GPUArch::GFX1250) { + NVTE_WARN("ck_tile_mx_grouped_gemm: only supported on gfx1250. Falling back."); + return false; + } + if (group_num <= 0) { return true; } From cb1614a606da2ef41e70a34127209dc7b773fc59 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 8 Jun 2026 19:13:30 +0000 Subject: [PATCH 27/47] Restore unrelated files from dev --- tests/cpp/operator/test_cublaslt_gemm.cu | 101 ++-------- tests/cpp/operator/test_swizzle.cu | 180 ------------------ .../common/gemm/cublaslt_gemm.cu | 53 ++---- transformer_engine/common/swizzle/swizzle.cu | 178 ----------------- .../jax/csrc/extensions/gemm.cpp | 21 +- .../pytorch/csrc/extensions/gemm.cpp | 6 +- .../pytorch/csrc/extensions/swizzle.cpp | 15 -- transformer_engine/pytorch/csrc/quantizer.cpp | 14 -- .../pytorch/tensor/mxfp8_tensor.py | 28 +-- 9 files changed, 43 insertions(+), 553 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index b8312de00..85f183bf7 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -11,7 +11,6 @@ #include #include #include -#include #include #include "../test_common.h" @@ -31,15 +30,7 @@ std::vector> test_case_sizes = { std::vector> test_case_sizes_mxfp8 = { {32, 128, 16}, - {64, 128, 32}, - {128, 128, 64}, - {64, 256, 32}, - {128, 384, 64}, - {256, 512, 128}, - {512, 1024, 256}, {768, 3072, 4096}, - {1024, 2048, 128}, - {4096, 8192, 64}, }; // A, B, Bias, Gelu, D @@ -312,40 +303,6 @@ void cpu_rowwise_to_columnwise( } } -// Swizzle MXFP8 scale_inv of a test::Tensor in-place for gfx1250. -static void swizzle_mxfp8_scales(test::Tensor &t, bool rowwise) { - using namespace transformer_engine; - void *scale_ptr = rowwise ? t.rowwise_scale_inv_dptr() - : t.columnwise_scale_inv_dptr(); - if (!scale_ptr) return; - const NVTEShape scale_shape = rowwise ? t.rowwise_scale_inv_shape() - : t.columnwise_scale_inv_shape(); - const NVTEShape data_shape = rowwise ? t.rowwise_shape() - : t.columnwise_shape(); - size_t num_scales = 1; - for (size_t d = 0; d < scale_shape.ndim; d++) num_scales *= scale_shape.data[d]; - uint8_t *d_tmp = nullptr; - NVTE_CHECK_CUDA(cudaMalloc(&d_tmp, num_scales)); - TensorWrapper input_tw(NVTE_MXFP8_1D_SCALING); - TensorWrapper output_tw(NVTE_MXFP8_1D_SCALING); - output_tw.set_with_gemm_swizzled_scales(true); - if (rowwise) { - input_tw.set_rowwise_data(nullptr, t.dtype(), data_shape); - input_tw.set_rowwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape); - output_tw.set_rowwise_data(nullptr, t.dtype(), data_shape); - output_tw.set_rowwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape); - } else { - input_tw.set_columnwise_data(nullptr, t.dtype(), data_shape); - input_tw.set_columnwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape); - output_tw.set_columnwise_data(nullptr, t.dtype(), data_shape); - output_tw.set_columnwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape); - } - nvte_swizzle_scaling_factors(input_tw.data(), output_tw.data(), 0); - NVTE_CHECK_CUDA(cudaDeviceSynchronize()); - NVTE_CHECK_CUDA(cudaMemcpy(scale_ptr, d_tmp, num_scales, cudaMemcpyDeviceToDevice)); - NVTE_CHECK_CUDA(cudaFree(d_tmp)); -} - std::pair getTestTolerances(const DType type, bool use_fp8, bool use_mxfp8) { auto [atol, rtol] = getTolerances(type); @@ -361,12 +318,6 @@ std::pair getTestTolerances(const DType type, bool use_fp8, bool else if (use_fp8) { atol = 1e-3; rtol = std::max(rtol, 1e-2); - // Relax for gfx1250 - cudaDeviceProp prop; - (void)cudaGetDeviceProperties(&prop, 0); - if (prop.major == 12 && type == DType::kBFloat16) { - rtol = std::max(rtol, 5e-2); - } } else if (type == DType::kBFloat16) { //relax for certain prime number TN gemm @@ -545,31 +496,6 @@ void performTest(const TestParams& params) { #endif Tensor Workspace("Workspace", TShape{ workspace_size }, DType::kByte); - //perform the reference gemm on GPU (before swizzle, which modifies scales in-place) - Tensor RefD("RefD", TShape{ params.n, params.m }, dtype); - Tensor RefPreGeluOut; - - if (params.use_gelu) { - RefPreGeluOut = Tensor("RefPreGeluOut", TShape{ params.n, params.m }, gelu_type); - } - - run_reference( - params, - A, - B, - params.use_bias ? &bias : nullptr, - D, - RefD, - params.use_gelu ? &RefPreGeluOut : nullptr); - - // On gfx1250, hipBLASLt MXFP8 kernels expect pre-swizzled scales. - if (use_mxfp8 && prop.major == 12) { - if (!a_colwise) swizzle_mxfp8_scales(A, true); - if (a_colwise) swizzle_mxfp8_scales(A, false); - if (!b_colwise) swizzle_mxfp8_scales(B, true); - if (b_colwise) swizzle_mxfp8_scales(B, false); - } - //perform the gemm in GPU nvte_cublas_gemm(A.data(), B.data(), @@ -591,6 +517,23 @@ void performTest(const TestParams& params) { pre_gelu_out.to_cpu(); } + //perform the reference gemm on GPU + Tensor RefD("RefD", TShape{ params.n, params.m }, dtype); + Tensor RefPreGeluOut; + + if (params.use_gelu) { + RefPreGeluOut = Tensor("RefPreGeluOut", TShape{ params.n, params.m }, gelu_type); + } + + run_reference( + params, + A, + B, + params.use_bias ? &bias : nullptr, + D, + RefD, + params.use_gelu ? &RefPreGeluOut : nullptr); + // check if error message happens in running (void)cudaDeviceSynchronize(); auto err = cudaGetLastError(); @@ -662,16 +605,6 @@ void performDqTest(const TestParams ¶ms) { nvte_dequantize(A_fp8.data(), A_ref.data(), 0); nvte_dequantize(B_fp8.data(), B_ref.data(), 0); - // On gfx1250, hipBLASLt MXFP8 kernels expect pre-swizzled scales. - if (prop.major == 12) { - const bool a_colwise = !params.transa; - const bool b_colwise = params.transb; - if (!a_colwise) swizzle_mxfp8_scales(A_fp8, true); - if (a_colwise) swizzle_mxfp8_scales(A_fp8, false); - if (!b_colwise) swizzle_mxfp8_scales(B_fp8, true); - if (b_colwise) swizzle_mxfp8_scales(B_fp8, false); - } - Tensor bias; Tensor pre_gelu_out; diff --git a/tests/cpp/operator/test_swizzle.cu b/tests/cpp/operator/test_swizzle.cu index 0092a0c62..3209d2335 100644 --- a/tests/cpp/operator/test_swizzle.cu +++ b/tests/cpp/operator/test_swizzle.cu @@ -166,183 +166,3 @@ INSTANTIATE_TEST_SUITE_P( std::to_string(std::get<2>(info.param)); return name; }); - -#ifdef __HIP_PLATFORM_AMD__ - -// MX pre-swizzle test (gfx1250 Tensile 3D layout) -// -// Tensile 3D: {K_scale, M}.reshape({K_scale, padM/4, 4}).permute({1, 0, 2}) -// For source (m, k): dst = (m/4) * (K*4) + k*4 + (m%4) - -// CPU reference for Tensile 3D MX scale pre-swizzle. -// Row-major input [M, K], output is a flat permuted array. -void compute_ref_mx_swizzle_row(const uint8_t *h_input, uint8_t *h_output, - const int M, const int K, - const int orig_M, const int orig_K) { - constexpr int GROUP = 4; - for (int m = 0; m < M; m++) { - for (int k = 0; k < K; k++) { - uint8_t val = 127; // E8M0 identity: 2^0 = 1.0 - if (m < orig_M && k < orig_K) { - val = h_input[m * orig_K + k]; - } - int group = k / GROUP; - int within = k % GROUP; - int dst = group * (M * GROUP) + m * GROUP + within; - h_output[dst] = val; - } - } -} - -void compute_ref_mx_swizzle_col(const uint8_t *h_input, uint8_t *h_output, - const int M, const int K, - const int orig_M, const int orig_K) { - constexpr int GROUP = 4; - for (int m = 0; m < M; m++) { - for (int k = 0; k < K; k++) { - uint8_t val = 127; - if (m < orig_M && k < orig_K) { - val = h_input[k * orig_M + m]; - } - int group = k / GROUP; - int within = k % GROUP; - int dst = group * (M * GROUP) + m * GROUP + within; - h_output[dst] = val; - } - } -} - -static size_t roundup_sz(size_t val, size_t mult) { - return ((val + mult - 1) / mult) * mult; -} - -class MxSwizzleTestSuite - : public ::testing::TestWithParam< - std::tuple, bool>> {}; - -TEST_P(MxSwizzleTestSuite, TestMxSwizzle) { - using namespace transformer_engine; - using namespace test; - - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, 0); - if (prop.major < 12) { - GTEST_SKIP() << "MXFP8 pre-swizzle is only supported on gfx1250"; - } - - const auto dims = std::get<0>(GetParam()); - const bool rowwise = std::get<1>(GetParam()); - - // Original (unpadded) scale dimensions - const size_t orig_M = dims.first; - const size_t orig_K = dims.second; - - // Padded dimensions: K-tiled layout requires K_scale padded to multiple of 4 - const size_t M = orig_M; - const size_t K = roundup_sz(orig_K, 4); - - // Allocate host input (unpadded) and fill with random data - const size_t input_size = orig_M * orig_K; - std::unique_ptr h_input(new uint8_t[input_size]); - std::mt19937 rng(42); - for (size_t i = 0; i < input_size; i++) { - h_input[i] = static_cast(rng() % 256); - } - - // Allocate device input - uint8_t *d_input = nullptr; - NVTE_CHECK_CUDA(cudaMalloc(&d_input, input_size)); - NVTE_CHECK_CUDA(cudaMemcpy(d_input, h_input.get(), input_size, cudaMemcpyHostToDevice)); - - // Allocate device output (padded size) - const size_t output_size = M * K; - uint8_t *d_output = nullptr; - NVTE_CHECK_CUDA(cudaMalloc(&d_output, output_size)); - NVTE_CHECK_CUDA(cudaMemset(d_output, 0, output_size)); - - // Build TensorWrapper for input and output - TensorWrapper input_tw(NVTE_MXFP8_1D_SCALING); - TensorWrapper output_tw(NVTE_MXFP8_1D_SCALING); - output_tw.set_with_gemm_swizzled_scales(true); - - // Data shape must be consistent with scale shape for validation. - // Scale shapes use padded K; data shapes use unpadded dims - // (kernel derives original_M/K from them). - if (rowwise) { - std::vector data_shape_in = {orig_M, orig_K * 32}; - std::vector data_shape_out = {M, K * 32}; - std::vector scale_shape_in = {M, K}; - std::vector scale_shape_out = {M, K}; - input_tw.set_rowwise_data(nullptr, DType::kFloat8E4M3, data_shape_in); - input_tw.set_rowwise_scale_inv(d_input, DType::kFloat8E8M0, scale_shape_in); - output_tw.set_rowwise_data(nullptr, DType::kFloat8E4M3, data_shape_out); - output_tw.set_rowwise_scale_inv(d_output, DType::kFloat8E8M0, scale_shape_out); - } else { - std::vector data_shape_in = {orig_K * 32, orig_M}; - std::vector data_shape_out = {K * 32, M}; - std::vector scale_shape_in = {K, M}; - std::vector scale_shape_out = {K, M}; - input_tw.set_columnwise_data(nullptr, DType::kFloat8E4M3, data_shape_in); - input_tw.set_columnwise_scale_inv(d_input, DType::kFloat8E8M0, scale_shape_in); - output_tw.set_columnwise_data(nullptr, DType::kFloat8E4M3, data_shape_out); - output_tw.set_columnwise_scale_inv(d_output, DType::kFloat8E8M0, scale_shape_out); - } - - nvte_swizzle_scaling_factors(input_tw.data(), output_tw.data(), 0); - - NVTE_CHECK_CUDA(cudaDeviceSynchronize()); - - // Copy output back to host - std::unique_ptr h_output(new uint8_t[output_size]); - NVTE_CHECK_CUDA(cudaMemcpy(h_output.get(), d_output, output_size, cudaMemcpyDeviceToHost)); - - // Compute reference - std::unique_ptr h_ref(new uint8_t[output_size]); - memset(h_ref.get(), 0, output_size); - if (rowwise) { - compute_ref_mx_swizzle_row(h_input.get(), h_ref.get(), M, K, orig_M, orig_K); - } else { - compute_ref_mx_swizzle_col(h_input.get(), h_ref.get(), M, K, orig_M, orig_K); - } - - // Compare - compareResults("mx_swizzle", h_output.get(), h_ref.get(), output_size); - - cudaFree(d_input); - cudaFree(d_output); -} - -namespace { - -// Scale dimensions (M_scale, K_scale). -// K_scale will be padded to multiple of 4 by the test. -std::vector> mx_scale_dims = { - {4, 4}, // minimal - {8, 4}, // small - {32, 8}, // medium - {64, 16}, // larger - {96, 8}, // non-power-of-2 M - {128, 32}, // big - {256, 64}, // bigger - {512, 128}, // stress inter-tile - {1024, 256}, // large - {4096, 256}, // max stress -}; - -} // namespace - -INSTANTIATE_TEST_SUITE_P( - OperatorTest, - MxSwizzleTestSuite, - ::testing::Combine( - ::testing::ValuesIn(mx_scale_dims), - ::testing::Values(true, false) - ), - [](const testing::TestParamInfo& info) { - std::string name = "M" + std::to_string(std::get<0>(info.param).first) + - "_K" + std::to_string(std::get<0>(info.param).second) + - (std::get<1>(info.param) ? "_row" : "_col"); - return name; - }); - -#endif // __HIP_PLATFORM_AMD__ diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 25d7a768e..35cad5092 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -33,7 +33,6 @@ #include "./cutlass_grouped_gemm.cuh" #else #include "ck_grouped_gemm/ck_grouped_gemm.h" -#include "ck_grouped_gemm/ck_mx_grouped_gemm.hpp" #endif #ifndef __HIP_PLATFORM_AMD__ @@ -1126,11 +1125,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor #else // Currently only support cutlass group gemm on Hopper Arch if (!(is_hopper && use_cutlass)) { - if (!use_cutlass) { #endif - if (warn_fallback) { - NVTE_WARN("Fallback to cuBLAS grouped GEMM."); - } cublas_path(); return; } @@ -1162,29 +1157,21 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor }; #endif -#ifdef __HIP_PLATFORM_AMD__ - auto effective_dtype = [](const transformer_engine::Tensor *t) { - if (t->has_data()) { - return t->data.dtype; - } - if (t->has_columnwise_data()) { - return t->columnwise_data.dtype; - } - return t->data.dtype; - }; -#endif - auto is_supported_dtype = [&]() -> bool { auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]); auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]); auto *OutputD = transformer_engine::convertNVTETensorCheck(D[0]); #ifdef __HIP_PLATFORM_AMD__ - auto A_dt = effective_dtype(inputA); - auto B_dt = effective_dtype(inputB); + auto A_dt = inputA->data.dtype; + auto B_dt = inputB->data.dtype; auto D_dt = OutputD->data.dtype; - - return ((is_fp8_dtype(A_dt) && is_fp8_dtype(B_dt)) || - ((A_dt == B_dt) && (A_dt == D_dt) && is_fp16_dtype(A_dt))); + return ( + (is_fp8_dtype(A_dt) && is_fp8_dtype(B_dt)) + ) || + ( + (A_dt == B_dt) && (A_dt == D_dt) && + (is_fp16_dtype(A_dt)) + ); #else auto A_type = get_cuda_dtype(inputA->data.dtype); auto B_type = get_cuda_dtype(inputB->data.dtype); @@ -1207,23 +1194,11 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && is_supported_dtype() && #ifdef __HIP_PLATFORM_AMD__ true) { - auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]); - const bool mxfp8_gemm = transformer_engine::is_mxfp8_scaling(inputA->scaling_mode); - - bool handled_by_ck = false; - if (mxfp8_gemm) { - handled_by_ck = ck_tile_mx_grouped_gemm( - A, B, D, num_gemms, transa, transb, workspace, accumulate, stream); - } else { - handled_by_ck = ck_tile_grouped_gemm( - A, B, D, num_gemms, transa, transb, workspace, accumulate, stream); - } - - if (!handled_by_ck) { - if (warn_fallback) { - NVTE_WARN("Fallback to cuBLAS grouped GEMM."); - } - cublas_path(); + if (!ck_tile_grouped_gemm(A, B, D, num_gemms, transa, transb, workspace, accumulate, stream)) { + if (warn_fallback) { + NVTE_WARN("Fallback to cuBLAS grouped GEMM."); + } + cublas_path(); } #else all_groups_uniform_k128(B, transb)) { diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 2bb0a1dd8..592992d61 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -14,7 +14,6 @@ #include #include "../common.h" -#include "../util/cuda_runtime.h" #include "../util/logging.h" #include "transformer_engine/transformer_engine.h" @@ -382,168 +381,9 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_ input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); } -#ifdef __HIP_PLATFORM_AMD__ -// ============================================================================ -// MX scale pre-swizzle kernel for gfx1250 — K-tiled 3D layout -// -// hipBLASlt Tensile kernels expect scales in a permuted 3D layout that -// groups K_scale into tiles of 4 (= 128 / MXBlock32): -// Tensor({M, K_scale}).pad(K_scale to mult of 4).reshape({M, K_scale/4, 4}).permute({1, 0, 2}) -// -// For source position (m, k) in the [M, K_scale] scale matrix: -// group = k / 4 -// within = k % 4 -// dst = group * (M * 4) + m * 4 + within -// -// Padding: K_scale to multiple of 4. No M padding required. -// Identity padding value: E8M0 127 = 2^0 = 1.0 -// -// Reference: swizzle_mx_scale() in hipblaslt/clients/common/include/testing_matmul.hpp -// ============================================================================ - -constexpr int MX_PRESWIZZLE_GROUP_SIZE = 4; - -// Unified MX scale pre-swizzle kernel for both row-wise and column-wise. -// Iterates only over valid (non-padded) elements; the caller must pre-fill -// the output buffer with identity (127) to handle padding. -// -// kRowwise=true: input is [orig_M, orig_K] row-major -// kRowwise=false: input is [orig_K, orig_M] row-major (column-wise scales) -template -__global__ void __launch_bounds__(256) - swizzle_scaling_mx_kernel(const uint8_t* __restrict__ input, - uint8_t* __restrict__ output, - const int padded_M, - const int orig_M, const int orig_K) { - const int idx = blockIdx.x * blockDim.x + threadIdx.x; - const int total = orig_M * orig_K; - if (idx >= total) return; - - const int m = idx / orig_K; - const int k = idx % orig_K; - - uint8_t val; - if constexpr (kRowwise) { - val = input[idx]; // == input[m * orig_K + k] - } else { - val = input[k * orig_M + m]; - } - - const int group = k / MX_PRESWIZZLE_GROUP_SIZE; - const int within = k % MX_PRESWIZZLE_GROUP_SIZE; - const int dst = group * (padded_M * MX_PRESWIZZLE_GROUP_SIZE) - + m * MX_PRESWIZZLE_GROUP_SIZE + within; - - output[dst] = val; -} - } // namespace -void swizzle_scaling_factors_mx(const Tensor* input, Tensor* output, cudaStream_t stream) { - // Check scaling mode - const auto& scaling_mode = input->scaling_mode; - NVTE_CHECK(scaling_mode == NVTE_MXFP8_1D_SCALING, - "MX pre-swizzle only supports MXFP8 scaling mode (got ", - to_string(input->scaling_mode), ")."); - - // Check tensors - CheckInputTensor(*input, "scaling_factor_input"); - CheckInputTensor(*output, "scaling_factor_output"); - NVTE_CHECK(!input->with_gemm_swizzled_scales, - "Expected input tensor with scales in compact format."); - NVTE_CHECK(output->with_gemm_swizzled_scales, - "Expected output tensor with scales in GEMM swizzled format."); - NVTE_CHECK(is_fp8_dtype(input->dtype()), "Input tensor has invalid dtype (expected FP8, got ", - to_string(input->dtype()), ")."); - - // Check if scaling factors are non-trivial - const bool has_rowwise_scale_inv = input->scale_inv.has_data(); - const bool has_columnwise_scale_inv = input->columnwise_scale_inv.has_data(); - NVTE_CHECK(!has_rowwise_scale_inv || !has_columnwise_scale_inv, - "Input tensor has both row-wise and column-wise scaling factors"); - if (!has_rowwise_scale_inv && !has_columnwise_scale_inv) { - return; - } - - // Deduce tensor dims - int m{0}, k{0}; - if (has_rowwise_scale_inv) { - NVTE_CHECK(input->scale_inv.shape.size() == 2, - "Expected 2D scaling factors, got shape=", input->scale_inv.shape, "."); - m = input->scale_inv.shape[0]; - k = input->scale_inv.shape[1]; - } else if (has_columnwise_scale_inv) { - NVTE_CHECK(input->columnwise_scale_inv.shape.size() == 2, - "Expected 2D scaling factors, got shape=", input->columnwise_scale_inv.shape, "."); - m = input->columnwise_scale_inv.shape[1]; - k = input->columnwise_scale_inv.shape[0]; - } - - // Check dims -- K-tiled layout requires K_scale padded to multiple of 4 - NVTE_CHECK(k % MX_PRESWIZZLE_GROUP_SIZE == 0, - "Scale K dimension must be padded to multiple of ", MX_PRESWIZZLE_GROUP_SIZE, - ", got ", k, "."); - - // Validate output dimensions match - if (has_rowwise_scale_inv) { - NVTE_CHECK(output->scale_inv.has_data(), - "Output tensor does not have row-wise scaling factors."); - NVTE_CHECK(m * k == output->scale_inv.numel(), "Expected output tensor to have ", m * k, - " row-wise scaling factors, but got shape=", output->scale_inv.shape, "."); - } - if (has_columnwise_scale_inv) { - NVTE_CHECK(output->columnwise_scale_inv.has_data(), - "Output tensor does not have column-wise scaling factors."); - NVTE_CHECK(m * k == output->columnwise_scale_inv.numel(), - "Expected output tensor to have ", m * k, - " column-wise scaling factors, but got shape=", - output->columnwise_scale_inv.shape, "."); - } - - const int total = m * k; - constexpr int block = 256; - - // Row-wise swizzle - if (has_rowwise_scale_inv) { - const int original_M = input->flat_first_dim(); - const int original_K = input->flat_last_dim() / MXFP8_BLOCK_SIZE; - // Pre-fill output with E8M0 identity (127 = 2^0) to handle padding - NVTE_CHECK_CUDA(cudaMemsetAsync(output->scale_inv.dptr, 127, total, stream)); - const int orig_total = original_M * original_K; - const int grid = (orig_total + block - 1) / block; - swizzle_scaling_mx_kernel<<>>( - reinterpret_cast(input->scale_inv.dptr), - reinterpret_cast(output->scale_inv.dptr), - m, original_M, original_K); - NVTE_CHECK_CUDA(cudaGetLastError()); - } - - // Column-wise swizzle - if (has_columnwise_scale_inv) { - const int original_M = input->flat_last_dim(); - const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE; - // Pre-fill output with E8M0 identity (127 = 2^0) to handle padding - NVTE_CHECK_CUDA(cudaMemsetAsync(output->columnwise_scale_inv.dptr, 127, total, stream)); - const int orig_total = original_M * original_K; - const int grid = (orig_total + block - 1) / block; - swizzle_scaling_mx_kernel<<>>( - reinterpret_cast(input->columnwise_scale_inv.dptr), - reinterpret_cast(output->columnwise_scale_inv.dptr), - m, original_M, original_K); - NVTE_CHECK_CUDA(cudaGetLastError()); - } -} -#endif // __HIP_PLATFORM_AMD__ - void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) { -#ifdef __HIP_PLATFORM_AMD__ - // On gfx1250, MXFP8 uses the MX pre-swizzle layout (K-tiled, grouped by 4). - if (input->scaling_mode == NVTE_MXFP8_1D_SCALING && cuda::sm_arch() == 125) { - swizzle_scaling_factors_mx(input, output, stream); - return; - } -#endif // __HIP_PLATFORM_AMD__ - // Check scaling mode const auto& scaling_mode = input->scaling_mode; NVTE_CHECK(scaling_mode == NVTE_MXFP8_1D_SCALING || scaling_mode == NVTE_NVFP4_1D_SCALING, @@ -861,24 +701,6 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, void multi_tensor_swizzle_scaling_factors(const std::vector& input, std::vector& output, cudaStream_t stream) { -#ifdef __HIP_PLATFORM_AMD__ - // On gfx1250, MXFP8 uses the MX pre-swizzle layout. - if (cuda::sm_arch() == 125) { - bool any_mxfp8 = false; - for (size_t i = 0; i < input.size(); i++) { - if (is_mxfp8_scaling(input[i]->scaling_mode)) { - any_mxfp8 = true; - } - } - if (any_mxfp8) { - for (size_t i = 0; i < input.size(); i++) { - swizzle_scaling_factors_mx(input[i], output[i], stream); - } - return; - } - } -#endif // __HIP_PLATFORM_AMD__ - auto num_tensors = input.size(); bool all_has_data = true; bool all_has_columnwise_data = true; diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index e02617e25..2d9a13278 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -81,31 +81,12 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( "Inverse scale factors need to have an 8-bit data type."); } if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { + // Assume MXFP8 scales are already swizzled if (rowwise) { input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); } else { input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); } -#ifdef USE_ROCM - // On gfx1250, pre-swizzle MXFP8 scales for hipBLASLt - if (transformer_engine::cuda::sm_arch() == 125 && swizzle_scale_ptr) { - TensorWrapper output(get_nvte_scaling_mode(scaling_mode)); - if (rowwise) { - output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape); - output.set_rowwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape); - } else { - output.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape); - output.set_columnwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape); - } - output.set_with_gemm_swizzled_scales(true); - nvte_swizzle_scaling_factors(input.data(), output.data(), stream); - if (rowwise) { - input.set_rowwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape); - } else { - input.set_columnwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape); - } - } -#endif input.set_with_gemm_swizzled_scales(true); } else if (is_nvfp4) { // Swizzle for NVFP4 #ifdef USE_ROCM diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 73303a179..12fa16c1e 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -284,10 +284,13 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans config.set_use_split_accumulator(use_split_accumulator); config.set_sm_count(num_math_sms); +#ifndef USE_ROCM // Keep the swizzled scaling factor tensors alive during the GEMM. std::vector> swizzled_scale_inverses_list; +#endif auto main_stream = at::cuda::getCurrentCUDAStream(); if (A_tensor.numel() != 0 && B_tensor.numel() != 0) { +#ifndef USE_ROCM // Optionally swizzle the scaling factors auto [A_row_scales, A_col_scales] = swizzle_scales_for_gemm(A_tensor, transa, !transa); auto [B_row_scales, B_col_scales] = swizzle_scales_for_gemm(B_tensor, !transb, transb); @@ -296,7 +299,6 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans swizzled_scale_inverses_list.emplace_back(std::move(B_row_scales)); swizzled_scale_inverses_list.emplace_back(std::move(B_col_scales)); -#ifndef USE_ROCM // Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer // as it is not natively supported by cublasLt if (fp8_block_scaling && transformer_engine::cuda::sm_arch() >= 100) { @@ -570,6 +572,7 @@ std::optional> te_general_grouped_gemm( te_pre_gelu_out_wrappers.emplace_back(std::move(te_pre_gelu_out)); } +#ifndef USE_ROCM // Keep the swizzled scaling factor tensors alive during the GEMM. std::vector> swizzled_scale_inverses_list; @@ -579,7 +582,6 @@ std::optional> te_general_grouped_gemm( swizzled_scale_inverses_list.emplace_back( multi_tensor_swizzle_scales_for_gemm(te_B_wrappers, !transb, transb)); -#ifndef USE_ROCM // Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer // as it is not natively supported by cublasLt if (transformer_engine::cuda::sm_arch() >= 100) { diff --git a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp index f814bd4c9..bd5524b56 100644 --- a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp +++ b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp @@ -13,7 +13,6 @@ #include "common.h" #include "common/common.h" -#include "common/util/cuda_runtime.h" #include "extensions.h" #include "pybind.h" #include "util.h" @@ -63,13 +62,6 @@ std::tuple, std::optional> swizzle_scales_ return {std::nullopt, std::nullopt}; } -#ifdef USE_ROCM - // On ROCm, only MXFP8 on gfx1250 needs scale pre-swizzling - if (scaling_mode != NVTE_MXFP8_1D_SCALING || transformer_engine::cuda::sm_arch() != 125) { - return {std::nullopt, std::nullopt}; - } -#endif - // Return early if scales are already swizzled if (tensor.get_with_gemm_swizzled_scales()) { return {std::nullopt, std::nullopt}; @@ -179,13 +171,6 @@ std::optional multi_tensor_swizzle_scales_for_gemm( return std::nullopt; } -#ifdef USE_ROCM - // On ROCm, only MXFP8 on gfx1250 needs scale pre-swizzling - if (scaling_mode != NVTE_MXFP8_1D_SCALING || transformer_engine::cuda::sm_arch() != 125) { - return std::nullopt; - } -#endif - // Filter out tensors that already have swizzled scales std::vector tensors_needing_swizzle; for (auto &tensor : tensors) { diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 0ec72028e..e8d012073 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -9,9 +9,6 @@ #include #include "common.h" -#ifdef USE_ROCM -#include "common/util/cuda_runtime.h" -#endif #include "pybind.h" #include "torch/torch.h" @@ -1698,17 +1695,6 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE, " (got shape=", shape, ")"); #ifdef USE_ROCM - // gfx1250 MX pre-swizzle (Tensile 3D) layout requires M padded to multiple of 4. - if (transformer_engine::cuda::sm_arch() == 125) { - size_t m_dim = numel / last_dim; - size_t k_scale = last_dim / MXFP8_BLOCK_SIZE; - if (!columnwise) { - return {roundup(m_dim, 4), k_scale}; - } else { - return {k_scale, roundup(m_dim, 4)}; - } - } - return !columnwise ? std::vector{numel / last_dim, last_dim / MXFP8_BLOCK_SIZE} : std::vector{numel / last_dim / MXFP8_BLOCK_SIZE, last_dim}; diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 442a381bc..63a460276 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -19,7 +19,7 @@ from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe from ..constants import MXFP8_BLOCK_SCALING_SIZE -from ..utils import devices_match, get_device_compute_capability, round_up_to_nearest_multiple +from ..utils import devices_match, round_up_to_nearest_multiple from .storage.mxfp8_tensor_storage import MXFP8TensorStorage, _FromMXFP8Func from ..quantized_tensor import QuantizedTensor, Quantizer from ._quantization_helpers import _IdentityFunc @@ -144,14 +144,9 @@ def make_empty( data = torch.empty(shape, dtype=torch.uint8, device=device) # ROCm TE does not implement fuse padding zeros so use zero tensor here if IS_HIP_EXTENSION: - m_dim = math.prod(shape[:-1]) - k_scale = math.ceil(shape[-1] / MXFP8_BLOCK_SCALING_SIZE) - # gfx1250 MX pre-swizzle layout requires M padded to multiple of 4 - if get_device_compute_capability() == (12, 5): - m_dim = round_up_to_nearest_multiple(m_dim, 4) scale_inv = torch.zeros( - m_dim, - k_scale, + math.prod(shape[:-1]), + math.ceil(shape[-1] / MXFP8_BLOCK_SCALING_SIZE), dtype=torch.uint8, device=device, pin_memory=pin_memory, @@ -174,14 +169,9 @@ def make_empty( ) # ROCm TE does not implement fuse padding zeros so use zero tensor here if IS_HIP_EXTENSION: - k_scale = math.ceil(math.prod(shape[:-1]) / MXFP8_BLOCK_SCALING_SIZE) - m_dim = shape[-1] - # gfx1250 MX pre-swizzle layout requires M padded to multiple of 4 - if get_device_compute_capability() == (12, 5): - m_dim = round_up_to_nearest_multiple(m_dim, 4) columnwise_scale_inv = torch.zeros( - k_scale, - m_dim, + math.ceil(math.prod(shape[:-1]) / MXFP8_BLOCK_SCALING_SIZE), + shape[-1], dtype=torch.uint8, device=device, pin_memory=pin_memory, @@ -523,12 +513,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): scale_invs = [tensor._rowwise_scale_inv, tensor._columnwise_scale_inv] split_sizes_for_scale = [split_size, split_size // MXFP8_BLOCK_SCALING_SIZE] - if IS_HIP_EXTENSION and get_device_compute_capability() == (12, 5): - # gfx1250 MX pre-swizzle layout requires M padded to multiple of 4 - padding_multiples = [4, 1] - else: - # Padding requirements: rowwise dim0 should be divisble by 128, columnwise dim0 should be divisble by 4 - padding_multiples = [128, 4] + # Padding requirements: rowwise dim0 should be divisble by 128, columnwise dim0 should be divisble by 4 + padding_multiples = [128, 4] for scale_inv, scale_split_size, pad_multiple in zip( scale_invs, split_sizes_for_scale, padding_multiples ): From a8bb9502415f86e4eea25b7fff517ca7510e6c8f Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 8 Jun 2026 19:19:25 +0000 Subject: [PATCH 28/47] Restore FP8 grouped GEMM source from dev --- .../ck_grouped_gemm/ck_grouped_gemm_fp8.cpp | 118 +++++++----------- 1 file changed, 47 insertions(+), 71 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp index f31d73100..50b701c05 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp @@ -6,6 +6,7 @@ #include "ck_grouped_gemm_common.h" #include "ck_grouped_gemm_fp8.h" +#include "common/util/cuda_runtime.h" #include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp" @@ -15,30 +16,13 @@ namespace transformer_engine { namespace grouped_gemm { -struct TileCfg_128x128x128_16x16x128_2x2x1 { - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128; - - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 128; - - static constexpr bool kPadM = false; - static constexpr bool kPadN = false; - static constexpr bool kPadK = false; - - static constexpr bool DoubleSmemBuffer = false; - - static constexpr ck_tile::index_t TilePartitionerGroupNum = 16; - static constexpr ck_tile::index_t TilePartitionerM01 = 8; +enum class GPUArch { + GFX942, + GFX950, + UNKNOWN }; -struct TileCfg_128x128x128_16x16x64_2x2x1 { +struct TileCfg_128x128x128_16x16x128_2x2x1 { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; static constexpr ck_tile::index_t K_Tile = 128; @@ -49,7 +33,7 @@ struct TileCfg_128x128x128_16x16x64_2x2x1 { static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 64; + static constexpr ck_tile::index_t K_Warp_Tile = 128; static constexpr bool kPadM = false; static constexpr bool kPadN = false; @@ -131,7 +115,8 @@ class QuantGroupedGemmRunner : public RunnerInterface { AccType, GemmShape, UniversalTraits, - false>; + false, + AccType>; using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -280,6 +265,18 @@ class QuantGroupedGemmRunner : public RunnerInterface { } }; +static inline GPUArch detect_gpu_arch() { + int arch = cuda::sm_arch(0); + + if (arch == 94) { + return GPUArch::GFX942; + } + if (arch == 95) { + return GPUArch::GFX950; + } + return GPUArch::UNKNOWN; +} + template struct FP8TileCfg; @@ -293,11 +290,6 @@ struct FP8TileCfg { using type = TileCfg_128x128x128_16x16x128_2x2x1; }; -template <> -struct FP8TileCfg { - using type = TileCfg_128x128x128_16x16x64_2x2x1; -}; - template static bool ck_tile_grouped_gemm_fp8_dispatch_arch(DType a_dtype, DType b_dtype, @@ -309,38 +301,31 @@ static bool ck_tile_grouped_gemm_fp8_dispatch_arch(DType a_dtype, using CTypeLayout = RowMajor; using TileCfg = typename FP8TileCfg::type; - // FP8 grouped GEMM is only compiled for CK's preferred NT presentation: - // transA=false, transB=true - // which maps to: - // ALayout=RowMajor, BLayout=ColMajor. - // - // The caller is responsible for rewriting other FP8 layouts into this form - // using columnwise_data when needed. Reject anything that did not normalize - // successfully so we do not instantiate unreachable/unsupported layout variants. - if (ctx.transA || !ctx.transB) { - return false; - } - - using ALayout = RowMajor; - using BLayout = ColMajor; - - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(a_dtype, a_te_type, { - using AType = typename TETypeToCKType::type; - - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(b_dtype, b_te_type, { - using BType = typename TETypeToCKType::type; - - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { - using CType = typename TETypeToCKType::type; - using Runner = QuantGroupedGemmRunner; - runner = std::make_unique(); + TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transA, kTransA, { + using ALayout = std::conditional_t; + + TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transB, kTransB, { + using BLayout = std::conditional_t; + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(a_dtype, a_te_type, { + using AType = typename TETypeToCKType::type; + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(b_dtype, b_te_type, { + using BType = typename TETypeToCKType::type; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { + using CType = typename TETypeToCKType::type; + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + }); + }); }); }); }); @@ -357,21 +342,12 @@ bool ck_tile_grouped_gemm_fp8_dispatch(DType a_dtype, DType d_dtype, const GroupedGemmRunContext& ctx) { switch (detect_gpu_arch()) { -#if defined(__gfx942__) case GPUArch::GFX942: return ck_tile_grouped_gemm_fp8_dispatch_arch(a_dtype, b_dtype, d_dtype, ctx); -#endif -#if defined(__gfx950__) case GPUArch::GFX950: return ck_tile_grouped_gemm_fp8_dispatch_arch(a_dtype, b_dtype, d_dtype, ctx); -#endif -#if defined(__gfx1250__) - case GPUArch::GFX1250: - return ck_tile_grouped_gemm_fp8_dispatch_arch(a_dtype, b_dtype, d_dtype, ctx); -#endif - default: - NVTE_ERROR("ck_tile_grouped_gemm: available architectures = {gfx942, gfx950, gfx1250}"); + NVTE_ERROR("ck_tile_grouped_gemm: available architectures = {gfx942, gfx950}"); return false; } } From baeba44fbb8ee37ffa73c51158c627d17b688efb Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 8 Jun 2026 19:21:15 +0000 Subject: [PATCH 29/47] Restore unrelated CK grouped GEMM files from dev --- .../gemm/ck_grouped_gemm/ck_grouped_gemm.cpp | 45 ++++++++----------- 1 file changed, 19 insertions(+), 26 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp index c5f8f4086..5684be1cd 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp @@ -57,19 +57,19 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, // FP8 special handling. // // A_use/B_use and transA_use/transB_use have already gone through the - // upstream-style grouped GEMM normalization above. CK FP8 grouped GEMM is - // compiled only for the preferred NT presentation: + // upstream-style grouped GEMM normalization above. This block only rewrites + // that normalized presentation into the CK FP8 preferred NT presentation by selecting + // `columnwise_data` when needed. // - // transA_use = false - // transB_use = true + // CK FP8 target presentation: + // A_use: N + // B_use: T // - // This block rewrites the normalized presentation into that NT form by - // selecting columnwise_data when needed. If the required columnwise_data view - // is unavailable, this CK FP8 backend cannot represent the GEMM in its - // supported layout form, so we fall back instead of compiling/running an - // unsupported layout variant. + // The outer condition checks whether this NT presentation is possible: + // - A_use is already N, or can be made N using columnwise_data + // - B_use is already T, or can be made T using columnwise_data // - // Rewrite cases: + // Then each operand is rewritten independently only if needed: // NN -> rewrite B only // TN -> rewrite A and B // NT -> already in target form @@ -81,23 +81,16 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, const bool has_a_col = A0_te->has_columnwise_data(); const bool has_b_col = B0_te->has_columnwise_data(); - const bool can_make_a_nt = !transA_use || has_a_col; - const bool can_make_b_nt = transB_use || has_b_col; + if ((!transA_use || has_a_col) && (transB_use || has_b_col)) { + if (transA_use) { + use_a_colwise_data = true; + transA_use = false; + } - if (!can_make_a_nt || !can_make_b_nt) { - NVTE_WARN("ck_tile_grouped_gemm: FP8 grouped GEMM requires NT presentation. " - "Missing required columnwise_data for layout rewrite; falling back."); - return false; - } - - if (transA_use) { - use_a_colwise_data = true; - transA_use = false; - } - - if (!transB_use) { - use_b_colwise_data = true; - transB_use = true; + if (!transB_use) { + use_b_colwise_data = true; + transB_use = true; + } } } From b18099f1a6afb15f89d89fe71e1d76d8271e4406 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 8 Jun 2026 22:14:36 +0000 Subject: [PATCH 30/47] update dispatch --- .../common/gemm/cublaslt_gemm.cu | 50 +++++++++++++------ 1 file changed, 36 insertions(+), 14 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 35cad5092..dfadad460 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -33,6 +33,7 @@ #include "./cutlass_grouped_gemm.cuh" #else #include "ck_grouped_gemm/ck_grouped_gemm.h" +#include "ck_grouped_gemm/ck_mx_grouped_gemm.hpp" #endif #ifndef __HIP_PLATFORM_AMD__ @@ -1157,21 +1158,30 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor }; #endif + #ifdef __HIP_PLATFORM_AMD__ + auto effective_dtype = [](const transformer_engine::Tensor *t) { + if (t->has_data()) { + return t->data.dtype; + } + if (t->has_columnwise_data()) { + return t->columnwise_data.dtype; + } + return t->data.dtype; + }; + #endif + auto is_supported_dtype = [&]() -> bool { auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]); auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]); auto *OutputD = transformer_engine::convertNVTETensorCheck(D[0]); #ifdef __HIP_PLATFORM_AMD__ - auto A_dt = inputA->data.dtype; - auto B_dt = inputB->data.dtype; + auto A_dt = effective_dtype(inputA); + auto B_dt = effective_dtype(inputB); auto D_dt = OutputD->data.dtype; - return ( - (is_fp8_dtype(A_dt) && is_fp8_dtype(B_dt)) - ) || - ( - (A_dt == B_dt) && (A_dt == D_dt) && - (is_fp16_dtype(A_dt)) - ); + + return ((is_fp8_dtype(A_dt) && is_fp8_dtype(B_dt)) || + ((A_dt == B_dt) && (A_dt == D_dt) && is_fp16_dtype(A_dt))); + #else auto A_type = get_cuda_dtype(inputA->data.dtype); auto B_type = get_cuda_dtype(inputB->data.dtype); @@ -1194,11 +1204,23 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && is_supported_dtype() && #ifdef __HIP_PLATFORM_AMD__ true) { - if (!ck_tile_grouped_gemm(A, B, D, num_gemms, transa, transb, workspace, accumulate, stream)) { - if (warn_fallback) { - NVTE_WARN("Fallback to cuBLAS grouped GEMM."); - } - cublas_path(); + auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]) + const bool mxfp8_gemm = transformer_engine::is_mxfp8_scaling(inputA->scaling_mode); + + bool handled_by_ck = false; + if (mxfp8_gemm) { + handled_by_ck = ck_tile_mx_grouped_gemm( + A, B, D, num_gemms, transa, transb, workspace, accumulate, stream); + } else { + handled_by_ck = ck_tile_grouped_gemm( + A, B, D, num_gemms, transa, transb, workspace, accumulate, stream); + } + + if (!handled_by_ck) { + if (warn_fallback) { + NVTE_WARN("Fallback to cuBLAS grouped GEMM."); + } + cublas_path(); } #else all_groups_uniform_k128(B, transb)) { From d69f40c19d87be304cd502fe17736bf148ef3427 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 8 Jun 2026 22:21:24 +0000 Subject: [PATCH 31/47] Remove rocm_libraries submodule --- 3rdparty/rocm_libraries | 1 - 1 file changed, 1 deletion(-) delete mode 160000 3rdparty/rocm_libraries diff --git a/3rdparty/rocm_libraries b/3rdparty/rocm_libraries deleted file mode 160000 index 287088769..000000000 --- a/3rdparty/rocm_libraries +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 2870887694e0de120aca16c302779a2326386a85 From 88bb3dd83bf90b0b8ef77acb3ee7e65273381e04 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 8 Jun 2026 22:34:11 +0000 Subject: [PATCH 32/47] Add standalone Composable Kernel submodule --- .gitmodules | 4 ++++ 3rdparty/composable_kernel | 1 + transformer_engine/common/CMakeLists.txt | 2 +- 3 files changed, 6 insertions(+), 1 deletion(-) create mode 160000 3rdparty/composable_kernel diff --git a/.gitmodules b/.gitmodules index e44ffa060..eefd36fb7 100644 --- a/.gitmodules +++ b/.gitmodules @@ -30,3 +30,7 @@ path = 3rdparty/rocm_libraries url = https://github.com/ROCm/rocm-libraries.git +[submodule "3rdparty/composable_kernel"] + path = 3rdparty/composable_kernel + url = https://github.com/ROCm/composable_kernel.git + branch = develop diff --git a/3rdparty/composable_kernel b/3rdparty/composable_kernel new file mode 160000 index 000000000..0d18f4fc0 --- /dev/null +++ b/3rdparty/composable_kernel @@ -0,0 +1 @@ +Subproject commit 0d18f4fc05a31890e5ee365cfc15d82e1ba94669 diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 28d24c5bd..c09f3065a 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -384,7 +384,7 @@ set_property( PROPERTY COMPILE_OPTIONS "-g0;-dopt=on") else() - set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/rocm_libraries/projects/composablekernel) + set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/composable_kernel) target_include_directories(transformer_engine PRIVATE ${CK_ROOT}/include) endif() #USE_CUDA From e07767004ce18c91befe875c2fc1f507856992ce Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 8 Jun 2026 22:36:07 +0000 Subject: [PATCH 33/47] update gitmodules --- .gitmodules | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/.gitmodules b/.gitmodules index eefd36fb7..ef033e69d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -26,11 +26,8 @@ [submodule "3rdparty/QoLA"] path = 3rdparty/QoLA url = https://github.com/ROCm/QoLA.git -[submodule "3rdparty/rocm_libraries"] - path = 3rdparty/rocm_libraries - url = https://github.com/ROCm/rocm-libraries.git - [submodule "3rdparty/composable_kernel"] path = 3rdparty/composable_kernel url = https://github.com/ROCm/composable_kernel.git branch = develop + From 1f764d26bba6c86600d7d9e1de5cbf65b5dcb666 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 8 Jun 2026 23:00:16 +0000 Subject: [PATCH 34/47] minor fixes --- .../ck_grouped_gemm/ck_grouped_gemm_fp8.cpp | 18 ------------------ .../common/gemm/cublaslt_gemm.cu | 2 +- 2 files changed, 1 insertion(+), 19 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp index 50b701c05..4c26ad8e0 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp @@ -16,12 +16,6 @@ namespace transformer_engine { namespace grouped_gemm { -enum class GPUArch { - GFX942, - GFX950, - UNKNOWN -}; - struct TileCfg_128x128x128_16x16x128_2x2x1 { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -265,18 +259,6 @@ class QuantGroupedGemmRunner : public RunnerInterface { } }; -static inline GPUArch detect_gpu_arch() { - int arch = cuda::sm_arch(0); - - if (arch == 94) { - return GPUArch::GFX942; - } - if (arch == 95) { - return GPUArch::GFX950; - } - return GPUArch::UNKNOWN; -} - template struct FP8TileCfg; diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index dfadad460..cface8ebd 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1204,7 +1204,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && is_supported_dtype() && #ifdef __HIP_PLATFORM_AMD__ true) { - auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]) + auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]); const bool mxfp8_gemm = transformer_engine::is_mxfp8_scaling(inputA->scaling_mode); bool handled_by_ck = false; From 4a262d5356463d586472ac21cd7d4a243f75584f Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 8 Jun 2026 23:22:39 +0000 Subject: [PATCH 35/47] address PR comments --- tests/cpp/operator/CMakeLists.txt | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index d29f436e8..ce6b884e0 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -49,11 +49,6 @@ if(USE_ROCM) test_cublaslt_gemm.cu test_cast_mxfp4_transpose.cu test_ck_grouped_mxfp8.cu) -endif() - -if(USE_CUDA) - add_executable(test_operator ${test_cuda_sources}) -else() TE_GetHipifiedSources("${test_cuda_sources}" ${CMAKE_CURRENT_SOURCE_DIR} test_hip_sources) TE_AddHipifyDeps("${test_cuda_sources}" ${CMAKE_CURRENT_SOURCE_DIR}) message("${message_line}") @@ -63,14 +58,12 @@ endif() # Find required packages find_package(OpenMP REQUIRED) - if(USE_CUDA) list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn) target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS} OpenMP::OpenMP_CXX) else() target_link_libraries(test_operator PUBLIC hip::host hip::device GTest::gtest_main ${TE_LIB} OpenMP::OpenMP_CXX hiprand) endif() - target_compile_options(test_operator PRIVATE -O2 -fopenmp) include(GoogleTest) From 669c4cc6165163163880e3b2d94ab0f9ccf73c53 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Tue, 9 Jun 2026 00:10:28 +0000 Subject: [PATCH 36/47] address PR comments --- tests/cpp/operator/test_ck_grouped_mxfp8.cu | 29 ++++++++++----------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/tests/cpp/operator/test_ck_grouped_mxfp8.cu b/tests/cpp/operator/test_ck_grouped_mxfp8.cu index 7456be812..d853f921c 100644 --- a/tests/cpp/operator/test_ck_grouped_mxfp8.cu +++ b/tests/cpp/operator/test_ck_grouped_mxfp8.cu @@ -11,20 +11,6 @@ // 2. ck_tile::reference_mx_gemm host reference, using exact quantized operands/scales // 3. TE HIP reference kernel simplified from test_cublaslt_gemm.cu compute_ref_kernel -#include -#include -#include - -#include -#include -#include - -#include "../test_common.h" - -#include "ck_tile/core.hpp" -#include "ck_tile/host.hpp" -#include "ck_tile/ops/gemm.hpp" - #include #include #include @@ -39,6 +25,20 @@ #include #include +#include +#include +#include + +#include +#include +#include + +#include "../test_common.h" + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + using namespace transformer_engine; using namespace test; @@ -92,7 +92,6 @@ static std::string case_name(const testing::TestParamInfo& info) { static void set_env_defaults() { setenv("NVTE_USE_CUTLASS_GROUPED_GEMM", "1", 1); setenv("NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK", "1", 0); - setenv("NVTE_ROCM_ENABLE_MXFP8", "1", 0); } static float to_float(const bf16_t& x) { return static_cast(x); } From f3ecda3cab366818fe260b951f00f22e01b444b2 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Tue, 9 Jun 2026 21:46:31 +0000 Subject: [PATCH 37/47] address pr comments: fix gfx1250 arch name and convert if-else to switch in detect_gpu_arch --- .../ck_grouped_gemm/ck_grouped_gemm_common.h | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h index c8b0e7cd7..acb24dba5 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h @@ -83,18 +83,16 @@ enum class GPUArch { }; static inline GPUArch detect_gpu_arch() { - int arch = cuda::sm_arch(0); - - if (arch == 94) { - return GPUArch::GFX942; - } - if (arch == 95) { - return GPUArch::GFX950; - } - if (arch == 1250) { - return GPUArch::GFX1250; + switch (cuda::sm_arch(0)) { + case 94: + return GPUArch::GFX942; + case 95: + return GPUArch::GFX950; + case 125: + return GPUArch::GFX1250; + default: + return GPUArch::UNKNOWN; } - return GPUArch::UNKNOWN; } struct GroupedGemmRunContext { From 94b0126a789b00386c801fecfcd2a67118711e0f Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Wed, 10 Jun 2026 13:59:50 +0000 Subject: [PATCH 38/47] use workspace for ck group gemm mxfp8 scales --- .../ck_grouped_gemm/ck_mx_grouped_gemm.cpp | 39 ++++++++++++++----- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp index ee52c2bed..b6cf03e1b 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp @@ -360,10 +360,15 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, std::vector descs; descs.reserve(group_num); - std::vector> a_scale_shuffled_bufs; - std::vector> b_scale_shuffled_bufs; - a_scale_shuffled_bufs.reserve(group_num); - b_scale_shuffled_bufs.reserve(group_num); + NVTE_CHECK(ctx.workspace != nullptr, + "ck_tile_mx_grouped_gemm: workspace is required for shuffled MXFP8 scales."); + + // Carve regions from the end of the workspace for mxfp8 scales. + // Layout: [CK kargs workspace ... | a_scales (i) | b_scales (i) | ... | a_scales (group_num-1) | b_scales (group_num-1)] + constexpr size_t kScaleWorkspaceAlign = 256; + uint8_t* scale_workspace_base = reinterpret_cast(ctx.workspace); + size_t scale_workspace_end = + (ctx.workspace_bytes / kScaleWorkspaceAlign) * kScaleWorkspaceAlign; for (int i = 0; i < group_num; i++) { const transformer_engine::Tensor* const A_te = @@ -439,12 +444,24 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, static_cast(b_scale_output_rows) * static_cast(KScale) * sizeof(BScaleType); - a_scale_shuffled_bufs.push_back( - std::make_unique(a_scale_shuffled_bytes)); - b_scale_shuffled_bufs.push_back( - std::make_unique(b_scale_shuffled_bytes)); - void* a_scale_shuffled_ptr = a_scale_shuffled_bufs.back()->GetDeviceBuffer(); - void* b_scale_shuffled_ptr = b_scale_shuffled_bufs.back()->GetDeviceBuffer(); + const size_t scale_pair_bytes = + a_scale_shuffled_bytes + b_scale_shuffled_bytes; + + scale_workspace_end = + (scale_workspace_end / kScaleWorkspaceAlign) * kScaleWorkspaceAlign; + + NVTE_CHECK(scale_workspace_end >= scale_pair_bytes, + "ck_tile_mx_grouped_gemm: insufficient workspace for shuffled MXFP8 scales. " + "Need current group scale bytes=", scale_pair_bytes, + ", available workspace bytes=", scale_workspace_end, + ". Increase the grouped GEMM workspace size."); + + scale_workspace_end -= scale_pair_bytes; + uint8_t* scale_pair_ptr = scale_workspace_base + scale_workspace_end; + + void* a_scale_shuffled_ptr = scale_pair_ptr; + void* b_scale_shuffled_ptr = scale_pair_ptr + a_scale_shuffled_bytes; + // CK expects canonical pre-shuffled scale buffers laid out as // A: [M, KScale] and B: [N, KScale], independent of A/B data layouts. // TE rowwise MXFP8 scale_inv is [rows, KScale] and can be read with @@ -501,6 +518,8 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, {/*stride_Ds*/}, stride_E)); } + ctx.workspace_bytes = scale_workspace_end; + // invoke gemm bool ok = false; TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(a_dtype, a_te_type, { From 479c509ea4fc7d2e3833dcb655d400405e9a31a3 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Wed, 10 Jun 2026 15:03:14 +0000 Subject: [PATCH 39/47] add comment to ck gfx1250 mxfp8 scale swizzle --- .../gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp index b6cf03e1b..e200103a3 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp @@ -54,6 +54,18 @@ struct GroupedGemKernelParam_Wmma static constexpr ck_tile::index_t K_Warp_Tile = 128; }; +// gfx1250 scale preshuffle. +// +// Input scales are logically [MN, KScale] +// +// The output layout groups KScale into tiles of 4 (= 128 / ScaleBlockSize) +// and additionally blocks M into chunks of 32 rows: +// +// [MN, KScale] +// -> [MN/32, KScale/4, 32, 4] +// +// For A scales, rows=M and output_rows is M padded to M_Warp_Tile. +// For B scales, rows=N and output_rows is currently N. template __global__ void preshuffle_scale_gfx1250_kernel(const ScaleType* __restrict__ src, ScaleType* __restrict__ dst, From 5b4b7fe621d669a5f2d592d8e02377d301eb7f6d Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Wed, 10 Jun 2026 19:37:14 +0000 Subject: [PATCH 40/47] change random generation in test_ck_grouped_mxfp8.cu to use pre-existing utilities from test_common.cu --- tests/cpp/operator/test_ck_grouped_mxfp8.cu | 39 +++++++-------------- 1 file changed, 13 insertions(+), 26 deletions(-) diff --git a/tests/cpp/operator/test_ck_grouped_mxfp8.cu b/tests/cpp/operator/test_ck_grouped_mxfp8.cu index d853f921c..ef71f406a 100644 --- a/tests/cpp/operator/test_ck_grouped_mxfp8.cu +++ b/tests/cpp/operator/test_ck_grouped_mxfp8.cu @@ -19,7 +19,6 @@ #include #include #include -#include #include #include #include @@ -75,8 +74,6 @@ struct CaseConfig { size_t n; size_t k; int experts; - float scale; - int seed; LayoutConfig layout; DTypeConfig dtype; }; @@ -148,16 +145,6 @@ __global__ void compute_ref_kernel( } } -template -static void fill_randn_cpu(Tensor* t, float scale, int seed) { - std::mt19937 gen(seed); - std::normal_distribution dist(0.0f, scale); - const size_t n = product(t->rowwise_shape()); - T* ptr = t->rowwise_cpu_dptr(); - for (size_t i = 0; i < n; ++i) ptr[i] = static_cast(dist(gen)); - t->from_cpu(); -} - static std::vector split_even(size_t m_total, int experts) { NVTE_CHECK(experts > 0, "experts must be > 0"); NVTE_CHECK(m_total % static_cast(experts) == 0, @@ -445,11 +432,11 @@ static void run_case_typed(const CaseConfig& cfg) { const auto a_shape = a_shape_for_te(cfg.n, cfg.k, cfg.layout.transa); const auto b_shape = b_shape_for_te(m, cfg.k, cfg.layout.transb); - a_src.emplace_back("a_src", a_shape, DType::kBFloat16); - b_src.emplace_back("b_src", b_shape, DType::kBFloat16); + a_src.emplace_back("a_src" + std::to_string(g), a_shape, DType::kBFloat16); + b_src.emplace_back("b_src" + std::to_string(g), b_shape, DType::kBFloat16); - fill_randn_cpu(&a_src.back(), cfg.scale, cfg.seed + 1009 * g + 17); - fill_randn_cpu(&b_src.back(), cfg.scale, cfg.seed + 1009 * g + 29); + fillUniform(&a_src.back()); + fillUniform(&b_src.back()); // Allocate both rowwise and columnwise MX views so the backend can canonicalize NN/NT/TN. a_mx.emplace_back("a_mx", a_shape, te_dtype(cfg.dtype.a), @@ -533,17 +520,17 @@ static std::vector make_cases() { const std::vector dtypes = {kFP8FP8, kFP8BF8, kBF8FP8, kBF8BF8}; const std::vector base_cases = { // Small sanity across NN/NT/TN. - CaseConfig{1024, 1024, 1024, 2, 0.25f, 1234, kNN, kFP8FP8}, - CaseConfig{1024, 1024, 1024, 2, 0.25f, 1234, kNT, kFP8FP8}, - CaseConfig{1024, 1024, 1024, 2, 0.25f, 1234, kTN, kFP8FP8}, + CaseConfig{1024, 1024, 1024, 2, kNN, kFP8FP8}, + CaseConfig{1024, 1024, 1024, 2, kNT, kFP8FP8}, + CaseConfig{1024, 1024, 1024, 2, kTN, kFP8FP8}, // Earlier failure regime across NN/NT/TN. - CaseConfig{1536, 4096, 4096, 3, 0.25f, 1234, kNN, kFP8FP8}, - CaseConfig{1536, 4096, 4096, 3, 0.25f, 1234, kNT, kFP8FP8}, - CaseConfig{1536, 4096, 4096, 3, 0.25f, 1234, kTN, kFP8FP8}, + CaseConfig{1536, 4096, 4096, 3, kNN, kFP8FP8}, + CaseConfig{1536, 4096, 4096, 3, kNT, kFP8FP8}, + CaseConfig{1536, 4096, 4096, 3, kTN, kFP8FP8}, // Llama-ish suspicious path across NN/NT/TN. - CaseConfig{4096, 12288, 4096, 4, 0.25f, 1234, kNN, kFP8FP8}, - CaseConfig{4096, 12288, 4096, 4, 0.25f, 1234, kNT, kFP8FP8}, - CaseConfig{4096, 12288, 4096, 4, 0.25f, 1234, kTN, kFP8FP8}, + CaseConfig{4096, 12288, 4096, 4, kNN, kFP8FP8}, + CaseConfig{4096, 12288, 4096, 4, kNT, kFP8FP8}, + CaseConfig{4096, 12288, 4096, 4, kTN, kFP8FP8}, }; std::vector cases; From 68ed32a02d7947738ad83b3d2d720b7da3ee6670 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Wed, 10 Jun 2026 19:42:36 +0000 Subject: [PATCH 41/47] address nits in ck_grouped_gemm_common.h --- .../common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h index acb24dba5..8fa8c75fe 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h @@ -7,7 +7,6 @@ #pragma once #include -#include "common/util/cuda_runtime.h" #include #include @@ -16,14 +15,16 @@ #include #include + +#include "common/util/cuda_runtime.h" #include "../../common.h" #include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "ck_tile/ops/gemm.hpp" -#include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/gemm/kernel/mx_grouped_gemm_kernel.hpp" -#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" namespace transformer_engine { namespace grouped_gemm { From 2e74a63e27410c5467aee5c497f7a92344c57602 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Wed, 10 Jun 2026 20:36:01 +0000 Subject: [PATCH 42/47] stylistic changes --- .../ck_grouped_gemm/ck_mx_grouped_gemm.cpp | 671 +++++++++--------- 1 file changed, 337 insertions(+), 334 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp index e200103a3..c562c131b 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp @@ -13,45 +13,42 @@ using mx_grouped_gemm_kargs = ck_tile::MxGroupedGemmHostArgs<>; static constexpr ck_tile::index_t ScaleBlockSize = 32; -enum struct MxGemmPipelineType -{ - CompTDMV1, - CompTDMV2 +enum struct MxGemmPipelineType { + CompTDMV1, + CompTDMV2 }; template struct MxGemmPipelineTypeSelector; + template -struct MxGemmPipelineTypeSelector -{ - using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompTDM; - using pipeline = ck_tile::GemmPipelineAgBgCrCompTDMV1; - static constexpr auto GetName() { return "GemmPipelineAgBgCrCompTDMV1"; } +struct MxGemmPipelineTypeSelector { + using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompTDM; + using pipeline = ck_tile::GemmPipelineAgBgCrCompTDMV1; + static constexpr auto GetName() { return "GemmPipelineAgBgCrCompTDMV1"; } }; template -struct MxGemmPipelineTypeSelector -{ - using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompTDM; - using pipeline = ck_tile::GemmPipelineAgBgCrCompTDMV2; - static constexpr auto GetName() { return "GemmPipelineAgBgCrCompTDMV2"; } +struct MxGemmPipelineTypeSelector { + using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompTDM; + using pipeline = ck_tile::GemmPipelineAgBgCrCompTDMV2; + static constexpr auto GetName() { return "GemmPipelineAgBgCrCompTDMV2"; } }; -struct GroupedGemKernelParam_Wmma -{ - static const bool kPadM = false; - static const bool kPadN = false; - static const bool kPadK = false; - static const int kBlockPerCu = 1; - static const ck_tile::index_t M_Tile = 64; - static const ck_tile::index_t N_Tile = 64; - static const ck_tile::index_t K_Tile = 128; - static const ck_tile::index_t M_Warp = 2; - static const ck_tile::index_t N_Warp = 2; - static const ck_tile::index_t K_Warp = 1; - static const ck_tile::index_t M_Warp_Tile = 32; - static const ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 128; +struct GroupedGemKernelParam_Wmma { + static const bool kPadM = false; + static const bool kPadN = false; + static const bool kPadK = false; + static const int kBlockPerCu = 1; + static const ck_tile::index_t M_Tile = 64; + static const ck_tile::index_t N_Tile = 64; + static const ck_tile::index_t K_Tile = 128; + static const ck_tile::index_t M_Warp = 2; + static const ck_tile::index_t N_Warp = 2; + static const ck_tile::index_t K_Warp = 1; + static const ck_tile::index_t M_Warp_Tile = 32; + static const ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 128; }; // gfx1250 scale preshuffle. @@ -67,97 +64,103 @@ struct GroupedGemKernelParam_Wmma // For A scales, rows=M and output_rows is M padded to M_Warp_Tile. // For B scales, rows=N and output_rows is currently N. template -__global__ void preshuffle_scale_gfx1250_kernel(const ScaleType* __restrict__ src, - ScaleType* __restrict__ dst, +__global__ void preshuffle_scale_gfx1250_kernel(const ScaleType *__restrict__ src, + ScaleType *__restrict__ dst, int actual_rows, int output_rows, - int KScale) -{ - static_assert(ScaleBlockSize == 32 && sizeof(ScaleType) == 1, - "gfx1250 scale preshuffle only supports 8-bit scale with ScaleBlockSize=32"); - constexpr int MPerXdlops = 16; - constexpr int KPerXdlops = 128; - constexpr int MNPack = 2; - constexpr int KPack = 1; - constexpr int MNStep = MPerXdlops; // 16 - constexpr int KStep = KPerXdlops / ScaleBlockSize; // 4 - const int K0 = KScale / (KPack * KStep); - const int linear = blockIdx.x * blockDim.x + threadIdx.x; - const int total = output_rows * KScale; - if(linear >= total) - return; - const int mn = linear / KScale; - const int k = linear % KScale; - const int iMNRepeat = mn / (MNStep * MNPack); - const int tempmn = mn % (MNStep * MNPack); - const int iKRepeat = k / (KStep * KPack); - const int tempk = k % (KStep * KPack); - const int outputIndex = - (iMNRepeat * MNPack * MNStep) * (KStep * KPack * K0) + - (iKRepeat * KStep * KPack) * (MNStep * MNPack) + - tempmn * (KStep * KPack) + - tempk; - ScaleType value{}; - if(mn < actual_rows) - { - if constexpr(KStride) - value = src[mn * KScale + k]; - else - value = src[k * actual_rows + mn]; + int KScale) { + static_assert(ScaleBlockSize == 32 && sizeof(ScaleType) == 1, + "gfx1250 scale preshuffle only supports 8-bit scale with ScaleBlockSize=32"); + constexpr int MPerXdlops = 16; + constexpr int KPerXdlops = 128; + constexpr int MNPack = 2; + constexpr int KPack = 1; + constexpr int MNStep = MPerXdlops; // 16 + constexpr int KStep = KPerXdlops / ScaleBlockSize; // 4 + const int K0 = KScale / (KPack * KStep); + const int linear = blockIdx.x * blockDim.x + threadIdx.x; + const int total = output_rows * KScale; + if (linear >= total) { + return; + } + const int mn = linear / KScale; + const int k = linear % KScale; + const int iMNRepeat = mn / (MNStep * MNPack); + const int tempmn = mn % (MNStep * MNPack); + const int iKRepeat = k / (KStep * KPack); + const int tempk = k % (KStep * KPack); + const int outputIndex = + (iMNRepeat * MNPack * MNStep) * (KStep * KPack * K0) + + (iKRepeat * KStep * KPack) * (MNStep * MNPack) + + tempmn * (KStep * KPack) + + tempk; + ScaleType value{}; + if (mn < actual_rows) { + if constexpr (KStride) { + value = src[mn * KScale + k]; + } else { + value = src[k * actual_rows + mn]; } - dst[outputIndex] = value; + } + dst[outputIndex] = value; } template -void preShuffleScaleBuffer_gfx1250(const ScaleType* src, - ScaleType* dst, +void preShuffleScaleBuffer_gfx1250(const ScaleType *src, + ScaleType *dst, int actual_rows, int output_rows, int KScale, - hipStream_t stream) -{ - constexpr int KPerXdlops = 128; - constexpr int KStep = KPerXdlops / ScaleBlockSize; // 4 - if(KScale % KStep != 0) - { - NVTE_ERROR("preshuffle_scale_gfx1250: KScale must be a multiple of 4, " - "i.e. original K must be a multiple of 128 for ScaleBlockSize=32."); - } - const int total = output_rows * KScale; - constexpr int block_size = 256; - const int grid_size = (total + block_size - 1) / block_size; - hipLaunchKernelGGL((preshuffle_scale_gfx1250_kernel), - dim3(grid_size), - dim3(block_size), - 0, - stream, - src, - dst, - actual_rows, - output_rows, - KScale); - NVTE_CHECK_CUDA(hipGetLastError()); + hipStream_t stream) { + constexpr int KPerXdlops = 128; + constexpr int KStep = KPerXdlops / ScaleBlockSize; // 4 + if (KScale % KStep != 0) { + NVTE_ERROR("preshuffle_scale_gfx1250: KScale must be a multiple of 4, " + "i.e. original K must be a multiple of 128 for ScaleBlockSize=32."); + } + const int total = output_rows * KScale; + constexpr int block_size = 256; + const int grid_size = (total + block_size - 1) / block_size; + hipLaunchKernelGGL((preshuffle_scale_gfx1250_kernel), + dim3(grid_size), + dim3(block_size), + 0, + stream, + src, + dst, + actual_rows, + output_rows, + KScale); + NVTE_CHECK_CUDA(hipGetLastError()); } -template -bool invoke_mx_grouped_gemm(const std::vector& descs, const GroupedGemmRunContext& ctx, const ck_tile::stream_config& stream_cfg) -{ - // check hardware WMMA support for the warp tile +template +bool invoke_mx_grouped_gemm(const std::vector &descs, + const GroupedGemmRunContext &ctx, + const ck_tile::stream_config &stream_cfg) { + // Check hardware WMMA support for the warp tile. static constexpr bool has_wmma_support = ck_tile::has_wmma_traits_v; + AType, + BType, + AccType, + MXFP8GemmConfig::M_Warp_Tile, + MXFP8GemmConfig::N_Warp_Tile, + MXFP8GemmConfig::K_Warp_Tile>; NVTE_CHECK(has_wmma_support, - "ck_tile_mx_grouped_gemm: unsupported gfx125 WMMA traits for " - "AType/BType/AccType with warp tile shape ", - MXFP8GemmConfig::M_Warp_Tile, "x", - MXFP8GemmConfig::N_Warp_Tile, "x", - MXFP8GemmConfig::K_Warp_Tile); + "ck_tile_mx_grouped_gemm: unsupported gfx125 WMMA traits for " + "AType/BType/AccType with warp tile shape ", + MXFP8GemmConfig::M_Warp_Tile, "x", + MXFP8GemmConfig::N_Warp_Tile, "x", + MXFP8GemmConfig::K_Warp_Tile); using CLayout = RowMajor; constexpr bool preshuffle = false; @@ -194,50 +197,52 @@ bool invoke_mx_grouped_gemm(const std::vector& descs, con CLayout, TransposeC, StructuredSparsity, - false,//Persistent + false, // Persistent NumWaveGroup, preshuffle>; using UniversalGemmProblem = - ck_tile::MxGemmPipelineProblem; + /* Make pipeline selective. */ + using GemmPipeline = + typename MxGemmPipelineTypeSelector< + PipelineType, + UniversalGemmProblem>::pipeline; + + using GemmEpilogue = ck_tile::TdmEpilogue< + ck_tile::CShuffleEpilogueProblem, // DsDataType float, - GemmShape, - GemmUniversalTraits, - ck_tile::GemmPipelineScheduler::Intrawave, - ck_tile::element_wise::PassThrough, + CType, + ck_tile::tuple<>, // DsLayout + CLayout, ck_tile::element_wise::PassThrough, - AType, - BType, - AScaleType, - BScaleType>; - /* make pipeline selective */ - using GemmPipeline = - typename MxGemmPipelineTypeSelector::pipeline; - using GemmEpilogue = ck_tile::TdmEpilogue< - ck_tile::CShuffleEpilogueProblem,//DsDataType - float, - CType, - ck_tile::tuple<>,//DsLayout - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - MXFP8GemmConfig::M_Warp, - MXFP8GemmConfig::N_Warp, - MXFP8GemmConfig::M_Warp_Tile, - MXFP8GemmConfig::N_Warp_Tile, - MXFP8GemmConfig::K_Warp_Tile, - UniversalGemmProblem::TransposeC, - 1, /*kNumWaveGroups_*/ - false, /*FixedVectorSize_*/ - 1, /*VectorSizeC_*/ - 1, /*BlockedXDLN_PerWarp_*/ - DoubleSmemBuffer, /*DoubleSmemBuffer*/ - AType, /*AType_*/ - BType /*BType_*/>>; + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + MXFP8GemmConfig::M_Warp, + MXFP8GemmConfig::N_Warp, + MXFP8GemmConfig::M_Warp_Tile, + MXFP8GemmConfig::N_Warp_Tile, + MXFP8GemmConfig::K_Warp_Tile, + UniversalGemmProblem::TransposeC, + 1, /* kNumWaveGroups_ */ + false, /* FixedVectorSize_ */ + 1, /* VectorSizeC_ */ + 1, /* BlockedXDLN_PerWarp_ */ + DoubleSmemBuffer, /* DoubleSmemBuffer */ + AType, /* AType_ */ + BType /* BType_ */>>; using Kernel = ck_tile::MxGroupedGemmKernel; if (!has_sufficient_workspace(ctx)) { @@ -245,10 +250,9 @@ bool invoke_mx_grouped_gemm(const std::vector& descs, con } auto kargs = Kernel::MakeKargs(descs); - if(!Kernel::IsSupportedArgument(kargs)) - { + if (!Kernel::IsSupportedArgument(kargs)) { NVTE_WARN("ck_tile_mx_grouped_gemm: CK_Tile kernel arguments not supported for this config. " - "Falling back."); + "Falling back."); return false; } const dim3 blocks = Kernel::BlockSize(); @@ -259,25 +263,25 @@ bool invoke_mx_grouped_gemm(const std::vector& descs, con hipMemcpyHostToDevice, ctx.stream)); ck_tile::ignore = ck_tile::launch_kernel( - stream_cfg, ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, - ck_tile::cast_pointer_to_constant_address_space(ctx.workspace), - kargs.size())); + stream_cfg, ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, + ck_tile::cast_pointer_to_constant_address_space(ctx.workspace), + kargs.size())); return true; }); }); return false; } -bool ck_tile_mx_grouped_gemm(const NVTETensor* A, - const NVTETensor* B, - NVTETensor* D, - int group_num, - bool transA, - bool transB, - NVTETensor* workspace, - bool accumulate,//ignored for now - hipStream_t stream) { +bool ck_tile_mx_grouped_gemm(const NVTETensor *A, + const NVTETensor *B, + NVTETensor *D, + int group_num, + bool transA, + bool transB, + NVTETensor *workspace, + bool accumulate, // ignored for now + hipStream_t stream) { if (detect_gpu_arch() != GPUArch::GFX1250) { NVTE_WARN("ck_tile_mx_grouped_gemm: only supported on gfx1250. Falling back."); @@ -290,8 +294,8 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, // Normalize input mats // I.e., swap A and B, as well as transa and transb. - const NVTETensor* A_use = B; - const NVTETensor* B_use = A; + const NVTETensor *A_use = B; + const NVTETensor *B_use = A; bool transA_use = transB; bool transB_use = transA; @@ -302,18 +306,18 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, const bool use_a_colwise_data = transA_use; const bool use_b_colwise_data = !transB_use; - Tensor* A0_te = convertNVTETensorCheck(A_use[0]); - Tensor* B0_te = convertNVTETensorCheck(B_use[0]); + Tensor *A0_te = convertNVTETensorCheck(A_use[0]); + Tensor *B0_te = convertNVTETensorCheck(B_use[0]); // Validate scale type / data type combination. // Expected input data format: fp8/bf8 (e4m3/e5m2) // Expected scale data format: e8m0 - const auto* D0 = convertNVTETensorCheck(D[0]); + const auto *D0 = convertNVTETensorCheck(D[0]); - const auto& A0_data = use_a_colwise_data ? A0_te->columnwise_data : A0_te->data; - const auto& B0_data = use_b_colwise_data ? B0_te->columnwise_data : B0_te->data; - const auto& A0_scale = use_a_colwise_data ? A0_te->columnwise_scale_inv : A0_te->scale_inv; - const auto& B0_scale = use_b_colwise_data ? B0_te->columnwise_scale_inv : B0_te->scale_inv; + const auto &A0_data = use_a_colwise_data ? A0_te->columnwise_data : A0_te->data; + const auto &B0_data = use_b_colwise_data ? B0_te->columnwise_data : B0_te->data; + const auto &A0_scale = use_a_colwise_data ? A0_te->columnwise_scale_inv : A0_te->scale_inv; + const auto &B0_scale = use_b_colwise_data ? B0_te->columnwise_scale_inv : B0_te->scale_inv; NVTE_CHECK(A0_data.dptr != nullptr, "ck_tile_mx_grouped_gemm: A[0] data is not initialized"); @@ -327,12 +331,12 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, const auto a_scale_dtype = A0_scale.dtype; const auto b_scale_dtype = B0_scale.dtype; NVTE_CHECK(a_scale_dtype == DType::kFloat8E8M0, - "ck_tile_mx_grouped_gemm: A scale_inv dtype must be Float8E8M0, got ", - static_cast(a_scale_dtype)); + "ck_tile_mx_grouped_gemm: A scale_inv dtype must be Float8E8M0, got ", + static_cast(a_scale_dtype)); NVTE_CHECK(b_scale_dtype == DType::kFloat8E8M0, - "ck_tile_mx_grouped_gemm: B scale_inv dtype must be Float8E8M0, got ", - static_cast(b_scale_dtype)); + "ck_tile_mx_grouped_gemm: B scale_inv dtype must be Float8E8M0, got ", + static_cast(b_scale_dtype)); const auto a_dtype = A0_data.dtype; const auto b_dtype = B0_data.dtype; @@ -343,28 +347,28 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, using AScaleType = ck_tile::e8m0_t; using BScaleType = ck_tile::e8m0_t; - void* ws_ptr = nullptr; + void *ws_ptr = nullptr; size_t ws_bytes = 0; if (workspace) { - auto* ws_te = convertNVTETensorCheck(*workspace); + auto *ws_te = convertNVTETensorCheck(*workspace); ws_ptr = ws_te->data.dptr; ws_bytes = ws_te->data.numel() * typeToSize(ws_te->data.dtype); } GroupedGemmRunContext ctx{ - .A = A_use, - .B = B_use, - .D = D, - .N = 0, - .group_num = group_num, - .transA = transA_use, - .transB = transB_use, - .workspace = ws_ptr, - .workspace_bytes = ws_bytes, - .stream = stream, - .use_a_columnwise_data = use_a_colwise_data, - .use_b_columnwise_data = use_b_colwise_data, - .accumulate = false, + .A = A_use, + .B = B_use, + .D = D, + .N = 0, + .group_num = group_num, + .transA = transA_use, + .transB = transB_use, + .workspace = ws_ptr, + .workspace_bytes = ws_bytes, + .stream = stream, + .use_a_columnwise_data = use_a_colwise_data, + .use_b_columnwise_data = use_b_colwise_data, + .accumulate = false, }; const ck_tile::stream_config s{ctx.stream}; @@ -378,161 +382,160 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, // Carve regions from the end of the workspace for mxfp8 scales. // Layout: [CK kargs workspace ... | a_scales (i) | b_scales (i) | ... | a_scales (group_num-1) | b_scales (group_num-1)] constexpr size_t kScaleWorkspaceAlign = 256; - uint8_t* scale_workspace_base = reinterpret_cast(ctx.workspace); + uint8_t *scale_workspace_base = reinterpret_cast(ctx.workspace); size_t scale_workspace_end = (ctx.workspace_bytes / kScaleWorkspaceAlign) * kScaleWorkspaceAlign; for (int i = 0; i < group_num; i++) { - const transformer_engine::Tensor* const A_te = - transformer_engine::convertNVTETensorCheck(ctx.A[i]); - const transformer_engine::Tensor* const B_te = - transformer_engine::convertNVTETensorCheck(ctx.B[i]); - transformer_engine::Tensor* D_te = - transformer_engine::convertNVTETensorCheck(ctx.D[i]); - - const auto& a = ctx.use_a_columnwise_data ? A_te->columnwise_data : A_te->data; - const auto& b = ctx.use_b_columnwise_data ? B_te->columnwise_data : B_te->data; - const auto& d = D_te->data; - const auto& a_scales = - ctx.use_a_columnwise_data ? A_te->columnwise_scale_inv : A_te->scale_inv; - const auto& b_scales = - ctx.use_b_columnwise_data ? B_te->columnwise_scale_inv : B_te->scale_inv; - - int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; - - if (!get_flat_2d_dims(*A_te, Ad0, Ad1)) { - NVTE_ERROR("ck_tile_mx_grouped_gemm: expected rank>=2 for normalized A in group ", i); - } + const transformer_engine::Tensor *const A_te = + transformer_engine::convertNVTETensorCheck(ctx.A[i]); + const transformer_engine::Tensor *const B_te = + transformer_engine::convertNVTETensorCheck(ctx.B[i]); + transformer_engine::Tensor *D_te = + transformer_engine::convertNVTETensorCheck(ctx.D[i]); + + const auto &a = ctx.use_a_columnwise_data ? A_te->columnwise_data : A_te->data; + const auto &b = ctx.use_b_columnwise_data ? B_te->columnwise_data : B_te->data; + const auto &d = D_te->data; + const auto &a_scales = + ctx.use_a_columnwise_data ? A_te->columnwise_scale_inv : A_te->scale_inv; + const auto &b_scales = + ctx.use_b_columnwise_data ? B_te->columnwise_scale_inv : B_te->scale_inv; + + int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; + + if (!get_flat_2d_dims(*A_te, Ad0, Ad1)) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: expected rank>=2 for normalized A in group ", i); + } - if (!get_flat_2d_dims(*B_te, Bd0, Bd1)) { - NVTE_ERROR("ck_tile_mx_grouped_gemm: expected rank>=2 for normalized B in group ", i); - } + if (!get_flat_2d_dims(*B_te, Bd0, Bd1)) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: expected rank>=2 for normalized B in group ", i); + } - if (!get_flat_2d_dims(*D_te, Dd0, Dd1)) { - NVTE_ERROR("ck_tile_mx_grouped_gemm: expected rank>=2 for normalized D in group ", i); - } - if (a.dptr == nullptr || b.dptr == nullptr || a_scales.dptr == nullptr || - b_scales.dptr == nullptr) { - NVTE_ERROR("ck_tile_mx_grouped_gemm: effective A/B data or scale_inv is missing."); - } - if (a_scales.shape.size() != 2 || b_scales.shape.size() != 2) { - NVTE_ERROR("ck_tile_mx_grouped_gemm: expected effective A/B scale_inv tensors to be rank-2."); - } + if (!get_flat_2d_dims(*D_te, Dd0, Dd1)) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: expected rank>=2 for normalized D in group ", i); + } + if (a.dptr == nullptr || b.dptr == nullptr || a_scales.dptr == nullptr || + b_scales.dptr == nullptr) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: effective A/B data or scale_inv is missing."); + } + if (a_scales.shape.size() != 2 || b_scales.shape.size() != 2) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: expected effective A/B scale_inv tensors to be rank-2."); + } - const int64_t M = ctx.transA ? Ad1 : Ad0; - const int64_t K = ctx.transA ? Ad0 : Ad1; - const int64_t N = ctx.transB ? Bd0 : Bd1; - const int64_t Kb = ctx.transB ? Bd1 : Bd0; - if (K % ScaleBlockSize != 0) { - NVTE_ERROR("ck_tile_mx_grouped_gemm: K must be a multiple of ScaleBlockSize for MX GEMM", i); - } - const int KScale = static_cast(K / ScaleBlockSize); - if (Kb != K) { - NVTE_ERROR("ck_tile_mx_grouped_gemm: K mismatch between A and B in group ", i, - ". op(A)=", M, "x", K, ", op(B)=", Kb, "x", N); - } - if (Dd0 != M || Dd1 != N) { - NVTE_ERROR("ck_tile_mx_grouped_gemm: D shape mismatch in group ", i, - ". D=", Dd0, "x", Dd1, ", expected=", M, "x", N); - } + const int64_t M = ctx.transA ? Ad1 : Ad0; + const int64_t K = ctx.transA ? Ad0 : Ad1; + const int64_t N = ctx.transB ? Bd0 : Bd1; + const int64_t Kb = ctx.transB ? Bd1 : Bd0; + if (K % ScaleBlockSize != 0) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: K must be a multiple of ScaleBlockSize for MX GEMM", i); + } + const int KScale = static_cast(K / ScaleBlockSize); + if (Kb != K) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: K mismatch between A and B in group ", i, + ". op(A)=", M, "x", K, ", op(B)=", Kb, "x", N); + } + if (Dd0 != M || Dd1 != N) { + NVTE_ERROR("ck_tile_mx_grouped_gemm: D shape mismatch in group ", i, + ". D=", Dd0, "x", Dd1, ", expected=", M, "x", N); + } - const ck_tile::index_t stride_A = static_cast(Ad1); - const ck_tile::index_t stride_B = static_cast(Bd1); - const ck_tile::index_t stride_E = static_cast(Dd1); + const ck_tile::index_t stride_A = static_cast(Ad1); + const ck_tile::index_t stride_B = static_cast(Bd1); + const ck_tile::index_t stride_E = static_cast(Dd1); - // Pre-shuffle scale buffers for the hardware. - const int a_scale_actual_rows = static_cast(M); - const int a_scale_output_rows = + // Pre-shuffle scale buffers for the hardware. + const int a_scale_actual_rows = static_cast(M); + const int a_scale_output_rows = ck_tile::integer_least_multiple( - static_cast(M), - static_cast(GroupedGemKernelParam_Wmma::M_Warp_Tile)); - const int b_scale_actual_rows = static_cast(N); - const int b_scale_output_rows = static_cast(N); - const size_t a_scale_shuffled_bytes = - static_cast(a_scale_output_rows) * - static_cast(KScale) * - sizeof(AScaleType); - const size_t b_scale_shuffled_bytes = - static_cast(b_scale_output_rows) * - static_cast(KScale) * - sizeof(BScaleType); - const size_t scale_pair_bytes = - a_scale_shuffled_bytes + b_scale_shuffled_bytes; - - scale_workspace_end = - (scale_workspace_end / kScaleWorkspaceAlign) * kScaleWorkspaceAlign; - - NVTE_CHECK(scale_workspace_end >= scale_pair_bytes, - "ck_tile_mx_grouped_gemm: insufficient workspace for shuffled MXFP8 scales. " - "Need current group scale bytes=", scale_pair_bytes, - ", available workspace bytes=", scale_workspace_end, - ". Increase the grouped GEMM workspace size."); - - scale_workspace_end -= scale_pair_bytes; - uint8_t* scale_pair_ptr = scale_workspace_base + scale_workspace_end; - - void* a_scale_shuffled_ptr = scale_pair_ptr; - void* b_scale_shuffled_ptr = scale_pair_ptr + a_scale_shuffled_bytes; - - // CK expects canonical pre-shuffled scale buffers laid out as - // A: [M, KScale] and B: [N, KScale], independent of A/B data layouts. - // TE rowwise MXFP8 scale_inv is [rows, KScale] and can be read with - // KStride=true. TE columnwise_scale_inv is [KScale, rows] and must be - // read with KStride=false before writing CK's canonical shuffled layout. - if (ctx.use_a_columnwise_data) { - preShuffleScaleBuffer_gfx1250( - reinterpret_cast(a_scales.dptr), - reinterpret_cast(a_scale_shuffled_ptr), - a_scale_actual_rows, - a_scale_output_rows, - KScale, - stream); - } else { - preShuffleScaleBuffer_gfx1250( - reinterpret_cast(a_scales.dptr), - reinterpret_cast(a_scale_shuffled_ptr), - a_scale_actual_rows, - a_scale_output_rows, - KScale, - stream); - } + static_cast(M), + static_cast(GroupedGemKernelParam_Wmma::M_Warp_Tile)); + const int b_scale_actual_rows = static_cast(N); + const int b_scale_output_rows = static_cast(N); + const size_t a_scale_shuffled_bytes = + static_cast(a_scale_output_rows) * + static_cast(KScale) * + sizeof(AScaleType); + const size_t b_scale_shuffled_bytes = + static_cast(b_scale_output_rows) * + static_cast(KScale) * + sizeof(BScaleType); + const size_t scale_pair_bytes = + a_scale_shuffled_bytes + b_scale_shuffled_bytes; + scale_workspace_end = + (scale_workspace_end / kScaleWorkspaceAlign) * kScaleWorkspaceAlign; + + NVTE_CHECK(scale_workspace_end >= scale_pair_bytes, + "ck_tile_mx_grouped_gemm: insufficient workspace for shuffled MXFP8 scales. " + "Need current group scale bytes=", scale_pair_bytes, + ", available workspace bytes=", scale_workspace_end, + ". Increase the grouped GEMM workspace size."); + + scale_workspace_end -= scale_pair_bytes; + uint8_t *scale_pair_ptr = scale_workspace_base + scale_workspace_end; + + void *a_scale_shuffled_ptr = scale_pair_ptr; + void *b_scale_shuffled_ptr = scale_pair_ptr + a_scale_shuffled_bytes; + + // CK expects canonical pre-shuffled scale buffers laid out as + // A: [M, KScale] and B: [N, KScale], independent of A/B data layouts. + // TE rowwise MXFP8 scale_inv is [rows, KScale] and can be read with + // KStride=true. TE columnwise_scale_inv is [KScale, rows] and must be + // read with KStride=false before writing CK's canonical shuffled layout. + if (ctx.use_a_columnwise_data) { + preShuffleScaleBuffer_gfx1250( + reinterpret_cast(a_scales.dptr), + reinterpret_cast(a_scale_shuffled_ptr), + a_scale_actual_rows, + a_scale_output_rows, + KScale, + stream); + } else { + preShuffleScaleBuffer_gfx1250( + reinterpret_cast(a_scales.dptr), + reinterpret_cast(a_scale_shuffled_ptr), + a_scale_actual_rows, + a_scale_output_rows, + KScale, + stream); + } - if (ctx.use_b_columnwise_data) { - preShuffleScaleBuffer_gfx1250( - reinterpret_cast(b_scales.dptr), - reinterpret_cast(b_scale_shuffled_ptr), - b_scale_actual_rows, - b_scale_output_rows, - KScale, - stream); - } else { - preShuffleScaleBuffer_gfx1250( - reinterpret_cast(b_scales.dptr), - reinterpret_cast(b_scale_shuffled_ptr), - b_scale_actual_rows, - b_scale_output_rows, - KScale, - stream); - } - descs.emplace_back(mx_grouped_gemm_kargs( - a.dptr, - a_scale_shuffled_ptr, - b.dptr, - b_scale_shuffled_ptr, - {/*ds_ptr*/}, - d.dptr, - 1,//kbatch - M, - N, - K, - stride_A, - stride_B, - {/*stride_Ds*/}, - stride_E)); + if (ctx.use_b_columnwise_data) { + preShuffleScaleBuffer_gfx1250( + reinterpret_cast(b_scales.dptr), + reinterpret_cast(b_scale_shuffled_ptr), + b_scale_actual_rows, + b_scale_output_rows, + KScale, + stream); + } else { + preShuffleScaleBuffer_gfx1250( + reinterpret_cast(b_scales.dptr), + reinterpret_cast(b_scale_shuffled_ptr), + b_scale_actual_rows, + b_scale_output_rows, + KScale, + stream); + } + descs.emplace_back(mx_grouped_gemm_kargs( + a.dptr, + a_scale_shuffled_ptr, + b.dptr, + b_scale_shuffled_ptr, + {/*ds_ptr*/}, + d.dptr, + 1, // kbatch + M, + N, + K, + stride_A, + stride_B, + {/*stride_Ds*/}, + stride_E)); } ctx.workspace_bytes = scale_workspace_end; - // invoke gemm + // Invoke the GEMM. bool ok = false; TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(a_dtype, a_te_type, { using AType = typename TETypeToCKType::type; @@ -542,7 +545,7 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, using CType = typename TETypeToCKType::type; ok = invoke_mx_grouped_gemm(descs,ctx,s); + AScaleType, BScaleType>(descs, ctx, s); }); }); }); @@ -552,15 +555,15 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor* A, } // namespace grouped_gemm } // namespace transformer_engine -bool ck_tile_mx_grouped_gemm(const NVTETensor* A, - const NVTETensor* B, - NVTETensor* D, +bool ck_tile_mx_grouped_gemm(const NVTETensor *A, + const NVTETensor *B, + NVTETensor *D, int group_num, bool transA, bool transB, - NVTETensor* workspace, + NVTETensor *workspace, bool accumulate, hipStream_t stream) { return transformer_engine::grouped_gemm::ck_tile_mx_grouped_gemm( - A, B, D, group_num, transA, transB, workspace, accumulate, stream); + A, B, D, group_num, transA, transB, workspace, accumulate, stream); } From f4c97caa048ebcc566c1a4239ab96647dc394f66 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Wed, 10 Jun 2026 20:45:03 +0000 Subject: [PATCH 43/47] address pr comment: explicitly mention purpose of ck gfx1250 swizzle over existing implementation --- .../common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp index c562c131b..e942f0e85 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp @@ -53,13 +53,18 @@ struct GroupedGemKernelParam_Wmma { // gfx1250 scale preshuffle. // +// Unlike the existing MXFP8 GEMM scale swizzle defined in: +// transformer_engine/common/swizzle/swizzle.cu +// +// CK gfx1250 WMMA kernels expect scales in the layout below: +// // Input scales are logically [MN, KScale] // // The output layout groups KScale into tiles of 4 (= 128 / ScaleBlockSize) // and additionally blocks M into chunks of 32 rows: // -// [MN, KScale] -// -> [MN/32, KScale/4, 32, 4] +// [MN, KScale] +// -> [MN/32, KScale/4, 32, 4] // // For A scales, rows=M and output_rows is M padded to M_Warp_Tile. // For B scales, rows=N and output_rows is currently N. From bdc6b4eb78ce3317e24e53e1ffc6b1308050df9a Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Wed, 10 Jun 2026 21:20:07 +0000 Subject: [PATCH 44/47] address PR comments --- .../gemm/ck_grouped_gemm/ck_grouped_gemm.h | 20 ++++++-- .../ck_grouped_gemm/ck_mx_grouped_gemm.cpp | 51 +++++-------------- .../ck_grouped_gemm/ck_mx_grouped_gemm.hpp | 16 ------ .../common/gemm/cublaslt_gemm.cu | 1 - 4 files changed, 30 insertions(+), 58 deletions(-) delete mode 100644 transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.hpp diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.h index 97b4cfd88..cce56ff9b 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.h @@ -4,12 +4,24 @@ * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ -bool ck_tile_grouped_gemm(const NVTETensor* A, - const NVTETensor* B, - NVTETensor* D, +#pragma once + +bool ck_tile_grouped_gemm(const NVTETensor *A, + const NVTETensor *B, + NVTETensor *D, int group_num, bool transA, bool transB, - NVTETensor* workspace, + NVTETensor *workspace, bool accumulate, hipStream_t stream); + +bool ck_tile_mx_grouped_gemm(const NVTETensor *A, + const NVTETensor *B, + NVTETensor *D, + int group_num, + bool transA, + bool transB, + NVTETensor *workspace, + bool accumulate, + hipStream_t stream); diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp index e942f0e85..1b55236ce 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp @@ -428,10 +428,10 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor *A, NVTE_ERROR("ck_tile_mx_grouped_gemm: expected effective A/B scale_inv tensors to be rank-2."); } - const int64_t M = ctx.transA ? Ad1 : Ad0; - const int64_t K = ctx.transA ? Ad0 : Ad1; - const int64_t N = ctx.transB ? Bd0 : Bd1; - const int64_t Kb = ctx.transB ? Bd1 : Bd0; + const size_t M = ctx.transA ? Ad1 : Ad0; + const size_t K = ctx.transA ? Ad0 : Ad1; + const size_t N = ctx.transB ? Bd0 : Bd1; + const size_t Kb = ctx.transB ? Bd1 : Bd0; if (K % ScaleBlockSize != 0) { NVTE_ERROR("ck_tile_mx_grouped_gemm: K must be a multiple of ScaleBlockSize for MX GEMM", i); } @@ -491,52 +491,29 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor *A, preShuffleScaleBuffer_gfx1250( reinterpret_cast(a_scales.dptr), reinterpret_cast(a_scale_shuffled_ptr), - a_scale_actual_rows, - a_scale_output_rows, - KScale, - stream); + a_scale_actual_rows, a_scale_output_rows, KScale, stream); } else { preShuffleScaleBuffer_gfx1250( reinterpret_cast(a_scales.dptr), reinterpret_cast(a_scale_shuffled_ptr), - a_scale_actual_rows, - a_scale_output_rows, - KScale, - stream); + a_scale_actual_rows, a_scale_output_rows, KScale, stream); } if (ctx.use_b_columnwise_data) { preShuffleScaleBuffer_gfx1250( reinterpret_cast(b_scales.dptr), reinterpret_cast(b_scale_shuffled_ptr), - b_scale_actual_rows, - b_scale_output_rows, - KScale, - stream); + b_scale_actual_rows, b_scale_output_rows, KScale, stream); } else { preShuffleScaleBuffer_gfx1250( reinterpret_cast(b_scales.dptr), reinterpret_cast(b_scale_shuffled_ptr), - b_scale_actual_rows, - b_scale_output_rows, - KScale, - stream); + b_scale_actual_rows, b_scale_output_rows, KScale, stream); } descs.emplace_back(mx_grouped_gemm_kargs( - a.dptr, - a_scale_shuffled_ptr, - b.dptr, - b_scale_shuffled_ptr, - {/*ds_ptr*/}, - d.dptr, - 1, // kbatch - M, - N, - K, - stride_A, - stride_B, - {/*stride_Ds*/}, - stride_E)); + a.dptr, a_scale_shuffled_ptr, b.dptr, b_scale_shuffled_ptr, + {/*ds_ptr*/}, d.dptr, 1, // kbatch + M, N, K, stride_A, stride_B, {/*stride_Ds*/}, stride_E)); } ctx.workspace_bytes = scale_workspace_end; @@ -551,9 +528,9 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor *A, ok = invoke_mx_grouped_gemm(descs, ctx, s); - }); - }); - }); + }); // NOLINT(*) + }); // NOLINT(*) + }); // NOLINT(*) return ok; } diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.hpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.hpp deleted file mode 100644 index 96d3cd11b..000000000 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.hpp +++ /dev/null @@ -1,16 +0,0 @@ -/************************************************************************* - * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. - * - * License for AMD contributions = MIT. See LICENSE for more information - ************************************************************************/ - -bool ck_tile_mx_grouped_gemm(const NVTETensor* A, - const NVTETensor* B, - NVTETensor* D, - int group_num, - bool transA, - bool transB, - NVTETensor* workspace, - bool accumulate, - hipStream_t stream); - diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index cface8ebd..26c76d3c3 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -33,7 +33,6 @@ #include "./cutlass_grouped_gemm.cuh" #else #include "ck_grouped_gemm/ck_grouped_gemm.h" -#include "ck_grouped_gemm/ck_mx_grouped_gemm.hpp" #endif #ifndef __HIP_PLATFORM_AMD__ From 457bbc1a5c806ad188f9062c89a4e7147fa549b9 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Wed, 10 Jun 2026 21:25:40 +0000 Subject: [PATCH 45/47] inline mxfp8_gemm bool into if statement --- transformer_engine/common/gemm/cublaslt_gemm.cu | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 26c76d3c3..3e87d0544 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1204,10 +1204,9 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor #ifdef __HIP_PLATFORM_AMD__ true) { auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]); - const bool mxfp8_gemm = transformer_engine::is_mxfp8_scaling(inputA->scaling_mode); bool handled_by_ck = false; - if (mxfp8_gemm) { + if (transformer_engine::is_mxfp8_scaling(inputA->scaling_mode)) { handled_by_ck = ck_tile_mx_grouped_gemm( A, B, D, num_gemms, transa, transb, workspace, accumulate, stream); } else { From 19151f4053f351eefa9c9cd001fc3dae258a1cd2 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Wed, 10 Jun 2026 21:28:34 +0000 Subject: [PATCH 46/47] address nit --- .../common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h index 8fa8c75fe..7a5dbe63a 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h @@ -84,7 +84,7 @@ enum class GPUArch { }; static inline GPUArch detect_gpu_arch() { - switch (cuda::sm_arch(0)) { + switch (cuda::sm_arch()) { case 94: return GPUArch::GFX942; case 95: From 231c91654b61d8718cff14aa6b04ef6103e0fbed Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Wed, 10 Jun 2026 21:50:16 +0000 Subject: [PATCH 47/47] add warn fallback for mxfp8 ck --- .../ck_grouped_gemm/ck_mx_grouped_gemm.cpp | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp index 1b55236ce..96726abb8 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_mx_grouped_gemm.cpp @@ -149,7 +149,8 @@ template bool invoke_mx_grouped_gemm(const std::vector &descs, const GroupedGemmRunContext &ctx, - const ck_tile::stream_config &stream_cfg) { + const ck_tile::stream_config &stream_cfg, + bool warn_fallback) { // Check hardware WMMA support for the warp tile. static constexpr bool has_wmma_support = ck_tile::has_wmma_traits_v &descs, auto kargs = Kernel::MakeKargs(descs); if (!Kernel::IsSupportedArgument(kargs)) { - NVTE_WARN("ck_tile_mx_grouped_gemm: CK_Tile kernel arguments not supported for this config. " - "Falling back."); + if (warn_fallback) { + NVTE_WARN("ck_tile_mx_grouped_gemm: CK_Tile kernel arguments not supported for this config. " + "Falling back."); + } return false; } const dim3 blocks = Kernel::BlockSize(); @@ -288,8 +291,13 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor *A, bool accumulate, // ignored for now hipStream_t stream) { + const bool warn_fallback = + getenv("NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK", false); + if (detect_gpu_arch() != GPUArch::GFX1250) { - NVTE_WARN("ck_tile_mx_grouped_gemm: only supported on gfx1250. Falling back."); + if (warn_fallback) { + NVTE_WARN("ck_tile_mx_grouped_gemm: only supported on gfx1250. Falling back."); + } return false; } @@ -527,7 +535,7 @@ bool ck_tile_mx_grouped_gemm(const NVTETensor *A, using CType = typename TETypeToCKType::type; ok = invoke_mx_grouped_gemm(descs, ctx, s); + AScaleType, BScaleType>(descs, ctx, s, warn_fallback); }); // NOLINT(*) }); // NOLINT(*) }); // NOLINT(*)