diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index b0ed9ae6e..7f462f821 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -32,217 +32,83 @@ __forceinline__ __device__ int binary_search(int32_t target, const int32_t *arra return left - 1; } -// define dk_dv_reduce function only for fp16 and bf16 types -template -__global__ void dk_dv_reduce( - uint64_t b, uint64_t h, uint64_t hg, uint64_t s_kv, uint64_t d, +// Reduce expanded (per-Q-head) dk/dv buffers back to hg KV-heads, accumulating in fp32. +// Defined only for fp16/bf16 (see launch helper's dtype switch). +// GroupMode : true -> THD/varlen addressing via cu_seqlen with padding early-exit; +// false -> batch addressing with an explicit batch stride. +// ReduceBoth : true -> reduce dk and dv together in a single pass, sharing the index +// math (used when d_qk == d_v, where dk/dv have identical layout/strides); +// false -> reduce only the tensor supplied in the dk slot (dv args unused). +// The d_qk != d_v case issues two ReduceBoth=false launches (one per tensor); the common +// d_qk == d_v case stays a single fused launch, so kernel-launch count and per-thread work +// are identical to the original four kernels. +template +__global__ void dkv_reduce( + uint64_t b, uint64_t h, uint64_t hg, uint64_t d, + const int32_t* cu_seqlen_kv_ptr, + const int32_t* cu_seqlen_kv_padded_ptr, const DataType *dk_expanded, const DataType *dv_expanded, uint64_t stride_b_dkv_expanded, uint64_t stride_h_dkv_expanded, uint64_t stride_s_dkv_expanded, DataType *dk, DataType *dv, - //k,v, dk, dv guaranteed to have the same stride + //k, v, dk, dv guaranteed to have the same stride uint64_t stride_b_dkv, uint64_t stride_h_dkv, uint64_t stride_s_dkv){ - uint64_t batch_idx = blockIdx.x; - uint64_t seqlen_idx = blockIdx.y; - uint64_t head_k_idx = blockIdx.z; - uint64_t hdim_idx = threadIdx.x; + const uint64_t hdim_idx = threadIdx.x; + assert(hdim_idx < d); // h guaranteed to be multiples of hg - uint64_t head_idx_offset = h / hg; - - float sum_dk = 0.0f; - float sum_dv = 0.0f; - - assert(hdim_idx){ - sum_dk += ck_tile::bf16_to_float(dk_expanded[read_idx]); - sum_dv += ck_tile::bf16_to_float(dv_expanded[read_idx]); - }else{ - sum_dk += dk_expanded[read_idx]; - sum_dv += dv_expanded[read_idx]; + const uint64_t head_idx_offset = h / hg; + + uint64_t read_idx, write_idx; + if constexpr (GroupMode){ + const uint64_t seqlen_idx = blockIdx.x; + const uint64_t head_k_idx = blockIdx.y; + // skip padding tokens beyond the (padded) total token count + if(seqlen_idx >= *((cu_seqlen_kv_padded_ptr? cu_seqlen_kv_padded_ptr: cu_seqlen_kv_ptr)+b)){ + return; } - read_idx += stride_h_dkv_expanded; - } - - // bf16 requires special casting in CK - if constexpr (std::is_same_v){ - dk[write_idx] = ck_tile::float_to_bf16(sum_dk); - dv[write_idx] = ck_tile::float_to_bf16(sum_dv); - }else{ - dk[write_idx] = sum_dk; - dv[write_idx] = sum_dv; - } -} - -// When d_qk != d_v, we need to reduce dk and dv separately -template -__global__ void dk_or_dv_reduce( - uint64_t b, uint64_t h, uint64_t hg, uint64_t s_kv, uint64_t d, - const DataType *dk_or_dv_expanded, - uint64_t stride_b_dk_or_dv_expanded, uint64_t stride_h_dk_or_dv_expanded, uint64_t stride_s_dk_or_dv_expanded, - DataType *dk_or_dv, - //k,v, dk, dv guaranteed to have the same stride - uint64_t stride_b_dk_or_dv, uint64_t stride_h_dk_or_dv, uint64_t stride_s_dk_or_dv){ - - uint64_t batch_idx = blockIdx.x; - uint64_t seqlen_idx = blockIdx.y; - uint64_t head_k_or_v_idx = blockIdx.z; - uint64_t hdim_idx = threadIdx.x; - - // h guaranteed to be multiples of hg - uint64_t head_idx_offset = h / hg; - - float sum_dk_or_dv = 0.0f; - - assert(hdim_idx){ - sum_dk_or_dv += ck_tile::bf16_to_float(dk_or_dv_expanded[read_idx]); - }else{ - sum_dk_or_dv += dk_or_dv_expanded[read_idx]; + if(cu_seqlen_kv_padded_ptr){ + uint64_t seq_idx = binary_search(seqlen_idx, cu_seqlen_kv_padded_ptr, b+1); + uint64_t unpadded_size = cu_seqlen_kv_ptr[seq_idx+1] - cu_seqlen_kv_ptr[seq_idx]; + if(seqlen_idx >= cu_seqlen_kv_padded_ptr[seq_idx] + unpadded_size){ + return; + } } - read_idx += stride_h_dk_or_dv_expanded; - } - - // bf16 requires special casting in CK - if constexpr (std::is_same_v){ - dk_or_dv[write_idx] = ck_tile::float_to_bf16(sum_dk_or_dv); + read_idx = head_k_idx*head_idx_offset*stride_h_dkv_expanded + seqlen_idx*stride_s_dkv_expanded + hdim_idx; + write_idx = head_k_idx*stride_h_dkv + seqlen_idx*stride_s_dkv + hdim_idx; }else{ - dk_or_dv[write_idx] = sum_dk_or_dv; + const uint64_t batch_idx = blockIdx.x; + const uint64_t seqlen_idx = blockIdx.y; + const uint64_t head_k_idx = blockIdx.z; + read_idx = batch_idx*stride_b_dkv_expanded + head_k_idx*head_idx_offset*stride_h_dkv_expanded + seqlen_idx*stride_s_dkv_expanded + hdim_idx; + write_idx = batch_idx*stride_b_dkv + head_k_idx*stride_h_dkv + seqlen_idx*stride_s_dkv + hdim_idx; } -} - -// define dk_dv_reduce function in THD layout only for fp16 and bf16 types -template -__global__ void dk_dv_reduce_thd( - uint64_t b, uint64_t h, uint64_t hg, uint64_t d, - const int32_t* cu_seqlen_kv_ptr, - const int32_t* cu_seqlen_kv_padded_ptr, - const DataType *dk_expanded, - const DataType *dv_expanded, - uint64_t stride_h_dkv_expanded, uint64_t stride_s_dkv_expanded, - DataType *dk, - DataType *dv, - //k,v, dk, dv guaranteed to have the same stride - uint64_t stride_h_dkv, uint64_t stride_s_dkv){ - - uint64_t seqlen_idx = blockIdx.x; - uint64_t head_k_idx = blockIdx.y; - uint64_t hdim_idx = threadIdx.x; - - assert(hdim_idx= *((cu_seqlen_kv_padded_ptr? cu_seqlen_kv_padded_ptr: cu_seqlen_kv_ptr)+b)){ - return; - } - if(cu_seqlen_kv_padded_ptr){ - uint64_t seq_idx = binary_search(seqlen_idx, cu_seqlen_kv_padded_ptr, b+1); - uint64_t unpadded_size = cu_seqlen_kv_ptr[seq_idx+1] - cu_seqlen_kv_ptr[seq_idx]; - if(seqlen_idx >= cu_seqlen_kv_padded_ptr[seq_idx] + unpadded_size){ - return; - } - } - // h guaranteed to be multiples of hg - uint64_t head_idx_offset = h / hg; float sum_dk = 0.0f; float sum_dv = 0.0f; - - - uint64_t read_idx = head_k_idx*head_idx_offset*stride_h_dkv_expanded + seqlen_idx*stride_s_dkv_expanded + hdim_idx; - uint64_t write_idx = head_k_idx*stride_h_dkv + seqlen_idx* stride_s_dkv + hdim_idx; - for(uint64_t ii = 0; ii < head_idx_offset; ii++){ - // bf16 requires special casting in CK - if constexpr (std::is_same_v){ - sum_dk += ck_tile::bf16_to_float(dk_expanded[read_idx]); - sum_dv += ck_tile::bf16_to_float(dv_expanded[read_idx]); - }else{ - sum_dk += dk_expanded[read_idx]; - sum_dv += dv_expanded[read_idx]; + sum_dk += to_f32(dk_expanded[read_idx]); + if constexpr (ReduceBoth){ + sum_dv += to_f32(dv_expanded[read_idx]); } read_idx += stride_h_dkv_expanded; } - - // bf16 requires special casting in CK - if constexpr (std::is_same_v){ - dk[write_idx] = ck_tile::float_to_bf16(sum_dk); - dv[write_idx] = ck_tile::float_to_bf16(sum_dv); - }else{ - dk[write_idx] = sum_dk; - dv[write_idx] = sum_dv; + dk[write_idx] = from_f32(sum_dk); + if constexpr (ReduceBoth){ + dv[write_idx] = from_f32(sum_dv); } } -// When d_qk != d_v, we need to reduce dk and dv separately -template -__global__ void dk_or_dv_reduce_thd( - uint64_t b, uint64_t h, uint64_t hg, uint64_t d, - const int32_t* cu_seqlen_kv_ptr, - const int32_t* cu_seqlen_kv_padded_ptr, - const DataType *dk_or_dv_expanded, - uint64_t stride_h_dk_or_dv_expanded, uint64_t stride_s_dk_or_dv_expanded, - DataType *dk_or_dv, - //k,v, dk, dv guaranteed to have the same stride - uint64_t stride_h_dk_or_dv, uint64_t stride_s_dk_or_dv){ - - uint64_t seqlen_idx = blockIdx.x; - uint64_t head_k_or_v_idx = blockIdx.y; - uint64_t hdim_idx = threadIdx.x; - - assert(hdim_idx= *((cu_seqlen_kv_padded_ptr? cu_seqlen_kv_padded_ptr: cu_seqlen_kv_ptr)+b)){ - return; - } - if(cu_seqlen_kv_padded_ptr){ - uint64_t seq_idx = binary_search(seqlen_idx, cu_seqlen_kv_padded_ptr, b+1); - uint64_t unpadded_size = cu_seqlen_kv_ptr[seq_idx+1] - cu_seqlen_kv_ptr[seq_idx]; - if(seqlen_idx >= cu_seqlen_kv_padded_ptr[seq_idx] + unpadded_size){ - return; - } - } - // h guaranteed to be multiples of hg - uint64_t head_idx_offset = h / hg; - - float sum_dk_or_dv = 0.0f; - - uint64_t read_idx = head_k_or_v_idx*head_idx_offset*stride_h_dk_or_dv_expanded + seqlen_idx*stride_s_dk_or_dv_expanded + hdim_idx; - uint64_t write_idx = head_k_or_v_idx*stride_h_dk_or_dv + seqlen_idx* stride_s_dk_or_dv + hdim_idx; - - for(uint64_t ii = 0; ii < head_idx_offset; ii++){ - // bf16 requires special casting in CK - if constexpr (std::is_same_v){ - sum_dk_or_dv += ck_tile::bf16_to_float(dk_or_dv_expanded[read_idx]); - }else{ - sum_dk_or_dv += dk_or_dv_expanded[read_idx]; - } - read_idx += stride_h_dk_or_dv_expanded; - } - - // bf16 requires special casting in CK - if constexpr (std::is_same_v){ - dk_or_dv[write_idx] = ck_tile::float_to_bf16(sum_dk_or_dv); - }else{ - dk_or_dv[write_idx] = sum_dk_or_dv; - } -} - - -// define dbias_reduce functions only for fp16 and bf16 types -template -__global__ void dbias_reduce_11ss( +// Reduce expanded dbias (b, h, s_q, s_kv) over the batch and/or head dims, accumulating +// in fp32. Defined only for fp16/bf16 (see launch helper's dtype switch). +// ReduceB && ReduceH -> sum over b and h -> output (1, 1, s_q, s_kv) [k11SS] +// ReduceB && !ReduceH -> sum over b, per head -> output (1, h, s_q, s_kv) [k1HSS] +// !ReduceB && ReduceH -> sum over h, per batch-> output (b, 1, s_q, s_kv) [kB1SS] +template +__global__ void dbias_reduce( uint64_t b, uint64_t h, uint64_t s_q, uint64_t s_kv, const DataType *dbias_expanded, DataType *dbias){ @@ -250,84 +116,105 @@ __global__ void dbias_reduce_11ss( const uint64_t stride_h = s_q*s_kv; const uint64_t stride_b = h*s_q*s_kv; for(uint64_t ss_idx = blockIdx.x*blockDim.x + threadIdx.x; ss_idx < s_q*s_kv; ss_idx += blockDim.x * gridDim.x){ - //sum over b, h dims both - float sum_dbias = 0.0f; - for(uint64_t b_idx = 0; b_idx< b; b_idx++){ + if constexpr (ReduceB && ReduceH){ + //sum over b, h dims both + float sum_dbias = 0.0f; + for(uint64_t b_idx = 0; b_idx< b; b_idx++){ + for(uint64_t h_idx = 0; h_idx < h; h_idx++){ + sum_dbias += to_f32(dbias_expanded[b_idx*stride_b + h_idx*stride_h+ss_idx]); + } + } + dbias[ss_idx] = from_f32(sum_dbias); + }else if constexpr (ReduceB){ for(uint64_t h_idx = 0; h_idx < h; h_idx++){ - if constexpr (std::is_same_v){ - // bf16 requires special casting in CK - sum_dbias += ck_tile::bf16_to_float(dbias_expanded[b_idx*stride_b + h_idx*stride_h+ss_idx]); - }else{ - sum_dbias += dbias_expanded[b_idx*stride_b + h_idx*stride_h+ss_idx]; + //sum over b dims only + float sum_dbias = 0.0f; + for(uint64_t b_idx = 0; b_idx< b; b_idx++){ + sum_dbias += to_f32(dbias_expanded[b_idx*stride_b + h_idx*stride_h+ss_idx]); } + dbias[ss_idx + h_idx*stride_h] = from_f32(sum_dbias); } - } - if constexpr (std::is_same_v){ - dbias[ss_idx] = ck_tile::float_to_bf16(sum_dbias); }else{ - dbias[ss_idx] = sum_dbias; - } - } -} - -// define dbias_reduce functions only for fp16 and bf16 types -template -__global__ void dbias_reduce_1hss( - uint64_t b, uint64_t h, uint64_t s_q, uint64_t s_kv, - const DataType *dbias_expanded, - DataType *dbias){ - - const uint64_t stride_h = s_q*s_kv; - const uint64_t stride_b = h*s_q*s_kv; - for(uint64_t ss_idx = blockIdx.x*blockDim.x + threadIdx.x; ss_idx < s_q*s_kv; ss_idx += blockDim.x * gridDim.x){ - for(uint64_t h_idx = 0; h_idx < h; h_idx++){ - //sum over b dims only - float sum_dbias = 0.0f; + // ReduceH only for(uint64_t b_idx = 0; b_idx< b; b_idx++){ - if constexpr (std::is_same_v){ - // bf16 requires special casting in CK - sum_dbias += ck_tile::bf16_to_float(dbias_expanded[b_idx*stride_b + h_idx*stride_h+ss_idx]); - }else{ - sum_dbias += dbias_expanded[b_idx*stride_b + h_idx*stride_h+ss_idx]; + //sum over h dims only + float sum_dbias = 0.0f; + for(uint64_t h_idx = 0; h_idx < h; h_idx++){ + sum_dbias += to_f32(dbias_expanded[b_idx*stride_b + h_idx*stride_h+ss_idx]); } - } - if constexpr (std::is_same_v){ - dbias[ss_idx + h_idx*stride_h] = ck_tile::float_to_bf16(sum_dbias); - }else{ - dbias[ss_idx + h_idx*stride_h] = sum_dbias; + // output is packed [b, s_q*s_kv]; per-batch slice size == stride_h (s_q*s_kv) + dbias[ss_idx + b_idx*stride_h] = from_f32(sum_dbias); } } } } -// define dbias_reduce functions only for fp16 and bf16 types -template -__global__ void dbias_reduce_b1ss( - uint64_t b, uint64_t h, uint64_t s_q, uint64_t s_kv, - const DataType *dbias_expanded, - DataType *dbias){ +// Streamlined logging for a dk/dv reduction launch (no-op unless CK logging is enabled). +static void log_dkv_reduce( + const char* name, const CkAttnBwdArgs& args, + const void* dk_exp, const void* dv_exp, + uint64_t stride_b_exp, uint64_t stride_h_exp, uint64_t stride_s_exp, + const void* dk, const void* dv, + uint64_t stride_b, uint64_t stride_h, uint64_t stride_s){ + std::ostream* log_file = get_ck_log_stream(); + if(!log_file) return; + (*log_file) << "\nrun " << name << ":\n"; + log_value(log_file, "cu_seqlen_kv_ptr", args.cu_seqlen_kv_ptr); + log_value(log_file, "cu_seqlen_kv_padded_ptr", args.cu_seqlen_kv_padded_ptr); + log_value(log_file, "dk_expanded_ptr", dk_exp); + log_value(log_file, "dv_expanded_ptr", dv_exp); + log_value(log_file, "stride_b_expanded", stride_b_exp); + log_value(log_file, "stride_h_expanded", stride_h_exp); + log_value(log_file, "stride_s_expanded", stride_s_exp); + log_value(log_file, "dk_ptr", dk); + log_value(log_file, "dv_ptr", dv); + log_value(log_file, "stride_b", stride_b); + log_value(log_file, "stride_h", stride_h); + log_value(log_file, "stride_s", stride_s); +} - const uint64_t stride_h = s_q*s_kv; - const uint64_t stride_b = h*s_q*s_kv; - for(uint64_t ss_idx = blockIdx.x*blockDim.x + threadIdx.x; ss_idx < s_q*s_kv; ss_idx += blockDim.x * gridDim.x){ - for(uint64_t b_idx = 0; b_idx< b; b_idx++){ - //sum over h dims only - float sum_dbias = 0.0f; - for(uint64_t h_idx = 0; h_idx < h; h_idx++){ - if constexpr (std::is_same_v){ - // bf16 requires special casting in CK - sum_dbias += ck_tile::bf16_to_float(dbias_expanded[b_idx*stride_b + h_idx*stride_h+ss_idx]); - }else{ - sum_dbias += dbias_expanded[b_idx*stride_b + h_idx*stride_h+ss_idx]; - } - } - if constexpr (std::is_same_v){ - dbias[ss_idx + b_idx*stride_h] = ck_tile::float_to_bf16(sum_dbias); - }else{ - dbias[ss_idx + b_idx*stride_h] = sum_dbias; - } - } +// Launch the unified dk/dv reduction for one tensor-set configuration. For the fused +// (d_qk == d_v) path call with ReduceBoth=true and both dk/dv pointers; for the split +// path call once per tensor with ReduceBoth=false and the tensor in the dk slot. +template +static void launch_dkv_reduce( + const char* name, const CkAttnBwdArgs& args, dim3 grid, dim3 block, + const void* dk_exp, const void* dv_exp, + uint64_t stride_b_exp, uint64_t stride_h_exp, uint64_t stride_s_exp, + void* dk, void* dv, + uint64_t stride_b, uint64_t stride_h, uint64_t stride_s, + uint64_t d, hipStream_t stream){ + log_dkv_reduce(name, args, dk_exp, dv_exp, stride_b_exp, stride_h_exp, stride_s_exp, + dk, dv, stride_b, stride_h, stride_s); + CK_FUSED_ATTN_TYPE_SWITCH_16BIT(args.dtype, CK_TILE_TYPE, + hipLaunchKernelGGL( + (dkv_reduce), grid, block, 0, stream, + args.b, args.h, args.hg, d, + static_cast(args.cu_seqlen_kv_ptr), + static_cast(args.cu_seqlen_kv_padded_ptr), + static_cast(dk_exp), + static_cast(dv_exp), + stride_b_exp, stride_h_exp, stride_s_exp, + static_cast(dk), + static_cast(dv), + stride_b, stride_h, stride_s);); +} + +// Launch the unified dbias reduction (batch mode only). +template +static void launch_dbias_reduce( + const char* name, const CkAttnBwdArgs& args, dim3 grid, dim3 block, hipStream_t stream){ + if (auto* log_file = get_ck_log_stream()) { + *log_file << "\nrun " << name << ":\n"; + log_value(log_file, "dbias_ptr", args.dbias_ptr); + log_value(log_file, "dbias_expanded_ptr", args.dbias_expanded_ptr); } + CK_FUSED_ATTN_TYPE_SWITCH_16BIT(args.dtype, CK_TILE_TYPE, + hipLaunchKernelGGL( + (dbias_reduce), grid, block, 0, stream, + args.b, args.h, args.s_q, args.s_kv, + static_cast(args.dbias_expanded_ptr), + static_cast(args.dbias_ptr));); } // print the fmha_traits and args passed into ck apis @@ -604,153 +491,61 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ } // Post-dispatch reductions for MQA/GQA: reduce dk_expanded/dv_expanded into dk/dv. - // Batch and group modes use different kernels (batch carves by stride_b; group carves by cu_seqlen). + // Batch and group modes use different addressing (batch carves by stride_b; group by + // cu_seqlen). When d_qk == d_v, dk and dv share layout and are reduced in one fused pass; + // otherwise each is reduced by its own launch. if(is_mqa_gqa){ if(args.is_group_mode()){ dim3 grid(args.max_tokens_kv, args.hg); if(args.d_qk == args.d_v){ - dim3 block(args.d_qk); - if (auto* log_file = get_ck_log_stream()) { - *log_file << "\n" << "run dk_dv_reduce_thd: " << "\n"; - *log_file << "cu_seqlen_kv_ptr: " << args.cu_seqlen_kv_ptr << "\n"; - *log_file << "cu_seqlen_kv_padded_ptr: " << args.cu_seqlen_kv_padded_ptr << "\n"; - *log_file << "dk_expanded_ptr: " << args.dk_expanded_ptr << "\n"; - *log_file << "dv_expanded_ptr: " << args.dv_expanded_ptr << "\n"; - *log_file << "stride_h_dkv_expanded: " << args.stride_h_dk_expanded << "\n"; - *log_file << "stride_s_dkv_expanded: " << args.stride_s_dk_expanded << "\n"; - *log_file << "dk_ptr: " << args.dk_ptr << "\n"; - *log_file << "dv_ptr: " << args.dv_ptr << "\n"; - *log_file << "stride_h_dk: " << args.stride_h_dk << "\n"; - *log_file << "stride_s_dk: " << args.stride_s_dk << "\n"; - } - CK_FUSED_ATTN_TYPE_SWITCH_16BIT(args.dtype, CK_TILE_TYPE, - hipLaunchKernelGGL( - dk_dv_reduce_thd, grid, block, 0, stream, - args.b, args.h, args.hg, args.d_qk, - static_cast(args.cu_seqlen_kv_ptr), - static_cast(args.cu_seqlen_kv_padded_ptr), - static_cast(args.dk_expanded_ptr), - static_cast(args.dv_expanded_ptr), - args.stride_h_dk_expanded, args.stride_s_dk_expanded, - static_cast(args.dk_ptr), - static_cast(args.dv_ptr), - args.stride_h_dk, args.stride_s_dk);); + launch_dkv_reduce( + "dk_dv_reduce_thd", args, grid, dim3(args.d_qk), + args.dk_expanded_ptr, args.dv_expanded_ptr, + args.stride_b_dk_expanded, args.stride_h_dk_expanded, args.stride_s_dk_expanded, + args.dk_ptr, args.dv_ptr, + args.stride_b_dk, args.stride_h_dk, args.stride_s_dk, + args.d_qk, stream); } else { - dim3 block_dk(args.d_qk); - if (auto* log_file = get_ck_log_stream()) { - *log_file << "\n" << "run dk_or_dv_reduce_thd on dk: " << "\n"; - *log_file << "cu_seqlen_kv_ptr: " << args.cu_seqlen_kv_ptr << "\n"; - *log_file << "cu_seqlen_kv_padded_ptr: " << args.cu_seqlen_kv_padded_ptr << "\n"; - *log_file << "dk_expanded_ptr: " << args.dk_expanded_ptr << "\n"; - *log_file << "stride_h_dk_expanded: " << args.stride_h_dk_expanded << "\n"; - *log_file << "stride_s_dk_expanded: " << args.stride_s_dk_expanded << "\n"; - *log_file << "dk_ptr: " << args.dk_ptr << "\n"; - *log_file << "stride_h_dk: " << args.stride_h_dk << "\n"; - *log_file << "stride_s_dk: " << args.stride_s_dk << "\n"; - } - CK_FUSED_ATTN_TYPE_SWITCH_16BIT(args.dtype, CK_TILE_TYPE, - hipLaunchKernelGGL( - dk_or_dv_reduce_thd, grid, block_dk, 0, stream, - args.b, args.h, args.hg, args.d_qk, - static_cast(args.cu_seqlen_kv_ptr), - static_cast(args.cu_seqlen_kv_padded_ptr), - static_cast(args.dk_expanded_ptr), - args.stride_h_dk_expanded, args.stride_s_dk_expanded, - static_cast(args.dk_ptr), - args.stride_h_dk, args.stride_s_dk);); - - dim3 block_dv(args.d_v); - if (auto* log_file = get_ck_log_stream()) { - *log_file << "\n" << "run dk_or_dv_reduce_thd on dv: " << "\n"; - *log_file << "cu_seqlen_kv_ptr: " << args.cu_seqlen_kv_ptr << "\n"; - *log_file << "cu_seqlen_kv_padded_ptr: " << args.cu_seqlen_kv_padded_ptr << "\n"; - *log_file << "dv_expanded_ptr: " << args.dv_expanded_ptr << "\n"; - *log_file << "stride_h_dv_expanded: " << args.stride_h_dv_expanded << "\n"; - *log_file << "stride_s_dv_expanded: " << args.stride_s_dv_expanded << "\n"; - *log_file << "dv_ptr: " << args.dv_ptr << "\n"; - *log_file << "stride_h_dv: " << args.stride_h_dv << "\n"; - *log_file << "stride_s_dv: " << args.stride_s_dv << "\n"; - } - CK_FUSED_ATTN_TYPE_SWITCH_16BIT(args.dtype, CK_TILE_TYPE, - hipLaunchKernelGGL( - dk_or_dv_reduce_thd, grid, block_dv, 0, stream, - args.b, args.h, args.hg, args.d_v, - static_cast(args.cu_seqlen_kv_ptr), - static_cast(args.cu_seqlen_kv_padded_ptr), - static_cast(args.dv_expanded_ptr), - args.stride_h_dv_expanded, args.stride_s_dv_expanded, - static_cast(args.dv_ptr), - args.stride_h_dv, args.stride_s_dv);); + launch_dkv_reduce( + "dk_reduce_thd", args, grid, dim3(args.d_qk), + args.dk_expanded_ptr, nullptr, + args.stride_b_dk_expanded, args.stride_h_dk_expanded, args.stride_s_dk_expanded, + args.dk_ptr, nullptr, + args.stride_b_dk, args.stride_h_dk, args.stride_s_dk, + args.d_qk, stream); + launch_dkv_reduce( + "dv_reduce_thd", args, grid, dim3(args.d_v), + args.dv_expanded_ptr, nullptr, + args.stride_b_dv_expanded, args.stride_h_dv_expanded, args.stride_s_dv_expanded, + args.dv_ptr, nullptr, + args.stride_b_dv, args.stride_h_dv, args.stride_s_dv, + args.d_v, stream); } } else { dim3 grid(args.b, args.s_kv, args.hg); if(args.d_qk == args.d_v){ - dim3 block(args.d_qk); - if (auto* log_file = get_ck_log_stream()) { - *log_file << "\n" << "run dk_dv_reduce: " << "\n"; - *log_file << "dk_expanded_ptr: " << args.dk_expanded_ptr << "\n"; - *log_file << "dv_expanded_ptr: " << args.dv_expanded_ptr << "\n"; - *log_file << "stride_b_dkv_expanded: " << args.stride_b_dk_expanded << "\n"; - *log_file << "stride_h_dkv_expanded: " << args.stride_h_dk_expanded << "\n"; - *log_file << "stride_s_dkv_expanded: " << args.stride_s_dk_expanded << "\n"; - *log_file << "dk_ptr: " << args.dk_ptr << "\n"; - *log_file << "dv_ptr: " << args.dv_ptr << "\n"; - *log_file << "stride_b_dk: " << args.stride_b_dk << "\n"; - *log_file << "stride_h_dk: " << args.stride_h_dk << "\n"; - *log_file << "stride_s_dk: " << args.stride_s_dk << "\n"; - } - CK_FUSED_ATTN_TYPE_SWITCH_16BIT(args.dtype, CK_TILE_TYPE, - hipLaunchKernelGGL( - dk_dv_reduce, grid, block, 0, stream, - args.b, args.h, args.hg, args.s_kv, args.d_qk, - static_cast(args.dk_expanded_ptr), - static_cast(args.dv_expanded_ptr), - args.stride_b_dk_expanded, args.stride_h_dk_expanded, args.stride_s_dk_expanded, - static_cast(args.dk_ptr), - static_cast(args.dv_ptr), - args.stride_b_dk, args.stride_h_dk, args.stride_s_dk);); + launch_dkv_reduce( + "dk_dv_reduce", args, grid, dim3(args.d_qk), + args.dk_expanded_ptr, args.dv_expanded_ptr, + args.stride_b_dk_expanded, args.stride_h_dk_expanded, args.stride_s_dk_expanded, + args.dk_ptr, args.dv_ptr, + args.stride_b_dk, args.stride_h_dk, args.stride_s_dk, + args.d_qk, stream); } else { - dim3 block_dk(args.d_qk); - if (auto* log_file = get_ck_log_stream()) { - *log_file << "\n" << "run dk_or_dv_reduce on dk: " << "\n"; - *log_file << "dk_expanded_ptr: " << args.dk_expanded_ptr << "\n"; - *log_file << "stride_b_dk_expanded: " << args.stride_b_dk_expanded << "\n"; - *log_file << "stride_h_dk_expanded: " << args.stride_h_dk_expanded << "\n"; - *log_file << "stride_s_dk_expanded: " << args.stride_s_dk_expanded << "\n"; - *log_file << "dk_ptr: " << args.dk_ptr << "\n"; - *log_file << "stride_b_dk: " << args.stride_b_dk << "\n"; - *log_file << "stride_h_dk: " << args.stride_h_dk << "\n"; - *log_file << "stride_s_dk: " << args.stride_s_dk << "\n"; - } - CK_FUSED_ATTN_TYPE_SWITCH_16BIT(args.dtype, CK_TILE_TYPE, - hipLaunchKernelGGL( - dk_or_dv_reduce, grid, block_dk, 0, stream, - args.b, args.h, args.hg, args.s_kv, args.d_qk, - static_cast(args.dk_expanded_ptr), - args.stride_b_dk_expanded, args.stride_h_dk_expanded, args.stride_s_dk_expanded, - static_cast(args.dk_ptr), - args.stride_b_dk, args.stride_h_dk, args.stride_s_dk);); - - dim3 block_dv(args.d_v); - if (auto* log_file = get_ck_log_stream()) { - *log_file << "\n" << "run dk_or_dv_reduce on dv: " << "\n"; - *log_file << "dv_expanded_ptr: " << args.dv_expanded_ptr << "\n"; - *log_file << "stride_b_dv_expanded: " << args.stride_b_dv_expanded << "\n"; - *log_file << "stride_h_dv_expanded: " << args.stride_h_dv_expanded << "\n"; - *log_file << "stride_s_dv_expanded: " << args.stride_s_dv_expanded << "\n"; - *log_file << "dv_ptr: " << args.dv_ptr << "\n"; - *log_file << "stride_b_dv: " << args.stride_b_dv << "\n"; - *log_file << "stride_h_dv: " << args.stride_h_dv << "\n"; - *log_file << "stride_s_dv: " << args.stride_s_dv << "\n"; - } - CK_FUSED_ATTN_TYPE_SWITCH_16BIT(args.dtype, CK_TILE_TYPE, - hipLaunchKernelGGL( - dk_or_dv_reduce, grid, block_dv, 0, stream, - args.b, args.h, args.hg, args.s_kv, args.d_v, - static_cast(args.dv_expanded_ptr), - args.stride_b_dv_expanded, args.stride_h_dv_expanded, args.stride_s_dv_expanded, - static_cast(args.dv_ptr), - args.stride_b_dv, args.stride_h_dv, args.stride_s_dv);); + launch_dkv_reduce( + "dk_reduce", args, grid, dim3(args.d_qk), + args.dk_expanded_ptr, nullptr, + args.stride_b_dk_expanded, args.stride_h_dk_expanded, args.stride_s_dk_expanded, + args.dk_ptr, nullptr, + args.stride_b_dk, args.stride_h_dk, args.stride_s_dk, + args.d_qk, stream); + launch_dkv_reduce( + "dv_reduce", args, grid, dim3(args.d_v), + args.dv_expanded_ptr, nullptr, + args.stride_b_dv_expanded, args.stride_h_dv_expanded, args.stride_s_dv_expanded, + args.dv_ptr, nullptr, + args.stride_b_dv, args.stride_h_dv, args.stride_s_dv, + args.d_v, stream); } } } @@ -762,41 +557,11 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ dim3 block(THREADS_PER_BLOCK); dim3 grid(ceil(1.0 * args.s_q * args.s_kv / THREADS_PER_BLOCK)); if(bias_shape==BiasShape::k11SS){ - if (auto* log_file = get_ck_log_stream()) { - *log_file << "\n" << "run dbias_reduce_11SS: " << "\n"; - *log_file << "dbias_ptr: " << args.dbias_ptr << "\n"; - *log_file << "dbias_expanded_ptr: " << args.dbias_expanded_ptr << "\n"; - } - CK_FUSED_ATTN_TYPE_SWITCH_16BIT(args.dtype, CK_TILE_TYPE, - hipLaunchKernelGGL( - dbias_reduce_11ss, grid, block, 0, stream, - args.b, args.h, args.s_q, args.s_kv, - static_cast(args.dbias_expanded_ptr), - static_cast(args.dbias_ptr));); + launch_dbias_reduce("dbias_reduce_11SS", args, grid, block, stream); }else if(bias_shape==BiasShape::k1HSS){ - if (auto* log_file = get_ck_log_stream()) { - *log_file << "\n" << "run dbias_reduce_1HSS: " << "\n"; - *log_file << "dbias_ptr: " << args.dbias_ptr << "\n"; - *log_file << "dbias_expanded_ptr: " << args.dbias_expanded_ptr << "\n"; - } - CK_FUSED_ATTN_TYPE_SWITCH_16BIT(args.dtype, CK_TILE_TYPE, - hipLaunchKernelGGL( - dbias_reduce_1hss, grid, block, 0, stream, - args.b, args.h, args.s_q, args.s_kv, - static_cast(args.dbias_expanded_ptr), - static_cast(args.dbias_ptr));); + launch_dbias_reduce("dbias_reduce_1HSS", args, grid, block, stream); }else if(bias_shape==BiasShape::kB1SS){ - if (auto* log_file = get_ck_log_stream()) { - *log_file << "\n" << "run dbias_reduce_B1SS: " << "\n"; - *log_file << "dbias_ptr: " << args.dbias_ptr << "\n"; - *log_file << "dbias_expanded_ptr: " << args.dbias_expanded_ptr << "\n"; - } - CK_FUSED_ATTN_TYPE_SWITCH_16BIT(args.dtype, CK_TILE_TYPE, - hipLaunchKernelGGL( - dbias_reduce_b1ss, grid, block, 0, stream, - args.b, args.h, args.s_q, args.s_kv, - static_cast(args.dbias_expanded_ptr), - static_cast(args.dbias_ptr));); + launch_dbias_reduce("dbias_reduce_B1SS", args, grid, block, stream); } } return hipSuccess; diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp index a926d230d..dc110c6f0 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.hpp @@ -9,6 +9,7 @@ #include #include +#include #include #include "ck_tile/host.hpp" #include "ck_fused_attn/ck_fused_attn.hpp" @@ -36,6 +37,27 @@ switch (dtype) { \ throw std::runtime_error("Invalid type for 16 bit.."); \ } +// Device-side dtype conversion helpers for the reduction kernels. +// ck_tile::bf16_t lacks implicit float conversions and needs explicit +// bf16_to_float / float_to_bf16; fp16/fp32 convert implicitly. +template +__device__ __forceinline__ float to_f32(T x) { + if constexpr (std::is_same_v) { + return ck_tile::bf16_to_float(x); + } else { + return static_cast(x); + } +} + +template +__device__ __forceinline__ T from_f32(float x) { + if constexpr (std::is_same_v) { + return ck_tile::float_to_bf16(x); + } else { + return static_cast(x); + } +} + // element-wise bias shape enum class BiasShape{ k11SS = 0,