88from prompting .rewards .exact_match import (
99 INCORRECT_PENALTY ,
1010 MIN_SMOOTH_PENALTY_SCALE ,
11+ NO_EOS_PENALTY ,
1112 VERIFICATION_THRESH_SIM ,
12- VERIFICATION_THRESH_CONTAINS ,
1313 LogitsRewardModel ,
14- NO_EOS_PENALTY ,
1514)
1615from prompting .rewards .reward import BatchRewardOutput
1716from prompting .tasks .base_task import BaseTextTask
@@ -110,14 +109,8 @@ async def test_correct_completion(model_manager, task):
110109
111110 with (
112111 patch ("prompting.rewards.exact_match.MIN_VERIFY_TOKENS" , 2 ),
113- patch (
114- "prompting.rewards.exact_match.LogitsRewardModel.verify_logit_similarity" ,
115- return_value = 1
116- ),
117- patch (
118- "prompting.rewards.exact_match.LogitsRewardModel.verify_logit_contains" ,
119- return_value = 1
120- ),
112+ patch ("prompting.rewards.exact_match.LogitsRewardModel.verify_logit_similarity" , return_value = 1 ),
113+ patch ("prompting.rewards.exact_match.LogitsRewardModel.verify_logit_contains" , return_value = 1 ),
121114 ):
122115 reward_model = LogitsRewardModel ()
123116 result = await reward_model .reward (
@@ -156,10 +149,7 @@ def mock_verify_sim(original_logits, verification_logits):
156149 with (
157150 patch ("prompting.rewards.exact_match.MIN_VERIFY_TOKENS" , 2 ),
158151 patch ("prompting.rewards.exact_match.LogitsRewardModel.verify_logit_similarity" , side_effect = mock_verify_sim ),
159- patch (
160- "prompting.rewards.exact_match.LogitsRewardModel.verify_logit_contains" ,
161- return_value = 1
162- ),
152+ patch ("prompting.rewards.exact_match.LogitsRewardModel.verify_logit_contains" , return_value = 1 ),
163153 ):
164154 reward_model = LogitsRewardModel ()
165155 result = await reward_model .reward (
@@ -260,7 +250,7 @@ def test_smooth_reward_scale():
260250 (0.3 , 0.3 , 0.0 ),
261251 # At max boundary.
262252 (1.0 , 0.3 , 1.0 ),
263- ]
253+ ],
264254)
265255def test_rescale_various_cases (value , min_value , expected ):
266256 assert LogitsRewardModel .rescale (value , min_value = min_value ) == pytest .approx (expected )
@@ -272,14 +262,14 @@ def test_rescale_various_cases(value, min_value, expected):
272262 # All valid.
273263 ([[0.1 , 1.0 ], [5.0 , 0.1 ], [6.5 ]], 0.55 ),
274264 # Mixed values.
275- ([[ - 1.0 , 0.5 ], [2.0 , 0.1 ]], 1.05 ),
265+ ([[- 1.0 , 0.5 ], [2.0 , 0.1 ]], 1.05 ),
276266 # All negative.
277267 ([[- 3.0 , - 0.1 ], [- 2.5 ]], 1e-6 ),
278268 # Empty lists.
279269 ([[], []], 1e-6 ),
280270 # Zeros included.
281271 ([[0.0 , - 1.0 ], [0.0 ]], 0.0 ),
282- ]
272+ ],
283273)
284274def test_fastest_timing_various_cases (values , expected ):
285275 assert LogitsRewardModel .fastest_timing (values ) == pytest .approx (expected )
0 commit comments