From d21ef679b5fb9d1559976697277afbae3c1aef73 Mon Sep 17 00:00:00 2001 From: Juan David Garcia Date: Tue, 3 Mar 2026 17:03:32 -0500 Subject: [PATCH 1/4] feat: update llama.cpp submodule and bindings for Qwen 3.5 support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Updates the llama.cpp submodule to da348c9df which includes support for the Qwen 3.5 model architecture (hybrid SSM + attention). Changes to Python bindings: 1. llama_cpp.py: Sync llama_context_params struct with upstream C API - flash_attn (bool) → flash_attn_type (enum llama_flash_attn_type) - Add samplers (void*) and n_samplers (size_t) fields - Add LLAMA_FLASH_ATTN_TYPE_* enum constants 2. llama.py: Update flash_attn parameter handling - Map flash_attn=True/False to flash_attn_type=1/0 3. _ctypes_extensions.py: Graceful handling of deprecated symbols - ctypes_function decorator returns stub instead of crashing when a symbol is not found in the shared library Tested with Qwen3.5-0.8B-Q4_K_M.gguf on Apple Silicon (M1 Pro): - Cold start: ~4s (vs ~40s with mlx-vlm) - Inference: ~0.6s per chat completion - Model loads and runs correctly on Metal GPU --- llama_cpp/_ctypes_extensions.py | 21 ++++++++++++++----- llama_cpp/llama.py | 11 ++++++++-- llama_cpp/llama_cpp.py | 36 ++++++++++++++++++++++++--------- vendor/llama.cpp | 2 +- 4 files changed, 52 insertions(+), 18 deletions(-) diff --git a/llama_cpp/_ctypes_extensions.py b/llama_cpp/_ctypes_extensions.py index e88ed387d..2700466fa 100644 --- a/llama_cpp/_ctypes_extensions.py +++ b/llama_cpp/_ctypes_extensions.py @@ -110,11 +110,22 @@ def ctypes_function( ): def decorator(f: F) -> F: if enabled: - func = getattr(lib, name) - func.argtypes = argtypes - func.restype = restype - functools.wraps(f)(func) - return func + try: + func = getattr(lib, name) + func.argtypes = argtypes + func.restype = restype + functools.wraps(f)(func) + return func + except AttributeError: + # Symbol not found in shared library (deprecated/removed) + @functools.wraps(f) + def stub(*args: Any, **kwargs: Any) -> Any: + raise NotImplementedError( + f"Symbol '{name}' not found in shared library. The C API might " + "have been removed or deprecated." + ) + + return stub # type: ignore else: return f diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 71d94ebd8..44005bb6b 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -341,7 +341,11 @@ def __init__( self._logits_all = logits_all if draft_model is None else True self.context_params.embeddings = embedding # TODO: Rename to embeddings self.context_params.offload_kqv = offload_kqv - self.context_params.flash_attn = flash_attn + self.context_params.flash_attn_type = ( + llama_cpp.LLAMA_FLASH_ATTN_TYPE_ENABLED + if flash_attn + else llama_cpp.LLAMA_FLASH_ATTN_TYPE_DISABLED + ) if op_offload is not None: self.context_params.op_offload = op_offload @@ -2096,7 +2100,10 @@ def __getstate__(self): logits_all=self._logits_all, embedding=self.context_params.embeddings, offload_kqv=self.context_params.offload_kqv, - flash_attn=self.context_params.flash_attn, + flash_attn=( + self.context_params.flash_attn_type + == llama_cpp.LLAMA_FLASH_ATTN_TYPE_ENABLED + ), op_offload=self.context_params.op_offload, swa_full=self.context_params.swa_full, # Sampling Params diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 711d42a6a..a306c313f 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -463,6 +463,16 @@ LLAMA_ATTENTION_TYPE_NON_CAUSAL = 1 +# enum llama_flash_attn_type { +# LLAMA_FLASH_ATTN_TYPE_AUTO = -1, +# LLAMA_FLASH_ATTN_TYPE_DISABLED = 0, +# LLAMA_FLASH_ATTN_TYPE_ENABLED = 1, +# }; +LLAMA_FLASH_ATTN_TYPE_AUTO = -1 +LLAMA_FLASH_ATTN_TYPE_DISABLED = 0 +LLAMA_FLASH_ATTN_TYPE_ENABLED = 1 + + # enum llama_split_mode { # LLAMA_SPLIT_MODE_NONE = 0, // single GPU # LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs @@ -761,6 +771,7 @@ class llama_model_params(ctypes.Structure): # enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` # enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id # enum llama_attention_type attention_type; // attention type to use for embeddings +# enum llama_flash_attn_type flash_attn_type; // when to enable Flash Attention # // ref: https://github.com/ggml-org/llama.cpp/pull/2054 # float rope_freq_base; // RoPE base frequency, 0 = from model @@ -770,7 +781,7 @@ class llama_model_params(ctypes.Structure): # float yarn_beta_fast; // YaRN low correction dim # float yarn_beta_slow; // YaRN high correction dim # uint32_t yarn_orig_ctx; // YaRN original context size -# float defrag_thold; // defragment the KV cache if holes/size > thold, <= 0 disabled (default) +# float defrag_thold; // [DEPRECATED] defragment the KV cache if holes/size > thold, <= 0 disabled (default) # ggml_backend_sched_eval_callback cb_eval; # void * cb_eval_user_data; @@ -787,15 +798,14 @@ class llama_model_params(ctypes.Structure): # // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value. # bool embeddings; // if true, extract embeddings (together with logits) # bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU -# bool flash_attn; // use flash attention [EXPERIMENTAL] # bool no_perf; // measure performance timings # bool op_offload; // offload host tensor operations to device -# bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) -# // NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases -# // ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573 +# bool swa_full; // use full-size SWA cache # bool kv_unified; // use a unified buffer across the input sequences when computing the attention -# // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix -# // ref: https://github.com/ggml-org/llama.cpp/pull/14363 + +# // [EXPERIMENTAL] +# struct llama_sampler_seq_config * samplers; +# size_t n_samplers; # }; class llama_context_params(ctypes.Structure): """Parameters for llama_context @@ -810,6 +820,7 @@ class llama_context_params(ctypes.Structure): rope_scaling_type (int): RoPE scaling type, from `enum llama_rope_scaling_type` pooling_type (int): whether to pool (sum) embedding results by sequence id (ignored if no pooling layer) attention_type (int): attention type to use for embeddings + flash_attn_type (int): when to enable Flash Attention, from `enum llama_flash_attn_type` rope_freq_base (float): RoPE base frequency, 0 = from model rope_freq_scale (float): RoPE frequency scaling factor, 0 = from model yarn_ext_factor (float): YaRN extrapolation mix factor, negative = from model @@ -826,11 +837,12 @@ class llama_context_params(ctypes.Structure): abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback embeddings (bool): if true, extract embeddings (together with logits) offload_kqv (bool): whether to offload the KQV ops (including the KV cache) to GPU - flash_attn (bool): whether to use flash attention no_perf (bool): whether to measure performance timings op_offload (bool): offload host tensor operations to device swa_full (bool): use full-size SWA cache kv_unified (bool): use a unified buffer across the input sequences when computing the attention + samplers (ctypes.c_void_p): backend sampler chain configuration [EXPERIMENTAL] + n_samplers (ctypes.c_size_t): number of backend sampler chains """ if TYPE_CHECKING: @@ -843,6 +855,7 @@ class llama_context_params(ctypes.Structure): rope_scaling_type: int pooling_type: int attention_type: int + flash_attn_type: int rope_freq_base: float rope_freq_scale: float yarn_ext_factor: float @@ -859,11 +872,12 @@ class llama_context_params(ctypes.Structure): abort_callback_data: ctypes.c_void_p embeddings: bool offload_kqv: bool - flash_attn: bool no_perf: bool op_offload: bool swa_full: bool kv_unified: bool + samplers: ctypes.c_void_p + n_samplers: ctypes.c_size_t _fields_ = [ ("n_ctx", ctypes.c_uint32), @@ -875,6 +889,7 @@ class llama_context_params(ctypes.Structure): ("rope_scaling_type", ctypes.c_int), ("pooling_type", ctypes.c_int), ("attention_type", ctypes.c_int), + ("flash_attn_type", ctypes.c_int), ("rope_freq_base", ctypes.c_float), ("rope_freq_scale", ctypes.c_float), ("yarn_ext_factor", ctypes.c_float), @@ -891,11 +906,12 @@ class llama_context_params(ctypes.Structure): ("abort_callback_data", ctypes.c_void_p), ("embeddings", ctypes.c_bool), ("offload_kqv", ctypes.c_bool), - ("flash_attn", ctypes.c_bool), ("no_perf", ctypes.c_bool), ("op_offload", ctypes.c_bool), ("swa_full", ctypes.c_bool), ("kv_unified", ctypes.c_bool), + ("samplers", ctypes.c_void_p), + ("n_samplers", ctypes.c_size_t), ] diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 4227c9be4..da348c9df 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 4227c9be4268ac844921b90f31595f81236bd317 +Subproject commit da348c9dfbcfab16584f4640ee53146fdf85a741 From eacc2584fa52f8810ea7903966b8fd38a05538f0 Mon Sep 17 00:00:00 2001 From: r-dh Date: Wed, 4 Mar 2026 20:47:22 +0100 Subject: [PATCH 2/4] fix: set BUILD_NUMBER and LLAMA_INSTALL_VERSION for mtmd build --- CMakeLists.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4b06d98b3..0acf1e675 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -46,6 +46,7 @@ function(llama_cpp_python_install_target target) endfunction() if (LLAMA_BUILD) + set(BUILD_NUMBER 0 CACHE STRING "Build number" FORCE) set(BUILD_SHARED_LIBS "On") set(CMAKE_SKIP_BUILD_RPATH FALSE) @@ -154,6 +155,10 @@ if (LLAMA_BUILD) endif() # Building llava + # Set LLAMA_INSTALL_VERSION for mtmd (not inherited from llama.cpp subdirectory scope) + if(NOT DEFINED LLAMA_INSTALL_VERSION) + set(LLAMA_INSTALL_VERSION "0.0.0") + endif() add_subdirectory(vendor/llama.cpp/tools/mtmd) if (WIN32) From 01248477d742e3e9991df1af5f682493bfd62d7d Mon Sep 17 00:00:00 2001 From: r-dh Date: Wed, 4 Mar 2026 20:47:26 +0100 Subject: [PATCH 3/4] fix: return bool from kv_cache_seq_rm for partial removal detection --- llama_cpp/_internals.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py index b5175a7f2..b440ef7a9 100644 --- a/llama_cpp/_internals.py +++ b/llama_cpp/_internals.py @@ -290,10 +290,10 @@ def kv_cache_clear(self): assert self.memory is not None, "Memory is not initialized" llama_cpp.llama_memory_clear(self.memory, True) - def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int): + def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int) -> bool: assert self.memory is not None, "Memory is not initialized" seq_id = seq_id if seq_id >= 0 else 0 - llama_cpp.llama_memory_seq_rm(self.memory, seq_id, p0, p1) + return llama_cpp.llama_memory_seq_rm(self.memory, seq_id, p0, p1) def kv_cache_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int): assert self.memory is not None, "Memory is not initialized" From 47aedc229966a6bb502ea9fdb6ee48c7c2409c15 Mon Sep 17 00:00:00 2001 From: r-dh Date: Wed, 4 Mar 2026 20:47:31 +0100 Subject: [PATCH 4/4] fix: handle GDN hybrid models that reject partial memory removal --- llama_cpp/llama.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 44005bb6b..f76b69d03 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -892,13 +892,23 @@ def generate( else: break if longest_prefix > 0: - reset = False - tokens = tokens[longest_prefix:] - self.n_tokens = longest_prefix - if self.verbose: + # Try to trim the KV cache to prefix length. Hybrid models + # (e.g. GDN) may not support partial removal — in that case we + # fall through to the full reset path below. + if self._ctx.kv_cache_seq_rm(-1, longest_prefix, -1): + reset = False + tokens = tokens[longest_prefix:] + self.n_tokens = longest_prefix + if self.verbose: + print( + f"Llama.generate: {longest_prefix} prefix-match hit, " + f"remaining {len(tokens)} prompt tokens to eval", + file=sys.stderr, + ) + elif self.verbose: print( - f"Llama.generate: {longest_prefix} prefix-match hit, " - f"remaining {len(tokens)} prompt tokens to eval", + f"Llama.generate: {longest_prefix} prefix-match found " + f"but partial kv removal not supported, re-evaluating full prompt", file=sys.stderr, ) @@ -1045,7 +1055,7 @@ def embed( data: Union[List[List[float]], List[List[List[float]]]] = [] def decode_batch(seq_sizes: List[int]): - llama_cpp.llama_kv_self_clear(self._ctx.ctx) + llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(self._ctx.ctx), True) self._ctx.decode(self._batch) self._batch.reset() @@ -1116,7 +1126,7 @@ def decode_batch(seq_sizes: List[int]): output = data[0] if isinstance(input, str) else data - llama_cpp.llama_kv_self_clear(self._ctx.ctx) + llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(self._ctx.ctx), True) self.reset() if return_count: