diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 3d68b8004..81260bad7 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -330,6 +330,242 @@ static void dump_packed_block_q4x4x2(const uint8_t * v, unsigned int i, size_t k GGML_FP16_TO_FP32(d[4]), GGML_FP16_TO_FP32(d[5]), GGML_FP16_TO_FP32(d[6]), GGML_FP16_TO_FP32(d[7])); } + + + + + +// initialize q4_1 buffer (padding with zeros) +static void init_row_q4_1x4x2(block_q4_1 * x, int64_t k) { + static const int qk = QK_Q4_1x4x2; + const int nb = (k + qk - 1) / qk; + + for (int i = 0; i < nb; i++) { + x[i * 8 + 0].d = 0; + x[i * 8 + 1].d = 0; + x[i * 8 + 2].d = 0; + x[i * 8 + 3].d = 0; + x[i * 8 + 4].d = 0; + x[i * 8 + 5].d = 0; + x[i * 8 + 6].d = 0; + x[i * 8 + 7].d = 0; + x[i * 8 + 0].m = 0; + x[i * 8 + 1].m = 0; + x[i * 8 + 2].m = 0; + x[i * 8 + 3].m = 0; + x[i * 8 + 4].m = 0; + x[i * 8 + 5].m = 0; + x[i * 8 + 6].m = 0; + x[i * 8 + 7].m = 0; + } +} + +static void unpack_q4_1_quants(uint8_t * x, const block_q4_1 * b, unsigned int bi) { + static const int qk = QK4_1; + + for (unsigned int i = 0; i < qk / 2; ++i) { + const uint8_t q0 = b->qs[i] & 0x0F; + const uint8_t q1 = b->qs[i] >> 4; + x[bi * qk + i + 0] = q0; + x[bi * qk + i + qk / 2] = q1; + } +} + +static void pack_q4_1_quants(block_q4_1 * x, const uint8_t * qs, unsigned int bi) { + static const int qk = QK4_1; + + for (unsigned int i = 0; i < qk / 2; ++i) { + const uint8_t x0 = qs[bi * qk + i + 0]; + const uint8_t x1 = qs[bi * qk + i + qk / 2]; + x->qs[i] = x0 | (x1 << 4); + } +} + + +static void repack_row_q4_1x4x2(uint8_t * y, const block_q4_1 * x, int64_t k) { + static const int qk = QK_Q4_1x4x2; + const int nb = (k + qk - 1) / qk; + const int nloe = k % qk; + + const int dblk_size = 8 * 2; + const int qblk_size = qk / 2; + const int qrow_size = k / 2; + const int drow_size = k / 32 * 2; // k/32 blocks, each has 2 bytes d + + uint8_t * y_q = y + 0; + uint8_t * y_d = y + qrow_size; + uint8_t * y_m = y_d + drow_size; + + for (int i = 0; i < nb; i++) { + uint8_t qs[QK_Q4_1x4x2]; + unpack_q4_1_quants(qs, &x[i * 8 + 0], 0); + unpack_q4_1_quants(qs, &x[i * 8 + 1], 1); + unpack_q4_1_quants(qs, &x[i * 8 + 2], 2); + unpack_q4_1_quants(qs, &x[i * 8 + 3], 3); + unpack_q4_1_quants(qs, &x[i * 8 + 4], 4); + unpack_q4_1_quants(qs, &x[i * 8 + 5], 5); + unpack_q4_1_quants(qs, &x[i * 8 + 6], 6); + unpack_q4_1_quants(qs, &x[i * 8 + 7], 7); + + bool partial = (nloe && i == nb-1); + + uint8_t * q = y_q + (i * qblk_size); + for (int j = 0; j < qk / 2; j++) { + if (partial && j >= nloe/2) break; + const uint8_t x0 = qs[j + 0]; + const uint8_t x1 = qs[j + qk / 2]; + q[j] = x0 | (x1 << 4); + } + + uint8_t * d = y_d + (i * dblk_size); + uint8_t * m = y_m + (i * dblk_size); + + int max_blks = partial ? nloe / 32 : 8; + for (int j = 0; j < 8; j++) { + if (j < max_blks) { + ((ggml_fp16_t*)d)[j] = x[i * 8 + j].d; + ((ggml_fp16_t*)m)[j] = x[i * 8 + j].m; + } else { + ((ggml_fp16_t*)d)[j] = 0; + ((ggml_fp16_t*)m)[j] = 0; + } + } + } +} + +static void unpack_row_q4_1x4x2(uint8_t * y, const block_q4_1 * x, int64_t k) { + static const int qk = QK_Q4_1x4x2; + const int nb = (k + qk - 1) / qk; + const int nloe = k % qk; + + const int dblk_size = 8 * 2; + const int qblk_size = qk / 2; + const int qrow_size = k / 2; + const int drow_size = k / 32 * 2; + + const uint8_t * y_q = (const uint8_t *)x; + const uint8_t * y_d = y_q + qrow_size; + const uint8_t * y_m = y_d + drow_size; + + block_q4_1 * out = (block_q4_1 *)y; + + for (int i = 0; i < nb; i++) { + uint8_t qs[QK_Q4_1x4x2]; + const uint8_t * q = y_q + (i * qblk_size); + + bool partial = (nloe && i == nb-1); + + for (int j = 0; j < qk / 2; j++) { + if (partial && j >= nloe/2) { + qs[j + 0] = 0; + qs[j + qk / 2] = 0; + } else { + qs[j + 0] = q[j] & 0x0f; + qs[j + qk / 2] = q[j] >> 4; + } + } + + int max_blks = partial ? nloe / 32 : 8; + for (int j = 0; j < max_blks; j++) { + pack_q4_1_quants(&out[i * 8 + j], qs, j); + } + + const uint8_t * d = y_d + (i * dblk_size); + const uint8_t * m = y_m + (i * dblk_size); + + for (int j = 0; j < max_blks; j++) { + out[i * 8 + j].d = ((ggml_fp16_t*)d)[j]; + out[i * 8 + j].m = ((ggml_fp16_t*)m)[j]; + } + } +} + +static void repack_q4_1_q4_1x4x2(ggml_tensor * t, const void * data, size_t size) { + int64_t nrows = ggml_nrows(t); + + size_t row_size = ggml_row_size(t->type, t->ne[0]); + size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_1x4x2)); // extra elements for the pad + size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + + const size_t total_tensor_size = (size_t)nrows * row_size; + const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; + + const int64_t n_full_rows = n_bytes_to_copy / row_size; + const size_t n_rem_bytes = n_bytes_to_copy % row_size; + + void * buf_pd = ggml_aligned_malloc(row_size_pd); + GGML_ASSERT(buf_pd != NULL); + + void * buf_rp = ggml_aligned_malloc(row_size_rp); + GGML_ASSERT(buf_rp != NULL); + + HEX_VERBOSE("ggml-hex: repack-q4_1-q4_1x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu +", t->name, data, size, + t->ne[0], nrows, row_size); + + init_row_q4_1x4x2((block_q4_1 *) buf_pd, t->ne[0]); + + for (int64_t i = 0; i < n_full_rows; i++) { + const uint8_t * src = (const uint8_t *) data + (i * row_size); + uint8_t * dst = (uint8_t *) t->data + (i * row_size); + + memcpy(buf_pd, src, row_size); + repack_row_q4_1x4x2((uint8_t *) buf_rp, (const block_q4_1 *) buf_pd, t->ne[0]); + memcpy(dst, buf_rp, row_size); + } + + if (n_rem_bytes > 0) { + const int64_t i = n_full_rows; + const uint8_t * src = (const uint8_t *) data + (i * row_size); + uint8_t * dst = (uint8_t *) t->data + (i * row_size); + + memcpy(buf_pd, src, n_rem_bytes); + repack_row_q4_1x4x2((uint8_t *) buf_rp, (const block_q4_1 *) buf_pd, t->ne[0]); + memcpy(dst, buf_rp, row_size); + } + + ggml_aligned_free(buf_rp); + ggml_aligned_free(buf_pd); +} + +static void repack_q4_1x4x2_q4_1(void * data, const ggml_tensor * t, size_t size) { + int64_t nrows = ggml_nrows(t); + + size_t row_size = ggml_row_size(t->type, t->ne[0]); + size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_1x4x2)); + + const size_t total_tensor_size = (size_t)nrows * row_size; + const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; + + const int64_t n_full_rows = n_bytes_to_copy / row_size; + const size_t n_rem_bytes = n_bytes_to_copy % row_size; + + void * buf_pd = ggml_aligned_malloc(row_size_pd); + GGML_ASSERT(buf_pd != NULL); + + for (int64_t i = 0; i < n_full_rows; i++) { + const uint8_t * src = (const uint8_t *) t->data + (i * row_size); + uint8_t * dst = (uint8_t *) data + (i * row_size); + + unpack_row_q4_1x4x2((uint8_t *) buf_pd, (const block_q4_1 *) src, t->ne[0]); + memcpy(dst, buf_pd, row_size); + } + + if (n_rem_bytes > 0) { + const int64_t i = n_full_rows; + const uint8_t * src = (const uint8_t *) t->data + (i * row_size); + uint8_t * dst = (uint8_t *) data + (i * row_size); + + unpack_row_q4_1x4x2((uint8_t *) buf_pd, (const block_q4_1 *) src, t->ne[0]); + memcpy(dst, buf_pd, n_rem_bytes); + } + + ggml_aligned_free(buf_pd); +} + + + + static void unpack_q4_0_quants(uint8_t * qs, const block_q4_0 * x, unsigned int bi) { static const int qk = QK4_0; @@ -1349,6 +1585,11 @@ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer, GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); repack_q4_0_q4x4x2(tensor, data, size); break; + case GGML_TYPE_Q4_1: + GGML_ASSERT(offset == 0); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); + repack_q4_1_q4_1x4x2(tensor, data, size); + break; case GGML_TYPE_Q8_0: GGML_ASSERT(offset == 0); @@ -1391,6 +1632,11 @@ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer, GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); repack_q4x4x2_q4_0(data, tensor, size); break; + case GGML_TYPE_Q4_1: + GGML_ASSERT(offset == 0); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); + repack_q4_1x4x2_q4_1(data, tensor, size); + break; case GGML_TYPE_Q8_0: GGML_ASSERT(offset == 0); @@ -2163,6 +2409,7 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s switch (src0->type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_Q8_0: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: @@ -2213,6 +2460,7 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session switch (src0->type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_Q8_0: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index ec191c149..f5d05bcbd 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -86,6 +86,8 @@ static inline size_t get_x4x2_row_stride(int weight_type, int k) { case HTP_TYPE_Q4_0: case HTP_TYPE_IQ4_NL: return (size_t) nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb + case HTP_TYPE_Q4_1: + return (size_t) nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE * 2); // 160 * nb case HTP_TYPE_Q8_0: return (size_t) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb case HTP_TYPE_MXFP4: @@ -430,6 +432,18 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( v1[0] = v1[1] = v1[2] = v1[3] = Q6_V_vzero(); } + if (weight_type == HTP_TYPE_Q4_1) { + int m_off = (k_block / 32) * 2; + for (int g = 0; g < 4; g++) { + HVX_Vector m0 = Q6_V_vsplat_R(*(uint32_t *)(r0 + scale_off + g * sizeof(__fp16) + m_off)); + v0[g] = Q6_Vhf_vadd_VhfVhf(v0[g], m0); + if (row1 < n_cols) { + HVX_Vector m1 = Q6_V_vsplat_R(*(uint32_t *)(r1 + scale_off + g * sizeof(__fp16) + m_off)); + v1[g] = Q6_Vhf_vadd_VhfVhf(v1[g], m1); + } + } + } + for (int g = 0; g < 4; g++) { Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v0[g]); } v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); for (int g = 0; g < 4; g++) { Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v1[g]); } @@ -517,6 +531,16 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) : Q6_V_vzero(); + if (weight_type == HTP_TYPE_Q4_1) { + int m_off = (k_block / 32) * 2; + HVX_Vector m0 = Q6_V_vsplat_R(*(uint32_t *)(r0 + scale_off + m_off)); + v0 = Q6_Vhf_vadd_VhfVhf(v0, m0); + if (row1 < n_cols) { + HVX_Vector m1 = Q6_V_vsplat_R(*(uint32_t *)(r1 + scale_off + m_off)); + v1 = Q6_Vhf_vadd_VhfVhf(v1, m1); + } + } + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 44a6ab4f7..f0def9830 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -20,6 +20,7 @@ enum htp_data_type { HTP_TYPE_F32 = 0, HTP_TYPE_F16 = 1, HTP_TYPE_Q4_0 = 2, + HTP_TYPE_Q4_1 = 3, HTP_TYPE_Q8_0 = 8, HTP_TYPE_IQ4_NL = 20, HTP_TYPE_I32 = 26, @@ -28,6 +29,8 @@ enum htp_data_type { // types used internally for repack, dyn.quant, etc HTP_TYPE_Q4_0x4x2 = 200, + HTP_TYPE_Q4_1x4x2, + HTP_TYPE_Q8_1x4x2, HTP_TYPE_Q8_0x4x2, HTP_TYPE_MXFP4x4x2, @@ -35,7 +38,8 @@ enum htp_data_type { }; // Constats for internal types -#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128) +#define QK_Q4_0x4x2 256 +#define QK_Q4_1x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128) #define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks #define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index bac06693d..cfad4bdff 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -148,6 +148,130 @@ static inline HVX_Vector_x8 hvx_vec_load_iq4nlx4x8_partial(const uint8_t * restr // q4x4x2 and q8x4x2 are the flat q4/8_0 formats where all quants are stored first followed by all scales + + +static inline size_t q8_1x4x2_row_size(int ne0) { + int nb = ne0 / QK_Q4_1x4x2; + // Each QK_Q4_1x4x2 (256 elements) has 8 blocks of 32 elements. + // For q8_1x4x2 we store: + // - quants: 256 bytes + // - scales: 8 * sizeof(ggml_fp16_t) = 16 bytes + // - mins/padding: 8 * sizeof(ggml_fp16_t) = 16 bytes + // - sums: 8 * sizeof(float) = 32 bytes + return nb * (QK_Q4_1x4x2 + 8 * sizeof(ggml_fp16_t) * 2 + 8 * sizeof(float)); +} + +static void quantize_row_f32_q8_1x4x2(const float * src, void * dst, int ne0) { + const int nb = ne0 / QK_Q4_1x4x2; + + int8_t * out_qs = (int8_t *)dst; + ggml_fp16_t * out_ds = (ggml_fp16_t *)(out_qs + nb * QK_Q4_1x4x2); + ggml_fp16_t * out_ms = out_ds + nb * 8; + float * out_sums = (float *)(out_ms + nb * 8); + + for (int i = 0; i < nb; i++) { + for (int b = 0; b < 8; b++) { + float max = 0.0f; + float amax = 0.0f; + float sum = 0.0f; + for (int j = 0; j < 32; j++) { + float v = src[i * QK_Q4_1x4x2 + b * 32 + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + sum += v; + } + + const float d = max / 127.0f; + const float id = d ? 1.0f / d : 0.0f; + + out_ds[i * 8 + b] = ((union { __fp16 f; uint16_t i; }){ .f = d }).i; + out_ms[i * 8 + b] = ((union { __fp16 f; uint16_t i; }){ .f = 0.0f }).i; + out_sums[i * 8 + b] = sum; + + for (int j = 0; j < 32; j++) { + const float v = src[i * QK_Q4_1x4x2 + b * 32 + j] * id; + out_qs[i * QK_Q4_1x4x2 + b * 32 + j] = (int8_t)roundf(v); + } + } + } +} + +// Dummy vector dot for testing +static void vec_dot_q4_1x4x2_q8_1x4x2_1x1(const void * vsrc0, const void * vsrc1, float * dst, int ne00) { + const uint8_t * src0 = (const uint8_t *)vsrc0; + const uint8_t * src1 = (const uint8_t *)vsrc1; + + const int nb = ne00 / QK_Q4_1x4x2; + const int num_qs_0 = nb * QK_Q4_1x4x2 / 2; + + const uint8_t * qs0 = src0; + const ggml_fp16_t * ds0 = (const ggml_fp16_t *)(src0 + num_qs_0); + const ggml_fp16_t * ms0 = ds0 + nb * 8; + + const int8_t * qs1 = (const int8_t *)src1; + const ggml_fp16_t * ds1 = (const ggml_fp16_t *)(qs1 + nb * QK_Q4_1x4x2); + const float * sums1 = (const float *)(ds1 + nb * 8 + nb * 8); + + float sumf = 0.0f; + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < 128; j++) { + uint8_t q0 = qs0[i * 128 + j]; + + int8_t q1_0 = qs1[i * 256 + j]; + int8_t q1_1 = qs1[i * 256 + j + 128]; + + int blk0 = j / 32; + int blk1 = j / 32 + 4; + + const float d0_0 = (float)((__fp16 *)ds0)[i * 8 + blk0]; + const float d1_0 = (float)((__fp16 *)ds1)[i * 8 + blk0]; + + const float d0_1 = (float)((__fp16 *)ds0)[i * 8 + blk1]; + const float d1_1 = (float)((__fp16 *)ds1)[i * 8 + blk1]; + + sumf += (q0 & 0x0F) * q1_0 * d0_0 * d1_0; + sumf += (q0 >> 4) * q1_1 * d0_1 * d1_1; + } + + for (int b = 0; b < 8; b++) { + const float m0 = (float)((__fp16 *)ms0)[i * 8 + b]; + const float sum1 = sums1[i * 8 + b]; + sumf += m0 * sum1; + } + } + + *dst = sumf; +} + +static void vec_dot_q4_1x4x2_q8_1x4x2_2x1(const void * vsrc0, const void * vsrc1, float * dst, int ne00) { + const size_t src0_row_size = ne00 / QK_Q4_1x4x2 * (QK_Q4_1x4x2 / 2 + sizeof(ggml_fp16_t) * 2); + const uint8_t * src0_0 = (const uint8_t *)vsrc0; + const uint8_t * src0_1 = src0_0 + src0_row_size; + + vec_dot_q4_1x4x2_q8_1x4x2_1x1(src0_0, vsrc1, dst + 0, ne00); + vec_dot_q4_1x4x2_q8_1x4x2_1x1(src0_1, vsrc1, dst + 1, ne00); +} + +static void vec_dot_q4_1x4x2_q8_1x4x2_2x2(const void * vsrc0, const void * vsrc1, float * dst, int ne00) { + const size_t src0_row_size = ne00 / QK_Q4_1x4x2 * (QK_Q4_1x4x2 / 2 + sizeof(ggml_fp16_t) * 2); + const size_t src1_row_size = q8_1x4x2_row_size(ne00); + + const uint8_t * src0_0 = (const uint8_t *)vsrc0; + const uint8_t * src0_1 = src0_0 + src0_row_size; + + const uint8_t * src1_0 = (const uint8_t *)vsrc1; + const uint8_t * src1_1 = src1_0 + src1_row_size; + + vec_dot_q4_1x4x2_q8_1x4x2_1x1(src0_0, src1_0, dst + 0, ne00); + vec_dot_q4_1x4x2_q8_1x4x2_1x1(src0_1, src1_0, dst + 1, ne00); + vec_dot_q4_1x4x2_q8_1x4x2_1x1(src0_0, src1_1, dst + 2, ne00); + vec_dot_q4_1x4x2_q8_1x4x2_1x1(src0_1, src1_1, dst + 3, ne00); +} + + static inline size_t q8x4x2_row_size(uint32_t ne) { // ensures perfect alignment of quants and full row const uint32_t qk = QK_Q8_0x4x2; @@ -2746,6 +2870,14 @@ static inline bool htp_is_permuted(const struct htp_tensor * t) { static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_type type) { switch (type) { + + case HTP_TYPE_Q4_1: + mmctx->type = "q4_1x4x2-q8_1x4x2"; + mmctx->vec_dot_1x1 = vec_dot_q4_1x4x2_q8_1x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_q4_1x4x2_q8_1x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_q4_1x4x2_q8_1x4x2_2x2; + break; + case HTP_TYPE_Q4_0: mmctx->type = "q4x4x2-f32"; mmctx->vec_dot_1x1 = vec_dot_q4x4x2_q8x4x2_1x1; @@ -3120,8 +3252,15 @@ int op_matmul_id(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } - quant_job_func = quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); + + if (src0->type == HTP_TYPE_Q4_1) { + quant_job_func = quantize_row_f32_q8_1x4x2; + src1_row_size = q8_1x4x2_row_size(ne10); + } else { + quant_job_func = quantize_f32_q8x4x2; + src1_row_size = q8x4x2_row_size(ne10); + } + const size_t src2_spad_size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread);