From 9de7f8f6ebe926d13937da301d74ec2ce962533f Mon Sep 17 00:00:00 2001 From: bolunz Date: Thu, 28 May 2026 03:08:30 +0000 Subject: [PATCH 1/3] feat: add MLASelfAttention Module --- .../transformer/causal_self_attention.h | 7 - .../modules/transformer/mla_self_attention.h | 50 +++++ .../include/nn/modules/transformer/utils.h | 6 + .../transformer/causal_self_attention.cc | 38 +--- .../modules/transformer/mla_self_attention.cc | 185 ++++++++++++++++++ .../src/nn/modules/transformer/utils.cc | 38 ++++ .../test_transformer_architecture.cc | 25 +++ 7 files changed, 305 insertions(+), 44 deletions(-) create mode 100644 infini_train/include/nn/modules/transformer/mla_self_attention.h create mode 100644 infini_train/src/nn/modules/transformer/mla_self_attention.cc diff --git a/infini_train/include/nn/modules/transformer/causal_self_attention.h b/infini_train/include/nn/modules/transformer/causal_self_attention.h index 5ac55e31..7a96714f 100644 --- a/infini_train/include/nn/modules/transformer/causal_self_attention.h +++ b/infini_train/include/nn/modules/transformer/causal_self_attention.h @@ -1,7 +1,6 @@ #pragma once #include -#include #include #include "infini_train/include/nn/modules/module.h" @@ -43,12 +42,6 @@ class CausalSelfAttention : public infini_train::nn::CloneableModule> ForwardWithRoPE(const std::vector> &x); - // RoPE helper methods - std::tuple, std::shared_ptr> - ApplyRotaryEmbedding(const std::shared_ptr &xq, - const std::shared_ptr &xk, - const std::shared_ptr &freqs_cis); - // GQA helper method std::shared_ptr RepeatKV(const std::shared_ptr &x, int64_t n_rep); }; diff --git a/infini_train/include/nn/modules/transformer/mla_self_attention.h b/infini_train/include/nn/modules/transformer/mla_self_attention.h new file mode 100644 index 00000000..b4419e43 --- /dev/null +++ b/infini_train/include/nn/modules/transformer/mla_self_attention.h @@ -0,0 +1,50 @@ +#pragma once + +#include +#include + +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace infini_train::nn { + +class MLASelfAttention : public infini_train::nn::CloneableModule { +public: + static constexpr char kType[] = "MLASelfAttention"; + + static constexpr char kQAProjLayerName[] = "q_a_proj"; + static constexpr char kQANormLayerName[] = "q_a_layernorm"; + static constexpr char kQBProjLayerName[] = "q_b_proj"; + static constexpr char kKVAProjLayerName[] = "kv_a_proj_with_mqa"; + static constexpr char kKVANormLayerName[] = "kv_a_layernorm"; + static constexpr char kKVBProjLayerName[] = "kv_b_proj"; + static constexpr char kCProjLayerName[] = "c_proj"; + + static constexpr char kParamBiasName[] = "bias"; + + explicit MLASelfAttention(const TransformerConfig &config); + MLASelfAttention(const TransformerConfig &config, int64_t q_lora_rank, int64_t kv_lora_rank, + int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim); + + std::vector> + Forward(const std::vector> &x) override; + +private: + TransformerConfig config_; + int64_t n_head_ = 0; + int64_t n_embd_ = 0; + int64_t local_n_head_ = 0; + + int64_t q_lora_rank_ = 0; + int64_t kv_lora_rank_ = 0; + int64_t qk_nope_head_dim_ = 0; + int64_t qk_rope_head_dim_ = 0; + int64_t qk_head_dim_ = 0; + int64_t v_head_dim_ = 0; + + void SetupAttention(const TransformerConfig &config, int64_t q_lora_rank, int64_t kv_lora_rank, + int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim); + +}; + +} // namespace infini_train::nn diff --git a/infini_train/include/nn/modules/transformer/utils.h b/infini_train/include/nn/modules/transformer/utils.h index d3a62c63..30db08e6 100644 --- a/infini_train/include/nn/modules/transformer/utils.h +++ b/infini_train/include/nn/modules/transformer/utils.h @@ -1,6 +1,8 @@ #pragma once #include +#include +#include #include "infini_train/include/tensor.h" @@ -8,4 +10,8 @@ namespace infini_train { // RoPE helper method std::shared_ptr PrecomputeFreqsCis(int64_t dim, int64_t end, float theta = 10000.0f, bool use_scaled = false, Device device = Device()); + +std::tuple, std::shared_ptr> +ApplyRotaryEmbedding(const std::shared_ptr &xq, const std::shared_ptr &xk, + const std::shared_ptr &freqs_cis); } // namespace infini_train diff --git a/infini_train/src/nn/modules/transformer/causal_self_attention.cc b/infini_train/src/nn/modules/transformer/causal_self_attention.cc index 5ea9eec5..7320ca12 100644 --- a/infini_train/src/nn/modules/transformer/causal_self_attention.cc +++ b/infini_train/src/nn/modules/transformer/causal_self_attention.cc @@ -12,6 +12,7 @@ #include "infini_train/include/nn/modules/normalization.h" #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/modules/transformer/transformer_config.h" +#include "infini_train/include/nn/modules/transformer/utils.h" #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" #include "infini_train/include/tensor.h" @@ -130,43 +131,6 @@ CausalSelfAttention::ForwardStandard(const std::vector, std::shared_ptr> -CausalSelfAttention::ApplyRotaryEmbedding(const std::shared_ptr &xq, - const std::shared_ptr &xk, - const std::shared_ptr &freqs_cis) { - // Reshape freqs_cis for broadcasting - const auto &x_shape = xq->Dims(); // (B, T, H, D) - const int64_t T = x_shape[1]; - const int64_t D = x_shape[3]; - - std::vector target_shape = {1, T, 1, D / 2, 2}; - auto cos_sin = freqs_cis->View(target_shape); // -> (1, T, 1, D/2, 2) - - auto cos = cos_sin->Slice(-1, 0, 1, 1)->Squeeze(-1); // (1, T, 1, D/2) - auto sin = cos_sin->Slice(-1, 1, 2, 1)->Squeeze(-1); // (1, T, 1, D/2) - - auto slice_pair = [](const std::shared_ptr &x) { - auto even = x->Slice(-1, 0, x->Dims().back(), 2); - auto odd = x->Slice(-1, 1, x->Dims().back(), 2); - return std::make_pair(even, odd); - }; - - auto [q_even, q_odd] = slice_pair(xq); - auto q_rotated_left = q_even * cos - q_odd * sin; - auto q_rotated_right = q_even * sin + q_odd * cos; - auto q_rotated - = nn::function::Stack(std::vector>{q_rotated_left, q_rotated_right}, -1)->Flatten(-2); - - auto [k_even, k_odd] = slice_pair(xk); - auto k_rotated_left = k_even * cos - k_odd * sin; - auto k_rotated_right = k_even * sin + k_odd * cos; - auto k_rotated - = nn::function::Stack(std::vector>{k_rotated_left, k_rotated_right}, -1)->Flatten(-2); - - return {q_rotated, k_rotated}; -} - std::shared_ptr CausalSelfAttention::RepeatKV(const std::shared_ptr &x, int64_t n_rep) { const auto &shape = x->Dims(); diff --git a/infini_train/src/nn/modules/transformer/mla_self_attention.cc b/infini_train/src/nn/modules/transformer/mla_self_attention.cc new file mode 100644 index 00000000..097cf830 --- /dev/null +++ b/infini_train/src/nn/modules/transformer/mla_self_attention.cc @@ -0,0 +1,185 @@ +#include "infini_train/include/nn/modules/transformer/mla_self_attention.h" + +#include +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/nn/functional.h" +#include "infini_train/include/nn/modules/linear.h" +#include "infini_train/include/nn/modules/normalization.h" +#include "infini_train/include/nn/modules/transformer/transformer_config.h" +#include "infini_train/include/nn/modules/transformer/utils.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/nn/parallel/tensor_parallel.h" +#include "infini_train/include/nn/parallel/utils.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn { +namespace { +int64_t DefaultQKVHeadDim(const TransformerConfig &config) { + CHECK_EQ(config.n_embd % config.n_head, 0) << "n_embd must be divisible by n_head"; + return config.n_embd / config.n_head; +} + +int64_t DefaultQKRoPEHeadDim(const TransformerConfig &config) { + return DefaultQKVHeadDim(config); +} + +int64_t DefaultQKNoPEHeadDim(const TransformerConfig &config) { + return DefaultQKVHeadDim(config); +} +} // namespace + +MLASelfAttention::MLASelfAttention(const TransformerConfig &config) + : MLASelfAttention(config, + /*q_lora_rank=*/config.n_embd, + /*kv_lora_rank=*/config.n_embd, + /*qk_nope_head_dim=*/DefaultQKNoPEHeadDim(config), + /*qk_rope_head_dim=*/DefaultQKRoPEHeadDim(config), + /*v_head_dim=*/DefaultQKVHeadDim(config)) {} + +MLASelfAttention::MLASelfAttention(const TransformerConfig &config, int64_t q_lora_rank, int64_t kv_lora_rank, + int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim) + : CloneableModule(kType), config_(config) { + SetupAttention(config, q_lora_rank, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim); + + modules_[kQAProjLayerName] = std::make_shared( + /*in_features=*/n_embd_, + /*out_features=*/q_lora_rank_, + /*bias=*/config_.add_bias_linear); + modules_[kQANormLayerName] = std::make_shared(q_lora_rank_, config_.norm_eps); + modules_[kQBProjLayerName] = std::make_shared( + /*in_features=*/q_lora_rank_, + /*out_features=*/n_head_ * qk_head_dim_, + /*bias=*/config_.add_bias_linear, + /*gather_output=*/false, + /*input_is_parallel=*/false, + /*skip_bias_add=*/false, + /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); + + modules_[kKVAProjLayerName] = std::make_shared( + /*in_features=*/n_embd_, + /*out_features=*/kv_lora_rank_ + qk_rope_head_dim_, + /*bias=*/config_.add_bias_linear); + modules_[kKVANormLayerName] = std::make_shared(kv_lora_rank_, config_.norm_eps); + modules_[kKVBProjLayerName] = std::make_shared( + /*in_features=*/kv_lora_rank_, + /*out_features=*/n_head_ * (qk_nope_head_dim_ + v_head_dim_), + /*bias=*/config_.add_bias_linear, + /*gather_output=*/false, + /*input_is_parallel=*/false, + /*skip_bias_add=*/false, + /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); + + modules_[kCProjLayerName] = std::make_shared( + /*in_features=*/n_head_ * v_head_dim_, + /*out_features=*/n_embd_, + /*bias=*/config_.add_bias_linear, + /*reduce_output=*/true, + /*input_is_parallel=*/true, + /*skip_bias_add=*/false, + /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); + + buffers_[kParamBiasName] = function::Tril(nn::function::Ones({config_.block_size, config_.block_size})) + ->View({1, 1, config_.block_size, config_.block_size}); +} + +void MLASelfAttention::SetupAttention(const TransformerConfig &config, int64_t q_lora_rank, int64_t kv_lora_rank, + int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim) { + auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); + + CHECK_EQ(config.n_embd % config.n_head, 0) << "n_embd must be divisible by n_head"; + CHECK_EQ(config.n_head % tp_world_size, 0) << "n_head must be divisible by TP world size"; + CHECK_GT(q_lora_rank, 0) << "q_lora_rank must be positive"; + CHECK_GT(kv_lora_rank, 0) << "kv_lora_rank must be positive"; + CHECK_GT(qk_nope_head_dim, 0) << "qk_nope_head_dim must be positive"; + CHECK_GT(qk_rope_head_dim, 0) << "qk_rope_head_dim must be positive"; + CHECK_GT(v_head_dim, 0) << "v_head_dim must be positive"; + CHECK_EQ(qk_rope_head_dim % 2, 0) << "qk_rope_head_dim must be even for RoPE"; + + n_head_ = config.n_head; + n_embd_ = config.n_embd; + local_n_head_ = n_head_ / tp_world_size; + + q_lora_rank_ = q_lora_rank; + kv_lora_rank_ = kv_lora_rank; + qk_nope_head_dim_ = qk_nope_head_dim; + qk_rope_head_dim_ = qk_rope_head_dim; + qk_head_dim_ = qk_nope_head_dim_ + qk_rope_head_dim_; + v_head_dim_ = v_head_dim; +} + +std::vector> +MLASelfAttention::Forward(const std::vector> &x) { + CHECK_GE(x.size(), 1) << "MLASelfAttention expects at least hidden states"; + + const auto B = x[0]->Dims()[0]; + const auto C = x[0]->Dims()[2]; + CHECK_EQ(C, n_embd_) << "hidden size must match n_embd"; + + const auto freqs_cis = x.size() > 1 ? x[1] : nullptr; + const auto external_mask = x.size() > 3 ? x[3] : nullptr; + if (config_.attention_type == AttentionType::kRoPE) { + CHECK(freqs_cis != nullptr) << "freqs_cis is null."; + } + + // (B, T, C) -> q_a -> RMSNorm -> q_b -> (B, T, H_local * (D_nope + D_rope)) + auto q = (*modules_[kQAProjLayerName])({x[0]})[0]; + q = (*modules_[kQANormLayerName])({q})[0]; + q = (*modules_[kQBProjLayerName])({q})[0]; + const auto T = q->Dims()[1]; + q = q->View({B, T, local_n_head_, qk_head_dim_}); + + auto q_nope = q->Slice(-1, 0, qk_nope_head_dim_); + auto q_pe = q->Slice(-1, qk_nope_head_dim_, qk_head_dim_); + + // (B, T, C) -> kv_a -> compressed kv latent and shared RoPE key. + auto compressed_kv_with_pe = (*modules_[kKVAProjLayerName])({x[0]})[0]; + auto compressed_kv = compressed_kv_with_pe->Slice(-1, 0, kv_lora_rank_); + auto k_pe = compressed_kv_with_pe->Slice(-1, kv_lora_rank_, kv_lora_rank_ + qk_rope_head_dim_) + ->Contiguous(); + if (nn::parallel::global::GetSequenceParallelEnabled()) { + k_pe = nn::parallel::GatherFromSPRegionFunc(k_pe)[0]; + } + k_pe = k_pe->View({B, T, 1, qk_rope_head_dim_}); + + // (B, T, R_kv) -> RMSNorm -> kv_b -> (B, T, H_local * (D_nope + D_v)) + auto kv = (*modules_[kKVANormLayerName])({compressed_kv})[0]; + kv = (*modules_[kKVBProjLayerName])({kv})[0]; + kv = kv->View({B, T, local_n_head_, qk_nope_head_dim_ + v_head_dim_}); + auto k_nope = kv->Slice(-1, 0, qk_nope_head_dim_); + auto v = kv->Slice(-1, qk_nope_head_dim_, qk_nope_head_dim_ + v_head_dim_); + + if (config_.attention_type == AttentionType::kRoPE) { + std::tie(q_pe, k_pe) = ApplyRotaryEmbedding(q_pe, k_pe, freqs_cis); + } + + k_pe = k_pe->RepeatInterleave(local_n_head_, 2); + q = nn::function::Concat(std::vector>{q_nope, q_pe}, -1); + auto k = nn::function::Concat(std::vector>{k_nope, k_pe}, -1); + + // (B, T, H_local, D) -> (B, H_local, T, D) + q = q->Transpose(1, 2); + k = k->Transpose(1, 2); + v = v->Transpose(1, 2); + + auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast(qk_head_dim_))); + if (external_mask) { + att = att->MaskedFill(external_mask, std::numeric_limits::lowest()); + } else { + auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1}); + att = att->MaskedFill(mask == 0, -std::numeric_limits::infinity()); + } + att = nn::function::Softmax(att, -1); + + auto y = att->Matmul(v); + y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_n_head_ * v_head_dim_}); + y = (*modules_[kCProjLayerName])({y})[0]; + return {y}; +} + +} // namespace infini_train::nn diff --git a/infini_train/src/nn/modules/transformer/utils.cc b/infini_train/src/nn/modules/transformer/utils.cc index 98505fd0..4ec11f2d 100644 --- a/infini_train/src/nn/modules/transformer/utils.cc +++ b/infini_train/src/nn/modules/transformer/utils.cc @@ -1,5 +1,9 @@ #include "infini_train/include/nn/modules/transformer/utils.h" +#include +#include +#include + #include "glog/logging.h" #include "infini_train/include/nn/functional.h" @@ -27,4 +31,38 @@ std::shared_ptr PrecomputeFreqsCis(int64_t dim, int64_t end, float theta return freqs_cis; } + +std::tuple, std::shared_ptr> +ApplyRotaryEmbedding(const std::shared_ptr &xq, const std::shared_ptr &xk, + const std::shared_ptr &freqs_cis) { + const auto &x_shape = xq->Dims(); // (B, T, H, D) + const int64_t T = x_shape[1]; + const int64_t D = x_shape[3]; + + std::vector target_shape = {1, T, 1, D / 2, 2}; + auto cos_sin = freqs_cis->View(target_shape); // -> (1, T, 1, D/2, 2) + + auto cos = cos_sin->Slice(-1, 0, 1, 1)->Squeeze(-1); // (1, T, 1, D/2) + auto sin = cos_sin->Slice(-1, 1, 2, 1)->Squeeze(-1); // (1, T, 1, D/2) + + auto slice_pair = [](const std::shared_ptr &x) { + auto even = x->Slice(-1, 0, x->Dims().back(), 2); + auto odd = x->Slice(-1, 1, x->Dims().back(), 2); + return std::make_pair(even, odd); + }; + + auto [q_even, q_odd] = slice_pair(xq); + auto q_rotated_left = q_even * cos - q_odd * sin; + auto q_rotated_right = q_even * sin + q_odd * cos; + auto q_rotated + = nn::function::Stack(std::vector>{q_rotated_left, q_rotated_right}, -1)->Flatten(-2); + + auto [k_even, k_odd] = slice_pair(xk); + auto k_rotated_left = k_even * cos - k_odd * sin; + auto k_rotated_right = k_even * sin + k_odd * cos; + auto k_rotated + = nn::function::Stack(std::vector>{k_rotated_left, k_rotated_right}, -1)->Flatten(-2); + + return {q_rotated, k_rotated}; +} } // namespace infini_train diff --git a/tests/transformer/test_transformer_architecture.cc b/tests/transformer/test_transformer_architecture.cc index ba62e1e3..f36d10f6 100644 --- a/tests/transformer/test_transformer_architecture.cc +++ b/tests/transformer/test_transformer_architecture.cc @@ -7,6 +7,7 @@ #include "infini_train/include/nn/modules/normalization.h" #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" +#include "infini_train/include/nn/modules/transformer/mla_self_attention.h" #include "infini_train/include/nn/modules/transformer/mlp.h" #include "infini_train/include/nn/modules/transformer/transformer.h" #include "infini_train/include/nn/modules/transformer/transformer_config.h" @@ -110,6 +111,30 @@ TEST_P(TransformerModuleTest, StandardAttention) { EXPECT_EQ(output[0]->Dims(), input->Dims()); } +TEST_P(TransformerModuleTest, MLAAttention) { + SKIP_CPU(); + nn::TransformerConfig config; + config.n_embd = 64; + config.n_head = 4; + config.block_size = 16; + config.attention_type = nn::AttentionType::kStandard; + config.add_bias_linear = true; + + auto attn = std::make_shared( + config, + /*q_lora_rank=*/32, + /*kv_lora_rank=*/32, + /*qk_nope_head_dim=*/8, + /*qk_rope_head_dim=*/8, + /*v_head_dim=*/16); + attn->To(GetDevice()); + EXPECT_FALSE(attn->Parameters().empty()); + + auto input = std::make_shared(std::vector{2, 8, 64}, DataType::kFLOAT32, GetDevice()); + auto output = (*attn)({input}); + EXPECT_EQ(output[0]->Dims(), input->Dims()); +} + TEST_P(TransformerModuleTest, GPT2TransformerLayer) { SKIP_CPU(); nn::TransformerConfig config; From 87ca357154f544aa783bbe533e451027ceb446f6 Mon Sep 17 00:00:00 2001 From: bolunz Date: Thu, 28 May 2026 13:33:10 +0000 Subject: [PATCH 2/3] feat: support q_lora/non-q_lora and tp/non-tp variations --- .../modules/transformer/mla_self_attention.h | 26 ++- infini_train/include/nn/parallel/utils.h | 1 + .../modules/transformer/mla_self_attention.cc | 201 ++++++++++++++---- .../src/nn/parallel/tensor_parallel.cc | 37 ++++ .../test_transformer_architecture.cc | 34 +++ 5 files changed, 242 insertions(+), 57 deletions(-) diff --git a/infini_train/include/nn/modules/transformer/mla_self_attention.h b/infini_train/include/nn/modules/transformer/mla_self_attention.h index b4419e43..75b9da3a 100644 --- a/infini_train/include/nn/modules/transformer/mla_self_attention.h +++ b/infini_train/include/nn/modules/transformer/mla_self_attention.h @@ -12,19 +12,21 @@ class MLASelfAttention : public infini_train::nn::CloneableModule> Forward(const std::vector> &x) override; @@ -42,9 +44,13 @@ class MLASelfAttention : public infini_train::nn::CloneableModule GetPipelineParallelGroupRanks(int global_rank); // TP/SP Communication Helper Functions std::vector> GatherFromTPRegionFunc(const std::shared_ptr &input); +std::vector> ScatterToSPRegionFunc(const std::shared_ptr &input); std::vector> ReduceScatterToSPRegionFunc(const std::shared_ptr &input); std::vector> GatherFromSPRegionFunc(const std::shared_ptr &input); std::vector> ScatterToTPRegionFunc(const std::shared_ptr &input); diff --git a/infini_train/src/nn/modules/transformer/mla_self_attention.cc b/infini_train/src/nn/modules/transformer/mla_self_attention.cc index 097cf830..423c91c5 100644 --- a/infini_train/src/nn/modules/transformer/mla_self_attention.cc +++ b/infini_train/src/nn/modules/transformer/mla_self_attention.cc @@ -43,30 +43,65 @@ MLASelfAttention::MLASelfAttention(const TransformerConfig &config) /*v_head_dim=*/DefaultQKVHeadDim(config)) {} MLASelfAttention::MLASelfAttention(const TransformerConfig &config, int64_t q_lora_rank, int64_t kv_lora_rank, - int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim) + int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim, + bool q_down_proj_use_tp, bool kv_down_proj_use_tp) : CloneableModule(kType), config_(config) { - SetupAttention(config, q_lora_rank, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim); - - modules_[kQAProjLayerName] = std::make_shared( - /*in_features=*/n_embd_, - /*out_features=*/q_lora_rank_, - /*bias=*/config_.add_bias_linear); - modules_[kQANormLayerName] = std::make_shared(q_lora_rank_, config_.norm_eps); - modules_[kQBProjLayerName] = std::make_shared( - /*in_features=*/q_lora_rank_, - /*out_features=*/n_head_ * qk_head_dim_, - /*bias=*/config_.add_bias_linear, - /*gather_output=*/false, - /*input_is_parallel=*/false, - /*skip_bias_add=*/false, - /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); + SetupAttention(config, q_lora_rank, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, + q_down_proj_use_tp, kv_down_proj_use_tp); + + if (use_q_lora_) { + if (q_down_proj_use_tp_) { + modules_[kLinearQDownProjLayerName] = std::make_shared( + /*in_features=*/n_embd_, + /*out_features=*/q_lora_rank_, + /*bias=*/config_.add_bias_linear, + /*gather_output=*/false, + /*input_is_parallel=*/false, + /*skip_bias_add=*/false, + /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); + } else { + modules_[kLinearQDownProjLayerName] = std::make_shared( + /*in_features=*/n_embd_, + /*out_features=*/q_lora_rank_, + /*bias=*/config_.add_bias_linear); + } + modules_[kQLayerNormLayerName] = std::make_shared(q_lora_rank_, config_.norm_eps); + modules_[kLinearQUpProjLayerName] = std::make_shared( + /*in_features=*/q_lora_rank_, + /*out_features=*/n_head_ * qk_head_dim_, + /*bias=*/config_.add_bias_linear, + /*gather_output=*/false, + /*input_is_parallel=*/false, + /*skip_bias_add=*/false, + /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); + } else { + modules_[kLinearQProjLayerName] = std::make_shared( + /*in_features=*/n_embd_, + /*out_features=*/n_head_ * qk_head_dim_, + /*bias=*/config_.add_bias_linear, + /*gather_output=*/false, + /*input_is_parallel=*/false, + /*skip_bias_add=*/false, + /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); + } - modules_[kKVAProjLayerName] = std::make_shared( - /*in_features=*/n_embd_, - /*out_features=*/kv_lora_rank_ + qk_rope_head_dim_, - /*bias=*/config_.add_bias_linear); - modules_[kKVANormLayerName] = std::make_shared(kv_lora_rank_, config_.norm_eps); - modules_[kKVBProjLayerName] = std::make_shared( + if (kv_down_proj_use_tp_) { + modules_[kLinearKVDownProjLayerName] = std::make_shared( + /*in_features=*/n_embd_, + /*out_features=*/kv_lora_rank_ + qk_rope_head_dim_, + /*bias=*/config_.add_bias_linear, + /*gather_output=*/false, + /*input_is_parallel=*/false, + /*skip_bias_add=*/false, + /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); + } else { + modules_[kLinearKVDownProjLayerName] = std::make_shared( + /*in_features=*/n_embd_, + /*out_features=*/kv_lora_rank_ + qk_rope_head_dim_, + /*bias=*/config_.add_bias_linear); + } + modules_[kKVLayerNormLayerName] = std::make_shared(kv_lora_rank_, config_.norm_eps); + modules_[kLinearKVUpProjLayerName] = std::make_shared( /*in_features=*/kv_lora_rank_, /*out_features=*/n_head_ * (qk_nope_head_dim_ + v_head_dim_), /*bias=*/config_.add_bias_linear, @@ -75,7 +110,7 @@ MLASelfAttention::MLASelfAttention(const TransformerConfig &config, int64_t q_lo /*skip_bias_add=*/false, /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); - modules_[kCProjLayerName] = std::make_shared( + modules_[kLinearProjLayerName] = std::make_shared( /*in_features=*/n_head_ * v_head_dim_, /*out_features=*/n_embd_, /*bias=*/config_.add_bias_linear, @@ -89,12 +124,13 @@ MLASelfAttention::MLASelfAttention(const TransformerConfig &config, int64_t q_lo } void MLASelfAttention::SetupAttention(const TransformerConfig &config, int64_t q_lora_rank, int64_t kv_lora_rank, - int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim) { + int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim, + bool q_down_proj_use_tp, bool kv_down_proj_use_tp) { auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); CHECK_EQ(config.n_embd % config.n_head, 0) << "n_embd must be divisible by n_head"; CHECK_EQ(config.n_head % tp_world_size, 0) << "n_head must be divisible by TP world size"; - CHECK_GT(q_lora_rank, 0) << "q_lora_rank must be positive"; + CHECK(q_lora_rank == -1 || q_lora_rank > 0) << "q_lora_rank must be positive, or -1 to disable q LoRA"; CHECK_GT(kv_lora_rank, 0) << "kv_lora_rank must be positive"; CHECK_GT(qk_nope_head_dim, 0) << "qk_nope_head_dim must be positive"; CHECK_GT(qk_rope_head_dim, 0) << "qk_rope_head_dim must be positive"; @@ -105,80 +141,151 @@ void MLASelfAttention::SetupAttention(const TransformerConfig &config, int64_t q n_embd_ = config.n_embd; local_n_head_ = n_head_ / tp_world_size; - q_lora_rank_ = q_lora_rank; + use_q_lora_ = q_lora_rank != -1; + q_lora_rank_ = use_q_lora_ ? q_lora_rank : 0; kv_lora_rank_ = kv_lora_rank; qk_nope_head_dim_ = qk_nope_head_dim; qk_rope_head_dim_ = qk_rope_head_dim; qk_head_dim_ = qk_nope_head_dim_ + qk_rope_head_dim_; v_head_dim_ = v_head_dim; + q_down_proj_use_tp_ = q_down_proj_use_tp; + kv_down_proj_use_tp_ = kv_down_proj_use_tp; } std::vector> MLASelfAttention::Forward(const std::vector> &x) { CHECK_GE(x.size(), 1) << "MLASelfAttention expects at least hidden states"; + // x[0]: (B, T_local, C) const auto B = x[0]->Dims()[0]; const auto C = x[0]->Dims()[2]; CHECK_EQ(C, n_embd_) << "hidden size must match n_embd"; + // freqs_cis: (T, D_rope / 2, 2) const auto freqs_cis = x.size() > 1 ? x[1] : nullptr; + // external_mask: (1, 1, T, T) const auto external_mask = x.size() > 3 ? x[3] : nullptr; if (config_.attention_type == AttentionType::kRoPE) { CHECK(freqs_cis != nullptr) << "freqs_cis is null."; } - // (B, T, C) -> q_a -> RMSNorm -> q_b -> (B, T, H_local * (D_nope + D_rope)) - auto q = (*modules_[kQAProjLayerName])({x[0]})[0]; - q = (*modules_[kQANormLayerName])({q})[0]; - q = (*modules_[kQBProjLayerName])({q})[0]; + const bool sequence_parallel_enabled = nn::parallel::global::GetSequenceParallelEnabled(); + + // ----------- Q PATH ----------- + // Q path, align with Megatron: + // - q_lora_rank == -1 -> linear_q_proj directly; + // - otherwise linear_q_down_proj -> q_layernorm -> linear_q_up_proj. + std::shared_ptr q; + if (use_q_lora_) { + // linear_q_down_proj: + // non-TP path: (B, T_local, C) -> (B, T_local, R_q) + // TP path before gather: (B, T, C) -> (B, T, R_q / TP) + // - Note that ColumnParallelLinear would perform a GatherFromSPRegion in the beginning + auto q_compressed = (*modules_[kLinearQDownProjLayerName])({x[0]})[0]; + if (q_down_proj_use_tp_ && q_compressed->Dims().back() != q_lora_rank_) { + // Gather the sharded latent dimension: (B, T, R_q / TP) -> (B, T, R_q). + q_compressed = nn::parallel::GatherFromTPRegionFunc(q_compressed)[0]; + if (sequence_parallel_enabled) { + // Keep the q_up input sequence-sharded: (B, T_full, R_q) -> (B, T_local, R_q). + q_compressed = nn::parallel::ScatterToSPRegionFunc(q_compressed)[0]; + } + } + // q_layernorm preserves shape: (B, T_local, R_q) + q_compressed = (*modules_[kQLayerNormLayerName])({q_compressed})[0]; + // linear_q_up_proj: (B, T_local, R_q) -> (B, T, H_local * (D_nope + D_rope)). + q = (*modules_[kLinearQUpProjLayerName])({q_compressed})[0]; + } else { + // linear_q_proj direct path: (B, T, C) -> (B, T, H_local * (D_nope + D_rope)). + q = (*modules_[kLinearQProjLayerName])({x[0]})[0]; + } + + // T should be the full seqlen after the q projection path gathers sequence-parallel input. const auto T = q->Dims()[1]; + // q: (B, T, H_local * D_qk) -> (B, T, H_local, D_qk) + // qk_head_dim_ = qk_nope_head_dim_ + qk_rope_head_dim_ q = q->View({B, T, local_n_head_, qk_head_dim_}); + // q_nope: (B, T, H_local, D_nope), q_pos_emb: (B, T, H_local, D_rope) auto q_nope = q->Slice(-1, 0, qk_nope_head_dim_); - auto q_pe = q->Slice(-1, qk_nope_head_dim_, qk_head_dim_); + auto q_pos_emb = q->Slice(-1, qk_nope_head_dim_, qk_head_dim_); + + // ----------- KV PATH ----------- + // linear_kv_down_proj: + // non-TP path: (B, T_local, C) -> (B, T_local, R_kv + D_rope) + // TP path before gather: (B, T, C) -> (B, T, (R_kv + D_rope) / TP) + auto compressed_kv_with_pe = (*modules_[kLinearKVDownProjLayerName])({x[0]})[0]; + const auto kv_down_proj_out_dim = kv_lora_rank_ + qk_rope_head_dim_; + const bool kv_down_proj_output_is_sharded = compressed_kv_with_pe->Dims().back() != kv_down_proj_out_dim; + if (kv_down_proj_use_tp_ && kv_down_proj_output_is_sharded) { + // Gather latent+RoPE dim: (B, T, (R_kv + D_rope) / TP) -> (B, T, R_kv + D_rope) + compressed_kv_with_pe = nn::parallel::GatherFromTPRegionFunc(compressed_kv_with_pe)[0]; + } - // (B, T, C) -> kv_a -> compressed kv latent and shared RoPE key. - auto compressed_kv_with_pe = (*modules_[kKVAProjLayerName])({x[0]})[0]; + // compressed_kv: (B, T_local, R_kv), k_pos_emb: (B, T_local, D_rope) auto compressed_kv = compressed_kv_with_pe->Slice(-1, 0, kv_lora_rank_); - auto k_pe = compressed_kv_with_pe->Slice(-1, kv_lora_rank_, kv_lora_rank_ + qk_rope_head_dim_) - ->Contiguous(); - if (nn::parallel::global::GetSequenceParallelEnabled()) { - k_pe = nn::parallel::GatherFromSPRegionFunc(k_pe)[0]; + auto k_pos_emb = compressed_kv_with_pe->Slice(-1, kv_lora_rank_, kv_lora_rank_ + qk_rope_head_dim_)->Contiguous(); + const bool k_pos_emb_has_full_sequence = kv_down_proj_use_tp_ && kv_down_proj_output_is_sharded + && sequence_parallel_enabled; + if (k_pos_emb_has_full_sequence) { + // k_pos_emb already has full T; keep only compressed_kv sequence-sharded for linear_kv_up_proj. + // compressed_kv: (B, T, R_kv) -> (B, T_local, R_kv) + compressed_kv = nn::parallel::ScatterToSPRegionFunc(compressed_kv)[0]; + } else if (sequence_parallel_enabled) { + // Replicated down-proj path produces local k_pos_emb; gather it for attention. + // k_pos_emb: (B, T_local, D_rope) -> (B, T, D_rope) + k_pos_emb = nn::parallel::GatherFromSPRegionFunc(k_pos_emb)[0]; } - k_pe = k_pe->View({B, T, 1, qk_rope_head_dim_}); + // k_pos_emb: (B, T, D_rope) -> (B, T, 1, D_rope), shared across local heads. + k_pos_emb = k_pos_emb->View({B, T, 1, qk_rope_head_dim_}); - // (B, T, R_kv) -> RMSNorm -> kv_b -> (B, T, H_local * (D_nope + D_v)) - auto kv = (*modules_[kKVANormLayerName])({compressed_kv})[0]; - kv = (*modules_[kKVBProjLayerName])({kv})[0]; + // (B, T, R_kv) -> kv_layernorm -> linear_kv_up_proj -> (B, T, H_local * (D_nope + D_v)) + // kv_layernorm preserves compressed_kv shape: (B, T_local, R_kv) + auto kv = (*modules_[kKVLayerNormLayerName])({compressed_kv})[0]; + // linear_kv_up_proj: (B, T_local, R_kv) -> (B, T, H_local * (D_nope + D_v)) + kv = (*modules_[kLinearKVUpProjLayerName])({kv})[0]; + // kv: (B, T, H_local * (D_nope + D_v)) -> (B, T, H_local, D_nope + D_v) kv = kv->View({B, T, local_n_head_, qk_nope_head_dim_ + v_head_dim_}); + // k_nope: (B, T, H_local, D_nope), v: (B, T, H_local, D_v) auto k_nope = kv->Slice(-1, 0, qk_nope_head_dim_); auto v = kv->Slice(-1, qk_nope_head_dim_, qk_nope_head_dim_ + v_head_dim_); if (config_.attention_type == AttentionType::kRoPE) { - std::tie(q_pe, k_pe) = ApplyRotaryEmbedding(q_pe, k_pe, freqs_cis); + // q_pos_emb: (B, T, H_local, D_rope), k_pos_emb: (B, T, 1, D_rope) + std::tie(q_pos_emb, k_pos_emb) = ApplyRotaryEmbedding(q_pos_emb, k_pos_emb, freqs_cis); } - k_pe = k_pe->RepeatInterleave(local_n_head_, 2); - q = nn::function::Concat(std::vector>{q_nope, q_pe}, -1); - auto k = nn::function::Concat(std::vector>{k_nope, k_pe}, -1); + // k_pos_emb: (B, T, 1, D_rope) -> (B, T, H_local, D_rope) + k_pos_emb = k_pos_emb->RepeatInterleave(local_n_head_, 2); + // q: (B, T, H_local, D_qk), k: (B, T, H_local, D_qk) + q = nn::function::Concat(std::vector>{q_nope, q_pos_emb}, -1); + auto k = nn::function::Concat(std::vector>{k_nope, k_pos_emb}, -1); - // (B, T, H_local, D) -> (B, H_local, T, D) + // ----------- CORE ATTN ----------- + // q/k: (B, T, H_local, D_qk) -> (B, H_local, T, D_qk) + // v: (B, T, H_local, D_v) -> (B, H_local, T, D_v) q = q->Transpose(1, 2); k = k->Transpose(1, 2); v = v->Transpose(1, 2); + // att: (B, H_local, T, T) auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast(qk_head_dim_))); if (external_mask) { att = att->MaskedFill(external_mask, std::numeric_limits::lowest()); } else { + // mask: (1, 1, T, T) auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1}); att = att->MaskedFill(mask == 0, -std::numeric_limits::infinity()); } + // att: (B, H_local, T, T) att = nn::function::Softmax(att, -1); + // y: (B, H_local, T, D_v) auto y = att->Matmul(v); + // y: (B, H_local, T, D_v) -> (B, T, H_local, D_v) -> (B, T, H_local * D_v) y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_n_head_ * v_head_dim_}); - y = (*modules_[kCProjLayerName])({y})[0]; + // linear_proj: (B, T, H_local * D_v) -> (B, T, C) + y = (*modules_[kLinearProjLayerName])({y})[0]; + return {y}; } diff --git a/infini_train/src/nn/parallel/tensor_parallel.cc b/infini_train/src/nn/parallel/tensor_parallel.cc index 44ab8189..b83c5e52 100644 --- a/infini_train/src/nn/parallel/tensor_parallel.cc +++ b/infini_train/src/nn/parallel/tensor_parallel.cc @@ -45,6 +45,24 @@ std::shared_ptr GatherAlongFirstDim(const std::shared_ptr &tenso return gathered_output; } +std::shared_ptr ScatterAlongFirstDim(const std::shared_ptr &tensor) { + int world_size = global::GetTensorParallelSize(); + CHECK_GT(world_size, 0) << "Tensor Parallel group not initialized"; + if (world_size == 1) { + return tensor; + } + + auto device = tensor->GetDevice(); + auto tp_group = ProcessGroupFactory::Instance(device.type()) + ->Get(GetTensorParallelProcessGroupName(device.Rank().GlobalRank())); + auto rank = tp_group->GetGroupRank(device.Rank().GlobalRank()); + + CHECK_EQ(tensor->Dims()[0] % world_size, 0) << "First dimension must be divisible by TP world size"; + auto first_dim_size = tensor->Dims()[0] / world_size; + auto shards = tensor->Split(first_dim_size, 0); + return shards[rank]->Contiguous(); +} + std::shared_ptr GatherAlongLastDim(const std::shared_ptr &tensor) { int world_size = global::GetTensorParallelSize(); CHECK_GT(world_size, 0) << "Tensor Parallel group not initialized"; @@ -214,6 +232,21 @@ class ReduceScatterToSPRegion : public autograd::Function { }; }; +class ScatterToSPRegion : public autograd::Function { +public: + static constexpr char kType[] = "ScatterToSPRegionFunction"; + + explicit ScatterToSPRegion() : autograd::Function(kType) {} + + std::vector> Forward(const std::vector> &input_tensors) override { + return {ScatterAlongFirstDim(input_tensors[0]->Transpose(0, 1))->Transpose(0, 1)}; + }; + + std::vector> Backward(const std::vector> &grad_outputs) override { + return {GatherAlongFirstDim(grad_outputs[0]->Transpose(0, 1))->Transpose(0, 1)}; + }; +}; + class GatherFromSPRegion : public autograd::Function { public: static constexpr char kType[] = "GatherFromSPRegionFunction"; @@ -263,6 +296,10 @@ std::vector> ReduceScatterToSPRegionFunc(const std::shar return std::make_shared()->Apply({input}); } +std::vector> ScatterToSPRegionFunc(const std::shared_ptr &input) { + return std::make_shared()->Apply({input}); +} + std::vector> GatherFromSPRegionFunc(const std::shared_ptr &input) { return std::make_shared()->Apply({input}); } diff --git a/tests/transformer/test_transformer_architecture.cc b/tests/transformer/test_transformer_architecture.cc index f36d10f6..047566ea 100644 --- a/tests/transformer/test_transformer_architecture.cc +++ b/tests/transformer/test_transformer_architecture.cc @@ -4,6 +4,7 @@ #include "gtest/gtest.h" +#include "infini_train/include/nn/modules/linear.h" #include "infini_train/include/nn/modules/normalization.h" #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" @@ -12,6 +13,7 @@ #include "infini_train/include/nn/modules/transformer/transformer.h" #include "infini_train/include/nn/modules/transformer/transformer_config.h" #include "infini_train/include/nn/modules/transformer/utils.h" +#include "infini_train/include/nn/parallel/tensor_parallel.h" #include "infini_train/include/tensor.h" #include "tests/common/test_utils.h" @@ -129,10 +131,42 @@ TEST_P(TransformerModuleTest, MLAAttention) { /*v_head_dim=*/16); attn->To(GetDevice()); EXPECT_FALSE(attn->Parameters().empty()); + EXPECT_EQ(attn->module(nn::MLASelfAttention::kLinearQDownProjLayerName).type(), nn::Linear::kType); + EXPECT_EQ(attn->module(nn::MLASelfAttention::kLinearKVDownProjLayerName).type(), nn::Linear::kType); auto input = std::make_shared(std::vector{2, 8, 64}, DataType::kFLOAT32, GetDevice()); auto output = (*attn)({input}); EXPECT_EQ(output[0]->Dims(), input->Dims()); + + auto tp_down_attn = std::make_shared( + config, + /*q_lora_rank=*/32, + /*kv_lora_rank=*/32, + /*qk_nope_head_dim=*/8, + /*qk_rope_head_dim=*/8, + /*v_head_dim=*/16, + /*q_down_proj_use_tp=*/true, + /*kv_down_proj_use_tp=*/true); + tp_down_attn->To(GetDevice()); + EXPECT_EQ(tp_down_attn->module(nn::MLASelfAttention::kLinearQDownProjLayerName).type(), + nn::parallel::ColumnParallelLinear::kType); + EXPECT_EQ(tp_down_attn->module(nn::MLASelfAttention::kLinearKVDownProjLayerName).type(), + nn::parallel::ColumnParallelLinear::kType); + output = (*tp_down_attn)({input}); + EXPECT_EQ(output[0]->Dims(), input->Dims()); + + auto direct_q_attn = std::make_shared( + config, + /*q_lora_rank=*/-1, + /*kv_lora_rank=*/32, + /*qk_nope_head_dim=*/8, + /*qk_rope_head_dim=*/8, + /*v_head_dim=*/16); + direct_q_attn->To(GetDevice()); + EXPECT_EQ(direct_q_attn->module(nn::MLASelfAttention::kLinearQProjLayerName).type(), + nn::parallel::ColumnParallelLinear::kType); + output = (*direct_q_attn)({input}); + EXPECT_EQ(output[0]->Dims(), input->Dims()); } TEST_P(TransformerModuleTest, GPT2TransformerLayer) { From dd18b35eeb8f9a80b34b0d95b9f8f06c4447df51 Mon Sep 17 00:00:00 2001 From: bolunz Date: Fri, 29 May 2026 02:36:49 +0000 Subject: [PATCH 3/3] fix: move mla args into TransformerConfig --- .../modules/transformer/mla_self_attention.h | 7 +-- .../modules/transformer/transformer_config.h | 10 ++++ .../modules/transformer/mla_self_attention.cc | 60 ++++++------------- .../src/nn/modules/transformer/transformer.cc | 17 ++++-- .../test_transformer_architecture.cc | 40 +++++-------- 5 files changed, 58 insertions(+), 76 deletions(-) diff --git a/infini_train/include/nn/modules/transformer/mla_self_attention.h b/infini_train/include/nn/modules/transformer/mla_self_attention.h index 75b9da3a..63177cc6 100644 --- a/infini_train/include/nn/modules/transformer/mla_self_attention.h +++ b/infini_train/include/nn/modules/transformer/mla_self_attention.h @@ -24,9 +24,6 @@ class MLASelfAttention : public infini_train::nn::CloneableModule> Forward(const std::vector> &x) override; @@ -48,9 +45,7 @@ class MLASelfAttention : public infini_train::nn::CloneableModule q_lora_rank = std::nullopt; // nullopt means direct linear_q_proj path. + int64_t kv_lora_rank = 0; // 0 falls back to n_embd in MLASelfAttention. + int64_t qk_nope_head_dim = 0; // 0 falls back to n_embd / n_head. + int64_t qk_rope_head_dim = 0; // 0 falls back to n_embd / n_head. + int64_t v_head_dim = 0; // 0 falls back to n_embd / n_head. + bool q_down_proj_use_tp = false; // Use ColumnParallelLinear for linear_q_down_proj. + bool kv_down_proj_use_tp = false; // Use ColumnParallelLinear for linear_kv_down_proj. + // Normalization float norm_eps = 1e-5f; // epsilon in RMSNorm diff --git a/infini_train/src/nn/modules/transformer/mla_self_attention.cc b/infini_train/src/nn/modules/transformer/mla_self_attention.cc index 423c91c5..7549e812 100644 --- a/infini_train/src/nn/modules/transformer/mla_self_attention.cc +++ b/infini_train/src/nn/modules/transformer/mla_self_attention.cc @@ -19,35 +19,9 @@ #include "infini_train/include/tensor.h" namespace infini_train::nn { -namespace { -int64_t DefaultQKVHeadDim(const TransformerConfig &config) { - CHECK_EQ(config.n_embd % config.n_head, 0) << "n_embd must be divisible by n_head"; - return config.n_embd / config.n_head; -} - -int64_t DefaultQKRoPEHeadDim(const TransformerConfig &config) { - return DefaultQKVHeadDim(config); -} -int64_t DefaultQKNoPEHeadDim(const TransformerConfig &config) { - return DefaultQKVHeadDim(config); -} -} // namespace - -MLASelfAttention::MLASelfAttention(const TransformerConfig &config) - : MLASelfAttention(config, - /*q_lora_rank=*/config.n_embd, - /*kv_lora_rank=*/config.n_embd, - /*qk_nope_head_dim=*/DefaultQKNoPEHeadDim(config), - /*qk_rope_head_dim=*/DefaultQKRoPEHeadDim(config), - /*v_head_dim=*/DefaultQKVHeadDim(config)) {} - -MLASelfAttention::MLASelfAttention(const TransformerConfig &config, int64_t q_lora_rank, int64_t kv_lora_rank, - int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim, - bool q_down_proj_use_tp, bool kv_down_proj_use_tp) - : CloneableModule(kType), config_(config) { - SetupAttention(config, q_lora_rank, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, - q_down_proj_use_tp, kv_down_proj_use_tp); +MLASelfAttention::MLASelfAttention(const TransformerConfig &config) : CloneableModule(kType), config_(config) { + SetupAttention(config); if (use_q_lora_) { if (q_down_proj_use_tp_) { @@ -123,15 +97,19 @@ MLASelfAttention::MLASelfAttention(const TransformerConfig &config, int64_t q_lo ->View({1, 1, config_.block_size, config_.block_size}); } -void MLASelfAttention::SetupAttention(const TransformerConfig &config, int64_t q_lora_rank, int64_t kv_lora_rank, - int64_t qk_nope_head_dim, int64_t qk_rope_head_dim, int64_t v_head_dim, - bool q_down_proj_use_tp, bool kv_down_proj_use_tp) { +void MLASelfAttention::SetupAttention(const TransformerConfig &config) { auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); CHECK_EQ(config.n_embd % config.n_head, 0) << "n_embd must be divisible by n_head"; CHECK_EQ(config.n_head % tp_world_size, 0) << "n_head must be divisible by TP world size"; - CHECK(q_lora_rank == -1 || q_lora_rank > 0) << "q_lora_rank must be positive, or -1 to disable q LoRA"; - CHECK_GT(kv_lora_rank, 0) << "kv_lora_rank must be positive"; + CHECK(!config.q_lora_rank.has_value() || config.q_lora_rank.value() > 0) << "q_lora_rank must be positive when set"; + + const auto default_head_dim = config.n_embd / config.n_head; + const int64_t kv_lora_rank = config.kv_lora_rank > 0 ? config.kv_lora_rank : config.n_embd; + const int64_t qk_nope_head_dim = config.qk_nope_head_dim > 0 ? config.qk_nope_head_dim : default_head_dim; + const int64_t qk_rope_head_dim = config.qk_rope_head_dim > 0 ? config.qk_rope_head_dim : default_head_dim; + const int64_t v_head_dim = config.v_head_dim > 0 ? config.v_head_dim : default_head_dim; + CHECK_GT(qk_nope_head_dim, 0) << "qk_nope_head_dim must be positive"; CHECK_GT(qk_rope_head_dim, 0) << "qk_rope_head_dim must be positive"; CHECK_GT(v_head_dim, 0) << "v_head_dim must be positive"; @@ -141,15 +119,15 @@ void MLASelfAttention::SetupAttention(const TransformerConfig &config, int64_t q n_embd_ = config.n_embd; local_n_head_ = n_head_ / tp_world_size; - use_q_lora_ = q_lora_rank != -1; - q_lora_rank_ = use_q_lora_ ? q_lora_rank : 0; + use_q_lora_ = config.q_lora_rank.has_value(); + q_lora_rank_ = config.q_lora_rank.value_or(0); kv_lora_rank_ = kv_lora_rank; qk_nope_head_dim_ = qk_nope_head_dim; qk_rope_head_dim_ = qk_rope_head_dim; qk_head_dim_ = qk_nope_head_dim_ + qk_rope_head_dim_; v_head_dim_ = v_head_dim; - q_down_proj_use_tp_ = q_down_proj_use_tp; - kv_down_proj_use_tp_ = kv_down_proj_use_tp; + q_down_proj_use_tp_ = config.q_down_proj_use_tp; + kv_down_proj_use_tp_ = config.kv_down_proj_use_tp; } std::vector> @@ -173,7 +151,7 @@ MLASelfAttention::Forward(const std::vector linear_q_proj directly; + // - q_lora_rank == nullopt -> linear_q_proj directly; // - otherwise linear_q_down_proj -> q_layernorm -> linear_q_up_proj. std::shared_ptr q; if (use_q_lora_) { @@ -224,8 +202,8 @@ MLASelfAttention::Forward(const std::vectorSlice(-1, 0, kv_lora_rank_); auto k_pos_emb = compressed_kv_with_pe->Slice(-1, kv_lora_rank_, kv_lora_rank_ + qk_rope_head_dim_)->Contiguous(); - const bool k_pos_emb_has_full_sequence = kv_down_proj_use_tp_ && kv_down_proj_output_is_sharded - && sequence_parallel_enabled; + const bool k_pos_emb_has_full_sequence + = kv_down_proj_use_tp_ && kv_down_proj_output_is_sharded && sequence_parallel_enabled; if (k_pos_emb_has_full_sequence) { // k_pos_emb already has full T; keep only compressed_kv sequence-sharded for linear_kv_up_proj. // compressed_kv: (B, T, R_kv) -> (B, T_local, R_kv) @@ -285,7 +263,7 @@ MLASelfAttention::Forward(const std::vectorTranspose(1, 2)->Contiguous()->View({B, T, local_n_head_ * v_head_dim_}); // linear_proj: (B, T, H_local * D_v) -> (B, T, C) y = (*modules_[kLinearProjLayerName])({y})[0]; - + return {y}; } diff --git a/infini_train/src/nn/modules/transformer/transformer.cc b/infini_train/src/nn/modules/transformer/transformer.cc index c7e0f28c..048cf96c 100644 --- a/infini_train/src/nn/modules/transformer/transformer.cc +++ b/infini_train/src/nn/modules/transformer/transformer.cc @@ -14,6 +14,7 @@ #include "infini_train/include/nn/modules/normalization.h" #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" +#include "infini_train/include/nn/modules/transformer/mla_self_attention.h" #include "infini_train/include/nn/modules/transformer/mlp.h" #include "infini_train/include/nn/modules/transformer/utils.h" #include "infini_train/include/nn/parallel/global.h" @@ -28,8 +29,8 @@ TransformerFirstStage::TransformerFirstStage(const TransformerConfig &config) modules_[kWTELayerName] = std::make_shared( config_.vocab_size, config_.n_embd, parallel::global::GetSequenceParallelEnabled()); - // LLaMA3 use RoPE, so they don't need position embedding - if (config_.activation_type == MLPType::kGELU) { + // RoPE-based models do not use absolute position embedding. + if (config_.attention_type == AttentionType::kStandard) { modules_[kWPELayerName] = std::make_shared(config_.block_size, config_.n_embd); } } @@ -85,7 +86,11 @@ TransformerLayer::TransformerLayer(const nn::TransformerConfig &config) : Clonea LOG(FATAL) << "Unsupported norm type"; } - modules_[kAttnLayerName] = std::make_shared(config); + if (config.multi_latent_attention) { + modules_[kAttnLayerName] = std::make_shared(config); + } else { + modules_[kAttnLayerName] = std::make_shared(config); + } modules_[kMlpLayerName] = std::make_shared(config); } @@ -135,8 +140,10 @@ std::vector> TransformerChunk::Forward(const std::vector // Init freqs_cis on device only once if (buffers_[kFreqsCisName] == nullptr) { - int64_t head_dim = config_.n_embd / config_.n_head; - buffers_[kFreqsCisName] = PrecomputeFreqsCis(head_dim, config_.block_size * 2, config_.rope_theta, + int64_t rope_head_dim = config_.multi_latent_attention && config_.qk_rope_head_dim > 0 + ? config_.qk_rope_head_dim + : config_.n_embd / config_.n_head; + buffers_[kFreqsCisName] = PrecomputeFreqsCis(rope_head_dim, config_.block_size * 2, config_.rope_theta, config_.use_scaled_rope, device); } diff --git a/tests/transformer/test_transformer_architecture.cc b/tests/transformer/test_transformer_architecture.cc index 047566ea..9ff660a8 100644 --- a/tests/transformer/test_transformer_architecture.cc +++ b/tests/transformer/test_transformer_architecture.cc @@ -1,5 +1,6 @@ #include #include +#include #include #include "gtest/gtest.h" @@ -121,14 +122,14 @@ TEST_P(TransformerModuleTest, MLAAttention) { config.block_size = 16; config.attention_type = nn::AttentionType::kStandard; config.add_bias_linear = true; - - auto attn = std::make_shared( - config, - /*q_lora_rank=*/32, - /*kv_lora_rank=*/32, - /*qk_nope_head_dim=*/8, - /*qk_rope_head_dim=*/8, - /*v_head_dim=*/16); + config.multi_latent_attention = true; + config.q_lora_rank = 32; + config.kv_lora_rank = 32; + config.qk_nope_head_dim = 8; + config.qk_rope_head_dim = 8; + config.v_head_dim = 16; + + auto attn = std::make_shared(config); attn->To(GetDevice()); EXPECT_FALSE(attn->Parameters().empty()); EXPECT_EQ(attn->module(nn::MLASelfAttention::kLinearQDownProjLayerName).type(), nn::Linear::kType); @@ -138,15 +139,10 @@ TEST_P(TransformerModuleTest, MLAAttention) { auto output = (*attn)({input}); EXPECT_EQ(output[0]->Dims(), input->Dims()); - auto tp_down_attn = std::make_shared( - config, - /*q_lora_rank=*/32, - /*kv_lora_rank=*/32, - /*qk_nope_head_dim=*/8, - /*qk_rope_head_dim=*/8, - /*v_head_dim=*/16, - /*q_down_proj_use_tp=*/true, - /*kv_down_proj_use_tp=*/true); + auto tp_down_config = config; + tp_down_config.q_down_proj_use_tp = true; + tp_down_config.kv_down_proj_use_tp = true; + auto tp_down_attn = std::make_shared(tp_down_config); tp_down_attn->To(GetDevice()); EXPECT_EQ(tp_down_attn->module(nn::MLASelfAttention::kLinearQDownProjLayerName).type(), nn::parallel::ColumnParallelLinear::kType); @@ -155,13 +151,9 @@ TEST_P(TransformerModuleTest, MLAAttention) { output = (*tp_down_attn)({input}); EXPECT_EQ(output[0]->Dims(), input->Dims()); - auto direct_q_attn = std::make_shared( - config, - /*q_lora_rank=*/-1, - /*kv_lora_rank=*/32, - /*qk_nope_head_dim=*/8, - /*qk_rope_head_dim=*/8, - /*v_head_dim=*/16); + auto direct_q_config = config; + direct_q_config.q_lora_rank = std::nullopt; + auto direct_q_attn = std::make_shared(direct_q_config); direct_q_attn->To(GetDevice()); EXPECT_EQ(direct_q_attn->module(nn::MLASelfAttention::kLinearQProjLayerName).type(), nn::parallel::ColumnParallelLinear::kType);