Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/pruna/evaluation/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def list(cls, task_type: str | None = None) -> list[str]:
"Text-to-image benchmark for long, detailed prompts. Evaluates model ability to "
"handle complex multi-clause descriptions and maintain coherence across long instructions."
),
metrics=[], # Paper uses text_score/TIT-Score; not in Pruna
metrics=["text_score"],
task_type="text_to_image",
reference="https://arxiv.org/abs/2507.22058",
),
Expand Down Expand Up @@ -299,6 +299,13 @@ def list(cls, task_type: str | None = None) -> list[str]:
task_type="text_to_image",
reference="https://arxiv.org/abs/2506.07977",
),
Benchmark(
name="OneIG Text Rendering",
description="OneIG subset: text and graphics painted into the image.",
metrics=["oneig_text_score"],
task_type="text_to_image",
reference="https://arxiv.org/abs/2506.07977",
),
Benchmark(
name="DPG",
description=(
Expand Down
372 changes: 372 additions & 0 deletions src/pruna/evaluation/metrics/metric_text_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,372 @@
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Text rendering via OCR: mean Levenshtein (``text_score`` / ``ocr_levenshtein``).

OneIG composite: ``oneig_text_score`` / ``ocr_text_score``.
"""

from __future__ import annotations

from abc import abstractmethod
from typing import Any, Literal

import numpy as np
import torch

from pruna.engine.utils import set_to_best_available_device
from pruna.evaluation.metrics.metric_stateful import StatefulMetric
from pruna.evaluation.metrics.metric_text_score_utils import (
levenshtein,
normalize_text_simple,
oneig_mean_text_score,
oneig_per_sample_contributions,
)
from pruna.evaluation.metrics.registry import MetricRegistry
from pruna.evaluation.metrics.result import MetricResult
from pruna.evaluation.metrics.utils import (
SINGLE,
get_call_type_for_single_metric,
metric_data_processor,
)
from pruna.evaluation.metrics.vlm_base import BaseVLM, get_vlm
from pruna.evaluation.metrics.vlm_utils import TextOutput, _process_images, get_text_from_response

OCR_PROMPT = (
"Extract all text visible in this image. Include logos, stylized fonts, handwritten text, "
"and non-standard typography. Return only the extracted text, exactly as it appears—no preamble, "
"explanation, or markdown. Preserve words, numbers, punctuation, and spacing. "
"IMPORTANT: Do NOT correct spelling errors or typos. If a word is misspelled in the image "
"(e.g. 'Teclhology' instead of 'Technology'), reproduce it exactly as it appears, including the misspelling. "
"If no text is recognized, reply with exactly: No text recognized"
)


class _BaseVLMOCRTextMetric(StatefulMetric):
"""
Shared VLM OCR over rendered images with ground truth in ``text_content``.

Subclasses implement how OCR and GT strings are scored and aggregated.

Parameters
----------
*args : Any
Additional positional arguments (unused; registry compatibility).
vlm : BaseVLM | None, optional
Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored.
vlm_type : {'litellm', 'transformers'}, optional
VLM backend. Default is ``'litellm'``.
model_name : str | None, optional
Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not
provided (e.g. ``openai/gpt-4o``).
vlm_kwargs : dict, optional
Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models,
set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options.
structured_output : bool, optional
Use structured generation (litellm pydantic; transformers outlines when applicable).
Default is True.
device : str | torch.device | None, optional
Device for transformers VLM.
api_key : str | None, optional
API key for litellm.
call_type : str, optional
Call type for the metric.
**kwargs : Any
Additional arguments.

Examples
--------
OCR metrics call ``get_vlm`` directly (not ``StatefulVLMMeanScoresMetric``). Same
``hosted`` / ``local`` pattern as :func:`~pruna.evaluation.metrics.vlm_base.get_vlm`:

.. code-block:: python

import torch

from pruna.evaluation.metrics import TextScoreMetric

hosted = TextScoreMetric(vlm_type="litellm", model_name="openai/gpt-4o")
local = TextScoreMetric(
vlm_type="transformers",
model_name="HuggingFaceTB/SmolVLM-256M-Instruct",
device="cpu",
vlm_kwargs={"model_load_kwargs": {"torch_dtype": torch.float32}},
)

Use ``OneIGTextScoreMetric`` the same way for ``oneig_text_score`` / ``ocr_text_score``.
"""

default_call_type: str = "y_gt"

def __init__(
self,
*args: Any,
vlm: BaseVLM | None = None,
vlm_type: Literal["litellm", "transformers"] = "litellm",
model_name: str | None = None,
vlm_kwargs: dict | None = None,
structured_output: bool = True,
device: str | torch.device | None = None,
api_key: str | None = None,
call_type: str = SINGLE,
**kwargs: Any,
) -> None:
super().__init__(device=device)
self.device = set_to_best_available_device(device)

self.vlm = get_vlm(
vlm=vlm,
vlm_type=vlm_type,
model_name=model_name,
device=device,
api_key=api_key,
structured_output=structured_output,
**(vlm_kwargs or {}),
)
self.response_format = TextOutput if structured_output else None

self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type)
self.higher_is_better = type(self).higher_is_better

@abstractmethod
def _accumulate_sample(self, text_gt: str, ocr_text: str) -> None:
"""Update metric state from one ground-truth / OCR pair."""

@abstractmethod
def _compute_result_value(self) -> float:
"""Return the scalar reported as ``MetricResult.result``."""

def update(self, x: list[Any] | torch.Tensor, gt: list[str], outputs: torch.Tensor) -> None:
"""
Run OCR on outputs and score against ``text_content`` (or string list) auxiliaries.

Parameters
----------
x : List[Any] | torch.Tensor
Batch prompts or metadata.
gt : list of dict or list of str
Auxiliaries with ``'text_content'`` as a string, a list of strings (joined with
newlines), or plain strings per batch item.
outputs : torch.Tensor
Rendered images.
"""
inputs = metric_data_processor(x, gt, outputs, self.call_type)
images = _process_images(inputs[0])
auxiliaries = inputs[1] if len(inputs) > 1 and isinstance(inputs[1], (list, tuple)) else [{}] * len(images)
for i, image in enumerate(images):
responses = self.vlm.generate([image], [OCR_PROMPT], response_format=self.response_format)
raw = responses[0] if responses else ""
ocr_text = get_text_from_response(raw)
aux = auxiliaries[i] if i < len(auxiliaries) else {}
text_gt = aux.get("text_content") if isinstance(aux, dict) else (aux if isinstance(aux, str) else None)
if isinstance(text_gt, list):
text_gt = "\n".join(str(x) for x in text_gt)
if text_gt is None:
raise ValueError(
f"{self.metric_name} requires 'text_content' in auxiliaries. "
"Use a benchmark that provides it (e.g. LongTextBench, OneIG)."
)
self._accumulate_sample(text_gt, ocr_text)

def compute(self) -> MetricResult:
"""
Aggregate batched contributions into a single metric value.

Returns
-------
MetricResult
Named result with ``higher_is_better`` taken from the class.
"""
value = self._compute_result_value()
return MetricResult(self.metric_name, self.__dict__, float(value))


@MetricRegistry.register("ocr_levenshtein")
@MetricRegistry.register("text_score")
class TextScoreMetric(_BaseVLMOCRTextMetric):
"""
OCR then mean normalized character accuracy in [0, 1] (higher is better).

Registry: ``ocr_levenshtein`` (descriptive) and ``text_score`` (legacy).

Uses light normalization only (not the full OneIG preprocess). See
:class:`OneIGTextScoreMetric` for the OneIG composite ``ocr_text_score``.

Parameters
----------
*args : Any
Additional positional arguments (unused; registry compatibility).
vlm : BaseVLM | None, optional
Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored.
vlm_type : {'litellm', 'transformers'}, optional
VLM backend. Default is ``'litellm'``.
model_name : str | None, optional
Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not
provided (e.g. ``openai/gpt-4o``).
vlm_kwargs : dict, optional
Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models,
set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options.
structured_output : bool, optional
Use structured generation (litellm pydantic; transformers outlines when applicable).
Default is True.
device : str | torch.device | None, optional
Device for transformers VLM.
api_key : str | None, optional
API key for litellm.
call_type : str, optional
Call type for the metric.
**kwargs : Any
Additional keyword arguments forwarded to :class:`_BaseVLMOCRTextMetric`.
"""

scores: list[float]
higher_is_better: bool = True
metric_name: str = "text_score"

def __init__(
self,
*args: Any,
vlm: BaseVLM | None = None,
vlm_type: Literal["litellm", "transformers"] = "litellm",
model_name: str | None = None,
vlm_kwargs: dict[str, Any] | None = None,
structured_output: bool = True,
device: str | torch.device | None = None,
api_key: str | None = None,
call_type: str = SINGLE,
**kwargs: Any,
) -> None:
super().__init__(
*args,
vlm=vlm,
vlm_type=vlm_type,
model_name=model_name,
vlm_kwargs=vlm_kwargs,
structured_output=structured_output,
device=device,
api_key=api_key,
call_type=call_type,
**kwargs,
)
self.add_state("scores", [])

def _accumulate_sample(self, text_gt: str, ocr_text: str) -> None:
norm_gt = normalize_text_simple(text_gt)
norm_ocr = normalize_text_simple(ocr_text)
ed = levenshtein(norm_ocr, norm_gt)
denom = max(float(len(norm_gt)), 1.0)
self.scores.append(1.0 - min(1.0, ed / denom))

def _compute_result_value(self) -> float:
if not self.scores:
return 0.0
return float(np.mean(self.scores))


@MetricRegistry.register("ocr_text_score")
@MetricRegistry.register("oneig_text_score")
class OneIGTextScoreMetric(_BaseVLMOCRTextMetric):
"""
OCR then OneIG-style composite text score (higher is better).

Registry: ``ocr_text_score`` (descriptive) and ``oneig_text_score`` (protocol).

Aggregates edit distance, completion rate, and word/char accuracy like
``OneIG-Benchmark/scripts/text/text_score.py``.

Parameters
----------
*args : Any
Additional positional arguments (forwarded to :class:`_BaseVLMOCRTextMetric`).
language_mode : {'EN', 'ZH'}, optional
Selects ``MAX_EDIT_DISTANCE`` (100 vs 50) for the composite.
vlm : BaseVLM | None, optional
Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored.
vlm_type : {'litellm', 'transformers'}, optional
VLM backend. Default is ``'litellm'``.
model_name : str | None, optional
Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not
provided (e.g. ``openai/gpt-4o``).
vlm_kwargs : dict, optional
Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models,
set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options.
structured_output : bool, optional
Use structured generation (litellm pydantic; transformers outlines when applicable).
Default is True.
device : str | torch.device | None, optional
Device for transformers VLM.
api_key : str | None, optional
API key for litellm.
call_type : str, optional
Call type for the metric.
**kwargs : Any
Additional keyword arguments forwarded to :class:`_BaseVLMOCRTextMetric`.
"""

edit_distances: list[float]
completion_ratios: list[float]
match_counts: list[int]
gt_totals: list[int]

higher_is_better: bool = True
metric_name: str = "oneig_text_score"

def __init__(
self,
*args: Any,
language_mode: Literal["EN", "ZH"] = "EN",
vlm: BaseVLM | None = None,
vlm_type: Literal["litellm", "transformers"] = "litellm",
model_name: str | None = None,
vlm_kwargs: dict[str, Any] | None = None,
structured_output: bool = True,
device: str | torch.device | None = None,
api_key: str | None = None,
call_type: str = SINGLE,
**kwargs: Any,
) -> None:
super().__init__(
*args,
vlm=vlm,
vlm_type=vlm_type,
model_name=model_name,
vlm_kwargs=vlm_kwargs,
structured_output=structured_output,
device=device,
api_key=api_key,
call_type=call_type,
**kwargs,
)
self.language_mode = language_mode
self.add_state("edit_distances", [])
self.add_state("completion_ratios", [])
self.add_state("match_counts", [])
self.add_state("gt_totals", [])

def _accumulate_sample(self, text_gt: str, ocr_text: str) -> None:
ed, cr, mcount, gtot = oneig_per_sample_contributions(text_gt, ocr_text)
self.edit_distances.append(ed)
self.completion_ratios.append(cr)
self.match_counts.append(mcount)
self.gt_totals.append(gtot)

def _compute_result_value(self) -> float:
*_, text_score = oneig_mean_text_score(
self.edit_distances,
self.completion_ratios,
self.match_counts,
self.gt_totals,
self.language_mode,
)
return text_score
Loading