@@ -36,115 +36,95 @@ def qwen3_acc_thinking_template():
3636
3737
3838@pytest .mark .parametrize (
39- "loss_masks,response_ids,eos_token_id ,expected_masks" ,
39+ "loss_masks,stop_reasons ,expected_masks" ,
4040 [
41- # Test case 1: All responses end with eos token - masks should remain unchanged
41+ # Test case 1: All responses completed normally - masks should remain unchanged
4242 (
4343 [[1 , 1 , 0 , 1 ], [0 , 1 , 1 , 1 ], [1 , 0 , 1 ]],
44- [[1 , 2 , 3 , 4 ], [5 , 6 , 7 , 4 ], [8 , 9 , 4 ]], # All end with eos_token_id=4
45- 4 ,
44+ ["stop" , "stop" , "stop" ],
4645 [[1 , 1 , 0 , 1 ], [0 , 1 , 1 , 1 ], [1 , 0 , 1 ]],
4746 ),
48- # Test case 2: No responses end with eos token - all masks should be zeroed
47+ # Test case 2: All responses truncated - all masks should be zeroed
4948 (
5049 [[1 , 1 , 0 , 1 ], [0 , 1 , 1 , 1 ], [1 , 0 , 1 ]],
51- [[1 , 2 , 3 , 5 ], [5 , 6 , 7 , 8 ], [8 , 9 , 10 ]], # None end with eos_token_id=4
52- 4 ,
50+ ["length" , "length" , "length" ],
5351 [[0 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 ], [0 , 0 , 0 ]],
5452 ),
55- # Test case 3: Mixed responses - only non-eos ending masks should be zeroed
53+ # Test case 3: Mixed - only truncated masks should be zeroed
5654 (
5755 [[1 , 1 , 0 , 1 ], [0 , 1 , 1 , 1 ], [1 , 0 , 1 , 0 , 1 ]],
58- [[1 , 2 , 3 , 4 ], [5 , 6 , 7 , 8 ], [8 , 9 , 10 , 11 , 4 ]], # First and third end with eos_token_id=4
59- 4 ,
56+ ["stop" , "length" , "stop" ],
6057 [[1 , 1 , 0 , 1 ], [0 , 0 , 0 , 0 ], [1 , 0 , 1 , 0 , 1 ]],
6158 ),
62- # Test case 4: Empty responses should be zeroed
59+ # Test case 4: Various non-"stop" reasons should all be zeroed
6360 (
6461 [[1 , 1 ], [1 , 0 , 1 ], [0 , 1 , 1 , 1 ]],
65- [[], [1 , 2 , 3 ], [4 , 5 , 6 , 7 ]], # Empty, no eos, no eos (eos_token_id=4)
66- 4 ,
62+ ["length" , "abort" , "cancelled" ],
6763 [[0 , 0 ], [0 , 0 , 0 ], [0 , 0 , 0 , 0 ]],
6864 ),
6965 # Test case 5: Empty lists
70- ([], [], 4 , []),
71- # Test case 6: Different eos token id
72- (
73- [[1 , 1 ], [1 , 0 , 1 ], [0 , 1 , 1 , 1 ]],
74- [[1 , 2 ], [3 , 4 , 99 ], [5 , 6 , 7 , 99 ]], # Second and third end with eos_token_id=99
75- 99 ,
76- [[0 , 0 ], [1 , 0 , 1 ], [0 , 1 , 1 , 1 ]],
77- ),
66+ ([], [], []),
7867 ],
7968)
80- def test_apply_overlong_filtering (loss_masks , response_ids , eos_token_id , expected_masks ):
69+ def test_apply_overlong_filtering (loss_masks , stop_reasons , expected_masks ):
8170 """
8271 Test the apply_overlong_filtering function which implements DAPO Overlong Filtering.
8372
84- This function should zero-out every token's mask whenever the response does not end
85- with the eos token id (i.e. truncated), while leaving other masks unchanged.
73+ This function should zero-out every token's mask whenever the stop reason is not "stop"
74+ (i.e. the response was truncated), while leaving other masks unchanged.
8675 """
87- result = apply_overlong_filtering (loss_masks , response_ids , eos_token_id )
76+ result = apply_overlong_filtering (loss_masks , stop_reasons )
8877
8978 assert result == expected_masks , f"Expected { expected_masks } , but got { result } "
9079
91- # Verify that the original inputs are not modified (immutability check)
9280 assert len (result ) == len (loss_masks ), "Result should have same length as input"
9381
94- # Check that each individual mask is processed correctly
95- for i , (original_mask , response , expected_mask ) in enumerate (zip (loss_masks , response_ids , expected_masks )):
96- if len (response ) == 0 or response [- 1 ] != eos_token_id :
97- # Should be all zeros with same length as original
82+ for i , (original_mask , stop_reason , expected_mask ) in enumerate (zip (loss_masks , stop_reasons , expected_masks )):
83+ if stop_reason != "stop" :
9884 assert result [i ] == [0 ] * len (original_mask ), f"Mask { i } should be all zeros for truncated response"
9985 else :
100- # Should be unchanged
101- assert result [i ] == original_mask , f"Mask { i } should be unchanged for response ending with eos token"
86+ assert result [i ] == original_mask , f"Mask { i } should be unchanged for completed response"
10287
10388
10489def test_apply_overlong_filtering_immutability ():
10590 """
10691 Test that apply_overlong_filtering doesn't modify the original input lists.
10792 """
10893 original_loss_masks = [[1 , 1 , 0 , 1 ], [0 , 1 , 1 ]]
109- original_response_ids = [[1 , 2 , 3 , 4 ], [5 , 6 , 7 ]] # First ends with eos=4, second doesn't
110- eos_token_id = 4
94+ original_stop_reasons = ["stop" , "length" ]
11195
112- # Create copies to compare against later
113- loss_masks_copy = [mask [:] for mask in original_loss_masks ] # Deep copy of lists
114- response_ids_copy = [response [:] for response in original_response_ids ] # Deep copy of lists
96+ loss_masks_copy = [mask [:] for mask in original_loss_masks ]
97+ stop_reasons_copy = original_stop_reasons [:]
11598
116- result = apply_overlong_filtering (original_loss_masks , original_response_ids , eos_token_id )
99+ result = apply_overlong_filtering (original_loss_masks , original_stop_reasons )
117100
118- # Verify original inputs are unchanged
119101 assert original_loss_masks == loss_masks_copy , "Original loss_masks should not be modified"
120- assert original_response_ids == response_ids_copy , "Original response_ids should not be modified"
102+ assert original_stop_reasons == stop_reasons_copy , "Original stop_reasons should not be modified"
121103
122- # Verify result is correct
123- expected = [[1 , 1 , 0 , 1 ], [0 , 0 , 0 ]] # Second mask zeroed due to not ending with eos
104+ expected = [[1 , 1 , 0 , 1 ], [0 , 0 , 0 ]] # Second mask zeroed due to truncation
124105 assert result == expected , f"Expected { expected } , got { result } "
125106
126107
127108@pytest .mark .parametrize (
128- "loss_masks,response_ids " ,
109+ "loss_masks,stop_reasons " ,
129110 [
130- # Test case 1: More loss_masks than response_ids
131- ([[1 , 1 ], [0 , 1 ]], [[ 1 , 2 ] ]),
132- # Test case 2: More response_ids than loss_masks
133- ([[1 , 1 ]], [[ 1 , 2 ], [ 3 , 4 ] ]),
134- # Test case 3: Empty loss_masks but non-empty response_ids
135- ([], [[ 1 , 2 ] ]),
136- # Test case 4: Non-empty loss_masks but empty response_ids
111+ # Test case 1: More loss_masks than stop_reasons
112+ ([[1 , 1 ], [0 , 1 ]], ["stop" ]),
113+ # Test case 2: More stop_reasons than loss_masks
114+ ([[1 , 1 ]], ["stop" , "length" ]),
115+ # Test case 3: Empty loss_masks but non-empty stop_reasons
116+ ([], ["stop" ]),
117+ # Test case 4: Non-empty loss_masks but empty stop_reasons
137118 ([[1 , 0 ]], []),
138119 ],
139120)
140- def test_apply_overlong_filtering_length_mismatch_assertion (loss_masks , response_ids ):
121+ def test_apply_overlong_filtering_length_mismatch_assertion (loss_masks , stop_reasons ):
141122 """
142- Test that apply_overlong_filtering raises AssertionError when loss_masks and response_ids
123+ Test that apply_overlong_filtering raises AssertionError when loss_masks and stop_reasons
143124 have different lengths.
144125 """
145- eos_token_id = 4
146- with pytest .raises (AssertionError , match = "loss_masks and response_ids must have the same length" ):
147- apply_overlong_filtering (loss_masks , response_ids , eos_token_id )
126+ with pytest .raises (AssertionError , match = "loss_masks and stop_reasons must have the same length" ):
127+ apply_overlong_filtering (loss_masks , stop_reasons )
148128
149129
150130dummy_chat_template = (
0 commit comments