Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 11 additions & 28 deletions transformer_engine/common/fused_attn_rocm/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "fused_attn_ck.h"
#include "../common.h"
#include "../util/cuda_runtime.h" //cuda::sm_arch
#include "../util/system.h" //getenv
#include "utils.h"

// map NVTE_QKV_Layout to NVTE_QKV_Layout_Group
Expand Down Expand Up @@ -141,12 +142,9 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) {
// causal_bottom_right, padding_causal_bottom_right | (-1, 0) or (>=0, 0)
std::pair<int64_t, int64_t> check_set_window_size(NVTE_Mask_Type attn_mask_type, std::pair<int64_t, int64_t> window_size){
//mask_type contain causal
bool nvte_log_fused_attn_config = false;
if (const char* env_p = std::getenv("NVTE_LOG_FUSED_ATTN_CONFIG") ) {
if (env_p != nullptr && std::string(env_p) == "1")
nvte_log_fused_attn_config = true;
}
if(attn_mask_type==NVTE_CAUSAL_MASK || attn_mask_type==NVTE_PADDING_CAUSAL_MASK || attn_mask_type==NVTE_CAUSAL_BOTTOM_RIGHT_MASK || attn_mask_type==NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK){
const bool nvte_log_fused_attn_config =
transformer_engine::getenv<bool>("NVTE_LOG_FUSED_ATTN_CONFIG");
if(transformer_engine::fused_attn_rocm::is_causal_mask(attn_mask_type)){
if(window_size==std::make_pair<int64_t, int64_t>(-1, -1) || (window_size.first >=0 && window_size.second!=0)){
//TODO: better INFO logging
if(nvte_log_fused_attn_config){
Expand Down Expand Up @@ -236,11 +234,8 @@ void log_fused_attn_config(
size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) {

//log the fused attn config at NVTE common level
bool nvte_log_fused_attn_config = false;
if (const char* env_p = std::getenv("NVTE_LOG_FUSED_ATTN_CONFIG") ) {
if (env_p != nullptr && std::string(env_p) == "1")
nvte_log_fused_attn_config = true;
}
const bool nvte_log_fused_attn_config =
transformer_engine::getenv<bool>("NVTE_LOG_FUSED_ATTN_CONFIG");
if(!nvte_log_fused_attn_config){
return;
}
Expand Down Expand Up @@ -297,24 +292,12 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
if (return_max_logit) return NVTE_Fused_Attn_Backend::NVTE_No_Backend;

// by default, fused attn is enabled
bool nvte_fused_attn = true;
if (const char* env_p = std::getenv("NVTE_FUSED_ATTN") ) {
if (env_p != nullptr && std::string(env_p) == "0")
nvte_fused_attn = false;
}
const bool nvte_fused_attn = getenv<bool>("NVTE_FUSED_ATTN", true);

// by default, both ck and aotriton backends are enabled by nvte_fused_attn
bool nvte_fused_attn_ck = nvte_fused_attn;
bool nvte_fused_attn_aotriton = nvte_fused_attn;

if (const char* env_p = std::getenv("NVTE_FUSED_ATTN_CK") ) {
if (env_p != nullptr && std::string(env_p) == "0")
nvte_fused_attn_ck = false;
}
if (const char* env_p = std::getenv("NVTE_FUSED_ATTN_AOTRITON") ) {
if (env_p != nullptr && std::string(env_p) == "0")
nvte_fused_attn_aotriton = false;
}
// by default, both ck and aotriton backends inherit the master toggle
const bool nvte_fused_attn_ck = nvte_fused_attn && getenv<bool>("NVTE_FUSED_ATTN_CK", true);
const bool nvte_fused_attn_aotriton =
nvte_fused_attn && getenv<bool>("NVTE_FUSED_ATTN_AOTRITON", true);

// fix the incompatible window size from upstream frameworks pytorch/jax
std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,7 @@ void fused_attn_aotriton_fwd_impl(
std::array<uint64_t, 4>{1, 1, 1, 1},
dtype);

bool nvte_log_aotriton_config = false;
if (const char* env_p = std::getenv("NVTE_LOG_AOTRITON_CONFIG") ) {
if (env_p != nullptr && std::string(env_p) == "1")
nvte_log_aotriton_config = true;
}
const bool nvte_log_aotriton_config = getenv<bool>("NVTE_LOG_AOTRITON_CONFIG");
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, dtype);
auto seed = mk_aoscalartensor(devPtrDropoutSeed);
auto offset1 = mk_aoscalartensor(devPtrDropoutOffset);
Expand Down Expand Up @@ -401,11 +397,7 @@ void fused_attn_aotriton_bwd_impl(
auto cu_seqlens_q = aotriton::TensorView<1>(reinterpret_cast<intptr_t>(devPtrCuSeqlensQ), cu_seqlens_shape, cu_seqlens_stride, aotriton::DType::kInt32);
auto cu_seqlens_k = aotriton::TensorView<1>(reinterpret_cast<intptr_t>(devPtrCuSeqlensKV), cu_seqlens_shape, cu_seqlens_stride, aotriton::DType::kInt32);

bool nvte_log_aotriton_config = false;
if (const char* env_p = std::getenv("NVTE_LOG_AOTRITON_CONFIG") ) {
if (env_p != nullptr && std::string(env_p) == "1")
nvte_log_aotriton_config = true;
}
const bool nvte_log_aotriton_config = getenv<bool>("NVTE_LOG_AOTRITON_CONFIG");
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, dtype);
auto seed = mk_aoscalartensor(devPtrDropoutSeed);
auto offset = mk_aoscalartensor(devPtrDropoutOffset);
Expand Down Expand Up @@ -555,19 +547,8 @@ void fused_attn_aotriton_fwd(
&workspace_size,
stream);

if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
set_workspace_size(workspace, workspace_size);
return;
#else
NVTE_ERROR("AOTriton backend not compiled.");
#endif // USE_FUSED_ATTN_AOTRITON
Expand Down Expand Up @@ -620,19 +601,8 @@ void fused_attn_aotriton_bwd(
&workspace_size,
stream);

if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
set_workspace_size(workspace, workspace_size);
return;
#else
NVTE_ERROR("AOTriton backend not compiled.");
#endif // USE_FUSED_ATTN_AOTRITON
Expand Down
76 changes: 12 additions & 64 deletions transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,8 @@ bool is_ck_backend_supported(
#ifdef USE_FUSED_ATTN_CK

// debug info setting
bool nvte_log_ck_config = false;
if (const char* env_p = std::getenv("NVTE_LOG_CK_CONFIG") ) {
if (env_p != nullptr && std::string(env_p) == "1")
nvte_log_ck_config = true;
}

const bool nvte_log_ck_config = getenv<bool>("NVTE_LOG_CK_CONFIG");

// single filters

// filter based on num_heads and num_gqa_groups
Expand Down Expand Up @@ -90,10 +86,7 @@ bool is_ck_backend_supported(
}

// joint filter based on sliding window and attn_mask
bool is_causal = (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK||
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK);
bool is_causal = is_causal_mask(attn_mask_type);
if(is_causal){
// causal mask window must be with causal top left or causal bottom right mask type
if (!((window_size_left ==-1 || window_size_left >=0) && window_size_right ==0 )){
Expand Down Expand Up @@ -137,9 +130,7 @@ bool is_ck_backend_supported(
// in NVTE, padding can happen in both THD format or BSHD/SBHD format
// For THD format, padding is natural
// For BSHD/SBHD, padding can be inferred by a cu_seqlen which shows the actual seqlen for each batch, while the dim(S) is the max_seqlen
bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK);
bool is_padding = is_padding_mask(attn_mask_type);
if(is_ragged && !is_padding){
if(nvte_log_ck_config){
std::cout<<"Ragged QKV input requires padding mask"<<std::endl;
Expand Down Expand Up @@ -483,11 +474,7 @@ void fused_attn_ck_fwd_impl(
size_t *workspace_size,
cudaStream_t stream){

bool nvte_log_ck_config = false;
if (const char* env_p = std::getenv("NVTE_LOG_CK_CONFIG") ) {
if (env_p != nullptr && std::string(env_p) == "1")
nvte_log_ck_config = true;
}
const bool nvte_log_ck_config = getenv<bool>("NVTE_LOG_CK_CONFIG");

bool nvte_ck_uses_fwd_v3 = getenv<int>("NVTE_CK_USES_FWD_V3", 1);
int nvte_ck_how_v3_bf16_cvt = getenv<int>("NVTE_CK_HOW_V3_BF16_CVT", 1);
Expand All @@ -497,9 +484,7 @@ void fused_attn_ck_fwd_impl(
bool is_SBHD = qkv_format==NVTE_QKV_Format::NVTE_SBHD || qkv_format==NVTE_QKV_Format::NVTE_SBHD_2BSHD;
bool is_BSHD = qkv_format==NVTE_QKV_Format::NVTE_BSHD;

bool is_padding = (mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK);
bool is_padding = is_padding_mask(mask_type);
bool bshd_to_thd = is_BSHD && is_padding;

// extract the qkv and o storage bytes to allocate buffer for padding removing
Expand Down Expand Up @@ -730,11 +715,7 @@ void fused_attn_ck_bwd_impl(
size_t *workspace_size,
cudaStream_t stream) {

bool nvte_log_ck_config = false;
if (const char* env_p = std::getenv("NVTE_LOG_CK_CONFIG") ) {
if (env_p != nullptr && std::string(env_p) == "1")
nvte_log_ck_config = true;
}
const bool nvte_log_ck_config = getenv<bool>("NVTE_LOG_CK_CONFIG");
// bwd v3 is optional by enabling the following envs
// default values follows the ck example setting
bool nvte_ck_uses_bwd_v3 = getenv<int>("NVTE_CK_USES_BWD_V3", 1);
Expand All @@ -750,9 +731,7 @@ void fused_attn_ck_bwd_impl(
bool is_ragged = qkv_format==NVTE_QKV_Format::NVTE_THD;
bool is_SBHD = qkv_format==NVTE_QKV_Format::NVTE_SBHD || qkv_format==NVTE_QKV_Format::NVTE_SBHD_2BSHD;
bool is_BSHD = qkv_format==NVTE_QKV_Format::NVTE_BSHD;
bool is_padding = (mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK);
bool is_padding = is_padding_mask(mask_type);
bool bshd_to_thd = is_BSHD && is_padding;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);

Expand Down Expand Up @@ -1202,10 +1181,6 @@ void fused_attn_ck_fwd(
}
size_t workspace_size = 0;

bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK);

fused_attn_ck_fwd_impl(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, bias_b, bias_h,
max_tokens_q, max_tokens_kv,
Expand All @@ -1224,19 +1199,8 @@ void fused_attn_ck_fwd(
&workspace_size,
stream);

if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
set_workspace_size(workspace, workspace_size);
return;
#else
NVTE_ERROR("CK fused attn backend not compiled.");
#endif // USE_FUSED_ATTN_CK
Expand Down Expand Up @@ -1290,11 +1254,6 @@ void fused_attn_ck_bwd(

size_t workspace_size = 0;

bool is_ragged = nvte_get_qkv_format(qkv_layout)==NVTE_QKV_Format::NVTE_THD;
bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK);

// extract the max_tokens for padding/unpadding and softmax_lse buffer
// b from cu_seqlen and max_seqlen are not the actual storage batch and seqlen for pad_between_seqs case
size_t max_tokens_q = std::accumulate((input_Q->data).shape.begin(), (input_Q->data).shape.end(), static_cast<size_t>(1), std::multiplies<size_t>())/h_q/d_qk;
Expand All @@ -1321,19 +1280,8 @@ void fused_attn_ck_bwd(
&workspace_size,
stream);

if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
return;
} else {
NVTE_ERROR("Unexpected workspace_size.");
}
set_workspace_size(workspace, workspace_size);
return;
#else
NVTE_ERROR("CK fused attn backend not compiled.");
#endif // USE_FUSED_ATTN_CK
Expand Down
29 changes: 29 additions & 0 deletions transformer_engine/common/fused_attn_rocm/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,35 @@ namespace fused_attn_rocm {

using namespace transformer_engine;

bool is_padding_mask(NVTE_Mask_Type mask_type){
return mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK;
}

bool is_causal_mask(NVTE_Mask_Type mask_type){
return mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK ||
mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK;
}

void set_workspace_size(Tensor *workspace, size_t workspace_size){
// workspace_size is unsigned, so the only cases are >0 and ==0.
if(workspace_size > 0){
if(workspace->data.dptr == nullptr){
// sizing pass: request allocation
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
}
// execution pass (dptr != nullptr): kernel already ran, nothing to do
return;
}
// workspace_size == 0: report a 1-byte placeholder
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
}

size_t nvte_dtype_size(DType t_dtype){
switch(t_dtype){
case DType::kByte:
Expand Down
12 changes: 12 additions & 0 deletions transformer_engine/common/fused_attn_rocm/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "transformer_engine/fused_attn.h"
#include "transformer_engine/transformer_engine.h"
#include "../common.h"


namespace transformer_engine {
Expand All @@ -28,6 +29,17 @@ enum NVTE_QKV_Matrix {
NVTE_O_Matrix = 3, // final output
};

// mask-class predicates shared across backend support checks and dispatch glue.
// Kept here so a new padding/causal mask variant only needs updating in one place.
bool is_padding_mask(NVTE_Mask_Type mask_type);
bool is_causal_mask(NVTE_Mask_Type mask_type);

// Finalize a fused-attn workspace tensor from a computed size. In the sizing pass
// (workspace->data.dptr == nullptr) the shape/dtype are set so the framework can
// allocate; a size of 0 is reported as a 1-byte placeholder. Mirrors the boilerplate
// previously duplicated across every backend fwd/bwd entry point.
void set_workspace_size(Tensor *workspace, size_t workspace_size);

void generateMatrixStrides(
uint64_t b, uint64_t h,
uint64_t s_q, uint64_t s_kv,
Expand Down
Loading