Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,14 @@ void GetPositionIds(const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& position_ids);
void GetPositionIdsAndSlotMapping(const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables,
const paddle::Tensor& position_ids,
const paddle::Tensor& slot_mapping,
const int block_size);

std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const paddle::Tensor& kv_nope,
Expand Down Expand Up @@ -1731,6 +1739,9 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
#endif

m.def("get_position_ids", &GetPositionIds, "get_position_ids function");
m.def("get_position_ids_and_slot_mapping",
&GetPositionIdsAndSlotMapping,
"get_position_ids_and_slot_mapping function");

/**
* cutlass_scaled_mm.cu
Expand Down
108 changes: 108 additions & 0 deletions custom_ops/gpu_ops/get_position_ids_and_slot_mapping.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "helper.h"
#include "paddle/extension.h"

__global__ void GetPositionIdsAndSlotMappingKernel(
const int* __restrict__ seq_lens_encoder,
const int* __restrict__ seq_lens_decoder,
const int* __restrict__ seq_lens_this_time,
const int* __restrict__ batch_id_per_token,
const int* __restrict__ block_tables,
const int bsz,
const int max_num_blocks,
const int block_size,
int64_t* __restrict__ position_ids,
int64_t* __restrict__ slot_mapping) {
int current_bid = threadIdx.x;
if (current_bid >= bsz) return;

// Calculate the offset of current batch in the position_ids buffer
int buffer_offset = 0;
for (int i = 0; i < current_bid; i++) {

This comment was marked as outdated.

buffer_offset += seq_lens_this_time[i];
}

// Calculate the token offset in the current batch
int token_offset = seq_lens_decoder[current_bid];
int token_num_this_batch = seq_lens_this_time[current_bid];
if (token_num_this_batch == 0) return;

// Write position ids and slot mapping for current batch
#pragma unroll
for (int i = 0; i < token_num_this_batch; i++) {
int pos_id = token_offset + i;
int idx = buffer_offset + i;

// Write position_id
position_ids[idx] = pos_id;

// Calculate slot mapping directly
int block_idx = pos_id / block_size;
int block_offset = pos_id % block_size;
int batch_id = batch_id_per_token[idx];

// Get block_id from block_tables
int block_id = block_tables[batch_id * max_num_blocks + block_idx];

// Calculate slot mapping
slot_mapping[idx] = static_cast<int64_t>(
static_cast<int64_t>(block_id) * block_size + block_offset);
}
}

void GetPositionIdsAndSlotMapping(const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables,
const paddle::Tensor& position_ids,
const paddle::Tensor& slot_mapping,
const int block_size) {
const int bsz = seq_lens_this_time.shape()[0];
const int total_token_num = position_ids.shape()[0];
const int max_num_blocks = block_tables.shape()[1];

GetPositionIdsAndSlotMappingKernel<<<1,

This comment was marked as outdated.

bsz,
0,
seq_lens_this_time.stream()>>>(
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
seq_lens_this_time.data<int>(),
batch_id_per_token.data<int>(),
block_tables.data<int>(),
bsz,
max_num_blocks,
block_size,
const_cast<int64_t*>(position_ids.data<int64_t>()),
const_cast<int64_t*>(slot_mapping.data<int64_t>()));
}

PD_BUILD_STATIC_OP(get_position_ids_and_slot_mapping)
.Inputs({
"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
"batch_id_per_token",
"block_tables",
"position_ids",
"slot_mapping",
})
.Attrs({"block_size: int"})
.Outputs({"position_ids_out", "slot_mapping_out"})
.SetInplaceMap({{"position_ids", "position_ids_out"},
{"slot_mapping", "slot_mapping_out"}})
.SetKernelFn(PD_KERNEL(GetPositionIdsAndSlotMapping));
1 change: 1 addition & 0 deletions custom_ops/setup_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def find_end_files(directory, end_str):
"gpu_ops/sample_kernels/top_k_renorm_probs.cu",
"gpu_ops/sample_kernels/min_p_sampling_from_probs.cu",
"gpu_ops/get_position_ids.cu",
"gpu_ops/get_position_ids_and_slot_mapping.cu",
"gpu_ops/fused_rotary_position_encoding.cu",
"gpu_ops/noaux_tc.cu",
"gpu_ops/noaux_tc_redundant.cu",
Expand Down
37 changes: 36 additions & 1 deletion fastdeploy/cache_manager/routing_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def __init__(self, fd_config, num_gpu_blocks: int):
self.routing_dtype = routing_replay_config.routing_dtype
self.only_last_turn = routing_replay_config.only_last_turn
self.use_fused_put = routing_replay_config.use_fused_put
self.debug_mode = routing_replay_config.debug_mode
self.block_size = fd_config.cache_config.block_size
self.return_mode = (
routing_replay_config.routing_store_type
Expand Down Expand Up @@ -235,7 +236,41 @@ def gather_routing_for_request(self, block_table, seq_len: int) -> np.ndarray:
block_indices = positions // self.block_size
offsets = positions % self.block_size
slot_mapping = np.array(block_ids)[block_indices] * self.block_size + offsets
return self.host_view.gather(slot_mapping)
routing_data = self.host_view.gather(slot_mapping)

if self.debug_mode:
expected_routing = np.arange(seq_len, dtype=routing_data.dtype)[:, None, None]
expected_routing = np.broadcast_to(expected_routing, (seq_len, self.num_moe_layers, self.moe_top_k))
if not np.array_equal(routing_data, expected_routing):
# Find all mismatched tokens
mismatch_mask = (routing_data != expected_routing).any(axis=(1, 2))
mismatched_token_indices = np.where(mismatch_mask)[0]
# Check for duplicate slots in gather
unique_slots, counts = np.unique(slot_mapping, return_counts=True)
num_duplicates = np.sum(counts > 1)
dup_info = ""
if num_duplicates > 0:
dup_indices = np.where(counts > 1)[0]
dup_slots = unique_slots[dup_indices]
dup_info = f", duplicate_slots={list(dup_slots)}"
logger.error(
f"[R3 Debug] Gather mismatch! seq_len={seq_len}, mismatched_tokens={len(mismatched_token_indices)}, "
f"slots=[{slot_mapping[0]}...{slot_mapping[-1]}]{dup_info}"
)
logger.error(f"Mismatched token indices: {mismatched_token_indices}")
for idx in mismatched_token_indices: # Print all mismatches tokens
logger.error(
f" position={idx}, slot={slot_mapping[idx]}, "
f"expected={expected_routing[idx, 0, 0]}, actual={routing_data[idx, 0, 0]}"
)
raise ValueError("[R3 Debug]Routing gather validation failed.")
else:
logger.debug(
f"[R3 Debug] Gather validation passed: seq_len={seq_len}, "
f"slots=[{slot_mapping[0]}...{slot_mapping[-1]}]"
)

return routing_data

def on_request_finished(self, request_id: str, block_table, seq_len: int) -> Optional[np.ndarray]:
"""
Expand Down
5 changes: 5 additions & 0 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1855,6 +1855,9 @@ def __init__(self, args) -> None:
# Fused routing of all layers
self.use_fused_put: bool = False

# Debug mode: hack topk_ids to use position_ids for validation
self.debug_mode: bool = False

# Auto-filled by FDConfig from ModelConfig (do not set manually)
self.routing_dtype: str = "" # "uint8" / "uint16" / "uint32"
self.num_moe_layers: int = 0
Expand Down Expand Up @@ -1885,6 +1888,8 @@ def postprocess(self, model_config: "ModelConfig") -> None:
self.routing_dtype = "uint32"
else:
raise ValueError(f"num_experts {num_experts} exceeds uint32 range")
if self.debug_mode:
self.routing_dtype = "int64"

def to_json_string(self):
"""
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/model_executor/layers/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,8 @@ def forward(
ep_size=self.fd_config.parallel_config.expert_parallel_size,
tp_group=self.fd_config.parallel_config.tp_group,
total_token_num=forward_meta.batch_id_per_token.shape[0],
position_ids=forward_meta.position_ids,
debug_mode=self.fd_config.routing_replay_config.debug_mode,
)
if current_platform.is_intel_hpu():
out = self.forward_normal(
Expand Down
106 changes: 104 additions & 2 deletions fastdeploy/model_executor/layers/moe/routing_indices_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def save_routing_to_buffer_v2(
ep_size: int,
tp_group: dist.communication.group.Group,
total_token_num: int = -1,
position_ids: paddle.Tensor = None,
debug_mode: bool = False,
):
token_num_per_rank = topk_ids.shape[0]
if token_num_per_rank == 0:
Expand All @@ -83,6 +85,12 @@ def save_routing_to_buffer_v2(
), f"[R3] total_token_num={total_token_num} < token_num_per_rank={token_num_per_rank}"
topk_ids = topk_ids_all[:total_token_num, :]

if debug_mode and position_ids is not None:
token_num, top_k = topk_ids.shape
hack_ids = position_ids[:token_num].cast(topk_ids.dtype)
hack_ids = hack_ids.unsqueeze(1).expand([-1, top_k])
topk_ids = hack_ids

token_num, top_k = topk_ids.shape
buf_max_tokens, num_moe_layers, buf_top_k = device_routing_buffer.shape

Expand Down Expand Up @@ -124,7 +132,9 @@ def __init__(self, fd_config: FDConfig, total_block_num: int):
self.num_moe_layers = rrc.num_moe_layers
self.moe_top_k = rrc.moe_top_k
self.routing_dtype = rrc.routing_dtype
self.debug_mode = rrc.debug_mode
self.tp_rank = fd_config.parallel_config.tensor_parallel_rank
self.token_num_overlap = 0

logger.info(f"[R3] RoutedExpertsCapturer config: {rrc}")

Expand All @@ -133,6 +143,7 @@ def __init__(self, fd_config: FDConfig, total_block_num: int):
def _init_routing_cache(self, dtype: str, total_block_num: int):
"""Initialize GPU transient buffer, staging buffers, and CPU pinned buffers."""
max_num_kv_tokens = total_block_num * self.fd_config.cache_config.block_size
self.max_num_kv_tokens = max_num_kv_tokens # Save for slot range validation

# Small GPU transient buffer: only current step's token routing
# TODO(Chengyanfu): Use max_num_batched_tokens to replace get_max_chunk_tokens()
Expand All @@ -145,6 +156,14 @@ def _init_routing_cache(self, dtype: str, total_block_num: int):

self.cpu_routing_buf = paddle.zeros(shape, dtype=dtype).pin_memory()
self.cpu_slot_mapping_buf = paddle.zeros([max_num_batched_tokens], dtype=paddle.int64).pin_memory()

if self.debug_mode:
self.position_ids_staging_buf = paddle.zeros([max_num_batched_tokens], dtype=paddle.int64)
self.cpu_position_ids_buf = paddle.zeros([max_num_batched_tokens], dtype=paddle.int64).pin_memory()
else:
self.position_ids_staging_buf = None
self.cpu_position_ids_buf = None

self._pending_save = None # {"num_tokens": int}

# Lazy attach to SharedMemory routing_host_buffer (created by Engine after profiling)
Expand Down Expand Up @@ -180,7 +199,9 @@ def _try_attach_routing_host_view(self):
"Routing capture will be skipped."
)

def prepare_pending_save(self, num_tokens: int, slot_mapping_gpu: paddle.Tensor):
def prepare_pending_save(
self, num_tokens: int, slot_mapping_gpu: paddle.Tensor, position_ids_gpu: paddle.Tensor = None
):
"""
Enqueue D2D + async D2H for routing data and slot_mapping.
Must be called before post_process_event.record().
Expand All @@ -190,14 +211,25 @@ def prepare_pending_save(self, num_tokens: int, slot_mapping_gpu: paddle.Tensor)
2. D2D (non-blocking): slot_mapping_gpu → slot_mapping_staging_buf
3. async D2H: routing_staging_buf → cpu_routing_buf
4. async D2H: slot_mapping_staging_buf → cpu_slot_mapping_buf
5. async D2H (debug mode): position_ids_gpu → cpu_position_ids_buf
"""
if num_tokens > 0:
if self.fd_config.scheduler_config.enable_overlap_schedule:
num_tokens = self.token_num_overlap

This comment was marked as outdated.

slot_mapping_gpu = slot_mapping_gpu[:num_tokens]
position_ids_gpu = position_ids_gpu[:num_tokens]

# D2D: GPU → staging
self.routing_staging_buf.copy_(self.device_routing_buffer, False)
self.slot_mapping_staging_buf.copy_(slot_mapping_gpu, False)
# Async D2H: staging → CPU pinned
self.cpu_routing_buf.copy_(self.routing_staging_buf, False)
self.cpu_slot_mapping_buf.copy_(self.slot_mapping_staging_buf, False)

if self.debug_mode and position_ids_gpu is not None and self.cpu_position_ids_buf is not None:
self.position_ids_staging_buf.copy_(position_ids_gpu, False)
self.cpu_position_ids_buf.copy_(self.position_ids_staging_buf, False)

self._pending_save = {"num_tokens": num_tokens}
else:
self._pending_save = None
Expand All @@ -222,7 +254,77 @@ def flush_pending_save(self):
num_tokens = pending["num_tokens"]
# NOTE(gongshaotian): Slice pinned memory tensor maybe cause problem.
data = self.cpu_routing_buf.cpu()[:num_tokens].numpy()
slot_np = self.cpu_slot_mapping_buf.cpu()[:num_tokens].numpy()
slot_cpu = self.cpu_slot_mapping_buf.cpu()
slot_cpu_slice = slot_cpu[:num_tokens]
slot_np = slot_cpu_slice.numpy()

if self.debug_mode and self.cpu_position_ids_buf is not None:
position_ids = self.cpu_position_ids_buf.cpu()[:num_tokens].numpy()
expected_routing = position_ids[:, None, None]
expected_routing = np.broadcast_to(expected_routing, (num_tokens, self.num_moe_layers, self.moe_top_k))
if not np.array_equal(data, expected_routing):
# 1. Check routing capture
mismatch_mask = (data != expected_routing).any(axis=(1, 2))
mismatched_token_indices = np.where(mismatch_mask)[0]
logger.error(
f"[R3 Debug] flush mismatch! num_tokens={num_tokens}, mismatched_tokens={len(mismatched_token_indices)}"
)
logger.error(f"Mismatched token indices: {mismatched_token_indices}")
for idx in mismatched_token_indices:
logger.error(
f" token={idx}, position_id={position_ids[idx]}, slot={slot_np[idx]}, "
f"expected={expected_routing[idx, :, :]}, actual={data[idx, :, :]}"
)
raise ValueError("Routing data verification failed.")
else:
# 2. Check slot mapping generation and validate slot indices (should be >= 0)
if slot_cpu_slice.min() < 0:
error_parts = [f"[R3 Debug] Invalid slot indices: num_tokens={num_tokens}"]
error_parts.append(" token |slot_staging | slot_pinned | slot_cpu | position_id | data[0,0]")
error_parts.append(" " + "-" * 50)
for i in range(num_tokens):
error_parts.append(
f" {i:4d} | {int(self.slot_mapping_staging_buf[i]):7d} | {int(self.cpu_slot_mapping_buf[i]):7d} | {int(slot_cpu[i]):7d} | {int(position_ids[i]):11d} | {int(data[i, 0, 0])}"
)
raise AssertionError("\n".join(error_parts))
# 2.1 Check slot range (should be < max_num_kv_tokens)
max_slot = slot_cpu_slice.max()
if max_slot >= self.max_num_kv_tokens:
invalid_slots = np.where(slot_np >= self.max_num_kv_tokens)[0]
error_parts = [
f"[R3 Debug] Slot indices out of range: num_tokens={num_tokens}, "
f"max_slot={max_slot}, max_num_kv_tokens={self.max_num_kv_tokens}"
]
error_parts.append(f" Invalid slot indices: {invalid_slots[:10]}... ({len(invalid_slots)} total)")
error_parts.append(" token |slot | position_id | data[0,0]")
error_parts.append(" " + "-" * 50)
for idx in invalid_slots[:10]:
error_parts.append(
f" {idx:4d} | {int(slot_np[idx]):6d} | {int(position_ids[idx]):11d} | {int(data[idx, 0, 0])}"
)
raise AssertionError("\n".join(error_parts))
# 3. Check slot mapping duplicates
unique_slots, counts = np.unique(slot_np, return_counts=True)
num_unique = len(unique_slots)
num_duplicates = np.sum(counts > 1)
if num_duplicates > 0:
duplicate_indices = np.where(counts > 1)[0]
dup_slots_info = []
for slot_idx in duplicate_indices[:5]:
slot = unique_slots[slot_idx]
count = counts[slot_idx]
dup_token_indices = np.where(slot_np == slot)[0]
dup_slots_info.append(f"slot={slot} count={count} indices={dup_token_indices}")
logger.error(
f"[R3 Debug] flush validation passed but found duplicate slots! "
f"num_tokens={num_tokens}, unique_slots={num_unique}, duplicates={num_duplicates}. "
f"Details: {'; '.join(dup_slots_info)}"
)
else:
logger.debug(
f"[R3 Debug] flush validation passed: num_tokens={num_tokens}, "
f"slots=[{slot_np[0]}...{slot_np[-1]}], unique_slots={num_unique}"
)

self.routing_host_view.scatter(slot_np, data)

Expand Down
Loading
Loading