From 847c4a8a048e527df97d319645e7d577c5b2938b Mon Sep 17 00:00:00 2001 From: Surbhi Jain Date: Wed, 29 Apr 2026 18:33:46 +0000 Subject: [PATCH] Add unit tests for evaluate_rl.py --- tests/post_training/unit/evaluate_rl_test.py | 93 ++++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/tests/post_training/unit/evaluate_rl_test.py b/tests/post_training/unit/evaluate_rl_test.py index 205e4a8797..a28b0cda35 100644 --- a/tests/post_training/unit/evaluate_rl_test.py +++ b/tests/post_training/unit/evaluate_rl_test.py @@ -14,9 +14,11 @@ """Unit tests for evaluate_rl.py (CPU-only).""" +import json import unittest import pytest from types import SimpleNamespace +from unittest import mock from maxtext.trainers.post_train.rl import evaluate_rl @@ -208,5 +210,96 @@ def test_pass_at_1_all_correct(self): self.assertAlmostEqual(has_correct_format, 1.0) +class TestEvaluate(unittest.TestCase): + """Tests for the main evaluate() function.""" + + def setUp(self): + self.config = _make_config() + self.mock_cluster = mock.Mock() + self.dataset = [ + { + "question": ["q1", "q2"], + "answer": [json.dumps(["a1"]), json.dumps(["a2"])], + "prompts": ["p1", "p2"], + } + ] + + @pytest.mark.cpu_only + @mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.generate_responses") + @mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.score_responses") + def test_evaluate_pass_mode(self, mock_score, mock_generate): + """Test basic evaluation flow in 'pass' mode.""" + config = _make_config(eval_mode="pass") + # Two items in batch: one correct, one incorrect. + mock_generate.return_value = [["r1"], ["r2"]] + mock_score.side_effect = [ + (True, True, True), # First item is correct + (False, False, False), # Second item is incorrect + ] + + stats, response_list = evaluate_rl.evaluate(config, self.dataset, self.mock_cluster) + + self.assertEqual(stats, (1, 2, 50.0, 50.0, 50.0)) + self.assertEqual(response_list, []) + self.assertEqual(mock_generate.call_count, 1) + self.assertEqual(mock_score.call_count, 2) + + @pytest.mark.cpu_only + @mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.generate_responses") + @mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.score_responses") + def test_evaluate_pass_at_1_mode(self, mock_score, mock_generate): + """Test evaluation flow in 'pass_at_1' mode, which returns floats.""" + config = _make_config(eval_mode="pass_at_1") + mock_generate.return_value = [["r1a", "r1b"], ["r2a", "r2b"]] + # First item: 1/2 correct (0.5) + # Second item: 2/2 correct (1.0) + mock_score.side_effect = [ + (0.5, 0.5, 1.0), + (1.0, 1.0, 1.0), + ] + + stats, _ = evaluate_rl.evaluate(config, self.dataset, self.mock_cluster) + + # Total correct = 0.5 + 1.0 = 1.5 + # Total items = 2 + # Accuracy = (1.5 / 2) * 100 = 75.0 + self.assertEqual(stats, (1.5, 2, 75.0, 75.0, 100.0)) + + @pytest.mark.cpu_only + @mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.generate_responses") + @mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.score_responses") + def test_make_lst_correct(self, mock_score, mock_generate): + """Test that make_lst=True, corr_lst=True returns only correct items.""" + mock_generate.return_value = [["r1"], ["r2"]] + mock_score.side_effect = [(True, True, True), (False, False, False)] + + _, response_list = evaluate_rl.evaluate(self.config, self.dataset, self.mock_cluster, make_lst=True, corr_lst=True) + + self.assertEqual(len(response_list), 1) + self.assertEqual(response_list[0], ("q1", ["a1"], ["r1"])) + + @pytest.mark.cpu_only + @mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.generate_responses") + @mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.score_responses") + def test_make_lst_incorrect(self, mock_score, mock_generate): + """Test that make_lst=True, corr_lst=False returns only incorrect items.""" + mock_generate.return_value = [["r1"], ["r2"]] + mock_score.side_effect = [(True, True, True), (False, False, False)] + + _, response_list = evaluate_rl.evaluate(self.config, self.dataset, self.mock_cluster, make_lst=True, corr_lst=False) + + self.assertEqual(len(response_list), 1) + self.assertEqual(response_list[0], ("q2", ["a2"], ["r2"])) + + @pytest.mark.cpu_only + @mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.generate_responses") + def test_empty_dataset(self, mock_generate): + """Test that evaluation on an empty dataset returns zero stats.""" + stats, response_list = evaluate_rl.evaluate(self.config, [], self.mock_cluster) + self.assertEqual(stats, (0, 0, 0, 0, 0)) + self.assertEqual(response_list, []) + mock_generate.assert_not_called() + + if __name__ == "__main__": unittest.main()