Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions csrc/pybind11/engine/engine.hpp
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个改动需要解释一下,不加这个会发生什么?
另外这里确实有个bug,回头可以跟老马商量一下
@ma-hang

Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,13 @@ inline void bind_infer_engine(py::module &m) {
return state_dict_tp_all;
})
.def("process_weights_after_loading", &InferEngine::process_weights_after_loading, "Process the weights after loading on all workers (e.g., for quantization)")
.def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
.def("reset_cache", [](InferEngine &self, std::shared_ptr<cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
.def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); },
"Run inference on all ranks with arbitrary arguments",
py::call_guard<py::gil_scoped_release>())
.def("reset_cache", [](InferEngine &self, std::shared_ptr<cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); },
py::arg("cache_config") = py::none(),
py::call_guard<py::gil_scoped_release>())

.def("get_cache_config", [](const InferEngine &self) -> std::shared_ptr<cache::CacheConfig> {
auto cfg = self.get_cache_config();
return cfg ? std::shared_ptr<cache::CacheConfig>(cfg->unique_copy()) : nullptr; })
Expand Down
85 changes: 60 additions & 25 deletions examples/bench.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是顺手加了支持paged的warmup是么?

Original file line number Diff line number Diff line change
Expand Up @@ -325,48 +325,83 @@ def run(
# Warmup
# ---------------------------------------------------------------------------- #
if cfg.warmup:
warmup_steps = 1

# warmup cache capacity
warmup_cache_len = 128
warmup_batch = len(test.input_ids_list)

test.model.reset_cache(
StaticKVCacheConfig(
max_batch_size=warmup_batch,
max_cache_len=warmup_cache_len,
print("=================== warmup start ===================")
# -------------------------------------------------------- #
# reset cache before warmup
# support both paged cache and static cache
# -------------------------------------------------------- #
if cache_config is not None:
# Paged KVCache
test.model.reset_cache(cache_config)
else:
# Static KVCache
max_batch_size = max(c["batch_size"] for _, c in cases_dict.items())
max_cache_len = max(
c["input_len"] + c["output_len"]
for _, c in cases_dict.items()
)
)

avg_prompt_len = min(64, max(len(ids) for ids in test.input_ids_list))

warmup_ids = [
ids[:avg_prompt_len] if len(ids) >= avg_prompt_len else ids
for ids in test.input_ids_list
]
test.model.reset_cache(
StaticKVCacheConfig(
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
)
)

input_ids_infini = infinicore.from_list(warmup_ids)
warmup_shapes = []
seen = set()
for _, case in cases_dict.items():
key = (case["batch_size"], case["input_len"])
if key in seen:
continue
seen.add(key)
warmup_shapes.append((case["batch_size"], case["input_len"]))

for w_batch, w_input_len in warmup_shapes:
tqdm.write(
f"\033[93m[warmup] batch={w_batch}, input_len={w_input_len}, "
f"will prefill + 3 decode steps\033[0m"
)

print("=================== warmup start ===================")
warmup_ids = repeat_prompt(test.input_ids_list[0], target_length=w_input_len)
warmup_ids_list = [warmup_ids] * w_batch
warmup_input = infinicore.from_list(warmup_ids_list)

for _ in range(warmup_steps):
_ = test.model.generate(
input_ids_infini,
warmup_input,
GenerationConfig(
max_new_tokens=5, # decode kernel warmup
temperature=cfg.temperature,
max_new_tokens=3,
eos_token_id=[],
top_k=cfg.top_k,
top_p=cfg.top_p,
temperature=cfg.temperature,
stop_on_eos=False,
),
_measure_and_log_time=False,
)

print("=================== warmup done ====================")

# reset cache back to benchmark config
# -------------------------------------------------------- #
# reset cache back to benchmark config
# support both paged cache and static cache
# -------------------------------------------------------- #
if cache_config is not None:
# Paged KVCache
test.model.reset_cache(cache_config)
else:
# Static KVCache
max_batch_size = max(c["batch_size"] for _, c in cases_dict.items())
max_cache_len = max(
c["input_len"] + c["output_len"]
for _, c in cases_dict.items()
)

test.model.reset_cache(
StaticKVCacheConfig(
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
)
)

# ---------------------------------------------------------------------------- #
# Warmup done
Expand Down