diff --git a/pyproject.toml b/pyproject.toml index dc42053e..545426ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,9 +36,6 @@ possibly-missing-attribute = "ignore" missing-argument = "ignore" unused-type-ignore-comment = "ignore" -[tool.bandit] -exclude_dirs = ["tests", "docs"] - [tool.coverage.run] source = ["src/pruna"] @@ -97,14 +94,14 @@ stable-fast-pruna = { index = "pruna_internal", extra = "stable-fast-extraindex" [project] name = "pruna" -version = "0.3.3" +version = "0.3.2" description = "Smash your AI models" authors = [ {name = "Pruna AI", email = "hello@pruna.ai"} ] license = {file = "LICENSE"} readme = "README.md" -requires-python = ">=3.10,<3.14" +requires-python = ">=3.10,<3.13" keywords = ["AI", "machine learning", "model optimization", "pruning"] classifiers = [ "Development Status :: 4 - Beta", @@ -157,6 +154,7 @@ dependencies = [ "peft>=0.18.0,<0.19.0", "trl<=0.21.0", "termcolor==2.3.0", + "realesrgan", ] [project.optional-dependencies] @@ -171,6 +169,10 @@ vllm = [ "vllm>=0.16.0", "ray", ] +evaluation = [ + "outlines>1.2.0,<2.0.0", + "litellm>=1.0.0", +] stable-fast = [ "xformers>=0.0.30", "stable-fast-pruna>=1.0.8,<1.0.9", @@ -195,18 +197,12 @@ awq = [ "llmcompressor>=0.9", "torch>=2.9.0" ] -upscale = [ - "realesrgan", -] full = [ "pruna[stable-fast]", ] vbench = [ "vbench-pruna; sys_platform != 'darwin'", ] -rapidata = [ - "rapidata>=3.0.0" -] dev = [ "wget", "python-dotenv", @@ -233,15 +229,12 @@ dev = [ "types-PyYAML", "logbar", "pytest-xdist>=3.8.0", + "pruna[evaluation]", ] cpu = [] lmharness = [ "lm-eval>=0.4.0" ] -evaluation = [ - "pruna[rapidata]", - "pruna[lmharness]" -] # Intel extension is tightly coupled with the torch version intel = [ diff --git a/src/pruna/evaluation/metrics/__init__.py b/src/pruna/evaluation/metrics/__init__.py index bf7414c3..d5cba6b8 100644 --- a/src/pruna/evaluation/metrics/__init__.py +++ b/src/pruna/evaluation/metrics/__init__.py @@ -26,6 +26,13 @@ from pruna.evaluation.metrics.metric_rapiddata import RapidataMetric as RapidataMetric from pruna.evaluation.metrics.metric_sharpness import SharpnessMetric from pruna.evaluation.metrics.metric_torch import TorchMetricWrapper +from pruna.evaluation.metrics.vlm_base import ( + BaseVLM, + LitellmVLM, + StatefulVLMMeanScoresMetric, + TransformersVLM, + get_vlm, +) __all__ = [ "MetricRegistry", @@ -47,4 +54,9 @@ "AestheticLAION", "LMEvalMetric", "RapidataMetric", + "BaseVLM", + "LitellmVLM", + "StatefulVLMMeanScoresMetric", + "TransformersVLM", + "get_vlm", ] diff --git a/src/pruna/evaluation/metrics/utils.py b/src/pruna/evaluation/metrics/utils.py index 29342701..a3cdb4a5 100644 --- a/src/pruna/evaluation/metrics/utils.py +++ b/src/pruna/evaluation/metrics/utils.py @@ -56,13 +56,17 @@ def metric_data_processor( This function determines the order and selection of inputs to be passed to various metrics. The function supports different input arrangements through the 'call_type' configuration: - - 'x_y': Uses input data (x) and model outputs - - 'gt_y': Uses ground truth (gt) and model outputs - - 'y_x': Uses model outputs and input data (x) - - 'y_gt': Uses model outputs and ground truth (gt) - - 'pairwise_gt_y': Uses cached base model outputs (gt) and smashed model outputs (y). - - 'pairwise_y_gt': Uses smashed model outputs (y) and cached base model outputs (gt). - The evaluation agent is expected to pass the cached base model outputs as gt. + + - 'y_gt': Model's output first, then ground truth. Returns [outputs, gt]. + - 'gt_y': Ground truth first, then model's output. Returns [gt, outputs]. + - 'y_x': Model's output first, then input data. Returns [outputs, x]. + Used by CLIPScore, VQA, ImageEditScore, VIEScore. + - 'x_y': Input data first, then model's output. Returns [x, outputs]. + - 'x_gt': Input data first, then ground truth. Returns [x, gt]. + - 'gt_x': Ground truth first, then input data. Returns [gt, x]. + - 'pairwise_y_gt': Base model's output first, then subsequent model's output. + - 'pairwise_gt_y': Subsequent model's output first, then base model's output. + - 'y': Only the output is used; the metric has an internal dataset. Returns [outputs]. Parameters ---------- @@ -85,7 +89,8 @@ def metric_data_processor( Raises ------ ValueError - If the specified call_type is not one of: 'x_y', 'gt_y', 'y_x', 'y_gt', 'pairwise'. + If the specified call_type is not one of: 'y_gt', 'gt_y', 'y_x', 'x_y', + 'x_gt', 'gt_x', 'pairwise_y_gt', 'pairwise_gt_y', 'y'. Examples -------- @@ -106,11 +111,15 @@ def metric_data_processor( return [outputs, x] elif call_type == "y_gt": return [outputs, gt] + elif call_type == "x_gt": + return [x, gt] + elif call_type == "gt_x": + return [gt, x] elif call_type == "pairwise_gt_y": return [gt, outputs] elif call_type == "pairwise_y_gt": return [outputs, gt] - elif call_type == "y": # IQA metrics that have an internal dataset + elif call_type == "y": return [outputs] else: raise ValueError(f"Invalid call type: {call_type}") diff --git a/src/pruna/evaluation/metrics/vlm_base.py b/src/pruna/evaluation/metrics/vlm_base.py new file mode 100644 index 00000000..2de7e164 --- /dev/null +++ b/src/pruna/evaluation/metrics/vlm_base.py @@ -0,0 +1,1118 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +VLM (Vision-Language Model) base classes for metrics. + +Implementations +--------------- +- **LitellmVLM** — API inference via ``litellm`` (many providers behind one client). +- **TransformersVLM** — local Hugging Face models on device. + +Why LiteLLM for the default API path +-------------------------------------- +Judge-style metrics need a capable vision-language model. Loading large VLMs locally is +expensive; routing through ``litellm`` keeps the default path lightweight and matches common +API-judge setups without bundling a full local VLM in every metric run. + +API keys and environment +------------------------ +For ``vlm_type="litellm"``, the key passed to the provider is resolved in this order: + +1. The ``api_key`` argument on the metric or :func:`get_vlm` +2. ``LITELLM_API_KEY`` +3. ``OPENAI_API_KEY`` + +Routes such as ``openai/gpt-4o`` use the OpenAI-compatible key. Other providers follow +LiteLLM’s environment conventions (for example ``ANTHROPIC_API_KEY`` for ``anthropic/...``). +The same ``OPENAI_API_KEY`` you use for other OpenAI-hosted judges (for example in pbench) +applies here. + +For a short user-facing summary of key order, hosted vs local, and a minimal ``transformers`` +example, see :doc:`Evaluate a model ` (Vision-language judge +metrics). + +Choosing local vs API +--------------------- +Metrics in :data:`VLM_METRIC_REGISTRY_NAMES` take ``vlm_type`` and ``model_name``: + +- **API** (``vlm_type="litellm"``, default) — use a vision-capable route (e.g. + :data:`DEFAULT_LITELLM_MODEL`). +- **Local** (``vlm_type="transformers"``) — e.g. SmolVLM for offline or CI. + +The ``oneig_reasoning`` metric is separate: it runs the LLM2CLIP stack locally; see +``pruna.evaluation.metrics.metric_oneig_reasoning``. + +Structured outputs +------------------ +- LitellmVLM: pydantic ``response_format`` where applicable. +- TransformersVLM: Outlines 1.x constrained decoding via ``outlines.Generator`` and + ``outlines.models.transformers.from_transformers`` (single- and multi-image ``Chat`` inputs). + +Usage examples +---------------- +Minimal LiteLLM and local ``transformers`` construction is shown under :func:`get_vlm` +(``Examples`` section). **Registry metrics** (``vqa``, ``qa_accuracy``, +``img_edit_score``, OCR/text metrics, ``oneig_alignment``, ``vie_score``, …) take the same +``vlm_type``, ``model_name``, ``api_key``, and ``vlm_kwargs`` pattern; see +:class:`StatefulVLMMeanScoresMetric` below and each metric class docstring. + +For VIEScore-style **text--image editing** metrics that pass two PIL images per prompt (source +then edited), call :meth:`LitellmVLM.generate_with_image_lists` or +:meth:`TransformersVLM.generate_with_image_lists` with ``image_lists[i]`` aligned to +``prompts[i]``. + +Metric-level (hosted vs local) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +``Task(request=["vqa"], ...)`` supplies ``model_name="openai/gpt-4o"`` for VLM registry names. +Override the backend by constructing a metric instance: + +.. code-block:: python + + import torch + + from pruna.evaluation.metrics import VQAMetric + + hosted = VQAMetric(vlm_type="litellm", model_name="openai/gpt-4o") + local = VQAMetric( + vlm_type="transformers", + model_name="HuggingFaceTB/SmolVLM-256M-Instruct", + device="cpu", + vlm_kwargs={"model_load_kwargs": {"torch_dtype": torch.float32}}, + ) + +``QAAccuracyMetric`` and other classes that call :func:`get_vlm` directly use the same +arguments (substitute the class name). +""" + +from __future__ import annotations + +import base64 +import io +import math +import os +from abc import ABC, abstractmethod +from typing import Any, List, Literal, Optional, Type, TypeVar, Union + +import numpy as np +import torch +from PIL import Image +from pydantic import BaseModel + +from pruna.evaluation.metrics.metric_stateful import StatefulMetric +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.utils import get_call_type_for_single_metric +from pruna.logging.logger import pruna_logger + +T = TypeVar("T", bound=BaseModel) + +DEFAULT_LITELLM_MODEL: str = "openai/gpt-4o" + +VLM_METRIC_REGISTRY_NAMES: frozenset[str] = frozenset( + ( + "vqa", + "qa_accuracy", + "img_edit_score", + "text_score", + "ocr_levenshtein", + "ocr_text_score", + "oneig_text_score", + "oneig_alignment", + "vie_score", + ) +) + + +def get_vlm( + vlm: Optional[BaseVLM] = None, + vlm_type: Literal["litellm", "transformers"] = "litellm", + *, + model_name: Optional[str] = None, + device: Optional[str | torch.device] = None, + api_key: Optional[str] = None, + structured_output: bool = True, + **vlm_kwargs: Any, +) -> BaseVLM: + """ + Create or return a VLM instance. + + Parameters + ---------- + vlm : BaseVLM | None + If provided, returned as-is. Otherwise a VLM is created. + vlm_type : {"litellm", "transformers"} + Backend when creating a VLM. + model_name : str | None + Model name for litellm (e.g. ``openai/gpt-4o``) or HuggingFace ``from_pretrained`` id. + **Required** when ``vlm`` is not provided. Ignored when ``vlm`` is provided. + device : str | torch.device | None + Device for transformers VLM. + api_key : str | None + API key for litellm. + structured_output : bool + When True, litellm uses pydantic ``response_format`` from the metric; for + ``transformers``, enables outlines-based constrained decoding when a string + format is passed to ``generate``/``score``. + **vlm_kwargs : Any + Same dict as ``vlm_kwargs`` on VLM metrics: forwarded to the backend chosen by + ``vlm_type``. For ``"litellm"``, kwargs go to ``LitellmVLM`` (e.g. provider-specific + options). For ``"transformers"``, use ``model_load_kwargs`` for + ``AutoModelForImageTextToText.from_pretrained``; any other keys are passed to + ``TransformersVLM`` after ``model_load_kwargs`` is popped. + + Returns + ------- + BaseVLM + The VLM instance. + + Notes + ----- + When ``vlm_type`` is ``"litellm"`` and ``api_key`` is omitted, the key is taken from + ``LITELLM_API_KEY`` or ``OPENAI_API_KEY``. See the module docstring above. User manual: + :doc:`Evaluate a model ` (Vision-language judge metrics). + + Examples + -------- + Hosted (``litellm``) and local Hugging Face (``transformers``). API key for ``hosted`` from + ``OPENAI_API_KEY`` or ``LITELLM_API_KEY`` if ``api_key`` is omitted. + + .. code-block:: python + + import torch + + from pruna.evaluation.metrics.vlm_base import get_vlm + + hosted = get_vlm(vlm_type="litellm", model_name="openai/gpt-4o") + local = get_vlm( + vlm_type="transformers", + model_name="HuggingFaceTB/SmolVLM-256M-Instruct", + device="cpu", + model_load_kwargs={"torch_dtype": torch.float32}, + ) + + Another LiteLLM provider route (set the env var that route expects, e.g. + ``ANTHROPIC_API_KEY`` for ``anthropic/...``): + + .. code-block:: python + + from pruna.evaluation.metrics.vlm_base import get_vlm + + other_provider = get_vlm( + vlm_type="litellm", model_name="anthropic/claude-3-5-sonnet-20241022" + ) + """ + if vlm is not None: + return vlm + if not model_name: + raise ValueError( + "get_vlm requires model_name when vlm is not provided " + '(pass model_name explicitly, e.g. model_name="openai/gpt-4o").' + ) + if vlm_type == "litellm": + return LitellmVLM(model_name=model_name, api_key=api_key, **vlm_kwargs) + model_load_kwargs = vlm_kwargs.pop("model_load_kwargs", {}) + return TransformersVLM( + model_name=model_name, + device=device, + use_outlines=structured_output, + model_load_kwargs=model_load_kwargs, + **vlm_kwargs, + ) + + +class BaseVLM(ABC): + """Base class for Vision-Language Models.""" + + @abstractmethod + def generate( + self, + images: List[Image.Image], + prompts: List[str], + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> List[str]: + """ + Generate responses for images and prompts. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + prompts : List[str] + List of text prompts. + response_format : Type[BaseModel] | str | None + Optional pydantic model (litellm) or format string: "integer", "yes_no", "json" (transformers/outlines). + **kwargs : Any + Additional arguments passed to the implementation. + + Returns + ------- + List[str] + Generated responses. + """ + pass + + @abstractmethod + def score( + self, + images: List[Image.Image], + questions: List[str], + answers: List[str], + use_probability: bool = False, + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> List[float]: + """ + Score how well answers match images for given questions. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + questions : List[str] + List of questions. + answers : List[str] + List of expected answers. + use_probability : bool, optional + If True and supported, return P(expected answer) instead of binary 0/1. + response_format : Type[BaseModel] | str | None, optional + Structured output format. When set, uses generate() with this format and + extracts the answer field for comparison instead of raw string matching. + **kwargs : Any + Additional arguments passed to the implementation. + + Returns + ------- + List[float] + Scores for each image-question pair (0-1, or probability when use_probability). + """ + pass + + +class LitellmVLM(BaseVLM): + """ + VLM using litellm for API-based inference. + + Supports many providers (OpenAI, Anthropic, Azure, and others) through a single client. + + Parameters + ---------- + model_name : str + Model name (e.g. ``openai/gpt-4o`` for litellm). Passed from :func:`get_vlm`. + api_key : str | None, optional + API key for the provider. If omitted, uses ``LITELLM_API_KEY`` then ``OPENAI_API_KEY``. + **kwargs : Any + Additional arguments passed to litellm. + + Notes + ----- + LiteLLM is the default API backend so metric runs can use a hosted VLM judge without + downloading large local checkpoints. Provider-specific environment variables are described + in the LiteLLM documentation; OpenAI-compatible routes typically use ``OPENAI_API_KEY``. + User manual: :doc:`Evaluate a model ` (Vision-language + judge metrics). + + Examples + -------- + OpenAI-compatible route (pass ``api_key`` explicitly, or rely on ``OPENAI_API_KEY`` / + ``LITELLM_API_KEY`` when omitted): + + >>> from pruna.evaluation.metrics.vlm_base import LitellmVLM + >>> hosted = LitellmVLM(model_name="openai/gpt-4o", api_key="sk-placeholder") + >>> hosted.api_key == "sk-placeholder" + True + + Same naming as :func:`get_vlm` examples: ``other_provider`` for non-OpenAI LiteLLM routes + (set ``ANTHROPIC_API_KEY``, etc.): + + .. code-block:: python + + from pruna.evaluation.metrics.vlm_base import LitellmVLM + + other_provider = LitellmVLM(model_name="anthropic/claude-3-5-sonnet-20241022") + """ + + def __init__( + self, + model_name: str, + api_key: Optional[str] = None, + **kwargs: Any, + ) -> None: + self.model_name = model_name + self.api_key = api_key or os.getenv("LITELLM_API_KEY") or os.getenv("OPENAI_API_KEY") + self.extra_kwargs = kwargs + try: + import litellm + + litellm.drop_params = True + self._litellm = litellm + except ImportError: + pruna_logger.error("litellm not installed. Install with: pip install litellm") + raise + + def _litellm_chat_completion( + self, + content: list[dict[str, Any]], + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> str: + completion_kwargs: dict[str, Any] = { + "model": self.model_name, + "messages": [{"role": "user", "content": content}], + "api_key": self.api_key, + **self.extra_kwargs, + **kwargs, + } + if response_format is not None and isinstance(response_format, type): + completion_kwargs["response_format"] = response_format + response = self._litellm.completion(**completion_kwargs) + content_result = response.choices[0].message.content + use_pydantic = ( + response_format is not None + and isinstance(response_format, type) + and isinstance(content_result, response_format) + ) + if use_pydantic: + return content_result.model_dump_json() + return str(content_result) if content_result is not None else "" + + def generate( + self, + images: List[Image.Image], + prompts: List[str], + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> List[str]: + """ + Generate responses for images and prompts. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + prompts : List[str] + List of text prompts. + response_format : Type[BaseModel] | str | None + Optional pydantic model for structured output (litellm uses BaseModel). + **kwargs : Any + Additional arguments passed to litellm completion. + + Returns + ------- + List[str] + Generated responses. + """ + results = [] + for image, prompt in zip(images, prompts): + try: + content = [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": self._image_to_data_url(image)}}, + ] + results.append(self._litellm_chat_completion(content, response_format, **kwargs)) + except Exception as e: + pruna_logger.error(f"Litellm generation failed: {e}") + results.append("") + return results + + def generate_with_image_lists( + self, + image_lists: List[List[Image.Image]], + prompts: List[str], + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> List[str]: + """ + Generate one response per (``image_list``, ``prompt``) pair. + + Each ``image_list`` contains one or more PIL images (e.g. source and edited for + VIEScore ``tie``). Message content is built as text first, then each image as + ``image_url``, matching common OpenAI-style multi-image chat layouts. + + Parameters + ---------- + image_lists : list[list[PIL.Image.Image]] + One list of images per prompt (same length as ``prompts``). + prompts : list[str] + User text for each row. + response_format : optional + Same as :meth:`generate`. + **kwargs : Any + Forwarded to litellm ``completion``. + + Returns + ------- + list[str] + One string (or JSON string for pydantic) per row. + """ + if len(image_lists) != len(prompts): + raise ValueError("image_lists and prompts must have the same length.") + results: List[str] = [] + for imgs, prompt in zip(image_lists, prompts): + try: + content: list[dict[str, Any]] = [{"type": "text", "text": prompt}] + for im in imgs: + content.append({"type": "image_url", "image_url": {"url": self._image_to_data_url(im)}}) + results.append(self._litellm_chat_completion(content, response_format, **kwargs)) + except Exception as e: + pruna_logger.error(f"Litellm multi-image generation failed: {e}") + results.append("") + return results + + def score( + self, + images: List[Image.Image], + questions: List[str], + answers: List[str], + use_probability: bool = False, + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> List[float]: + """ + Score how well answers match images for given questions. + + When use_probability=True, requests logprobs from the API and returns P(expected). + When response_format is set, uses structured generation and extracts the answer field. + Falls back to binary 0/1 if logprobs not available. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + questions : List[str] + List of questions. + answers : List[str] + List of expected answers. + use_probability : bool, optional + If True, return P(expected) from logprobs when available. Default is False. + response_format : Type[BaseModel] | str | None, optional + Structured output format for answer extraction. + **kwargs : Any + Additional arguments passed to litellm completion. + + Returns + ------- + List[float] + Scores for each image-question pair (0-1, or probability when use_probability). + """ + from pruna.evaluation.metrics.vlm_utils import get_answer_from_response + + scores = [] + for image, question, answer in zip(images, questions, answers): + prompt = f"{question} Please answer yes or no." + if use_probability: + score = self._score_with_logprobs(image, prompt, answer, **kwargs) + elif response_format is not None: + raw = self.generate([image], [prompt], response_format=response_format, **kwargs)[0] + response_answer = get_answer_from_response(raw) + score = 1.0 if answer.lower() in response_answer.lower() else 0.0 + else: + response = self.generate([image], [prompt], **kwargs)[0].lower() + score = 1.0 if answer.lower() in response else 0.0 + scores.append(score) + return scores + + def _score_with_logprobs(self, image: Image.Image, prompt: str, expected: str, **kwargs: Any) -> float: + """ + Get P(expected) from logprobs when available. + + Parameters + ---------- + image : Image.Image + PIL Image to score. + prompt : str + Question prompt. + expected : str + Expected answer (e.g., "Yes"). + **kwargs : Any + Additional arguments passed to litellm completion. + + Returns + ------- + float + Probability of expected answer (0-1), or binary 0/1 on fallback. + """ + content = [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": self._image_to_data_url(image)}}, + ] + completion_kwargs = { + "model": self.model_name, + "messages": [{"role": "user", "content": content}], + "api_key": self.api_key, + "logprobs": True, + "top_logprobs": 5, + **self.extra_kwargs, + **kwargs, + } + try: + response = self._litellm.completion(**completion_kwargs) + choice = response.choices[0] + logprobs = getattr(choice, "logprobs", None) or getattr(choice.message, "logprobs", None) + if logprobs and hasattr(logprobs, "content"): + # Match token if it starts with the yes/no word and the remainder is non-alphabetic + # (e.g. "yes." or "no," match, but "yesterday" or "not" do not). + def _word_matches(token_str: str, word: str) -> bool: + return token_str.startswith(word) and ( + len(token_str) == len(word) or not token_str[len(word)].isalpha() + ) + + yes_words = ("yes", " yes") + no_words = ("no", " no") + p_yes = 0.0 + p_no = 0.0 + for tok in logprobs.content or []: + top = getattr(tok, "top_logprobs", None) or [] + for t in top: + token_str = (getattr(t, "token", "") or "").lower() + lp = float(getattr(t, "logprob", -1e9) or -1e9) + prob = math.exp(lp) + if any(_word_matches(token_str, w) for w in yes_words): + p_yes += prob + elif any(_word_matches(token_str, w) for w in no_words): + p_no += prob + break # Only process the first output token's top_logprobs + eps = 1e-12 + denom = p_yes + p_no + if denom > eps: + ans = expected.strip().lower() + if ans == "yes": + return float(min(1.0, p_yes / denom)) + if ans == "no": + return float(min(1.0, p_no / denom)) + content_str = (choice.message.content or "").lower() + if expected.lower() in content_str: + return 1.0 + return 0.0 + except Exception: + response = self.generate([image], [prompt], **kwargs)[0].lower() + return 1.0 if expected.lower() in response else 0.0 + + def _image_to_data_url(self, image: Image.Image) -> str: + buffer = io.BytesIO() + image.save(buffer, format="PNG") + buffer.seek(0) + b64 = base64.b64encode(buffer.read()).decode("utf-8") + return f"data:image/png;base64,{b64}" + + +class TransformersVLM(BaseVLM): + """ + VLM using HuggingFace Transformers for local inference. + + Supports models like BLIP, LLaVA, SmolVLM, etc. + + Parameters + ---------- + model_name : str, optional + HuggingFace model name. Default is "Salesforce/blip2-opt-2.7b". + device : str | torch.device | None, optional + Device for inference. Auto-detected if None. + use_outlines : bool, optional + Whether to use outlines for constrained decoding when the caller passes a string + ``response_format``. Usually set from ``structured_output`` via :func:`get_vlm`. + model_load_kwargs : dict, optional + Kwargs passed to from_pretrained (e.g. dtype, attn_implementation). + **kwargs : Any + Additional arguments passed to model.generate. + + Notes + ----- + Prefer :func:`get_vlm` from metrics so ``structured_output`` and ``vlm_kwargs`` match + registry metrics. User manual: :doc:`Evaluate a model ` + (Vision-language judge metrics). + + Examples + -------- + Local judge only (same ``local`` pattern as :func:`get_vlm`; prefer :func:`get_vlm` from + metrics so ``structured_output`` and ``vlm_kwargs`` stay aligned): + + .. code-block:: python + + import torch + + from pruna.evaluation.metrics.vlm_base import TransformersVLM + + local = TransformersVLM( + model_name="HuggingFaceTB/SmolVLM-256M-Instruct", + device="cpu", + model_load_kwargs={"torch_dtype": torch.float32}, + ) + """ + + def __init__( + self, + model_name: str = "Salesforce/blip2-opt-2.7b", + device: Optional[str | torch.device] = None, + use_outlines: bool = False, + model_load_kwargs: Optional[dict] = None, + **kwargs: Any, + ) -> None: + self.model_name = model_name + self.use_outlines = use_outlines + self.model_load_kwargs = model_load_kwargs or {} + if device is None: + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + else: + self.device = torch.device(device) + self.extra_kwargs = kwargs + self._model = None + self._processor = None + self._yes_no_prefix_ids: Optional[tuple[list[int], list[int]]] = None + self._outlines_wrapped_model: Any = None + + def _load_model(self) -> None: + if self._model is not None: + return + try: + from transformers import AutoModelForImageTextToText, AutoProcessor + except ImportError: + pruna_logger.error("transformers not installed. Install with: pip install transformers") + raise + pruna_logger.info(f"Loading VLM model: {self.model_name}") + self._processor = AutoProcessor.from_pretrained(self.model_name) + self._model = AutoModelForImageTextToText.from_pretrained(self.model_name, **self.model_load_kwargs) + device = self.device + self._model.to(device) # type: ignore[invalid-argument-type] + self._model.eval() + + def _get_outlines_wrapped_model(self) -> Any: + """Lazily wrap HF model + processor for Outlines 1.x steerable generation.""" + if self._outlines_wrapped_model is None: + from outlines.models.transformers import from_transformers + + assert self._processor is not None, "_processor must be loaded before wrapping with outlines" + self._outlines_wrapped_model = from_transformers(self._model, self._processor) + return self._outlines_wrapped_model + + def _pil_for_outlines(self, image: Image.Image) -> Any: + """Wrap a PIL image for ``outlines.inputs.Image`` (requires a concrete ``format``).""" + from outlines.inputs import Image as OutlinesImage + + buf = io.BytesIO() + image.convert("RGB").save(buf, format="PNG") + buf.seek(0) + pil = Image.open(buf) + return OutlinesImage(pil) + + def _chat_user_with_images(self, images: List[Image.Image], prompt: str) -> Any: + """Build an ``outlines.inputs.Chat`` with one or more images then text (HF multimodal dicts).""" + from outlines.inputs import Chat + + parts: list[dict[str, Any]] = [] + for im in images: + parts.append({"type": "image", "image": self._pil_for_outlines(im)}) + parts.append({"type": "text", "text": prompt}) + return Chat([{"role": "user", "content": parts}]) + + def _outlines_output_term(self, response_format: Any) -> Any: + """ + Map metric ``response_format`` to an Outlines output type, or None for unconstrained decode. + + Returns + ------- + Any + A term accepted by :class:`outlines.generator.Generator`, or None. + """ + from outlines.types import json_schema, regex + + if isinstance(response_format, str): + if response_format == "integer": + return regex(r"\d+") + if response_format == "yes_no": + return regex(r"(Yes|No)") + return None + if isinstance(response_format, type): + try: + if issubclass(response_format, BaseModel): + return json_schema(response_format) + except TypeError: + return None + return None + + def _generate_steered(self, chats: List[Any], output_term: Any, max_new_tokens: int) -> List[str]: + """Run Outlines :class:`~outlines.generator.Generator` on prepared chat inputs.""" + from outlines import Generator + + om = self._get_outlines_wrapped_model() + results: List[str] = [] + with torch.compiler.set_stance("force_eager"): + gen = Generator(om, output_type=output_term) + for chat in chats: + try: + out = gen(chat, max_new_tokens=max_new_tokens) + results.append(out if isinstance(out, str) else str(out)) + except Exception as e: + pruna_logger.warning(f"Outlines generation failed: {e}, using empty string") + results.append("") + return results + + def generate( + self, + images: List[Image.Image], + prompts: List[str], + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> List[str]: + """ + Generate responses using local VLM. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + prompts : List[str] + List of text prompts. + response_format : Type[BaseModel] | str | None + When ``use_outlines`` is True: string ``integer`` / ``yes_no``, or a Pydantic model + class for JSON-schema constrained decoding; otherwise unconstrained ``model.generate``. + **kwargs : Any + Additional arguments passed to model generate. + + Returns + ------- + List[str] + Generated responses. + """ + self._load_model() + max_new_tokens = kwargs.get("max_new_tokens", 128) + term = self._outlines_output_term(response_format) if self.use_outlines else None + if term is not None: + chats = [self._chat_user_with_images([image], prompt) for image, prompt in zip(images, prompts)] + return self._generate_steered(chats, term, max_new_tokens) + return self._generate_standard(images, prompts, max_new_tokens) + + def generate_with_image_lists( + self, + image_lists: List[List[Image.Image]], + prompts: List[str], + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> List[str]: + """ + Generate with multiple PIL images per prompt (e.g. VIEScore source + edited). + + Uses the chat template path with several ``image`` parts then text. When + ``use_outlines`` is True and ``response_format`` maps to an Outlines output type + (string ``integer`` / ``yes_no`` or a Pydantic model class), uses the same + Outlines 1.x steerable path as :meth:`generate` via ``outlines.inputs.Chat``. + Otherwise uses unconstrained ``model.generate``. + + Parameters + ---------- + image_lists : list[list[PIL.Image.Image]] + One list of images per prompt. + prompts : list[str] + Prompts aligned with ``image_lists``. + response_format : optional + Same conventions as :meth:`generate` for structured decoding when outlines is enabled. + **kwargs : Any + Passed through (e.g. ``max_new_tokens``). + + Returns + ------- + list[str] + Decoded strings per row. + """ + if len(image_lists) != len(prompts): + raise ValueError("image_lists and prompts must have the same length.") + max_new_tokens = kwargs.get("max_new_tokens", 128) + self._load_model() + term = self._outlines_output_term(response_format) if self.use_outlines else None + if term is not None: + chats = [self._chat_user_with_images(imgs, prompt) for imgs, prompt in zip(image_lists, prompts)] + return self._generate_steered(chats, term, max_new_tokens) + results: List[str] = [] + with torch.inference_mode(): + for imgs, prompt in zip(image_lists, prompts): + inputs = self._prepare_inputs_multi(imgs, prompt) + input_len = inputs["input_ids"].shape[1] + output = self._model.generate(**inputs, max_new_tokens=max_new_tokens, **self.extra_kwargs) + response = self._decode_output(output[0][input_len:]) + results.append(response) + return results + + def _prepare_inputs(self, image: Image.Image, prompt: str) -> dict: + """Prepare model inputs, supporting both BLIP-style and chat-template processors.""" + try: + inputs = self._processor(images=[image], text=prompt, return_tensors="pt") + except (ValueError, TypeError): + conversation = [ + {"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]} + ] + inputs = self._processor.apply_chat_template( + conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ) + return {k: v.to(self.device) for k, v in inputs.items()} + + def _prepare_inputs_multi(self, images: List[Image.Image], prompt: str) -> dict: + """Chat-template inputs with multiple images then text (VIEScore ``tie``-style).""" + parts: list[dict[str, Any]] = [] + for im in images: + parts.append({"type": "image", "image": im}) + parts.append({"type": "text", "text": prompt}) + conversation = [{"role": "user", "content": parts}] + inputs = self._processor.apply_chat_template( + conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ) + return {k: v.to(self.device) for k, v in inputs.items()} + + def _decode_output(self, output_ids: torch.Tensor) -> str: + """Decode model output to text.""" + if hasattr(self._processor, "batch_decode"): + return self._processor.batch_decode([output_ids], skip_special_tokens=True)[0] + return self._processor.decode(output_ids, skip_special_tokens=True) + + def _generate_standard( + self, + images: List[Image.Image], + prompts: List[str], + max_new_tokens: int, + ) -> List[str]: + """Standard generation without outlines.""" + results = [] + with torch.inference_mode(): + for image, prompt in zip(images, prompts): + inputs = self._prepare_inputs(image, prompt) + input_len = inputs["input_ids"].shape[1] + output = self._model.generate(**inputs, max_new_tokens=max_new_tokens, **self.extra_kwargs) + # Decode only the newly generated tokens to avoid re-including the prompt text. + response = self._decode_output(output[0][input_len:]) + results.append(response) + return results + + def _get_tokenizer(self) -> Any: + """Return the HF tokenizer used for yes/no prefix ids and decoding.""" + self._load_model() + proc = self._processor + tok = getattr(proc, "tokenizer", None) or getattr(proc, "text_tokenizer", None) + if tok is None: + raise ValueError( + "Transformers VLM probability scoring requires a tokenizer on the processor; " + "pass use_probability=False for binary scoring." + ) + return tok + + def _score_yes_no_probability(self, image: Image.Image, question: str, answer: str) -> float: + """Soft VQAScore-style score from next-token softmax over yes/no prefix token ids.""" + from pruna.evaluation.metrics.vlm_utils import yes_no_first_token_id_groups + + self._load_model() + prompt = f"{question} Please answer yes or no." + inputs = self._prepare_inputs(image, prompt) + if self._yes_no_prefix_ids is None: + self._yes_no_prefix_ids = yes_no_first_token_id_groups(self._get_tokenizer()) + yes_ids, no_ids = self._yes_no_prefix_ids + if not yes_ids or not no_ids: + pruna_logger.warning("Empty yes/no prefix token ids; install a tokenizer with standard Yes/No encodings.") + return 0.0 + with torch.inference_mode(): + out = self._model(**inputs) + if not hasattr(out, "logits") or out.logits is None: + raise RuntimeError("Model forward did not return logits; cannot compute P(Yes).") + logits = out.logits[0, -1, :].float() + probs = torch.softmax(logits, dim=-1) + device = probs.device + p_yes = probs[torch.tensor(yes_ids, device=device, dtype=torch.long)].sum() + p_no = probs[torch.tensor(no_ids, device=device, dtype=torch.long)].sum() + denom = p_yes + p_no + ans = answer.strip().lower() + eps = 1e-12 + if float(denom.item()) < eps: + if ans == "yes": + return float(p_yes.clamp(0.0, 1.0).item()) + if ans == "no": + return float(p_no.clamp(0.0, 1.0).item()) + return 0.0 + if ans == "yes": + return float((p_yes / (denom + eps)).item()) + if ans == "no": + return float((p_no / (denom + eps)).item()) + return float((p_yes / (denom + eps)).item()) + + def score( + self, + images: List[Image.Image], + questions: List[str], + answers: List[str], + use_probability: bool = False, + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> List[float]: + """ + Score how well answers match images for given questions. + + When ``use_probability`` is True, computes a VQAScore-style score from the next-token + distribution at the last context position (softmax mass on yes/no prefix token ids, + normalized over their union). Otherwise uses generation and binary substring matching. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + questions : List[str] + List of questions. + answers : List[str] + List of expected answers. + use_probability : bool, optional + If True, return a soft score from logits (no ``generate`` call). If False, binary. + response_format : Type[BaseModel] | str | None, optional + Structured output format for answer extraction (only when ``use_probability`` is False). + **kwargs : Any + Additional arguments passed to ``generate`` when ``use_probability`` is False. + + Returns + ------- + List[float] + Scores for each image-question pair in ``[0, 1]``. + """ + from pruna.evaluation.metrics.vlm_utils import get_answer_from_response + + scores = [] + for image, question, answer in zip(images, questions, answers): + if use_probability: + scores.append(self._score_yes_no_probability(image, question, answer)) + continue + prompt = f"{question} Please answer yes or no." + responses = self.generate([image], [prompt], response_format=response_format, **kwargs) + raw = responses[0] if responses else "" + response_answer = get_answer_from_response(raw) if response_format is not None else raw.lower() + score = 1.0 if answer.lower() in response_answer.lower() else 0.0 + scores.append(score) + return scores + + +def auxiliary_dicts_from_gt(gt: Any, batch_size: int) -> list[dict[str, Any]]: + """ + Map batch ``gt`` to per-row auxiliary dicts when using ``prompt_with_auxiliaries_collate``. + + For ``y_x`` metrics, :func:`~pruna.evaluation.metrics.utils.metric_data_processor` does not + include ``gt`` in its output; pass the batch ``gt`` argument here so fields such as + ``source_image_bytes`` are visible to editing metrics. + + Parameters + ---------- + gt : Any + Second element of the dataloader batch: typically a ``list[dict]`` of aux columns. + batch_size : int + Number of samples in the batch. + + Returns + ------- + list[dict[str, Any]] + One dict per row; empty dicts when ``gt`` is not a list of dicts (e.g. tensor placeholders + in tests). + """ + if batch_size <= 0: + return [] + if isinstance(gt, (list, tuple)) and gt and isinstance(gt[0], dict): + out: list[dict[str, Any]] = [] + for i in range(batch_size): + row = gt[i] if i < len(gt) else {} + out.append(row if isinstance(row, dict) else {}) + return out + return [{} for _ in range(batch_size)] + + +def prompts_from_y_x_inputs(inputs: Any, batch_len: int) -> list[str]: + """ + Extract per-row prompts from :func:`~pruna.evaluation.metrics.utils.metric_data_processor` output. + + Parameters + ---------- + inputs : Any + Return value of ``metric_data_processor`` for ``y_x`` call types. + batch_len : int + Number of samples in the batch. + + Returns + ------- + list[str] + Prompt list from ``inputs[1]`` when present; otherwise ``batch_len`` empty strings. + """ + if len(inputs) > 1 and isinstance(inputs[1], list): + return inputs[1] + return [""] * batch_len + + +class StatefulVLMMeanScoresMetric(StatefulMetric): + """ + Base for VLM metrics that accumulate ``scores`` and report the batch mean in :meth:`compute`. + + Subclasses set ``default_call_type`` and ``metric_name``, then call :meth:`_init_vlm_scores` + from ``__init__`` after any metric-specific attributes (e.g. ``use_probability``). + + Parameters + ---------- + device : str | torch.device | None + Device forwarded to :class:`~pruna.evaluation.metrics.metric_stateful.StatefulMetric`. + **kwargs : Any + Additional keyword arguments forwarded to the parent class. + + Notes + ----- + User guide: :doc:`Evaluate a model ` (Vision-language judge metrics). + Registry metrics (``VQAMetric``, ``VieScoreMetric``, …) pass ``vlm_type`` and ``model_name`` + into :meth:`_init_vlm_scores`; see :func:`get_vlm`. + + For **auxiliary image bytes** (editing benchmarks), use + :func:`~pruna.evaluation.metrics.vlm_utils.pil_rgb_from_aux_image_bytes` and + :data:`~pruna.evaluation.metrics.vlm_utils.VLM_AUX_IMAGE_BYTES_KEY_ORDER`. + """ + + scores: list[float] + default_call_type: str = "y_x" + higher_is_better: bool = True + metric_name: str = "" + + def _init_vlm_scores( + self, + *, + vlm: Optional[BaseVLM], + vlm_type: Literal["litellm", "transformers"], + model_name: Optional[str], + vlm_kwargs: Optional[dict[str, Any]], + structured_output: bool, + device: Optional[str | torch.device], + api_key: Optional[str], + call_type: str, + ) -> None: + """Attach ``self.vlm``, ``self.call_type``, and the ``scores`` state.""" + self.vlm = get_vlm( + vlm=vlm, + vlm_type=vlm_type, + model_name=model_name, + device=device, + api_key=api_key, + structured_output=structured_output, + **(vlm_kwargs or {}), + ) + self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) + self.add_state("scores", []) + self.higher_is_better = type(self).higher_is_better + + def compute_mean_of_scores(self) -> MetricResult: + """ + Return the mean of accumulated ``scores``, or ``0.0`` when empty. + + Returns + ------- + MetricResult + Aggregated result for this metric. + """ + if not self.scores: + return MetricResult(self.metric_name, self.__dict__, 0.0) + return MetricResult(self.metric_name, self.__dict__, float(np.mean(self.scores))) diff --git a/src/pruna/evaluation/metrics/vlm_utils.py b/src/pruna/evaluation/metrics/vlm_utils.py new file mode 100644 index 00000000..7e4f53e9 --- /dev/null +++ b/src/pruna/evaluation/metrics/vlm_utils.py @@ -0,0 +1,394 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared utilities and Pydantic models for VLM metrics.""" + +from __future__ import annotations + +import json +import math +import re +from io import BytesIO +from typing import Any, List, Sequence + +import torch +from PIL import Image +from pydantic import BaseModel, Field + +VLM_AUX_IMAGE_BYTES_KEY_ORDER: tuple[str, ...] = ( + "source_image_bytes", + "input_image_bytes", + "reference_image_bytes", + "image_bytes", +) + + +def _tensor_to_pil(tensor: torch.Tensor) -> Image.Image: + if tensor.ndim == 4: + tensor = tensor[0] + if tensor.max() > 1: + tensor = tensor / 255.0 + np_img = (tensor.cpu().numpy() * 255).astype("uint8") + return Image.fromarray(np_img.transpose(1, 2, 0)) + + +def _process_images(images: torch.Tensor) -> List[Any]: + return [_tensor_to_pil(img) if isinstance(img, torch.Tensor) else img for img in images] + + +def pil_rgb_from_aux_image_bytes( + aux: dict[str, Any], + *, + min_bytes_in_value_scan: int = 0, +) -> Image.Image | None: + """ + Decode a source / reference RGB image from auxiliary dict bytes. + + Tries :data:`VLM_AUX_IMAGE_BYTES_KEY_ORDER` first, then scans ``aux.values()`` for raw + byte blobs. The value scan skips blobs shorter than ``min_bytes_in_value_scan`` (use + ``100`` to match editing metrics that avoid tiny false positives). + + Parameters + ---------- + aux : dict[str, Any] + Per-sample auxiliary dict (e.g. from ``prompt_with_auxiliaries_collate``). + min_bytes_in_value_scan : int, optional + Minimum length for blobs discovered only in the generic ``aux.values()`` pass. + Named keys use any non-empty ``bytes`` / ``bytearray``. Use ``0`` when any + non-trivial blob in ``aux.values()`` should be tried (e.g. tests building preds from aux). + + Returns + ------- + PIL.Image.Image | None + RGB image if decoding succeeds; ``None`` if nothing decodable was found. + """ + for key in VLM_AUX_IMAGE_BYTES_KEY_ORDER: + raw = aux.get(key) + if isinstance(raw, (bytes, bytearray)) and raw: + try: + return Image.open(BytesIO(raw)).convert("RGB") + except Exception: + continue + for v in aux.values(): + if isinstance(v, (bytes, bytearray)) and v and len(v) >= min_bytes_in_value_scan: + try: + return Image.open(BytesIO(v)).convert("RGB") + except Exception: + continue + return None + + +def yes_no_first_token_id_groups(tokenizer: Any) -> tuple[list[int], list[int]]: + """ + Collect first subword token ids that start a yes/no answer for next-token softmax scoring. + + Used by :class:`~pruna.evaluation.metrics.vlm_base.TransformersVLM` for VQAScore-style + P(Yes): sum softmax mass on these ids, normalized against yes+no for a stable [0, 1] score. + + Parameters + ---------- + tokenizer : Any + Hugging Face ``PreTrainedTokenizer`` (or compatible ``encode``). + + Returns + ------- + list[int] + Distinct token ids for yes-leaning first tokens (overlap with no-ids removed). + list[int] + Distinct token ids for no-leaning first tokens (overlap with yes-ids removed). + """ + yes_prefixes = ( + "Yes", + " Yes", + " yes", + "yes", + "\nYes", + "\n Yes", + "Yes,", + " Yes,", + ) + no_prefixes = ( + "No", + " No", + " no", + "no", + "\nNo", + "\n No", + "No,", + " No,", + ) + yes_ids: set[int] = set() + no_ids: set[int] = set() + for s in yes_prefixes: + ids = tokenizer.encode(s, add_special_tokens=False) + if ids: + yes_ids.add(ids[0]) + for s in no_prefixes: + ids = tokenizer.encode(s, add_special_tokens=False) + if ids: + no_ids.add(ids[0]) + overlap = yes_ids & no_ids + yes_only = sorted(yes_ids - overlap) + no_only = sorted(no_ids - overlap) + return yes_only, no_only + + +class VQAnswer(BaseModel): + """ + Structured output for VQA questions (Yes/No or open-ended). + + Parameters + ---------- + answer : str + Answer to the question. Typically "Yes" or "No" for alignment metrics, + but can be any string for open-ended questions. + """ + + answer: str = Field(description="Answer to the question") + + +class FloatOutput(BaseModel): + """ + Structured output for numeric scoring (img_edit_score, VieScoreMetric). + + Parameters + ---------- + score : float + Score from 0 to 10. + """ + + score: float = Field(ge=0, le=10, description="Score from 0 to 10") + + +class VIEScoreJsonOutput(BaseModel): + """ + Structured output matching VIEScore JSON (text-to-image / editing evaluation). + + Parameters + ---------- + score : list[float] + One or more sub-scores on a 0--10 scale (e.g. two criteria for editing). + reasoning : str + Short evaluator reasoning. + """ + + score: list[float] = Field(description="Sub-scores on 0-10 scale") + reasoning: str = Field(default="", description="Brief reasoning") + + +def _json_dict_from_response_fragment(text: str) -> dict | None: + """Parse a leading JSON object from a string response, or return None.""" + stripped = (text or "").strip() + if not stripped.startswith("{"): + return None + try: + data = json.loads(stripped) + except (json.JSONDecodeError, TypeError): + return None + return data if isinstance(data, dict) else None + + +class TextOutput(BaseModel): + """ + Structured output for text extraction (text_score). + + Parameters + ---------- + text : str + Extracted text from the image, or 'No text recognized' if empty. + """ + + text: str = Field(description="Extracted text from the image, or 'No text recognized' if empty") + + +def get_answer_from_response(response: str | BaseModel | dict) -> str: + """ + Extract answer string from a VLM score() response (VQAnswer, dict, or raw string). + + Parameters + ---------- + response : str | BaseModel | dict + Raw response from vlm.generate() or vlm.score(). + + Returns + ------- + str + Extracted answer string, or empty string. + """ + if response is None: + return "" + if isinstance(response, VQAnswer): + return response.answer + if isinstance(response, dict): + return response.get("answer", "") + raw = str(response).strip() + parsed = _json_dict_from_response_fragment(raw) + if parsed is not None: + return str(parsed.get("answer", raw)) + return raw + + +def get_text_from_response(response: str | BaseModel | dict) -> str: + """ + Extract text from a VLM generate() response (str, pydantic, or dict). + + Parameters + ---------- + response : str | BaseModel | dict + Raw response from vlm.generate(). + + Returns + ------- + str + Extracted text, or empty string. + """ + if response is None: + return "" + if isinstance(response, TextOutput): + text = response.text + elif isinstance(response, dict): + text = response.get("text", "") + else: + text = (response or "").strip() + parsed = _json_dict_from_response_fragment(text) + if parsed is not None: + text = str(parsed.get("text", text)) + for phrase in ("No text recognized", "no text recognized", "No text"): + text = text.replace(phrase, "").strip() + return (text or "").strip() + + +def get_score_from_response(response: str | BaseModel | dict) -> float: + """ + Extract numeric score (0-10) from a VLM generate() response. + + Handles: + + * ``FloatOutput`` instances (local / parsed Pydantic). + * ``dict`` with a ``"score"`` key. + * JSON **strings** (e.g. LitellmVLM returns ``model_dump_json()`` for structured output). + * Plain text with a number (first decimal or integer matched). + + Parameters + ---------- + response : str | BaseModel | dict + Raw response from vlm.generate(). + + Returns + ------- + float + Score in [0, 1] (normalized from 0-10). Always non-negative. + """ + if response is None: + return 0.0 + if isinstance(response, FloatOutput): + return max(0.0, min(float(response.score), 10.0)) / 10.0 + if isinstance(response, dict): + return max(0.0, min(float(response.get("score", 0)), 10.0)) / 10.0 + text = str(response or "").strip() + parsed = _json_dict_from_response_fragment(text) + if parsed is not None and "score" in parsed: + try: + return max(0.0, min(float(parsed["score"]), 10.0)) / 10.0 + except (TypeError, ValueError): + pass + match = re.search(r"\d+(?:\.\d+)?", text) + if match: + return min(float(match.group(0)), 10.0) / 10.0 + return 0.0 + + +def viescore_min_scores_0_10(response: str | BaseModel | dict) -> list[float]: + """ + Parse VIEScore-style JSON with a ``score`` list of values in ``[0, 10]``. + + Parameters + ---------- + response : str | BaseModel | dict + Model output (pydantic ``VIEScoreJsonOutput``, dict, or JSON string). + + Returns + ------- + list[float] + Sub-scores; empty if parsing fails. + """ + if response is None: + return [] + if isinstance(response, VIEScoreJsonOutput): + return [float(x) for x in response.score] + if isinstance(response, dict): + raw = response.get("score", []) + if isinstance(raw, (list, tuple)): + return [float(x) for x in raw] + return [] + text = str(response or "").strip() + parsed = _json_dict_from_response_fragment(text) + if parsed is not None and "score" in parsed: + try: + raw = parsed["score"] + if isinstance(raw, (list, tuple)): + return [float(x) for x in raw] + return [float(raw)] + except (TypeError, ValueError): + return [] + return [] + + +def pad_viescore_subscores_to_two(values: list[float]) -> list[float]: + """ + Pad or truncate VIEScore sub-score lists to length two for :func:`viescore_tie_overall_unit`. + + Parameters + ---------- + values : list[float] + Parsed sub-scores from :func:`viescore_min_scores_0_10`. + + Returns + ------- + list[float] + Exactly two values in ``[0, 10]``, padding with ``0.0`` when fewer than two are present. + """ + if len(values) >= 2: + return values[:2] + if not values: + return [0.0, 0.0] + return values + [0.0] * (2 - len(values)) + + +def viescore_tie_overall_unit(sc_scores: Sequence[float], pq_scores: Sequence[float]) -> float: + """ + Overall VIEScore for text-image editing (``tie`` task): ``sqrt(min(SC)*min(PQ))/10`` in ``[0, 1]``. + + Matches the reference ``math.sqrt(SC_score * PQ_score)`` on a 0--10 scale with + ``SC_score = min(...)``, ``PQ_score = min(...)`` (`VIEScore`_). + + .. _VIEScore: https://github.com/TIGER-AI-Lab/VIEScore + + Parameters + ---------- + sc_scores : Sequence[float] + Semantic / instruction sub-scores on 0--10 (e.g. editing success and over-editing). + pq_scores : Sequence[float] + Perceptual sub-scores on 0--10 (e.g. naturalness and artifacts). + + Returns + ------- + float + Overall score in ``[0, 1]`` (higher is better). + """ + if not sc_scores or not pq_scores: + return 0.0 + sc = min(float(x) for x in sc_scores) + pq = min(float(x) for x in pq_scores) + return math.sqrt(sc * pq) / 10.0 diff --git a/tests/evaluation/test_vlm_base_infrastructure.py b/tests/evaluation/test_vlm_base_infrastructure.py new file mode 100644 index 00000000..a4eaa139 --- /dev/null +++ b/tests/evaluation/test_vlm_base_infrastructure.py @@ -0,0 +1,684 @@ +"""Tests for VLM metrics (VQA, ImageEditScore, QAAccuracy, TextScore, VieScore) and vlm_utils helpers.""" + +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from pruna.evaluation.metrics.metric_img_edit_score import ImageEditScoreMetric +from pruna.evaluation.metrics.metric_oneig_alignment import OneIGAlignmentMetric +from pruna.evaluation.metrics.metric_qa_accuracy import QAAccuracyMetric +from pruna.evaluation.metrics.metric_text_score import OneIGTextScoreMetric, TextScoreMetric +from pruna.evaluation.metrics.metric_vie_score import VieScoreMetric +from pruna.evaluation.metrics.metric_vqa import VQAMetric +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.vlm_base import BaseVLM, get_vlm +from pruna.evaluation.metrics.vlm_utils import ( + FloatOutput, + VLM_AUX_IMAGE_BYTES_KEY_ORDER, + get_score_from_response, + yes_no_first_token_id_groups, +) + +from ._vlm_batch_snapshot_helpers import ( + BenchmarkVlmBatchOutcome, + pred_tensor_from_auxiliaries, + safe_json_for_snapshot, + vlm_benchmark_batch_to_json_record, +) + +SMOL_VLM = "HuggingFaceTB/SmolVLM-256M-Instruct" + +_ALL_VLM = ( + VQAMetric, + ImageEditScoreMetric, + QAAccuracyMetric, + OneIGAlignmentMetric, + TextScoreMetric, + OneIGTextScoreMetric, + VieScoreMetric, +) + +_SLOW_SMOL_SUBSET = ( + VQAMetric, + OneIGAlignmentMetric, + ImageEditScoreMetric, + VieScoreMetric, +) + + +@pytest.mark.parametrize( + ("raw", "expected"), + [ + (FloatOutput(score=8.0), 0.8), + ({"score": 5.0}, 0.5), + ('{"score": 7.5}', 0.75), + ('{"score": 10}', 1.0), + ("8", 0.8), + ("Score: 7.5 out of 10", 0.75), + ("", 0.0), + ], +) +def test_get_score_from_response(raw: object, expected: float) -> None: + """``get_score_from_response`` maps pydantic, dict, JSON, and text to ``[0, 1]``.""" + assert get_score_from_response(raw) == pytest.approx(expected) + + +def _dummy_image(batch: int = 1, size: int = 224) -> torch.Tensor: + return torch.rand(batch, 3, size, size) + + +def _update_metric(metric: object, prompts: list, images: torch.Tensor) -> None: + if isinstance(metric, OneIGAlignmentMetric): + metric.update( + prompts, + [ + { + "questions": {"1": "Is there a cat?", "2": "Is it sleeping?"}, + "dependencies": {"1": [0], "2": [1]}, + } + ], + images, + ) + elif isinstance(metric, QAAccuracyMetric): + metric.update( + prompts, + [{"questions": {"1": "Is there a cat?"}}], + images, + ) + elif isinstance(metric, (TextScoreMetric, OneIGTextScoreMetric)): + metric.update(prompts, ["cat"], images) + else: + metric.update(prompts, images, images) + + +@pytest.mark.cpu +@pytest.mark.slow +@pytest.mark.parametrize("metric_cls", _SLOW_SMOL_SUBSET) +def test_vlm_metrics_transformers_smolvlm(metric_cls: type) -> None: + """Smoke-test a subset with local SmolVLM (full matrix covered by litellm mock).""" + metric = metric_cls( + vlm_type="transformers", + model_name=SMOL_VLM, + device="cpu", + structured_output=True, + ) + images = _dummy_image(batch=1) + prompts = ["a cat"] + _update_metric(metric, prompts, images) + result = metric.compute() + assert result.name == metric.metric_name + assert isinstance(result.result, float) + if metric.higher_is_better: + assert 0.0 <= result.result <= 1.0 + else: + assert result.result >= 0.0 + + +@pytest.mark.cpu +@pytest.mark.parametrize("metric_cls", _ALL_VLM) +def test_vlm_metrics_litellm_mocked(metric_cls: type) -> None: + """Each VLM metric runs end-to-end with mocked litellm.""" + pytest.importorskip("litellm") + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + if metric_cls in (VQAMetric, QAAccuracyMetric, OneIGAlignmentMetric): + mock_response.choices[0].message.content = '{"answer": "Yes"}' + else: + mock_response.choices[0].message.content = '{"score": 8}' + + with patch("litellm.completion") as mock_completion: + mock_completion.return_value = mock_response + + metric = metric_cls( + vlm_type="litellm", + model_name="gpt-4o", + device="cpu", + structured_output=True, + ) + images = _dummy_image(batch=1) + prompts = ["a cat"] + _update_metric(metric, prompts, images) + result = metric.compute() + + assert result.name == metric.metric_name + assert isinstance(result.result, float) + assert mock_completion.called + + +@pytest.mark.cpu +def test_vlm_metrics_empty_compute_returns_zero() -> None: + """No updates → compute is 0.0 (same for all stateful VLM metrics).""" + metric = VQAMetric( + vlm_type="transformers", + model_name=SMOL_VLM, + device="cpu", + structured_output=True, + ) + assert metric.compute().result == 0.0 + + +@pytest.mark.cpu +def test_vlm_metrics_custom_vlm() -> None: + """Custom VLM passed to VQAMetric is used instead of the default litellm backend.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.generate.return_value = ["Yes"] + mock_vlm.score.return_value = [1.0] + + metric = VQAMetric(vlm=mock_vlm, vlm_type="litellm", device="cpu", structured_output=True) + images = _dummy_image(batch=1) + prompts = ["a cat"] + metric.update(prompts, images, images) + assert metric.compute().result == 1.0 + mock_vlm.score.assert_called() + + +@pytest.mark.cpu +def test_get_vlm_returns_custom() -> None: + """get_vlm returns the provided VLM instance unchanged.""" + custom = MagicMock(spec=BaseVLM) + out = get_vlm(vlm=custom, vlm_type="litellm", model_name="gpt-4o") + assert out is custom + + +@pytest.mark.cpu +def test_yes_no_first_token_id_groups_disjoint() -> None: + """Prefix token ids for Yes vs No should not overlap (avoids double-counting).""" + pytest.importorskip("transformers") + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained("gpt2") + yes_ids, no_ids = yes_no_first_token_id_groups(tok) + assert yes_ids and no_ids + assert not (set(yes_ids) & set(no_ids)) + + +@pytest.mark.cpu +def test_get_vlm_requires_model_name_without_vlm() -> None: + """get_vlm raises ValueError when no model_name is given and no vlm is provided.""" + with pytest.raises(ValueError, match="model_name"): + get_vlm(vlm=None, vlm_type="litellm") + + +@pytest.mark.cpu +@pytest.mark.parametrize( + "metric_cls, expected_name, expected_result", + [ + (TextScoreMetric, "text_score", 1.0), + (OneIGTextScoreMetric, "oneig_text_score", 1.0), + ], +) +def test_text_metrics_list_str_gt(metric_cls: type, expected_name: str, expected_result: float) -> None: + """Text metrics accept plain string ground-truth and return the expected score.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.generate.return_value = ["hello world"] + + metric = metric_cls(vlm=mock_vlm, vlm_type="litellm", device="cpu") + images = _dummy_image(batch=1) + metric.update(["a prompt"], ["hello world"], images) + result = metric.compute() + + assert result.result == expected_result + assert result.name == expected_name + mock_vlm.generate.assert_called_once() + + +@pytest.mark.cpu +def test_text_score_result_in_zero_one_range() -> None: + """TextScoreMetric must return a normalized score in [0, 1], not raw edit distance.""" + mock_vlm = MagicMock(spec=BaseVLM) + # VLM OCR returns something very different from ground truth (high edit distance) + mock_vlm.generate.return_value = ["completely wrong text abcdefghijklmnop"] + + metric = TextScoreMetric(vlm=mock_vlm, device="cpu") + images = _dummy_image(batch=1) + metric.update(["prompt"], ["hello"], images) + result = metric.compute() + + assert 0.0 <= result.result <= 1.0, f"TextScoreMetric must return [0,1], got {result.result}" + assert result.result < 0.5, f"Very different strings should score below 0.5, got {result.result}" + + +@pytest.mark.cpu +def test_text_score_perfect_match_is_one() -> None: + """TextScoreMetric: identical OCR and GT -> score 1.0.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.generate.return_value = ["hello world"] + + metric = TextScoreMetric(vlm=mock_vlm, device="cpu") + images = _dummy_image(batch=1) + metric.update(["prompt"], ["hello world"], images) + result = metric.compute() + + assert result.result == 1.0, f"Perfect match should give 1.0, got {result.result}" + assert result.higher_is_better is True + + +@pytest.mark.cpu +def test_text_score_registry_aliases() -> None: + """Registry aliases ocr_levenshtein and ocr_text_score resolve to the correct metric classes.""" + from pruna.evaluation.metrics.registry import MetricRegistry + + lev = MetricRegistry.get_metric("ocr_levenshtein", device="cpu", model_name="openai/gpt-4o") + comp = MetricRegistry.get_metric("ocr_text_score", device="cpu", model_name="openai/gpt-4o") + assert type(lev).__name__ == "TextScoreMetric" + assert type(comp).__name__ == "OneIGTextScoreMetric" + assert lev.metric_name == "text_score" + assert comp.metric_name == "oneig_text_score" + + +@pytest.mark.cpu +def test_oneig_text_score_utils_golden_composite() -> None: + """oneig_mean_text_score returns expected component values for a known input.""" + from pruna.evaluation.metrics.metric_text_score_utils import oneig_mean_text_score + + ed, cr, wac, composite = oneig_mean_text_score( + edit_distances=[10.0], + completion_ratios=[0.0], + match_counts=[2], + gt_totals=[4], + language_mode="EN", + ) + assert ed == 10.0 + assert cr == 0.0 + assert wac == 0.5 + assert composite == pytest.approx(0.95) + + _, _, _, zh = oneig_mean_text_score( + edit_distances=[30.0], + completion_ratios=[0.0], + match_counts=[0], + gt_totals=[1], + language_mode="ZH", + ) + assert zh == pytest.approx(0.4) + + +@pytest.mark.cpu +def test_qa_accuracy_all_or_nothing_partial_fail() -> None: + """all_or_nothing: if any question scores 0, the image score is 0.0 (not a partial mean).""" + mock_vlm = MagicMock(spec=BaseVLM) + # First question Yes (1.0), second question No (0.0) → mean=0.5, all_or_nothing=0.0 + mock_vlm.score.return_value = [1.0, 0.0] + + metric = QAAccuracyMetric(vlm=mock_vlm, device="cpu", aggregation="all_or_nothing") + metric.update( + ["a prompt"], + [{"questions": {"1": "Is there a cat?", "2": "Is it blue?"}}], + _dummy_image(batch=1), + ) + result = metric.compute() + assert result.result == 0.0, f"Expected 0.0 for all_or_nothing with one No, got {result.result}" + + +@pytest.mark.cpu +def test_qa_accuracy_all_or_nothing_all_yes() -> None: + """all_or_nothing: all Yes → score 1.0.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.score.return_value = [1.0, 1.0] + + metric = QAAccuracyMetric(vlm=mock_vlm, device="cpu", aggregation="all_or_nothing") + metric.update( + ["a prompt"], + [{"questions": {"1": "Is there a cat?", "2": "Is it blue?"}}], + _dummy_image(batch=1), + ) + result = metric.compute() + assert result.result == 1.0, f"Expected 1.0 for all_or_nothing with all Yes, got {result.result}" + + +@pytest.mark.cpu +def test_qa_accuracy_invalid_aggregation_raises() -> None: + """qa_accuracy rejects aggregation values other than mean / all_or_nothing.""" + mock_vlm = MagicMock(spec=BaseVLM) + with pytest.raises(ValueError, match="aggregation"): + QAAccuracyMetric(vlm=mock_vlm, device="cpu", aggregation="median") + + +@pytest.mark.cpu +def test_vie_score_tie_uses_source_from_gt_and_two_image_sc() -> None: + """With ``source_image_bytes`` in gt, VieScore calls two-image SC then PQ on the edited image.""" + from io import BytesIO + + from PIL import Image + + buf = BytesIO() + Image.new("RGB", (8, 8), color=(0, 0, 200)).save(buf, format="PNG") + src_bytes = buf.getvalue() + + mock_vlm = MagicMock() + mock_vlm.generate_with_image_lists.return_value = ['{"score": [8.0, 8.0], "reasoning": "ok"}'] + mock_vlm.generate.return_value = ['{"score": [9.0, 9.0], "reasoning": "ok"}'] + + metric = VieScoreMetric(vlm=mock_vlm, device="cpu", structured_output=True) + pred = _dummy_image(batch=1) + metric.update( + ["make the sky purple"], + [{"source_image_bytes": src_bytes}], + pred, + ) + result = metric.compute() + + assert mock_vlm.generate_with_image_lists.called + assert mock_vlm.generate.called + assert 0.0 <= result.result <= 1.0 + + +@pytest.mark.cpu +def test_vie_score_uses_get_score_from_response() -> None: + """VieScoreMetric ``t2i`` path parses JSON ``score`` lists via ``viescore_min_scores_0_10``.""" + mock_vlm = MagicMock(spec=BaseVLM) + # LitellmVLM returns model_dump_json() for structured outputs → JSON string (two SC + two PQ sub-scores) + mock_vlm.generate.return_value = ['{"score": [8.0, 8.0], "reasoning": ""}'] + + metric = VieScoreMetric(vlm=mock_vlm, device="cpu", structured_output=True) + metric.update(["a cat on a sofa"], _dummy_image(batch=1), _dummy_image(batch=1)) + result = metric.compute() + + # min(SC)=8, min(PQ)=8 → sqrt(8 * 8) / 10 = 0.8 + assert abs(result.result - 0.8) < 0.01, f"Expected ~0.8, got {result.result}" + + +@pytest.mark.cpu +def test_img_edit_score_negative_response_clamped() -> None: + """img_edit_score must be non-negative even when the VLM generates a negative JSON score. + + Regression for: Outlines constrained decoding can emit {"score": -10} despite the + FloatOutput JSON schema specifying minimum=0, because Outlines does not enforce numeric + bounds during token sampling. The fix is max(0.0, ...) in get_score_from_response. + """ + mock_vlm = MagicMock(spec=BaseVLM) + # Simulate Outlines generating a negative value (the bug scenario) + mock_vlm.generate.return_value = ['{"score": -10.0}'] + + metric = ImageEditScoreMetric(vlm=mock_vlm, device="cpu", structured_output=True) + metric.update(["replace the boot with a mug"], torch.zeros(1), _dummy_image(batch=1)) + result = metric.compute() + + assert result.result >= 0.0, f"img_edit_score must be >= 0, got {result.result}" + + +@pytest.mark.cpu +def test_qa_accuracy_all_or_nothing_ambiguous_score() -> None: + """all_or_nothing: score exactly 0.5 (ambiguous) is treated as No → result 0.0.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.score.return_value = [0.5] + + metric = QAAccuracyMetric(vlm=mock_vlm, device="cpu", aggregation="all_or_nothing") + metric.update( + ["a prompt"], + [{"questions": {"1": "Is there a cat?"}}], + _dummy_image(batch=1), + ) + result = metric.compute() + assert result.result == 0.0, f"Score 0.5 should be treated as No (ambiguous), got {result.result}" + + +@pytest.mark.cpu +@pytest.mark.slow +def test_yes_no_token_ids_smolvlm_nonempty() -> None: + """SmolVLM tokenizer must yield non-empty disjoint yes/no prefix ids for VQAScore scoring.""" + pytest.importorskip("transformers") + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct") + yes_ids, no_ids = yes_no_first_token_id_groups(tok) + assert len(yes_ids) > 0, "SmolVLM tokenizer has no 'Yes'-prefix token ids" + assert len(no_ids) > 0, "SmolVLM tokenizer has no 'No'-prefix token ids" + assert not (set(yes_ids) & set(no_ids)), "yes_ids and no_ids must be disjoint" + + +@pytest.mark.cpu +def test_img_edit_score_uses_prompt_from_x() -> None: + """img_edit_score must score the edited image against the instruction from x, not gt.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.generate.return_value = ['{"score": 9}'] + + metric = ImageEditScoreMetric(vlm=mock_vlm, device="cpu") + pred = _dummy_image(batch=1) + metric.update( + ["replace the cat with a dog"], # x = instruction + pred, # gt = unused for y_x + pred, # outputs = edited image + ) + result = metric.compute() + + call_args = mock_vlm.generate.call_args + prompt_sent = call_args[0][1][0] # second positional arg = prompts list, first item + assert "replace the cat with a dog" in prompt_sent, f"Instruction not in VLM prompt. Got: {prompt_sent}" + assert abs(result.result - 0.9) < 0.01, f"Expected ~0.9, got {result.result}" + + +@pytest.mark.cpu +def test_vie_score_geditbench_gap_documented() -> None: + """VieScoreMetric infers text--image editing from ``source_image_bytes`` in aux (no ``task_type``). + + This test fails if a ``task_type`` parameter is added to ``__init__`` without updating + GEditBench integration tests and benchmark copy accordingly. + """ + import inspect + + sig = inspect.signature(VieScoreMetric.__init__) + assert "task_type" not in sig.parameters, ( + "VieScoreMetric now has task_type — update GEditBench docs and e2e tests, then remove this sentinel." + ) + + +@pytest.mark.cpu +def test_litellm_logprob_aggregation_sums_all_yes_tokens() -> None: + """LitellmVLM logprob scoring must sum all yes-prefix token probs, not return the first.""" + pytest.importorskip("litellm") + import math + from unittest.mock import MagicMock, patch + + import numpy as np + from PIL import Image + + from pruna.evaluation.metrics.vlm_base import LitellmVLM + + # Simulate top_logprobs for first output token: + # "Yes" → logprob=-2.303 (p≈0.10), " yes" → logprob=-2.996 (p≈0.05) → total p_yes≈0.15 + # "No" → logprob=-1.609 (p≈0.20), " no" → logprob=-2.303 (p≈0.10) → total p_no≈0.30 + # normalized: p_yes/(p_yes+p_no) ≈ 0.15/0.45 ≈ 0.333 + def make_top_logprob(token, logprob): + t = MagicMock() + t.token = token + t.logprob = logprob + return t + + first_tok = MagicMock() + first_tok.top_logprobs = [ + make_top_logprob("Yes", math.log(0.10)), + make_top_logprob(" yes", math.log(0.05)), + make_top_logprob("No", math.log(0.20)), + make_top_logprob(" no", math.log(0.10)), + make_top_logprob("maybe", math.log(0.55)), + ] + + mock_logprobs = MagicMock() + mock_logprobs.content = [first_tok] + + mock_choice = MagicMock() + mock_choice.logprobs = mock_logprobs + mock_choice.message.content = "Yes" + + mock_response = MagicMock() + mock_response.choices = [mock_choice] + + with patch("litellm.completion", return_value=mock_response): + vlm = LitellmVLM(model_name="openai/gpt-4o") + img = Image.fromarray(np.zeros((32, 32, 3), dtype="uint8")) + score = vlm._score_with_logprobs(img, "Is there a cat?", "Yes") + + # Should be ~0.333 (p_yes=0.15 / (p_yes+p_no)=0.45), not just 0.10 (first match) + assert 0.28 < score < 0.40, f"Expected ~0.333 (sum-normalized), got {score}" + + +@pytest.mark.cpu +@pytest.mark.slow +def test_vqa_probability_score_normalized() -> None: + """P(Yes) from TransformersVLM.score use_probability=True is in [0, 1].""" + pytest.importorskip("transformers") + import numpy as np + from PIL import Image + + from pruna.evaluation.metrics.vlm_base import TransformersVLM + + vlm = TransformersVLM( + model_name="HuggingFaceTB/SmolVLM-256M-Instruct", + device="cpu", + use_outlines=False, + ) + img = Image.fromarray(np.zeros((32, 32, 3), dtype="uint8")) + scores = vlm.score([img], ["Is there a cat?"], ["Yes"], use_probability=True) + assert len(scores) == 1 + assert 0.0 <= scores[0] <= 1.0, f"P(Yes) must be in [0, 1], got {scores[0]}" + + +# --------------------------------------------------------------------------- +# vlm_benchmark_batch_to_json_record serialization tests +# --------------------------------------------------------------------------- + + +def test_vlm_benchmark_batch_to_json_record_serializes_batch() -> None: + """Record includes prompts, pred shape, and metric fields.""" + mr = MetricResult(name="qa_accuracy", params={}, result=0.25, higher_is_better=True) + outcome = BenchmarkVlmBatchOutcome( + result=mr, + prompts=["prompt"], + auxiliaries=[{"path": "/tmp/x.png"}], + pred=torch.zeros(1, 3, 8, 8), + ) + rec = vlm_benchmark_batch_to_json_record( + outcome, + benchmark_key="GenEval", + benchmark_name="GenEval", + metric_name="qa_accuracy", + vlm_type="transformers", + model_name="m", + device="cpu", + ) + assert rec["inputs"]["prompts"] == ["prompt"] + assert rec["pred"]["shape"] == [1, 3, 8, 8] + assert rec["metric_result"]["result"] == 0.25 + + +def test_safe_json_handles_bytes_without_expanding() -> None: + """Bytes values in aux (e.g. source_image_bytes) are summarized, not expanded to str repr.""" + result = safe_json_for_snapshot({"source_image_bytes": b"\xff\xd8\xff" * 1000, "name": "test"}) + assert result["source_image_bytes"] == {"bytes_len": 3000} + assert result["name"] == "test" + + +def test_vlm_benchmark_batch_to_json_record_preserves_null_question_slots() -> None: + """Padded ``None`` question labels stay JSON null, not the string ``"None"``.""" + mr = MetricResult(name="oneig_alignment", params={}, result=1.0, higher_is_better=True) + outcome = BenchmarkVlmBatchOutcome( + result=mr, + prompts=["p"], + auxiliaries=[{"questions": {"1": "Are there boys?", "21": None}, "subset": "Anime_Stylization"}], + pred=torch.zeros(1, 3, 8, 8), + ) + rec = vlm_benchmark_batch_to_json_record( + outcome, + benchmark_key="OneIGAnimeStylization", + benchmark_name="OneIG Anime Stylization", + metric_name="oneig_alignment", + vlm_type="transformers", + model_name="m", + device="cpu", + ) + qs = rec["inputs"]["auxiliary_0"]["questions"] + assert qs["1"] == "Are there boys?" + assert qs["21"] is None + + +# --------------------------------------------------------------------------- +# pred_tensor_from_auxiliaries (test helper, wraps pil_rgb_from_aux_image_bytes) tests +# --------------------------------------------------------------------------- + + +def _make_jpeg_bytes(h: int = 32, w: int = 32) -> bytes: + """Return a tiny JPEG-encoded RGB image as bytes (test helper).""" + import io + + import numpy as np + from PIL import Image + + arr = (np.random.rand(h, w, 3) * 255).astype("uint8") + buf = io.BytesIO() + Image.fromarray(arr).save(buf, format="JPEG") + return buf.getvalue() + + +@pytest.mark.cpu +def test_pred_from_auxiliaries_uses_source_image_bytes() -> None: + """pred_tensor_from_auxiliaries decodes source_image_bytes into a float tensor in [0, 1].""" + src_bytes = _make_jpeg_bytes() + aux = [{"source_image_bytes": src_bytes, "category": "background_change"}] + pred = pred_tensor_from_auxiliaries(aux, size=64) + + assert pred.shape == (1, 3, 64, 64), f"Expected (1,3,64,64), got {pred.shape}" + assert pred.min() >= 0.0 and pred.max() <= 1.0, "Pixel values must be in [0, 1]" + + +@pytest.mark.cpu +def test_pred_from_auxiliaries_falls_back_to_noise_without_source_image() -> None: + """pred_tensor_from_auxiliaries returns random noise when no source_image_bytes is present.""" + aux = [{"category": "single_object"}] + pred = pred_tensor_from_auxiliaries(aux, size=32) + assert pred.shape == (1, 3, 32, 32) + assert pred.min() >= 0.0 and pred.max() <= 1.0 + + +@pytest.mark.cpu +def test_pred_from_auxiliaries_mixed_batch() -> None: + """Batch with one source image and one missing falls back per-item.""" + src_bytes = _make_jpeg_bytes() + aux = [ + {"source_image_bytes": src_bytes, "category": "color_alter"}, + {"category": "style_change"}, # no source image + ] + pred = pred_tensor_from_auxiliaries(aux, size=32) + assert pred.shape == (2, 3, 32, 32) + assert pred.min() >= 0.0 and pred.max() <= 1.0 + + +@pytest.mark.cpu +def test_pred_from_auxiliaries_generic_bytes_scan() -> None: + """pred_tensor_from_auxiliaries discovers image bytes under an unknown field name (generic scan).""" + src_bytes = _make_jpeg_bytes() + aux = [{"my_custom_image_bytes": src_bytes, "category": "motion_change"}] + pred = pred_tensor_from_auxiliaries(aux, size=32) + assert pred.shape == (1, 3, 32, 32) + assert pred.min() >= 0.0 and pred.max() <= 1.0 + + +@pytest.mark.cpu +def test_pred_from_auxiliaries_known_names_take_priority() -> None: + """Known field names are resolved before the generic bytes scan.""" + src_bytes_known = _make_jpeg_bytes(16, 16) + src_bytes_unknown = _make_jpeg_bytes(32, 32) + first_known = VLM_AUX_IMAGE_BYTES_KEY_ORDER[0] + aux = [{"other_bytes": src_bytes_unknown, first_known: src_bytes_known}] + pred = pred_tensor_from_auxiliaries(aux, size=16) + # Should use the known key (16x16 image → 16x16 crop); generic scan would pick 32x32 + assert pred.shape == (1, 3, 16, 16) + + +@pytest.mark.cpu +def test_pred_from_auxiliaries_require_source_image_raises_when_missing() -> None: + """require_source_image=True raises ValueError instead of silently returning noise.""" + aux = [{"category": "replace"}] # no image bytes + with pytest.raises(ValueError, match="require_source_image=True"): + pred_tensor_from_auxiliaries(aux, size=32, require_source_image=True) + + +@pytest.mark.cpu +def test_pred_from_auxiliaries_require_source_image_succeeds_when_present() -> None: + """require_source_image=True succeeds and decodes bytes when source_image_bytes is present.""" + src_bytes = _make_jpeg_bytes() + aux = [{"source_image_bytes": src_bytes, "category": "replace"}] + pred = pred_tensor_from_auxiliaries(aux, size=32, require_source_image=True) + assert pred.shape == (1, 3, 32, 32) + assert pred.min() >= 0.0 and pred.max() <= 1.0