Skip to content

Commit 65702c5

Browse files
[BugFix] Normalize reward loss over valid pairs (#3886)
1 parent 92a3f7f commit 65702c5

2 files changed

Lines changed: 62 additions & 5 deletions

File tree

test/test_rlhf.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,17 +406,69 @@ def test_compute_reward_loss_identical_sequences():
406406

407407
chosen_batch = SimpleNamespace(
408408
input_ids=input_ids,
409-
rewards=torch.randn(1, seq_len),
409+
rewards=torch.randn(1, seq_len, requires_grad=True),
410410
)
411411
rejected_batch = SimpleNamespace(
412412
input_ids=input_ids.clone(),
413-
rewards=torch.randn(1, seq_len),
413+
rewards=torch.randn(1, seq_len, requires_grad=True),
414414
)
415415
loss = GPT2RewardModel.compute_reward_loss(
416416
chosen_batch, rejected_batch, pad_token_id=pad_token_id
417417
)
418418
assert loss.shape == torch.Size([])
419419
assert loss.item() == 0.0
420+
loss.backward()
421+
torch.testing.assert_close(
422+
chosen_batch.rewards.grad, torch.zeros_like(chosen_batch.rewards)
423+
)
424+
torch.testing.assert_close(
425+
rejected_batch.rewards.grad, torch.zeros_like(rejected_batch.rewards)
426+
)
427+
428+
429+
def test_compute_reward_loss_normalizes_by_non_identical_sequences():
430+
pad_token_id = 50256
431+
chosen_ids = torch.tensor(
432+
[
433+
[1, 2, 3, 4, pad_token_id],
434+
[1, 2, 9, 4, pad_token_id],
435+
]
436+
)
437+
rejected_ids = torch.tensor(
438+
[
439+
[1, 2, 3, 4, pad_token_id],
440+
[1, 2, 3, 4, pad_token_id],
441+
]
442+
)
443+
chosen_rewards = torch.tensor(
444+
[
445+
[0.0, 0.0, 10.0, 10.0, 0.0],
446+
[0.0, 0.0, 2.0, 2.0, 0.0],
447+
],
448+
requires_grad=True,
449+
)
450+
rejected_rewards = torch.tensor(
451+
[
452+
[0.0, 0.0, -10.0, -10.0, 0.0],
453+
[0.0, 0.0, 1.0, 1.0, 0.0],
454+
],
455+
requires_grad=True,
456+
)
457+
chosen_batch = SimpleNamespace(input_ids=chosen_ids, rewards=chosen_rewards)
458+
rejected_batch = SimpleNamespace(input_ids=rejected_ids, rewards=rejected_rewards)
459+
460+
loss = GPT2RewardModel.compute_reward_loss(
461+
chosen_batch, rejected_batch, pad_token_id=pad_token_id
462+
)
463+
expected_loss = -F.logsigmoid(chosen_rewards[1, 2:4] - rejected_rewards[1, 2:4])
464+
torch.testing.assert_close(loss, expected_loss.mean())
465+
loss.backward()
466+
torch.testing.assert_close(
467+
chosen_rewards.grad[0], torch.zeros_like(chosen_rewards[0])
468+
)
469+
torch.testing.assert_close(
470+
rejected_rewards.grad[0], torch.zeros_like(rejected_rewards[0])
471+
)
420472

421473

422474
@pytest.mark.skipif(

torchrl/modules/models/llm.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ def compute_reward_loss(chosen_batch, rejected_batch, pad_token_id=50256):
122122
rejected_rewards = rejected_batch.rewards
123123

124124
bs = chosen_rewards.shape[0]
125-
loss = torch.tensor(0.0, device=chosen_rewards.device)
125+
loss = None
126+
valid_count = 0
126127

127128
# TODO: this loop can likely be made more efficient
128129
for i in range(bs):
@@ -144,8 +145,12 @@ def compute_reward_loss(chosen_batch, rejected_batch, pad_token_id=50256):
144145
c_truncated_reward = chosen_rewards[i][divergence_ind:end_ind]
145146
r_truncated_reward = rejected_rewards[i][divergence_ind:end_ind]
146147

147-
loss += -F.logsigmoid(c_truncated_reward - r_truncated_reward).mean()
148-
return loss / bs
148+
sample_loss = -F.logsigmoid(c_truncated_reward - r_truncated_reward).mean()
149+
loss = sample_loss if loss is None else loss + sample_loss
150+
valid_count += 1
151+
if loss is None:
152+
return chosen_rewards.sum() * 0.0 + rejected_rewards.sum() * 0.0
153+
return loss / valid_count
149154

150155
@classmethod
151156
def from_pretrained(cls, path):

0 commit comments

Comments
 (0)