From 92361ef5f32db8ce261b104a7be751ce39551bff Mon Sep 17 00:00:00 2001 From: Mika Senghaas Date: Thu, 11 Jun 2026 03:30:26 +0000 Subject: [PATCH] perf(mm): ephemeral pixel return + smaller image cache to bound rollout RAM MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The renderer-client rollout path keeps and ships the full processed ``pixel_values`` (tens of MB per large image) for every image on every turn, so resident multimodal memory grows with turns x concurrency — at 256 concurrent 1024^2 rollouts the trace/env-worker alone retains ~86 GB. Two opt-in, default-off modes shrink this; behaviour is unchanged unless a consumer sets the env flag: - ``RENDERERS_MM_EPHEMERAL`` (stored data): ``generate`` returns a descriptor-only ``multi_modal_data`` (image_grid_thw + mm_hashes + mm_placeholders, no ``pixel_values``), so the trajectory never retains decoded tensors. Stored mm becomes O(1) per image (a descriptor) instead of O(image-pixels) — 86 GB -> 5.8 GB at 256 concurrent 1024^2 rollouts, and the per-rollout slope drops ~20x. The consumer re-derives pixels downstream from the message images. Purely client-side; safe on any engine incl. vLLM 0.22. - ``RENDERERS_MM_HASH_CACHE`` (transported data): send each image's pixels once, then descriptor-only (``None`` kwargs) so the engine serves it from its mm-hash cache. ``_build_qwen_vl_features`` is now descriptor-aware (per-item ``None`` slots aligned to ``mm_hashes``); a sent-hash memory drives it with a cache-miss fallback. REQUIRES an engine that resolves ``None`` from cache (the disagg router topology) — a plain single-server vLLM 0.22 forces ``skip_mm_cache=True`` and crashes on an unresolved ``None``, so this stays OFF until the engine supports it. Also lower ``image_cache_max`` default 256 -> 32: each entry holds a decoded pixel tensor and the pool holds one cache per renderer, so 256 capped resident cache memory at ~pool_size x 256 x pixel_bytes (tens of GB for large images). Co-Authored-By: Claude Opus 4.8 (1M context) --- renderers/client.py | 219 ++++++++++++++++++++++++++++++++++++------- renderers/configs.py | 13 ++- tests/test_client.py | 128 +++++++++++++++++++++++++ 3 files changed, 321 insertions(+), 39 deletions(-) diff --git a/renderers/client.py b/renderers/client.py index 0c63c0e..95ce624 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -14,11 +14,15 @@ import asyncio import json import logging +import os +import threading +from collections import OrderedDict from collections.abc import Mapping +from dataclasses import replace from typing import Any, cast import httpx -from openai import AsyncOpenAI +from openai import AsyncOpenAI, OpenAIError from renderers.base import ( Message, @@ -31,6 +35,92 @@ ) _request_logger = logging.getLogger("renderers.client") + +# --- multimodal memory modes ------------------------------------------------- +# Two independent, opt-in modes shrink the multimodal data the rollout path keeps +# and ships. Both default OFF so behaviour is unchanged unless a consumer opts in. +# +# EPHEMERAL (``RENDERERS_MM_EPHEMERAL``) — *stored* data. ``generate`` returns a +# descriptor-only ``multi_modal_data`` (image_grid_thw + mm_hashes + +# mm_placeholders, no ``pixel_values``), so the env worker / trace never retains +# decoded image tensors for the life of a rollout — resident stored mm drops from +# the processed-tensor size to a tiny descriptor (the source image still lives in +# the message). The consumer re-derives pixels downstream (training-sample +# construction) from the message images. Safe on any engine (purely client-side). +# +# HASH-ONLY (``RENDERERS_MM_HASH_CACHE``) — *transported* data. Send each distinct +# image's ``pixel_values`` once; later turns send a descriptor-only (``None`` +# kwargs) slot and the engine serves it from its mm-hash cache, so the POST body +# carries only new images. REQUIRES an engine that resolves ``None`` slots from +# cache (the disagg router topology). A plain single-server vLLM 0.22 forces +# ``skip_mm_cache=True`` and crashes on an unresolved ``None`` — hence default OFF. +def _env_on(name: str) -> bool: + return os.environ.get(name, "0").lower() not in ("0", "false", "off", "no", "") + + +_MM_EPHEMERAL = _env_on("RENDERERS_MM_EPHEMERAL") +_MM_HASH_CACHE_ENABLED = _env_on("RENDERERS_MM_HASH_CACHE") +# Per-(endpoint, model) FIFO bound on remembered sent hashes. Should comfortably +# exceed the engine's mm cache so we don't claim cache hits the engine evicted; +# a stale claim is corrected by the cache-miss fallback, never a wrong result. +_MM_SENT_HASHES_MAX = int(os.environ.get("RENDERERS_MM_SENT_HASHES_MAX", "16384")) +_mm_sent: dict[tuple[str, str], "OrderedDict[str, None]"] = {} +_mm_sent_lock = threading.Lock() + + +def _mm_seen(key: tuple[str, str], hashes: list[str]) -> set[str]: + """Subset of ``hashes`` already sent under ``key`` (believed engine-cached).""" + with _mm_sent_lock: + lru = _mm_sent.get(key) + if not lru: + return set() + seen = set() + for h in hashes: + if h in lru: + seen.add(h) + lru.move_to_end(h) + return seen + + +def _mm_record(key: tuple[str, str], hashes: list[str]) -> None: + """Record ``hashes`` as sent-in-full under ``key`` (FIFO-bounded).""" + if not hashes: + return + with _mm_sent_lock: + lru = _mm_sent.setdefault(key, OrderedDict()) + for h in hashes: + lru[h] = None + lru.move_to_end(h) + while len(lru) > _MM_SENT_HASHES_MAX: + lru.popitem(last=False) + + +def _mm_forget(key: tuple[str, str], hashes: list[str]) -> None: + """Drop ``hashes`` from the sent set (after a suspected engine cache miss).""" + with _mm_sent_lock: + lru = _mm_sent.get(key) + if lru: + for h in hashes: + lru.pop(h, None) + + +def _strip_pixels_for(mm_data: MultiModalData, hashes: set[str]) -> MultiModalData: + """Return ``mm_data`` with ``pixel_values`` dropped from image items whose + hash is in ``hashes`` (descriptor — ``image_grid_thw`` etc. — retained, so + token alignment and a later re-materialize still work).""" + if not hashes or not mm_data.mm_items: + return mm_data + new_items: dict[str, list[dict[str, Any]]] = {} + for modality, items in mm_data.mm_items.items(): + item_hashes = mm_data.mm_hashes.get(modality) or [] + rebuilt = [] + for i, item in enumerate(items): + if i < len(item_hashes) and item_hashes[i] in hashes: + rebuilt.append({k: v for k, v in item.items() if k != "pixel_values"}) + else: + rebuilt.append(item) + new_items[modality] = rebuilt + return replace(mm_data, mm_items=new_items) ROUTED_EXPERTS_DATA_PREFIX = b'"routed_experts":{"data":"' @@ -248,13 +338,6 @@ def _prepare(): "token_ids": prompt_ids, "sampling_params": sp, } - features = ( - _build_mm_features(renderer, mm_data) - if mm_data and not mm_data.is_empty() - else None - ) - if features is not None: - body["features"] = features if cache_salt is not None: body["cache_salt"] = cache_salt if priority is not None: @@ -271,13 +354,70 @@ def _prepare(): len(prompt_ids), sp.get("max_tokens"), ) - post_kwargs: dict[str, Any] = { - "cast_to": httpx.Response, - "body": body, - } - if extra_headers: - post_kwargs["options"] = cast(Any, {"headers": extra_headers}) - raw_response = await client.post(endpoint, **post_kwargs) + + mm_key = (endpoint, model) + img_hashes: list[str] = ( + list(mm_data.mm_hashes.get("image") or []) + if mm_data is not None and not mm_data.is_empty() + else [] + ) + + # Hash-only SEND (opt-in, ``RENDERERS_MM_HASH_CACHE``): strip already-sent + # images to descriptors so the engine serves them from its mm-hash cache and + # only new images carry pixels. Requires an engine that resolves ``None`` + # kwargs slots from cache (the disagg router topology); a plain single-server + # vLLM crashes on an unresolved ``None``, so this defaults OFF. + hash_only = False + sent_full = list(img_hashes) # hashes shipped with full pixels this turn + if img_hashes and _MM_HASH_CACHE_ENABLED: + seen = _mm_seen(mm_key, img_hashes) + if seen: + mm_data = _strip_pixels_for(mm_data, seen) + hash_only = True + sent_full = [h for h in img_hashes if h not in seen] + + def _post(send_mm: MultiModalData | None): + b = dict(body) + feats = ( + _build_mm_features(renderer, send_mm) + if send_mm is not None and not send_mm.is_empty() + else None + ) + if feats is not None: + b["features"] = feats + pk: dict[str, Any] = {"cast_to": httpx.Response, "body": b} + if extra_headers: + pk["options"] = cast(Any, {"headers": extra_headers}) + return client.post(endpoint, **pk) + + try: + raw_response = await _post(mm_data) + miss = hash_only and raw_response.status_code >= 400 + except OpenAIError: + if not hash_only: + raise + miss = True + if miss: + # Suspected engine mm-cache miss for a hash-only image: forget the claim, + # re-render the full prompt, and resend every image in full. + _mm_forget(mm_key, img_hashes) + _, _, mm_data, prompt_attr = await _maybe_offload(renderer, _prepare) + raw_response = await _post(mm_data) + sent_full = list(img_hashes) + if _MM_HASH_CACHE_ENABLED and sent_full: + # Remember the images we shipped in full so later turns send them + # hash-only (engine cache hit). + _mm_record(mm_key, sent_full) + + # Ephemeral RETURN (opt-in, ``RENDERERS_MM_EPHEMERAL``): never hand decoded + # pixels back to the caller — the trajectory keeps a descriptor only, so the + # env worker / trace never retains image tensors for the life of a rollout. + # Pixels are re-derived downstream from the message images (training-sample + # construction). Independent of the send mode above. + if img_hashes and _MM_EPHEMERAL: + mm_data = _strip_pixels_for(mm_data, set(img_hashes)) + if prompt_attr is not None and getattr(prompt_attr, "multi_modal_data", None) is not None: + prompt_attr = replace(prompt_attr, multi_modal_data=mm_data) data = parse_generate_response(raw_response.content) choice = (data.get("choices") or [{}])[0] @@ -415,21 +555,32 @@ def _build_qwen_vl_features( image_items = mm_data.mm_items.get("image") or [] if image_items: - # mm_items now ship numpy arrays (the renderer is torch-free); - # convert at this vLLM-glue boundary where torch is already a - # hard dependency. - pixel_values = torch.cat( - [torch.as_tensor(it["pixel_values"]) for it in image_items], dim=0 - ) - image_grid_thw = torch.cat( - [torch.as_tensor(it["image_grid_thw"]) for it in image_items], dim=0 - ) - hf_inputs = BatchFeature( - data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw} - ) - config = _create_qwen2vl_field_factory(spatial_merge_size)(hf_inputs) - kwargs_items = MultiModalKwargsItems.from_hf_inputs(hf_inputs, config) - encoded = [encode_mm_kwargs_item(it) for it in kwargs_items["image"]] + # Per-item hash-only/full: items carrying ``pixel_values`` are encoded as + # full payloads; items stripped to a descriptor (no ``pixel_values``) get + # a ``None`` slot, scattered back to their original positions so + # ``kwargs_data`` stays aligned with ``mm_hashes`` / ``mm_placeholders``. + # vLLM serves the ``None`` slots from its mm-hash cache. mm_items now ship + # numpy arrays (the renderer is torch-free); convert at this vLLM-glue + # boundary where torch is already a hard dependency. + encoded: list[Any] = [None] * len(image_items) + full_idx = [ + i for i, it in enumerate(image_items) if it.get("pixel_values") is not None + ] + if full_idx: + full_items = [image_items[i] for i in full_idx] + pixel_values = torch.cat( + [torch.as_tensor(it["pixel_values"]) for it in full_items], dim=0 + ) + image_grid_thw = torch.cat( + [torch.as_tensor(it["image_grid_thw"]) for it in full_items], dim=0 + ) + hf_inputs = BatchFeature( + data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw} + ) + config = _create_qwen2vl_field_factory(spatial_merge_size)(hf_inputs) + kwargs_items = MultiModalKwargsItems.from_hf_inputs(hf_inputs, config) + for slot, item in zip(full_idx, kwargs_items["image"]): + encoded[slot] = encode_mm_kwargs_item(item) out["kwargs_data"]["image"] = encoded out["mm_hashes"]["image"] = list(mm_data.mm_hashes.get("image") or []) out["mm_placeholders"]["image"] = [ @@ -437,10 +588,10 @@ def _build_qwen_vl_features( for p in mm_data.mm_placeholders.get("image") or [] ] - # If kwargs_data is empty across all modalities, drop the key so vLLM - # falls back to the hash-only (cache-hit) path. Otherwise hand it the - # full payload. - if not any(out["kwargs_data"].values()): + # If no full payload was built across any modality, drop kwargs_data so vLLM + # takes the hash-only (cache-hit) path for the whole request. Otherwise hand + # it the payload (with ``None`` slots for the hash-only images). + if not any(any(item is not None for item in items) for items in out["kwargs_data"].values()): out["kwargs_data"] = None return out diff --git a/renderers/configs.py b/renderers/configs.py index 2c18a17..14f410f 100644 --- a/renderers/configs.py +++ b/renderers/configs.py @@ -148,9 +148,12 @@ class Qwen35RendererConfig(BaseRendererConfig): running across the entire conversation. Mirrors the chat template's ``add_vision_id`` toggle.""" - image_cache_max: int = 256 + image_cache_max: int = 32 """FIFO bound on the per-renderer image processor cache. Renderer- - internal — not a Jinja chat-template kwarg.""" + internal — not a Jinja chat-template kwarg. Each entry holds a decoded + ``pixel_values`` tensor (tens of MB for a large image), and the pool holds + one cache per renderer, so this caps resident cache memory at roughly + ``pool_size * image_cache_max * pixel_values_bytes``.""" _internal_fields = frozenset({"image_cache_max"}) @@ -166,7 +169,7 @@ class Qwen36RendererConfig(BaseRendererConfig): add_vision_id: bool = False """See :class:`Qwen35RendererConfig.add_vision_id`.""" - image_cache_max: int = 256 + image_cache_max: int = 32 """See :class:`Qwen35RendererConfig.image_cache_max`.""" _internal_fields = frozenset({"image_cache_max"}) @@ -180,7 +183,7 @@ class Qwen3VLRendererConfig(BaseRendererConfig): add_vision_id: bool = False """See :class:`Qwen35RendererConfig.add_vision_id`.""" - image_cache_max: int = 256 + image_cache_max: int = 32 """See :class:`Qwen35RendererConfig.image_cache_max`.""" _internal_fields = frozenset({"image_cache_max"}) @@ -294,7 +297,7 @@ class KimiK25RendererConfig(BaseRendererConfig): ``thinking`` (not ``enable_thinking``) to match the upstream chat template's native variable name.""" - image_cache_max: int = 256 + image_cache_max: int = 32 """See :class:`Qwen35RendererConfig.image_cache_max`.""" _internal_fields = frozenset({"image_cache_max"}) diff --git a/tests/test_client.py b/tests/test_client.py index 1cc1000..d38c1a4 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -14,6 +14,18 @@ from renderers.client import generate +@pytest.fixture(autouse=True) +def _reset_mm_hash_cache(): + """Isolate the module-level hash-only send cache between tests (its whole job + is to persist sent hashes across generate() calls, which would otherwise leak + between tests).""" + import renderers.client as _client + + _client._mm_sent.clear() + yield + _client._mm_sent.clear() + + class _FakeRenderer: supports_tools = True @@ -371,6 +383,122 @@ def test_generate_serializes_multimodal_features_for_qwen_vl_family( assert isinstance(item, str) and len(item) > 0 +def test_generate_sends_pixels_once_then_hash_only(monkeypatch): + """With hash-only send enabled: first send of an image carries full + ``pixel_values``; a later send of the same hash carries a descriptor-only + (``None``) slot so the engine serves it from its mm-hash cache. Mixed prompts + keep per-item alignment.""" + pytest.importorskip("torch") + pytest.importorskip("vllm", reason="vllm needed for features serialization") + monkeypatch.setattr("renderers.client._MM_HASH_CACHE_ENABLED", True) + + import torch as _torch + from renderers.base import MultiModalData, PlaceholderRange, load_tokenizer + from renderers.qwen3_vl import Qwen3VLRenderer + + renderer = Qwen3VLRenderer(load_tokenizer("Qwen/Qwen3-VL-4B-Instruct")) + + def _img(): + return { + "pixel_values": _torch.zeros(4, 8, dtype=_torch.float32), + "image_grid_thw": _torch.tensor([[1, 2, 2]], dtype=_torch.int64), + } + + def _mm(hashes): + return MultiModalData( + mm_hashes={"image": list(hashes)}, + mm_placeholders={ + "image": [PlaceholderRange(offset=5 + i, length=1) for i in range(len(hashes))] + }, + mm_items={"image": [_img() for _ in hashes]}, + ) + + def _features_for(hashes): + client = _FakeClient() + asyncio.run( + generate( + client=client, + renderer=renderer, + messages=[], + model="qwen3-vl", + prompt_ids=list(range(20)), + multi_modal_data=_mm(hashes), + sampling_params={"max_tokens": 4}, + ) + ) + return client.calls[0]["body"]["features"] + + # First turn: image "aaa" sent in full. + f1 = _features_for(["aaa"]) + assert f1["kwargs_data"] is not None + assert f1["kwargs_data"]["image"][0] is not None + + # Second turn: same image -> hash-only (kwargs_data dropped), descriptors kept + # so the engine can look it up in its mm-hash cache. + f2 = _features_for(["aaa"]) + assert f2["kwargs_data"] is None + assert f2["mm_hashes"] == {"image": ["aaa"]} + assert f2["mm_placeholders"]["image"] == [{"offset": 5, "length": 1}] + + # Mixed turn: prior "aaa" hash-only, new "bbb" full -> [None, ], + # aligned with mm_hashes. + f3 = _features_for(["aaa", "bbb"]) + slots = f3["kwargs_data"]["image"] + assert slots[0] is None + assert isinstance(slots[1], str) and len(slots[1]) > 0 + assert f3["mm_hashes"] == {"image": ["aaa", "bbb"]} + + +def test_generate_ephemeral_returns_descriptor_only(monkeypatch): + """With ephemeral enabled, the engine still receives full ``pixel_values`` but + ``generate`` returns ``multi_modal_data`` stripped to a descriptor (grid + + hashes + placeholders, no ``pixel_values``), so the trajectory retains no + decoded image tensors.""" + pytest.importorskip("torch") + pytest.importorskip("vllm", reason="vllm needed for features serialization") + monkeypatch.setattr("renderers.client._MM_EPHEMERAL", True) + + import torch as _torch + from renderers.base import MultiModalData, PlaceholderRange, load_tokenizer + from renderers.qwen3_vl import Qwen3VLRenderer + + renderer = Qwen3VLRenderer(load_tokenizer("Qwen/Qwen3-VL-4B-Instruct")) + mm = MultiModalData( + mm_hashes={"image": ["aaa"]}, + mm_placeholders={"image": [PlaceholderRange(offset=5, length=1)]}, + mm_items={ + "image": [ + { + "pixel_values": _torch.zeros(4, 8, dtype=_torch.float32), + "image_grid_thw": _torch.tensor([[1, 2, 2]], dtype=_torch.int64), + } + ] + }, + ) + client = _FakeClient() + result = asyncio.run( + generate( + client=client, + renderer=renderer, + messages=[], + model="qwen3-vl", + prompt_ids=list(range(20)), + multi_modal_data=mm, + sampling_params={"max_tokens": 4}, + ) + ) + + # Engine still got full pixels (send is unaffected by ephemeral). + sent = client.calls[0]["body"]["features"] + assert sent["kwargs_data"]["image"][0] is not None + # Returned mm is descriptor-only: grid kept, pixel_values dropped. + out = result["multi_modal_data"] + item = out.mm_items["image"][0] + assert "pixel_values" not in item + assert "image_grid_thw" in item + assert out.mm_hashes == {"image": ["aaa"]} + + # --------------------------------------------------------------------------- # Prompt overflow handling. # ---------------------------------------------------------------------------