|
14 | 14 |
|
15 | 15 | """Unit tests for evaluate_rl.py (CPU-only).""" |
16 | 16 |
|
| 17 | +import json |
17 | 18 | import unittest |
18 | 19 | import pytest |
19 | 20 | from types import SimpleNamespace |
| 21 | +from unittest import mock |
20 | 22 |
|
21 | 23 | from maxtext.trainers.post_train.rl import evaluate_rl |
22 | 24 |
|
@@ -208,5 +210,95 @@ def test_pass_at_1_all_correct(self): |
208 | 210 | self.assertAlmostEqual(has_correct_format, 1.0) |
209 | 211 |
|
210 | 212 |
|
| 213 | +class TestEvaluate(unittest.TestCase): |
| 214 | + """Tests for the main evaluate() function.""" |
| 215 | + |
| 216 | + def setUp(self): |
| 217 | + self.mock_cluster = mock.Mock() |
| 218 | + self.dataset = [ |
| 219 | + { |
| 220 | + "question": ["q1", "q2"], |
| 221 | + "answer": [json.dumps(["a1"]), json.dumps(["a2"])], |
| 222 | + "prompts": ["p1", "p2"], |
| 223 | + } |
| 224 | + ] |
| 225 | + |
| 226 | + @pytest.mark.cpu_only |
| 227 | + @mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.generate_responses") |
| 228 | + @mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.score_responses") |
| 229 | + def test_evaluate_pass_mode(self, mock_score, mock_generate): |
| 230 | + """Test basic evaluation flow in 'pass' mode.""" |
| 231 | + config = _make_config(eval_mode="pass") |
| 232 | + # Two items in batch: one correct, one incorrect. |
| 233 | + mock_generate.return_value = [["r1"], ["r2"]] |
| 234 | + mock_score.side_effect = [ |
| 235 | + (True, True, True), # First item is correct |
| 236 | + (False, False, False), # Second item is incorrect |
| 237 | + ] |
| 238 | + |
| 239 | + stats, response_list = evaluate_rl.evaluate(config, self.dataset, self.mock_cluster) |
| 240 | + |
| 241 | + self.assertEqual(stats, (1, 2, 50.0, 50.0, 50.0)) |
| 242 | + self.assertEqual(response_list, []) |
| 243 | + self.assertEqual(mock_generate.call_count, 1) |
| 244 | + self.assertEqual(mock_score.call_count, 2) |
| 245 | + |
| 246 | + @pytest.mark.cpu_only |
| 247 | + @mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.generate_responses") |
| 248 | + @mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.score_responses") |
| 249 | + def test_evaluate_pass_at_1_mode(self, mock_score, mock_generate): |
| 250 | + """Test evaluation flow in 'pass_at_1' mode, which returns floats.""" |
| 251 | + config = _make_config(eval_mode="pass_at_1") |
| 252 | + mock_generate.return_value = [["r1a", "r1b"], ["r2a", "r2b"]] |
| 253 | + # First item: 1/2 correct (0.5) |
| 254 | + # Second item: 2/2 correct (1.0) |
| 255 | + mock_score.side_effect = [ |
| 256 | + (0.5, 0.5, 1.0), |
| 257 | + (1.0, 1.0, 1.0), |
| 258 | + ] |
| 259 | + |
| 260 | + stats, _ = evaluate_rl.evaluate(config, self.dataset, self.mock_cluster) |
| 261 | + |
| 262 | + # Total correct = 0.5 + 1.0 = 1.5 |
| 263 | + # Total items = 2 |
| 264 | + # Accuracy = (1.5 / 2) * 100 = 75.0 |
| 265 | + self.assertEqual(stats, (1.5, 2, 75.0, 75.0, 100.0)) |
| 266 | + |
| 267 | + @pytest.mark.cpu_only |
| 268 | + @mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.generate_responses") |
| 269 | + @mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.score_responses") |
| 270 | + def test_make_lst_correct(self, mock_score, mock_generate): |
| 271 | + """Test that make_lst=True, corr_lst=True returns only correct items.""" |
| 272 | + mock_generate.return_value = [["r1"], ["r2"]] |
| 273 | + mock_score.side_effect = [(True, True, True), (False, False, False)] |
| 274 | + |
| 275 | + _, response_list = evaluate_rl.evaluate(self.config, self.dataset, self.mock_cluster, make_lst=True, corr_lst=True) |
| 276 | + |
| 277 | + self.assertEqual(len(response_list), 1) |
| 278 | + self.assertEqual(response_list[0], ("q1", ["a1"], ["r1"])) |
| 279 | + |
| 280 | + @pytest.mark.cpu_only |
| 281 | + @mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.generate_responses") |
| 282 | + @mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.score_responses") |
| 283 | + def test_make_lst_incorrect(self, mock_score, mock_generate): |
| 284 | + """Test that make_lst=True, corr_lst=False returns only incorrect items.""" |
| 285 | + mock_generate.return_value = [["r1"], ["r2"]] |
| 286 | + mock_score.side_effect = [(True, True, True), (False, False, False)] |
| 287 | + |
| 288 | + _, response_list = evaluate_rl.evaluate(self.config, self.dataset, self.mock_cluster, make_lst=True, corr_lst=False) |
| 289 | + |
| 290 | + self.assertEqual(len(response_list), 1) |
| 291 | + self.assertEqual(response_list[0], ("q2", ["a2"], ["r2"])) |
| 292 | + |
| 293 | + @pytest.mark.cpu_only |
| 294 | + @mock.patch("maxtext.trainers.post_train.rl.evaluate_rl.generate_responses") |
| 295 | + def test_empty_dataset(self, mock_generate): |
| 296 | + """Test that evaluation on an empty dataset returns zero stats.""" |
| 297 | + stats, response_list = evaluate_rl.evaluate(self.config, [], self.mock_cluster) |
| 298 | + self.assertEqual(stats, (0, 0, 0, 0, 0)) |
| 299 | + self.assertEqual(response_list, []) |
| 300 | + mock_generate.assert_not_called() |
| 301 | + |
| 302 | + |
211 | 303 | if __name__ == "__main__": |
212 | 304 | unittest.main() |
0 commit comments