feat: add mx.fast.turboquant_attention for compressed KV cache#3340
feat: add mx.fast.turboquant_attention for compressed KV cache#3340yzamari wants to merge 8 commits intoml-explore:mainfrom
Conversation
Native C++/Metal function that computes attention directly from TurboQuant compressed KV data with zero intermediate allocations. Fuses MSE score + QJL correction + value dequantization + online softmax in a single Metal kernel. Follows the sdpa_vector pattern: SIMD groups stride over KV tokens, threads split head dimension D, cross-group reduction via threadgroup memory transpose. New files: - Metal shader: sdpa_vector_turboquant.h (fused kernel) - Metal instantiation: sdpa_vector_turboquant.metal - Params struct: params_turboquant.h (TurboQuantAttnParams) - GPU dispatch: turboquant_attention.cpp Modified: - fast.h: turboquant_attention declaration - fast_primitives.h: TurboQuantAttention primitive class - fast.cpp: C++ implementation with query rotation + validation - python/src/fast.cpp: nanobind Python binding - CMakeLists: build integration for Metal kernel + dispatch Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Three optimizations identified by Apple architecture and low-level optimization expert audits: 1. Cache centroids in registers (4 floats) — eliminates random device memory reads for centroid lookup (15-25% MSE section improvement) 2. Hoist v_scales/v_zeros loads outside sub-loop — eliminates 4x redundant device memory loads per thread per token (10-15% value section improvement) 3. Pre-scale queries by attention scale at load time — saves 1 multiply per token in the main loop Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Changed turboquant_attention to return 3 arrays instead of 1: - acc: unnormalized weighted sum (B, H_q, qL, D) - max_score: running max (B, H_q, qL) - sum_exp: sum of exp(scores - max) (B, H_q, qL) This enables merging with buffer portion via log-sum-exp arithmetic, matching the existing mlx-turboquant fused decode interface. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Pass 1 splits KV tokens across multiple threadgroups (32-128 blocks), each producing partial (acc, m, l). Pass 2 merges via log-sum-exp. Routes to 2-pass when N >= 1024 and qL == 1 (decode path). Falls back to 1-pass for shorter sequences or prefill. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add batch_idx to KV base address calculation in sdpa_vector_turboquant_2pass_1. Without this, B>1 reads wrong data. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
16 tests covering output shapes, dtypes, GQA, batch sizes, edge cases (N=1, N=2048 for 2-pass kernel), float16 support, input validation, and 1-pass vs 2-pass consistency. All tests pass alongside existing 22 upstream fast tests. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Run clang-format and black on all TurboQuant files. Add turboquant_attention_bench.py comparing fused kernel vs standard SDPA. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Templatize Metal shaders on MSE_BITS and V_BITS template parameters, generalize bit unpacking to handle multi-byte reads (4 coords at 4-bit = 2 bytes). Supports 2-bit, 4-bit, and mixed (4-bit keys + 2-bit values). 22 tests pass (16 original 2-bit + 6 new 4-bit). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
Duplicate of #3328. |
|
Hi @zcbenz , I'd like to respectfully ask you to reconsider this closure. While both PRs add TurboQuant support, they are ▎ #3328 implements only the PolarQuant stage (Algorithm 1) of the paper with 3-bit keys only. ▎ This PR implements the full TurboQuant algorithm — both the MSE/PolarQuant stage (Algorithm 1) and the QJL residual ▎ Other differences: ▎ I don't think these are duplicates — #3328 implements a subset of the algorithm. Would you be open to re-evaluating both |
|
Basically we don't have time to review 2k lines PRs, the correct way to implement things is to make small steps like the referenced PR. Also we will not necessarily add fast kernel for turbo quant, it is hot but if we add AI-written kernels for every hot thing MLX would probably have 100000000000000 lines of code now. |
Summary
Adds
mx.fast.turboquant_attention()— a fused Metal kernel that computes attention directly from compressed KV cache data without ever decompressing it. Implements the TurboQuant algorithm (Google Research, ICLR 2026) as a native MLX operation.This is a new additive API — no existing code is modified.
What it does
Standard attention reads full FP16 keys and values (2 bytes/element). TurboQuant compresses these to 2-bit packed data (~0.3-0.4 bytes/element) and computes attention scores + value weighted sums directly from the compressed representation in a single Metal kernel pass.
Returns unnormalized
(acc, max_score, sum_exp)for log-sum-exp merge with an uncompressed buffer.Real-world inference results (M4 Pro 48GB)
Using mlx-turboquant and turboquant-bench:
Qwen2.5-3B-Instruct-4bit — generation speed at different context lengths:
Standard SDPA degrades linearly as context grows. TurboQuant stays flat (~110 tok/s) because it reads 5x less data from memory.
KV cache compression:
Implementation details
Files changed (all additive)
kernels/sdpa_vector_turboquant.hkernels/sdpa_vector_turboquant.metalkernels/steel/attn/params_turboquant.hmetal/turboquant_attention.cppfast.cppfast.hfast_primitives.hpython/src/fast.cpptests/test_turboquant_attention.pybenchmarks/python/turboquant_attention_bench.pyTests
16 new tests in
test_turboquant_attention.py:All 22 existing
test_fast.pytests continue to pass (38/38 total).Benchmarks
Included:
benchmarks/python/turboquant_attention_bench.pyFor full end-to-end model benchmarks: turboquant-bench
References