Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions tests/post_training/unit/evaluate_rl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@

"""Unit tests for evaluate_rl.py (CPU-only)."""

import json
import unittest
import pytest
from types import SimpleNamespace
from unittest import mock

from maxtext.trainers.post_train.rl import evaluate_rl

Expand Down Expand Up @@ -208,5 +210,96 @@ def test_pass_at_1_all_correct(self):
self.assertAlmostEqual(has_correct_format, 1.0)


class TestEvaluate(unittest.TestCase):
"""Tests for the main evaluate() function."""

def setUp(self):
self.config = _make_config()
self.mock_cluster = mock.Mock()
self.dataset = [
{
"question": ["q1", "q2"],
"answer": [json.dumps(["a1"]), json.dumps(["a2"])],
"prompts": ["p1", "p2"],
}
]

@pytest.mark.cpu_only
@mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.generate_responses")
@mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.score_responses")
def test_evaluate_pass_mode(self, mock_score, mock_generate):
"""Test basic evaluation flow in 'pass' mode."""
config = _make_config(eval_mode="pass")
# Two items in batch: one correct, one incorrect.
mock_generate.return_value = [["r1"], ["r2"]]
mock_score.side_effect = [
(True, True, True), # First item is correct
(False, False, False), # Second item is incorrect
]

stats, response_list = evaluate_rl.evaluate(config, self.dataset, self.mock_cluster)

self.assertEqual(stats, (1, 2, 50.0, 50.0, 50.0))
self.assertEqual(response_list, [])
self.assertEqual(mock_generate.call_count, 1)
self.assertEqual(mock_score.call_count, 2)

@pytest.mark.cpu_only
@mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.generate_responses")
@mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.score_responses")
def test_evaluate_pass_at_1_mode(self, mock_score, mock_generate):
"""Test evaluation flow in 'pass_at_1' mode, which returns floats."""
config = _make_config(eval_mode="pass_at_1")
mock_generate.return_value = [["r1a", "r1b"], ["r2a", "r2b"]]
# First item: 1/2 correct (0.5)
# Second item: 2/2 correct (1.0)
mock_score.side_effect = [
(0.5, 0.5, 1.0),
(1.0, 1.0, 1.0),
]

stats, _ = evaluate_rl.evaluate(config, self.dataset, self.mock_cluster)

# Total correct = 0.5 + 1.0 = 1.5
# Total items = 2
# Accuracy = (1.5 / 2) * 100 = 75.0
self.assertEqual(stats, (1.5, 2, 75.0, 75.0, 100.0))

@pytest.mark.cpu_only
@mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.generate_responses")
@mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.score_responses")
def test_make_lst_correct(self, mock_score, mock_generate):
"""Test that make_lst=True, corr_lst=True returns only correct items."""
mock_generate.return_value = [["r1"], ["r2"]]
mock_score.side_effect = [(True, True, True), (False, False, False)]

_, response_list = evaluate_rl.evaluate(self.config, self.dataset, self.mock_cluster, make_lst=True, corr_lst=True)

self.assertEqual(len(response_list), 1)
self.assertEqual(response_list[0], ("q1", ["a1"], ["r1"]))

@pytest.mark.cpu_only
@mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.generate_responses")
@mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.score_responses")
def test_make_lst_incorrect(self, mock_score, mock_generate):
"""Test that make_lst=True, corr_lst=False returns only incorrect items."""
mock_generate.return_value = [["r1"], ["r2"]]
mock_score.side_effect = [(True, True, True), (False, False, False)]

_, response_list = evaluate_rl.evaluate(self.config, self.dataset, self.mock_cluster, make_lst=True, corr_lst=False)

self.assertEqual(len(response_list), 1)
self.assertEqual(response_list[0], ("q2", ["a2"], ["r2"]))

@pytest.mark.cpu_only
@mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.generate_responses")
def test_empty_dataset(self, mock_generate):
"""Test that evaluation on an empty dataset returns zero stats."""
stats, response_list = evaluate_rl.evaluate(self.config, [], self.mock_cluster)
self.assertEqual(stats, (0, 0, 0, 0, 0))
self.assertEqual(response_list, [])
mock_generate.assert_not_called()


if __name__ == "__main__":
unittest.main()
Loading