Skip to content

Commit af42cf8

Browse files
authored
[algorithm][generator] change overlong filtering to use stop reasons over checking eos token (NovaSky-AI#1319)
previously overlong filtering was doing this: ```python [ [0] * len(mask) if not response or response[-1] != eos_token_id else mask for mask, response in zip(loss_masks, response_ids) ] ``` which was flaky, since models could choose to end with a token other than `tokenizer.eos_token_id`. This is the case for `moonlight_16b_a3b` which ends with `<|im_end|>` even though it separately has `tokenizer.eos_token_id` set to `[EOS]`. It's more reliable to just check `stop reasons != stop`. This overlaps slightly with `zero_reward_on_non_stop`, but does have different behavior since it zeroes out the loss mask and not the reward (which is environment responsibility). <!-- devin-review-badge-begin --> --- <a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1319" target="_blank"> <picture> <source media="(prefers-color-scheme: dark)" srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1"> <img src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1" alt="Open with Devin"> </picture> </a> <!-- devin-review-badge-end -->
1 parent cf33e86 commit af42cf8

3 files changed

Lines changed: 52 additions & 64 deletions

File tree

skyrl/train/generators/skyrl_gym_generator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,8 @@ async def generate_batched(
650650
rollout_metrics = get_rollout_metrics(responses, rewards, env_metrics, env_classes)
651651

652652
if self.generator_cfg.apply_overlong_filtering:
653-
loss_masks = apply_overlong_filtering(loss_masks, responses, self.tokenizer.eos_token_id)
653+
# set loss mask to 0 if the stop reason is not "stop"
654+
loss_masks = apply_overlong_filtering(loss_masks, stop_reasons)
654655

655656
generator_output: GeneratorOutput = {
656657
"prompt_token_ids": prompt_token_ids,
@@ -767,7 +768,8 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False
767768
rewards = self._zero_reward_if_not_stop(rewards, stop_reasons)
768769

769770
if self.generator_cfg.apply_overlong_filtering:
770-
loss_masks = apply_overlong_filtering(loss_masks, responses, self.tokenizer.eos_token_id)
771+
# set loss mask to 0 if the stop reason is not "stop"
772+
loss_masks = apply_overlong_filtering(loss_masks, stop_reasons)
771773

772774
generator_output: GeneratorOutput = {
773775
"prompt_token_ids": prompt_token_ids,

skyrl/train/generators/utils.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -276,20 +276,26 @@ def concatenate_generator_outputs(generator_outputs: List[GeneratorOutput]) -> G
276276

277277
def apply_overlong_filtering(
278278
loss_masks: List[List[int]],
279-
response_ids: List[List[int]],
280-
eos_token_id: int,
279+
stop_reasons: List[str],
281280
) -> List[List[int]]:
282281
"""
283282
Implements DAPO Overlong Filtering: zero-out every token's mask whenever
284-
the response does not end with the eos token id (i.e. truncated).
283+
the response was truncated (i.e. did not end with a stop token).
284+
285+
Uses stop_reasons from the inference engine rather than checking for a
286+
specific eos token id, making this model/tokenizer agnostic.
287+
288+
Args:
289+
loss_masks: Per-trajectory token loss masks.
290+
stop_reasons: Per-trajectory stop reasons from the inference engine
291+
(e.g. "stop" for normal completion, "length" for truncation).
285292
286293
Returns:
287-
- The loss masks with tokens zeroed out for truncated responses
294+
The loss masks with tokens zeroed out for truncated responses.
288295
"""
289-
assert len(loss_masks) == len(response_ids), "loss_masks and response_ids must have the same length"
296+
assert len(loss_masks) == len(stop_reasons), "loss_masks and stop_reasons must have the same length"
290297
return [
291-
[0] * len(mask) if not response or response[-1] != eos_token_id else mask
292-
for mask, response in zip(loss_masks, response_ids)
298+
[0] * len(mask) if stop_reason != "stop" else mask[:] for mask, stop_reason in zip(loss_masks, stop_reasons)
293299
]
294300

295301

tests/train/generators/test_utils.py

Lines changed: 35 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -36,115 +36,95 @@ def qwen3_acc_thinking_template():
3636

3737

3838
@pytest.mark.parametrize(
39-
"loss_masks,response_ids,eos_token_id,expected_masks",
39+
"loss_masks,stop_reasons,expected_masks",
4040
[
41-
# Test case 1: All responses end with eos token - masks should remain unchanged
41+
# Test case 1: All responses completed normally - masks should remain unchanged
4242
(
4343
[[1, 1, 0, 1], [0, 1, 1, 1], [1, 0, 1]],
44-
[[1, 2, 3, 4], [5, 6, 7, 4], [8, 9, 4]], # All end with eos_token_id=4
45-
4,
44+
["stop", "stop", "stop"],
4645
[[1, 1, 0, 1], [0, 1, 1, 1], [1, 0, 1]],
4746
),
48-
# Test case 2: No responses end with eos token - all masks should be zeroed
47+
# Test case 2: All responses truncated - all masks should be zeroed
4948
(
5049
[[1, 1, 0, 1], [0, 1, 1, 1], [1, 0, 1]],
51-
[[1, 2, 3, 5], [5, 6, 7, 8], [8, 9, 10]], # None end with eos_token_id=4
52-
4,
50+
["length", "length", "length"],
5351
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0]],
5452
),
55-
# Test case 3: Mixed responses - only non-eos ending masks should be zeroed
53+
# Test case 3: Mixed - only truncated masks should be zeroed
5654
(
5755
[[1, 1, 0, 1], [0, 1, 1, 1], [1, 0, 1, 0, 1]],
58-
[[1, 2, 3, 4], [5, 6, 7, 8], [8, 9, 10, 11, 4]], # First and third end with eos_token_id=4
59-
4,
56+
["stop", "length", "stop"],
6057
[[1, 1, 0, 1], [0, 0, 0, 0], [1, 0, 1, 0, 1]],
6158
),
62-
# Test case 4: Empty responses should be zeroed
59+
# Test case 4: Various non-"stop" reasons should all be zeroed
6360
(
6461
[[1, 1], [1, 0, 1], [0, 1, 1, 1]],
65-
[[], [1, 2, 3], [4, 5, 6, 7]], # Empty, no eos, no eos (eos_token_id=4)
66-
4,
62+
["length", "abort", "cancelled"],
6763
[[0, 0], [0, 0, 0], [0, 0, 0, 0]],
6864
),
6965
# Test case 5: Empty lists
70-
([], [], 4, []),
71-
# Test case 6: Different eos token id
72-
(
73-
[[1, 1], [1, 0, 1], [0, 1, 1, 1]],
74-
[[1, 2], [3, 4, 99], [5, 6, 7, 99]], # Second and third end with eos_token_id=99
75-
99,
76-
[[0, 0], [1, 0, 1], [0, 1, 1, 1]],
77-
),
66+
([], [], []),
7867
],
7968
)
80-
def test_apply_overlong_filtering(loss_masks, response_ids, eos_token_id, expected_masks):
69+
def test_apply_overlong_filtering(loss_masks, stop_reasons, expected_masks):
8170
"""
8271
Test the apply_overlong_filtering function which implements DAPO Overlong Filtering.
8372
84-
This function should zero-out every token's mask whenever the response does not end
85-
with the eos token id (i.e. truncated), while leaving other masks unchanged.
73+
This function should zero-out every token's mask whenever the stop reason is not "stop"
74+
(i.e. the response was truncated), while leaving other masks unchanged.
8675
"""
87-
result = apply_overlong_filtering(loss_masks, response_ids, eos_token_id)
76+
result = apply_overlong_filtering(loss_masks, stop_reasons)
8877

8978
assert result == expected_masks, f"Expected {expected_masks}, but got {result}"
9079

91-
# Verify that the original inputs are not modified (immutability check)
9280
assert len(result) == len(loss_masks), "Result should have same length as input"
9381

94-
# Check that each individual mask is processed correctly
95-
for i, (original_mask, response, expected_mask) in enumerate(zip(loss_masks, response_ids, expected_masks)):
96-
if len(response) == 0 or response[-1] != eos_token_id:
97-
# Should be all zeros with same length as original
82+
for i, (original_mask, stop_reason, expected_mask) in enumerate(zip(loss_masks, stop_reasons, expected_masks)):
83+
if stop_reason != "stop":
9884
assert result[i] == [0] * len(original_mask), f"Mask {i} should be all zeros for truncated response"
9985
else:
100-
# Should be unchanged
101-
assert result[i] == original_mask, f"Mask {i} should be unchanged for response ending with eos token"
86+
assert result[i] == original_mask, f"Mask {i} should be unchanged for completed response"
10287

10388

10489
def test_apply_overlong_filtering_immutability():
10590
"""
10691
Test that apply_overlong_filtering doesn't modify the original input lists.
10792
"""
10893
original_loss_masks = [[1, 1, 0, 1], [0, 1, 1]]
109-
original_response_ids = [[1, 2, 3, 4], [5, 6, 7]] # First ends with eos=4, second doesn't
110-
eos_token_id = 4
94+
original_stop_reasons = ["stop", "length"]
11195

112-
# Create copies to compare against later
113-
loss_masks_copy = [mask[:] for mask in original_loss_masks] # Deep copy of lists
114-
response_ids_copy = [response[:] for response in original_response_ids] # Deep copy of lists
96+
loss_masks_copy = [mask[:] for mask in original_loss_masks]
97+
stop_reasons_copy = original_stop_reasons[:]
11598

116-
result = apply_overlong_filtering(original_loss_masks, original_response_ids, eos_token_id)
99+
result = apply_overlong_filtering(original_loss_masks, original_stop_reasons)
117100

118-
# Verify original inputs are unchanged
119101
assert original_loss_masks == loss_masks_copy, "Original loss_masks should not be modified"
120-
assert original_response_ids == response_ids_copy, "Original response_ids should not be modified"
102+
assert original_stop_reasons == stop_reasons_copy, "Original stop_reasons should not be modified"
121103

122-
# Verify result is correct
123-
expected = [[1, 1, 0, 1], [0, 0, 0]] # Second mask zeroed due to not ending with eos
104+
expected = [[1, 1, 0, 1], [0, 0, 0]] # Second mask zeroed due to truncation
124105
assert result == expected, f"Expected {expected}, got {result}"
125106

126107

127108
@pytest.mark.parametrize(
128-
"loss_masks,response_ids",
109+
"loss_masks,stop_reasons",
129110
[
130-
# Test case 1: More loss_masks than response_ids
131-
([[1, 1], [0, 1]], [[1, 2]]),
132-
# Test case 2: More response_ids than loss_masks
133-
([[1, 1]], [[1, 2], [3, 4]]),
134-
# Test case 3: Empty loss_masks but non-empty response_ids
135-
([], [[1, 2]]),
136-
# Test case 4: Non-empty loss_masks but empty response_ids
111+
# Test case 1: More loss_masks than stop_reasons
112+
([[1, 1], [0, 1]], ["stop"]),
113+
# Test case 2: More stop_reasons than loss_masks
114+
([[1, 1]], ["stop", "length"]),
115+
# Test case 3: Empty loss_masks but non-empty stop_reasons
116+
([], ["stop"]),
117+
# Test case 4: Non-empty loss_masks but empty stop_reasons
137118
([[1, 0]], []),
138119
],
139120
)
140-
def test_apply_overlong_filtering_length_mismatch_assertion(loss_masks, response_ids):
121+
def test_apply_overlong_filtering_length_mismatch_assertion(loss_masks, stop_reasons):
141122
"""
142-
Test that apply_overlong_filtering raises AssertionError when loss_masks and response_ids
123+
Test that apply_overlong_filtering raises AssertionError when loss_masks and stop_reasons
143124
have different lengths.
144125
"""
145-
eos_token_id = 4
146-
with pytest.raises(AssertionError, match="loss_masks and response_ids must have the same length"):
147-
apply_overlong_filtering(loss_masks, response_ids, eos_token_id)
126+
with pytest.raises(AssertionError, match="loss_masks and stop_reasons must have the same length"):
127+
apply_overlong_filtering(loss_masks, stop_reasons)
148128

149129

150130
dummy_chat_template = (

0 commit comments

Comments
 (0)