-
Notifications
You must be signed in to change notification settings - Fork 31
enable blockwise FP8 quantization on rocm #609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
8335488
6226301
bdf905e
676d1f0
f8a0fc5
231e381
e158d3e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Copyright
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,6 @@ | ||
| /************************************************************************* | ||
| * This file was modified for portability to AMDGPU | ||
| * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. | ||
| * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * | ||
| * See LICENSE for license information. | ||
|
|
@@ -28,7 +30,11 @@ namespace { | |
|
|
||
| // const values configuration | ||
|
|
||
| #if defined(__HIP_PLATFORM_AMD__) && !defined(__gfx1250__) | ||
| constexpr size_t kThreadsPerWarp = 64; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is platform dependent.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed now guarded with gfx1250 for 32 threads |
||
| #else | ||
| constexpr size_t kThreadsPerWarp = 32; | ||
| #endif | ||
| #ifdef TMA_HW_SUPPORTED | ||
| constexpr size_t BLOCK_TILE_DIM = 128; | ||
| constexpr size_t WARP_TILE_DIM_X = 32; | ||
|
|
@@ -40,8 +46,12 @@ constexpr size_t BLOCK_TILE_DIM = 128; | |
| constexpr size_t WARP_TILE_DIM_X = 64; | ||
| constexpr size_t WARP_TILE_DIM_Y = 32; | ||
| constexpr size_t THREAD_TILE_DIM_X = 8; | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| constexpr size_t THREAD_TILE_DIM_Y = 4; | ||
| #else | ||
| constexpr size_t THREAD_TILE_DIM_Y = 8; | ||
| #endif | ||
| #endif | ||
|
|
||
| #ifdef TMA_HW_SUPPORTED | ||
| constexpr size_t NUM_BYTES_PER_BANK = 4; | ||
|
|
@@ -62,6 +72,16 @@ constexpr size_t NUM_THREADS_Y_IN_WARP = kThreadsPerWarp / NUM_THREADS_X_IN_WARP | |
|
|
||
| #define MIN(a, b) (a < b ? a : b) | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| __device__ __forceinline__ float warp_reduce_max_64(float val) { | ||
| #pragma unroll | ||
| for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) | ||
| val = fmaxf(val, __shfl_down(val, delta, kThreadsPerWarp)); | ||
| return val; | ||
| } | ||
| #endif | ||
|
|
||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| template <bool kReturnTranspose, typename CType, typename IType, typename OType> | ||
| __global__ void __launch_bounds__(THREADS_PER_BLOCK) | ||
| block_scaled_cast_transpose_kernel(const IType* const input, OType* const output_c, | ||
|
|
@@ -247,6 +267,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) | |
| #endif | ||
| } | ||
| } | ||
| #endif // __HIP_PLATFORM_AMD__ | ||
|
|
||
| template <bool kReturnTranspose, typename CType, typename IType, typename OType> | ||
| __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose_kernel_notaligned( | ||
|
|
@@ -357,10 +378,18 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose | |
| } | ||
| } | ||
| // Reduce amax in the warp (32x32 tile) | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| warp_tile_amax = warp_reduce_max_64(amax); | ||
| #else | ||
| warp_tile_amax = warp_reduce_max<kThreadsPerWarp>(amax); | ||
| #endif | ||
| // broadcast the amax to all threads in a warp from the lane 0 | ||
| constexpr int lane_zero = 0; | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| warp_tile_amax = __shfl(warp_tile_amax, lane_zero, kThreadsPerWarp); | ||
| #else | ||
| warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero); | ||
| #endif | ||
|
|
||
| // reduce warp_tile_amax across multiple warps in a thread block using shared mem | ||
| if (tid_in_warp == 0) { | ||
|
|
@@ -456,6 +485,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose | |
| } | ||
| } | ||
|
|
||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| template <typename OutputType> | ||
| CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size_t global_dim_y) { | ||
| CUtensorMapDataType dataType; | ||
|
|
@@ -473,6 +503,7 @@ CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size | |
| /*stride_elems=*/global_dim_x, /*offset_elems=*/0, sizeof(OutputType) * 8); | ||
| return tensor_map_output_trans; | ||
| } | ||
| #endif // __HIP_PLATFORM_AMD__ | ||
|
|
||
| } // namespace | ||
| } // namespace transformer_engine | ||
|
|
@@ -543,9 +574,10 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor | |
| return_transpose, kReturnTranspose, | ||
|
|
||
| dim3 grid(num_blocks_x, num_blocks_y, 1); | ||
|
|
||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| const bool full_tile = | ||
| row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0; | ||
|
|
||
| if (full_tile) { | ||
| CUtensorMap tensor_map_output_trans; | ||
| if (return_transpose) { | ||
|
|
@@ -573,6 +605,18 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor | |
| scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, | ||
| pow_2_scale, noop_ptr); | ||
| } // full-tile | ||
| #else | ||
| block_scaled_cast_transpose_kernel_notaligned<kReturnTranspose, float, InputType, | ||
| OutputType> | ||
| <<<grid, THREADS_PER_BLOCK, 0, stream>>>( | ||
| reinterpret_cast<const InputType*>(input.dptr), | ||
| reinterpret_cast<OutputType*>(output.dptr), | ||
| reinterpret_cast<OutputType*>(output_t.dptr), | ||
| reinterpret_cast<float*>(scale_inv.dptr), | ||
| reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, | ||
| scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, | ||
| pow_2_scale, noop_ptr); | ||
| #endif // __HIP_PLATFORM_AMD__ | ||
| ) // return_transpose | ||
| ) // OutputType | ||
| ) // InputType | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: This test and test_float8_current_scaling_exact should be ahead of test_quantized_tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed