Add 1-bit affine quantization support (Metal)#3161
Add 1-bit affine quantization support (Metal)#3161khosravipasha wants to merge 6 commits intoml-explore:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds 1-bit affine quantization support to MLX, extending the existing quantization framework from {2, 3, 4, 5, 6, 8} bits to include 1-bit. The implementation provides efficient packing and inference for 1-bit quantized weights on Apple Silicon (Metal) and CPU backends.
Changes:
- Added 1-bit support to affine quantization with formula:
scale = w_max - w_min,bias = w_minwhere bit 0 → w_min, bit 1 → w_max - Implemented full Metal kernel support for 1-bit in both NAX and non-NAX paths across all quantized operations (quantize, dequantize, qmm, qmv)
- Extended CPU backend with 1-bit quantization, dequantization, and quantized matmul dispatch
- Added comprehensive test coverage for 1-bit symmetric/asymmetric weights, zero handling, and quantized matmul correctness
- Updated Python bindings documentation and validation to accept bits=1
- Added 1-bit benchmark entries for performance comparison across different group sizes
- Excluded CUDA backend from 1-bit support (added to cuda_skip.py)
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
python/tests/test_quantized.py |
Added 1-bit to existing parameterized tests and new dedicated test_1bit_quantize_dequantize with symmetric/asymmetric weights, zero handling, and qmm/qmv correctness tests |
python/tests/cuda_skip.py |
Added 1-bit test to CUDA skip list since CUDA backend doesn't support 1-bit |
python/src/ops.cpp |
Updated quantization mode documentation table to include 1-bit in supported bits |
mlx/ops.cpp |
Modified validation to accept bits >= 1 and added 1-bit quantization formula (scale = w_max - w_min, bias = w_min) |
mlx/backend/metal/kernels/quantized_nax.metal |
Added 1-bit kernel instantiation macro for NAX path |
mlx/backend/metal/kernels/quantized_nax.h |
Implemented 1-bit versions of load_vector, load_vector_safe, qdot, qdot_safe, qouter, and dequantize for NAX optimized kernels |
mlx/backend/metal/kernels/quantized.metal |
Added 1-bit kernel instantiation macro and quantize/dequantize logic for non-NAX path |
mlx/backend/metal/kernels/quantized.h |
Implemented 1-bit versions of all quantization primitives for non-NAX kernels |
mlx/backend/cpu/quantized.cpp |
Added 1-bit case to qmm dispatch and quantization logic matching Metal implementation |
benchmarks/python/comparative/compare.py |
Added compare_mlx_quant function and 1-bit benchmark entries for qmv and qmm paths |
benchmarks/python/comparative/bench_mlx.py |
Added 1-bit entries to quant_matmul dictionary for all group sizes and transpose modes, plus auto-quantization logic |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
Hi @khosravipasha that is pretty cool. I am not sure we want to support 1-bit quants natively in MLX, even 2 bits are not really used out there. The options I see are:
|
|
@angeloskath Thanks for the comment, we actually do have a native 1-bit modelx just released today, so would be amazing to have it in the official mlx repo. Sorry for delayed response, we were waiting to come out of stealth before resending the PR. Checkout our launch: https://huggingface.co/prism-ml I made a few changes I will send a the fresh code in a bit, we also have changes on mlx-swift for iPhones. |
3155f5c to
644a8cd
Compare
angeloskath
left a comment
There was a problem hiding this comment.
Did a first pass and left some comments.
The main one is on the change in qmv_fast and why it is needed it isn't clear from the PR.
There is also linter errors and so on. But let's fix these and get it merged.
|
|
||
|
|
||
| quant_matmul = { | ||
| "quant_matmul_32_1": partial(_quant_matmul, transpose=False, group_size=32, bits=1), |
There was a problem hiding this comment.
Yeah we probably need to simply add loops that create these.
| # Parse group_size and bits from the benchmark name, e.g. | ||
| # "quant_matmul_128_4" or "quant_matmul_t_128_4" | ||
| fn = quant_matmul[args.benchmark] | ||
| gs = fn.keywords["group_size"] | ||
| bits = fn.keywords["bits"] | ||
| transpose = fn.keywords["transpose"] | ||
|
|
||
| # xs[0] = activation x, xs[1] = original (float) weight matrix | ||
| # Quantize the weight internally so the caller only needs: | ||
| # --size MxK --size NxK (transpose=True) or --size MxK --size KxN | ||
| w_float = xs[1].astype(mx.float16) | ||
| w_q, scales, biases = mx.quantize(w_float, group_size=gs, bits=bits) | ||
| mx.eval(w_q, scales, biases) | ||
| x_input = xs[0].astype(mx.float16) | ||
| mx.eval(x_input) | ||
| print(bench(_quant_matmul, x_input, w_q, scales, biases, transpose, gs, bits)) |
There was a problem hiding this comment.
I would actually leave this as it were. This makes a lot of choices that we don't want to pre-make. For instance x should be able to be bfloat16 and also float32. Similarly the arrays could potentially be transposed or anything.
I understand it is annoying to have to pass 4 arrays with the correct sizes. A better solution would be to do another if section say simple_quant_matmul_* that just delegates to quant_matmul_* and creates the quantized input itself. But the dtypes should definitely not be changed.
There was a problem hiding this comment.
Thats fair, this is not necessary,will revert to original for this
One question when I was testing bf16 was a lot slower than fp16 on my M4 Pro mac, is that expected?
That's why I was doing this for testing.
There was a problem hiding this comment.
It depends what you were testing but on the GPU on M4 it shouldn't be. On pre M3 it should be slower indeed.
There was a problem hiding this comment.
Can you please make a repro and file an issue?
| y += tid.x * out_vec_size + out_row; | ||
|
|
||
| for (int k = 0; k < in_vec_size; k += block_size) { | ||
| const int aligned_end = (in_vec_size / block_size) * block_size; |
There was a problem hiding this comment.
Why is that needed? Is something changed for the 1 bit compared to the 2 bits in the launch parameters? I am not against this addition but it is not clear why it is needed for this PR. It might generally be a better choice to route more cases to qmv_fast with an epilogue such as this instead of the plain qmv .
Either way it should be clear why that is needed and likely not in this PR.
There was a problem hiding this comment.
Yeah this I added later on, forgot to include in the PR.
Main issue was while testing the 4B model we were getting gibberish output after packing into 1-bit (packing into 2-bit was giving me good results, when packing our model into 2-bit mlx with this formula
2-bit: {-d,+d} → {0, 3} scale=2d/3, bias=-d (16 vals/uint32)
The main reason is the shapes are not divisible by 2048 for the 4B variant. But block size becomes 2048 for 1-bit the way we did it (at least if I understood correctly) so they will be some left overs not handled, the other for loop tries to handle that, for larger bit this won't be an issue as block size is less. The sizes are "hidden_size=2560 and intermediate_size=9728" neither is divisible by 2048 but they are divisible by 512 so 2-bit works okay without this
| Constant | Formula | 1-bit | 2-bit |
|---|---|---|---|
packs_per_thread |
bits==2 ? 1 : 2 |
2 | 1 |
pack_factor |
32 / bits |
32 | 16 |
values_per_thread |
pack_factor × packs_per_thread |
64 | 16 |
block_size |
values_per_thread × SIMD_SIZE(32) |
2048 | 512 |
oh maybe we can fix this in simpler way by packs_per_thread=1 also for 1-bit?
did not notice this till now
constexpr int packs_per_thread = bits == 2 ? 1 : 2;
bcomes:
constexpr int packs_per_thread = (bits == 1 || bits == 2) ? 1 : 2
need to test correctness and speed of kernels, will try if that works can simplify here
need to think more how to generalize (does group size matter here? for us we were doing 128).
There was a problem hiding this comment.
on second thought even packs_per_thread=1 will need things to be divisible by 1024 which still has issues for the 4B sizes.
There was a problem hiding this comment.
So you kind of need to look closer into the whole kernel routing. It might warrant a more general change and this epilogue is not bad. Here is what I mean:
The qmv_fast path is selected when the input is divisible by 512. All the kernels there assume that. You can do the following:
- Change the 1 bit kernel to work on 512 block size
- Change the launch code to check for 2048 or 1024 divisibility for the 1 bit case
- Change the kernel to check for remaining blocks as you have but that requires re-evaluating the launch code and removing the % 512 check. This might be a better choice overall .
In all of the above you should run some micro and macro benchmarks to evaluate perf.
There was a problem hiding this comment.
Thanks, yeah need to take a closer look
Change the 1 bit kernel to work on 512 block size
Thought about this a bit, but even if I set packs_per_thread=1, still needs to be divisible by 1024. Need to see if can switch any other tuning params.
Change the launch code to check for 2048 or 1024 divisibility for the 1 bit case
I think tried something similar to this and then 4B model was going through the slower kernel paths, and even was slower than the 8B model, need to double check my notes.
Change the kernel to check for remaining blocks as you have but that requires re-evaluating the launch code and removing the % 512 check. This might be a better choice overall .
Yeah this could work, just want to make sure it does not affect other stuff. I mainly ran benchmarks for the 3 models we have (in 1-bit and in 16-bit). For 8B MLX 1-bit makes us 8.4x faster, and for 4B/1.7BB around 4-5x faster. The 4B is also not as fast as I was expecting, could be due to the epilogue.
Overall, which one do you think is the best option short term? Long term I think there is a lot more room for tuning these.
|
Thanks for the feedback, addressed the easy ones, the main one is to figure out a good way to handle qmv_fast changes (4B model has shapes that are not divisible by 2028 nor 1024, that causes block_size for loops to have some left overs; see detail in the comment) Had a question about NAX vs non-NAX, is there something special need to do, not too familiar what their difference is, I mainly tested in Mac M4 Pro. Want to make sure works well on older Mac (M1-M3). |
NAX refers to the neural accelerators for M5. I can run some benchmarks. There isn't anything overly special per se, just the matmuls are faster so we need to make sure we can keep feeding the NAX quickly enough so dequantizing should be efficient. There could be a lot of room for tuning this but let's get something working that isn't too far from bf16 for a largeish matrix and we should be good to go. |
I see thanks, yeah more neural chips on M5 chips seems exciting from what I have read, have not had a chance to try them yet myself. For now mostly care about correctness, and not being very slow. Definetly can be tuned, I saw Ivan ran speed benchmarks with M5 Max already, looks very fast |
Point at PrismML-Eng/mlx@prism which adds 1-bit affine quantization Metal kernels, enabling Bonsai-8B-mlx-1bit (1.2 GB Qwen3-8B) to run locally on Apple Silicon. Temporary pin until ml-explore/mlx#3161 merges upstream. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Based on upstream mlx-lm HEAD (Gemma 4, BatchGenerator refactor). Points mlx dep at PrismML-Eng/mlx@prism for 1-bit affine quantization Metal kernels. Enables Bonsai-8B-mlx-1bit locally. Temporary pin until ml-explore/mlx#3161 merges upstream. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>


Add 1-bit affine quantization support (Metal)
Proposed changes
This PR adds 1-bit support to MLX's affine quantization mode, extending the supported bit-widths from
{2, 3, 4, 5, 6, 8}to{1, 2, 3, 4, 5, 6, 8}.MLX already supports affine quantization at 2, 3, 4, 5, 6, and 8 bits via
w_hat = scale * w_q + bias. This PR extends that same framework to 1-bit, adding full kernel support for 1-bit affine dequantization and quantized matmul across CPU and Metal backends.This assumes the model has already been quantized externally (e.g. during training) — the contribution here is efficient packing and inference on Apple Silicon. It supports packing for both affine and symmetric 1-bit weights:
Affine 1-bit — weights have arbitrary per-group min/max:
Symmetric 1-bit — weights are
{-d, +d}per group, automatically handled by the affine formula above sincew_min = -d, w_max = +d:A dedicated symmetric 1-bit mode (scale only, no bias) could save memory and skip the bias addition in the matmul kernels, but for now both cases run through the same affine path.
What's included
qmmdispatch)quantized.h) and NAX (quantized_nax.h,quantized_nax.metal) paths, plus quantize/dequantize kernelsmx.quantize(w, bits=1),mx.dequantize(...), andmx.quantized_matmul(...)documentationtest_quantize_dequantize,test_qmm, and a dedicatedtest_1bit_quantize_dequantizecovering round-trip accuracy, zero handling, and quantized matmul correctness. Full test suite passes (672 tests, 0 failures).dispatch_bitsdoes not include acase 1:path. The new 1-bit test is added tocuda_skip.py.Expected model-level performance (hypothetical, 8B parameter model, Apple M4 Pro 48 GB)
Based on the kernel-level benchmarks below, a hypothetical 8B parameter model at 1-bit would see roughly (varying by group size due to scale/bias metadata overhead):
We verified one scenario (group size 128, all weights quantized) and observed throughput in the ballpark of the estimates above. The primary purpose of this table is to give a sense of the runtime speed that 1-bit quantization enables. These are back-of-the-envelope numbers — actual end-to-end performance will vary depending on which layers are quantized, group size, attention overhead, and other non-quantized computation.
Kernel Corretness validation (KL divergence, 8B parameter model, WikiText-2)
To validate matmul kernel correctness, we compared two runs of the same 1-bit quantized model: one using the quantized matmul kernels (weights stay packed in 1-bit), and the other with the 1-bit weights dequantized to FP16 first and run through standard FP16 matmul. This is not a comparison between an FP16 model and its quantized version — both sides use identical weight values, so any divergence would indicate a kernel bug. Both the prompt processing (qmm) and token generation (qmv) paths were tested.
Prompt processing (qmm path) — 20 WikiText-2 chunks:
Token generation (qmv path) — 113 autoregressive steps (single-token qmv) across 5 prompts:
Both forward and reverse KL are near-zero, confirming the quantized kernels produce results consistent with the dequantized FP16 reference in both qmm and qmv code paths.
Changes
mlx/backend/cpu/quantized.cpp- 1-bit quantization logic andqmmdispatchmlx/backend/metal/kernels/quantized.h- Metal 1-bitload_vector,qdot,qdot_safe,qouter,dequantizemlx/backend/metal/kernels/quantized_nax.h- Same for NAX kernelsmlx/ops.cpp- Validation to acceptbits=1python/src/ops.cpp- Updated docstring tablepython/tests/test_quantized.py- Added 1-bit to existing tests + dedicated 1-bit testpython/tests/cuda_skip.py- Skip 1-bit test on CUDAbenchmarks/python/comparative/bench_mlx.py- Added 1-bit entries toquant_matmuldict; auto-quantizes weight from--sizeargsbenchmarks/python/comparative/compare.py- Addedquant_matmulbenchmark entries comparing 1/2/4/8-bit across qmv and qmm pathsNotes
The Metal
qmv_quad_implkernel has a minor edge case with 1-bit when the inner dimension is < 128. In practice this should never come up — virtually all models have dimensions well above 128.If all weights in a group are exactly 0, the affine 1-bit quantization computes
scale = eps(floored) andbias = 0, which dequantizes all values to near-zero (correct behavior).Kernel-level
quantized_matmulbenchmarks (Apple M4 Pro 48 GB, GPU, NAX path,group_size=128, 1000 calls, weight shape in parentheses):qmv path (M=1, single-token generation, memory-bandwidth bound):
qmm path (M=32, prompt processing, more compute-bound):
1-bit entries have been added to
benchmarks/python/comparative/bench_mlx.pyandcompare.py. To reproduce (from repo root):Questions for reviewers
NAX vs non-NAX testing: All benchmarks and the full test suite were run on macOS 26.2 (M4 Pro 48 GB), where NAX is active. The non-NAX path was partially validated by rebuilding with
-DCMAKE_CXX_FLAGS=-DMLX_METAL_NO_NAX— unit tests pass, but full benchmarking was only done on the NAX path. We only have access to an M4. Is the-DMLX_METAL_NO_NAXbuild flag sufficient to validate the non-NAX path, or would you recommend testing on actual older hardware (M1/M2/M3)?Test coverage: The full test suite passes (672 tests, 0 failures), including dedicated 1-bit tests for both symmetric and asymmetric weight round-trip accuracy, zero handling, and quantized matmul correctness across both qmm and qmv paths. Is there any additional testing you'd like to see before merging?
Future work
Checklist
pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes