diff --git a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/dispatch.cpp b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/dispatch.cpp index 547a759b6..1ee637655 100644 --- a/examples/workers/l3/ep_dispatch_combine/kernels/aiv/dispatch.cpp +++ b/examples/workers/l3/ep_dispatch_combine/kernels/aiv/dispatch.cpp @@ -33,7 +33,8 @@ * recovers column-0 values * because columns [1, W_PAD) * are zero by design); - * recv_idx_out is [L, R] (scalar copy of column 0) + * recv_idx_out is [L, R] (TROWSUM over [L,R,IDX_PAD] + * wide window, INT32) * * Design notes: * - All cross-rank GM writes go through tile primitives (TPUT). No AIV @@ -44,9 +45,8 @@ * - Weight uses TROWSUM along the W_PAD axis to compact the wide window * [L, R, W_PAD] → [L, R] FP32: sum-of-row recovers slot [0] because the * other lanes are zero. One TLOAD + TROWSUM + TSTORE per expert. - * - Idx uses scalar GM copy of column 0 to compact [L, R, IDX_PAD] → - * [L, R] INT32. INT32 TROWSUM exists in pto-isa but hangs on a2a3 in - * this configuration; the L*R = 128 scalar stores are negligible. + * - Idx uses the same TROWSUM compaction along the IDX_PAD axis to compact + * [L, R, IDX_PAD] → [L, R] INT32. One TLOAD + TROWSUM + TSTORE per expert. */ #ifndef __gm__ @@ -527,18 +527,25 @@ extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ in wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID1); } - // Stage out idx: scalar copy of column 0 from the wide window. - // - // ⚠ The same TROWSUM compaction used above for the FP32 weight channel - // does NOT work reliably for INT32 on a2a3: pto-isa declares INT32 - // TROWSUM support, but with the same [R, IDX_PAD] / Layout::DN setup - // the kernel hangs on hardware. Until that path is stabilized, fall - // back to a scalar copy here. Volume is small (L*R = 128 INT32 stores) - // so the perf cost is negligible. + // Stage out idx: same TROWSUM compaction as the weight channel, on the + // INT32 [R, IDX_PAD] wide window. sum-along-PAD recovers slot [0] because + // columns [1, IDX_PAD) are zero by design. for (int e = 0; e < L; ++e) { - for (int slot = 0; slot < R; ++slot) { - recv_idx_out[e * R + slot] = recv_idx_local[(e * R + slot) * IDX_PAD]; - } + __gm__ int32_t *idx_win = recv_idx_local + e * R * IDX_PAD; + __gm__ int32_t *idx_out = recv_idx_out + e * R; + IWideG idx_win_g(idx_win); + ISumG idx_out_g(idx_out); + TLOAD(idx_wide_tile, idx_win_g); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + pipe_barrier(PIPE_V); + TROWSUM(idx_sum_tile, idx_wide_tile, idx_tmp_tile); + pipe_barrier(PIPE_V); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TSTORE(idx_out_g, idx_sum_tile); + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID1); } pipe_barrier(PIPE_ALL); } diff --git a/examples/workers/l3/ep_dispatch_combine/main.py b/examples/workers/l3/ep_dispatch_combine/main.py index 54f674d25..9f5d5ed25 100644 --- a/examples/workers/l3/ep_dispatch_combine/main.py +++ b/examples/workers/l3/ep_dispatch_combine/main.py @@ -43,7 +43,7 @@ [weight, 0, 0, …, 0]; receiver writes recv_w[loc_e][slot, :W_PAD] and the kernel TROWSUM-compacts to a [L, R] FP32 host output. - Idx uses the same minimum-tile rationale: 1xIDX_PAD=8 INT32 per - route, actual r=t*TOPK+k at slot [0]; compacted via scalar copy to + route, actual r=t*TOPK+k at slot [0]; TROWSUM-compacted to [L, R] INT32 host output. Combine reads it to address routed_y_buf[t, k, :] without a host-built origin_map. - ``recv_count_out`` is [L, 1] INT32 emitted by dispatch's prefix_sum