diff --git a/src/pruna/evaluation/benchmarks.py b/src/pruna/evaluation/benchmarks.py index e3f58164..34b76d5a 100644 --- a/src/pruna/evaluation/benchmarks.py +++ b/src/pruna/evaluation/benchmarks.py @@ -256,7 +256,7 @@ def list(cls, task_type: str | None = None) -> list[str]: "Text-to-image benchmark for long, detailed prompts. Evaluates model ability to " "handle complex multi-clause descriptions and maintain coherence across long instructions." ), - metrics=[], # Paper uses text_score/TIT-Score; not in Pruna + metrics=["text_score"], task_type="text_to_image", reference="https://arxiv.org/abs/2507.22058", ), @@ -299,6 +299,13 @@ def list(cls, task_type: str | None = None) -> list[str]: task_type="text_to_image", reference="https://arxiv.org/abs/2506.07977", ), + Benchmark( + name="OneIG Text Rendering", + description="OneIG subset: text and graphics painted into the image.", + metrics=["oneig_text_score"], + task_type="text_to_image", + reference="https://arxiv.org/abs/2506.07977", + ), Benchmark( name="DPG", description=( diff --git a/src/pruna/evaluation/metrics/metric_text_score.py b/src/pruna/evaluation/metrics/metric_text_score.py new file mode 100644 index 00000000..b56a6957 --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_text_score.py @@ -0,0 +1,372 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Text rendering via OCR: mean Levenshtein (``text_score`` / ``ocr_levenshtein``). + +OneIG composite: ``oneig_text_score`` / ``ocr_text_score``. +""" + +from __future__ import annotations + +from abc import abstractmethod +from typing import Any, Literal + +import numpy as np +import torch + +from pruna.engine.utils import set_to_best_available_device +from pruna.evaluation.metrics.metric_stateful import StatefulMetric +from pruna.evaluation.metrics.metric_text_score_utils import ( + levenshtein, + normalize_text_simple, + oneig_mean_text_score, + oneig_per_sample_contributions, +) +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.utils import ( + SINGLE, + get_call_type_for_single_metric, + metric_data_processor, +) +from pruna.evaluation.metrics.vlm_base import BaseVLM, get_vlm +from pruna.evaluation.metrics.vlm_utils import TextOutput, _process_images, get_text_from_response + +OCR_PROMPT = ( + "Extract all text visible in this image. Include logos, stylized fonts, handwritten text, " + "and non-standard typography. Return only the extracted text, exactly as it appears—no preamble, " + "explanation, or markdown. Preserve words, numbers, punctuation, and spacing. " + "IMPORTANT: Do NOT correct spelling errors or typos. If a word is misspelled in the image " + "(e.g. 'Teclhology' instead of 'Technology'), reproduce it exactly as it appears, including the misspelling. " + "If no text is recognized, reply with exactly: No text recognized" +) + + +class _BaseVLMOCRTextMetric(StatefulMetric): + """ + Shared VLM OCR over rendered images with ground truth in ``text_content``. + + Subclasses implement how OCR and GT strings are scored and aggregated. + + Parameters + ---------- + *args : Any + Additional positional arguments (unused; registry compatibility). + vlm : BaseVLM | None, optional + Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored. + vlm_type : {'litellm', 'transformers'}, optional + VLM backend. Default is ``'litellm'``. + model_name : str | None, optional + Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not + provided (e.g. ``openai/gpt-4o``). + vlm_kwargs : dict, optional + Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, + set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. + structured_output : bool, optional + Use structured generation (litellm pydantic; transformers outlines when applicable). + Default is True. + device : str | torch.device | None, optional + Device for transformers VLM. + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional arguments. + + Examples + -------- + OCR metrics call ``get_vlm`` directly (not ``StatefulVLMMeanScoresMetric``). Same + ``hosted`` / ``local`` pattern as :func:`~pruna.evaluation.metrics.vlm_base.get_vlm`: + + .. code-block:: python + + import torch + + from pruna.evaluation.metrics import TextScoreMetric + + hosted = TextScoreMetric(vlm_type="litellm", model_name="openai/gpt-4o") + local = TextScoreMetric( + vlm_type="transformers", + model_name="HuggingFaceTB/SmolVLM-256M-Instruct", + device="cpu", + vlm_kwargs={"model_load_kwargs": {"torch_dtype": torch.float32}}, + ) + + Use ``OneIGTextScoreMetric`` the same way for ``oneig_text_score`` / ``ocr_text_score``. + """ + + default_call_type: str = "y_gt" + + def __init__( + self, + *args: Any, + vlm: BaseVLM | None = None, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str | None = None, + vlm_kwargs: dict | None = None, + structured_output: bool = True, + device: str | torch.device | None = None, + api_key: str | None = None, + call_type: str = SINGLE, + **kwargs: Any, + ) -> None: + super().__init__(device=device) + self.device = set_to_best_available_device(device) + + self.vlm = get_vlm( + vlm=vlm, + vlm_type=vlm_type, + model_name=model_name, + device=device, + api_key=api_key, + structured_output=structured_output, + **(vlm_kwargs or {}), + ) + self.response_format = TextOutput if structured_output else None + + self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) + self.higher_is_better = type(self).higher_is_better + + @abstractmethod + def _accumulate_sample(self, text_gt: str, ocr_text: str) -> None: + """Update metric state from one ground-truth / OCR pair.""" + + @abstractmethod + def _compute_result_value(self) -> float: + """Return the scalar reported as ``MetricResult.result``.""" + + def update(self, x: list[Any] | torch.Tensor, gt: list[str], outputs: torch.Tensor) -> None: + """ + Run OCR on outputs and score against ``text_content`` (or string list) auxiliaries. + + Parameters + ---------- + x : List[Any] | torch.Tensor + Batch prompts or metadata. + gt : list of dict or list of str + Auxiliaries with ``'text_content'`` as a string, a list of strings (joined with + newlines), or plain strings per batch item. + outputs : torch.Tensor + Rendered images. + """ + inputs = metric_data_processor(x, gt, outputs, self.call_type) + images = _process_images(inputs[0]) + auxiliaries = inputs[1] if len(inputs) > 1 and isinstance(inputs[1], (list, tuple)) else [{}] * len(images) + for i, image in enumerate(images): + responses = self.vlm.generate([image], [OCR_PROMPT], response_format=self.response_format) + raw = responses[0] if responses else "" + ocr_text = get_text_from_response(raw) + aux = auxiliaries[i] if i < len(auxiliaries) else {} + text_gt = aux.get("text_content") if isinstance(aux, dict) else (aux if isinstance(aux, str) else None) + if isinstance(text_gt, list): + text_gt = "\n".join(str(x) for x in text_gt) + if text_gt is None: + raise ValueError( + f"{self.metric_name} requires 'text_content' in auxiliaries. " + "Use a benchmark that provides it (e.g. LongTextBench, OneIG)." + ) + self._accumulate_sample(text_gt, ocr_text) + + def compute(self) -> MetricResult: + """ + Aggregate batched contributions into a single metric value. + + Returns + ------- + MetricResult + Named result with ``higher_is_better`` taken from the class. + """ + value = self._compute_result_value() + return MetricResult(self.metric_name, self.__dict__, float(value)) + + +@MetricRegistry.register("ocr_levenshtein") +@MetricRegistry.register("text_score") +class TextScoreMetric(_BaseVLMOCRTextMetric): + """ + OCR then mean normalized character accuracy in [0, 1] (higher is better). + + Registry: ``ocr_levenshtein`` (descriptive) and ``text_score`` (legacy). + + Uses light normalization only (not the full OneIG preprocess). See + :class:`OneIGTextScoreMetric` for the OneIG composite ``ocr_text_score``. + + Parameters + ---------- + *args : Any + Additional positional arguments (unused; registry compatibility). + vlm : BaseVLM | None, optional + Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored. + vlm_type : {'litellm', 'transformers'}, optional + VLM backend. Default is ``'litellm'``. + model_name : str | None, optional + Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not + provided (e.g. ``openai/gpt-4o``). + vlm_kwargs : dict, optional + Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, + set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. + structured_output : bool, optional + Use structured generation (litellm pydantic; transformers outlines when applicable). + Default is True. + device : str | torch.device | None, optional + Device for transformers VLM. + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional keyword arguments forwarded to :class:`_BaseVLMOCRTextMetric`. + """ + + scores: list[float] + higher_is_better: bool = True + metric_name: str = "text_score" + + def __init__( + self, + *args: Any, + vlm: BaseVLM | None = None, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str | None = None, + vlm_kwargs: dict[str, Any] | None = None, + structured_output: bool = True, + device: str | torch.device | None = None, + api_key: str | None = None, + call_type: str = SINGLE, + **kwargs: Any, + ) -> None: + super().__init__( + *args, + vlm=vlm, + vlm_type=vlm_type, + model_name=model_name, + vlm_kwargs=vlm_kwargs, + structured_output=structured_output, + device=device, + api_key=api_key, + call_type=call_type, + **kwargs, + ) + self.add_state("scores", []) + + def _accumulate_sample(self, text_gt: str, ocr_text: str) -> None: + norm_gt = normalize_text_simple(text_gt) + norm_ocr = normalize_text_simple(ocr_text) + ed = levenshtein(norm_ocr, norm_gt) + denom = max(float(len(norm_gt)), 1.0) + self.scores.append(1.0 - min(1.0, ed / denom)) + + def _compute_result_value(self) -> float: + if not self.scores: + return 0.0 + return float(np.mean(self.scores)) + + +@MetricRegistry.register("ocr_text_score") +@MetricRegistry.register("oneig_text_score") +class OneIGTextScoreMetric(_BaseVLMOCRTextMetric): + """ + OCR then OneIG-style composite text score (higher is better). + + Registry: ``ocr_text_score`` (descriptive) and ``oneig_text_score`` (protocol). + + Aggregates edit distance, completion rate, and word/char accuracy like + ``OneIG-Benchmark/scripts/text/text_score.py``. + + Parameters + ---------- + *args : Any + Additional positional arguments (forwarded to :class:`_BaseVLMOCRTextMetric`). + language_mode : {'EN', 'ZH'}, optional + Selects ``MAX_EDIT_DISTANCE`` (100 vs 50) for the composite. + vlm : BaseVLM | None, optional + Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored. + vlm_type : {'litellm', 'transformers'}, optional + VLM backend. Default is ``'litellm'``. + model_name : str | None, optional + Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not + provided (e.g. ``openai/gpt-4o``). + vlm_kwargs : dict, optional + Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, + set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. + structured_output : bool, optional + Use structured generation (litellm pydantic; transformers outlines when applicable). + Default is True. + device : str | torch.device | None, optional + Device for transformers VLM. + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional keyword arguments forwarded to :class:`_BaseVLMOCRTextMetric`. + """ + + edit_distances: list[float] + completion_ratios: list[float] + match_counts: list[int] + gt_totals: list[int] + + higher_is_better: bool = True + metric_name: str = "oneig_text_score" + + def __init__( + self, + *args: Any, + language_mode: Literal["EN", "ZH"] = "EN", + vlm: BaseVLM | None = None, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str | None = None, + vlm_kwargs: dict[str, Any] | None = None, + structured_output: bool = True, + device: str | torch.device | None = None, + api_key: str | None = None, + call_type: str = SINGLE, + **kwargs: Any, + ) -> None: + super().__init__( + *args, + vlm=vlm, + vlm_type=vlm_type, + model_name=model_name, + vlm_kwargs=vlm_kwargs, + structured_output=structured_output, + device=device, + api_key=api_key, + call_type=call_type, + **kwargs, + ) + self.language_mode = language_mode + self.add_state("edit_distances", []) + self.add_state("completion_ratios", []) + self.add_state("match_counts", []) + self.add_state("gt_totals", []) + + def _accumulate_sample(self, text_gt: str, ocr_text: str) -> None: + ed, cr, mcount, gtot = oneig_per_sample_contributions(text_gt, ocr_text) + self.edit_distances.append(ed) + self.completion_ratios.append(cr) + self.match_counts.append(mcount) + self.gt_totals.append(gtot) + + def _compute_result_value(self) -> float: + *_, text_score = oneig_mean_text_score( + self.edit_distances, + self.completion_ratios, + self.match_counts, + self.gt_totals, + self.language_mode, + ) + return text_score diff --git a/src/pruna/evaluation/metrics/metric_text_score_utils.py b/src/pruna/evaluation/metrics/metric_text_score_utils.py new file mode 100644 index 00000000..8aa7d850 --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_text_score_utils.py @@ -0,0 +1,274 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpers for text rendering metrics (simple Levenshtein vs OneIG-style composite). + +OneIG-style preprocessing and aggregation follow +`OneIG-Benchmark/scripts/text/text_utils.py` and `text_score.py` (Apache-2.0). +""" + +from __future__ import annotations + +import re +from collections import Counter +from typing import Literal + +_OCR_HALLUCINATION_KEYWORDS = ("addCriterion", "No text recognized.", "No text recognized") + + +def normalize_text_simple(s: str) -> str: + """ + Normalize text for the legacy ``text_score`` metric (light cleanup + spacing). + + Parameters + ---------- + s : str + Raw string. + + Returns + ------- + str + Normalized string. + """ + cleaned = re.sub( + r"[^\u4e00-\u9fa5a-zA-Z0-9\sàâäéèêëîïôöùûüçÀÂÄÉÈÊËÎÏÔÖÙÛÜÇ]", + "", + s or "", + ) + return re.sub(r"\s+", " ", cleaned).strip() + + +def levenshtein(s1: str, s2: str) -> float: + """ + Symmetric Levenshtein edit distance. + + Parameters + ---------- + s1 : str + First string. + s2 : str + Second string. + + Returns + ------- + float + Edit distance. + """ + if len(s1) < len(s2): + return levenshtein(s2, s1) + prev = list(range(len(s2) + 1)) + for i, c1 in enumerate(s1): + curr = [i + 1] + for j, c2 in enumerate(s2): + curr.append(min(prev[j] + (c1 != c2), prev[j + 1] + 1, curr[-1] + 1)) + prev = curr + return float(prev[-1]) + + +def contains_chinese(text: str) -> bool: + """ + Return True if ``text`` contains CJK unified ideographs. + + Parameters + ---------- + text : str + Input text. + + Returns + ------- + bool + Whether Chinese characters are present. + """ + return bool(re.search(r"[\u4e00-\u9fff]", text)) + + +def preprocess_string_oneig(s: str) -> str: + """ + OneIG ``preprocess_string``: charset filter, Chinese vs whitespace normalization. + + Parameters + ---------- + s : str + Raw string. + + Returns + ------- + str + Preprocessed string (ground truth or OCR). + """ + raw = s or "" + cleaned = re.sub( + r"[^\u4e00-\u9fa5a-zA-Z0-9\sàâäéèêëîïôöùûüçÀÂÄÉÈÊËÎÏÔÖÙÛÜÇ]", + "", + raw, + ) + if contains_chinese(cleaned): + pattern = re.compile( + r"[\u4e00-\u9fa5a-zA-Z0-9àâäéèêëîïôöùûüçÀÂÄÉÈÊËÎÏÔÖÙÛÜÇ]", + ) + return "".join(pattern.findall(raw)).strip() + return re.sub(r"\s+", " ", cleaned).strip() + + +def clean_oneig_ocr_hallucinations(text: str) -> str: + """ + Remove known OCR boilerplate substrings (OneIG ``clean_and_remove_hallucinations``). + + Parameters + ---------- + text : str + Raw OCR output. + + Returns + ------- + str + Cleaned OCR text. + """ + out = text or "" + for keyword in _OCR_HALLUCINATION_KEYWORDS: + out = out.replace(keyword, "").replace(f"\n{keyword}", "").replace(f"{keyword}\n", "") + return out + + +def calculate_char_match_ratio( + text_gt: str, + ocr_str: str, +) -> tuple[int, float, int]: + """ + OneIG overlap stats: character multiset for ZH, word multiset for EN. + + Parameters + ---------- + text_gt : str + Preprocessed ground truth. + ocr_str : str + Preprocessed OCR. + + Returns + ------- + total_match_count : int + Overlap count used in WAC numerator aggregation. + ratio : float + Per-sample ratio (mean of ratios is not used in the official aggregate). + gt_total : int + Denominator term: ``sum(gt_counter.values())`` for WAC aggregation. + """ + if contains_chinese(text_gt): + gt_counter: Counter[str] = Counter(text_gt) + ocr_counter: Counter[str] = Counter(ocr_str) + total_match_count = int(sum((gt_counter & ocr_counter).values())) + ratio = total_match_count / len(text_gt) if len(text_gt) > 0 else 0.0 + return total_match_count, ratio, int(sum(gt_counter.values())) + + words_gt = text_gt.split() + words_ocr = ocr_str.split() + gt_counter = Counter(words_gt) + ocr_counter = Counter(words_ocr) + total_match_count = int(sum((gt_counter & ocr_counter).values())) + total_gt_count = len(words_gt) + ratio = total_match_count / total_gt_count if total_gt_count > 0 else 0.0 + return total_match_count, ratio, int(sum(gt_counter.values())) + + +def max_edit_distance_for_language(language_mode: Literal["EN", "ZH"]) -> int: + """ + OneIG ``MAX_EDIT_DISTANCE`` (100 for English, 50 for Chinese benchmark split). + + Parameters + ---------- + language_mode : {'EN', 'ZH'} + Benchmark language mode. + + Returns + ------- + int + Cap used in the composite text score. + """ + return 50 if language_mode == "ZH" else 100 + + +def oneig_per_sample_contributions(text_gt: str, ocr_raw: str) -> tuple[float, float, int, int]: + """ + Per-sample terms for OneIG aggregation (ED, CR, WAC numerator/denominator parts). + + Parameters + ---------- + text_gt : str + Ground-truth text (dataset field). + ocr_raw : str + Raw OCR string from the VLM. + + Returns + ------- + edit_distance : float + Levenshtein distance after OneIG preprocess. + completion_ratio : float + 1.0 if distance is zero, else 0.0. + match_count : int + Overlap count for WAC. + gt_total : int + Ground-truth token count term for WAC denominator. + """ + ocr_clean = clean_oneig_ocr_hallucinations(ocr_raw) + gt_pre = preprocess_string_oneig(text_gt) + ocr_pre = preprocess_string_oneig(ocr_clean) + ed = levenshtein(ocr_pre, gt_pre) + cr = 1.0 if ed == 0.0 else 0.0 + match_count, _, gt_total = calculate_char_match_ratio(gt_pre, ocr_pre) + return ed, cr, match_count, gt_total + + +def oneig_mean_text_score( + edit_distances: list[float], + completion_ratios: list[float], + match_counts: list[int], + gt_totals: list[int], + language_mode: Literal["EN", "ZH"], +) -> tuple[float, float, float, float]: + """ + Aggregate OneIG ED, CR, WAC and composite text score (higher is better). + + Parameters + ---------- + edit_distances : list of float + Per-sample edit distances. + completion_ratios : list of float + Per-sample completion indicators. + match_counts : list of int + Per-sample WAC numerators. + gt_totals : list of int + Per-sample WAC denominator terms. + language_mode : {'EN', 'ZH'} + Selects ``MAX_EDIT_DISTANCE``. + + Returns + ------- + ed_mean : float + Mean edit distance. + cr_mean : float + Mean completion ratio. + wac : float + Micro-averaged WAC: ``sum(match_counts) / sum(gt_totals)``. + text_score : float + Composite: ``1 - min(MAX_ED, ED) * (1 - CR) * (1 - WAC) / MAX_ED``. + """ + cap = float(max_edit_distance_for_language(language_mode)) + if not edit_distances: + return 0.0, 0.0, 0.0, 0.0 + ed_mean = float(sum(edit_distances) / len(edit_distances)) + cr_mean = float(sum(completion_ratios) / len(completion_ratios)) + denom = float(sum(gt_totals)) + wac = float(sum(match_counts) / denom) if denom > 0.0 else 0.0 + text_score = 1.0 - min(cap, ed_mean) * (1.0 - cr_mean) * (1.0 - wac) / cap + return ed_mean, cr_mean, wac, text_score