From 270357f5a53cbf5963320c26fdf97c7c3cce2b41 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 8 Jun 2026 16:28:55 +0000 Subject: [PATCH] Minor refactoring of workspace setting and env getting --- .../common/fused_attn_rocm/fused_attn.cpp | 39 +++------- .../fused_attn_rocm/fused_attn_aotriton.cpp | 42 ++-------- .../common/fused_attn_rocm/fused_attn_ck.cpp | 76 +++---------------- .../common/fused_attn_rocm/utils.cpp | 29 +++++++ .../common/fused_attn_rocm/utils.h | 12 +++ 5 files changed, 70 insertions(+), 128 deletions(-) diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index 943fa1697..b621d7174 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -12,6 +12,7 @@ #include "fused_attn_ck.h" #include "../common.h" #include "../util/cuda_runtime.h" //cuda::sm_arch +#include "../util/system.h" //getenv #include "utils.h" // map NVTE_QKV_Layout to NVTE_QKV_Layout_Group @@ -141,12 +142,9 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { // causal_bottom_right, padding_causal_bottom_right | (-1, 0) or (>=0, 0) std::pair check_set_window_size(NVTE_Mask_Type attn_mask_type, std::pair window_size){ //mask_type contain causal - bool nvte_log_fused_attn_config = false; - if (const char* env_p = std::getenv("NVTE_LOG_FUSED_ATTN_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - nvte_log_fused_attn_config = true; - } - if(attn_mask_type==NVTE_CAUSAL_MASK || attn_mask_type==NVTE_PADDING_CAUSAL_MASK || attn_mask_type==NVTE_CAUSAL_BOTTOM_RIGHT_MASK || attn_mask_type==NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK){ + const bool nvte_log_fused_attn_config = + transformer_engine::getenv("NVTE_LOG_FUSED_ATTN_CONFIG"); + if(transformer_engine::fused_attn_rocm::is_causal_mask(attn_mask_type)){ if(window_size==std::make_pair(-1, -1) || (window_size.first >=0 && window_size.second!=0)){ //TODO: better INFO logging if(nvte_log_fused_attn_config){ @@ -236,11 +234,8 @@ void log_fused_attn_config( size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) { //log the fused attn config at NVTE common level - bool nvte_log_fused_attn_config = false; - if (const char* env_p = std::getenv("NVTE_LOG_FUSED_ATTN_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - nvte_log_fused_attn_config = true; - } + const bool nvte_log_fused_attn_config = + transformer_engine::getenv("NVTE_LOG_FUSED_ATTN_CONFIG"); if(!nvte_log_fused_attn_config){ return; } @@ -297,24 +292,12 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( if (return_max_logit) return NVTE_Fused_Attn_Backend::NVTE_No_Backend; // by default, fused attn is enabled - bool nvte_fused_attn = true; - if (const char* env_p = std::getenv("NVTE_FUSED_ATTN") ) { - if (env_p != nullptr && std::string(env_p) == "0") - nvte_fused_attn = false; - } + const bool nvte_fused_attn = getenv("NVTE_FUSED_ATTN", true); - // by default, both ck and aotriton backends are enabled by nvte_fused_attn - bool nvte_fused_attn_ck = nvte_fused_attn; - bool nvte_fused_attn_aotriton = nvte_fused_attn; - - if (const char* env_p = std::getenv("NVTE_FUSED_ATTN_CK") ) { - if (env_p != nullptr && std::string(env_p) == "0") - nvte_fused_attn_ck = false; - } - if (const char* env_p = std::getenv("NVTE_FUSED_ATTN_AOTRITON") ) { - if (env_p != nullptr && std::string(env_p) == "0") - nvte_fused_attn_aotriton = false; - } + // by default, both ck and aotriton backends inherit the master toggle + const bool nvte_fused_attn_ck = nvte_fused_attn && getenv("NVTE_FUSED_ATTN_CK", true); + const bool nvte_fused_attn_aotriton = + nvte_fused_attn && getenv("NVTE_FUSED_ATTN_AOTRITON", true); // fix the incompatible window size from upstream frameworks pytorch/jax std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp index 33557d214..2928520fc 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp @@ -206,11 +206,7 @@ void fused_attn_aotriton_fwd_impl( std::array{1, 1, 1, 1}, dtype); - bool nvte_log_aotriton_config = false; - if (const char* env_p = std::getenv("NVTE_LOG_AOTRITON_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - nvte_log_aotriton_config = true; - } + const bool nvte_log_aotriton_config = getenv("NVTE_LOG_AOTRITON_CONFIG"); aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, dtype); auto seed = mk_aoscalartensor(devPtrDropoutSeed); auto offset1 = mk_aoscalartensor(devPtrDropoutOffset); @@ -401,11 +397,7 @@ void fused_attn_aotriton_bwd_impl( auto cu_seqlens_q = aotriton::TensorView<1>(reinterpret_cast(devPtrCuSeqlensQ), cu_seqlens_shape, cu_seqlens_stride, aotriton::DType::kInt32); auto cu_seqlens_k = aotriton::TensorView<1>(reinterpret_cast(devPtrCuSeqlensKV), cu_seqlens_shape, cu_seqlens_stride, aotriton::DType::kInt32); - bool nvte_log_aotriton_config = false; - if (const char* env_p = std::getenv("NVTE_LOG_AOTRITON_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - nvte_log_aotriton_config = true; - } + const bool nvte_log_aotriton_config = getenv("NVTE_LOG_AOTRITON_CONFIG"); aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, dtype); auto seed = mk_aoscalartensor(devPtrDropoutSeed); auto offset = mk_aoscalartensor(devPtrDropoutOffset); @@ -555,19 +547,8 @@ void fused_attn_aotriton_fwd( &workspace_size, stream); - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } + set_workspace_size(workspace, workspace_size); + return; #else NVTE_ERROR("AOTriton backend not compiled."); #endif // USE_FUSED_ATTN_AOTRITON @@ -620,19 +601,8 @@ void fused_attn_aotriton_bwd( &workspace_size, stream); - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } + set_workspace_size(workspace, workspace_size); + return; #else NVTE_ERROR("AOTriton backend not compiled."); #endif // USE_FUSED_ATTN_AOTRITON diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 597387555..0c7d80a6a 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -38,12 +38,8 @@ bool is_ck_backend_supported( #ifdef USE_FUSED_ATTN_CK // debug info setting - bool nvte_log_ck_config = false; - if (const char* env_p = std::getenv("NVTE_LOG_CK_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - nvte_log_ck_config = true; - } - + const bool nvte_log_ck_config = getenv("NVTE_LOG_CK_CONFIG"); + // single filters // filter based on num_heads and num_gqa_groups @@ -90,10 +86,7 @@ bool is_ck_backend_supported( } // joint filter based on sliding window and attn_mask - bool is_causal = (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK|| - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK|| - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); + bool is_causal = is_causal_mask(attn_mask_type); if(is_causal){ // causal mask window must be with causal top left or causal bottom right mask type if (!((window_size_left ==-1 || window_size_left >=0) && window_size_right ==0 )){ @@ -137,9 +130,7 @@ bool is_ck_backend_supported( // in NVTE, padding can happen in both THD format or BSHD/SBHD format // For THD format, padding is natural // For BSHD/SBHD, padding can be inferred by a cu_seqlen which shows the actual seqlen for each batch, while the dim(S) is the max_seqlen - bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); + bool is_padding = is_padding_mask(attn_mask_type); if(is_ragged && !is_padding){ if(nvte_log_ck_config){ std::cout<<"Ragged QKV input requires padding mask"<("NVTE_LOG_CK_CONFIG"); bool nvte_ck_uses_fwd_v3 = getenv("NVTE_CK_USES_FWD_V3", 1); int nvte_ck_how_v3_bf16_cvt = getenv("NVTE_CK_HOW_V3_BF16_CVT", 1); @@ -497,9 +484,7 @@ void fused_attn_ck_fwd_impl( bool is_SBHD = qkv_format==NVTE_QKV_Format::NVTE_SBHD || qkv_format==NVTE_QKV_Format::NVTE_SBHD_2BSHD; bool is_BSHD = qkv_format==NVTE_QKV_Format::NVTE_BSHD; - bool is_padding = (mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); + bool is_padding = is_padding_mask(mask_type); bool bshd_to_thd = is_BSHD && is_padding; // extract the qkv and o storage bytes to allocate buffer for padding removing @@ -730,11 +715,7 @@ void fused_attn_ck_bwd_impl( size_t *workspace_size, cudaStream_t stream) { - bool nvte_log_ck_config = false; - if (const char* env_p = std::getenv("NVTE_LOG_CK_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - nvte_log_ck_config = true; - } + const bool nvte_log_ck_config = getenv("NVTE_LOG_CK_CONFIG"); // bwd v3 is optional by enabling the following envs // default values follows the ck example setting bool nvte_ck_uses_bwd_v3 = getenv("NVTE_CK_USES_BWD_V3", 1); @@ -750,9 +731,7 @@ void fused_attn_ck_bwd_impl( bool is_ragged = qkv_format==NVTE_QKV_Format::NVTE_THD; bool is_SBHD = qkv_format==NVTE_QKV_Format::NVTE_SBHD || qkv_format==NVTE_QKV_Format::NVTE_SBHD_2BSHD; bool is_BSHD = qkv_format==NVTE_QKV_Format::NVTE_BSHD; - bool is_padding = (mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); + bool is_padding = is_padding_mask(mask_type); bool bshd_to_thd = is_BSHD && is_padding; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); @@ -1202,10 +1181,6 @@ void fused_attn_ck_fwd( } size_t workspace_size = 0; - bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); - fused_attn_ck_fwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, bias_b, bias_h, max_tokens_q, max_tokens_kv, @@ -1224,19 +1199,8 @@ void fused_attn_ck_fwd( &workspace_size, stream); - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } + set_workspace_size(workspace, workspace_size); + return; #else NVTE_ERROR("CK fused attn backend not compiled."); #endif // USE_FUSED_ATTN_CK @@ -1290,11 +1254,6 @@ void fused_attn_ck_bwd( size_t workspace_size = 0; - bool is_ragged = nvte_get_qkv_format(qkv_layout)==NVTE_QKV_Format::NVTE_THD; - bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); - // extract the max_tokens for padding/unpadding and softmax_lse buffer // b from cu_seqlen and max_seqlen are not the actual storage batch and seqlen for pad_between_seqs case size_t max_tokens_q = std::accumulate((input_Q->data).shape.begin(), (input_Q->data).shape.end(), static_cast(1), std::multiplies())/h_q/d_qk; @@ -1321,19 +1280,8 @@ void fused_attn_ck_bwd( &workspace_size, stream); - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } + set_workspace_size(workspace, workspace_size); + return; #else NVTE_ERROR("CK fused attn backend not compiled."); #endif // USE_FUSED_ATTN_CK diff --git a/transformer_engine/common/fused_attn_rocm/utils.cpp b/transformer_engine/common/fused_attn_rocm/utils.cpp index 813003364..2261ef607 100644 --- a/transformer_engine/common/fused_attn_rocm/utils.cpp +++ b/transformer_engine/common/fused_attn_rocm/utils.cpp @@ -13,6 +13,35 @@ namespace fused_attn_rocm { using namespace transformer_engine; +bool is_padding_mask(NVTE_Mask_Type mask_type){ + return mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || + mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK; +} + +bool is_causal_mask(NVTE_Mask_Type mask_type){ + return mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || + mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || + mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK; +} + +void set_workspace_size(Tensor *workspace, size_t workspace_size){ + // workspace_size is unsigned, so the only cases are >0 and ==0. + if(workspace_size > 0){ + if(workspace->data.dptr == nullptr){ + // sizing pass: request allocation + workspace->data.shape = {workspace_size}; + workspace->data.dtype = DType::kByte; + } + // execution pass (dptr != nullptr): kernel already ran, nothing to do + return; + } + // workspace_size == 0: report a 1-byte placeholder + workspace->data.shape = {1}; + workspace->data.dtype = DType::kByte; +} + size_t nvte_dtype_size(DType t_dtype){ switch(t_dtype){ case DType::kByte: diff --git a/transformer_engine/common/fused_attn_rocm/utils.h b/transformer_engine/common/fused_attn_rocm/utils.h index 626cad131..6cc9a1430 100644 --- a/transformer_engine/common/fused_attn_rocm/utils.h +++ b/transformer_engine/common/fused_attn_rocm/utils.h @@ -14,6 +14,7 @@ #include "transformer_engine/fused_attn.h" #include "transformer_engine/transformer_engine.h" +#include "../common.h" namespace transformer_engine { @@ -28,6 +29,17 @@ enum NVTE_QKV_Matrix { NVTE_O_Matrix = 3, // final output }; +// mask-class predicates shared across backend support checks and dispatch glue. +// Kept here so a new padding/causal mask variant only needs updating in one place. +bool is_padding_mask(NVTE_Mask_Type mask_type); +bool is_causal_mask(NVTE_Mask_Type mask_type); + +// Finalize a fused-attn workspace tensor from a computed size. In the sizing pass +// (workspace->data.dptr == nullptr) the shape/dtype are set so the framework can +// allocate; a size of 0 is reported as a 1-byte placeholder. Mirrors the boilerplate +// previously duplicated across every backend fwd/bwd entry point. +void set_workspace_size(Tensor *workspace, size_t workspace_size); + void generateMatrixStrides( uint64_t b, uint64_t h, uint64_t s_q, uint64_t s_kv,