-
Notifications
You must be signed in to change notification settings - Fork 290
Expand file tree
/
Copy pathdrgrpo_learner_test.py
More file actions
145 lines (126 loc) · 4.92 KB
/
drgrpo_learner_test.py
File metadata and controls
145 lines (126 loc) · 4.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
from flax import nnx
import jax
import jax.numpy as jnp
import numpy as np
from tunix.rl import function_registry as fr
from tunix.rl.grpo import drgrpo_learner as drgrpo_lib
from tunix.rl.grpo import grpo_learner as grpo_lib
from tunix.tests import test_common as tc
jax.config.update("jax_threefry_partitionable", False)
class DrGRPOlearnerTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self.mock_model = mock.MagicMock()
self.pad_id = 0
self.eos_id = 1
# Common data shapes
self.batch_size = 2
self.seq_len = 4
self.prompt_ids = jnp.zeros(
(self.batch_size, self.seq_len), dtype=jnp.int32
)
self.completion_ids = jnp.ones(
(self.batch_size, self.seq_len), dtype=jnp.int32
)
self.completion_mask = jnp.array(
[[1, 1, 1, 0], [1, 1, 0, 0]], dtype=jnp.float32
)
self.advantages = jnp.array([0.5, -0.2], dtype=jnp.float32)
self.ref_per_token_logps = (
jnp.ones_like(self.completion_ids, dtype=jnp.float32) * -0.2
)
self.old_per_token_logps = (
jnp.ones_like(self.completion_ids, dtype=jnp.float32) * -0.15
)
def create_train_example(self):
example = mock.MagicMock()
example.prompt_ids = self.prompt_ids
example.completion_ids = self.completion_ids
example.completion_mask = self.completion_mask
example.advantages = self.advantages
example.ref_per_token_logps = self.ref_per_token_logps
example.old_per_token_logps = self.old_per_token_logps
example.segment_ids = None
example.segment_positions = None
return example
def test_create_config(self):
drgrpo_config = drgrpo_lib.DrGRPOConfig(
epsilon=0.1, num_generations=5, num_iterations=3, beta=0.123
)
self.assertEqual(drgrpo_config.algo_variant, "drgrpo")
self.assertEqual(drgrpo_config.advantage_estimator, "drgrpo")
self.assertEqual(drgrpo_config.loss_agg_mode, "sequence-mean-token-scale")
self.assertEqual(drgrpo_config.num_generations, 5)
self.assertEqual(drgrpo_config.num_iterations, 3)
self.assertEqual(drgrpo_config.epsilon, 0.1)
self.assertEqual(drgrpo_config.beta, 0.123)
def test_drgrpo_advantage_estimator(self):
drgrpo_config = drgrpo_lib.DrGRPOConfig()
grpo_config = grpo_lib.GRPOConfig()
grpo_advantage_estimator = fr.get_advantage_estimator(
grpo_config.advantage_estimator
)
drgrpo_advantage_estimator = fr.get_advantage_estimator(
drgrpo_config.advantage_estimator
)
# Batch size 3 with group size 2.
n_generations = 2
rewards = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
num_generations = n_generations
grpo_advantages = grpo_advantage_estimator(
rewards=rewards.ravel(), num_generations=num_generations
)
drgrpo_advantages = drgrpo_advantage_estimator(
rewards=rewards.ravel(), num_generations=num_generations
)
# Dr. GRPO advantages are not scaled by the standard deviation.
# Std. across groups above is the same by construction.
std_factor = jnp.array([1.0, 2.0]).std(ddof=1) + 1e-6
np.testing.assert_allclose(grpo_advantages * std_factor, drgrpo_advantages)
def test_drgrpo_loss_fn(self):
drgrpo_config = drgrpo_lib.DrGRPOConfig()
drgrpo_loss_fn_impl = fr.default_registry.get(
"policy_loss_fn", drgrpo_config.policy_loss_fn
)
# Create the same input for both functions
train_example = self.create_train_example()
pad_id = self.pad_id
eos_id = self.eos_id
vocab = tc.MockVocab()
model = tc.ToyTransformer(
config=tc.ModelConfig(vocab_size=vocab.GetPieceSize()),
rngs=nnx.Rngs(0),
)
# Call DrGRPO loss function
drgrpo_loss, drgrpo_aux = drgrpo_loss_fn_impl(
model, train_example, drgrpo_config, pad_id, eos_id
)
self.assertIn("kl", drgrpo_aux)
self.assertTrue(jnp.isfinite(drgrpo_loss).all())
def test_compute_advantages(self):
rewards = jnp.array(
[[0.57450044, 0.09968603, 0.7419659, 0.8941783, 0.59656656, 0.45325184]]
)
advantages = drgrpo_lib.compute_advantages(rewards, num_generations=3)
expected_array = jnp.array([
[0.10245, -0.372365, 0.269915, 0.246179, -0.051432, -0.194747],
])
np.testing.assert_allclose(advantages, expected_array, rtol=1e-5, atol=1e-5)
if __name__ == "__main__":
absltest.main()