From 0c37636bb33bda768fffda4c8a9fd307f7340022 Mon Sep 17 00:00:00 2001 From: zhushuang Date: Thu, 21 May 2026 16:12:29 +0800 Subject: [PATCH] issue/1167 - feat: add flash-attn via MooreThreads/mate for moore gpu --- .gitmodules | 6 + README.md | 13 +- include/infinicore/adaptor/aten_adaptor.hpp | 19 +- python/infinicore/__init__.py | 6 + .../infinicore/ops/moore_mate_flash_attn.py | 304 ++++++++++++++++++ src/infinicore/adaptor/aten_adaptor.cc | 7 + .../ops/mha_kvcache/mha_kvcache_flashattn.cc | 90 ++++++ .../mha_varlen_flashattn.cc | 91 ++++++ third_party/mate | 1 + xmake.lua | 107 +++++- 10 files changed, 633 insertions(+), 11 deletions(-) create mode 100644 python/infinicore/ops/moore_mate_flash_attn.py create mode 160000 third_party/mate diff --git a/.gitmodules b/.gitmodules index bca919479..b7a65778a 100644 --- a/.gitmodules +++ b/.gitmodules @@ -5,3 +5,9 @@ path = third_party/nlohmann_json url = https://github.com/nlohmann/json.git branch = master +[submodule "third_party/mate"] + path = third_party/mate + url = https://github.com/MooreThreads/mate + branch = v0.1.3 + ignore = untracked + update = none diff --git a/README.md b/README.md index bd0f7fe64..f5895008f 100644 --- a/README.md +++ b/README.md @@ -155,7 +155,7 @@ python scripts/install.py [XMAKE_CONFIG_FLAGS] xmake f --ascend-npu=true -cv ``` -##### 试验功能 -- 使用flash attention库中的算子 +##### 试验功能 -- 使用英伟达平台 flash attention 库中的算子 ```shell @@ -176,6 +176,17 @@ python scripts/install.py [XMAKE_CONFIG_FLAGS] ``` +##### 试验功能 -- 使用摩尔线程开源 mate 提供的 flash attention 能力 + ```shell + #该功能依赖摩尔线程开源项目 mate(https://github.com/MooreThreads/mate) v0.1.3 版本,默认不随仓库递归拉取。 + + #若需启用 Moore MATE FlashAttention,请手动初始化对应子模块: + git -c submodule.third_party/mate.update=checkout submodule update --init --recursive third_party/mate + + #随后参考 mate v0.1.3 README 进行编译,之后在 xmake 配置环节额外打开 --aten 开关使用 mate 提供的 flash attention 能力,可参考: + xmake f --moore-gpu=y --ccl=y --aten=y -cv + ``` + 2. 编译安装 默认安装路径为 `$HOME/.infini`。 diff --git a/include/infinicore/adaptor/aten_adaptor.hpp b/include/infinicore/adaptor/aten_adaptor.hpp index 00d5cbec2..030be7d0e 100644 --- a/include/infinicore/adaptor/aten_adaptor.hpp +++ b/include/infinicore/adaptor/aten_adaptor.hpp @@ -11,6 +11,12 @@ #include #endif +#if defined(ENABLE_MOORE_API) +#include +#include +#include +#endif + namespace infinicore::adaptor { inline at::ScalarType to_at_dtype(DataType dtype) { switch (dtype) { @@ -36,7 +42,13 @@ inline at::Device to_at_device(const Device &device) { return at::Device(at::kCUDA, device.getIndex()); } else if (device.getType() == Device::Type::CPU) { return at::Device(at::kCPU); - } else { + } +#if defined(ENABLE_MOORE_API) + else if (device.getType() == Device::Type::MOORE) { + return at::Device(at::DeviceType::PrivateUse1, device.getIndex()); + } +#endif + else { throw std::runtime_error("Unsupported device type for ATen"); } } @@ -46,6 +58,11 @@ at::Tensor to_aten_tensor(const infinicore::Tensor &t); #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) c10::cuda::CUDAStream get_cuda_stream(); #endif + +#if defined(ENABLE_MOORE_API) +c10::musa::MUSAStream get_musa_stream(); +#endif + } // namespace infinicore::adaptor #endif // ENABLE_ATEN diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index 8c9adc64c..6c947887d 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -106,6 +106,10 @@ from infinicore.ops.matmul import matmul from infinicore.ops.mha_kvcache import mha_kvcache from infinicore.ops.mha_varlen import mha_varlen +from infinicore.ops.moore_mate_flash_attn import ( + moore_mate_flash_attn_decode, + moore_mate_flash_attn_prefill, +) from infinicore.ops.mul import mul from infinicore.ops.narrow import narrow from infinicore.ops.nrm2 import nrm2 @@ -276,6 +280,8 @@ "zeros", "sum", "var_mean", + "moore_mate_flash_attn_prefill", + "moore_mate_flash_attn_decode", "var", "topk", "all", diff --git a/python/infinicore/ops/moore_mate_flash_attn.py b/python/infinicore/ops/moore_mate_flash_attn.py new file mode 100644 index 000000000..e50d0f6a8 --- /dev/null +++ b/python/infinicore/ops/moore_mate_flash_attn.py @@ -0,0 +1,304 @@ +""" +Paged Flash-Attention wrapper backed by MooreThreads mate (flash_attn). + +Runtime requirements: + - torch (with MUSA) + - mate (repo : https://github.com/MooreThreads/mate) + +Provides three entry points: + - moore_mate_flash_attn_decode: decode with layout (num_blocks, block_size, num_kv_heads, head_size) + - moore_mate_flash_attn_prefill: variable-length prefill (used by mha_varlen) +""" + +import torch + +try: + from flash_attn import flash_attn_with_kvcache, get_scheduler_metadata + + _MATE_AVAILABLE = True +except ImportError: + _MATE_AVAILABLE = False + + +def is_available() -> bool: + """Return True if mate / flash_attn is installed and importable.""" + return _MATE_AVAILABLE + + +def _check_mate_available(): + """Raise a clear error if mate is not installed.""" + if not _MATE_AVAILABLE: + raise RuntimeError( + "flash_attn (mate) is not installed. " + "Please build and install MooreThreads/mate first." + ) + + +# ============================================================================= +# Decode kernels +# ============================================================================= + + +@torch.inference_mode() +def moore_mate_flash_attn_decode( + q: torch.Tensor, # (num_seqs, num_heads, head_size) + k_cache: torch.Tensor, # (num_blocks, block_size, num_kv_heads, head_size) + v_cache: torch.Tensor, # (num_blocks, block_size, num_kv_heads, head_size) + block_tables: torch.Tensor, # (num_seqs, max_blocks_per_seq) + seq_lens: torch.Tensor, # (num_seqs,) + scale: float, + block_size: int, + max_seq_len: int, +) -> torch.Tensor: + """ + Decode entry point with native flash_attn KV cache layout (B, P, H, D). + No layout conversion is performed. + """ + _check_mate_available() + + num_seqs, num_heads, head_size = q.shape + num_kv_heads = k_cache.shape[2] + device = q.device + + cache_seqlens = seq_lens.to(torch.int32) + page_table = block_tables.to(torch.int32) + cu_seqlens_q = torch.arange(0, num_seqs + 1, dtype=torch.int32, device=device) + pack_gqa = num_heads != num_kv_heads + + metadata = get_scheduler_metadata( + batch_size=num_seqs, + max_seqlen_q=1, + max_seqlen_k=max_seq_len, + num_heads_q=num_heads, + num_heads_kv=num_kv_heads, + headdim=head_size, + cache_seqlens=cache_seqlens, + qkv_dtype=q.dtype, + headdim_v=head_size, + cu_seqlens_q=cu_seqlens_q, + page_size=block_size, + causal=False, + window_size=(None, None), + pack_gqa=pack_gqa, + ) + + out, *_ = flash_attn_with_kvcache( + q=q, + k_cache=k_cache, + v_cache=v_cache, + cache_seqlens=cache_seqlens, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=1, + softmax_scale=scale, + causal=False, + scheduler_metadata=metadata, + pack_gqa=pack_gqa, + return_softmax_lse=True, + ) + return out + + +# ============================================================================= +# Prefill kernel (variable-length) +# ============================================================================= + + +@torch.inference_mode() +def moore_mate_flash_attn_prefill( + q: torch.Tensor, # (total_q, num_heads, head_size) -- varlen unpad + k_cache: torch.Tensor, # (num_blocks, block_size, num_kv_heads, head_size) + v_cache: torch.Tensor, # (num_blocks, block_size, num_kv_heads, head_size) + cu_seqlens_q: torch.Tensor, # (batch+1,) int32 + cu_seqlens_k: torch.Tensor, # (batch+1,) int32 + block_tables: torch.Tensor, # (batch, max_blocks_per_seq) + scale: float, + max_seqlen_q: int, + max_seqlen_k: int, + block_size: int, + causal: bool = True, # prefill is typically causal +) -> torch.Tensor: + """ + Variable-length prefill entry point. Layout follows flash_attn (B, P, H, D). + Intended to be called from the C++ mha_varlen Moore branch. + """ + _check_mate_available() + + cu_seqlens_q = cu_seqlens_q.to(torch.int32) + cu_seqlens_k = cu_seqlens_k.to(torch.int32) + page_table = block_tables.to(torch.int32) + + # mate uses cache_seqlens (per-batch KV length), derived from cu_seqlens_k + cache_seqlens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).to(torch.int32).contiguous() + + batch_size = cache_seqlens.shape[0] + num_heads = q.shape[1] + head_size = q.shape[2] + num_kv_heads = k_cache.shape[2] + pack_gqa = num_heads != num_kv_heads + + metadata = get_scheduler_metadata( + batch_size=batch_size, + max_seqlen_q=int(max_seqlen_q), + max_seqlen_k=int(max_seqlen_k), + num_heads_q=num_heads, + num_heads_kv=num_kv_heads, + headdim=head_size, + cache_seqlens=cache_seqlens, + qkv_dtype=q.dtype, + headdim_v=head_size, + cu_seqlens_q=cu_seqlens_q, + page_size=block_size, + causal=causal, + window_size=(None, None), + pack_gqa=pack_gqa, + ) + + out, *_ = flash_attn_with_kvcache( + q=q, + k_cache=k_cache, + v_cache=v_cache, + cache_seqlens=cache_seqlens, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=int(max_seqlen_q), + softmax_scale=scale, + causal=causal, + scheduler_metadata=metadata, + pack_gqa=pack_gqa, + return_softmax_lse=True, + ) + return out + + +# ============================================================================= +# Self tests +# ============================================================================= + + +def _test_moore_mate_flash_attn_decode(): + """Test moore_mate_flash_attn_decode with flash_attn layout (B, P, H, D).""" + print("\n=== Test 1: moore_mate_flash_attn_decode (decode, flash_attn layout) ===") + device = torch.device("musa") + + num_seqs, num_heads, num_kv_heads = 2, 8, 2 + head_size, block_size, max_seq_len = 128, 16, 64 + num_blocks = 32 + + q = torch.randn(num_seqs, num_heads, head_size, dtype=torch.float16, device=device) + k_cache = torch.randn( + num_blocks, + block_size, + num_kv_heads, + head_size, + dtype=torch.float16, + device=device, + ) + v_cache = torch.randn( + num_blocks, + block_size, + num_kv_heads, + head_size, + dtype=torch.float16, + device=device, + ) + block_tables = torch.zeros(num_seqs, 4, dtype=torch.int32, device=device) + block_tables[0, 0] = 0 + block_tables[1, 0] = 1 + seq_lens = torch.tensor([32, 48], dtype=torch.int32, device=device) + + out = moore_mate_flash_attn_decode( + q=q, + k_cache=k_cache, + v_cache=v_cache, + block_tables=block_tables, + seq_lens=seq_lens, + scale=head_size**-0.5, + block_size=block_size, + max_seq_len=max_seq_len, + ) + torch.musa.synchronize() + print(f"output shape = {tuple(out.shape)}") + assert out.shape == q.shape + print("moore_mate_flash_attn_decode passed") + + +def _test_moore_mate_flash_attn_prefill(): + """Test moore_mate_flash_attn_prefill with variable-length input.""" + print("\n=== Test 2: moore_mate_flash_attn_prefill (varlen prefill) ===") + device = torch.device("musa") + torch.manual_seed(666) + torch.musa.manual_seed(666) + + batch_size = 2 + seqlens_q = [55, 222] + seqlens_kv = [55, 222] # prefill: q_len == k_len + num_heads, num_kv_heads = 8, 2 + head_size, block_size = 128, 16 + + total_q = sum(seqlens_q) + max_q = max(seqlens_q) + max_k = max(seqlens_kv) + num_blocks_per_seq = (max_k + block_size - 1) // block_size + num_blocks = batch_size * num_blocks_per_seq + + q_unpad = torch.randn( + total_q, num_heads, head_size, dtype=torch.float16, device=device + ) + k_cache = torch.randn( + num_blocks, + block_size, + num_kv_heads, + head_size, + dtype=torch.float16, + device=device, + ) + v_cache = torch.randn( + num_blocks, + block_size, + num_kv_heads, + head_size, + dtype=torch.float16, + device=device, + ) + + cu_seqlens_q = torch.tensor( + [0] + list(torch.cumsum(torch.tensor(seqlens_q), 0).tolist()), + dtype=torch.int32, + device=device, + ) + cu_seqlens_k = torch.tensor( + [0] + list(torch.cumsum(torch.tensor(seqlens_kv), 0).tolist()), + dtype=torch.int32, + device=device, + ) + block_tables = torch.arange(num_blocks, dtype=torch.int32, device=device).view( + batch_size, num_blocks_per_seq + ) + + out = moore_mate_flash_attn_prefill( + q=q_unpad, + k_cache=k_cache, + v_cache=v_cache, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + block_tables=block_tables, + scale=head_size**-0.5, + max_seqlen_q=max_q, + max_seqlen_k=max_k, + block_size=block_size, + causal=True, + ) + torch.musa.synchronize() + print(f"output shape = {tuple(out.shape)}") + assert out.shape == q_unpad.shape + print("moore_mate_flash_attn_prefill passed") + + +if __name__ == "__main__": + if not is_available(): + raise SystemExit("mate / flash_attn not available, please build mate first.") + + _test_moore_mate_flash_attn_decode() + _test_moore_mate_flash_attn_prefill() + print("\nAll tests passed.") diff --git a/src/infinicore/adaptor/aten_adaptor.cc b/src/infinicore/adaptor/aten_adaptor.cc index 04db643f9..803bf67f8 100644 --- a/src/infinicore/adaptor/aten_adaptor.cc +++ b/src/infinicore/adaptor/aten_adaptor.cc @@ -39,6 +39,13 @@ c10::cuda::CUDAStream get_cuda_stream() { } #endif +#if defined(ENABLE_MOORE_API) +c10::musa::MUSAStream get_musa_stream() { + return c10::musa::getStreamFromExternal( + musaStream_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex()); +} +#endif + } // namespace infinicore::adaptor #endif // ENABLE_ATEN diff --git a/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc b/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc index 0167c17df..8c3ca461c 100644 --- a/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc +++ b/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc @@ -16,8 +16,39 @@ #define INFINICORE_FLASH_OP(name) flash::name #endif +#if defined(ENABLE_MOORE_MATE_FLASH_ATTN) +#include "infinicore/adaptor/aten_adaptor.hpp" +#include +#include +#include +#endif + namespace infinicore::op::mha_kvcache_impl::flashattn { +#if defined(ENABLE_MOORE_MATE_FLASH_ATTN) +namespace py = pybind11; + +// Lightweight RAII: Binds MUSA streams, +// avoiding the need to include +namespace { +class LocalMUSAStreamGuard { +public: + explicit LocalMUSAStreamGuard(const c10::musa::MUSAStream &s) + : prev_(c10::musa::getCurrentMUSAStream(s.device_index())) { + c10::musa::setCurrentMUSAStream(s); + } + ~LocalMUSAStreamGuard() { + c10::musa::setCurrentMUSAStream(prev_); + } + LocalMUSAStreamGuard(const LocalMUSAStreamGuard &) = delete; + LocalMUSAStreamGuard &operator=(const LocalMUSAStreamGuard &) = delete; + +private: + c10::musa::MUSAStream prev_; +}; +} // namespace +#endif // ENABLE_MOORE_MATE_FLASH_ATTN + struct PlannedMeta { graph::GraphTensor out, q, k_cache, v_cache, seqlens_k, block_table; std::optional alibi_slopes; @@ -43,7 +74,66 @@ void *plan(Tensor out, scale}; } +#if defined(ENABLE_MOORE_MATE_FLASH_ATTN) +static void run_moore_mate_flash_attn_decode(PlannedMeta *p) { + if (p->alibi_slopes.has_value()) { + throw std::runtime_error( + "[mha_kvcache/moore] ALiBi not supported by mate flash_attn_with_kvcache"); + } + + LocalMUSAStreamGuard guard(infinicore::adaptor::get_musa_stream()); + + auto out_tensor = infinicore::adaptor::to_aten_tensor(p->out); + auto q_4d = infinicore::adaptor::to_aten_tensor(p->q); + auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache); + auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache); + auto seqlens_k = infinicore::adaptor::to_aten_tensor(p->seqlens_k); + auto block_table = infinicore::adaptor::to_aten_tensor(p->block_table); + + auto q_3d = q_4d.squeeze(1); + + const int64_t block_size = k_cache.size(1); + const int64_t max_seq_len = block_table.size(1) * block_size; + + try { + py::gil_scoped_acquire gil; + py::module_ wrapper = py::module_::import("infinicore.ops.moore_mate_flash_attn"); + + py::object py_q = py::cast(q_3d); + py::object py_k_cache = py::cast(k_cache); + py::object py_v_cache = py::cast(v_cache); + py::object py_seqlens_k = py::cast(seqlens_k); + py::object py_blk_tbl = py::cast(block_table); + + py::object result = wrapper.attr("moore_mate_flash_attn_decode")( + py_q, + py_k_cache, + py_v_cache, + py_blk_tbl, + py_seqlens_k, + p->scale, + block_size, + max_seq_len); + + at::Tensor result_t = result.cast(); + out_tensor.copy_(result_t.unsqueeze(1)); + + result = py::none(); + py_q = py_k_cache = py_v_cache = py_seqlens_k = py_blk_tbl = py::none(); + } catch (const py::error_already_set &e) { + throw std::runtime_error( + std::string("[mha_kvcache/moore] Python error: ") + e.what()); + } +} +#endif // ENABLE_MOORE_MATE_FLASH_ATTN + void run(void *planned_meta) { +#if defined(ENABLE_MOORE_MATE_FLASH_ATTN) + auto *p = reinterpret_cast(planned_meta); + run_moore_mate_flash_attn_decode(p); + return; +#endif + #ifdef ENABLE_FLASH_ATTN #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream()); diff --git a/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc b/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc index f80107e7e..16401a458 100644 --- a/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc +++ b/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc @@ -10,8 +10,37 @@ #endif #endif +#if defined(ENABLE_MOORE_MATE_FLASH_ATTN) +#include "infinicore/adaptor/aten_adaptor.hpp" +#include +#include +#include +#endif + namespace infinicore::op::mha_varlen_impl::flashattn { +#if defined(ENABLE_MOORE_MATE_FLASH_ATTN) +namespace py = pybind11; + +namespace { +class LocalMUSAStreamGuard { +public: + explicit LocalMUSAStreamGuard(const c10::musa::MUSAStream &s) + : prev_(c10::musa::getCurrentMUSAStream(s.device_index())) { + c10::musa::setCurrentMUSAStream(s); + } + ~LocalMUSAStreamGuard() { + c10::musa::setCurrentMUSAStream(prev_); + } + LocalMUSAStreamGuard(const LocalMUSAStreamGuard &) = delete; + LocalMUSAStreamGuard &operator=(const LocalMUSAStreamGuard &) = delete; + +private: + c10::musa::MUSAStream prev_; +}; +} // namespace +#endif // ENABLE_MOORE_MATE_FLASH_ATTN + struct PlannedMeta { graph::GraphTensor out, q, k, v, cum_seqlens_q, cum_seqlens_k, block_table; int max_seqlen_q, max_seqlen_k; @@ -59,7 +88,69 @@ namespace { #endif // ENABLE_FLASH_ATTN } // namespace +#if defined(ENABLE_MOORE_MATE_FLASH_ATTN) +static void run_moore_mate_flash_attn_prefill(PlannedMeta *p) { + if (p->alibi_slopes.has_value()) { + throw std::runtime_error( + "[mha_varlen/moore] ALiBi not supported by mate flash_attn_varlen"); + } + + LocalMUSAStreamGuard guard(infinicore::adaptor::get_musa_stream()); + + auto out_tensor = infinicore::adaptor::to_aten_tensor(p->out); + auto q_tensor = infinicore::adaptor::to_aten_tensor(p->q); + auto k_cache = infinicore::adaptor::to_aten_tensor(p->k); + auto v_cache = infinicore::adaptor::to_aten_tensor(p->v); + auto cu_seqlens_q = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_q); + auto cu_seqlens_k = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_k); + auto block_table = infinicore::adaptor::to_aten_tensor(p->block_table); + + const int64_t block_size = k_cache.size(1); + + int max_seqlen_q_bound = static_cast(q_tensor.size(0)); + int max_seqlen_k_bound = static_cast(q_tensor.size(0)); + + try { + py::gil_scoped_acquire gil; + py::module_ wrapper = py::module_::import("infinicore.ops.moore_mate_flash_attn"); + + py::object py_q = py::cast(q_tensor); + py::object py_k = py::cast(k_cache); + py::object py_v = py::cast(v_cache); + py::object py_cuq = py::cast(cu_seqlens_q); + py::object py_cuk = py::cast(cu_seqlens_k); + py::object py_blk = py::cast(block_table); + + py::object result = wrapper.attr("moore_mate_flash_attn_prefill")( + py_q, + py_k, + py_v, + py_cuq, + py_cuk, + py_blk, + p->scale, + max_seqlen_q_bound, + max_seqlen_k_bound, + block_size, + true); + + at::Tensor result_t = result.cast(); + out_tensor.copy_(result_t); + } catch (const py::error_already_set &e) { + throw std::runtime_error( + std::string("[mha_varlen/moore] Python error: ") + e.what()); + } +} +#endif // ENABLE_MOORE_MATE_FLASH_ATTN + void run(void *planned_meta) { + +#if defined(ENABLE_MOORE_MATE_FLASH_ATTN) + auto *p = reinterpret_cast(planned_meta); + run_moore_mate_flash_attn_prefill(p); + return; +#endif + #ifdef ENABLE_FLASH_ATTN c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream()); auto *p = reinterpret_cast(planned_meta); diff --git a/third_party/mate b/third_party/mate new file mode 160000 index 000000000..99bbf1feb --- /dev/null +++ b/third_party/mate @@ -0,0 +1 @@ +Subproject commit 99bbf1feb925ccfcbd0f01d7786543e8ce6f268b diff --git a/xmake.lua b/xmake.lua index ccae79cd2..f4d2f8f6c 100644 --- a/xmake.lua +++ b/xmake.lua @@ -237,7 +237,8 @@ option_end() if has_config("aten") then add_defines("ENABLE_ATEN") - if get_config("flash-attn") and get_config("flash-attn") ~= "" then + if get_config("flash-attn") and get_config("flash-attn") ~= "" + and (has_config("nv-gpu") or has_config("qy-gpu")) then add_defines("ENABLE_FLASH_ATTN") end end @@ -487,6 +488,16 @@ target("infinicore_cpp_api") -- from other included scripts; MetaX and QY each register their own hook in `xmake/metax.lua` -- and `xmake/qy.lua`. + -- Moore mate: pybind11/embed.h for mha_kvcache branch + if has_config("moore-gpu") and has_config("aten") then + add_packages("pybind11") + end + + -- Moore mate: enable Python bridge macro for flash-attn Moore path + if has_config("moore-gpu") and has_config("aten") then + add_defines("ENABLE_MOORE_MATE_FLASH_ATTN") + end + before_build(function (target) -- MetaX + flash-attn: `flash_attn_2_cuda` may use a different `mha_fwd_kvcache` ABI -- depending on the underlying stack version. When building with MACA (`--use-mc=y`), @@ -527,18 +538,96 @@ target("infinicore_cpp_api") path.join(TORCH_DIR, "lib"), { public = true } ) - target:add( - "links", - "torch", - "c10", - "torch_cuda", - "c10_cuda", - { public = true } - ) + + -- Moore mate: link torch_musa instead of torch_cuda/c10_cuda + if has_config("moore-gpu") then + target:add( + "links", + "torch", + "torch_cpu", + "torch_python", + "c10", + { public = true } + ) + + -- Detect torch_musa install path + local musa_outdata = os.iorunv("python", {"-c", "import torch_musa, os; print(os.path.dirname(torch_musa.__file__))"}):trim() + local TORCH_MUSA_DIR = musa_outdata + local MUSA_ROOT = os.getenv("MUSA_ROOT") or os.getenv("MUSA_HOME") or os.getenv("MUSA_PATH") or "/usr/local/musa" + + target:add( + "includedirs", + path.join(MUSA_ROOT, "include"), + path.directory(TORCH_MUSA_DIR), + path.join(TORCH_MUSA_DIR, "include"), + path.join(TORCH_MUSA_DIR, "share/generated_cuda_compatible/include"), + path.join(TORCH_MUSA_DIR, "share/generated_cuda_compatible"), + { public = true } + ) + + target:add( + "linkdirs", + path.join(TORCH_MUSA_DIR, "lib"), + { public = true } + ) + target:add( + "links", + "musa_python", + "musa_kernels", + { public = true } + ) + + -- libpython for pybind11::scoped_interpreter / embed + local pyinc = os.iorunv("python", {"-c", + "import sysconfig; print(sysconfig.get_path('include'))"}):trim() + local pylib = os.iorunv("python", {"-c", + "import sysconfig; print(sysconfig.get_config_var('LIBDIR'))"}):trim() + local pyver = os.iorunv("python", {"-c", + "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')"}):trim() + target:add("includedirs", pyinc, { public = true }) + target:add("linkdirs", pylib, { public = true }) + target:add("links", "python" .. pyver, { public = true }) + + target:add( + "shflags", + "-Wl,-rpath," .. path.join(TORCH_MUSA_DIR, "lib"), + "-Wl,-rpath," .. path.join(MUSA_ROOT, "lib"), + "-Wl,-rpath," .. path.join(TORCH_DIR, "lib"), + "-Wl,-rpath," .. pylib, + { force = true } + ) + else + target:add( + "links", + "torch", + "c10", + "torch_cuda", + "c10_cuda", + { public = true } + ) + end end end) + -- Moore mate: force link torch_python to bypass --as-needed + if has_config("moore-gpu") and has_config("aten") then + before_link(function (target) + local torch_dir = os.iorunv("python", {"-c", + "import torch, os; print(os.path.dirname(torch.__file__))"}):trim() + local torch_lib = path.join(torch_dir, "lib") + target:add("shflags", + "-Wl,--no-as-needed", + "-L" .. torch_lib, + "-ltorch_python", + "-ltorch_cpu", + "-lc10", + "-Wl,--as-needed", + "-Wl,-rpath," .. torch_lib, + {force = true}) + end) + end + -- Add InfiniCore C++ source files (needed for RoPE and other nn modules) add_files("src/infinicore/*.cc") add_files("src/infinicore/adaptor/*.cc")