Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 185 additions & 34 deletions renderers/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@
import asyncio
import json
import logging
import os
import threading
from collections import OrderedDict
from collections.abc import Mapping
from dataclasses import replace
from typing import Any, cast

import httpx
from openai import AsyncOpenAI
from openai import AsyncOpenAI, OpenAIError

from renderers.base import (
Message,
Expand All @@ -31,6 +35,92 @@
)

_request_logger = logging.getLogger("renderers.client")

# --- multimodal memory modes -------------------------------------------------
# Two independent, opt-in modes shrink the multimodal data the rollout path keeps
# and ships. Both default OFF so behaviour is unchanged unless a consumer opts in.
#
# EPHEMERAL (``RENDERERS_MM_EPHEMERAL``) — *stored* data. ``generate`` returns a
# descriptor-only ``multi_modal_data`` (image_grid_thw + mm_hashes +
# mm_placeholders, no ``pixel_values``), so the env worker / trace never retains
# decoded image tensors for the life of a rollout — resident stored mm drops from
# the processed-tensor size to a tiny descriptor (the source image still lives in
# the message). The consumer re-derives pixels downstream (training-sample
# construction) from the message images. Safe on any engine (purely client-side).
#
# HASH-ONLY (``RENDERERS_MM_HASH_CACHE``) — *transported* data. Send each distinct
# image's ``pixel_values`` once; later turns send a descriptor-only (``None``
# kwargs) slot and the engine serves it from its mm-hash cache, so the POST body
# carries only new images. REQUIRES an engine that resolves ``None`` slots from
# cache (the disagg router topology). A plain single-server vLLM 0.22 forces
# ``skip_mm_cache=True`` and crashes on an unresolved ``None`` — hence default OFF.
def _env_on(name: str) -> bool:
return os.environ.get(name, "0").lower() not in ("0", "false", "off", "no", "")


_MM_EPHEMERAL = _env_on("RENDERERS_MM_EPHEMERAL")
_MM_HASH_CACHE_ENABLED = _env_on("RENDERERS_MM_HASH_CACHE")
# Per-(endpoint, model) FIFO bound on remembered sent hashes. Should comfortably
# exceed the engine's mm cache so we don't claim cache hits the engine evicted;
# a stale claim is corrected by the cache-miss fallback, never a wrong result.
_MM_SENT_HASHES_MAX = int(os.environ.get("RENDERERS_MM_SENT_HASHES_MAX", "16384"))
_mm_sent: dict[tuple[str, str], "OrderedDict[str, None]"] = {}
_mm_sent_lock = threading.Lock()


def _mm_seen(key: tuple[str, str], hashes: list[str]) -> set[str]:
"""Subset of ``hashes`` already sent under ``key`` (believed engine-cached)."""
with _mm_sent_lock:
lru = _mm_sent.get(key)
if not lru:
return set()
seen = set()
for h in hashes:
if h in lru:
seen.add(h)
lru.move_to_end(h)
return seen


def _mm_record(key: tuple[str, str], hashes: list[str]) -> None:
"""Record ``hashes`` as sent-in-full under ``key`` (FIFO-bounded)."""
if not hashes:
return
with _mm_sent_lock:
lru = _mm_sent.setdefault(key, OrderedDict())
for h in hashes:
lru[h] = None
lru.move_to_end(h)
while len(lru) > _MM_SENT_HASHES_MAX:
lru.popitem(last=False)


def _mm_forget(key: tuple[str, str], hashes: list[str]) -> None:
"""Drop ``hashes`` from the sent set (after a suspected engine cache miss)."""
with _mm_sent_lock:
lru = _mm_sent.get(key)
if lru:
for h in hashes:
lru.pop(h, None)


def _strip_pixels_for(mm_data: MultiModalData, hashes: set[str]) -> MultiModalData:
"""Return ``mm_data`` with ``pixel_values`` dropped from image items whose
hash is in ``hashes`` (descriptor — ``image_grid_thw`` etc. — retained, so
token alignment and a later re-materialize still work)."""
if not hashes or not mm_data.mm_items:
return mm_data
new_items: dict[str, list[dict[str, Any]]] = {}
for modality, items in mm_data.mm_items.items():
item_hashes = mm_data.mm_hashes.get(modality) or []
rebuilt = []
for i, item in enumerate(items):
if i < len(item_hashes) and item_hashes[i] in hashes:
rebuilt.append({k: v for k, v in item.items() if k != "pixel_values"})
else:
rebuilt.append(item)
new_items[modality] = rebuilt
return replace(mm_data, mm_items=new_items)
ROUTED_EXPERTS_DATA_PREFIX = b'"routed_experts":{"data":"'


Expand Down Expand Up @@ -248,13 +338,6 @@ def _prepare():
"token_ids": prompt_ids,
"sampling_params": sp,
}
features = (
_build_mm_features(renderer, mm_data)
if mm_data and not mm_data.is_empty()
else None
)
if features is not None:
body["features"] = features
if cache_salt is not None:
body["cache_salt"] = cache_salt
if priority is not None:
Expand All @@ -271,13 +354,70 @@ def _prepare():
len(prompt_ids),
sp.get("max_tokens"),
)
post_kwargs: dict[str, Any] = {
"cast_to": httpx.Response,
"body": body,
}
if extra_headers:
post_kwargs["options"] = cast(Any, {"headers": extra_headers})
raw_response = await client.post(endpoint, **post_kwargs)

mm_key = (endpoint, model)
img_hashes: list[str] = (
list(mm_data.mm_hashes.get("image") or [])
if mm_data is not None and not mm_data.is_empty()
else []
)

# Hash-only SEND (opt-in, ``RENDERERS_MM_HASH_CACHE``): strip already-sent
# images to descriptors so the engine serves them from its mm-hash cache and
# only new images carry pixels. Requires an engine that resolves ``None``
# kwargs slots from cache (the disagg router topology); a plain single-server
# vLLM crashes on an unresolved ``None``, so this defaults OFF.
hash_only = False
sent_full = list(img_hashes) # hashes shipped with full pixels this turn
if img_hashes and _MM_HASH_CACHE_ENABLED:
seen = _mm_seen(mm_key, img_hashes)
if seen:
mm_data = _strip_pixels_for(mm_data, seen)
hash_only = True
sent_full = [h for h in img_hashes if h not in seen]

def _post(send_mm: MultiModalData | None):
b = dict(body)
feats = (
_build_mm_features(renderer, send_mm)
if send_mm is not None and not send_mm.is_empty()
else None
)
if feats is not None:
b["features"] = feats
pk: dict[str, Any] = {"cast_to": httpx.Response, "body": b}
if extra_headers:
pk["options"] = cast(Any, {"headers": extra_headers})
return client.post(endpoint, **pk)

try:
raw_response = await _post(mm_data)
miss = hash_only and raw_response.status_code >= 400
except OpenAIError:
if not hash_only:
raise
miss = True
if miss:
# Suspected engine mm-cache miss for a hash-only image: forget the claim,
# re-render the full prompt, and resend every image in full.
_mm_forget(mm_key, img_hashes)
_, _, mm_data, prompt_attr = await _maybe_offload(renderer, _prepare)
raw_response = await _post(mm_data)
sent_full = list(img_hashes)
if _MM_HASH_CACHE_ENABLED and sent_full:
# Remember the images we shipped in full so later turns send them
# hash-only (engine cache hit).
_mm_record(mm_key, sent_full)

# Ephemeral RETURN (opt-in, ``RENDERERS_MM_EPHEMERAL``): never hand decoded
# pixels back to the caller — the trajectory keeps a descriptor only, so the
# env worker / trace never retains image tensors for the life of a rollout.
# Pixels are re-derived downstream from the message images (training-sample
# construction). Independent of the send mode above.
if img_hashes and _MM_EPHEMERAL:
mm_data = _strip_pixels_for(mm_data, set(img_hashes))
if prompt_attr is not None and getattr(prompt_attr, "multi_modal_data", None) is not None:
prompt_attr = replace(prompt_attr, multi_modal_data=mm_data)
data = parse_generate_response(raw_response.content)

choice = (data.get("choices") or [{}])[0]
Expand Down Expand Up @@ -415,32 +555,43 @@ def _build_qwen_vl_features(

image_items = mm_data.mm_items.get("image") or []
if image_items:
# mm_items now ship numpy arrays (the renderer is torch-free);
# convert at this vLLM-glue boundary where torch is already a
# hard dependency.
pixel_values = torch.cat(
[torch.as_tensor(it["pixel_values"]) for it in image_items], dim=0
)
image_grid_thw = torch.cat(
[torch.as_tensor(it["image_grid_thw"]) for it in image_items], dim=0
)
hf_inputs = BatchFeature(
data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw}
)
config = _create_qwen2vl_field_factory(spatial_merge_size)(hf_inputs)
kwargs_items = MultiModalKwargsItems.from_hf_inputs(hf_inputs, config)
encoded = [encode_mm_kwargs_item(it) for it in kwargs_items["image"]]
# Per-item hash-only/full: items carrying ``pixel_values`` are encoded as
# full payloads; items stripped to a descriptor (no ``pixel_values``) get
# a ``None`` slot, scattered back to their original positions so
# ``kwargs_data`` stays aligned with ``mm_hashes`` / ``mm_placeholders``.
# vLLM serves the ``None`` slots from its mm-hash cache. mm_items now ship
# numpy arrays (the renderer is torch-free); convert at this vLLM-glue
# boundary where torch is already a hard dependency.
encoded: list[Any] = [None] * len(image_items)
full_idx = [
i for i, it in enumerate(image_items) if it.get("pixel_values") is not None
]
if full_idx:
full_items = [image_items[i] for i in full_idx]
pixel_values = torch.cat(
[torch.as_tensor(it["pixel_values"]) for it in full_items], dim=0
)
image_grid_thw = torch.cat(
[torch.as_tensor(it["image_grid_thw"]) for it in full_items], dim=0
)
hf_inputs = BatchFeature(
data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw}
)
config = _create_qwen2vl_field_factory(spatial_merge_size)(hf_inputs)
kwargs_items = MultiModalKwargsItems.from_hf_inputs(hf_inputs, config)
for slot, item in zip(full_idx, kwargs_items["image"]):
encoded[slot] = encode_mm_kwargs_item(item)
out["kwargs_data"]["image"] = encoded
out["mm_hashes"]["image"] = list(mm_data.mm_hashes.get("image") or [])
out["mm_placeholders"]["image"] = [
{"offset": p.offset, "length": p.length}
for p in mm_data.mm_placeholders.get("image") or []
]

# If kwargs_data is empty across all modalities, drop the key so vLLM
# falls back to the hash-only (cache-hit) path. Otherwise hand it the
# full payload.
if not any(out["kwargs_data"].values()):
# If no full payload was built across any modality, drop kwargs_data so vLLM
# takes the hash-only (cache-hit) path for the whole request. Otherwise hand
# it the payload (with ``None`` slots for the hash-only images).
if not any(any(item is not None for item in items) for items in out["kwargs_data"].values()):
out["kwargs_data"] = None

return out
13 changes: 8 additions & 5 deletions renderers/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,12 @@ class Qwen35RendererConfig(BaseRendererConfig):
running across the entire conversation. Mirrors the chat template's
``add_vision_id`` toggle."""

image_cache_max: int = 256
image_cache_max: int = 32
"""FIFO bound on the per-renderer image processor cache. Renderer-
internal — not a Jinja chat-template kwarg."""
internal — not a Jinja chat-template kwarg. Each entry holds a decoded
``pixel_values`` tensor (tens of MB for a large image), and the pool holds
one cache per renderer, so this caps resident cache memory at roughly
``pool_size * image_cache_max * pixel_values_bytes``."""

_internal_fields = frozenset({"image_cache_max"})

Expand All @@ -166,7 +169,7 @@ class Qwen36RendererConfig(BaseRendererConfig):
add_vision_id: bool = False
"""See :class:`Qwen35RendererConfig.add_vision_id`."""

image_cache_max: int = 256
image_cache_max: int = 32
"""See :class:`Qwen35RendererConfig.image_cache_max`."""

_internal_fields = frozenset({"image_cache_max"})
Expand All @@ -180,7 +183,7 @@ class Qwen3VLRendererConfig(BaseRendererConfig):
add_vision_id: bool = False
"""See :class:`Qwen35RendererConfig.add_vision_id`."""

image_cache_max: int = 256
image_cache_max: int = 32
"""See :class:`Qwen35RendererConfig.image_cache_max`."""

_internal_fields = frozenset({"image_cache_max"})
Expand Down Expand Up @@ -294,7 +297,7 @@ class KimiK25RendererConfig(BaseRendererConfig):
``thinking`` (not ``enable_thinking``) to match the upstream chat
template's native variable name."""

image_cache_max: int = 256
image_cache_max: int = 32
"""See :class:`Qwen35RendererConfig.image_cache_max`."""

_internal_fields = frozenset({"image_cache_max"})
Expand Down
Loading
Loading