-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Add 1-bit affine quantization support (Metal) #3161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
784203f
4d6ca4d
d155e95
644a8cd
b194cb9
a386acc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -28,13 +28,28 @@ inline constexpr short get_bytes_per_pack() { | |||||||||||||||||||||
| template <typename T, typename U, int values_per_thread, int bits> | ||||||||||||||||||||||
| inline U load_vector(const device T* x, thread U* x_thread) { | ||||||||||||||||||||||
| static_assert( | ||||||||||||||||||||||
| bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || | ||||||||||||||||||||||
| bits == 8, | ||||||||||||||||||||||
| "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); | ||||||||||||||||||||||
| bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || | ||||||||||||||||||||||
| bits == 6 || bits == 8, | ||||||||||||||||||||||
| "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| U sum = 0; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if (bits == 2) { | ||||||||||||||||||||||
| if (bits == 1) { | ||||||||||||||||||||||
| for (int i = 0; i < values_per_thread; i += 8) { | ||||||||||||||||||||||
| sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + | ||||||||||||||||||||||
| x[i + 6] + x[i + 7]; | ||||||||||||||||||||||
| x_thread[i] = x[i]; | ||||||||||||||||||||||
| x_thread[i + 1] = x[i + 1]; | ||||||||||||||||||||||
| x_thread[i + 2] = x[i + 2]; | ||||||||||||||||||||||
| x_thread[i + 3] = x[i + 3]; | ||||||||||||||||||||||
| x_thread[i + 4] = x[i + 4]; | ||||||||||||||||||||||
| x_thread[i + 5] = x[i + 5]; | ||||||||||||||||||||||
| x_thread[i + 6] = x[i + 6]; | ||||||||||||||||||||||
| x_thread[i + 7] = x[i + 7]; | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| else if (bits == 2) { | ||||||||||||||||||||||
| for (int i = 0; i < values_per_thread; i += 4) { | ||||||||||||||||||||||
| sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; | ||||||||||||||||||||||
| x_thread[i] = x[i]; | ||||||||||||||||||||||
|
|
@@ -107,13 +122,28 @@ inline U load_vector(const device T* x, thread U* x_thread) { | |||||||||||||||||||||
| template <typename T, typename U, int values_per_thread, int bits> | ||||||||||||||||||||||
| inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { | ||||||||||||||||||||||
| static_assert( | ||||||||||||||||||||||
| bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || | ||||||||||||||||||||||
| bits == 8, | ||||||||||||||||||||||
| "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); | ||||||||||||||||||||||
| bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || | ||||||||||||||||||||||
| bits == 6 || bits == 8, | ||||||||||||||||||||||
| "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| U sum = 0; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if (bits == 2) { | ||||||||||||||||||||||
| if (bits == 1) { | ||||||||||||||||||||||
| for (int i = 0; i < N; i += 8) { | ||||||||||||||||||||||
| sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + | ||||||||||||||||||||||
| x[i + 6] + x[i + 7]; | ||||||||||||||||||||||
| x_thread[i] = x[i]; | ||||||||||||||||||||||
| x_thread[i + 1] = x[i + 1]; | ||||||||||||||||||||||
| x_thread[i + 2] = x[i + 2]; | ||||||||||||||||||||||
| x_thread[i + 3] = x[i + 3]; | ||||||||||||||||||||||
| x_thread[i + 4] = x[i + 4]; | ||||||||||||||||||||||
| x_thread[i + 5] = x[i + 5]; | ||||||||||||||||||||||
| x_thread[i + 6] = x[i + 6]; | ||||||||||||||||||||||
| x_thread[i + 7] = x[i + 7]; | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| else if (bits == 2) { | ||||||||||||||||||||||
| for (int i = 0; i < N; i += 4) { | ||||||||||||||||||||||
| sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; | ||||||||||||||||||||||
| x_thread[i] = x[i]; | ||||||||||||||||||||||
|
|
@@ -196,13 +226,27 @@ inline U qdot( | |||||||||||||||||||||
| U bias, | ||||||||||||||||||||||
| U sum) { | ||||||||||||||||||||||
| static_assert( | ||||||||||||||||||||||
| bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || | ||||||||||||||||||||||
| bits == 8, | ||||||||||||||||||||||
| "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); | ||||||||||||||||||||||
| bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || | ||||||||||||||||||||||
| bits == 6 || bits == 8, | ||||||||||||||||||||||
| "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| U accum = 0; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if (bits == 2) { | ||||||||||||||||||||||
| if (bits == 1) { | ||||||||||||||||||||||
| for (int i = 0; i < (values_per_thread / 8); i++) { | ||||||||||||||||||||||
| uint8_t wb = w[i]; | ||||||||||||||||||||||
| accum += select(U(0), x_thread[8 * i], bool(wb & 0x01)); | ||||||||||||||||||||||
| accum += select(U(0), x_thread[8 * i + 1], bool(wb & 0x02)); | ||||||||||||||||||||||
| accum += select(U(0), x_thread[8 * i + 2], bool(wb & 0x04)); | ||||||||||||||||||||||
| accum += select(U(0), x_thread[8 * i + 3], bool(wb & 0x08)); | ||||||||||||||||||||||
| accum += select(U(0), x_thread[8 * i + 4], bool(wb & 0x10)); | ||||||||||||||||||||||
| accum += select(U(0), x_thread[8 * i + 5], bool(wb & 0x20)); | ||||||||||||||||||||||
| accum += select(U(0), x_thread[8 * i + 6], bool(wb & 0x40)); | ||||||||||||||||||||||
| accum += select(U(0), x_thread[8 * i + 7], bool(wb & 0x80)); | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| else if (bits == 2) { | ||||||||||||||||||||||
| for (int i = 0; i < (values_per_thread / 4); i++) { | ||||||||||||||||||||||
| accum += | ||||||||||||||||||||||
| (x_thread[4 * i] * (w[i] & 0x03) + | ||||||||||||||||||||||
|
|
@@ -298,13 +342,27 @@ inline U qdot_safe( | |||||||||||||||||||||
| U sum, | ||||||||||||||||||||||
| int N) { | ||||||||||||||||||||||
| static_assert( | ||||||||||||||||||||||
| bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || | ||||||||||||||||||||||
| bits == 8, | ||||||||||||||||||||||
| "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); | ||||||||||||||||||||||
| bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || | ||||||||||||||||||||||
| bits == 6 || bits == 8, | ||||||||||||||||||||||
| "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| U accum = 0; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if (bits == 2) { | ||||||||||||||||||||||
| if (bits == 1) { | ||||||||||||||||||||||
| for (int i = 0; i < (N / 8); i++) { | ||||||||||||||||||||||
| uint8_t wb = w[i]; | ||||||||||||||||||||||
| accum += select(U(0), x_thread[8 * i], bool(wb & 0x01)); | ||||||||||||||||||||||
| accum += select(U(0), x_thread[8 * i + 1], bool(wb & 0x02)); | ||||||||||||||||||||||
| accum += select(U(0), x_thread[8 * i + 2], bool(wb & 0x04)); | ||||||||||||||||||||||
| accum += select(U(0), x_thread[8 * i + 3], bool(wb & 0x08)); | ||||||||||||||||||||||
| accum += select(U(0), x_thread[8 * i + 4], bool(wb & 0x10)); | ||||||||||||||||||||||
| accum += select(U(0), x_thread[8 * i + 5], bool(wb & 0x20)); | ||||||||||||||||||||||
| accum += select(U(0), x_thread[8 * i + 6], bool(wb & 0x40)); | ||||||||||||||||||||||
| accum += select(U(0), x_thread[8 * i + 7], bool(wb & 0x80)); | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| else if (bits == 2) { | ||||||||||||||||||||||
| for (int i = 0; i < (N / 4); i++) { | ||||||||||||||||||||||
| accum += | ||||||||||||||||||||||
| (x_thread[4 * i] * (w[i] & 0x03) + | ||||||||||||||||||||||
|
|
@@ -395,11 +453,25 @@ template <typename U, int values_per_thread, int bits> | |||||||||||||||||||||
| inline void | ||||||||||||||||||||||
| qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { | ||||||||||||||||||||||
| static_assert( | ||||||||||||||||||||||
| bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || | ||||||||||||||||||||||
| bits == 8, | ||||||||||||||||||||||
| "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); | ||||||||||||||||||||||
| bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || | ||||||||||||||||||||||
| bits == 6 || bits == 8, | ||||||||||||||||||||||
| "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if (bits == 2) { | ||||||||||||||||||||||
| if (bits == 1) { | ||||||||||||||||||||||
| for (int i = 0; i < (values_per_thread / 8); i++) { | ||||||||||||||||||||||
| uint8_t wb = w[i]; | ||||||||||||||||||||||
| result[8 * i] += x * (select(U(0), scale, bool(wb & 0x01)) + bias); | ||||||||||||||||||||||
| result[8 * i + 1] += x * (select(U(0), scale, bool(wb & 0x02)) + bias); | ||||||||||||||||||||||
| result[8 * i + 2] += x * (select(U(0), scale, bool(wb & 0x04)) + bias); | ||||||||||||||||||||||
| result[8 * i + 3] += x * (select(U(0), scale, bool(wb & 0x08)) + bias); | ||||||||||||||||||||||
| result[8 * i + 4] += x * (select(U(0), scale, bool(wb & 0x10)) + bias); | ||||||||||||||||||||||
| result[8 * i + 5] += x * (select(U(0), scale, bool(wb & 0x20)) + bias); | ||||||||||||||||||||||
| result[8 * i + 6] += x * (select(U(0), scale, bool(wb & 0x40)) + bias); | ||||||||||||||||||||||
| result[8 * i + 7] += x * (select(U(0), scale, bool(wb & 0x80)) + bias); | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| else if (bits == 2) { | ||||||||||||||||||||||
| U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; | ||||||||||||||||||||||
| for (int i = 0; i < (values_per_thread / 4); i++) { | ||||||||||||||||||||||
| result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias); | ||||||||||||||||||||||
|
|
@@ -484,11 +556,33 @@ template <typename U, int N, int bits> | |||||||||||||||||||||
| inline void | ||||||||||||||||||||||
| dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { | ||||||||||||||||||||||
| static_assert( | ||||||||||||||||||||||
| bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || | ||||||||||||||||||||||
| bits == 8, | ||||||||||||||||||||||
| "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); | ||||||||||||||||||||||
| bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || | ||||||||||||||||||||||
| bits == 6 || bits == 8, | ||||||||||||||||||||||
| "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if (bits == 1) { | ||||||||||||||||||||||
| U s[8] = { | ||||||||||||||||||||||
| scale, | ||||||||||||||||||||||
| scale / static_cast<U>(2.0f), | ||||||||||||||||||||||
| scale / static_cast<U>(4.0f), | ||||||||||||||||||||||
| scale / static_cast<U>(8.0f), | ||||||||||||||||||||||
| scale / static_cast<U>(16.0f), | ||||||||||||||||||||||
| scale / static_cast<U>(32.0f), | ||||||||||||||||||||||
| scale / static_cast<U>(64.0f), | ||||||||||||||||||||||
| scale / static_cast<U>(128.0f)}; | ||||||||||||||||||||||
| for (int i = 0; i < (N / 8); i++) { | ||||||||||||||||||||||
| w_local[8 * i] = s[0] * (w[i] & 0x01) + bias; | ||||||||||||||||||||||
| w_local[8 * i + 1] = s[1] * (w[i] & 0x02) + bias; | ||||||||||||||||||||||
| w_local[8 * i + 2] = s[2] * (w[i] & 0x04) + bias; | ||||||||||||||||||||||
| w_local[8 * i + 3] = s[3] * (w[i] & 0x08) + bias; | ||||||||||||||||||||||
| w_local[8 * i + 4] = s[4] * (w[i] & 0x10) + bias; | ||||||||||||||||||||||
| w_local[8 * i + 5] = s[5] * (w[i] & 0x20) + bias; | ||||||||||||||||||||||
| w_local[8 * i + 6] = s[6] * (w[i] & 0x40) + bias; | ||||||||||||||||||||||
| w_local[8 * i + 7] = s[7] * (w[i] & 0x80) + bias; | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if (bits == 2) { | ||||||||||||||||||||||
| else if (bits == 2) { | ||||||||||||||||||||||
| U s[4] = { | ||||||||||||||||||||||
| scale, | ||||||||||||||||||||||
| scale / static_cast<U>(4.0f), | ||||||||||||||||||||||
|
|
@@ -577,9 +671,9 @@ struct QuantizedBlockLoader { | |||||||||||||||||||||
| group_size % BCOLS == 0, | ||||||||||||||||||||||
| "The group size should be divisible by the columns"); | ||||||||||||||||||||||
| static_assert( | ||||||||||||||||||||||
| bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || | ||||||||||||||||||||||
| bits == 8, | ||||||||||||||||||||||
| "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); | ||||||||||||||||||||||
| bits == 1 || bits == 2 || bits == 3 || bits == 4 || bits == 5 || | ||||||||||||||||||||||
| bits == 6 || bits == 8, | ||||||||||||||||||||||
| "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 8}"); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| MLX_MTL_CONST short pack_factor = get_pack_factor<bits, 8>(); | ||||||||||||||||||||||
| MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack<bits>(); | ||||||||||||||||||||||
|
|
@@ -786,7 +880,9 @@ METAL_FUNC void qmv_fast_impl( | |||||||||||||||||||||
| x += tid.x * in_vec_size + simd_lid * values_per_thread; | ||||||||||||||||||||||
| y += tid.x * out_vec_size + out_row; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| for (int k = 0; k < in_vec_size; k += block_size) { | ||||||||||||||||||||||
| const int aligned_end = (in_vec_size / block_size) * block_size; | ||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is that needed? Is something changed for the 1 bit compared to the 2 bits in the launch parameters? I am not against this addition but it is not clear why it is needed for this PR. It might generally be a better choice to route more cases to qmv_fast with an epilogue such as this instead of the plain qmv . Either way it should be clear why that is needed and likely not in this PR.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah this I added later on, forgot to include in the PR. The main reason is the shapes are not divisible by 2048 for the 4B variant. But block size becomes 2048 for 1-bit the way we did it (at least if I understood correctly) so they will be some left overs not handled, the other for loop tries to handle that, for larger bit this won't be an issue as block size is less. The sizes are "hidden_size=2560 and intermediate_size=9728" neither is divisible by 2048 but they are divisible by 512 so 2-bit works okay without this
oh maybe we can fix this in simpler way by packs_per_thread=1 also for 1-bit? bcomes: need to test correctness and speed of kernels, will try if that works can simplify here need to think more how to generalize (does group size matter here? for us we were doing 128).
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. on second thought even packs_per_thread=1 will need things to be divisible by 1024 which still has issues for the 4B sizes.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So you kind of need to look closer into the whole kernel routing. It might warrant a more general change and this epilogue is not bad. Here is what I mean: The
In all of the above you should run some micro and macro benchmarks to evaluate perf.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, yeah need to take a closer look
Thought about this a bit, but even if I set
I think tried something similar to this and then 4B model was going through the slower kernel paths, and even was slower than the 8B model, need to double check my notes.
Yeah this could work, just want to make sure it does not affect other stuff. I mainly ran benchmarks for the 3 models we have (in 1-bit and in 16-bit). For 8B MLX 1-bit makes us 8.4x faster, and for 4B/1.7BB around 4-5x faster. The 4B is also not as fast as I was expecting, could be due to the epilogue. Overall, which one do you think is the best option short term? Long term I think there is a lot more room for tuning these.
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| for (int k = 0; k < aligned_end; k += block_size) { | ||||||||||||||||||||||
| U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| for (int row = 0; row < results_per_simdgroup; row++) { | ||||||||||||||||||||||
|
|
@@ -805,6 +901,27 @@ METAL_FUNC void qmv_fast_impl( | |||||||||||||||||||||
| x += block_size; | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if (aligned_end < in_vec_size) { | ||||||||||||||||||||||
| bool in_bounds = (aligned_end + simd_lid * values_per_thread) < in_vec_size; | ||||||||||||||||||||||
| U sum = 0; | ||||||||||||||||||||||
| if (in_bounds) { | ||||||||||||||||||||||
| sum = load_vector<T, U, values_per_thread, bits>(x, x_thread); | ||||||||||||||||||||||
| } else { | ||||||||||||||||||||||
| for (int i = 0; i < values_per_thread; i++) | ||||||||||||||||||||||
| x_thread[i] = 0; | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| for (int row = 0; row < results_per_simdgroup; row++) { | ||||||||||||||||||||||
| auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); | ||||||||||||||||||||||
| const device T* sl = scales + row * in_vec_size_g; | ||||||||||||||||||||||
| const device T* bl = biases + row * in_vec_size_g; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| U s = in_bounds ? (U)sl[0] : (U)0; | ||||||||||||||||||||||
| U b = in_bounds ? (U)bl[0] : (U)0; | ||||||||||||||||||||||
| result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum); | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| for (int row = 0; row < results_per_simdgroup; row++) { | ||||||||||||||||||||||
| result[row] = simd_sum(result[row]); | ||||||||||||||||||||||
| if (simd_lid == 0) { | ||||||||||||||||||||||
|
|
@@ -2472,14 +2589,23 @@ template <typename T, const int group_size, const int bits> | |||||||||||||||||||||
| w_min = simd_min(w_min); | ||||||||||||||||||||||
| w_max = simd_max(w_max); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| float scale = max((w_max - w_min) / n_bins, eps); | ||||||||||||||||||||||
| bool side = abs(w_min) > abs(w_max); | ||||||||||||||||||||||
| scale = side ? scale : -scale; | ||||||||||||||||||||||
| float edge = side ? w_min : w_max; | ||||||||||||||||||||||
| float q0 = round(edge / scale); | ||||||||||||||||||||||
| bool at_zero = q0 == 0.0f; | ||||||||||||||||||||||
| scale = at_zero ? scale : edge / q0; | ||||||||||||||||||||||
| float bias = at_zero ? 0 : edge; | ||||||||||||||||||||||
| float scale; | ||||||||||||||||||||||
| float bias; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if (bits == 1) { | ||||||||||||||||||||||
| // Affine 1-bit: bit 0 -> w_min, bit 1 -> w_max | ||||||||||||||||||||||
| scale = max(w_max - w_min, eps); | ||||||||||||||||||||||
| bias = w_min; | ||||||||||||||||||||||
| } else { | ||||||||||||||||||||||
| scale = max((w_max - w_min) / n_bins, eps); | ||||||||||||||||||||||
| bool side = abs(w_min) > abs(w_max); | ||||||||||||||||||||||
| scale = side ? scale : -scale; | ||||||||||||||||||||||
| float edge = side ? w_min : w_max; | ||||||||||||||||||||||
| float q0 = round(edge / scale); | ||||||||||||||||||||||
| bool at_zero = q0 == 0.0f; | ||||||||||||||||||||||
| scale = at_zero ? scale : edge / q0; | ||||||||||||||||||||||
| bias = at_zero ? 0 : edge; | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| // Write out the scales and biases | ||||||||||||||||||||||
| size_t gindex = in_index / group_size; | ||||||||||||||||||||||
|
|
@@ -2583,7 +2709,9 @@ template <typename T, const int group_size, const int bits> | |||||||||||||||||||||
| #pragma clang loop unroll(full) | ||||||||||||||||||||||
| for (int i = 0; i < pack_factor; i++) { | ||||||||||||||||||||||
| uint8_t d; | ||||||||||||||||||||||
| if (bits == 2) { | ||||||||||||||||||||||
| if (bits == 1) { | ||||||||||||||||||||||
| d = (val >> i) & 0x01; | ||||||||||||||||||||||
| } else if (bits == 2) { | ||||||||||||||||||||||
| d = (val >> (bits * i)) & 0x03; | ||||||||||||||||||||||
| } else if (bits == 4) { | ||||||||||||||||||||||
| d = (val >> (bits * i)) & 0x0f; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||


There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah we probably need to simply add loops that create these.