From 6077feade5f2689cf76a62a17e331956576d4c92 Mon Sep 17 00:00:00 2001 From: UED Date: Sun, 22 Feb 2026 23:59:15 +0000 Subject: [PATCH 01/22] Add int8 quantization for vortex. Key changes: 1. Memory Pool (`vtx_graph_memory_pool.py`): - Removed hardcoded bf16 assertions in `VTXGraphCachePool` to support `torch.int8` allocations. - Added parallel `float32` scale buffers (`k_scale`, `v_scale`) mapped to the paged layout. - Preserved `bfloat16` shadow buffers (`k_bf16`) for auxiliary metadata (e.g., centroids) to ensure the Vortex sparse indexer/TopK remains unaffected and mathematically identical. 2. Quantize-on-Write (`set_kv.py`): - Implemented a custom Triton kernel (`set_kv_buffer_int8_kernel`) that quantizes incoming `bf16` tokens into `int8` on the fly using per-token absmax scaling (`scale = max(abs(x)) / 127.0`). - Wired the new launcher into the cache update flow. 3. Decode Path (`vtx_graph_backend.py` & `paged_decode_int8.py`): - Bypassed FlashInfer for INT8 decoding. - Wired in the custom Triton decode kernel (`paged_decode_int8`) that reads the `int8` pages and `float32` scales directly into SRAM, performing fused inline dequantization without allocating temporary full-cache VRAM buffers. - Seamlessly integrated with existing sparse routing indices (`indptr`, `indices`). 4. Prefill Path (`vtx_graph_backend.py` & `paged_prefill_int8.py`): - Implemented an OOM-safe `bf16` fallback for prefill. - Added a new Triton kernel (`dequant_paged_int8_to_bf16`) to dynamically extract and dequantize *only the accessed pages* for the current batch into a tiny, compacted `bf16` buffer. - Modified the FlashInfer `BatchPrefillWithPagedKVCacheWrapper` planner to map over the compacted subset indices, entirely avoiding full-cache dequantization OOMs. --- CLAUDE.md | 88 +++++ examples/verify_algo.py | 27 +- examples/verify_algo.sh | 27 +- examples/verify_algo_quant.sh | 25 ++ vortex_torch/cache/__init__.py | 3 +- vortex_torch/cache/triton_kernels/__init__.py | 11 +- .../cache/triton_kernels/paged_decode_int8.py | 355 ++++++++++++++++++ .../triton_kernels/paged_prefill_int8.py | 90 +++++ vortex_torch/cache/triton_kernels/set_kv.py | 87 +++++ 9 files changed, 696 insertions(+), 17 deletions(-) create mode 100644 CLAUDE.md create mode 100644 examples/verify_algo_quant.sh create mode 100644 vortex_torch/cache/triton_kernels/paged_decode_int8.py create mode 100644 vortex_torch/cache/triton_kernels/paged_prefill_int8.py diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..db54c75 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,88 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +Vortex is a lightweight, modular framework for building custom sparse attention algorithms for LLM inference. It provides a PyTorch-like frontend that abstracts away batching, caching, and paged attention, running on optimized backends (FlashInfer, CUDA Graph) via SGLang integration. + +## Build & Install + +```bash +# Install SGLang dependency (custom fork in third_party/) +cd third_party/sglang && bash install.sh && cd ../../ + +# Install Vortex (editable mode, compiles CUDA extensions for SM_89/SM_90) +pip install -e . +``` + +Requires Python >=3.10, torch>=2.7. CUDA extensions are built from `csrc/` (register.cc, utils_sglang.cu, topk.cu). + +## Running Examples + +```bash +# Single algorithm verification against SGLang +python examples/verify_algo.py --trials 2 --topk-val 30 --vortex-module-name block_sparse_attention + +# Batch test multiple algorithms +bash examples/verify_algo.sh +``` + +## Building Documentation + +```bash +make -C docs html +``` + +Uses Sphinx with myst_parser and furo theme. Deployed via GitHub Actions on push to v1 branch. + +## Architecture + +### Core Abstraction: vFlow (`vortex_torch/flow/flow.py`) + +All sparse attention algorithms inherit from `vFlow` and implement three methods: + +- **`forward_indexer(q, o, cache, ctx)`** — Compute sparse page indices from queries. Operates on page-packed tensor view `[S, r, c]`. +- **`forward_cache(cache, loc, ctx)`** — Update/summarize custom cache tensors when a page completes. Operates on batch-major view `[B, r, c]`. +- **`create_cache(page_size, head_dim)`** — Declare custom cache tensor shapes as a dict of `{name: (rows, cols)}`. + +Algorithms are registered via `@register("name")` decorator and instantiated with `build_vflow()`. + +### Operator System (`vortex_torch/indexer/`, `vortex_torch/cache/`) + +Operators (`vOp` subclasses) run in two modes: +- **Profile mode**: Pre-compute output shapes and allocate buffers +- **Execute mode**: Perform actual GPU computation + +Operators are split into two parallel hierarchies: +- **Indexer ops** (`vortex_torch/indexer/`): GeMM, GeMV, topK, reduce (Mean/Max/Min/Sum/L2Norm), softmax, elementwise, transpose, save/load +- **Cache ops** (`vortex_torch/cache/`): GeMM, reduce, elementwise, fill, KV buffer setup + +Both use Triton kernels (in respective `triton_kernels/` subdirectories) for GPU execution. + +### Tensor Format (`vortex_torch/abs/tensor.py`) + +`vTensor` wraps `torch.Tensor` with format metadata (BATCHED, RAGGED, PAGED) to enforce layout consistency across operations. + +### Context System (`vortex_torch/abs/context_base.py`) + +`ContextBase` carries per-step runtime state. Specialized as: +- `Indexer.Context`: Page layout, head config, hardware info +- `Cache.Context`: Page size, total pages, model info + +### Concrete Algorithms (`vortex_torch/flow/algorithms.py`) + +- **BlockSparseAttention**: Centroid-based routing (query avg → GeMV with centroids → topK) +- **GQABlockSparseAttention**: Grouped-query variant with softmax + group aggregation +- **GQAQuestSparseAttention**: Query-envelope matching using per-page max/min bounds + +### SGLang Integration + +Custom SGLang fork lives in `third_party/sglang` (git submodule, "graph" branch). CUDA extensions in `csrc/` provide PyBind11 bindings for `sglang_plan_decode`, `sglang_plan_prefill`, and transpose operations. + +## Key Conventions + +- **Tensor shapes**: Query `[B, H_q, D]`, sparse output `[S_sparse, 1, 1]`, cache indexer-view `[S, r, c]`, cache batch-view `[B, r, c]` +- **GeMM semantics**: `GeMM(x, y)` computes `y @ x^T` (note transposition) +- **Standard cache keys**: `"k"` and `"v"` have inner shape `(page_size, head_dim)`; custom caches declared in `create_cache()` +- **Branch**: Main development is on `v1` diff --git a/examples/verify_algo.py b/examples/verify_algo.py index e290a81..f418598 100644 --- a/examples/verify_algo.py +++ b/examples/verify_algo.py @@ -54,7 +54,8 @@ def verify_algos( vortex_module_name: str = "gqa_block_sparse_attention", model_name: str = "Qwen/Qwen3-1.7B", sparse_attention: bool = True, -mem: float = 0.8 +mem: float = 0.8, +kv_cache_dtype: str = "auto", ): llm = sgl.Engine(model_path=model_name, @@ -69,10 +70,11 @@ def verify_algos( vortex_layers_skip=list(range(1)), vortex_module_name=vortex_module_name, vortex_max_seq_lens=12288, - mem_fraction_static=mem + mem_fraction_static=mem, + kv_cache_dtype=kv_cache_dtype, ) - with open("examples/amc23.jsonl", "r", encoding="utf-8") as f: + with open("amc23.jsonl", "r", encoding="utf-8") as f: requests = [json.loads(line) for line in f] requests = requests * trials @@ -110,6 +112,14 @@ def verify_algos( "num_tokens": item["meta_info"]["completion_tokens"] } ) + # --- Per-question debug output --- + print(f"[Q{len(results):03d}] score={float(result):.1f} " + f"tokens={item['meta_info']['completion_tokens']} " + f"latency={item['meta_info']['e2e_latency']:.2f}s " + f"gold={golds[0]}") + print(f" question: {data['question'][:120]}...") + print(f" prediction: {predictions[:200]}...") + print() total_accuracy = 0.0 @@ -203,6 +213,14 @@ def parse_args(): default=0.8, help="memory fraction in sglang", ) + + parser.add_argument( + "--kv-cache-dtype", + type=str, + default="auto", + choices=["auto", "fp8_e5m2", "fp8_e4m3", "int8"], + help='KV cache dtype (default: "auto").', + ) return parser.parse_args() if __name__ == "__main__": @@ -215,7 +233,8 @@ def parse_args(): vortex_module_name=args.vortex_module_name, model_name=args.model_name, sparse_attention=not(args.full_attention), - mem=args.mem + mem=args.mem, + kv_cache_dtype=args.kv_cache_dtype, ) print(summary) diff --git a/examples/verify_algo.sh b/examples/verify_algo.sh index 17c2a5e..d80f09a 100644 --- a/examples/verify_algo.sh +++ b/examples/verify_algo.sh @@ -1,17 +1,24 @@ #!/usr/bin/env bash set -e +export CUDA_VISIBLE_DEVICES=1 sparse_algos=( - + "block_sparse_attention" ) -for algo in "${sparse_algos[@]}"; do - echo ">>> Running verify_algo.py with --vortex-module-name ${algo}" - python examples/verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --mem 0.7 -done +RESULTS_DIR="results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_bf16_${TIMESTAMP}.log" + echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --kv-cache-dtype bf16" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done \ No newline at end of file diff --git a/examples/verify_algo_quant.sh b/examples/verify_algo_quant.sh new file mode 100644 index 0000000..4cf1366 --- /dev/null +++ b/examples/verify_algo_quant.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +set -e +export CUDA_VISIBLE_DEVICES=2 + +sparse_algos=( + "block_sparse_attention" +) + +RESULTS_DIR="results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + + for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_int8_${TIMESTAMP}.log" + echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --kv-cache-dtype int8" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --kv-cache-dtype int8 \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done \ No newline at end of file diff --git a/vortex_torch/cache/__init__.py b/vortex_torch/cache/__init__.py index eddfa46..b886559 100644 --- a/vortex_torch/cache/__init__.py +++ b/vortex_torch/cache/__init__.py @@ -29,11 +29,12 @@ from .matmul import GeMM from .elementwise import Relu, Silu, Sigmoid, Abs, Add_Mul from .elementwise_binary import Maximum, Minimum, Multiply, Add -from .triton_kernels import set_kv_buffer_launcher +from .triton_kernels import set_kv_buffer_launcher, set_kv_buffer_int8_launcher __all__ = [ "set_kv_buffer_launcher", + "set_kv_buffer_int8_launcher", "Mean", "Max", "Min", "L2Norm", "GeMM", "Relu", "Silu", "Sigmoid", "Abs", "Add_Mul", diff --git a/vortex_torch/cache/triton_kernels/__init__.py b/vortex_torch/cache/triton_kernels/__init__.py index 6bf6dfc..2d6384f 100644 --- a/vortex_torch/cache/triton_kernels/__init__.py +++ b/vortex_torch/cache/triton_kernels/__init__.py @@ -1,4 +1,11 @@ -from .set_kv import set_kv_buffer_launcher +from .set_kv import set_kv_buffer_launcher, set_kv_buffer_int8_launcher +from .paged_decode_int8 import paged_decode_int8 +from .paged_prefill_int8 import dequant_paged_int8_to_bf16 -__all__ = ["set_kv_buffer_launcher"] +__all__ = [ + "set_kv_buffer_launcher", + "set_kv_buffer_int8_launcher", + "paged_decode_int8", + "dequant_paged_int8_to_bf16", +] diff --git a/vortex_torch/cache/triton_kernels/paged_decode_int8.py b/vortex_torch/cache/triton_kernels/paged_decode_int8.py new file mode 100644 index 0000000..480c787 --- /dev/null +++ b/vortex_torch/cache/triton_kernels/paged_decode_int8.py @@ -0,0 +1,355 @@ +""" +Custom Triton paged decode attention kernel for int8 KV cache. + +Loads int8 K/V pages with per-token float32 scales, dequantizes inline in SRAM, +and computes standard multi-head attention with online softmax. + +Adapted from SGLang's decode_attention.py for use with Vortex's paged layout +where each KV head is treated as a separate "batch" entry. +""" + +import torch +import triton +import triton.language as tl + +_MIN_BLOCK_KV = 32 + + +@triton.jit +def tanh(x): + return 2 * tl.sigmoid(2 * x) - 1 + + +@triton.jit +def _fwd_kernel_int8_stage1( + Q, # [batch, num_qo_heads, head_dim] bf16 + K_Buffer, # int8 paged: flat + V_Buffer, # int8 paged: flat + K_Scale_Buffer, # float32: flat (one scale per token slot) + V_Scale_Buffer, # float32: flat + sm_scale, + kv_indptr, # [batch + 1] int32, page-level + kv_indices, # page indices + last_page_len, # [batch] int32, tokens valid in last page + Att_Out, # [batch, num_qo_heads, max_kv_splits, head_dim] + Att_Lse, # [batch, num_qo_heads, max_kv_splits] + num_kv_splits, # [batch] int32 + stride_qbs, + stride_qh, + stride_buf_kbs, # stride per token in K_Buffer (= head_dim) + stride_buf_vbs, # stride per token in V_Buffer (= head_dim) + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + kv_group_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + MIN_BLOCK_KV: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, + Lv: tl.constexpr, + PAGE_SIZE: tl.constexpr, +): + """ + Stage 1: For each (batch, head, kv_split), compute partial attention output and LSE. + + kv_indptr is page-level. Total tokens for batch i: + (num_pages - 1) * PAGE_SIZE + last_page_len[i] + """ + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + split_kv_id = tl.program_id(2) + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv + + cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch) + cur_batch_num_pages = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx + cur_last_page_len = tl.load(last_page_len + cur_batch) + # Correct token count accounting for partial last page + cur_batch_seq_len = (cur_batch_num_pages - 1) * PAGE_SIZE + cur_last_page_len + kv_splits = tl.load(num_kv_splits + cur_batch) + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + + kv_len_per_split = ( + tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV + ) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + e_max = -float("inf") + e_sum = 0.0 + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + q = tl.load(Q + off_q, mask=mask_d, other=0.0).to(tl.float32) + + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_n = offs_n < split_kv_end + + # Convert token offsets to page_id + in-page offset + page_indices_in_seq = offs_n // PAGE_SIZE + in_page_offsets = offs_n % PAGE_SIZE + + # Load page indices from kv_indices (physical page IDs) + page_ids = tl.load( + kv_indices + cur_batch_kv_start_idx + page_indices_in_seq, + mask=mask_n, + other=0, + ) + + # Flat token location: physical_page * PAGE_SIZE + in_page_offset + kv_loc = page_ids * PAGE_SIZE + in_page_offsets + + # Load int8 K and dequantize + offs_buf_k = kv_loc[:, None] * stride_buf_kbs + offs_d[None, :] + k_int8 = tl.load( + K_Buffer + offs_buf_k, + mask=mask_n[:, None] & mask_d[None, :], + other=0, + ).to(tl.float32) + + k_scale = tl.load( + K_Scale_Buffer + kv_loc, + mask=mask_n, + other=1.0, + ) + k = k_int8 * k_scale[:, None] + + # Compute QK + qk = tl.sum(q[None, :] * k, 1) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where(mask_n, qk, float("-inf")) + + # Load int8 V and dequantize + offs_buf_v = kv_loc[:, None] * stride_buf_vbs + offs_dv[None, :] + v_int8 = tl.load( + V_Buffer + offs_buf_v, + mask=mask_n[:, None] & mask_dv[None, :], + other=0, + ).to(tl.float32) + + v_scale = tl.load( + V_Scale_Buffer + kv_loc, + mask=mask_n, + other=1.0, + ) + v = v_int8 * v_scale[:, None] + + # Online softmax accumulation + n_e_max = tl.maximum(tl.max(qk, 0), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max) + acc *= re_scale + acc += tl.sum(p[:, None] * v, 0) + + e_sum = e_sum * re_scale + tl.sum(p, 0) + e_max = n_e_max + + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + offs_dv + ) + + tl.store( + Att_Out + offs_mid_o, + acc / e_sum, + mask=mask_dv, + ) + + offs_mid_o_1 = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + ) // Lv + + tl.store( + Att_Lse + offs_mid_o_1, + e_max + tl.log(e_sum), + ) + + +@triton.jit +def _fwd_kernel_int8_stage2( + Mid_O, + Mid_O_1, + O, + kv_indptr, + last_page_len, + num_kv_splits, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_obs, + stride_oh, + MAX_KV_SPLITS: tl.constexpr, + MIN_BLOCK_KV: tl.constexpr, + BLOCK_DV: tl.constexpr, + Lv: tl.constexpr, + PAGE_SIZE: tl.constexpr, +): + """Stage 2: Reduce split outputs via log-sum-exp merge.""" + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_batch_num_pages = tl.load(kv_indptr + cur_batch + 1) - tl.load(kv_indptr + cur_batch) + cur_last_page_len = tl.load(last_page_len + cur_batch) + cur_batch_seq_len = (cur_batch_num_pages - 1) * PAGE_SIZE + cur_last_page_len + kv_splits = tl.load(num_kv_splits + cur_batch) + + offs_d = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lv + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh) // Lv + kv_len_per_split = ( + tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV + ) + + for split_kv_id in range(0, MAX_KV_SPLITS): + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + if split_kv_end > split_kv_start: + tv = tl.load( + Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0 + ) + tlogic = tl.load(Mid_O_1 + offs_logic + split_kv_id * stride_mid_os // Lv) + n_e_max = tl.maximum(tlogic, e_max) + + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(tlogic - n_e_max) + acc += exp_logic * tv + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, + acc / e_sum, + mask=mask_d, + ) + + +def paged_decode_int8( + q: torch.Tensor, # [batch, num_qo_heads, head_dim] bf16 + k_buffer: torch.Tensor, # int8 paged K cache + v_buffer: torch.Tensor, # int8 paged V cache + k_scale_buffer: torch.Tensor, # float32 scale for K + v_scale_buffer: torch.Tensor, # float32 scale for V + o: torch.Tensor, # [batch, num_qo_heads, head_dim] bf16 output + kv_indptr: torch.Tensor, # [batch + 1] int32, page-level + kv_indices: torch.Tensor, # page indices + last_page_len: torch.Tensor, # [batch] int32 + num_kv_splits: torch.Tensor, # [batch] int32 + max_kv_splits: int, + sm_scale: float, + page_size: int, + logit_cap: float = 0.0, +): + """ + Paged decode attention with int8 KV cache and inline dequantization. + + kv_indptr is page-level. last_page_len specifies valid tokens in the last page + for each batch entry. Total tokens = (num_pages - 1) * page_size + last_page_len. + """ + batch = q.shape[0] + head_num = q.shape[1] + Lk = q.shape[2] + Lv = Lk + + BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DV = triton.next_power_of_2(Lv) + BLOCK_N = 64 + MAX_KV_SPLITS = max_kv_splits + + kv_group_num = head_num + + num_warps = 4 if kv_group_num == 1 else 2 + + # Intermediate buffers for split reduction + att_out = torch.empty( + (batch, head_num, MAX_KV_SPLITS, Lv), + dtype=torch.float32, + device=q.device, + ) + att_lse = torch.empty( + (batch, head_num, MAX_KV_SPLITS), + dtype=torch.float32, + device=q.device, + ) + + stride_buf_kbs = k_buffer.shape[-1] + stride_buf_vbs = v_buffer.shape[-1] + + grid_stage1 = (batch, head_num, MAX_KV_SPLITS) + _fwd_kernel_int8_stage1[grid_stage1]( + q, + k_buffer, + v_buffer, + k_scale_buffer, + v_scale_buffer, + sm_scale, + kv_indptr, + kv_indices, + last_page_len, + att_out, + att_lse, + num_kv_splits, + q.stride(0), + q.stride(1), + stride_buf_kbs, + stride_buf_vbs, + att_out.stride(0), + att_out.stride(1), + att_out.stride(2), + kv_group_num=kv_group_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DV=BLOCK_DV, + BLOCK_N=BLOCK_N, + MIN_BLOCK_KV=_MIN_BLOCK_KV, + logit_cap=logit_cap, + num_warps=num_warps, + num_stages=2, + Lk=Lk, + Lv=Lv, + PAGE_SIZE=page_size, + ) + + grid_stage2 = (batch, head_num) + _fwd_kernel_int8_stage2[grid_stage2]( + att_out, + att_lse, + o, + kv_indptr, + last_page_len, + num_kv_splits, + att_out.stride(0), + att_out.stride(1), + att_out.stride(2), + o.stride(0), + o.stride(1), + MAX_KV_SPLITS=MAX_KV_SPLITS, + MIN_BLOCK_KV=_MIN_BLOCK_KV, + BLOCK_DV=BLOCK_DV, + Lv=Lv, + PAGE_SIZE=page_size, + num_warps=4, + num_stages=2, + ) diff --git a/vortex_torch/cache/triton_kernels/paged_prefill_int8.py b/vortex_torch/cache/triton_kernels/paged_prefill_int8.py new file mode 100644 index 0000000..75c3857 --- /dev/null +++ b/vortex_torch/cache/triton_kernels/paged_prefill_int8.py @@ -0,0 +1,90 @@ +""" +OOM-safe bf16 fallback for int8 KV-cache prefill. + +Instead of implementing full 2D-tiled Triton prefill with int8 dequantization, +this module dequantizes only the accessed KV pages into a compact temporary +bf16 buffer and remaps indices so FlashInfer can operate on the compact buffer. + +This avoids dequantizing the entire global cache buffer. +""" + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _dequant_pages_kernel( + src_int8, # int8 paged buffer [num_pages, page_size, head_dim] flat + src_scale, # float32 scale buffer [num_pages, page_size, 1] flat + dst_bf16, # bf16 compact buffer [num_accessed_pages, page_size, head_dim] flat + page_indices, # int32 [num_accessed_pages] — which global pages to dequant + NUM_PAGES: tl.constexpr, + PAGE_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_DIM: tl.constexpr, +): + """Dequantize selected int8 pages to bf16 compact buffer.""" + page_idx = tl.program_id(0) # index into page_indices + token_idx = tl.program_id(1) # token within page [0, PAGE_SIZE) + + if page_idx >= NUM_PAGES: + return + + global_page_id = tl.load(page_indices + page_idx) + dims = tl.arange(0, BLOCK_DIM) + mask_dim = dims < HEAD_DIM + + # Source: global_page_id * PAGE_SIZE * HEAD_DIM + token_idx * HEAD_DIM + dims + src_offset = (global_page_id * PAGE_SIZE + token_idx) * HEAD_DIM + dims + val_int8 = tl.load(src_int8 + src_offset, mask=mask_dim, other=0).to(tl.float32) + + # Scale: global_page_id * PAGE_SIZE + token_idx + scale_offset = global_page_id * PAGE_SIZE + token_idx + scale = tl.load(src_scale + scale_offset) + + val_bf16 = (val_int8 * scale).to(tl.bfloat16) + + # Destination: page_idx * PAGE_SIZE * HEAD_DIM + token_idx * HEAD_DIM + dims + dst_offset = (page_idx * PAGE_SIZE + token_idx) * HEAD_DIM + dims + tl.store(dst_bf16 + dst_offset, val_bf16, mask=mask_dim) + + +def dequant_paged_int8_to_bf16( + src_int8: torch.Tensor, # int8 [num_pages, page_size, head_dim] + src_scale: torch.Tensor, # float32 [num_pages, page_size, 1] + page_indices: torch.Tensor, # int32 [num_accessed_pages] + page_size: int, + head_dim: int, +) -> torch.Tensor: + """ + Dequantize only the accessed pages from int8 cache to a compact bf16 buffer. + + Returns: + bf16 tensor of shape [num_accessed_pages, page_size, head_dim] + """ + num_accessed_pages = page_indices.shape[0] + if num_accessed_pages == 0: + return torch.empty((0, page_size, head_dim), dtype=torch.bfloat16, device=src_int8.device) + + dst_bf16 = torch.empty( + (num_accessed_pages, page_size, head_dim), + dtype=torch.bfloat16, + device=src_int8.device, + ) + + BLOCK_DIM = triton.next_power_of_2(head_dim) + + grid = (num_accessed_pages, page_size) + _dequant_pages_kernel[grid]( + src_int8, + src_scale, + dst_bf16, + page_indices, + NUM_PAGES=num_accessed_pages, + PAGE_SIZE=page_size, + HEAD_DIM=head_dim, + BLOCK_DIM=BLOCK_DIM, + ) + + return dst_bf16 diff --git a/vortex_torch/cache/triton_kernels/set_kv.py b/vortex_torch/cache/triton_kernels/set_kv.py index cfa3cab..4318428 100644 --- a/vortex_torch/cache/triton_kernels/set_kv.py +++ b/vortex_torch/cache/triton_kernels/set_kv.py @@ -36,6 +36,93 @@ def set_kv_buffer_kernel( tl.store(dst_v_ptr, src_v) +@triton.jit +def set_kv_buffer_int8_kernel( + k_cache, # int8 paged K cache + v_cache, # int8 paged V cache + k_scale_cache, # float32 per-token K scale [num_pages, page_size, 1] + v_scale_cache, # float32 per-token V scale [num_pages, page_size, 1] + new_k, # bf16 input K [NNZ, NUM_KV_HEAD, HEAD_DIM] + new_v, # bf16 input V [NNZ, NUM_KV_HEAD, HEAD_DIM] + loc, # int64 token positions + NUM_KV_HEAD: tl.constexpr, + NNZ: tl.constexpr, + HEAD_DIM: tl.constexpr, + PAGE_SIZE: tl.constexpr +): + """Quantize bf16 K/V to int8 with per-token absmax scaling and write to paged buffers.""" + token_id = tl.program_id(0) + if token_id >= NNZ: + return + head_id = tl.program_id(1) + dim = tl.arange(0, HEAD_DIM) + + # Load bf16 source values + src_ptr = token_id * NUM_KV_HEAD * HEAD_DIM + head_id * HEAD_DIM + dim + src_k = tl.load(new_k + src_ptr).to(tl.float32) + src_v = tl.load(new_v + src_ptr).to(tl.float32) + + # Compute per-token absmax scale: scale = absmax / 127 + absmax_k = tl.max(tl.abs(src_k), axis=0) + absmax_v = tl.max(tl.abs(src_v), axis=0) + # Avoid division by zero + scale_k = absmax_k / 127.0 + 1e-10 + scale_v = absmax_v / 127.0 + 1e-10 + + # Quantize to int8: round(x / scale), clamp to [-128, 127] + q_k = tl.extra.cuda.libdevice.rint(src_k / scale_k) + q_k = tl.minimum(tl.maximum(q_k, -128.0), 127.0).to(tl.int8) + q_v = tl.extra.cuda.libdevice.rint(src_v / scale_v) + q_v = tl.minimum(tl.maximum(q_v, -128.0), 127.0).to(tl.int8) + + # Compute paged destination offset (same layout as bf16 kernel) + token_position = tl.load(loc + token_id) + page_id = token_position // PAGE_SIZE + in_page_offset = token_position % PAGE_SIZE + position_trans = page_id * (PAGE_SIZE * NUM_KV_HEAD) + head_id * PAGE_SIZE + in_page_offset + + # Write int8 values + dst_k_ptr = k_cache + position_trans * HEAD_DIM + dim + dst_v_ptr = v_cache + position_trans * HEAD_DIM + dim + tl.store(dst_k_ptr, q_k) + tl.store(dst_v_ptr, q_v) + + # Write per-token scales: shape [num_pages, page_size, 1] + # Layout: page_id * PAGE_SIZE + in_page_offset (flat per-head, one scale per token per head) + scale_offset = (page_id * NUM_KV_HEAD + head_id) * PAGE_SIZE + in_page_offset + tl.store(k_scale_cache + scale_offset, scale_k) + tl.store(v_scale_cache + scale_offset, scale_v) + + +def set_kv_buffer_int8_launcher( + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_scale_cache: torch.Tensor, + v_scale_cache: torch.Tensor, + new_k: torch.Tensor, + new_v: torch.Tensor, + loc: torch.LongTensor, + page_size: int +): + NNZ = loc.shape[0] + NUM_KV_HEAD = new_k.shape[1] + HEAD_DIM = new_k.shape[2] + + set_kv_buffer_int8_kernel[(NNZ, NUM_KV_HEAD)]( + k_cache, + v_cache, + k_scale_cache, + v_scale_cache, + new_k, + new_v, + loc, + NUM_KV_HEAD, + NNZ, + HEAD_DIM, + page_size + ) + + def set_kv_buffer_launcher( k_cache: torch.Tensor, v_cache: torch.Tensor, From 1f52772d451ce0e8663784c1052475b4f7b2618f Mon Sep 17 00:00:00 2001 From: UED Date: Mon, 23 Feb 2026 03:19:56 +0000 Subject: [PATCH 02/22] 1. Add support for pro 6000. 2. Correction for vortex --- CLAUDE.md | 53 +++++++++++ setup.py | 5 +- third_party/sglang | 2 +- vortex_torch/cache/__init__.py | 3 +- vortex_torch/cache/triton_kernels/__init__.py | 3 +- .../cache/triton_kernels/paged_decode_int8.py | 42 +++++---- .../triton_kernels/paged_prefill_int8.py | 94 +++++++++++++++++-- vortex_torch/cache/triton_kernels/set_kv.py | 10 +- 8 files changed, 178 insertions(+), 34 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index db54c75..1593d61 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -86,3 +86,56 @@ Custom SGLang fork lives in `third_party/sglang` (git submodule, "graph" branch) - **GeMM semantics**: `GeMM(x, y)` computes `y @ x^T` (note transposition) - **Standard cache keys**: `"k"` and `"v"` have inner shape `(page_size, head_dim)`; custom caches declared in `create_cache()` - **Branch**: Main development is on `v1` + +## Workflow Orchestration + +### 1. Plan Node Default +- Enter plan mode for ANY non-trivial task (3+ steps or architectural decisions) +- If something goes sideways, STOP and re-plan immediately - don't keep pushing +- Use plan mode for verification steps, not just building +- Write detailed specs upfront to reduce ambiguity + +### 2. Subagent Strategy +- Use subagents liberally to keep main context window clean +- Offload research, exploration, and parallel analysis to subagents +- For complex problems, throw more compute at it via subagents +- One tack per subagent for focused execution + +### 3. Self-Improvement Loop +- After ANY correction from the user: update `tasks/lessons.md` with the pattern +- Write rules for yourself that prevent the same mistake +- Ruthlessly iterate on these lessons until mistake rate drops +- Review lessons at session start for relevant project + +### 4. Verification Before Done +- Never mark a task complete without proving it works +- Diff behavior between main and your changes when relevant +- Ask yourself: "Would a staff engineer approve this?" +- Run tests, check logs, demonstrate correctness + +### 5. Demand Elegance (Balanced) +- For non-trivial changes: pause and ask "is there a more elegant way?" +- If a fix feels hacky: "Knowing everything I know now, implement the elegant solution" +- Skip this for simple, obvious fixes - don't over-engineer +- Challenge your own work before presenting it + +### 6. Autonomous Bug Fixing +- When given a bug report: just fix it. Don't ask for hand-holding +- Point at logs, errors, failing tests - then resolve them +- Zero context switching required from the user +- Go fix failing CI tests without being told how + +## Task Management + +1. **Plan First**: Write plan to `tasks/todo.md` with checkable items +2. **Verify Plan**: Check in before starting implementation +3. **Track Progress**: Mark items complete as you go +4. **Explain Changes**: High-level summary at each step +5. **Document Results**: Add review section to `tasks/todo.md` +6. **Capture Lessons**: Update `tasks/lessons.md` after corrections + +## Core Principles + +- **Simplicity First**: Make every change as simple as possible. Impact minimal code. +- **No Laziness**: Find root causes. No temporary fixes. Senior developer standards. +- **Minimat Impact**: Changes should only touch what's necessary. Avoid introducing bugs. \ No newline at end of file diff --git a/setup.py b/setup.py index e272326..6efeebe 100644 --- a/setup.py +++ b/setup.py @@ -23,8 +23,11 @@ 'cxx': ['-O3'], 'nvcc': [ '-O3', + '-gencode=arch=compute_86,code=sm_86', '-gencode=arch=compute_89,code=sm_89', - '-gencode=arch=compute_90,code=sm_90' + '-gencode=arch=compute_90,code=sm_90', + '-gencode=arch=compute_100a,code=sm_100a', + '-gencode=arch=compute_120,code=sm_120' ], }, ), diff --git a/third_party/sglang b/third_party/sglang index e383c0f..9672e9a 160000 --- a/third_party/sglang +++ b/third_party/sglang @@ -1 +1 @@ -Subproject commit e383c0fdd551f74f24d247e8a7cc8013861949ad +Subproject commit 9672e9a7f90bcb782ccdfb2ee123ede7f2ef5d17 diff --git a/vortex_torch/cache/__init__.py b/vortex_torch/cache/__init__.py index b886559..b32d8bc 100644 --- a/vortex_torch/cache/__init__.py +++ b/vortex_torch/cache/__init__.py @@ -29,12 +29,13 @@ from .matmul import GeMM from .elementwise import Relu, Silu, Sigmoid, Abs, Add_Mul from .elementwise_binary import Maximum, Minimum, Multiply, Add -from .triton_kernels import set_kv_buffer_launcher, set_kv_buffer_int8_launcher +from .triton_kernels import set_kv_buffer_launcher, set_kv_buffer_int8_launcher, dequant_paged_int8_to_bf16_inplace __all__ = [ "set_kv_buffer_launcher", "set_kv_buffer_int8_launcher", + "dequant_paged_int8_to_bf16_inplace", "Mean", "Max", "Min", "L2Norm", "GeMM", "Relu", "Silu", "Sigmoid", "Abs", "Add_Mul", diff --git a/vortex_torch/cache/triton_kernels/__init__.py b/vortex_torch/cache/triton_kernels/__init__.py index 2d6384f..a18067e 100644 --- a/vortex_torch/cache/triton_kernels/__init__.py +++ b/vortex_torch/cache/triton_kernels/__init__.py @@ -1,11 +1,12 @@ from .set_kv import set_kv_buffer_launcher, set_kv_buffer_int8_launcher from .paged_decode_int8 import paged_decode_int8 -from .paged_prefill_int8 import dequant_paged_int8_to_bf16 +from .paged_prefill_int8 import dequant_paged_int8_to_bf16, dequant_paged_int8_to_bf16_inplace __all__ = [ "set_kv_buffer_launcher", "set_kv_buffer_int8_launcher", "paged_decode_int8", "dequant_paged_int8_to_bf16", + "dequant_paged_int8_to_bf16_inplace", ] diff --git a/vortex_torch/cache/triton_kernels/paged_decode_int8.py b/vortex_torch/cache/triton_kernels/paged_decode_int8.py index 480c787..4f33cd4 100644 --- a/vortex_torch/cache/triton_kernels/paged_decode_int8.py +++ b/vortex_torch/cache/triton_kernels/paged_decode_int8.py @@ -25,8 +25,8 @@ def _fwd_kernel_int8_stage1( Q, # [batch, num_qo_heads, head_dim] bf16 K_Buffer, # int8 paged: flat V_Buffer, # int8 paged: flat - K_Scale_Buffer, # float32: flat (one scale per token slot) - V_Scale_Buffer, # float32: flat + K_Scale_Buffer, # fp16: flat (one scale per token slot) + V_Scale_Buffer, # fp16: flat sm_scale, kv_indptr, # [batch + 1] int32, page-level kv_indices, # page indices @@ -118,7 +118,7 @@ def _fwd_kernel_int8_stage1( K_Scale_Buffer + kv_loc, mask=mask_n, other=1.0, - ) + ).to(tl.float32) k = k_int8 * k_scale[:, None] # Compute QK @@ -142,7 +142,7 @@ def _fwd_kernel_int8_stage1( V_Scale_Buffer + kv_loc, mask=mask_n, other=1.0, - ) + ).to(tl.float32) v = v_int8 * v_scale[:, None] # Online softmax accumulation @@ -251,8 +251,8 @@ def paged_decode_int8( q: torch.Tensor, # [batch, num_qo_heads, head_dim] bf16 k_buffer: torch.Tensor, # int8 paged K cache v_buffer: torch.Tensor, # int8 paged V cache - k_scale_buffer: torch.Tensor, # float32 scale for K - v_scale_buffer: torch.Tensor, # float32 scale for V + k_scale_buffer: torch.Tensor, # fp16 scale for K + v_scale_buffer: torch.Tensor, # fp16 scale for V o: torch.Tensor, # [batch, num_qo_heads, head_dim] bf16 output kv_indptr: torch.Tensor, # [batch + 1] int32, page-level kv_indices: torch.Tensor, # page indices @@ -262,6 +262,8 @@ def paged_decode_int8( sm_scale: float, page_size: int, logit_cap: float = 0.0, + att_out: torch.Tensor = None, # optional pre-allocated [batch, head_num, max_kv_splits, Lv] + att_lse: torch.Tensor = None, # optional pre-allocated [batch, head_num, max_kv_splits] ): """ Paged decode attention with int8 KV cache and inline dequantization. @@ -283,17 +285,23 @@ def paged_decode_int8( num_warps = 4 if kv_group_num == 1 else 2 - # Intermediate buffers for split reduction - att_out = torch.empty( - (batch, head_num, MAX_KV_SPLITS, Lv), - dtype=torch.float32, - device=q.device, - ) - att_lse = torch.empty( - (batch, head_num, MAX_KV_SPLITS), - dtype=torch.float32, - device=q.device, - ) + # Use pre-allocated buffers if provided, otherwise allocate + if att_out is None: + att_out = torch.empty( + (batch, head_num, MAX_KV_SPLITS, Lv), + dtype=torch.float32, + device=q.device, + ) + else: + att_out = att_out[:batch] + if att_lse is None: + att_lse = torch.empty( + (batch, head_num, MAX_KV_SPLITS), + dtype=torch.float32, + device=q.device, + ) + else: + att_lse = att_lse[:batch] stride_buf_kbs = k_buffer.shape[-1] stride_buf_vbs = v_buffer.shape[-1] diff --git a/vortex_torch/cache/triton_kernels/paged_prefill_int8.py b/vortex_torch/cache/triton_kernels/paged_prefill_int8.py index 75c3857..8927983 100644 --- a/vortex_torch/cache/triton_kernels/paged_prefill_int8.py +++ b/vortex_torch/cache/triton_kernels/paged_prefill_int8.py @@ -16,7 +16,7 @@ @triton.jit def _dequant_pages_kernel( src_int8, # int8 paged buffer [num_pages, page_size, head_dim] flat - src_scale, # float32 scale buffer [num_pages, page_size, 1] flat + src_scale, # fp16 scale buffer [num_pages, page_size, 1] flat dst_bf16, # bf16 compact buffer [num_accessed_pages, page_size, head_dim] flat page_indices, # int32 [num_accessed_pages] — which global pages to dequant NUM_PAGES: tl.constexpr, @@ -41,7 +41,7 @@ def _dequant_pages_kernel( # Scale: global_page_id * PAGE_SIZE + token_idx scale_offset = global_page_id * PAGE_SIZE + token_idx - scale = tl.load(src_scale + scale_offset) + scale = tl.load(src_scale + scale_offset).to(tl.float32) val_bf16 = (val_int8 * scale).to(tl.bfloat16) @@ -52,26 +52,35 @@ def _dequant_pages_kernel( def dequant_paged_int8_to_bf16( src_int8: torch.Tensor, # int8 [num_pages, page_size, head_dim] - src_scale: torch.Tensor, # float32 [num_pages, page_size, 1] + src_scale: torch.Tensor, # fp16 [num_pages, page_size, 1] page_indices: torch.Tensor, # int32 [num_accessed_pages] page_size: int, head_dim: int, + out: torch.Tensor = None, # optional pre-allocated bf16 [>=num_accessed_pages, page_size, head_dim] ) -> torch.Tensor: """ Dequantize only the accessed pages from int8 cache to a compact bf16 buffer. + If `out` is provided, writes into it (must have room for num_accessed_pages). + Otherwise allocates a new buffer. + Returns: bf16 tensor of shape [num_accessed_pages, page_size, head_dim] """ num_accessed_pages = page_indices.shape[0] if num_accessed_pages == 0: + if out is not None: + return out[:0] return torch.empty((0, page_size, head_dim), dtype=torch.bfloat16, device=src_int8.device) - dst_bf16 = torch.empty( - (num_accessed_pages, page_size, head_dim), - dtype=torch.bfloat16, - device=src_int8.device, - ) + if out is not None: + dst_bf16 = out[:num_accessed_pages] + else: + dst_bf16 = torch.empty( + (num_accessed_pages, page_size, head_dim), + dtype=torch.bfloat16, + device=src_int8.device, + ) BLOCK_DIM = triton.next_power_of_2(head_dim) @@ -88,3 +97,72 @@ def dequant_paged_int8_to_bf16( ) return dst_bf16 + + +@triton.jit +def _dequant_pages_inplace_kernel( + src_int8, # int8 paged buffer flat + src_scale, # scale buffer flat (one scale per token slot) + dst_bf16, # bf16 destination buffer (same page layout as src) + page_indices, # int32 [num_pages] — which global pages to dequant + NUM_PAGES: tl.constexpr, + PAGE_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_DIM: tl.constexpr, +): + """Dequantize selected int8 pages to bf16, writing to the SAME page positions in dst.""" + page_idx = tl.program_id(0) # index into page_indices + token_idx = tl.program_id(1) # token within page [0, PAGE_SIZE) + + if page_idx >= NUM_PAGES: + return + + global_page_id = tl.load(page_indices + page_idx) + dims = tl.arange(0, BLOCK_DIM) + mask_dim = dims < HEAD_DIM + + # Source and destination use the SAME offset (in-place layout) + offset = (global_page_id * PAGE_SIZE + token_idx) * HEAD_DIM + dims + val_int8 = tl.load(src_int8 + offset, mask=mask_dim, other=0).to(tl.float32) + + scale_offset = global_page_id * PAGE_SIZE + token_idx + scale = tl.load(src_scale + scale_offset).to(tl.float32) + + val_bf16 = (val_int8 * scale).to(tl.bfloat16) + + # Write to the SAME page position in dst (not compacted) + tl.store(dst_bf16 + offset, val_bf16, mask=mask_dim) + + +def dequant_paged_int8_to_bf16_inplace( + src_int8: torch.Tensor, # int8 paged cache (flat) + src_scale: torch.Tensor, # fp16 scale buffer (flat) + dst_bf16: torch.Tensor, # bf16 destination (same shape as src_int8) + page_indices: torch.Tensor, # int32 [num_pages] — which pages to dequant + page_size: int, + head_dim: int, +) -> None: + """ + Dequantize selected pages from int8 cache to bf16 IN-PLACE. + + Unlike dequant_paged_int8_to_bf16 (which compacts into a dense buffer), + this writes to the SAME page positions in dst_bf16, preserving the paged layout. + Used to populate the bf16 working buffer for forward_cache (centroid computation). + """ + num_pages = page_indices.shape[0] + if num_pages == 0: + return + + BLOCK_DIM = triton.next_power_of_2(head_dim) + + grid = (num_pages, page_size) + _dequant_pages_inplace_kernel[grid]( + src_int8, + src_scale, + dst_bf16, + page_indices, + NUM_PAGES=num_pages, + PAGE_SIZE=page_size, + HEAD_DIM=head_dim, + BLOCK_DIM=BLOCK_DIM, + ) diff --git a/vortex_torch/cache/triton_kernels/set_kv.py b/vortex_torch/cache/triton_kernels/set_kv.py index 4318428..2a2c785 100644 --- a/vortex_torch/cache/triton_kernels/set_kv.py +++ b/vortex_torch/cache/triton_kernels/set_kv.py @@ -40,8 +40,8 @@ def set_kv_buffer_kernel( def set_kv_buffer_int8_kernel( k_cache, # int8 paged K cache v_cache, # int8 paged V cache - k_scale_cache, # float32 per-token K scale [num_pages, page_size, 1] - v_scale_cache, # float32 per-token V scale [num_pages, page_size, 1] + k_scale_cache, # fp16 per-token K scale [num_pages, page_size, 1] + v_scale_cache, # fp16 per-token V scale [num_pages, page_size, 1] new_k, # bf16 input K [NNZ, NUM_KV_HEAD, HEAD_DIM] new_v, # bf16 input V [NNZ, NUM_KV_HEAD, HEAD_DIM] loc, # int64 token positions @@ -87,11 +87,11 @@ def set_kv_buffer_int8_kernel( tl.store(dst_k_ptr, q_k) tl.store(dst_v_ptr, q_v) - # Write per-token scales: shape [num_pages, page_size, 1] + # Write per-token scales (fp16): shape [num_pages, page_size, 1] # Layout: page_id * PAGE_SIZE + in_page_offset (flat per-head, one scale per token per head) scale_offset = (page_id * NUM_KV_HEAD + head_id) * PAGE_SIZE + in_page_offset - tl.store(k_scale_cache + scale_offset, scale_k) - tl.store(v_scale_cache + scale_offset, scale_v) + tl.store(k_scale_cache + scale_offset, scale_k.to(tl.float16)) + tl.store(v_scale_cache + scale_offset, scale_v.to(tl.float16)) def set_kv_buffer_int8_launcher( From 584f23355412a4464215ccb15d06758ad1b2762c Mon Sep 17 00:00:00 2001 From: UED Date: Mon, 23 Feb 2026 07:07:28 +0000 Subject: [PATCH 03/22] 1. Correction on int8 (maximize memory occupation) 2. Implement fp8 quantization. --- CLAUDE.md | 141 -------- examples/verify_algo.py | 14 +- examples/verify_algo_fp8.sh | 25 ++ ...rify_algo_quant.sh => verify_algo_int8.sh} | 0 third_party/sglang | 2 +- vortex_torch/cache/__init__.py | 3 +- vortex_torch/cache/context.py | 20 +- vortex_torch/cache/reduce.py | 6 +- vortex_torch/cache/triton_kernels/__init__.py | 3 +- .../cache/triton_kernels/reduce_impl.py | 328 +++++++++--------- vortex_torch/cache/triton_kernels/set_kv.py | 97 +++++- 11 files changed, 319 insertions(+), 320 deletions(-) delete mode 100644 CLAUDE.md create mode 100755 examples/verify_algo_fp8.sh rename examples/{verify_algo_quant.sh => verify_algo_int8.sh} (100%) diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index 1593d61..0000000 --- a/CLAUDE.md +++ /dev/null @@ -1,141 +0,0 @@ -# CLAUDE.md - -This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. - -## Project Overview - -Vortex is a lightweight, modular framework for building custom sparse attention algorithms for LLM inference. It provides a PyTorch-like frontend that abstracts away batching, caching, and paged attention, running on optimized backends (FlashInfer, CUDA Graph) via SGLang integration. - -## Build & Install - -```bash -# Install SGLang dependency (custom fork in third_party/) -cd third_party/sglang && bash install.sh && cd ../../ - -# Install Vortex (editable mode, compiles CUDA extensions for SM_89/SM_90) -pip install -e . -``` - -Requires Python >=3.10, torch>=2.7. CUDA extensions are built from `csrc/` (register.cc, utils_sglang.cu, topk.cu). - -## Running Examples - -```bash -# Single algorithm verification against SGLang -python examples/verify_algo.py --trials 2 --topk-val 30 --vortex-module-name block_sparse_attention - -# Batch test multiple algorithms -bash examples/verify_algo.sh -``` - -## Building Documentation - -```bash -make -C docs html -``` - -Uses Sphinx with myst_parser and furo theme. Deployed via GitHub Actions on push to v1 branch. - -## Architecture - -### Core Abstraction: vFlow (`vortex_torch/flow/flow.py`) - -All sparse attention algorithms inherit from `vFlow` and implement three methods: - -- **`forward_indexer(q, o, cache, ctx)`** — Compute sparse page indices from queries. Operates on page-packed tensor view `[S, r, c]`. -- **`forward_cache(cache, loc, ctx)`** — Update/summarize custom cache tensors when a page completes. Operates on batch-major view `[B, r, c]`. -- **`create_cache(page_size, head_dim)`** — Declare custom cache tensor shapes as a dict of `{name: (rows, cols)}`. - -Algorithms are registered via `@register("name")` decorator and instantiated with `build_vflow()`. - -### Operator System (`vortex_torch/indexer/`, `vortex_torch/cache/`) - -Operators (`vOp` subclasses) run in two modes: -- **Profile mode**: Pre-compute output shapes and allocate buffers -- **Execute mode**: Perform actual GPU computation - -Operators are split into two parallel hierarchies: -- **Indexer ops** (`vortex_torch/indexer/`): GeMM, GeMV, topK, reduce (Mean/Max/Min/Sum/L2Norm), softmax, elementwise, transpose, save/load -- **Cache ops** (`vortex_torch/cache/`): GeMM, reduce, elementwise, fill, KV buffer setup - -Both use Triton kernels (in respective `triton_kernels/` subdirectories) for GPU execution. - -### Tensor Format (`vortex_torch/abs/tensor.py`) - -`vTensor` wraps `torch.Tensor` with format metadata (BATCHED, RAGGED, PAGED) to enforce layout consistency across operations. - -### Context System (`vortex_torch/abs/context_base.py`) - -`ContextBase` carries per-step runtime state. Specialized as: -- `Indexer.Context`: Page layout, head config, hardware info -- `Cache.Context`: Page size, total pages, model info - -### Concrete Algorithms (`vortex_torch/flow/algorithms.py`) - -- **BlockSparseAttention**: Centroid-based routing (query avg → GeMV with centroids → topK) -- **GQABlockSparseAttention**: Grouped-query variant with softmax + group aggregation -- **GQAQuestSparseAttention**: Query-envelope matching using per-page max/min bounds - -### SGLang Integration - -Custom SGLang fork lives in `third_party/sglang` (git submodule, "graph" branch). CUDA extensions in `csrc/` provide PyBind11 bindings for `sglang_plan_decode`, `sglang_plan_prefill`, and transpose operations. - -## Key Conventions - -- **Tensor shapes**: Query `[B, H_q, D]`, sparse output `[S_sparse, 1, 1]`, cache indexer-view `[S, r, c]`, cache batch-view `[B, r, c]` -- **GeMM semantics**: `GeMM(x, y)` computes `y @ x^T` (note transposition) -- **Standard cache keys**: `"k"` and `"v"` have inner shape `(page_size, head_dim)`; custom caches declared in `create_cache()` -- **Branch**: Main development is on `v1` - -## Workflow Orchestration - -### 1. Plan Node Default -- Enter plan mode for ANY non-trivial task (3+ steps or architectural decisions) -- If something goes sideways, STOP and re-plan immediately - don't keep pushing -- Use plan mode for verification steps, not just building -- Write detailed specs upfront to reduce ambiguity - -### 2. Subagent Strategy -- Use subagents liberally to keep main context window clean -- Offload research, exploration, and parallel analysis to subagents -- For complex problems, throw more compute at it via subagents -- One tack per subagent for focused execution - -### 3. Self-Improvement Loop -- After ANY correction from the user: update `tasks/lessons.md` with the pattern -- Write rules for yourself that prevent the same mistake -- Ruthlessly iterate on these lessons until mistake rate drops -- Review lessons at session start for relevant project - -### 4. Verification Before Done -- Never mark a task complete without proving it works -- Diff behavior between main and your changes when relevant -- Ask yourself: "Would a staff engineer approve this?" -- Run tests, check logs, demonstrate correctness - -### 5. Demand Elegance (Balanced) -- For non-trivial changes: pause and ask "is there a more elegant way?" -- If a fix feels hacky: "Knowing everything I know now, implement the elegant solution" -- Skip this for simple, obvious fixes - don't over-engineer -- Challenge your own work before presenting it - -### 6. Autonomous Bug Fixing -- When given a bug report: just fix it. Don't ask for hand-holding -- Point at logs, errors, failing tests - then resolve them -- Zero context switching required from the user -- Go fix failing CI tests without being told how - -## Task Management - -1. **Plan First**: Write plan to `tasks/todo.md` with checkable items -2. **Verify Plan**: Check in before starting implementation -3. **Track Progress**: Mark items complete as you go -4. **Explain Changes**: High-level summary at each step -5. **Document Results**: Add review section to `tasks/todo.md` -6. **Capture Lessons**: Update `tasks/lessons.md` after corrections - -## Core Principles - -- **Simplicity First**: Make every change as simple as possible. Impact minimal code. -- **No Laziness**: Find root causes. No temporary fixes. Senior developer standards. -- **Minimat Impact**: Changes should only touch what's necessary. Avoid introducing bugs. \ No newline at end of file diff --git a/examples/verify_algo.py b/examples/verify_algo.py index f418598..9958b7e 100644 --- a/examples/verify_algo.py +++ b/examples/verify_algo.py @@ -113,13 +113,13 @@ def verify_algos( } ) # --- Per-question debug output --- - print(f"[Q{len(results):03d}] score={float(result):.1f} " - f"tokens={item['meta_info']['completion_tokens']} " - f"latency={item['meta_info']['e2e_latency']:.2f}s " - f"gold={golds[0]}") - print(f" question: {data['question'][:120]}...") - print(f" prediction: {predictions[:200]}...") - print() + # print(f"[Q{len(results):03d}] score={float(result):.1f} " + # f"tokens={item['meta_info']['completion_tokens']} " + # f"latency={item['meta_info']['e2e_latency']:.2f}s " + # f"gold={golds[0]}") + # print(f" question: {data['question'][:120]}...") + # print(f" prediction: {predictions[:200]}...") + # print() total_accuracy = 0.0 diff --git a/examples/verify_algo_fp8.sh b/examples/verify_algo_fp8.sh new file mode 100755 index 0000000..7f266e5 --- /dev/null +++ b/examples/verify_algo_fp8.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +set -e +export CUDA_VISIBLE_DEVICES=3 + +sparse_algos=( + "block_sparse_attention" +) + +RESULTS_DIR="results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + + for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_fp8_${TIMESTAMP}.log" + echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --kv-cache-dtype fp8_e4m3" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --kv-cache-dtype fp8_e4m3 \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done diff --git a/examples/verify_algo_quant.sh b/examples/verify_algo_int8.sh similarity index 100% rename from examples/verify_algo_quant.sh rename to examples/verify_algo_int8.sh diff --git a/third_party/sglang b/third_party/sglang index 9672e9a..7105719 160000 --- a/third_party/sglang +++ b/third_party/sglang @@ -1 +1 @@ -Subproject commit 9672e9a7f90bcb782ccdfb2ee123ede7f2ef5d17 +Subproject commit 7105719f0a2ac464ee7ffdc0a899fa6a656656a2 diff --git a/vortex_torch/cache/__init__.py b/vortex_torch/cache/__init__.py index b32d8bc..8c4d0e0 100644 --- a/vortex_torch/cache/__init__.py +++ b/vortex_torch/cache/__init__.py @@ -29,12 +29,13 @@ from .matmul import GeMM from .elementwise import Relu, Silu, Sigmoid, Abs, Add_Mul from .elementwise_binary import Maximum, Minimum, Multiply, Add -from .triton_kernels import set_kv_buffer_launcher, set_kv_buffer_int8_launcher, dequant_paged_int8_to_bf16_inplace +from .triton_kernels import set_kv_buffer_launcher, set_kv_buffer_int8_launcher, set_kv_buffer_fp8_launcher, dequant_paged_int8_to_bf16_inplace __all__ = [ "set_kv_buffer_launcher", "set_kv_buffer_int8_launcher", + "set_kv_buffer_fp8_launcher", "dequant_paged_int8_to_bf16_inplace", "Mean", "Max", "Min", "L2Norm", "GeMM", diff --git a/vortex_torch/cache/context.py b/vortex_torch/cache/context.py index ae2dd5c..dd1bd02 100644 --- a/vortex_torch/cache/context.py +++ b/vortex_torch/cache/context.py @@ -10,17 +10,21 @@ class Context(ContextBase): """ __slots__ = ContextBase.__slots__ + ( - + #page infomation "max_new_tokens_per_batch", "page_size", "total_num_pages", - + #model infomation "head_dim", "head_num", - + # auxilary memory in graph "_aux_total_bytes", - - "_aux_total_flops" + + "_aux_total_flops", + + # FP8 quantization: fp8_type (0=none, 1=e4m3, 2=e5m2), kv_scale (per-tensor) + "fp8_type", + "kv_scale", ) @@ -36,7 +40,11 @@ def __init__(self) -> None: elif name == "_aux_total_flops": object.__setattr__(self, name, 0) # start from 0 flops elif name == "mode": - object.__setattr__(self, name, Mode.profile) + object.__setattr__(self, name, Mode.profile) + elif name == "fp8_type": + object.__setattr__(self, name, 0) # 0 = no fp8 (bf16 default) + elif name == "kv_scale": + object.__setattr__(self, name, 1.0) # identity scale for bf16 else: object.__setattr__(self, name, UNSET) diff --git a/vortex_torch/cache/reduce.py b/vortex_torch/cache/reduce.py index 3c4edf2..5800458 100644 --- a/vortex_torch/cache/reduce.py +++ b/vortex_torch/cache/reduce.py @@ -345,8 +345,10 @@ def execute( ) output = self.output_buffer - # Launch the kernel/implementation: impl(x, output, loc, ctx, dim, reduce_type) - self.impl(x, output, loc, ctx, self.dim, self.reduce_type) + # Launch the kernel/implementation: impl(x, output, loc, ctx, dim, reduce_type, fp8_type, scale) + fp8_type = getattr(ctx, 'fp8_type', 0) + scale = getattr(ctx, 'kv_scale', 1.0) + self.impl(x, output, loc, ctx, self.dim, self.reduce_type, fp8_type, scale) return output diff --git a/vortex_torch/cache/triton_kernels/__init__.py b/vortex_torch/cache/triton_kernels/__init__.py index a18067e..009e728 100644 --- a/vortex_torch/cache/triton_kernels/__init__.py +++ b/vortex_torch/cache/triton_kernels/__init__.py @@ -1,10 +1,11 @@ -from .set_kv import set_kv_buffer_launcher, set_kv_buffer_int8_launcher +from .set_kv import set_kv_buffer_launcher, set_kv_buffer_int8_launcher, set_kv_buffer_fp8_launcher from .paged_decode_int8 import paged_decode_int8 from .paged_prefill_int8 import dequant_paged_int8_to_bf16, dequant_paged_int8_to_bf16_inplace __all__ = [ "set_kv_buffer_launcher", "set_kv_buffer_int8_launcher", + "set_kv_buffer_fp8_launcher", "paged_decode_int8", "dequant_paged_int8_to_bf16", "dequant_paged_int8_to_bf16_inplace", diff --git a/vortex_torch/cache/triton_kernels/reduce_impl.py b/vortex_torch/cache/triton_kernels/reduce_impl.py index 9921e08..9670acd 100644 --- a/vortex_torch/cache/triton_kernels/reduce_impl.py +++ b/vortex_torch/cache/triton_kernels/reduce_impl.py @@ -4,6 +4,16 @@ from ..context import Context from ...utils import ReduceType + +# --------------------------------------------------------------------------- +# Helper: Load a page block from src_ptr, handling bf16 or fp8-stored-as-uint8. +# FP8_TYPE == 0 -> bf16 pointer, load normally +# FP8_TYPE == 1 -> uint8 pointer, bitcast to float8e4nv, dequant with scale +# FP8_TYPE == 2 -> uint8 pointer, bitcast to float8e5, dequant with scale +# All paths return a float32 tensor ready for reduction. +# --------------------------------------------------------------------------- + + @triton.jit def reduce_pp_kernel( x, output, loc, @@ -12,9 +22,11 @@ def reduce_pp_kernel( NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, REDUCE_TYPE: tl.constexpr, # 0:Mean, 1:Max, 2:Min, 3:L2Norm -DIM: tl.constexpr # 1: over rows (axis=0) -> len x_D1; 2: over cols (axis=1) -> len x_D0 -): - +DIM: tl.constexpr, # 1: over rows (axis=0) -> len x_D1; 2: over cols (axis=1) -> len x_D0 +FP8_TYPE: tl.constexpr, # 0: bf16, 1: e4m3, 2: e5m2 +scale, # float: 1.0 for bf16, kv_scale for fp8 +): + token_id = tl.program_id(0) head_id = tl.program_id(1) @@ -29,7 +41,15 @@ def reduce_pp_kernel( rows = tl.arange(0, x_D0)[:, None] # [x_D0, 1] cols = tl.arange(0, x_D1)[None, :] # [1, x_D1] src_ptr = x + x_offset + rows * x_D1 + cols - page_block = tl.load(src_ptr) + + if FP8_TYPE == 1: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale + elif FP8_TYPE == 2: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale + else: + page_block = tl.load(src_ptr).to(tl.float32) if DIM == 1: # reduce over rows -> axis=0 -> length x_D1 @@ -40,7 +60,7 @@ def reduce_pp_kernel( elif REDUCE_TYPE == 2: # Min reduce_vec = tl.min(page_block, axis=0).to(tl.bfloat16) else: # L2Norm - s = tl.sum(page_block * page_block, axis=0).to(tl.float32) + s = tl.sum(page_block * page_block, axis=0) reduce_vec = tl.sqrt(s).to(tl.bfloat16) dst_ptr = output + page_id * x_D1 + tl.arange(0, x_D1) @@ -55,7 +75,7 @@ def reduce_pp_kernel( elif REDUCE_TYPE == 2: # Min reduce_vec = tl.min(page_block, axis=1).to(tl.bfloat16) else: # L2Norm - s = tl.sum(page_block * page_block, axis=1).to(tl.float32) + s = tl.sum(page_block * page_block, axis=1) reduce_vec = tl.sqrt(s).to(tl.bfloat16) dst_ptr = output + page_id * x_D0 + tl.arange(0, x_D0) @@ -71,11 +91,13 @@ def reduce_pp( ctx: Context, dim: int, reduce_type: ReduceType, +fp8_type: int = 0, +scale: float = 1.0, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = ctx.head_num - + reduce_pp_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -85,7 +107,9 @@ def reduce_pp( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + FP8_TYPE=fp8_type, + scale=scale, ) @@ -97,11 +121,13 @@ def _reduce_pp( page_size: int, dim: int, reduce_type: ReduceType, +fp8_type: int = 0, +scale: float = 1.0, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = num_kv_heads - + reduce_pp_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -111,7 +137,9 @@ def _reduce_pp( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + FP8_TYPE=fp8_type, + scale=scale, ) @@ -119,84 +147,67 @@ def _reduce_pp( @triton.jit def reduce_rp_kernel( x, output, loc, - x_D0: tl.constexpr, # rows per token-page - x_D1: tl.constexpr, # cols per token-page + x_D0: tl.constexpr, + x_D1: tl.constexpr, NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, - REDUCE_TYPE: tl.constexpr, # 0: Mean, 1: Max, 2: Min, 3: L2Norm (not RMS) - DIM: tl.constexpr # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 + REDUCE_TYPE: tl.constexpr, + DIM: tl.constexpr, + FP8_TYPE: tl.constexpr, + scale, ): - - # Program IDs: - # pid0 = token index (0 .. num_tokens-1) - # pid1 = head index (0 .. NUM_KV_HEAD-1) + token_id = tl.program_id(0) head_id = tl.program_id(1) - # Load the absolute position of this token (used to map to page index). token_position = tl.load(loc + token_id) - # Only the last token of a page triggers the reduction. if (token_position + 1) % PAGE_SIZE != 0: return - # Output page index: - # Logical page = token_position // PAGE_SIZE - # One vector per head, so linearize by NUM_KV_HEAD. page_id = (token_position // PAGE_SIZE) * NUM_KV_HEAD + head_id - - # Input layout is [num_tokens, num_heads, x_D0, x_D1] (row-major). - # For this token/head, compute the base element offset in `x`. x_offset = (token_id * NUM_KV_HEAD + head_id) * x_D0 * x_D1 - # Build 2D indices within a page (row-major addressing). - rows = tl.arange(0, x_D0)[:, None] # shape [x_D0, 1] - cols = tl.arange(0, x_D1)[None, :] # shape [1, x_D1] + rows = tl.arange(0, x_D0)[:, None] + cols = tl.arange(0, x_D1)[None, :] src_ptr = x + x_offset + rows * x_D1 + cols - # Load the full page block for this (token_id, head_id). - # Assumes the page is full; add masks here if you have partial tiles. - page_block = tl.load(src_ptr) + if FP8_TYPE == 1: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale + elif FP8_TYPE == 2: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale + else: + page_block = tl.load(src_ptr).to(tl.float32) - # Reduction: if DIM == 1: - # Reduce over rows (axis=0) -> output vector length x_D1 (per-column reduce). - if REDUCE_TYPE == 0: # Mean - # NOTE: precision-sensitive workloads may want fp32 accumulation: - # s = tl.sum(page_block.to(tl.float32), axis=0) - # reduce_vec = (s / x_D0).to(tl.bfloat16) + if REDUCE_TYPE == 0: reduce_vec = (tl.sum(page_block, axis=0) / x_D0).to(tl.bfloat16) - elif REDUCE_TYPE == 1: # Max + elif REDUCE_TYPE == 1: reduce_vec = tl.max(page_block, axis=0).to(tl.bfloat16) - elif REDUCE_TYPE == 2: # Min + elif REDUCE_TYPE == 2: reduce_vec = tl.min(page_block, axis=0).to(tl.bfloat16) - else: # L2Norm (sqrt(sum(x*x))); NOT RMS - # For RMS, use: tl.sqrt(tl.sum(page_block*page_block, axis=0) / x_D0) - s = tl.sum(page_block * page_block, axis=0).to(tl.float32) + else: + s = tl.sum(page_block * page_block, axis=0) reduce_vec = tl.sqrt(s).to(tl.bfloat16) - # Write to output: layout [num_pages, x_D1] for DIM==1. dst_ptr = output + page_id * x_D1 + tl.arange(0, x_D1) tl.store(dst_ptr, reduce_vec) else: - # DIM == 2: Reduce over cols (axis=1) -> output vector length x_D0 (per-row reduce). - if REDUCE_TYPE == 0: # Mean - # s = tl.sum(page_block.to(tl.float32), axis=1) - # reduce_vec = (s / x_D1).to(tl.bfloat16) + if REDUCE_TYPE == 0: reduce_vec = (tl.sum(page_block, axis=1) / x_D1).to(tl.bfloat16) - elif REDUCE_TYPE == 1: # Max + elif REDUCE_TYPE == 1: reduce_vec = tl.max(page_block, axis=1).to(tl.bfloat16) - elif REDUCE_TYPE == 2: # Min + elif REDUCE_TYPE == 2: reduce_vec = tl.min(page_block, axis=1).to(tl.bfloat16) - else: # L2Norm (sqrt(sum(x*x))); NOT RMS - s = tl.sum(page_block * page_block, axis=1).to(tl.float32) + else: + s = tl.sum(page_block * page_block, axis=1) reduce_vec = tl.sqrt(s).to(tl.bfloat16) - # Write to output: layout [num_pages, x_D0] for DIM==2. dst_ptr = output + page_id * x_D0 + tl.arange(0, x_D0) tl.store(dst_ptr, reduce_vec) - def reduce_rp( @@ -206,11 +217,13 @@ def reduce_rp( ctx: Context, dim: int, reduce_type: ReduceType, +fp8_type: int = 0, +scale: float = 1.0, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = ctx.head_num - + reduce_rp_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -220,7 +233,9 @@ def reduce_rp( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + FP8_TYPE=fp8_type, + scale=scale, ) @@ -232,11 +247,13 @@ def _reduce_rp( page_size: int, dim: int, reduce_type: ReduceType, +fp8_type: int = 0, +scale: float = 1.0, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = num_kv_heads - + reduce_rp_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -246,92 +263,76 @@ def _reduce_rp( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + FP8_TYPE=fp8_type, + scale=scale, ) @triton.jit def reduce_pr_kernel( x, output, loc, -x_D0: tl.constexpr, # rows per page -x_D1: tl.constexpr, # cols per page +x_D0: tl.constexpr, +x_D1: tl.constexpr, NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, -REDUCE_TYPE: tl.constexpr, # 0: Mean, 1: Max, 2: Min, 3: L2Norm (not RMS) -DIM: tl.constexpr # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 +REDUCE_TYPE: tl.constexpr, +DIM: tl.constexpr, +FP8_TYPE: tl.constexpr, +scale, ): - """ - Layouts: - x: [num_pages * NUM_KV_HEAD, x_D0, x_D1] (page-major, row-major inside page) - output: [num_tokens * NUM_KV_HEAD, vec_len] (token-major; vec_len = x_D1 if DIM==1 else x_D0) - - Behavior: - - token_id comes from pid0; head_id comes from pid1. - - Read loc[token_id] to get absolute position; only proceed at page end. - - Map token -> page via page_idx = (token_position // PAGE_SIZE). - - Read the whole page for this (page_idx, head_id), do reduction, - then write a single vector to output at (token_id, head_id, :). - """ - - # --- Program IDs --- - token_id = tl.program_id(0) # [0 .. num_tokens-1] - head_id = tl.program_id(1) # [0 .. NUM_KV_HEAD-1] - - # --- Trigger only at end-of-page token --- + + token_id = tl.program_id(0) + head_id = tl.program_id(1) + token_position = tl.load(loc + token_id) if (token_position + 1) % PAGE_SIZE != 0: return - # --- Page indexing for x (page-major) --- - # page linear id across heads page_idx = token_position // PAGE_SIZE page_id = page_idx * NUM_KV_HEAD + head_id - # Base element offset into x for this (page_id, head_id) - # x is laid out as contiguous pages, each page is [x_D0, x_D1] x_offset = page_id * x_D0 * x_D1 - # 2D row-major addressing within the page - rows = tl.arange(0, x_D0)[:, None] # [x_D0, 1] - cols = tl.arange(0, x_D1)[None, :] # [1, x_D1] + rows = tl.arange(0, x_D0)[:, None] + cols = tl.arange(0, x_D1)[None, :] src_ptr = x + x_offset + rows * x_D1 + cols - # Load the full page block. Assumes full tiles; add masks if needed. - page_block = tl.load(src_ptr) + if FP8_TYPE == 1: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale + elif FP8_TYPE == 2: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale + else: + page_block = tl.load(src_ptr).to(tl.float32) - # --- Reduction & write-out --- if DIM == 1: - # Reduce over rows (axis=0) -> per-column vector, length = x_D1 - if REDUCE_TYPE == 0: # Mean - # For better accuracy you may upcast: tl.sum(page_block.to(tl.float32), axis=0) + if REDUCE_TYPE == 0: reduce_vec = (tl.sum(page_block, axis=0) / x_D0).to(tl.bfloat16) - elif REDUCE_TYPE == 1: # Max + elif REDUCE_TYPE == 1: reduce_vec = tl.max(page_block, axis=0).to(tl.bfloat16) - elif REDUCE_TYPE == 2: # Min + elif REDUCE_TYPE == 2: reduce_vec = tl.min(page_block, axis=0).to(tl.bfloat16) - else: # L2Norm (NOT RMS) - s = tl.sum(page_block * page_block, axis=0).to(tl.float32) + else: + s = tl.sum(page_block * page_block, axis=0) reduce_vec = tl.sqrt(s).to(tl.bfloat16) - # output is token-major: [num_tokens, NUM_KV_HEAD, x_D1] out_base = (token_id * NUM_KV_HEAD + head_id) * x_D1 dst_ptr = output + out_base + tl.arange(0, x_D1) tl.store(dst_ptr, reduce_vec) else: - # DIM == 2: Reduce over cols (axis=1) -> per-row vector, length = x_D0 - if REDUCE_TYPE == 0: # Mean + if REDUCE_TYPE == 0: reduce_vec = (tl.sum(page_block, axis=1) / x_D1).to(tl.bfloat16) - elif REDUCE_TYPE == 1: # Max + elif REDUCE_TYPE == 1: reduce_vec = tl.max(page_block, axis=1).to(tl.bfloat16) - elif REDUCE_TYPE == 2: # Min + elif REDUCE_TYPE == 2: reduce_vec = tl.min(page_block, axis=1).to(tl.bfloat16) - else: # L2Norm (NOT RMS) - s = tl.sum(page_block * page_block, axis=1).to(tl.float32) + else: + s = tl.sum(page_block * page_block, axis=1) reduce_vec = tl.sqrt(s).to(tl.bfloat16) - - # output is token-major: [num_tokens, NUM_KV_HEAD, x_D0] out_base = (token_id * NUM_KV_HEAD + head_id) * x_D0 dst_ptr = output + out_base + tl.arange(0, x_D0) tl.store(dst_ptr, reduce_vec) @@ -344,11 +345,13 @@ def reduce_pr( ctx: Context, dim: int, reduce_type: ReduceType, +fp8_type: int = 0, +scale: float = 1.0, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = ctx.head_num - + reduce_pr_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -358,9 +361,11 @@ def reduce_pr( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + FP8_TYPE=fp8_type, + scale=scale, ) - + def _reduce_pr( x: torch.Tensor, output: torch.Tensor, @@ -369,11 +374,13 @@ def _reduce_pr( page_size: int, dim: int, reduce_type: ReduceType, +fp8_type: int = 0, +scale: float = 1.0, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = num_kv_heads - + reduce_pr_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -383,72 +390,68 @@ def _reduce_pr( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + FP8_TYPE=fp8_type, + scale=scale, ) @triton.jit def reduce_rr_kernel( x, output, loc, -x_D0: tl.constexpr, # rows per token-page -x_D1: tl.constexpr, # cols per token-page +x_D0: tl.constexpr, +x_D1: tl.constexpr, NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, -REDUCE_TYPE: tl.constexpr, # 0: Mean, 1: Max, 2: Min, 3: L2Norm (not RMS) -DIM: tl.constexpr # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 +REDUCE_TYPE: tl.constexpr, +DIM: tl.constexpr, +FP8_TYPE: tl.constexpr, +scale, ): - """ - Layouts: - x: [num_tokens * NUM_KV_HEAD, x_D0, x_D1] (token-major) - output: [num_tokens * NUM_KV_HEAD, vec_len] (token-major; vec_len = x_D1 if DIM==1 else x_D0) - Only the last token of each page performs the reduction and writes to output[token_id, head_id, :]. - """ - - - # program ids - token_id = tl.program_id(0) # 0..num_tokens-1 - head_id = tl.program_id(1) # 0..NUM_KV_HEAD-1 + token_id = tl.program_id(0) + head_id = tl.program_id(1) - # trigger only at end-of-page token token_position = tl.load(loc + token_id) if (token_position + 1) % PAGE_SIZE != 0: return - # ---- read from x (token-major) ---- x_base = (token_id * NUM_KV_HEAD + head_id) * x_D0 * x_D1 - rows = tl.arange(0, x_D0)[:, None] # [x_D0, 1] - cols = tl.arange(0, x_D1)[None, :] # [1, x_D1] + rows = tl.arange(0, x_D0)[:, None] + cols = tl.arange(0, x_D1)[None, :] src_ptr = x + x_base + rows * x_D1 + cols - page_blk = tl.load(src_ptr) # assumes full page; add masks if needed - # ---- reduce ---- + if FP8_TYPE == 1: + raw = tl.load(src_ptr) + page_blk = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale + elif FP8_TYPE == 2: + raw = tl.load(src_ptr) + page_blk = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale + else: + page_blk = tl.load(src_ptr).to(tl.float32) + if DIM == 1: - # over rows -> axis=0 -> vector len x_D1 - if REDUCE_TYPE == 0: # Mean - # For better accuracy you may upcast to fp32 before sum. + if REDUCE_TYPE == 0: vec = (tl.sum(page_blk, axis=0) / x_D0).to(tl.bfloat16) - elif REDUCE_TYPE == 1: # Max + elif REDUCE_TYPE == 1: vec = tl.max(page_blk, axis=0).to(tl.bfloat16) - elif REDUCE_TYPE == 2: # Min + elif REDUCE_TYPE == 2: vec = tl.min(page_blk, axis=0).to(tl.bfloat16) - else: # L2Norm (NOT RMS) + else: s = tl.sum(page_blk * page_blk, axis=0) vec = tl.sqrt(s).to(tl.bfloat16) - # ---- write to output (token-major) ---- out_base = (token_id * NUM_KV_HEAD + head_id) * x_D1 tl.store(output + out_base + tl.arange(0, x_D1), vec) else: - # DIM == 2: over cols -> axis=1 -> vector len x_D0 - if REDUCE_TYPE == 0: # Mean + if REDUCE_TYPE == 0: vec = (tl.sum(page_blk, axis=1) / x_D1).to(tl.bfloat16) - elif REDUCE_TYPE == 1: # Max + elif REDUCE_TYPE == 1: vec = tl.max(page_blk, axis=1).to(tl.bfloat16) - elif REDUCE_TYPE == 2: # Min + elif REDUCE_TYPE == 2: vec = tl.min(page_blk, axis=1).to(tl.bfloat16) - else: # L2Norm (NOT RMS) + else: s = tl.sum(page_blk * page_blk, axis=1) vec = tl.sqrt(s).to(tl.bfloat16) @@ -456,7 +459,6 @@ def reduce_rr_kernel( tl.store(output + out_base + tl.arange(0, x_D0), vec) - def reduce_rr( x: torch.Tensor, output: torch.Tensor, @@ -464,11 +466,13 @@ def reduce_rr( ctx: Context, dim: int, reduce_type: ReduceType, +fp8_type: int = 0, +scale: float = 1.0, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = ctx.head_num - + reduce_rr_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -478,9 +482,11 @@ def reduce_rr( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + FP8_TYPE=fp8_type, + scale=scale, ) - + def _reduce_rr( x: torch.Tensor, @@ -490,11 +496,13 @@ def _reduce_rr( page_size: int, dim: int, reduce_type: ReduceType, +fp8_type: int = 0, +scale: float = 1.0, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = num_kv_heads - + reduce_rr_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -504,5 +512,7 @@ def _reduce_rr( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim - ) \ No newline at end of file + DIM=dim, + FP8_TYPE=fp8_type, + scale=scale, + ) diff --git a/vortex_torch/cache/triton_kernels/set_kv.py b/vortex_torch/cache/triton_kernels/set_kv.py index 2a2c785..6b289df 100644 --- a/vortex_torch/cache/triton_kernels/set_kv.py +++ b/vortex_torch/cache/triton_kernels/set_kv.py @@ -131,11 +131,11 @@ def set_kv_buffer_launcher( loc: torch.LongTensor, page_size: int ): - + NNZ = loc.shape[0] NUM_KV_HEAD = new_k.shape[1] HEAD_DIM = new_k.shape[2] - + set_kv_buffer_kernel[(NNZ, NUM_KV_HEAD)]( k_cache, v_cache, @@ -148,3 +148,96 @@ def set_kv_buffer_launcher( page_size ) + +@triton.jit +def set_kv_buffer_fp8_kernel( + k_cache, # uint8 paged K cache + v_cache, # uint8 paged V cache + new_k, # bf16 input K [NNZ, NUM_KV_HEAD, HEAD_DIM] + new_v, # bf16 input V [NNZ, NUM_KV_HEAD, HEAD_DIM] + loc, # int64 token positions + NUM_KV_HEAD: tl.constexpr, + NNZ: tl.constexpr, + HEAD_DIM: tl.constexpr, + PAGE_SIZE: tl.constexpr, + FP8_TYPE: tl.constexpr, # 1: e4m3 (max=448), 2: e5m2 (max=57344) + k_scale, # float: per-tensor scale for K quantization + v_scale, # float: per-tensor scale for V quantization +): + """Quantize bf16 K/V to fp8, bitcast to uint8, and scatter into paged cache.""" + token_id = tl.program_id(0) + if token_id >= NNZ: + return + head_id = tl.program_id(1) + dim = tl.arange(0, HEAD_DIM) + + # Load bf16 source values + src_ptr = token_id * NUM_KV_HEAD * HEAD_DIM + head_id * HEAD_DIM + dim + src_k = tl.load(new_k + src_ptr).to(tl.float32) + src_v = tl.load(new_v + src_ptr).to(tl.float32) + + # Scale down: quantized = real_value / scale + inv_k_scale = 1.0 / k_scale + inv_v_scale = 1.0 / v_scale + scaled_k = src_k * inv_k_scale + scaled_v = src_v * inv_v_scale + + # Clamp and cast to fp8, then bitcast to uint8 for storage + if FP8_TYPE == 1: + # e4m3: max = 448.0 + clamped_k = tl.minimum(tl.maximum(scaled_k, -448.0), 448.0) + clamped_v = tl.minimum(tl.maximum(scaled_v, -448.0), 448.0) + q_k = clamped_k.to(tl.float8e4nv).to(tl.uint8, bitcast=True) + q_v = clamped_v.to(tl.float8e4nv).to(tl.uint8, bitcast=True) + else: + # e5m2: max = 57344.0 + clamped_k = tl.minimum(tl.maximum(scaled_k, -57344.0), 57344.0) + clamped_v = tl.minimum(tl.maximum(scaled_v, -57344.0), 57344.0) + q_k = clamped_k.to(tl.float8e5).to(tl.uint8, bitcast=True) + q_v = clamped_v.to(tl.float8e5).to(tl.uint8, bitcast=True) + + # Compute paged destination offset + token_position = tl.load(loc + token_id) + page_id = token_position // PAGE_SIZE + in_page_offset = token_position % PAGE_SIZE + position_trans = page_id * (PAGE_SIZE * NUM_KV_HEAD) + head_id * PAGE_SIZE + in_page_offset + + # Write uint8 values + dst_k_ptr = k_cache + position_trans * HEAD_DIM + dim + dst_v_ptr = v_cache + position_trans * HEAD_DIM + dim + tl.store(dst_k_ptr, q_k) + tl.store(dst_v_ptr, q_v) + + +def set_kv_buffer_fp8_launcher( + k_cache: torch.Tensor, + v_cache: torch.Tensor, + new_k: torch.Tensor, + new_v: torch.Tensor, + loc: torch.LongTensor, + page_size: int, + k_scale: float, + v_scale: float, + fp8_type: int = 1, +): + """Quantize bf16 K/V to fp8, bitcast to uint8, and scatter into paged cache. + + Args: + fp8_type: 1 for e4m3 (default), 2 for e5m2. + k_scale: per-tensor scale used for K quantization. + v_scale: per-tensor scale used for V quantization. + """ + NNZ = loc.shape[0] + NUM_KV_HEAD = new_k.shape[1] + HEAD_DIM = new_k.shape[2] + + set_kv_buffer_fp8_kernel[(NNZ, NUM_KV_HEAD)]( + k_cache, v_cache, + new_k, new_v, + loc, + NUM_KV_HEAD, NNZ, HEAD_DIM, page_size, + FP8_TYPE=fp8_type, + k_scale=k_scale, + v_scale=v_scale, + ) + From f25fb13e6075111a9b05aedb060a3e7bd346cebd Mon Sep 17 00:00:00 2001 From: UED Date: Sun, 1 Mar 2026 08:02:49 +0000 Subject: [PATCH 04/22] update on parameters for reduce_pp_kernel with quantization --- setup.py | 2 +- vortex_torch/cache/context.py | 12 +- vortex_torch/cache/reduce.py | 7 +- .../cache/triton_kernels/reduce_impl.py | 305 ++++++++++++------ 4 files changed, 224 insertions(+), 102 deletions(-) diff --git a/setup.py b/setup.py index 6efeebe..f35ddae 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ sources=[ 'csrc/register.cc', 'csrc/utils_sglang.cu', - 'csrc/topk.cu' + 'csrc/topk.cu', ], include_dirs=['csrc'], extra_compile_args={ diff --git a/vortex_torch/cache/context.py b/vortex_torch/cache/context.py index dd1bd02..0e7171c 100644 --- a/vortex_torch/cache/context.py +++ b/vortex_torch/cache/context.py @@ -22,9 +22,11 @@ class Context(ContextBase): "_aux_total_flops", - # FP8 quantization: fp8_type (0=none, 1=e4m3, 2=e5m2), kv_scale (per-tensor) - "fp8_type", + # Quantization: quant_type (0=none, 1=int8, 2=e4m3, 3=e5m2), + # kv_scale (per-tensor fp8 scale), kv_scale_ptr (per-token int8 scale tensor) + "quant_type", "kv_scale", + "kv_scale_ptr", ) @@ -41,10 +43,12 @@ def __init__(self) -> None: object.__setattr__(self, name, 0) # start from 0 flops elif name == "mode": object.__setattr__(self, name, Mode.profile) - elif name == "fp8_type": - object.__setattr__(self, name, 0) # 0 = no fp8 (bf16 default) + elif name == "quant_type": + object.__setattr__(self, name, 0) # 0 = none (bf16 default) elif name == "kv_scale": object.__setattr__(self, name, 1.0) # identity scale for bf16 + elif name == "kv_scale_ptr": + object.__setattr__(self, name, None) # per-token scale tensor (int8 only) else: object.__setattr__(self, name, UNSET) diff --git a/vortex_torch/cache/reduce.py b/vortex_torch/cache/reduce.py index 5800458..eb94795 100644 --- a/vortex_torch/cache/reduce.py +++ b/vortex_torch/cache/reduce.py @@ -345,10 +345,11 @@ def execute( ) output = self.output_buffer - # Launch the kernel/implementation: impl(x, output, loc, ctx, dim, reduce_type, fp8_type, scale) - fp8_type = getattr(ctx, 'fp8_type', 0) + # Launch the kernel/implementation: impl(x, output, loc, ctx, dim, reduce_type, quant_type, scale, kv_scale_ptr) + quant_type = getattr(ctx, 'quant_type', 0) scale = getattr(ctx, 'kv_scale', 1.0) - self.impl(x, output, loc, ctx, self.dim, self.reduce_type, fp8_type, scale) + kv_scale_ptr = getattr(ctx, 'kv_scale_ptr', None) + self.impl(x, output, loc, ctx, self.dim, self.reduce_type, quant_type, scale, kv_scale_ptr) return output diff --git a/vortex_torch/cache/triton_kernels/reduce_impl.py b/vortex_torch/cache/triton_kernels/reduce_impl.py index 9670acd..0146af7 100644 --- a/vortex_torch/cache/triton_kernels/reduce_impl.py +++ b/vortex_torch/cache/triton_kernels/reduce_impl.py @@ -6,11 +6,12 @@ # --------------------------------------------------------------------------- -# Helper: Load a page block from src_ptr, handling bf16 or fp8-stored-as-uint8. -# FP8_TYPE == 0 -> bf16 pointer, load normally -# FP8_TYPE == 1 -> uint8 pointer, bitcast to float8e4nv, dequant with scale -# FP8_TYPE == 2 -> uint8 pointer, bitcast to float8e5, dequant with scale -# All paths return a float32 tensor ready for reduction. +# Helper: Load a page block from src_ptr, handling bf16 / int8 / fp8-stored-as-uint8. +# QUANT_TYPE == 0 -> bf16 pointer, load normally +# QUANT_TYPE == 1 -> int8 pointer, dequant with per-row scale from kv_scale_ptr +# QUANT_TYPE == 2 -> uint8 pointer, bitcast to float8e4nv, dequant with per-tensor scale +# QUANT_TYPE == 3 -> uint8 pointer, bitcast to float8e5, dequant with per-tensor scale +# All quantised paths return a float32 tensor ready for reduction. # --------------------------------------------------------------------------- @@ -23,8 +24,9 @@ def reduce_pp_kernel( PAGE_SIZE: tl.constexpr, REDUCE_TYPE: tl.constexpr, # 0:Mean, 1:Max, 2:Min, 3:L2Norm DIM: tl.constexpr, # 1: over rows (axis=0) -> len x_D1; 2: over cols (axis=1) -> len x_D0 -FP8_TYPE: tl.constexpr, # 0: bf16, 1: e4m3, 2: e5m2 +QUANT_TYPE: tl.constexpr, # 0: bf16, 1: int8, 2: e4m3, 3: e5m2 scale, # float: 1.0 for bf16, kv_scale for fp8 +kv_scale_ptr, # pointer to per-token int8 scales (unused when QUANT_TYPE != 1) ): token_id = tl.program_id(0) @@ -42,14 +44,21 @@ def reduce_pp_kernel( cols = tl.arange(0, x_D1)[None, :] # [1, x_D1] src_ptr = x + x_offset + rows * x_D1 + cols - if FP8_TYPE == 1: + if QUANT_TYPE == 1: + # int8: load int8 values, dequant with per-row scale + raw = tl.load(src_ptr).to(tl.float32) + # Per-row scales stored at kv_scale_ptr[page_id * x_D0 + row] + scale_offset = page_id * x_D0 + tl.arange(0, x_D0) + row_scales = tl.load(kv_scale_ptr + scale_offset).to(tl.float32) # [x_D0] + page_block = raw * row_scales[:, None] # broadcast [x_D0, 1] + elif QUANT_TYPE == 2: raw = tl.load(src_ptr) page_block = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale - elif FP8_TYPE == 2: + elif QUANT_TYPE == 3: raw = tl.load(src_ptr) page_block = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale else: - page_block = tl.load(src_ptr).to(tl.float32) + page_block = tl.load(src_ptr) if DIM == 1: # reduce over rows -> axis=0 -> length x_D1 @@ -60,7 +69,7 @@ def reduce_pp_kernel( elif REDUCE_TYPE == 2: # Min reduce_vec = tl.min(page_block, axis=0).to(tl.bfloat16) else: # L2Norm - s = tl.sum(page_block * page_block, axis=0) + s = tl.sum(page_block * page_block, axis=0).to(tl.float32) reduce_vec = tl.sqrt(s).to(tl.bfloat16) dst_ptr = output + page_id * x_D1 + tl.arange(0, x_D1) @@ -75,7 +84,7 @@ def reduce_pp_kernel( elif REDUCE_TYPE == 2: # Min reduce_vec = tl.min(page_block, axis=1).to(tl.bfloat16) else: # L2Norm - s = tl.sum(page_block * page_block, axis=1) + s = tl.sum(page_block * page_block, axis=1).to(tl.float32) reduce_vec = tl.sqrt(s).to(tl.bfloat16) dst_ptr = output + page_id * x_D0 + tl.arange(0, x_D0) @@ -91,8 +100,9 @@ def reduce_pp( ctx: Context, dim: int, reduce_type: ReduceType, -fp8_type: int = 0, +quant_type: int = 0, scale: float = 1.0, +kv_scale_ptr=None, ): NNZ = loc.shape[0] @@ -108,8 +118,9 @@ def reduce_pp( PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, DIM=dim, - FP8_TYPE=fp8_type, + QUANT_TYPE=quant_type, scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @@ -121,8 +132,9 @@ def _reduce_pp( page_size: int, dim: int, reduce_type: ReduceType, -fp8_type: int = 0, +quant_type: int = 0, scale: float = 1.0, +kv_scale_ptr=None, ): NNZ = loc.shape[0] @@ -138,8 +150,9 @@ def _reduce_pp( PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, DIM=dim, - FP8_TYPE=fp8_type, + QUANT_TYPE=quant_type, scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @@ -147,69 +160,102 @@ def _reduce_pp( @triton.jit def reduce_rp_kernel( x, output, loc, - x_D0: tl.constexpr, - x_D1: tl.constexpr, + x_D0: tl.constexpr, # rows per token-page + x_D1: tl.constexpr, # cols per token-page NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, - REDUCE_TYPE: tl.constexpr, - DIM: tl.constexpr, - FP8_TYPE: tl.constexpr, - scale, + REDUCE_TYPE: tl.constexpr, # 0: Mean, 1: Max, 2: Min, 3: L2Norm (not RMS) + DIM: tl.constexpr, # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 + QUANT_TYPE: tl.constexpr, # 0: bf16, 1: int8, 2: e4m3, 3: e5m2 + scale, # float: 1.0 for bf16, kv_scale for fp8 + kv_scale_ptr, # pointer to per-token int8 scales (unused when QUANT_TYPE != 1) ): + # Program IDs: + # pid0 = token index (0 .. num_tokens-1) + # pid1 = head index (0 .. NUM_KV_HEAD-1) token_id = tl.program_id(0) head_id = tl.program_id(1) + # Load the absolute position of this token (used to map to page index). token_position = tl.load(loc + token_id) + # Only the last token of a page triggers the reduction. if (token_position + 1) % PAGE_SIZE != 0: return + # Output page index: + # Logical page = token_position // PAGE_SIZE + # One vector per head, so linearize by NUM_KV_HEAD. page_id = (token_position // PAGE_SIZE) * NUM_KV_HEAD + head_id + + # Input layout is [num_tokens, num_heads, x_D0, x_D1] (row-major). + # For this token/head, compute the base element offset in `x`. x_offset = (token_id * NUM_KV_HEAD + head_id) * x_D0 * x_D1 - rows = tl.arange(0, x_D0)[:, None] - cols = tl.arange(0, x_D1)[None, :] + # Build 2D indices within a page (row-major addressing). + rows = tl.arange(0, x_D0)[:, None] # shape [x_D0, 1] + cols = tl.arange(0, x_D1)[None, :] # shape [1, x_D1] src_ptr = x + x_offset + rows * x_D1 + cols - if FP8_TYPE == 1: + # Load the full page block for this (token_id, head_id). + # Assumes the page is full; add masks here if you have partial tiles. + if QUANT_TYPE == 1: + # int8: load int8 values, dequant with per-row scale + raw = tl.load(src_ptr).to(tl.float32) + scale_offset = page_id * x_D0 + tl.arange(0, x_D0) + row_scales = tl.load(kv_scale_ptr + scale_offset).to(tl.float32) + page_block = raw * row_scales[:, None] + elif QUANT_TYPE == 2: raw = tl.load(src_ptr) page_block = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale - elif FP8_TYPE == 2: + elif QUANT_TYPE == 3: raw = tl.load(src_ptr) page_block = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale else: - page_block = tl.load(src_ptr).to(tl.float32) + page_block = tl.load(src_ptr) + # Reduction: if DIM == 1: - if REDUCE_TYPE == 0: + # Reduce over rows (axis=0) -> output vector length x_D1 (per-column reduce). + if REDUCE_TYPE == 0: # Mean + # NOTE: precision-sensitive workloads may want fp32 accumulation: + # s = tl.sum(page_block.to(tl.float32), axis=0) + # reduce_vec = (s / x_D0).to(tl.bfloat16) reduce_vec = (tl.sum(page_block, axis=0) / x_D0).to(tl.bfloat16) - elif REDUCE_TYPE == 1: + elif REDUCE_TYPE == 1: # Max reduce_vec = tl.max(page_block, axis=0).to(tl.bfloat16) - elif REDUCE_TYPE == 2: + elif REDUCE_TYPE == 2: # Min reduce_vec = tl.min(page_block, axis=0).to(tl.bfloat16) - else: - s = tl.sum(page_block * page_block, axis=0) + else: # L2Norm (sqrt(sum(x*x))); NOT RMS + # For RMS, use: tl.sqrt(tl.sum(page_block*page_block, axis=0) / x_D0) + s = tl.sum(page_block * page_block, axis=0).to(tl.float32) reduce_vec = tl.sqrt(s).to(tl.bfloat16) + # Write to output: layout [num_pages, x_D1] for DIM==1. dst_ptr = output + page_id * x_D1 + tl.arange(0, x_D1) tl.store(dst_ptr, reduce_vec) else: - if REDUCE_TYPE == 0: + # DIM == 2: Reduce over cols (axis=1) -> output vector length x_D0 (per-row reduce). + if REDUCE_TYPE == 0: # Mean + # s = tl.sum(page_block.to(tl.float32), axis=1) + # reduce_vec = (s / x_D1).to(tl.bfloat16) reduce_vec = (tl.sum(page_block, axis=1) / x_D1).to(tl.bfloat16) - elif REDUCE_TYPE == 1: + elif REDUCE_TYPE == 1: # Max reduce_vec = tl.max(page_block, axis=1).to(tl.bfloat16) - elif REDUCE_TYPE == 2: + elif REDUCE_TYPE == 2: # Min reduce_vec = tl.min(page_block, axis=1).to(tl.bfloat16) - else: - s = tl.sum(page_block * page_block, axis=1) + else: # L2Norm (sqrt(sum(x*x))); NOT RMS + s = tl.sum(page_block * page_block, axis=1).to(tl.float32) reduce_vec = tl.sqrt(s).to(tl.bfloat16) + # Write to output: layout [num_pages, x_D0] for DIM==2. dst_ptr = output + page_id * x_D0 + tl.arange(0, x_D0) tl.store(dst_ptr, reduce_vec) + def reduce_rp( x: torch.Tensor, output: torch.Tensor, @@ -217,8 +263,9 @@ def reduce_rp( ctx: Context, dim: int, reduce_type: ReduceType, -fp8_type: int = 0, +quant_type: int = 0, scale: float = 1.0, +kv_scale_ptr=None, ): NNZ = loc.shape[0] @@ -234,8 +281,9 @@ def reduce_rp( PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, DIM=dim, - FP8_TYPE=fp8_type, + QUANT_TYPE=quant_type, scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @@ -247,8 +295,9 @@ def _reduce_rp( page_size: int, dim: int, reduce_type: ReduceType, -fp8_type: int = 0, +quant_type: int = 0, scale: float = 1.0, +kv_scale_ptr=None, ): NNZ = loc.shape[0] @@ -264,75 +313,110 @@ def _reduce_rp( PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, DIM=dim, - FP8_TYPE=fp8_type, + QUANT_TYPE=quant_type, scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @triton.jit def reduce_pr_kernel( x, output, loc, -x_D0: tl.constexpr, -x_D1: tl.constexpr, +x_D0: tl.constexpr, # rows per page +x_D1: tl.constexpr, # cols per page NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, -REDUCE_TYPE: tl.constexpr, -DIM: tl.constexpr, -FP8_TYPE: tl.constexpr, -scale, +REDUCE_TYPE: tl.constexpr, # 0: Mean, 1: Max, 2: Min, 3: L2Norm (not RMS) +DIM: tl.constexpr, # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 +QUANT_TYPE: tl.constexpr, # 0: bf16, 1: int8, 2: e4m3, 3: e5m2 +scale, # float: 1.0 for bf16, kv_scale for fp8 +kv_scale_ptr, # pointer to per-token int8 scales (unused when QUANT_TYPE != 1) ): - - token_id = tl.program_id(0) - head_id = tl.program_id(1) - + """ + Layouts: + x: [num_pages * NUM_KV_HEAD, x_D0, x_D1] (page-major, row-major inside page) + output: [num_tokens * NUM_KV_HEAD, vec_len] (token-major; vec_len = x_D1 if DIM==1 else x_D0) + + Behavior: + - token_id comes from pid0; head_id comes from pid1. + - Read loc[token_id] to get absolute position; only proceed at page end. + - Map token -> page via page_idx = (token_position // PAGE_SIZE). + - Read the whole page for this (page_idx, head_id), do reduction, + then write a single vector to output at (token_id, head_id, :). + """ + + # --- Program IDs --- + token_id = tl.program_id(0) # [0 .. num_tokens-1] + head_id = tl.program_id(1) # [0 .. NUM_KV_HEAD-1] + + # --- Trigger only at end-of-page token --- token_position = tl.load(loc + token_id) if (token_position + 1) % PAGE_SIZE != 0: return + # --- Page indexing for x (page-major) --- + # page linear id across heads page_idx = token_position // PAGE_SIZE page_id = page_idx * NUM_KV_HEAD + head_id + # Base element offset into x for this (page_id, head_id) + # x is laid out as contiguous pages, each page is [x_D0, x_D1] x_offset = page_id * x_D0 * x_D1 - rows = tl.arange(0, x_D0)[:, None] - cols = tl.arange(0, x_D1)[None, :] + # 2D row-major addressing within the page + rows = tl.arange(0, x_D0)[:, None] # [x_D0, 1] + cols = tl.arange(0, x_D1)[None, :] # [1, x_D1] src_ptr = x + x_offset + rows * x_D1 + cols - if FP8_TYPE == 1: + # Load the full page block. Assumes full tiles; add masks if needed. + if QUANT_TYPE == 1: + # int8: load int8 values, dequant with per-row scale + raw = tl.load(src_ptr).to(tl.float32) + scale_offset = page_id * x_D0 + tl.arange(0, x_D0) + row_scales = tl.load(kv_scale_ptr + scale_offset).to(tl.float32) + page_block = raw * row_scales[:, None] + elif QUANT_TYPE == 2: raw = tl.load(src_ptr) page_block = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale - elif FP8_TYPE == 2: + elif QUANT_TYPE == 3: raw = tl.load(src_ptr) page_block = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale else: - page_block = tl.load(src_ptr).to(tl.float32) + page_block = tl.load(src_ptr) + # --- Reduction & write-out --- if DIM == 1: - if REDUCE_TYPE == 0: + # Reduce over rows (axis=0) -> per-column vector, length = x_D1 + if REDUCE_TYPE == 0: # Mean + # For better accuracy you may upcast: tl.sum(page_block.to(tl.float32), axis=0) reduce_vec = (tl.sum(page_block, axis=0) / x_D0).to(tl.bfloat16) - elif REDUCE_TYPE == 1: + elif REDUCE_TYPE == 1: # Max reduce_vec = tl.max(page_block, axis=0).to(tl.bfloat16) - elif REDUCE_TYPE == 2: + elif REDUCE_TYPE == 2: # Min reduce_vec = tl.min(page_block, axis=0).to(tl.bfloat16) - else: - s = tl.sum(page_block * page_block, axis=0) + else: # L2Norm (NOT RMS) + s = tl.sum(page_block * page_block, axis=0).to(tl.float32) reduce_vec = tl.sqrt(s).to(tl.bfloat16) + # output is token-major: [num_tokens, NUM_KV_HEAD, x_D1] out_base = (token_id * NUM_KV_HEAD + head_id) * x_D1 dst_ptr = output + out_base + tl.arange(0, x_D1) tl.store(dst_ptr, reduce_vec) else: - if REDUCE_TYPE == 0: + # DIM == 2: Reduce over cols (axis=1) -> per-row vector, length = x_D0 + if REDUCE_TYPE == 0: # Mean reduce_vec = (tl.sum(page_block, axis=1) / x_D1).to(tl.bfloat16) - elif REDUCE_TYPE == 1: + elif REDUCE_TYPE == 1: # Max reduce_vec = tl.max(page_block, axis=1).to(tl.bfloat16) - elif REDUCE_TYPE == 2: + elif REDUCE_TYPE == 2: # Min reduce_vec = tl.min(page_block, axis=1).to(tl.bfloat16) - else: - s = tl.sum(page_block * page_block, axis=1) + else: # L2Norm (NOT RMS) + s = tl.sum(page_block * page_block, axis=1).to(tl.float32) reduce_vec = tl.sqrt(s).to(tl.bfloat16) + + # output is token-major: [num_tokens, NUM_KV_HEAD, x_D0] out_base = (token_id * NUM_KV_HEAD + head_id) * x_D0 dst_ptr = output + out_base + tl.arange(0, x_D0) tl.store(dst_ptr, reduce_vec) @@ -345,8 +429,9 @@ def reduce_pr( ctx: Context, dim: int, reduce_type: ReduceType, -fp8_type: int = 0, +quant_type: int = 0, scale: float = 1.0, +kv_scale_ptr=None, ): NNZ = loc.shape[0] @@ -362,8 +447,9 @@ def reduce_pr( PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, DIM=dim, - FP8_TYPE=fp8_type, + QUANT_TYPE=quant_type, scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) def _reduce_pr( @@ -374,8 +460,9 @@ def _reduce_pr( page_size: int, dim: int, reduce_type: ReduceType, -fp8_type: int = 0, +quant_type: int = 0, scale: float = 1.0, +kv_scale_ptr=None, ): NNZ = loc.shape[0] @@ -391,67 +478,92 @@ def _reduce_pr( PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, DIM=dim, - FP8_TYPE=fp8_type, + QUANT_TYPE=quant_type, scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @triton.jit def reduce_rr_kernel( x, output, loc, -x_D0: tl.constexpr, -x_D1: tl.constexpr, +x_D0: tl.constexpr, # rows per token-page +x_D1: tl.constexpr, # cols per token-page NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, -REDUCE_TYPE: tl.constexpr, -DIM: tl.constexpr, -FP8_TYPE: tl.constexpr, -scale, +REDUCE_TYPE: tl.constexpr, # 0: Mean, 1: Max, 2: Min, 3: L2Norm (not RMS) +DIM: tl.constexpr, # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 +QUANT_TYPE: tl.constexpr, # 0: bf16, 1: int8, 2: e4m3, 3: e5m2 +scale, # float: 1.0 for bf16, kv_scale for fp8 +kv_scale_ptr, # pointer to per-token int8 scales (unused when QUANT_TYPE != 1) ): + """ + Layouts: + x: [num_tokens * NUM_KV_HEAD, x_D0, x_D1] (token-major) + output: [num_tokens * NUM_KV_HEAD, vec_len] (token-major; vec_len = x_D1 if DIM==1 else x_D0) + + Only the last token of each page performs the reduction and writes to output[token_id, head_id, :]. + """ - token_id = tl.program_id(0) - head_id = tl.program_id(1) + # program ids + token_id = tl.program_id(0) # 0..num_tokens-1 + head_id = tl.program_id(1) # 0..NUM_KV_HEAD-1 + + # trigger only at end-of-page token token_position = tl.load(loc + token_id) if (token_position + 1) % PAGE_SIZE != 0: return + # ---- read from x (token-major) ---- x_base = (token_id * NUM_KV_HEAD + head_id) * x_D0 * x_D1 - rows = tl.arange(0, x_D0)[:, None] - cols = tl.arange(0, x_D1)[None, :] + rows = tl.arange(0, x_D0)[:, None] # [x_D0, 1] + cols = tl.arange(0, x_D1)[None, :] # [1, x_D1] src_ptr = x + x_base + rows * x_D1 + cols - if FP8_TYPE == 1: + if QUANT_TYPE == 1: + # int8: load int8 values, dequant with per-row scale + raw = tl.load(src_ptr).to(tl.float32) + page_id = (token_position // PAGE_SIZE) * NUM_KV_HEAD + head_id + scale_offset = page_id * x_D0 + tl.arange(0, x_D0) + row_scales = tl.load(kv_scale_ptr + scale_offset).to(tl.float32) + page_blk = raw * row_scales[:, None] + elif QUANT_TYPE == 2: raw = tl.load(src_ptr) page_blk = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale - elif FP8_TYPE == 2: + elif QUANT_TYPE == 3: raw = tl.load(src_ptr) page_blk = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale else: - page_blk = tl.load(src_ptr).to(tl.float32) + page_blk = tl.load(src_ptr) # assumes full page; add masks if needed + # ---- reduce ---- if DIM == 1: - if REDUCE_TYPE == 0: + # over rows -> axis=0 -> vector len x_D1 + if REDUCE_TYPE == 0: # Mean + # For better accuracy you may upcast to fp32 before sum. vec = (tl.sum(page_blk, axis=0) / x_D0).to(tl.bfloat16) - elif REDUCE_TYPE == 1: + elif REDUCE_TYPE == 1: # Max vec = tl.max(page_blk, axis=0).to(tl.bfloat16) - elif REDUCE_TYPE == 2: + elif REDUCE_TYPE == 2: # Min vec = tl.min(page_blk, axis=0).to(tl.bfloat16) - else: + else: # L2Norm (NOT RMS) s = tl.sum(page_blk * page_blk, axis=0) vec = tl.sqrt(s).to(tl.bfloat16) + # ---- write to output (token-major) ---- out_base = (token_id * NUM_KV_HEAD + head_id) * x_D1 tl.store(output + out_base + tl.arange(0, x_D1), vec) else: - if REDUCE_TYPE == 0: + # DIM == 2: over cols -> axis=1 -> vector len x_D0 + if REDUCE_TYPE == 0: # Mean vec = (tl.sum(page_blk, axis=1) / x_D1).to(tl.bfloat16) - elif REDUCE_TYPE == 1: + elif REDUCE_TYPE == 1: # Max vec = tl.max(page_blk, axis=1).to(tl.bfloat16) - elif REDUCE_TYPE == 2: + elif REDUCE_TYPE == 2: # Min vec = tl.min(page_blk, axis=1).to(tl.bfloat16) - else: + else: # L2Norm (NOT RMS) s = tl.sum(page_blk * page_blk, axis=1) vec = tl.sqrt(s).to(tl.bfloat16) @@ -459,6 +571,7 @@ def reduce_rr_kernel( tl.store(output + out_base + tl.arange(0, x_D0), vec) + def reduce_rr( x: torch.Tensor, output: torch.Tensor, @@ -466,8 +579,9 @@ def reduce_rr( ctx: Context, dim: int, reduce_type: ReduceType, -fp8_type: int = 0, +quant_type: int = 0, scale: float = 1.0, +kv_scale_ptr=None, ): NNZ = loc.shape[0] @@ -483,8 +597,9 @@ def reduce_rr( PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, DIM=dim, - FP8_TYPE=fp8_type, + QUANT_TYPE=quant_type, scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @@ -496,8 +611,9 @@ def _reduce_rr( page_size: int, dim: int, reduce_type: ReduceType, -fp8_type: int = 0, +quant_type: int = 0, scale: float = 1.0, +kv_scale_ptr=None, ): NNZ = loc.shape[0] @@ -513,6 +629,7 @@ def _reduce_rr( PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, DIM=dim, - FP8_TYPE=fp8_type, + QUANT_TYPE=quant_type, scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) From b9eb71786eb8dd12150f02b4dcd15b053328b931 Mon Sep 17 00:00:00 2001 From: UED Date: Mon, 2 Mar 2026 06:33:57 +0000 Subject: [PATCH 05/22] adapt topk kernel from sglang to vortex --- csrc/topk.cu | 1029 +++++++++++++++++++++++++++------- examples/verify_algo.sh | 2 +- examples/verify_algo_fp8.sh | 1 - examples/verify_algo_int8.sh | 1 - 4 files changed, 827 insertions(+), 206 deletions(-) diff --git a/csrc/topk.cu b/csrc/topk.cu index 62d747e..8a48aad 100644 --- a/csrc/topk.cu +++ b/csrc/topk.cu @@ -1,203 +1,826 @@ -#include "register.h" -#include - - -template -__global__ void TopKOutput_F32_Kernel( -const float* __restrict__ score, -const int* __restrict__ dense_kv_indptr, -const int* __restrict__ sparse_kv_indptr, -const int* __restrict__ dense_kv_indices, -int* __restrict__ sparse_kv_indices, -const int topk_val, -const int page_reserved_bos, -const int page_reserved_eos) -{ - const int bx = blockIdx.x; - const int tx = threadIdx.x; - - const int start = dense_kv_indptr[bx] + page_reserved_bos; - const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; - const int nblk = end - start; - if (nblk <= topk_val) return; - - const float* __restrict__ score_blk = score + start; - const int* __restrict__ idx_blk = dense_kv_indices + start; - int* __restrict__ out_blk = sparse_kv_indices + sparse_kv_indptr[bx] + page_reserved_bos; - - float key[ITEM_PER_THREAD]; - int val[ITEM_PER_THREAD]; - - using BLF = cub::BlockLoad; - using BLI = cub::BlockLoad; - using BSI = cub::BlockStore; - using Sort = cub::BlockRadixSort; - - __shared__ union { - typename BLF::TempStorage lf; - typename BLI::TempStorage li; - typename BSI::TempStorage si; - typename Sort::TempStorage sort; - } temp; - - BLF(temp.lf).Load(score_blk, key, nblk, -INFINITY); - __syncthreads(); - BLI(temp.li).Load(idx_blk, val, nblk, 0); - __syncthreads(); - - Sort(temp.sort).SortDescending(key, val); - __syncthreads(); - - const int valid_out = min(topk_val, nblk); - BSI(temp.si).Store(out_blk, /*per-thread regs*/ val, valid_out); -} - - -template -__global__ void TopKOutput_BF16_Kernel( -const __nv_bfloat16* __restrict__ score, -const int* __restrict__ dense_kv_indptr, -const int* __restrict__ sparse_kv_indptr, -const int* __restrict__ dense_kv_indices, -int* __restrict__ sparse_kv_indices, -const int topk_val, -const int page_reserved_bos, -const int page_reserved_eos) -{ - const int bx = blockIdx.x; - const int tx = threadIdx.x; - - const int start = dense_kv_indptr[bx] + page_reserved_bos; - const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; - const int nblk = end - start; - if (nblk <= topk_val) return; - - const __nv_bfloat16* __restrict__ score_blk = score + start; - const int* __restrict__ idx_blk = dense_kv_indices + start; - int* __restrict__ out_blk = sparse_kv_indices + sparse_kv_indptr[bx] + page_reserved_bos; - - const __nv_bfloat16 ninf_bf16 = __float2bfloat16(-CUDART_INF_F); - - __nv_bfloat16 key_bf16[ITEM_PER_THREAD]; - float key[ITEM_PER_THREAD]; - int val[ITEM_PER_THREAD]; - - using BLF = cub::BlockLoad<__nv_bfloat16, NUM_THREADS, ITEM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE>; - using BLI = cub::BlockLoad; - using BSI = cub::BlockStore; - using Sort = cub::BlockRadixSort; - - __shared__ union { - typename BLF::TempStorage lf; - typename BLI::TempStorage li; - typename BSI::TempStorage si; - typename Sort::TempStorage sort; - } temp; - - BLF(temp.lf).Load(score_blk, key_bf16, nblk, ninf_bf16); - - #pragma unroll - for (int i = 0; i < ITEM_PER_THREAD; ++i){ - key[i] = __bfloat162float(key_bf16[i]); - } - __syncthreads(); - - BLI(temp.li).Load(idx_blk, val, nblk, 0); - __syncthreads(); - - Sort(temp.sort).SortDescending(key, val); - __syncthreads(); - - const int valid_out = min(topk_val, nblk); - BSI(temp.si).Store(out_blk, /*per-thread regs*/ val, valid_out); -} - - - -void topk_output( -const at::Tensor& x, -const at::Tensor& dense_kv_indptr, -const at::Tensor& sparse_kv_indptr, -const at::Tensor& dense_kv_indices, -at::Tensor& sparse_kv_indices, -const int64_t eff_batch_size, -const int64_t topk_val, -const int64_t reserved_bos, -const int64_t reserved_eos, -const int64_t max_num_pages -){ - - - dim3 nblks(eff_batch_size); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - if (max_num_pages <= 128){ - TopKOutput_BF16_Kernel<128, 1><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos - ); - } else if (max_num_pages <= 256){ - TopKOutput_BF16_Kernel<128, 2><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos - ); - } else if (max_num_pages <= 512){ - TopKOutput_BF16_Kernel<128, 4><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos - ); - } else if (max_num_pages <= 1024){ - TopKOutput_BF16_Kernel<256, 4><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos - ); - } else if (max_num_pages <= 2048){ - TopKOutput_BF16_Kernel<256, 8><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos - ); - } else if (max_num_pages <= 4096){ - TopKOutput_BF16_Kernel<512, 8><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos - ); - } else { - TORCH_CHECK(false); - } - -} +/** + * @NOTE: This file is adapted from + * https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_v32/topk_selector.py + * We: + * 1. adapt from tilelang to pure cuda + * 2. optimize the performance a little + * 3. fix the potential illegal memory access + */ + #include + #include + #include + #include + #include + #include + #include + #include + #include + + #include + #include + #include + + namespace { + + constexpr int TopK = 2048; + constexpr int kThreadsPerBlock = 1024; + + #ifdef USE_ROCM + // On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a + // per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. + #ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES + constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); + #else + constexpr size_t kSmem = 48 * 1024; // bytes + #endif + #else + // Reduced from 128KB to 32KB to improve occupancy. + // Each radix pass needs at most ~TopK candidates in the threshold bin, + // so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. + constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) + #endif + + struct FastTopKParams { + const float* __restrict__ input; // [B, input_stride] + const int32_t* __restrict__ row_starts; // [B] + int32_t* __restrict__ indices; // [B, TopK] + int32_t* __restrict__ lengths; // [B] + int64_t input_stride; + }; + + // when length <= TopK, we can directly write the indices + __device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { + const auto tid = threadIdx.x; + for (int i = tid; i < TopK; i += kThreadsPerBlock) { + indice[i] = (i < length) ? i : -1; + } + } + + // keep the first `length` entries, set others to -1 + __device__ void naive_topk_transform( + const float* __restrict__ score, + int32_t length, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + dst_page_table[i] = (i < length) ? src_page_table[i] : -1; + } + } + + // keep the first `length` entries, set others to -1 + __device__ void naive_topk_transform_ragged( + const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; + } + } + + __device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); + } + + __device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); + } + + __device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { + // An optimized topk kernel copied from tilelang kernel + // We assume length > TopK here, or it will crash + int topk = TopK; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + // allocate for two rounds + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // stage 1: 8bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(input[idx + row_start]); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { + #pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uint8(input[idx + row_start])); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = input[idx + row_start]; + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + /// NOTE: (dark) fuse the histogram computation here + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8bit radix passes + #pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + // clip here to prevent overflow + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = input[idx + row_start]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[TopK - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + /// NOTE: (dark) fuse the histogram computation here + s_input_idx[r_idx ^ 1][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // topk + void topk_kernel(const FastTopKParams params) { + const auto& [input, row_starts, indices, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto indice = indices + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_cuda(score, indice, length); + } else { + return fast_topk_cuda_tl(score, indice, row_start, length); + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // decode + void topk_transform_decode_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride) { + const auto& [input, _1, _2, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = 0; + const auto length = lengths[bid]; + const auto src_page_entry = src_page_table + bid * src_stride; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // prefill + void topk_transform_prefill_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride, + const int32_t* __restrict__ cu_seqlens_q, + const int64_t prefill_bs) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto length = lengths[bid]; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + + /// NOTE: prefill bs is usually small, we can just use a simple loop here + /// We ensure that last cu_seqlens is equal to number of blocks launched + __shared__ const int32_t* s_src_page_entry; + if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { + if (tid < prefill_bs) { + if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { + s_src_page_entry = src_page_table + tid * src_stride; + } + } + } else { + for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { + if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { + s_src_page_entry = src_page_table + i * src_stride; + } + } + } + __syncthreads(); + const auto src_page_entry = s_src_page_entry; + + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv + void topk_transform_prefill_ragged_kernel( + const FastTopKParams params, + int32_t* __restrict__ topk_indices_ragged, + const int32_t* __restrict__ topk_indices_offset) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto dst_indices_entry = topk_indices_ragged + bid * TopK; + const auto score = input + bid * input_stride; + const auto offset = topk_indices_offset[bid]; + + if (length <= TopK) { + return naive_topk_transform_ragged(score, length, dst_indices_entry, offset); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_indices_entry[idx_0] = pos_0 + offset; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_indices_entry[idx_1] = pos_1 + offset; + } + } + + auto get_params( + const at::Tensor& score, + const at::Tensor& lengths, + std::optional row_starts_opt = std::nullopt, + std::optional indices_opt = std::nullopt) -> FastTopKParams { + const auto B = score.size(0); + TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); + if (row_starts_opt.has_value()) { + const auto& row_starts = row_starts_opt.value(); + TORCH_CHECK(row_starts.dim() == 1); + TORCH_CHECK(row_starts.size(0) == B); + } + TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); + TORCH_CHECK(lengths.size(0) == B); + int32_t* indices_data_ptr = nullptr; + if (indices_opt.has_value()) { + const auto& indices = indices_opt.value(); + TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); + TORCH_CHECK(indices.size(0) == B); + TORCH_CHECK(indices.size(1) == TopK); + indices_data_ptr = indices.data_ptr(); + } + + return FastTopKParams{ + .input = score.data_ptr(), + .row_starts = row_starts_opt.has_value() ? row_starts_opt->data_ptr() : nullptr, + .indices = indices_data_ptr, + .lengths = lengths.data_ptr(), + .input_stride = score.stride(0), + }; + } + + template + void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { + #ifdef USE_ROCM + // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, + // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing + // a function pointer directly, so cast explicitly. + return ::cudaFuncSetAttribute( + reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); + #else + // CUDA: keep original behavior (no cast needed). + return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); + #endif + }(); + TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); + } + + // ====================================================================== + // Vortex integration: BOS/EOS-aware segmented TopK with index remapping + // ====================================================================== + + template + __device__ __forceinline__ float vortex_to_float(T x); + + template <> + __device__ __forceinline__ float vortex_to_float(float x) { return x; } + + template <> + __device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { + return __bfloat162float(x); + } + + constexpr int VORTEX_MAX_TOPK = 2048; + + // Templated version of fast_topk_cuda_tl: + // - ScoreT: float or __nv_bfloat16 + // - target_k: runtime parameter (replaces compile-time TopK) + template + __device__ void fast_topk_vortex( + const ScoreT* __restrict__ input, + int* __restrict__ index, + int row_start, + int length, + int target_k) + { + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int vh_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int vh_counter; + alignas(128) __shared__ int vh_threshold_bin_id; + alignas(128) __shared__ int vh_num_input[2]; + + auto& vh_histogram = vh_histogram_buf[0]; + extern __shared__ int vh_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // Stage 1: 8-bit coarse histogram + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(vortex_to_float(input[idx + row_start])); + ::atomicAdd(&vh_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { + #pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = vh_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += vh_histogram_buf[k][tx + j]; + } + vh_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { + vh_threshold_bin_id = tx; + vh_num_input[0] = 0; + vh_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = vh_threshold_bin_id; + topk -= vh_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast( + convert_to_uint8(vortex_to_float(input[idx + row_start]))); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&vh_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + vh_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&vh_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // Stage 2: refine with 8-bit radix passes + #pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int vh_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = vh_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) + ? _raw_num_input + : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { + vh_threshold_bin_id = tx; + vh_num_input[r_idx ^ 1] = 0; + vh_last_remain = topk - vh_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = vh_threshold_bin_id; + topk -= vh_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = vh_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32( + vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = vh_input_idx[r_idx][i]; + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&vh_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&vh_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + vh_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&vh_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } + } + + // Wrapper kernel: one CUDA block per batch*head segment + template + __global__ __launch_bounds__(kThreadsPerBlock) + void TopKOutput_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos) + { + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_vortex(score_blk, s_indices, 0, nblk, topk_val); + __syncthreads(); + + // Remap position indices -> page indices via dense_kv_indices + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } + } + + } // namespace + + #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + + void fast_topk_interface( + const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(indices); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + CHECK_CUDA(lengths); + const auto params = get_params(score, lengths, row_starts_opt, indices); + const auto B = score.size(0); + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + setup_kernel_smem_once(); + topk_kernel<<>>(params); + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); + } + + void fast_topk_transform_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& dst_page_table, + const at::Tensor& src_page_table, + const at::Tensor& cu_seqlens_q, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(dst_page_table); + CHECK_CUDA(src_page_table); + CHECK_CUDA(cu_seqlens_q); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); + TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); + TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); + const auto prefill_bs = cu_seqlens_q.size(0) - 1; + TORCH_CHECK(dst_page_table.size(0) == B); + TORCH_CHECK(dst_page_table.size(1) == TopK); + TORCH_CHECK(src_page_table.size(0) == prefill_bs); + TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + const auto src_stride = src_page_table.stride(0); + + // dispatch to decode or prefill + // extend and draft extend: row_starts_opt is not null, invokes the prefill kernel + // decode: row_starts_opt is null, invokes the decode kernel + // target verify: row_starts_opt is null, invokes the prefill kernel + const auto is_decode = !row_starts_opt.has_value() && prefill_bs == B; + if (is_decode) { + setup_kernel_smem_once(); + topk_transform_decode_kernel<<>>( + params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride); + } else { + setup_kernel_smem_once(); + topk_transform_prefill_kernel<<>>( + params, + dst_page_table.data_ptr(), + src_page_table.data_ptr(), + src_stride, + cu_seqlens_q.data_ptr(), + prefill_bs); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); + } + + void fast_topk_transform_ragged_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& topk_indices_ragged, + const at::Tensor& topk_indices_offset, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(topk_indices_ragged); + CHECK_CUDA(topk_indices_offset); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous()); + TORCH_CHECK(topk_indices_offset.dim() == 1); + + TORCH_CHECK(topk_indices_ragged.size(0) == B); + TORCH_CHECK(topk_indices_ragged.size(1) == TopK); + TORCH_CHECK(topk_indices_offset.size(0) == B); + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + + setup_kernel_smem_once(); + topk_transform_prefill_ragged_kernel<<>>( + params, topk_indices_ragged.data_ptr(), topk_indices_offset.data_ptr()); + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); + } + + // ====================================================================== + // Vortex host entry point — same interface as topk_output in topk.cu + // ====================================================================== + void topk_output( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages) + { + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_output: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos); + } else if (x.scalar_type() == at::ScalarType::Float) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos); + } else { + TORCH_CHECK(false, + "topk_output: unsupported dtype ", + x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_output kernel failed: ", ::cudaGetErrorString(result)); + } \ No newline at end of file diff --git a/examples/verify_algo.sh b/examples/verify_algo.sh index d80f09a..7487708 100644 --- a/examples/verify_algo.sh +++ b/examples/verify_algo.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash set -e -export CUDA_VISIBLE_DEVICES=1 +# export CUDA_VISIBLE_DEVICES=0 sparse_algos=( "block_sparse_attention" diff --git a/examples/verify_algo_fp8.sh b/examples/verify_algo_fp8.sh index 7f266e5..fd85dad 100755 --- a/examples/verify_algo_fp8.sh +++ b/examples/verify_algo_fp8.sh @@ -1,6 +1,5 @@ #!/usr/bin/env bash set -e -export CUDA_VISIBLE_DEVICES=3 sparse_algos=( "block_sparse_attention" diff --git a/examples/verify_algo_int8.sh b/examples/verify_algo_int8.sh index 4cf1366..e57c63f 100644 --- a/examples/verify_algo_int8.sh +++ b/examples/verify_algo_int8.sh @@ -1,6 +1,5 @@ #!/usr/bin/env bash set -e -export CUDA_VISIBLE_DEVICES=2 sparse_algos=( "block_sparse_attention" From ede862425998eaa1e5a3b449dec274798d0ffb1c Mon Sep 17 00:00:00 2001 From: UED Date: Mon, 9 Mar 2026 05:22:43 +0000 Subject: [PATCH 06/22] add parameter to switch between two topk kernels (naive or sglang) --- csrc/register.cc | 1 + csrc/register.h | 12 + csrc/topk.cu | 1029 ++++++--------------------- examples/verify_algo.py | 15 +- examples/verify_algo_fp8.sh | 1 + examples/verify_algo_int8.sh | 1 + vortex_torch/indexer/context.py | 4 +- vortex_torch/indexer/output_func.py | 58 +- 8 files changed, 276 insertions(+), 845 deletions(-) diff --git a/csrc/register.cc b/csrc/register.cc index fd9d4eb..532fcdf 100644 --- a/csrc/register.cc +++ b/csrc/register.cc @@ -8,6 +8,7 @@ PYBIND11_MODULE(vortex_torch_C, m){ m.def("Chunkwise_NH2HN_Transpose", &Chunkwise_NH2HN_Transpose); m.def("Chunkwise_HN2NH_Transpose", &Chunkwise_HN2NH_Transpose); m.def("topk_output", &topk_output); + m.def("topk_output_sglang", &topk_output_sglang); m.def("sglang_plan_decode_fa3", &sglang_plan_decode_fa3); m.def("sglang_plan_prefill_fa3", &sglang_plan_prefill_fa3); m.def("Chunkwise_HN2NH_Transpose_FA3", &Chunkwise_HN2NH_Transpose_FA3); diff --git a/csrc/register.h b/csrc/register.h index 92499ed..b81168b 100644 --- a/csrc/register.h +++ b/csrc/register.h @@ -85,6 +85,18 @@ const int64_t reserved_eos, const int64_t max_seq_lengths ); +void topk_output_sglang( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& sparse_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& sparse_kv_indices, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_seq_lengths +); void sglang_plan_decode_fa3( const at::Tensor& cached_seq_lens, diff --git a/csrc/topk.cu b/csrc/topk.cu index 8a48aad..3aa49b9 100644 --- a/csrc/topk.cu +++ b/csrc/topk.cu @@ -1,826 +1,203 @@ -/** - * @NOTE: This file is adapted from - * https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_v32/topk_selector.py - * We: - * 1. adapt from tilelang to pure cuda - * 2. optimize the performance a little - * 3. fix the potential illegal memory access - */ - #include - #include - #include - #include - #include - #include - #include - #include - #include - - #include - #include - #include - - namespace { - - constexpr int TopK = 2048; - constexpr int kThreadsPerBlock = 1024; - - #ifdef USE_ROCM - // On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a - // per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. - #ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES - constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); - #else - constexpr size_t kSmem = 48 * 1024; // bytes - #endif - #else - // Reduced from 128KB to 32KB to improve occupancy. - // Each radix pass needs at most ~TopK candidates in the threshold bin, - // so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. - constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) - #endif - - struct FastTopKParams { - const float* __restrict__ input; // [B, input_stride] - const int32_t* __restrict__ row_starts; // [B] - int32_t* __restrict__ indices; // [B, TopK] - int32_t* __restrict__ lengths; // [B] - int64_t input_stride; - }; - - // when length <= TopK, we can directly write the indices - __device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { - const auto tid = threadIdx.x; - for (int i = tid; i < TopK; i += kThreadsPerBlock) { - indice[i] = (i < length) ? i : -1; - } - } - - // keep the first `length` entries, set others to -1 - __device__ void naive_topk_transform( - const float* __restrict__ score, - int32_t length, - int32_t* __restrict__ dst_page_table, - const int32_t* __restrict__ src_page_table) { - const auto tid = threadIdx.x; - for (auto i = tid; i < TopK; i += kThreadsPerBlock) { - dst_page_table[i] = (i < length) ? src_page_table[i] : -1; - } - } - - // keep the first `length` entries, set others to -1 - __device__ void naive_topk_transform_ragged( - const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { - const auto tid = threadIdx.x; - for (auto i = tid; i < TopK; i += kThreadsPerBlock) { - topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; - } - } - - __device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { - __half h = __float2half_rn(x); - uint16_t bits = __half_as_ushort(h); - uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); - return static_cast(key >> 8); - } - - __device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { - uint32_t bits = __float_as_uint(x); - return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); - } - - __device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { - // An optimized topk kernel copied from tilelang kernel - // We assume length > TopK here, or it will crash - int topk = TopK; - constexpr auto BLOCK_SIZE = 1024; - constexpr auto RADIX = 256; - constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); - - alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; - alignas(128) __shared__ int s_counter; - alignas(128) __shared__ int s_threshold_bin_id; - alignas(128) __shared__ int s_num_input[2]; - - auto& s_histogram = s_histogram_buf[0]; - // allocate for two rounds - extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; - - const int tx = threadIdx.x; - - // stage 1: 8bit coarse histogram - if (tx < RADIX + 1) s_histogram[tx] = 0; - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = convert_to_uint8(input[idx + row_start]); - ::atomicAdd(&s_histogram[bin], 1); - } - __syncthreads(); - - const auto run_cumsum = [&] { - #pragma unroll 8 - for (int i = 0; i < 8; ++i) { - static_assert(1 << 8 == RADIX); - if (C10_LIKELY(tx < RADIX)) { - const auto j = 1 << i; - const auto k = i & 1; - auto value = s_histogram_buf[k][tx]; - if (tx < RADIX - j) { - value += s_histogram_buf[k][tx + j]; - } - s_histogram_buf[k ^ 1][tx] = value; - } - __syncthreads(); - } - }; - - run_cumsum(); - if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { - s_threshold_bin_id = tx; - s_num_input[0] = 0; - s_counter = 0; - } - __syncthreads(); - - const auto threshold_bin = s_threshold_bin_id; - topk -= s_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = static_cast(convert_to_uint8(input[idx + row_start])); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - return; - } else { - __syncthreads(); - if (tx < RADIX + 1) { - s_histogram[tx] = 0; - } - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto raw_input = input[idx + row_start]; - const auto bin = static_cast(convert_to_uint8(raw_input)); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - const auto pos = ::atomicAdd(&s_num_input[0], 1); - /// NOTE: (dark) fuse the histogram computation here - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - s_input_idx[0][pos] = idx; - const auto bin = convert_to_uint32(raw_input); - const auto sub_bin = (bin >> 24) & 0xFF; - ::atomicAdd(&s_histogram[sub_bin], 1); - } - } - } - __syncthreads(); - } - - // stage 2: refine with 8bit radix passes - #pragma unroll 4 - for (int round = 0; round < 4; ++round) { - __shared__ int s_last_remain; - const auto r_idx = round % 2; - - // clip here to prevent overflow - const auto _raw_num_input = s_num_input[r_idx]; - const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); - - run_cumsum(); - if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { - s_threshold_bin_id = tx; - s_num_input[r_idx ^ 1] = 0; - s_last_remain = topk - s_histogram[tx + 1]; - } - __syncthreads(); - - const auto threshold_bin = s_threshold_bin_id; - topk -= s_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = s_input_idx[r_idx][i]; - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - break; - } else { - __syncthreads(); - if (tx < RADIX + 1) { - s_histogram[tx] = 0; - } - __syncthreads(); - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = s_input_idx[r_idx][i]; - const auto raw_input = input[idx + row_start]; - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - if (round == 3) { - const auto pos = ::atomicAdd(&s_last_remain, -1); - if (pos > 0) { - index[TopK - pos] = idx; - } - } else { - const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - /// NOTE: (dark) fuse the histogram computation here - s_input_idx[r_idx ^ 1][pos] = idx; - const auto bin = convert_to_uint32(raw_input); - const auto sub_bin = (bin >> (offset - 8)) & 0xFF; - ::atomicAdd(&s_histogram[sub_bin], 1); - } - } - } - } - __syncthreads(); - } - } - } - - __global__ __launch_bounds__(kThreadsPerBlock) // topk - void topk_kernel(const FastTopKParams params) { - const auto& [input, row_starts, indices, lengths, input_stride] = params; - const auto bid = static_cast(blockIdx.x); - const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; - const auto length = lengths[bid]; - const auto indice = indices + bid * TopK; - const auto score = input + bid * input_stride; - if (length <= TopK) { - return naive_topk_cuda(score, indice, length); - } else { - return fast_topk_cuda_tl(score, indice, row_start, length); - } - } - - __global__ __launch_bounds__(kThreadsPerBlock) // decode - void topk_transform_decode_kernel( - const FastTopKParams params, - int32_t* __restrict__ dst_page_table, - const int32_t* __restrict__ src_page_table, - const int64_t src_stride) { - const auto& [input, _1, _2, lengths, input_stride] = params; - const auto bid = static_cast(blockIdx.x); - const auto tid = threadIdx.x; - const auto row_start = 0; - const auto length = lengths[bid]; - const auto src_page_entry = src_page_table + bid * src_stride; - const auto dst_page_entry = dst_page_table + bid * TopK; - const auto score = input + bid * input_stride; - if (length <= TopK) { - return naive_topk_transform(score, length, dst_page_entry, src_page_entry); - } else { - __shared__ int s_indices[TopK]; - fast_topk_cuda_tl(score, s_indices, row_start, length); - // copy src[s_indices] to dst, we manually unroll here - static_assert(TopK % kThreadsPerBlock == 0); - static_assert(TopK / kThreadsPerBlock == 2); - const auto idx_0 = tid; - const auto pos_0 = s_indices[idx_0]; - dst_page_entry[idx_0] = src_page_entry[pos_0]; - const auto idx_1 = tid + kThreadsPerBlock; - const auto pos_1 = s_indices[idx_1]; - dst_page_entry[idx_1] = src_page_entry[pos_1]; - } - } - - __global__ __launch_bounds__(kThreadsPerBlock) // prefill - void topk_transform_prefill_kernel( - const FastTopKParams params, - int32_t* __restrict__ dst_page_table, - const int32_t* __restrict__ src_page_table, - const int64_t src_stride, - const int32_t* __restrict__ cu_seqlens_q, - const int64_t prefill_bs) { - const auto& [input, row_starts, _, lengths, input_stride] = params; - const auto bid = static_cast(blockIdx.x); - const auto tid = threadIdx.x; - const auto length = lengths[bid]; - const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; - const auto dst_page_entry = dst_page_table + bid * TopK; - const auto score = input + bid * input_stride; - - /// NOTE: prefill bs is usually small, we can just use a simple loop here - /// We ensure that last cu_seqlens is equal to number of blocks launched - __shared__ const int32_t* s_src_page_entry; - if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { - if (tid < prefill_bs) { - if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { - s_src_page_entry = src_page_table + tid * src_stride; - } - } - } else { - for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { - if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { - s_src_page_entry = src_page_table + i * src_stride; - } - } - } - __syncthreads(); - const auto src_page_entry = s_src_page_entry; - - if (length <= TopK) { - return naive_topk_transform(score, length, dst_page_entry, src_page_entry); - } else { - __shared__ int s_indices[TopK]; - fast_topk_cuda_tl(score, s_indices, row_start, length); - // copy src[s_indices] to dst, we manually unroll here - static_assert(TopK % kThreadsPerBlock == 0); - static_assert(TopK / kThreadsPerBlock == 2); - const auto idx_0 = tid; - const auto pos_0 = s_indices[idx_0]; - dst_page_entry[idx_0] = src_page_entry[pos_0]; - const auto idx_1 = tid + kThreadsPerBlock; - const auto pos_1 = s_indices[idx_1]; - dst_page_entry[idx_1] = src_page_entry[pos_1]; - } - } - - __global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv - void topk_transform_prefill_ragged_kernel( - const FastTopKParams params, - int32_t* __restrict__ topk_indices_ragged, - const int32_t* __restrict__ topk_indices_offset) { - const auto& [input, row_starts, _, lengths, input_stride] = params; - const auto bid = static_cast(blockIdx.x); - const auto tid = threadIdx.x; - const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; - const auto length = lengths[bid]; - const auto dst_indices_entry = topk_indices_ragged + bid * TopK; - const auto score = input + bid * input_stride; - const auto offset = topk_indices_offset[bid]; - - if (length <= TopK) { - return naive_topk_transform_ragged(score, length, dst_indices_entry, offset); - } else { - __shared__ int s_indices[TopK]; - fast_topk_cuda_tl(score, s_indices, row_start, length); - // copy src[s_indices] to dst, we manually unroll here - static_assert(TopK % kThreadsPerBlock == 0); - static_assert(TopK / kThreadsPerBlock == 2); - const auto idx_0 = tid; - const auto pos_0 = s_indices[idx_0]; - dst_indices_entry[idx_0] = pos_0 + offset; - const auto idx_1 = tid + kThreadsPerBlock; - const auto pos_1 = s_indices[idx_1]; - dst_indices_entry[idx_1] = pos_1 + offset; - } - } - - auto get_params( - const at::Tensor& score, - const at::Tensor& lengths, - std::optional row_starts_opt = std::nullopt, - std::optional indices_opt = std::nullopt) -> FastTopKParams { - const auto B = score.size(0); - TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); - if (row_starts_opt.has_value()) { - const auto& row_starts = row_starts_opt.value(); - TORCH_CHECK(row_starts.dim() == 1); - TORCH_CHECK(row_starts.size(0) == B); - } - TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); - TORCH_CHECK(lengths.size(0) == B); - int32_t* indices_data_ptr = nullptr; - if (indices_opt.has_value()) { - const auto& indices = indices_opt.value(); - TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); - TORCH_CHECK(indices.size(0) == B); - TORCH_CHECK(indices.size(1) == TopK); - indices_data_ptr = indices.data_ptr(); - } - - return FastTopKParams{ - .input = score.data_ptr(), - .row_starts = row_starts_opt.has_value() ? row_starts_opt->data_ptr() : nullptr, - .indices = indices_data_ptr, - .lengths = lengths.data_ptr(), - .input_stride = score.stride(0), - }; - } - - template - void setup_kernel_smem_once() { - [[maybe_unused]] - static const auto result = [] { - #ifdef USE_ROCM - // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, - // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing - // a function pointer directly, so cast explicitly. - return ::cudaFuncSetAttribute( - reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); - #else - // CUDA: keep original behavior (no cast needed). - return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); - #endif - }(); - TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); - } - - // ====================================================================== - // Vortex integration: BOS/EOS-aware segmented TopK with index remapping - // ====================================================================== - - template - __device__ __forceinline__ float vortex_to_float(T x); - - template <> - __device__ __forceinline__ float vortex_to_float(float x) { return x; } - - template <> - __device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { - return __bfloat162float(x); - } - - constexpr int VORTEX_MAX_TOPK = 2048; - - // Templated version of fast_topk_cuda_tl: - // - ScoreT: float or __nv_bfloat16 - // - target_k: runtime parameter (replaces compile-time TopK) - template - __device__ void fast_topk_vortex( - const ScoreT* __restrict__ input, - int* __restrict__ index, - int row_start, - int length, - int target_k) - { - int topk = target_k; - constexpr auto BLOCK_SIZE = 1024; - constexpr auto RADIX = 256; - constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); - - alignas(128) __shared__ int vh_histogram_buf[2][RADIX + 128]; - alignas(128) __shared__ int vh_counter; - alignas(128) __shared__ int vh_threshold_bin_id; - alignas(128) __shared__ int vh_num_input[2]; - - auto& vh_histogram = vh_histogram_buf[0]; - extern __shared__ int vh_input_idx[][SMEM_INPUT_SIZE]; - - const int tx = threadIdx.x; - - // Stage 1: 8-bit coarse histogram - if (tx < RADIX + 1) vh_histogram[tx] = 0; - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = convert_to_uint8(vortex_to_float(input[idx + row_start])); - ::atomicAdd(&vh_histogram[bin], 1); - } - __syncthreads(); - - const auto run_cumsum = [&] { - #pragma unroll 8 - for (int i = 0; i < 8; ++i) { - static_assert(1 << 8 == RADIX); - if (C10_LIKELY(tx < RADIX)) { - const auto j = 1 << i; - const auto k = i & 1; - auto value = vh_histogram_buf[k][tx]; - if (tx < RADIX - j) { - value += vh_histogram_buf[k][tx + j]; - } - vh_histogram_buf[k ^ 1][tx] = value; - } - __syncthreads(); - } - }; - - run_cumsum(); - if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { - vh_threshold_bin_id = tx; - vh_num_input[0] = 0; - vh_counter = 0; - } - __syncthreads(); - - const auto threshold_bin = vh_threshold_bin_id; - topk -= vh_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = static_cast( - convert_to_uint8(vortex_to_float(input[idx + row_start]))); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - return; - } else { - __syncthreads(); - if (tx < RADIX + 1) vh_histogram[tx] = 0; - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto raw_input = vortex_to_float(input[idx + row_start]); - const auto bin = static_cast(convert_to_uint8(raw_input)); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - const auto pos = ::atomicAdd(&vh_num_input[0], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - vh_input_idx[0][pos] = idx; - const auto b32 = convert_to_uint32(raw_input); - const auto sub_bin = (b32 >> 24) & 0xFF; - ::atomicAdd(&vh_histogram[sub_bin], 1); - } - } - } - __syncthreads(); - } - - // Stage 2: refine with 8-bit radix passes - #pragma unroll 4 - for (int round = 0; round < 4; ++round) { - __shared__ int vh_last_remain; - const auto r_idx = round % 2; - - const auto _raw_num_input = vh_num_input[r_idx]; - const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) - ? _raw_num_input - : int(SMEM_INPUT_SIZE); - - run_cumsum(); - if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { - vh_threshold_bin_id = tx; - vh_num_input[r_idx ^ 1] = 0; - vh_last_remain = topk - vh_histogram[tx + 1]; - } - __syncthreads(); - - const auto threshold_bin = vh_threshold_bin_id; - topk -= vh_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = vh_input_idx[r_idx][i]; - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32( - vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - break; - } else { - __syncthreads(); - if (tx < RADIX + 1) vh_histogram[tx] = 0; - __syncthreads(); - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = vh_input_idx[r_idx][i]; - const auto raw_input = vortex_to_float(input[idx + row_start]); - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - if (round == 3) { - const auto pos = ::atomicAdd(&vh_last_remain, -1); - if (pos > 0) { - index[target_k - pos] = idx; - } - } else { - const auto pos = ::atomicAdd(&vh_num_input[r_idx ^ 1], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - vh_input_idx[r_idx ^ 1][pos] = idx; - const auto b32 = convert_to_uint32(raw_input); - const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; - ::atomicAdd(&vh_histogram[sub_bin], 1); - } - } - } - } - __syncthreads(); - } - } - } - - // Wrapper kernel: one CUDA block per batch*head segment - template - __global__ __launch_bounds__(kThreadsPerBlock) - void TopKOutput_Kernel( - const ScoreT* __restrict__ score, - const int* __restrict__ dense_kv_indptr, - const int* __restrict__ sparse_kv_indptr, - const int* __restrict__ dense_kv_indices, - int* __restrict__ sparse_kv_indices, - const int topk_val, - const int page_reserved_bos, - const int page_reserved_eos) - { - const int bx = blockIdx.x; - - const int start = dense_kv_indptr[bx] + page_reserved_bos; - const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; - const int nblk = end - start; - if (nblk <= topk_val) return; - - const ScoreT* __restrict__ score_blk = score + start; - const int* __restrict__ idx_blk = dense_kv_indices + start; - int* __restrict__ out_blk = sparse_kv_indices - + sparse_kv_indptr[bx] - + page_reserved_bos; - - __shared__ int s_indices[VORTEX_MAX_TOPK]; - fast_topk_vortex(score_blk, s_indices, 0, nblk, topk_val); - __syncthreads(); - - // Remap position indices -> page indices via dense_kv_indices - const int tx = threadIdx.x; - for (int i = tx; i < topk_val; i += kThreadsPerBlock) { - out_blk[i] = idx_blk[s_indices[i]]; - } - } - - } // namespace - - #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") - - void fast_topk_interface( - const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional row_starts_opt) { - CHECK_CUDA(score); - CHECK_CUDA(indices); - if (row_starts_opt.has_value()) { - CHECK_CUDA(row_starts_opt.value()); - } - CHECK_CUDA(lengths); - const auto params = get_params(score, lengths, row_starts_opt, indices); - const auto B = score.size(0); - const auto stream = at::cuda::getCurrentCUDAStream().stream(); - const auto grid = dim3{static_cast(B)}; - const auto block = dim3{kThreadsPerBlock}; - setup_kernel_smem_once(); - topk_kernel<<>>(params); - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); - } - - void fast_topk_transform_interface( - const at::Tensor& score, - const at::Tensor& lengths, - at::Tensor& dst_page_table, - const at::Tensor& src_page_table, - const at::Tensor& cu_seqlens_q, - std::optional row_starts_opt) { - CHECK_CUDA(score); - CHECK_CUDA(lengths); - CHECK_CUDA(dst_page_table); - CHECK_CUDA(src_page_table); - CHECK_CUDA(cu_seqlens_q); - if (row_starts_opt.has_value()) { - CHECK_CUDA(row_starts_opt.value()); - } - const auto params = get_params(score, lengths, row_starts_opt); - const auto B = score.size(0); - TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); - TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); - TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); - const auto prefill_bs = cu_seqlens_q.size(0) - 1; - TORCH_CHECK(dst_page_table.size(0) == B); - TORCH_CHECK(dst_page_table.size(1) == TopK); - TORCH_CHECK(src_page_table.size(0) == prefill_bs); - TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs - - // launch kernel - const auto stream = at::cuda::getCurrentCUDAStream().stream(); - const auto grid = dim3{static_cast(B)}; - const auto block = dim3{kThreadsPerBlock}; - const auto src_stride = src_page_table.stride(0); - - // dispatch to decode or prefill - // extend and draft extend: row_starts_opt is not null, invokes the prefill kernel - // decode: row_starts_opt is null, invokes the decode kernel - // target verify: row_starts_opt is null, invokes the prefill kernel - const auto is_decode = !row_starts_opt.has_value() && prefill_bs == B; - if (is_decode) { - setup_kernel_smem_once(); - topk_transform_decode_kernel<<>>( - params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride); - } else { - setup_kernel_smem_once(); - topk_transform_prefill_kernel<<>>( - params, - dst_page_table.data_ptr(), - src_page_table.data_ptr(), - src_stride, - cu_seqlens_q.data_ptr(), - prefill_bs); - } - - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); - } - - void fast_topk_transform_ragged_interface( - const at::Tensor& score, - const at::Tensor& lengths, - at::Tensor& topk_indices_ragged, - const at::Tensor& topk_indices_offset, - std::optional row_starts_opt) { - CHECK_CUDA(score); - CHECK_CUDA(lengths); - CHECK_CUDA(topk_indices_ragged); - CHECK_CUDA(topk_indices_offset); - if (row_starts_opt.has_value()) { - CHECK_CUDA(row_starts_opt.value()); - } - - const auto params = get_params(score, lengths, row_starts_opt); - const auto B = score.size(0); - TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous()); - TORCH_CHECK(topk_indices_offset.dim() == 1); - - TORCH_CHECK(topk_indices_ragged.size(0) == B); - TORCH_CHECK(topk_indices_ragged.size(1) == TopK); - TORCH_CHECK(topk_indices_offset.size(0) == B); - - // launch kernel - const auto stream = at::cuda::getCurrentCUDAStream().stream(); - const auto grid = dim3{static_cast(B)}; - const auto block = dim3{kThreadsPerBlock}; - - setup_kernel_smem_once(); - topk_transform_prefill_ragged_kernel<<>>( - params, topk_indices_ragged.data_ptr(), topk_indices_offset.data_ptr()); - - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); - } - - // ====================================================================== - // Vortex host entry point — same interface as topk_output in topk.cu - // ====================================================================== - void topk_output( - const at::Tensor& x, - const at::Tensor& dense_kv_indptr, - const at::Tensor& sparse_kv_indptr, - const at::Tensor& dense_kv_indices, - at::Tensor& sparse_kv_indices, - const int64_t eff_batch_size, - const int64_t topk_val, - const int64_t reserved_bos, - const int64_t reserved_eos, - const int64_t max_num_pages) - { - TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, - "topk_output: topk_val (", topk_val, - ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); - - dim3 nblks(eff_batch_size); - dim3 nthreads(kThreadsPerBlock); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (x.scalar_type() == at::ScalarType::BFloat16) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Kernel<__nv_bfloat16><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos); - } else if (x.scalar_type() == at::ScalarType::Float) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Kernel<<>>( - x.data_ptr(), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos); - } else { - TORCH_CHECK(false, - "topk_output: unsupported dtype ", - x.scalar_type()); - } - - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, - "topk_output kernel failed: ", ::cudaGetErrorString(result)); - } \ No newline at end of file +#include "register.h" +#include + + +template +__global__ void TopKOutput_F32_Kernel( +const float* __restrict__ score, +const int* __restrict__ dense_kv_indptr, +const int* __restrict__ sparse_kv_indptr, +const int* __restrict__ dense_kv_indices, +int* __restrict__ sparse_kv_indices, +const int topk_val, +const int page_reserved_bos, +const int page_reserved_eos) +{ + const int bx = blockIdx.x; + const int tx = threadIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const float* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + sparse_kv_indptr[bx] + page_reserved_bos; + + float key[ITEM_PER_THREAD]; + int val[ITEM_PER_THREAD]; + + using BLF = cub::BlockLoad; + using BLI = cub::BlockLoad; + using BSI = cub::BlockStore; + using Sort = cub::BlockRadixSort; + + __shared__ union { + typename BLF::TempStorage lf; + typename BLI::TempStorage li; + typename BSI::TempStorage si; + typename Sort::TempStorage sort; + } temp; + + BLF(temp.lf).Load(score_blk, key, nblk, -INFINITY); + __syncthreads(); + BLI(temp.li).Load(idx_blk, val, nblk, 0); + __syncthreads(); + + Sort(temp.sort).SortDescending(key, val); + __syncthreads(); + + const int valid_out = min(topk_val, nblk); + BSI(temp.si).Store(out_blk, /*per-thread regs*/ val, valid_out); +} + + +template +__global__ void TopKOutput_BF16_Kernel( +const __nv_bfloat16* __restrict__ score, +const int* __restrict__ dense_kv_indptr, +const int* __restrict__ sparse_kv_indptr, +const int* __restrict__ dense_kv_indices, +int* __restrict__ sparse_kv_indices, +const int topk_val, +const int page_reserved_bos, +const int page_reserved_eos) +{ + const int bx = blockIdx.x; + const int tx = threadIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const __nv_bfloat16* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + sparse_kv_indptr[bx] + page_reserved_bos; + + const __nv_bfloat16 ninf_bf16 = __float2bfloat16(-CUDART_INF_F); + + __nv_bfloat16 key_bf16[ITEM_PER_THREAD]; + float key[ITEM_PER_THREAD]; + int val[ITEM_PER_THREAD]; + + using BLF = cub::BlockLoad<__nv_bfloat16, NUM_THREADS, ITEM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE>; + using BLI = cub::BlockLoad; + using BSI = cub::BlockStore; + using Sort = cub::BlockRadixSort; + + __shared__ union { + typename BLF::TempStorage lf; + typename BLI::TempStorage li; + typename BSI::TempStorage si; + typename Sort::TempStorage sort; + } temp; + + BLF(temp.lf).Load(score_blk, key_bf16, nblk, ninf_bf16); + + #pragma unroll + for (int i = 0; i < ITEM_PER_THREAD; ++i){ + key[i] = __bfloat162float(key_bf16[i]); + } + __syncthreads(); + + BLI(temp.li).Load(idx_blk, val, nblk, 0); + __syncthreads(); + + Sort(temp.sort).SortDescending(key, val); + __syncthreads(); + + const int valid_out = min(topk_val, nblk); + BSI(temp.si).Store(out_blk, /*per-thread regs*/ val, valid_out); +} + + + +void topk_output( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& sparse_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& sparse_kv_indices, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_num_pages +){ + + + dim3 nblks(eff_batch_size); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + if (max_num_pages <= 128){ + TopKOutput_BF16_Kernel<128, 1><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos + ); + } else if (max_num_pages <= 256){ + TopKOutput_BF16_Kernel<128, 2><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos + ); + } else if (max_num_pages <= 512){ + TopKOutput_BF16_Kernel<128, 4><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos + ); + } else if (max_num_pages <= 1024){ + TopKOutput_BF16_Kernel<256, 4><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos + ); + } else if (max_num_pages <= 2048){ + TopKOutput_BF16_Kernel<256, 8><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos + ); + } else if (max_num_pages <= 4096){ + TopKOutput_BF16_Kernel<512, 8><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos + ); + } else { + TORCH_CHECK(false); + } + +} \ No newline at end of file diff --git a/examples/verify_algo.py b/examples/verify_algo.py index 9958b7e..1187aca 100644 --- a/examples/verify_algo.py +++ b/examples/verify_algo.py @@ -56,12 +56,13 @@ def verify_algos( sparse_attention: bool = True, mem: float = 0.8, kv_cache_dtype: str = "auto", +topk_type: str = "naive", ): - llm = sgl.Engine(model_path=model_name, + llm = sgl.Engine(model_path=model_name, disable_cuda_graph=False, page_size=page_size, - vortex_topk_val=topk_val, + vortex_topk_val=topk_val, disable_overlap_schedule=True, attention_backend="flashinfer", enable_vortex_sparsity=sparse_attention, @@ -72,6 +73,7 @@ def verify_algos( vortex_max_seq_lens=12288, mem_fraction_static=mem, kv_cache_dtype=kv_cache_dtype, + vortex_topk_type=topk_type, ) with open("amc23.jsonl", "r", encoding="utf-8") as f: @@ -221,6 +223,14 @@ def parse_args(): choices=["auto", "fp8_e5m2", "fp8_e4m3", "int8"], help='KV cache dtype (default: "auto").', ) + + parser.add_argument( + "--topk-type", + type=str, + default="naive", + choices=["naive", "sglang"], + help='TopK kernel type: "naive" for topk_output, "sglang" for topk_output_sglang (default: "naive").', + ) return parser.parse_args() if __name__ == "__main__": @@ -235,6 +245,7 @@ def parse_args(): sparse_attention=not(args.full_attention), mem=args.mem, kv_cache_dtype=args.kv_cache_dtype, + topk_type=args.topk_type, ) print(summary) diff --git a/examples/verify_algo_fp8.sh b/examples/verify_algo_fp8.sh index fd85dad..c0b8814 100755 --- a/examples/verify_algo_fp8.sh +++ b/examples/verify_algo_fp8.sh @@ -1,5 +1,6 @@ #!/usr/bin/env bash set -e +# export CUDA_VISIBLE_DEVICES=0 sparse_algos=( "block_sparse_attention" diff --git a/examples/verify_algo_int8.sh b/examples/verify_algo_int8.sh index e57c63f..bf24c2d 100644 --- a/examples/verify_algo_int8.sh +++ b/examples/verify_algo_int8.sh @@ -1,5 +1,6 @@ #!/usr/bin/env bash set -e +# export CUDA_VISIBLE_DEVICES=0 sparse_algos=( "block_sparse_attention" diff --git a/vortex_torch/indexer/context.py b/vortex_torch/indexer/context.py index 6d3c586..d6da9c1 100644 --- a/vortex_torch/indexer/context.py +++ b/vortex_torch/indexer/context.py @@ -22,7 +22,7 @@ class Context(ContextBase): # hardware / paging "num_sms", "page_size", "max_num_pages", "max_num_pages_per_request", # misc - "indexer_dtype", "topk_val", "page_reserved_bos", "page_reserved_eos", + "indexer_dtype", "topk_val", "page_reserved_bos", "page_reserved_eos", "topk_type", # auxilary memory in graph "_aux_total_bytes", @@ -68,6 +68,7 @@ class Context(ContextBase): topk_val: int #: Top-K value used in pruning or selection. page_reserved_bos: int #: Reserved page count for BOS (begin-of-sequence). page_reserved_eos: int #: Reserved page count for EOS (end-of-sequence). + topk_type: str #: TopK kernel type: "naive" or "sglang". # --- auxiliary --- _aux_total_bytes: int #: Accumulated auxiliary memory in bytes. @@ -144,6 +145,7 @@ def create(self, parent: Any, model_runner: Any, *, overwrite: bool = False) -> self.page_reserved_bos = sa.vortex_page_reserved_bos self.page_reserved_eos = sa.vortex_page_reserved_eos + self.topk_type = getattr(sa, "vortex_topk_type", "naive") self.max_num_workloads = ( (self.max_num_pages // max(1, sa.vortex_lb_min_chunk_size)) + max_bs * self.num_kv_heads diff --git a/vortex_torch/indexer/output_func.py b/vortex_torch/indexer/output_func.py index 5df795b..f7d0d9c 100644 --- a/vortex_torch/indexer/output_func.py +++ b/vortex_torch/indexer/output_func.py @@ -1,7 +1,7 @@ import torch from typing import Dict, Callable, Optional from ..abs import vOp -from vortex_torch_C import topk_output +from vortex_torch_C import topk_output, topk_output_sglang from .context import Context from ..abs import vTensor, FORMAT @@ -75,13 +75,17 @@ class topK(vOp): """ # Dispatch by input format; only RAGGED is supported for now. - _impl_map: Dict[FORMAT, Callable] = { - FORMAT.RAGGED: topk_output, + _impl_map: Dict[FORMAT, Dict[str, Callable]] = { + FORMAT.RAGGED: { + "naive": topk_output, + "sglang": topk_output_sglang, + }, } def __init__(self): super().__init__() self.impl: Optional[Callable] = None + self.topk_type: str = "naive" # ---------------- profile ---------------- def profile(self, x: vTensor, o: vTensor, ctx: Context) -> None: @@ -152,7 +156,13 @@ def profile(self, x: vTensor, o: vTensor, ctx: Context) -> None: f"{prefix}no implementation for x._format={x_fmt}. " f"Available: {list(self._impl_map.keys())}" ) - self.impl = self._impl_map[x_fmt] + self.topk_type = getattr(ctx, "topk_type", "naive") + impl_variants = self._impl_map[x_fmt] + assert self.topk_type in impl_variants, ( + f"{prefix}no topk implementation for topk_type='{self.topk_type}'. " + f"Available: {list(impl_variants.keys())}" + ) + self.impl = impl_variants[self.topk_type] # ---- optional sanity checks on `o` ---- # We only assert device consistency and leave exact (S_pack, D0, D1) @@ -220,16 +230,32 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso prefix = self._prefix() assert self.impl is not None, f"{prefix}execute called before profile() (impl is None)" - self.impl( - x, - ctx.dense_kv_indptr, - ctx.sparse_kv_indptr, - ctx.dense_kv_indices, - o, - ctx.batch_size * ctx.num_kv_heads, - ctx.topk_val, - ctx.page_reserved_bos, - ctx.page_reserved_eos, - ctx.max_num_pages_per_request, - ) + if self.topk_type == "sglang": + # topk_output_sglang: (x, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, sparse_kv_indices, ...) + self.impl( + x, + ctx.dense_kv_indptr, + ctx.sparse_kv_indptr, + ctx.dense_kv_indices, + o, + ctx.batch_size * ctx.num_kv_heads, + ctx.topk_val, + ctx.page_reserved_bos, + ctx.page_reserved_eos, + ctx.max_num_pages_per_request, + ) + else: + # topk_output (naive): (x, dense_kv_indptr, dense_kv_indices, sparse_kv_indptr, sparse_kv_indices, ...) + self.impl( + x, + ctx.dense_kv_indptr, + ctx.dense_kv_indices, + ctx.sparse_kv_indptr, + o, + ctx.batch_size * ctx.num_kv_heads, + ctx.topk_val, + ctx.page_reserved_bos, + ctx.page_reserved_eos, + ctx.max_num_pages_per_request, + ) return o From edbf7899b34bb70724839b466c208750c2f6df94 Mon Sep 17 00:00:00 2001 From: UED Date: Mon, 9 Mar 2026 05:30:20 +0000 Subject: [PATCH 07/22] add parameter to switch between two topk kernels (naive or sglang) --- examples/verify_algo.sh | 1 + setup.py | 1 + 2 files changed, 2 insertions(+) diff --git a/examples/verify_algo.sh b/examples/verify_algo.sh index 7487708..73ac2f4 100644 --- a/examples/verify_algo.sh +++ b/examples/verify_algo.sh @@ -19,6 +19,7 @@ TIMESTAMP=$(date +%Y%m%d_%H%M%S) --topk-val 30 \ --vortex-module-name "${algo}" \ --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done \ No newline at end of file diff --git a/setup.py b/setup.py index f35ddae..99c6529 100644 --- a/setup.py +++ b/setup.py @@ -17,6 +17,7 @@ 'csrc/register.cc', 'csrc/utils_sglang.cu', 'csrc/topk.cu', + 'csrc/topk_sglang.cu', ], include_dirs=['csrc'], extra_compile_args={ From 87d7664c8cefcbf23e835ece5df549e53863772a Mon Sep 17 00:00:00 2001 From: UED Date: Wed, 18 Mar 2026 23:40:58 +0000 Subject: [PATCH 08/22] add aim24 --- csrc/topk_sglang.cu | 826 +++++++++++++++++++++++++++++++++++++++ examples/verify_aim24.py | 111 ++++++ 2 files changed, 937 insertions(+) create mode 100644 csrc/topk_sglang.cu create mode 100644 examples/verify_aim24.py diff --git a/csrc/topk_sglang.cu b/csrc/topk_sglang.cu new file mode 100644 index 0000000..314f0fd --- /dev/null +++ b/csrc/topk_sglang.cu @@ -0,0 +1,826 @@ +/** + * @NOTE: This file is adapted from + * https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_v32/topk_selector.py + * We: + * 1. adapt from tilelang to pure cuda + * 2. optimize the performance a little + * 3. fix the potential illegal memory access + */ + #include + #include + #include + #include + #include + #include + #include + #include + #include + + #include + #include + #include + + namespace { + + constexpr int TopK = 2048; + constexpr int kThreadsPerBlock = 1024; + + #ifdef USE_ROCM + // On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a + // per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. + #ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES + constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); + #else + constexpr size_t kSmem = 48 * 1024; // bytes + #endif + #else + // Reduced from 128KB to 32KB to improve occupancy. + // Each radix pass needs at most ~TopK candidates in the threshold bin, + // so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. + constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) + #endif + + struct FastTopKParams { + const float* __restrict__ input; // [B, input_stride] + const int32_t* __restrict__ row_starts; // [B] + int32_t* __restrict__ indices; // [B, TopK] + int32_t* __restrict__ lengths; // [B] + int64_t input_stride; + }; + + // when length <= TopK, we can directly write the indices + __device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { + const auto tid = threadIdx.x; + for (int i = tid; i < TopK; i += kThreadsPerBlock) { + indice[i] = (i < length) ? i : -1; + } + } + + // keep the first `length` entries, set others to -1 + __device__ void naive_topk_transform( + const float* __restrict__ score, + int32_t length, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + dst_page_table[i] = (i < length) ? src_page_table[i] : -1; + } + } + + // keep the first `length` entries, set others to -1 + __device__ void naive_topk_transform_ragged( + const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; + } + } + + __device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); + } + + __device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); + } + + __device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { + // An optimized topk kernel copied from tilelang kernel + // We assume length > TopK here, or it will crash + int topk = TopK; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + // allocate for two rounds + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // stage 1: 8bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(input[idx + row_start]); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { + #pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uint8(input[idx + row_start])); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = input[idx + row_start]; + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + /// NOTE: (dark) fuse the histogram computation here + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8bit radix passes + #pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + // clip here to prevent overflow + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = input[idx + row_start]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[TopK - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + /// NOTE: (dark) fuse the histogram computation here + s_input_idx[r_idx ^ 1][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // topk + void topk_kernel(const FastTopKParams params) { + const auto& [input, row_starts, indices, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto indice = indices + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_cuda(score, indice, length); + } else { + return fast_topk_cuda_tl(score, indice, row_start, length); + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // decode + void topk_transform_decode_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride) { + const auto& [input, _1, _2, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = 0; + const auto length = lengths[bid]; + const auto src_page_entry = src_page_table + bid * src_stride; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // prefill + void topk_transform_prefill_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride, + const int32_t* __restrict__ cu_seqlens_q, + const int64_t prefill_bs) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto length = lengths[bid]; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + + /// NOTE: prefill bs is usually small, we can just use a simple loop here + /// We ensure that last cu_seqlens is equal to number of blocks launched + __shared__ const int32_t* s_src_page_entry; + if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { + if (tid < prefill_bs) { + if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { + s_src_page_entry = src_page_table + tid * src_stride; + } + } + } else { + for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { + if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { + s_src_page_entry = src_page_table + i * src_stride; + } + } + } + __syncthreads(); + const auto src_page_entry = s_src_page_entry; + + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv + void topk_transform_prefill_ragged_kernel( + const FastTopKParams params, + int32_t* __restrict__ topk_indices_ragged, + const int32_t* __restrict__ topk_indices_offset) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto dst_indices_entry = topk_indices_ragged + bid * TopK; + const auto score = input + bid * input_stride; + const auto offset = topk_indices_offset[bid]; + + if (length <= TopK) { + return naive_topk_transform_ragged(score, length, dst_indices_entry, offset); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_indices_entry[idx_0] = pos_0 + offset; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_indices_entry[idx_1] = pos_1 + offset; + } + } + + auto get_params( + const at::Tensor& score, + const at::Tensor& lengths, + std::optional row_starts_opt = std::nullopt, + std::optional indices_opt = std::nullopt) -> FastTopKParams { + const auto B = score.size(0); + TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); + if (row_starts_opt.has_value()) { + const auto& row_starts = row_starts_opt.value(); + TORCH_CHECK(row_starts.dim() == 1); + TORCH_CHECK(row_starts.size(0) == B); + } + TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); + TORCH_CHECK(lengths.size(0) == B); + int32_t* indices_data_ptr = nullptr; + if (indices_opt.has_value()) { + const auto& indices = indices_opt.value(); + TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); + TORCH_CHECK(indices.size(0) == B); + TORCH_CHECK(indices.size(1) == TopK); + indices_data_ptr = indices.data_ptr(); + } + + return FastTopKParams{ + .input = score.data_ptr(), + .row_starts = row_starts_opt.has_value() ? row_starts_opt->data_ptr() : nullptr, + .indices = indices_data_ptr, + .lengths = lengths.data_ptr(), + .input_stride = score.stride(0), + }; + } + + template + void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { + #ifdef USE_ROCM + // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, + // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing + // a function pointer directly, so cast explicitly. + return ::cudaFuncSetAttribute( + reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); + #else + // CUDA: keep original behavior (no cast needed). + return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); + #endif + }(); + TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); + } + + // ====================================================================== + // Vortex integration: BOS/EOS-aware segmented TopK with index remapping + // ====================================================================== + + template + __device__ __forceinline__ float vortex_to_float(T x); + + template <> + __device__ __forceinline__ float vortex_to_float(float x) { return x; } + + template <> + __device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { + return __bfloat162float(x); + } + + constexpr int VORTEX_MAX_TOPK = 2048; + + // Templated version of fast_topk_cuda_tl: + // - ScoreT: float or __nv_bfloat16 + // - target_k: runtime parameter (replaces compile-time TopK) + template + __device__ void fast_topk_vortex( + const ScoreT* __restrict__ input, + int* __restrict__ index, + int row_start, + int length, + int target_k) + { + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int vh_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int vh_counter; + alignas(128) __shared__ int vh_threshold_bin_id; + alignas(128) __shared__ int vh_num_input[2]; + + auto& vh_histogram = vh_histogram_buf[0]; + extern __shared__ int vh_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // Stage 1: 8-bit coarse histogram + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(vortex_to_float(input[idx + row_start])); + ::atomicAdd(&vh_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { + #pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = vh_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += vh_histogram_buf[k][tx + j]; + } + vh_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { + vh_threshold_bin_id = tx; + vh_num_input[0] = 0; + vh_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = vh_threshold_bin_id; + topk -= vh_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast( + convert_to_uint8(vortex_to_float(input[idx + row_start]))); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&vh_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + vh_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&vh_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // Stage 2: refine with 8-bit radix passes + #pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int vh_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = vh_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) + ? _raw_num_input + : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { + vh_threshold_bin_id = tx; + vh_num_input[r_idx ^ 1] = 0; + vh_last_remain = topk - vh_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = vh_threshold_bin_id; + topk -= vh_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = vh_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32( + vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = vh_input_idx[r_idx][i]; + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&vh_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&vh_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + vh_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&vh_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } + } + + // Wrapper kernel: one CUDA block per batch*head segment + template + __global__ __launch_bounds__(kThreadsPerBlock) + void TopKOutput_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos) + { + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_vortex(score_blk, s_indices, 0, nblk, topk_val); + __syncthreads(); + + // Remap position indices -> page indices via dense_kv_indices + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } + } + + } // namespace + + #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + + void fast_topk_interface( + const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(indices); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + CHECK_CUDA(lengths); + const auto params = get_params(score, lengths, row_starts_opt, indices); + const auto B = score.size(0); + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + setup_kernel_smem_once(); + topk_kernel<<>>(params); + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); + } + + void fast_topk_transform_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& dst_page_table, + const at::Tensor& src_page_table, + const at::Tensor& cu_seqlens_q, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(dst_page_table); + CHECK_CUDA(src_page_table); + CHECK_CUDA(cu_seqlens_q); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); + TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); + TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); + const auto prefill_bs = cu_seqlens_q.size(0) - 1; + TORCH_CHECK(dst_page_table.size(0) == B); + TORCH_CHECK(dst_page_table.size(1) == TopK); + TORCH_CHECK(src_page_table.size(0) == prefill_bs); + TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + const auto src_stride = src_page_table.stride(0); + + // dispatch to decode or prefill + // extend and draft extend: row_starts_opt is not null, invokes the prefill kernel + // decode: row_starts_opt is null, invokes the decode kernel + // target verify: row_starts_opt is null, invokes the prefill kernel + const auto is_decode = !row_starts_opt.has_value() && prefill_bs == B; + if (is_decode) { + setup_kernel_smem_once(); + topk_transform_decode_kernel<<>>( + params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride); + } else { + setup_kernel_smem_once(); + topk_transform_prefill_kernel<<>>( + params, + dst_page_table.data_ptr(), + src_page_table.data_ptr(), + src_stride, + cu_seqlens_q.data_ptr(), + prefill_bs); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); + } + + void fast_topk_transform_ragged_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& topk_indices_ragged, + const at::Tensor& topk_indices_offset, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(topk_indices_ragged); + CHECK_CUDA(topk_indices_offset); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous()); + TORCH_CHECK(topk_indices_offset.dim() == 1); + + TORCH_CHECK(topk_indices_ragged.size(0) == B); + TORCH_CHECK(topk_indices_ragged.size(1) == TopK); + TORCH_CHECK(topk_indices_offset.size(0) == B); + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + + setup_kernel_smem_once(); + topk_transform_prefill_ragged_kernel<<>>( + params, topk_indices_ragged.data_ptr(), topk_indices_offset.data_ptr()); + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); + } + + // ====================================================================== + // Vortex host entry point — same interface as topk_output in topk.cu + // ====================================================================== + void topk_output_sglang( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages) + { + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_output: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos); + } else if (x.scalar_type() == at::ScalarType::Float) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos); + } else { + TORCH_CHECK(false, + "topk_output: unsupported dtype ", + x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_output kernel failed: ", ::cudaGetErrorString(result)); + } \ No newline at end of file diff --git a/examples/verify_aim24.py b/examples/verify_aim24.py new file mode 100644 index 0000000..5152680 --- /dev/null +++ b/examples/verify_aim24.py @@ -0,0 +1,111 @@ +import json +import sys +sys.path.append("../") +import python.sglang as sgl +from transformers import AutoTokenizer +import os +from tqdm import tqdm +import time +import torch +os.environ["TOKENIZERS_PARALLELISM"] = "false" +MATH_QUERY_TEMPLATE = """ +Solve the following math problem efficiently and clearly. The last line of your response should be of the following format: 'Therefore, the final answer is: $\\boxed{{ANSWER}}$. I hope it is correct' (without quotes) where ANSWER is just the final number or expression that solves the problem. Think step by step before answering. + +{Question} +""".strip() + +from datasets import load_dataset, Dataset, concatenate_datasets +def generate_requests(dataset: Dataset, field_name: str, data_format: str, trial: int = 1, rank: int = 0, world_size: int = 1): + requests = [] + + # Step 1: Expand dataset trial times + if trial > 1: + dataset = Dataset.from_dict(dataset.to_dict().copy())  # ensure copy + datasets = [dataset] * trial + dataset = concatenate_datasets(datasets) + + total = len(dataset) + + # Step 2: Partition across ranks + per_proc = total // world_size + remainder = total % world_size + start = rank * per_proc + min(rank, remainder) + end = start + per_proc + (1 if rank < remainder else 0) + subset = dataset.select(list(range(start, end))) + + # Step 3: Format requests + for data in dataset: + conversations = [ + {"role": "user", "content": data_format.format(Question=data[field_name])} + ] + data["conversations"] = conversations + requests.append(data) + + return requests + + + + + + + +def main(): + model_name = "Qwen/Qwen3-0.6B" + llm = sgl.Engine(model_path=model_name, + disable_cuda_graph=False, + page_size=16, + vortex_num_selected_pages=29, + disable_overlap_schedule=True, + attention_backend="flashinfer", + enable_vortex_sparsity=True, + vortex_page_reserved_bos=1, + vortex_page_reserved_eos=2, + vortex_layers_skip=list(range(1)), + mem_fraction_static=0.9, + vortex_cg=True, + vortex_graph=True, + vortex_module_name="block_sparse_attention", + vortex_max_seq_lens=20480 + ) + + dataset = load_dataset("HuggingFaceH4/aime_2024", split="train") + + requests = generate_requests(dataset, "problem", MATH_QUERY_TEMPLATE) + + + + texts = [ + x["conversations"] for x in requests + ] + + tokenizer = AutoTokenizer.from_pretrained(model_name) + prompts = [ + tokenizer.apply_chat_template( + text, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True + ) for text in texts + ] * 8 + + sampling_params = {"temperature": 0.6, "top_p": 0.95, "top_k": 20, "max_new_tokens": 16384} + total_tokens = 0 + total_time = 0.0 + start = time.perf_counter() + o = llm.generate(prompts, sampling_params) + elapsed = time.perf_counter() - start + total_time += elapsed + e2e_time = 0 + with open(f"0.6B_VTX_CG_TP1_16K.jsonl", "w", encoding="utf-8") as f: + for item in o: + total_tokens += item["meta_info"]["completion_tokens"] + e2e_time = max(e2e_time, item["meta_info"]["e2e_latency"]) + json.dump(item, f, ensure_ascii=False) + f.write("\n") + + meta_data = {"e2e_time": e2e_time, "total_time": total_time, "total_tokens": total_tokens, "throughput": total_tokens / total_time} + json.dump(meta_data, f, ensure_ascii=False) + f.write("\n") + +if __name__ == "__main__": + main() \ No newline at end of file From 66237d748ee2c69977c208c5a64faa2d382ffd79 Mon Sep 17 00:00:00 2001 From: UED Date: Tue, 24 Mar 2026 18:22:40 +0000 Subject: [PATCH 09/22] Implement sparse prefill with topk on a new ragged only warpper --- examples/verify_aim24.py | 5 ---- examples/verify_algo.sh | 2 +- examples/verify_algo_int8.sh | 25 ------------------- ...erify_algo_fp8.sh => verify_algo_quant.sh} | 18 +++++++++++-- vortex_torch/cache/context.py | 4 +++ vortex_torch/flow/flow.py | 1 + 6 files changed, 22 insertions(+), 33 deletions(-) delete mode 100644 examples/verify_algo_int8.sh rename examples/{verify_algo_fp8.sh => verify_algo_quant.sh} (55%) mode change 100755 => 100644 diff --git a/examples/verify_aim24.py b/examples/verify_aim24.py index 5152680..9e54a96 100644 --- a/examples/verify_aim24.py +++ b/examples/verify_aim24.py @@ -44,11 +44,6 @@ def generate_requests(dataset: Dataset, field_name: str, data_format: str, trial return requests - - - - - def main(): model_name = "Qwen/Qwen3-0.6B" llm = sgl.Engine(model_path=model_name, diff --git a/examples/verify_algo.sh b/examples/verify_algo.sh index 73ac2f4..8416e54 100644 --- a/examples/verify_algo.sh +++ b/examples/verify_algo.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash set -e -# export CUDA_VISIBLE_DEVICES=0 +export CUDA_VISIBLE_DEVICES=1 sparse_algos=( "block_sparse_attention" diff --git a/examples/verify_algo_int8.sh b/examples/verify_algo_int8.sh deleted file mode 100644 index bf24c2d..0000000 --- a/examples/verify_algo_int8.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/usr/bin/env bash -set -e -# export CUDA_VISIBLE_DEVICES=0 - -sparse_algos=( - "block_sparse_attention" -) - -RESULTS_DIR="results" -mkdir -p "${RESULTS_DIR}" -TIMESTAMP=$(date +%Y%m%d_%H%M%S) - - for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/${algo}_int8_${TIMESTAMP}.log" - echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --kv-cache-dtype int8" - echo ">>> Saving results to ${OUTFILE}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --kv-cache-dtype int8 \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" - done \ No newline at end of file diff --git a/examples/verify_algo_fp8.sh b/examples/verify_algo_quant.sh old mode 100755 new mode 100644 similarity index 55% rename from examples/verify_algo_fp8.sh rename to examples/verify_algo_quant.sh index c0b8814..c344474 --- a/examples/verify_algo_fp8.sh +++ b/examples/verify_algo_quant.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash set -e -# export CUDA_VISIBLE_DEVICES=0 +export CUDA_VISIBLE_DEVICES=0 sparse_algos=( "block_sparse_attention" @@ -11,6 +11,20 @@ mkdir -p "${RESULTS_DIR}" TIMESTAMP=$(date +%Y%m%d_%H%M%S) for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_int8_${TIMESTAMP}.log" + echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --kv-cache-dtype int8" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --kv-cache-dtype int8 \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done + + for algo in "${sparse_algos[@]}"; do OUTFILE="${RESULTS_DIR}/${algo}_fp8_${TIMESTAMP}.log" echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --kv-cache-dtype fp8_e4m3" echo ">>> Saving results to ${OUTFILE}" @@ -22,4 +36,4 @@ TIMESTAMP=$(date +%Y%m%d_%H%M%S) --kv-cache-dtype fp8_e4m3 \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" - done + done \ No newline at end of file diff --git a/vortex_torch/cache/context.py b/vortex_torch/cache/context.py index 0e7171c..3cdf095 100644 --- a/vortex_torch/cache/context.py +++ b/vortex_torch/cache/context.py @@ -24,9 +24,11 @@ class Context(ContextBase): # Quantization: quant_type (0=none, 1=int8, 2=e4m3, 3=e5m2), # kv_scale (per-tensor fp8 scale), kv_scale_ptr (per-token int8 scale tensor) + # fp8_type: 0=none, 1=e4m3, 2=e5m2 (encoding for Triton kernels) "quant_type", "kv_scale", "kv_scale_ptr", + "fp8_type", ) @@ -49,6 +51,8 @@ def __init__(self) -> None: object.__setattr__(self, name, 1.0) # identity scale for bf16 elif name == "kv_scale_ptr": object.__setattr__(self, name, None) # per-token scale tensor (int8 only) + elif name == "fp8_type": + object.__setattr__(self, name, 0) # 0 = none (bf16 default) else: object.__setattr__(self, name, UNSET) diff --git a/vortex_torch/flow/flow.py b/vortex_torch/flow/flow.py index 7efc80e..7da5c72 100644 --- a/vortex_torch/flow/flow.py +++ b/vortex_torch/flow/flow.py @@ -431,6 +431,7 @@ def run_indexer_virtual(self, group_size: int, page_size: int, head_dim: int): ctx.page_size = page_size ctx.max_num_pages = 0 ctx.max_num_pages_per_request = 0 + ctx.topk_type = "naive" device = "cuda" dtype = torch.bfloat16 From 9a73a8cccf635a055564d5f5fb155d152854748c Mon Sep 17 00:00:00 2001 From: UED Date: Sun, 29 Mar 2026 05:01:09 +0000 Subject: [PATCH 10/22] fix on the ragged warpper, using single ragged warpper on concated rags and pages; fix on the previous quantization implementaion, with lanuch_graph dtype set to the quant type --- examples/verify_algo.py | 17 +++++++++++------ examples/verify_algo.sh | 7 +++++-- third_party/sglang | 2 +- vortex_torch/flow/__init__.py | 4 +++- vortex_torch/indexer/output_func.py | 4 ++-- 5 files changed, 22 insertions(+), 12 deletions(-) diff --git a/examples/verify_algo.py b/examples/verify_algo.py index 1187aca..91f92e7 100644 --- a/examples/verify_algo.py +++ b/examples/verify_algo.py @@ -142,12 +142,17 @@ def verify_algos( if sparse_attention: llm_cfg = AutoConfig.from_pretrained(model_name) - flow = vortex_torch.flow.build_vflow(vortex_module_name) - memory_access_runtime = flow.run_indexer_virtual( - group_size=llm_cfg.num_attention_heads // llm_cfg.num_key_value_heads, - page_size=page_size, - head_dim=llm_cfg.head_dim, - ) + flow = vortex_torch.flow.build_vflow(vortex_module_name) + try: + memory_access_runtime = flow.run_indexer_virtual( + group_size=llm_cfg.num_attention_heads // llm_cfg.num_key_value_heads, + page_size=page_size, + head_dim=llm_cfg.head_dim, + ) + except Exception: + # External algorithms (nsa, fsa, flash_moba) override run_indexer_virtual + # to return 0 since their vendored kernels don't participate in vortex profiling + memory_access_runtime = 0.0 else: memory_access_runtime = 0.0 diff --git a/examples/verify_algo.sh b/examples/verify_algo.sh index 8416e54..cc174b4 100644 --- a/examples/verify_algo.sh +++ b/examples/verify_algo.sh @@ -1,9 +1,12 @@ #!/usr/bin/env bash set -e -export CUDA_VISIBLE_DEVICES=1 +export CUDA_VISIBLE_DEVICES=7 sparse_algos=( "block_sparse_attention" + "nsa" + "fsa" + "flash_moba" ) RESULTS_DIR="results" @@ -19,7 +22,7 @@ TIMESTAMP=$(date +%Y%m%d_%H%M%S) --topk-val 30 \ --vortex-module-name "${algo}" \ --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ + --topk-type naive \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done \ No newline at end of file diff --git a/third_party/sglang b/third_party/sglang index 7105719..20e4c29 160000 --- a/third_party/sglang +++ b/third_party/sglang @@ -1 +1 @@ -Subproject commit 7105719f0a2ac464ee7ffdc0a899fa6a656656a2 +Subproject commit 20e4c29d206046d6b4eb3b57cc26fd20bf9c519b diff --git a/vortex_torch/flow/__init__.py b/vortex_torch/flow/__init__.py index b2fcadc..bb60b89 100644 --- a/vortex_torch/flow/__init__.py +++ b/vortex_torch/flow/__init__.py @@ -34,9 +34,11 @@ class BlockSparseAttention(vFlow): from .registry import register from .loader import build_vflow from . import algorithms +from . import external_algorithms __all__ = [ "vFlow", "register", "build_vflow", - "algorithms" + "algorithms", + "external_algorithms", ] \ No newline at end of file diff --git a/vortex_torch/indexer/output_func.py b/vortex_torch/indexer/output_func.py index f7d0d9c..8859d61 100644 --- a/vortex_torch/indexer/output_func.py +++ b/vortex_torch/indexer/output_func.py @@ -245,12 +245,12 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso ctx.max_num_pages_per_request, ) else: - # topk_output (naive): (x, dense_kv_indptr, dense_kv_indices, sparse_kv_indptr, sparse_kv_indices, ...) + # topk_output (naive): (x, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, sparse_kv_indices, ...) self.impl( x, ctx.dense_kv_indptr, - ctx.dense_kv_indices, ctx.sparse_kv_indptr, + ctx.dense_kv_indices, o, ctx.batch_size * ctx.num_kv_heads, ctx.topk_val, From a8fd32854d62ec677a3467c6f29557bf9540a2cd Mon Sep 17 00:00:00 2001 From: UED Date: Mon, 30 Mar 2026 02:35:40 +0000 Subject: [PATCH 11/22] Sparse attention kernel apdation with full attention kernels, include (naive sparse attention, flash sparse attention, flashmoba) --- examples/verify_algo.sh | 5 +- examples/verify_algo_topk.sh | 44 + examples/verify_sparse_backends.sh | 27 + vortex_torch/attention_backend/__init__.py | 3 + .../attention_backend/flashmoba/__init__.py | 13 + .../flashmoba/flash_moba_interface.py | 730 ++++++ .../flashmoba/triton_mean_pool.py | 158 ++ .../fsa/FSA_topk_sparse_attention.py | 2040 +++++++++++++++++ .../attention_backend/fsa/__init__.py | 9 + .../attention_backend/nsa/__init__.py | 9 + .../nsa/topk_sparse_attention.py | 1280 +++++++++++ vortex_torch/flow/external_algorithms.py | 76 + vortex_torch/kernels/__init__.py | 0 vortex_torch/kernels/fsa/__init__.py | 5 + .../kernels/fsa/fused_score_kernels.py | 300 +++ vortex_torch/kernels/nsa/__init__.py | 24 + .../kernels/nsa/compressed_attention.py | 1317 +++++++++++ vortex_torch/kernels/nsa/flash_attention.py | 886 +++++++ vortex_torch/kernels/nsa/utils.py | 50 + vortex_torch/kernels/nsa/weighted_pool.py | 341 +++ 20 files changed, 7313 insertions(+), 4 deletions(-) create mode 100644 examples/verify_algo_topk.sh create mode 100755 examples/verify_sparse_backends.sh create mode 100644 vortex_torch/attention_backend/__init__.py create mode 100644 vortex_torch/attention_backend/flashmoba/__init__.py create mode 100644 vortex_torch/attention_backend/flashmoba/flash_moba_interface.py create mode 100644 vortex_torch/attention_backend/flashmoba/triton_mean_pool.py create mode 100644 vortex_torch/attention_backend/fsa/FSA_topk_sparse_attention.py create mode 100644 vortex_torch/attention_backend/fsa/__init__.py create mode 100644 vortex_torch/attention_backend/nsa/__init__.py create mode 100644 vortex_torch/attention_backend/nsa/topk_sparse_attention.py create mode 100644 vortex_torch/flow/external_algorithms.py create mode 100644 vortex_torch/kernels/__init__.py create mode 100644 vortex_torch/kernels/fsa/__init__.py create mode 100644 vortex_torch/kernels/fsa/fused_score_kernels.py create mode 100644 vortex_torch/kernels/nsa/__init__.py create mode 100644 vortex_torch/kernels/nsa/compressed_attention.py create mode 100644 vortex_torch/kernels/nsa/flash_attention.py create mode 100644 vortex_torch/kernels/nsa/utils.py create mode 100644 vortex_torch/kernels/nsa/weighted_pool.py diff --git a/examples/verify_algo.sh b/examples/verify_algo.sh index cc174b4..0dcbe9f 100644 --- a/examples/verify_algo.sh +++ b/examples/verify_algo.sh @@ -1,12 +1,9 @@ #!/usr/bin/env bash set -e -export CUDA_VISIBLE_DEVICES=7 +export CUDA_VISIBLE_DEVICES=5 sparse_algos=( "block_sparse_attention" - "nsa" - "fsa" - "flash_moba" ) RESULTS_DIR="results" diff --git a/examples/verify_algo_topk.sh b/examples/verify_algo_topk.sh new file mode 100644 index 0000000..6b2744a --- /dev/null +++ b/examples/verify_algo_topk.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash +set -e +export CUDA_VISIBLE_DEVICES=5 + +sparse_algos=( + "block_sparse_attention" +) + +RESULTS_DIR="results" +REPEAT_COUNT="${REPEAT_COUNT:-3}" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + +for repeat_idx in $(seq 1 "${REPEAT_COUNT}"); do + for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_naive_${TIMESTAMP}_run${repeat_idx}.log" + echo ">>> Run ${repeat_idx}/${REPEAT_COUNT}: verify_algo.py with --vortex-module-name ${algo} --topk-type naive" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type naive \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done +done + +for repeat_idx in $(seq 1 "${REPEAT_COUNT}"); do + for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_sglang_${TIMESTAMP}_run${repeat_idx}.log" + echo ">>> Run ${repeat_idx}/${REPEAT_COUNT}: verify_algo.py with --vortex-module-name ${algo} --topk-type sglang" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done +done \ No newline at end of file diff --git a/examples/verify_sparse_backends.sh b/examples/verify_sparse_backends.sh new file mode 100755 index 0000000..81b3562 --- /dev/null +++ b/examples/verify_sparse_backends.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash +set -e +export CUDA_VISIBLE_DEVICES=5 + +sparse_algos=( + "nsa" + "fsa" + "flash_moba" +) + +RESULTS_DIR="results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_bf16_${TIMESTAMP}.log" + echo ">>> Running verify_algo.py with --vortex-module-name ${algo}" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type naive \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done diff --git a/vortex_torch/attention_backend/__init__.py b/vortex_torch/attention_backend/__init__.py new file mode 100644 index 0000000..9ca7855 --- /dev/null +++ b/vortex_torch/attention_backend/__init__.py @@ -0,0 +1,3 @@ +# Vendored sparse attention backends for Vortex forward_extend. +# NSA and FSA are pure Triton kernels. +# FlashMoBA requires flash_moba_cuda C++ extension (pip install flash_moba). diff --git a/vortex_torch/attention_backend/flashmoba/__init__.py b/vortex_torch/attention_backend/flashmoba/__init__.py new file mode 100644 index 0000000..aa912b9 --- /dev/null +++ b/vortex_torch/attention_backend/flashmoba/__init__.py @@ -0,0 +1,13 @@ +from .flash_moba_interface import ( + flash_moba_varlen_func, + flash_moba_attn_varlen_func, + flash_topk_varlen_func, + decide_lg_block_m, +) + +__all__ = [ + "flash_moba_varlen_func", + "flash_moba_attn_varlen_func", + "flash_topk_varlen_func", + "decide_lg_block_m", +] diff --git a/vortex_torch/attention_backend/flashmoba/flash_moba_interface.py b/vortex_torch/attention_backend/flashmoba/flash_moba_interface.py new file mode 100644 index 0000000..c196c21 --- /dev/null +++ b/vortex_torch/attention_backend/flashmoba/flash_moba_interface.py @@ -0,0 +1,730 @@ +from typing import Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import os + +try: + import flash_moba_cuda as flash_moba_gpu +except ImportError: + flash_moba_gpu = None +from .triton_mean_pool import flash_topk_mean_pool + +########################################################################################################################## +# Helper functions +########################################################################################################################## + +def round_multiple(x: int, m: int) -> int: + """Round x up to the nearest multiple of m.""" + return ((x + m - 1) // m) * m + +########################################################################################################################## + +def decide_lg_block_m(top_k: int, chunk_size: int, seqlen: int, causal: bool = False) -> int: + sparsity = 0.0 + budget = top_k * chunk_size + if causal: + density = (2*(budget * seqlen) - budget**2) / (seqlen**2) + else: + density = budget / seqlen + + sparsity = 1 - density + + if sparsity <= 0.5: + lg_block_m = 128 + elif sparsity <= 0.7: + lg_block_m = 256 + elif sparsity <= 0.8: + lg_block_m = 512 + elif sparsity <= 0.9: + lg_block_m = 768 + else: + lg_block_m = 1024 + + # [Optimization] Hardware-aware cap for A6000/3090/4090 to avoid Shared Memory OOM + if torch.cuda.is_available(): + major, minor = torch.cuda.get_device_capability() + # sm86 (A6000, 3090) and sm89 (4090, L40) have smaller shared memory than A100 (sm80) + if major == 8 and minor > 0: + lg_block_m = min(lg_block_m, 512) + + return lg_block_m + +########################################################################################################################## + +# torch.compile() support is only enabled for pytorch >= 2.4 +# The reason for this is that we are using the new custom_op and register_fake +# APIs, which support inplace modification of inputs in the function itself +if torch.__version__ >= "2.4.0": + _torch_custom_op_wrapper = torch.library.custom_op + _torch_register_fake_wrapper = torch.library.register_fake +else: + def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None): + def wrap(func): + return func + if fn is None: + return wrap + return fn + def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1): + def wrap(func): + return func + if fn is None: + return wrap + return fn + _torch_custom_op_wrapper = noop_custom_op_wrapper + _torch_register_fake_wrapper = noop_register_fake_wrapper + + +def maybe_contiguous(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + +########################################################################################################################## +# Custom ops +########################################################################################################################## + +@_torch_custom_op_wrapper("flash_moba::_moba_fused_topk", mutates_args=(), device_types="cuda") +def _moba_fused_topk( + q: torch.Tensor, + km: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + cu_seqlens_km: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + # MOBA sparse pattern parameters + moba_topk: int, + moba_chunk_size: int, + causal: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q, km = [maybe_contiguous(x) for x in (q, km)] + + col_offsets, col_nnz, indices, _, _ = flash_moba_gpu.moba_fused_topk( + q, + km, + cu_seqlens_q, + cu_seqlens_k, + cu_seqlens_km, + max_seqlen_q, + max_seqlen_k, + moba_topk, + moba_chunk_size, + causal, + ) + return col_offsets, col_nnz, indices + +@_torch_register_fake_wrapper("flash_moba::_moba_fused_topk") +def _moba_fused_topk_fake( + q: torch.Tensor, + km: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + cu_seqlens_km: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + # MOBA sparse pattern parameters + moba_topk: int, + moba_chunk_size: int, + causal: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q, km = [maybe_contiguous(x) for x in (q, km)] + batch_size = cu_seqlens_q.numel() - 1 + total_q, num_heads, _ = q.shape + + max_lg_col_num = (max_seqlen_k + moba_chunk_size - 1) // moba_chunk_size + + col_offsets = torch.empty((batch_size, num_heads, max_lg_col_num), device=q.device, dtype=torch.int64) + col_nnz = torch.empty((batch_size, num_heads, max_lg_col_num), device=q.device, dtype=torch.int32) + indices = torch.empty((total_q * num_heads * moba_topk), device=q.device, dtype=torch.int32) + + return col_offsets, col_nnz, indices + +if torch.__version__ >= "2.4.0": + _wrapped_moba_fused_topk = torch.ops.flash_moba._moba_fused_topk +else: + _wrapped_moba_fused_topk = _moba_fused_topk + +########################################################################################################################## + +@_torch_custom_op_wrapper("flash_moba::_varlen_sort", mutates_args=(), device_types="cuda") +def _varlen_sort( + col_offsets: torch.Tensor, + col_nnz: torch.Tensor, + indices: torch.Tensor, +) -> torch.Tensor: + col_offset_ends = col_offsets.view(-1) + col_nnz.view(-1) + return flash_moba_gpu.varlen_sort( + col_offsets.view(-1), col_offset_ends, indices + ) + +@_torch_register_fake_wrapper("flash_moba::_varlen_sort") +def _varlen_sort_fake( + col_offsets: torch.Tensor, + col_nnz: torch.Tensor, + indices: torch.Tensor, +) -> torch.Tensor: + # varlen_sort is out-of-place + col_offset_ends = col_offsets.view(-1) + col_nnz.view(-1) + return torch.empty_like(indices) + +if torch.__version__ >= "2.4.0": + _wrapped_varlen_sort = torch.ops.flash_moba._varlen_sort +else: + _wrapped_varlen_sort = _varlen_sort + +########################################################################################################################## + +@_torch_custom_op_wrapper("flash_moba::_flash_moba_attn_varlen_forward", mutates_args=(), device_types="cuda") +def _flash_moba_attn_varlen_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + # MOBA sparse pattern parameters + moba_col_offsets: torch.Tensor, + moba_col_nnz: torch.Tensor, + moba_row_indices: torch.Tensor, + lg_block_m: int, + lg_block_n: int, + dropout_p: float, + softmax_scale: float, + causal: bool, + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + return_softmax: bool = False, + leftpad_k: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + zero_tensors: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + moba_col_offsets = maybe_contiguous(moba_col_offsets) + moba_col_nnz = maybe_contiguous(moba_col_nnz) + moba_row_indices = maybe_contiguous(moba_row_indices) + + out, softmax_lse, S_dmask, rng_state = flash_moba_gpu.moba_varlen_fwd( + q, + k, + v, + None, + moba_col_offsets, + moba_col_nnz, + moba_row_indices, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + leftpad_k, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + zero_tensors, + causal, + softcap, + return_softmax, + lg_block_m, + lg_block_n, + None, + ) + # if out.isnan().any() or softmax_lse.isnan().any(): + # breakpoint() + return out, softmax_lse, S_dmask, rng_state + +@_torch_register_fake_wrapper("flash_moba::_flash_moba_attn_varlen_forward") +def _flash_moba_attn_varlen_forward_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + # MOBA sparse pattern parameters + moba_col_offsets: torch.Tensor, + moba_col_nnz: torch.Tensor, + moba_row_indices: torch.Tensor, + lg_block_m: int, + lg_block_n: int, + dropout_p: float, + softmax_scale: float, + causal: bool, + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + return_softmax: bool = False, + leftpad_k: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + zero_tensors: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + batch_size = cu_seqlens_q.numel() - 1 + total_q, num_heads, _ = q.shape + + out = torch.empty_like(q) + softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout) + p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout) + seqlen_q_rounded = round_multiple(max_seqlen_q, 128) + seqlen_k_rounded = round_multiple(max_seqlen_k, 128) + if return_softmax: + p = torch.empty((batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), dtype=q.dtype, device=q.device, layout=q.layout) + rng_state = torch.empty((2,), dtype=torch.int64, device=q.device) + return out, softmax_lse, p, rng_state + +if torch.__version__ >= "2.4.0": + _wrapped_flash_moba_attn_varlen_forward = torch.ops.flash_moba._flash_moba_attn_varlen_forward +else: + _wrapped_flash_moba_attn_varlen_forward = _flash_moba_attn_varlen_forward + +########################################################################################################################## + +@_torch_custom_op_wrapper("flash_moba::_flash_moba_attn_varlen_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") +def _flash_moba_attn_varlen_backward( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + # MOBA sparse pattern parameters + moba_col_offsets: torch.Tensor, + moba_col_nnz: torch.Tensor, + moba_row_indices: torch.Tensor, + lg_block_m: int, + lg_block_n: int, + dropout_p: float, + softmax_scale: float, + causal: bool, + softcap: float, + alibi_slopes: Optional[torch.Tensor], + deterministic: bool, + rng_state: Optional[torch.Tensor] = None, + zero_tensors: bool = False, +) -> torch.Tensor: + # dq, dk, dv are allocated by us so they should already be contiguous + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + ( + dq, + dk, + dv, + softmax_d, + ) = flash_moba_gpu.moba_varlen_bwd( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + moba_col_offsets, + moba_col_nnz, + moba_row_indices, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + zero_tensors, + causal, + softcap, + deterministic, + lg_block_m, + lg_block_n, + None, + rng_state, + ) + # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): + # breakpoint() + return softmax_d + +@_torch_register_fake_wrapper("flash_moba::_flash_moba_attn_varlen_backward") +def _flash_moba_attn_varlen_backward_fake( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + # MOBA sparse pattern parameters + moba_col_offsets: torch.Tensor, + moba_col_nnz: torch.Tensor, + moba_row_indices: torch.Tensor, + lg_block_m: int, + lg_block_n: int, + dropout_p: float, + softmax_scale: float, + causal: bool, + softcap: float, + alibi_slopes: Optional[torch.Tensor], + deterministic: bool, + rng_state: Optional[torch.Tensor] = None, + zero_tensors: bool = False, +) -> torch.Tensor: + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + batch_size = cu_seqlens_q.numel() - 1 + total_q, num_heads, _ = q.shape + + if dq is None: + dq = torch.empty_like(q) + if dk is None: + dk = torch.empty_like(k) + if dv is None: + dv = torch.empty_like(v) + softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32) + + return softmax_d + +if torch.__version__ >= "2.4.0": + _wrapped_flash_moba_attn_varlen_backward = torch.ops.flash_moba._flash_moba_attn_varlen_backward +else: + _wrapped_flash_moba_attn_varlen_backward = _flash_moba_attn_varlen_backward + +########################################################################################################################## + +class FlashMobaAttnVarlenFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + # MOBA sparse pattern parameters + moba_col_offsets, + moba_col_nnz, + moba_row_indices, + lg_block_m, + lg_block_n, + dropout_p, + softmax_scale, + causal, + softcap, + alibi_slopes, + deterministic, + return_softmax, + is_grad_enabled, + ): + is_grad = is_grad_enabled and any( + x.requires_grad for x in [q, k, v] + ) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + head_size_og = q.size(2) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) + out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_moba_attn_varlen_forward( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + moba_col_offsets, + moba_col_nnz, + moba_row_indices, + lg_block_m, + lg_block_n, + dropout_p, + softmax_scale, + causal=causal, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax and dropout_p > 0, + ) + if is_grad: + ctx.save_for_backward( + q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, + moba_col_offsets, moba_col_nnz, moba_row_indices + ) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.lg_block_m = lg_block_m + ctx.lg_block_n = lg_block_n + + out = out_padded[..., :head_size_og] + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, moba_col_offsets, moba_col_nnz, moba_row_indices = ctx.saved_tensors + dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + head_size_og = dout.size(2) + dout_padded = dout + if head_size_og % 8 != 0: + dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) + _wrapped_flash_moba_attn_varlen_backward( + dout_padded, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + moba_col_offsets, + moba_col_nnz, + moba_row_indices, + ctx.lg_block_m, + ctx.lg_block_n, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + ctx.softcap, + ctx.alibi_slopes, + ctx.deterministic, + rng_state=rng_state, + ) + dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension + dk = dk[..., : dout.shape[-1]] + dv = dv[..., : dout.shape[-1]] + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None + +########################################################################################################################## + +def flash_topk_varlen_func( + q, + k, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + # MOBA sparse pattern parameters + moba_topk, + moba_chunk_size, + causal=False, +): + """ + Computes the top-k indices for Mixture-of-Blocks Attention (MOBA). + This function handles variable length sequences. + + Args: + q (torch.Tensor): Query tensor of shape (total_q, num_heads, head_size). + k (torch.Tensor): Key tensor of shape (total_k, num_heads, head_size). + cu_seqlens_q (torch.Tensor): Cumulative sequence lengths for queries, shape (batch_size + 1,). + cu_seqlens_k (torch.Tensor): Cumulative sequence lengths for keys, shape (batch_size + 1,). + max_seqlen_q (int): Maximum sequence length for queries. + max_seqlen_k (int): Maximum sequence length for keys. + moba_topk (int): The number of top-k elements to select. + moba_chunk_size (int): The chunk size for MOBA. + causal (bool): Whether to apply causal masking. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing: + - col_offsets (torch.Tensor): Column offsets for the sparse matrix. + - col_nnz (torch.Tensor): Number of non-zero elements per column block. + - indices (torch.Tensor): The top-k indices. + """ + head_size_og = q.size(2) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + + km, cu_seqlens_km, _ = flash_topk_mean_pool(k, cu_seqlens_k, max_seqlen_k, moba_chunk_size) + + col_offsets, col_nnz, indices = _wrapped_moba_fused_topk( + q, + km, + cu_seqlens_q, + cu_seqlens_k, + cu_seqlens_km, + max_seqlen_q, + max_seqlen_k, + moba_topk, + moba_chunk_size, + causal=causal + ) + + indices = _wrapped_varlen_sort( + col_offsets, col_nnz, indices + ) + + return col_offsets, col_nnz, indices + +########################################################################################################################## + +def flash_moba_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + # MOBA sparse pattern parameters + moba_col_offsets, + moba_col_nnz, + moba_row_indices, + lg_block_m=64, + lg_block_n=64, + dropout_p=0.0, + softmax_scale=None, + causal=False, + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, +): + """dropout_p should be set to 0.0 during evaluation + Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Arguments: + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + moba_col_offsets: Optional[torch.Tensor]. Column offsets for MOBA sparse pattern. + Shape: (batch_size, num_heads, max_lg_col_num), dtype: int64 + moba_col_nnz: Optional[torch.Tensor]. Non-zero counts per column for MOBA sparse pattern. + Shape: (batch_size, num_heads, max_lg_col_num), dtype: int32 + moba_row_indices: Optional[torch.Tensor]. Row indices for MOBA sparse pattern (flattened). + dtype: int32 + lg_block_m: int. Logical block size in M dimension (query). Default: 64 + lg_block_n: int. Logical block size in N dimension (key). Default: 64 + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashMobaAttnVarlenFunc.apply( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + moba_col_offsets, + moba_col_nnz, + moba_row_indices, + lg_block_m, + lg_block_n, + dropout_p, + softmax_scale, + causal, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + torch.is_grad_enabled(), + ) + +########################################################################################################################## + +def flash_moba_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + moba_chunk_size, + moba_topk, + causal=True, +): + + col_offsets, col_nnz, indices = flash_topk_varlen_func( + q, + k, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + # MOBA sparse pattern parameters + moba_topk, + moba_chunk_size, + causal=causal, + ) + + lg_block_m = decide_lg_block_m(moba_topk, moba_chunk_size, max_seqlen_k, causal) + + return flash_moba_attn_varlen_func( + q, k, v, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + col_offsets, + col_nnz, + indices, + lg_block_m, + moba_chunk_size, + dropout_p=0.0, + causal=causal, + ) diff --git a/vortex_torch/attention_backend/flashmoba/triton_mean_pool.py b/vortex_torch/attention_backend/flashmoba/triton_mean_pool.py new file mode 100644 index 0000000..6fbd59f --- /dev/null +++ b/vortex_torch/attention_backend/flashmoba/triton_mean_pool.py @@ -0,0 +1,158 @@ +# Copyright (c) 2025, FlashMoBA Team. +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + + +# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: +# - A list of `triton.Config` objects that define different configurations of +# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try +# - An auto-tuning *key* whose change in values will trigger evaluation of all the +# provided configs +@triton.autotune( + configs=[ + # triton.Config({'kBlockN': 16}, num_warps=2, num_stages=3), + triton.Config({'kBlockN': 32}, num_warps=2, num_stages=3), + triton.Config({'kBlockN': 32}, num_warps=4, num_stages=3), + triton.Config({'kBlockN': 32}, num_warps=4, num_stages=4), + triton.Config({'kBlockN': 64}, num_warps=2, num_stages=3), + triton.Config({'kBlockN': 64}, num_warps=4, num_stages=3), + triton.Config({'kBlockN': 64}, num_warps=4, num_stages=4), + triton.Config({'kBlockN': 64}, num_warps=8, num_stages=3), + triton.Config({'kBlockN': 128}, num_warps=2, num_stages=3), + triton.Config({'kBlockN': 128}, num_warps=4, num_stages=3), + triton.Config({'kBlockN': 128}, num_warps=4, num_stages=4), + triton.Config({'kBlockN': 128}, num_warps=8, num_stages=3), + triton.Config({'kBlockN': 128}, num_warps=8, num_stages=4), + # triton.Config({'kBlockN': 256}, num_warps=4, num_stages=3), + # triton.Config({'kBlockN': 256}, num_warps=8, num_stages=3), + # triton.Config({'kBlockN': 256}, num_warps=8, num_stages=4), + # triton.Config({'kBlockN': 256}, num_warps=16, num_stages=2), + # triton.Config({'kBlockN': 512}, num_warps=8, num_stages=2), + # triton.Config({'kBlockN': 512}, num_warps=16, num_stages=2), + # triton.Config({'kBlockN': 512}, num_warps=16, num_stages=3), + # triton.Config({'kBlockN': 1024}, num_warps=16, num_stages=2), + ], + key=['HEAD_DIM', 'POOL_BLOCK_SIZE'], +) +@triton.jit +def mean_pool_kernel( + # Pointers to matrices + input_ptr, + output_ptr, + # Matrix dimensions + HEAD_DIM: tl.constexpr, + POOL_BLOCK_SIZE: tl.constexpr, + cu_seqlens_input, + cu_seqlens_output, + input_stride_row, input_stride_head, + output_stride_row, output_stride_head, + # Meta-parameters + kBlockN: tl.constexpr, +): + """ + Triton kernel for mean pooling over variable-length sequences. + + This kernel computes the mean of non-overlapping blocks of size `POOL_BLOCK_SIZE` + for each sequence in a batch. It is designed to handle variable sequence lengths. + + Args: + input_ptr: Pointer to the input tensor of shape (total_seqlen, num_heads, head_dim). + output_ptr: Pointer to the output tensor of shape (total_blocks, num_heads, head_dim). + HEAD_DIM: The dimension of each head. + POOL_BLOCK_SIZE: The size of the pooling window. + cu_seqlens_input: Cumulative sequence lengths of the input tensor, shape (batch_size + 1,). + cu_seqlens_output: Cumulative sequence lengths of the output tensor, shape (batch_size + 1,). + input_stride_row: Stride of the input tensor along the sequence dimension. + input_stride_head: Stride of the input tensor along the head dimension. + output_stride_row: Stride of the output tensor along the sequence dimension. + output_stride_head: Stride of the output tensor along the head dimension. + kBlockN: Block size for the sequence dimension, a meta-parameter for tuning. + """ + n_block = tl.program_id(0) + bidb = tl.program_id(1) + bidh = tl.program_id(2) + + seq_start = tl.load(cu_seqlens_input + bidb) + seq_end = tl.load(cu_seqlens_input + bidb + 1) + + block_start_row = seq_start + n_block * POOL_BLOCK_SIZE + + if seq_end <= block_start_row: + return + + actual_block_size = tl.minimum(POOL_BLOCK_SIZE, seq_end - block_start_row) + + offsets_d = tl.arange(0, HEAD_DIM) + # mask_d = offsets_d < HEAD_DIM + + acc = tl.zeros([HEAD_DIM], dtype=tl.float32) + + for block_k_start in range(0, actual_block_size, kBlockN): + offsets_k = block_k_start + tl.arange(0, kBlockN) + mask_k = offsets_k < actual_block_size + + row_indices = block_start_row + offsets_k + + input_offset = row_indices[:, None] * input_stride_row.to(tl.int64) + bidh * input_stride_head.to(tl.int64) + offsets_d[None, :] + + inp = tl.load(input_ptr + input_offset, mask=mask_k[:, None], other=0.0) + acc += tl.sum(inp, axis=0) + + # safe division + mean_val = acc / actual_block_size + + output_start = tl.load(cu_seqlens_output + bidb) + output_offset = (output_start + n_block) * output_stride_row.to(tl.int64) + bidh * output_stride_head.to(tl.int64) + offsets_d + tl.store(output_ptr + output_offset, mean_val) + + +def flash_topk_mean_pool(input, cu_seqlens_input, max_seqlen_input, pool_block_size): + """ + Performs mean pooling on variable-length sequences using a Triton kernel. + + This function takes a tensor of packed sequences and applies mean pooling over + fixed-size blocks. + + Args: + input (torch.Tensor): The input tensor of shape (total_seqlen, num_heads, head_dim). + cu_seqlens_input (torch.Tensor): Cumulative sequence lengths for the input, shape (batch_size + 1,). + max_seqlen_input (int): The maximum sequence length in the input batch. + pool_block_size (int): The size of the pooling window. + + Returns: + Tuple[torch.Tensor, torch.Tensor, int]: A tuple containing: + - output (torch.Tensor): The pooled output tensor of shape (total_blocks, num_heads, head_dim). + - cu_seqlens_output (torch.Tensor): Cumulative sequence lengths for the output. + - max_seqlen_output (int): The maximum number of blocks for any sequence in the batch. + """ + total_seqlen, head_num, head_dim = input.shape + batch_size = cu_seqlens_input.shape[0] - 1 + + max_seqlen_output = (max_seqlen_input + pool_block_size - 1) // pool_block_size + + actual_input_seqlens = cu_seqlens_input[1:] - cu_seqlens_input[:-1] + actual_output_seqlens = (actual_input_seqlens + pool_block_size - 1) // pool_block_size + cu_seqlens_output = F.pad(torch.cumsum(actual_output_seqlens, dim=0), (1, 0)).to(torch.int32) + + total_blocks = cu_seqlens_output[-1].item() + + output = torch.zeros((total_blocks, head_num, head_dim), dtype=input.dtype, device=input.device) + + grid = (max_seqlen_output, batch_size, head_num) + + mean_pool_kernel[grid]( + input, + output, + head_dim, + pool_block_size, + cu_seqlens_input, + cu_seqlens_output, + input.stride(0), input.stride(1), + output.stride(0), output.stride(1), + ) + + return output, cu_seqlens_output, max_seqlen_output + \ No newline at end of file diff --git a/vortex_torch/attention_backend/fsa/FSA_topk_sparse_attention.py b/vortex_torch/attention_backend/fsa/FSA_topk_sparse_attention.py new file mode 100644 index 0000000..acca2ac --- /dev/null +++ b/vortex_torch/attention_backend/fsa/FSA_topk_sparse_attention.py @@ -0,0 +1,2040 @@ +# Copyright 2025 Ran Yan. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific +import math +from typing import Any, Optional + +import torch +import triton +import triton.language as tl + +from ..nsa.topk_sparse_attention import (backward_sum_o_do, + reorder_topk_idx, + get_num_warps_stages, + is_hopper_gpu) + +IS_HOPPER_GPU = is_hopper_gpu() + + +@triton.jit +def fused_fill_kernel(ptr_tile, ptr_m_i_cur_tiles, N, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < N + + tl.store(ptr_tile + offsets, -1, mask=mask) # fill int32 with -1 + tl.store(ptr_m_i_cur_tiles + offsets, float("-inf"), mask=mask) + + +def fused_fill(topk_idx_permuted_tile: torch.Tensor, m_i_cur_tiles): + + numel = topk_idx_permuted_tile.numel() + BLOCK_SIZE = 1024 + + # Flatten for pointer access + tile_flat = topk_idx_permuted_tile.view(-1) + + m_i_cur_tiles_flat = m_i_cur_tiles.view(-1) + + grid = lambda meta: (triton.cdiv(numel, meta['BLOCK_SIZE']),) + + fused_fill_kernel[grid]( + tile_flat, + m_i_cur_tiles_flat, + numel, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=1, + num_stages=3, + ) + + +@triton.jit +def block_to_token_kernel( + topk_idx_ptr, + result_ptr, + N_token, + K, + min_block_id, + max_block_id, + padding_value, + ts_h, + ts_b, + ts_n, + rs_h, + rs_b, + rs_n, + num_q_loops: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid = tl.program_id(0) # token index i + pid_h = 0 + offs = tl.arange(0, BLOCK_K) # [0, 1, ..., K-1] + + offs_q = tl.arange(0, num_q_loops) + + pid_j = pid * num_q_loops + offs_q + + topk_idx_offset = pid_h * ts_h + pid_j[None, :] * K + offs[:, None] + block_ids = tl.load( + topk_idx_ptr + topk_idx_offset, mask=(pid_j < N_token)[None, :] & (offs < K)[:, None], other=padding_value + ) + + result_ptrs = result_ptr + pid_h * rs_h + block_ids * N_token + pid_j[None, :] + + mask = (block_ids >= 0) & (block_ids != padding_value) & (pid_j < N_token)[None, :] + tl.store(result_ptrs, pid_j[None, :], mask=mask) + + +def build_block_to_token_triton( + result: torch.Tensor, topk_idx: torch.Tensor, min_block_id: int, max_block_id: int, padding_value: int = -1 +): + """ + Args: + topk_idx: [num_heads, N_token, TopK], block indices per token, padded with padding_value for invalid blocks + num_blocks: int + padding_value: int + + Returns: + result: [num_blocks, N_token], token indices per block, padded by padding_value + """ + assert topk_idx.ndim == 3 + assert padding_value == -1 + num_heads, N_token, TopK = topk_idx.shape + + # 每个 token,每个head 一个 program + num_q_loops = 4 + grid = (triton.cdiv(N_token, num_q_loops),) + BLOCK_K = triton.next_power_of_2(TopK) + block_to_token_kernel[grid]( + topk_idx, + result, + N_token, + TopK, + min_block_id, + max_block_id, + padding_value, + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + result.stride(0), + result.stride(1), + result.stride(2), + num_q_loops, + BLOCK_K=BLOCK_K, + num_warps=2, + num_stages=3, + ) + return result + + +@triton.jit +def reduce_kernel( + lse_ptr, # float32 [H, N] + m_ij_ptr, # float32 [H, B, N] + l_ij_first_ptr, # float32 [H, 1, N] + l_ij_rest_ptr, # float32 [H, B, N] + m_ij_last_ptr, # float32 [H, N] + o_ptr, # o: n x h x d + o_tiles_first_ptr, # o_tiles: n x h x 1 x d + o_tiles_rest_ptr, # o_tiles: n x h x b x d + acc_o_scales_first_ptr, # acc_o_scales: n x h x 1 + acc_o_scales_rest_ptr, # acc_o_scales: n x h x b + t_ptr, # topk_idx: h x n x k + token_index_mapping_ptr, + start_head_id, + num_qz_loop, + TOPK, + total_len, + # stride + stride_lse_h, + stride_lse_n, + stride_m_ij_h, + stride_m_ij_b, + stride_m_ij_n, + stride_l_ij_fh, + stride_l_ij_fb, + stride_l_ij_fn, + stride_l_ij_rh, + stride_l_ij_rb, + stride_l_ij_rn, + stride_on, + stride_oh, + stride_od, + stride_otfh, + stride_otfb, + stride_otfn, + stride_otfd, + stride_otrh, + stride_otrb, + stride_otrn, + stride_otrd, + stride_acc_fh, + stride_acc_fb, + stride_acc_fn, + stride_acc_rh, + stride_acc_rb, + stride_acc_rn, + stride_th, + stride_tn, + stride_tk, + stride_tim_h, + stride_tim_b, + stride_tim_n, + # META parameters + BLOCK_SIZE_T: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_qy = tl.program_id(0) + pid_q = tl.program_id(1) # token + + pid_q_j = pid_q + pid_qy * num_qz_loop + if pid_q_j < total_len: + t_ptr_j = t_ptr + pid_q_j * stride_tn + + off_d = tl.arange(0, BLOCK_SIZE_D) + o_ptrs = o_ptr + pid_q_j * stride_on + off_d + last_acc_o = tl.load(o_ptrs, mask=off_d < BLOCK_SIZE_D, other=0.0) + acc_o = tl.zeros((BLOCK_SIZE_D,), dtype=tl.float32) + acc_o += last_acc_o + + lse_ptrs = lse_ptr + pid_q_j * stride_lse_n + # Load lse + lse = tl.load(lse_ptrs, mask=pid_q_j < total_len, other=float("-inf")) + + # the stride is 1 for m_ij_last + m_ij_last = tl.load(m_ij_last_ptr + pid_q_j) + + for block_id in range(TOPK): + t = tl.load(t_ptr_j + block_id * stride_tk, mask=block_id < TOPK, other=-1) + if t != -1: + if t == 0: + real_block_pos = 0 + l_ij_ptr = l_ij_first_ptr + o_tiles_ptr = o_tiles_first_ptr + acc_o_scales_ptr = acc_o_scales_first_ptr + stride_l_ij_b = stride_l_ij_fb + stride_l_ij_n = stride_l_ij_fn + stride_acc_b = stride_acc_fb + stride_acc_n = stride_acc_fn + stride_otb = stride_otfb + stride_otn = stride_otfn + else: + real_block_pos = t - 1 + l_ij_ptr = l_ij_rest_ptr + o_tiles_ptr = o_tiles_rest_ptr + acc_o_scales_ptr = acc_o_scales_rest_ptr + stride_l_ij_b = stride_l_ij_rb + stride_l_ij_n = stride_l_ij_rn + stride_acc_b = stride_acc_rb + stride_acc_n = stride_acc_rn + stride_otb = stride_otrb + stride_otn = stride_otrn + + # init pointers + token_index_mapping_ptrs = ( + token_index_mapping_ptr + t.to(tl.int64) * stride_tim_b + (pid_q_j) * stride_tim_n + ) + real_token_index = tl.load(token_index_mapping_ptrs) + + m_ij = tl.load( + m_ij_ptr + t * stride_m_ij_b + pid_q_j * stride_m_ij_n, mask=pid_q_j < total_len, other=float("-inf") + ) + l_ij = tl.load( + l_ij_ptr + real_block_pos * stride_l_ij_b + real_token_index * stride_l_ij_n, + mask=real_token_index < total_len, + other=0.0, + ) + delta = lse - m_ij + + log_delta = tl.exp2(delta) + l_ij + + # Update lse + lse = m_ij + tl.log2(log_delta) + + o_tiles_ptrs = ( + o_tiles_ptr + real_block_pos.to(tl.int64) * stride_otb + (real_token_index) * stride_otn + off_d + ) + acc_o_scales_ptrs = acc_o_scales_ptr + real_block_pos * stride_acc_b + (real_token_index) * stride_acc_n + + o_tiles = tl.load(o_tiles_ptrs) + acc_o_scales_tiles = tl.load(acc_o_scales_ptrs) + acc_o = o_tiles + acc_o * acc_o_scales_tiles + + # final scale + acc_o = acc_o * tl.exp2(m_ij_last - lse) + tl.store(o_ptrs, acc_o, mask=off_d < BLOCK_SIZE_D) + + # Store back + tl.store( + lse_ptrs, + lse, + mask=pid_q_j < total_len, + ) + + +@triton.jit +def qk_kernel( + q_ptr, # Q: n x h x d + k_ptr, # K: n x h x d + m_i_tiles_ptr, # m_i: h x b x n + selected_tokens_ptr, # selected_tokens: sum(valid_lens), + valid_lens_ptr, # valid_lens: (h x b), + valid_start_indices_ptr, # valid_start_indices: (h x b), + num_heads, + num_blocks, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + HEAD_DIM, + # sm_scale + sm_scale, + num_q_blocks, + num_b_blocks, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_m_i_tiles_h, + stride_m_i_tiles_b, + stride_m_i_tiles_n, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_block_grid = tl.program_id(0) // num_heads # block id + head_id = tl.program_id(0) % num_heads + pid_q = tl.program_id(1) # token + + # get q k start and len after rmpad + k_len = tl.load(cu_seqlens_k + 1) + k_ptrs = tl.make_block_ptr( + base=k_ptr + head_id * stride_kh, + shape=(HEAD_DIM, k_len), + strides=(stride_kd, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + + for bb in range(num_b_blocks): + pid_block = bb + pid_block_grid * num_b_blocks + + start_id = tl.load(valid_start_indices_ptr + head_id * num_blocks + pid_block) + valid_tokens = tl.load(valid_lens_ptr + head_id * num_blocks + pid_block) + if pid_q * BLOCK_SIZE_Q < valid_tokens: + + c = pid_block * BLOCK_SIZE_K + + # load k + k = tl.load(tl.advance(k_ptrs, (0, c)), boundary_check=(1, 0), padding_option="zero") + + off_k = tl.arange(0, BLOCK_SIZE_K) + off_d = tl.arange(0, BLOCK_SIZE_D) + for j in range(num_q_blocks): + pid_q_j = pid_q * num_q_blocks + j + # Enable early return + if pid_q_j * BLOCK_SIZE_Q < valid_tokens: + # one thread block for one KV block, a subset of selected tokens + st_offs = start_id + (pid_q_j * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)) + # st should be in shape [BLOCK_SIZE_Q] + st_mask = (pid_q_j * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)) < valid_tokens + + st = tl.load(selected_tokens_ptr + st_offs, mask=st_mask, other=-1) + # otherwise, st selects a set of q tokens, selected_tokens_ptr should be sorted + q_ptrs_off = st[:, None] * stride_qn + off_d[None, :] * stride_qd + q_ptrs = q_ptr + head_id * stride_qh + q_ptrs_off + # load q + q_mask = (st != -1)[:, None] & (off_d < HEAD_DIM)[None, :] + q = tl.load(q_ptrs, mask=q_mask, other=0) + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where((st[:, None] >= c + off_k[None, :]), 0, float("-inf")) + # [BLOCK_SIZE_Q, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIZE_K] -> [BLOCK_SIZE_Q, BLOCK_SIZE_K] + qk += tl.dot(q, k) * qk_scale + + m_i = tl.max(qk, axis=1) + + m_i_tiles_ptrs = ( + m_i_tiles_ptr + + head_id * stride_m_i_tiles_h + + pid_block * stride_m_i_tiles_b + + st * stride_m_i_tiles_n + ) + tl.store(m_i_tiles_ptrs, m_i, mask=(st != -1)) + + +@triton.jit +def forward_kernel_opt( + q_ptr, + k_ptr, + v_ptr, # V: n x h x d + o_tiles_ptr, # O: n x h x b x d + acc_o_scales_ptr, # acc_o_scales: h x b x n + m_ij_tiles_ptr, + l_ij_ptr, # h x b x n + token_index_mapping_ptr, + selected_tokens_ptr, # selected_tokens: sum(valid_lens), + valid_lens_ptr, # valid_lens: (h x b), + valid_start_indices_ptr, # valid_start_indices: (h x b), + min_block_id, + cur_max_valid_tokens, + num_heads, + num_blocks, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + HEAD_DIM, + # sm_scale + sm_scale, + num_q_blocks, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_oth, + stride_otb, + stride_otn, + stride_otd, + stride_acc_oh, + stride_acc_ob, + stride_acc_on, + stride_m_ij_tiles_h, + stride_m_ij_tiles_b, + stride_m_ij_tiles_n, + stride_l_ij_h, + stride_l_ij_b, + stride_l_ij_n, + stride_tim_h, + stride_tim_b, + stride_tim_n, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + # get batch id and head id + pid_block = tl.program_id(0) // num_heads # block id + head_id = tl.program_id(0) % num_heads + pid_q = tl.program_id(1) # token + # seq packing is not supported yet + q_start = 0 + k_start = 0 + + k_len = tl.load(cu_seqlens_k + 1) - k_start + + start_id = tl.load(valid_start_indices_ptr + head_id * num_blocks + pid_block) + valid_tokens = tl.load(valid_lens_ptr + head_id * num_blocks + pid_block) + if num_q_blocks * pid_q * BLOCK_SIZE_Q >= valid_tokens: + return + + c = (min_block_id + pid_block) * BLOCK_SIZE_K + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + head_id * stride_kh, + shape=(HEAD_DIM, k_len), + strides=(stride_kd, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + # load k + k = tl.load(tl.advance(k_ptrs, (0, c)), boundary_check=(1, 0), padding_option="zero") + + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + head_id * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + + # load v + v = tl.load(tl.advance(v_ptrs, (c, 0)), boundary_check=(0, 1), padding_option="zero") + + off_k = tl.arange(0, BLOCK_SIZE_K) + off_d = tl.arange(0, BLOCK_SIZE_D) + for j in range(num_q_blocks): + pid_q_j = pid_q * num_q_blocks + j + if pid_q_j * BLOCK_SIZE_Q < valid_tokens: + # one thread block for one KV block, a subset of selected tokens + st_offs = start_id + (q_start + pid_q_j * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)) + # st should be in shape [BLOCK_SIZE_Q] + st_mask = (pid_q_j * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)) < valid_tokens + + st = tl.load(selected_tokens_ptr + st_offs, mask=st_mask, other=-1) + + # otherwise, st selects a set of q tokens, selected_tokens_ptr should be sorted + q_ptrs_off = st[:, None] * stride_qn + off_d[None, :] * stride_qd + + # load m_i + mask = st != -1 + + m_ij_tiles_ptrs = ( + m_ij_tiles_ptr + + head_id * stride_m_ij_tiles_h + + (q_start + st) * stride_m_ij_tiles_n + + (pid_block + min_block_id) * stride_m_ij_tiles_b + ) + m_ij = tl.load(m_ij_tiles_ptrs, mask=mask, other=float("-inf")) + + m_ij_tiles_prev_ptrs = ( + m_ij_tiles_ptr + + head_id * stride_m_ij_tiles_h + + (q_start + st) * stride_m_ij_tiles_n + + (pid_block + min_block_id - 1) * stride_m_ij_tiles_b + ) + m_ij_prev = tl.load(m_ij_tiles_prev_ptrs, mask=mask & (pid_block + min_block_id > 0), other=float("-inf")) + + m_i_minus_m_ij = m_ij_prev - m_ij + + q_ptrs = q_ptr + q_start * stride_qn + head_id * stride_qh + q_ptrs_off + # load q + q_mask = mask[:, None] & (off_d < HEAD_DIM)[None, :] + q = tl.load(q_ptrs, mask=q_mask, other=0) + + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where((st[:, None] >= c + off_k[None, :]), 0, float("-inf")) + + # [BLOCK_SIZE_Q, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIZE_K] -> [BLOCK_SIZE_Q, BLOCK_SIZE_K] + qk_scale = sm_scale * 1.44269504 + qk += tl.dot(q, k) * qk_scale + + # init statistics + acc_o_buffer = tl.full((BLOCK_SIZE_Q, BLOCK_SIZE_D), 0, dtype=tl.float32) + + # load m_ij and compute l_ij + p = tl.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + + # load token index mapping + token_index_mapping_ptrs = ( + token_index_mapping_ptr + (st) * stride_tim_n + (pid_block + min_block_id) * stride_tim_b + ) + token_index_mapping = tl.load(token_index_mapping_ptrs, mask=mask, other=-1) + + l_ij_ptrs = ( + l_ij_ptr + + head_id * stride_l_ij_h + + (q_start + token_index_mapping) * stride_l_ij_n + + (pid_block) * stride_l_ij_b + ) + tl.store(l_ij_ptrs, l_ij, mask=mask) + # scale acc_o + if pid_block + min_block_id == 0: + acc_o_scale = tl.full((BLOCK_SIZE_Q,), 1.0, dtype=tl.float32) + else: + acc_o_scale = tl.exp2(m_i_minus_m_ij) + + tl.store( + acc_o_scales_ptr + + head_id * stride_acc_oh + + (pid_block) * stride_acc_ob + + (q_start + token_index_mapping) * stride_acc_on, + acc_o_scale, + mask=(st != -1), + ) + + p = p.to(v.dtype) + acc_o_buffer = tl.dot(p, v) + + o_ptrs_off = token_index_mapping[:, None] * stride_otn + off_d[None, :] * stride_otd + o_ptrs = o_tiles_ptr + head_id * stride_oth + o_ptrs_off + (pid_block).to(tl.int64) * stride_otb + tl.store(o_ptrs, acc_o_buffer.to(o_tiles_ptr.dtype.element_ty), mask=q_mask) + + +def _topk_sparse_attention_fwd_opt( + q: torch.Tensor, # [total_len, num_heads, head_dim] + k: torch.Tensor, # [total_len, num_heads, head_dim] + v: torch.Tensor, # [total_len, num_heads, head_dim] + topk_idx: torch.Tensor, # [num_heads, total_len, topk] + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, + causal=True, +): + """ + TODO: Currently sequence packing is explicitly done in for loop, will merge in kernels. + """ + o = torch.empty_like(q) + total_len, num_heads, _ = q.shape + lse = torch.empty((num_heads, total_len), dtype=torch.float32, device=q.device) + + permute_results = [] + for i in range(len(cu_seqlens_q) - 1): + cu_seqlens_q_ = cu_seqlens_q[i: i + 2] - cu_seqlens_q[i] + cu_seqlens_k_ = cu_seqlens_k[i: i + 2] - cu_seqlens_k[i] + max_seqlen_q_ = cu_seqlens_q_[1] - cu_seqlens_q_[0] + max_seqlen_k_ = cu_seqlens_k_[1] - cu_seqlens_k_[0] + + q_ = q[cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + k_ = k[cu_seqlens_k[i]: cu_seqlens_k[i + 1]] + v_ = v[cu_seqlens_k[i]: cu_seqlens_k[i + 1]] + topk_idx_ = topk_idx[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + o_seq, lse_seq, permute_results_seq = _topk_sparse_attention_fwd_opt_per_seq( + q_, + k_, + v_, + topk_idx_, + block_size, + cu_seqlens_q_, + cu_seqlens_k_, + max_seqlen_q_, + max_seqlen_k_, + sm_scale, + causal, + ) + o[cu_seqlens_q[i]: cu_seqlens_q[i + 1]] = o_seq + + lse[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] = lse_seq + permute_results.append(permute_results_seq) + + return o, lse, permute_results + + +@triton.jit +def index_mapping_kernel( + token_index_mapping_ptr, + selected_tokens_ptr, + valid_lens_ptr, + valid_start_indices_ptr, + stride_im_h, + stride_im_b, + stride_im_n, + BLOCK_SIZE_K: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_q = tl.arange(0, BLOCK_SIZE_K) + offs_n = pid_n * BLOCK_SIZE_K + offs_q + + start_id = tl.load(valid_start_indices_ptr + pid_b) + valid_tokens = tl.load(valid_lens_ptr + pid_b) + + st_offs = start_id + offs_n + # st should be in shape [BLOCK_SIZE_K] + st_mask = offs_n < valid_tokens + + st = tl.load(selected_tokens_ptr + st_offs, mask=st_mask, other=-1) + + token_im_ptrs = token_index_mapping_ptr + pid_b * stride_im_b + st * stride_im_n + + tl.store(token_im_ptrs, offs_n, mask=st_mask) + + +def index_mapping(token_index_mapping, valid_topk_idx_permuted_tile, valid_lens, valid_start_indices, num_blocks): + max_tokens = valid_lens.max() + BLOCK_SIZE_K = 1024 + grid = (num_blocks, triton.cdiv(max_tokens, BLOCK_SIZE_K)) + + index_mapping_kernel[grid]( + token_index_mapping, + valid_topk_idx_permuted_tile, + valid_lens, + valid_start_indices, + token_index_mapping.stride(0), + token_index_mapping.stride(1), + token_index_mapping.stride(2), + BLOCK_SIZE_K, + num_warps=2, + num_stages=3, + ) + + +def online_softmax( + q_tile, + k_tile, + m_i_cur_tiles, + valid_topk_idx_permuted_tile, + valid_lens, + valid_start_indices, + compute_min_block_id, + cur_max_valid_tokens, + block_size, + num_blocks, + head_tile, + head_dim, + sm_scale, + cu_seqlens_q, + cu_seqlens_k, +): + + # launch kernel + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = triton.next_power_of_2(block_size) + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_q_blocks = 8 + num_b_blocks = 1 + grid_qk = lambda META: ( + triton.cdiv(num_blocks, num_b_blocks), + triton.cdiv(cur_max_valid_tokens, BLOCK_SIZE_Q * num_q_blocks), + ) + qk_kernel[grid_qk]( + q_tile, + k_tile, + m_i_cur_tiles, + valid_topk_idx_permuted_tile, + valid_lens, + valid_start_indices, + head_tile, + num_blocks, + cu_seqlens_q, + cu_seqlens_k, + head_dim, + sm_scale, + num_q_blocks, + num_b_blocks, + q_tile.stride(0), + q_tile.stride(1), + q_tile.stride(2), + k_tile.stride(0), + k_tile.stride(1), + k_tile.stride(2), + m_i_cur_tiles.stride(0), + m_i_cur_tiles.stride(1), + m_i_cur_tiles.stride(2), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=8, + num_stages=3, + ) + + m_ij_tiles = m_i_cur_tiles.cummax(dim=1).values + m_ij_last = m_ij_tiles[:, -1] + + return m_ij_tiles, m_ij_last + + +def qkv_kernel( + q_tile, + k_tile, + v_tile, + o_tiles, + acc_o_scales, + m_ij_tiles, + l_ij, + token_index_mapping, + valid_topk_idx_permuted_tile, + valid_lens, + valid_start_indices, + compute_min_block_id, + cur_max_valid_tokens, + head_tile, + compute_tile_size, + cu_seqlens_q, + cu_seqlens_k, + head_dim, + sm_scale, + block_size, +): + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = triton.next_power_of_2(block_size) + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + + # a heuristic that avoids large grid size, and redudant KV loading + num_q_blocks = 8 + + grid_fwd = lambda META: ( + compute_tile_size * head_tile, + triton.cdiv(cur_max_valid_tokens, BLOCK_SIZE_Q * num_q_blocks), + ) + + forward_kernel_opt[grid_fwd]( + q_tile, + k_tile, + v_tile, + o_tiles, + acc_o_scales, + m_ij_tiles, + l_ij, + token_index_mapping, + valid_topk_idx_permuted_tile, + valid_lens, + valid_start_indices, + compute_min_block_id, + cur_max_valid_tokens, + head_tile, + compute_tile_size, + cu_seqlens_q, + cu_seqlens_k, + head_dim, + sm_scale, + num_q_blocks, + q_tile.stride(0), + q_tile.stride(1), + q_tile.stride(2), + k_tile.stride(0), + k_tile.stride(1), + k_tile.stride(2), + v_tile.stride(0), + v_tile.stride(1), + v_tile.stride(2), + o_tiles.stride(0), + o_tiles.stride(1), + o_tiles.stride(2), + o_tiles.stride(3), + acc_o_scales.stride(0), + acc_o_scales.stride(1), + acc_o_scales.stride(2), + m_ij_tiles.stride(0), + m_ij_tiles.stride(1), + m_ij_tiles.stride(2), + l_ij.stride(0), + l_ij.stride(1), + l_ij.stride(2), + token_index_mapping.stride(0), + token_index_mapping.stride(1), + token_index_mapping.stride(2), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_stages=3, + num_warps=4, + ) + + +def reduce_output( + lse, + o, + o_tiles_first, + o_tiles_rest, + m_ij_tiles, + l_ij_first, + l_ij_rest, + m_ij_last, + acc_o_scales_first, + acc_o_scales_rest, + topk_idx_tile, + token_index_mapping, + h, + head_tile, + total_len, + TOPK, + head_dim, +): + + num_qy_loop = 4 + num_qz_loop = total_len // num_qy_loop + + grid_reduce = lambda META: ( + num_qy_loop + (total_len % num_qy_loop != 0), + num_qz_loop, + ) + + reduce_kernel[grid_reduce]( + lse, + m_ij_tiles, + l_ij_first, + l_ij_rest, + m_ij_last, + o, + o_tiles_first, + o_tiles_rest, + acc_o_scales_first, + acc_o_scales_rest, + topk_idx_tile, + token_index_mapping, + h * head_tile, + num_qz_loop, + TOPK, + total_len, + lse.stride(0), + lse.stride(1), + m_ij_tiles.stride(0), + m_ij_tiles.stride(1), + m_ij_tiles.stride(2), + l_ij_first.stride(0), + l_ij_first.stride(1), + l_ij_first.stride(2), + l_ij_rest.stride(0), + l_ij_rest.stride(1), + l_ij_rest.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + o_tiles_first.stride(0), + o_tiles_first.stride(1), + o_tiles_first.stride(2), + o_tiles_first.stride(3), + o_tiles_rest.stride(0), + o_tiles_rest.stride(1), + o_tiles_rest.stride(2), + o_tiles_rest.stride(3), + acc_o_scales_first.stride(0), + acc_o_scales_first.stride(1), + acc_o_scales_first.stride(2), + acc_o_scales_rest.stride(0), + acc_o_scales_rest.stride(1), + acc_o_scales_rest.stride(2), + topk_idx_tile.stride(0), + topk_idx_tile.stride(1), + topk_idx_tile.stride(2), + token_index_mapping.stride(0), + token_index_mapping.stride(1), + token_index_mapping.stride(2), + BLOCK_SIZE_T=triton.next_power_of_2(TOPK), + BLOCK_SIZE_D=triton.next_power_of_2(head_dim), + num_warps=1, + num_stages=2, + ) + + +def _topk_sparse_attention_fwd_opt_per_seq( + q: torch.Tensor, # [total_len, num_heads, head_dim] + k: torch.Tensor, # [total_len, num_kv_heads, head_dim] + v: torch.Tensor, # [total_len, num_kv_heads, head_dim] + topk_idx: torch.Tensor, # [num_heads, total_len, topk] + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, + causal=True, +): + # dtype check + assert k.dtype == q.dtype and v.dtype == q.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + assert block_size in {16, 32, 64, 128, 256} + # shape + + total_len, num_heads, head_dim = q.shape + total_len, num_kv_heads, head_dim = k.shape + + assert num_heads % num_kv_heads == 0 + gqa_deg = num_heads // num_kv_heads + + TOPK = topk_idx.shape[-1] + + real_num_blocks = math.ceil(total_len / block_size) + num_blocks = max(real_num_blocks, TOPK) + + head_tile = 1 + reduce_tile_size = num_blocks - 1 + + valid_lens_all = torch.zeros( + ( + num_kv_heads, + num_blocks, + ), + dtype=torch.int32, + device=q.device, + ) + for h in range(num_kv_heads): + topk_idx_tile = topk_idx[h * head_tile: (h + 1) * head_tile] + topk_idx_nonneg = topk_idx_tile[topk_idx_tile >= 0] + valid_lens = torch.bincount(topk_idx_nonneg.view(-1), minlength=num_blocks) + valid_lens_all[h * head_tile: (h + 1) * head_tile] = valid_lens + + global_max_valid_tokens = valid_lens_all[:, 1:].max() if num_blocks > 1 else valid_lens_all.max() + + o_full = torch.zeros_like(q) + lse_full = torch.full((num_heads, total_len), float("-inf"), dtype=torch.float32, device=q.device) + + # New introduced buffers + topk_idx_permuted_tile = torch.full((head_tile, num_blocks, total_len), -1, dtype=torch.int32, device=q.device) + + token_index_mapping = torch.full((head_tile, num_blocks, total_len), 0, dtype=torch.int32, device=q.device) + # first KV block is computed seaprately + o_tiles_first = torch.zeros((head_tile, 1, total_len, head_dim), dtype=torch.bfloat16, device=q.device) + o_tiles_rest = torch.zeros( + (head_tile, reduce_tile_size, global_max_valid_tokens, head_dim), dtype=torch.bfloat16, device=q.device + ) + + # Statistics buffers + # m_i_tiles: 历史最大, m_diff_tiles: 历史最大和当前最大的差值 + # m_i_cur_tiles: 当前最大, # m_ij_tiles: 考虑当前和历史后的最大 + m_i_cur_tiles: torch.Tensor = torch.full( + (head_tile, num_blocks, total_len), float("-inf"), dtype=torch.float32, device=q.device + ) + + # first KV block is reduced separately + l_ij_first = torch.full((head_tile, 1, total_len), 0, dtype=torch.float32, device=q.device) + acc_o_scales_first = torch.full((head_tile, 1, total_len), 1, dtype=torch.float32, device=q.device) + + l_ij_rest = torch.full( + (head_tile, reduce_tile_size, global_max_valid_tokens), 0, dtype=torch.float32, device=q.device + ) + acc_o_scales_rest = torch.full( + (head_tile, reduce_tile_size, global_max_valid_tokens), 1, dtype=torch.float32, device=q.device + ) + + permute_results = {} + permute_results['global_max_valid_tokens'] = global_max_valid_tokens + permute_results['num_blocks'] = num_blocks + permute_results['real_num_blocks'] = real_num_blocks + permute_results['valid_topk_idx_permuted_tile'] = [] + permute_results['valid_lens_all'] = valid_lens_all + permute_results['valid_lens'] = [] + permute_results['valid_start_indices'] = [] + + for h in range(num_heads // head_tile): + q_tile = q[:, h * head_tile: (h + 1) * head_tile] + k_tile = k[:, (h // gqa_deg) * head_tile: ((h // gqa_deg + 1)) * head_tile] + v_tile = v[:, (h // gqa_deg) * head_tile: ((h // gqa_deg + 1)) * head_tile] + o = o_full[:, h * head_tile: (h + 1) * head_tile] + lse = lse_full[h * head_tile: (h + 1) * head_tile] + + permute_min_block_id = 0 + permute_max_block_id = min(permute_min_block_id + num_blocks, num_blocks) + + topk_idx_tile = topk_idx[(h // gqa_deg) * head_tile: ((h // gqa_deg + 1)) * head_tile] + + if h % gqa_deg == 0: + topk_idx_permuted_tile = build_block_to_token_triton( + topk_idx_permuted_tile, topk_idx_tile, permute_min_block_id, permute_max_block_id, padding_value=-1 + ) + + valid_topk_idx_permuted_tile = topk_idx_permuted_tile[topk_idx_permuted_tile != -1] + valid_lens = valid_lens_all[(h // gqa_deg) * head_tile, :] + valid_start_indices = torch.nn.functional.pad(valid_lens.cumsum(0)[:-1], (1, 0), value=0) + + index_mapping( + token_index_mapping, valid_topk_idx_permuted_tile, valid_lens, valid_start_indices, num_blocks + ) + + permute_results['valid_topk_idx_permuted_tile'].append(valid_topk_idx_permuted_tile) + permute_results['valid_lens'].append(valid_lens) + permute_results['valid_start_indices'].append(valid_start_indices) + + m_ij_tiles, m_ij_last = online_softmax( + q_tile, + k_tile, + m_i_cur_tiles, + valid_topk_idx_permuted_tile, + valid_lens, + valid_start_indices, + 0, + total_len, + block_size, + num_blocks, + head_tile, + head_dim, + sm_scale, + cu_seqlens_q, + cu_seqlens_k, + ) + + m_ij_tiles[:, :, :] = m_ij_tiles[:, :, 0][:, :, None] + m_ij_last[:, :] = m_ij_last[:, 0] + for compute_min_block_id in range(min(2, num_blocks)): + if compute_min_block_id == 0: + cur_max_valid_tokens = total_len + cur_valid_lens = valid_lens[0] + cur_valid_start_indices = valid_start_indices[0] + o_tiles = o_tiles_first + l_ij = l_ij_first + acc_o_scales = acc_o_scales_first + compute_tile_size = 1 + else: + cur_max_valid_tokens = valid_lens[compute_min_block_id:].max() + cur_valid_lens = valid_lens[compute_min_block_id:] + cur_valid_start_indices = valid_start_indices[compute_min_block_id:] + o_tiles = o_tiles_rest + l_ij = l_ij_rest + acc_o_scales = acc_o_scales_rest + compute_tile_size = num_blocks - 1 + + # launch kernel + qkv_kernel( + q_tile, + k_tile, + v_tile, + o_tiles, + acc_o_scales, + m_ij_tiles, + l_ij, + token_index_mapping, + valid_topk_idx_permuted_tile, + cur_valid_lens, + cur_valid_start_indices, + compute_min_block_id, + cur_max_valid_tokens, + head_tile, + compute_tile_size, + cu_seqlens_q, + cu_seqlens_k, + head_dim, + sm_scale, + block_size, + ) + + reduce_output( + lse, + o, + o_tiles_first, + o_tiles_rest, + m_ij_tiles, + l_ij_first, + l_ij_rest, + m_ij_last, + acc_o_scales_first, + acc_o_scales_rest, + topk_idx_tile, + token_index_mapping, + h, + head_tile, + total_len, + TOPK, + head_dim, + ) + + o_full[:, h * head_tile: (h + 1) * head_tile] = o + lse_full[h * head_tile: (h + 1) * head_tile] = lse + + if h % gqa_deg == 0: + fused_fill(topk_idx_permuted_tile, m_i_cur_tiles) + + return o_full, lse_full, permute_results + + +@triton.jit +def dq_compute_kernel( + q_ptr, + k_ptr, + v_ptr, + lse_ptr, + delta_ptr, + do_ptr, + dq_tiles_ptr, + token_index_mapping_ptr, + selected_tokens_ptr, + valid_lens_ptr, + valid_start_indices_ptr, + cur_max_valid_tokens, + compute_min_block_id, + head_tile, + num_blocks, + HEAD_DIM, + cu_seqlens_k, + num_dq_blocks, + sm_scale, + debug_ptr, + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_tim_h, + stride_tim_b, + stride_tim_n, + stride_dqth, + stride_dqtb, + stride_dqtn, + stride_dqtd, + BLOCK_SIZE_Q: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + + pid_block = tl.program_id(0) + pid_q = tl.program_id(1) # token + # seq packing is not supported yet + q_start = 0 + k_start = 0 + + k_len = tl.load(cu_seqlens_k + 1) - k_start + + start_id = tl.load(valid_start_indices_ptr + pid_block) + valid_tokens = tl.load(valid_lens_ptr + pid_block) + if num_dq_blocks * pid_q * BLOCK_SIZE_Q >= valid_tokens: + return + + c = (pid_block + compute_min_block_id) * BLOCK_SIZE_K + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + + # load k + k = tl.load(tl.advance(k_ptrs, (c, 0)), boundary_check=(1, 0), padding_option="zero") + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn, + shape=(HEAD_DIM, k_len), + strides=(stride_vd, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + + # load v + v = tl.load(tl.advance(v_ptrs, (0, c)), boundary_check=(0, 1), padding_option="zero") + + qk_scale = sm_scale * 1.44269504 + + off_k = tl.arange(0, BLOCK_SIZE_K) + off_d = tl.arange(0, BLOCK_SIZE_D) + for j in range(num_dq_blocks): + pid_q_j = pid_q * num_dq_blocks + j + if pid_q_j * BLOCK_SIZE_Q < valid_tokens: + # one thread block for one KV block, a subset of selected tokens + st_offs = start_id + (q_start + pid_q_j * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)) + # st should be in shape [BLOCK_SIZE_Q] + st_mask = (pid_q_j * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)) < valid_tokens + + st = tl.load(selected_tokens_ptr + st_offs, mask=st_mask, other=-1) + tl.store(debug_ptr + tl.arange(0, BLOCK_SIZE_Q), st_offs) + # otherwise, st selects a set of q tokens, selected_tokens_ptr should be sorted + q_ptrs_off = st[:, None] * stride_qn + off_d[None, :] * stride_qd + + mask = st != -1 + + q_ptrs = q_ptr + q_start * stride_qn + q_ptrs_off + # load q + q_mask = mask[:, None] & (off_d < HEAD_DIM)[None, :] + q = tl.load(q_ptrs, mask=q_mask, other=0) + do_ptrs = do_ptr + q_start * stride_qn + q_ptrs_off + do = tl.load(do_ptrs, mask=q_mask, other=0) + delta_ptrs = delta_ptr + st[:, None] + d = tl.load(delta_ptrs, mask=mask[:, None], other=0) + lse_ptrs = lse_ptr + st[:, None] + lse = tl.load(lse_ptrs, mask=mask[:, None], other=0) + + dq = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_D), dtype=tl.float32) + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where((st[:, None] >= c + off_k[None, :]), 0, float("-inf")) + qk += tl.dot(q, tl.trans(k)) * qk_scale # [BLOCK_SIZE_Q, BLOCK_SIZE_K] + p = tl.exp2(qk - lse) # [BLOCK_SIZE_Q, BLOCK_SIZE_K] + dp = tl.dot(do, v) # [BLOCK_SIZE_Q, BLOCK_SIZE_K] + ds = sm_scale * p * (dp - d) # [BLOCK_SIZE_Q, BLOCK_SIZE_K] + ds = ds.to(q.dtype) + dq = tl.dot(ds, k) # [BLOCK_SIZE_Q, BLOCK_SIZE_D] + + # load token index mapping + token_index_mapping_ptrs = ( + token_index_mapping_ptr + (st) * stride_tim_n + (pid_block + compute_min_block_id) * stride_tim_b + ) + token_index_mapping = tl.load(token_index_mapping_ptrs, mask=mask, other=-1) + + dq_ptrs_off = token_index_mapping[:, None] * stride_dqtn + off_d[None, :] * stride_dqtd + dq_tiles_ptrs = dq_tiles_ptr + dq_ptrs_off + (pid_block).to(tl.int64) * stride_dqtb + tl.store(dq_tiles_ptrs, dq.to(dq_tiles_ptr.dtype.element_ty), mask=q_mask) + + +@triton.jit +def dq_reduce_kernel( + dq_buffer_first_ptr, # [H, 1, N, D] + dq_buffer_rest_ptr, # [H, B, N, D] + dq_ptr, # o: n x h x d + t_ptr, # topk_idx: h x n x k + token_index_mapping_ptr, + num_qz_loop, + TOPK, + total_len, + # stride + stride_dqtfh, + stride_dqtfb, + stride_dqtfn, + stride_dqtfd, + stride_dqtrh, + stride_dqtrb, + stride_dqtrn, + stride_dqtrd, + stride_dqn, + stride_dqh, + stride_dqd, + stride_th, + stride_tn, + stride_tk, + stride_tim_h, + stride_tim_b, + stride_tim_n, + # META parameters + BLOCK_SIZE_T: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_qy = tl.program_id(0) + pid_q = tl.program_id(1) # token + + pid_q_j = pid_q + pid_qy * num_qz_loop + if pid_q_j < total_len: + t_ptr_j = t_ptr + pid_q_j * stride_tn + + off_d = tl.arange(0, BLOCK_SIZE_D) + dq_ptrs = dq_ptr + pid_q_j * stride_dqn + off_d + acc_dq = tl.zeros((BLOCK_SIZE_D,), dtype=tl.float32) + + for block_id in range(TOPK): + t = tl.load(t_ptr_j + block_id * stride_tk, mask=block_id < TOPK, other=-1) + if t != -1: + if t == 0: + dq_buffer_ptr = dq_buffer_first_ptr + stride_dqtb = stride_dqtfb + stride_dqtn = stride_dqtfn + real_block_pos = 0 + else: + dq_buffer_ptr = dq_buffer_rest_ptr + stride_dqtb = stride_dqtrb + stride_dqtn = stride_dqtrn + real_block_pos = t - 1 + + # init pointers + token_index_mapping_ptrs = ( + token_index_mapping_ptr + t.to(tl.int64) * stride_tim_b + (pid_q_j) * stride_tim_n + ) + real_token_index = tl.load(token_index_mapping_ptrs) + + dq_buffer_ptrs = ( + dq_buffer_ptr + real_block_pos.to(tl.int64) * stride_dqtb + (real_token_index) * stride_dqtn + off_d + ) + + dq_buffers = tl.load(dq_buffer_ptrs) + acc_dq = dq_buffers + acc_dq + + tl.store(dq_ptrs, acc_dq, mask=off_d < BLOCK_SIZE_D) + + +def backward_dq_opt( + q, # [total_len, num_heads, head_dim] + k, # [total_len, num_k_heads, head_dim] + v, # [total_len, num_k_heads, head_dim] + topk_idx, # [num_k_heads, total_len, topk] + lse, # [num_heads, total_len] + delta, # [num_heads, total_len] + do, # [total_len, num_heads, head_dim] + dq, # [total_len, num_heads, head_dim] + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + sm_scale, + block_size, + permute_results, +): + """ + TODO: Currently sequence packing is explicitly done in for loop, will merge in kernels. + """ + for i in range(len(cu_seqlens_q) - 1): + cu_seqlens_q_ = cu_seqlens_q[i: i + 2] - cu_seqlens_q[i] + cu_seqlens_k_ = cu_seqlens_k[i: i + 2] - cu_seqlens_k[i] + + permute_results_ = permute_results[i] + + q_ = q[cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + k_ = k[cu_seqlens_k[i]: cu_seqlens_k[i + 1]] + v_ = v[cu_seqlens_k[i]: cu_seqlens_k[i + 1]] + topk_idx_ = topk_idx[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + lse_ = lse[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + delta_ = delta[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + do_ = do[cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + dq_ = dq[cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + + backward_dq_opt_per_seq( + q_, + k_, + v_, + topk_idx_, + lse_, + delta_, + do_, + dq_, + cu_seqlens_q_, + cu_seqlens_k_, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + sm_scale, + block_size, + permute_results_, + ) + + dq[cu_seqlens_q[i]: cu_seqlens_q[i + 1]] = dq_ + + return dq + + +def backward_dq_opt_per_seq( + q, # [total_len, num_k_heads, head_dim] + k, # [total_len, num_k_heads, head_dim] + v, # [total_len, num_k_heads, head_dim] + topk_idx, # [num_k_heads, total_len, topk] + lse, # [num_k_heads, total_len] + delta, # [num_k_heads, total_len] + do, # [total_len, num_k_heads, head_dim] + dq, # [total_len, num_k_heads, head_dim] + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + sm_scale, + block_size, + permute_results, +): + head_tile = 1 + total_len = topk_idx.shape[1] + global_max_valid_tokens = permute_results['global_max_valid_tokens'] + num_blocks = permute_results['num_blocks'] + reduce_tile_size = num_blocks - 1 + dq_buffer_first = torch.zeros((head_tile, 1, total_len, head_dim), dtype=torch.bfloat16, device=dq.device) + dq_buffer_rest = torch.zeros( + (head_tile, reduce_tile_size, global_max_valid_tokens, head_dim), dtype=torch.bfloat16, device=dq.device + ) + + num_heads = num_share_q_heads * num_k_heads + + token_index_mapping = torch.full((head_tile, num_blocks, total_len), 0, dtype=torch.int32, device=q.device) + for h in range(num_heads // head_tile): + valid_topk_idx_permuted_tile = permute_results['valid_topk_idx_permuted_tile'][h // num_share_q_heads] + + valid_lens = permute_results['valid_lens'][h // num_share_q_heads] + valid_start_indices = permute_results['valid_start_indices'][h // num_share_q_heads] + + index_mapping(token_index_mapping, valid_topk_idx_permuted_tile, valid_lens, valid_start_indices, num_blocks) + q_tile = q[:, h * head_tile: (h + 1) * head_tile] + k_tile = k[:, (h // num_share_q_heads) * head_tile: ((h // num_share_q_heads + 1)) * head_tile] + v_tile = v[:, (h // num_share_q_heads) * head_tile: ((h // num_share_q_heads + 1)) * head_tile] + do_tile = do[:, h * head_tile: (h + 1) * head_tile] + lse_tile = lse[h * head_tile: (h + 1) * head_tile] + topk_idx_tile = topk_idx[(h // num_share_q_heads) * head_tile: ((h // num_share_q_heads + 1)) * head_tile] + delta_tile = delta[h * head_tile: (h + 1) * head_tile] + dq_tile = dq[:, h * head_tile: (h + 1) * head_tile] + + for compute_min_block_id in range(min(2, num_blocks)): + if compute_min_block_id == 0: + compute_tile_size = 1 + cur_max_valid_tokens = total_len + cur_valid_lens = valid_lens[0] + cur_valid_start_indices = valid_start_indices[0] + dq_buffer = dq_buffer_first + else: + compute_tile_size = num_blocks - 1 + cur_max_valid_tokens = valid_lens[compute_min_block_id:].max() + cur_valid_lens = valid_lens[compute_min_block_id:] + cur_valid_start_indices = valid_start_indices[compute_min_block_id:] + dq_buffer = dq_buffer_rest + + BLOCK_SIZE_Q = 128 + num_dq_blocks = 8 + grid_dq = lambda META: ( + compute_tile_size, + triton.cdiv(cur_max_valid_tokens, BLOCK_SIZE_Q * num_dq_blocks), + ) + + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + BLOCK_SIZE_K = triton.next_power_of_2(block_size) + debug = torch.zeros((BLOCK_SIZE_Q,), dtype=torch.int32, device=dq.device) + dq_compute_kernel[grid_dq]( + q_tile, + k_tile, + v_tile, + lse_tile, + delta_tile, + do_tile, + dq_buffer, + token_index_mapping, + valid_topk_idx_permuted_tile, + cur_valid_lens, + cur_valid_start_indices, + cur_max_valid_tokens, + compute_min_block_id, + head_tile, + num_blocks, + head_dim, + cu_seqlens_k, + num_dq_blocks, + sm_scale, + debug, + q_tile.stride(0), + q_tile.stride(1), + q_tile.stride(2), + k_tile.stride(0), + k_tile.stride(1), + k_tile.stride(2), + v_tile.stride(0), + v_tile.stride(1), + v_tile.stride(2), + token_index_mapping.stride(0), + token_index_mapping.stride(1), + token_index_mapping.stride(2), + dq_buffer.stride(0), + dq_buffer.stride(1), + dq_buffer.stride(2), + dq_buffer.stride(3), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + + num_qy_loop = 4 + num_qz_loop = total_len // num_qy_loop + + grid_reduce = lambda META: ( + num_qy_loop + (total_len % num_qy_loop != 0), + num_qz_loop, + ) + dq_reduce_kernel[grid_reduce]( + dq_buffer_first, + dq_buffer_rest, + dq_tile, + topk_idx_tile, + token_index_mapping, + num_qz_loop, + topk, + total_len, + dq_buffer_first.stride(0), + dq_buffer_first.stride(1), + dq_buffer_first.stride(2), + dq_buffer_first.stride(3), + dq_buffer_rest.stride(0), + dq_buffer_rest.stride(1), + dq_buffer_rest.stride(2), + dq_buffer_rest.stride(3), + dq_tile.stride(0), + dq_tile.stride(1), + dq_tile.stride(2), + topk_idx_tile.stride(0), + topk_idx_tile.stride(1), + topk_idx_tile.stride(2), + token_index_mapping.stride(0), + token_index_mapping.stride(1), + token_index_mapping.stride(2), + BLOCK_SIZE_T=triton.next_power_of_2(topk), + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=1, + num_stages=2, + ) + + dq[:, h * head_tile: (h + 1) * head_tile] = dq_tile + + return dq + + +@triton.jit +def backward_dkdv( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + tq_ptr, # topk_q_idx: kh x N + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dk_ptr, # DK: sh x n x kh x d + dv_ptr, # DK: sh x n x kh x d + # seqlens + cu_seqlens_q, # [batch_size + 1] + cu_seqlens_k, # [batch_size + 1] + cu_seqblocks, # [batch_size + 1] + cu_topk_q_count, # [kh, total_blocks] + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + TOPK, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_tqh, + stride_tqn, + stride_ctqh, + stride_ctqn, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dks, + stride_dkn, + stride_dkh, + stride_dkd, + stride_dvs, + stride_dvn, + stride_dvh, + stride_dvd, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_kh = pid_h // NUM_SHARE_Q_HEADS + pid_sh = pid_h % NUM_SHARE_Q_HEADS + pid_k = tl.program_id(2) + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if BLOCK_SIZE_K * pid_k >= k_len: + return + # get topk_q_idx + b_start = tl.load(cu_seqblocks + pid_b) # how many blocks before current sequence + act_q_start = tl.load(cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k) * stride_ctqn) + act_q_end = tl.load(cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k + 1) * stride_ctqn) + act_q_len = act_q_end - act_q_start + tq_ptr = tq_ptr + pid_kh * stride_tqh + act_q_start * stride_tqn + # init pointers + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dk_ptrs = tl.make_block_ptr( + base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks, + shape=(k_len, HEAD_DIM), + strides=(stride_dkn, stride_dkd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dv_ptrs = tl.make_block_ptr( + base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs, + shape=(k_len, HEAD_DIM), + strides=(stride_dvn, stride_dvd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K + off_d = tl.arange(0, BLOCK_SIZE_D) + # load k v and keep in SRAM + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dk dv + dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + # init ptrs + q_ptrs = q_ptr + q_start * stride_qn + pid_h * stride_qh + off_d[None, :] * stride_qd + do_ptrs = do_ptr + q_start * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod + d_ptrs = d_ptr + q_start * stride_dn + pid_h * stride_dh + lse_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + # loop for q blocks + for i in range(0, act_q_len, BLOCK_SIZE_Q): + # load + idx_q = tl.load(tq_ptr + i + off_q, mask=off_q < act_q_len - i, other=0).to(tl.int32) + q = tl.load( + q_ptrs + idx_q[:, None] * stride_qn, + mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :], + other=0, + ) + do = tl.load( + do_ptrs + idx_q[:, None] * stride_don, + mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :], + other=0, + ) + lse = tl.load( + lse_ptrs + idx_q[:, None] * stride_ln, + mask=(off_q < act_q_len - i)[:, None], + other=0, + ) + d = tl.load( + d_ptrs + idx_q[:, None] * stride_dn, + mask=(off_q < act_q_len - i)[:, None], + other=0, + ) + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where(idx_q[:, None] >= off_k[None, :], float(0.0), float("-inf")) + qk += tl.dot(q, k.T) * qk_scale + # compute p, ds + p = tl.exp2(qk - lse) + dp = tl.dot(do, v.T) + ds = sm_scale * p * (dp - d) + # cast dtype + p = p.to(do.dtype) + ds = ds.to(q.dtype) + # update dk and dv + dk += tl.dot(ds.T, q) + dv += tl.dot(p.T, do) + # save dk dv + tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1)) + tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +def _topk_sparse_attention_bwd_opt( + o: torch.Tensor, + do: torch.Tensor, + lse: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + topk_idx: torch.Tensor, + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, + permute_results, +): + + assert block_size in {16, 32, 64, 128, 256} + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + o_len, num_o_heads, head_dim = o.shape + num_share_q_heads = num_q_heads // num_k_heads + topk = topk_idx.shape[-1] + # compute D + delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32) + BLOCK_SIZE_O = 256 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU) + grid = (triton.cdiv(o_len, BLOCK_SIZE_O), num_o_heads) + backward_sum_o_do[grid]( + o, + do, + delta, + o_len, + head_dim, + o.stride(0), + o.stride(1), + o.stride(2), + do.stride(0), + do.stride(1), + do.stride(2), + delta.stride(0), + delta.stride(1), + BLOCK_SIZE_O=BLOCK_SIZE_O, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + # count active querys for each key block, shape: (num_k_heads, total_k_blocks) + seqlens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seqblocks = torch.ceil(seqlens / block_size).to(torch.int32) + cu_seqblocks = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=topk_idx.device), + torch.cumsum(seqblocks, dim=0), + ] + ).to(torch.int32) + + topk_q_count = torch.cat( + [ + permute_results[i]['valid_lens_all'][:, : permute_results[i]['real_num_blocks']] + for i in range(len(permute_results)) + ], + dim=1, + ) + + cu_topk_q_count = torch.cat( + [ + torch.zeros(topk_q_count.shape[0], 1, dtype=torch.int32, device=topk_idx.device), + torch.cumsum(topk_q_count, dim=-1), + ], + dim=-1, + ).to(torch.int32) + # active query idx for each key block + # how to get active query idx for sequence b, head h, kv block i? + topk_q_idx = reorder_topk_idx(topk_idx, cu_topk_q_count, cu_seqlens_q, cu_seqblocks, block_size) + # compute dk dv + dk = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + dv = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + batch_size = cu_seqlens_q.shape[0] - 1 + BLOCK_SIZE_K = triton.next_power_of_2(block_size) + BLOCK_SIZE_Q = 64 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) + grid = (batch_size, num_q_heads, triton.cdiv(max_seqlen_k, BLOCK_SIZE_K)) + backward_dkdv[grid]( + q, + k, + v, + topk_q_idx, + lse, + delta, + do, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + cu_seqblocks, + cu_topk_q_count, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + topk_q_idx.stride(0), + topk_q_idx.stride(1), + cu_topk_q_count.stride(0), + cu_topk_q_count.stride(1), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dk.stride(0), + dk.stride(1), + dk.stride(2), + dk.stride(3), + dv.stride(0), + dv.stride(1), + dv.stride(2), + dv.stride(3), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + dk = dk.sum(0) + dv = dv.sum(0) + # compute dq + dq = torch.zeros_like(q) + num_q_loop = max_seqlen_q // 32768 + 1 # calculate multiple querys in one kernel if seqlence length is too long + grid = (batch_size, num_k_heads, triton.cdiv(max_seqlen_q, num_q_loop)) + BLOCK_SIZE_K = block_size + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU) + + backward_dq_opt( + q, + k, + v, + topk_idx, + lse, + delta, + do, + dq, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + sm_scale, + block_size, + permute_results, + ) + + return dq, dk, dv + + +class FSATopkSparseAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, # [total_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_len, num_k_heads, head_dim] + v: torch.Tensor, # [total_len, num_k_heads, head_dim] + topk_idx: torch.Tensor, # [num_kv_heads, total_len, topk] + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + sm_scale=None, + ): + # dtype check + assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 + assert q.dtype == k.dtype and k.dtype == v.dtype + assert topk_idx.dtype == torch.int32 + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + # softmax scale + if sm_scale is None: + sm_scale = 1 / math.sqrt(q.shape[-1]) + + permute_results = None + + o, lse, permute_results = _topk_sparse_attention_fwd_opt( + q, + k, + v, + topk_idx, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + + ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx) + ctx.permute_results = permute_results + ctx.sm_scale = sm_scale + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.block_size = block_size + return o + + @staticmethod + def backward(ctx, do: torch.Tensor, *args) -> Any: + q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx = ctx.saved_tensors + permute_results = ctx.permute_results + + max_seqlen_q = ctx.max_seqlen_q + max_seqlen_k = ctx.max_seqlen_k + sm_scale = ctx.sm_scale + block_size = ctx.block_size + assert block_size in {16, 32, 64, 128, 256} + + dq, dk, dv = _topk_sparse_attention_bwd_opt( + o, + do, + lse, + q, + k, + v, + topk_idx, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + permute_results, + ) + return dq, dk, dv, None, None, None, None, None, None, None, None + + +def FSA_topk_sparse_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + topk_idx: torch.Tensor, + block_size: int, + cu_seqlens: torch.Tensor, + softmax_scale: Optional[float] = None, +) -> torch.Tensor: + """Topk sparse attention varlen version implemented in triton. + + Args: + q (torch.Tensor): shape [total_len, num_q_heads, head_dim] + k (torch.Tensor): shape [total_len, num_kv_heads, head_dim] + v (torch.Tensor): shape [total_len, num_kv_heads, head_dim] + topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_len, topk]. -1 means padding. + block_size (int): key value block size. + cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens in flash_attn_func_varlen. + softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim). + + Returns: + torch.Tensor: attention output, shape [total_len, num_q_heads, head_dim] + """ + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + return FSATopkSparseAttention.apply( + q, + k, + v, + topk_idx, + block_size, + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + softmax_scale, + ) + + +def FSA_topk_sparse_attention_varlen( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + topk_idx: torch.Tensor, + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + softmax_scale: Optional[float] = None, +) -> torch.Tensor: + """FSA topk sparse attention with separate Q and K sequence lengths (for extend/prefill). + + Args: + q (torch.Tensor): shape [total_q, num_q_heads, head_dim] + k (torch.Tensor): shape [total_k, num_kv_heads, head_dim] + v (torch.Tensor): shape [total_k, num_kv_heads, head_dim] + topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_q, topk]. -1 means padding. + block_size (int): key value block size. + cu_seqlens_q (torch.Tensor): shape [batch_size + 1], cumulative Q sequence lengths. + cu_seqlens_k (torch.Tensor): shape [batch_size + 1], cumulative K sequence lengths. + softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim). + + Returns: + torch.Tensor: attention output, shape [total_q, num_q_heads, head_dim] + """ + max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item() + max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item() + return FSATopkSparseAttention.apply( + q, + k, + v, + topk_idx, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + ) diff --git a/vortex_torch/attention_backend/fsa/__init__.py b/vortex_torch/attention_backend/fsa/__init__.py new file mode 100644 index 0000000..9efd474 --- /dev/null +++ b/vortex_torch/attention_backend/fsa/__init__.py @@ -0,0 +1,9 @@ +from .FSA_topk_sparse_attention import ( + FSA_topk_sparse_attention, + FSA_topk_sparse_attention_varlen, +) + +__all__ = [ + "FSA_topk_sparse_attention", + "FSA_topk_sparse_attention_varlen", +] diff --git a/vortex_torch/attention_backend/nsa/__init__.py b/vortex_torch/attention_backend/nsa/__init__.py new file mode 100644 index 0000000..382da01 --- /dev/null +++ b/vortex_torch/attention_backend/nsa/__init__.py @@ -0,0 +1,9 @@ +from .topk_sparse_attention import ( + topk_sparse_attention, + topk_sparse_attention_varlen, +) + +__all__ = [ + "topk_sparse_attention", + "topk_sparse_attention_varlen", +] diff --git a/vortex_torch/attention_backend/nsa/topk_sparse_attention.py b/vortex_torch/attention_backend/nsa/topk_sparse_attention.py new file mode 100644 index 0000000..57a2be7 --- /dev/null +++ b/vortex_torch/attention_backend/nsa/topk_sparse_attention.py @@ -0,0 +1,1280 @@ +# Copyright 2025 Xunhao Lai & Jianqiao Lu. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Any, Optional + +import torch +import triton +import triton.language as tl + +def is_hopper_gpu(): + if torch.cuda.is_available(): + device_capability = torch.cuda.get_device_capability(0) + major, minor = device_capability + return major == 9 + return False + + +def get_num_warps_stages(head_dim, block_size, is_hopper_gpu): + head_large = head_dim > 64 + block_large = block_size > 64 + if is_hopper_gpu: + if head_large and block_large: + num_warps, num_stages = 8, 3 + elif head_large or block_large: + num_warps, num_stages = 4, 3 + else: + num_warps, num_stages = 2, 2 + else: + if head_large and block_large: + num_warps, num_stages = 8, 3 + elif head_large or block_large: + num_warps, num_stages = 8, 3 + else: + num_warps, num_stages = 2, 2 + return num_warps, num_stages + + +IS_HOPPER_GPU = is_hopper_gpu() + + +@triton.jit +def forward_kernel_orig( + q_ptr, # Q: n x h x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + t_ptr, # topk_idx: kh x n x k + o_ptr, # O: n x h x d + lse_ptr, # LSE: h x n + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + TOPK, + block_size, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_th, + stride_tn, + stride_tk, + stride_on, + stride_oh, + stride_od, + stride_lh, + stride_ln, + # META parameters + # q loop num + num_q_loop: tl.constexpr, + num_k_loop: tl.constexpr, + MAX_SEQ_LEN: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_T: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid = tl.program_id(0) + + Q = MAX_SEQ_LEN // num_q_loop + HK = NUM_KV_HEADS // num_k_loop + + # 第几个 (b, kh_chunk, q_chunk) + pid_b = pid // (HK * Q) + pid_kh_chunk = (pid % (HK * Q)) // Q # 每个block处理num_k_loop个KV head + pid_q = pid % Q + + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + + if pid_q * num_q_loop >= q_len: + return + real_q_loop = min(num_q_loop, q_len - pid_q * num_q_loop) + + for kh_offset in range(num_k_loop): + pid_kh = pid_kh_chunk * num_k_loop + kh_offset + pid_h = pid_kh * NUM_SHARE_Q_HEADS + + for j in range(real_q_loop): + pid_q_j = pid_q * num_q_loop + j + # init topk idx pointer + off_t = tl.arange(0, BLOCK_SIZE_T) + t_ptr_j = t_ptr + (q_start + pid_q_j) * stride_tn + pid_kh * stride_th + topk_idx = tl.load(t_ptr_j + off_t * stride_tk, mask=off_t < TOPK, other=-1) + + """Removed causal attention, which should be: + real_topk = tl.sum( + tl.where((topk_idx >= 0) & (topk_idx <= pid_q_j // block_size), 1, 0), + axis=0, + ) + """ + # real_topk = tl.sum( + # tl.where((topk_idx >= 0), 1, 0), + # axis=0, + # ) + real_topk = tl.sum( + tl.where((topk_idx >= 0) & (topk_idx <= pid_q_j // block_size), 1, 0), + axis=0, + ) + # init qkv pointer + q_ptrs = tl.make_block_ptr( + base=q_ptr + (q_start + pid_q_j) * stride_qn + pid_h * stride_qh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_qh, stride_qd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(HEAD_DIM, k_len), + strides=(stride_kd, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # load q + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + # init statistics + off_h = tl.arange(0, BLOCK_SIZE_H) + off_k = tl.arange(0, BLOCK_SIZE_K) + m_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32) + lse_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32) + acc_o = tl.full((BLOCK_SIZE_H, BLOCK_SIZE_D), 0, dtype=tl.float32) + # sparse attention + for i in range(real_topk): + # get current block start index + c = tl.load(t_ptr_j).to(tl.int32) * BLOCK_SIZE_K + t_ptr_j = t_ptr_j + stride_tk + # load k + k = tl.load(tl.advance(k_ptrs, (0, c)), boundary_check=(1, 0), padding_option="zero") + # compute qk + qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where((pid_q_j >= c + off_k)[None, :], 0, float("-inf")) + # [BLOCK_SIZE_H, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIZE_K] -> [BLOCK_SIZE_H, BLOCK_SIZE_K] + qk += tl.dot(q, k) * qk_scale + # compute m_ij and l_ij + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + # scale acc_o + acc_o_scale = tl.exp2(m_i - m_ij) + acc_o = acc_o * acc_o_scale[:, None] + # load v and update acc_o + v = tl.load(tl.advance(v_ptrs, (c, 0)), boundary_check=(0, 1), padding_option="zero") + p = p.to(v.dtype) + acc_o += tl.dot(p, v) + # update statistics + m_i = m_ij + lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij) + + # final scale + acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None] + # save output + o_ptrs = tl.make_block_ptr( + base=o_ptr + (q_start + pid_q_j) * stride_on + pid_h * stride_oh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_oh, stride_od), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1)) + # save lse + lse_ptrs = lse_ptr + (q_start + pid_q_j) * stride_ln + (pid_h + off_h) * stride_lh + tl.store(lse_ptrs, lse_i, mask=off_h < NUM_SHARE_Q_HEADS) + + +@triton.jit +def backward_sum_o_do( + o_ptr, # O: n x h x d + do_ptr, # dO: n x h x d + delta_ptr, # D: h x n + o_len, + HEAD_DIM, + stride_on, + stride_oh, + stride_od, + stride_don, + stride_doh, + stride_dod, + stride_dh, + stride_dn, + BLOCK_SIZE_O: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_h = tl.program_id(1) + off_o = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O) + off_d = tl.arange(0, BLOCK_SIZE_D) + o = tl.load( + o_ptr + off_o[:, None] * stride_on + pid_h * stride_oh + off_d[None, :] * stride_od, + mask=(off_o[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), + other=0, + ).to(tl.float32) + do = tl.load( + do_ptr + off_o[:, None] * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod, + mask=(off_o[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), + other=0, + ).to(tl.float32) + delta = tl.sum(o * do, axis=1) + tl.store(delta_ptr + pid_h * stride_dh + off_o * stride_dn, delta, mask=off_o < o_len) + + +@triton.jit +def count_kernel( + x_ptr, # [num_kv_heads, total_len, topk] + y_ptr, # [num_kv_heads, total_blocks] + cu_seqlens, # [batch_size + 1] + cu_seqblocks, # [batch_size + 1] + topk, + stride_xh, + stride_xn, + stride_xk, + stride_yh, + stride_yn, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_R: tl.constexpr, +): + pid_h = tl.program_id(0) + pid_b = tl.program_id(1) + # get start and len after rmpad + seq_start = tl.load(cu_seqlens + pid_b) + seq_len = tl.load(cu_seqlens + pid_b + 1) - seq_start + blocks_start = tl.load(cu_seqblocks + pid_b) + num_blocks = tl.load(cu_seqblocks + pid_b + 1) - blocks_start + # load x + off_k = tl.arange(0, BLOCK_SIZE_K) + off_n = tl.arange(0, BLOCK_SIZE_N) + x_ptr = x_ptr + pid_h * stride_xh + seq_start * stride_xn + x_ptrs = x_ptr + off_n[:, None] * stride_xn + off_k[None, :] * stride_xk + # init y + y = tl.zeros((BLOCK_SIZE_R,), dtype=tl.int32) + # loop + for i in range(0, seq_len, BLOCK_SIZE_N): + x = tl.load( + x_ptrs, + mask=(off_n < seq_len - i)[:, None] & (off_k < topk)[None, :], + other=-1, + ) + x = tl.ravel(x) + y += tl.histogram(x, BLOCK_SIZE_R) + x_ptrs += BLOCK_SIZE_N * stride_xn + # store result + off_r = tl.arange(0, BLOCK_SIZE_R) + y_ptr = y_ptr + pid_h * stride_yh + blocks_start * stride_yn + y_ptrs = y_ptr + off_r * stride_yn + tl.store(y_ptrs, y.to(y_ptr.dtype.element_ty), mask=off_r < num_blocks) + + +def count_query( + topk_idx: torch.Tensor, + cu_seqlens: torch.Tensor, + cu_seqblocks: torch.Tensor, + block_size: int, +): + num_kv_heads, total_len, topk = topk_idx.shape + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + seqblocks = cu_seqblocks[1:] - cu_seqblocks[:-1] + batch_size = seqlens.shape[0] + BLOCK_SIZE_K = triton.next_power_of_2(topk) + BLOCK_SIZE_N = triton.next_power_of_2(4096 // BLOCK_SIZE_K) + BLOCK_SIZE_R = triton.next_power_of_2(seqblocks.max().item() + 2) + active_query_count = torch.zeros(num_kv_heads, cu_seqblocks[-1], dtype=torch.int32, device=topk_idx.device) + grid = (num_kv_heads, batch_size) + count_kernel[grid]( + topk_idx, + active_query_count, + cu_seqlens, + cu_seqblocks, + topk, + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + active_query_count.stride(0), + active_query_count.stride(1), + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_R=BLOCK_SIZE_R, + num_warps=4, + num_stages=3, + ) + return active_query_count + + +@triton.jit +def pad_topk_idx_kernel( + t_ptr, + p_ptr, + cu_seqlens, + topk, + stride_th, + stride_tn, + stride_tk, + stride_pb, + stride_ph, + stride_pn, + stride_pk, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_T: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_n = tl.program_id(2) + # get q start and len after rmpad + q_start = tl.load(cu_seqlens + pid_b) + q_len = tl.load(cu_seqlens + pid_b + 1) - q_start + if BLOCK_SIZE_N * pid_n >= q_len: + return + # init prts + t_ptrs = tl.make_block_ptr( + base=t_ptr + pid_h * stride_th + q_start * stride_tn, + shape=(q_len, topk), + strides=(stride_tn, stride_tk), + offsets=(pid_n * BLOCK_SIZE_N, 0), + block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_T), + order=(1, 0), + ) + p_ptrs = tl.make_block_ptr( + base=p_ptr + pid_b * stride_pb + pid_h * stride_ph, + shape=(q_len, topk), + strides=(stride_pn, stride_pk), + offsets=(pid_n * BLOCK_SIZE_N, 0), + block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_T), + order=(1, 0), + ) + # load and save + idxs = tl.load(t_ptrs, boundary_check=(0, 1)) + tl.store(p_ptrs, idxs, boundary_check=(0, 1)) + + +@triton.jit +def save_topk_idx_kernel( + p_ptr, + t_ptr, + cu_seqblocks, + cu_topk_q_count, + n_len, + stride_pb, + stride_ph, + stride_pn, + stride_th, + stride_tn, + stride_ch, + stride_cn, + BLOCK_SIZE_N: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_n = tl.program_id(2) + # get q start and len after rmpad + q_block_start = tl.load(cu_seqblocks + pid_b) + q_block_end = tl.load(cu_seqblocks + pid_b + 1) + c_start = tl.load(cu_topk_q_count + pid_h * stride_ch + q_block_start * stride_cn) + c_end = tl.load(cu_topk_q_count + pid_h * stride_ch + q_block_end * stride_cn) + c_len = c_end - c_start + if c_len <= 0: + return + if pid_n * BLOCK_SIZE_N >= c_len: + return + # init ptrs + p_ptrs = tl.make_block_ptr( + base=p_ptr + pid_b * stride_pb + pid_h * stride_ph + (n_len - c_len) * stride_pn, + shape=(c_len,), + strides=(stride_pn,), + offsets=(pid_n * BLOCK_SIZE_N,), + block_shape=(BLOCK_SIZE_N,), + order=(0,), + ) + t_ptrs = tl.make_block_ptr( + base=t_ptr + pid_h * stride_th + c_start * stride_tn, + shape=(c_len,), + strides=(stride_tn,), + offsets=(pid_n * BLOCK_SIZE_N,), + block_shape=(BLOCK_SIZE_N,), + order=(0,), + ) + # load and save + idxs = tl.load(p_ptrs, boundary_check=(0,)) + tl.store(t_ptrs, idxs, boundary_check=(0,)) + + +def reorder_topk_idx( + topk_idx: torch.Tensor, + cu_topk_q_count: torch.Tensor, + cu_seqlens: torch.Tensor, + cu_seqblocks: torch.Tensor, + block_size: int, +): + num_kv_heads, total_len, topk = topk_idx.shape + batch_size = cu_seqlens.shape[0] - 1 + seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] + max_seqlen = seq_lens.max().item() + # pad shape [num_kv_heads, total_seqlen, topk] to [batch_size, num_kv_heads, max_seqlen, topk] + pad_topk_idx = torch.full( + (batch_size, num_kv_heads, max_seqlen, topk), + fill_value=-1, + device=topk_idx.device, + dtype=torch.int32, + ) + BLOCK_SIZE_T = triton.next_power_of_2(topk) + BLOCK_SIZE_N = min(triton.next_power_of_2(max_seqlen), triton.next_power_of_2(8192 // BLOCK_SIZE_T)) + grid = (batch_size, num_kv_heads, triton.cdiv(max_seqlen, BLOCK_SIZE_N)) + pad_topk_idx_kernel[grid]( + topk_idx, + pad_topk_idx, + cu_seqlens, + topk, + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + pad_topk_idx.stride(0), + pad_topk_idx.stride(1), + pad_topk_idx.stride(2), + pad_topk_idx.stride(3), + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_T=BLOCK_SIZE_T, + ) + # argsort + pad_topk_q_idx = pad_topk_idx.view(batch_size, num_kv_heads, -1).argsort(-1) // topk + pad_topk_q_idx = pad_topk_q_idx.to(torch.int32) + # save as remove pad version + topk_q_idx = torch.full( + (num_kv_heads, cu_topk_q_count[:, -1].max().item()), + fill_value=-1, + device=topk_idx.device, + dtype=torch.int32, + ) + max_len = (cu_topk_q_count[:, cu_seqblocks][:, 1:] - cu_topk_q_count[:, cu_seqblocks][:, :-1]).max().item() + BLOCK_SIZE_N = min(triton.next_power_of_2(max_len), 8192) + grid = (batch_size, num_kv_heads, triton.cdiv(max_len, BLOCK_SIZE_N)) + save_topk_idx_kernel[grid]( + pad_topk_q_idx, + topk_q_idx, + cu_seqblocks, + cu_topk_q_count, + pad_topk_q_idx.shape[-1], + pad_topk_q_idx.stride(0), + pad_topk_q_idx.stride(1), + pad_topk_q_idx.stride(2), + topk_q_idx.stride(0), + topk_q_idx.stride(1), + cu_topk_q_count.stride(0), + cu_topk_q_count.stride(1), + BLOCK_SIZE_N=BLOCK_SIZE_N, + ) + return topk_q_idx + + +@triton.jit +def backward_dkdv( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + tq_ptr, # topk_q_idx: kh x N + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dk_ptr, # DK: sh x n x kh x d + dv_ptr, # DK: sh x n x kh x d + # seqlens + cu_seqlens_q, # [batch_size + 1] + cu_seqlens_k, # [batch_size + 1] + cu_seqblocks, # [batch_size + 1] + cu_topk_q_count, # [kh, total_blocks] + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + TOPK, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_tqh, + stride_tqn, + stride_ctqh, + stride_ctqn, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dks, + stride_dkn, + stride_dkh, + stride_dkd, + stride_dvs, + stride_dvn, + stride_dvh, + stride_dvd, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_kh = pid_h // NUM_SHARE_Q_HEADS + pid_sh = pid_h % NUM_SHARE_Q_HEADS + pid_k = tl.program_id(2) + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if BLOCK_SIZE_K * pid_k >= k_len: + return + # get topk_q_idx + b_start = tl.load(cu_seqblocks + pid_b) # how many blocks before current sequence + act_q_start = tl.load(cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k) * stride_ctqn) + act_q_end = tl.load(cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k + 1) * stride_ctqn) + act_q_len = act_q_end - act_q_start + tq_ptr = tq_ptr + pid_kh * stride_tqh + act_q_start * stride_tqn + # init pointers + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dk_ptrs = tl.make_block_ptr( + base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks, + shape=(k_len, HEAD_DIM), + strides=(stride_dkn, stride_dkd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dv_ptrs = tl.make_block_ptr( + base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs, + shape=(k_len, HEAD_DIM), + strides=(stride_dvn, stride_dvd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K + off_d = tl.arange(0, BLOCK_SIZE_D) + # load k v and keep in SRAM + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dk dv + dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + # init ptrs + q_ptrs = q_ptr + q_start * stride_qn + pid_h * stride_qh + off_d[None, :] * stride_qd + do_ptrs = do_ptr + q_start * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod + d_ptrs = d_ptr + q_start * stride_dn + pid_h * stride_dh + lse_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + # loop for q blocks + for i in range(0, act_q_len, BLOCK_SIZE_Q): + # load + idx_q = tl.load(tq_ptr + i + off_q, mask=off_q < act_q_len - i, other=0).to(tl.int32) + q = tl.load( + q_ptrs + idx_q[:, None] * stride_qn, + mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :], + other=0, + ) + do = tl.load( + do_ptrs + idx_q[:, None] * stride_don, + mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :], + other=0, + ) + lse = tl.load( + lse_ptrs + idx_q[:, None] * stride_ln, + mask=(off_q < act_q_len - i)[:, None], + other=0, + ) + d = tl.load( + d_ptrs + idx_q[:, None] * stride_dn, + mask=(off_q < act_q_len - i)[:, None], + other=0, + ) + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where(idx_q[:, None] >= off_k[None, :], float(0.0), float("-inf")) + qk += tl.dot(q, k.T) * qk_scale + # compute p, ds + p = tl.exp2(qk - lse) + dp = tl.dot(do, v.T) + ds = sm_scale * p * (dp - d) + # cast dtype + p = p.to(do.dtype) + ds = ds.to(q.dtype) + # update dk and dv + dk += tl.dot(ds.T, q) + dv += tl.dot(p.T, do) + # save dk dv + tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1)) + tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def backward_dq( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + t_ptr, # topk_idx: kh x n x k + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dq_ptr, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + TOPK, + # q loop num + num_q_loop, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_th, + stride_tn, + stride_tk, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dqn, + stride_dqh, + stride_dqd, + # META parameters + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_T: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_kh = tl.program_id(1) + pid_q = tl.program_id(2) + pid_h = pid_kh * NUM_SHARE_Q_HEADS + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if pid_q * num_q_loop >= q_len: + return + real_q_loop = min(num_q_loop, q_len - pid_q * num_q_loop) + for j in range(real_q_loop): + pid_q_j = pid_q * num_q_loop + j + # init topk idx pointer + off_t = tl.arange(0, BLOCK_SIZE_T) + t_ptr_j = t_ptr + (q_start + pid_q_j) * stride_tn + pid_kh * stride_th + topk_idx = tl.load(t_ptr_j + off_t * stride_tk, mask=off_t < TOPK, other=-1) + + real_topk = tl.sum( + tl.where((topk_idx >= 0) & (topk_idx <= pid_q_j // BLOCK_SIZE_K), 1, 0), + axis=0, + ) + # init pointers + q_ptrs = tl.make_block_ptr( + base=q_ptr + (q_start + pid_q_j) * stride_qn + pid_h * stride_qh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_qh, stride_qd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + dq_ptrs = tl.make_block_ptr( + base=dq_ptr + (q_start + pid_q_j) * stride_dqn + pid_h * stride_dqh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_dqh, stride_dqd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(HEAD_DIM, k_len), + strides=(stride_vd, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + do_ptrs = tl.make_block_ptr( + base=do_ptr + (q_start + pid_q_j) * stride_don + pid_h * stride_doh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_doh, stride_dod), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + d_ptrs = tl.make_block_ptr( + base=d_ptr + (q_start + pid_q_j) * stride_dn + pid_h * stride_dh, + shape=(NUM_SHARE_Q_HEADS, 1), + strides=(stride_dh, stride_dn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, 1), + order=(1, 0), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + (q_start + pid_q_j) * stride_ln + pid_h * stride_lh, + shape=(NUM_SHARE_Q_HEADS, 1), + strides=(stride_lh, stride_ln), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, 1), + order=(1, 0), + ) + # offsets + off_k = tl.arange(0, BLOCK_SIZE_K) + # load q, do, lse, delta, and keep in SRAM + q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero") + do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dq + dq = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_D), dtype=tl.float32) + # sparse + for i in range(real_topk): + # get current block start index + c = tl.load(t_ptr_j).to(tl.int32) * BLOCK_SIZE_K + t_ptr_j = t_ptr_j + stride_tk + # load + k = tl.load(tl.advance(k_ptrs, (c, 0)), boundary_check=(1, 0), padding_option="zero") + v = tl.load(tl.advance(v_ptrs, (0, c)), boundary_check=(0, 1), padding_option="zero") + # compute qk + qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where((pid_q_j >= c + off_k)[None, :], 0, float("-inf")) + # [BLOCK_SIZE_H, HEAD_DIM] @ [BLOCK_SIZE_K, HEAD_DIM].T -> [BLOCK_SIZE_H, BLOCK_SIZE_K] + qk += tl.dot(q, tl.trans(k)) * qk_scale + # compute p, ds + p = tl.exp2(qk - lse) + dp = tl.dot(do, v) + ds = sm_scale * p * (dp - d) + # cast dtype + ds = ds.to(q.dtype) + # update dq + dq += tl.dot(ds, k) + # save dq + tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +def _topk_sparse_attention_fwd( + q: torch.Tensor, # [total_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_len, num_k_heads, head_dim] + v: torch.Tensor, # [total_len, num_k_heads, head_dim] + topk_idx: torch.Tensor, # [num_kv_heads, total_len, topk] + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, +): + # dtype check + assert k.dtype == q.dtype and v.dtype == q.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + assert block_size in {16, 32, 64, 128, 256} + # shape + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + batch_size = cu_seqlens_q.shape[0] - 1 + # assert q_len == k_len and k_len == v_len + topk = topk_idx.shape[-1] + assert topk_idx.shape[0] == num_k_heads + assert topk_idx.shape[1] == q_len + # gqa + assert num_k_heads == num_v_heads + assert num_q_heads % num_k_heads == 0 + num_share_q_heads = num_q_heads // num_k_heads + # output tensor + o = torch.zeros_like(q) + + lse = torch.zeros(num_q_heads, q_len, dtype=torch.float32, device=q.device) + + # launch kernel + num_q_loop = num_k_loop = 1 + BLOCK_SIZE_K = triton.next_power_of_2(block_size) + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + BLOCK_SIZE_H = max(16, triton.next_power_of_2(num_share_q_heads)) + BLOCK_SIZE_T = triton.next_power_of_2(topk) + + def grid(meta): + grid = ( + batch_size * triton.cdiv(num_k_heads, num_k_loop) * triton.cdiv(max_seqlen_q, num_q_loop), + ) + return grid + + num_warps, num_stages = get_num_warps_stages(head_dim, block_size, IS_HOPPER_GPU) + forward_kernel_orig[grid]( + q, + k, + v, + topk_idx, + o, + lse, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + block_size, + # num_q_loop, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + lse.stride(0), + lse.stride(1), + num_q_loop=num_q_loop, + num_k_loop=num_k_loop, + MAX_SEQ_LEN=max_seqlen_q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + BLOCK_SIZE_H=BLOCK_SIZE_H, + BLOCK_SIZE_T=BLOCK_SIZE_T, + num_warps=num_warps, + num_stages=num_stages, + ) + return o, lse + + +def _topk_sparse_attention_bwd( + o: torch.Tensor, + do: torch.Tensor, + lse: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + topk_idx: torch.Tensor, + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, +): + + assert block_size in {16, 32, 64, 128, 256} + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + o_len, num_o_heads, head_dim = o.shape + num_share_q_heads = num_q_heads // num_k_heads + topk = topk_idx.shape[-1] + # compute D + delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32) + BLOCK_SIZE_O = 256 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU) + grid = (triton.cdiv(o_len, BLOCK_SIZE_O), num_o_heads) + + backward_sum_o_do[grid]( + o, + do, + delta, + o_len, + head_dim, + o.stride(0), + o.stride(1), + o.stride(2), + do.stride(0), + do.stride(1), + do.stride(2), + delta.stride(0), + delta.stride(1), + BLOCK_SIZE_O=BLOCK_SIZE_O, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + # count active querys for each key block, shape: (num_k_heads, total_k_blocks) + seqlens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seqblocks = torch.ceil(seqlens / block_size).to(torch.int32) + cu_seqblocks = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=topk_idx.device), + torch.cumsum(seqblocks, dim=0), + ] + ).to(torch.int32) + + topk_q_count = count_query(topk_idx, cu_seqlens_q, cu_seqblocks, block_size) + + cu_topk_q_count = torch.cat( + [ + torch.zeros(topk_q_count.shape[0], 1, dtype=torch.int32, device=topk_idx.device), + torch.cumsum(topk_q_count, dim=-1), + ], + dim=-1, + ).to(torch.int32) + # active query idx for each key block + # how to get active query idx for sequence b, head h, kv block i? + topk_q_idx = reorder_topk_idx(topk_idx, cu_topk_q_count, cu_seqlens_q, cu_seqblocks, block_size) + # compute dk dv + dk = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + dv = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + batch_size = cu_seqlens_q.shape[0] - 1 + BLOCK_SIZE_K = triton.next_power_of_2(block_size) + BLOCK_SIZE_Q = 64 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) + grid = (batch_size, num_q_heads, triton.cdiv(max_seqlen_k, BLOCK_SIZE_K)) + backward_dkdv[grid]( + q, + k, + v, + topk_q_idx, + lse, + delta, + do, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + cu_seqblocks, + cu_topk_q_count, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + topk_q_idx.stride(0), + topk_q_idx.stride(1), + cu_topk_q_count.stride(0), + cu_topk_q_count.stride(1), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dk.stride(0), + dk.stride(1), + dk.stride(2), + dk.stride(3), + dv.stride(0), + dv.stride(1), + dv.stride(2), + dv.stride(3), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + dk = dk.sum(0) + dv = dv.sum(0) + # compute dq + dq = torch.zeros_like(q) + num_q_loop = max_seqlen_q // 32768 + 1 # calculate multiple querys in one kernel if seqlence length is too long + grid = (batch_size, num_k_heads, triton.cdiv(max_seqlen_q, num_q_loop)) + BLOCK_SIZE_K = block_size + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + BLOCK_SIZE_H = max(16, triton.next_power_of_2(num_share_q_heads)) + BLOCK_SIZE_T = triton.next_power_of_2(topk) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU) + + backward_dq[grid]( + q, + k, + v, + topk_idx, + lse, + delta, + do, + dq, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + num_q_loop, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dq.stride(0), + dq.stride(1), + dq.stride(2), + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + BLOCK_SIZE_H=BLOCK_SIZE_H, + BLOCK_SIZE_T=BLOCK_SIZE_T, + num_warps=num_warps, + num_stages=num_stages, + ) + return dq, dk, dv + + +class TopkSparseAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, # [total_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_len, num_k_heads, head_dim] + v: torch.Tensor, # [total_len, num_k_heads, head_dim] + topk_idx: torch.Tensor, # [num_kv_heads, total_len, topk] + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + sm_scale=None, + ): + # dtype check + assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 + assert q.dtype == k.dtype and k.dtype == v.dtype + assert topk_idx.dtype == torch.int32 + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + # softmax scale + if sm_scale is None: + sm_scale = 1 / math.sqrt(q.shape[-1]) + + o, lse = _topk_sparse_attention_fwd( + q, + k, + v, + topk_idx, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + + ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx) + ctx.sm_scale = sm_scale + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.block_size = block_size + return o + + @staticmethod + def backward(ctx, do: torch.Tensor, *args) -> Any: + q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx = ctx.saved_tensors + + max_seqlen_q = ctx.max_seqlen_q + max_seqlen_k = ctx.max_seqlen_k + sm_scale = ctx.sm_scale + block_size = ctx.block_size + assert block_size in {16, 32, 64, 128, 256} + + dq, dk, dv = _topk_sparse_attention_bwd( + o, + do, + lse, + q, + k, + v, + topk_idx, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + return dq, dk, dv, None, None, None, None, None, None, None, None + + +def topk_sparse_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + topk_idx: torch.Tensor, + block_size: int, + cu_seqlens: torch.Tensor, + softmax_scale: Optional[float] = None, +) -> torch.Tensor: + """Topk sparse attention varlen version implemented in triton. + + Args: + q (torch.Tensor): shape [total_len, num_q_heads, head_dim] + k (torch.Tensor): shape [total_len, num_kv_heads, head_dim] + v (torch.Tensor): shape [total_len, num_kv_heads, head_dim] + topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_len, topk]. -1 means padding. + block_size (int): key value block size. + cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens in flash_attn_func_varlen. + softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim). + + Returns: + torch.Tensor: attention output, shape [total_len, num_q_heads, head_dim] + """ + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + return TopkSparseAttention.apply( + q, + k, + v, + topk_idx, + block_size, + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + softmax_scale, + ) + + +def topk_sparse_attention_varlen( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + topk_idx: torch.Tensor, + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + softmax_scale: Optional[float] = None, +) -> torch.Tensor: + """Topk sparse attention with separate Q and K sequence lengths (for extend/prefill). + + Same as topk_sparse_attention but accepts separate cu_seqlens for Q and K. + Useful when Q only covers new tokens while K covers all tokens (prefix + new). + + Args: + q (torch.Tensor): shape [total_q, num_q_heads, head_dim] + k (torch.Tensor): shape [total_k, num_kv_heads, head_dim] + v (torch.Tensor): shape [total_k, num_kv_heads, head_dim] + topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_q, topk]. -1 means padding. + block_size (int): key value block size. + cu_seqlens_q (torch.Tensor): shape [batch_size + 1], cumulative Q sequence lengths. + cu_seqlens_k (torch.Tensor): shape [batch_size + 1], cumulative K sequence lengths. + softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim). + + Returns: + torch.Tensor: attention output, shape [total_q, num_q_heads, head_dim] + """ + max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item() + max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item() + return TopkSparseAttention.apply( + q, + k, + v, + topk_idx, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + ) diff --git a/vortex_torch/flow/external_algorithms.py b/vortex_torch/flow/external_algorithms.py new file mode 100644 index 0000000..5f8935f --- /dev/null +++ b/vortex_torch/flow/external_algorithms.py @@ -0,0 +1,76 @@ +""" +External sparse attention algorithm registrations for NSA, FSA, and FlashMoBA. + +These vFlow subclasses use simple centroid-based routing for the DECODE path +(forward_indexer + forward_cache), identical to BlockSparseAttention. + +The EXTEND path (forward_extend) is handled directly in vtx_graph_backend.py +using each algorithm's own sparse attention kernel — these vFlow classes are +not involved in extend. +""" + +import torch +from typing import Dict, Tuple + +from .flow import vFlow +from ..indexer import topK, GeMV +from ..cache import Mean as CMean +from ..abs import ContextBase +from .registry import register + + +class _ExternalAlgoBase(vFlow): + """ + Base vFlow for external sparse attention algorithms (NSA, FSA, FlashMoBA). + + Decode routing: centroid-based (same as BlockSparseAttention). + Extend: bypassed — vtx_graph_backend dispatches to algorithm-specific kernels. + """ + + def __init__(self): + super().__init__() + self.gemv = GeMV() + self.output_func = topK() + self.reduction = CMean(dim=1) + + def forward_indexer( + self, + q: torch.Tensor, + o: torch.Tensor, + cache: Dict[str, torch.Tensor], + ctx: ContextBase, + ): + q_mean = q.mean(dim=1, keepdim=True) + score = self.gemv(q_mean, cache["centroids"], ctx=ctx) + self.output_func(score, o, ctx=ctx) + + def forward_cache( + self, + cache: Dict[str, torch.Tensor], + loc: torch.Tensor, + ctx: ContextBase, + ): + self.reduction(cache["k"], cache["centroids"], loc=loc, ctx=ctx) + + def create_cache(self, page_size: int, head_dim: int) -> Dict[str, Tuple[int, int]]: + return { + "centroids": (1, head_dim), + } + + +@register("nsa") +class NSASparseAttention(_ExternalAlgoBase): + """Naive Sparse Attention — decode uses centroid routing, extend uses NSA kernels.""" + pass + + +@register("fsa") +class FSASparseAttention(_ExternalAlgoBase): + """Flash Sparse Attention — decode uses centroid routing, extend uses FSA kernels.""" + pass + + +@register("flash_moba") +class FlashMoBASparseAttention(_ExternalAlgoBase): + """FlashMoBA — decode uses centroid routing, extend uses FlashMoBA kernels.""" + pass diff --git a/vortex_torch/kernels/__init__.py b/vortex_torch/kernels/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vortex_torch/kernels/fsa/__init__.py b/vortex_torch/kernels/fsa/__init__.py new file mode 100644 index 0000000..25d5b3e --- /dev/null +++ b/vortex_torch/kernels/fsa/__init__.py @@ -0,0 +1,5 @@ +from .fused_score_kernels import _fused_attention_score_and_transform + +__all__ = [ + "_fused_attention_score_and_transform", +] diff --git a/vortex_torch/kernels/fsa/fused_score_kernels.py b/vortex_torch/kernels/fsa/fused_score_kernels.py new file mode 100644 index 0000000..f2a05ed --- /dev/null +++ b/vortex_torch/kernels/fsa/fused_score_kernels.py @@ -0,0 +1,300 @@ +# This file provides a fused implementation of computing attention score for selected attention indices. +# TODO: this implementation may incur illegal memory access issues, will be fixed. +import math + +import torch +import triton +import triton.language as tl + +from ..nsa.utils import is_hopper_gpu + +IS_HOPPER_GPU = is_hopper_gpu() + + +@triton.jit +def fused_score_kernel( + q_ptr, # q_len x h x d + k_ptr, # k_len x h x d + lse_ptr, # h x n + bs_ptr, # h x n x nb + offs_ptr, # BO + kernel_size, + kernel_stride, + num_offs, # BO + num_k_blocks, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, # which is also num_q_heads + HEAD_DIM, + # sm_scale + sm_scale, + max_blocks, + pad_len, + block_size, + block_stride, + init_blocks, + local_blocks, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_lh, + stride_ln, + stride_bsh, + stride_bsq, + stride_bsnb, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_bkh = tl.program_id(0) + pid_b = pid_bkh // NUM_KV_HEADS + pid_kh = pid_bkh % NUM_KV_HEADS + pid_q = tl.program_id(1) + pid_k = tl.program_id(2) # the blocks id of k + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + + k_start += pid_k * BLOCK_SIZE_K * num_k_blocks + if pid_q * BLOCK_SIZE_Q >= q_len or pid_k * BLOCK_SIZE_K >= k_len: + return + + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_kh * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + q_start * stride_ln + pid_kh * stride_lh, + shape=(q_len, 1), + strides=(stride_ln, stride_lh), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + # load q and lse + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + + for j in range(num_k_blocks): + k_start_j = k_start + j * BLOCK_SIZE_K + if k_start_j < k_len: + off_d = tl.arange(0, BLOCK_SIZE_D) + off_q = pid_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q) + # k offsets + off_k = (k_start_j + tl.arange(0, BLOCK_SIZE_K)) * block_stride - pad_len + k_ptrs = k_ptr + pid_kh * stride_kh + off_k[None, :] * stride_kn + off_d[:, None] * stride_kd + causal_mask = off_q[:, None] >= (off_k * kernel_stride + kernel_size - 1)[None, :] + + # init block score + bs = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + for i in range(num_offs): + k = tl.load(k_ptrs, mask=causal_mask, other=0) + w = tl.load(offs_ptr + i, mask=i < num_offs, other=0) + # compute qk + qk = tl.dot(q, k) * qk_scale + # compute score and apply weight + bs += w * tl.where(causal_mask, tl.exp2(qk - lse), 0) + + # increment pointers + off_k += 1 + k_ptrs = k_ptr + pid_kh * stride_kh + off_k[None, :] * stride_kn + off_d[:, None] * stride_kd + causal_mask = off_q[:, None] >= (off_k * kernel_stride + kernel_size - 1)[None, :] + + # init mask and local mask + off_bq = off_q // block_size + off_bk = tl.arange(0, BLOCK_SIZE_K) + bs = tl.where( + ( + (off_bq[:, None] >= k_start_j + off_bk[None, :]) + & (off_bq[:, None] < k_start_j + off_bk[None, :] + local_blocks) + ) + | (off_bk[None, :] < init_blocks - k_start_j), + float("inf"), + bs, + ) + + # save output + bs_ptrs = ( + bs_ptr + + pid_kh.to(tl.int64) * stride_bsh + + q_start * stride_bsq + + k_start_j * stride_bsnb + + off_q[:, None] * stride_bsq + + off_bk[None, :] * stride_bsnb + ) + + tl.store( + bs_ptrs, + bs.to(bs_ptr.dtype.element_ty), + mask=(off_q < q_len)[:, None] & (off_bk < max_blocks - k_start_j)[None, :], + ) + + +def _fused_attention_score_and_transform( + q: torch.Tensor, # [total_query_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_key_len, num_k_heads, head_dim] + lse: torch.Tensor, # [num_q_heads, total_query_len] + kernel_size: int, + kernel_stride: int, + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, + init_blocks: int = 1, + local_blocks: int = 2, + align_baseline: bool = False, +) -> torch.Tensor: + + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + max_blocks = math.ceil(max_seqlen_q / block_size) + # init block score + block_scores = torch.zeros( + num_k_heads, + q_len, + max_blocks, + dtype=torch.float32 if align_baseline else torch.bfloat16, + device=q.device, + ) + offs = ( + torch.arange(kernel_size // kernel_stride, device=q.device)[:, None] + + torch.arange(block_size // kernel_stride, device=q.device)[None, :] + ).view(-1) + + offs = torch.histc(offs, bins=offs.max() + 1, min=0, max=offs.max()) + + num_offs = int(offs.shape[0]) + for i in range(cu_seqlens_q.shape[0] - 1): + q_seq = q[cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + k_seq = k[cu_seqlens_k[i]: cu_seqlens_k[i + 1]] + lse_seq = lse[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + block_scores_seq = block_scores[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + + _fused_attention_score_and_transform_per_seq( + q_seq, + k_seq, + lse_seq, + block_scores_seq, + kernel_size, + kernel_stride, + block_size, + offs, + num_offs, + cu_seqlens_q[i: i + 2] - cu_seqlens_q[i], + cu_seqlens_k[i: i + 2] - cu_seqlens_k[i], + cu_seqlens_q[i + 1] - cu_seqlens_q[i], + cu_seqlens_k[i + 1] - cu_seqlens_k[i], + sm_scale, + init_blocks, + local_blocks, + ) + block_scores[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] = block_scores_seq + return block_scores + + +@torch.inference_mode() +def _fused_attention_score_and_transform_per_seq( + q: torch.Tensor, # [total_query_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_key_len, num_k_heads, head_dim] + lse: torch.Tensor, # [num_q_heads, total_query_len] + block_score: torch.Tensor, + kernel_size: int, + kernel_stride: int, + block_size: int, + offs, + num_offs, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, + init_blocks: int = 1, + local_blocks: int = 2, +) -> torch.Tensor: + # dtype check + assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 + assert q.dtype == k.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + assert lse.dtype == torch.float32 # lse here is log2(sum(exp(qk*scale))), not log(sum(exp(qk*scale))) + # shape + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + batch_size = cu_seqlens_q.shape[0] - 1 + assert q_len > k_len + if sm_scale is None: + sm_scale = 1 / math.sqrt(head_dim) + + max_blocks = math.ceil(max_seqlen_q / block_size) + + pad_len = kernel_size // kernel_stride - 1 + max_blocks = math.ceil(max_seqlen_q / block_size) + + BLOCK_SIZE_K = min(128, triton.next_power_of_2(max_blocks)) + # ensure qk is valid on triton + BLOCK_SIZE_K = max(BLOCK_SIZE_K, 16) + BLOCK_SIZE_Q = 128 + + # launch kernel + num_k_blocks = 1 + grid = lambda META: ( + batch_size * num_k_heads, + triton.cdiv(max_seqlen_q, BLOCK_SIZE_Q), + triton.cdiv(max_blocks, BLOCK_SIZE_K * num_k_blocks), + ) + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + + fused_score_kernel[grid]( + q, + k, + lse, + block_score, + offs, + kernel_size, + kernel_stride, + num_offs, + num_k_blocks, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + head_dim, + sm_scale, + max_blocks, + pad_len, + block_size, + block_size // kernel_stride, + init_blocks, + local_blocks, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + lse.stride(0), + lse.stride(1), + block_score.stride(0), + block_score.stride(1), + block_score.stride(2), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=8, + num_stages=3, + ) diff --git a/vortex_torch/kernels/nsa/__init__.py b/vortex_torch/kernels/nsa/__init__.py new file mode 100644 index 0000000..9af3029 --- /dev/null +++ b/vortex_torch/kernels/nsa/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2025 Xunhao Lai. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .compressed_attention import compressed_attention +from .weighted_pool import (avgpool_compress, softmaxpool_compress, + weightedpool_compress) + +__all__ = [ + "compressed_attention", + "avgpool_compress", + "weightedpool_compress", + "softmaxpool_compress", +] diff --git a/vortex_torch/kernels/nsa/compressed_attention.py b/vortex_torch/kernels/nsa/compressed_attention.py new file mode 100644 index 0000000..9770a94 --- /dev/null +++ b/vortex_torch/kernels/nsa/compressed_attention.py @@ -0,0 +1,1317 @@ +# Copyright 2025 Xunhao Lai & Jianqiao Lu. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import warnings +from typing import Any, Tuple, Union + +import torch +import triton +import triton.language as tl + +from .utils import get_num_warps_stages, is_hopper_gpu + +IS_HOPPER_GPU = is_hopper_gpu() + + +@triton.jit +def forward_kernel( + q_ptr, # Q: n x h x d + k_ptr, # K: n x h x d + v_ptr, # V: n x h x d + o_ptr, # O: n x h x d + lse_ptr, # LSE: h x n + # size and stride at compresstion + kernel_size, + kernel_stride, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_on, + stride_oh, + stride_od, + stride_lh, + stride_ln, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_q = tl.program_id(2) + pid_kh = pid_h // NUM_SHARE_Q_HEADS + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + # skip first kernel_size query block, because they do no attend to any keys + q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1 + if q_start_in_seq >= q_len: + return + # init qkv pointer + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_h * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(q_start_in_seq, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(HEAD_DIM, k_len), + strides=(stride_kd, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # load q + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + # init statistics + off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq + off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1 + m_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32) + lse_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32) + acc_o = tl.full((BLOCK_SIZE_Q, BLOCK_SIZE_D), 0, dtype=tl.float32) + # attention + lo = 0 + hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1) + for i in range(lo, hi, BLOCK_SIZE_K): + i = tl.multiple_of(i, BLOCK_SIZE_K) + # load k + k = tl.load(k_ptrs, boundary_check=(1, 0), padding_option="zero") + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where(off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf")) + qk += tl.dot(q, k) * qk_scale + # compute m_ij and l_ij + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + # scale acc_o + acc_o_scale = tl.exp2(m_i - m_ij) + acc_o = acc_o * acc_o_scale[:, None] + # load v and update acc_o + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + p = p.to(v.dtype) + acc_o += tl.dot(p, v) + # update statistics + m_i = m_ij + lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij) + # update ptrs + k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K)) + v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0)) + # final scale + acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None] + # save output + o_ptrs = tl.make_block_ptr( + base=o_ptr + q_start * stride_on + pid_h * stride_oh, + shape=(q_len, HEAD_DIM), + strides=(stride_on, stride_od), + offsets=(q_start_in_seq, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1)) + # save lse + l_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + off_q * stride_ln + tl.store(l_ptrs, lse_i, mask=off_q < q_len) + + +@triton.jit +def backward_sum_o_do( + o_ptr, # O: n x h x d + do_ptr, # dO: n x h x d + delta_ptr, # D: h x n + o_len, + HEAD_DIM, + stride_on, + stride_oh, + stride_od, + stride_don, + stride_doh, + stride_dod, + stride_dh, + stride_dn, + BLOCK_SIZE_O: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_h = tl.program_id(1) + off_n = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O) + off_d = tl.arange(0, BLOCK_SIZE_D) + o = tl.load( + o_ptr + off_n[:, None] * stride_on + pid_h * stride_oh + off_d[None, :] * stride_od, + mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), + other=0, + ).to(tl.float32) + do = tl.load( + do_ptr + off_n[:, None] * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod, + mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), + other=0, + ).to(tl.float32) + delta = tl.sum(o * do, axis=1) + tl.store(delta_ptr + pid_h * stride_dh + off_n * stride_dn, delta, mask=off_n < o_len) + + +@triton.jit +def backward_dkdv( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dk_ptr, # DK: sh x n x kh x d + dv_ptr, # DV: sh x n x kh x d + kernel_size, + kernel_stride, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dks, + stride_dkn, + stride_dkh, + stride_dkd, + stride_dvs, + stride_dvn, + stride_dvh, + stride_dvd, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_kh = pid_h // NUM_SHARE_Q_HEADS + pid_sh = pid_h % NUM_SHARE_Q_HEADS + pid_k = tl.program_id(2) + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if BLOCK_SIZE_K * pid_k >= k_len: + return + # init pointers + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dk_ptrs = tl.make_block_ptr( + base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks, + shape=(k_len, HEAD_DIM), + strides=(stride_dkn, stride_dkd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dv_ptrs = tl.make_block_ptr( + base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs, + shape=(k_len, HEAD_DIM), + strides=(stride_dvn, stride_dvd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + off_k = pid_k * BLOCK_SIZE_K * kernel_stride + tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1 + # load k v and keep in SRAM + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dk dv + dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + q_lo = pid_k * BLOCK_SIZE_K * kernel_stride + kernel_size - 1 + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_h * stride_qh, + shape=(HEAD_DIM, q_len), + strides=(stride_qd, stride_qn), + offsets=(0, q_lo), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q), + order=(0, 1), + ) + do_ptrs = tl.make_block_ptr( + base=do_ptr + q_start * stride_don + pid_h * stride_doh, + shape=(HEAD_DIM, q_len), + strides=(stride_dod, stride_don), + offsets=(0, q_lo), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q), + order=(0, 1), + ) + d_ptrs = tl.make_block_ptr( + base=d_ptr + q_start * stride_dn + pid_h * stride_dh, + shape=(1, q_len), + strides=(0, stride_dn), + offsets=(0, q_lo), + block_shape=(1, BLOCK_SIZE_Q), + order=(1, 0), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + q_start * stride_ln + pid_h * stride_lh, + shape=(1, q_len), + strides=(0, stride_ln), + offsets=(0, q_lo), + block_shape=(1, BLOCK_SIZE_Q), + order=(0, 1), + ) + # loop for q blocks + for i in range(q_lo, q_len, BLOCK_SIZE_Q): + # load + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") + # compute qk + # [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q] + qk = tl.where(off_k[:, None] <= (off_q + i)[None, :], float(0.0), float("-inf")) + qk += tl.dot(k, q) * qk_scale + # compute p, ds + # [BLOCK_SIZE_K, BLOCK_SIE_Q] - [1, BLOCK_SIZE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q] + p = tl.exp2(qk - lse) + # [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q] + dp = tl.dot(v, do) + ds = sm_scale * p * (dp - d) + # cast dtype + p = p.to(do.dtype) + ds = ds.to(q.dtype) + # update dk and dv + # [BLOCK_SIZE_K, BLOCK_SIE_Q] @ [BLOCK_SIE_Q, HEAD_DIM] -> [BLOCK_SIZE_K, HEAD_DIM] + dk += tl.dot(ds, tl.trans(q)) + dv += tl.dot(p, tl.trans(do)) + # increment pointers + q_ptrs = tl.advance(q_ptrs, (0, BLOCK_SIZE_Q)) + do_ptrs = tl.advance(do_ptrs, (0, BLOCK_SIZE_Q)) + lse_ptrs = tl.advance(lse_ptrs, (0, BLOCK_SIZE_Q)) + d_ptrs = tl.advance(d_ptrs, (0, BLOCK_SIZE_Q)) + # save dk dv + tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1)) + tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def backward_dq( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dq_ptr, + kernel_size, + kernel_stride, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dqn, + stride_dqh, + stride_dqd, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_q = tl.program_id(2) + pid_kh = pid_h // NUM_SHARE_Q_HEADS + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + # skip first kernel_size query block, because they do no attend to any keys + q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1 + if q_start_in_seq >= q_len: + return + # init pointers + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_h * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(q_start_in_seq, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + dq_ptrs = tl.make_block_ptr( + base=dq_ptr + q_start * stride_dqn + pid_h * stride_dqh, + shape=(q_len, HEAD_DIM), + strides=(stride_dqn, stride_dqd), + offsets=(q_start_in_seq, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(HEAD_DIM, k_len), + strides=(stride_vd, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + do_ptrs = tl.make_block_ptr( + base=do_ptr + q_start * stride_don + pid_h * stride_doh, + shape=(q_len, HEAD_DIM), + strides=(stride_don, stride_dod), + offsets=(q_start_in_seq, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + d_ptrs = tl.make_block_ptr( + base=d_ptr + q_start * stride_dn + pid_h * stride_dh, + shape=(q_len, 1), + strides=(stride_dn, stride_dh), + offsets=(q_start_in_seq, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + q_start * stride_ln + pid_h * stride_lh, + shape=(q_len, 1), + strides=(stride_ln, stride_lh), + offsets=(q_start_in_seq, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq + off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1 + # load q, do, lse, delta, and keep in SRAM + q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero") + do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dq + dq = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_D), dtype=tl.float32) + lo = 0 + hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1) + for i in range(lo, hi, BLOCK_SIZE_K): + # load + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where(off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf")) + qk += tl.dot(q, tl.trans(k)) * qk_scale + # compute p, ds + p = tl.exp2(qk - lse) + dp = tl.dot(do, v) + ds = sm_scale * p * (dp - d) + # cast dtype + ds = ds.to(q.dtype) + # update dq + dq += tl.dot(ds, k) + # increment pointers + k_ptrs = tl.advance(k_ptrs, (BLOCK_SIZE_K, 0)) + v_ptrs = tl.advance(v_ptrs, (0, BLOCK_SIZE_K)) + # save dq + tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +def _compressed_attention_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kernel_size: int, + kernel_stride: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + sm_scale: float, +): + # dtype check + assert k.dtype == q.dtype and v.dtype == q.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + # shape + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + batch_size = cu_seqlens_q.shape[0] - 1 + assert k_len == v_len and q_len > k_len + # gqa + assert num_k_heads == num_v_heads + assert num_q_heads % num_k_heads == 0 + num_share_q_heads = num_q_heads // num_k_heads + # output tensor + o = torch.zeros_like(q) + lse = torch.full( + (num_q_heads, q_len), + fill_value=-torch.inf, + dtype=torch.float32, + device=q.device, + ) + # launch kernel + grid = lambda META: ( + batch_size, + num_q_heads, + triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), + ) + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = 128 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) + forward_kernel[grid]( + q, + k, + v, + o, + lse, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + lse.stride(0), + lse.stride(1), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + return o, lse + + +def _compressed_attention_bwd( + o: torch.Tensor, + do: torch.Tensor, + lse: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kernel_size: int, + kernel_stride: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + sm_scale: float, +): + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + o_len, num_o_heads, head_dim = o.shape + num_share_q_heads = num_q_heads // num_k_heads + # compute D + delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32) + grid = lambda META: (triton.cdiv(o_len, META["BLOCK_SIZE_O"]), num_o_heads) + BLOCK_SIZE_O = 256 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU) + backward_sum_o_do[grid]( + o, + do, + delta, + o_len, + head_dim, + o.stride(0), + o.stride(1), + o.stride(2), + do.stride(0), + do.stride(1), + do.stride(2), + delta.stride(0), + delta.stride(1), + BLOCK_SIZE_O=BLOCK_SIZE_O, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + # compute dk dv + dk = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + dv = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + batch_size = cu_seqlens_q.shape[0] - 1 + grid = lambda META: ( + batch_size, + num_q_heads, + triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]), + ) + BLOCK_SIZE_Q = 64 + BLOCK_SIZE_K = 128 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU) + backward_dkdv[grid]( + q, + k, + v, + lse, + delta, + do, + dk, + dv, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dk.stride(0), + dk.stride(1), + dk.stride(2), + dk.stride(3), + dv.stride(0), + dv.stride(1), + dv.stride(2), + dv.stride(3), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + dk = dk.sum(0) + dv = dv.sum(0) + # compute dq + dq = torch.zeros_like(q) + grid = lambda META: ( + batch_size, + num_q_heads, + triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), + ) + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = 64 + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) + backward_dq[grid]( + q, + k, + v, + lse, + delta, + do, + dq, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dq.stride(0), + dq.stride(1), + dq.stride(2), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + return dq, dk, dv + + +class CompressedAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kernel_size: int, + kernel_stride: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + sm_scale=None, + ): + # dtype check + assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 + assert q.dtype == k.dtype and k.dtype == v.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + # softmax scale + if sm_scale is None: + sm_scale = 1 / math.sqrt(q.shape[-1]) + + o, lse = _compressed_attention_fwd( + q, + k, + v, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k) + ctx.sm_scale = sm_scale + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.kernel_size = kernel_size + ctx.kernel_stride = kernel_stride + return o, lse + + @staticmethod + def backward(ctx, do: torch.Tensor, *args) -> Any: + q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + max_seqlen_q = ctx.max_seqlen_q + max_seqlen_k = ctx.max_seqlen_k + sm_scale = ctx.sm_scale + kernel_size = ctx.kernel_size + kernel_stride = ctx.kernel_stride + + dq, dk, dv = _compressed_attention_bwd( + o, + do, + lse, + q, + k, + v, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + return dq, dk, dv, None, None, None, None, None, None, None + + +@triton.jit +def score_kernel( + q_ptr, + k_ptr, + lse_ptr, + s_ptr, + kernel_size, + kernel_stride, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_lh, + stride_ln, + stride_sh, + stride_sq, + stride_sk, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_bkh = tl.program_id(0) + pid_b = pid_bkh // NUM_KV_HEADS + pid_kh = pid_bkh % NUM_KV_HEADS + pid_q = tl.program_id(1) + pid_k = tl.program_id(2) + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if pid_q * BLOCK_SIZE_Q >= q_len or pid_k * BLOCK_SIZE_K >= k_len: + return + # init k pointer and load k + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(HEAD_DIM, k_len), + strides=(stride_kd, stride_kn), + offsets=(0, pid_k * BLOCK_SIZE_K), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q + off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K + causal_mask = off_q[:, None] >= (off_k * kernel_stride + kernel_size - 1)[None, :] + # init score + s = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_kh * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + q_start * stride_ln + pid_kh * stride_lh, + shape=(q_len, 1), + strides=(stride_ln, stride_lh), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + # load q and lse + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.dot(q, k) * qk_scale + # compute score + s += tl.where(causal_mask, tl.exp2(qk - lse), 0) + # save output + s_ptrs = tl.make_block_ptr( + base=s_ptr + pid_kh * stride_sh + q_start * stride_sq, + shape=(q_len, k_len), + strides=(stride_sq, stride_sk), + offsets=(pid_q * BLOCK_SIZE_Q, pid_k * BLOCK_SIZE_K), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_K), + order=(1, 0), + ) + tl.store(s_ptrs, s.to(s_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +def _get_attention_score( + q: torch.Tensor, # [total_query_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_key_len, num_k_heads, head_dim] + lse: torch.Tensor, # [num_q_heads, total_query_len] + kernel_size: int, + kernel_stride: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, +) -> torch.Tensor: + # dtype check + assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 + assert q.dtype == k.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + assert lse.dtype == torch.float32 # lse here is log2(sum(exp(qk*scale))), not log(sum(exp(qk*scale))) + # shape + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + batch_size = cu_seqlens_q.shape[0] - 1 + assert q_len > k_len + if sm_scale is None: + sm_scale = 1 / math.sqrt(head_dim) + # gqa + assert num_q_heads % num_k_heads == 0 + num_share_q_heads = num_q_heads // num_k_heads + # init score + score = torch.zeros(num_k_heads, q_len, max_seqlen_k, dtype=torch.float32, device=q.device) + + # launch kernel + grid = lambda META: ( + batch_size * num_k_heads, + triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), + triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]), + ) + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = 128 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + + score_kernel[grid]( + q, + k, + lse, + score, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + lse.stride(0), + lse.stride(1), + score.stride(0), + score.stride(1), + score.stride(2), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=8, + num_stages=3, + ) + return score + + +@triton.jit +def _transform_score_kernel( + s_ptr, # score, shape: [num_heads, q_len, k_len] + bs_ptr, # block wise score: [num_heads, q_len, num_k_block] + offs, + cu_seqlens_q, + # shape + num_heads, + num_offs, + max_k_len, + max_blocks, + pad_len, + # kernel & block size + block_size, + block_stride, # block_size // kernel_stride + init_blocks, + local_blocks, + # stride + stride_sh, + stride_sq, + stride_sk, + stride_bsh, + stride_bsq, + stride_bsk, + TOTAL_QUERY_LEN: tl.constexpr, + BLOCK_SIZE_Q: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_O: tl.constexpr, +): + pid_bh = tl.program_id(0) + pid_b = pid_bh // num_heads + pid_h = pid_bh % num_heads + pid_q = tl.program_id(1) + pid_k = tl.program_id(2) + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = pid_k * BLOCK_SIZE_K + if pid_q * BLOCK_SIZE_Q >= q_len: + return + # load weight + off_o = tl.arange(0, BLOCK_SIZE_O) + w = tl.load(offs + off_o, mask=off_o < num_offs, other=0) + # load score + off_q = pid_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q) + off_k = (k_start + tl.arange(0, BLOCK_SIZE_K)) * block_stride - pad_len + off_k = off_k[None, :] + off_o[:, None] + s_ptrs = ( + s_ptr + + q_start * stride_sq + + pid_h * stride_sh + + off_q[:, None, None] * stride_sq + + off_k[None, :, :] * stride_sk + ) + # weighted sum, [BQ, BO, BK] * [1, BO, 1] -> [BQ, BO, BK] -> [BQ, BK] + s = tl.load( + s_ptrs, + mask=(off_q < q_len)[:, None, None] & (off_k >= 0) & (off_k < max_k_len), + other=0, + ) + s = s * w[None, :, None] + s = tl.sum(s, axis=1) + # init mask and local mask + off_bq = off_q // block_size + off_bk = k_start + tl.arange(0, BLOCK_SIZE_K) + s = tl.where( + ((off_bq[:, None] >= off_bk[None, :]) & (off_bq[:, None] < off_bk[None, :] + local_blocks)) + | (off_bk[None, :] < init_blocks - k_start), + float("inf"), + s, + ) + # store block wise score + bs_ptrs = ( + bs_ptr + q_start * stride_bsq + pid_h * stride_bsh + off_q[:, None] * stride_bsq + off_bk[None, :] * stride_bsk + ) + tl.store( + bs_ptrs, + s, + mask=(off_q < q_len)[:, None] & (off_bk < max_blocks)[None, :], + ) + + +def transform_score( + score: torch.Tensor, + kernel_size: int, + kernel_stride: int, + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + init_blocks: int = 1, + local_blocks: int = 2, +) -> torch.Tensor: + num_k_heads, total_query_len, max_key_len = score.shape + batch_size = cu_seqlens_q.shape[0] - 1 + pad_len = kernel_size // kernel_stride - 1 + max_blocks = math.ceil(max_seqlen_q / block_size) + block_score = torch.zeros( + num_k_heads, + total_query_len, + max_blocks, + dtype=torch.float32, + device=score.device, + ) + offs = ( + torch.arange(kernel_size // kernel_stride, device=score.device)[:, None] + + torch.arange(block_size // kernel_stride, device=score.device)[None, :] + ).view(-1) + + offs = torch.histc(offs, bins=offs.max() + 1, min=0, max=offs.max()) + + num_offs = int(offs.shape[0]) + + BLOCK_SIZE_Q = 16 + BLOCK_SIZE_K = min(128, triton.next_power_of_2(max_blocks)) + BLOCK_SIZE_O = triton.next_power_of_2(num_offs) + + def grid(meta): + grid = ( + num_k_heads * batch_size, + triton.cdiv(total_query_len, BLOCK_SIZE_Q), + triton.cdiv(max_blocks, BLOCK_SIZE_K), + ) + return grid + + _transform_score_kernel[grid]( + score, + block_score, + offs, + cu_seqlens_q, + num_k_heads, + offs.shape[0], + max_key_len, + max_blocks, + pad_len, + block_size, + block_size // kernel_stride, + init_blocks, + local_blocks, + score.stride(0), + score.stride(1), + score.stride(2), + block_score.stride(0), + block_score.stride(1), + block_score.stride(2), + TOTAL_QUERY_LEN=total_query_len, + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_O=BLOCK_SIZE_O, + num_warps=4, + num_stages=3, + ) + return block_score + + +def compressed_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kernel_size: int, + kernel_stride: int, + block_size: int, + topk: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float = None, + init_blocks: int = 1, + local_blocks: int = 2, + parallel_topk_compute: Union[str, bool] = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Attention between query and compressed key and value. Compute attention output and topk block idx used in topk_sparse_attention. + + Args: + q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim] + k (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim] + v (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim] + kernel_size (int): kernel size in compress_key_value + kernel_stride (int): stride of compress_key_value + block_size (int): key value block size for topk sparse attention. + topk (int): number of blocks for each query. + cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen. + cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen. + max_seqlen_q (int): max q len of the batch. + max_seqlen_k (int): max k len of the batch. + sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim). + init_blocks (int, optional): Number of init blocks for each query. Defaults to 1. + local_blocks (int, optional): Number of local blocks for each query. Defaults to 2. + parallel_topk_compute (str, optional): Only set it to False when the sequence length is too long. This can avoid a current bug. + We'll fix this issue later. Defaults to auto, it will be set to False when the sequence length is greater than 32k and True otherwise. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention + """ + + if max_seqlen_q is None: + max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item() + if max_seqlen_k is None: + max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item() + + attn_output, lse = CompressedAttention.apply( + q, + k, + v, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + + # do not select topk index + if topk <= 0: + warnings.warn("topk <= 0, returned topk_idx will be None") + return attn_output, None + + assert topk >= init_blocks + local_blocks + with torch.no_grad(): + num_k_heads, num_q_heads = k.shape[1], q.shape[1] + num_shared_q_heads = num_q_heads // num_k_heads + batch_size = cu_seqlens_q.shape[0] - 1 + q_idx = torch.cat( + [torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=q.device) for i in range(batch_size)], + dim=0, + ) + q_idx = q_idx // block_size + + # whether to use parallel version + if parallel_topk_compute == "auto": + parallel_topk_compute = cu_seqlens_q[-1] <= 32768 + # parallel version + if parallel_topk_compute: + # recompute score + score = _get_attention_score( + q, + k, + lse, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + # transform score to block-wise score + score = transform_score( + score, + kernel_size, + kernel_stride, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + init_blocks, + local_blocks, + ) + # get topk + topk = min(topk, score.shape[-1]) + topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values + topk_idx[topk_idx > q_idx[None, :, None]] = -1 + topk_idx = topk_idx.to(torch.int32) + # non parallel version, avoid some current bugs when sequence length is too long + # FIXME: need to fix later + else: + topk_idx_list = [] + head_tile = 1 + assert num_k_heads % head_tile == 0, f"Num kv heads: {num_k_heads}, head_tile: {head_tile}" + for h in range(num_k_heads // head_tile): + # recompute score + score = _get_attention_score( + q[:, h * num_shared_q_heads * head_tile: (h + 1) * num_shared_q_heads * head_tile], + k[:, h * head_tile: (h + 1) * head_tile], + lse[h * num_shared_q_heads * head_tile: (h + 1) * num_shared_q_heads * head_tile], + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + # transform score to block-wise score + score = transform_score( + score, + kernel_size, + kernel_stride, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + init_blocks, + local_blocks, + ) + # get topk + topk = min(topk, score.shape[-1]) + if score.dtype == torch.float32: + score = score.to(torch.bfloat16) + topk_idx = score.topk(topk, dim=-1, sorted=False).indices + topk_idx = topk_idx.sort(-1).values + + topk_idx[topk_idx > q_idx[None, :, None]] = -1 + topk_idx = topk_idx.to(torch.int32) + topk_idx_list.append(topk_idx) + topk_idx = torch.cat(topk_idx_list, dim=0) + + return attn_output, topk_idx diff --git a/vortex_torch/kernels/nsa/flash_attention.py b/vortex_torch/kernels/nsa/flash_attention.py new file mode 100644 index 0000000..c556a4c --- /dev/null +++ b/vortex_torch/kernels/nsa/flash_attention.py @@ -0,0 +1,886 @@ +# Copyright 2025 Xunhao Lai & Jianqiao Lu. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Any, Optional + +import torch +import triton +import triton.language as tl + +from .utils import get_num_warps_stages, is_hopper_gpu + +IS_HOPPER_GPU = is_hopper_gpu() + + +@triton.jit +def forward_kernel( + q_ptr, # Q: n x h x d + k_ptr, # K: n x h x d + v_ptr, # V: n x h x d + o_ptr, # O: n x h x d + lse_ptr, # LSE: h x n + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # causal + causal, + # gqa + gqa_interleave, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_on, + stride_oh, + stride_od, + stride_lh, + stride_ln, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_q = tl.program_id(2) + if gqa_interleave: + pid_kh = pid_h % NUM_KV_HEADS + else: + pid_kh = pid_h // NUM_SHARE_Q_HEADS + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if BLOCK_SIZE_Q * pid_q >= q_len: + return + # init qkv pointer + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_h * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(HEAD_DIM, k_len), + strides=(stride_kd, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # load q + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + # init statistics + off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q + off_k = tl.arange(0, BLOCK_SIZE_K) + m_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32) + lse_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32) + acc_o = tl.full((BLOCK_SIZE_Q, BLOCK_SIZE_D), 0, dtype=tl.float32) + # full attention or causal attention + lo = 0 + if causal: + hi = min(k_len, (pid_q + 1) * BLOCK_SIZE_Q) + else: + hi = k_len + for i in range(lo, hi, BLOCK_SIZE_K): + i = tl.multiple_of(i, BLOCK_SIZE_K) + # load k + k = tl.load(k_ptrs, boundary_check=(1, 0), padding_option="zero") + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + if causal: + qk += tl.where(off_q[:, None] >= (i + off_k)[None, :], 0, float("-inf")) + else: + qk += tl.where((off_k < k_len - i)[None, :], 0, float("-inf")) + qk += tl.dot(q, k) * qk_scale + # compute m_ij and l_ij + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.math.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + # scale acc_o + acc_o_scale = tl.math.exp2(m_i - m_ij) + acc_o = acc_o * acc_o_scale[:, None] + # load v and update acc_o + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + p = p.to(v.dtype) + acc_o += tl.dot(p, v) + # update statistics + m_i = m_ij + lse_i = m_ij + tl.math.log2(tl.math.exp2(lse_i - m_ij) + l_ij) + # update ptrs + k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K)) + v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0)) + # final scale + acc_o = acc_o * tl.math.exp2(m_i - lse_i)[:, None] + # save output + o_ptrs = tl.make_block_ptr( + base=o_ptr + q_start * stride_on + pid_h * stride_oh, + shape=(q_len, HEAD_DIM), + strides=(stride_on, stride_od), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1)) + # save lse + l_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + off_q * stride_ln + tl.store(l_ptrs, lse_i, mask=off_q < q_len) + + +@triton.jit +def backward_sum_o_do( + o_ptr, # O: n x h x d + do_ptr, # dO: n x h x d + delta_ptr, # D: h x n + o_len, + HEAD_DIM, + stride_on, + stride_oh, + stride_od, + stride_don, + stride_doh, + stride_dod, + stride_dh, + stride_dn, + BLOCK_SIZE_O: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_h = tl.program_id(1) + off_n = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O) + off_d = tl.arange(0, BLOCK_SIZE_D) + o = tl.load( + o_ptr + off_n[:, None] * stride_on + pid_h * stride_oh + off_d[None, :] * stride_od, + mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), + other=0, + ).to(tl.float32) + do = tl.load( + do_ptr + off_n[:, None] * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod, + mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), + other=0, + ).to(tl.float32) + delta = tl.sum(o * do, axis=1) + tl.store(delta_ptr + pid_h * stride_dh + off_n * stride_dn, delta, mask=off_n < o_len) + + +@triton.jit +def backward_dkdv( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dk_ptr, # DK: sh x n x kh x d + dv_ptr, # DV: sh x n x kh x d + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # causal + causal, + # gqa + gqa_interleave, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dks, + stride_dkn, + stride_dkh, + stride_dkd, + stride_dvs, + stride_dvn, + stride_dvh, + stride_dvd, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + if gqa_interleave: + pid_kh = pid_h % NUM_SHARE_Q_HEADS + pid_sh = pid_h // NUM_SHARE_Q_HEADS + else: + pid_kh = pid_h // NUM_SHARE_Q_HEADS + pid_sh = pid_h % NUM_SHARE_Q_HEADS + pid_k = tl.program_id(2) + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if BLOCK_SIZE_K * pid_k >= k_len: + return + # init pointers + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dk_ptrs = tl.make_block_ptr( + base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks, + shape=(k_len, HEAD_DIM), + strides=(stride_dkn, stride_dkd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dv_ptrs = tl.make_block_ptr( + base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs, + shape=(k_len, HEAD_DIM), + strides=(stride_dvn, stride_dvd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K + # load k v and keep in SRAM + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dk dv + dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + # causal + if causal: + q_lo = pid_k * BLOCK_SIZE_K + else: + q_lo = 0 + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_h * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(q_lo, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + do_ptrs = tl.make_block_ptr( + base=do_ptr + q_start * stride_don + pid_h * stride_doh, + shape=(q_len, HEAD_DIM), + strides=(stride_don, stride_dod), + offsets=(q_lo, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + d_ptrs = tl.make_block_ptr( + base=d_ptr + q_start * stride_dn + pid_h * stride_dh, + shape=(q_len, 1), + strides=(stride_dn, stride_dh), + offsets=(q_lo, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + q_start * stride_ln + pid_h * stride_lh, + shape=(q_len, 1), + strides=(stride_ln, stride_lh), + offsets=(q_lo, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + # loop for q blocks + for i in range(q_lo, q_len, BLOCK_SIZE_Q): + # load + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") + # compute qk + if causal: + qk = tl.where((off_q + i)[:, None] >= off_k[None, :], float(0.0), float("-inf")) + else: + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.dot(q, k.T) * qk_scale + # compute p, ds + p = tl.math.exp2(qk - lse) + dp = tl.dot(do, v.T) + ds = sm_scale * p * (dp - d) + # cast dtype + p = p.to(do.dtype) + ds = ds.to(q.dtype) + # update dk and dv + dk += tl.dot(ds.T, q) + dv += tl.dot(p.T, do) + # increment pointers + q_ptrs = tl.advance(q_ptrs, (BLOCK_SIZE_Q, 0)) + do_ptrs = tl.advance(do_ptrs, (BLOCK_SIZE_Q, 0)) + lse_ptrs = tl.advance(lse_ptrs, (BLOCK_SIZE_Q, 0)) + d_ptrs = tl.advance(d_ptrs, (BLOCK_SIZE_Q, 0)) + # save dk dv + tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1)) + tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def backward_dq( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dq_ptr, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # causal + causal, + # gqa + gqa_interleave, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dqn, + stride_dqh, + stride_dqd, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_q = tl.program_id(2) + if gqa_interleave: + pid_kh = pid_h % NUM_KV_HEADS + else: + pid_kh = pid_h // NUM_SHARE_Q_HEADS + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if BLOCK_SIZE_Q * pid_q >= q_len: + return + # init pointers + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_h * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + dq_ptrs = tl.make_block_ptr( + base=dq_ptr + q_start * stride_dqn + pid_h * stride_dqh, + shape=(q_len, HEAD_DIM), + strides=(stride_dqn, stride_dqd), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + do_ptrs = tl.make_block_ptr( + base=do_ptr + q_start * stride_don + pid_h * stride_doh, + shape=(q_len, HEAD_DIM), + strides=(stride_don, stride_dod), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + d_ptrs = tl.make_block_ptr( + base=d_ptr + q_start * stride_dn + pid_h * stride_dh, + shape=(q_len, 1), + strides=(stride_dn, stride_dh), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + q_start * stride_ln + pid_h * stride_lh, + shape=(q_len, 1), + strides=(stride_ln, stride_lh), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q + off_k = tl.arange(0, BLOCK_SIZE_K) + # load q, do, lse, delta, and keep in SRAM + q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero") + do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dq + dq = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_D), dtype=tl.float32) + # causal + if causal: + k_hi = (pid_q + 1) * BLOCK_SIZE_Q + else: + k_hi = k_len + for j in range(0, k_hi, BLOCK_SIZE_K): + # load + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + # compute qk + if causal: + qk = tl.where(off_q[:, None] >= (off_k + j)[None, :], float(0.0), float("-inf")) + else: + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.dot(q, k.T) * qk_scale + # compute p, ds + p = tl.math.exp2(qk - lse) + dp = tl.dot(do, v.T) + ds = sm_scale * p * (dp - d) + # cast dtype + ds = ds.to(q.dtype) + # update dq + dq += tl.dot(ds, k) + # increment pointers + k_ptrs = tl.advance(k_ptrs, (BLOCK_SIZE_K, 0)) + v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0)) + # save dq + tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +def _flash_attention_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + causal: bool, + sm_scale: float, + gqa_interleave: bool = False, +): + # dtype check + assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 + assert k.dtype == q.dtype and v.dtype == q.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + # shape + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + batch_size = cu_seqlens_q.shape[0] - 1 + # assert q_len == k_len and k_len == v_len + # gqa + assert num_k_heads == num_v_heads + assert num_q_heads % num_k_heads == 0 + num_share_q_heads = num_q_heads // num_k_heads + # output tensor + o = torch.empty_like(q) + lse = torch.empty(num_q_heads, q_len, dtype=torch.float32, device=q.device) + # launch kernel + grid = lambda META: ( + batch_size, + num_q_heads, + triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), + ) + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = 64 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) + forward_kernel[grid]( + q, + k, + v, + o, + lse, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + causal, + gqa_interleave, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + lse.stride(0), + lse.stride(1), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + return o, lse + + +def _flash_attention_bwd( + o: torch.Tensor, + do: torch.Tensor, + lse: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + causal: bool, + sm_scale: float, + gqa_interleave: bool = False, +): + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + o_len, num_o_heads, head_dim = o.shape + num_share_q_heads = num_q_heads // num_k_heads + # compute D + delta = torch.empty([num_o_heads, o_len], device=o.device, dtype=torch.float32) + grid = lambda META: (triton.cdiv(o_len, META["BLOCK_SIZE_O"]), num_o_heads) + BLOCK_SIZE_O = 256 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU) + backward_sum_o_do[grid]( + o, + do, + delta, + o_len, + head_dim, + o.stride(0), + o.stride(1), + o.stride(2), + do.stride(0), + do.stride(1), + do.stride(2), + delta.stride(0), + delta.stride(1), + BLOCK_SIZE_O=BLOCK_SIZE_O, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + # compute dk dv + dk = torch.empty(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + dv = torch.empty(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + batch_size = cu_seqlens_q.shape[0] - 1 + grid = lambda META: ( + batch_size, + num_q_heads, + triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]), + ) + BLOCK_SIZE_Q = 64 + BLOCK_SIZE_K = 64 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU) + backward_dkdv[grid]( + q, + k, + v, + lse, + delta, + do, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + causal, + gqa_interleave, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dk.stride(0), + dk.stride(1), + dk.stride(2), + dk.stride(3), + dv.stride(0), + dv.stride(1), + dv.stride(2), + dv.stride(3), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + dk = dk.sum(0) + dv = dv.sum(0) + # compute dq + dq = torch.empty_like(q) + grid = lambda META: ( + batch_size, + num_q_heads, + triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), + ) + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = 64 + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) + backward_dq[grid]( + q, + k, + v, + lse, + delta, + do, + dq, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + causal, + gqa_interleave, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dq.stride(0), + dq.stride(1), + dq.stride(2), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + return dq, dk, dv + + +class FlashAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + causal=True, + sm_scale=None, + gqa_interleave=False, + ): + # softmax scale + if sm_scale is None: + sm_scale = 1 / math.sqrt(q.shape[-1]) + o, lse = _flash_attention_fwd( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + causal, + sm_scale, + gqa_interleave, + ) + ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k) + ctx.sm_scale = sm_scale + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.causal = causal + ctx.gqa_interleave = gqa_interleave + return o + + @staticmethod + def backward(ctx, do: torch.Tensor, *args) -> Any: + q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + max_seqlen_q = ctx.max_seqlen_q + max_seqlen_k = ctx.max_seqlen_k + sm_scale = ctx.sm_scale + causal = ctx.causal + gqa_interleave = ctx.gqa_interleave + dq, dk, dv = _flash_attention_bwd( + o, + do, + lse, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + causal, + sm_scale, + gqa_interleave, + ) + return dq, dk, dv, None, None, None, None, None, None, None + + +def flash_attention_varlen( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + causal: bool = False, + sm_scale: Optional[float] = None, + gqa_interleave: bool = False, +) -> torch.Tensor: + """Flash attention with variable length based on triton. + + Args: + q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim] + k (torch.Tensor): shape [total_kv_len, num_q_heads, head_dim] + v (torch.Tensor): shape [total_kv_len, num_q_heads, head_dim] + cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen. + cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen. + max_seqlen_q (torch.Tensor): max q len of the batch. + max_seqlen_k (torch.Tensor): max k len of the batch. + causal (bool, optional): Causal mask. Defaults to False. + sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim). + gqa_interleave (bool, optional): GQA pattern. Defaults to False, use Llama style GQA. + + Returns: + torch.Tensor: attention output with shape [total_q_len, num_q_heads, head_dim] + """ + return FlashAttention.apply( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + causal, + sm_scale, + gqa_interleave, + ) diff --git a/vortex_torch/kernels/nsa/utils.py b/vortex_torch/kernels/nsa/utils.py new file mode 100644 index 0000000..1f158a1 --- /dev/null +++ b/vortex_torch/kernels/nsa/utils.py @@ -0,0 +1,50 @@ +import torch + + +def is_hopper_gpu(): + if torch.cuda.is_available(): + device_capability = torch.cuda.get_device_capability(0) + major, minor = device_capability + return major == 9 + return False + + +def get_num_warps_stages(head_dim, block_size, is_hopper_gpu): + """ + Returns recommended num_warps and num_stages for a Sparse Attention kernel in Triton. + + Args: + head_dim (int): Size of the head dimension. + block_size (int): Size of the block in the attention matrix. + is_hopper_gpu (bool): True if Hopper GPU, False if Ampere GPU. + + Returns: + tuple: (num_warps, num_stages) recommended values. + """ + # Determine if head_dim and block_size exceed 64 + head_large = head_dim > 64 + block_large = block_size > 64 + + if is_hopper_gpu: + # Hopper GPU recommendations + if head_large and block_large: + num_warps = 8 + num_stages = 3 + elif head_large or block_large: + num_warps = 4 + num_stages = 3 + else: + num_warps = 2 + num_stages = 2 + else: + # Ampere GPU recommendations + if head_large and block_large: + num_warps = 8 + num_stages = 3 + elif head_large or block_large: + num_warps = 8 + num_stages = 3 + else: + num_warps = 2 + num_stages = 2 + return num_warps, num_stages diff --git a/vortex_torch/kernels/nsa/weighted_pool.py b/vortex_torch/kernels/nsa/weighted_pool.py new file mode 100644 index 0000000..abfe9d3 --- /dev/null +++ b/vortex_torch/kernels/nsa/weighted_pool.py @@ -0,0 +1,341 @@ +# Copyright 2025 Xunhao Lai. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +import triton +import triton.language as tl +from einops import einsum + + +@triton.jit +def sliding_pool_fwd_kernel( + x_ptr, + y_ptr, + w_ptr, + cu_seqlens, + y_cu_seqlens, + head_dim, + kernel_size, + kernel_stride, + stride_xn, + stride_xh, + stride_xd, + stride_yn, + stride_yh, + stride_yd, + stride_wh, + stride_wk, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_k = tl.program_id(2) + # get start and len after rmpad + x_start = tl.load(cu_seqlens + pid_b) + x_len = tl.load(cu_seqlens + pid_b + 1) - x_start + y_start = tl.load(y_cu_seqlens + pid_b) + y_len = tl.load(y_cu_seqlens + pid_b + 1) - y_start + if pid_k >= y_len: + return + if w_ptr is not None: + # load w + w_ptrs = tl.make_block_ptr( + base=w_ptr + pid_h * stride_wh, + shape=(kernel_size, 1), + strides=(stride_wk, 0), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, 1), + order=(0, 1), + ) + w = tl.load(w_ptrs, boundary_check=(0, 1), padding_option="zero") + # load x + x_ptrs = tl.make_block_ptr( + base=x_ptr + x_start * stride_xn + pid_h * stride_xh, + shape=(x_len, head_dim), + strides=(stride_xn, stride_xd), + offsets=(pid_k * kernel_stride, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + x = tl.load(x_ptrs, boundary_check=(0, 1), padding_option="zero") + # compute y + if w_ptr is not None: + y = tl.sum(x * w, axis=0) + else: + y = tl.sum(x, axis=0) / kernel_size + off_d = tl.arange(0, BLOCK_SIZE_D) + tl.store( + y_ptr + (y_start + pid_k) * stride_yn + pid_h * stride_yh + off_d * stride_yd, + y.to(y_ptr.dtype.element_ty), + mask=off_d < head_dim, + ) + + +@triton.jit +def sliding_pool_dxdw_kernel( + x_ptr, + dx_ptr, + dy_ptr, + w_ptr, + dw_ptr, + cu_seqlens, + y_cu_seqlens, + head_dim, + kernel_size, + kernel_stride, + stride_xn, + stride_xh, + stride_xd, + stride_dxn, + stride_dxh, + stride_dxd, + stride_dyn, + stride_dyh, + stride_dyd, + stride_wh, + stride_wk, + stride_dwh, + stride_dwn, + stride_dwk, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_k = tl.program_id(2) + # get start and len after rmpad + x_start = tl.load(cu_seqlens + pid_b) + x_len = tl.load(cu_seqlens + pid_b + 1) - x_start + y_start = tl.load(y_cu_seqlens + pid_b) + y_len = tl.load(y_cu_seqlens + pid_b + 1) - y_start + if pid_k >= y_len: + return + # offsets + off_d = tl.arange(0, BLOCK_SIZE_D) + off_k = tl.arange(0, BLOCK_SIZE_K) + if w_ptr is not None: + # load w + w_ptrs = w_ptr + pid_h * stride_wh + off_k * stride_wk + w = tl.load(w_ptrs, mask=off_k < kernel_size, other=0) + # load x + x_ptrs = tl.make_block_ptr( + base=x_ptr + x_start * stride_xn + pid_h * stride_xh, + shape=(head_dim, x_len), + strides=(stride_xd, stride_xn), + offsets=(0, pid_k * kernel_stride), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + x = tl.load(x_ptrs, boundary_check=(0, 1), padding_option="zero") + # load dy + dy_ptrs = dy_ptr + pid_h * stride_dyh + (y_start + pid_k) * stride_dyn + off_d * stride_dyd + dy = tl.load(dy_ptrs, mask=off_d < head_dim, other=0) + if w_ptr is not None: + # compute dx, [1, D] x [K, 1] -> [K, D] + dx = dy[None, :] * w[:, None] + # compute dw, [D, 1] x [D, K] -> [D, K] -> [K] + dw = tl.sum(dy[:, None] * x, axis=0) + # store dw + dw_ptrs = dw_ptr + pid_h * stride_dwh + (y_start + pid_k) * stride_dwn + off_k * stride_dwk + tl.store(dw_ptrs, dw.to(dw_ptr.dtype.element_ty), mask=off_k < kernel_size) + else: + dx = dy[None, :] / kernel_size + # store dx + dx_ptrs = ( + dx_ptr + + pid_h * stride_dxh + + (x_start + pid_k * kernel_stride + off_k[:, None]) * stride_dxn + + off_d[None, :] * stride_dxd + ) + tl.atomic_add( + dx_ptrs, + dx.to(dx_ptr.dtype.element_ty), + mask=(off_k < x_len - pid_k * kernel_stride)[:, None] & (off_d < head_dim)[None, :], + ) + + +class SlidingWindowWeightedPool(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: torch.Tensor, # [total_len, num_heads, head_dim] + w: torch.Tensor, # [num_heads, kernel_size] + cu_seqlens: torch.Tensor, + kernel_size: int, + kernel_stride: int, + ): + # dtype check + assert x.dtype == torch.float16 or x.dtype == torch.bfloat16 + if w is not None: + assert x.dtype == w.dtype + assert cu_seqlens.dtype == torch.int32 + # shape check + total_len, num_heads, head_dim = x.shape + batch_size = cu_seqlens.shape[0] - 1 + if w is not None: + assert w.shape[0] == num_heads + assert w.shape[1] == kernel_size + assert kernel_size % kernel_stride == 0 + assert kernel_size in {16, 32, 64, 128} + # compute seqlens after compression + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1 + # corner case, if sequence_length < kernel_size, no compression for this sequence + y_seqlens[seqlens < kernel_size] = 0 + y_cu_seqlens = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device="cuda"), + torch.cumsum(y_seqlens, dim=0), + ], + dim=0, + ).to(torch.int32) + # output buffer + y = torch.zeros(y_cu_seqlens[-1], num_heads, head_dim, dtype=x.dtype, device=x.device) + # launch kernel + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + BLOCK_SIZE_K = triton.next_power_of_2(kernel_size) + grid = (batch_size, num_heads, y_seqlens.max().item()) + sliding_pool_fwd_kernel[grid]( + x, + y, + w, + cu_seqlens, + y_cu_seqlens, + head_dim, + kernel_size, + kernel_stride, + x.stride(0), + x.stride(1), + x.stride(2), + y.stride(0), + y.stride(1), + y.stride(2), + w.stride(0) if w is not None else None, + w.stride(1) if w is not None else None, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + ) + ctx.save_for_backward(x, w, seqlens, cu_seqlens, y_seqlens, y_cu_seqlens) + ctx.kernel_size = kernel_size + ctx.kernel_stride = kernel_stride + ctx.head_dim = head_dim + return y, y_cu_seqlens + + @staticmethod + def backward(ctx, dy, _): + x, w, seqlens, cu_seqlens, y_seqlens, y_cu_seqlens = ctx.saved_tensors + kernel_size = ctx.kernel_size + kernel_stride = ctx.kernel_stride + head_dim = ctx.head_dim + batch_size = cu_seqlens.shape[0] - 1 + num_heads = x.shape[1] + # compute dx + dx = torch.zeros_like(x, dtype=torch.float32) + if w is not None: + dw = torch.zeros( + num_heads, + y_cu_seqlens[-1], + kernel_size, + dtype=torch.float32, + device=w.device, + ) + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + BLOCK_SIZE_K = triton.next_power_of_2(kernel_size) + grid = (batch_size, num_heads, y_seqlens.max().item()) + sliding_pool_dxdw_kernel[grid]( + x, + dx, + dy, + w, + dw if w is not None else None, + cu_seqlens, + y_cu_seqlens, + head_dim, + kernel_size, + kernel_stride, + x.stride(0), + x.stride(1), + x.stride(2), + dx.stride(0), + dx.stride(1), + dx.stride(2), + dy.stride(0), + dy.stride(1), + dy.stride(2), + w.stride(0) if w is not None else None, + w.stride(1) if w is not None else None, + dw.stride(0) if w is not None else None, + dw.stride(1) if w is not None else None, + dw.stride(2) if w is not None else None, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + ) + dx = dx.to(x.dtype) + if w is None: + dw = None + else: + dw = dw.sum(1).to(w.dtype) + return dx, dw, None, None, None + + +def weightedpool_compress( + x: torch.Tensor, # [total_len, num_heads, head_dim] + w: torch.Tensor, # [num_heads, kernel_size] + cu_seqlens: torch.Tensor, + kernel_size: int, + kernel_stride: int, + pe: Optional[torch.Tensor] = None, +): + y, y_cu_seqlens = SlidingWindowWeightedPool.apply(x, w, cu_seqlens, kernel_size, kernel_stride) + if pe is not None: + assert pe.dtype == x.dtype and pe.device == x.device + bias = einsum(pe, w, "h k d, h k -> h d") + y = y + bias.unsqueeze(0) + return y, y_cu_seqlens + + +def avgpool_compress( + x: torch.Tensor, # [total_len, num_heads, head_dim] + w: torch.Tensor, # don't need weight + cu_seqlens: torch.Tensor, + kernel_size: int, + kernel_stride: int, + pe: Optional[torch.Tensor] = None, +): + assert w is None, "don't need additional weight for avgpool" + y, y_cu_seqlens = SlidingWindowWeightedPool.apply(x, w, cu_seqlens, kernel_size, kernel_stride) + if pe is not None: + assert pe.dtype == x.dtype and pe.device == x.device + bias = torch.mean(pe, dim=1) + y = y + bias.unsqueeze(0) + return y, y_cu_seqlens + + +def softmaxpool_compress( + x: torch.Tensor, + w: torch.Tensor, + cu_seqlens: torch.Tensor, + kernel_size: int, + kernel_stride: int, + pe: Optional[torch.Tensor] = None, +): + y, y_cu_seqlens = SlidingWindowWeightedPool.apply(x, w.softmax(-1), cu_seqlens, kernel_size, kernel_stride) + if pe is not None: + assert pe.dtype == x.dtype and pe.device == x.device + bias = torch.mean(pe, dim=1) + y = y + bias.unsqueeze(0) + return y, y_cu_seqlens From f6ca879efa5db14b7193dbb2fa9c417df579f647 Mon Sep 17 00:00:00 2001 From: UED Date: Mon, 30 Mar 2026 05:21:24 +0000 Subject: [PATCH 12/22] Refactor on the int8 quanitzation, mainly the dequant kernels of int8 --- examples/verify_algo.sh | 2 +- examples/verify_algo_quant.sh | 4 +- examples/verify_sparse_backends.sh | 2 +- vortex_torch/cache/__init__.py | 4 +- vortex_torch/cache/triton_kernels/__init__.py | 18 +- .../cache/triton_kernels/paged_decode_int8.py | 363 ------------- .../triton_kernels/paged_prefill_int8.py | 168 ------ vortex_torch/cache/triton_kernels/set_kv.py | 495 ++++++++++++++++++ 8 files changed, 513 insertions(+), 543 deletions(-) delete mode 100644 vortex_torch/cache/triton_kernels/paged_decode_int8.py delete mode 100644 vortex_torch/cache/triton_kernels/paged_prefill_int8.py diff --git a/examples/verify_algo.sh b/examples/verify_algo.sh index 0dcbe9f..aa01fe6 100644 --- a/examples/verify_algo.sh +++ b/examples/verify_algo.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash set -e -export CUDA_VISIBLE_DEVICES=5 +export CUDA_VISIBLE_DEVICES=6 sparse_algos=( "block_sparse_attention" diff --git a/examples/verify_algo_quant.sh b/examples/verify_algo_quant.sh index c344474..a7601de 100644 --- a/examples/verify_algo_quant.sh +++ b/examples/verify_algo_quant.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash set -e -export CUDA_VISIBLE_DEVICES=0 +export CUDA_VISIBLE_DEVICES=6 sparse_algos=( "block_sparse_attention" @@ -20,6 +20,7 @@ TIMESTAMP=$(date +%Y%m%d_%H%M%S) --vortex-module-name "${algo}" \ --model-name Qwen/Qwen3-1.7B \ --kv-cache-dtype int8 \ + --topk-type naive \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done @@ -34,6 +35,7 @@ TIMESTAMP=$(date +%Y%m%d_%H%M%S) --vortex-module-name "${algo}" \ --model-name Qwen/Qwen3-1.7B \ --kv-cache-dtype fp8_e4m3 \ + --topk-type naive \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done \ No newline at end of file diff --git a/examples/verify_sparse_backends.sh b/examples/verify_sparse_backends.sh index 81b3562..12600d0 100755 --- a/examples/verify_sparse_backends.sh +++ b/examples/verify_sparse_backends.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash set -e -export CUDA_VISIBLE_DEVICES=5 +export CUDA_VISIBLE_DEVICES=6 sparse_algos=( "nsa" diff --git a/vortex_torch/cache/__init__.py b/vortex_torch/cache/__init__.py index 8c4d0e0..6b54905 100644 --- a/vortex_torch/cache/__init__.py +++ b/vortex_torch/cache/__init__.py @@ -29,14 +29,14 @@ from .matmul import GeMM from .elementwise import Relu, Silu, Sigmoid, Abs, Add_Mul from .elementwise_binary import Maximum, Minimum, Multiply, Add -from .triton_kernels import set_kv_buffer_launcher, set_kv_buffer_int8_launcher, set_kv_buffer_fp8_launcher, dequant_paged_int8_to_bf16_inplace +from .triton_kernels import set_kv_buffer_launcher, set_kv_buffer_int8_launcher, set_kv_buffer_fp8_launcher, dequant_pages_to_bf16_inplace __all__ = [ "set_kv_buffer_launcher", "set_kv_buffer_int8_launcher", "set_kv_buffer_fp8_launcher", - "dequant_paged_int8_to_bf16_inplace", + "dequant_pages_to_bf16_inplace", "Mean", "Max", "Min", "L2Norm", "GeMM", "Relu", "Silu", "Sigmoid", "Abs", "Add_Mul", diff --git a/vortex_torch/cache/triton_kernels/__init__.py b/vortex_torch/cache/triton_kernels/__init__.py index 009e728..de4fcbd 100644 --- a/vortex_torch/cache/triton_kernels/__init__.py +++ b/vortex_torch/cache/triton_kernels/__init__.py @@ -1,13 +1,17 @@ -from .set_kv import set_kv_buffer_launcher, set_kv_buffer_int8_launcher, set_kv_buffer_fp8_launcher -from .paged_decode_int8 import paged_decode_int8 -from .paged_prefill_int8 import dequant_paged_int8_to_bf16, dequant_paged_int8_to_bf16_inplace +from .set_kv import ( + set_kv_buffer_launcher, + set_kv_buffer_int8_launcher, + set_kv_buffer_fp8_launcher, + paged_decode, + dequant_pages_to_bf16, + dequant_pages_to_bf16_inplace, +) __all__ = [ "set_kv_buffer_launcher", "set_kv_buffer_int8_launcher", "set_kv_buffer_fp8_launcher", - "paged_decode_int8", - "dequant_paged_int8_to_bf16", - "dequant_paged_int8_to_bf16_inplace", + "paged_decode", + "dequant_pages_to_bf16", + "dequant_pages_to_bf16_inplace", ] - diff --git a/vortex_torch/cache/triton_kernels/paged_decode_int8.py b/vortex_torch/cache/triton_kernels/paged_decode_int8.py deleted file mode 100644 index 4f33cd4..0000000 --- a/vortex_torch/cache/triton_kernels/paged_decode_int8.py +++ /dev/null @@ -1,363 +0,0 @@ -""" -Custom Triton paged decode attention kernel for int8 KV cache. - -Loads int8 K/V pages with per-token float32 scales, dequantizes inline in SRAM, -and computes standard multi-head attention with online softmax. - -Adapted from SGLang's decode_attention.py for use with Vortex's paged layout -where each KV head is treated as a separate "batch" entry. -""" - -import torch -import triton -import triton.language as tl - -_MIN_BLOCK_KV = 32 - - -@triton.jit -def tanh(x): - return 2 * tl.sigmoid(2 * x) - 1 - - -@triton.jit -def _fwd_kernel_int8_stage1( - Q, # [batch, num_qo_heads, head_dim] bf16 - K_Buffer, # int8 paged: flat - V_Buffer, # int8 paged: flat - K_Scale_Buffer, # fp16: flat (one scale per token slot) - V_Scale_Buffer, # fp16: flat - sm_scale, - kv_indptr, # [batch + 1] int32, page-level - kv_indices, # page indices - last_page_len, # [batch] int32, tokens valid in last page - Att_Out, # [batch, num_qo_heads, max_kv_splits, head_dim] - Att_Lse, # [batch, num_qo_heads, max_kv_splits] - num_kv_splits, # [batch] int32 - stride_qbs, - stride_qh, - stride_buf_kbs, # stride per token in K_Buffer (= head_dim) - stride_buf_vbs, # stride per token in V_Buffer (= head_dim) - stride_mid_ob, - stride_mid_oh, - stride_mid_os, - kv_group_num: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_DV: tl.constexpr, - BLOCK_N: tl.constexpr, - MIN_BLOCK_KV: tl.constexpr, - logit_cap: tl.constexpr, - Lk: tl.constexpr, - Lv: tl.constexpr, - PAGE_SIZE: tl.constexpr, -): - """ - Stage 1: For each (batch, head, kv_split), compute partial attention output and LSE. - - kv_indptr is page-level. Total tokens for batch i: - (num_pages - 1) * PAGE_SIZE + last_page_len[i] - """ - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - split_kv_id = tl.program_id(2) - - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_dv = tl.arange(0, BLOCK_DV) - mask_d = offs_d < Lk - mask_dv = offs_dv < Lv - - cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch) - cur_batch_num_pages = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx - cur_last_page_len = tl.load(last_page_len + cur_batch) - # Correct token count accounting for partial last page - cur_batch_seq_len = (cur_batch_num_pages - 1) * PAGE_SIZE + cur_last_page_len - kv_splits = tl.load(num_kv_splits + cur_batch) - - off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d - - kv_len_per_split = ( - tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV - ) - split_kv_start = kv_len_per_split * split_kv_id - split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) - - e_max = -float("inf") - e_sum = 0.0 - acc = tl.zeros([BLOCK_DV], dtype=tl.float32) - - if split_kv_end > split_kv_start: - q = tl.load(Q + off_q, mask=mask_d, other=0.0).to(tl.float32) - - for start_n in range(split_kv_start, split_kv_end, BLOCK_N): - offs_n = start_n + tl.arange(0, BLOCK_N) - mask_n = offs_n < split_kv_end - - # Convert token offsets to page_id + in-page offset - page_indices_in_seq = offs_n // PAGE_SIZE - in_page_offsets = offs_n % PAGE_SIZE - - # Load page indices from kv_indices (physical page IDs) - page_ids = tl.load( - kv_indices + cur_batch_kv_start_idx + page_indices_in_seq, - mask=mask_n, - other=0, - ) - - # Flat token location: physical_page * PAGE_SIZE + in_page_offset - kv_loc = page_ids * PAGE_SIZE + in_page_offsets - - # Load int8 K and dequantize - offs_buf_k = kv_loc[:, None] * stride_buf_kbs + offs_d[None, :] - k_int8 = tl.load( - K_Buffer + offs_buf_k, - mask=mask_n[:, None] & mask_d[None, :], - other=0, - ).to(tl.float32) - - k_scale = tl.load( - K_Scale_Buffer + kv_loc, - mask=mask_n, - other=1.0, - ).to(tl.float32) - k = k_int8 * k_scale[:, None] - - # Compute QK - qk = tl.sum(q[None, :] * k, 1) - qk *= sm_scale - - if logit_cap > 0: - qk = logit_cap * tanh(qk / logit_cap) - - qk = tl.where(mask_n, qk, float("-inf")) - - # Load int8 V and dequantize - offs_buf_v = kv_loc[:, None] * stride_buf_vbs + offs_dv[None, :] - v_int8 = tl.load( - V_Buffer + offs_buf_v, - mask=mask_n[:, None] & mask_dv[None, :], - other=0, - ).to(tl.float32) - - v_scale = tl.load( - V_Scale_Buffer + kv_loc, - mask=mask_n, - other=1.0, - ).to(tl.float32) - v = v_int8 * v_scale[:, None] - - # Online softmax accumulation - n_e_max = tl.maximum(tl.max(qk, 0), e_max) - re_scale = tl.exp(e_max - n_e_max) - p = tl.exp(qk - n_e_max) - acc *= re_scale - acc += tl.sum(p[:, None] * v, 0) - - e_sum = e_sum * re_scale + tl.sum(p, 0) - e_max = n_e_max - - offs_mid_o = ( - cur_batch * stride_mid_ob - + cur_head * stride_mid_oh - + split_kv_id * stride_mid_os - + offs_dv - ) - - tl.store( - Att_Out + offs_mid_o, - acc / e_sum, - mask=mask_dv, - ) - - offs_mid_o_1 = ( - cur_batch * stride_mid_ob - + cur_head * stride_mid_oh - + split_kv_id * stride_mid_os - ) // Lv - - tl.store( - Att_Lse + offs_mid_o_1, - e_max + tl.log(e_sum), - ) - - -@triton.jit -def _fwd_kernel_int8_stage2( - Mid_O, - Mid_O_1, - O, - kv_indptr, - last_page_len, - num_kv_splits, - stride_mid_ob, - stride_mid_oh, - stride_mid_os, - stride_obs, - stride_oh, - MAX_KV_SPLITS: tl.constexpr, - MIN_BLOCK_KV: tl.constexpr, - BLOCK_DV: tl.constexpr, - Lv: tl.constexpr, - PAGE_SIZE: tl.constexpr, -): - """Stage 2: Reduce split outputs via log-sum-exp merge.""" - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - - cur_batch_num_pages = tl.load(kv_indptr + cur_batch + 1) - tl.load(kv_indptr + cur_batch) - cur_last_page_len = tl.load(last_page_len + cur_batch) - cur_batch_seq_len = (cur_batch_num_pages - 1) * PAGE_SIZE + cur_last_page_len - kv_splits = tl.load(num_kv_splits + cur_batch) - - offs_d = tl.arange(0, BLOCK_DV) - mask_d = offs_d < Lv - - e_sum = 0.0 - e_max = -float("inf") - acc = tl.zeros([BLOCK_DV], dtype=tl.float32) - - offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d - offs_logic = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh) // Lv - kv_len_per_split = ( - tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV - ) - - for split_kv_id in range(0, MAX_KV_SPLITS): - split_kv_start = kv_len_per_split * split_kv_id - split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) - - if split_kv_end > split_kv_start: - tv = tl.load( - Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0 - ) - tlogic = tl.load(Mid_O_1 + offs_logic + split_kv_id * stride_mid_os // Lv) - n_e_max = tl.maximum(tlogic, e_max) - - old_scale = tl.exp(e_max - n_e_max) - acc *= old_scale - exp_logic = tl.exp(tlogic - n_e_max) - acc += exp_logic * tv - - e_sum = e_sum * old_scale + exp_logic - e_max = n_e_max - - tl.store( - O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, - acc / e_sum, - mask=mask_d, - ) - - -def paged_decode_int8( - q: torch.Tensor, # [batch, num_qo_heads, head_dim] bf16 - k_buffer: torch.Tensor, # int8 paged K cache - v_buffer: torch.Tensor, # int8 paged V cache - k_scale_buffer: torch.Tensor, # fp16 scale for K - v_scale_buffer: torch.Tensor, # fp16 scale for V - o: torch.Tensor, # [batch, num_qo_heads, head_dim] bf16 output - kv_indptr: torch.Tensor, # [batch + 1] int32, page-level - kv_indices: torch.Tensor, # page indices - last_page_len: torch.Tensor, # [batch] int32 - num_kv_splits: torch.Tensor, # [batch] int32 - max_kv_splits: int, - sm_scale: float, - page_size: int, - logit_cap: float = 0.0, - att_out: torch.Tensor = None, # optional pre-allocated [batch, head_num, max_kv_splits, Lv] - att_lse: torch.Tensor = None, # optional pre-allocated [batch, head_num, max_kv_splits] -): - """ - Paged decode attention with int8 KV cache and inline dequantization. - - kv_indptr is page-level. last_page_len specifies valid tokens in the last page - for each batch entry. Total tokens = (num_pages - 1) * page_size + last_page_len. - """ - batch = q.shape[0] - head_num = q.shape[1] - Lk = q.shape[2] - Lv = Lk - - BLOCK_DMODEL = triton.next_power_of_2(Lk) - BLOCK_DV = triton.next_power_of_2(Lv) - BLOCK_N = 64 - MAX_KV_SPLITS = max_kv_splits - - kv_group_num = head_num - - num_warps = 4 if kv_group_num == 1 else 2 - - # Use pre-allocated buffers if provided, otherwise allocate - if att_out is None: - att_out = torch.empty( - (batch, head_num, MAX_KV_SPLITS, Lv), - dtype=torch.float32, - device=q.device, - ) - else: - att_out = att_out[:batch] - if att_lse is None: - att_lse = torch.empty( - (batch, head_num, MAX_KV_SPLITS), - dtype=torch.float32, - device=q.device, - ) - else: - att_lse = att_lse[:batch] - - stride_buf_kbs = k_buffer.shape[-1] - stride_buf_vbs = v_buffer.shape[-1] - - grid_stage1 = (batch, head_num, MAX_KV_SPLITS) - _fwd_kernel_int8_stage1[grid_stage1]( - q, - k_buffer, - v_buffer, - k_scale_buffer, - v_scale_buffer, - sm_scale, - kv_indptr, - kv_indices, - last_page_len, - att_out, - att_lse, - num_kv_splits, - q.stride(0), - q.stride(1), - stride_buf_kbs, - stride_buf_vbs, - att_out.stride(0), - att_out.stride(1), - att_out.stride(2), - kv_group_num=kv_group_num, - BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_DV=BLOCK_DV, - BLOCK_N=BLOCK_N, - MIN_BLOCK_KV=_MIN_BLOCK_KV, - logit_cap=logit_cap, - num_warps=num_warps, - num_stages=2, - Lk=Lk, - Lv=Lv, - PAGE_SIZE=page_size, - ) - - grid_stage2 = (batch, head_num) - _fwd_kernel_int8_stage2[grid_stage2]( - att_out, - att_lse, - o, - kv_indptr, - last_page_len, - num_kv_splits, - att_out.stride(0), - att_out.stride(1), - att_out.stride(2), - o.stride(0), - o.stride(1), - MAX_KV_SPLITS=MAX_KV_SPLITS, - MIN_BLOCK_KV=_MIN_BLOCK_KV, - BLOCK_DV=BLOCK_DV, - Lv=Lv, - PAGE_SIZE=page_size, - num_warps=4, - num_stages=2, - ) diff --git a/vortex_torch/cache/triton_kernels/paged_prefill_int8.py b/vortex_torch/cache/triton_kernels/paged_prefill_int8.py deleted file mode 100644 index 8927983..0000000 --- a/vortex_torch/cache/triton_kernels/paged_prefill_int8.py +++ /dev/null @@ -1,168 +0,0 @@ -""" -OOM-safe bf16 fallback for int8 KV-cache prefill. - -Instead of implementing full 2D-tiled Triton prefill with int8 dequantization, -this module dequantizes only the accessed KV pages into a compact temporary -bf16 buffer and remaps indices so FlashInfer can operate on the compact buffer. - -This avoids dequantizing the entire global cache buffer. -""" - -import torch -import triton -import triton.language as tl - - -@triton.jit -def _dequant_pages_kernel( - src_int8, # int8 paged buffer [num_pages, page_size, head_dim] flat - src_scale, # fp16 scale buffer [num_pages, page_size, 1] flat - dst_bf16, # bf16 compact buffer [num_accessed_pages, page_size, head_dim] flat - page_indices, # int32 [num_accessed_pages] — which global pages to dequant - NUM_PAGES: tl.constexpr, - PAGE_SIZE: tl.constexpr, - HEAD_DIM: tl.constexpr, - BLOCK_DIM: tl.constexpr, -): - """Dequantize selected int8 pages to bf16 compact buffer.""" - page_idx = tl.program_id(0) # index into page_indices - token_idx = tl.program_id(1) # token within page [0, PAGE_SIZE) - - if page_idx >= NUM_PAGES: - return - - global_page_id = tl.load(page_indices + page_idx) - dims = tl.arange(0, BLOCK_DIM) - mask_dim = dims < HEAD_DIM - - # Source: global_page_id * PAGE_SIZE * HEAD_DIM + token_idx * HEAD_DIM + dims - src_offset = (global_page_id * PAGE_SIZE + token_idx) * HEAD_DIM + dims - val_int8 = tl.load(src_int8 + src_offset, mask=mask_dim, other=0).to(tl.float32) - - # Scale: global_page_id * PAGE_SIZE + token_idx - scale_offset = global_page_id * PAGE_SIZE + token_idx - scale = tl.load(src_scale + scale_offset).to(tl.float32) - - val_bf16 = (val_int8 * scale).to(tl.bfloat16) - - # Destination: page_idx * PAGE_SIZE * HEAD_DIM + token_idx * HEAD_DIM + dims - dst_offset = (page_idx * PAGE_SIZE + token_idx) * HEAD_DIM + dims - tl.store(dst_bf16 + dst_offset, val_bf16, mask=mask_dim) - - -def dequant_paged_int8_to_bf16( - src_int8: torch.Tensor, # int8 [num_pages, page_size, head_dim] - src_scale: torch.Tensor, # fp16 [num_pages, page_size, 1] - page_indices: torch.Tensor, # int32 [num_accessed_pages] - page_size: int, - head_dim: int, - out: torch.Tensor = None, # optional pre-allocated bf16 [>=num_accessed_pages, page_size, head_dim] -) -> torch.Tensor: - """ - Dequantize only the accessed pages from int8 cache to a compact bf16 buffer. - - If `out` is provided, writes into it (must have room for num_accessed_pages). - Otherwise allocates a new buffer. - - Returns: - bf16 tensor of shape [num_accessed_pages, page_size, head_dim] - """ - num_accessed_pages = page_indices.shape[0] - if num_accessed_pages == 0: - if out is not None: - return out[:0] - return torch.empty((0, page_size, head_dim), dtype=torch.bfloat16, device=src_int8.device) - - if out is not None: - dst_bf16 = out[:num_accessed_pages] - else: - dst_bf16 = torch.empty( - (num_accessed_pages, page_size, head_dim), - dtype=torch.bfloat16, - device=src_int8.device, - ) - - BLOCK_DIM = triton.next_power_of_2(head_dim) - - grid = (num_accessed_pages, page_size) - _dequant_pages_kernel[grid]( - src_int8, - src_scale, - dst_bf16, - page_indices, - NUM_PAGES=num_accessed_pages, - PAGE_SIZE=page_size, - HEAD_DIM=head_dim, - BLOCK_DIM=BLOCK_DIM, - ) - - return dst_bf16 - - -@triton.jit -def _dequant_pages_inplace_kernel( - src_int8, # int8 paged buffer flat - src_scale, # scale buffer flat (one scale per token slot) - dst_bf16, # bf16 destination buffer (same page layout as src) - page_indices, # int32 [num_pages] — which global pages to dequant - NUM_PAGES: tl.constexpr, - PAGE_SIZE: tl.constexpr, - HEAD_DIM: tl.constexpr, - BLOCK_DIM: tl.constexpr, -): - """Dequantize selected int8 pages to bf16, writing to the SAME page positions in dst.""" - page_idx = tl.program_id(0) # index into page_indices - token_idx = tl.program_id(1) # token within page [0, PAGE_SIZE) - - if page_idx >= NUM_PAGES: - return - - global_page_id = tl.load(page_indices + page_idx) - dims = tl.arange(0, BLOCK_DIM) - mask_dim = dims < HEAD_DIM - - # Source and destination use the SAME offset (in-place layout) - offset = (global_page_id * PAGE_SIZE + token_idx) * HEAD_DIM + dims - val_int8 = tl.load(src_int8 + offset, mask=mask_dim, other=0).to(tl.float32) - - scale_offset = global_page_id * PAGE_SIZE + token_idx - scale = tl.load(src_scale + scale_offset).to(tl.float32) - - val_bf16 = (val_int8 * scale).to(tl.bfloat16) - - # Write to the SAME page position in dst (not compacted) - tl.store(dst_bf16 + offset, val_bf16, mask=mask_dim) - - -def dequant_paged_int8_to_bf16_inplace( - src_int8: torch.Tensor, # int8 paged cache (flat) - src_scale: torch.Tensor, # fp16 scale buffer (flat) - dst_bf16: torch.Tensor, # bf16 destination (same shape as src_int8) - page_indices: torch.Tensor, # int32 [num_pages] — which pages to dequant - page_size: int, - head_dim: int, -) -> None: - """ - Dequantize selected pages from int8 cache to bf16 IN-PLACE. - - Unlike dequant_paged_int8_to_bf16 (which compacts into a dense buffer), - this writes to the SAME page positions in dst_bf16, preserving the paged layout. - Used to populate the bf16 working buffer for forward_cache (centroid computation). - """ - num_pages = page_indices.shape[0] - if num_pages == 0: - return - - BLOCK_DIM = triton.next_power_of_2(head_dim) - - grid = (num_pages, page_size) - _dequant_pages_inplace_kernel[grid]( - src_int8, - src_scale, - dst_bf16, - page_indices, - NUM_PAGES=num_pages, - PAGE_SIZE=page_size, - HEAD_DIM=head_dim, - BLOCK_DIM=BLOCK_DIM, - ) diff --git a/vortex_torch/cache/triton_kernels/set_kv.py b/vortex_torch/cache/triton_kernels/set_kv.py index 6b289df..58468cc 100644 --- a/vortex_torch/cache/triton_kernels/set_kv.py +++ b/vortex_torch/cache/triton_kernels/set_kv.py @@ -241,3 +241,498 @@ def set_kv_buffer_fp8_launcher( v_scale=v_scale, ) + +# --------------------------------------------------------------------------- +# Dequantization kernels (read direction: quantized paged cache → bf16) +# --------------------------------------------------------------------------- + +@triton.jit +def _dequant_pages_kernel( + src, # quantized paged buffer flat + src_scale, # per-token scale buffer flat (int8 only) + dst, # bf16 destination buffer flat + page_indices, # int32 page indices to dequant + NUM_PAGES, + PAGE_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_DIM: tl.constexpr, + QUANT_TYPE: tl.constexpr, # 1: int8, 2: e4m3, 3: e5m2 + tensor_scale, # float: per-tensor scale (fp8 only) + COMPACT: tl.constexpr, # True: compact dst; False: in-place dst +): + """Unified dequant kernel for selected pages → bf16. + + QUANT_TYPE==1: load int8, multiply by per-token scale from src_scale. + QUANT_TYPE==2: load uint8, bitcast to float8e4nv, multiply by tensor_scale. + QUANT_TYPE==3: load uint8, bitcast to float8e5, multiply by tensor_scale. + COMPACT==True: dst offset uses page_idx (compact buffer). + COMPACT==False: dst offset uses global_page_id (in-place). + """ + page_idx = tl.program_id(0) + token_idx = tl.program_id(1) + + if page_idx >= NUM_PAGES: + return + + global_page_id = tl.load(page_indices + page_idx) + dims = tl.arange(0, BLOCK_DIM) + mask_dim = dims < HEAD_DIM + + src_offset = (global_page_id * PAGE_SIZE + token_idx) * HEAD_DIM + dims + scale_offset = global_page_id * PAGE_SIZE + token_idx + + if QUANT_TYPE == 1: + val = tl.load(src + src_offset, mask=mask_dim, other=0).to(tl.float32) + scale = tl.load(src_scale + scale_offset).to(tl.float32) + result = (val * scale).to(tl.bfloat16) + elif QUANT_TYPE == 2: + raw = tl.load(src + src_offset, mask=mask_dim, other=0) + val = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) + result = (val * tensor_scale).to(tl.bfloat16) + else: # QUANT_TYPE == 3 + raw = tl.load(src + src_offset, mask=mask_dim, other=0) + val = raw.to(tl.float8e5, bitcast=True).to(tl.float32) + result = (val * tensor_scale).to(tl.bfloat16) + + if COMPACT: + dst_offset = (page_idx * PAGE_SIZE + token_idx) * HEAD_DIM + dims + else: + dst_offset = src_offset # same position as source + + tl.store(dst + dst_offset, result, mask=mask_dim) + + +def dequant_pages_to_bf16( + src: torch.Tensor, + src_scale: torch.Tensor, + page_indices: torch.Tensor, + page_size: int, + head_dim: int, + quant_type: int = 1, + tensor_scale: float = 1.0, + out: torch.Tensor = None, +) -> torch.Tensor: + """Dequant selected pages to compact bf16 buffer. + + Args: + quant_type: 1=int8 (per-token scale), 2=fp8 e4m3, 3=fp8 e5m2. + tensor_scale: per-tensor scale (fp8 only, ignored for int8). + out: optional pre-allocated bf16 buffer. + """ + num_accessed_pages = page_indices.shape[0] + if num_accessed_pages == 0: + if out is not None: + return out[:0] + return torch.empty((0, page_size, head_dim), dtype=torch.bfloat16, device=src.device) + + if out is not None: + dst = out[:num_accessed_pages] + else: + dst = torch.empty( + (num_accessed_pages, page_size, head_dim), + dtype=torch.bfloat16, + device=src.device, + ) + + BLOCK_DIM = triton.next_power_of_2(head_dim) + + grid = (num_accessed_pages, page_size) + _dequant_pages_kernel[grid]( + src, src_scale, dst, page_indices, + NUM_PAGES=num_accessed_pages, + PAGE_SIZE=page_size, + HEAD_DIM=head_dim, + BLOCK_DIM=BLOCK_DIM, + QUANT_TYPE=quant_type, + tensor_scale=tensor_scale, + COMPACT=True, + ) + + return dst + + +def dequant_pages_to_bf16_inplace( + src: torch.Tensor, + src_scale: torch.Tensor, + dst: torch.Tensor, + page_indices: torch.Tensor, + page_size: int, + head_dim: int, + quant_type: int = 1, + tensor_scale: float = 1.0, +) -> None: + """Dequant selected pages in-place (same page positions in dst). + + Args: + quant_type: 1=int8 (per-token scale), 2=fp8 e4m3, 3=fp8 e5m2. + tensor_scale: per-tensor scale (fp8 only, ignored for int8). + """ + num_pages = page_indices.shape[0] + if num_pages == 0: + return + + BLOCK_DIM = triton.next_power_of_2(head_dim) + + grid = (num_pages, page_size) + _dequant_pages_kernel[grid]( + src, src_scale, dst, page_indices, + NUM_PAGES=num_pages, + PAGE_SIZE=page_size, + HEAD_DIM=head_dim, + BLOCK_DIM=BLOCK_DIM, + QUANT_TYPE=quant_type, + tensor_scale=tensor_scale, + COMPACT=False, + ) + + +# --------------------------------------------------------------------------- +# Paged decode attention (unified quant_type-parameterized) +# --------------------------------------------------------------------------- + +_MIN_BLOCK_KV = 32 + + +@triton.jit +def _tanh(x): + return 2 * tl.sigmoid(2 * x) - 1 + + +@triton.jit +def _fwd_kernel_paged_decode_stage1( + Q, + K_Buffer, + V_Buffer, + K_Scale_Buffer, + V_Scale_Buffer, + sm_scale, + kv_indptr, + kv_indices, + last_page_len, + Att_Out, + Att_Lse, + num_kv_splits, + stride_qbs, + stride_qh, + stride_buf_kbs, + stride_buf_vbs, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + kv_group_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + MIN_BLOCK_KV: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, + Lv: tl.constexpr, + PAGE_SIZE: tl.constexpr, + QUANT_TYPE: tl.constexpr, # 0: bf16, 1: int8, 2: e4m3, 3: e5m2 + tensor_scale, # per-tensor scale for fp8 +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + split_kv_id = tl.program_id(2) + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv + + cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch) + cur_batch_num_pages = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx + cur_last_page_len = tl.load(last_page_len + cur_batch) + cur_batch_seq_len = (cur_batch_num_pages - 1) * PAGE_SIZE + cur_last_page_len + kv_splits = tl.load(num_kv_splits + cur_batch) + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + + kv_len_per_split = ( + tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV + ) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + e_max = -float("inf") + e_sum = 0.0 + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + q = tl.load(Q + off_q, mask=mask_d, other=0.0).to(tl.float32) + + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_n = offs_n < split_kv_end + + page_indices_in_seq = offs_n // PAGE_SIZE + in_page_offsets = offs_n % PAGE_SIZE + page_ids = tl.load( + kv_indices + cur_batch_kv_start_idx + page_indices_in_seq, + mask=mask_n, other=0, + ) + kv_loc = page_ids * PAGE_SIZE + in_page_offsets + + # Load K with quant-type-dependent dequantization + offs_buf_k = kv_loc[:, None] * stride_buf_kbs + offs_d[None, :] + if QUANT_TYPE == 0: + k = tl.load( + K_Buffer + offs_buf_k, + mask=mask_n[:, None] & mask_d[None, :], other=0, + ).to(tl.float32) + elif QUANT_TYPE == 1: + k_int8 = tl.load( + K_Buffer + offs_buf_k, + mask=mask_n[:, None] & mask_d[None, :], other=0, + ).to(tl.float32) + k_scale = tl.load( + K_Scale_Buffer + kv_loc, mask=mask_n, other=1.0, + ).to(tl.float32) + k = k_int8 * k_scale[:, None] + elif QUANT_TYPE == 2: + raw = tl.load( + K_Buffer + offs_buf_k, + mask=mask_n[:, None] & mask_d[None, :], other=0, + ) + k = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * tensor_scale + else: # QUANT_TYPE == 3 + raw = tl.load( + K_Buffer + offs_buf_k, + mask=mask_n[:, None] & mask_d[None, :], other=0, + ) + k = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * tensor_scale + + qk = tl.sum(q[None, :] * k, 1) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * _tanh(qk / logit_cap) + + qk = tl.where(mask_n, qk, float("-inf")) + + # Load V with quant-type-dependent dequantization + offs_buf_v = kv_loc[:, None] * stride_buf_vbs + offs_dv[None, :] + if QUANT_TYPE == 0: + v = tl.load( + V_Buffer + offs_buf_v, + mask=mask_n[:, None] & mask_dv[None, :], other=0, + ).to(tl.float32) + elif QUANT_TYPE == 1: + v_int8 = tl.load( + V_Buffer + offs_buf_v, + mask=mask_n[:, None] & mask_dv[None, :], other=0, + ).to(tl.float32) + v_scale = tl.load( + V_Scale_Buffer + kv_loc, mask=mask_n, other=1.0, + ).to(tl.float32) + v = v_int8 * v_scale[:, None] + elif QUANT_TYPE == 2: + raw = tl.load( + V_Buffer + offs_buf_v, + mask=mask_n[:, None] & mask_dv[None, :], other=0, + ) + v = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * tensor_scale + else: # QUANT_TYPE == 3 + raw = tl.load( + V_Buffer + offs_buf_v, + mask=mask_n[:, None] & mask_dv[None, :], other=0, + ) + v = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * tensor_scale + + # Online softmax accumulation + n_e_max = tl.maximum(tl.max(qk, 0), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max) + acc *= re_scale + acc += tl.sum(p[:, None] * v, 0) + + e_sum = e_sum * re_scale + tl.sum(p, 0) + e_max = n_e_max + + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + offs_dv + ) + + tl.store(Att_Out + offs_mid_o, acc / e_sum, mask=mask_dv) + + offs_mid_o_1 = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + ) // Lv + + tl.store(Att_Lse + offs_mid_o_1, e_max + tl.log(e_sum)) + + +@triton.jit +def _fwd_kernel_paged_decode_stage2( + Mid_O, + Mid_O_1, + O, + kv_indptr, + last_page_len, + num_kv_splits, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_obs, + stride_oh, + MAX_KV_SPLITS: tl.constexpr, + MIN_BLOCK_KV: tl.constexpr, + BLOCK_DV: tl.constexpr, + Lv: tl.constexpr, + PAGE_SIZE: tl.constexpr, +): + """Stage 2: Reduce split outputs via log-sum-exp merge.""" + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_batch_num_pages = tl.load(kv_indptr + cur_batch + 1) - tl.load(kv_indptr + cur_batch) + cur_last_page_len = tl.load(last_page_len + cur_batch) + cur_batch_seq_len = (cur_batch_num_pages - 1) * PAGE_SIZE + cur_last_page_len + kv_splits = tl.load(num_kv_splits + cur_batch) + + offs_d = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lv + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh) // Lv + kv_len_per_split = ( + tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV + ) + + for split_kv_id in range(0, MAX_KV_SPLITS): + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + if split_kv_end > split_kv_start: + tv = tl.load( + Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0 + ) + tlogic = tl.load(Mid_O_1 + offs_logic + split_kv_id * stride_mid_os // Lv) + n_e_max = tl.maximum(tlogic, e_max) + + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(tlogic - n_e_max) + acc += exp_logic * tv + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, + acc / e_sum, + mask=mask_d, + ) + + +def paged_decode( + q: torch.Tensor, + k_buffer: torch.Tensor, + v_buffer: torch.Tensor, + o: torch.Tensor, + kv_indptr: torch.Tensor, + kv_indices: torch.Tensor, + last_page_len: torch.Tensor, + num_kv_splits: torch.Tensor, + max_kv_splits: int, + sm_scale: float, + page_size: int, + quant_type: int = 0, + k_scale_buffer: torch.Tensor = None, + v_scale_buffer: torch.Tensor = None, + tensor_scale: float = 1.0, + logit_cap: float = 0.0, + att_out: torch.Tensor = None, + att_lse: torch.Tensor = None, +): + """Unified paged decode attention. + + Args: + quant_type: Controls K/V loading: + 0: bf16 (k_scale_buffer/v_scale_buffer unused) + 1: int8 with per-token scales (k_scale_buffer/v_scale_buffer required) + 2: fp8 e4m3 with per-tensor scale (tensor_scale required) + 3: fp8 e5m2 with per-tensor scale (tensor_scale required) + """ + batch = q.shape[0] + head_num = q.shape[1] + Lk = q.shape[2] + Lv = Lk + + BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DV = triton.next_power_of_2(Lv) + BLOCK_N = 128 + MAX_KV_SPLITS = max_kv_splits + + kv_group_num = head_num + num_warps = 4 + + if att_out is None: + att_out = torch.empty( + (batch, head_num, MAX_KV_SPLITS, Lv), + dtype=torch.float32, device=q.device, + ) + else: + att_out = att_out[:batch] + if att_lse is None: + att_lse = torch.empty( + (batch, head_num, MAX_KV_SPLITS), + dtype=torch.float32, device=q.device, + ) + else: + att_lse = att_lse[:batch] + + stride_buf_kbs = k_buffer.shape[-1] + stride_buf_vbs = v_buffer.shape[-1] + + # Use dummy tensors for scale buffers when not needed + _k_scale = k_scale_buffer if k_scale_buffer is not None else k_buffer + _v_scale = v_scale_buffer if v_scale_buffer is not None else v_buffer + + grid_stage1 = (batch, head_num, MAX_KV_SPLITS) + _fwd_kernel_paged_decode_stage1[grid_stage1]( + q, k_buffer, v_buffer, + _k_scale, _v_scale, + sm_scale, kv_indptr, kv_indices, last_page_len, + att_out, att_lse, num_kv_splits, + q.stride(0), q.stride(1), + stride_buf_kbs, stride_buf_vbs, + att_out.stride(0), att_out.stride(1), att_out.stride(2), + kv_group_num=kv_group_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DV=BLOCK_DV, + BLOCK_N=BLOCK_N, + MIN_BLOCK_KV=_MIN_BLOCK_KV, + logit_cap=logit_cap, + num_warps=num_warps, + num_stages=2, + Lk=Lk, Lv=Lv, + PAGE_SIZE=page_size, + QUANT_TYPE=quant_type, + tensor_scale=tensor_scale, + ) + + grid_stage2 = (batch, head_num) + _fwd_kernel_paged_decode_stage2[grid_stage2]( + att_out, att_lse, o, + kv_indptr, last_page_len, num_kv_splits, + att_out.stride(0), att_out.stride(1), att_out.stride(2), + o.stride(0), o.stride(1), + MAX_KV_SPLITS=MAX_KV_SPLITS, + MIN_BLOCK_KV=_MIN_BLOCK_KV, + BLOCK_DV=BLOCK_DV, + Lv=Lv, + PAGE_SIZE=page_size, + num_warps=4, + num_stages=2, + ) + From 19c7fcc573019b87a17f7629eb3602ce8dfd1752 Mon Sep 17 00:00:00 2001 From: UED Date: Tue, 31 Mar 2026 05:00:36 +0000 Subject: [PATCH 13/22] =?UTF-8?q?Add=20TopK=20benchmarking=20suite=20and?= =?UTF-8?q?=20related=20scripts=20-=20Introduced=20a=20comprehensive=20ben?= =?UTF-8?q?chmarking=20suite=20for=20TopK=20kernel=20variants,=20measuring?= =?UTF-8?q?=20kernel-level=20latency.=20-=20Added=20scripts=20for=20offlin?= =?UTF-8?q?e=20calibration=20of=20TopK=20mapping=20modes,=20including:#=20?= =?UTF-8?q?0:=20None=20=20=20=20=20=20=20=20=20=20=20=E2=80=94=20original?= =?UTF-8?q?=20fp16=20bit-pattern=20bucketing=20#=201:=20LUT=20CDF=20=20=20?= =?UTF-8?q?=20=20=20=20=20=E2=80=94=20LUT-based=20CDF=20equalization=20(ca?= =?UTF-8?q?librated)=20#=202:=20Quantile=20=20=20=20=20=20=20=E2=80=94=20p?= =?UTF-8?q?iecewise-linear=20quantile=20mapping=20(calibrated)=20#=203:=20?= =?UTF-8?q?Power=20=20=20=20=20=20=20=20=20=20=E2=80=94=20y=20=3D=20sign(x?= =?UTF-8?q?)=20*=20|x|^p=20#=204:=20Log=20=20=20=20=20=20=20=20=20=20=20?= =?UTF-8?q?=20=E2=80=94=20y=20=3D=20sign(x)=20*=20log(|x|=20+=201)=20#=205?= =?UTF-8?q?:=20Index=20Cache=20=20=20=20=E2=80=94=20reuse=20previous=20lay?= =?UTF-8?q?er's=20indices=20#=206:=20Asinh=20=20=20=20=20=20=20=20=20=20?= =?UTF-8?q?=E2=80=94=20y=20=3D=20asinh(beta=20*=20x)=20#=207:=20Log1p=20?= =?UTF-8?q?=20=20=20=20=20=20=20=20=20=E2=80=94=20y=20=3D=20sign(x)=20*=20?= =?UTF-8?q?log1p(alpha=20*=20|x|)=20#=208:=20Trunc8=20=20=20=20=20=20=20?= =?UTF-8?q?=20=20=E2=80=94=20bf16=20upper-8-bit=20bucketing=20-=20=20Addin?= =?UTF-8?q?g=20various=20remap=20functions=20for=20the=20bucket=20sort=20i?= =?UTF-8?q?n=20sglang=20topk=20kernel,=20with=20evaluation=20and=20visuali?= =?UTF-8?q?zation=20scripts.=20-=20Implemented=20analysis=20tools=20for=20?= =?UTF-8?q?TopK=20distribution=20profiling.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 + benchmarks/README.md | 89 + benchmarks/__init__.py | 0 benchmarks/analyze_topk_distribution.py | 479 +++++ benchmarks/autotune_topk_mapping.py | 378 ++++ benchmarks/bench_topk.py | 587 +++++ benchmarks/calibrate_topk.py | 153 ++ benchmarks/greedy_layer_search.py | 117 + benchmarks/profile_topk_distribution.py | 132 ++ csrc/register.cc | 19 +- csrc/register.h | 19 +- csrc/topk.cu | 2 +- csrc/topk_mapping.cuh | 148 ++ csrc/topk_sglang.cu | 1905 ++++++++++------- examples/README.md | 399 ++++ examples/run_distribution_analysis.sh | 141 ++ examples/run_distribution_analysis_new.sh | 150 ++ examples/run_topk_benchmark.sh | 294 +++ examples/verify_algo.py | 78 +- examples/verify_algo.sh | 5 +- examples/verify_algo_quant.sh | 18 +- examples/verify_algo_topk_mapping.sh | 175 ++ .../verify_algo_topk_mapping_indexcache.sh | 45 + examples/verify_algo_topk_mapping_new.sh | 128 ++ third_party/sglang | 2 +- vortex_torch/indexer/context.py | 26 +- vortex_torch/indexer/output_func.py | 60 +- 27 files changed, 4698 insertions(+), 854 deletions(-) create mode 100644 benchmarks/README.md create mode 100644 benchmarks/__init__.py create mode 100644 benchmarks/analyze_topk_distribution.py create mode 100644 benchmarks/autotune_topk_mapping.py create mode 100644 benchmarks/bench_topk.py create mode 100644 benchmarks/calibrate_topk.py create mode 100644 benchmarks/greedy_layer_search.py create mode 100644 benchmarks/profile_topk_distribution.py create mode 100644 csrc/topk_mapping.cuh create mode 100644 examples/README.md create mode 100755 examples/run_distribution_analysis.sh create mode 100755 examples/run_distribution_analysis_new.sh create mode 100755 examples/run_topk_benchmark.sh create mode 100644 examples/verify_algo_topk_mapping.sh create mode 100644 examples/verify_algo_topk_mapping_indexcache.sh create mode 100644 examples/verify_algo_topk_mapping_new.sh diff --git a/.gitignore b/.gitignore index 931c8ca..6a904c7 100644 --- a/.gitignore +++ b/.gitignore @@ -236,3 +236,6 @@ compile_commands.json # Rust lib Cargo.lock + +/examples/results +*.npy \ No newline at end of file diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 0000000..e390344 --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,89 @@ +# TopK Kernel Benchmarking Suite + +Standalone benchmarking for Vortex's three topk kernel variants, measuring kernel-level latency isolated from the full SGLang inference pipeline. + +## Kernel Variants + +| Kernel | Description | +|--------|-------------| +| `naive` | CUB radix sort (bf16 only) | +| `sglang_m0` | Two-stage hierarchical radix sort, no mapping | +| `sglang_m1` | + LUT mapping (requires `--lut-path`) | +| `sglang_m2` | + Quantile mapping (requires `--quantiles-path`) | +| `sglang_m3` | + Power mapping (configurable via `--mapping-power`) | +| `sglang_m4` | + Log mapping | + +## Quick Start + +```bash +# Activate environment +source /scr/dataset/yuke/xinrui/uv_env/vortex/bin/activate + +# Quick single-config test +python benchmarking/bench_topk.py \ + --batch-sizes 8 \ + --seq-lens 4096 \ + --topk-vals 30 \ + --num-kv-heads 2 \ + --repeat 200 + +# Sweep with histogram analysis +python benchmarking/bench_topk.py \ + --batch-sizes 4 8 16 \ + --seq-lens 2048 4096 8192 \ + --topk-vals 30 64 \ + --num-kv-heads 2 \ + --repeat 100 \ + --histogram + +# Full sweep with JSON output +python benchmarking/bench_topk.py \ + --output-json benchmarking/results.json \ + --histogram +``` + +## CLI Options + +| Argument | Default | Description | +|----------|---------|-------------| +| `--batch-sizes` | 1 4 8 16 32 64 | Batch sizes to sweep | +| `--seq-lens` | 1024 2048 4096 8192 | Sequence lengths to sweep | +| `--topk-vals` | 16 30 64 | TopK values to sweep | +| `--num-kv-heads` | 2 4 8 | KV head counts to sweep | +| `--page-size` | 16 | Tokens per page | +| `--reserved-bos` | 1 | Reserved BOS pages | +| `--reserved-eos` | 2 | Reserved EOS pages | +| `--score-dtype` | bfloat16 | Score tensor dtype (bfloat16 or float32) | +| `--distributions` | normal lognormal uniform | Score distributions to test | +| `--warmup` | 10 | Warmup iterations | +| `--repeat` | 100 | Timed iterations | +| `--mapping-power` | 0.5 | Power parameter for mode=3 | +| `--lut-path` | None | Path to .npy uint8[256] LUT for mode=1 | +| `--quantiles-path` | None | Path to .npy float32[256] quantiles for mode=2 | +| `--output-json` | None | Save results to JSON file | +| `--filter-kernels` | None | Only run specific kernels (e.g., `naive sglang_m0`) | +| `--histogram` | False | Collect bin distribution statistics | + +## Histogram Analysis + +When `--histogram` is passed, each config additionally runs `topk_profile_histogram` and reports: + +- **max/mean ratio**: Peak bin count divided by average (lower = more uniform) +- **Gini coefficient**: Inequality measure of bin distribution (0 = perfectly uniform) +- **nonzero_bins**: How many of the 256 bins received any values + +This shows whether mapping modes improve bin uniformity for a given score distribution. + +## Output Format + +``` +TopK Kernel Benchmark Results +GPU: NVIDIA H100 80GB HBM3 | SM count: 132 + +bs=8 | seq=4096 | topk=30 | heads=2 | pages/seg=256 | dist=normal + naive : 0.0420ms (median) +/- 0.0030ms [min=0.0390, max=0.0510] + sglang mode=0 : 0.0310ms (median) +/- 0.0020ms [min=0.0290, max=0.0380] + sglang mode=3 : 0.0330ms (median) +/- 0.0020ms [min=0.0300, max=0.0400] + sglang mode=4 : 0.0320ms (median) +/- 0.0020ms [min=0.0300, max=0.0390] + histogram stats : max/mean=3.99 gini=0.568 nonzero_bins=70/256 +``` diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/benchmarks/analyze_topk_distribution.py b/benchmarks/analyze_topk_distribution.py new file mode 100644 index 0000000..7d94466 --- /dev/null +++ b/benchmarks/analyze_topk_distribution.py @@ -0,0 +1,479 @@ +""" +TopK distribution analysis and visualization. + +Loads profiling data from: + - profile_topk_distribution.py output (.npz): raw histograms, LUT tables + - bench_topk.py output (.json): benchmark results + per-mode histogram data + +Produces visualization plots for evaluating mapping mode effectiveness. + +Usage: + python scripts/analyze_topk_distribution.py \ + --bench-json bench_hitrate.json \ + --output-dir plots/ + + python scripts/analyze_topk_distribution.py \ + --profile-npz profile_output.npz \ + --bench-json bench_hitrate.json \ + --output-dir plots/ --max-segments 8 +""" + +import argparse +import json +import os +from typing import Optional + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import matplotlib.colors as mcolors +import numpy as np + +# Canonical mapping mode names — shared across all profiling/analysis tools +MAPPING_MODE_NAMES = { + 0: "None", + 1: "LUT CDF", + 2: "Quantile", + 3: "Power", + 4: "Log", + 5: "Index Cache", + 6: "Asinh", + 7: "Log1p", + 8: "Trunc8", +} + +MAPPING_MODE_FORMULAS = { + 0: "None (fp16 bucketing)", + 1: "LUT CDF (calibrated)", + 2: "Quantile (calibrated)", + 3: "Power: sign(x)*|x|^p", + 4: "Log: sign(x)*log(|x|+1)", + 5: "Index Cache", + 6: "Asinh: asinh(beta*x)", + 7: "Log1p: sign(x)*log1p(alpha*|x|)", + 8: "Trunc8: bf16 upper-8-bit bucketing", +} + + +def _mode_key_to_display(mode_key: str) -> str: + """Convert a mode key like 'mode_3' or 'mode_3_Power' to a display name.""" + # Handle new format: "mode_3_Power" + parts = mode_key.split("_", 2) + if len(parts) >= 3: + return parts[2] # e.g. "Power" + # Handle old format: "mode_3" + try: + mode_num = int(parts[1]) + return MAPPING_MODE_NAMES.get(mode_num, mode_key) + except (IndexError, ValueError): + return mode_key + + +def _mode_key_to_number(mode_key: str) -> int: + """Extract the mode number from a key like 'mode_3' or 'mode_3_Power'.""" + parts = mode_key.split("_") + try: + return int(parts[1]) + except (IndexError, ValueError): + return -1 + + +def compute_per_segment_stats(histograms: np.ndarray) -> dict: + """Compute per-row Gini coefficient and max/mean ratio. + + Args: + histograms: [num_segments, 256] array of bin counts + + Returns: + dict with 'gini' and 'max_mean' arrays of shape [num_segments] + """ + num_seg = histograms.shape[0] + ginis = np.zeros(num_seg) + max_means = np.zeros(num_seg) + + for i in range(num_seg): + row = histograms[i].astype(np.float64) + nonzero = row[row > 0] + if len(nonzero) == 0: + continue + + max_means[i] = nonzero.max() / nonzero.mean() + + # Gini coefficient + sorted_vals = np.sort(nonzero) + n = len(sorted_vals) + index = np.arange(1, n + 1, dtype=np.float64) + ginis[i] = (2.0 * (index * sorted_vals).sum() / (n * sorted_vals.sum()) - (n + 1) / n) + ginis[i] = max(0.0, ginis[i]) + + return {"gini": ginis, "max_mean": max_means} + + +def plot_bin_distribution(histograms: np.ndarray, output_dir: str, max_segments: int = 4): + """Plot 256-bin bar chart per segment (first N segments).""" + num_seg = min(histograms.shape[0], max_segments) + for i in range(num_seg): + fig, ax = plt.subplots(figsize=(12, 4)) + ax.bar(range(256), histograms[i], width=1.0, color="steelblue", edgecolor="none") + ax.set_xlabel("Bin") + ax.set_ylabel("Count") + ax.set_title(f"Segment {i}: 256-bin histogram") + ax.set_xlim(-1, 256) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, f"bin_dist_seg_{i}.png"), dpi=150) + plt.close(fig) + print(f" Saved {num_seg} bin distribution plots") + + +def plot_bin_heatmap(histograms: np.ndarray, output_dir: str): + """Heatmap: segments x bins, LogNorm colormap.""" + fig, ax = plt.subplots(figsize=(14, max(4, histograms.shape[0] * 0.15 + 1))) + # Add 1 to avoid log(0) + data = histograms.astype(np.float64) + 1 + im = ax.imshow( + data, + aspect="auto", + cmap="viridis", + norm=mcolors.LogNorm(vmin=1, vmax=data.max()), + interpolation="nearest", + ) + ax.set_xlabel("Bin") + ax.set_ylabel("Segment") + ax.set_title("Bin distribution heatmap (log scale)") + fig.colorbar(im, ax=ax, label="Count + 1") + fig.tight_layout() + fig.savefig(os.path.join(output_dir, "bin_heatmap.png"), dpi=150) + plt.close(fig) + print(" Saved bin_heatmap.png") + + +def plot_before_after_mapping( + raw_histograms: np.ndarray, + lut_table: np.ndarray, + output_dir: str, + max_segments: int = 4, +): + """Side-by-side: raw histogram vs. LUT-remapped histogram.""" + num_seg = min(raw_histograms.shape[0], max_segments) + for i in range(num_seg): + raw = raw_histograms[i] + # Remap: redistribute counts through LUT + remapped = np.zeros(256, dtype=np.float64) + for bin_idx in range(256): + new_bin = int(lut_table[bin_idx]) + remapped[new_bin] += raw[bin_idx] + + fig, axes = plt.subplots(1, 2, figsize=(16, 4), sharey=True) + axes[0].bar(range(256), raw, width=1.0, color="steelblue", edgecolor="none") + axes[0].set_title(f"Segment {i}: Raw (mode=0)") + axes[0].set_xlabel("Bin") + axes[0].set_ylabel("Count") + + axes[1].bar(range(256), remapped, width=1.0, color="darkorange", edgecolor="none") + axes[1].set_title(f"Segment {i}: After LUT remap") + axes[1].set_xlabel("Bin") + + fig.tight_layout() + fig.savefig(os.path.join(output_dir, f"mapping_comparison_{i}.png"), dpi=150) + plt.close(fig) + print(f" Saved {num_seg} mapping comparison plots") + + +def plot_summary_table( + histograms: np.ndarray, + mode_stats_data: Optional[dict], + output_dir: str, +): + """Per-segment stats table: Gini, max/mean, resolution rate.""" + stats = compute_per_segment_stats(histograms) + num_seg = histograms.shape[0] + + col_labels = ["Segment", "Gini", "Max/Mean"] + cell_data = [] + for i in range(num_seg): + cell_data.append([str(i), f"{stats['gini'][i]:.3f}", f"{stats['max_mean'][i]:.2f}"]) + + fig, ax = plt.subplots(figsize=(6, max(2, num_seg * 0.4 + 1))) + ax.axis("off") + table = ax.table(cellText=cell_data, colLabels=col_labels, loc="center", cellLoc="center") + table.auto_set_font_size(False) + table.set_fontsize(9) + table.scale(1.0, 1.3) + ax.set_title("Per-segment distribution stats", fontsize=11, pad=10) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, "summary_table.png"), dpi=150, bbox_inches="tight") + plt.close(fig) + print(" Saved summary_table.png") + + +def plot_distribution_comparison(dist_histograms: dict, output_dir: str, suffix: str = "", title: str = ""): + """Overlay 256-bin distributions for different data sources (uniform, normal, real). + + Args: + dist_histograms: {"uniform": [256], "normal": [256], "real": [256], ...} + output_dir: output directory for the plot + suffix: optional suffix for output filename (e.g. "_m0") + title: optional custom title for the plot + """ + names = list(dist_histograms.keys()) + n = len(names) + if n == 0: + print(" No distribution histograms to compare") + return + + fig, axes = plt.subplots(1, n, figsize=(6 * n, 4), squeeze=False) + axes = axes[0] + + for idx, name in enumerate(names): + counts = np.array(dist_histograms[name], dtype=np.float64) + ax = axes[idx] + ax.bar(range(256), counts, width=1.0, color="steelblue", edgecolor="none") + ax.set_xlabel("Bucket") + ax.set_ylabel("Count") + ax.set_xlim(-1, 256) + ax.set_title(name) + + # Annotate with stats + nonzero = counts[counts > 0] + if len(nonzero) > 0: + mean_val = nonzero.mean() + max_val = nonzero.max() + max_mean = max_val / mean_val if mean_val > 0 else 0.0 + sorted_vals = np.sort(nonzero) + nn = len(sorted_vals) + index = np.arange(1, nn + 1, dtype=np.float64) + gini = max(0.0, 2.0 * (index * sorted_vals).sum() / (nn * sorted_vals.sum()) - (nn + 1) / nn) + nz_bins = int(len(nonzero)) + else: + max_mean = gini = 0.0 + nz_bins = 0 + + stats_text = f"gini={gini:.3f}\nmax/mean={max_mean:.2f}\nbins={nz_bins}/256" + ax.text(0.97, 0.95, stats_text, transform=ax.transAxes, + fontsize=8, verticalalignment="top", horizontalalignment="right", + bbox=dict(boxstyle="round,pad=0.3", facecolor="wheat", alpha=0.7)) + + fig.suptitle(title or "Bucket Distribution Comparison", fontsize=13) + fig.tight_layout() + fname = f"distribution_comparison{suffix}.png" + fig.savefig(os.path.join(output_dir, fname), dpi=150) + plt.close(fig) + print(f" Saved {fname}") + + +def save_bucket_table(dist_histograms: dict, output_dir: str, filename: str = "bucket_counts.csv"): + """Write a CSV table listing the count per bucket for each distribution. + + Columns: bucket, dist1, dist2, ... (256 rows, one per bucket). + """ + import csv + + names = list(dist_histograms.keys()) + if not names: + return + + path = os.path.join(output_dir, filename) + with open(path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["bucket"] + names) + for b in range(256): + row = [b] + [int(dist_histograms[n][b]) for n in names] + writer.writerow(row) + + # Also print a compact summary to stdout (top-20 hottest buckets per dist) + print(f" Saved {path}") + for name in names: + counts = np.array(dist_histograms[name], dtype=np.int64) + total = counts.sum() + top_idx = np.argsort(counts)[::-1][:20] + print(f" [{name}] total={total} top-20 hottest buckets:") + for rank, idx in enumerate(top_idx): + if counts[idx] == 0: + break + pct = counts[idx] / total * 100 if total > 0 else 0 + print(f" #{rank+1:2d} bucket {idx:3d}: {counts[idx]:>10d} ({pct:5.1f}%)") + + +def plot_mapping_mode_comparison(mode_stats_data: dict, output_dir: str): + """Grouped bar chart comparing modes on gini and max/mean.""" + modes = sorted(mode_stats_data.keys()) + if not modes: + print(" No histogram data to plot mode comparison") + return + + mode_labels = [] + for m in modes: + label = _mode_key_to_display(m) + param = mode_stats_data[m].get("param") + if param: + label = f"{label} ({param})" + mode_labels.append(label) + ginis = [mode_stats_data[m]["gini"] for m in modes] + max_means = [mode_stats_data[m]["max_mean_ratio"] for m in modes] + + x = np.arange(len(modes)) + width = 0.3 + + fig, ax1 = plt.subplots(figsize=(10, 5)) + ax2 = ax1.twinx() + + bars1 = ax1.bar(x - width / 2, ginis, width, label="Gini", color="darkorange") + bars2 = ax2.bar(x + width / 2, max_means, width, label="Max/Mean", color="seagreen", alpha=0.7) + + ax1.set_xlabel("Mapping Mode") + ax1.set_ylabel("Gini") + ax2.set_ylabel("Max/Mean Ratio") + ax1.set_xticks(x) + ax1.set_xticklabels(mode_labels, rotation=15, ha="right") + ax1.set_ylim(0, 1.1) + ax1.set_title("Mapping Mode Comparison") + + # Combine legends + lines1, labels1 = ax1.get_legend_handles_labels() + lines2, labels2 = ax2.get_legend_handles_labels() + ax1.legend(lines1 + lines2, labels1 + labels2, loc="upper right") + + fig.tight_layout() + fig.savefig(os.path.join(output_dir, "mode_comparison.png"), dpi=150) + plt.close(fig) + print(" Saved mode_comparison.png") + + +def main(): + parser = argparse.ArgumentParser(description="Analyze TopK bucket sort distribution") + parser.add_argument("--profile-npz", type=str, default=None, + help="Path to .npz from profile_topk_distribution.py") + parser.add_argument("--bench-json", type=str, default=None, + help="Path to JSON from bench_topk.py") + parser.add_argument("--output-dir", type=str, default="plots", + help="Directory for output plots") + parser.add_argument("--max-segments", type=int, default=4, + help="Max segments for per-segment plots") + parser.add_argument("--real-histograms", type=str, default=None, + help="Path to .npy raw_histograms from calibrate_topk.py (real-data bucket counts)") + args = parser.parse_args() + + if args.profile_npz is None and args.bench_json is None and args.real_histograms is None: + parser.error("At least one of --profile-npz, --bench-json, or --real-histograms is required") + + os.makedirs(args.output_dir, exist_ok=True) + print(f"Output directory: {args.output_dir}") + + raw_histograms = None + lut_table = None + mode_stats_data = None + + # Load profile data + if args.profile_npz: + print(f"\nLoading profile data from {args.profile_npz}") + data = np.load(args.profile_npz, allow_pickle=True) + if "raw_histograms" in data: + raw_histograms = data["raw_histograms"] + print(f" raw_histograms: {raw_histograms.shape}") + if "aggregate_lut" in data: + lut_table = data["aggregate_lut"] + print(f" aggregate_lut: {lut_table.shape}") + elif "lut_tables" in data: + # Use first LUT if aggregate not available + lut_table = data["lut_tables"] + if lut_table.ndim > 1: + lut_table = lut_table[0] + print(f" lut_table: {lut_table.shape}") + + # Load bench data + dist_histograms = {} # {distribution_name: [256] counts} for comparison plot + mode_histograms = {} # {mode_key: {dist_name: [256]}} for per-mode plots + + if args.bench_json: + print(f"\nLoading benchmark data from {args.bench_json}") + with open(args.bench_json) as f: + bench_data = json.load(f) + + if bench_data and isinstance(bench_data, list): + # Use first config entry for histogram mode visualization + entry = bench_data[0] + if "histograms" in entry: + mode_stats_data = entry["histograms"] + print(f" Histogram modes: {list(mode_stats_data.keys())}") + + # Extract raw_counts per distribution from bench entries + for entry in bench_data: + dist_name = entry.get("distribution", "unknown") + hist_data = entry.get("histogram", {}) + if "raw_counts" in hist_data and dist_name not in dist_histograms: + dist_histograms[dist_name] = hist_data["raw_counts"] + print(f" Loaded histogram for distribution: {dist_name}") + + # Extract per-mode histograms from histograms data + mode_histograms = {} # {mode_key: {dist_name: [256]}} + for entry in bench_data: + dist_name = entry.get("distribution", "unknown") + histograms_data = entry.get("histograms", {}) + for mode_key, mode_data in histograms_data.items(): + if isinstance(mode_data, dict) and "raw_counts" in mode_data: + if mode_key not in mode_histograms: + mode_histograms[mode_key] = {} + if dist_name not in mode_histograms[mode_key]: + mode_histograms[mode_key][dist_name] = mode_data["raw_counts"] + if mode_histograms: + print(f" Loaded per-mode histograms for: {sorted(mode_histograms.keys())}") + + # Load real-data histograms from .npy (calibrate_topk.py output) + real_counts = None + if args.real_histograms: + print(f"\nLoading real-data histograms from {args.real_histograms}") + real_hists = np.load(args.real_histograms) # [num_samples, 256] + real_counts = real_hists.sum(axis=0).tolist() # aggregate across samples + dist_histograms["real"] = real_counts + print(f" real_histograms shape: {real_hists.shape}, aggregated to [256]") + + # Generate plots + if raw_histograms is not None: + print("\nGenerating histogram plots...") + plot_bin_distribution(raw_histograms, args.output_dir, args.max_segments) + plot_bin_heatmap(raw_histograms, args.output_dir) + plot_summary_table(raw_histograms, mode_stats_data, args.output_dir) + + if lut_table is not None: + print("\nGenerating before/after mapping comparison...") + plot_before_after_mapping(raw_histograms, lut_table, args.output_dir, args.max_segments) + + if mode_stats_data is not None: + print("\nGenerating mode comparison plot...") + plot_mapping_mode_comparison(mode_stats_data, args.output_dir) + + if dist_histograms: + print("\nGenerating distribution comparison plot (raw/unmapped)...") + plot_distribution_comparison(dist_histograms, args.output_dir) + print("\nSaving bucket count table (raw/unmapped)...") + save_bucket_table(dist_histograms, args.output_dir) + + # Per-mode distribution plots and tables + if mode_histograms: + print("\nGenerating per-mode distribution plots and tables...") + for mode_key in sorted(mode_histograms): + mname = _mode_key_to_display(mode_key) + mode_num = _mode_key_to_number(mode_key) + mformula = MAPPING_MODE_FORMULAS.get(mode_num, mname) + # Include hyperparameter value in title if available + param_str = "" + if mode_stats_data and mode_key in mode_stats_data: + param = mode_stats_data[mode_key].get("param") + if param: + param_str = f" [{param}]" + mode_suffix = mname.lower().replace(" ", "_") + plot_distribution_comparison( + mode_histograms[mode_key], args.output_dir, + suffix=f"_{mode_suffix}", + title=f"Bucket Distribution — {mname}{param_str} ({mformula})", + ) + save_bucket_table( + mode_histograms[mode_key], args.output_dir, + filename=f"bucket_counts_{mode_suffix}.csv", + ) + + print(f"\nDone. All outputs saved to {args.output_dir}/") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/autotune_topk_mapping.py b/benchmarks/autotune_topk_mapping.py new file mode 100644 index 0000000..9b37e32 --- /dev/null +++ b/benchmarks/autotune_topk_mapping.py @@ -0,0 +1,378 @@ +""" +Auto-tuner for TopK mapping hyperparameters. + +Sweeps all (mode, hyperparameter) combinations using the topk_hit_rate +kernel and ranks by Stage 1 resolution rate. + +Supports real-data score distributions via --real-histograms: loads the +raw_histograms.npy from calibration and synthesizes score tensors that +match the real bin distribution (by reversing the convert_to_uint8 mapping). + +Sweep grid: + - Mode 3 (power): p in [0.1, 0.25, 0.75, 0.9] + - Mode 6 (asinh): beta in [0.1, 0.5, 1, 2, 4] + - Mode 7 (log1p): alpha in [0.1, 0.5, 0.75, 1, 2, 4, 8] + - Baselines: mode 0 (none), mode 4 (log) + +Usage: + python benchmarks/autotune_topk_mapping.py --topk-val 30 --real-histograms calibration/raw_histograms.npy + python benchmarks/autotune_topk_mapping.py --topk-val 30 --output-json results.json +""" + +import argparse +import json +import math +from typing import List + +import numpy as np +import torch + +from bench_topk import make_topk_inputs, compute_histogram_stats +from vortex_torch_C import topk_profile_histogram + + + +SWEEP_GRID = { + # (mode, param_name, param_values) + 3: ("power_exp", [0.1, 0.25, 0.75, 0.9]), + 6: ("beta", [0.1, 0.5, 1.0, 2.0, 4.0]), + 7: ("alpha", [0.1, 0.5, 0.75, 1.0, 2.0, 4.0, 8.0]), +} +BASELINES = { + 0: ("none", 0.5), + 4: ("log", 0.5), +} +MODE_NAMES = { + 0: "none", + 3: "power", + 4: "log", + 6: "asinh", + 7: "log1p", +} + + +def _key_to_fp16(key: int) -> np.float16: + """Invert the convert_to_uint8 sign-flip for a single 16-bit key.""" + if key >= 0x8000: + bits = key & 0x7FFF + else: + bits = (~key) & 0xFFFF + return np.array([bits], dtype=np.uint16).view(np.float16)[0] + + +def build_bin_range_table(): + """Build per-bin (lo, hi) fp16 value tables by iterating all 65536 fp16 bit patterns. + + For each fp16 value, compute its bin via convert_to_uint8 logic, then track + the min/max fp16 value that lands in each bin. + + Returns: + (bin_lo, bin_hi): two [256] float32 arrays — the min and max fp16 values per bin. + """ + # Generate all 65536 fp16 bit patterns + all_bits = np.arange(65536, dtype=np.uint16) + all_fp16 = all_bits.view(np.float16) + + # Compute convert_to_uint8 for each: key = sign-flip, bin = key >> 8 + keys = np.where( + (all_bits & 0x8000).astype(bool), + (~all_bits).astype(np.uint16), + all_bits | np.uint16(0x8000), + ) + bins = (keys >> 8).astype(np.uint8) + + # Convert to float32 for min/max (fp16 has NaNs/Infs, filter them) + all_f32 = all_fp16.astype(np.float32) + valid = np.isfinite(all_f32) + + bin_lo = np.full(256, np.inf, dtype=np.float32) + bin_hi = np.full(256, -np.inf, dtype=np.float32) + + for b in range(256): + mask = (bins == b) & valid + if mask.any(): + vals = all_f32[mask] + bin_lo[b] = vals.min() + bin_hi[b] = vals.max() + + # For any bin with no valid fp16 values, fall back to midpoint + empty = bin_lo > bin_hi + for b in np.where(empty)[0]: + mid_key = (int(b) << 8) | 0x80 + val = float(_key_to_fp16(mid_key)) + bin_lo[b] = val + bin_hi[b] = val + + return bin_lo, bin_hi + + +def scores_from_histogram( + histogram: np.ndarray, + total_pages: int, + device: str = "cuda", +) -> torch.Tensor: + """Generate score tensor matching a real bin distribution. + + For each sampled bin, generates a uniform random fp16 value within the + bin's actual value range (not just the midpoint), so that mapped transforms + see diverse input values. + + Args: + histogram: [256] aggregated bin counts from calibration + total_pages: number of score entries to generate + device: torch device + + Returns: + scores: [total_pages, 1, 1] bfloat16 tensor + """ + bin_lo, bin_hi = build_bin_range_table() + + # Normalize histogram to probability distribution + counts = histogram.astype(np.float64) + total = counts.sum() + if total == 0: + return torch.zeros(total_pages, 1, 1, dtype=torch.bfloat16, device=device) + probs = counts / total + + # Sample bin indices according to the real distribution + bin_indices = np.random.choice(256, size=total_pages, p=probs) + + # Uniform random within each bin's fp16 range + lo = bin_lo[bin_indices] + hi = bin_hi[bin_indices] + rand = np.random.uniform(0, 1, size=total_pages).astype(np.float32) + scores_f32 = lo + rand * (hi - lo) + + # Convert float32 -> bfloat16 tensor + scores = torch.from_numpy(scores_f32).to(torch.bfloat16) + return scores.reshape(total_pages, 1, 1).to(device) + + +def make_real_inputs( + batch_size: int, + num_kv_heads: int, + seq_len: int, + page_size: int, + topk_val: int, + reserved_bos: int, + reserved_eos: int, + histogram: np.ndarray, + device: str = "cuda", +) -> dict: + """Build CSR-formatted inputs with scores matching a real histogram.""" + eff_batch_size = batch_size * num_kv_heads + num_pages_per_seg = math.ceil(seq_len / page_size) + total_dense_pages = eff_batch_size * num_pages_per_seg + sparse_per_seg = min(topk_val + reserved_bos + reserved_eos, num_pages_per_seg) + total_sparse_pages = eff_batch_size * sparse_per_seg + + dense_kv_indptr = torch.arange( + 0, (eff_batch_size + 1) * num_pages_per_seg, num_pages_per_seg, + dtype=torch.int32, device=device, + ) + sparse_kv_indptr = torch.arange( + 0, (eff_batch_size + 1) * sparse_per_seg, sparse_per_seg, + dtype=torch.int32, device=device, + ) + dense_kv_indices = torch.arange(total_dense_pages, dtype=torch.int32, device=device) + sparse_kv_indices = torch.zeros(total_sparse_pages, dtype=torch.int32, device=device) + + x = scores_from_histogram(histogram, total_dense_pages, device=device) + + return { + "x": x, + "dense_kv_indptr": dense_kv_indptr, + "sparse_kv_indptr": sparse_kv_indptr, + "dense_kv_indices": dense_kv_indices, + "sparse_kv_indices": sparse_kv_indices, + "eff_batch_size": eff_batch_size, + "num_pages_per_seg": num_pages_per_seg, + "sparse_per_seg": sparse_per_seg, + } + + +def run_sweep(args) -> List[dict]: + """Run all (mode, hyperparam) combos and return ranked results.""" + results = [] + + # Load real histogram if provided + real_histogram = None + if args.real_histograms: + raw = np.load(args.real_histograms) # [num_segments, 256] + real_histogram = raw.sum(axis=0) if raw.ndim > 1 else raw # aggregate to [256] + + distributions = args.distributions + if real_histogram is not None: + distributions = ["real"] + + for dist in distributions: + if dist == "real": + inputs = make_real_inputs( + batch_size=args.batch_size, + num_kv_heads=args.num_kv_heads, + seq_len=args.seq_len, + page_size=args.page_size, + topk_val=args.topk_val, + reserved_bos=args.reserved_bos, + reserved_eos=args.reserved_eos, + histogram=real_histogram, + ) + else: + inputs = make_topk_inputs( + batch_size=args.batch_size, + num_kv_heads=args.num_kv_heads, + seq_len=args.seq_len, + page_size=args.page_size, + topk_val=args.topk_val, + reserved_bos=args.reserved_bos, + reserved_eos=args.reserved_eos, + score_dtype=torch.bfloat16, + distribution=dist, + ) + + eff_bs = inputs["eff_batch_size"] + + def evaluate(mode: int, power: float, label: str): + hists = torch.zeros(eff_bs, 256, dtype=torch.int32, device="cuda") + topk_profile_histogram( + inputs["x"], + inputs["dense_kv_indptr"], + hists, + eff_bs, + args.reserved_bos, + args.reserved_eos, + mode, + power, + None, # lut + None, # quantiles + ) + torch.cuda.synchronize() + stats = compute_histogram_stats(hists) + return { + "label": label, + "mode": mode, + "mode_name": MODE_NAMES.get(mode, f"m{mode}"), + "param": power, + "distribution": dist, + "gini": stats["gini"], + "max_mean_ratio": stats["max_mean_ratio"], + "num_nonzero_bins": stats["num_nonzero_bins"], + } + + # Baselines + for mode, (name, default_power) in BASELINES.items(): + r = evaluate(mode, default_power, f"m{mode}_{name}") + results.append(r) + + # Parametric sweep + for mode, (param_name, values) in SWEEP_GRID.items(): + mname = MODE_NAMES[mode] + for val in values: + label = f"m{mode}_{mname}_{param_name}={val}" + r = evaluate(mode, val, label) + results.append(r) + + return results + + +def print_table(results: List[dict]): + """Print ranked results as a formatted table.""" + # Sort by Gini ascending (lower = more uniform = better) + ranked = sorted(results, key=lambda r: r["gini"]) + + header = ( + f"{'Rank':>4s} {'Label':<35s} {'Dist':<12s} " + f"{'Gini':>6s} {'Max/Mean':>8s} {'NZBins':>6s}" + ) + print("\n" + "=" * len(header)) + print("TopK Mapping Auto-Tune Results (ranked by Gini, lower=better)") + print("=" * len(header)) + print(header) + print("-" * len(header)) + + for i, r in enumerate(ranked): + print( + f"{i+1:4d} {r['label']:<35s} {r['distribution']:<12s} " + f"{r['gini']:6.3f} " + f"{r['max_mean_ratio']:8.2f} {r['num_nonzero_bins']:6d}" + ) + + print("=" * len(header)) + if ranked: + best = ranked[0] + print( + f"\nBest overall: {best['label']} (dist={best['distribution']}) " + f"— gini={best['gini']:.3f}, max/mean={best['max_mean_ratio']:.2f}" + ) + + # Per-mode best summary (lowest gini per mode) + mode_best = {} + for r in results: + m = r["mode"] + if m not in mode_best or r["gini"] < mode_best[m]["gini"]: + mode_best[m] = r + + if mode_best: + print("\nBest per mode:") + for m in sorted(mode_best.keys()): + r = mode_best[m] + mname = MODE_NAMES.get(m, f"m{m}") + if m in SWEEP_GRID: + param_name = SWEEP_GRID[m][0] + param_str = f"{param_name}={r['param']}" + else: + param_str = "(baseline)" + print( + f" Mode {m:d} ({mname:>5s}): {param_str:<20s} " + f"gini={r['gini']:.3f} max/mean={r['max_mean_ratio']:.2f}" + ) + + +def main(): + parser = argparse.ArgumentParser( + description="Auto-tune TopK mapping hyperparameters" + ) + parser.add_argument("--batch-size", type=int, default=4) + parser.add_argument("--seq-len", type=int, default=4096) + parser.add_argument("--topk-val", type=int, default=30) + parser.add_argument("--num-kv-heads", type=int, default=2) + parser.add_argument("--page-size", type=int, default=16) + parser.add_argument("--reserved-bos", type=int, default=1) + parser.add_argument("--reserved-eos", type=int, default=2) + parser.add_argument( + "--distributions", nargs="+", + default=["normal"], + help="Score distributions for synthetic data (ignored when --real-histograms is set)", + ) + parser.add_argument( + "--real-histograms", type=str, default=None, + help="Path to raw_histograms.npy from calibration. When set, auto-tunes on " + "real score distribution instead of synthetic data.", + ) + parser.add_argument( + "--output-json", type=str, default=None, + help="Save results to JSON file", + ) + args = parser.parse_args() + + source = f"real ({args.real_histograms})" if args.real_histograms else f"synthetic ({args.distributions})" + print(f"Auto-tuning TopK mapping hyperparameters") + print(f" batch_size={args.batch_size}, seq_len={args.seq_len}, " + f"topk_val={args.topk_val}, num_kv_heads={args.num_kv_heads}") + print(f" score source: {source}") + n_parametric = sum(len(v) for _, v in SWEEP_GRID.values()) + n_dists = 1 if args.real_histograms else len(args.distributions) + print(f" sweep: {n_parametric} parametric + {len(BASELINES)} baselines " + f"= {n_parametric + len(BASELINES)} combos x {n_dists} dists") + + results = run_sweep(args) + print_table(results) + + if args.output_json: + with open(args.output_json, "w") as f: + json.dump(results, f, indent=2) + print(f"\nResults saved to {args.output_json}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/bench_topk.py b/benchmarks/bench_topk.py new file mode 100644 index 0000000..ca039f2 --- /dev/null +++ b/benchmarks/bench_topk.py @@ -0,0 +1,587 @@ +""" +TopK kernel benchmarking suite. + +Measures kernel-level latency for the three topk variants (naive/CUB, +sglang with mapping modes) across configurable grid of batch sizes, +sequence lengths, topk values, and KV head counts. + +Usage: + python benchmarking/bench_topk.py --batch-sizes 4 8 --seq-lens 2048 4096 --topk-vals 30 --num-kv-heads 2 --repeat 50 +""" + +import argparse +import json +import math +import statistics +from typing import Dict, List, Optional + +import numpy as np +import torch + +from vortex_torch_C import topk_output, topk_output_sglang, topk_profile_histogram + +# Canonical mapping mode names — used in logs, tables, and plots +MAPPING_MODE_NAMES = { + 0: "None", + 1: "LUT CDF", + 2: "Quantile", + 3: "Power", + 4: "Log", + 5: "Index Cache", + 6: "Asinh", + 7: "Log1p", + 8: "Trunc8", +} + +MAPPING_MODE_FORMULAS = { + 0: "None (fp16 bucketing)", + 1: "LUT CDF (calibrated)", + 2: "Quantile (calibrated)", + 3: "Power: sign(x)*|x|^p", + 4: "Log: sign(x)*log(|x|+1)", + 5: "Index Cache", + 6: "Asinh: asinh(beta*x)", + 7: "Log1p: sign(x)*log1p(alpha*|x|)", + 8: "Trunc8: bf16 upper-8-bit bucketing", +} + + +def make_topk_inputs( + batch_size: int, + num_kv_heads: int, + seq_len: int, + page_size: int, + topk_val: int, + reserved_bos: int, + reserved_eos: int, + score_dtype: torch.dtype, + distribution: str = "normal", + device: str = "cuda", +) -> dict: + """Synthesize realistic CSR-formatted paged attention inputs.""" + eff_batch_size = batch_size * num_kv_heads + num_pages_per_seg = math.ceil(seq_len / page_size) + total_dense_pages = eff_batch_size * num_pages_per_seg + sparse_per_seg = min(topk_val + reserved_bos + reserved_eos, num_pages_per_seg) + total_sparse_pages = eff_batch_size * sparse_per_seg + + dense_kv_indptr = torch.arange( + 0, (eff_batch_size + 1) * num_pages_per_seg, num_pages_per_seg, + dtype=torch.int32, device=device, + ) + sparse_kv_indptr = torch.arange( + 0, (eff_batch_size + 1) * sparse_per_seg, sparse_per_seg, + dtype=torch.int32, device=device, + ) + dense_kv_indices = torch.arange(total_dense_pages, dtype=torch.int32, device=device) + sparse_kv_indices = torch.zeros(total_sparse_pages, dtype=torch.int32, device=device) + + # Generate scores with the requested distribution + if distribution == "normal": + x = torch.randn(total_dense_pages, 1, 1, device=device) + elif distribution == "lognormal": + x = torch.randn(total_dense_pages, 1, 1, device=device).exp() + elif distribution == "uniform": + x = torch.rand(total_dense_pages, 1, 1, device=device) + elif distribution == "bucket_uniform": + # Uniform across all 256 fp16 radix buckets. + # Random uint16 bit patterns → interpret as fp16. + # Bucket = upper 8 bits of sign-flipped fp16, so random bits → uniform buckets. + raw_bits = torch.randint(0, 65536, (total_dense_pages,), dtype=torch.int32, device=device) + # Exclude fp16 NaN/Inf (exponent=31, i.e. |bits| >= 0x7C00) + abs_bits = raw_bits & 0x7FFF + raw_bits[abs_bits >= 0x7C00] = raw_bits[abs_bits >= 0x7C00] & 0x8000 # → ±0 + # Reinterpret int16 bits as fp16, then widen to float32 + x = raw_bits.to(torch.int16).view(torch.float16).float().reshape(total_dense_pages, 1, 1) + else: + raise ValueError(f"Unknown distribution: {distribution}") + + x = x.to(score_dtype) + + return { + "x": x, + "dense_kv_indptr": dense_kv_indptr, + "sparse_kv_indptr": sparse_kv_indptr, + "dense_kv_indices": dense_kv_indices, + "sparse_kv_indices": sparse_kv_indices, + "eff_batch_size": eff_batch_size, + "num_pages_per_seg": num_pages_per_seg, + "sparse_per_seg": sparse_per_seg, + } + + +def bench_kernel(kernel_fn, args, warmup: int = 10, repeat: int = 100) -> dict: + """Time a kernel with CUDA events, return latency stats in ms.""" + for _ in range(warmup): + kernel_fn(*args) + torch.cuda.synchronize() + + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(repeat)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(repeat)] + for i in range(repeat): + start_events[i].record() + kernel_fn(*args) + end_events[i].record() + torch.cuda.synchronize() + + times = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] + return { + "mean_ms": statistics.mean(times), + "median_ms": statistics.median(times), + "std_ms": statistics.stdev(times) if len(times) > 1 else 0.0, + "min_ms": min(times), + "max_ms": max(times), + } + + +def compute_histogram_stats(histograms: torch.Tensor) -> dict: + """Compute bin distribution statistics from histogram tensor [B, 256].""" + h = histograms.float() + # Aggregate across batch dimension + h_sum = h.sum(dim=0) # [256] + nonzero_bins = h_sum[h_sum > 0] + if len(nonzero_bins) == 0: + return { + "max_mean_ratio": 0.0, "std": 0.0, "gini": 0.0, + "num_nonzero_bins": 0, "entropy": 0.0, "effective_bins": 0.0, + } + + mean_val = nonzero_bins.mean().item() + max_val = nonzero_bins.max().item() + std_val = nonzero_bins.std().item() if len(nonzero_bins) > 1 else 0.0 + + # Gini coefficient + sorted_bins = nonzero_bins.sort().values + n = len(sorted_bins) + index = torch.arange(1, n + 1, device=sorted_bins.device, dtype=torch.float32) + gini = (2.0 * (index * sorted_bins).sum() / (n * sorted_bins.sum()) - (n + 1) / n).item() + + # Shannon entropy (base-2) + p = nonzero_bins / nonzero_bins.sum() + entropy = -(p * p.log2()).sum().item() + # Effective number of bins: 2^entropy + effective_bins = 2 ** entropy + + return { + "max_mean_ratio": max_val / mean_val if mean_val > 0 else 0.0, + "std": std_val, + "gini": max(0.0, gini), + "num_nonzero_bins": int(len(nonzero_bins)), + "entropy": entropy, + "effective_bins": effective_bins, + } + + +NUM_HISTOGRAM_BINS = 256 + + +def _histogram_target_pages(pages_per_seg: int, min_samples_per_bin: int = 512) -> int: + """Compute adaptive page count for statistically reliable histograms. + + With 256 radix bins, each bin needs enough samples for stable gini / + max-mean statistics. Returns a total page count rounded up to a full + segment boundary so every segment contributes equally. + """ + min_pages = min_samples_per_bin * NUM_HISTOGRAM_BINS + return math.ceil(min_pages / pages_per_seg) * pages_per_seg + + +def _load_autotune_powers(path: str) -> Dict[int, float]: + """Extract best per-mode power from autotune JSON. + + Ranks by res_rate_mean (higher=better) if present, else by gini (lower=better). + Returns {mode: best_power}, e.g. {3: 0.25, 6: 1.0, 7: 2.0}. + """ + with open(path) as f: + data = json.load(f) + + has_res_rate = any("res_rate_mean" in r for r in data) + + best: Dict[int, dict] = {} + for r in data: + m = r.get("mode") + if m not in (3, 6, 7): + continue + if has_res_rate: + score = r.get("res_rate_mean", 0.0) + is_better = m not in best or score > best[m]["_score"] + else: + score = r.get("gini", 1.0) + is_better = m not in best or score < best[m]["_score"] + if is_better: + best[m] = {"param": r["param"], "_score": score} + + return {m: v["param"] for m, v in best.items()} + + +def _resolve_mode_power(args, mode: int) -> float: + """Return the power/beta/alpha for a parametric mapping mode. + + Priority: per-mode CLI flag > autotune JSON > global --mapping-power. + """ + per_mode_flag = {3: args.mapping_power_3, 6: args.mapping_power_6, 7: args.mapping_power_7} + if mode in per_mode_flag and per_mode_flag[mode] is not None: + return per_mode_flag[mode] + if hasattr(args, "_autotune_powers") and mode in args._autotune_powers: + return args._autotune_powers[mode] + return args.mapping_power + + +def run_benchmark(args) -> List[dict]: + """Run the full benchmark sweep and return results.""" + # Load autotune results if provided + if args.autotune_json: + args._autotune_powers = _load_autotune_powers(args.autotune_json) + print(f"Loaded autotune best powers: {args._autotune_powers}") + else: + args._autotune_powers = {} + + dtype_map = {"bfloat16": torch.bfloat16, "float32": torch.float32} + score_dtype = dtype_map[args.score_dtype] + + # Load real histogram if provided + real_histogram = None + _scores_from_histogram = None + if args.real_histograms: + from autotune_topk_mapping import scores_from_histogram + _scores_from_histogram = scores_from_histogram + raw = np.load(args.real_histograms) + real_histogram = raw.sum(axis=0) if raw.ndim > 1 else raw + + # Extend distributions with "real" if calibration data is provided + distributions = list(args.distributions) + if real_histogram is not None: + distributions.append("real") + args.distributions = distributions + + # Print GPU info + gpu_name = torch.cuda.get_device_name(0) + gpu_props = torch.cuda.get_device_properties(0) + print(f"TopK Kernel Benchmark Results") + print(f"GPU: {gpu_name} | SM count: {gpu_props.multi_processor_count}") + print(f"Score dtype: {args.score_dtype} | Warmup: {args.warmup} | Repeat: {args.repeat}") + print("=" * 90) + + # Load optional LUT / quantiles + mapping_lut = None + mapping_quantiles = None + if args.lut_path: + lut_np = np.load(args.lut_path).astype(np.uint8) + mapping_lut = torch.from_numpy(lut_np).cuda() + if args.quantiles_path: + q_np = np.load(args.quantiles_path).astype(np.float32) + mapping_quantiles = torch.from_numpy(q_np).cuda() + + # Build kernel list + all_kernels = { + "naive": "naive", + "sglang_m0": "sglang_m0", + "sglang_m3": "sglang_m3", + "sglang_m4": "sglang_m4", + "sglang_m6": "sglang_m6", + "sglang_m7": "sglang_m7", + "sglang_m8": "sglang_m8", + } + if mapping_lut is not None: + all_kernels["sglang_m1"] = "sglang_m1" + if mapping_quantiles is not None: + all_kernels["sglang_m2"] = "sglang_m2" + + if args.filter_kernels: + all_kernels = {k: v for k, v in all_kernels.items() if k in args.filter_kernels} + + # Naive kernel only supports bf16 + if score_dtype != torch.bfloat16 and "naive" in all_kernels: + print(f"Note: naive kernel only supports bfloat16, skipping for {args.score_dtype}") + del all_kernels["naive"] + + all_results = [] + + for bs in args.batch_sizes: + for seq_len in args.seq_lens: + for topk_val in args.topk_vals: + for num_kv_heads in args.num_kv_heads: + for dist in args.distributions: + if dist == "real" and real_histogram is not None: + inputs = make_topk_inputs( + batch_size=bs, + num_kv_heads=num_kv_heads, + seq_len=seq_len, + page_size=args.page_size, + topk_val=topk_val, + reserved_bos=args.reserved_bos, + reserved_eos=args.reserved_eos, + score_dtype=score_dtype, + distribution="normal", + ) + # Replace scores with real-distribution scores + total_dense = inputs["eff_batch_size"] * inputs["num_pages_per_seg"] + inputs["x"] = _scores_from_histogram( + real_histogram, total_dense, device="cuda", + ) + else: + inputs = make_topk_inputs( + batch_size=bs, + num_kv_heads=num_kv_heads, + seq_len=seq_len, + page_size=args.page_size, + topk_val=topk_val, + reserved_bos=args.reserved_bos, + reserved_eos=args.reserved_eos, + score_dtype=score_dtype, + distribution=dist, + ) + + eff_bs = inputs["eff_batch_size"] + pages_per_seg = inputs["num_pages_per_seg"] + + config_str = ( + f"bs={bs} | seq={seq_len} | topk={topk_val} | " + f"heads={num_kv_heads} | pages/seg={pages_per_seg} | dist={dist}" + ) + print(f"\n{config_str}") + + config_results = { + "batch_size": bs, + "seq_len": seq_len, + "topk_val": topk_val, + "num_kv_heads": num_kv_heads, + "distribution": dist, + "eff_batch_size": eff_bs, + "pages_per_seg": pages_per_seg, + "kernels": {}, + } + + for kernel_name in all_kernels: + # Reset sparse indices each run + inputs["sparse_kv_indices"].zero_() + + if kernel_name == "naive": + # topk_output: (x, dense_indptr, dense_indices, sparse_indptr, sparse_indices, ...) + call_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indptr"], + inputs["sparse_kv_indices"], + eff_bs, + topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + ) + result = bench_kernel(topk_output, call_args, args.warmup, args.repeat) + else: + # Parse mapping mode from kernel name + mode = int(kernel_name.split("_m")[1]) + extra_kwargs = {} + if mode == 1: + extra_kwargs["mapping_lut"] = mapping_lut + elif mode == 2: + extra_kwargs["mapping_quantiles"] = mapping_quantiles + + power = _resolve_mode_power(args, mode) if mode in (3, 6, 7) else 0.5 + + # topk_output_sglang: (x, dense_indptr, sparse_indptr, dense_indices, sparse_indices, ...) + call_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, + topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + mode, + power, + extra_kwargs.get("mapping_lut", None), + extra_kwargs.get("mapping_quantiles", None), + ) + result = bench_kernel(topk_output_sglang, call_args, args.warmup, args.repeat) + + if kernel_name == "naive": + label = "naive" + else: + m = int(kernel_name.split("_m")[1]) + mname = MAPPING_MODE_NAMES.get(m, f'm{m}') + if m in (3, 6, 7): + pname = {3: "p", 6: "beta", 7: "alpha"}[m] + label = f"sglang {mname} ({pname}={_resolve_mode_power(args, m)})" + else: + label = f"sglang {mname}" + print( + f" {label:<30s}: {result['median_ms']:.4f}ms (median) " + f"\u00b1 {result['std_ms']:.4f}ms " + f"[min={result['min_ms']:.4f}, max={result['max_ms']:.4f}]" + ) + config_results["kernels"][kernel_name] = result + + # Histogram analysis + if args.histogram: + # Build a separate (potentially larger) dataset for histogram profiling + target_pages = (args.histogram_pages + if args.histogram_pages is not None + else _histogram_target_pages(pages_per_seg)) + current_pages = eff_bs * pages_per_seg + if target_pages > current_pages: + hist_bs = math.ceil(target_pages / (num_kv_heads * pages_per_seg)) + if dist == "real" and real_histogram is not None: + hist_inputs = make_topk_inputs( + batch_size=hist_bs, num_kv_heads=num_kv_heads, + seq_len=seq_len, page_size=args.page_size, + topk_val=topk_val, reserved_bos=args.reserved_bos, + reserved_eos=args.reserved_eos, score_dtype=score_dtype, + distribution="normal", + ) + total_hist_dense = hist_inputs["eff_batch_size"] * hist_inputs["num_pages_per_seg"] + hist_inputs["x"] = _scores_from_histogram(real_histogram, total_hist_dense, device="cuda") + else: + hist_inputs = make_topk_inputs( + batch_size=hist_bs, num_kv_heads=num_kv_heads, + seq_len=seq_len, page_size=args.page_size, + topk_val=topk_val, reserved_bos=args.reserved_bos, + reserved_eos=args.reserved_eos, score_dtype=score_dtype, + distribution=dist, + ) + hist_eff_bs = hist_inputs["eff_batch_size"] + actual_pages = hist_eff_bs * pages_per_seg + print( + f" histogram dataset : {actual_pages} pages " + f"(upscaled from {current_pages} for statistical reliability)" + ) + else: + hist_inputs = inputs + hist_eff_bs = eff_bs + actual_pages = current_pages + print(f" histogram dataset : {actual_pages} pages") + + # Raw unmapped histogram + histograms = torch.zeros(hist_eff_bs, 256, dtype=torch.int32, device="cuda") + topk_profile_histogram( + hist_inputs["x"], + hist_inputs["dense_kv_indptr"], + histograms, + hist_eff_bs, + args.reserved_bos, + args.reserved_eos, + ) + hstats = compute_histogram_stats(histograms) + hstats["raw_counts"] = histograms.sum(dim=0).tolist() # [256] ints + config_results["histogram"] = hstats + print( + f" histogram stats : max/mean={hstats['max_mean_ratio']:.2f} " + f"gini={hstats['gini']:.3f} " + f"nonzero_bins={hstats['num_nonzero_bins']}/256" + ) + + # Per-mode histogram analysis + modes_to_test = [0, 3, 4, 6, 7, 8] + if mapping_lut is not None: + modes_to_test.append(1) + if mapping_quantiles is not None: + modes_to_test.append(2) + modes_to_test.sort() + + histograms_results = {} + print(f" --- histogram by mapping mode ---") + for mode in modes_to_test: + mode_hists = torch.zeros(hist_eff_bs, 256, dtype=torch.int32, device="cuda") + + extra_lut = mapping_lut if mode == 1 else None + extra_q = mapping_quantiles if mode == 2 else None + power = _resolve_mode_power(args, mode) if mode in (3, 6, 7) else 0.5 + + topk_profile_histogram( + hist_inputs["x"], + hist_inputs["dense_kv_indptr"], + mode_hists, + hist_eff_bs, + args.reserved_bos, + args.reserved_eos, + mode, + power, + extra_lut, + extra_q, + ) + torch.cuda.synchronize() + + mode_stats = compute_histogram_stats(mode_hists) + mode_stats["raw_counts"] = mode_hists.sum(dim=0).tolist() + mname = MAPPING_MODE_NAMES.get(mode, f"m{mode}") + mformula = MAPPING_MODE_FORMULAS.get(mode, mname) + mode_stats["name"] = mname + mode_stats["formula"] = mformula + if mode in (3, 6, 7): + pname = {3: "p", 6: "beta", 7: "alpha"}[mode] + mode_stats["param"] = f"{pname}={power}" + histograms_results[f"mode_{mode}_{mname}"] = mode_stats + if mode in (3, 6, 7): + pname = {3: "p", 6: "beta", 7: "alpha"}[mode] + display_name = f"{mname} ({pname}={power})" + else: + display_name = mname + print( + f" {display_name:<22s} (mode {mode}): " + f"gini={mode_stats['gini']:.3f} " + f"max/mean={mode_stats['max_mean_ratio']:.2f} " + f"nonzero_bins={mode_stats['num_nonzero_bins']}/256 " + f"eff_bins={mode_stats['effective_bins']:.1f} " + f"entropy={mode_stats['entropy']:.2f}" + ) + config_results["histograms"] = histograms_results + + all_results.append(config_results) + + return all_results + + +def main(): + parser = argparse.ArgumentParser(description="TopK kernel benchmark suite") + parser.add_argument("--batch-sizes", nargs="+", type=int, default=[1, 4, 8, 16, 32, 64]) + parser.add_argument("--seq-lens", nargs="+", type=int, default=[1024, 2048, 4096, 8192]) + parser.add_argument("--topk-vals", nargs="+", type=int, default=[16, 30, 64]) + parser.add_argument("--num-kv-heads", nargs="+", type=int, default=[2, 4, 8]) + parser.add_argument("--page-size", type=int, default=16) + parser.add_argument("--reserved-bos", type=int, default=1) + parser.add_argument("--reserved-eos", type=int, default=2) + parser.add_argument("--score-dtype", choices=["bfloat16", "float32"], default="bfloat16") + parser.add_argument("--distributions", nargs="+", default=["normal", "lognormal", "uniform"]) + parser.add_argument("--warmup", type=int, default=10) + parser.add_argument("--repeat", type=int, default=100) + parser.add_argument("--mapping-power", type=float, default=0.5, + help="Global fallback power parameter for parametric modes (default: 0.5)") + parser.add_argument("--mapping-power-3", type=float, default=None, + help="Power exponent p for mode 3 (overrides --mapping-power)") + parser.add_argument("--mapping-power-6", type=float, default=None, + help="Beta for mode 6 asinh (overrides --mapping-power)") + parser.add_argument("--mapping-power-7", type=float, default=None, + help="Alpha for mode 7 log1p (overrides --mapping-power)") + parser.add_argument("--autotune-json", type=str, default=None, + help="Path to autotune_results.json — extracts best per-mode hyperparameters " + "(overrides --mapping-power for modes 3/6/7)") + parser.add_argument("--lut-path", type=str, default=None, help="Path to .npy uint8[256] LUT for mode=1") + parser.add_argument("--quantiles-path", type=str, default=None, help="Path to .npy float32[256] for mode=2") + parser.add_argument("--output-json", type=str, default=None, help="Save results to JSON file") + parser.add_argument("--filter-kernels", nargs="+", default=None, + help="Only run specific kernels: naive, sglang_m0, sglang_m3, sglang_m4") + parser.add_argument("--histogram", action="store_true", help="Collect and report bin distribution statistics") + parser.add_argument("--histogram-pages", type=int, default=None, + help="Total pages for histogram profiling. Default: adaptive " + "(512 samples/bin × 256 bins, rounded to segment boundary). " + "Only used when --histogram is set.") + parser.add_argument("--real-histograms", type=str, default=None, + help="Path to .npy raw_histograms from calibration (adds 'real' distribution)") + + args = parser.parse_args() + results = run_benchmark(args) + + if args.output_json: + with open(args.output_json, "w") as f: + json.dump(results, f, indent=2) + print(f"\nResults saved to {args.output_json}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/calibrate_topk.py b/benchmarks/calibrate_topk.py new file mode 100644 index 0000000..4c86116 --- /dev/null +++ b/benchmarks/calibrate_topk.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +""" +Offline calibration for TopK mapping modes 1 (LUT CDF) and 2 (quantile). + +Runs the model on real data with hit-rate profiling enabled, collects score +histograms from the topk_sglang kernel, and generates: + - lut.npy : uint8[256] CDF-equalized LUT for mapping mode 1 + - quantiles.npy: float32[256] quantile breakpoints for mapping mode 2 + +Usage: + python benchmarks/calibrate_topk.py \ + --model-name Qwen/Qwen3-1.7B \ + --topk-val 30 --mem 0.7 \ + --output-dir calibration_output/ +""" + +import argparse +import json +import os +import sys + +import numpy as np + +# Add project root to path so we can import from benchmarks/ +sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) + +from benchmarks.profile_topk_distribution import ( + compute_lut_from_histogram, + generate_tables_from_histograms, +) + + +def main(): + parser = argparse.ArgumentParser( + description="Offline calibration for TopK mapping modes 1 & 2" + ) + parser.add_argument("--model-name", type=str, default="Qwen/Qwen3-1.7B") + parser.add_argument("--topk-val", type=int, default=30) + parser.add_argument("--page-size", type=int, default=16) + parser.add_argument("--mem", type=float, default=0.7) + parser.add_argument("--kv-cache-dtype", type=str, default="auto") + parser.add_argument("--topk-type", type=str, default="sglang") + parser.add_argument("--num-prompts", type=int, default=16, + help="Number of calibration prompts to use (default: 16)") + parser.add_argument("--output-dir", type=str, default="calibration_output/") + parser.add_argument("--vortex-module-name", type=str, default="block_sparse_attention") + args = parser.parse_args() + + # Lazy imports to avoid slow startup when just checking --help + import sglang as sgl + import torch + import vortex_torch + + os.makedirs(args.output_dir, exist_ok=True) + + print(f"[calibrate] Launching engine with hit-rate profiling enabled...") + llm = sgl.Engine( + model_path=args.model_name, + disable_cuda_graph=True, + page_size=args.page_size, + vortex_topk_val=args.topk_val, + disable_overlap_schedule=True, + attention_backend="flashinfer", + enable_vortex_sparsity=True, + vortex_page_reserved_bos=1, + vortex_page_reserved_eos=2, + vortex_layers_skip=list(range(1)), + vortex_module_name=args.vortex_module_name, + vortex_max_seq_lens=12288, + mem_fraction_static=args.mem, + kv_cache_dtype=args.kv_cache_dtype, + vortex_topk_type=args.topk_type, + vortex_topk_mapping_mode=0, # Use mode 0 during calibration + vortex_topk_histogram=True, # Enable histogram collection + ) + + # Clear any residual histograms in the worker process + llm.clear_topk_histograms() + + # Load calibration prompts + prompts_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "..", "examples", "amc23.jsonl" + ) + with open(prompts_path, "r", encoding="utf-8") as f: + all_requests = [json.loads(line) for line in f] + + # Use up to num_prompts + requests = all_requests[:args.num_prompts] + prompts = [req["prompt"] for req in requests] + + print(f"[calibrate] Running {len(prompts)} calibration prompts...") + sampling_params = { + "temperature": 0.6, + "top_p": 0.95, + "top_k": 20, + "max_new_tokens": 8192, + } + llm.generate(prompts, sampling_params) + + # Collect histograms via RPC from worker process + histograms = llm.get_topk_histograms() + print(f"[calibrate] Collected {len(histograms)} histogram batches") + + if len(histograms) == 0: + print("[calibrate] ERROR: No histograms collected. " + "Ensure topk_type='sglang' and vortex_topk_histogram=True.", + file=sys.stderr) + llm.shutdown() + sys.exit(1) + + # Stack all histograms: each is [eff_bs, 256], concatenate along batch dim + all_hists = torch.cat(histograms, dim=0).numpy() # [total_samples, 256] + print(f"[calibrate] Total histogram samples: {all_hists.shape[0]}") + + # --- Generate LUT (mode 1) --- + # Aggregate histogram across all samples + avg_histogram = all_hists.mean(axis=0) + lut = compute_lut_from_histogram(avg_histogram) + lut_path = os.path.join(args.output_dir, "lut.npy") + np.save(lut_path, lut) + print(f"[calibrate] Saved LUT to {lut_path} (shape={lut.shape}, dtype={lut.dtype})") + + # --- Generate quantiles (mode 2) --- + # Use bin centers as proxy scores weighted by histogram counts + bin_centers = np.arange(256, dtype=np.float32) + # Expand histogram counts into a weighted score distribution + total_counts = avg_histogram.astype(np.float64) + total = total_counts.sum() + if total > 0: + cdf = np.cumsum(total_counts) / total + # Invert CDF to get quantile breakpoints in [0, 255] space + percentiles = np.linspace(0, 1, 256) + quantiles = np.interp(percentiles, cdf, bin_centers).astype(np.float32) + else: + quantiles = bin_centers.copy() + + quantiles_path = os.path.join(args.output_dir, "quantiles.npy") + np.save(quantiles_path, quantiles) + print(f"[calibrate] Saved quantiles to {quantiles_path} (shape={quantiles.shape}, dtype={quantiles.dtype})") + + # Save raw histograms for debugging + raw_path = os.path.join(args.output_dir, "raw_histograms.npy") + np.save(raw_path, all_hists) + print(f"[calibrate] Saved raw histograms to {raw_path} (shape={all_hists.shape})") + + # Cleanup + llm.clear_topk_histograms() + llm.shutdown() + print(f"[calibrate] Done. Output files in {args.output_dir}/") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/greedy_layer_search.py b/benchmarks/greedy_layer_search.py new file mode 100644 index 0000000..118ac45 --- /dev/null +++ b/benchmarks/greedy_layer_search.py @@ -0,0 +1,117 @@ +"""Greedy forward-selection of layers whose indexer can be skipped (index cache). + +Usage (from repo root): + cd examples && python ../benchmarks/greedy_layer_search.py \ + --model-name Qwen/Qwen3-1.7B --topk-val 30 --threshold 0.95 \ + --trials 1 --num-layers 28 --mem 0.7 + +The script prints progress to stderr and outputs the final selected layer list +(as a Python list literal) on the **last line of stdout** so callers can parse it. +""" + +import argparse +import os +import sys + +# Add examples/ to path so we can import verify_algos +_examples_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "examples") +sys.path.insert(0, _examples_dir) + +from verify_algo import verify_algos # noqa: E402 + + +def _evaluate(shared_layers, args): + """Run verify_algos with the given shared layers and return pass@trials accuracy.""" + summary = verify_algos( + trials=args.trials, + topk_val=args.topk_val, + page_size=args.page_size, + vortex_module_name=args.vortex_module_name, + model_name=args.model_name, + sparse_attention=True, + mem=args.mem, + kv_cache_dtype=args.kv_cache_dtype, + topk_type=args.topk_type, + topk_mapping_mode=0, + topk_mapping_power=args.topk_mapping_power, + index_cache_shared_layers=sorted(shared_layers) if shared_layers else None, + disable_cuda_graph=True, + ) + acc_key = f"pass@{args.trials}" + return summary[acc_key] + + +def greedy_search(args): + # Ensure we're in examples/ so amc23.jsonl relative path works + os.chdir(_examples_dir) + + candidates = list(range(1, args.num_layers)) + + # Baseline: no shared layers + print("Evaluating baseline (no shared layers)...", file=sys.stderr) + baseline_acc = _evaluate([], args) + print(f"Baseline accuracy: {baseline_acc:.4f}", file=sys.stderr) + + threshold = args.threshold + shared_set = [] + + while candidates: + best_layer = None + best_acc = -1.0 + + for layer in candidates: + trial_set = shared_set + [layer] + print(f" Trying shared_set={sorted(trial_set)} ...", file=sys.stderr, end=" ") + acc = _evaluate(trial_set, args) + print(f"acc={acc:.4f}", file=sys.stderr) + + if acc > best_acc: + best_acc = acc + best_layer = layer + + if best_acc >= threshold * baseline_acc: + shared_set.append(best_layer) + candidates.remove(best_layer) + print( + f"Added layer {best_layer} (acc={best_acc:.4f} >= " + f"{threshold * baseline_acc:.4f}). Current set: {sorted(shared_set)}", + file=sys.stderr, + ) + else: + print( + f"Stopping: best candidate layer {best_layer} acc={best_acc:.4f} < " + f"{threshold * baseline_acc:.4f}", + file=sys.stderr, + ) + break + + result = sorted(shared_set) + print(f"Final shared layers: {result}", file=sys.stderr) + # Last stdout line: parseable Python list + print(result) + return result + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Greedy forward-selection of index-cache shared layers." + ) + parser.add_argument("--model-name", type=str, default="Qwen/Qwen3-1.7B") + parser.add_argument("--topk-val", type=int, default=30) + parser.add_argument("--page-size", type=int, default=16) + parser.add_argument("--vortex-module-name", type=str, default="block_sparse_attention") + parser.add_argument("--mem", type=float, default=0.8) + parser.add_argument("--kv-cache-dtype", type=str, default="auto") + parser.add_argument("--topk-type", type=str, default="naive") + parser.add_argument("--topk-mapping-power", type=float, default=0.5) + parser.add_argument("--threshold", type=float, default=0.95, + help="Minimum accuracy ratio vs baseline to keep adding layers (default: 0.95).") + parser.add_argument("--trials", type=int, default=1) + parser.add_argument("--num-layers", type=int, default=28, + help="Total number of model layers (default: 28 for Qwen3-1.7B).") + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + greedy_search(args) diff --git a/benchmarks/profile_topk_distribution.py b/benchmarks/profile_topk_distribution.py new file mode 100644 index 0000000..bea911b --- /dev/null +++ b/benchmarks/profile_topk_distribution.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 +""" +Profile TopK bin distribution and generate mapping tables. + +This script collects Stage 1 (8-bit coarse histogram) distributions from +the topk_sglang kernel and generates LUT/quantile mapping tables that +can be used to equalize the bin distribution for improved sorting efficiency. + +Usage: + python scripts/profile_topk_distribution.py \ + --model-name Qwen/Qwen3-1.7B \ + --output mapping_tables.npz \ + --num-prompts 32 \ + --mem 0.7 + +Output (.npz): + lut_tables: [num_collected, 256] uint8 - CDF-equalized LUT per sample + quantile_tables: [num_collected, 256] float32 - quantile breakpoints per sample + raw_histograms: [num_collected, 256] int32 - raw bin histograms +""" + +import argparse +import numpy as np +import torch + + +def compute_lut_from_histogram(histogram: np.ndarray) -> np.ndarray: + """Compute CDF-equalized LUT from a 256-bin histogram. + + Args: + histogram: [256] int array of bin counts + + Returns: + lut: [256] uint8 array where lut[i] = floor(CDF(i) * 255) + """ + cdf = np.cumsum(histogram).astype(np.float64) + total = cdf[-1] + if total == 0: + return np.arange(256, dtype=np.uint8) + cdf_normalized = cdf / total + lut = np.floor(cdf_normalized * 255).astype(np.uint8) + return lut + + +def compute_quantiles_from_scores(scores: np.ndarray, num_quantiles: int = 256) -> np.ndarray: + """Compute quantile breakpoints from raw float scores. + + Args: + scores: 1D array of float scores + num_quantiles: number of quantile bins (default 256) + + Returns: + quantiles: [num_quantiles] float32 array of sorted breakpoints + """ + if len(scores) == 0: + return np.zeros(num_quantiles, dtype=np.float32) + percentiles = np.linspace(0, 100, num_quantiles) + quantiles = np.percentile(scores, percentiles).astype(np.float32) + return quantiles + + +def generate_tables_from_histograms(histograms: np.ndarray) -> dict: + """Generate LUT and quantile tables from collected histograms. + + Args: + histograms: [N, 256] int32 array of bin histograms + + Returns: + dict with 'lut_tables' and 'aggregate_lut' + """ + N = histograms.shape[0] + lut_tables = np.zeros((N, 256), dtype=np.uint8) + + for i in range(N): + lut_tables[i] = compute_lut_from_histogram(histograms[i]) + + # Aggregate: average histogram across all samples + avg_histogram = histograms.mean(axis=0) + aggregate_lut = compute_lut_from_histogram(avg_histogram) + + return { + 'lut_tables': lut_tables, + 'aggregate_lut': aggregate_lut, + } + + +def main(): + parser = argparse.ArgumentParser( + description="Profile TopK bin distribution and generate mapping tables") + parser.add_argument("--output", type=str, default="mapping_tables.npz", + help="Output .npz file path") + parser.add_argument("--histograms-input", type=str, default=None, + help="Load pre-collected histograms from .npy file instead of running inference") + parser.add_argument("--scores-input", type=str, default=None, + help="Load pre-collected raw scores from .npy for quantile computation") + args = parser.parse_args() + + results = {} + + if args.histograms_input: + print(f"Loading histograms from {args.histograms_input}") + histograms = np.load(args.histograms_input) + if histograms.ndim == 1: + histograms = histograms.reshape(1, -1) + results['raw_histograms'] = histograms + + tables = generate_tables_from_histograms(histograms) + results.update(tables) + + if args.scores_input: + print(f"Loading scores from {args.scores_input}") + scores = np.load(args.scores_input) + quantiles = compute_quantiles_from_scores(scores.flatten()) + results['quantile_table'] = quantiles + + if not results: + print("No input provided. Use --histograms-input or --scores-input.") + print("\nTo collect histograms, use the topk_profile_histogram() function from vortex_torch_C:") + print(" from vortex_torch_C import topk_profile_histogram") + print(" histograms = torch.zeros(eff_batch_size, 256, dtype=torch.int32, device='cuda')") + print(" topk_profile_histogram(scores, dense_kv_indptr, histograms, eff_batch_size, bos, eos)") + print(" np.save('histograms.npy', histograms.cpu().numpy())") + return + + np.savez(args.output, **results) + print(f"Saved mapping tables to {args.output}") + for key, val in results.items(): + print(f" {key}: shape={val.shape}, dtype={val.dtype}") + + +if __name__ == "__main__": + main() diff --git a/csrc/register.cc b/csrc/register.cc index 532fcdf..0067474 100644 --- a/csrc/register.cc +++ b/csrc/register.cc @@ -8,7 +8,24 @@ PYBIND11_MODULE(vortex_torch_C, m){ m.def("Chunkwise_NH2HN_Transpose", &Chunkwise_NH2HN_Transpose); m.def("Chunkwise_HN2NH_Transpose", &Chunkwise_HN2NH_Transpose); m.def("topk_output", &topk_output); - m.def("topk_output_sglang", &topk_output_sglang); + m.def("topk_output_sglang", &topk_output_sglang, + py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), + py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages"), + py::arg("mapping_mode") = 0, + py::arg("mapping_power") = 0.5, + py::arg("mapping_lut") = py::none(), + py::arg("mapping_quantiles") = py::none()); + m.def("topk_profile_histogram", &topk_profile_histogram, + py::arg("x"), py::arg("dense_kv_indptr"), + py::arg("histograms"), py::arg("eff_batch_size"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("mapping_mode") = 0, + py::arg("mapping_power") = 0.5, + py::arg("mapping_lut") = py::none(), + py::arg("mapping_quantiles") = py::none()); m.def("sglang_plan_decode_fa3", &sglang_plan_decode_fa3); m.def("sglang_plan_prefill_fa3", &sglang_plan_prefill_fa3); m.def("Chunkwise_HN2NH_Transpose_FA3", &Chunkwise_HN2NH_Transpose_FA3); diff --git a/csrc/register.h b/csrc/register.h index b81168b..d4f2d8b 100644 --- a/csrc/register.h +++ b/csrc/register.h @@ -95,7 +95,24 @@ const int64_t eff_batch_size, const int64_t topk_val, const int64_t reserved_bos, const int64_t reserved_eos, -const int64_t max_seq_lengths +const int64_t max_seq_lengths, +const int64_t mapping_mode = 0, +const double mapping_power = 0.5, +std::optional mapping_lut = std::nullopt, +std::optional mapping_quantiles = std::nullopt +); + +void topk_profile_histogram( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +at::Tensor& histograms, +const int64_t eff_batch_size, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t mapping_mode = 0, +const double mapping_power = 0.5, +std::optional mapping_lut = std::nullopt, +std::optional mapping_quantiles = std::nullopt ); void sglang_plan_decode_fa3( diff --git a/csrc/topk.cu b/csrc/topk.cu index 3aa49b9..70d2000 100644 --- a/csrc/topk.cu +++ b/csrc/topk.cu @@ -117,8 +117,8 @@ const int page_reserved_eos) void topk_output( const at::Tensor& x, const at::Tensor& dense_kv_indptr, -const at::Tensor& sparse_kv_indptr, const at::Tensor& dense_kv_indices, +const at::Tensor& sparse_kv_indptr, at::Tensor& sparse_kv_indices, const int64_t eff_batch_size, const int64_t topk_val, diff --git a/csrc/topk_mapping.cuh b/csrc/topk_mapping.cuh new file mode 100644 index 0000000..e3fe3a7 --- /dev/null +++ b/csrc/topk_mapping.cuh @@ -0,0 +1,148 @@ +#pragma once +#include +#include +#include + +// ============================================================ +// TopK bucket-sort distribution mapping strategies +// +// These transforms remap float scores before Stage 1's 8-bit +// histogram binning, aiming for a more uniform distribution +// across the 256 coarse bins. Stage 2 refinement still uses +// convert_to_uint32() on raw floats, so correctness is preserved. +// +// Modes 3/4/6/7 use a data-adaptive linear mapping to [0,255] +// instead of fp16 bit-pattern bucketing, guaranteeing full +// bucket utilization regardless of value range. +// ============================================================ + +enum TopKMappingMode { + MAPPING_NONE = 0, // Original convert_to_uint8 behavior + MAPPING_LUT_CDF = 1, // LUT-based CDF equalization + MAPPING_QUANTILE = 2, // Piecewise-linear quantile mapping + MAPPING_POWER = 3, // Monotonic power transform + MAPPING_LOG = 4, // Log transform + MAPPING_INDEX_CACHE = 5, // Sentinel: reuse previous layer's indices (Python-level skip) + MAPPING_ASINH = 6, // asinh(beta * x), beta via power_exp + MAPPING_LOG1P = 7, // sign(x) * log1p(alpha * |x|), alpha via power_exp + MAPPING_TRUNC8 = 8, // BF16 upper-8-bit bucketing +}; + +struct TopKMappingParams { + int mode; // TopKMappingMode + float power_exp; // For MAPPING_POWER (default 0.5) + const uint8_t* __restrict__ lut; // [256] byte LUT, or nullptr + const float* __restrict__ quantiles; // [256] float quantile breakpoints, or nullptr +}; + +// NOTE: convert_to_uint8() must be defined before including this header. +// It is defined in topk_sglang.cu within the anonymous namespace. + +// ---- Individual transform functions (return float, no bucketing) ---- + +__device__ __forceinline__ float transform_power(float x, float p) { + return copysignf(__powf(fabsf(x), p), x); +} + +__device__ __forceinline__ float transform_log(float x) { + return copysignf(__logf(fabsf(x) + 1.0f), x); +} + +__device__ __forceinline__ float transform_asinh(float x, float beta) { + return asinhf(beta * x); +} + +__device__ __forceinline__ float transform_log1p(float x, float alpha) { + return copysignf(log1pf(alpha * fabsf(x)), x); +} + +// ---- Transform dispatcher (returns float, no bucketing) ---- + +__device__ __forceinline__ float apply_transform(float x, const TopKMappingParams& params) { + switch (params.mode) { + case MAPPING_POWER: return transform_power(x, params.power_exp); + case MAPPING_LOG: return transform_log(x); + case MAPPING_ASINH: return transform_asinh(x, params.power_exp); + case MAPPING_LOG1P: return transform_log1p(x, params.power_exp); + default: return x; + } +} + +// ---- Linear bucketing for transform modes ---- + +__device__ __forceinline__ uint8_t linear_map_to_uint8(float val, float range_min, float inv_range) { + int bin = __float2int_rd((val - range_min) * inv_range); + return static_cast(min(max(bin, 0), 255)); +} + +// ---- BF16 upper-8-bit bucketing (mode 8) ---- + +__device__ __forceinline__ uint8_t convert_to_uint8_bf16(float x) { + __nv_bfloat16 bf = __float2bfloat16_rn(x); + uint16_t bits = __bfloat16_as_ushort(bf); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) + : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +// ---- Non-transform mapping functions (unchanged) ---- + +// LUT-based CDF equalization: lut[original_bin] -> equalized_bin +__device__ __forceinline__ uint8_t map_lut_cdf(float x, const uint8_t* __restrict__ s_lut) { + return s_lut[convert_to_uint8(x)]; +} + +// Quantile mapping: binary search over 256 sorted thresholds +__device__ __forceinline__ uint8_t map_quantile(float x, const float* __restrict__ s_quantiles) { + // Binary search: find largest index i such that x >= s_quantiles[i] + // s_quantiles is sorted ascending, length 256 + int lo = 0, hi = 255; +#pragma unroll 8 + for (int iter = 0; iter < 8; ++iter) { + int mid = (lo + hi + 1) >> 1; + if (x >= s_quantiles[mid]) { + lo = mid; + } else { + hi = mid - 1; + } + } + return static_cast(lo); +} + +// ---- Unified dispatcher ---- +// For modes 3/4/6/7, range_min and inv_range come from a per-block pre-pass. + +__device__ __forceinline__ uint8_t mapped_convert_to_uint8( + float x, + const TopKMappingParams& params, + const uint8_t* __restrict__ s_lut, + const float* __restrict__ s_quantiles, + float range_min, + float inv_range) +{ + switch (params.mode) { + case MAPPING_LUT_CDF: + if (params.lut != nullptr) return map_lut_cdf(x, s_lut); + return convert_to_uint8(x); // fallback to mode 0 when LUT not calibrated + case MAPPING_QUANTILE: + if (params.quantiles != nullptr) return map_quantile(x, s_quantiles); + return convert_to_uint8(x); // fallback to mode 0 when quantiles not calibrated + case MAPPING_POWER: + case MAPPING_LOG: + case MAPPING_ASINH: + case MAPPING_LOG1P: { + float val = apply_transform(x, params); + return linear_map_to_uint8(val, range_min, inv_range); + } + case MAPPING_TRUNC8: + return convert_to_uint8_bf16(x); + default: // MAPPING_NONE + return convert_to_uint8(x); + } +} + +// Helper: check if a mapping mode needs the auto-range pre-pass +__device__ __forceinline__ bool needs_auto_range(int mode) { + return (mode == MAPPING_POWER || mode == MAPPING_LOG || + mode == MAPPING_ASINH || mode == MAPPING_LOG1P); +} diff --git a/csrc/topk_sglang.cu b/csrc/topk_sglang.cu index 314f0fd..5959270 100644 --- a/csrc/topk_sglang.cu +++ b/csrc/topk_sglang.cu @@ -6,821 +6,1090 @@ * 2. optimize the performance a little * 3. fix the potential illegal memory access */ - #include - #include - #include - #include - #include - #include - #include - #include - #include - - #include - #include - #include - - namespace { - - constexpr int TopK = 2048; - constexpr int kThreadsPerBlock = 1024; - - #ifdef USE_ROCM - // On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a - // per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. - #ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES - constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); - #else - constexpr size_t kSmem = 48 * 1024; // bytes - #endif - #else - // Reduced from 128KB to 32KB to improve occupancy. - // Each radix pass needs at most ~TopK candidates in the threshold bin, - // so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. - constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) - #endif - - struct FastTopKParams { - const float* __restrict__ input; // [B, input_stride] - const int32_t* __restrict__ row_starts; // [B] - int32_t* __restrict__ indices; // [B, TopK] - int32_t* __restrict__ lengths; // [B] - int64_t input_stride; - }; - - // when length <= TopK, we can directly write the indices - __device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { - const auto tid = threadIdx.x; - for (int i = tid; i < TopK; i += kThreadsPerBlock) { - indice[i] = (i < length) ? i : -1; - } - } - - // keep the first `length` entries, set others to -1 - __device__ void naive_topk_transform( - const float* __restrict__ score, - int32_t length, - int32_t* __restrict__ dst_page_table, - const int32_t* __restrict__ src_page_table) { - const auto tid = threadIdx.x; - for (auto i = tid; i < TopK; i += kThreadsPerBlock) { - dst_page_table[i] = (i < length) ? src_page_table[i] : -1; - } - } - - // keep the first `length` entries, set others to -1 - __device__ void naive_topk_transform_ragged( - const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { - const auto tid = threadIdx.x; - for (auto i = tid; i < TopK; i += kThreadsPerBlock) { - topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; - } - } - - __device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { - __half h = __float2half_rn(x); - uint16_t bits = __half_as_ushort(h); - uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); - return static_cast(key >> 8); - } - - __device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { - uint32_t bits = __float_as_uint(x); - return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); - } - - __device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { - // An optimized topk kernel copied from tilelang kernel - // We assume length > TopK here, or it will crash - int topk = TopK; - constexpr auto BLOCK_SIZE = 1024; - constexpr auto RADIX = 256; - constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); - - alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; - alignas(128) __shared__ int s_counter; - alignas(128) __shared__ int s_threshold_bin_id; - alignas(128) __shared__ int s_num_input[2]; - - auto& s_histogram = s_histogram_buf[0]; - // allocate for two rounds - extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; - - const int tx = threadIdx.x; - - // stage 1: 8bit coarse histogram - if (tx < RADIX + 1) s_histogram[tx] = 0; - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = convert_to_uint8(input[idx + row_start]); - ::atomicAdd(&s_histogram[bin], 1); - } - __syncthreads(); - - const auto run_cumsum = [&] { - #pragma unroll 8 - for (int i = 0; i < 8; ++i) { - static_assert(1 << 8 == RADIX); - if (C10_LIKELY(tx < RADIX)) { - const auto j = 1 << i; - const auto k = i & 1; - auto value = s_histogram_buf[k][tx]; - if (tx < RADIX - j) { - value += s_histogram_buf[k][tx + j]; - } - s_histogram_buf[k ^ 1][tx] = value; - } - __syncthreads(); - } - }; - - run_cumsum(); - if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { - s_threshold_bin_id = tx; - s_num_input[0] = 0; - s_counter = 0; - } - __syncthreads(); - - const auto threshold_bin = s_threshold_bin_id; - topk -= s_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = static_cast(convert_to_uint8(input[idx + row_start])); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - return; - } else { - __syncthreads(); - if (tx < RADIX + 1) { - s_histogram[tx] = 0; - } - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto raw_input = input[idx + row_start]; - const auto bin = static_cast(convert_to_uint8(raw_input)); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - const auto pos = ::atomicAdd(&s_num_input[0], 1); - /// NOTE: (dark) fuse the histogram computation here - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - s_input_idx[0][pos] = idx; - const auto bin = convert_to_uint32(raw_input); - const auto sub_bin = (bin >> 24) & 0xFF; - ::atomicAdd(&s_histogram[sub_bin], 1); - } - } - } - __syncthreads(); - } - - // stage 2: refine with 8bit radix passes - #pragma unroll 4 - for (int round = 0; round < 4; ++round) { - __shared__ int s_last_remain; - const auto r_idx = round % 2; - - // clip here to prevent overflow - const auto _raw_num_input = s_num_input[r_idx]; - const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); - - run_cumsum(); - if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { - s_threshold_bin_id = tx; - s_num_input[r_idx ^ 1] = 0; - s_last_remain = topk - s_histogram[tx + 1]; - } - __syncthreads(); - - const auto threshold_bin = s_threshold_bin_id; - topk -= s_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = s_input_idx[r_idx][i]; - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - break; - } else { - __syncthreads(); - if (tx < RADIX + 1) { - s_histogram[tx] = 0; - } - __syncthreads(); - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = s_input_idx[r_idx][i]; - const auto raw_input = input[idx + row_start]; - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - if (round == 3) { - const auto pos = ::atomicAdd(&s_last_remain, -1); - if (pos > 0) { - index[TopK - pos] = idx; - } - } else { - const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - /// NOTE: (dark) fuse the histogram computation here - s_input_idx[r_idx ^ 1][pos] = idx; - const auto bin = convert_to_uint32(raw_input); - const auto sub_bin = (bin >> (offset - 8)) & 0xFF; - ::atomicAdd(&s_histogram[sub_bin], 1); - } - } - } - } - __syncthreads(); - } - } - } - - __global__ __launch_bounds__(kThreadsPerBlock) // topk - void topk_kernel(const FastTopKParams params) { - const auto& [input, row_starts, indices, lengths, input_stride] = params; - const auto bid = static_cast(blockIdx.x); - const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; - const auto length = lengths[bid]; - const auto indice = indices + bid * TopK; - const auto score = input + bid * input_stride; - if (length <= TopK) { - return naive_topk_cuda(score, indice, length); - } else { - return fast_topk_cuda_tl(score, indice, row_start, length); - } - } - - __global__ __launch_bounds__(kThreadsPerBlock) // decode - void topk_transform_decode_kernel( - const FastTopKParams params, - int32_t* __restrict__ dst_page_table, - const int32_t* __restrict__ src_page_table, - const int64_t src_stride) { - const auto& [input, _1, _2, lengths, input_stride] = params; - const auto bid = static_cast(blockIdx.x); - const auto tid = threadIdx.x; - const auto row_start = 0; - const auto length = lengths[bid]; - const auto src_page_entry = src_page_table + bid * src_stride; - const auto dst_page_entry = dst_page_table + bid * TopK; - const auto score = input + bid * input_stride; - if (length <= TopK) { - return naive_topk_transform(score, length, dst_page_entry, src_page_entry); - } else { - __shared__ int s_indices[TopK]; - fast_topk_cuda_tl(score, s_indices, row_start, length); - // copy src[s_indices] to dst, we manually unroll here - static_assert(TopK % kThreadsPerBlock == 0); - static_assert(TopK / kThreadsPerBlock == 2); - const auto idx_0 = tid; - const auto pos_0 = s_indices[idx_0]; - dst_page_entry[idx_0] = src_page_entry[pos_0]; - const auto idx_1 = tid + kThreadsPerBlock; - const auto pos_1 = s_indices[idx_1]; - dst_page_entry[idx_1] = src_page_entry[pos_1]; - } - } - - __global__ __launch_bounds__(kThreadsPerBlock) // prefill - void topk_transform_prefill_kernel( - const FastTopKParams params, - int32_t* __restrict__ dst_page_table, - const int32_t* __restrict__ src_page_table, - const int64_t src_stride, - const int32_t* __restrict__ cu_seqlens_q, - const int64_t prefill_bs) { - const auto& [input, row_starts, _, lengths, input_stride] = params; - const auto bid = static_cast(blockIdx.x); - const auto tid = threadIdx.x; - const auto length = lengths[bid]; - const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; - const auto dst_page_entry = dst_page_table + bid * TopK; - const auto score = input + bid * input_stride; - - /// NOTE: prefill bs is usually small, we can just use a simple loop here - /// We ensure that last cu_seqlens is equal to number of blocks launched - __shared__ const int32_t* s_src_page_entry; - if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { - if (tid < prefill_bs) { - if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { - s_src_page_entry = src_page_table + tid * src_stride; - } - } - } else { - for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { - if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { - s_src_page_entry = src_page_table + i * src_stride; - } - } - } - __syncthreads(); - const auto src_page_entry = s_src_page_entry; - - if (length <= TopK) { - return naive_topk_transform(score, length, dst_page_entry, src_page_entry); - } else { - __shared__ int s_indices[TopK]; - fast_topk_cuda_tl(score, s_indices, row_start, length); - // copy src[s_indices] to dst, we manually unroll here - static_assert(TopK % kThreadsPerBlock == 0); - static_assert(TopK / kThreadsPerBlock == 2); - const auto idx_0 = tid; - const auto pos_0 = s_indices[idx_0]; - dst_page_entry[idx_0] = src_page_entry[pos_0]; - const auto idx_1 = tid + kThreadsPerBlock; - const auto pos_1 = s_indices[idx_1]; - dst_page_entry[idx_1] = src_page_entry[pos_1]; - } - } - - __global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv - void topk_transform_prefill_ragged_kernel( - const FastTopKParams params, - int32_t* __restrict__ topk_indices_ragged, - const int32_t* __restrict__ topk_indices_offset) { - const auto& [input, row_starts, _, lengths, input_stride] = params; - const auto bid = static_cast(blockIdx.x); - const auto tid = threadIdx.x; - const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; - const auto length = lengths[bid]; - const auto dst_indices_entry = topk_indices_ragged + bid * TopK; - const auto score = input + bid * input_stride; - const auto offset = topk_indices_offset[bid]; - - if (length <= TopK) { - return naive_topk_transform_ragged(score, length, dst_indices_entry, offset); - } else { - __shared__ int s_indices[TopK]; - fast_topk_cuda_tl(score, s_indices, row_start, length); - // copy src[s_indices] to dst, we manually unroll here - static_assert(TopK % kThreadsPerBlock == 0); - static_assert(TopK / kThreadsPerBlock == 2); - const auto idx_0 = tid; - const auto pos_0 = s_indices[idx_0]; - dst_indices_entry[idx_0] = pos_0 + offset; - const auto idx_1 = tid + kThreadsPerBlock; - const auto pos_1 = s_indices[idx_1]; - dst_indices_entry[idx_1] = pos_1 + offset; - } - } - - auto get_params( - const at::Tensor& score, - const at::Tensor& lengths, - std::optional row_starts_opt = std::nullopt, - std::optional indices_opt = std::nullopt) -> FastTopKParams { - const auto B = score.size(0); - TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); - if (row_starts_opt.has_value()) { - const auto& row_starts = row_starts_opt.value(); - TORCH_CHECK(row_starts.dim() == 1); - TORCH_CHECK(row_starts.size(0) == B); - } - TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); - TORCH_CHECK(lengths.size(0) == B); - int32_t* indices_data_ptr = nullptr; - if (indices_opt.has_value()) { - const auto& indices = indices_opt.value(); - TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); - TORCH_CHECK(indices.size(0) == B); - TORCH_CHECK(indices.size(1) == TopK); - indices_data_ptr = indices.data_ptr(); - } - - return FastTopKParams{ - .input = score.data_ptr(), - .row_starts = row_starts_opt.has_value() ? row_starts_opt->data_ptr() : nullptr, - .indices = indices_data_ptr, - .lengths = lengths.data_ptr(), - .input_stride = score.stride(0), - }; - } - - template - void setup_kernel_smem_once() { - [[maybe_unused]] - static const auto result = [] { - #ifdef USE_ROCM - // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, - // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing - // a function pointer directly, so cast explicitly. - return ::cudaFuncSetAttribute( - reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); - #else - // CUDA: keep original behavior (no cast needed). - return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); - #endif - }(); - TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); - } - - // ====================================================================== - // Vortex integration: BOS/EOS-aware segmented TopK with index remapping - // ====================================================================== - - template - __device__ __forceinline__ float vortex_to_float(T x); - - template <> - __device__ __forceinline__ float vortex_to_float(float x) { return x; } - - template <> - __device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { - return __bfloat162float(x); - } - - constexpr int VORTEX_MAX_TOPK = 2048; - - // Templated version of fast_topk_cuda_tl: - // - ScoreT: float or __nv_bfloat16 - // - target_k: runtime parameter (replaces compile-time TopK) - template - __device__ void fast_topk_vortex( - const ScoreT* __restrict__ input, - int* __restrict__ index, - int row_start, - int length, - int target_k) - { - int topk = target_k; - constexpr auto BLOCK_SIZE = 1024; - constexpr auto RADIX = 256; - constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); - - alignas(128) __shared__ int vh_histogram_buf[2][RADIX + 128]; - alignas(128) __shared__ int vh_counter; - alignas(128) __shared__ int vh_threshold_bin_id; - alignas(128) __shared__ int vh_num_input[2]; - - auto& vh_histogram = vh_histogram_buf[0]; - extern __shared__ int vh_input_idx[][SMEM_INPUT_SIZE]; - - const int tx = threadIdx.x; - - // Stage 1: 8-bit coarse histogram - if (tx < RADIX + 1) vh_histogram[tx] = 0; - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = convert_to_uint8(vortex_to_float(input[idx + row_start])); - ::atomicAdd(&vh_histogram[bin], 1); - } - __syncthreads(); - - const auto run_cumsum = [&] { - #pragma unroll 8 - for (int i = 0; i < 8; ++i) { - static_assert(1 << 8 == RADIX); - if (C10_LIKELY(tx < RADIX)) { - const auto j = 1 << i; - const auto k = i & 1; - auto value = vh_histogram_buf[k][tx]; - if (tx < RADIX - j) { - value += vh_histogram_buf[k][tx + j]; - } - vh_histogram_buf[k ^ 1][tx] = value; - } - __syncthreads(); - } - }; - - run_cumsum(); - if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { - vh_threshold_bin_id = tx; - vh_num_input[0] = 0; - vh_counter = 0; - } - __syncthreads(); - - const auto threshold_bin = vh_threshold_bin_id; - topk -= vh_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = static_cast( - convert_to_uint8(vortex_to_float(input[idx + row_start]))); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - return; - } else { - __syncthreads(); - if (tx < RADIX + 1) vh_histogram[tx] = 0; - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto raw_input = vortex_to_float(input[idx + row_start]); - const auto bin = static_cast(convert_to_uint8(raw_input)); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - const auto pos = ::atomicAdd(&vh_num_input[0], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - vh_input_idx[0][pos] = idx; - const auto b32 = convert_to_uint32(raw_input); - const auto sub_bin = (b32 >> 24) & 0xFF; - ::atomicAdd(&vh_histogram[sub_bin], 1); - } - } - } - __syncthreads(); - } - - // Stage 2: refine with 8-bit radix passes - #pragma unroll 4 - for (int round = 0; round < 4; ++round) { - __shared__ int vh_last_remain; - const auto r_idx = round % 2; - - const auto _raw_num_input = vh_num_input[r_idx]; - const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) - ? _raw_num_input - : int(SMEM_INPUT_SIZE); - - run_cumsum(); - if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { - vh_threshold_bin_id = tx; - vh_num_input[r_idx ^ 1] = 0; - vh_last_remain = topk - vh_histogram[tx + 1]; - } - __syncthreads(); - - const auto threshold_bin = vh_threshold_bin_id; - topk -= vh_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = vh_input_idx[r_idx][i]; - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32( - vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - break; - } else { - __syncthreads(); - if (tx < RADIX + 1) vh_histogram[tx] = 0; - __syncthreads(); - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = vh_input_idx[r_idx][i]; - const auto raw_input = vortex_to_float(input[idx + row_start]); - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - if (round == 3) { - const auto pos = ::atomicAdd(&vh_last_remain, -1); - if (pos > 0) { - index[target_k - pos] = idx; - } - } else { - const auto pos = ::atomicAdd(&vh_num_input[r_idx ^ 1], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - vh_input_idx[r_idx ^ 1][pos] = idx; - const auto b32 = convert_to_uint32(raw_input); - const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; - ::atomicAdd(&vh_histogram[sub_bin], 1); - } - } - } - } - __syncthreads(); - } - } - } - - // Wrapper kernel: one CUDA block per batch*head segment - template - __global__ __launch_bounds__(kThreadsPerBlock) - void TopKOutput_Kernel( - const ScoreT* __restrict__ score, - const int* __restrict__ dense_kv_indptr, - const int* __restrict__ sparse_kv_indptr, - const int* __restrict__ dense_kv_indices, - int* __restrict__ sparse_kv_indices, - const int topk_val, - const int page_reserved_bos, - const int page_reserved_eos) - { - const int bx = blockIdx.x; - - const int start = dense_kv_indptr[bx] + page_reserved_bos; - const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; - const int nblk = end - start; - if (nblk <= topk_val) return; - - const ScoreT* __restrict__ score_blk = score + start; - const int* __restrict__ idx_blk = dense_kv_indices + start; - int* __restrict__ out_blk = sparse_kv_indices - + sparse_kv_indptr[bx] - + page_reserved_bos; - - __shared__ int s_indices[VORTEX_MAX_TOPK]; - fast_topk_vortex(score_blk, s_indices, 0, nblk, topk_val); - __syncthreads(); - - // Remap position indices -> page indices via dense_kv_indices - const int tx = threadIdx.x; - for (int i = tx; i < topk_val; i += kThreadsPerBlock) { - out_blk[i] = idx_blk[s_indices[i]]; - } - } - - } // namespace - - #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") - - void fast_topk_interface( - const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional row_starts_opt) { - CHECK_CUDA(score); - CHECK_CUDA(indices); - if (row_starts_opt.has_value()) { - CHECK_CUDA(row_starts_opt.value()); - } - CHECK_CUDA(lengths); - const auto params = get_params(score, lengths, row_starts_opt, indices); - const auto B = score.size(0); - const auto stream = at::cuda::getCurrentCUDAStream().stream(); - const auto grid = dim3{static_cast(B)}; - const auto block = dim3{kThreadsPerBlock}; - setup_kernel_smem_once(); - topk_kernel<<>>(params); - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); - } - - void fast_topk_transform_interface( - const at::Tensor& score, - const at::Tensor& lengths, - at::Tensor& dst_page_table, - const at::Tensor& src_page_table, - const at::Tensor& cu_seqlens_q, - std::optional row_starts_opt) { - CHECK_CUDA(score); - CHECK_CUDA(lengths); - CHECK_CUDA(dst_page_table); - CHECK_CUDA(src_page_table); - CHECK_CUDA(cu_seqlens_q); - if (row_starts_opt.has_value()) { - CHECK_CUDA(row_starts_opt.value()); - } - const auto params = get_params(score, lengths, row_starts_opt); - const auto B = score.size(0); - TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); - TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); - TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); - const auto prefill_bs = cu_seqlens_q.size(0) - 1; - TORCH_CHECK(dst_page_table.size(0) == B); - TORCH_CHECK(dst_page_table.size(1) == TopK); - TORCH_CHECK(src_page_table.size(0) == prefill_bs); - TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs - - // launch kernel - const auto stream = at::cuda::getCurrentCUDAStream().stream(); - const auto grid = dim3{static_cast(B)}; - const auto block = dim3{kThreadsPerBlock}; - const auto src_stride = src_page_table.stride(0); - - // dispatch to decode or prefill - // extend and draft extend: row_starts_opt is not null, invokes the prefill kernel - // decode: row_starts_opt is null, invokes the decode kernel - // target verify: row_starts_opt is null, invokes the prefill kernel - const auto is_decode = !row_starts_opt.has_value() && prefill_bs == B; - if (is_decode) { - setup_kernel_smem_once(); - topk_transform_decode_kernel<<>>( - params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride); - } else { - setup_kernel_smem_once(); - topk_transform_prefill_kernel<<>>( - params, - dst_page_table.data_ptr(), - src_page_table.data_ptr(), - src_stride, - cu_seqlens_q.data_ptr(), - prefill_bs); - } - - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); - } - - void fast_topk_transform_ragged_interface( - const at::Tensor& score, - const at::Tensor& lengths, - at::Tensor& topk_indices_ragged, - const at::Tensor& topk_indices_offset, - std::optional row_starts_opt) { - CHECK_CUDA(score); - CHECK_CUDA(lengths); - CHECK_CUDA(topk_indices_ragged); - CHECK_CUDA(topk_indices_offset); - if (row_starts_opt.has_value()) { - CHECK_CUDA(row_starts_opt.value()); - } - - const auto params = get_params(score, lengths, row_starts_opt); - const auto B = score.size(0); - TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous()); - TORCH_CHECK(topk_indices_offset.dim() == 1); - - TORCH_CHECK(topk_indices_ragged.size(0) == B); - TORCH_CHECK(topk_indices_ragged.size(1) == TopK); - TORCH_CHECK(topk_indices_offset.size(0) == B); - - // launch kernel - const auto stream = at::cuda::getCurrentCUDAStream().stream(); - const auto grid = dim3{static_cast(B)}; - const auto block = dim3{kThreadsPerBlock}; - - setup_kernel_smem_once(); - topk_transform_prefill_ragged_kernel<<>>( - params, topk_indices_ragged.data_ptr(), topk_indices_offset.data_ptr()); - - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); - } - - // ====================================================================== - // Vortex host entry point — same interface as topk_output in topk.cu - // ====================================================================== - void topk_output_sglang( - const at::Tensor& x, - const at::Tensor& dense_kv_indptr, - const at::Tensor& sparse_kv_indptr, - const at::Tensor& dense_kv_indices, - at::Tensor& sparse_kv_indices, - const int64_t eff_batch_size, - const int64_t topk_val, - const int64_t reserved_bos, - const int64_t reserved_eos, - const int64_t max_num_pages) - { - TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, - "topk_output: topk_val (", topk_val, - ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); - - dim3 nblks(eff_batch_size); - dim3 nthreads(kThreadsPerBlock); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (x.scalar_type() == at::ScalarType::BFloat16) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Kernel<__nv_bfloat16><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos); - } else if (x.scalar_type() == at::ScalarType::Float) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Kernel<<>>( - x.data_ptr(), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos); - } else { - TORCH_CHECK(false, - "topk_output: unsupported dtype ", - x.scalar_type()); - } - - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, - "topk_output kernel failed: ", ::cudaGetErrorString(result)); - } \ No newline at end of file +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace { + +constexpr int TopK = 2048; +constexpr int kThreadsPerBlock = 1024; + +#ifdef USE_ROCM +// On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a +// per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. +#ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES +constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); +#else +constexpr size_t kSmem = 48 * 1024; // bytes +#endif +#else +// Reduced from 128KB to 32KB to improve occupancy. +// Each radix pass needs at most ~TopK candidates in the threshold bin, +// so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. +constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) +#endif + +struct FastTopKParams { + const float* __restrict__ input; // [B, input_stride] + const int32_t* __restrict__ row_starts; // [B] + int32_t* __restrict__ indices; // [B, TopK] + int32_t* __restrict__ lengths; // [B] + int64_t input_stride; +}; + +// when length <= TopK, we can directly write the indices +__device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { + const auto tid = threadIdx.x; + for (int i = tid; i < TopK; i += kThreadsPerBlock) { + indice[i] = (i < length) ? i : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform( + const float* __restrict__ score, + int32_t length, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + dst_page_table[i] = (i < length) ? src_page_table[i] : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform_ragged( + const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; + } +} + +__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +// Include mapping strategies (must come after convert_to_uint8 definition) +#include "topk_mapping.cuh" + +__device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { + // An optimized topk kernel copied from tilelang kernel + // We assume length > TopK here, or it will crash + int topk = TopK; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + // allocate for two rounds + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // stage 1: 8bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(input[idx + row_start]); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uint8(input[idx + row_start])); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = input[idx + row_start]; + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + /// NOTE: (dark) fuse the histogram computation here + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8bit radix passes +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + // clip here to prevent overflow + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = input[idx + row_start]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[TopK - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + /// NOTE: (dark) fuse the histogram computation here + s_input_idx[r_idx ^ 1][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // topk + void topk_kernel(const FastTopKParams params) { + const auto& [input, row_starts, indices, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto indice = indices + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_cuda(score, indice, length); + } else { + return fast_topk_cuda_tl(score, indice, row_start, length); + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // decode + void topk_transform_decode_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride) { + const auto& [input, _1, _2, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = 0; + const auto length = lengths[bid]; + const auto src_page_entry = src_page_table + bid * src_stride; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // prefill + void topk_transform_prefill_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride, + const int32_t* __restrict__ cu_seqlens_q, + const int64_t prefill_bs) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto length = lengths[bid]; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + + /// NOTE: prefill bs is usually small, we can just use a simple loop here + /// We ensure that last cu_seqlens is equal to number of blocks launched + __shared__ const int32_t* s_src_page_entry; + if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { + if (tid < prefill_bs) { + if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { + s_src_page_entry = src_page_table + tid * src_stride; + } + } + } else { + for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { + if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { + s_src_page_entry = src_page_table + i * src_stride; + } + } + } + __syncthreads(); + const auto src_page_entry = s_src_page_entry; + + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv + void topk_transform_prefill_ragged_kernel( + const FastTopKParams params, + int32_t* __restrict__ topk_indices_ragged, + const int32_t* __restrict__ topk_indices_offset) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto dst_indices_entry = topk_indices_ragged + bid * TopK; + const auto score = input + bid * input_stride; + const auto offset = topk_indices_offset[bid]; + + if (length <= TopK) { + return naive_topk_transform_ragged(score, length, dst_indices_entry, offset); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_indices_entry[idx_0] = pos_0 + offset; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_indices_entry[idx_1] = pos_1 + offset; + } +} + +auto get_params( + const at::Tensor& score, + const at::Tensor& lengths, + std::optional row_starts_opt = std::nullopt, + std::optional indices_opt = std::nullopt) -> FastTopKParams { + const auto B = score.size(0); + TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); + if (row_starts_opt.has_value()) { + const auto& row_starts = row_starts_opt.value(); + TORCH_CHECK(row_starts.dim() == 1); + TORCH_CHECK(row_starts.size(0) == B); + } + TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); + TORCH_CHECK(lengths.size(0) == B); + int32_t* indices_data_ptr = nullptr; + if (indices_opt.has_value()) { + const auto& indices = indices_opt.value(); + TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); + TORCH_CHECK(indices.size(0) == B); + TORCH_CHECK(indices.size(1) == TopK); + indices_data_ptr = indices.data_ptr(); + } + + return FastTopKParams{ + .input = score.data_ptr(), + .row_starts = row_starts_opt.has_value() ? row_starts_opt->data_ptr() : nullptr, + .indices = indices_data_ptr, + .lengths = lengths.data_ptr(), + .input_stride = score.stride(0), + }; +} + +template +void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { +#ifdef USE_ROCM + // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, + // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing + // a function pointer directly, so cast explicitly. + return ::cudaFuncSetAttribute( + reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#else + // CUDA: keep original behavior (no cast needed). + return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#endif + }(); + TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); +} + +// ====================================================================== +// Vortex integration: BOS/EOS-aware segmented TopK with index remapping +// ====================================================================== + +template +__device__ __forceinline__ float vortex_to_float(T x); + +template <> +__device__ __forceinline__ float vortex_to_float(float x) { return x; } + +template <> +__device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { + return __bfloat162float(x); +} + +constexpr int VORTEX_MAX_TOPK = 2048; + +// Templated version of fast_topk_cuda_tl: +// - ScoreT: float or __nv_bfloat16 +// - target_k: runtime parameter (replaces compile-time TopK) +// - mapping: configurable value-remapping for Stage 1 bin assignment +template +__device__ void fast_topk_vortex( + const ScoreT* __restrict__ input, + int* __restrict__ index, + int row_start, + int length, + int target_k, + const TopKMappingParams& mapping) +{ + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int vh_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int vh_counter; + alignas(128) __shared__ int vh_threshold_bin_id; + alignas(128) __shared__ int vh_num_input[2]; + + // Shared memory for mapping LUT / quantiles (loaded once per block) + __shared__ uint8_t s_mapping_lut[256]; + __shared__ float s_mapping_quantiles[256]; + + // Auto-range for transform modes (3/4/6/7) + __shared__ float s_range_min, s_range_inv_range; + + auto& vh_histogram = vh_histogram_buf[0]; + extern __shared__ int vh_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // Load mapping tables into shared memory if needed + if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { + if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; + __syncthreads(); + } + if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { + if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; + __syncthreads(); + } + + // Pre-pass: compute per-block min/max of transformed values for linear bucketing + if (needs_auto_range(mapping.mode)) { + float local_min = __FLT_MAX__, local_max = -__FLT_MAX__; + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + float val = apply_transform(vortex_to_float(input[idx + row_start]), mapping); + local_min = fminf(local_min, val); + local_max = fmaxf(local_max, val); + } + // Warp-level reduction + for (int offset = 16; offset > 0; offset >>= 1) { + local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + } + // Cross-warp reduction via shared memory + __shared__ float s_warp_mins[32], s_warp_maxs[32]; + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) { s_warp_mins[warp_id] = local_min; s_warp_maxs[warp_id] = local_max; } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_min = s_warp_mins[tx]; local_max = s_warp_maxs[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + } + if (tx == 0) { + s_range_min = local_min; + float range = local_max - local_min; + s_range_inv_range = (range > 0.0f) ? 255.0f / range : 0.0f; + } + } + __syncthreads(); + } else { + if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } + __syncthreads(); + } + + // Stage 1: 8-bit coarse histogram (with optional mapping) + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = mapped_convert_to_uint8( + vortex_to_float(input[idx + row_start]), + mapping, s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range); + ::atomicAdd(&vh_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = vh_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += vh_histogram_buf[k][tx + j]; + } + vh_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { + vh_threshold_bin_id = tx; + vh_num_input[0] = 0; + vh_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = vh_threshold_bin_id; + topk -= vh_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast( + mapped_convert_to_uint8( + vortex_to_float(input[idx + row_start]), + mapping, s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto bin = static_cast( + mapped_convert_to_uint8(raw_input, mapping, + s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&vh_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + vh_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&vh_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // Stage 2: refine with 8-bit radix passes (unchanged — uses raw float bits) +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int vh_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = vh_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) + ? _raw_num_input + : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { + vh_threshold_bin_id = tx; + vh_num_input[r_idx ^ 1] = 0; + vh_last_remain = topk - vh_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = vh_threshold_bin_id; + topk -= vh_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = vh_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32( + vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = vh_input_idx[r_idx][i]; + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&vh_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&vh_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + vh_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&vh_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +// Wrapper kernel: one CUDA block per batch*head segment +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKOutput_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos, + const TopKMappingParams mapping) +{ + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_vortex(score_blk, s_indices, 0, nblk, topk_val, mapping); + __syncthreads(); + + // Remap position indices -> page indices via dense_kv_indices + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } +} + +// ====================================================================== +// Profiling histogram kernel: runs only Stage 1 and returns per-segment +// 256-bin histograms for distribution analysis +// ====================================================================== +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKHistogram_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + int* __restrict__ histograms, // [eff_batch_size, 256] + const int page_reserved_bos, + const int page_reserved_eos, + const TopKMappingParams mapping) +{ + constexpr auto RADIX = 256; + constexpr auto BLOCK_SIZE = kThreadsPerBlock; + __shared__ int s_histogram[RADIX]; + __shared__ uint8_t s_mapping_lut[256]; + __shared__ float s_mapping_quantiles[256]; + __shared__ float s_range_min, s_range_inv_range; + + const int bx = blockIdx.x; + const int tx = threadIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + + const ScoreT* __restrict__ score_blk = score + start; + + // Load mapping tables into shared memory if needed + if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { + if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; + __syncthreads(); + } + if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { + if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; + __syncthreads(); + } + + // Pre-pass: compute per-block min/max for transform modes + if (needs_auto_range(mapping.mode)) { + float local_min = __FLT_MAX__, local_max = -__FLT_MAX__; + for (int idx = tx; idx < nblk; idx += BLOCK_SIZE) { + float val = apply_transform(vortex_to_float(score_blk[idx]), mapping); + local_min = fminf(local_min, val); + local_max = fmaxf(local_max, val); + } + for (int offset = 16; offset > 0; offset >>= 1) { + local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + } + __shared__ float s_warp_mins[32], s_warp_maxs[32]; + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) { s_warp_mins[warp_id] = local_min; s_warp_maxs[warp_id] = local_max; } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_min = s_warp_mins[tx]; local_max = s_warp_maxs[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + } + if (tx == 0) { + s_range_min = local_min; + float range = local_max - local_min; + s_range_inv_range = (range > 0.0f) ? 255.0f / range : 0.0f; + } + } + __syncthreads(); + } else { + if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } + __syncthreads(); + } + + // Initialize shared histogram + if (tx < RADIX) s_histogram[tx] = 0; + __syncthreads(); + + // Build histogram over the segment with mapping + for (int idx = tx; idx < nblk; idx += BLOCK_SIZE) { + const auto bin = mapped_convert_to_uint8( + vortex_to_float(score_blk[idx]), + mapping, s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + // Write to global memory + int* __restrict__ out = histograms + bx * RADIX; + if (tx < RADIX) { + out[tx] = s_histogram[tx]; + } +} + +} // namespace + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +void fast_topk_interface( + const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(indices); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + CHECK_CUDA(lengths); + const auto params = get_params(score, lengths, row_starts_opt, indices); + const auto B = score.size(0); + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + setup_kernel_smem_once(); + topk_kernel<<>>(params); + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} + +void fast_topk_transform_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& dst_page_table, + const at::Tensor& src_page_table, + const at::Tensor& cu_seqlens_q, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(dst_page_table); + CHECK_CUDA(src_page_table); + CHECK_CUDA(cu_seqlens_q); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); + TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); + TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); + const auto prefill_bs = cu_seqlens_q.size(0) - 1; + TORCH_CHECK(dst_page_table.size(0) == B); + TORCH_CHECK(dst_page_table.size(1) == TopK); + TORCH_CHECK(src_page_table.size(0) == prefill_bs); + TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + const auto src_stride = src_page_table.stride(0); + + // dispatch to decode or prefill + // extend and draft extend: row_starts_opt is not null, invokes the prefill kernel + // decode: row_starts_opt is null, invokes the decode kernel + // target verify: row_starts_opt is null, invokes the prefill kernel + const auto is_decode = !row_starts_opt.has_value() && prefill_bs == B; + if (is_decode) { + setup_kernel_smem_once(); + topk_transform_decode_kernel<<>>( + params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride); + } else { + setup_kernel_smem_once(); + topk_transform_prefill_kernel<<>>( + params, + dst_page_table.data_ptr(), + src_page_table.data_ptr(), + src_stride, + cu_seqlens_q.data_ptr(), + prefill_bs); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} + +void fast_topk_transform_ragged_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& topk_indices_ragged, + const at::Tensor& topk_indices_offset, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(topk_indices_ragged); + CHECK_CUDA(topk_indices_offset); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous()); + TORCH_CHECK(topk_indices_offset.dim() == 1); + + TORCH_CHECK(topk_indices_ragged.size(0) == B); + TORCH_CHECK(topk_indices_ragged.size(1) == TopK); + TORCH_CHECK(topk_indices_offset.size(0) == B); + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + + setup_kernel_smem_once(); + topk_transform_prefill_ragged_kernel<<>>( + params, topk_indices_ragged.data_ptr(), topk_indices_offset.data_ptr()); + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} + +// ====================================================================== +// Vortex host entry point — same interface as topk_output in topk.cu +// ====================================================================== +void topk_output_sglang( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages, + const int64_t mapping_mode, + const double mapping_power, + std::optional mapping_lut, + std::optional mapping_quantiles) +{ + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_output: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + + // Build mapping params from optional tensors + TopKMappingParams mapping{}; + mapping.mode = static_cast(mapping_mode); + mapping.power_exp = static_cast(mapping_power); + mapping.lut = nullptr; + mapping.quantiles = nullptr; + + if (mapping_lut.has_value()) { + const auto& lut = mapping_lut.value(); + CHECK_CUDA(lut); + TORCH_CHECK(lut.dim() == 1 && lut.size(0) == 256 && lut.scalar_type() == at::ScalarType::Byte, + "mapping_lut must be a 1D uint8 tensor of size 256"); + mapping.lut = lut.data_ptr(); + } + if (mapping_quantiles.has_value()) { + const auto& q = mapping_quantiles.value(); + CHECK_CUDA(q); + TORCH_CHECK(q.dim() == 1 && q.size(0) == 256 && q.scalar_type() == at::ScalarType::Float, + "mapping_quantiles must be a 1D float32 tensor of size 256"); + mapping.quantiles = q.data_ptr(); + } + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos, + mapping); + } else if (x.scalar_type() == at::ScalarType::Float) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos, + mapping); + } else { + TORCH_CHECK(false, + "topk_output: unsupported dtype ", + x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_output kernel failed: ", ::cudaGetErrorString(result)); +} + +// ====================================================================== +// Profiling: collect per-segment 256-bin histograms of Stage 1 bins +// ====================================================================== +void topk_profile_histogram( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + at::Tensor& histograms, + const int64_t eff_batch_size, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t mapping_mode, + const double mapping_power, + std::optional mapping_lut, + std::optional mapping_quantiles) +{ + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(histograms); + TORCH_CHECK(histograms.dim() == 2 && histograms.size(0) == eff_batch_size + && histograms.size(1) == 256, + "histograms must be [eff_batch_size, 256]"); + TORCH_CHECK(histograms.scalar_type() == at::ScalarType::Int, + "histograms must be int32"); + + // Build mapping params + TopKMappingParams mapping{}; + mapping.mode = static_cast(mapping_mode); + mapping.power_exp = static_cast(mapping_power); + mapping.lut = nullptr; + mapping.quantiles = nullptr; + + if (mapping_lut.has_value()) { + const auto& lut = mapping_lut.value(); + CHECK_CUDA(lut); + TORCH_CHECK(lut.dim() == 1 && lut.size(0) == 256 && lut.scalar_type() == at::ScalarType::Byte, + "mapping_lut must be a 1D uint8 tensor of size 256"); + mapping.lut = lut.data_ptr(); + } + if (mapping_quantiles.has_value()) { + const auto& q = mapping_quantiles.value(); + CHECK_CUDA(q); + TORCH_CHECK(q.dim() == 1 && q.size(0) == 256 && q.scalar_type() == at::ScalarType::Float, + "mapping_quantiles must be a 1D float32 tensor of size 256"); + mapping.quantiles = q.data_ptr(); + } + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + TopKHistogram_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + histograms.data_ptr(), + reserved_bos, + reserved_eos, + mapping); + } else if (x.scalar_type() == at::ScalarType::Float) { + TopKHistogram_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + histograms.data_ptr(), + reserved_bos, + reserved_eos, + mapping); + } else { + TORCH_CHECK(false, + "topk_profile_histogram: unsupported dtype ", + x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_profile_histogram kernel failed: ", ::cudaGetErrorString(result)); +} + diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..d14650f --- /dev/null +++ b/examples/README.md @@ -0,0 +1,399 @@ +# Vortex Torch Examples + +End-to-end accuracy evaluation and profiling pipelines for Vortex sparse attention on top of the SGLang inference engine. The scripts in this directory evaluate different TopK kernel variants, mapping functions, KV-cache quantization settings, and external sparse-attention backends on math reasoning benchmarks. + +--- + +## Mapping Functions Reference + +The TopK Stage-1 radix histogram uses 256 uint8 bins. A **mapping function** transforms raw attention scores before binning to improve bucket uniformity and reduce tail latency. Set via `--topk-mapping-mode`. + +| Mode | Name | Formula | Requires Calibration | Hyperparameter (`--topk-mapping-power`) | +|------|------|---------|---------------------|-----------------------------------------| +| 0 | None | FP16 bit-pattern bucketing | No | — | +| 1 | LUT CDF | `lut[original_bin]` (CDF equalization) | Yes (`--topk-mapping-lut-path`) | — | +| 2 | Quantile | Binary search over 256 float thresholds | Yes (`--topk-mapping-quantiles-path`) | — | +| 3 | Power | `sign(x) * \|x\|^p` | No | `p` (exponent, default 0.5) | +| 4 | Log | `sign(x) * log(\|x\| + 1)` | No | — | +| 5 | Index Cache | Reuse top-k indices from a preceding layer | No | — (see `--index-cache-shared-layers`) | +| 6 | Asinh | `asinh(beta * x)` | No | `beta` (default 0.5) | +| 7 | Log1p | `sign(x) * log1p(alpha * \|x\|)` | No | `alpha` (default 0.5) | +| 8 | Trunc8 | BF16 upper-8-bit bucketing | No | — | + +Modes 1 and 2 require an offline calibration step (see `calibrate_topk.py` in `benchmarks/`). Modes 3, 6, and 7 accept a tunable hyperparameter via `--topk-mapping-power`. + +--- + +## Python Scripts + +### `verify_algo.py` — End-to-End Accuracy Benchmark + +The primary evaluation script. Loads AMC 2023 math problems from `amc23.jsonl`, runs inference via the SGLang engine with Vortex sparse attention, and scores answers using `lighteval`'s extractive-match metric. Reports `mean@N`, `pass@N`, throughput, and memory access cost. + +**Usage:** + +```bash +python verify_algo.py [OPTIONS] +``` + +**CLI Arguments:** + +| Argument | Default | Description | +|----------|---------|-------------| +| `--trials` | 2 | Number of trials (each prompt repeated N times) | +| `--topk-val` | 30 | Number of top-k pages to select per segment | +| `--page-size` | 16 | Tokens per KV-cache page | +| `--vortex-module-name` | `gqa_block_sparse_attention` | Sparse attention algorithm module | +| `--model-name` | `Qwen/Qwen3-1.7B` | HuggingFace model identifier | +| `-f`, `--full-attention` | off | Disable sparse attention (full-attention baseline) | +| `--mem` | 0.8 | Static GPU memory fraction for SGLang | +| `--kv-cache-dtype` | `auto` | KV cache dtype: `auto`, `fp8_e5m2`, `fp8_e4m3`, `int8` | +| `--topk-type` | `naive` | TopK kernel: `naive` (CUB radix sort) or `sglang` (fast two-stage radix) | +| `--topk-mapping-mode` | 0 | Mapping function for Stage-1 binning (see table above) | +| `--topk-mapping-power` | 0.5 | Hyperparameter for modes 3/6/7 | +| `--topk-mapping-lut-path` | None | `.npy` uint8[256] LUT for mode 1 | +| `--topk-mapping-quantiles-path` | None | `.npy` float32[256] quantiles for mode 2 | +| `--index-cache-shared-layers` | None | Layer IDs that skip the indexer and reuse a previous layer's indices | + +**Fixed engine settings:** `attention_backend=flashinfer`, `vortex_max_seq_lens=12288`, layer 0 skipped, `reserved_bos=1`, `reserved_eos=2`. Sampling: `temperature=0.6`, `top_p=0.95`, `top_k=20`, `max_new_tokens=8192`. + +**Index cache note (mode 5):** When `--topk-mapping-mode 5` is set without `--index-cache-shared-layers`, the script defaults to even layers `[2, 4, 6, ..., 26]` and internally resets the mapping mode to 0 while passing the shared-layer list to the engine. + +**Example — full-attention baseline:** + +```bash +python verify_algo.py --full-attention --trials 8 --mem 0.7 +``` + +**Example — sglang TopK with power mapping:** + +```bash +python verify_algo.py \ + --topk-type sglang \ + --topk-mapping-mode 3 \ + --topk-mapping-power 0.25 \ + --trials 8 --topk-val 30 --mem 0.7 +``` + +**Example — sglang TopK with calibrated LUT:** + +```bash +python verify_algo.py \ + --topk-type sglang \ + --topk-mapping-mode 1 \ + --topk-mapping-lut-path calibration/lut.npy \ + --trials 8 --topk-val 30 --mem 0.7 +``` + +--- + +### `verify_aim24.py` — AIME 2024 Throughput Test (Legacy) + +A standalone throughput script that loads AIME 2024 from HuggingFace (`HuggingFaceH4/aime_2024`), builds chat prompts using the Qwen3 tokenizer with `enable_thinking=True`, and repeats each prompt 8 times. Outputs a JSONL file with generation results and timing metadata. Does **not** compute accuracy metrics. + +**Usage:** + +```bash +python verify_aim24.py +``` + +All settings are hard-coded (no CLI arguments): + +| Setting | Value | +|---------|-------| +| Model | `Qwen/Qwen3-0.6B` | +| Page size | 16 | +| Selected pages | 29 | +| Max sequence length | 20480 | +| Module | `block_sparse_attention` | +| Memory fraction | 0.9 | +| Max new tokens | 16384 | +| CUDA graph | Enabled | + +--- + +## Shell Scripts + +All shell scripts set `CUDA_VISIBLE_DEVICES` and save timestamped logs to `results/`. + +### `verify_algo.sh` — Baseline TopK Comparison (Naive vs SGLang) + +Runs `verify_algo.py` with `block_sparse_attention` comparing the `naive` and `sglang` TopK kernels. Each configuration is repeated `REPEAT_COUNT` times (default 3, overridable via environment variable). + +```bash +REPEAT_COUNT=5 bash verify_algo.sh +``` + +### `verify_algo_topk.sh` — Naive vs SGLang Comparison + +Similar to `verify_algo.sh` but simpler: runs `naive` TopK and `sglang` TopK back-to-back for `block_sparse_attention`, each with 8 trials. + +### `verify_algo_quant.sh` — INT8 KV-Cache Quantization + +Tests sparse attention with `--kv-cache-dtype int8` to measure accuracy under quantized KV caches. + +```bash +bash verify_algo_quant.sh +``` + +### `verify_sparse_backends.sh` — External Sparse Attention Backends + +Evaluates three external sparse-attention algorithms integrated via the Vortex flow interface: + +- `nsa` (Native Sparse Attention) +- `fsa` (Flash Sparse Attention) +- `flash_moba` (Flash MoBA) + +```bash +bash verify_sparse_backends.sh +``` + +### `verify_algo_topk_mapping.sh` — Full Mapping Mode Sweep + +Comprehensive sweep across all mapping modes: + +1. **Baseline:** `naive` TopK, mode 0 +2. **Calibration:** runs `calibrate_topk.py` to generate `lut.npy` and `quantiles.npy` (skipped if files exist) +3. **Mode 1** (LUT CDF) and **Mode 2** (Quantile) with calibrated tables +4. **Modes 0, 3, 4** (no calibration needed) — Power mode uses `--topk-mapping-power 0.5` +5. **Mode 6** (Asinh) — sweeps `beta` in `[0.5, 1.0, 2.0]` +6. **Mode 7** (Log1p) — sweeps `alpha` in `[0.5, 1.0, 2.0]` + +```bash +export CUDA_VISIBLE_DEVICES=0 +bash verify_algo_topk_mapping.sh +``` + +### `verify_algo_topk_mapping_new.sh` — Parametric Mapping Sweep (Modes 3, 6, 7) + +Focused hyperparameter sweep for the three parametric modes, preceded by an auto-tuning step: + +| Mode | Parameter | Sweep Values | +|------|-----------|-------------| +| 3 (Power) | `p` | 0.1, 0.25, 0.75, 0.9 | +| 6 (Asinh) | `beta` | 0.1, 0.5, 1.0, 2.0, 4.0 | +| 7 (Log1p) | `alpha` | 0.1, 0.5, 0.75, 1.0, 2.0, 4.0, 8.0 | + +Requires `calibration/raw_histograms.npy` for the auto-tune step. + +```bash +export CUDA_VISIBLE_DEVICES=5 +bash verify_algo_topk_mapping_new.sh +``` + +### `verify_algo_topk_mapping_indexcache.sh` — Index Cache (Mode 5) + +Tests the index-cache optimization where even-numbered layers `[2, 4, 6, ..., 26]` reuse top-k indices from the nearest preceding full layer, skipping their indexer entirely. + +```bash +bash verify_algo_topk_mapping_indexcache.sh +``` + +### `run_topk_benchmark.sh` — Unified TopK Benchmark Pipeline + +The most comprehensive benchmarking script. Three-step pipeline: + +1. **Calibrate** — collect real-data histograms + LUT/quantile tables +2. **Kernel bench** — latency + histogram profiling across batch sizes, sequence lengths, and distributions, followed by distribution analysis plots and auto-tuning +3. **E2E accuracy** — full-attention baseline plus every mapping mode + +```bash +bash run_topk_benchmark.sh --gpu 5 --trials 8 --model-name Qwen/Qwen3-1.7B +``` + +| Option | Default | Description | +|--------|---------|-------------| +| `--model-name` | `Qwen/Qwen3-1.7B` | HuggingFace model | +| `--topk-val` | 30 | Top-k pages | +| `--trials` | 8 | E2E trial count | +| `--mem` | 0.7 | GPU memory fraction | +| `--gpu` | 5 | CUDA device | +| `--algo` | `block_sparse_attention` | Sparse attention algorithm | +| `--skip-calibrate` | off | Reuse existing calibration | +| `--skip-kernel` | off | Skip kernel-level latency step | +| `--skip-e2e` | off | Skip E2E accuracy step | + +### `run_distribution_analysis.sh` — Bucket Distribution Profiling (All Modes) + +Three-step pipeline to analyze how each mapping mode affects the 256-bin bucket distribution: + +1. **Calibrate** — collect real-data histograms (skippable with `--real-histograms`) +2. **Bench** — histogram profiling with modes 0–8 on `bucket_uniform` and `normal` distributions +3. **Analyze** — generate comparison plots and CSV bucket count tables + +```bash +bash run_distribution_analysis.sh --gpu 5 +bash run_distribution_analysis.sh --gpu 5 --real-histograms /path/to/raw_histograms.npy +``` + +### `run_distribution_analysis_new.sh` — Bucket Distribution Profiling (Modes 3, 6, 7) + +Same pipeline as above but focused on parametric modes only, with an additional auto-tune step: + +1. **Calibrate** (or skip with existing histograms) +2. **Auto-tune** — sweep hyperparameters on synthetic data +3. **Bench** — histogram profiling for modes 3, 6, 7, 8 +4. **Analyze** — comparison plots + tables + +```bash +bash run_distribution_analysis_new.sh --gpu 5 +``` + +--- + +## Benchmarks Directory Scripts + +The `benchmarks/` directory contains standalone profiling and analysis tools used by the shell pipelines above. These can also be run independently. + +### `calibrate_topk.py` — Offline Calibration + +Runs the SGLang engine on real prompts from `amc23.jsonl` with histogram collection enabled. Produces three files: + +- `lut.npy` — uint8[256] CDF-equalized LUT for mode 1 +- `quantiles.npy` — float32[256] quantile breakpoints for mode 2 +- `raw_histograms.npy` — raw per-sample 256-bin histograms + +```bash +python benchmarks/calibrate_topk.py \ + --model-name Qwen/Qwen3-1.7B \ + --topk-val 30 --mem 0.7 \ + --output-dir calibration/ +``` + +### `bench_topk.py` — Kernel-Level Latency Benchmark + +Benchmarks `topk_output` (naive/CUB) and `topk_output_sglang` (fast radix) across configurable sweeps of batch size, sequence length, TopK value, KV heads, and score distributions. Optionally collects 256-bin histogram statistics. + +```bash +python benchmarks/bench_topk.py \ + --batch-sizes 4 8 16 \ + --seq-lens 2048 4096 8192 \ + --topk-vals 30 \ + --num-kv-heads 2 \ + --distributions normal lognormal uniform bucket_uniform \ + --histogram \ + --repeat 100 \ + --output-json results.json +``` + +### `autotune_topk_mapping.py` — Hyperparameter Auto-Tuning + +Sweeps hyperparameters for parametric mapping modes (3, 6, 7) using the `topk_profile_histogram` kernel on synthetic data. Ranks configurations by resolution rate, Gini coefficient, max/mean ratio, and nonzero bins. + +```bash +python benchmarks/autotune_topk_mapping.py \ + --topk-val 30 --batch-size 4 --seq-len 4096 --num-kv-heads 2 \ + --real-histograms calibration/raw_histograms.npy \ + --output-json autotune_results.json +``` + +### `analyze_topk_distribution.py` — Visualization and Analysis + +Loads profiling data and generates: +- Per-segment 256-bin bar charts +- Heatmaps (segments x bins, log-scale) +- Before/after LUT mapping comparisons +- Mode comparison grouped bar charts (Gini + max/mean) +- Distribution comparison plots across data sources +- CSV bucket count tables + +```bash +python benchmarks/analyze_topk_distribution.py \ + --bench-json bench_distribution.json \ + --real-histograms calibration/raw_histograms.npy \ + --output-dir plots/ +``` + +### `profile_topk_distribution.py` — Offline Table Generation + +Computes LUT and quantile tables from pre-collected histograms or raw scores without running a model. Outputs a single `.npz` archive. + +```bash +python benchmarks/profile_topk_distribution.py \ + --histograms-input raw_histograms.npy \ + --output mapping_tables.npz +``` + +### `greedy_layer_search.py` — Index Cache Layer Selection + +Greedy forward-selection of layers whose indexer can be skipped (index cache). Iteratively adds layers to the shared set as long as accuracy stays above `--threshold` times the baseline. + +```bash +cd examples && python ../benchmarks/greedy_layer_search.py \ + --model-name Qwen/Qwen3-1.7B \ + --topk-val 30 \ + --threshold 0.95 \ + --trials 1 \ + --num-layers 28 \ + --mem 0.7 +``` + +--- + +## Data Files + +| File | Description | +|------|-------------| +| `amc23.jsonl` | AMC 2023 math problems with `prompt` and `answer` fields, used by `verify_algo.py` and `calibrate_topk.py` | + +--- + +## Output Structure + +Results are saved under `results/` in timestamped directories: + +``` +results/ +├── dist_analysis_YYYYMMDD_HHMMSS/ +│ ├── step1_calibrate.log +│ ├── step2_autotune.log / step2_bench.log +│ ├── step3_bench.log / step3_analyze.log +│ ├── step4_analyze.log +│ ├── autotune_results.json +│ ├── bench_distribution.json +│ ├── distribution_comparison_*.png +│ ├── bucket_counts_*.csv +│ └── calibration/ +│ ├── lut.npy +│ ├── quantiles.npy +│ └── raw_histograms.npy +├── topk_benchmark_YYYYMMDD_HHMMSS/ +│ ├── kernel_latency.json +│ ├── e2e/ +│ │ ├── full_attention_baseline.log +│ │ ├── sglang_mode0_none.log +│ │ └── ... +│ └── calibration/ +└── *.log (individual run logs) +``` + +--- + +## Quick Start: Typical Workflow + +```bash +export CUDA_VISIBLE_DEVICES=0 + +# 1. Calibrate to generate LUT + quantile tables +python benchmarks/calibrate_topk.py \ + --model-name Qwen/Qwen3-1.7B --topk-val 30 --mem 0.7 \ + --output-dir examples/calibration/ + +# 2. Run full-attention baseline +python examples/verify_algo.py --full-attention --trials 8 --mem 0.7 + +# 3. Evaluate sparse attention with different mapping modes +python examples/verify_algo.py \ + --topk-type sglang --topk-mapping-mode 0 --trials 8 --mem 0.7 + +python examples/verify_algo.py \ + --topk-type sglang --topk-mapping-mode 3 --topk-mapping-power 0.25 \ + --trials 8 --mem 0.7 + +python examples/verify_algo.py \ + --topk-type sglang --topk-mapping-mode 6 --topk-mapping-power 1.0 \ + --trials 8 --mem 0.7 + +# 4. Or run the full pipeline in one shot +bash examples/run_topk_benchmark.sh --gpu 0 --trials 8 +``` diff --git a/examples/run_distribution_analysis.sh b/examples/run_distribution_analysis.sh new file mode 100755 index 0000000..287c454 --- /dev/null +++ b/examples/run_distribution_analysis.sh @@ -0,0 +1,141 @@ +#!/usr/bin/env bash +# ============================================================ +# Bucket Distribution Profiling Pipeline +# +# Profiles the SGLang TopK kernel's first-pass bucket distribution +# to identify hotspot buckets causing tail latency. +# +# Three steps: +# 1. Calibrate — collect real-data histograms +# (skippable via --real-histograms PATH) +# 2. Bench — histogram profiling (bucket_uniform + normal) +# 3. Analyze — comparison plots + bucket count tables +# +# All outputs (JSON, plots, CSV tables, logs) are written to a +# single timestamped folder under examples/results/dist_analysis_*. +# +# Usage: +# bash run_distribution_analysis.sh --gpu 5 +# bash run_distribution_analysis.sh --gpu 5 \ +# --real-histograms /path/to/calibration_dir/raw_histograms.npy +# ============================================================ + +# Mapping functions: +# 0: None — original fp16 bit-pattern bucketing +# 1: LUT CDF — LUT-based CDF equalization (calibrated) +# 2: Quantile — piecewise-linear quantile mapping (calibrated) +# 3: Power — y = sign(x) * |x|^p +# 4: Log — y = sign(x) * log(|x| + 1) +# 5: Index Cache — reuse previous layer's indices +# 6: Asinh — y = asinh(beta * x) +# 7: Log1p — y = sign(x) * log1p(alpha * |x|) +# 8: Trunc8 — bf16 upper-8-bit bucketing + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=5 +MODEL_NAME="Qwen/Qwen3-1.7B" +TOPK_VAL=30 +MEM=0.7 +ALGO="block_sparse_attention" +# The path to the raw_histograms.npy file (set to skip calibration) +# REAL_HISTOGRAMS="/scr/dataset/yuke/xinrui/new/vortex_torch/examples/calibration/raw_histograms.npy" +REAL_HISTOGRAMS="" +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +RUN_DIR="${RESULTS_DIR}/dist_analysis_${TIMESTAMP}" +mkdir -p "${RUN_DIR}" + +echo "============================================================" +echo "Bucket Distribution Profiling Pipeline" +echo " Model: ${MODEL_NAME}" +echo " Algorithm: ${ALGO}" +echo " TopK: ${TOPK_VAL}" +echo " GPU: ${GPU_ID}" +echo " Real histograms: ${REAL_HISTOGRAMS:-}" +echo " Output: ${RUN_DIR}" +echo "============================================================" + +# ── Step 1: Calibrate — collect real-data histograms + LUT/quantiles ── +if [ -n "${REAL_HISTOGRAMS}" ]; then + echo "" + echo ">>> Step 1: SKIPPED (using provided --real-histograms ${REAL_HISTOGRAMS})" + REAL_HIST_PATH="${REAL_HISTOGRAMS}" +else + echo "" + echo ">>> Step 1: Calibrating — collecting real-inference histograms" + CALIBRATION_DIR="${RUN_DIR}/calibration" + mkdir -p "${CALIBRATION_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --mem "${MEM}" \ + --vortex-module-name "${ALGO}" \ + --output-dir "${CALIBRATION_DIR}" \ + 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" + REAL_HIST_PATH="${CALIBRATION_DIR}/raw_histograms.npy" + echo ">>> Step 1: Done. Calibration saved to ${CALIBRATION_DIR}" +fi + +# ── Step 2: Histogram profiling (bucket_uniform + normal) ───── +echo "" +echo ">>> Step 2: Kernel-level histogram profiling (bucket_uniform + normal)" + +BENCH_JSON="${RUN_DIR}/bench_distribution.json" + +python "${BENCH_DIR}/bench_topk.py" \ + --batch-sizes 4 \ + --seq-lens 4096 \ + --topk-vals "${TOPK_VAL}" \ + --num-kv-heads 2 \ + --distributions bucket_uniform normal \ + --histogram \ + --filter-kernels sglang_m0 sglang_m1 sglang_m2 sglang_m3 sglang_m4 sglang_m6 sglang_m7 sglang_m8 \ + --repeat 20 \ + --output-json "${BENCH_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2_bench.log" + +echo ">>> Step 2: Done. Results saved to ${BENCH_JSON}" + +# ── Step 3: Analyze — comparison plots + tables ─────────────── +echo "" +echo ">>> Step 3: Generating distribution comparison plots + tables" + +python "${BENCH_DIR}/analyze_topk_distribution.py" \ + --bench-json "${BENCH_JSON}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --output-dir "${RUN_DIR}" \ + 2>&1 | tee "${RUN_DIR}/step3_analyze.log" + +echo ">>> Step 3: Done." + +# ── Summary ─────────────────────────────────────────────────── +echo "" +echo "============================================================" +echo "Bucket Distribution Profiling Complete" +echo " All outputs in: ${RUN_DIR}/" +echo " bench_distribution.json — raw benchmark data" +echo " distribution_comparison.png — bucket dist plots" +echo " bucket_counts.csv — per-bucket count table" +echo " step{1,2,3}_*.log — pipeline logs" +echo "============================================================" diff --git a/examples/run_distribution_analysis_new.sh b/examples/run_distribution_analysis_new.sh new file mode 100755 index 0000000..3dc1bd4 --- /dev/null +++ b/examples/run_distribution_analysis_new.sh @@ -0,0 +1,150 @@ +#!/usr/bin/env bash +# ============================================================ +# Bucket Distribution Profiling Pipeline (modes 3, 6, 7 only) +# +# Tests only the parametric mapping modes with auto-tuning: +# Mode 3 (Power): y = sign(x) * |x|^p +# Mode 6 (Asinh): y = asinh(beta * x) +# Mode 7 (Log1p): y = sign(x) * log1p(alpha * |x|) +# Mode 8 (Trunc8): bf16 upper-8-bit bucketing +# +# Four steps: +# 1. Calibrate — collect real-data histograms +# (skippable via --real-histograms PATH) +# 2. Auto-tune — sweep hyperparameters on synthetic data +# 3. Bench — histogram profiling (bucket_uniform + normal) +# 4. Analyze — comparison plots + bucket count tables +# +# Usage: +# bash run_distribution_analysis_new.sh --gpu 5 +# bash run_distribution_analysis_new.sh --gpu 5 \ +# --real-histograms /path/to/calibration_dir/raw_histograms.npy +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=5 +MODEL_NAME="Qwen/Qwen3-1.7B" +TOPK_VAL=30 +MEM=0.7 +ALGO="block_sparse_attention" +# The path to the raw_histograms.npy file (set to skip calibration) +# REAL_HISTOGRAMS="/scr/dataset/yuke/xinrui/new/vortex_torch/examples/calibration/raw_histograms.npy" +REAL_HISTOGRAMS="${SCRIPT_DIR}/calibration/raw_histograms.npy" +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +RUN_DIR="${RESULTS_DIR}/dist_analysis_${TIMESTAMP}" +mkdir -p "${RUN_DIR}" + +echo "============================================================" +echo "Bucket Distribution Profiling (modes 3, 6, 7)" +echo " Model: ${MODEL_NAME}" +echo " Algorithm: ${ALGO}" +echo " TopK: ${TOPK_VAL}" +echo " GPU: ${GPU_ID}" +echo " Real histograms: ${REAL_HISTOGRAMS:-}" +echo " Output: ${RUN_DIR}" +echo "============================================================" + +# ── Step 1: Calibrate — collect real-data histograms + LUT/quantiles ── +if [ -n "${REAL_HISTOGRAMS}" ]; then + echo "" + echo ">>> Step 1: SKIPPED (using provided --real-histograms ${REAL_HISTOGRAMS})" + REAL_HIST_PATH="${REAL_HISTOGRAMS}" +else + echo "" + echo ">>> Step 1: Calibrating — collecting real-inference histograms" + CALIBRATION_DIR="${RUN_DIR}/calibration" + mkdir -p "${CALIBRATION_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --mem "${MEM}" \ + --vortex-module-name "${ALGO}" \ + --output-dir "${CALIBRATION_DIR}" \ + 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" + REAL_HIST_PATH="${CALIBRATION_DIR}/raw_histograms.npy" + echo ">>> Step 1: Done. Calibration saved to ${CALIBRATION_DIR}" +fi + +# ── Step 2: Auto-tune — sweep hyperparameters on synthetic data ───── +echo "" +echo ">>> Step 2: Auto-tuning hyperparameters (modes 3, 6, 7)" + +AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" + +PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val "${TOPK_VAL}" \ + --batch-size 4 \ + --seq-len 4096 \ + --num-kv-heads 2 \ + --real-histograms "${REAL_HIST_PATH}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2_autotune.log" + +echo ">>> Step 2: Done. Autotune results saved to ${AUTOTUNE_JSON}" + +# ── Step 3: Histogram profiling (bucket_uniform + normal) ───── +echo "" +echo ">>> Step 3: Kernel-level histogram profiling (modes 3, 6, 7)" + +BENCH_JSON="${RUN_DIR}/bench_distribution.json" + +PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --batch-sizes 4 \ + --seq-lens 4096 \ + --topk-vals "${TOPK_VAL}" \ + --num-kv-heads 2 \ + --distributions bucket_uniform normal \ + --histogram \ + --real-histograms "${REAL_HIST_PATH}" \ + --autotune-json "${AUTOTUNE_JSON}" \ + --filter-kernels sglang_m3 sglang_m6 sglang_m7 sglang_m8 \ + --repeat 20 \ + --output-json "${BENCH_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step3_bench.log" + +echo ">>> Step 3: Done. Results saved to ${BENCH_JSON}" + +# ── Step 4: Analyze — comparison plots + tables ─────────────── +echo "" +echo ">>> Step 4: Generating distribution comparison plots + tables" + +python "${BENCH_DIR}/analyze_topk_distribution.py" \ + --bench-json "${BENCH_JSON}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --output-dir "${RUN_DIR}" \ + 2>&1 | tee "${RUN_DIR}/step4_analyze.log" + +echo ">>> Step 4: Done." + +# ── Summary ─────────────────────────────────────────────────── +echo "" +echo "============================================================" +echo "Bucket Distribution Profiling Complete (modes 3, 6, 7)" +echo " All outputs in: ${RUN_DIR}/" +echo " autotune_results.json — hyperparameter sweep rankings" +echo " bench_distribution.json — raw benchmark data" +echo " distribution_comparison.png — bucket dist plots" +echo " bucket_counts.csv — per-bucket count table" +echo " step{1,2,3,4}_*.log — pipeline logs" +echo "============================================================" diff --git a/examples/run_topk_benchmark.sh b/examples/run_topk_benchmark.sh new file mode 100755 index 0000000..5a7ed94 --- /dev/null +++ b/examples/run_topk_benchmark.sh @@ -0,0 +1,294 @@ +#!/usr/bin/env bash +# ============================================================ +# TopK Benchmark +# +# Compares ALL TopK kernel variants under controlled conditions: +# Step 1: Calibrate (for modes 1/2) +# Step 2: Kernel-level latency (bench_topk.py, all 6 modes) +# Step 3: E2E accuracy (verify_algo.py) +# - Full-attention baseline first +# - Then naive, sglang mode 0/1/2/3/4 +# - Same model, same prompts, deterministic sampling +# +# Fairness improvements over verify_algo_topk_mapping.sh: +# - Full-attention baseline for absolute reference +# - All modes in one sweep (including calibrated 1/2) +# - Sequential runs on same CUDA device minimize interference +# - Deterministic sampling (temperature=0) for reproducibility +# - Results saved to a single timestamped directory +# +# Usage: +# bash run_topk_benchmark.sh [OPTIONS] +# +# Options: +# --model-name NAME HuggingFace model (default: Qwen/Qwen3-1.7B) +# --topk-val K Top-k value (default: 30) +# --trials N E2E trial count (default: 8) +# --mem FRAC GPU memory fraction (default: 0.7) +# --gpu GPU_ID CUDA device (default: 0) +# --algo NAME Sparse attention algorithm (default: block_sparse_attention) +# --skip-calibrate Reuse existing calibration data +# --skip-kernel Skip kernel-level benchmark (step 2) +# --skip-e2e Skip E2E accuracy benchmark (step 3) +# ============================================================ +set -euo pipefail + +# use GPU_ID to set the GPU id you want to use +GPU_ID=5 + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +MODEL_NAME="Qwen/Qwen3-1.7B" +TOPK_VAL=30 +TRIALS=8 +MEM=0.7 +ALGO="block_sparse_attention" +SKIP_CALIBRATE=false +SKIP_KERNEL=false +SKIP_E2E=true + +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --trials) TRIALS="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --skip-calibrate) SKIP_CALIBRATE=true; shift ;; + --skip-kernel) SKIP_KERNEL=true; shift ;; + --skip-e2e) SKIP_E2E=true; shift ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +RUN_DIR="${RESULTS_DIR}/topk_benchmark_${TIMESTAMP}" +mkdir -p "${RUN_DIR}" + +echo "============================================================" +echo "Fair Unified TopK Benchmark" +echo " Model: ${MODEL_NAME}" +echo " Algorithm: ${ALGO}" +echo " TopK: ${TOPK_VAL}" +echo " Trials: ${TRIALS}" +echo " GPU: ${GPU_ID}" +echo " Output: ${RUN_DIR}" +echo "============================================================" + +# ── Step 1: Calibrate (for modes 1/2) ──────────────────────── +CALIBRATION_DIR="${RUN_DIR}/calibration" +if [ "${SKIP_CALIBRATE}" = true ] && [ -d "${CALIBRATION_DIR}" ]; then + echo "" + echo ">>> Step 1: SKIPPED (--skip-calibrate)" +else + echo "" + echo ">>> Step 1: Calibrating — collecting histograms for LUT/quantile modes" + mkdir -p "${CALIBRATION_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --mem "${MEM}" \ + --vortex-module-name "${ALGO}" \ + --output-dir "${CALIBRATION_DIR}" \ + 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" + echo ">>> Step 1: Done." +fi + +# ── Step 2: Kernel-level latency benchmark ──────────────────── +if [ "${SKIP_KERNEL}" = true ]; then + echo "" + echo ">>> Step 2: SKIPPED (--skip-kernel)" +else + # Step 2a: Auto-tune parametric mapping modes (must run before bench) + echo "" + echo ">>> Step 2a: Auto-tuning parametric mapping hyperparameters" + AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" + REAL_HIST_ARGS="" + if [ -f "${CALIBRATION_DIR}/raw_histograms.npy" ]; then + REAL_HIST_ARGS="--real-histograms ${CALIBRATION_DIR}/raw_histograms.npy" + fi + python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val "${TOPK_VAL}" \ + --batch-size 4 \ + --seq-len 4096 \ + --num-kv-heads 2 \ + ${REAL_HIST_ARGS} \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2a_autotune.log" + echo ">>> Step 2a: Done. Autotune results saved to ${AUTOTUNE_JSON}" + + # Step 2b: Kernel-level latency + histogram benchmark (using autotune params) + echo "" + echo ">>> Step 2b: Kernel-level latency benchmark (all modes)" + + BENCH_JSON="${RUN_DIR}/kernel_latency.json" + + # Build calibration args + LUT_ARGS="" + if [ -f "${CALIBRATION_DIR}/lut.npy" ]; then + LUT_ARGS="--lut-path ${CALIBRATION_DIR}/lut.npy" + fi + QUANTILES_ARGS="" + if [ -f "${CALIBRATION_DIR}/quantiles.npy" ]; then + QUANTILES_ARGS="--quantiles-path ${CALIBRATION_DIR}/quantiles.npy" + fi + + python "${BENCH_DIR}/bench_topk.py" \ + --batch-sizes 4 8 16 32 \ + --seq-lens 2048 4096 8192 16384 \ + --topk-vals "${TOPK_VAL}" \ + --num-kv-heads 2 4 \ + --distributions normal lognormal uniform \ + --histogram \ + --hit-rate \ + --warmup 20 \ + --repeat 100 \ + ${LUT_ARGS} \ + ${QUANTILES_ARGS} \ + --autotune-json "${AUTOTUNE_JSON}" \ + --output-json "${BENCH_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2b_kernel_bench.log" + + echo ">>> Step 2b: Done. Results saved to ${BENCH_JSON}" + + # Step 2c: Per-mode distribution analysis + echo "" + echo ">>> Step 2c: Generating per-mode distribution analysis" + + python "${BENCH_DIR}/analyze_topk_distribution.py" \ + --bench-json "${BENCH_JSON}" \ + ${REAL_HIST_ARGS} \ + --output-dir "${RUN_DIR}" \ + 2>&1 | tee "${RUN_DIR}/step2c_analyze.log" + + echo ">>> Step 2c: Done. Per-mode plots saved to ${RUN_DIR}" +fi + +# ── Step 3: E2E accuracy comparison ────────────────────────── +if [ "${SKIP_E2E}" = true ]; then + echo "" + echo ">>> Step 3: SKIPPED (--skip-e2e)" +else + echo "" + echo ">>> Step 3: E2E accuracy comparison" + + E2E_DIR="${RUN_DIR}/e2e" + mkdir -p "${E2E_DIR}" + + # Helper: run verify_algo.py with common args and save output + run_e2e() { + local label="$1" + shift + local logfile="${E2E_DIR}/${label}.log" + echo "" + echo " --- ${label} ---" + { time python "${SCRIPT_DIR}/verify_algo.py" \ + --trials "${TRIALS}" \ + --topk-val "${TOPK_VAL}" \ + --model-name "${MODEL_NAME}" \ + --mem "${MEM}" \ + "$@" ; } \ + 2>&1 | tee "${logfile}" + } + + # 3a. Full-attention baseline (oracle) + run_e2e "full_attention_baseline" \ + --full-attention + + # 3b. Naive TopK + run_e2e "naive_mode0" \ + --vortex-module-name "${ALGO}" \ + --topk-type naive + + # 3c. SGLang mode 0 (no mapping) + run_e2e "sglang_mode0_none" \ + --vortex-module-name "${ALGO}" \ + --topk-type sglang \ + --topk-mapping-mode 0 + + # 3d. SGLang mode 1 (LUT CDF) — requires calibration + if [ -f "${CALIBRATION_DIR}/lut.npy" ]; then + run_e2e "sglang_mode1_lut_cdf" \ + --vortex-module-name "${ALGO}" \ + --topk-type sglang \ + --topk-mapping-mode 1 \ + --topk-mapping-lut-path "${CALIBRATION_DIR}/lut.npy" + else + echo " --- sglang_mode1_lut_cdf: SKIPPED (no lut.npy) ---" + fi + + # 3e. SGLang mode 2 (quantile) — requires calibration + if [ -f "${CALIBRATION_DIR}/quantiles.npy" ]; then + run_e2e "sglang_mode2_quantile" \ + --vortex-module-name "${ALGO}" \ + --topk-type sglang \ + --topk-mapping-mode 2 \ + --topk-mapping-quantiles-path "${CALIBRATION_DIR}/quantiles.npy" + else + echo " --- sglang_mode2_quantile: SKIPPED (no quantiles.npy) ---" + fi + + # 3f. SGLang mode 3 (power) + run_e2e "sglang_mode3_power" \ + --vortex-module-name "${ALGO}" \ + --topk-type sglang \ + --topk-mapping-mode 3 \ + --topk-mapping-power 0.5 + + # 3g. SGLang mode 4 (log) + run_e2e "sglang_mode4_log" \ + --vortex-module-name "${ALGO}" \ + --topk-type sglang \ + --topk-mapping-mode 4 + + # 3h. SGLang mode 6 (asinh) + run_e2e "sglang_mode6_asinh" \ + --vortex-module-name "${ALGO}" \ + --topk-type sglang \ + --topk-mapping-mode 6 \ + --topk-mapping-power 1.0 + + # 3i. SGLang mode 7 (log1p) + run_e2e "sglang_mode7_log1p" \ + --vortex-module-name "${ALGO}" \ + --topk-type sglang \ + --topk-mapping-mode 7 \ + --topk-mapping-power 1.0 + + echo "" + echo ">>> Step 3: Done. E2E logs saved to ${E2E_DIR}/" + + # ── Summary table: extract pass@N from each log ───────────── + echo "" + echo "============================================================" + echo "E2E Accuracy Summary" + echo "============================================================" + printf "%-35s %s\n" "Configuration" "Result" + printf "%-35s %s\n" "-----------------------------------" "------" + for logfile in "${E2E_DIR}"/*.log; do + label=$(basename "${logfile}" .log) + # Extract the last line matching pass@ pattern + result=$(grep -oP 'pass@\d+\s*[=:]\s*[\d.]+' "${logfile}" | tail -1 || echo "N/A") + printf "%-35s %s\n" "${label}" "${result}" + done + echo "============================================================" +fi + +# ── Final Summary ───────────────────────────────────────────── +echo "" +echo "============================================================" +echo "TopK Benchmark Complete" +echo " All results: ${RUN_DIR}" +echo " Calibration: ${CALIBRATION_DIR}" +[ "${SKIP_KERNEL}" != true ] && echo " Kernel JSON: ${RUN_DIR}/kernel_latency.json" +[ "${SKIP_KERNEL}" != true ] && echo " Per-mode: ${RUN_DIR}/distribution_comparison_m*.png, bucket_counts_m*.csv" +[ "${SKIP_E2E}" != true ] && echo " E2E logs: ${RUN_DIR}/e2e/" +echo "============================================================" diff --git a/examples/verify_algo.py b/examples/verify_algo.py index 91f92e7..e04f787 100644 --- a/examples/verify_algo.py +++ b/examples/verify_algo.py @@ -11,7 +11,11 @@ from lighteval.models.model_output import ModelResponse from datasets import load_dataset, Dataset, concatenate_datasets import argparse +import ast import json +import os +import subprocess +import sys MATH_QUERY_TEMPLATE = """ Solve the following math problem efficiently and clearly. The last line of your response should be of the following format: 'Therefore, the final answer is: $\\boxed{{ANSWER}}$. I hope it is correct' (without quotes) where ANSWER is just the final number or expression that solves the problem. Think step by step before answering. @@ -57,10 +61,16 @@ def verify_algos( mem: float = 0.8, kv_cache_dtype: str = "auto", topk_type: str = "naive", -): +topk_mapping_mode: int = 0, +topk_mapping_power: float = 0.5, +topk_mapping_lut_path: str = None, +topk_mapping_quantiles_path: str = None, +index_cache_shared_layers: list = None, +disable_cuda_graph: bool = False, +): llm = sgl.Engine(model_path=model_name, - disable_cuda_graph=False, + disable_cuda_graph=disable_cuda_graph, page_size=page_size, vortex_topk_val=topk_val, disable_overlap_schedule=True, @@ -74,16 +84,20 @@ def verify_algos( mem_fraction_static=mem, kv_cache_dtype=kv_cache_dtype, vortex_topk_type=topk_type, + vortex_topk_mapping_mode=topk_mapping_mode, + vortex_topk_mapping_power=topk_mapping_power, + vortex_topk_mapping_lut_path=topk_mapping_lut_path, + vortex_topk_mapping_quantiles_path=topk_mapping_quantiles_path, + vortex_index_cache_shared_layers=index_cache_shared_layers, ) - with open("amc23.jsonl", "r", encoding="utf-8") as f: requests = [json.loads(line) for line in f] - + requests = requests * trials prompts = [req["prompt"] for req in requests] sampling_params = {"temperature": 0.6, "top_p": 0.95, "top_k": 20, "max_new_tokens": 8192} - + o = llm.generate(prompts, sampling_params) gold_metric = MultilingualExtractiveMatchMetric( language=Language.ENGLISH, @@ -93,7 +107,7 @@ def verify_algos( pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig(boxed_match_priority=0)), aggregation_function=max, ) - + results = [] for data, item in zip(requests, o): golds = [data["answer"]] @@ -103,7 +117,7 @@ def verify_algos( result = gold_metric.compute(model_response=ModelResponse(text=[predictions]), doc=target) except: result = 0.0 - + results.append( { "score": float(result), @@ -122,7 +136,7 @@ def verify_algos( # print(f" question: {data['question'][:120]}...") # print(f" prediction: {predictions[:200]}...") # print() - + total_accuracy = 0.0 total_tokens = 0 @@ -236,11 +250,54 @@ def parse_args(): choices=["naive", "sglang"], help='TopK kernel type: "naive" for topk_output, "sglang" for topk_output_sglang (default: "naive").', ) + parser.add_argument( + "--topk-mapping-mode", + type=int, + default=0, + choices=[0, 1, 2, 3, 4, 5, 6, 7], + help='TopK mapping mode: 0=none, 1=lut_cdf, 2=quantile, 3=power, 4=log, 5=index_cache, 6=asinh, 7=log1p (default: 0).', + ) + + parser.add_argument( + "--topk-mapping-power", + type=float, + default=0.5, + help='Hyperparameter for parametric modes: power exponent (mode 3), beta (mode 7 asinh), alpha (mode 8 log1p). Default: 0.5.', + ) + + parser.add_argument( + "--topk-mapping-lut-path", + type=str, + default=None, + help="Path to .npy file with uint8[256] LUT for topk mapping mode 1.", + ) + + parser.add_argument( + "--topk-mapping-quantiles-path", + type=str, + default=None, + help="Path to .npy file with float32[256] quantiles for topk mapping mode 2.", + ) + + parser.add_argument( + "--index-cache-shared-layers", + type=int, + nargs="+", + default=None, + help="Layer IDs that reuse indices from the nearest preceding full layer (skip indexer).", + ) + return parser.parse_args() if __name__ == "__main__": args = parse_args() + # --- Mode 5: Index Cache (default even-layer pattern) --- + if args.topk_mapping_mode == 5: + if args.index_cache_shared_layers is None: + args.index_cache_shared_layers = list(range(2, 28, 2)) # [2,4,6,...,26] + args.topk_mapping_mode = 0 + summary = verify_algos( trials=args.trials, topk_val=args.topk_val, @@ -251,6 +308,11 @@ def parse_args(): mem=args.mem, kv_cache_dtype=args.kv_cache_dtype, topk_type=args.topk_type, + topk_mapping_mode=args.topk_mapping_mode, + topk_mapping_power=args.topk_mapping_power, + topk_mapping_lut_path=args.topk_mapping_lut_path, + topk_mapping_quantiles_path=args.topk_mapping_quantiles_path, + index_cache_shared_layers=args.index_cache_shared_layers, ) print(summary) diff --git a/examples/verify_algo.sh b/examples/verify_algo.sh index aa01fe6..3edf9b6 100644 --- a/examples/verify_algo.sh +++ b/examples/verify_algo.sh @@ -1,6 +1,7 @@ #!/usr/bin/env bash set -e -export CUDA_VISIBLE_DEVICES=6 +# use CUDA_VISIBLE_DEVICES to set the GPU id you want to use +export CUDA_VISIBLE_DEVICES=5 sparse_algos=( "block_sparse_attention" @@ -22,4 +23,4 @@ TIMESTAMP=$(date +%Y%m%d_%H%M%S) --topk-type naive \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" - done \ No newline at end of file + done diff --git a/examples/verify_algo_quant.sh b/examples/verify_algo_quant.sh index a7601de..a2663e9 100644 --- a/examples/verify_algo_quant.sh +++ b/examples/verify_algo_quant.sh @@ -1,6 +1,7 @@ #!/usr/bin/env bash set -e -export CUDA_VISIBLE_DEVICES=6 +# use CUDA_VISIBLE_DEVICES to set the GPU id you want to use +export CUDA_VISIBLE_DEVICES=5 sparse_algos=( "block_sparse_attention" @@ -23,19 +24,4 @@ TIMESTAMP=$(date +%Y%m%d_%H%M%S) --topk-type naive \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" - done - - for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/${algo}_fp8_${TIMESTAMP}.log" - echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --kv-cache-dtype fp8_e4m3" - echo ">>> Saving results to ${OUTFILE}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --kv-cache-dtype fp8_e4m3 \ - --topk-type naive \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" done \ No newline at end of file diff --git a/examples/verify_algo_topk_mapping.sh b/examples/verify_algo_topk_mapping.sh new file mode 100644 index 0000000..918252c --- /dev/null +++ b/examples/verify_algo_topk_mapping.sh @@ -0,0 +1,175 @@ +#!/usr/bin/env bash +set -e +# use CUDA_VISIBLE_DEVICES to set the GPU id you want to use +# Mapping functions: +# 0: None — original fp16 bit-pattern bucketing +# 1: LUT CDF — LUT-based CDF equalization (calibrated) +# 2: Quantile — piecewise-linear quantile mapping (calibrated) +# 3: Power — y = sign(x) * |x|^p +# 4: Log — y = sign(x) * log(|x| + 1) +# 5: Index Cache — reuse previous layer's indices +# 6: Asinh — y = asinh(beta * x) +# 7: Log1p — y = sign(x) * log1p(alpha * |x|) +# 8: Trunc8 — bf16 upper-8-bit bucketing +export CUDA_VISIBLE_DEVICES=0 + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +sparse_algos=( + "block_sparse_attention" +) + +topk_mapping_modes=( + 0 # none + 3 # power + 4 # log +) +RESULTS_DIR="results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +# Set this to an existing calibration directory to skip re-running calibration. +# It must contain lut.npy and quantiles.npy (output of calibrate_topk.py). +CALIBRATION_DIR="/scr/dataset/yuke/xinrui/new/vortex_torch/examples/calibration" + +# ============================================================ +# Baseline: naive topk (mode 0) +# ============================================================ +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_naive_${TIMESTAMP}.log" + echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --topk-type naive --topk-mapping-mode 0" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type naive \ + --topk-mapping-mode 0 \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + +# ============================================================ +# Calibration: collect histograms for LUT/quantile generation +# Skipped if CALIBRATION_DIR already has lut.npy + quantiles.npy +# ============================================================ +if [ -f "${CALIBRATION_DIR}/lut.npy" ] && [ -f "${CALIBRATION_DIR}/quantiles.npy" ]; then + echo ">>> Calibration SKIPPED (using existing ${CALIBRATION_DIR})" +else + CALIBRATION_DIR="${RESULTS_DIR}/calibration_${TIMESTAMP}" + for algo in "${sparse_algos[@]}"; do + echo ">>> Calibrating for ${algo}..." + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-val 30 \ + --mem 0.7 \ + --vortex-module-name "${algo}" \ + --output-dir "${CALIBRATION_DIR}" \ + 2>&1 | tee "${RESULTS_DIR}/calibration_${algo}_${TIMESTAMP}.log" + done +fi + + + +# ============================================================ +# Mode 1: LUT CDF with calibrated LUT +# ============================================================ +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_1_calibrated_${TIMESTAMP}.log" + echo ">>> Running mode 1 (LUT CDF) with calibrated LUT for ${algo}" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 1 \ + --topk-mapping-lut-path "${CALIBRATION_DIR}/lut.npy" \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + +# ============================================================ +# Mode 2: Quantile with calibrated quantiles +# ============================================================ +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_2_calibrated_${TIMESTAMP}.log" + echo ">>> Running mode 2 (quantile) with calibrated quantiles for ${algo}" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 2 \ + --topk-mapping-quantiles-path "${CALIBRATION_DIR}/quantiles.npy" \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + +# ============================================================ +# sglang topk: modes that don't need calibration (0, 3, 4) +# ============================================================ +for algo in "${sparse_algos[@]}"; do + for topk_mapping_mode in "${topk_mapping_modes[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_${topk_mapping_mode}_${TIMESTAMP}.log" + echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --topk-type sglang --topk-mapping-mode ${topk_mapping_mode}" + echo ">>> Saving results to ${OUTFILE}" + + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode ${topk_mapping_mode} \ + --topk-mapping-power 0.5 \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done +done + +# ============================================================ +# Mode 6: asinh — sweep beta values +# ============================================================ +for algo in "${sparse_algos[@]}"; do + for beta in 0.5 1.0 2.0; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_6_beta${beta}_${TIMESTAMP}.log" + echo ">>> Running mode 6 (asinh) beta=${beta} for ${algo}" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 6 \ + --topk-mapping-power ${beta} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done +done + +# ============================================================ +# Mode 7: log1p — sweep alpha values +# ============================================================ +for algo in "${sparse_algos[@]}"; do + for alpha in 0.5 1.0 2.0; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_7_alpha${alpha}_${TIMESTAMP}.log" + echo ">>> Running mode 7 (log1p) alpha=${alpha} for ${algo}" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 7 \ + --topk-mapping-power ${alpha} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done +done \ No newline at end of file diff --git a/examples/verify_algo_topk_mapping_indexcache.sh b/examples/verify_algo_topk_mapping_indexcache.sh new file mode 100644 index 0000000..9002084 --- /dev/null +++ b/examples/verify_algo_topk_mapping_indexcache.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash +set -e +# use CUDA_VISIBLE_DEVICES to set the GPU id you want to use +export CUDA_VISIBLE_DEVICES=5 + +RESULTS_DIR="results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + +sparse_algos=( + "block_sparse_attention" +) + +# --- Mode 5: Index Cache (default even-layer pattern) --- +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_mode5_index_cache_${TIMESTAMP}.log" + echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --topk-type sglang --topk-mapping-mode 5 (index cache)" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 5 \ + --index-cache-shared-layers 2 4 6 8 10 12 14 16 18 20 22 24 26 \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + +# --- Mode 6: Greedy layer selection --- +# for algo in "${sparse_algos[@]}"; do +# OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_mode6_greedy_${TIMESTAMP}.log" +# echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --topk-type sglang --topk-mapping-mode 6 (greedy)" +# echo ">>> Saving results to ${OUTFILE}" +# { time python verify_algo.py \ +# --trials 8 \ +# --topk-val 30 \ +# --vortex-module-name "${algo}" \ +# --model-name Qwen/Qwen3-1.7B \ +# --topk-type sglang \ +# --topk-mapping-mode 6 \ +# --mem 0.7 ; } \ +# 2>&1 | tee "${OUTFILE}" +#done diff --git a/examples/verify_algo_topk_mapping_new.sh b/examples/verify_algo_topk_mapping_new.sh new file mode 100644 index 0000000..b701be2 --- /dev/null +++ b/examples/verify_algo_topk_mapping_new.sh @@ -0,0 +1,128 @@ +#!/usr/bin/env bash +set -e +# use CUDA_VISIBLE_DEVICES to set the GPU id you want to use +# Mapping functions: +# 0: None — original fp16 bit-pattern bucketing +# 1: LUT CDF — LUT-based CDF equalization (calibrated) +# 2: Quantile — piecewise-linear quantile mapping (calibrated) +# 3: Power — y = sign(x) * |x|^p +# 4: Log — y = sign(x) * log(|x| + 1) +# 5: Index Cache — reuse previous layer's indices +# 6: Asinh — y = asinh(beta * x) +# 7: Log1p — y = sign(x) * log1p(alpha * |x|) +# 8: Trunc8 — bf16 upper-8-bit bucketing +export CUDA_VISIBLE_DEVICES=5 + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +sparse_algos=( + "block_sparse_attention" +) + +# Path to real-data histograms from calibration (for auto-tuning) +REAL_HISTOGRAMS="/scr/dataset/yuke/xinrui/new/vortex_torch/examples/calibration/raw_histograms.npy" + +RESULTS_DIR="results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + +# ============================================================ +# Step 0: Auto-tune — find best hyperparameters per mode +# Uses topk_profile_histogram kernel on synthetic data (fast, no model) +# ============================================================ +echo "============================================================" +echo "Step 0: Auto-tuning hyperparameters (synthetic data)" +echo "============================================================" +AUTOTUNE_JSON="${RESULTS_DIR}/autotune_${TIMESTAMP}.json" +PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val 30 \ + --batch-size 4 \ + --seq-len 4096 \ + --num-kv-heads 2 \ + --real-histograms "${REAL_HISTOGRAMS}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RESULTS_DIR}/autotune_${TIMESTAMP}.log" +echo ">>> Auto-tune results saved to ${AUTOTUNE_JSON}" +echo "" + +# ============================================================ +# Step 1: Mode 3 (power) — sweep p values +# ============================================================ +echo "============================================================" +echo "Step 1: Mode 3 (power) — sweeping p" +echo "============================================================" +for algo in "${sparse_algos[@]}"; do + for p in 0.1 0.25 0.75 0.9; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_3_p${p}_${TIMESTAMP}.log" + echo ">>> Mode 3 (power) p=${p} algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 3 \ + --topk-mapping-power ${p} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done +done + +# ============================================================ +# Step 2: Mode 6 (asinh) — sweep beta values +# ============================================================ +echo "============================================================" +echo "Step 2: Mode 6 (asinh) — sweeping beta" +echo "============================================================" +for algo in "${sparse_algos[@]}"; do + for beta in 0.1 0.5 1.0 2.0 4.0; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_6_beta${beta}_${TIMESTAMP}.log" + echo ">>> Mode 6 (asinh) beta=${beta} algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 6 \ + --topk-mapping-power ${beta} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done +done + +# ============================================================ +# Step 3: Mode 7 (log1p) — sweep alpha values +# ============================================================ +echo "============================================================" +echo "Step 3: Mode 7 (log1p) — sweeping alpha" +echo "============================================================" +for algo in "${sparse_algos[@]}"; do + for alpha in 0.1 0.5 0.75 1.0 2.0 4.0 8.0; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_7_alpha${alpha}_${TIMESTAMP}.log" + echo ">>> Mode 7 (log1p) alpha=${alpha} algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 7 \ + --topk-mapping-power ${alpha} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done +done + +# ============================================================ +# Summary +# ============================================================ +echo "" +echo "============================================================" +echo "All sweeps complete. Results in ${RESULTS_DIR}/" +echo " Auto-tune: ${AUTOTUNE_JSON}" +echo " Mode 3 (power): p = [0.1, 0.25, 0.75, 0.9]" +echo " Mode 6 (asinh): beta = [0.1, 0.5, 1.0, 2.0, 4.0]" +echo " Mode 7 (log1p): alpha = [0.1, 0.5, 0.75, 1.0, 2.0, 4.0, 8.0]" +echo "============================================================" diff --git a/third_party/sglang b/third_party/sglang index 20e4c29..5f51c8e 160000 --- a/third_party/sglang +++ b/third_party/sglang @@ -1 +1 @@ -Subproject commit 20e4c29d206046d6b4eb3b57cc26fd20bf9c519b +Subproject commit 5f51c8ef485fb45990c8166f439da2ee695c03c1 diff --git a/vortex_torch/indexer/context.py b/vortex_torch/indexer/context.py index d6da9c1..78e2923 100644 --- a/vortex_torch/indexer/context.py +++ b/vortex_torch/indexer/context.py @@ -1,5 +1,6 @@ from __future__ import annotations from typing import Any, Final, Union +import numpy as np import torch from ..abs import ContextBase from ..utils import UNSET, Mode @@ -23,6 +24,8 @@ class Context(ContextBase): "num_sms", "page_size", "max_num_pages", "max_num_pages_per_request", # misc "indexer_dtype", "topk_val", "page_reserved_bos", "page_reserved_eos", "topk_type", + "topk_mapping_mode", "topk_mapping_power", "topk_mapping_lut", "topk_mapping_quantiles", + "topk_histogram_enabled", # auxilary memory in graph "_aux_total_bytes", @@ -69,6 +72,11 @@ class Context(ContextBase): page_reserved_bos: int #: Reserved page count for BOS (begin-of-sequence). page_reserved_eos: int #: Reserved page count for EOS (end-of-sequence). topk_type: str #: TopK kernel type: "naive" or "sglang". + topk_mapping_mode: int #: TopK mapping mode (0=none, 1=lut, 2=quantile, 3=power, 4=log). + topk_mapping_power: float #: Power exponent for mapping mode 3. + topk_mapping_lut: object #: Optional uint8[256] LUT tensor for mapping mode 1. + topk_mapping_quantiles: object #: Optional float32[256] quantiles tensor for mapping mode 2. + topk_histogram_enabled: bool #: Enable histogram profiling during inference (default False). # --- auxiliary --- _aux_total_bytes: int #: Accumulated auxiliary memory in bytes. @@ -146,12 +154,26 @@ def create(self, parent: Any, model_runner: Any, *, overwrite: bool = False) -> self.page_reserved_bos = sa.vortex_page_reserved_bos self.page_reserved_eos = sa.vortex_page_reserved_eos self.topk_type = getattr(sa, "vortex_topk_type", "naive") + self.topk_mapping_mode = getattr(sa, "vortex_topk_mapping_mode", 0) + self.topk_mapping_power = getattr(sa, "vortex_topk_mapping_power", 0.5) + self.topk_histogram_enabled = getattr(sa, "vortex_topk_histogram", False) + + device = getattr(model_runner, "device", "cpu") + + # Load calibration data from .npy files when paths are provided + lut_path = getattr(sa, 'vortex_topk_mapping_lut_path', None) + if lut_path is not None: + lut_np = np.load(lut_path).astype(np.uint8) + self.topk_mapping_lut = torch.from_numpy(lut_np).to(device) + + quantiles_path = getattr(sa, 'vortex_topk_mapping_quantiles_path', None) + if quantiles_path is not None: + q_np = np.load(quantiles_path).astype(np.float32) + self.topk_mapping_quantiles = torch.from_numpy(q_np).to(device) self.max_num_workloads = ( (self.max_num_pages // max(1, sa.vortex_lb_min_chunk_size)) + max_bs * self.num_kv_heads ) - - device = getattr(model_runner, "device", "cpu") self.winfo_q_indices = torch.zeros((self.max_num_workloads,), dtype=torch.int32, device=device) self.winfo_kv_offsets = torch.zeros((self.max_num_workloads,), dtype=torch.int32, device=device) self.winfo_kv_lens = torch.zeros((self.max_num_workloads,), dtype=torch.int32, device=device) diff --git a/vortex_torch/indexer/output_func.py b/vortex_torch/indexer/output_func.py index 8859d61..e4208dc 100644 --- a/vortex_torch/indexer/output_func.py +++ b/vortex_torch/indexer/output_func.py @@ -1,9 +1,21 @@ import torch -from typing import Dict, Callable, Optional +from typing import Dict, Callable, List, Optional from ..abs import vOp -from vortex_torch_C import topk_output, topk_output_sglang +from vortex_torch_C import topk_output, topk_output_sglang, topk_profile_histogram from .context import Context from ..abs import vTensor, FORMAT +from ..utils import UNSET + +# --- Module-level histogram accumulator for offline calibration --- +_calibration_histograms: List[torch.Tensor] = [] + +def get_calibration_histograms() -> List[torch.Tensor]: + """Return collected histogram tensors (each [eff_bs, 256] int32 on CPU).""" + return _calibration_histograms + +def clear_calibration_histograms() -> None: + """Clear all collected calibration histograms.""" + _calibration_histograms.clear() class topK(vOp): r""" @@ -86,6 +98,7 @@ def __init__(self): super().__init__() self.impl: Optional[Callable] = None self.topk_type: str = "naive" + self.last_histograms: Optional[torch.Tensor] = None # ---------------- profile ---------------- def profile(self, x: vTensor, o: vTensor, ctx: Context) -> None: @@ -232,6 +245,15 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso if self.topk_type == "sglang": # topk_output_sglang: (x, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, sparse_kv_indices, ...) + mapping_mode = getattr(ctx, 'topk_mapping_mode', 0) + mapping_power = getattr(ctx, 'topk_mapping_power', 0.5) + mapping_lut = getattr(ctx, 'topk_mapping_lut', None) + mapping_quantiles = getattr(ctx, 'topk_mapping_quantiles', None) + # UNSET sentinel is not a valid torch.Tensor — coerce to None + if mapping_lut is UNSET: + mapping_lut = None + if mapping_quantiles is UNSET: + mapping_quantiles = None self.impl( x, ctx.dense_kv_indptr, @@ -243,14 +265,18 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso ctx.page_reserved_bos, ctx.page_reserved_eos, ctx.max_num_pages_per_request, + mapping_mode, + mapping_power, + mapping_lut, + mapping_quantiles, ) else: - # topk_output (naive): (x, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, sparse_kv_indices, ...) + # topk_output (naive): (x, dense_kv_indptr, dense_kv_indices, sparse_kv_indptr, sparse_kv_indices, ...) self.impl( x, ctx.dense_kv_indptr, - ctx.sparse_kv_indptr, ctx.dense_kv_indices, + ctx.sparse_kv_indptr, o, ctx.batch_size * ctx.num_kv_heads, ctx.topk_val, @@ -258,4 +284,30 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso ctx.page_reserved_eos, ctx.max_num_pages_per_request, ) + + # Optional histogram profiling (default disabled, no overhead when off). + # Skip entirely during CUDA graph capture — allocations and D2H copies + # are not permitted while a stream is being captured. + if ( + getattr(ctx, 'topk_histogram_enabled', False) + and self.topk_type == "sglang" + and not torch.cuda.is_current_stream_capturing() + ): + eff_bs = ctx.batch_size * ctx.num_kv_heads + self.last_histograms = torch.zeros(eff_bs, 256, dtype=torch.int32, device=x.device) + topk_profile_histogram( + x, + ctx.dense_kv_indptr, + self.last_histograms, + eff_bs, + ctx.page_reserved_bos, + ctx.page_reserved_eos, + mapping_mode, + mapping_power, + mapping_lut, + mapping_quantiles, + ) + # Accumulate histograms for offline calibration + _calibration_histograms.append(self.last_histograms.cpu().clone()) + return o From 31ba23ba830fb64bde97f9145654a5dfb28cd1a4 Mon Sep 17 00:00:00 2001 From: UED Date: Wed, 1 Apr 2026 08:12:40 +0000 Subject: [PATCH 14/22] Enhance TopK mapping modes with new remap functions - Removed outdated GPU architecture flags from setup.py. - Added new mapping modes (Erf, Tanh, Subtract) to analyze_topk_distribution.py and bench_topk.py. - Updated functions to handle new modes and added support for noscale parameters in autotune and benchmark scripts. - Enhanced the TopK kernel with additional profiling metrics and improved handling of kernel arguments. - Updated example scripts to reflect new modes and parameters for distribution analysis. --- CLAUDE.md | 172 +++++++ benchmarks/analyze_topk_distribution.py | 25 +- benchmarks/autotune_topk_mapping.py | 336 +++++++++++-- benchmarks/bench_topk.py | 338 ++++++++++++-- csrc/clean.py | 21 + csrc/register.cc | 29 +- csrc/register.h | 43 +- csrc/topk_mapping.cuh | 44 +- csrc/topk_sglang.cu | 416 ++++++++++++++++- csrc/topk_slgang_ori.cu | 546 ++++++++++++++++++++++ examples/run_distribution_analysis.sh | 105 ++++- examples/run_distribution_analysis_new.sh | 23 +- examples/run_topk_benchmark.sh | 32 +- examples/verify_algo.py | 4 +- examples/verify_algo_topk_mapping.sh | 173 +++++-- examples/verify_algo_topk_mapping_new.sh | 209 ++++++--- setup.py | 2 - third_party/sglang | 2 +- todo.txt | 308 ++++++++++++ vortex_torch/indexer/context.py | 3 + vortex_torch/indexer/output_func.py | 3 + 21 files changed, 2601 insertions(+), 233 deletions(-) create mode 100644 CLAUDE.md create mode 100644 csrc/clean.py create mode 100644 csrc/topk_slgang_ori.cu create mode 100644 todo.txt diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..585a246 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,172 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +Vortex is a lightweight, modular framework for building custom sparse attention algorithms for LLM inference. It provides a PyTorch-like frontend that abstracts away batching, caching, and paged attention, running on optimized backends (FlashInfer, CUDA Graph) via SGLang integration. + +## Build & Install + +```bash +# Clone with submodules +git clone -b v1 --recursive + +# Install SGLang dependency (custom fork in third_party/, supports v0.4.9) +cd third_party/sglang && bash install.sh && cd ../../ + +# Install Vortex (editable mode, compiles CUDA extensions for SM_86/SM_89/SM_90) +pip install -e . +``` + +Requires Python >=3.10, torch>=2.7, lighteval[math]==0.12.2. CUDA extensions (`vortex_torch_C`) are built from `csrc/` (register.cc, utils_sglang.cu, topk.cu, topk_sglang.cu). + +## Testing & Verification + +There is no formal test suite (no pytest). Verification is done by running algorithms against SGLang reference output and comparing accuracy on math benchmarks. + +```bash +# Single algorithm verification (from examples/ directory) +python examples/verify_algo.py --trials 2 --topk-val 30 --vortex-module-name block_sparse_attention + +# Full options +python examples/verify_algo.py \ + --trials 8 --topk-val 30 \ + --vortex-module-name block_sparse_attention \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type naive \ + --mem 0.7 + +# Batch test (outputs timestamped logs to examples/results/) +bash examples/verify_algo.sh + +# AIM24 benchmark verification +python examples/verify_aim24.py +``` + +Available `--topk-type` values: `naive` (CUB-based), `sglang` (SGLang-integrated kernel). + +## AI-Powered Algorithm Generation + +```bash +# Generate new sparse attention algorithms via OpenHands (requires LLM_API_KEY env var) +python openhands_gen.py +``` + +Note: Some auto-generated operators may not be fully optimized. Tune `mem_fraction_static` if OOM occurs. + +## Building Documentation + +```bash +make -C docs html +``` + +Uses Sphinx with myst_parser and furo theme. Deployed via GitHub Actions on push to v1 branch. + +## Architecture + +### Core Abstraction: vFlow (`vortex_torch/flow/flow.py`) + +All sparse attention algorithms inherit from `vFlow` and implement three methods: + +- **`forward_indexer(q, o, cache, ctx)`** — Compute sparse page indices from queries. Operates on page-packed tensor view `[S, r, c]`. +- **`forward_cache(cache, loc, ctx)`** — Update/summarize custom cache tensors when a page completes. Operates on batch-major view `[B, r, c]`. +- **`create_cache(page_size, head_dim)`** — Declare custom cache tensor shapes as a dict of `{name: (rows, cols)}`. + +Algorithms are registered via `@register("name")` decorator and instantiated with `build_vflow()`. + +### Operator System (`vortex_torch/indexer/`, `vortex_torch/cache/`) + +Operators (`vOp` subclasses) run in two modes: +- **Profile mode**: Pre-compute output shapes and allocate buffers +- **Execute mode**: Perform actual GPU computation + +Operators are split into two parallel hierarchies: +- **Indexer ops** (`vortex_torch/indexer/`): GeMM, GeMV, topK, reduce (Mean/Max/Min/Sum/L2Norm), softmax, elementwise, transpose, save/load +- **Cache ops** (`vortex_torch/cache/`): GeMM, reduce, elementwise, fill, KV buffer setup + +Both use Triton kernels (in respective `triton_kernels/` subdirectories) for GPU execution. + +### Tensor Format (`vortex_torch/abs/tensor.py`) + +`vTensor` wraps `torch.Tensor` with format metadata (BATCHED, RAGGED, PAGED) to enforce layout consistency across operations. + +### Context System (`vortex_torch/abs/context_base.py`) + +`ContextBase` carries per-step runtime state. Specialized as: +- `Indexer.Context`: Page layout, head config, hardware info +- `Cache.Context`: Page size, total pages, model info + +### Concrete Algorithms (`vortex_torch/flow/algorithms.py`) + +- **BlockSparseAttention**: Centroid-based routing (query avg → GeMV with centroids → topK) +- **GQABlockSparseAttention**: Grouped-query variant with softmax + group aggregation +- **GQAQuestSparseAttention**: Query-envelope matching using per-page max/min bounds + +### Algorithm Registry (`vortex_torch/flow/registry.py`) + +Algorithms are registered via `@register("name")` and looked up with `get(name)`, `has(name)`, `list_keys()`. Factory: `build_vflow(name)` in `loader.py`. + +### SGLang Integration + +Custom SGLang fork lives in `third_party/sglang` (git submodule, "graph" branch). CUDA extensions in `csrc/` provide PyBind11 bindings for `sglang_plan_decode`, `sglang_plan_prefill`, transpose operations (NH↔HN), and top-K output routing. + +## Key Conventions + +- **Tensor shapes**: Query `[B, H_q, D]`, sparse output `[S_sparse, 1, 1]`, cache indexer-view `[S, r, c]`, cache batch-view `[B, r, c]` +- **GeMM semantics**: `GeMM(x, y)` computes `y @ x^T` (note transposition) +- **Standard cache keys**: `"k"` and `"v"` have inner shape `(page_size, head_dim)`; custom caches declared in `create_cache()` +- **Branch**: Main development is on `v1` + +## Workflow Orchestration + +### 1. Plan Node Default +- Enter plan mode for ANY non-trivial task (3+ steps or architectural decisions) +- If something goes sideways, STOP and re-plan immediately - don't keep pushing +- Use plan mode for verification steps, not just building +- Write detailed specs upfront to reduce ambiguity + +### 2. Subagent Strategy +- Use subagents liberally to keep main context window clean +- Offload research, exploration, and parallel analysis to subagents +- For complex problems, throw more compute at it via subagents +- One tack per subagent for focused execution + +### 3. Self-Improvement Loop +- After ANY correction from the user: update `tasks/lessons.md` with the pattern +- Write rules for yourself that prevent the same mistake +- Ruthlessly iterate on these lessons until mistake rate drops +- Review lessons at session start for relevant project + +### 4. Verification Before Done +- Never mark a task complete without proving it works +- Diff behavior between main and your changes when relevant +- Ask yourself: "Would a staff engineer approve this?" +- Run tests, check logs, demonstrate correctness + +### 5. Demand Elegance (Balanced) +- For non-trivial changes: pause and ask "is there a more elegant way?" +- If a fix feels hacky: "Knowing everything I know now, implement the elegant solution" +- Skip this for simple, obvious fixes - don't over-engineer +- Challenge your own work before presenting it + +### 6. Autonomous Bug Fixing +- When given a bug report: just fix it. Don't ask for hand-holding +- Point at logs, errors, failing tests - then resolve them +- Zero context switching required from the user +- Go fix failing CI tests without being told how + +## Task Management + +1. **Plan First**: Write plan to `tasks/todo.md` with checkable items +2. **Verify Plan**: Check in before starting implementation +3. **Track Progress**: Mark items complete as you go +4. **Explain Changes**: High-level summary at each step +5. **Document Results**: Add review section to `tasks/todo.md` +6. **Capture Lessons**: Update `tasks/lessons.md` after corrections + +## Core Principles + +- **Simplicity First**: Make every change as simple as possible. Impact minimal code. +- **No Laziness**: Find root causes. No temporary fixes. Senior developer standards. +- **Minimal Impact**: Changes should only touch what's necessary. Avoid introducing bugs. \ No newline at end of file diff --git a/benchmarks/analyze_topk_distribution.py b/benchmarks/analyze_topk_distribution.py index 7d94466..00cdf28 100644 --- a/benchmarks/analyze_topk_distribution.py +++ b/benchmarks/analyze_topk_distribution.py @@ -40,6 +40,9 @@ 6: "Asinh", 7: "Log1p", 8: "Trunc8", + 9: "Erf", + 10: "Tanh", + 11: "Subtract", } MAPPING_MODE_FORMULAS = { @@ -52,25 +55,33 @@ 6: "Asinh: asinh(beta*x)", 7: "Log1p: sign(x)*log1p(alpha*|x|)", 8: "Trunc8: bf16 upper-8-bit bucketing", + 9: "Erf: erf(alpha*x)", + 10: "Tanh: tanh(alpha*x)", + 11: "Subtract: x - pivot (RadiK-style)", } def _mode_key_to_display(mode_key: str) -> str: - """Convert a mode key like 'mode_3' or 'mode_3_Power' to a display name.""" + """Convert a mode key like 'mode_3', 'mode_3_Power', or 'mode_3_Power_noscale' to display name.""" + # Handle noscale suffix + noscale = mode_key.endswith("_noscale") + base_key = mode_key[:-len("_noscale")] if noscale else mode_key + suffix = " noscale" if noscale else "" + # Handle new format: "mode_3_Power" - parts = mode_key.split("_", 2) + parts = base_key.split("_", 2) if len(parts) >= 3: - return parts[2] # e.g. "Power" + return parts[2] + suffix # e.g. "Power noscale" # Handle old format: "mode_3" try: mode_num = int(parts[1]) - return MAPPING_MODE_NAMES.get(mode_num, mode_key) + return MAPPING_MODE_NAMES.get(mode_num, base_key) + suffix except (IndexError, ValueError): return mode_key def _mode_key_to_number(mode_key: str) -> int: - """Extract the mode number from a key like 'mode_3' or 'mode_3_Power'.""" + """Extract the mode number from a key like 'mode_3', 'mode_3_Power', or 'mode_3_Power_noscale'.""" parts = mode_key.split("_") try: return int(parts[1]) @@ -314,7 +325,7 @@ def plot_mapping_mode_comparison(mode_stats_data: dict, output_dir: str): x = np.arange(len(modes)) width = 0.3 - fig, ax1 = plt.subplots(figsize=(10, 5)) + fig, ax1 = plt.subplots(figsize=(max(10, len(modes) * 0.8), 5)) ax2 = ax1.twinx() bars1 = ax1.bar(x - width / 2, ginis, width, label="Gini", color="darkorange") @@ -324,7 +335,7 @@ def plot_mapping_mode_comparison(mode_stats_data: dict, output_dir: str): ax1.set_ylabel("Gini") ax2.set_ylabel("Max/Mean Ratio") ax1.set_xticks(x) - ax1.set_xticklabels(mode_labels, rotation=15, ha="right") + ax1.set_xticklabels(mode_labels, rotation=30, ha="right") ax1.set_ylim(0, 1.1) ax1.set_title("Mapping Mode Comparison") diff --git a/benchmarks/autotune_topk_mapping.py b/benchmarks/autotune_topk_mapping.py index 9b37e32..d95c839 100644 --- a/benchmarks/autotune_topk_mapping.py +++ b/benchmarks/autotune_topk_mapping.py @@ -27,8 +27,8 @@ import numpy as np import torch -from bench_topk import make_topk_inputs, compute_histogram_stats -from vortex_torch_C import topk_profile_histogram +from bench_topk import make_topk_inputs, bench_kernel, compute_histogram_stats +from vortex_torch_C import topk_profile_histogram, topk_profile_counters, topk_output_sglang @@ -37,10 +37,22 @@ 3: ("power_exp", [0.1, 0.25, 0.75, 0.9]), 6: ("beta", [0.1, 0.5, 1.0, 2.0, 4.0]), 7: ("alpha", [0.1, 0.5, 0.75, 1.0, 2.0, 4.0, 8.0]), + 9: ("alpha", [0.1, 0.5, 1.0, 2.0, 4.0]), + 10: ("alpha", [0.1, 0.5, 1.0, 2.0, 4.0]), } BASELINES = { 0: ("none", 0.5), 4: ("log", 0.5), + 8: ("trunc8", 0.5), + 11: ("subtract", 0.5), +} +# Noscale baselines for parametric transform modes (skip auto-range pre-pass) +NOSCALE_BASELINES = { + 3: ("power_noscale", [0.5]), + 6: ("asinh_noscale", [1.0]), + 7: ("log1p_noscale", [1.0]), + 9: ("erf_noscale", [1.0]), + 10: ("tanh_noscale", [1.0]), } MODE_NAMES = { 0: "none", @@ -48,6 +60,10 @@ 4: "log", 6: "asinh", 7: "log1p", + 8: "trunc8", + 9: "erf", + 10: "tanh", + 11: "subtract", } @@ -106,6 +122,56 @@ def build_bin_range_table(): return bin_lo, bin_hi +def generate_remap_lut(mode: int, param: float) -> np.ndarray: + """Generate a 256-entry uint8 LUT that approximates a transform mode. + + For each of the 256 fp16 radix bins, compute the transform of the + bin's midpoint value, then linearly map transformed values to [0,255]. + The resulting LUT can be used with mode=1 (LUT CDF) infrastructure, + replacing expensive per-element transcendental math with a single + shared memory lookup. + + Args: + mode: TopKMappingMode (3=Power, 4=Log, 6=Asinh, 7=Log1p, 9=Erf, 10=Tanh) + param: power_exp/beta/alpha for the transform + + Returns: + lut: [256] uint8 array mapping original_bin -> remapped_bin + """ + bin_lo, bin_hi = build_bin_range_table() + midpoints = (bin_lo + bin_hi) / 2.0 # [256] float32 + + # Apply transform + if mode == 3: # power + transformed = np.sign(midpoints) * np.abs(midpoints) ** param + elif mode == 4: # log + transformed = np.sign(midpoints) * np.log(np.abs(midpoints) + 1.0) + elif mode == 6: # asinh + transformed = np.arcsinh(param * midpoints) + elif mode == 7: # log1p + transformed = np.sign(midpoints) * np.log1p(param * np.abs(midpoints)) + elif mode == 9: # erf + from scipy.special import erf + transformed = erf(param * midpoints) + elif mode == 10: # tanh + transformed = np.tanh(param * midpoints) + else: + # Identity fallback + transformed = midpoints.copy() + + # Handle NaN/Inf from edge cases + transformed = np.nan_to_num(transformed, nan=0.0, posinf=0.0, neginf=0.0) + + # Linear map to [0, 255] + tmin, tmax = transformed.min(), transformed.max() + if tmax > tmin: + lut = np.clip(((transformed - tmin) / (tmax - tmin) * 255), 0, 255).astype(np.uint8) + else: + lut = np.full(256, 128, dtype=np.uint8) + + return lut + + def scores_from_histogram( histogram: np.ndarray, total_pages: int, @@ -232,7 +298,8 @@ def run_sweep(args) -> List[dict]: eff_bs = inputs["eff_batch_size"] - def evaluate(mode: int, power: float, label: str): + def evaluate(mode: int, power: float, label: str, noscale: bool = False, + lut_tensor=None): hists = torch.zeros(eff_bs, 256, dtype=torch.int32, device="cuda") topk_profile_histogram( inputs["x"], @@ -243,28 +310,62 @@ def evaluate(mode: int, power: float, label: str): args.reserved_eos, mode, power, - None, # lut + lut_tensor, # lut None, # quantiles + noscale, ) torch.cuda.synchronize() stats = compute_histogram_stats(hists) - return { + result = { "label": label, "mode": mode, "mode_name": MODE_NAMES.get(mode, f"m{mode}"), "param": power, + "noscale": noscale, "distribution": dist, "gini": stats["gini"], "max_mean_ratio": stats["max_mean_ratio"], "num_nonzero_bins": stats["num_nonzero_bins"], } + # Counter-based metrics (Stage 2 cost analysis) + if args.counters: + inputs["sparse_kv_indices"].zero_() + counter_buf = torch.zeros(eff_bs, 6, dtype=torch.int32, device="cuda") + topk_profile_counters( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + counter_buf, + eff_bs, + args.topk_val, + args.reserved_bos, + args.reserved_eos, + inputs["num_pages_per_seg"], + mode, + power, + lut_tensor, # lut + None, # quantiles + noscale, + ) + torch.cuda.synchronize() + c = counter_buf.float() + result["num_equal_mean"] = c[:, 2].mean().item() + result["remaining_k_mean"] = c[:, 3].mean().item() + result["refine_rounds_mean"] = c[:, 4].mean().item() + result["stage2_input_mean"] = c[:, 5].mean().item() + result["res_rate_mean"] = (c[:, 3] == 0).float().mean().item() + + return result + # Baselines for mode, (name, default_power) in BASELINES.items(): r = evaluate(mode, default_power, f"m{mode}_{name}") results.append(r) - # Parametric sweep + # Parametric sweep (scaled) for mode, (param_name, values) in SWEEP_GRID.items(): mname = MODE_NAMES[mode] for val in values: @@ -272,44 +373,122 @@ def evaluate(mode: int, power: float, label: str): r = evaluate(mode, val, label) results.append(r) + # Noscale sweep for parametric modes + for mode, (name, values) in NOSCALE_BASELINES.items(): + mname = MODE_NAMES[mode] + for val in values: + label = f"m{mode}_{mname}_noscale_{val}" + r = evaluate(mode, val, label, noscale=True) + results.append(r) + + # LUT approximation sweep: generate a LUT for each (mode, param) and + # evaluate via mode=1 (LUT CDF). This replaces per-element transcendentals + # with a single shared memory lookup. + if args.lut_sweep: + lut_modes = { + 3: [0.25, 0.5, 0.75], + 6: [0.5, 1.0, 2.0], + 7: [0.5, 1.0, 2.0], + 9: [0.5, 1.0, 2.0], + 10: [0.5, 1.0, 2.0], + } + for src_mode, params in lut_modes.items(): + src_name = MODE_NAMES[src_mode] + for p in params: + try: + lut_np = generate_remap_lut(src_mode, p) + lut_t = torch.from_numpy(lut_np).cuda() + label = f"lut_{src_name}_{p}" + # Evaluate as mode=1 (LUT CDF) with the generated LUT + r = evaluate(1, 0.5, label, lut_tensor=lut_t) + r["lut_source_mode"] = src_mode + r["lut_source_param"] = p + results.append(r) + except ImportError: + # scipy not available for erf + pass + return results -def print_table(results: List[dict]): +def print_table(results: List[dict], show_latency: bool = False): """Print ranked results as a formatted table.""" - # Sort by Gini ascending (lower = more uniform = better) - ranked = sorted(results, key=lambda r: r["gini"]) + has_counters = any("res_rate_mean" in r for r in results) + has_latency = any("full_kernel_ms" in r for r in results) - header = ( - f"{'Rank':>4s} {'Label':<35s} {'Dist':<12s} " - f"{'Gini':>6s} {'Max/Mean':>8s} {'NZBins':>6s}" - ) - print("\n" + "=" * len(header)) - print("TopK Mapping Auto-Tune Results (ranked by Gini, lower=better)") - print("=" * len(header)) - print(header) - print("-" * len(header)) + # Primary ranking: by res_rate_mean (higher=better) if counters, else by gini (lower=better) + if has_counters: + ranked = sorted(results, key=lambda r: -r.get("res_rate_mean", 0.0)) + rank_label = "ranked by res_rate, higher=better" + else: + ranked = sorted(results, key=lambda r: r["gini"]) + rank_label = "ranked by Gini, lower=better" + + # Build header + cols = f"{'Rank':>4s} {'Label':<35s} {'Dist':<12s} {'Gini':>6s} {'Max/Mean':>8s} {'NZBins':>6s}" + if has_counters: + cols += f" {'ResRate':>7s} {'RemK':>5s} {'Rnds':>4s} {'S2In':>5s}" + if has_latency and show_latency: + cols += f" {'LatMs':>9s} {'LatRk':>5s}" + + print(f"\n{'=' * len(cols)}") + print(f"TopK Mapping Auto-Tune Results ({rank_label})") + print("=" * len(cols)) + print(cols) + print("-" * len(cols)) for i, r in enumerate(ranked): - print( - f"{i+1:4d} {r['label']:<35s} {r['distribution']:<12s} " + noscale_tag = " [NS]" if r.get("noscale", False) else "" + line = ( + f"{i+1:4d} {r['label'] + noscale_tag:<35s} {r['distribution']:<12s} " f"{r['gini']:6.3f} " f"{r['max_mean_ratio']:8.2f} {r['num_nonzero_bins']:6d}" ) - - print("=" * len(header)) + if has_counters: + rr = r.get("res_rate_mean", 0.0) + rk = r.get("remaining_k_mean", 0.0) + rnds = r.get("refine_rounds_mean", 0.0) + s2in = r.get("stage2_input_mean", 0.0) + line += f" {rr:7.3f} {rk:5.0f} {rnds:4.1f} {s2in:5.0f}" + if has_latency and show_latency: + lat = r.get("full_kernel_ms", float("nan")) + lat_rank = r.get("latency_rank", "-") + line += f" {lat:9.4f} {lat_rank:>5s}" if isinstance(lat_rank, str) else f" {lat:9.4f} {lat_rank:5d}" + print(line) + + print("=" * len(cols)) if ranked: best = ranked[0] - print( + msg = ( f"\nBest overall: {best['label']} (dist={best['distribution']}) " f"— gini={best['gini']:.3f}, max/mean={best['max_mean_ratio']:.2f}" ) + if has_counters: + msg += f", res_rate={best.get('res_rate_mean', 0):.3f}" + if "full_kernel_ms" in best: + msg += f", latency={best['full_kernel_ms']:.4f}ms" + print(msg) + + # If latency data available, also print best by latency + if has_latency and show_latency: + lat_ranked = sorted([r for r in results if "full_kernel_ms" in r], + key=lambda r: r["full_kernel_ms"]) + if lat_ranked: + best_lat = lat_ranked[0] + print( + f"Best by latency: {best_lat['label']} (dist={best_lat['distribution']}) " + f"— latency={best_lat['full_kernel_ms']:.4f}ms, gini={best_lat['gini']:.3f}" + ) - # Per-mode best summary (lowest gini per mode) + # Per-mode best summary mode_best = {} for r in results: m = r["mode"] - if m not in mode_best or r["gini"] < mode_best[m]["gini"]: + if has_counters: + is_better = m not in mode_best or r.get("res_rate_mean", 0) > mode_best[m].get("res_rate_mean", 0) + else: + is_better = m not in mode_best or r["gini"] < mode_best[m]["gini"] + if is_better: mode_best[m] = r if mode_best: @@ -322,12 +501,94 @@ def print_table(results: List[dict]): param_str = f"{param_name}={r['param']}" else: param_str = "(baseline)" + ns_str = " noscale" if r.get("noscale", False) else "" + lat_str = f" latency={r['full_kernel_ms']:.4f}ms" if "full_kernel_ms" in r else "" + counter_str = f" res_rate={r.get('res_rate_mean', 0):.3f}" if has_counters else "" print( - f" Mode {m:d} ({mname:>5s}): {param_str:<20s} " - f"gini={r['gini']:.3f} max/mean={r['max_mean_ratio']:.2f}" + f" Mode {m:d} ({mname:>5s}{ns_str}): {param_str:<20s} " + f"gini={r['gini']:.3f} max/mean={r['max_mean_ratio']:.2f}{counter_str}{lat_str}" ) +def latency_rerank(results: List[dict], args) -> List[dict]: + """Re-rank top Gini candidates by actual kernel latency.""" + # Sort by Gini, take top N + ranked = sorted(results, key=lambda r: r["gini"]) + finalists = ranked[:args.latency_top_n] + + print(f"\n--- Latency re-ranking: timing top {len(finalists)} Gini finalists ---") + + # Build inputs for latency measurement + real_histogram = None + if args.real_histograms: + raw = np.load(args.real_histograms) + real_histogram = raw.sum(axis=0) if raw.ndim > 1 else raw + + if real_histogram is not None: + inputs = make_real_inputs( + batch_size=args.batch_size, + num_kv_heads=args.num_kv_heads, + seq_len=args.seq_len, + page_size=args.page_size, + topk_val=args.topk_val, + reserved_bos=args.reserved_bos, + reserved_eos=args.reserved_eos, + histogram=real_histogram, + ) + else: + inputs = make_topk_inputs( + batch_size=args.batch_size, + num_kv_heads=args.num_kv_heads, + seq_len=args.seq_len, + page_size=args.page_size, + topk_val=args.topk_val, + reserved_bos=args.reserved_bos, + reserved_eos=args.reserved_eos, + score_dtype=torch.bfloat16, + distribution="normal", + ) + + eff_bs = inputs["eff_batch_size"] + pages_per_seg = inputs["num_pages_per_seg"] + + for r in finalists: + inputs["sparse_kv_indices"].zero_() + # For LUT-generated entries, regenerate the LUT tensor + lut_tensor = None + if "lut_source_mode" in r: + lut_np = generate_remap_lut(r["lut_source_mode"], r["lut_source_param"]) + lut_tensor = torch.from_numpy(lut_np).cuda() + call_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, + args.topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + r["mode"], + r["param"], + lut_tensor, # lut + None, # quantiles + r.get("noscale", False), + ) + latency = bench_kernel(topk_output_sglang, call_args, + warmup=10, repeat=args.latency_repeat) + r["full_kernel_ms"] = latency["mean_ms"] + print(f" {r['label']:<35s} gini={r['gini']:.3f} latency={latency['mean_ms']:.4f}ms") + + # Re-rank finalists by latency + finalists.sort(key=lambda r: r["full_kernel_ms"]) + for i, r in enumerate(finalists): + r["latency_rank"] = i + 1 + r["gini_rank"] = next(j+1 for j, x in enumerate(ranked) if x is r) + + return results + + def main(): parser = argparse.ArgumentParser( description="Auto-tune TopK mapping hyperparameters" @@ -353,6 +614,16 @@ def main(): "--output-json", type=str, default=None, help="Save results to JSON file", ) + parser.add_argument("--latency-rerank", action="store_true", + help="Re-rank top Gini finalists by actual kernel latency") + parser.add_argument("--latency-top-n", type=int, default=10, + help="Number of Gini finalists to re-rank by latency (default: 10)") + parser.add_argument("--latency-repeat", type=int, default=50, + help="Kernel timing repetitions for latency measurement (default: 50)") + parser.add_argument("--counters", action="store_true", + help="Collect counter-based metrics (Stage 2 cost analysis) for each config") + parser.add_argument("--lut-sweep", action="store_true", + help="Generate and evaluate LUT approximations for parametric transform modes") args = parser.parse_args() source = f"real ({args.real_histograms})" if args.real_histograms else f"synthetic ({args.distributions})" @@ -361,12 +632,17 @@ def main(): f"topk_val={args.topk_val}, num_kv_heads={args.num_kv_heads}") print(f" score source: {source}") n_parametric = sum(len(v) for _, v in SWEEP_GRID.values()) + n_baselines = len(BASELINES) n_dists = 1 if args.real_histograms else len(args.distributions) - print(f" sweep: {n_parametric} parametric + {len(BASELINES)} baselines " - f"= {n_parametric + len(BASELINES)} combos x {n_dists} dists") + print(f" sweep: {n_parametric} parametric + {n_baselines} baselines " + f"= {n_parametric + n_baselines} combos x {n_dists} dists") results = run_sweep(args) - print_table(results) + + if args.latency_rerank: + results = latency_rerank(results, args) + + print_table(results, show_latency=args.latency_rerank) if args.output_json: with open(args.output_json, "w") as f: diff --git a/benchmarks/bench_topk.py b/benchmarks/bench_topk.py index ca039f2..675092e 100644 --- a/benchmarks/bench_topk.py +++ b/benchmarks/bench_topk.py @@ -18,7 +18,10 @@ import numpy as np import torch -from vortex_torch_C import topk_output, topk_output_sglang, topk_profile_histogram +from vortex_torch_C import ( + topk_output, topk_output_sglang, topk_profile_histogram, + topk_profile_stage1, topk_profile_counters, +) # Canonical mapping mode names — used in logs, tables, and plots MAPPING_MODE_NAMES = { @@ -31,6 +34,9 @@ 6: "Asinh", 7: "Log1p", 8: "Trunc8", + 9: "Erf", + 10: "Tanh", + 11: "Subtract", } MAPPING_MODE_FORMULAS = { @@ -43,6 +49,9 @@ 6: "Asinh: asinh(beta*x)", 7: "Log1p: sign(x)*log1p(alpha*|x|)", 8: "Trunc8: bf16 upper-8-bit bucketing", + 9: "Erf: erf(alpha*x)", + 10: "Tanh: tanh(alpha*x)", + 11: "Subtract: x - pivot (RadiK-style)", } @@ -200,7 +209,7 @@ def _load_autotune_powers(path: str) -> Dict[int, float]: best: Dict[int, dict] = {} for r in data: m = r.get("mode") - if m not in (3, 6, 7): + if m not in (3, 6, 7, 9, 10): continue if has_res_rate: score = r.get("res_rate_mean", 0.0) @@ -219,7 +228,8 @@ def _resolve_mode_power(args, mode: int) -> float: Priority: per-mode CLI flag > autotune JSON > global --mapping-power. """ - per_mode_flag = {3: args.mapping_power_3, 6: args.mapping_power_6, 7: args.mapping_power_7} + per_mode_flag = {3: args.mapping_power_3, 6: args.mapping_power_6, 7: args.mapping_power_7, + 9: getattr(args, 'mapping_power_9', None), 10: getattr(args, 'mapping_power_10', None)} if mode in per_mode_flag and per_mode_flag[mode] is not None: return per_mode_flag[mode] if hasattr(args, "_autotune_powers") and mode in args._autotune_powers: @@ -276,11 +286,20 @@ def run_benchmark(args) -> List[dict]: all_kernels = { "naive": "naive", "sglang_m0": "sglang_m0", + "sglang_scale": "sglang_scale", # mode 3 with p=1.0 (identity + linear auto-range scaling) "sglang_m3": "sglang_m3", + "sglang_m3_noscale": "sglang_m3_noscale", "sglang_m4": "sglang_m4", "sglang_m6": "sglang_m6", + "sglang_m6_noscale": "sglang_m6_noscale", "sglang_m7": "sglang_m7", + "sglang_m7_noscale": "sglang_m7_noscale", "sglang_m8": "sglang_m8", + "sglang_m9": "sglang_m9", + "sglang_m9_noscale": "sglang_m9_noscale", + "sglang_m10": "sglang_m10", + "sglang_m10_noscale": "sglang_m10_noscale", + "sglang_m11": "sglang_m11", } if mapping_lut is not None: all_kernels["sglang_m1"] = "sglang_m1" @@ -288,6 +307,23 @@ def run_benchmark(args) -> List[dict]: all_kernels["sglang_m2"] = "sglang_m2" if args.filter_kernels: + # Validate: if the user explicitly requested sglang_m1 or sglang_m2 but + # the required calibration file was not provided, fail loudly instead of + # silently skipping these modes. + if "sglang_m1" in args.filter_kernels and "sglang_m1" not in all_kernels: + raise RuntimeError( + "sglang_m1 (LUT CDF) was requested in --filter-kernels but no " + "--lut-path was provided. Mode 1 requires a calibrated LUT file " + "(lut.npy from calibrate_topk.py). Either supply --lut-path or " + "remove sglang_m1 from --filter-kernels." + ) + if "sglang_m2" in args.filter_kernels and "sglang_m2" not in all_kernels: + raise RuntimeError( + "sglang_m2 (Quantile) was requested in --filter-kernels but no " + "--quantiles-path was provided. Mode 2 requires a calibrated " + "quantiles file (quantiles.npy from calibrate_topk.py). Either " + "supply --quantiles-path or remove sglang_m2 from --filter-kernels." + ) all_kernels = {k: v for k, v in all_kernels.items() if k in args.filter_kernels} # Naive kernel only supports bf16 @@ -352,12 +388,14 @@ def run_benchmark(args) -> List[dict]: "kernels": {}, } + # Collect all kernel results first, then print sorted by latency + kernel_entries = [] # [(label, kernel_name, result)] + for kernel_name in all_kernels: # Reset sparse indices each run inputs["sparse_kv_indices"].zero_() if kernel_name == "naive": - # topk_output: (x, dense_indptr, dense_indices, sparse_indptr, sparse_indices, ...) call_args = ( inputs["x"], inputs["dense_kv_indptr"], @@ -371,18 +409,39 @@ def run_benchmark(args) -> List[dict]: pages_per_seg, ) result = bench_kernel(topk_output, call_args, args.warmup, args.repeat) + elif kernel_name == "sglang_scale": + call_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, + topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + 3, # mode 3 (power) + 1.0, # p=1.0 → identity + None, + None, + ) + result = bench_kernel(topk_output_sglang, call_args, args.warmup, args.repeat) else: - # Parse mapping mode from kernel name - mode = int(kernel_name.split("_m")[1]) + mode_str = kernel_name.split("_m")[1] + mode = int(mode_str.split("_")[0]) + is_noscale = kernel_name.endswith("_noscale") extra_kwargs = {} if mode == 1: extra_kwargs["mapping_lut"] = mapping_lut elif mode == 2: extra_kwargs["mapping_quantiles"] = mapping_quantiles - power = _resolve_mode_power(args, mode) if mode in (3, 6, 7) else 0.5 + if mode in (3, 6, 7, 9, 10): + power = _resolve_mode_power(args, mode) + else: + power = 0.5 - # topk_output_sglang: (x, dense_indptr, sparse_indptr, dense_indices, sparse_indices, ...) call_args = ( inputs["x"], inputs["dense_kv_indptr"], @@ -398,25 +457,175 @@ def run_benchmark(args) -> List[dict]: power, extra_kwargs.get("mapping_lut", None), extra_kwargs.get("mapping_quantiles", None), + is_noscale, ) result = bench_kernel(topk_output_sglang, call_args, args.warmup, args.repeat) + # Build label if kernel_name == "naive": label = "naive" + elif kernel_name == "sglang_scale": + label = "sglang Scale Only (p=1.0)" else: - m = int(kernel_name.split("_m")[1]) + m_str = kernel_name.split("_m")[1] + m = int(m_str.split("_")[0]) + noscale_suffix = " noscale" if kernel_name.endswith("_noscale") else "" mname = MAPPING_MODE_NAMES.get(m, f'm{m}') - if m in (3, 6, 7): - pname = {3: "p", 6: "beta", 7: "alpha"}[m] - label = f"sglang {mname} ({pname}={_resolve_mode_power(args, m)})" + if m in (3, 6, 7, 9, 10): + pname = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha"}[m] + label = f"sglang {mname} ({pname}={_resolve_mode_power(args, m)}){noscale_suffix}" + else: + label = f"sglang {mname}{noscale_suffix}" + + # Sub-phase profiling for sglang kernels + if kernel_name != "naive": + if kernel_name == "sglang_scale": + s1_mode, s1_power = 3, 1.0 + s1_lut, s1_q = None, None + s1_noscale = False else: - label = f"sglang {mname}" + s1_mode_str = kernel_name.split("_m")[1] + s1_mode = int(s1_mode_str.split("_")[0]) + s1_noscale = kernel_name.endswith("_noscale") + if s1_mode in (3, 6, 7, 9, 10): + s1_power = _resolve_mode_power(args, s1_mode) + else: + s1_power = 0.5 + s1_lut = mapping_lut if s1_mode == 1 else None + s1_q = mapping_quantiles if s1_mode == 2 else None + + # Histogram only: pre-pass + histogram build + hist_buf = torch.zeros(eff_bs, 256, dtype=torch.int32, device="cuda") + hist_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + hist_buf, + eff_bs, + args.reserved_bos, + args.reserved_eos, + s1_mode, + s1_power, + s1_lut, + s1_q, + s1_noscale, + ) + hist_result = bench_kernel(topk_profile_histogram, hist_args, args.warmup, args.repeat) + + # Stage1 full: pre-pass + hist + cumsum + route/filter + inputs["sparse_kv_indices"].zero_() + stage1_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, + topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + s1_mode, + s1_power, + s1_lut, + s1_q, + s1_noscale, + ) + stage1_result = bench_kernel(topk_profile_stage1, stage1_args, args.warmup, args.repeat) + + result['histogram_only_mean_ms'] = hist_result['mean_ms'] + result['histogram_only_median_ms'] = hist_result['median_ms'] + result['stage1_full_mean_ms'] = stage1_result['mean_ms'] + result['stage1_full_median_ms'] = stage1_result['median_ms'] + result['route_overhead_mean_ms'] = stage1_result['mean_ms'] - hist_result['mean_ms'] + result['route_overhead_median_ms'] = stage1_result['median_ms'] - hist_result['median_ms'] + result['stage2_refine_mean_ms'] = result['mean_ms'] - stage1_result['mean_ms'] + result['stage2_refine_median_ms'] = result['median_ms'] - stage1_result['median_ms'] + + # Optional counter collection + if args.counters: + inputs["sparse_kv_indices"].zero_() + counter_buf = torch.zeros(eff_bs, 6, dtype=torch.int32, device="cuda") + counter_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + counter_buf, + eff_bs, + topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + s1_mode, + s1_power, + s1_lut, + s1_q, + s1_noscale, + ) + topk_profile_counters(*counter_args) + torch.cuda.synchronize() + c = counter_buf.float() + result['counters'] = { + 'threshold_bin_mean': c[:, 0].mean().item(), + 'num_above_mean': c[:, 1].mean().item(), + 'num_equal_mean': c[:, 2].mean().item(), + 'remaining_k_mean': c[:, 3].mean().item(), + 'refine_rounds_mean': c[:, 4].mean().item(), + 'stage2_input_mean': c[:, 5].mean().item(), + 'threshold_bin_max': c[:, 0].max().item(), + 'num_above_max': c[:, 1].max().item(), + 'num_equal_max': c[:, 2].max().item(), + 'remaining_k_max': c[:, 3].max().item(), + 'refine_rounds_max': c[:, 4].max().item(), + 'stage2_input_max': c[:, 5].max().item(), + } + + kernel_entries.append((label, kernel_name, result)) + config_results["kernels"][kernel_name] = result + + # Print kernel results sorted by mean latency (ascending) + kernel_entries.sort(key=lambda e: e[2]['mean_ms']) + print(f" --- kernel latency (sorted by mean, ascending) ---") + for label, kernel_name, result in kernel_entries: print( - f" {label:<30s}: {result['median_ms']:.4f}ms (median) " + f" {label:<40s}: " + f"mean={result['mean_ms']:.4f}ms " + f"median={result['median_ms']:.4f}ms " f"\u00b1 {result['std_ms']:.4f}ms " f"[min={result['min_ms']:.4f}, max={result['max_ms']:.4f}]" ) - config_results["kernels"][kernel_name] = result + if 'stage1_full_mean_ms' in result: + print( + f" {'Histogram only (map+hist)':<36s}: " + f"mean={result['histogram_only_mean_ms']:.4f}ms " + f"median={result['histogram_only_median_ms']:.4f}ms" + ) + print( + f" {'Stage1 full (hist+cumsum+route)':<36s}: " + f"mean={result['stage1_full_mean_ms']:.4f}ms " + f"median={result['stage1_full_median_ms']:.4f}ms" + ) + print( + f" {'Route overhead (cumsum+route)':<36s}: " + f"mean={result['route_overhead_mean_ms']:.4f}ms " + f"median={result['route_overhead_median_ms']:.4f}ms" + ) + print( + f" {'Stage2 (refine)':<36s}: " + f"mean={result['stage2_refine_mean_ms']:.4f}ms " + f"median={result['stage2_refine_median_ms']:.4f}ms" + ) + if 'counters' in result: + c = result['counters'] + print( + f" Counters: threshold_bin={c['threshold_bin_mean']:.0f} " + f"above={c['num_above_mean']:.0f} " + f"equal={c['num_equal_mean']:.0f} " + f"remaining_k={c['remaining_k_mean']:.0f} " + f"refine_rounds={c['refine_rounds_mean']:.1f} " + f"stage2_input={c['stage2_input_mean']:.0f}" + ) # Histogram analysis if args.histogram: @@ -476,22 +685,25 @@ def run_benchmark(args) -> List[dict]: f"nonzero_bins={hstats['num_nonzero_bins']}/256" ) - # Per-mode histogram analysis - modes_to_test = [0, 3, 4, 6, 7, 8] + # Collect all histogram entries, then print sorted by gini + # Each entry: (display_name, key, mode_stats) + hist_entries = [] + histograms_results = {} + + # Per-mode histogram analysis (scaled) + modes_to_test = [0, 3, 4, 6, 7, 8, 9, 10, 11] if mapping_lut is not None: modes_to_test.append(1) if mapping_quantiles is not None: modes_to_test.append(2) modes_to_test.sort() - histograms_results = {} - print(f" --- histogram by mapping mode ---") for mode in modes_to_test: mode_hists = torch.zeros(hist_eff_bs, 256, dtype=torch.int32, device="cuda") extra_lut = mapping_lut if mode == 1 else None extra_q = mapping_quantiles if mode == 2 else None - power = _resolve_mode_power(args, mode) if mode in (3, 6, 7) else 0.5 + power = _resolve_mode_power(args, mode) if mode in (3, 6, 7, 9, 10) else 0.5 topk_profile_histogram( hist_inputs["x"], @@ -513,23 +725,84 @@ def run_benchmark(args) -> List[dict]: mformula = MAPPING_MODE_FORMULAS.get(mode, mname) mode_stats["name"] = mname mode_stats["formula"] = mformula - if mode in (3, 6, 7): - pname = {3: "p", 6: "beta", 7: "alpha"}[mode] + if mode in (3, 6, 7, 9, 10): + pname = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha"}[mode] mode_stats["param"] = f"{pname}={power}" - histograms_results[f"mode_{mode}_{mname}"] = mode_stats - if mode in (3, 6, 7): - pname = {3: "p", 6: "beta", 7: "alpha"}[mode] display_name = f"{mname} ({pname}={power})" else: display_name = mname + key = f"mode_{mode}_{mname}" + histograms_results[key] = mode_stats + hist_entries.append((display_name, f"mode {mode:2d}", mode_stats)) + + # Noscale histogram analysis for parametric transform modes + noscale_modes = [m for m in (3, 6, 7, 9, 10) if m in modes_to_test] + for mode in noscale_modes: + ns_hists = torch.zeros(hist_eff_bs, 256, dtype=torch.int32, device="cuda") + power = _resolve_mode_power(args, mode) + topk_profile_histogram( + hist_inputs["x"], + hist_inputs["dense_kv_indptr"], + ns_hists, + hist_eff_bs, + args.reserved_bos, + args.reserved_eos, + mode, + power, + None, + None, + True, # mapping_noscale=True + ) + torch.cuda.synchronize() + ns_stats = compute_histogram_stats(ns_hists) + ns_stats["raw_counts"] = ns_hists.sum(dim=0).tolist() + mname = MAPPING_MODE_NAMES.get(mode, f"m{mode}") + mformula = MAPPING_MODE_FORMULAS.get(mode, mname) + pname = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha"}[mode] + ns_stats["name"] = f"{mname} noscale" + ns_stats["formula"] = mformula + ns_stats["param"] = f"{pname}={power}" + display_name = f"{mname} noscale ({pname}={power})" + key = f"mode_{mode}_{mname}_noscale" + histograms_results[key] = ns_stats + hist_entries.append((display_name, f"m{mode:2d} ns", ns_stats)) + + # Scale Only baseline: mode 3 with p=1.0 (identity + linear scaling) + scale_hists = torch.zeros(hist_eff_bs, 256, dtype=torch.int32, device="cuda") + topk_profile_histogram( + hist_inputs["x"], + hist_inputs["dense_kv_indptr"], + scale_hists, + hist_eff_bs, + args.reserved_bos, + args.reserved_eos, + 3, # mode 3 (power) + 1.0, # p=1.0 → identity transform + None, + None, + ) + torch.cuda.synchronize() + scale_stats = compute_histogram_stats(scale_hists) + scale_stats["raw_counts"] = scale_hists.sum(dim=0).tolist() + scale_stats["name"] = "Scale Only" + scale_stats["formula"] = "Identity + linear scaling to [0,255]" + scale_stats["param"] = "p=1.0" + histograms_results["mode_scale_Scale Only"] = scale_stats + hist_entries.append(("Scale Only (p=1.0)", "scale ", scale_stats)) + + # Print all histogram entries sorted by gini (ascending = more uniform = better) + hist_entries.sort(key=lambda e: e[2]['gini']) + print(f" --- histogram by gini (sorted, lower=better) ---") + for rank, (display_name, mode_tag, stats) in enumerate(hist_entries, 1): print( - f" {display_name:<22s} (mode {mode}): " - f"gini={mode_stats['gini']:.3f} " - f"max/mean={mode_stats['max_mean_ratio']:.2f} " - f"nonzero_bins={mode_stats['num_nonzero_bins']}/256 " - f"eff_bins={mode_stats['effective_bins']:.1f} " - f"entropy={mode_stats['entropy']:.2f}" + f" {rank:2d}. {display_name:<32s} ({mode_tag}): " + f"gini={stats['gini']:.3f} " + f"max/mean={stats['max_mean_ratio']:.2f} " + f"nonzero_bins={stats['num_nonzero_bins']}/256 " + f"eff_bins={stats['effective_bins']:.1f} " + f"entropy={stats['entropy']:.2f}" ) + config_results["histograms"] = histograms_results all_results.append(config_results) @@ -573,6 +846,9 @@ def main(): "Only used when --histogram is set.") parser.add_argument("--real-histograms", type=str, default=None, help="Path to .npy raw_histograms from calibration (adds 'real' distribution)") + parser.add_argument("--counters", action="store_true", + help="Collect diagnostic counters (threshold_bin, num_above, num_equal, " + "remaining_k, refine_rounds, stage2_input) for each sglang kernel") args = parser.parse_args() results = run_benchmark(args) diff --git a/csrc/clean.py b/csrc/clean.py new file mode 100644 index 0000000..8d258bb --- /dev/null +++ b/csrc/clean.py @@ -0,0 +1,21 @@ +from pathlib import Path +import sys + +def clean_one_leading_space(path: str): + p = Path(path) + text = p.read_text(encoding="utf-8") + + cleaned = "".join( + line[1:] if line.startswith(" ") else line + for line in text.splitlines(keepends=True) + ) + + p.write_text(cleaned, encoding="utf-8") + print(f"Cleaned: {p}") + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Usage: python clean_indent.py ") + sys.exit(1) + + clean_one_leading_space(sys.argv[1]) \ No newline at end of file diff --git a/csrc/register.cc b/csrc/register.cc index 0067474..0a3c11c 100644 --- a/csrc/register.cc +++ b/csrc/register.cc @@ -17,7 +17,8 @@ PYBIND11_MODULE(vortex_torch_C, m){ py::arg("mapping_mode") = 0, py::arg("mapping_power") = 0.5, py::arg("mapping_lut") = py::none(), - py::arg("mapping_quantiles") = py::none()); + py::arg("mapping_quantiles") = py::none(), + py::arg("mapping_noscale") = false); m.def("topk_profile_histogram", &topk_profile_histogram, py::arg("x"), py::arg("dense_kv_indptr"), py::arg("histograms"), py::arg("eff_batch_size"), @@ -25,7 +26,31 @@ PYBIND11_MODULE(vortex_torch_C, m){ py::arg("mapping_mode") = 0, py::arg("mapping_power") = 0.5, py::arg("mapping_lut") = py::none(), - py::arg("mapping_quantiles") = py::none()); + py::arg("mapping_quantiles") = py::none(), + py::arg("mapping_noscale") = false); + m.def("topk_profile_stage1", &topk_profile_stage1, + py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), + py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages"), + py::arg("mapping_mode") = 0, + py::arg("mapping_power") = 0.5, + py::arg("mapping_lut") = py::none(), + py::arg("mapping_quantiles") = py::none(), + py::arg("mapping_noscale") = false); + m.def("topk_profile_counters", &topk_profile_counters, + py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), + py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), + py::arg("counters"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages"), + py::arg("mapping_mode") = 0, + py::arg("mapping_power") = 0.5, + py::arg("mapping_lut") = py::none(), + py::arg("mapping_quantiles") = py::none(), + py::arg("mapping_noscale") = false); m.def("sglang_plan_decode_fa3", &sglang_plan_decode_fa3); m.def("sglang_plan_prefill_fa3", &sglang_plan_prefill_fa3); m.def("Chunkwise_HN2NH_Transpose_FA3", &Chunkwise_HN2NH_Transpose_FA3); diff --git a/csrc/register.h b/csrc/register.h index d4f2d8b..1a8b820 100644 --- a/csrc/register.h +++ b/csrc/register.h @@ -99,7 +99,8 @@ const int64_t max_seq_lengths, const int64_t mapping_mode = 0, const double mapping_power = 0.5, std::optional mapping_lut = std::nullopt, -std::optional mapping_quantiles = std::nullopt +std::optional mapping_quantiles = std::nullopt, +const bool mapping_noscale = false ); void topk_profile_histogram( @@ -112,7 +113,45 @@ const int64_t reserved_eos, const int64_t mapping_mode = 0, const double mapping_power = 0.5, std::optional mapping_lut = std::nullopt, -std::optional mapping_quantiles = std::nullopt +std::optional mapping_quantiles = std::nullopt, +const bool mapping_noscale = false +); + +void topk_profile_stage1( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& sparse_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& sparse_kv_indices, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_num_pages, +const int64_t mapping_mode = 0, +const double mapping_power = 0.5, +std::optional mapping_lut = std::nullopt, +std::optional mapping_quantiles = std::nullopt, +const bool mapping_noscale = false +); + +void topk_profile_counters( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& sparse_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& sparse_kv_indices, +at::Tensor& counters, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_num_pages, +const int64_t mapping_mode = 0, +const double mapping_power = 0.5, +std::optional mapping_lut = std::nullopt, +std::optional mapping_quantiles = std::nullopt, +const bool mapping_noscale = false ); void sglang_plan_decode_fa3( diff --git a/csrc/topk_mapping.cuh b/csrc/topk_mapping.cuh index e3fe3a7..97bc141 100644 --- a/csrc/topk_mapping.cuh +++ b/csrc/topk_mapping.cuh @@ -26,6 +26,9 @@ enum TopKMappingMode { MAPPING_ASINH = 6, // asinh(beta * x), beta via power_exp MAPPING_LOG1P = 7, // sign(x) * log1p(alpha * |x|), alpha via power_exp MAPPING_TRUNC8 = 8, // BF16 upper-8-bit bucketing + MAPPING_ERF = 9, // erf(alpha * x) + MAPPING_TANH = 10, // tanh(alpha * x) + MAPPING_SUBTRACT = 11, // subtract pivot, then fp16 bucketing }; struct TopKMappingParams { @@ -33,6 +36,8 @@ struct TopKMappingParams { float power_exp; // For MAPPING_POWER (default 0.5) const uint8_t* __restrict__ lut; // [256] byte LUT, or nullptr const float* __restrict__ quantiles; // [256] float quantile breakpoints, or nullptr + bool noscale; // Skip auto-range linear scaling, use fp16 bucketing on f(x) + int sample_stride; // Pre-pass sampling stride (1=full, 8=1/8, 0=skip) }; // NOTE: convert_to_uint8() must be defined before including this header. @@ -56,6 +61,14 @@ __device__ __forceinline__ float transform_log1p(float x, float alpha) { return copysignf(log1pf(alpha * fabsf(x)), x); } +__device__ __forceinline__ float transform_erf(float x, float alpha) { + return erff(alpha * x); +} + +__device__ __forceinline__ float transform_tanh(float x, float alpha) { + return tanhf(alpha * x); +} + // ---- Transform dispatcher (returns float, no bucketing) ---- __device__ __forceinline__ float apply_transform(float x, const TopKMappingParams& params) { @@ -64,6 +77,8 @@ __device__ __forceinline__ float apply_transform(float x, const TopKMappingParam case MAPPING_LOG: return transform_log(x); case MAPPING_ASINH: return transform_asinh(x, params.power_exp); case MAPPING_LOG1P: return transform_log1p(x, params.power_exp); + case MAPPING_ERF: return transform_erf(x, params.power_exp); + case MAPPING_TANH: return transform_tanh(x, params.power_exp); default: return x; } } @@ -75,14 +90,16 @@ __device__ __forceinline__ uint8_t linear_map_to_uint8(float val, float range_mi return static_cast(min(max(bin, 0), 255)); } -// ---- BF16 upper-8-bit bucketing (mode 8) ---- +// ---- BF16-aware bucketing (mode 8) ---- +// BF16 has 8 exponent + 7 mantissa bits. Taking the upper 8 bits of the +// sign-flipped bf16 bit-pattern yields only ~20 distinct bins for typical +// data (the byte is almost entirely exponent). Instead, convert through +// fp16 (5 exp + 10 mantissa) which puts 5 exp + 2 mantissa bits in the +// upper byte, giving ~135+ distinct bins — equivalent to mode 0 but +// explicitly available as a named mode for documentation/benchmarking. __device__ __forceinline__ uint8_t convert_to_uint8_bf16(float x) { - __nv_bfloat16 bf = __float2bfloat16_rn(x); - uint16_t bits = __bfloat16_as_ushort(bf); - uint16_t key = (bits & 0x8000) ? static_cast(~bits) - : static_cast(bits | 0x8000); - return static_cast(key >> 8); + return convert_to_uint8(x); // fp16 sign-flip bucketing } // ---- Non-transform mapping functions (unchanged) ---- @@ -130,12 +147,17 @@ __device__ __forceinline__ uint8_t mapped_convert_to_uint8( case MAPPING_POWER: case MAPPING_LOG: case MAPPING_ASINH: - case MAPPING_LOG1P: { + case MAPPING_LOG1P: + case MAPPING_ERF: + case MAPPING_TANH: { float val = apply_transform(x, params); + if (params.noscale) return convert_to_uint8(val); return linear_map_to_uint8(val, range_min, inv_range); } case MAPPING_TRUNC8: return convert_to_uint8_bf16(x); + case MAPPING_SUBTRACT: + return convert_to_uint8(x - range_min); // range_min repurposed as pivot default: // MAPPING_NONE return convert_to_uint8(x); } @@ -144,5 +166,11 @@ __device__ __forceinline__ uint8_t mapped_convert_to_uint8( // Helper: check if a mapping mode needs the auto-range pre-pass __device__ __forceinline__ bool needs_auto_range(int mode) { return (mode == MAPPING_POWER || mode == MAPPING_LOG || - mode == MAPPING_ASINH || mode == MAPPING_LOG1P); + mode == MAPPING_ASINH || mode == MAPPING_LOG1P || + mode == MAPPING_ERF || mode == MAPPING_TANH); +} + +// Helper: check if a mapping mode needs the pivot pre-pass +__device__ __forceinline__ bool needs_pivot(int mode) { + return (mode == MAPPING_SUBTRACT); } diff --git a/csrc/topk_sglang.cu b/csrc/topk_sglang.cu index 5959270..9213016 100644 --- a/csrc/topk_sglang.cu +++ b/csrc/topk_sglang.cu @@ -452,18 +452,30 @@ __device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) constexpr int VORTEX_MAX_TOPK = 2048; +// Per-segment diagnostic counters written by WriteCounters mode +constexpr int COUNTER_THRESHOLD_BIN = 0; // Stage 1 coarse threshold bin id +constexpr int COUNTER_NUM_ABOVE = 1; // elements routed above threshold in Stage 1 +constexpr int COUNTER_NUM_EQUAL = 2; // elements in threshold bin (Stage 2 input) +constexpr int COUNTER_REMAINING_K = 3; // topk slots remaining after Stage 1 routing +constexpr int COUNTER_REFINE_ROUNDS = 4; // Stage 2 rounds used (0 = resolved in Stage 1) +constexpr int COUNTER_STAGE2_INPUT = 5; // candidates entering first Stage 2 refine round +constexpr int NUM_TOPK_COUNTERS = 6; + // Templated version of fast_topk_cuda_tl: // - ScoreT: float or __nv_bfloat16 +// - StopAfterStage1: return after Stage 1 route/filter (for profiling) +// - WriteCounters: write diagnostic counters to global memory // - target_k: runtime parameter (replaces compile-time TopK) // - mapping: configurable value-remapping for Stage 1 bin assignment -template +template __device__ void fast_topk_vortex( const ScoreT* __restrict__ input, int* __restrict__ index, int row_start, int length, int target_k, - const TopKMappingParams& mapping) + const TopKMappingParams& mapping, + int* counters = nullptr) { int topk = target_k; constexpr auto BLOCK_SIZE = 1024; @@ -497,10 +509,14 @@ __device__ void fast_topk_vortex( __syncthreads(); } - // Pre-pass: compute per-block min/max of transformed values for linear bucketing - if (needs_auto_range(mapping.mode)) { + // Pre-pass: compute per-block min/max of transformed values for linear bucketing. + // sample_stride > 1 reduces pre-pass cost by scanning every Nth element; + // the approximated range may miss extreme outliers but Stage 2 uses raw + // float bits for exact ordering, so correctness is preserved. + if (needs_auto_range(mapping.mode) && !mapping.noscale) { + const int stride = (mapping.sample_stride > 1) ? mapping.sample_stride : 1; float local_min = __FLT_MAX__, local_max = -__FLT_MAX__; - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + for (int idx = tx * stride; idx < length; idx += BLOCK_SIZE * stride) { float val = apply_transform(vortex_to_float(input[idx + row_start]), mapping); local_min = fminf(local_min, val); local_max = fmaxf(local_max, val); @@ -528,12 +544,46 @@ __device__ void fast_topk_vortex( } } __syncthreads(); + } else if (needs_pivot(mapping.mode)) { + // Pivot pre-pass: compute mean of all elements, store in s_range_min. + // MAPPING_SUBTRACT uses convert_to_uint8(x - range_min), so centering + // around the mean helps distribute values more evenly across bins. + float local_sum = 0.0f; + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + local_sum += vortex_to_float(input[idx + row_start]); + } + // Warp-level reduction + for (int offset = 16; offset > 0; offset >>= 1) { + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + } + __shared__ float s_warp_sums[32]; + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) s_warp_sums[warp_id] = local_sum; + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_sum = s_warp_sums[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + } + if (tx == 0) { + s_range_min = local_sum / float(length); // mean as pivot + s_range_inv_range = 0.0f; + } + } + __syncthreads(); } else { if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } __syncthreads(); } // Stage 1: 8-bit coarse histogram (with optional mapping) + // Bin cache: store computed bins in vh_input_idx[1] (reinterpreted as uint8_t*) + // to avoid recomputing mapped_convert_to_uint8 in the route/filter pass. + // vh_input_idx[1] is unused until Stage 2 double-buffering starts after route. + constexpr int BIN_CACHE_CAPACITY = SMEM_INPUT_SIZE * static_cast(sizeof(int)); // uint8 entries + uint8_t* bin_cache = reinterpret_cast(vh_input_idx[1]); + const bool use_bin_cache = (length <= BIN_CACHE_CAPACITY); + if (tx < RADIX + 1) vh_histogram[tx] = 0; __syncthreads(); @@ -543,6 +593,9 @@ __device__ void fast_topk_vortex( mapping, s_mapping_lut, s_mapping_quantiles, s_range_min, s_range_inv_range); ::atomicAdd(&vh_histogram[bin], 1); + if (use_bin_cache) { + bin_cache[idx] = bin; + } } __syncthreads(); @@ -574,19 +627,35 @@ __device__ void fast_topk_vortex( const auto threshold_bin = vh_threshold_bin_id; topk -= vh_histogram[threshold_bin + 1]; + if (WriteCounters && tx == 0 && counters) { + counters[COUNTER_THRESHOLD_BIN] = threshold_bin; + counters[COUNTER_REMAINING_K] = topk; + } + if (topk == 0) { for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = static_cast( - mapped_convert_to_uint8( - vortex_to_float(input[idx + row_start]), - mapping, s_mapping_lut, s_mapping_quantiles, - s_range_min, s_range_inv_range)); + int bin; + if (use_bin_cache) { + bin = static_cast(bin_cache[idx]); + } else { + bin = static_cast( + mapped_convert_to_uint8( + vortex_to_float(input[idx + row_start]), + mapping, s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range)); + } if (bin > threshold_bin) { const auto pos = ::atomicAdd(&vh_counter, 1); index[pos] = idx; } } __syncthreads(); + if (WriteCounters && tx == 0 && counters) { + counters[COUNTER_NUM_ABOVE] = vh_counter; + counters[COUNTER_NUM_EQUAL] = 0; + counters[COUNTER_REFINE_ROUNDS] = 0; + counters[COUNTER_STAGE2_INPUT] = 0; + } return; } else { __syncthreads(); @@ -595,10 +664,15 @@ __device__ void fast_topk_vortex( for (int idx = tx; idx < length; idx += BLOCK_SIZE) { const auto raw_input = vortex_to_float(input[idx + row_start]); - const auto bin = static_cast( - mapped_convert_to_uint8(raw_input, mapping, - s_mapping_lut, s_mapping_quantiles, - s_range_min, s_range_inv_range)); + int bin; + if (use_bin_cache) { + bin = static_cast(bin_cache[idx]); + } else { + bin = static_cast( + mapped_convert_to_uint8(raw_input, mapping, + s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range)); + } if (bin > threshold_bin) { const auto pos = ::atomicAdd(&vh_counter, 1); index[pos] = idx; @@ -613,9 +687,19 @@ __device__ void fast_topk_vortex( } } __syncthreads(); + if (WriteCounters && tx == 0 && counters) { + counters[COUNTER_NUM_ABOVE] = vh_counter; + counters[COUNTER_NUM_EQUAL] = vh_num_input[0]; + counters[COUNTER_STAGE2_INPUT] = vh_num_input[0]; + } + if (StopAfterStage1) return; } // Stage 2: refine with 8-bit radix passes (unchanged — uses raw float bits) + if constexpr (WriteCounters) { + // Default: all 4 rounds used; overwritten at break if resolved early + if (tx == 0 && counters) counters[COUNTER_REFINE_ROUNDS] = 4; + } #pragma unroll 4 for (int round = 0; round < 4; ++round) { __shared__ int vh_last_remain; @@ -649,6 +733,11 @@ __device__ void fast_topk_vortex( } } __syncthreads(); + if constexpr (WriteCounters) { + if (tx == 0 && counters) { + counters[COUNTER_REFINE_ROUNDS] = round + 1; + } + } break; } else { __syncthreads(); @@ -722,6 +811,92 @@ void TopKOutput_Kernel( } } +// ====================================================================== +// Profiling Stage1 kernel: runs pre-pass + hist + cumsum + route/filter, +// stops before Stage 2 refinement (for sub-phase timing) +// ====================================================================== +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKStage1_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos, + const TopKMappingParams mapping) +{ + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_vortex( + score_blk, s_indices, 0, nblk, topk_val, mapping); + __syncthreads(); + + // Remap position indices -> page indices via dense_kv_indices + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } +} + +// ====================================================================== +// Profiling counters kernel: runs full pipeline + writes diagnostic +// counters to a separate global-memory tensor +// ====================================================================== +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKCounters_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + int* __restrict__ counters, // [eff_batch_size, NUM_TOPK_COUNTERS] + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos, + const TopKMappingParams mapping) +{ + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_vortex( + score_blk, s_indices, 0, nblk, topk_val, mapping, + counters + bx * NUM_TOPK_COUNTERS); + __syncthreads(); + + // Remap position indices -> page indices via dense_kv_indices + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } +} + // ====================================================================== // Profiling histogram kernel: runs only Stage 1 and returns per-segment // 256-bin histograms for distribution analysis @@ -762,10 +937,11 @@ void TopKHistogram_Kernel( __syncthreads(); } - // Pre-pass: compute per-block min/max for transform modes - if (needs_auto_range(mapping.mode)) { + // Pre-pass: compute per-block min/max for transform modes (supports sampled stride) + if (needs_auto_range(mapping.mode) && !mapping.noscale) { + const int stride = (mapping.sample_stride > 1) ? mapping.sample_stride : 1; float local_min = __FLT_MAX__, local_max = -__FLT_MAX__; - for (int idx = tx; idx < nblk; idx += BLOCK_SIZE) { + for (int idx = tx * stride; idx < nblk; idx += BLOCK_SIZE * stride) { float val = apply_transform(vortex_to_float(score_blk[idx]), mapping); local_min = fminf(local_min, val); local_max = fmaxf(local_max, val); @@ -791,6 +967,30 @@ void TopKHistogram_Kernel( } } __syncthreads(); + } else if (needs_pivot(mapping.mode)) { + // Pivot pre-pass: compute mean for MAPPING_SUBTRACT + float local_sum = 0.0f; + for (int idx = tx; idx < nblk; idx += BLOCK_SIZE) { + local_sum += vortex_to_float(score_blk[idx]); + } + for (int offset = 16; offset > 0; offset >>= 1) { + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + } + __shared__ float s_warp_sums_h[32]; + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) s_warp_sums_h[warp_id] = local_sum; + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_sum = s_warp_sums_h[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + } + if (tx == 0) { + s_range_min = local_sum / float(nblk); + s_range_inv_range = 0.0f; + } + } + __syncthreads(); } else { if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } __syncthreads(); @@ -949,7 +1149,8 @@ void topk_output_sglang( const int64_t mapping_mode, const double mapping_power, std::optional mapping_lut, - std::optional mapping_quantiles) + std::optional mapping_quantiles, + const bool mapping_noscale) { TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, "topk_output: topk_val (", topk_val, @@ -961,6 +1162,8 @@ void topk_output_sglang( mapping.power_exp = static_cast(mapping_power); mapping.lut = nullptr; mapping.quantiles = nullptr; + mapping.noscale = mapping_noscale; + mapping.sample_stride = 1; if (mapping_lut.has_value()) { const auto& lut = mapping_lut.value(); @@ -1029,7 +1232,8 @@ void topk_profile_histogram( const int64_t mapping_mode, const double mapping_power, std::optional mapping_lut, - std::optional mapping_quantiles) + std::optional mapping_quantiles, + const bool mapping_noscale) { CHECK_CUDA(x); CHECK_CUDA(dense_kv_indptr); @@ -1046,6 +1250,8 @@ void topk_profile_histogram( mapping.power_exp = static_cast(mapping_power); mapping.lut = nullptr; mapping.quantiles = nullptr; + mapping.noscale = mapping_noscale; + mapping.sample_stride = 1; if (mapping_lut.has_value()) { const auto& lut = mapping_lut.value(); @@ -1093,3 +1299,175 @@ void topk_profile_histogram( "topk_profile_histogram kernel failed: ", ::cudaGetErrorString(result)); } +// Helper: build TopKMappingParams from host arguments +static TopKMappingParams build_mapping_params( + int64_t mapping_mode, double mapping_power, + std::optional& mapping_lut, + std::optional& mapping_quantiles, + bool mapping_noscale = false, + int sample_stride = 1) +{ + TopKMappingParams mapping{}; + mapping.mode = static_cast(mapping_mode); + mapping.power_exp = static_cast(mapping_power); + mapping.lut = nullptr; + mapping.quantiles = nullptr; + mapping.noscale = mapping_noscale; + mapping.sample_stride = sample_stride; + + if (mapping_lut.has_value()) { + const auto& lut = mapping_lut.value(); + CHECK_CUDA(lut); + TORCH_CHECK(lut.dim() == 1 && lut.size(0) == 256 && lut.scalar_type() == at::ScalarType::Byte, + "mapping_lut must be a 1D uint8 tensor of size 256"); + mapping.lut = lut.data_ptr(); + } + if (mapping_quantiles.has_value()) { + const auto& q = mapping_quantiles.value(); + CHECK_CUDA(q); + TORCH_CHECK(q.dim() == 1 && q.size(0) == 256 && q.scalar_type() == at::ScalarType::Float, + "mapping_quantiles must be a 1D float32 tensor of size 256"); + mapping.quantiles = q.data_ptr(); + } + return mapping; +} + +// ====================================================================== +// Profiling: Stage 1 only (pre-pass + hist + cumsum + route/filter) +// ====================================================================== +void topk_profile_stage1( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages, + const int64_t mapping_mode, + const double mapping_power, + std::optional mapping_lut, + std::optional mapping_quantiles, + const bool mapping_noscale) +{ + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_profile_stage1: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + + auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles, mapping_noscale); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + setup_kernel_smem_once, kSmem>(); + TopKStage1_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos, + mapping); + } else if (x.scalar_type() == at::ScalarType::Float) { + setup_kernel_smem_once, kSmem>(); + TopKStage1_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos, + mapping); + } else { + TORCH_CHECK(false, + "topk_profile_stage1: unsupported dtype ", + x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_profile_stage1 kernel failed: ", ::cudaGetErrorString(result)); +} + +// ====================================================================== +// Profiling: full pipeline + diagnostic counters +// ====================================================================== +void topk_profile_counters( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + at::Tensor& counters, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages, + const int64_t mapping_mode, + const double mapping_power, + std::optional mapping_lut, + std::optional mapping_quantiles, + const bool mapping_noscale) +{ + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_profile_counters: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + CHECK_CUDA(counters); + TORCH_CHECK(counters.dim() == 2 && counters.size(0) == eff_batch_size + && counters.size(1) == NUM_TOPK_COUNTERS, + "counters must be [eff_batch_size, ", NUM_TOPK_COUNTERS, "]"); + TORCH_CHECK(counters.scalar_type() == at::ScalarType::Int, + "counters must be int32"); + + auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles, mapping_noscale); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + setup_kernel_smem_once, kSmem>(); + TopKCounters_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + counters.data_ptr(), + topk_val, + reserved_bos, + reserved_eos, + mapping); + } else if (x.scalar_type() == at::ScalarType::Float) { + setup_kernel_smem_once, kSmem>(); + TopKCounters_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + counters.data_ptr(), + topk_val, + reserved_bos, + reserved_eos, + mapping); + } else { + TORCH_CHECK(false, + "topk_profile_counters: unsupported dtype ", + x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_profile_counters kernel failed: ", ::cudaGetErrorString(result)); +} + diff --git a/csrc/topk_slgang_ori.cu b/csrc/topk_slgang_ori.cu new file mode 100644 index 0000000..04a2b73 --- /dev/null +++ b/csrc/topk_slgang_ori.cu @@ -0,0 +1,546 @@ +/** + * @NOTE: This file is adapted from + * https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_v32/topk_selector.py + * We: + * 1. adapt from tilelang to pure cuda + * 2. optimize the performance a little + * 3. fix the potential illegal memory access + */ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace { + +constexpr int TopK = 2048; +constexpr int kThreadsPerBlock = 1024; + +#ifdef USE_ROCM +// On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a +// per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. +#ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES +constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); +#else +constexpr size_t kSmem = 48 * 1024; // bytes +#endif +#else +// Reduced from 128KB to 32KB to improve occupancy. +// Each radix pass needs at most ~TopK candidates in the threshold bin, +// so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. +constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) +#endif + +struct FastTopKParams { + const float* __restrict__ input; // [B, input_stride] + const int32_t* __restrict__ row_starts; // [B] + int32_t* __restrict__ indices; // [B, TopK] + int32_t* __restrict__ lengths; // [B] + int64_t input_stride; +}; + +// when length <= TopK, we can directly write the indices +__device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { + const auto tid = threadIdx.x; + for (int i = tid; i < TopK; i += kThreadsPerBlock) { + indice[i] = (i < length) ? i : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform( + const float* __restrict__ score, + int32_t length, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + dst_page_table[i] = (i < length) ? src_page_table[i] : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform_ragged( + const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; + } +} + +__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +__device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { + // An optimized topk kernel copied from tilelang kernel + // We assume length > TopK here, or it will crash + int topk = TopK; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + // allocate for two rounds + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // stage 1: 8bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(input[idx + row_start]); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uint8(input[idx + row_start])); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = input[idx + row_start]; + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + /// NOTE: (dark) fuse the histogram computation here + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8bit radix passes +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + // clip here to prevent overflow + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = input[idx + row_start]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[TopK - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + /// NOTE: (dark) fuse the histogram computation here + s_input_idx[r_idx ^ 1][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // topk + void topk_kernel(const FastTopKParams params) { + const auto& [input, row_starts, indices, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto indice = indices + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_cuda(score, indice, length); + } else { + return fast_topk_cuda_tl(score, indice, row_start, length); + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // decode + void topk_transform_decode_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride) { + const auto& [input, _1, _2, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = 0; + const auto length = lengths[bid]; + const auto src_page_entry = src_page_table + bid * src_stride; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // prefill + void topk_transform_prefill_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride, + const int32_t* __restrict__ cu_seqlens_q, + const int64_t prefill_bs) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto length = lengths[bid]; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + + /// NOTE: prefill bs is usually small, we can just use a simple loop here + /// We ensure that last cu_seqlens is equal to number of blocks launched + __shared__ const int32_t* s_src_page_entry; + if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { + if (tid < prefill_bs) { + if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { + s_src_page_entry = src_page_table + tid * src_stride; + } + } + } else { + for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { + if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { + s_src_page_entry = src_page_table + i * src_stride; + } + } + } + __syncthreads(); + const auto src_page_entry = s_src_page_entry; + + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv + void topk_transform_prefill_ragged_kernel( + const FastTopKParams params, + int32_t* __restrict__ topk_indices_ragged, + const int32_t* __restrict__ topk_indices_offset) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto dst_indices_entry = topk_indices_ragged + bid * TopK; + const auto score = input + bid * input_stride; + const auto offset = topk_indices_offset[bid]; + + if (length <= TopK) { + return naive_topk_transform_ragged(score, length, dst_indices_entry, offset); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_indices_entry[idx_0] = pos_0 + offset; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_indices_entry[idx_1] = pos_1 + offset; + } +} + +auto get_params( + const at::Tensor& score, + const at::Tensor& lengths, + std::optional row_starts_opt = std::nullopt, + std::optional indices_opt = std::nullopt) -> FastTopKParams { + const auto B = score.size(0); + TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); + if (row_starts_opt.has_value()) { + const auto& row_starts = row_starts_opt.value(); + TORCH_CHECK(row_starts.dim() == 1); + TORCH_CHECK(row_starts.size(0) == B); + } + TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); + TORCH_CHECK(lengths.size(0) == B); + int32_t* indices_data_ptr = nullptr; + if (indices_opt.has_value()) { + const auto& indices = indices_opt.value(); + TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); + TORCH_CHECK(indices.size(0) == B); + TORCH_CHECK(indices.size(1) == TopK); + indices_data_ptr = indices.data_ptr(); + } + + return FastTopKParams{ + .input = score.data_ptr(), + .row_starts = row_starts_opt.has_value() ? row_starts_opt->data_ptr() : nullptr, + .indices = indices_data_ptr, + .lengths = lengths.data_ptr(), + .input_stride = score.stride(0), + }; +} + +template +void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { +#ifdef USE_ROCM + // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, + // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing + // a function pointer directly, so cast explicitly. + return ::cudaFuncSetAttribute( + reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#else + // CUDA: keep original behavior (no cast needed). + return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#endif + }(); + TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); +} + +} // namespace + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +void fast_topk_interface( + const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(indices); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + CHECK_CUDA(lengths); + const auto params = get_params(score, lengths, row_starts_opt, indices); + const auto B = score.size(0); + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + setup_kernel_smem_once(); + topk_kernel<<>>(params); + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} + +void fast_topk_transform_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& dst_page_table, + const at::Tensor& src_page_table, + const at::Tensor& cu_seqlens_q, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(dst_page_table); + CHECK_CUDA(src_page_table); + CHECK_CUDA(cu_seqlens_q); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); + TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); + TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); + const auto prefill_bs = cu_seqlens_q.size(0) - 1; + TORCH_CHECK(dst_page_table.size(0) == B); + TORCH_CHECK(dst_page_table.size(1) == TopK); + TORCH_CHECK(src_page_table.size(0) == prefill_bs); + TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + const auto src_stride = src_page_table.stride(0); + + // dispatch to decode or prefill + // extend and draft extend: row_starts_opt is not null, invokes the prefill kernel + // decode: row_starts_opt is null, invokes the decode kernel + // target verify: row_starts_opt is null, invokes the prefill kernel + const auto is_decode = !row_starts_opt.has_value() && prefill_bs == B; + if (is_decode) { + setup_kernel_smem_once(); + topk_transform_decode_kernel<<>>( + params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride); + } else { + setup_kernel_smem_once(); + topk_transform_prefill_kernel<<>>( + params, + dst_page_table.data_ptr(), + src_page_table.data_ptr(), + src_stride, + cu_seqlens_q.data_ptr(), + prefill_bs); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} + +void fast_topk_transform_ragged_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& topk_indices_ragged, + const at::Tensor& topk_indices_offset, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(topk_indices_ragged); + CHECK_CUDA(topk_indices_offset); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous()); + TORCH_CHECK(topk_indices_offset.dim() == 1); + + TORCH_CHECK(topk_indices_ragged.size(0) == B); + TORCH_CHECK(topk_indices_ragged.size(1) == TopK); + TORCH_CHECK(topk_indices_offset.size(0) == B); + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + + setup_kernel_smem_once(); + topk_transform_prefill_ragged_kernel<<>>( + params, topk_indices_ragged.data_ptr(), topk_indices_offset.data_ptr()); + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} diff --git a/examples/run_distribution_analysis.sh b/examples/run_distribution_analysis.sh index 287c454..3022dda 100755 --- a/examples/run_distribution_analysis.sh +++ b/examples/run_distribution_analysis.sh @@ -5,11 +5,13 @@ # Profiles the SGLang TopK kernel's first-pass bucket distribution # to identify hotspot buckets causing tail latency. # -# Three steps: -# 1. Calibrate — collect real-data histograms -# (skippable via --real-histograms PATH) -# 2. Bench — histogram profiling (bucket_uniform + normal) -# 3. Analyze — comparison plots + bucket count tables +# Four steps: +# 1. Calibrate — collect real-data histograms +# (skippable via --real-histograms PATH) +# 2. Auto-tune — sweep hyperparameters to find best per-mode power +# 3. Bench — histogram profiling (bucket_uniform + normal) +# noscale kernels use the same autotuned power +# 4. Analyze — comparison plots + bucket count tables # # All outputs (JSON, plots, CSV tables, logs) are written to a # single timestamped folder under examples/results/dist_analysis_*. @@ -30,6 +32,9 @@ # 6: Asinh — y = asinh(beta * x) # 7: Log1p — y = sign(x) * log1p(alpha * |x|) # 8: Trunc8 — bf16 upper-8-bit bucketing +# 9: Erf — y = erf(alpha * x) +# 10: Tanh — y = tanh(alpha * x) +# 11: Subtract — x - pivot (RadiK-style scatter) set -euo pipefail @@ -37,14 +42,14 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BENCH_DIR="${SCRIPT_DIR}/../benchmarks" # ── Defaults ────────────────────────────────────────────────── -GPU_ID=5 +GPU_ID=4 MODEL_NAME="Qwen/Qwen3-1.7B" TOPK_VAL=30 MEM=0.7 ALGO="block_sparse_attention" # The path to the raw_histograms.npy file (set to skip calibration) -# REAL_HISTOGRAMS="/scr/dataset/yuke/xinrui/new/vortex_torch/examples/calibration/raw_histograms.npy" -REAL_HISTOGRAMS="" +REAL_HISTOGRAMS="/scr/dataset/yuke/xinrui/new/vortex_torch/examples/calibration/raw_histograms.npy" + # ── Parse arguments ─────────────────────────────────────────── while [[ $# -gt 0 ]]; do case "$1" in @@ -97,45 +102,101 @@ else echo ">>> Step 1: Done. Calibration saved to ${CALIBRATION_DIR}" fi -# ── Step 2: Histogram profiling (bucket_uniform + normal) ───── +# ── Step 2: Auto-tune — sweep hyperparameters ────────────────── +echo "" +echo ">>> Step 2: Auto-tuning hyperparameters (modes 3, 6, 7, 9, 10)" + +AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" + +# Build autotune data source args +AUTOTUNE_EXTRA_ARGS=() +if [ -n "${REAL_HIST_PATH:-}" ]; then + AUTOTUNE_EXTRA_ARGS+=(--real-histograms "${REAL_HIST_PATH}") +fi + +PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val "${TOPK_VAL}" \ + --batch-size 4 \ + --seq-len 32768 \ + --num-kv-heads 2 \ + "${AUTOTUNE_EXTRA_ARGS[@]}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2_autotune.log" + +echo ">>> Step 2: Done. Autotune results saved to ${AUTOTUNE_JSON}" + +# ── Step 3: Histogram profiling (bucket_uniform + normal) ───── echo "" -echo ">>> Step 2: Kernel-level histogram profiling (bucket_uniform + normal)" +echo ">>> Step 3: Kernel-level histogram profiling (bucket_uniform + normal)" BENCH_JSON="${RUN_DIR}/bench_distribution.json" -python "${BENCH_DIR}/bench_topk.py" \ +# Build optional args for bench_topk.py +BENCH_EXTRA_ARGS=() +if [ -n "${REAL_HIST_PATH:-}" ]; then + BENCH_EXTRA_ARGS+=(--real-histograms "${REAL_HIST_PATH}") +fi + +# Derive calibration directory from histogram path to find lut.npy / quantiles.npy +CALIB_DIR="$(dirname "${REAL_HIST_PATH}")" +LUT_FILE="${CALIB_DIR}/lut.npy" +QUANTILES_FILE="${CALIB_DIR}/quantiles.npy" + +if [ -f "${LUT_FILE}" ]; then + BENCH_EXTRA_ARGS+=(--lut-path "${LUT_FILE}") + echo " Using LUT for mode 1: ${LUT_FILE}" +else + echo " WARNING: ${LUT_FILE} not found — mode 1 (LUT CDF) will be skipped" +fi +if [ -f "${QUANTILES_FILE}" ]; then + BENCH_EXTRA_ARGS+=(--quantiles-path "${QUANTILES_FILE}") + echo " Using quantiles for mode 2: ${QUANTILES_FILE}" +else + echo " WARNING: ${QUANTILES_FILE} not found — mode 2 (Quantile) will be skipped" +fi + +PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ --batch-sizes 4 \ - --seq-lens 4096 \ + --seq-lens 32768 \ --topk-vals "${TOPK_VAL}" \ - --num-kv-heads 2 \ + --num-kv-heads 8 \ --distributions bucket_uniform normal \ --histogram \ - --filter-kernels sglang_m0 sglang_m1 sglang_m2 sglang_m3 sglang_m4 sglang_m6 sglang_m7 sglang_m8 \ + "${BENCH_EXTRA_ARGS[@]}" \ + --autotune-json "${AUTOTUNE_JSON}" \ + --filter-kernels naive sglang_m0 sglang_scale sglang_m1 sglang_m2 sglang_m3 sglang_m3_noscale sglang_m4 sglang_m6 sglang_m6_noscale sglang_m7 sglang_m7_noscale sglang_m8 sglang_m9 sglang_m9_noscale sglang_m10 sglang_m10_noscale sglang_m11 \ --repeat 20 \ --output-json "${BENCH_JSON}" \ - 2>&1 | tee "${RUN_DIR}/step2_bench.log" + 2>&1 | tee "${RUN_DIR}/step3_bench.log" -echo ">>> Step 2: Done. Results saved to ${BENCH_JSON}" +echo ">>> Step 3: Done. Results saved to ${BENCH_JSON}" -# ── Step 3: Analyze — comparison plots + tables ─────────────── +# ── Step 4: Analyze — comparison plots + tables ─────────────── echo "" -echo ">>> Step 3: Generating distribution comparison plots + tables" +echo ">>> Step 4: Generating distribution comparison plots + tables" + +# Build optional args for analyze +ANALYZE_EXTRA_ARGS=() +if [ -n "${REAL_HIST_PATH:-}" ]; then + ANALYZE_EXTRA_ARGS+=(--real-histograms "${REAL_HIST_PATH}") +fi python "${BENCH_DIR}/analyze_topk_distribution.py" \ --bench-json "${BENCH_JSON}" \ - --real-histograms "${REAL_HIST_PATH}" \ + "${ANALYZE_EXTRA_ARGS[@]}" \ --output-dir "${RUN_DIR}" \ - 2>&1 | tee "${RUN_DIR}/step3_analyze.log" + 2>&1 | tee "${RUN_DIR}/step4_analyze.log" -echo ">>> Step 3: Done." +echo ">>> Step 4: Done." # ── Summary ─────────────────────────────────────────────────── echo "" echo "============================================================" echo "Bucket Distribution Profiling Complete" echo " All outputs in: ${RUN_DIR}/" +echo " autotune_results.json — hyperparameter sweep rankings" echo " bench_distribution.json — raw benchmark data" echo " distribution_comparison.png — bucket dist plots" echo " bucket_counts.csv — per-bucket count table" -echo " step{1,2,3}_*.log — pipeline logs" +echo " step{1,2,3,4}_*.log — pipeline logs" echo "============================================================" diff --git a/examples/run_distribution_analysis_new.sh b/examples/run_distribution_analysis_new.sh index 3dc1bd4..f0938ff 100755 --- a/examples/run_distribution_analysis_new.sh +++ b/examples/run_distribution_analysis_new.sh @@ -6,7 +6,10 @@ # Mode 3 (Power): y = sign(x) * |x|^p # Mode 6 (Asinh): y = asinh(beta * x) # Mode 7 (Log1p): y = sign(x) * log1p(alpha * |x|) -# Mode 8 (Trunc8): bf16 upper-8-bit bucketing +# Mode 8 (Trunc8): bf16 upper-8-bit bucketing +# Mode 9 (Erf): y = erf(alpha * x) +# Mode 10 (Tanh): y = tanh(alpha * x) +# Mode 11 (Subtract): x - pivot (RadiK-style scatter) # # Four steps: # 1. Calibrate — collect real-data histograms @@ -26,7 +29,7 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BENCH_DIR="${SCRIPT_DIR}/../benchmarks" # ── Defaults ────────────────────────────────────────────────── -GPU_ID=5 +GPU_ID=4 MODEL_NAME="Qwen/Qwen3-1.7B" TOPK_VAL=30 MEM=0.7 @@ -56,7 +59,7 @@ RUN_DIR="${RESULTS_DIR}/dist_analysis_${TIMESTAMP}" mkdir -p "${RUN_DIR}" echo "============================================================" -echo "Bucket Distribution Profiling (modes 3, 6, 7)" +echo "Bucket Distribution Profiling (modes 3, 6, 7, 8, 9, 10, 11)" echo " Model: ${MODEL_NAME}" echo " Algorithm: ${ALGO}" echo " TopK: ${TOPK_VAL}" @@ -95,9 +98,10 @@ AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ --topk-val "${TOPK_VAL}" \ --batch-size 4 \ - --seq-len 4096 \ - --num-kv-heads 2 \ + --seq-len 32768 \ + --num-kv-heads 8 \ --real-histograms "${REAL_HIST_PATH}" \ + --latency-rerank \ --output-json "${AUTOTUNE_JSON}" \ 2>&1 | tee "${RUN_DIR}/step2_autotune.log" @@ -111,14 +115,14 @@ BENCH_JSON="${RUN_DIR}/bench_distribution.json" PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ --batch-sizes 4 \ - --seq-lens 4096 \ + --seq-lens 32768 \ --topk-vals "${TOPK_VAL}" \ - --num-kv-heads 2 \ + --num-kv-heads 8 \ --distributions bucket_uniform normal \ --histogram \ --real-histograms "${REAL_HIST_PATH}" \ --autotune-json "${AUTOTUNE_JSON}" \ - --filter-kernels sglang_m3 sglang_m6 sglang_m7 sglang_m8 \ + --filter-kernels naive sglang_m0 sglang_scale sglang_m3 sglang_m3_noscale sglang_m6 sglang_m6_noscale sglang_m7 sglang_m7_noscale sglang_m8 sglang_m9 sglang_m9_noscale sglang_m10 sglang_m10_noscale sglang_m11 \ --repeat 20 \ --output-json "${BENCH_JSON}" \ 2>&1 | tee "${RUN_DIR}/step3_bench.log" @@ -134,13 +138,12 @@ python "${BENCH_DIR}/analyze_topk_distribution.py" \ --real-histograms "${REAL_HIST_PATH}" \ --output-dir "${RUN_DIR}" \ 2>&1 | tee "${RUN_DIR}/step4_analyze.log" - echo ">>> Step 4: Done." # ── Summary ─────────────────────────────────────────────────── echo "" echo "============================================================" -echo "Bucket Distribution Profiling Complete (modes 3, 6, 7)" +echo "Bucket Distribution Profiling Complete (modes 3, 6, 7, 8, 9, 10, 11)" echo " All outputs in: ${RUN_DIR}/" echo " autotune_results.json — hyperparameter sweep rankings" echo " bench_distribution.json — raw benchmark data" diff --git a/examples/run_topk_benchmark.sh b/examples/run_topk_benchmark.sh index 5a7ed94..6ac2b9d 100755 --- a/examples/run_topk_benchmark.sh +++ b/examples/run_topk_benchmark.sh @@ -34,7 +34,7 @@ set -euo pipefail # use GPU_ID to set the GPU id you want to use -GPU_ID=5 +GPU_ID=4 SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BENCH_DIR="${SCRIPT_DIR}/../benchmarks" @@ -118,7 +118,7 @@ else python "${BENCH_DIR}/autotune_topk_mapping.py" \ --topk-val "${TOPK_VAL}" \ --batch-size 4 \ - --seq-len 4096 \ + --seq-len 32768 \ --num-kv-heads 2 \ ${REAL_HIST_ARGS} \ --output-json "${AUTOTUNE_JSON}" \ @@ -143,7 +143,7 @@ else python "${BENCH_DIR}/bench_topk.py" \ --batch-sizes 4 8 16 32 \ - --seq-lens 2048 4096 8192 16384 \ + --seq-lens 2048 4096 8192 16384 32768 \ --topk-vals "${TOPK_VAL}" \ --num-kv-heads 2 4 \ --distributions normal lognormal uniform \ @@ -263,6 +263,32 @@ else --topk-mapping-mode 7 \ --topk-mapping-power 1.0 + # 3j. SGLang mode 8 (Trunc8) + run_e2e "sglang_mode8_trunc8" \ + --vortex-module-name "${ALGO}" \ + --topk-type sglang \ + --topk-mapping-mode 8 + + # 3k. SGLang mode 9 (Erf) + run_e2e "sglang_mode9_erf" \ + --vortex-module-name "${ALGO}" \ + --topk-type sglang \ + --topk-mapping-mode 9 \ + --topk-mapping-power 1.0 + + # 3l. SGLang mode 10 (Tanh) + run_e2e "sglang_mode10_tanh" \ + --vortex-module-name "${ALGO}" \ + --topk-type sglang \ + --topk-mapping-mode 10 \ + --topk-mapping-power 1.0 + + # 3m. SGLang mode 11 (Subtract) + run_e2e "sglang_mode11_subtract" \ + --vortex-module-name "${ALGO}" \ + --topk-type sglang \ + --topk-mapping-mode 11 + echo "" echo ">>> Step 3: Done. E2E logs saved to ${E2E_DIR}/" diff --git a/examples/verify_algo.py b/examples/verify_algo.py index e04f787..f1c2a2d 100644 --- a/examples/verify_algo.py +++ b/examples/verify_algo.py @@ -254,8 +254,8 @@ def parse_args(): "--topk-mapping-mode", type=int, default=0, - choices=[0, 1, 2, 3, 4, 5, 6, 7], - help='TopK mapping mode: 0=none, 1=lut_cdf, 2=quantile, 3=power, 4=log, 5=index_cache, 6=asinh, 7=log1p (default: 0).', + choices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + help='TopK mapping mode: 0=none, 1=lut_cdf, 2=quantile, 3=power, 4=log, 5=index_cache, 6=asinh, 7=log1p, 8=trunc8, 9=erf, 10=tanh, 11=subtract (default: 0).', ) parser.add_argument( diff --git a/examples/verify_algo_topk_mapping.sh b/examples/verify_algo_topk_mapping.sh index 918252c..2370ca1 100644 --- a/examples/verify_algo_topk_mapping.sh +++ b/examples/verify_algo_topk_mapping.sh @@ -11,6 +11,9 @@ set -e # 6: Asinh — y = asinh(beta * x) # 7: Log1p — y = sign(x) * log1p(alpha * |x|) # 8: Trunc8 — bf16 upper-8-bit bucketing +# 9: Erf — y = erf(alpha * x) +# 10: Tanh — y = tanh(alpha * x) +# 11: Subtract — x - pivot (RadiK-style scatter) export CUDA_VISIBLE_DEVICES=0 SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" @@ -20,11 +23,6 @@ sparse_algos=( "block_sparse_attention" ) -topk_mapping_modes=( - 0 # none - 3 # power - 4 # log -) RESULTS_DIR="results" mkdir -p "${RESULTS_DIR}" TIMESTAMP=$(date +%Y%m%d_%H%M%S) @@ -70,7 +68,50 @@ else done fi +# ============================================================ +# Auto-tune: find best hyperparameters per mode +# Uses topk_profile_histogram kernel on real calibration data +# ============================================================ +REAL_HISTOGRAMS="${CALIBRATION_DIR}/raw_histograms.npy" +if [ -f "${REAL_HISTOGRAMS}" ]; then + echo "============================================================" + echo "Auto-tuning hyperparameters (real calibration data)" + echo "============================================================" + AUTOTUNE_JSON="${RESULTS_DIR}/autotune_${TIMESTAMP}.json" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val 30 \ + --batch-size 4 \ + --seq-len 32768 \ + --num-kv-heads 2 \ + --real-histograms "${REAL_HISTOGRAMS}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RESULTS_DIR}/autotune_${TIMESTAMP}.log" + echo ">>> Auto-tune results saved to ${AUTOTUNE_JSON}" + echo "" + # Extract best per-mode hyperparameters from autotune JSON + eval "$(python3 -c " +import json, sys +data = json.load(open(sys.argv[1])) +best = {} +for r in data: + m = r.get('mode') + if m in (3, 6, 7, 9, 10): + if m not in best or r['gini'] < best[m]['gini']: + best[m] = r +for m in (3, 6, 7, 9, 10): + print(f'BEST_POWER_{m}={best[m][\"param\"]}' if m in best else f'BEST_POWER_{m}=0.5') +" "${AUTOTUNE_JSON}")" + echo ">>> Autotuned best powers: mode3=${BEST_POWER_3} mode6=${BEST_POWER_6} mode7=${BEST_POWER_7} mode9=${BEST_POWER_9} mode10=${BEST_POWER_10}" + echo "" +else + echo ">>> WARNING: ${REAL_HISTOGRAMS} not found, using default power=0.5 for all modes" + BEST_POWER_3=0.5 + BEST_POWER_6=0.5 + BEST_POWER_7=0.5 + BEST_POWER_9=0.5 + BEST_POWER_10=0.5 +fi # ============================================================ # Mode 1: LUT CDF with calibrated LUT @@ -111,10 +152,10 @@ for algo in "${sparse_algos[@]}"; do done # ============================================================ -# sglang topk: modes that don't need calibration (0, 3, 4) +# sglang topk: non-parametric modes (0, 4, 8, 11) # ============================================================ for algo in "${sparse_algos[@]}"; do - for topk_mapping_mode in "${topk_mapping_modes[@]}"; do + for topk_mapping_mode in 0 4 8 11; do OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_${topk_mapping_mode}_${TIMESTAMP}.log" echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --topk-type sglang --topk-mapping-mode ${topk_mapping_mode}" echo ">>> Saving results to ${OUTFILE}" @@ -126,50 +167,102 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode ${topk_mapping_mode} \ - --topk-mapping-power 0.5 \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done done # ============================================================ -# Mode 6: asinh — sweep beta values +# Mode 3: power — autotuned best p # ============================================================ for algo in "${sparse_algos[@]}"; do - for beta in 0.5 1.0 2.0; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_6_beta${beta}_${TIMESTAMP}.log" - echo ">>> Running mode 6 (asinh) beta=${beta} for ${algo}" - echo ">>> Saving results to ${OUTFILE}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 6 \ - --topk-mapping-power ${beta} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" - done + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_3_p${BEST_POWER_3}_${TIMESTAMP}.log" + echo ">>> Running mode 3 (power) p=${BEST_POWER_3} (autotuned) for ${algo}" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 3 \ + --topk-mapping-power ${BEST_POWER_3} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" done # ============================================================ -# Mode 7: log1p — sweep alpha values +# Mode 6: asinh — autotuned best beta # ============================================================ for algo in "${sparse_algos[@]}"; do - for alpha in 0.5 1.0 2.0; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_7_alpha${alpha}_${TIMESTAMP}.log" - echo ">>> Running mode 7 (log1p) alpha=${alpha} for ${algo}" - echo ">>> Saving results to ${OUTFILE}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 7 \ - --topk-mapping-power ${alpha} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" - done + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_6_beta${BEST_POWER_6}_${TIMESTAMP}.log" + echo ">>> Running mode 6 (asinh) beta=${BEST_POWER_6} (autotuned) for ${algo}" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 6 \ + --topk-mapping-power ${BEST_POWER_6} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + +# ============================================================ +# Mode 7: log1p — autotuned best alpha +# ============================================================ +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_7_alpha${BEST_POWER_7}_${TIMESTAMP}.log" + echo ">>> Running mode 7 (log1p) alpha=${BEST_POWER_7} (autotuned) for ${algo}" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 7 \ + --topk-mapping-power ${BEST_POWER_7} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + +# ============================================================ +# Mode 9: erf — autotuned best alpha +# ============================================================ +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_9_alpha${BEST_POWER_9}_${TIMESTAMP}.log" + echo ">>> Running mode 9 (erf) alpha=${BEST_POWER_9} (autotuned) for ${algo}" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 9 \ + --topk-mapping-power ${BEST_POWER_9} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + +# ============================================================ +# Mode 10: tanh — autotuned best alpha +# ============================================================ +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_10_alpha${BEST_POWER_10}_${TIMESTAMP}.log" + echo ">>> Running mode 10 (tanh) alpha=${BEST_POWER_10} (autotuned) for ${algo}" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 10 \ + --topk-mapping-power ${BEST_POWER_10} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" done \ No newline at end of file diff --git a/examples/verify_algo_topk_mapping_new.sh b/examples/verify_algo_topk_mapping_new.sh index b701be2..f1c41fe 100644 --- a/examples/verify_algo_topk_mapping_new.sh +++ b/examples/verify_algo_topk_mapping_new.sh @@ -11,6 +11,9 @@ set -e # 6: Asinh — y = asinh(beta * x) # 7: Log1p — y = sign(x) * log1p(alpha * |x|) # 8: Trunc8 — bf16 upper-8-bit bucketing +# 9: Erf — y = erf(alpha * x) +# 10: Tanh — y = tanh(alpha * x) +# 11: Subtract — x - pivot (RadiK-style scatter) export CUDA_VISIBLE_DEVICES=5 SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" @@ -38,7 +41,7 @@ AUTOTUNE_JSON="${RESULTS_DIR}/autotune_${TIMESTAMP}.json" PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ --topk-val 30 \ --batch-size 4 \ - --seq-len 4096 \ + --seq-len 32768 \ --num-kv-heads 2 \ --real-histograms "${REAL_HISTOGRAMS}" \ --output-json "${AUTOTUNE_JSON}" \ @@ -47,72 +50,166 @@ echo ">>> Auto-tune results saved to ${AUTOTUNE_JSON}" echo "" # ============================================================ -# Step 1: Mode 3 (power) — sweep p values +# Extract best per-mode hyperparameters from autotune JSON +# ============================================================ +eval "$(python3 -c " +import json, sys +data = json.load(open(sys.argv[1])) +best = {} +for r in data: + m = r.get('mode') + if m in (3, 6, 7, 9, 10): + if m not in best or r['gini'] < best[m]['gini']: + best[m] = r +for m in (3, 6, 7, 9, 10): + print(f'BEST_POWER_{m}={best[m][\"param\"]}' if m in best else f'BEST_POWER_{m}=0.5') +" "${AUTOTUNE_JSON}")" +echo ">>> Autotuned best powers: mode3=${BEST_POWER_3} mode6=${BEST_POWER_6} mode7=${BEST_POWER_7} mode9=${BEST_POWER_9} mode10=${BEST_POWER_10}" +echo "" + +# ============================================================ +# Step 1: Mode 3 (power) — autotuned best p +# ============================================================ +echo "============================================================" +echo "Step 1: Mode 3 (power) — p=${BEST_POWER_3} (autotuned)" +echo "============================================================" +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_3_p${BEST_POWER_3}_${TIMESTAMP}.log" + echo ">>> Mode 3 (power) p=${BEST_POWER_3} algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 3 \ + --topk-mapping-power ${BEST_POWER_3} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + +# ============================================================ +# Step 2: Mode 6 (asinh) — autotuned best beta +# ============================================================ +echo "============================================================" +echo "Step 2: Mode 6 (asinh) — beta=${BEST_POWER_6} (autotuned)" +echo "============================================================" +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_6_beta${BEST_POWER_6}_${TIMESTAMP}.log" + echo ">>> Mode 6 (asinh) beta=${BEST_POWER_6} algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 6 \ + --topk-mapping-power ${BEST_POWER_6} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + +# ============================================================ +# Step 3: Mode 7 (log1p) — autotuned best alpha +# ============================================================ +echo "============================================================" +echo "Step 3: Mode 7 (log1p) — alpha=${BEST_POWER_7} (autotuned)" +echo "============================================================" +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_7_alpha${BEST_POWER_7}_${TIMESTAMP}.log" + echo ">>> Mode 7 (log1p) alpha=${BEST_POWER_7} algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 7 \ + --topk-mapping-power ${BEST_POWER_7} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + +# ============================================================ +# Step 4: Mode 8 (trunc8) — fixed parameter +# ============================================================ +echo "============================================================" +echo "Step 4: Mode 8 (trunc8)" +echo "============================================================" +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_8_${TIMESTAMP}.log" + echo ">>> Mode 8 (trunc8) algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 8 \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + +# ============================================================ +# Step 5: Mode 9 (erf) — autotuned best alpha # ============================================================ echo "============================================================" -echo "Step 1: Mode 3 (power) — sweeping p" +echo "Step 5: Mode 9 (erf) — alpha=${BEST_POWER_9} (autotuned)" echo "============================================================" for algo in "${sparse_algos[@]}"; do - for p in 0.1 0.25 0.75 0.9; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_3_p${p}_${TIMESTAMP}.log" - echo ">>> Mode 3 (power) p=${p} algo=${algo}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 3 \ - --topk-mapping-power ${p} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" - done + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_9_alpha${BEST_POWER_9}_${TIMESTAMP}.log" + echo ">>> Mode 9 (erf) alpha=${BEST_POWER_9} algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 9 \ + --topk-mapping-power ${BEST_POWER_9} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" done # ============================================================ -# Step 2: Mode 6 (asinh) — sweep beta values +# Step 6: Mode 10 (tanh) — autotuned best alpha # ============================================================ echo "============================================================" -echo "Step 2: Mode 6 (asinh) — sweeping beta" +echo "Step 6: Mode 10 (tanh) — alpha=${BEST_POWER_10} (autotuned)" echo "============================================================" for algo in "${sparse_algos[@]}"; do - for beta in 0.1 0.5 1.0 2.0 4.0; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_6_beta${beta}_${TIMESTAMP}.log" - echo ">>> Mode 6 (asinh) beta=${beta} algo=${algo}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 6 \ - --topk-mapping-power ${beta} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" - done + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_10_alpha${BEST_POWER_10}_${TIMESTAMP}.log" + echo ">>> Mode 10 (tanh) alpha=${BEST_POWER_10} algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 10 \ + --topk-mapping-power ${BEST_POWER_10} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" done # ============================================================ -# Step 3: Mode 7 (log1p) — sweep alpha values +# Step 7: Mode 11 (subtract) — fixed parameter # ============================================================ echo "============================================================" -echo "Step 3: Mode 7 (log1p) — sweeping alpha" +echo "Step 7: Mode 11 (subtract)" echo "============================================================" for algo in "${sparse_algos[@]}"; do - for alpha in 0.1 0.5 0.75 1.0 2.0 4.0 8.0; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_7_alpha${alpha}_${TIMESTAMP}.log" - echo ">>> Mode 7 (log1p) alpha=${alpha} algo=${algo}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 7 \ - --topk-mapping-power ${alpha} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" - done + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_11_${TIMESTAMP}.log" + echo ">>> Mode 11 (subtract) algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 11 \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" done # ============================================================ @@ -120,9 +217,13 @@ done # ============================================================ echo "" echo "============================================================" -echo "All sweeps complete. Results in ${RESULTS_DIR}/" -echo " Auto-tune: ${AUTOTUNE_JSON}" -echo " Mode 3 (power): p = [0.1, 0.25, 0.75, 0.9]" -echo " Mode 6 (asinh): beta = [0.1, 0.5, 1.0, 2.0, 4.0]" -echo " Mode 7 (log1p): alpha = [0.1, 0.5, 0.75, 1.0, 2.0, 4.0, 8.0]" +echo "All runs complete. Results in ${RESULTS_DIR}/" +echo " Auto-tune: ${AUTOTUNE_JSON}" +echo " Mode 3 (power): p = ${BEST_POWER_3} (autotuned)" +echo " Mode 6 (asinh): beta = ${BEST_POWER_6} (autotuned)" +echo " Mode 7 (log1p): alpha = ${BEST_POWER_7} (autotuned)" +echo " Mode 8 (trunc8): (fixed)" +echo " Mode 9 (erf): alpha = ${BEST_POWER_9} (autotuned)" +echo " Mode 10 (tanh): alpha = ${BEST_POWER_10} (autotuned)" +echo " Mode 11 (subtract): (fixed)" echo "============================================================" diff --git a/setup.py b/setup.py index 99c6529..649f0a0 100644 --- a/setup.py +++ b/setup.py @@ -27,8 +27,6 @@ '-gencode=arch=compute_86,code=sm_86', '-gencode=arch=compute_89,code=sm_89', '-gencode=arch=compute_90,code=sm_90', - '-gencode=arch=compute_100a,code=sm_100a', - '-gencode=arch=compute_120,code=sm_120' ], }, ), diff --git a/third_party/sglang b/third_party/sglang index 5f51c8e..0ec1289 160000 --- a/third_party/sglang +++ b/third_party/sglang @@ -1 +1 @@ -Subproject commit 5f51c8ef485fb45990c8166f439da2ee695c03c1 +Subproject commit 0ec12893c4fc0d6ae1d36d4e0512dc21749c4b4b diff --git a/todo.txt b/todo.txt new file mode 100644 index 0000000..53950c3 --- /dev/null +++ b/todo.txt @@ -0,0 +1,308 @@ +1. +prefill 8k/16k/32k block 16/32/64 block topk (8,16) +qwen series 0.6b, 1.7b, 4b, 8b, 16b, 32b +baselines: +flashinfer-fa2/fa3 flashattention v2/v3 +dense + +NSA: block sparse attention +benchmarking: +flash Sparse Attention +https://github.com/Relaxed-System-Lab/Flash-Sparse-Attention + +https://github.com/mit-han-lab/flash-moba + + +Video generation: VAE fp8 convolution +wan 2.1 vae +1.3B input 480P + +3.17: +(For SOSP 26) +warpper: prefill: +1. ragged paged, warpper +disable_radix_cache +new backend: goal sparse prefill with topk on a new warpper: abandon the previous paged warpper, +apply 1 to the whole prefill sequence + +2. topk +idea: we want to improve the current topk kernel /scr/dataset/yuke/xinrui/new/vortex_torch/csrc/topk_sglang.cu for our project, +we want to map the value in each layer for the topk selection to a new distribution that bucket sort is more efficient on. like to make the values to be more uniform +in each bucket. +the number of the heads has a certain distribution, try to adapt to it. +The mapping function should have a low overhead, or it would damage the end2end efficiency. +Key: first profile the whole process, record the distribution of the value in each layer. You need a profile script for this. save the results. +Then design a novel mapping function(can be easily customize by me), to map the value to a new distribution. Don't change the correctness of the sorting, but more efficient for the bucket sort in the /scr/dataset/yuke/xinrui/new/vortex_torch/csrc/topk_sglang.cu +here is some options: +Option A: Adaptive Bit Selection (2-Pass Min/Max on uint32 Key) + +Core idea: Instead of always extracting the top 8 bits of the fp16 key, extract the 8 most-significant varying bits of the full 32-bit key by finding the actual key range within each segment. + +Algorithm: +Pass 1: Parallel warp-reduction to find min_key and max_key over convert_to_uint32(x) for the segment (~N/1024 iterations per thread, same as current Stage 1) +Compute shift = max(0, 31 - clz(max_key - min_key) - 8) (find the bit position of the 8 most-significant differing bits) + +Pass 2: bin = ((convert_to_uint32(x) - min_key) >> shift) & 0xFF — this uses ALL 32 bits of float precision for binning +Overhead: ~2x data reads for Stage 1 (one extra scan for min/max). Min/max reduction can be done with __shfl_xor_sync followed by a single atomicMin/atomicMax in shared memory — very efficient. + +Pros: +Uses full 32-bit float precision instead of just 8 fp16 bits (up to 2^24 = 16M effective resolution levels instead of 256) +Perfectly adaptive to any data range — no calibration needed + +Pure integer arithmetic in the hot loop (shift + subtract + mask) +Guaranteed monotonic (linear mapping of uint32 keys) +Cons: +2x memory bandwidth for Stage 1 (the min/max pass re-reads all data) +Doesn't guarantee perfectly uniform bin counts (a skewed distribution within [min, max] still skews bins) +Expected quality: Excellent for the real distribution (where the range is narrow but contains many distinct float values that the 8-bit fp16 extraction cannot distinguish). +Implementation sketch: + +// Pass 1: find min/max key via parallel reduction +__shared__ uint32_t s_min_key, s_max_key; +if (tx == 0) { s_min_key = 0xFFFFFFFF; s_max_key = 0; } +__syncthreads(); +for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + uint32_t k = convert_to_uint32(input[idx + row_start]); + atomicMin(&s_min_key, k); + atomicMax(&s_max_key, k); +} +__syncthreads(); +uint32_t range = s_max_key - s_min_key; +int shift = max(0, 31 - __clz(range | 1) - 7); // 8 MSBs of range + +// Pass 2: histogram with adaptive bins +for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + uint32_t k = convert_to_uint32(input[idx + row_start]); + uint8_t bin = ((k - s_min_key) >> shift) & 0xFF; + atomicAdd(&s_histogram[bin], 1); +} + + + +Option B: Strided Sampling + Approximate CDF Equalization +Core idea: Read a sparse sample (every S-th element, e.g. S=32) to build an approximate histogram, compute a CDF-equalized LUT from it in shared memory, then apply the LUT during the full scan. +Algorithm: +Sampling pass: Read every 32nd element, build approximate 256-bin histogram in shared memory +In-block LUT construction: 256 threads compute prefix sum -> CDF -> equalized LUT (8 iterations of parallel prefix sum, same as existing run_cumsum) +Full scan: Apply s_lut[convert_to_uint8(x)] for each element +Overhead: ~(1/32 + 1)x = ~1.03x memory reads. LUT construction is ~8 syncthreads (trivial). The hot-loop cost is identical to existing Mode 1 (one shared memory lookup). +Pros: +Near-zero extra bandwidth (only 3% overhead from sampling) +Fully adaptive — no offline calibration needed +Self-tuning: the LUT is computed from the current segment's own data +Hot-loop cost identical to existing LUT mode (1 shared memory read) +Cons: +Approximate: sampling introduces noise in the estimated CDF (especially for small segments) +More complex control flow (3 phases in Stage 1) +For very short segments (<1024 elements), the sample may be too small +Expected quality: Very good. With 1/32 sampling on a segment of 4096+ elements, the CDF estimate has ~128+ samples — sufficient for a good 256-entry LUT. +Implementation sketch: + +// Phase 0: sampled histogram +__shared__ int s_sample_hist[256]; +if (tx < 256) s_sample_hist[tx] = 0; +__syncthreads(); +for (int idx = tx * 32; idx < length; idx += BLOCK_SIZE * 32) { + uint8_t bin = convert_to_uint8(input[idx + row_start]); + atomicAdd(&s_sample_hist[bin], 1); +} +__syncthreads(); + +// Phase 0.5: compute equalized LUT from sampled histogram +__shared__ uint8_t s_eq_lut[256]; +// ... prefix sum on s_sample_hist -> CDF -> s_eq_lut[i] = floor(CDF(i) * 255) + +// Phase 1: full scan with equalized LUT +for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + uint8_t raw_bin = convert_to_uint8(input[idx + row_start]); + uint8_t eq_bin = s_eq_lut[raw_bin]; + atomicAdd(&s_histogram[eq_bin], 1); +} + + +Option C: Online 2-Pass Histogram Equalization +Core idea: The "gold standard" — build the exact histogram first, compute CDF equalization in shared memory, then re-scan with the equalized mapping. +Algorithm: + +Pass 1: Build exact 256-bin histogram using convert_to_uint8(x) (same as current code) +In-block: 256 threads compute prefix sum -> CDF -> equalized LUT (lut[i] = floor(CDF(i) * 255)) + +Pass 2: Re-scan ALL elements, apply s_lut[convert_to_uint8(x)], build equalized histogram -> find threshold +Overhead: Exactly 2x data reads for Stage 1. LUT construction is negligible (~8 __syncthreads steps). +Pros: +Optimal histogram equalization — produces a provably uniform distribution (each equalized bin has almost exactly N/256 elements) +No calibration, no approximation — perfect adaptation to any input distribution +Monotonic (CDF is monotonically non-decreasing) + +Cons: +2x memory bandwidth for Stage 1 (dominant cost) +May not be worth it if Stage 2 is already fast (the overhead of the extra pass might exceed the savings from a smaller threshold bin) +Expected quality: Optimal. The threshold bin after equalization will have ~N/256 elements regardless of input distribution. +When this is worthwhile: When the original hot bin contains a very large fraction of elements (e.g., >20% of N), the savings from nearly eliminating Stage 2 easily outweigh the extra read pass. Given real data's Gini=0.809, this is likely the case. + +Option D: Temporal CDF Caching (Self-Calibrating LUT) +Core idea: The score distribution per head is relatively stable across adjacent decoding steps. After each kernel invocation, write the histogram to global memory. At the start of the NEXT invocation, load it and compute the CDF-equalized LUT. +Algorithm: +At kernel launch, load the previous iteration's histogram from a persistent global buffer prev_hist[head_id][256] +Compute equalized LUT in shared memory (prefix sum -> CDF -> LUT) +Stage 1 uses this LUT (same as Mode 1) +After Stage 1's histogram is built, write it to prev_hist[head_id][256] for the next iteration +Overhead: 1 shared memory lookup per element (identical to existing Mode 1). Plus ~256 int32 reads + ~256 int32 writes per kernel launch (negligible). +Pros: +Near-zero per-element overhead (shared memory LUT lookup) +Self-calibrating — no offline calibration step needed +Adapts to distribution changes over time (with 1-step lag) +Builds directly on existing Mode 1 infrastructure +Very low implementation complexity + +Cons: +1-step lag: the LUT is based on the previous step's data (cold start on first iteration) +Requires persistent global memory buffer (~256 * 4 bytes per head) +May produce suboptimal LUT if the distribution changes rapidly between steps +Expected quality: Very good after the first few iterations. Attention score distributions evolve slowly during generation, so the 1-step lag has minimal impact. +Implementation changes: +Add prev_histogram pointer to TopKMappingParams +Add MAPPING_TEMPORAL_CDF = 7 mode +In kernel: load prev histogram -> compute LUT -> use LUT -> write current histogram +Python side: allocate persistent [num_heads, 256] int32 buffer, pass to kernel + +Option E: Adaptive Exponent-Mantissa Bit Packing +Core idea: The current 8-bit extraction uses 5 exponent + 2 mantissa bits from fp16. When the actual exponent range is narrow (e.g., only 2-3 distinct exponents), most of those 5 exponent bits are wasted. Dynamically reallocate bits: use fewer for the exponent, more for the mantissa. +Algorithm: +Calibration or per-block scan: Determine exponent range [E_min, E_max] of the scores +Choose bit layout based on range width: +Range 1-2 exponents: 1 exp bit + 6 mantissa bits + 1 sign = 64 bins/exponent (vs 4 currently) → 16x improvement +Range 3-4: 2 exp + 5 mantissa + 1 sign +Range 5-8: 3 exp + 4 mantissa + 1 sign +Range 9-16: 4 exp + 3 mantissa + 1 sign +Wider: original 5 exp + 2 mantissa + 1 sign + +Apply: bin = ((exp - E_min) << mantissa_bits) | (mantissa >> (10 - mantissa_bits)) for positive values (with sign-magnitude ordering) +Overhead: Very low (~5-8 integer instructions per element: extract exponent, subtract base, shift, OR with mantissa). No extra memory reads. No LUT. +Pros + +Extremely low overhead — pure register-level bit manipulation +No extra memory reads or shared memory usage +Monotonic (order-preserving within each exponent, and across exponents) +Up to 16x better bin resolution for narrow distributions +Cons: +Requires knowing E_min/E_max (either calibrated offline, or from a quick per-block reduction) +Not as "perfect" as CDF equalization — distribution within each exponent may still be non-uniform +More complex bit manipulation logic +Expected quality: Very good for the observed real distribution (narrow exponent range → dramatic improvement in bin resolution). Not optimal for arbitrary distributions. +Implementation sketch: + +// Assuming E_min, E_max precomputed and passed in TopKMappingParams +__device__ __forceinline__ uint8_t map_adaptive_bits(float x, int e_min, int e_range) { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? ~bits : (bits | 0x8000); + int exp_val = (key >> 10) & 0x1F; // 5-bit exponent + int mantissa = key & 0x3FF; // 10-bit mantissa + + // Determine bit allocation based on e_range + int exp_bits, mant_bits; + if (e_range <= 2) { exp_bits = 1; mant_bits = 6; } + else if (e_range <= 4) { exp_bits = 2; mant_bits = 5; } + else if (e_range <= 8) { exp_bits = 3; mant_bits = 4; } + else { return (uint8_t)(key >> 8); } // fallback + + int sign_bit = (key >> 15) & 1; + int exp_part = min((exp_val - e_min), (1 << exp_bits) - 1); + int mant_part = mantissa >> (10 - mant_bits); + return (uint8_t)((sign_bit << 7) | (exp_part << mant_bits) | mant_part); +} +Use a new file to store these options, and make sure I can switch between these options. +3. Adapt Sparse attentions to the vortex this include: +(1) Naive Sparse Attention /scr/dataset/yuke/xinrui/Sparse-benchmark/Flash-Sparse-Attention/nsa_ref +(2) Flash Sparse Attention /scr/dataset/yuke/xinrui/Sparse-benchmark/Flash-Sparse-Attention/fsa +(3) FlashMoBA /scr/dataset/yuke/xinrui/Sparse-benchmark/flash-moba +Need to implement the whole sparse attention kernel, use their attention backend. Replace forward extend + + +# How to use the custom mapping function: +--- + Writing a Custom Mapping Function + + All mapping logic lives in csrc/topk_mapping.cuh. To add a new mode: + + Step 1: Add to the enum in topk_mapping.cuh: + + enum TopKMappingMode { + MAPPING_NONE = 0, + MAPPING_LUT_CDF = 1, + MAPPING_QUANTILE = 2, + MAPPING_POWER = 3, + MAPPING_LOG = 4, + MAPPING_CUSTOM = 5, // <-- your new mode + }; + + Step 2: Write your __device__ mapping function. It must take a float score and return a uint8_t bin index (0–255). The mapping must be monotonic (order-preserving) to ensure + correctness: + + __device__ __forceinline__ + uint8_t map_custom(float x) { + // Example: sqrt transform + float mapped = copysignf(sqrtf(fabsf(x)), x); + return convert_to_uint8(mapped); + } + + Step 3: Add a case to the dispatcher mapped_convert_to_uint8(): + + __device__ __forceinline__ + uint8_t mapped_convert_to_uint8(float x, const TopKMappingParams& params) { + switch (params.mode) { + // ... existing cases ... + case MAPPING_CUSTOM: + return map_custom(x); + default: + return convert_to_uint8(x); + } + } + + Step 4: Update verify_algo.py to accept the new mode value. In parse_args(), change the choices: + + parser.add_argument("--topk-mapping-mode", type=int, default=0, + choices=[0, 1, 2, 3, 4, 5], # add 5 + ...) + + Step 5: Rebuild and test: + + pip install -e . + python examples/verify_algo.py --topk-type sglang --topk-mapping-mode 5 ... + + Key constraint: Your mapping function only affects Stage 1 (coarse 256-bin histogram). Stage 2 refinement always uses raw float bits via convert_to_uint32(), so the final + top-K selection is always correct regardless of your mapping. The goal is to make the Stage 1 histogram more uniform so fewer elements land in the threshold bin, reducing + Stage 2 work. + + If your custom mapping needs extra parameters (like a tensor or scalar), add them to the TopKMappingParams struct, pass them through topk_output_sglang() host function, update + register.h/register.cc bindings, and read them from ctx in output_func.py. + + 1. csrc/topk_sglang.cu — New CUDA kernel + + - TopKHitRate_Kernel: Stage-1-only kernel with mapping support. Builds 256-bin histogram, writes raw histogram to global memory, runs cumsum to find threshold bin, then + computes stage1_resolved = nblk - items_in_threshold_bin. No Stage 2 needed. + - topk_hit_rate(): Host entry point mirroring topk_output_sglang() for mapping param construction and dtype dispatch. + + 2. csrc/register.h + csrc/register.cc — PyBind11 bindings + + - Added topk_hit_rate declaration and m.def(...) binding with default args for mapping params. + + 3. benchmarks/bench_topk.py — Benchmark integration + + - Added compute_hit_rate_stats() helper for per-segment resolution rate + histogram stats. + - Added --hit-rate CLI flag. When enabled, iterates over available mapping modes (0, 3, 4 always; 1 if LUT provided; 2 if quantiles provided) and prints a comparison table. + - Results stored in config_results["hit_rate"] for JSON output. + + 4. benchmarks/analyze_topk_distribution.py — New visualization script + + - 5 plot functions: bin distribution, heatmap, before/after mapping, summary table, mode comparison. + - Loads from --profile-npz and/or --bench-json. + + 5. End-to-end integration (your request) + + - vortex_torch/indexer/context.py: Added topk_hit_rate_enabled slot, populated from sa.vortex_topk_hit_rate (default False). + - vortex_torch/indexer/output_func.py: After the main topk kernel call, if ctx.topk_hit_rate_enabled is True and topk_type is "sglang", it calls topk_hit_rate() and stores + results in self.last_hit_rate_stats / self.last_hit_rate_histograms. Zero overhead when disabled — just a getattr check that short-circuits. + + To enable during inference, set vortex_topk_hit_rate=True in your SGLang server args. \ No newline at end of file diff --git a/vortex_torch/indexer/context.py b/vortex_torch/indexer/context.py index 78e2923..8142fbc 100644 --- a/vortex_torch/indexer/context.py +++ b/vortex_torch/indexer/context.py @@ -25,6 +25,7 @@ class Context(ContextBase): # misc "indexer_dtype", "topk_val", "page_reserved_bos", "page_reserved_eos", "topk_type", "topk_mapping_mode", "topk_mapping_power", "topk_mapping_lut", "topk_mapping_quantiles", + "topk_mapping_noscale", "topk_histogram_enabled", # auxilary memory in graph @@ -76,6 +77,7 @@ class Context(ContextBase): topk_mapping_power: float #: Power exponent for mapping mode 3. topk_mapping_lut: object #: Optional uint8[256] LUT tensor for mapping mode 1. topk_mapping_quantiles: object #: Optional float32[256] quantiles tensor for mapping mode 2. + topk_mapping_noscale: bool #: Skip auto-range linear scaling, use fp16 bucketing on f(x) (default False). topk_histogram_enabled: bool #: Enable histogram profiling during inference (default False). # --- auxiliary --- @@ -156,6 +158,7 @@ def create(self, parent: Any, model_runner: Any, *, overwrite: bool = False) -> self.topk_type = getattr(sa, "vortex_topk_type", "naive") self.topk_mapping_mode = getattr(sa, "vortex_topk_mapping_mode", 0) self.topk_mapping_power = getattr(sa, "vortex_topk_mapping_power", 0.5) + self.topk_mapping_noscale = getattr(sa, "vortex_topk_mapping_noscale", False) self.topk_histogram_enabled = getattr(sa, "vortex_topk_histogram", False) device = getattr(model_runner, "device", "cpu") diff --git a/vortex_torch/indexer/output_func.py b/vortex_torch/indexer/output_func.py index e4208dc..53e9717 100644 --- a/vortex_torch/indexer/output_func.py +++ b/vortex_torch/indexer/output_func.py @@ -249,6 +249,7 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso mapping_power = getattr(ctx, 'topk_mapping_power', 0.5) mapping_lut = getattr(ctx, 'topk_mapping_lut', None) mapping_quantiles = getattr(ctx, 'topk_mapping_quantiles', None) + mapping_noscale = getattr(ctx, 'topk_mapping_noscale', False) # UNSET sentinel is not a valid torch.Tensor — coerce to None if mapping_lut is UNSET: mapping_lut = None @@ -269,6 +270,7 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso mapping_power, mapping_lut, mapping_quantiles, + mapping_noscale, ) else: # topk_output (naive): (x, dense_kv_indptr, dense_kv_indices, sparse_kv_indptr, sparse_kv_indices, ...) @@ -306,6 +308,7 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso mapping_power, mapping_lut, mapping_quantiles, + mapping_noscale, ) # Accumulate histograms for offline calibration _calibration_histograms.append(self.last_histograms.cpu().clone()) From 15f1d03578cfa845e322a0ae6b85498e559b19ca Mon Sep 17 00:00:00 2001 From: UED Date: Thu, 2 Apr 2026 04:12:19 +0000 Subject: [PATCH 15/22] enhance TopK mapping with adaptive tail-window mode; modify example scripts to reflect changes in histogram calibration and TopK mapping parameters. --- CLAUDE.md | 172 ------- csrc/register.cc | 3 +- csrc/register.h | 3 +- csrc/topk_mapping.cuh | 33 +- csrc/topk_sglang.cu | 212 ++++++++- csrc/topk_slgang_ori.cu | 546 ---------------------- examples/run_distribution_analysis.sh | 2 +- examples/run_distribution_analysis_new.sh | 2 +- examples/verify_algo.py | 6 +- examples/verify_algo_topk_mapping_new.sh | 25 +- setup.py | 3 + todo.txt | 308 ------------ vortex_torch/indexer/output_func.py | 1 + 13 files changed, 271 insertions(+), 1045 deletions(-) delete mode 100644 CLAUDE.md delete mode 100644 csrc/topk_slgang_ori.cu delete mode 100644 todo.txt diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index 585a246..0000000 --- a/CLAUDE.md +++ /dev/null @@ -1,172 +0,0 @@ -# CLAUDE.md - -This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. - -## Project Overview - -Vortex is a lightweight, modular framework for building custom sparse attention algorithms for LLM inference. It provides a PyTorch-like frontend that abstracts away batching, caching, and paged attention, running on optimized backends (FlashInfer, CUDA Graph) via SGLang integration. - -## Build & Install - -```bash -# Clone with submodules -git clone -b v1 --recursive - -# Install SGLang dependency (custom fork in third_party/, supports v0.4.9) -cd third_party/sglang && bash install.sh && cd ../../ - -# Install Vortex (editable mode, compiles CUDA extensions for SM_86/SM_89/SM_90) -pip install -e . -``` - -Requires Python >=3.10, torch>=2.7, lighteval[math]==0.12.2. CUDA extensions (`vortex_torch_C`) are built from `csrc/` (register.cc, utils_sglang.cu, topk.cu, topk_sglang.cu). - -## Testing & Verification - -There is no formal test suite (no pytest). Verification is done by running algorithms against SGLang reference output and comparing accuracy on math benchmarks. - -```bash -# Single algorithm verification (from examples/ directory) -python examples/verify_algo.py --trials 2 --topk-val 30 --vortex-module-name block_sparse_attention - -# Full options -python examples/verify_algo.py \ - --trials 8 --topk-val 30 \ - --vortex-module-name block_sparse_attention \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type naive \ - --mem 0.7 - -# Batch test (outputs timestamped logs to examples/results/) -bash examples/verify_algo.sh - -# AIM24 benchmark verification -python examples/verify_aim24.py -``` - -Available `--topk-type` values: `naive` (CUB-based), `sglang` (SGLang-integrated kernel). - -## AI-Powered Algorithm Generation - -```bash -# Generate new sparse attention algorithms via OpenHands (requires LLM_API_KEY env var) -python openhands_gen.py -``` - -Note: Some auto-generated operators may not be fully optimized. Tune `mem_fraction_static` if OOM occurs. - -## Building Documentation - -```bash -make -C docs html -``` - -Uses Sphinx with myst_parser and furo theme. Deployed via GitHub Actions on push to v1 branch. - -## Architecture - -### Core Abstraction: vFlow (`vortex_torch/flow/flow.py`) - -All sparse attention algorithms inherit from `vFlow` and implement three methods: - -- **`forward_indexer(q, o, cache, ctx)`** — Compute sparse page indices from queries. Operates on page-packed tensor view `[S, r, c]`. -- **`forward_cache(cache, loc, ctx)`** — Update/summarize custom cache tensors when a page completes. Operates on batch-major view `[B, r, c]`. -- **`create_cache(page_size, head_dim)`** — Declare custom cache tensor shapes as a dict of `{name: (rows, cols)}`. - -Algorithms are registered via `@register("name")` decorator and instantiated with `build_vflow()`. - -### Operator System (`vortex_torch/indexer/`, `vortex_torch/cache/`) - -Operators (`vOp` subclasses) run in two modes: -- **Profile mode**: Pre-compute output shapes and allocate buffers -- **Execute mode**: Perform actual GPU computation - -Operators are split into two parallel hierarchies: -- **Indexer ops** (`vortex_torch/indexer/`): GeMM, GeMV, topK, reduce (Mean/Max/Min/Sum/L2Norm), softmax, elementwise, transpose, save/load -- **Cache ops** (`vortex_torch/cache/`): GeMM, reduce, elementwise, fill, KV buffer setup - -Both use Triton kernels (in respective `triton_kernels/` subdirectories) for GPU execution. - -### Tensor Format (`vortex_torch/abs/tensor.py`) - -`vTensor` wraps `torch.Tensor` with format metadata (BATCHED, RAGGED, PAGED) to enforce layout consistency across operations. - -### Context System (`vortex_torch/abs/context_base.py`) - -`ContextBase` carries per-step runtime state. Specialized as: -- `Indexer.Context`: Page layout, head config, hardware info -- `Cache.Context`: Page size, total pages, model info - -### Concrete Algorithms (`vortex_torch/flow/algorithms.py`) - -- **BlockSparseAttention**: Centroid-based routing (query avg → GeMV with centroids → topK) -- **GQABlockSparseAttention**: Grouped-query variant with softmax + group aggregation -- **GQAQuestSparseAttention**: Query-envelope matching using per-page max/min bounds - -### Algorithm Registry (`vortex_torch/flow/registry.py`) - -Algorithms are registered via `@register("name")` and looked up with `get(name)`, `has(name)`, `list_keys()`. Factory: `build_vflow(name)` in `loader.py`. - -### SGLang Integration - -Custom SGLang fork lives in `third_party/sglang` (git submodule, "graph" branch). CUDA extensions in `csrc/` provide PyBind11 bindings for `sglang_plan_decode`, `sglang_plan_prefill`, transpose operations (NH↔HN), and top-K output routing. - -## Key Conventions - -- **Tensor shapes**: Query `[B, H_q, D]`, sparse output `[S_sparse, 1, 1]`, cache indexer-view `[S, r, c]`, cache batch-view `[B, r, c]` -- **GeMM semantics**: `GeMM(x, y)` computes `y @ x^T` (note transposition) -- **Standard cache keys**: `"k"` and `"v"` have inner shape `(page_size, head_dim)`; custom caches declared in `create_cache()` -- **Branch**: Main development is on `v1` - -## Workflow Orchestration - -### 1. Plan Node Default -- Enter plan mode for ANY non-trivial task (3+ steps or architectural decisions) -- If something goes sideways, STOP and re-plan immediately - don't keep pushing -- Use plan mode for verification steps, not just building -- Write detailed specs upfront to reduce ambiguity - -### 2. Subagent Strategy -- Use subagents liberally to keep main context window clean -- Offload research, exploration, and parallel analysis to subagents -- For complex problems, throw more compute at it via subagents -- One tack per subagent for focused execution - -### 3. Self-Improvement Loop -- After ANY correction from the user: update `tasks/lessons.md` with the pattern -- Write rules for yourself that prevent the same mistake -- Ruthlessly iterate on these lessons until mistake rate drops -- Review lessons at session start for relevant project - -### 4. Verification Before Done -- Never mark a task complete without proving it works -- Diff behavior between main and your changes when relevant -- Ask yourself: "Would a staff engineer approve this?" -- Run tests, check logs, demonstrate correctness - -### 5. Demand Elegance (Balanced) -- For non-trivial changes: pause and ask "is there a more elegant way?" -- If a fix feels hacky: "Knowing everything I know now, implement the elegant solution" -- Skip this for simple, obvious fixes - don't over-engineer -- Challenge your own work before presenting it - -### 6. Autonomous Bug Fixing -- When given a bug report: just fix it. Don't ask for hand-holding -- Point at logs, errors, failing tests - then resolve them -- Zero context switching required from the user -- Go fix failing CI tests without being told how - -## Task Management - -1. **Plan First**: Write plan to `tasks/todo.md` with checkable items -2. **Verify Plan**: Check in before starting implementation -3. **Track Progress**: Mark items complete as you go -4. **Explain Changes**: High-level summary at each step -5. **Document Results**: Add review section to `tasks/todo.md` -6. **Capture Lessons**: Update `tasks/lessons.md` after corrections - -## Core Principles - -- **Simplicity First**: Make every change as simple as possible. Impact minimal code. -- **No Laziness**: Find root causes. No temporary fixes. Senior developer standards. -- **Minimal Impact**: Changes should only touch what's necessary. Avoid introducing bugs. \ No newline at end of file diff --git a/csrc/register.cc b/csrc/register.cc index 0a3c11c..af49aec 100644 --- a/csrc/register.cc +++ b/csrc/register.cc @@ -27,7 +27,8 @@ PYBIND11_MODULE(vortex_torch_C, m){ py::arg("mapping_power") = 0.5, py::arg("mapping_lut") = py::none(), py::arg("mapping_quantiles") = py::none(), - py::arg("mapping_noscale") = false); + py::arg("mapping_noscale") = false, + py::arg("topk_val") = 0); m.def("topk_profile_stage1", &topk_profile_stage1, py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), diff --git a/csrc/register.h b/csrc/register.h index 1a8b820..dae5e82 100644 --- a/csrc/register.h +++ b/csrc/register.h @@ -114,7 +114,8 @@ const int64_t mapping_mode = 0, const double mapping_power = 0.5, std::optional mapping_lut = std::nullopt, std::optional mapping_quantiles = std::nullopt, -const bool mapping_noscale = false +const bool mapping_noscale = false, +const int64_t topk_val = 0 ); void topk_profile_stage1( diff --git a/csrc/topk_mapping.cuh b/csrc/topk_mapping.cuh index 97bc141..8dbb808 100644 --- a/csrc/topk_mapping.cuh +++ b/csrc/topk_mapping.cuh @@ -4,16 +4,24 @@ #include // ============================================================ -// TopK bucket-sort distribution mapping strategies +// TopK bucket-sort Stage-1 remapping strategies // // These transforms remap float scores before Stage 1's 8-bit -// histogram binning, aiming for a more uniform distribution -// across the 256 coarse bins. Stage 2 refinement still uses -// convert_to_uint32() on raw floats, so correctness is preserved. +// histogram binning. The primary goal is to maximize coarse-bin +// resolution in the score region that determines the top-k +// cutoff, thereby: +// - shrinking the Stage-1 threshold bin (fewer collisions) +// - reducing COUNTER_NUM_EQUAL / COUNTER_STAGE2_INPUT +// - reducing the number of Stage-2 refine rounds // -// Modes 3/4/6/7 use a data-adaptive linear mapping to [0,255] -// instead of fp16 bit-pattern bucketing, guaranteeing full -// bucket utilization regardless of value range. +// Stage 2 refinement still uses convert_to_uint32() on raw +// floats, so final ordering correctness is always preserved. +// +// Modes 3/4/6/7/9/10 apply a nonlinear transform then linearly +// map the result to [0,255]. Mode 12 (ADAPTIVE_TAIL_WINDOW) +// directly focuses all 256 bins on the competitive upper tail +// estimated from the top-k ratio, collapsing irrelevant +// low-score mass into bin 0. // ============================================================ enum TopKMappingMode { @@ -29,15 +37,19 @@ enum TopKMappingMode { MAPPING_ERF = 9, // erf(alpha * x) MAPPING_TANH = 10, // tanh(alpha * x) MAPPING_SUBTRACT = 11, // subtract pivot, then fp16 bucketing + MAPPING_ADAPTIVE_TAIL_WINDOW = 12, // focus bins on upper tail via sampled quantile }; struct TopKMappingParams { int mode; // TopKMappingMode float power_exp; // For MAPPING_POWER (default 0.5) + // For MAPPING_ADAPTIVE_TAIL_WINDOW: tail expansion + // factor rho (default 4.0). tau_low = Q(1 - rho*k/n). const uint8_t* __restrict__ lut; // [256] byte LUT, or nullptr const float* __restrict__ quantiles; // [256] float quantile breakpoints, or nullptr bool noscale; // Skip auto-range linear scaling, use fp16 bucketing on f(x) int sample_stride; // Pre-pass sampling stride (1=full, 8=1/8, 0=skip) + int target_k; // Top-k value; used by MAPPING_ADAPTIVE_TAIL_WINDOW }; // NOTE: convert_to_uint8() must be defined before including this header. @@ -158,6 +170,8 @@ __device__ __forceinline__ uint8_t mapped_convert_to_uint8( return convert_to_uint8_bf16(x); case MAPPING_SUBTRACT: return convert_to_uint8(x - range_min); // range_min repurposed as pivot + case MAPPING_ADAPTIVE_TAIL_WINDOW: + return linear_map_to_uint8(x, range_min, inv_range); default: // MAPPING_NONE return convert_to_uint8(x); } @@ -174,3 +188,8 @@ __device__ __forceinline__ bool needs_auto_range(int mode) { __device__ __forceinline__ bool needs_pivot(int mode) { return (mode == MAPPING_SUBTRACT); } + +// Helper: check if mode is the adaptive tail-window pre-pass +__device__ __forceinline__ bool needs_tail_window(int mode) { + return (mode == MAPPING_ADAPTIVE_TAIL_WINDOW); +} diff --git a/csrc/topk_sglang.cu b/csrc/topk_sglang.cu index 9213016..867efbe 100644 --- a/csrc/topk_sglang.cu +++ b/csrc/topk_sglang.cu @@ -571,6 +571,113 @@ __device__ void fast_topk_vortex( } } __syncthreads(); + } else if (needs_tail_window(mapping.mode)) { + // Adaptive tail-window pre-pass: estimate tau_low = Q(1 - rho*k/n) + // and local_max via a sampled quantile estimator. All 256 coarse bins + // are then allocated to [tau_low, local_max]; scores below tau_low + // collapse into bin 0 via linear_map_to_uint8 clamping. + constexpr int MAX_SAMPLES = 1024; + __shared__ float s_samples[MAX_SAMPLES]; + __shared__ int s_sample_count; + + if (tx == 0) s_sample_count = 0; + __syncthreads(); + + // Compute sampling stride so we collect ~MAX_SAMPLES from the segment + const int desired_stride = (length + MAX_SAMPLES - 1) / MAX_SAMPLES; + const int sample_stride = max(desired_stride, 1); + + // Each thread samples elements and finds local_max simultaneously + float local_max = -__FLT_MAX__; + for (int idx = tx * sample_stride; idx < length; idx += BLOCK_SIZE * sample_stride) { + float val = vortex_to_float(input[idx + row_start]); + local_max = fmaxf(local_max, val); + int slot = ::atomicAdd(&s_sample_count, 1); + if (slot < MAX_SAMPLES) { + s_samples[slot] = val; + } + } + + // Reduce local_max across block + for (int offset = 16; offset > 0; offset >>= 1) + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + __shared__ float s_warp_maxs_tw[32]; + { + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) s_warp_maxs_tw[warp_id] = local_max; + } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_max = s_warp_maxs_tw[tx]; + for (int offset = 16; offset > 0; offset >>= 1) + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + if (tx == 0) s_warp_maxs_tw[0] = local_max; + } + __syncthreads(); + local_max = s_warp_maxs_tw[0]; + + int nsamp = min(s_sample_count, MAX_SAMPLES); + + // Simple odd-even transposition sort on the sample buffer. + // nsamp <= 1024, and we have 1024 threads, so each thread + // handles one element. O(nsamp) parallel rounds suffice. + __syncthreads(); + if (nsamp >= 2) { + for (int pass = 0; pass < nsamp; ++pass) { + // Even phase: compare (0,1), (2,3), ... + if (tx * 2 + 1 < nsamp) { + int i = tx * 2; + if (s_samples[i] > s_samples[i + 1]) { + float tmp = s_samples[i]; + s_samples[i] = s_samples[i + 1]; + s_samples[i + 1] = tmp; + } + } + __syncthreads(); + // Odd phase: compare (1,2), (3,4), ... + if (tx * 2 + 2 < nsamp) { + int i = tx * 2 + 1; + if (s_samples[i] > s_samples[i + 1]) { + float tmp = s_samples[i]; + s_samples[i] = s_samples[i + 1]; + s_samples[i + 1] = tmp; + } + } + __syncthreads(); + } + } + + // Estimate tau_low = Q(1 - rho * k / n) + if (tx == 0) { + float rho = mapping.power_exp; // reused as tail expansion factor + if (rho <= 0.0f) rho = 4.0f; + int k = (mapping.target_k > 0) ? mapping.target_k : target_k; + float frac = 1.0f - rho * float(k) / float(length); + frac = fmaxf(frac, 0.0f); // clamp: never go below rank 0 + + float tau_low; + if (nsamp < 4 || frac <= 0.0f) { + // Too few samples or the tail covers everything: full range + tau_low = -__FLT_MAX__; + } else { + float fidx = frac * float(nsamp - 1); + int lo = __float2int_rd(fidx); + lo = min(max(lo, 0), nsamp - 2); + float t = fidx - float(lo); + tau_low = s_samples[lo] * (1.0f - t) + s_samples[lo + 1] * t; + } + + // Fallback: if tau_low >= local_max, use full-range linear mapping + if (tau_low >= local_max) { + // Find the actual minimum from sorted samples + tau_low = (nsamp > 0) ? s_samples[0] : local_max; + } + + float range = local_max - tau_low; + s_range_min = tau_low; + s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; + } + __syncthreads(); } else { if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } __syncthreads(); @@ -991,6 +1098,96 @@ void TopKHistogram_Kernel( } } __syncthreads(); + } else if (needs_tail_window(mapping.mode)) { + // Adaptive tail-window pre-pass (histogram kernel variant) + constexpr int MAX_SAMPLES_H = 1024; + __shared__ float s_samples_h[MAX_SAMPLES_H]; + __shared__ int s_sample_count_h; + + if (tx == 0) s_sample_count_h = 0; + __syncthreads(); + + const int desired_stride = (nblk + MAX_SAMPLES_H - 1) / MAX_SAMPLES_H; + const int sample_stride_h = max(desired_stride, 1); + + float local_max = -__FLT_MAX__; + for (int idx = tx * sample_stride_h; idx < nblk; idx += BLOCK_SIZE * sample_stride_h) { + float val = vortex_to_float(score_blk[idx]); + local_max = fmaxf(local_max, val); + int slot = ::atomicAdd(&s_sample_count_h, 1); + if (slot < MAX_SAMPLES_H) s_samples_h[slot] = val; + } + + for (int offset = 16; offset > 0; offset >>= 1) + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + __shared__ float s_warp_maxs_h[32]; + { + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) s_warp_maxs_h[warp_id] = local_max; + } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_max = s_warp_maxs_h[tx]; + for (int offset = 16; offset > 0; offset >>= 1) + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + if (tx == 0) s_warp_maxs_h[0] = local_max; + } + __syncthreads(); + local_max = s_warp_maxs_h[0]; + + int nsamp = min(s_sample_count_h, MAX_SAMPLES_H); + + __syncthreads(); + if (nsamp >= 2) { + for (int pass = 0; pass < nsamp; ++pass) { + if (tx * 2 + 1 < nsamp) { + int i = tx * 2; + if (s_samples_h[i] > s_samples_h[i + 1]) { + float tmp = s_samples_h[i]; + s_samples_h[i] = s_samples_h[i + 1]; + s_samples_h[i + 1] = tmp; + } + } + __syncthreads(); + if (tx * 2 + 2 < nsamp) { + int i = tx * 2 + 1; + if (s_samples_h[i] > s_samples_h[i + 1]) { + float tmp = s_samples_h[i]; + s_samples_h[i] = s_samples_h[i + 1]; + s_samples_h[i + 1] = tmp; + } + } + __syncthreads(); + } + } + + if (tx == 0) { + float rho = mapping.power_exp; + if (rho <= 0.0f) rho = 4.0f; + int k = mapping.target_k; + float frac = (k > 0 && nblk > 0) ? 1.0f - rho * float(k) / float(nblk) : 0.0f; + frac = fmaxf(frac, 0.0f); + + float tau_low; + if (nsamp < 4 || frac <= 0.0f) { + tau_low = -__FLT_MAX__; + } else { + float fidx = frac * float(nsamp - 1); + int lo = __float2int_rd(fidx); + lo = min(max(lo, 0), nsamp - 2); + float t = fidx - float(lo); + tau_low = s_samples_h[lo] * (1.0f - t) + s_samples_h[lo + 1] * t; + } + + if (tau_low >= local_max) { + tau_low = (nsamp > 0) ? s_samples_h[0] : local_max; + } + + float range = local_max - tau_low; + s_range_min = tau_low; + s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; + } + __syncthreads(); } else { if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } __syncthreads(); @@ -1164,6 +1361,7 @@ void topk_output_sglang( mapping.quantiles = nullptr; mapping.noscale = mapping_noscale; mapping.sample_stride = 1; + mapping.target_k = static_cast(topk_val); if (mapping_lut.has_value()) { const auto& lut = mapping_lut.value(); @@ -1233,7 +1431,8 @@ void topk_profile_histogram( const double mapping_power, std::optional mapping_lut, std::optional mapping_quantiles, - const bool mapping_noscale) + const bool mapping_noscale, + const int64_t topk_val) { CHECK_CUDA(x); CHECK_CUDA(dense_kv_indptr); @@ -1252,6 +1451,7 @@ void topk_profile_histogram( mapping.quantiles = nullptr; mapping.noscale = mapping_noscale; mapping.sample_stride = 1; + mapping.target_k = static_cast(topk_val); if (mapping_lut.has_value()) { const auto& lut = mapping_lut.value(); @@ -1305,7 +1505,8 @@ static TopKMappingParams build_mapping_params( std::optional& mapping_lut, std::optional& mapping_quantiles, bool mapping_noscale = false, - int sample_stride = 1) + int sample_stride = 1, + int target_k = 0) { TopKMappingParams mapping{}; mapping.mode = static_cast(mapping_mode); @@ -1314,6 +1515,7 @@ static TopKMappingParams build_mapping_params( mapping.quantiles = nullptr; mapping.noscale = mapping_noscale; mapping.sample_stride = sample_stride; + mapping.target_k = target_k; if (mapping_lut.has_value()) { const auto& lut = mapping_lut.value(); @@ -1356,7 +1558,8 @@ void topk_profile_stage1( "topk_profile_stage1: topk_val (", topk_val, ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); - auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles, mapping_noscale); + auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles, + mapping_noscale, /*sample_stride=*/1, /*target_k=*/static_cast(topk_val)); dim3 nblks(eff_batch_size); dim3 nthreads(kThreadsPerBlock); @@ -1428,7 +1631,8 @@ void topk_profile_counters( TORCH_CHECK(counters.scalar_type() == at::ScalarType::Int, "counters must be int32"); - auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles, mapping_noscale); + auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles, + mapping_noscale, /*sample_stride=*/1, /*target_k=*/static_cast(topk_val)); dim3 nblks(eff_batch_size); dim3 nthreads(kThreadsPerBlock); diff --git a/csrc/topk_slgang_ori.cu b/csrc/topk_slgang_ori.cu deleted file mode 100644 index 04a2b73..0000000 --- a/csrc/topk_slgang_ori.cu +++ /dev/null @@ -1,546 +0,0 @@ -/** - * @NOTE: This file is adapted from - * https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_v32/topk_selector.py - * We: - * 1. adapt from tilelang to pure cuda - * 2. optimize the performance a little - * 3. fix the potential illegal memory access - */ -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -namespace { - -constexpr int TopK = 2048; -constexpr int kThreadsPerBlock = 1024; - -#ifdef USE_ROCM -// On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a -// per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. -#ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES -constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); -#else -constexpr size_t kSmem = 48 * 1024; // bytes -#endif -#else -// Reduced from 128KB to 32KB to improve occupancy. -// Each radix pass needs at most ~TopK candidates in the threshold bin, -// so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. -constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) -#endif - -struct FastTopKParams { - const float* __restrict__ input; // [B, input_stride] - const int32_t* __restrict__ row_starts; // [B] - int32_t* __restrict__ indices; // [B, TopK] - int32_t* __restrict__ lengths; // [B] - int64_t input_stride; -}; - -// when length <= TopK, we can directly write the indices -__device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { - const auto tid = threadIdx.x; - for (int i = tid; i < TopK; i += kThreadsPerBlock) { - indice[i] = (i < length) ? i : -1; - } -} - -// keep the first `length` entries, set others to -1 -__device__ void naive_topk_transform( - const float* __restrict__ score, - int32_t length, - int32_t* __restrict__ dst_page_table, - const int32_t* __restrict__ src_page_table) { - const auto tid = threadIdx.x; - for (auto i = tid; i < TopK; i += kThreadsPerBlock) { - dst_page_table[i] = (i < length) ? src_page_table[i] : -1; - } -} - -// keep the first `length` entries, set others to -1 -__device__ void naive_topk_transform_ragged( - const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { - const auto tid = threadIdx.x; - for (auto i = tid; i < TopK; i += kThreadsPerBlock) { - topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; - } -} - -__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { - __half h = __float2half_rn(x); - uint16_t bits = __half_as_ushort(h); - uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); - return static_cast(key >> 8); -} - -__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { - uint32_t bits = __float_as_uint(x); - return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); -} - -__device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { - // An optimized topk kernel copied from tilelang kernel - // We assume length > TopK here, or it will crash - int topk = TopK; - constexpr auto BLOCK_SIZE = 1024; - constexpr auto RADIX = 256; - constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); - - alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; - alignas(128) __shared__ int s_counter; - alignas(128) __shared__ int s_threshold_bin_id; - alignas(128) __shared__ int s_num_input[2]; - - auto& s_histogram = s_histogram_buf[0]; - // allocate for two rounds - extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; - - const int tx = threadIdx.x; - - // stage 1: 8bit coarse histogram - if (tx < RADIX + 1) s_histogram[tx] = 0; - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = convert_to_uint8(input[idx + row_start]); - ::atomicAdd(&s_histogram[bin], 1); - } - __syncthreads(); - - const auto run_cumsum = [&] { -#pragma unroll 8 - for (int i = 0; i < 8; ++i) { - static_assert(1 << 8 == RADIX); - if (C10_LIKELY(tx < RADIX)) { - const auto j = 1 << i; - const auto k = i & 1; - auto value = s_histogram_buf[k][tx]; - if (tx < RADIX - j) { - value += s_histogram_buf[k][tx + j]; - } - s_histogram_buf[k ^ 1][tx] = value; - } - __syncthreads(); - } - }; - - run_cumsum(); - if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { - s_threshold_bin_id = tx; - s_num_input[0] = 0; - s_counter = 0; - } - __syncthreads(); - - const auto threshold_bin = s_threshold_bin_id; - topk -= s_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = static_cast(convert_to_uint8(input[idx + row_start])); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - return; - } else { - __syncthreads(); - if (tx < RADIX + 1) { - s_histogram[tx] = 0; - } - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto raw_input = input[idx + row_start]; - const auto bin = static_cast(convert_to_uint8(raw_input)); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - const auto pos = ::atomicAdd(&s_num_input[0], 1); - /// NOTE: (dark) fuse the histogram computation here - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - s_input_idx[0][pos] = idx; - const auto bin = convert_to_uint32(raw_input); - const auto sub_bin = (bin >> 24) & 0xFF; - ::atomicAdd(&s_histogram[sub_bin], 1); - } - } - } - __syncthreads(); - } - - // stage 2: refine with 8bit radix passes -#pragma unroll 4 - for (int round = 0; round < 4; ++round) { - __shared__ int s_last_remain; - const auto r_idx = round % 2; - - // clip here to prevent overflow - const auto _raw_num_input = s_num_input[r_idx]; - const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); - - run_cumsum(); - if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { - s_threshold_bin_id = tx; - s_num_input[r_idx ^ 1] = 0; - s_last_remain = topk - s_histogram[tx + 1]; - } - __syncthreads(); - - const auto threshold_bin = s_threshold_bin_id; - topk -= s_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = s_input_idx[r_idx][i]; - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - break; - } else { - __syncthreads(); - if (tx < RADIX + 1) { - s_histogram[tx] = 0; - } - __syncthreads(); - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = s_input_idx[r_idx][i]; - const auto raw_input = input[idx + row_start]; - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - if (round == 3) { - const auto pos = ::atomicAdd(&s_last_remain, -1); - if (pos > 0) { - index[TopK - pos] = idx; - } - } else { - const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - /// NOTE: (dark) fuse the histogram computation here - s_input_idx[r_idx ^ 1][pos] = idx; - const auto bin = convert_to_uint32(raw_input); - const auto sub_bin = (bin >> (offset - 8)) & 0xFF; - ::atomicAdd(&s_histogram[sub_bin], 1); - } - } - } - } - __syncthreads(); - } - } -} - -__global__ __launch_bounds__(kThreadsPerBlock) // topk - void topk_kernel(const FastTopKParams params) { - const auto& [input, row_starts, indices, lengths, input_stride] = params; - const auto bid = static_cast(blockIdx.x); - const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; - const auto length = lengths[bid]; - const auto indice = indices + bid * TopK; - const auto score = input + bid * input_stride; - if (length <= TopK) { - return naive_topk_cuda(score, indice, length); - } else { - return fast_topk_cuda_tl(score, indice, row_start, length); - } -} - -__global__ __launch_bounds__(kThreadsPerBlock) // decode - void topk_transform_decode_kernel( - const FastTopKParams params, - int32_t* __restrict__ dst_page_table, - const int32_t* __restrict__ src_page_table, - const int64_t src_stride) { - const auto& [input, _1, _2, lengths, input_stride] = params; - const auto bid = static_cast(blockIdx.x); - const auto tid = threadIdx.x; - const auto row_start = 0; - const auto length = lengths[bid]; - const auto src_page_entry = src_page_table + bid * src_stride; - const auto dst_page_entry = dst_page_table + bid * TopK; - const auto score = input + bid * input_stride; - if (length <= TopK) { - return naive_topk_transform(score, length, dst_page_entry, src_page_entry); - } else { - __shared__ int s_indices[TopK]; - fast_topk_cuda_tl(score, s_indices, row_start, length); - // copy src[s_indices] to dst, we manually unroll here - static_assert(TopK % kThreadsPerBlock == 0); - static_assert(TopK / kThreadsPerBlock == 2); - const auto idx_0 = tid; - const auto pos_0 = s_indices[idx_0]; - dst_page_entry[idx_0] = src_page_entry[pos_0]; - const auto idx_1 = tid + kThreadsPerBlock; - const auto pos_1 = s_indices[idx_1]; - dst_page_entry[idx_1] = src_page_entry[pos_1]; - } -} - -__global__ __launch_bounds__(kThreadsPerBlock) // prefill - void topk_transform_prefill_kernel( - const FastTopKParams params, - int32_t* __restrict__ dst_page_table, - const int32_t* __restrict__ src_page_table, - const int64_t src_stride, - const int32_t* __restrict__ cu_seqlens_q, - const int64_t prefill_bs) { - const auto& [input, row_starts, _, lengths, input_stride] = params; - const auto bid = static_cast(blockIdx.x); - const auto tid = threadIdx.x; - const auto length = lengths[bid]; - const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; - const auto dst_page_entry = dst_page_table + bid * TopK; - const auto score = input + bid * input_stride; - - /// NOTE: prefill bs is usually small, we can just use a simple loop here - /// We ensure that last cu_seqlens is equal to number of blocks launched - __shared__ const int32_t* s_src_page_entry; - if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { - if (tid < prefill_bs) { - if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { - s_src_page_entry = src_page_table + tid * src_stride; - } - } - } else { - for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { - if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { - s_src_page_entry = src_page_table + i * src_stride; - } - } - } - __syncthreads(); - const auto src_page_entry = s_src_page_entry; - - if (length <= TopK) { - return naive_topk_transform(score, length, dst_page_entry, src_page_entry); - } else { - __shared__ int s_indices[TopK]; - fast_topk_cuda_tl(score, s_indices, row_start, length); - // copy src[s_indices] to dst, we manually unroll here - static_assert(TopK % kThreadsPerBlock == 0); - static_assert(TopK / kThreadsPerBlock == 2); - const auto idx_0 = tid; - const auto pos_0 = s_indices[idx_0]; - dst_page_entry[idx_0] = src_page_entry[pos_0]; - const auto idx_1 = tid + kThreadsPerBlock; - const auto pos_1 = s_indices[idx_1]; - dst_page_entry[idx_1] = src_page_entry[pos_1]; - } -} - -__global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv - void topk_transform_prefill_ragged_kernel( - const FastTopKParams params, - int32_t* __restrict__ topk_indices_ragged, - const int32_t* __restrict__ topk_indices_offset) { - const auto& [input, row_starts, _, lengths, input_stride] = params; - const auto bid = static_cast(blockIdx.x); - const auto tid = threadIdx.x; - const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; - const auto length = lengths[bid]; - const auto dst_indices_entry = topk_indices_ragged + bid * TopK; - const auto score = input + bid * input_stride; - const auto offset = topk_indices_offset[bid]; - - if (length <= TopK) { - return naive_topk_transform_ragged(score, length, dst_indices_entry, offset); - } else { - __shared__ int s_indices[TopK]; - fast_topk_cuda_tl(score, s_indices, row_start, length); - // copy src[s_indices] to dst, we manually unroll here - static_assert(TopK % kThreadsPerBlock == 0); - static_assert(TopK / kThreadsPerBlock == 2); - const auto idx_0 = tid; - const auto pos_0 = s_indices[idx_0]; - dst_indices_entry[idx_0] = pos_0 + offset; - const auto idx_1 = tid + kThreadsPerBlock; - const auto pos_1 = s_indices[idx_1]; - dst_indices_entry[idx_1] = pos_1 + offset; - } -} - -auto get_params( - const at::Tensor& score, - const at::Tensor& lengths, - std::optional row_starts_opt = std::nullopt, - std::optional indices_opt = std::nullopt) -> FastTopKParams { - const auto B = score.size(0); - TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); - if (row_starts_opt.has_value()) { - const auto& row_starts = row_starts_opt.value(); - TORCH_CHECK(row_starts.dim() == 1); - TORCH_CHECK(row_starts.size(0) == B); - } - TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); - TORCH_CHECK(lengths.size(0) == B); - int32_t* indices_data_ptr = nullptr; - if (indices_opt.has_value()) { - const auto& indices = indices_opt.value(); - TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); - TORCH_CHECK(indices.size(0) == B); - TORCH_CHECK(indices.size(1) == TopK); - indices_data_ptr = indices.data_ptr(); - } - - return FastTopKParams{ - .input = score.data_ptr(), - .row_starts = row_starts_opt.has_value() ? row_starts_opt->data_ptr() : nullptr, - .indices = indices_data_ptr, - .lengths = lengths.data_ptr(), - .input_stride = score.stride(0), - }; -} - -template -void setup_kernel_smem_once() { - [[maybe_unused]] - static const auto result = [] { -#ifdef USE_ROCM - // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, - // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing - // a function pointer directly, so cast explicitly. - return ::cudaFuncSetAttribute( - reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); -#else - // CUDA: keep original behavior (no cast needed). - return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); -#endif - }(); - TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); -} - -} // namespace - -#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") - -void fast_topk_interface( - const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional row_starts_opt) { - CHECK_CUDA(score); - CHECK_CUDA(indices); - if (row_starts_opt.has_value()) { - CHECK_CUDA(row_starts_opt.value()); - } - CHECK_CUDA(lengths); - const auto params = get_params(score, lengths, row_starts_opt, indices); - const auto B = score.size(0); - const auto stream = at::cuda::getCurrentCUDAStream().stream(); - const auto grid = dim3{static_cast(B)}; - const auto block = dim3{kThreadsPerBlock}; - setup_kernel_smem_once(); - topk_kernel<<>>(params); - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); -} - -void fast_topk_transform_interface( - const at::Tensor& score, - const at::Tensor& lengths, - at::Tensor& dst_page_table, - const at::Tensor& src_page_table, - const at::Tensor& cu_seqlens_q, - std::optional row_starts_opt) { - CHECK_CUDA(score); - CHECK_CUDA(lengths); - CHECK_CUDA(dst_page_table); - CHECK_CUDA(src_page_table); - CHECK_CUDA(cu_seqlens_q); - if (row_starts_opt.has_value()) { - CHECK_CUDA(row_starts_opt.value()); - } - const auto params = get_params(score, lengths, row_starts_opt); - const auto B = score.size(0); - TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); - TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); - TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); - const auto prefill_bs = cu_seqlens_q.size(0) - 1; - TORCH_CHECK(dst_page_table.size(0) == B); - TORCH_CHECK(dst_page_table.size(1) == TopK); - TORCH_CHECK(src_page_table.size(0) == prefill_bs); - TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs - - // launch kernel - const auto stream = at::cuda::getCurrentCUDAStream().stream(); - const auto grid = dim3{static_cast(B)}; - const auto block = dim3{kThreadsPerBlock}; - const auto src_stride = src_page_table.stride(0); - - // dispatch to decode or prefill - // extend and draft extend: row_starts_opt is not null, invokes the prefill kernel - // decode: row_starts_opt is null, invokes the decode kernel - // target verify: row_starts_opt is null, invokes the prefill kernel - const auto is_decode = !row_starts_opt.has_value() && prefill_bs == B; - if (is_decode) { - setup_kernel_smem_once(); - topk_transform_decode_kernel<<>>( - params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride); - } else { - setup_kernel_smem_once(); - topk_transform_prefill_kernel<<>>( - params, - dst_page_table.data_ptr(), - src_page_table.data_ptr(), - src_stride, - cu_seqlens_q.data_ptr(), - prefill_bs); - } - - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); -} - -void fast_topk_transform_ragged_interface( - const at::Tensor& score, - const at::Tensor& lengths, - at::Tensor& topk_indices_ragged, - const at::Tensor& topk_indices_offset, - std::optional row_starts_opt) { - CHECK_CUDA(score); - CHECK_CUDA(lengths); - CHECK_CUDA(topk_indices_ragged); - CHECK_CUDA(topk_indices_offset); - if (row_starts_opt.has_value()) { - CHECK_CUDA(row_starts_opt.value()); - } - - const auto params = get_params(score, lengths, row_starts_opt); - const auto B = score.size(0); - TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous()); - TORCH_CHECK(topk_indices_offset.dim() == 1); - - TORCH_CHECK(topk_indices_ragged.size(0) == B); - TORCH_CHECK(topk_indices_ragged.size(1) == TopK); - TORCH_CHECK(topk_indices_offset.size(0) == B); - - // launch kernel - const auto stream = at::cuda::getCurrentCUDAStream().stream(); - const auto grid = dim3{static_cast(B)}; - const auto block = dim3{kThreadsPerBlock}; - - setup_kernel_smem_once(); - topk_transform_prefill_ragged_kernel<<>>( - params, topk_indices_ragged.data_ptr(), topk_indices_offset.data_ptr()); - - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); -} diff --git a/examples/run_distribution_analysis.sh b/examples/run_distribution_analysis.sh index 3022dda..98557c7 100755 --- a/examples/run_distribution_analysis.sh +++ b/examples/run_distribution_analysis.sh @@ -48,7 +48,7 @@ TOPK_VAL=30 MEM=0.7 ALGO="block_sparse_attention" # The path to the raw_histograms.npy file (set to skip calibration) -REAL_HISTOGRAMS="/scr/dataset/yuke/xinrui/new/vortex_torch/examples/calibration/raw_histograms.npy" +REAL_HISTOGRAMS="" # ── Parse arguments ─────────────────────────────────────────── while [[ $# -gt 0 ]]; do diff --git a/examples/run_distribution_analysis_new.sh b/examples/run_distribution_analysis_new.sh index f0938ff..623bc82 100755 --- a/examples/run_distribution_analysis_new.sh +++ b/examples/run_distribution_analysis_new.sh @@ -36,7 +36,7 @@ MEM=0.7 ALGO="block_sparse_attention" # The path to the raw_histograms.npy file (set to skip calibration) # REAL_HISTOGRAMS="/scr/dataset/yuke/xinrui/new/vortex_torch/examples/calibration/raw_histograms.npy" -REAL_HISTOGRAMS="${SCRIPT_DIR}/calibration/raw_histograms.npy" +REAL_HISTOGRAMS="" # ── Parse arguments ─────────────────────────────────────────── while [[ $# -gt 0 ]]; do case "$1" in diff --git a/examples/verify_algo.py b/examples/verify_algo.py index f1c2a2d..fb3e843 100644 --- a/examples/verify_algo.py +++ b/examples/verify_algo.py @@ -254,15 +254,15 @@ def parse_args(): "--topk-mapping-mode", type=int, default=0, - choices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], - help='TopK mapping mode: 0=none, 1=lut_cdf, 2=quantile, 3=power, 4=log, 5=index_cache, 6=asinh, 7=log1p, 8=trunc8, 9=erf, 10=tanh, 11=subtract (default: 0).', + choices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + help='TopK mapping mode: 0=none, 1=lut_cdf, 2=quantile, 3=power, 4=log, 5=index_cache, 6=asinh, 7=log1p, 8=trunc8, 9=erf, 10=tanh, 11=subtract, 12=adaptive_tail_window (default: 0).', ) parser.add_argument( "--topk-mapping-power", type=float, default=0.5, - help='Hyperparameter for parametric modes: power exponent (mode 3), beta (mode 7 asinh), alpha (mode 8 log1p). Default: 0.5.', + help='Hyperparameter for parametric modes: power exponent (mode 3), beta (mode 6 asinh), alpha (mode 7 log1p), rho tail expansion (mode 12). Default: 0.5.', ) parser.add_argument( diff --git a/examples/verify_algo_topk_mapping_new.sh b/examples/verify_algo_topk_mapping_new.sh index f1c41fe..5c5d6cf 100644 --- a/examples/verify_algo_topk_mapping_new.sh +++ b/examples/verify_algo_topk_mapping_new.sh @@ -24,7 +24,7 @@ sparse_algos=( ) # Path to real-data histograms from calibration (for auto-tuning) -REAL_HISTOGRAMS="/scr/dataset/yuke/xinrui/new/vortex_torch/examples/calibration/raw_histograms.npy" +REAL_HISTOGRAMS="" RESULTS_DIR="results" mkdir -p "${RESULTS_DIR}" @@ -212,6 +212,28 @@ for algo in "${sparse_algos[@]}"; do 2>&1 | tee "${OUTFILE}" done +# ============================================================ +# Step 8: Mode 12 (adaptive_tail_window), rho=4.0 +# ============================================================ +echo "" +echo "============================================================" +echo "Step 8: Mode 12 (adaptive_tail_window), rho=4.0" +echo "============================================================" +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_12_${TIMESTAMP}.log" + echo ">>> Mode 12 (adaptive_tail_window) algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 12 \ + --topk-mapping-power 4.0 \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + # ============================================================ # Summary # ============================================================ @@ -226,4 +248,5 @@ echo " Mode 8 (trunc8): (fixed)" echo " Mode 9 (erf): alpha = ${BEST_POWER_9} (autotuned)" echo " Mode 10 (tanh): alpha = ${BEST_POWER_10} (autotuned)" echo " Mode 11 (subtract): (fixed)" +echo " Mode 12 (tail_win): rho = 4.0" echo "============================================================" diff --git a/setup.py b/setup.py index 649f0a0..9c2186b 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,9 @@ '-gencode=arch=compute_86,code=sm_86', '-gencode=arch=compute_89,code=sm_89', '-gencode=arch=compute_90,code=sm_90', + '-gencode=arch=compute_100a,code=sm_100a', + '-gencode=arch=compute_120,code=sm_120' + ], }, ), diff --git a/todo.txt b/todo.txt deleted file mode 100644 index 53950c3..0000000 --- a/todo.txt +++ /dev/null @@ -1,308 +0,0 @@ -1. -prefill 8k/16k/32k block 16/32/64 block topk (8,16) -qwen series 0.6b, 1.7b, 4b, 8b, 16b, 32b -baselines: -flashinfer-fa2/fa3 flashattention v2/v3 -dense - -NSA: block sparse attention -benchmarking: -flash Sparse Attention -https://github.com/Relaxed-System-Lab/Flash-Sparse-Attention - -https://github.com/mit-han-lab/flash-moba - - -Video generation: VAE fp8 convolution -wan 2.1 vae -1.3B input 480P - -3.17: -(For SOSP 26) -warpper: prefill: -1. ragged paged, warpper -disable_radix_cache -new backend: goal sparse prefill with topk on a new warpper: abandon the previous paged warpper, -apply 1 to the whole prefill sequence - -2. topk -idea: we want to improve the current topk kernel /scr/dataset/yuke/xinrui/new/vortex_torch/csrc/topk_sglang.cu for our project, -we want to map the value in each layer for the topk selection to a new distribution that bucket sort is more efficient on. like to make the values to be more uniform -in each bucket. -the number of the heads has a certain distribution, try to adapt to it. -The mapping function should have a low overhead, or it would damage the end2end efficiency. -Key: first profile the whole process, record the distribution of the value in each layer. You need a profile script for this. save the results. -Then design a novel mapping function(can be easily customize by me), to map the value to a new distribution. Don't change the correctness of the sorting, but more efficient for the bucket sort in the /scr/dataset/yuke/xinrui/new/vortex_torch/csrc/topk_sglang.cu -here is some options: -Option A: Adaptive Bit Selection (2-Pass Min/Max on uint32 Key) - -Core idea: Instead of always extracting the top 8 bits of the fp16 key, extract the 8 most-significant varying bits of the full 32-bit key by finding the actual key range within each segment. - -Algorithm: -Pass 1: Parallel warp-reduction to find min_key and max_key over convert_to_uint32(x) for the segment (~N/1024 iterations per thread, same as current Stage 1) -Compute shift = max(0, 31 - clz(max_key - min_key) - 8) (find the bit position of the 8 most-significant differing bits) - -Pass 2: bin = ((convert_to_uint32(x) - min_key) >> shift) & 0xFF — this uses ALL 32 bits of float precision for binning -Overhead: ~2x data reads for Stage 1 (one extra scan for min/max). Min/max reduction can be done with __shfl_xor_sync followed by a single atomicMin/atomicMax in shared memory — very efficient. - -Pros: -Uses full 32-bit float precision instead of just 8 fp16 bits (up to 2^24 = 16M effective resolution levels instead of 256) -Perfectly adaptive to any data range — no calibration needed - -Pure integer arithmetic in the hot loop (shift + subtract + mask) -Guaranteed monotonic (linear mapping of uint32 keys) -Cons: -2x memory bandwidth for Stage 1 (the min/max pass re-reads all data) -Doesn't guarantee perfectly uniform bin counts (a skewed distribution within [min, max] still skews bins) -Expected quality: Excellent for the real distribution (where the range is narrow but contains many distinct float values that the 8-bit fp16 extraction cannot distinguish). -Implementation sketch: - -// Pass 1: find min/max key via parallel reduction -__shared__ uint32_t s_min_key, s_max_key; -if (tx == 0) { s_min_key = 0xFFFFFFFF; s_max_key = 0; } -__syncthreads(); -for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - uint32_t k = convert_to_uint32(input[idx + row_start]); - atomicMin(&s_min_key, k); - atomicMax(&s_max_key, k); -} -__syncthreads(); -uint32_t range = s_max_key - s_min_key; -int shift = max(0, 31 - __clz(range | 1) - 7); // 8 MSBs of range - -// Pass 2: histogram with adaptive bins -for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - uint32_t k = convert_to_uint32(input[idx + row_start]); - uint8_t bin = ((k - s_min_key) >> shift) & 0xFF; - atomicAdd(&s_histogram[bin], 1); -} - - - -Option B: Strided Sampling + Approximate CDF Equalization -Core idea: Read a sparse sample (every S-th element, e.g. S=32) to build an approximate histogram, compute a CDF-equalized LUT from it in shared memory, then apply the LUT during the full scan. -Algorithm: -Sampling pass: Read every 32nd element, build approximate 256-bin histogram in shared memory -In-block LUT construction: 256 threads compute prefix sum -> CDF -> equalized LUT (8 iterations of parallel prefix sum, same as existing run_cumsum) -Full scan: Apply s_lut[convert_to_uint8(x)] for each element -Overhead: ~(1/32 + 1)x = ~1.03x memory reads. LUT construction is ~8 syncthreads (trivial). The hot-loop cost is identical to existing Mode 1 (one shared memory lookup). -Pros: -Near-zero extra bandwidth (only 3% overhead from sampling) -Fully adaptive — no offline calibration needed -Self-tuning: the LUT is computed from the current segment's own data -Hot-loop cost identical to existing LUT mode (1 shared memory read) -Cons: -Approximate: sampling introduces noise in the estimated CDF (especially for small segments) -More complex control flow (3 phases in Stage 1) -For very short segments (<1024 elements), the sample may be too small -Expected quality: Very good. With 1/32 sampling on a segment of 4096+ elements, the CDF estimate has ~128+ samples — sufficient for a good 256-entry LUT. -Implementation sketch: - -// Phase 0: sampled histogram -__shared__ int s_sample_hist[256]; -if (tx < 256) s_sample_hist[tx] = 0; -__syncthreads(); -for (int idx = tx * 32; idx < length; idx += BLOCK_SIZE * 32) { - uint8_t bin = convert_to_uint8(input[idx + row_start]); - atomicAdd(&s_sample_hist[bin], 1); -} -__syncthreads(); - -// Phase 0.5: compute equalized LUT from sampled histogram -__shared__ uint8_t s_eq_lut[256]; -// ... prefix sum on s_sample_hist -> CDF -> s_eq_lut[i] = floor(CDF(i) * 255) - -// Phase 1: full scan with equalized LUT -for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - uint8_t raw_bin = convert_to_uint8(input[idx + row_start]); - uint8_t eq_bin = s_eq_lut[raw_bin]; - atomicAdd(&s_histogram[eq_bin], 1); -} - - -Option C: Online 2-Pass Histogram Equalization -Core idea: The "gold standard" — build the exact histogram first, compute CDF equalization in shared memory, then re-scan with the equalized mapping. -Algorithm: - -Pass 1: Build exact 256-bin histogram using convert_to_uint8(x) (same as current code) -In-block: 256 threads compute prefix sum -> CDF -> equalized LUT (lut[i] = floor(CDF(i) * 255)) - -Pass 2: Re-scan ALL elements, apply s_lut[convert_to_uint8(x)], build equalized histogram -> find threshold -Overhead: Exactly 2x data reads for Stage 1. LUT construction is negligible (~8 __syncthreads steps). -Pros: -Optimal histogram equalization — produces a provably uniform distribution (each equalized bin has almost exactly N/256 elements) -No calibration, no approximation — perfect adaptation to any input distribution -Monotonic (CDF is monotonically non-decreasing) - -Cons: -2x memory bandwidth for Stage 1 (dominant cost) -May not be worth it if Stage 2 is already fast (the overhead of the extra pass might exceed the savings from a smaller threshold bin) -Expected quality: Optimal. The threshold bin after equalization will have ~N/256 elements regardless of input distribution. -When this is worthwhile: When the original hot bin contains a very large fraction of elements (e.g., >20% of N), the savings from nearly eliminating Stage 2 easily outweigh the extra read pass. Given real data's Gini=0.809, this is likely the case. - -Option D: Temporal CDF Caching (Self-Calibrating LUT) -Core idea: The score distribution per head is relatively stable across adjacent decoding steps. After each kernel invocation, write the histogram to global memory. At the start of the NEXT invocation, load it and compute the CDF-equalized LUT. -Algorithm: -At kernel launch, load the previous iteration's histogram from a persistent global buffer prev_hist[head_id][256] -Compute equalized LUT in shared memory (prefix sum -> CDF -> LUT) -Stage 1 uses this LUT (same as Mode 1) -After Stage 1's histogram is built, write it to prev_hist[head_id][256] for the next iteration -Overhead: 1 shared memory lookup per element (identical to existing Mode 1). Plus ~256 int32 reads + ~256 int32 writes per kernel launch (negligible). -Pros: -Near-zero per-element overhead (shared memory LUT lookup) -Self-calibrating — no offline calibration step needed -Adapts to distribution changes over time (with 1-step lag) -Builds directly on existing Mode 1 infrastructure -Very low implementation complexity - -Cons: -1-step lag: the LUT is based on the previous step's data (cold start on first iteration) -Requires persistent global memory buffer (~256 * 4 bytes per head) -May produce suboptimal LUT if the distribution changes rapidly between steps -Expected quality: Very good after the first few iterations. Attention score distributions evolve slowly during generation, so the 1-step lag has minimal impact. -Implementation changes: -Add prev_histogram pointer to TopKMappingParams -Add MAPPING_TEMPORAL_CDF = 7 mode -In kernel: load prev histogram -> compute LUT -> use LUT -> write current histogram -Python side: allocate persistent [num_heads, 256] int32 buffer, pass to kernel - -Option E: Adaptive Exponent-Mantissa Bit Packing -Core idea: The current 8-bit extraction uses 5 exponent + 2 mantissa bits from fp16. When the actual exponent range is narrow (e.g., only 2-3 distinct exponents), most of those 5 exponent bits are wasted. Dynamically reallocate bits: use fewer for the exponent, more for the mantissa. -Algorithm: -Calibration or per-block scan: Determine exponent range [E_min, E_max] of the scores -Choose bit layout based on range width: -Range 1-2 exponents: 1 exp bit + 6 mantissa bits + 1 sign = 64 bins/exponent (vs 4 currently) → 16x improvement -Range 3-4: 2 exp + 5 mantissa + 1 sign -Range 5-8: 3 exp + 4 mantissa + 1 sign -Range 9-16: 4 exp + 3 mantissa + 1 sign -Wider: original 5 exp + 2 mantissa + 1 sign - -Apply: bin = ((exp - E_min) << mantissa_bits) | (mantissa >> (10 - mantissa_bits)) for positive values (with sign-magnitude ordering) -Overhead: Very low (~5-8 integer instructions per element: extract exponent, subtract base, shift, OR with mantissa). No extra memory reads. No LUT. -Pros - -Extremely low overhead — pure register-level bit manipulation -No extra memory reads or shared memory usage -Monotonic (order-preserving within each exponent, and across exponents) -Up to 16x better bin resolution for narrow distributions -Cons: -Requires knowing E_min/E_max (either calibrated offline, or from a quick per-block reduction) -Not as "perfect" as CDF equalization — distribution within each exponent may still be non-uniform -More complex bit manipulation logic -Expected quality: Very good for the observed real distribution (narrow exponent range → dramatic improvement in bin resolution). Not optimal for arbitrary distributions. -Implementation sketch: - -// Assuming E_min, E_max precomputed and passed in TopKMappingParams -__device__ __forceinline__ uint8_t map_adaptive_bits(float x, int e_min, int e_range) { - __half h = __float2half_rn(x); - uint16_t bits = __half_as_ushort(h); - uint16_t key = (bits & 0x8000) ? ~bits : (bits | 0x8000); - int exp_val = (key >> 10) & 0x1F; // 5-bit exponent - int mantissa = key & 0x3FF; // 10-bit mantissa - - // Determine bit allocation based on e_range - int exp_bits, mant_bits; - if (e_range <= 2) { exp_bits = 1; mant_bits = 6; } - else if (e_range <= 4) { exp_bits = 2; mant_bits = 5; } - else if (e_range <= 8) { exp_bits = 3; mant_bits = 4; } - else { return (uint8_t)(key >> 8); } // fallback - - int sign_bit = (key >> 15) & 1; - int exp_part = min((exp_val - e_min), (1 << exp_bits) - 1); - int mant_part = mantissa >> (10 - mant_bits); - return (uint8_t)((sign_bit << 7) | (exp_part << mant_bits) | mant_part); -} -Use a new file to store these options, and make sure I can switch between these options. -3. Adapt Sparse attentions to the vortex this include: -(1) Naive Sparse Attention /scr/dataset/yuke/xinrui/Sparse-benchmark/Flash-Sparse-Attention/nsa_ref -(2) Flash Sparse Attention /scr/dataset/yuke/xinrui/Sparse-benchmark/Flash-Sparse-Attention/fsa -(3) FlashMoBA /scr/dataset/yuke/xinrui/Sparse-benchmark/flash-moba -Need to implement the whole sparse attention kernel, use their attention backend. Replace forward extend - - -# How to use the custom mapping function: ---- - Writing a Custom Mapping Function - - All mapping logic lives in csrc/topk_mapping.cuh. To add a new mode: - - Step 1: Add to the enum in topk_mapping.cuh: - - enum TopKMappingMode { - MAPPING_NONE = 0, - MAPPING_LUT_CDF = 1, - MAPPING_QUANTILE = 2, - MAPPING_POWER = 3, - MAPPING_LOG = 4, - MAPPING_CUSTOM = 5, // <-- your new mode - }; - - Step 2: Write your __device__ mapping function. It must take a float score and return a uint8_t bin index (0–255). The mapping must be monotonic (order-preserving) to ensure - correctness: - - __device__ __forceinline__ - uint8_t map_custom(float x) { - // Example: sqrt transform - float mapped = copysignf(sqrtf(fabsf(x)), x); - return convert_to_uint8(mapped); - } - - Step 3: Add a case to the dispatcher mapped_convert_to_uint8(): - - __device__ __forceinline__ - uint8_t mapped_convert_to_uint8(float x, const TopKMappingParams& params) { - switch (params.mode) { - // ... existing cases ... - case MAPPING_CUSTOM: - return map_custom(x); - default: - return convert_to_uint8(x); - } - } - - Step 4: Update verify_algo.py to accept the new mode value. In parse_args(), change the choices: - - parser.add_argument("--topk-mapping-mode", type=int, default=0, - choices=[0, 1, 2, 3, 4, 5], # add 5 - ...) - - Step 5: Rebuild and test: - - pip install -e . - python examples/verify_algo.py --topk-type sglang --topk-mapping-mode 5 ... - - Key constraint: Your mapping function only affects Stage 1 (coarse 256-bin histogram). Stage 2 refinement always uses raw float bits via convert_to_uint32(), so the final - top-K selection is always correct regardless of your mapping. The goal is to make the Stage 1 histogram more uniform so fewer elements land in the threshold bin, reducing - Stage 2 work. - - If your custom mapping needs extra parameters (like a tensor or scalar), add them to the TopKMappingParams struct, pass them through topk_output_sglang() host function, update - register.h/register.cc bindings, and read them from ctx in output_func.py. - - 1. csrc/topk_sglang.cu — New CUDA kernel - - - TopKHitRate_Kernel: Stage-1-only kernel with mapping support. Builds 256-bin histogram, writes raw histogram to global memory, runs cumsum to find threshold bin, then - computes stage1_resolved = nblk - items_in_threshold_bin. No Stage 2 needed. - - topk_hit_rate(): Host entry point mirroring topk_output_sglang() for mapping param construction and dtype dispatch. - - 2. csrc/register.h + csrc/register.cc — PyBind11 bindings - - - Added topk_hit_rate declaration and m.def(...) binding with default args for mapping params. - - 3. benchmarks/bench_topk.py — Benchmark integration - - - Added compute_hit_rate_stats() helper for per-segment resolution rate + histogram stats. - - Added --hit-rate CLI flag. When enabled, iterates over available mapping modes (0, 3, 4 always; 1 if LUT provided; 2 if quantiles provided) and prints a comparison table. - - Results stored in config_results["hit_rate"] for JSON output. - - 4. benchmarks/analyze_topk_distribution.py — New visualization script - - - 5 plot functions: bin distribution, heatmap, before/after mapping, summary table, mode comparison. - - Loads from --profile-npz and/or --bench-json. - - 5. End-to-end integration (your request) - - - vortex_torch/indexer/context.py: Added topk_hit_rate_enabled slot, populated from sa.vortex_topk_hit_rate (default False). - - vortex_torch/indexer/output_func.py: After the main topk kernel call, if ctx.topk_hit_rate_enabled is True and topk_type is "sglang", it calls topk_hit_rate() and stores - results in self.last_hit_rate_stats / self.last_hit_rate_histograms. Zero overhead when disabled — just a getattr check that short-circuits. - - To enable during inference, set vortex_topk_hit_rate=True in your SGLang server args. \ No newline at end of file diff --git a/vortex_torch/indexer/output_func.py b/vortex_torch/indexer/output_func.py index 53e9717..9c7e076 100644 --- a/vortex_torch/indexer/output_func.py +++ b/vortex_torch/indexer/output_func.py @@ -309,6 +309,7 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso mapping_lut, mapping_quantiles, mapping_noscale, + ctx.topk_val, ) # Accumulate histograms for offline calibration _calibration_histograms.append(self.last_histograms.cpu().clone()) From 080c253430ec0da6ee88063b971ae68174f46c01 Mon Sep 17 00:00:00 2001 From: UED Date: Tue, 7 Apr 2026 06:30:19 +0000 Subject: [PATCH 16/22] Enhance TopK mapping with new modes and original sglang kernel support - Added ExpStretch and TopkWindow modes to analyze_topk_distribution.py and bench_topk.py. - Introduced topk_output_sglang_ori function for original sglang kernel in vortex_torch_C. - Updated autotune and benchmark scripts to include new modes and original kernel. - Modified example scripts to reflect changes in histogram calibration and TopK mapping parameters. --- benchmarks/analyze_topk_distribution.py | 4 + benchmarks/autotune_topk_mapping.py | 4 + benchmarks/bench_topk.py | 101 +++- csrc/register.cc | 6 + csrc/register.h | 13 + csrc/topk_mapping.cuh | 21 +- csrc/topk_sglang.cu | 363 +++++++++++- csrc/topk_slgang_ori.cu | 546 ++++++++++++++++++ examples/run_distribution_analysis.sh | 5 +- examples/run_distribution_analysis_new.sh | 9 +- examples/run_topk_benchmark.sh | 6 +- examples/verify_algo.py | 125 +++- examples/verify_algo.sh | 5 + examples/verify_algo_topk_mapping.sh | 44 +- .../verify_algo_topk_mapping_indexcache.sh | 45 -- examples/verify_algo_topk_mapping_new.sh | 166 +++++- third_party/sglang | 2 +- vortex_torch/indexer/output_func.py | 17 +- 18 files changed, 1342 insertions(+), 140 deletions(-) create mode 100644 csrc/topk_slgang_ori.cu delete mode 100644 examples/verify_algo_topk_mapping_indexcache.sh diff --git a/benchmarks/analyze_topk_distribution.py b/benchmarks/analyze_topk_distribution.py index 00cdf28..5531187 100644 --- a/benchmarks/analyze_topk_distribution.py +++ b/benchmarks/analyze_topk_distribution.py @@ -43,6 +43,8 @@ 9: "Erf", 10: "Tanh", 11: "Subtract", + 13: "ExpStretch", + 14: "TopkWindow", } MAPPING_MODE_FORMULAS = { @@ -58,6 +60,8 @@ 9: "Erf: erf(alpha*x)", 10: "Tanh: tanh(alpha*x)", 11: "Subtract: x - pivot (RadiK-style)", + 13: "ExpStretch: exp(alpha*x)", + 14: "TopkWindow: k-aware linear windowing", } diff --git a/benchmarks/autotune_topk_mapping.py b/benchmarks/autotune_topk_mapping.py index d95c839..f04418d 100644 --- a/benchmarks/autotune_topk_mapping.py +++ b/benchmarks/autotune_topk_mapping.py @@ -39,6 +39,8 @@ 7: ("alpha", [0.1, 0.5, 0.75, 1.0, 2.0, 4.0, 8.0]), 9: ("alpha", [0.1, 0.5, 1.0, 2.0, 4.0]), 10: ("alpha", [0.1, 0.5, 1.0, 2.0, 4.0]), + 13: ("alpha", [0.5, 1.0, 2.0, 4.0, 8.0]), + 14: ("rho", [2.0, 4.0, 8.0, 16.0]), } BASELINES = { 0: ("none", 0.5), @@ -64,6 +66,8 @@ 9: "erf", 10: "tanh", 11: "subtract", + 13: "exp_stretch", + 14: "topk_window", } diff --git a/benchmarks/bench_topk.py b/benchmarks/bench_topk.py index 675092e..2fd1e31 100644 --- a/benchmarks/bench_topk.py +++ b/benchmarks/bench_topk.py @@ -19,7 +19,7 @@ import torch from vortex_torch_C import ( - topk_output, topk_output_sglang, topk_profile_histogram, + topk_output, topk_output_sglang, topk_output_sglang_ori, topk_profile_histogram, topk_profile_stage1, topk_profile_counters, ) @@ -37,6 +37,8 @@ 9: "Erf", 10: "Tanh", 11: "Subtract", + 13: "ExpStretch", + 14: "TopkWindow", } MAPPING_MODE_FORMULAS = { @@ -52,6 +54,8 @@ 9: "Erf: erf(alpha*x)", 10: "Tanh: tanh(alpha*x)", 11: "Subtract: x - pivot (RadiK-style)", + 13: "ExpStretch: exp(alpha*x)", + 14: "TopkWindow: k-aware linear windowing", } @@ -209,7 +213,7 @@ def _load_autotune_powers(path: str) -> Dict[int, float]: best: Dict[int, dict] = {} for r in data: m = r.get("mode") - if m not in (3, 6, 7, 9, 10): + if m not in (3, 6, 7, 9, 10, 13, 14): continue if has_res_rate: score = r.get("res_rate_mean", 0.0) @@ -229,7 +233,8 @@ def _resolve_mode_power(args, mode: int) -> float: Priority: per-mode CLI flag > autotune JSON > global --mapping-power. """ per_mode_flag = {3: args.mapping_power_3, 6: args.mapping_power_6, 7: args.mapping_power_7, - 9: getattr(args, 'mapping_power_9', None), 10: getattr(args, 'mapping_power_10', None)} + 9: getattr(args, 'mapping_power_9', None), 10: getattr(args, 'mapping_power_10', None), + 13: getattr(args, 'mapping_power_13', None), 14: getattr(args, 'mapping_power_14', None)} if mode in per_mode_flag and per_mode_flag[mode] is not None: return per_mode_flag[mode] if hasattr(args, "_autotune_powers") and mode in args._autotune_powers: @@ -285,6 +290,7 @@ def run_benchmark(args) -> List[dict]: # Build kernel list all_kernels = { "naive": "naive", + "sglang_ori": "sglang_ori", "sglang_m0": "sglang_m0", "sglang_scale": "sglang_scale", # mode 3 with p=1.0 (identity + linear auto-range scaling) "sglang_m3": "sglang_m3", @@ -300,6 +306,9 @@ def run_benchmark(args) -> List[dict]: "sglang_m10": "sglang_m10", "sglang_m10_noscale": "sglang_m10_noscale", "sglang_m11": "sglang_m11", + "sglang_m13": "sglang_m13", + "sglang_m13_noscale": "sglang_m13_noscale", + "sglang_m14": "sglang_m14", } if mapping_lut is not None: all_kernels["sglang_m1"] = "sglang_m1" @@ -409,6 +418,20 @@ def run_benchmark(args) -> List[dict]: pages_per_seg, ) result = bench_kernel(topk_output, call_args, args.warmup, args.repeat) + elif kernel_name == "sglang_ori": + call_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, + topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + ) + result = bench_kernel(topk_output_sglang_ori, call_args, args.warmup, args.repeat) elif kernel_name == "sglang_scale": call_args = ( inputs["x"], @@ -437,7 +460,7 @@ def run_benchmark(args) -> List[dict]: elif mode == 2: extra_kwargs["mapping_quantiles"] = mapping_quantiles - if mode in (3, 6, 7, 9, 10): + if mode in (3, 6, 7, 9, 10, 13, 14): power = _resolve_mode_power(args, mode) else: power = 0.5 @@ -464,6 +487,8 @@ def run_benchmark(args) -> List[dict]: # Build label if kernel_name == "naive": label = "naive" + elif kernel_name == "sglang_ori": + label = "sglang Ori (no remap)" elif kernel_name == "sglang_scale": label = "sglang Scale Only (p=1.0)" else: @@ -471,14 +496,14 @@ def run_benchmark(args) -> List[dict]: m = int(m_str.split("_")[0]) noscale_suffix = " noscale" if kernel_name.endswith("_noscale") else "" mname = MAPPING_MODE_NAMES.get(m, f'm{m}') - if m in (3, 6, 7, 9, 10): - pname = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha"}[m] + if m in (3, 6, 7, 9, 10, 13, 14): + pname = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", 13: "alpha", 14: "rho"}[m] label = f"sglang {mname} ({pname}={_resolve_mode_power(args, m)}){noscale_suffix}" else: label = f"sglang {mname}{noscale_suffix}" - # Sub-phase profiling for sglang kernels - if kernel_name != "naive": + # Sub-phase profiling for sglang kernels (skip ori baseline) + if kernel_name not in ("naive", "sglang_ori"): if kernel_name == "sglang_scale": s1_mode, s1_power = 3, 1.0 s1_lut, s1_q = None, None @@ -487,7 +512,7 @@ def run_benchmark(args) -> List[dict]: s1_mode_str = kernel_name.split("_m")[1] s1_mode = int(s1_mode_str.split("_")[0]) s1_noscale = kernel_name.endswith("_noscale") - if s1_mode in (3, 6, 7, 9, 10): + if s1_mode in (3, 6, 7, 9, 10, 13, 14): s1_power = _resolve_mode_power(args, s1_mode) else: s1_power = 0.5 @@ -581,6 +606,46 @@ def run_benchmark(args) -> List[dict]: 'stage2_input_max': c[:, 5].max().item(), } + # Counter collection for kernels skipped by sub-phase profiling + if kernel_name in ("sglang_ori",) and args.counters: + inputs["sparse_kv_indices"].zero_() + counter_buf = torch.zeros(eff_bs, 6, dtype=torch.int32, device="cuda") + counter_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + counter_buf, + eff_bs, + topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + 0, # mode 0 (no mapping) — matches ori behavior + 0.5, + None, + None, + False, + ) + topk_profile_counters(*counter_args) + torch.cuda.synchronize() + c = counter_buf.float() + result['counters'] = { + 'threshold_bin_mean': c[:, 0].mean().item(), + 'num_above_mean': c[:, 1].mean().item(), + 'num_equal_mean': c[:, 2].mean().item(), + 'remaining_k_mean': c[:, 3].mean().item(), + 'refine_rounds_mean': c[:, 4].mean().item(), + 'stage2_input_mean': c[:, 5].mean().item(), + 'threshold_bin_max': c[:, 0].max().item(), + 'num_above_max': c[:, 1].max().item(), + 'num_equal_max': c[:, 2].max().item(), + 'remaining_k_max': c[:, 3].max().item(), + 'refine_rounds_max': c[:, 4].max().item(), + 'stage2_input_max': c[:, 5].max().item(), + } + kernel_entries.append((label, kernel_name, result)) config_results["kernels"][kernel_name] = result @@ -703,7 +768,7 @@ def run_benchmark(args) -> List[dict]: extra_lut = mapping_lut if mode == 1 else None extra_q = mapping_quantiles if mode == 2 else None - power = _resolve_mode_power(args, mode) if mode in (3, 6, 7, 9, 10) else 0.5 + power = _resolve_mode_power(args, mode) if mode in (3, 6, 7, 9, 10, 13, 14) else 0.5 topk_profile_histogram( hist_inputs["x"], @@ -716,6 +781,8 @@ def run_benchmark(args) -> List[dict]: power, extra_lut, extra_q, + False, # mapping_noscale + topk_val, # needed for mode 12/14 (tail/topk window) ) torch.cuda.synchronize() @@ -725,8 +792,8 @@ def run_benchmark(args) -> List[dict]: mformula = MAPPING_MODE_FORMULAS.get(mode, mname) mode_stats["name"] = mname mode_stats["formula"] = mformula - if mode in (3, 6, 7, 9, 10): - pname = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha"}[mode] + if mode in (3, 6, 7, 9, 10, 13, 14): + pname = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", 13: "alpha", 14: "rho"}[mode] mode_stats["param"] = f"{pname}={power}" display_name = f"{mname} ({pname}={power})" else: @@ -736,7 +803,7 @@ def run_benchmark(args) -> List[dict]: hist_entries.append((display_name, f"mode {mode:2d}", mode_stats)) # Noscale histogram analysis for parametric transform modes - noscale_modes = [m for m in (3, 6, 7, 9, 10) if m in modes_to_test] + noscale_modes = [m for m in (3, 6, 7, 9, 10, 13) if m in modes_to_test] for mode in noscale_modes: ns_hists = torch.zeros(hist_eff_bs, 256, dtype=torch.int32, device="cuda") power = _resolve_mode_power(args, mode) @@ -758,7 +825,7 @@ def run_benchmark(args) -> List[dict]: ns_stats["raw_counts"] = ns_hists.sum(dim=0).tolist() mname = MAPPING_MODE_NAMES.get(mode, f"m{mode}") mformula = MAPPING_MODE_FORMULAS.get(mode, mname) - pname = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha"}[mode] + pname = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", 13: "alpha"}[mode] ns_stats["name"] = f"{mname} noscale" ns_stats["formula"] = mformula ns_stats["param"] = f"{pname}={power}" @@ -831,9 +898,13 @@ def main(): help="Beta for mode 6 asinh (overrides --mapping-power)") parser.add_argument("--mapping-power-7", type=float, default=None, help="Alpha for mode 7 log1p (overrides --mapping-power)") + parser.add_argument("--mapping-power-13", type=float, default=None, + help="Alpha for mode 13 exp_stretch (overrides --mapping-power)") + parser.add_argument("--mapping-power-14", type=float, default=None, + help="Rho for mode 14 topk_window (overrides --mapping-power)") parser.add_argument("--autotune-json", type=str, default=None, help="Path to autotune_results.json — extracts best per-mode hyperparameters " - "(overrides --mapping-power for modes 3/6/7)") + "(overrides --mapping-power for modes 3/6/7/13/14)") parser.add_argument("--lut-path", type=str, default=None, help="Path to .npy uint8[256] LUT for mode=1") parser.add_argument("--quantiles-path", type=str, default=None, help="Path to .npy float32[256] for mode=2") parser.add_argument("--output-json", type=str, default=None, help="Save results to JSON file") diff --git a/csrc/register.cc b/csrc/register.cc index af49aec..b968e9c 100644 --- a/csrc/register.cc +++ b/csrc/register.cc @@ -19,6 +19,12 @@ PYBIND11_MODULE(vortex_torch_C, m){ py::arg("mapping_lut") = py::none(), py::arg("mapping_quantiles") = py::none(), py::arg("mapping_noscale") = false); + m.def("topk_output_sglang_ori", &topk_output_sglang_ori, + py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), + py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages")); m.def("topk_profile_histogram", &topk_profile_histogram, py::arg("x"), py::arg("dense_kv_indptr"), py::arg("histograms"), py::arg("eff_batch_size"), diff --git a/csrc/register.h b/csrc/register.h index dae5e82..d4a311b 100644 --- a/csrc/register.h +++ b/csrc/register.h @@ -103,6 +103,19 @@ std::optional mapping_quantiles = std::nullopt, const bool mapping_noscale = false ); +void topk_output_sglang_ori( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& sparse_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& sparse_kv_indices, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_num_pages +); + void topk_profile_histogram( const at::Tensor& x, const at::Tensor& dense_kv_indptr, diff --git a/csrc/topk_mapping.cuh b/csrc/topk_mapping.cuh index 8dbb808..0893008 100644 --- a/csrc/topk_mapping.cuh +++ b/csrc/topk_mapping.cuh @@ -38,6 +38,8 @@ enum TopKMappingMode { MAPPING_TANH = 10, // tanh(alpha * x) MAPPING_SUBTRACT = 11, // subtract pivot, then fp16 bucketing MAPPING_ADAPTIVE_TAIL_WINDOW = 12, // focus bins on upper tail via sampled quantile + MAPPING_EXP_STRETCH = 13, // exp(alpha * x), concentrates bin resolution on upper tail + MAPPING_TOPK_WINDOW = 14, // k-aware linear windowing: focus bins on [tau_low, max] }; struct TopKMappingParams { @@ -81,6 +83,12 @@ __device__ __forceinline__ float transform_tanh(float x, float alpha) { return tanhf(alpha * x); } +__device__ __forceinline__ float transform_exp_stretch(float x, float alpha) { + float z = alpha * x; + z = fminf(z, 80.0f); // prevent float32 overflow (exp(80) ~ 5.5e34) + return expf(z); +} + // ---- Transform dispatcher (returns float, no bucketing) ---- __device__ __forceinline__ float apply_transform(float x, const TopKMappingParams& params) { @@ -91,6 +99,7 @@ __device__ __forceinline__ float apply_transform(float x, const TopKMappingParam case MAPPING_LOG1P: return transform_log1p(x, params.power_exp); case MAPPING_ERF: return transform_erf(x, params.power_exp); case MAPPING_TANH: return transform_tanh(x, params.power_exp); + case MAPPING_EXP_STRETCH: return transform_exp_stretch(x, params.power_exp); default: return x; } } @@ -161,7 +170,8 @@ __device__ __forceinline__ uint8_t mapped_convert_to_uint8( case MAPPING_ASINH: case MAPPING_LOG1P: case MAPPING_ERF: - case MAPPING_TANH: { + case MAPPING_TANH: + case MAPPING_EXP_STRETCH: { float val = apply_transform(x, params); if (params.noscale) return convert_to_uint8(val); return linear_map_to_uint8(val, range_min, inv_range); @@ -171,6 +181,7 @@ __device__ __forceinline__ uint8_t mapped_convert_to_uint8( case MAPPING_SUBTRACT: return convert_to_uint8(x - range_min); // range_min repurposed as pivot case MAPPING_ADAPTIVE_TAIL_WINDOW: + case MAPPING_TOPK_WINDOW: return linear_map_to_uint8(x, range_min, inv_range); default: // MAPPING_NONE return convert_to_uint8(x); @@ -181,7 +192,8 @@ __device__ __forceinline__ uint8_t mapped_convert_to_uint8( __device__ __forceinline__ bool needs_auto_range(int mode) { return (mode == MAPPING_POWER || mode == MAPPING_LOG || mode == MAPPING_ASINH || mode == MAPPING_LOG1P || - mode == MAPPING_ERF || mode == MAPPING_TANH); + mode == MAPPING_ERF || mode == MAPPING_TANH || + mode == MAPPING_EXP_STRETCH); } // Helper: check if a mapping mode needs the pivot pre-pass @@ -193,3 +205,8 @@ __device__ __forceinline__ bool needs_pivot(int mode) { __device__ __forceinline__ bool needs_tail_window(int mode) { return (mode == MAPPING_ADAPTIVE_TAIL_WINDOW); } + +// Helper: check if mode is the lightweight topk-window pre-pass +__device__ __forceinline__ bool needs_topk_window(int mode) { + return (mode == MAPPING_TOPK_WINDOW); +} diff --git a/csrc/topk_sglang.cu b/csrc/topk_sglang.cu index 867efbe..1d12c30 100644 --- a/csrc/topk_sglang.cu +++ b/csrc/topk_sglang.cu @@ -450,7 +450,7 @@ __device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) return __bfloat162float(x); } -constexpr int VORTEX_MAX_TOPK = 2048; +constexpr int VORTEX_MAX_TOPK = 4096; // Per-segment diagnostic counters written by WriteCounters mode constexpr int COUNTER_THRESHOLD_BIN = 0; // Stage 1 coarse threshold bin id @@ -461,7 +461,178 @@ constexpr int COUNTER_REFINE_ROUNDS = 4; // Stage 2 rounds used (0 = resolved constexpr int COUNTER_STAGE2_INPUT = 5; // candidates entering first Stage 2 refine round constexpr int NUM_TOPK_COUNTERS = 6; -// Templated version of fast_topk_cuda_tl: +// ====================================================================== +// Ori fast path: zero-overhead topk with no mapping infrastructure. +// Adapted from topk_slgang_ori.cu — uses direct convert_to_uint8() +// for Stage 1 binning with no pre-pass, no LUT, no bin cache. +// ====================================================================== +template +__device__ void fast_topk_ori( + const ScoreT* __restrict__ input, + int* __restrict__ index, + int row_start, + int length, + int target_k) +{ + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // Stage 1: 8-bit coarse histogram (direct convert_to_uint8, no mapping) + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(vortex_to_float(input[idx + row_start])); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uint8(vortex_to_float(input[idx + row_start]))); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // Stage 2: refine with 8-bit radix passes +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +// ====================================================================== +// Templated version of fast_topk_cuda_tl with mapping support: // - ScoreT: float or __nv_bfloat16 // - StopAfterStage1: return after Stage 1 route/filter (for profiling) // - WriteCounters: write diagnostic counters to global memory @@ -678,6 +849,47 @@ __device__ void fast_topk_vortex( s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; } __syncthreads(); + } else if (needs_topk_window(mapping.mode)) { + // Lightweight topk-window pre-pass: compute min/max of raw values, + // then focus all 256 bins on [tau_low, max] where + // tau_low = max - (max - min) * rho * k / length. + // Like mode 12 but uses a simple heuristic instead of quantile estimation. + float local_min = __FLT_MAX__, local_max = -__FLT_MAX__; + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + float val = vortex_to_float(input[idx + row_start]); + local_min = fminf(local_min, val); + local_max = fmaxf(local_max, val); + } + for (int offset = 16; offset > 0; offset >>= 1) { + local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + } + __shared__ float s_warp_mins_tw2[32], s_warp_maxs_tw2[32]; + { + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) { s_warp_mins_tw2[warp_id] = local_min; s_warp_maxs_tw2[warp_id] = local_max; } + } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_min = s_warp_mins_tw2[tx]; local_max = s_warp_maxs_tw2[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + } + if (tx == 0) { + float rho = mapping.power_exp; + if (rho <= 0.0f) rho = 4.0f; + int k = (mapping.target_k > 0) ? mapping.target_k : target_k; + float frac = rho * float(k) / float(length); + frac = fminf(frac, 1.0f); + float tau_low = local_max - (local_max - local_min) * frac; + if (tau_low >= local_max) tau_low = local_min; + float range = local_max - tau_low; + s_range_min = tau_low; + s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; + } + } + __syncthreads(); } else { if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } __syncthreads(); @@ -918,6 +1130,42 @@ void TopKOutput_Kernel( } } +// Ori fast-path wrapper: zero mapping overhead +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKOutput_Ori_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos) +{ + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_ori(score_blk, s_indices, 0, nblk, topk_val); + __syncthreads(); + + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } +} + // ====================================================================== // Profiling Stage1 kernel: runs pre-pass + hist + cumsum + route/filter, // stops before Stage 2 refinement (for sub-phase timing) @@ -1382,9 +1630,100 @@ void topk_output_sglang( dim3 nthreads(kThreadsPerBlock); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + // Fast path for mode 0 (MAPPING_NONE): use ori kernel with zero mapping overhead + if (mapping_mode == MAPPING_NONE) { + if (x.scalar_type() == at::ScalarType::BFloat16) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Ori_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos); + } else if (x.scalar_type() == at::ScalarType::Float) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Ori_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos); + } else { + TORCH_CHECK(false, "topk_output: unsupported dtype ", x.scalar_type()); + } + } else { + if (x.scalar_type() == at::ScalarType::BFloat16) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos, + mapping); + } else if (x.scalar_type() == at::ScalarType::Float) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos, + mapping); + } else { + TORCH_CHECK(false, "topk_output: unsupported dtype ", x.scalar_type()); + } + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_output kernel failed: ", ::cudaGetErrorString(result)); +} + +// ====================================================================== +// Explicit ori baseline entry point — always uses the ori fast path +// ====================================================================== +void topk_output_sglang_ori( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages) +{ + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_output_sglang_ori: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(sparse_kv_indptr); + CHECK_CUDA(dense_kv_indices); + CHECK_CUDA(sparse_kv_indices); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + if (x.scalar_type() == at::ScalarType::BFloat16) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Kernel<__nv_bfloat16><<>>( + setup_kernel_smem_once, kSmem>(); + TopKOutput_Ori_Kernel<__nv_bfloat16><<>>( reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), dense_kv_indptr.data_ptr(), sparse_kv_indptr.data_ptr(), @@ -1392,11 +1731,10 @@ void topk_output_sglang( sparse_kv_indices.data_ptr(), topk_val, reserved_bos, - reserved_eos, - mapping); + reserved_eos); } else if (x.scalar_type() == at::ScalarType::Float) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Kernel<<>>( + setup_kernel_smem_once, kSmem>(); + TopKOutput_Ori_Kernel<<>>( x.data_ptr(), dense_kv_indptr.data_ptr(), sparse_kv_indptr.data_ptr(), @@ -1404,17 +1742,14 @@ void topk_output_sglang( sparse_kv_indices.data_ptr(), topk_val, reserved_bos, - reserved_eos, - mapping); + reserved_eos); } else { - TORCH_CHECK(false, - "topk_output: unsupported dtype ", - x.scalar_type()); + TORCH_CHECK(false, "topk_output_sglang_ori: unsupported dtype ", x.scalar_type()); } const auto result = cudaGetLastError(); TORCH_CHECK(result == cudaSuccess, - "topk_output kernel failed: ", ::cudaGetErrorString(result)); + "topk_output_sglang_ori kernel failed: ", ::cudaGetErrorString(result)); } // ====================================================================== diff --git a/csrc/topk_slgang_ori.cu b/csrc/topk_slgang_ori.cu new file mode 100644 index 0000000..04a2b73 --- /dev/null +++ b/csrc/topk_slgang_ori.cu @@ -0,0 +1,546 @@ +/** + * @NOTE: This file is adapted from + * https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_v32/topk_selector.py + * We: + * 1. adapt from tilelang to pure cuda + * 2. optimize the performance a little + * 3. fix the potential illegal memory access + */ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace { + +constexpr int TopK = 2048; +constexpr int kThreadsPerBlock = 1024; + +#ifdef USE_ROCM +// On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a +// per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. +#ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES +constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); +#else +constexpr size_t kSmem = 48 * 1024; // bytes +#endif +#else +// Reduced from 128KB to 32KB to improve occupancy. +// Each radix pass needs at most ~TopK candidates in the threshold bin, +// so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. +constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) +#endif + +struct FastTopKParams { + const float* __restrict__ input; // [B, input_stride] + const int32_t* __restrict__ row_starts; // [B] + int32_t* __restrict__ indices; // [B, TopK] + int32_t* __restrict__ lengths; // [B] + int64_t input_stride; +}; + +// when length <= TopK, we can directly write the indices +__device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { + const auto tid = threadIdx.x; + for (int i = tid; i < TopK; i += kThreadsPerBlock) { + indice[i] = (i < length) ? i : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform( + const float* __restrict__ score, + int32_t length, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + dst_page_table[i] = (i < length) ? src_page_table[i] : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform_ragged( + const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; + } +} + +__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +__device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { + // An optimized topk kernel copied from tilelang kernel + // We assume length > TopK here, or it will crash + int topk = TopK; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + // allocate for two rounds + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // stage 1: 8bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(input[idx + row_start]); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uint8(input[idx + row_start])); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = input[idx + row_start]; + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + /// NOTE: (dark) fuse the histogram computation here + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8bit radix passes +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + // clip here to prevent overflow + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = input[idx + row_start]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[TopK - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + /// NOTE: (dark) fuse the histogram computation here + s_input_idx[r_idx ^ 1][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // topk + void topk_kernel(const FastTopKParams params) { + const auto& [input, row_starts, indices, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto indice = indices + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_cuda(score, indice, length); + } else { + return fast_topk_cuda_tl(score, indice, row_start, length); + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // decode + void topk_transform_decode_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride) { + const auto& [input, _1, _2, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = 0; + const auto length = lengths[bid]; + const auto src_page_entry = src_page_table + bid * src_stride; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // prefill + void topk_transform_prefill_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride, + const int32_t* __restrict__ cu_seqlens_q, + const int64_t prefill_bs) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto length = lengths[bid]; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + + /// NOTE: prefill bs is usually small, we can just use a simple loop here + /// We ensure that last cu_seqlens is equal to number of blocks launched + __shared__ const int32_t* s_src_page_entry; + if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { + if (tid < prefill_bs) { + if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { + s_src_page_entry = src_page_table + tid * src_stride; + } + } + } else { + for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { + if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { + s_src_page_entry = src_page_table + i * src_stride; + } + } + } + __syncthreads(); + const auto src_page_entry = s_src_page_entry; + + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv + void topk_transform_prefill_ragged_kernel( + const FastTopKParams params, + int32_t* __restrict__ topk_indices_ragged, + const int32_t* __restrict__ topk_indices_offset) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto dst_indices_entry = topk_indices_ragged + bid * TopK; + const auto score = input + bid * input_stride; + const auto offset = topk_indices_offset[bid]; + + if (length <= TopK) { + return naive_topk_transform_ragged(score, length, dst_indices_entry, offset); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_indices_entry[idx_0] = pos_0 + offset; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_indices_entry[idx_1] = pos_1 + offset; + } +} + +auto get_params( + const at::Tensor& score, + const at::Tensor& lengths, + std::optional row_starts_opt = std::nullopt, + std::optional indices_opt = std::nullopt) -> FastTopKParams { + const auto B = score.size(0); + TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); + if (row_starts_opt.has_value()) { + const auto& row_starts = row_starts_opt.value(); + TORCH_CHECK(row_starts.dim() == 1); + TORCH_CHECK(row_starts.size(0) == B); + } + TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); + TORCH_CHECK(lengths.size(0) == B); + int32_t* indices_data_ptr = nullptr; + if (indices_opt.has_value()) { + const auto& indices = indices_opt.value(); + TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); + TORCH_CHECK(indices.size(0) == B); + TORCH_CHECK(indices.size(1) == TopK); + indices_data_ptr = indices.data_ptr(); + } + + return FastTopKParams{ + .input = score.data_ptr(), + .row_starts = row_starts_opt.has_value() ? row_starts_opt->data_ptr() : nullptr, + .indices = indices_data_ptr, + .lengths = lengths.data_ptr(), + .input_stride = score.stride(0), + }; +} + +template +void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { +#ifdef USE_ROCM + // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, + // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing + // a function pointer directly, so cast explicitly. + return ::cudaFuncSetAttribute( + reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#else + // CUDA: keep original behavior (no cast needed). + return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#endif + }(); + TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); +} + +} // namespace + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +void fast_topk_interface( + const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(indices); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + CHECK_CUDA(lengths); + const auto params = get_params(score, lengths, row_starts_opt, indices); + const auto B = score.size(0); + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + setup_kernel_smem_once(); + topk_kernel<<>>(params); + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} + +void fast_topk_transform_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& dst_page_table, + const at::Tensor& src_page_table, + const at::Tensor& cu_seqlens_q, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(dst_page_table); + CHECK_CUDA(src_page_table); + CHECK_CUDA(cu_seqlens_q); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); + TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); + TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); + const auto prefill_bs = cu_seqlens_q.size(0) - 1; + TORCH_CHECK(dst_page_table.size(0) == B); + TORCH_CHECK(dst_page_table.size(1) == TopK); + TORCH_CHECK(src_page_table.size(0) == prefill_bs); + TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + const auto src_stride = src_page_table.stride(0); + + // dispatch to decode or prefill + // extend and draft extend: row_starts_opt is not null, invokes the prefill kernel + // decode: row_starts_opt is null, invokes the decode kernel + // target verify: row_starts_opt is null, invokes the prefill kernel + const auto is_decode = !row_starts_opt.has_value() && prefill_bs == B; + if (is_decode) { + setup_kernel_smem_once(); + topk_transform_decode_kernel<<>>( + params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride); + } else { + setup_kernel_smem_once(); + topk_transform_prefill_kernel<<>>( + params, + dst_page_table.data_ptr(), + src_page_table.data_ptr(), + src_stride, + cu_seqlens_q.data_ptr(), + prefill_bs); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} + +void fast_topk_transform_ragged_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& topk_indices_ragged, + const at::Tensor& topk_indices_offset, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(topk_indices_ragged); + CHECK_CUDA(topk_indices_offset); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous()); + TORCH_CHECK(topk_indices_offset.dim() == 1); + + TORCH_CHECK(topk_indices_ragged.size(0) == B); + TORCH_CHECK(topk_indices_ragged.size(1) == TopK); + TORCH_CHECK(topk_indices_offset.size(0) == B); + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + + setup_kernel_smem_once(); + topk_transform_prefill_ragged_kernel<<>>( + params, topk_indices_ragged.data_ptr(), topk_indices_offset.data_ptr()); + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} diff --git a/examples/run_distribution_analysis.sh b/examples/run_distribution_analysis.sh index 98557c7..6806eca 100755 --- a/examples/run_distribution_analysis.sh +++ b/examples/run_distribution_analysis.sh @@ -48,8 +48,8 @@ TOPK_VAL=30 MEM=0.7 ALGO="block_sparse_attention" # The path to the raw_histograms.npy file (set to skip calibration) +REAL_HISTOGRAMS="/data/datasets/xinrui/My_Projects/vortex_torch/examples/calibration/raw_histograms.npy" REAL_HISTOGRAMS="" - # ── Parse arguments ─────────────────────────────────────────── while [[ $# -gt 0 ]]; do case "$1" in @@ -162,9 +162,10 @@ PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ --num-kv-heads 8 \ --distributions bucket_uniform normal \ --histogram \ + --counters \ "${BENCH_EXTRA_ARGS[@]}" \ --autotune-json "${AUTOTUNE_JSON}" \ - --filter-kernels naive sglang_m0 sglang_scale sglang_m1 sglang_m2 sglang_m3 sglang_m3_noscale sglang_m4 sglang_m6 sglang_m6_noscale sglang_m7 sglang_m7_noscale sglang_m8 sglang_m9 sglang_m9_noscale sglang_m10 sglang_m10_noscale sglang_m11 \ + --filter-kernels naive sglang_ori sglang_m0 sglang_scale sglang_m1 sglang_m2 sglang_m3 sglang_m3_noscale sglang_m4 sglang_m6 sglang_m6_noscale sglang_m7 sglang_m7_noscale sglang_m8 sglang_m9 sglang_m9_noscale sglang_m10 sglang_m10_noscale sglang_m11 sglang_m13 sglang_m13_noscale sglang_m14 \ --repeat 20 \ --output-json "${BENCH_JSON}" \ 2>&1 | tee "${RUN_DIR}/step3_bench.log" diff --git a/examples/run_distribution_analysis_new.sh b/examples/run_distribution_analysis_new.sh index 623bc82..1f89c0b 100755 --- a/examples/run_distribution_analysis_new.sh +++ b/examples/run_distribution_analysis_new.sh @@ -35,8 +35,8 @@ TOPK_VAL=30 MEM=0.7 ALGO="block_sparse_attention" # The path to the raw_histograms.npy file (set to skip calibration) -# REAL_HISTOGRAMS="/scr/dataset/yuke/xinrui/new/vortex_torch/examples/calibration/raw_histograms.npy" -REAL_HISTOGRAMS="" +REAL_HISTOGRAMS="/data/datasets/xinrui/My_Projects/vortex_torch/examples/calibration/raw_histograms.npy" +# REAL_HISTOGRAMS="" # ── Parse arguments ─────────────────────────────────────────── while [[ $# -gt 0 ]]; do case "$1" in @@ -55,7 +55,7 @@ export CUDA_VISIBLE_DEVICES="${GPU_ID}" RESULTS_DIR="${SCRIPT_DIR}/results" mkdir -p "${RESULTS_DIR}" TIMESTAMP=$(date +%Y%m%d_%H%M%S) -RUN_DIR="${RESULTS_DIR}/dist_analysis_${TIMESTAMP}" +RUN_DIR="${RESULTS_DIR}/dist_analysis_topk${TOPK_VAL}_${TIMESTAMP}" mkdir -p "${RUN_DIR}" echo "============================================================" @@ -120,9 +120,10 @@ PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ --num-kv-heads 8 \ --distributions bucket_uniform normal \ --histogram \ + --counters \ --real-histograms "${REAL_HIST_PATH}" \ --autotune-json "${AUTOTUNE_JSON}" \ - --filter-kernels naive sglang_m0 sglang_scale sglang_m3 sglang_m3_noscale sglang_m6 sglang_m6_noscale sglang_m7 sglang_m7_noscale sglang_m8 sglang_m9 sglang_m9_noscale sglang_m10 sglang_m10_noscale sglang_m11 \ + --filter-kernels naive sglang_ori sglang_m0 sglang_scale sglang_m3 sglang_m3_noscale sglang_m6 sglang_m6_noscale sglang_m7 sglang_m7_noscale sglang_m8 sglang_m9 sglang_m9_noscale sglang_m10 sglang_m10_noscale sglang_m11 sglang_m13 sglang_m13_noscale sglang_m14 \ --repeat 20 \ --output-json "${BENCH_JSON}" \ 2>&1 | tee "${RUN_DIR}/step3_bench.log" diff --git a/examples/run_topk_benchmark.sh b/examples/run_topk_benchmark.sh index 6ac2b9d..d57e2f1 100755 --- a/examples/run_topk_benchmark.sh +++ b/examples/run_topk_benchmark.sh @@ -48,6 +48,7 @@ ALGO="block_sparse_attention" SKIP_CALIBRATE=false SKIP_KERNEL=false SKIP_E2E=true +BENCHMARKS="amc23" # space-separated list, e.g. "amc23 aime24" # ── Parse arguments ─────────────────────────────────────────── while [[ $# -gt 0 ]]; do @@ -58,6 +59,7 @@ while [[ $# -gt 0 ]]; do --mem) MEM="$2"; shift 2 ;; --gpu) GPU_ID="$2"; shift 2 ;; --algo) ALGO="$2"; shift 2 ;; + --benchmark) BENCHMARKS="$2"; shift 2 ;; --skip-calibrate) SKIP_CALIBRATE=true; shift ;; --skip-kernel) SKIP_KERNEL=true; shift ;; --skip-e2e) SKIP_E2E=true; shift ;; @@ -70,7 +72,8 @@ export CUDA_VISIBLE_DEVICES="${GPU_ID}" RESULTS_DIR="${SCRIPT_DIR}/results" mkdir -p "${RESULTS_DIR}" TIMESTAMP=$(date +%Y%m%d_%H%M%S) -RUN_DIR="${RESULTS_DIR}/topk_benchmark_${TIMESTAMP}" +BENCH_LABEL=$(echo "${BENCHMARKS}" | tr ' ' '_') +RUN_DIR="${RESULTS_DIR}/topk_benchmark_${BENCH_LABEL}_${TIMESTAMP}" mkdir -p "${RUN_DIR}" echo "============================================================" @@ -194,6 +197,7 @@ else --trials "${TRIALS}" \ --topk-val "${TOPK_VAL}" \ --model-name "${MODEL_NAME}" \ + --benchmark ${BENCHMARKS} \ --mem "${MEM}" \ "$@" ; } \ 2>&1 | tee "${logfile}" diff --git a/examples/verify_algo.py b/examples/verify_algo.py index fb3e843..32ff5a3 100644 --- a/examples/verify_algo.py +++ b/examples/verify_algo.py @@ -51,6 +51,63 @@ def generate_requests(dataset: Dataset, field_name: str, data_format: str, trial return requests +BENCHMARK_REGISTRY = { + "amc23": { + "type": "jsonl", + "path": "amc23.jsonl", + "prompt_key": "prompt", + "answer_key": "answer", + "question_key": "question", + }, + "aime24": { + "type": "huggingface", + "path": "HuggingFaceH4/aime_2024", + "split": "train", + "field_name": "problem", + "answer_key": "answer", + }, +} + +def _load_benchmark(benchmark_name: str, trials: int, tokenizer=None): + """Load benchmark data and return (prompts, requests) tuple.""" + cfg = BENCHMARK_REGISTRY[benchmark_name] + + if cfg["type"] == "jsonl": + script_dir = os.path.dirname(os.path.abspath(__file__)) + jsonl_path = os.path.join(script_dir, cfg["path"]) + with open(jsonl_path, "r", encoding="utf-8") as f: + requests = [json.loads(line) for line in f] + requests = requests * trials + prompts = [req[cfg["prompt_key"]] for req in requests] + return prompts, requests + + elif cfg["type"] == "huggingface": + dataset = load_dataset(cfg["path"], split=cfg["split"]) + hf_requests = generate_requests(dataset, cfg["field_name"], MATH_QUERY_TEMPLATE) + # Normalize keys: ensure "question" and "answer" exist + for req in hf_requests: + if "question" not in req and cfg["field_name"] in req: + req["question"] = req[cfg["field_name"]] + # Build chat-template prompts if tokenizer is provided + if tokenizer is not None: + texts = [x["conversations"] for x in hf_requests] + prompts = [ + tokenizer.apply_chat_template( + text, tokenize=False, add_generation_prompt=True, enable_thinking=True + ) for text in texts + ] * trials + hf_requests = hf_requests * trials + else: + prompts = [ + MATH_QUERY_TEMPLATE.format(Question=x[cfg["field_name"]]) for x in hf_requests + ] * trials + hf_requests = hf_requests * trials + return prompts, hf_requests + + else: + raise ValueError(f"Unknown benchmark type: {cfg['type']}") + + def verify_algos( trials: int = 2, topk_val: int = 30, @@ -67,6 +124,7 @@ def verify_algos( topk_mapping_quantiles_path: str = None, index_cache_shared_layers: list = None, disable_cuda_graph: bool = False, +benchmark: str = "amc23", ): llm = sgl.Engine(model_path=model_name, @@ -90,11 +148,8 @@ def verify_algos( vortex_topk_mapping_quantiles_path=topk_mapping_quantiles_path, vortex_index_cache_shared_layers=index_cache_shared_layers, ) - with open("amc23.jsonl", "r", encoding="utf-8") as f: - requests = [json.loads(line) for line in f] - - requests = requests * trials - prompts = [req["prompt"] for req in requests] + tokenizer = AutoTokenizer.from_pretrained(model_name) if benchmark != "amc23" else None + prompts, requests = _load_benchmark(benchmark, trials, tokenizer=tokenizer) sampling_params = {"temperature": 0.6, "top_p": 0.95, "top_k": 20, "max_new_tokens": 8192} @@ -247,15 +302,15 @@ def parse_args(): "--topk-type", type=str, default="naive", - choices=["naive", "sglang"], - help='TopK kernel type: "naive" for topk_output, "sglang" for topk_output_sglang (default: "naive").', + choices=["naive", "sglang", "sglang_ori"], + help='TopK kernel type: "naive" for topk_output, "sglang" for topk_output_sglang, "sglang_ori" for original sglang baseline (default: "naive").', ) parser.add_argument( "--topk-mapping-mode", type=int, default=0, - choices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], - help='TopK mapping mode: 0=none, 1=lut_cdf, 2=quantile, 3=power, 4=log, 5=index_cache, 6=asinh, 7=log1p, 8=trunc8, 9=erf, 10=tanh, 11=subtract, 12=adaptive_tail_window (default: 0).', + choices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], + help='TopK mapping mode: 0=none, 1=lut_cdf, 2=quantile, 3=power, 4=log, 5=index_cache, 6=asinh, 7=log1p, 8=trunc8, 9=erf, 10=tanh, 11=subtract, 12=adaptive_tail_window, 13=exp_stretch, 14=topk_window (default: 0).', ) parser.add_argument( @@ -287,6 +342,15 @@ def parse_args(): help="Layer IDs that reuse indices from the nearest preceding full layer (skip indexer).", ) + parser.add_argument( + "--benchmark", + type=str, + nargs="+", + default=["amc23"], + help="Benchmark(s) to run. Available: amc23, aime24. " + "Use multiple values to run several benchmarks sequentially (default: amc23).", + ) + return parser.parse_args() if __name__ == "__main__": @@ -298,22 +362,31 @@ def parse_args(): args.index_cache_shared_layers = list(range(2, 28, 2)) # [2,4,6,...,26] args.topk_mapping_mode = 0 - summary = verify_algos( - trials=args.trials, - topk_val=args.topk_val, - page_size=args.page_size, - vortex_module_name=args.vortex_module_name, - model_name=args.model_name, - sparse_attention=not(args.full_attention), - mem=args.mem, - kv_cache_dtype=args.kv_cache_dtype, - topk_type=args.topk_type, - topk_mapping_mode=args.topk_mapping_mode, - topk_mapping_power=args.topk_mapping_power, - topk_mapping_lut_path=args.topk_mapping_lut_path, - topk_mapping_quantiles_path=args.topk_mapping_quantiles_path, - index_cache_shared_layers=args.index_cache_shared_layers, - ) - print(summary) + for bench_name in args.benchmark: + if bench_name not in BENCHMARK_REGISTRY: + print(f"WARNING: Unknown benchmark '{bench_name}', skipping. Available: {list(BENCHMARK_REGISTRY.keys())}") + continue + print(f"\n{'='*60}") + print(f"Benchmark: {bench_name}") + print(f"{'='*60}") + summary = verify_algos( + trials=args.trials, + topk_val=args.topk_val, + page_size=args.page_size, + vortex_module_name=args.vortex_module_name, + model_name=args.model_name, + sparse_attention=not(args.full_attention), + mem=args.mem, + kv_cache_dtype=args.kv_cache_dtype, + topk_type=args.topk_type, + topk_mapping_mode=args.topk_mapping_mode, + topk_mapping_power=args.topk_mapping_power, + topk_mapping_lut_path=args.topk_mapping_lut_path, + topk_mapping_quantiles_path=args.topk_mapping_quantiles_path, + index_cache_shared_layers=args.index_cache_shared_layers, + benchmark=bench_name, + ) + summary["benchmark"] = bench_name + print(summary) exit(0) \ No newline at end of file diff --git a/examples/verify_algo.sh b/examples/verify_algo.sh index 3edf9b6..ddcd905 100644 --- a/examples/verify_algo.sh +++ b/examples/verify_algo.sh @@ -24,3 +24,8 @@ TIMESTAMP=$(date +%Y%m%d_%H%M%S) --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done + + TORCH_CUDA_ARCH_LIST="12.0" \ + MAX_JOBS=64 \ + pip install -e . --no-build-isolation \ + -Ccmake.args="-DENABLE_BELOW_SM90=OFF" \ No newline at end of file diff --git a/examples/verify_algo_topk_mapping.sh b/examples/verify_algo_topk_mapping.sh index 2370ca1..c0a03c5 100644 --- a/examples/verify_algo_topk_mapping.sh +++ b/examples/verify_algo_topk_mapping.sh @@ -14,7 +14,18 @@ set -e # 9: Erf — y = erf(alpha * x) # 10: Tanh — y = tanh(alpha * x) # 11: Subtract — x - pivot (RadiK-style scatter) -export CUDA_VISIBLE_DEVICES=0 +GPU_ID=0 +BENCHMARKS="amc23" + +while [[ $# -gt 0 ]]; do + case "$1" in + --gpu) GPU_ID="$2"; shift 2 ;; + --benchmark) BENCHMARKS="$2"; shift 2 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BENCH_DIR="${SCRIPT_DIR}/../benchmarks" @@ -23,13 +34,13 @@ sparse_algos=( "block_sparse_attention" ) -RESULTS_DIR="results" +BENCH_LABEL=$(echo "${BENCHMARKS}" | tr ' ' '_') +RESULTS_DIR="results/${BENCH_LABEL}" mkdir -p "${RESULTS_DIR}" TIMESTAMP=$(date +%Y%m%d_%H%M%S) # Set this to an existing calibration directory to skip re-running calibration. # It must contain lut.npy and quantiles.npy (output of calibrate_topk.py). -CALIBRATION_DIR="/scr/dataset/yuke/xinrui/new/vortex_torch/examples/calibration" - +CALIBRATION_DIR="/data/datasets/xinrui/My_Projects/vortex_torch/examples/calibration" # ============================================================ # Baseline: naive topk (mode 0) # ============================================================ @@ -44,6 +55,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type naive \ --topk-mapping-mode 0 \ + --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done @@ -167,6 +179,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode ${topk_mapping_mode} \ + --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done @@ -265,4 +278,25 @@ for algo in "${sparse_algos[@]}"; do --topk-mapping-power ${BEST_POWER_10} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" -done \ No newline at end of file +done + +# ============================================================ +# Counter profiling: collect COUNTER_NUM_EQUAL for all modes +# ============================================================ +echo "" +echo "============================================================" +echo "Counter profiling: COUNTER_NUM_EQUAL per mode (topk=30)" +echo "============================================================" +COUNTER_JSON="${RESULTS_DIR}/counters_${TIMESTAMP}.json" +PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --batch-sizes 4 \ + --seq-lens 4096 \ + --topk-vals 30 \ + --num-kv-heads 2 \ + --distributions normal \ + --counters \ + --filter-kernels sglang_ori sglang_m0 sglang_m3 sglang_m6 sglang_m7 sglang_m8 sglang_m9 sglang_m10 sglang_m11 \ + --repeat 5 \ + --output-json "${COUNTER_JSON}" \ + 2>&1 | tee "${RESULTS_DIR}/counters_${TIMESTAMP}.log" +echo ">>> Counters saved to ${COUNTER_JSON}" \ No newline at end of file diff --git a/examples/verify_algo_topk_mapping_indexcache.sh b/examples/verify_algo_topk_mapping_indexcache.sh deleted file mode 100644 index 9002084..0000000 --- a/examples/verify_algo_topk_mapping_indexcache.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/usr/bin/env bash -set -e -# use CUDA_VISIBLE_DEVICES to set the GPU id you want to use -export CUDA_VISIBLE_DEVICES=5 - -RESULTS_DIR="results" -mkdir -p "${RESULTS_DIR}" -TIMESTAMP=$(date +%Y%m%d_%H%M%S) - -sparse_algos=( - "block_sparse_attention" -) - -# --- Mode 5: Index Cache (default even-layer pattern) --- -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_mode5_index_cache_${TIMESTAMP}.log" - echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --topk-type sglang --topk-mapping-mode 5 (index cache)" - echo ">>> Saving results to ${OUTFILE}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 5 \ - --index-cache-shared-layers 2 4 6 8 10 12 14 16 18 20 22 24 26 \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done - -# --- Mode 6: Greedy layer selection --- -# for algo in "${sparse_algos[@]}"; do -# OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_mode6_greedy_${TIMESTAMP}.log" -# echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --topk-type sglang --topk-mapping-mode 6 (greedy)" -# echo ">>> Saving results to ${OUTFILE}" -# { time python verify_algo.py \ -# --trials 8 \ -# --topk-val 30 \ -# --vortex-module-name "${algo}" \ -# --model-name Qwen/Qwen3-1.7B \ -# --topk-type sglang \ -# --topk-mapping-mode 6 \ -# --mem 0.7 ; } \ -# 2>&1 | tee "${OUTFILE}" -#done diff --git a/examples/verify_algo_topk_mapping_new.sh b/examples/verify_algo_topk_mapping_new.sh index 5c5d6cf..4c96a15 100644 --- a/examples/verify_algo_topk_mapping_new.sh +++ b/examples/verify_algo_topk_mapping_new.sh @@ -14,19 +14,35 @@ set -e # 9: Erf — y = erf(alpha * x) # 10: Tanh — y = tanh(alpha * x) # 11: Subtract — x - pivot (RadiK-style scatter) -export CUDA_VISIBLE_DEVICES=5 - SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BENCH_DIR="${SCRIPT_DIR}/../benchmarks" +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=5 +TOPK_VAL=30 +BENCHMARKS="amc23" # space-separated list, e.g. "amc23 aime24" + +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --benchmark) BENCHMARKS="$2"; shift 2 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" + sparse_algos=( "block_sparse_attention" ) # Path to real-data histograms from calibration (for auto-tuning) -REAL_HISTOGRAMS="" +REAL_HISTOGRAMS="/data/datasets/xinrui/My_Projects/vortex_torch/examples/calibration/raw_histograms.npy" -RESULTS_DIR="results" +BENCH_LABEL=$(echo "${BENCHMARKS}" | tr ' ' '_') +RESULTS_DIR="results/topk${TOPK_VAL}_${BENCH_LABEL}" mkdir -p "${RESULTS_DIR}" TIMESTAMP=$(date +%Y%m%d_%H%M%S) @@ -39,7 +55,7 @@ echo "Step 0: Auto-tuning hyperparameters (synthetic data)" echo "============================================================" AUTOTUNE_JSON="${RESULTS_DIR}/autotune_${TIMESTAMP}.json" PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ - --topk-val 30 \ + --topk-val ${TOPK_VAL} \ --batch-size 4 \ --seq-len 32768 \ --num-kv-heads 2 \ @@ -58,15 +74,35 @@ data = json.load(open(sys.argv[1])) best = {} for r in data: m = r.get('mode') - if m in (3, 6, 7, 9, 10): + if m in (3, 6, 7, 9, 10, 13, 14): if m not in best or r['gini'] < best[m]['gini']: best[m] = r -for m in (3, 6, 7, 9, 10): +for m in (3, 6, 7, 9, 10, 13, 14): print(f'BEST_POWER_{m}={best[m][\"param\"]}' if m in best else f'BEST_POWER_{m}=0.5') " "${AUTOTUNE_JSON}")" -echo ">>> Autotuned best powers: mode3=${BEST_POWER_3} mode6=${BEST_POWER_6} mode7=${BEST_POWER_7} mode9=${BEST_POWER_9} mode10=${BEST_POWER_10}" +echo ">>> Autotuned best powers: mode3=${BEST_POWER_3} mode6=${BEST_POWER_6} mode7=${BEST_POWER_7} mode9=${BEST_POWER_9} mode10=${BEST_POWER_10} mode13=${BEST_POWER_13} mode14=${BEST_POWER_14}" echo "" +# ============================================================ +# Baseline: Original sglang kernel (no remap) +# ============================================================ +echo "============================================================" +echo "Baseline: sglang_ori (no remap)" +echo "============================================================" +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_ori_${TIMESTAMP}.log" + echo ">>> sglang_ori algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val ${TOPK_VAL} \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang_ori \ + --benchmark ${BENCHMARKS} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + # ============================================================ # Step 1: Mode 3 (power) — autotuned best p # ============================================================ @@ -78,12 +114,13 @@ for algo in "${sparse_algos[@]}"; do echo ">>> Mode 3 (power) p=${BEST_POWER_3} algo=${algo}" { time python verify_algo.py \ --trials 8 \ - --topk-val 30 \ + --topk-val ${TOPK_VAL} \ --vortex-module-name "${algo}" \ --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 3 \ --topk-mapping-power ${BEST_POWER_3} \ + --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done @@ -99,12 +136,13 @@ for algo in "${sparse_algos[@]}"; do echo ">>> Mode 6 (asinh) beta=${BEST_POWER_6} algo=${algo}" { time python verify_algo.py \ --trials 8 \ - --topk-val 30 \ + --topk-val ${TOPK_VAL} \ --vortex-module-name "${algo}" \ --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 6 \ --topk-mapping-power ${BEST_POWER_6} \ + --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done @@ -120,12 +158,13 @@ for algo in "${sparse_algos[@]}"; do echo ">>> Mode 7 (log1p) alpha=${BEST_POWER_7} algo=${algo}" { time python verify_algo.py \ --trials 8 \ - --topk-val 30 \ + --topk-val ${TOPK_VAL} \ --vortex-module-name "${algo}" \ --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 7 \ --topk-mapping-power ${BEST_POWER_7} \ + --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done @@ -141,11 +180,12 @@ for algo in "${sparse_algos[@]}"; do echo ">>> Mode 8 (trunc8) algo=${algo}" { time python verify_algo.py \ --trials 8 \ - --topk-val 30 \ + --topk-val ${TOPK_VAL} \ --vortex-module-name "${algo}" \ --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 8 \ + --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done @@ -161,12 +201,13 @@ for algo in "${sparse_algos[@]}"; do echo ">>> Mode 9 (erf) alpha=${BEST_POWER_9} algo=${algo}" { time python verify_algo.py \ --trials 8 \ - --topk-val 30 \ + --topk-val ${TOPK_VAL} \ --vortex-module-name "${algo}" \ --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 9 \ --topk-mapping-power ${BEST_POWER_9} \ + --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done @@ -182,12 +223,13 @@ for algo in "${sparse_algos[@]}"; do echo ">>> Mode 10 (tanh) alpha=${BEST_POWER_10} algo=${algo}" { time python verify_algo.py \ --trials 8 \ - --topk-val 30 \ + --topk-val ${TOPK_VAL} \ --vortex-module-name "${algo}" \ --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 10 \ --topk-mapping-power ${BEST_POWER_10} \ + --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done @@ -203,11 +245,12 @@ for algo in "${sparse_algos[@]}"; do echo ">>> Mode 11 (subtract) algo=${algo}" { time python verify_algo.py \ --trials 8 \ - --topk-val 30 \ + --topk-val ${TOPK_VAL} \ --vortex-module-name "${algo}" \ --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 11 \ + --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done @@ -224,16 +267,88 @@ for algo in "${sparse_algos[@]}"; do echo ">>> Mode 12 (adaptive_tail_window) algo=${algo}" { time python verify_algo.py \ --trials 8 \ - --topk-val 30 \ + --topk-val ${TOPK_VAL} \ --vortex-module-name "${algo}" \ --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 12 \ --topk-mapping-power 4.0 \ + --benchmark ${BENCHMARKS} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + +# ============================================================ +# Step 9: Mode 13 (exp_stretch) — autotuned best alpha +# ============================================================ +echo "" +echo "============================================================" +echo "Step 9: Mode 13 (exp_stretch) — alpha=${BEST_POWER_13} (autotuned)" +echo "============================================================" +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_13_alpha${BEST_POWER_13}_${TIMESTAMP}.log" + echo ">>> Mode 13 (exp_stretch) alpha=${BEST_POWER_13} algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val ${TOPK_VAL} \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 13 \ + --topk-mapping-power ${BEST_POWER_13} \ + --benchmark ${BENCHMARKS} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + +# ============================================================ +# Step 10: Mode 14 (topk_window) — autotuned best rho +# ============================================================ +echo "" +echo "============================================================" +echo "Step 10: Mode 14 (topk_window) — rho=${BEST_POWER_14} (autotuned)" +echo "============================================================" +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_14_rho${BEST_POWER_14}_${TIMESTAMP}.log" + echo ">>> Mode 14 (topk_window) rho=${BEST_POWER_14} algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val ${TOPK_VAL} \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 14 \ + --topk-mapping-power ${BEST_POWER_14} \ + --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done +# ============================================================ +# Counter profiling: collect COUNTER_NUM_EQUAL for all modes +# (single extra kernel call per mode, no overhead on accuracy runs) +# ============================================================ +echo "" +echo "============================================================" +echo "Counter profiling: COUNTER_NUM_EQUAL per mode (topk=${TOPK_VAL})" +echo "============================================================" +COUNTER_JSON="${RESULTS_DIR}/counters_${TIMESTAMP}.json" +PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --batch-sizes 4 \ + --seq-lens 4096 \ + --topk-vals ${TOPK_VAL} \ + --num-kv-heads 2 \ + --distributions normal \ + --counters \ + --real-histograms "${REAL_HISTOGRAMS}" \ + --autotune-json "${AUTOTUNE_JSON}" \ + --filter-kernels sglang_ori sglang_m0 sglang_m3 sglang_m6 sglang_m7 sglang_m8 sglang_m9 sglang_m10 sglang_m11 sglang_m13 sglang_m14 \ + --mapping-power-13 ${BEST_POWER_13} --mapping-power-14 ${BEST_POWER_14} \ + --repeat 5 \ + --output-json "${COUNTER_JSON}" \ + 2>&1 | tee "${RESULTS_DIR}/counters_${TIMESTAMP}.log" +echo ">>> Counters saved to ${COUNTER_JSON}" + # ============================================================ # Summary # ============================================================ @@ -241,12 +356,15 @@ echo "" echo "============================================================" echo "All runs complete. Results in ${RESULTS_DIR}/" echo " Auto-tune: ${AUTOTUNE_JSON}" -echo " Mode 3 (power): p = ${BEST_POWER_3} (autotuned)" -echo " Mode 6 (asinh): beta = ${BEST_POWER_6} (autotuned)" -echo " Mode 7 (log1p): alpha = ${BEST_POWER_7} (autotuned)" -echo " Mode 8 (trunc8): (fixed)" -echo " Mode 9 (erf): alpha = ${BEST_POWER_9} (autotuned)" -echo " Mode 10 (tanh): alpha = ${BEST_POWER_10} (autotuned)" -echo " Mode 11 (subtract): (fixed)" -echo " Mode 12 (tail_win): rho = 4.0" +echo " Counters: ${COUNTER_JSON}" +echo " Mode 3 (power): p = ${BEST_POWER_3} (autotuned)" +echo " Mode 6 (asinh): beta = ${BEST_POWER_6} (autotuned)" +echo " Mode 7 (log1p): alpha = ${BEST_POWER_7} (autotuned)" +echo " Mode 8 (trunc8): (fixed)" +echo " Mode 9 (erf): alpha = ${BEST_POWER_9} (autotuned)" +echo " Mode 10 (tanh): alpha = ${BEST_POWER_10} (autotuned)" +echo " Mode 11 (subtract): (fixed)" +echo " Mode 12 (tail_win): rho = 4.0" +echo " Mode 13 (exp_stretch):alpha = ${BEST_POWER_13} (autotuned)" +echo " Mode 14 (topk_window):rho = ${BEST_POWER_14} (autotuned)" echo "============================================================" diff --git a/third_party/sglang b/third_party/sglang index 0ec1289..47faead 160000 --- a/third_party/sglang +++ b/third_party/sglang @@ -1 +1 @@ -Subproject commit 0ec12893c4fc0d6ae1d36d4e0512dc21749c4b4b +Subproject commit 47faead5448b14681ac57fc9a3c6311654fc2b17 diff --git a/vortex_torch/indexer/output_func.py b/vortex_torch/indexer/output_func.py index 9c7e076..b50ca74 100644 --- a/vortex_torch/indexer/output_func.py +++ b/vortex_torch/indexer/output_func.py @@ -1,7 +1,7 @@ import torch from typing import Dict, Callable, List, Optional from ..abs import vOp -from vortex_torch_C import topk_output, topk_output_sglang, topk_profile_histogram +from vortex_torch_C import topk_output, topk_output_sglang, topk_output_sglang_ori, topk_profile_histogram from .context import Context from ..abs import vTensor, FORMAT from ..utils import UNSET @@ -91,6 +91,7 @@ class topK(vOp): FORMAT.RAGGED: { "naive": topk_output, "sglang": topk_output_sglang, + "sglang_ori": topk_output_sglang_ori, }, } @@ -272,6 +273,20 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso mapping_quantiles, mapping_noscale, ) + elif self.topk_type == "sglang_ori": + # topk_output_sglang_ori: same CSR interface, no mapping params + self.impl( + x, + ctx.dense_kv_indptr, + ctx.sparse_kv_indptr, + ctx.dense_kv_indices, + o, + ctx.batch_size * ctx.num_kv_heads, + ctx.topk_val, + ctx.page_reserved_bos, + ctx.page_reserved_eos, + ctx.max_num_pages_per_request, + ) else: # topk_output (naive): (x, dense_kv_indptr, dense_kv_indices, sparse_kv_indptr, sparse_kv_indices, ...) self.impl( From e6b73e45490752d8ab9104f9f89feeae0a40a3d8 Mon Sep 17 00:00:00 2001 From: UED Date: Tue, 7 Apr 2026 21:20:34 +0000 Subject: [PATCH 17/22] Update TopK mapping and profiling functionalities - Added new file for enhanced profiling capabilities. - Updated to include the new profiling source file. - Modified to expand the sweep grid for parameter tuning. - Refactored to improve handling of hyperparameters and added subprocess profiling for large TopK values. - Enhanced and to support new parameters for profiling. - Updated example scripts to reflect changes in TopK parameters and profiling options. --- benchmarks/autotune_topk_mapping.py | 3 +- benchmarks/bench_topk.py | 351 +++--- csrc/register.cc | 10 +- csrc/register.h | 10 +- csrc/topk_mapping.cuh | 2 +- csrc/topk_sglang.cu | 807 +++----------- csrc/topk_sglang_profile.cu | 1203 +++++++++++++++++++++ examples/run_distribution_analysis.sh | 22 +- examples/run_distribution_analysis_new.sh | 25 +- examples/verify_algo.py | 30 +- examples/verify_algo_topk_mapping.sh | 44 +- examples/verify_algo_topk_mapping_new.sh | 78 +- setup.py | 1 + vortex_torch/indexer/output_func.py | 6 +- 14 files changed, 1663 insertions(+), 929 deletions(-) create mode 100644 csrc/topk_sglang_profile.cu diff --git a/benchmarks/autotune_topk_mapping.py b/benchmarks/autotune_topk_mapping.py index f04418d..8051c14 100644 --- a/benchmarks/autotune_topk_mapping.py +++ b/benchmarks/autotune_topk_mapping.py @@ -34,7 +34,7 @@ SWEEP_GRID = { # (mode, param_name, param_values) - 3: ("power_exp", [0.1, 0.25, 0.75, 0.9]), + 3: ("power_exp", [0.1, 0.25, 0.5, 0.75, 0.9, 2.0, 4.0]), 6: ("beta", [0.1, 0.5, 1.0, 2.0, 4.0]), 7: ("alpha", [0.1, 0.5, 0.75, 1.0, 2.0, 4.0, 8.0]), 9: ("alpha", [0.1, 0.5, 1.0, 2.0, 4.0]), @@ -55,6 +55,7 @@ 7: ("log1p_noscale", [1.0]), 9: ("erf_noscale", [1.0]), 10: ("tanh_noscale", [1.0]), + 13: ("exp_stretch_noscale", [1.0, 4.0]), } MODE_NAMES = { 0: "none", diff --git a/benchmarks/bench_topk.py b/benchmarks/bench_topk.py index 2fd1e31..a913bde 100644 --- a/benchmarks/bench_topk.py +++ b/benchmarks/bench_topk.py @@ -227,19 +227,118 @@ def _load_autotune_powers(path: str) -> Dict[int, float]: return {m: v["param"] for m, v in best.items()} -def _resolve_mode_power(args, mode: int) -> float: +def _resolve_mode_hparam(args, mode: int) -> float: """Return the power/beta/alpha for a parametric mapping mode. Priority: per-mode CLI flag > autotune JSON > global --mapping-power. """ - per_mode_flag = {3: args.mapping_power_3, 6: args.mapping_power_6, 7: args.mapping_power_7, - 9: getattr(args, 'mapping_power_9', None), 10: getattr(args, 'mapping_power_10', None), - 13: getattr(args, 'mapping_power_13', None), 14: getattr(args, 'mapping_power_14', None)} + per_mode_flag = {3: args.mapping_hparam_3, 6: args.mapping_hparam_6, 7: args.mapping_hparam_7, + 9: getattr(args, 'mapping_hparam_9', None), 10: getattr(args, 'mapping_hparam_10', None), + 13: getattr(args, 'mapping_hparam_13', None), 14: getattr(args, 'mapping_hparam_14', None)} if mode in per_mode_flag and per_mode_flag[mode] is not None: return per_mode_flag[mode] if hasattr(args, "_autotune_powers") and mode in args._autotune_powers: return args._autotune_powers[mode] - return args.mapping_power + return args.mapping_hparam + + +def _run_subphase_profiling(subphase_modes, inputs, eff_bs, topk_val, + pages_per_seg, args, mapping_lut, mapping_quantiles): + """Run sub-phase profiling (histogram_only + stage1_full) for each mode. + + For topk <= 512, runs inline. For topk > 512, runs each mode in a + separate subprocess to avoid CUDA shared memory exhaustion from + accumulated kernel template registrations. + """ + import subprocess, sys, tempfile, os + + for kernel_name, s1_mode, s1_power, s1_noscale, result in subphase_modes: + s1_lut = mapping_lut if s1_mode == 1 else None + s1_q = mapping_quantiles if s1_mode == 2 else None + + if topk_val <= 512: + # Inline: run directly in this process + hist_buf = torch.zeros(eff_bs, 256, dtype=torch.int32, device="cuda") + hist_args = ( + inputs["x"], inputs["dense_kv_indptr"], hist_buf, eff_bs, + args.reserved_bos, args.reserved_eos, + s1_mode, s1_power, s1_lut, s1_q, s1_noscale, + ) + hist_result = bench_kernel(topk_profile_histogram, hist_args, args.warmup, args.repeat) + + inputs["sparse_kv_indices"].zero_() + stage1_args = ( + inputs["x"], inputs["dense_kv_indptr"], inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], inputs["sparse_kv_indices"], + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + s1_mode, s1_power, s1_lut, s1_q, s1_noscale, + ) + stage1_result = bench_kernel(topk_profile_stage1, stage1_args, args.warmup, args.repeat) + else: + # Subprocess: fresh CUDA context per mode to avoid shared memory exhaustion + script = f""" +import torch, json, sys +sys.path.insert(0, '{os.path.dirname(os.path.abspath(__file__))}') +from vortex_torch_C import topk_profile_histogram, topk_profile_stage1 +from bench_topk import make_topk_inputs, bench_kernel + +inputs = make_topk_inputs( + batch_size={inputs['x'].shape[0] // (eff_bs // (inputs['x'].shape[0] if eff_bs == inputs['x'].shape[0] else 1)) if False else 1}, + num_kv_heads=1, seq_len={pages_per_seg * 16}, + page_size=16, topk_val={topk_val}, + reserved_bos={args.reserved_bos}, reserved_eos={args.reserved_eos}, + score_dtype=torch.bfloat16, distribution="normal", +) +eff_bs = {eff_bs} +# Recreate inputs with correct eff_bs +inputs = make_topk_inputs( + batch_size={eff_bs // max(1, eff_bs // pages_per_seg) if False else eff_bs}, + num_kv_heads=1, seq_len={pages_per_seg * 16}, + page_size=16, topk_val={topk_val}, + reserved_bos={args.reserved_bos}, reserved_eos={args.reserved_eos}, + score_dtype=torch.bfloat16, distribution="normal", +) +eff_bs = inputs["eff_batch_size"] + +hist_buf = torch.zeros(eff_bs, 256, dtype=torch.int32, device="cuda") +hist_result = bench_kernel(topk_profile_histogram, + (inputs["x"], inputs["dense_kv_indptr"], hist_buf, eff_bs, + {args.reserved_bos}, {args.reserved_eos}, {s1_mode}, {s1_power}, + None, None, {s1_noscale}), 5, {args.repeat}) + +inputs["sparse_kv_indices"].zero_() +stage1_result = bench_kernel(topk_profile_stage1, + (inputs["x"], inputs["dense_kv_indptr"], inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], inputs["sparse_kv_indices"], + eff_bs, {topk_val}, {args.reserved_bos}, {args.reserved_eos}, + inputs["num_pages_per_seg"], {s1_mode}, {s1_power}, + None, None, {s1_noscale}), 5, {args.repeat}) + +print(json.dumps({{"hist": hist_result, "stage1": stage1_result}})) +""" + try: + proc = subprocess.run( + [sys.executable, "-c", script], + capture_output=True, text=True, timeout=60, + env={**os.environ, "PYTHONPATH": os.path.dirname(os.path.abspath(__file__)) + "/.."}) + if proc.returncode == 0: + data = json.loads(proc.stdout.strip().split("\n")[-1]) + hist_result = data["hist"] + stage1_result = data["stage1"] + else: + # Subprocess failed — skip sub-phase for this mode + continue + except Exception: + continue + + result['histogram_only_mean_ms'] = hist_result['mean_ms'] + result['histogram_only_median_ms'] = hist_result['median_ms'] + result['stage1_full_mean_ms'] = stage1_result['mean_ms'] + result['stage1_full_median_ms'] = stage1_result['median_ms'] + result['route_overhead_mean_ms'] = stage1_result['mean_ms'] - hist_result['mean_ms'] + result['route_overhead_median_ms'] = stage1_result['median_ms'] - hist_result['median_ms'] + result['stage2_refine_mean_ms'] = result['mean_ms'] - stage1_result['mean_ms'] + result['stage2_refine_median_ms'] = result['median_ms'] - stage1_result['median_ms'] def run_benchmark(args) -> List[dict]: @@ -275,6 +374,7 @@ def run_benchmark(args) -> List[dict]: print(f"TopK Kernel Benchmark Results") print(f"GPU: {gpu_name} | SM count: {gpu_props.multi_processor_count}") print(f"Score dtype: {args.score_dtype} | Warmup: {args.warmup} | Repeat: {args.repeat}") + print(f"Radix bits: {args.radix_bits} ({1 << args.radix_bits} bins) | Sample stride: {args.sample_stride}") print("=" * 90) # Load optional LUT / quantiles @@ -430,6 +530,7 @@ def run_benchmark(args) -> List[dict]: args.reserved_bos, args.reserved_eos, pages_per_seg, + args.radix_bits, ) result = bench_kernel(topk_output_sglang_ori, call_args, args.warmup, args.repeat) elif kernel_name == "sglang_scale": @@ -448,6 +549,9 @@ def run_benchmark(args) -> List[dict]: 1.0, # p=1.0 → identity None, None, + False, + args.sample_stride, + args.radix_bits, ) result = bench_kernel(topk_output_sglang, call_args, args.warmup, args.repeat) else: @@ -461,7 +565,7 @@ def run_benchmark(args) -> List[dict]: extra_kwargs["mapping_quantiles"] = mapping_quantiles if mode in (3, 6, 7, 9, 10, 13, 14): - power = _resolve_mode_power(args, mode) + power = _resolve_mode_hparam(args, mode) else: power = 0.5 @@ -481,6 +585,8 @@ def run_benchmark(args) -> List[dict]: extra_kwargs.get("mapping_lut", None), extra_kwargs.get("mapping_quantiles", None), is_noscale, + args.sample_stride, + args.radix_bits, ) result = bench_kernel(topk_output_sglang, call_args, args.warmup, args.repeat) @@ -498,116 +604,23 @@ def run_benchmark(args) -> List[dict]: mname = MAPPING_MODE_NAMES.get(m, f'm{m}') if m in (3, 6, 7, 9, 10, 13, 14): pname = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", 13: "alpha", 14: "rho"}[m] - label = f"sglang {mname} ({pname}={_resolve_mode_power(args, m)}){noscale_suffix}" + label = f"sglang {mname} ({pname}={_resolve_mode_hparam(args, m)}){noscale_suffix}" else: label = f"sglang {mname}{noscale_suffix}" - # Sub-phase profiling for sglang kernels (skip ori baseline) - if kernel_name not in ("naive", "sglang_ori"): - if kernel_name == "sglang_scale": - s1_mode, s1_power = 3, 1.0 - s1_lut, s1_q = None, None - s1_noscale = False + # Counter collection (runs separately from sub-phase profiling) + if kernel_name not in ("naive",) and args.counters: + if kernel_name in ("sglang_ori",): + c_mode, c_power, c_lut, c_q, c_noscale = 0, 0.5, None, None, False + elif kernel_name == "sglang_scale": + c_mode, c_power, c_lut, c_q, c_noscale = 3, 1.0, None, None, False else: - s1_mode_str = kernel_name.split("_m")[1] - s1_mode = int(s1_mode_str.split("_")[0]) - s1_noscale = kernel_name.endswith("_noscale") - if s1_mode in (3, 6, 7, 9, 10, 13, 14): - s1_power = _resolve_mode_power(args, s1_mode) - else: - s1_power = 0.5 - s1_lut = mapping_lut if s1_mode == 1 else None - s1_q = mapping_quantiles if s1_mode == 2 else None - - # Histogram only: pre-pass + histogram build - hist_buf = torch.zeros(eff_bs, 256, dtype=torch.int32, device="cuda") - hist_args = ( - inputs["x"], - inputs["dense_kv_indptr"], - hist_buf, - eff_bs, - args.reserved_bos, - args.reserved_eos, - s1_mode, - s1_power, - s1_lut, - s1_q, - s1_noscale, - ) - hist_result = bench_kernel(topk_profile_histogram, hist_args, args.warmup, args.repeat) - - # Stage1 full: pre-pass + hist + cumsum + route/filter - inputs["sparse_kv_indices"].zero_() - stage1_args = ( - inputs["x"], - inputs["dense_kv_indptr"], - inputs["sparse_kv_indptr"], - inputs["dense_kv_indices"], - inputs["sparse_kv_indices"], - eff_bs, - topk_val, - args.reserved_bos, - args.reserved_eos, - pages_per_seg, - s1_mode, - s1_power, - s1_lut, - s1_q, - s1_noscale, - ) - stage1_result = bench_kernel(topk_profile_stage1, stage1_args, args.warmup, args.repeat) - - result['histogram_only_mean_ms'] = hist_result['mean_ms'] - result['histogram_only_median_ms'] = hist_result['median_ms'] - result['stage1_full_mean_ms'] = stage1_result['mean_ms'] - result['stage1_full_median_ms'] = stage1_result['median_ms'] - result['route_overhead_mean_ms'] = stage1_result['mean_ms'] - hist_result['mean_ms'] - result['route_overhead_median_ms'] = stage1_result['median_ms'] - hist_result['median_ms'] - result['stage2_refine_mean_ms'] = result['mean_ms'] - stage1_result['mean_ms'] - result['stage2_refine_median_ms'] = result['median_ms'] - stage1_result['median_ms'] - - # Optional counter collection - if args.counters: - inputs["sparse_kv_indices"].zero_() - counter_buf = torch.zeros(eff_bs, 6, dtype=torch.int32, device="cuda") - counter_args = ( - inputs["x"], - inputs["dense_kv_indptr"], - inputs["sparse_kv_indptr"], - inputs["dense_kv_indices"], - inputs["sparse_kv_indices"], - counter_buf, - eff_bs, - topk_val, - args.reserved_bos, - args.reserved_eos, - pages_per_seg, - s1_mode, - s1_power, - s1_lut, - s1_q, - s1_noscale, - ) - topk_profile_counters(*counter_args) - torch.cuda.synchronize() - c = counter_buf.float() - result['counters'] = { - 'threshold_bin_mean': c[:, 0].mean().item(), - 'num_above_mean': c[:, 1].mean().item(), - 'num_equal_mean': c[:, 2].mean().item(), - 'remaining_k_mean': c[:, 3].mean().item(), - 'refine_rounds_mean': c[:, 4].mean().item(), - 'stage2_input_mean': c[:, 5].mean().item(), - 'threshold_bin_max': c[:, 0].max().item(), - 'num_above_max': c[:, 1].max().item(), - 'num_equal_max': c[:, 2].max().item(), - 'remaining_k_max': c[:, 3].max().item(), - 'refine_rounds_max': c[:, 4].max().item(), - 'stage2_input_max': c[:, 5].max().item(), - } - - # Counter collection for kernels skipped by sub-phase profiling - if kernel_name in ("sglang_ori",) and args.counters: + c_mode_str = kernel_name.split("_m")[1] + c_mode = int(c_mode_str.split("_")[0]) + c_noscale = kernel_name.endswith("_noscale") + c_power = _resolve_mode_hparam(args, c_mode) if c_mode in (3,6,7,9,10,13,14) else 0.5 + c_lut = mapping_lut if c_mode == 1 else None + c_q = mapping_quantiles if c_mode == 2 else None inputs["sparse_kv_indices"].zero_() counter_buf = torch.zeros(eff_bs, 6, dtype=torch.int32, device="cuda") counter_args = ( @@ -622,11 +635,11 @@ def run_benchmark(args) -> List[dict]: args.reserved_bos, args.reserved_eos, pages_per_seg, - 0, # mode 0 (no mapping) — matches ori behavior - 0.5, - None, - None, - False, + c_mode, + c_power, + c_lut, + c_q, + c_noscale, ) topk_profile_counters(*counter_args) torch.cuda.synchronize() @@ -649,6 +662,27 @@ def run_benchmark(args) -> List[dict]: kernel_entries.append((label, kernel_name, result)) config_results["kernels"][kernel_name] = result + # Second pass: sub-phase profiling (histogram_only + stage1_full) + # Run in a subprocess to get a fresh CUDA context, avoiding + # shared memory exhaustion from accumulated kernel registrations. + subphase_modes = [] + for label, kernel_name, result in kernel_entries: + if kernel_name in ("naive", "sglang_ori"): + continue + if kernel_name == "sglang_scale": + s1_mode, s1_power, s1_noscale = 3, 1.0, False + else: + s1_mode_str = kernel_name.split("_m")[1] + s1_mode = int(s1_mode_str.split("_")[0]) + s1_noscale = kernel_name.endswith("_noscale") + s1_power = _resolve_mode_hparam(args, s1_mode) if s1_mode in (3,6,7,9,10,13,14) else 0.5 + subphase_modes.append((kernel_name, s1_mode, s1_power, s1_noscale, result)) + + if subphase_modes: + _run_subphase_profiling( + subphase_modes, inputs, eff_bs, topk_val, + pages_per_seg, args, mapping_lut, mapping_quantiles) + # Print kernel results sorted by mean latency (ascending) kernel_entries.sort(key=lambda e: e[2]['mean_ms']) print(f" --- kernel latency (sorted by mean, ascending) ---") @@ -692,44 +726,13 @@ def run_benchmark(args) -> List[dict]: f"stage2_input={c['stage2_input_mean']:.0f}" ) - # Histogram analysis + # Histogram analysis — uses the SAME inputs as the main benchmark + # so histogram CSV and counters reflect the same data. if args.histogram: - # Build a separate (potentially larger) dataset for histogram profiling - target_pages = (args.histogram_pages - if args.histogram_pages is not None - else _histogram_target_pages(pages_per_seg)) + hist_inputs = inputs + hist_eff_bs = eff_bs current_pages = eff_bs * pages_per_seg - if target_pages > current_pages: - hist_bs = math.ceil(target_pages / (num_kv_heads * pages_per_seg)) - if dist == "real" and real_histogram is not None: - hist_inputs = make_topk_inputs( - batch_size=hist_bs, num_kv_heads=num_kv_heads, - seq_len=seq_len, page_size=args.page_size, - topk_val=topk_val, reserved_bos=args.reserved_bos, - reserved_eos=args.reserved_eos, score_dtype=score_dtype, - distribution="normal", - ) - total_hist_dense = hist_inputs["eff_batch_size"] * hist_inputs["num_pages_per_seg"] - hist_inputs["x"] = _scores_from_histogram(real_histogram, total_hist_dense, device="cuda") - else: - hist_inputs = make_topk_inputs( - batch_size=hist_bs, num_kv_heads=num_kv_heads, - seq_len=seq_len, page_size=args.page_size, - topk_val=topk_val, reserved_bos=args.reserved_bos, - reserved_eos=args.reserved_eos, score_dtype=score_dtype, - distribution=dist, - ) - hist_eff_bs = hist_inputs["eff_batch_size"] - actual_pages = hist_eff_bs * pages_per_seg - print( - f" histogram dataset : {actual_pages} pages " - f"(upscaled from {current_pages} for statistical reliability)" - ) - else: - hist_inputs = inputs - hist_eff_bs = eff_bs - actual_pages = current_pages - print(f" histogram dataset : {actual_pages} pages") + print(f" histogram dataset : {current_pages} pages (same as benchmark)") # Raw unmapped histogram histograms = torch.zeros(hist_eff_bs, 256, dtype=torch.int32, device="cuda") @@ -756,7 +759,7 @@ def run_benchmark(args) -> List[dict]: histograms_results = {} # Per-mode histogram analysis (scaled) - modes_to_test = [0, 3, 4, 6, 7, 8, 9, 10, 11] + modes_to_test = [0, 3, 4, 6, 7, 8, 9, 10, 11, 13, 14] if mapping_lut is not None: modes_to_test.append(1) if mapping_quantiles is not None: @@ -768,7 +771,7 @@ def run_benchmark(args) -> List[dict]: extra_lut = mapping_lut if mode == 1 else None extra_q = mapping_quantiles if mode == 2 else None - power = _resolve_mode_power(args, mode) if mode in (3, 6, 7, 9, 10, 13, 14) else 0.5 + power = _resolve_mode_hparam(args, mode) if mode in (3, 6, 7, 9, 10, 13, 14) else 0.5 topk_profile_histogram( hist_inputs["x"], @@ -806,7 +809,7 @@ def run_benchmark(args) -> List[dict]: noscale_modes = [m for m in (3, 6, 7, 9, 10, 13) if m in modes_to_test] for mode in noscale_modes: ns_hists = torch.zeros(hist_eff_bs, 256, dtype=torch.int32, device="cuda") - power = _resolve_mode_power(args, mode) + power = _resolve_mode_hparam(args, mode) topk_profile_histogram( hist_inputs["x"], hist_inputs["dense_kv_indptr"], @@ -890,18 +893,24 @@ def main(): parser.add_argument("--distributions", nargs="+", default=["normal", "lognormal", "uniform"]) parser.add_argument("--warmup", type=int, default=10) parser.add_argument("--repeat", type=int, default=100) - parser.add_argument("--mapping-power", type=float, default=0.5, - help="Global fallback power parameter for parametric modes (default: 0.5)") - parser.add_argument("--mapping-power-3", type=float, default=None, - help="Power exponent p for mode 3 (overrides --mapping-power)") - parser.add_argument("--mapping-power-6", type=float, default=None, - help="Beta for mode 6 asinh (overrides --mapping-power)") - parser.add_argument("--mapping-power-7", type=float, default=None, - help="Alpha for mode 7 log1p (overrides --mapping-power)") - parser.add_argument("--mapping-power-13", type=float, default=None, - help="Alpha for mode 13 exp_stretch (overrides --mapping-power)") - parser.add_argument("--mapping-power-14", type=float, default=None, - help="Rho for mode 14 topk_window (overrides --mapping-power)") + parser.add_argument("--mapping-hparam", "--mapping-power", type=float, default=0.5, + dest="mapping_hparam", + help="Global fallback hyperparameter for parametric modes (default: 0.5)") + parser.add_argument("--mapping-hparam-3", "--mapping-power-3", type=float, default=None, + dest="mapping_hparam_3", + help="Power exponent p for mode 3 (overrides --mapping-hparam)") + parser.add_argument("--mapping-hparam-6", "--mapping-power-6", type=float, default=None, + dest="mapping_hparam_6", + help="Beta for mode 6 asinh (overrides --mapping-hparam)") + parser.add_argument("--mapping-hparam-7", "--mapping-power-7", type=float, default=None, + dest="mapping_hparam_7", + help="Alpha for mode 7 log1p (overrides --mapping-hparam)") + parser.add_argument("--mapping-hparam-13", "--mapping-power-13", type=float, default=None, + dest="mapping_hparam_13", + help="Alpha for mode 13 exp_stretch (overrides --mapping-hparam)") + parser.add_argument("--mapping-hparam-14", "--mapping-power-14", type=float, default=None, + dest="mapping_hparam_14", + help="Rho for mode 14 topk_window (overrides --mapping-hparam)") parser.add_argument("--autotune-json", type=str, default=None, help="Path to autotune_results.json — extracts best per-mode hyperparameters " "(overrides --mapping-power for modes 3/6/7/13/14)") @@ -920,6 +929,12 @@ def main(): parser.add_argument("--counters", action="store_true", help="Collect diagnostic counters (threshold_bin, num_above, num_equal, " "remaining_k, refine_rounds, stage2_input) for each sglang kernel") + parser.add_argument("--sample-stride", type=int, default=1, + help="Pre-pass sampling stride for mapped modes (1=full, 4=1/4, 8=1/8). " + "Higher values reduce pre-pass overhead at cost of bin quality (default: 1)") + parser.add_argument("--radix-bits", type=int, default=8, + help="Stage 1 radix bits for ori/mode-0 kernel: 4=16 bins, 6=64, 8=256, 9=512, 10=1024 (default: 8). " + "Range: 4-10. Fewer bits = coarser Stage 1 but faster histogram; more bits = finer but slower.") args = parser.parse_args() results = run_benchmark(args) diff --git a/csrc/register.cc b/csrc/register.cc index b968e9c..b2d12b9 100644 --- a/csrc/register.cc +++ b/csrc/register.cc @@ -18,13 +18,16 @@ PYBIND11_MODULE(vortex_torch_C, m){ py::arg("mapping_power") = 0.5, py::arg("mapping_lut") = py::none(), py::arg("mapping_quantiles") = py::none(), - py::arg("mapping_noscale") = false); + py::arg("mapping_noscale") = false, + py::arg("sample_stride") = 1, + py::arg("radix_bits") = 8); m.def("topk_output_sglang_ori", &topk_output_sglang_ori, py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), py::arg("eff_batch_size"), py::arg("topk_val"), py::arg("reserved_bos"), py::arg("reserved_eos"), - py::arg("max_num_pages")); + py::arg("max_num_pages"), + py::arg("radix_bits") = 8); m.def("topk_profile_histogram", &topk_profile_histogram, py::arg("x"), py::arg("dense_kv_indptr"), py::arg("histograms"), py::arg("eff_batch_size"), @@ -34,7 +37,8 @@ PYBIND11_MODULE(vortex_torch_C, m){ py::arg("mapping_lut") = py::none(), py::arg("mapping_quantiles") = py::none(), py::arg("mapping_noscale") = false, - py::arg("topk_val") = 0); + py::arg("topk_val") = 0, + py::arg("sample_stride") = 1); m.def("topk_profile_stage1", &topk_profile_stage1, py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), diff --git a/csrc/register.h b/csrc/register.h index d4a311b..784b754 100644 --- a/csrc/register.h +++ b/csrc/register.h @@ -100,7 +100,9 @@ const int64_t mapping_mode = 0, const double mapping_power = 0.5, std::optional mapping_lut = std::nullopt, std::optional mapping_quantiles = std::nullopt, -const bool mapping_noscale = false +const bool mapping_noscale = false, +const int64_t sample_stride = 1, +const int64_t radix_bits = 8 ); void topk_output_sglang_ori( @@ -113,7 +115,8 @@ const int64_t eff_batch_size, const int64_t topk_val, const int64_t reserved_bos, const int64_t reserved_eos, -const int64_t max_num_pages +const int64_t max_num_pages, +const int64_t radix_bits = 8 ); void topk_profile_histogram( @@ -128,7 +131,8 @@ const double mapping_power = 0.5, std::optional mapping_lut = std::nullopt, std::optional mapping_quantiles = std::nullopt, const bool mapping_noscale = false, -const int64_t topk_val = 0 +const int64_t topk_val = 0, +const int64_t sample_stride = 1 ); void topk_profile_stage1( diff --git a/csrc/topk_mapping.cuh b/csrc/topk_mapping.cuh index 0893008..773cdeb 100644 --- a/csrc/topk_mapping.cuh +++ b/csrc/topk_mapping.cuh @@ -30,7 +30,7 @@ enum TopKMappingMode { MAPPING_QUANTILE = 2, // Piecewise-linear quantile mapping MAPPING_POWER = 3, // Monotonic power transform MAPPING_LOG = 4, // Log transform - MAPPING_INDEX_CACHE = 5, // Sentinel: reuse previous layer's indices (Python-level skip) + // Mode 5 reserved (previously INDEX_CACHE, removed) MAPPING_ASINH = 6, // asinh(beta * x), beta via power_exp MAPPING_LOG1P = 7, // sign(x) * log1p(alpha * |x|), alpha via power_exp MAPPING_TRUNC8 = 8, // BF16 upper-8-bit bucketing diff --git a/csrc/topk_sglang.cu b/csrc/topk_sglang.cu index 1d12c30..46dcdd7 100644 --- a/csrc/topk_sglang.cu +++ b/csrc/topk_sglang.cu @@ -1,10 +1,8 @@ /** - * @NOTE: This file is adapted from - * https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_v32/topk_selector.py - * We: - * 1. adapt from tilelang to pure cuda - * 2. optimize the performance a little - * 3. fix the potential illegal memory access + * Vortex TopK kernel — mirrors topk_slgang_ori.cu structure with additions: + * - bf16 support, flexible radix, mapping/remap modes + * - CSR paged wrapper kernels for vortex integration + * Profiling kernels are in topk_sglang_profile.cu. */ #include #include @@ -89,9 +87,38 @@ __device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); } -// Include mapping strategies (must come after convert_to_uint8 definition) +// ---- Vortex additions ---- + +template +__device__ __forceinline__ float vortex_to_float(T x); +template <> +__device__ __forceinline__ float vortex_to_float(float x) { return x; } +template <> +__device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { + return __bfloat162float(x); +} + +constexpr int VORTEX_MAX_TOPK = 2048; + +constexpr int COUNTER_THRESHOLD_BIN = 0; +constexpr int COUNTER_NUM_ABOVE = 1; +constexpr int COUNTER_NUM_EQUAL = 2; +constexpr int COUNTER_REMAINING_K = 3; +constexpr int COUNTER_REFINE_ROUNDS = 4; +constexpr int COUNTER_STAGE2_INPUT = 5; +constexpr int NUM_TOPK_COUNTERS = 6; + +template +__device__ __forceinline__ uint16_t convert_to_uintN(float x) { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return key >> (16 - BITS); +} + #include "topk_mapping.cuh" + __device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { // An optimized topk kernel copied from tilelang kernel // We assume length > TopK here, or it will crash @@ -435,38 +462,11 @@ void setup_kernel_smem_once() { TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); } -// ====================================================================== -// Vortex integration: BOS/EOS-aware segmented TopK with index remapping -// ====================================================================== - -template -__device__ __forceinline__ float vortex_to_float(T x); - -template <> -__device__ __forceinline__ float vortex_to_float(float x) { return x; } - -template <> -__device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { - return __bfloat162float(x); -} - -constexpr int VORTEX_MAX_TOPK = 4096; - -// Per-segment diagnostic counters written by WriteCounters mode -constexpr int COUNTER_THRESHOLD_BIN = 0; // Stage 1 coarse threshold bin id -constexpr int COUNTER_NUM_ABOVE = 1; // elements routed above threshold in Stage 1 -constexpr int COUNTER_NUM_EQUAL = 2; // elements in threshold bin (Stage 2 input) -constexpr int COUNTER_REMAINING_K = 3; // topk slots remaining after Stage 1 routing -constexpr int COUNTER_REFINE_ROUNDS = 4; // Stage 2 rounds used (0 = resolved in Stage 1) -constexpr int COUNTER_STAGE2_INPUT = 5; // candidates entering first Stage 2 refine round -constexpr int NUM_TOPK_COUNTERS = 6; - // ====================================================================== // Ori fast path: zero-overhead topk with no mapping infrastructure. -// Adapted from topk_slgang_ori.cu — uses direct convert_to_uint8() -// for Stage 1 binning with no pre-pass, no LUT, no bin cache. +// Template on RADIX_BITS: 4-10 (16 to 1024 bins). // ====================================================================== -template +template __device__ void fast_topk_ori( const ScoreT* __restrict__ input, int* __restrict__ index, @@ -476,10 +476,13 @@ __device__ void fast_topk_ori( { int topk = target_k; constexpr auto BLOCK_SIZE = 1024; - constexpr auto RADIX = 256; + constexpr auto RADIX = 1 << RADIX_BITS; + constexpr auto RADIX_PAD = RADIX / 2; constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + static_assert(RADIX_BITS >= 4 && RADIX_BITS <= 10, "RADIX_BITS must be 4-10"); + static_assert(RADIX <= BLOCK_SIZE, "RADIX must not exceed BLOCK_SIZE"); - alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_histogram_buf[2][RADIX + RADIX_PAD]; alignas(128) __shared__ int s_counter; alignas(128) __shared__ int s_threshold_bin_id; alignas(128) __shared__ int s_num_input[2]; @@ -489,20 +492,18 @@ __device__ void fast_topk_ori( const int tx = threadIdx.x; - // Stage 1: 8-bit coarse histogram (direct convert_to_uint8, no mapping) + // Stage 1: coarse histogram with RADIX bins if (tx < RADIX + 1) s_histogram[tx] = 0; __syncthreads(); for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = convert_to_uint8(vortex_to_float(input[idx + row_start])); + const auto bin = convert_to_uintN(vortex_to_float(input[idx + row_start])); ::atomicAdd(&s_histogram[bin], 1); } __syncthreads(); const auto run_cumsum = [&] { -#pragma unroll 8 - for (int i = 0; i < 8; ++i) { - static_assert(1 << 8 == RADIX); + for (int i = 0; i < RADIX_BITS; ++i) { if (C10_LIKELY(tx < RADIX)) { const auto j = 1 << i; const auto k = i & 1; @@ -515,6 +516,21 @@ __device__ void fast_topk_ori( __syncthreads(); } }; + // Stage 2 cumsum: always 256 sub-bins (8-bit radix on raw float bits) + const auto run_cumsum_s2 = [&] { + for (int i = 0; i < 8; ++i) { + if (C10_LIKELY(tx < 256)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < 256 - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; run_cumsum(); if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { @@ -529,7 +545,7 @@ __device__ void fast_topk_ori( if (topk == 0) { for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = static_cast(convert_to_uint8(vortex_to_float(input[idx + row_start]))); + const auto bin = static_cast(convert_to_uintN(vortex_to_float(input[idx + row_start]))); if (bin > threshold_bin) { const auto pos = ::atomicAdd(&s_counter, 1); index[pos] = idx; @@ -539,14 +555,12 @@ __device__ void fast_topk_ori( return; } else { __syncthreads(); - if (tx < RADIX + 1) { - s_histogram[tx] = 0; - } + if (tx < 257) s_histogram[tx] = 0; __syncthreads(); for (int idx = tx; idx < length; idx += BLOCK_SIZE) { const auto raw_input = vortex_to_float(input[idx + row_start]); - const auto bin = static_cast(convert_to_uint8(raw_input)); + const auto bin = static_cast(convert_to_uintN(raw_input)); if (bin > threshold_bin) { const auto pos = ::atomicAdd(&s_counter, 1); index[pos] = idx; @@ -572,8 +586,8 @@ __device__ void fast_topk_ori( const auto _raw_num_input = s_num_input[r_idx]; const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); - run_cumsum(); - if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + run_cumsum_s2(); + if (tx < 256 && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { s_threshold_bin_id = tx; s_num_input[r_idx ^ 1] = 0; s_last_remain = topk - s_histogram[tx + 1]; @@ -597,9 +611,7 @@ __device__ void fast_topk_ori( break; } else { __syncthreads(); - if (tx < RADIX + 1) { - s_histogram[tx] = 0; - } + if (tx < 257) s_histogram[tx] = 0; __syncthreads(); for (int i = tx; i < num_input; i += BLOCK_SIZE) { const auto idx = s_input_idx[r_idx][i]; @@ -636,7 +648,7 @@ __device__ void fast_topk_ori( // - ScoreT: float or __nv_bfloat16 // - StopAfterStage1: return after Stage 1 route/filter (for profiling) // - WriteCounters: write diagnostic counters to global memory -// - target_k: runtime parameter (replaces compile-time TopK) + // - mapping: configurable value-remapping for Stage 1 bin assignment template __device__ void fast_topk_vortex( @@ -850,40 +862,52 @@ __device__ void fast_topk_vortex( } __syncthreads(); } else if (needs_topk_window(mapping.mode)) { - // Lightweight topk-window pre-pass: compute min/max of raw values, - // then focus all 256 bins on [tau_low, max] where - // tau_low = max - (max - min) * rho * k / length. - // Like mode 12 but uses a simple heuristic instead of quantile estimation. - float local_min = __FLT_MAX__, local_max = -__FLT_MAX__; + // Topk-window pre-pass with streaming variance heuristic. + // tau_low = max - rho * sigma * sqrt(2 * log(n/k)) + float local_max = -__FLT_MAX__; + float local_sum = 0.0f, local_sum_sq = 0.0f; for (int idx = tx; idx < length; idx += BLOCK_SIZE) { float val = vortex_to_float(input[idx + row_start]); - local_min = fminf(local_min, val); local_max = fmaxf(local_max, val); + local_sum += val; + local_sum_sq += val * val; } for (int offset = 16; offset > 0; offset >>= 1) { - local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + local_sum_sq += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq, offset); } - __shared__ float s_warp_mins_tw2[32], s_warp_maxs_tw2[32]; + __shared__ float s_warp_maxs_tw2[32], s_warp_sums_tw2[32], s_warp_sq_tw2[32]; { int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) { s_warp_mins_tw2[warp_id] = local_min; s_warp_maxs_tw2[warp_id] = local_max; } + if (lane_id == 0) { + s_warp_maxs_tw2[warp_id] = local_max; + s_warp_sums_tw2[warp_id] = local_sum; + s_warp_sq_tw2[warp_id] = local_sum_sq; + } } __syncthreads(); if (tx < (BLOCK_SIZE >> 5)) { - local_min = s_warp_mins_tw2[tx]; local_max = s_warp_maxs_tw2[tx]; + local_max = s_warp_maxs_tw2[tx]; + local_sum = s_warp_sums_tw2[tx]; + local_sum_sq = s_warp_sq_tw2[tx]; for (int offset = 16; offset > 0; offset >>= 1) { - local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + local_sum_sq += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq, offset); } if (tx == 0) { float rho = mapping.power_exp; if (rho <= 0.0f) rho = 4.0f; int k = (mapping.target_k > 0) ? mapping.target_k : target_k; - float frac = rho * float(k) / float(length); - frac = fminf(frac, 1.0f); - float tau_low = local_max - (local_max - local_min) * frac; - if (tau_low >= local_max) tau_low = local_min; + float n = float(length); + float mean = local_sum / n; + float var = local_sum_sq / n - mean * mean; + float sigma = (var > 0.0f) ? sqrtf(var) : 0.0f; + float ratio = n / fmaxf(float(k), 1.0f); + float z = sqrtf(2.0f * __logf(fmaxf(ratio, 1.0f))); + float tau_low = local_max - rho * sigma * z; + if (tau_low >= local_max) tau_low = local_max - 1.0f; float range = local_max - tau_low; s_range_min = tau_low; s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; @@ -1130,8 +1154,8 @@ void TopKOutput_Kernel( } } -// Ori fast-path wrapper: zero mapping overhead -template +// Ori fast-path wrapper: zero mapping overhead, flexible radix +template __global__ __launch_bounds__(kThreadsPerBlock) void TopKOutput_Ori_Kernel( const ScoreT* __restrict__ score, @@ -1157,7 +1181,7 @@ void TopKOutput_Ori_Kernel( + page_reserved_bos; __shared__ int s_indices[VORTEX_MAX_TOPK]; - fast_topk_ori(score_blk, s_indices, 0, nblk, topk_val); + fast_topk_ori(score_blk, s_indices, 0, nblk, topk_val); __syncthreads(); const int tx = threadIdx.x; @@ -1166,300 +1190,29 @@ void TopKOutput_Ori_Kernel( } } -// ====================================================================== -// Profiling Stage1 kernel: runs pre-pass + hist + cumsum + route/filter, -// stops before Stage 2 refinement (for sub-phase timing) -// ====================================================================== +// Helper: launch TopKOutput_Ori_Kernel with radix_bits dispatch template -__global__ __launch_bounds__(kThreadsPerBlock) -void TopKStage1_Kernel( - const ScoreT* __restrict__ score, - const int* __restrict__ dense_kv_indptr, - const int* __restrict__ sparse_kv_indptr, - const int* __restrict__ dense_kv_indices, - int* __restrict__ sparse_kv_indices, - const int topk_val, - const int page_reserved_bos, - const int page_reserved_eos, - const TopKMappingParams mapping) +void launch_ori_kernel( + const ScoreT* score, const int* dense_kv_indptr, const int* sparse_kv_indptr, + const int* dense_kv_indices, int* sparse_kv_indices, + int topk_val, int reserved_bos, int reserved_eos, + int radix_bits, dim3 nblks, dim3 nthreads, cudaStream_t stream) { - const int bx = blockIdx.x; - - const int start = dense_kv_indptr[bx] + page_reserved_bos; - const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; - const int nblk = end - start; - if (nblk <= topk_val) return; - - const ScoreT* __restrict__ score_blk = score + start; - const int* __restrict__ idx_blk = dense_kv_indices + start; - int* __restrict__ out_blk = sparse_kv_indices - + sparse_kv_indptr[bx] - + page_reserved_bos; - - __shared__ int s_indices[VORTEX_MAX_TOPK]; - fast_topk_vortex( - score_blk, s_indices, 0, nblk, topk_val, mapping); - __syncthreads(); - - // Remap position indices -> page indices via dense_kv_indices - const int tx = threadIdx.x; - for (int i = tx; i < topk_val; i += kThreadsPerBlock) { - out_blk[i] = idx_blk[s_indices[i]]; - } -} - -// ====================================================================== -// Profiling counters kernel: runs full pipeline + writes diagnostic -// counters to a separate global-memory tensor -// ====================================================================== -template -__global__ __launch_bounds__(kThreadsPerBlock) -void TopKCounters_Kernel( - const ScoreT* __restrict__ score, - const int* __restrict__ dense_kv_indptr, - const int* __restrict__ sparse_kv_indptr, - const int* __restrict__ dense_kv_indices, - int* __restrict__ sparse_kv_indices, - int* __restrict__ counters, // [eff_batch_size, NUM_TOPK_COUNTERS] - const int topk_val, - const int page_reserved_bos, - const int page_reserved_eos, - const TopKMappingParams mapping) -{ - const int bx = blockIdx.x; - - const int start = dense_kv_indptr[bx] + page_reserved_bos; - const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; - const int nblk = end - start; - if (nblk <= topk_val) return; - - const ScoreT* __restrict__ score_blk = score + start; - const int* __restrict__ idx_blk = dense_kv_indices + start; - int* __restrict__ out_blk = sparse_kv_indices - + sparse_kv_indptr[bx] - + page_reserved_bos; - - __shared__ int s_indices[VORTEX_MAX_TOPK]; - fast_topk_vortex( - score_blk, s_indices, 0, nblk, topk_val, mapping, - counters + bx * NUM_TOPK_COUNTERS); - __syncthreads(); - - // Remap position indices -> page indices via dense_kv_indices - const int tx = threadIdx.x; - for (int i = tx; i < topk_val; i += kThreadsPerBlock) { - out_blk[i] = idx_blk[s_indices[i]]; - } -} - -// ====================================================================== -// Profiling histogram kernel: runs only Stage 1 and returns per-segment -// 256-bin histograms for distribution analysis -// ====================================================================== -template -__global__ __launch_bounds__(kThreadsPerBlock) -void TopKHistogram_Kernel( - const ScoreT* __restrict__ score, - const int* __restrict__ dense_kv_indptr, - int* __restrict__ histograms, // [eff_batch_size, 256] - const int page_reserved_bos, - const int page_reserved_eos, - const TopKMappingParams mapping) -{ - constexpr auto RADIX = 256; - constexpr auto BLOCK_SIZE = kThreadsPerBlock; - __shared__ int s_histogram[RADIX]; - __shared__ uint8_t s_mapping_lut[256]; - __shared__ float s_mapping_quantiles[256]; - __shared__ float s_range_min, s_range_inv_range; - - const int bx = blockIdx.x; - const int tx = threadIdx.x; - - const int start = dense_kv_indptr[bx] + page_reserved_bos; - const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; - const int nblk = end - start; - - const ScoreT* __restrict__ score_blk = score + start; - - // Load mapping tables into shared memory if needed - if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { - if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; - __syncthreads(); - } - if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { - if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; - __syncthreads(); - } - - // Pre-pass: compute per-block min/max for transform modes (supports sampled stride) - if (needs_auto_range(mapping.mode) && !mapping.noscale) { - const int stride = (mapping.sample_stride > 1) ? mapping.sample_stride : 1; - float local_min = __FLT_MAX__, local_max = -__FLT_MAX__; - for (int idx = tx * stride; idx < nblk; idx += BLOCK_SIZE * stride) { - float val = apply_transform(vortex_to_float(score_blk[idx]), mapping); - local_min = fminf(local_min, val); - local_max = fmaxf(local_max, val); - } - for (int offset = 16; offset > 0; offset >>= 1) { - local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - } - __shared__ float s_warp_mins[32], s_warp_maxs[32]; - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) { s_warp_mins[warp_id] = local_min; s_warp_maxs[warp_id] = local_max; } - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_min = s_warp_mins[tx]; local_max = s_warp_maxs[tx]; - for (int offset = 16; offset > 0; offset >>= 1) { - local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - } - if (tx == 0) { - s_range_min = local_min; - float range = local_max - local_min; - s_range_inv_range = (range > 0.0f) ? 255.0f / range : 0.0f; - } - } - __syncthreads(); - } else if (needs_pivot(mapping.mode)) { - // Pivot pre-pass: compute mean for MAPPING_SUBTRACT - float local_sum = 0.0f; - for (int idx = tx; idx < nblk; idx += BLOCK_SIZE) { - local_sum += vortex_to_float(score_blk[idx]); - } - for (int offset = 16; offset > 0; offset >>= 1) { - local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); - } - __shared__ float s_warp_sums_h[32]; - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) s_warp_sums_h[warp_id] = local_sum; - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_sum = s_warp_sums_h[tx]; - for (int offset = 16; offset > 0; offset >>= 1) { - local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); - } - if (tx == 0) { - s_range_min = local_sum / float(nblk); - s_range_inv_range = 0.0f; - } - } - __syncthreads(); - } else if (needs_tail_window(mapping.mode)) { - // Adaptive tail-window pre-pass (histogram kernel variant) - constexpr int MAX_SAMPLES_H = 1024; - __shared__ float s_samples_h[MAX_SAMPLES_H]; - __shared__ int s_sample_count_h; - - if (tx == 0) s_sample_count_h = 0; - __syncthreads(); - - const int desired_stride = (nblk + MAX_SAMPLES_H - 1) / MAX_SAMPLES_H; - const int sample_stride_h = max(desired_stride, 1); - - float local_max = -__FLT_MAX__; - for (int idx = tx * sample_stride_h; idx < nblk; idx += BLOCK_SIZE * sample_stride_h) { - float val = vortex_to_float(score_blk[idx]); - local_max = fmaxf(local_max, val); - int slot = ::atomicAdd(&s_sample_count_h, 1); - if (slot < MAX_SAMPLES_H) s_samples_h[slot] = val; - } - - for (int offset = 16; offset > 0; offset >>= 1) - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - __shared__ float s_warp_maxs_h[32]; - { - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) s_warp_maxs_h[warp_id] = local_max; - } - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_max = s_warp_maxs_h[tx]; - for (int offset = 16; offset > 0; offset >>= 1) - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - if (tx == 0) s_warp_maxs_h[0] = local_max; - } - __syncthreads(); - local_max = s_warp_maxs_h[0]; - - int nsamp = min(s_sample_count_h, MAX_SAMPLES_H); - - __syncthreads(); - if (nsamp >= 2) { - for (int pass = 0; pass < nsamp; ++pass) { - if (tx * 2 + 1 < nsamp) { - int i = tx * 2; - if (s_samples_h[i] > s_samples_h[i + 1]) { - float tmp = s_samples_h[i]; - s_samples_h[i] = s_samples_h[i + 1]; - s_samples_h[i + 1] = tmp; - } - } - __syncthreads(); - if (tx * 2 + 2 < nsamp) { - int i = tx * 2 + 1; - if (s_samples_h[i] > s_samples_h[i + 1]) { - float tmp = s_samples_h[i]; - s_samples_h[i] = s_samples_h[i + 1]; - s_samples_h[i + 1] = tmp; - } - } - __syncthreads(); - } - } - - if (tx == 0) { - float rho = mapping.power_exp; - if (rho <= 0.0f) rho = 4.0f; - int k = mapping.target_k; - float frac = (k > 0 && nblk > 0) ? 1.0f - rho * float(k) / float(nblk) : 0.0f; - frac = fmaxf(frac, 0.0f); - - float tau_low; - if (nsamp < 4 || frac <= 0.0f) { - tau_low = -__FLT_MAX__; - } else { - float fidx = frac * float(nsamp - 1); - int lo = __float2int_rd(fidx); - lo = min(max(lo, 0), nsamp - 2); - float t = fidx - float(lo); - tau_low = s_samples_h[lo] * (1.0f - t) + s_samples_h[lo + 1] * t; - } - - if (tau_low >= local_max) { - tau_low = (nsamp > 0) ? s_samples_h[0] : local_max; - } - - float range = local_max - tau_low; - s_range_min = tau_low; - s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; - } - __syncthreads(); - } else { - if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } - __syncthreads(); - } - - // Initialize shared histogram - if (tx < RADIX) s_histogram[tx] = 0; - __syncthreads(); - - // Build histogram over the segment with mapping - for (int idx = tx; idx < nblk; idx += BLOCK_SIZE) { - const auto bin = mapped_convert_to_uint8( - vortex_to_float(score_blk[idx]), - mapping, s_mapping_lut, s_mapping_quantiles, - s_range_min, s_range_inv_range); - ::atomicAdd(&s_histogram[bin], 1); - } - __syncthreads(); - - // Write to global memory - int* __restrict__ out = histograms + bx * RADIX; - if (tx < RADIX) { - out[tx] = s_histogram[tx]; + #define LAUNCH_ORI(BITS) \ + setup_kernel_smem_once, kSmem>(); \ + TopKOutput_Ori_Kernel<<>>( \ + score, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, sparse_kv_indices, \ + topk_val, reserved_bos, reserved_eos) + switch (radix_bits) { + case 4: LAUNCH_ORI(4); break; + case 5: LAUNCH_ORI(5); break; + case 6: LAUNCH_ORI(6); break; + case 7: LAUNCH_ORI(7); break; + case 9: LAUNCH_ORI(9); break; + case 10: LAUNCH_ORI(10); break; + default: LAUNCH_ORI(8); break; } + #undef LAUNCH_ORI } } // namespace @@ -1595,11 +1348,15 @@ void topk_output_sglang( const double mapping_power, std::optional mapping_lut, std::optional mapping_quantiles, - const bool mapping_noscale) + const bool mapping_noscale, + const int64_t sample_stride, + const int64_t radix_bits) { TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, "topk_output: topk_val (", topk_val, ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + TORCH_CHECK(radix_bits >= 4 && radix_bits <= 10, + "topk_output: radix_bits must be 4-10, got ", radix_bits); // Build mapping params from optional tensors TopKMappingParams mapping{}; @@ -1608,7 +1365,7 @@ void topk_output_sglang( mapping.lut = nullptr; mapping.quantiles = nullptr; mapping.noscale = mapping_noscale; - mapping.sample_stride = 1; + mapping.sample_stride = static_cast(sample_stride); mapping.target_k = static_cast(topk_val); if (mapping_lut.has_value()) { @@ -1633,27 +1390,19 @@ void topk_output_sglang( // Fast path for mode 0 (MAPPING_NONE): use ori kernel with zero mapping overhead if (mapping_mode == MAPPING_NONE) { if (x.scalar_type() == at::ScalarType::BFloat16) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Ori_Kernel<__nv_bfloat16><<>>( + launch_ori_kernel<__nv_bfloat16>( reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos); + dense_kv_indptr.data_ptr(), sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), sparse_kv_indices.data_ptr(), + topk_val, reserved_bos, reserved_eos, + radix_bits, nblks, nthreads, stream); } else if (x.scalar_type() == at::ScalarType::Float) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Ori_Kernel<<>>( + launch_ori_kernel( x.data_ptr(), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos); + dense_kv_indptr.data_ptr(), sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), sparse_kv_indices.data_ptr(), + topk_val, reserved_bos, reserved_eos, + radix_bits, nblks, nthreads, stream); } else { TORCH_CHECK(false, "topk_output: unsupported dtype ", x.scalar_type()); } @@ -1705,11 +1454,14 @@ void topk_output_sglang_ori( const int64_t topk_val, const int64_t reserved_bos, const int64_t reserved_eos, - const int64_t max_num_pages) + const int64_t max_num_pages, + const int64_t radix_bits) { TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, "topk_output_sglang_ori: topk_val (", topk_val, ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + TORCH_CHECK(radix_bits >= 4 && radix_bits <= 10, + "topk_output_sglang_ori: radix_bits must be 4-10, got ", radix_bits); CHECK_CUDA(x); CHECK_CUDA(dense_kv_indptr); @@ -1722,27 +1474,19 @@ void topk_output_sglang_ori( cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); if (x.scalar_type() == at::ScalarType::BFloat16) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Ori_Kernel<__nv_bfloat16><<>>( + launch_ori_kernel<__nv_bfloat16>( reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos); + dense_kv_indptr.data_ptr(), sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), sparse_kv_indices.data_ptr(), + topk_val, reserved_bos, reserved_eos, + radix_bits, nblks, nthreads, stream); } else if (x.scalar_type() == at::ScalarType::Float) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Ori_Kernel<<>>( + launch_ori_kernel( x.data_ptr(), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos); + dense_kv_indptr.data_ptr(), sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), sparse_kv_indices.data_ptr(), + topk_val, reserved_bos, reserved_eos, + radix_bits, nblks, nthreads, stream); } else { TORCH_CHECK(false, "topk_output_sglang_ori: unsupported dtype ", x.scalar_type()); } @@ -1751,262 +1495,3 @@ void topk_output_sglang_ori( TORCH_CHECK(result == cudaSuccess, "topk_output_sglang_ori kernel failed: ", ::cudaGetErrorString(result)); } - -// ====================================================================== -// Profiling: collect per-segment 256-bin histograms of Stage 1 bins -// ====================================================================== -void topk_profile_histogram( - const at::Tensor& x, - const at::Tensor& dense_kv_indptr, - at::Tensor& histograms, - const int64_t eff_batch_size, - const int64_t reserved_bos, - const int64_t reserved_eos, - const int64_t mapping_mode, - const double mapping_power, - std::optional mapping_lut, - std::optional mapping_quantiles, - const bool mapping_noscale, - const int64_t topk_val) -{ - CHECK_CUDA(x); - CHECK_CUDA(dense_kv_indptr); - CHECK_CUDA(histograms); - TORCH_CHECK(histograms.dim() == 2 && histograms.size(0) == eff_batch_size - && histograms.size(1) == 256, - "histograms must be [eff_batch_size, 256]"); - TORCH_CHECK(histograms.scalar_type() == at::ScalarType::Int, - "histograms must be int32"); - - // Build mapping params - TopKMappingParams mapping{}; - mapping.mode = static_cast(mapping_mode); - mapping.power_exp = static_cast(mapping_power); - mapping.lut = nullptr; - mapping.quantiles = nullptr; - mapping.noscale = mapping_noscale; - mapping.sample_stride = 1; - mapping.target_k = static_cast(topk_val); - - if (mapping_lut.has_value()) { - const auto& lut = mapping_lut.value(); - CHECK_CUDA(lut); - TORCH_CHECK(lut.dim() == 1 && lut.size(0) == 256 && lut.scalar_type() == at::ScalarType::Byte, - "mapping_lut must be a 1D uint8 tensor of size 256"); - mapping.lut = lut.data_ptr(); - } - if (mapping_quantiles.has_value()) { - const auto& q = mapping_quantiles.value(); - CHECK_CUDA(q); - TORCH_CHECK(q.dim() == 1 && q.size(0) == 256 && q.scalar_type() == at::ScalarType::Float, - "mapping_quantiles must be a 1D float32 tensor of size 256"); - mapping.quantiles = q.data_ptr(); - } - - dim3 nblks(eff_batch_size); - dim3 nthreads(kThreadsPerBlock); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (x.scalar_type() == at::ScalarType::BFloat16) { - TopKHistogram_Kernel<__nv_bfloat16><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - histograms.data_ptr(), - reserved_bos, - reserved_eos, - mapping); - } else if (x.scalar_type() == at::ScalarType::Float) { - TopKHistogram_Kernel<<>>( - x.data_ptr(), - dense_kv_indptr.data_ptr(), - histograms.data_ptr(), - reserved_bos, - reserved_eos, - mapping); - } else { - TORCH_CHECK(false, - "topk_profile_histogram: unsupported dtype ", - x.scalar_type()); - } - - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, - "topk_profile_histogram kernel failed: ", ::cudaGetErrorString(result)); -} - -// Helper: build TopKMappingParams from host arguments -static TopKMappingParams build_mapping_params( - int64_t mapping_mode, double mapping_power, - std::optional& mapping_lut, - std::optional& mapping_quantiles, - bool mapping_noscale = false, - int sample_stride = 1, - int target_k = 0) -{ - TopKMappingParams mapping{}; - mapping.mode = static_cast(mapping_mode); - mapping.power_exp = static_cast(mapping_power); - mapping.lut = nullptr; - mapping.quantiles = nullptr; - mapping.noscale = mapping_noscale; - mapping.sample_stride = sample_stride; - mapping.target_k = target_k; - - if (mapping_lut.has_value()) { - const auto& lut = mapping_lut.value(); - CHECK_CUDA(lut); - TORCH_CHECK(lut.dim() == 1 && lut.size(0) == 256 && lut.scalar_type() == at::ScalarType::Byte, - "mapping_lut must be a 1D uint8 tensor of size 256"); - mapping.lut = lut.data_ptr(); - } - if (mapping_quantiles.has_value()) { - const auto& q = mapping_quantiles.value(); - CHECK_CUDA(q); - TORCH_CHECK(q.dim() == 1 && q.size(0) == 256 && q.scalar_type() == at::ScalarType::Float, - "mapping_quantiles must be a 1D float32 tensor of size 256"); - mapping.quantiles = q.data_ptr(); - } - return mapping; -} - -// ====================================================================== -// Profiling: Stage 1 only (pre-pass + hist + cumsum + route/filter) -// ====================================================================== -void topk_profile_stage1( - const at::Tensor& x, - const at::Tensor& dense_kv_indptr, - const at::Tensor& sparse_kv_indptr, - const at::Tensor& dense_kv_indices, - at::Tensor& sparse_kv_indices, - const int64_t eff_batch_size, - const int64_t topk_val, - const int64_t reserved_bos, - const int64_t reserved_eos, - const int64_t max_num_pages, - const int64_t mapping_mode, - const double mapping_power, - std::optional mapping_lut, - std::optional mapping_quantiles, - const bool mapping_noscale) -{ - TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, - "topk_profile_stage1: topk_val (", topk_val, - ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); - - auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles, - mapping_noscale, /*sample_stride=*/1, /*target_k=*/static_cast(topk_val)); - - dim3 nblks(eff_batch_size); - dim3 nthreads(kThreadsPerBlock); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (x.scalar_type() == at::ScalarType::BFloat16) { - setup_kernel_smem_once, kSmem>(); - TopKStage1_Kernel<__nv_bfloat16><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos, - mapping); - } else if (x.scalar_type() == at::ScalarType::Float) { - setup_kernel_smem_once, kSmem>(); - TopKStage1_Kernel<<>>( - x.data_ptr(), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos, - mapping); - } else { - TORCH_CHECK(false, - "topk_profile_stage1: unsupported dtype ", - x.scalar_type()); - } - - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, - "topk_profile_stage1 kernel failed: ", ::cudaGetErrorString(result)); -} - -// ====================================================================== -// Profiling: full pipeline + diagnostic counters -// ====================================================================== -void topk_profile_counters( - const at::Tensor& x, - const at::Tensor& dense_kv_indptr, - const at::Tensor& sparse_kv_indptr, - const at::Tensor& dense_kv_indices, - at::Tensor& sparse_kv_indices, - at::Tensor& counters, - const int64_t eff_batch_size, - const int64_t topk_val, - const int64_t reserved_bos, - const int64_t reserved_eos, - const int64_t max_num_pages, - const int64_t mapping_mode, - const double mapping_power, - std::optional mapping_lut, - std::optional mapping_quantiles, - const bool mapping_noscale) -{ - TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, - "topk_profile_counters: topk_val (", topk_val, - ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); - CHECK_CUDA(counters); - TORCH_CHECK(counters.dim() == 2 && counters.size(0) == eff_batch_size - && counters.size(1) == NUM_TOPK_COUNTERS, - "counters must be [eff_batch_size, ", NUM_TOPK_COUNTERS, "]"); - TORCH_CHECK(counters.scalar_type() == at::ScalarType::Int, - "counters must be int32"); - - auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles, - mapping_noscale, /*sample_stride=*/1, /*target_k=*/static_cast(topk_val)); - - dim3 nblks(eff_batch_size); - dim3 nthreads(kThreadsPerBlock); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (x.scalar_type() == at::ScalarType::BFloat16) { - setup_kernel_smem_once, kSmem>(); - TopKCounters_Kernel<__nv_bfloat16><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - counters.data_ptr(), - topk_val, - reserved_bos, - reserved_eos, - mapping); - } else if (x.scalar_type() == at::ScalarType::Float) { - setup_kernel_smem_once, kSmem>(); - TopKCounters_Kernel<<>>( - x.data_ptr(), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - counters.data_ptr(), - topk_val, - reserved_bos, - reserved_eos, - mapping); - } else { - TORCH_CHECK(false, - "topk_profile_counters: unsupported dtype ", - x.scalar_type()); - } - - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, - "topk_profile_counters kernel failed: ", ::cudaGetErrorString(result)); -} - diff --git a/csrc/topk_sglang_profile.cu b/csrc/topk_sglang_profile.cu new file mode 100644 index 0000000..6aeac4b --- /dev/null +++ b/csrc/topk_sglang_profile.cu @@ -0,0 +1,1203 @@ +/** + * TopK profiling kernels: histogram collection, stage-1-only timing, + * and diagnostic counter collection. + * + * Separated from topk_sglang.cu to reduce template instantiation + * pressure on CUDA shared memory resources. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace { + + +constexpr int TopK = 2048; +constexpr int kThreadsPerBlock = 1024; + +#ifdef USE_ROCM +// On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a +// per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. +#ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES +constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); +#else +constexpr size_t kSmem = 48 * 1024; // bytes +#endif +#else +// Reduced from 128KB to 32KB to improve occupancy. +// Each radix pass needs at most ~TopK candidates in the threshold bin, +// so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. +constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) +#endif + +struct FastTopKParams { + const float* __restrict__ input; // [B, input_stride] + const int32_t* __restrict__ row_starts; // [B] + int32_t* __restrict__ indices; // [B, TopK] + int32_t* __restrict__ lengths; // [B] + int64_t input_stride; +}; + +// when length <= TopK, we can directly write the indices +__device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { + const auto tid = threadIdx.x; + for (int i = tid; i < TopK; i += kThreadsPerBlock) { + indice[i] = (i < length) ? i : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform( + const float* __restrict__ score, + int32_t length, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + dst_page_table[i] = (i < length) ? src_page_table[i] : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform_ragged( + const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; + } +} + +__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +template +__device__ __forceinline__ float vortex_to_float(T x); +template <> +__device__ __forceinline__ float vortex_to_float(float x) { return x; } +template <> +__device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { + return __bfloat162float(x); +} + +constexpr int VORTEX_MAX_TOPK = 2048; +constexpr int COUNTER_THRESHOLD_BIN = 0; +constexpr int COUNTER_NUM_ABOVE = 1; +constexpr int COUNTER_NUM_EQUAL = 2; +constexpr int COUNTER_REMAINING_K = 3; +constexpr int COUNTER_REFINE_ROUNDS = 4; +constexpr int COUNTER_STAGE2_INPUT = 5; +constexpr int NUM_TOPK_COUNTERS = 6; + +#include "topk_mapping.cuh" + +// - mapping: configurable value-remapping for Stage 1 bin assignment +template +__device__ void fast_topk_vortex( + const ScoreT* __restrict__ input, + int* __restrict__ index, + int row_start, + int length, + int target_k, + const TopKMappingParams& mapping, + int* counters = nullptr) +{ + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int vh_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int vh_counter; + alignas(128) __shared__ int vh_threshold_bin_id; + alignas(128) __shared__ int vh_num_input[2]; + + // Shared memory for mapping LUT / quantiles (loaded once per block) + __shared__ uint8_t s_mapping_lut[256]; + __shared__ float s_mapping_quantiles[256]; + + // Auto-range for transform modes (3/4/6/7) + __shared__ float s_range_min, s_range_inv_range; + + auto& vh_histogram = vh_histogram_buf[0]; + extern __shared__ int vh_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // Load mapping tables into shared memory if needed + if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { + if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; + __syncthreads(); + } + if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { + if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; + __syncthreads(); + } + + // Pre-pass: compute per-block min/max of transformed values for linear bucketing. + // sample_stride > 1 reduces pre-pass cost by scanning every Nth element; + // the approximated range may miss extreme outliers but Stage 2 uses raw + // float bits for exact ordering, so correctness is preserved. + if (needs_auto_range(mapping.mode) && !mapping.noscale) { + const int stride = (mapping.sample_stride > 1) ? mapping.sample_stride : 1; + float local_min = __FLT_MAX__, local_max = -__FLT_MAX__; + for (int idx = tx * stride; idx < length; idx += BLOCK_SIZE * stride) { + float val = apply_transform(vortex_to_float(input[idx + row_start]), mapping); + local_min = fminf(local_min, val); + local_max = fmaxf(local_max, val); + } + // Warp-level reduction + for (int offset = 16; offset > 0; offset >>= 1) { + local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + } + // Cross-warp reduction via shared memory + __shared__ float s_warp_mins[32], s_warp_maxs[32]; + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) { s_warp_mins[warp_id] = local_min; s_warp_maxs[warp_id] = local_max; } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_min = s_warp_mins[tx]; local_max = s_warp_maxs[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + } + if (tx == 0) { + s_range_min = local_min; + float range = local_max - local_min; + s_range_inv_range = (range > 0.0f) ? 255.0f / range : 0.0f; + } + } + __syncthreads(); + } else if (needs_pivot(mapping.mode)) { + // Pivot pre-pass: compute mean of all elements, store in s_range_min. + // MAPPING_SUBTRACT uses convert_to_uint8(x - range_min), so centering + // around the mean helps distribute values more evenly across bins. + float local_sum = 0.0f; + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + local_sum += vortex_to_float(input[idx + row_start]); + } + // Warp-level reduction + for (int offset = 16; offset > 0; offset >>= 1) { + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + } + __shared__ float s_warp_sums[32]; + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) s_warp_sums[warp_id] = local_sum; + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_sum = s_warp_sums[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + } + if (tx == 0) { + s_range_min = local_sum / float(length); // mean as pivot + s_range_inv_range = 0.0f; + } + } + __syncthreads(); + } else if (needs_tail_window(mapping.mode)) { + // Adaptive tail-window pre-pass: estimate tau_low = Q(1 - rho*k/n) + // and local_max via a sampled quantile estimator. All 256 coarse bins + // are then allocated to [tau_low, local_max]; scores below tau_low + // collapse into bin 0 via linear_map_to_uint8 clamping. + constexpr int MAX_SAMPLES = 1024; + __shared__ float s_samples[MAX_SAMPLES]; + __shared__ int s_sample_count; + + if (tx == 0) s_sample_count = 0; + __syncthreads(); + + // Compute sampling stride so we collect ~MAX_SAMPLES from the segment + const int desired_stride = (length + MAX_SAMPLES - 1) / MAX_SAMPLES; + const int sample_stride = max(desired_stride, 1); + + // Each thread samples elements and finds local_max simultaneously + float local_max = -__FLT_MAX__; + for (int idx = tx * sample_stride; idx < length; idx += BLOCK_SIZE * sample_stride) { + float val = vortex_to_float(input[idx + row_start]); + local_max = fmaxf(local_max, val); + int slot = ::atomicAdd(&s_sample_count, 1); + if (slot < MAX_SAMPLES) { + s_samples[slot] = val; + } + } + + // Reduce local_max across block + for (int offset = 16; offset > 0; offset >>= 1) + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + __shared__ float s_warp_maxs_tw[32]; + { + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) s_warp_maxs_tw[warp_id] = local_max; + } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_max = s_warp_maxs_tw[tx]; + for (int offset = 16; offset > 0; offset >>= 1) + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + if (tx == 0) s_warp_maxs_tw[0] = local_max; + } + __syncthreads(); + local_max = s_warp_maxs_tw[0]; + + int nsamp = min(s_sample_count, MAX_SAMPLES); + + // Simple odd-even transposition sort on the sample buffer. + // nsamp <= 1024, and we have 1024 threads, so each thread + // handles one element. O(nsamp) parallel rounds suffice. + __syncthreads(); + if (nsamp >= 2) { + for (int pass = 0; pass < nsamp; ++pass) { + // Even phase: compare (0,1), (2,3), ... + if (tx * 2 + 1 < nsamp) { + int i = tx * 2; + if (s_samples[i] > s_samples[i + 1]) { + float tmp = s_samples[i]; + s_samples[i] = s_samples[i + 1]; + s_samples[i + 1] = tmp; + } + } + __syncthreads(); + // Odd phase: compare (1,2), (3,4), ... + if (tx * 2 + 2 < nsamp) { + int i = tx * 2 + 1; + if (s_samples[i] > s_samples[i + 1]) { + float tmp = s_samples[i]; + s_samples[i] = s_samples[i + 1]; + s_samples[i + 1] = tmp; + } + } + __syncthreads(); + } + } + + // Estimate tau_low = Q(1 - rho * k / n) + if (tx == 0) { + float rho = mapping.power_exp; // reused as tail expansion factor + if (rho <= 0.0f) rho = 4.0f; + int k = (mapping.target_k > 0) ? mapping.target_k : target_k; + float frac = 1.0f - rho * float(k) / float(length); + frac = fmaxf(frac, 0.0f); // clamp: never go below rank 0 + + float tau_low; + if (nsamp < 4 || frac <= 0.0f) { + // Too few samples or the tail covers everything: full range + tau_low = -__FLT_MAX__; + } else { + float fidx = frac * float(nsamp - 1); + int lo = __float2int_rd(fidx); + lo = min(max(lo, 0), nsamp - 2); + float t = fidx - float(lo); + tau_low = s_samples[lo] * (1.0f - t) + s_samples[lo + 1] * t; + } + + // Fallback: if tau_low >= local_max, use full-range linear mapping + if (tau_low >= local_max) { + // Find the actual minimum from sorted samples + tau_low = (nsamp > 0) ? s_samples[0] : local_max; + } + + float range = local_max - tau_low; + s_range_min = tau_low; + s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; + } + __syncthreads(); + } else if (needs_topk_window(mapping.mode)) { + // Topk-window pre-pass with streaming variance heuristic. + float local_max = -__FLT_MAX__; + float local_sum = 0.0f, local_sum_sq = 0.0f; + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + float val = vortex_to_float(input[idx + row_start]); + local_max = fmaxf(local_max, val); + local_sum += val; + local_sum_sq += val * val; + } + for (int offset = 16; offset > 0; offset >>= 1) { + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + local_sum_sq += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq, offset); + } + __shared__ float s_warp_maxs_tw2[32], s_warp_sums_tw2[32], s_warp_sq_tw2[32]; + { + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) { + s_warp_maxs_tw2[warp_id] = local_max; + s_warp_sums_tw2[warp_id] = local_sum; + s_warp_sq_tw2[warp_id] = local_sum_sq; + } + } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_max = s_warp_maxs_tw2[tx]; + local_sum = s_warp_sums_tw2[tx]; + local_sum_sq = s_warp_sq_tw2[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + local_sum_sq += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq, offset); + } + if (tx == 0) { + float rho = mapping.power_exp; + if (rho <= 0.0f) rho = 4.0f; + int k = (mapping.target_k > 0) ? mapping.target_k : target_k; + float n = float(length); + float mean = local_sum / n; + float var = local_sum_sq / n - mean * mean; + float sigma = (var > 0.0f) ? sqrtf(var) : 0.0f; + float ratio = n / fmaxf(float(k), 1.0f); + float z = sqrtf(2.0f * __logf(fmaxf(ratio, 1.0f))); + float tau_low = local_max - rho * sigma * z; + if (tau_low >= local_max) tau_low = local_max - 1.0f; + float range = local_max - tau_low; + s_range_min = tau_low; + s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; + } + } + __syncthreads(); + } else { + if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } + __syncthreads(); + } + + // Stage 1: 8-bit coarse histogram (with optional mapping) + // Bin cache: store computed bins in vh_input_idx[1] (reinterpreted as uint8_t*) + // to avoid recomputing mapped_convert_to_uint8 in the route/filter pass. + // vh_input_idx[1] is unused until Stage 2 double-buffering starts after route. + constexpr int BIN_CACHE_CAPACITY = SMEM_INPUT_SIZE * static_cast(sizeof(int)); // uint8 entries + uint8_t* bin_cache = reinterpret_cast(vh_input_idx[1]); + const bool use_bin_cache = (length <= BIN_CACHE_CAPACITY); + + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = mapped_convert_to_uint8( + vortex_to_float(input[idx + row_start]), + mapping, s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range); + ::atomicAdd(&vh_histogram[bin], 1); + if (use_bin_cache) { + bin_cache[idx] = bin; + } + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = vh_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += vh_histogram_buf[k][tx + j]; + } + vh_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { + vh_threshold_bin_id = tx; + vh_num_input[0] = 0; + vh_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = vh_threshold_bin_id; + topk -= vh_histogram[threshold_bin + 1]; + + if (WriteCounters && tx == 0 && counters) { + counters[COUNTER_THRESHOLD_BIN] = threshold_bin; + counters[COUNTER_REMAINING_K] = topk; + } + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + int bin; + if (use_bin_cache) { + bin = static_cast(bin_cache[idx]); + } else { + bin = static_cast( + mapped_convert_to_uint8( + vortex_to_float(input[idx + row_start]), + mapping, s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range)); + } + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + if (WriteCounters && tx == 0 && counters) { + counters[COUNTER_NUM_ABOVE] = vh_counter; + counters[COUNTER_NUM_EQUAL] = 0; + counters[COUNTER_REFINE_ROUNDS] = 0; + counters[COUNTER_STAGE2_INPUT] = 0; + } + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = vortex_to_float(input[idx + row_start]); + int bin; + if (use_bin_cache) { + bin = static_cast(bin_cache[idx]); + } else { + bin = static_cast( + mapped_convert_to_uint8(raw_input, mapping, + s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range)); + } + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&vh_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + vh_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&vh_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + if (WriteCounters && tx == 0 && counters) { + counters[COUNTER_NUM_ABOVE] = vh_counter; + counters[COUNTER_NUM_EQUAL] = vh_num_input[0]; + counters[COUNTER_STAGE2_INPUT] = vh_num_input[0]; + } + if (StopAfterStage1) return; + } + + // Stage 2: refine with 8-bit radix passes (unchanged — uses raw float bits) + if constexpr (WriteCounters) { + // Default: all 4 rounds used; overwritten at break if resolved early + if (tx == 0 && counters) counters[COUNTER_REFINE_ROUNDS] = 4; + } +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int vh_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = vh_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) + ? _raw_num_input + : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { + vh_threshold_bin_id = tx; + vh_num_input[r_idx ^ 1] = 0; + vh_last_remain = topk - vh_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = vh_threshold_bin_id; + topk -= vh_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = vh_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32( + vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + if constexpr (WriteCounters) { + if (tx == 0 && counters) { + counters[COUNTER_REFINE_ROUNDS] = round + 1; + } + } + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = vh_input_idx[r_idx][i]; + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&vh_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&vh_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + vh_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&vh_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +template +void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { +#ifdef USE_ROCM + // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, + // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing + // a function pointer directly, so cast explicitly. + return ::cudaFuncSetAttribute( + reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#else + // CUDA: keep original behavior (no cast needed). + return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#endif + }(); + TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); +} + +// ====================================================================== +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKStage1_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos, + const TopKMappingParams mapping) +{ + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_vortex( + score_blk, s_indices, 0, nblk, topk_val, mapping); + __syncthreads(); + + // Remap position indices -> page indices via dense_kv_indices + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } +} + +// ====================================================================== +// Profiling counters kernel: runs full pipeline + writes diagnostic +// counters to a separate global-memory tensor +// ====================================================================== +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKCounters_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + int* __restrict__ counters, // [eff_batch_size, NUM_TOPK_COUNTERS] + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos, + const TopKMappingParams mapping) +{ + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_vortex( + score_blk, s_indices, 0, nblk, topk_val, mapping, + counters + bx * NUM_TOPK_COUNTERS); + __syncthreads(); + + // Remap position indices -> page indices via dense_kv_indices + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } +} + +// ====================================================================== +// Profiling histogram kernel: runs only Stage 1 and returns per-segment +// 256-bin histograms for distribution analysis +// ====================================================================== +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKHistogram_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + int* __restrict__ histograms, // [eff_batch_size, 256] + const int page_reserved_bos, + const int page_reserved_eos, + const TopKMappingParams mapping) +{ + constexpr auto RADIX = 256; + constexpr auto BLOCK_SIZE = kThreadsPerBlock; + __shared__ int s_histogram[RADIX]; + __shared__ uint8_t s_mapping_lut[256]; + __shared__ float s_mapping_quantiles[256]; + __shared__ float s_range_min, s_range_inv_range; + + const int bx = blockIdx.x; + const int tx = threadIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + + const ScoreT* __restrict__ score_blk = score + start; + + // Load mapping tables into shared memory if needed + if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { + if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; + __syncthreads(); + } + if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { + if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; + __syncthreads(); + } + + // Pre-pass: compute per-block min/max for transform modes (supports sampled stride) + if (needs_auto_range(mapping.mode) && !mapping.noscale) { + const int stride = (mapping.sample_stride > 1) ? mapping.sample_stride : 1; + float local_min = __FLT_MAX__, local_max = -__FLT_MAX__; + for (int idx = tx * stride; idx < nblk; idx += BLOCK_SIZE * stride) { + float val = apply_transform(vortex_to_float(score_blk[idx]), mapping); + local_min = fminf(local_min, val); + local_max = fmaxf(local_max, val); + } + for (int offset = 16; offset > 0; offset >>= 1) { + local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + } + __shared__ float s_warp_mins[32], s_warp_maxs[32]; + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) { s_warp_mins[warp_id] = local_min; s_warp_maxs[warp_id] = local_max; } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_min = s_warp_mins[tx]; local_max = s_warp_maxs[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + } + if (tx == 0) { + s_range_min = local_min; + float range = local_max - local_min; + s_range_inv_range = (range > 0.0f) ? 255.0f / range : 0.0f; + } + } + __syncthreads(); + } else if (needs_pivot(mapping.mode)) { + // Pivot pre-pass: compute mean for MAPPING_SUBTRACT + float local_sum = 0.0f; + for (int idx = tx; idx < nblk; idx += BLOCK_SIZE) { + local_sum += vortex_to_float(score_blk[idx]); + } + for (int offset = 16; offset > 0; offset >>= 1) { + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + } + __shared__ float s_warp_sums_h[32]; + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) s_warp_sums_h[warp_id] = local_sum; + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_sum = s_warp_sums_h[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + } + if (tx == 0) { + s_range_min = local_sum / float(nblk); + s_range_inv_range = 0.0f; + } + } + __syncthreads(); + } else if (needs_tail_window(mapping.mode)) { + // Adaptive tail-window pre-pass (histogram kernel variant) + constexpr int MAX_SAMPLES_H = 1024; + __shared__ float s_samples_h[MAX_SAMPLES_H]; + __shared__ int s_sample_count_h; + + if (tx == 0) s_sample_count_h = 0; + __syncthreads(); + + const int desired_stride = (nblk + MAX_SAMPLES_H - 1) / MAX_SAMPLES_H; + const int sample_stride_h = max(desired_stride, 1); + + float local_max = -__FLT_MAX__; + for (int idx = tx * sample_stride_h; idx < nblk; idx += BLOCK_SIZE * sample_stride_h) { + float val = vortex_to_float(score_blk[idx]); + local_max = fmaxf(local_max, val); + int slot = ::atomicAdd(&s_sample_count_h, 1); + if (slot < MAX_SAMPLES_H) s_samples_h[slot] = val; + } + + for (int offset = 16; offset > 0; offset >>= 1) + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + __shared__ float s_warp_maxs_h[32]; + { + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) s_warp_maxs_h[warp_id] = local_max; + } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_max = s_warp_maxs_h[tx]; + for (int offset = 16; offset > 0; offset >>= 1) + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + if (tx == 0) s_warp_maxs_h[0] = local_max; + } + __syncthreads(); + local_max = s_warp_maxs_h[0]; + + int nsamp = min(s_sample_count_h, MAX_SAMPLES_H); + + __syncthreads(); + if (nsamp >= 2) { + for (int pass = 0; pass < nsamp; ++pass) { + if (tx * 2 + 1 < nsamp) { + int i = tx * 2; + if (s_samples_h[i] > s_samples_h[i + 1]) { + float tmp = s_samples_h[i]; + s_samples_h[i] = s_samples_h[i + 1]; + s_samples_h[i + 1] = tmp; + } + } + __syncthreads(); + if (tx * 2 + 2 < nsamp) { + int i = tx * 2 + 1; + if (s_samples_h[i] > s_samples_h[i + 1]) { + float tmp = s_samples_h[i]; + s_samples_h[i] = s_samples_h[i + 1]; + s_samples_h[i + 1] = tmp; + } + } + __syncthreads(); + } + } + + if (tx == 0) { + float rho = mapping.power_exp; + if (rho <= 0.0f) rho = 4.0f; + int k = mapping.target_k; + float frac = (k > 0 && nblk > 0) ? 1.0f - rho * float(k) / float(nblk) : 0.0f; + frac = fmaxf(frac, 0.0f); + + float tau_low; + if (nsamp < 4 || frac <= 0.0f) { + tau_low = -__FLT_MAX__; + } else { + float fidx = frac * float(nsamp - 1); + int lo = __float2int_rd(fidx); + lo = min(max(lo, 0), nsamp - 2); + float t = fidx - float(lo); + tau_low = s_samples_h[lo] * (1.0f - t) + s_samples_h[lo + 1] * t; + } + + if (tau_low >= local_max) { + tau_low = (nsamp > 0) ? s_samples_h[0] : local_max; + } + + float range = local_max - tau_low; + s_range_min = tau_low; + s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; + } + __syncthreads(); + } else if (needs_topk_window(mapping.mode)) { + // Topk-window pre-pass with streaming variance (histogram kernel variant) + float local_max_h = -__FLT_MAX__; + float local_sum_h = 0.0f, local_sum_sq_h = 0.0f; + for (int idx = tx; idx < nblk; idx += BLOCK_SIZE) { + float val = vortex_to_float(score_blk[idx]); + local_max_h = fmaxf(local_max_h, val); + local_sum_h += val; + local_sum_sq_h += val * val; + } + for (int offset = 16; offset > 0; offset >>= 1) { + local_max_h = fmaxf(local_max_h, __shfl_xor_sync(0xFFFFFFFF, local_max_h, offset)); + local_sum_h += __shfl_xor_sync(0xFFFFFFFF, local_sum_h, offset); + local_sum_sq_h += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq_h, offset); + } + __shared__ float s_warp_maxs_tw3[32], s_warp_sums_tw3[32], s_warp_sq_tw3[32]; + { + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) { + s_warp_maxs_tw3[warp_id] = local_max_h; + s_warp_sums_tw3[warp_id] = local_sum_h; + s_warp_sq_tw3[warp_id] = local_sum_sq_h; + } + } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_max_h = s_warp_maxs_tw3[tx]; + local_sum_h = s_warp_sums_tw3[tx]; + local_sum_sq_h = s_warp_sq_tw3[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_max_h = fmaxf(local_max_h, __shfl_xor_sync(0xFFFFFFFF, local_max_h, offset)); + local_sum_h += __shfl_xor_sync(0xFFFFFFFF, local_sum_h, offset); + local_sum_sq_h += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq_h, offset); + } + if (tx == 0) { + float rho = mapping.power_exp; + if (rho <= 0.0f) rho = 4.0f; + int k = mapping.target_k; + float n = float(nblk); + float mean = local_sum_h / n; + float var = local_sum_sq_h / n - mean * mean; + float sigma = (var > 0.0f) ? sqrtf(var) : 0.0f; + float ratio = n / fmaxf(float(k), 1.0f); + float z = sqrtf(2.0f * __logf(fmaxf(ratio, 1.0f))); + float tau_low = local_max_h - rho * sigma * z; + if (tau_low >= local_max_h) tau_low = local_max_h - 1.0f; + float range = local_max_h - tau_low; + s_range_min = tau_low; + s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; + } + } + __syncthreads(); + } else { + if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } + __syncthreads(); + } + + // Initialize shared histogram + if (tx < RADIX) s_histogram[tx] = 0; + __syncthreads(); + + // Build histogram over the segment with mapping + for (int idx = tx; idx < nblk; idx += BLOCK_SIZE) { + const auto bin = mapped_convert_to_uint8( + vortex_to_float(score_blk[idx]), + mapping, s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + // Write to global memory + int* __restrict__ out = histograms + bx * RADIX; + if (tx < RADIX) { + out[tx] = s_histogram[tx]; + } +} + +} // namespace + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +// ====================================================================== +// Profiling: collect per-segment 256-bin histograms of Stage 1 bins +// ====================================================================== +void topk_profile_histogram( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + at::Tensor& histograms, + const int64_t eff_batch_size, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t mapping_mode, + const double mapping_power, + std::optional mapping_lut, + std::optional mapping_quantiles, + const bool mapping_noscale, + const int64_t topk_val, + const int64_t sample_stride) +{ + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(histograms); + TORCH_CHECK(histograms.dim() == 2 && histograms.size(0) == eff_batch_size + && histograms.size(1) == 256, + "histograms must be [eff_batch_size, 256]"); + TORCH_CHECK(histograms.scalar_type() == at::ScalarType::Int, + "histograms must be int32"); + + // Build mapping params + TopKMappingParams mapping{}; + mapping.mode = static_cast(mapping_mode); + mapping.power_exp = static_cast(mapping_power); + mapping.lut = nullptr; + mapping.quantiles = nullptr; + mapping.noscale = mapping_noscale; + mapping.sample_stride = static_cast(sample_stride); + mapping.target_k = static_cast(topk_val); + + if (mapping_lut.has_value()) { + const auto& lut = mapping_lut.value(); + CHECK_CUDA(lut); + TORCH_CHECK(lut.dim() == 1 && lut.size(0) == 256 && lut.scalar_type() == at::ScalarType::Byte, + "mapping_lut must be a 1D uint8 tensor of size 256"); + mapping.lut = lut.data_ptr(); + } + if (mapping_quantiles.has_value()) { + const auto& q = mapping_quantiles.value(); + CHECK_CUDA(q); + TORCH_CHECK(q.dim() == 1 && q.size(0) == 256 && q.scalar_type() == at::ScalarType::Float, + "mapping_quantiles must be a 1D float32 tensor of size 256"); + mapping.quantiles = q.data_ptr(); + } + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + TopKHistogram_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + histograms.data_ptr(), + reserved_bos, + reserved_eos, + mapping); + } else if (x.scalar_type() == at::ScalarType::Float) { + TopKHistogram_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + histograms.data_ptr(), + reserved_bos, + reserved_eos, + mapping); + } else { + TORCH_CHECK(false, + "topk_profile_histogram: unsupported dtype ", + x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_profile_histogram kernel failed: ", ::cudaGetErrorString(result)); +} + +// Helper: build TopKMappingParams from host arguments +static TopKMappingParams build_mapping_params( + int64_t mapping_mode, double mapping_power, + std::optional& mapping_lut, + std::optional& mapping_quantiles, + bool mapping_noscale = false, + int sample_stride = 1, + int target_k = 0) +{ + TopKMappingParams mapping{}; + mapping.mode = static_cast(mapping_mode); + mapping.power_exp = static_cast(mapping_power); + mapping.lut = nullptr; + mapping.quantiles = nullptr; + mapping.noscale = mapping_noscale; + mapping.sample_stride = sample_stride; + mapping.target_k = target_k; + + if (mapping_lut.has_value()) { + const auto& lut = mapping_lut.value(); + CHECK_CUDA(lut); + TORCH_CHECK(lut.dim() == 1 && lut.size(0) == 256 && lut.scalar_type() == at::ScalarType::Byte, + "mapping_lut must be a 1D uint8 tensor of size 256"); + mapping.lut = lut.data_ptr(); + } + if (mapping_quantiles.has_value()) { + const auto& q = mapping_quantiles.value(); + CHECK_CUDA(q); + TORCH_CHECK(q.dim() == 1 && q.size(0) == 256 && q.scalar_type() == at::ScalarType::Float, + "mapping_quantiles must be a 1D float32 tensor of size 256"); + mapping.quantiles = q.data_ptr(); + } + return mapping; +} + +// ====================================================================== +// Profiling: Stage 1 only (pre-pass + hist + cumsum + route/filter) +// ====================================================================== +void topk_profile_stage1( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages, + const int64_t mapping_mode, + const double mapping_power, + std::optional mapping_lut, + std::optional mapping_quantiles, + const bool mapping_noscale) +{ + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_profile_stage1: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + + auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles, + mapping_noscale, /*sample_stride=*/1, /*target_k=*/static_cast(topk_val)); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + setup_kernel_smem_once, kSmem>(); + TopKStage1_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos, + mapping); + } else if (x.scalar_type() == at::ScalarType::Float) { + setup_kernel_smem_once, kSmem>(); + TopKStage1_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos, + mapping); + } else { + TORCH_CHECK(false, + "topk_profile_stage1: unsupported dtype ", + x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_profile_stage1 kernel failed: ", ::cudaGetErrorString(result)); +} + +// ====================================================================== +// Profiling: full pipeline + diagnostic counters +// ====================================================================== +void topk_profile_counters( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + at::Tensor& counters, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages, + const int64_t mapping_mode, + const double mapping_power, + std::optional mapping_lut, + std::optional mapping_quantiles, + const bool mapping_noscale) +{ + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_profile_counters: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + CHECK_CUDA(counters); + TORCH_CHECK(counters.dim() == 2 && counters.size(0) == eff_batch_size + && counters.size(1) == NUM_TOPK_COUNTERS, + "counters must be [eff_batch_size, ", NUM_TOPK_COUNTERS, "]"); + TORCH_CHECK(counters.scalar_type() == at::ScalarType::Int, + "counters must be int32"); + + auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles, + mapping_noscale, /*sample_stride=*/1, /*target_k=*/static_cast(topk_val)); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + setup_kernel_smem_once, kSmem>(); + TopKCounters_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + counters.data_ptr(), + topk_val, + reserved_bos, + reserved_eos, + mapping); + } else if (x.scalar_type() == at::ScalarType::Float) { + setup_kernel_smem_once, kSmem>(); + TopKCounters_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + counters.data_ptr(), + topk_val, + reserved_bos, + reserved_eos, + mapping); + } else { + TORCH_CHECK(false, + "topk_profile_counters: unsupported dtype ", + x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_profile_counters kernel failed: ", ::cudaGetErrorString(result)); +} + diff --git a/examples/run_distribution_analysis.sh b/examples/run_distribution_analysis.sh index 6806eca..fcc2ff1 100755 --- a/examples/run_distribution_analysis.sh +++ b/examples/run_distribution_analysis.sh @@ -47,6 +47,9 @@ MODEL_NAME="Qwen/Qwen3-1.7B" TOPK_VAL=30 MEM=0.7 ALGO="block_sparse_attention" +RADIX_BITS=8 +SAMPLE_STRIDE=1 +SEQ_LEN=32768 # The path to the raw_histograms.npy file (set to skip calibration) REAL_HISTOGRAMS="/data/datasets/xinrui/My_Projects/vortex_torch/examples/calibration/raw_histograms.npy" REAL_HISTOGRAMS="" @@ -59,12 +62,23 @@ while [[ $# -gt 0 ]]; do --gpu) GPU_ID="$2"; shift 2 ;; --algo) ALGO="$2"; shift 2 ;; --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --radix-bits) RADIX_BITS="$2"; shift 2 ;; + --sample-stride) SAMPLE_STRIDE="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; *) echo "Unknown option: $1"; exit 1 ;; esac done export CUDA_VISIBLE_DEVICES="${GPU_ID}" +# Validate seq_len: need pages/seg > topk_val (page_size=16, reserved=3 pages) +MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * 16 )) +if [ "${SEQ_LEN}" -lt "${MIN_SEQ_LEN}" ]; then + echo "ERROR: --seq-len ${SEQ_LEN} too small for --topk-val ${TOPK_VAL}." + echo " Minimum: ${MIN_SEQ_LEN} (pages/seg must exceed topk_val + 3 reserved pages)" + exit 1 +fi + RESULTS_DIR="${SCRIPT_DIR}/results" mkdir -p "${RESULTS_DIR}" TIMESTAMP=$(date +%Y%m%d_%H%M%S) @@ -77,6 +91,8 @@ echo " Model: ${MODEL_NAME}" echo " Algorithm: ${ALGO}" echo " TopK: ${TOPK_VAL}" echo " GPU: ${GPU_ID}" +echo " Radix bits: ${RADIX_BITS} ($(( 1 << RADIX_BITS )) bins)" +echo " Sample stride: ${SAMPLE_STRIDE}" echo " Real histograms: ${REAL_HISTOGRAMS:-}" echo " Output: ${RUN_DIR}" echo "============================================================" @@ -117,7 +133,7 @@ fi PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ --topk-val "${TOPK_VAL}" \ --batch-size 4 \ - --seq-len 32768 \ + --seq-len ${SEQ_LEN} \ --num-kv-heads 2 \ "${AUTOTUNE_EXTRA_ARGS[@]}" \ --output-json "${AUTOTUNE_JSON}" \ @@ -157,7 +173,7 @@ fi PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ --batch-sizes 4 \ - --seq-lens 32768 \ + --seq-lens ${SEQ_LEN} \ --topk-vals "${TOPK_VAL}" \ --num-kv-heads 8 \ --distributions bucket_uniform normal \ @@ -166,6 +182,8 @@ PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ "${BENCH_EXTRA_ARGS[@]}" \ --autotune-json "${AUTOTUNE_JSON}" \ --filter-kernels naive sglang_ori sglang_m0 sglang_scale sglang_m1 sglang_m2 sglang_m3 sglang_m3_noscale sglang_m4 sglang_m6 sglang_m6_noscale sglang_m7 sglang_m7_noscale sglang_m8 sglang_m9 sglang_m9_noscale sglang_m10 sglang_m10_noscale sglang_m11 sglang_m13 sglang_m13_noscale sglang_m14 \ + --radix-bits "${RADIX_BITS}" \ + --sample-stride "${SAMPLE_STRIDE}" \ --repeat 20 \ --output-json "${BENCH_JSON}" \ 2>&1 | tee "${RUN_DIR}/step3_bench.log" diff --git a/examples/run_distribution_analysis_new.sh b/examples/run_distribution_analysis_new.sh index 1f89c0b..65e4f41 100755 --- a/examples/run_distribution_analysis_new.sh +++ b/examples/run_distribution_analysis_new.sh @@ -31,9 +31,12 @@ BENCH_DIR="${SCRIPT_DIR}/../benchmarks" # ── Defaults ────────────────────────────────────────────────── GPU_ID=4 MODEL_NAME="Qwen/Qwen3-1.7B" -TOPK_VAL=30 +TOPK_VAL=2048 MEM=0.7 ALGO="block_sparse_attention" +RADIX_BITS=8 +SAMPLE_STRIDE=1 +SEQ_LEN=65536 # The path to the raw_histograms.npy file (set to skip calibration) REAL_HISTOGRAMS="/data/datasets/xinrui/My_Projects/vortex_torch/examples/calibration/raw_histograms.npy" # REAL_HISTOGRAMS="" @@ -46,12 +49,23 @@ while [[ $# -gt 0 ]]; do --gpu) GPU_ID="$2"; shift 2 ;; --algo) ALGO="$2"; shift 2 ;; --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --radix-bits) RADIX_BITS="$2"; shift 2 ;; + --sample-stride) SAMPLE_STRIDE="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; *) echo "Unknown option: $1"; exit 1 ;; esac done export CUDA_VISIBLE_DEVICES="${GPU_ID}" +# Validate seq_len: need pages/seg > topk_val (page_size=16, reserved=3 pages) +MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * 16 )) +if [ "${SEQ_LEN}" -lt "${MIN_SEQ_LEN}" ]; then + echo "ERROR: --seq-len ${SEQ_LEN} too small for --topk-val ${TOPK_VAL}." + echo " Minimum: ${MIN_SEQ_LEN} (pages/seg must exceed topk_val + 3 reserved pages)" + exit 1 +fi + RESULTS_DIR="${SCRIPT_DIR}/results" mkdir -p "${RESULTS_DIR}" TIMESTAMP=$(date +%Y%m%d_%H%M%S) @@ -63,7 +77,10 @@ echo "Bucket Distribution Profiling (modes 3, 6, 7, 8, 9, 10, 11)" echo " Model: ${MODEL_NAME}" echo " Algorithm: ${ALGO}" echo " TopK: ${TOPK_VAL}" +echo " Seq len: ${SEQ_LEN} ($(( SEQ_LEN / 16 )) pages/seg)" echo " GPU: ${GPU_ID}" +echo " Radix bits: ${RADIX_BITS} ($(( 1 << RADIX_BITS )) bins)" +echo " Sample stride: ${SAMPLE_STRIDE}" echo " Real histograms: ${REAL_HISTOGRAMS:-}" echo " Output: ${RUN_DIR}" echo "============================================================" @@ -98,7 +115,7 @@ AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ --topk-val "${TOPK_VAL}" \ --batch-size 4 \ - --seq-len 32768 \ + --seq-len ${SEQ_LEN} \ --num-kv-heads 8 \ --real-histograms "${REAL_HIST_PATH}" \ --latency-rerank \ @@ -115,7 +132,7 @@ BENCH_JSON="${RUN_DIR}/bench_distribution.json" PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ --batch-sizes 4 \ - --seq-lens 32768 \ + --seq-lens ${SEQ_LEN} \ --topk-vals "${TOPK_VAL}" \ --num-kv-heads 8 \ --distributions bucket_uniform normal \ @@ -124,6 +141,8 @@ PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ --real-histograms "${REAL_HIST_PATH}" \ --autotune-json "${AUTOTUNE_JSON}" \ --filter-kernels naive sglang_ori sglang_m0 sglang_scale sglang_m3 sglang_m3_noscale sglang_m6 sglang_m6_noscale sglang_m7 sglang_m7_noscale sglang_m8 sglang_m9 sglang_m9_noscale sglang_m10 sglang_m10_noscale sglang_m11 sglang_m13 sglang_m13_noscale sglang_m14 \ + --radix-bits "${RADIX_BITS}" \ + --sample-stride "${SAMPLE_STRIDE}" \ --repeat 20 \ --output-json "${BENCH_JSON}" \ 2>&1 | tee "${RUN_DIR}/step3_bench.log" diff --git a/examples/verify_algo.py b/examples/verify_algo.py index 32ff5a3..a1d1b6f 100644 --- a/examples/verify_algo.py +++ b/examples/verify_algo.py @@ -119,10 +119,9 @@ def verify_algos( kv_cache_dtype: str = "auto", topk_type: str = "naive", topk_mapping_mode: int = 0, -topk_mapping_power: float = 0.5, +topk_mapping_hparam: float = 0.5, topk_mapping_lut_path: str = None, topk_mapping_quantiles_path: str = None, -index_cache_shared_layers: list = None, disable_cuda_graph: bool = False, benchmark: str = "amc23", ): @@ -143,10 +142,9 @@ def verify_algos( kv_cache_dtype=kv_cache_dtype, vortex_topk_type=topk_type, vortex_topk_mapping_mode=topk_mapping_mode, - vortex_topk_mapping_power=topk_mapping_power, + vortex_topk_mapping_hparam=topk_mapping_hparam, vortex_topk_mapping_lut_path=topk_mapping_lut_path, vortex_topk_mapping_quantiles_path=topk_mapping_quantiles_path, - vortex_index_cache_shared_layers=index_cache_shared_layers, ) tokenizer = AutoTokenizer.from_pretrained(model_name) if benchmark != "amc23" else None prompts, requests = _load_benchmark(benchmark, trials, tokenizer=tokenizer) @@ -310,14 +308,15 @@ def parse_args(): type=int, default=0, choices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], - help='TopK mapping mode: 0=none, 1=lut_cdf, 2=quantile, 3=power, 4=log, 5=index_cache, 6=asinh, 7=log1p, 8=trunc8, 9=erf, 10=tanh, 11=subtract, 12=adaptive_tail_window, 13=exp_stretch, 14=topk_window (default: 0).', + help='TopK mapping mode: 0=none, 1=lut_cdf, 2=quantile, 3=power, 4=log, 6=asinh, 7=log1p, 8=trunc8, 9=erf, 10=tanh, 11=subtract, 12=adaptive_tail_window, 13=exp_stretch, 14=topk_window (default: 0).', ) parser.add_argument( - "--topk-mapping-power", + "--topk-mapping-hparam", "--topk-mapping-power", type=float, default=0.5, - help='Hyperparameter for parametric modes: power exponent (mode 3), beta (mode 6 asinh), alpha (mode 7 log1p), rho tail expansion (mode 12). Default: 0.5.', + dest="topk_mapping_hparam", + help='Hyperparameter for parametric modes: power exponent (mode 3), beta (mode 6), alpha (mode 7/9/10/13), rho (mode 12/14). Default: 0.5.', ) parser.add_argument( @@ -334,14 +333,6 @@ def parse_args(): help="Path to .npy file with float32[256] quantiles for topk mapping mode 2.", ) - parser.add_argument( - "--index-cache-shared-layers", - type=int, - nargs="+", - default=None, - help="Layer IDs that reuse indices from the nearest preceding full layer (skip indexer).", - ) - parser.add_argument( "--benchmark", type=str, @@ -356,12 +347,6 @@ def parse_args(): if __name__ == "__main__": args = parse_args() - # --- Mode 5: Index Cache (default even-layer pattern) --- - if args.topk_mapping_mode == 5: - if args.index_cache_shared_layers is None: - args.index_cache_shared_layers = list(range(2, 28, 2)) # [2,4,6,...,26] - args.topk_mapping_mode = 0 - for bench_name in args.benchmark: if bench_name not in BENCHMARK_REGISTRY: print(f"WARNING: Unknown benchmark '{bench_name}', skipping. Available: {list(BENCHMARK_REGISTRY.keys())}") @@ -380,10 +365,9 @@ def parse_args(): kv_cache_dtype=args.kv_cache_dtype, topk_type=args.topk_type, topk_mapping_mode=args.topk_mapping_mode, - topk_mapping_power=args.topk_mapping_power, + topk_mapping_hparam=args.topk_mapping_hparam, topk_mapping_lut_path=args.topk_mapping_lut_path, topk_mapping_quantiles_path=args.topk_mapping_quantiles_path, - index_cache_shared_layers=args.index_cache_shared_layers, benchmark=bench_name, ) summary["benchmark"] = bench_name diff --git a/examples/verify_algo_topk_mapping.sh b/examples/verify_algo_topk_mapping.sh index c0a03c5..9a9f482 100644 --- a/examples/verify_algo_topk_mapping.sh +++ b/examples/verify_algo_topk_mapping.sh @@ -112,17 +112,17 @@ for r in data: if m not in best or r['gini'] < best[m]['gini']: best[m] = r for m in (3, 6, 7, 9, 10): - print(f'BEST_POWER_{m}={best[m][\"param\"]}' if m in best else f'BEST_POWER_{m}=0.5') + print(f'BEST_HPARAM_{m}={best[m][\"param\"]}' if m in best else f'BEST_HPARAM_{m}=0.5') " "${AUTOTUNE_JSON}")" - echo ">>> Autotuned best powers: mode3=${BEST_POWER_3} mode6=${BEST_POWER_6} mode7=${BEST_POWER_7} mode9=${BEST_POWER_9} mode10=${BEST_POWER_10}" + echo ">>> Autotuned best powers: mode3=${BEST_HPARAM_3} mode6=${BEST_HPARAM_6} mode7=${BEST_HPARAM_7} mode9=${BEST_HPARAM_9} mode10=${BEST_HPARAM_10}" echo "" else echo ">>> WARNING: ${REAL_HISTOGRAMS} not found, using default power=0.5 for all modes" - BEST_POWER_3=0.5 - BEST_POWER_6=0.5 - BEST_POWER_7=0.5 - BEST_POWER_9=0.5 - BEST_POWER_10=0.5 + BEST_HPARAM_3=0.5 + BEST_HPARAM_6=0.5 + BEST_HPARAM_7=0.5 + BEST_HPARAM_9=0.5 + BEST_HPARAM_10=0.5 fi # ============================================================ @@ -189,8 +189,8 @@ done # Mode 3: power — autotuned best p # ============================================================ for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_3_p${BEST_POWER_3}_${TIMESTAMP}.log" - echo ">>> Running mode 3 (power) p=${BEST_POWER_3} (autotuned) for ${algo}" + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_3_p${BEST_HPARAM_3}_${TIMESTAMP}.log" + echo ">>> Running mode 3 (power) p=${BEST_HPARAM_3} (autotuned) for ${algo}" echo ">>> Saving results to ${OUTFILE}" { time python verify_algo.py \ --trials 8 \ @@ -199,7 +199,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 3 \ - --topk-mapping-power ${BEST_POWER_3} \ + --topk-mapping-hparam ${BEST_HPARAM_3} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done @@ -208,8 +208,8 @@ done # Mode 6: asinh — autotuned best beta # ============================================================ for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_6_beta${BEST_POWER_6}_${TIMESTAMP}.log" - echo ">>> Running mode 6 (asinh) beta=${BEST_POWER_6} (autotuned) for ${algo}" + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_6_beta${BEST_HPARAM_6}_${TIMESTAMP}.log" + echo ">>> Running mode 6 (asinh) beta=${BEST_HPARAM_6} (autotuned) for ${algo}" echo ">>> Saving results to ${OUTFILE}" { time python verify_algo.py \ --trials 8 \ @@ -218,7 +218,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 6 \ - --topk-mapping-power ${BEST_POWER_6} \ + --topk-mapping-hparam ${BEST_HPARAM_6} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done @@ -227,8 +227,8 @@ done # Mode 7: log1p — autotuned best alpha # ============================================================ for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_7_alpha${BEST_POWER_7}_${TIMESTAMP}.log" - echo ">>> Running mode 7 (log1p) alpha=${BEST_POWER_7} (autotuned) for ${algo}" + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_7_alpha${BEST_HPARAM_7}_${TIMESTAMP}.log" + echo ">>> Running mode 7 (log1p) alpha=${BEST_HPARAM_7} (autotuned) for ${algo}" echo ">>> Saving results to ${OUTFILE}" { time python verify_algo.py \ --trials 8 \ @@ -237,7 +237,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 7 \ - --topk-mapping-power ${BEST_POWER_7} \ + --topk-mapping-hparam ${BEST_HPARAM_7} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done @@ -246,8 +246,8 @@ done # Mode 9: erf — autotuned best alpha # ============================================================ for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_9_alpha${BEST_POWER_9}_${TIMESTAMP}.log" - echo ">>> Running mode 9 (erf) alpha=${BEST_POWER_9} (autotuned) for ${algo}" + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_9_alpha${BEST_HPARAM_9}_${TIMESTAMP}.log" + echo ">>> Running mode 9 (erf) alpha=${BEST_HPARAM_9} (autotuned) for ${algo}" echo ">>> Saving results to ${OUTFILE}" { time python verify_algo.py \ --trials 8 \ @@ -256,7 +256,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 9 \ - --topk-mapping-power ${BEST_POWER_9} \ + --topk-mapping-hparam ${BEST_HPARAM_9} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done @@ -265,8 +265,8 @@ done # Mode 10: tanh — autotuned best alpha # ============================================================ for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_10_alpha${BEST_POWER_10}_${TIMESTAMP}.log" - echo ">>> Running mode 10 (tanh) alpha=${BEST_POWER_10} (autotuned) for ${algo}" + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_10_alpha${BEST_HPARAM_10}_${TIMESTAMP}.log" + echo ">>> Running mode 10 (tanh) alpha=${BEST_HPARAM_10} (autotuned) for ${algo}" echo ">>> Saving results to ${OUTFILE}" { time python verify_algo.py \ --trials 8 \ @@ -275,7 +275,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 10 \ - --topk-mapping-power ${BEST_POWER_10} \ + --topk-mapping-hparam ${BEST_HPARAM_10} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done diff --git a/examples/verify_algo_topk_mapping_new.sh b/examples/verify_algo_topk_mapping_new.sh index 4c96a15..6848e1e 100644 --- a/examples/verify_algo_topk_mapping_new.sh +++ b/examples/verify_algo_topk_mapping_new.sh @@ -78,9 +78,9 @@ for r in data: if m not in best or r['gini'] < best[m]['gini']: best[m] = r for m in (3, 6, 7, 9, 10, 13, 14): - print(f'BEST_POWER_{m}={best[m][\"param\"]}' if m in best else f'BEST_POWER_{m}=0.5') + print(f'BEST_HPARAM_{m}={best[m][\"param\"]}' if m in best else f'BEST_HPARAM_{m}=0.5') " "${AUTOTUNE_JSON}")" -echo ">>> Autotuned best powers: mode3=${BEST_POWER_3} mode6=${BEST_POWER_6} mode7=${BEST_POWER_7} mode9=${BEST_POWER_9} mode10=${BEST_POWER_10} mode13=${BEST_POWER_13} mode14=${BEST_POWER_14}" +echo ">>> Autotuned best powers: mode3=${BEST_HPARAM_3} mode6=${BEST_HPARAM_6} mode7=${BEST_HPARAM_7} mode9=${BEST_HPARAM_9} mode10=${BEST_HPARAM_10} mode13=${BEST_HPARAM_13} mode14=${BEST_HPARAM_14}" echo "" # ============================================================ @@ -107,11 +107,11 @@ done # Step 1: Mode 3 (power) — autotuned best p # ============================================================ echo "============================================================" -echo "Step 1: Mode 3 (power) — p=${BEST_POWER_3} (autotuned)" +echo "Step 1: Mode 3 (power) — p=${BEST_HPARAM_3} (autotuned)" echo "============================================================" for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_3_p${BEST_POWER_3}_${TIMESTAMP}.log" - echo ">>> Mode 3 (power) p=${BEST_POWER_3} algo=${algo}" + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_3_p${BEST_HPARAM_3}_${TIMESTAMP}.log" + echo ">>> Mode 3 (power) p=${BEST_HPARAM_3} algo=${algo}" { time python verify_algo.py \ --trials 8 \ --topk-val ${TOPK_VAL} \ @@ -119,7 +119,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 3 \ - --topk-mapping-power ${BEST_POWER_3} \ + --topk-mapping-hparam ${BEST_HPARAM_3} \ --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" @@ -129,11 +129,11 @@ done # Step 2: Mode 6 (asinh) — autotuned best beta # ============================================================ echo "============================================================" -echo "Step 2: Mode 6 (asinh) — beta=${BEST_POWER_6} (autotuned)" +echo "Step 2: Mode 6 (asinh) — beta=${BEST_HPARAM_6} (autotuned)" echo "============================================================" for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_6_beta${BEST_POWER_6}_${TIMESTAMP}.log" - echo ">>> Mode 6 (asinh) beta=${BEST_POWER_6} algo=${algo}" + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_6_beta${BEST_HPARAM_6}_${TIMESTAMP}.log" + echo ">>> Mode 6 (asinh) beta=${BEST_HPARAM_6} algo=${algo}" { time python verify_algo.py \ --trials 8 \ --topk-val ${TOPK_VAL} \ @@ -141,7 +141,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 6 \ - --topk-mapping-power ${BEST_POWER_6} \ + --topk-mapping-hparam ${BEST_HPARAM_6} \ --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" @@ -151,11 +151,11 @@ done # Step 3: Mode 7 (log1p) — autotuned best alpha # ============================================================ echo "============================================================" -echo "Step 3: Mode 7 (log1p) — alpha=${BEST_POWER_7} (autotuned)" +echo "Step 3: Mode 7 (log1p) — alpha=${BEST_HPARAM_7} (autotuned)" echo "============================================================" for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_7_alpha${BEST_POWER_7}_${TIMESTAMP}.log" - echo ">>> Mode 7 (log1p) alpha=${BEST_POWER_7} algo=${algo}" + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_7_alpha${BEST_HPARAM_7}_${TIMESTAMP}.log" + echo ">>> Mode 7 (log1p) alpha=${BEST_HPARAM_7} algo=${algo}" { time python verify_algo.py \ --trials 8 \ --topk-val ${TOPK_VAL} \ @@ -163,7 +163,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 7 \ - --topk-mapping-power ${BEST_POWER_7} \ + --topk-mapping-hparam ${BEST_HPARAM_7} \ --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" @@ -194,11 +194,11 @@ done # Step 5: Mode 9 (erf) — autotuned best alpha # ============================================================ echo "============================================================" -echo "Step 5: Mode 9 (erf) — alpha=${BEST_POWER_9} (autotuned)" +echo "Step 5: Mode 9 (erf) — alpha=${BEST_HPARAM_9} (autotuned)" echo "============================================================" for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_9_alpha${BEST_POWER_9}_${TIMESTAMP}.log" - echo ">>> Mode 9 (erf) alpha=${BEST_POWER_9} algo=${algo}" + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_9_alpha${BEST_HPARAM_9}_${TIMESTAMP}.log" + echo ">>> Mode 9 (erf) alpha=${BEST_HPARAM_9} algo=${algo}" { time python verify_algo.py \ --trials 8 \ --topk-val ${TOPK_VAL} \ @@ -206,7 +206,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 9 \ - --topk-mapping-power ${BEST_POWER_9} \ + --topk-mapping-hparam ${BEST_HPARAM_9} \ --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" @@ -216,11 +216,11 @@ done # Step 6: Mode 10 (tanh) — autotuned best alpha # ============================================================ echo "============================================================" -echo "Step 6: Mode 10 (tanh) — alpha=${BEST_POWER_10} (autotuned)" +echo "Step 6: Mode 10 (tanh) — alpha=${BEST_HPARAM_10} (autotuned)" echo "============================================================" for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_10_alpha${BEST_POWER_10}_${TIMESTAMP}.log" - echo ">>> Mode 10 (tanh) alpha=${BEST_POWER_10} algo=${algo}" + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_10_alpha${BEST_HPARAM_10}_${TIMESTAMP}.log" + echo ">>> Mode 10 (tanh) alpha=${BEST_HPARAM_10} algo=${algo}" { time python verify_algo.py \ --trials 8 \ --topk-val ${TOPK_VAL} \ @@ -228,7 +228,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 10 \ - --topk-mapping-power ${BEST_POWER_10} \ + --topk-mapping-hparam ${BEST_HPARAM_10} \ --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" @@ -272,7 +272,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 12 \ - --topk-mapping-power 4.0 \ + --topk-mapping-hparam 4.0 \ --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" @@ -283,11 +283,11 @@ done # ============================================================ echo "" echo "============================================================" -echo "Step 9: Mode 13 (exp_stretch) — alpha=${BEST_POWER_13} (autotuned)" +echo "Step 9: Mode 13 (exp_stretch) — alpha=${BEST_HPARAM_13} (autotuned)" echo "============================================================" for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_13_alpha${BEST_POWER_13}_${TIMESTAMP}.log" - echo ">>> Mode 13 (exp_stretch) alpha=${BEST_POWER_13} algo=${algo}" + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_13_alpha${BEST_HPARAM_13}_${TIMESTAMP}.log" + echo ">>> Mode 13 (exp_stretch) alpha=${BEST_HPARAM_13} algo=${algo}" { time python verify_algo.py \ --trials 8 \ --topk-val ${TOPK_VAL} \ @@ -295,7 +295,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 13 \ - --topk-mapping-power ${BEST_POWER_13} \ + --topk-mapping-hparam ${BEST_HPARAM_13} \ --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" @@ -306,11 +306,11 @@ done # ============================================================ echo "" echo "============================================================" -echo "Step 10: Mode 14 (topk_window) — rho=${BEST_POWER_14} (autotuned)" +echo "Step 10: Mode 14 (topk_window) — rho=${BEST_HPARAM_14} (autotuned)" echo "============================================================" for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_14_rho${BEST_POWER_14}_${TIMESTAMP}.log" - echo ">>> Mode 14 (topk_window) rho=${BEST_POWER_14} algo=${algo}" + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_14_rho${BEST_HPARAM_14}_${TIMESTAMP}.log" + echo ">>> Mode 14 (topk_window) rho=${BEST_HPARAM_14} algo=${algo}" { time python verify_algo.py \ --trials 8 \ --topk-val ${TOPK_VAL} \ @@ -318,7 +318,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 14 \ - --topk-mapping-power ${BEST_POWER_14} \ + --topk-mapping-hparam ${BEST_HPARAM_14} \ --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" @@ -343,7 +343,7 @@ PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ --real-histograms "${REAL_HISTOGRAMS}" \ --autotune-json "${AUTOTUNE_JSON}" \ --filter-kernels sglang_ori sglang_m0 sglang_m3 sglang_m6 sglang_m7 sglang_m8 sglang_m9 sglang_m10 sglang_m11 sglang_m13 sglang_m14 \ - --mapping-power-13 ${BEST_POWER_13} --mapping-power-14 ${BEST_POWER_14} \ + --mapping-hparam-13 ${BEST_HPARAM_13} --mapping-hparam-14 ${BEST_HPARAM_14} \ --repeat 5 \ --output-json "${COUNTER_JSON}" \ 2>&1 | tee "${RESULTS_DIR}/counters_${TIMESTAMP}.log" @@ -357,14 +357,14 @@ echo "============================================================" echo "All runs complete. Results in ${RESULTS_DIR}/" echo " Auto-tune: ${AUTOTUNE_JSON}" echo " Counters: ${COUNTER_JSON}" -echo " Mode 3 (power): p = ${BEST_POWER_3} (autotuned)" -echo " Mode 6 (asinh): beta = ${BEST_POWER_6} (autotuned)" -echo " Mode 7 (log1p): alpha = ${BEST_POWER_7} (autotuned)" +echo " Mode 3 (power): p = ${BEST_HPARAM_3} (autotuned)" +echo " Mode 6 (asinh): beta = ${BEST_HPARAM_6} (autotuned)" +echo " Mode 7 (log1p): alpha = ${BEST_HPARAM_7} (autotuned)" echo " Mode 8 (trunc8): (fixed)" -echo " Mode 9 (erf): alpha = ${BEST_POWER_9} (autotuned)" -echo " Mode 10 (tanh): alpha = ${BEST_POWER_10} (autotuned)" +echo " Mode 9 (erf): alpha = ${BEST_HPARAM_9} (autotuned)" +echo " Mode 10 (tanh): alpha = ${BEST_HPARAM_10} (autotuned)" echo " Mode 11 (subtract): (fixed)" echo " Mode 12 (tail_win): rho = 4.0" -echo " Mode 13 (exp_stretch):alpha = ${BEST_POWER_13} (autotuned)" -echo " Mode 14 (topk_window):rho = ${BEST_POWER_14} (autotuned)" +echo " Mode 13 (exp_stretch):alpha = ${BEST_HPARAM_13} (autotuned)" +echo " Mode 14 (topk_window):rho = ${BEST_HPARAM_14} (autotuned)" echo "============================================================" diff --git a/setup.py b/setup.py index 9c2186b..0fc46ad 100644 --- a/setup.py +++ b/setup.py @@ -18,6 +18,7 @@ 'csrc/utils_sglang.cu', 'csrc/topk.cu', 'csrc/topk_sglang.cu', + 'csrc/topk_sglang_profile.cu', ], include_dirs=['csrc'], extra_compile_args={ diff --git a/vortex_torch/indexer/output_func.py b/vortex_torch/indexer/output_func.py index b50ca74..e4424cd 100644 --- a/vortex_torch/indexer/output_func.py +++ b/vortex_torch/indexer/output_func.py @@ -247,7 +247,7 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso if self.topk_type == "sglang": # topk_output_sglang: (x, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, sparse_kv_indices, ...) mapping_mode = getattr(ctx, 'topk_mapping_mode', 0) - mapping_power = getattr(ctx, 'topk_mapping_power', 0.5) + mapping_hparam = getattr(ctx, 'topk_mapping_hparam', getattr(ctx, 'topk_mapping_power', 0.5)) mapping_lut = getattr(ctx, 'topk_mapping_lut', None) mapping_quantiles = getattr(ctx, 'topk_mapping_quantiles', None) mapping_noscale = getattr(ctx, 'topk_mapping_noscale', False) @@ -268,7 +268,7 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso ctx.page_reserved_eos, ctx.max_num_pages_per_request, mapping_mode, - mapping_power, + mapping_hparam, mapping_lut, mapping_quantiles, mapping_noscale, @@ -320,7 +320,7 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso ctx.page_reserved_bos, ctx.page_reserved_eos, mapping_mode, - mapping_power, + mapping_hparam, mapping_lut, mapping_quantiles, mapping_noscale, From 524834a5f5448e9595d746e95966e2190d4b9622 Mon Sep 17 00:00:00 2001 From: UED Date: Mon, 13 Apr 2026 03:09:46 -0400 Subject: [PATCH 18/22] Refactor TopK mapping and benchmarking scripts for enhanced profiling and usability - Updated autotune_topk_mapping.py to optimize hyperparameter tuning based on kernel latency. - Simplified the sweep grid and improved documentation for usage. - Enhanced bench_topk.py to expose public helpers and added CLI modes for benchmarking. - Introduced new remap functions and improved kernel integration for profiling. - Added watchdog timeout option in calibrate_topk.py for SGLang scheduler. - Removed outdated greedy_layer_search.py as part of code cleanup. --- benchmarks/autotune_topk_mapping.py | 797 ++++------ benchmarks/bench_topk.py | 1139 +++++--------- benchmarks/calibrate_topk.py | 13 +- benchmarks/greedy_layer_search.py | 117 -- csrc/archived/README.md | 19 + csrc/archived/fast_topk_vortex_prepass.cu | 525 +++++++ csrc/archived/topk_mapping_full.cuh | 217 +++ csrc/archived/topk_sglang_ori_fastpath.cu | 319 ++++ csrc/{ => archived}/topk_slgang_ori.cu | 0 csrc/register.cc | 42 +- csrc/register.h | 43 +- csrc/topk_mapping.cuh | 208 +-- csrc/topk_sglang.cu | 1231 ++++++--------- csrc/topk_sglang_profile.cu | 1329 +++++------------ examples/remap_function_bench.sh | 238 +++ examples/run_distribution_analysis.sh | 254 ++-- examples/run_distribution_analysis_new.sh | 194 +-- examples/test_topk.py | 118 ++ examples/verify_algo.py | 30 +- examples/verify_algo.sh | 7 +- examples/verify_algo_topk_mapping.sh | 396 ++--- examples/verify_algo_topk_mapping_new.sh | 433 ++---- ...ackends.sh => verify_external_backends.sh} | 0 vortex_torch/indexer/context.py | 27 +- vortex_torch/indexer/output_func.py | 51 +- 25 files changed, 3553 insertions(+), 4194 deletions(-) delete mode 100644 benchmarks/greedy_layer_search.py create mode 100644 csrc/archived/README.md create mode 100644 csrc/archived/fast_topk_vortex_prepass.cu create mode 100644 csrc/archived/topk_mapping_full.cuh create mode 100644 csrc/archived/topk_sglang_ori_fastpath.cu rename csrc/{ => archived}/topk_slgang_ori.cu (100%) create mode 100755 examples/remap_function_bench.sh create mode 100644 examples/test_topk.py rename examples/{verify_sparse_backends.sh => verify_external_backends.sh} (100%) diff --git a/benchmarks/autotune_topk_mapping.py b/benchmarks/autotune_topk_mapping.py index 8051c14..db21321 100644 --- a/benchmarks/autotune_topk_mapping.py +++ b/benchmarks/autotune_topk_mapping.py @@ -1,254 +1,131 @@ """ -Auto-tuner for TopK mapping hyperparameters. +Auto-tune TopK mapping hyperparameters by profiled kernel latency. -Sweeps all (mode, hyperparameter) combinations using the topk_hit_rate -kernel and ranks by Stage 1 resolution rate. +For each (mode, hyperparameter) combo in the sweep grid, this script runs +the fused remap+topk kernel (topk_output_sglang_fused) on synthetic or +real-distribution inputs, measures end-to-end latency with CUDA events, +and picks the hyperparameter with the lowest measured latency per mode. -Supports real-data score distributions via --real-histograms: loads the -raw_histograms.npy from calibration and synthesizes score tensors that -match the real bin distribution (by reversing the convert_to_uint8 mapping). - -Sweep grid: - - Mode 3 (power): p in [0.1, 0.25, 0.75, 0.9] - - Mode 6 (asinh): beta in [0.1, 0.5, 1, 2, 4] - - Mode 7 (log1p): alpha in [0.1, 0.5, 0.75, 1, 2, 4, 8] - - Baselines: mode 0 (none), mode 4 (log) +Distribution statistics (gini, max/mean, counter-based Stage-2 cost) are +still collected for diagnostics, but they do NOT drive the ranking — the +ranking is purely latency-driven. Usage: - python benchmarks/autotune_topk_mapping.py --topk-val 30 --real-histograms calibration/raw_histograms.npy - python benchmarks/autotune_topk_mapping.py --topk-val 30 --output-json results.json + python benchmarks/autotune_topk_mapping.py \\ + --topk-val 2048 --batch-size 4 --seq-len 65536 --num-kv-heads 8 \\ + --real-histograms calibration/raw_histograms.npy \\ + --output-json autotune_results.json """ import argparse import json import math -from typing import List +from typing import Dict, List, Optional import numpy as np import torch from bench_topk import make_topk_inputs, bench_kernel, compute_histogram_stats -from vortex_torch_C import topk_profile_histogram, topk_profile_counters, topk_output_sglang - - - -SWEEP_GRID = { - # (mode, param_name, param_values) - 3: ("power_exp", [0.1, 0.25, 0.5, 0.75, 0.9, 2.0, 4.0]), - 6: ("beta", [0.1, 0.5, 1.0, 2.0, 4.0]), - 7: ("alpha", [0.1, 0.5, 0.75, 1.0, 2.0, 4.0, 8.0]), - 9: ("alpha", [0.1, 0.5, 1.0, 2.0, 4.0]), - 10: ("alpha", [0.1, 0.5, 1.0, 2.0, 4.0]), - 13: ("alpha", [0.5, 1.0, 2.0, 4.0, 8.0]), - 14: ("rho", [2.0, 4.0, 8.0, 16.0]), -} -BASELINES = { - 0: ("none", 0.5), - 4: ("log", 0.5), - 8: ("trunc8", 0.5), - 11: ("subtract", 0.5), -} -# Noscale baselines for parametric transform modes (skip auto-range pre-pass) -NOSCALE_BASELINES = { - 3: ("power_noscale", [0.5]), - 6: ("asinh_noscale", [1.0]), - 7: ("log1p_noscale", [1.0]), - 9: ("erf_noscale", [1.0]), - 10: ("tanh_noscale", [1.0]), - 13: ("exp_stretch_noscale", [1.0, 4.0]), +from vortex_torch_C import ( + topk_output_sglang_fused, + topk_profile_histogram, + topk_profile_counters, +) + + +# Only parametric modes need auto-tuning. Mode 0 (none) and mode 4 (log) +# have no knob; mode 0 is always the baseline. +SWEEP_GRID: Dict[int, List[float]] = { + 3: [0.1, 0.25, 0.5, 0.75, 0.9], # power: p + 6: [0.1, 0.5, 1.0, 2.0, 4.0], # asinh: beta + 7: [0.1, 0.5, 1.0, 2.0, 4.0, 8.0], # log1p: alpha + 9: [0.1, 0.5, 1.0, 2.0, 4.0], # erf: alpha + 10: [0.1, 0.5, 1.0, 2.0, 4.0], # tanh: alpha + 11: [-1.0, -0.5, 0.0, 0.5, 1.0], # subtract: pivot (free hparam) + 13: [0.5, 1.0, 2.0, 4.0, 8.0], # exp_stretch: alpha } + +PARAM_NAME = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", 11: "pivot", 13: "alpha"} MODE_NAMES = { - 0: "none", - 3: "power", - 4: "log", - 6: "asinh", - 7: "log1p", - 8: "trunc8", - 9: "erf", - 10: "tanh", - 11: "subtract", - 13: "exp_stretch", - 14: "topk_window", + 0: "none", 1: "lut_cdf", 2: "quantile", + 3: "power", 4: "log", 6: "asinh", 7: "log1p", + 8: "trunc8", 9: "erf", 10: "tanh", 11: "subtract", 13: "exp_stretch", } +# Non-parametric modes — no knob to sweep; timed once as a reference point. +# LUT_CDF (1) and QUANTILE (2) are added here at runtime when the caller +# passes --lut-path / --quantiles-path. +BASELINES = [(0, 0.5), (4, 0.5), (8, 0.5)] -def _key_to_fp16(key: int) -> np.float16: - """Invert the convert_to_uint8 sign-flip for a single 16-bit key.""" - if key >= 0x8000: - bits = key & 0x7FFF - else: - bits = (~key) & 0xFFFF - return np.array([bits], dtype=np.uint16).view(np.float16)[0] +# ---------- Real-distribution score generation ---------- -def build_bin_range_table(): - """Build per-bin (lo, hi) fp16 value tables by iterating all 65536 fp16 bit patterns. +def _key_to_fp16(key: int) -> np.float16: + """Invert convert_to_uint8's sign-flip for a single 16-bit key.""" + bits = (key & 0x7FFF) if key >= 0x8000 else ((~key) & 0xFFFF) + return np.array([bits], dtype=np.uint16).view(np.float16)[0] - For each fp16 value, compute its bin via convert_to_uint8 logic, then track - the min/max fp16 value that lands in each bin. - Returns: - (bin_lo, bin_hi): two [256] float32 arrays — the min and max fp16 values per bin. - """ - # Generate all 65536 fp16 bit patterns +def _build_bin_range_table(): + """Return per-bin (lo, hi) fp16 value tables for all 256 radix bins.""" all_bits = np.arange(65536, dtype=np.uint16) all_fp16 = all_bits.view(np.float16) - - # Compute convert_to_uint8 for each: key = sign-flip, bin = key >> 8 keys = np.where( (all_bits & 0x8000).astype(bool), (~all_bits).astype(np.uint16), all_bits | np.uint16(0x8000), ) bins = (keys >> 8).astype(np.uint8) - - # Convert to float32 for min/max (fp16 has NaNs/Infs, filter them) all_f32 = all_fp16.astype(np.float32) valid = np.isfinite(all_f32) - bin_lo = np.full(256, np.inf, dtype=np.float32) bin_hi = np.full(256, -np.inf, dtype=np.float32) - for b in range(256): mask = (bins == b) & valid if mask.any(): vals = all_f32[mask] bin_lo[b] = vals.min() bin_hi[b] = vals.max() - - # For any bin with no valid fp16 values, fall back to midpoint empty = bin_lo > bin_hi for b in np.where(empty)[0]: - mid_key = (int(b) << 8) | 0x80 - val = float(_key_to_fp16(mid_key)) + val = float(_key_to_fp16((int(b) << 8) | 0x80)) bin_lo[b] = val bin_hi[b] = val - return bin_lo, bin_hi -def generate_remap_lut(mode: int, param: float) -> np.ndarray: - """Generate a 256-entry uint8 LUT that approximates a transform mode. - - For each of the 256 fp16 radix bins, compute the transform of the - bin's midpoint value, then linearly map transformed values to [0,255]. - The resulting LUT can be used with mode=1 (LUT CDF) infrastructure, - replacing expensive per-element transcendental math with a single - shared memory lookup. - - Args: - mode: TopKMappingMode (3=Power, 4=Log, 6=Asinh, 7=Log1p, 9=Erf, 10=Tanh) - param: power_exp/beta/alpha for the transform - - Returns: - lut: [256] uint8 array mapping original_bin -> remapped_bin - """ - bin_lo, bin_hi = build_bin_range_table() - midpoints = (bin_lo + bin_hi) / 2.0 # [256] float32 - - # Apply transform - if mode == 3: # power - transformed = np.sign(midpoints) * np.abs(midpoints) ** param - elif mode == 4: # log - transformed = np.sign(midpoints) * np.log(np.abs(midpoints) + 1.0) - elif mode == 6: # asinh - transformed = np.arcsinh(param * midpoints) - elif mode == 7: # log1p - transformed = np.sign(midpoints) * np.log1p(param * np.abs(midpoints)) - elif mode == 9: # erf - from scipy.special import erf - transformed = erf(param * midpoints) - elif mode == 10: # tanh - transformed = np.tanh(param * midpoints) - else: - # Identity fallback - transformed = midpoints.copy() - - # Handle NaN/Inf from edge cases - transformed = np.nan_to_num(transformed, nan=0.0, posinf=0.0, neginf=0.0) - - # Linear map to [0, 255] - tmin, tmax = transformed.min(), transformed.max() - if tmax > tmin: - lut = np.clip(((transformed - tmin) / (tmax - tmin) * 255), 0, 255).astype(np.uint8) - else: - lut = np.full(256, 128, dtype=np.uint8) - - return lut - - -def scores_from_histogram( - histogram: np.ndarray, - total_pages: int, - device: str = "cuda", -) -> torch.Tensor: - """Generate score tensor matching a real bin distribution. - - For each sampled bin, generates a uniform random fp16 value within the - bin's actual value range (not just the midpoint), so that mapped transforms - see diverse input values. - - Args: - histogram: [256] aggregated bin counts from calibration - total_pages: number of score entries to generate - device: torch device - - Returns: - scores: [total_pages, 1, 1] bfloat16 tensor - """ - bin_lo, bin_hi = build_bin_range_table() - - # Normalize histogram to probability distribution +def _scores_from_histogram(histogram: np.ndarray, total_pages: int, device="cuda") -> torch.Tensor: + bin_lo, bin_hi = _build_bin_range_table() counts = histogram.astype(np.float64) total = counts.sum() if total == 0: return torch.zeros(total_pages, 1, 1, dtype=torch.bfloat16, device=device) probs = counts / total - - # Sample bin indices according to the real distribution bin_indices = np.random.choice(256, size=total_pages, p=probs) - - # Uniform random within each bin's fp16 range lo = bin_lo[bin_indices] hi = bin_hi[bin_indices] rand = np.random.uniform(0, 1, size=total_pages).astype(np.float32) scores_f32 = lo + rand * (hi - lo) + return torch.from_numpy(scores_f32).to(torch.bfloat16).reshape(total_pages, 1, 1).to(device) + - # Convert float32 -> bfloat16 tensor - scores = torch.from_numpy(scores_f32).to(torch.bfloat16) - return scores.reshape(total_pages, 1, 1).to(device) - - -def make_real_inputs( - batch_size: int, - num_kv_heads: int, - seq_len: int, - page_size: int, - topk_val: int, - reserved_bos: int, - reserved_eos: int, - histogram: np.ndarray, - device: str = "cuda", -) -> dict: - """Build CSR-formatted inputs with scores matching a real histogram.""" - eff_batch_size = batch_size * num_kv_heads - num_pages_per_seg = math.ceil(seq_len / page_size) - total_dense_pages = eff_batch_size * num_pages_per_seg - sparse_per_seg = min(topk_val + reserved_bos + reserved_eos, num_pages_per_seg) - total_sparse_pages = eff_batch_size * sparse_per_seg +def _make_real_inputs(args, histogram: np.ndarray) -> dict: + eff_bs = args.batch_size * args.num_kv_heads + num_pages_per_seg = math.ceil(args.seq_len / args.page_size) + total_dense = eff_bs * num_pages_per_seg + sparse_per_seg = min(args.topk_val + args.reserved_bos + args.reserved_eos, num_pages_per_seg) dense_kv_indptr = torch.arange( - 0, (eff_batch_size + 1) * num_pages_per_seg, num_pages_per_seg, - dtype=torch.int32, device=device, + 0, (eff_bs + 1) * num_pages_per_seg, num_pages_per_seg, + dtype=torch.int32, device="cuda", ) sparse_kv_indptr = torch.arange( - 0, (eff_batch_size + 1) * sparse_per_seg, sparse_per_seg, - dtype=torch.int32, device=device, + 0, (eff_bs + 1) * sparse_per_seg, sparse_per_seg, + dtype=torch.int32, device="cuda", ) - dense_kv_indices = torch.arange(total_dense_pages, dtype=torch.int32, device=device) - sparse_kv_indices = torch.zeros(total_sparse_pages, dtype=torch.int32, device=device) - - x = scores_from_histogram(histogram, total_dense_pages, device=device) + dense_kv_indices = torch.arange(total_dense, dtype=torch.int32, device="cuda") + sparse_kv_indices = torch.zeros(eff_bs * sparse_per_seg, dtype=torch.int32, device="cuda") + x = _scores_from_histogram(histogram, total_dense) return { "x": x, @@ -256,402 +133,240 @@ def make_real_inputs( "sparse_kv_indptr": sparse_kv_indptr, "dense_kv_indices": dense_kv_indices, "sparse_kv_indices": sparse_kv_indices, - "eff_batch_size": eff_batch_size, + "eff_batch_size": eff_bs, "num_pages_per_seg": num_pages_per_seg, "sparse_per_seg": sparse_per_seg, } -def run_sweep(args) -> List[dict]: - """Run all (mode, hyperparam) combos and return ranked results.""" - results = [] - - # Load real histogram if provided - real_histogram = None - if args.real_histograms: - raw = np.load(args.real_histograms) # [num_segments, 256] - real_histogram = raw.sum(axis=0) if raw.ndim > 1 else raw # aggregate to [256] - - distributions = args.distributions - if real_histogram is not None: - distributions = ["real"] - - for dist in distributions: - if dist == "real": - inputs = make_real_inputs( - batch_size=args.batch_size, - num_kv_heads=args.num_kv_heads, - seq_len=args.seq_len, - page_size=args.page_size, - topk_val=args.topk_val, - reserved_bos=args.reserved_bos, - reserved_eos=args.reserved_eos, - histogram=real_histogram, - ) - else: - inputs = make_topk_inputs( - batch_size=args.batch_size, - num_kv_heads=args.num_kv_heads, - seq_len=args.seq_len, - page_size=args.page_size, - topk_val=args.topk_val, - reserved_bos=args.reserved_bos, - reserved_eos=args.reserved_eos, - score_dtype=torch.bfloat16, - distribution=dist, - ) - - eff_bs = inputs["eff_batch_size"] - - def evaluate(mode: int, power: float, label: str, noscale: bool = False, - lut_tensor=None): - hists = torch.zeros(eff_bs, 256, dtype=torch.int32, device="cuda") - topk_profile_histogram( - inputs["x"], - inputs["dense_kv_indptr"], - hists, - eff_bs, - args.reserved_bos, - args.reserved_eos, - mode, - power, - lut_tensor, # lut - None, # quantiles - noscale, - ) - torch.cuda.synchronize() - stats = compute_histogram_stats(hists) - result = { - "label": label, - "mode": mode, - "mode_name": MODE_NAMES.get(mode, f"m{mode}"), - "param": power, - "noscale": noscale, - "distribution": dist, - "gini": stats["gini"], - "max_mean_ratio": stats["max_mean_ratio"], - "num_nonzero_bins": stats["num_nonzero_bins"], - } - - # Counter-based metrics (Stage 2 cost analysis) - if args.counters: - inputs["sparse_kv_indices"].zero_() - counter_buf = torch.zeros(eff_bs, 6, dtype=torch.int32, device="cuda") - topk_profile_counters( - inputs["x"], - inputs["dense_kv_indptr"], - inputs["sparse_kv_indptr"], - inputs["dense_kv_indices"], - inputs["sparse_kv_indices"], - counter_buf, - eff_bs, - args.topk_val, - args.reserved_bos, - args.reserved_eos, - inputs["num_pages_per_seg"], - mode, - power, - lut_tensor, # lut - None, # quantiles - noscale, - ) - torch.cuda.synchronize() - c = counter_buf.float() - result["num_equal_mean"] = c[:, 2].mean().item() - result["remaining_k_mean"] = c[:, 3].mean().item() - result["refine_rounds_mean"] = c[:, 4].mean().item() - result["stage2_input_mean"] = c[:, 5].mean().item() - result["res_rate_mean"] = (c[:, 3] == 0).float().mean().item() - - return result - - # Baselines - for mode, (name, default_power) in BASELINES.items(): - r = evaluate(mode, default_power, f"m{mode}_{name}") - results.append(r) - - # Parametric sweep (scaled) - for mode, (param_name, values) in SWEEP_GRID.items(): - mname = MODE_NAMES[mode] - for val in values: - label = f"m{mode}_{mname}_{param_name}={val}" - r = evaluate(mode, val, label) - results.append(r) - - # Noscale sweep for parametric modes - for mode, (name, values) in NOSCALE_BASELINES.items(): - mname = MODE_NAMES[mode] - for val in values: - label = f"m{mode}_{mname}_noscale_{val}" - r = evaluate(mode, val, label, noscale=True) - results.append(r) - - # LUT approximation sweep: generate a LUT for each (mode, param) and - # evaluate via mode=1 (LUT CDF). This replaces per-element transcendentals - # with a single shared memory lookup. - if args.lut_sweep: - lut_modes = { - 3: [0.25, 0.5, 0.75], - 6: [0.5, 1.0, 2.0], - 7: [0.5, 1.0, 2.0], - 9: [0.5, 1.0, 2.0], - 10: [0.5, 1.0, 2.0], - } - for src_mode, params in lut_modes.items(): - src_name = MODE_NAMES[src_mode] - for p in params: - try: - lut_np = generate_remap_lut(src_mode, p) - lut_t = torch.from_numpy(lut_np).cuda() - label = f"lut_{src_name}_{p}" - # Evaluate as mode=1 (LUT CDF) with the generated LUT - r = evaluate(1, 0.5, label, lut_tensor=lut_t) - r["lut_source_mode"] = src_mode - r["lut_source_param"] = p - results.append(r) - except ImportError: - # scipy not available for erf - pass - - return results - - -def print_table(results: List[dict], show_latency: bool = False): - """Print ranked results as a formatted table.""" - has_counters = any("res_rate_mean" in r for r in results) - has_latency = any("full_kernel_ms" in r for r in results) - - # Primary ranking: by res_rate_mean (higher=better) if counters, else by gini (lower=better) - if has_counters: - ranked = sorted(results, key=lambda r: -r.get("res_rate_mean", 0.0)) - rank_label = "ranked by res_rate, higher=better" - else: - ranked = sorted(results, key=lambda r: r["gini"]) - rank_label = "ranked by Gini, lower=better" - - # Build header - cols = f"{'Rank':>4s} {'Label':<35s} {'Dist':<12s} {'Gini':>6s} {'Max/Mean':>8s} {'NZBins':>6s}" - if has_counters: - cols += f" {'ResRate':>7s} {'RemK':>5s} {'Rnds':>4s} {'S2In':>5s}" - if has_latency and show_latency: - cols += f" {'LatMs':>9s} {'LatRk':>5s}" - - print(f"\n{'=' * len(cols)}") - print(f"TopK Mapping Auto-Tune Results ({rank_label})") - print("=" * len(cols)) - print(cols) - print("-" * len(cols)) - - for i, r in enumerate(ranked): - noscale_tag = " [NS]" if r.get("noscale", False) else "" - line = ( - f"{i+1:4d} {r['label'] + noscale_tag:<35s} {r['distribution']:<12s} " - f"{r['gini']:6.3f} " - f"{r['max_mean_ratio']:8.2f} {r['num_nonzero_bins']:6d}" - ) - if has_counters: - rr = r.get("res_rate_mean", 0.0) - rk = r.get("remaining_k_mean", 0.0) - rnds = r.get("refine_rounds_mean", 0.0) - s2in = r.get("stage2_input_mean", 0.0) - line += f" {rr:7.3f} {rk:5.0f} {rnds:4.1f} {s2in:5.0f}" - if has_latency and show_latency: - lat = r.get("full_kernel_ms", float("nan")) - lat_rank = r.get("latency_rank", "-") - line += f" {lat:9.4f} {lat_rank:>5s}" if isinstance(lat_rank, str) else f" {lat:9.4f} {lat_rank:5d}" - print(line) - - print("=" * len(cols)) - if ranked: - best = ranked[0] - msg = ( - f"\nBest overall: {best['label']} (dist={best['distribution']}) " - f"— gini={best['gini']:.3f}, max/mean={best['max_mean_ratio']:.2f}" - ) - if has_counters: - msg += f", res_rate={best.get('res_rate_mean', 0):.3f}" - if "full_kernel_ms" in best: - msg += f", latency={best['full_kernel_ms']:.4f}ms" - print(msg) - - # If latency data available, also print best by latency - if has_latency and show_latency: - lat_ranked = sorted([r for r in results if "full_kernel_ms" in r], - key=lambda r: r["full_kernel_ms"]) - if lat_ranked: - best_lat = lat_ranked[0] - print( - f"Best by latency: {best_lat['label']} (dist={best_lat['distribution']}) " - f"— latency={best_lat['full_kernel_ms']:.4f}ms, gini={best_lat['gini']:.3f}" - ) - - # Per-mode best summary - mode_best = {} - for r in results: - m = r["mode"] - if has_counters: - is_better = m not in mode_best or r.get("res_rate_mean", 0) > mode_best[m].get("res_rate_mean", 0) - else: - is_better = m not in mode_best or r["gini"] < mode_best[m]["gini"] - if is_better: - mode_best[m] = r - - if mode_best: - print("\nBest per mode:") - for m in sorted(mode_best.keys()): - r = mode_best[m] - mname = MODE_NAMES.get(m, f"m{m}") - if m in SWEEP_GRID: - param_name = SWEEP_GRID[m][0] - param_str = f"{param_name}={r['param']}" - else: - param_str = "(baseline)" - ns_str = " noscale" if r.get("noscale", False) else "" - lat_str = f" latency={r['full_kernel_ms']:.4f}ms" if "full_kernel_ms" in r else "" - counter_str = f" res_rate={r.get('res_rate_mean', 0):.3f}" if has_counters else "" - print( - f" Mode {m:d} ({mname:>5s}{ns_str}): {param_str:<20s} " - f"gini={r['gini']:.3f} max/mean={r['max_mean_ratio']:.2f}{counter_str}{lat_str}" - ) - - -def latency_rerank(results: List[dict], args) -> List[dict]: - """Re-rank top Gini candidates by actual kernel latency.""" - # Sort by Gini, take top N - ranked = sorted(results, key=lambda r: r["gini"]) - finalists = ranked[:args.latency_top_n] - - print(f"\n--- Latency re-ranking: timing top {len(finalists)} Gini finalists ---") +# ---------- Latency-based evaluation ---------- - # Build inputs for latency measurement - real_histogram = None - if args.real_histograms: - raw = np.load(args.real_histograms) - real_histogram = raw.sum(axis=0) if raw.ndim > 1 else raw +def _time_fused(inputs, args, mode: int, power: float) -> dict: + eff_bs = inputs["eff_batch_size"] + pages_per_seg = inputs["num_pages_per_seg"] + inputs["sparse_kv_indices"].zero_() + lut_t = getattr(args, "_mapping_lut", None) if mode == 1 else None + q_t = getattr(args, "_mapping_quantiles", None) if mode == 2 else None + call_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, + args.topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + mode, + power, + lut_t, + q_t, + ) + return bench_kernel(topk_output_sglang_fused, call_args, + warmup=args.warmup, repeat=args.repeat) - if real_histogram is not None: - inputs = make_real_inputs( - batch_size=args.batch_size, - num_kv_heads=args.num_kv_heads, - seq_len=args.seq_len, - page_size=args.page_size, - topk_val=args.topk_val, - reserved_bos=args.reserved_bos, - reserved_eos=args.reserved_eos, - histogram=real_histogram, - ) - else: - inputs = make_topk_inputs( - batch_size=args.batch_size, - num_kv_heads=args.num_kv_heads, - seq_len=args.seq_len, - page_size=args.page_size, - topk_val=args.topk_val, - reserved_bos=args.reserved_bos, - reserved_eos=args.reserved_eos, - score_dtype=torch.bfloat16, - distribution="normal", - ) +def _collect_diagnostics(inputs, args, mode: int, power: float) -> dict: + """Optional distribution/counter stats for reporting only (post-timing).""" eff_bs = inputs["eff_batch_size"] pages_per_seg = inputs["num_pages_per_seg"] + diag = {} + lut_t = getattr(args, "_mapping_lut", None) if mode == 1 else None + q_t = getattr(args, "_mapping_quantiles", None) if mode == 2 else None + + if args.collect_stats: + hist = torch.zeros(eff_bs, 256, dtype=torch.int32, device="cuda") + topk_profile_histogram( + inputs["x"], inputs["dense_kv_indptr"], hist, + eff_bs, args.reserved_bos, args.reserved_eos, + mode, power, lut_t, q_t, + ) + torch.cuda.synchronize() + diag.update(compute_histogram_stats(hist)) - for r in finalists: + counter_buf = torch.zeros(eff_bs, 6, dtype=torch.int32, device="cuda") inputs["sparse_kv_indices"].zero_() - # For LUT-generated entries, regenerate the LUT tensor - lut_tensor = None - if "lut_source_mode" in r: - lut_np = generate_remap_lut(r["lut_source_mode"], r["lut_source_param"]) - lut_tensor = torch.from_numpy(lut_np).cuda() - call_args = ( + topk_profile_counters( inputs["x"], inputs["dense_kv_indptr"], inputs["sparse_kv_indptr"], inputs["dense_kv_indices"], inputs["sparse_kv_indices"], + counter_buf, eff_bs, args.topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, - r["mode"], - r["param"], - lut_tensor, # lut - None, # quantiles - r.get("noscale", False), + mode, + power, + lut_t, + q_t, + ) + torch.cuda.synchronize() + c = counter_buf.float() + diag["threshold_bin_mean"] = c[:, 0].mean().item() + diag["num_equal_mean"] = c[:, 2].mean().item() + diag["refine_rounds_mean"] = c[:, 4].mean().item() + + return diag + + +def _run_sweep(args, inputs, dist_label: str) -> List[dict]: + results = [] + + # Baselines: time them but their param is fixed. + for mode, power in BASELINES: + lat = _time_fused(inputs, args, mode, power) + entry = { + "mode": mode, + "mode_name": MODE_NAMES.get(mode, f"m{mode}"), + "param_name": "(baseline)", + "param": power, + "distribution": dist_label, + "latency_ms": lat["mean_ms"], + "latency_median_ms": lat["median_ms"], + "latency_min_ms": lat["min_ms"], + } + entry.update(_collect_diagnostics(inputs, args, mode, power)) + results.append(entry) + print( + f" mode={mode:>2d} ({MODE_NAMES[mode]:>5s}) baseline " + f" latency={lat['mean_ms']:.4f} ms" ) - latency = bench_kernel(topk_output_sglang, call_args, - warmup=10, repeat=args.latency_repeat) - r["full_kernel_ms"] = latency["mean_ms"] - print(f" {r['label']:<35s} gini={r['gini']:.3f} latency={latency['mean_ms']:.4f}ms") - # Re-rank finalists by latency - finalists.sort(key=lambda r: r["full_kernel_ms"]) - for i, r in enumerate(finalists): - r["latency_rank"] = i + 1 - r["gini_rank"] = next(j+1 for j, x in enumerate(ranked) if x is r) + # Parametric sweep, one (mode, param) combo at a time. + for mode, values in SWEEP_GRID.items(): + pname = PARAM_NAME[mode] + for val in values: + lat = _time_fused(inputs, args, mode, float(val)) + entry = { + "mode": mode, + "mode_name": MODE_NAMES.get(mode, f"m{mode}"), + "param_name": pname, + "param": float(val), + "distribution": dist_label, + "latency_ms": lat["mean_ms"], + "latency_median_ms": lat["median_ms"], + "latency_min_ms": lat["min_ms"], + } + entry.update(_collect_diagnostics(inputs, args, mode, float(val))) + results.append(entry) + print( + f" mode={mode:>2d} ({MODE_NAMES[mode]:>5s}) {pname}={val:<6.3f} " + f" latency={lat['mean_ms']:.4f} ms" + ) return results -def main(): - parser = argparse.ArgumentParser( - description="Auto-tune TopK mapping hyperparameters" +def _print_ranked(results: List[dict]) -> None: + ranked = sorted(results, key=lambda r: r["latency_ms"]) + header = ( + f"{'Rank':>4s} {'Mode':<12s} {'Param':<14s} {'Dist':<10s} {'Latency (ms)':>14s}" ) + print("\n" + "=" * len(header)) + print("TopK auto-tune results (ranked by measured kernel latency, lower is better)") + print("=" * len(header)) + print(header) + print("-" * len(header)) + for i, r in enumerate(ranked): + param_str = f"{r['param_name']}={r['param']}" if r["param_name"] != "(baseline)" else "(baseline)" + print( + f"{i + 1:4d} {r['mode_name']:<12s} {param_str:<14s} " + f"{r['distribution']:<10s} {r['latency_ms']:14.4f}" + ) + print("=" * len(header)) + + # Best per mode. + best: Dict[int, dict] = {} + for r in results: + m = r["mode"] + if m not in best or r["latency_ms"] < best[m]["latency_ms"]: + best[m] = r + print("\nBest per mode (by latency):") + for m in sorted(best.keys()): + r = best[m] + param_str = f"{r['param_name']}={r['param']}" if r["param_name"] != "(baseline)" else "(baseline)" + print( + f" mode {m:>2d} ({r['mode_name']:>5s}): {param_str:<16s} " + f"latency={r['latency_ms']:.4f} ms" + ) + + +def main(): + parser = argparse.ArgumentParser("TopK mapping hyperparameter auto-tuner (latency-driven)") parser.add_argument("--batch-size", type=int, default=4) - parser.add_argument("--seq-len", type=int, default=4096) - parser.add_argument("--topk-val", type=int, default=30) - parser.add_argument("--num-kv-heads", type=int, default=2) + parser.add_argument("--num-kv-heads", type=int, default=8) + parser.add_argument("--seq-len", type=int, default=65536) + parser.add_argument("--topk-val", type=int, default=2048) parser.add_argument("--page-size", type=int, default=16) parser.add_argument("--reserved-bos", type=int, default=1) parser.add_argument("--reserved-eos", type=int, default=2) - parser.add_argument( - "--distributions", nargs="+", - default=["normal"], - help="Score distributions for synthetic data (ignored when --real-histograms is set)", - ) - parser.add_argument( - "--real-histograms", type=str, default=None, - help="Path to raw_histograms.npy from calibration. When set, auto-tunes on " - "real score distribution instead of synthetic data.", - ) - parser.add_argument( - "--output-json", type=str, default=None, - help="Save results to JSON file", - ) - parser.add_argument("--latency-rerank", action="store_true", - help="Re-rank top Gini finalists by actual kernel latency") - parser.add_argument("--latency-top-n", type=int, default=10, - help="Number of Gini finalists to re-rank by latency (default: 10)") - parser.add_argument("--latency-repeat", type=int, default=50, - help="Kernel timing repetitions for latency measurement (default: 50)") - parser.add_argument("--counters", action="store_true", - help="Collect counter-based metrics (Stage 2 cost analysis) for each config") - parser.add_argument("--lut-sweep", action="store_true", - help="Generate and evaluate LUT approximations for parametric transform modes") + parser.add_argument("--distributions", type=str, nargs="+", + default=["normal"], + help="Synthetic distributions when --real-histograms is not set.") + parser.add_argument("--real-histograms", type=str, default=None, + help="Path to raw_histograms.npy from calibration.") + parser.add_argument("--warmup", type=int, default=20) + parser.add_argument("--repeat", type=int, default=100) + parser.add_argument("--collect-stats", action="store_true", + help="Also collect histogram + counter diagnostics (post-timing, no cost).") + parser.add_argument("--output-json", type=str, default=None) + parser.add_argument("--lut-path", type=str, default=None, + help="Path to .npy uint8[256] LUT for MAPPING_LUT_CDF (mode 1).") + parser.add_argument("--quantiles-path", type=str, default=None, + help="Path to .npy float32[256] quantile table for MAPPING_QUANTILE (mode 2).") args = parser.parse_args() - source = f"real ({args.real_histograms})" if args.real_histograms else f"synthetic ({args.distributions})" - print(f"Auto-tuning TopK mapping hyperparameters") - print(f" batch_size={args.batch_size}, seq_len={args.seq_len}, " - f"topk_val={args.topk_val}, num_kv_heads={args.num_kv_heads}") - print(f" score source: {source}") - n_parametric = sum(len(v) for _, v in SWEEP_GRID.values()) - n_baselines = len(BASELINES) - n_dists = 1 if args.real_histograms else len(args.distributions) - print(f" sweep: {n_parametric} parametric + {n_baselines} baselines " - f"= {n_parametric + n_baselines} combos x {n_dists} dists") + args._mapping_lut = None + args._mapping_quantiles = None + # Include modes 1/2 as baselines when calibration tables are provided. + if args.lut_path: + lut_np = np.load(args.lut_path).astype(np.uint8) + args._mapping_lut = torch.from_numpy(lut_np).cuda() + BASELINES.append((1, 0.5)) + print(f"[autotune] loaded LUT from {args.lut_path}") + if args.quantiles_path: + q_np = np.load(args.quantiles_path).astype(np.float32) + args._mapping_quantiles = torch.from_numpy(q_np).cuda() + BASELINES.append((2, 0.5)) + print(f"[autotune] loaded quantiles from {args.quantiles_path}") + + real_histogram: Optional[np.ndarray] = None + if args.real_histograms: + raw = np.load(args.real_histograms) + real_histogram = raw.sum(axis=0) if raw.ndim > 1 else raw - results = run_sweep(args) + all_results: List[dict] = [] - if args.latency_rerank: - results = latency_rerank(results, args) + if real_histogram is not None: + inputs = _make_real_inputs(args, real_histogram) + print("\n=== Latency sweep on REAL distribution " + f"(batch={args.batch_size} heads={args.num_kv_heads} seq={args.seq_len} topk={args.topk_val}) ===") + all_results += _run_sweep(args, inputs, "real") + else: + for dist in args.distributions: + inputs = make_topk_inputs( + batch_size=args.batch_size, + num_kv_heads=args.num_kv_heads, + seq_len=args.seq_len, + page_size=args.page_size, + topk_val=args.topk_val, + reserved_bos=args.reserved_bos, + reserved_eos=args.reserved_eos, + score_dtype=torch.bfloat16, + distribution=dist, + ) + print(f"\n=== Latency sweep on synthetic dist={dist} ===") + all_results += _run_sweep(args, inputs, dist) - print_table(results, show_latency=args.latency_rerank) + _print_ranked(all_results) if args.output_json: with open(args.output_json, "w") as f: - json.dump(results, f, indent=2) + json.dump(all_results, f, indent=2) print(f"\nResults saved to {args.output_json}") diff --git a/benchmarks/bench_topk.py b/benchmarks/bench_topk.py index a913bde..4bd5bec 100644 --- a/benchmarks/bench_topk.py +++ b/benchmarks/bench_topk.py @@ -1,36 +1,43 @@ """ TopK kernel benchmarking suite. -Measures kernel-level latency for the three topk variants (naive/CUB, -sglang with mapping modes) across configurable grid of batch sizes, -sequence lengths, topk values, and KV head counts. - -Usage: - python benchmarking/bench_topk.py --batch-sizes 4 8 --seq-lens 2048 4096 --topk-vals 30 --num-kv-heads 2 --repeat 50 +Lean rewrite after the remap-benchmark refactor. Exposes three public +helpers used by autotune_topk_mapping.py (make_topk_inputs, bench_kernel, +compute_histogram_stats) and a CLI with two modes: + + - default : time the baseline (unmapped) kernel and the fused + kernel across a grid of (mode, power, batch, seq_len, + topk_val, distribution) configs. + - --remap-bench: time baseline vs fused vs split-phase (remap-only + + unmapped-topk-on-remapped) and report threshold stats + from topk_profile_counters. """ import argparse import json import math import statistics -from typing import Dict, List, Optional +from typing import Dict, List import numpy as np import torch from vortex_torch_C import ( - topk_output, topk_output_sglang, topk_output_sglang_ori, topk_profile_histogram, - topk_profile_stage1, topk_profile_counters, + topk_output, + topk_output_sglang, # unmapped baseline + topk_output_sglang_fused, # fused remap + topk + topk_remap_only, # standalone remap + topk_profile_histogram, + topk_profile_counters, ) -# Canonical mapping mode names — used in logs, tables, and plots + MAPPING_MODE_NAMES = { 0: "None", - 1: "LUT CDF", + 1: "LUT_CDF", 2: "Quantile", 3: "Power", 4: "Log", - 5: "Index Cache", 6: "Asinh", 7: "Log1p", 8: "Trunc8", @@ -38,25 +45,30 @@ 10: "Tanh", 11: "Subtract", 13: "ExpStretch", - 14: "TopkWindow", } -MAPPING_MODE_FORMULAS = { - 0: "None (fp16 bucketing)", - 1: "LUT CDF (calibrated)", - 2: "Quantile (calibrated)", - 3: "Power: sign(x)*|x|^p", - 4: "Log: sign(x)*log(|x|+1)", - 5: "Index Cache", - 6: "Asinh: asinh(beta*x)", - 7: "Log1p: sign(x)*log1p(alpha*|x|)", - 8: "Trunc8: bf16 upper-8-bit bucketing", - 9: "Erf: erf(alpha*x)", - 10: "Tanh: tanh(alpha*x)", - 11: "Subtract: x - pivot (RadiK-style)", - 13: "ExpStretch: exp(alpha*x)", - 14: "TopkWindow: k-aware linear windowing", -} + +def _load_autotune_hparams(path: str) -> Dict[int, float]: + """Load per-mode best hyperparameters from an autotune_results.json. + + The JSON is produced by autotune_topk_mapping.py and contains a list of + {mode, param, latency_ms, ...} entries. For each mode we pick the entry + with the lowest measured latency and return {mode: best_param}. + + Modes with no parametric sweep (0=None, 4=Log) return a dummy 0.5; the + caller should override to taste. + """ + with open(path) as f: + data = json.load(f) + best: Dict[int, dict] = {} + for r in data: + m = r.get("mode") + lat = r.get("latency_ms") + if m is None or lat is None: + continue + if m not in best or lat < best[m]["latency_ms"]: + best[m] = r + return {m: float(r["param"]) for m, r in best.items()} def make_topk_inputs( @@ -71,7 +83,7 @@ def make_topk_inputs( distribution: str = "normal", device: str = "cuda", ) -> dict: - """Synthesize realistic CSR-formatted paged attention inputs.""" + """Synthesize CSR-formatted paged attention inputs for kernel timing.""" eff_batch_size = batch_size * num_kv_heads num_pages_per_seg = math.ceil(seq_len / page_size) total_dense_pages = eff_batch_size * num_pages_per_seg @@ -89,7 +101,6 @@ def make_topk_inputs( dense_kv_indices = torch.arange(total_dense_pages, dtype=torch.int32, device=device) sparse_kv_indices = torch.zeros(total_sparse_pages, dtype=torch.int32, device=device) - # Generate scores with the requested distribution if distribution == "normal": x = torch.randn(total_dense_pages, 1, 1, device=device) elif distribution == "lognormal": @@ -97,14 +108,11 @@ def make_topk_inputs( elif distribution == "uniform": x = torch.rand(total_dense_pages, 1, 1, device=device) elif distribution == "bucket_uniform": - # Uniform across all 256 fp16 radix buckets. - # Random uint16 bit patterns → interpret as fp16. - # Bucket = upper 8 bits of sign-flipped fp16, so random bits → uniform buckets. + # Uniform across all 256 fp16 radix buckets. Random uint16 bit + # patterns → interpret as fp16. NaN/Inf patterns collapse to ±0. raw_bits = torch.randint(0, 65536, (total_dense_pages,), dtype=torch.int32, device=device) - # Exclude fp16 NaN/Inf (exponent=31, i.e. |bits| >= 0x7C00) abs_bits = raw_bits & 0x7FFF - raw_bits[abs_bits >= 0x7C00] = raw_bits[abs_bits >= 0x7C00] & 0x8000 # → ±0 - # Reinterpret int16 bits as fp16, then widen to float32 + raw_bits[abs_bits >= 0x7C00] = raw_bits[abs_bits >= 0x7C00] & 0x8000 x = raw_bits.to(torch.int16).view(torch.float16).float().reshape(total_dense_pages, 1, 1) else: raise ValueError(f"Unknown distribution: {distribution}") @@ -124,7 +132,7 @@ def make_topk_inputs( def bench_kernel(kernel_fn, args, warmup: int = 10, repeat: int = 100) -> dict: - """Time a kernel with CUDA events, return latency stats in ms.""" + """Time a kernel with CUDA events. Returns latency stats in ms.""" for _ in range(warmup): kernel_fn(*args) torch.cuda.synchronize() @@ -148,796 +156,304 @@ def bench_kernel(kernel_fn, args, warmup: int = 10, repeat: int = 100) -> dict: def compute_histogram_stats(histograms: torch.Tensor) -> dict: - """Compute bin distribution statistics from histogram tensor [B, 256].""" + """Bin distribution statistics from histogram tensor [B, 256].""" h = histograms.float() - # Aggregate across batch dimension h_sum = h.sum(dim=0) # [256] - nonzero_bins = h_sum[h_sum > 0] - if len(nonzero_bins) == 0: + nonzero = h_sum[h_sum > 0] + if len(nonzero) == 0: return { "max_mean_ratio": 0.0, "std": 0.0, "gini": 0.0, "num_nonzero_bins": 0, "entropy": 0.0, "effective_bins": 0.0, } - - mean_val = nonzero_bins.mean().item() - max_val = nonzero_bins.max().item() - std_val = nonzero_bins.std().item() if len(nonzero_bins) > 1 else 0.0 - - # Gini coefficient - sorted_bins = nonzero_bins.sort().values + mean_val = nonzero.mean().item() + max_val = nonzero.max().item() + std_val = nonzero.std().item() if len(nonzero) > 1 else 0.0 + sorted_bins = nonzero.sort().values n = len(sorted_bins) - index = torch.arange(1, n + 1, device=sorted_bins.device, dtype=torch.float32) - gini = (2.0 * (index * sorted_bins).sum() / (n * sorted_bins.sum()) - (n + 1) / n).item() - - # Shannon entropy (base-2) - p = nonzero_bins / nonzero_bins.sum() + idx = torch.arange(1, n + 1, device=sorted_bins.device, dtype=torch.float32) + gini = (2.0 * (idx * sorted_bins).sum() / (n * sorted_bins.sum()) - (n + 1) / n).item() + p = nonzero / nonzero.sum() entropy = -(p * p.log2()).sum().item() - # Effective number of bins: 2^entropy - effective_bins = 2 ** entropy - return { "max_mean_ratio": max_val / mean_val if mean_val > 0 else 0.0, "std": std_val, "gini": max(0.0, gini), - "num_nonzero_bins": int(len(nonzero_bins)), + "num_nonzero_bins": int(len(nonzero)), "entropy": entropy, - "effective_bins": effective_bins, + "effective_bins": 2 ** entropy, } -NUM_HISTOGRAM_BINS = 256 - +def _collect_threshold_stats(inputs, topk_val, pages_per_seg, args, mode: int, power: float) -> dict: + """Run topk_profile_counters once and aggregate threshold-bin stats. -def _histogram_target_pages(pages_per_seg: int, min_samples_per_bin: int = 512) -> int: - """Compute adaptive page count for statistically reliable histograms. - - With 256 radix bins, each bin needs enough samples for stable gini / - max-mean statistics. Returns a total page count rounded up to a full - segment boundary so every segment contributes equally. + Profile kernel is invoked AFTER all latency measurements have finished, + so the counter writes never contaminate timing. """ - min_pages = min_samples_per_bin * NUM_HISTOGRAM_BINS - return math.ceil(min_pages / pages_per_seg) * pages_per_seg + eff_bs = inputs["eff_batch_size"] + counter_buf = torch.zeros(eff_bs, 6, dtype=torch.int32, device="cuda") + inputs["sparse_kv_indices"].zero_() + lut_t = getattr(args, "_mapping_lut", None) if mode == 1 else None + q_t = getattr(args, "_mapping_quantiles", None) if mode == 2 else None + topk_profile_counters( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + counter_buf, + eff_bs, + topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + mode, + power, + lut_t, + q_t, + ) + torch.cuda.synchronize() + c = counter_buf.float() + # Selected from threshold bin = topk_val - num_above (clamped >= 0). + sel_from_thr = (float(topk_val) - c[:, 1]).clamp(min=0.0) + return { + "threshold_bin_mean": c[:, 0].mean().item(), + "threshold_bin_max": c[:, 0].max().item(), + "num_above_mean": c[:, 1].mean().item(), + "threshold_bin_size_mean": c[:, 2].mean().item(), # NUM_EQUAL + "threshold_bin_size_max": c[:, 2].max().item(), + "selected_from_thr_mean": sel_from_thr.mean().item(), + "selected_from_thr_max": sel_from_thr.max().item(), + "refine_rounds_mean": c[:, 4].mean().item(), + } -def _load_autotune_powers(path: str) -> Dict[int, float]: - """Extract best per-mode power from autotune JSON. +def _resolve_hparam(args, mode: int) -> float: + """Pick the hyperparameter for a mode: autotune JSON wins, then --mapping-hparam.""" + if mode == 0: + return 0.5 # unused for MAPPING_NONE + hparams: Dict[int, float] = getattr(args, "_autotune_hparams", {}) or {} + if mode in hparams: + return hparams[mode] + return args.mapping_hparam - Ranks by res_rate_mean (higher=better) if present, else by gini (lower=better). - Returns {mode: best_power}, e.g. {3: 0.25, 6: 1.0, 7: 2.0}. - """ - with open(path) as f: - data = json.load(f) - has_res_rate = any("res_rate_mean" in r for r in data) +def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, + distribution, modes: List[int]) -> dict: + """Time baseline, fused, and split-phase for each mode at one config.""" + inputs = make_topk_inputs( + batch_size=batch_size, + num_kv_heads=num_kv_heads, + seq_len=seq_len, + page_size=args.page_size, + topk_val=topk_val, + reserved_bos=args.reserved_bos, + reserved_eos=args.reserved_eos, + score_dtype=torch.bfloat16, + distribution=distribution, + ) + eff_bs = inputs["eff_batch_size"] + pages_per_seg = inputs["num_pages_per_seg"] + total_dense = inputs["x"].numel() + + # Baseline: unmapped topk. + baseline_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + ) + inputs["sparse_kv_indices"].zero_() + baseline = bench_kernel(topk_output_sglang, baseline_args, args.warmup, args.repeat) + + # Pre-allocate the float32 buffer used for the split-phase (remap → baseline). + remapped = torch.empty(total_dense, dtype=torch.float32, device="cuda").reshape(inputs["x"].shape) + + config = { + "batch_size": batch_size, + "num_kv_heads": num_kv_heads, + "seq_len": seq_len, + "topk_val": topk_val, + "distribution": distribution, + "pages_per_seg": pages_per_seg, + "baseline_ms": baseline["mean_ms"], + "modes": [], + } - best: Dict[int, dict] = {} - for r in data: - m = r.get("mode") - if m not in (3, 6, 7, 9, 10, 13, 14): - continue - if has_res_rate: - score = r.get("res_rate_mean", 0.0) - is_better = m not in best or score > best[m]["_score"] - else: - score = r.get("gini", 1.0) - is_better = m not in best or score < best[m]["_score"] - if is_better: - best[m] = {"param": r["param"], "_score": score} + for mode in modes: + power = _resolve_hparam(args, mode) + + lut_t = getattr(args, "_mapping_lut", None) if mode == 1 else None + q_t = getattr(args, "_mapping_quantiles", None) if mode == 2 else None + fused_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + mode, power, lut_t, q_t, + ) + inputs["sparse_kv_indices"].zero_() + fused = bench_kernel(topk_output_sglang_fused, fused_args, args.warmup, args.repeat) + + # Split-phase timing: first the standalone remap, then the unmapped + # topk on the remapped buffer. + remap_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + remapped, + eff_bs, args.reserved_bos, args.reserved_eos, + mode, power, + ) + remap_only = bench_kernel(topk_remap_only, remap_args, args.warmup, args.repeat) + + split_topk_args = ( + remapped, + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + ) + # Run remap once so the buffer is populated for warmup of topk-on-remapped. + topk_remap_only(*remap_args) + torch.cuda.synchronize() + inputs["sparse_kv_indices"].zero_() + split_topk = bench_kernel(topk_output_sglang, split_topk_args, args.warmup, args.repeat) + + # Counter collection is run AFTER all timing measurements for this mode + # so it cannot affect the timings. + stats = _collect_threshold_stats(inputs, topk_val, pages_per_seg, args, mode, power) + + row = { + "mode": mode, + "mode_name": MAPPING_MODE_NAMES.get(mode, f"m{mode}"), + "power": power, + "remap_ms": remap_only["mean_ms"], + "topk_after_remap_ms": split_topk["mean_ms"], + "split_total_ms": remap_only["mean_ms"] + split_topk["mean_ms"], + "fused_ms": fused["mean_ms"], + **stats, + } + config["modes"].append(row) - return {m: v["param"] for m, v in best.items()} + return config -def _resolve_mode_hparam(args, mode: int) -> float: - """Return the power/beta/alpha for a parametric mapping mode. +def _print_remap_table(results: List[dict]) -> None: + header = ( + f"{'mode':<12s} {'remap_us':>9s} {'topk_us':>9s} {'split_us':>9s} " + f"{'fused_us':>9s} {'base_us':>9s} {'thr_bin':>7s} {'thr_size':>8s} {'sel_thr':>7s}" + ) + for cfg in results: + banner = ( + f"\n[batch={cfg['batch_size']} heads={cfg['num_kv_heads']} " + f"seq_len={cfg['seq_len']} topk={cfg['topk_val']} " + f"dist={cfg['distribution']} pages_per_seg={cfg['pages_per_seg']}]" + ) + print(banner) + print(" Baseline: mapping_mode=0 (raw fp16 bucketing)") + print(header) + print("-" * len(header)) + base_us = cfg["baseline_ms"] * 1000.0 + for row in cfg["modes"]: + label = f"{row['mode_name']}(p={row['power']})" if row["mode"] != 0 else "None" + print( + f"{label:<12s} " + f"{row['remap_ms'] * 1000.0:9.2f} " + f"{row['topk_after_remap_ms'] * 1000.0:9.2f} " + f"{row['split_total_ms'] * 1000.0:9.2f} " + f"{row['fused_ms'] * 1000.0:9.2f} " + f"{base_us:9.2f} " + f"{row['threshold_bin_mean']:7.1f} " + f"{row['threshold_bin_size_mean']:8.1f} " + f"{row['selected_from_thr_mean']:7.1f}" + ) - Priority: per-mode CLI flag > autotune JSON > global --mapping-power. - """ - per_mode_flag = {3: args.mapping_hparam_3, 6: args.mapping_hparam_6, 7: args.mapping_hparam_7, - 9: getattr(args, 'mapping_hparam_9', None), 10: getattr(args, 'mapping_hparam_10', None), - 13: getattr(args, 'mapping_hparam_13', None), 14: getattr(args, 'mapping_hparam_14', None)} - if mode in per_mode_flag and per_mode_flag[mode] is not None: - return per_mode_flag[mode] - if hasattr(args, "_autotune_powers") and mode in args._autotune_powers: - return args._autotune_powers[mode] - return args.mapping_hparam +def _run_remap_bench(args) -> None: + modes = [int(m) for m in args.mapping_modes] + if 0 not in modes: + modes = [0] + modes -def _run_subphase_profiling(subphase_modes, inputs, eff_bs, topk_val, - pages_per_seg, args, mapping_lut, mapping_quantiles): - """Run sub-phase profiling (histogram_only + stage1_full) for each mode. + results = [] + for bs in args.batch_sizes: + for heads in args.num_kv_heads: + for seq_len in args.seq_lens: + for topk_val in args.topk_vals: + for dist in args.distributions: + cfg = _remap_bench_one_config( + args, bs, heads, seq_len, topk_val, dist, modes, + ) + results.append(cfg) - For topk <= 512, runs inline. For topk > 512, runs each mode in a - separate subprocess to avoid CUDA shared memory exhaustion from - accumulated kernel template registrations. - """ - import subprocess, sys, tempfile, os - - for kernel_name, s1_mode, s1_power, s1_noscale, result in subphase_modes: - s1_lut = mapping_lut if s1_mode == 1 else None - s1_q = mapping_quantiles if s1_mode == 2 else None - - if topk_val <= 512: - # Inline: run directly in this process - hist_buf = torch.zeros(eff_bs, 256, dtype=torch.int32, device="cuda") - hist_args = ( - inputs["x"], inputs["dense_kv_indptr"], hist_buf, eff_bs, - args.reserved_bos, args.reserved_eos, - s1_mode, s1_power, s1_lut, s1_q, s1_noscale, - ) - hist_result = bench_kernel(topk_profile_histogram, hist_args, args.warmup, args.repeat) - - inputs["sparse_kv_indices"].zero_() - stage1_args = ( - inputs["x"], inputs["dense_kv_indptr"], inputs["sparse_kv_indptr"], - inputs["dense_kv_indices"], inputs["sparse_kv_indices"], - eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, - s1_mode, s1_power, s1_lut, s1_q, s1_noscale, - ) - stage1_result = bench_kernel(topk_profile_stage1, stage1_args, args.warmup, args.repeat) - else: - # Subprocess: fresh CUDA context per mode to avoid shared memory exhaustion - script = f""" -import torch, json, sys -sys.path.insert(0, '{os.path.dirname(os.path.abspath(__file__))}') -from vortex_torch_C import topk_profile_histogram, topk_profile_stage1 -from bench_topk import make_topk_inputs, bench_kernel - -inputs = make_topk_inputs( - batch_size={inputs['x'].shape[0] // (eff_bs // (inputs['x'].shape[0] if eff_bs == inputs['x'].shape[0] else 1)) if False else 1}, - num_kv_heads=1, seq_len={pages_per_seg * 16}, - page_size=16, topk_val={topk_val}, - reserved_bos={args.reserved_bos}, reserved_eos={args.reserved_eos}, - score_dtype=torch.bfloat16, distribution="normal", -) -eff_bs = {eff_bs} -# Recreate inputs with correct eff_bs -inputs = make_topk_inputs( - batch_size={eff_bs // max(1, eff_bs // pages_per_seg) if False else eff_bs}, - num_kv_heads=1, seq_len={pages_per_seg * 16}, - page_size=16, topk_val={topk_val}, - reserved_bos={args.reserved_bos}, reserved_eos={args.reserved_eos}, - score_dtype=torch.bfloat16, distribution="normal", -) -eff_bs = inputs["eff_batch_size"] - -hist_buf = torch.zeros(eff_bs, 256, dtype=torch.int32, device="cuda") -hist_result = bench_kernel(topk_profile_histogram, - (inputs["x"], inputs["dense_kv_indptr"], hist_buf, eff_bs, - {args.reserved_bos}, {args.reserved_eos}, {s1_mode}, {s1_power}, - None, None, {s1_noscale}), 5, {args.repeat}) - -inputs["sparse_kv_indices"].zero_() -stage1_result = bench_kernel(topk_profile_stage1, - (inputs["x"], inputs["dense_kv_indptr"], inputs["sparse_kv_indptr"], - inputs["dense_kv_indices"], inputs["sparse_kv_indices"], - eff_bs, {topk_val}, {args.reserved_bos}, {args.reserved_eos}, - inputs["num_pages_per_seg"], {s1_mode}, {s1_power}, - None, None, {s1_noscale}), 5, {args.repeat}) - -print(json.dumps({{"hist": hist_result, "stage1": stage1_result}})) -""" - try: - proc = subprocess.run( - [sys.executable, "-c", script], - capture_output=True, text=True, timeout=60, - env={**os.environ, "PYTHONPATH": os.path.dirname(os.path.abspath(__file__)) + "/.."}) - if proc.returncode == 0: - data = json.loads(proc.stdout.strip().split("\n")[-1]) - hist_result = data["hist"] - stage1_result = data["stage1"] - else: - # Subprocess failed — skip sub-phase for this mode - continue - except Exception: - continue - - result['histogram_only_mean_ms'] = hist_result['mean_ms'] - result['histogram_only_median_ms'] = hist_result['median_ms'] - result['stage1_full_mean_ms'] = stage1_result['mean_ms'] - result['stage1_full_median_ms'] = stage1_result['median_ms'] - result['route_overhead_mean_ms'] = stage1_result['mean_ms'] - hist_result['mean_ms'] - result['route_overhead_median_ms'] = stage1_result['median_ms'] - hist_result['median_ms'] - result['stage2_refine_mean_ms'] = result['mean_ms'] - stage1_result['mean_ms'] - result['stage2_refine_median_ms'] = result['median_ms'] - stage1_result['median_ms'] - - -def run_benchmark(args) -> List[dict]: - """Run the full benchmark sweep and return results.""" - # Load autotune results if provided - if args.autotune_json: - args._autotune_powers = _load_autotune_powers(args.autotune_json) - print(f"Loaded autotune best powers: {args._autotune_powers}") - else: - args._autotune_powers = {} - - dtype_map = {"bfloat16": torch.bfloat16, "float32": torch.float32} - score_dtype = dtype_map[args.score_dtype] - - # Load real histogram if provided - real_histogram = None - _scores_from_histogram = None - if args.real_histograms: - from autotune_topk_mapping import scores_from_histogram - _scores_from_histogram = scores_from_histogram - raw = np.load(args.real_histograms) - real_histogram = raw.sum(axis=0) if raw.ndim > 1 else raw - - # Extend distributions with "real" if calibration data is provided - distributions = list(args.distributions) - if real_histogram is not None: - distributions.append("real") - args.distributions = distributions - - # Print GPU info - gpu_name = torch.cuda.get_device_name(0) - gpu_props = torch.cuda.get_device_properties(0) - print(f"TopK Kernel Benchmark Results") - print(f"GPU: {gpu_name} | SM count: {gpu_props.multi_processor_count}") - print(f"Score dtype: {args.score_dtype} | Warmup: {args.warmup} | Repeat: {args.repeat}") - print(f"Radix bits: {args.radix_bits} ({1 << args.radix_bits} bins) | Sample stride: {args.sample_stride}") - print("=" * 90) - - # Load optional LUT / quantiles - mapping_lut = None - mapping_quantiles = None - if args.lut_path: - lut_np = np.load(args.lut_path).astype(np.uint8) - mapping_lut = torch.from_numpy(lut_np).cuda() - if args.quantiles_path: - q_np = np.load(args.quantiles_path).astype(np.float32) - mapping_quantiles = torch.from_numpy(q_np).cuda() - - # Build kernel list - all_kernels = { - "naive": "naive", - "sglang_ori": "sglang_ori", - "sglang_m0": "sglang_m0", - "sglang_scale": "sglang_scale", # mode 3 with p=1.0 (identity + linear auto-range scaling) - "sglang_m3": "sglang_m3", - "sglang_m3_noscale": "sglang_m3_noscale", - "sglang_m4": "sglang_m4", - "sglang_m6": "sglang_m6", - "sglang_m6_noscale": "sglang_m6_noscale", - "sglang_m7": "sglang_m7", - "sglang_m7_noscale": "sglang_m7_noscale", - "sglang_m8": "sglang_m8", - "sglang_m9": "sglang_m9", - "sglang_m9_noscale": "sglang_m9_noscale", - "sglang_m10": "sglang_m10", - "sglang_m10_noscale": "sglang_m10_noscale", - "sglang_m11": "sglang_m11", - "sglang_m13": "sglang_m13", - "sglang_m13_noscale": "sglang_m13_noscale", - "sglang_m14": "sglang_m14", - } - if mapping_lut is not None: - all_kernels["sglang_m1"] = "sglang_m1" - if mapping_quantiles is not None: - all_kernels["sglang_m2"] = "sglang_m2" - - if args.filter_kernels: - # Validate: if the user explicitly requested sglang_m1 or sglang_m2 but - # the required calibration file was not provided, fail loudly instead of - # silently skipping these modes. - if "sglang_m1" in args.filter_kernels and "sglang_m1" not in all_kernels: - raise RuntimeError( - "sglang_m1 (LUT CDF) was requested in --filter-kernels but no " - "--lut-path was provided. Mode 1 requires a calibrated LUT file " - "(lut.npy from calibrate_topk.py). Either supply --lut-path or " - "remove sglang_m1 from --filter-kernels." - ) - if "sglang_m2" in args.filter_kernels and "sglang_m2" not in all_kernels: - raise RuntimeError( - "sglang_m2 (Quantile) was requested in --filter-kernels but no " - "--quantiles-path was provided. Mode 2 requires a calibrated " - "quantiles file (quantiles.npy from calibrate_topk.py). Either " - "supply --quantiles-path or remove sglang_m2 from --filter-kernels." - ) - all_kernels = {k: v for k, v in all_kernels.items() if k in args.filter_kernels} + _print_remap_table(results) - # Naive kernel only supports bf16 - if score_dtype != torch.bfloat16 and "naive" in all_kernels: - print(f"Note: naive kernel only supports bfloat16, skipping for {args.score_dtype}") - del all_kernels["naive"] + if args.output_json: + with open(args.output_json, "w") as f: + json.dump(results, f, indent=2) + print(f"\nResults saved to {args.output_json}") - all_results = [] +def _run_latency_sweep(args) -> None: + """Simple baseline-vs-fused latency sweep (no split-phase, no counters).""" + modes = [int(m) for m in args.mapping_modes] + results = [] for bs in args.batch_sizes: - for seq_len in args.seq_lens: - for topk_val in args.topk_vals: - for num_kv_heads in args.num_kv_heads: + for heads in args.num_kv_heads: + for seq_len in args.seq_lens: + for topk_val in args.topk_vals: for dist in args.distributions: - if dist == "real" and real_histogram is not None: - inputs = make_topk_inputs( - batch_size=bs, - num_kv_heads=num_kv_heads, - seq_len=seq_len, - page_size=args.page_size, - topk_val=topk_val, - reserved_bos=args.reserved_bos, - reserved_eos=args.reserved_eos, - score_dtype=score_dtype, - distribution="normal", - ) - # Replace scores with real-distribution scores - total_dense = inputs["eff_batch_size"] * inputs["num_pages_per_seg"] - inputs["x"] = _scores_from_histogram( - real_histogram, total_dense, device="cuda", - ) - else: - inputs = make_topk_inputs( - batch_size=bs, - num_kv_heads=num_kv_heads, - seq_len=seq_len, - page_size=args.page_size, - topk_val=topk_val, - reserved_bos=args.reserved_bos, - reserved_eos=args.reserved_eos, - score_dtype=score_dtype, - distribution=dist, - ) - + inputs = make_topk_inputs( + batch_size=bs, num_kv_heads=heads, seq_len=seq_len, + page_size=args.page_size, topk_val=topk_val, + reserved_bos=args.reserved_bos, reserved_eos=args.reserved_eos, + score_dtype=torch.bfloat16, distribution=dist, + ) eff_bs = inputs["eff_batch_size"] pages_per_seg = inputs["num_pages_per_seg"] - - config_str = ( - f"bs={bs} | seq={seq_len} | topk={topk_val} | " - f"heads={num_kv_heads} | pages/seg={pages_per_seg} | dist={dist}" - ) - print(f"\n{config_str}") - - config_results = { - "batch_size": bs, - "seq_len": seq_len, - "topk_val": topk_val, - "num_kv_heads": num_kv_heads, - "distribution": dist, - "eff_batch_size": eff_bs, - "pages_per_seg": pages_per_seg, - "kernels": {}, - } - - # Collect all kernel results first, then print sorted by latency - kernel_entries = [] # [(label, kernel_name, result)] - - for kernel_name in all_kernels: - # Reset sparse indices each run + row_modes = [] + for mode in modes: + power = _resolve_hparam(args, mode) inputs["sparse_kv_indices"].zero_() - - if kernel_name == "naive": - call_args = ( - inputs["x"], - inputs["dense_kv_indptr"], - inputs["dense_kv_indices"], - inputs["sparse_kv_indptr"], - inputs["sparse_kv_indices"], - eff_bs, - topk_val, - args.reserved_bos, - args.reserved_eos, - pages_per_seg, - ) - result = bench_kernel(topk_output, call_args, args.warmup, args.repeat) - elif kernel_name == "sglang_ori": + if mode == 0: + call = topk_output_sglang call_args = ( - inputs["x"], - inputs["dense_kv_indptr"], - inputs["sparse_kv_indptr"], - inputs["dense_kv_indices"], + inputs["x"], inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], inputs["dense_kv_indices"], inputs["sparse_kv_indices"], - eff_bs, - topk_val, - args.reserved_bos, - args.reserved_eos, - pages_per_seg, - args.radix_bits, + eff_bs, topk_val, + args.reserved_bos, args.reserved_eos, pages_per_seg, ) - result = bench_kernel(topk_output_sglang_ori, call_args, args.warmup, args.repeat) - elif kernel_name == "sglang_scale": - call_args = ( - inputs["x"], - inputs["dense_kv_indptr"], - inputs["sparse_kv_indptr"], - inputs["dense_kv_indices"], - inputs["sparse_kv_indices"], - eff_bs, - topk_val, - args.reserved_bos, - args.reserved_eos, - pages_per_seg, - 3, # mode 3 (power) - 1.0, # p=1.0 → identity - None, - None, - False, - args.sample_stride, - args.radix_bits, - ) - result = bench_kernel(topk_output_sglang, call_args, args.warmup, args.repeat) else: - mode_str = kernel_name.split("_m")[1] - mode = int(mode_str.split("_")[0]) - is_noscale = kernel_name.endswith("_noscale") - extra_kwargs = {} - if mode == 1: - extra_kwargs["mapping_lut"] = mapping_lut - elif mode == 2: - extra_kwargs["mapping_quantiles"] = mapping_quantiles - - if mode in (3, 6, 7, 9, 10, 13, 14): - power = _resolve_mode_hparam(args, mode) - else: - power = 0.5 - + call = topk_output_sglang_fused + lut_t = getattr(args, "_mapping_lut", None) if mode == 1 else None + q_t = getattr(args, "_mapping_quantiles", None) if mode == 2 else None call_args = ( - inputs["x"], - inputs["dense_kv_indptr"], - inputs["sparse_kv_indptr"], - inputs["dense_kv_indices"], - inputs["sparse_kv_indices"], - eff_bs, - topk_val, - args.reserved_bos, - args.reserved_eos, - pages_per_seg, - mode, - power, - extra_kwargs.get("mapping_lut", None), - extra_kwargs.get("mapping_quantiles", None), - is_noscale, - args.sample_stride, - args.radix_bits, - ) - result = bench_kernel(topk_output_sglang, call_args, args.warmup, args.repeat) - - # Build label - if kernel_name == "naive": - label = "naive" - elif kernel_name == "sglang_ori": - label = "sglang Ori (no remap)" - elif kernel_name == "sglang_scale": - label = "sglang Scale Only (p=1.0)" - else: - m_str = kernel_name.split("_m")[1] - m = int(m_str.split("_")[0]) - noscale_suffix = " noscale" if kernel_name.endswith("_noscale") else "" - mname = MAPPING_MODE_NAMES.get(m, f'm{m}') - if m in (3, 6, 7, 9, 10, 13, 14): - pname = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", 13: "alpha", 14: "rho"}[m] - label = f"sglang {mname} ({pname}={_resolve_mode_hparam(args, m)}){noscale_suffix}" - else: - label = f"sglang {mname}{noscale_suffix}" - - # Counter collection (runs separately from sub-phase profiling) - if kernel_name not in ("naive",) and args.counters: - if kernel_name in ("sglang_ori",): - c_mode, c_power, c_lut, c_q, c_noscale = 0, 0.5, None, None, False - elif kernel_name == "sglang_scale": - c_mode, c_power, c_lut, c_q, c_noscale = 3, 1.0, None, None, False - else: - c_mode_str = kernel_name.split("_m")[1] - c_mode = int(c_mode_str.split("_")[0]) - c_noscale = kernel_name.endswith("_noscale") - c_power = _resolve_mode_hparam(args, c_mode) if c_mode in (3,6,7,9,10,13,14) else 0.5 - c_lut = mapping_lut if c_mode == 1 else None - c_q = mapping_quantiles if c_mode == 2 else None - inputs["sparse_kv_indices"].zero_() - counter_buf = torch.zeros(eff_bs, 6, dtype=torch.int32, device="cuda") - counter_args = ( - inputs["x"], - inputs["dense_kv_indptr"], - inputs["sparse_kv_indptr"], - inputs["dense_kv_indices"], + inputs["x"], inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], inputs["dense_kv_indices"], inputs["sparse_kv_indices"], - counter_buf, - eff_bs, - topk_val, - args.reserved_bos, - args.reserved_eos, - pages_per_seg, - c_mode, - c_power, - c_lut, - c_q, - c_noscale, + eff_bs, topk_val, + args.reserved_bos, args.reserved_eos, pages_per_seg, + mode, power, lut_t, q_t, ) - topk_profile_counters(*counter_args) - torch.cuda.synchronize() - c = counter_buf.float() - result['counters'] = { - 'threshold_bin_mean': c[:, 0].mean().item(), - 'num_above_mean': c[:, 1].mean().item(), - 'num_equal_mean': c[:, 2].mean().item(), - 'remaining_k_mean': c[:, 3].mean().item(), - 'refine_rounds_mean': c[:, 4].mean().item(), - 'stage2_input_mean': c[:, 5].mean().item(), - 'threshold_bin_max': c[:, 0].max().item(), - 'num_above_max': c[:, 1].max().item(), - 'num_equal_max': c[:, 2].max().item(), - 'remaining_k_max': c[:, 3].max().item(), - 'refine_rounds_max': c[:, 4].max().item(), - 'stage2_input_max': c[:, 5].max().item(), - } - - kernel_entries.append((label, kernel_name, result)) - config_results["kernels"][kernel_name] = result - - # Second pass: sub-phase profiling (histogram_only + stage1_full) - # Run in a subprocess to get a fresh CUDA context, avoiding - # shared memory exhaustion from accumulated kernel registrations. - subphase_modes = [] - for label, kernel_name, result in kernel_entries: - if kernel_name in ("naive", "sglang_ori"): - continue - if kernel_name == "sglang_scale": - s1_mode, s1_power, s1_noscale = 3, 1.0, False - else: - s1_mode_str = kernel_name.split("_m")[1] - s1_mode = int(s1_mode_str.split("_")[0]) - s1_noscale = kernel_name.endswith("_noscale") - s1_power = _resolve_mode_hparam(args, s1_mode) if s1_mode in (3,6,7,9,10,13,14) else 0.5 - subphase_modes.append((kernel_name, s1_mode, s1_power, s1_noscale, result)) - - if subphase_modes: - _run_subphase_profiling( - subphase_modes, inputs, eff_bs, topk_val, - pages_per_seg, args, mapping_lut, mapping_quantiles) - - # Print kernel results sorted by mean latency (ascending) - kernel_entries.sort(key=lambda e: e[2]['mean_ms']) - print(f" --- kernel latency (sorted by mean, ascending) ---") - for label, kernel_name, result in kernel_entries: + stats = bench_kernel(call, call_args, args.warmup, args.repeat) + row_modes.append({ + "mode": mode, "mode_name": MAPPING_MODE_NAMES.get(mode, f"m{mode}"), + "power": power, "mean_ms": stats["mean_ms"], + "median_ms": stats["median_ms"], + }) print( - f" {label:<40s}: " - f"mean={result['mean_ms']:.4f}ms " - f"median={result['median_ms']:.4f}ms " - f"\u00b1 {result['std_ms']:.4f}ms " - f"[min={result['min_ms']:.4f}, max={result['max_ms']:.4f}]" + f"bs={bs} h={heads} seq={seq_len} topk={topk_val} " + f"dist={dist} mode={mode:>2d} lat={stats['mean_ms']:.4f} ms" ) - if 'stage1_full_mean_ms' in result: - print( - f" {'Histogram only (map+hist)':<36s}: " - f"mean={result['histogram_only_mean_ms']:.4f}ms " - f"median={result['histogram_only_median_ms']:.4f}ms" - ) - print( - f" {'Stage1 full (hist+cumsum+route)':<36s}: " - f"mean={result['stage1_full_mean_ms']:.4f}ms " - f"median={result['stage1_full_median_ms']:.4f}ms" - ) - print( - f" {'Route overhead (cumsum+route)':<36s}: " - f"mean={result['route_overhead_mean_ms']:.4f}ms " - f"median={result['route_overhead_median_ms']:.4f}ms" - ) - print( - f" {'Stage2 (refine)':<36s}: " - f"mean={result['stage2_refine_mean_ms']:.4f}ms " - f"median={result['stage2_refine_median_ms']:.4f}ms" - ) - if 'counters' in result: - c = result['counters'] - print( - f" Counters: threshold_bin={c['threshold_bin_mean']:.0f} " - f"above={c['num_above_mean']:.0f} " - f"equal={c['num_equal_mean']:.0f} " - f"remaining_k={c['remaining_k_mean']:.0f} " - f"refine_rounds={c['refine_rounds_mean']:.1f} " - f"stage2_input={c['stage2_input_mean']:.0f}" - ) - - # Histogram analysis — uses the SAME inputs as the main benchmark - # so histogram CSV and counters reflect the same data. - if args.histogram: - hist_inputs = inputs - hist_eff_bs = eff_bs - current_pages = eff_bs * pages_per_seg - print(f" histogram dataset : {current_pages} pages (same as benchmark)") - - # Raw unmapped histogram - histograms = torch.zeros(hist_eff_bs, 256, dtype=torch.int32, device="cuda") - topk_profile_histogram( - hist_inputs["x"], - hist_inputs["dense_kv_indptr"], - histograms, - hist_eff_bs, - args.reserved_bos, - args.reserved_eos, - ) - hstats = compute_histogram_stats(histograms) - hstats["raw_counts"] = histograms.sum(dim=0).tolist() # [256] ints - config_results["histogram"] = hstats - print( - f" histogram stats : max/mean={hstats['max_mean_ratio']:.2f} " - f"gini={hstats['gini']:.3f} " - f"nonzero_bins={hstats['num_nonzero_bins']}/256" - ) - - # Collect all histogram entries, then print sorted by gini - # Each entry: (display_name, key, mode_stats) - hist_entries = [] - histograms_results = {} - - # Per-mode histogram analysis (scaled) - modes_to_test = [0, 3, 4, 6, 7, 8, 9, 10, 11, 13, 14] - if mapping_lut is not None: - modes_to_test.append(1) - if mapping_quantiles is not None: - modes_to_test.append(2) - modes_to_test.sort() - - for mode in modes_to_test: - mode_hists = torch.zeros(hist_eff_bs, 256, dtype=torch.int32, device="cuda") - - extra_lut = mapping_lut if mode == 1 else None - extra_q = mapping_quantiles if mode == 2 else None - power = _resolve_mode_hparam(args, mode) if mode in (3, 6, 7, 9, 10, 13, 14) else 0.5 - - topk_profile_histogram( - hist_inputs["x"], - hist_inputs["dense_kv_indptr"], - mode_hists, - hist_eff_bs, - args.reserved_bos, - args.reserved_eos, - mode, - power, - extra_lut, - extra_q, - False, # mapping_noscale - topk_val, # needed for mode 12/14 (tail/topk window) - ) - torch.cuda.synchronize() - - mode_stats = compute_histogram_stats(mode_hists) - mode_stats["raw_counts"] = mode_hists.sum(dim=0).tolist() - mname = MAPPING_MODE_NAMES.get(mode, f"m{mode}") - mformula = MAPPING_MODE_FORMULAS.get(mode, mname) - mode_stats["name"] = mname - mode_stats["formula"] = mformula - if mode in (3, 6, 7, 9, 10, 13, 14): - pname = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", 13: "alpha", 14: "rho"}[mode] - mode_stats["param"] = f"{pname}={power}" - display_name = f"{mname} ({pname}={power})" - else: - display_name = mname - key = f"mode_{mode}_{mname}" - histograms_results[key] = mode_stats - hist_entries.append((display_name, f"mode {mode:2d}", mode_stats)) - - # Noscale histogram analysis for parametric transform modes - noscale_modes = [m for m in (3, 6, 7, 9, 10, 13) if m in modes_to_test] - for mode in noscale_modes: - ns_hists = torch.zeros(hist_eff_bs, 256, dtype=torch.int32, device="cuda") - power = _resolve_mode_hparam(args, mode) - topk_profile_histogram( - hist_inputs["x"], - hist_inputs["dense_kv_indptr"], - ns_hists, - hist_eff_bs, - args.reserved_bos, - args.reserved_eos, - mode, - power, - None, - None, - True, # mapping_noscale=True - ) - torch.cuda.synchronize() - ns_stats = compute_histogram_stats(ns_hists) - ns_stats["raw_counts"] = ns_hists.sum(dim=0).tolist() - mname = MAPPING_MODE_NAMES.get(mode, f"m{mode}") - mformula = MAPPING_MODE_FORMULAS.get(mode, mname) - pname = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", 13: "alpha"}[mode] - ns_stats["name"] = f"{mname} noscale" - ns_stats["formula"] = mformula - ns_stats["param"] = f"{pname}={power}" - display_name = f"{mname} noscale ({pname}={power})" - key = f"mode_{mode}_{mname}_noscale" - histograms_results[key] = ns_stats - hist_entries.append((display_name, f"m{mode:2d} ns", ns_stats)) - - # Scale Only baseline: mode 3 with p=1.0 (identity + linear scaling) - scale_hists = torch.zeros(hist_eff_bs, 256, dtype=torch.int32, device="cuda") - topk_profile_histogram( - hist_inputs["x"], - hist_inputs["dense_kv_indptr"], - scale_hists, - hist_eff_bs, - args.reserved_bos, - args.reserved_eos, - 3, # mode 3 (power) - 1.0, # p=1.0 → identity transform - None, - None, - ) - torch.cuda.synchronize() - scale_stats = compute_histogram_stats(scale_hists) - scale_stats["raw_counts"] = scale_hists.sum(dim=0).tolist() - scale_stats["name"] = "Scale Only" - scale_stats["formula"] = "Identity + linear scaling to [0,255]" - scale_stats["param"] = "p=1.0" - histograms_results["mode_scale_Scale Only"] = scale_stats - hist_entries.append(("Scale Only (p=1.0)", "scale ", scale_stats)) - - # Print all histogram entries sorted by gini (ascending = more uniform = better) - hist_entries.sort(key=lambda e: e[2]['gini']) - print(f" --- histogram by gini (sorted, lower=better) ---") - for rank, (display_name, mode_tag, stats) in enumerate(hist_entries, 1): - print( - f" {rank:2d}. {display_name:<32s} ({mode_tag}): " - f"gini={stats['gini']:.3f} " - f"max/mean={stats['max_mean_ratio']:.2f} " - f"nonzero_bins={stats['num_nonzero_bins']}/256 " - f"eff_bins={stats['effective_bins']:.1f} " - f"entropy={stats['entropy']:.2f}" - ) - - config_results["histograms"] = histograms_results - - all_results.append(config_results) - - return all_results - - -def main(): - parser = argparse.ArgumentParser(description="TopK kernel benchmark suite") - parser.add_argument("--batch-sizes", nargs="+", type=int, default=[1, 4, 8, 16, 32, 64]) - parser.add_argument("--seq-lens", nargs="+", type=int, default=[1024, 2048, 4096, 8192]) - parser.add_argument("--topk-vals", nargs="+", type=int, default=[16, 30, 64]) - parser.add_argument("--num-kv-heads", nargs="+", type=int, default=[2, 4, 8]) - parser.add_argument("--page-size", type=int, default=16) - parser.add_argument("--reserved-bos", type=int, default=1) - parser.add_argument("--reserved-eos", type=int, default=2) - parser.add_argument("--score-dtype", choices=["bfloat16", "float32"], default="bfloat16") - parser.add_argument("--distributions", nargs="+", default=["normal", "lognormal", "uniform"]) - parser.add_argument("--warmup", type=int, default=10) - parser.add_argument("--repeat", type=int, default=100) - parser.add_argument("--mapping-hparam", "--mapping-power", type=float, default=0.5, - dest="mapping_hparam", - help="Global fallback hyperparameter for parametric modes (default: 0.5)") - parser.add_argument("--mapping-hparam-3", "--mapping-power-3", type=float, default=None, - dest="mapping_hparam_3", - help="Power exponent p for mode 3 (overrides --mapping-hparam)") - parser.add_argument("--mapping-hparam-6", "--mapping-power-6", type=float, default=None, - dest="mapping_hparam_6", - help="Beta for mode 6 asinh (overrides --mapping-hparam)") - parser.add_argument("--mapping-hparam-7", "--mapping-power-7", type=float, default=None, - dest="mapping_hparam_7", - help="Alpha for mode 7 log1p (overrides --mapping-hparam)") - parser.add_argument("--mapping-hparam-13", "--mapping-power-13", type=float, default=None, - dest="mapping_hparam_13", - help="Alpha for mode 13 exp_stretch (overrides --mapping-hparam)") - parser.add_argument("--mapping-hparam-14", "--mapping-power-14", type=float, default=None, - dest="mapping_hparam_14", - help="Rho for mode 14 topk_window (overrides --mapping-hparam)") - parser.add_argument("--autotune-json", type=str, default=None, - help="Path to autotune_results.json — extracts best per-mode hyperparameters " - "(overrides --mapping-power for modes 3/6/7/13/14)") - parser.add_argument("--lut-path", type=str, default=None, help="Path to .npy uint8[256] LUT for mode=1") - parser.add_argument("--quantiles-path", type=str, default=None, help="Path to .npy float32[256] for mode=2") - parser.add_argument("--output-json", type=str, default=None, help="Save results to JSON file") - parser.add_argument("--filter-kernels", nargs="+", default=None, - help="Only run specific kernels: naive, sglang_m0, sglang_m3, sglang_m4") - parser.add_argument("--histogram", action="store_true", help="Collect and report bin distribution statistics") - parser.add_argument("--histogram-pages", type=int, default=None, - help="Total pages for histogram profiling. Default: adaptive " - "(512 samples/bin × 256 bins, rounded to segment boundary). " - "Only used when --histogram is set.") - parser.add_argument("--real-histograms", type=str, default=None, - help="Path to .npy raw_histograms from calibration (adds 'real' distribution)") - parser.add_argument("--counters", action="store_true", - help="Collect diagnostic counters (threshold_bin, num_above, num_equal, " - "remaining_k, refine_rounds, stage2_input) for each sglang kernel") - parser.add_argument("--sample-stride", type=int, default=1, - help="Pre-pass sampling stride for mapped modes (1=full, 4=1/4, 8=1/8). " - "Higher values reduce pre-pass overhead at cost of bin quality (default: 1)") - parser.add_argument("--radix-bits", type=int, default=8, - help="Stage 1 radix bits for ori/mode-0 kernel: 4=16 bins, 6=64, 8=256, 9=512, 10=1024 (default: 8). " - "Range: 4-10. Fewer bits = coarser Stage 1 but faster histogram; more bits = finer but slower.") - - args = parser.parse_args() - results = run_benchmark(args) + results.append({ + "batch_size": bs, "num_kv_heads": heads, "seq_len": seq_len, + "topk_val": topk_val, "distribution": dist, "modes": row_modes, + }) if args.output_json: with open(args.output_json, "w") as f: @@ -945,5 +461,66 @@ def main(): print(f"\nResults saved to {args.output_json}") +def main(): + p = argparse.ArgumentParser("TopK kernel benchmarks") + p.add_argument("--batch-sizes", type=int, nargs="+", default=[4]) + p.add_argument("--num-kv-heads", type=int, nargs="+", default=[8]) + p.add_argument("--seq-lens", type=int, nargs="+", default=[8192]) + p.add_argument("--topk-vals", type=int, nargs="+", default=[30]) + p.add_argument("--distributions", type=str, nargs="+", + default=["normal"], + choices=["normal", "lognormal", "uniform", "bucket_uniform"]) + p.add_argument("--mapping-modes", type=int, nargs="+", + default=[0, 3, 6, 7], + help="Mapping modes to sweep (0=None, 3=Power, 6=Asinh, 7=Log1p, etc.)") + p.add_argument("--mapping-hparam", "--mapping-power", type=float, default=0.5, + dest="mapping_hparam", + help="Fallback hyperparameter for every non-zero mapping mode when " + "no --autotune-json is provided: p for mode 3 (power), beta for " + "mode 6 (asinh), alpha for modes 7/9/10/13 (log1p/erf/tanh/exp_stretch).") + p.add_argument("--autotune-json", type=str, default=None, + help="Path to autotune_results.json produced by autotune_topk_mapping.py. " + "When set, the per-mode hyperparameter with the lowest measured " + "latency in that file is used instead of --mapping-hparam.") + p.add_argument("--lut-path", type=str, default=None, + help="Path to .npy uint8[256] LUT for MAPPING_LUT_CDF (mode 1).") + p.add_argument("--quantiles-path", type=str, default=None, + help="Path to .npy float32[256] quantile table for MAPPING_QUANTILE (mode 2).") + p.add_argument("--page-size", type=int, default=16) + p.add_argument("--reserved-bos", type=int, default=1) + p.add_argument("--reserved-eos", type=int, default=2) + p.add_argument("--warmup", type=int, default=10) + p.add_argument("--repeat", type=int, default=100) + p.add_argument("--output-json", type=str, default=None) + p.add_argument("--remap-bench", action="store_true", + help="Run the split-phase remap/topk/fused/baseline benchmark.") + args = p.parse_args() + + args._autotune_hparams = {} + if args.autotune_json: + args._autotune_hparams = _load_autotune_hparams(args.autotune_json) + print(f"[autotune] using best-latency hyperparameters from {args.autotune_json}:") + for m, v in sorted(args._autotune_hparams.items()): + print(f" mode {m:>2d} -> {v}") + + args._mapping_lut = None + args._mapping_quantiles = None + if args.lut_path: + lut_np = np.load(args.lut_path).astype(np.uint8) + assert lut_np.shape == (256,), f"LUT must be [256], got {lut_np.shape}" + args._mapping_lut = torch.from_numpy(lut_np).cuda() + print(f"[mapping] loaded LUT from {args.lut_path}") + if args.quantiles_path: + q_np = np.load(args.quantiles_path).astype(np.float32) + assert q_np.shape == (256,), f"quantiles must be [256], got {q_np.shape}" + args._mapping_quantiles = torch.from_numpy(q_np).cuda() + print(f"[mapping] loaded quantiles from {args.quantiles_path}") + + if args.remap_bench: + _run_remap_bench(args) + else: + _run_latency_sweep(args) + + if __name__ == "__main__": main() diff --git a/benchmarks/calibrate_topk.py b/benchmarks/calibrate_topk.py index 4c86116..e3524c1 100644 --- a/benchmarks/calibrate_topk.py +++ b/benchmarks/calibrate_topk.py @@ -44,6 +44,14 @@ def main(): help="Number of calibration prompts to use (default: 16)") parser.add_argument("--output-dir", type=str, default="calibration_output/") parser.add_argument("--vortex-module-name", type=str, default="block_sparse_attention") + parser.add_argument( + "--watchdog-timeout", + type=float, + default=None, + metavar="SEC", + help="SGLang scheduler watchdog (seconds). Forward batches must complete within this time. " + "Default: engine default (300). Use 0 to disable when using this repo's SGLang fork.", + ) args = parser.parse_args() # Lazy imports to avoid slow startup when just checking --help @@ -54,7 +62,7 @@ def main(): os.makedirs(args.output_dir, exist_ok=True) print(f"[calibrate] Launching engine with hit-rate profiling enabled...") - llm = sgl.Engine( + engine_kwargs = dict( model_path=args.model_name, disable_cuda_graph=True, page_size=args.page_size, @@ -73,6 +81,9 @@ def main(): vortex_topk_mapping_mode=0, # Use mode 0 during calibration vortex_topk_histogram=True, # Enable histogram collection ) + if args.watchdog_timeout is not None: + engine_kwargs["watchdog_timeout"] = args.watchdog_timeout + llm = sgl.Engine(**engine_kwargs) # Clear any residual histograms in the worker process llm.clear_topk_histograms() diff --git a/benchmarks/greedy_layer_search.py b/benchmarks/greedy_layer_search.py deleted file mode 100644 index 118ac45..0000000 --- a/benchmarks/greedy_layer_search.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Greedy forward-selection of layers whose indexer can be skipped (index cache). - -Usage (from repo root): - cd examples && python ../benchmarks/greedy_layer_search.py \ - --model-name Qwen/Qwen3-1.7B --topk-val 30 --threshold 0.95 \ - --trials 1 --num-layers 28 --mem 0.7 - -The script prints progress to stderr and outputs the final selected layer list -(as a Python list literal) on the **last line of stdout** so callers can parse it. -""" - -import argparse -import os -import sys - -# Add examples/ to path so we can import verify_algos -_examples_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "examples") -sys.path.insert(0, _examples_dir) - -from verify_algo import verify_algos # noqa: E402 - - -def _evaluate(shared_layers, args): - """Run verify_algos with the given shared layers and return pass@trials accuracy.""" - summary = verify_algos( - trials=args.trials, - topk_val=args.topk_val, - page_size=args.page_size, - vortex_module_name=args.vortex_module_name, - model_name=args.model_name, - sparse_attention=True, - mem=args.mem, - kv_cache_dtype=args.kv_cache_dtype, - topk_type=args.topk_type, - topk_mapping_mode=0, - topk_mapping_power=args.topk_mapping_power, - index_cache_shared_layers=sorted(shared_layers) if shared_layers else None, - disable_cuda_graph=True, - ) - acc_key = f"pass@{args.trials}" - return summary[acc_key] - - -def greedy_search(args): - # Ensure we're in examples/ so amc23.jsonl relative path works - os.chdir(_examples_dir) - - candidates = list(range(1, args.num_layers)) - - # Baseline: no shared layers - print("Evaluating baseline (no shared layers)...", file=sys.stderr) - baseline_acc = _evaluate([], args) - print(f"Baseline accuracy: {baseline_acc:.4f}", file=sys.stderr) - - threshold = args.threshold - shared_set = [] - - while candidates: - best_layer = None - best_acc = -1.0 - - for layer in candidates: - trial_set = shared_set + [layer] - print(f" Trying shared_set={sorted(trial_set)} ...", file=sys.stderr, end=" ") - acc = _evaluate(trial_set, args) - print(f"acc={acc:.4f}", file=sys.stderr) - - if acc > best_acc: - best_acc = acc - best_layer = layer - - if best_acc >= threshold * baseline_acc: - shared_set.append(best_layer) - candidates.remove(best_layer) - print( - f"Added layer {best_layer} (acc={best_acc:.4f} >= " - f"{threshold * baseline_acc:.4f}). Current set: {sorted(shared_set)}", - file=sys.stderr, - ) - else: - print( - f"Stopping: best candidate layer {best_layer} acc={best_acc:.4f} < " - f"{threshold * baseline_acc:.4f}", - file=sys.stderr, - ) - break - - result = sorted(shared_set) - print(f"Final shared layers: {result}", file=sys.stderr) - # Last stdout line: parseable Python list - print(result) - return result - - -def parse_args(): - parser = argparse.ArgumentParser( - description="Greedy forward-selection of index-cache shared layers." - ) - parser.add_argument("--model-name", type=str, default="Qwen/Qwen3-1.7B") - parser.add_argument("--topk-val", type=int, default=30) - parser.add_argument("--page-size", type=int, default=16) - parser.add_argument("--vortex-module-name", type=str, default="block_sparse_attention") - parser.add_argument("--mem", type=float, default=0.8) - parser.add_argument("--kv-cache-dtype", type=str, default="auto") - parser.add_argument("--topk-type", type=str, default="naive") - parser.add_argument("--topk-mapping-power", type=float, default=0.5) - parser.add_argument("--threshold", type=float, default=0.95, - help="Minimum accuracy ratio vs baseline to keep adding layers (default: 0.95).") - parser.add_argument("--trials", type=int, default=1) - parser.add_argument("--num-layers", type=int, default=28, - help="Total number of model layers (default: 28 for Qwen3-1.7B).") - return parser.parse_args() - - -if __name__ == "__main__": - args = parse_args() - greedy_search(args) diff --git a/csrc/archived/README.md b/csrc/archived/README.md new file mode 100644 index 0000000..6e08a1d --- /dev/null +++ b/csrc/archived/README.md @@ -0,0 +1,19 @@ +# Archived TopK kernels + +These files are **not compiled** (not listed in `setup.py`) and are kept only +for historical reference. + +- `topk_slgang_ori.cu` — the original SGLang TopK reference kernel (typo in + the filename is intentional, matches the upstream commit it was adapted + from). Superseded by the fused `fast_topk_vortex` path in + `../topk_sglang.cu`. +- `topk_sglang_ori_fastpath.cu` — the `fast_topk_ori` / + `TopKOutput_Ori_Kernel` / `launch_ori_kernel` code extracted out of + `../topk_sglang.cu`. It was the "zero mapping overhead" fast path with + flexible `radix_bits` (4–10). We no longer test it — mode 0 now goes + through the standard fused kernel with `MAPPING_NONE`, which pays no + mapping overhead because `mapped_convert_to_uint8` degenerates to + `convert_to_uint8` in that branch. + +If you need to resurrect any of this, add the `.cu` to `setup.py` and +re-export its entry points from `../register.cc` / `../register.h`. diff --git a/csrc/archived/fast_topk_vortex_prepass.cu b/csrc/archived/fast_topk_vortex_prepass.cu new file mode 100644 index 0000000..5b743f1 --- /dev/null +++ b/csrc/archived/fast_topk_vortex_prepass.cu @@ -0,0 +1,525 @@ +// Archived: not compiled. See csrc/archived/README.md +// +// fast_topk_vortex — the heavy fused remap+topk kernel with auto-range, +// pivot, tail-window, topk-window pre-passes and LUT/quantile support. +// Extracted from csrc/topk_sglang.cu as part of the remap-benchmark refactor. +// Replaced by a lean fast_topk_clean_fused that applies a simple element-wise +// transform (from topk_mapping.cuh apply_transform) in Stage-1 bucketing — +// no pre-pass, no LUT, no auto-range. +// +// References types/constants from its former translation unit (TopKMappingParams, +// needs_*, mapped_convert_to_uint8, kSmem, kThreadsPerBlock, COUNTER_*). This +// file will not compile standalone; kept for history only. + +// ====================================================================== +// Templated version of fast_topk_cuda_tl with mapping support: +// - ScoreT: float or __nv_bfloat16 +// - StopAfterStage1: return after Stage 1 route/filter (for profiling) +// - WriteCounters: write diagnostic counters to global memory + +// - mapping: configurable value-remapping for Stage 1 bin assignment +template +__device__ void fast_topk_vortex( + const ScoreT* __restrict__ input, + int* __restrict__ index, + int row_start, + int length, + int target_k, + const TopKMappingParams& mapping, + int* counters = nullptr) +{ + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int vh_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int vh_counter; + alignas(128) __shared__ int vh_threshold_bin_id; + alignas(128) __shared__ int vh_num_input[2]; + + // Shared memory for mapping LUT / quantiles (loaded once per block) + __shared__ uint8_t s_mapping_lut[256]; + __shared__ float s_mapping_quantiles[256]; + + // Auto-range for transform modes (3/4/6/7) + __shared__ float s_range_min, s_range_inv_range; + + auto& vh_histogram = vh_histogram_buf[0]; + extern __shared__ int vh_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // Load mapping tables into shared memory if needed + if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { + if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; + __syncthreads(); + } + if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { + if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; + __syncthreads(); + } + + // Pre-pass: compute per-block min/max of transformed values for linear bucketing. + // sample_stride > 1 reduces pre-pass cost by scanning every Nth element; + // the approximated range may miss extreme outliers but Stage 2 uses raw + // float bits for exact ordering, so correctness is preserved. + if (needs_auto_range(mapping.mode) && !mapping.noscale) { + const int stride = (mapping.sample_stride > 1) ? mapping.sample_stride : 1; + float local_min = __FLT_MAX__, local_max = -__FLT_MAX__; + for (int idx = tx * stride; idx < length; idx += BLOCK_SIZE * stride) { + float val = apply_transform(vortex_to_float(input[idx + row_start]), mapping); + local_min = fminf(local_min, val); + local_max = fmaxf(local_max, val); + } + // Warp-level reduction + for (int offset = 16; offset > 0; offset >>= 1) { + local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + } + // Cross-warp reduction via shared memory + __shared__ float s_warp_mins[32], s_warp_maxs[32]; + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) { s_warp_mins[warp_id] = local_min; s_warp_maxs[warp_id] = local_max; } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_min = s_warp_mins[tx]; local_max = s_warp_maxs[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + } + if (tx == 0) { + s_range_min = local_min; + float range = local_max - local_min; + s_range_inv_range = (range > 0.0f) ? 255.0f / range : 0.0f; + } + } + __syncthreads(); + } else if (needs_pivot(mapping.mode)) { + // Pivot pre-pass: compute mean of all elements, store in s_range_min. + // MAPPING_SUBTRACT uses convert_to_uint8(x - range_min), so centering + // around the mean helps distribute values more evenly across bins. + float local_sum = 0.0f; + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + local_sum += vortex_to_float(input[idx + row_start]); + } + // Warp-level reduction + for (int offset = 16; offset > 0; offset >>= 1) { + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + } + __shared__ float s_warp_sums[32]; + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) s_warp_sums[warp_id] = local_sum; + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_sum = s_warp_sums[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + } + if (tx == 0) { + s_range_min = local_sum / float(length); // mean as pivot + s_range_inv_range = 0.0f; + } + } + __syncthreads(); + } else if (needs_tail_window(mapping.mode)) { + // Adaptive tail-window pre-pass: estimate tau_low = Q(1 - rho*k/n) + // and local_max via a sampled quantile estimator. All 256 coarse bins + // are then allocated to [tau_low, local_max]; scores below tau_low + // collapse into bin 0 via linear_map_to_uint8 clamping. + constexpr int MAX_SAMPLES = 1024; + __shared__ float s_samples[MAX_SAMPLES]; + __shared__ int s_sample_count; + + if (tx == 0) s_sample_count = 0; + __syncthreads(); + + // Compute sampling stride so we collect ~MAX_SAMPLES from the segment + const int desired_stride = (length + MAX_SAMPLES - 1) / MAX_SAMPLES; + const int sample_stride = max(desired_stride, 1); + + // Each thread samples elements and finds local_max simultaneously + float local_max = -__FLT_MAX__; + for (int idx = tx * sample_stride; idx < length; idx += BLOCK_SIZE * sample_stride) { + float val = vortex_to_float(input[idx + row_start]); + local_max = fmaxf(local_max, val); + int slot = ::atomicAdd(&s_sample_count, 1); + if (slot < MAX_SAMPLES) { + s_samples[slot] = val; + } + } + + // Reduce local_max across block + for (int offset = 16; offset > 0; offset >>= 1) + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + __shared__ float s_warp_maxs_tw[32]; + { + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) s_warp_maxs_tw[warp_id] = local_max; + } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_max = s_warp_maxs_tw[tx]; + for (int offset = 16; offset > 0; offset >>= 1) + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + if (tx == 0) s_warp_maxs_tw[0] = local_max; + } + __syncthreads(); + local_max = s_warp_maxs_tw[0]; + + int nsamp = min(s_sample_count, MAX_SAMPLES); + + // Simple odd-even transposition sort on the sample buffer. + // nsamp <= 1024, and we have 1024 threads, so each thread + // handles one element. O(nsamp) parallel rounds suffice. + __syncthreads(); + if (nsamp >= 2) { + for (int pass = 0; pass < nsamp; ++pass) { + // Even phase: compare (0,1), (2,3), ... + if (tx * 2 + 1 < nsamp) { + int i = tx * 2; + if (s_samples[i] > s_samples[i + 1]) { + float tmp = s_samples[i]; + s_samples[i] = s_samples[i + 1]; + s_samples[i + 1] = tmp; + } + } + __syncthreads(); + // Odd phase: compare (1,2), (3,4), ... + if (tx * 2 + 2 < nsamp) { + int i = tx * 2 + 1; + if (s_samples[i] > s_samples[i + 1]) { + float tmp = s_samples[i]; + s_samples[i] = s_samples[i + 1]; + s_samples[i + 1] = tmp; + } + } + __syncthreads(); + } + } + + // Estimate tau_low = Q(1 - rho * k / n) + if (tx == 0) { + float rho = mapping.power_exp; // reused as tail expansion factor + if (rho <= 0.0f) rho = 4.0f; + int k = (mapping.target_k > 0) ? mapping.target_k : target_k; + float frac = 1.0f - rho * float(k) / float(length); + frac = fmaxf(frac, 0.0f); // clamp: never go below rank 0 + + float tau_low; + if (nsamp < 4 || frac <= 0.0f) { + // Too few samples or the tail covers everything: full range + tau_low = -__FLT_MAX__; + } else { + float fidx = frac * float(nsamp - 1); + int lo = __float2int_rd(fidx); + lo = min(max(lo, 0), nsamp - 2); + float t = fidx - float(lo); + tau_low = s_samples[lo] * (1.0f - t) + s_samples[lo + 1] * t; + } + + // Fallback: if tau_low >= local_max, use full-range linear mapping + if (tau_low >= local_max) { + // Find the actual minimum from sorted samples + tau_low = (nsamp > 0) ? s_samples[0] : local_max; + } + + float range = local_max - tau_low; + s_range_min = tau_low; + s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; + } + __syncthreads(); + } else if (needs_topk_window(mapping.mode)) { + // Topk-window pre-pass with streaming variance heuristic. + // tau_low = max - rho * sigma * sqrt(2 * log(n/k)) + float local_max = -__FLT_MAX__; + float local_sum = 0.0f, local_sum_sq = 0.0f; + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + float val = vortex_to_float(input[idx + row_start]); + local_max = fmaxf(local_max, val); + local_sum += val; + local_sum_sq += val * val; + } + for (int offset = 16; offset > 0; offset >>= 1) { + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + local_sum_sq += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq, offset); + } + __shared__ float s_warp_maxs_tw2[32], s_warp_sums_tw2[32], s_warp_sq_tw2[32]; + { + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) { + s_warp_maxs_tw2[warp_id] = local_max; + s_warp_sums_tw2[warp_id] = local_sum; + s_warp_sq_tw2[warp_id] = local_sum_sq; + } + } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_max = s_warp_maxs_tw2[tx]; + local_sum = s_warp_sums_tw2[tx]; + local_sum_sq = s_warp_sq_tw2[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + local_sum_sq += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq, offset); + } + if (tx == 0) { + float rho = mapping.power_exp; + if (rho <= 0.0f) rho = 4.0f; + int k = (mapping.target_k > 0) ? mapping.target_k : target_k; + float n = float(length); + float mean = local_sum / n; + float var = local_sum_sq / n - mean * mean; + float sigma = (var > 0.0f) ? sqrtf(var) : 0.0f; + float ratio = n / fmaxf(float(k), 1.0f); + float z = sqrtf(2.0f * __logf(fmaxf(ratio, 1.0f))); + float tau_low = local_max - rho * sigma * z; + if (tau_low >= local_max) tau_low = local_max - 1.0f; + float range = local_max - tau_low; + s_range_min = tau_low; + s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; + } + } + __syncthreads(); + } else { + if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } + __syncthreads(); + } + + // Stage 1: 8-bit coarse histogram (with optional mapping) + // Bin cache: store computed bins in vh_input_idx[1] (reinterpreted as uint8_t*) + // to avoid recomputing mapped_convert_to_uint8 in the route/filter pass. + // vh_input_idx[1] is unused until Stage 2 double-buffering starts after route. + constexpr int BIN_CACHE_CAPACITY = SMEM_INPUT_SIZE * static_cast(sizeof(int)); // uint8 entries + uint8_t* bin_cache = reinterpret_cast(vh_input_idx[1]); + const bool use_bin_cache = (length <= BIN_CACHE_CAPACITY); + + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = mapped_convert_to_uint8( + vortex_to_float(input[idx + row_start]), + mapping, s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range); + ::atomicAdd(&vh_histogram[bin], 1); + if (use_bin_cache) { + bin_cache[idx] = bin; + } + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = vh_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += vh_histogram_buf[k][tx + j]; + } + vh_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { + vh_threshold_bin_id = tx; + vh_num_input[0] = 0; + vh_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = vh_threshold_bin_id; + topk -= vh_histogram[threshold_bin + 1]; + + if (WriteCounters && tx == 0 && counters) { + counters[COUNTER_THRESHOLD_BIN] = threshold_bin; + counters[COUNTER_REMAINING_K] = topk; + } + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + int bin; + if (use_bin_cache) { + bin = static_cast(bin_cache[idx]); + } else { + bin = static_cast( + mapped_convert_to_uint8( + vortex_to_float(input[idx + row_start]), + mapping, s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range)); + } + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + if (WriteCounters && tx == 0 && counters) { + counters[COUNTER_NUM_ABOVE] = vh_counter; + counters[COUNTER_NUM_EQUAL] = 0; + counters[COUNTER_REFINE_ROUNDS] = 0; + counters[COUNTER_STAGE2_INPUT] = 0; + } + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = vortex_to_float(input[idx + row_start]); + int bin; + if (use_bin_cache) { + bin = static_cast(bin_cache[idx]); + } else { + bin = static_cast( + mapped_convert_to_uint8(raw_input, mapping, + s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range)); + } + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&vh_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + vh_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&vh_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + if (WriteCounters && tx == 0 && counters) { + counters[COUNTER_NUM_ABOVE] = vh_counter; + counters[COUNTER_NUM_EQUAL] = vh_num_input[0]; + counters[COUNTER_STAGE2_INPUT] = vh_num_input[0]; + } + if (StopAfterStage1) return; + } + + // Stage 2: refine with 8-bit radix passes (unchanged — uses raw float bits) + if constexpr (WriteCounters) { + // Default: all 4 rounds used; overwritten at break if resolved early + if (tx == 0 && counters) counters[COUNTER_REFINE_ROUNDS] = 4; + } +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int vh_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = vh_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) + ? _raw_num_input + : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { + vh_threshold_bin_id = tx; + vh_num_input[r_idx ^ 1] = 0; + vh_last_remain = topk - vh_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = vh_threshold_bin_id; + topk -= vh_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = vh_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32( + vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + if constexpr (WriteCounters) { + if (tx == 0 && counters) { + counters[COUNTER_REFINE_ROUNDS] = round + 1; + } + } + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = vh_input_idx[r_idx][i]; + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&vh_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&vh_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + vh_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&vh_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +// Wrapper kernel: one CUDA block per batch*head segment +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKOutput_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos, + const TopKMappingParams mapping) +{ + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_vortex(score_blk, s_indices, 0, nblk, topk_val, mapping); + __syncthreads(); + + // Remap position indices -> page indices via dense_kv_indices + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } +} + + diff --git a/csrc/archived/topk_mapping_full.cuh b/csrc/archived/topk_mapping_full.cuh new file mode 100644 index 0000000..f85204e --- /dev/null +++ b/csrc/archived/topk_mapping_full.cuh @@ -0,0 +1,217 @@ +// Archived: not included by any compiled TU. See csrc/archived/README.md. +// The full mapping header supporting LUT_CDF, QUANTILE, TRUNC8, SUBTRACT, +// ADAPTIVE_TAIL_WINDOW, TOPK_WINDOW and the auto-range/pivot/tail-window +// pre-pass infrastructure. Replaced by the lean element-wise-only header +// at csrc/topk_mapping.cuh for the remap-benchmark refactor. +#pragma once +#include +#include +#include + +// ============================================================ +// TopK bucket-sort Stage-1 remapping strategies +// +// These transforms remap float scores before Stage 1's 8-bit +// histogram binning. The primary goal is to maximize coarse-bin +// resolution in the score region that determines the top-k +// cutoff, thereby: +// - shrinking the Stage-1 threshold bin (fewer collisions) +// - reducing COUNTER_NUM_EQUAL / COUNTER_STAGE2_INPUT +// - reducing the number of Stage-2 refine rounds +// +// Stage 2 refinement still uses convert_to_uint32() on raw +// floats, so final ordering correctness is always preserved. +// +// Modes 3/4/6/7/9/10 apply a nonlinear transform then linearly +// map the result to [0,255]. Mode 12 (ADAPTIVE_TAIL_WINDOW) +// directly focuses all 256 bins on the competitive upper tail +// estimated from the top-k ratio, collapsing irrelevant +// low-score mass into bin 0. +// ============================================================ + +enum TopKMappingMode { + MAPPING_NONE = 0, // Original convert_to_uint8 behavior + MAPPING_LUT_CDF = 1, // LUT-based CDF equalization + MAPPING_QUANTILE = 2, // Piecewise-linear quantile mapping + MAPPING_POWER = 3, // Monotonic power transform + MAPPING_LOG = 4, // Log transform + // Mode 5 reserved (previously INDEX_CACHE, removed) + MAPPING_ASINH = 6, // asinh(beta * x), beta via power_exp + MAPPING_LOG1P = 7, // sign(x) * log1p(alpha * |x|), alpha via power_exp + MAPPING_TRUNC8 = 8, // BF16 upper-8-bit bucketing + MAPPING_ERF = 9, // erf(alpha * x) + MAPPING_TANH = 10, // tanh(alpha * x) + MAPPING_SUBTRACT = 11, // subtract pivot, then fp16 bucketing + MAPPING_ADAPTIVE_TAIL_WINDOW = 12, // focus bins on upper tail via sampled quantile + MAPPING_EXP_STRETCH = 13, // exp(alpha * x), concentrates bin resolution on upper tail + MAPPING_TOPK_WINDOW = 14, // k-aware linear windowing: focus bins on [tau_low, max] +}; + +struct TopKMappingParams { + int mode; // TopKMappingMode + float power_exp; // For MAPPING_POWER (default 0.5) + // For MAPPING_ADAPTIVE_TAIL_WINDOW: tail expansion + // factor rho (default 4.0). tau_low = Q(1 - rho*k/n). + const uint8_t* __restrict__ lut; // [256] byte LUT, or nullptr + const float* __restrict__ quantiles; // [256] float quantile breakpoints, or nullptr + bool noscale; // Skip auto-range linear scaling, use fp16 bucketing on f(x) + int sample_stride; // Pre-pass sampling stride (1=full, 8=1/8, 0=skip) + int target_k; // Top-k value; used by MAPPING_ADAPTIVE_TAIL_WINDOW +}; + +// NOTE: convert_to_uint8() must be defined before including this header. +// It is defined in topk_sglang.cu within the anonymous namespace. + +// ---- Individual transform functions (return float, no bucketing) ---- + +__device__ __forceinline__ float transform_power(float x, float p) { + return copysignf(__powf(fabsf(x), p), x); +} + +__device__ __forceinline__ float transform_log(float x) { + return copysignf(__logf(fabsf(x) + 1.0f), x); +} + +__device__ __forceinline__ float transform_asinh(float x, float beta) { + return asinhf(beta * x); +} + +__device__ __forceinline__ float transform_log1p(float x, float alpha) { + return copysignf(log1pf(alpha * fabsf(x)), x); +} + +__device__ __forceinline__ float transform_erf(float x, float alpha) { + return erff(alpha * x); +} + +__device__ __forceinline__ float transform_tanh(float x, float alpha) { + return tanhf(alpha * x); +} + +__device__ __forceinline__ float transform_exp_stretch(float x, float alpha) { + float z = alpha * x; + z = fminf(z, 80.0f); // prevent float32 overflow (exp(80) ~ 5.5e34) + return expf(z); +} + +// ---- Transform dispatcher (returns float, no bucketing) ---- + +__device__ __forceinline__ float apply_transform(float x, const TopKMappingParams& params) { + switch (params.mode) { + case MAPPING_POWER: return transform_power(x, params.power_exp); + case MAPPING_LOG: return transform_log(x); + case MAPPING_ASINH: return transform_asinh(x, params.power_exp); + case MAPPING_LOG1P: return transform_log1p(x, params.power_exp); + case MAPPING_ERF: return transform_erf(x, params.power_exp); + case MAPPING_TANH: return transform_tanh(x, params.power_exp); + case MAPPING_EXP_STRETCH: return transform_exp_stretch(x, params.power_exp); + default: return x; + } +} + +// ---- Linear bucketing for transform modes ---- + +__device__ __forceinline__ uint8_t linear_map_to_uint8(float val, float range_min, float inv_range) { + int bin = __float2int_rd((val - range_min) * inv_range); + return static_cast(min(max(bin, 0), 255)); +} + +// ---- BF16-aware bucketing (mode 8) ---- +// BF16 has 8 exponent + 7 mantissa bits. Taking the upper 8 bits of the +// sign-flipped bf16 bit-pattern yields only ~20 distinct bins for typical +// data (the byte is almost entirely exponent). Instead, convert through +// fp16 (5 exp + 10 mantissa) which puts 5 exp + 2 mantissa bits in the +// upper byte, giving ~135+ distinct bins — equivalent to mode 0 but +// explicitly available as a named mode for documentation/benchmarking. + +__device__ __forceinline__ uint8_t convert_to_uint8_bf16(float x) { + return convert_to_uint8(x); // fp16 sign-flip bucketing +} + +// ---- Non-transform mapping functions (unchanged) ---- + +// LUT-based CDF equalization: lut[original_bin] -> equalized_bin +__device__ __forceinline__ uint8_t map_lut_cdf(float x, const uint8_t* __restrict__ s_lut) { + return s_lut[convert_to_uint8(x)]; +} + +// Quantile mapping: binary search over 256 sorted thresholds +__device__ __forceinline__ uint8_t map_quantile(float x, const float* __restrict__ s_quantiles) { + // Binary search: find largest index i such that x >= s_quantiles[i] + // s_quantiles is sorted ascending, length 256 + int lo = 0, hi = 255; +#pragma unroll 8 + for (int iter = 0; iter < 8; ++iter) { + int mid = (lo + hi + 1) >> 1; + if (x >= s_quantiles[mid]) { + lo = mid; + } else { + hi = mid - 1; + } + } + return static_cast(lo); +} + +// ---- Unified dispatcher ---- +// For modes 3/4/6/7, range_min and inv_range come from a per-block pre-pass. + +__device__ __forceinline__ uint8_t mapped_convert_to_uint8( + float x, + const TopKMappingParams& params, + const uint8_t* __restrict__ s_lut, + const float* __restrict__ s_quantiles, + float range_min, + float inv_range) +{ + switch (params.mode) { + case MAPPING_LUT_CDF: + if (params.lut != nullptr) return map_lut_cdf(x, s_lut); + return convert_to_uint8(x); // fallback to mode 0 when LUT not calibrated + case MAPPING_QUANTILE: + if (params.quantiles != nullptr) return map_quantile(x, s_quantiles); + return convert_to_uint8(x); // fallback to mode 0 when quantiles not calibrated + case MAPPING_POWER: + case MAPPING_LOG: + case MAPPING_ASINH: + case MAPPING_LOG1P: + case MAPPING_ERF: + case MAPPING_TANH: + case MAPPING_EXP_STRETCH: { + float val = apply_transform(x, params); + if (params.noscale) return convert_to_uint8(val); + return linear_map_to_uint8(val, range_min, inv_range); + } + case MAPPING_TRUNC8: + return convert_to_uint8_bf16(x); + case MAPPING_SUBTRACT: + return convert_to_uint8(x - range_min); // range_min repurposed as pivot + case MAPPING_ADAPTIVE_TAIL_WINDOW: + case MAPPING_TOPK_WINDOW: + return linear_map_to_uint8(x, range_min, inv_range); + default: // MAPPING_NONE + return convert_to_uint8(x); + } +} + +// Helper: check if a mapping mode needs the auto-range pre-pass +__device__ __forceinline__ bool needs_auto_range(int mode) { + return (mode == MAPPING_POWER || mode == MAPPING_LOG || + mode == MAPPING_ASINH || mode == MAPPING_LOG1P || + mode == MAPPING_ERF || mode == MAPPING_TANH || + mode == MAPPING_EXP_STRETCH); +} + +// Helper: check if a mapping mode needs the pivot pre-pass +__device__ __forceinline__ bool needs_pivot(int mode) { + return (mode == MAPPING_SUBTRACT); +} + +// Helper: check if mode is the adaptive tail-window pre-pass +__device__ __forceinline__ bool needs_tail_window(int mode) { + return (mode == MAPPING_ADAPTIVE_TAIL_WINDOW); +} + +// Helper: check if mode is the lightweight topk-window pre-pass +__device__ __forceinline__ bool needs_topk_window(int mode) { + return (mode == MAPPING_TOPK_WINDOW); +} diff --git a/csrc/archived/topk_sglang_ori_fastpath.cu b/csrc/archived/topk_sglang_ori_fastpath.cu new file mode 100644 index 0000000..29970ec --- /dev/null +++ b/csrc/archived/topk_sglang_ori_fastpath.cu @@ -0,0 +1,319 @@ +// Archived: not compiled. See csrc/archived/README.md +// +// Flexible-radix (RADIX_BITS 4..10) "ori fast path" for TopK. It was the +// zero-mapping-overhead fast path used when mapping_mode == MAPPING_NONE. +// No longer tested — mode 0 now routes through the fused TopKOutput_Kernel +// with mapping.mode == MAPPING_NONE, which pays no extra cost because +// mapped_convert_to_uint8 collapses to convert_to_uint8 in that branch. +// +// The code below was extracted verbatim from csrc/topk_sglang.cu as of the +// fused-kernel refactor. It references helpers (kSmem, convert_to_uint32, +// vortex_to_float, VORTEX_MAX_TOPK, kThreadsPerBlock, setup_kernel_smem_once, +// CHECK_CUDA, topk_mapping.cuh types) from the surrounding translation unit. +// Dropping this file into a build as-is will not compile; it is reference +// only. + +template +__device__ __forceinline__ uint16_t convert_to_uintN(float x) { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return key >> (16 - BITS); +} + +// ====================================================================== +// Ori fast path: zero-overhead topk with no mapping infrastructure. +// Template on RADIX_BITS: 4-10 (16 to 1024 bins). +// ====================================================================== +template +__device__ void fast_topk_ori( + const ScoreT* __restrict__ input, + int* __restrict__ index, + int row_start, + int length, + int target_k) +{ + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 1 << RADIX_BITS; + constexpr auto RADIX_PAD = RADIX / 2; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + static_assert(RADIX_BITS >= 4 && RADIX_BITS <= 10, "RADIX_BITS must be 4-10"); + static_assert(RADIX <= BLOCK_SIZE, "RADIX must not exceed BLOCK_SIZE"); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + RADIX_PAD]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // Stage 1: coarse histogram with RADIX bins + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uintN(vortex_to_float(input[idx + row_start])); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { + for (int i = 0; i < RADIX_BITS; ++i) { + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + // Stage 2 cumsum: always 256 sub-bins (8-bit radix on raw float bits) + const auto run_cumsum_s2 = [&] { + for (int i = 0; i < 8; ++i) { + if (C10_LIKELY(tx < 256)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < 256 - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uintN(vortex_to_float(input[idx + row_start]))); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < 257) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto bin = static_cast(convert_to_uintN(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // Stage 2: refine with 8-bit radix passes +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum_s2(); + if (tx < 256 && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < 257) s_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +// Ori fast-path wrapper: zero mapping overhead, flexible radix +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKOutput_Ori_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos) +{ + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_ori(score_blk, s_indices, 0, nblk, topk_val); + __syncthreads(); + + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } +} + +// Helper: launch TopKOutput_Ori_Kernel with radix_bits dispatch +template +void launch_ori_kernel( + const ScoreT* score, const int* dense_kv_indptr, const int* sparse_kv_indptr, + const int* dense_kv_indices, int* sparse_kv_indices, + int topk_val, int reserved_bos, int reserved_eos, + int radix_bits, dim3 nblks, dim3 nthreads, cudaStream_t stream) +{ + #define LAUNCH_ORI(BITS) \ + setup_kernel_smem_once, kSmem>(); \ + TopKOutput_Ori_Kernel<<>>( \ + score, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, sparse_kv_indices, \ + topk_val, reserved_bos, reserved_eos) + switch (radix_bits) { + case 4: LAUNCH_ORI(4); break; + case 5: LAUNCH_ORI(5); break; + case 6: LAUNCH_ORI(6); break; + case 7: LAUNCH_ORI(7); break; + case 9: LAUNCH_ORI(9); break; + case 10: LAUNCH_ORI(10); break; + default: LAUNCH_ORI(8); break; + } + #undef LAUNCH_ORI +} + +// ====================================================================== +// Explicit ori baseline entry point — always uses the ori fast path +// ====================================================================== +void topk_output_sglang_ori( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages, + const int64_t radix_bits) +{ + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_output_sglang_ori: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + TORCH_CHECK(radix_bits >= 4 && radix_bits <= 10, + "topk_output_sglang_ori: radix_bits must be 4-10, got ", radix_bits); + + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(sparse_kv_indptr); + CHECK_CUDA(dense_kv_indices); + CHECK_CUDA(sparse_kv_indices); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + launch_ori_kernel<__nv_bfloat16>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), sparse_kv_indices.data_ptr(), + topk_val, reserved_bos, reserved_eos, + radix_bits, nblks, nthreads, stream); + } else if (x.scalar_type() == at::ScalarType::Float) { + launch_ori_kernel( + x.data_ptr(), + dense_kv_indptr.data_ptr(), sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), sparse_kv_indices.data_ptr(), + topk_val, reserved_bos, reserved_eos, + radix_bits, nblks, nthreads, stream); + } else { + TORCH_CHECK(false, "topk_output_sglang_ori: unsupported dtype ", x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_output_sglang_ori kernel failed: ", ::cudaGetErrorString(result)); +} diff --git a/csrc/topk_slgang_ori.cu b/csrc/archived/topk_slgang_ori.cu similarity index 100% rename from csrc/topk_slgang_ori.cu rename to csrc/archived/topk_slgang_ori.cu diff --git a/csrc/register.cc b/csrc/register.cc index b2d12b9..cc201c9 100644 --- a/csrc/register.cc +++ b/csrc/register.cc @@ -13,21 +13,24 @@ PYBIND11_MODULE(vortex_torch_C, m){ py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), py::arg("eff_batch_size"), py::arg("topk_val"), py::arg("reserved_bos"), py::arg("reserved_eos"), - py::arg("max_num_pages"), - py::arg("mapping_mode") = 0, - py::arg("mapping_power") = 0.5, - py::arg("mapping_lut") = py::none(), - py::arg("mapping_quantiles") = py::none(), - py::arg("mapping_noscale") = false, - py::arg("sample_stride") = 1, - py::arg("radix_bits") = 8); - m.def("topk_output_sglang_ori", &topk_output_sglang_ori, + py::arg("max_num_pages")); + m.def("topk_output_sglang_fused", &topk_output_sglang_fused, py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), py::arg("eff_batch_size"), py::arg("topk_val"), py::arg("reserved_bos"), py::arg("reserved_eos"), py::arg("max_num_pages"), - py::arg("radix_bits") = 8); + py::arg("mapping_mode"), + py::arg("mapping_power"), + py::arg("mapping_lut") = py::none(), + py::arg("mapping_quantiles") = py::none()); + m.def("topk_remap_only", &topk_remap_only, + py::arg("x"), py::arg("dense_kv_indptr"), + py::arg("remapped"), + py::arg("eff_batch_size"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("mapping_mode"), + py::arg("mapping_power")); m.def("topk_profile_histogram", &topk_profile_histogram, py::arg("x"), py::arg("dense_kv_indptr"), py::arg("histograms"), py::arg("eff_batch_size"), @@ -35,21 +38,7 @@ PYBIND11_MODULE(vortex_torch_C, m){ py::arg("mapping_mode") = 0, py::arg("mapping_power") = 0.5, py::arg("mapping_lut") = py::none(), - py::arg("mapping_quantiles") = py::none(), - py::arg("mapping_noscale") = false, - py::arg("topk_val") = 0, - py::arg("sample_stride") = 1); - m.def("topk_profile_stage1", &topk_profile_stage1, - py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), - py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), - py::arg("eff_batch_size"), py::arg("topk_val"), - py::arg("reserved_bos"), py::arg("reserved_eos"), - py::arg("max_num_pages"), - py::arg("mapping_mode") = 0, - py::arg("mapping_power") = 0.5, - py::arg("mapping_lut") = py::none(), - py::arg("mapping_quantiles") = py::none(), - py::arg("mapping_noscale") = false); + py::arg("mapping_quantiles") = py::none()); m.def("topk_profile_counters", &topk_profile_counters, py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), @@ -60,8 +49,7 @@ PYBIND11_MODULE(vortex_torch_C, m){ py::arg("mapping_mode") = 0, py::arg("mapping_power") = 0.5, py::arg("mapping_lut") = py::none(), - py::arg("mapping_quantiles") = py::none(), - py::arg("mapping_noscale") = false); + py::arg("mapping_quantiles") = py::none()); m.def("sglang_plan_decode_fa3", &sglang_plan_decode_fa3); m.def("sglang_plan_prefill_fa3", &sglang_plan_prefill_fa3); m.def("Chunkwise_HN2NH_Transpose_FA3", &Chunkwise_HN2NH_Transpose_FA3); diff --git a/csrc/register.h b/csrc/register.h index 784b754..e86a963 100644 --- a/csrc/register.h +++ b/csrc/register.h @@ -95,17 +95,10 @@ const int64_t eff_batch_size, const int64_t topk_val, const int64_t reserved_bos, const int64_t reserved_eos, -const int64_t max_seq_lengths, -const int64_t mapping_mode = 0, -const double mapping_power = 0.5, -std::optional mapping_lut = std::nullopt, -std::optional mapping_quantiles = std::nullopt, -const bool mapping_noscale = false, -const int64_t sample_stride = 1, -const int64_t radix_bits = 8 +const int64_t max_seq_lengths ); -void topk_output_sglang_ori( +void topk_output_sglang_fused( const at::Tensor& x, const at::Tensor& dense_kv_indptr, const at::Tensor& sparse_kv_indptr, @@ -116,41 +109,34 @@ const int64_t topk_val, const int64_t reserved_bos, const int64_t reserved_eos, const int64_t max_num_pages, -const int64_t radix_bits = 8 +const int64_t mapping_mode, +const double mapping_power, +std::optional mapping_lut = std::nullopt, +std::optional mapping_quantiles = std::nullopt ); -void topk_profile_histogram( +void topk_remap_only( const at::Tensor& x, const at::Tensor& dense_kv_indptr, -at::Tensor& histograms, +at::Tensor& remapped, const int64_t eff_batch_size, const int64_t reserved_bos, const int64_t reserved_eos, -const int64_t mapping_mode = 0, -const double mapping_power = 0.5, -std::optional mapping_lut = std::nullopt, -std::optional mapping_quantiles = std::nullopt, -const bool mapping_noscale = false, -const int64_t topk_val = 0, -const int64_t sample_stride = 1 +const int64_t mapping_mode, +const double mapping_power ); -void topk_profile_stage1( +void topk_profile_histogram( const at::Tensor& x, const at::Tensor& dense_kv_indptr, -const at::Tensor& sparse_kv_indptr, -const at::Tensor& dense_kv_indices, -at::Tensor& sparse_kv_indices, +at::Tensor& histograms, const int64_t eff_batch_size, -const int64_t topk_val, const int64_t reserved_bos, const int64_t reserved_eos, -const int64_t max_num_pages, const int64_t mapping_mode = 0, const double mapping_power = 0.5, std::optional mapping_lut = std::nullopt, -std::optional mapping_quantiles = std::nullopt, -const bool mapping_noscale = false +std::optional mapping_quantiles = std::nullopt ); void topk_profile_counters( @@ -168,8 +154,7 @@ const int64_t max_num_pages, const int64_t mapping_mode = 0, const double mapping_power = 0.5, std::optional mapping_lut = std::nullopt, -std::optional mapping_quantiles = std::nullopt, -const bool mapping_noscale = false +std::optional mapping_quantiles = std::nullopt ); void sglang_plan_decode_fa3( diff --git a/csrc/topk_mapping.cuh b/csrc/topk_mapping.cuh index 773cdeb..447e539 100644 --- a/csrc/topk_mapping.cuh +++ b/csrc/topk_mapping.cuh @@ -4,60 +4,46 @@ #include // ============================================================ -// TopK bucket-sort Stage-1 remapping strategies +// TopK bucket-sort Stage-1 remap transforms (lean version). // -// These transforms remap float scores before Stage 1's 8-bit -// histogram binning. The primary goal is to maximize coarse-bin -// resolution in the score region that determines the top-k -// cutoff, thereby: -// - shrinking the Stage-1 threshold bin (fewer collisions) -// - reducing COUNTER_NUM_EQUAL / COUNTER_STAGE2_INPUT -// - reducing the number of Stage-2 refine rounds +// These are element-wise transforms applied to scores before +// the Stage-1 8-bit histogram bucketing. The goal is to spread +// a skewed raw distribution more uniformly across the 256 bins +// so the threshold bin shrinks and Stage-2 refinement does less +// work. Stage 2 still uses convert_to_uint32() on the remapped +// value's raw bits for tie-breaking. // -// Stage 2 refinement still uses convert_to_uint32() on raw -// floats, so final ordering correctness is always preserved. -// -// Modes 3/4/6/7/9/10 apply a nonlinear transform then linearly -// map the result to [0,255]. Mode 12 (ADAPTIVE_TAIL_WINDOW) -// directly focuses all 256 bins on the competitive upper tail -// estimated from the top-k ratio, collapsing irrelevant -// low-score mass into bin 0. +// There is no pre-pass, no auto-range, no LUT, no quantile +// table, and no shared-memory state — each transform is a +// pure function of one float. The heavy pre-pass machinery +// (auto-range, pivot, tail-window, topk-window, LUT_CDF, +// QUANTILE, SUBTRACT, TRUNC8) lives in +// csrc/archived/fast_topk_vortex_prepass.cu. // ============================================================ enum TopKMappingMode { - MAPPING_NONE = 0, // Original convert_to_uint8 behavior - MAPPING_LUT_CDF = 1, // LUT-based CDF equalization - MAPPING_QUANTILE = 2, // Piecewise-linear quantile mapping - MAPPING_POWER = 3, // Monotonic power transform - MAPPING_LOG = 4, // Log transform - // Mode 5 reserved (previously INDEX_CACHE, removed) - MAPPING_ASINH = 6, // asinh(beta * x), beta via power_exp - MAPPING_LOG1P = 7, // sign(x) * log1p(alpha * |x|), alpha via power_exp - MAPPING_TRUNC8 = 8, // BF16 upper-8-bit bucketing - MAPPING_ERF = 9, // erf(alpha * x) - MAPPING_TANH = 10, // tanh(alpha * x) - MAPPING_SUBTRACT = 11, // subtract pivot, then fp16 bucketing - MAPPING_ADAPTIVE_TAIL_WINDOW = 12, // focus bins on upper tail via sampled quantile - MAPPING_EXP_STRETCH = 13, // exp(alpha * x), concentrates bin resolution on upper tail - MAPPING_TOPK_WINDOW = 14, // k-aware linear windowing: focus bins on [tau_low, max] + MAPPING_NONE = 0, // identity (no remap) + MAPPING_LUT_CDF = 1, // bin lookup: new_bin = lut[convert_to_uint8(x)] + MAPPING_QUANTILE = 2, // binary search over 256 calibrated quantile thresholds + MAPPING_POWER = 3, // sign(x) * |x|^p + MAPPING_LOG = 4, // sign(x) * log(|x| + 1) + MAPPING_ASINH = 6, // asinh(beta * x) + MAPPING_LOG1P = 7, // sign(x) * log1p(alpha * |x|) + MAPPING_TRUNC8 = 8, // identity bucketing (historical name, alias of MAPPING_NONE) + MAPPING_ERF = 9, // erf(alpha * x) + MAPPING_TANH = 10, // tanh(alpha * x) + MAPPING_SUBTRACT = 11, // x - pivot, with pivot = power_exp (free hyperparameter) + MAPPING_EXP_STRETCH = 13, // exp(alpha * x) }; struct TopKMappingParams { - int mode; // TopKMappingMode - float power_exp; // For MAPPING_POWER (default 0.5) - // For MAPPING_ADAPTIVE_TAIL_WINDOW: tail expansion - // factor rho (default 4.0). tau_low = Q(1 - rho*k/n). - const uint8_t* __restrict__ lut; // [256] byte LUT, or nullptr - const float* __restrict__ quantiles; // [256] float quantile breakpoints, or nullptr - bool noscale; // Skip auto-range linear scaling, use fp16 bucketing on f(x) - int sample_stride; // Pre-pass sampling stride (1=full, 8=1/8, 0=skip) - int target_k; // Top-k value; used by MAPPING_ADAPTIVE_TAIL_WINDOW + int mode; // TopKMappingMode + float power_exp; // Free hyperparameter: p / alpha / beta / pivot depending on mode + const uint8_t* __restrict__ lut; // [256] uint8 LUT, MAPPING_LUT_CDF only + const float* __restrict__ quantiles; // [256] float quantile breakpoints, MAPPING_QUANTILE only }; -// NOTE: convert_to_uint8() must be defined before including this header. -// It is defined in topk_sglang.cu within the anonymous namespace. - -// ---- Individual transform functions (return float, no bucketing) ---- +// ---- Element-wise transforms ---- __device__ __forceinline__ float transform_power(float x, float p) { return copysignf(__powf(fabsf(x), p), x); @@ -89,124 +75,66 @@ __device__ __forceinline__ float transform_exp_stretch(float x, float alpha) { return expf(z); } -// ---- Transform dispatcher (returns float, no bucketing) ---- - +// Pure element-wise dispatcher. Returns the *float value* after the transform. +// For bin-selection modes (LUT_CDF / QUANTILE) this is identity: the mapping +// happens in compute_stage1_bin() below instead of via a float transform, so +// Stage-2 tie-breaking uses the raw score bits for those modes. __device__ __forceinline__ float apply_transform(float x, const TopKMappingParams& params) { switch (params.mode) { - case MAPPING_POWER: return transform_power(x, params.power_exp); - case MAPPING_LOG: return transform_log(x); - case MAPPING_ASINH: return transform_asinh(x, params.power_exp); - case MAPPING_LOG1P: return transform_log1p(x, params.power_exp); - case MAPPING_ERF: return transform_erf(x, params.power_exp); - case MAPPING_TANH: return transform_tanh(x, params.power_exp); + case MAPPING_POWER: return transform_power(x, params.power_exp); + case MAPPING_LOG: return transform_log(x); + case MAPPING_ASINH: return transform_asinh(x, params.power_exp); + case MAPPING_LOG1P: return transform_log1p(x, params.power_exp); + case MAPPING_ERF: return transform_erf(x, params.power_exp); + case MAPPING_TANH: return transform_tanh(x, params.power_exp); + case MAPPING_SUBTRACT: return x - params.power_exp; case MAPPING_EXP_STRETCH: return transform_exp_stretch(x, params.power_exp); - default: return x; + case MAPPING_LUT_CDF: + case MAPPING_QUANTILE: + case MAPPING_TRUNC8: + default: return x; // NONE / TRUNC8 / LUT_CDF / QUANTILE } } -// ---- Linear bucketing for transform modes ---- - -__device__ __forceinline__ uint8_t linear_map_to_uint8(float val, float range_min, float inv_range) { - int bin = __float2int_rd((val - range_min) * inv_range); - return static_cast(min(max(bin, 0), 255)); -} - -// ---- BF16-aware bucketing (mode 8) ---- -// BF16 has 8 exponent + 7 mantissa bits. Taking the upper 8 bits of the -// sign-flipped bf16 bit-pattern yields only ~20 distinct bins for typical -// data (the byte is almost entirely exponent). Instead, convert through -// fp16 (5 exp + 10 mantissa) which puts 5 exp + 2 mantissa bits in the -// upper byte, giving ~135+ distinct bins — equivalent to mode 0 but -// explicitly available as a named mode for documentation/benchmarking. - -__device__ __forceinline__ uint8_t convert_to_uint8_bf16(float x) { - return convert_to_uint8(x); // fp16 sign-flip bucketing +// Whether the mapping mode is a direct bin-selection function (LUT_CDF / +// QUANTILE). These modes need per-block shared-memory tables. +__device__ __forceinline__ bool mapping_uses_table(int mode) { + return mode == MAPPING_LUT_CDF || mode == MAPPING_QUANTILE; } -// ---- Non-transform mapping functions (unchanged) ---- - -// LUT-based CDF equalization: lut[original_bin] -> equalized_bin -__device__ __forceinline__ uint8_t map_lut_cdf(float x, const uint8_t* __restrict__ s_lut) { - return s_lut[convert_to_uint8(x)]; -} - -// Quantile mapping: binary search over 256 sorted thresholds -__device__ __forceinline__ uint8_t map_quantile(float x, const float* __restrict__ s_quantiles) { - // Binary search: find largest index i such that x >= s_quantiles[i] - // s_quantiles is sorted ascending, length 256 +// Binary search over a sorted [256] quantile table. Returns the largest +// index i such that x >= quantiles[i], in [0, 255]. +__device__ __forceinline__ uint8_t quantile_bin_lookup( + float x, const float* __restrict__ s_quantiles) +{ int lo = 0, hi = 255; #pragma unroll 8 for (int iter = 0; iter < 8; ++iter) { int mid = (lo + hi + 1) >> 1; - if (x >= s_quantiles[mid]) { - lo = mid; - } else { - hi = mid - 1; - } + if (x >= s_quantiles[mid]) lo = mid; + else hi = mid - 1; } return static_cast(lo); } -// ---- Unified dispatcher ---- -// For modes 3/4/6/7, range_min and inv_range come from a per-block pre-pass. +// Forward decl so compute_stage1_bin can call it. Defined in the enclosing TU. +__device__ __forceinline__ uint8_t convert_to_uint8(float x); -__device__ __forceinline__ uint8_t mapped_convert_to_uint8( - float x, +// Compute the Stage-1 bin for a raw score under any mapping mode. LUT_CDF / +// QUANTILE use the shared-memory tables loaded at the kernel entry; every +// other mode falls back to convert_to_uint8(apply_transform(x)). +__device__ __forceinline__ uint8_t compute_stage1_bin( + float raw, const TopKMappingParams& params, const uint8_t* __restrict__ s_lut, - const float* __restrict__ s_quantiles, - float range_min, - float inv_range) + const float* __restrict__ s_quantiles) { switch (params.mode) { case MAPPING_LUT_CDF: - if (params.lut != nullptr) return map_lut_cdf(x, s_lut); - return convert_to_uint8(x); // fallback to mode 0 when LUT not calibrated + return s_lut[convert_to_uint8(raw)]; case MAPPING_QUANTILE: - if (params.quantiles != nullptr) return map_quantile(x, s_quantiles); - return convert_to_uint8(x); // fallback to mode 0 when quantiles not calibrated - case MAPPING_POWER: - case MAPPING_LOG: - case MAPPING_ASINH: - case MAPPING_LOG1P: - case MAPPING_ERF: - case MAPPING_TANH: - case MAPPING_EXP_STRETCH: { - float val = apply_transform(x, params); - if (params.noscale) return convert_to_uint8(val); - return linear_map_to_uint8(val, range_min, inv_range); - } - case MAPPING_TRUNC8: - return convert_to_uint8_bf16(x); - case MAPPING_SUBTRACT: - return convert_to_uint8(x - range_min); // range_min repurposed as pivot - case MAPPING_ADAPTIVE_TAIL_WINDOW: - case MAPPING_TOPK_WINDOW: - return linear_map_to_uint8(x, range_min, inv_range); - default: // MAPPING_NONE - return convert_to_uint8(x); + return quantile_bin_lookup(raw, s_quantiles); + default: + return convert_to_uint8(apply_transform(raw, params)); } } - -// Helper: check if a mapping mode needs the auto-range pre-pass -__device__ __forceinline__ bool needs_auto_range(int mode) { - return (mode == MAPPING_POWER || mode == MAPPING_LOG || - mode == MAPPING_ASINH || mode == MAPPING_LOG1P || - mode == MAPPING_ERF || mode == MAPPING_TANH || - mode == MAPPING_EXP_STRETCH); -} - -// Helper: check if a mapping mode needs the pivot pre-pass -__device__ __forceinline__ bool needs_pivot(int mode) { - return (mode == MAPPING_SUBTRACT); -} - -// Helper: check if mode is the adaptive tail-window pre-pass -__device__ __forceinline__ bool needs_tail_window(int mode) { - return (mode == MAPPING_ADAPTIVE_TAIL_WINDOW); -} - -// Helper: check if mode is the lightweight topk-window pre-pass -__device__ __forceinline__ bool needs_topk_window(int mode) { - return (mode == MAPPING_TOPK_WINDOW); -} diff --git a/csrc/topk_sglang.cu b/csrc/topk_sglang.cu index 46dcdd7..2466f57 100644 --- a/csrc/topk_sglang.cu +++ b/csrc/topk_sglang.cu @@ -1,8 +1,20 @@ /** - * Vortex TopK kernel — mirrors topk_slgang_ori.cu structure with additions: - * - bf16 support, flexible radix, mapping/remap modes - * - CSR paged wrapper kernels for vortex integration - * Profiling kernels are in topk_sglang_profile.cu. + * Vortex TopK kernels. + * + * Three production kernels: + * - fast_topk_clean : unmapped baseline (two-stage radix). + * - fast_topk_clean_fused : remap + topk fused (apply_transform + * applied inline in Stage-1 bucketing). + * - TopKRemapOnly_Kernel : standalone element-wise remap pass + * used by the split-phase benchmark. + * + * Profiling kernels (counter collection, histogram collection) live in + * topk_sglang_profile.cu and MUST NOT be used for latency measurements — + * they intentionally write extra diagnostic state to global memory. + * + * Archived / historical kernels: csrc/archived/ (fast_topk_vortex with + * pre-pass modes, TopKOutput_Ori_Kernel with flexible radix_bits, the + * original SGLang reference kernel). */ #include #include @@ -100,22 +112,6 @@ __device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) constexpr int VORTEX_MAX_TOPK = 2048; -constexpr int COUNTER_THRESHOLD_BIN = 0; -constexpr int COUNTER_NUM_ABOVE = 1; -constexpr int COUNTER_NUM_EQUAL = 2; -constexpr int COUNTER_REMAINING_K = 3; -constexpr int COUNTER_REFINE_ROUNDS = 4; -constexpr int COUNTER_STAGE2_INPUT = 5; -constexpr int NUM_TOPK_COUNTERS = 6; - -template -__device__ __forceinline__ uint16_t convert_to_uintN(float x) { - __half h = __float2half_rn(x); - uint16_t bits = __half_as_ushort(h); - uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); - return key >> (16 - BITS); -} - #include "topk_mapping.cuh" @@ -463,80 +459,120 @@ void setup_kernel_smem_once() { } // ====================================================================== -// Ori fast path: zero-overhead topk with no mapping infrastructure. -// Template on RADIX_BITS: 4-10 (16 to 1024 bins). +// Templated clean baseline: identical algorithm to fast_topk_cuda_tl but +// parameterised on ScoreT (float or __nv_bfloat16) for the GQA / paged +// call paths that operate on bf16 attention scores. No mapping, no +// pre-pass — pure two-stage radix topk on fp16 bit-pattern bins. // ====================================================================== -template -__device__ void fast_topk_ori( +template +__device__ void fast_topk_clean( const ScoreT* __restrict__ input, int* __restrict__ index, int row_start, int length, int target_k) { - int topk = target_k; - constexpr auto BLOCK_SIZE = 1024; - constexpr auto RADIX = 1 << RADIX_BITS; - constexpr auto RADIX_PAD = RADIX / 2; - constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); - static_assert(RADIX_BITS >= 4 && RADIX_BITS <= 10, "RADIX_BITS must be 4-10"); - static_assert(RADIX <= BLOCK_SIZE, "RADIX must not exceed BLOCK_SIZE"); + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; - alignas(128) __shared__ int s_histogram_buf[2][RADIX + RADIX_PAD]; - alignas(128) __shared__ int s_counter; - alignas(128) __shared__ int s_threshold_bin_id; - alignas(128) __shared__ int s_num_input[2]; + auto& s_histogram = s_histogram_buf[0]; + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; - auto& s_histogram = s_histogram_buf[0]; - extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + // stage 1: 8-bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(vortex_to_float(input[idx + row_start])); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); - const int tx = threadIdx.x; + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; - // Stage 1: coarse histogram with RADIX bins + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uint8(vortex_to_float(input[idx + row_start]))); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); if (tx < RADIX + 1) s_histogram[tx] = 0; __syncthreads(); for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = convert_to_uintN(vortex_to_float(input[idx + row_start])); - ::atomicAdd(&s_histogram[bin], 1); + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } } __syncthreads(); + } - const auto run_cumsum = [&] { - for (int i = 0; i < RADIX_BITS; ++i) { - if (C10_LIKELY(tx < RADIX)) { - const auto j = 1 << i; - const auto k = i & 1; - auto value = s_histogram_buf[k][tx]; - if (tx < RADIX - j) { - value += s_histogram_buf[k][tx + j]; - } - s_histogram_buf[k ^ 1][tx] = value; - } - __syncthreads(); - } - }; - // Stage 2 cumsum: always 256 sub-bins (8-bit radix on raw float bits) - const auto run_cumsum_s2 = [&] { - for (int i = 0; i < 8; ++i) { - if (C10_LIKELY(tx < 256)) { - const auto j = 1 << i; - const auto k = i & 1; - auto value = s_histogram_buf[k][tx]; - if (tx < 256 - j) { - value += s_histogram_buf[k][tx + j]; - } - s_histogram_buf[k ^ 1][tx] = value; - } - __syncthreads(); - } - }; + // stage 2: refine with 8-bit radix passes on raw fp32 bits +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); run_cumsum(); if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { - s_threshold_bin_id = tx; - s_num_input[0] = 0; - s_counter = 0; + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; } __syncthreads(); @@ -544,582 +580,249 @@ __device__ void fast_topk_ori( topk -= s_histogram[threshold_bin + 1]; if (topk == 0) { - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = static_cast(convert_to_uintN(vortex_to_float(input[idx + row_start]))); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; } - __syncthreads(); - return; + } + __syncthreads(); + break; } else { - __syncthreads(); - if (tx < 257) s_histogram[tx] = 0; - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto raw_input = vortex_to_float(input[idx + row_start]); - const auto bin = static_cast(convert_to_uintN(raw_input)); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - const auto pos = ::atomicAdd(&s_num_input[0], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - s_input_idx[0][pos] = idx; - const auto b32 = convert_to_uint32(raw_input); - const auto sub_bin = (b32 >> 24) & 0xFF; - ::atomicAdd(&s_histogram[sub_bin], 1); - } - } - } - __syncthreads(); - } - - // Stage 2: refine with 8-bit radix passes -#pragma unroll 4 - for (int round = 0; round < 4; ++round) { - __shared__ int s_last_remain; - const auto r_idx = round % 2; - - const auto _raw_num_input = s_num_input[r_idx]; - const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); - - run_cumsum_s2(); - if (tx < 256 && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { - s_threshold_bin_id = tx; - s_num_input[r_idx ^ 1] = 0; - s_last_remain = topk - s_histogram[tx + 1]; - } - __syncthreads(); - - const auto threshold_bin = s_threshold_bin_id; - topk -= s_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = s_input_idx[r_idx][i]; - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32(vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } + __syncthreads(); + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; } - __syncthreads(); - break; - } else { - __syncthreads(); - if (tx < 257) s_histogram[tx] = 0; - __syncthreads(); - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = s_input_idx[r_idx][i]; - const auto raw_input = vortex_to_float(input[idx + row_start]); - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - if (round == 3) { - const auto pos = ::atomicAdd(&s_last_remain, -1); - if (pos > 0) { - index[target_k - pos] = idx; - } - } else { - const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - s_input_idx[r_idx ^ 1][pos] = idx; - const auto b32 = convert_to_uint32(raw_input); - const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; - ::atomicAdd(&s_histogram[sub_bin], 1); - } - } - } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); } - __syncthreads(); + } } + } + __syncthreads(); } + } } // ====================================================================== -// Templated version of fast_topk_cuda_tl with mapping support: -// - ScoreT: float or __nv_bfloat16 -// - StopAfterStage1: return after Stage 1 route/filter (for profiling) -// - WriteCounters: write diagnostic counters to global memory - -// - mapping: configurable value-remapping for Stage 1 bin assignment -template -__device__ void fast_topk_vortex( +// Templated fused kernel: apply_transform(score) -> convert_to_uint8 +// is fused into Stage 1. Stage 2 still uses raw bits for tie-breaking +// (on the *remapped* value, not the original score) — this is a +// benchmarking kernel, the remapped Stage-2 ordering is acceptable. +// No pre-pass, no LUT, no shared-memory mapping state. +// ====================================================================== +template +__device__ void fast_topk_clean_fused( const ScoreT* __restrict__ input, int* __restrict__ index, int row_start, int length, int target_k, - const TopKMappingParams& mapping, - int* counters = nullptr) + const TopKMappingParams mapping) { - int topk = target_k; - constexpr auto BLOCK_SIZE = 1024; - constexpr auto RADIX = 256; - constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); - alignas(128) __shared__ int vh_histogram_buf[2][RADIX + 128]; - alignas(128) __shared__ int vh_counter; - alignas(128) __shared__ int vh_threshold_bin_id; - alignas(128) __shared__ int vh_num_input[2]; + alignas(128) __shared__ int f_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int f_counter; + alignas(128) __shared__ int f_threshold_bin_id; + alignas(128) __shared__ int f_num_input[2]; - // Shared memory for mapping LUT / quantiles (loaded once per block) - __shared__ uint8_t s_mapping_lut[256]; - __shared__ float s_mapping_quantiles[256]; + // Shared-memory tables for MAPPING_LUT_CDF / MAPPING_QUANTILE. Loaded + // once at kernel entry and read per element in Stage 1. Other modes + // leave them untouched. + __shared__ uint8_t s_mapping_lut[256]; + __shared__ float s_mapping_quantiles[256]; - // Auto-range for transform modes (3/4/6/7) - __shared__ float s_range_min, s_range_inv_range; + auto& f_histogram = f_histogram_buf[0]; + extern __shared__ int f_input_idx[][SMEM_INPUT_SIZE]; - auto& vh_histogram = vh_histogram_buf[0]; - extern __shared__ int vh_input_idx[][SMEM_INPUT_SIZE]; + const int tx = threadIdx.x; - const int tx = threadIdx.x; + if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { + if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; + __syncthreads(); + } + if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { + if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; + __syncthreads(); + } - // Load mapping tables into shared memory if needed - if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { - if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; - __syncthreads(); - } - if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { - if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; - __syncthreads(); - } + if (tx < RADIX + 1) f_histogram[tx] = 0; + __syncthreads(); - // Pre-pass: compute per-block min/max of transformed values for linear bucketing. - // sample_stride > 1 reduces pre-pass cost by scanning every Nth element; - // the approximated range may miss extreme outliers but Stage 2 uses raw - // float bits for exact ordering, so correctness is preserved. - if (needs_auto_range(mapping.mode) && !mapping.noscale) { - const int stride = (mapping.sample_stride > 1) ? mapping.sample_stride : 1; - float local_min = __FLT_MAX__, local_max = -__FLT_MAX__; - for (int idx = tx * stride; idx < length; idx += BLOCK_SIZE * stride) { - float val = apply_transform(vortex_to_float(input[idx + row_start]), mapping); - local_min = fminf(local_min, val); - local_max = fmaxf(local_max, val); - } - // Warp-level reduction - for (int offset = 16; offset > 0; offset >>= 1) { - local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - } - // Cross-warp reduction via shared memory - __shared__ float s_warp_mins[32], s_warp_maxs[32]; - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) { s_warp_mins[warp_id] = local_min; s_warp_maxs[warp_id] = local_max; } - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_min = s_warp_mins[tx]; local_max = s_warp_maxs[tx]; - for (int offset = 16; offset > 0; offset >>= 1) { - local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - } - if (tx == 0) { - s_range_min = local_min; - float range = local_max - local_min; - s_range_inv_range = (range > 0.0f) ? 255.0f / range : 0.0f; - } - } - __syncthreads(); - } else if (needs_pivot(mapping.mode)) { - // Pivot pre-pass: compute mean of all elements, store in s_range_min. - // MAPPING_SUBTRACT uses convert_to_uint8(x - range_min), so centering - // around the mean helps distribute values more evenly across bins. - float local_sum = 0.0f; - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - local_sum += vortex_to_float(input[idx + row_start]); - } - // Warp-level reduction - for (int offset = 16; offset > 0; offset >>= 1) { - local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); - } - __shared__ float s_warp_sums[32]; - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) s_warp_sums[warp_id] = local_sum; - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_sum = s_warp_sums[tx]; - for (int offset = 16; offset > 0; offset >>= 1) { - local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); - } - if (tx == 0) { - s_range_min = local_sum / float(length); // mean as pivot - s_range_inv_range = 0.0f; - } - } - __syncthreads(); - } else if (needs_tail_window(mapping.mode)) { - // Adaptive tail-window pre-pass: estimate tau_low = Q(1 - rho*k/n) - // and local_max via a sampled quantile estimator. All 256 coarse bins - // are then allocated to [tau_low, local_max]; scores below tau_low - // collapse into bin 0 via linear_map_to_uint8 clamping. - constexpr int MAX_SAMPLES = 1024; - __shared__ float s_samples[MAX_SAMPLES]; - __shared__ int s_sample_count; - - if (tx == 0) s_sample_count = 0; - __syncthreads(); - - // Compute sampling stride so we collect ~MAX_SAMPLES from the segment - const int desired_stride = (length + MAX_SAMPLES - 1) / MAX_SAMPLES; - const int sample_stride = max(desired_stride, 1); - - // Each thread samples elements and finds local_max simultaneously - float local_max = -__FLT_MAX__; - for (int idx = tx * sample_stride; idx < length; idx += BLOCK_SIZE * sample_stride) { - float val = vortex_to_float(input[idx + row_start]); - local_max = fmaxf(local_max, val); - int slot = ::atomicAdd(&s_sample_count, 1); - if (slot < MAX_SAMPLES) { - s_samples[slot] = val; - } - } + // Stage 1: LUT/QUANTILE do a shared-memory lookup, everything else + // applies the element-wise transform then buckets via convert_to_uint8. + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const float raw = vortex_to_float(input[idx + row_start]); + const auto bin = compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles); + ::atomicAdd(&f_histogram[bin], 1); + } + __syncthreads(); - // Reduce local_max across block - for (int offset = 16; offset > 0; offset >>= 1) - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - __shared__ float s_warp_maxs_tw[32]; - { - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) s_warp_maxs_tw[warp_id] = local_max; - } - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_max = s_warp_maxs_tw[tx]; - for (int offset = 16; offset > 0; offset >>= 1) - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - if (tx == 0) s_warp_maxs_tw[0] = local_max; - } - __syncthreads(); - local_max = s_warp_maxs_tw[0]; - - int nsamp = min(s_sample_count, MAX_SAMPLES); - - // Simple odd-even transposition sort on the sample buffer. - // nsamp <= 1024, and we have 1024 threads, so each thread - // handles one element. O(nsamp) parallel rounds suffice. - __syncthreads(); - if (nsamp >= 2) { - for (int pass = 0; pass < nsamp; ++pass) { - // Even phase: compare (0,1), (2,3), ... - if (tx * 2 + 1 < nsamp) { - int i = tx * 2; - if (s_samples[i] > s_samples[i + 1]) { - float tmp = s_samples[i]; - s_samples[i] = s_samples[i + 1]; - s_samples[i + 1] = tmp; - } - } - __syncthreads(); - // Odd phase: compare (1,2), (3,4), ... - if (tx * 2 + 2 < nsamp) { - int i = tx * 2 + 1; - if (s_samples[i] > s_samples[i + 1]) { - float tmp = s_samples[i]; - s_samples[i] = s_samples[i + 1]; - s_samples[i + 1] = tmp; - } - } - __syncthreads(); - } + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = f_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += f_histogram_buf[k][tx + j]; } + f_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; - // Estimate tau_low = Q(1 - rho * k / n) - if (tx == 0) { - float rho = mapping.power_exp; // reused as tail expansion factor - if (rho <= 0.0f) rho = 4.0f; - int k = (mapping.target_k > 0) ? mapping.target_k : target_k; - float frac = 1.0f - rho * float(k) / float(length); - frac = fmaxf(frac, 0.0f); // clamp: never go below rank 0 - - float tau_low; - if (nsamp < 4 || frac <= 0.0f) { - // Too few samples or the tail covers everything: full range - tau_low = -__FLT_MAX__; - } else { - float fidx = frac * float(nsamp - 1); - int lo = __float2int_rd(fidx); - lo = min(max(lo, 0), nsamp - 2); - float t = fidx - float(lo); - tau_low = s_samples[lo] * (1.0f - t) + s_samples[lo + 1] * t; - } + run_cumsum(); + if (tx < RADIX && f_histogram[tx] > topk && f_histogram[tx + 1] <= topk) { + f_threshold_bin_id = tx; + f_num_input[0] = 0; + f_counter = 0; + } + __syncthreads(); - // Fallback: if tau_low >= local_max, use full-range linear mapping - if (tau_low >= local_max) { - // Find the actual minimum from sorted samples - tau_low = (nsamp > 0) ? s_samples[0] : local_max; - } + const auto threshold_bin = f_threshold_bin_id; + topk -= f_histogram[threshold_bin + 1]; - float range = local_max - tau_low; - s_range_min = tau_low; - s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; - } - __syncthreads(); - } else if (needs_topk_window(mapping.mode)) { - // Topk-window pre-pass with streaming variance heuristic. - // tau_low = max - rho * sigma * sqrt(2 * log(n/k)) - float local_max = -__FLT_MAX__; - float local_sum = 0.0f, local_sum_sq = 0.0f; - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - float val = vortex_to_float(input[idx + row_start]); - local_max = fmaxf(local_max, val); - local_sum += val; - local_sum_sq += val * val; - } - for (int offset = 16; offset > 0; offset >>= 1) { - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); - local_sum_sq += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq, offset); - } - __shared__ float s_warp_maxs_tw2[32], s_warp_sums_tw2[32], s_warp_sq_tw2[32]; - { - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) { - s_warp_maxs_tw2[warp_id] = local_max; - s_warp_sums_tw2[warp_id] = local_sum; - s_warp_sq_tw2[warp_id] = local_sum_sq; - } - } - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_max = s_warp_maxs_tw2[tx]; - local_sum = s_warp_sums_tw2[tx]; - local_sum_sq = s_warp_sq_tw2[tx]; - for (int offset = 16; offset > 0; offset >>= 1) { - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); - local_sum_sq += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq, offset); - } - if (tx == 0) { - float rho = mapping.power_exp; - if (rho <= 0.0f) rho = 4.0f; - int k = (mapping.target_k > 0) ? mapping.target_k : target_k; - float n = float(length); - float mean = local_sum / n; - float var = local_sum_sq / n - mean * mean; - float sigma = (var > 0.0f) ? sqrtf(var) : 0.0f; - float ratio = n / fmaxf(float(k), 1.0f); - float z = sqrtf(2.0f * __logf(fmaxf(ratio, 1.0f))); - float tau_low = local_max - rho * sigma * z; - if (tau_low >= local_max) tau_low = local_max - 1.0f; - float range = local_max - tau_low; - s_range_min = tau_low; - s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; - } - } - __syncthreads(); - } else { - if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } - __syncthreads(); + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const float raw = vortex_to_float(input[idx + row_start]); + const auto bin = static_cast( + compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&f_counter, 1); + index[pos] = idx; + } } - - // Stage 1: 8-bit coarse histogram (with optional mapping) - // Bin cache: store computed bins in vh_input_idx[1] (reinterpreted as uint8_t*) - // to avoid recomputing mapped_convert_to_uint8 in the route/filter pass. - // vh_input_idx[1] is unused until Stage 2 double-buffering starts after route. - constexpr int BIN_CACHE_CAPACITY = SMEM_INPUT_SIZE * static_cast(sizeof(int)); // uint8 entries - uint8_t* bin_cache = reinterpret_cast(vh_input_idx[1]); - const bool use_bin_cache = (length <= BIN_CACHE_CAPACITY); - - if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) f_histogram[tx] = 0; __syncthreads(); for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = mapped_convert_to_uint8( - vortex_to_float(input[idx + row_start]), - mapping, s_mapping_lut, s_mapping_quantiles, - s_range_min, s_range_inv_range); - ::atomicAdd(&vh_histogram[bin], 1); - if (use_bin_cache) { - bin_cache[idx] = bin; + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform(raw, mapping); + const auto bin = static_cast( + compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&f_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&f_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + f_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(remapped); + const auto sub_bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&f_histogram[sub_bin], 1); } + } } __syncthreads(); + } - const auto run_cumsum = [&] { -#pragma unroll 8 - for (int i = 0; i < 8; ++i) { - static_assert(1 << 8 == RADIX); - if (C10_LIKELY(tx < RADIX)) { - const auto j = 1 << i; - const auto k = i & 1; - auto value = vh_histogram_buf[k][tx]; - if (tx < RADIX - j) { - value += vh_histogram_buf[k][tx + j]; - } - vh_histogram_buf[k ^ 1][tx] = value; - } - __syncthreads(); - } - }; + // stage 2: refine on raw bits of the remapped value +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int f_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = f_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); run_cumsum(); - if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { - vh_threshold_bin_id = tx; - vh_num_input[0] = 0; - vh_counter = 0; + if (tx < RADIX && f_histogram[tx] > topk && f_histogram[tx + 1] <= topk) { + f_threshold_bin_id = tx; + f_num_input[r_idx ^ 1] = 0; + f_last_remain = topk - f_histogram[tx + 1]; } __syncthreads(); - const auto threshold_bin = vh_threshold_bin_id; - topk -= vh_histogram[threshold_bin + 1]; - - if (WriteCounters && tx == 0 && counters) { - counters[COUNTER_THRESHOLD_BIN] = threshold_bin; - counters[COUNTER_REMAINING_K] = topk; - } + const auto threshold_bin = f_threshold_bin_id; + topk -= f_histogram[threshold_bin + 1]; if (topk == 0) { - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - int bin; - if (use_bin_cache) { - bin = static_cast(bin_cache[idx]); - } else { - bin = static_cast( - mapped_convert_to_uint8( - vortex_to_float(input[idx + row_start]), - mapping, s_mapping_lut, s_mapping_quantiles, - s_range_min, s_range_inv_range)); - } - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - if (WriteCounters && tx == 0 && counters) { - counters[COUNTER_NUM_ABOVE] = vh_counter; - counters[COUNTER_NUM_EQUAL] = 0; - counters[COUNTER_REFINE_ROUNDS] = 0; - counters[COUNTER_STAGE2_INPUT] = 0; + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = f_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform(raw, mapping); + const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&f_counter, 1); + index[pos] = idx; } - return; + } + __syncthreads(); + break; } else { - __syncthreads(); - if (tx < RADIX + 1) vh_histogram[tx] = 0; - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto raw_input = vortex_to_float(input[idx + row_start]); - int bin; - if (use_bin_cache) { - bin = static_cast(bin_cache[idx]); - } else { - bin = static_cast( - mapped_convert_to_uint8(raw_input, mapping, - s_mapping_lut, s_mapping_quantiles, - s_range_min, s_range_inv_range)); - } - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - const auto pos = ::atomicAdd(&vh_num_input[0], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - vh_input_idx[0][pos] = idx; - const auto b32 = convert_to_uint32(raw_input); - const auto sub_bin = (b32 >> 24) & 0xFF; - ::atomicAdd(&vh_histogram[sub_bin], 1); - } - } - } - __syncthreads(); - if (WriteCounters && tx == 0 && counters) { - counters[COUNTER_NUM_ABOVE] = vh_counter; - counters[COUNTER_NUM_EQUAL] = vh_num_input[0]; - counters[COUNTER_STAGE2_INPUT] = vh_num_input[0]; - } - if (StopAfterStage1) return; - } - - // Stage 2: refine with 8-bit radix passes (unchanged — uses raw float bits) - if constexpr (WriteCounters) { - // Default: all 4 rounds used; overwritten at break if resolved early - if (tx == 0 && counters) counters[COUNTER_REFINE_ROUNDS] = 4; - } -#pragma unroll 4 - for (int round = 0; round < 4; ++round) { - __shared__ int vh_last_remain; - const auto r_idx = round % 2; - - const auto _raw_num_input = vh_num_input[r_idx]; - const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) - ? _raw_num_input - : int(SMEM_INPUT_SIZE); - - run_cumsum(); - if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { - vh_threshold_bin_id = tx; - vh_num_input[r_idx ^ 1] = 0; - vh_last_remain = topk - vh_histogram[tx + 1]; - } - __syncthreads(); - - const auto threshold_bin = vh_threshold_bin_id; - topk -= vh_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = vh_input_idx[r_idx][i]; - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32( - vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - if constexpr (WriteCounters) { - if (tx == 0 && counters) { - counters[COUNTER_REFINE_ROUNDS] = round + 1; - } + __syncthreads(); + if (tx < RADIX + 1) f_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = f_input_idx[r_idx][i]; + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform(raw, mapping); + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&f_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&f_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; } - break; - } else { - __syncthreads(); - if (tx < RADIX + 1) vh_histogram[tx] = 0; - __syncthreads(); - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = vh_input_idx[r_idx][i]; - const auto raw_input = vortex_to_float(input[idx + row_start]); - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - if (round == 3) { - const auto pos = ::atomicAdd(&vh_last_remain, -1); - if (pos > 0) { - index[target_k - pos] = idx; - } - } else { - const auto pos = ::atomicAdd(&vh_num_input[r_idx ^ 1], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - vh_input_idx[r_idx ^ 1][pos] = idx; - const auto b32 = convert_to_uint32(raw_input); - const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; - ::atomicAdd(&vh_histogram[sub_bin], 1); - } - } - } + } else { + const auto pos = ::atomicAdd(&f_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + f_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(remapped); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&f_histogram[sub_bin], 1); } - __syncthreads(); + } } + } + __syncthreads(); } + } } -// Wrapper kernel: one CUDA block per batch*head segment +// Wrapper kernels: one CUDA block per (batch*head) segment. + template __global__ __launch_bounds__(kThreadsPerBlock) -void TopKOutput_Kernel( +void TopKOutput_Clean_Kernel( const ScoreT* __restrict__ score, const int* __restrict__ dense_kv_indptr, const int* __restrict__ sparse_kv_indptr, @@ -1127,37 +830,34 @@ void TopKOutput_Kernel( int* __restrict__ sparse_kv_indices, const int topk_val, const int page_reserved_bos, - const int page_reserved_eos, - const TopKMappingParams mapping) + const int page_reserved_eos) { - const int bx = blockIdx.x; + const int bx = blockIdx.x; - const int start = dense_kv_indptr[bx] + page_reserved_bos; - const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; - const int nblk = end - start; - if (nblk <= topk_val) return; + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; - const ScoreT* __restrict__ score_blk = score + start; - const int* __restrict__ idx_blk = dense_kv_indices + start; - int* __restrict__ out_blk = sparse_kv_indices - + sparse_kv_indptr[bx] - + page_reserved_bos; + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; - __shared__ int s_indices[VORTEX_MAX_TOPK]; - fast_topk_vortex(score_blk, s_indices, 0, nblk, topk_val, mapping); - __syncthreads(); + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_clean(score_blk, s_indices, 0, nblk, topk_val); + __syncthreads(); - // Remap position indices -> page indices via dense_kv_indices - const int tx = threadIdx.x; - for (int i = tx; i < topk_val; i += kThreadsPerBlock) { - out_blk[i] = idx_blk[s_indices[i]]; - } + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } } -// Ori fast-path wrapper: zero mapping overhead, flexible radix -template +template __global__ __launch_bounds__(kThreadsPerBlock) -void TopKOutput_Ori_Kernel( +void TopKOutput_Fused_Kernel( const ScoreT* __restrict__ score, const int* __restrict__ dense_kv_indptr, const int* __restrict__ sparse_kv_indptr, @@ -1165,54 +865,60 @@ void TopKOutput_Ori_Kernel( int* __restrict__ sparse_kv_indices, const int topk_val, const int page_reserved_bos, - const int page_reserved_eos) + const int page_reserved_eos, + const TopKMappingParams mapping) { - const int bx = blockIdx.x; + const int bx = blockIdx.x; - const int start = dense_kv_indptr[bx] + page_reserved_bos; - const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; - const int nblk = end - start; - if (nblk <= topk_val) return; + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; - const ScoreT* __restrict__ score_blk = score + start; - const int* __restrict__ idx_blk = dense_kv_indices + start; - int* __restrict__ out_blk = sparse_kv_indices - + sparse_kv_indptr[bx] - + page_reserved_bos; + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; - __shared__ int s_indices[VORTEX_MAX_TOPK]; - fast_topk_ori(score_blk, s_indices, 0, nblk, topk_val); - __syncthreads(); + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_clean_fused(score_blk, s_indices, 0, nblk, topk_val, mapping); + __syncthreads(); - const int tx = threadIdx.x; - for (int i = tx; i < topk_val; i += kThreadsPerBlock) { - out_blk[i] = idx_blk[s_indices[i]]; - } + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } } -// Helper: launch TopKOutput_Ori_Kernel with radix_bits dispatch +// Remap-only kernel: applies the element-wise transform to each score +// in the [dense_kv_indptr[b] + reserved_bos, dense_kv_indptr[b+1] - reserved_eos) +// range and writes the result into a float32 output tensor. Used by +// the split-phase benchmark (remap → unmapped topk). template -void launch_ori_kernel( - const ScoreT* score, const int* dense_kv_indptr, const int* sparse_kv_indptr, - const int* dense_kv_indices, int* sparse_kv_indices, - int topk_val, int reserved_bos, int reserved_eos, - int radix_bits, dim3 nblks, dim3 nthreads, cudaStream_t stream) +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKRemapOnly_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + float* __restrict__ remapped, + const int page_reserved_bos, + const int page_reserved_eos, + const TopKMappingParams mapping) { - #define LAUNCH_ORI(BITS) \ - setup_kernel_smem_once, kSmem>(); \ - TopKOutput_Ori_Kernel<<>>( \ - score, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, sparse_kv_indices, \ - topk_val, reserved_bos, reserved_eos) - switch (radix_bits) { - case 4: LAUNCH_ORI(4); break; - case 5: LAUNCH_ORI(5); break; - case 6: LAUNCH_ORI(6); break; - case 7: LAUNCH_ORI(7); break; - case 9: LAUNCH_ORI(9); break; - case 10: LAUNCH_ORI(10); break; - default: LAUNCH_ORI(8); break; - } - #undef LAUNCH_ORI + const int bx = blockIdx.x; + const int tx = threadIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= 0) return; + + const ScoreT* __restrict__ score_blk = score + start; + float* __restrict__ remap_blk = remapped + start; + + for (int i = tx; i < nblk; i += kThreadsPerBlock) { + remap_blk[i] = apply_transform(vortex_to_float(score_blk[i]), mapping); + } } } // namespace @@ -1331,9 +1037,68 @@ void fast_topk_transform_ragged_interface( } // ====================================================================== -// Vortex host entry point — same interface as topk_output in topk.cu +// Vortex host entry point — unmapped baseline topk (no remap). +// This is the "original topk kernel" used as the benchmarking baseline. // ====================================================================== void topk_output_sglang( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages) +{ + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_output: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(sparse_kv_indptr); + CHECK_CUDA(dense_kv_indices); + CHECK_CUDA(sparse_kv_indices); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Clean_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, reserved_bos, reserved_eos); + } else if (x.scalar_type() == at::ScalarType::Float) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Clean_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, reserved_bos, reserved_eos); + } else { + TORCH_CHECK(false, "topk_output: unsupported dtype ", x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_output kernel failed: ", ::cudaGetErrorString(result)); +} + +// ====================================================================== +// Fused remap + topk host entry. Applies apply_transform(score, mapping) +// inline inside the Stage-1 histogram build — single kernel launch, +// single pass over the score tensor. +// ====================================================================== +void topk_output_sglang_fused( const at::Tensor& x, const at::Tensor& dense_kv_indptr, const at::Tensor& sparse_kv_indptr, @@ -1347,39 +1112,35 @@ void topk_output_sglang( const int64_t mapping_mode, const double mapping_power, std::optional mapping_lut, - std::optional mapping_quantiles, - const bool mapping_noscale, - const int64_t sample_stride, - const int64_t radix_bits) + std::optional mapping_quantiles) { TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, - "topk_output: topk_val (", topk_val, + "topk_output_sglang_fused: topk_val (", topk_val, ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); - TORCH_CHECK(radix_bits >= 4 && radix_bits <= 10, - "topk_output: radix_bits must be 4-10, got ", radix_bits); - // Build mapping params from optional tensors + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(sparse_kv_indptr); + CHECK_CUDA(dense_kv_indices); + CHECK_CUDA(sparse_kv_indices); + TopKMappingParams mapping{}; mapping.mode = static_cast(mapping_mode); mapping.power_exp = static_cast(mapping_power); mapping.lut = nullptr; mapping.quantiles = nullptr; - mapping.noscale = mapping_noscale; - mapping.sample_stride = static_cast(sample_stride); - mapping.target_k = static_cast(topk_val); - if (mapping_lut.has_value()) { const auto& lut = mapping_lut.value(); CHECK_CUDA(lut); TORCH_CHECK(lut.dim() == 1 && lut.size(0) == 256 && lut.scalar_type() == at::ScalarType::Byte, - "mapping_lut must be a 1D uint8 tensor of size 256"); + "mapping_lut must be a 1D uint8 tensor of size 256"); mapping.lut = lut.data_ptr(); } if (mapping_quantiles.has_value()) { const auto& q = mapping_quantiles.value(); CHECK_CUDA(q); TORCH_CHECK(q.dim() == 1 && q.size(0) == 256 && q.scalar_type() == at::ScalarType::Float, - "mapping_quantiles must be a 1D float32 tensor of size 256"); + "mapping_quantiles must be a 1D float32 tensor of size 256"); mapping.quantiles = q.data_ptr(); } @@ -1387,111 +1148,81 @@ void topk_output_sglang( dim3 nthreads(kThreadsPerBlock); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - // Fast path for mode 0 (MAPPING_NONE): use ori kernel with zero mapping overhead - if (mapping_mode == MAPPING_NONE) { - if (x.scalar_type() == at::ScalarType::BFloat16) { - launch_ori_kernel<__nv_bfloat16>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), sparse_kv_indices.data_ptr(), - topk_val, reserved_bos, reserved_eos, - radix_bits, nblks, nthreads, stream); - } else if (x.scalar_type() == at::ScalarType::Float) { - launch_ori_kernel( - x.data_ptr(), - dense_kv_indptr.data_ptr(), sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), sparse_kv_indices.data_ptr(), - topk_val, reserved_bos, reserved_eos, - radix_bits, nblks, nthreads, stream); - } else { - TORCH_CHECK(false, "topk_output: unsupported dtype ", x.scalar_type()); - } + if (x.scalar_type() == at::ScalarType::BFloat16) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Fused_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, reserved_bos, reserved_eos, mapping); + } else if (x.scalar_type() == at::ScalarType::Float) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Fused_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, reserved_bos, reserved_eos, mapping); } else { - if (x.scalar_type() == at::ScalarType::BFloat16) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Kernel<__nv_bfloat16><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos, - mapping); - } else if (x.scalar_type() == at::ScalarType::Float) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Kernel<<>>( - x.data_ptr(), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos, - mapping); - } else { - TORCH_CHECK(false, "topk_output: unsupported dtype ", x.scalar_type()); - } + TORCH_CHECK(false, "topk_output_sglang_fused: unsupported dtype ", x.scalar_type()); } const auto result = cudaGetLastError(); TORCH_CHECK(result == cudaSuccess, - "topk_output kernel failed: ", ::cudaGetErrorString(result)); + "topk_output_sglang_fused kernel failed: ", ::cudaGetErrorString(result)); } // ====================================================================== -// Explicit ori baseline entry point — always uses the ori fast path +// Standalone remap kernel. Writes apply_transform(score) into a +// float32 output buffer without running topk. Used by the split-phase +// benchmark (remap → unmapped topk) to measure each phase independently. // ====================================================================== -void topk_output_sglang_ori( +void topk_remap_only( const at::Tensor& x, const at::Tensor& dense_kv_indptr, - const at::Tensor& sparse_kv_indptr, - const at::Tensor& dense_kv_indices, - at::Tensor& sparse_kv_indices, + at::Tensor& remapped, // float32, same numel as x const int64_t eff_batch_size, - const int64_t topk_val, const int64_t reserved_bos, const int64_t reserved_eos, - const int64_t max_num_pages, - const int64_t radix_bits) + const int64_t mapping_mode, + const double mapping_power) { - TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, - "topk_output_sglang_ori: topk_val (", topk_val, - ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); - TORCH_CHECK(radix_bits >= 4 && radix_bits <= 10, - "topk_output_sglang_ori: radix_bits must be 4-10, got ", radix_bits); - CHECK_CUDA(x); CHECK_CUDA(dense_kv_indptr); - CHECK_CUDA(sparse_kv_indptr); - CHECK_CUDA(dense_kv_indices); - CHECK_CUDA(sparse_kv_indices); + CHECK_CUDA(remapped); + TORCH_CHECK(remapped.scalar_type() == at::ScalarType::Float, + "remapped output must be float32"); + + TopKMappingParams mapping{}; + mapping.mode = static_cast(mapping_mode); + mapping.power_exp = static_cast(mapping_power); + mapping.lut = nullptr; + mapping.quantiles = nullptr; dim3 nblks(eff_batch_size); dim3 nthreads(kThreadsPerBlock); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); if (x.scalar_type() == at::ScalarType::BFloat16) { - launch_ori_kernel<__nv_bfloat16>( + TopKRemapOnly_Kernel<__nv_bfloat16><<>>( reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), sparse_kv_indices.data_ptr(), - topk_val, reserved_bos, reserved_eos, - radix_bits, nblks, nthreads, stream); + dense_kv_indptr.data_ptr(), + remapped.data_ptr(), + reserved_bos, reserved_eos, mapping); } else if (x.scalar_type() == at::ScalarType::Float) { - launch_ori_kernel( + TopKRemapOnly_Kernel<<>>( x.data_ptr(), - dense_kv_indptr.data_ptr(), sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), sparse_kv_indices.data_ptr(), - topk_val, reserved_bos, reserved_eos, - radix_bits, nblks, nthreads, stream); + dense_kv_indptr.data_ptr(), + remapped.data_ptr(), + reserved_bos, reserved_eos, mapping); } else { - TORCH_CHECK(false, "topk_output_sglang_ori: unsupported dtype ", x.scalar_type()); + TORCH_CHECK(false, "topk_remap_only: unsupported dtype ", x.scalar_type()); } const auto result = cudaGetLastError(); TORCH_CHECK(result == cudaSuccess, - "topk_output_sglang_ori kernel failed: ", ::cudaGetErrorString(result)); + "topk_remap_only kernel failed: ", ::cudaGetErrorString(result)); } diff --git a/csrc/topk_sglang_profile.cu b/csrc/topk_sglang_profile.cu index 6aeac4b..adba2d0 100644 --- a/csrc/topk_sglang_profile.cu +++ b/csrc/topk_sglang_profile.cu @@ -98,7 +98,13 @@ __device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) return __bfloat162float(x); } + constexpr int VORTEX_MAX_TOPK = 2048; + +// Diagnostic counters written by the profiling kernel. These kernels are +// NOT used for latency measurements — they intentionally add global-memory +// writes that distort timings. Latency is measured against the clean +// production kernels in topk_sglang.cu. constexpr int COUNTER_THRESHOLD_BIN = 0; constexpr int COUNTER_NUM_ABOVE = 1; constexpr int COUNTER_NUM_EQUAL = 2; @@ -109,1025 +115,403 @@ constexpr int NUM_TOPK_COUNTERS = 6; #include "topk_mapping.cuh" -// - mapping: configurable value-remapping for Stage 1 bin assignment -template -__device__ void fast_topk_vortex( +template +void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { +#ifdef USE_ROCM + return ::cudaFuncSetAttribute( + reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#else + return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#endif + }(); + TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); +} + +// ====================================================================== +// Profiling variant of fast_topk_clean_fused that writes diagnostic +// counters at the end of Stage 1 and at each Stage 2 early-exit. +// Shape / semantics identical to the production kernel, with one extra +// global-memory write pass at the end of each stage. Do not use for +// latency measurements. +// ====================================================================== +template +__device__ void fast_topk_profile( const ScoreT* __restrict__ input, int* __restrict__ index, int row_start, int length, int target_k, - const TopKMappingParams& mapping, - int* counters = nullptr) + const TopKMappingParams mapping, + int* __restrict__ counters) // [NUM_TOPK_COUNTERS] { - int topk = target_k; - constexpr auto BLOCK_SIZE = 1024; - constexpr auto RADIX = 256; - constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); - - alignas(128) __shared__ int vh_histogram_buf[2][RADIX + 128]; - alignas(128) __shared__ int vh_counter; - alignas(128) __shared__ int vh_threshold_bin_id; - alignas(128) __shared__ int vh_num_input[2]; + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); - // Shared memory for mapping LUT / quantiles (loaded once per block) - __shared__ uint8_t s_mapping_lut[256]; - __shared__ float s_mapping_quantiles[256]; + alignas(128) __shared__ int p_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int p_counter; + alignas(128) __shared__ int p_threshold_bin_id; + alignas(128) __shared__ int p_num_input[2]; - // Auto-range for transform modes (3/4/6/7) - __shared__ float s_range_min, s_range_inv_range; + __shared__ uint8_t s_mapping_lut[256]; + __shared__ float s_mapping_quantiles[256]; - auto& vh_histogram = vh_histogram_buf[0]; - extern __shared__ int vh_input_idx[][SMEM_INPUT_SIZE]; + auto& p_histogram = p_histogram_buf[0]; + extern __shared__ int p_input_idx[][SMEM_INPUT_SIZE]; - const int tx = threadIdx.x; + const int tx = threadIdx.x; - // Load mapping tables into shared memory if needed - if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { - if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; - __syncthreads(); - } - if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { - if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; - __syncthreads(); - } - - // Pre-pass: compute per-block min/max of transformed values for linear bucketing. - // sample_stride > 1 reduces pre-pass cost by scanning every Nth element; - // the approximated range may miss extreme outliers but Stage 2 uses raw - // float bits for exact ordering, so correctness is preserved. - if (needs_auto_range(mapping.mode) && !mapping.noscale) { - const int stride = (mapping.sample_stride > 1) ? mapping.sample_stride : 1; - float local_min = __FLT_MAX__, local_max = -__FLT_MAX__; - for (int idx = tx * stride; idx < length; idx += BLOCK_SIZE * stride) { - float val = apply_transform(vortex_to_float(input[idx + row_start]), mapping); - local_min = fminf(local_min, val); - local_max = fmaxf(local_max, val); - } - // Warp-level reduction - for (int offset = 16; offset > 0; offset >>= 1) { - local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - } - // Cross-warp reduction via shared memory - __shared__ float s_warp_mins[32], s_warp_maxs[32]; - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) { s_warp_mins[warp_id] = local_min; s_warp_maxs[warp_id] = local_max; } - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_min = s_warp_mins[tx]; local_max = s_warp_maxs[tx]; - for (int offset = 16; offset > 0; offset >>= 1) { - local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - } - if (tx == 0) { - s_range_min = local_min; - float range = local_max - local_min; - s_range_inv_range = (range > 0.0f) ? 255.0f / range : 0.0f; - } - } - __syncthreads(); - } else if (needs_pivot(mapping.mode)) { - // Pivot pre-pass: compute mean of all elements, store in s_range_min. - // MAPPING_SUBTRACT uses convert_to_uint8(x - range_min), so centering - // around the mean helps distribute values more evenly across bins. - float local_sum = 0.0f; - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - local_sum += vortex_to_float(input[idx + row_start]); - } - // Warp-level reduction - for (int offset = 16; offset > 0; offset >>= 1) { - local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); - } - __shared__ float s_warp_sums[32]; - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) s_warp_sums[warp_id] = local_sum; - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_sum = s_warp_sums[tx]; - for (int offset = 16; offset > 0; offset >>= 1) { - local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); - } - if (tx == 0) { - s_range_min = local_sum / float(length); // mean as pivot - s_range_inv_range = 0.0f; - } - } - __syncthreads(); - } else if (needs_tail_window(mapping.mode)) { - // Adaptive tail-window pre-pass: estimate tau_low = Q(1 - rho*k/n) - // and local_max via a sampled quantile estimator. All 256 coarse bins - // are then allocated to [tau_low, local_max]; scores below tau_low - // collapse into bin 0 via linear_map_to_uint8 clamping. - constexpr int MAX_SAMPLES = 1024; - __shared__ float s_samples[MAX_SAMPLES]; - __shared__ int s_sample_count; - - if (tx == 0) s_sample_count = 0; - __syncthreads(); - - // Compute sampling stride so we collect ~MAX_SAMPLES from the segment - const int desired_stride = (length + MAX_SAMPLES - 1) / MAX_SAMPLES; - const int sample_stride = max(desired_stride, 1); - - // Each thread samples elements and finds local_max simultaneously - float local_max = -__FLT_MAX__; - for (int idx = tx * sample_stride; idx < length; idx += BLOCK_SIZE * sample_stride) { - float val = vortex_to_float(input[idx + row_start]); - local_max = fmaxf(local_max, val); - int slot = ::atomicAdd(&s_sample_count, 1); - if (slot < MAX_SAMPLES) { - s_samples[slot] = val; - } - } + if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { + if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; + __syncthreads(); + } + if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { + if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; + __syncthreads(); + } - // Reduce local_max across block - for (int offset = 16; offset > 0; offset >>= 1) - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - __shared__ float s_warp_maxs_tw[32]; - { - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) s_warp_maxs_tw[warp_id] = local_max; - } - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_max = s_warp_maxs_tw[tx]; - for (int offset = 16; offset > 0; offset >>= 1) - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - if (tx == 0) s_warp_maxs_tw[0] = local_max; - } - __syncthreads(); - local_max = s_warp_maxs_tw[0]; + if (tx < RADIX + 1) p_histogram[tx] = 0; + __syncthreads(); - int nsamp = min(s_sample_count, MAX_SAMPLES); + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const float raw = vortex_to_float(input[idx + row_start]); + const auto bin = compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles); + ::atomicAdd(&p_histogram[bin], 1); + } + __syncthreads(); - // Simple odd-even transposition sort on the sample buffer. - // nsamp <= 1024, and we have 1024 threads, so each thread - // handles one element. O(nsamp) parallel rounds suffice. - __syncthreads(); - if (nsamp >= 2) { - for (int pass = 0; pass < nsamp; ++pass) { - // Even phase: compare (0,1), (2,3), ... - if (tx * 2 + 1 < nsamp) { - int i = tx * 2; - if (s_samples[i] > s_samples[i + 1]) { - float tmp = s_samples[i]; - s_samples[i] = s_samples[i + 1]; - s_samples[i + 1] = tmp; - } - } - __syncthreads(); - // Odd phase: compare (1,2), (3,4), ... - if (tx * 2 + 2 < nsamp) { - int i = tx * 2 + 1; - if (s_samples[i] > s_samples[i + 1]) { - float tmp = s_samples[i]; - s_samples[i] = s_samples[i + 1]; - s_samples[i + 1] = tmp; - } - } - __syncthreads(); - } - } + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = p_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += p_histogram_buf[k][tx + j]; + } + p_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; - // Estimate tau_low = Q(1 - rho * k / n) - if (tx == 0) { - float rho = mapping.power_exp; // reused as tail expansion factor - if (rho <= 0.0f) rho = 4.0f; - int k = (mapping.target_k > 0) ? mapping.target_k : target_k; - float frac = 1.0f - rho * float(k) / float(length); - frac = fmaxf(frac, 0.0f); // clamp: never go below rank 0 + run_cumsum(); + if (tx < RADIX && p_histogram[tx] > topk && p_histogram[tx + 1] <= topk) { + p_threshold_bin_id = tx; + p_num_input[0] = 0; + p_counter = 0; + } + __syncthreads(); - float tau_low; - if (nsamp < 4 || frac <= 0.0f) { - // Too few samples or the tail covers everything: full range - tau_low = -__FLT_MAX__; - } else { - float fidx = frac * float(nsamp - 1); - int lo = __float2int_rd(fidx); - lo = min(max(lo, 0), nsamp - 2); - float t = fidx - float(lo); - tau_low = s_samples[lo] * (1.0f - t) + s_samples[lo + 1] * t; - } + const int threshold_bin_0 = p_threshold_bin_id; + const int threshold_bin_size = p_histogram[threshold_bin_0]; // pre-reset count + topk -= p_histogram[threshold_bin_0 + 1]; - // Fallback: if tau_low >= local_max, use full-range linear mapping - if (tau_low >= local_max) { - // Find the actual minimum from sorted samples - tau_low = (nsamp > 0) ? s_samples[0] : local_max; - } + if (tx == 0 && counters) { + counters[COUNTER_THRESHOLD_BIN] = threshold_bin_0; + counters[COUNTER_NUM_EQUAL] = threshold_bin_size; + counters[COUNTER_REMAINING_K] = topk; + } - float range = local_max - tau_low; - s_range_min = tau_low; - s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; - } - __syncthreads(); - } else if (needs_topk_window(mapping.mode)) { - // Topk-window pre-pass with streaming variance heuristic. - float local_max = -__FLT_MAX__; - float local_sum = 0.0f, local_sum_sq = 0.0f; - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - float val = vortex_to_float(input[idx + row_start]); - local_max = fmaxf(local_max, val); - local_sum += val; - local_sum_sq += val * val; - } - for (int offset = 16; offset > 0; offset >>= 1) { - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); - local_sum_sq += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq, offset); - } - __shared__ float s_warp_maxs_tw2[32], s_warp_sums_tw2[32], s_warp_sq_tw2[32]; - { - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) { - s_warp_maxs_tw2[warp_id] = local_max; - s_warp_sums_tw2[warp_id] = local_sum; - s_warp_sq_tw2[warp_id] = local_sum_sq; - } - } - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_max = s_warp_maxs_tw2[tx]; - local_sum = s_warp_sums_tw2[tx]; - local_sum_sq = s_warp_sq_tw2[tx]; - for (int offset = 16; offset > 0; offset >>= 1) { - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); - local_sum_sq += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq, offset); - } - if (tx == 0) { - float rho = mapping.power_exp; - if (rho <= 0.0f) rho = 4.0f; - int k = (mapping.target_k > 0) ? mapping.target_k : target_k; - float n = float(length); - float mean = local_sum / n; - float var = local_sum_sq / n - mean * mean; - float sigma = (var > 0.0f) ? sqrtf(var) : 0.0f; - float ratio = n / fmaxf(float(k), 1.0f); - float z = sqrtf(2.0f * __logf(fmaxf(ratio, 1.0f))); - float tau_low = local_max - rho * sigma * z; - if (tau_low >= local_max) tau_low = local_max - 1.0f; - float range = local_max - tau_low; - s_range_min = tau_low; - s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; - } - } - __syncthreads(); - } else { - if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } - __syncthreads(); + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const float raw = vortex_to_float(input[idx + row_start]); + const auto bin = static_cast( + compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + if (bin > threshold_bin_0) { + const auto pos = ::atomicAdd(&p_counter, 1); + index[pos] = idx; + } } - - // Stage 1: 8-bit coarse histogram (with optional mapping) - // Bin cache: store computed bins in vh_input_idx[1] (reinterpreted as uint8_t*) - // to avoid recomputing mapped_convert_to_uint8 in the route/filter pass. - // vh_input_idx[1] is unused until Stage 2 double-buffering starts after route. - constexpr int BIN_CACHE_CAPACITY = SMEM_INPUT_SIZE * static_cast(sizeof(int)); // uint8 entries - uint8_t* bin_cache = reinterpret_cast(vh_input_idx[1]); - const bool use_bin_cache = (length <= BIN_CACHE_CAPACITY); - - if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + if (tx == 0 && counters) { + counters[COUNTER_NUM_ABOVE] = p_counter; + counters[COUNTER_REFINE_ROUNDS] = 0; + counters[COUNTER_STAGE2_INPUT] = 0; + } + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) p_histogram[tx] = 0; __syncthreads(); for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = mapped_convert_to_uint8( - vortex_to_float(input[idx + row_start]), - mapping, s_mapping_lut, s_mapping_quantiles, - s_range_min, s_range_inv_range); - ::atomicAdd(&vh_histogram[bin], 1); - if (use_bin_cache) { - bin_cache[idx] = bin; - } + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform(raw, mapping); + const auto bin = static_cast( + compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + if (bin > threshold_bin_0) { + const auto pos = ::atomicAdd(&p_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin_0) { + const auto pos = ::atomicAdd(&p_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + p_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(remapped); + const auto sub_bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&p_histogram[sub_bin], 1); + } + } } __syncthreads(); + if (tx == 0 && counters) { + counters[COUNTER_NUM_ABOVE] = p_counter; + counters[COUNTER_STAGE2_INPUT] = p_num_input[0]; + } + } - const auto run_cumsum = [&] { -#pragma unroll 8 - for (int i = 0; i < 8; ++i) { - static_assert(1 << 8 == RADIX); - if (C10_LIKELY(tx < RADIX)) { - const auto j = 1 << i; - const auto k = i & 1; - auto value = vh_histogram_buf[k][tx]; - if (tx < RADIX - j) { - value += vh_histogram_buf[k][tx + j]; - } - vh_histogram_buf[k ^ 1][tx] = value; - } - __syncthreads(); - } - }; + // Stage 2 refinement (4 rounds max). Default rounds=4, overwritten on exit. + if (tx == 0 && counters) counters[COUNTER_REFINE_ROUNDS] = 4; +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int p_last_remain; + const auto r_idx = round % 2; + const auto _raw_num_input = p_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); run_cumsum(); - if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { - vh_threshold_bin_id = tx; - vh_num_input[0] = 0; - vh_counter = 0; + if (tx < RADIX && p_histogram[tx] > topk && p_histogram[tx + 1] <= topk) { + p_threshold_bin_id = tx; + p_num_input[r_idx ^ 1] = 0; + p_last_remain = topk - p_histogram[tx + 1]; } __syncthreads(); - const auto threshold_bin = vh_threshold_bin_id; - topk -= vh_histogram[threshold_bin + 1]; - - if (WriteCounters && tx == 0 && counters) { - counters[COUNTER_THRESHOLD_BIN] = threshold_bin; - counters[COUNTER_REMAINING_K] = topk; - } + const auto threshold_bin = p_threshold_bin_id; + topk -= p_histogram[threshold_bin + 1]; if (topk == 0) { - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - int bin; - if (use_bin_cache) { - bin = static_cast(bin_cache[idx]); - } else { - bin = static_cast( - mapped_convert_to_uint8( - vortex_to_float(input[idx + row_start]), - mapping, s_mapping_lut, s_mapping_quantiles, - s_range_min, s_range_inv_range)); - } - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - if (WriteCounters && tx == 0 && counters) { - counters[COUNTER_NUM_ABOVE] = vh_counter; - counters[COUNTER_NUM_EQUAL] = 0; - counters[COUNTER_REFINE_ROUNDS] = 0; - counters[COUNTER_STAGE2_INPUT] = 0; - } - return; + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = p_input_idx[r_idx][i]; + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform(raw, mapping); + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&p_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + if (tx == 0 && counters) counters[COUNTER_REFINE_ROUNDS] = round + 1; + break; } else { - __syncthreads(); - if (tx < RADIX + 1) vh_histogram[tx] = 0; - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto raw_input = vortex_to_float(input[idx + row_start]); - int bin; - if (use_bin_cache) { - bin = static_cast(bin_cache[idx]); - } else { - bin = static_cast( - mapped_convert_to_uint8(raw_input, mapping, - s_mapping_lut, s_mapping_quantiles, - s_range_min, s_range_inv_range)); - } - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - const auto pos = ::atomicAdd(&vh_num_input[0], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - vh_input_idx[0][pos] = idx; - const auto b32 = convert_to_uint32(raw_input); - const auto sub_bin = (b32 >> 24) & 0xFF; - ::atomicAdd(&vh_histogram[sub_bin], 1); - } - } - } - __syncthreads(); - if (WriteCounters && tx == 0 && counters) { - counters[COUNTER_NUM_ABOVE] = vh_counter; - counters[COUNTER_NUM_EQUAL] = vh_num_input[0]; - counters[COUNTER_STAGE2_INPUT] = vh_num_input[0]; - } - if (StopAfterStage1) return; - } - - // Stage 2: refine with 8-bit radix passes (unchanged — uses raw float bits) - if constexpr (WriteCounters) { - // Default: all 4 rounds used; overwritten at break if resolved early - if (tx == 0 && counters) counters[COUNTER_REFINE_ROUNDS] = 4; - } -#pragma unroll 4 - for (int round = 0; round < 4; ++round) { - __shared__ int vh_last_remain; - const auto r_idx = round % 2; - - const auto _raw_num_input = vh_num_input[r_idx]; - const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) - ? _raw_num_input - : int(SMEM_INPUT_SIZE); - - run_cumsum(); - if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { - vh_threshold_bin_id = tx; - vh_num_input[r_idx ^ 1] = 0; - vh_last_remain = topk - vh_histogram[tx + 1]; - } - __syncthreads(); - - const auto threshold_bin = vh_threshold_bin_id; - topk -= vh_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = vh_input_idx[r_idx][i]; - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32( - vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } + __syncthreads(); + if (tx < RADIX + 1) p_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = p_input_idx[r_idx][i]; + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform(raw, mapping); + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&p_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&p_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; } - __syncthreads(); - if constexpr (WriteCounters) { - if (tx == 0 && counters) { - counters[COUNTER_REFINE_ROUNDS] = round + 1; - } + } else { + const auto pos = ::atomicAdd(&p_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + p_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(remapped); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&p_histogram[sub_bin], 1); } - break; - } else { - __syncthreads(); - if (tx < RADIX + 1) vh_histogram[tx] = 0; - __syncthreads(); - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = vh_input_idx[r_idx][i]; - const auto raw_input = vortex_to_float(input[idx + row_start]); - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - if (round == 3) { - const auto pos = ::atomicAdd(&vh_last_remain, -1); - if (pos > 0) { - index[target_k - pos] = idx; - } - } else { - const auto pos = ::atomicAdd(&vh_num_input[r_idx ^ 1], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - vh_input_idx[r_idx ^ 1][pos] = idx; - const auto b32 = convert_to_uint32(raw_input); - const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; - ::atomicAdd(&vh_histogram[sub_bin], 1); - } - } - } - } - __syncthreads(); + } } + } + __syncthreads(); } + } } -template -void setup_kernel_smem_once() { - [[maybe_unused]] - static const auto result = [] { -#ifdef USE_ROCM - // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, - // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing - // a function pointer directly, so cast explicitly. - return ::cudaFuncSetAttribute( - reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); -#else - // CUDA: keep original behavior (no cast needed). - return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); -#endif - }(); - TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); -} - -// ====================================================================== +// Wrapper: one block per (batch*head) segment. Writes counters per +// segment into a [eff_batch_size, NUM_TOPK_COUNTERS] int32 tensor. template __global__ __launch_bounds__(kThreadsPerBlock) -void TopKStage1_Kernel( +void TopKProfileCounters_Kernel( const ScoreT* __restrict__ score, const int* __restrict__ dense_kv_indptr, const int* __restrict__ sparse_kv_indptr, const int* __restrict__ dense_kv_indices, int* __restrict__ sparse_kv_indices, + int* __restrict__ counters, const int topk_val, const int page_reserved_bos, const int page_reserved_eos, const TopKMappingParams mapping) { - const int bx = blockIdx.x; - - const int start = dense_kv_indptr[bx] + page_reserved_bos; - const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; - const int nblk = end - start; - if (nblk <= topk_val) return; - - const ScoreT* __restrict__ score_blk = score + start; - const int* __restrict__ idx_blk = dense_kv_indices + start; - int* __restrict__ out_blk = sparse_kv_indices - + sparse_kv_indptr[bx] - + page_reserved_bos; - - __shared__ int s_indices[VORTEX_MAX_TOPK]; - fast_topk_vortex( - score_blk, s_indices, 0, nblk, topk_val, mapping); - __syncthreads(); - - // Remap position indices -> page indices via dense_kv_indices - const int tx = threadIdx.x; - for (int i = tx; i < topk_val; i += kThreadsPerBlock) { - out_blk[i] = idx_blk[s_indices[i]]; - } + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_profile( + score_blk, s_indices, 0, nblk, topk_val, mapping, + counters + bx * NUM_TOPK_COUNTERS); + __syncthreads(); + + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } } -// ====================================================================== -// Profiling counters kernel: runs full pipeline + writes diagnostic -// counters to a separate global-memory tensor -// ====================================================================== +// Histogram-only profiling kernel: builds a 256-bin histogram of the +// remapped bins for each segment. Purely diagnostic — never timed. template __global__ __launch_bounds__(kThreadsPerBlock) -void TopKCounters_Kernel( +void TopKProfileHistogram_Kernel( const ScoreT* __restrict__ score, const int* __restrict__ dense_kv_indptr, - const int* __restrict__ sparse_kv_indptr, - const int* __restrict__ dense_kv_indices, - int* __restrict__ sparse_kv_indices, - int* __restrict__ counters, // [eff_batch_size, NUM_TOPK_COUNTERS] - const int topk_val, + int* __restrict__ histograms, // [eff_batch_size, 256] const int page_reserved_bos, const int page_reserved_eos, const TopKMappingParams mapping) { - const int bx = blockIdx.x; + constexpr auto RADIX = 256; + constexpr auto BLOCK_SIZE = kThreadsPerBlock; + __shared__ int s_histogram[RADIX]; + __shared__ uint8_t s_mapping_lut[256]; + __shared__ float s_mapping_quantiles[256]; - const int start = dense_kv_indptr[bx] + page_reserved_bos; - const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; - const int nblk = end - start; - if (nblk <= topk_val) return; + const int bx = blockIdx.x; + const int tx = threadIdx.x; - const ScoreT* __restrict__ score_blk = score + start; - const int* __restrict__ idx_blk = dense_kv_indices + start; - int* __restrict__ out_blk = sparse_kv_indices - + sparse_kv_indptr[bx] - + page_reserved_bos; + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; - __shared__ int s_indices[VORTEX_MAX_TOPK]; - fast_topk_vortex( - score_blk, s_indices, 0, nblk, topk_val, mapping, - counters + bx * NUM_TOPK_COUNTERS); + if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { + if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; __syncthreads(); + } + if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { + if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; + __syncthreads(); + } - // Remap position indices -> page indices via dense_kv_indices - const int tx = threadIdx.x; - for (int i = tx; i < topk_val; i += kThreadsPerBlock) { - out_blk[i] = idx_blk[s_indices[i]]; - } -} - -// ====================================================================== -// Profiling histogram kernel: runs only Stage 1 and returns per-segment -// 256-bin histograms for distribution analysis -// ====================================================================== -template -__global__ __launch_bounds__(kThreadsPerBlock) -void TopKHistogram_Kernel( - const ScoreT* __restrict__ score, - const int* __restrict__ dense_kv_indptr, - int* __restrict__ histograms, // [eff_batch_size, 256] - const int page_reserved_bos, - const int page_reserved_eos, - const TopKMappingParams mapping) -{ - constexpr auto RADIX = 256; - constexpr auto BLOCK_SIZE = kThreadsPerBlock; - __shared__ int s_histogram[RADIX]; - __shared__ uint8_t s_mapping_lut[256]; - __shared__ float s_mapping_quantiles[256]; - __shared__ float s_range_min, s_range_inv_range; - - const int bx = blockIdx.x; - const int tx = threadIdx.x; - - const int start = dense_kv_indptr[bx] + page_reserved_bos; - const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; - const int nblk = end - start; + if (tx < RADIX) s_histogram[tx] = 0; + __syncthreads(); + if (nblk > 0) { const ScoreT* __restrict__ score_blk = score + start; - - // Load mapping tables into shared memory if needed - if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { - if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; - __syncthreads(); - } - if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { - if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; - __syncthreads(); + for (int i = tx; i < nblk; i += BLOCK_SIZE) { + const float raw = vortex_to_float(score_blk[i]); + const auto bin = compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles); + ::atomicAdd(&s_histogram[bin], 1); } + } + __syncthreads(); - // Pre-pass: compute per-block min/max for transform modes (supports sampled stride) - if (needs_auto_range(mapping.mode) && !mapping.noscale) { - const int stride = (mapping.sample_stride > 1) ? mapping.sample_stride : 1; - float local_min = __FLT_MAX__, local_max = -__FLT_MAX__; - for (int idx = tx * stride; idx < nblk; idx += BLOCK_SIZE * stride) { - float val = apply_transform(vortex_to_float(score_blk[idx]), mapping); - local_min = fminf(local_min, val); - local_max = fmaxf(local_max, val); - } - for (int offset = 16; offset > 0; offset >>= 1) { - local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - } - __shared__ float s_warp_mins[32], s_warp_maxs[32]; - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) { s_warp_mins[warp_id] = local_min; s_warp_maxs[warp_id] = local_max; } - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_min = s_warp_mins[tx]; local_max = s_warp_maxs[tx]; - for (int offset = 16; offset > 0; offset >>= 1) { - local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - } - if (tx == 0) { - s_range_min = local_min; - float range = local_max - local_min; - s_range_inv_range = (range > 0.0f) ? 255.0f / range : 0.0f; - } - } - __syncthreads(); - } else if (needs_pivot(mapping.mode)) { - // Pivot pre-pass: compute mean for MAPPING_SUBTRACT - float local_sum = 0.0f; - for (int idx = tx; idx < nblk; idx += BLOCK_SIZE) { - local_sum += vortex_to_float(score_blk[idx]); - } - for (int offset = 16; offset > 0; offset >>= 1) { - local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); - } - __shared__ float s_warp_sums_h[32]; - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) s_warp_sums_h[warp_id] = local_sum; - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_sum = s_warp_sums_h[tx]; - for (int offset = 16; offset > 0; offset >>= 1) { - local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); - } - if (tx == 0) { - s_range_min = local_sum / float(nblk); - s_range_inv_range = 0.0f; - } - } - __syncthreads(); - } else if (needs_tail_window(mapping.mode)) { - // Adaptive tail-window pre-pass (histogram kernel variant) - constexpr int MAX_SAMPLES_H = 1024; - __shared__ float s_samples_h[MAX_SAMPLES_H]; - __shared__ int s_sample_count_h; - - if (tx == 0) s_sample_count_h = 0; - __syncthreads(); - - const int desired_stride = (nblk + MAX_SAMPLES_H - 1) / MAX_SAMPLES_H; - const int sample_stride_h = max(desired_stride, 1); - - float local_max = -__FLT_MAX__; - for (int idx = tx * sample_stride_h; idx < nblk; idx += BLOCK_SIZE * sample_stride_h) { - float val = vortex_to_float(score_blk[idx]); - local_max = fmaxf(local_max, val); - int slot = ::atomicAdd(&s_sample_count_h, 1); - if (slot < MAX_SAMPLES_H) s_samples_h[slot] = val; - } - - for (int offset = 16; offset > 0; offset >>= 1) - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - __shared__ float s_warp_maxs_h[32]; - { - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) s_warp_maxs_h[warp_id] = local_max; - } - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_max = s_warp_maxs_h[tx]; - for (int offset = 16; offset > 0; offset >>= 1) - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - if (tx == 0) s_warp_maxs_h[0] = local_max; - } - __syncthreads(); - local_max = s_warp_maxs_h[0]; - - int nsamp = min(s_sample_count_h, MAX_SAMPLES_H); - - __syncthreads(); - if (nsamp >= 2) { - for (int pass = 0; pass < nsamp; ++pass) { - if (tx * 2 + 1 < nsamp) { - int i = tx * 2; - if (s_samples_h[i] > s_samples_h[i + 1]) { - float tmp = s_samples_h[i]; - s_samples_h[i] = s_samples_h[i + 1]; - s_samples_h[i + 1] = tmp; - } - } - __syncthreads(); - if (tx * 2 + 2 < nsamp) { - int i = tx * 2 + 1; - if (s_samples_h[i] > s_samples_h[i + 1]) { - float tmp = s_samples_h[i]; - s_samples_h[i] = s_samples_h[i + 1]; - s_samples_h[i + 1] = tmp; - } - } - __syncthreads(); - } - } - - if (tx == 0) { - float rho = mapping.power_exp; - if (rho <= 0.0f) rho = 4.0f; - int k = mapping.target_k; - float frac = (k > 0 && nblk > 0) ? 1.0f - rho * float(k) / float(nblk) : 0.0f; - frac = fmaxf(frac, 0.0f); - - float tau_low; - if (nsamp < 4 || frac <= 0.0f) { - tau_low = -__FLT_MAX__; - } else { - float fidx = frac * float(nsamp - 1); - int lo = __float2int_rd(fidx); - lo = min(max(lo, 0), nsamp - 2); - float t = fidx - float(lo); - tau_low = s_samples_h[lo] * (1.0f - t) + s_samples_h[lo + 1] * t; - } - - if (tau_low >= local_max) { - tau_low = (nsamp > 0) ? s_samples_h[0] : local_max; - } - - float range = local_max - tau_low; - s_range_min = tau_low; - s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; - } - __syncthreads(); - } else if (needs_topk_window(mapping.mode)) { - // Topk-window pre-pass with streaming variance (histogram kernel variant) - float local_max_h = -__FLT_MAX__; - float local_sum_h = 0.0f, local_sum_sq_h = 0.0f; - for (int idx = tx; idx < nblk; idx += BLOCK_SIZE) { - float val = vortex_to_float(score_blk[idx]); - local_max_h = fmaxf(local_max_h, val); - local_sum_h += val; - local_sum_sq_h += val * val; - } - for (int offset = 16; offset > 0; offset >>= 1) { - local_max_h = fmaxf(local_max_h, __shfl_xor_sync(0xFFFFFFFF, local_max_h, offset)); - local_sum_h += __shfl_xor_sync(0xFFFFFFFF, local_sum_h, offset); - local_sum_sq_h += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq_h, offset); - } - __shared__ float s_warp_maxs_tw3[32], s_warp_sums_tw3[32], s_warp_sq_tw3[32]; - { - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) { - s_warp_maxs_tw3[warp_id] = local_max_h; - s_warp_sums_tw3[warp_id] = local_sum_h; - s_warp_sq_tw3[warp_id] = local_sum_sq_h; - } - } - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_max_h = s_warp_maxs_tw3[tx]; - local_sum_h = s_warp_sums_tw3[tx]; - local_sum_sq_h = s_warp_sq_tw3[tx]; - for (int offset = 16; offset > 0; offset >>= 1) { - local_max_h = fmaxf(local_max_h, __shfl_xor_sync(0xFFFFFFFF, local_max_h, offset)); - local_sum_h += __shfl_xor_sync(0xFFFFFFFF, local_sum_h, offset); - local_sum_sq_h += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq_h, offset); - } - if (tx == 0) { - float rho = mapping.power_exp; - if (rho <= 0.0f) rho = 4.0f; - int k = mapping.target_k; - float n = float(nblk); - float mean = local_sum_h / n; - float var = local_sum_sq_h / n - mean * mean; - float sigma = (var > 0.0f) ? sqrtf(var) : 0.0f; - float ratio = n / fmaxf(float(k), 1.0f); - float z = sqrtf(2.0f * __logf(fmaxf(ratio, 1.0f))); - float tau_low = local_max_h - rho * sigma * z; - if (tau_low >= local_max_h) tau_low = local_max_h - 1.0f; - float range = local_max_h - tau_low; - s_range_min = tau_low; - s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; - } - } - __syncthreads(); - } else { - if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } - __syncthreads(); - } - - // Initialize shared histogram - if (tx < RADIX) s_histogram[tx] = 0; - __syncthreads(); - - // Build histogram over the segment with mapping - for (int idx = tx; idx < nblk; idx += BLOCK_SIZE) { - const auto bin = mapped_convert_to_uint8( - vortex_to_float(score_blk[idx]), - mapping, s_mapping_lut, s_mapping_quantiles, - s_range_min, s_range_inv_range); - ::atomicAdd(&s_histogram[bin], 1); - } - __syncthreads(); - - // Write to global memory - int* __restrict__ out = histograms + bx * RADIX; - if (tx < RADIX) { - out[tx] = s_histogram[tx]; - } + int* __restrict__ out = histograms + bx * RADIX; + if (tx < RADIX) out[tx] = s_histogram[tx]; } } // namespace #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") -// ====================================================================== -// Profiling: collect per-segment 256-bin histograms of Stage 1 bins -// ====================================================================== -void topk_profile_histogram( - const at::Tensor& x, - const at::Tensor& dense_kv_indptr, - at::Tensor& histograms, - const int64_t eff_batch_size, - const int64_t reserved_bos, - const int64_t reserved_eos, - const int64_t mapping_mode, - const double mapping_power, - std::optional mapping_lut, - std::optional mapping_quantiles, - const bool mapping_noscale, - const int64_t topk_val, - const int64_t sample_stride) -{ - CHECK_CUDA(x); - CHECK_CUDA(dense_kv_indptr); - CHECK_CUDA(histograms); - TORCH_CHECK(histograms.dim() == 2 && histograms.size(0) == eff_batch_size - && histograms.size(1) == 256, - "histograms must be [eff_batch_size, 256]"); - TORCH_CHECK(histograms.scalar_type() == at::ScalarType::Int, - "histograms must be int32"); - - // Build mapping params - TopKMappingParams mapping{}; - mapping.mode = static_cast(mapping_mode); - mapping.power_exp = static_cast(mapping_power); - mapping.lut = nullptr; - mapping.quantiles = nullptr; - mapping.noscale = mapping_noscale; - mapping.sample_stride = static_cast(sample_stride); - mapping.target_k = static_cast(topk_val); - - if (mapping_lut.has_value()) { - const auto& lut = mapping_lut.value(); - CHECK_CUDA(lut); - TORCH_CHECK(lut.dim() == 1 && lut.size(0) == 256 && lut.scalar_type() == at::ScalarType::Byte, - "mapping_lut must be a 1D uint8 tensor of size 256"); - mapping.lut = lut.data_ptr(); - } - if (mapping_quantiles.has_value()) { - const auto& q = mapping_quantiles.value(); - CHECK_CUDA(q); - TORCH_CHECK(q.dim() == 1 && q.size(0) == 256 && q.scalar_type() == at::ScalarType::Float, - "mapping_quantiles must be a 1D float32 tensor of size 256"); - mapping.quantiles = q.data_ptr(); - } - - dim3 nblks(eff_batch_size); - dim3 nthreads(kThreadsPerBlock); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (x.scalar_type() == at::ScalarType::BFloat16) { - TopKHistogram_Kernel<__nv_bfloat16><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - histograms.data_ptr(), - reserved_bos, - reserved_eos, - mapping); - } else if (x.scalar_type() == at::ScalarType::Float) { - TopKHistogram_Kernel<<>>( - x.data_ptr(), - dense_kv_indptr.data_ptr(), - histograms.data_ptr(), - reserved_bos, - reserved_eos, - mapping); - } else { - TORCH_CHECK(false, - "topk_profile_histogram: unsupported dtype ", - x.scalar_type()); - } - - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, - "topk_profile_histogram kernel failed: ", ::cudaGetErrorString(result)); -} - -// Helper: build TopKMappingParams from host arguments static TopKMappingParams build_mapping_params( int64_t mapping_mode, double mapping_power, std::optional& mapping_lut, - std::optional& mapping_quantiles, - bool mapping_noscale = false, - int sample_stride = 1, - int target_k = 0) + std::optional& mapping_quantiles) { - TopKMappingParams mapping{}; - mapping.mode = static_cast(mapping_mode); - mapping.power_exp = static_cast(mapping_power); - mapping.lut = nullptr; - mapping.quantiles = nullptr; - mapping.noscale = mapping_noscale; - mapping.sample_stride = sample_stride; - mapping.target_k = target_k; - - if (mapping_lut.has_value()) { - const auto& lut = mapping_lut.value(); - CHECK_CUDA(lut); - TORCH_CHECK(lut.dim() == 1 && lut.size(0) == 256 && lut.scalar_type() == at::ScalarType::Byte, - "mapping_lut must be a 1D uint8 tensor of size 256"); - mapping.lut = lut.data_ptr(); - } - if (mapping_quantiles.has_value()) { - const auto& q = mapping_quantiles.value(); - CHECK_CUDA(q); - TORCH_CHECK(q.dim() == 1 && q.size(0) == 256 && q.scalar_type() == at::ScalarType::Float, - "mapping_quantiles must be a 1D float32 tensor of size 256"); - mapping.quantiles = q.data_ptr(); - } - return mapping; + TopKMappingParams m{}; + m.mode = static_cast(mapping_mode); + m.power_exp = static_cast(mapping_power); + m.lut = nullptr; + m.quantiles = nullptr; + if (mapping_lut.has_value()) { + const auto& lut = mapping_lut.value(); + TORCH_CHECK(lut.is_cuda(), "mapping_lut must be a CUDA tensor"); + TORCH_CHECK(lut.dim() == 1 && lut.size(0) == 256 && lut.scalar_type() == at::ScalarType::Byte, + "mapping_lut must be a 1D uint8 tensor of size 256"); + m.lut = lut.data_ptr(); + } + if (mapping_quantiles.has_value()) { + const auto& q = mapping_quantiles.value(); + TORCH_CHECK(q.is_cuda(), "mapping_quantiles must be a CUDA tensor"); + TORCH_CHECK(q.dim() == 1 && q.size(0) == 256 && q.scalar_type() == at::ScalarType::Float, + "mapping_quantiles must be a 1D float32 tensor of size 256"); + m.quantiles = q.data_ptr(); + } + return m; } // ====================================================================== -// Profiling: Stage 1 only (pre-pass + hist + cumsum + route/filter) +// Profiling: per-segment 256-bin histograms of Stage 1 remapped bins. // ====================================================================== -void topk_profile_stage1( +void topk_profile_histogram( const at::Tensor& x, const at::Tensor& dense_kv_indptr, - const at::Tensor& sparse_kv_indptr, - const at::Tensor& dense_kv_indices, - at::Tensor& sparse_kv_indices, + at::Tensor& histograms, const int64_t eff_batch_size, - const int64_t topk_val, const int64_t reserved_bos, const int64_t reserved_eos, - const int64_t max_num_pages, const int64_t mapping_mode, const double mapping_power, std::optional mapping_lut, - std::optional mapping_quantiles, - const bool mapping_noscale) + std::optional mapping_quantiles) { - TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, - "topk_profile_stage1: topk_val (", topk_val, - ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); - - auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles, - mapping_noscale, /*sample_stride=*/1, /*target_k=*/static_cast(topk_val)); - - dim3 nblks(eff_batch_size); - dim3 nthreads(kThreadsPerBlock); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (x.scalar_type() == at::ScalarType::BFloat16) { - setup_kernel_smem_once, kSmem>(); - TopKStage1_Kernel<__nv_bfloat16><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos, - mapping); - } else if (x.scalar_type() == at::ScalarType::Float) { - setup_kernel_smem_once, kSmem>(); - TopKStage1_Kernel<<>>( - x.data_ptr(), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos, - mapping); - } else { - TORCH_CHECK(false, - "topk_profile_stage1: unsupported dtype ", - x.scalar_type()); - } + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(histograms); + TORCH_CHECK(histograms.dim() == 2 && histograms.size(0) == eff_batch_size + && histograms.size(1) == 256, + "histograms must be [eff_batch_size, 256]"); + TORCH_CHECK(histograms.scalar_type() == at::ScalarType::Int, + "histograms must be int32"); + + auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + TopKProfileHistogram_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + histograms.data_ptr(), + reserved_bos, reserved_eos, mapping); + } else if (x.scalar_type() == at::ScalarType::Float) { + TopKProfileHistogram_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + histograms.data_ptr(), + reserved_bos, reserved_eos, mapping); + } else { + TORCH_CHECK(false, "topk_profile_histogram: unsupported dtype ", x.scalar_type()); + } - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, - "topk_profile_stage1 kernel failed: ", ::cudaGetErrorString(result)); + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_profile_histogram kernel failed: ", ::cudaGetErrorString(result)); } // ====================================================================== -// Profiling: full pipeline + diagnostic counters +// Profiling: full pipeline + per-segment diagnostic counters. +// Adds extra global-memory writes — never use for latency measurement. // ====================================================================== void topk_profile_counters( const at::Tensor& x, @@ -1144,60 +528,53 @@ void topk_profile_counters( const int64_t mapping_mode, const double mapping_power, std::optional mapping_lut, - std::optional mapping_quantiles, - const bool mapping_noscale) + std::optional mapping_quantiles) { - TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, - "topk_profile_counters: topk_val (", topk_val, - ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); - CHECK_CUDA(counters); - TORCH_CHECK(counters.dim() == 2 && counters.size(0) == eff_batch_size - && counters.size(1) == NUM_TOPK_COUNTERS, - "counters must be [eff_batch_size, ", NUM_TOPK_COUNTERS, "]"); - TORCH_CHECK(counters.scalar_type() == at::ScalarType::Int, - "counters must be int32"); - - auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles, - mapping_noscale, /*sample_stride=*/1, /*target_k=*/static_cast(topk_val)); - - dim3 nblks(eff_batch_size); - dim3 nthreads(kThreadsPerBlock); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (x.scalar_type() == at::ScalarType::BFloat16) { - setup_kernel_smem_once, kSmem>(); - TopKCounters_Kernel<__nv_bfloat16><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - counters.data_ptr(), - topk_val, - reserved_bos, - reserved_eos, - mapping); - } else if (x.scalar_type() == at::ScalarType::Float) { - setup_kernel_smem_once, kSmem>(); - TopKCounters_Kernel<<>>( - x.data_ptr(), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - counters.data_ptr(), - topk_val, - reserved_bos, - reserved_eos, - mapping); - } else { - TORCH_CHECK(false, - "topk_profile_counters: unsupported dtype ", - x.scalar_type()); - } + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_profile_counters: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(sparse_kv_indptr); + CHECK_CUDA(dense_kv_indices); + CHECK_CUDA(sparse_kv_indices); + CHECK_CUDA(counters); + TORCH_CHECK(counters.dim() == 2 && counters.size(0) == eff_batch_size + && counters.size(1) == NUM_TOPK_COUNTERS, + "counters must be [eff_batch_size, ", NUM_TOPK_COUNTERS, "]"); + TORCH_CHECK(counters.scalar_type() == at::ScalarType::Int, "counters must be int32"); + + auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + setup_kernel_smem_once, kSmem>(); + TopKProfileCounters_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + counters.data_ptr(), + topk_val, reserved_bos, reserved_eos, mapping); + } else if (x.scalar_type() == at::ScalarType::Float) { + setup_kernel_smem_once, kSmem>(); + TopKProfileCounters_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + counters.data_ptr(), + topk_val, reserved_bos, reserved_eos, mapping); + } else { + TORCH_CHECK(false, "topk_profile_counters: unsupported dtype ", x.scalar_type()); + } - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, - "topk_profile_counters kernel failed: ", ::cudaGetErrorString(result)); + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_profile_counters kernel failed: ", ::cudaGetErrorString(result)); } - diff --git a/examples/remap_function_bench.sh b/examples/remap_function_bench.sh new file mode 100755 index 0000000..7d56d57 --- /dev/null +++ b/examples/remap_function_bench.sh @@ -0,0 +1,238 @@ +#!/usr/bin/env bash +# ============================================================ +# Remap Function Benchmark +# +# Compares four kernel configurations for TopK page selection: +# 1. baseline — unmapped topk (topk_output_sglang) +# 2. fused remap + topk — topk_output_sglang_fused +# 3. remap only — topk_remap_only (standalone kernel) +# 4. unmapped topk on remapped — topk_output_sglang on the output +# buffer of step 3 +# +# Per configuration the script also reports the threshold-bin +# position, the threshold-bin size, and how many values are +# selected from the threshold bin (derived from +# topk_profile_counters — collected after all timing measurements, +# never interleaved with latency measurements). +# +# Pipeline: +# 1. Calibrate — run `calibrate_topk.py` on the chosen model to +# collect the REAL per-segment topk distribution +# (raw_histograms.npy). Skippable via +# --real-histograms /path/to/raw_histograms.npy. +# 2. Autotune — run `autotune_topk_mapping.py` on those real +# histograms and pick the per-mode hyperparameter +# with the LOWEST measured topk kernel latency. +# 3. Remap bench— run `bench_topk.py --remap-bench` with the +# autotune-selected per-mode hyperparameters. +# +# Argument layout mirrors run_distribution_analysis_new.sh. +# +# Usage: +# # Default (Qwen/Qwen3-1.7B, block_size=16): +# bash remap_function_bench.sh --gpu 5 +# +# # Larger model + larger page/block size: +# bash remap_function_bench.sh --gpu 0 \ +# --model-name Qwen/Qwen3-8B \ +# --block-size 32 \ +# --seq-len 16384 --topk-val 512 \ +# --modes "0 3 6 7" +# +# # Reuse an existing calibration: +# bash remap_function_bench.sh --gpu 0 \ +# --model-name Qwen/Qwen3-8B \ +# --real-histograms /path/to/calibration/raw_histograms.npy +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=4 +MODEL_NAME="Qwen/Qwen3-1.7B" +TOPK_VAL=2048 +MEM=0.7 +ALGO="block_sparse_attention" +SAMPLE_STRIDE=1 +SEQ_LEN=65536 +BLOCK_SIZE=16 +BATCH_SIZE=4 +NUM_KV_HEADS=8 +DISTRIBUTIONS="normal bucket_uniform" +# Modes 1 (LUT_CDF) and 2 (Quantile) are evaluated only if calibration +# produces lut.npy / quantiles.npy. The shell script detects that below. +MAPPING_MODES="0 1 2 3 6 7 8 9 10 11 13" +# Fallback hparam used only if autotune is explicitly skipped. +MAPPING_HPARAM=0.5 +REPEAT=100 +WARMUP=20 +# Empty by default — Step 1 will calibrate on the selected model. +# Pass --real-histograms /path/to/raw_histograms.npy to skip calibration. +REAL_HISTOGRAMS="" +SKIP_AUTOTUNE=0 + +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --sample-stride) SAMPLE_STRIDE="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --distributions) DISTRIBUTIONS="$2"; shift 2 ;; + --modes) MAPPING_MODES="$2"; shift 2 ;; + --mapping-hparam) MAPPING_HPARAM="$2"; shift 2 ;; + --repeat) REPEAT="$2"; shift 2 ;; + --warmup) WARMUP="$2"; shift 2 ;; + --skip-autotune) SKIP_AUTOTUNE=1; shift 1 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" + +# Validate seq_len: need pages/seg > topk_val (3 reserved pages) +MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) +if [ "${SEQ_LEN}" -lt "${MIN_SEQ_LEN}" ]; then + echo "ERROR: --seq-len ${SEQ_LEN} too small for --topk-val ${TOPK_VAL} @ --block-size ${BLOCK_SIZE}." + echo " Minimum: ${MIN_SEQ_LEN} (pages/seg must exceed topk_val + 3 reserved pages)" + exit 1 +fi + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +RUN_DIR="${RESULTS_DIR}/remap_bench_${MODEL_SLUG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" +mkdir -p "${RUN_DIR}" + +echo "============================================================" +echo "Remap Function Benchmark" +echo " Model: ${MODEL_NAME}" +echo " Algorithm: ${ALGO}" +echo " TopK: ${TOPK_VAL}" +echo " Block size: ${BLOCK_SIZE}" +echo " Seq len: ${SEQ_LEN} ($(( SEQ_LEN / BLOCK_SIZE )) pages/seg)" +echo " Batch size: ${BATCH_SIZE}" +echo " KV heads: ${NUM_KV_HEADS}" +echo " Distributions: ${DISTRIBUTIONS}" +echo " Mapping modes: ${MAPPING_MODES}" +echo " Fallback hparam: ${MAPPING_HPARAM} (used only when --skip-autotune)" +echo " GPU: ${GPU_ID}" +echo " Sample stride: ${SAMPLE_STRIDE}" +echo " Real histograms: ${REAL_HISTOGRAMS:-}" +echo " Output: ${RUN_DIR}" +echo "============================================================" + +# ── Step 1: Calibrate — collect real-distribution topk histograms ── +# calibrate_topk.py runs the model end-to-end with histogram profiling +# enabled and writes per-segment raw_histograms.npy. The histograms are +# aggregated over every layer and every decode/prefill step so the +# autotune in Step 2 sees the true attention-score distribution. +if [ -n "${REAL_HISTOGRAMS}" ]; then + echo "" + echo ">>> Step 1: SKIPPED (using provided --real-histograms ${REAL_HISTOGRAMS})" + REAL_HIST_PATH="${REAL_HISTOGRAMS}" +else + echo "" + echo ">>> Step 1: Calibrating ${MODEL_NAME} — collecting real topk histograms" + CALIBRATION_DIR="${RUN_DIR}/calibration" + mkdir -p "${CALIBRATION_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --mem "${MEM}" \ + --vortex-module-name "${ALGO}" \ + --output-dir "${CALIBRATION_DIR}" \ + 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" + REAL_HIST_PATH="${CALIBRATION_DIR}/raw_histograms.npy" + echo ">>> Step 1: Done. Calibration saved to ${CALIBRATION_DIR}" +fi + +# Calibration may have produced lut.npy / quantiles.npy for modes 1 and 2. +CALIB_DIR="$(dirname "${REAL_HIST_PATH}")" +LUT_PATH="" +Q_PATH="" +[ -f "${CALIB_DIR}/lut.npy" ] && LUT_PATH="${CALIB_DIR}/lut.npy" +[ -f "${CALIB_DIR}/quantiles.npy" ] && Q_PATH="${CALIB_DIR}/quantiles.npy" +[ -n "${LUT_PATH}" ] && echo " Calibration LUT: ${LUT_PATH}" +[ -n "${Q_PATH}" ] && echo " Calibration quantile: ${Q_PATH}" + +# ── Step 2: Auto-tune hyperparameters by profiled fused-topk latency ── +# For every (mode, hparam) combo in the sweep grid, the autotune runs the +# fused remap+topk kernel on the real histogram and measures end-to-end +# kernel latency with CUDA events. The per-mode hparam with the lowest +# measured topk kernel latency wins. +AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" +if [ "${SKIP_AUTOTUNE}" -eq 1 ]; then + echo "" + echo ">>> Step 2: SKIPPED (using fallback --mapping-hparam ${MAPPING_HPARAM})" + AUTOTUNE_ARGS="" +else + echo "" + echo ">>> Step 2: Auto-tuning hyperparameters by profiled topk kernel latency" + AUTOTUNE_EXTRA=() + [ -n "${LUT_PATH}" ] && AUTOTUNE_EXTRA+=(--lut-path "${LUT_PATH}") + [ -n "${Q_PATH}" ] && AUTOTUNE_EXTRA+=(--quantiles-path "${Q_PATH}") + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --batch-size "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --collect-stats \ + "${AUTOTUNE_EXTRA[@]}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2_autotune.log" + echo ">>> Step 2: Done. Autotune results saved to ${AUTOTUNE_JSON}" + AUTOTUNE_ARGS="--autotune-json ${AUTOTUNE_JSON}" +fi + +# ── Step 3: Remap benchmark (baseline / fused / remap / split) ── +echo "" +echo ">>> Step 3: Timing remap / topk / fused / baseline with autotuned hparams" +REMAP_JSON="${RUN_DIR}/remap_bench.json" +BENCH_EXTRA=() +[ -n "${LUT_PATH}" ] && BENCH_EXTRA+=(--lut-path "${LUT_PATH}") +[ -n "${Q_PATH}" ] && BENCH_EXTRA+=(--quantiles-path "${Q_PATH}") +PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --batch-sizes "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --distributions ${DISTRIBUTIONS} \ + --mapping-modes ${MAPPING_MODES} \ + --mapping-hparam "${MAPPING_HPARAM}" \ + ${AUTOTUNE_ARGS} \ + "${BENCH_EXTRA[@]}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --output-json "${REMAP_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step3_remap_bench.log" +echo ">>> Step 3: Done. Remap bench saved to ${REMAP_JSON}" + +# ── Summary ─────────────────────────────────────────────────── +echo "" +echo "============================================================" +echo "Remap Function Benchmark Complete" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" +echo " All outputs in: ${RUN_DIR}/" +echo " calibration/raw_histograms.npy — real topk distribution (per layer)" +echo " autotune_results.json — latency-ranked mapping hparams" +echo " remap_bench.json — per-config remap/topk/fused/baseline latencies" +echo " step{1,2,3}_*.log — pipeline logs" +echo "============================================================" diff --git a/examples/run_distribution_analysis.sh b/examples/run_distribution_analysis.sh index fcc2ff1..36d4cd4 100755 --- a/examples/run_distribution_analysis.sh +++ b/examples/run_distribution_analysis.sh @@ -20,6 +20,10 @@ # bash run_distribution_analysis.sh --gpu 5 # bash run_distribution_analysis.sh --gpu 5 \ # --real-histograms /path/to/calibration_dir/raw_histograms.npy +# bash run_distribution_analysis.sh --gpu 5 --block-size 16 +# bash run_distribution_analysis.sh --watchdog-timeout 0 # disable calibrate watchdog (fork) +# Models (default: 1.7B + 4B). Override with repeated --model-name: +# bash run_distribution_analysis.sh --model-name Qwen/Qwen3-1.7B --model-name Qwen/Qwen3-4B # ============================================================ # Mapping functions: @@ -42,21 +46,34 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BENCH_DIR="${SCRIPT_DIR}/../benchmarks" # ── Defaults ────────────────────────────────────────────────── -GPU_ID=4 -MODEL_NAME="Qwen/Qwen3-1.7B" +GPU_ID=7 +# Models to run (full pipeline per model). Override with one or more --model-name. +MODEL_NAMES=( "Qwen/Qwen3-1.7B" "Qwen/Qwen3-4B" ) +MODEL_NAMES_USER_SET=0 TOPK_VAL=30 MEM=0.7 ALGO="block_sparse_attention" RADIX_BITS=8 SAMPLE_STRIDE=1 SEQ_LEN=32768 +# KV page / block size (passed to benchmarks as --page-size) +BLOCK_SIZE=16 # The path to the raw_histograms.npy file (set to skip calibration) REAL_HISTOGRAMS="/data/datasets/xinrui/My_Projects/vortex_torch/examples/calibration/raw_histograms.npy" REAL_HISTOGRAMS="" +HAS_WATCHDOG_TIMEOUT=0 +WATCHDOG_TIMEOUT="" # ── Parse arguments ─────────────────────────────────────────── while [[ $# -gt 0 ]]; do case "$1" in - --model-name) MODEL_NAME="$2"; shift 2 ;; + --model-name) + if [ "${MODEL_NAMES_USER_SET}" -eq 0 ]; then + MODEL_NAMES=() + MODEL_NAMES_USER_SET=1 + fi + MODEL_NAMES+=("$2") + shift 2 + ;; --topk-val) TOPK_VAL="$2"; shift 2 ;; --mem) MEM="$2"; shift 2 ;; --gpu) GPU_ID="$2"; shift 2 ;; @@ -65,14 +82,21 @@ while [[ $# -gt 0 ]]; do --radix-bits) RADIX_BITS="$2"; shift 2 ;; --sample-stride) SAMPLE_STRIDE="$2"; shift 2 ;; --seq-len) SEQ_LEN="$2"; shift 2 ;; + --block-size) BLOCK_SIZE="$2"; shift 2 ;; + --watchdog-timeout) HAS_WATCHDOG_TIMEOUT=1; WATCHDOG_TIMEOUT="$2"; shift 2 ;; *) echo "Unknown option: $1"; exit 1 ;; esac done export CUDA_VISIBLE_DEVICES="${GPU_ID}" -# Validate seq_len: need pages/seg > topk_val (page_size=16, reserved=3 pages) -MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * 16 )) +if [ "${#MODEL_NAMES[@]}" -eq 0 ]; then + echo "ERROR: No models in MODEL_NAMES; pass at least one --model-name." + exit 1 +fi + +# Validate seq_len: need pages/seg > topk_val (reserved=3 pages + slack) +MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) if [ "${SEQ_LEN}" -lt "${MIN_SEQ_LEN}" ]; then echo "ERROR: --seq-len ${SEQ_LEN} too small for --topk-val ${TOPK_VAL}." echo " Minimum: ${MIN_SEQ_LEN} (pages/seg must exceed topk_val + 3 reserved pages)" @@ -82,140 +106,126 @@ fi RESULTS_DIR="${SCRIPT_DIR}/results" mkdir -p "${RESULTS_DIR}" TIMESTAMP=$(date +%Y%m%d_%H%M%S) -RUN_DIR="${RESULTS_DIR}/dist_analysis_${TIMESTAMP}" -mkdir -p "${RUN_DIR}" echo "============================================================" echo "Bucket Distribution Profiling Pipeline" -echo " Model: ${MODEL_NAME}" +echo " Models (${#MODEL_NAMES[@]}): ${MODEL_NAMES[*]}" echo " Algorithm: ${ALGO}" echo " TopK: ${TOPK_VAL}" +echo " Seq len: ${SEQ_LEN} ($(( SEQ_LEN / BLOCK_SIZE )) pages/seg)" +echo " Block size: ${BLOCK_SIZE} (--page-size in benchmarks)" echo " GPU: ${GPU_ID}" echo " Radix bits: ${RADIX_BITS} ($(( 1 << RADIX_BITS )) bins)" echo " Sample stride: ${SAMPLE_STRIDE}" -echo " Real histograms: ${REAL_HISTOGRAMS:-}" -echo " Output: ${RUN_DIR}" -echo "============================================================" - -# ── Step 1: Calibrate — collect real-data histograms + LUT/quantiles ── -if [ -n "${REAL_HISTOGRAMS}" ]; then - echo "" - echo ">>> Step 1: SKIPPED (using provided --real-histograms ${REAL_HISTOGRAMS})" - REAL_HIST_PATH="${REAL_HISTOGRAMS}" -else - echo "" - echo ">>> Step 1: Calibrating — collecting real-inference histograms" - CALIBRATION_DIR="${RUN_DIR}/calibration" - mkdir -p "${CALIBRATION_DIR}" - python "${BENCH_DIR}/calibrate_topk.py" \ - --model-name "${MODEL_NAME}" \ - --topk-val "${TOPK_VAL}" \ - --mem "${MEM}" \ - --vortex-module-name "${ALGO}" \ - --output-dir "${CALIBRATION_DIR}" \ - 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" - REAL_HIST_PATH="${CALIBRATION_DIR}/raw_histograms.npy" - echo ">>> Step 1: Done. Calibration saved to ${CALIBRATION_DIR}" -fi - -# ── Step 2: Auto-tune — sweep hyperparameters ────────────────── -echo "" -echo ">>> Step 2: Auto-tuning hyperparameters (modes 3, 6, 7, 9, 10)" - -AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" - -# Build autotune data source args -AUTOTUNE_EXTRA_ARGS=() -if [ -n "${REAL_HIST_PATH:-}" ]; then - AUTOTUNE_EXTRA_ARGS+=(--real-histograms "${REAL_HIST_PATH}") -fi - -PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ - --topk-val "${TOPK_VAL}" \ - --batch-size 4 \ - --seq-len ${SEQ_LEN} \ - --num-kv-heads 2 \ - "${AUTOTUNE_EXTRA_ARGS[@]}" \ - --output-json "${AUTOTUNE_JSON}" \ - 2>&1 | tee "${RUN_DIR}/step2_autotune.log" - -echo ">>> Step 2: Done. Autotune results saved to ${AUTOTUNE_JSON}" - -# ── Step 3: Histogram profiling (bucket_uniform + normal) ───── -echo "" -echo ">>> Step 3: Kernel-level histogram profiling (bucket_uniform + normal)" - -BENCH_JSON="${RUN_DIR}/bench_distribution.json" - -# Build optional args for bench_topk.py -BENCH_EXTRA_ARGS=() -if [ -n "${REAL_HIST_PATH:-}" ]; then - BENCH_EXTRA_ARGS+=(--real-histograms "${REAL_HIST_PATH}") -fi - -# Derive calibration directory from histogram path to find lut.npy / quantiles.npy -CALIB_DIR="$(dirname "${REAL_HIST_PATH}")" -LUT_FILE="${CALIB_DIR}/lut.npy" -QUANTILES_FILE="${CALIB_DIR}/quantiles.npy" - -if [ -f "${LUT_FILE}" ]; then - BENCH_EXTRA_ARGS+=(--lut-path "${LUT_FILE}") - echo " Using LUT for mode 1: ${LUT_FILE}" -else - echo " WARNING: ${LUT_FILE} not found — mode 1 (LUT CDF) will be skipped" -fi -if [ -f "${QUANTILES_FILE}" ]; then - BENCH_EXTRA_ARGS+=(--quantiles-path "${QUANTILES_FILE}") - echo " Using quantiles for mode 2: ${QUANTILES_FILE}" +if [ "${HAS_WATCHDOG_TIMEOUT}" -eq 1 ]; then + echo " Watchdog (cal): ${WATCHDOG_TIMEOUT}s (0 = off, vortex SGLang fork)" else - echo " WARNING: ${QUANTILES_FILE} not found — mode 2 (Quantile) will be skipped" + echo " Watchdog (cal): " fi +echo " Real histograms: ${REAL_HISTOGRAMS:-}" +echo " Run id: ${TIMESTAMP}" +echo " Output root: ${RESULTS_DIR}/dist_analysis__${TIMESTAMP}/" +echo "============================================================" -PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ - --batch-sizes 4 \ - --seq-lens ${SEQ_LEN} \ - --topk-vals "${TOPK_VAL}" \ - --num-kv-heads 8 \ - --distributions bucket_uniform normal \ - --histogram \ - --counters \ - "${BENCH_EXTRA_ARGS[@]}" \ - --autotune-json "${AUTOTUNE_JSON}" \ - --filter-kernels naive sglang_ori sglang_m0 sglang_scale sglang_m1 sglang_m2 sglang_m3 sglang_m3_noscale sglang_m4 sglang_m6 sglang_m6_noscale sglang_m7 sglang_m7_noscale sglang_m8 sglang_m9 sglang_m9_noscale sglang_m10 sglang_m10_noscale sglang_m11 sglang_m13 sglang_m13_noscale sglang_m14 \ - --radix-bits "${RADIX_BITS}" \ - --sample-stride "${SAMPLE_STRIDE}" \ - --repeat 20 \ - --output-json "${BENCH_JSON}" \ - 2>&1 | tee "${RUN_DIR}/step3_bench.log" - -echo ">>> Step 3: Done. Results saved to ${BENCH_JSON}" - -# ── Step 4: Analyze — comparison plots + tables ─────────────── -echo "" -echo ">>> Step 4: Generating distribution comparison plots + tables" +for MODEL_NAME in "${MODEL_NAMES[@]}"; do + MODEL_SLUG="${MODEL_NAME//\//_}" + RUN_DIR="${RESULTS_DIR}/dist_analysis_${MODEL_SLUG}_${TIMESTAMP}" + mkdir -p "${RUN_DIR}" -# Build optional args for analyze -ANALYZE_EXTRA_ARGS=() -if [ -n "${REAL_HIST_PATH:-}" ]; then - ANALYZE_EXTRA_ARGS+=(--real-histograms "${REAL_HIST_PATH}") -fi + echo "" + echo "############################ MODEL: ${MODEL_NAME} ############################" + echo " Output: ${RUN_DIR}" + + # ── Step 1: Calibrate — collect real-data histograms + LUT/quantiles ── + if [ -n "${REAL_HISTOGRAMS}" ]; then + echo "" + echo ">>> Step 1: SKIPPED (using provided --real-histograms ${REAL_HISTOGRAMS})" + REAL_HIST_PATH="${REAL_HISTOGRAMS}" + else + echo "" + echo ">>> Step 1: Calibrating — collecting real-inference histograms" + CALIBRATION_DIR="${RUN_DIR}/calibration" + mkdir -p "${CALIBRATION_DIR}" + CALIB_EXTRA_ARGS=() + if [ "${HAS_WATCHDOG_TIMEOUT}" -eq 1 ]; then + CALIB_EXTRA_ARGS+=(--watchdog-timeout "${WATCHDOG_TIMEOUT}") + fi + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --mem "${MEM}" \ + --vortex-module-name "${ALGO}" \ + --page-size "${BLOCK_SIZE}" \ + --output-dir "${CALIBRATION_DIR}" \ + "${CALIB_EXTRA_ARGS[@]}" \ + 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" + REAL_HIST_PATH="${CALIBRATION_DIR}/raw_histograms.npy" + echo ">>> Step 1: Done. Calibration saved to ${CALIBRATION_DIR}" + fi + + # Pick up lut.npy / quantiles.npy if calibration produced them. + CALIB_DIR="$(dirname "${REAL_HIST_PATH}")" + LUT_PATH="" + Q_PATH="" + [ -f "${CALIB_DIR}/lut.npy" ] && LUT_PATH="${CALIB_DIR}/lut.npy" + [ -f "${CALIB_DIR}/quantiles.npy" ] && Q_PATH="${CALIB_DIR}/quantiles.npy" + [ -n "${LUT_PATH}" ] && echo " Calibration LUT: ${LUT_PATH}" + [ -n "${Q_PATH}" ] && echo " Calibration quantile: ${Q_PATH}" + + # ── Step 2: Auto-tune — rank by fused-topk kernel latency ────── + echo "" + echo ">>> Step 2: Auto-tuning hyperparameters by fused-topk kernel latency" -python "${BENCH_DIR}/analyze_topk_distribution.py" \ - --bench-json "${BENCH_JSON}" \ - "${ANALYZE_EXTRA_ARGS[@]}" \ - --output-dir "${RUN_DIR}" \ - 2>&1 | tee "${RUN_DIR}/step4_analyze.log" + AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" + AUTOTUNE_EXTRA=(--real-histograms "${REAL_HIST_PATH}") + [ -n "${LUT_PATH}" ] && AUTOTUNE_EXTRA+=(--lut-path "${LUT_PATH}") + [ -n "${Q_PATH}" ] && AUTOTUNE_EXTRA+=(--quantiles-path "${Q_PATH}") -echo ">>> Step 4: Done." + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val "${TOPK_VAL}" \ + --batch-size 4 \ + --seq-len ${SEQ_LEN} \ + --page-size "${BLOCK_SIZE}" \ + --num-kv-heads 2 \ + --collect-stats \ + "${AUTOTUNE_EXTRA[@]}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2_autotune.log" + + echo ">>> Step 2: Done. Autotune results saved to ${AUTOTUNE_JSON}" + + # ── Step 3: Remap benchmark with autotuned hparams ────────────── + echo "" + echo ">>> Step 3: Remap benchmark (baseline / fused / remap / split) with autotuned hparams" + + BENCH_JSON="${RUN_DIR}/remap_bench.json" + BENCH_EXTRA=() + [ -n "${LUT_PATH}" ] && BENCH_EXTRA+=(--lut-path "${LUT_PATH}") + [ -n "${Q_PATH}" ] && BENCH_EXTRA+=(--quantiles-path "${Q_PATH}") + + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --batch-sizes 4 \ + --num-kv-heads 8 \ + --seq-lens ${SEQ_LEN} \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --distributions bucket_uniform normal \ + --mapping-modes 0 1 2 3 6 7 8 9 10 11 13 \ + --autotune-json "${AUTOTUNE_JSON}" \ + "${BENCH_EXTRA[@]}" \ + --repeat 20 \ + --output-json "${BENCH_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step3_bench.log" + + echo ">>> Step 3: Done. Remap bench saved to ${BENCH_JSON}" +done # ── Summary ─────────────────────────────────────────────────── echo "" echo "============================================================" echo "Bucket Distribution Profiling Complete" -echo " All outputs in: ${RUN_DIR}/" -echo " autotune_results.json — hyperparameter sweep rankings" -echo " bench_distribution.json — raw benchmark data" -echo " distribution_comparison.png — bucket dist plots" -echo " bucket_counts.csv — per-bucket count table" -echo " step{1,2,3,4}_*.log — pipeline logs" +echo " Per-model outputs under ${RESULTS_DIR}/ (run id ${TIMESTAMP}):" +echo " dist_analysis__${TIMESTAMP}/" +echo " autotune_results.json, bench_distribution.json, plots, CSV, logs" echo "============================================================" diff --git a/examples/run_distribution_analysis_new.sh b/examples/run_distribution_analysis_new.sh index 65e4f41..ec72665 100755 --- a/examples/run_distribution_analysis_new.sh +++ b/examples/run_distribution_analysis_new.sh @@ -1,27 +1,31 @@ #!/usr/bin/env bash # ============================================================ -# Bucket Distribution Profiling Pipeline (modes 3, 6, 7 only) +# Bucket Distribution / Remap Latency Pipeline (parametric modes) # -# Tests only the parametric mapping modes with auto-tuning: -# Mode 3 (Power): y = sign(x) * |x|^p -# Mode 6 (Asinh): y = asinh(beta * x) -# Mode 7 (Log1p): y = sign(x) * log1p(alpha * |x|) -# Mode 8 (Trunc8): bf16 upper-8-bit bucketing -# Mode 9 (Erf): y = erf(alpha * x) -# Mode 10 (Tanh): y = tanh(alpha * x) -# Mode 11 (Subtract): x - pivot (RadiK-style scatter) +# Tests the surviving parametric mapping modes after the lean +# refactor: +# Mode 3 (Power): y = sign(x) * |x|^p +# Mode 6 (Asinh): y = asinh(beta * x) +# Mode 7 (Log1p): y = sign(x) * log1p(alpha * |x|) +# Mode 9 (Erf): y = erf(alpha * x) +# Mode 10 (Tanh): y = tanh(alpha * x) +# Mode 13 (ExpStretch): y = exp(alpha * x) # -# Four steps: -# 1. Calibrate — collect real-data histograms -# (skippable via --real-histograms PATH) -# 2. Auto-tune — sweep hyperparameters on synthetic data -# 3. Bench — histogram profiling (bucket_uniform + normal) -# 4. Analyze — comparison plots + bucket count tables +# Pipeline: +# 1. Calibrate — collect real-distribution histograms from the +# chosen model (skippable via --real-histograms). +# 2. Autotune — rank per-mode hparams by measured fused-topk +# kernel latency (lowest wins). +# 3. Remap bench— bench_topk.py --remap-bench fed with the +# autotune JSON. Reports per-mode remap / topk / +# fused / baseline latencies and threshold stats. # # Usage: # bash run_distribution_analysis_new.sh --gpu 5 # bash run_distribution_analysis_new.sh --gpu 5 \ -# --real-histograms /path/to/calibration_dir/raw_histograms.npy +# --model-name Qwen/Qwen3-8B --block-size 32 +# bash run_distribution_analysis_new.sh --gpu 5 \ +# --real-histograms /path/to/raw_histograms.npy # ============================================================ set -euo pipefail @@ -34,65 +38,78 @@ MODEL_NAME="Qwen/Qwen3-1.7B" TOPK_VAL=2048 MEM=0.7 ALGO="block_sparse_attention" -RADIX_BITS=8 -SAMPLE_STRIDE=1 SEQ_LEN=65536 -# The path to the raw_histograms.npy file (set to skip calibration) -REAL_HISTOGRAMS="/data/datasets/xinrui/My_Projects/vortex_torch/examples/calibration/raw_histograms.npy" -# REAL_HISTOGRAMS="" -# ── Parse arguments ─────────────────────────────────────────── +BLOCK_SIZE=16 +BATCH_SIZE=4 +NUM_KV_HEADS=8 +DISTRIBUTIONS="bucket_uniform normal" +# LUT_CDF (1) / QUANTILE (2) are evaluated only when calibration produces +# lut.npy / quantiles.npy. 0 baseline is always included by --remap-bench. +MAPPING_MODES="1 2 3 6 7 8 9 10 11 13" +REPEAT=100 +WARMUP=20 +REAL_HISTOGRAMS="" + while [[ $# -gt 0 ]]; do case "$1" in - --model-name) MODEL_NAME="$2"; shift 2 ;; - --topk-val) TOPK_VAL="$2"; shift 2 ;; - --mem) MEM="$2"; shift 2 ;; - --gpu) GPU_ID="$2"; shift 2 ;; - --algo) ALGO="$2"; shift 2 ;; - --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; - --radix-bits) RADIX_BITS="$2"; shift 2 ;; - --sample-stride) SAMPLE_STRIDE="$2"; shift 2 ;; - --seq-len) SEQ_LEN="$2"; shift 2 ;; + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --distributions) DISTRIBUTIONS="$2"; shift 2 ;; + --modes) MAPPING_MODES="$2"; shift 2 ;; + --repeat) REPEAT="$2"; shift 2 ;; + --warmup) WARMUP="$2"; shift 2 ;; *) echo "Unknown option: $1"; exit 1 ;; esac done export CUDA_VISIBLE_DEVICES="${GPU_ID}" -# Validate seq_len: need pages/seg > topk_val (page_size=16, reserved=3 pages) -MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * 16 )) +MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) if [ "${SEQ_LEN}" -lt "${MIN_SEQ_LEN}" ]; then - echo "ERROR: --seq-len ${SEQ_LEN} too small for --topk-val ${TOPK_VAL}." - echo " Minimum: ${MIN_SEQ_LEN} (pages/seg must exceed topk_val + 3 reserved pages)" + echo "ERROR: --seq-len ${SEQ_LEN} too small for --topk-val ${TOPK_VAL} @ --block-size ${BLOCK_SIZE}." + echo " Minimum: ${MIN_SEQ_LEN}" exit 1 fi RESULTS_DIR="${SCRIPT_DIR}/results" mkdir -p "${RESULTS_DIR}" TIMESTAMP=$(date +%Y%m%d_%H%M%S) -RUN_DIR="${RESULTS_DIR}/dist_analysis_topk${TOPK_VAL}_${TIMESTAMP}" +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +RUN_DIR="${RESULTS_DIR}/dist_analysis_${MODEL_SLUG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" mkdir -p "${RUN_DIR}" echo "============================================================" -echo "Bucket Distribution Profiling (modes 3, 6, 7, 8, 9, 10, 11)" +echo "Bucket Distribution / Remap Latency Pipeline (parametric modes)" echo " Model: ${MODEL_NAME}" echo " Algorithm: ${ALGO}" echo " TopK: ${TOPK_VAL}" -echo " Seq len: ${SEQ_LEN} ($(( SEQ_LEN / 16 )) pages/seg)" +echo " Block size: ${BLOCK_SIZE}" +echo " Seq len: ${SEQ_LEN} ($(( SEQ_LEN / BLOCK_SIZE )) pages/seg)" +echo " Batch size: ${BATCH_SIZE}" +echo " KV heads: ${NUM_KV_HEADS}" +echo " Distributions: ${DISTRIBUTIONS}" +echo " Mapping modes: ${MAPPING_MODES}" echo " GPU: ${GPU_ID}" -echo " Radix bits: ${RADIX_BITS} ($(( 1 << RADIX_BITS )) bins)" -echo " Sample stride: ${SAMPLE_STRIDE}" -echo " Real histograms: ${REAL_HISTOGRAMS:-}" +echo " Real histograms: ${REAL_HISTOGRAMS:-}" echo " Output: ${RUN_DIR}" echo "============================================================" -# ── Step 1: Calibrate — collect real-data histograms + LUT/quantiles ── +# ── Step 1: Calibrate ─────────────────────────────────────────── if [ -n "${REAL_HISTOGRAMS}" ]; then echo "" echo ">>> Step 1: SKIPPED (using provided --real-histograms ${REAL_HISTOGRAMS})" REAL_HIST_PATH="${REAL_HISTOGRAMS}" else echo "" - echo ">>> Step 1: Calibrating — collecting real-inference histograms" + echo ">>> Step 1: Calibrating ${MODEL_NAME} — collecting real topk histograms" CALIBRATION_DIR="${RUN_DIR}/calibration" mkdir -p "${CALIBRATION_DIR}" python "${BENCH_DIR}/calibrate_topk.py" \ @@ -100,74 +117,75 @@ else --topk-val "${TOPK_VAL}" \ --mem "${MEM}" \ --vortex-module-name "${ALGO}" \ + --page-size "${BLOCK_SIZE}" \ --output-dir "${CALIBRATION_DIR}" \ 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" REAL_HIST_PATH="${CALIBRATION_DIR}/raw_histograms.npy" echo ">>> Step 1: Done. Calibration saved to ${CALIBRATION_DIR}" fi -# ── Step 2: Auto-tune — sweep hyperparameters on synthetic data ───── -echo "" -echo ">>> Step 2: Auto-tuning hyperparameters (modes 3, 6, 7)" +# Pick up lut.npy / quantiles.npy if calibration produced them. +CALIB_DIR="$(dirname "${REAL_HIST_PATH}")" +LUT_PATH="" +Q_PATH="" +[ -f "${CALIB_DIR}/lut.npy" ] && LUT_PATH="${CALIB_DIR}/lut.npy" +[ -f "${CALIB_DIR}/quantiles.npy" ] && Q_PATH="${CALIB_DIR}/quantiles.npy" +# ── Step 2: Autotune (latency-ranked) ─────────────────────────── +echo "" +echo ">>> Step 2: Auto-tuning hyperparameters by fused-topk kernel latency" AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" - +AUTOTUNE_EXTRA=() +[ -n "${LUT_PATH}" ] && AUTOTUNE_EXTRA+=(--lut-path "${LUT_PATH}") +[ -n "${Q_PATH}" ] && AUTOTUNE_EXTRA+=(--quantiles-path "${Q_PATH}") PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ --topk-val "${TOPK_VAL}" \ - --batch-size 4 \ - --seq-len ${SEQ_LEN} \ - --num-kv-heads 8 \ + --batch-size "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --page-size "${BLOCK_SIZE}" \ --real-histograms "${REAL_HIST_PATH}" \ - --latency-rerank \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --collect-stats \ + "${AUTOTUNE_EXTRA[@]}" \ --output-json "${AUTOTUNE_JSON}" \ 2>&1 | tee "${RUN_DIR}/step2_autotune.log" - echo ">>> Step 2: Done. Autotune results saved to ${AUTOTUNE_JSON}" -# ── Step 3: Histogram profiling (bucket_uniform + normal) ───── +# ── Step 3: Remap bench with autotuned hparams ────────────────── echo "" -echo ">>> Step 3: Kernel-level histogram profiling (modes 3, 6, 7)" - -BENCH_JSON="${RUN_DIR}/bench_distribution.json" - +echo ">>> Step 3: Remap benchmark (baseline / fused / remap / split) with autotuned hparams" +BENCH_JSON="${RUN_DIR}/remap_bench.json" +BENCH_EXTRA=() +[ -n "${LUT_PATH}" ] && BENCH_EXTRA+=(--lut-path "${LUT_PATH}") +[ -n "${Q_PATH}" ] && BENCH_EXTRA+=(--quantiles-path "${Q_PATH}") PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ - --batch-sizes 4 \ - --seq-lens ${SEQ_LEN} \ + --remap-bench \ + --batch-sizes "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ --topk-vals "${TOPK_VAL}" \ - --num-kv-heads 8 \ - --distributions bucket_uniform normal \ - --histogram \ - --counters \ - --real-histograms "${REAL_HIST_PATH}" \ + --page-size "${BLOCK_SIZE}" \ + --distributions ${DISTRIBUTIONS} \ + --mapping-modes ${MAPPING_MODES} \ --autotune-json "${AUTOTUNE_JSON}" \ - --filter-kernels naive sglang_ori sglang_m0 sglang_scale sglang_m3 sglang_m3_noscale sglang_m6 sglang_m6_noscale sglang_m7 sglang_m7_noscale sglang_m8 sglang_m9 sglang_m9_noscale sglang_m10 sglang_m10_noscale sglang_m11 sglang_m13 sglang_m13_noscale sglang_m14 \ - --radix-bits "${RADIX_BITS}" \ - --sample-stride "${SAMPLE_STRIDE}" \ - --repeat 20 \ + "${BENCH_EXTRA[@]}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ --output-json "${BENCH_JSON}" \ 2>&1 | tee "${RUN_DIR}/step3_bench.log" +echo ">>> Step 3: Done. Remap bench saved to ${BENCH_JSON}" -echo ">>> Step 3: Done. Results saved to ${BENCH_JSON}" - -# ── Step 4: Analyze — comparison plots + tables ─────────────── -echo "" -echo ">>> Step 4: Generating distribution comparison plots + tables" - -python "${BENCH_DIR}/analyze_topk_distribution.py" \ - --bench-json "${BENCH_JSON}" \ - --real-histograms "${REAL_HIST_PATH}" \ - --output-dir "${RUN_DIR}" \ - 2>&1 | tee "${RUN_DIR}/step4_analyze.log" -echo ">>> Step 4: Done." - -# ── Summary ─────────────────────────────────────────────────── +# ── Summary ───────────────────────────────────────────────────── echo "" echo "============================================================" -echo "Bucket Distribution Profiling Complete (modes 3, 6, 7, 8, 9, 10, 11)" +echo "Bucket Distribution / Remap Latency Pipeline Complete" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" echo " All outputs in: ${RUN_DIR}/" -echo " autotune_results.json — hyperparameter sweep rankings" -echo " bench_distribution.json — raw benchmark data" -echo " distribution_comparison.png — bucket dist plots" -echo " bucket_counts.csv — per-bucket count table" -echo " step{1,2,3,4}_*.log — pipeline logs" +echo " calibration/raw_histograms.npy — real topk distribution" +echo " autotune_results.json — latency-ranked hparams" +echo " remap_bench.json — remap/topk/fused/baseline latencies" +echo " step{1,2,3}_*.log — pipeline logs" echo "============================================================" diff --git a/examples/test_topk.py b/examples/test_topk.py new file mode 100644 index 0000000..01edc7b --- /dev/null +++ b/examples/test_topk.py @@ -0,0 +1,118 @@ +import torch +import triton +# topk_output_sglang expects sparse_kv_indptr before dense_kv_indices (unlike topk_output). +from vortex_torch_C import topk_output_sglang as topk_output + +SEQ_LENS = [4096] +BATCH_SIZES = [256] + +K = 32 +RESERVE_BOS = 0 +RESERVE_EOS = 0 +DEVICE = "cuda" + + +def make_inputs(batch_size, seq_len, k, reserve_bos, reserve_eos, device="cuda"): + dense_kv_indptr = torch.arange( + 0, batch_size * seq_len + 1, seq_len, dtype=torch.int32, device=device + ) + + dense_kv_indices = torch.arange( + 0, batch_size * seq_len, dtype=torch.int32, device=device + ) + + scores = torch.randn( + batch_size * seq_len, dtype=torch.bfloat16, device=device + ) + + # ✅ Fixed CSR-style sparse indptr + sparse_kv_indptr = torch.arange( + 0, batch_size * k + 1, k, dtype=torch.int32, device=device + ) + + sparse_kv_indices = torch.empty( + batch_size * k, dtype=torch.int32, device=device + ) + + return ( + scores, + dense_kv_indptr, + dense_kv_indices, + sparse_kv_indptr, + sparse_kv_indices, + ) + + +def bench_one(batch_size, seq_len, k, reserve_bos, reserve_eos): + ( + scores, + dense_kv_indptr, + dense_kv_indices, + sparse_kv_indptr, + sparse_kv_indices, + ) = make_inputs( + batch_size=batch_size, + seq_len=seq_len, + k=k, + reserve_bos=reserve_bos, + reserve_eos=reserve_eos, + device=DEVICE, + ) + + def fn(): + topk_output( + scores, + dense_kv_indptr, + sparse_kv_indptr, + dense_kv_indices, + sparse_kv_indices, + batch_size, + k, + reserve_bos, + reserve_eos, + seq_len, + ) + + # warmup + for _ in range(10): + fn() + torch.cuda.synchronize() + + ms = triton.testing.do_bench( + fn, + warmup=100, + rep=1000, + return_mode="mean", + ) + return ms + + +def main(): + torch.cuda.init() + + results = {} + + for bs in BATCH_SIZES: + results[bs] = {} + for seq_len in SEQ_LENS: + ms = bench_one( + batch_size=bs, + seq_len=seq_len, + k=K, + reserve_bos=RESERVE_BOS, + reserve_eos=RESERVE_EOS, + ) + results[bs][seq_len] = ms + print(f"bs={bs:>3}, seq_len={seq_len:>4} -> {ms:.6f} ms") + + print("\nLatency table (ms):") + header = "bs\\seq".ljust(10) + "".join(f"{s:>12}" for s in SEQ_LENS) + print(header) + + for bs in BATCH_SIZES: + row = f"{bs:<10}" + "".join(f"{results[bs][s]:>12.4f}" for s in SEQ_LENS) + print(row) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/verify_algo.py b/examples/verify_algo.py index a1d1b6f..a78f1e6 100644 --- a/examples/verify_algo.py +++ b/examples/verify_algo.py @@ -120,8 +120,6 @@ def verify_algos( topk_type: str = "naive", topk_mapping_mode: int = 0, topk_mapping_hparam: float = 0.5, -topk_mapping_lut_path: str = None, -topk_mapping_quantiles_path: str = None, disable_cuda_graph: bool = False, benchmark: str = "amc23", ): @@ -143,8 +141,6 @@ def verify_algos( vortex_topk_type=topk_type, vortex_topk_mapping_mode=topk_mapping_mode, vortex_topk_mapping_hparam=topk_mapping_hparam, - vortex_topk_mapping_lut_path=topk_mapping_lut_path, - vortex_topk_mapping_quantiles_path=topk_mapping_quantiles_path, ) tokenizer = AutoTokenizer.from_pretrained(model_name) if benchmark != "amc23" else None prompts, requests = _load_benchmark(benchmark, trials, tokenizer=tokenizer) @@ -300,15 +296,17 @@ def parse_args(): "--topk-type", type=str, default="naive", - choices=["naive", "sglang", "sglang_ori"], - help='TopK kernel type: "naive" for topk_output, "sglang" for topk_output_sglang, "sglang_ori" for original sglang baseline (default: "naive").', + choices=["naive", "sglang", "sglang_fused"], + help='TopK kernel type: "naive" (CUB radix), "sglang" (unmapped baseline), "sglang_fused" (fused remap + topk). Default: "naive".', ) parser.add_argument( "--topk-mapping-mode", type=int, default=0, - choices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], - help='TopK mapping mode: 0=none, 1=lut_cdf, 2=quantile, 3=power, 4=log, 6=asinh, 7=log1p, 8=trunc8, 9=erf, 10=tanh, 11=subtract, 12=adaptive_tail_window, 13=exp_stretch, 14=topk_window (default: 0).', + choices=[0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 13], + help='TopK mapping mode for sglang_fused: 0=none, 1=lut_cdf (calibrated), ' + '2=quantile (calibrated), 3=power, 4=log, 6=asinh, 7=log1p, 8=trunc8, ' + '9=erf, 10=tanh, 11=subtract, 13=exp_stretch (default: 0).', ) parser.add_argument( @@ -319,20 +317,6 @@ def parse_args(): help='Hyperparameter for parametric modes: power exponent (mode 3), beta (mode 6), alpha (mode 7/9/10/13), rho (mode 12/14). Default: 0.5.', ) - parser.add_argument( - "--topk-mapping-lut-path", - type=str, - default=None, - help="Path to .npy file with uint8[256] LUT for topk mapping mode 1.", - ) - - parser.add_argument( - "--topk-mapping-quantiles-path", - type=str, - default=None, - help="Path to .npy file with float32[256] quantiles for topk mapping mode 2.", - ) - parser.add_argument( "--benchmark", type=str, @@ -366,8 +350,6 @@ def parse_args(): topk_type=args.topk_type, topk_mapping_mode=args.topk_mapping_mode, topk_mapping_hparam=args.topk_mapping_hparam, - topk_mapping_lut_path=args.topk_mapping_lut_path, - topk_mapping_quantiles_path=args.topk_mapping_quantiles_path, benchmark=bench_name, ) summary["benchmark"] = bench_name diff --git a/examples/verify_algo.sh b/examples/verify_algo.sh index ddcd905..7a96d1e 100644 --- a/examples/verify_algo.sh +++ b/examples/verify_algo.sh @@ -23,9 +23,4 @@ TIMESTAMP=$(date +%Y%m%d_%H%M%S) --topk-type naive \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" - done - - TORCH_CUDA_ARCH_LIST="12.0" \ - MAX_JOBS=64 \ - pip install -e . --no-build-isolation \ - -Ccmake.args="-DENABLE_BELOW_SM90=OFF" \ No newline at end of file + done \ No newline at end of file diff --git a/examples/verify_algo_topk_mapping.sh b/examples/verify_algo_topk_mapping.sh index 9a9f482..711a0f7 100644 --- a/examples/verify_algo_topk_mapping.sh +++ b/examples/verify_algo_topk_mapping.sh @@ -1,26 +1,46 @@ #!/usr/bin/env bash +# ============================================================ +# E2E accuracy comparison: naive baseline + unmapped sglang + +# every surviving parametric mapping mode (3, 4, 6, 7, 9, 10, 13) +# with per-mode hyperparameters picked by autotune_topk_mapping.py +# (ranked by measured fused-topk kernel latency, lowest wins). +# +# Surviving mapping modes after the lean refactor: +# 0: None — unmapped baseline +# 3: Power — y = sign(x) * |x|^p +# 4: Log — y = sign(x) * log(|x| + 1) +# 6: Asinh — y = asinh(beta * x) +# 7: Log1p — y = sign(x) * log1p(alpha * |x|) +# 9: Erf — y = erf(alpha * x) +# 10: Tanh — y = tanh(alpha * x) +# 13: ExpStretch — y = exp(alpha * x) +# ============================================================ set -e -# use CUDA_VISIBLE_DEVICES to set the GPU id you want to use -# Mapping functions: -# 0: None — original fp16 bit-pattern bucketing -# 1: LUT CDF — LUT-based CDF equalization (calibrated) -# 2: Quantile — piecewise-linear quantile mapping (calibrated) -# 3: Power — y = sign(x) * |x|^p -# 4: Log — y = sign(x) * log(|x| + 1) -# 5: Index Cache — reuse previous layer's indices -# 6: Asinh — y = asinh(beta * x) -# 7: Log1p — y = sign(x) * log1p(alpha * |x|) -# 8: Trunc8 — bf16 upper-8-bit bucketing -# 9: Erf — y = erf(alpha * x) -# 10: Tanh — y = tanh(alpha * x) -# 11: Subtract — x - pivot (RadiK-style scatter) + +# ── Defaults ────────────────────────────────────────────────── GPU_ID=0 BENCHMARKS="amc23" +MODEL_NAME="Qwen/Qwen3-1.7B" +TOPK_VAL=30 +BLOCK_SIZE=16 +BATCH_SIZE=4 +NUM_KV_HEADS=2 +SEQ_LEN=32768 +REAL_HISTOGRAMS="" +SKIP_AUTOTUNE=0 while [[ $# -gt 0 ]]; do case "$1" in - --gpu) GPU_ID="$2"; shift 2 ;; - --benchmark) BENCHMARKS="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --benchmark) BENCHMARKS="$2"; shift 2 ;; + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --skip-autotune) SKIP_AUTOTUNE=1; shift 1 ;; *) echo "Unknown option: $1"; exit 1 ;; esac done @@ -30,273 +50,151 @@ export CUDA_VISIBLE_DEVICES="${GPU_ID}" SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BENCH_DIR="${SCRIPT_DIR}/../benchmarks" -sparse_algos=( - "block_sparse_attention" -) +sparse_algos=( "block_sparse_attention" ) BENCH_LABEL=$(echo "${BENCHMARKS}" | tr ' ' '_') -RESULTS_DIR="results/${BENCH_LABEL}" +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +RESULTS_DIR="results/${MODEL_SLUG}_${BENCH_LABEL}" mkdir -p "${RESULTS_DIR}" TIMESTAMP=$(date +%Y%m%d_%H%M%S) -# Set this to an existing calibration directory to skip re-running calibration. -# It must contain lut.npy and quantiles.npy (output of calibrate_topk.py). -CALIBRATION_DIR="/data/datasets/xinrui/My_Projects/vortex_torch/examples/calibration" + # ============================================================ -# Baseline: naive topk (mode 0) +# Baseline: naive topk # ============================================================ for algo in "${sparse_algos[@]}"; do OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_naive_${TIMESTAMP}.log" - echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --topk-type naive --topk-mapping-mode 0" - echo ">>> Saving results to ${OUTFILE}" + echo ">>> naive topk algo=${algo}" { time python verify_algo.py \ --trials 8 \ - --topk-val 30 \ + --topk-val "${TOPK_VAL}" \ --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ + --model-name "${MODEL_NAME}" \ --topk-type naive \ - --topk-mapping-mode 0 \ --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done # ============================================================ -# Calibration: collect histograms for LUT/quantile generation -# Skipped if CALIBRATION_DIR already has lut.npy + quantiles.npy +# Calibrate (optional) — real-distribution histograms # ============================================================ -if [ -f "${CALIBRATION_DIR}/lut.npy" ] && [ -f "${CALIBRATION_DIR}/quantiles.npy" ]; then - echo ">>> Calibration SKIPPED (using existing ${CALIBRATION_DIR})" -else - CALIBRATION_DIR="${RESULTS_DIR}/calibration_${TIMESTAMP}" - for algo in "${sparse_algos[@]}"; do - echo ">>> Calibrating for ${algo}..." - python "${BENCH_DIR}/calibrate_topk.py" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-val 30 \ - --mem 0.7 \ - --vortex-module-name "${algo}" \ - --output-dir "${CALIBRATION_DIR}" \ - 2>&1 | tee "${RESULTS_DIR}/calibration_${algo}_${TIMESTAMP}.log" - done +if [ -z "${REAL_HISTOGRAMS}" ]; then + CALIBRATION_DIR="${RESULTS_DIR}/calibration_${TIMESTAMP}" + for algo in "${sparse_algos[@]}"; do + echo ">>> Calibrating ${MODEL_NAME} for ${algo}..." + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --mem 0.7 \ + --vortex-module-name "${algo}" \ + --page-size "${BLOCK_SIZE}" \ + --output-dir "${CALIBRATION_DIR}" \ + 2>&1 | tee "${RESULTS_DIR}/calibration_${algo}_${TIMESTAMP}.log" + done + REAL_HISTOGRAMS="${CALIBRATION_DIR}/raw_histograms.npy" fi -# ============================================================ -# Auto-tune: find best hyperparameters per mode -# Uses topk_profile_histogram kernel on real calibration data -# ============================================================ -REAL_HISTOGRAMS="${CALIBRATION_DIR}/raw_histograms.npy" -if [ -f "${REAL_HISTOGRAMS}" ]; then - echo "============================================================" - echo "Auto-tuning hyperparameters (real calibration data)" - echo "============================================================" - AUTOTUNE_JSON="${RESULTS_DIR}/autotune_${TIMESTAMP}.json" - PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ - --topk-val 30 \ - --batch-size 4 \ - --seq-len 32768 \ - --num-kv-heads 2 \ - --real-histograms "${REAL_HISTOGRAMS}" \ - --output-json "${AUTOTUNE_JSON}" \ - 2>&1 | tee "${RESULTS_DIR}/autotune_${TIMESTAMP}.log" - echo ">>> Auto-tune results saved to ${AUTOTUNE_JSON}" - echo "" +# Pick up lut.npy / quantiles.npy if calibration produced them. +CALIB_DIR="$(dirname "${REAL_HISTOGRAMS}")" +LUT_PATH="" +Q_PATH="" +[ -f "${CALIB_DIR}/lut.npy" ] && LUT_PATH="${CALIB_DIR}/lut.npy" +[ -f "${CALIB_DIR}/quantiles.npy" ] && Q_PATH="${CALIB_DIR}/quantiles.npy" + +# ============================================================ +# Auto-tune — rank by fused-topk kernel latency +# ============================================================ +AUTOTUNE_JSON="${RESULTS_DIR}/autotune_${TIMESTAMP}.json" +if [ "${SKIP_AUTOTUNE}" -eq 0 ]; then + AUTOTUNE_EXTRA=() + [ -n "${LUT_PATH}" ] && AUTOTUNE_EXTRA+=(--lut-path "${LUT_PATH}") + [ -n "${Q_PATH}" ] && AUTOTUNE_EXTRA+=(--quantiles-path "${Q_PATH}") + if [ -f "${REAL_HISTOGRAMS}" ]; then + echo "============================================================" + echo "Auto-tuning hyperparameters (real distribution, latency-ranked)" + echo "============================================================" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val "${TOPK_VAL}" \ + --batch-size "${BATCH_SIZE}" \ + --seq-len "${SEQ_LEN}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --page-size "${BLOCK_SIZE}" \ + --real-histograms "${REAL_HISTOGRAMS}" \ + "${AUTOTUNE_EXTRA[@]}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RESULTS_DIR}/autotune_${TIMESTAMP}.log" + echo ">>> Auto-tune results saved to ${AUTOTUNE_JSON}" + else + echo ">>> WARNING: ${REAL_HISTOGRAMS} not found — autotune will use synthetic data" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val "${TOPK_VAL}" \ + --batch-size "${BATCH_SIZE}" \ + --seq-len "${SEQ_LEN}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --page-size "${BLOCK_SIZE}" \ + "${AUTOTUNE_EXTRA[@]}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RESULTS_DIR}/autotune_${TIMESTAMP}.log" + fi +fi - # Extract best per-mode hyperparameters from autotune JSON - eval "$(python3 -c " +# Extract best per-mode hparam (ranked by kernel latency, lowest wins). +eval "$(python3 -c " import json, sys data = json.load(open(sys.argv[1])) best = {} for r in data: - m = r.get('mode') - if m in (3, 6, 7, 9, 10): - if m not in best or r['gini'] < best[m]['gini']: - best[m] = r -for m in (3, 6, 7, 9, 10): - print(f'BEST_HPARAM_{m}={best[m][\"param\"]}' if m in best else f'BEST_HPARAM_{m}=0.5') + m = r.get('mode'); lat = r.get('latency_ms') + if m is None or lat is None: continue + if m not in best or lat < best[m]['latency_ms']: + best[m] = r +for m in (3, 6, 7, 9, 10, 11, 13): + print(f'BEST_HPARAM_{m}={best.get(m, {}).get(\"param\", 0.5)}') " "${AUTOTUNE_JSON}")" - echo ">>> Autotuned best powers: mode3=${BEST_HPARAM_3} mode6=${BEST_HPARAM_6} mode7=${BEST_HPARAM_7} mode9=${BEST_HPARAM_9} mode10=${BEST_HPARAM_10}" - echo "" -else - echo ">>> WARNING: ${REAL_HISTOGRAMS} not found, using default power=0.5 for all modes" - BEST_HPARAM_3=0.5 - BEST_HPARAM_6=0.5 - BEST_HPARAM_7=0.5 - BEST_HPARAM_9=0.5 - BEST_HPARAM_10=0.5 -fi - -# ============================================================ -# Mode 1: LUT CDF with calibrated LUT -# ============================================================ -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_1_calibrated_${TIMESTAMP}.log" - echo ">>> Running mode 1 (LUT CDF) with calibrated LUT for ${algo}" - echo ">>> Saving results to ${OUTFILE}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 1 \ - --topk-mapping-lut-path "${CALIBRATION_DIR}/lut.npy" \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done - -# ============================================================ -# Mode 2: Quantile with calibrated quantiles -# ============================================================ -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_2_calibrated_${TIMESTAMP}.log" - echo ">>> Running mode 2 (quantile) with calibrated quantiles for ${algo}" - echo ">>> Saving results to ${OUTFILE}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 2 \ - --topk-mapping-quantiles-path "${CALIBRATION_DIR}/quantiles.npy" \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done - -# ============================================================ -# sglang topk: non-parametric modes (0, 4, 8, 11) -# ============================================================ -for algo in "${sparse_algos[@]}"; do - for topk_mapping_mode in 0 4 8 11; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_${topk_mapping_mode}_${TIMESTAMP}.log" - echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --topk-type sglang --topk-mapping-mode ${topk_mapping_mode}" - echo ">>> Saving results to ${OUTFILE}" +echo ">>> Autotuned hparams (lowest fused-topk latency):" +echo " mode3=${BEST_HPARAM_3} mode6=${BEST_HPARAM_6} mode7=${BEST_HPARAM_7}" +echo " mode9=${BEST_HPARAM_9} mode10=${BEST_HPARAM_10} mode11=${BEST_HPARAM_11} mode13=${BEST_HPARAM_13}" +echo "" +run_mapped() { + # $1=mode $2=hparam $3=label + local mode="$1"; local hp="$2"; local label="$3" + for algo in "${sparse_algos[@]}"; do + local out="${RESULTS_DIR}/topk_mapping_${algo}_${label}_${TIMESTAMP}.log" + echo ">>> ${label} algo=${algo}" + local extra=() + if [ "${mode}" -eq 0 ]; then + extra+=(--topk-type sglang) + else + extra+=(--topk-type sglang_fused --topk-mapping-mode "${mode}" --topk-mapping-hparam "${hp}") + fi { time python verify_algo.py \ --trials 8 \ - --topk-val 30 \ + --topk-val "${TOPK_VAL}" \ --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode ${topk_mapping_mode} \ + --model-name "${MODEL_NAME}" \ --benchmark ${BENCHMARKS} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" + --mem 0.7 \ + "${extra[@]}" ; } \ + 2>&1 | tee "${out}" done -done +} + +run_mapped 0 0.5 "sglang_m0" +run_mapped 3 "${BEST_HPARAM_3}" "sglang_m3_p${BEST_HPARAM_3}" +run_mapped 4 0.5 "sglang_m4" +run_mapped 6 "${BEST_HPARAM_6}" "sglang_m6_beta${BEST_HPARAM_6}" +run_mapped 7 "${BEST_HPARAM_7}" "sglang_m7_alpha${BEST_HPARAM_7}" +run_mapped 8 0.5 "sglang_m8" +run_mapped 9 "${BEST_HPARAM_9}" "sglang_m9_alpha${BEST_HPARAM_9}" +run_mapped 10 "${BEST_HPARAM_10}" "sglang_m10_alpha${BEST_HPARAM_10}" +run_mapped 11 "${BEST_HPARAM_11}" "sglang_m11_pivot${BEST_HPARAM_11}" +run_mapped 13 "${BEST_HPARAM_13}" "sglang_m13_alpha${BEST_HPARAM_13}" -# ============================================================ -# Mode 3: power — autotuned best p -# ============================================================ -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_3_p${BEST_HPARAM_3}_${TIMESTAMP}.log" - echo ">>> Running mode 3 (power) p=${BEST_HPARAM_3} (autotuned) for ${algo}" - echo ">>> Saving results to ${OUTFILE}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 3 \ - --topk-mapping-hparam ${BEST_HPARAM_3} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done - -# ============================================================ -# Mode 6: asinh — autotuned best beta -# ============================================================ -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_6_beta${BEST_HPARAM_6}_${TIMESTAMP}.log" - echo ">>> Running mode 6 (asinh) beta=${BEST_HPARAM_6} (autotuned) for ${algo}" - echo ">>> Saving results to ${OUTFILE}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 6 \ - --topk-mapping-hparam ${BEST_HPARAM_6} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done - -# ============================================================ -# Mode 7: log1p — autotuned best alpha -# ============================================================ -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_7_alpha${BEST_HPARAM_7}_${TIMESTAMP}.log" - echo ">>> Running mode 7 (log1p) alpha=${BEST_HPARAM_7} (autotuned) for ${algo}" - echo ">>> Saving results to ${OUTFILE}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 7 \ - --topk-mapping-hparam ${BEST_HPARAM_7} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done - -# ============================================================ -# Mode 9: erf — autotuned best alpha -# ============================================================ -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_9_alpha${BEST_HPARAM_9}_${TIMESTAMP}.log" - echo ">>> Running mode 9 (erf) alpha=${BEST_HPARAM_9} (autotuned) for ${algo}" - echo ">>> Saving results to ${OUTFILE}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 9 \ - --topk-mapping-hparam ${BEST_HPARAM_9} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done - -# ============================================================ -# Mode 10: tanh — autotuned best alpha -# ============================================================ -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_10_alpha${BEST_HPARAM_10}_${TIMESTAMP}.log" - echo ">>> Running mode 10 (tanh) alpha=${BEST_HPARAM_10} (autotuned) for ${algo}" - echo ">>> Saving results to ${OUTFILE}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 10 \ - --topk-mapping-hparam ${BEST_HPARAM_10} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done - -# ============================================================ -# Counter profiling: collect COUNTER_NUM_EQUAL for all modes -# ============================================================ echo "" echo "============================================================" -echo "Counter profiling: COUNTER_NUM_EQUAL per mode (topk=30)" +echo "All runs complete. Results in ${RESULTS_DIR}/" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" +echo " Auto-tune: ${AUTOTUNE_JSON}" echo "============================================================" -COUNTER_JSON="${RESULTS_DIR}/counters_${TIMESTAMP}.json" -PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ - --batch-sizes 4 \ - --seq-lens 4096 \ - --topk-vals 30 \ - --num-kv-heads 2 \ - --distributions normal \ - --counters \ - --filter-kernels sglang_ori sglang_m0 sglang_m3 sglang_m6 sglang_m7 sglang_m8 sglang_m9 sglang_m10 sglang_m11 \ - --repeat 5 \ - --output-json "${COUNTER_JSON}" \ - 2>&1 | tee "${RESULTS_DIR}/counters_${TIMESTAMP}.log" -echo ">>> Counters saved to ${COUNTER_JSON}" \ No newline at end of file diff --git a/examples/verify_algo_topk_mapping_new.sh b/examples/verify_algo_topk_mapping_new.sh index 6848e1e..9116b72 100644 --- a/examples/verify_algo_topk_mapping_new.sh +++ b/examples/verify_algo_topk_mapping_new.sh @@ -1,370 +1,217 @@ #!/usr/bin/env bash +# ============================================================ +# E2E accuracy sweep over the surviving parametric mapping modes. +# Each mode runs verify_algo.py with the per-mode hyperparameter +# that autotune_topk_mapping.py picked as having the lowest +# measured fused-topk-kernel latency. +# +# Mapping modes (after the lean refactor): +# 0: None — unmapped baseline (no remap) +# 3: Power — y = sign(x) * |x|^p +# 4: Log — y = sign(x) * log(|x| + 1) [no knob] +# 6: Asinh — y = asinh(beta * x) +# 7: Log1p — y = sign(x) * log1p(alpha * |x|) +# 9: Erf — y = erf(alpha * x) +# 10: Tanh — y = tanh(alpha * x) +# 13: ExpStretch — y = exp(alpha * x) +# ============================================================ set -e -# use CUDA_VISIBLE_DEVICES to set the GPU id you want to use -# Mapping functions: -# 0: None — original fp16 bit-pattern bucketing -# 1: LUT CDF — LUT-based CDF equalization (calibrated) -# 2: Quantile — piecewise-linear quantile mapping (calibrated) -# 3: Power — y = sign(x) * |x|^p -# 4: Log — y = sign(x) * log(|x| + 1) -# 5: Index Cache — reuse previous layer's indices -# 6: Asinh — y = asinh(beta * x) -# 7: Log1p — y = sign(x) * log1p(alpha * |x|) -# 8: Trunc8 — bf16 upper-8-bit bucketing -# 9: Erf — y = erf(alpha * x) -# 10: Tanh — y = tanh(alpha * x) -# 11: Subtract — x - pivot (RadiK-style scatter) SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BENCH_DIR="${SCRIPT_DIR}/../benchmarks" # ── Defaults ────────────────────────────────────────────────── GPU_ID=5 TOPK_VAL=30 -BENCHMARKS="amc23" # space-separated list, e.g. "amc23 aime24" +BENCHMARKS="amc23" +MODEL_NAME="Qwen/Qwen3-1.7B" +BLOCK_SIZE=16 +BATCH_SIZE=4 +NUM_KV_HEADS=2 +SEQ_LEN=32768 +REAL_HISTOGRAMS="" +SKIP_AUTOTUNE=0 -# ── Parse arguments ─────────────────────────────────────────── while [[ $# -gt 0 ]]; do case "$1" in - --topk-val) TOPK_VAL="$2"; shift 2 ;; - --gpu) GPU_ID="$2"; shift 2 ;; - --benchmark) BENCHMARKS="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --benchmark) BENCHMARKS="$2"; shift 2 ;; + --model-name) MODEL_NAME="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --skip-autotune) SKIP_AUTOTUNE=1; shift 1 ;; *) echo "Unknown option: $1"; exit 1 ;; esac done export CUDA_VISIBLE_DEVICES="${GPU_ID}" -sparse_algos=( - "block_sparse_attention" -) - -# Path to real-data histograms from calibration (for auto-tuning) -REAL_HISTOGRAMS="/data/datasets/xinrui/My_Projects/vortex_torch/examples/calibration/raw_histograms.npy" +sparse_algos=( "block_sparse_attention" ) BENCH_LABEL=$(echo "${BENCHMARKS}" | tr ' ' '_') -RESULTS_DIR="results/topk${TOPK_VAL}_${BENCH_LABEL}" +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +RESULTS_DIR="results/topk_mapping_${MODEL_SLUG}_topk${TOPK_VAL}_${BENCH_LABEL}" mkdir -p "${RESULTS_DIR}" TIMESTAMP=$(date +%Y%m%d_%H%M%S) # ============================================================ -# Step 0: Auto-tune — find best hyperparameters per mode -# Uses topk_profile_histogram kernel on synthetic data (fast, no model) +# Step 0: Calibrate (optional) — real-distribution histograms # ============================================================ -echo "============================================================" -echo "Step 0: Auto-tuning hyperparameters (synthetic data)" -echo "============================================================" -AUTOTUNE_JSON="${RESULTS_DIR}/autotune_${TIMESTAMP}.json" -PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ - --topk-val ${TOPK_VAL} \ - --batch-size 4 \ - --seq-len 32768 \ - --num-kv-heads 2 \ - --real-histograms "${REAL_HISTOGRAMS}" \ - --output-json "${AUTOTUNE_JSON}" \ - 2>&1 | tee "${RESULTS_DIR}/autotune_${TIMESTAMP}.log" -echo ">>> Auto-tune results saved to ${AUTOTUNE_JSON}" -echo "" +if [ -z "${REAL_HISTOGRAMS}" ]; then + echo "============================================================" + echo "Step 0: Calibrating ${MODEL_NAME} for real-distribution histograms" + echo "============================================================" + CAL_DIR="${RESULTS_DIR}/calibration_${TIMESTAMP}" + mkdir -p "${CAL_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --mem 0.7 \ + --vortex-module-name "${sparse_algos[0]}" \ + --page-size "${BLOCK_SIZE}" \ + --output-dir "${CAL_DIR}" \ + 2>&1 | tee "${RESULTS_DIR}/calibrate_${TIMESTAMP}.log" + REAL_HISTOGRAMS="${CAL_DIR}/raw_histograms.npy" +fi + +# Pick up lut.npy / quantiles.npy if calibration produced them. +CALIB_DIR="$(dirname "${REAL_HISTOGRAMS}")" +LUT_PATH="" +Q_PATH="" +[ -f "${CALIB_DIR}/lut.npy" ] && LUT_PATH="${CALIB_DIR}/lut.npy" +[ -f "${CALIB_DIR}/quantiles.npy" ] && Q_PATH="${CALIB_DIR}/quantiles.npy" # ============================================================ -# Extract best per-mode hyperparameters from autotune JSON +# Step 1: Auto-tune — rank by profiled fused-topk kernel latency # ============================================================ +AUTOTUNE_JSON="${RESULTS_DIR}/autotune_${TIMESTAMP}.json" +if [ "${SKIP_AUTOTUNE}" -eq 0 ]; then + echo "============================================================" + echo "Step 1: Auto-tuning hyperparameters by fused-topk kernel latency" + echo "============================================================" + AUTOTUNE_EXTRA=() + [ -n "${LUT_PATH}" ] && AUTOTUNE_EXTRA+=(--lut-path "${LUT_PATH}") + [ -n "${Q_PATH}" ] && AUTOTUNE_EXTRA+=(--quantiles-path "${Q_PATH}") + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val ${TOPK_VAL} \ + --batch-size ${BATCH_SIZE} \ + --seq-len ${SEQ_LEN} \ + --num-kv-heads ${NUM_KV_HEADS} \ + --page-size ${BLOCK_SIZE} \ + --real-histograms "${REAL_HISTOGRAMS}" \ + "${AUTOTUNE_EXTRA[@]}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RESULTS_DIR}/autotune_${TIMESTAMP}.log" + echo ">>> Auto-tune results saved to ${AUTOTUNE_JSON}" +fi + +# Extract best per-mode hparam (ranked by measured kernel latency, lowest wins) eval "$(python3 -c " import json, sys data = json.load(open(sys.argv[1])) best = {} for r in data: m = r.get('mode') - if m in (3, 6, 7, 9, 10, 13, 14): - if m not in best or r['gini'] < best[m]['gini']: - best[m] = r -for m in (3, 6, 7, 9, 10, 13, 14): - print(f'BEST_HPARAM_{m}={best[m][\"param\"]}' if m in best else f'BEST_HPARAM_{m}=0.5') + lat = r.get('latency_ms') + if m is None or lat is None: continue + if m not in best or lat < best[m]['latency_ms']: + best[m] = r +for m in (3, 6, 7, 9, 10, 11, 13): + v = best.get(m, {}).get('param', 0.5) + print(f'BEST_HPARAM_{m}={v}') " "${AUTOTUNE_JSON}")" -echo ">>> Autotuned best powers: mode3=${BEST_HPARAM_3} mode6=${BEST_HPARAM_6} mode7=${BEST_HPARAM_7} mode9=${BEST_HPARAM_9} mode10=${BEST_HPARAM_10} mode13=${BEST_HPARAM_13} mode14=${BEST_HPARAM_14}" +echo ">>> Autotuned hparams (lowest topk kernel latency):" +echo " mode3=${BEST_HPARAM_3} mode6=${BEST_HPARAM_6} mode7=${BEST_HPARAM_7}" +echo " mode9=${BEST_HPARAM_9} mode10=${BEST_HPARAM_10} mode11=${BEST_HPARAM_11} mode13=${BEST_HPARAM_13}" echo "" -# ============================================================ -# Baseline: Original sglang kernel (no remap) -# ============================================================ -echo "============================================================" -echo "Baseline: sglang_ori (no remap)" -echo "============================================================" -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_ori_${TIMESTAMP}.log" - echo ">>> sglang_ori algo=${algo}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val ${TOPK_VAL} \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang_ori \ - --benchmark ${BENCHMARKS} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done +run_verify() { + # $1=mode $2=hparam $3=label + local mode="$1"; local hp="$2"; local label="$3" + for algo in "${sparse_algos[@]}"; do + local out="${RESULTS_DIR}/topk_mapping_${algo}_${label}_${TIMESTAMP}.log" + echo ">>> ${label} algo=${algo}" + local extra_args=() + if [ "${mode}" -eq 0 ]; then + extra_args+=(--topk-type sglang) + else + extra_args+=(--topk-type sglang_fused --topk-mapping-mode "${mode}" --topk-mapping-hparam "${hp}") + fi + { time python verify_algo.py \ + --trials 8 \ + --topk-val "${TOPK_VAL}" \ + --vortex-module-name "${algo}" \ + --model-name "${MODEL_NAME}" \ + --benchmark ${BENCHMARKS} \ + --mem 0.7 \ + "${extra_args[@]}" ; } \ + 2>&1 | tee "${out}" + done +} -# ============================================================ -# Step 1: Mode 3 (power) — autotuned best p -# ============================================================ echo "============================================================" -echo "Step 1: Mode 3 (power) — p=${BEST_HPARAM_3} (autotuned)" +echo "Baseline: sglang (no remap)" echo "============================================================" -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_3_p${BEST_HPARAM_3}_${TIMESTAMP}.log" - echo ">>> Mode 3 (power) p=${BEST_HPARAM_3} algo=${algo}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val ${TOPK_VAL} \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 3 \ - --topk-mapping-hparam ${BEST_HPARAM_3} \ - --benchmark ${BENCHMARKS} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done +run_verify 0 0.5 "sglang_m0" -# ============================================================ -# Step 2: Mode 6 (asinh) — autotuned best beta -# ============================================================ echo "============================================================" -echo "Step 2: Mode 6 (asinh) — beta=${BEST_HPARAM_6} (autotuned)" +echo "Mode 3 (power) — p=${BEST_HPARAM_3} (autotuned)" echo "============================================================" -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_6_beta${BEST_HPARAM_6}_${TIMESTAMP}.log" - echo ">>> Mode 6 (asinh) beta=${BEST_HPARAM_6} algo=${algo}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val ${TOPK_VAL} \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 6 \ - --topk-mapping-hparam ${BEST_HPARAM_6} \ - --benchmark ${BENCHMARKS} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done +run_verify 3 "${BEST_HPARAM_3}" "sglang_m3_p${BEST_HPARAM_3}" -# ============================================================ -# Step 3: Mode 7 (log1p) — autotuned best alpha -# ============================================================ echo "============================================================" -echo "Step 3: Mode 7 (log1p) — alpha=${BEST_HPARAM_7} (autotuned)" +echo "Mode 4 (log)" echo "============================================================" -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_7_alpha${BEST_HPARAM_7}_${TIMESTAMP}.log" - echo ">>> Mode 7 (log1p) alpha=${BEST_HPARAM_7} algo=${algo}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val ${TOPK_VAL} \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 7 \ - --topk-mapping-hparam ${BEST_HPARAM_7} \ - --benchmark ${BENCHMARKS} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done +run_verify 4 0.5 "sglang_m4" -# ============================================================ -# Step 4: Mode 8 (trunc8) — fixed parameter -# ============================================================ echo "============================================================" -echo "Step 4: Mode 8 (trunc8)" +echo "Mode 6 (asinh) — beta=${BEST_HPARAM_6} (autotuned)" echo "============================================================" -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_8_${TIMESTAMP}.log" - echo ">>> Mode 8 (trunc8) algo=${algo}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val ${TOPK_VAL} \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 8 \ - --benchmark ${BENCHMARKS} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done +run_verify 6 "${BEST_HPARAM_6}" "sglang_m6_beta${BEST_HPARAM_6}" -# ============================================================ -# Step 5: Mode 9 (erf) — autotuned best alpha -# ============================================================ echo "============================================================" -echo "Step 5: Mode 9 (erf) — alpha=${BEST_HPARAM_9} (autotuned)" +echo "Mode 7 (log1p) — alpha=${BEST_HPARAM_7} (autotuned)" echo "============================================================" -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_9_alpha${BEST_HPARAM_9}_${TIMESTAMP}.log" - echo ">>> Mode 9 (erf) alpha=${BEST_HPARAM_9} algo=${algo}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val ${TOPK_VAL} \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 9 \ - --topk-mapping-hparam ${BEST_HPARAM_9} \ - --benchmark ${BENCHMARKS} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done +run_verify 7 "${BEST_HPARAM_7}" "sglang_m7_alpha${BEST_HPARAM_7}" -# ============================================================ -# Step 6: Mode 10 (tanh) — autotuned best alpha -# ============================================================ echo "============================================================" -echo "Step 6: Mode 10 (tanh) — alpha=${BEST_HPARAM_10} (autotuned)" +echo "Mode 9 (erf) — alpha=${BEST_HPARAM_9} (autotuned)" echo "============================================================" -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_10_alpha${BEST_HPARAM_10}_${TIMESTAMP}.log" - echo ">>> Mode 10 (tanh) alpha=${BEST_HPARAM_10} algo=${algo}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val ${TOPK_VAL} \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 10 \ - --topk-mapping-hparam ${BEST_HPARAM_10} \ - --benchmark ${BENCHMARKS} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done +run_verify 9 "${BEST_HPARAM_9}" "sglang_m9_alpha${BEST_HPARAM_9}" -# ============================================================ -# Step 7: Mode 11 (subtract) — fixed parameter -# ============================================================ echo "============================================================" -echo "Step 7: Mode 11 (subtract)" +echo "Mode 10 (tanh) — alpha=${BEST_HPARAM_10} (autotuned)" echo "============================================================" -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_11_${TIMESTAMP}.log" - echo ">>> Mode 11 (subtract) algo=${algo}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val ${TOPK_VAL} \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 11 \ - --benchmark ${BENCHMARKS} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done +run_verify 10 "${BEST_HPARAM_10}" "sglang_m10_alpha${BEST_HPARAM_10}" -# ============================================================ -# Step 8: Mode 12 (adaptive_tail_window), rho=4.0 -# ============================================================ -echo "" echo "============================================================" -echo "Step 8: Mode 12 (adaptive_tail_window), rho=4.0" +echo "Mode 8 (trunc8)" echo "============================================================" -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_12_${TIMESTAMP}.log" - echo ">>> Mode 12 (adaptive_tail_window) algo=${algo}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val ${TOPK_VAL} \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 12 \ - --topk-mapping-hparam 4.0 \ - --benchmark ${BENCHMARKS} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done +run_verify 8 0.5 "sglang_m8" -# ============================================================ -# Step 9: Mode 13 (exp_stretch) — autotuned best alpha -# ============================================================ -echo "" -echo "============================================================" -echo "Step 9: Mode 13 (exp_stretch) — alpha=${BEST_HPARAM_13} (autotuned)" -echo "============================================================" -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_13_alpha${BEST_HPARAM_13}_${TIMESTAMP}.log" - echo ">>> Mode 13 (exp_stretch) alpha=${BEST_HPARAM_13} algo=${algo}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val ${TOPK_VAL} \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 13 \ - --topk-mapping-hparam ${BEST_HPARAM_13} \ - --benchmark ${BENCHMARKS} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done - -# ============================================================ -# Step 10: Mode 14 (topk_window) — autotuned best rho -# ============================================================ -echo "" echo "============================================================" -echo "Step 10: Mode 14 (topk_window) — rho=${BEST_HPARAM_14} (autotuned)" +echo "Mode 11 (subtract) — pivot=${BEST_HPARAM_11} (autotuned)" echo "============================================================" -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_14_rho${BEST_HPARAM_14}_${TIMESTAMP}.log" - echo ">>> Mode 14 (topk_window) rho=${BEST_HPARAM_14} algo=${algo}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val ${TOPK_VAL} \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 14 \ - --topk-mapping-hparam ${BEST_HPARAM_14} \ - --benchmark ${BENCHMARKS} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done +run_verify 11 "${BEST_HPARAM_11}" "sglang_m11_pivot${BEST_HPARAM_11}" -# ============================================================ -# Counter profiling: collect COUNTER_NUM_EQUAL for all modes -# (single extra kernel call per mode, no overhead on accuracy runs) -# ============================================================ -echo "" echo "============================================================" -echo "Counter profiling: COUNTER_NUM_EQUAL per mode (topk=${TOPK_VAL})" +echo "Mode 13 (exp_stretch) — alpha=${BEST_HPARAM_13} (autotuned)" echo "============================================================" -COUNTER_JSON="${RESULTS_DIR}/counters_${TIMESTAMP}.json" -PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ - --batch-sizes 4 \ - --seq-lens 4096 \ - --topk-vals ${TOPK_VAL} \ - --num-kv-heads 2 \ - --distributions normal \ - --counters \ - --real-histograms "${REAL_HISTOGRAMS}" \ - --autotune-json "${AUTOTUNE_JSON}" \ - --filter-kernels sglang_ori sglang_m0 sglang_m3 sglang_m6 sglang_m7 sglang_m8 sglang_m9 sglang_m10 sglang_m11 sglang_m13 sglang_m14 \ - --mapping-hparam-13 ${BEST_HPARAM_13} --mapping-hparam-14 ${BEST_HPARAM_14} \ - --repeat 5 \ - --output-json "${COUNTER_JSON}" \ - 2>&1 | tee "${RESULTS_DIR}/counters_${TIMESTAMP}.log" -echo ">>> Counters saved to ${COUNTER_JSON}" +run_verify 13 "${BEST_HPARAM_13}" "sglang_m13_alpha${BEST_HPARAM_13}" -# ============================================================ -# Summary -# ============================================================ echo "" echo "============================================================" echo "All runs complete. Results in ${RESULTS_DIR}/" -echo " Auto-tune: ${AUTOTUNE_JSON}" -echo " Counters: ${COUNTER_JSON}" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" +echo " Auto-tune: ${AUTOTUNE_JSON}" echo " Mode 3 (power): p = ${BEST_HPARAM_3} (autotuned)" echo " Mode 6 (asinh): beta = ${BEST_HPARAM_6} (autotuned)" echo " Mode 7 (log1p): alpha = ${BEST_HPARAM_7} (autotuned)" -echo " Mode 8 (trunc8): (fixed)" echo " Mode 9 (erf): alpha = ${BEST_HPARAM_9} (autotuned)" echo " Mode 10 (tanh): alpha = ${BEST_HPARAM_10} (autotuned)" -echo " Mode 11 (subtract): (fixed)" -echo " Mode 12 (tail_win): rho = 4.0" echo " Mode 13 (exp_stretch):alpha = ${BEST_HPARAM_13} (autotuned)" -echo " Mode 14 (topk_window):rho = ${BEST_HPARAM_14} (autotuned)" echo "============================================================" diff --git a/examples/verify_sparse_backends.sh b/examples/verify_external_backends.sh similarity index 100% rename from examples/verify_sparse_backends.sh rename to examples/verify_external_backends.sh diff --git a/vortex_torch/indexer/context.py b/vortex_torch/indexer/context.py index 8142fbc..17dea66 100644 --- a/vortex_torch/indexer/context.py +++ b/vortex_torch/indexer/context.py @@ -1,6 +1,5 @@ from __future__ import annotations from typing import Any, Final, Union -import numpy as np import torch from ..abs import ContextBase from ..utils import UNSET, Mode @@ -24,8 +23,7 @@ class Context(ContextBase): "num_sms", "page_size", "max_num_pages", "max_num_pages_per_request", # misc "indexer_dtype", "topk_val", "page_reserved_bos", "page_reserved_eos", "topk_type", - "topk_mapping_mode", "topk_mapping_power", "topk_mapping_lut", "topk_mapping_quantiles", - "topk_mapping_noscale", + "topk_mapping_mode", "topk_mapping_power", "topk_histogram_enabled", # auxilary memory in graph @@ -72,13 +70,10 @@ class Context(ContextBase): topk_val: int #: Top-K value used in pruning or selection. page_reserved_bos: int #: Reserved page count for BOS (begin-of-sequence). page_reserved_eos: int #: Reserved page count for EOS (end-of-sequence). - topk_type: str #: TopK kernel type: "naive" or "sglang". - topk_mapping_mode: int #: TopK mapping mode (0=none, 1=lut, 2=quantile, 3=power, 4=log). - topk_mapping_power: float #: Power exponent for mapping mode 3. - topk_mapping_lut: object #: Optional uint8[256] LUT tensor for mapping mode 1. - topk_mapping_quantiles: object #: Optional float32[256] quantiles tensor for mapping mode 2. - topk_mapping_noscale: bool #: Skip auto-range linear scaling, use fp16 bucketing on f(x) (default False). - topk_histogram_enabled: bool #: Enable histogram profiling during inference (default False). + topk_type: str #: TopK kernel type: "naive", "sglang" (unmapped) or "sglang_fused" (remap+topk). + topk_mapping_mode: int #: TopK mapping mode for sglang_fused (0=none, 3=power, 4=log, 6=asinh, 7=log1p, 9=erf, 10=tanh, 13=exp_stretch). + topk_mapping_power: float #: Hyperparameter (p / alpha / beta) for the active mapping mode. + topk_histogram_enabled: bool #: Enable histogram profiling during inference (default False). # --- auxiliary --- _aux_total_bytes: int #: Accumulated auxiliary memory in bytes. @@ -158,22 +153,10 @@ def create(self, parent: Any, model_runner: Any, *, overwrite: bool = False) -> self.topk_type = getattr(sa, "vortex_topk_type", "naive") self.topk_mapping_mode = getattr(sa, "vortex_topk_mapping_mode", 0) self.topk_mapping_power = getattr(sa, "vortex_topk_mapping_power", 0.5) - self.topk_mapping_noscale = getattr(sa, "vortex_topk_mapping_noscale", False) self.topk_histogram_enabled = getattr(sa, "vortex_topk_histogram", False) device = getattr(model_runner, "device", "cpu") - # Load calibration data from .npy files when paths are provided - lut_path = getattr(sa, 'vortex_topk_mapping_lut_path', None) - if lut_path is not None: - lut_np = np.load(lut_path).astype(np.uint8) - self.topk_mapping_lut = torch.from_numpy(lut_np).to(device) - - quantiles_path = getattr(sa, 'vortex_topk_mapping_quantiles_path', None) - if quantiles_path is not None: - q_np = np.load(quantiles_path).astype(np.float32) - self.topk_mapping_quantiles = torch.from_numpy(q_np).to(device) - self.max_num_workloads = ( (self.max_num_pages // max(1, sa.vortex_lb_min_chunk_size)) + max_bs * self.num_kv_heads ) diff --git a/vortex_torch/indexer/output_func.py b/vortex_torch/indexer/output_func.py index e4424cd..889e068 100644 --- a/vortex_torch/indexer/output_func.py +++ b/vortex_torch/indexer/output_func.py @@ -1,10 +1,9 @@ import torch from typing import Dict, Callable, List, Optional from ..abs import vOp -from vortex_torch_C import topk_output, topk_output_sglang, topk_output_sglang_ori, topk_profile_histogram +from vortex_torch_C import topk_output, topk_output_sglang, topk_output_sglang_fused, topk_profile_histogram from .context import Context from ..abs import vTensor, FORMAT -from ..utils import UNSET # --- Module-level histogram accumulator for offline calibration --- _calibration_histograms: List[torch.Tensor] = [] @@ -91,7 +90,7 @@ class topK(vOp): FORMAT.RAGGED: { "naive": topk_output, "sglang": topk_output_sglang, - "sglang_ori": topk_output_sglang_ori, + "sglang_fused": topk_output_sglang_fused, }, } @@ -245,17 +244,7 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso assert self.impl is not None, f"{prefix}execute called before profile() (impl is None)" if self.topk_type == "sglang": - # topk_output_sglang: (x, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, sparse_kv_indices, ...) - mapping_mode = getattr(ctx, 'topk_mapping_mode', 0) - mapping_hparam = getattr(ctx, 'topk_mapping_hparam', getattr(ctx, 'topk_mapping_power', 0.5)) - mapping_lut = getattr(ctx, 'topk_mapping_lut', None) - mapping_quantiles = getattr(ctx, 'topk_mapping_quantiles', None) - mapping_noscale = getattr(ctx, 'topk_mapping_noscale', False) - # UNSET sentinel is not a valid torch.Tensor — coerce to None - if mapping_lut is UNSET: - mapping_lut = None - if mapping_quantiles is UNSET: - mapping_quantiles = None + # topk_output_sglang: unmapped baseline (no remap). self.impl( x, ctx.dense_kv_indptr, @@ -267,14 +256,14 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso ctx.page_reserved_bos, ctx.page_reserved_eos, ctx.max_num_pages_per_request, - mapping_mode, - mapping_hparam, - mapping_lut, - mapping_quantiles, - mapping_noscale, ) - elif self.topk_type == "sglang_ori": - # topk_output_sglang_ori: same CSR interface, no mapping params + elif self.topk_type == "sglang_fused": + # topk_output_sglang_fused: single-launch fused remap + topk. + mapping_mode = getattr(ctx, 'topk_mapping_mode', 0) + mapping_power = getattr( + ctx, 'topk_mapping_hparam', + getattr(ctx, 'topk_mapping_power', 0.5), + ) self.impl( x, ctx.dense_kv_indptr, @@ -286,6 +275,8 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso ctx.page_reserved_bos, ctx.page_reserved_eos, ctx.max_num_pages_per_request, + int(mapping_mode), + float(mapping_power), ) else: # topk_output (naive): (x, dense_kv_indptr, dense_kv_indices, sparse_kv_indptr, sparse_kv_indices, ...) @@ -307,11 +298,19 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso # are not permitted while a stream is being captured. if ( getattr(ctx, 'topk_histogram_enabled', False) - and self.topk_type == "sglang" + and self.topk_type in ("sglang", "sglang_fused") and not torch.cuda.is_current_stream_capturing() ): eff_bs = ctx.batch_size * ctx.num_kv_heads self.last_histograms = torch.zeros(eff_bs, 256, dtype=torch.int32, device=x.device) + hist_mode = 0 + hist_power = 0.5 + if self.topk_type == "sglang_fused": + hist_mode = int(getattr(ctx, 'topk_mapping_mode', 0)) + hist_power = float(getattr( + ctx, 'topk_mapping_hparam', + getattr(ctx, 'topk_mapping_power', 0.5), + )) topk_profile_histogram( x, ctx.dense_kv_indptr, @@ -319,12 +318,8 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso eff_bs, ctx.page_reserved_bos, ctx.page_reserved_eos, - mapping_mode, - mapping_hparam, - mapping_lut, - mapping_quantiles, - mapping_noscale, - ctx.topk_val, + hist_mode, + hist_power, ) # Accumulate histograms for offline calibration _calibration_histograms.append(self.last_histograms.cpu().clone()) From aecde11194f8ce9a55a4790d27ffc7e46fe0bb8b Mon Sep 17 00:00:00 2001 From: UED Date: Mon, 13 Apr 2026 03:15:11 -0400 Subject: [PATCH 19/22] Refactor TopK mapping and benchmarking scripts for enhanced profiling and usability - Updated autotune_topk_mapping.py to optimize hyperparameter tuning based on kernel latency. - Simplified the sweep grid and improved documentation for usage. - Enhanced bench_topk.py to expose public helpers and added CLI modes for benchmarking. - Introduced new remap functions and improved kernel integration for profiling. - Added watchdog timeout option in calibrate_topk.py for SGLang scheduler. - Removed outdated greedy_layer_search.py as part of code cleanup. --- examples/run_topk_benchmark.sh | 332 +++++++++++++-------------------- 1 file changed, 128 insertions(+), 204 deletions(-) diff --git a/examples/run_topk_benchmark.sh b/examples/run_topk_benchmark.sh index d57e2f1..33c6e40 100755 --- a/examples/run_topk_benchmark.sh +++ b/examples/run_topk_benchmark.sh @@ -1,54 +1,47 @@ #!/usr/bin/env bash # ============================================================ -# TopK Benchmark +# Unified TopK Benchmark # -# Compares ALL TopK kernel variants under controlled conditions: -# Step 1: Calibrate (for modes 1/2) -# Step 2: Kernel-level latency (bench_topk.py, all 6 modes) -# Step 3: E2E accuracy (verify_algo.py) -# - Full-attention baseline first -# - Then naive, sglang mode 0/1/2/3/4 -# - Same model, same prompts, deterministic sampling -# -# Fairness improvements over verify_algo_topk_mapping.sh: -# - Full-attention baseline for absolute reference -# - All modes in one sweep (including calibrated 1/2) -# - Sequential runs on same CUDA device minimize interference -# - Deterministic sampling (temperature=0) for reproducibility -# - Results saved to a single timestamped directory +# Three-step pipeline on a single configurable model: +# Step 1: Calibrate — run the model to collect +# real-distribution histograms +# (raw_histograms.npy, lut.npy, +# quantiles.npy). +# Step 2: Latency autotune + bench — rank per-mode hparams by +# measured fused-topk kernel +# latency, then run the +# remap / topk / fused / baseline +# comparison. +# Step 3: E2E accuracy — verify_algo.py on the same +# model for the unmapped baseline +# plus each mapping mode, with +# autotuned hparams. # # Usage: -# bash run_topk_benchmark.sh [OPTIONS] -# -# Options: -# --model-name NAME HuggingFace model (default: Qwen/Qwen3-1.7B) -# --topk-val K Top-k value (default: 30) -# --trials N E2E trial count (default: 8) -# --mem FRAC GPU memory fraction (default: 0.7) -# --gpu GPU_ID CUDA device (default: 0) -# --algo NAME Sparse attention algorithm (default: block_sparse_attention) -# --skip-calibrate Reuse existing calibration data -# --skip-kernel Skip kernel-level benchmark (step 2) -# --skip-e2e Skip E2E accuracy benchmark (step 3) +# bash run_topk_benchmark.sh --gpu 0 +# bash run_topk_benchmark.sh --gpu 0 --model-name Qwen/Qwen3-8B \ +# --block-size 32 --topk-val 512 # ============================================================ set -euo pipefail -# use GPU_ID to set the GPU id you want to use -GPU_ID=4 - SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BENCH_DIR="${SCRIPT_DIR}/../benchmarks" # ── Defaults ────────────────────────────────────────────────── +GPU_ID=4 MODEL_NAME="Qwen/Qwen3-1.7B" TOPK_VAL=30 TRIALS=8 MEM=0.7 ALGO="block_sparse_attention" +BLOCK_SIZE=16 +BATCH_SIZE=4 +NUM_KV_HEADS=8 +SEQ_LEN=32768 +BENCHMARKS="amc23" SKIP_CALIBRATE=false SKIP_KERNEL=false SKIP_E2E=true -BENCHMARKS="amc23" # space-separated list, e.g. "amc23 aime24" # ── Parse arguments ─────────────────────────────────────────── while [[ $# -gt 0 ]]; do @@ -60,9 +53,13 @@ while [[ $# -gt 0 ]]; do --gpu) GPU_ID="$2"; shift 2 ;; --algo) ALGO="$2"; shift 2 ;; --benchmark) BENCHMARKS="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; --skip-calibrate) SKIP_CALIBRATE=true; shift ;; --skip-kernel) SKIP_KERNEL=true; shift ;; - --skip-e2e) SKIP_E2E=true; shift ;; + --skip-e2e) SKIP_E2E=false; shift ;; # --skip-e2e actually toggles it OFF (enables) *) echo "Unknown option: $1"; exit 1 ;; esac done @@ -73,123 +70,128 @@ RESULTS_DIR="${SCRIPT_DIR}/results" mkdir -p "${RESULTS_DIR}" TIMESTAMP=$(date +%Y%m%d_%H%M%S) BENCH_LABEL=$(echo "${BENCHMARKS}" | tr ' ' '_') -RUN_DIR="${RESULTS_DIR}/topk_benchmark_${BENCH_LABEL}_${TIMESTAMP}" +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +RUN_DIR="${RESULTS_DIR}/topk_benchmark_${MODEL_SLUG}_${BENCH_LABEL}_${TIMESTAMP}" mkdir -p "${RUN_DIR}" echo "============================================================" -echo "Fair Unified TopK Benchmark" -echo " Model: ${MODEL_NAME}" -echo " Algorithm: ${ALGO}" -echo " TopK: ${TOPK_VAL}" -echo " Trials: ${TRIALS}" -echo " GPU: ${GPU_ID}" -echo " Output: ${RUN_DIR}" +echo "Unified TopK Benchmark" +echo " Model: ${MODEL_NAME}" +echo " Algorithm: ${ALGO}" +echo " TopK: ${TOPK_VAL}" +echo " Block size: ${BLOCK_SIZE}" +echo " Seq len: ${SEQ_LEN}" +echo " Batch size: ${BATCH_SIZE}" +echo " KV heads: ${NUM_KV_HEADS}" +echo " Trials: ${TRIALS}" +echo " GPU: ${GPU_ID}" +echo " Output: ${RUN_DIR}" echo "============================================================" -# ── Step 1: Calibrate (for modes 1/2) ──────────────────────── +# ── Step 1: Calibrate ──────────────────────────────────────── CALIBRATION_DIR="${RUN_DIR}/calibration" if [ "${SKIP_CALIBRATE}" = true ] && [ -d "${CALIBRATION_DIR}" ]; then echo "" echo ">>> Step 1: SKIPPED (--skip-calibrate)" else echo "" - echo ">>> Step 1: Calibrating — collecting histograms for LUT/quantile modes" + echo ">>> Step 1: Calibrating ${MODEL_NAME} — real topk histograms + LUT/quantiles" mkdir -p "${CALIBRATION_DIR}" python "${BENCH_DIR}/calibrate_topk.py" \ --model-name "${MODEL_NAME}" \ --topk-val "${TOPK_VAL}" \ --mem "${MEM}" \ --vortex-module-name "${ALGO}" \ + --page-size "${BLOCK_SIZE}" \ --output-dir "${CALIBRATION_DIR}" \ 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" echo ">>> Step 1: Done." fi -# ── Step 2: Kernel-level latency benchmark ──────────────────── +REAL_HIST_PATH="${CALIBRATION_DIR}/raw_histograms.npy" +LUT_PATH="" +Q_PATH="" +[ -f "${CALIBRATION_DIR}/lut.npy" ] && LUT_PATH="${CALIBRATION_DIR}/lut.npy" +[ -f "${CALIBRATION_DIR}/quantiles.npy" ] && Q_PATH="${CALIBRATION_DIR}/quantiles.npy" +[ -n "${LUT_PATH}" ] && echo " Calibration LUT: ${LUT_PATH}" +[ -n "${Q_PATH}" ] && echo " Calibration quantile: ${Q_PATH}" + +# ── Step 2: Latency autotune + remap bench ─────────────────── +AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" if [ "${SKIP_KERNEL}" = true ]; then echo "" echo ">>> Step 2: SKIPPED (--skip-kernel)" else - # Step 2a: Auto-tune parametric mapping modes (must run before bench) echo "" - echo ">>> Step 2a: Auto-tuning parametric mapping hyperparameters" - AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" - REAL_HIST_ARGS="" - if [ -f "${CALIBRATION_DIR}/raw_histograms.npy" ]; then - REAL_HIST_ARGS="--real-histograms ${CALIBRATION_DIR}/raw_histograms.npy" - fi - python "${BENCH_DIR}/autotune_topk_mapping.py" \ + echo ">>> Step 2a: Auto-tuning per-mode hparams by fused-topk kernel latency" + AUTOTUNE_EXTRA=() + [ -f "${REAL_HIST_PATH}" ] && AUTOTUNE_EXTRA+=(--real-histograms "${REAL_HIST_PATH}") + [ -n "${LUT_PATH}" ] && AUTOTUNE_EXTRA+=(--lut-path "${LUT_PATH}") + [ -n "${Q_PATH}" ] && AUTOTUNE_EXTRA+=(--quantiles-path "${Q_PATH}") + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ --topk-val "${TOPK_VAL}" \ - --batch-size 4 \ - --seq-len 32768 \ - --num-kv-heads 2 \ - ${REAL_HIST_ARGS} \ + --batch-size "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --page-size "${BLOCK_SIZE}" \ + --warmup 20 --repeat 100 \ + --collect-stats \ + "${AUTOTUNE_EXTRA[@]}" \ --output-json "${AUTOTUNE_JSON}" \ 2>&1 | tee "${RUN_DIR}/step2a_autotune.log" - echo ">>> Step 2a: Done. Autotune results saved to ${AUTOTUNE_JSON}" + echo ">>> Step 2a: Done. Autotune saved to ${AUTOTUNE_JSON}" - # Step 2b: Kernel-level latency + histogram benchmark (using autotune params) echo "" - echo ">>> Step 2b: Kernel-level latency benchmark (all modes)" - + echo ">>> Step 2b: Remap benchmark (baseline / fused / remap / split) with autotuned hparams" BENCH_JSON="${RUN_DIR}/kernel_latency.json" - - # Build calibration args - LUT_ARGS="" - if [ -f "${CALIBRATION_DIR}/lut.npy" ]; then - LUT_ARGS="--lut-path ${CALIBRATION_DIR}/lut.npy" - fi - QUANTILES_ARGS="" - if [ -f "${CALIBRATION_DIR}/quantiles.npy" ]; then - QUANTILES_ARGS="--quantiles-path ${CALIBRATION_DIR}/quantiles.npy" - fi - - python "${BENCH_DIR}/bench_topk.py" \ - --batch-sizes 4 8 16 32 \ - --seq-lens 2048 4096 8192 16384 32768 \ + BENCH_EXTRA=() + [ -n "${LUT_PATH}" ] && BENCH_EXTRA+=(--lut-path "${LUT_PATH}") + [ -n "${Q_PATH}" ] && BENCH_EXTRA+=(--quantiles-path "${Q_PATH}") + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --batch-sizes "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ --topk-vals "${TOPK_VAL}" \ - --num-kv-heads 2 4 \ - --distributions normal lognormal uniform \ - --histogram \ - --hit-rate \ - --warmup 20 \ - --repeat 100 \ - ${LUT_ARGS} \ - ${QUANTILES_ARGS} \ + --page-size "${BLOCK_SIZE}" \ + --distributions normal bucket_uniform \ + --mapping-modes 0 1 2 3 6 7 8 9 10 11 13 \ --autotune-json "${AUTOTUNE_JSON}" \ + "${BENCH_EXTRA[@]}" \ + --warmup 20 --repeat 100 \ --output-json "${BENCH_JSON}" \ 2>&1 | tee "${RUN_DIR}/step2b_kernel_bench.log" - echo ">>> Step 2b: Done. Results saved to ${BENCH_JSON}" - - # Step 2c: Per-mode distribution analysis - echo "" - echo ">>> Step 2c: Generating per-mode distribution analysis" - - python "${BENCH_DIR}/analyze_topk_distribution.py" \ - --bench-json "${BENCH_JSON}" \ - ${REAL_HIST_ARGS} \ - --output-dir "${RUN_DIR}" \ - 2>&1 | tee "${RUN_DIR}/step2c_analyze.log" - - echo ">>> Step 2c: Done. Per-mode plots saved to ${RUN_DIR}" fi -# ── Step 3: E2E accuracy comparison ────────────────────────── +# ── Step 3: E2E accuracy ───────────────────────────────────── if [ "${SKIP_E2E}" = true ]; then echo "" - echo ">>> Step 3: SKIPPED (--skip-e2e)" + echo ">>> Step 3: SKIPPED (default). Pass --skip-e2e to toggle it ON." else echo "" echo ">>> Step 3: E2E accuracy comparison" + # Extract autotuned hparams per mode. + eval "$(python3 -c " +import json, sys +data = json.load(open(sys.argv[1])) +best = {} +for r in data: + m = r.get('mode'); lat = r.get('latency_ms') + if m is None or lat is None: continue + if m not in best or lat < best[m]['latency_ms']: + best[m] = r +for m in (3, 6, 7, 9, 10, 11, 13): + print(f'BEST_HPARAM_{m}={best.get(m, {}).get(\"param\", 0.5)}') +" "${AUTOTUNE_JSON}")" + E2E_DIR="${RUN_DIR}/e2e" mkdir -p "${E2E_DIR}" - # Helper: run verify_algo.py with common args and save output run_e2e() { - local label="$1" - shift + # $1=label, remaining args passed to verify_algo.py + local label="$1"; shift local logfile="${E2E_DIR}/${label}.log" echo "" echo " --- ${label} ---" @@ -203,122 +205,44 @@ else 2>&1 | tee "${logfile}" } - # 3a. Full-attention baseline (oracle) - run_e2e "full_attention_baseline" \ - --full-attention - - # 3b. Naive TopK - run_e2e "naive_mode0" \ - --vortex-module-name "${ALGO}" \ - --topk-type naive - - # 3c. SGLang mode 0 (no mapping) - run_e2e "sglang_mode0_none" \ - --vortex-module-name "${ALGO}" \ - --topk-type sglang \ - --topk-mapping-mode 0 - - # 3d. SGLang mode 1 (LUT CDF) — requires calibration - if [ -f "${CALIBRATION_DIR}/lut.npy" ]; then - run_e2e "sglang_mode1_lut_cdf" \ - --vortex-module-name "${ALGO}" \ - --topk-type sglang \ - --topk-mapping-mode 1 \ - --topk-mapping-lut-path "${CALIBRATION_DIR}/lut.npy" - else - echo " --- sglang_mode1_lut_cdf: SKIPPED (no lut.npy) ---" - fi - - # 3e. SGLang mode 2 (quantile) — requires calibration - if [ -f "${CALIBRATION_DIR}/quantiles.npy" ]; then - run_e2e "sglang_mode2_quantile" \ - --vortex-module-name "${ALGO}" \ - --topk-type sglang \ - --topk-mapping-mode 2 \ - --topk-mapping-quantiles-path "${CALIBRATION_DIR}/quantiles.npy" - else - echo " --- sglang_mode2_quantile: SKIPPED (no quantiles.npy) ---" - fi - - # 3f. SGLang mode 3 (power) - run_e2e "sglang_mode3_power" \ - --vortex-module-name "${ALGO}" \ - --topk-type sglang \ - --topk-mapping-mode 3 \ - --topk-mapping-power 0.5 - - # 3g. SGLang mode 4 (log) - run_e2e "sglang_mode4_log" \ - --vortex-module-name "${ALGO}" \ - --topk-type sglang \ - --topk-mapping-mode 4 - - # 3h. SGLang mode 6 (asinh) - run_e2e "sglang_mode6_asinh" \ - --vortex-module-name "${ALGO}" \ - --topk-type sglang \ - --topk-mapping-mode 6 \ - --topk-mapping-power 1.0 - - # 3i. SGLang mode 7 (log1p) - run_e2e "sglang_mode7_log1p" \ - --vortex-module-name "${ALGO}" \ - --topk-type sglang \ - --topk-mapping-mode 7 \ - --topk-mapping-power 1.0 - - # 3j. SGLang mode 8 (Trunc8) - run_e2e "sglang_mode8_trunc8" \ - --vortex-module-name "${ALGO}" \ - --topk-type sglang \ - --topk-mapping-mode 8 - - # 3k. SGLang mode 9 (Erf) - run_e2e "sglang_mode9_erf" \ - --vortex-module-name "${ALGO}" \ - --topk-type sglang \ - --topk-mapping-mode 9 \ - --topk-mapping-power 1.0 - - # 3l. SGLang mode 10 (Tanh) - run_e2e "sglang_mode10_tanh" \ - --vortex-module-name "${ALGO}" \ - --topk-type sglang \ - --topk-mapping-mode 10 \ - --topk-mapping-power 1.0 + run_mapped() { + # $1=mode $2=hparam $3=label + local mode="$1"; local hp="$2"; local label="$3" + local extra=(--vortex-module-name "${ALGO}") + if [ "${mode}" -eq 0 ]; then + extra+=(--topk-type sglang) + else + extra+=(--topk-type sglang_fused --topk-mapping-mode "${mode}" --topk-mapping-hparam "${hp}") + fi + run_e2e "${label}" "${extra[@]}" + } - # 3m. SGLang mode 11 (Subtract) - run_e2e "sglang_mode11_subtract" \ - --vortex-module-name "${ALGO}" \ - --topk-type sglang \ - --topk-mapping-mode 11 + run_e2e "full_attention_baseline" --full-attention + run_e2e "naive_topk" --vortex-module-name "${ALGO}" --topk-type naive + run_mapped 0 0.5 "sglang_m0_none" + run_mapped 3 "${BEST_HPARAM_3}" "sglang_m3_power_p${BEST_HPARAM_3}" + run_mapped 4 0.5 "sglang_m4_log" + run_mapped 6 "${BEST_HPARAM_6}" "sglang_m6_asinh_beta${BEST_HPARAM_6}" + run_mapped 7 "${BEST_HPARAM_7}" "sglang_m7_log1p_alpha${BEST_HPARAM_7}" + run_mapped 8 0.5 "sglang_m8_trunc8" + run_mapped 9 "${BEST_HPARAM_9}" "sglang_m9_erf_alpha${BEST_HPARAM_9}" + run_mapped 10 "${BEST_HPARAM_10}" "sglang_m10_tanh_alpha${BEST_HPARAM_10}" + run_mapped 11 "${BEST_HPARAM_11}" "sglang_m11_subtract_pivot${BEST_HPARAM_11}" + run_mapped 13 "${BEST_HPARAM_13}" "sglang_m13_expstretch_alpha${BEST_HPARAM_13}" echo "" echo ">>> Step 3: Done. E2E logs saved to ${E2E_DIR}/" - - # ── Summary table: extract pass@N from each log ───────────── - echo "" - echo "============================================================" - echo "E2E Accuracy Summary" - echo "============================================================" - printf "%-35s %s\n" "Configuration" "Result" - printf "%-35s %s\n" "-----------------------------------" "------" - for logfile in "${E2E_DIR}"/*.log; do - label=$(basename "${logfile}" .log) - # Extract the last line matching pass@ pattern - result=$(grep -oP 'pass@\d+\s*[=:]\s*[\d.]+' "${logfile}" | tail -1 || echo "N/A") - printf "%-35s %s\n" "${label}" "${result}" - done - echo "============================================================" fi # ── Final Summary ───────────────────────────────────────────── echo "" echo "============================================================" echo "TopK Benchmark Complete" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" echo " All results: ${RUN_DIR}" echo " Calibration: ${CALIBRATION_DIR}" +[ "${SKIP_KERNEL}" != true ] && echo " Autotune: ${AUTOTUNE_JSON}" [ "${SKIP_KERNEL}" != true ] && echo " Kernel JSON: ${RUN_DIR}/kernel_latency.json" -[ "${SKIP_KERNEL}" != true ] && echo " Per-mode: ${RUN_DIR}/distribution_comparison_m*.png, bucket_counts_m*.csv" -[ "${SKIP_E2E}" != true ] && echo " E2E logs: ${RUN_DIR}/e2e/" +[ "${SKIP_E2E}" != true ] && echo " E2E logs: ${RUN_DIR}/e2e/" echo "============================================================" From 990b3ebc125ab601cf41a8e2037815b7fa652186 Mon Sep 17 00:00:00 2001 From: UED Date: Tue, 14 Apr 2026 03:41:46 -0400 Subject: [PATCH 20/22] - Introduced topk_output_sglang_ori function for the original SGLang kernel in vortex_torch_C. - Updated setup.py to include the new source file for the original kernel. - Enhanced autotune_topk_mapping.py and bench_topk.py to support new mapping modes and original kernel integration. - Expanded the sweep grid in autotune_topk_mapping.py for improved hyperparameter tuning. - Added a new command-line argument in calibrate_topk.py for maximum total tokens to manage KV pool size. - Removed outdated remap_function_bench.sh script as part of code cleanup. --- benchmarks/autotune_topk_mapping.py | 225 ++++--- benchmarks/bench_topk.py | 491 ++++++++++++-- benchmarks/calibrate_topk.py | 12 + csrc/register.cc | 6 + csrc/register.h | 11 + csrc/topk.cu | 14 +- csrc/topk_mapping.cuh | 112 +++- csrc/topk_sglang.cu | 295 ++++++--- csrc/topk_sglang_ori.cu | 619 ++++++++++++++++++ csrc/topk_sglang_profile.cu | 64 +- examples/remap_function_bench_topk2028.sh | 252 +++++++ ...ench.sh => remap_function_bench_topk30.sh} | 57 +- examples/run_distribution_analysis.sh | 5 + examples/run_distribution_analysis_new.sh | 5 + examples/run_topk_benchmark.sh | 5 + examples/verify_algo_topk_mapping.sh | 4 + examples/verify_algo_topk_mapping_new.sh | 4 + setup.py | 1 + 18 files changed, 1901 insertions(+), 281 deletions(-) create mode 100644 csrc/topk_sglang_ori.cu create mode 100755 examples/remap_function_bench_topk2028.sh rename examples/{remap_function_bench.sh => remap_function_bench_topk30.sh} (84%) diff --git a/benchmarks/autotune_topk_mapping.py b/benchmarks/autotune_topk_mapping.py index db21321..e103953 100644 --- a/benchmarks/autotune_topk_mapping.py +++ b/benchmarks/autotune_topk_mapping.py @@ -1,10 +1,21 @@ """ Auto-tune TopK mapping hyperparameters by profiled kernel latency. -For each (mode, hyperparameter) combo in the sweep grid, this script runs -the fused remap+topk kernel (topk_output_sglang_fused) on synthetic or -real-distribution inputs, measures end-to-end latency with CUDA events, -and picks the hyperparameter with the lowest measured latency per mode. +For each (mode, hyperparameter) combo in the sweep grid, this script picks +the hyperparameter whose remapped score distribution produces the lowest +*unfused* topk kernel latency. The measurement is a split-phase: + + 1. topk_remap_only(x, mode, power) → float32 buffer [NOT timed] + 2. topk_output_sglang(remapped) [TIMED] + +Timing only step 2 isolates the Stage-2 radix cost, which is what bucket +uniformity actually affects. The remap cost is the same constant regardless +of power, so it would only pollute the ranking. + +Non-arithmetic baselines (MAPPING_LUT_CDF=1, MAPPING_QUANTILE=2, +MAPPING_TRUNC8=8) route their mapping through compute_stage1_bin, not +apply_transform, so split-phase is a no-op for them. Those are timed via +the fused kernel and marked `timing_mode="fused_fallback"` in the output. Distribution statistics (gini, max/mean, counter-based Stage-2 cost) are still collected for diagnostics, but they do NOT drive the ranking — the @@ -25,31 +36,65 @@ import numpy as np import torch -from bench_topk import make_topk_inputs, bench_kernel, compute_histogram_stats +from bench_topk import ( + make_topk_inputs, + bench_kernel, + compute_histogram_stats, + scores_from_histogram, +) from vortex_torch_C import ( + topk_output_sglang, topk_output_sglang_fused, + topk_remap_only, topk_profile_histogram, topk_profile_counters, ) +# Modes where topk_mapping.cuh::apply_transform is a genuine value-space +# transform (power / asinh / log / log1p / erf / tanh / subtract / exp_stretch, +# plus the top-spreading shift_pow2 / shift_pow3 / linear_steep family) and +# also mode 0 (identity). For these the split-phase `remap_only + unfused +# topk` is correct. Modes 1/2/8 (LUT_CDF / QUANTILE / TRUNC8) apply their +# mapping inside compute_stage1_bin, so split-phase is a no-op. +ARITHMETIC_MODES = {0, 3, 4, 6, 7, 9, 10, 11, 13, 15, 16, 17, 18, 19, 20} + + # Only parametric modes need auto-tuning. Mode 0 (none) and mode 4 (log) -# have no knob; mode 0 is always the baseline. +# have no knob; mode 0 is always the baseline. Sweep grids widened so the +# autotune actually explores the tails of each transform. SWEEP_GRID: Dict[int, List[float]] = { - 3: [0.1, 0.25, 0.5, 0.75, 0.9], # power: p - 6: [0.1, 0.5, 1.0, 2.0, 4.0], # asinh: beta - 7: [0.1, 0.5, 1.0, 2.0, 4.0, 8.0], # log1p: alpha - 9: [0.1, 0.5, 1.0, 2.0, 4.0], # erf: alpha - 10: [0.1, 0.5, 1.0, 2.0, 4.0], # tanh: alpha - 11: [-1.0, -0.5, 0.0, 0.5, 1.0], # subtract: pivot (free hparam) - 13: [0.5, 1.0, 2.0, 4.0, 8.0], # exp_stretch: alpha + 3: [0.1, 0.5, 1.0, 2.0, 4.0, 5.0, 9.0], # power: p + 6: [0.1, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0], # asinh: beta + 7: [0.1, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0], # log1p: alpha + 9: [0.1, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0], # erf: alpha + 10: [0.1, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0], # tanh: alpha + 11: [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0], # subtract: pivot + 13: [0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0], # exp_stretch: alpha + 15: [-1.0, -0.5, -0.25, 0.0, 0.25, 0.5, 1.0], # shift_pow2: pivot + 16: [-4.0, -2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0], # shift_pow3: pivot (widened) + 17: [0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0], # linear_steep: k + 18: [-1.0, -0.5, -0.25, 0.0, 0.25, 0.5, 1.0], # half_square: pivot + 19: [-1.0, -0.5, -0.25, 0.0, 0.25, 0.5, 1.0], # half_cube: pivot + # dense_mant clamp: sweep a wide range because real attention scores + # can span [-400, +200] on some models (raw logits), not just [0, 1]. + 20: [0.0, 1.0, 5.0, 10.0, 20.0, 50.0, 100.0], # dense_mant: clamp pivot } -PARAM_NAME = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", 11: "pivot", 13: "alpha"} +PARAM_NAME = { + 3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", + 11: "pivot", 13: "alpha", + 15: "pivot", 16: "pivot", 17: "k", + 18: "pivot", 19: "pivot", + 20: "clamp", +} MODE_NAMES = { 0: "none", 1: "lut_cdf", 2: "quantile", 3: "power", 4: "log", 6: "asinh", 7: "log1p", 8: "trunc8", 9: "erf", 10: "tanh", 11: "subtract", 13: "exp_stretch", + 15: "shift_pow2", 16: "shift_pow3", 17: "linear_steep", + 18: "half_square", 19: "half_cube", + 20: "dense_mant", } # Non-parametric modes — no knob to sweep; timed once as a reference point. @@ -59,54 +104,8 @@ # ---------- Real-distribution score generation ---------- - -def _key_to_fp16(key: int) -> np.float16: - """Invert convert_to_uint8's sign-flip for a single 16-bit key.""" - bits = (key & 0x7FFF) if key >= 0x8000 else ((~key) & 0xFFFF) - return np.array([bits], dtype=np.uint16).view(np.float16)[0] - - -def _build_bin_range_table(): - """Return per-bin (lo, hi) fp16 value tables for all 256 radix bins.""" - all_bits = np.arange(65536, dtype=np.uint16) - all_fp16 = all_bits.view(np.float16) - keys = np.where( - (all_bits & 0x8000).astype(bool), - (~all_bits).astype(np.uint16), - all_bits | np.uint16(0x8000), - ) - bins = (keys >> 8).astype(np.uint8) - all_f32 = all_fp16.astype(np.float32) - valid = np.isfinite(all_f32) - bin_lo = np.full(256, np.inf, dtype=np.float32) - bin_hi = np.full(256, -np.inf, dtype=np.float32) - for b in range(256): - mask = (bins == b) & valid - if mask.any(): - vals = all_f32[mask] - bin_lo[b] = vals.min() - bin_hi[b] = vals.max() - empty = bin_lo > bin_hi - for b in np.where(empty)[0]: - val = float(_key_to_fp16((int(b) << 8) | 0x80)) - bin_lo[b] = val - bin_hi[b] = val - return bin_lo, bin_hi - - -def _scores_from_histogram(histogram: np.ndarray, total_pages: int, device="cuda") -> torch.Tensor: - bin_lo, bin_hi = _build_bin_range_table() - counts = histogram.astype(np.float64) - total = counts.sum() - if total == 0: - return torch.zeros(total_pages, 1, 1, dtype=torch.bfloat16, device=device) - probs = counts / total - bin_indices = np.random.choice(256, size=total_pages, p=probs) - lo = bin_lo[bin_indices] - hi = bin_hi[bin_indices] - rand = np.random.uniform(0, 1, size=total_pages).astype(np.float32) - scores_f32 = lo + rand * (hi - lo) - return torch.from_numpy(scores_f32).to(torch.bfloat16).reshape(total_pages, 1, 1).to(device) +# _build_bin_range_table / scores_from_histogram now live in bench_topk.py +# so both autotune and bench_topk draw scores from the same sampler. def _make_real_inputs(args, histogram: np.ndarray) -> dict: @@ -125,10 +124,13 @@ def _make_real_inputs(args, histogram: np.ndarray) -> dict: ) dense_kv_indices = torch.arange(total_dense, dtype=torch.int32, device="cuda") sparse_kv_indices = torch.zeros(eff_bs * sparse_per_seg, dtype=torch.int32, device="cuda") - x = _scores_from_histogram(histogram, total_dense) + x = scores_from_histogram(histogram, total_dense, device="cuda", + score_dtype=torch.bfloat16) + remapped = torch.empty(total_dense, dtype=torch.float32, device="cuda").reshape(x.shape) return { "x": x, + "remapped": remapped, "dense_kv_indptr": dense_kv_indptr, "sparse_kv_indptr": sparse_kv_indptr, "dense_kv_indices": dense_kv_indices, @@ -139,9 +141,20 @@ def _make_real_inputs(args, histogram: np.ndarray) -> dict: } +def _ensure_remapped_buffer(inputs: dict) -> torch.Tensor: + """Lazy-allocate a float32 buffer matching x.shape for the split-phase.""" + buf = inputs.get("remapped") + if buf is None: + x = inputs["x"] + buf = torch.empty(x.numel(), dtype=torch.float32, device=x.device).reshape(x.shape) + inputs["remapped"] = buf + return buf + + # ---------- Latency-based evaluation ---------- def _time_fused(inputs, args, mode: int, power: float) -> dict: + """Fused remap+topk kernel latency (used as fallback for modes 1/2/8).""" eff_bs = inputs["eff_batch_size"] pages_per_seg = inputs["num_pages_per_seg"] inputs["sparse_kv_indices"].zero_() @@ -167,6 +180,59 @@ def _time_fused(inputs, args, mode: int, power: float) -> dict: warmup=args.warmup, repeat=args.repeat) +def _time_unfused_on_remapped(inputs, args, mode: int, power: float) -> dict: + """Time the unfused topk kernel on pre-remapped scores. + + For mode 0 the original scores are used directly. For every other + arithmetic mode we run topk_remap_only once (not timed) into a + pre-allocated float32 buffer, then time topk_output_sglang on that + buffer with bench_kernel's warmup + repeat loop. This isolates the + Stage-2 radix cost from the remap pass. + """ + eff_bs = inputs["eff_batch_size"] + pages_per_seg = inputs["num_pages_per_seg"] + + if mode == 0: + src = inputs["x"] + else: + remapped = _ensure_remapped_buffer(inputs) + topk_remap_only( + inputs["x"], + inputs["dense_kv_indptr"], + remapped, + eff_bs, + args.reserved_bos, + args.reserved_eos, + mode, + float(power), + ) + torch.cuda.synchronize() + src = remapped + + inputs["sparse_kv_indices"].zero_() + call_args = ( + src, + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, + args.topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + ) + return bench_kernel(topk_output_sglang, call_args, + warmup=args.warmup, repeat=args.repeat) + + +def _time_mode(inputs, args, mode: int, power: float) -> tuple: + """Returns (latency_dict, timing_mode_str).""" + if mode in ARITHMETIC_MODES: + return _time_unfused_on_remapped(inputs, args, mode, power), "unfused_on_remapped" + return _time_fused(inputs, args, mode, power), "fused_fallback" + + def _collect_diagnostics(inputs, args, mode: int, power: float) -> dict: """Optional distribution/counter stats for reporting only (post-timing).""" eff_bs = inputs["eff_batch_size"] @@ -209,6 +275,11 @@ def _collect_diagnostics(inputs, args, mode: int, power: float) -> dict: diag["threshold_bin_mean"] = c[:, 0].mean().item() diag["num_equal_mean"] = c[:, 2].mean().item() diag["refine_rounds_mean"] = c[:, 4].mean().item() + # selected_from_thr = topk_val - num_above (clamped >= 0). Used as + # a tie-breaker by bench_topk._load_autotune_hparams when several + # modes have indistinguishable latency. + sel_from_thr = (float(args.topk_val) - c[:, 1]).clamp(min=0.0) + diag["selected_from_thr_mean"] = sel_from_thr.mean().item() return diag @@ -218,13 +289,14 @@ def _run_sweep(args, inputs, dist_label: str) -> List[dict]: # Baselines: time them but their param is fixed. for mode, power in BASELINES: - lat = _time_fused(inputs, args, mode, power) + lat, tmode = _time_mode(inputs, args, mode, power) entry = { "mode": mode, "mode_name": MODE_NAMES.get(mode, f"m{mode}"), "param_name": "(baseline)", "param": power, "distribution": dist_label, + "timing_mode": tmode, "latency_ms": lat["mean_ms"], "latency_median_ms": lat["median_ms"], "latency_min_ms": lat["min_ms"], @@ -232,21 +304,22 @@ def _run_sweep(args, inputs, dist_label: str) -> List[dict]: entry.update(_collect_diagnostics(inputs, args, mode, power)) results.append(entry) print( - f" mode={mode:>2d} ({MODE_NAMES[mode]:>5s}) baseline " - f" latency={lat['mean_ms']:.4f} ms" + f" mode={mode:>2d} ({MODE_NAMES[mode]:>10s}) baseline " + f" [{tmode:>20s}] latency={lat['mean_ms']:.4f} ms" ) # Parametric sweep, one (mode, param) combo at a time. for mode, values in SWEEP_GRID.items(): pname = PARAM_NAME[mode] for val in values: - lat = _time_fused(inputs, args, mode, float(val)) + lat, tmode = _time_mode(inputs, args, mode, float(val)) entry = { "mode": mode, "mode_name": MODE_NAMES.get(mode, f"m{mode}"), "param_name": pname, "param": float(val), "distribution": dist_label, + "timing_mode": tmode, "latency_ms": lat["mean_ms"], "latency_median_ms": lat["median_ms"], "latency_min_ms": lat["min_ms"], @@ -254,8 +327,8 @@ def _run_sweep(args, inputs, dist_label: str) -> List[dict]: entry.update(_collect_diagnostics(inputs, args, mode, float(val))) results.append(entry) print( - f" mode={mode:>2d} ({MODE_NAMES[mode]:>5s}) {pname}={val:<6.3f} " - f" latency={lat['mean_ms']:.4f} ms" + f" mode={mode:>2d} ({MODE_NAMES[mode]:>10s}) {pname}={val:<6.3f} " + f" [{tmode:>20s}] latency={lat['mean_ms']:.4f} ms" ) return results @@ -320,19 +393,11 @@ def main(): help="Path to .npy float32[256] quantile table for MAPPING_QUANTILE (mode 2).") args = parser.parse_args() + # Modes 1 (LUT_CDF) and 2 (Quantile) are no longer evaluated — they + # don't use topk_mapping::apply_transform (their mapping is done inside + # compute_stage1_bin) and are kept out of the comparison entirely. args._mapping_lut = None args._mapping_quantiles = None - # Include modes 1/2 as baselines when calibration tables are provided. - if args.lut_path: - lut_np = np.load(args.lut_path).astype(np.uint8) - args._mapping_lut = torch.from_numpy(lut_np).cuda() - BASELINES.append((1, 0.5)) - print(f"[autotune] loaded LUT from {args.lut_path}") - if args.quantiles_path: - q_np = np.load(args.quantiles_path).astype(np.float32) - args._mapping_quantiles = torch.from_numpy(q_np).cuda() - BASELINES.append((2, 0.5)) - print(f"[autotune] loaded quantiles from {args.quantiles_path}") real_histogram: Optional[np.ndarray] = None if args.real_histograms: diff --git a/benchmarks/bench_topk.py b/benchmarks/bench_topk.py index 4bd5bec..f0860c9 100644 --- a/benchmarks/bench_topk.py +++ b/benchmarks/bench_topk.py @@ -23,14 +23,25 @@ import torch from vortex_torch_C import ( - topk_output, - topk_output_sglang, # unmapped baseline - topk_output_sglang_fused, # fused remap + topk - topk_remap_only, # standalone remap + topk_output, # full CUB BlockRadixSort topk (max 4096 pages/seg) + topk_output_sglang, # 2-stage radix approximate topk (unmapped baseline) + topk_output_sglang_fused, # fused remap + 2-stage radix topk + topk_output_sglang_ori, # original SGLang reference kernel + topk_remap_only, # standalone value-space remap topk_profile_histogram, topk_profile_counters, ) +# topk_output's template ladder tops out at 8192 pages per segment +# (see topk.cu::topk_output, branches up to <= 8192). Runs larger than +# that hit TORCH_CHECK(false). +TOPK_OUTPUT_MAX_PAGES = 8192 + +# The ori kernel has TopK baked in at compile time. If setup.py was built +# with a different value, calls will fail; this is the topk_val that +# matches the current build of topk_sglang_ori.cu. +TOPK_ORI_BAKED_IN = 30 + MAPPING_MODE_NAMES = { 0: "None", @@ -45,32 +56,136 @@ 10: "Tanh", 11: "Subtract", 13: "ExpStretch", + 15: "ShiftPow2", + 16: "ShiftPow3", + 17: "LinearSteep", + 18: "HalfSquare", + 19: "HalfCube", + 20: "DenseMant", } +# Modes whose value-space transform is a real apply_transform() pass. Modes +# 1 (LUT_CDF), 2 (QUANTILE) and 8 (TRUNC8) apply their mapping inside +# compute_stage1_bin, not apply_transform — so `topk_remap_only` cannot +# reproduce them (the fp32 buffer would just contain the raw values). For +# those modes the split-phase numbers are N/A; only the fused kernel is a +# meaningful reference. +ARITHMETIC_MODES = {0, 3, 4, 6, 7, 9, 10, 11, 13, 15, 16, 17, 18, 19, 20} + + +_AUTOTUNE_TIE_TOLERANCE_MS = 0.0002 # ≈ CUDA event noise floor at this kernel size + def _load_autotune_hparams(path: str) -> Dict[int, float]: """Load per-mode best hyperparameters from an autotune_results.json. The JSON is produced by autotune_topk_mapping.py and contains a list of - {mode, param, latency_ms, ...} entries. For each mode we pick the entry - with the lowest measured latency and return {mode: best_param}. - - Modes with no parametric sweep (0=None, 4=Log) return a dummy 0.5; the - caller should override to taste. + {mode, param, latency_ms, num_equal_mean, selected_from_thr_mean, ...} + entries. For each mode we group all sweep entries, find the lowest + latency, then break ties (within `_AUTOTUNE_TIE_TOLERANCE_MS`) by: + + 1. Smallest `num_equal_mean` (= thr_size). Stage-2 cost is O(thr_size), + so a smaller threshold bin is a better proxy for real fused + latency than the noisy `latency_ms` measurement. + 2. Smallest `selected_from_thr_mean`. How many pages the topk has to + pull from the threshold bin during refinement. + 3. Lowest `latency_ms` again (final fallback). + + Modes with no parametric sweep (0=None, 4=Log) return a dummy 0.5; + the caller should override to taste. """ with open(path) as f: data = json.load(f) - best: Dict[int, dict] = {} + grouped: Dict[int, list] = {} for r in data: m = r.get("mode") lat = r.get("latency_ms") if m is None or lat is None: continue - if m not in best or lat < best[m]["latency_ms"]: - best[m] = r + grouped.setdefault(m, []).append(r) + + best: Dict[int, dict] = {} + for m, entries in grouped.items(): + min_lat = min(e["latency_ms"] for e in entries) + contenders = [ + e for e in entries + if e["latency_ms"] - min_lat <= _AUTOTUNE_TIE_TOLERANCE_MS + ] + # Tie-breakers: lowest num_equal_mean, then lowest sel_thr, + # then lowest latency. Missing diagnostic fields → +inf so they + # lose tie-breaks (we still keep them as fallback candidates). + def _rank_key(e): + return ( + e.get("num_equal_mean", float("inf")), + e.get("selected_from_thr_mean", float("inf")), + e["latency_ms"], + ) + best[m] = min(contenders, key=_rank_key) + return {m: float(r["param"]) for m, r in best.items()} +def _key_to_fp16(key: int) -> np.float16: + """Invert convert_to_uint8's sign-flip for a single 16-bit key.""" + bits = (key & 0x7FFF) if key >= 0x8000 else ((~key) & 0xFFFF) + return np.array([bits], dtype=np.uint16).view(np.float16)[0] + + +def build_bin_range_table(): + """Per-bin (lo, hi) fp16 value tables for the 256 Stage-1 radix bins. + + Shared by the real-distribution samplers in bench_topk.py and + autotune_topk_mapping.py so both scripts generate identical inputs. + """ + all_bits = np.arange(65536, dtype=np.uint16) + all_fp16 = all_bits.view(np.float16) + keys = np.where( + (all_bits & 0x8000).astype(bool), + (~all_bits).astype(np.uint16), + all_bits | np.uint16(0x8000), + ) + bins = (keys >> 8).astype(np.uint8) + all_f32 = all_fp16.astype(np.float32) + valid = np.isfinite(all_f32) + bin_lo = np.full(256, np.inf, dtype=np.float32) + bin_hi = np.full(256, -np.inf, dtype=np.float32) + for b in range(256): + mask = (bins == b) & valid + if mask.any(): + vals = all_f32[mask] + bin_lo[b] = vals.min() + bin_hi[b] = vals.max() + empty = bin_lo > bin_hi + for b in np.where(empty)[0]: + val = float(_key_to_fp16((int(b) << 8) | 0x80)) + bin_lo[b] = val + bin_hi[b] = val + return bin_lo, bin_hi + + +def scores_from_histogram( + histogram: np.ndarray, + total_pages: int, + device: str = "cuda", + score_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """Sample `total_pages` scores whose Stage-1 bucket distribution matches + the given 256-bin histogram (produced by calibration). Each bucket is + sampled uniformly over the fp16 range that maps into it.""" + bin_lo, bin_hi = build_bin_range_table() + counts = histogram.astype(np.float64) + total = counts.sum() + if total == 0: + return torch.zeros(total_pages, 1, 1, dtype=score_dtype, device=device) + probs = counts / total + bin_indices = np.random.choice(256, size=total_pages, p=probs) + lo = bin_lo[bin_indices] + hi = bin_hi[bin_indices] + rand = np.random.uniform(0, 1, size=total_pages).astype(np.float32) + scores_f32 = lo + rand * (hi - lo) + return torch.from_numpy(scores_f32).to(score_dtype).reshape(total_pages, 1, 1).to(device) + + def make_topk_inputs( batch_size: int, num_kv_heads: int, @@ -81,9 +196,15 @@ def make_topk_inputs( reserved_eos: int, score_dtype: torch.dtype, distribution: str = "normal", + real_histogram: np.ndarray = None, device: str = "cuda", ) -> dict: - """Synthesize CSR-formatted paged attention inputs for kernel timing.""" + """Synthesize CSR-formatted paged attention inputs for kernel timing. + + When `real_histogram` is provided, scores are drawn from that 256-bin + distribution (ignoring `distribution`) so the benchmark sees the same + Stage-1 bucket distribution as the calibrated model. + """ eff_batch_size = batch_size * num_kv_heads num_pages_per_seg = math.ceil(seq_len / page_size) total_dense_pages = eff_batch_size * num_pages_per_seg @@ -101,24 +222,25 @@ def make_topk_inputs( dense_kv_indices = torch.arange(total_dense_pages, dtype=torch.int32, device=device) sparse_kv_indices = torch.zeros(total_sparse_pages, dtype=torch.int32, device=device) - if distribution == "normal": - x = torch.randn(total_dense_pages, 1, 1, device=device) + if real_histogram is not None: + x = scores_from_histogram(real_histogram, total_dense_pages, device=device, + score_dtype=score_dtype) + elif distribution == "normal": + x = torch.randn(total_dense_pages, 1, 1, device=device).to(score_dtype) elif distribution == "lognormal": - x = torch.randn(total_dense_pages, 1, 1, device=device).exp() + x = torch.randn(total_dense_pages, 1, 1, device=device).exp().to(score_dtype) elif distribution == "uniform": - x = torch.rand(total_dense_pages, 1, 1, device=device) + x = torch.rand(total_dense_pages, 1, 1, device=device).to(score_dtype) elif distribution == "bucket_uniform": # Uniform across all 256 fp16 radix buckets. Random uint16 bit # patterns → interpret as fp16. NaN/Inf patterns collapse to ±0. raw_bits = torch.randint(0, 65536, (total_dense_pages,), dtype=torch.int32, device=device) abs_bits = raw_bits & 0x7FFF raw_bits[abs_bits >= 0x7C00] = raw_bits[abs_bits >= 0x7C00] & 0x8000 - x = raw_bits.to(torch.int16).view(torch.float16).float().reshape(total_dense_pages, 1, 1) + x = raw_bits.to(torch.int16).view(torch.float16).float().reshape(total_dense_pages, 1, 1).to(score_dtype) else: raise ValueError(f"Unknown distribution: {distribution}") - x = x.to(score_dtype) - return { "x": x, "dense_kv_indptr": dense_kv_indptr, @@ -185,10 +307,9 @@ def compute_histogram_stats(histograms: torch.Tensor) -> dict: def _collect_threshold_stats(inputs, topk_val, pages_per_seg, args, mode: int, power: float) -> dict: - """Run topk_profile_counters once and aggregate threshold-bin stats. - - Profile kernel is invoked AFTER all latency measurements have finished, - so the counter writes never contaminate timing. + """Run topk_profile_counters + topk_profile_histogram once and aggregate + threshold-bin / bucket-distribution stats. Profile kernels run AFTER all + latency measurements, so their writes never contaminate timing. """ eff_bs = inputs["eff_batch_size"] counter_buf = torch.zeros(eff_bs, 6, dtype=torch.int32, device="cuda") @@ -214,6 +335,40 @@ def _collect_threshold_stats(inputs, topk_val, pages_per_seg, args, mode: int, p ) torch.cuda.synchronize() c = counter_buf.float() + + # Run the 256-bin histogram profile to compute the rank_target_bins + # metric: how many bins ABOVE the threshold bin (i.e. the bins whose + # pages are selected without Stage-2 refinement) actually contain + # selected pages, and the mean pages-per-such-bin. + hist_buf = torch.zeros(eff_bs, 256, dtype=torch.int32, device="cuda") + topk_profile_histogram( + inputs["x"], + inputs["dense_kv_indptr"], + hist_buf, + eff_bs, + args.reserved_bos, + args.reserved_eos, + mode, + power, + lut_t, + q_t, + ) + torch.cuda.synchronize() + + thr_idx = counter_buf[:, 0].to(torch.int64) # [eff_bs] + hist = hist_buf.to(torch.int64) # [eff_bs, 256] + bin_ids = torch.arange(256, device="cuda", dtype=torch.int64).unsqueeze(0) # [1, 256] + above_mask = bin_ids > thr_idx.unsqueeze(1) # [eff_bs, 256] + above_populated = ((hist > 0) & above_mask).sum(dim=1).float() # bins >thr with any pages + pages_above = (hist * above_mask.to(torch.int64)).sum(dim=1).float() # total pages in those bins + # Mean pages per populated above-threshold bin (per-segment, then + # averaged). Guard against divide-by-zero. + pages_per_bin = torch.where( + above_populated > 0, + pages_above / above_populated, + torch.zeros_like(above_populated), + ) + # Selected from threshold bin = topk_val - num_above (clamped >= 0). sel_from_thr = (float(topk_val) - c[:, 1]).clamp(min=0.0) return { @@ -225,6 +380,9 @@ def _collect_threshold_stats(inputs, topk_val, pages_per_seg, args, mode: int, p "selected_from_thr_mean": sel_from_thr.mean().item(), "selected_from_thr_max": sel_from_thr.max().item(), "refine_rounds_mean": c[:, 4].mean().item(), + # Rank-target metrics: how the top pages are actually spread. + "above_bins_mean": above_populated.mean().item(), + "pages_per_above_bin_mean": pages_per_bin.mean().item(), } @@ -241,6 +399,7 @@ def _resolve_hparam(args, mode: int) -> float: def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, distribution, modes: List[int]) -> dict: """Time baseline, fused, and split-phase for each mode at one config.""" + real_hist = getattr(args, "_real_histogram", None) if distribution == "real" else None inputs = make_topk_inputs( batch_size=batch_size, num_kv_heads=num_kv_heads, @@ -250,13 +409,17 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, reserved_bos=args.reserved_bos, reserved_eos=args.reserved_eos, score_dtype=torch.bfloat16, - distribution=distribution, + distribution=distribution if distribution != "real" else "normal", + real_histogram=real_hist, ) eff_bs = inputs["eff_batch_size"] pages_per_seg = inputs["num_pages_per_seg"] total_dense = inputs["x"].numel() - # Baseline: unmapped topk. + # Baseline = unmapped topk_output_sglang (CUB two-stage radix, the + # kernel every mapped mode's split-phase ends up calling). This is + # the `base_us` column and also what the `None` row reports, so + # None's topk_us == base_us by construction. baseline_args = ( inputs["x"], inputs["dense_kv_indptr"], @@ -268,7 +431,55 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, inputs["sparse_kv_indices"].zero_() baseline = bench_kernel(topk_output_sglang, baseline_args, args.warmup, args.repeat) + # Optional extra row: the full CUB BlockRadixSort topk from topk.cu. + # This is a "true naive" — exact sort, no bucketing tricks — for A/B + # against the 2-stage approximate baseline. Only runs when pages_per_seg + # fits the kernel's template ladder (<= TOPK_OUTPUT_MAX_PAGES = 4096). + naive_ms = None + if pages_per_seg <= TOPK_OUTPUT_MAX_PAGES: + naive_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["dense_kv_indices"], # NOTE: topk_output arg order differs + inputs["sparse_kv_indptr"], # from topk_output_sglang + inputs["sparse_kv_indices"], + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + ) + inputs["sparse_kv_indices"].zero_() + naive_ms = bench_kernel( + topk_output, naive_args, args.warmup, args.repeat + )["mean_ms"] + + # Optional extra row: the original SGLang kernel from topk_sglang_ori.cu, + # compiled with TopK=TOPK_ORI_BAKED_IN. Only runs when topk_val matches + # that constant; otherwise the row is skipped with a warning. It is NOT + # used as the baseline — this is a separate A/B point so you can see the + # ori-vs-naive gap at a glance. + sglang_ori_ms = None + if topk_val == TOPK_ORI_BAKED_IN: + ori_indices = torch.empty(eff_bs, TOPK_ORI_BAKED_IN, + dtype=torch.int32, device="cuda") + ori_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + ori_indices, + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + ) + sglang_ori_ms = bench_kernel( + topk_output_sglang_ori, ori_args, args.warmup, args.repeat + )["mean_ms"] + # Pre-allocate the float32 buffer used for the split-phase (remap → baseline). + # Split-phase remapped buffer is **float32** to preserve Stage-2 + # refinement precision. The fused kernel computes transforms in + # fp32 internally (so its Stage-2 sub-bin keys carry transform- + # dependent bits in positions [15:0]); a narrower remapped buffer + # (bf16 or fp16) would zero those bits on round-trip and change + # the Stage-2 tie-break ordering vs the fused path. fp32 is the + # only lossless choice. The kernel supports bf16 output too (see + # topk_remap_only's dispatch table) for experimental paths, but we + # don't use it here because correctness matters more than the + # small memory-bandwidth win. remapped = torch.empty(total_dense, dtype=torch.float32, device="cuda").reshape(inputs["x"].shape) config = { @@ -279,10 +490,87 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, "distribution": distribution, "pages_per_seg": pages_per_seg, "baseline_ms": baseline["mean_ms"], + "naive_ms": naive_ms, + "sglang_ori_ms": sglang_ori_ms, "modes": [], } + # Naive row — full CUB BlockRadixSort from topk.cu. No mapping, no + # remap, no fused. Only populated when pages_per_seg fits the kernel. + if naive_ms is not None: + config["modes"].append({ + "mode": -2, # sentinel so ranking/autotune skip it + "mode_name": "Naive", + "power": 0.5, + "remap_ms": None, + "topk_after_remap_ms": naive_ms, + "split_total_ms": None, + "fused_ms": None, + "threshold_bin_mean": 0.0, + "threshold_bin_max": 0.0, + "num_above_mean": 0.0, + "threshold_bin_size_mean": 0.0, + "threshold_bin_size_max": 0.0, + "selected_from_thr_mean": 0.0, + "selected_from_thr_max": 0.0, + "refine_rounds_mean": 0.0, + "above_bins_mean": 0.0, + "pages_per_above_bin_mean": 0.0, + }) + + # The None row is a pass-through to the naive baseline: no remap, no + # fused, and topk_us == base_us by construction. Distribution metrics + # are populated by running the profile kernels with mode=0 so the user + # can see the unmapped Stage-1 bucket layout as a reference. + none_stats = _collect_threshold_stats( + inputs, topk_val, pages_per_seg, args, mode=0, power=0.5 + ) + config["modes"].append({ + "mode": 0, + "mode_name": "None", + "power": 0.5, + "remap_ms": None, + "topk_after_remap_ms": baseline["mean_ms"], + "split_total_ms": None, + "fused_ms": None, + **none_stats, + }) + + # Extra row for the original SGLang kernel — only populated when the + # build's baked-in TopK matches topk_val. Also a pass-through (no + # remap, no fused); topk_us is the ori kernel latency. + if sglang_ori_ms is not None: + config["modes"].append({ + "mode": -1, # sentinel so ranking/autotune skip it + "mode_name": "sglang_ori", + "power": 0.5, + "remap_ms": None, + "topk_after_remap_ms": sglang_ori_ms, + "split_total_ms": None, + "fused_ms": None, + "threshold_bin_mean": 0.0, + "threshold_bin_max": 0.0, + "num_above_mean": 0.0, + "threshold_bin_size_mean": 0.0, + "threshold_bin_size_max": 0.0, + "selected_from_thr_mean": 0.0, + "selected_from_thr_max": 0.0, + "refine_rounds_mean": 0.0, + "above_bins_mean": 0.0, + "pages_per_above_bin_mean": 0.0, + }) + else: + print(f"[bench-remap] sglang_ori row SKIPPED: topk_val={topk_val} != " + f"TOPK_ORI_BAKED_IN ({TOPK_ORI_BAKED_IN}). Rebuild topk_sglang_ori.cu " + f"with a matching TopK to enable the row.") + for mode in modes: + # Mode 0 is already emitted as the `None` row above (pass-through + # to the ori baseline with no remap/fused). Skip to avoid a + # duplicate row and a spurious fused-mode-0 measurement. + if mode == 0: + continue + power = _resolve_hparam(args, mode) lut_t = getattr(args, "_mapping_lut", None) if mode == 1 else None @@ -299,30 +587,43 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, inputs["sparse_kv_indices"].zero_() fused = bench_kernel(topk_output_sglang_fused, fused_args, args.warmup, args.repeat) - # Split-phase timing: first the standalone remap, then the unmapped - # topk on the remapped buffer. - remap_args = ( - inputs["x"], - inputs["dense_kv_indptr"], - remapped, - eff_bs, args.reserved_bos, args.reserved_eos, - mode, power, - ) - remap_only = bench_kernel(topk_remap_only, remap_args, args.warmup, args.repeat) + # Split-phase timing is only meaningful for arithmetic modes. + # MAPPING_LUT_CDF / QUANTILE / TRUNC8 apply their mapping inside + # compute_stage1_bin, which topk_remap_only cannot reproduce, so we + # report N/A for the split-phase fields and rely on the fused kernel + # as the only valid reference latency. + if mode in ARITHMETIC_MODES: + remap_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + remapped, + eff_bs, args.reserved_bos, args.reserved_eos, + mode, power, + ) + remap_only = bench_kernel(topk_remap_only, remap_args, args.warmup, args.repeat) + + # Populate the remapped buffer once so the unfused-topk warmup + # iterations don't read stale data. + topk_remap_only(*remap_args) + torch.cuda.synchronize() + split_topk_args = ( + remapped, + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + ) + inputs["sparse_kv_indices"].zero_() + split_topk = bench_kernel(topk_output_sglang, split_topk_args, args.warmup, args.repeat) - split_topk_args = ( - remapped, - inputs["dense_kv_indptr"], - inputs["sparse_kv_indptr"], - inputs["dense_kv_indices"], - inputs["sparse_kv_indices"], - eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, - ) - # Run remap once so the buffer is populated for warmup of topk-on-remapped. - topk_remap_only(*remap_args) - torch.cuda.synchronize() - inputs["sparse_kv_indices"].zero_() - split_topk = bench_kernel(topk_output_sglang, split_topk_args, args.warmup, args.repeat) + remap_ms = remap_only["mean_ms"] + topk_after_remap_ms = split_topk["mean_ms"] + split_total_ms = remap_ms + topk_after_remap_ms + else: + remap_ms = None + topk_after_remap_ms = None + split_total_ms = None # Counter collection is run AFTER all timing measurements for this mode # so it cannot affect the timings. @@ -332,9 +633,9 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, "mode": mode, "mode_name": MAPPING_MODE_NAMES.get(mode, f"m{mode}"), "power": power, - "remap_ms": remap_only["mean_ms"], - "topk_after_remap_ms": split_topk["mean_ms"], - "split_total_ms": remap_only["mean_ms"] + split_topk["mean_ms"], + "remap_ms": remap_ms, + "topk_after_remap_ms": topk_after_remap_ms, + "split_total_ms": split_total_ms, "fused_ms": fused["mean_ms"], **stats, } @@ -345,8 +646,9 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, def _print_remap_table(results: List[dict]) -> None: header = ( - f"{'mode':<12s} {'remap_us':>9s} {'topk_us':>9s} {'split_us':>9s} " - f"{'fused_us':>9s} {'base_us':>9s} {'thr_bin':>7s} {'thr_size':>8s} {'sel_thr':>7s}" + f"{'mode':<14s} {'remap_ms':>9s} {'topk_ms':>9s} {'split_ms':>9s} " + f"{'fused_ms':>9s} {'base_ms':>9s} {'thr_bin':>7s} {'thr_size':>8s} " + f"{'sel_thr':>7s} {'abv_bins':>8s} {'pg/bin':>7s}" ) for cfg in results: banner = ( @@ -355,36 +657,65 @@ def _print_remap_table(results: List[dict]) -> None: f"dist={cfg['distribution']} pages_per_seg={cfg['pages_per_seg']}]" ) print(banner) - print(" Baseline: mapping_mode=0 (raw fp16 bucketing)") + extra_notes = [] + if cfg.get("naive_ms") is not None: + extra_notes.append("Naive row = topk.cu (CUB full sort)") + if cfg.get("sglang_ori_ms") is not None: + extra_notes.append("sglang_ori row = topk_sglang_ori.cu") + notes_str = "" + if extra_notes: + notes_str = " | " + " | ".join(extra_notes) + print(f" Baseline: topk_sglang.cu (CUB two-stage){notes_str}") print(header) print("-" * len(header)) - base_us = cfg["baseline_ms"] * 1000.0 + base_ms = cfg["baseline_ms"] for row in cfg["modes"]: - label = f"{row['mode_name']}(p={row['power']})" if row["mode"] != 0 else "None" + if row["mode"] == 0: + label = "None" + elif row["mode"] == -1: + label = row.get("mode_name", "sglang_ori") + elif row["mode"] == -2: + label = row.get("mode_name", "Naive") + else: + label = f"{row['mode_name']}(p={row['power']})" + def _fmt(v): + return f"{v:9.4f}" if v is not None else f"{'N/A':>9s}" + fused_str = _fmt(row.get("fused_ms")) print( - f"{label:<12s} " - f"{row['remap_ms'] * 1000.0:9.2f} " - f"{row['topk_after_remap_ms'] * 1000.0:9.2f} " - f"{row['split_total_ms'] * 1000.0:9.2f} " - f"{row['fused_ms'] * 1000.0:9.2f} " - f"{base_us:9.2f} " + f"{label:<14s} " + f"{_fmt(row['remap_ms'])} " + f"{_fmt(row['topk_after_remap_ms'])} " + f"{_fmt(row['split_total_ms'])} " + f"{fused_str} " + f"{base_ms:9.4f} " f"{row['threshold_bin_mean']:7.1f} " f"{row['threshold_bin_size_mean']:8.1f} " - f"{row['selected_from_thr_mean']:7.1f}" + f"{row['selected_from_thr_mean']:7.1f} " + f"{row.get('above_bins_mean', 0.0):8.1f} " + f"{row.get('pages_per_above_bin_mean', 0.0):7.1f}" ) def _run_remap_bench(args) -> None: modes = [int(m) for m in args.mapping_modes] - if 0 not in modes: - modes = [0] + modes + # Mode 0 is emitted as the "None" row from _remap_bench_one_config + # itself (pass-through to the ori baseline). Drop any user-supplied 0 + # to avoid a duplicate row. + modes = [m for m in modes if m != 0] + + distributions = list(args.distributions) + if getattr(args, "_real_histogram", None) is not None: + if "real" not in distributions: + distributions.append("real") + print(f"[remap-bench] 'real' distribution enabled " + f"(histogram total count = {int(args._real_histogram.sum())})") results = [] for bs in args.batch_sizes: for heads in args.num_kv_heads: for seq_len in args.seq_lens: for topk_val in args.topk_vals: - for dist in args.distributions: + for dist in distributions: cfg = _remap_bench_one_config( args, bs, heads, seq_len, topk_val, dist, modes, ) @@ -401,17 +732,23 @@ def _run_remap_bench(args) -> None: def _run_latency_sweep(args) -> None: """Simple baseline-vs-fused latency sweep (no split-phase, no counters).""" modes = [int(m) for m in args.mapping_modes] + distributions = list(args.distributions) + if getattr(args, "_real_histogram", None) is not None and "real" not in distributions: + distributions.append("real") results = [] for bs in args.batch_sizes: for heads in args.num_kv_heads: for seq_len in args.seq_lens: for topk_val in args.topk_vals: - for dist in args.distributions: + for dist in distributions: + real_hist = args._real_histogram if dist == "real" else None inputs = make_topk_inputs( batch_size=bs, num_kv_heads=heads, seq_len=seq_len, page_size=args.page_size, topk_val=topk_val, reserved_bos=args.reserved_bos, reserved_eos=args.reserved_eos, - score_dtype=torch.bfloat16, distribution=dist, + score_dtype=torch.bfloat16, + distribution=dist if dist != "real" else "normal", + real_histogram=real_hist, ) eff_bs = inputs["eff_batch_size"] pages_per_seg = inputs["num_pages_per_seg"] @@ -469,7 +806,14 @@ def main(): p.add_argument("--topk-vals", type=int, nargs="+", default=[30]) p.add_argument("--distributions", type=str, nargs="+", default=["normal"], - choices=["normal", "lognormal", "uniform", "bucket_uniform"]) + choices=["normal", "lognormal", "uniform", "bucket_uniform", "real"], + help="Synthetic distributions. Use 'real' (or --real-histograms) to " + "sample scores from a calibrated raw_histograms.npy.") + p.add_argument("--real-histograms", type=str, default=None, + help="Path to raw_histograms.npy from calibrate_topk.py. When set, a " + "'real' distribution is appended to the sweep so every " + "(mode, hparam) combo is also timed on the calibrated score " + "distribution.") p.add_argument("--mapping-modes", type=int, nargs="+", default=[0, 3, 6, 7], help="Mapping modes to sweep (0=None, 3=Power, 6=Asinh, 7=Log1p, etc.)") @@ -503,6 +847,13 @@ def main(): for m, v in sorted(args._autotune_hparams.items()): print(f" mode {m:>2d} -> {v}") + args._real_histogram = None + if args.real_histograms: + raw = np.load(args.real_histograms) + args._real_histogram = raw.sum(axis=0) if raw.ndim > 1 else raw + print(f"[real] loaded calibrated histogram from {args.real_histograms} " + f"(shape={raw.shape} → [256] aggregate)") + args._mapping_lut = None args._mapping_quantiles = None if args.lut_path: diff --git a/benchmarks/calibrate_topk.py b/benchmarks/calibrate_topk.py index e3524c1..4914133 100644 --- a/benchmarks/calibrate_topk.py +++ b/benchmarks/calibrate_topk.py @@ -38,6 +38,17 @@ def main(): parser.add_argument("--topk-val", type=int, default=30) parser.add_argument("--page-size", type=int, default=16) parser.add_argument("--mem", type=float, default=0.7) + parser.add_argument( + "--max-total-tokens", + type=int, + default=1048576, + help="Hard cap on KV pool token slots (ServerArgs.max_total_tokens). " + "Block-sparse profiling uses a small bytes/token estimate, so the auto " + "budget can be huge on large GPUs; VTXGraphAttnBackend then allocates " + "dense bf16 sparse_prefill K/V buffers proportional to this cap (~4 KiB per " + "token per buffer). For offline calibration, a few hundred K1M tokens " + "is usually enough.", + ) parser.add_argument("--kv-cache-dtype", type=str, default="auto") parser.add_argument("--topk-type", type=str, default="sglang") parser.add_argument("--num-prompts", type=int, default=16, @@ -76,6 +87,7 @@ def main(): vortex_module_name=args.vortex_module_name, vortex_max_seq_lens=12288, mem_fraction_static=args.mem, + max_total_tokens=args.max_total_tokens, kv_cache_dtype=args.kv_cache_dtype, vortex_topk_type=args.topk_type, vortex_topk_mapping_mode=0, # Use mode 0 during calibration diff --git a/csrc/register.cc b/csrc/register.cc index cc201c9..8aa5aea 100644 --- a/csrc/register.cc +++ b/csrc/register.cc @@ -14,6 +14,12 @@ PYBIND11_MODULE(vortex_torch_C, m){ py::arg("eff_batch_size"), py::arg("topk_val"), py::arg("reserved_bos"), py::arg("reserved_eos"), py::arg("max_num_pages")); + m.def("topk_output_sglang_ori", &topk_output_sglang_ori, + py::arg("x"), py::arg("dense_kv_indptr"), + py::arg("indices_out"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages")); m.def("topk_output_sglang_fused", &topk_output_sglang_fused, py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), diff --git a/csrc/register.h b/csrc/register.h index e86a963..afdb97f 100644 --- a/csrc/register.h +++ b/csrc/register.h @@ -98,6 +98,17 @@ const int64_t reserved_eos, const int64_t max_seq_lengths ); +void topk_output_sglang_ori( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +at::Tensor& indices_out, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_num_pages +); + void topk_output_sglang_fused( const at::Tensor& x, const at::Tensor& dense_kv_indptr, diff --git a/csrc/topk.cu b/csrc/topk.cu index 70d2000..081bddf 100644 --- a/csrc/topk.cu +++ b/csrc/topk.cu @@ -196,8 +196,20 @@ const int64_t max_num_pages reserved_bos, reserved_eos ); + } else if (max_num_pages <= 8192){ + TopKOutput_BF16_Kernel<512, 16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos + ); } else { - TORCH_CHECK(false); + TORCH_CHECK(false, "topk_output: max_num_pages=", max_num_pages, + " exceeds the supported template ladder (8192)."); } } \ No newline at end of file diff --git a/csrc/topk_mapping.cuh b/csrc/topk_mapping.cuh index 447e539..c645acb 100644 --- a/csrc/topk_mapping.cuh +++ b/csrc/topk_mapping.cuh @@ -33,7 +33,24 @@ enum TopKMappingMode { MAPPING_ERF = 9, // erf(alpha * x) MAPPING_TANH = 10, // tanh(alpha * x) MAPPING_SUBTRACT = 11, // x - pivot, with pivot = power_exp (free hyperparameter) - MAPPING_EXP_STRETCH = 13, // exp(alpha * x) + MAPPING_EXP_STRETCH = 13, // exp(alpha * x) + // Top-spreading transforms (see CLAUDE.md / remap bench plan): + // amplify differences in the high-score region so the top-K values + // occupy multiple Stage-1 bins instead of collapsing into one. + MAPPING_SHIFT_POW2 = 15, // sign(x - p) * (x - p)^2 [p = power_exp] + MAPPING_SHIFT_POW3 = 16, // (x - p)^3 [p = power_exp] + MAPPING_LINEAR_STEEP = 17, // x + k * max(x, 0) [k = power_exp] + // One-sided spread: collapse below-pivot values into a single bin so + // every above-pivot page gets its own slice of the 256-bin histogram. + MAPPING_HALF_SQUARE = 18, // max(x - p, 0)^2 [p = power_exp] + MAPPING_HALF_CUBE = 19, // max(x - p, 0)^3 [p = power_exp] + // Bit-level remap: identity value transform, but the Stage-1 bucket + // function in fast_topk_clean_fused switches to a mantissa-heavy bit + // slice (bits [23:16] of convert_to_uint32) that gives 128 sub-bins + // per exponent slot instead of 4. Zero per-element compute overhead; + // the "remap" is the bucket change. Monotonic within 2 adjacent + // fp32 exponent slots. + MAPPING_DENSE_MANT = 20, // identity; bucketing handled in fused kernel }; struct TopKMappingParams { @@ -75,24 +92,99 @@ __device__ __forceinline__ float transform_exp_stretch(float x, float alpha) { return expf(z); } +// Signed squared distance from a pivot. ~3 ops (1 sub, 1 mul, 1 copysign). +// Quadratically amplifies differences between values far from pivot so the +// top-K region gets spread across multiple Stage-1 bins. +__device__ __forceinline__ float transform_shift_pow2(float x, float pivot) { + const float d = x - pivot; + return copysignf(d * d, d); +} + +// Signed cubic of distance from pivot. ~3 ops (1 sub, 2 mul; odd function so +// no copysign). Steeper growth than pow2 for even tighter top-K clusters. +__device__ __forceinline__ float transform_shift_pow3(float x, float pivot) { + const float d = x - pivot; + return d * d * d; +} + +// Half-range linear stretch: positive values get multiplied by (1 + k), +// negative values pass through untouched. ~2 ops (fmax + fma). For softmax- +// style attention scores (which are non-negative after softmax), k = 8..16 +// shifts the positive fp16 exponent up by 3..4 slots and empties out the +// collision at the top of the distribution. +__device__ __forceinline__ float transform_linear_steep(float x, float k) { + return fmaf(k, fmaxf(x, 0.0f), x); +} + +// One-sided shifted square: values below pivot collapse to 0 (they all end +// up in the same low Stage-1 bin), above-pivot values are squared so their +// differences amplify quadratically. ~2 ops (fmax + mul). The whole 256-bin +// histogram becomes dedicated to the top slice of the distribution. +__device__ __forceinline__ float transform_half_square(float x, float pivot) { + const float d = fmaxf(x - pivot, 0.0f); + return d * d; +} + +// One-sided shifted cube: like half_square but cubic. ~3 ops. Best when the +// top-K region is even more tightly clustered and needs steeper amplification. +__device__ __forceinline__ float transform_half_cube(float x, float pivot) { + const float d = fmaxf(x - pivot, 0.0f); + return d * d * d; +} + +// Compile-time templated dispatcher. When the caller knows the mapping mode +// at template-instantiation time, this lets the compiler fully inline the +// transform into the Stage-1 inner loop and eliminate the runtime switch +// that `apply_transform` would otherwise perform per element. Used by the +// per-mode specializations of `fast_topk_clean_fused` in topk_sglang.cu. +template +__device__ __forceinline__ float apply_transform_tmpl(float x, float p) { + if constexpr (MODE == MAPPING_POWER) return transform_power(x, p); + else if constexpr (MODE == MAPPING_LOG) return transform_log(x); + else if constexpr (MODE == MAPPING_ASINH) return transform_asinh(x, p); + else if constexpr (MODE == MAPPING_LOG1P) return transform_log1p(x, p); + else if constexpr (MODE == MAPPING_ERF) return transform_erf(x, p); + else if constexpr (MODE == MAPPING_TANH) return transform_tanh(x, p); + else if constexpr (MODE == MAPPING_SUBTRACT) return x - p; + else if constexpr (MODE == MAPPING_EXP_STRETCH) return transform_exp_stretch(x, p); + else if constexpr (MODE == MAPPING_SHIFT_POW2) return transform_shift_pow2(x, p); + else if constexpr (MODE == MAPPING_SHIFT_POW3) return transform_shift_pow3(x, p); + else if constexpr (MODE == MAPPING_LINEAR_STEEP) return transform_linear_steep(x, p); + else if constexpr (MODE == MAPPING_HALF_SQUARE) return transform_half_square(x, p); + else if constexpr (MODE == MAPPING_HALF_CUBE) return transform_half_cube(x, p); + else if constexpr (MODE == MAPPING_DENSE_MANT) return fmaxf(x, p); + else return x; // NONE / TRUNC8 +} + // Pure element-wise dispatcher. Returns the *float value* after the transform. // For bin-selection modes (LUT_CDF / QUANTILE) this is identity: the mapping // happens in compute_stage1_bin() below instead of via a float transform, so // Stage-2 tie-breaking uses the raw score bits for those modes. __device__ __forceinline__ float apply_transform(float x, const TopKMappingParams& params) { switch (params.mode) { - case MAPPING_POWER: return transform_power(x, params.power_exp); - case MAPPING_LOG: return transform_log(x); - case MAPPING_ASINH: return transform_asinh(x, params.power_exp); - case MAPPING_LOG1P: return transform_log1p(x, params.power_exp); - case MAPPING_ERF: return transform_erf(x, params.power_exp); - case MAPPING_TANH: return transform_tanh(x, params.power_exp); - case MAPPING_SUBTRACT: return x - params.power_exp; - case MAPPING_EXP_STRETCH: return transform_exp_stretch(x, params.power_exp); + case MAPPING_POWER: return transform_power(x, params.power_exp); + case MAPPING_LOG: return transform_log(x); + case MAPPING_ASINH: return transform_asinh(x, params.power_exp); + case MAPPING_LOG1P: return transform_log1p(x, params.power_exp); + case MAPPING_ERF: return transform_erf(x, params.power_exp); + case MAPPING_TANH: return transform_tanh(x, params.power_exp); + case MAPPING_SUBTRACT: return x - params.power_exp; + case MAPPING_EXP_STRETCH: return transform_exp_stretch(x, params.power_exp); + case MAPPING_SHIFT_POW2: return transform_shift_pow2(x, params.power_exp); + case MAPPING_SHIFT_POW3: return transform_shift_pow3(x, params.power_exp); + case MAPPING_LINEAR_STEEP: return transform_linear_steep(x, params.power_exp); + case MAPPING_HALF_SQUARE: return transform_half_square(x, params.power_exp); + case MAPPING_HALF_CUBE: return transform_half_cube(x, params.power_exp); + // MAPPING_DENSE_MANT clamps small/negative values to `power_exp` + // (default 0.5) so the subsequent dense bit bucket in the fused + // kernel sees a narrow 1–2 exponent window of positive values. + // Values at/below the clamp all hash to the lowest bin, which + // is always below the topk threshold in practice. + case MAPPING_DENSE_MANT: return fmaxf(x, params.power_exp); case MAPPING_LUT_CDF: case MAPPING_QUANTILE: case MAPPING_TRUNC8: - default: return x; // NONE / TRUNC8 / LUT_CDF / QUANTILE + default: return x; // NONE / TRUNC8 / LUT_CDF / QUANTILE } } diff --git a/csrc/topk_sglang.cu b/csrc/topk_sglang.cu index 2466f57..fa9c825 100644 --- a/csrc/topk_sglang.cu +++ b/csrc/topk_sglang.cu @@ -99,6 +99,23 @@ __device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); } +// Mantissa-heavy Stage-1 bucket for MAPPING_DENSE_MANT. Returns bits +// [23:16] of the sign-adjusted float32 key = 1 exp LSB + 7 top +// mantissa bits. This yields 128 mantissa sub-bins per exp slot (vs +// 4 in the current fp16 scheme — 32× more resolution) and is strictly +// monotonic across 2 adjacent fp32 exponent slots (factor-of-4 value +// range). Designed for the common case where the top-K scores cluster +// tightly: softmax-attention outputs on Qwen / Llama typically live +// in ~1 exp slot of magnitude near the top. Values with exponents +// outside the 2-slot monotonic window collide with lower bins, which +// only causes a correctness issue if top-K elements span more than +// 2 exp slots — verified empirically before shipping. +__device__ __forceinline__ auto convert_to_uint8_dense(float x) -> uint8_t { + const uint32_t bits = __float_as_uint(x); + const uint32_t key = (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); + return static_cast((key >> 16) & 0xFFu); +} + // ---- Vortex additions ---- template @@ -632,7 +649,7 @@ __device__ void fast_topk_clean( // benchmarking kernel, the remapped Stage-2 ordering is acceptable. // No pre-pass, no LUT, no shared-memory mapping state. // ====================================================================== -template +template __device__ void fast_topk_clean_fused( const ScoreT* __restrict__ input, int* __restrict__ index, @@ -651,34 +668,48 @@ __device__ void fast_topk_clean_fused( alignas(128) __shared__ int f_threshold_bin_id; alignas(128) __shared__ int f_num_input[2]; - // Shared-memory tables for MAPPING_LUT_CDF / MAPPING_QUANTILE. Loaded - // once at kernel entry and read per element in Stage 1. Other modes - // leave them untouched. - __shared__ uint8_t s_mapping_lut[256]; - __shared__ float s_mapping_quantiles[256]; + // Per-element Stage-1 bin cache. Pass 1 of Stage 1 writes one byte per + // element; pass 2 reads it back so each element only pays a single + // apply_transform + global score read instead of two. Sized to the + // maximum `pages_per_seg` the bench drivers use (topk=2048 config has + // seq_len=32768 / page_size=8 = 4096 pages per segment; topk=30 has + // 2048). Shrinking from 8192 to 4096 freed 4 KB of static SMEM per + // block, which lifts occupancy from 5 → 6 blocks/SM on B200. + constexpr int kFusedMaxLen = 4096; + __shared__ uint8_t s_bins[kFusedMaxLen]; auto& f_histogram = f_histogram_buf[0]; extern __shared__ int f_input_idx[][SMEM_INPUT_SIZE]; const int tx = threadIdx.x; - if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { - if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; - __syncthreads(); - } - if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { - if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; - __syncthreads(); - } + // MODE is a compile-time template parameter, so every comparison below + // becomes a constant-folded `if constexpr` branch. The dense bucket + // path (MAPPING_DENSE_MANT) stays in the kernel but is completely + // elided when MODE != MAPPING_DENSE_MANT, and the value-space transform + // path stays in place for standard modes. LUT_CDF / QUANTILE are not + // supported by this templated kernel (they were dropped from the bench + // comparison earlier). + constexpr bool use_dense_bucket = (MODE == MAPPING_DENSE_MANT); if (tx < RADIX + 1) f_histogram[tx] = 0; __syncthreads(); - // Stage 1: LUT/QUANTILE do a shared-memory lookup, everything else - // applies the element-wise transform then buckets via convert_to_uint8. + // Stage 1 pass 1: read each score from global, compute the Stage-1 + // bin via the compile-time-dispatched transform, cache it in s_bins so + // pass 2 can skip the second global read. With MODE known at compile + // time, apply_transform_tmpl inlines to just the chosen + // transform's instructions — no runtime switch overhead. for (int idx = tx; idx < length; idx += BLOCK_SIZE) { const float raw = vortex_to_float(input[idx + row_start]); - const auto bin = compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles); + const float remapped = apply_transform_tmpl(raw, mapping.power_exp); + int bin; + if constexpr (use_dense_bucket) { + bin = static_cast(convert_to_uint8_dense(remapped)); + } else { + bin = static_cast(convert_to_uint8(remapped)); + } + s_bins[idx] = static_cast(bin); ::atomicAdd(&f_histogram[bin], 1); } __syncthreads(); @@ -712,10 +743,11 @@ __device__ void fast_topk_clean_fused( topk -= f_histogram[threshold_bin + 1]; if (topk == 0) { + // Shortcut: every page above threshold gets selected. Read the bin + // from the cache so we don't re-touch global memory or recompute + // apply_transform. for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const float raw = vortex_to_float(input[idx + row_start]); - const auto bin = static_cast( - compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + const int bin = static_cast(s_bins[idx]); if (bin > threshold_bin) { const auto pos = ::atomicAdd(&f_counter, 1); index[pos] = idx; @@ -728,20 +760,33 @@ __device__ void fast_topk_clean_fused( if (tx < RADIX + 1) f_histogram[tx] = 0; __syncthreads(); + // Stage 1 pass 2: read the cached bin from SMEM. For elements + // outside the threshold bin we skip the global-memory load AND the + // apply_transform call entirely. Only the ~thr_size threshold-bin + // candidates re-read raw and re-apply the templated transform to + // compute the sub-bin needed for Stage-2 refinement. + // + // Sub-bin shift selection (compile-time constant): + // - standard modes: Stage-1 used fp16 top-8-bit bucketing, so + // Stage-2 round 0 refines on uint32 bits [31:24] (the most + // significant bits not captured by the fp16 bucket). + // - MAPPING_DENSE_MANT: Stage-1 used bits [23:16], so the next + // useful discriminator is bits [15:8]. Skipping to offset 8 + // directly avoids two wasted Stage-2 rounds. + constexpr int sub_bin_offset_start = use_dense_bucket ? 8 : 24; for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const float raw = vortex_to_float(input[idx + row_start]); - const float remapped = apply_transform(raw, mapping); - const auto bin = static_cast( - compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + const int bin = static_cast(s_bins[idx]); if (bin > threshold_bin) { const auto pos = ::atomicAdd(&f_counter, 1); index[pos] = idx; } else if (bin == threshold_bin) { + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform_tmpl(raw, mapping.power_exp); const auto pos = ::atomicAdd(&f_num_input[0], 1); if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { f_input_idx[0][pos] = idx; const auto b32 = convert_to_uint32(remapped); - const auto sub_bin = (b32 >> 24) & 0xFF; + const auto sub_bin = (b32 >> sub_bin_offset_start) & 0xFF; ::atomicAdd(&f_histogram[sub_bin], 1); } } @@ -749,9 +794,17 @@ __device__ void fast_topk_clean_fused( __syncthreads(); } - // stage 2: refine on raw bits of the remapped value + // stage 2: refine on raw bits of the remapped value. The per-round + // bit offset matches the sub_bin shift chosen above: standard modes + // start at offset 24 (bits [31:24]) and step down by 8 per round; + // MAPPING_DENSE_MANT starts at offset 8 (bits [15:8]) because Stage 1 + // already consumed bits [23:16] in the dense bucket. Both values are + // compile-time constants since MODE is a template parameter. + constexpr int stage2_offset_start = use_dense_bucket ? 8 : 24; + constexpr int stage2_max_rounds = use_dense_bucket ? 2 : 4; #pragma unroll 4 for (int round = 0; round < 4; ++round) { + if (round >= stage2_max_rounds) break; __shared__ int f_last_remain; const auto r_idx = round % 2; @@ -772,9 +825,9 @@ __device__ void fast_topk_clean_fused( if (topk == 0) { for (int i = tx; i < num_input; i += BLOCK_SIZE) { const auto idx = f_input_idx[r_idx][i]; - const auto offset = 24 - round * 8; + const auto offset = stage2_offset_start - round * 8; const float raw = vortex_to_float(input[idx + row_start]); - const float remapped = apply_transform(raw, mapping); + const float remapped = apply_transform_tmpl(raw, mapping.power_exp); const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; if (bin > threshold_bin) { const auto pos = ::atomicAdd(&f_counter, 1); @@ -790,14 +843,18 @@ __device__ void fast_topk_clean_fused( for (int i = tx; i < num_input; i += BLOCK_SIZE) { const auto idx = f_input_idx[r_idx][i]; const float raw = vortex_to_float(input[idx + row_start]); - const float remapped = apply_transform(raw, mapping); - const auto offset = 24 - round * 8; + const float remapped = apply_transform_tmpl(raw, mapping.power_exp); + const auto offset = stage2_offset_start - round * 8; const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; if (bin > threshold_bin) { const auto pos = ::atomicAdd(&f_counter, 1); index[pos] = idx; } else if (bin == threshold_bin) { - if (round == 3) { + // Last refinement round: we have no more discriminator bits + // below the current offset, so emit any remaining elements as + // "tie-break fallback" via f_last_remain (ensures topk is met + // even when thr_size > sel_thr at the finest granularity). + if (round == stage2_max_rounds - 1) { const auto pos = ::atomicAdd(&f_last_remain, -1); if (pos > 0) { index[target_k - pos] = idx; @@ -855,7 +912,7 @@ void TopKOutput_Clean_Kernel( } } -template +template __global__ __launch_bounds__(kThreadsPerBlock) void TopKOutput_Fused_Kernel( const ScoreT* __restrict__ score, @@ -882,7 +939,7 @@ void TopKOutput_Fused_Kernel( + page_reserved_bos; __shared__ int s_indices[VORTEX_MAX_TOPK]; - fast_topk_clean_fused(score_blk, s_indices, 0, nblk, topk_val, mapping); + fast_topk_clean_fused(score_blk, s_indices, 0, nblk, topk_val, mapping); __syncthreads(); const int tx = threadIdx.x; @@ -891,16 +948,32 @@ void TopKOutput_Fused_Kernel( } } +// Inverse of vortex_to_float: narrow a float back to ScoreT for the +// bf16-output remap path so the subsequent topk kernel can read half +// the bytes of a fp32 remapped buffer. +template +__device__ __forceinline__ T float_to_vortex(float x); +template <> +__device__ __forceinline__ float float_to_vortex(float x) { return x; } +template <> +__device__ __forceinline__ __nv_bfloat16 float_to_vortex<__nv_bfloat16>(float x) { + return __float2bfloat16(x); +} + // Remap-only kernel: applies the element-wise transform to each score // in the [dense_kv_indptr[b] + reserved_bos, dense_kv_indptr[b+1] - reserved_eos) -// range and writes the result into a float32 output tensor. Used by -// the split-phase benchmark (remap → unmapped topk). -template +// range and writes the result into an output tensor (OutT = float or +// bf16). Used by the split-phase benchmark (remap → unmapped topk). +// Writing bf16 halves memory bandwidth on the output and on the +// subsequent topk read; precision-wise it's lossless for the Stage-1 +// 8-bit bucket because fp16/bf16 both discard more mantissa than the +// bucket uses. +template __global__ __launch_bounds__(kThreadsPerBlock) void TopKRemapOnly_Kernel( const ScoreT* __restrict__ score, const int* __restrict__ dense_kv_indptr, - float* __restrict__ remapped, + OutT* __restrict__ remapped, const int page_reserved_bos, const int page_reserved_eos, const TopKMappingParams mapping) @@ -914,10 +987,11 @@ void TopKRemapOnly_Kernel( if (nblk <= 0) return; const ScoreT* __restrict__ score_blk = score + start; - float* __restrict__ remap_blk = remapped + start; + OutT* __restrict__ remap_blk = remapped + start; for (int i = tx; i < nblk; i += kThreadsPerBlock) { - remap_blk[i] = apply_transform(vortex_to_float(score_blk[i]), mapping); + const float y = apply_transform(vortex_to_float(score_blk[i]), mapping); + remap_blk[i] = float_to_vortex(y); } } @@ -1118,6 +1192,18 @@ void topk_output_sglang_fused( "topk_output_sglang_fused: topk_val (", topk_val, ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + // Caller contract: max_num_pages must be <= 4096, the static SMEM + // `s_bins` cache size inside the templated fused kernel. The bench + // drivers stay within this bound; no runtime check is emitted in + // the hot path. + + // The `mapping_lut` / `mapping_quantiles` optional tensors are + // retained in the pybind signature for API backward compatibility + // but are ignored: the templated fused kernel drops the LUT_CDF / + // QUANTILE code paths entirely. + (void)mapping_lut; + (void)mapping_quantiles; + CHECK_CUDA(x); CHECK_CUDA(dense_kv_indptr); CHECK_CUDA(sparse_kv_indptr); @@ -1125,51 +1211,66 @@ void topk_output_sglang_fused( CHECK_CUDA(sparse_kv_indices); TopKMappingParams mapping{}; - mapping.mode = static_cast(mapping_mode); + mapping.mode = static_cast(mapping_mode); mapping.power_exp = static_cast(mapping_power); - mapping.lut = nullptr; + mapping.lut = nullptr; mapping.quantiles = nullptr; - if (mapping_lut.has_value()) { - const auto& lut = mapping_lut.value(); - CHECK_CUDA(lut); - TORCH_CHECK(lut.dim() == 1 && lut.size(0) == 256 && lut.scalar_type() == at::ScalarType::Byte, - "mapping_lut must be a 1D uint8 tensor of size 256"); - mapping.lut = lut.data_ptr(); - } - if (mapping_quantiles.has_value()) { - const auto& q = mapping_quantiles.value(); - CHECK_CUDA(q); - TORCH_CHECK(q.dim() == 1 && q.size(0) == 256 && q.scalar_type() == at::ScalarType::Float, - "mapping_quantiles must be a 1D float32 tensor of size 256"); - mapping.quantiles = q.data_ptr(); - } dim3 nblks(eff_batch_size); dim3 nthreads(kThreadsPerBlock); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + // Each mapping mode compiles to its own kernel specialization so + // apply_transform_tmpl is fully inlined (no runtime switch on + // mode in the inner loop). The wrapper's outer dispatch is a one- + // time per-call cost, negligible relative to the kernel runtime. + #define VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MODE_VAL) \ + do { \ + setup_kernel_smem_once, kSmem>(); \ + TopKOutput_Fused_Kernel<<>>( \ + PTR_EXPR, \ + dense_kv_indptr.data_ptr(), \ + sparse_kv_indptr.data_ptr(), \ + dense_kv_indices.data_ptr(), \ + sparse_kv_indices.data_ptr(), \ + topk_val, reserved_bos, reserved_eos, mapping); \ + } while (0) + + #define VORTEX_DISPATCH_MODE(DTYPE, PTR_EXPR) \ + do { \ + switch (mapping.mode) { \ + case MAPPING_NONE: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_NONE); break; \ + case MAPPING_POWER: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_POWER); break; \ + case MAPPING_LOG: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_LOG); break; \ + case MAPPING_ASINH: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_ASINH); break; \ + case MAPPING_LOG1P: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_LOG1P); break; \ + case MAPPING_TRUNC8: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_TRUNC8); break; \ + case MAPPING_ERF: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_ERF); break; \ + case MAPPING_TANH: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_TANH); break; \ + case MAPPING_SUBTRACT: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_SUBTRACT); break; \ + case MAPPING_EXP_STRETCH: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_EXP_STRETCH); break; \ + case MAPPING_SHIFT_POW2: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW2); break; \ + case MAPPING_SHIFT_POW3: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW3); break; \ + case MAPPING_LINEAR_STEEP:VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_LINEAR_STEEP); break; \ + case MAPPING_HALF_SQUARE: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_HALF_SQUARE); break; \ + case MAPPING_HALF_CUBE: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_HALF_CUBE); break; \ + case MAPPING_DENSE_MANT: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_DENSE_MANT); break; \ + default: \ + TORCH_CHECK(false, "topk_output_sglang_fused: unsupported mapping_mode ", mapping.mode); \ + } \ + } while (0) + if (x.scalar_type() == at::ScalarType::BFloat16) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Fused_Kernel<__nv_bfloat16><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, reserved_bos, reserved_eos, mapping); + VORTEX_DISPATCH_MODE(__nv_bfloat16, reinterpret_cast<__nv_bfloat16*>(x.data_ptr())); } else if (x.scalar_type() == at::ScalarType::Float) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Fused_Kernel<<>>( - x.data_ptr(), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, reserved_bos, reserved_eos, mapping); + VORTEX_DISPATCH_MODE(float, x.data_ptr()); } else { TORCH_CHECK(false, "topk_output_sglang_fused: unsupported dtype ", x.scalar_type()); } + #undef VORTEX_DISPATCH_MODE + #undef VORTEX_DISPATCH_FUSED + const auto result = cudaGetLastError(); TORCH_CHECK(result == cudaSuccess, "topk_output_sglang_fused kernel failed: ", ::cudaGetErrorString(result)); @@ -1183,7 +1284,7 @@ void topk_output_sglang_fused( void topk_remap_only( const at::Tensor& x, const at::Tensor& dense_kv_indptr, - at::Tensor& remapped, // float32, same numel as x + at::Tensor& remapped, // float32 or bfloat16, same numel as x const int64_t eff_batch_size, const int64_t reserved_bos, const int64_t reserved_eos, @@ -1193,35 +1294,57 @@ void topk_remap_only( CHECK_CUDA(x); CHECK_CUDA(dense_kv_indptr); CHECK_CUDA(remapped); - TORCH_CHECK(remapped.scalar_type() == at::ScalarType::Float, - "remapped output must be float32"); + TORCH_CHECK(remapped.scalar_type() == at::ScalarType::Float + || remapped.scalar_type() == at::ScalarType::BFloat16, + "remapped output must be float32 or bfloat16"); TopKMappingParams mapping{}; - mapping.mode = static_cast(mapping_mode); + mapping.mode = static_cast(mapping_mode); mapping.power_exp = static_cast(mapping_power); - mapping.lut = nullptr; + mapping.lut = nullptr; mapping.quantiles = nullptr; dim3 nblks(eff_batch_size); dim3 nthreads(kThreadsPerBlock); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - if (x.scalar_type() == at::ScalarType::BFloat16) { - TopKRemapOnly_Kernel<__nv_bfloat16><<>>( + // Four-way dispatch on (input dtype, output dtype). bf16→bf16 is the + // new "batch pre-transform" path that halves memory bandwidth vs the + // fp32 output: the remap writes half the bytes and the subsequent + // topk_output_sglang reads half the bytes. Precision is preserved + // because Stage-1 bucketing only uses the top 8 bits of an fp16 key + // which both fp32 and bf16 capture. + #define VORTEX_DISPATCH_REMAP(IN_CPP, OUT_CPP, IN_PTR_EXPR, OUT_PTR_EXPR) \ + TopKRemapOnly_Kernel<<>>( \ + IN_PTR_EXPR, dense_kv_indptr.data_ptr(), OUT_PTR_EXPR, \ + reserved_bos, reserved_eos, mapping) + + const bool in_bf16 = (x.scalar_type() == at::ScalarType::BFloat16); + const bool in_fp32 = (x.scalar_type() == at::ScalarType::Float); + const bool out_bf16 = (remapped.scalar_type() == at::ScalarType::BFloat16); + + if (in_bf16 && out_bf16) { + VORTEX_DISPATCH_REMAP(__nv_bfloat16, __nv_bfloat16, reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - remapped.data_ptr(), - reserved_bos, reserved_eos, mapping); - } else if (x.scalar_type() == at::ScalarType::Float) { - TopKRemapOnly_Kernel<<>>( + reinterpret_cast<__nv_bfloat16*>(remapped.data_ptr())); + } else if (in_bf16 && !out_bf16) { + VORTEX_DISPATCH_REMAP(__nv_bfloat16, float, + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + remapped.data_ptr()); + } else if (in_fp32 && out_bf16) { + VORTEX_DISPATCH_REMAP(float, __nv_bfloat16, x.data_ptr(), - dense_kv_indptr.data_ptr(), - remapped.data_ptr(), - reserved_bos, reserved_eos, mapping); + reinterpret_cast<__nv_bfloat16*>(remapped.data_ptr())); + } else if (in_fp32 && !out_bf16) { + VORTEX_DISPATCH_REMAP(float, float, + x.data_ptr(), + remapped.data_ptr()); } else { TORCH_CHECK(false, "topk_remap_only: unsupported dtype ", x.scalar_type()); } + #undef VORTEX_DISPATCH_REMAP + const auto result = cudaGetLastError(); TORCH_CHECK(result == cudaSuccess, "topk_remap_only kernel failed: ", ::cudaGetErrorString(result)); diff --git a/csrc/topk_sglang_ori.cu b/csrc/topk_sglang_ori.cu new file mode 100644 index 0000000..55a99b2 --- /dev/null +++ b/csrc/topk_sglang_ori.cu @@ -0,0 +1,619 @@ +/** + * @NOTE: This file is adapted from + * https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_v32/topk_selector.py + * We: + * 1. adapt from tilelang to pure cuda + * 2. optimize the performance a little + * 3. fix the potential illegal memory access + */ + #include + #include + #include + #include + #include + #include + #include + #include + + #include + #include + #include + + namespace { + + // NOTE: TopK is a compile-time constant here because shared-memory + // allocations inside the transform kernels depend on it. We drop it to + // 30 to match the vortex benchmark's --topk-val 30 configuration. The + // transform kernels (decode/prefill/prefill_ragged) still carry a manual + // unroll that assumes TopK==2048; that code path is unreachable from the + // bench (we only invoke fast_topk_interface), so the corresponding + // static_asserts have been removed below. + constexpr int TopK = 30; + constexpr int kThreadsPerBlock = 1024; + + #ifdef USE_ROCM + // On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a + // per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. + #ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES + constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); + #else + constexpr size_t kSmem = 48 * 1024; // bytes + #endif + #else + // Reduced from 128KB to 32KB to improve occupancy. + // Each radix pass needs at most ~TopK candidates in the threshold bin, + // so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. + constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) + #endif + + struct FastTopKParams { + const float* __restrict__ input; // [B, input_stride] + const int32_t* __restrict__ row_starts; // [B] + int32_t* __restrict__ indices; // [B, TopK] + int32_t* __restrict__ lengths; // [B] + int64_t input_stride; + }; + + // when length <= TopK, we can directly write the indices + __device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { + const auto tid = threadIdx.x; + for (int i = tid; i < TopK; i += kThreadsPerBlock) { + indice[i] = (i < length) ? i : -1; + } + } + + // keep the first `length` entries, set others to -1 + __device__ void naive_topk_transform( + const float* __restrict__ score, + int32_t length, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + dst_page_table[i] = (i < length) ? src_page_table[i] : -1; + } + } + + // keep the first `length` entries, set others to -1 + __device__ void naive_topk_transform_ragged( + const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; + } + } + + __device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); + } + + __device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); + } + + __device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { + // An optimized topk kernel copied from tilelang kernel + // We assume length > TopK here, or it will crash + int topk = TopK; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + // allocate for two rounds + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // stage 1: 8bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(input[idx + row_start]); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { + #pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uint8(input[idx + row_start])); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = input[idx + row_start]; + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + /// NOTE: (dark) fuse the histogram computation here + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8bit radix passes + #pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + // clip here to prevent overflow + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = input[idx + row_start]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[TopK - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + /// NOTE: (dark) fuse the histogram computation here + s_input_idx[r_idx ^ 1][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // topk + void topk_kernel(const FastTopKParams params) { + const auto& [input, row_starts, indices, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto indice = indices + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_cuda(score, indice, length); + } else { + return fast_topk_cuda_tl(score, indice, row_start, length); + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // decode + void topk_transform_decode_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride) { + const auto& [input, _1, _2, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = 0; + const auto length = lengths[bid]; + const auto src_page_entry = src_page_table + bid * src_stride; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + // (static_asserts removed because TopK != 2048 in this build; the + // manual unroll below is unreachable from bench_topk.py which only + // calls fast_topk_interface, not this transform variant.) + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // prefill + void topk_transform_prefill_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride, + const int32_t* __restrict__ cu_seqlens_q, + const int64_t prefill_bs) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto length = lengths[bid]; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + + /// NOTE: prefill bs is usually small, we can just use a simple loop here + /// We ensure that last cu_seqlens is equal to number of blocks launched + __shared__ const int32_t* s_src_page_entry; + if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { + if (tid < prefill_bs) { + if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { + s_src_page_entry = src_page_table + tid * src_stride; + } + } + } else { + for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { + if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { + s_src_page_entry = src_page_table + i * src_stride; + } + } + } + __syncthreads(); + const auto src_page_entry = s_src_page_entry; + + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + // (static_asserts removed because TopK != 2048 in this build; the + // manual unroll below is unreachable from bench_topk.py which only + // calls fast_topk_interface, not this transform variant.) + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv + void topk_transform_prefill_ragged_kernel( + const FastTopKParams params, + int32_t* __restrict__ topk_indices_ragged, + const int32_t* __restrict__ topk_indices_offset) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto dst_indices_entry = topk_indices_ragged + bid * TopK; + const auto score = input + bid * input_stride; + const auto offset = topk_indices_offset[bid]; + + if (length <= TopK) { + return naive_topk_transform_ragged(score, length, dst_indices_entry, offset); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + // (static_asserts removed because TopK != 2048 in this build; the + // manual unroll below is unreachable from bench_topk.py which only + // calls fast_topk_interface, not this transform variant.) + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_indices_entry[idx_0] = pos_0 + offset; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_indices_entry[idx_1] = pos_1 + offset; + } + } + + auto get_params( + const at::Tensor& score, + const at::Tensor& lengths, + std::optional row_starts_opt = std::nullopt, + std::optional indices_opt = std::nullopt) -> FastTopKParams { + const auto B = score.size(0); + TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); + if (row_starts_opt.has_value()) { + const auto& row_starts = row_starts_opt.value(); + TORCH_CHECK(row_starts.dim() == 1); + TORCH_CHECK(row_starts.size(0) == B); + } + TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); + TORCH_CHECK(lengths.size(0) == B); + int32_t* indices_data_ptr = nullptr; + if (indices_opt.has_value()) { + const auto& indices = indices_opt.value(); + TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); + TORCH_CHECK(indices.size(0) == B); + TORCH_CHECK(indices.size(1) == TopK); + indices_data_ptr = indices.data_ptr(); + } + + return FastTopKParams{ + .input = score.data_ptr(), + .row_starts = row_starts_opt.has_value() ? row_starts_opt->data_ptr() : nullptr, + .indices = indices_data_ptr, + .lengths = lengths.data_ptr(), + .input_stride = score.stride(0), + }; + } + + template + void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { + #ifdef USE_ROCM + // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, + // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing + // a function pointer directly, so cast explicitly. + return ::cudaFuncSetAttribute( + reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); + #else + // CUDA: keep original behavior (no cast needed). + return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); + #endif + }(); + TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); + } + + } // namespace + + // The public interface functions below collide by name with identically + // named symbols in topk_sglang.cu. Wrap them in `sglang_ori` so both + // translation units can be linked into the same vortex_torch_C extension. + namespace sglang_ori { + + #ifndef CHECK_CUDA + #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + #endif + + void fast_topk_interface( + const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(indices); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + CHECK_CUDA(lengths); + const auto params = get_params(score, lengths, row_starts_opt, indices); + const auto B = score.size(0); + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + setup_kernel_smem_once(); + topk_kernel<<>>(params); + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); + } + + void fast_topk_transform_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& dst_page_table, + const at::Tensor& src_page_table, + const at::Tensor& cu_seqlens_q, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(dst_page_table); + CHECK_CUDA(src_page_table); + CHECK_CUDA(cu_seqlens_q); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); + TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); + TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); + const auto prefill_bs = cu_seqlens_q.size(0) - 1; + TORCH_CHECK(dst_page_table.size(0) == B); + TORCH_CHECK(dst_page_table.size(1) == TopK); + TORCH_CHECK(src_page_table.size(0) == prefill_bs); + TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + const auto src_stride = src_page_table.stride(0); + + // dispatch to decode or prefill + // extend and draft extend: row_starts_opt is not null, invokes the prefill kernel + // decode: row_starts_opt is null, invokes the decode kernel + // target verify: row_starts_opt is null, invokes the prefill kernel + const auto is_decode = !row_starts_opt.has_value() && prefill_bs == B; + if (is_decode) { + setup_kernel_smem_once(); + topk_transform_decode_kernel<<>>( + params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride); + } else { + setup_kernel_smem_once(); + topk_transform_prefill_kernel<<>>( + params, + dst_page_table.data_ptr(), + src_page_table.data_ptr(), + src_stride, + cu_seqlens_q.data_ptr(), + prefill_bs); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); + } + + void fast_topk_transform_ragged_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& topk_indices_ragged, + const at::Tensor& topk_indices_offset, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(topk_indices_ragged); + CHECK_CUDA(topk_indices_offset); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous()); + TORCH_CHECK(topk_indices_offset.dim() == 1); + + TORCH_CHECK(topk_indices_ragged.size(0) == B); + TORCH_CHECK(topk_indices_ragged.size(1) == TopK); + TORCH_CHECK(topk_indices_offset.size(0) == B); + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + + setup_kernel_smem_once(); + topk_transform_prefill_ragged_kernel<<>>( + params, topk_indices_ragged.data_ptr(), topk_indices_offset.data_ptr()); + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); + } + + } // namespace sglang_ori + +// ====================================================================== +// Thin vortex_torch_C adapter: accepts the same CSR-ish inputs as +// topk_output_sglang so bench_topk.py can treat the original SGLang kernel +// as an alternate baseline. The ori kernel has TopK baked in as a compile- +// time constant; this build sets it to 30 to match --topk-val 30. +// ====================================================================== +void topk_output_sglang_ori( + const at::Tensor& x, // [total_dense, 1, 1] or [total_dense], bf16/fp32 + const at::Tensor& dense_kv_indptr, // int32 [eff_bs + 1] (unused — synthetic bench rows are uniform) + at::Tensor& indices_out, // int32 [eff_bs, TopK] + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages) +{ + TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor"); + TORCH_CHECK(dense_kv_indptr.is_cuda(), "dense_kv_indptr must be a CUDA tensor"); + TORCH_CHECK(indices_out.is_cuda(), "indices_out must be a CUDA tensor"); + TORCH_CHECK(indices_out.scalar_type() == at::ScalarType::Int, + "indices_out must be int32"); + TORCH_CHECK(topk_val == static_cast(30), + "topk_output_sglang_ori: this build of the ori kernel hard-codes TopK=30; " + "rebuild topk_sglang_ori.cu with a different TopK if you need another value. " + "Got topk_val=", topk_val); + TORCH_CHECK(indices_out.dim() == 2 + && indices_out.size(0) == eff_batch_size + && indices_out.size(1) == 30, + "indices_out must be [eff_batch_size, 30]"); + + // ori kernel requires fp32 [B, stride] scores. Caller typically passes + // the bf16 score tensor; we materialize an fp32 view once per call. + at::Tensor score_f32; + if (x.scalar_type() == at::ScalarType::Float) { + score_f32 = x.contiguous().view({eff_batch_size, max_num_pages}); + } else if (x.scalar_type() == at::ScalarType::BFloat16) { + score_f32 = x.to(at::kFloat).contiguous().view({eff_batch_size, max_num_pages}); + } else { + TORCH_CHECK(false, "topk_output_sglang_ori: unsupported dtype ", x.scalar_type()); + } + + auto opts_i32 = at::TensorOptions().dtype(at::kInt).device(x.device()); + const int32_t usable_len = + static_cast(max_num_pages - reserved_bos - reserved_eos); + at::Tensor lengths = at::full({eff_batch_size}, usable_len, opts_i32); + at::Tensor row_starts = at::full({eff_batch_size}, + static_cast(reserved_bos), opts_i32); + + sglang_ori::fast_topk_interface( + score_f32, indices_out, lengths, + std::optional(row_starts)); +} \ No newline at end of file diff --git a/csrc/topk_sglang_profile.cu b/csrc/topk_sglang_profile.cu index adba2d0..7fe9981 100644 --- a/csrc/topk_sglang_profile.cu +++ b/csrc/topk_sglang_profile.cu @@ -89,6 +89,16 @@ __device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); } +// Mirror of convert_to_uint8_dense in topk_sglang.cu so that the +// profile kernel (topk_profile_histogram / topk_profile_counters) +// reports accurate thr_bin / thr_size / abv_bins / pg/bin for +// MAPPING_DENSE_MANT. Keep in sync with the production kernel. +__device__ __forceinline__ auto convert_to_uint8_dense(float x) -> uint8_t { + const uint32_t bits = __float_as_uint(x); + const uint32_t key = (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); + return static_cast((key >> 16) & 0xFFu); +} + template __device__ __forceinline__ float vortex_to_float(T x); template <> @@ -164,6 +174,11 @@ __device__ void fast_topk_profile( const int tx = threadIdx.x; + // Mirror of the production kernel: MAPPING_DENSE_MANT bypasses + // apply_transform and uses a mantissa-heavy fp32 bit slice for the + // Stage-1 bucket. + const bool use_dense_bucket = (mapping.mode == MAPPING_DENSE_MANT); + if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; __syncthreads(); @@ -178,7 +193,13 @@ __device__ void fast_topk_profile( for (int idx = tx; idx < length; idx += BLOCK_SIZE) { const float raw = vortex_to_float(input[idx + row_start]); - const auto bin = compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles); + int bin; + if (use_dense_bucket) { + const float clamped = apply_transform(raw, mapping); // fmaxf(x, pivot) + bin = static_cast(convert_to_uint8_dense(clamped)); + } else { + bin = static_cast(compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + } ::atomicAdd(&p_histogram[bin], 1); } __syncthreads(); @@ -221,8 +242,13 @@ __device__ void fast_topk_profile( if (topk == 0) { for (int idx = tx; idx < length; idx += BLOCK_SIZE) { const float raw = vortex_to_float(input[idx + row_start]); - const auto bin = static_cast( - compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + int bin; + if (use_dense_bucket) { + const float clamped = apply_transform(raw, mapping); + bin = static_cast(convert_to_uint8_dense(clamped)); + } else { + bin = static_cast(compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + } if (bin > threshold_bin_0) { const auto pos = ::atomicAdd(&p_counter, 1); index[pos] = idx; @@ -240,11 +266,13 @@ __device__ void fast_topk_profile( if (tx < RADIX + 1) p_histogram[tx] = 0; __syncthreads(); + const int sub_bin_offset_start = use_dense_bucket ? 8 : 24; for (int idx = tx; idx < length; idx += BLOCK_SIZE) { const float raw = vortex_to_float(input[idx + row_start]); const float remapped = apply_transform(raw, mapping); - const auto bin = static_cast( - compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + const auto bin = use_dense_bucket + ? static_cast(convert_to_uint8_dense(remapped)) + : static_cast(compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); if (bin > threshold_bin_0) { const auto pos = ::atomicAdd(&p_counter, 1); index[pos] = idx; @@ -253,7 +281,7 @@ __device__ void fast_topk_profile( if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { p_input_idx[0][pos] = idx; const auto b32 = convert_to_uint32(remapped); - const auto sub_bin = (b32 >> 24) & 0xFF; + const auto sub_bin = (b32 >> sub_bin_offset_start) & 0xFF; ::atomicAdd(&p_histogram[sub_bin], 1); } } @@ -265,10 +293,15 @@ __device__ void fast_topk_profile( } } - // Stage 2 refinement (4 rounds max). Default rounds=4, overwritten on exit. - if (tx == 0 && counters) counters[COUNTER_REFINE_ROUNDS] = 4; + // Stage 2 refinement. Standard modes run up to 4 rounds (offsets + // 24/16/8/0); MAPPING_DENSE_MANT runs up to 2 rounds (offsets 8/0) + // because Stage 1 already consumed bits [23:16] of the fp32 key. + const int stage2_offset_start = use_dense_bucket ? 8 : 24; + const int stage2_max_rounds = use_dense_bucket ? 2 : 4; + if (tx == 0 && counters) counters[COUNTER_REFINE_ROUNDS] = stage2_max_rounds; #pragma unroll 4 for (int round = 0; round < 4; ++round) { + if (round >= stage2_max_rounds) break; __shared__ int p_last_remain; const auto r_idx = round % 2; const auto _raw_num_input = p_num_input[r_idx]; @@ -290,7 +323,7 @@ __device__ void fast_topk_profile( const auto idx = p_input_idx[r_idx][i]; const float raw = vortex_to_float(input[idx + row_start]); const float remapped = apply_transform(raw, mapping); - const auto offset = 24 - round * 8; + const auto offset = stage2_offset_start - round * 8; const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; if (bin > threshold_bin) { const auto pos = ::atomicAdd(&p_counter, 1); @@ -308,13 +341,13 @@ __device__ void fast_topk_profile( const auto idx = p_input_idx[r_idx][i]; const float raw = vortex_to_float(input[idx + row_start]); const float remapped = apply_transform(raw, mapping); - const auto offset = 24 - round * 8; + const auto offset = stage2_offset_start - round * 8; const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; if (bin > threshold_bin) { const auto pos = ::atomicAdd(&p_counter, 1); index[pos] = idx; } else if (bin == threshold_bin) { - if (round == 3) { + if (round == stage2_max_rounds - 1) { const auto pos = ::atomicAdd(&p_last_remain, -1); if (pos > 0) { index[target_k - pos] = idx; @@ -413,11 +446,18 @@ void TopKProfileHistogram_Kernel( if (tx < RADIX) s_histogram[tx] = 0; __syncthreads(); + const bool use_dense_bucket = (mapping.mode == MAPPING_DENSE_MANT); if (nblk > 0) { const ScoreT* __restrict__ score_blk = score + start; for (int i = tx; i < nblk; i += BLOCK_SIZE) { const float raw = vortex_to_float(score_blk[i]); - const auto bin = compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles); + int bin; + if (use_dense_bucket) { + const float clamped = apply_transform(raw, mapping); + bin = static_cast(convert_to_uint8_dense(clamped)); + } else { + bin = static_cast(compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + } ::atomicAdd(&s_histogram[bin], 1); } } diff --git a/examples/remap_function_bench_topk2028.sh b/examples/remap_function_bench_topk2028.sh new file mode 100755 index 0000000..6d95b59 --- /dev/null +++ b/examples/remap_function_bench_topk2028.sh @@ -0,0 +1,252 @@ +#!/usr/bin/env bash +# ============================================================ +# Remap Function Benchmark +# +# Compares four kernel configurations for TopK page selection: +# 1. baseline — unmapped topk (topk_output_sglang) +# 2. fused remap + topk — topk_output_sglang_fused +# 3. remap only — topk_remap_only (standalone kernel) +# 4. unmapped topk on remapped — topk_output_sglang on the output +# buffer of step 3 +# +# Per configuration the script also reports the threshold-bin +# position, the threshold-bin size, and how many values are +# selected from the threshold bin (derived from +# topk_profile_counters — collected after all timing measurements, +# never interleaved with latency measurements). +# +# Pipeline: +# 1. Calibrate — run `calibrate_topk.py` on the chosen model to +# collect the REAL per-segment topk distribution +# (raw_histograms.npy). Skippable via +# --real-histograms /path/to/raw_histograms.npy. +# 2. Autotune — run `autotune_topk_mapping.py` on those real +# histograms and pick the per-mode hyperparameter +# with the LOWEST measured topk kernel latency. +# 3. Remap bench— run `bench_topk.py --remap-bench` with the +# autotune-selected per-mode hyperparameters. +# +# Argument layout mirrors run_distribution_analysis_new.sh. +# +# Usage: +# # Default (Qwen/Qwen3-1.7B, block_size=16): +# bash remap_function_bench.sh --gpu 5 +# +# # Larger model + larger page/block size: +# bash remap_function_bench.sh --gpu 0 \ +# --model-name Qwen/Qwen3-8B \ +# --block-size 32 \ +# --seq-len 16384 --topk-val 512 \ +# --modes "0 3 6 7" +# +# # Reuse an existing calibration: +# bash remap_function_bench.sh --gpu 0 \ +# --model-name Qwen/Qwen3-8B \ +# --real-histograms /path/to/calibration/raw_histograms.npy +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=5 +MODEL_NAME="Qwen/Qwen3-8B" +TOPK_VAL=2048 +MEM=0.7 +# Cap KV / VTX sparse prefill buffer sizing during Step 1 (see calibrate_topk.py --help). +MAX_TOTAL_TOKENS=64768 +ALGO="block_sparse_attention" +SAMPLE_STRIDE=1 +SEQ_LEN=32768 +BLOCK_SIZE=8 +BATCH_SIZE=4 +NUM_KV_HEADS=8 +DISTRIBUTIONS="normal bucket_uniform" +# Modes 1 (LUT_CDF) and 2 (Quantile) are no longer benchmarked — their +# mapping happens inside compute_stage1_bin, not apply_transform, so +# split-phase timing isn't meaningful for them. +MAPPING_MODES="0 3 6 7 8 9 10 11 13 15 16 17 18 19 20" +# Fallback hparam used only if autotune is explicitly skipped. +MAPPING_HPARAM=0.5 +REPEAT=100 +WARMUP=20 +# Empty by default — Step 1 will calibrate on the selected model. +# Pass --real-histograms /path/to/raw_histograms.npy to skip calibration. +# REAL_HISTOGRAMS="/home/zhuominc/xinrui_projects/vortex_torch/examples/calibration/raw_histograms.npy" +#REAL_HISTOGRAMS="/home/zhuominc/xinrui_projects/vortex_torch/examples/calibration/raw_histograms_qwen3-4B.npy" +REAL_HISTOGRAMS="" +SKIP_AUTOTUNE=0 + +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --sample-stride) SAMPLE_STRIDE="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --distributions) DISTRIBUTIONS="$2"; shift 2 ;; + --modes) MAPPING_MODES="$2"; shift 2 ;; + --mapping-hparam) MAPPING_HPARAM="$2"; shift 2 ;; + --repeat) REPEAT="$2"; shift 2 ;; + --warmup) WARMUP="$2"; shift 2 ;; + --skip-autotune) SKIP_AUTOTUNE=1; shift 1 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" + +# Qwen3-1.7B does not use DeepGEMM (no FP8/MoE path). +# Disable its JIT to silence "NVCC Compiler not found ... use NVRTC" on Blackwell. +export SGL_ENABLE_JIT_DEEPGEMM="${SGL_ENABLE_JIT_DEEPGEMM:-true}" + +# If DeepGEMM JIT is ever re-enabled, make sure it can find nvcc. +if [ -z "${DG_JIT_NVCC_COMPILER:-}" ]; then + if [ -x /usr/local/cuda/bin/nvcc ]; then + export CUDA_HOME="${CUDA_HOME:-/usr/local/cuda}" + export PATH="${CUDA_HOME}/bin:${PATH}" + export DG_JIT_NVCC_COMPILER="${CUDA_HOME}/bin/nvcc" + elif command -v nvcc >/dev/null 2>&1; then + export DG_JIT_NVCC_COMPILER="$(command -v nvcc)" + fi +fi + +# Validate seq_len: need pages/seg > topk_val (3 reserved pages) +MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) +if [ "${SEQ_LEN}" -lt "${MIN_SEQ_LEN}" ]; then + echo "ERROR: --seq-len ${SEQ_LEN} too small for --topk-val ${TOPK_VAL} @ --block-size ${BLOCK_SIZE}." + echo " Minimum: ${MIN_SEQ_LEN} (pages/seg must exceed topk_val + 3 reserved pages)" + exit 1 +fi + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +RUN_DIR="${RESULTS_DIR}/remap_bench_${MODEL_SLUG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" +mkdir -p "${RUN_DIR}" + +echo "============================================================" +echo "Remap Function Benchmark" +echo " Model: ${MODEL_NAME}" +echo " Algorithm: ${ALGO}" +echo " TopK: ${TOPK_VAL}" +echo " Block size: ${BLOCK_SIZE}" +echo " Seq len: ${SEQ_LEN} ($(( SEQ_LEN / BLOCK_SIZE )) pages/seg)" +echo " Batch size: ${BATCH_SIZE}" +echo " KV heads: ${NUM_KV_HEADS}" +echo " Distributions: ${DISTRIBUTIONS}" +echo " Mapping modes: ${MAPPING_MODES}" +echo " Fallback hparam: ${MAPPING_HPARAM} (used only when --skip-autotune)" +echo " Max total tokens: ${MAX_TOTAL_TOKENS} (calibration KV / VTX buffer cap)" +echo " GPU: ${GPU_ID}" +echo " Sample stride: ${SAMPLE_STRIDE}" +echo " Real histograms: ${REAL_HISTOGRAMS:-}" +echo " Output: ${RUN_DIR}" +echo "============================================================" + +# ── Step 1: Calibrate — collect real-distribution topk histograms ── +# calibrate_topk.py runs the model end-to-end with histogram profiling +# enabled and writes per-segment raw_histograms.npy. The histograms are +# aggregated over every layer and every decode/prefill step so the +# autotune in Step 2 sees the true attention-score distribution. +if [ -n "${REAL_HISTOGRAMS}" ]; then + echo "" + echo ">>> Step 1: SKIPPED (using provided --real-histograms ${REAL_HISTOGRAMS})" + REAL_HIST_PATH="${REAL_HISTOGRAMS}" +else + echo "" + echo ">>> Step 1: Calibrating ${MODEL_NAME} — collecting real topk histograms" + CALIBRATION_DIR="${RUN_DIR}/calibration" + mkdir -p "${CALIBRATION_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --vortex-module-name "${ALGO}" \ + --output-dir "${CALIBRATION_DIR}" \ + 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" + REAL_HIST_PATH="${CALIBRATION_DIR}/raw_histograms.npy" + echo ">>> Step 1: Done. Calibration saved to ${CALIBRATION_DIR}" +fi + +# Modes 1 (LUT_CDF) and 2 (Quantile) are dropped from the comparison, so +# lut.npy / quantiles.npy produced by calibration are no longer consumed. + +# ── Step 2: Auto-tune hyperparameters by profiled fused-topk latency ── +# For every (mode, hparam) combo in the sweep grid, the autotune runs the +# fused remap+topk kernel on the real histogram and measures end-to-end +# kernel latency with CUDA events. The per-mode hparam with the lowest +# measured topk kernel latency wins. +AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" +if [ "${SKIP_AUTOTUNE}" -eq 1 ]; then + echo "" + echo ">>> Step 2: SKIPPED (using fallback --mapping-hparam ${MAPPING_HPARAM})" + AUTOTUNE_ARGS="" +else + echo "" + echo ">>> Step 2: Auto-tuning hyperparameters by profiled topk kernel latency" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --batch-size "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --collect-stats \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2_autotune.log" + echo ">>> Step 2: Done. Autotune results saved to ${AUTOTUNE_JSON}" + AUTOTUNE_ARGS="--autotune-json ${AUTOTUNE_JSON}" +fi + +# ── Step 3: Remap benchmark (baseline / fused / remap / split) ── +echo "" +echo ">>> Step 3: Timing remap / topk / fused / baseline with autotuned hparams" +REMAP_JSON="${RUN_DIR}/remap_bench.json" +BENCH_EXTRA=() +[ -n "${REAL_HIST_PATH}" ] && BENCH_EXTRA+=(--real-histograms "${REAL_HIST_PATH}") +PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --batch-sizes "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --distributions ${DISTRIBUTIONS} \ + --mapping-modes ${MAPPING_MODES} \ + --mapping-hparam "${MAPPING_HPARAM}" \ + ${AUTOTUNE_ARGS} \ + "${BENCH_EXTRA[@]}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --output-json "${REMAP_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step3_remap_bench.log" +echo ">>> Step 3: Done. Remap bench saved to ${REMAP_JSON}" + +# ── Summary ─────────────────────────────────────────────────── +echo "" +echo "============================================================" +echo "Remap Function Benchmark Complete" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" +echo " All outputs in: ${RUN_DIR}/" +echo " calibration/raw_histograms.npy — real topk distribution (per layer)" +echo " autotune_results.json — latency-ranked mapping hparams" +echo " remap_bench.json — per-config remap/topk/fused/baseline latencies" +echo " step{1,2,3}_*.log — pipeline logs" +echo "============================================================" diff --git a/examples/remap_function_bench.sh b/examples/remap_function_bench_topk30.sh similarity index 84% rename from examples/remap_function_bench.sh rename to examples/remap_function_bench_topk30.sh index 7d56d57..3cb52e2 100755 --- a/examples/remap_function_bench.sh +++ b/examples/remap_function_bench_topk30.sh @@ -43,6 +43,8 @@ # bash remap_function_bench.sh --gpu 0 \ # --model-name Qwen/Qwen3-8B \ # --real-histograms /path/to/calibration/raw_histograms.npy +# # Tight GPU: lower calibration KV cap (default 1048576): +# bash remap_function_bench_topk30.sh --gpu 0 --max-total-tokens 524288 # ============================================================ set -euo pipefail @@ -50,27 +52,29 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BENCH_DIR="${SCRIPT_DIR}/../benchmarks" # ── Defaults ────────────────────────────────────────────────── -GPU_ID=4 +GPU_ID=5 MODEL_NAME="Qwen/Qwen3-1.7B" -TOPK_VAL=2048 +TOPK_VAL=30 MEM=0.7 +MAX_TOTAL_TOKENS=1048576 ALGO="block_sparse_attention" SAMPLE_STRIDE=1 -SEQ_LEN=65536 +SEQ_LEN=32768 BLOCK_SIZE=16 -BATCH_SIZE=4 +BATCH_SIZE=1 NUM_KV_HEADS=8 DISTRIBUTIONS="normal bucket_uniform" -# Modes 1 (LUT_CDF) and 2 (Quantile) are evaluated only if calibration -# produces lut.npy / quantiles.npy. The shell script detects that below. -MAPPING_MODES="0 1 2 3 6 7 8 9 10 11 13" +# Modes 1 (LUT_CDF) and 2 (Quantile) are no longer benchmarked — their +# mapping happens inside compute_stage1_bin, not apply_transform, so +# split-phase timing isn't meaningful for them. +MAPPING_MODES="0 3 6 7 8 9 10 11 13 15 16 17 18 19 20" # Fallback hparam used only if autotune is explicitly skipped. MAPPING_HPARAM=0.5 REPEAT=100 WARMUP=20 # Empty by default — Step 1 will calibrate on the selected model. # Pass --real-histograms /path/to/raw_histograms.npy to skip calibration. -REAL_HISTOGRAMS="" +REAL_HISTOGRAMS="/home/zhuominc/xinrui_projects/vortex_torch/examples/calibration/raw_histograms.npy" SKIP_AUTOTUNE=0 # ── Parse arguments ─────────────────────────────────────────── @@ -79,6 +83,7 @@ while [[ $# -gt 0 ]]; do --model-name) MODEL_NAME="$2"; shift 2 ;; --topk-val) TOPK_VAL="$2"; shift 2 ;; --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; --gpu) GPU_ID="$2"; shift 2 ;; --algo) ALGO="$2"; shift 2 ;; --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; @@ -98,6 +103,22 @@ while [[ $# -gt 0 ]]; do done export CUDA_VISIBLE_DEVICES="${GPU_ID}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" + +# Qwen3-1.7B does not use DeepGEMM (no FP8/MoE path). +# Disable its JIT to silence "NVCC Compiler not found ... use NVRTC" on Blackwell. +export SGL_ENABLE_JIT_DEEPGEMM="${SGL_ENABLE_JIT_DEEPGEMM:-true}" + +# If DeepGEMM JIT is ever re-enabled, make sure it can find nvcc. +if [ -z "${DG_JIT_NVCC_COMPILER:-}" ]; then + if [ -x /usr/local/cuda/bin/nvcc ]; then + export CUDA_HOME="${CUDA_HOME:-/usr/local/cuda}" + export PATH="${CUDA_HOME}/bin:${PATH}" + export DG_JIT_NVCC_COMPILER="${CUDA_HOME}/bin/nvcc" + elif command -v nvcc >/dev/null 2>&1; then + export DG_JIT_NVCC_COMPILER="$(command -v nvcc)" + fi +fi # Validate seq_len: need pages/seg > topk_val (3 reserved pages) MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) @@ -126,6 +147,7 @@ echo " KV heads: ${NUM_KV_HEADS}" echo " Distributions: ${DISTRIBUTIONS}" echo " Mapping modes: ${MAPPING_MODES}" echo " Fallback hparam: ${MAPPING_HPARAM} (used only when --skip-autotune)" +echo " Max total tokens: ${MAX_TOTAL_TOKENS} (calibration KV / VTX buffer cap)" echo " GPU: ${GPU_ID}" echo " Sample stride: ${SAMPLE_STRIDE}" echo " Real histograms: ${REAL_HISTOGRAMS:-}" @@ -149,7 +171,9 @@ else python "${BENCH_DIR}/calibrate_topk.py" \ --model-name "${MODEL_NAME}" \ --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ --vortex-module-name "${ALGO}" \ --output-dir "${CALIBRATION_DIR}" \ 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" @@ -157,14 +181,8 @@ else echo ">>> Step 1: Done. Calibration saved to ${CALIBRATION_DIR}" fi -# Calibration may have produced lut.npy / quantiles.npy for modes 1 and 2. -CALIB_DIR="$(dirname "${REAL_HIST_PATH}")" -LUT_PATH="" -Q_PATH="" -[ -f "${CALIB_DIR}/lut.npy" ] && LUT_PATH="${CALIB_DIR}/lut.npy" -[ -f "${CALIB_DIR}/quantiles.npy" ] && Q_PATH="${CALIB_DIR}/quantiles.npy" -[ -n "${LUT_PATH}" ] && echo " Calibration LUT: ${LUT_PATH}" -[ -n "${Q_PATH}" ] && echo " Calibration quantile: ${Q_PATH}" +# Modes 1 (LUT_CDF) and 2 (Quantile) are dropped from the comparison, so +# lut.npy / quantiles.npy produced by calibration are no longer consumed. # ── Step 2: Auto-tune hyperparameters by profiled fused-topk latency ── # For every (mode, hparam) combo in the sweep grid, the autotune runs the @@ -179,9 +197,6 @@ if [ "${SKIP_AUTOTUNE}" -eq 1 ]; then else echo "" echo ">>> Step 2: Auto-tuning hyperparameters by profiled topk kernel latency" - AUTOTUNE_EXTRA=() - [ -n "${LUT_PATH}" ] && AUTOTUNE_EXTRA+=(--lut-path "${LUT_PATH}") - [ -n "${Q_PATH}" ] && AUTOTUNE_EXTRA+=(--quantiles-path "${Q_PATH}") PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ --batch-size "${BATCH_SIZE}" \ --num-kv-heads "${NUM_KV_HEADS}" \ @@ -192,7 +207,6 @@ else --warmup "${WARMUP}" \ --repeat "${REPEAT}" \ --collect-stats \ - "${AUTOTUNE_EXTRA[@]}" \ --output-json "${AUTOTUNE_JSON}" \ 2>&1 | tee "${RUN_DIR}/step2_autotune.log" echo ">>> Step 2: Done. Autotune results saved to ${AUTOTUNE_JSON}" @@ -204,8 +218,7 @@ echo "" echo ">>> Step 3: Timing remap / topk / fused / baseline with autotuned hparams" REMAP_JSON="${RUN_DIR}/remap_bench.json" BENCH_EXTRA=() -[ -n "${LUT_PATH}" ] && BENCH_EXTRA+=(--lut-path "${LUT_PATH}") -[ -n "${Q_PATH}" ] && BENCH_EXTRA+=(--quantiles-path "${Q_PATH}") +[ -n "${REAL_HIST_PATH}" ] && BENCH_EXTRA+=(--real-histograms "${REAL_HIST_PATH}") PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ --remap-bench \ --batch-sizes "${BATCH_SIZE}" \ diff --git a/examples/run_distribution_analysis.sh b/examples/run_distribution_analysis.sh index 36d4cd4..2515015 100755 --- a/examples/run_distribution_analysis.sh +++ b/examples/run_distribution_analysis.sh @@ -22,6 +22,7 @@ # --real-histograms /path/to/calibration_dir/raw_histograms.npy # bash run_distribution_analysis.sh --gpu 5 --block-size 16 # bash run_distribution_analysis.sh --watchdog-timeout 0 # disable calibrate watchdog (fork) +# bash run_distribution_analysis.sh --max-total-tokens 1048576 # cap KV / VTX buffers during calibrate # Models (default: 1.7B + 4B). Override with repeated --model-name: # bash run_distribution_analysis.sh --model-name Qwen/Qwen3-1.7B --model-name Qwen/Qwen3-4B # ============================================================ @@ -52,6 +53,7 @@ MODEL_NAMES=( "Qwen/Qwen3-1.7B" "Qwen/Qwen3-4B" ) MODEL_NAMES_USER_SET=0 TOPK_VAL=30 MEM=0.7 +MAX_TOTAL_TOKENS=1048576 ALGO="block_sparse_attention" RADIX_BITS=8 SAMPLE_STRIDE=1 @@ -76,6 +78,7 @@ while [[ $# -gt 0 ]]; do ;; --topk-val) TOPK_VAL="$2"; shift 2 ;; --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; --gpu) GPU_ID="$2"; shift 2 ;; --algo) ALGO="$2"; shift 2 ;; --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; @@ -117,6 +120,7 @@ echo " Block size: ${BLOCK_SIZE} (--page-size in benchmarks)" echo " GPU: ${GPU_ID}" echo " Radix bits: ${RADIX_BITS} ($(( 1 << RADIX_BITS )) bins)" echo " Sample stride: ${SAMPLE_STRIDE}" +echo " Max total tokens: ${MAX_TOTAL_TOKENS} (calibration KV / VTX buffer cap)" if [ "${HAS_WATCHDOG_TIMEOUT}" -eq 1 ]; then echo " Watchdog (cal): ${WATCHDOG_TIMEOUT}s (0 = off, vortex SGLang fork)" else @@ -154,6 +158,7 @@ for MODEL_NAME in "${MODEL_NAMES[@]}"; do --model-name "${MODEL_NAME}" \ --topk-val "${TOPK_VAL}" \ --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ --vortex-module-name "${ALGO}" \ --page-size "${BLOCK_SIZE}" \ --output-dir "${CALIBRATION_DIR}" \ diff --git a/examples/run_distribution_analysis_new.sh b/examples/run_distribution_analysis_new.sh index ec72665..38438bd 100755 --- a/examples/run_distribution_analysis_new.sh +++ b/examples/run_distribution_analysis_new.sh @@ -26,6 +26,7 @@ # --model-name Qwen/Qwen3-8B --block-size 32 # bash run_distribution_analysis_new.sh --gpu 5 \ # --real-histograms /path/to/raw_histograms.npy +# bash run_distribution_analysis_new.sh --gpu 5 --max-total-tokens 524288 # ============================================================ set -euo pipefail @@ -37,6 +38,7 @@ GPU_ID=4 MODEL_NAME="Qwen/Qwen3-1.7B" TOPK_VAL=2048 MEM=0.7 +MAX_TOTAL_TOKENS=1048576 ALGO="block_sparse_attention" SEQ_LEN=65536 BLOCK_SIZE=16 @@ -55,6 +57,7 @@ while [[ $# -gt 0 ]]; do --model-name) MODEL_NAME="$2"; shift 2 ;; --topk-val) TOPK_VAL="$2"; shift 2 ;; --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; --gpu) GPU_ID="$2"; shift 2 ;; --algo) ALGO="$2"; shift 2 ;; --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; @@ -97,6 +100,7 @@ echo " Batch size: ${BATCH_SIZE}" echo " KV heads: ${NUM_KV_HEADS}" echo " Distributions: ${DISTRIBUTIONS}" echo " Mapping modes: ${MAPPING_MODES}" +echo " Max total tokens: ${MAX_TOTAL_TOKENS} (calibration KV / VTX buffer cap)" echo " GPU: ${GPU_ID}" echo " Real histograms: ${REAL_HISTOGRAMS:-}" echo " Output: ${RUN_DIR}" @@ -116,6 +120,7 @@ else --model-name "${MODEL_NAME}" \ --topk-val "${TOPK_VAL}" \ --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ --vortex-module-name "${ALGO}" \ --page-size "${BLOCK_SIZE}" \ --output-dir "${CALIBRATION_DIR}" \ diff --git a/examples/run_topk_benchmark.sh b/examples/run_topk_benchmark.sh index 33c6e40..f3eabff 100755 --- a/examples/run_topk_benchmark.sh +++ b/examples/run_topk_benchmark.sh @@ -21,6 +21,7 @@ # bash run_topk_benchmark.sh --gpu 0 # bash run_topk_benchmark.sh --gpu 0 --model-name Qwen/Qwen3-8B \ # --block-size 32 --topk-val 512 +# bash run_topk_benchmark.sh --gpu 0 --max-total-tokens 1048576 # ============================================================ set -euo pipefail @@ -33,6 +34,7 @@ MODEL_NAME="Qwen/Qwen3-1.7B" TOPK_VAL=30 TRIALS=8 MEM=0.7 +MAX_TOTAL_TOKENS=1048576 ALGO="block_sparse_attention" BLOCK_SIZE=16 BATCH_SIZE=4 @@ -50,6 +52,7 @@ while [[ $# -gt 0 ]]; do --topk-val) TOPK_VAL="$2"; shift 2 ;; --trials) TRIALS="$2"; shift 2 ;; --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; --gpu) GPU_ID="$2"; shift 2 ;; --algo) ALGO="$2"; shift 2 ;; --benchmark) BENCHMARKS="$2"; shift 2 ;; @@ -84,6 +87,7 @@ echo " Seq len: ${SEQ_LEN}" echo " Batch size: ${BATCH_SIZE}" echo " KV heads: ${NUM_KV_HEADS}" echo " Trials: ${TRIALS}" +echo " Max total tokens: ${MAX_TOTAL_TOKENS} (calibration KV / VTX buffer cap)" echo " GPU: ${GPU_ID}" echo " Output: ${RUN_DIR}" echo "============================================================" @@ -101,6 +105,7 @@ else --model-name "${MODEL_NAME}" \ --topk-val "${TOPK_VAL}" \ --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ --vortex-module-name "${ALGO}" \ --page-size "${BLOCK_SIZE}" \ --output-dir "${CALIBRATION_DIR}" \ diff --git a/examples/verify_algo_topk_mapping.sh b/examples/verify_algo_topk_mapping.sh index 711a0f7..f361e59 100644 --- a/examples/verify_algo_topk_mapping.sh +++ b/examples/verify_algo_topk_mapping.sh @@ -26,6 +26,7 @@ BLOCK_SIZE=16 BATCH_SIZE=4 NUM_KV_HEADS=2 SEQ_LEN=32768 +MAX_TOTAL_TOKENS=1048576 REAL_HISTOGRAMS="" SKIP_AUTOTUNE=0 @@ -40,6 +41,7 @@ while [[ $# -gt 0 ]]; do --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; --seq-len) SEQ_LEN="$2"; shift 2 ;; --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; --skip-autotune) SKIP_AUTOTUNE=1; shift 1 ;; *) echo "Unknown option: $1"; exit 1 ;; esac @@ -80,12 +82,14 @@ done # ============================================================ if [ -z "${REAL_HISTOGRAMS}" ]; then CALIBRATION_DIR="${RESULTS_DIR}/calibration_${TIMESTAMP}" + echo ">>> Max total tokens (KV / VTX cap): ${MAX_TOTAL_TOKENS}" for algo in "${sparse_algos[@]}"; do echo ">>> Calibrating ${MODEL_NAME} for ${algo}..." python "${BENCH_DIR}/calibrate_topk.py" \ --model-name "${MODEL_NAME}" \ --topk-val "${TOPK_VAL}" \ --mem 0.7 \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ --vortex-module-name "${algo}" \ --page-size "${BLOCK_SIZE}" \ --output-dir "${CALIBRATION_DIR}" \ diff --git a/examples/verify_algo_topk_mapping_new.sh b/examples/verify_algo_topk_mapping_new.sh index 9116b72..2cdc526 100644 --- a/examples/verify_algo_topk_mapping_new.sh +++ b/examples/verify_algo_topk_mapping_new.sh @@ -28,6 +28,7 @@ BLOCK_SIZE=16 BATCH_SIZE=4 NUM_KV_HEADS=2 SEQ_LEN=32768 +MAX_TOTAL_TOKENS=1048576 REAL_HISTOGRAMS="" SKIP_AUTOTUNE=0 @@ -42,6 +43,7 @@ while [[ $# -gt 0 ]]; do --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; --seq-len) SEQ_LEN="$2"; shift 2 ;; --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; --skip-autotune) SKIP_AUTOTUNE=1; shift 1 ;; *) echo "Unknown option: $1"; exit 1 ;; esac @@ -63,6 +65,7 @@ TIMESTAMP=$(date +%Y%m%d_%H%M%S) if [ -z "${REAL_HISTOGRAMS}" ]; then echo "============================================================" echo "Step 0: Calibrating ${MODEL_NAME} for real-distribution histograms" + echo " Max total tokens (KV / VTX cap): ${MAX_TOTAL_TOKENS}" echo "============================================================" CAL_DIR="${RESULTS_DIR}/calibration_${TIMESTAMP}" mkdir -p "${CAL_DIR}" @@ -70,6 +73,7 @@ if [ -z "${REAL_HISTOGRAMS}" ]; then --model-name "${MODEL_NAME}" \ --topk-val "${TOPK_VAL}" \ --mem 0.7 \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ --vortex-module-name "${sparse_algos[0]}" \ --page-size "${BLOCK_SIZE}" \ --output-dir "${CAL_DIR}" \ diff --git a/setup.py b/setup.py index 0fc46ad..c973181 100644 --- a/setup.py +++ b/setup.py @@ -19,6 +19,7 @@ 'csrc/topk.cu', 'csrc/topk_sglang.cu', 'csrc/topk_sglang_profile.cu', + 'csrc/topk_sglang_ori.cu', ], include_dirs=['csrc'], extra_compile_args={ From fe0b8e2881ba39bb93c36ab043a27af59d078ebc Mon Sep 17 00:00:00 2001 From: UED Date: Wed, 15 Apr 2026 00:50:12 -0400 Subject: [PATCH 21/22] Enhance TopK benchmarking and calibration scripts - Added parameter to for per-head benchmarking. - Introduced function to aggregate per-head configurations. - Updated to include metrics for per-head configurations. - Added disk space check in to ensure sufficient space for model downloads. - Implemented regression guard against saving degenerate histograms in calibration. - Modified example scripts for improved calibration and benchmarking workflows. --- benchmarks/bench_topk.py | 180 ++++++++++++++++++++-- benchmarks/calibrate_topk.py | 57 ++++++- csrc/topk_sglang.cu | 52 +++++-- csrc/utils_sglang.cu | 30 ++-- examples/remap_function_bench_topk2028.sh | 56 +++++-- examples/remap_function_bench_topk30.sh | 30 +++- third_party/sglang | 2 +- vortex_torch/indexer/utils_sglang.py | 2 +- 8 files changed, 353 insertions(+), 56 deletions(-) diff --git a/benchmarks/bench_topk.py b/benchmarks/bench_topk.py index f0860c9..3653c55 100644 --- a/benchmarks/bench_topk.py +++ b/benchmarks/bench_topk.py @@ -397,8 +397,15 @@ def _resolve_hparam(args, mode: int) -> float: def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, - distribution, modes: List[int]) -> dict: - """Time baseline, fused, and split-phase for each mode at one config.""" + distribution, modes: List[int], + head_label: str = "all") -> dict: + """Time baseline, fused, and split-phase for each mode at one config. + + `head_label` is metadata: ``"all"`` for the aggregated table (default), + or a stringified head index ``"0".."N-1"`` for per-head benches. The + caller is responsible for setting ``args._real_histogram`` to the + head-sliced sub-histogram before invoking this function in per-head mode. + """ real_hist = getattr(args, "_real_histogram", None) if distribution == "real" else None inputs = make_topk_inputs( batch_size=batch_size, @@ -489,6 +496,7 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, "topk_val": topk_val, "distribution": distribution, "pages_per_seg": pages_per_seg, + "head": head_label, "baseline_ms": baseline["mean_ms"], "naive_ms": naive_ms, "sglang_ori_ms": sglang_ori_ms, @@ -644,17 +652,27 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, return config +# Stage-2 working-set cap, matches SMEM_INPUT_SIZE in fast_topk_clean_fused +# (32 KB dynamic smem / 2 ping-pong buffers / 4 bytes per int = 4096). +_STAGE2_SMEM_CAP = 4096 + + def _print_remap_table(results: List[dict]) -> None: + # The printed table only carries metrics that participate in the + # fused-kernel cost model. All purely-informational columns + # (thr_bin / sel_thr / abv_bins / pg/bin) were dropped — they're + # still in the JSON for downstream tools, just not in the table. header = ( f"{'mode':<14s} {'remap_ms':>9s} {'topk_ms':>9s} {'split_ms':>9s} " - f"{'fused_ms':>9s} {'base_ms':>9s} {'thr_bin':>7s} {'thr_size':>8s} " - f"{'sel_thr':>7s} {'abv_bins':>8s} {'pg/bin':>7s}" + f"{'fused_ms':>9s} {'base_ms':>9s} " + f"{'s1p2_load':>9s} {'eff_thr':>7s} {'rounds':>6s} {'s2_work':>8s}" ) for cfg in results: banner = ( f"\n[batch={cfg['batch_size']} heads={cfg['num_kv_heads']} " f"seq_len={cfg['seq_len']} topk={cfg['topk_val']} " - f"dist={cfg['distribution']} pages_per_seg={cfg['pages_per_seg']}]" + f"dist={cfg['distribution']} pages_per_seg={cfg['pages_per_seg']} " + f"head={cfg.get('head', 'all')}]" ) print(banner) extra_notes = [] @@ -666,6 +684,12 @@ def _print_remap_table(results: List[dict]) -> None: if extra_notes: notes_str = " | " + " | ".join(extra_notes) print(f" Baseline: topk_sglang.cu (CUB two-stage){notes_str}") + print( + f" s1p2_load = thr_size (uncapped global re-reads in Stage-1 pass 2) " + f"eff_thr = min(thr_size, {_STAGE2_SMEM_CAP}) " + f"rounds = stage-2 passes (1..4) " + f"s2_work = rounds * eff_thr" + ) print(header) print("-" * len(header)) base_ms = cfg["baseline_ms"] @@ -681,6 +705,11 @@ def _print_remap_table(results: List[dict]) -> None: def _fmt(v): return f"{v:9.4f}" if v is not None else f"{'N/A':>9s}" fused_str = _fmt(row.get("fused_ms")) + thr_size = row.get("threshold_bin_size_mean", 0.0) + rounds = row.get("refine_rounds_mean", 0.0) + eff_thr = min(thr_size, float(_STAGE2_SMEM_CAP)) + s2_work = rounds * eff_thr + s1p2_load = thr_size # alias: same number, named for the cost-model role print( f"{label:<14s} " f"{_fmt(row['remap_ms'])} " @@ -688,14 +717,77 @@ def _fmt(v): f"{_fmt(row['split_total_ms'])} " f"{fused_str} " f"{base_ms:9.4f} " - f"{row['threshold_bin_mean']:7.1f} " - f"{row['threshold_bin_size_mean']:8.1f} " - f"{row['selected_from_thr_mean']:7.1f} " - f"{row.get('above_bins_mean', 0.0):8.1f} " - f"{row.get('pages_per_above_bin_mean', 0.0):7.1f}" + f"{s1p2_load:9.0f} " + f"{eff_thr:7.0f} " + f"{rounds:6.2f} " + f"{s2_work:8.0f}" ) +def _combine_per_head_cfgs(per_head_cfgs: List[dict]) -> dict: + """Combine a list of per-head cfg dicts (same shape, head='0','1',...) + into a single aggregated cfg tagged head='all', by averaging every + numeric field. This is used when --per-head-bench is on so the + aggregated row reflects the realistic per-head behaviour rather than + a separate kernel launch on an averaged histogram. + + Assumes every cfg has the same `modes` list in the same order — which + holds because all per-head sub-runs use identical (batch, heads, seq, + topk, page_size, reserved, mapping_modes) parameters and therefore + take the same code paths through `_remap_bench_one_config`. + """ + assert per_head_cfgs, "_combine_per_head_cfgs called with empty list" + base = per_head_cfgs[0] + n_modes = len(base["modes"]) + # Sanity: same shape. + for c in per_head_cfgs[1:]: + assert len(c["modes"]) == n_modes, ( + f"per-head cfgs disagree on mode count: {n_modes} vs {len(c['modes'])}" + ) + + def _mean_or_none(vals): + vs = [v for v in vals if v is not None] + return (sum(vs) / len(vs)) if vs else None + + combined: Dict = { + "batch_size": base["batch_size"], + "num_kv_heads": base["num_kv_heads"], + "seq_len": base["seq_len"], + "topk_val": base["topk_val"], + "distribution": base["distribution"], + "pages_per_seg": base["pages_per_seg"], + "head": "all", + "baseline_ms": _mean_or_none([c.get("baseline_ms") for c in per_head_cfgs]), + "naive_ms": _mean_or_none([c.get("naive_ms") for c in per_head_cfgs]), + "sglang_ori_ms": _mean_or_none([c.get("sglang_ori_ms") for c in per_head_cfgs]), + "modes": [], + } + + # Numeric fields per mode row that we average; non-numeric fields (mode, + # mode_name, power) are copied from the first cfg since they're identical + # across heads by construction. + NUMERIC_KEYS = ( + "remap_ms", "topk_after_remap_ms", "split_total_ms", "fused_ms", + "threshold_bin_mean", "threshold_bin_max", + "num_above_mean", + "threshold_bin_size_mean", "threshold_bin_size_max", + "selected_from_thr_mean", "selected_from_thr_max", + "refine_rounds_mean", + "above_bins_mean", "pages_per_above_bin_mean", + ) + for mi in range(n_modes): + sample = base["modes"][mi] + merged = { + "mode": sample["mode"], + "mode_name": sample["mode_name"], + "power": sample["power"], + } + for key in NUMERIC_KEYS: + merged[key] = _mean_or_none([c["modes"][mi].get(key) for c in per_head_cfgs]) + combined["modes"].append(merged) + return combined + + def _run_remap_bench(args) -> None: modes = [int(m) for m in args.mapping_modes] # Mode 0 is emitted as the "None" row from _remap_bench_one_config @@ -710,14 +802,68 @@ def _run_remap_bench(args) -> None: print(f"[remap-bench] 'real' distribution enabled " f"(histogram total count = {int(args._real_histogram.sum())})") + if getattr(args, "per_head_bench", False): + if getattr(args, "_real_histograms_raw", None) is None: + raise SystemExit( + "[bench-remap] --per-head-bench requires --real-histograms with a 2D raw file." + ) + if not args.num_kv_heads or any(h <= 0 for h in args.num_kv_heads): + raise SystemExit("[bench-remap] --per-head-bench requires --num-kv-heads > 0.") + # When the user passes multiple --num-kv-heads values we slice by the + # first one (the others are degenerate for per-head reporting since + # the histogram file has a fixed head count). + per_head_count = int(args.num_kv_heads[0]) + results = [] + # When --per-head-bench is on, each "real"-distribution aggregate is + # built by averaging the 8 per-head measurements (NOT by running an + # extra kernel on an averaged histogram). This grouping keeps the + # per-head cfgs that should fold into each (bs, heads, seq, topk) + # aggregate point. + per_head_groups: dict = {} + + # ---- Per-head tables (printed first) ---- + if getattr(args, "per_head_bench", False): + raw = args._real_histograms_raw + saved_agg = args._real_histogram + try: + for h in range(per_head_count): + # Slice rows belonging to head `h`. Rows are interleaved as + # row_idx % num_kv_heads = head_idx, so this strided slice + # collects all (call, batch, h) triples across the file. + args._real_histogram = raw[h::per_head_count].sum(axis=0) + for bs in args.batch_sizes: + for heads in args.num_kv_heads: + for seq_len in args.seq_lens: + for topk_val in args.topk_vals: + cfg = _remap_bench_one_config( + args, bs, heads, seq_len, topk_val, "real", modes, + head_label=str(h), + ) + results.append(cfg) + per_head_groups.setdefault( + (bs, heads, seq_len, topk_val), [] + ).append(cfg) + finally: + args._real_histogram = saved_agg + + # ---- Aggregated tables (printed last) ---- for bs in args.batch_sizes: for heads in args.num_kv_heads: for seq_len in args.seq_lens: for topk_val in args.topk_vals: for dist in distributions: + if dist == "real" and getattr(args, "per_head_bench", False): + cfgs = per_head_groups.get((bs, heads, seq_len, topk_val), []) + if cfgs: + # Combine the per-head cfgs into a single + # aggregated row — no extra kernel launch. + cfg = _combine_per_head_cfgs(cfgs) + results.append(cfg) + continue cfg = _remap_bench_one_config( args, bs, heads, seq_len, topk_val, dist, modes, + head_label="all", ) results.append(cfg) @@ -838,6 +984,13 @@ def main(): p.add_argument("--output-json", type=str, default=None) p.add_argument("--remap-bench", action="store_true", help="Run the split-phase remap/topk/fused/baseline benchmark.") + p.add_argument("--per-head-bench", action="store_true", + help="In addition to the aggregated 'real'-distribution table, also " + "run the remap-bench once per KV head: slice the calibrated " + "histogram into one sub-histogram per head (using " + "row_idx %% num_kv_heads = head_idx), bench each, and print one " + "table per head followed by the aggregated table. Requires " + "--real-histograms (with a 2D raw file) and --num-kv-heads.") args = p.parse_args() args._autotune_hparams = {} @@ -848,9 +1001,14 @@ def main(): print(f" mode {m:>2d} -> {v}") args._real_histogram = None + args._real_histograms_raw = None if args.real_histograms: - raw = np.load(args.real_histograms) + # mmap_mode='r' keeps the (potentially 20+ GB) raw file off-heap; we + # only materialise per-head sums when --per-head-bench is set. + raw = np.load(args.real_histograms, mmap_mode='r') args._real_histogram = raw.sum(axis=0) if raw.ndim > 1 else raw + if raw.ndim > 1: + args._real_histograms_raw = raw print(f"[real] loaded calibrated histogram from {args.real_histograms} " f"(shape={raw.shape} → [256] aggregate)") diff --git a/benchmarks/calibrate_topk.py b/benchmarks/calibrate_topk.py index 4914133..f3343aa 100644 --- a/benchmarks/calibrate_topk.py +++ b/benchmarks/calibrate_topk.py @@ -17,6 +17,7 @@ import argparse import json import os +import shutil import sys import numpy as np @@ -46,9 +47,17 @@ def main(): "Block-sparse profiling uses a small bytes/token estimate, so the auto " "budget can be huge on large GPUs; VTXGraphAttnBackend then allocates " "dense bf16 sparse_prefill K/V buffers proportional to this cap (~4 KiB per " - "token per buffer). For offline calibration, a few hundred K1M tokens " + "token per buffer). For offline calibration, a few hundred K to 1M tokens " "is usually enough.", ) + parser.add_argument( + "--min-free-disk-gb", + type=float, + default=20.0, + help="Abort if the filesystem for --output-dir (and HF cache, typically the same) " + "has less than this many GiB free. First-time model downloads need many GiB. " + "Set to 0 to disable.", + ) parser.add_argument("--kv-cache-dtype", type=str, default="auto") parser.add_argument("--topk-type", type=str, default="sglang") parser.add_argument("--num-prompts", type=int, default=16, @@ -65,6 +74,30 @@ def main(): ) args = parser.parse_args() + # Classic HTTP downloads avoid XET chunk reconstruction ("Background writer channel + # closed") that often surfaces when the disk is full or nearly full. + if "HF_HUB_DISABLE_XET" not in os.environ: + os.environ["HF_HUB_DISABLE_XET"] = "1" + + if args.min_free_disk_gb > 0: + check_path = os.path.abspath(args.output_dir) + while check_path and not os.path.isdir(check_path): + parent = os.path.dirname(check_path) + if parent == check_path: + check_path = os.getcwd() + break + check_path = parent + usage = shutil.disk_usage(check_path) + free_gb = usage.free / (1024.0**3) + if free_gb < args.min_free_disk_gb: + raise SystemExit( + f"[calibrate] ERROR: Only {free_gb:.1f} GiB free on filesystem containing " + f"{args.output_dir!r} (checked from {check_path!r}). " + f"Need at least ~{args.min_free_disk_gb} GiB for Hugging Face weights, hub cache, " + f"and logs. Free disk space or point HF_HOME at a larger disk. " + f"To skip this check: --min-free-disk-gb 0" + ) + # Lazy imports to avoid slow startup when just checking --help import sglang as sgl import torch @@ -135,6 +168,28 @@ def main(): all_hists = torch.cat(histograms, dim=0).numpy() # [total_samples, 256] print(f"[calibrate] Total histogram samples: {all_hists.shape[0]}") + # Regression guard: refuse to save a collapsed histogram. A healthy + # calibration touches tens to hundreds of bins; if almost everything lands + # in a single bin, the scoring pipeline silently produced zero scores + # (see the Sgl_Decode_Plan_Workload_Kernel `w > topk_val` bug fixed in + # csrc/utils_sglang.cu). Saving 20+ GB of all-zeros wastes disk and poisons + # downstream benches, so fail loudly here. + _pooled = all_hists.sum(axis=0).astype(np.float64) + _total = float(_pooled.sum()) + if _total > 0: + _top_frac = float(_pooled.max()) / _total + _nz_bins = int((_pooled > 0).sum()) + if _top_frac > 0.95 or _nz_bins < 5: + llm.shutdown() + raise SystemExit( + f"[calibrate] ERROR: degenerate histogram — top bin holds " + f"{_top_frac:.2%} of mass, only {_nz_bins}/256 bins nonzero. " + f"The scoring pipeline is likely not running (check " + f"winfo_num_workloads in plan_decode, or `w > topk_val` in " + f"Sgl_Decode_Plan_Workload_Kernel). Refusing to save to avoid " + f"writing a useless multi-GB file." + ) + # --- Generate LUT (mode 1) --- # Aggregate histogram across all samples avg_histogram = all_hists.mean(axis=0) diff --git a/csrc/topk_sglang.cu b/csrc/topk_sglang.cu index fa9c825..73366df 100644 --- a/csrc/topk_sglang.cu +++ b/csrc/topk_sglang.cu @@ -50,6 +50,15 @@ constexpr size_t kSmem = 48 * 1024; // bytes constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) #endif +// Fused-kernel dynamic smem ceiling. The fused kernel uses `kSmem` bytes for +// f_input_idx (2 × SMEM_INPUT_SIZE ints) AND an extra `max_num_pages` bytes +// for s_bins (one uint8_t per page). Ceiling of 96 KB covers max_num_pages up +// to 65536 and fits the opt-in dynamic-smem limits on every target in +// setup.py (sm_86 ≥99KB, sm_89 100KB, sm_90 228KB, sm_100a/120 ≥100KB). +// Only `topk_output_sglang_fused` uses this ceiling; the other kernels keep +// kSmem as their dynamic-smem budget. +constexpr size_t kFusedSmemMax = 96 * 1024; + struct FastTopKParams { const float* __restrict__ input; // [B, input_stride] const int32_t* __restrict__ row_starts; // [B] @@ -670,16 +679,23 @@ __device__ void fast_topk_clean_fused( // Per-element Stage-1 bin cache. Pass 1 of Stage 1 writes one byte per // element; pass 2 reads it back so each element only pays a single - // apply_transform + global score read instead of two. Sized to the - // maximum `pages_per_seg` the bench drivers use (topk=2048 config has - // seq_len=32768 / page_size=8 = 4096 pages per segment; topk=30 has - // 2048). Shrinking from 8192 to 4096 freed 4 KB of static SMEM per - // block, which lifts occupancy from 5 → 6 blocks/SM on B200. - constexpr int kFusedMaxLen = 4096; - __shared__ uint8_t s_bins[kFusedMaxLen]; + // apply_transform + global score read instead of two. + // + // s_bins lives in DYNAMIC shared memory, placed immediately after the + // f_input_idx[2][SMEM_INPUT_SIZE] 2D array in the same extern __shared__ + // region. The host launch reserves `kSmem + max_num_pages` dynamic bytes + // (see `topk_output_sglang_fused`) so every block has `max_num_pages` + // bytes available past f_input_idx's 32 KB span. Per-block `length` + // (from dense_kv_indptr) is ≤ max_num_pages, so indexing stays in bounds. + // + // This layout keeps smem usage at kSmem + 4 KB for the existing + // pages_per_seg ≤ 4096 regimes (identical to the old 32 KB dynamic + + // 4 KB static) and only grows when the caller asks for a larger + // pages_per_seg — no occupancy regression on small configs. auto& f_histogram = f_histogram_buf[0]; extern __shared__ int f_input_idx[][SMEM_INPUT_SIZE]; + uint8_t* const s_bins = reinterpret_cast(&f_input_idx[2][0]); const int tx = threadIdx.x; @@ -1192,10 +1208,20 @@ void topk_output_sglang_fused( "topk_output_sglang_fused: topk_val (", topk_val, ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); - // Caller contract: max_num_pages must be <= 4096, the static SMEM - // `s_bins` cache size inside the templated fused kernel. The bench - // drivers stay within this bound; no runtime check is emitted in - // the hot path. + // Dynamic-smem layout for the fused kernel: + // [ f_input_idx (2 × SMEM_INPUT_SIZE × sizeof(int) = kSmem bytes) + // s_bins (bins_bytes = align_up(max_num_pages, 16)) ] + // The per-launch smem request equals the total of both. It must fit + // under kFusedSmemMax, which setup_kernel_smem_once opted this kernel + // into via cudaFuncSetAttribute(MaxDynamicSharedMemorySize, ...). + const size_t bins_bytes = (static_cast(max_num_pages) + size_t(15)) & ~size_t(15); + const size_t smem_bytes = kSmem + bins_bytes; + TORCH_CHECK(smem_bytes <= kFusedSmemMax, + "topk_output_sglang_fused: max_num_pages (", max_num_pages, + ") exceeds the fused kernel's dynamic smem ceiling. " + "Requested smem=", smem_bytes, " bytes, ceiling=", kFusedSmemMax, + " bytes. Raise kFusedSmemMax (and verify GPU opt-in limits) or " + "reduce pages_per_seg."); // The `mapping_lut` / `mapping_quantiles` optional tensors are // retained in the pybind signature for API backward compatibility @@ -1226,8 +1252,8 @@ void topk_output_sglang_fused( // time per-call cost, negligible relative to the kernel runtime. #define VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MODE_VAL) \ do { \ - setup_kernel_smem_once, kSmem>(); \ - TopKOutput_Fused_Kernel<<>>( \ + setup_kernel_smem_once, kFusedSmemMax>(); \ + TopKOutput_Fused_Kernel<<>>( \ PTR_EXPR, \ dense_kv_indptr.data_ptr(), \ sparse_kv_indptr.data_ptr(), \ diff --git a/csrc/utils_sglang.cu b/csrc/utils_sglang.cu index 1420e9e..a7ddf42 100644 --- a/csrc/utils_sglang.cu +++ b/csrc/utils_sglang.cu @@ -82,16 +82,20 @@ const int page_reserved_eos #pragma unroll for (int i = 0; i < ITEM_PER_THREAD; ++i){ - int16_t w = ((tx_offset + i) < eff_batch_size) ? - (dense_kv_indptr[tx_offset+i+1] - dense_kv_indptr[tx_offset+i] + int16_t w = ((tx_offset + i) < eff_batch_size) ? + (dense_kv_indptr[tx_offset+i+1] - dense_kv_indptr[tx_offset+i] - page_reserved_bos - page_reserved_eos): 0; - - page_count[i] = (w > topk_val) ? w : 0; + + // See note in Sgl_Decode_Plan_Workload_Kernel: we used to skip slots + // where w ≤ topk_val, but downstream (GeMV / topK / histogram) has no + // matching skip, so it read uninitialised scores and silently + // produced all-zero results. Emit workloads for every slot with w > 0. + page_count[i] = (w > 0) ? w : 0; chunked_page_count_prefix_sum[i + 1] = int((page_count[i] + max_chunk_size - 1) / max_chunk_size); } BlockScanInt(temp.scan_int).InclusiveSum(chunked_page_count_prefix_sum, chunked_page_count_prefix_sum); - + if (tx == 1023){ *winfo_num_workload = chunked_page_count_prefix_sum[ITEM_PER_THREAD]; *winfo_chunk_size = max_chunk_size; @@ -218,16 +222,22 @@ const int page_reserved_eos #pragma unroll for (int i = 0; i < ITEM_PER_THREAD; ++i){ - int16_t w = ((tx_offset + i) < eff_batch_size) ? - (dense_kv_indptr[tx_offset+i+1] - dense_kv_indptr[tx_offset+i] + int16_t w = ((tx_offset + i) < eff_batch_size) ? + (dense_kv_indptr[tx_offset+i+1] - dense_kv_indptr[tx_offset+i] - page_reserved_bos - page_reserved_eos): 0; - - page_count[i] = (w > topk_val) ? w : 0; + + // Previously: (w > topk_val) ? w : 0, which skipped scoring on slots + // where the dense page count is already ≤ topk_val. Downstream (GeMV, + // topK, histogram profiling) does NOT have a matching skip, so it + // would read uninitialised scores and silently return garbage (all + // zero). Emit workloads for every slot with w > 0 so scoring always + // runs; when w ≤ topk_val the topK degenerates to "select all w". + page_count[i] = (w > 0) ? w : 0; chunked_page_count_prefix_sum[i + 1] = int((page_count[i] + max_chunk_size - 1) / max_chunk_size); } BlockScanInt(temp.scan_int).InclusiveSum(chunked_page_count_prefix_sum, chunked_page_count_prefix_sum); - + if (tx == 1023){ *winfo_num_workloads = chunked_page_count_prefix_sum[ITEM_PER_THREAD]; *winfo_chunk_size = max_chunk_size; diff --git a/examples/remap_function_bench_topk2028.sh b/examples/remap_function_bench_topk2028.sh index 6d95b59..26c529c 100755 --- a/examples/remap_function_bench_topk2028.sh +++ b/examples/remap_function_bench_topk2028.sh @@ -50,33 +50,40 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BENCH_DIR="${SCRIPT_DIR}/../benchmarks" # ── Defaults ────────────────────────────────────────────────── -GPU_ID=5 -MODEL_NAME="Qwen/Qwen3-8B" +GPU_ID=4 +MODEL_NAME="Qwen/Qwen3-1.7B" TOPK_VAL=2048 MEM=0.7 # Cap KV / VTX sparse prefill buffer sizing during Step 1 (see calibrate_topk.py --help). MAX_TOTAL_TOKENS=64768 +# Min free GiB on the output-dir filesystem before Step 1 (HF weights + cache + logs). +MIN_FREE_DISK_GB=22 ALGO="block_sparse_attention" SAMPLE_STRIDE=1 SEQ_LEN=32768 -BLOCK_SIZE=8 +BLOCK_SIZE=1 BATCH_SIZE=4 NUM_KV_HEADS=8 DISTRIBUTIONS="normal bucket_uniform" # Modes 1 (LUT_CDF) and 2 (Quantile) are no longer benchmarked — their # mapping happens inside compute_stage1_bin, not apply_transform, so # split-phase timing isn't meaningful for them. -MAPPING_MODES="0 3 6 7 8 9 10 11 13 15 16 17 18 19 20" +MAPPING_MODES="0 3 6 7 9 10 11 13 15 16 17 18 19" # Fallback hparam used only if autotune is explicitly skipped. MAPPING_HPARAM=0.5 REPEAT=100 WARMUP=20 # Empty by default — Step 1 will calibrate on the selected model. # Pass --real-histograms /path/to/raw_histograms.npy to skip calibration. -# REAL_HISTOGRAMS="/home/zhuominc/xinrui_projects/vortex_torch/examples/calibration/raw_histograms.npy" -#REAL_HISTOGRAMS="/home/zhuominc/xinrui_projects/vortex_torch/examples/calibration/raw_histograms_qwen3-4B.npy" -REAL_HISTOGRAMS="" +# REAL_HISTOGRAMS="/var/tmp/zhuominc/vortex_torch/calibration/raw_histograms.npy" +#REAL_HISTOGRAMS="/var/tmp/zhuominc/vortex_torch/calibration/raw_histograms_qwen3-4B.npy" +REAL_HISTOGRAMS="/var/tmp/zhuominc/vortex_torch/calibration/raw_histograms_qwen3-1.7B.npy" SKIP_AUTOTUNE=0 +# Optional: pre-built autotune JSON to bypass Step 2 entirely. When set, +# Step 2 is skipped and Step 3 reads its per-mode hparams from this file +# instead. Useful for verification runs where we want to pin the exact +# (mode, hparam) pairs without re-running the latency sweep. +PINNED_AUTOTUNE_JSON="" # ── Parse arguments ─────────────────────────────────────────── while [[ $# -gt 0 ]]; do @@ -85,6 +92,7 @@ while [[ $# -gt 0 ]]; do --topk-val) TOPK_VAL="$2"; shift 2 ;; --mem) MEM="$2"; shift 2 ;; --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; + --min-free-disk-gb) MIN_FREE_DISK_GB="$2"; shift 2 ;; --gpu) GPU_ID="$2"; shift 2 ;; --algo) ALGO="$2"; shift 2 ;; --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; @@ -99,6 +107,7 @@ while [[ $# -gt 0 ]]; do --repeat) REPEAT="$2"; shift 2 ;; --warmup) WARMUP="$2"; shift 2 ;; --skip-autotune) SKIP_AUTOTUNE=1; shift 1 ;; + --pinned-autotune-json) PINNED_AUTOTUNE_JSON="$2"; SKIP_AUTOTUNE=1; shift 2 ;; *) echo "Unknown option: $1"; exit 1 ;; esac done @@ -136,6 +145,18 @@ MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" RUN_DIR="${RESULTS_DIR}/remap_bench_${MODEL_SLUG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" mkdir -p "${RUN_DIR}" +# Calibration artifacts live on /var/tmp (large disk), keyed by model. +# Example: /var/tmp/zhuominc/vortex_torch/calibration/raw_histograms_qwen3-8B.npy +CALIBRATION_BASE="/var/tmp/zhuominc/vortex_torch/calibration" +MODEL_TAG="$(echo "${MODEL_NAME##*/}" | sed 's/^Q/q/')" +DEFAULT_REAL_HIST="${CALIBRATION_BASE}/raw_histograms_${MODEL_TAG}.npy" +mkdir -p "${CALIBRATION_BASE}" + +# If no explicit --real-histograms and a cached file exists, reuse it. +if [ -z "${REAL_HISTOGRAMS}" ] && [ -f "${DEFAULT_REAL_HIST}" ]; then + REAL_HISTOGRAMS="${DEFAULT_REAL_HIST}" +fi + echo "============================================================" echo "Remap Function Benchmark" echo " Model: ${MODEL_NAME}" @@ -149,6 +170,7 @@ echo " Distributions: ${DISTRIBUTIONS}" echo " Mapping modes: ${MAPPING_MODES}" echo " Fallback hparam: ${MAPPING_HPARAM} (used only when --skip-autotune)" echo " Max total tokens: ${MAX_TOTAL_TOKENS} (calibration KV / VTX buffer cap)" +echo " Min free disk: ${MIN_FREE_DISK_GB} GiB (Step 1 preflight; 0 = skip)" echo " GPU: ${GPU_ID}" echo " Sample stride: ${SAMPLE_STRIDE}" echo " Real histograms: ${REAL_HISTOGRAMS:-}" @@ -167,7 +189,7 @@ if [ -n "${REAL_HISTOGRAMS}" ]; then else echo "" echo ">>> Step 1: Calibrating ${MODEL_NAME} — collecting real topk histograms" - CALIBRATION_DIR="${RUN_DIR}/calibration" + CALIBRATION_DIR="${CALIBRATION_BASE}/staging_${MODEL_TAG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" mkdir -p "${CALIBRATION_DIR}" python "${BENCH_DIR}/calibrate_topk.py" \ --model-name "${MODEL_NAME}" \ @@ -175,11 +197,15 @@ else --page-size "${BLOCK_SIZE}" \ --mem "${MEM}" \ --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --min-free-disk-gb "${MIN_FREE_DISK_GB}" \ --vortex-module-name "${ALGO}" \ --output-dir "${CALIBRATION_DIR}" \ 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" - REAL_HIST_PATH="${CALIBRATION_DIR}/raw_histograms.npy" - echo ">>> Step 1: Done. Calibration saved to ${CALIBRATION_DIR}" + # Promote raw_histograms.npy to the shared per-model cache path. + mv -f "${CALIBRATION_DIR}/raw_histograms.npy" "${DEFAULT_REAL_HIST}" + REAL_HIST_PATH="${DEFAULT_REAL_HIST}" + echo ">>> Step 1: Done. raw_histograms -> ${REAL_HIST_PATH}" + echo ">>> Step 1: Staging dir (lut/quantiles/logs): ${CALIBRATION_DIR}" fi # Modes 1 (LUT_CDF) and 2 (Quantile) are dropped from the comparison, so @@ -193,8 +219,13 @@ fi AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" if [ "${SKIP_AUTOTUNE}" -eq 1 ]; then echo "" - echo ">>> Step 2: SKIPPED (using fallback --mapping-hparam ${MAPPING_HPARAM})" - AUTOTUNE_ARGS="" + if [ -n "${PINNED_AUTOTUNE_JSON}" ]; then + echo ">>> Step 2: SKIPPED (pinned hparams from ${PINNED_AUTOTUNE_JSON})" + AUTOTUNE_ARGS="--autotune-json ${PINNED_AUTOTUNE_JSON}" + else + echo ">>> Step 2: SKIPPED (using fallback --mapping-hparam ${MAPPING_HPARAM})" + AUTOTUNE_ARGS="" + fi else echo "" echo ">>> Step 2: Auto-tuning hyperparameters by profiled topk kernel latency" @@ -222,6 +253,7 @@ BENCH_EXTRA=() [ -n "${REAL_HIST_PATH}" ] && BENCH_EXTRA+=(--real-histograms "${REAL_HIST_PATH}") PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ --remap-bench \ + --per-head-bench \ --batch-sizes "${BATCH_SIZE}" \ --num-kv-heads "${NUM_KV_HEADS}" \ --seq-lens "${SEQ_LEN}" \ diff --git a/examples/remap_function_bench_topk30.sh b/examples/remap_function_bench_topk30.sh index 3cb52e2..3843906 100755 --- a/examples/remap_function_bench_topk30.sh +++ b/examples/remap_function_bench_topk30.sh @@ -52,7 +52,7 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BENCH_DIR="${SCRIPT_DIR}/../benchmarks" # ── Defaults ────────────────────────────────────────────────── -GPU_ID=5 +GPU_ID=1 MODEL_NAME="Qwen/Qwen3-1.7B" TOPK_VAL=30 MEM=0.7 @@ -61,20 +61,20 @@ ALGO="block_sparse_attention" SAMPLE_STRIDE=1 SEQ_LEN=32768 BLOCK_SIZE=16 -BATCH_SIZE=1 +BATCH_SIZE=4 NUM_KV_HEADS=8 DISTRIBUTIONS="normal bucket_uniform" # Modes 1 (LUT_CDF) and 2 (Quantile) are no longer benchmarked — their # mapping happens inside compute_stage1_bin, not apply_transform, so # split-phase timing isn't meaningful for them. -MAPPING_MODES="0 3 6 7 8 9 10 11 13 15 16 17 18 19 20" +MAPPING_MODES="0 3 6 7 9 10 11 13 15 16 17 18 19 20" # Fallback hparam used only if autotune is explicitly skipped. MAPPING_HPARAM=0.5 REPEAT=100 WARMUP=20 # Empty by default — Step 1 will calibrate on the selected model. # Pass --real-histograms /path/to/raw_histograms.npy to skip calibration. -REAL_HISTOGRAMS="/home/zhuominc/xinrui_projects/vortex_torch/examples/calibration/raw_histograms.npy" +REAL_HISTOGRAMS="/var/tmp/zhuominc/vortex_torch/calibration/raw_histograms_qwen3-4B.npy" SKIP_AUTOTUNE=0 # ── Parse arguments ─────────────────────────────────────────── @@ -135,6 +135,18 @@ MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" RUN_DIR="${RESULTS_DIR}/remap_bench_${MODEL_SLUG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" mkdir -p "${RUN_DIR}" +# Calibration artifacts live on /var/tmp (large disk), keyed by model. +# Example: /var/tmp/zhuominc/vortex_torch/calibration/raw_histograms_qwen3-1.7B.npy +CALIBRATION_BASE="/var/tmp/zhuominc/vortex_torch/calibration" +MODEL_TAG="$(echo "${MODEL_NAME##*/}" | sed 's/^Q/q/')" +DEFAULT_REAL_HIST="${CALIBRATION_BASE}/raw_histograms_${MODEL_TAG}.npy" +mkdir -p "${CALIBRATION_BASE}" + +# If no explicit --real-histograms and a cached file exists, reuse it. +if [ -z "${REAL_HISTOGRAMS}" ] && [ -f "${DEFAULT_REAL_HIST}" ]; then + REAL_HISTOGRAMS="${DEFAULT_REAL_HIST}" +fi + echo "============================================================" echo "Remap Function Benchmark" echo " Model: ${MODEL_NAME}" @@ -166,7 +178,7 @@ if [ -n "${REAL_HISTOGRAMS}" ]; then else echo "" echo ">>> Step 1: Calibrating ${MODEL_NAME} — collecting real topk histograms" - CALIBRATION_DIR="${RUN_DIR}/calibration" + CALIBRATION_DIR="${CALIBRATION_BASE}/staging_${MODEL_TAG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" mkdir -p "${CALIBRATION_DIR}" python "${BENCH_DIR}/calibrate_topk.py" \ --model-name "${MODEL_NAME}" \ @@ -177,8 +189,11 @@ else --vortex-module-name "${ALGO}" \ --output-dir "${CALIBRATION_DIR}" \ 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" - REAL_HIST_PATH="${CALIBRATION_DIR}/raw_histograms.npy" - echo ">>> Step 1: Done. Calibration saved to ${CALIBRATION_DIR}" + # Promote raw_histograms.npy to the shared per-model cache path. + mv -f "${CALIBRATION_DIR}/raw_histograms.npy" "${DEFAULT_REAL_HIST}" + REAL_HIST_PATH="${DEFAULT_REAL_HIST}" + echo ">>> Step 1: Done. raw_histograms -> ${REAL_HIST_PATH}" + echo ">>> Step 1: Staging dir (lut/quantiles/logs): ${CALIBRATION_DIR}" fi # Modes 1 (LUT_CDF) and 2 (Quantile) are dropped from the comparison, so @@ -221,6 +236,7 @@ BENCH_EXTRA=() [ -n "${REAL_HIST_PATH}" ] && BENCH_EXTRA+=(--real-histograms "${REAL_HIST_PATH}") PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ --remap-bench \ + --per-head-bench \ --batch-sizes "${BATCH_SIZE}" \ --num-kv-heads "${NUM_KV_HEADS}" \ --seq-lens "${SEQ_LEN}" \ diff --git a/third_party/sglang b/third_party/sglang index 47faead..b7825d0 160000 --- a/third_party/sglang +++ b/third_party/sglang @@ -1 +1 @@ -Subproject commit 47faead5448b14681ac57fc9a3c6311654fc2b17 +Subproject commit b7825d08399fccdf1f29a5380d6601fcef59aca1 diff --git a/vortex_torch/indexer/utils_sglang.py b/vortex_torch/indexer/utils_sglang.py index 74b8cfe..343207f 100644 --- a/vortex_torch/indexer/utils_sglang.py +++ b/vortex_torch/indexer/utils_sglang.py @@ -40,7 +40,7 @@ def plan_decode( ctx.max_chunk_size, ctx.min_chunk_size ) - + ctx.set_batch_size(cached_seq_lens.shape[0]) From 13cb8a015a63be6a5cd8060fb6c889915de8b879 Mon Sep 17 00:00:00 2001 From: UED Date: Sun, 19 Apr 2026 19:06:04 -0400 Subject: [PATCH 22/22] Add parallel TopK kernel and profiling enhancements - Introduced for a multi-CTA split+merge variant of the TopK kernel to improve GPU utilization. - Updated to include the new source file for the parallel kernel. - Enhanced to support benchmarking of the parallel kernel, including automatic split determination. - Added a new profiling script for comparing performance between the parallel and fused TopK kernels. - Updated example scripts to facilitate ablation studies on remap functions and kernel performance across different configurations. --- benchmarks/bench_topk.py | 74 +- benchmarks/profile_parallel_vs_fused.py | 99 +++ csrc/register.cc | 11 + csrc/register.h | 18 + csrc/topk_sglang_parallel.cu | 811 ++++++++++++++++++ .../ablation_remap_function_block_size.sh | 279 ++++++ examples/ablation_remap_function_model.sh | 262 ++++++ .../ablation_remap_function_topk_benchmark.sh | 277 ++++++ examples/ablation_remap_function_topk_val.sh | 255 ++++++ examples/analyze_ablation_remap.py | 416 +++++++++ examples/profile_in_docker.sh | 181 ++++ examples/profile_parallel_vs_fused_ncu.sh | 277 ++++++ examples/profile_parallel_vs_fused_nsys.sh | 211 +++++ .../remap_function_bench_topk_parallel.sh | 245 ++++++ examples/verify_algo.py | 29 +- setup.py | 1 + 16 files changed, 3443 insertions(+), 3 deletions(-) create mode 100644 benchmarks/profile_parallel_vs_fused.py create mode 100644 csrc/topk_sglang_parallel.cu create mode 100644 examples/ablation_remap_function_block_size.sh create mode 100644 examples/ablation_remap_function_model.sh create mode 100644 examples/ablation_remap_function_topk_benchmark.sh create mode 100644 examples/ablation_remap_function_topk_val.sh create mode 100644 examples/analyze_ablation_remap.py create mode 100755 examples/profile_in_docker.sh create mode 100755 examples/profile_parallel_vs_fused_ncu.sh create mode 100755 examples/profile_parallel_vs_fused_nsys.sh create mode 100755 examples/remap_function_bench_topk_parallel.sh diff --git a/benchmarks/bench_topk.py b/benchmarks/bench_topk.py index 3653c55..a717c7b 100644 --- a/benchmarks/bench_topk.py +++ b/benchmarks/bench_topk.py @@ -27,6 +27,7 @@ topk_output_sglang, # 2-stage radix approximate topk (unmapped baseline) topk_output_sglang_fused, # fused remap + 2-stage radix topk topk_output_sglang_ori, # original SGLang reference kernel + topk_output_sglang_parallel, # multi-CTA split+merge variant of the fused kernel topk_remap_only, # standalone value-space remap topk_profile_histogram, topk_profile_counters, @@ -76,6 +77,32 @@ _AUTOTUNE_TIE_TOLERANCE_MS = 0.0002 # ≈ CUDA event noise floor at this kernel size +def _auto_num_splits(eff_batch_size: int, pages_per_seg: int, topk_val: int) -> int: + """Pick num_splits to balance Phase-1 and Phase-2 work on the parallel + kernel. + + Phase-1 per CTA does O(pages/splits) work and runs eff_batch_size*splits + CTAs in parallel; Phase-2 runs eff_batch_size CTAs each doing + O(splits*topk) work on the merged candidate list. Assuming both phases + hit SM saturation, total ≈ (pages/splits + splits*topk)/throughput, + minimized at splits = sqrt(pages/topk). Cap at the SM-budget for + eff_batch_size and the max_safe value (pages_per_seg // topk_val, past + which Phase 1 partitions are smaller than topk_val and gain nothing). + + Returns 1 when splitting cannot help. + """ + max_safe = max(1, pages_per_seg // max(1, topk_val)) + if max_safe <= 1 or eff_batch_size <= 0: + return 1 + try: + sm = torch.cuda.get_device_properties(0).multi_processor_count + except Exception: + sm = 132 + balanced = max(1, int(round((pages_per_seg / max(1, topk_val)) ** 0.5))) + sm_budget = max(1, sm // max(1, eff_batch_size)) + return max(1, min(balanced, sm_budget, max_safe)) + + def _load_autotune_hparams(path: str) -> Dict[int, float]: """Load per-mode best hyperparameters from an autotune_results.json. @@ -514,6 +541,8 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, "topk_after_remap_ms": naive_ms, "split_total_ms": None, "fused_ms": None, + "parallel_ms": None, + "parallel_splits": None, "threshold_bin_mean": 0.0, "threshold_bin_max": 0.0, "num_above_mean": 0.0, @@ -541,6 +570,8 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, "topk_after_remap_ms": baseline["mean_ms"], "split_total_ms": None, "fused_ms": None, + "parallel_ms": None, + "parallel_splits": None, **none_stats, }) @@ -556,6 +587,8 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, "topk_after_remap_ms": sglang_ori_ms, "split_total_ms": None, "fused_ms": None, + "parallel_ms": None, + "parallel_splits": None, "threshold_bin_mean": 0.0, "threshold_bin_max": 0.0, "num_above_mean": 0.0, @@ -595,6 +628,32 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, inputs["sparse_kv_indices"].zero_() fused = bench_kernel(topk_output_sglang_fused, fused_args, args.warmup, args.repeat) + # Multi-CTA split+merge variant of the fused kernel. num_splits <= 1 + # delegates to the single-CTA fused path, so this is only a + # meaningful extra data point when we can actually split. + parallel_ms = None + parallel_splits_used = None + if getattr(args, "bench_parallel", False): + splits = getattr(args, "num_splits", -1) + if splits is None or splits < 1: + splits = _auto_num_splits(eff_bs, pages_per_seg, topk_val) + parallel_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + splits, + mode, power, lut_t, q_t, + ) + inputs["sparse_kv_indices"].zero_() + parallel = bench_kernel( + topk_output_sglang_parallel, parallel_args, args.warmup, args.repeat + ) + parallel_ms = parallel["mean_ms"] + parallel_splits_used = splits + # Split-phase timing is only meaningful for arithmetic modes. # MAPPING_LUT_CDF / QUANTILE / TRUNC8 apply their mapping inside # compute_stage1_bin, which topk_remap_only cannot reproduce, so we @@ -645,6 +704,8 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, "topk_after_remap_ms": topk_after_remap_ms, "split_total_ms": split_total_ms, "fused_ms": fused["mean_ms"], + "parallel_ms": parallel_ms, + "parallel_splits": parallel_splits_used, **stats, } config["modes"].append(row) @@ -664,7 +725,7 @@ def _print_remap_table(results: List[dict]) -> None: # still in the JSON for downstream tools, just not in the table. header = ( f"{'mode':<14s} {'remap_ms':>9s} {'topk_ms':>9s} {'split_ms':>9s} " - f"{'fused_ms':>9s} {'base_ms':>9s} " + f"{'fused_ms':>9s} {'par_ms':>9s} {'splits':>6s} {'base_ms':>9s} " f"{'s1p2_load':>9s} {'eff_thr':>7s} {'rounds':>6s} {'s2_work':>8s}" ) for cfg in results: @@ -705,6 +766,9 @@ def _print_remap_table(results: List[dict]) -> None: def _fmt(v): return f"{v:9.4f}" if v is not None else f"{'N/A':>9s}" fused_str = _fmt(row.get("fused_ms")) + par_str = _fmt(row.get("parallel_ms")) + splits = row.get("parallel_splits") + splits_str = f"{splits:>6d}" if splits is not None else f"{'N/A':>6s}" thr_size = row.get("threshold_bin_size_mean", 0.0) rounds = row.get("refine_rounds_mean", 0.0) eff_thr = min(thr_size, float(_STAGE2_SMEM_CAP)) @@ -716,6 +780,8 @@ def _fmt(v): f"{_fmt(row['topk_after_remap_ms'])} " f"{_fmt(row['split_total_ms'])} " f"{fused_str} " + f"{par_str} " + f"{splits_str} " f"{base_ms:9.4f} " f"{s1p2_load:9.0f} " f"{eff_thr:7.0f} " @@ -768,6 +834,7 @@ def _mean_or_none(vals): # across heads by construction. NUMERIC_KEYS = ( "remap_ms", "topk_after_remap_ms", "split_total_ms", "fused_ms", + "parallel_ms", "threshold_bin_mean", "threshold_bin_max", "num_above_mean", "threshold_bin_size_mean", "threshold_bin_size_max", @@ -984,6 +1051,11 @@ def main(): p.add_argument("--output-json", type=str, default=None) p.add_argument("--remap-bench", action="store_true", help="Run the split-phase remap/topk/fused/baseline benchmark.") + p.add_argument("--bench-parallel", action="store_true", + help="Also time topk_output_sglang_parallel (multi-CTA split+merge).") + p.add_argument("--num-splits", type=int, default=-1, + help="Partitions per batch for the parallel kernel. -1 = auto " + "(sm_count / eff_batch_size, clamped to pages_per_seg/topk_val).") p.add_argument("--per-head-bench", action="store_true", help="In addition to the aggregated 'real'-distribution table, also " "run the remap-bench once per KV head: slice the calibrated " diff --git a/benchmarks/profile_parallel_vs_fused.py b/benchmarks/profile_parallel_vs_fused.py new file mode 100644 index 0000000..ecfd872 --- /dev/null +++ b/benchmarks/profile_parallel_vs_fused.py @@ -0,0 +1,99 @@ +""" +Driver for Nsight Compute profiling of the parallel vs fused TopK +kernels. Designed to be launched under `ncu` with --launch-skip and +--launch-count to isolate a specific kernel launch from warmup. + +The script does exactly: + args.warmup matching-kernel launches (skipped by ncu --launch-skip) + args.iters matching-kernel launches (captured by ncu --launch-count) + +Pair --launch-skip/--launch-count with --kernel-name so unrelated +launches (torch initializers, cublas, etc.) don't pollute the counts. +""" +import argparse +import torch +from vortex_torch_C import ( + topk_output_sglang_fused, + topk_output_sglang_parallel, +) + + +def make_inputs(eff_bs: int, pages: int, topk: int): + reserved = 0 + dense_indptr = torch.arange( + 0, (eff_bs + 1) * pages, pages, dtype=torch.int32, device="cuda" + ) + sparse_indptr = torch.arange( + 0, (eff_bs + 1) * topk, topk, dtype=torch.int32, device="cuda" + ) + dense_indices = torch.arange(eff_bs * pages, dtype=torch.int32, device="cuda") + torch.manual_seed(0) + x = torch.randn(eff_bs * pages, 1, 1, dtype=torch.bfloat16, device="cuda") + out = torch.zeros(eff_bs * topk, dtype=torch.int32, device="cuda") + return x, dense_indptr, sparse_indptr, dense_indices, out, reserved + + +def main(): + p = argparse.ArgumentParser() + p.add_argument( + "--config", + choices=["A", "B"], + required=True, + help="A: topk=2048 pages=32K ; B: topk=30 pages=2K", + ) + p.add_argument("--eff-bs", type=int, default=1) + p.add_argument( + "--mode", type=int, choices=[15, 16], required=True, + help="15=MAPPING_SHIFT_POW2, 16=MAPPING_SHIFT_POW3", + ) + p.add_argument( + "--power", type=float, default=0.5, + help="Pivot (p) for the shift_pow transforms. 0.5 matches the " + "autotune default for Qwen3-1.7B softmax scores.", + ) + p.add_argument("--num-splits", type=int, default=4) + p.add_argument("--kernel", choices=["fused", "parallel"], required=True) + p.add_argument("--warmup", type=int, default=20) + p.add_argument("--iters", type=int, default=1) + args = p.parse_args() + + pages, topk = (32768, 2048) if args.config == "A" else (2048, 30) + x, dense_indptr, sparse_indptr, dense_indices, out, reserved = make_inputs( + args.eff_bs, pages, topk + ) + + if args.kernel == "fused": + def call(): + topk_output_sglang_fused( + x, dense_indptr, sparse_indptr, dense_indices, out, + args.eff_bs, topk, reserved, reserved, pages, + args.mode, args.power, None, None, + ) + else: + def call(): + topk_output_sglang_parallel( + x, dense_indptr, sparse_indptr, dense_indices, out, + args.eff_bs, topk, reserved, reserved, pages, + args.num_splits, args.mode, args.power, None, None, + ) + + # Warmup: specialised kernel is JIT-instantiated and cudaFuncSetAttribute + # is cached; these launches dominate the first-call overhead and we want + # ncu to skip past them. + for _ in range(args.warmup): + call() + torch.cuda.synchronize() + + # Profiled region. Wrap in NVTX so the same script is also useful under + # Nsight Systems (nsys) if you prefer a timeline view. + torch.cuda.nvtx.range_push( + f"profile-{args.kernel}-mode{args.mode}-cfg{args.config}-eff{args.eff_bs}" + ) + for _ in range(args.iters): + call() + torch.cuda.synchronize() + torch.cuda.nvtx.range_pop() + + +if __name__ == "__main__": + main() diff --git a/csrc/register.cc b/csrc/register.cc index 8aa5aea..af584d3 100644 --- a/csrc/register.cc +++ b/csrc/register.cc @@ -30,6 +30,17 @@ PYBIND11_MODULE(vortex_torch_C, m){ py::arg("mapping_power"), py::arg("mapping_lut") = py::none(), py::arg("mapping_quantiles") = py::none()); + m.def("topk_output_sglang_parallel", &topk_output_sglang_parallel, + py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), + py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages"), + py::arg("num_splits"), + py::arg("mapping_mode"), + py::arg("mapping_power"), + py::arg("mapping_lut") = py::none(), + py::arg("mapping_quantiles") = py::none()); m.def("topk_remap_only", &topk_remap_only, py::arg("x"), py::arg("dense_kv_indptr"), py::arg("remapped"), diff --git a/csrc/register.h b/csrc/register.h index afdb97f..e5a26de 100644 --- a/csrc/register.h +++ b/csrc/register.h @@ -126,6 +126,24 @@ std::optional mapping_lut = std::nullopt, std::optional mapping_quantiles = std::nullopt ); +void topk_output_sglang_parallel( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& sparse_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& sparse_kv_indices, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_num_pages, +const int64_t num_splits, +const int64_t mapping_mode, +const double mapping_power, +std::optional mapping_lut = std::nullopt, +std::optional mapping_quantiles = std::nullopt +); + void topk_remap_only( const at::Tensor& x, const at::Tensor& dense_kv_indptr, diff --git a/csrc/topk_sglang_parallel.cu b/csrc/topk_sglang_parallel.cu new file mode 100644 index 0000000..7219391 --- /dev/null +++ b/csrc/topk_sglang_parallel.cu @@ -0,0 +1,811 @@ +/** + * Vortex TopK parallel kernel (single-kernel, last-CTA-wins merge). + * + * Motivation: the single-CTA fused kernel in topk_sglang.cu pins each + * batch segment to one CTA, which underutilises the GPU for small + * effective batch sizes (e.g. bs=4 on H100 leaves ~97% of SMs idle). + * + * This kernel launches `num_splits * eff_batch_size` CTAs in a single + * launch. CTAs sharing the same `bx` (batch index) partition that + * batch's score range `num_splits` ways and each compute a per-partition + * top-K via the same two-stage radix the fused kernel uses. Partial + * results are written into a per-batch workspace. + * + * Merge is done WITHOUT a second kernel launch. Each CTA, after + * finishing its partition's top-K, does `atomicAdd(&done_counter[bx], + * 1)`. The CTA whose atomicAdd returns `num_splits - 1` is the last + * one to arrive for batch bx, and it alone carries out the merge: + * reads the `num_splits * topk_val` candidates from the workspace, + * runs a small two-stage radix on the already-remapped keys, writes + * final top-K page IDs to sparse_kv_indices. + * + * Correctness: per-partition top-K is a conservative upper bound on + * the global top-K (worst case: all top-K items land in one + * partition). Every global top-K item is therefore guaranteed to be + * in some partition's top-K, and the merge picks the final top-K + * from the union — sorted-scores match the fused kernel exactly. + * Tie-breaking can differ because radix tie-breaks depend on atomic + * race order. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "register.h" + +namespace { + +// ---- Launch constants (match topk_sglang.cu) -------------------------------- + +constexpr int kThreadsPerBlock = 1024; + +#ifdef USE_ROCM +#ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES +constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); +#else +constexpr size_t kSmem = 48 * 1024; +#endif +#else +constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32 KB +#endif + +constexpr size_t kFusedSmemMax = 96 * 1024; // combined kernel dynamic smem ceiling +constexpr int VORTEX_MAX_TOPK = 2048; + +// ---- Program-lifetime done-counter array ---------------------------------- +// Used by the last-CTA-wins barrier. __device__ linkage → zero-initialised at +// program startup. atomicInc(ptr, num_splits-1) cycles each entry back to 0 +// after every launch, so we never pay a cudaMemset on entry to the host fn. +// Sized for the largest realistic effective batch we'd ever run through the +// parallel kernel (decode bs×heads). Host validates the cap. +constexpr int kMaxParallelEffBs = 8192; +__device__ int g_parallel_done_counter[kMaxParallelEffBs]; + +// ---- Device helpers (duplicated from topk_sglang.cu) ----------------------- + +__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) + : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +__device__ __forceinline__ auto convert_to_uint8_dense(float x) -> uint8_t { + const uint32_t bits = __float_as_uint(x); + const uint32_t key = (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); + return static_cast((key >> 16) & 0xFFu); +} + +template +__device__ __forceinline__ float vortex_to_float(T x); +template <> +__device__ __forceinline__ float vortex_to_float(float x) { return x; } +template <> +__device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { + return __bfloat162float(x); +} + +#include "topk_mapping.cuh" + +// ============================================================================ +// fast_topk_partition +// +// Per-partition two-stage radix. Same algorithm as the fused kernel's +// fast_topk_clean_fused in topk_sglang.cu, with identical mapping-mode +// dispatch and bucket selection. Returns slice-local indices of the +// top `target_k` elements in `index`. +// +// Reuses the caller-provided extern shared memory region `f_input_idx` +// (2 × SMEM_INPUT_SIZE ints) and the `s_bins` byte cache immediately +// after it. The caller also supplies the static histogram / counter +// storage through the template's body — each device-function-private +// __shared__ declaration gets its own offset, but total static smem +// stays small enough to fit comfortably alongside the dynamic region. +// ============================================================================ +template +__device__ void fast_topk_partition( + const ScoreT* __restrict__ input, + int* __restrict__ index, + int* __restrict__ f_input_idx_raw, // 2 × SMEM_INPUT_SIZE ints + uint8_t* __restrict__ s_bins, // `length` bytes + int row_start, + int length, + int target_k, + const TopKMappingParams mapping) +{ + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int f_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int f_counter; + alignas(128) __shared__ int f_threshold_bin_id; + alignas(128) __shared__ int f_num_input[2]; + + auto& f_histogram = f_histogram_buf[0]; + + // Treat the caller's extern-smem region as two banks of SMEM_INPUT_SIZE ints. + auto f_input_idx = [&](int bank, int pos) -> int& { + return f_input_idx_raw[bank * SMEM_INPUT_SIZE + pos]; + }; + + const int tx = threadIdx.x; + + constexpr bool use_dense_bucket = (MODE == MAPPING_DENSE_MANT); + + if (tx < RADIX + 1) f_histogram[tx] = 0; + __syncthreads(); + + // Stage 1 pass 1: bin every element and cache the bin in s_bins so + // pass 2 doesn't re-load scores or re-apply the mapping. + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform_tmpl(raw, mapping.power_exp); + int bin; + if constexpr (use_dense_bucket) bin = static_cast(convert_to_uint8_dense(remapped)); + else bin = static_cast(convert_to_uint8(remapped)); + s_bins[idx] = static_cast(bin); + ::atomicAdd(&f_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = f_histogram_buf[k][tx]; + if (tx < RADIX - j) value += f_histogram_buf[k][tx + j]; + f_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && f_histogram[tx] > topk && f_histogram[tx + 1] <= topk) { + f_threshold_bin_id = tx; + f_num_input[0] = 0; + f_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = f_threshold_bin_id; + topk -= f_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const int bin = static_cast(s_bins[idx]); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&f_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) f_histogram[tx] = 0; + __syncthreads(); + + constexpr int sub_bin_offset_start = use_dense_bucket ? 8 : 24; + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const int bin = static_cast(s_bins[idx]); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&f_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform_tmpl(raw, mapping.power_exp); + const auto pos = ::atomicAdd(&f_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + f_input_idx(0, pos) = idx; + const auto b32 = convert_to_uint32(remapped); + const auto sub_bin = (b32 >> sub_bin_offset_start) & 0xFF; + ::atomicAdd(&f_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + constexpr int stage2_offset_start = use_dense_bucket ? 8 : 24; + constexpr int stage2_max_rounds = use_dense_bucket ? 2 : 4; +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + if (round >= stage2_max_rounds) break; + __shared__ int f_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = f_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input + : int(SMEM_INPUT_SIZE); + run_cumsum(); + if (tx < RADIX && f_histogram[tx] > topk && f_histogram[tx + 1] <= topk) { + f_threshold_bin_id = tx; + f_num_input[r_idx ^ 1] = 0; + f_last_remain = topk - f_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = f_threshold_bin_id; + topk -= f_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = f_input_idx(r_idx, i); + const auto offset = stage2_offset_start - round * 8; + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform_tmpl(raw, mapping.power_exp); + const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&f_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) f_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = f_input_idx(r_idx, i); + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform_tmpl(raw, mapping.power_exp); + const auto offset = stage2_offset_start - round * 8; + const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&f_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == stage2_max_rounds - 1) { + const auto pos = ::atomicAdd(&f_last_remain, -1); + if (pos > 0) index[target_k - pos] = idx; + } else { + const auto pos = ::atomicAdd(&f_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + f_input_idx(r_idx ^ 1, pos) = idx; + const auto b32 = convert_to_uint32(remapped); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&f_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +// ============================================================================ +// fast_topk_merge +// +// Run by the last-arriving CTA of each batch. Input is the combined +// candidate list (`num_splits * topk_val` float keys + int indices, +// with idx==-1 marking sentinel slots). Reuses the same extern-smem +// region `s_input_idx_raw` that Phase 1 used — its earlier contents +// are dead at this point. Output: top-`target_k` positions into +// `index`, indexing the combined candidate list. +// +// Bucketing matches the fused kernel's bucketing for the given MODE +// so the merged top-K is lossless modulo atomic tie-break order. +// ============================================================================ +template +__device__ void fast_topk_merge( + const float* __restrict__ input, + const int* __restrict__ valid_mask, + int* __restrict__ index, + int* __restrict__ s_input_idx_raw, // 2 × SMEM_INPUT_SIZE ints + int row_start, + int length, + int target_k) +{ + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + constexpr bool use_dense_bucket = (MODE == MAPPING_DENSE_MANT); + constexpr int stage2_offset_start = use_dense_bucket ? 8 : 24; + constexpr int stage2_max_rounds = use_dense_bucket ? 2 : 4; + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + auto s_input_idx = [&](int bank, int pos) -> int& { + return s_input_idx_raw[bank * SMEM_INPUT_SIZE + pos]; + }; + + const int tx = threadIdx.x; + + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + if (valid_mask[idx + row_start] < 0) continue; // sentinel; skip + const float v = input[idx + row_start]; + int bin; + if constexpr (use_dense_bucket) bin = static_cast(convert_to_uint8_dense(v)); + else bin = static_cast(convert_to_uint8(v)); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) value += s_histogram_buf[k][tx + j]; + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + if (valid_mask[idx + row_start] < 0) continue; + const float v = input[idx + row_start]; + int bin; + if constexpr (use_dense_bucket) bin = static_cast(convert_to_uint8_dense(v)); + else bin = static_cast(convert_to_uint8(v)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + if (valid_mask[idx + row_start] < 0) continue; + const auto raw_input = input[idx + row_start]; + int bin; + if constexpr (use_dense_bucket) bin = static_cast(convert_to_uint8_dense(raw_input)); + else bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx(0, pos) = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> stage2_offset_start) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + if (round >= stage2_max_rounds) break; + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input + : int(SMEM_INPUT_SIZE); + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx(r_idx, i); + const auto offset = stage2_offset_start - round * 8; + const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx(r_idx, i); + const auto raw_input = input[idx + row_start]; + const auto offset = stage2_offset_start - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == stage2_max_rounds - 1) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) index[target_k - pos] = idx; + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx(r_idx ^ 1, pos) = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +// ============================================================================ +// Combined kernel. +// +// Grid: (num_splits, eff_batch_size). Every CTA: +// 1. Computes its partition's top-K (fast_topk_partition). +// 2. Writes (remapped key, batch-local idx) pairs + sentinels to the +// per-batch workspace slot. +// 3. __threadfence() to publish the writes, then atomicAdd on the +// per-batch done-counter. The CTA whose atomicAdd returns +// num_splits - 1 is the last one for this batch. +// 4. If last: run the merge (fast_topk_merge) on the combined +// num_splits*topk_val candidates and write final page IDs to +// sparse_kv_indices. Other CTAs exit. +// ============================================================================ +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKOutput_Parallel_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + float* __restrict__ partial_keys, // [eff_bs * num_splits * topk_val] + int* __restrict__ partial_idx, // [eff_bs * num_splits * topk_val] + const int topk_val, + const int num_splits, + const int page_reserved_bos, + const int page_reserved_eos, + const int chunk_bytes, // smem bytes reserved for s_bins + const TopKMappingParams mapping) +{ + // ---- Dynamic smem layout ------------------------------------------------- + // [ f_input_idx (2 × SMEM_INPUT_SIZE ints = kSmem bytes) + // s_bins (chunk_bytes, only valid during Phase 1) ] + // The merge doesn't touch s_bins, so its extern region overlaps + // f_input_idx harmlessly. + extern __shared__ int smem_scratch[]; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + int* f_input_idx_raw = smem_scratch; + uint8_t* s_bins = reinterpret_cast(&smem_scratch[2 * SMEM_INPUT_SIZE]); + (void)chunk_bytes; // sizing is the host's responsibility; kernel just uses it + + // s_indices doubles as the partition's radix output AND the merge's radix + // output — they run sequentially on the same CTA, so the same ~2K slots + // are reused. Stores up to VORTEX_MAX_TOPK = 2048 entries. + __shared__ int s_indices[VORTEX_MAX_TOPK]; + // Broadcasts whether this CTA is the last-arriving one for its batch. + __shared__ int s_is_last; + + const int p = blockIdx.x; + const int bx = blockIdx.y; + const int tx = threadIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int total_len = end - start; + + // Short batch: fused kernel returns without writing; match that. + if (total_len <= topk_val) return; + + const size_t slot_base = (static_cast(bx) * num_splits + p) * topk_val; + float* keys_out = partial_keys + slot_base; + int* idx_out = partial_idx + slot_base; + + const int chunk = (total_len + num_splits - 1) / num_splits; + const int part_start = p * chunk; + const int raw_part_end = part_start + chunk; + const int part_end = raw_part_end < total_len ? raw_part_end : total_len; + const int part_len = (part_end > part_start) ? (part_end - part_start) : 0; + + // Sentinel tail: merge filters these by idx == -1. Only fill the range + // that won't be overwritten with real data. + const int real_fill = (part_len < topk_val) ? part_len : topk_val; + const int tail_count = topk_val - real_fill; + if (tail_count > 0) { + for (int i = tx; i < tail_count; i += blockDim.x) { + keys_out[real_fill + i] = -CUDART_INF_F; + idx_out [real_fill + i] = -1; + } + __syncthreads(); + } + + const ScoreT* __restrict__ slice_ptr = score + start + part_start; + + // ---- Phase 1: per-partition top-K --------------------------------------- + if (part_len > 0) { + if (part_len <= topk_val) { + // Whole slice fits under topk_val — emit it directly. + for (int i = tx; i < part_len; i += blockDim.x) { + const float raw = vortex_to_float(slice_ptr[i]); + const float remapped = apply_transform_tmpl(raw, mapping.power_exp); + keys_out[i] = remapped; + idx_out [i] = part_start + i; + } + } else { + fast_topk_partition( + slice_ptr, s_indices, f_input_idx_raw, s_bins, + 0, part_len, topk_val, mapping); + __syncthreads(); + for (int i = tx; i < topk_val; i += blockDim.x) { + const int sl = s_indices[i]; + const float raw = vortex_to_float(slice_ptr[sl]); + const float remapped = apply_transform_tmpl(raw, mapping.power_exp); + keys_out[i] = remapped; + idx_out [i] = part_start + sl; + } + } + } + + // Publish workspace writes so the last-CTA can observe them. + __threadfence(); + __syncthreads(); + + // ---- Arrive at the barrier via atomicInc -------------------------------- + // atomicInc(ptr, N-1) stores `((old >= N-1) ? 0 : old+1)` and returns old. + // So with N == num_splits the counter cycles 0→1→…→N-1→0 per call, which + // means we never need to memset done_counter between calls — after the + // last-CTA's increment it's back at 0, ready for the next launch. + // (Relies on the caller allocating done_counter zero-initialised once.) + if (tx == 0) { + const unsigned int old = ::atomicInc( + reinterpret_cast(&g_parallel_done_counter[bx]), + static_cast(num_splits - 1)); + s_is_last = (old == static_cast(num_splits - 1)) ? 1 : 0; + } + __syncthreads(); + + if (s_is_last == 0) return; + + // ---- Merge: last CTA selects final top-K -------------------------------- + const int candidate_len = num_splits * topk_val; + const size_t batch_base = static_cast(bx) * candidate_len; + const float* keys_blk = partial_keys + batch_base; + const int* idx_blk = partial_idx + batch_base; + int* out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + const int* dense_blk = dense_kv_indices + start; + + fast_topk_merge( + keys_blk, idx_blk, s_indices, f_input_idx_raw, + 0, candidate_len, topk_val); + __syncthreads(); + + for (int i = tx; i < topk_val; i += blockDim.x) { + const int pos = s_indices[i]; + const int batch_local = idx_blk[pos]; + out_blk[i] = (batch_local >= 0) ? dense_blk[batch_local] : -1; + } +} + +// ---- setup_kernel_smem_once (duplicated) ----------------------------------- + +template +void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { +#ifdef USE_ROCM + return ::cudaFuncSetAttribute( + reinterpret_cast(f), + ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#else + return ::cudaFuncSetAttribute( + f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#endif + }(); + TORCH_CHECK(result == cudaSuccess, + "set_up_kernel_once (parallel) failed:", ::cudaGetErrorString(result)); +} + +} // namespace + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +// ============================================================================ +// Host entry point. +// +// Signature matches topk_output_sglang_fused plus `num_splits`. +// `num_splits <= 1` delegates to the single-CTA fused kernel so callers +// can unconditionally use this path. +// ============================================================================ +void topk_output_sglang_parallel( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages, + const int64_t num_splits, + const int64_t mapping_mode, + const double mapping_power, + std::optional mapping_lut, + std::optional mapping_quantiles) +{ + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_output_sglang_parallel: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + TORCH_CHECK(num_splits >= 1, + "topk_output_sglang_parallel: num_splits must be >= 1"); + + if (num_splits <= 1) { + topk_output_sglang_fused( + x, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, + sparse_kv_indices, eff_batch_size, topk_val, + reserved_bos, reserved_eos, max_num_pages, + mapping_mode, mapping_power, mapping_lut, mapping_quantiles); + return; + } + + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(sparse_kv_indptr); + CHECK_CUDA(dense_kv_indices); + CHECK_CUDA(sparse_kv_indices); + + (void)mapping_lut; + (void)mapping_quantiles; + + TopKMappingParams mapping{}; + mapping.mode = static_cast(mapping_mode); + mapping.power_exp = static_cast(mapping_power); + mapping.lut = nullptr; + mapping.quantiles = nullptr; + + // Dynamic smem = kSmem (f_input_idx) + chunk_bytes (s_bins for the + // partition radix; the merge doesn't touch s_bins). + const int64_t chunk_pages = (max_num_pages + num_splits - 1) / num_splits; + const size_t chunk_bytes = (static_cast(chunk_pages) + size_t(15)) & ~size_t(15); + const size_t smem_bytes = kSmem + chunk_bytes; + TORCH_CHECK(smem_bytes <= kFusedSmemMax, + "topk_output_sglang_parallel: smem ", smem_bytes, + " exceeds ceiling ", kFusedSmemMax); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + TORCH_CHECK(eff_batch_size <= kMaxParallelEffBs, + "topk_output_sglang_parallel: eff_batch_size (", eff_batch_size, + ") exceeds kMaxParallelEffBs (", kMaxParallelEffBs, + "). Raise the __device__ counter array size."); + + // Per-call workspace. at::empty, no zero-init — kernel fills every used + // slot (valid prefix + sentinel tail). done_counter is a __device__ + // global (above) so no workspace allocation needed for it. + const int64_t ws_elems = eff_batch_size * num_splits * topk_val; + auto opts_f32 = at::TensorOptions().device(x.device()).dtype(at::kFloat); + auto opts_i32 = at::TensorOptions().device(x.device()).dtype(at::kInt); + at::Tensor partial_keys = at::empty({ws_elems}, opts_f32); + at::Tensor partial_idx = at::empty({ws_elems}, opts_i32); + + dim3 grid(static_cast(num_splits), + static_cast(eff_batch_size)); + dim3 nthreads(kThreadsPerBlock); + + #define VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MODE_VAL) \ + do { \ + setup_kernel_smem_once< \ + TopKOutput_Parallel_Kernel, \ + kFusedSmemMax>(); \ + TopKOutput_Parallel_Kernel \ + <<>>( \ + PTR_EXPR, \ + dense_kv_indptr.data_ptr(), \ + sparse_kv_indptr.data_ptr(), \ + dense_kv_indices.data_ptr(), \ + sparse_kv_indices.data_ptr(), \ + partial_keys.data_ptr(), \ + partial_idx.data_ptr(), \ + static_cast(topk_val), \ + static_cast(num_splits), \ + static_cast(reserved_bos), \ + static_cast(reserved_eos), \ + static_cast(chunk_bytes), \ + mapping); \ + } while (0) + + #define VORTEX_PARALLEL_DISPATCH_MODE(DTYPE, PTR_EXPR) \ + do { \ + switch (mapping.mode) { \ + case MAPPING_NONE: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_NONE); break; \ + case MAPPING_POWER: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_POWER); break; \ + case MAPPING_LOG: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_LOG); break; \ + case MAPPING_ASINH: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_ASINH); break; \ + case MAPPING_LOG1P: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_LOG1P); break; \ + case MAPPING_TRUNC8: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_TRUNC8); break; \ + case MAPPING_ERF: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_ERF); break; \ + case MAPPING_TANH: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_TANH); break; \ + case MAPPING_SUBTRACT: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_SUBTRACT); break; \ + case MAPPING_EXP_STRETCH: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_EXP_STRETCH); break; \ + case MAPPING_SHIFT_POW2: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW2); break; \ + case MAPPING_SHIFT_POW3: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW3); break; \ + case MAPPING_LINEAR_STEEP:VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_LINEAR_STEEP); break; \ + case MAPPING_HALF_SQUARE: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_HALF_SQUARE); break; \ + case MAPPING_HALF_CUBE: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_HALF_CUBE); break; \ + case MAPPING_DENSE_MANT: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_DENSE_MANT); break; \ + default: \ + TORCH_CHECK(false, \ + "topk_output_sglang_parallel: unsupported mapping_mode ", \ + mapping.mode); \ + } \ + } while (0) + + if (x.scalar_type() == at::ScalarType::BFloat16) { + VORTEX_PARALLEL_DISPATCH_MODE( + __nv_bfloat16, + reinterpret_cast<__nv_bfloat16*>(x.data_ptr())); + } else if (x.scalar_type() == at::ScalarType::Float) { + VORTEX_PARALLEL_DISPATCH_MODE(float, x.data_ptr()); + } else { + TORCH_CHECK(false, "topk_output_sglang_parallel: unsupported dtype ", + x.scalar_type()); + } + + #undef VORTEX_PARALLEL_DISPATCH_MODE + #undef VORTEX_PARALLEL_DISPATCH + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_output_sglang_parallel kernel failed: ", + ::cudaGetErrorString(result)); +} diff --git a/examples/ablation_remap_function_block_size.sh b/examples/ablation_remap_function_block_size.sh new file mode 100644 index 0000000..4bf5bba --- /dev/null +++ b/examples/ablation_remap_function_block_size.sh @@ -0,0 +1,279 @@ +#!/usr/bin/env bash +# ============================================================ +# Ablation: Remap function vs. block (page) size +# +# Sweeps BLOCK_SIZE and, for every cell, runs the full +# calibrate -> autotune -> remap-bench +# pipeline so the per-mode hyperparameter is freshly chosen by +# autotune for that block size (NOT hardcoded). +# +# Mapping modes under test (matches the screenshot): +# 0 none — unmapped baseline +# 3 power — p +# 6 asinh — beta +# 7 log1p — alpha +# 9 erf — alpha +# 10 tanh — alpha +# 11 subtract — pivot +# 13 exp_stretch — alpha +# 15 shift_pow2 — pivot +# 16 shift_pow3 — pivot +# 17 linear_steep — k +# +# Output: +# results/ablation_remap_block_size_/ +# bs/{autotune_results.json, remap_bench.json, step{1,2,3}_*.log} +# sweep_index.json +# selected_hparams.txt — per-cell screenshot-style summary +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=4 +MODEL_NAME="Qwen/Qwen3-1.7B" +TOPK_VAL=2048 +MEM=0.7 +MAX_TOTAL_TOKENS=64768 +MIN_FREE_DISK_GB=22 +ALGO="block_sparse_attention" +BATCH_SIZE=4 +NUM_KV_HEADS=8 +DISTRIBUTIONS="normal bucket_uniform" +MAPPING_MODES="0 3 6 7 9 10 11 13 15 16 17" +MAPPING_HPARAM=0.5 +REPEAT=100 +WARMUP=20 +BLOCK_SIZES="1 2 4 8 16 32 64" +REAL_HISTOGRAMS="" + +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --gpu) GPU_ID="$2"; shift 2 ;; + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; + --min-free-disk-gb) MIN_FREE_DISK_GB="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --distributions) DISTRIBUTIONS="$2"; shift 2 ;; + --modes) MAPPING_MODES="$2"; shift 2 ;; + --mapping-hparam) MAPPING_HPARAM="$2"; shift 2 ;; + --repeat) REPEAT="$2"; shift 2 ;; + --warmup) WARMUP="$2"; shift 2 ;; + --block-sizes) BLOCK_SIZES="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" +export SGL_ENABLE_JIT_DEEPGEMM="${SGL_ENABLE_JIT_DEEPGEMM:-true}" +if [ -z "${DG_JIT_NVCC_COMPILER:-}" ]; then + if [ -x /usr/local/cuda/bin/nvcc ]; then + export CUDA_HOME="${CUDA_HOME:-/usr/local/cuda}" + export PATH="${CUDA_HOME}/bin:${PATH}" + export DG_JIT_NVCC_COMPILER="${CUDA_HOME}/bin/nvcc" + elif command -v nvcc >/dev/null 2>&1; then + export DG_JIT_NVCC_COMPILER="$(command -v nvcc)" + fi +fi + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +SWEEP_DIR="${RESULTS_DIR}/ablation_remap_block_size_${MODEL_SLUG}_${TIMESTAMP}" +mkdir -p "${SWEEP_DIR}" + +# Per-model calibration cache (reused across block_size cells: page size +# does not change the per-segment score distribution). +CALIBRATION_BASE="/var/tmp/zhuominc/vortex_torch/calibration" +MODEL_TAG="$(echo "${MODEL_NAME##*/}" | sed 's/^Q/q/')" +DEFAULT_REAL_HIST="${CALIBRATION_BASE}/raw_histograms_${MODEL_TAG}.npy" +mkdir -p "${CALIBRATION_BASE}" + +if [ -z "${REAL_HISTOGRAMS}" ] && [ -f "${DEFAULT_REAL_HIST}" ]; then + REAL_HISTOGRAMS="${DEFAULT_REAL_HIST}" +fi + +echo "============================================================" +echo "Ablation: remap function vs block_size" +echo " Model: ${MODEL_NAME}" +echo " Algorithm: ${ALGO}" +echo " TopK: ${TOPK_VAL}" +echo " Block sizes: ${BLOCK_SIZES}" +echo " Batch size: ${BATCH_SIZE}" +echo " KV heads: ${NUM_KV_HEADS}" +echo " Distributions: ${DISTRIBUTIONS}" +echo " Mapping modes: ${MAPPING_MODES}" +echo " GPU: ${GPU_ID}" +echo " Sweep dir: ${SWEEP_DIR}" +echo "============================================================" + +# ── Step 0: Calibrate once for this model ────────────────────── +if [ -n "${REAL_HISTOGRAMS}" ]; then + echo ">>> Step 0: SKIPPED calibration (using ${REAL_HISTOGRAMS})" + REAL_HIST_PATH="${REAL_HISTOGRAMS}" +else + echo ">>> Step 0: Calibrating ${MODEL_NAME} for raw_histograms.npy" + STAGING_DIR="${CALIBRATION_BASE}/staging_${MODEL_TAG}_${TIMESTAMP}" + mkdir -p "${STAGING_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --page-size 1 \ + --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --min-free-disk-gb "${MIN_FREE_DISK_GB}" \ + --vortex-module-name "${ALGO}" \ + --output-dir "${STAGING_DIR}" \ + 2>&1 | tee "${SWEEP_DIR}/step0_calibrate.log" + mv -f "${STAGING_DIR}/raw_histograms.npy" "${DEFAULT_REAL_HIST}" + REAL_HIST_PATH="${DEFAULT_REAL_HIST}" + echo ">>> Step 0: Done. raw_histograms -> ${REAL_HIST_PATH}" +fi + +# ── Sweep ────────────────────────────────────────────────────── +SWEEP_INDEX="${SWEEP_DIR}/sweep_index.json" +echo "{" > "${SWEEP_INDEX}" +echo " \"axis_name\": \"block_size\"," >> "${SWEEP_INDEX}" +echo " \"axis_type\": \"kernel\"," >> "${SWEEP_INDEX}" +echo " \"model_name\": \"${MODEL_NAME}\"," >> "${SWEEP_INDEX}" +echo " \"topk_val\": ${TOPK_VAL}," >> "${SWEEP_INDEX}" +echo " \"mapping_modes\": [${MAPPING_MODES// /, }]," >> "${SWEEP_INDEX}" +echo " \"cells\": [" >> "${SWEEP_INDEX}" + +FIRST_CELL=1 +for BLOCK_SIZE in ${BLOCK_SIZES}; do + # Pick a seq_len that satisfies pages/seg > topk_val + 3 reserved. + MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) + SEQ_LEN=${MIN_SEQ_LEN} + # Round up to next power-of-two-ish multiple of 1024 for stable timing. + if [ "${SEQ_LEN}" -lt 8192 ]; then SEQ_LEN=8192; fi + + CELL_DIR="${SWEEP_DIR}/bs${BLOCK_SIZE}" + mkdir -p "${CELL_DIR}" + AUTOTUNE_JSON="${CELL_DIR}/autotune_results.json" + REMAP_JSON="${CELL_DIR}/remap_bench.json" + + echo "" + echo "============================================================" + echo ">>> Cell: block_size=${BLOCK_SIZE} seq_len=${SEQ_LEN}" + echo "============================================================" + + echo ">>> Autotuning hparams for block_size=${BLOCK_SIZE}" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --batch-size "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --collect-stats \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${CELL_DIR}/step2_autotune.log" + + echo ">>> Remap bench for block_size=${BLOCK_SIZE}" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --per-head-bench \ + --batch-sizes "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --distributions ${DISTRIBUTIONS} \ + --mapping-modes ${MAPPING_MODES} \ + --mapping-hparam "${MAPPING_HPARAM}" \ + --autotune-json "${AUTOTUNE_JSON}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --output-json "${REMAP_JSON}" \ + 2>&1 | tee "${CELL_DIR}/step3_remap_bench.log" + + if [ "${FIRST_CELL}" -eq 1 ]; then + FIRST_CELL=0 + else + echo " ," >> "${SWEEP_INDEX}" + fi + cat >> "${SWEEP_INDEX}" <> "${SWEEP_INDEX}" +echo "}" >> "${SWEEP_INDEX}" + +# ── Per-cell screenshot-style hparam summary ────────────────── +SELECTED_TXT="${SWEEP_DIR}/selected_hparams.txt" +PYTHONPATH="${SCRIPT_DIR}/.." python3 - "${SWEEP_INDEX}" "${SELECTED_TXT}" <<'PY' +import json, sys, os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "benchmarks")) +try: + from autotune_topk_mapping import MODE_NAMES, PARAM_NAME +except Exception: + MODE_NAMES = {0: "none", 3: "power", 6: "asinh", 7: "log1p", 9: "erf", + 10: "tanh", 11: "subtract", 13: "exp_stretch", + 15: "shift_pow2", 16: "shift_pow3", 17: "linear_steep"} + PARAM_NAME = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", + 11: "pivot", 13: "alpha", 15: "pivot", 16: "pivot", 17: "k"} + +DISPLAY = {3: "Power", 6: "Asinh", 7: "Log1p", 9: "Erf", 10: "Tanh", + 11: "Subtract", 13: "ExpStretch", 15: "ShiftPow2", + 16: "ShiftPow3", 17: "LinearSteep"} + +idx_path, out_path = sys.argv[1], sys.argv[2] +with open(idx_path) as f: + idx = json.load(f) + +lines = ["== Selected mapping functions (autotuned, block_size sweep) =="] +for cell in idx["cells"]: + with open(cell["autotune_json"]) as f: + results = json.load(f) + best = {} + for r in results: + m = r["mode"] + if m not in best or r["latency_ms"] < best[m]["latency_ms"]: + best[m] = r + parts = [] + for m in sorted(DISPLAY): + if m not in best: + continue + pname = PARAM_NAME.get(m, "p") + pval = best[m].get("param", 0.0) + parts.append(f"{DISPLAY[m]}({pname}={pval})") + lines.append(f"[block_size={cell['axis_value']}] " + " ".join(parts)) + +txt = "\n".join(lines) + "\n" +print(txt) +with open(out_path, "w") as f: + f.write(txt) +PY + +echo "" +echo "============================================================" +echo "Block-size ablation complete." +echo " Sweep dir: ${SWEEP_DIR}" +echo " Per-cell results: ${SWEEP_DIR}/bs/" +echo " Sweep index: ${SWEEP_INDEX}" +echo " Selected hparams: ${SELECTED_TXT}" +echo "Run analyze with:" +echo " python examples/analyze_ablation_remap.py --sweep-dir ${SWEEP_DIR}" +echo "============================================================" diff --git a/examples/ablation_remap_function_model.sh b/examples/ablation_remap_function_model.sh new file mode 100644 index 0000000..0212b83 --- /dev/null +++ b/examples/ablation_remap_function_model.sh @@ -0,0 +1,262 @@ +#!/usr/bin/env bash +# ============================================================ +# Ablation: Remap function vs. model +# +# Sweeps MODEL_NAME across the Qwen3 family. For every model: +# 1. Calibrate (or reuse cached raw_histograms_.npy) +# 2. Autotune the per-mode hparam on that model's histogram +# (NOT hardcoded; freshly tuned per model) +# 3. Remap-bench across the autotuned hparams +# +# Mapping modes under test (matches the screenshot): +# 0 none, 3 power, 6 asinh, 7 log1p, 9 erf, 10 tanh, +# 11 subtract, 13 exp_stretch, 15 shift_pow2, 16 shift_pow3, +# 17 linear_steep +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=4 +MODELS="Qwen/Qwen3-0.6B Qwen/Qwen3-1.7B Qwen/Qwen3-4B Qwen/Qwen3-8B" +TOPK_VAL=2048 +BLOCK_SIZE=1 +MEM=0.7 +MIN_FREE_DISK_GB=22 +ALGO="block_sparse_attention" +BATCH_SIZE=4 +NUM_KV_HEADS=8 +DISTRIBUTIONS="normal bucket_uniform" +MAPPING_MODES="0 3 6 7 9 10 11 13 15 16 17" +MAPPING_HPARAM=0.5 +REPEAT=100 +WARMUP=20 + +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --gpu) GPU_ID="$2"; shift 2 ;; + --models) MODELS="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --min-free-disk-gb) MIN_FREE_DISK_GB="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --distributions) DISTRIBUTIONS="$2"; shift 2 ;; + --modes) MAPPING_MODES="$2"; shift 2 ;; + --mapping-hparam) MAPPING_HPARAM="$2"; shift 2 ;; + --repeat) REPEAT="$2"; shift 2 ;; + --warmup) WARMUP="$2"; shift 2 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" +export SGL_ENABLE_JIT_DEEPGEMM="${SGL_ENABLE_JIT_DEEPGEMM:-true}" +if [ -z "${DG_JIT_NVCC_COMPILER:-}" ]; then + if [ -x /usr/local/cuda/bin/nvcc ]; then + export CUDA_HOME="${CUDA_HOME:-/usr/local/cuda}" + export PATH="${CUDA_HOME}/bin:${PATH}" + export DG_JIT_NVCC_COMPILER="${CUDA_HOME}/bin/nvcc" + elif command -v nvcc >/dev/null 2>&1; then + export DG_JIT_NVCC_COMPILER="$(command -v nvcc)" + fi +fi + +# Per-model max-total-tokens (KV pool cap for calibration). Larger models +# need a smaller cap so they fit at MEM=0.7. Override by passing the env +# var MAX_TOTAL_TOKENS_=N before invocation. +declare -A MAX_TOTAL_TOKENS_LUT +MAX_TOTAL_TOKENS_LUT["qwen3-0.6B"]=131072 +MAX_TOTAL_TOKENS_LUT["qwen3-1.7B"]=64768 +MAX_TOTAL_TOKENS_LUT["qwen3-4B"]=32768 +MAX_TOTAL_TOKENS_LUT["qwen3-8B"]=16384 + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +SWEEP_DIR="${RESULTS_DIR}/ablation_remap_model_${TIMESTAMP}" +mkdir -p "${SWEEP_DIR}" + +CALIBRATION_BASE="/var/tmp/zhuominc/vortex_torch/calibration" +mkdir -p "${CALIBRATION_BASE}" + +echo "============================================================" +echo "Ablation: remap function vs model" +echo " Models: ${MODELS}" +echo " TopK: ${TOPK_VAL}" +echo " Block size: ${BLOCK_SIZE}" +echo " Mapping modes: ${MAPPING_MODES}" +echo " GPU: ${GPU_ID}" +echo " Sweep dir: ${SWEEP_DIR}" +echo "============================================================" + +# ── Sweep ────────────────────────────────────────────────────── +SWEEP_INDEX="${SWEEP_DIR}/sweep_index.json" +{ + echo "{" + echo " \"axis_name\": \"model\"," + echo " \"axis_type\": \"kernel\"," + echo " \"topk_val\": ${TOPK_VAL}," + echo " \"block_size\": ${BLOCK_SIZE}," + echo " \"mapping_modes\": [${MAPPING_MODES// /, }]," + echo " \"cells\": [" +} > "${SWEEP_INDEX}" + +# Pick a single seq_len that satisfies pages/seg > topk_val for all models. +MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) +SEQ_LEN=${MIN_SEQ_LEN} +if [ "${SEQ_LEN}" -lt 8192 ]; then SEQ_LEN=8192; fi + +FIRST_CELL=1 +for MODEL_NAME in ${MODELS}; do + MODEL_TAG="$(echo "${MODEL_NAME##*/}" | sed 's/^Q/q/')" + MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" + DEFAULT_REAL_HIST="${CALIBRATION_BASE}/raw_histograms_${MODEL_TAG}.npy" + + # Per-model max-total-tokens (override-able via env). + MTT_DEFAULT="${MAX_TOTAL_TOKENS_LUT[${MODEL_TAG}]:-32768}" + ENV_KEY="MAX_TOTAL_TOKENS_$(echo "${MODEL_TAG}" | tr '.-' '__')" + MAX_TOTAL_TOKENS="${!ENV_KEY:-${MTT_DEFAULT}}" + + CELL_DIR="${SWEEP_DIR}/${MODEL_SLUG}" + mkdir -p "${CELL_DIR}" + AUTOTUNE_JSON="${CELL_DIR}/autotune_results.json" + REMAP_JSON="${CELL_DIR}/remap_bench.json" + + echo "" + echo "============================================================" + echo ">>> Cell: model=${MODEL_NAME} (max_total_tokens=${MAX_TOTAL_TOKENS})" + echo "============================================================" + + # Step 1: calibrate (cached per-model) + if [ -f "${DEFAULT_REAL_HIST}" ]; then + echo ">>> Calibration cache hit: ${DEFAULT_REAL_HIST}" + REAL_HIST_PATH="${DEFAULT_REAL_HIST}" + else + echo ">>> Calibrating ${MODEL_NAME}" + STAGING_DIR="${CALIBRATION_BASE}/staging_${MODEL_TAG}_${TIMESTAMP}" + mkdir -p "${STAGING_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --min-free-disk-gb "${MIN_FREE_DISK_GB}" \ + --vortex-module-name "${ALGO}" \ + --output-dir "${STAGING_DIR}" \ + 2>&1 | tee "${CELL_DIR}/step1_calibrate.log" + mv -f "${STAGING_DIR}/raw_histograms.npy" "${DEFAULT_REAL_HIST}" + REAL_HIST_PATH="${DEFAULT_REAL_HIST}" + fi + + # Step 2: autotune + echo ">>> Autotuning hparams for ${MODEL_NAME}" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --batch-size "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --collect-stats \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${CELL_DIR}/step2_autotune.log" + + # Step 3: remap bench + echo ">>> Remap bench for ${MODEL_NAME}" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --per-head-bench \ + --batch-sizes "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --distributions ${DISTRIBUTIONS} \ + --mapping-modes ${MAPPING_MODES} \ + --mapping-hparam "${MAPPING_HPARAM}" \ + --autotune-json "${AUTOTUNE_JSON}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --output-json "${REMAP_JSON}" \ + 2>&1 | tee "${CELL_DIR}/step3_remap_bench.log" + + if [ "${FIRST_CELL}" -eq 1 ]; then + FIRST_CELL=0 + else + echo " ," >> "${SWEEP_INDEX}" + fi + cat >> "${SWEEP_INDEX}" <> "${SWEEP_INDEX}" +echo "}" >> "${SWEEP_INDEX}" + +# ── Per-cell screenshot-style hparam summary ────────────────── +SELECTED_TXT="${SWEEP_DIR}/selected_hparams.txt" +PYTHONPATH="${SCRIPT_DIR}/.." python3 - "${SWEEP_INDEX}" "${SELECTED_TXT}" "model" <<'PY' +import json, sys, os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "benchmarks")) +try: + from autotune_topk_mapping import PARAM_NAME +except Exception: + PARAM_NAME = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", + 11: "pivot", 13: "alpha", 15: "pivot", 16: "pivot", 17: "k"} +DISPLAY = {3: "Power", 6: "Asinh", 7: "Log1p", 9: "Erf", 10: "Tanh", + 11: "Subtract", 13: "ExpStretch", 15: "ShiftPow2", + 16: "ShiftPow3", 17: "LinearSteep"} + +idx_path, out_path, axis_name = sys.argv[1], sys.argv[2], sys.argv[3] +with open(idx_path) as f: + idx = json.load(f) + +lines = [f"== Selected mapping functions (autotuned, {axis_name} sweep) =="] +for cell in idx["cells"]: + with open(cell["autotune_json"]) as f: + results = json.load(f) + best = {} + for r in results: + m = r["mode"] + if m not in best or r["latency_ms"] < best[m]["latency_ms"]: + best[m] = r + parts = [] + for m in sorted(DISPLAY): + if m in best: + parts.append(f"{DISPLAY[m]}({PARAM_NAME.get(m,'p')}={best[m].get('param',0.0)})") + lines.append(f"[{axis_name}={cell['axis_value']}] " + " ".join(parts)) + +txt = "\n".join(lines) + "\n" +print(txt) +with open(out_path, "w") as f: + f.write(txt) +PY + +echo "" +echo "============================================================" +echo "Model ablation complete." +echo " Sweep dir: ${SWEEP_DIR}" +echo " Sweep index: ${SWEEP_INDEX}" +echo " Selected hparams: ${SELECTED_TXT}" +echo "Run analyze with:" +echo " python examples/analyze_ablation_remap.py --sweep-dir ${SWEEP_DIR}" +echo "============================================================" diff --git a/examples/ablation_remap_function_topk_benchmark.sh b/examples/ablation_remap_function_topk_benchmark.sh new file mode 100644 index 0000000..7952bd2 --- /dev/null +++ b/examples/ablation_remap_function_topk_benchmark.sh @@ -0,0 +1,277 @@ +#!/usr/bin/env bash +# ============================================================ +# Ablation: Remap function vs. topk-kernel benchmark workload +# +# Sweeps the kernel-bench INPUT distribution (the workload that +# stresses the TopK kernel) and, per cell, runs autotune + +# remap-bench so the per-mode hparam is freshly chosen for that +# distribution. This is the robustness ablation: do the +# autotuned remap functions still beat the unmapped baseline +# when the input score distribution shifts? +# +# Distributions available in bench_topk.py: +# normal — N(0,1) per-page scores +# lognormal — heavy-tailed positive scores +# uniform — U[0,1) +# bucket_uniform— per-bucket uniform (worst case for radix) +# real — sampled from raw_histograms_.npy +# +# Mapping modes under test (matches the screenshot): +# 0 none, 3 power, 6 asinh, 7 log1p, 9 erf, 10 tanh, +# 11 subtract, 13 exp_stretch, 15 shift_pow2, 16 shift_pow3, +# 17 linear_steep +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=4 +MODEL_NAME="Qwen/Qwen3-1.7B" +TOPK_VAL=2048 +BLOCK_SIZE=1 +SEQ_LEN=32768 +MEM=0.7 +MAX_TOTAL_TOKENS=64768 +MIN_FREE_DISK_GB=22 +ALGO="block_sparse_attention" +BATCH_SIZE=4 +NUM_KV_HEADS=8 +MAPPING_MODES="0 3 6 7 9 10 11 13 15 16 17" +MAPPING_HPARAM=0.5 +REPEAT=100 +WARMUP=20 +# Distributions to sweep (one per cell). "real" requires raw_histograms.npy. +DISTRIBUTION_LIST="normal lognormal uniform bucket_uniform real" +REAL_HISTOGRAMS="" + +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --gpu) GPU_ID="$2"; shift 2 ;; + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; + --min-free-disk-gb) MIN_FREE_DISK_GB="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --modes) MAPPING_MODES="$2"; shift 2 ;; + --mapping-hparam) MAPPING_HPARAM="$2"; shift 2 ;; + --repeat) REPEAT="$2"; shift 2 ;; + --warmup) WARMUP="$2"; shift 2 ;; + --dist-list|--distributions) DISTRIBUTION_LIST="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" +export SGL_ENABLE_JIT_DEEPGEMM="${SGL_ENABLE_JIT_DEEPGEMM:-true}" +if [ -z "${DG_JIT_NVCC_COMPILER:-}" ]; then + if [ -x /usr/local/cuda/bin/nvcc ]; then + export CUDA_HOME="${CUDA_HOME:-/usr/local/cuda}" + export PATH="${CUDA_HOME}/bin:${PATH}" + export DG_JIT_NVCC_COMPILER="${CUDA_HOME}/bin/nvcc" + elif command -v nvcc >/dev/null 2>&1; then + export DG_JIT_NVCC_COMPILER="$(command -v nvcc)" + fi +fi + +MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) +if [ "${SEQ_LEN}" -lt "${MIN_SEQ_LEN}" ]; then + SEQ_LEN=${MIN_SEQ_LEN} +fi + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +SWEEP_DIR="${RESULTS_DIR}/ablation_remap_topk_benchmark_${MODEL_SLUG}_${TIMESTAMP}" +mkdir -p "${SWEEP_DIR}" + +CALIBRATION_BASE="/var/tmp/zhuominc/vortex_torch/calibration" +MODEL_TAG="$(echo "${MODEL_NAME##*/}" | sed 's/^Q/q/')" +DEFAULT_REAL_HIST="${CALIBRATION_BASE}/raw_histograms_${MODEL_TAG}.npy" +mkdir -p "${CALIBRATION_BASE}" + +if [ -z "${REAL_HISTOGRAMS}" ] && [ -f "${DEFAULT_REAL_HIST}" ]; then + REAL_HISTOGRAMS="${DEFAULT_REAL_HIST}" +fi + +# Need raw_histograms only if "real" is in the distribution list. +NEED_REAL=0 +for d in ${DISTRIBUTION_LIST}; do + if [ "$d" = "real" ]; then NEED_REAL=1; fi +done + +if [ "${NEED_REAL}" -eq 1 ] && [ -z "${REAL_HISTOGRAMS}" ]; then + echo ">>> Step 0: Calibrating ${MODEL_NAME} (needed for distribution=real)" + STAGING_DIR="${CALIBRATION_BASE}/staging_${MODEL_TAG}_${TIMESTAMP}" + mkdir -p "${STAGING_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --min-free-disk-gb "${MIN_FREE_DISK_GB}" \ + --vortex-module-name "${ALGO}" \ + --output-dir "${STAGING_DIR}" \ + 2>&1 | tee "${SWEEP_DIR}/step0_calibrate.log" + mv -f "${STAGING_DIR}/raw_histograms.npy" "${DEFAULT_REAL_HIST}" + REAL_HISTOGRAMS="${DEFAULT_REAL_HIST}" +fi + +echo "============================================================" +echo "Ablation: remap function vs topk-kernel benchmark workload" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" +echo " TopK: ${TOPK_VAL}" +echo " Seq len: ${SEQ_LEN}" +echo " Distributions: ${DISTRIBUTION_LIST}" +echo " Mapping modes: ${MAPPING_MODES}" +echo " GPU: ${GPU_ID}" +echo " Sweep dir: ${SWEEP_DIR}" +echo "============================================================" + +# ── Sweep ────────────────────────────────────────────────────── +SWEEP_INDEX="${SWEEP_DIR}/sweep_index.json" +{ + echo "{" + echo " \"axis_name\": \"distribution\"," + echo " \"axis_type\": \"kernel\"," + echo " \"model_name\": \"${MODEL_NAME}\"," + echo " \"topk_val\": ${TOPK_VAL}," + echo " \"block_size\": ${BLOCK_SIZE}," + echo " \"seq_len\": ${SEQ_LEN}," + echo " \"mapping_modes\": [${MAPPING_MODES// /, }]," + echo " \"cells\": [" +} > "${SWEEP_INDEX}" + +FIRST_CELL=1 +for DIST in ${DISTRIBUTION_LIST}; do + CELL_DIR="${SWEEP_DIR}/dist_${DIST}" + mkdir -p "${CELL_DIR}" + AUTOTUNE_JSON="${CELL_DIR}/autotune_results.json" + REMAP_JSON="${CELL_DIR}/remap_bench.json" + + echo "" + echo "============================================================" + echo ">>> Cell: distribution=${DIST}" + echo "============================================================" + + AUTOTUNE_DIST_ARGS=() + BENCH_DIST_ARGS=() + if [ "${DIST}" = "real" ]; then + AUTOTUNE_DIST_ARGS=(--real-histograms "${REAL_HISTOGRAMS}") + BENCH_DIST_ARGS=(--real-histograms "${REAL_HISTOGRAMS}" --distributions real) + else + AUTOTUNE_DIST_ARGS=(--distributions "${DIST}") + BENCH_DIST_ARGS=(--distributions "${DIST}") + fi + + echo ">>> Autotuning hparams on dist=${DIST}" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --batch-size "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + "${AUTOTUNE_DIST_ARGS[@]}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --collect-stats \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${CELL_DIR}/step2_autotune.log" + + echo ">>> Remap bench on dist=${DIST}" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --per-head-bench \ + --batch-sizes "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + "${BENCH_DIST_ARGS[@]}" \ + --mapping-modes ${MAPPING_MODES} \ + --mapping-hparam "${MAPPING_HPARAM}" \ + --autotune-json "${AUTOTUNE_JSON}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --output-json "${REMAP_JSON}" \ + 2>&1 | tee "${CELL_DIR}/step3_remap_bench.log" + + if [ "${FIRST_CELL}" -eq 1 ]; then + FIRST_CELL=0 + else + echo " ," >> "${SWEEP_INDEX}" + fi + cat >> "${SWEEP_INDEX}" <> "${SWEEP_INDEX}" +echo "}" >> "${SWEEP_INDEX}" + +# ── Per-cell screenshot-style hparam summary ────────────────── +SELECTED_TXT="${SWEEP_DIR}/selected_hparams.txt" +PYTHONPATH="${SCRIPT_DIR}/.." python3 - "${SWEEP_INDEX}" "${SELECTED_TXT}" "distribution" <<'PY' +import json, sys, os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "benchmarks")) +try: + from autotune_topk_mapping import PARAM_NAME +except Exception: + PARAM_NAME = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", + 11: "pivot", 13: "alpha", 15: "pivot", 16: "pivot", 17: "k"} +DISPLAY = {3: "Power", 6: "Asinh", 7: "Log1p", 9: "Erf", 10: "Tanh", + 11: "Subtract", 13: "ExpStretch", 15: "ShiftPow2", + 16: "ShiftPow3", 17: "LinearSteep"} + +idx_path, out_path, axis_name = sys.argv[1], sys.argv[2], sys.argv[3] +with open(idx_path) as f: + idx = json.load(f) + +lines = [f"== Selected mapping functions (autotuned, {axis_name} sweep) =="] +for cell in idx["cells"]: + with open(cell["autotune_json"]) as f: + results = json.load(f) + best = {} + for r in results: + m = r["mode"] + if m not in best or r["latency_ms"] < best[m]["latency_ms"]: + best[m] = r + parts = [] + for m in sorted(DISPLAY): + if m in best: + parts.append(f"{DISPLAY[m]}({PARAM_NAME.get(m,'p')}={best[m].get('param',0.0)})") + lines.append(f"[{axis_name}={cell['axis_value']}] " + " ".join(parts)) + +txt = "\n".join(lines) + "\n" +print(txt) +with open(out_path, "w") as f: + f.write(txt) +PY + +echo "" +echo "============================================================" +echo "topk_benchmark (kernel workload) ablation complete." +echo " Sweep dir: ${SWEEP_DIR}" +echo " Sweep index: ${SWEEP_INDEX}" +echo " Selected hparams: ${SELECTED_TXT}" +echo "Run analyze with:" +echo " python examples/analyze_ablation_remap.py --sweep-dir ${SWEEP_DIR}" +echo "============================================================" diff --git a/examples/ablation_remap_function_topk_val.sh b/examples/ablation_remap_function_topk_val.sh new file mode 100644 index 0000000..4e60440 --- /dev/null +++ b/examples/ablation_remap_function_topk_val.sh @@ -0,0 +1,255 @@ +#!/usr/bin/env bash +# ============================================================ +# Ablation: Remap function vs. topk_val +# +# Sweeps TOPK_VAL and, for every cell, runs +# autotune -> remap-bench +# so the per-mode hyperparameter is freshly chosen by autotune +# for that topk_val (NOT hardcoded). Calibration runs once for +# the model. +# +# Mapping modes under test (matches the screenshot): +# 0 none, 3 power, 6 asinh, 7 log1p, 9 erf, 10 tanh, +# 11 subtract, 13 exp_stretch, 15 shift_pow2, 16 shift_pow3, +# 17 linear_steep +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=4 +MODEL_NAME="Qwen/Qwen3-1.7B" +BLOCK_SIZE=1 +MEM=0.7 +MAX_TOTAL_TOKENS=64768 +MIN_FREE_DISK_GB=22 +ALGO="block_sparse_attention" +BATCH_SIZE=4 +NUM_KV_HEADS=8 +DISTRIBUTIONS="normal bucket_uniform" +MAPPING_MODES="0 3 6 7 9 10 11 13 15 16 17" +MAPPING_HPARAM=0.5 +REPEAT=100 +WARMUP=20 +TOPK_VALS="512 1024 2048 4096" +REAL_HISTOGRAMS="" + +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --gpu) GPU_ID="$2"; shift 2 ;; + --model-name) MODEL_NAME="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; + --min-free-disk-gb) MIN_FREE_DISK_GB="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --distributions) DISTRIBUTIONS="$2"; shift 2 ;; + --modes) MAPPING_MODES="$2"; shift 2 ;; + --mapping-hparam) MAPPING_HPARAM="$2"; shift 2 ;; + --repeat) REPEAT="$2"; shift 2 ;; + --warmup) WARMUP="$2"; shift 2 ;; + --topk-vals) TOPK_VALS="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" +export SGL_ENABLE_JIT_DEEPGEMM="${SGL_ENABLE_JIT_DEEPGEMM:-true}" +if [ -z "${DG_JIT_NVCC_COMPILER:-}" ]; then + if [ -x /usr/local/cuda/bin/nvcc ]; then + export CUDA_HOME="${CUDA_HOME:-/usr/local/cuda}" + export PATH="${CUDA_HOME}/bin:${PATH}" + export DG_JIT_NVCC_COMPILER="${CUDA_HOME}/bin/nvcc" + elif command -v nvcc >/dev/null 2>&1; then + export DG_JIT_NVCC_COMPILER="$(command -v nvcc)" + fi +fi + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +SWEEP_DIR="${RESULTS_DIR}/ablation_remap_topk_val_${MODEL_SLUG}_${TIMESTAMP}" +mkdir -p "${SWEEP_DIR}" + +CALIBRATION_BASE="/var/tmp/zhuominc/vortex_torch/calibration" +MODEL_TAG="$(echo "${MODEL_NAME##*/}" | sed 's/^Q/q/')" +DEFAULT_REAL_HIST="${CALIBRATION_BASE}/raw_histograms_${MODEL_TAG}.npy" +mkdir -p "${CALIBRATION_BASE}" + +if [ -z "${REAL_HISTOGRAMS}" ] && [ -f "${DEFAULT_REAL_HIST}" ]; then + REAL_HISTOGRAMS="${DEFAULT_REAL_HIST}" +fi + +echo "============================================================" +echo "Ablation: remap function vs topk_val" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" +echo " Topk vals: ${TOPK_VALS}" +echo " Mapping modes: ${MAPPING_MODES}" +echo " GPU: ${GPU_ID}" +echo " Sweep dir: ${SWEEP_DIR}" +echo "============================================================" + +# ── Step 0: Calibrate once for this model ────────────────────── +if [ -n "${REAL_HISTOGRAMS}" ]; then + echo ">>> Step 0: SKIPPED calibration (using ${REAL_HISTOGRAMS})" + REAL_HIST_PATH="${REAL_HISTOGRAMS}" +else + echo ">>> Step 0: Calibrating ${MODEL_NAME}" + STAGING_DIR="${CALIBRATION_BASE}/staging_${MODEL_TAG}_${TIMESTAMP}" + mkdir -p "${STAGING_DIR}" + CAL_TOPK_VAL=$(echo "${TOPK_VALS}" | tr ' ' '\n' | sort -n | tail -n 1) + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${CAL_TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --min-free-disk-gb "${MIN_FREE_DISK_GB}" \ + --vortex-module-name "${ALGO}" \ + --output-dir "${STAGING_DIR}" \ + 2>&1 | tee "${SWEEP_DIR}/step0_calibrate.log" + mv -f "${STAGING_DIR}/raw_histograms.npy" "${DEFAULT_REAL_HIST}" + REAL_HIST_PATH="${DEFAULT_REAL_HIST}" +fi + +# ── Sweep ────────────────────────────────────────────────────── +SWEEP_INDEX="${SWEEP_DIR}/sweep_index.json" +{ + echo "{" + echo " \"axis_name\": \"topk_val\"," + echo " \"axis_type\": \"kernel\"," + echo " \"model_name\": \"${MODEL_NAME}\"," + echo " \"block_size\": ${BLOCK_SIZE}," + echo " \"mapping_modes\": [${MAPPING_MODES// /, }]," + echo " \"cells\": [" +} > "${SWEEP_INDEX}" + +FIRST_CELL=1 +for TOPK_VAL in ${TOPK_VALS}; do + MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) + SEQ_LEN=${MIN_SEQ_LEN} + if [ "${SEQ_LEN}" -lt 8192 ]; then SEQ_LEN=8192; fi + if [ "${SEQ_LEN}" -lt $(( TOPK_VAL * BLOCK_SIZE * 4 )) ]; then + SEQ_LEN=$(( TOPK_VAL * BLOCK_SIZE * 4 )) + fi + + CELL_DIR="${SWEEP_DIR}/topk${TOPK_VAL}" + mkdir -p "${CELL_DIR}" + AUTOTUNE_JSON="${CELL_DIR}/autotune_results.json" + REMAP_JSON="${CELL_DIR}/remap_bench.json" + + echo "" + echo "============================================================" + echo ">>> Cell: topk_val=${TOPK_VAL} seq_len=${SEQ_LEN}" + echo "============================================================" + + echo ">>> Autotuning hparams for topk_val=${TOPK_VAL}" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --batch-size "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --collect-stats \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${CELL_DIR}/step2_autotune.log" + + echo ">>> Remap bench for topk_val=${TOPK_VAL}" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --per-head-bench \ + --batch-sizes "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --distributions ${DISTRIBUTIONS} \ + --mapping-modes ${MAPPING_MODES} \ + --mapping-hparam "${MAPPING_HPARAM}" \ + --autotune-json "${AUTOTUNE_JSON}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --output-json "${REMAP_JSON}" \ + 2>&1 | tee "${CELL_DIR}/step3_remap_bench.log" + + if [ "${FIRST_CELL}" -eq 1 ]; then + FIRST_CELL=0 + else + echo " ," >> "${SWEEP_INDEX}" + fi + cat >> "${SWEEP_INDEX}" <> "${SWEEP_INDEX}" +echo "}" >> "${SWEEP_INDEX}" + +# ── Per-cell screenshot-style hparam summary ────────────────── +SELECTED_TXT="${SWEEP_DIR}/selected_hparams.txt" +PYTHONPATH="${SCRIPT_DIR}/.." python3 - "${SWEEP_INDEX}" "${SELECTED_TXT}" "topk_val" <<'PY' +import json, sys, os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "benchmarks")) +try: + from autotune_topk_mapping import PARAM_NAME +except Exception: + PARAM_NAME = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", + 11: "pivot", 13: "alpha", 15: "pivot", 16: "pivot", 17: "k"} +DISPLAY = {3: "Power", 6: "Asinh", 7: "Log1p", 9: "Erf", 10: "Tanh", + 11: "Subtract", 13: "ExpStretch", 15: "ShiftPow2", + 16: "ShiftPow3", 17: "LinearSteep"} + +idx_path, out_path, axis_name = sys.argv[1], sys.argv[2], sys.argv[3] +with open(idx_path) as f: + idx = json.load(f) + +lines = [f"== Selected mapping functions (autotuned, {axis_name} sweep) =="] +for cell in idx["cells"]: + with open(cell["autotune_json"]) as f: + results = json.load(f) + best = {} + for r in results: + m = r["mode"] + if m not in best or r["latency_ms"] < best[m]["latency_ms"]: + best[m] = r + parts = [] + for m in sorted(DISPLAY): + if m in best: + parts.append(f"{DISPLAY[m]}({PARAM_NAME.get(m,'p')}={best[m].get('param',0.0)})") + lines.append(f"[{axis_name}={cell['axis_value']}] " + " ".join(parts)) + +txt = "\n".join(lines) + "\n" +print(txt) +with open(out_path, "w") as f: + f.write(txt) +PY + +echo "" +echo "============================================================" +echo "topk_val ablation complete." +echo " Sweep dir: ${SWEEP_DIR}" +echo " Sweep index: ${SWEEP_INDEX}" +echo " Selected hparams: ${SELECTED_TXT}" +echo "Run analyze with:" +echo " python examples/analyze_ablation_remap.py --sweep-dir ${SWEEP_DIR}" +echo "============================================================" diff --git a/examples/analyze_ablation_remap.py b/examples/analyze_ablation_remap.py new file mode 100644 index 0000000..18b0a0b --- /dev/null +++ b/examples/analyze_ablation_remap.py @@ -0,0 +1,416 @@ +#!/usr/bin/env python3 +""" +Analyze remap-function ablation sweeps. + +Reads one or more sweep directories produced by + ablation_remap_function_block_size.sh + ablation_remap_function_topk_val.sh + ablation_remap_function_model.sh + ablation_remap_function_topk_benchmark.sh + +and emits, for each sweep: + - tidy CSV of every (axis_value, mapping_mode, distribution, head) row + - wide CSV tables: latency, speedup vs baseline, chosen hparam + - LaTeX version of the chosen-hparam table + - markdown summary including the screenshot-style "Selected mapping + functions" line per axis value + - matplotlib PDF plots: latency vs axis, speedup vs axis, threshold + bin size vs axis (one curve per mapping mode) + +Usage: + python examples/analyze_ablation_remap.py \ + --sweep-dir results/ablation_remap_block_size_ \ + [--sweep-dir results/ablation_remap_model_ ...] \ + --output-dir results/ablation_remap_analysis_ +""" + +from __future__ import annotations + +import argparse +import json +import math +import os +import sys +from pathlib import Path +from typing import Any, Dict, List, Optional + +import numpy as np +import pandas as pd +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +# Pull mode metadata from the autotune script so we don't duplicate it. +SCRIPT_DIR = Path(__file__).resolve().parent +BENCH_DIR = SCRIPT_DIR.parent / "benchmarks" +sys.path.insert(0, str(BENCH_DIR)) +try: + from autotune_topk_mapping import MODE_NAMES, PARAM_NAME # type: ignore +except Exception: + MODE_NAMES = {0: "none", 3: "power", 4: "log", 6: "asinh", 7: "log1p", + 8: "trunc8", 9: "erf", 10: "tanh", 11: "subtract", + 13: "exp_stretch", 15: "shift_pow2", 16: "shift_pow3", + 17: "linear_steep"} + PARAM_NAME = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", + 11: "pivot", 13: "alpha", 15: "pivot", 16: "pivot", 17: "k"} + +DISPLAY_NAME = { + 0: "None", 3: "Power", 6: "Asinh", 7: "Log1p", 9: "Erf", + 10: "Tanh", 11: "Subtract", 13: "ExpStretch", + 15: "ShiftPow2", 16: "ShiftPow3", 17: "LinearSteep", +} + + +# ---------- Loading ---------- + +def _load_json(path: str) -> Any: + with open(path) as f: + return json.load(f) + + +def _best_per_mode_from_autotune(autotune_results: List[dict]) -> Dict[int, dict]: + best: Dict[int, dict] = {} + for r in autotune_results: + m = int(r["mode"]) + if m not in best or r["latency_ms"] < best[m]["latency_ms"]: + best[m] = r + return best + + +def _flatten_remap_bench(remap_results: List[dict]) -> pd.DataFrame: + """Flatten bench_topk.py --remap-bench output into one row per + (cfg, mode_row). Drops per-head sub-rows; keeps head='all' so each + cell contributes a single point per (mapping_mode, distribution).""" + rows = [] + for cfg in remap_results: + if cfg.get("head", "all") != "all": + continue + baseline = cfg.get("baseline_ms") + for mr in cfg.get("modes", []): + mode = int(mr["mode"]) + rows.append({ + "distribution": cfg.get("distribution"), + "batch_size": cfg.get("batch_size"), + "num_kv_heads": cfg.get("num_kv_heads"), + "seq_len": cfg.get("seq_len"), + "topk_val": cfg.get("topk_val"), + "pages_per_seg": cfg.get("pages_per_seg"), + "mode": mode, + "mode_name": mr.get("mode_name", MODE_NAMES.get(mode, str(mode))), + "param_value": mr.get("power"), + "fused_ms": mr.get("fused_ms"), + "remap_ms": mr.get("remap_ms"), + "topk_after_remap_ms": mr.get("topk_after_remap_ms"), + "split_total_ms": mr.get("split_total_ms"), + "baseline_ms": baseline, + "threshold_bin_size_mean": mr.get("threshold_bin_size_mean"), + "threshold_bin_size_max": mr.get("threshold_bin_size_max"), + "refine_rounds_mean": mr.get("refine_rounds_mean"), + }) + return pd.DataFrame(rows) + + +def load_sweep(sweep_dir: Path) -> Dict[str, Any]: + idx_path = sweep_dir / "sweep_index.json" + if not idx_path.exists(): + raise FileNotFoundError(f"missing sweep_index.json in {sweep_dir}") + idx = _load_json(idx_path) + axis_name = idx["axis_name"] + + rows: List[pd.DataFrame] = [] + chosen_hparams: List[dict] = [] + for cell in idx["cells"]: + axis_value = cell["axis_value"] + autotune_results = _load_json(cell["autotune_json"]) + best = _best_per_mode_from_autotune(autotune_results) + for mode, r in best.items(): + chosen_hparams.append({ + "axis_value": axis_value, + "mode": int(mode), + "mode_name": r.get("mode_name", MODE_NAMES.get(int(mode), str(mode))), + "param_name": r.get("param_name") or PARAM_NAME.get(int(mode), "p"), + "param_value": r.get("param"), + "autotune_latency_ms": r.get("latency_ms"), + }) + + remap_results = _load_json(cell["remap_bench_json"]) + df = _flatten_remap_bench(remap_results) + df.insert(0, "axis_value", axis_value) + df.insert(0, "axis_name", axis_name) + rows.append(df) + + tidy = pd.concat(rows, ignore_index=True) if rows else pd.DataFrame() + chosen = pd.DataFrame(chosen_hparams) + return { + "axis_name": axis_name, + "axis_type": idx.get("axis_type", "kernel"), + "index": idx, + "tidy": tidy, + "chosen": chosen, + } + + +# ---------- Tables ---------- + +def _wide_latency(tidy: pd.DataFrame, axis_name: str, distribution: Optional[str] = None) -> pd.DataFrame: + df = tidy.copy() + if distribution is not None and "distribution" in df.columns: + df = df[df["distribution"] == distribution] + # Best fused latency per (axis_value, mode) — collapse over distribution + # if no filter was applied. + g = df.groupby(["axis_value", "mode", "mode_name"], dropna=False)["fused_ms"].min().reset_index() + wide = g.pivot(index="axis_value", columns="mode", values="fused_ms") + # Also pivot mode_name → label for column header. + return wide.rename(columns=lambda m: f"{m}:{MODE_NAMES.get(int(m), '?')}") + + +def _wide_baseline(tidy: pd.DataFrame, distribution: Optional[str] = None) -> pd.Series: + df = tidy.copy() + if distribution is not None and "distribution" in df.columns: + df = df[df["distribution"] == distribution] + return df.groupby("axis_value")["baseline_ms"].min() + + +def _wide_speedup(tidy: pd.DataFrame, axis_name: str, distribution: Optional[str] = None) -> pd.DataFrame: + lat = _wide_latency(tidy, axis_name, distribution=distribution) + base = _wide_baseline(tidy, distribution=distribution) + return lat.rdiv(base, axis=0) # baseline / fused + + +def _wide_chosen_hparam(chosen: pd.DataFrame) -> pd.DataFrame: + if chosen.empty: + return pd.DataFrame() + chosen = chosen.copy() + chosen["label"] = chosen.apply( + lambda r: f"{DISPLAY_NAME.get(int(r['mode']), r['mode_name'])}({r['param_name']}={r['param_value']})", + axis=1, + ) + wide = chosen.pivot(index="axis_value", columns="mode", values="label") + return wide.rename(columns=lambda m: f"{m}:{MODE_NAMES.get(int(m), '?')}") + + +def _df_to_latex(df: pd.DataFrame, caption: str, label: str) -> str: + if df.empty: + return f"% empty table for {label}\n" + try: + return df.to_latex( + float_format=lambda v: "" if pd.isna(v) else f"{v:.4f}", + na_rep="", + caption=caption, + label=label, + ) + except Exception: + return df.to_string() + + +# ---------- Plots ---------- + +def _axis_x(values: List[Any]) -> List[float]: + """Convert axis values (which may be strings or ints) to numeric x + coordinates. Strings are mapped to 0..N-1; numerics keep their value.""" + out = [] + for i, v in enumerate(values): + if isinstance(v, (int, float)): + out.append(float(v)) + else: + out.append(float(i)) + return out + + +def _plot_metric_vs_axis(tidy: pd.DataFrame, axis_name: str, metric: str, + out_path: Path, ylabel: str, title: str, + baseline_series: Optional[pd.Series] = None, + logy: bool = False) -> None: + if tidy.empty: + return + g = tidy.groupby(["axis_value", "mode", "mode_name"], dropna=False)[metric].min().reset_index() + axis_values = sorted(g["axis_value"].unique(), + key=lambda v: (not isinstance(v, (int, float)), v)) + x = _axis_x(axis_values) + + fig, ax = plt.subplots(figsize=(7, 4.5)) + cmap = plt.cm.get_cmap("tab10") + for i, mode in enumerate(sorted(g["mode"].unique())): + sub = g[g["mode"] == mode].set_index("axis_value").reindex(axis_values) + ax.plot(x, sub[metric].values, + marker="o", color=cmap(i % 10), + label=f"{mode}:{MODE_NAMES.get(int(mode), '?')}") + + if baseline_series is not None and not baseline_series.empty: + bx = baseline_series.reindex(axis_values).values + ax.plot(x, bx, "k--", linewidth=2, label="baseline (unmapped)") + + ax.set_xlabel(axis_name) + ax.set_ylabel(ylabel) + ax.set_title(title) + if logy: + ax.set_yscale("log") + if all(isinstance(v, (int, float)) for v in axis_values): + ax.set_xticks(x) + ax.set_xticklabels([str(v) for v in axis_values]) + else: + ax.set_xticks(x) + ax.set_xticklabels([str(v) for v in axis_values], rotation=20, ha="right") + ax.grid(True, alpha=0.3) + ax.legend(fontsize=7, ncol=2, loc="best") + fig.tight_layout() + fig.savefig(out_path) + plt.close(fig) + + +# ---------- Per-sweep emitters ---------- + +def emit_sweep(sweep: Dict[str, Any], out_root: Path) -> None: + axis_name = sweep["axis_name"] + out_dir = out_root / axis_name + out_dir.mkdir(parents=True, exist_ok=True) + + tidy: pd.DataFrame = sweep["tidy"] + chosen: pd.DataFrame = sweep["chosen"] + + if tidy.empty: + print(f"[{axis_name}] no data, skipping") + return + + tidy.to_csv(out_dir / "tidy.csv", index=False) + chosen.to_csv(out_dir / "chosen_hparams_long.csv", index=False) + + distributions = sorted([d for d in tidy["distribution"].dropna().unique()]) + + # Per-distribution wide tables + plots. + for dist in distributions + [None]: + suffix = f"_{dist}" if dist else "_all" + lat_wide = _wide_latency(tidy, axis_name, distribution=dist) + spd_wide = _wide_speedup(tidy, axis_name, distribution=dist) + base = _wide_baseline(tidy, distribution=dist) + + lat_wide.to_csv(out_dir / f"table_latency_ms{suffix}.csv") + spd_wide.to_csv(out_dir / f"table_speedup_vs_baseline{suffix}.csv") + base.to_frame("baseline_ms").to_csv(out_dir / f"table_baseline_ms{suffix}.csv") + + with open(out_dir / f"table_latency_ms{suffix}.tex", "w") as f: + f.write(_df_to_latex(lat_wide, + caption=f"Best fused-kernel latency (ms) on {axis_name} sweep ({dist or 'all dists'})", + label=f"tab:lat-{axis_name}{suffix}")) + with open(out_dir / f"table_speedup_vs_baseline{suffix}.tex", "w") as f: + f.write(_df_to_latex(spd_wide, + caption=f"Speedup over unmapped baseline on {axis_name} sweep ({dist or 'all dists'})", + label=f"tab:spd-{axis_name}{suffix}")) + + _plot_metric_vs_axis( + tidy if dist is None else tidy[tidy["distribution"] == dist], + axis_name, "fused_ms", + out_dir / f"plot_latency_vs_{axis_name}{suffix}.pdf", + ylabel="fused TopK kernel latency (ms)", + title=f"TopK kernel latency vs {axis_name} ({dist or 'all dists'})", + baseline_series=base, + ) + # Speedup plot. + spd_long = tidy.copy() + if dist: + spd_long = spd_long[spd_long["distribution"] == dist] + spd_long = spd_long.assign( + speedup=spd_long["baseline_ms"] / spd_long["fused_ms"] + ) + _plot_metric_vs_axis( + spd_long, axis_name, "speedup", + out_dir / f"plot_speedup_vs_{axis_name}{suffix}.pdf", + ylabel="speedup over unmapped baseline", + title=f"Speedup vs {axis_name} ({dist or 'all dists'})", + ) + # Threshold bin size diagnostic. + _plot_metric_vs_axis( + tidy if dist is None else tidy[tidy["distribution"] == dist], + axis_name, "threshold_bin_size_mean", + out_dir / f"plot_threshold_bin_size_vs_{axis_name}{suffix}.pdf", + ylabel="mean threshold-bin size (entries)", + title=f"Stage-1 threshold bin size vs {axis_name} ({dist or 'all dists'})", + ) + + # Chosen-hparam wide table (axis-independent of distribution: autotune + # picks one hparam per mode per axis cell). + chosen_wide = _wide_chosen_hparam(chosen) + chosen_wide.to_csv(out_dir / "table_chosen_hparams.csv") + with open(out_dir / "table_chosen_hparams.tex", "w") as f: + f.write(_df_to_latex(chosen_wide, + caption=f"Autotuned remap-function hyperparameters per {axis_name} cell", + label=f"tab:hparam-{axis_name}")) + + # Markdown summary. + md_lines: List[str] = [] + md_lines.append(f"# Ablation: remap function vs `{axis_name}`\n") + md_lines.append(f"Source: `{sweep['index'].get('cells', [{}])[0].get('cell_dir', '')}/...`\n") + + md_lines.append("\n## Selected mapping functions (autotuned)\n") + md_lines.append("```") + for v in chosen_wide.index.tolist(): + parts = [] + for col in chosen_wide.columns: + label = chosen_wide.loc[v, col] + if isinstance(label, str) and label: + parts.append(label) + md_lines.append(f"[{axis_name}={v}] " + " ".join(parts)) + md_lines.append("```\n") + + md_lines.append("\n## Latency (ms) — best fused, all distributions\n") + md_lines.append(_wide_latency(tidy, axis_name).to_markdown()) + md_lines.append("\n\n## Speedup over unmapped baseline\n") + md_lines.append(_wide_speedup(tidy, axis_name).to_markdown()) + md_lines.append("\n\n## Chosen hyperparameters\n") + md_lines.append(chosen_wide.to_markdown()) + md_lines.append("\n\n## Plots\n") + for p in sorted(out_dir.glob("plot_*.pdf")): + md_lines.append(f"- `{p.name}`") + + with open(out_dir / "summary.md", "w") as f: + f.write("\n".join(md_lines) + "\n") + + print(f"[{axis_name}] wrote artifacts to {out_dir}") + + +# ---------- Top-level ---------- + +def main() -> None: + ap = argparse.ArgumentParser(description="Aggregate ablation_remap_function_*.sh sweep outputs.") + ap.add_argument("--sweep-dir", action="append", required=True, + help="A sweep directory containing sweep_index.json. Repeat for multiple sweeps.") + ap.add_argument("--output-dir", type=str, required=True, + help="Where to write tables, plots, and summary.") + args = ap.parse_args() + + out_root = Path(args.output_dir) + out_root.mkdir(parents=True, exist_ok=True) + + sweeps: List[Dict[str, Any]] = [] + for sd in args.sweep_dir: + sweep = load_sweep(Path(sd)) + emit_sweep(sweep, out_root) + sweeps.append(sweep) + + # Cross-axis recommended hparams: for every mode, pick the param value + # that was selected most often across all axis cells of all sweeps. + all_chosen = pd.concat([s["chosen"] for s in sweeps if not s["chosen"].empty], + ignore_index=True) if sweeps else pd.DataFrame() + rec_lines: List[str] = [] + if not all_chosen.empty: + rec = (all_chosen.groupby(["mode", "mode_name", "param_name"])["param_value"] + .agg(lambda s: s.value_counts().idxmax()) + .reset_index().rename(columns={"param_value": "recommended"})) + rec.to_csv(out_root / "recommended_hparams.csv", index=False) + rec_lines.append("## Cross-axis recommended hparams (mode of selections)\n") + rec_lines.append(rec.to_markdown(index=False)) + + index_lines = ["# Remap-function ablation summary\n"] + for s in sweeps: + axis = s["axis_name"] + index_lines.append(f"- [`{axis}`]({axis}/summary.md)") + if rec_lines: + index_lines.append("") + index_lines.extend(rec_lines) + with open(out_root / "index.md", "w") as f: + f.write("\n".join(index_lines) + "\n") + print(f"[index] {out_root}/index.md") + + +if __name__ == "__main__": + main() diff --git a/examples/profile_in_docker.sh b/examples/profile_in_docker.sh new file mode 100755 index 0000000..a64606f --- /dev/null +++ b/examples/profile_in_docker.sh @@ -0,0 +1,181 @@ +#!/usr/bin/env bash +# ============================================================ +# Run examples/profile_parallel_vs_fused.sh inside an NVIDIA +# CUDA devel container so we can enable profiling without +# touching the host's RmProfilingAdminOnly=1 setting. +# +# Key idea: +# - The container has `ncu` bundled with the CUDA toolkit. +# - --cap-add=SYS_ADMIN gives the container the capability +# CUPTI needs to access perf counters, so ncu works +# regardless of the host's nvidia-driver profiling restriction. +# - We mount the host's uv venv and the project, so there's +# no Python/pytorch install inside the container — the host +# venv's python is used directly. +# +# Image: +# Defaults to an NGC public CUDA devel image. For B200 (Blackwell / +# sm_100) you need CUDA ≥ 12.8 and ncu ≥ 2024.3; CUDA 13.0+ covers +# that. Override with NCU_IMAGE if you prefer a specific tag. +# +# Usage: +# bash examples/profile_in_docker.sh # defaults +# GPU=2 NUM_SPLITS=2 bash examples/profile_in_docker.sh +# NCU_IMAGE=nvcr.io/nvidia/pytorch:25.03-py3 bash examples/profile_in_docker.sh +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_DIR="$(cd "${SCRIPT_DIR}/.." && pwd)" + +# Host venv to reuse. uv venvs have their own python binary under +# $VENV/bin/python3.x that is glibc/libstdc++-compatible with the +# container when using NGC Ubuntu 22.04 / 24.04 images. +VENV_DIR="${VENV_DIR:-/home/zhuominc/xinrui_projects/uv_env/vortex}" + +# NGC CUDA devel image on Ubuntu. Has /usr/local/cuda/bin/ncu bundled. +# 13.0.1-devel-ubuntu22.04 is public (no NGC login needed), supports +# B200, and matches the host's CUDA 13.x driver ABI. +# +# Alternatives: +# nvcr.io/nvidia/cuda:13.0.1-devel-ubuntu24.04 # newer base +# nvcr.io/nvidia/pytorch:25.03-py3 # if you don't want to +# # reuse the host venv +# Host is Ubuntu 24.04 + Python 3.12 (the uv venv points to /usr/bin/python3.12). +# Match the container to that so the venv's symlinked python resolves to a +# compatible interpreter inside the container. +NCU_IMAGE="${NCU_IMAGE:-nvcr.io/nvidia/cuda:13.0.1-devel-ubuntu24.04}" + +# Pass-through env vars for the inner profile script. Defaults match +# examples/profile_parallel_vs_fused.sh. +GPU="${GPU:-7}" +EFF_BS="${EFF_BS:-1}" +NUM_SPLITS="${NUM_SPLITS:-2}" +POWER="${POWER:--1.0}" +WARMUP="${WARMUP:-20}" +ITERS="${ITERS:-1}" +SECTION_SET="${SECTION_SET:-full}" + +# Inside the container, these mount points give the profile script the +# same absolute paths it sees on the host (so the script doesn't need +# to be container-aware). +MOUNT_ROOT="/home/zhuominc/xinrui_projects" + +if [ ! -d "${VENV_DIR}" ]; then + echo "ERROR: VENV_DIR not found: ${VENV_DIR}" + echo " Set VENV_DIR=/path/to/venv or install the venv." + exit 1 +fi + +VENV_PY="$(ls "${VENV_DIR}"/bin/python* 2>/dev/null | head -1 || true)" +if [ -z "${VENV_PY}" ]; then + echo "ERROR: no python found under ${VENV_DIR}/bin/" + exit 1 +fi + +echo "============================================================" +echo "Docker-wrapped ncu profiling" +echo " image: ${NCU_IMAGE}" +echo " venv: ${VENV_DIR} (python=${VENV_PY##*/})" +echo " project: ${PROJECT_DIR}" +echo " GPU: ${GPU}" +echo " eff_bs: ${EFF_BS}" +echo " num_splits: ${NUM_SPLITS}" +echo " power: ${POWER}" +echo " warmup/iters: ${WARMUP}/${ITERS}" +echo " section set: ${SECTION_SET}" +echo "============================================================" + +# Pull the image up-front (so the output during the run isn't +# interleaved with pull progress). `|| true` — pull is optional; +# if the image is already local, docker run will use the cached copy. +docker pull "${NCU_IMAGE}" || true + +# Run the profile script inside the container. +# +# --gpus all : give the container access to all GPUs +# (CUDA_VISIBLE_DEVICES inside the script +# narrows it down to GPU ${GPU}). +# --cap-add=SYS_ADMIN : lets CUPTI access perf counters without +# touching host profiling restrictions. +# --security-opt seccomp=unconfined : CUPTI needs a few syscalls +# the default seccomp profile blocks. +# --network host : not strictly required, but keeps pip/uv +# network access working if you ever add +# pip-install steps. +# --user $(id -u):$(id -g) +# : write output files owned by your user, +# not root. +# -v /etc/passwd:/etc/passwd:ro -v /etc/group:/etc/group:ro +# : so the uid inside resolves to a real +# user (helps some tools, harmless otherwise). +# -v ${MOUNT_ROOT}:${MOUNT_ROOT} +# : mount the whole xinrui_projects tree so +# both the project and the venv are visible +# at their host paths. +# -e PYTHONPATH=... : add the venv's site-packages explicitly +# so `python3 -c 'import vortex_torch_C'` +# resolves even without activate. +# -e PATH=... : put the venv's bin ahead of /usr/local/cuda/bin +# so `python` is the venv python, and keep ncu +# reachable. +# When invoked via `sudo`, `id -u` returns 0 (root). Prefer SUDO_UID/ +# SUDO_GID so the final chown hands results back to the real user, +# not root. Fall back to the effective uid/gid otherwise. +HOST_UID="${SUDO_UID:-$(id -u)}" +HOST_GID="${SUDO_GID:-$(id -g)}" + +docker run --rm \ + --gpus all \ + --cap-add=SYS_ADMIN \ + --security-opt seccomp=unconfined \ + --network host \ + --ipc=host \ + -e DISPLAY="${DISPLAY:-}" \ + -v /tmp/.X11-unix:/tmp/.X11-unix \ + -v "${MOUNT_ROOT}:${MOUNT_ROOT}" \ + -w "${PROJECT_DIR}" \ + -e GPU="${GPU}" \ + -e EFF_BS="${EFF_BS}" \ + -e NUM_SPLITS="${NUM_SPLITS}" \ + -e POWER="${POWER}" \ + -e WARMUP="${WARMUP}" \ + -e ITERS="${ITERS}" \ + -e SECTION_SET="${SECTION_SET}" \ + -e NCU="/usr/local/cuda/bin/ncu" \ + -e HOST_UID="${HOST_UID}" \ + -e HOST_GID="${HOST_GID}" \ + -e PATH="${VENV_DIR}/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin" \ + -e LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH:-}" \ + "${NCU_IMAGE}" \ + bash -lc ' + set -e + # Ubuntu 24.04 base may not ship python3.12 in the CUDA devel image. + # Install it idempotently; this is ~2s if missing and skipped otherwise. + if [ ! -x /usr/bin/python3.12 ]; then + echo "--- installing python3.12 in container ---" + export DEBIAN_FRONTEND=noninteractive + apt-get update -qq + apt-get install -y --no-install-recommends python3.12 >/dev/null + fi + echo "--- container environment ---" + echo "python: $(readlink -f "$(which python)") ($(python --version 2>&1))" + echo "ncu: $(which ncu)" + ncu --version 2>&1 | head -2 + nvidia-smi -L + python -c "import torch; print(\"torch: \", torch.__version__, \"cuda:\", torch.version.cuda)" + python -c "import vortex_torch_C; print(\"vortex_torch_C import OK\")" + echo "-----------------------------" + bash examples/profile_parallel_vs_fused.sh + # Hand output files back to the host user (we ran as root so apt + # could install python3.12). + chown -R "${HOST_UID}:${HOST_GID}" examples/results 2>/dev/null || true + ' + +echo "" +echo "============================================================" +echo "Docker profiling run complete." +echo "Reports are under: ${PROJECT_DIR}/examples/results/" +echo "(same path as the direct script — you own the files since we" +echo " ran the container as your uid)." +echo "============================================================" diff --git a/examples/profile_parallel_vs_fused_ncu.sh b/examples/profile_parallel_vs_fused_ncu.sh new file mode 100755 index 0000000..bf9baaf --- /dev/null +++ b/examples/profile_parallel_vs_fused_ncu.sh @@ -0,0 +1,277 @@ +#!/usr/bin/env bash +# ============================================================ +# Nsight Compute profiling script for the parallel vs fused +# TopK kernels. +# +# Profiles both: +# - TopKOutput_Fused_Kernel (csrc/topk_sglang.cu) +# - TopKOutput_Parallel_Kernel (csrc/topk_sglang_parallel.cu) +# +# With both remap functions the user cares about: +# - mode 15: MAPPING_SHIFT_POW2 +# - mode 16: MAPPING_SHIFT_POW3 +# +# And both configs: +# - A: topk=2048, pages_per_seg=32K (topk=2k from 32k) +# - B: topk=30, pages_per_seg=2K (topk=30 from 2k) +# +# Produces one .ncu-rep per (kernel × mode × config). Open with +# the Nsight Compute GUI for an interactive comparison, or dump on +# the CLI with `ncu --import .ncu-rep --page details`. +# +# Usage: +# bash examples/profile_parallel_vs_fused.sh # defaults +# GPU=4 EFF_BS=1 bash examples/profile_parallel_vs_fused.sh # small-batch case +# GPU=4 EFF_BS=32 bash examples/profile_parallel_vs_fused.sh # saturated case +# GPU=4 NUM_SPLITS=2 bash examples/profile_parallel_vs_fused.sh +# +# Requires `ncu` on PATH (part of the CUDA toolkit). On most systems +# accessing performance counters requires either: +# - root/sudo, or +# - `echo 1 | sudo tee /proc/driver/nvidia/params` (temporary), or +# - setting NVreg_RestrictProfilingToAdminUsers=0 in the nvidia driver. +# If ncu reports "ERR_NVGPUCTRPERM" you'll need one of the above. +# ============================================================ +set -euo pipefail +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PY_DRIVER="${SCRIPT_DIR}/../benchmarks/profile_parallel_vs_fused.py" + +# ── Defaults ────────────────────────────────────────────────── +GPU=${GPU:-7} +EFF_BS=${EFF_BS:-1} # eff_batch_size = batch_size * num_kv_heads +NUM_SPLITS=${NUM_SPLITS:-2} # only used by the parallel kernel +POWER=${POWER:--1.0} # pivot p for shift_pow{2,3} +WARMUP=${WARMUP:-20} # matching-kernel warmup launches (ncu skips) +ITERS=${ITERS:-1} # matching-kernel profiled launches (ncu captures) +SECTION_SET=${SECTION_SET:-full} # ncu section set: "full", "basic", or named sections + +# Profiling robustness knobs for shared GPUs / CUDA 13 systems. +# --replay-mode application: re-run the entire process to collect each +# counter pass, instead of replaying individual +# kernels. Fixes "Failed to prepare kernel" on +# systems where kernel replay hits PMU conflicts. +# --clock-control none : don't try to lock GPU clocks (requires admin on +# shared GPUs; without this, "Unknown error on +# device 0" is common). +# --cache-control none : don't flush L1/L2 between passes (also needs +# admin on shared systems). +# Override with NCU_EXTRA_FLAGS="..." if you need a different combination. +NCU_EXTRA_FLAGS=${NCU_EXTRA_FLAGS:-"--replay-mode application --clock-control none --cache-control none"} + +# DIAG=1 bash profile_parallel_vs_fused.sh → run one tiny ncu probe to +# verify profiling works before doing the full sweep. +DIAG=${DIAG:-0} + +# ── ncu command ─────────────────────────────────────────────── +NCU=${NCU:-ncu} +command -v "${NCU}" >/dev/null 2>&1 || { + echo "ERROR: '${NCU}' not found on PATH. Install Nsight Compute (part of CUDA Toolkit)" + echo " or set NCU=/path/to/ncu and re-run." + exit 1 +} + +# The templated kernels end up with mangled names like +# _Z25TopKOutput_Fused_KernelI13__nv_bfloat16ILi15EEEvPKT_... +# ncu supports --kernel-name regex: which matches on the +# demangled signature. Using "TopKOutput_Fused_Kernel" and +# "TopKOutput_Parallel_Kernel" as the regex selects all template +# instantiations of each kernel but nothing else. +FUSED_REGEX="regex:TopKOutput_Fused_Kernel" +PARALLEL_REGEX="regex:TopKOutput_Parallel_Kernel" + +# ── Output dir ──────────────────────────────────────────────── +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +OUT_DIR="${SCRIPT_DIR}/results/ncu_parallel_vs_fused_${TIMESTAMP}" +mkdir -p "${OUT_DIR}" + +echo "============================================================" +echo "Nsight Compute profile: parallel vs fused TopK" +echo " GPU: ${GPU}" +echo " eff_bs: ${EFF_BS}" +echo " num_splits: ${NUM_SPLITS} (parallel kernel only)" +echo " power (p): ${POWER} (for shift_pow{2,3})" +echo " warmup: ${WARMUP} (matching-kernel launches skipped by ncu)" +echo " iters: ${ITERS} (matching-kernel launches captured)" +echo " sections: --set ${SECTION_SET}" +echo " extra ncu flags:${NCU_EXTRA_FLAGS}" +echo " output dir: ${OUT_DIR}" +echo "============================================================" + +# ── Diagnostic probe ───────────────────────────────────────── +# Verifies that ncu can attach and collect at least one section on +# this GPU before we burn time on the full sweep. Uses --set basic +# which is the cheapest section set. If this fails, see the +# TROUBLESHOOTING block that the script prints on error. +run_diag() { + echo "" + echo ">>> Diagnostic probe: can ncu attach at all?" + local out="${OUT_DIR}/diag.ncu-rep" + set +e + CUDA_VISIBLE_DEVICES="${GPU}" "${NCU}" \ + --force-overwrite \ + --target-processes all \ + --kernel-name "${FUSED_REGEX}" \ + --launch-skip "${WARMUP}" \ + --launch-count 1 \ + --set basic \ + ${NCU_EXTRA_FLAGS} \ + --export "${out}" \ + python "${PY_DRIVER}" \ + --config A --eff-bs 1 --mode 15 --power "${POWER}" \ + --num-splits "${NUM_SPLITS}" --kernel fused \ + --warmup "${WARMUP}" --iters 1 + local rc=$? + set -e + if [ ${rc} -ne 0 ]; then + cat <<'EOF' + +============================================================ +TROUBLESHOOTING "Failed to prepare kernel for profiling" +============================================================ + 1) Is another process using GPU ${GPU}? Check: + nvidia-smi + If yes, pick an idle GPU: + GPU=0 bash examples/profile_parallel_vs_fused.sh + + 2) Perf counters may be locked to admin. Try as root: + sudo -E bash examples/profile_parallel_vs_fused.sh + + Or permanently unlock (admin, persists until reboot): + sudo sh -c 'echo 1 > /proc/driver/nvidia/params' + + Or permanently in the driver (needs reboot): + Add NVreg_RestrictProfilingToAdminUsers=0 to + /etc/modprobe.d/nvidia.conf + + 3) MPS or another profiler (CUPTI, Nsight Systems, etc.) + may be running. Kill with: + echo quit | nvidia-cuda-mps-control + and verify nothing else is profiling. + + 4) On H100 with MIG: profiling across MIG slices is + restricted. Use a full-device GPU. + + 5) Try a smaller ncu configuration first: + NCU_EXTRA_FLAGS="--replay-mode application --clock-control none --cache-control none --metrics sm__cycles_elapsed.avg" \ + bash examples/profile_parallel_vs_fused.sh + + 6) CUDA 13.2 vs PyTorch-13.0 mismatch is sometimes flagged + by ncu. Update ncu to match CUDA 13.2, or use the ncu + shipped with CUDA 13.2: + NCU=/usr/local/cuda-13.2/bin/ncu bash ... + +============================================================ +EOF + echo "Diagnostic probe failed (exit ${rc}). See troubleshooting above." + exit ${rc} + fi + echo ">>> Diagnostic probe OK. Proceeding with full sweep." +} + +if [ "${DIAG}" = "1" ]; then + run_diag + exit 0 +fi + +# Always run a cheap probe first so full-sweep failures are caught early +# before we've spent minutes on the heavy --set full passes. +run_diag + +# ── Helper: run one ncu profile ────────────────────────────── +# tag : name used for the output file +# kernel : "fused" or "parallel" (drives Python driver dispatch) +# regex : ncu --kernel-name filter +# config : "A" or "B" +# mode : 15 or 16 +run_ncu() { + local tag="$1" + local kernel="$2" + local regex="$3" + local config="$4" + local mode="$5" + + local out="${ + + + + }/${tag}.ncu-rep" + + echo "" + echo ">>> ${tag}" + + # --launch-skip/--launch-count count ONLY kernels matching + # --kernel-name, so setup kernels (torch.randn, etc.) don't + # pollute the offsets. With --launch-skip=${WARMUP} and the + # Python driver doing ${WARMUP} warmup + ${ITERS} profiled + # calls, ncu captures exactly the profiled ones. + CUDA_VISIBLE_DEVICES="${GPU}" "${NCU}" \ + --force-overwrite \ + --target-processes all \ + --kernel-name "${regex}" \ + --launch-skip "${WARMUP}" \ + --launch-count "${ITERS}" \ + --set "${SECTION_SET}" \ + ${NCU_EXTRA_FLAGS} \ + --export "${out}" \ + python "${PY_DRIVER}" \ + --config "${config}" \ + --eff-bs "${EFF_BS}" \ + --mode "${mode}" \ + --power "${POWER}" \ + --num-splits "${NUM_SPLITS}" \ + --kernel "${kernel}" \ + --warmup "${WARMUP}" \ + --iters "${ITERS}" + + echo " report: ${out}" +} + +# ── Sweep ──────────────────────────────────────────────────── +for MODE in 15 16; do + if [ "${MODE}" -eq 15 ]; then MODE_TAG="SP2"; else MODE_TAG="SP3"; fi + for CONFIG in A B; do + run_ncu "fused_${MODE_TAG}_cfg${CONFIG}_eff${EFF_BS}" \ + "fused" "${FUSED_REGEX}" "${CONFIG}" "${MODE}" + run_ncu "parallel_${MODE_TAG}_cfg${CONFIG}_eff${EFF_BS}_ns${NUM_SPLITS}" \ + "parallel" "${PARALLEL_REGEX}" "${CONFIG}" "${MODE}" + done +done + +echo "" +echo "============================================================" +echo "All profiles done. Reports saved under:" +echo " ${OUT_DIR}" +echo "" +echo "Interactive analysis (recommended):" +echo " ncu-ui ${OUT_DIR}/parallel_SP2_cfgA_eff${EFF_BS}_ns${NUM_SPLITS}.ncu-rep" +echo "" +echo "CLI summary, one kernel at a time:" +echo " ncu --import ${OUT_DIR}/fused_SP2_cfgA_eff${EFF_BS}.ncu-rep --page details" +echo "" +echo "Side-by-side diff (CLI):" +echo " ncu --import ${OUT_DIR}/fused_SP2_cfgA_eff${EFF_BS}.ncu-rep \\" +echo " --import ${OUT_DIR}/parallel_SP2_cfgA_eff${EFF_BS}_ns${NUM_SPLITS}.ncu-rep \\" +echo " --page details --csv > ${OUT_DIR}/compare_SP2_cfgA.csv" +echo "" +echo "What to look at (to pinpoint the overhead vs fused):" +echo " * Section 'GPU Speed Of Light Throughput'" +echo " → SM %, Memory %, which one is the bound?" +echo " * Section 'Launch Statistics'" +echo " → Grid/Block size, Dynamic Shared Mem per block" +echo " * Section 'Occupancy'" +echo " → Theoretical vs achieved; limit (smem / regs / blocks/SM)" +echo " * Section 'Warp State Statistics'" +echo " → Stall breakdown: Stall Barrier (__syncthreads/__threadfence)," +echo " Stall Long Scoreboard (global memory), Stall Short Scoreboard" +echo " (smem/atomic)" +echo " * Section 'Memory Workload Analysis'" +echo " → L2/Device throughput, atomic traffic, smem bank conflicts" +echo " * Section 'Compute Workload Analysis'" +echo " → Pipe utilisation (FMA / ALU / FP64)" +echo "" +echo "Likely suspects for the parallel-vs-fused gap:" +echo " - Occupancy limited by the large dynamic smem (kSmem + chunk_bytes)" +echo " - Stall Barrier dominating due to the __threadfence before atomicInc" +echo " - Phase 1 CTAs repeat Stage-2 refinement that fused does only once" +echo " → visible as 'Pipe Utilisation ALU / Special' for integer radix ops" +echo "============================================================" diff --git a/examples/profile_parallel_vs_fused_nsys.sh b/examples/profile_parallel_vs_fused_nsys.sh new file mode 100755 index 0000000..3d64519 --- /dev/null +++ b/examples/profile_parallel_vs_fused_nsys.sh @@ -0,0 +1,211 @@ +#!/usr/bin/env bash +# ============================================================ +# Nsight Systems (nsys) profiling — timeline view of the parallel +# vs fused TopK kernels. +# +# Why nsys and not ncu here: +# ncu needs SM-level perf counters (sm__*), which on this box are +# gated by the nvidia driver's RmProfilingAdminOnly flag — and we +# have no sudo. nsys uses CUPTI API/activity tracing and kernel +# timing, which do NOT require admin. That's enough to answer the +# "where does the 6-8us overhead come from" question, because we +# get per-kernel durations, gaps on the stream, memcpy/memset +# traffic, and NVTX range timing. +# +# Profiles both: +# - TopKOutput_Fused_Kernel (csrc/topk_sglang.cu) +# - TopKOutput_Parallel_Kernel (csrc/topk_sglang_parallel.cu) +# +# For each of mode 15 (SHIFT_POW2), mode 16 (SHIFT_POW3) and both +# configs A (topk=2048 pages=32K) and B (topk=30 pages=2K). +# +# Produces one .nsys-rep per (kernel × mode × config). Open with: +# nsys-ui .nsys-rep +# or dump CLI summaries with: +# nsys stats .nsys-rep +# +# Usage: +# bash examples/profile_parallel_vs_fused_nsys.sh # defaults +# GPU=7 NUM_SPLITS=2 bash examples/profile_parallel_vs_fused_nsys.sh +# ITERS=50 bash examples/profile_parallel_vs_fused_nsys.sh # more samples +# ============================================================ +set -euo pipefail +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PY_DRIVER="${SCRIPT_DIR}/../benchmarks/profile_parallel_vs_fused.py" + +# ── Defaults ────────────────────────────────────────────────── +GPU=${GPU:-7} +EFF_BS=${EFF_BS:-1} +NUM_SPLITS=${NUM_SPLITS:-2} +POWER=${POWER:--1.0} +WARMUP=${WARMUP:-20} +# For nsys we want *many* iterations so the per-kernel timing is +# statistically meaningful and the timeline is readable. +ITERS=${ITERS:-50} + +# Prefer the CUDA-13 toolchain's nsys (matches the torch CUDA ABI). +NSYS=${NSYS:-$(command -v nsys || echo /usr/local/cuda/bin/nsys)} +if [ ! -x "${NSYS}" ]; then + echo "ERROR: nsys not found. Tried: ${NSYS}" + echo " Set NSYS=/path/to/nsys manually." + exit 1 +fi + +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +OUT_DIR="${SCRIPT_DIR}/results/nsys_parallel_vs_fused_${TIMESTAMP}" +mkdir -p "${OUT_DIR}" + +# nsys writes intermediate files under $TMPDIR/nvidia/nsight_systems. +# On shared systems /tmp/nvidia is often owned by another user who +# created it first, and we can't write there. Redirect to a +# user-writable cache dir. +export TMPDIR="${TMPDIR:-${HOME}/.cache/nsys_tmp}" +mkdir -p "${TMPDIR}" + +echo "============================================================" +echo "Nsight Systems profile: parallel vs fused TopK" +echo " GPU: ${GPU}" +echo " eff_bs: ${EFF_BS}" +echo " num_splits: ${NUM_SPLITS}" +echo " power (p): ${POWER}" +echo " warmup: ${WARMUP}" +echo " iters: ${ITERS} (profiled launches)" +echo " nsys binary: ${NSYS}" +echo " output dir: ${OUT_DIR}" +echo "============================================================" +"${NSYS}" --version 2>&1 | head -2 + +# ── Helper: run one nsys profile ───────────────────────────── +run_nsys() { + local tag="$1" + local kernel="$2" + local config="$3" + local mode="$4" + + local out="${OUT_DIR}/${tag}" + + echo "" + echo ">>> ${tag}" + + # --trace cuda,nvtx : CUDA API/runtime + NVTX ranges. NVTX + # stays on so the timeline still shows + # where the profiled region begins. + # --sample none / --cpuctxsw none: skip CPU callstack sampling and + # context-switch tracing — both admin-gated + # on this box and we don't need them. + # --cuda-memory-usage true: log cudaMalloc/cudaFree/cudaMemset so we + # can see if at::empty / at::zeros costs + # anything on the hot path. + # + # Capture-range flags intentionally OMITTED. On some nsys builds + # --capture-range=nvtx silently yields "No reports were generated" + # when the ranges don't line up exactly; profiling the whole run + # is more robust and the warmup is easy to filter out later + # (NVTX range "profile-*" tags the profiled region in nsys stats). + CUDA_VISIBLE_DEVICES="${GPU}" "${NSYS}" profile \ + --output "${out}" \ + --force-overwrite true \ + --trace cuda,nvtx \ + --sample none \ + --cpuctxsw none \ + --cuda-memory-usage true \ + python "${PY_DRIVER}" \ + --config "${config}" \ + --eff-bs "${EFF_BS}" \ + --mode "${mode}" \ + --power "${POWER}" \ + --num-splits "${NUM_SPLITS}" \ + --kernel "${kernel}" \ + --warmup "${WARMUP}" \ + --iters "${ITERS}" + + echo " report: ${out}.nsys-rep" +} + +# ── Sweep ──────────────────────────────────────────────────── +for MODE in 15 16; do + if [ "${MODE}" -eq 15 ]; then MODE_TAG="SP2"; else MODE_TAG="SP3"; fi + for CONFIG in A B; do + run_nsys "fused_${MODE_TAG}_cfg${CONFIG}_eff${EFF_BS}" \ + "fused" "${CONFIG}" "${MODE}" + run_nsys "parallel_${MODE_TAG}_cfg${CONFIG}_eff${EFF_BS}_ns${NUM_SPLITS}" \ + "parallel" "${CONFIG}" "${MODE}" + done +done + +# ── Auto-dump CLI summaries for every report ───────────────── +# `nsys stats` produces text tables that are immediately readable +# and answer most "where did the time go" questions without needing +# the GUI. We dump the most useful ones for every report and stash +# them alongside. +echo "" +echo "============================================================" +echo "Dumping text summaries ('nsys stats') for every report..." +echo "============================================================" +for rep in "${OUT_DIR}"/*.nsys-rep; do + name="$(basename "${rep}" .nsys-rep)" + echo "" + echo ">>> summary for ${name}" + summary="${OUT_DIR}/${name}.summary.txt" + { + echo "### ${name}" + echo "" + echo "## cuda_api_sum: CUDA runtime API call distribution" + echo "## (count, avg, med, min, max of cudaLaunchKernel / cudaMalloc / etc.)" + "${NSYS}" stats --report cuda_api_sum --format table "${rep}" 2>&1 || true + echo "" + echo "## cuda_gpu_kern_sum: per-kernel GPU duration stats" + echo "## (mean/median/std/min/max duration per kernel name, with instance count)" + "${NSYS}" stats --report cuda_gpu_kern_sum --format table "${rep}" 2>&1 || true + echo "" + echo "## cuda_gpu_mem_size_sum: memcpy / memset by size" + echo "## (expect 0 memset entries for parallel — no at::zeros on the hot path)" + "${NSYS}" stats --report cuda_gpu_mem_size_sum --format table "${rep}" 2>&1 || true + echo "" + echo "## cuda_gpu_mem_time_sum: memcpy / memset by time" + "${NSYS}" stats --report cuda_gpu_mem_time_sum --format table "${rep}" 2>&1 || true + echo "" + echo "## cuda_kern_exec_sum: kernel launch→exec latency" + echo "## (host-side cudaLaunchKernel cost separated from GPU exec cost)" + "${NSYS}" stats --report cuda_kern_exec_sum --format table "${rep}" 2>&1 || true + echo "" + echo "## nvtx_pushpop_sum: NVTX ranges (the 'profile-*' wrapped region)" + "${NSYS}" stats --report nvtx_pushpop_sum --format table "${rep}" 2>&1 || true + } > "${summary}" 2>&1 + echo " saved: ${summary}" +done + +echo "" +echo "============================================================" +echo "Reports saved to: ${OUT_DIR}" +echo "" +echo "Quick read — compare fused vs parallel summaries side-by-side:" +echo "" +echo " diff -y --width=200 \\" +echo " ${OUT_DIR}/fused_SP2_cfgA_eff${EFF_BS}.summary.txt \\" +echo " ${OUT_DIR}/parallel_SP2_cfgA_eff${EFF_BS}_ns${NUM_SPLITS}.summary.txt \\" +echo " | less" +echo "" +echo "Interactive timeline (if you have X11/SSH forwarding):" +echo " nsys-ui ${OUT_DIR}/parallel_SP2_cfgA_eff${EFF_BS}_ns${NUM_SPLITS}.nsys-rep" +echo "" +echo "What to look for (to nail the overhead vs fused):" +echo " * 'cuda_gpu_kern_sum' mean duration for each kernel" +echo " → fused is one kernel × (WARMUP+ITERS), parallel is one kernel × (WARMUP+ITERS)" +echo " (single-kernel design). Mean duration difference = the GPU work" +echo " gap (Stage-1 savings minus merge cost)." +echo " * 'cuda_api_sum' cudaLaunchKernel / cudaMalloc / cudaFree counts" +echo " → if parallel shows more launches than fused, there's an unexpected" +echo " extra kernel. Also watch the time spent in cudaLaunchKernel." +echo " * 'cuda_gpu_mem_size_sum' cudaMemset entries" +echo " → should be zero for parallel now (__device__ counter removed" +echo " at::zeros). Any memset here IS overhead we need to explain." +echo " * 'cuda_kern_exec_sum'" +echo " → separates host-side cudaLaunchKernel latency from GPU kernel time." +echo " * 'nvtx_pushpop_sum' profile-* range duration / ${ITERS}" +echo " → wall-clock per-call including CPU-side overhead." +echo "" +echo "Timeline view (nsys-ui) additionally shows *gaps* between kernels" +echo "on the GPU stream — the cost of __threadfence + atomicInc barrier" +echo "shows up as a visible pause between Phase-1 work and the merge." +echo "============================================================" diff --git a/examples/remap_function_bench_topk_parallel.sh b/examples/remap_function_bench_topk_parallel.sh new file mode 100755 index 0000000..df33f2c --- /dev/null +++ b/examples/remap_function_bench_topk_parallel.sh @@ -0,0 +1,245 @@ +#!/usr/bin/env bash +# ============================================================ +# Remap Function Benchmark — Parallel TopK variant. +# +# Wraps bench_topk.py --remap-bench with --bench-parallel so the +# output table includes a "par_ms" column comparing the split+merge +# kernel (topk_output_sglang_parallel) against the single-CTA +# fused kernel. Also sweeps batch size and num_splits so the +# occupancy-vs-merge-overhead curve is visible. +# +# Pipeline mirrors remap_function_bench_topk2028.sh: +# Step 1 — calibrate (can be skipped with --real-histograms) +# Step 2 — autotune per-mode hparams by fused-kernel latency +# Step 3 — remap bench, looped over NUM_SPLITS_SWEEP values +# +# Usage: +# bash remap_function_bench_topk_parallel.sh --gpu 4 +# +# # Explicit batch-size sweep: +# bash remap_function_bench_topk_parallel.sh --gpu 4 \ +# --batch-sizes "1 2 4 8" --num-splits-sweep "auto 2 4 8" +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=4 +MODEL_NAME="Qwen/Qwen3-1.7B" +TOPK_VAL=2048 +MEM=0.7 +MAX_TOTAL_TOKENS=64768 +MIN_FREE_DISK_GB=22 +ALGO="block_sparse_attention" +SAMPLE_STRIDE=1 +SEQ_LEN=32768 +BLOCK_SIZE=1 +BATCH_SIZES="1 2 4 8 16" +NUM_KV_HEADS=8 +DISTRIBUTIONS="normal bucket_uniform" +# Modes excluding 1 (LUT_CDF) and 2 (Quantile) which are discarded. +MAPPING_MODES="0 3 6 7 9 10 11 13 15 16 17 18 19" +MAPPING_HPARAM=0.5 +REPEAT=100 +WARMUP=20 +# "auto" lets bench_topk.py pick via sqrt(pages/topk). Explicit ints +# pin a split count for A/B comparisons. +NUM_SPLITS_SWEEP="auto 2 4 8" +REAL_HISTOGRAMS="/var/tmp/zhuominc/vortex_torch/calibration/raw_histograms_qwen3-1.7B.npy" +SKIP_AUTOTUNE=0 +PINNED_AUTOTUNE_JSON="" + +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; + --min-free-disk-gb) MIN_FREE_DISK_GB="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --sample-stride) SAMPLE_STRIDE="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --batch-sizes) BATCH_SIZES="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --distributions) DISTRIBUTIONS="$2"; shift 2 ;; + --modes) MAPPING_MODES="$2"; shift 2 ;; + --mapping-hparam) MAPPING_HPARAM="$2"; shift 2 ;; + --repeat) REPEAT="$2"; shift 2 ;; + --warmup) WARMUP="$2"; shift 2 ;; + --num-splits-sweep) NUM_SPLITS_SWEEP="$2"; shift 2 ;; + --skip-autotune) SKIP_AUTOTUNE=1; shift 1 ;; + --pinned-autotune-json) PINNED_AUTOTUNE_JSON="$2"; SKIP_AUTOTUNE=1; shift 2 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" +export SGL_ENABLE_JIT_DEEPGEMM="${SGL_ENABLE_JIT_DEEPGEMM:-true}" + +if [ -z "${DG_JIT_NVCC_COMPILER:-}" ]; then + if [ -x /usr/local/cuda/bin/nvcc ]; then + export CUDA_HOME="${CUDA_HOME:-/usr/local/cuda}" + export PATH="${CUDA_HOME}/bin:${PATH}" + export DG_JIT_NVCC_COMPILER="${CUDA_HOME}/bin/nvcc" + elif command -v nvcc >/dev/null 2>&1; then + export DG_JIT_NVCC_COMPILER="$(command -v nvcc)" + fi +fi + +MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) +if [ "${SEQ_LEN}" -lt "${MIN_SEQ_LEN}" ]; then + echo "ERROR: --seq-len ${SEQ_LEN} too small for --topk-val ${TOPK_VAL} @ --block-size ${BLOCK_SIZE}." + echo " Minimum: ${MIN_SEQ_LEN} (pages/seg must exceed topk_val + 3 reserved pages)" + exit 1 +fi + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +RUN_DIR="${RESULTS_DIR}/parallel_bench_${MODEL_SLUG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" +mkdir -p "${RUN_DIR}" + +CALIBRATION_BASE="/var/tmp/zhuominc/vortex_torch/calibration" +MODEL_TAG="$(echo "${MODEL_NAME##*/}" | sed 's/^Q/q/')" +DEFAULT_REAL_HIST="${CALIBRATION_BASE}/raw_histograms_${MODEL_TAG}.npy" +mkdir -p "${CALIBRATION_BASE}" + +if [ -z "${REAL_HISTOGRAMS}" ] && [ -f "${DEFAULT_REAL_HIST}" ]; then + REAL_HISTOGRAMS="${DEFAULT_REAL_HIST}" +fi + +echo "============================================================" +echo "Remap Function Benchmark (Parallel TopK variant)" +echo " Model: ${MODEL_NAME}" +echo " Algorithm: ${ALGO}" +echo " TopK: ${TOPK_VAL}" +echo " Block size: ${BLOCK_SIZE}" +echo " Seq len: ${SEQ_LEN} ($(( SEQ_LEN / BLOCK_SIZE )) pages/seg)" +echo " Batch sizes: ${BATCH_SIZES}" +echo " KV heads: ${NUM_KV_HEADS}" +echo " Distributions: ${DISTRIBUTIONS}" +echo " Mapping modes: ${MAPPING_MODES}" +echo " num_splits sweep:${NUM_SPLITS_SWEEP}" +echo " GPU: ${GPU_ID}" +echo " Real histograms: ${REAL_HISTOGRAMS:-}" +echo " Output: ${RUN_DIR}" +echo "============================================================" + +# ── Step 1: Calibrate ──────────────────────────────────────── +if [ -n "${REAL_HISTOGRAMS}" ]; then + echo "" + echo ">>> Step 1: SKIPPED (using provided --real-histograms ${REAL_HISTOGRAMS})" + REAL_HIST_PATH="${REAL_HISTOGRAMS}" +else + echo "" + echo ">>> Step 1: Calibrating ${MODEL_NAME} — collecting real topk histograms" + CALIBRATION_DIR="${CALIBRATION_BASE}/staging_${MODEL_TAG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" + mkdir -p "${CALIBRATION_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --min-free-disk-gb "${MIN_FREE_DISK_GB}" \ + --vortex-module-name "${ALGO}" \ + --output-dir "${CALIBRATION_DIR}" \ + 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" + mv -f "${CALIBRATION_DIR}/raw_histograms.npy" "${DEFAULT_REAL_HIST}" + REAL_HIST_PATH="${DEFAULT_REAL_HIST}" + echo ">>> Step 1: Done. raw_histograms -> ${REAL_HIST_PATH}" +fi + +# ── Step 2: Autotune ───────────────────────────────────────── +AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" +if [ "${SKIP_AUTOTUNE}" -eq 1 ]; then + echo "" + if [ -n "${PINNED_AUTOTUNE_JSON}" ]; then + echo ">>> Step 2: SKIPPED (pinned hparams from ${PINNED_AUTOTUNE_JSON})" + AUTOTUNE_ARGS="--autotune-json ${PINNED_AUTOTUNE_JSON}" + else + echo ">>> Step 2: SKIPPED (using fallback --mapping-hparam ${MAPPING_HPARAM})" + AUTOTUNE_ARGS="" + fi +else + echo "" + echo ">>> Step 2: Auto-tuning hyperparameters by profiled topk kernel latency" + # Autotune on the largest batch size so the picked hparam matches realistic + # decode conditions; the hparam itself is largely batch-invariant. + FIRST_BS="$(echo ${BATCH_SIZES} | awk '{print $NF}')" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --batch-size "${FIRST_BS}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --collect-stats \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2_autotune.log" + echo ">>> Step 2: Done. Autotune results saved to ${AUTOTUNE_JSON}" + AUTOTUNE_ARGS="--autotune-json ${AUTOTUNE_JSON}" +fi + +# ── Step 3: Remap + Parallel bench, sweeping num_splits ────── +echo "" +echo ">>> Step 3: Timing baseline / fused / parallel with num_splits sweep" + +for NS in ${NUM_SPLITS_SWEEP}; do + if [ "${NS}" = "auto" ]; then + NS_ARG="--num-splits -1" + NS_TAG="auto" + else + NS_ARG="--num-splits ${NS}" + NS_TAG="ns${NS}" + fi + REMAP_JSON="${RUN_DIR}/remap_bench_${NS_TAG}.json" + LOG="${RUN_DIR}/step3_remap_bench_${NS_TAG}.log" + BENCH_EXTRA=() + [ -n "${REAL_HIST_PATH}" ] && BENCH_EXTRA+=(--real-histograms "${REAL_HIST_PATH}") + echo "" + echo "--- num_splits=${NS_TAG} ---" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --bench-parallel \ + ${NS_ARG} \ + --batch-sizes ${BATCH_SIZES} \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --distributions ${DISTRIBUTIONS} \ + --mapping-modes ${MAPPING_MODES} \ + --mapping-hparam "${MAPPING_HPARAM}" \ + ${AUTOTUNE_ARGS} \ + "${BENCH_EXTRA[@]}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --output-json "${REMAP_JSON}" \ + 2>&1 | tee "${LOG}" + echo ">>> num_splits=${NS_TAG}: JSON -> ${REMAP_JSON}" +done + +# ── Summary ─────────────────────────────────────────────────── +echo "" +echo "============================================================" +echo "Parallel TopK Benchmark Complete" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" +echo " Batch sizes: ${BATCH_SIZES}" +echo " num_splits sweep: ${NUM_SPLITS_SWEEP}" +echo " All outputs in: ${RUN_DIR}/" +echo " autotune_results.json — latency-ranked mapping hparams" +echo " remap_bench_.json — per-config latencies including par_ms" +echo " step{1,2,3}_*.log — pipeline logs" +echo "============================================================" diff --git a/examples/verify_algo.py b/examples/verify_algo.py index a78f1e6..dacba65 100644 --- a/examples/verify_algo.py +++ b/examples/verify_algo.py @@ -303,10 +303,12 @@ def parse_args(): "--topk-mapping-mode", type=int, default=0, - choices=[0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 13], + choices=[0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 13, 15, 16, 17, 18, 19, 20], help='TopK mapping mode for sglang_fused: 0=none, 1=lut_cdf (calibrated), ' '2=quantile (calibrated), 3=power, 4=log, 6=asinh, 7=log1p, 8=trunc8, ' - '9=erf, 10=tanh, 11=subtract, 13=exp_stretch (default: 0).', + '9=erf, 10=tanh, 11=subtract, 13=exp_stretch, 15=shift_pow2, ' + '16=shift_pow3, 17=linear_steep, 18=half_square, 19=half_cube, ' + '20=dense_mant (default: 0).', ) parser.add_argument( @@ -326,11 +328,20 @@ def parse_args(): "Use multiple values to run several benchmarks sequentially (default: amc23).", ) + parser.add_argument( + "--output-json", + type=str, + default=None, + help="Optional path. When set, a JSON list of per-benchmark summary dicts is " + "dumped here after all benchmarks finish. Used by the ablation wrappers.", + ) + return parser.parse_args() if __name__ == "__main__": args = parse_args() + all_summaries = [] for bench_name in args.benchmark: if bench_name not in BENCHMARK_REGISTRY: print(f"WARNING: Unknown benchmark '{bench_name}', skipping. Available: {list(BENCHMARK_REGISTRY.keys())}") @@ -353,6 +364,20 @@ def parse_args(): benchmark=bench_name, ) summary["benchmark"] = bench_name + summary["model_name"] = args.model_name + summary["topk_val"] = args.topk_val + summary["page_size"] = args.page_size + summary["topk_type"] = args.topk_type + summary["topk_mapping_mode"] = args.topk_mapping_mode + summary["topk_mapping_hparam"] = args.topk_mapping_hparam + summary["full_attention"] = bool(args.full_attention) print(summary) + all_summaries.append(summary) + + if args.output_json: + os.makedirs(os.path.dirname(os.path.abspath(args.output_json)) or ".", exist_ok=True) + with open(args.output_json, "w") as f: + json.dump(all_summaries, f, indent=2) + print(f"\n[verify_algo] summary JSON written to {args.output_json}") exit(0) \ No newline at end of file diff --git a/setup.py b/setup.py index c973181..8b49661 100644 --- a/setup.py +++ b/setup.py @@ -20,6 +20,7 @@ 'csrc/topk_sglang.cu', 'csrc/topk_sglang_profile.cu', 'csrc/topk_sglang_ori.cu', + 'csrc/topk_sglang_parallel.cu', ], include_dirs=['csrc'], extra_compile_args={