Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#pragma once

#include <memory>
#include <tuple>
#include <vector>

#include "infini_train/include/nn/modules/module.h"
Expand Down Expand Up @@ -43,12 +42,6 @@ class CausalSelfAttention : public infini_train::nn::CloneableModule<CausalSelfA
std::vector<std::shared_ptr<infini_train::Tensor>>
ForwardWithRoPE(const std::vector<std::shared_ptr<infini_train::Tensor>> &x);

// RoPE helper methods
std::tuple<std::shared_ptr<infini_train::Tensor>, std::shared_ptr<infini_train::Tensor>>
ApplyRotaryEmbedding(const std::shared_ptr<infini_train::Tensor> &xq,
const std::shared_ptr<infini_train::Tensor> &xk,
const std::shared_ptr<infini_train::Tensor> &freqs_cis);

// GQA helper method
std::shared_ptr<infini_train::Tensor> RepeatKV(const std::shared_ptr<infini_train::Tensor> &x, int64_t n_rep);
};
Expand Down
51 changes: 51 additions & 0 deletions infini_train/include/nn/modules/transformer/mla_self_attention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#pragma once

#include <memory>
#include <vector>

#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/nn/modules/transformer/transformer_config.h"

namespace infini_train::nn {

class MLASelfAttention : public infini_train::nn::CloneableModule<MLASelfAttention> {
public:
static constexpr char kType[] = "MLASelfAttention";

static constexpr char kLinearQProjLayerName[] = "linear_q_proj";
static constexpr char kLinearQDownProjLayerName[] = "linear_q_down_proj";
static constexpr char kQLayerNormLayerName[] = "q_layernorm";
static constexpr char kLinearQUpProjLayerName[] = "linear_q_up_proj";
static constexpr char kLinearKVDownProjLayerName[] = "linear_kv_down_proj";
static constexpr char kKVLayerNormLayerName[] = "kv_layernorm";
static constexpr char kLinearKVUpProjLayerName[] = "linear_kv_up_proj";
static constexpr char kLinearProjLayerName[] = "linear_proj";

static constexpr char kParamBiasName[] = "bias";

explicit MLASelfAttention(const TransformerConfig &config);

std::vector<std::shared_ptr<infini_train::Tensor>>
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

private:
TransformerConfig config_;
int64_t n_head_ = 0;
int64_t n_embd_ = 0;
int64_t local_n_head_ = 0;

int64_t q_lora_rank_ = 0;
int64_t kv_lora_rank_ = 0;
int64_t qk_nope_head_dim_ = 0;
int64_t qk_rope_head_dim_ = 0;
int64_t qk_head_dim_ = 0;
int64_t v_head_dim_ = 0;

bool use_q_lora_ = true;
bool q_down_proj_use_tp_ = false;
bool kv_down_proj_use_tp_ = false;

void SetupAttention(const TransformerConfig &config);
};

} // namespace infini_train::nn
10 changes: 10 additions & 0 deletions infini_train/include/nn/modules/transformer/transformer_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ struct TransformerConfig {
float rope_theta = 500000.0f; // theta in RoPE
bool use_scaled_rope = false; // scaled RoPE

// MLA config
bool multi_latent_attention = false; // Use MLA instead of standard causal self-attention.
std::optional<int64_t> q_lora_rank = std::nullopt; // nullopt means direct linear_q_proj path.
int64_t kv_lora_rank = 0; // 0 falls back to n_embd in MLASelfAttention.
int64_t qk_nope_head_dim = 0; // 0 falls back to n_embd / n_head.
int64_t qk_rope_head_dim = 0; // 0 falls back to n_embd / n_head.
int64_t v_head_dim = 0; // 0 falls back to n_embd / n_head.
bool q_down_proj_use_tp = false; // Use ColumnParallelLinear for linear_q_down_proj.
bool kv_down_proj_use_tp = false; // Use ColumnParallelLinear for linear_kv_down_proj.

// Normalization
float norm_eps = 1e-5f; // epsilon in RMSNorm

Expand Down
6 changes: 6 additions & 0 deletions infini_train/include/nn/modules/transformer/utils.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
#pragma once

#include <cstdint>
#include <memory>
#include <tuple>

#include "infini_train/include/tensor.h"

namespace infini_train {
// RoPE helper method
std::shared_ptr<Tensor> PrecomputeFreqsCis(int64_t dim, int64_t end, float theta = 10000.0f, bool use_scaled = false,
Device device = Device());

std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>
ApplyRotaryEmbedding(const std::shared_ptr<Tensor> &xq, const std::shared_ptr<Tensor> &xk,
const std::shared_ptr<Tensor> &freqs_cis);
} // namespace infini_train
1 change: 1 addition & 0 deletions infini_train/include/nn/parallel/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ std::vector<int> GetPipelineParallelGroupRanks(int global_rank);

// TP/SP Communication Helper Functions
std::vector<std::shared_ptr<Tensor>> GatherFromTPRegionFunc(const std::shared_ptr<Tensor> &input);
std::vector<std::shared_ptr<Tensor>> ScatterToSPRegionFunc(const std::shared_ptr<Tensor> &input);
std::vector<std::shared_ptr<Tensor>> ReduceScatterToSPRegionFunc(const std::shared_ptr<Tensor> &input);
std::vector<std::shared_ptr<Tensor>> GatherFromSPRegionFunc(const std::shared_ptr<Tensor> &input);
std::vector<std::shared_ptr<Tensor>> ScatterToTPRegionFunc(const std::shared_ptr<Tensor> &input);
Expand Down
38 changes: 1 addition & 37 deletions infini_train/src/nn/modules/transformer/causal_self_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "infini_train/include/nn/modules/normalization.h"
#include "infini_train/include/nn/modules/sparse.h"
#include "infini_train/include/nn/modules/transformer/transformer_config.h"
#include "infini_train/include/nn/modules/transformer/utils.h"
#include "infini_train/include/nn/parallel/global.h"
#include "infini_train/include/nn/parallel/tensor_parallel.h"
#include "infini_train/include/tensor.h"
Expand Down Expand Up @@ -130,43 +131,6 @@ CausalSelfAttention::ForwardStandard(const std::vector<std::shared_ptr<infini_tr
return {y};
}

// RoPE helper methods
std::tuple<std::shared_ptr<infini_train::Tensor>, std::shared_ptr<infini_train::Tensor>>
CausalSelfAttention::ApplyRotaryEmbedding(const std::shared_ptr<infini_train::Tensor> &xq,
const std::shared_ptr<infini_train::Tensor> &xk,
const std::shared_ptr<infini_train::Tensor> &freqs_cis) {
// Reshape freqs_cis for broadcasting
const auto &x_shape = xq->Dims(); // (B, T, H, D)
const int64_t T = x_shape[1];
const int64_t D = x_shape[3];

std::vector<int64_t> target_shape = {1, T, 1, D / 2, 2};
auto cos_sin = freqs_cis->View(target_shape); // -> (1, T, 1, D/2, 2)

auto cos = cos_sin->Slice(-1, 0, 1, 1)->Squeeze(-1); // (1, T, 1, D/2)
auto sin = cos_sin->Slice(-1, 1, 2, 1)->Squeeze(-1); // (1, T, 1, D/2)

auto slice_pair = [](const std::shared_ptr<Tensor> &x) {
auto even = x->Slice(-1, 0, x->Dims().back(), 2);
auto odd = x->Slice(-1, 1, x->Dims().back(), 2);
return std::make_pair(even, odd);
};

auto [q_even, q_odd] = slice_pair(xq);
auto q_rotated_left = q_even * cos - q_odd * sin;
auto q_rotated_right = q_even * sin + q_odd * cos;
auto q_rotated
= nn::function::Stack(std::vector<std::shared_ptr<Tensor>>{q_rotated_left, q_rotated_right}, -1)->Flatten(-2);

auto [k_even, k_odd] = slice_pair(xk);
auto k_rotated_left = k_even * cos - k_odd * sin;
auto k_rotated_right = k_even * sin + k_odd * cos;
auto k_rotated
= nn::function::Stack(std::vector<std::shared_ptr<Tensor>>{k_rotated_left, k_rotated_right}, -1)->Flatten(-2);

return {q_rotated, k_rotated};
}

std::shared_ptr<infini_train::Tensor> CausalSelfAttention::RepeatKV(const std::shared_ptr<infini_train::Tensor> &x,
int64_t n_rep) {
const auto &shape = x->Dims();
Expand Down
Loading
Loading