Skip to content

Commit 901966f

Browse files
committed
Add unit tests for evaluate_rl.py
1 parent 060915d commit 901966f

1 file changed

Lines changed: 92 additions & 0 deletions

File tree

tests/post_training/unit/evaluate_rl_test.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414

1515
"""Unit tests for evaluate_rl.py (CPU-only)."""
1616

17+
import json
1718
import unittest
1819
import pytest
1920
from types import SimpleNamespace
21+
from unittest import mock
2022

2123
from maxtext.trainers.post_train.rl import evaluate_rl
2224

@@ -208,5 +210,95 @@ def test_pass_at_1_all_correct(self):
208210
self.assertAlmostEqual(has_correct_format, 1.0)
209211

210212

213+
class TestEvaluate(unittest.TestCase):
214+
"""Tests for the main evaluate() function."""
215+
216+
def setUp(self):
217+
self.mock_cluster = mock.Mock()
218+
self.dataset = [
219+
{
220+
"question": ["q1", "q2"],
221+
"answer": [json.dumps(["a1"]), json.dumps(["a2"])],
222+
"prompts": ["p1", "p2"],
223+
}
224+
]
225+
226+
@pytest.mark.cpu_only
227+
@mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.generate_responses")
228+
@mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.score_responses")
229+
def test_evaluate_pass_mode(self, mock_score, mock_generate):
230+
"""Test basic evaluation flow in 'pass' mode."""
231+
config = _make_config(eval_mode="pass")
232+
# Two items in batch: one correct, one incorrect.
233+
mock_generate.return_value = [["r1"], ["r2"]]
234+
mock_score.side_effect = [
235+
(True, True, True), # First item is correct
236+
(False, False, False), # Second item is incorrect
237+
]
238+
239+
stats, response_list = evaluate_rl.evaluate(config, self.dataset, self.mock_cluster)
240+
241+
self.assertEqual(stats, (1, 2, 50.0, 50.0, 50.0))
242+
self.assertEqual(response_list, [])
243+
self.assertEqual(mock_generate.call_count, 1)
244+
self.assertEqual(mock_score.call_count, 2)
245+
246+
@pytest.mark.cpu_only
247+
@mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.generate_responses")
248+
@mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.score_responses")
249+
def test_evaluate_pass_at_1_mode(self, mock_score, mock_generate):
250+
"""Test evaluation flow in 'pass_at_1' mode, which returns floats."""
251+
config = _make_config(eval_mode="pass_at_1")
252+
mock_generate.return_value = [["r1a", "r1b"], ["r2a", "r2b"]]
253+
# First item: 1/2 correct (0.5)
254+
# Second item: 2/2 correct (1.0)
255+
mock_score.side_effect = [
256+
(0.5, 0.5, 1.0),
257+
(1.0, 1.0, 1.0),
258+
]
259+
260+
stats, _ = evaluate_rl.evaluate(config, self.dataset, self.mock_cluster)
261+
262+
# Total correct = 0.5 + 1.0 = 1.5
263+
# Total items = 2
264+
# Accuracy = (1.5 / 2) * 100 = 75.0
265+
self.assertEqual(stats, (1.5, 2, 75.0, 75.0, 100.0))
266+
267+
@pytest.mark.cpu_only
268+
@mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.generate_responses")
269+
@mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.score_responses")
270+
def test_make_lst_correct(self, mock_score, mock_generate):
271+
"""Test that make_lst=True, corr_lst=True returns only correct items."""
272+
mock_generate.return_value = [["r1"], ["r2"]]
273+
mock_score.side_effect = [(True, True, True), (False, False, False)]
274+
275+
_, response_list = evaluate_rl.evaluate(self.config, self.dataset, self.mock_cluster, make_lst=True, corr_lst=True)
276+
277+
self.assertEqual(len(response_list), 1)
278+
self.assertEqual(response_list[0], ("q1", ["a1"], ["r1"]))
279+
280+
@pytest.mark.cpu_only
281+
@mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.generate_responses")
282+
@mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.score_responses")
283+
def test_make_lst_incorrect(self, mock_score, mock_generate):
284+
"""Test that make_lst=True, corr_lst=False returns only incorrect items."""
285+
mock_generate.return_value = [["r1"], ["r2"]]
286+
mock_score.side_effect = [(True, True, True), (False, False, False)]
287+
288+
_, response_list = evaluate_rl.evaluate(self.config, self.dataset, self.mock_cluster, make_lst=True, corr_lst=False)
289+
290+
self.assertEqual(len(response_list), 1)
291+
self.assertEqual(response_list[0], ("q2", ["a2"], ["r2"]))
292+
293+
@pytest.mark.cpu_only
294+
@mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.generate_responses")
295+
def test_empty_dataset(self, mock_generate):
296+
"""Test that evaluation on an empty dataset returns zero stats."""
297+
stats, response_list = evaluate_rl.evaluate(self.config, [], self.mock_cluster)
298+
self.assertEqual(stats, (0, 0, 0, 0, 0))
299+
self.assertEqual(response_list, [])
300+
mock_generate.assert_not_called()
301+
302+
211303
if __name__ == "__main__":
212304
unittest.main()

0 commit comments

Comments
 (0)