Skip to content

Add 1-bit affine quantization support (Metal)#3161

Open
khosravipasha wants to merge 6 commits intoml-explore:mainfrom
PrismML-Eng:1bit-affine-quantization
Open

Add 1-bit affine quantization support (Metal)#3161
khosravipasha wants to merge 6 commits intoml-explore:mainfrom
PrismML-Eng:1bit-affine-quantization

Conversation

@khosravipasha
Copy link
Copy Markdown

Add 1-bit affine quantization support (Metal)

Proposed changes

This PR adds 1-bit support to MLX's affine quantization mode, extending the supported bit-widths from {2, 3, 4, 5, 6, 8} to {1, 2, 3, 4, 5, 6, 8}.

MLX already supports affine quantization at 2, 3, 4, 5, 6, and 8 bits via w_hat = scale * w_q + bias. This PR extends that same framework to 1-bit, adding full kernel support for 1-bit affine dequantization and quantized matmul across CPU and Metal backends.

This assumes the model has already been quantized externally (e.g. during training) — the contribution here is efficient packing and inference on Apple Silicon. It supports packing for both affine and symmetric 1-bit weights:

Affine 1-bit — weights have arbitrary per-group min/max:

scale = w_max - w_min,  bias = w_min
bit 0 → w_min,  bit 1 → w_max

Symmetric 1-bit — weights are {-d, +d} per group, automatically handled by the affine formula above since w_min = -d, w_max = +d:

scale = w_max - w_min = 2d,  bias = w_min = -d
bit 0 → 0·(2d) + (-d) = -d
bit 1 → 1·(2d) + (-d) = +d

A dedicated symmetric 1-bit mode (scale only, no bias) could save memory and skip the bias addition in the matmul kernels, but for now both cases run through the same affine path.

What's included

  • CPU backend: 1-bit quantize, dequantize, and quantized matmul (qmm dispatch)
  • Metal backend: Full 1-bit support in all quantized kernels — both non-NAX (quantized.h) and NAX (quantized_nax.h, quantized_nax.metal) paths, plus quantize/dequantize kernels
  • Python bindings: Updated mx.quantize(w, bits=1), mx.dequantize(...), and mx.quantized_matmul(...) documentation
  • Unit tests: Added 1-bit to test_quantize_dequantize, test_qmm, and a dedicated test_1bit_quantize_dequantize covering round-trip accuracy, zero handling, and quantized matmul correctness. Full test suite passes (672 tests, 0 failures).
  • No CUDA support: 1-bit is not yet supported on the CUDA backend. The CUDA dispatch_bits does not include a case 1: path. The new 1-bit test is added to cuda_skip.py.

Expected model-level performance (hypothetical, 8B parameter model, Apple M4 Pro 48 GB)

Based on the kernel-level benchmarks below, a hypothetical 8B parameter model at 1-bit would see roughly (varying by group size due to scale/bias metadata overhead):

Configuration Memory Expected Throughput
FP16 (baseline) ~15.3 GB ~15 tok/s
1-bit (group size 128) ~1.3 GB ~100–130 tok/s
1-bit (group size 64) ~1.6 GB ~90–115 tok/s
1-bit (group size 32) ~2.0 GB ~80–100 tok/s

We verified one scenario (group size 128, all weights quantized) and observed throughput in the ballpark of the estimates above. The primary purpose of this table is to give a sense of the runtime speed that 1-bit quantization enables. These are back-of-the-envelope numbers — actual end-to-end performance will vary depending on which layers are quantized, group size, attention overhead, and other non-quantized computation.

Kernel Corretness validation (KL divergence, 8B parameter model, WikiText-2)

To validate matmul kernel correctness, we compared two runs of the same 1-bit quantized model: one using the quantized matmul kernels (weights stay packed in 1-bit), and the other with the 1-bit weights dequantized to FP16 first and run through standard FP16 matmul. This is not a comparison between an FP16 model and its quantized version — both sides use identical weight values, so any divergence would indicate a kernel bug. Both the prompt processing (qmm) and token generation (qmv) paths were tested.

Prompt processing (qmm path) — 20 WikiText-2 chunks:

Metric Value
Forward KL(P||Q) Mean 0.000024
Reverse KL(Q||P) Mean 0.000017
Mean Top-1 Agreement 99.85%
Min Top-1 Agreement 99.29%

Token generation (qmv path) — 113 autoregressive steps (single-token qmv) across 5 prompts:

Metric Value
Forward KL(P||Q) Mean 0.000067
Reverse KL(Q||P) Mean -0.000038
Mean Top-1 Agreement 100.0%
Min Top-1 Agreement 100%

Both forward and reverse KL are near-zero, confirming the quantized kernels produce results consistent with the dequantized FP16 reference in both qmm and qmv code paths.

Changes

  • mlx/backend/cpu/quantized.cpp - 1-bit quantization logic and qmm dispatch
  • mlx/backend/metal/kernels/quantized.h - Metal 1-bit load_vector, qdot, qdot_safe, qouter, dequantize
  • mlx/backend/metal/kernels/quantized_nax.h - Same for NAX kernels
  • mlx/ops.cpp - Validation to accept bits=1
  • python/src/ops.cpp - Updated docstring table
  • python/tests/test_quantized.py - Added 1-bit to existing tests + dedicated 1-bit test
  • python/tests/cuda_skip.py - Skip 1-bit test on CUDA
  • benchmarks/python/comparative/bench_mlx.py - Added 1-bit entries to quant_matmul dict; auto-quantizes weight from --size args
  • benchmarks/python/comparative/compare.py - Added quant_matmul benchmark entries comparing 1/2/4/8-bit across qmv and qmm paths

Notes

  1. The Metal qmv_quad_impl kernel has a minor edge case with 1-bit when the inner dimension is < 128. In practice this should never come up — virtually all models have dimensions well above 128.

  2. If all weights in a group are exactly 0, the affine 1-bit quantization computes scale = eps (floored) and bias = 0, which dequantizes all values to near-zero (correct behavior).

  3. Kernel-level quantized_matmul benchmarks (Apple M4 Pro 48 GB, GPU, NAX path, group_size=128, 1000 calls, weight shape in parentheses):

    qmv path (M=1, single-token generation, memory-bandwidth bound):

    Layer FP16 1-bit 2-bit 4-bit 1-bit speedup vs FP16
    attn_proj (4096×4096) 167 µs 27 µs 39 µs 38 µs 6.2×
    ffn_gate (11008×4096) 391 µs 46 µs 72 µs 110 µs 8.6×
    ffn_down (4096×11008) 409 µs 59 µs 75 µs 98 µs 6.9×

    qmm path (M=32, prompt processing, more compute-bound):

    Layer FP16 1-bit 2-bit 4-bit 1-bit speedup vs FP16
    attn_proj (4096×4096) 308 µs 178 µs 179 µs 182 µs 1.7×
    ffn_gate (11008×4096) 841 µs 430 µs 438 µs 433 µs 2.0×
    ffn_down (4096×11008) 649 µs 441 µs 435 µs 440 µs 1.5×

    1-bit entries have been added to benchmarks/python/comparative/bench_mlx.py and compare.py. To reproduce (from repo root):

    # run all quant_matmul benchmarks (1/2/4/8-bit, qmv M=1, qmm M=32 & M=512)
    python benchmarks/python/comparative/compare.py --filter quant_matmul
    
    # or run individual benchmarks
    python benchmarks/python/comparative/bench_mlx.py quant_matmul_t_128_1 --size 1x4096 --size 4096x4096
    python benchmarks/python/comparative/bench_mlx.py quant_matmul_t_128_4 --size 1x4096 --size 4096x4096
    
    # unit tests
    python -m pytest python/tests/test_quantized.py::TestQuantized::test_1bit_quantize_dequantize -v
    python -m pytest python/tests/test_quantized.py::TestQuantized::test_quantize_dequantize -v
    python -m pytest python/tests/test_quantized.py::TestQuantized::test_qmm -v

Questions for reviewers

  1. NAX vs non-NAX testing: All benchmarks and the full test suite were run on macOS 26.2 (M4 Pro 48 GB), where NAX is active. The non-NAX path was partially validated by rebuilding with -DCMAKE_CXX_FLAGS=-DMLX_METAL_NO_NAX — unit tests pass, but full benchmarking was only done on the NAX path. We only have access to an M4. Is the -DMLX_METAL_NO_NAX build flag sufficient to validate the non-NAX path, or would you recommend testing on actual older hardware (M1/M2/M3)?

  2. Test coverage: The full test suite passes (672 tests, 0 failures), including dedicated 1-bit tests for both symmetric and asymmetric weight round-trip accuracy, zero handling, and quantized matmul correctness across both qmm and qmv paths. Is there any additional testing you'd like to see before merging?

Future work

  • Dedicated symmetric 1-bit mode
  • CUDA support

Checklist

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

Copilot AI review requested due to automatic review settings February 24, 2026 04:57
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds 1-bit affine quantization support to MLX, extending the existing quantization framework from {2, 3, 4, 5, 6, 8} bits to include 1-bit. The implementation provides efficient packing and inference for 1-bit quantized weights on Apple Silicon (Metal) and CPU backends.

Changes:

  • Added 1-bit support to affine quantization with formula: scale = w_max - w_min, bias = w_min where bit 0 → w_min, bit 1 → w_max
  • Implemented full Metal kernel support for 1-bit in both NAX and non-NAX paths across all quantized operations (quantize, dequantize, qmm, qmv)
  • Extended CPU backend with 1-bit quantization, dequantization, and quantized matmul dispatch
  • Added comprehensive test coverage for 1-bit symmetric/asymmetric weights, zero handling, and quantized matmul correctness
  • Updated Python bindings documentation and validation to accept bits=1
  • Added 1-bit benchmark entries for performance comparison across different group sizes
  • Excluded CUDA backend from 1-bit support (added to cuda_skip.py)

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated no comments.

Show a summary per file
File Description
python/tests/test_quantized.py Added 1-bit to existing parameterized tests and new dedicated test_1bit_quantize_dequantize with symmetric/asymmetric weights, zero handling, and qmm/qmv correctness tests
python/tests/cuda_skip.py Added 1-bit test to CUDA skip list since CUDA backend doesn't support 1-bit
python/src/ops.cpp Updated quantization mode documentation table to include 1-bit in supported bits
mlx/ops.cpp Modified validation to accept bits >= 1 and added 1-bit quantization formula (scale = w_max - w_min, bias = w_min)
mlx/backend/metal/kernels/quantized_nax.metal Added 1-bit kernel instantiation macro for NAX path
mlx/backend/metal/kernels/quantized_nax.h Implemented 1-bit versions of load_vector, load_vector_safe, qdot, qdot_safe, qouter, and dequantize for NAX optimized kernels
mlx/backend/metal/kernels/quantized.metal Added 1-bit kernel instantiation macro and quantize/dequantize logic for non-NAX path
mlx/backend/metal/kernels/quantized.h Implemented 1-bit versions of all quantization primitives for non-NAX kernels
mlx/backend/cpu/quantized.cpp Added 1-bit case to qmm dispatch and quantization logic matching Metal implementation
benchmarks/python/comparative/compare.py Added compare_mlx_quant function and 1-bit benchmark entries for qmv and qmm paths
benchmarks/python/comparative/bench_mlx.py Added 1-bit entries to quant_matmul dictionary for all group sizes and transpose modes, plus auto-quantization logic

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@angeloskath
Copy link
Copy Markdown
Member

Hi @khosravipasha that is pretty cool. I am not sure we want to support 1-bit quants natively in MLX, even 2 bits are not really used out there.

The options I see are:

  1. You could leave it as an open PR for people to chime in if they want this or think it would be useful in any way
  2. You could make it an extension that we 'd be happy to link to at MLX Community Projects #654 and anybody could simply pip install it and use it.

@khosravipasha khosravipasha deleted the 1bit-affine-quantization branch March 9, 2026 19:50
@khosravipasha khosravipasha restored the 1bit-affine-quantization branch March 31, 2026 19:25
@khosravipasha khosravipasha reopened this Mar 31, 2026
@khosravipasha
Copy link
Copy Markdown
Author

khosravipasha commented Mar 31, 2026

@angeloskath Thanks for the comment, we actually do have a native 1-bit modelx just released today, so would be amazing to have it in the official mlx repo.

Sorry for delayed response, we were waiting to come out of stealth before resending the PR.

Checkout our launch: https://huggingface.co/prism-ml
For now we are hosting it in out public fork: https://github.com/PrismML-Eng/mlx

I made a few changes I will send a the fresh code in a bit, we also have changes on mlx-swift for iPhones.

Copy link
Copy Markdown
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did a first pass and left some comments.

The main one is on the change in qmv_fast and why it is needed it isn't clear from the PR.

There is also linter errors and so on. But let's fix these and get it merged.



quant_matmul = {
"quant_matmul_32_1": partial(_quant_matmul, transpose=False, group_size=32, bits=1),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we probably need to simply add loops that create these.

Comment on lines +437 to +452
# Parse group_size and bits from the benchmark name, e.g.
# "quant_matmul_128_4" or "quant_matmul_t_128_4"
fn = quant_matmul[args.benchmark]
gs = fn.keywords["group_size"]
bits = fn.keywords["bits"]
transpose = fn.keywords["transpose"]

# xs[0] = activation x, xs[1] = original (float) weight matrix
# Quantize the weight internally so the caller only needs:
# --size MxK --size NxK (transpose=True) or --size MxK --size KxN
w_float = xs[1].astype(mx.float16)
w_q, scales, biases = mx.quantize(w_float, group_size=gs, bits=bits)
mx.eval(w_q, scales, biases)
x_input = xs[0].astype(mx.float16)
mx.eval(x_input)
print(bench(_quant_matmul, x_input, w_q, scales, biases, transpose, gs, bits))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would actually leave this as it were. This makes a lot of choices that we don't want to pre-make. For instance x should be able to be bfloat16 and also float32. Similarly the arrays could potentially be transposed or anything.

I understand it is annoying to have to pass 4 arrays with the correct sizes. A better solution would be to do another if section say simple_quant_matmul_* that just delegates to quant_matmul_* and creates the quantized input itself. But the dtypes should definitely not be changed.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats fair, this is not necessary,will revert to original for this

One question when I was testing bf16 was a lot slower than fp16 on my M4 Pro mac, is that expected?
That's why I was doing this for testing.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It depends what you were testing but on the GPU on M4 it shouldn't be. On pre M3 it should be slower indeed.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please make a repro and file an issue?

Copy link
Copy Markdown
Author

@khosravipasha khosravipasha Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just did a quick test and they are within <1% of each other on the M4. I probably miss-remembered from another backend (maybe gguf), or when I used bf16 scales when packing into 1-bit.

Been juggling around many backend recently.

Screenshot 2026-04-02 at 01 03 44

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah the slow down was after I pack into 1-bit. If I do bf16 then its roughly 30% slower and peak memory usage a bit higher, might be some hidden converstion to fp32 somewhere
Not really an issue for us since we went with fp16 scales

Screenshot 2026-04-02 at 16 39 46

y += tid.x * out_vec_size + out_row;

for (int k = 0; k < in_vec_size; k += block_size) {
const int aligned_end = (in_vec_size / block_size) * block_size;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is that needed? Is something changed for the 1 bit compared to the 2 bits in the launch parameters? I am not against this addition but it is not clear why it is needed for this PR. It might generally be a better choice to route more cases to qmv_fast with an epilogue such as this instead of the plain qmv .

Either way it should be clear why that is needed and likely not in this PR.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this I added later on, forgot to include in the PR.
Main issue was while testing the 4B model we were getting gibberish output after packing into 1-bit (packing into 2-bit was giving me good results, when packing our model into 2-bit mlx with this formula
2-bit: {-d,+d} → {0, 3} scale=2d/3, bias=-d (16 vals/uint32)

The main reason is the shapes are not divisible by 2048 for the 4B variant. But block size becomes 2048 for 1-bit the way we did it (at least if I understood correctly) so they will be some left overs not handled, the other for loop tries to handle that, for larger bit this won't be an issue as block size is less. The sizes are "hidden_size=2560 and intermediate_size=9728" neither is divisible by 2048 but they are divisible by 512 so 2-bit works okay without this

Constant Formula 1-bit 2-bit
packs_per_thread bits==2 ? 1 : 2 2 1
pack_factor 32 / bits 32 16
values_per_thread pack_factor × packs_per_thread 64 16
block_size values_per_thread × SIMD_SIZE(32) 2048 512

oh maybe we can fix this in simpler way by packs_per_thread=1 also for 1-bit?
did not notice this till now

constexpr int packs_per_thread = bits == 2 ? 1 : 2;

bcomes:

constexpr int packs_per_thread = (bits == 1 || bits == 2) ? 1 : 2

need to test correctness and speed of kernels, will try if that works can simplify here

need to think more how to generalize (does group size matter here? for us we were doing 128).

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on second thought even packs_per_thread=1 will need things to be divisible by 1024 which still has issues for the 4B sizes.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you kind of need to look closer into the whole kernel routing. It might warrant a more general change and this epilogue is not bad. Here is what I mean:

The qmv_fast path is selected when the input is divisible by 512. All the kernels there assume that. You can do the following:

  • Change the 1 bit kernel to work on 512 block size
  • Change the launch code to check for 2048 or 1024 divisibility for the 1 bit case
  • Change the kernel to check for remaining blocks as you have but that requires re-evaluating the launch code and removing the % 512 check. This might be a better choice overall .

In all of the above you should run some micro and macro benchmarks to evaluate perf.

Copy link
Copy Markdown
Author

@khosravipasha khosravipasha Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, yeah need to take a closer look

Change the 1 bit kernel to work on 512 block size

Thought about this a bit, but even if I set packs_per_thread=1, still needs to be divisible by 1024. Need to see if can switch any other tuning params.

Change the launch code to check for 2048 or 1024 divisibility for the 1 bit case

I think tried something similar to this and then 4B model was going through the slower kernel paths, and even was slower than the 8B model, need to double check my notes.

Change the kernel to check for remaining blocks as you have but that requires re-evaluating the launch code and removing the % 512 check. This might be a better choice overall .

Yeah this could work, just want to make sure it does not affect other stuff. I mainly ran benchmarks for the 3 models we have (in 1-bit and in 16-bit). For 8B MLX 1-bit makes us 8.4x faster, and for 4B/1.7BB around 4-5x faster. The 4B is also not as fast as I was expecting, could be due to the epilogue.

Overall, which one do you think is the best option short term? Long term I think there is a lot more room for tuning these.

Screenshot 2026-04-02 at 01 11 16 Screenshot 2026-04-02 at 01 12 01

1-bit-bonsai-8b-whitepaper.pdf

@khosravipasha
Copy link
Copy Markdown
Author

Thanks for the feedback, addressed the easy ones, the main one is to figure out a good way to handle qmv_fast changes (4B model has shapes that are not divisible by 2028 nor 1024, that causes block_size for loops to have some left overs; see detail in the comment)

Had a question about NAX vs non-NAX, is there something special need to do, not too familiar what their difference is, I mainly tested in Mac M4 Pro. Want to make sure works well on older Mac (M1-M3).

@angeloskath
Copy link
Copy Markdown
Member

Had a question about NAX vs non-NAX

NAX refers to the neural accelerators for M5. I can run some benchmarks. There isn't anything overly special per se, just the matmuls are faster so we need to make sure we can keep feeding the NAX quickly enough so dequantizing should be efficient.

There could be a lot of room for tuning this but let's get something working that isn't too far from bf16 for a largeish matrix and we should be good to go.

@khosravipasha
Copy link
Copy Markdown
Author

NAX refers to the neural accelerators for M5. I can run some benchmarks. There isn't anything overly special per se, just the matmuls are faster so we need to make sure we can keep feeding the NAX quickly enough so dequantizing should be efficient.
There could be a lot of room for tuning this but let's get something working that isn't too far from bf16 for a largeish matrix and we should be good to go.

I see thanks, yeah more neural chips on M5 chips seems exciting from what I have read, have not had a chance to try them yet myself. For now mostly care about correctness, and not being very slow. Definetly can be tuned,

I saw Ivan ran speed benchmarks with M5 Max already, looks very fast
https://x.com/ivanfioravanti/status/2039077744114319461

lyonsno added a commit to lyonsno/mlx-lm that referenced this pull request Apr 4, 2026
Point at PrismML-Eng/mlx@prism which adds 1-bit affine
quantization Metal kernels, enabling Bonsai-8B-mlx-1bit
(1.2 GB Qwen3-8B) to run locally on Apple Silicon.

Temporary pin until ml-explore/mlx#3161 merges upstream.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
lyonsno added a commit to lyonsno/mlx-lm that referenced this pull request Apr 4, 2026
Based on upstream mlx-lm HEAD (Gemma 4, BatchGenerator refactor).
Points mlx dep at PrismML-Eng/mlx@prism for 1-bit affine quantization
Metal kernels. Enables Bonsai-8B-mlx-1bit locally.

Temporary pin until ml-explore/mlx#3161 merges upstream.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants