From e0c377fb6807ea0bbc7aecae23a0544ac064ff42 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Thu, 22 Jan 2026 12:20:50 +0800 Subject: [PATCH 01/11] Fix llama-bench -p -n where p<=256 --- ggml/src/ggml-openvino/utils.cpp | 12 +++++------- ggml/src/ggml-openvino/utils.h | 9 +++------ 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index f7d62588c87..2d30eef941f 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -768,14 +768,12 @@ graph_key compute_graph_key(ggml_cgraph * cgraph) { graph_key key; key.n_nodes = cgraph->n_nodes; - if (cgraph->n_nodes > 0) { - key.first_node_name = std::string(cgraph->nodes[0]->name); - key.last_node_name = std::string(cgraph->nodes[cgraph->n_nodes - 1]->name); - } else { - key.first_node_name = ""; - key.last_node_name = ""; + for (int i = 0; i < cgraph->n_nodes; ++i) { + const auto * node = cgraph->nodes[i]; + if (node->op == GGML_OP_SET_ROWS && strncmp(node->src[2]->name, "cache_k_l0", 10) == 0) { + key.cache_k_l0 = node->src[2]; + } } - return key; } diff --git a/ggml/src/ggml-openvino/utils.h b/ggml/src/ggml-openvino/utils.h index 47bf2d4ff17..72ef904f741 100644 --- a/ggml/src/ggml-openvino/utils.h +++ b/ggml/src/ggml-openvino/utils.h @@ -8,20 +8,17 @@ struct graph_key { size_t n_nodes; - std::string first_node_name; - std::string last_node_name; + void * cache_k_l0; bool operator==(const graph_key & other) const { - return n_nodes == other.n_nodes && first_node_name == other.first_node_name && - last_node_name == other.last_node_name; + return n_nodes == other.n_nodes && cache_k_l0 == other.cache_k_l0; } }; struct graph_key_hash { size_t operator()(const graph_key & key) const { size_t h = std::hash{}(key.n_nodes); - h ^= std::hash{}(key.first_node_name) + 0x9e3779b9 + (h << 6) + (h >> 2); - h ^= std::hash{}(key.last_node_name) + 0x9e3779b9 + (h << 6) + (h >> 2); + h ^= std::hash{}(key.cache_k_l0) + 0x9e3779b9 + (h << 6) + (h >> 2); return h; } }; From ff9bb1ab144343972e22e48d3d070857e9c50713 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Thu, 22 Jan 2026 15:52:10 +0800 Subject: [PATCH 02/11] Fix --direct-io 0 --- ggml/src/ggml-openvino/ggml-openvino.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index de986ea42d6..06bff5a2b77 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -943,7 +943,7 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con } static bool ggml_backend_openvino_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { - return ggml_backend_buft_is_openvino(buft) || ggml_backend_buft_is_openvino_host(buft); + return ggml_backend_buft_is_openvino(buft) || ggml_backend_buft_is_host(buft); GGML_UNUSED(dev); } From cd067dcbfedbbcdcd0493a4eebc739a6570cc24f Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Sat, 24 Jan 2026 17:16:06 +0800 Subject: [PATCH 03/11] Don't put kvcache on GPU in stateful mode --- ggml/src/ggml-openvino/ggml-openvino.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index 06bff5a2b77..8d6a0dbf335 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -140,7 +140,7 @@ static enum ggml_status ggml_backend_openvino_buffer_init_tensor(ggml_backend_bu // Put kvcache on device memory for GPU (NPU memory is too small even for kvcache) if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_ANY && strncmp(tensor->name, "cache_", 6) == 0 && !ctx->is_remote && - ggml_openvino_get_device_name() == "GPU") { + ggml_openvino_get_device_name() == "GPU" && !getenv("GGML_OPENVINO_STATEFUL_EXECUTION")) { GGML_ASSERT(ctx->tensor_extras.empty()); auto device = ctx->device; auto size = ctx->size; From e480d5bf00ece985bb8e3bc6bdb7bdb32d14f481 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Fri, 23 Jan 2026 15:49:01 +0800 Subject: [PATCH 04/11] Remove hardcode names --- ggml/src/ggml-openvino/ggml-decoder.cpp | 63 +++++++++++++------------ ggml/src/ggml-openvino/ggml-decoder.h | 8 ++-- ggml/src/ggml-openvino/utils.cpp | 4 +- 3 files changed, 39 insertions(+), 36 deletions(-) diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index b8fe6358c8d..01e2c2ff193 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -169,9 +169,11 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) { // TODO: The shape modification for stateful model below is not validated for all supported models yet. More generic solution might be needed // to enable additional cases. Ideally, this could be removed from decoder and done as part of a transformation later. auto stateless_kv_shape = get_graph_input_shape(node, src); - assert(stateless_kv_shape.size() == 4 && stateless_kv_shape[0] == 1 && stateless_kv_shape[1] == 1 - && stateless_kv_shape[2].is_dynamic() && stateless_kv_shape[3] == (m_model_params.n_heads_kv*m_model_params.head_size)); - stateful_kv_shape = {stateless_kv_shape[0], ov::Dimension::dynamic(), m_model_params.n_heads_kv, m_model_params.head_size}; + assert(stateless_kv_shape.size() == 4 && stateless_kv_shape[0] == 1 && + stateless_kv_shape[1] == 1 && stateless_kv_shape[2].is_dynamic() && + stateless_kv_shape[3] == (m_model_params.n_heads_kv * m_model_params.head_size)); + stateful_kv_shape = {stateless_kv_shape[0], ov::Dimension::dynamic(), + m_model_params.n_heads_kv, m_model_params.head_size}; } } } @@ -180,9 +182,8 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) { } m_inputs[src_name] = src; assert(stateful_kv_shape.rank().is_static()); - ov::PartialShape param_shape = (stateful_kv_shape.rank().get_length() != 0) - ? stateful_kv_shape - : get_graph_input_shape(node, src); + ov::PartialShape param_shape = + (stateful_kv_shape.rank().get_length() != 0) ? stateful_kv_shape : get_graph_input_shape(node, src); auto param_node = std::make_shared(get_ov_type(src), param_shape); param_node->set_friendly_name(src_name); param_node->output(0).get_tensor().set_names({src_name}); @@ -197,7 +198,7 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) { static std::set debug_output_names = {}; // Workaround: the final tensor "result_output" does not have GGML_TENSOR_FLAG_OUTPUT flag set in cgraph if (node->op == GGML_OP_SET_ROWS || node->flags & GGML_TENSOR_FLAG_OUTPUT || - node_output_name.find("output") != std::string::npos || debug_output_names.count(node_output_name)) { + debug_output_names.count(node_output_name)) { if (m_model_outputs.find(node_output_name) == m_model_outputs.end()) { m_model_outputs[node_output_name] = node_output; } @@ -312,6 +313,11 @@ std::pair GgmlOvDecoder::compute_llm_params(ggml_cgr auto * node = cgraph->nodes[i]; std::string name = std::string(node->name); if (node->op == GGML_OP_FLASH_ATTN_EXT) { + model_params.n_heads = node->src[0]->ne[2]; + model_params.n_heads_kv = node->src[1]->ne[2]; + model_params.head_size = node->src[0]->ne[0]; + compute_params.input_len = node->src[0]->ne[1]; + auto * cache_k_perm = node->src[1]; if (cache_k_perm->op == GGML_OP_CPY) { cache_k_perm = cache_k_perm->src[0]; @@ -324,9 +330,8 @@ std::pair GgmlOvDecoder::compute_llm_params(ggml_cgr int layer = extract_layer_from_name(cache_k->name); auto * mask = node->src[3]; std::string mask_name(mask->name); - assert(mask_name.find("self_kq_mask") == 0); - if (std::string(node->src[3]->name).find("swa") != std::string::npos) { + if (mask_name.find("swa") != std::string::npos) { model_params.swa_layers.push_back(layer); model_params.ctx_per_seq_swa = cache_k->ne[1]; } else { @@ -351,25 +356,18 @@ std::pair GgmlOvDecoder::compute_llm_params(ggml_cgr compute_params.attention_size_swa = model_params.ctx_per_seq_swa; compute_params.token_len_per_seq = 1; } - - } else if (node->op == GGML_OP_ROPE) { - if (name.find("Qcur-0") == 0 || std::string(node->src[0]->name).find("Qcur-0") == 0) { - model_params.head_size = node->ne[0]; - model_params.n_heads = node->ne[1]; - model_params.rope_params = node->op_params; - auto * inp_pos = node->src[1]; - compute_params.input_len = inp_pos->ne[0]; - } else if (name.find("Kcur-0") == 0 || std::string(node->src[0]->name).find("Kcur-0") == 0) { - model_params.n_heads_kv = node->ne[1]; - } - } else if (node->op == GGML_OP_GET_ROWS && std::string(node->src[1]->name) == "inp_out_ids") { - // for static case, output_len is always 1 except for llama-perplexity - compute_params.output_len = node->src[1]->ne[0]; - if (is_static && compute_params.output_len == 0) { - compute_params.output_len = 1; - } + break; + } + if (node->op == GGML_OP_ROPE) { + model_params.rope_params = node->op_params; } } + auto * output_tensor = cgraph->nodes[cgraph->n_nodes - 1]; + compute_params.output_len = output_tensor->ne[1]; + // for NPU, output_len is always 1 except for llama-perplexity + if (is_static && compute_params.output_len == 0) { + compute_params.output_len = 1; + } model_params.ctx = model_params.ctx_per_seq * model_params.n_seq; model_params.ctx_swa = model_params.ctx_per_seq_swa * model_params.n_seq; return {model_params, compute_params}; @@ -385,14 +383,17 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * op, co auto name = std::string(input->name); ov::PartialShape input_shape; - if (name == "inp_tokens" || name == "inp_pos") { + if ((op->op == GGML_OP_GET_ROWS && op->src[0]->op == GGML_OP_NONE) || op->op == GGML_OP_ROPE) { + // tokens or positions int len = m_is_static ? (m_is_prefill ? m_prefill_chunk_size : 1) : -1; input_shape = ov::PartialShape{1, 1, 1, len}; - } else if (name == "inp_out_ids") { + } else if (op->op == GGML_OP_GET_ROWS) { + // output index input_shape = ov::PartialShape{1, 1, 1, m_is_static ? m_compute_params.output_len : -1}; - } else if (name.find("self_kq_mask") == 0) { + } else if (op->op == GGML_OP_CPY || op->op == GGML_OP_FLASH_ATTN_EXT) { + // mask if (m_is_static) { input_shape = ov::PartialShape{1, 1, m_is_prefill ? m_prefill_chunk_size : 1, m_model_params.ctx}; } else if (m_is_stateful) { @@ -401,7 +402,8 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * op, co input_shape = ov::PartialShape{-1, 1, -1, -1}; } - } else if (name.find("cache_") == 0) { + } else if (op && op->op == GGML_OP_SET_ROWS && op->src[2] == input) { + // kvcache input_shape = ov::PartialShape{get_shape(input)}; if (!m_is_static) { // do not fix ctx size to make llama-bench work @@ -409,6 +411,7 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * op, co } } else if (op && op->op == GGML_OP_SET_ROWS && op->src[1] == input) { + // kv update index int len = m_is_static ? (m_is_prefill ? m_prefill_chunk_size : 1) : -1; input_shape = ov::PartialShape{1, 1, 1, len}; diff --git a/ggml/src/ggml-openvino/ggml-decoder.h b/ggml/src/ggml-openvino/ggml-decoder.h index 4afec272e1a..c0d18b7512e 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.h +++ b/ggml/src/ggml-openvino/ggml-decoder.h @@ -16,7 +16,7 @@ struct ModelParams { int ctx_swa = -1; int ctx_per_seq = -1; int ctx_per_seq_swa = -1; - int n_seq = -1; + int n_seq = 1; int n_heads = -1; int n_heads_kv = -1; int head_size = -1; @@ -37,14 +37,14 @@ struct ModelParams { }; struct ComputeParams { - int n_seq_active = -1; - int seq_active_start = -1; + int n_seq_active = 1; + int seq_active_start = 0; int attention_size = -1; int attention_size_swa = -1; int input_len = -1; int token_len_per_seq = -1; int past_kv_len = -1; - int output_len = -1; + int output_len = 1; }; class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder { diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index 2d30eef941f..8c3717472b4 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -614,10 +614,10 @@ ov::Tensor get_ov_output_tensor(std::shared_ptr ggml_decoder, con auto output_type = ggml_decoder->get_ov_type(ggml_tensor); auto output_shape = ggml_decoder->get_shape(ggml_tensor); - if (ggml_decoder->is_static() && result_name == "result_output" && output_shape[2] == 0) { + if (ggml_decoder->is_static() && output_shape[2] == 0) { output_shape[2] = 1; } - if (ggml_decoder->is_stateful() && result_name == "result_output") { + if (ggml_decoder->is_stateful() && ggml_tensor->flags & GGML_TENSOR_FLAG_OUTPUT) { std::vector output_shape_3d; for (size_t i=1; i Date: Fri, 23 Jan 2026 15:49:36 +0800 Subject: [PATCH 05/11] Fix stateful shapes --- .../ggml-openvino/openvino/op/glu_geglu.cpp | 2 +- .../ggml-openvino/openvino/op/glu_swiglu.cpp | 2 +- ggml/src/ggml-openvino/openvino/op/rope.cpp | 22 +++++-------------- ggml/src/ggml-openvino/openvino/utils.cpp | 2 +- ggml/src/ggml-openvino/utils.cpp | 2 ++ 5 files changed, 11 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp b/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp index ad5cd3f6ba5..8be9e8deb06 100644 --- a/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +++ b/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp @@ -26,7 +26,7 @@ OutputVector translate_glu_geglu(const NodeContext & context) { src1 = context.get_input(1); } else { auto combined = context.get_input(0); - auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {3}); + auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {-1}); auto split = std::make_shared(combined, split_axis, 2); src0 = split->output(0); src1 = split->output(1); diff --git a/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp b/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp index 2b7f13629f2..6e0b85517e6 100644 --- a/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +++ b/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp @@ -26,7 +26,7 @@ OutputVector translate_glu_swiglu(const NodeContext & context) { src1 = context.get_input(1); } else { auto combined = context.get_input(0); - auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {3}); + auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {-1}); auto split = std::make_shared(combined, split_axis, 2); src0 = split->output(0); src1 = split->output(1); diff --git a/ggml/src/ggml-openvino/openvino/op/rope.cpp b/ggml/src/ggml-openvino/openvino/op/rope.cpp index 01bc46131e1..44e3368217e 100644 --- a/ggml/src/ggml-openvino/openvino/op/rope.cpp +++ b/ggml/src/ggml-openvino/openvino/op/rope.cpp @@ -70,22 +70,16 @@ OutputVector translate_rope(const NodeContext & context) { constexpr int ROPE_TYPE_NORM = 0; if (mode == ROPE_TYPE_NORM) { + auto neg_one = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1}); auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2}); auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[3]}); Output even_slice; Output odd_slice; - int32_t unsqueeze_dim = 4; - if (context.is_stateful()) { - unsqueeze_dim = 3; - even_slice = std::make_shared(data_node, zero, end, two, two); - odd_slice = std::make_shared(data_node, one, end, two, two); - } else { - auto three = ov::op::v0::Constant::create(ov::element::i64, {1}, {3}); - even_slice = std::make_shared(data_node, zero, end, two, three); - odd_slice = std::make_shared(data_node, one, end, two, three); - } + int32_t unsqueeze_dim = context.is_stateful() ? 3 : 4; + even_slice = std::make_shared(data_node, zero, end, two, neg_one); + odd_slice = std::make_shared(data_node, one, end, two, neg_one); Output first_half = std::make_shared(std::make_shared(even_slice, cos_theta_node), @@ -105,7 +99,7 @@ OutputVector translate_rope(const NodeContext & context) { res = std::make_shared(stack, data_shape, false); } else if (mode == ROPE_TYPE_NEOX) { auto data_split = std::make_shared( - data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {3}), 2); + data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {-1}), 2); Output slice_data_node_0 = data_split->outputs()[0]; Output slice_data_node_1 = data_split->outputs()[1]; @@ -117,11 +111,7 @@ OutputVector translate_rope(const NodeContext & context) { std::make_shared(slice_data_node_0, sin_theta_node), std::make_shared(slice_data_node_1, cos_theta_node)); - int32_t concat_dim = 3; - if (context.is_stateful()) { - concat_dim = 2; - } - res = std::make_shared(ov::OutputVector{first_half_node, second_half_node}, concat_dim); + res = std::make_shared(ov::OutputVector{first_half_node, second_half_node}, -1); } return rename_outputs_with_suffix({res}, context.get_name()); diff --git a/ggml/src/ggml-openvino/openvino/utils.cpp b/ggml/src/ggml-openvino/openvino/utils.cpp index b7553f99c86..a0215b97b11 100644 --- a/ggml/src/ggml-openvino/openvino/utils.cpp +++ b/ggml/src/ggml-openvino/openvino/utils.cpp @@ -216,7 +216,7 @@ ov::Output process_view_input(const NodeContext & context, int input_i auto begin = ov::op::v0::Constant::create(ov::element::i64, {1}, {split_addr}); auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {slice_end}); auto stride = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); - auto axes = ov::op::v0::Constant::create(ov::element::i64, {1}, {3}); + auto axes = ov::op::v0::Constant::create(ov::element::i64, {1}, {context.is_stateful() ? 2 : 3}); auto sliced = std::make_shared(input, begin, end, stride, axes); return sliced; } diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index 8c3717472b4..edf42cd9854 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -497,6 +497,7 @@ ov::Tensor get_ov_input_tensor(std::shared_ptr ggml_decoder, cons ov::Tensor get_ov_input_tensor_static_decode(std::shared_ptr ggml_decoder, const std::string & param_name) { + // NPU decoding stage const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(param_name); const auto * op = ggml_decoder->get_tensor_used_op(ggml_tensor); @@ -540,6 +541,7 @@ ov::Tensor get_ov_input_tensor_static_decode(std::shared_ptr ggml ov::Tensor get_ov_input_tensor_static_prefill(std::shared_ptr ggml_decoder, const std::string & param_name, int chunk_index) { + // NPU prompt processing stage const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(param_name); const auto * op = ggml_decoder->get_tensor_used_op(ggml_tensor); From 4a8fd24e32089ba84e57bb92ff3a23ae1a067894 Mon Sep 17 00:00:00 2001 From: Mustafa Cavus Date: Wed, 21 Jan 2026 15:17:11 -0800 Subject: [PATCH 06/11] Simplification for stateful and update output shape processing --- ggml/src/ggml-openvino/ggml-decoder.cpp | 18 ++++----- ggml/src/ggml-openvino/ggml-decoder.h | 2 +- .../openvino/translate_session.cpp | 25 ++++++++++++ ggml/src/ggml-openvino/utils.cpp | 39 ++++++++----------- 4 files changed, 52 insertions(+), 32 deletions(-) diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index 01e2c2ff193..2f97af0a3ed 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -56,11 +56,11 @@ GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph, m_model_params(model_params), m_compute_params(compute_params) { if (auto * env = getenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS"); env && std::string(env) != "0") { - #ifdef _WIN32 - _putenv_s("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS", ""); - #else - unsetenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS"); - #endif +#ifdef _WIN32 + _putenv_s("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS", ""); +#else + unsetenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS"); +#endif print_tensor_address_map(cgraph); } @@ -106,8 +106,7 @@ GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph, std::map(get_ov_type(src_node), ov::Shape(get_shape(src_node))); + auto param_node = std::make_shared(get_ov_type(src_node), get_shape(src_node)); param_node->set_friendly_name(src_name); param_node->output(0).get_tensor().set_names({src_name}); m_model_inputs[src_name] = param_node; @@ -163,7 +162,8 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) { // GGML_BACKEND_BUFFER_USAGE_ANY are kv caches if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_ANY) { assert(src_name.find("cache_k") == 0 || src_name.find("cache_v") == 0); - if (auto it = std::find(m_model_params.kv_names.begin(), m_model_params.kv_names.end(), src_name); it == m_model_params.kv_names.end()) { + if (auto it = std::find(m_model_params.kv_names.begin(), m_model_params.kv_names.end(), src_name); + it == m_model_params.kv_names.end()) { m_model_params.kv_names.push_back(src_name); if (is_stateful()) { // TODO: The shape modification for stateful model below is not validated for all supported models yet. More generic solution might be needed @@ -719,7 +719,7 @@ void print_tensor_address_map(const ggml_cgraph * cgraph) { } } -std::vector GgmlOvDecoder::get_shape(const ggml_tensor * tensor) { +ov::Shape GgmlOvDecoder::get_shape(const ggml_tensor * tensor) { std::vector shape; for (int i = GGML_MAX_DIMS - 1; i >= 0; --i) { shape.push_back(static_cast(tensor->ne[i])); diff --git a/ggml/src/ggml-openvino/ggml-decoder.h b/ggml/src/ggml-openvino/ggml-decoder.h index c0d18b7512e..f69d1878800 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.h +++ b/ggml/src/ggml-openvino/ggml-decoder.h @@ -207,7 +207,7 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder { bool m_is_prefill = false; int m_prefill_chunk_size = 0; - static std::vector get_shape(const ggml_tensor * tensor); + static ov::Shape get_shape(const ggml_tensor * tensor); static std::vector get_stride(const ggml_tensor * tensor); static ov::element::Type get_ov_type(const ggml_tensor * tensor); static std::string compute_op_type(const ggml_tensor * node); diff --git a/ggml/src/ggml-openvino/openvino/translate_session.cpp b/ggml/src/ggml-openvino/openvino/translate_session.cpp index adb3025d175..b7e7b58531f 100644 --- a/ggml/src/ggml-openvino/openvino/translate_session.cpp +++ b/ggml/src/ggml-openvino/openvino/translate_session.cpp @@ -29,8 +29,10 @@ #include #include #include +#include #include #include +#include namespace ov { namespace frontend { @@ -252,6 +254,29 @@ std::shared_ptr TranslateSession::apply_transformations(std::shared_ptr(); } manager.run_passes(model); + if (ggml_model_decoder->is_stateful()) { + auto output_names = ggml_model_decoder->get_model_output_names(); + std::map model_output_indexes; + for (size_t i=0; iget_output_size(); i++) { + auto output_friendly_name = model->output(i).get_node_shared_ptr()->get_friendly_name(); + auto output_id = model_output_indexes[output_friendly_name]; + auto model_output_shape = model->output(i).get_partial_shape(); + auto decoder_output_shape = ggml_model_decoder->get_output_shape(output_id); + if (model_output_shape.rank().is_static() && decoder_output_shape.rank().is_static() + && model_output_shape.rank().get_length() + 1 == decoder_output_shape.rank().get_length() + && decoder_output_shape[0].is_static() && decoder_output_shape[0].get_length() == 1) { + ppp.output(i).postprocess().custom([](const ov::Output& node) { + auto axes = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{1}, {0}); + return std::make_shared(node, axes); + }); + } + } + model = ppp.build(); + } } return model; } diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index edf42cd9854..0c5a520b251 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -103,10 +103,12 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::strin ggml_decoder->add_extra_inputs(); infer_request = infer_request_cache[key]; - auto * inp_pos = get_inp_pos_tensor(cgraph); - int32_t * pos_data = (int32_t *) inp_pos->data; - if (pos_data[0] == 0) { - infer_request->reset_state(); + if (stateful) { + const auto * inp_pos = get_inp_pos_tensor(cgraph); + int32_t * pos_data = (int32_t *) inp_pos->data; + if (pos_data[0] == 0) { + infer_request->reset_state(); + } } decoder_end_time = ggml_time_us(); @@ -118,7 +120,8 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::strin std::shared_ptr model; auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph); - ggml_decoder = std::make_shared(cgraph, m_params, c_params, model_weights, is_static, stateful); + ggml_decoder = + std::make_shared(cgraph, m_params, c_params, model_weights, is_static, stateful); decoder_end_time = ggml_time_us(); auto input_model = std::make_shared(ggml_decoder); @@ -351,7 +354,9 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph) { } for (size_t i = 0; i < ov_output_names.size(); i++) { - auto output_tensor = get_ov_output_tensor(ggml_decoder, ov_output_names[i]); + auto * ggml_tensor = ggml_decoder->get_model_outputs().at(ov_output_names[i]); + ov::Tensor output_tensor(infer_request->get_output_tensor(i).get_element_type(), + infer_request->get_output_tensor(i).get_shape(), ggml_tensor->data); infer_request->set_output_tensor(i, output_tensor); } @@ -378,7 +383,9 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph) { } for (size_t i = 0; i < ov_output_names.size(); i++) { - auto output_tensor = get_ov_output_tensor(ggml_decoder, ov_output_names[i]); + auto * ggml_tensor = ggml_decoder->get_model_outputs().at(ov_output_names[i]); + ov::Tensor output_tensor(infer_request->get_output_tensor(i).get_element_type(), + infer_request->get_output_tensor(i).get_shape(), ggml_tensor->data); infer_request->set_output_tensor(i, output_tensor); } @@ -478,7 +485,7 @@ ov::Tensor convert_ggml_input_to_ov(std::shared_ptr ggml_decoder, // This case is added to make test-backend-ops work input_shape = ggml_decoder->get_shape(ggml_tensor->view_src); } else { - input_shape = ggml_decoder->get_shape(ggml_tensor); + input_shape = ggml_decoder->get_shape(ggml_tensor); } auto input_tensor = ov::Tensor(ggml_decoder->get_ov_type(ggml_tensor), input_shape, input_data); return input_tensor; @@ -616,20 +623,8 @@ ov::Tensor get_ov_output_tensor(std::shared_ptr ggml_decoder, con auto output_type = ggml_decoder->get_ov_type(ggml_tensor); auto output_shape = ggml_decoder->get_shape(ggml_tensor); - if (ggml_decoder->is_static() && output_shape[2] == 0) { - output_shape[2] = 1; - } - if (ggml_decoder->is_stateful() && ggml_tensor->flags & GGML_TENSOR_FLAG_OUTPUT) { - std::vector output_shape_3d; - for (size_t i=1; idata); - return output_tensor; - } else { - ov::Tensor output_tensor(output_type, output_shape, ggml_tensor->data); - return output_tensor; - } + ov::Tensor output_tensor(output_type, output_shape, ggml_tensor->data); + return output_tensor; } size_t checksum(const void * data, size_t size) { From 750a04a3de9af231dc65e98aa2f886bebc330a2e Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Tue, 3 Feb 2026 17:39:21 +0800 Subject: [PATCH 07/11] Remove hardcode names --- ggml/src/ggml-openvino/ggml-decoder.cpp | 33 +++++++++---------- ggml/src/ggml-openvino/ggml-decoder.h | 28 ++++++++++++++++ .../src/ggml-openvino/ggml-openvino-extra.cpp | 3 +- ggml/src/ggml-openvino/ggml-openvino.cpp | 4 +-- ggml/src/ggml-openvino/utils.cpp | 18 +++++----- 5 files changed, 56 insertions(+), 30 deletions(-) diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index 2f97af0a3ed..4806b90894b 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -161,7 +161,6 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) { ov::PartialShape stateful_kv_shape; // GGML_BACKEND_BUFFER_USAGE_ANY are kv caches if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_ANY) { - assert(src_name.find("cache_k") == 0 || src_name.find("cache_v") == 0); if (auto it = std::find(m_model_params.kv_names.begin(), m_model_params.kv_names.end(), src_name); it == m_model_params.kv_names.end()) { m_model_params.kv_names.push_back(src_name); @@ -242,18 +241,18 @@ int GgmlOvDecoder::compute_op_case(const ggml_tensor * node) const { case GGML_OP_PERMUTE: { if (node->src[0]->op != GGML_OP_VIEW) { op_case = 1; - } else if (ggml_is_contiguous(node->src[0])) { + } else if (node->src[0]->src[0]->op == GGML_OP_NONE) { + // kv cache tensor std::string src_name(node->view_src->name); - if (src_name.find("cache") == std::string::npos) { - op_case = 4; + int layer = extract_layer_from_name(src_name); + if (!is_swa_layer(layer)) { + op_case = 2; } else { - int layer = extract_layer_from_name(src_name); - if (!is_swa_layer(layer)) { - op_case = 2; - } else { - op_case = 3; - } + op_case = 3; } + } else if (node->src[0]->src[0]->op == GGML_OP_ROPE || node->src[0]->src[0]->src[0]->op == GGML_OP_ROPE) { + // rope'ed query tensor + op_case = 4; } break; } @@ -383,16 +382,16 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * op, co auto name = std::string(input->name); ov::PartialShape input_shape; - if ((op->op == GGML_OP_GET_ROWS && op->src[0]->op == GGML_OP_NONE) || op->op == GGML_OP_ROPE) { + if (is_inp_tok(input, op) || is_inp_pos(input, op)) { // tokens or positions int len = m_is_static ? (m_is_prefill ? m_prefill_chunk_size : 1) : -1; input_shape = ov::PartialShape{1, 1, 1, len}; - } else if (op->op == GGML_OP_GET_ROWS) { + } else if (is_output_idx(input, op)) { // output index input_shape = ov::PartialShape{1, 1, 1, m_is_static ? m_compute_params.output_len : -1}; - } else if (op->op == GGML_OP_CPY || op->op == GGML_OP_FLASH_ATTN_EXT) { + } else if (is_inp_mask(input, op)) { // mask if (m_is_static) { input_shape = ov::PartialShape{1, 1, m_is_prefill ? m_prefill_chunk_size : 1, m_model_params.ctx}; @@ -402,7 +401,7 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * op, co input_shape = ov::PartialShape{-1, 1, -1, -1}; } - } else if (op && op->op == GGML_OP_SET_ROWS && op->src[2] == input) { + } else if (is_kvcache(input, op)) { // kvcache input_shape = ov::PartialShape{get_shape(input)}; if (!m_is_static) { @@ -410,7 +409,7 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * op, co input_shape[2] = -1; } - } else if (op && op->op == GGML_OP_SET_ROWS && op->src[1] == input) { + } else if (is_kv_idx(input, op)) { // kv update index int len = m_is_static ? (m_is_prefill ? m_prefill_chunk_size : 1) : -1; input_shape = ov::PartialShape{1, 1, 1, len}; @@ -490,9 +489,7 @@ const ggml_tensor * GgmlOvDecoder::get_tensor_from_name(const std::string & name std::map GgmlOvDecoder::get_kv_param_res_names() const { std::map kv_param_res_names; for (const auto & name : m_model_params.kv_names) { - if (name.find("cache_k") == 0 || name.find("cache_v") == 0) { - kv_param_res_names[name] = name; - } + kv_param_res_names[name] = name; } return kv_param_res_names; } diff --git a/ggml/src/ggml-openvino/ggml-decoder.h b/ggml/src/ggml-openvino/ggml-decoder.h index f69d1878800..260cc0cedbb 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.h +++ b/ggml/src/ggml-openvino/ggml-decoder.h @@ -213,6 +213,34 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder { static std::string compute_op_type(const ggml_tensor * node); void add_extra_inputs(); + inline static bool is_inp_tok(const ggml_tensor * tensor, const ggml_tensor * op) { + return op->op == GGML_OP_GET_ROWS && tensor == op->src[1] && op->src[0]->op == GGML_OP_NONE; + } + + inline static bool is_inp_pos(const ggml_tensor * tensor, const ggml_tensor * op) { + return op->op == GGML_OP_ROPE && tensor == op->src[1]; + } + + inline static bool is_inp_emb(const ggml_tensor * tensor, const ggml_tensor * op) { + return tensor->op == GGML_OP_GET_ROWS && op->op == GGML_OP_RMS_NORM; + } + + inline static bool is_inp_mask(const ggml_tensor * tensor, const ggml_tensor * op) { + return op->op == GGML_OP_CPY || (op->op == GGML_OP_FLASH_ATTN_EXT && tensor == op->src[3]); + } + + inline static bool is_kvcache(const ggml_tensor * tensor, const ggml_tensor * op) { + return op->op == GGML_OP_SET_ROWS && op->src[2] == tensor; + } + + inline static bool is_kv_idx(const ggml_tensor * tensor, const ggml_tensor * op) { + return op->op == GGML_OP_SET_ROWS && op->src[1] == tensor; + } + + inline static bool is_output_idx(const ggml_tensor * tensor, const ggml_tensor * op) { + return op->op == GGML_OP_GET_ROWS && tensor == op->src[1] && op->src[0]->op != GGML_OP_NONE; + } + private: void set_input_output(ggml_tensor * node, bool naive = false); int compute_op_case(const ggml_tensor * node) const; diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp index 76871cc4be3..3b4afbbbce8 100644 --- a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp @@ -85,7 +85,8 @@ void ggml_openvino_device_config::init() { // Release the context (queue keeps a reference) clReleaseContext(cl_ctx); } else if (device_name == "NPU") { - remote_context = ov_singleton_core().get_default_context(device_name); + // remote tensor is not used for NPU yet + // remote_context = ov_singleton_core().get_default_context(device_name); } initialized = true; diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index 8d6a0dbf335..b2d5234083b 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -139,8 +139,8 @@ static enum ggml_status ggml_backend_openvino_buffer_init_tensor(ggml_backend_bu ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; // Put kvcache on device memory for GPU (NPU memory is too small even for kvcache) - if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_ANY && strncmp(tensor->name, "cache_", 6) == 0 && !ctx->is_remote && - ggml_openvino_get_device_name() == "GPU" && !getenv("GGML_OPENVINO_STATEFUL_EXECUTION")) { + if (strncmp(tensor->name, "cache_", 6) == 0 && !ctx->is_remote && ggml_openvino_get_device_name() == "GPU" && + !getenv("GGML_OPENVINO_STATEFUL_EXECUTION")) { GGML_ASSERT(ctx->tensor_extras.empty()); auto device = ctx->device; auto size = ctx->size; diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index 0c5a520b251..83d3b3afee2 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -508,8 +508,8 @@ ov::Tensor get_ov_input_tensor_static_decode(std::shared_ptr ggml const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(param_name); const auto * op = ggml_decoder->get_tensor_used_op(ggml_tensor); - if (param_name == "inp_pos" || param_name == "inp_tokens" || - (op->op == GGML_OP_SET_ROWS && op->src[1] == ggml_tensor)) { + if (GgmlOvDecoder::is_inp_tok(ggml_tensor, op) || GgmlOvDecoder::is_inp_pos(ggml_tensor, op) || + GgmlOvDecoder::is_kv_idx(ggml_tensor, op)) { assert(ggml_tensor->ne[0] == 1); ov::Shape input_shape = {1, 1, 1, 1}; ov::Tensor input_tensor(ggml_decoder->get_ov_type(ggml_tensor), input_shape); @@ -523,7 +523,7 @@ ov::Tensor get_ov_input_tensor_static_decode(std::shared_ptr ggml return input_tensor; } - if (param_name == "inp_out_ids") { + if (GgmlOvDecoder::is_output_idx(ggml_tensor, op)) { ov::Shape input_shape = {1, 1, 1, 1}; ov::Tensor input_tensor(ggml_decoder->get_ov_type(ggml_tensor), input_shape); int32_t inp_out_id = *((int32_t *) ggml_tensor->data); @@ -533,7 +533,7 @@ ov::Tensor get_ov_input_tensor_static_decode(std::shared_ptr ggml return input_tensor; } - if (param_name.find("self_kq_mask") == 0) { + if (GgmlOvDecoder::is_inp_mask(ggml_tensor, op)) { size_t context_size = ggml_decoder->get_ctx_size(); std::vector padded_data = pad_input(ggml_tensor, 1, context_size, -INFINITY); ov::Tensor input_tensor(ov::element::f32, ov::Shape{1, 1, 1, context_size}); @@ -557,8 +557,8 @@ ov::Tensor get_ov_input_tensor_static_prefill(std::shared_ptr ggm const size_t chunk_valid_size = std::min(chunk_size, input_len - chunk_index * chunk_size); const size_t chunk_pad_size = chunk_size - chunk_valid_size; - if (param_name == "inp_pos" || param_name == "inp_tokens" || - (op->op == GGML_OP_SET_ROWS && op->src[1] == ggml_tensor)) { + if (GgmlOvDecoder::is_inp_tok(ggml_tensor, op) || GgmlOvDecoder::is_inp_pos(ggml_tensor, op) || + GgmlOvDecoder::is_kv_idx(ggml_tensor, op)) { ov::Shape input_shape = {1, 1, 1, chunk_size}; ov::Tensor input_tensor(ggml_decoder->get_ov_type(ggml_tensor), input_shape); // copy the chunk_index-th chunk from ggml_tensor @@ -585,7 +585,7 @@ ov::Tensor get_ov_input_tensor_static_prefill(std::shared_ptr ggm return input_tensor; } - if (param_name == "inp_out_ids") { + if (GgmlOvDecoder::is_output_idx(ggml_tensor, op)) { size_t output_len = ggml_decoder->get_compute_params().output_len; ov::Shape input_shape = {1, 1, 1, output_len}; ov::Tensor input_tensor(ggml_decoder->get_ov_type(ggml_tensor), input_shape); @@ -600,7 +600,7 @@ ov::Tensor get_ov_input_tensor_static_prefill(std::shared_ptr ggm return input_tensor; } - if (param_name.find("self_kq_mask") == 0) { + if (GgmlOvDecoder::is_inp_mask(ggml_tensor, op)) { size_t cols = ggml_tensor->ne[0]; size_t rows = ggml_tensor->ne[1]; float * ggml_data = (float *) ggml_tensor->data + chunk_index * chunk_size * cols; @@ -748,7 +748,7 @@ const ggml_tensor * get_inp_pos_tensor(ggml_cgraph * cgraph) { if (src == nullptr) { break; } - if (std::string(src->name) == "inp_pos") { + if (GgmlOvDecoder::is_inp_pos(src, op)) { return src; } } From 47346d08583fdc96d05e35f2207b6c0bc0c9c44b Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Wed, 4 Feb 2026 16:58:39 +0800 Subject: [PATCH 08/11] Avoid re-compilation in llama-bench --- ggml/include/ggml-openvino.h | 2 ++ ggml/src/ggml-openvino/ggml-decoder.cpp | 16 ++++++++-- ggml/src/ggml-openvino/ggml-decoder.h | 20 ++++++++----- ggml/src/ggml-openvino/ggml-openvino.cpp | 16 ++++++++++ ggml/src/ggml-openvino/utils.cpp | 38 ++++++++++-------------- ggml/src/ggml-openvino/utils.h | 25 +++++++++++----- 6 files changed, 78 insertions(+), 39 deletions(-) diff --git a/ggml/include/ggml-openvino.h b/ggml/include/ggml-openvino.h index 46c1485f663..b68b55d1e81 100644 --- a/ggml/include/ggml-openvino.h +++ b/ggml/include/ggml-openvino.h @@ -24,6 +24,8 @@ GGML_BACKEND_API bool ggml_backend_buft_is_openvino(ggml_backend_buffer_type_t b GGML_BACKEND_API bool ggml_backend_buft_is_openvino_host(ggml_backend_buffer_type_t buft); +GGML_BACKEND_API size_t ggml_backend_openvino_buffer_get_ctx_id(ggml_backend_buffer_t buffer); + // device buffer GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_buffer_type(int device); diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index 4806b90894b..f7052bfc823 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -79,6 +79,17 @@ GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph, add_extra_inputs(); } +void GgmlOvDecoder::update_io(ggml_cgraph * cgraph) { + m_cgraph = cgraph; + m_model_inputs.clear(); + m_model_outputs.clear(); + m_node_info_list.clear(); + for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) { + auto * cur_node = cgraph->nodes[node_n]; + set_input_output(cur_node); + } +} + GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph, std::map> & model_weights) { m_cgraph = cgraph; m_model_weights = model_weights; @@ -330,6 +341,7 @@ std::pair GgmlOvDecoder::compute_llm_params(ggml_cgr auto * mask = node->src[3]; std::string mask_name(mask->name); + model_params.kv_buffer_ctx_id = ggml_backend_openvino_buffer_get_ctx_id(cache_k->buffer); if (mask_name.find("swa") != std::string::npos) { model_params.swa_layers.push_back(layer); model_params.ctx_per_seq_swa = cache_k->ne[1]; @@ -358,7 +370,7 @@ std::pair GgmlOvDecoder::compute_llm_params(ggml_cgr break; } if (node->op == GGML_OP_ROPE) { - model_params.rope_params = node->op_params; + memcpy(model_params.rope_params, node->op_params, sizeof(int32_t) * 15); } } auto * output_tensor = cgraph->nodes[cgraph->n_nodes - 1]; @@ -405,7 +417,7 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * op, co // kvcache input_shape = ov::PartialShape{get_shape(input)}; if (!m_is_static) { - // do not fix ctx size to make llama-bench work + // do not fix ctx size to make llama-bench work across test params input_shape[2] = -1; } diff --git a/ggml/src/ggml-openvino/ggml-decoder.h b/ggml/src/ggml-openvino/ggml-decoder.h index 260cc0cedbb..c8e3edeaf89 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.h +++ b/ggml/src/ggml-openvino/ggml-decoder.h @@ -5,6 +5,7 @@ #include "openvino/decoder.hpp" #include +#include #include #include #include @@ -20,20 +21,21 @@ struct ModelParams { int n_heads = -1; int n_heads_kv = -1; int head_size = -1; - int32_t * rope_params = nullptr; + int32_t rope_params[15]; std::vector swa_layers; std::vector kv_names; + size_t kv_buffer_ctx_id = 0; - bool operator==(const ModelParams & other) const { - return n_seq == other.n_seq && n_heads == other.n_heads && n_heads_kv == other.n_heads_kv && - head_size == other.head_size && rope_params == other.rope_params && swa_layers == other.swa_layers && - ctx_per_seq == other.ctx_per_seq && ctx_per_seq_swa == other.ctx_per_seq_swa; + bool same_rope_params(const ModelParams & other) const { + return memcmp(rope_params, other.rope_params, sizeof(int32_t) * 15) == 0; } - bool can_reuse_dynamically(const ModelParams & other) const { return *this == other; } + bool can_reuse_dynamically(const ModelParams & other) const { return same_rope_params(other); } - bool can_reuse_statically(const ModelParams & other) const { return *this == other; } + bool can_reuse_statically(const ModelParams & other) const { return same_rope_params(other) && ctx == other.ctx; } + + bool kv_buffer_changed(const ModelParams & other) const { return kv_buffer_ctx_id != other.kv_buffer_ctx_id; } }; struct ComputeParams { @@ -170,7 +172,7 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder { int get_input_len() const { return m_compute_params.input_len; } - virtual int32_t * get_rope_params() const override { return m_model_params.rope_params; } + virtual int32_t * get_rope_params() const override { return const_cast(m_model_params.rope_params); } virtual std::map get_kv_param_res_names() const override; @@ -213,6 +215,8 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder { static std::string compute_op_type(const ggml_tensor * node); void add_extra_inputs(); + void update_io(ggml_cgraph * cgraph); + inline static bool is_inp_tok(const ggml_tensor * tensor, const ggml_tensor * op) { return op->op == GGML_OP_GET_ROWS && tensor == op->src[1] && op->src[0]->op == GGML_OP_NONE; } diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index b2d5234083b..87577dde9c7 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -8,6 +8,7 @@ #include "ggml-quants.hpp" #include "ggml.h" +#include #include #include #include @@ -53,6 +54,7 @@ struct ggml_backend_openvino_buffer_context { int device; std::string name; + size_t id; // For non-weight buffers (KV cache, compute), we still use contiguous allocation void * data; @@ -71,6 +73,10 @@ struct ggml_backend_openvino_buffer_context { ggml_backend_openvino_buffer_context(int device, size_t size, bool is_remote = false) : device(device), name(std::string(GGML_OPENVINO_NAME) + std::to_string(device)), + id([]() { + static std::atomic next_id{1}; + return next_id.fetch_add(1); + }()), data(nullptr), size(size), is_remote(is_remote) { @@ -107,6 +113,8 @@ struct ggml_backend_openvino_buffer_context { ~ggml_backend_openvino_buffer_context() { // Clean up all tensor extras + GGML_LOG_DEBUG("Deleting OpenVINO buffer context #%zu for device %d, size %zu MB\n", id, device, + size / 1024 / 1024); for (auto & pair : tensor_extras) { delete pair.second; } @@ -587,6 +595,14 @@ bool ggml_backend_buffer_is_openvino(ggml_backend_buffer_t buffer) { return buffer->iface.free_buffer == ggml_backend_openvino_buffer_free_buffer; } +size_t ggml_backend_openvino_buffer_get_ctx_id(ggml_backend_buffer_t buffer) { + if (!ggml_backend_buffer_is_openvino(buffer)) { + return 0; + } + ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + return ctx->id; +} + bool ggml_backend_buft_is_openvino(ggml_backend_buffer_type_t buft) { return buft->iface.get_name == ggml_backend_openvino_buffer_type_get_name; } diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index 83d3b3afee2..69cac19019c 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -76,7 +76,7 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::strin ComputeParams c_params; std::tie(m_params, c_params) = GgmlOvDecoder::compute_llm_params(cgraph, is_static); - const auto key = compute_graph_key(cgraph); + graph_key key(cgraph); bool cache_hit; int64_t decoder_end_time; @@ -90,19 +90,22 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::strin auto it = decoder_cache.find(key); cache_hit = it != decoder_cache.end(); + ModelParams old_m_params; if (cache_hit) { ggml_decoder = it->second; - cache_hit = ggml_decoder->get_model_params().can_reuse_dynamically(m_params); + old_m_params = ggml_decoder->get_model_params(); + cache_hit = old_m_params.can_reuse_dynamically(m_params); } if (cache_hit) { std::map> model_weights; - ggml_decoder = decoder_cache[key]; ggml_decoder->set_compute_params(c_params); ggml_decoder->set_model_params(m_params); + if (old_m_params.kv_buffer_changed(m_params)) { + ggml_decoder->update_io(cgraph); + } ggml_decoder->add_extra_inputs(); - infer_request = infer_request_cache[key]; - + infer_request = infer_request_cache.at(key); if (stateful) { const auto * inp_pos = get_inp_pos_tensor(cgraph); int32_t * pos_data = (int32_t *) inp_pos->data; @@ -240,7 +243,7 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph) { const auto * inp_pos = get_inp_pos_tensor(cgraph); const auto is_prefill = get_is_prefill(inp_pos); - const auto key = compute_graph_key(cgraph); + graph_key key(cgraph); bool cache_hit; int64_t decoder_end_time; @@ -254,19 +257,23 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph) { auto it = decoder_cache.find(key); cache_hit = it != decoder_cache.end(); + ModelParams old_m_params; if (cache_hit) { ggml_decoder = it->second; - cache_hit = ggml_decoder->get_model_params().can_reuse_statically(m_params); + old_m_params = ggml_decoder->get_model_params(); + cache_hit = old_m_params.can_reuse_statically(m_params); } if (cache_hit) { std::map> model_weights; - ggml_decoder = decoder_cache[key]; ggml_decoder->m_is_prefill = is_prefill; ggml_decoder->set_model_params(m_params); ggml_decoder->set_compute_params(c_params); + if (old_m_params.kv_buffer_changed(m_params)) { + ggml_decoder->update_io(cgraph); + } ggml_decoder->add_extra_inputs(); - infer_request = is_prefill ? infer_request_cache_prefill[key] : infer_request_cache[key]; + infer_request = is_prefill ? infer_request_cache_prefill.at(key) : infer_request_cache.at(key); decoder_end_time = ggml_time_us(); conversion_end_time = decoder_end_time; @@ -761,17 +768,4 @@ bool get_is_prefill(const ggml_tensor * inp_pos) { return inp_pos->ne[0] > 1; } -graph_key compute_graph_key(ggml_cgraph * cgraph) { - graph_key key; - key.n_nodes = cgraph->n_nodes; - - for (int i = 0; i < cgraph->n_nodes; ++i) { - const auto * node = cgraph->nodes[i]; - if (node->op == GGML_OP_SET_ROWS && strncmp(node->src[2]->name, "cache_k_l0", 10) == 0) { - key.cache_k_l0 = node->src[2]; - } - } - return key; -} - #pragma GCC diagnostic pop diff --git a/ggml/src/ggml-openvino/utils.h b/ggml/src/ggml-openvino/utils.h index 72ef904f741..7c403b7d890 100644 --- a/ggml/src/ggml-openvino/utils.h +++ b/ggml/src/ggml-openvino/utils.h @@ -5,20 +5,33 @@ #include #include #include +#include struct graph_key { - size_t n_nodes; - void * cache_k_l0; + int n_nodes; + std::string first_node_name; + std::string last_node_name; + + graph_key(const ggml_cgraph * cgraph) : n_nodes(cgraph->n_nodes) { + if (n_nodes > 0) { + first_node_name = cgraph->nodes[0]->name; + last_node_name = cgraph->nodes[n_nodes - 1]->name; + } + } bool operator==(const graph_key & other) const { - return n_nodes == other.n_nodes && cache_k_l0 == other.cache_k_l0; + return n_nodes == other.n_nodes && first_node_name == other.first_node_name && + last_node_name == other.last_node_name; } }; struct graph_key_hash { size_t operator()(const graph_key & key) const { - size_t h = std::hash{}(key.n_nodes); - h ^= std::hash{}(key.cache_k_l0) + 0x9e3779b9 + (h << 6) + (h >> 2); + size_t h = std::hash{}(key.n_nodes); + if (key.n_nodes > 0) { + h ^= std::hash{}(key.first_node_name) + 0x9e3779b9 + (h << 6) + (h >> 2); + h ^= std::hash{}(key.last_node_name) + 0x9e3779b9 + (h << 6) + (h >> 2); + } return h; } }; @@ -66,8 +79,6 @@ const ggml_tensor * get_inp_pos_tensor(struct ggml_cgraph * cgraph); bool get_is_prefill(const ggml_tensor * inp_pos); -graph_key compute_graph_key(struct ggml_cgraph * cgraph); - ov::Tensor get_ov_input_tensor(std::shared_ptr ggml_decoder, const std::string & param_name); ov::Tensor get_ov_input_tensor_static_decode(std::shared_ptr ggml_decoder, const std::string & param_name); From 907d8322e699c76fa8d2f0c3bfddc8a034874f59 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Thu, 5 Feb 2026 11:12:50 +0800 Subject: [PATCH 09/11] Extract zp directly instead of bias --- ggml/src/ggml-openvino/ggml-decoder.cpp | 35 +- .../src/ggml-openvino/ggml-openvino-extra.cpp | 29 +- ggml/src/ggml-openvino/ggml-openvino-extra.h | 21 +- ggml/src/ggml-openvino/ggml-openvino.cpp | 13 +- ggml/src/ggml-openvino/ggml-quants.cpp | 393 +++++++++--------- ggml/src/ggml-openvino/ggml-quants.hpp | 86 ++-- 6 files changed, 297 insertions(+), 280 deletions(-) diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index f7052bfc823..d8d71cf25ee 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -508,10 +508,10 @@ std::map GgmlOvDecoder::get_kv_param_res_names() const std::map> GgmlOvDecoder::create_weight_nodes(ggml_cgraph * cgraph) { std::map> model_weights; - static std::mutex weights_mutex; + // static std::mutex weights_mutex; auto * nodes = cgraph->nodes; auto n_nodes = cgraph->n_nodes; - std::for_each(std::execution::par, nodes, nodes + n_nodes, [&](ggml_tensor * node) { + std::for_each(std::execution::seq, nodes, nodes + n_nodes, [&](ggml_tensor * node) { for (int i = 0; i < GGML_MAX_SRC; i++) { auto * src = node->src[i]; if (src == nullptr) { @@ -522,21 +522,26 @@ std::map> GgmlOvDecoder::create_weight_no if (!src->view_src) { ggml_backend_buffer * buffer = src->buffer; if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS || ggml_is_quantized(src->type)) { - bool should_create = false; - { - std::lock_guard lock(weights_mutex); - if (model_weights.find(src_name) == model_weights.end()) { - model_weights[src_name] = nullptr; - should_create = true; - } - } - if (should_create) { + // bool should_create = false; + // { + // std::lock_guard lock(weights_mutex); + // if (model_weights.find(src_name) == model_weights.end()) { + // model_weights[src_name] = nullptr; + // should_create = true; + // } + // } + // if (should_create) { + // auto weight_node = create_weight_node(src); + // weight_node->set_friendly_name(src_name); + // { + // std::lock_guard lock(weights_mutex); + // model_weights[src_name] = weight_node; + // } + // } + if (model_weights.find(src_name) == model_weights.end()) { auto weight_node = create_weight_node(src); weight_node->set_friendly_name(src_name); - { - std::lock_guard lock(weights_mutex); - model_weights[src_name] = weight_node; - } + model_weights[src_name] = weight_node; } } } diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp index 3b4afbbbce8..4584dc38d0e 100644 --- a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp @@ -209,12 +209,12 @@ ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_ten layout.is_requant = true; layout.requant_type = requant_type; - // Special case: requant to F16 - just store F16 weights, no scales/biases + // Special case: requant to F16 - just store F16 weights, no scales/zp if (requant_type.value() == ExtraQuantType::F16) { layout.weights_size = n_elements * sizeof(uint16_t); // F16 = 2 bytes layout.total_size = layout.weights_size; layout.weights_offset = 0; - // No scales/biases for F16 + // No scales/zp for F16 return layout; } @@ -255,14 +255,15 @@ ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_ten layout.weights_size = layout.is_u4 ? (n_elements / 2) : n_elements; int64_t n_blocks = n_elements / layout.weights_per_block; layout.scales_size = n_blocks * sizeof(uint16_t); - // For symmetric quantization, we only need one bias value (not one per block) - layout.biases_size = layout.is_symmetric ? sizeof(uint16_t) : n_blocks * sizeof(uint16_t); + // For symmetric quantization, we only need one zp value (not one per block) + // Zero points are stored in U4 or U8 format matching the weight type + size_t n_zp_elements = layout.is_symmetric ? 1 : n_blocks; + layout.zp_size = layout.is_u4 ? ((n_zp_elements + 1) / 2) : n_zp_elements; layout.weights_offset = 0; layout.scales_offset = ((layout.weights_size + alignment - 1) / alignment) * alignment; - layout.biases_offset = - layout.scales_offset + ((layout.scales_size + alignment - 1) / alignment) * alignment; - layout.total_size = layout.biases_offset + layout.biases_size; + layout.zp_offset = layout.scales_offset + ((layout.scales_size + alignment - 1) / alignment) * alignment; + layout.total_size = layout.zp_offset + layout.zp_size; layout.total_size = std::max(layout.total_size, ggml_nbytes(tensor)); return layout; } @@ -305,17 +306,19 @@ ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_ten // Weights: U4 = n_elements/2 bytes, U8 = n_elements bytes layout.weights_size = layout.is_u4 ? (n_elements / 2) : n_elements; - // Scales and biases: F16 per block + // Scales: F16 per block int64_t n_blocks = n_elements / layout.weights_per_block; layout.scales_size = n_blocks * sizeof(uint16_t); // F16 = 2 bytes - // For symmetric quantization, we only need one bias value (not one per block) - layout.biases_size = layout.is_symmetric ? sizeof(uint16_t) : n_blocks * sizeof(uint16_t); + // Zero points: U4 or U8 matching weight type + // For symmetric quantization, we only need one zp value (not one per block) + size_t n_zp_elements = layout.is_symmetric ? 1 : n_blocks; + layout.zp_size = layout.is_u4 ? ((n_zp_elements + 1) / 2) : n_zp_elements; - // Layout in buffer: [weights | scales | biases] with alignment + // Layout in buffer: [weights | scales | zp] with alignment layout.weights_offset = 0; layout.scales_offset = ((layout.weights_size + alignment - 1) / alignment) * alignment; - layout.biases_offset = layout.scales_offset + ((layout.scales_size + alignment - 1) / alignment) * alignment; - layout.total_size = layout.biases_offset + layout.biases_size; + layout.zp_offset = layout.scales_offset + ((layout.scales_size + alignment - 1) / alignment) * alignment; + layout.total_size = layout.zp_offset + layout.zp_size; return layout; } diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.h b/ggml/src/ggml-openvino/ggml-openvino-extra.h index e2c5a8ceeae..726a90abb02 100644 --- a/ggml/src/ggml-openvino/ggml-openvino-extra.h +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.h @@ -110,16 +110,19 @@ struct ggml_openvino_weight_extra : public ggml_openvino_extra_base { : ggml_openvino_extra_base(Type::WEIGHT), constant(std::move(c)) {} }; -// Extra data for quantized weight tensors - stores extracted weights/scales/biases and ov::Constant +// Extra data for quantized weight tensors - stores extracted weights/scales/zp and ov::Constant struct ggml_openvino_quantized_weight_extra : public ggml_openvino_extra_base { ov::Tensor weights; // U4 or U8 extracted weights ov::Tensor scales; // F16 scales - ov::Tensor biases; // F16 biases (zero points) + ov::Tensor zp; // U4 or U8 zero points (same type as weights) std::shared_ptr constant; // Pre-built OpenVINO weight subgraph - ggml_openvino_quantized_weight_extra(ov::Tensor w, ov::Tensor s, ov::Tensor b, std::shared_ptr c) - : ggml_openvino_extra_base(Type::QUANTIZED_WEIGHT), - weights(std::move(w)), scales(std::move(s)), biases(std::move(b)), constant(std::move(c)) {} + ggml_openvino_quantized_weight_extra(ov::Tensor w, ov::Tensor s, ov::Tensor z, std::shared_ptr c) : + ggml_openvino_extra_base(Type::QUANTIZED_WEIGHT), + weights(std::move(w)), + scales(std::move(s)), + zp(std::move(z)), + constant(std::move(c)) {} }; // Extra data for KV cache / compute tensors - stores ov::Tensor for infer_request @@ -133,7 +136,7 @@ struct ggml_openvino_tensor_extra : public ggml_openvino_extra_base { // ===================================================== // Extracted Size Calculation for Quantized Tensors // ===================================================== -// For quantized tensors, we need extra space to store extracted weights, scales, and biases. +// For quantized tensors, we need extra space to store extracted weights, scales, and zero points. // Returns the total size needed in the buffer for extracted data. struct ggml_openvino_extracted_layout { @@ -142,10 +145,10 @@ struct ggml_openvino_extracted_layout { size_t weights_size; // Size of weights in bytes size_t scales_offset; // Offset to scales in buffer size_t scales_size; // Size of scales in bytes - size_t biases_offset; // Offset to biases in buffer - size_t biases_size; // Size of biases in bytes + size_t zp_offset; // Offset to zero points in buffer + size_t zp_size; // Size of zero points in bytes (U4 or U8) bool is_u4; // true for U4 weights, false for U8 - int64_t weights_per_block;// weights per scale/bias block + int64_t weights_per_block; // weights per scale/zp block bool is_symmetric; // true for symmetric quantization // Requantization info diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index 87577dde9c7..e531a9c0362 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -259,13 +259,15 @@ static void ggml_backend_openvino_buffer_set_tensor(ggml_backend_buffer_t buffer ov::Shape weight_shape = {static_cast(tensor->ne[1]), static_cast(tensor->ne[0])}; ov::Shape scale_shape = {static_cast(tensor->ne[1]), static_cast(tensor->ne[0] / layout.weights_per_block)}; + // zp shape: scalar for symmetric, per-block for asymmetric + ov::Shape zp_shape = layout.is_symmetric ? ov::Shape{} : scale_shape; ov::Tensor weights(weight_type, weight_shape, buf_base + layout.weights_offset); ov::Tensor scales(ov::element::f16, scale_shape, buf_base + layout.scales_offset); - ov::Tensor biases(ov::element::f16, scale_shape, buf_base + layout.biases_offset); + ov::Tensor zp(weight_type, zp_shape, buf_base + layout.zp_offset); auto * extra = new ggml_openvino_quantized_weight_extra(std::move(weights), std::move(scales), - std::move(biases), constant); + std::move(zp), constant); ctx->tensor_extras[tensor] = extra; tensor->extra = extra; @@ -487,10 +489,9 @@ static size_t ggml_backend_openvino_buffer_type_get_alloc_size(ggml_backend_buff if (ggml_is_quantized(tensor->type) && tensor->ne[2] == 1 && tensor->ne[3] == 1) { ggml_openvino_extracted_layout layout = ggml_openvino_get_extracted_layout(tensor); if (layout.total_size > 0) { - GGML_LOG_DEBUG( - "%s: tensor %s needs %zu bytes (original %zu, extracted: weights=%zu scales=%zu biases=%zu)\n", - __func__, tensor->name, layout.total_size, ggml_nbytes(tensor), layout.weights_size, layout.scales_size, - layout.biases_size); + GGML_LOG_DEBUG("%s: tensor %s needs %zu bytes (original %zu, extracted: weights=%zu scales=%zu zp=%zu)\n", + __func__, tensor->name, layout.total_size, ggml_nbytes(tensor), layout.weights_size, + layout.scales_size, layout.zp_size); return layout.total_size; } } diff --git a/ggml/src/ggml-openvino/ggml-quants.cpp b/ggml/src/ggml-openvino/ggml-quants.cpp index 8946b73a561..2de0494c910 100644 --- a/ggml/src/ggml-openvino/ggml-quants.cpp +++ b/ggml/src/ggml-openvino/ggml-quants.cpp @@ -42,80 +42,97 @@ void unpack_32_4(const uint8_t * data, uint8_t * dst) { } } -// Extracts (weight, scales, biases) from Q4_0 tensors. +// Extracts (weight, scales, zp) from Q4_0 tensors. // Data layout is: |16 bit scale|32 x 4bit weights|. void extract_q4_0_data(const ggml_tensor * tensor, ov::Tensor & weights_arr, ov::Tensor & scales_arr, - ov::Tensor & biases_arr) { + ov::Tensor & zp_arr) { const uint64_t bytes_per_block = 18; // 2 bytes scale, 32x0.5 byte weights auto * data = static_cast(tensor->data); auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * biases = biases_arr.data::value_type>(); + auto * zp = static_cast(zp_arr.data()); - bool is_scalar_bias = (biases_arr.get_size() == 1); // Symmetric quantization + bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization + + // For Q4_0, zero point is always 8 + if (is_scalar_zp) { + zp[0] = 8 | (8 << 4); // Pack two 4-bit values + } ov::parallel_for(scales_arr.get_size(), [&](size_t i) { scales[i] = ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block))); - // For symmetric quantization, only write the first bias (all blocks share the same bias relationship) - if (is_scalar_bias) { - if (i == 0) { - biases[0] = ov::float16(-8.f * static_cast(scales[0])); + // For asymmetric quantization, compute per-block zero points + if (!is_scalar_zp) { + // Pack two 4-bit zero points per byte + if (i % 2 == 0) { + zp[i / 2] = 8; // Lower nibble + } else { + zp[i / 2] |= (8 << 4); // Upper nibble } - } else { - biases[i] = ov::float16(-8.f * static_cast(scales[i])); } unpack_32_4(data + i * bytes_per_block + 2, weights + i * 16); }); } -// Extracts (weight, scales, biases) from Q4_1 tensors. -// Data layout is: |16 bit scale|16 bit bias|32 x 4bit weights|. +// Extracts (weight, scales, zp) from Q4_1 tensors. +// Data layout is: |16 bit scale|16 bit min|32 x 4bit weights|. void extract_q4_1_data(const ggml_tensor * tensor, ov::Tensor & weights_arr, ov::Tensor & scales_arr, - ov::Tensor & biases_arr) { - const uint64_t bytes_per_block = 20; // 2 bytes scale, 2 bytes bias, 32x0.5 byte weights + ov::Tensor & zp_arr) { + const uint64_t bytes_per_block = 20; // 2 bytes scale, 2 bytes min, 32x0.5 byte weights auto * data = static_cast(tensor->data); auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * biases = biases_arr.data::value_type>(); + auto * zp = static_cast(zp_arr.data()); + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { - scales[i] = ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block))); - biases[i] = ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block + 2))); + float scale = static_cast(ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block)))); + float min = static_cast(ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block + 2)))); + scales[i] = ov::float16(scale); + // zp = -min / scale (bias = min, so zp = -bias/scale) + uint8_t zp_val = (scale != 0.0f) ? (uint8_t) std::round(-min / scale) : 0; + // Pack two 4-bit zero points per byte + if (i % 2 == 0) { + zp[i / 2] = zp_val & 0x0F; // Lower nibble + } else { + zp[i / 2] |= (zp_val << 4); // Upper nibble + } unpack_32_4(data + i * bytes_per_block + 4, weights + i * 16); }); } -// Extracts (weight, scales, biases) from Q8_0 tensors. +// Extracts (weight, scales, zp) from Q8_0 tensors. // Data layout is: |16 bit scale|32 x 8bit weights|. void extract_q8_0_data(const ggml_tensor * tensor, ov::Tensor & weights_arr, ov::Tensor & scales_arr, - ov::Tensor & biases_arr) { + ov::Tensor & zp_arr) { const uint64_t weights_per_block = 32; const uint64_t bytes_per_block = 34; // 2 bytes scale, 32x1 byte weights auto * data = static_cast(tensor->data); auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * biases = biases_arr.data::value_type>(); + auto * zp = static_cast(zp_arr.data()); - bool is_scalar_bias = (biases_arr.get_size() == 1); // Symmetric quantization + bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization + + // For Q8_0, zero point is always 128 + if (is_scalar_zp) { + zp[0] = 128; + } ov::parallel_for(scales_arr.get_size(), [&](size_t i) { uint8_t * block_data = data + i * bytes_per_block; scales[i] = ov::float16::from_bits(*(uint16_t *) block_data); - // For symmetric quantization, only write the first bias (all blocks share the same bias relationship) - if (is_scalar_bias) { - if (i == 0) { - biases[0] = ov::float16(-128.f * static_cast(scales[0])); - } - } else { - biases[i] = ov::float16(-128.f * static_cast(scales[i])); + // For asymmetric quantization, store per-block zero points + if (!is_scalar_zp) { + zp[i] = 128; } for (size_t j = 0; j < weights_per_block; ++j) { uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes. @@ -147,51 +164,60 @@ void unpack_256_4(const uint8_t * data, uint8_t * dst) { void extract_q4_k_data(const ggml_tensor * tensor, ov::Tensor & weights_arr, ov::Tensor & scales_arr, - ov::Tensor & biases_arr) { + ov::Tensor & zp_arr) { const uint64_t bytes_per_block = 2 + 2 + 12 + 128; const uint64_t n_super_block = tensor->nb[3] / bytes_per_block; auto * data = static_cast(tensor->data); auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * biases = biases_arr.data::value_type>(); + auto * zp = static_cast(zp_arr.data()); ov::parallel_for(n_super_block, [&](size_t i) { uint8_t * block_data = data + i * bytes_per_block; // Extract scale factors and offsets float scale_scales = static_cast(ov::float16::from_bits(*((uint16_t *) block_data))); - float scale_biases = static_cast(ov::float16::from_bits(*((uint16_t *) block_data + 1))); + float scale_mins = static_cast(ov::float16::from_bits(*((uint16_t *) block_data + 1))); // Extract qs1 and qs2 uint8_t * qs1 = block_data + 4; - // uint8_t* qs2 = block_data + 16; - - scales[i * 8] = ov::float16(scale_scales * static_cast((*(qs1) & 0b111111))); - scales[i * 8 + 1] = ov::float16(scale_scales * static_cast((*(qs1 + 1) & 0b111111))); - scales[i * 8 + 2] = ov::float16(scale_scales * static_cast((*(qs1 + 2) & 0b111111))); - scales[i * 8 + 3] = ov::float16(scale_scales * static_cast((*(qs1 + 3) & 0b111111))); - scales[i * 8 + 4] = - ov::float16(scale_scales * static_cast((*(qs1 + 8) & 0b00001111) | ((*(qs1) >> 6) << 4))); - scales[i * 8 + 5] = - ov::float16(scale_scales * static_cast((*(qs1 + 9) & 0b00001111) | ((*(qs1 + 1) >> 6) << 4))); - scales[i * 8 + 6] = - ov::float16(scale_scales * static_cast((*(qs1 + 10) & 0b00001111) | ((*(qs1 + 2) >> 6) << 4))); - scales[i * 8 + 7] = - ov::float16(scale_scales * static_cast((*(qs1 + 11) & 0b00001111) | ((*(qs1 + 3) >> 6) << 4))); - - biases[i * 8] = ov::float16(-1.f * scale_biases * static_cast((*(qs1 + 4) & 0b111111))); - biases[i * 8 + 1] = ov::float16(-1.f * scale_biases * static_cast((*(qs1 + 5) & 0b111111))); - biases[i * 8 + 2] = ov::float16(-1.f * scale_biases * static_cast((*(qs1 + 6) & 0b111111))); - biases[i * 8 + 3] = ov::float16(-1.f * scale_biases * static_cast((*(qs1 + 7) & 0b111111))); - biases[i * 8 + 4] = - ov::float16(-1.f * scale_biases * static_cast((*(qs1 + 8) >> 4) | ((*(qs1 + 4) >> 6) << 4))); - biases[i * 8 + 5] = - ov::float16(-1.f * scale_biases * static_cast((*(qs1 + 9) >> 4) | ((*(qs1 + 5) >> 6) << 4))); - biases[i * 8 + 6] = - ov::float16(-1.f * scale_biases * static_cast((*(qs1 + 10) >> 4) | ((*(qs1 + 6) >> 6) << 4))); - biases[i * 8 + 7] = - ov::float16(-1.f * scale_biases * static_cast((*(qs1 + 11) >> 4) | ((*(qs1 + 7) >> 6) << 4))); + + // Calculate scales + float scale_vals[8]; + scale_vals[0] = scale_scales * static_cast((*(qs1) & 0b111111)); + scale_vals[1] = scale_scales * static_cast((*(qs1 + 1) & 0b111111)); + scale_vals[2] = scale_scales * static_cast((*(qs1 + 2) & 0b111111)); + scale_vals[3] = scale_scales * static_cast((*(qs1 + 3) & 0b111111)); + scale_vals[4] = scale_scales * static_cast((*(qs1 + 8) & 0b00001111) | ((*(qs1) >> 6) << 4)); + scale_vals[5] = scale_scales * static_cast((*(qs1 + 9) & 0b00001111) | ((*(qs1 + 1) >> 6) << 4)); + scale_vals[6] = scale_scales * static_cast((*(qs1 + 10) & 0b00001111) | ((*(qs1 + 2) >> 6) << 4)); + scale_vals[7] = scale_scales * static_cast((*(qs1 + 11) & 0b00001111) | ((*(qs1 + 3) >> 6) << 4)); + + // Calculate min values (bias = -min) + float min_vals[8]; + min_vals[0] = scale_mins * static_cast((*(qs1 + 4) & 0b111111)); + min_vals[1] = scale_mins * static_cast((*(qs1 + 5) & 0b111111)); + min_vals[2] = scale_mins * static_cast((*(qs1 + 6) & 0b111111)); + min_vals[3] = scale_mins * static_cast((*(qs1 + 7) & 0b111111)); + min_vals[4] = scale_mins * static_cast((*(qs1 + 8) >> 4) | ((*(qs1 + 4) >> 6) << 4)); + min_vals[5] = scale_mins * static_cast((*(qs1 + 9) >> 4) | ((*(qs1 + 5) >> 6) << 4)); + min_vals[6] = scale_mins * static_cast((*(qs1 + 10) >> 4) | ((*(qs1 + 6) >> 6) << 4)); + min_vals[7] = scale_mins * static_cast((*(qs1 + 11) >> 4) | ((*(qs1 + 7) >> 6) << 4)); + + // Store scales and compute zero points + for (int j = 0; j < 8; j++) { + scales[i * 8 + j] = ov::float16(scale_vals[j]); + // zp = min / scale (since bias = -min and zp = -bias/scale) + uint8_t zp_val = (scale_vals[j] != 0.0f) ? (uint8_t) std::round(min_vals[j] / scale_vals[j]) : 0; + // Pack two 4-bit zero points per byte + size_t idx = i * 8 + j; + if (idx % 2 == 0) { + zp[idx / 2] = zp_val & 0x0F; + } else { + zp[idx / 2] |= (zp_val << 4); + } + } unpack_256_4(block_data + 16, weights + i * 128); }); } @@ -199,16 +225,21 @@ void extract_q4_k_data(const ggml_tensor * tensor, void extract_q6_k_data(const ggml_tensor * tensor, ov::Tensor & weights_arr, ov::Tensor & scales_arr, - ov::Tensor & biases_arr) { + ov::Tensor & zp_arr) { const uint64_t bytes_per_block = 128 + 64 + 16 + 2; const uint64_t n_super_block = tensor->nb[3] / bytes_per_block; auto * data = static_cast(tensor->data); auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * biases = biases_arr.data::value_type>(); + auto * zp = static_cast(zp_arr.data()); + + bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization - bool is_scalar_bias = (biases_arr.get_size() == 1); // Symmetric quantization + // For Q6_K, zero point is always 32 + if (is_scalar_zp) { + zp[0] = 32; + } ov::parallel_for(n_super_block, [&](size_t i) { uint8_t * block_data = data + i * bytes_per_block; @@ -219,13 +250,9 @@ void extract_q6_k_data(const ggml_tensor * tensor, for (size_t j = 0; j < 16; j++) { scales[j + i * 16] = ov::float16(scale_factor * static_cast(*((int8_t *) (block_data + 128 + 64 + j)))); - // For symmetric quantization, only write the first bias (all blocks share the same bias relationship) - if (is_scalar_bias) { - if (i == 0 && j == 0) { - biases[0] = ov::float16(-32.f * static_cast(scales[0])); - } - } else { - biases[j + i * 16] = ov::float16(-32.f * static_cast(scales[j + i * 16])); + // For asymmetric quantization, store per-block zero points + if (!is_scalar_zp) { + zp[j + i * 16] = 32; } } @@ -258,20 +285,20 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t * d, uint8 void extract_q5_k_data(const ggml_tensor * tensor, ov::Tensor & weights_arr, ov::Tensor & scales_arr, - ov::Tensor & biases_arr) { + ov::Tensor & zp_arr) { const uint64_t bytes_per_block = 4 + 12 + 32 + 128; const uint64_t n_super_block = tensor->nb[3] / bytes_per_block; auto * data = static_cast(tensor->data); auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * biases = biases_arr.data::value_type>(); + auto * zp = static_cast(zp_arr.data()); ov::parallel_for(n_super_block, [&](size_t i) { uint8_t * block_data = data + i * bytes_per_block; const float d = static_cast(ov::float16::from_bits(*((uint16_t *) block_data))); - const float min = static_cast(ov::float16::from_bits(*((uint16_t *) block_data + 1))); + const float min_factor = static_cast(ov::float16::from_bits(*((uint16_t *) block_data + 1))); const uint8_t * scales_data = block_data + 4; // 12 bytes of scales const uint8_t * qh = block_data + 4 + 12; // 32 bytes of high bits @@ -289,17 +316,18 @@ void extract_q5_k_data(const ggml_tensor * tensor, // Get scale and min for first 32 elements get_scale_min_k4(is + 0, scales_data, &sc, &m); const float d1 = d * sc; - const float m1 = min * m; + const float m1 = min_factor * m; // Get scale and min for second 32 elements get_scale_min_k4(is + 1, scales_data, &sc, &m); const float d2 = d * sc; - const float m2 = min * m; + const float m2 = min_factor * m; scales[i * 8 + is] = ov::float16(d1); - biases[i * 8 + is] = ov::float16(-m1); scales[i * 8 + is + 1] = ov::float16(d2); - biases[i * 8 + is + 1] = ov::float16(-m2); + // zp = min / scale (since bias = -min and zp = -bias/scale) + zp[i * 8 + is] = (d1 != 0.0f) ? (uint8_t) std::round(m1 / d1) : 0; + zp[i * 8 + is + 1] = (d2 != 0.0f) ? (uint8_t) std::round(m2 / d2) : 0; // Extract weights for first 32 elements (matching deq formula exactly) for (int l = 0; l < 32; ++l) { @@ -321,16 +349,13 @@ void extract_q5_k_data(const ggml_tensor * tensor, // TODO Reorder for make_intX_weights -ov::Output make_int8_weights(ov::Tensor & weight, - ov::Tensor & scales, - ov::Tensor & biases, - size_t group_size) { +ov::Output make_int8_weights(ov::Tensor & weight, ov::Tensor & scales, ov::Tensor & zp, size_t group_size) { ov::Shape orig_shape = weight.get_shape(); - // Expand dimensions for scales and biases + // Expand dimensions for scales and zp auto scale_shape = scales.get_shape(); - auto bias_shape = biases.get_shape(); - bool is_scalar_bias = bias_shape.empty(); // Symmetric quantization + auto zp_shape = zp.get_shape(); + bool is_scalar_zp = zp_shape.empty(); // Symmetric quantization ov::Shape packed_shape = {orig_shape[0], orig_shape[1] / group_size, group_size}; @@ -340,10 +365,10 @@ ov::Output make_int8_weights(ov::Tensor & weight, } else { scale_shape.push_back(1); scales.set_shape(scale_shape); - // For symmetric quantization, biases remain scalar (don't resize) - if (!is_scalar_bias) { - bias_shape = scale_shape; - biases.set_shape(bias_shape); + // For symmetric quantization, zp remains scalar (don't resize) + if (!is_scalar_zp) { + zp_shape.push_back(1); + zp.set_shape(zp_shape); } } @@ -352,26 +377,9 @@ ov::Output make_int8_weights(ov::Tensor & weight, static_cast(weight.data()), nullptr); weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; auto scales_f16 = std::make_shared(scales); - ov::Tensor biases_u8(ov::element::u8, is_scalar_bias ? ov::Shape{} : scale_shape); - // Calculate zero point - const ov::float16 * bias_data = biases.data::value_type>(); - const ov::float16 * scale_data = scales.data::value_type>(); - uint8_t * bias_u8_data = biases_u8.data(); - - if (is_scalar_bias) { - // Symmetric quantization: single bias value for all blocks - // For Q8_0, bias = -128 * scale, so zero_point = 128 - bias_u8_data[0] = (uint8_t) std::round(-1.f * static_cast(bias_data[0]) / static_cast(scale_data[0])); - } else { - // Asymmetric quantization: per-block biases - for (size_t i = 0; i < biases_u8.get_size(); ++i) { - bias_u8_data[i] = - (uint8_t) std::round(-1.f * static_cast(bias_data[i]) / static_cast(scale_data[i])); - } - } - - auto zero_point = std::make_shared(biases_u8); + // Zero point is already in U8 format from extraction + auto zero_point = std::make_shared(zp); float zp_value; if (ov::op::util::get_single_value(zero_point, zp_value)) { zero_point = ov::op::v0::Constant::create(zero_point->get_element_type(), {}, {zp_value}); @@ -395,16 +403,13 @@ ov::Output make_int8_weights(ov::Tensor & weight, return std::make_shared(w_zp_s, ov::element::f32); } -ov::Output make_int4_weights(ov::Tensor & weight, - ov::Tensor & scales, - ov::Tensor & biases, - size_t group_size) { +ov::Output make_int4_weights(ov::Tensor & weight, ov::Tensor & scales, ov::Tensor & zp, size_t group_size) { ov::Shape orig_weight_shape = weight.get_shape(); - // Expand dimensions for scales and biases - ov::Shape scale_bias_shape = scales.get_shape(); - auto bias_shape = biases.get_shape(); - bool is_scalar_bias = bias_shape.empty(); // Symmetric quantization + // Expand dimensions for scales and zp + ov::Shape scale_shape = scales.get_shape(); + auto zp_shape = zp.get_shape(); + bool is_scalar_zp = zp_shape.empty(); // Symmetric quantization // Create INT4 weight tensor ov::Shape packed_shape = {orig_weight_shape[0], orig_weight_shape[1] / group_size, group_size}; @@ -413,12 +418,12 @@ ov::Output make_int4_weights(ov::Tensor & weight, // Requantized channel-wise case packed_shape.erase(packed_shape.begin() + 1); } else { - scale_bias_shape.push_back(1); - scales.set_shape(scale_bias_shape); - // For symmetric quantization, biases remain scalar (don't resize) - if (!is_scalar_bias) { - bias_shape = scale_bias_shape; - biases.set_shape(bias_shape); + scale_shape.push_back(1); + scales.set_shape(scale_shape); + // For symmetric quantization, zp remains scalar (don't resize) + if (!is_scalar_zp) { + zp_shape.push_back(1); + zp.set_shape(zp_shape); } } @@ -427,29 +432,8 @@ ov::Output make_int4_weights(ov::Tensor & weight, weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; auto weights_f16 = std::make_shared(weights_node, ov::element::f16); - // Pack zero points: two subsequent values into one - const ov::float16 * bias_data = biases.data::value_type>(); - const ov::float16 * scale_data = scales.data::value_type>(); - ov::Tensor zero_point_tensor(ov::element::u4, is_scalar_bias ? ov::Shape{} : scale_bias_shape); - uint8_t * zero_point_data = static_cast(zero_point_tensor.data()); - - if (is_scalar_bias) { - // Symmetric quantization: single bias value for all blocks - // For Q4_0, bias = -8 * scale, so zero_point = 8 - uint8_t zp = (uint8_t) std::round(-1.f * static_cast(bias_data[0]) / static_cast(scale_data[0])); - zero_point_data[0] = (zp << 4) | (zp & 0x0F); - } else { - // Asymmetric quantization: per-block biases - for (size_t i = 0; i < zero_point_tensor.get_byte_size(); ++i) { - uint8_t bias1 = - (uint8_t) std::round(-1.f * static_cast(bias_data[i * 2]) / static_cast(scale_data[i * 2])); - uint8_t bias2 = (uint8_t) std::round(-1.f * static_cast(bias_data[i * 2 + 1]) / - static_cast(scale_data[i * 2 + 1])); - zero_point_data[i] = (bias2 << 4) | (bias1 & 0x0F); - } - } - - auto zero_points_node = std::make_shared(zero_point_tensor); + // Zero point is already in U4 format from extraction + auto zero_points_node = std::make_shared(zp); float zp_value; if (ov::op::util::get_single_value(zero_points_node, zp_value)) { zero_points_node = ov::op::v0::Constant::create(zero_points_node->get_element_type(), {}, {zp_value}); @@ -480,7 +464,7 @@ std::shared_ptr extract_quantized_weights(const ggml_tensor * tensor, const void * data, ov::Tensor & weights, ov::Tensor & scales, - ov::Tensor & biases) { + ov::Tensor & zp) { // Create a temporary tensor for extraction functions that read from tensor->data ggml_tensor temp_tensor = *tensor; temp_tensor.data = const_cast(data); @@ -512,22 +496,22 @@ std::shared_ptr extract_quantized_weights(const ggml_tensor * tensor, // Extract quantized data switch (tensor->type) { case GGML_TYPE_Q4_0: - extract_q4_0_data(&temp_tensor, weights, scales, biases); + extract_q4_0_data(&temp_tensor, weights, scales, zp); break; case GGML_TYPE_Q4_1: - extract_q4_1_data(&temp_tensor, weights, scales, biases); + extract_q4_1_data(&temp_tensor, weights, scales, zp); break; case GGML_TYPE_Q4_K: - extract_q4_k_data(&temp_tensor, weights, scales, biases); + extract_q4_k_data(&temp_tensor, weights, scales, zp); break; case GGML_TYPE_Q8_0: - extract_q8_0_data(&temp_tensor, weights, scales, biases); + extract_q8_0_data(&temp_tensor, weights, scales, zp); break; case GGML_TYPE_Q6_K: - extract_q6_k_data(&temp_tensor, weights, scales, biases); + extract_q6_k_data(&temp_tensor, weights, scales, zp); break; case GGML_TYPE_Q5_K: - extract_q5_k_data(&temp_tensor, weights, scales, biases); + extract_q5_k_data(&temp_tensor, weights, scales, zp); break; default: throw std::runtime_error("Unsupported quantized type: " + std::string(ggml_type_name(tensor->type))); @@ -536,9 +520,9 @@ std::shared_ptr extract_quantized_weights(const ggml_tensor * tensor, // Create the OpenVINO weight subgraph ov::Output weight_node; if (is_u4) { - weight_node = make_int4_weights(weights, scales, biases, weights_per_block); + weight_node = make_int4_weights(weights, scales, zp, weights_per_block); } else { - weight_node = make_int8_weights(weights, scales, biases, weights_per_block); + weight_node = make_int8_weights(weights, scales, zp, weights_per_block); } auto result = weight_node.get_node_shared_ptr(); @@ -553,7 +537,7 @@ std::shared_ptr requantize_to_buffers(const ggml_tensor * tensor, int64_t block_size, ov::Tensor & weights, ov::Tensor & scales, - ov::Tensor & biases) { + ov::Tensor & zp) { int64_t n_elements = ggml_nelements(tensor); // First dequantize to F32 @@ -572,19 +556,19 @@ std::shared_ptr requantize_to_buffers(const ggml_tensor * tensor, bool is_u4 = (requant_type == ExtraQuantType::Q4_0_C || requant_type == ExtraQuantType::Q4_0_128); if (is_u4) { - quantize_q4_0(weights_f32.data(), weights, scales, biases, n_elements, block_size); + quantize_q4_0(weights_f32.data(), weights, scales, zp, n_elements, block_size); } else if (requant_type == ExtraQuantType::Q8_1_C) { - quantize_q8_1(weights_f32.data(), weights, scales, biases, n_elements, block_size); + quantize_q8_1(weights_f32.data(), weights, scales, zp, n_elements, block_size); } else { - quantize_q8_0(weights_f32.data(), weights, scales, biases, n_elements, block_size); + quantize_q8_0(weights_f32.data(), weights, scales, zp, n_elements, block_size); } // Create the OpenVINO weight subgraph ov::Output weight_node; if (is_u4) { - weight_node = make_int4_weights(weights, scales, biases, block_size); + weight_node = make_int4_weights(weights, scales, zp, block_size); } else { - weight_node = make_int8_weights(weights, scales, biases, block_size); + weight_node = make_int8_weights(weights, scales, zp, block_size); } auto result = weight_node.get_node_shared_ptr(); @@ -653,50 +637,52 @@ std::shared_ptr process_weight_tensor(const ggml_tensor * tensor, cons } else { weights = ov::Tensor(ov::element::f16, node_shape); } - ov::Tensor dummy_scales, dummy_biases; // Not used for F16 - result = requantize_to_buffers(tensor, data, ExtraQuantType::F16, 0, weights, dummy_scales, dummy_biases); + ov::Tensor dummy_scales, dummy_zp; // Not used for F16 + result = requantize_to_buffers(tensor, data, ExtraQuantType::F16, 0, weights, dummy_scales, dummy_zp); } else { // Requant to quantized format (Q4_0_128, Q8_0_32, etc.) ov::element::Type weight_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; ov::Shape scale_shape = {node_shape[0], node_shape[1] / layout.weights_per_block}; - // For symmetric quantization, biases are a single value instead of per-block - ov::Shape bias_shape = layout.is_symmetric ? ov::Shape{} : scale_shape; + // For symmetric quantization, zp is a scalar value instead of per-block + // zp uses the same element type as weights (U4 or U8) + ov::Shape zp_shape = layout.is_symmetric ? ov::Shape{} : scale_shape; - ov::Tensor weights, scales, biases; + ov::Tensor weights, scales, zp; if (output_base_ptr) { uint8_t * buf_base = static_cast(output_base_ptr); weights = ov::Tensor(weight_type, node_shape, buf_base + layout.weights_offset); scales = ov::Tensor(ov::element::f16, scale_shape, buf_base + layout.scales_offset); - biases = ov::Tensor(ov::element::f16, bias_shape, buf_base + layout.biases_offset); + zp = ov::Tensor(weight_type, zp_shape, buf_base + layout.zp_offset); } else { weights = ov::Tensor(weight_type, node_shape); scales = ov::Tensor(ov::element::f16, scale_shape); - biases = ov::Tensor(ov::element::f16, bias_shape); + zp = ov::Tensor(weight_type, zp_shape); } result = requantize_to_buffers(tensor, data, layout.requant_type.value(), layout.weights_per_block, weights, - scales, biases); + scales, zp); } } else { // Normal extraction path (no requant) ov::element::Type weight_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; ov::Shape scale_shape = {node_shape[0], node_shape[1] / layout.weights_per_block}; - // For symmetric quantization, biases are a single value instead of per-block - ov::Shape bias_shape = layout.is_symmetric ? ov::Shape{} : scale_shape; + // For symmetric quantization, zp is a scalar value instead of per-block + // zp uses the same element type as weights (U4 or U8) + ov::Shape zp_shape = layout.is_symmetric ? ov::Shape{} : scale_shape; - ov::Tensor weights, scales, biases; + ov::Tensor weights, scales, zp; if (output_base_ptr) { uint8_t * buf_base = static_cast(output_base_ptr); weights = ov::Tensor(weight_type, node_shape, buf_base + layout.weights_offset); scales = ov::Tensor(ov::element::f16, scale_shape, buf_base + layout.scales_offset); - biases = ov::Tensor(ov::element::f16, bias_shape, buf_base + layout.biases_offset); + zp = ov::Tensor(weight_type, zp_shape, buf_base + layout.zp_offset); } else { weights = ov::Tensor(weight_type, node_shape); scales = ov::Tensor(ov::element::f16, scale_shape); - biases = ov::Tensor(ov::element::f16, bias_shape); + zp = ov::Tensor(weight_type, zp_shape); } - result = extract_quantized_weights(tensor, data, weights, scales, biases); + result = extract_quantized_weights(tensor, data, weights, scales, zp); } return result; @@ -705,7 +691,7 @@ std::shared_ptr process_weight_tensor(const ggml_tensor * tensor, cons void quantize_q4_0(const float * x, ov::Tensor & weights_arr, ov::Tensor & scales_arr, - ov::Tensor & biases_arr, + ov::Tensor & zp_arr, int64_t k, int64_t qk) { assert(k % qk == 0); @@ -713,8 +699,13 @@ void quantize_q4_0(const float * x, auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * biases = biases_arr.data::value_type>(); - bool is_scalar_bias = (biases_arr.get_size() == 1); // Symmetric quantization + auto * zp = static_cast(zp_arr.data()); + bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization + + // For Q4_0, zero point is always 8 + if (is_scalar_zp) { + zp[0] = 8 | (8 << 4); // Pack two 4-bit values + } for (int i = 0; i < nb; i++) { float amax = 0.0f; // absolute max @@ -732,27 +723,27 @@ void quantize_q4_0(const float * x, if (d == 0) { scales[i] = ov::float16(1.0f); - if (is_scalar_bias) { - if (i == 0) { - biases[0] = ov::float16(-8.0f); + // zp is already set to 8 for symmetric, or set per-block for asymmetric + if (!is_scalar_zp) { + if (i % 2 == 0) { + zp[i / 2] = 8; + } else { + zp[i / 2] |= (8 << 4); } - } else { - biases[i] = ov::float16(-8.0f); } - uint8_t zp = 8; - memset(weights + i * qk / 2, zp | (zp << 4), qk / 2); + memset(weights + i * qk / 2, 8 | (8 << 4), qk / 2); continue; } const float id = 1.0f / d; scales[i] = ov::float16(d); - // For symmetric quantization, only write the first bias (all blocks share the same bias relationship) - if (is_scalar_bias) { - if (i == 0) { - biases[0] = ov::float16(-8.f * d); + // For asymmetric quantization, store per-block zero points + if (!is_scalar_zp) { + if (i % 2 == 0) { + zp[i / 2] = 8; + } else { + zp[i / 2] |= (8 << 4); } - } else { - biases[i] = ov::float16(-8.f * d); } for (int j = 0; j < qk / 2; ++j) { @@ -768,7 +759,7 @@ void quantize_q4_0(const float * x, void quantize_q8_0(const float * x, ov::Tensor & weights_arr, ov::Tensor & scales_arr, - ov::Tensor & biases_arr, + ov::Tensor & zp_arr, int64_t k, int64_t qk) { assert(k % qk == 0); @@ -776,8 +767,13 @@ void quantize_q8_0(const float * x, auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * biases = biases_arr.data::value_type>(); - bool is_scalar_bias = (biases_arr.get_size() == 1); // Symmetric quantization + auto * zp = static_cast(zp_arr.data()); + bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization + + // For Q8_0, zero point is always 128 + if (is_scalar_zp) { + zp[0] = 128; + } for (int i = 0; i < nb; i++) { float amax = 0.0f; // absolute max @@ -792,13 +788,9 @@ void quantize_q8_0(const float * x, const float d = amax / 127.0f; const float id = d ? 1.0f / d : 0.0f; scales[i] = ov::float16(d); - // For symmetric quantization, only write the first bias (all blocks share the same bias relationship) - if (is_scalar_bias) { - if (i == 0) { - biases[0] = ov::float16(-128.0f * d); - } - } else { - biases[i] = ov::float16(-128.0f * d); + // For asymmetric quantization, store per-block zero points + if (!is_scalar_zp) { + zp[i] = 128; } for (int j = 0; j < qk; ++j) { @@ -812,7 +804,7 @@ void quantize_q8_0(const float * x, void quantize_q8_1(const float * x, ov::Tensor & weights_arr, ov::Tensor & scales_arr, - ov::Tensor & biases_arr, + ov::Tensor & zp_arr, int64_t k, int64_t qk) { assert(k % qk == 0); @@ -820,7 +812,7 @@ void quantize_q8_1(const float * x, auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * biases = biases_arr.data::value_type>(); + auto * zp = static_cast(zp_arr.data()); for (int i = 0; i < nb; i++) { float min = std::numeric_limits::max(); float max = std::numeric_limits::lowest(); @@ -838,7 +830,8 @@ void quantize_q8_1(const float * x, const float d = (max - min) / ((1 << 8) - 1); const float id = d ? 1.0f / d : 0.0f; scales[i] = ov::float16(d); - biases[i] = ov::float16(min); + // zp = -min / scale (Q8_1 is asymmetric) + zp[i] = (d != 0.0f) ? (uint8_t) std::round(-min / d) : 0; for (int j = 0; j < qk; ++j) { const float x0 = (x[i * qk + j] - min) * id; diff --git a/ggml/src/ggml-openvino/ggml-quants.hpp b/ggml/src/ggml-openvino/ggml-quants.hpp index a1334e2408d..67396892642 100644 --- a/ggml/src/ggml-openvino/ggml-quants.hpp +++ b/ggml/src/ggml-openvino/ggml-quants.hpp @@ -8,52 +8,52 @@ void unpack_32_4(const uint8_t* data, uint8_t* dst); -void extract_q4_0_data(const ggml_tensor* tensor, - ov::Tensor& weights_arr, - ov::Tensor& scales_arr, - ov::Tensor& biases_arr); +void extract_q4_0_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr); -void extract_q4_1_data(const ggml_tensor* tensor, - ov::Tensor& weights_arr, - ov::Tensor& scales_arr, - ov::Tensor& biases_arr); +void extract_q4_1_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr); -void extract_q8_0_data(const ggml_tensor* tensor, - ov::Tensor& weights_arr, - ov::Tensor& scales_arr, - ov::Tensor& biases_arr); +void extract_q8_0_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr); void unpack_256_4(const uint8_t* data, uint8_t* dst); -void extract_q4_k_data(const ggml_tensor* tensor, - ov::Tensor& weights_arr, - ov::Tensor& scales_arr, - ov::Tensor& biases_arr); +void extract_q4_k_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr); -void extract_q5_k_data(const ggml_tensor* tensor, - ov::Tensor& weights_arr, - ov::Tensor& scales_arr, - ov::Tensor& biases_arr); +void extract_q5_k_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr); -void extract_q6_k_data(const ggml_tensor* tensor, - ov::Tensor& weights_arr, - ov::Tensor& scales_arr, - ov::Tensor& biases_arr); +void extract_q6_k_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr); static constexpr size_t GGML_QUANTIZATION_GROUP_SIZE = 32; -ov::Output make_int8_weights(ov::Tensor& weight, - ov::Tensor& scales, - ov::Tensor& biases, +ov::Output make_int8_weights(ov::Tensor & weight, + ov::Tensor & scales, + ov::Tensor & zp, size_t group_size = GGML_QUANTIZATION_GROUP_SIZE); -ov::Output make_int4_weights(ov::Tensor& weight, - ov::Tensor& scales, - ov::Tensor& biases, +ov::Output make_int4_weights(ov::Tensor & weight, + ov::Tensor & scales, + ov::Tensor & zp, size_t group_size = GGML_QUANTIZATION_GROUP_SIZE); // Extract quantized weights from tensor and create weight subgraph -// If weights/scales/biases are provided (non-empty), uses them as output buffers +// If weights/scales/zp are provided (non-empty), uses them as output buffers // Otherwise allocates new ov::Tensors internally // Returns the weight node (make_int4_weights or make_int8_weights result) std::shared_ptr extract_quantized_weights( @@ -61,10 +61,10 @@ std::shared_ptr extract_quantized_weights( const void * data, // Source data pointer (may differ from tensor->data) ov::Tensor & weights, ov::Tensor & scales, - ov::Tensor & biases); + ov::Tensor & zp); // Requantize weights from tensor to target format, writing to provided buffers -// For F16 target, only weights buffer is used (scales/biases ignored) +// For F16 target, only weights buffer is used (scales/zp ignored) // Returns the weight node std::shared_ptr requantize_to_buffers(const ggml_tensor * tensor, const void * data, // Source data pointer @@ -72,7 +72,7 @@ std::shared_ptr requantize_to_buffers(const ggml_tensor * tensor, int64_t block_size, ov::Tensor & weights, ov::Tensor & scales, - ov::Tensor & biases); + ov::Tensor & zp); // Process weight tensor and create an OpenVINO constant node // Handles F16/F32/BF16 and quantized weights, with optional requantization @@ -84,11 +84,23 @@ std::shared_ptr process_weight_tensor( const void * data, // Source data pointer (may differ from tensor->data) void * output_base_ptr = nullptr); // Base pointer for output buffers (or nullptr for internal allocation) -void quantize_q4_0(const float* x, ov::Tensor& weights_arr, ov::Tensor& scales_arr, ov::Tensor& biases_arr, int64_t k, +void quantize_q4_0(const float * x, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr, + int64_t k, int64_t qk); -void quantize_q8_1(const float* x, ov::Tensor& weights_arr, ov::Tensor& scales_arr, ov::Tensor& biases_arr, int64_t k, +void quantize_q8_1(const float * x, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr, + int64_t k, int64_t qk); -void quantize_q8_0(const float* x, ov::Tensor& weights_arr, ov::Tensor& scales_arr, ov::Tensor& biases_arr, int64_t k, +void quantize_q8_0(const float * x, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr, + int64_t k, int64_t qk); namespace ov { From 6d71ded5faff3f59ede67a945b89207d4084a97f Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Fri, 6 Feb 2026 20:09:12 +0800 Subject: [PATCH 10/11] Refactor weight tensor processing --- ggml/src/ggml-openvino/ggml-decoder.cpp | 72 ++++++------ .../src/ggml-openvino/ggml-openvino-extra.cpp | 1 + ggml/src/ggml-openvino/ggml-openvino-extra.h | 41 ++++--- ggml/src/ggml-openvino/ggml-openvino.cpp | 101 +++++++---------- ggml/src/ggml-openvino/ggml-quants.cpp | 106 +++++++----------- ggml/src/ggml-openvino/ggml-quants.hpp | 37 +++++- 6 files changed, 181 insertions(+), 177 deletions(-) diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index d8d71cf25ee..da381e4fad7 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -550,11 +550,6 @@ std::map> GgmlOvDecoder::create_weight_no return model_weights; } -// Static cache for quantized weight nodes (keyed by tensor data pointer) -// This is a fallback for when tensors don't have pre-built constants in extra -static std::unordered_map> s_quantized_weight_cache; -static std::mutex s_quantized_weight_cache_mutex; - std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor * tensor) { // Check if we have a pre-built constant from the OpenVINO backend buffer // This is set during ggml_backend_openvino_buffer_set_tensor @@ -569,51 +564,62 @@ std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor * tensor if (extra_base->type == ggml_openvino_extra_base::Type::WEIGHT) { // F16/F32/BF16 weight with shared-memory constant auto * weight_extra = static_cast(tensor->extra); - if (weight_extra->constant) { - GGML_LOG_DEBUG("%s: using pre-built constant for %s\n", __func__, tensor->name); - return weight_extra->constant; + if (weight_extra->weight_node) { + GGML_LOG_DEBUG("%s: using pre-built weight node for %s\n", __func__, tensor->name); + return weight_extra->weight_node; } } else if (extra_base->type == ggml_openvino_extra_base::Type::QUANTIZED_WEIGHT) { // Quantized weight with pre-extracted data auto * quant_extra = static_cast(tensor->extra); - if (quant_extra->constant) { - GGML_LOG_DEBUG("%s: using pre-extracted quantized constant for %s\n", __func__, tensor->name); - return quant_extra->constant; + if (quant_extra->weight_node) { + GGML_LOG_DEBUG("%s: using pre-extracted quantized weight node for %s\n", __func__, tensor->name); + return quant_extra->weight_node; } } } - // Fallback: Check static cache for quantized weights (keyed by data pointer) - // This handles cases where tensors weren't loaded through OpenVINO buffer - if (ggml_is_quantized(tensor->type)) { - std::lock_guard lock(s_quantized_weight_cache_mutex); - auto it = s_quantized_weight_cache.find(tensor->data); - if (it != s_quantized_weight_cache.end()) { - GGML_LOG_DEBUG("%s: using cached quantized constant for %s\n", __func__, tensor->name); - return it->second; - } - } - - GGML_LOG_DEBUG("%s: creating new constant for %s (extra=%p)\n", __func__, tensor->name, tensor->extra); + // Fallback: tensor doesn't have a pre-built extra. The buffer type can only be + // openvino_host_buffer_type, which has enough space (get_alloc_size returns + // layout.total_size for quantized 2D tensors) to store extracted data in-place. + // Build the weight node and store it in tensor->extra for future reuse. + GGML_LOG_DEBUG("%s: creating new weight node for %s\n", __func__, tensor->name); - std::set weight_types = {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0, - GGML_TYPE_Q4_1, GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, GGML_TYPE_Q6_K}; + static const std::set weight_types = {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, + GGML_TYPE_Q8_0, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, + GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, GGML_TYPE_Q6_K}; if (weight_types.find(tensor->type) == weight_types.end()) { throw std::runtime_error("Unexpected weight tensor type: " + std::string(tensor->name) + " with type " + ggml_type_name(tensor->type)); } - std::shared_ptr result = process_weight_tensor(tensor, tensor->data, nullptr); - result->set_friendly_name(tensor->name); - - // Cache the quantized weight node for future reuse + OvWeight ov_weight; if (ggml_is_quantized(tensor->type)) { - std::lock_guard lock(s_quantized_weight_cache_mutex); - s_quantized_weight_cache[tensor->data] = result; - GGML_LOG_DEBUG("%s: cached quantized constant for %s\n", __func__, tensor->name); + // For quantized weights, copy raw data to a temp buffer first because + // process_weight_tensor reads from data and writes extracted results + // (weights/scales/zp) to output_base_ptr — they would overlap if both + // point to tensor->data. + size_t raw_size = ggml_nbytes(tensor); + std::vector tmp(raw_size); + memcpy(tmp.data(), tensor->data, raw_size); + ov_weight = process_weight_tensor(tensor, tmp.data(), tensor->data); + } else { + // For non-quantized weights (F16/F32/BF16), data is already in tensor->data. + // process_weight_tensor will create an ov::Tensor wrapping tensor->data directly. + ov_weight = process_weight_tensor(tensor, tensor->data, tensor->data); + } + + ov_weight.weight_node->set_friendly_name(tensor->name); + + ggml_openvino_extra_base * extra; + if (ov_weight.is_quantized()) { + extra = new ggml_openvino_quantized_weight_extra(std::move(ov_weight.weights), std::move(ov_weight.scales), + std::move(ov_weight.zp), ov_weight.weight_node); + } else { + extra = new ggml_openvino_weight_extra(std::move(ov_weight.weights), ov_weight.weight_node); } + ggml_openvino_buffer_register_extra(tensor, extra); - return result; + return ov_weight.weight_node; } void GgmlOvDecoder::dump_cgraph(const ggml_cgraph * cgraph, std::string & filename) { diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp index 4584dc38d0e..39bf7610eb5 100644 --- a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp @@ -319,6 +319,7 @@ ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_ten layout.scales_offset = ((layout.weights_size + alignment - 1) / alignment) * alignment; layout.zp_offset = layout.scales_offset + ((layout.scales_size + alignment - 1) / alignment) * alignment; layout.total_size = layout.zp_offset + layout.zp_size; + layout.total_size = std::max(layout.total_size, ggml_nbytes(tensor)); return layout; } diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.h b/ggml/src/ggml-openvino/ggml-openvino-extra.h index 726a90abb02..9ce46671548 100644 --- a/ggml/src/ggml-openvino/ggml-openvino-extra.h +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.h @@ -102,27 +102,30 @@ struct ggml_openvino_extra_base { explicit ggml_openvino_extra_base(Type t) : type(t) {} }; -// Extra data for F16/F32/BF16 weight tensors - stores the pre-built ov::Constant node +// Extra data for F16/F32/BF16 weight tensors - stores the pre-built weight node struct ggml_openvino_weight_extra : public ggml_openvino_extra_base { - std::shared_ptr constant; // Pre-built OpenVINO Constant node + ov::Tensor weights; // The underlying weight data tensor + std::shared_ptr weight_node; // Pre-built OpenVINO weight node - explicit ggml_openvino_weight_extra(std::shared_ptr c) - : ggml_openvino_extra_base(Type::WEIGHT), constant(std::move(c)) {} + ggml_openvino_weight_extra(ov::Tensor w, std::shared_ptr n) : + ggml_openvino_extra_base(Type::WEIGHT), + weights(std::move(w)), + weight_node(std::move(n)) {} }; -// Extra data for quantized weight tensors - stores extracted weights/scales/zp and ov::Constant +// Extra data for quantized weight tensors - stores extracted weights/scales/zp and weight node struct ggml_openvino_quantized_weight_extra : public ggml_openvino_extra_base { ov::Tensor weights; // U4 or U8 extracted weights ov::Tensor scales; // F16 scales ov::Tensor zp; // U4 or U8 zero points (same type as weights) - std::shared_ptr constant; // Pre-built OpenVINO weight subgraph + std::shared_ptr weight_node; // Pre-built OpenVINO weight subgraph - ggml_openvino_quantized_weight_extra(ov::Tensor w, ov::Tensor s, ov::Tensor z, std::shared_ptr c) : + ggml_openvino_quantized_weight_extra(ov::Tensor w, ov::Tensor s, ov::Tensor z, std::shared_ptr n) : ggml_openvino_extra_base(Type::QUANTIZED_WEIGHT), weights(std::move(w)), scales(std::move(s)), zp(std::move(z)), - constant(std::move(c)) {} + weight_node(std::move(n)) {} }; // Extra data for KV cache / compute tensors - stores ov::Tensor for infer_request @@ -140,19 +143,19 @@ struct ggml_openvino_tensor_extra : public ggml_openvino_extra_base { // Returns the total size needed in the buffer for extracted data. struct ggml_openvino_extracted_layout { - size_t total_size; // Total bytes needed - size_t weights_offset; // Offset to weights in buffer - size_t weights_size; // Size of weights in bytes - size_t scales_offset; // Offset to scales in buffer - size_t scales_size; // Size of scales in bytes - size_t zp_offset; // Offset to zero points in buffer - size_t zp_size; // Size of zero points in bytes (U4 or U8) - bool is_u4; // true for U4 weights, false for U8 + size_t total_size = 0; // Total bytes needed + size_t weights_offset = 0; // Offset to weights in buffer + size_t weights_size = 0; // Size of weights in bytes + size_t scales_offset = 0; // Offset to scales in buffer + size_t scales_size = 0; // Size of scales in bytes + size_t zp_offset = 0; // Offset to zero points in buffer + size_t zp_size = 0; // Size of zero points in bytes (U4 or U8) + bool is_u4; // true for U4 weights, false for U8 int64_t weights_per_block; // weights per scale/zp block bool is_symmetric; // true for symmetric quantization // Requantization info - bool is_requant; // true if this tensor needs requantization + bool is_requant = false; // true if this tensor needs requantization std::optional requant_type; // target requant type if is_requant }; @@ -160,3 +163,7 @@ struct ggml_openvino_extracted_layout { ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_tensor * tensor); ggml_openvino_tensor_extra * ggml_openvino_create_tensor_extra(const ggml_tensor * tensor, bool is_remote); + +// Register an extra with the tensor's OpenVINO buffer context for proper lifetime management. +// This sets tensor->extra and tracks the extra in the buffer context for cleanup. +void ggml_openvino_buffer_register_extra(ggml_tensor * tensor, ggml_openvino_extra_base * extra); diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index e531a9c0362..efd399fe3f2 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -230,80 +230,45 @@ static void ggml_backend_openvino_buffer_set_tensor(ggml_backend_buffer_t buffer // 2D tensor (typical weight shape) bool is_2d = (tensor->ne[2] == 1 && tensor->ne[3] == 1); - // Check if this is a quantized weight tensor that needs extraction/requantization - ggml_openvino_extracted_layout layout = {}; - if (is_weight_buffer && is_full_tensor_set && is_2d && ggml_is_quantized(tensor->type)) { - layout = ggml_openvino_get_extracted_layout(tensor); - } + if (is_weight_buffer && is_full_tensor_set && is_2d) { + try { + auto result = process_weight_tensor(tensor, data, tensor->data); + result.weight_node->set_friendly_name(tensor->name); - if (layout.total_size > 0) { - // Quantized weight tensor with extraction/requantization - uint8_t * buf_base = (uint8_t *) tensor->data; + const auto & layout = result.layout; + ggml_openvino_extra_base * extra; - try { - std::shared_ptr constant = process_weight_tensor(tensor, data, buf_base); - constant->set_friendly_name(tensor->name); - - // Store in tensor->extra - if (layout.is_requant && layout.requant_type.has_value() && - layout.requant_type.value() == ExtraQuantType::F16) { - // F16 requant case - use weight_extra - auto * extra = new ggml_openvino_weight_extra(constant); - ctx->tensor_extras[tensor] = extra; - tensor->extra = extra; - GGML_LOG_DEBUG("%s: requantized %s to F16\n", __func__, tensor->name); - } else { - // Quantized case - use quantized_weight_extra - // Create tensors with external memory (already filled by process_weight_tensor) - ov::element::Type weight_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; - ov::Shape weight_shape = {static_cast(tensor->ne[1]), static_cast(tensor->ne[0])}; - ov::Shape scale_shape = {static_cast(tensor->ne[1]), - static_cast(tensor->ne[0] / layout.weights_per_block)}; - // zp shape: scalar for symmetric, per-block for asymmetric - ov::Shape zp_shape = layout.is_symmetric ? ov::Shape{} : scale_shape; - - ov::Tensor weights(weight_type, weight_shape, buf_base + layout.weights_offset); - ov::Tensor scales(ov::element::f16, scale_shape, buf_base + layout.scales_offset); - ov::Tensor zp(weight_type, zp_shape, buf_base + layout.zp_offset); - - auto * extra = new ggml_openvino_quantized_weight_extra(std::move(weights), std::move(scales), - std::move(zp), constant); - ctx->tensor_extras[tensor] = extra; - tensor->extra = extra; + // Quantized path with extracted weight/scale/zp tensors + if (result.is_quantized()) { + extra = new ggml_openvino_quantized_weight_extra(std::move(result.weights), std::move(result.scales), + std::move(result.zp), result.weight_node); if (layout.is_requant) { GGML_LOG_DEBUG("%s: requantized %s to %s (u%d, block_size=%ld)\n", __func__, tensor->name, - layout.requant_type.value() == ExtraQuantType::Q4_0_128 ? "Q4_0_128" : "Q8_0_32", - layout.is_u4 ? 4 : 8, layout.weights_per_block); + extra_quant_type_name(layout.requant_type.value()), layout.is_u4 ? 4 : 8, + layout.weights_per_block); } else { int64_t n_blocks = ggml_nelements(tensor) / layout.weights_per_block; - GGML_LOG_DEBUG("%s: extracted quantized constant for %s (u%d, %zu weights, %ld blocks)\n", __func__, - tensor->name, layout.is_u4 ? 4 : 8, layout.weights_size, n_blocks); + GGML_LOG_DEBUG("%s: extracted quantized weight node for %s (u%d, %zu weights, %ld blocks)\n", + __func__, tensor->name, layout.is_u4 ? 4 : 8, layout.weights_size, n_blocks); } - } + } else { + // F16/F32/BF16 weight or F16-requant + extra = new ggml_openvino_weight_extra(std::move(result.weights), result.weight_node); - } catch (const std::exception & e) { - GGML_LOG_ERROR("%s: failed to process quantized data for %s: %s\n", __func__, tensor->name, e.what()); - // Fall back to storing raw data - memcpy((char *) tensor->data + offset, data, size); - } - } else if (is_weight_buffer && is_full_tensor_set && is_2d && - (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_F16 || tensor->type == GGML_TYPE_BF16)) { - // F16/F32/BF16 weight tensor - try { - std::shared_ptr constant = process_weight_tensor(tensor, data, tensor->data); - constant->set_friendly_name(tensor->name); + if (layout.total_size > 0) { + GGML_LOG_DEBUG("%s: requantized %s to F16\n", __func__, tensor->name); + } else { + GGML_LOG_DEBUG("%s: created shared-memory weight node for %s\n", __func__, tensor->name); + } + } - // Store in tensor->extra - ggml_openvino_weight_extra * extra = new ggml_openvino_weight_extra(constant); ctx->tensor_extras[tensor] = extra; tensor->extra = extra; - GGML_LOG_DEBUG("%s: created shared-memory constant for %s\n", __func__, tensor->name); - } catch (const std::exception & e) { - GGML_LOG_DEBUG("%s: failed to create shared-memory constant for %s: %s\n", __func__, tensor->name, - e.what()); + GGML_LOG_ERROR("%s: failed to process weight tensor for %s: %s\n", __func__, tensor->name, e.what()); + memcpy((char *) tensor->data + offset, data, size); } } else { // Non-weight tensor (KV cache, activations, etc.) - copy data @@ -604,6 +569,22 @@ size_t ggml_backend_openvino_buffer_get_ctx_id(ggml_backend_buffer_t buffer) { return ctx->id; } +void ggml_openvino_buffer_register_extra(ggml_tensor * tensor, ggml_openvino_extra_base * extra) { + GGML_ASSERT(tensor != nullptr); + GGML_ASSERT(tensor->buffer != nullptr); + GGML_ASSERT(ggml_backend_buffer_is_openvino(tensor->buffer)); + + auto * ctx = static_cast(tensor->buffer->context); + + auto it = ctx->tensor_extras.find(tensor); + if (it != ctx->tensor_extras.end()) { + delete it->second; + } + + ctx->tensor_extras[tensor] = extra; + tensor->extra = extra; +} + bool ggml_backend_buft_is_openvino(ggml_backend_buffer_type_t buft) { return buft->iface.get_name == ggml_backend_openvino_buffer_type_get_name; } diff --git a/ggml/src/ggml-openvino/ggml-quants.cpp b/ggml/src/ggml-openvino/ggml-quants.cpp index 2de0494c910..10909cbc1ef 100644 --- a/ggml/src/ggml-openvino/ggml-quants.cpp +++ b/ggml/src/ggml-openvino/ggml-quants.cpp @@ -576,10 +576,12 @@ std::shared_ptr requantize_to_buffers(const ggml_tensor * tensor, return result; } -std::shared_ptr process_weight_tensor(const ggml_tensor * tensor, const void * data, void * output_base_ptr) { +OvWeight process_weight_tensor(const ggml_tensor * tensor, const void * data, void * output_base_ptr) { GGML_ASSERT(tensor != nullptr); GGML_ASSERT(data != nullptr); + OvWeight result; + // Get 2D shape for weights [rows, cols] ov::Shape node_shape = {static_cast(tensor->ne[1]), static_cast(tensor->ne[0])}; @@ -600,18 +602,16 @@ std::shared_ptr process_weight_tensor(const ggml_tensor * tensor, cons OPENVINO_THROW("Unexpected tensor type in F16/F32/BF16 path"); } - if (output_base_ptr) { + if (output_base_ptr && output_base_ptr != data) { // Using external buffer - copy data and create shared-memory constant size_t tensor_bytes = ggml_nbytes(tensor); memcpy(output_base_ptr, data, tensor_bytes); - ov::Tensor ov_tensor(element_type, node_shape, output_base_ptr); - return std::make_shared(ov_tensor); + result.weights = ov::Tensor(element_type, node_shape, output_base_ptr); } else { - // Allocate internal buffer - ov::Tensor weights(element_type, node_shape); - memcpy(weights.data(), data, ggml_nelements(tensor) * element_type.size()); - return std::make_shared(weights); + result.weights = ov::Tensor(element_type, node_shape, data); } + result.weight_node = std::make_shared(result.weights); + return result; } // Handle quantized weights @@ -619,70 +619,48 @@ std::shared_ptr process_weight_tensor(const ggml_tensor * tensor, cons OPENVINO_THROW("Unsupported weight tensor type: ", ggml_type_name(tensor->type)); } - auto layout = ggml_openvino_get_extracted_layout(tensor); + result.layout = ggml_openvino_get_extracted_layout(tensor); + const auto & layout = result.layout; if (layout.total_size == 0) { OPENVINO_THROW("Unsupported quantized type: ", ggml_type_name(tensor->type)); } - std::shared_ptr result; - - if (layout.is_requant && layout.requant_type.has_value()) { - // Requantization path - if (layout.requant_type.value() == ExtraQuantType::F16) { - // Requant to F16 - ov::Tensor weights; - if (output_base_ptr) { - weights = ov::Tensor(ov::element::f16, node_shape, - static_cast(output_base_ptr) + layout.weights_offset); - } else { - weights = ov::Tensor(ov::element::f16, node_shape); - } - ov::Tensor dummy_scales, dummy_zp; // Not used for F16 - result = requantize_to_buffers(tensor, data, ExtraQuantType::F16, 0, weights, dummy_scales, dummy_zp); - } else { - // Requant to quantized format (Q4_0_128, Q8_0_32, etc.) - ov::element::Type weight_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; - ov::Shape scale_shape = {node_shape[0], node_shape[1] / layout.weights_per_block}; - // For symmetric quantization, zp is a scalar value instead of per-block - // zp uses the same element type as weights (U4 or U8) - ov::Shape zp_shape = layout.is_symmetric ? ov::Shape{} : scale_shape; - - ov::Tensor weights, scales, zp; - if (output_base_ptr) { - uint8_t * buf_base = static_cast(output_base_ptr); - weights = ov::Tensor(weight_type, node_shape, buf_base + layout.weights_offset); - scales = ov::Tensor(ov::element::f16, scale_shape, buf_base + layout.scales_offset); - zp = ov::Tensor(weight_type, zp_shape, buf_base + layout.zp_offset); - } else { - weights = ov::Tensor(weight_type, node_shape); - scales = ov::Tensor(ov::element::f16, scale_shape); - zp = ov::Tensor(weight_type, zp_shape); - } - - result = requantize_to_buffers(tensor, data, layout.requant_type.value(), layout.weights_per_block, weights, - scales, zp); - } - } else { - // Normal extraction path (no requant) - ov::element::Type weight_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; - ov::Shape scale_shape = {node_shape[0], node_shape[1] / layout.weights_per_block}; - // For symmetric quantization, zp is a scalar value instead of per-block - // zp uses the same element type as weights (U4 or U8) - ov::Shape zp_shape = layout.is_symmetric ? ov::Shape{} : scale_shape; - - ov::Tensor weights, scales, zp; + // F16 requant path - no separate scales/zp needed in result + if (layout.is_requant && layout.requant_type.has_value() && layout.requant_type.value() == ExtraQuantType::F16) { if (output_base_ptr) { - uint8_t * buf_base = static_cast(output_base_ptr); - weights = ov::Tensor(weight_type, node_shape, buf_base + layout.weights_offset); - scales = ov::Tensor(ov::element::f16, scale_shape, buf_base + layout.scales_offset); - zp = ov::Tensor(weight_type, zp_shape, buf_base + layout.zp_offset); + result.weights = ov::Tensor(ov::element::f16, node_shape, + static_cast(output_base_ptr) + layout.weights_offset); } else { - weights = ov::Tensor(weight_type, node_shape); - scales = ov::Tensor(ov::element::f16, scale_shape); - zp = ov::Tensor(weight_type, zp_shape); + result.weights = ov::Tensor(ov::element::f16, node_shape); } + ov::Tensor dummy_scales, dummy_zp; // Not used for F16 + result.weight_node = + requantize_to_buffers(tensor, data, ExtraQuantType::F16, 0, result.weights, dummy_scales, dummy_zp); + return result; + } - result = extract_quantized_weights(tensor, data, weights, scales, zp); + // Quantized path (normal extraction or quantized requant) + // Create weight/scale/zp tensors - shared between both paths + ov::element::Type weight_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; + ov::Shape scale_shape = {node_shape[0], node_shape[1] / layout.weights_per_block}; + ov::Shape zp_shape = layout.is_symmetric ? ov::Shape{} : scale_shape; + + if (output_base_ptr) { + uint8_t * buf_base = static_cast(output_base_ptr); + result.weights = ov::Tensor(weight_type, node_shape, buf_base + layout.weights_offset); + result.scales = ov::Tensor(ov::element::f16, scale_shape, buf_base + layout.scales_offset); + result.zp = ov::Tensor(weight_type, zp_shape, buf_base + layout.zp_offset); + } else { + result.weights = ov::Tensor(weight_type, node_shape); + result.scales = ov::Tensor(ov::element::f16, scale_shape); + result.zp = ov::Tensor(weight_type, zp_shape); + } + + if (layout.is_requant && layout.requant_type.has_value()) { + result.weight_node = requantize_to_buffers(tensor, data, layout.requant_type.value(), layout.weights_per_block, + result.weights, result.scales, result.zp); + } else { + result.weight_node = extract_quantized_weights(tensor, data, result.weights, result.scales, result.zp); } return result; diff --git a/ggml/src/ggml-openvino/ggml-quants.hpp b/ggml/src/ggml-openvino/ggml-quants.hpp index 67396892642..600b9c9f299 100644 --- a/ggml/src/ggml-openvino/ggml-quants.hpp +++ b/ggml/src/ggml-openvino/ggml-quants.hpp @@ -74,12 +74,43 @@ std::shared_ptr requantize_to_buffers(const ggml_tensor * tensor, ov::Tensor & scales, ov::Tensor & zp); -// Process weight tensor and create an OpenVINO constant node +inline const char * extra_quant_type_name(ExtraQuantType t) { + switch (t) { + case ExtraQuantType::F16: + return "F16"; + case ExtraQuantType::Q4_0_C: + return "Q4_0_C"; + case ExtraQuantType::Q4_0_128: + return "Q4_0_128"; + case ExtraQuantType::Q8_0_C: + return "Q8_0_C"; + case ExtraQuantType::Q8_0_32: + return "Q8_0_32"; + case ExtraQuantType::Q8_1_C: + return "Q8_1_C"; + default: + return "unknown"; + } +} + +// Result from process_weight_tensor containing the weight node and tensors. +// For quantized weights, also contains the extracted layout and scale/zp tensors. +struct OvWeight { + std::shared_ptr weight_node; + ggml_openvino_extracted_layout layout; // Only meaningful for quantized (layout.total_size > 0) + ov::Tensor weights; + ov::Tensor scales; + ov::Tensor zp; + + bool is_quantized() const { return layout.scales_size > 0; } +}; + +// Process weight tensor and create an OpenVINO weight node // Handles F16/F32/BF16 and quantized weights, with optional requantization // If output_base_ptr is nullptr, allocates internal buffers (for decoder use) // If output_base_ptr is provided, uses pre-allocated buffers at specified offsets (for backend buffer use) -// Returns the weight constant node -std::shared_ptr process_weight_tensor( +// Returns OvWeight with the weight node and optional quantized tensors +OvWeight process_weight_tensor( const ggml_tensor * tensor, const void * data, // Source data pointer (may differ from tensor->data) void * output_base_ptr = nullptr); // Base pointer for output buffers (or nullptr for internal allocation) From f60ee79578426aff93049e60523b2db7262011bc Mon Sep 17 00:00:00 2001 From: Xuejun Zhai Date: Sun, 8 Feb 2026 21:15:41 -0800 Subject: [PATCH 11/11] Add ov falback to CPU machinisim & verified with OV without ROPE support --- ggml/include/ggml.h | 4 +- ggml/src/ggml-backend.cpp | 1 + ggml/src/ggml-openvino/ggml-decoder.cpp | 240 +++++++++++++++++- ggml/src/ggml-openvino/ggml-decoder.h | 17 +- ggml/src/ggml-openvino/openvino/decoder.hpp | 2 + .../openvino/translate_session.cpp | 46 ++-- ggml/src/ggml-openvino/utils.cpp | 75 +++++- 7 files changed, 356 insertions(+), 29 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index b69583dd3fd..73c29b6243e 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -682,7 +682,9 @@ extern "C" { void * extra; // extra things e.g. for ggml-cuda.cu - char padding[8]; + char padding[16]; + // add a struct ggml_tensor * named org_src, initialized to NULL, for keeping track of original source tensors in case of in-place operations + struct ggml_tensor * org_src; }; static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor); diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 1b59924b8cb..745625af700 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1267,6 +1267,7 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra ggml_set_input(tensor_copy); ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor } + tensor_copy->org_src = src; tensor_id_copy(src_id, cur_backend_id, c) = tensor_copy; SET_CAUSE(tensor_copy, "4.cpy"); } diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index da381e4fad7..6ce8c17db87 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -71,6 +71,13 @@ GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph, set_input_output(cur_node); } + m_is_full_model = has_inp_tokens && has_output; + if (!m_is_full_model) { + compute_cgraph_dynamic_dims(); + add_extra_model_inputs_for_fallback(); + add_extra_model_outputs_for_fallback(); + } + for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) { m_node_info_list[node_n].node_op_case = compute_op_case(m_node_info_list[node_n].node); m_node_info_list[node_n].node_op_type = compute_op_type(m_node_info_list[node_n].node); @@ -164,6 +171,10 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) { current_node_info.node_inputs[src_name] = src; current_node_info.node_inputs_names.push_back(src_name); + if (is_inp_tok(src, node)) { + has_inp_tokens = true; + } + // Add model inputs if (!naive && !src->view_src) { ggml_backend_buffer * buffer = src->buffer; @@ -206,6 +217,9 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) { if (!naive) { // Model outputs are tensors with GGML_TENSOR_FLAG_OUTPUT flag and kv_caches static std::set debug_output_names = {}; + if (node->flags & GGML_TENSOR_FLAG_OUTPUT) { + has_output = true; + } // Workaround: the final tensor "result_output" does not have GGML_TENSOR_FLAG_OUTPUT flag set in cgraph if (node->op == GGML_OP_SET_ROWS || node->flags & GGML_TENSOR_FLAG_OUTPUT || debug_output_names.count(node_output_name)) { @@ -294,6 +308,9 @@ int GgmlOvDecoder::compute_op_case(const ggml_tensor * node) const { throw std::runtime_error("Unsupported VIEW case"); } op_case = 2; + if (!m_is_full_model && m_model_inputs.find(std::string(src->name)) != m_model_inputs.end()) { + op_case = 0; + } } break; } @@ -390,7 +407,7 @@ void GgmlOvDecoder::validate_cgraph() const { } } -ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * op, const ggml_tensor * input) const { +ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * op, const ggml_tensor * input, int dynamic_dim_index) const { auto name = std::string(input->name); ov::PartialShape input_shape; @@ -429,6 +446,9 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * op, co } else { input_shape = ov::PartialShape{get_shape(input)}; } + if (dynamic_dim_index != -1) { + input_shape[3-dynamic_dim_index] = -1; + } return input_shape; } @@ -906,3 +926,221 @@ const std::string & GgmlOvDecoder::get_op_type() const { static const std::string unknown_op = "UNKNOWN_GGML_OP"; return unknown_op; } + +/** + * @brief Computes the dynamic dimensions for the computation graph nodes to support fallback mechanisms. + * + * This function traverses the computation graph and determines the dynamic dimensions + * for each node based on its operation type and dependencies. The dynamic dimension + * is stored in the `m_node_dynamic_dims` map, where a value of -1 indicates no dynamic + * dimension. Specific operations such as GGML_OP_GET_ROWS, GGML_OP_MUL, GGML_OP_VIEW, + * etc., are handled to compute the dynamic dimension index. + * + * Key behaviors: + * - Nodes with operations like GGML_OP_NONE, GGML_OP_GET_ROWS, GGML_OP_MUL, and others + * are analyzed to determine their dynamic dimensions. + * - Nodes with specific names (e.g., "inp_tokens", "inp_pos", "inp_out_ids") are + * explicitly assigned a dynamic dimension index of 0. + * - For operations like GGML_OP_VIEW and GGML_OP_RESHAPE, the function ensures that + * the dynamic dimension is uniquely determined; otherwise, a warning is printed. + * - Unhandled operations print a message indicating the node name and operation type. + * + * This function is critical for preparing the computation graph for execution, ensuring + * that dynamic dimensions are correctly propagated and resolved. + */ +void GgmlOvDecoder::compute_cgraph_dynamic_dims() { + auto visit_node = [&](auto && self, ggml_tensor * node) -> void { + if (!node) { + return; + } + + if (node->op == GGML_OP_CPY) { + m_node_dynamic_dims[node] = -1; + } + + if (m_node_dynamic_dims.count(node)) { + return; + } + for (int i = 0; i < GGML_MAX_SRC; i++) { + ggml_tensor * src = node->src[i]; + if (src == nullptr) { + continue; + } + if (src->org_src) { + if (is_inp_tok(src->org_src, node) || is_inp_pos(src->org_src, node) || is_output_idx(src->org_src, node)) { + m_node_dynamic_dims[src->org_src] = 0; + m_node_dynamic_dims[src] = m_node_dynamic_dims[src->org_src]; + continue; + } + self(self, src->org_src); + m_node_dynamic_dims[src] = m_node_dynamic_dims[src->org_src]; + } else { + if (is_inp_tok(src, node) || is_inp_pos(src, node) || is_output_idx(src, node)) { + m_node_dynamic_dims[src] = 0; + continue; + } + self(self, src); + } + } + switch (node->op) { + case GGML_OP_NONE: + m_node_dynamic_dims[node] = -1; + // if (std::string(node->name) == "inp_tokens" || std::string(node->name) == "inp_pos" || + // std::string(node->name) == "inp_out_ids") { + // m_node_dynamic_dims[node] = 0; + // } + break; + case GGML_OP_GET_ROWS: + m_node_dynamic_dims[node] = -1; + if (m_node_dynamic_dims[node->src[1]] != -1) { + m_node_dynamic_dims[node] = 1; + } + break; + case GGML_OP_MUL: + case GGML_OP_MUL_MAT: + m_node_dynamic_dims[node] = -1; + if (m_node_dynamic_dims[node->src[0]] != -1) { + m_node_dynamic_dims[node] = m_node_dynamic_dims[node->src[0]]; + } + if (m_node_dynamic_dims[node->src[1]] != -1) { + m_node_dynamic_dims[node] = m_node_dynamic_dims[node->src[1]]; + } + break; + case GGML_OP_VIEW: + case GGML_OP_FLASH_ATTN_EXT: + case GGML_OP_PERMUTE: + case GGML_OP_RESHAPE: + m_node_dynamic_dims[node] = -1; + if (m_node_dynamic_dims[node->src[0]] != -1) { + auto dynamic_dim_idx = m_node_dynamic_dims[node->src[0]]; + auto dynamic_dim_value = node->src[0]->ne[dynamic_dim_idx]; + int same_dim_count = 0; + for (int i = 0; i < 4; i++) { + if (node->ne[i] == dynamic_dim_value) { + m_node_dynamic_dims[node] = i; + same_dim_count++; + } + } + if (same_dim_count != 1) { + std::cout << "Cannot determine dynamic dim for node: " << node->name << std::endl; + } + } + break; + case GGML_OP_RMS_NORM: + case GGML_OP_ADD: + case GGML_OP_GLU: + case GGML_OP_ROPE: + case GGML_OP_SCALE: + m_node_dynamic_dims[node] = m_node_dynamic_dims[node->src[0]]; + break; + case GGML_OP_CPY: + case GGML_OP_SET_ROWS: + m_node_dynamic_dims[node] = -1; + break; + default: + std::cout << "Doesn't handle node name: " << node->name << " op: " << ggml_op_name(node->op) << std::endl; + break; + } + }; + + for (int i = 0; i < m_cgraph->n_nodes; i++) { + ggml_tensor * node = m_cgraph->nodes[i]; + visit_node(visit_node, node); + } +} + +/** + * @brief Adds extra model outputs to support fallback mechanisms. + * + * This function ensures that all relevant nodes in the computation graph are included + * as model outputs for fallback scenarios. It creates a mapping of tensor data addresses + * to their corresponding nodes, excluding nodes with the GGML_OP_VIEW operation. + * + * Key behaviors: + * - Iterates through all nodes in the computation graph and maps their data addresses + * to the corresponding tensor nodes, skipping nodes with GGML_OP_VIEW. + * - Adds nodes to the `m_model_outputs` map if they are not already present, using + * the tensor's name as the key. + * + * This function is essential for ensuring that fallback mechanisms have access to all + * necessary model outputs, particularly in scenarios where certain outputs are not + * explicitly defined in the original model configuration. + */ +void GgmlOvDecoder::add_extra_model_outputs_for_fallback() { + std::map address_map; + for (int i = 0; i < m_cgraph->n_nodes; i++) { + ggml_tensor * node = m_cgraph->nodes[i]; + + if (node->flags & GGML_TENSOR_FLAG_OUTPUT) { + return; + } + + if (node->op == GGML_OP_VIEW) { + continue; + } + address_map[node->data] = node; + } + + for (const auto & pair : address_map) { + const std::string & name = pair.second->name; + if (m_model_outputs.find(name) == m_model_outputs.end()) { + m_model_outputs[name] = pair.second; + } + } +} + +/** +* @brief Adds extra model inputs to support fallback mechanisms. +* +* This function ensures that all necessary input nodes in the computation graph are +* included as model inputs for fallback scenarios. It iterates through the source nodes +* of each computation graph node and adds them to the `m_model_inputs` map if they meet +* specific criteria. +* +* Key behaviors: +* - Skips source nodes that are already present in `m_model_weights` or `m_model_inputs`. +* - Excludes intermediate nodes that are part of `m_node_info_list`. +* - For eligible source nodes, creates OpenVINO parameter nodes with appropriate types +* and shapes, and assigns them friendly names. +* - Updates the `m_inputs` and `m_model_inputs` maps with the new parameter nodes. +* +* This function is critical for ensuring that fallback mechanisms have access to all +* required model inputs, particularly in scenarios where certain inputs are not +* explicitly defined in the original model configuration. +*/ +void GgmlOvDecoder::add_extra_model_inputs_for_fallback() { + for (int i = 0; i < m_cgraph->n_nodes; i++) { + ggml_tensor * node = m_cgraph->nodes[i]; + for (int i = 0; i < GGML_MAX_SRC; i++) { + auto * src = node->src[i]; + if (src == nullptr) { + continue; + } + std::string src_name = std::string(src->name); + if (m_model_weights.find(src_name) != m_model_weights.end()) { + continue; + } + + bool is_intermediate_node = false; + for (const auto & node_info : m_node_info_list) { + if (node_info.node == src) { + is_intermediate_node = true; + break; + } + } + if (is_intermediate_node) { + continue; + } + if (m_model_inputs.find(src_name) != m_model_inputs.end()) { + continue; + } + + m_inputs[src_name] = src; + auto param_node = std::make_shared( + get_ov_type(src), get_graph_input_shape(node, src, m_node_dynamic_dims[src])); + param_node->set_friendly_name(src_name); + param_node->output(0).get_tensor().set_names({src_name}); + m_model_inputs[src_name] = param_node; + } + } +} diff --git a/ggml/src/ggml-openvino/ggml-decoder.h b/ggml/src/ggml-openvino/ggml-decoder.h index c8e3edeaf89..6be7239a9c8 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.h +++ b/ggml/src/ggml-openvino/ggml-decoder.h @@ -180,7 +180,7 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder { virtual bool is_stateful() const override { return m_is_stateful; } - ov::PartialShape get_graph_input_shape(const ggml_tensor * op, const ggml_tensor * input) const; + ov::PartialShape get_graph_input_shape(const ggml_tensor * op, const ggml_tensor * input, int dynamic_dim_index=-1) const; static void dump_cgraph(const ggml_cgraph * cgraph, std::string & filename); @@ -204,9 +204,12 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder { void set_compute_params(const ComputeParams & compute_params) { m_compute_params = compute_params; } + virtual bool is_full_model() const override {return m_is_full_model; } + bool m_is_static = false; bool m_is_stateful = false; bool m_is_prefill = false; + bool m_is_full_model = true; // label the cgraph is splited or not int m_prefill_chunk_size = 0; static ov::Shape get_shape(const ggml_tensor * tensor); @@ -249,6 +252,13 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder { void set_input_output(ggml_tensor * node, bool naive = false); int compute_op_case(const ggml_tensor * node) const; + // @brief Computes the dynamic dimensions for the computation graph nodes to support fallback mechanisms. + void compute_cgraph_dynamic_dims(); + // @brief Adds extra model outputs to support fallback mechanisms. + void add_extra_model_outputs_for_fallback(); + // @brief Adds extra model inputs to support fallback mechanisms. + void add_extra_model_inputs_for_fallback(); + void validate_cgraph() const; ggml_cgraph * m_cgraph = nullptr; @@ -261,6 +271,11 @@ class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder { std::map m_model_outputs; std::vector m_node_info_list; + std::map + m_node_dynamic_dims; // map from ggml_tensor to its dynamic dimension index, -1 means static + bool has_inp_tokens = false; + bool has_output = false; + ModelParams m_model_params; ComputeParams m_compute_params; }; diff --git a/ggml/src/ggml-openvino/openvino/decoder.hpp b/ggml/src/ggml-openvino/openvino/decoder.hpp index 3b8da2be5d2..1fe4ea6c811 100644 --- a/ggml/src/ggml-openvino/openvino/decoder.hpp +++ b/ggml/src/ggml-openvino/openvino/decoder.hpp @@ -53,6 +53,8 @@ class GgmlDecoder : public DecoderBase { virtual int get_op_case(int node_idx) const = 0; + virtual bool is_full_model() const = 0; + virtual const std::map>& get_model_inputs() const = 0; virtual const std::map>& get_model_extra_inputs() const = 0; virtual const std::map>& get_model_weights() const = 0; diff --git a/ggml/src/ggml-openvino/openvino/translate_session.cpp b/ggml/src/ggml-openvino/openvino/translate_session.cpp index b7e7b58531f..e0ffc7bedb6 100644 --- a/ggml/src/ggml-openvino/openvino/translate_session.cpp +++ b/ggml/src/ggml-openvino/openvino/translate_session.cpp @@ -77,26 +77,25 @@ ov::pass::MakeStateful::ParamResPairs get_kv_param_res_pairs( } void add_sliced_mask(TensorMap & tensor_map, GgmlDecoder & ggml_model_decoder) { - auto token_len_per_seq = tensor_map.at("token_len_per_seq").get_node_shared_ptr(); - auto create_sliced_mask = [&](const std::string & mask_name, const std::string & sliced_name, bool is_static) { - if (tensor_map.find(mask_name) != tensor_map.end()) { + if ((tensor_map.find(mask_name) != tensor_map.end()) && (tensor_map.find("token_len_per_seq") != tensor_map.end())){ + auto token_len_per_seq = tensor_map.at("token_len_per_seq").get_node_shared_ptr(); auto mask = tensor_map.at(mask_name).get_node_shared_ptr(); std::shared_ptr mask_sliced; if (is_static) { mask_sliced = mask; } else if (ggml_model_decoder.is_stateful()) { - auto zero_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {0,0}); - auto one_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,1}); + auto zero_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 0}); + auto one_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {1, 1}); auto zero_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); auto two_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {2}); - auto axes = ov::op::v0::Constant::create(ov::element::i64, {2}, {-2,-1}); + auto axes = ov::op::v0::Constant::create(ov::element::i64, {2}, {-2, -1}); auto inp_pos = tensor_map.at("inp_pos").get_node_shared_ptr(); auto shape_of_inp_pos = std::make_shared(inp_pos); auto gather_inp_pos = std::make_shared(shape_of_inp_pos, two_1d, zero_1d); - auto stop = std::make_shared(ov::OutputVector{token_len_per_seq, gather_inp_pos}, 0); - mask_sliced = - std::make_shared(mask, zero_2d, stop, one_2d, axes); + auto stop = + std::make_shared(ov::OutputVector{token_len_per_seq, gather_inp_pos}, 0); + mask_sliced = std::make_shared(mask, zero_2d, stop, one_2d, axes); mask_sliced = std::make_shared(mask_sliced, ov::element::f16); mask_sliced->set_friendly_name(sliced_name); } else { @@ -116,21 +115,24 @@ void add_sliced_mask(TensorMap & tensor_map, GgmlDecoder & ggml_model_decoder) { } void add_rope_sin_cos(TensorMap & tensor_map, GgmlDecoder & ggml_model_decoder) { - int32_t * rope_params = ggml_model_decoder.get_rope_params(); - auto inp_pos = tensor_map.at("inp_pos").get_node_shared_ptr(); - std::shared_ptr rope_freqs_weight; - if (tensor_map.find("rope_freqs_weight") != tensor_map.end()) { - rope_freqs_weight = tensor_map.at("rope_freqs.weight").get_node_shared_ptr(); - } + if ((tensor_map.find("rope_freqs_weight") != tensor_map.end()) && + (tensor_map.find("inp_pos") != tensor_map.end())) { + int32_t * rope_params = ggml_model_decoder.get_rope_params(); + auto inp_pos = tensor_map.at("inp_pos").get_node_shared_ptr(); + std::shared_ptr rope_freqs_weight; + if (tensor_map.find("rope_freqs_weight") != tensor_map.end()) { + rope_freqs_weight = tensor_map.at("rope_freqs.weight").get_node_shared_ptr(); + } - auto sin_cos = make_sin_cos(rope_params, inp_pos, rope_freqs_weight); - auto sin_theta = sin_cos.first; - auto cos_theta = sin_cos.second; + auto sin_cos = make_sin_cos(rope_params, inp_pos, rope_freqs_weight); + auto sin_theta = sin_cos.first; + auto cos_theta = sin_cos.second; - cos_theta.get_node_shared_ptr()->set_friendly_name("rope_cos"); - sin_theta.get_node_shared_ptr()->set_friendly_name("rope_sin"); - tensor_map.insert({"rope_cos", cos_theta}); - tensor_map.insert({"rope_sin", sin_theta}); + cos_theta.get_node_shared_ptr()->set_friendly_name("rope_cos"); + sin_theta.get_node_shared_ptr()->set_friendly_name("rope_sin"); + tensor_map.insert({"rope_cos", cos_theta}); + tensor_map.insert({"rope_sin", sin_theta}); + } } // Create common patterns diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index 69cac19019c..9cb22bad6b9 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -4,6 +4,7 @@ #include "ggml-openvino-extra.h" #include "ggml-openvino/ggml-decoder.h" #include "ggml.h" +#include "ggml-cpu.h" #include "openvino/frontend.hpp" #include "openvino/input_model.hpp" @@ -94,7 +95,9 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::strin if (cache_hit) { ggml_decoder = it->second; old_m_params = ggml_decoder->get_model_params(); - cache_hit = old_m_params.can_reuse_dynamically(m_params); + if (ggml_decoder->is_full_model()) { + cache_hit = old_m_params.can_reuse_dynamically(m_params); + } } if (cache_hit) { @@ -421,7 +424,7 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph) { } bool is_naive(ggml_cgraph * cgraph) { - constexpr int naive_graph_size_threshold = 20; + constexpr int naive_graph_size_threshold = 0; return cgraph->n_nodes < naive_graph_size_threshold; } @@ -475,7 +478,7 @@ namespace { ov::Tensor convert_ggml_input_to_ov(std::shared_ptr ggml_decoder, const std::string & name) { const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(name); - if (ggml_tensor->extra != nullptr) { + if (ggml_tensor->extra != nullptr && ggml_decoder->is_full_model()) { // GGML_LOG_DEBUG("Using ggml_tensor->extra as ov::Tensor for input: %s\n", name.c_str()); auto * extra_base = static_cast(ggml_tensor->extra); if (extra_base->type != ggml_openvino_extra_base::Type::TENSOR) { @@ -488,12 +491,76 @@ ov::Tensor convert_ggml_input_to_ov(std::shared_ptr ggml_decoder, // GGML_LOG_DEBUG("Converting ggml tensor to ov::Tensor for input: %s\n", name.c_str()); auto * input_data = ggml_tensor->data; ov::Shape input_shape; - if (ggml_tensor->op == GGML_OP_VIEW) { + if (0) { // This case is added to make test-backend-ops work input_shape = ggml_decoder->get_shape(ggml_tensor->view_src); } else { input_shape = ggml_decoder->get_shape(ggml_tensor); } + + // If the tensor is a result of PERMUTE operation, use ggml_cont to make it contiguous + if (ggml_tensor->op == GGML_OP_PERMUTE && !ggml_decoder->is_full_model()) { + // Create a temporary context for ggml_cont operation + // Need space for: tensor overhead, tensor data, graph structure, and work buffer + size_t mem_size = ggml_tensor_overhead() * 4 + ggml_nbytes(ggml_tensor) * 2 + 1024 * 1024; + struct ggml_init_params params = { + /*.mem_size =*/mem_size, + /*.mem_buffer =*/NULL, + /*.no_alloc =*/false, + }; + struct ggml_context * temp_ctx = ggml_init(params); + if (temp_ctx == NULL) { + throw std::runtime_error("Failed to initialize temporary context for PERMUTE"); + } + + // Create contiguous tensor using ggml_cont + struct ggml_tensor * cont_tensor = ggml_cont(temp_ctx, const_cast(ggml_tensor)); + + // Build a simple graph to compute ggml_cont + struct ggml_cgraph * gf = ggml_new_graph(temp_ctx); + ggml_build_forward_expand(gf, cont_tensor); + ggml_graph_compute_with_ctx(temp_ctx, gf, 1); + + // Create OpenVINO tensor with contiguous data + ov::Tensor input_tensor(ggml_decoder->get_ov_type(ggml_tensor), input_shape); + memcpy(input_tensor.data(), cont_tensor->data, ggml_nbytes(cont_tensor)); + + // Free temporary context + ggml_free(temp_ctx); + + return input_tensor; + } + + // If the tensor is a result of VIEW operation, use ggml_cont to make it contiguous + if (ggml_tensor->op == GGML_OP_VIEW && !ggml_decoder->is_full_model()) { + // if the ggml_tensor shape size is equal to the source tensor shape size, no need to reconstruct the ov input tensor data + if (ggml_nelements(ggml_tensor) == ggml_nelements(ggml_tensor->view_src)) { + auto input_tensor = ov::Tensor(ggml_decoder->get_ov_type(ggml_tensor), input_shape, input_data); + return input_tensor; + } + + // Create OpenVINO input tensor, the data need to reconstructed based on the view tensor shape & stride + // Todo: parallel copy & the copy the whole last dim one loop (perf improve) + ov::Tensor input_tensor(ggml_decoder->get_ov_type(ggml_tensor), input_shape); + const auto * src_tensor = ggml_tensor->view_src; + size_t des_index = 0; + for (size_t i0 = 0; i0 < static_cast(ggml_tensor->ne[3]); i0++) { + for (size_t i1 = 0; i1 < static_cast(ggml_tensor->ne[2]); i1++) { + for (size_t i2 = 0; i2 < static_cast(ggml_tensor->ne[1]); i2++) { + for (size_t i3 = 0; i3 < static_cast(ggml_tensor->ne[0]); i3++) { + size_t src_index = ggml_tensor->view_offs + i0 * ggml_tensor->nb[3] + i1 * ggml_tensor->nb[2] + + i2 * ggml_tensor->nb[1] + i3 * ggml_tensor->nb[0]; + + memcpy(static_cast(input_tensor.data()) + des_index, + static_cast(src_tensor->data) + src_index, ggml_tensor->nb[0]); + des_index += ggml_tensor->nb[0]; + } + } + } + } + return input_tensor; + } + auto input_tensor = ov::Tensor(ggml_decoder->get_ov_type(ggml_tensor), input_shape, input_data); return input_tensor; }