Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions benchmarks/python/comparative/bench_mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,17 @@ def _quant_matmul(x, w, s, b, transpose, group_size, bits):


quant_matmul = {
"quant_matmul_32_1": partial(_quant_matmul, transpose=False, group_size=32, bits=1),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we probably need to simply add loops that create these.

"quant_matmul_32_2": partial(_quant_matmul, transpose=False, group_size=32, bits=2),
"quant_matmul_32_4": partial(_quant_matmul, transpose=False, group_size=32, bits=4),
"quant_matmul_32_8": partial(_quant_matmul, transpose=False, group_size=32, bits=8),
"quant_matmul_64_1": partial(_quant_matmul, transpose=False, group_size=64, bits=1),
"quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2),
"quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4),
"quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8),
"quant_matmul_128_1": partial(
_quant_matmul, transpose=False, group_size=128, bits=1
),
"quant_matmul_128_2": partial(
_quant_matmul, transpose=False, group_size=128, bits=2
),
Expand All @@ -87,6 +92,9 @@ def _quant_matmul(x, w, s, b, transpose, group_size, bits):
"quant_matmul_128_8": partial(
_quant_matmul, transpose=False, group_size=128, bits=8
),
"quant_matmul_t_32_1": partial(
_quant_matmul, transpose=True, group_size=32, bits=1
),
"quant_matmul_t_32_2": partial(
_quant_matmul, transpose=True, group_size=32, bits=2
),
Expand All @@ -96,6 +104,9 @@ def _quant_matmul(x, w, s, b, transpose, group_size, bits):
"quant_matmul_t_32_8": partial(
_quant_matmul, transpose=True, group_size=32, bits=8
),
"quant_matmul_t_64_1": partial(
_quant_matmul, transpose=True, group_size=64, bits=1
),
"quant_matmul_t_64_2": partial(
_quant_matmul, transpose=True, group_size=64, bits=2
),
Expand All @@ -105,6 +116,9 @@ def _quant_matmul(x, w, s, b, transpose, group_size, bits):
"quant_matmul_t_64_8": partial(
_quant_matmul, transpose=True, group_size=64, bits=8
),
"quant_matmul_t_128_1": partial(
_quant_matmul, transpose=True, group_size=128, bits=1
),
"quant_matmul_t_128_2": partial(
_quant_matmul, transpose=True, group_size=128, bits=2
),
Expand Down
34 changes: 24 additions & 10 deletions mlx/backend/cpu/quantized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,10 @@ void _qmm_dispatch_typed(
int bits,
bool transposed_w) {
switch (bits) {
case 1:
_qmm_dispatch_group<T, 1>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
case 2:
_qmm_dispatch_group<T, 2>(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
Expand All @@ -376,7 +380,8 @@ void _qmm_dispatch_typed(
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
break;
default:
throw std::invalid_argument("Quantization bits must be 2, 3, 4, 6 or 8.");
throw std::invalid_argument(
"Quantization bits must be 1, 2, 3, 4, 5, 6 or 8.");
}
}

Expand Down Expand Up @@ -1172,15 +1177,24 @@ void quantize(
w_min = std::min(w_min, (float)w[w_idx + j]);
}
bool mask = std::abs(w_min) > std::abs(w_max);
float scale = std::max((w_max - w_min) / n_bins, eps);
scale = mask ? scale : -scale;

float edge = mask ? w_min : w_max;
float q0 = std::rint(edge / scale);
float bias = 0;
if (q0 != 0) {
scale = edge / q0;
bias = edge;
float scale;
float bias;

if (bits == 1) {
// Affine 1-bit: bit 0 -> w_min, bit 1 -> w_max
scale = std::max(w_max - w_min, eps);
bias = w_min;
} else {
scale = std::max((w_max - w_min) / n_bins, eps);
scale = mask ? scale : -scale;

float edge = mask ? w_min : w_max;
float q0 = std::rint(edge / scale);
bias = 0;
if (q0 != 0) {
scale = edge / q0;
bias = edge;
}
}
size_t out_idx = i * int_per_group;
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
Expand Down
202 changes: 165 additions & 37 deletions mlx/backend/metal/kernels/quantized.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,28 @@ inline constexpr short get_bytes_per_pack() {
template <typename T, typename U, int values_per_thread, int bits>
inline U load_vector(const device T* x, thread U* x_thread) {
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
bits == 8,
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 ||
bits == 6 || bits == 8,
"Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}");

U sum = 0;

if (bits == 2) {
if (bits == 1) {
for (int i = 0; i < values_per_thread; i += 8) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
x[i + 6] + x[i + 7];
x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1];
x_thread[i + 2] = x[i + 2];
x_thread[i + 3] = x[i + 3];
x_thread[i + 4] = x[i + 4];
x_thread[i + 5] = x[i + 5];
x_thread[i + 6] = x[i + 6];
x_thread[i + 7] = x[i + 7];
}
}

else if (bits == 2) {
for (int i = 0; i < values_per_thread; i += 4) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
x_thread[i] = x[i];
Expand Down Expand Up @@ -107,13 +122,28 @@ inline U load_vector(const device T* x, thread U* x_thread) {
template <typename T, typename U, int values_per_thread, int bits>
inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
bits == 8,
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 ||
bits == 6 || bits == 8,
"Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}");

U sum = 0;

if (bits == 2) {
if (bits == 1) {
for (int i = 0; i < N; i += 8) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
x[i + 6] + x[i + 7];
x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1];
x_thread[i + 2] = x[i + 2];
x_thread[i + 3] = x[i + 3];
x_thread[i + 4] = x[i + 4];
x_thread[i + 5] = x[i + 5];
x_thread[i + 6] = x[i + 6];
x_thread[i + 7] = x[i + 7];
}
}

else if (bits == 2) {
for (int i = 0; i < N; i += 4) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
x_thread[i] = x[i];
Expand Down Expand Up @@ -196,13 +226,27 @@ inline U qdot(
U bias,
U sum) {
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
bits == 8,
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 ||
bits == 6 || bits == 8,
"Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}");

U accum = 0;

if (bits == 2) {
if (bits == 1) {
for (int i = 0; i < (values_per_thread / 8); i++) {
uint8_t wb = w[i];
accum += select(U(0), x_thread[8 * i], bool(wb & 0x01));
accum += select(U(0), x_thread[8 * i + 1], bool(wb & 0x02));
accum += select(U(0), x_thread[8 * i + 2], bool(wb & 0x04));
accum += select(U(0), x_thread[8 * i + 3], bool(wb & 0x08));
accum += select(U(0), x_thread[8 * i + 4], bool(wb & 0x10));
accum += select(U(0), x_thread[8 * i + 5], bool(wb & 0x20));
accum += select(U(0), x_thread[8 * i + 6], bool(wb & 0x40));
accum += select(U(0), x_thread[8 * i + 7], bool(wb & 0x80));
}
}

else if (bits == 2) {
for (int i = 0; i < (values_per_thread / 4); i++) {
accum +=
(x_thread[4 * i] * (w[i] & 0x03) +
Expand Down Expand Up @@ -298,13 +342,27 @@ inline U qdot_safe(
U sum,
int N) {
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
bits == 8,
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 ||
bits == 6 || bits == 8,
"Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}");

U accum = 0;

if (bits == 2) {
if (bits == 1) {
for (int i = 0; i < (N / 8); i++) {
uint8_t wb = w[i];
accum += select(U(0), x_thread[8 * i], bool(wb & 0x01));
accum += select(U(0), x_thread[8 * i + 1], bool(wb & 0x02));
accum += select(U(0), x_thread[8 * i + 2], bool(wb & 0x04));
accum += select(U(0), x_thread[8 * i + 3], bool(wb & 0x08));
accum += select(U(0), x_thread[8 * i + 4], bool(wb & 0x10));
accum += select(U(0), x_thread[8 * i + 5], bool(wb & 0x20));
accum += select(U(0), x_thread[8 * i + 6], bool(wb & 0x40));
accum += select(U(0), x_thread[8 * i + 7], bool(wb & 0x80));
}
}

else if (bits == 2) {
for (int i = 0; i < (N / 4); i++) {
accum +=
(x_thread[4 * i] * (w[i] & 0x03) +
Expand Down Expand Up @@ -395,11 +453,25 @@ template <typename U, int values_per_thread, int bits>
inline void
qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
bits == 8,
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 ||
bits == 6 || bits == 8,
"Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}");

if (bits == 2) {
if (bits == 1) {
for (int i = 0; i < (values_per_thread / 8); i++) {
uint8_t wb = w[i];
result[8 * i] += x * (select(U(0), scale, bool(wb & 0x01)) + bias);
result[8 * i + 1] += x * (select(U(0), scale, bool(wb & 0x02)) + bias);
result[8 * i + 2] += x * (select(U(0), scale, bool(wb & 0x04)) + bias);
result[8 * i + 3] += x * (select(U(0), scale, bool(wb & 0x08)) + bias);
result[8 * i + 4] += x * (select(U(0), scale, bool(wb & 0x10)) + bias);
result[8 * i + 5] += x * (select(U(0), scale, bool(wb & 0x20)) + bias);
result[8 * i + 6] += x * (select(U(0), scale, bool(wb & 0x40)) + bias);
result[8 * i + 7] += x * (select(U(0), scale, bool(wb & 0x80)) + bias);
}
}

else if (bits == 2) {
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
for (int i = 0; i < (values_per_thread / 4); i++) {
result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias);
Expand Down Expand Up @@ -484,11 +556,33 @@ template <typename U, int N, int bits>
inline void
dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
bits == 8,
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 ||
bits == 6 || bits == 8,
"Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}");

if (bits == 1) {
U s[8] = {
scale,
scale / static_cast<U>(2.0f),
scale / static_cast<U>(4.0f),
scale / static_cast<U>(8.0f),
scale / static_cast<U>(16.0f),
scale / static_cast<U>(32.0f),
scale / static_cast<U>(64.0f),
scale / static_cast<U>(128.0f)};
for (int i = 0; i < (N / 8); i++) {
w_local[8 * i] = s[0] * (w[i] & 0x01) + bias;
w_local[8 * i + 1] = s[1] * (w[i] & 0x02) + bias;
w_local[8 * i + 2] = s[2] * (w[i] & 0x04) + bias;
w_local[8 * i + 3] = s[3] * (w[i] & 0x08) + bias;
w_local[8 * i + 4] = s[4] * (w[i] & 0x10) + bias;
w_local[8 * i + 5] = s[5] * (w[i] & 0x20) + bias;
w_local[8 * i + 6] = s[6] * (w[i] & 0x40) + bias;
w_local[8 * i + 7] = s[7] * (w[i] & 0x80) + bias;
}
}

if (bits == 2) {
else if (bits == 2) {
U s[4] = {
scale,
scale / static_cast<U>(4.0f),
Expand Down Expand Up @@ -577,9 +671,9 @@ struct QuantizedBlockLoader {
group_size % BCOLS == 0,
"The group size should be divisible by the columns");
static_assert(
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
bits == 8,
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 ||
bits == 6 || bits == 8,
"Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}");

MLX_MTL_CONST short pack_factor = get_pack_factor<bits, 8>();
MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack<bits>();
Expand Down Expand Up @@ -786,7 +880,9 @@ METAL_FUNC void qmv_fast_impl(
x += tid.x * in_vec_size + simd_lid * values_per_thread;
y += tid.x * out_vec_size + out_row;

for (int k = 0; k < in_vec_size; k += block_size) {
const int aligned_end = (in_vec_size / block_size) * block_size;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is that needed? Is something changed for the 1 bit compared to the 2 bits in the launch parameters? I am not against this addition but it is not clear why it is needed for this PR. It might generally be a better choice to route more cases to qmv_fast with an epilogue such as this instead of the plain qmv .

Either way it should be clear why that is needed and likely not in this PR.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this I added later on, forgot to include in the PR.
Main issue was while testing the 4B model we were getting gibberish output after packing into 1-bit (packing into 2-bit was giving me good results, when packing our model into 2-bit mlx with this formula
2-bit: {-d,+d} → {0, 3} scale=2d/3, bias=-d (16 vals/uint32)

The main reason is the shapes are not divisible by 2048 for the 4B variant. But block size becomes 2048 for 1-bit the way we did it (at least if I understood correctly) so they will be some left overs not handled, the other for loop tries to handle that, for larger bit this won't be an issue as block size is less. The sizes are "hidden_size=2560 and intermediate_size=9728" neither is divisible by 2048 but they are divisible by 512 so 2-bit works okay without this

Constant Formula 1-bit 2-bit
packs_per_thread bits==2 ? 1 : 2 2 1
pack_factor 32 / bits 32 16
values_per_thread pack_factor × packs_per_thread 64 16
block_size values_per_thread × SIMD_SIZE(32) 2048 512

oh maybe we can fix this in simpler way by packs_per_thread=1 also for 1-bit?
did not notice this till now

constexpr int packs_per_thread = bits == 2 ? 1 : 2;

bcomes:

constexpr int packs_per_thread = (bits == 1 || bits == 2) ? 1 : 2

need to test correctness and speed of kernels, will try if that works can simplify here

need to think more how to generalize (does group size matter here? for us we were doing 128).

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on second thought even packs_per_thread=1 will need things to be divisible by 1024 which still has issues for the 4B sizes.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you kind of need to look closer into the whole kernel routing. It might warrant a more general change and this epilogue is not bad. Here is what I mean:

The qmv_fast path is selected when the input is divisible by 512. All the kernels there assume that. You can do the following:

  • Change the 1 bit kernel to work on 512 block size
  • Change the launch code to check for 2048 or 1024 divisibility for the 1 bit case
  • Change the kernel to check for remaining blocks as you have but that requires re-evaluating the launch code and removing the % 512 check. This might be a better choice overall .

In all of the above you should run some micro and macro benchmarks to evaluate perf.

Copy link
Copy Markdown
Author

@khosravipasha khosravipasha Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, yeah need to take a closer look

Change the 1 bit kernel to work on 512 block size

Thought about this a bit, but even if I set packs_per_thread=1, still needs to be divisible by 1024. Need to see if can switch any other tuning params.

Change the launch code to check for 2048 or 1024 divisibility for the 1 bit case

I think tried something similar to this and then 4B model was going through the slower kernel paths, and even was slower than the 8B model, need to double check my notes.

Change the kernel to check for remaining blocks as you have but that requires re-evaluating the launch code and removing the % 512 check. This might be a better choice overall .

Yeah this could work, just want to make sure it does not affect other stuff. I mainly ran benchmarks for the 3 models we have (in 1-bit and in 16-bit). For 8B MLX 1-bit makes us 8.4x faster, and for 4B/1.7BB around 4-5x faster. The 4B is also not as fast as I was expecting, could be due to the epilogue.

Overall, which one do you think is the best option short term? Long term I think there is a lot more room for tuning these.

Screenshot 2026-04-02 at 01 11 16 Screenshot 2026-04-02 at 01 12 01

1-bit-bonsai-8b-whitepaper.pdf


for (int k = 0; k < aligned_end; k += block_size) {
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);

for (int row = 0; row < results_per_simdgroup; row++) {
Expand All @@ -805,6 +901,27 @@ METAL_FUNC void qmv_fast_impl(
x += block_size;
}

if (aligned_end < in_vec_size) {
bool in_bounds = (aligned_end + simd_lid * values_per_thread) < in_vec_size;
U sum = 0;
if (in_bounds) {
sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
} else {
for (int i = 0; i < values_per_thread; i++)
x_thread[i] = 0;
}

for (int row = 0; row < results_per_simdgroup; row++) {
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g;

U s = in_bounds ? (U)sl[0] : (U)0;
U b = in_bounds ? (U)bl[0] : (U)0;
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
}
}

for (int row = 0; row < results_per_simdgroup; row++) {
result[row] = simd_sum(result[row]);
if (simd_lid == 0) {
Expand Down Expand Up @@ -2472,14 +2589,23 @@ template <typename T, const int group_size, const int bits>
w_min = simd_min(w_min);
w_max = simd_max(w_max);

float scale = max((w_max - w_min) / n_bins, eps);
bool side = abs(w_min) > abs(w_max);
scale = side ? scale : -scale;
float edge = side ? w_min : w_max;
float q0 = round(edge / scale);
bool at_zero = q0 == 0.0f;
scale = at_zero ? scale : edge / q0;
float bias = at_zero ? 0 : edge;
float scale;
float bias;

if (bits == 1) {
// Affine 1-bit: bit 0 -> w_min, bit 1 -> w_max
scale = max(w_max - w_min, eps);
bias = w_min;
} else {
scale = max((w_max - w_min) / n_bins, eps);
bool side = abs(w_min) > abs(w_max);
scale = side ? scale : -scale;
float edge = side ? w_min : w_max;
float q0 = round(edge / scale);
bool at_zero = q0 == 0.0f;
scale = at_zero ? scale : edge / q0;
bias = at_zero ? 0 : edge;
}

// Write out the scales and biases
size_t gindex = in_index / group_size;
Expand Down Expand Up @@ -2583,7 +2709,9 @@ template <typename T, const int group_size, const int bits>
#pragma clang loop unroll(full)
for (int i = 0; i < pack_factor; i++) {
uint8_t d;
if (bits == 2) {
if (bits == 1) {
d = (val >> i) & 0x01;
} else if (bits == 2) {
d = (val >> (bits * i)) & 0x03;
} else if (bits == 4) {
d = (val >> (bits * i)) & 0x0f;
Expand Down
Loading
Loading