Skip to content

Commit f89a315

Browse files
committed
Run pre-commit hook
1 parent c337674 commit f89a315

File tree

2 files changed

+10
-25
lines changed

2 files changed

+10
-25
lines changed

prompting/rewards/exact_match.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,6 @@ async def reward(
141141
if score_contains_mean < VERIFICATION_THRESH_CONTAINS:
142142
raise ValueError(f"Logits contains mean score is below threshold: {score_contains_mean:.2f}")
143143

144-
145144
timing_verified.append(timings)
146145
smooth_reward = self.smooth_timings_reward(timings)
147146
# Min-max scale logits reward, e.g from [0.95; 1.0] to [0.0, 1.0].
@@ -219,9 +218,7 @@ def smooth_timings_reward(timings_uid: list[float], min_reward: float = MIN_SMOO
219218

220219
@staticmethod
221220
def verify_logit_contains(
222-
candidate_token: str,
223-
candidate_logits: dict[str, float],
224-
gt_logits: dict[str, float]
221+
candidate_token: str, candidate_logits: dict[str, float], gt_logits: dict[str, float]
225222
) -> float:
226223
"""Verify if the selected token and logprobs are present in the verification output."""
227224
if candidate_token not in candidate_logits.keys():
@@ -234,9 +231,7 @@ def verify_logit_contains(
234231

235232
@staticmethod
236233
def verify_logit_similarity(
237-
original_logits: dict[str, float],
238-
verification_logits: dict[str, float],
239-
fill_value: float = -100.0
234+
original_logits: dict[str, float], verification_logits: dict[str, float], fill_value: float = -100.0
240235
) -> float:
241236
all_tokens = sorted(set(original_logits) | set(verification_logits))
242237
orig_vec = np.array([original_logits.get(t, fill_value) for t in all_tokens], dtype=np.float64)
@@ -252,4 +247,4 @@ def softmax(x: np.ndarray) -> np.ndarray:
252247

253248
orig_unit = orig_prob / np.linalg.norm(orig_prob)
254249
verif_unit = verif_prob / np.linalg.norm(verif_prob)
255-
return float(np.dot(orig_unit, verif_unit))
250+
return float(np.dot(orig_unit, verif_unit))

tests/prompting/rewards/test_exact_match.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88
from prompting.rewards.exact_match import (
99
INCORRECT_PENALTY,
1010
MIN_SMOOTH_PENALTY_SCALE,
11+
NO_EOS_PENALTY,
1112
VERIFICATION_THRESH_SIM,
12-
VERIFICATION_THRESH_CONTAINS,
1313
LogitsRewardModel,
14-
NO_EOS_PENALTY,
1514
)
1615
from prompting.rewards.reward import BatchRewardOutput
1716
from prompting.tasks.base_task import BaseTextTask
@@ -110,14 +109,8 @@ async def test_correct_completion(model_manager, task):
110109

111110
with (
112111
patch("prompting.rewards.exact_match.MIN_VERIFY_TOKENS", 2),
113-
patch(
114-
"prompting.rewards.exact_match.LogitsRewardModel.verify_logit_similarity",
115-
return_value=1
116-
),
117-
patch(
118-
"prompting.rewards.exact_match.LogitsRewardModel.verify_logit_contains",
119-
return_value=1
120-
),
112+
patch("prompting.rewards.exact_match.LogitsRewardModel.verify_logit_similarity", return_value=1),
113+
patch("prompting.rewards.exact_match.LogitsRewardModel.verify_logit_contains", return_value=1),
121114
):
122115
reward_model = LogitsRewardModel()
123116
result = await reward_model.reward(
@@ -156,10 +149,7 @@ def mock_verify_sim(original_logits, verification_logits):
156149
with (
157150
patch("prompting.rewards.exact_match.MIN_VERIFY_TOKENS", 2),
158151
patch("prompting.rewards.exact_match.LogitsRewardModel.verify_logit_similarity", side_effect=mock_verify_sim),
159-
patch(
160-
"prompting.rewards.exact_match.LogitsRewardModel.verify_logit_contains",
161-
return_value=1
162-
),
152+
patch("prompting.rewards.exact_match.LogitsRewardModel.verify_logit_contains", return_value=1),
163153
):
164154
reward_model = LogitsRewardModel()
165155
result = await reward_model.reward(
@@ -260,7 +250,7 @@ def test_smooth_reward_scale():
260250
(0.3, 0.3, 0.0),
261251
# At max boundary.
262252
(1.0, 0.3, 1.0),
263-
]
253+
],
264254
)
265255
def test_rescale_various_cases(value, min_value, expected):
266256
assert LogitsRewardModel.rescale(value, min_value=min_value) == pytest.approx(expected)
@@ -272,14 +262,14 @@ def test_rescale_various_cases(value, min_value, expected):
272262
# All valid.
273263
([[0.1, 1.0], [5.0, 0.1], [6.5]], 0.55),
274264
# Mixed values.
275-
([[ -1.0, 0.5], [2.0, 0.1]], 1.05),
265+
([[-1.0, 0.5], [2.0, 0.1]], 1.05),
276266
# All negative.
277267
([[-3.0, -0.1], [-2.5]], 1e-6),
278268
# Empty lists.
279269
([[], []], 1e-6),
280270
# Zeros included.
281271
([[0.0, -1.0], [0.0]], 0.0),
282-
]
272+
],
283273
)
284274
def test_fastest_timing_various_cases(values, expected):
285275
assert LogitsRewardModel.fastest_timing(values) == pytest.approx(expected)

0 commit comments

Comments
 (0)