diff --git a/src/pruna/evaluation/benchmarks.py b/src/pruna/evaluation/benchmarks.py index 40ed999c..6ddf0001 100644 --- a/src/pruna/evaluation/benchmarks.py +++ b/src/pruna/evaluation/benchmarks.py @@ -18,6 +18,9 @@ from pruna.data.utils import get_literal_values_from_param from pruna.evaluation.metrics import MetricRegistry +TASK_TYPE_TEXT_IMAGE = "text_to_image" +TASK_TYPE_TEXT_PLUS_IMAGE_IMAGE = "text+image_image" + @dataclass class Benchmark: @@ -31,9 +34,11 @@ class Benchmark: description : str Description of what the benchmark evaluates. metrics : list[str] - List of metric names used for evaluation. + Metric names from ``MetricRegistry`` that the ``reference`` paper + explicitly names for that benchmark. task_type : str - Type of task the benchmark evaluates (e.g., 'text_to_image'). + Type of task the benchmark evaluates (e.g., ``text_to_image``, + ``text+image_image``, ``text_to_video``). reference : str | None URL to the canonical paper (e.g., arXiv) for this benchmark. """ @@ -62,24 +67,11 @@ class BenchmarkRegistry: """ Registry for benchmarks. - Metrics per benchmark are set to those explicitly used in the reference - paper (see reference URL). All entries verified from paper evaluation - sections (ar5iv/HTML or PDF) as of verification pass: - - - Parti Prompts (2206.10789 §5.2, §5.4): human side-by-side only on P222. - - DrawBench (2205.11487 §4.3): human raters only; COCO uses FID + CLIP. - - GenAI Bench (2406.13743): VQAScore only (web/PWC; ar5iv failed). - - VBench (2311.17982): 16 dimension-specific methods; no single Pruna metric. - - COCO (2205.11487 §4.1): FID and CLIP score for fidelity and alignment. - - ImageNet (1409.0575 §4): top-1/top-5 classification accuracy. - - WikiText (1609.07843 §5): perplexity on validation/test. - - GenEval (2310.11513 §3.2): Mask2Former + CLIP color pipeline, binary score. - - HPS (2306.09341): HPS v2 scoring model (CLIP fine-tuned on HPD v2). - - ImgEdit (2505.20275 §4.2): GPT-4o 1–5 ratings and ImgEdit-Judge. - - Long Text Bench (2507.22058 §4): Text Accuracy (OCR, Qwen2.5-VL-7B). - - GEditBench (2504.17761 §4.2): VIEScore (SQ, PQ, O via GPT-4.1/Qwen2.5-VL). - - OneIG (2506.07977 §4.1): per-dimension metrics (semantic alignment, ED, etc.). - - DPG (2403.05135): DSG-style graph score, mPLUG-large adjudicator. + Each entry's ``metrics`` lists only ``MetricRegistry`` names that have a + directly named counterpart in the benchmark reference paper (e.g. + CLIPScore -> ``clip_score``, VQAScore -> ``vqa``, FID -> ``fid``). + If the paper uses a method with no matching registered metric, the list is + kept empty and callers should pass explicit metrics to ``Task``. """ _registry: dict[str, Benchmark] = {} @@ -154,7 +146,7 @@ def list(cls, task_type: str | None = None) -> list[str]: "perspectives, and symbol rendering from basic to complex compositions." ), metrics=[], # Paper uses human evaluation only; pass explicit metrics if needed - task_type="text_to_image", + task_type=TASK_TYPE_TEXT_IMAGE, reference="https://arxiv.org/abs/2206.10789", ), Benchmark( @@ -164,7 +156,7 @@ def list(cls, task_type: str | None = None) -> list[str]: "Enables side-by-side comparison on sample quality and image-text alignment with human raters." ), metrics=[], # Paper uses human evaluation only; pass explicit metrics if needed - task_type="text_to_image", + task_type=TASK_TYPE_TEXT_IMAGE, reference="https://arxiv.org/abs/2205.11487", ), Benchmark( @@ -174,8 +166,8 @@ def list(cls, task_type: str | None = None) -> list[str]: "Covers basic skills (scene, attributes, spatial relationships) to advanced reasoning " "(counting, comparison, logic/negation) with over 24k human ratings." ), - metrics=["vqa", "clip_score"], - task_type="text_to_image", + metrics=["vqa", "clip_score"], # VQAScore + CLIPScore both named (arXiv:2406.13743) + task_type=TASK_TYPE_TEXT_IMAGE, reference="https://arxiv.org/abs/2406.13743", ), Benchmark( @@ -195,8 +187,8 @@ def list(cls, task_type: str | None = None) -> list[str]: "MS-COCO for text-to-image evaluation (Imagen, 2205.11487). Paper reports " "FID for fidelity and CLIP score for image-text alignment." ), - metrics=["fid", "clip_score"], # §4.1: FID + CLIP score - task_type="text_to_image", + metrics=["fid", "clip_score"], + task_type=TASK_TYPE_TEXT_IMAGE, reference="https://arxiv.org/abs/2205.11487", ), Benchmark( @@ -223,11 +215,13 @@ def list(cls, task_type: str | None = None) -> list[str]: name="GenEval", description=( "Compositional text-to-image benchmark with 6 categories: single object, two object, " - "counting, colors, position, color attributes. Evaluates fine-grained alignment " - "between prompts and generated images via VQA-style questions." + "counting, colors, position, color attributes. Uses atomic yes/no questions per prompt; " + "``Task.from_benchmark`` wires ``qa_accuracy`` with strict per-image aggregation " + "(all questions must pass) plus ``clip_score``. For holistic VQAScore-style scoring " + "use GenAI Bench with ``vqa``." ), - metrics=["qa_accuracy", "clip_score"], # strict QA + CLIP score - task_type="text_to_image", + metrics=["qa_accuracy", "clip_score"], + task_type=TASK_TYPE_TEXT_IMAGE, reference="https://arxiv.org/abs/2310.11513", ), Benchmark( @@ -237,7 +231,7 @@ def list(cls, task_type: str | None = None) -> list[str]: "Covers anime, concept-art, paintings, and photo styles with human preference data." ), metrics=[], # Paper uses HPS scoring model; not in Pruna - task_type="text_to_image", + task_type=TASK_TYPE_TEXT_IMAGE, reference="https://arxiv.org/abs/2306.09341", ), Benchmark( @@ -246,18 +240,19 @@ def list(cls, task_type: str | None = None) -> list[str]: "Image editing benchmark with 8 edit types: replace, add, remove, adjust, extract, " "style, background, compose. Evaluates instruction-following for inpainting and editing." ), - metrics=[], # Paper uses GPT-4o/ImgEdit-Judge; not in Pruna - task_type="text_to_image", + metrics=["img_edit_score"], # Paper uses GPT-4o rubric scores + FakeShield judge. + task_type=TASK_TYPE_TEXT_PLUS_IMAGE_IMAGE, reference="https://arxiv.org/abs/2505.20275", ), Benchmark( name="Long Text Bench", description=( - "Text-to-image benchmark for long, detailed prompts. Evaluates model ability to " - "handle complex multi-clause descriptions and maintain coherence across long instructions." + "Text rendering benchmark evaluating whether T2I models correctly render specific text strings " + "provided in prompts. Uses ``text_score`` (normalized character accuracy in [0, 1]). " + "This is OCR correctness, not long-prompt semantic alignment." ), metrics=["text_score"], - task_type="text_to_image", + task_type=TASK_TYPE_TEXT_IMAGE, reference="https://arxiv.org/abs/2507.22058", ), Benchmark( @@ -265,52 +260,53 @@ def list(cls, task_type: str | None = None) -> list[str]: description=( "General image editing benchmark with 11 task types: background change, color alter, " "material alter, motion change, style change, subject add/remove/replace, text change, " - "tone transfer, and human retouching." + "tone transfer, and human retouching. " + "Evaluated with VIEScore in text-image-edit (TIE) mode when source image bytes are available." ), - metrics=["vie_score"], - task_type="text+image_image", + metrics=["vie_score"], # VIEScore is explicitly named in GEditBench. + task_type=TASK_TYPE_TEXT_PLUS_IMAGE_IMAGE, reference="https://arxiv.org/abs/2504.17761", ), Benchmark( name="OneIG Anime Stylization", description="OneIG subset: anime and stylized imagery.", metrics=["oneig_alignment"], - task_type="text_to_image", + task_type=TASK_TYPE_TEXT_IMAGE, reference="https://arxiv.org/abs/2506.07977", ), Benchmark( name="OneIG General Object", description="OneIG subset: everyday objects and scenes.", metrics=["oneig_alignment"], - task_type="text_to_image", + task_type=TASK_TYPE_TEXT_IMAGE, reference="https://arxiv.org/abs/2506.07977", ), Benchmark( name="OneIG Knowledge Reasoning", description="OneIG subset: knowledge- and reasoning-heavy prompts.", metrics=["oneig_reasoning"], - task_type="text_to_image", + task_type=TASK_TYPE_TEXT_IMAGE, reference="https://arxiv.org/abs/2506.07977", ), Benchmark( name="OneIG Multilingualism", description="OneIG subset: multilingual prompts (incl. Chinese splits).", metrics=["oneig_alignment"], - task_type="text_to_image", + task_type=TASK_TYPE_TEXT_IMAGE, reference="https://arxiv.org/abs/2506.07977", ), Benchmark( name="OneIG Portrait", description="OneIG subset: people and portraits.", metrics=["oneig_alignment"], - task_type="text_to_image", + task_type=TASK_TYPE_TEXT_IMAGE, reference="https://arxiv.org/abs/2506.07977", ), Benchmark( name="OneIG Text Rendering", description="OneIG subset: text and graphics painted into the image.", metrics=["oneig_text_score"], - task_type="text_to_image", + task_type=TASK_TYPE_TEXT_IMAGE, reference="https://arxiv.org/abs/2506.07977", ), Benchmark( @@ -320,7 +316,7 @@ def list(cls, task_type: str | None = None) -> list[str]: "global, and other descriptive aspects with natural-language questions for alignment." ), metrics=[], # Paper uses custom evaluation; not in Pruna - task_type="text_to_image", + task_type=TASK_TYPE_TEXT_IMAGE, reference="https://arxiv.org/abs/2403.05135", ), ]: diff --git a/src/pruna/evaluation/metrics/metric_img_edit_score.py b/src/pruna/evaluation/metrics/metric_img_edit_score.py new file mode 100644 index 00000000..4c1832af --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_img_edit_score.py @@ -0,0 +1,219 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Image Edit Score metric. + +VLM-based instruction-following score for image editing. Evaluates how well an edited image +follows the given editing instruction on a 0-10 scale. Related work: EditScore (arXiv:2509.23909), +ADIEE (ICCV 2025). + +When the ``ImgEdit`` benchmark provides a per-sample ``judge_prompt`` and +``source_image_bytes`` in the auxiliaries, the metric mirrors the ImgEdit paper +evaluation protocol: the judge_prompt rubric (three 1-5 criterion scores) is +filled with the editing instruction, both source and edited images are shown to +the VLM, and the minimum of the three criterion scores is normalised to [0, 1] by +dividing by 5 (consistent with VIEScore methodology: the weakest criterion governs). +Without these auxiliaries the metric falls back to a single-image generic 0-10 prompt. +""" + +from __future__ import annotations + +from typing import Any, Literal + +import torch + +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.utils import ( + SINGLE, + metric_data_processor, +) +from pruna.evaluation.metrics.vlm_base import ( + BaseVLM, + StatefulVLMMeanScoresMetric, + auxiliary_dicts_from_gt, + prompts_from_y_x_inputs, +) +from pruna.evaluation.metrics.vlm_utils import ( + FloatOutput, + VIEScoreJsonOutput, + _process_images, + get_score_from_response, + pil_rgb_from_aux_image_bytes, + viescore_min_scores_0_10, +) + +_FALLBACK_QUESTION = ( + 'On a scale of 0 to 10, how well does this edited image follow the instruction "{prompt}"? ' + "0 = instruction not followed at all, 10 = perfectly executed. Reply with a single number." +) + +_JUDGE_JSON_SUFFIX = ( + '\n\nProvide your three criterion scores as JSON: {"score": [score1, score2, score3]} ' + "where each score is a number from 1 to 5." +) + + +@MetricRegistry.register("img_edit_score") +class ImageEditScoreMetric(StatefulVLMMeanScoresMetric): + """ + Image Edit Score metric. + + VLM-based instruction-following score for image editing. Evaluates how well an edited image + follows the given editing instruction. Higher scores indicate better editing quality. + + When auxiliaries contain ``judge_prompt`` and ``source_image_bytes`` (as provided + by the ImgEdit benchmark), the metric passes **both** the source (before) and edited + (after) images to the VLM together with the dataset-specific rubric. This matches + the ImgEdit paper's evaluation protocol. Without these fields, it falls back to a + single-image generic question. + + Related work: EditScore (arXiv:2509.23909), ADIEE (ICCV 2025). + + Parameters + ---------- + *args : Any + Additional positional arguments. + vlm : BaseVLM | None, optional + Custom VLM instance. If provided, vlm_type and model_name are ignored. + vlm_type : {"litellm", "transformers"}, optional + VLM backend. Default is "litellm". + model_name : str | None, optional + Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not + provided (e.g. ``openai/gpt-4o``). + vlm_kwargs : dict, optional + Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, + set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. + structured_output : bool, optional + Use structured generation (litellm pydantic; transformers outlines when applicable). + Default is True. + device : str | torch.device | None, optional + Device for transformers VLM. + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional arguments. + + Examples + -------- + Same ``hosted`` / ``local`` pattern as :func:`~pruna.evaluation.metrics.vlm_base.get_vlm`: + + .. code-block:: python + + import torch + + from pruna.evaluation.metrics import ImageEditScoreMetric + + hosted = ImageEditScoreMetric(vlm_type="litellm", model_name="openai/gpt-4o") + local = ImageEditScoreMetric( + vlm_type="transformers", + model_name="HuggingFaceTB/SmolVLM-256M-Instruct", + device="cpu", + vlm_kwargs={"model_load_kwargs": {"torch_dtype": torch.float32}}, + ) + """ + + scores: list[float] + default_call_type: str = "y_x" + higher_is_better: bool = True + metric_name: str = "img_edit_score" + + def __init__( + self, + *args, + vlm: BaseVLM | None = None, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str | None = None, + vlm_kwargs: dict | None = None, + structured_output: bool = True, + device: str | torch.device | None = None, + api_key: str | None = None, + call_type: str = SINGLE, + **kwargs: Any, + ) -> None: + super().__init__(device=device) + self.response_format = FloatOutput if structured_output else None + + self._init_vlm_scores( + vlm=vlm, + vlm_type=vlm_type, + model_name=model_name, + vlm_kwargs=vlm_kwargs, + structured_output=structured_output, + device=device, + api_key=api_key, + call_type=call_type, + ) + + def update(self, x: list[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None: + """ + Update the metric with new batch data. + + When ``gt`` auxiliaries contain ``judge_prompt`` and ``source_image_bytes``, the + metric uses the dataset rubric and a before/after two-image comparison. Otherwise + it falls back to a single-image generic question. + + Parameters + ---------- + x : List[Any] | torch.Tensor + The input data (editing instructions / prompts). + gt : torch.Tensor + Auxiliaries per sample (may contain ``judge_prompt`` and ``source_image_bytes``). + outputs : torch.Tensor + The output (edited) images. + """ + inputs = metric_data_processor(x, gt, outputs, self.call_type) + images = _process_images(inputs[0]) + prompts = prompts_from_y_x_inputs(inputs, len(images)) + aux_list = auxiliary_dicts_from_gt(gt, len(images)) + + for i, image in enumerate(images): + prompt = prompts[i] if i < len(prompts) else "" + aux_row = aux_list[i] + + judge_prompt = aux_row.get("judge_prompt", "") or "" + source_image = pil_rgb_from_aux_image_bytes(aux_row, min_bytes_in_value_scan=100) + + if judge_prompt and source_image is not None: + filled = judge_prompt.replace("", prompt).strip() + question = filled + _JUDGE_JSON_SUFFIX + try: + responses = self.vlm.generate_with_image_lists( + [[source_image, image]], [question], response_format=VIEScoreJsonOutput + ) + raw = viescore_min_scores_0_10(responses[0]) + if raw: + score = max(0.0, min(1.0, float(min(raw)) / 5.0)) + self.scores.append(score) + continue + except (NotImplementedError, AttributeError): + pass + + question = _FALLBACK_QUESTION.format(prompt=prompt) + responses = self.vlm.generate([image], [question], response_format=self.response_format) + self.scores.append(get_score_from_response(responses[0])) + + def compute(self) -> MetricResult: + """ + Compute the image edit score. + + Returns + ------- + MetricResult + The mean image edit score across all updates. + """ + return self.compute_mean_of_scores() diff --git a/tests/evaluation/test_vision_metrics.py b/tests/evaluation/test_vision_metrics.py index e7284eee..f41df366 100644 --- a/tests/evaluation/test_vision_metrics.py +++ b/tests/evaluation/test_vision_metrics.py @@ -5,6 +5,7 @@ import pytest import torch +from pruna.evaluation.metrics.metric_img_edit_score import ImageEditScoreMetric from pruna.evaluation.metrics.metric_vie_score import VieScoreMetric from pruna.evaluation.metrics.metric_vqa import VQAMetric from pruna.evaluation.metrics.vlm_base import BaseVLM @@ -42,3 +43,16 @@ def test_vie_score_uses_json_score_lists() -> None: result = metric.compute() assert abs(result.result - 0.8) < 0.01 + + +@pytest.mark.cpu +def test_img_edit_score_negative_response_clamped() -> None: + """Image edit score never goes below zero for malformed negative outputs.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.generate.return_value = ['{"score": -10.0}'] + + metric = ImageEditScoreMetric(vlm=mock_vlm, device="cpu", structured_output=True) + metric.update(["replace the boot with a mug"], torch.zeros(1), _dummy_image()) + result = metric.compute() + + assert result.result == 0.0