@@ -406,17 +406,69 @@ def test_compute_reward_loss_identical_sequences():
406406
407407 chosen_batch = SimpleNamespace (
408408 input_ids = input_ids ,
409- rewards = torch .randn (1 , seq_len ),
409+ rewards = torch .randn (1 , seq_len , requires_grad = True ),
410410 )
411411 rejected_batch = SimpleNamespace (
412412 input_ids = input_ids .clone (),
413- rewards = torch .randn (1 , seq_len ),
413+ rewards = torch .randn (1 , seq_len , requires_grad = True ),
414414 )
415415 loss = GPT2RewardModel .compute_reward_loss (
416416 chosen_batch , rejected_batch , pad_token_id = pad_token_id
417417 )
418418 assert loss .shape == torch .Size ([])
419419 assert loss .item () == 0.0
420+ loss .backward ()
421+ torch .testing .assert_close (
422+ chosen_batch .rewards .grad , torch .zeros_like (chosen_batch .rewards )
423+ )
424+ torch .testing .assert_close (
425+ rejected_batch .rewards .grad , torch .zeros_like (rejected_batch .rewards )
426+ )
427+
428+
429+ def test_compute_reward_loss_normalizes_by_non_identical_sequences ():
430+ pad_token_id = 50256
431+ chosen_ids = torch .tensor (
432+ [
433+ [1 , 2 , 3 , 4 , pad_token_id ],
434+ [1 , 2 , 9 , 4 , pad_token_id ],
435+ ]
436+ )
437+ rejected_ids = torch .tensor (
438+ [
439+ [1 , 2 , 3 , 4 , pad_token_id ],
440+ [1 , 2 , 3 , 4 , pad_token_id ],
441+ ]
442+ )
443+ chosen_rewards = torch .tensor (
444+ [
445+ [0.0 , 0.0 , 10.0 , 10.0 , 0.0 ],
446+ [0.0 , 0.0 , 2.0 , 2.0 , 0.0 ],
447+ ],
448+ requires_grad = True ,
449+ )
450+ rejected_rewards = torch .tensor (
451+ [
452+ [0.0 , 0.0 , - 10.0 , - 10.0 , 0.0 ],
453+ [0.0 , 0.0 , 1.0 , 1.0 , 0.0 ],
454+ ],
455+ requires_grad = True ,
456+ )
457+ chosen_batch = SimpleNamespace (input_ids = chosen_ids , rewards = chosen_rewards )
458+ rejected_batch = SimpleNamespace (input_ids = rejected_ids , rewards = rejected_rewards )
459+
460+ loss = GPT2RewardModel .compute_reward_loss (
461+ chosen_batch , rejected_batch , pad_token_id = pad_token_id
462+ )
463+ expected_loss = - F .logsigmoid (chosen_rewards [1 , 2 :4 ] - rejected_rewards [1 , 2 :4 ])
464+ torch .testing .assert_close (loss , expected_loss .mean ())
465+ loss .backward ()
466+ torch .testing .assert_close (
467+ chosen_rewards .grad [0 ], torch .zeros_like (chosen_rewards [0 ])
468+ )
469+ torch .testing .assert_close (
470+ rejected_rewards .grad [0 ], torch .zeros_like (rejected_rewards [0 ])
471+ )
420472
421473
422474@pytest .mark .skipif (
0 commit comments