diff --git a/src/lighteval/metrics/normalizations.py b/src/lighteval/metrics/normalizations.py index ef55681b1..15932dd87 100644 --- a/src/lighteval/metrics/normalizations.py +++ b/src/lighteval/metrics/normalizations.py @@ -20,6 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import logging import re import string import sys @@ -31,6 +32,8 @@ from lighteval.utils.imports import Extra, requires from lighteval.utils.language import Language +logger = logging.getLogger(__name__) + # From HELM def helm_normalizer(text: str) -> str: @@ -523,8 +526,14 @@ def normalize_log_probs( normalized_log_probs = [choices_logprob[ix] / len(choice) for ix, choice in enumerate(choices_text)] case LogProbTokenNorm(): assert choices_tokens is not None, "choices_tokens must be provided for token normalization" + n = min(len(choices_logprob), len(choices_tokens)) + if n < len(choices_logprob): + logger.warning( + f"choices_tokens length ({len(choices_tokens)}) is less than choices_logprob length " + f"({len(choices_logprob)}). This may indicate corrupted cache data. Truncating to {n} elements." + ) normalized_log_probs = [ - choices_logprob[ix] / len(choices_tokens[ix]) for ix in range(len(choices_logprob)) + choices_logprob[ix] / len(choices_tokens[ix]) for ix in range(n) ] case LogProbPMINorm(): assert unconditioned_logprob is not None, "unconditioned_logprob must be provided for PMI normalization"