Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
455 changes: 455 additions & 0 deletions examples/deepscaler/benchmark_sequence_packing.py

Large diffs are not rendered by default.

166 changes: 152 additions & 14 deletions tests/rl/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,150 @@ def test_get_per_token_logps(self):
atol=1e-03,
)

def test_compute_per_token_logps(self):
def test_process_ids_raises_value_error(self):
prompt_tokens = jnp.array([[1, 2], [3, 4]])
completion_tokens = jnp.array([[5, 6], [7, 8]])
segment_ids = jnp.array([[1, 1, 2, 2], [1, 1, 2, 2]])
with self.assertRaisesRegex(
ValueError,
"segment_positions must be explicitly provided for packed sequences.",
):
common.process_ids(
prompt_tokens,
completion_tokens,
pad_id=0,
eos_id=-1,
segment_ids=segment_ids,
segment_positions=None,
)

@parameterized.named_parameters(
dict(
testcase_name="normal",
prompt_tokens=np.array([[1, 2, 3, 4], [0, 0, 1, 2], [0, 1, 2, 3]]),
completion_tokens=np.array(
[[10, 11, -1, 12], [10, 11, 12, 13], [10, 11, 12, -1]]
),
segment_ids=None,
segment_positions=None,
expected_logps=np.array([
[-5.876301, -8.700251, -5.046069, -5.788748],
[-6.071025, -7.5328417, -5.9712567, -4.653783],
[-6.039485, -8.264197, -6.2771187, -4.767109],
]),
),
dict(
testcase_name="seq-packed-single-item",
prompt_tokens=np.zeros((3, 0), dtype=np.int32),
completion_tokens=np.array([
[1, 2, 3, 4, 10, 11, -1, 12],
[0, 0, 1, 2, 10, 11, 12, 13],
[0, 1, 2, 3, 10, 11, 12, -1],
]),
segment_ids=np.ones((3, 8), dtype=np.int32),
segment_positions=np.tile(np.arange(8), (3, 1)),
expected_logps=np.array([
[
0.0,
-7.3199797,
-6.8320303,
-5.6091313,
-5.876301,
-8.700251,
-5.0460696,
-5.788748,
],
[
0.0,
-6.4536085,
-5.5156517,
-7.103587,
-6.0710244,
-7.5328417,
-5.971257,
-4.653783,
],
[
0.0,
-5.789238,
-7.7057056,
-6.7916627,
-6.0394855,
-8.264197,
-6.2771187,
-4.7671094,
],
]),
),
dict(
testcase_name="seq-packed-multi-item",
prompt_tokens=np.zeros((2, 0), dtype=np.int32),
completion_tokens=np.array([
[1, 2, 3, 4, 10, 11, -1, 12, 0, 0, 1, 2, 10, 11, 12, 13],
[0, 1, 2, 3, 10, 11, 12, -1, 0, 0, 0, 0, 0, 0, 0, 0],
]),
segment_ids=np.array([
[1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2],
[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
]),
segment_positions=np.array([
[0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7],
]),
# NOTE: Expected logprobs diverge from single-item values because
# floating-point reductions in XLA compound differently when changing
# batch size from 3 to 2 and sequence length from 8 to 16.
expected_logps=np.array([
[
0.0,
-7.255163,
-6.413455,
-5.682157,
-5.83097,
-8.132578,
-4.8891325,
-5.7902822,
-6.452383,
-6.524351,
-5.778284,
-7.255163,
-6.245493,
-8.132578,
-6.025977,
-4.6675467,
],
[
0.0,
-4.070095,
-7.792082,
-6.3780885,
-6.312748,
-6.536421,
-6.0986547,
-5.62961,
-5.558264,
-6.595858,
-6.595858,
-6.595858,
-6.595858,
-6.595858,
-6.595858,
-6.595858,
],
]),
),
)
def test_compute_per_token_logps(
self,
prompt_tokens,
completion_tokens,
segment_ids,
segment_positions,
expected_logps,
):
model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0))
prompt_tokens = jnp.array([[1, 2, 3, 4], [0, 0, 1, 2], [0, 1, 2, 3]])
completion_tokens = jnp.array(
[[10, 11, -1, 12], [10, 11, 12, 13], [10, 11, 12, -1]]
)
graphdef, state = nnx.split(model)

per_token_logps = common.compute_per_token_logps(
graphdef,
state,
Expand All @@ -132,17 +269,14 @@ def test_compute_per_token_logps(self):
pad_id=0,
eos_id=-1,
return_logits=False,
segment_ids=segment_ids,
segment_positions=segment_positions,
)

np.testing.assert_allclose(
per_token_logps,
np.array([
[-5.876301, -8.700251, -5.046069, -5.788748],
[-6.071025, -7.5328417, -5.9712567, -4.653783],
[-6.039485, -8.264197, -6.2771187, -4.767109],
]),
atol=1e-1,
rtol=1e-2,
per_token_logps, expected_logps, atol=1e-5, rtol=1e-5
)

_, logits = common.compute_per_token_logps(
graphdef,
state,
Expand All @@ -151,8 +285,12 @@ def test_compute_per_token_logps(self):
pad_id=0,
eos_id=-1,
return_logits=True,
segment_ids=segment_ids,
segment_positions=segment_positions,
)
np.testing.assert_equal(
logits.shape, (expected_logps.shape[0], expected_logps.shape[1], 256)
)
np.testing.assert_equal(logits.shape, (3, 4, 256))

def test_np_make_completion_mask(self):
completion_ids = np.array(
Expand Down
2 changes: 2 additions & 0 deletions tests/rl/grpo/dapo_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def create_train_example(self):
example.advantages = self.advantages
example.ref_per_token_logps = self.ref_per_token_logps
example.old_per_token_logps = self.old_per_token_logps
example.segment_ids = None
example.segment_positions = None
return example

def test_diff_loss(self):
Expand Down
2 changes: 2 additions & 0 deletions tests/rl/grpo/drgrpo_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def create_train_example(self):
example.advantages = self.advantages
example.ref_per_token_logps = self.ref_per_token_logps
example.old_per_token_logps = self.old_per_token_logps
example.segment_ids = None
example.segment_positions = None
return example

def test_create_config(self):
Expand Down
74 changes: 74 additions & 0 deletions tests/rl/grpo/grpo_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def setup(kwargs: Optional[Dict[str, Any]] = None):
gradient_accumulation_steps=kwargs.get(
'gradient_accumulation_steps', None
),
max_seq_token_per_tpu=kwargs.get('max_seq_token_per_tpu', None),
),
rollout_config=base_rollout.RolloutConfig(
max_tokens_to_generate=10,
Expand Down Expand Up @@ -632,6 +633,79 @@ def test_grpo_with_lora_model(self):
)
jax.tree.map_with_path(tc.assert_equal, original_base_params, base_params)

@parameterized.named_parameters(
dict(
testcase_name='single_sequence',
max_token_len=266, # exactly 256 (max_prompt_length) + 10 (max_tokens_to_generate)
),
dict(
testcase_name='single_sequence_with_padding',
max_token_len=300, # fits 1 sequence, pads to 300
),
dict(
testcase_name='multiple_sequences',
max_token_len=532, # exactly (256+10) * 2
),
dict(
testcase_name='large_budget',
max_token_len=1000, # fits multiple sequences, pads to 1000
),
)
def test_sequence_packing(self, max_token_len):
kwargs = {'eval_every_n_steps': 2}

# Train without sequence packing
rl_cluster_unpacked, model_unpacked, original_variables = setup(kwargs)
grpo_config_unpacked = grpo_lib.GRPOConfig(
num_generations=2,
num_iterations=1,
)
learner_unpacked = grpo_lib.GRPOLearner(
rl_cluster=rl_cluster_unpacked,
reward_fns=reward_1,
algo_config=grpo_config_unpacked,
)
# the algorithm config use_sequence_packing is False by default
train_ds_1 = _dummy_dataset(MySource(repeat=4), batch_size=2)
learner_unpacked.train(train_ds_1, None)
params_unpacked = nnx.state(model_unpacked, nnx.Param)

# Train with sequence packing
kwargs_packed = {
'eval_every_n_steps': 2,
'max_seq_token_per_tpu': max_token_len,
}
rl_cluster_packed, model_packed, _ = setup(kwargs_packed)
grpo_config_packed = grpo_lib.GRPOConfig(
num_generations=2,
num_iterations=1,
)
learner_packed = grpo_lib.GRPOLearner(
rl_cluster=rl_cluster_packed,
reward_fns=reward_1,
algo_config=grpo_config_packed,
)
train_ds_2 = _dummy_dataset(MySource(repeat=4), batch_size=2)
learner_packed.train(train_ds_2, None)
params_packed = nnx.state(model_packed, nnx.Param)

jax.tree.map_with_path(
tc.assert_not_equal, original_variables, params_packed
)

# Check params are almost equal
# TODO(noghabi): Reduce the tolerance. Currently, the toy model does not use
# the segment IDs in the attention mask, which causes numerical
# inaccuracies.
jax.tree.map_with_path(
lambda path, x, y: tc.assert_close(path, x, y, atol=5e-2, rtol=1e-1),
params_unpacked,
params_packed,
)

# Verify that both learners processed the same number of examples
self.assertEqual(learner_unpacked._iter_steps, learner_packed._iter_steps)

def test_exception_from_data_preparation(self):
class _TrainerWithException(grpo_lib.GRPOLearner):
@override
Expand Down
Loading
Loading