Skip to content

Add TurboQuant KV cache compression with native Metal SDPA kernel#3328

Open
arozanov wants to merge 6 commits intoml-explore:mainfrom
arozanov:feature/turboquant-kv-cache
Open

Add TurboQuant KV cache compression with native Metal SDPA kernel#3328
arozanov wants to merge 6 commits intoml-explore:mainfrom
arozanov:feature/turboquant-kv-cache

Conversation

@arozanov
Copy link
Copy Markdown

@arozanov arozanov commented Mar 28, 2026

Proposed changes

Adds TurboQuant (arXiv 2504.19874) as a native Metal SDPA kernel for KV cache compression.

  • New QuantizationMode::TurboQuant
  • sdpa_vector_turbo Metal kernel: reads 3-bit packed K indices with codebook dequant
  • Pre-rotated query optimization: no WHT butterfly in attention inner loop
  • TurboQuantSDPA primitive with full eval_gpu dispatch
  • Python API: mx.fast.turboquant_sdpa()

Benchmarks (M4 Pro 48GB, 28 query heads, 4 KV heads, D=128):

Context Apple SDPA TurboQuant Speedup
1K 0.161ms 0.108ms 1.5x
4K 0.164ms 0.109ms 1.5x
8K 0.209ms 0.106ms 2.0x
16K 0.511ms 0.104ms 4.9x

TurboQuant reads 3-bit packed data (4.6x less memory bandwidth than fp16). Kernel time constant at ~0.1ms regardless of context length.

Related: standalone package at https://github.com/arozanov/turboquant-mlx and mlx-lm PR at ml-explore/mlx-lm#1067

Checklist

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

Adds TurboQuant (arXiv 2504.19874) as a new quantization mode for
KV cache compression in MLX core.

Changes:
- QuantizationMode::TurboQuant enum + string conversion
- sdpa_vector_turbo Metal kernel: reads bit-packed uint32 K indices
  with codebook dequant, pre-rotated query optimization (no WHT
  in inner loop). Instantiated for fp16/bf16 x 64/128 dim x 3/4 bit.
- C++ dispatch function sdpa_vector_turbo() in SDPA backend
- Python binding mx.fast.turboquant_sdpa()
- CMake fix: removed -sdk macosx from xcrun metal invocation
  (Metal Toolchain installed via xcodebuild -downloadComponent)

Status: Metal kernel compiled and instantiated. C++ dispatch ready.
Python binding exposed. Currently falls back to regular SDPA —
full native dispatch needs TurboQuantSDPA Primitive subclass
to wire eval_gpu to the turbo kernel.
- TurboQuantSDPA primitive class in fast_primitives.h
- eval_gpu() routes to sdpa_vector_turbo Metal kernel
- Full pipeline: Python mx.fast.turboquant_sdpa() → C++ → Metal
- Pre-rotated query: no WHT butterfly in attention inner loop
- Kernel reads bit-packed uint32 K indices + codebook directly
Native Metal kernel benchmarks:
  256 tokens: 0.83x standard SDPA
  1K tokens:  0.71x (turbo faster)
  4K tokens:  0.49x (turbo 2x faster)

TurboQuant reads 3-bit packed data = less memory bandwidth than fp16.
Native Metal kernel benchmarks (28 query heads, 4 KV heads, D=128):
  256 tokens:  0.8x (overhead)
  1K tokens:   1.5x faster
  4K tokens:   1.5x faster
  8K tokens:   2.0x faster
  16K tokens:  4.9x faster

TurboQuant kernel stays at ~0.1ms regardless of context length.
Apple SDPA grows linearly with context (memory bandwidth limited).

Changes:
- Proper buffer allocation with donation in eval_gpu
- Contiguous copy handling
- CPU fallback for non-GPU paths
@yzamari

This comment was marked as off-topic.

@zcbenz
Copy link
Copy Markdown
Collaborator

zcbenz commented Mar 31, 2026

Thanks for the PR and I think it is nicely written.

But I think we need to have a generic quantized sdpa kernel first and then add things like turbo quant as part of the API, rather than a specialized turbo quant kernel directly. #3026 is a nice step but as you can see we don't quite have enough maintainers to review.

On turbo quant itself, we don't just add builtin kernels for every new hot thing, users should use custom extensions or custom kernels instead.

@arozanov
Copy link
Copy Markdown
Author

Thanks for the PR and I think it is nicely written.

But I think we need to have a generic quantized sdpa kernel first and then add things like turbo quant as part of the API, rather than a specialized turbo quant kernel directly. #3026 is a nice step but as you can see we don't quite have enough maintainers to review.

On turbo quant itself, we don't just add builtin kernels for every new hot thing, users should use custom extensions or custom kernels instead.

Makes sense. Happy to help with #3026 if useful - the packed-integer codepath here (3-bit indices + codebook lookup) could serve as a reference for one of the quantization modes a generic API would need.
I'll package this as a custom Metal kernel extension via mx.fast.metal_kernel() in the meantime.

@deceptech-packet-ninja
Copy link
Copy Markdown

Bug: N reads wrong dimension of k_norms

While testing this kernel locally (built from source), I found that line 450 in scaled_dot_product_attention.cpp:

int N = k_norms.shape(1); // kv sequence length

reads shape(1) which is the head dimension (H_kv), not the sequence length (T). For k_norms of shape (B, H_kv, T), this should be:

int N = k_norms.shape(2); // kv sequence length

With shape(1), the kernel only processes H_kv keys (e.g., 8 instead of 500), causing incorrect attention output. After changing to shape(2), the kernel produces exact match (1.000 cosine similarity) against the reference mx.fast.scaled_dot_product_attention.

Kernel benchmarks after the fix (M-series, Llama-3-8B config: B=1, H_q=32, H_kv=8, D=128):

Context Standard SDPA TurboQuant SDPA Speedup
256 0.43ms 0.44ms 1.0x
512 0.49ms 0.22ms 2.2x
1024 0.52ms 0.20ms 2.6x
4096 1.91ms 0.30ms 6.3x

The kernel is impressively fast — essentially constant time regardless of context length.

@arozanov
Copy link
Copy Markdown
Author

Bug: N reads wrong dimension of k_norms

While testing this kernel locally (built from source), I found that line 450 in scaled_dot_product_attention.cpp:

int N = k_norms.shape(1); // kv sequence length

reads shape(1) which is the head dimension (H_kv), not the sequence length (T). For k_norms of shape (B, H_kv, T), this should be:

int N = k_norms.shape(2); // kv sequence length

With shape(1), the kernel only processes H_kv keys (e.g., 8 instead of 500), causing incorrect attention output. After changing to shape(2), the kernel produces exact match (1.000 cosine similarity) against the reference mx.fast.scaled_dot_product_attention.

Kernel benchmarks after the fix (M-series, Llama-3-8B config: B=1, H_q=32, H_kv=8, D=128):

Context Standard SDPA TurboQuant SDPA Speedup
256 0.43ms 0.44ms 1.0x
512 0.49ms 0.22ms 2.2x
1024 0.52ms 0.20ms 2.6x
4096 1.91ms 0.30ms 6.3x
The kernel is impressively fast — essentially constant time regardless of context length.

Good catch, thanks! Fixed: k_norms.shape(1) → k_norms.shape(2). The shape is (B, H_kv, T) so dimension 2 is the sequence length.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants