Skip to content

feat: add mx.fast.turboquant_attention for compressed KV cache#3340

Closed
yzamari wants to merge 8 commits intoml-explore:mainfrom
yzamari:feature/turboquant-attention
Closed

feat: add mx.fast.turboquant_attention for compressed KV cache#3340
yzamari wants to merge 8 commits intoml-explore:mainfrom
yzamari:feature/turboquant-attention

Conversation

@yzamari
Copy link
Copy Markdown

@yzamari yzamari commented Mar 31, 2026

Summary

Adds mx.fast.turboquant_attention() — a fused Metal kernel that computes attention directly from compressed KV cache data without ever decompressing it. Implements the TurboQuant algorithm (Google Research, ICLR 2026) as a native MLX operation.

This is a new additive API — no existing code is modified.

What it does

Standard attention reads full FP16 keys and values (2 bytes/element). TurboQuant compresses these to 2-bit packed data (~0.3-0.4 bytes/element) and computes attention scores + value weighted sums directly from the compressed representation in a single Metal kernel pass.

# Standard path: 16-bit keys/values → attention
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)

# TurboQuant path: 2-bit packed data → attention (no decompression)
acc, m, l = mx.fast.turboquant_attention(
    queries, k_packed, k_signs, k_norms, k_res_norms,
    centroids, v_packed, v_scales, v_zeros,
    rotation_matrix, sketch_matrix,
    scale=scale, qjl_scale=qjl_scale,
)
output = acc / l[..., None]  # normalize

Returns unnormalized (acc, max_score, sum_exp) for log-sum-exp merge with an uncompressed buffer.

Real-world inference results (M4 Pro 48GB)

Using mlx-turboquant and turboquant-bench:

Qwen2.5-3B-Instruct-4bit — generation speed at different context lengths:

Context Standard SDPA TurboQuant Speedup
4K 83.5 tok/s 111.8 tok/s 1.3x
8K 73.3 tok/s 117.6 tok/s 1.6x
16K 57.8 tok/s 113.4 tok/s 2.0x
32K 27.9 tok/s 110.4 tok/s 4.0x

Standard SDPA degrades linearly as context grows. TurboQuant stays flat (~110 tok/s) because it reads 5x less data from memory.

KV cache compression:

Tokens FP16 Cache TurboQuant 2-bit Compression
4K 512 MB 113 MB 4.5x
16K 2,048 MB 413 MB 5.0x

Implementation details

  • Two kernel variants: 1-pass (N < 1024) and 2-pass (N ≥ 1024) for different sequence lengths
  • Fused operations per token: MSE centroid lookup → QJL sign correction → online softmax update → value dequant → weighted accumulation
  • Zero intermediate allocations: everything computed in GPU registers
  • GQA support: H_q can be any multiple of H_kv
  • Head dimensions: D=64 and D=128
  • Supported dtypes: float16, bfloat16, float32

Files changed (all additive)

File Lines Description
kernels/sdpa_vector_turboquant.h ~450 Metal shader (1-pass + 2-pass)
kernels/sdpa_vector_turboquant.metal ~90 Kernel instantiation
kernels/steel/attn/params_turboquant.h ~28 Parameter struct
metal/turboquant_attention.cpp ~242 Dispatch logic
fast.cpp ~151 C++ API + validation
fast.h ~23 Declaration
fast_primitives.h ~49 Primitive class
python/src/fast.cpp ~84 Python bindings
tests/test_turboquant_attention.py ~300 16 tests
benchmarks/python/turboquant_attention_bench.py ~110 Benchmark vs SDPA

Tests

16 new tests in test_turboquant_attention.py:

  • Output shapes, dtypes, finiteness
  • GQA (H_q=8, H_kv=2)
  • Batch size > 1
  • Edge cases (N=1, N=2048 for 2-pass kernel)
  • 1-pass vs 2-pass consistency
  • float16 queries
  • Input validation (wrong rank, unsupported D, CPU rejection)

All 22 existing test_fast.py tests continue to pass (38/38 total).

Benchmarks

Included: benchmarks/python/turboquant_attention_bench.py

For full end-to-end model benchmarks: turboquant-bench

References

Yahav Zamari and others added 8 commits March 31, 2026 03:30
Native C++/Metal function that computes attention directly from
TurboQuant compressed KV data with zero intermediate allocations.
Fuses MSE score + QJL correction + value dequantization + online
softmax in a single Metal kernel.

Follows the sdpa_vector pattern: SIMD groups stride over KV tokens,
threads split head dimension D, cross-group reduction via threadgroup
memory transpose.

New files:
- Metal shader: sdpa_vector_turboquant.h (fused kernel)
- Metal instantiation: sdpa_vector_turboquant.metal
- Params struct: params_turboquant.h (TurboQuantAttnParams)
- GPU dispatch: turboquant_attention.cpp

Modified:
- fast.h: turboquant_attention declaration
- fast_primitives.h: TurboQuantAttention primitive class
- fast.cpp: C++ implementation with query rotation + validation
- python/src/fast.cpp: nanobind Python binding
- CMakeLists: build integration for Metal kernel + dispatch

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Three optimizations identified by Apple architecture and low-level
optimization expert audits:

1. Cache centroids in registers (4 floats) — eliminates random device
   memory reads for centroid lookup (15-25% MSE section improvement)
2. Hoist v_scales/v_zeros loads outside sub-loop — eliminates 4x
   redundant device memory loads per thread per token (10-15% value
   section improvement)
3. Pre-scale queries by attention scale at load time — saves 1
   multiply per token in the main loop

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Changed turboquant_attention to return 3 arrays instead of 1:
- acc: unnormalized weighted sum (B, H_q, qL, D)
- max_score: running max (B, H_q, qL)
- sum_exp: sum of exp(scores - max) (B, H_q, qL)

This enables merging with buffer portion via log-sum-exp arithmetic,
matching the existing mlx-turboquant fused decode interface.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Pass 1 splits KV tokens across multiple threadgroups (32-128 blocks),
each producing partial (acc, m, l). Pass 2 merges via log-sum-exp.

Routes to 2-pass when N >= 1024 and qL == 1 (decode path). Falls back
to 1-pass for shorter sequences or prefill.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add batch_idx to KV base address calculation in
sdpa_vector_turboquant_2pass_1. Without this, B>1 reads wrong data.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
16 tests covering output shapes, dtypes, GQA, batch sizes, edge cases
(N=1, N=2048 for 2-pass kernel), float16 support, input validation,
and 1-pass vs 2-pass consistency. All tests pass alongside existing
22 upstream fast tests.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Run clang-format and black on all TurboQuant files.
Add turboquant_attention_bench.py comparing fused kernel vs standard SDPA.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Templatize Metal shaders on MSE_BITS and V_BITS template parameters,
generalize bit unpacking to handle multi-byte reads (4 coords at 4-bit
= 2 bytes). Supports 2-bit, 4-bit, and mixed (4-bit keys + 2-bit values).

22 tests pass (16 original 2-bit + 6 new 4-bit).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@zcbenz
Copy link
Copy Markdown
Collaborator

zcbenz commented Mar 31, 2026

Duplicate of #3328.

@zcbenz zcbenz closed this Mar 31, 2026
@yzamari
Copy link
Copy Markdown
Author

yzamari commented Mar 31, 2026

Hi @zcbenz ,

I'd like to respectfully ask you to reconsider this closure. While both PRs add TurboQuant support, they are
substantially different implementations with different scope:

#3328 implements only the PolarQuant stage (Algorithm 1) of the paper with 3-bit keys only.

▎ This PR implements the full TurboQuant algorithm — both the MSE/PolarQuant stage (Algorithm 1) and the QJL residual
correction stage (Algorithm 2), which provides unbiased attention score estimation. This is the core contribution of the
paper (Zandieh et al., ICLR 2026).

▎ Other differences:
▎ - 2-bit and 4-bit support (vs 3-bit only)
▎ - Two kernel variants: 1-pass (N < 1024) and 2-pass (N ≥ 1024) for long-context scaling
▎ - 16 tests covering GQA, edge cases, dtype validation (vs 1 test)
▎ - Purely additive — zero modifications to existing MLX code
▎ - End-to-end model benchmarks (Qwen2.5-3B tok/s)

▎ I don't think these are duplicates — #3328 implements a subset of the algorithm. Would you be open to re-evaluating both
PRs on technical merit? Happy to discuss further or adapt anything to fit MLX's conventions.

@zcbenz
Copy link
Copy Markdown
Collaborator

zcbenz commented Mar 31, 2026

Basically we don't have time to review 2k lines PRs, the correct way to implement things is to make small steps like the referenced PR.

Also we will not necessarily add fast kernel for turbo quant, it is hot but if we add AI-written kernels for every hot thing MLX would probably have 100000000000000 lines of code now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants