From 4a83db6468be0d88c3a23ba1e9551234ae93f5aa Mon Sep 17 00:00:00 2001 From: zhushuang Date: Wed, 27 May 2026 19:50:02 +0800 Subject: [PATCH] issue/394 - feat: support flash-attn via MooreThreads/mate for moore gpu --- csrc/pybind11/engine/engine.hpp | 9 +++- examples/bench.py | 85 +++++++++++++++++++++++---------- 2 files changed, 67 insertions(+), 27 deletions(-) diff --git a/csrc/pybind11/engine/engine.hpp b/csrc/pybind11/engine/engine.hpp index 0d480bbf..7361fedf 100644 --- a/csrc/pybind11/engine/engine.hpp +++ b/csrc/pybind11/engine/engine.hpp @@ -70,8 +70,13 @@ inline void bind_infer_engine(py::module &m) { return state_dict_tp_all; }) .def("process_weights_after_loading", &InferEngine::process_weights_after_loading, "Process the weights after loading on all workers (e.g., for quantization)") - .def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") - .def("reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) + .def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, + "Run inference on all ranks with arbitrary arguments", + py::call_guard()) + .def("reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, + py::arg("cache_config") = py::none(), + py::call_guard()) + .def("get_cache_config", [](const InferEngine &self) -> std::shared_ptr { auto cfg = self.get_cache_config(); return cfg ? std::shared_ptr(cfg->unique_copy()) : nullptr; }) diff --git a/examples/bench.py b/examples/bench.py index 891f2cb0..1c54b158 100644 --- a/examples/bench.py +++ b/examples/bench.py @@ -325,48 +325,83 @@ def run( # Warmup # ---------------------------------------------------------------------------- # if cfg.warmup: - warmup_steps = 1 - - # warmup cache capacity - warmup_cache_len = 128 - warmup_batch = len(test.input_ids_list) - - test.model.reset_cache( - StaticKVCacheConfig( - max_batch_size=warmup_batch, - max_cache_len=warmup_cache_len, + print("=================== warmup start ===================") + # -------------------------------------------------------- # + # reset cache before warmup + # support both paged cache and static cache + # -------------------------------------------------------- # + if cache_config is not None: + # Paged KVCache + test.model.reset_cache(cache_config) + else: + # Static KVCache + max_batch_size = max(c["batch_size"] for _, c in cases_dict.items()) + max_cache_len = max( + c["input_len"] + c["output_len"] + for _, c in cases_dict.items() ) - ) - avg_prompt_len = min(64, max(len(ids) for ids in test.input_ids_list)) - - warmup_ids = [ - ids[:avg_prompt_len] if len(ids) >= avg_prompt_len else ids - for ids in test.input_ids_list - ] + test.model.reset_cache( + StaticKVCacheConfig( + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + ) + ) - input_ids_infini = infinicore.from_list(warmup_ids) + warmup_shapes = [] + seen = set() + for _, case in cases_dict.items(): + key = (case["batch_size"], case["input_len"]) + if key in seen: + continue + seen.add(key) + warmup_shapes.append((case["batch_size"], case["input_len"])) + + for w_batch, w_input_len in warmup_shapes: + tqdm.write( + f"\033[93m[warmup] batch={w_batch}, input_len={w_input_len}, " + f"will prefill + 3 decode steps\033[0m" + ) - print("=================== warmup start ===================") + warmup_ids = repeat_prompt(test.input_ids_list[0], target_length=w_input_len) + warmup_ids_list = [warmup_ids] * w_batch + warmup_input = infinicore.from_list(warmup_ids_list) - for _ in range(warmup_steps): _ = test.model.generate( - input_ids_infini, + warmup_input, GenerationConfig( - max_new_tokens=5, # decode kernel warmup - temperature=cfg.temperature, + max_new_tokens=3, + eos_token_id=[], top_k=cfg.top_k, top_p=cfg.top_p, + temperature=cfg.temperature, stop_on_eos=False, ), _measure_and_log_time=False, ) print("=================== warmup done ====================") - - # reset cache back to benchmark config + # -------------------------------------------------------- # + # reset cache back to benchmark config + # support both paged cache and static cache + # -------------------------------------------------------- # if cache_config is not None: + # Paged KVCache test.model.reset_cache(cache_config) + else: + # Static KVCache + max_batch_size = max(c["batch_size"] for _, c in cases_dict.items()) + max_cache_len = max( + c["input_len"] + c["output_len"] + for _, c in cases_dict.items() + ) + + test.model.reset_cache( + StaticKVCacheConfig( + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + ) + ) # ---------------------------------------------------------------------------- # # Warmup done