From c933c4dd27c82e0216e069cc1d0203f45c820004 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Sat, 25 Apr 2026 12:49:15 +0200 Subject: [PATCH 1/3] feat(vendor): add LLM2Vec embedding model - Add LLM2Vec from OneIG vendor source - Includes Llama encoder and bidirectional models - Self-contained, no dependencies on Pruna internals - Licensed under Apache 2.0 --- .../metrics/vendor/NOTICE.oneig_llm2vec | 12 + .../metrics/vendor/oneig_llm2vec/llm2vec.py | 549 ++++++++++++++++++ .../oneig_llm2vec/modeling_llama_encoder.py | 107 ++++ .../models/bidirectional_llama.py | 228 ++++++++ 4 files changed, 896 insertions(+) create mode 100644 src/pruna/evaluation/metrics/vendor/NOTICE.oneig_llm2vec create mode 100644 src/pruna/evaluation/metrics/vendor/oneig_llm2vec/llm2vec.py create mode 100644 src/pruna/evaluation/metrics/vendor/oneig_llm2vec/modeling_llama_encoder.py create mode 100644 src/pruna/evaluation/metrics/vendor/oneig_llm2vec/models/bidirectional_llama.py diff --git a/src/pruna/evaluation/metrics/vendor/NOTICE.oneig_llm2vec b/src/pruna/evaluation/metrics/vendor/NOTICE.oneig_llm2vec new file mode 100644 index 00000000..01654bd4 --- /dev/null +++ b/src/pruna/evaluation/metrics/vendor/NOTICE.oneig_llm2vec @@ -0,0 +1,12 @@ +LLM2Vec (llm2vec package) vendored from OneIG-Benchmark. + +Source: https://github.com/OneIG-Bench/OneIG-Benchmark +Commit: 41b49831e79e6dde5323618c164da1c4cf0f699d +Path: scripts/utils/llm2clip/llm2vec/ + +OneIG-Benchmark is licensed under the Apache License 2.0. +See the project repository for full license text. + +``oneig_llm2vec/modeling_llama_encoder.py`` is derived from +McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp (Hugging Face Hub); +Pruna relaxes the upstream flash-attention-only constraint for CPU use. diff --git a/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/llm2vec.py b/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/llm2vec.py new file mode 100644 index 00000000..102f5b28 --- /dev/null +++ b/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/llm2vec.py @@ -0,0 +1,549 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Vendored from OneIG-Benchmark (commit 41b49831e79e6dde5323618c164da1c4cf0f699d). +# See NOTICE.oneig_llm2vec in parent directory. + +import json +import logging +import pathlib +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import torch +import torch.multiprocessing as mp +from peft import PeftModel +from torch import Tensor, device, nn +from tqdm import trange +from transformers import ( + AutoConfig, + AutoModel, + AutoTokenizer, + LlamaConfig, + PretrainedConfig, +) + +from pruna.evaluation.metrics.vendor.oneig_llm2vec.models.bidirectional_llama import LlamaBiModel + +logger = logging.getLogger(__name__) + + +def batch_to_device(batch, target_device: device | str): + """ + Move tensor values in a batch dict to ``target_device``. + + Parameters + ---------- + batch : dict[str, Any] + Mapping of feature names to tensors or other values; only ``torch.Tensor`` + values are moved. + target_device : torch.device or str + Device to move tensors to. + + Returns + ------- + dict[str, Any] + The same ``batch`` object with tensors updated in place. + """ + for key in batch: + if isinstance(batch[key], Tensor): + batch[key] = batch[key].to(target_device) + return batch + + +class LLM2Vec(nn.Module): + """ + Bidirectional LLM wrapper with configurable pooling for dense embeddings. + + Parameters + ---------- + model : transformers.AutoModel + Encoder model used for hidden states. + tokenizer : transformers.AutoTokenizer + Tokenizer aligned with ``model``. + pooling_mode : str, optional + How to pool token hidden states (e.g. ``mean``, ``eos_token``). + max_length : int, optional + Maximum sequence length for tokenization. + doc_max_length : int, optional + Soft cap used when shortening document segments during encoding. + skip_instruction : bool, optional + If True, restrict attention to embed regions when pooling. + """ + + def __init__( + self, + model: AutoModel, + tokenizer: AutoTokenizer, + pooling_mode: str = "mean", + max_length: int = 512, + doc_max_length: int = 512, + skip_instruction: bool = True, + ): + super().__init__() + self.model = model + self.tokenizer = tokenizer + self.pooling_mode = pooling_mode + self.skip_instruction = skip_instruction + self.max_length = max_length + self.doc_max_length = 512 + self.config = model.config + + @classmethod + def _get_model_class(cls, config_class_name, enable_bidirectional): + if not enable_bidirectional: + return AutoModel + elif config_class_name == "LlamaConfig": + return LlamaBiModel + else: + raise ValueError(f"{config_class_name} is not supported yet with bidirectional models.") + + @classmethod + def from_pretrained( + cls, + base_model_name_or_path, + peft_model_name_or_path=None, + merge_peft=False, + enable_bidirectional=True, + extra_model_name_or_path=None, + **kwargs, + ): + """ + Load tokenizer and encoder weights and return an ``LLM2Vec`` instance. + + Optional PEFT adapters, bidirectional Llama, and extra adapter paths are + supported; keyword arguments are forwarded to Hugging Face + ``from_pretrained`` calls. + + Parameters + ---------- + base_model_name_or_path : str or pathlib.Path + Hub id or local directory for the base model. + peft_model_name_or_path : str or pathlib.Path, optional + Optional PEFT adapter to load on top of the base model. + merge_peft : bool, optional + If True, merge PEFT weights into the base weights after loading. + enable_bidirectional : bool, optional + If True, use bidirectional Llama when the config is ``LlamaConfig``. + extra_model_name_or_path : str, list of str, or None, optional + Additional PEFT checkpoint(s) applied sequentially when set. + **kwargs + Forwarded to Hugging Face ``from_pretrained`` (and related) calls. + + Returns + ------- + LLM2Vec + Configured wrapper around the loaded encoder and tokenizer. + """ + keys = ["pooling_mode", "max_length", "doc_max_length", "skip_instruction"] + encoder_args = {key: kwargs.pop(key, None) for key in keys if kwargs.get(key) is not None} + + tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + + config = AutoConfig.from_pretrained(base_model_name_or_path) + config_class_name = config.__class__.__name__ + + model_class = cls._get_model_class(config_class_name, enable_bidirectional=enable_bidirectional) + model = model_class.from_pretrained(base_model_name_or_path, **kwargs) + + base_path = pathlib.Path(base_model_name_or_path) + config_json = base_path / "config.json" + if base_path.is_dir() and config_json.exists(): + with open(config_json, encoding="utf-8") as config_file: + config_dict = json.load(config_file) + config = PretrainedConfig.from_dict(config_dict) + model.config._name_or_path = config._name_or_path + + if hasattr(model, "peft_config"): + model = PeftModel.from_pretrained( + model, + base_model_name_or_path, + ) + model = model.merge_and_unload() + + if peft_model_name_or_path is not None: + model = PeftModel.from_pretrained( + model, + peft_model_name_or_path, + ) + if merge_peft: + model = model.merge_and_unload() + if extra_model_name_or_path is not None: + logger.info(f"Loading extra model from {extra_model_name_or_path}") + if not merge_peft: + model = model.merge_and_unload() + if isinstance(extra_model_name_or_path, str): + model = PeftModel.from_pretrained( + model, + extra_model_name_or_path, + ) + peft_model_name_or_path = extra_model_name_or_path + model = model.merge_and_unload() + elif isinstance(extra_model_name_or_path, list): + for extra_model in extra_model_name_or_path: + model = PeftModel.from_pretrained( + model, + extra_model, + ) + peft_model_name_or_path = extra_model + model = model.merge_and_unload() + else: + raise ValueError("extra_model_name_or_path should be a string or a list of strings.") + config = {} + config_addr = peft_model_name_or_path if peft_model_name_or_path is not None else base_model_name_or_path + llm2vec_config_path = pathlib.Path(config_addr) / "llm2vec_config.json" + if llm2vec_config_path.exists(): + with open(llm2vec_config_path, encoding="utf-8") as config_file: + llm2vec_config = json.load(config_file) + config.update(llm2vec_config) + logger.info(f"LLM2Vec config: {config}") + for key, value in encoder_args.items(): + config[key] = value + + return cls(model=model, tokenizer=tokenizer, **config) + + def prepare_for_tokenization(self, text): + """ + Apply model-specific chat or EOS wrappers so tokenization matches training. + + Parameters + ---------- + text : str + Raw input text before tokenization. + + Returns + ------- + str + Text with any required special tokens or chat template prefixes or suffixes. + """ + if "Llama-3" in self.model.config._name_or_path and "Instruct" in self.model.config._name_or_path: + text = "<|start_header_id|>user<|end_header_id|>\n\n" + text.strip() + "<|eot_id|>" + return text + if self.model.config._name_or_path == "microsoft/Phi-3.5-mini-instruct": + text = "<|user|>\n" + text.strip() + "<|end|>\n" + return text + if self.pooling_mode == "eos_token": + if self.model.config._name_or_path == "meta-llama/Meta-Llama-3-8B": + text = text.strip() + "<|end_of_text|>" + elif isinstance(self.model.config, LlamaConfig): + text = text.strip() + " " + return text + + def tokenize(self, texts): + """ + Tokenize texts with optional embed-region markers for instruction/document split. + + Parameters + ---------- + texts : list of str + Strings that may contain the ``!@#$%^&*()`` delimiter between instruction and document. + + Returns + ------- + dict[str, torch.Tensor] + Tokenizer outputs including ``embed_mask`` when the delimiter is present. + """ + texts_2 = [] + original_texts = [] + for text in texts: + t = text.split("!@#$%^&*()") + texts_2.append(t[1] if len(t) > 1 else "") + original_texts.append("".join(t)) + + original = self.tokenizer( + original_texts, + return_tensors="pt", + padding=True, + truncation=True, + max_length=self.max_length, + ) + embed_mask = None + for t_i, t in enumerate(texts_2): + ids = self.tokenizer( + [t], + return_tensors="pt", + padding=True, + truncation=True, + max_length=self.max_length, + add_special_tokens=False, + ) + if embed_mask is None: + e_m = torch.zeros_like(original["attention_mask"][t_i]) + if len(ids["input_ids"][0]) > 0: + e_m[-len(ids["input_ids"][0]) :] = torch.ones(len(ids["input_ids"][0])) + embed_mask = e_m.unsqueeze(0) + else: + e_m = torch.zeros_like(original["attention_mask"][t_i]) + if len(ids["input_ids"][0]) > 0: + e_m[-len(ids["input_ids"][0]) :] = torch.ones(len(ids["input_ids"][0])) + embed_mask = torch.cat((embed_mask, e_m.unsqueeze(0)), dim=0) + + original["embed_mask"] = embed_mask + return original + + def _skip_instruction(self, sentence_feature): + assert sentence_feature["attention_mask"].shape == sentence_feature["embed_mask"].shape + sentence_feature["attention_mask"] = sentence_feature["embed_mask"] + + def forward(self, sentence_feature: Dict[str, Tensor]): + """ + Run the encoder and return pooled sentence embeddings. + + Parameters + ---------- + sentence_feature : dict[str, torch.Tensor] + Batch of tokenizer outputs; may include ``embed_mask`` for instruction masking. + + Returns + ------- + torch.Tensor + Pooled embeddings with shape ``(batch_size, hidden_size)``. + """ + embed_mask = None + if "embed_mask" in sentence_feature: + embed_mask = sentence_feature.pop("embed_mask") + reps = self.model(**sentence_feature) + if embed_mask is not None: + sentence_feature["embed_mask"] = embed_mask + + return self.get_pooling(sentence_feature, reps.last_hidden_state) + + def get_pooling(self, features, last_hidden_states): + """ + Pool token hidden states according to ``pooling_mode``. + + Parameters + ---------- + features : dict[str, torch.Tensor] + Tokenizer batch (attention mask, optional ``embed_mask``, etc.). + last_hidden_states : torch.Tensor + Sequence hidden states from the encoder, shape ``(batch, seq, hidden)``. + + Returns + ------- + torch.Tensor + Pooled embeddings, shape ``(batch, hidden)``. + """ + assert self.tokenizer.padding_side == "left", "Pooling modes are implemented for padding from left." + if self.skip_instruction: + self._skip_instruction(features) + seq_lengths = features["attention_mask"].sum(dim=-1) + if self.pooling_mode == "mean": + return torch.stack( + [last_hidden_states[i, -length:, :].mean(dim=0) for i, length in enumerate(seq_lengths)], + dim=0, + ) + elif self.pooling_mode == "weighted_mean": + bs, seq_len, _ = last_hidden_states.shape + complete_weights = torch.zeros(bs, seq_len, device=last_hidden_states.device) + for i, seq_l in enumerate(seq_lengths): + if seq_l > 0: + complete_weights[i, -seq_l:] = torch.arange(seq_l) + 1 + complete_weights[i] /= torch.clamp(complete_weights[i].sum(), min=1e-9) + return torch.sum(last_hidden_states * complete_weights.unsqueeze(-1), dim=1) + elif self.pooling_mode == "eos_token" or self.pooling_mode == "last_token": + return last_hidden_states[:, -1] + elif self.pooling_mode == "bos_token": + return last_hidden_states[features["input_ids"] == self.tokenizer.bos_token_id] + else: + raise ValueError(f"{self.pooling_mode} is not implemented yet.") + + def _convert_to_str(self, instruction, text): + tokenized_q = self.tokenizer( + text, + return_tensors="pt", + padding=True, + truncation=True, + max_length=self.max_length, + add_special_tokens=False, + ) + tokenized_q_length = len(tokenized_q["input_ids"][0]) + + while tokenized_q_length > self.doc_max_length: + reduction_ratio = self.doc_max_length / tokenized_q_length + reduced_length = int(len(text.split()) * reduction_ratio) + text = " ".join(text.split()[:reduced_length]) + tokenized_q = self.tokenizer( + text, + return_tensors="pt", + padding=True, + truncation=True, + max_length=self.max_length, + add_special_tokens=False, + ) + tokenized_q_length = len(tokenized_q["input_ids"][0]) + + return f"{instruction.strip()} !@#$%^&*(){text}" if instruction else f"!@#$%^&*(){text}" + + def encode( + self, + sentences: Union[str, List[str]], + batch_size: int = 32, + show_progress_bar: bool = True, + convert_to_numpy: bool = False, + convert_to_tensor: bool = True, + device: Optional[str] = None, + ): + """ + Encode sentences (optionally instruction + document) to embedding tensors. + + Parameters + ---------- + sentences : str, list of str, or nested list + Plain strings, or ``[instruction, document]`` pairs, or batches thereof. + batch_size : int, optional + Micro-batch size during encoding. + show_progress_bar : bool, optional + Ignored; progress is disabled in the implementation. + convert_to_numpy : bool, optional + If True, return a NumPy array instead of a tensor (mutually exclusive with ``convert_to_tensor``). + convert_to_tensor : bool, optional + If True (default), return a ``torch.Tensor`` of dtype float32. + device : str, optional + Device name; defaults to CUDA when available else CPU. + + Returns + ------- + torch.Tensor or numpy.ndarray + Stacked embeddings for all inputs, reordered to the original sentence order. + """ + seq: Any = sentences + if isinstance(seq[0], str) and isinstance(seq[-1], int): + seq = [seq] + if isinstance(seq[0], str): + seq = [[""] + [sentence] for sentence in seq] + + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + + concatenated_input_texts = [] + for sentence in seq: + assert isinstance(sentence[0], str) + assert isinstance(sentence[1], str) + concatenated_input_texts.append(self._convert_to_str(sentence[0], sentence[1])) + sentences = concatenated_input_texts + + self.train(mode=False) + + if convert_to_tensor: + convert_to_numpy = False + + length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences]) + sentences_sorted = [sentences[idx] for idx in length_sorted_idx] + all_embeddings = [] + + self.to(device) + for start_index in trange( + 0, + len(sentences), + batch_size, + desc="Batches", + disable=True, + ): + sentences_batch = sentences_sorted[start_index : start_index + batch_size] + embeddings = self._encode(sentences_batch, device=device, convert_to_numpy=convert_to_numpy) + all_embeddings.append(embeddings) + + all_embeddings = torch.cat(all_embeddings, dim=0) + all_embeddings = all_embeddings[np.argsort(length_sorted_idx)] + all_embeddings = all_embeddings.to(torch.float32) + return all_embeddings + + def save(self, output_path, merge_before_save=False, save_config=True): + """ + Persist model, tokenizer, and optional ``llm2vec_config.json`` to ``output_path``. + + Parameters + ---------- + output_path : str or pathlib.Path + Directory to write weights and tokenizer files into. + merge_before_save : bool, optional + If True and the inner model is a ``PeftModel``, merge adapters before saving. + save_config : bool, optional + If True, write ``llm2vec_config.json`` with pooling and length settings. + """ + if merge_before_save and isinstance(self.model, PeftModel): + self.model = self.model.merge_and_unload() + if hasattr(self.model, "_hf_peft_config_loaded"): + setattr(self.model, "_hf_peft_config_loaded", False) + + self.model.save_pretrained(output_path) + self.tokenizer.save_pretrained(output_path) + + llm2vec_config = { + "pooling_mode": self.pooling_mode, + "max_length": self.max_length, + "doc_max_length": self.doc_max_length, + "skip_instruction": self.skip_instruction, + } + + if save_config: + pathlib.Path(output_path).mkdir(exist_ok=True, parents=True) + config_out = pathlib.Path(output_path) / "llm2vec_config.json" + with open(config_out, "w", encoding="utf-8") as config_file: + json.dump(llm2vec_config, config_file, indent=4) + + def _encode( + self, + sentences_batch, + device: Optional[str] = None, + convert_to_numpy: bool = False, + multiprocessing=False, + ): + if multiprocessing: + rank = mp.current_process()._identity[0] + if device is None and torch.cuda.is_available(): + device = f"cuda:{rank % torch.cuda.device_count()}" + + use_device = device if device is not None else ("cuda" if torch.cuda.is_available() else "cpu") + self.to(use_device) + features = self.tokenize([self.prepare_for_tokenization(sentence) for sentence in sentences_batch]) + features = batch_to_device(features, use_device) + + with torch.no_grad(): + embeddings = self.forward(features) + return embeddings + + def _text_length(self, text: Union[List[int], List[List[int]]]): + if isinstance(text, str) or (isinstance(text, list) and isinstance(text[0], int)) or len(text) == 0: + return len(text) + if isinstance(text, dict): + return len(next(iter(text.values()))) + elif not hasattr(text, "__len__"): + return 1 + else: + return sum(len(t) if not isinstance(t, int) else 1 for t in text) + + def resize_token_embeddings( + self, + new_num_tokens: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + ) -> nn.Embedding: + """ + Resize the underlying model token embedding matrix. + + Parameters + ---------- + new_num_tokens : int, optional + New vocabulary size for the embedding table. + pad_to_multiple_of : int, optional + Pad vocabulary size to a multiple of this value when resizing. + + Returns + ------- + torch.nn.Embedding + The resized embedding module from the wrapped model. + """ + return self.model.resize_token_embeddings(new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of) + + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): + """ + Enable gradient checkpointing on the wrapped model. + + Parameters + ---------- + gradient_checkpointing_kwargs : dict, optional + Keyword arguments forwarded to the underlying ``gradient_checkpointing_enable`` call. + """ + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) diff --git a/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/modeling_llama_encoder.py b/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/modeling_llama_encoder.py new file mode 100644 index 00000000..cf9b4df8 --- /dev/null +++ b/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/modeling_llama_encoder.py @@ -0,0 +1,107 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Derived from McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp ``modeling_llama_encoder.py`` +# (Hugging Face Hub). Upstream requires ``flash_attention_2`` only; this copy allows ``eager`` +# and ``sdpa`` so ``oneig_reasoning`` can run on CPU without ``flash_attn``. See +# ``NOTICE.oneig_llm2vec`` in the parent ``vendor`` directory. + +import importlib.metadata + +from packaging import version +from torch import nn +from transformers import LlamaConfig, LlamaModel, LlamaPreTrainedModel +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaMLP, + LlamaRMSNorm, + LlamaRotaryEmbedding, +) +from transformers.utils import logging +from transformers.utils.import_utils import _is_package_available + +logger = logging.get_logger(__name__) + + +def is_transformers_attn_greater_or_equal_4_56_2() -> bool: + """ + Check whether the installed ``transformers`` package is at least 4.56.2. + + Returns + ------- + bool + True if ``transformers`` is installed and its version is >= 4.56.2; + False otherwise. + """ + if not _is_package_available("transformers"): + return False + return version.parse(importlib.metadata.version("transformers")) >= version.parse("4.56.2") + + +class ModifiedLlamaAttention(LlamaAttention): + """ + Llama self-attention with ``is_causal`` disabled for encoder-style use. + + Parameters + ---------- + *args, **kwargs + Forwarded to :class:`~transformers.models.llama.modeling_llama.LlamaAttention`. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_causal = False + + +class ModifiedLlamaDecoderLayer(LlamaDecoderLayer): + """ + Decoder block using :class:`ModifiedLlamaAttention` for bidirectional encoding. + + Parameters + ---------- + config : LlamaConfig + Model configuration. + layer_idx : int + Index of this decoder layer. + """ + + def __init__(self, config: LlamaConfig, layer_idx: int): + GradientCheckpointingLayer.__init__(self) + self.hidden_size = config.hidden_size + self.self_attn = ModifiedLlamaAttention(config=config, layer_idx=layer_idx) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + +class LlamaEncoderModel(LlamaModel): + """ + Bidirectional Llama stack for LLM2Vec-style encoding (eager, SDPA, or flash attention). + + Parameters + ---------- + config : LlamaConfig + Model configuration (requires transformers >= 4.56.2 layout). + """ + + def __init__(self, config: LlamaConfig) -> None: + if not is_transformers_attn_greater_or_equal_4_56_2(): + raise ValueError( + "The current implementation of LlamaEncoderModel follows modeling_llama.py " + "of transformers version >= 4.56.2" + ) + LlamaPreTrainedModel.__init__(self, config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [ModifiedLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + attn_impl = getattr(config, "_attn_implementation", getattr(config, "attn_implementation", "eager")) + self._use_sdpa = attn_impl == "sdpa" + self._use_flash_attention_2 = attn_impl == "flash_attention_2" + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.post_init() diff --git a/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/models/bidirectional_llama.py b/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/models/bidirectional_llama.py new file mode 100644 index 00000000..610853ac --- /dev/null +++ b/src/pruna/evaluation/metrics/vendor/oneig_llm2vec/models/bidirectional_llama.py @@ -0,0 +1,228 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Vendored from OneIG-Benchmark (commit 41b49831e79e6dde5323618c164da1c4cf0f699d). + +import importlib.metadata +from typing import cast + +import torch +from packaging import version +from peft import PeftModel +from torch import nn +from transformers import ( + LlamaConfig, + LlamaForCausalLM, + LlamaModel, + LlamaPreTrainedModel, +) +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.models.llama.modeling_llama import ( + LlamaDecoderLayer, + LlamaRMSNorm, + LlamaRotaryEmbedding, +) +from transformers.utils import logging +from transformers.utils.import_utils import _is_package_available + +logger = logging.get_logger(__name__) + + +def is_transformers_attn_greater_or_equal_4_38() -> bool: + """ + Check whether the installed ``transformers`` package is at least 4.38.0. + + Returns + ------- + bool + True if ``transformers`` is installed and its version is >= 4.38.0; + False otherwise. + """ + if not _is_package_available("transformers"): + return False + return version.parse(importlib.metadata.version("transformers")) >= version.parse("4.38.0") + + +def is_transformers_attn_greater_or_equal_4_40() -> bool: + """ + Check whether the installed ``transformers`` package is at least 4.40.0. + + Returns + ------- + bool + True if ``transformers`` is installed and its version is >= 4.40.0; + False otherwise. + """ + if not _is_package_available("transformers"): + return False + return version.parse(importlib.metadata.version("transformers")) >= version.parse("4.40.0") + + +class ModifiedLlamaDecoderLayer(LlamaDecoderLayer): + """ + Decoder layer with non-causal self-attention when supported by the attention module. + + Parameters + ---------- + config : LlamaConfig + Model configuration. + layer_idx : int + Index of this decoder layer. + """ + + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__(config, layer_idx) + if hasattr(self.self_attn, "is_causal"): + self.self_attn.is_causal = False + + +class LlamaBiModel(LlamaModel): + """ + Bidirectional Llama backbone for MNTP-style training (transformers >= 4.38). + + Parameters + ---------- + config : LlamaConfig + Model configuration. + """ + + _no_split_modules = ["ModifiedLlamaDecoderLayer"] + + def __init__(self, config: LlamaConfig): + if not is_transformers_attn_greater_or_equal_4_38(): + raise ValueError( + "The current implementation of LlamaBiModel follows modeling_llama.py of transformers version >= 4.38.0" + ) + LlamaPreTrainedModel.__init__(self, config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + + self.layers = nn.ModuleList( + [ModifiedLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + self.post_init() + + def _update_causal_mask( + self, + attention_mask, + input_tensor, + cache_position, + past_seen_tokens=None, + output_attentions=False, + ): + attn_impl = getattr(self.config, "_attn_implementation", getattr(self.config, "attn_implementation", "eager")) + if attn_impl == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + + if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): + target_length = self.config.max_position_embeddings + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else ( + cache_position[-1] + 1 + if not is_transformers_attn_greater_or_equal_4_40() + else past_seen_tokens + sequence_length + 1 + ) + ) + + causal_mask = torch.zeros((sequence_length, target_length), dtype=dtype, device=device) + + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + + if attention_mask is not None: + causal_mask = causal_mask.clone() + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + elif attention_mask.dim() == 4: + offset = cache_position[0] if attention_mask.shape[-2] < cache_position[0] + sequence_length else 0 + mask_shape = attention_mask.shape + mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype + causal_mask[ + : mask_shape[0], + : mask_shape[1], + offset : mask_shape[2] + offset, + : mask_shape[3], + ] = mask_slice + + attn_impl = getattr(self.config, "_attn_implementation", getattr(self.config, "attn_implementation", "eager")) + if ( + attn_impl == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + causal_mask = AttentionMaskConverter._unmask_unattended( + cast(torch.FloatTensor, causal_mask.to(dtype=torch.float32)), + min_dtype, + ) + + return causal_mask + + +class LlamaBiForMNTP(LlamaForCausalLM): + """ + Causal LM wrapper around :class:`LlamaBiModel` for MNTP with optional PEFT. + + Parameters + ---------- + config : LlamaConfig + Model configuration. + """ + + def __init__(self, config: LlamaConfig): + LlamaPreTrainedModel.__init__(self, config) + self.model = LlamaBiModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.post_init() + + def get_model_for_peft(self) -> LlamaBiModel | PeftModel: + """ + Return the inner model for PEFT wrapping (base or wrapped). + + Returns + ------- + LlamaBiModel or PeftModel + ``self.model``, either a :class:`LlamaBiModel` or a :class:`peft.PeftModel`. + """ + return self.model + + def set_model_for_peft(self, model: PeftModel) -> None: + """ + Replace the inner model with a PEFT-wrapped model. + + Parameters + ---------- + model : PeftModel + PEFT model whose base matches the expected backbone. + """ + self.model = model + + def save_peft_model(self, path: str) -> None: + """ + Save the (possibly PEFT-wrapped) inner model to disk. + + Parameters + ---------- + path : str + Directory path passed to ``save_pretrained`` on the inner model. + """ + self.model.save_pretrained(path) From 21212de87118cda902b76d138228109956f93154 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Sat, 25 Apr 2026 12:50:58 +0200 Subject: [PATCH 2/3] feat(infrastructure): add VLM base classes and utilities - Add BaseVLM abstract interface - Add LitellmVLM for API-based inference (OpenAI, Anthropic, etc.) - Add TransformersVLM for local Hugging Face models - Add StatefulVLMMeanScoresMetric base class for judge metrics - Add vlm_utils.py with image/batch utilities - Add pyproject.toml dependency pins (peft, litellm) - Add unit tests for infrastructure --- pyproject.toml | 23 +- src/pruna/evaluation/metrics/__init__.py | 49 +- src/pruna/evaluation/metrics/utils.py | 27 +- src/pruna/evaluation/metrics/vlm_base.py | 1118 +++++++++++++++++ src/pruna/evaluation/metrics/vlm_utils.py | 394 ++++++ .../test_vlm_base_infrastructure.py | 684 ++++++++++ 6 files changed, 2265 insertions(+), 30 deletions(-) create mode 100644 src/pruna/evaluation/metrics/vlm_base.py create mode 100644 src/pruna/evaluation/metrics/vlm_utils.py create mode 100644 tests/evaluation/test_vlm_base_infrastructure.py diff --git a/pyproject.toml b/pyproject.toml index d70093bf..41139e8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,9 +36,6 @@ possibly-missing-attribute = "ignore" missing-argument = "ignore" unused-type-ignore-comment = "ignore" -[tool.bandit] -exclude_dirs = ["tests", "docs"] - [tool.coverage.run] source = ["src/pruna"] @@ -96,14 +93,14 @@ stable-fast-pruna = { index = "pruna_internal", extra = "stable-fast-extraindex" [project] name = "pruna" -version = "0.3.3" +version = "0.3.2" description = "Smash your AI models" authors = [ {name = "Pruna AI", email = "hello@pruna.ai"} ] license = {file = "LICENSE"} readme = "README.md" -requires-python = ">=3.10,<3.14" +requires-python = ">=3.10,<3.13" keywords = ["AI", "machine learning", "model optimization", "pruning"] classifiers = [ "Development Status :: 4 - Beta", @@ -156,6 +153,7 @@ dependencies = [ "peft>=0.18.0,<0.19.0", "trl<=0.21.0", "termcolor==2.3.0", + "realesrgan", ] [project.optional-dependencies] @@ -170,6 +168,10 @@ vllm = [ "vllm>=0.16.0", "ray", ] +evaluation = [ + "outlines>1.2.0,<2.0.0", + "litellm>=1.0.0", +] stable-fast = [ "xformers>=0.0.30", "stable-fast-pruna>=1.0.8,<1.0.9", @@ -194,18 +196,12 @@ awq = [ "llmcompressor>=0.9", "torch>=2.9.0" ] -upscale = [ - "realesrgan", -] full = [ "pruna[stable-fast]", ] vbench = [ "vbench-pruna; sys_platform != 'darwin'", ] -rapidata = [ - "rapidata>=3.0.0" -] dev = [ "wget", "python-dotenv", @@ -232,15 +228,12 @@ dev = [ "types-PyYAML", "logbar", "pytest-xdist>=3.8.0", + "pruna[evaluation]", ] cpu = [] lmharness = [ "lm-eval>=0.4.0" ] -evaluation = [ - "pruna[rapidata]", - "pruna[lmharness]" -] # Intel extension is tightly coupled with the torch version intel = [ diff --git a/src/pruna/evaluation/metrics/__init__.py b/src/pruna/evaluation/metrics/__init__.py index bf7414c3..6d43473f 100644 --- a/src/pruna/evaluation/metrics/__init__.py +++ b/src/pruna/evaluation/metrics/__init__.py @@ -17,15 +17,41 @@ from pruna.evaluation.metrics.aesthetic_laion import AestheticLAION from pruna.evaluation.metrics.metric_cmmd import CMMD from pruna.evaluation.metrics.metric_dino_score import DinoScore -from pruna.evaluation.metrics.metric_elapsed_time import LatencyMetric, ThroughputMetric, TotalTimeMetric -from pruna.evaluation.metrics.metric_energy import CO2EmissionsMetric, EnergyConsumedMetric +from pruna.evaluation.metrics.metric_elapsed_time import ( + LatencyMetric, + ThroughputMetric, + TotalTimeMetric, +) +from pruna.evaluation.metrics.metric_energy import ( + CO2EmissionsMetric, + EnergyConsumedMetric, +) from pruna.evaluation.metrics.metric_evalharness import LMEvalMetric -from pruna.evaluation.metrics.metric_memory import DiskMemoryMetric, InferenceMemoryMetric, TrainingMemoryMetric -from pruna.evaluation.metrics.metric_model_architecture import TotalMACsMetric, TotalParamsMetric +from pruna.evaluation.metrics.metric_img_edit_score import ImageEditScoreMetric +from pruna.evaluation.metrics.metric_memory import ( + DiskMemoryMetric, + InferenceMemoryMetric, + TrainingMemoryMetric, +) +from pruna.evaluation.metrics.metric_model_architecture import ( + TotalMACsMetric, + TotalParamsMetric, +) +from pruna.evaluation.metrics.metric_oneig_alignment import OneIGAlignmentMetric from pruna.evaluation.metrics.metric_pairwise_clip import PairwiseClipScore -from pruna.evaluation.metrics.metric_rapiddata import RapidataMetric as RapidataMetric +from pruna.evaluation.metrics.metric_qa_accuracy import QAAccuracyMetric from pruna.evaluation.metrics.metric_sharpness import SharpnessMetric +from pruna.evaluation.metrics.metric_text_score import OneIGTextScoreMetric, TextScoreMetric from pruna.evaluation.metrics.metric_torch import TorchMetricWrapper +from pruna.evaluation.metrics.metric_vie_score import VieScoreMetric +from pruna.evaluation.metrics.metric_vqa import VQAMetric +from pruna.evaluation.metrics.vlm_base import ( + BaseVLM, + LitellmVLM, + StatefulVLMMeanScoresMetric, + TransformersVLM, + get_vlm, +) __all__ = [ "MetricRegistry", @@ -46,5 +72,16 @@ "SharpnessMetric", "AestheticLAION", "LMEvalMetric", - "RapidataMetric", + "VQAMetric", + "ImageEditScoreMetric", + "QAAccuracyMetric", + "OneIGAlignmentMetric", + "TextScoreMetric", + "OneIGTextScoreMetric", + "VieScoreMetric", + "BaseVLM", + "LitellmVLM", + "StatefulVLMMeanScoresMetric", + "TransformersVLM", + "get_vlm", ] diff --git a/src/pruna/evaluation/metrics/utils.py b/src/pruna/evaluation/metrics/utils.py index 29342701..a3cdb4a5 100644 --- a/src/pruna/evaluation/metrics/utils.py +++ b/src/pruna/evaluation/metrics/utils.py @@ -56,13 +56,17 @@ def metric_data_processor( This function determines the order and selection of inputs to be passed to various metrics. The function supports different input arrangements through the 'call_type' configuration: - - 'x_y': Uses input data (x) and model outputs - - 'gt_y': Uses ground truth (gt) and model outputs - - 'y_x': Uses model outputs and input data (x) - - 'y_gt': Uses model outputs and ground truth (gt) - - 'pairwise_gt_y': Uses cached base model outputs (gt) and smashed model outputs (y). - - 'pairwise_y_gt': Uses smashed model outputs (y) and cached base model outputs (gt). - The evaluation agent is expected to pass the cached base model outputs as gt. + + - 'y_gt': Model's output first, then ground truth. Returns [outputs, gt]. + - 'gt_y': Ground truth first, then model's output. Returns [gt, outputs]. + - 'y_x': Model's output first, then input data. Returns [outputs, x]. + Used by CLIPScore, VQA, ImageEditScore, VIEScore. + - 'x_y': Input data first, then model's output. Returns [x, outputs]. + - 'x_gt': Input data first, then ground truth. Returns [x, gt]. + - 'gt_x': Ground truth first, then input data. Returns [gt, x]. + - 'pairwise_y_gt': Base model's output first, then subsequent model's output. + - 'pairwise_gt_y': Subsequent model's output first, then base model's output. + - 'y': Only the output is used; the metric has an internal dataset. Returns [outputs]. Parameters ---------- @@ -85,7 +89,8 @@ def metric_data_processor( Raises ------ ValueError - If the specified call_type is not one of: 'x_y', 'gt_y', 'y_x', 'y_gt', 'pairwise'. + If the specified call_type is not one of: 'y_gt', 'gt_y', 'y_x', 'x_y', + 'x_gt', 'gt_x', 'pairwise_y_gt', 'pairwise_gt_y', 'y'. Examples -------- @@ -106,11 +111,15 @@ def metric_data_processor( return [outputs, x] elif call_type == "y_gt": return [outputs, gt] + elif call_type == "x_gt": + return [x, gt] + elif call_type == "gt_x": + return [gt, x] elif call_type == "pairwise_gt_y": return [gt, outputs] elif call_type == "pairwise_y_gt": return [outputs, gt] - elif call_type == "y": # IQA metrics that have an internal dataset + elif call_type == "y": return [outputs] else: raise ValueError(f"Invalid call type: {call_type}") diff --git a/src/pruna/evaluation/metrics/vlm_base.py b/src/pruna/evaluation/metrics/vlm_base.py new file mode 100644 index 00000000..2de7e164 --- /dev/null +++ b/src/pruna/evaluation/metrics/vlm_base.py @@ -0,0 +1,1118 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +VLM (Vision-Language Model) base classes for metrics. + +Implementations +--------------- +- **LitellmVLM** — API inference via ``litellm`` (many providers behind one client). +- **TransformersVLM** — local Hugging Face models on device. + +Why LiteLLM for the default API path +-------------------------------------- +Judge-style metrics need a capable vision-language model. Loading large VLMs locally is +expensive; routing through ``litellm`` keeps the default path lightweight and matches common +API-judge setups without bundling a full local VLM in every metric run. + +API keys and environment +------------------------ +For ``vlm_type="litellm"``, the key passed to the provider is resolved in this order: + +1. The ``api_key`` argument on the metric or :func:`get_vlm` +2. ``LITELLM_API_KEY`` +3. ``OPENAI_API_KEY`` + +Routes such as ``openai/gpt-4o`` use the OpenAI-compatible key. Other providers follow +LiteLLM’s environment conventions (for example ``ANTHROPIC_API_KEY`` for ``anthropic/...``). +The same ``OPENAI_API_KEY`` you use for other OpenAI-hosted judges (for example in pbench) +applies here. + +For a short user-facing summary of key order, hosted vs local, and a minimal ``transformers`` +example, see :doc:`Evaluate a model ` (Vision-language judge +metrics). + +Choosing local vs API +--------------------- +Metrics in :data:`VLM_METRIC_REGISTRY_NAMES` take ``vlm_type`` and ``model_name``: + +- **API** (``vlm_type="litellm"``, default) — use a vision-capable route (e.g. + :data:`DEFAULT_LITELLM_MODEL`). +- **Local** (``vlm_type="transformers"``) — e.g. SmolVLM for offline or CI. + +The ``oneig_reasoning`` metric is separate: it runs the LLM2CLIP stack locally; see +``pruna.evaluation.metrics.metric_oneig_reasoning``. + +Structured outputs +------------------ +- LitellmVLM: pydantic ``response_format`` where applicable. +- TransformersVLM: Outlines 1.x constrained decoding via ``outlines.Generator`` and + ``outlines.models.transformers.from_transformers`` (single- and multi-image ``Chat`` inputs). + +Usage examples +---------------- +Minimal LiteLLM and local ``transformers`` construction is shown under :func:`get_vlm` +(``Examples`` section). **Registry metrics** (``vqa``, ``qa_accuracy``, +``img_edit_score``, OCR/text metrics, ``oneig_alignment``, ``vie_score``, …) take the same +``vlm_type``, ``model_name``, ``api_key``, and ``vlm_kwargs`` pattern; see +:class:`StatefulVLMMeanScoresMetric` below and each metric class docstring. + +For VIEScore-style **text--image editing** metrics that pass two PIL images per prompt (source +then edited), call :meth:`LitellmVLM.generate_with_image_lists` or +:meth:`TransformersVLM.generate_with_image_lists` with ``image_lists[i]`` aligned to +``prompts[i]``. + +Metric-level (hosted vs local) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +``Task(request=["vqa"], ...)`` supplies ``model_name="openai/gpt-4o"`` for VLM registry names. +Override the backend by constructing a metric instance: + +.. code-block:: python + + import torch + + from pruna.evaluation.metrics import VQAMetric + + hosted = VQAMetric(vlm_type="litellm", model_name="openai/gpt-4o") + local = VQAMetric( + vlm_type="transformers", + model_name="HuggingFaceTB/SmolVLM-256M-Instruct", + device="cpu", + vlm_kwargs={"model_load_kwargs": {"torch_dtype": torch.float32}}, + ) + +``QAAccuracyMetric`` and other classes that call :func:`get_vlm` directly use the same +arguments (substitute the class name). +""" + +from __future__ import annotations + +import base64 +import io +import math +import os +from abc import ABC, abstractmethod +from typing import Any, List, Literal, Optional, Type, TypeVar, Union + +import numpy as np +import torch +from PIL import Image +from pydantic import BaseModel + +from pruna.evaluation.metrics.metric_stateful import StatefulMetric +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.utils import get_call_type_for_single_metric +from pruna.logging.logger import pruna_logger + +T = TypeVar("T", bound=BaseModel) + +DEFAULT_LITELLM_MODEL: str = "openai/gpt-4o" + +VLM_METRIC_REGISTRY_NAMES: frozenset[str] = frozenset( + ( + "vqa", + "qa_accuracy", + "img_edit_score", + "text_score", + "ocr_levenshtein", + "ocr_text_score", + "oneig_text_score", + "oneig_alignment", + "vie_score", + ) +) + + +def get_vlm( + vlm: Optional[BaseVLM] = None, + vlm_type: Literal["litellm", "transformers"] = "litellm", + *, + model_name: Optional[str] = None, + device: Optional[str | torch.device] = None, + api_key: Optional[str] = None, + structured_output: bool = True, + **vlm_kwargs: Any, +) -> BaseVLM: + """ + Create or return a VLM instance. + + Parameters + ---------- + vlm : BaseVLM | None + If provided, returned as-is. Otherwise a VLM is created. + vlm_type : {"litellm", "transformers"} + Backend when creating a VLM. + model_name : str | None + Model name for litellm (e.g. ``openai/gpt-4o``) or HuggingFace ``from_pretrained`` id. + **Required** when ``vlm`` is not provided. Ignored when ``vlm`` is provided. + device : str | torch.device | None + Device for transformers VLM. + api_key : str | None + API key for litellm. + structured_output : bool + When True, litellm uses pydantic ``response_format`` from the metric; for + ``transformers``, enables outlines-based constrained decoding when a string + format is passed to ``generate``/``score``. + **vlm_kwargs : Any + Same dict as ``vlm_kwargs`` on VLM metrics: forwarded to the backend chosen by + ``vlm_type``. For ``"litellm"``, kwargs go to ``LitellmVLM`` (e.g. provider-specific + options). For ``"transformers"``, use ``model_load_kwargs`` for + ``AutoModelForImageTextToText.from_pretrained``; any other keys are passed to + ``TransformersVLM`` after ``model_load_kwargs`` is popped. + + Returns + ------- + BaseVLM + The VLM instance. + + Notes + ----- + When ``vlm_type`` is ``"litellm"`` and ``api_key`` is omitted, the key is taken from + ``LITELLM_API_KEY`` or ``OPENAI_API_KEY``. See the module docstring above. User manual: + :doc:`Evaluate a model ` (Vision-language judge metrics). + + Examples + -------- + Hosted (``litellm``) and local Hugging Face (``transformers``). API key for ``hosted`` from + ``OPENAI_API_KEY`` or ``LITELLM_API_KEY`` if ``api_key`` is omitted. + + .. code-block:: python + + import torch + + from pruna.evaluation.metrics.vlm_base import get_vlm + + hosted = get_vlm(vlm_type="litellm", model_name="openai/gpt-4o") + local = get_vlm( + vlm_type="transformers", + model_name="HuggingFaceTB/SmolVLM-256M-Instruct", + device="cpu", + model_load_kwargs={"torch_dtype": torch.float32}, + ) + + Another LiteLLM provider route (set the env var that route expects, e.g. + ``ANTHROPIC_API_KEY`` for ``anthropic/...``): + + .. code-block:: python + + from pruna.evaluation.metrics.vlm_base import get_vlm + + other_provider = get_vlm( + vlm_type="litellm", model_name="anthropic/claude-3-5-sonnet-20241022" + ) + """ + if vlm is not None: + return vlm + if not model_name: + raise ValueError( + "get_vlm requires model_name when vlm is not provided " + '(pass model_name explicitly, e.g. model_name="openai/gpt-4o").' + ) + if vlm_type == "litellm": + return LitellmVLM(model_name=model_name, api_key=api_key, **vlm_kwargs) + model_load_kwargs = vlm_kwargs.pop("model_load_kwargs", {}) + return TransformersVLM( + model_name=model_name, + device=device, + use_outlines=structured_output, + model_load_kwargs=model_load_kwargs, + **vlm_kwargs, + ) + + +class BaseVLM(ABC): + """Base class for Vision-Language Models.""" + + @abstractmethod + def generate( + self, + images: List[Image.Image], + prompts: List[str], + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> List[str]: + """ + Generate responses for images and prompts. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + prompts : List[str] + List of text prompts. + response_format : Type[BaseModel] | str | None + Optional pydantic model (litellm) or format string: "integer", "yes_no", "json" (transformers/outlines). + **kwargs : Any + Additional arguments passed to the implementation. + + Returns + ------- + List[str] + Generated responses. + """ + pass + + @abstractmethod + def score( + self, + images: List[Image.Image], + questions: List[str], + answers: List[str], + use_probability: bool = False, + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> List[float]: + """ + Score how well answers match images for given questions. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + questions : List[str] + List of questions. + answers : List[str] + List of expected answers. + use_probability : bool, optional + If True and supported, return P(expected answer) instead of binary 0/1. + response_format : Type[BaseModel] | str | None, optional + Structured output format. When set, uses generate() with this format and + extracts the answer field for comparison instead of raw string matching. + **kwargs : Any + Additional arguments passed to the implementation. + + Returns + ------- + List[float] + Scores for each image-question pair (0-1, or probability when use_probability). + """ + pass + + +class LitellmVLM(BaseVLM): + """ + VLM using litellm for API-based inference. + + Supports many providers (OpenAI, Anthropic, Azure, and others) through a single client. + + Parameters + ---------- + model_name : str + Model name (e.g. ``openai/gpt-4o`` for litellm). Passed from :func:`get_vlm`. + api_key : str | None, optional + API key for the provider. If omitted, uses ``LITELLM_API_KEY`` then ``OPENAI_API_KEY``. + **kwargs : Any + Additional arguments passed to litellm. + + Notes + ----- + LiteLLM is the default API backend so metric runs can use a hosted VLM judge without + downloading large local checkpoints. Provider-specific environment variables are described + in the LiteLLM documentation; OpenAI-compatible routes typically use ``OPENAI_API_KEY``. + User manual: :doc:`Evaluate a model ` (Vision-language + judge metrics). + + Examples + -------- + OpenAI-compatible route (pass ``api_key`` explicitly, or rely on ``OPENAI_API_KEY`` / + ``LITELLM_API_KEY`` when omitted): + + >>> from pruna.evaluation.metrics.vlm_base import LitellmVLM + >>> hosted = LitellmVLM(model_name="openai/gpt-4o", api_key="sk-placeholder") + >>> hosted.api_key == "sk-placeholder" + True + + Same naming as :func:`get_vlm` examples: ``other_provider`` for non-OpenAI LiteLLM routes + (set ``ANTHROPIC_API_KEY``, etc.): + + .. code-block:: python + + from pruna.evaluation.metrics.vlm_base import LitellmVLM + + other_provider = LitellmVLM(model_name="anthropic/claude-3-5-sonnet-20241022") + """ + + def __init__( + self, + model_name: str, + api_key: Optional[str] = None, + **kwargs: Any, + ) -> None: + self.model_name = model_name + self.api_key = api_key or os.getenv("LITELLM_API_KEY") or os.getenv("OPENAI_API_KEY") + self.extra_kwargs = kwargs + try: + import litellm + + litellm.drop_params = True + self._litellm = litellm + except ImportError: + pruna_logger.error("litellm not installed. Install with: pip install litellm") + raise + + def _litellm_chat_completion( + self, + content: list[dict[str, Any]], + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> str: + completion_kwargs: dict[str, Any] = { + "model": self.model_name, + "messages": [{"role": "user", "content": content}], + "api_key": self.api_key, + **self.extra_kwargs, + **kwargs, + } + if response_format is not None and isinstance(response_format, type): + completion_kwargs["response_format"] = response_format + response = self._litellm.completion(**completion_kwargs) + content_result = response.choices[0].message.content + use_pydantic = ( + response_format is not None + and isinstance(response_format, type) + and isinstance(content_result, response_format) + ) + if use_pydantic: + return content_result.model_dump_json() + return str(content_result) if content_result is not None else "" + + def generate( + self, + images: List[Image.Image], + prompts: List[str], + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> List[str]: + """ + Generate responses for images and prompts. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + prompts : List[str] + List of text prompts. + response_format : Type[BaseModel] | str | None + Optional pydantic model for structured output (litellm uses BaseModel). + **kwargs : Any + Additional arguments passed to litellm completion. + + Returns + ------- + List[str] + Generated responses. + """ + results = [] + for image, prompt in zip(images, prompts): + try: + content = [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": self._image_to_data_url(image)}}, + ] + results.append(self._litellm_chat_completion(content, response_format, **kwargs)) + except Exception as e: + pruna_logger.error(f"Litellm generation failed: {e}") + results.append("") + return results + + def generate_with_image_lists( + self, + image_lists: List[List[Image.Image]], + prompts: List[str], + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> List[str]: + """ + Generate one response per (``image_list``, ``prompt``) pair. + + Each ``image_list`` contains one or more PIL images (e.g. source and edited for + VIEScore ``tie``). Message content is built as text first, then each image as + ``image_url``, matching common OpenAI-style multi-image chat layouts. + + Parameters + ---------- + image_lists : list[list[PIL.Image.Image]] + One list of images per prompt (same length as ``prompts``). + prompts : list[str] + User text for each row. + response_format : optional + Same as :meth:`generate`. + **kwargs : Any + Forwarded to litellm ``completion``. + + Returns + ------- + list[str] + One string (or JSON string for pydantic) per row. + """ + if len(image_lists) != len(prompts): + raise ValueError("image_lists and prompts must have the same length.") + results: List[str] = [] + for imgs, prompt in zip(image_lists, prompts): + try: + content: list[dict[str, Any]] = [{"type": "text", "text": prompt}] + for im in imgs: + content.append({"type": "image_url", "image_url": {"url": self._image_to_data_url(im)}}) + results.append(self._litellm_chat_completion(content, response_format, **kwargs)) + except Exception as e: + pruna_logger.error(f"Litellm multi-image generation failed: {e}") + results.append("") + return results + + def score( + self, + images: List[Image.Image], + questions: List[str], + answers: List[str], + use_probability: bool = False, + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> List[float]: + """ + Score how well answers match images for given questions. + + When use_probability=True, requests logprobs from the API and returns P(expected). + When response_format is set, uses structured generation and extracts the answer field. + Falls back to binary 0/1 if logprobs not available. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + questions : List[str] + List of questions. + answers : List[str] + List of expected answers. + use_probability : bool, optional + If True, return P(expected) from logprobs when available. Default is False. + response_format : Type[BaseModel] | str | None, optional + Structured output format for answer extraction. + **kwargs : Any + Additional arguments passed to litellm completion. + + Returns + ------- + List[float] + Scores for each image-question pair (0-1, or probability when use_probability). + """ + from pruna.evaluation.metrics.vlm_utils import get_answer_from_response + + scores = [] + for image, question, answer in zip(images, questions, answers): + prompt = f"{question} Please answer yes or no." + if use_probability: + score = self._score_with_logprobs(image, prompt, answer, **kwargs) + elif response_format is not None: + raw = self.generate([image], [prompt], response_format=response_format, **kwargs)[0] + response_answer = get_answer_from_response(raw) + score = 1.0 if answer.lower() in response_answer.lower() else 0.0 + else: + response = self.generate([image], [prompt], **kwargs)[0].lower() + score = 1.0 if answer.lower() in response else 0.0 + scores.append(score) + return scores + + def _score_with_logprobs(self, image: Image.Image, prompt: str, expected: str, **kwargs: Any) -> float: + """ + Get P(expected) from logprobs when available. + + Parameters + ---------- + image : Image.Image + PIL Image to score. + prompt : str + Question prompt. + expected : str + Expected answer (e.g., "Yes"). + **kwargs : Any + Additional arguments passed to litellm completion. + + Returns + ------- + float + Probability of expected answer (0-1), or binary 0/1 on fallback. + """ + content = [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": self._image_to_data_url(image)}}, + ] + completion_kwargs = { + "model": self.model_name, + "messages": [{"role": "user", "content": content}], + "api_key": self.api_key, + "logprobs": True, + "top_logprobs": 5, + **self.extra_kwargs, + **kwargs, + } + try: + response = self._litellm.completion(**completion_kwargs) + choice = response.choices[0] + logprobs = getattr(choice, "logprobs", None) or getattr(choice.message, "logprobs", None) + if logprobs and hasattr(logprobs, "content"): + # Match token if it starts with the yes/no word and the remainder is non-alphabetic + # (e.g. "yes." or "no," match, but "yesterday" or "not" do not). + def _word_matches(token_str: str, word: str) -> bool: + return token_str.startswith(word) and ( + len(token_str) == len(word) or not token_str[len(word)].isalpha() + ) + + yes_words = ("yes", " yes") + no_words = ("no", " no") + p_yes = 0.0 + p_no = 0.0 + for tok in logprobs.content or []: + top = getattr(tok, "top_logprobs", None) or [] + for t in top: + token_str = (getattr(t, "token", "") or "").lower() + lp = float(getattr(t, "logprob", -1e9) or -1e9) + prob = math.exp(lp) + if any(_word_matches(token_str, w) for w in yes_words): + p_yes += prob + elif any(_word_matches(token_str, w) for w in no_words): + p_no += prob + break # Only process the first output token's top_logprobs + eps = 1e-12 + denom = p_yes + p_no + if denom > eps: + ans = expected.strip().lower() + if ans == "yes": + return float(min(1.0, p_yes / denom)) + if ans == "no": + return float(min(1.0, p_no / denom)) + content_str = (choice.message.content or "").lower() + if expected.lower() in content_str: + return 1.0 + return 0.0 + except Exception: + response = self.generate([image], [prompt], **kwargs)[0].lower() + return 1.0 if expected.lower() in response else 0.0 + + def _image_to_data_url(self, image: Image.Image) -> str: + buffer = io.BytesIO() + image.save(buffer, format="PNG") + buffer.seek(0) + b64 = base64.b64encode(buffer.read()).decode("utf-8") + return f"data:image/png;base64,{b64}" + + +class TransformersVLM(BaseVLM): + """ + VLM using HuggingFace Transformers for local inference. + + Supports models like BLIP, LLaVA, SmolVLM, etc. + + Parameters + ---------- + model_name : str, optional + HuggingFace model name. Default is "Salesforce/blip2-opt-2.7b". + device : str | torch.device | None, optional + Device for inference. Auto-detected if None. + use_outlines : bool, optional + Whether to use outlines for constrained decoding when the caller passes a string + ``response_format``. Usually set from ``structured_output`` via :func:`get_vlm`. + model_load_kwargs : dict, optional + Kwargs passed to from_pretrained (e.g. dtype, attn_implementation). + **kwargs : Any + Additional arguments passed to model.generate. + + Notes + ----- + Prefer :func:`get_vlm` from metrics so ``structured_output`` and ``vlm_kwargs`` match + registry metrics. User manual: :doc:`Evaluate a model ` + (Vision-language judge metrics). + + Examples + -------- + Local judge only (same ``local`` pattern as :func:`get_vlm`; prefer :func:`get_vlm` from + metrics so ``structured_output`` and ``vlm_kwargs`` stay aligned): + + .. code-block:: python + + import torch + + from pruna.evaluation.metrics.vlm_base import TransformersVLM + + local = TransformersVLM( + model_name="HuggingFaceTB/SmolVLM-256M-Instruct", + device="cpu", + model_load_kwargs={"torch_dtype": torch.float32}, + ) + """ + + def __init__( + self, + model_name: str = "Salesforce/blip2-opt-2.7b", + device: Optional[str | torch.device] = None, + use_outlines: bool = False, + model_load_kwargs: Optional[dict] = None, + **kwargs: Any, + ) -> None: + self.model_name = model_name + self.use_outlines = use_outlines + self.model_load_kwargs = model_load_kwargs or {} + if device is None: + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + else: + self.device = torch.device(device) + self.extra_kwargs = kwargs + self._model = None + self._processor = None + self._yes_no_prefix_ids: Optional[tuple[list[int], list[int]]] = None + self._outlines_wrapped_model: Any = None + + def _load_model(self) -> None: + if self._model is not None: + return + try: + from transformers import AutoModelForImageTextToText, AutoProcessor + except ImportError: + pruna_logger.error("transformers not installed. Install with: pip install transformers") + raise + pruna_logger.info(f"Loading VLM model: {self.model_name}") + self._processor = AutoProcessor.from_pretrained(self.model_name) + self._model = AutoModelForImageTextToText.from_pretrained(self.model_name, **self.model_load_kwargs) + device = self.device + self._model.to(device) # type: ignore[invalid-argument-type] + self._model.eval() + + def _get_outlines_wrapped_model(self) -> Any: + """Lazily wrap HF model + processor for Outlines 1.x steerable generation.""" + if self._outlines_wrapped_model is None: + from outlines.models.transformers import from_transformers + + assert self._processor is not None, "_processor must be loaded before wrapping with outlines" + self._outlines_wrapped_model = from_transformers(self._model, self._processor) + return self._outlines_wrapped_model + + def _pil_for_outlines(self, image: Image.Image) -> Any: + """Wrap a PIL image for ``outlines.inputs.Image`` (requires a concrete ``format``).""" + from outlines.inputs import Image as OutlinesImage + + buf = io.BytesIO() + image.convert("RGB").save(buf, format="PNG") + buf.seek(0) + pil = Image.open(buf) + return OutlinesImage(pil) + + def _chat_user_with_images(self, images: List[Image.Image], prompt: str) -> Any: + """Build an ``outlines.inputs.Chat`` with one or more images then text (HF multimodal dicts).""" + from outlines.inputs import Chat + + parts: list[dict[str, Any]] = [] + for im in images: + parts.append({"type": "image", "image": self._pil_for_outlines(im)}) + parts.append({"type": "text", "text": prompt}) + return Chat([{"role": "user", "content": parts}]) + + def _outlines_output_term(self, response_format: Any) -> Any: + """ + Map metric ``response_format`` to an Outlines output type, or None for unconstrained decode. + + Returns + ------- + Any + A term accepted by :class:`outlines.generator.Generator`, or None. + """ + from outlines.types import json_schema, regex + + if isinstance(response_format, str): + if response_format == "integer": + return regex(r"\d+") + if response_format == "yes_no": + return regex(r"(Yes|No)") + return None + if isinstance(response_format, type): + try: + if issubclass(response_format, BaseModel): + return json_schema(response_format) + except TypeError: + return None + return None + + def _generate_steered(self, chats: List[Any], output_term: Any, max_new_tokens: int) -> List[str]: + """Run Outlines :class:`~outlines.generator.Generator` on prepared chat inputs.""" + from outlines import Generator + + om = self._get_outlines_wrapped_model() + results: List[str] = [] + with torch.compiler.set_stance("force_eager"): + gen = Generator(om, output_type=output_term) + for chat in chats: + try: + out = gen(chat, max_new_tokens=max_new_tokens) + results.append(out if isinstance(out, str) else str(out)) + except Exception as e: + pruna_logger.warning(f"Outlines generation failed: {e}, using empty string") + results.append("") + return results + + def generate( + self, + images: List[Image.Image], + prompts: List[str], + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> List[str]: + """ + Generate responses using local VLM. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + prompts : List[str] + List of text prompts. + response_format : Type[BaseModel] | str | None + When ``use_outlines`` is True: string ``integer`` / ``yes_no``, or a Pydantic model + class for JSON-schema constrained decoding; otherwise unconstrained ``model.generate``. + **kwargs : Any + Additional arguments passed to model generate. + + Returns + ------- + List[str] + Generated responses. + """ + self._load_model() + max_new_tokens = kwargs.get("max_new_tokens", 128) + term = self._outlines_output_term(response_format) if self.use_outlines else None + if term is not None: + chats = [self._chat_user_with_images([image], prompt) for image, prompt in zip(images, prompts)] + return self._generate_steered(chats, term, max_new_tokens) + return self._generate_standard(images, prompts, max_new_tokens) + + def generate_with_image_lists( + self, + image_lists: List[List[Image.Image]], + prompts: List[str], + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> List[str]: + """ + Generate with multiple PIL images per prompt (e.g. VIEScore source + edited). + + Uses the chat template path with several ``image`` parts then text. When + ``use_outlines`` is True and ``response_format`` maps to an Outlines output type + (string ``integer`` / ``yes_no`` or a Pydantic model class), uses the same + Outlines 1.x steerable path as :meth:`generate` via ``outlines.inputs.Chat``. + Otherwise uses unconstrained ``model.generate``. + + Parameters + ---------- + image_lists : list[list[PIL.Image.Image]] + One list of images per prompt. + prompts : list[str] + Prompts aligned with ``image_lists``. + response_format : optional + Same conventions as :meth:`generate` for structured decoding when outlines is enabled. + **kwargs : Any + Passed through (e.g. ``max_new_tokens``). + + Returns + ------- + list[str] + Decoded strings per row. + """ + if len(image_lists) != len(prompts): + raise ValueError("image_lists and prompts must have the same length.") + max_new_tokens = kwargs.get("max_new_tokens", 128) + self._load_model() + term = self._outlines_output_term(response_format) if self.use_outlines else None + if term is not None: + chats = [self._chat_user_with_images(imgs, prompt) for imgs, prompt in zip(image_lists, prompts)] + return self._generate_steered(chats, term, max_new_tokens) + results: List[str] = [] + with torch.inference_mode(): + for imgs, prompt in zip(image_lists, prompts): + inputs = self._prepare_inputs_multi(imgs, prompt) + input_len = inputs["input_ids"].shape[1] + output = self._model.generate(**inputs, max_new_tokens=max_new_tokens, **self.extra_kwargs) + response = self._decode_output(output[0][input_len:]) + results.append(response) + return results + + def _prepare_inputs(self, image: Image.Image, prompt: str) -> dict: + """Prepare model inputs, supporting both BLIP-style and chat-template processors.""" + try: + inputs = self._processor(images=[image], text=prompt, return_tensors="pt") + except (ValueError, TypeError): + conversation = [ + {"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]} + ] + inputs = self._processor.apply_chat_template( + conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ) + return {k: v.to(self.device) for k, v in inputs.items()} + + def _prepare_inputs_multi(self, images: List[Image.Image], prompt: str) -> dict: + """Chat-template inputs with multiple images then text (VIEScore ``tie``-style).""" + parts: list[dict[str, Any]] = [] + for im in images: + parts.append({"type": "image", "image": im}) + parts.append({"type": "text", "text": prompt}) + conversation = [{"role": "user", "content": parts}] + inputs = self._processor.apply_chat_template( + conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ) + return {k: v.to(self.device) for k, v in inputs.items()} + + def _decode_output(self, output_ids: torch.Tensor) -> str: + """Decode model output to text.""" + if hasattr(self._processor, "batch_decode"): + return self._processor.batch_decode([output_ids], skip_special_tokens=True)[0] + return self._processor.decode(output_ids, skip_special_tokens=True) + + def _generate_standard( + self, + images: List[Image.Image], + prompts: List[str], + max_new_tokens: int, + ) -> List[str]: + """Standard generation without outlines.""" + results = [] + with torch.inference_mode(): + for image, prompt in zip(images, prompts): + inputs = self._prepare_inputs(image, prompt) + input_len = inputs["input_ids"].shape[1] + output = self._model.generate(**inputs, max_new_tokens=max_new_tokens, **self.extra_kwargs) + # Decode only the newly generated tokens to avoid re-including the prompt text. + response = self._decode_output(output[0][input_len:]) + results.append(response) + return results + + def _get_tokenizer(self) -> Any: + """Return the HF tokenizer used for yes/no prefix ids and decoding.""" + self._load_model() + proc = self._processor + tok = getattr(proc, "tokenizer", None) or getattr(proc, "text_tokenizer", None) + if tok is None: + raise ValueError( + "Transformers VLM probability scoring requires a tokenizer on the processor; " + "pass use_probability=False for binary scoring." + ) + return tok + + def _score_yes_no_probability(self, image: Image.Image, question: str, answer: str) -> float: + """Soft VQAScore-style score from next-token softmax over yes/no prefix token ids.""" + from pruna.evaluation.metrics.vlm_utils import yes_no_first_token_id_groups + + self._load_model() + prompt = f"{question} Please answer yes or no." + inputs = self._prepare_inputs(image, prompt) + if self._yes_no_prefix_ids is None: + self._yes_no_prefix_ids = yes_no_first_token_id_groups(self._get_tokenizer()) + yes_ids, no_ids = self._yes_no_prefix_ids + if not yes_ids or not no_ids: + pruna_logger.warning("Empty yes/no prefix token ids; install a tokenizer with standard Yes/No encodings.") + return 0.0 + with torch.inference_mode(): + out = self._model(**inputs) + if not hasattr(out, "logits") or out.logits is None: + raise RuntimeError("Model forward did not return logits; cannot compute P(Yes).") + logits = out.logits[0, -1, :].float() + probs = torch.softmax(logits, dim=-1) + device = probs.device + p_yes = probs[torch.tensor(yes_ids, device=device, dtype=torch.long)].sum() + p_no = probs[torch.tensor(no_ids, device=device, dtype=torch.long)].sum() + denom = p_yes + p_no + ans = answer.strip().lower() + eps = 1e-12 + if float(denom.item()) < eps: + if ans == "yes": + return float(p_yes.clamp(0.0, 1.0).item()) + if ans == "no": + return float(p_no.clamp(0.0, 1.0).item()) + return 0.0 + if ans == "yes": + return float((p_yes / (denom + eps)).item()) + if ans == "no": + return float((p_no / (denom + eps)).item()) + return float((p_yes / (denom + eps)).item()) + + def score( + self, + images: List[Image.Image], + questions: List[str], + answers: List[str], + use_probability: bool = False, + response_format: Optional[Union[Type[BaseModel], Literal["integer"], Literal["yes_no"], Literal["json"]]] = None, + **kwargs: Any, + ) -> List[float]: + """ + Score how well answers match images for given questions. + + When ``use_probability`` is True, computes a VQAScore-style score from the next-token + distribution at the last context position (softmax mass on yes/no prefix token ids, + normalized over their union). Otherwise uses generation and binary substring matching. + + Parameters + ---------- + images : List[Image.Image] + List of PIL Images. + questions : List[str] + List of questions. + answers : List[str] + List of expected answers. + use_probability : bool, optional + If True, return a soft score from logits (no ``generate`` call). If False, binary. + response_format : Type[BaseModel] | str | None, optional + Structured output format for answer extraction (only when ``use_probability`` is False). + **kwargs : Any + Additional arguments passed to ``generate`` when ``use_probability`` is False. + + Returns + ------- + List[float] + Scores for each image-question pair in ``[0, 1]``. + """ + from pruna.evaluation.metrics.vlm_utils import get_answer_from_response + + scores = [] + for image, question, answer in zip(images, questions, answers): + if use_probability: + scores.append(self._score_yes_no_probability(image, question, answer)) + continue + prompt = f"{question} Please answer yes or no." + responses = self.generate([image], [prompt], response_format=response_format, **kwargs) + raw = responses[0] if responses else "" + response_answer = get_answer_from_response(raw) if response_format is not None else raw.lower() + score = 1.0 if answer.lower() in response_answer.lower() else 0.0 + scores.append(score) + return scores + + +def auxiliary_dicts_from_gt(gt: Any, batch_size: int) -> list[dict[str, Any]]: + """ + Map batch ``gt`` to per-row auxiliary dicts when using ``prompt_with_auxiliaries_collate``. + + For ``y_x`` metrics, :func:`~pruna.evaluation.metrics.utils.metric_data_processor` does not + include ``gt`` in its output; pass the batch ``gt`` argument here so fields such as + ``source_image_bytes`` are visible to editing metrics. + + Parameters + ---------- + gt : Any + Second element of the dataloader batch: typically a ``list[dict]`` of aux columns. + batch_size : int + Number of samples in the batch. + + Returns + ------- + list[dict[str, Any]] + One dict per row; empty dicts when ``gt`` is not a list of dicts (e.g. tensor placeholders + in tests). + """ + if batch_size <= 0: + return [] + if isinstance(gt, (list, tuple)) and gt and isinstance(gt[0], dict): + out: list[dict[str, Any]] = [] + for i in range(batch_size): + row = gt[i] if i < len(gt) else {} + out.append(row if isinstance(row, dict) else {}) + return out + return [{} for _ in range(batch_size)] + + +def prompts_from_y_x_inputs(inputs: Any, batch_len: int) -> list[str]: + """ + Extract per-row prompts from :func:`~pruna.evaluation.metrics.utils.metric_data_processor` output. + + Parameters + ---------- + inputs : Any + Return value of ``metric_data_processor`` for ``y_x`` call types. + batch_len : int + Number of samples in the batch. + + Returns + ------- + list[str] + Prompt list from ``inputs[1]`` when present; otherwise ``batch_len`` empty strings. + """ + if len(inputs) > 1 and isinstance(inputs[1], list): + return inputs[1] + return [""] * batch_len + + +class StatefulVLMMeanScoresMetric(StatefulMetric): + """ + Base for VLM metrics that accumulate ``scores`` and report the batch mean in :meth:`compute`. + + Subclasses set ``default_call_type`` and ``metric_name``, then call :meth:`_init_vlm_scores` + from ``__init__`` after any metric-specific attributes (e.g. ``use_probability``). + + Parameters + ---------- + device : str | torch.device | None + Device forwarded to :class:`~pruna.evaluation.metrics.metric_stateful.StatefulMetric`. + **kwargs : Any + Additional keyword arguments forwarded to the parent class. + + Notes + ----- + User guide: :doc:`Evaluate a model ` (Vision-language judge metrics). + Registry metrics (``VQAMetric``, ``VieScoreMetric``, …) pass ``vlm_type`` and ``model_name`` + into :meth:`_init_vlm_scores`; see :func:`get_vlm`. + + For **auxiliary image bytes** (editing benchmarks), use + :func:`~pruna.evaluation.metrics.vlm_utils.pil_rgb_from_aux_image_bytes` and + :data:`~pruna.evaluation.metrics.vlm_utils.VLM_AUX_IMAGE_BYTES_KEY_ORDER`. + """ + + scores: list[float] + default_call_type: str = "y_x" + higher_is_better: bool = True + metric_name: str = "" + + def _init_vlm_scores( + self, + *, + vlm: Optional[BaseVLM], + vlm_type: Literal["litellm", "transformers"], + model_name: Optional[str], + vlm_kwargs: Optional[dict[str, Any]], + structured_output: bool, + device: Optional[str | torch.device], + api_key: Optional[str], + call_type: str, + ) -> None: + """Attach ``self.vlm``, ``self.call_type``, and the ``scores`` state.""" + self.vlm = get_vlm( + vlm=vlm, + vlm_type=vlm_type, + model_name=model_name, + device=device, + api_key=api_key, + structured_output=structured_output, + **(vlm_kwargs or {}), + ) + self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) + self.add_state("scores", []) + self.higher_is_better = type(self).higher_is_better + + def compute_mean_of_scores(self) -> MetricResult: + """ + Return the mean of accumulated ``scores``, or ``0.0`` when empty. + + Returns + ------- + MetricResult + Aggregated result for this metric. + """ + if not self.scores: + return MetricResult(self.metric_name, self.__dict__, 0.0) + return MetricResult(self.metric_name, self.__dict__, float(np.mean(self.scores))) diff --git a/src/pruna/evaluation/metrics/vlm_utils.py b/src/pruna/evaluation/metrics/vlm_utils.py new file mode 100644 index 00000000..7e4f53e9 --- /dev/null +++ b/src/pruna/evaluation/metrics/vlm_utils.py @@ -0,0 +1,394 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared utilities and Pydantic models for VLM metrics.""" + +from __future__ import annotations + +import json +import math +import re +from io import BytesIO +from typing import Any, List, Sequence + +import torch +from PIL import Image +from pydantic import BaseModel, Field + +VLM_AUX_IMAGE_BYTES_KEY_ORDER: tuple[str, ...] = ( + "source_image_bytes", + "input_image_bytes", + "reference_image_bytes", + "image_bytes", +) + + +def _tensor_to_pil(tensor: torch.Tensor) -> Image.Image: + if tensor.ndim == 4: + tensor = tensor[0] + if tensor.max() > 1: + tensor = tensor / 255.0 + np_img = (tensor.cpu().numpy() * 255).astype("uint8") + return Image.fromarray(np_img.transpose(1, 2, 0)) + + +def _process_images(images: torch.Tensor) -> List[Any]: + return [_tensor_to_pil(img) if isinstance(img, torch.Tensor) else img for img in images] + + +def pil_rgb_from_aux_image_bytes( + aux: dict[str, Any], + *, + min_bytes_in_value_scan: int = 0, +) -> Image.Image | None: + """ + Decode a source / reference RGB image from auxiliary dict bytes. + + Tries :data:`VLM_AUX_IMAGE_BYTES_KEY_ORDER` first, then scans ``aux.values()`` for raw + byte blobs. The value scan skips blobs shorter than ``min_bytes_in_value_scan`` (use + ``100`` to match editing metrics that avoid tiny false positives). + + Parameters + ---------- + aux : dict[str, Any] + Per-sample auxiliary dict (e.g. from ``prompt_with_auxiliaries_collate``). + min_bytes_in_value_scan : int, optional + Minimum length for blobs discovered only in the generic ``aux.values()`` pass. + Named keys use any non-empty ``bytes`` / ``bytearray``. Use ``0`` when any + non-trivial blob in ``aux.values()`` should be tried (e.g. tests building preds from aux). + + Returns + ------- + PIL.Image.Image | None + RGB image if decoding succeeds; ``None`` if nothing decodable was found. + """ + for key in VLM_AUX_IMAGE_BYTES_KEY_ORDER: + raw = aux.get(key) + if isinstance(raw, (bytes, bytearray)) and raw: + try: + return Image.open(BytesIO(raw)).convert("RGB") + except Exception: + continue + for v in aux.values(): + if isinstance(v, (bytes, bytearray)) and v and len(v) >= min_bytes_in_value_scan: + try: + return Image.open(BytesIO(v)).convert("RGB") + except Exception: + continue + return None + + +def yes_no_first_token_id_groups(tokenizer: Any) -> tuple[list[int], list[int]]: + """ + Collect first subword token ids that start a yes/no answer for next-token softmax scoring. + + Used by :class:`~pruna.evaluation.metrics.vlm_base.TransformersVLM` for VQAScore-style + P(Yes): sum softmax mass on these ids, normalized against yes+no for a stable [0, 1] score. + + Parameters + ---------- + tokenizer : Any + Hugging Face ``PreTrainedTokenizer`` (or compatible ``encode``). + + Returns + ------- + list[int] + Distinct token ids for yes-leaning first tokens (overlap with no-ids removed). + list[int] + Distinct token ids for no-leaning first tokens (overlap with yes-ids removed). + """ + yes_prefixes = ( + "Yes", + " Yes", + " yes", + "yes", + "\nYes", + "\n Yes", + "Yes,", + " Yes,", + ) + no_prefixes = ( + "No", + " No", + " no", + "no", + "\nNo", + "\n No", + "No,", + " No,", + ) + yes_ids: set[int] = set() + no_ids: set[int] = set() + for s in yes_prefixes: + ids = tokenizer.encode(s, add_special_tokens=False) + if ids: + yes_ids.add(ids[0]) + for s in no_prefixes: + ids = tokenizer.encode(s, add_special_tokens=False) + if ids: + no_ids.add(ids[0]) + overlap = yes_ids & no_ids + yes_only = sorted(yes_ids - overlap) + no_only = sorted(no_ids - overlap) + return yes_only, no_only + + +class VQAnswer(BaseModel): + """ + Structured output for VQA questions (Yes/No or open-ended). + + Parameters + ---------- + answer : str + Answer to the question. Typically "Yes" or "No" for alignment metrics, + but can be any string for open-ended questions. + """ + + answer: str = Field(description="Answer to the question") + + +class FloatOutput(BaseModel): + """ + Structured output for numeric scoring (img_edit_score, VieScoreMetric). + + Parameters + ---------- + score : float + Score from 0 to 10. + """ + + score: float = Field(ge=0, le=10, description="Score from 0 to 10") + + +class VIEScoreJsonOutput(BaseModel): + """ + Structured output matching VIEScore JSON (text-to-image / editing evaluation). + + Parameters + ---------- + score : list[float] + One or more sub-scores on a 0--10 scale (e.g. two criteria for editing). + reasoning : str + Short evaluator reasoning. + """ + + score: list[float] = Field(description="Sub-scores on 0-10 scale") + reasoning: str = Field(default="", description="Brief reasoning") + + +def _json_dict_from_response_fragment(text: str) -> dict | None: + """Parse a leading JSON object from a string response, or return None.""" + stripped = (text or "").strip() + if not stripped.startswith("{"): + return None + try: + data = json.loads(stripped) + except (json.JSONDecodeError, TypeError): + return None + return data if isinstance(data, dict) else None + + +class TextOutput(BaseModel): + """ + Structured output for text extraction (text_score). + + Parameters + ---------- + text : str + Extracted text from the image, or 'No text recognized' if empty. + """ + + text: str = Field(description="Extracted text from the image, or 'No text recognized' if empty") + + +def get_answer_from_response(response: str | BaseModel | dict) -> str: + """ + Extract answer string from a VLM score() response (VQAnswer, dict, or raw string). + + Parameters + ---------- + response : str | BaseModel | dict + Raw response from vlm.generate() or vlm.score(). + + Returns + ------- + str + Extracted answer string, or empty string. + """ + if response is None: + return "" + if isinstance(response, VQAnswer): + return response.answer + if isinstance(response, dict): + return response.get("answer", "") + raw = str(response).strip() + parsed = _json_dict_from_response_fragment(raw) + if parsed is not None: + return str(parsed.get("answer", raw)) + return raw + + +def get_text_from_response(response: str | BaseModel | dict) -> str: + """ + Extract text from a VLM generate() response (str, pydantic, or dict). + + Parameters + ---------- + response : str | BaseModel | dict + Raw response from vlm.generate(). + + Returns + ------- + str + Extracted text, or empty string. + """ + if response is None: + return "" + if isinstance(response, TextOutput): + text = response.text + elif isinstance(response, dict): + text = response.get("text", "") + else: + text = (response or "").strip() + parsed = _json_dict_from_response_fragment(text) + if parsed is not None: + text = str(parsed.get("text", text)) + for phrase in ("No text recognized", "no text recognized", "No text"): + text = text.replace(phrase, "").strip() + return (text or "").strip() + + +def get_score_from_response(response: str | BaseModel | dict) -> float: + """ + Extract numeric score (0-10) from a VLM generate() response. + + Handles: + + * ``FloatOutput`` instances (local / parsed Pydantic). + * ``dict`` with a ``"score"`` key. + * JSON **strings** (e.g. LitellmVLM returns ``model_dump_json()`` for structured output). + * Plain text with a number (first decimal or integer matched). + + Parameters + ---------- + response : str | BaseModel | dict + Raw response from vlm.generate(). + + Returns + ------- + float + Score in [0, 1] (normalized from 0-10). Always non-negative. + """ + if response is None: + return 0.0 + if isinstance(response, FloatOutput): + return max(0.0, min(float(response.score), 10.0)) / 10.0 + if isinstance(response, dict): + return max(0.0, min(float(response.get("score", 0)), 10.0)) / 10.0 + text = str(response or "").strip() + parsed = _json_dict_from_response_fragment(text) + if parsed is not None and "score" in parsed: + try: + return max(0.0, min(float(parsed["score"]), 10.0)) / 10.0 + except (TypeError, ValueError): + pass + match = re.search(r"\d+(?:\.\d+)?", text) + if match: + return min(float(match.group(0)), 10.0) / 10.0 + return 0.0 + + +def viescore_min_scores_0_10(response: str | BaseModel | dict) -> list[float]: + """ + Parse VIEScore-style JSON with a ``score`` list of values in ``[0, 10]``. + + Parameters + ---------- + response : str | BaseModel | dict + Model output (pydantic ``VIEScoreJsonOutput``, dict, or JSON string). + + Returns + ------- + list[float] + Sub-scores; empty if parsing fails. + """ + if response is None: + return [] + if isinstance(response, VIEScoreJsonOutput): + return [float(x) for x in response.score] + if isinstance(response, dict): + raw = response.get("score", []) + if isinstance(raw, (list, tuple)): + return [float(x) for x in raw] + return [] + text = str(response or "").strip() + parsed = _json_dict_from_response_fragment(text) + if parsed is not None and "score" in parsed: + try: + raw = parsed["score"] + if isinstance(raw, (list, tuple)): + return [float(x) for x in raw] + return [float(raw)] + except (TypeError, ValueError): + return [] + return [] + + +def pad_viescore_subscores_to_two(values: list[float]) -> list[float]: + """ + Pad or truncate VIEScore sub-score lists to length two for :func:`viescore_tie_overall_unit`. + + Parameters + ---------- + values : list[float] + Parsed sub-scores from :func:`viescore_min_scores_0_10`. + + Returns + ------- + list[float] + Exactly two values in ``[0, 10]``, padding with ``0.0`` when fewer than two are present. + """ + if len(values) >= 2: + return values[:2] + if not values: + return [0.0, 0.0] + return values + [0.0] * (2 - len(values)) + + +def viescore_tie_overall_unit(sc_scores: Sequence[float], pq_scores: Sequence[float]) -> float: + """ + Overall VIEScore for text-image editing (``tie`` task): ``sqrt(min(SC)*min(PQ))/10`` in ``[0, 1]``. + + Matches the reference ``math.sqrt(SC_score * PQ_score)`` on a 0--10 scale with + ``SC_score = min(...)``, ``PQ_score = min(...)`` (`VIEScore`_). + + .. _VIEScore: https://github.com/TIGER-AI-Lab/VIEScore + + Parameters + ---------- + sc_scores : Sequence[float] + Semantic / instruction sub-scores on 0--10 (e.g. editing success and over-editing). + pq_scores : Sequence[float] + Perceptual sub-scores on 0--10 (e.g. naturalness and artifacts). + + Returns + ------- + float + Overall score in ``[0, 1]`` (higher is better). + """ + if not sc_scores or not pq_scores: + return 0.0 + sc = min(float(x) for x in sc_scores) + pq = min(float(x) for x in pq_scores) + return math.sqrt(sc * pq) / 10.0 diff --git a/tests/evaluation/test_vlm_base_infrastructure.py b/tests/evaluation/test_vlm_base_infrastructure.py new file mode 100644 index 00000000..a4eaa139 --- /dev/null +++ b/tests/evaluation/test_vlm_base_infrastructure.py @@ -0,0 +1,684 @@ +"""Tests for VLM metrics (VQA, ImageEditScore, QAAccuracy, TextScore, VieScore) and vlm_utils helpers.""" + +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from pruna.evaluation.metrics.metric_img_edit_score import ImageEditScoreMetric +from pruna.evaluation.metrics.metric_oneig_alignment import OneIGAlignmentMetric +from pruna.evaluation.metrics.metric_qa_accuracy import QAAccuracyMetric +from pruna.evaluation.metrics.metric_text_score import OneIGTextScoreMetric, TextScoreMetric +from pruna.evaluation.metrics.metric_vie_score import VieScoreMetric +from pruna.evaluation.metrics.metric_vqa import VQAMetric +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.vlm_base import BaseVLM, get_vlm +from pruna.evaluation.metrics.vlm_utils import ( + FloatOutput, + VLM_AUX_IMAGE_BYTES_KEY_ORDER, + get_score_from_response, + yes_no_first_token_id_groups, +) + +from ._vlm_batch_snapshot_helpers import ( + BenchmarkVlmBatchOutcome, + pred_tensor_from_auxiliaries, + safe_json_for_snapshot, + vlm_benchmark_batch_to_json_record, +) + +SMOL_VLM = "HuggingFaceTB/SmolVLM-256M-Instruct" + +_ALL_VLM = ( + VQAMetric, + ImageEditScoreMetric, + QAAccuracyMetric, + OneIGAlignmentMetric, + TextScoreMetric, + OneIGTextScoreMetric, + VieScoreMetric, +) + +_SLOW_SMOL_SUBSET = ( + VQAMetric, + OneIGAlignmentMetric, + ImageEditScoreMetric, + VieScoreMetric, +) + + +@pytest.mark.parametrize( + ("raw", "expected"), + [ + (FloatOutput(score=8.0), 0.8), + ({"score": 5.0}, 0.5), + ('{"score": 7.5}', 0.75), + ('{"score": 10}', 1.0), + ("8", 0.8), + ("Score: 7.5 out of 10", 0.75), + ("", 0.0), + ], +) +def test_get_score_from_response(raw: object, expected: float) -> None: + """``get_score_from_response`` maps pydantic, dict, JSON, and text to ``[0, 1]``.""" + assert get_score_from_response(raw) == pytest.approx(expected) + + +def _dummy_image(batch: int = 1, size: int = 224) -> torch.Tensor: + return torch.rand(batch, 3, size, size) + + +def _update_metric(metric: object, prompts: list, images: torch.Tensor) -> None: + if isinstance(metric, OneIGAlignmentMetric): + metric.update( + prompts, + [ + { + "questions": {"1": "Is there a cat?", "2": "Is it sleeping?"}, + "dependencies": {"1": [0], "2": [1]}, + } + ], + images, + ) + elif isinstance(metric, QAAccuracyMetric): + metric.update( + prompts, + [{"questions": {"1": "Is there a cat?"}}], + images, + ) + elif isinstance(metric, (TextScoreMetric, OneIGTextScoreMetric)): + metric.update(prompts, ["cat"], images) + else: + metric.update(prompts, images, images) + + +@pytest.mark.cpu +@pytest.mark.slow +@pytest.mark.parametrize("metric_cls", _SLOW_SMOL_SUBSET) +def test_vlm_metrics_transformers_smolvlm(metric_cls: type) -> None: + """Smoke-test a subset with local SmolVLM (full matrix covered by litellm mock).""" + metric = metric_cls( + vlm_type="transformers", + model_name=SMOL_VLM, + device="cpu", + structured_output=True, + ) + images = _dummy_image(batch=1) + prompts = ["a cat"] + _update_metric(metric, prompts, images) + result = metric.compute() + assert result.name == metric.metric_name + assert isinstance(result.result, float) + if metric.higher_is_better: + assert 0.0 <= result.result <= 1.0 + else: + assert result.result >= 0.0 + + +@pytest.mark.cpu +@pytest.mark.parametrize("metric_cls", _ALL_VLM) +def test_vlm_metrics_litellm_mocked(metric_cls: type) -> None: + """Each VLM metric runs end-to-end with mocked litellm.""" + pytest.importorskip("litellm") + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + if metric_cls in (VQAMetric, QAAccuracyMetric, OneIGAlignmentMetric): + mock_response.choices[0].message.content = '{"answer": "Yes"}' + else: + mock_response.choices[0].message.content = '{"score": 8}' + + with patch("litellm.completion") as mock_completion: + mock_completion.return_value = mock_response + + metric = metric_cls( + vlm_type="litellm", + model_name="gpt-4o", + device="cpu", + structured_output=True, + ) + images = _dummy_image(batch=1) + prompts = ["a cat"] + _update_metric(metric, prompts, images) + result = metric.compute() + + assert result.name == metric.metric_name + assert isinstance(result.result, float) + assert mock_completion.called + + +@pytest.mark.cpu +def test_vlm_metrics_empty_compute_returns_zero() -> None: + """No updates → compute is 0.0 (same for all stateful VLM metrics).""" + metric = VQAMetric( + vlm_type="transformers", + model_name=SMOL_VLM, + device="cpu", + structured_output=True, + ) + assert metric.compute().result == 0.0 + + +@pytest.mark.cpu +def test_vlm_metrics_custom_vlm() -> None: + """Custom VLM passed to VQAMetric is used instead of the default litellm backend.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.generate.return_value = ["Yes"] + mock_vlm.score.return_value = [1.0] + + metric = VQAMetric(vlm=mock_vlm, vlm_type="litellm", device="cpu", structured_output=True) + images = _dummy_image(batch=1) + prompts = ["a cat"] + metric.update(prompts, images, images) + assert metric.compute().result == 1.0 + mock_vlm.score.assert_called() + + +@pytest.mark.cpu +def test_get_vlm_returns_custom() -> None: + """get_vlm returns the provided VLM instance unchanged.""" + custom = MagicMock(spec=BaseVLM) + out = get_vlm(vlm=custom, vlm_type="litellm", model_name="gpt-4o") + assert out is custom + + +@pytest.mark.cpu +def test_yes_no_first_token_id_groups_disjoint() -> None: + """Prefix token ids for Yes vs No should not overlap (avoids double-counting).""" + pytest.importorskip("transformers") + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained("gpt2") + yes_ids, no_ids = yes_no_first_token_id_groups(tok) + assert yes_ids and no_ids + assert not (set(yes_ids) & set(no_ids)) + + +@pytest.mark.cpu +def test_get_vlm_requires_model_name_without_vlm() -> None: + """get_vlm raises ValueError when no model_name is given and no vlm is provided.""" + with pytest.raises(ValueError, match="model_name"): + get_vlm(vlm=None, vlm_type="litellm") + + +@pytest.mark.cpu +@pytest.mark.parametrize( + "metric_cls, expected_name, expected_result", + [ + (TextScoreMetric, "text_score", 1.0), + (OneIGTextScoreMetric, "oneig_text_score", 1.0), + ], +) +def test_text_metrics_list_str_gt(metric_cls: type, expected_name: str, expected_result: float) -> None: + """Text metrics accept plain string ground-truth and return the expected score.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.generate.return_value = ["hello world"] + + metric = metric_cls(vlm=mock_vlm, vlm_type="litellm", device="cpu") + images = _dummy_image(batch=1) + metric.update(["a prompt"], ["hello world"], images) + result = metric.compute() + + assert result.result == expected_result + assert result.name == expected_name + mock_vlm.generate.assert_called_once() + + +@pytest.mark.cpu +def test_text_score_result_in_zero_one_range() -> None: + """TextScoreMetric must return a normalized score in [0, 1], not raw edit distance.""" + mock_vlm = MagicMock(spec=BaseVLM) + # VLM OCR returns something very different from ground truth (high edit distance) + mock_vlm.generate.return_value = ["completely wrong text abcdefghijklmnop"] + + metric = TextScoreMetric(vlm=mock_vlm, device="cpu") + images = _dummy_image(batch=1) + metric.update(["prompt"], ["hello"], images) + result = metric.compute() + + assert 0.0 <= result.result <= 1.0, f"TextScoreMetric must return [0,1], got {result.result}" + assert result.result < 0.5, f"Very different strings should score below 0.5, got {result.result}" + + +@pytest.mark.cpu +def test_text_score_perfect_match_is_one() -> None: + """TextScoreMetric: identical OCR and GT -> score 1.0.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.generate.return_value = ["hello world"] + + metric = TextScoreMetric(vlm=mock_vlm, device="cpu") + images = _dummy_image(batch=1) + metric.update(["prompt"], ["hello world"], images) + result = metric.compute() + + assert result.result == 1.0, f"Perfect match should give 1.0, got {result.result}" + assert result.higher_is_better is True + + +@pytest.mark.cpu +def test_text_score_registry_aliases() -> None: + """Registry aliases ocr_levenshtein and ocr_text_score resolve to the correct metric classes.""" + from pruna.evaluation.metrics.registry import MetricRegistry + + lev = MetricRegistry.get_metric("ocr_levenshtein", device="cpu", model_name="openai/gpt-4o") + comp = MetricRegistry.get_metric("ocr_text_score", device="cpu", model_name="openai/gpt-4o") + assert type(lev).__name__ == "TextScoreMetric" + assert type(comp).__name__ == "OneIGTextScoreMetric" + assert lev.metric_name == "text_score" + assert comp.metric_name == "oneig_text_score" + + +@pytest.mark.cpu +def test_oneig_text_score_utils_golden_composite() -> None: + """oneig_mean_text_score returns expected component values for a known input.""" + from pruna.evaluation.metrics.metric_text_score_utils import oneig_mean_text_score + + ed, cr, wac, composite = oneig_mean_text_score( + edit_distances=[10.0], + completion_ratios=[0.0], + match_counts=[2], + gt_totals=[4], + language_mode="EN", + ) + assert ed == 10.0 + assert cr == 0.0 + assert wac == 0.5 + assert composite == pytest.approx(0.95) + + _, _, _, zh = oneig_mean_text_score( + edit_distances=[30.0], + completion_ratios=[0.0], + match_counts=[0], + gt_totals=[1], + language_mode="ZH", + ) + assert zh == pytest.approx(0.4) + + +@pytest.mark.cpu +def test_qa_accuracy_all_or_nothing_partial_fail() -> None: + """all_or_nothing: if any question scores 0, the image score is 0.0 (not a partial mean).""" + mock_vlm = MagicMock(spec=BaseVLM) + # First question Yes (1.0), second question No (0.0) → mean=0.5, all_or_nothing=0.0 + mock_vlm.score.return_value = [1.0, 0.0] + + metric = QAAccuracyMetric(vlm=mock_vlm, device="cpu", aggregation="all_or_nothing") + metric.update( + ["a prompt"], + [{"questions": {"1": "Is there a cat?", "2": "Is it blue?"}}], + _dummy_image(batch=1), + ) + result = metric.compute() + assert result.result == 0.0, f"Expected 0.0 for all_or_nothing with one No, got {result.result}" + + +@pytest.mark.cpu +def test_qa_accuracy_all_or_nothing_all_yes() -> None: + """all_or_nothing: all Yes → score 1.0.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.score.return_value = [1.0, 1.0] + + metric = QAAccuracyMetric(vlm=mock_vlm, device="cpu", aggregation="all_or_nothing") + metric.update( + ["a prompt"], + [{"questions": {"1": "Is there a cat?", "2": "Is it blue?"}}], + _dummy_image(batch=1), + ) + result = metric.compute() + assert result.result == 1.0, f"Expected 1.0 for all_or_nothing with all Yes, got {result.result}" + + +@pytest.mark.cpu +def test_qa_accuracy_invalid_aggregation_raises() -> None: + """qa_accuracy rejects aggregation values other than mean / all_or_nothing.""" + mock_vlm = MagicMock(spec=BaseVLM) + with pytest.raises(ValueError, match="aggregation"): + QAAccuracyMetric(vlm=mock_vlm, device="cpu", aggregation="median") + + +@pytest.mark.cpu +def test_vie_score_tie_uses_source_from_gt_and_two_image_sc() -> None: + """With ``source_image_bytes`` in gt, VieScore calls two-image SC then PQ on the edited image.""" + from io import BytesIO + + from PIL import Image + + buf = BytesIO() + Image.new("RGB", (8, 8), color=(0, 0, 200)).save(buf, format="PNG") + src_bytes = buf.getvalue() + + mock_vlm = MagicMock() + mock_vlm.generate_with_image_lists.return_value = ['{"score": [8.0, 8.0], "reasoning": "ok"}'] + mock_vlm.generate.return_value = ['{"score": [9.0, 9.0], "reasoning": "ok"}'] + + metric = VieScoreMetric(vlm=mock_vlm, device="cpu", structured_output=True) + pred = _dummy_image(batch=1) + metric.update( + ["make the sky purple"], + [{"source_image_bytes": src_bytes}], + pred, + ) + result = metric.compute() + + assert mock_vlm.generate_with_image_lists.called + assert mock_vlm.generate.called + assert 0.0 <= result.result <= 1.0 + + +@pytest.mark.cpu +def test_vie_score_uses_get_score_from_response() -> None: + """VieScoreMetric ``t2i`` path parses JSON ``score`` lists via ``viescore_min_scores_0_10``.""" + mock_vlm = MagicMock(spec=BaseVLM) + # LitellmVLM returns model_dump_json() for structured outputs → JSON string (two SC + two PQ sub-scores) + mock_vlm.generate.return_value = ['{"score": [8.0, 8.0], "reasoning": ""}'] + + metric = VieScoreMetric(vlm=mock_vlm, device="cpu", structured_output=True) + metric.update(["a cat on a sofa"], _dummy_image(batch=1), _dummy_image(batch=1)) + result = metric.compute() + + # min(SC)=8, min(PQ)=8 → sqrt(8 * 8) / 10 = 0.8 + assert abs(result.result - 0.8) < 0.01, f"Expected ~0.8, got {result.result}" + + +@pytest.mark.cpu +def test_img_edit_score_negative_response_clamped() -> None: + """img_edit_score must be non-negative even when the VLM generates a negative JSON score. + + Regression for: Outlines constrained decoding can emit {"score": -10} despite the + FloatOutput JSON schema specifying minimum=0, because Outlines does not enforce numeric + bounds during token sampling. The fix is max(0.0, ...) in get_score_from_response. + """ + mock_vlm = MagicMock(spec=BaseVLM) + # Simulate Outlines generating a negative value (the bug scenario) + mock_vlm.generate.return_value = ['{"score": -10.0}'] + + metric = ImageEditScoreMetric(vlm=mock_vlm, device="cpu", structured_output=True) + metric.update(["replace the boot with a mug"], torch.zeros(1), _dummy_image(batch=1)) + result = metric.compute() + + assert result.result >= 0.0, f"img_edit_score must be >= 0, got {result.result}" + + +@pytest.mark.cpu +def test_qa_accuracy_all_or_nothing_ambiguous_score() -> None: + """all_or_nothing: score exactly 0.5 (ambiguous) is treated as No → result 0.0.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.score.return_value = [0.5] + + metric = QAAccuracyMetric(vlm=mock_vlm, device="cpu", aggregation="all_or_nothing") + metric.update( + ["a prompt"], + [{"questions": {"1": "Is there a cat?"}}], + _dummy_image(batch=1), + ) + result = metric.compute() + assert result.result == 0.0, f"Score 0.5 should be treated as No (ambiguous), got {result.result}" + + +@pytest.mark.cpu +@pytest.mark.slow +def test_yes_no_token_ids_smolvlm_nonempty() -> None: + """SmolVLM tokenizer must yield non-empty disjoint yes/no prefix ids for VQAScore scoring.""" + pytest.importorskip("transformers") + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct") + yes_ids, no_ids = yes_no_first_token_id_groups(tok) + assert len(yes_ids) > 0, "SmolVLM tokenizer has no 'Yes'-prefix token ids" + assert len(no_ids) > 0, "SmolVLM tokenizer has no 'No'-prefix token ids" + assert not (set(yes_ids) & set(no_ids)), "yes_ids and no_ids must be disjoint" + + +@pytest.mark.cpu +def test_img_edit_score_uses_prompt_from_x() -> None: + """img_edit_score must score the edited image against the instruction from x, not gt.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.generate.return_value = ['{"score": 9}'] + + metric = ImageEditScoreMetric(vlm=mock_vlm, device="cpu") + pred = _dummy_image(batch=1) + metric.update( + ["replace the cat with a dog"], # x = instruction + pred, # gt = unused for y_x + pred, # outputs = edited image + ) + result = metric.compute() + + call_args = mock_vlm.generate.call_args + prompt_sent = call_args[0][1][0] # second positional arg = prompts list, first item + assert "replace the cat with a dog" in prompt_sent, f"Instruction not in VLM prompt. Got: {prompt_sent}" + assert abs(result.result - 0.9) < 0.01, f"Expected ~0.9, got {result.result}" + + +@pytest.mark.cpu +def test_vie_score_geditbench_gap_documented() -> None: + """VieScoreMetric infers text--image editing from ``source_image_bytes`` in aux (no ``task_type``). + + This test fails if a ``task_type`` parameter is added to ``__init__`` without updating + GEditBench integration tests and benchmark copy accordingly. + """ + import inspect + + sig = inspect.signature(VieScoreMetric.__init__) + assert "task_type" not in sig.parameters, ( + "VieScoreMetric now has task_type — update GEditBench docs and e2e tests, then remove this sentinel." + ) + + +@pytest.mark.cpu +def test_litellm_logprob_aggregation_sums_all_yes_tokens() -> None: + """LitellmVLM logprob scoring must sum all yes-prefix token probs, not return the first.""" + pytest.importorskip("litellm") + import math + from unittest.mock import MagicMock, patch + + import numpy as np + from PIL import Image + + from pruna.evaluation.metrics.vlm_base import LitellmVLM + + # Simulate top_logprobs for first output token: + # "Yes" → logprob=-2.303 (p≈0.10), " yes" → logprob=-2.996 (p≈0.05) → total p_yes≈0.15 + # "No" → logprob=-1.609 (p≈0.20), " no" → logprob=-2.303 (p≈0.10) → total p_no≈0.30 + # normalized: p_yes/(p_yes+p_no) ≈ 0.15/0.45 ≈ 0.333 + def make_top_logprob(token, logprob): + t = MagicMock() + t.token = token + t.logprob = logprob + return t + + first_tok = MagicMock() + first_tok.top_logprobs = [ + make_top_logprob("Yes", math.log(0.10)), + make_top_logprob(" yes", math.log(0.05)), + make_top_logprob("No", math.log(0.20)), + make_top_logprob(" no", math.log(0.10)), + make_top_logprob("maybe", math.log(0.55)), + ] + + mock_logprobs = MagicMock() + mock_logprobs.content = [first_tok] + + mock_choice = MagicMock() + mock_choice.logprobs = mock_logprobs + mock_choice.message.content = "Yes" + + mock_response = MagicMock() + mock_response.choices = [mock_choice] + + with patch("litellm.completion", return_value=mock_response): + vlm = LitellmVLM(model_name="openai/gpt-4o") + img = Image.fromarray(np.zeros((32, 32, 3), dtype="uint8")) + score = vlm._score_with_logprobs(img, "Is there a cat?", "Yes") + + # Should be ~0.333 (p_yes=0.15 / (p_yes+p_no)=0.45), not just 0.10 (first match) + assert 0.28 < score < 0.40, f"Expected ~0.333 (sum-normalized), got {score}" + + +@pytest.mark.cpu +@pytest.mark.slow +def test_vqa_probability_score_normalized() -> None: + """P(Yes) from TransformersVLM.score use_probability=True is in [0, 1].""" + pytest.importorskip("transformers") + import numpy as np + from PIL import Image + + from pruna.evaluation.metrics.vlm_base import TransformersVLM + + vlm = TransformersVLM( + model_name="HuggingFaceTB/SmolVLM-256M-Instruct", + device="cpu", + use_outlines=False, + ) + img = Image.fromarray(np.zeros((32, 32, 3), dtype="uint8")) + scores = vlm.score([img], ["Is there a cat?"], ["Yes"], use_probability=True) + assert len(scores) == 1 + assert 0.0 <= scores[0] <= 1.0, f"P(Yes) must be in [0, 1], got {scores[0]}" + + +# --------------------------------------------------------------------------- +# vlm_benchmark_batch_to_json_record serialization tests +# --------------------------------------------------------------------------- + + +def test_vlm_benchmark_batch_to_json_record_serializes_batch() -> None: + """Record includes prompts, pred shape, and metric fields.""" + mr = MetricResult(name="qa_accuracy", params={}, result=0.25, higher_is_better=True) + outcome = BenchmarkVlmBatchOutcome( + result=mr, + prompts=["prompt"], + auxiliaries=[{"path": "/tmp/x.png"}], + pred=torch.zeros(1, 3, 8, 8), + ) + rec = vlm_benchmark_batch_to_json_record( + outcome, + benchmark_key="GenEval", + benchmark_name="GenEval", + metric_name="qa_accuracy", + vlm_type="transformers", + model_name="m", + device="cpu", + ) + assert rec["inputs"]["prompts"] == ["prompt"] + assert rec["pred"]["shape"] == [1, 3, 8, 8] + assert rec["metric_result"]["result"] == 0.25 + + +def test_safe_json_handles_bytes_without_expanding() -> None: + """Bytes values in aux (e.g. source_image_bytes) are summarized, not expanded to str repr.""" + result = safe_json_for_snapshot({"source_image_bytes": b"\xff\xd8\xff" * 1000, "name": "test"}) + assert result["source_image_bytes"] == {"bytes_len": 3000} + assert result["name"] == "test" + + +def test_vlm_benchmark_batch_to_json_record_preserves_null_question_slots() -> None: + """Padded ``None`` question labels stay JSON null, not the string ``"None"``.""" + mr = MetricResult(name="oneig_alignment", params={}, result=1.0, higher_is_better=True) + outcome = BenchmarkVlmBatchOutcome( + result=mr, + prompts=["p"], + auxiliaries=[{"questions": {"1": "Are there boys?", "21": None}, "subset": "Anime_Stylization"}], + pred=torch.zeros(1, 3, 8, 8), + ) + rec = vlm_benchmark_batch_to_json_record( + outcome, + benchmark_key="OneIGAnimeStylization", + benchmark_name="OneIG Anime Stylization", + metric_name="oneig_alignment", + vlm_type="transformers", + model_name="m", + device="cpu", + ) + qs = rec["inputs"]["auxiliary_0"]["questions"] + assert qs["1"] == "Are there boys?" + assert qs["21"] is None + + +# --------------------------------------------------------------------------- +# pred_tensor_from_auxiliaries (test helper, wraps pil_rgb_from_aux_image_bytes) tests +# --------------------------------------------------------------------------- + + +def _make_jpeg_bytes(h: int = 32, w: int = 32) -> bytes: + """Return a tiny JPEG-encoded RGB image as bytes (test helper).""" + import io + + import numpy as np + from PIL import Image + + arr = (np.random.rand(h, w, 3) * 255).astype("uint8") + buf = io.BytesIO() + Image.fromarray(arr).save(buf, format="JPEG") + return buf.getvalue() + + +@pytest.mark.cpu +def test_pred_from_auxiliaries_uses_source_image_bytes() -> None: + """pred_tensor_from_auxiliaries decodes source_image_bytes into a float tensor in [0, 1].""" + src_bytes = _make_jpeg_bytes() + aux = [{"source_image_bytes": src_bytes, "category": "background_change"}] + pred = pred_tensor_from_auxiliaries(aux, size=64) + + assert pred.shape == (1, 3, 64, 64), f"Expected (1,3,64,64), got {pred.shape}" + assert pred.min() >= 0.0 and pred.max() <= 1.0, "Pixel values must be in [0, 1]" + + +@pytest.mark.cpu +def test_pred_from_auxiliaries_falls_back_to_noise_without_source_image() -> None: + """pred_tensor_from_auxiliaries returns random noise when no source_image_bytes is present.""" + aux = [{"category": "single_object"}] + pred = pred_tensor_from_auxiliaries(aux, size=32) + assert pred.shape == (1, 3, 32, 32) + assert pred.min() >= 0.0 and pred.max() <= 1.0 + + +@pytest.mark.cpu +def test_pred_from_auxiliaries_mixed_batch() -> None: + """Batch with one source image and one missing falls back per-item.""" + src_bytes = _make_jpeg_bytes() + aux = [ + {"source_image_bytes": src_bytes, "category": "color_alter"}, + {"category": "style_change"}, # no source image + ] + pred = pred_tensor_from_auxiliaries(aux, size=32) + assert pred.shape == (2, 3, 32, 32) + assert pred.min() >= 0.0 and pred.max() <= 1.0 + + +@pytest.mark.cpu +def test_pred_from_auxiliaries_generic_bytes_scan() -> None: + """pred_tensor_from_auxiliaries discovers image bytes under an unknown field name (generic scan).""" + src_bytes = _make_jpeg_bytes() + aux = [{"my_custom_image_bytes": src_bytes, "category": "motion_change"}] + pred = pred_tensor_from_auxiliaries(aux, size=32) + assert pred.shape == (1, 3, 32, 32) + assert pred.min() >= 0.0 and pred.max() <= 1.0 + + +@pytest.mark.cpu +def test_pred_from_auxiliaries_known_names_take_priority() -> None: + """Known field names are resolved before the generic bytes scan.""" + src_bytes_known = _make_jpeg_bytes(16, 16) + src_bytes_unknown = _make_jpeg_bytes(32, 32) + first_known = VLM_AUX_IMAGE_BYTES_KEY_ORDER[0] + aux = [{"other_bytes": src_bytes_unknown, first_known: src_bytes_known}] + pred = pred_tensor_from_auxiliaries(aux, size=16) + # Should use the known key (16x16 image → 16x16 crop); generic scan would pick 32x32 + assert pred.shape == (1, 3, 16, 16) + + +@pytest.mark.cpu +def test_pred_from_auxiliaries_require_source_image_raises_when_missing() -> None: + """require_source_image=True raises ValueError instead of silently returning noise.""" + aux = [{"category": "replace"}] # no image bytes + with pytest.raises(ValueError, match="require_source_image=True"): + pred_tensor_from_auxiliaries(aux, size=32, require_source_image=True) + + +@pytest.mark.cpu +def test_pred_from_auxiliaries_require_source_image_succeeds_when_present() -> None: + """require_source_image=True succeeds and decodes bytes when source_image_bytes is present.""" + src_bytes = _make_jpeg_bytes() + aux = [{"source_image_bytes": src_bytes, "category": "replace"}] + pred = pred_tensor_from_auxiliaries(aux, size=32, require_source_image=True) + assert pred.shape == (1, 3, 32, 32) + assert pred.min() >= 0.0 and pred.max() <= 1.0 From cba597b3d77850558dddae845a9efa545c428277 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Sat, 25 Apr 2026 12:51:30 +0200 Subject: [PATCH 3/3] feat(vision-metrics): add vision-based VLM metrics - Add VieScoreMetric for vision-instruction-execution alignment - Add ImageEditScoreMetric for image editing evaluation - Add VQAMetric for visual question-answering - Register all vision metrics in registry - Add benchmark configs for vision-based evaluation - Include unit and integration tests with mocked VLM --- src/pruna/evaluation/benchmarks.py | 145 ++-- .../metrics/metric_img_edit_score.py | 219 ++++++ .../evaluation/metrics/metric_vie_score.py | 363 ++++++++++ src/pruna/evaluation/metrics/metric_vqa.py | 158 ++++ src/pruna/evaluation/metrics/registry.py | 14 +- tests/evaluation/test_vision_metrics.py | 684 ++++++++++++++++++ 6 files changed, 1529 insertions(+), 54 deletions(-) create mode 100644 src/pruna/evaluation/metrics/metric_img_edit_score.py create mode 100644 src/pruna/evaluation/metrics/metric_vie_score.py create mode 100644 src/pruna/evaluation/metrics/metric_vqa.py create mode 100644 tests/evaluation/test_vision_metrics.py diff --git a/src/pruna/evaluation/benchmarks.py b/src/pruna/evaluation/benchmarks.py index e52ae463..be240c16 100644 --- a/src/pruna/evaluation/benchmarks.py +++ b/src/pruna/evaluation/benchmarks.py @@ -12,12 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +import builtins from dataclasses import dataclass, field from pruna.data import base_datasets from pruna.data.utils import get_literal_values_from_param from pruna.evaluation.metrics import MetricRegistry +TASK_TYPE_TEXT_IMAGE = "text_image" +TASK_TYPE_TEXT_PLUS_IMAGE_IMAGE = "text+image_image" + @dataclass class Benchmark: @@ -31,9 +37,13 @@ class Benchmark: description : str Description of what the benchmark evaluates. metrics : list[str] - List of metric names used for evaluation. + Metric names from ``MetricRegistry`` that the ``reference`` paper + explicitly names for that benchmark (not speculative proxies). Entries + with no matching registered name stay empty; pass metrics explicitly to + ``Task`` when running other evaluations. task_type : str - Type of task the benchmark evaluates (e.g., 'text_to_image'). + Modality-style label: ``text_image`` (text → image), ``text+image_image`` (text + source + image → image), or ``text_to_video``, ``image_classification``, ``text_generation``. reference : str | None URL to the canonical paper (e.g., arXiv) for this benchmark. """ @@ -62,24 +72,17 @@ class BenchmarkRegistry: """ Registry for benchmarks. - Metrics per benchmark are set to those explicitly used in the reference - paper (see reference URL). All entries verified from paper evaluation - sections (ar5iv/HTML or PDF) as of verification pass: + Each entry's ``metrics`` lists only ``MetricRegistry`` names that have a + **directly named** counterpart in the ``reference`` paper (e.g. CLIPScore → + ``clip_score``, VQAScore → ``vqa``, Fréchet inception distance → ``fid``). + If the paper cites a method with no registered metric (HPS v2, Mask2Former, + mPLUG-large adjudication, …), the list is empty. + + See ``.mine/benchmark-paper-alignment/01-arxiv-literature-vs-pruna-metrics.md`` + for paper-by-paper notes and Pruna implementation gaps. - - Parti Prompts (2206.10789 §5.2, §5.4): human side-by-side only on P222. - - DrawBench (2205.11487 §4.3): human raters only; COCO uses FID + CLIP. - - GenAI Bench (2406.13743): VQAScore only (web/PWC; ar5iv failed). - - VBench (2311.17982): 16 dimension-specific methods; no single Pruna metric. - - COCO (2205.11487 §4.1): FID and CLIP score for fidelity and alignment. - - ImageNet (1409.0575 §4): top-1/top-5 classification accuracy. - - WikiText (1609.07843 §5): perplexity on validation/test. - - GenEval (2310.11513 §3.2): Mask2Former + CLIP color pipeline, binary score. - - HPS (2306.09341): HPS v2 scoring model (CLIP fine-tuned on HPD v2). - - ImgEdit (2505.20275 §4.2): GPT-4o 1–5 ratings and ImgEdit-Judge. - - Long Text Bench (2507.22058 §4): Text Accuracy (OCR, Qwen2.5-VL-7B). - - GEditBench (2504.17761 §4.2): VIEScore (SQ, PQ, O via GPT-4.1/Qwen2.5-VL). - - OneIG (2506.07977 §4.1): per-dimension metrics (semantic alignment, ED, etc.). - - DPG (2403.05135): DSG-style graph score, mPLUG-large adjudicator. + OneIG is split into six subset benchmarks (plus full ``OneIG``); see + ``.mine/benchmark-paper-alignment/02-oneig-subset-metrics-verification.md`` for §4.1 mapping. """ _registry: dict[str, Benchmark] = {} @@ -88,9 +91,7 @@ class BenchmarkRegistry: def _register(cls, benchmark: Benchmark) -> None: missing = [m for m in benchmark.metrics if not MetricRegistry.has_metric(m)] if missing: - raise ValueError( - f"Benchmark '{benchmark.name}' references metrics not in MetricRegistry: {missing}." - ) + raise ValueError(f"Benchmark '{benchmark.name}' references metrics not in MetricRegistry: {missing}.") if benchmark.lookup_key not in base_datasets: available = ", ".join(base_datasets.keys()) raise ValueError( @@ -125,14 +126,14 @@ def get(cls, name: str) -> Benchmark: return cls._registry[key] @classmethod - def list(cls, task_type: str | None = None) -> list[str]: + def list(cls, task_type: str | None = None) -> builtins.list[str]: """ List available benchmark names. Parameters ---------- task_type : str | None - Filter by task type (e.g., 'text_to_image', 'text_to_video'). + Filter by task type (e.g., ``text_image``, ``text_to_video``). If None, returns all benchmarks. Returns @@ -154,7 +155,7 @@ def list(cls, task_type: str | None = None) -> list[str]: "perspectives, and symbol rendering from basic to complex compositions." ), metrics=[], # Paper uses human evaluation only; pass explicit metrics if needed - task_type="text_to_image", + task_type=TASK_TYPE_TEXT_IMAGE, reference="https://arxiv.org/abs/2206.10789", ), Benchmark( @@ -164,7 +165,7 @@ def list(cls, task_type: str | None = None) -> list[str]: "Enables side-by-side comparison on sample quality and image-text alignment with human raters." ), metrics=[], # Paper uses human evaluation only; pass explicit metrics if needed - task_type="text_to_image", + task_type=TASK_TYPE_TEXT_IMAGE, reference="https://arxiv.org/abs/2205.11487", ), Benchmark( @@ -174,8 +175,8 @@ def list(cls, task_type: str | None = None) -> list[str]: "Covers basic skills (scene, attributes, spatial relationships) to advanced reasoning " "(counting, comparison, logic/negation) with over 24k human ratings." ), - metrics=[], # Paper uses VQAScore only; not in Pruna - task_type="text_to_image", + metrics=["vqa", "clip_score"], # VQAScore + CLIPScore both named (arXiv:2406.13743) + task_type=TASK_TYPE_TEXT_IMAGE, reference="https://arxiv.org/abs/2406.13743", ), Benchmark( @@ -195,8 +196,8 @@ def list(cls, task_type: str | None = None) -> list[str]: "MS-COCO for text-to-image evaluation (Imagen, 2205.11487). Paper reports " "FID for fidelity and CLIP score for image-text alignment." ), - metrics=["fid", "clip_score"], # §4.1: FID + CLIP score - task_type="text_to_image", + metrics=["fid", "clip_score"], + task_type=TASK_TYPE_TEXT_IMAGE, reference="https://arxiv.org/abs/2205.11487", ), Benchmark( @@ -223,11 +224,13 @@ def list(cls, task_type: str | None = None) -> list[str]: name="GenEval", description=( "Compositional text-to-image benchmark with 6 categories: single object, two object, " - "counting, colors, position, color attributes. Evaluates fine-grained alignment " - "between prompts and generated images via VQA-style questions." + "counting, colors, position, color attributes. Uses atomic yes/no questions per prompt; " + "``Task.from_benchmark`` wires ``qa_accuracy`` with strict per-image aggregation " + "(all questions must pass) plus ``clip_score``. For holistic VQAScore-style scoring " + "use GenAI Bench with ``vqa``." ), - metrics=["clip_score"], # §3.2: Mask2Former; not in Pruna - task_type="text_to_image", + metrics=["qa_accuracy", "clip_score"], + task_type=TASK_TYPE_TEXT_IMAGE, reference="https://arxiv.org/abs/2310.11513", ), Benchmark( @@ -237,7 +240,7 @@ def list(cls, task_type: str | None = None) -> list[str]: "Covers anime, concept-art, paintings, and photo styles with human preference data." ), metrics=[], # Paper uses HPS scoring model; not in Pruna - task_type="text_to_image", + task_type=TASK_TYPE_TEXT_IMAGE, reference="https://arxiv.org/abs/2306.09341", ), Benchmark( @@ -246,18 +249,20 @@ def list(cls, task_type: str | None = None) -> list[str]: "Image editing benchmark with 8 edit types: replace, add, remove, adjust, extract, " "style, background, compose. Evaluates instruction-following for inpainting and editing." ), - metrics=[], # Paper uses GPT-4o/ImgEdit-Judge; not in Pruna - task_type="text_to_image", + metrics=["img_edit_score"], # Paper: GPT-4o rubric scores, FakeShield; no matching MetricRegistry name + task_type=TASK_TYPE_TEXT_PLUS_IMAGE_IMAGE, reference="https://arxiv.org/abs/2505.20275", ), Benchmark( name="Long Text Bench", description=( - "Text-to-image benchmark for long, detailed prompts. Evaluates model ability to " - "handle complex multi-clause descriptions and maintain coherence across long instructions." + "Text rendering benchmark evaluating whether T2I models correctly render specific text strings " + "specified in prompts. Provides ``text_content`` ground truth for OCR comparison via ``text_score`` " + "(normalized character accuracy in [0, 1]; higher is better). " + "Not to be confused with text-to-image alignment for long descriptive prompts." ), - metrics=[], # Paper uses text_score/TIT-Score; not in Pruna - task_type="text_to_image", + metrics=["text_score"], + task_type=TASK_TYPE_TEXT_IMAGE, reference="https://arxiv.org/abs/2507.22058", ), Benchmark( @@ -265,21 +270,55 @@ def list(cls, task_type: str | None = None) -> list[str]: description=( "General image editing benchmark with 11 task types: background change, color alter, " "material alter, motion change, style change, subject add/remove/replace, text change, " - "tone transfer, and human retouching." + "tone transfer, and human retouching. " + "Evaluated with VIEScore in text--image editing (``tie``) mode when source image bytes " + "are available in batch aux (semantic + perceptual sub-scores, overall as geometric mean " + "on the 0--10 scale; see ``vie_score`` metric)." ), - metrics=[], # Paper uses VIEScore; not in Pruna - task_type="text_to_image", + metrics=["vie_score"], # VIEScore named in GEdit-Bench section + task_type=TASK_TYPE_TEXT_PLUS_IMAGE_IMAGE, reference="https://arxiv.org/abs/2504.17761", ), Benchmark( - name="OneIG", - description=( - "Omni-dimensional benchmark for text-to-image evaluation. Six dataset categories " - "(Anime_Stylization, General_Object, Knowledge_Reasoning, Multilingualism, Portrait, " - "Text_Rendering) plus fine-grained style classes. Includes alignment questions." - ), - metrics=[], # Paper uses dimension-specific metrics; not in Pruna - task_type="text_to_image", + name="OneIG Anime Stylization", + description="OneIG subset: anime and stylized imagery.", + metrics=["oneig_alignment"], + task_type=TASK_TYPE_TEXT_IMAGE, + reference="https://arxiv.org/abs/2506.07977", + ), + Benchmark( + name="OneIG General Object", + description="OneIG subset: everyday objects and scenes.", + metrics=["oneig_alignment"], + task_type=TASK_TYPE_TEXT_IMAGE, + reference="https://arxiv.org/abs/2506.07977", + ), + Benchmark( + name="OneIG Knowledge Reasoning", + description="OneIG subset: knowledge- and reasoning-heavy prompts.", + metrics=["oneig_reasoning"], + task_type=TASK_TYPE_TEXT_IMAGE, + reference="https://arxiv.org/abs/2506.07977", + ), + Benchmark( + name="OneIG Multilingualism", + description="OneIG subset: multilingual prompts (incl. Chinese splits).", + metrics=["oneig_alignment"], + task_type=TASK_TYPE_TEXT_IMAGE, + reference="https://arxiv.org/abs/2506.07977", + ), + Benchmark( + name="OneIG Portrait", + description="OneIG subset: people and portraits.", + metrics=["oneig_alignment"], + task_type=TASK_TYPE_TEXT_IMAGE, + reference="https://arxiv.org/abs/2506.07977", + ), + Benchmark( + name="OneIG Text Rendering", + description="OneIG subset: text and graphics painted into the image.", + metrics=["oneig_text_score"], + task_type=TASK_TYPE_TEXT_IMAGE, reference="https://arxiv.org/abs/2506.07977", ), Benchmark( @@ -289,7 +328,7 @@ def list(cls, task_type: str | None = None) -> list[str]: "global, and other descriptive aspects with natural-language questions for alignment." ), metrics=[], # Paper uses custom evaluation; not in Pruna - task_type="text_to_image", + task_type=TASK_TYPE_TEXT_IMAGE, reference="https://arxiv.org/abs/2403.05135", ), ]: diff --git a/src/pruna/evaluation/metrics/metric_img_edit_score.py b/src/pruna/evaluation/metrics/metric_img_edit_score.py new file mode 100644 index 00000000..4c1832af --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_img_edit_score.py @@ -0,0 +1,219 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Image Edit Score metric. + +VLM-based instruction-following score for image editing. Evaluates how well an edited image +follows the given editing instruction on a 0-10 scale. Related work: EditScore (arXiv:2509.23909), +ADIEE (ICCV 2025). + +When the ``ImgEdit`` benchmark provides a per-sample ``judge_prompt`` and +``source_image_bytes`` in the auxiliaries, the metric mirrors the ImgEdit paper +evaluation protocol: the judge_prompt rubric (three 1-5 criterion scores) is +filled with the editing instruction, both source and edited images are shown to +the VLM, and the minimum of the three criterion scores is normalised to [0, 1] by +dividing by 5 (consistent with VIEScore methodology: the weakest criterion governs). +Without these auxiliaries the metric falls back to a single-image generic 0-10 prompt. +""" + +from __future__ import annotations + +from typing import Any, Literal + +import torch + +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.utils import ( + SINGLE, + metric_data_processor, +) +from pruna.evaluation.metrics.vlm_base import ( + BaseVLM, + StatefulVLMMeanScoresMetric, + auxiliary_dicts_from_gt, + prompts_from_y_x_inputs, +) +from pruna.evaluation.metrics.vlm_utils import ( + FloatOutput, + VIEScoreJsonOutput, + _process_images, + get_score_from_response, + pil_rgb_from_aux_image_bytes, + viescore_min_scores_0_10, +) + +_FALLBACK_QUESTION = ( + 'On a scale of 0 to 10, how well does this edited image follow the instruction "{prompt}"? ' + "0 = instruction not followed at all, 10 = perfectly executed. Reply with a single number." +) + +_JUDGE_JSON_SUFFIX = ( + '\n\nProvide your three criterion scores as JSON: {"score": [score1, score2, score3]} ' + "where each score is a number from 1 to 5." +) + + +@MetricRegistry.register("img_edit_score") +class ImageEditScoreMetric(StatefulVLMMeanScoresMetric): + """ + Image Edit Score metric. + + VLM-based instruction-following score for image editing. Evaluates how well an edited image + follows the given editing instruction. Higher scores indicate better editing quality. + + When auxiliaries contain ``judge_prompt`` and ``source_image_bytes`` (as provided + by the ImgEdit benchmark), the metric passes **both** the source (before) and edited + (after) images to the VLM together with the dataset-specific rubric. This matches + the ImgEdit paper's evaluation protocol. Without these fields, it falls back to a + single-image generic question. + + Related work: EditScore (arXiv:2509.23909), ADIEE (ICCV 2025). + + Parameters + ---------- + *args : Any + Additional positional arguments. + vlm : BaseVLM | None, optional + Custom VLM instance. If provided, vlm_type and model_name are ignored. + vlm_type : {"litellm", "transformers"}, optional + VLM backend. Default is "litellm". + model_name : str | None, optional + Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not + provided (e.g. ``openai/gpt-4o``). + vlm_kwargs : dict, optional + Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, + set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. + structured_output : bool, optional + Use structured generation (litellm pydantic; transformers outlines when applicable). + Default is True. + device : str | torch.device | None, optional + Device for transformers VLM. + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional arguments. + + Examples + -------- + Same ``hosted`` / ``local`` pattern as :func:`~pruna.evaluation.metrics.vlm_base.get_vlm`: + + .. code-block:: python + + import torch + + from pruna.evaluation.metrics import ImageEditScoreMetric + + hosted = ImageEditScoreMetric(vlm_type="litellm", model_name="openai/gpt-4o") + local = ImageEditScoreMetric( + vlm_type="transformers", + model_name="HuggingFaceTB/SmolVLM-256M-Instruct", + device="cpu", + vlm_kwargs={"model_load_kwargs": {"torch_dtype": torch.float32}}, + ) + """ + + scores: list[float] + default_call_type: str = "y_x" + higher_is_better: bool = True + metric_name: str = "img_edit_score" + + def __init__( + self, + *args, + vlm: BaseVLM | None = None, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str | None = None, + vlm_kwargs: dict | None = None, + structured_output: bool = True, + device: str | torch.device | None = None, + api_key: str | None = None, + call_type: str = SINGLE, + **kwargs: Any, + ) -> None: + super().__init__(device=device) + self.response_format = FloatOutput if structured_output else None + + self._init_vlm_scores( + vlm=vlm, + vlm_type=vlm_type, + model_name=model_name, + vlm_kwargs=vlm_kwargs, + structured_output=structured_output, + device=device, + api_key=api_key, + call_type=call_type, + ) + + def update(self, x: list[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None: + """ + Update the metric with new batch data. + + When ``gt`` auxiliaries contain ``judge_prompt`` and ``source_image_bytes``, the + metric uses the dataset rubric and a before/after two-image comparison. Otherwise + it falls back to a single-image generic question. + + Parameters + ---------- + x : List[Any] | torch.Tensor + The input data (editing instructions / prompts). + gt : torch.Tensor + Auxiliaries per sample (may contain ``judge_prompt`` and ``source_image_bytes``). + outputs : torch.Tensor + The output (edited) images. + """ + inputs = metric_data_processor(x, gt, outputs, self.call_type) + images = _process_images(inputs[0]) + prompts = prompts_from_y_x_inputs(inputs, len(images)) + aux_list = auxiliary_dicts_from_gt(gt, len(images)) + + for i, image in enumerate(images): + prompt = prompts[i] if i < len(prompts) else "" + aux_row = aux_list[i] + + judge_prompt = aux_row.get("judge_prompt", "") or "" + source_image = pil_rgb_from_aux_image_bytes(aux_row, min_bytes_in_value_scan=100) + + if judge_prompt and source_image is not None: + filled = judge_prompt.replace("", prompt).strip() + question = filled + _JUDGE_JSON_SUFFIX + try: + responses = self.vlm.generate_with_image_lists( + [[source_image, image]], [question], response_format=VIEScoreJsonOutput + ) + raw = viescore_min_scores_0_10(responses[0]) + if raw: + score = max(0.0, min(1.0, float(min(raw)) / 5.0)) + self.scores.append(score) + continue + except (NotImplementedError, AttributeError): + pass + + question = _FALLBACK_QUESTION.format(prompt=prompt) + responses = self.vlm.generate([image], [question], response_format=self.response_format) + self.scores.append(get_score_from_response(responses[0])) + + def compute(self) -> MetricResult: + """ + Compute the image edit score. + + Returns + ------- + MetricResult + The mean image edit score across all updates. + """ + return self.compute_mean_of_scores() diff --git a/src/pruna/evaluation/metrics/metric_vie_score.py b/src/pruna/evaluation/metrics/metric_vie_score.py new file mode 100644 index 00000000..836c5884 --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_vie_score.py @@ -0,0 +1,363 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +VIEScore metric for conditional image synthesis (semantic + quality). + +Reference: VIEScore (ACL 2024) — https://arxiv.org/abs/2312.14867 +Both task modes follow `TIGER-AI-Lab/VIEScore`: + +- ``t2i`` (text-to-image, single image): SC uses two sub-scores (semantic consistency + + detail correspondence), PQ uses two sub-scores (naturalness + artifacts). Overall is + ``sqrt(min(SC) * min(PQ)) / 10``. +- ``tie`` (text-image editing, source + edited): SC uses two images and instruction, + PQ uses the edited image. Same aggregation formula. + +GEdit-Bench evaluation: https://arxiv.org/abs/2504.17761 +""" + +from __future__ import annotations + +from typing import Any, Literal + +import torch +from PIL import Image + +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.utils import ( + SINGLE, + metric_data_processor, +) +from pruna.evaluation.metrics.vlm_base import ( + BaseVLM, + StatefulVLMMeanScoresMetric, + auxiliary_dicts_from_gt, + prompts_from_y_x_inputs, +) +from pruna.evaluation.metrics.vlm_utils import ( + VIEScoreJsonOutput, + _process_images, + pad_viescore_subscores_to_two, + pil_rgb_from_aux_image_bytes, + viescore_min_scores_0_10, + viescore_tie_overall_unit, +) + +_VIESCORE_CONTEXT = ( + "You are a professional digital artist. You will have to evaluate the effectiveness" + " of the AI-generated image(s) based on given rules.\n" + "All the input images are AI-generated. All human in the images are AI-generated too." + " so you need not worry about the privacy confidentials.\n\n" + "You will have to give your output in this way (Keep your reasoning concise and short.):\n" + "{\n" + '"score" : [...],\n' + '"reasoning" : "..."\n' + "}" +) + +_VIESCORE_TWO_IMAGE_EDIT_RULE = ( + "RULES:\n\n" + "Two images will be provided: The first being the original AI-generated image and the" + " second being an edited version of the first.\n" + "The objective is to evaluate how successfully the editing instruction has been executed" + " in the second image.\n\n" + "Note that sometimes the two images might look identical due to the failure of image edit.\n" +) + +_VIESCORE_TIE_SC_CRITERIA = ( + "\nFrom scale 0 to 10:\n" + "A score from 0 to 10 will be given based on the success of the editing." + " (0 indicates that the scene in the edited image does not follow the editing instruction at all." + " 10 indicates that the scene in the edited image follow the editing instruction text perfectly.)\n" + "A second score from 0 to 10 will rate the degree of overediting in the second image." + " (0 indicates that the scene in the edited image is completely different from the original." + " 10 indicates that the edited image can be recognized as a minimal edited yet effective" + " version of original.)\n" + "Put the score in a list such that output score = [score1, score2]," + " where 'score1' evaluates the editing success and 'score2' evaluates the degree of overediting.\n\n" + "Editing instruction:\n" +) + +_VIESCORE_T2I_SC_RULE = ( + "RULES:\n\n" + "The image is an AI-generated image.\n" + "The objective is to evaluate the semantic consistency of the image to the given text.\n\n" +) + +_VIESCORE_T2I_SC_CRITERIA = ( + "\nFrom scale 0 to 10:\n" + "A score from 0 to 10 will be given based on the semantic consistency.\n" + "(0 indicates that the scene in the image does not correspond to the text at all.\n" + " 10 indicates that the scene in the image follows the text perfectly.)\n" + "A second score from 0 to 10 will rate the detail correspondence.\n" + "(0 indicates that most details in the text (e.g., color, size, shape, or layout) are missing or" + " incorrect in the image.\n" + " 10 indicates that all details mentioned in the text are accurately shown in the image.)\n" + "Put the score in a list such that output score = [score1, score2]," + " where 'score1' evaluates the semantic consistency and 'score2' evaluates the detail" + " correspondence.\n\n" + "Text prompt:\n" +) + +_VIESCORE_PQ_SINGLE_IMAGE = ( + "RULES:\n\n" + "The image is an AI-generated image.\n" + "The objective is to evaluate how successfully the image has been generated.\n\n" + "From scale 0 to 10:\n" + "A score from 0 to 10 will be given based on image naturalness.\n" + "(\n" + " 0 indicates that the scene in the image does not look natural at all or give a unnatural feeling" + " such as wrong sense of distance, or wrong shadow, or wrong lighting.\n" + " 10 indicates that the image looks natural.\n" + ")\n" + "A second score from 0 to 10 will rate the image artifacts.\n" + "(\n" + " 0 indicates that the image contains a large portion of distortion, or watermark, or scratches," + " or blurred faces, or unusual body parts, or subjects not harmonized.\n" + " 10 indicates the image has no artifacts.\n" + ")\n" + "Put the score in a list such that output score = [naturalness, artifacts]\n" +) + + +def _build_viescore_tie_sc_prompt(instruction: str) -> str: + """Build the VIEScore ``tie`` semantic-criteria prompt (source + edited images). + + Args: + instruction: Editing instruction embedded in the prompt. + + Returns: + ------- + Full prompt aligned with TIGER-AI-Lab/VIEScore ``tie`` SC. + """ + return "\n".join( + [ + _VIESCORE_CONTEXT, + _VIESCORE_TWO_IMAGE_EDIT_RULE, + _VIESCORE_TIE_SC_CRITERIA.strip(), + instruction.strip(), + ] + ) + + +def _build_viescore_t2i_sc_prompt(prompt: str) -> str: + """Build the VIEScore ``t2i`` semantic-consistency prompt for one generated image. + + Args: + prompt: Text prompt used to generate the image. + + Returns: + ------- + Full prompt aligned with TIGER-AI-Lab/VIEScore ``t2i`` SC. + """ + return "\n".join( + [ + _VIESCORE_CONTEXT, + _VIESCORE_T2I_SC_RULE.strip(), + _VIESCORE_T2I_SC_CRITERIA.strip(), + prompt.strip(), + ] + ) + + +def _build_viescore_pq_prompt() -> str: + """Build the VIEScore perceptual-quality prompt for one image (SC or edited).""" + return "\n".join([_VIESCORE_CONTEXT, _VIESCORE_PQ_SINGLE_IMAGE]) + + +@MetricRegistry.register("vie_score") +class VieScoreMetric(StatefulVLMMeanScoresMetric): + """ + VIEScore: semantic + perceptual quality with geometric-mean overall. + + **Text-to-image (one generated image):** uses the VIEScore ``t2i`` SC prompt (semantic + consistency + detail correspondence, 0--10 each) and the shared PQ prompt (naturalness + + artifacts, 0--10 each). Overall is ``sqrt(min(SC) * min(PQ)) / 10`` in ``[0, 1]``. + + **Text--image editing (source + edited available):** matches the VIEScore ``tie`` setup + used in GEdit-Bench: semantic criteria use **two** images (source then edited) and the + editing instruction; perceptual criteria use the **edited** image only. Overall is + ``sqrt(min(SC) * min(PQ)) / 10`` in ``[0, 1]``, with ``min`` taken over the sub-scores in + each JSON ``score`` list, consistent with `VIEScore`_. + + .. _VIEScore: https://github.com/TIGER-AI-Lab/VIEScore + + Parameters + ---------- + *args : Any + Additional positional arguments. + vlm : BaseVLM | None, optional + Custom VLM instance. If provided, vlm_type and model_name are ignored. + vlm_type : {"litellm", "transformers"}, optional + VLM backend. Default is "litellm". + model_name : str | None, optional + Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not + provided (e.g. ``openai/gpt-4o``). + vlm_kwargs : dict, optional + Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, + set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. + structured_output : bool, optional + Use structured generation (litellm pydantic; transformers may use plain generation for + multi-image). Default is True. + device : str | torch.device | None, optional + Device for transformers VLM. + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional arguments. + + References + ---------- + VIEScore: Towards Explainable Metrics for Conditional Image Synthesis Evaluation (ACL 2024) + https://arxiv.org/abs/2312.14867 + https://github.com/TIGER-AI-Lab/VIEScore + + GEdit-Bench (image editing evaluation) + https://arxiv.org/abs/2504.17761 + + Examples + -------- + Same ``hosted`` / ``local`` pattern as :func:`~pruna.evaluation.metrics.vlm_base.get_vlm``. + Multi-image ``tie`` paths call ``generate_with_image_lists`` on ``self.vlm`` internally. + + .. code-block:: python + + import torch + + from pruna.evaluation.metrics import VieScoreMetric + + hosted = VieScoreMetric(vlm_type="litellm", model_name="openai/gpt-4o") + local = VieScoreMetric( + vlm_type="transformers", + model_name="HuggingFaceTB/SmolVLM-256M-Instruct", + device="cpu", + vlm_kwargs={"model_load_kwargs": {"torch_dtype": torch.float32}}, + ) + """ + + scores: list[float] + default_call_type: str = "y_x" + higher_is_better: bool = True + metric_name: str = "vie_score" + + def __init__( + self, + *args, + vlm: BaseVLM | None = None, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str | None = None, + vlm_kwargs: dict | None = None, + structured_output: bool = True, + device: str | torch.device | None = None, + api_key: str | None = None, + call_type: str = SINGLE, + **kwargs: Any, + ) -> None: + super().__init__(device=device) + self.structured_output = structured_output + self.response_format = VIEScoreJsonOutput if structured_output else None + + self._init_vlm_scores( + vlm=vlm, + vlm_type=vlm_type, + model_name=model_name, + vlm_kwargs=vlm_kwargs, + structured_output=structured_output, + device=device, + api_key=api_key, + call_type=call_type, + ) + + def _score_single_image_t2i(self, image: Image.Image, prompt: str) -> float: + """VIEScore ``t2i``: single-image SC (semantic + detail) and PQ (naturalness + artifacts). + + Matches the VIEScore paper's t2i evaluation: two SC sub-scores on 0--10 and two PQ + sub-scores on 0--10, aggregated as ``sqrt(min(SC) * min(PQ)) / 10``. + """ + sc_prompt = _build_viescore_t2i_sc_prompt(prompt) + pq_prompt = _build_viescore_pq_prompt() + + rf = self.response_format if self.structured_output else None + + sc_raw = self.vlm.generate([image], [sc_prompt], response_format=rf)[0] + pq_raw = self.vlm.generate([image], [pq_prompt], response_format=rf)[0] + + sc_list = pad_viescore_subscores_to_two(viescore_min_scores_0_10(sc_raw)) + pq_list = pad_viescore_subscores_to_two(viescore_min_scores_0_10(pq_raw)) + return viescore_tie_overall_unit(sc_list, pq_list) + + def _score_tie_gedit(self, source: Image.Image, edited: Image.Image, instruction: str) -> float: + """VIEScore ``tie``: two-image SC, single-image PQ, overall geometric mean on 0--10 mins.""" + sc_prompt = _build_viescore_tie_sc_prompt(instruction) + pq_prompt = _build_viescore_pq_prompt() + + rf = self.response_format if self.structured_output else None + + if hasattr(self.vlm, "generate_with_image_lists"): + sc_raw = self.vlm.generate_with_image_lists( + [[source, edited]], + [sc_prompt], + response_format=rf, + )[0] + else: + raise RuntimeError("VLM backend must implement generate_with_image_lists for editing parity.") + + pq_raw = self.vlm.generate([edited], [pq_prompt], response_format=rf)[0] + + sc_list = pad_viescore_subscores_to_two(viescore_min_scores_0_10(sc_raw)) + pq_list = pad_viescore_subscores_to_two(viescore_min_scores_0_10(pq_raw)) + return viescore_tie_overall_unit(sc_list, pq_list) + + def update(self, x: list[Any] | torch.Tensor, gt: Any, outputs: torch.Tensor) -> None: + """ + Update the metric with new batch data. + + Parameters + ---------- + x : List[Any] | torch.Tensor + The input data (prompts). + gt : Any + Per-sample auxiliary dicts (``prompt_with_auxiliaries_collate``), or tensor placeholders + when aux is unused. + outputs : torch.Tensor + The output images. + """ + inputs = metric_data_processor(x, gt, outputs, self.call_type) + images = _process_images(inputs[0]) + prompts = prompts_from_y_x_inputs(inputs, len(images)) + aux_list = auxiliary_dicts_from_gt(gt, len(images)) + + for i, image in enumerate(images): + prompt = prompts[i] if i < len(prompts) else "" + aux = aux_list[i] + source = pil_rgb_from_aux_image_bytes(aux, min_bytes_in_value_scan=100) + + if source is not None: + self.scores.append(self._score_tie_gedit(source, image, prompt)) + else: + self.scores.append(self._score_single_image_t2i(image, prompt)) + + def compute(self) -> MetricResult: + """ + Compute the VIEScore metric. + + Returns + ------- + MetricResult + The mean VIEScore across all updates. + """ + return self.compute_mean_of_scores() diff --git a/src/pruna/evaluation/metrics/metric_vqa.py b/src/pruna/evaluation/metrics/metric_vqa.py new file mode 100644 index 00000000..75dfb325 --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_vqa.py @@ -0,0 +1,158 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +VQA (Visual Question Answering) metric. + +Reference: VQAScore - Evaluating Text-to-Visual Generation with Image-to-Text Generation +https://arxiv.org/abs/2404.01291 + +Note: VQAScore uses P(Yes) (probability of "Yes" answer) for ranking. With litellm, +use_probability=True (default) requests logprobs for soft scores when the provider supports it. +Set use_probability=False for binary 0/1. With ``transformers``, ``use_probability=True`` +uses next-token softmax mass on yes/no prefix tokens (VQAScore-style); ``False`` uses +generation plus binary matching. + +For API keys, LiteLLM vs local ``transformers``, and hosted vs local construction, see +:doc:`Evaluate a model ` (Vision-language judge metrics) and +:func:`~pruna.evaluation.metrics.vlm_base.get_vlm`. +""" + +from __future__ import annotations + +from typing import Any, Literal + +import torch + +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.utils import SINGLE, metric_data_processor +from pruna.evaluation.metrics.vlm_base import BaseVLM, StatefulVLMMeanScoresMetric, prompts_from_y_x_inputs +from pruna.evaluation.metrics.vlm_utils import VQAnswer, _process_images + + +@MetricRegistry.register("vqa") +class VQAMetric(StatefulVLMMeanScoresMetric): + """ + VQA (Visual Question Answering) metric. + + Uses VLM to answer "Does this image show '{prompt}'?" and scores alignment. + Higher scores indicate better image-text alignment. + + VQAScore (arXiv:2404.01291) uses P(Yes) for ranking. Default ``use_probability=True`` + with litellm requests logprobs for soft scores when supported. + + Parameters + ---------- + vlm : BaseVLM | None, optional + Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored. + vlm_type : {"litellm", "transformers"}, optional + VLM backend to use. Default is "litellm". + model_name : str | None, optional + Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not + provided (e.g. ``openai/gpt-4o``). + vlm_kwargs : dict, optional + Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, + set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. + structured_output : bool, optional + Use structured generation for stable outputs (litellm pydantic; transformers outlines + when a string format is used). Default is True. + device : str | torch.device | None, optional + Device for transformers VLM. + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + use_probability : bool, optional + If True, use P(Yes) when backend supports logprobs (litellm). Otherwise binary 0/1. + Default is True for paper alignment. + **kwargs : Any + Additional arguments. + + Notes + ----- + For strict binary scoring without logprobs, pass ``use_probability=False``. Hosted vs + local setup: :doc:`Evaluate a model ` (Vision-language judge metrics). + """ + + scores: list[float] + default_call_type: str = "y_x" + higher_is_better: bool = True + metric_name: str = "vqa" + + def __init__( + self, + vlm: BaseVLM | None = None, + vlm_type: Literal["litellm", "transformers"] = "litellm", + model_name: str | None = None, + vlm_kwargs: dict | None = None, + structured_output: bool = True, + device: str | torch.device | None = None, + api_key: str | None = None, + call_type: str = SINGLE, + use_probability: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(device=device) + self.use_probability = use_probability + self.response_format = VQAnswer if structured_output else None + self._init_vlm_scores( + vlm=vlm, + vlm_type=vlm_type, + model_name=model_name, + vlm_kwargs=vlm_kwargs, + structured_output=structured_output, + device=device, + api_key=api_key, + call_type=call_type, + ) + + def update(self, x: list[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None: + """ + Update the metric with new batch data. + + Parameters + ---------- + x : list[Any] | torch.Tensor + The input data (prompts). + gt : torch.Tensor + The ground truth (unused; present for call-type compatibility). + outputs : torch.Tensor + The output images. + """ + inputs = metric_data_processor(x, gt, outputs, self.call_type) + images = _process_images(inputs[0]) + prompts = prompts_from_y_x_inputs(inputs, len(images)) + for i, image in enumerate(images): + prompt = prompts[i] if i < len(prompts) else "" + question = f'Does this image show "{prompt}"?' + score = self.vlm.score( + [image], + [question], + ["Yes"], + response_format=self.response_format, + use_probability=self.use_probability, + )[0] + self.scores.append(score) + + def compute(self) -> MetricResult: + """ + Compute the VQA score. + + Returns + ------- + MetricResult + The mean VQA score across all updates. + """ + return self.compute_mean_of_scores() diff --git a/src/pruna/evaluation/metrics/registry.py b/src/pruna/evaluation/metrics/registry.py index 5efd721a..650f8a76 100644 --- a/src/pruna/evaluation/metrics/registry.py +++ b/src/pruna/evaluation/metrics/registry.py @@ -14,6 +14,7 @@ from __future__ import annotations +import importlib from functools import partial from inspect import isclass from typing import Any, Callable, Dict, Iterable, List @@ -29,9 +30,17 @@ class MetricRegistry: Registry for metrics. The registry is a dictionary that maps metric names to metric classes. + + Notes + ----- + ``_lazy_metrics`` lists names that :meth:`has_metric` treats as registered before the + implementing module is loaded. The ``oneig_reasoning`` metric imports the LLM2CLIP-related + stack (vendored helpers, heavy optional dependencies); it is imported only when + :meth:`get_metric` is called with that name so other code paths avoid that cost. """ _registry: Dict[str, Callable[..., Any]] = {} + _lazy_metrics: frozenset[str] = frozenset({"oneig_reasoning"}) @classmethod def register(cls, name: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]: @@ -104,7 +113,7 @@ def has_metric(cls, name: str) -> bool: bool True if the metric is registered, False otherwise. """ - return name in cls._registry + return name in cls._registry or name in cls._lazy_metrics @classmethod def get_metric(cls, name: str, **kwargs) -> BaseMetric | StatefulMetric: @@ -122,6 +131,9 @@ def get_metric(cls, name: str, **kwargs) -> BaseMetric | StatefulMetric: ------- The metric instance. """ + if name in cls._lazy_metrics and name not in cls._registry: + importlib.import_module("pruna.evaluation.metrics.metric_oneig_reasoning") + if name not in cls._registry: raise ValueError(f"Metric '{name}' is not registered.") diff --git a/tests/evaluation/test_vision_metrics.py b/tests/evaluation/test_vision_metrics.py new file mode 100644 index 00000000..a4eaa139 --- /dev/null +++ b/tests/evaluation/test_vision_metrics.py @@ -0,0 +1,684 @@ +"""Tests for VLM metrics (VQA, ImageEditScore, QAAccuracy, TextScore, VieScore) and vlm_utils helpers.""" + +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from pruna.evaluation.metrics.metric_img_edit_score import ImageEditScoreMetric +from pruna.evaluation.metrics.metric_oneig_alignment import OneIGAlignmentMetric +from pruna.evaluation.metrics.metric_qa_accuracy import QAAccuracyMetric +from pruna.evaluation.metrics.metric_text_score import OneIGTextScoreMetric, TextScoreMetric +from pruna.evaluation.metrics.metric_vie_score import VieScoreMetric +from pruna.evaluation.metrics.metric_vqa import VQAMetric +from pruna.evaluation.metrics.result import MetricResult +from pruna.evaluation.metrics.vlm_base import BaseVLM, get_vlm +from pruna.evaluation.metrics.vlm_utils import ( + FloatOutput, + VLM_AUX_IMAGE_BYTES_KEY_ORDER, + get_score_from_response, + yes_no_first_token_id_groups, +) + +from ._vlm_batch_snapshot_helpers import ( + BenchmarkVlmBatchOutcome, + pred_tensor_from_auxiliaries, + safe_json_for_snapshot, + vlm_benchmark_batch_to_json_record, +) + +SMOL_VLM = "HuggingFaceTB/SmolVLM-256M-Instruct" + +_ALL_VLM = ( + VQAMetric, + ImageEditScoreMetric, + QAAccuracyMetric, + OneIGAlignmentMetric, + TextScoreMetric, + OneIGTextScoreMetric, + VieScoreMetric, +) + +_SLOW_SMOL_SUBSET = ( + VQAMetric, + OneIGAlignmentMetric, + ImageEditScoreMetric, + VieScoreMetric, +) + + +@pytest.mark.parametrize( + ("raw", "expected"), + [ + (FloatOutput(score=8.0), 0.8), + ({"score": 5.0}, 0.5), + ('{"score": 7.5}', 0.75), + ('{"score": 10}', 1.0), + ("8", 0.8), + ("Score: 7.5 out of 10", 0.75), + ("", 0.0), + ], +) +def test_get_score_from_response(raw: object, expected: float) -> None: + """``get_score_from_response`` maps pydantic, dict, JSON, and text to ``[0, 1]``.""" + assert get_score_from_response(raw) == pytest.approx(expected) + + +def _dummy_image(batch: int = 1, size: int = 224) -> torch.Tensor: + return torch.rand(batch, 3, size, size) + + +def _update_metric(metric: object, prompts: list, images: torch.Tensor) -> None: + if isinstance(metric, OneIGAlignmentMetric): + metric.update( + prompts, + [ + { + "questions": {"1": "Is there a cat?", "2": "Is it sleeping?"}, + "dependencies": {"1": [0], "2": [1]}, + } + ], + images, + ) + elif isinstance(metric, QAAccuracyMetric): + metric.update( + prompts, + [{"questions": {"1": "Is there a cat?"}}], + images, + ) + elif isinstance(metric, (TextScoreMetric, OneIGTextScoreMetric)): + metric.update(prompts, ["cat"], images) + else: + metric.update(prompts, images, images) + + +@pytest.mark.cpu +@pytest.mark.slow +@pytest.mark.parametrize("metric_cls", _SLOW_SMOL_SUBSET) +def test_vlm_metrics_transformers_smolvlm(metric_cls: type) -> None: + """Smoke-test a subset with local SmolVLM (full matrix covered by litellm mock).""" + metric = metric_cls( + vlm_type="transformers", + model_name=SMOL_VLM, + device="cpu", + structured_output=True, + ) + images = _dummy_image(batch=1) + prompts = ["a cat"] + _update_metric(metric, prompts, images) + result = metric.compute() + assert result.name == metric.metric_name + assert isinstance(result.result, float) + if metric.higher_is_better: + assert 0.0 <= result.result <= 1.0 + else: + assert result.result >= 0.0 + + +@pytest.mark.cpu +@pytest.mark.parametrize("metric_cls", _ALL_VLM) +def test_vlm_metrics_litellm_mocked(metric_cls: type) -> None: + """Each VLM metric runs end-to-end with mocked litellm.""" + pytest.importorskip("litellm") + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + if metric_cls in (VQAMetric, QAAccuracyMetric, OneIGAlignmentMetric): + mock_response.choices[0].message.content = '{"answer": "Yes"}' + else: + mock_response.choices[0].message.content = '{"score": 8}' + + with patch("litellm.completion") as mock_completion: + mock_completion.return_value = mock_response + + metric = metric_cls( + vlm_type="litellm", + model_name="gpt-4o", + device="cpu", + structured_output=True, + ) + images = _dummy_image(batch=1) + prompts = ["a cat"] + _update_metric(metric, prompts, images) + result = metric.compute() + + assert result.name == metric.metric_name + assert isinstance(result.result, float) + assert mock_completion.called + + +@pytest.mark.cpu +def test_vlm_metrics_empty_compute_returns_zero() -> None: + """No updates → compute is 0.0 (same for all stateful VLM metrics).""" + metric = VQAMetric( + vlm_type="transformers", + model_name=SMOL_VLM, + device="cpu", + structured_output=True, + ) + assert metric.compute().result == 0.0 + + +@pytest.mark.cpu +def test_vlm_metrics_custom_vlm() -> None: + """Custom VLM passed to VQAMetric is used instead of the default litellm backend.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.generate.return_value = ["Yes"] + mock_vlm.score.return_value = [1.0] + + metric = VQAMetric(vlm=mock_vlm, vlm_type="litellm", device="cpu", structured_output=True) + images = _dummy_image(batch=1) + prompts = ["a cat"] + metric.update(prompts, images, images) + assert metric.compute().result == 1.0 + mock_vlm.score.assert_called() + + +@pytest.mark.cpu +def test_get_vlm_returns_custom() -> None: + """get_vlm returns the provided VLM instance unchanged.""" + custom = MagicMock(spec=BaseVLM) + out = get_vlm(vlm=custom, vlm_type="litellm", model_name="gpt-4o") + assert out is custom + + +@pytest.mark.cpu +def test_yes_no_first_token_id_groups_disjoint() -> None: + """Prefix token ids for Yes vs No should not overlap (avoids double-counting).""" + pytest.importorskip("transformers") + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained("gpt2") + yes_ids, no_ids = yes_no_first_token_id_groups(tok) + assert yes_ids and no_ids + assert not (set(yes_ids) & set(no_ids)) + + +@pytest.mark.cpu +def test_get_vlm_requires_model_name_without_vlm() -> None: + """get_vlm raises ValueError when no model_name is given and no vlm is provided.""" + with pytest.raises(ValueError, match="model_name"): + get_vlm(vlm=None, vlm_type="litellm") + + +@pytest.mark.cpu +@pytest.mark.parametrize( + "metric_cls, expected_name, expected_result", + [ + (TextScoreMetric, "text_score", 1.0), + (OneIGTextScoreMetric, "oneig_text_score", 1.0), + ], +) +def test_text_metrics_list_str_gt(metric_cls: type, expected_name: str, expected_result: float) -> None: + """Text metrics accept plain string ground-truth and return the expected score.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.generate.return_value = ["hello world"] + + metric = metric_cls(vlm=mock_vlm, vlm_type="litellm", device="cpu") + images = _dummy_image(batch=1) + metric.update(["a prompt"], ["hello world"], images) + result = metric.compute() + + assert result.result == expected_result + assert result.name == expected_name + mock_vlm.generate.assert_called_once() + + +@pytest.mark.cpu +def test_text_score_result_in_zero_one_range() -> None: + """TextScoreMetric must return a normalized score in [0, 1], not raw edit distance.""" + mock_vlm = MagicMock(spec=BaseVLM) + # VLM OCR returns something very different from ground truth (high edit distance) + mock_vlm.generate.return_value = ["completely wrong text abcdefghijklmnop"] + + metric = TextScoreMetric(vlm=mock_vlm, device="cpu") + images = _dummy_image(batch=1) + metric.update(["prompt"], ["hello"], images) + result = metric.compute() + + assert 0.0 <= result.result <= 1.0, f"TextScoreMetric must return [0,1], got {result.result}" + assert result.result < 0.5, f"Very different strings should score below 0.5, got {result.result}" + + +@pytest.mark.cpu +def test_text_score_perfect_match_is_one() -> None: + """TextScoreMetric: identical OCR and GT -> score 1.0.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.generate.return_value = ["hello world"] + + metric = TextScoreMetric(vlm=mock_vlm, device="cpu") + images = _dummy_image(batch=1) + metric.update(["prompt"], ["hello world"], images) + result = metric.compute() + + assert result.result == 1.0, f"Perfect match should give 1.0, got {result.result}" + assert result.higher_is_better is True + + +@pytest.mark.cpu +def test_text_score_registry_aliases() -> None: + """Registry aliases ocr_levenshtein and ocr_text_score resolve to the correct metric classes.""" + from pruna.evaluation.metrics.registry import MetricRegistry + + lev = MetricRegistry.get_metric("ocr_levenshtein", device="cpu", model_name="openai/gpt-4o") + comp = MetricRegistry.get_metric("ocr_text_score", device="cpu", model_name="openai/gpt-4o") + assert type(lev).__name__ == "TextScoreMetric" + assert type(comp).__name__ == "OneIGTextScoreMetric" + assert lev.metric_name == "text_score" + assert comp.metric_name == "oneig_text_score" + + +@pytest.mark.cpu +def test_oneig_text_score_utils_golden_composite() -> None: + """oneig_mean_text_score returns expected component values for a known input.""" + from pruna.evaluation.metrics.metric_text_score_utils import oneig_mean_text_score + + ed, cr, wac, composite = oneig_mean_text_score( + edit_distances=[10.0], + completion_ratios=[0.0], + match_counts=[2], + gt_totals=[4], + language_mode="EN", + ) + assert ed == 10.0 + assert cr == 0.0 + assert wac == 0.5 + assert composite == pytest.approx(0.95) + + _, _, _, zh = oneig_mean_text_score( + edit_distances=[30.0], + completion_ratios=[0.0], + match_counts=[0], + gt_totals=[1], + language_mode="ZH", + ) + assert zh == pytest.approx(0.4) + + +@pytest.mark.cpu +def test_qa_accuracy_all_or_nothing_partial_fail() -> None: + """all_or_nothing: if any question scores 0, the image score is 0.0 (not a partial mean).""" + mock_vlm = MagicMock(spec=BaseVLM) + # First question Yes (1.0), second question No (0.0) → mean=0.5, all_or_nothing=0.0 + mock_vlm.score.return_value = [1.0, 0.0] + + metric = QAAccuracyMetric(vlm=mock_vlm, device="cpu", aggregation="all_or_nothing") + metric.update( + ["a prompt"], + [{"questions": {"1": "Is there a cat?", "2": "Is it blue?"}}], + _dummy_image(batch=1), + ) + result = metric.compute() + assert result.result == 0.0, f"Expected 0.0 for all_or_nothing with one No, got {result.result}" + + +@pytest.mark.cpu +def test_qa_accuracy_all_or_nothing_all_yes() -> None: + """all_or_nothing: all Yes → score 1.0.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.score.return_value = [1.0, 1.0] + + metric = QAAccuracyMetric(vlm=mock_vlm, device="cpu", aggregation="all_or_nothing") + metric.update( + ["a prompt"], + [{"questions": {"1": "Is there a cat?", "2": "Is it blue?"}}], + _dummy_image(batch=1), + ) + result = metric.compute() + assert result.result == 1.0, f"Expected 1.0 for all_or_nothing with all Yes, got {result.result}" + + +@pytest.mark.cpu +def test_qa_accuracy_invalid_aggregation_raises() -> None: + """qa_accuracy rejects aggregation values other than mean / all_or_nothing.""" + mock_vlm = MagicMock(spec=BaseVLM) + with pytest.raises(ValueError, match="aggregation"): + QAAccuracyMetric(vlm=mock_vlm, device="cpu", aggregation="median") + + +@pytest.mark.cpu +def test_vie_score_tie_uses_source_from_gt_and_two_image_sc() -> None: + """With ``source_image_bytes`` in gt, VieScore calls two-image SC then PQ on the edited image.""" + from io import BytesIO + + from PIL import Image + + buf = BytesIO() + Image.new("RGB", (8, 8), color=(0, 0, 200)).save(buf, format="PNG") + src_bytes = buf.getvalue() + + mock_vlm = MagicMock() + mock_vlm.generate_with_image_lists.return_value = ['{"score": [8.0, 8.0], "reasoning": "ok"}'] + mock_vlm.generate.return_value = ['{"score": [9.0, 9.0], "reasoning": "ok"}'] + + metric = VieScoreMetric(vlm=mock_vlm, device="cpu", structured_output=True) + pred = _dummy_image(batch=1) + metric.update( + ["make the sky purple"], + [{"source_image_bytes": src_bytes}], + pred, + ) + result = metric.compute() + + assert mock_vlm.generate_with_image_lists.called + assert mock_vlm.generate.called + assert 0.0 <= result.result <= 1.0 + + +@pytest.mark.cpu +def test_vie_score_uses_get_score_from_response() -> None: + """VieScoreMetric ``t2i`` path parses JSON ``score`` lists via ``viescore_min_scores_0_10``.""" + mock_vlm = MagicMock(spec=BaseVLM) + # LitellmVLM returns model_dump_json() for structured outputs → JSON string (two SC + two PQ sub-scores) + mock_vlm.generate.return_value = ['{"score": [8.0, 8.0], "reasoning": ""}'] + + metric = VieScoreMetric(vlm=mock_vlm, device="cpu", structured_output=True) + metric.update(["a cat on a sofa"], _dummy_image(batch=1), _dummy_image(batch=1)) + result = metric.compute() + + # min(SC)=8, min(PQ)=8 → sqrt(8 * 8) / 10 = 0.8 + assert abs(result.result - 0.8) < 0.01, f"Expected ~0.8, got {result.result}" + + +@pytest.mark.cpu +def test_img_edit_score_negative_response_clamped() -> None: + """img_edit_score must be non-negative even when the VLM generates a negative JSON score. + + Regression for: Outlines constrained decoding can emit {"score": -10} despite the + FloatOutput JSON schema specifying minimum=0, because Outlines does not enforce numeric + bounds during token sampling. The fix is max(0.0, ...) in get_score_from_response. + """ + mock_vlm = MagicMock(spec=BaseVLM) + # Simulate Outlines generating a negative value (the bug scenario) + mock_vlm.generate.return_value = ['{"score": -10.0}'] + + metric = ImageEditScoreMetric(vlm=mock_vlm, device="cpu", structured_output=True) + metric.update(["replace the boot with a mug"], torch.zeros(1), _dummy_image(batch=1)) + result = metric.compute() + + assert result.result >= 0.0, f"img_edit_score must be >= 0, got {result.result}" + + +@pytest.mark.cpu +def test_qa_accuracy_all_or_nothing_ambiguous_score() -> None: + """all_or_nothing: score exactly 0.5 (ambiguous) is treated as No → result 0.0.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.score.return_value = [0.5] + + metric = QAAccuracyMetric(vlm=mock_vlm, device="cpu", aggregation="all_or_nothing") + metric.update( + ["a prompt"], + [{"questions": {"1": "Is there a cat?"}}], + _dummy_image(batch=1), + ) + result = metric.compute() + assert result.result == 0.0, f"Score 0.5 should be treated as No (ambiguous), got {result.result}" + + +@pytest.mark.cpu +@pytest.mark.slow +def test_yes_no_token_ids_smolvlm_nonempty() -> None: + """SmolVLM tokenizer must yield non-empty disjoint yes/no prefix ids for VQAScore scoring.""" + pytest.importorskip("transformers") + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct") + yes_ids, no_ids = yes_no_first_token_id_groups(tok) + assert len(yes_ids) > 0, "SmolVLM tokenizer has no 'Yes'-prefix token ids" + assert len(no_ids) > 0, "SmolVLM tokenizer has no 'No'-prefix token ids" + assert not (set(yes_ids) & set(no_ids)), "yes_ids and no_ids must be disjoint" + + +@pytest.mark.cpu +def test_img_edit_score_uses_prompt_from_x() -> None: + """img_edit_score must score the edited image against the instruction from x, not gt.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.generate.return_value = ['{"score": 9}'] + + metric = ImageEditScoreMetric(vlm=mock_vlm, device="cpu") + pred = _dummy_image(batch=1) + metric.update( + ["replace the cat with a dog"], # x = instruction + pred, # gt = unused for y_x + pred, # outputs = edited image + ) + result = metric.compute() + + call_args = mock_vlm.generate.call_args + prompt_sent = call_args[0][1][0] # second positional arg = prompts list, first item + assert "replace the cat with a dog" in prompt_sent, f"Instruction not in VLM prompt. Got: {prompt_sent}" + assert abs(result.result - 0.9) < 0.01, f"Expected ~0.9, got {result.result}" + + +@pytest.mark.cpu +def test_vie_score_geditbench_gap_documented() -> None: + """VieScoreMetric infers text--image editing from ``source_image_bytes`` in aux (no ``task_type``). + + This test fails if a ``task_type`` parameter is added to ``__init__`` without updating + GEditBench integration tests and benchmark copy accordingly. + """ + import inspect + + sig = inspect.signature(VieScoreMetric.__init__) + assert "task_type" not in sig.parameters, ( + "VieScoreMetric now has task_type — update GEditBench docs and e2e tests, then remove this sentinel." + ) + + +@pytest.mark.cpu +def test_litellm_logprob_aggregation_sums_all_yes_tokens() -> None: + """LitellmVLM logprob scoring must sum all yes-prefix token probs, not return the first.""" + pytest.importorskip("litellm") + import math + from unittest.mock import MagicMock, patch + + import numpy as np + from PIL import Image + + from pruna.evaluation.metrics.vlm_base import LitellmVLM + + # Simulate top_logprobs for first output token: + # "Yes" → logprob=-2.303 (p≈0.10), " yes" → logprob=-2.996 (p≈0.05) → total p_yes≈0.15 + # "No" → logprob=-1.609 (p≈0.20), " no" → logprob=-2.303 (p≈0.10) → total p_no≈0.30 + # normalized: p_yes/(p_yes+p_no) ≈ 0.15/0.45 ≈ 0.333 + def make_top_logprob(token, logprob): + t = MagicMock() + t.token = token + t.logprob = logprob + return t + + first_tok = MagicMock() + first_tok.top_logprobs = [ + make_top_logprob("Yes", math.log(0.10)), + make_top_logprob(" yes", math.log(0.05)), + make_top_logprob("No", math.log(0.20)), + make_top_logprob(" no", math.log(0.10)), + make_top_logprob("maybe", math.log(0.55)), + ] + + mock_logprobs = MagicMock() + mock_logprobs.content = [first_tok] + + mock_choice = MagicMock() + mock_choice.logprobs = mock_logprobs + mock_choice.message.content = "Yes" + + mock_response = MagicMock() + mock_response.choices = [mock_choice] + + with patch("litellm.completion", return_value=mock_response): + vlm = LitellmVLM(model_name="openai/gpt-4o") + img = Image.fromarray(np.zeros((32, 32, 3), dtype="uint8")) + score = vlm._score_with_logprobs(img, "Is there a cat?", "Yes") + + # Should be ~0.333 (p_yes=0.15 / (p_yes+p_no)=0.45), not just 0.10 (first match) + assert 0.28 < score < 0.40, f"Expected ~0.333 (sum-normalized), got {score}" + + +@pytest.mark.cpu +@pytest.mark.slow +def test_vqa_probability_score_normalized() -> None: + """P(Yes) from TransformersVLM.score use_probability=True is in [0, 1].""" + pytest.importorskip("transformers") + import numpy as np + from PIL import Image + + from pruna.evaluation.metrics.vlm_base import TransformersVLM + + vlm = TransformersVLM( + model_name="HuggingFaceTB/SmolVLM-256M-Instruct", + device="cpu", + use_outlines=False, + ) + img = Image.fromarray(np.zeros((32, 32, 3), dtype="uint8")) + scores = vlm.score([img], ["Is there a cat?"], ["Yes"], use_probability=True) + assert len(scores) == 1 + assert 0.0 <= scores[0] <= 1.0, f"P(Yes) must be in [0, 1], got {scores[0]}" + + +# --------------------------------------------------------------------------- +# vlm_benchmark_batch_to_json_record serialization tests +# --------------------------------------------------------------------------- + + +def test_vlm_benchmark_batch_to_json_record_serializes_batch() -> None: + """Record includes prompts, pred shape, and metric fields.""" + mr = MetricResult(name="qa_accuracy", params={}, result=0.25, higher_is_better=True) + outcome = BenchmarkVlmBatchOutcome( + result=mr, + prompts=["prompt"], + auxiliaries=[{"path": "/tmp/x.png"}], + pred=torch.zeros(1, 3, 8, 8), + ) + rec = vlm_benchmark_batch_to_json_record( + outcome, + benchmark_key="GenEval", + benchmark_name="GenEval", + metric_name="qa_accuracy", + vlm_type="transformers", + model_name="m", + device="cpu", + ) + assert rec["inputs"]["prompts"] == ["prompt"] + assert rec["pred"]["shape"] == [1, 3, 8, 8] + assert rec["metric_result"]["result"] == 0.25 + + +def test_safe_json_handles_bytes_without_expanding() -> None: + """Bytes values in aux (e.g. source_image_bytes) are summarized, not expanded to str repr.""" + result = safe_json_for_snapshot({"source_image_bytes": b"\xff\xd8\xff" * 1000, "name": "test"}) + assert result["source_image_bytes"] == {"bytes_len": 3000} + assert result["name"] == "test" + + +def test_vlm_benchmark_batch_to_json_record_preserves_null_question_slots() -> None: + """Padded ``None`` question labels stay JSON null, not the string ``"None"``.""" + mr = MetricResult(name="oneig_alignment", params={}, result=1.0, higher_is_better=True) + outcome = BenchmarkVlmBatchOutcome( + result=mr, + prompts=["p"], + auxiliaries=[{"questions": {"1": "Are there boys?", "21": None}, "subset": "Anime_Stylization"}], + pred=torch.zeros(1, 3, 8, 8), + ) + rec = vlm_benchmark_batch_to_json_record( + outcome, + benchmark_key="OneIGAnimeStylization", + benchmark_name="OneIG Anime Stylization", + metric_name="oneig_alignment", + vlm_type="transformers", + model_name="m", + device="cpu", + ) + qs = rec["inputs"]["auxiliary_0"]["questions"] + assert qs["1"] == "Are there boys?" + assert qs["21"] is None + + +# --------------------------------------------------------------------------- +# pred_tensor_from_auxiliaries (test helper, wraps pil_rgb_from_aux_image_bytes) tests +# --------------------------------------------------------------------------- + + +def _make_jpeg_bytes(h: int = 32, w: int = 32) -> bytes: + """Return a tiny JPEG-encoded RGB image as bytes (test helper).""" + import io + + import numpy as np + from PIL import Image + + arr = (np.random.rand(h, w, 3) * 255).astype("uint8") + buf = io.BytesIO() + Image.fromarray(arr).save(buf, format="JPEG") + return buf.getvalue() + + +@pytest.mark.cpu +def test_pred_from_auxiliaries_uses_source_image_bytes() -> None: + """pred_tensor_from_auxiliaries decodes source_image_bytes into a float tensor in [0, 1].""" + src_bytes = _make_jpeg_bytes() + aux = [{"source_image_bytes": src_bytes, "category": "background_change"}] + pred = pred_tensor_from_auxiliaries(aux, size=64) + + assert pred.shape == (1, 3, 64, 64), f"Expected (1,3,64,64), got {pred.shape}" + assert pred.min() >= 0.0 and pred.max() <= 1.0, "Pixel values must be in [0, 1]" + + +@pytest.mark.cpu +def test_pred_from_auxiliaries_falls_back_to_noise_without_source_image() -> None: + """pred_tensor_from_auxiliaries returns random noise when no source_image_bytes is present.""" + aux = [{"category": "single_object"}] + pred = pred_tensor_from_auxiliaries(aux, size=32) + assert pred.shape == (1, 3, 32, 32) + assert pred.min() >= 0.0 and pred.max() <= 1.0 + + +@pytest.mark.cpu +def test_pred_from_auxiliaries_mixed_batch() -> None: + """Batch with one source image and one missing falls back per-item.""" + src_bytes = _make_jpeg_bytes() + aux = [ + {"source_image_bytes": src_bytes, "category": "color_alter"}, + {"category": "style_change"}, # no source image + ] + pred = pred_tensor_from_auxiliaries(aux, size=32) + assert pred.shape == (2, 3, 32, 32) + assert pred.min() >= 0.0 and pred.max() <= 1.0 + + +@pytest.mark.cpu +def test_pred_from_auxiliaries_generic_bytes_scan() -> None: + """pred_tensor_from_auxiliaries discovers image bytes under an unknown field name (generic scan).""" + src_bytes = _make_jpeg_bytes() + aux = [{"my_custom_image_bytes": src_bytes, "category": "motion_change"}] + pred = pred_tensor_from_auxiliaries(aux, size=32) + assert pred.shape == (1, 3, 32, 32) + assert pred.min() >= 0.0 and pred.max() <= 1.0 + + +@pytest.mark.cpu +def test_pred_from_auxiliaries_known_names_take_priority() -> None: + """Known field names are resolved before the generic bytes scan.""" + src_bytes_known = _make_jpeg_bytes(16, 16) + src_bytes_unknown = _make_jpeg_bytes(32, 32) + first_known = VLM_AUX_IMAGE_BYTES_KEY_ORDER[0] + aux = [{"other_bytes": src_bytes_unknown, first_known: src_bytes_known}] + pred = pred_tensor_from_auxiliaries(aux, size=16) + # Should use the known key (16x16 image → 16x16 crop); generic scan would pick 32x32 + assert pred.shape == (1, 3, 16, 16) + + +@pytest.mark.cpu +def test_pred_from_auxiliaries_require_source_image_raises_when_missing() -> None: + """require_source_image=True raises ValueError instead of silently returning noise.""" + aux = [{"category": "replace"}] # no image bytes + with pytest.raises(ValueError, match="require_source_image=True"): + pred_tensor_from_auxiliaries(aux, size=32, require_source_image=True) + + +@pytest.mark.cpu +def test_pred_from_auxiliaries_require_source_image_succeeds_when_present() -> None: + """require_source_image=True succeeds and decodes bytes when source_image_bytes is present.""" + src_bytes = _make_jpeg_bytes() + aux = [{"source_image_bytes": src_bytes, "category": "replace"}] + pred = pred_tensor_from_auxiliaries(aux, size=32, require_source_image=True) + assert pred.shape == (1, 3, 32, 32) + assert pred.min() >= 0.0 and pred.max() <= 1.0