enable blockwise FP8 quantization on rocm#609
Conversation
| # TODO replace with call to fp8.py when recipe added. | ||
| recipe_available = not IS_HIP_EXTENSION and (get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8) | ||
| if IS_HIP_EXTENSION: | ||
| recipe_available = get_device_compute_capability() >= (9, 0) |
There was a problem hiding this comment.
Wouldn't this be always True on ROCm TE?
There was a problem hiding this comment.
This test targets MI300 and MI350 so I set to (9,0)
| @@ -1 +1 @@ | |||
| /************************************************************************* | |||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| #include <cudaTypedefs.h> | ||
| #endif | ||
| #include <cuda_bf16.h> | ||
| #include <cuda_runtime.h> | ||
|
|
||
| #include <cfloat> | ||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| #include <cuda/barrier> | ||
| #endif | ||
|
|
||
| #include "common/common.h" | ||
| #include "common/recipe/recipe_common.cuh" | ||
| #include "common/util/cuda_runtime.h" | ||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| #include "common/util/ptx.cuh" | ||
| #endif |
There was a problem hiding this comment.
These #includes should be already disabled via hipify, so probably no need for the #ifndefs here.
| static constexpr float max = 448.0f; | ||
| static constexpr float max_inverse = 1.0 / max; |
There was a problem hiding this comment.
Is this change necessary? fp8e4m3 max depends on the device type on AMD.
There was a problem hiding this comment.
quantize_transpose_square_blockwise.cu and quantize_transpose_vector_blockwise.cu use
compute_scale_from_types<IType, fp8e4m3> for the first time, which exposed a latent bug in common.h
The #else branch of TypeExtrema<fp8e4m3> declared max as a static float,
This caused the constexpr static float max_finite_value initializer in TypeInfo in the same file to fail when the template was instantiated on the host.
The fix uses HIP_FP8_TYPE_FNUZ, used in hip_float8.h for selecting FNUZ at compile time, to make the host-pass branch constexpr as well.
There was a problem hiding this comment.
If the value really used host size, it should be runtime detected. If it is only for host translation of GPU code (i.e.. results are discarded), you can keep 448, no extra ifdefs is needed
There was a problem hiding this comment.
reverted to the original upstream. I instead changed the recipe_common.cuh following the other convention in quantize_transpose_vector_blockwise_fp4.cu L230
|
Could you give a description of what you want to achieve with this PR? My understanding is that block fp8 quantization relies on some upstream kernels that will need to be adapted for AMD. If you're just trying to enable the interface, I would argue that we should do this last, after we have a working quantization and GEMM path (and enabled and passing C++/Python tests). |
|
@alextmagro I tested with |
OK, in that case we need to add the cpp blockwise tests to the CMake file, and the pytorch test file to ci/pytorch.sh. |
…dant HIP guards, revert unnecessary common.h change
| run_default_fa 1 test_deferred_init.py | ||
| run_default_fa 1 test_quantized_tensor.py | ||
| run_default_fa 1 test_float8_current_scaling_exact.py | ||
| run_default_fa 1 test_float8blockwisetensor.py |
There was a problem hiding this comment.
nit: This test and test_float8_current_scaling_exact should be ahead of test_quantized_tensor.
| } | ||
|
|
||
| float qscale_inv = 1.0 / qscale; | ||
| float qscale_inv = 1.0f / qscale; |
There was a problem hiding this comment.
We had discussed this as a potential solution offline, but if the tolerances are needed then we should remove this
| void compare_scaling_factors(const std::string& name, const float* test, const float* ref, | ||
| const size_t row_blocks, const size_t col_blocks, | ||
| const size_t test_stride, const size_t ref_stride) { | ||
| const size_t test_stride, const size_t ref_stride, |
There was a problem hiding this comment.
Both function signature and the function change should be guarded. We might also just want to put the atol into the function body itself for now, since it is only used for blockwise scaling.
| # TODO replace with call to fp8.py when recipe added. | ||
| recipe_available = not IS_HIP_EXTENSION and (get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8) | ||
| if IS_HIP_EXTENSION: | ||
| recipe_available = get_device_compute_capability() >= (9, 5) |
There was a problem hiding this comment.
I think we should be able to support MI300s for this, right? If so, > (9, 0) should be correct.
| #ifdef __HIP_PLATFORM_AMD__ | ||
| warp_tile_amax = blockwise_warp_reduce_max(amax); | ||
| #else | ||
| warp_tile_amax = warp_reduce_max<kThreadsPerWarp>(amax); |
There was a problem hiding this comment.
We should be able to just use warp_reduce_max here, and remove the kThreadsPerWarp=64 logic too. For the most part, the compiler will double up and we will be okay here.
There was a problem hiding this comment.
warp_reduce_max in transformer_engine/common/utils.cuh uses THREADS_PER_WARP = 32 in the file which creates bug. Let me know if there is a better way
| } | ||
| } | ||
|
|
||
| #ifdef TMA_HW_SUPPORTED |
There was a problem hiding this comment.
I would prefer if we continued to use the AMD guard rather than TMA_HW_SUPPORTED here. Additionally, I am not sure we need to guard the def block around TMA_HW_SUPPORTED as all the cuda arch macros should be undefined.
There was a problem hiding this comment.
Fixed and now using AMD guard. Keep the upstream's TMA_HW_SUPPORTED guard, to guard non-TMA kernel
| scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, | ||
| tensor_map_output_trans, pow_2_scale, noop_ptr); | ||
| } else { | ||
| } else |
There was a problem hiding this comment.
Let's avoid splitting up the } else { line here. We can add another macro guard instead if needed.
There was a problem hiding this comment.
I could not think of a better way than the current code. I guarded and called the kernel.
| constexpr int kNVecSMem = 2; // The number of elements each LDS/STS touches | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| constexpr int kThreadsPerBlock = 512; // Thread block size, 8 warps (wave64) in total |
There was a problem hiding this comment.
Are there actual performance improvements for increasing the # of threads and the threads per warp? If not, we should use the already present values for now.
There was a problem hiding this comment.
The kernel expects 8 waves / block , so I increased the number of threads
| def check_fp8_block_scaling_support() -> Tuple[bool, str]: | ||
| """Return if fp8 block scaling support is available""" | ||
| if IS_HIP_EXTENSION: | ||
| return False, "FP8 block scaled gemm not yet supported for ROCm" |
There was a problem hiding this comment.
We need the same arch guard here as throughout the rest of the PR
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| using WarpSyncMask = uint64_t; | ||
| constexpr WarpSyncMask kFullWarpMask = 0xFFFFFFFFFFFFFFFFULL; |
There was a problem hiding this comment.
ROCm should not use it. See how *_sync calls are guarded in other places
There was a problem hiding this comment.
removed the mask and use ROCm __shfl instead of __shfl_sync
| } | ||
| } | ||
| // Reduce amax in the warp (32x32 tile) | ||
| #ifdef __HIP_PLATFORM_AMD__ |
There was a problem hiding this comment.
The whole this code is under #ifndef HIP_PLATFORM_AMD
| // const values configuration | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| constexpr size_t kThreadsPerWarp = 64; |
There was a problem hiding this comment.
It is platform dependent.
There was a problem hiding this comment.
fixed now guarded with gfx1250 for 32 threads
| transpose/multi_cast_transpose.cu | ||
| transpose/quantize_transpose_vector_blockwise.cu #CUDA-only | ||
| transpose/quantize_transpose_vector_blockwise.cu | ||
| transpose/quantize_transpose_square_blockwise.cu |
There was a problem hiding this comment.
It should stay in transformer_engine_cuda_arch_specific_sources
| static constexpr float max = 448.0f; | ||
| static constexpr float max_inverse = 1.0 / max; |
There was a problem hiding this comment.
If the value really used host size, it should be runtime detected. If it is only for host translation of GPU code (i.e.. results are discarded), you can keep 448, no extra ifdefs is needed
|
MI300 has 64KB of LDS which makes overflow when loading 128 * 128 FP32 data into LDS. I created a helper and branched the kernel. When loading FP32 data, the kernel loads 128 * 64 chunk of data and iterate to quantize. From the host's view, the kernel quantizes 128 * 128 elements. |
Description
Please include a brief summary of the changes, relevant motivation and context.
Enable blockwise FP8 quantization on rocm
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
remove HIP guard in quantization.py
guard kernels using TMA in quantization.
add branch to handle rocm for different threads per wave
Checklist: