diff --git a/docs/user_manual/configure.rst b/docs/user_manual/configure.rst index 4bfb8a67..f1cbf9cd 100644 --- a/docs/user_manual/configure.rst +++ b/docs/user_manual/configure.rst @@ -253,7 +253,7 @@ Underneath you can find the list of all the available datasets. - ``text: str`` * - Image Generation - `LAION256 `_, `OpenImage `_, `COCO `_, `DrawBench `_, `PartiPrompts `_, `GenAIBench `_ - - ``image_generation_collate``, ``prompt_collate`` + - ``image_generation_collate``, ``prompt_with_auxiliaries_collate`` - ``text: str``, ``image: Optional[PIL.Image.Image]`` * - Image Classification - `ImageNet `_, `MNIST `_, `CIFAR10 `_ diff --git a/docs/user_manual/evaluate.rst b/docs/user_manual/evaluate.rst index 35f5ef0a..eb0508a0 100644 --- a/docs/user_manual/evaluate.rst +++ b/docs/user_manual/evaluate.rst @@ -100,6 +100,48 @@ Evaluation Components The |pruna| package provides a variety of evaluation metrics to assess your models. In this section, we'll introduce the evaluation metrics you can use. +.. _vlm_judge_metrics: + +Vision-language judge metrics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Some **quality** metrics (for example ``vqa``, ``qa_accuracy``, ``alignment_score``, OCR-based text scores, ``vie_score``) use a **vision-language model** as a judge. By default they use a hosted API via ``litellm`` (``vlm_type="litellm"``); you can load a local Hugging Face model with ``vlm_type="transformers"``. When using string metric names with ``Task``, the default hosted route uses ``openai/gpt-4o`` unless you construct the metric explicitly. + +**API keys (hosted judges).** Pruna passes ``api_key`` into LiteLLM using this order: the metric’s ``api_key`` argument (if set), then ``LITELLM_API_KEY``, then ``OPENAI_API_KEY``. That matches common usage for OpenAI routes: LiteLLM documents ``OPENAI_API_KEY`` for OpenAI; ``LITELLM_API_KEY`` is an extra env name Pruna checks so you can supply a key without using ``OPENAI_API_KEY``. If all three are unset, LiteLLM can still pick up provider-specific variables (for example ``ANTHROPIC_API_KEY``) as in LiteLLM’s “Setting API Keys” and provider docs—if you use a non-OpenAI route but have ``OPENAI_API_KEY`` set for other tools, pass ``api_key`` explicitly so Pruna does not forward the wrong key. Credentials for Replicate or other image-only backends are separate and are not used by these metrics. + +**Hosted ``litellm``.** Set ``OPENAI_API_KEY`` or ``LITELLM_API_KEY`` (or pass ``api_key=...``) and use a vision-capable LiteLLM route as ``model_name``: + +.. code-block:: python + + from pruna.evaluation.metrics import VQAMetric + + hosted = VQAMetric(vlm_type="litellm", model_name="openai/gpt-4o") + +The same pattern works with ``get_vlm`` in ``pruna.evaluation.metrics.vlm_base``: + +.. code-block:: python + + from pruna.evaluation.metrics.vlm_base import get_vlm + + vlm = get_vlm(vlm_type="litellm", model_name="openai/gpt-4o") + +**Local ``transformers``.** Pass a Hugging Face model id as ``model_name``, set ``device``, and use ``vlm_kwargs`` with ``model_load_kwargs`` for ``from_pretrained`` (same pattern for any registry metric class): + +.. code-block:: python + + import torch + + from pruna.evaluation.metrics import VQAMetric + + local = VQAMetric( + vlm_type="transformers", + model_name="HuggingFaceTB/SmolVLM-256M-Instruct", + device="cpu", + vlm_kwargs={"model_load_kwargs": {"torch_dtype": torch.float32}}, + ) + +Use ``Task(request=[hosted, ...], ...)`` or ``Task(request=[local, ...], ...)`` (or pass the metric instance wherever metrics are configured). Full constructor patterns and ``get_vlm`` helpers are documented in ``pruna.evaluation.metrics.vlm_base`` and each metric’s docstring. + EvaluationAgent Initialization ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/src/pruna/data/__init__.py b/src/pruna/data/__init__.py index fd14a496..1f0ed5f6 100644 --- a/src/pruna/data/__init__.py +++ b/src/pruna/data/__init__.py @@ -34,7 +34,13 @@ setup_hps_dataset, setup_imgedit_dataset, setup_long_text_bench_dataset, + setup_oneig_anime_stylization_dataset, setup_oneig_dataset, + setup_oneig_general_object_dataset, + setup_oneig_knowledge_reasoning_dataset, + setup_oneig_multilingualism_dataset, + setup_oneig_portrait_dataset, + setup_oneig_text_rendering_dataset, setup_parti_prompts_dataset, ) from pruna.data.datasets.question_answering import setup_polyglot_dataset @@ -103,19 +109,33 @@ "image_classification_collate", {"img_size": 224}, ), - "DrawBench": (setup_drawbench_dataset, "prompt_collate", {}), + "DrawBench": (setup_drawbench_dataset, "prompt_with_auxiliaries_collate", {}), "PartiPrompts": ( setup_parti_prompts_dataset, "prompt_with_auxiliaries_collate", {}, ), - "GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}), + "GenAIBench": (setup_genai_bench_dataset, "prompt_with_auxiliaries_collate", {}), "GenEval": (setup_geneval_dataset, "prompt_with_auxiliaries_collate", {}), "HPS": (setup_hps_dataset, "prompt_with_auxiliaries_collate", {}), "ImgEdit": (setup_imgedit_dataset, "prompt_with_auxiliaries_collate", {}), "LongTextBench": (setup_long_text_bench_dataset, "prompt_with_auxiliaries_collate", {}), "GEditBench": (setup_gedit_dataset, "prompt_with_auxiliaries_collate", {}), "OneIG": (setup_oneig_dataset, "prompt_with_auxiliaries_collate", {}), + "OneIGAnimeStylization": ( + setup_oneig_anime_stylization_dataset, + "prompt_with_auxiliaries_collate", + {}, + ), + "OneIGGeneralObject": (setup_oneig_general_object_dataset, "prompt_with_auxiliaries_collate", {}), + "OneIGKnowledgeReasoning": ( + setup_oneig_knowledge_reasoning_dataset, + "prompt_with_auxiliaries_collate", + {}, + ), + "OneIGMultilingualism": (setup_oneig_multilingualism_dataset, "prompt_with_auxiliaries_collate", {}), + "OneIGPortrait": (setup_oneig_portrait_dataset, "prompt_with_auxiliaries_collate", {}), + "OneIGTextRendering": (setup_oneig_text_rendering_dataset, "prompt_with_auxiliaries_collate", {}), "DPG": (setup_dpg_dataset, "prompt_with_auxiliaries_collate", {}), "TinyIMDB": (setup_tiny_imdb_dataset, "text_generation_collate", {}), "VBench": (setup_vbench_dataset, "prompt_with_auxiliaries_collate", {}), diff --git a/src/pruna/data/datasets/prompt.py b/src/pruna/data/datasets/prompt.py index 7764d23b..b18d03d4 100644 --- a/src/pruna/data/datasets/prompt.py +++ b/src/pruna/data/datasets/prompt.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path from typing import Literal, Tuple, get_args from datasets import Dataset, load_dataset @@ -123,21 +124,95 @@ DPGCategory = Literal["entity", "attribute", "relation", "global", "other"] -def _to_oneig_record(row: dict, questions_by_key: dict[str, dict]) -> dict: - """Convert OneIG row to unified record format.""" +def _warn_ignored_benchmark_seed(seed: int | None, *, dataset: str) -> None: + if seed is not None: + pruna_logger.warning( + "%s: `seed` is ignored for this test-only benchmark; sampling does not shuffle the test split.", + dataset, + ) + + +def _oneig_alignment_language_zh(row: dict) -> bool: + """Return True when the official Q_D file for this row should use the ``*_zh`` graphs.""" + if row.get("category", "") == "Multilingualism": + return True + lang = row.get("language") or row.get("lang") + return isinstance(lang, str) and lang.lower() in {"zh", "zh-cn", "zh_cn", "chinese", "cn"} + + +def _oneig_qd_prefix(row: dict) -> str: + """Map dataset ``category`` (+ language) to Q_D JSON stem (e.g. ``object``, ``anime_zh``).""" + row_category = row.get("category", "") + use_zh = _oneig_alignment_language_zh(row) + if row_category == "Multilingualism": + return "multilingualism_zh" + base = _CATEGORY_TO_QD.get(row_category, "") + if not base: + return "" + return f"{base}_zh" if use_zh else base + + +def _to_oneig_record( + row: dict, + questions_by_key: dict[str, dict], + reasoning_gt_en: dict[str, str], + reasoning_gt_zh: dict[str, str], + reasoning_language: str = "EN", +) -> dict: + """Convert OneIG row to unified record format. + + Parameters + ---------- + row : dict + Raw Hugging Face row (``category``, ``id``, ``class``). EN configs use ``prompt_en``; the + ``OneIG-Bench-ZH`` **Multilingualism** split uses ``prompt_cn`` instead of ``prompt_en``. + questions_by_key : dict[str, dict] + Merged Q_D index keyed as ``{qd_stem}_{prompt_id}`` (see ``_fetch_oneig_alignment``). + reasoning_gt_en : dict[str, str] + Official ``gt_answer.json`` keyed by prompt id (e.g. ``"000"``). + reasoning_gt_zh : dict[str, str] + Official ``gt_answer_zh.json`` keyed by prompt id. + reasoning_language : str, optional + Which reasoning GT to use: ``"EN"`` or ``"ZH"``. Default is ``"EN"``. + + Returns + ------- + dict + Unified record including ``questions``, ``dependencies``, and ``reasoning_gt_answer`` when + applicable (Knowledge_Reasoning only). + """ row_category = row.get("category", "") row_class = row.get("class", "None") or "None" - qd_name = _CATEGORY_TO_QD.get(row_category, "") - lookup_key = f"{qd_name}_{row.get('id', '')}" if qd_name else "" + prompt_id = str(row.get("id", "")) + qd_prefix = _oneig_qd_prefix(row) + lookup_key = f"{qd_prefix}_{prompt_id}" if qd_prefix else "" q_info = questions_by_key.get(lookup_key, {}) + text = row.get("prompt") or row.get("prompt_en") or row.get("prompt_cn") or "" + reasoning_gt_answer: str | None = None + if row_category == "Knowledge_Reasoning": + if reasoning_language.upper() == "ZH": + reasoning_gt_answer = reasoning_gt_zh.get(prompt_id) + else: + reasoning_gt_answer = reasoning_gt_en.get(prompt_id) + is_text_rendering = row_category in ("Text_Rendering", "Text Rendering") + if is_text_rendering and text: + import re as _re + + quoted = _re.findall(r'"([^"]+)"', text) + text_content: str | None = " ".join(quoted) if quoted else (row_class if row_class != "None" else None) + else: + text_content = row_class if row_class != "None" else None + questions = {k: v for k, v in q_info.get("questions", {}).items() if v is not None} + dependencies = {k: v for k, v in q_info.get("dependencies", {}).items() if v is not None} return { - "text": row.get("prompt_en", row.get("prompt", "")), - "subset": "Text_Rendering" if row_category in ("Text_Rendering", "Text Rendering") else row_category, - "text_content": row_class if row_class != "None" else None, + "text": text, + "subset": "Text_Rendering" if is_text_rendering else row_category, + "text_content": text_content, "category": row_category, "class": row_class, - "questions": q_info.get("questions", {}), - "dependencies": q_info.get("dependencies", {}), + "questions": questions, + "dependencies": dependencies, + "reasoning_gt_answer": reasoning_gt_answer, } @@ -159,7 +234,7 @@ def setup_drawbench_dataset() -> Tuple[Dataset, Dataset, Dataset]: def setup_parti_prompts_dataset( - seed: int, + seed: int | None = None, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None, @@ -172,8 +247,8 @@ def setup_parti_prompts_dataset( Parameters ---------- - seed : int - The seed to use. + seed : int | None, optional + Ignored; test order is deterministic. If not None, a warning is logged. fraction : float The fraction of the dataset to use. train_sample_size : int | None @@ -188,6 +263,7 @@ def setup_parti_prompts_dataset( Tuple[Dataset, Dataset, Dataset] The Parti Prompts dataset (dummy train, dummy val, test). """ + _warn_ignored_benchmark_seed(seed, dataset="PartiPrompts") ds = load_dataset("nateraw/parti-prompts")["train"] # type: ignore[index] if category is not None: @@ -226,7 +302,7 @@ def _generate_geneval_question(entry: dict) -> list[str]: def setup_geneval_dataset( - seed: int, + seed: int | None = None, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None, @@ -239,8 +315,8 @@ def setup_geneval_dataset( Parameters ---------- - seed : int - The seed to use. + seed : int | None, optional + Ignored; test order is deterministic. If not None, a warning is logged. fraction : float The fraction of the dataset to use. train_sample_size : int | None @@ -255,6 +331,7 @@ def setup_geneval_dataset( Tuple[Dataset, Dataset, Dataset] The GenEval dataset (dummy train, dummy val, test). """ + _warn_ignored_benchmark_seed(seed, dataset="GenEval") import json import requests @@ -286,7 +363,7 @@ def setup_geneval_dataset( def setup_hps_dataset( - seed: int, + seed: int | None = None, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None, @@ -299,8 +376,8 @@ def setup_hps_dataset( Parameters ---------- - seed : int - The seed to use. + seed : int | None, optional + Ignored; test order is deterministic. If not None, a warning is logged. fraction : float The fraction of the dataset to use. train_sample_size : int | None @@ -315,6 +392,7 @@ def setup_hps_dataset( Tuple[Dataset, Dataset, Dataset] The HPD dataset (dummy train, dummy val, test). """ + _warn_ignored_benchmark_seed(seed, dataset="HPS") import json from huggingface_hub import hf_hub_download @@ -338,7 +416,7 @@ def setup_hps_dataset( def setup_long_text_bench_dataset( - seed: int, + seed: int | None = None, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None, @@ -350,8 +428,8 @@ def setup_long_text_bench_dataset( Parameters ---------- - seed : int - The seed to use. + seed : int | None, optional + Ignored; test order is deterministic. If not None, a warning is logged. fraction : float The fraction of the dataset to use. train_sample_size : int | None @@ -364,6 +442,7 @@ def setup_long_text_bench_dataset( Tuple[Dataset, Dataset, Dataset] The Long Text Bench dataset (dummy train, dummy val, test). """ + _warn_ignored_benchmark_seed(seed, dataset="LongTextBench") ds = load_dataset("X-Omni/LongText-Bench")["train"] # type: ignore[index] ds = ds.rename_column("text", "text_content") ds = ds.rename_column("prompt", "text") @@ -389,8 +468,74 @@ def setup_genai_bench_dataset() -> Tuple[Dataset, Dataset, Dataset]: return ds.select([0]), ds.select([0]), ds +def ensure_imgedit_benchmark_images_extracted() -> Path: + """ + Download ``Benchmark.tar`` (if needed), extract it, and return the ``singleturn`` folder. + + Returns + ------- + Path + ``Benchmark/singleturn`` directory whose files are addressed by ``image_id``. + + Raises + ------ + RuntimeError + If the archive cannot be downloaded or extracted. + """ + import tarfile + + from huggingface_hub import hf_hub_download + + tar_path = Path(hf_hub_download(repo_id="sysuyy/ImgEdit", filename="Benchmark.tar", repo_type="dataset")) + extract_dir = tar_path.parent / "imgedit_singleturn" + candidate = extract_dir / "Benchmark" / "singleturn" + if not candidate.is_dir() or not any(candidate.iterdir()): + extract_dir.mkdir(parents=True, exist_ok=True) + with tarfile.open(tar_path, "r") as tar: + tar.extractall(path=extract_dir) + if not candidate.is_dir() or not any(candidate.iterdir()): + raise RuntimeError(f"ImgEdit: failed to extract Benchmark.tar to {candidate}") + return candidate + + +def load_imgedit_source_image_bytes(image_id: str, *, image_folder: Path | None = None) -> bytes: + """ + Read one ImgEdit source image as JPEG bytes (RGB). + + Parameters + ---------- + image_id : str + Path relative to the singleturn folder (from the official ``basic_edit.json`` ``id``). + image_folder : Path | None, optional + ``Benchmark/singleturn`` directory; when ``None``, calls + :func:`ensure_imgedit_benchmark_images_extracted`. + + Returns + ------- + bytes + JPEG-encoded bytes for the source image. + + Raises + ------ + FileNotFoundError + If ``image_id`` does not exist under ``image_folder``. + Exception + If PIL cannot open or convert the image. + """ + from io import BytesIO + + from PIL import Image + + folder = image_folder if image_folder is not None else ensure_imgedit_benchmark_images_extracted() + img_path = folder / image_id + pil = Image.open(img_path).convert("RGB") + buf = BytesIO() + pil.save(buf, format="JPEG") + return buf.getvalue() + + def setup_imgedit_dataset( - seed: int, + seed: int | None = None, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None, @@ -403,8 +548,8 @@ def setup_imgedit_dataset( Parameters ---------- - seed : int - The seed to use. + seed : int | None, optional + Ignored; test order is deterministic. If not None, a warning is logged. fraction : float The fraction of the dataset to use. train_sample_size : int | None @@ -420,6 +565,7 @@ def setup_imgedit_dataset( Tuple[Dataset, Dataset, Dataset] The ImgEdit dataset (dummy train, dummy val, test). """ + _warn_ignored_benchmark_seed(seed, dataset="ImgEdit") import json import requests @@ -433,6 +579,8 @@ def setup_imgedit_dataset( instructions: dict = json.loads(response_instructions.text) judge_prompts: dict = json.loads(response_judge_prompts.text) + image_folder = ensure_imgedit_benchmark_images_extracted() + categories = [category] if category is not None and not isinstance(category, list) else category records = [] for _, instruction in instructions.items(): @@ -441,14 +589,17 @@ def setup_imgedit_dataset( if categories is not None and edit_type not in categories: continue - records.append( - { - "text": instruction.get("prompt", ""), - "category": edit_type, - "image_id": instruction.get("id", ""), - "judge_prompt": judge_prompts.get(edit_type, ""), - } - ) + image_id = instruction.get("id", "") + record: dict = { + "text": instruction.get("prompt", ""), + "category": edit_type, + "image_id": image_id, + "judge_prompt": judge_prompts.get(edit_type, ""), + } + src = load_imgedit_source_image_bytes(image_id, image_folder=image_folder) + if src is not None: + record["source_image_bytes"] = src + records.append(record) ds = Dataset.from_list(records) ds = stratify_dataset(ds, sample_size=test_sample_size, fraction=fraction) @@ -466,18 +617,47 @@ def setup_imgedit_dataset( "General_Object": "object", } -_ONEIG_ALIGNMENT_BASE = "https://raw.githubusercontent.com/OneIG-Bench/OneIG-Benchmark/41b49831e79e6dde5323618c164da1c4cf0f699d/scripts/alignment/Q_D" +_ONEIG_BENCHMARK_REF = "41b49831e79e6dde5323618c164da1c4cf0f699d" +_ONEIG_RAW_BASE = f"https://raw.githubusercontent.com/OneIG-Bench/OneIG-Benchmark/{_ONEIG_BENCHMARK_REF}" +_ONEIG_ALIGNMENT_QD_URL = f"{_ONEIG_RAW_BASE}/scripts/alignment/Q_D" +_ONEIG_REASONING_GT_URL_EN = f"{_ONEIG_RAW_BASE}/scripts/reasoning/gt_answer.json" +_ONEIG_REASONING_GT_URL_ZH = f"{_ONEIG_RAW_BASE}/scripts/reasoning/gt_answer_zh.json" + +_ONEIG_QD_JSON_STEMS: tuple[str, ...] = ( + "anime", + "human", + "object", + "anime_zh", + "human_zh", + "object_zh", + "multilingualism_zh", +) def _fetch_oneig_alignment() -> dict[str, dict]: - """Fetch alignment questions from per-category Q_D files (InferBench-style).""" + """Load OneIG question/dependency graphs from the official repo (HTTP, no on-disk cache). + + Fetches every ``scripts/alignment/Q_D/*.json`` file used by upstream ``alignment_score.py`` (EN + ZH), + including ``multilingualism_zh.json``. Keys in the returned map are ``{stem}_{prompt_id}`` matching + upstream file stems (e.g. ``object_012``, ``multilingualism_zh_000``). + + Returns + ------- + dict[str, dict] + ``prompt_id``-level ``questions`` and ``dependencies`` dicts (parsed from JSON strings when needed). + + Raises + ------ + requests.HTTPError + If any asset URL is missing or the response is not successful. + """ import json import requests questions_by_key: dict[str, dict] = {} - for qd_name in ("anime", "human", "object"): - url = f"{_ONEIG_ALIGNMENT_BASE}/{qd_name}.json" + for stem in _ONEIG_QD_JSON_STEMS: + url = f"{_ONEIG_ALIGNMENT_QD_URL}/{stem}.json" resp = requests.get(url, timeout=30) resp.raise_for_status() data = json.loads(resp.text) @@ -488,16 +668,55 @@ def _fetch_oneig_alignment() -> dict[str, dict]: q = json.loads(q) if isinstance(d, str): d = json.loads(d) - questions_by_key[f"{qd_name}_{prompt_id}"] = {"questions": q, "dependencies": d} + questions_by_key[f"{stem}_{prompt_id}"] = {"questions": q, "dependencies": d} return questions_by_key +def _fetch_oneig_reasoning_gt() -> tuple[dict[str, str], dict[str, str]]: + """Load official knowledge-reasoning reference answers (HTTP, no on-disk cache). + + Mirrors ``scripts/reasoning/gt_answer.json`` and ``gt_answer_zh.json`` from the same pinned commit as Q_D. + Keys are prompt ids (``str``), values are answer strings; downstream metrics may slice filenames to the + first three characters like ``reasoning_score.py``. + + Returns + ------- + tuple[dict[str, str], dict[str, str]] + ``(en_by_id, zh_by_id)``. + + Raises + ------ + requests.HTTPError + If any asset URL is missing or the response is not successful. + """ + import json + + import requests + + def _load(url: str) -> dict[str, str]: + resp = requests.get(url, timeout=60) + resp.raise_for_status() + raw = json.loads(resp.text) + return {str(k): str(v) for k, v in raw.items()} + + return _load(_ONEIG_REASONING_GT_URL_EN), _load(_ONEIG_REASONING_GT_URL_ZH) + + +def _oneig_needs_zh_multilingualism_hub(category: OneIGCategory | list[OneIGCategory] | None) -> bool: + """Whether ``OneIG-Bench-ZH`` must be loaded for ``Multilingualism`` rows.""" + if category is None: + return True + categories = [category] if not isinstance(category, list) else category + return "Multilingualism" in categories + + def setup_oneig_dataset( - seed: int, + seed: int | None = None, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None, category: OneIGCategory | list[OneIGCategory] | None = None, + reasoning_language: str = "EN", ) -> Tuple[Dataset, Dataset, Dataset]: """ Setup the OneIG benchmark dataset. @@ -506,8 +725,8 @@ def setup_oneig_dataset( Parameters ---------- - seed : int - The seed to use. + seed : int | None, optional + Ignored; test order is deterministic. If not None, a warning is logged. fraction : float The fraction of the dataset to use. train_sample_size : int | None @@ -517,16 +736,43 @@ def setup_oneig_dataset( category : OneIGCategory | list[OneIGCategory] | None Filter by dataset category (Anime_Stylization, Portrait, etc.) or class (fauvism, watercolor, etc.). If None, returns all subsets. + reasoning_language : str, optional + Which reasoning GT to use for Knowledge_Reasoning rows: ``"EN"`` or ``"ZH"``. Default is ``"EN"``. Returns ------- Tuple[Dataset, Dataset, Dataset] - The OneIG dataset (dummy train, dummy val, test). + The OneIG dataset (dummy train, dummy val, test). Rows include ``questions`` and + ``dependencies`` from official Q_D JSON (EN + ZH stems, including ``multilingualism_zh``), + plus ``reasoning_gt_answer`` for ``Knowledge_Reasoning`` (language chosen by ``reasoning_language``). + Rows cover EN categories from ``OneIG-Bench`` plus ``Multilingualism`` from ``OneIG-Bench-ZH``. + Assets are downloaded over HTTP on each call (pinned commit ``_ONEIG_BENCHMARK_REF``); there is + no local disk cache. + + Notes + ----- + Non-multilingual prompts are loaded from the Hub config ``OneIG-Bench``; **Multilingualism** rows + are taken only from ``OneIG-Bench-ZH`` (they use ``prompt_cn``). The ZH config is fetched only when + the requested ``category`` is ``None`` (full suite) or explicitly includes ``Multilingualism``. + Q_D / reasoning JSON URLs are defined next to ``_fetch_oneig_alignment`` and + ``_fetch_oneig_reasoning_gt``. """ + _warn_ignored_benchmark_seed(seed, dataset="OneIG") questions_by_key = _fetch_oneig_alignment() - - ds_raw = load_dataset("OneIG-Bench/OneIG-Bench", "OneIG-Bench")["train"] # type: ignore[index] - records = [_to_oneig_record(dict(row), questions_by_key) for row in ds_raw] + reasoning_gt_en, reasoning_gt_zh = _fetch_oneig_reasoning_gt() + + ds_en = load_dataset("OneIG-Bench/OneIG-Bench", "OneIG-Bench")["train"] # type: ignore[index] + records = [ + _to_oneig_record(dict(row), questions_by_key, reasoning_gt_en, reasoning_gt_zh, reasoning_language) + for row in ds_en + ] + if _oneig_needs_zh_multilingualism_hub(category): + ds_zh = load_dataset("OneIG-Bench/OneIG-Bench", "OneIG-Bench-ZH")["train"] # type: ignore[index] + ds_zh_ml = ds_zh.filter(lambda r: r["category"] == "Multilingualism") + records.extend( + _to_oneig_record(dict(row), questions_by_key, reasoning_gt_en, reasoning_gt_zh, reasoning_language) + for row in ds_zh_ml + ) ds = Dataset.from_list(records) if category is not None: @@ -544,8 +790,252 @@ def setup_oneig_dataset( return ds.select([0]), ds.select([0]), ds +# functools.partial is not used for these wrappers: get_literal_values_from_param would unwrap +# partial objects back to setup_oneig_dataset and expose every OneIGCategory instead of one. + + +def setup_oneig_anime_stylization_dataset( + seed: int | None = None, + fraction: float = 1.0, + train_sample_size: int | None = None, + test_sample_size: int | None = None, + reasoning_language: str = "EN", +) -> Tuple[Dataset, Dataset, Dataset]: + """ + Load OneIG-Bench with ``category`` fixed to ``Anime_Stylization``. + + License: Apache 2.0 + + Parameters + ---------- + seed : int | None, optional + Ignored; see :func:`setup_oneig_dataset`. + fraction : float + Fraction of the subset to use. + train_sample_size : int | None + Unused; train/val are dummy. + test_sample_size : int | None + Test sample size cap for the subset. + reasoning_language : str + Passed to :func:`setup_oneig_dataset`. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + Dummy train, dummy val, and test split for the Anime_Stylization subset. + """ + return setup_oneig_dataset( + seed=seed, + fraction=fraction, + train_sample_size=train_sample_size, + test_sample_size=test_sample_size, + category="Anime_Stylization", + reasoning_language=reasoning_language, + ) + + +def setup_oneig_general_object_dataset( + seed: int | None = None, + fraction: float = 1.0, + train_sample_size: int | None = None, + test_sample_size: int | None = None, + reasoning_language: str = "EN", +) -> Tuple[Dataset, Dataset, Dataset]: + """ + Load OneIG-Bench with ``category`` fixed to ``General_Object``. + + License: Apache 2.0 + + Parameters + ---------- + seed : int | None, optional + Ignored; see :func:`setup_oneig_dataset`. + fraction : float + Fraction of the subset to use. + train_sample_size : int | None + Unused; train/val are dummy. + test_sample_size : int | None + Test sample size cap for the subset. + reasoning_language : str + Passed to :func:`setup_oneig_dataset`. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + Dummy train, dummy val, and test split for the General_Object subset. + """ + return setup_oneig_dataset( + seed=seed, + fraction=fraction, + train_sample_size=train_sample_size, + test_sample_size=test_sample_size, + category="General_Object", + reasoning_language=reasoning_language, + ) + + +def setup_oneig_knowledge_reasoning_dataset( + seed: int | None = None, + fraction: float = 1.0, + train_sample_size: int | None = None, + test_sample_size: int | None = None, + reasoning_language: str = "EN", +) -> Tuple[Dataset, Dataset, Dataset]: + """ + Load OneIG-Bench with ``category`` fixed to ``Knowledge_Reasoning``. + + License: Apache 2.0 + + Parameters + ---------- + seed : int | None, optional + Ignored; see :func:`setup_oneig_dataset`. + fraction : float + Fraction of the subset to use. + train_sample_size : int | None + Unused; train/val are dummy. + test_sample_size : int | None + Test sample size cap for the subset. + reasoning_language : str + Passed to :func:`setup_oneig_dataset`. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + Dummy train, dummy val, and test split for the Knowledge_Reasoning subset. + """ + return setup_oneig_dataset( + seed=seed, + fraction=fraction, + train_sample_size=train_sample_size, + test_sample_size=test_sample_size, + category="Knowledge_Reasoning", + reasoning_language=reasoning_language, + ) + + +def setup_oneig_multilingualism_dataset( + seed: int | None = None, + fraction: float = 1.0, + train_sample_size: int | None = None, + test_sample_size: int | None = None, + reasoning_language: str = "EN", +) -> Tuple[Dataset, Dataset, Dataset]: + """ + Load OneIG-Bench with ``category`` fixed to ``Multilingualism``. + + License: Apache 2.0 + + Parameters + ---------- + seed : int | None, optional + Ignored; see :func:`setup_oneig_dataset`. + fraction : float + Fraction of the subset to use. + train_sample_size : int | None + Unused; train/val are dummy. + test_sample_size : int | None + Test sample size cap for the subset. + reasoning_language : str + Passed to :func:`setup_oneig_dataset`. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + Dummy train, dummy val, and test split for the Multilingualism subset. + """ + return setup_oneig_dataset( + seed=seed, + fraction=fraction, + train_sample_size=train_sample_size, + test_sample_size=test_sample_size, + category="Multilingualism", + reasoning_language=reasoning_language, + ) + + +def setup_oneig_portrait_dataset( + seed: int | None = None, + fraction: float = 1.0, + train_sample_size: int | None = None, + test_sample_size: int | None = None, + reasoning_language: str = "EN", +) -> Tuple[Dataset, Dataset, Dataset]: + """ + Load OneIG-Bench with ``category`` fixed to ``Portrait``. + + License: Apache 2.0 + + Parameters + ---------- + seed : int | None, optional + Ignored; see :func:`setup_oneig_dataset`. + fraction : float + Fraction of the subset to use. + train_sample_size : int | None + Unused; train/val are dummy. + test_sample_size : int | None + Test sample size cap for the subset. + reasoning_language : str + Passed to :func:`setup_oneig_dataset`. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + Dummy train, dummy val, and test split for the Portrait subset. + """ + return setup_oneig_dataset( + seed=seed, + fraction=fraction, + train_sample_size=train_sample_size, + test_sample_size=test_sample_size, + category="Portrait", + reasoning_language=reasoning_language, + ) + + +def setup_oneig_text_rendering_dataset( + seed: int | None = None, + fraction: float = 1.0, + train_sample_size: int | None = None, + test_sample_size: int | None = None, + reasoning_language: str = "EN", +) -> Tuple[Dataset, Dataset, Dataset]: + """ + Load OneIG-Bench with ``category`` fixed to ``Text_Rendering``. + + License: Apache 2.0 + + Parameters + ---------- + seed : int | None, optional + Ignored; see :func:`setup_oneig_dataset`. + fraction : float + Fraction of the subset to use. + train_sample_size : int | None + Unused; train/val are dummy. + test_sample_size : int | None + Test sample size cap for the subset. + reasoning_language : str + Passed to :func:`setup_oneig_dataset`. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + Dummy train, dummy val, and test split for the Text_Rendering subset. + """ + return setup_oneig_dataset( + seed=seed, + fraction=fraction, + train_sample_size=train_sample_size, + test_sample_size=test_sample_size, + category="Text_Rendering", + reasoning_language=reasoning_language, + ) + + def setup_gedit_dataset( - seed: int, + seed: int | None = None, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None, @@ -558,8 +1048,8 @@ def setup_gedit_dataset( Parameters ---------- - seed : int - The seed to use. + seed : int | None, optional + Ignored; test order is deterministic. If not None, a warning is logged. fraction : float The fraction of the dataset to use. train_sample_size : int | None @@ -576,6 +1066,7 @@ def setup_gedit_dataset( Tuple[Dataset, Dataset, Dataset] The GEditBench dataset (dummy train, dummy val, test). """ + _warn_ignored_benchmark_seed(seed, dataset="GEditBench") task_type_map = { "subject_add": "subject-add", "subject_remove": "subject-remove", @@ -595,12 +1086,18 @@ def setup_gedit_dataset( for row in ds: task_type = row.get("task_type", "") category_name = task_type_to_category.get(task_type, task_type) - records.append( - { - "text": row.get("instruction", ""), - "category": category_name, - } - ) + record: dict = { + "text": row.get("instruction", ""), + "category": category_name, + } + src = row.get("input_image_raw") + if src is not None: + from io import BytesIO + + buf = BytesIO() + src.save(buf, format="JPEG") + record["source_image_bytes"] = buf.getvalue() + records.append(record) ds = Dataset.from_list(records) ds = stratify_dataset(ds, sample_size=test_sample_size, fraction=fraction) @@ -613,7 +1110,7 @@ def setup_gedit_dataset( def setup_dpg_dataset( - seed: int, + seed: int | None = None, fraction: float = 1.0, train_sample_size: int | None = None, test_sample_size: int | None = None, @@ -626,8 +1123,8 @@ def setup_dpg_dataset( Parameters ---------- - seed : int - The seed to use. + seed : int | None, optional + Ignored; test order is deterministic. If not None, a warning is logged. fraction : float The fraction of the dataset to use. train_sample_size : int | None @@ -642,6 +1139,7 @@ def setup_dpg_dataset( Tuple[Dataset, Dataset, Dataset] The DPG dataset (dummy train, dummy val, test). """ + _warn_ignored_benchmark_seed(seed, dataset="DPG") import csv import io from collections import defaultdict diff --git a/src/pruna/data/pruna_datamodule.py b/src/pruna/data/pruna_datamodule.py index 6d1eaadd..03003127 100644 --- a/src/pruna/data/pruna_datamodule.py +++ b/src/pruna/data/pruna_datamodule.py @@ -135,7 +135,7 @@ def from_string( tokenizer: AutoTokenizer | None = None, collate_fn_args: dict = dict(), dataloader_args: dict = dict(), - seed: int = 42, + seed: int | None = None, category: str | list[str] | None = None, fraction: float = 1.0, train_sample_size: int | None = None, @@ -154,8 +154,10 @@ def from_string( Any additional arguments for the collate function. dataloader_args : dict Any additional arguments for the dataloader. - seed : int - The seed to use. + seed : int | None, optional + Passed to dataset setup when the loader uses shuffled sampling. + If None, setups that require a seed default to 42; test-only benchmarks + omit seed so ordering stays deterministic without warnings. category : str | list[str] | None The category of the dataset. fraction : float @@ -177,7 +179,12 @@ def from_string( collate_fn_args = default_collate_fn_args if "seed" in inspect.signature(setup_fn).parameters: - setup_fn = partial(setup_fn, seed=seed) + seed_param = inspect.signature(setup_fn).parameters["seed"] + has_default = seed_param.default is not inspect.Parameter.empty + if seed is not None: + setup_fn = partial(setup_fn, seed=seed) + elif not has_default: + setup_fn = partial(setup_fn, seed=42) if "category" in inspect.signature(setup_fn).parameters: setup_fn = partial(setup_fn, category=category) diff --git a/src/pruna/evaluation/evaluation_agent.py b/src/pruna/evaluation/evaluation_agent.py index 674bf962..3e20e4a5 100644 --- a/src/pruna/evaluation/evaluation_agent.py +++ b/src/pruna/evaluation/evaluation_agent.py @@ -26,10 +26,9 @@ from pruna.data.utils import move_batch_to_device from pruna.engine.pruna_model import PrunaModel from pruna.engine.utils import get_device, move_to_device, safe_memory_cleanup, set_to_best_available_device -from pruna.evaluation.metrics.context_mixin import EvaluationContextMixin from pruna.evaluation.metrics.metric_base import BaseMetric from pruna.evaluation.metrics.metric_stateful import StatefulMetric -from pruna.evaluation.metrics.result import MetricResult, MetricResultProtocol +from pruna.evaluation.metrics.result import MetricResult from pruna.evaluation.metrics.utils import ensure_device_consistency, get_device_map, group_metrics_by_inheritance from pruna.evaluation.task import Task from pruna.logging.logger import pruna_logger @@ -72,8 +71,8 @@ def __init__( raise ValueError("When not using 'task' parameter, both 'request' and 'datamodule' must be provided.") self.task = Task(request=request, datamodule=datamodule, device=device) - self.first_model_results: List[MetricResultProtocol] = [] - self.subsequent_model_results: List[MetricResultProtocol] = [] + self.first_model_results: List[MetricResult] = [] + self.subsequent_model_results: List[MetricResult] = [] self.device = set_to_best_available_device(self.task.device) self.cache: List[Tensor] = [] self.evaluation_for_first_model: bool = True @@ -113,8 +112,8 @@ def from_benchmark( Examples -------- - >>> agent = EvaluationAgent.from_benchmark("Parti Prompts", model) - >>> agent = EvaluationAgent.from_benchmark("HPS", model, category="anime", fraction=0.1) + >>> agent = EvaluationAgent.from_benchmark("Parti Prompts") + >>> agent = EvaluationAgent.from_benchmark("HPS", category="anime", fraction=0.1) """ task = Task.from_benchmark( benchmark_name, @@ -125,20 +124,18 @@ def from_benchmark( ) return cls(task=task) - def evaluate(self, model: Any, model_name: str | None = None) -> List[MetricResultProtocol]: + def evaluate(self, model: Any) -> List[MetricResult]: """ Evaluate models using different metric types. Parameters ---------- - model : Any + model : PrunaModel The model to evaluate. - model_name : str | None, optional - The name of the model to evaluate. Required for rapidata benchmark submission. Returns ------- - List[MetricResultProtocol] + List[MetricResult] The results of the model. """ results = [] @@ -149,10 +146,6 @@ def evaluate(self, model: Any, model_name: str | None = None) -> List[MetricResu pairwise_metrics = self.task.get_pairwise_stateful_metrics() stateless_metrics = self.task.get_stateless_metrics() - for metric in single_stateful_metrics: - if isinstance(metric, EvaluationContextMixin): - metric.current_context = model_name - # Update and compute stateful metrics. pruna_logger.info("Evaluating stateful metrics.") with torch.no_grad(): @@ -285,7 +278,7 @@ def update_stateful_metrics( def compute_stateful_metrics( self, single_stateful_metrics: List[StatefulMetric], pairwise_metrics: List[StatefulMetric] - ) -> List[MetricResultProtocol]: + ) -> List[MetricResult]: """ Compute stateful metrics. @@ -303,20 +296,16 @@ def compute_stateful_metrics( """ results = [] for stateful_metric in single_stateful_metrics: - result = stateful_metric.compute() - if result is not None: - results.append(result) + results.append(stateful_metric.compute()) stateful_metric.reset() if not self.evaluation_for_first_model and self.task.is_pairwise_evaluation(): for pairwise_metric in pairwise_metrics: - result = pairwise_metric.compute() - if result is not None: - results.append(result) + results.append(pairwise_metric.compute()) pairwise_metric.reset() return results - def compute_stateless_metrics(self, model: PrunaModel, stateless_metrics: List[Any]) -> List[MetricResultProtocol]: + def compute_stateless_metrics(self, model: PrunaModel, stateless_metrics: List[Any]) -> List[MetricResult]: """ Compute stateless metrics. diff --git a/src/pruna/evaluation/metrics/metric_elapsed_time.py b/src/pruna/evaluation/metrics/metric_elapsed_time.py index c3689446..ccfc413c 100644 --- a/src/pruna/evaluation/metrics/metric_elapsed_time.py +++ b/src/pruna/evaluation/metrics/metric_elapsed_time.py @@ -198,9 +198,11 @@ def compute(self, model: PrunaModel, dataloader: DataLoader) -> Dict[str, Any] | # Measurement list_elapsed_times = [] with tqdm(total=self.n_iterations, desc="Measuring inference time", unit="iter") as pbar: + def measure_with_progress(m, x): list_elapsed_times.append(self._time_inference(m, x)) pbar.update(1) + self._measure(model, dataloader, self.n_iterations, measure_with_progress) total_elapsed_time = sum(list_elapsed_times) diff --git a/src/pruna/evaluation/metrics/metric_torch.py b/src/pruna/evaluation/metrics/metric_torch.py index 4d329d86..ea2365fa 100644 --- a/src/pruna/evaluation/metrics/metric_torch.py +++ b/src/pruna/evaluation/metrics/metric_torch.py @@ -50,6 +50,26 @@ ) from pruna.logging.logger import pruna_logger +_PRUNA_TASK_ROUTING_KWARGS: tuple[str, ...] = ( + "vlm_type", + "model_name", + "structured_output", + "vlm_kwargs", + "api_key", +) + + +def _strip_task_routing_kwargs(kwargs: dict[str, Any]) -> None: + """ + Drop kwargs :class:`~pruna.evaluation.task.Task` passes when building mixed metric lists. + + Torchmetrics classes often end with ``**kwargs`` and would otherwise accept bogus keys + until a lower layer raises. Stripping here keeps :class:`TorchMetricWrapper` the single + choke point between Pruna routing and torchmetrics constructors. + """ + for key in _PRUNA_TASK_ROUTING_KWARGS: + kwargs.pop(key, None) + def default_update(metric: Metric, *args, **kwargs) -> None: """ @@ -124,9 +144,7 @@ def arniqa_update(metric: ARNIQA, preds: Any) -> None: def ssim_update( - metric: StructuralSimilarityIndexMeasure | MultiScaleStructuralSimilarityIndexMeasure, - preds: Any, - target: Any + metric: StructuralSimilarityIndexMeasure | MultiScaleStructuralSimilarityIndexMeasure, preds: Any, target: Any ) -> None: """ Update handler for SSIM or MS-SSIM metric. @@ -246,6 +264,7 @@ def __new__(cls, metric_name: str, call_type: str = "", **kwargs) -> StatefulMet if metric_name == "clip_score" and call_type.startswith(PAIRWISE): from pruna.evaluation.metrics.metric_pairwise_clip import PairwiseClipScore + _strip_task_routing_kwargs(kwargs) return PairwiseClipScore(**kwargs) return super().__new__(cls) @@ -259,6 +278,7 @@ def __init__(self, metric_name: str, call_type: str = "", **kwargs) -> None: If the metric name is not supported. """ self.metric_name = metric_name + _strip_task_routing_kwargs(kwargs) super().__init__(kwargs.pop("device", None)) try: self.metric = TorchMetrics[metric_name](**kwargs) diff --git a/src/pruna/evaluation/task.py b/src/pruna/evaluation/task.py index 0ae4ba8a..3e0866b5 100644 --- a/src/pruna/evaluation/task.py +++ b/src/pruna/evaluation/task.py @@ -27,6 +27,7 @@ from pruna.evaluation.metrics.metric_torch import TorchMetricWrapper from pruna.evaluation.metrics.registry import MetricRegistry from pruna.evaluation.metrics.utils import get_hyperparameters +from pruna.evaluation.metrics.vlm_base import VLM_METRIC_REGISTRY_NAMES from pruna.logging.logger import pruna_logger AVAILABLE_REQUESTS = ("image_generation_quality", "text_generation_quality") @@ -102,6 +103,20 @@ def from_benchmark( dataloader_args=dataloader_args or {}, **kwargs, ) + if benchmark.lookup_key == "GenEval": + return cls( + request=[ + MetricRegistry.get_metric( + "qa_accuracy", + aggregation="all_or_nothing", + model_name="openai/gpt-4o", + ), + MetricRegistry.get_metric("clip_score"), + ], + datamodule=datamodule, + device=device, + low_memory=low_memory, + ) return cls( request=benchmark.metrics, datamodule=datamodule, @@ -295,9 +310,16 @@ def _process_metric_names( for metric_name in request: metric_name = cast(str, metric_name) new_requests.append(cast(str, metric_name)) - return MetricRegistry.get_metrics( - names=new_requests, inference_device=inference_device, stateful_metric_device=stateful_metric_device - ) + out: List[BaseMetric | StatefulMetric] = [] + for name in new_requests: + kwargs: dict[str, Any] = { + "inference_device": inference_device, + "stateful_metric_device": stateful_metric_device, + } + if name in VLM_METRIC_REGISTRY_NAMES: + kwargs["model_name"] = "openai/gpt-4o" + out.append(MetricRegistry.get_metric(name, **kwargs)) + return out def _get_lm_eval_task_metrics(task_name: str): diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index 103cadfb..329a44f4 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -1,3 +1,4 @@ +import importlib.util from typing import Any, Callable import pytest @@ -59,12 +60,18 @@ def _assert_at_least_one_sample(datamodule: PrunaDataModule) -> None: pytest.param("GenAIBench", dict(), marks=pytest.mark.slow), pytest.param("TinyIMDB", dict(tokenizer=bert_tokenizer), marks=pytest.mark.slow), pytest.param("VBench", dict(), marks=pytest.mark.slow), - pytest.param("GenEval", dict(), marks=pytest.mark.slow), pytest.param("HPS", dict(), marks=pytest.mark.slow), pytest.param("ImgEdit", dict(), marks=pytest.mark.slow), pytest.param("LongTextBench", dict(), marks=pytest.mark.slow), pytest.param("GEditBench", dict(), marks=pytest.mark.slow), pytest.param("OneIG", dict(), marks=pytest.mark.slow), + pytest.param("OneIGAnimeStylization", dict(), marks=pytest.mark.slow), + pytest.param("OneIGGeneralObject", dict(), marks=pytest.mark.slow), + pytest.param("OneIGKnowledgeReasoning", dict(), marks=pytest.mark.slow), + pytest.param("OneIGMultilingualism", dict(), marks=pytest.mark.slow), + pytest.param("OneIGPortrait", dict(), marks=pytest.mark.slow), + pytest.param("OneIGTextRendering", dict(), marks=pytest.mark.slow), + pytest.param("GenEval", dict(), marks=pytest.mark.slow), pytest.param("DPG", dict(), marks=pytest.mark.slow), ], ) @@ -104,30 +111,32 @@ def test_dm_from_dataset(setup_fn: Callable, collate_fn: str, collate_fn_args: d iterate_dataloaders(datamodule) -def _benchmarks_with_category() -> list[tuple[str, str]]: - """Benchmarks that have a category param: (dataset_name, category) for every category.""" +def _benchmark_category_smoke() -> list[tuple[str, str]]: + """One (dataset, category) per benchmark that exposes a ``category`` parameter.""" result = [] - for name in base_datasets: + for name in sorted(base_datasets): + if name == "VBench" and importlib.util.find_spec("vbench") is None: + continue setup_fn = base_datasets[name][0] literal_values = get_literal_values_from_param(setup_fn, "category") if literal_values: - for cat in literal_values: - result.append((name, cat)) + result.append((name, sorted(literal_values)[0])) return result @pytest.mark.cpu @pytest.mark.slow -@pytest.mark.parametrize("dataset_name, category", _benchmarks_with_category()) +@pytest.mark.parametrize("dataset_name, category", _benchmark_category_smoke()) def test_benchmark_category_filter(dataset_name: str, category: str) -> None: - """Test dataset loading with each category filter; dataset has at least one sample.""" + """Category filter loads and batches match the chosen category (one category per dataset).""" dm = PrunaDataModule.from_string(dataset_name, category=category, dataloader_args={"batch_size": 4}) _assert_at_least_one_sample(dm) dm.limit_datasets(10) batch = next(iter(dm.test_dataloader())) prompts, auxiliaries = batch - assert len(prompts) == 4 + # Some categories have fewer than 4 samples; assert at least one rather than exactly four. + assert 1 <= len(prompts) <= 4 assert all(isinstance(p, str) for p in prompts) def _category_in_aux(aux: dict, cat: str) -> bool: @@ -143,20 +152,17 @@ def _category_in_aux(aux: dict, cat: str) -> bool: @pytest.mark.cpu @pytest.mark.slow -@pytest.mark.parametrize( - "dataset_name, required_aux_key", - [ +def test_prompt_benchmark_auxiliaries() -> None: + """Prompt-based benchmarks expose expected aux keys.""" + for dataset_name, required_aux_key in ( ("LongTextBench", "text_content"), ("OneIG", "text_content"), - ], -) -def test_prompt_benchmark_auxiliaries(dataset_name: str, required_aux_key: str) -> None: - """Test prompt-based benchmarks load with expected auxiliaries.""" - dm = PrunaDataModule.from_string(dataset_name, dataloader_args={"batch_size": 4}) - dm.limit_datasets(10) - batch = next(iter(dm.test_dataloader())) - prompts, auxiliaries = batch - - assert len(prompts) == 4 - assert all(isinstance(p, str) for p in prompts) - assert all(required_aux_key in aux for aux in auxiliaries) + ): + dm = PrunaDataModule.from_string(dataset_name, dataloader_args={"batch_size": 4}) + dm.limit_datasets(10) + batch = next(iter(dm.test_dataloader())) + prompts, auxiliaries = batch + + assert len(prompts) == 4 + assert all(isinstance(p, str) for p in prompts) + assert all(required_aux_key in aux for aux in auxiliaries) diff --git a/tests/evaluation/_vlm_batch_snapshot_helpers.py b/tests/evaluation/_vlm_batch_snapshot_helpers.py new file mode 100644 index 00000000..f10eb3c8 --- /dev/null +++ b/tests/evaluation/_vlm_batch_snapshot_helpers.py @@ -0,0 +1,140 @@ +"""Test-only helpers: placeholder ``pred`` tensors from aux + JSON snapshot records.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import torch + +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.vlm_utils import pil_rgb_from_aux_image_bytes + + +def pred_tensor_from_auxiliaries( + auxiliaries: list[Any], + size: int = 224, + *, + require_source_image: bool = False, +) -> torch.Tensor: + """Build a float pred batch from aux dicts (tests only; uses :func:`pil_rgb_from_aux_image_bytes`).""" + from torchvision.transforms import CenterCrop, Compose, Resize, ToTensor + + transform = Compose([Resize(size), CenterCrop(size), ToTensor()]) + tensors = [] + for aux in auxiliaries: + if not isinstance(aux, dict): + if require_source_image: + raise ValueError("require_source_image=True but auxiliary entry is not a dict.") + tensors.append(torch.rand(3, size, size)) + continue + + pil = pil_rgb_from_aux_image_bytes(aux, min_bytes_in_value_scan=0) + if pil is not None: + try: + tensors.append(transform(pil)) + continue + except Exception: + pass + + if require_source_image: + raise ValueError(f"require_source_image=True but no decodable image bytes found (keys: {list(aux.keys())}).") + tensors.append(torch.rand(3, size, size)) + return torch.stack(tensors) + + +@dataclass(frozen=True) +class BenchmarkVlmBatchOutcome: + """Minimal batch + metric result for snapshot tests.""" + + result: MetricResult + prompts: list[Any] + auxiliaries: list[Any] + pred: torch.Tensor + + +def _short(obj: Any, max_len: int = 400) -> Any: + if isinstance(obj, str) and len(obj) > max_len: + return obj[:max_len] + "…" + return obj + + +def _question_value_for_record(qt: Any, max_len: int = 200) -> Any: + if qt is None: + return None + if isinstance(qt, str): + return _short(qt, max_len) + return _short(str(qt), max_len) + + +def _aux_for_record(aux: dict[str, Any]) -> dict[str, Any]: + out: dict[str, Any] = {} + for k, v in aux.items(): + if k == "questions" and isinstance(v, dict): + out[k] = {qk: _question_value_for_record(qt, 200) for qk, qt in list(v.items())[:24]} + if len(v) > 24: + out["_truncated_questions"] = len(v) - 24 + else: + out[k] = _short(v) if isinstance(v, str) else v + return out + + +def safe_json_for_snapshot(obj: Any) -> Any: + """Recursively JSON-safe view (bytes → length, tensors → shape).""" + if obj is None or isinstance(obj, (bool, int, float, str)): + return obj + if isinstance(obj, bytes): + return {"bytes_len": len(obj)} + if isinstance(obj, dict): + return {str(k): safe_json_for_snapshot(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [safe_json_for_snapshot(x) for x in obj] + if isinstance(obj, torch.Tensor): + return {"tensor_shape": list(obj.shape), "dtype": str(obj.dtype)} + return str(obj) + + +def _metric_result_record(mr: MetricResult) -> dict[str, Any]: + return { + "name": mr.name, + "result": float(mr.result), + "higher_is_better": mr.higher_is_better, + "metric_units": mr.metric_units, + } + + +def vlm_benchmark_batch_to_json_record( + outcome: BenchmarkVlmBatchOutcome, + *, + benchmark_key: str, + benchmark_name: str, + metric_name: str, + vlm_type: str, + model_name: str, + device: str, + pred_note: str | None = "random noise placeholder", +) -> dict[str, Any]: + """Build a JSON-serializable snapshot (used only in tests).""" + a0 = outcome.auxiliaries[0] if outcome.auxiliaries and isinstance(outcome.auxiliaries[0], dict) else {} + pred_payload: dict[str, Any] = { + "shape": list(outcome.pred.shape), + "dtype": str(outcome.pred.dtype), + } + if pred_note is not None: + pred_payload["note"] = pred_note + record: dict[str, Any] = { + "benchmark_lookup_key": benchmark_key, + "benchmark_name": benchmark_name, + "metric_name": metric_name, + "dataset_name": benchmark_key, + "vlm_type": vlm_type, + "model_name": model_name, + "device": device, + "inputs": { + "prompts": [_short(p, 500) for p in outcome.prompts], + "auxiliary_0": _aux_for_record(a0) if isinstance(a0, dict) else safe_json_for_snapshot(a0), + }, + "pred": pred_payload, + "metric_result": _metric_result_record(outcome.result), + } + return safe_json_for_snapshot(record) diff --git a/tests/evaluation/test_task.py b/tests/evaluation/test_task.py index 67d3aff0..efc8cfa7 100644 --- a/tests/evaluation/test_task.py +++ b/tests/evaluation/test_task.py @@ -1,5 +1,5 @@ import pytest -from unittest.mock import patch +from unittest.mock import MagicMock, patch from transformers import AutoTokenizer from pruna.evaluation.task import Task from pruna.data.pruna_datamodule import PrunaDataModule @@ -36,10 +36,10 @@ def make_mock_metric(metric_class): with patch.object(TorchMetrics, '_member_map_', {**TorchMetrics._member_map_, **mock_metrics}): yield -@pytest.mark.parametrize("metric_name", MetricRegistry()._registry) +@pytest.mark.parametrize("metric_name", sorted(MetricRegistry._registry)) def test_metric_initialization_from_metric_name(metric_name): datamodule = PrunaDataModule.from_string("LAION256") - Task(request=[metric_name], datamodule=datamodule) + Task(request=[metric_name], datamodule=datamodule, device="cpu") @device_parametrized @@ -124,3 +124,16 @@ def test_task_invalid_named_request(): """Test that an invalid named request raises a ValueError.""" with pytest.raises(ValueError, match="not found"): Task(request="nonexistent_quality", datamodule=PrunaDataModule.from_string("LAION256"), device="cpu") + + +@pytest.mark.cpu +@patch("pruna.evaluation.task.PrunaDataModule.from_string") +def test_geneval_from_benchmark_uses_qa_accuracy_all_or_nothing(mock_from_string: MagicMock) -> None: + """GenEval uses strict per-image QA aggregation and CLIP.""" + mock_dm = MagicMock() + mock_dm.test_dataloader.return_value = iter([]) + mock_from_string.return_value = mock_dm + task = Task.from_benchmark("GenEval", dataloader_args={"batch_size": 1}) + qa = next(m for m in task.metrics if getattr(m, "metric_name", None) == "qa_accuracy") + assert qa.aggregation == "all_or_nothing" + assert any(getattr(m, "metric_name", None) == "clip_score" for m in task.metrics) diff --git a/tests/evaluation/test_vlm_e2e.py b/tests/evaluation/test_vlm_e2e.py new file mode 100644 index 00000000..a4eaa139 --- /dev/null +++ b/tests/evaluation/test_vlm_e2e.py @@ -0,0 +1,684 @@ +"""Tests for VLM metrics (VQA, ImageEditScore, QAAccuracy, TextScore, VieScore) and vlm_utils helpers.""" + +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from pruna.evaluation.metrics.metric_img_edit_score import ImageEditScoreMetric +from pruna.evaluation.metrics.metric_oneig_alignment import OneIGAlignmentMetric +from pruna.evaluation.metrics.metric_qa_accuracy import QAAccuracyMetric +from pruna.evaluation.metrics.metric_text_score import OneIGTextScoreMetric, TextScoreMetric +from pruna.evaluation.metrics.metric_vie_score import VieScoreMetric +from pruna.evaluation.metrics.metric_vqa import VQAMetric +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.vlm_base import BaseVLM, get_vlm +from pruna.evaluation.metrics.vlm_utils import ( + FloatOutput, + VLM_AUX_IMAGE_BYTES_KEY_ORDER, + get_score_from_response, + yes_no_first_token_id_groups, +) + +from ._vlm_batch_snapshot_helpers import ( + BenchmarkVlmBatchOutcome, + pred_tensor_from_auxiliaries, + safe_json_for_snapshot, + vlm_benchmark_batch_to_json_record, +) + +SMOL_VLM = "HuggingFaceTB/SmolVLM-256M-Instruct" + +_ALL_VLM = ( + VQAMetric, + ImageEditScoreMetric, + QAAccuracyMetric, + OneIGAlignmentMetric, + TextScoreMetric, + OneIGTextScoreMetric, + VieScoreMetric, +) + +_SLOW_SMOL_SUBSET = ( + VQAMetric, + OneIGAlignmentMetric, + ImageEditScoreMetric, + VieScoreMetric, +) + + +@pytest.mark.parametrize( + ("raw", "expected"), + [ + (FloatOutput(score=8.0), 0.8), + ({"score": 5.0}, 0.5), + ('{"score": 7.5}', 0.75), + ('{"score": 10}', 1.0), + ("8", 0.8), + ("Score: 7.5 out of 10", 0.75), + ("", 0.0), + ], +) +def test_get_score_from_response(raw: object, expected: float) -> None: + """``get_score_from_response`` maps pydantic, dict, JSON, and text to ``[0, 1]``.""" + assert get_score_from_response(raw) == pytest.approx(expected) + + +def _dummy_image(batch: int = 1, size: int = 224) -> torch.Tensor: + return torch.rand(batch, 3, size, size) + + +def _update_metric(metric: object, prompts: list, images: torch.Tensor) -> None: + if isinstance(metric, OneIGAlignmentMetric): + metric.update( + prompts, + [ + { + "questions": {"1": "Is there a cat?", "2": "Is it sleeping?"}, + "dependencies": {"1": [0], "2": [1]}, + } + ], + images, + ) + elif isinstance(metric, QAAccuracyMetric): + metric.update( + prompts, + [{"questions": {"1": "Is there a cat?"}}], + images, + ) + elif isinstance(metric, (TextScoreMetric, OneIGTextScoreMetric)): + metric.update(prompts, ["cat"], images) + else: + metric.update(prompts, images, images) + + +@pytest.mark.cpu +@pytest.mark.slow +@pytest.mark.parametrize("metric_cls", _SLOW_SMOL_SUBSET) +def test_vlm_metrics_transformers_smolvlm(metric_cls: type) -> None: + """Smoke-test a subset with local SmolVLM (full matrix covered by litellm mock).""" + metric = metric_cls( + vlm_type="transformers", + model_name=SMOL_VLM, + device="cpu", + structured_output=True, + ) + images = _dummy_image(batch=1) + prompts = ["a cat"] + _update_metric(metric, prompts, images) + result = metric.compute() + assert result.name == metric.metric_name + assert isinstance(result.result, float) + if metric.higher_is_better: + assert 0.0 <= result.result <= 1.0 + else: + assert result.result >= 0.0 + + +@pytest.mark.cpu +@pytest.mark.parametrize("metric_cls", _ALL_VLM) +def test_vlm_metrics_litellm_mocked(metric_cls: type) -> None: + """Each VLM metric runs end-to-end with mocked litellm.""" + pytest.importorskip("litellm") + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + if metric_cls in (VQAMetric, QAAccuracyMetric, OneIGAlignmentMetric): + mock_response.choices[0].message.content = '{"answer": "Yes"}' + else: + mock_response.choices[0].message.content = '{"score": 8}' + + with patch("litellm.completion") as mock_completion: + mock_completion.return_value = mock_response + + metric = metric_cls( + vlm_type="litellm", + model_name="gpt-4o", + device="cpu", + structured_output=True, + ) + images = _dummy_image(batch=1) + prompts = ["a cat"] + _update_metric(metric, prompts, images) + result = metric.compute() + + assert result.name == metric.metric_name + assert isinstance(result.result, float) + assert mock_completion.called + + +@pytest.mark.cpu +def test_vlm_metrics_empty_compute_returns_zero() -> None: + """No updates → compute is 0.0 (same for all stateful VLM metrics).""" + metric = VQAMetric( + vlm_type="transformers", + model_name=SMOL_VLM, + device="cpu", + structured_output=True, + ) + assert metric.compute().result == 0.0 + + +@pytest.mark.cpu +def test_vlm_metrics_custom_vlm() -> None: + """Custom VLM passed to VQAMetric is used instead of the default litellm backend.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.generate.return_value = ["Yes"] + mock_vlm.score.return_value = [1.0] + + metric = VQAMetric(vlm=mock_vlm, vlm_type="litellm", device="cpu", structured_output=True) + images = _dummy_image(batch=1) + prompts = ["a cat"] + metric.update(prompts, images, images) + assert metric.compute().result == 1.0 + mock_vlm.score.assert_called() + + +@pytest.mark.cpu +def test_get_vlm_returns_custom() -> None: + """get_vlm returns the provided VLM instance unchanged.""" + custom = MagicMock(spec=BaseVLM) + out = get_vlm(vlm=custom, vlm_type="litellm", model_name="gpt-4o") + assert out is custom + + +@pytest.mark.cpu +def test_yes_no_first_token_id_groups_disjoint() -> None: + """Prefix token ids for Yes vs No should not overlap (avoids double-counting).""" + pytest.importorskip("transformers") + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained("gpt2") + yes_ids, no_ids = yes_no_first_token_id_groups(tok) + assert yes_ids and no_ids + assert not (set(yes_ids) & set(no_ids)) + + +@pytest.mark.cpu +def test_get_vlm_requires_model_name_without_vlm() -> None: + """get_vlm raises ValueError when no model_name is given and no vlm is provided.""" + with pytest.raises(ValueError, match="model_name"): + get_vlm(vlm=None, vlm_type="litellm") + + +@pytest.mark.cpu +@pytest.mark.parametrize( + "metric_cls, expected_name, expected_result", + [ + (TextScoreMetric, "text_score", 1.0), + (OneIGTextScoreMetric, "oneig_text_score", 1.0), + ], +) +def test_text_metrics_list_str_gt(metric_cls: type, expected_name: str, expected_result: float) -> None: + """Text metrics accept plain string ground-truth and return the expected score.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.generate.return_value = ["hello world"] + + metric = metric_cls(vlm=mock_vlm, vlm_type="litellm", device="cpu") + images = _dummy_image(batch=1) + metric.update(["a prompt"], ["hello world"], images) + result = metric.compute() + + assert result.result == expected_result + assert result.name == expected_name + mock_vlm.generate.assert_called_once() + + +@pytest.mark.cpu +def test_text_score_result_in_zero_one_range() -> None: + """TextScoreMetric must return a normalized score in [0, 1], not raw edit distance.""" + mock_vlm = MagicMock(spec=BaseVLM) + # VLM OCR returns something very different from ground truth (high edit distance) + mock_vlm.generate.return_value = ["completely wrong text abcdefghijklmnop"] + + metric = TextScoreMetric(vlm=mock_vlm, device="cpu") + images = _dummy_image(batch=1) + metric.update(["prompt"], ["hello"], images) + result = metric.compute() + + assert 0.0 <= result.result <= 1.0, f"TextScoreMetric must return [0,1], got {result.result}" + assert result.result < 0.5, f"Very different strings should score below 0.5, got {result.result}" + + +@pytest.mark.cpu +def test_text_score_perfect_match_is_one() -> None: + """TextScoreMetric: identical OCR and GT -> score 1.0.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.generate.return_value = ["hello world"] + + metric = TextScoreMetric(vlm=mock_vlm, device="cpu") + images = _dummy_image(batch=1) + metric.update(["prompt"], ["hello world"], images) + result = metric.compute() + + assert result.result == 1.0, f"Perfect match should give 1.0, got {result.result}" + assert result.higher_is_better is True + + +@pytest.mark.cpu +def test_text_score_registry_aliases() -> None: + """Registry aliases ocr_levenshtein and ocr_text_score resolve to the correct metric classes.""" + from pruna.evaluation.metrics.registry import MetricRegistry + + lev = MetricRegistry.get_metric("ocr_levenshtein", device="cpu", model_name="openai/gpt-4o") + comp = MetricRegistry.get_metric("ocr_text_score", device="cpu", model_name="openai/gpt-4o") + assert type(lev).__name__ == "TextScoreMetric" + assert type(comp).__name__ == "OneIGTextScoreMetric" + assert lev.metric_name == "text_score" + assert comp.metric_name == "oneig_text_score" + + +@pytest.mark.cpu +def test_oneig_text_score_utils_golden_composite() -> None: + """oneig_mean_text_score returns expected component values for a known input.""" + from pruna.evaluation.metrics.metric_text_score_utils import oneig_mean_text_score + + ed, cr, wac, composite = oneig_mean_text_score( + edit_distances=[10.0], + completion_ratios=[0.0], + match_counts=[2], + gt_totals=[4], + language_mode="EN", + ) + assert ed == 10.0 + assert cr == 0.0 + assert wac == 0.5 + assert composite == pytest.approx(0.95) + + _, _, _, zh = oneig_mean_text_score( + edit_distances=[30.0], + completion_ratios=[0.0], + match_counts=[0], + gt_totals=[1], + language_mode="ZH", + ) + assert zh == pytest.approx(0.4) + + +@pytest.mark.cpu +def test_qa_accuracy_all_or_nothing_partial_fail() -> None: + """all_or_nothing: if any question scores 0, the image score is 0.0 (not a partial mean).""" + mock_vlm = MagicMock(spec=BaseVLM) + # First question Yes (1.0), second question No (0.0) → mean=0.5, all_or_nothing=0.0 + mock_vlm.score.return_value = [1.0, 0.0] + + metric = QAAccuracyMetric(vlm=mock_vlm, device="cpu", aggregation="all_or_nothing") + metric.update( + ["a prompt"], + [{"questions": {"1": "Is there a cat?", "2": "Is it blue?"}}], + _dummy_image(batch=1), + ) + result = metric.compute() + assert result.result == 0.0, f"Expected 0.0 for all_or_nothing with one No, got {result.result}" + + +@pytest.mark.cpu +def test_qa_accuracy_all_or_nothing_all_yes() -> None: + """all_or_nothing: all Yes → score 1.0.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.score.return_value = [1.0, 1.0] + + metric = QAAccuracyMetric(vlm=mock_vlm, device="cpu", aggregation="all_or_nothing") + metric.update( + ["a prompt"], + [{"questions": {"1": "Is there a cat?", "2": "Is it blue?"}}], + _dummy_image(batch=1), + ) + result = metric.compute() + assert result.result == 1.0, f"Expected 1.0 for all_or_nothing with all Yes, got {result.result}" + + +@pytest.mark.cpu +def test_qa_accuracy_invalid_aggregation_raises() -> None: + """qa_accuracy rejects aggregation values other than mean / all_or_nothing.""" + mock_vlm = MagicMock(spec=BaseVLM) + with pytest.raises(ValueError, match="aggregation"): + QAAccuracyMetric(vlm=mock_vlm, device="cpu", aggregation="median") + + +@pytest.mark.cpu +def test_vie_score_tie_uses_source_from_gt_and_two_image_sc() -> None: + """With ``source_image_bytes`` in gt, VieScore calls two-image SC then PQ on the edited image.""" + from io import BytesIO + + from PIL import Image + + buf = BytesIO() + Image.new("RGB", (8, 8), color=(0, 0, 200)).save(buf, format="PNG") + src_bytes = buf.getvalue() + + mock_vlm = MagicMock() + mock_vlm.generate_with_image_lists.return_value = ['{"score": [8.0, 8.0], "reasoning": "ok"}'] + mock_vlm.generate.return_value = ['{"score": [9.0, 9.0], "reasoning": "ok"}'] + + metric = VieScoreMetric(vlm=mock_vlm, device="cpu", structured_output=True) + pred = _dummy_image(batch=1) + metric.update( + ["make the sky purple"], + [{"source_image_bytes": src_bytes}], + pred, + ) + result = metric.compute() + + assert mock_vlm.generate_with_image_lists.called + assert mock_vlm.generate.called + assert 0.0 <= result.result <= 1.0 + + +@pytest.mark.cpu +def test_vie_score_uses_get_score_from_response() -> None: + """VieScoreMetric ``t2i`` path parses JSON ``score`` lists via ``viescore_min_scores_0_10``.""" + mock_vlm = MagicMock(spec=BaseVLM) + # LitellmVLM returns model_dump_json() for structured outputs → JSON string (two SC + two PQ sub-scores) + mock_vlm.generate.return_value = ['{"score": [8.0, 8.0], "reasoning": ""}'] + + metric = VieScoreMetric(vlm=mock_vlm, device="cpu", structured_output=True) + metric.update(["a cat on a sofa"], _dummy_image(batch=1), _dummy_image(batch=1)) + result = metric.compute() + + # min(SC)=8, min(PQ)=8 → sqrt(8 * 8) / 10 = 0.8 + assert abs(result.result - 0.8) < 0.01, f"Expected ~0.8, got {result.result}" + + +@pytest.mark.cpu +def test_img_edit_score_negative_response_clamped() -> None: + """img_edit_score must be non-negative even when the VLM generates a negative JSON score. + + Regression for: Outlines constrained decoding can emit {"score": -10} despite the + FloatOutput JSON schema specifying minimum=0, because Outlines does not enforce numeric + bounds during token sampling. The fix is max(0.0, ...) in get_score_from_response. + """ + mock_vlm = MagicMock(spec=BaseVLM) + # Simulate Outlines generating a negative value (the bug scenario) + mock_vlm.generate.return_value = ['{"score": -10.0}'] + + metric = ImageEditScoreMetric(vlm=mock_vlm, device="cpu", structured_output=True) + metric.update(["replace the boot with a mug"], torch.zeros(1), _dummy_image(batch=1)) + result = metric.compute() + + assert result.result >= 0.0, f"img_edit_score must be >= 0, got {result.result}" + + +@pytest.mark.cpu +def test_qa_accuracy_all_or_nothing_ambiguous_score() -> None: + """all_or_nothing: score exactly 0.5 (ambiguous) is treated as No → result 0.0.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.score.return_value = [0.5] + + metric = QAAccuracyMetric(vlm=mock_vlm, device="cpu", aggregation="all_or_nothing") + metric.update( + ["a prompt"], + [{"questions": {"1": "Is there a cat?"}}], + _dummy_image(batch=1), + ) + result = metric.compute() + assert result.result == 0.0, f"Score 0.5 should be treated as No (ambiguous), got {result.result}" + + +@pytest.mark.cpu +@pytest.mark.slow +def test_yes_no_token_ids_smolvlm_nonempty() -> None: + """SmolVLM tokenizer must yield non-empty disjoint yes/no prefix ids for VQAScore scoring.""" + pytest.importorskip("transformers") + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct") + yes_ids, no_ids = yes_no_first_token_id_groups(tok) + assert len(yes_ids) > 0, "SmolVLM tokenizer has no 'Yes'-prefix token ids" + assert len(no_ids) > 0, "SmolVLM tokenizer has no 'No'-prefix token ids" + assert not (set(yes_ids) & set(no_ids)), "yes_ids and no_ids must be disjoint" + + +@pytest.mark.cpu +def test_img_edit_score_uses_prompt_from_x() -> None: + """img_edit_score must score the edited image against the instruction from x, not gt.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.generate.return_value = ['{"score": 9}'] + + metric = ImageEditScoreMetric(vlm=mock_vlm, device="cpu") + pred = _dummy_image(batch=1) + metric.update( + ["replace the cat with a dog"], # x = instruction + pred, # gt = unused for y_x + pred, # outputs = edited image + ) + result = metric.compute() + + call_args = mock_vlm.generate.call_args + prompt_sent = call_args[0][1][0] # second positional arg = prompts list, first item + assert "replace the cat with a dog" in prompt_sent, f"Instruction not in VLM prompt. Got: {prompt_sent}" + assert abs(result.result - 0.9) < 0.01, f"Expected ~0.9, got {result.result}" + + +@pytest.mark.cpu +def test_vie_score_geditbench_gap_documented() -> None: + """VieScoreMetric infers text--image editing from ``source_image_bytes`` in aux (no ``task_type``). + + This test fails if a ``task_type`` parameter is added to ``__init__`` without updating + GEditBench integration tests and benchmark copy accordingly. + """ + import inspect + + sig = inspect.signature(VieScoreMetric.__init__) + assert "task_type" not in sig.parameters, ( + "VieScoreMetric now has task_type — update GEditBench docs and e2e tests, then remove this sentinel." + ) + + +@pytest.mark.cpu +def test_litellm_logprob_aggregation_sums_all_yes_tokens() -> None: + """LitellmVLM logprob scoring must sum all yes-prefix token probs, not return the first.""" + pytest.importorskip("litellm") + import math + from unittest.mock import MagicMock, patch + + import numpy as np + from PIL import Image + + from pruna.evaluation.metrics.vlm_base import LitellmVLM + + # Simulate top_logprobs for first output token: + # "Yes" → logprob=-2.303 (p≈0.10), " yes" → logprob=-2.996 (p≈0.05) → total p_yes≈0.15 + # "No" → logprob=-1.609 (p≈0.20), " no" → logprob=-2.303 (p≈0.10) → total p_no≈0.30 + # normalized: p_yes/(p_yes+p_no) ≈ 0.15/0.45 ≈ 0.333 + def make_top_logprob(token, logprob): + t = MagicMock() + t.token = token + t.logprob = logprob + return t + + first_tok = MagicMock() + first_tok.top_logprobs = [ + make_top_logprob("Yes", math.log(0.10)), + make_top_logprob(" yes", math.log(0.05)), + make_top_logprob("No", math.log(0.20)), + make_top_logprob(" no", math.log(0.10)), + make_top_logprob("maybe", math.log(0.55)), + ] + + mock_logprobs = MagicMock() + mock_logprobs.content = [first_tok] + + mock_choice = MagicMock() + mock_choice.logprobs = mock_logprobs + mock_choice.message.content = "Yes" + + mock_response = MagicMock() + mock_response.choices = [mock_choice] + + with patch("litellm.completion", return_value=mock_response): + vlm = LitellmVLM(model_name="openai/gpt-4o") + img = Image.fromarray(np.zeros((32, 32, 3), dtype="uint8")) + score = vlm._score_with_logprobs(img, "Is there a cat?", "Yes") + + # Should be ~0.333 (p_yes=0.15 / (p_yes+p_no)=0.45), not just 0.10 (first match) + assert 0.28 < score < 0.40, f"Expected ~0.333 (sum-normalized), got {score}" + + +@pytest.mark.cpu +@pytest.mark.slow +def test_vqa_probability_score_normalized() -> None: + """P(Yes) from TransformersVLM.score use_probability=True is in [0, 1].""" + pytest.importorskip("transformers") + import numpy as np + from PIL import Image + + from pruna.evaluation.metrics.vlm_base import TransformersVLM + + vlm = TransformersVLM( + model_name="HuggingFaceTB/SmolVLM-256M-Instruct", + device="cpu", + use_outlines=False, + ) + img = Image.fromarray(np.zeros((32, 32, 3), dtype="uint8")) + scores = vlm.score([img], ["Is there a cat?"], ["Yes"], use_probability=True) + assert len(scores) == 1 + assert 0.0 <= scores[0] <= 1.0, f"P(Yes) must be in [0, 1], got {scores[0]}" + + +# --------------------------------------------------------------------------- +# vlm_benchmark_batch_to_json_record serialization tests +# --------------------------------------------------------------------------- + + +def test_vlm_benchmark_batch_to_json_record_serializes_batch() -> None: + """Record includes prompts, pred shape, and metric fields.""" + mr = MetricResult(name="qa_accuracy", params={}, result=0.25, higher_is_better=True) + outcome = BenchmarkVlmBatchOutcome( + result=mr, + prompts=["prompt"], + auxiliaries=[{"path": "/tmp/x.png"}], + pred=torch.zeros(1, 3, 8, 8), + ) + rec = vlm_benchmark_batch_to_json_record( + outcome, + benchmark_key="GenEval", + benchmark_name="GenEval", + metric_name="qa_accuracy", + vlm_type="transformers", + model_name="m", + device="cpu", + ) + assert rec["inputs"]["prompts"] == ["prompt"] + assert rec["pred"]["shape"] == [1, 3, 8, 8] + assert rec["metric_result"]["result"] == 0.25 + + +def test_safe_json_handles_bytes_without_expanding() -> None: + """Bytes values in aux (e.g. source_image_bytes) are summarized, not expanded to str repr.""" + result = safe_json_for_snapshot({"source_image_bytes": b"\xff\xd8\xff" * 1000, "name": "test"}) + assert result["source_image_bytes"] == {"bytes_len": 3000} + assert result["name"] == "test" + + +def test_vlm_benchmark_batch_to_json_record_preserves_null_question_slots() -> None: + """Padded ``None`` question labels stay JSON null, not the string ``"None"``.""" + mr = MetricResult(name="oneig_alignment", params={}, result=1.0, higher_is_better=True) + outcome = BenchmarkVlmBatchOutcome( + result=mr, + prompts=["p"], + auxiliaries=[{"questions": {"1": "Are there boys?", "21": None}, "subset": "Anime_Stylization"}], + pred=torch.zeros(1, 3, 8, 8), + ) + rec = vlm_benchmark_batch_to_json_record( + outcome, + benchmark_key="OneIGAnimeStylization", + benchmark_name="OneIG Anime Stylization", + metric_name="oneig_alignment", + vlm_type="transformers", + model_name="m", + device="cpu", + ) + qs = rec["inputs"]["auxiliary_0"]["questions"] + assert qs["1"] == "Are there boys?" + assert qs["21"] is None + + +# --------------------------------------------------------------------------- +# pred_tensor_from_auxiliaries (test helper, wraps pil_rgb_from_aux_image_bytes) tests +# --------------------------------------------------------------------------- + + +def _make_jpeg_bytes(h: int = 32, w: int = 32) -> bytes: + """Return a tiny JPEG-encoded RGB image as bytes (test helper).""" + import io + + import numpy as np + from PIL import Image + + arr = (np.random.rand(h, w, 3) * 255).astype("uint8") + buf = io.BytesIO() + Image.fromarray(arr).save(buf, format="JPEG") + return buf.getvalue() + + +@pytest.mark.cpu +def test_pred_from_auxiliaries_uses_source_image_bytes() -> None: + """pred_tensor_from_auxiliaries decodes source_image_bytes into a float tensor in [0, 1].""" + src_bytes = _make_jpeg_bytes() + aux = [{"source_image_bytes": src_bytes, "category": "background_change"}] + pred = pred_tensor_from_auxiliaries(aux, size=64) + + assert pred.shape == (1, 3, 64, 64), f"Expected (1,3,64,64), got {pred.shape}" + assert pred.min() >= 0.0 and pred.max() <= 1.0, "Pixel values must be in [0, 1]" + + +@pytest.mark.cpu +def test_pred_from_auxiliaries_falls_back_to_noise_without_source_image() -> None: + """pred_tensor_from_auxiliaries returns random noise when no source_image_bytes is present.""" + aux = [{"category": "single_object"}] + pred = pred_tensor_from_auxiliaries(aux, size=32) + assert pred.shape == (1, 3, 32, 32) + assert pred.min() >= 0.0 and pred.max() <= 1.0 + + +@pytest.mark.cpu +def test_pred_from_auxiliaries_mixed_batch() -> None: + """Batch with one source image and one missing falls back per-item.""" + src_bytes = _make_jpeg_bytes() + aux = [ + {"source_image_bytes": src_bytes, "category": "color_alter"}, + {"category": "style_change"}, # no source image + ] + pred = pred_tensor_from_auxiliaries(aux, size=32) + assert pred.shape == (2, 3, 32, 32) + assert pred.min() >= 0.0 and pred.max() <= 1.0 + + +@pytest.mark.cpu +def test_pred_from_auxiliaries_generic_bytes_scan() -> None: + """pred_tensor_from_auxiliaries discovers image bytes under an unknown field name (generic scan).""" + src_bytes = _make_jpeg_bytes() + aux = [{"my_custom_image_bytes": src_bytes, "category": "motion_change"}] + pred = pred_tensor_from_auxiliaries(aux, size=32) + assert pred.shape == (1, 3, 32, 32) + assert pred.min() >= 0.0 and pred.max() <= 1.0 + + +@pytest.mark.cpu +def test_pred_from_auxiliaries_known_names_take_priority() -> None: + """Known field names are resolved before the generic bytes scan.""" + src_bytes_known = _make_jpeg_bytes(16, 16) + src_bytes_unknown = _make_jpeg_bytes(32, 32) + first_known = VLM_AUX_IMAGE_BYTES_KEY_ORDER[0] + aux = [{"other_bytes": src_bytes_unknown, first_known: src_bytes_known}] + pred = pred_tensor_from_auxiliaries(aux, size=16) + # Should use the known key (16x16 image → 16x16 crop); generic scan would pick 32x32 + assert pred.shape == (1, 3, 16, 16) + + +@pytest.mark.cpu +def test_pred_from_auxiliaries_require_source_image_raises_when_missing() -> None: + """require_source_image=True raises ValueError instead of silently returning noise.""" + aux = [{"category": "replace"}] # no image bytes + with pytest.raises(ValueError, match="require_source_image=True"): + pred_tensor_from_auxiliaries(aux, size=32, require_source_image=True) + + +@pytest.mark.cpu +def test_pred_from_auxiliaries_require_source_image_succeeds_when_present() -> None: + """require_source_image=True succeeds and decodes bytes when source_image_bytes is present.""" + src_bytes = _make_jpeg_bytes() + aux = [{"source_image_bytes": src_bytes, "category": "replace"}] + pred = pred_tensor_from_auxiliaries(aux, size=32, require_source_image=True) + assert pred.shape == (1, 3, 32, 32) + assert pred.min() >= 0.0 and pred.max() <= 1.0