Skip to content

enable blockwise FP8 quantization on rocm#609

Open
asdfvg123 wants to merge 7 commits into
devfrom
yeonsoo/blockwise_fp8
Open

enable blockwise FP8 quantization on rocm#609
asdfvg123 wants to merge 7 commits into
devfrom
yeonsoo/blockwise_fp8

Conversation

@asdfvg123

Copy link
Copy Markdown

Description

Please include a brief summary of the changes, relevant motivation and context.

Enable blockwise FP8 quantization on rocm

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

remove HIP guard in quantization.py
guard kernels using TMA in quantization.
add branch to handle rocm for different threads per wave

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

# TODO replace with call to fp8.py when recipe added.
recipe_available = not IS_HIP_EXTENSION and (get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8)
if IS_HIP_EXTENSION:
recipe_available = get_device_compute_capability() >= (9, 0)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this be always True on ROCm TE?

@asdfvg123 asdfvg123 Jun 4, 2026

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test targets MI300 and MI350 so I set to (9,0)

@@ -1 +1 @@
/*************************************************************************

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs AMD copyright

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

Comment on lines +8 to +24
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_bf16.h>
#include <cuda_runtime.h>

#include <cfloat>
#ifndef __HIP_PLATFORM_AMD__
#include <cuda/barrier>
#endif

#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/util/cuda_runtime.h"
#ifndef __HIP_PLATFORM_AMD__
#include "common/util/ptx.cuh"
#endif

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These #includes should be already disabled via hipify, so probably no need for the #ifndefs here.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Comment thread transformer_engine/common/common.h Outdated
Comment on lines +639 to +640
static constexpr float max = 448.0f;
static constexpr float max_inverse = 1.0 / max;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change necessary? fp8e4m3 max depends on the device type on AMD.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quantize_transpose_square_blockwise.cu and quantize_transpose_vector_blockwise.cu use
compute_scale_from_types<IType, fp8e4m3> for the first time, which exposed a latent bug in common.h

The #else branch of TypeExtrema<fp8e4m3> declared max as a static float,
This caused the constexpr static float max_finite_value initializer in TypeInfo in the same file to fail when the template was instantiated on the host.

The fix uses HIP_FP8_TYPE_FNUZ, used in hip_float8.h for selecting FNUZ at compile time, to make the host-pass branch constexpr as well.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the value really used host size, it should be runtime detected. If it is only for host translation of GPU code (i.e.. results are discarded), you can keep 448, no extra ifdefs is needed

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reverted to the original upstream. I instead changed the recipe_common.cuh following the other convention in quantize_transpose_vector_blockwise_fp4.cu L230

@alextmagro

alextmagro commented Jun 3, 2026

Copy link
Copy Markdown
Contributor

Could you give a description of what you want to achieve with this PR? My understanding is that block fp8 quantization relies on some upstream kernels that will need to be adapted for AMD.

If you're just trying to enable the interface, I would argue that we should do this last, after we have a working quantization and GEMM path (and enabled and passing C++/Python tests).

@asdfvg123

Copy link
Copy Markdown
Author

@alextmagro
This PR is to enable only the quantization in the AMD gpus, not the GEMM. There are two kernels in the upstream which uses TMA for the quantization and does not uses TMA for the quantization. I guarded the kernels which uses TMA and used the non-TMA kernels to quantize for AMD.

I tested with
tests/pytorch/test_float8blockwisetensor.py
and it passes [175 passed / 32 xpassed / 5 warnings]

@alextmagro

Copy link
Copy Markdown
Contributor

@alextmagro This PR is to enable only the quantization in the AMD gpus, not the GEMM. There are two kernels in the upstream which uses TMA for the quantization and does not uses TMA for the quantization. I guarded the kernels which uses TMA and used the non-TMA kernels to quantize for AMD.

I tested with tests/pytorch/test_float8blockwisetensor.py and it passes [175 passed / 32 xpassed / 5 warnings]

OK, in that case we need to add the cpp blockwise tests to the CMake file, and the pytorch test file to ci/pytorch.sh.

…dant HIP guards, revert unnecessary common.h change
Comment thread ci/pytorch.sh
run_default_fa 1 test_deferred_init.py
run_default_fa 1 test_quantized_tensor.py
run_default_fa 1 test_float8_current_scaling_exact.py
run_default_fa 1 test_float8blockwisetensor.py

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This test and test_float8_current_scaling_exact should be ahead of test_quantized_tensor.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

}

float qscale_inv = 1.0 / qscale;
float qscale_inv = 1.0f / qscale;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We had discussed this as a potential solution offline, but if the tolerances are needed then we should remove this

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

void compare_scaling_factors(const std::string& name, const float* test, const float* ref,
const size_t row_blocks, const size_t col_blocks,
const size_t test_stride, const size_t ref_stride) {
const size_t test_stride, const size_t ref_stride,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both function signature and the function change should be guarded. We might also just want to put the atol into the function body itself for now, since it is only used for blockwise scaling.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

# TODO replace with call to fp8.py when recipe added.
recipe_available = not IS_HIP_EXTENSION and (get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8)
if IS_HIP_EXTENSION:
recipe_available = get_device_compute_capability() >= (9, 5)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should be able to support MI300s for this, right? If so, > (9, 0) should be correct.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, keeping (9,0)

#ifdef __HIP_PLATFORM_AMD__
warp_tile_amax = blockwise_warp_reduce_max(amax);
#else
warp_tile_amax = warp_reduce_max<kThreadsPerWarp>(amax);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to just use warp_reduce_max here, and remove the kThreadsPerWarp=64 logic too. For the most part, the compiler will double up and we will be okay here.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warp_reduce_max in transformer_engine/common/utils.cuh uses THREADS_PER_WARP = 32 in the file which creates bug. Let me know if there is a better way

}
}

#ifdef TMA_HW_SUPPORTED

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer if we continued to use the AMD guard rather than TMA_HW_SUPPORTED here. Additionally, I am not sure we need to guard the def block around TMA_HW_SUPPORTED as all the cuda arch macros should be undefined.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed and now using AMD guard. Keep the upstream's TMA_HW_SUPPORTED guard, to guard non-TMA kernel

scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
tensor_map_output_trans, pow_2_scale, noop_ptr);
} else {
} else

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's avoid splitting up the } else { line here. We can add another macro guard instead if needed.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could not think of a better way than the current code. I guarded and called the kernel.

constexpr int kNVecSMem = 2; // The number of elements each LDS/STS touches

#ifdef __HIP_PLATFORM_AMD__
constexpr int kThreadsPerBlock = 512; // Thread block size, 8 warps (wave64) in total

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there actual performance improvements for increasing the # of threads and the threads per warp? If not, we should use the already present values for now.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kernel expects 8 waves / block , so I increased the number of threads

def check_fp8_block_scaling_support() -> Tuple[bool, str]:
"""Return if fp8 block scaling support is available"""
if IS_HIP_EXTENSION:
return False, "FP8 block scaled gemm not yet supported for ROCm"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need the same arch guard here as throughout the rest of the PR

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@alextmagro

alextmagro commented Jun 4, 2026

Copy link
Copy Markdown
Contributor

By the way, to run CI you need to add a CI level label. L3 is required before merging, L1 is for lighter testing, mostly sGPU tests, if you are midway through the ticket and expect to make more changes

Uploading image.png…

@asdfvg123 asdfvg123 added the ci-level 1 CI test level 1 label Jun 4, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copyright

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added


#ifdef __HIP_PLATFORM_AMD__
using WarpSyncMask = uint64_t;
constexpr WarpSyncMask kFullWarpMask = 0xFFFFFFFFFFFFFFFFULL;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ROCm should not use it. See how *_sync calls are guarded in other places

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed the mask and use ROCm __shfl instead of __shfl_sync

}
}
// Reduce amax in the warp (32x32 tile)
#ifdef __HIP_PLATFORM_AMD__

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whole this code is under #ifndef HIP_PLATFORM_AMD

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed the dead branch

// const values configuration

#ifdef __HIP_PLATFORM_AMD__
constexpr size_t kThreadsPerWarp = 64;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is platform dependent.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed now guarded with gfx1250 for 32 threads

transpose/multi_cast_transpose.cu
transpose/quantize_transpose_vector_blockwise.cu #CUDA-only
transpose/quantize_transpose_vector_blockwise.cu
transpose/quantize_transpose_square_blockwise.cu

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should stay in transformer_engine_cuda_arch_specific_sources

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Comment thread transformer_engine/common/common.h Outdated
Comment on lines +639 to +640
static constexpr float max = 448.0f;
static constexpr float max_inverse = 1.0 / max;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the value really used host size, it should be runtime detected. If it is only for host translation of GPU code (i.e.. results are discarded), you can keep 448, no extra ifdefs is needed

@asdfvg123

Copy link
Copy Markdown
Author

MI300 has 64KB of LDS which makes overflow when loading 128 * 128 FP32 data into LDS. I created a helper and branched the kernel. When loading FP32 data, the kernel loads 128 * 64 chunk of data and iterate to quantize. From the host's view, the kernel quantizes 128 * 128 elements.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 1 CI test level 1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants