Skip to content

CK MXFP8 Group Gemm gfx1250 Enablement#613

Open
aris134 wants to merge 50 commits into
devfrom
amartin/ck-mxfp8-group-gemm-gfx1250-clean
Open

CK MXFP8 Group Gemm gfx1250 Enablement#613
aris134 wants to merge 50 commits into
devfrom
amartin/ck-mxfp8-group-gemm-gfx1250-clean

Conversation

@aris134

@aris134 aris134 commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

Description

Integrates CK MXFP8 Group GEMM pipeline into TE.

Fixes https://github.com/ROCm/frameworks-internal/issues/16039

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add new submodule 3rdparty/rocm_libraries with sparse checkout for projects/composablekernel. Needed for the CK MXFP8 group GEMM definitions.
  • Adds CK MXFP8 group GEMM integration into TE with run-time arch detection for gfx1250 support.
  • Adds relevant cpp tests (tests/cpp/operator/test_ck_grouped_mxfp8.cu). Note that PyTorch grouped linear ck test coverage already includes MXFP8 (set NVTE_ROCM_ENABLE_MXFP8=1 in addition to NVTE_USE_CUTLASS_GROUPED_GEMM=1)

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

aris134 and others added 28 commits May 6, 2026 14:18
…rash; remaining issue is numerical validation vs BF16 sequential reference.
@aris134 aris134 self-assigned this Jun 9, 2026
@aris134 aris134 added the ci-level 1 CI test level 1 label Jun 9, 2026
@aris134 aris134 marked this pull request as ready for review June 9, 2026 01:01
@aris134 aris134 requested a review from ipanfilo June 9, 2026 01:01
Comment on lines +88 to +96
if (arch == 94) {
return GPUArch::GFX942;
}
if (arch == 95) {
return GPUArch::GFX950;
}
if (arch == 1250) {
return GPUArch::GFX1250;
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be a switch?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that looks nicer, thanks. Done in f3ecda3

if (arch == 95) {
return GPUArch::GFX950;
}
if (arch == 1250) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be 125?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think you're right, thanks. Fixed in f3ecda3

std::vector<mx_grouped_gemm_kargs> descs;
descs.reserve(group_num);

std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_scale_shuffled_bufs;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does ck_tile::DeviceMem allocate new memory? Can we use a workspace here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can use workspace here. Done in 94b0126

};

template <typename ScaleType, ck_tile::index_t ScaleBlockSize, bool KStride>
__global__ void preshuffle_scale_gfx1250_kernel(const ScaleType* __restrict__ src,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the same shuffling as in #605 ? Maybe we can add a comment here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite the same. The swizzle in #605 groups the scale-K dimension into tiles of 4, whereas this CK preshuffle additionally organizes scales into 32-row M blocks to match the layout expected by the CK gfx1250 WMMA kernel. I've added a comment to help clarify the layout in 479c509.

@matthiasdiener matthiasdiener Jun 10, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the comment you added goes in the right direction, I would additionally mention what you said here, that this is different from the other mxfp8 gemm swizzling, and that it is expected by CK 1250 WMMA kernel.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in f4c97ca

@aris134 aris134 requested a review from matthiasdiener June 10, 2026 15:06

@alextmagro alextmagro left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, my review was left as pending, so some of my comments may have already been addressed. Thanks!

}

template <typename T>
static void fill_randn_cpu(Tensor* t, float scale, int seed) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use our hipRAND generator in test_common?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Changed it in 5b4b7fe

return cases;
}

static const std::vector<CaseConfig> kCases = make_cases();

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should probably use seeds generated from test names like the rest of the c++ tests

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it should now be consistent in 5b4b7fe

#pragma once

#include <hip/hip_runtime.h>
#include "common/util/cuda_runtime.h"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this belongs after common headers

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 68ed32a

#include "ck_tile/core.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host/kernel_launch.hpp"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: /host/ goes before /ops/, and /elementwise/ goes before /gemm/

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 68ed32a

NVTE_ERROR("ck_tile_mx_grouped_gemm: expected effective A/B scale_inv tensors to be rank-2.");
}

const int64_t M = ctx.transA ? Ad1 : Ad0;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these should be size_ts, unless negative values are needed.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's fair. I changed that in bdc6b4e

KScale,
stream);
}
descs.emplace_back(mx_grouped_gemm_kargs(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another stylistic comment, but there are lots of line breaks for functions with 1 parameter per line. I personally prefer a more compact style with only line breaks as needed, especially when variable names are relatively short

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made some additional stylistic changes in bdc6b4e

ok = invoke_mx_grouped_gemm<GroupedGemKernelParam_Wmma,
AType, BType, CType,
AScaleType, BScaleType>(descs,ctx,s);
});

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need // NOLINT(*) at the end of every TRANSFORMER_ENGINE_TYPE_SWITCH_* statement

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in bdc6b4e

* License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/

bool ck_tile_mx_grouped_gemm(const NVTETensor* A,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing #pragma once, and maybe name file .h instead of .hpp for consistency?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, can we just add this to ck_grouped_gemm_common.h?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Since ck_grouped_gemm.h was meant to be the public API and ck_grouped_gemm_common.h is internal, I moved the declaration to ck_grouped_gemm.h, removing the need for ck_mx_grouped_gemm.h. Changes made in bdc6b4e

}
cublas_path();
auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]);
const bool mxfp8_gemm = transformer_engine::is_mxfp8_scaling(inputA->scaling_mode);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can probably inline this into the if statement since it is only used once

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 457bbc1


static constexpr ck_tile::index_t ScaleBlockSize = 32;

enum struct MxGemmPipelineType

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do prefer K&R style, and we lean towards that in the codebase. Consider moving open brackets to same line throughout, and maybe using post-increments and attaching references/pointers to the var instead of the type.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing that out. Made the edits in 2e74a63

};

static inline GPUArch detect_gpu_arch() {
switch (cuda::sm_arch(0)) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can just use

Suggested change
switch (cuda::sm_arch(0)) {
switch (cuda::sm_arch()) {

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 19151f4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 1 CI test level 1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants