@@ -52,7 +52,7 @@ def __init__(self, **kwargs):
5252 super ().__init__ (** kwargs )
5353 self .beta = beta
5454
55- self .model_wrapped = self .model_wrapped .to (torch .bfloat16 )
55+ # self.model_wrapped = self.model_wrapped.to(torch.bfloat16)
5656
5757 def _generate_and_score_completions (
5858 self , inputs : list [dict [str , Union [torch .Tensor , Any ]]]
@@ -153,7 +153,7 @@ def _generate_and_score_completions(
153153 prompt_ids ,
154154 prompt_mask ,
155155 )
156-
156+ # breakpoint()
157157 # Rollout
158158 with (
159159 profiling_context (self , "transformers.generate" ),
@@ -189,10 +189,9 @@ def _generate_and_score_completions(
189189 use_scheduler = self .args .use_scheduler ,
190190 )
191191 logger .info ("Rollout completed" )
192- if (
193- self .args .torch_empty_cache_steps is not None
194- and self .state .global_step % self .args .torch_empty_cache_steps == 0
195- ):
192+
193+ # let deepspeed manage cuda cache
194+ if self .accelerator .distributed_type != DistributedType .DEEPSPEED :
196195 torch .cuda .empty_cache ()
197196
198197 # Compute prompt length and extract completion ids
@@ -243,22 +242,12 @@ def _generate_and_score_completions(
243242 example ["student_logprob" ] = logprob
244243
245244 with torch .no_grad ():
246- # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
247- # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the
248- # **samples** may come from an earlier version of the model. In that case, we need to track old_per_token_logps
249- # for importance sampling. If the steps are aligned, importance sampling isn't necessary and we set
250- # old_per_token_logps to None.
251- # This will only run when self._step % generate_every == 0 or self._buffered_inputs is None
252- # generate_every = (
253- # self.args.steps_per_generation * self.num_iterations
254- # ) # generation frequency
245+ # In the diffusion setting we already have per-token log-probs for the rollout trajectory (`sequence_logp`)
246+ # computed in `generate`. We always reuse them as `old_per_token_logps` so that we can explicitly
247+ # measure and correct any on/off-policy mismatch during replay.
255248 old_per_token_logps = sequence_logp .clone ().detach ()
256- # if self.args.gradient_accumulation_steps % generate_every != 0:
257- # old_per_token_logps = sequence_logp.clone().detach()
258- # else:
259- # old_per_token_logps = None
260249
261- # Compute the per-token log probabilities for the reference model
250+ # Compute the per-token log probabilities for the reference model when KL regularization is enabled.
262251 if self .beta != 0.0 :
263252 ref_per_token_logps = old_per_token_logps .clone ().detach ()
264253 else :
@@ -452,16 +441,10 @@ def _compute_loss(self, model, inputs, num_items_in_batch):
452441 inputs ["completion_mask" ],
453442 )
454443 input_ids = torch .cat ([prompt_ids , completion_ids ], dim = 1 )
455- # Replay must mirror the rollout mask: during generation only the prompt tokens
456- # were marked as valid, so keep zeros on the completion portion.
457- prompt_only_mask = torch .ones_like (prompt_ids , dtype = prompt_mask .dtype )
458- attention_mask = torch .cat (
459- [
460- prompt_only_mask ,
461- torch .zeros_like (completion_ids , dtype = prompt_mask .dtype ),
462- ],
463- dim = 1 ,
464- )
444+ # Replay must mirror the rollout mask used during generation:
445+ # prompt padding is preserved, completion tokens are treated as valid (non-padding) tokens.
446+ attention_mask = torch .ones_like (input_ids , dtype = prompt_mask .dtype )
447+ attention_mask [:, :prompt_len ] = prompt_mask
465448 sampling_traj = inputs ["sampling_traj" ]
466449 x0_hist = inputs ["x0_hist" ]
467450 all_advantages = inputs ["advantages" ]
@@ -483,17 +466,13 @@ def _compute_loss(self, model, inputs, num_items_in_batch):
483466
484467 all_traj_len = self .accelerator .gather (
485468 torch .tensor (traj_len , device = input_ids_batch .device )
486- )
469+ )
487470 max_traj_len = all_traj_len .max ().item ()
488471
489472 mask_id = self .args .mask_id
490473 cur_input = input_ids_batch .clone ()
491474 cur_input [:, prompt_len :] = mask_id
492- if (
493- self .args .torch_empty_cache_steps is not None
494- and self .state .global_step % self .args .torch_empty_cache_steps == 0
495- ):
496- torch .cuda .empty_cache ()
475+ # torch.cuda.empty_cache()
497476 for step in tqdm (range (max_traj_len ), desc = "Computing per-token logps" ):
498477 # logger.info(f"Step {step} of {traj_len}")
499478 # running the model in batches per step
@@ -546,14 +525,19 @@ def _compute_loss(self, model, inputs, num_items_in_batch):
546525 cur_logp = torch .zeros_like (
547526 unmasking_prob [batch ], dtype = torch .float32
548527 ).unsqueeze (0 )
528+ EPS = 1e-6
529+ clamped_prob = torch .clamp (unmasking_prob [batch ], min = EPS , max = 1.0 - EPS )
549530 if len (cur_traj [batch ][step ]) > 0 :
531+ # Use log1p for log(1-p) when p is small
550532 cur_logp [:, keep_mask_index_mask ] = torch .log1p (
551- - unmasking_prob [ batch , keep_mask_index_mask ]
533+ - clamped_prob [ keep_mask_index_mask ]
552534 )
535+ # Use log for log(p), now safe due to clamping
553536 cur_logp [:, unmasking_index_mask ] = (
554- torch .log (unmasking_prob [ batch , unmasking_index_mask ])
537+ torch .log (clamped_prob [ unmasking_index_mask ])
555538 + x0_logp [batch , unmasking_index_mask ]
556539 )
540+
557541 if (
558542 torch .isnan (cur_logp ).sum () > 0
559543 or not torch .isfinite (cur_logp ).all ()
@@ -620,17 +604,15 @@ def _compute_loss(self, model, inputs, num_items_in_batch):
620604 # Two-sided clipping
621605 if self .args .delta is not None :
622606 coef_1 = torch .clamp (coef_1 , max = self .args .delta )
623- advantages = torch .where (
624- advantages < self .args .advantage_min_clip ,
625- torch .zeros_like (
626- advantages
627- ), # ignores advantages below a threshold
628- advantages ,
607+ advantages = torch .clamp (
608+ advantages , min = self .args .advantage_min_clip
629609 )
630610
631611 per_token_loss1 = coef_1 * advantages .unsqueeze (1 )
632612 per_token_loss2 = coef_2 * advantages .unsqueeze (1 )
633613 per_token_loss = - torch .min (per_token_loss1 , per_token_loss2 )
614+ # if entropy_mask is not None:
615+ # per_token_loss = per_token_loss * entropy_mask
634616
635617 if self .beta != 0.0 :
636618 per_token_loss = per_token_loss + self .beta * per_token_kl
@@ -640,21 +622,28 @@ def _compute_loss(self, model, inputs, num_items_in_batch):
640622 / per_token_loss .size (0 )
641623 / self .max_completion_length
642624 )
643- loss = loss / self . current_gradient_accumulation_steps
625+
644626 if loss .grad_fn is None :
645627 # this means that no token is unmasked, this can happen because generated completion rollout is splitted into smaller batches
646628 # raise ValueError("No gradient found")
647629 loss = logits .exp ().sum () * 0.0 + unmasking_prob .sum () * 0.0
648630 loss_list .append (loss .item ())
649631 # print(f"Loss: {loss}")
650- # Backward pass
651632 if bad_flag or loss .isnan ():
652633 accel_break (bad_process_index )
653- # logger.info(f"[Rank {self.accelerator.process_index}]Loss: {loss}")
654- self .backward (loss , num_items_in_batch )
634+
635+ # Backward pass: accumulate gradients over diffusion steps but only let DeepSpeed
636+ # take an optimizer step on the final (chunk, step) pair.
637+ force_deepspeed_step = False
638+ if self .accelerator .distributed_type == DistributedType .DEEPSPEED :
639+ is_last_chunk = start + batch_size == input_ids .size (0 )
640+ is_last_step = step == max_traj_len - 1
641+ force_deepspeed_step = is_last_chunk and is_last_step
642+ self .backward (loss , num_items_in_batch , force_deepspeed_step = force_deepspeed_step )
655643 return_loss += loss .detach ()
656644
657645 del cur_input
646+ # torch.cuda.empty_cache() # to reduce memory usage but will make things super slow
658647 cur_input = next_input
659648
660649 # Log the metrics
@@ -758,11 +747,31 @@ def compute_loss(
758747 else :
759748 return self ._compute_loss (model , inputs , num_items_in_batch )
760749
761- def backward (self , loss : torch .Tensor , num_items_in_batch ):
750+ def backward (self , loss : torch .Tensor , num_items_in_batch , force_deepspeed_step = False ):
751+ if (force_deepspeed_step and self .accelerator .distributed_type != DistributedType .DEEPSPEED ):
752+ raise ValueError ("force_deepspeed_step should only be true during DeepSpeed runs" )
753+
762754 kwargs = {}
755+
756+
757+ # since we don't want deepspeed to step the optimizer
758+ # every single time we unmask a chunk, we force the gradient sync
759+ # flag to be false here until we're ready to step (after the full rollout)
760+ if self .accelerator .distributed_type == DistributedType .DEEPSPEED :
761+ orig_sync = getattr (self .accelerator , "sync_gradients" , True )
762+ self .accelerator .sync_gradients = force_deepspeed_step
763+ self .accelerator .backward (loss , ** kwargs )
764+ self .accelerator .sync_gradients = orig_sync
765+ return # exit early so that we don't clear any live gradients from the cache
766+
767+ if (
768+ self .args .torch_empty_cache_steps is not None
769+ and self .state .global_step % self .args .torch_empty_cache_steps == 0
770+ ):
771+ torch .cuda .empty_cache ()
763772
764773 if self .args .n_gpu > 1 :
765- loss = loss .mean () # mean() to average on multi-gpu parallel training
774+ loss = loss .mean () # mean() to average on multi-gpu parallel training (non-deepspeed)
766775
767776 # Finally we need to normalize the loss for reporting if GA loss bug is not fixed during compute loss
768777 if (
@@ -771,9 +780,4 @@ def backward(self, loss: torch.Tensor, num_items_in_batch):
771780 # If the model does not accept loss kwargs, we need to normalize the loss by the number of gradient accumulation steps
772781 loss = loss / self .current_gradient_accumulation_steps
773782
774- # Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
775- # https://github.com/huggingface/transformers/pull/35808
776- if self .accelerator .distributed_type == DistributedType .DEEPSPEED :
777- kwargs ["scale_wrt_gas" ] = False
778-
779783 self .accelerator .backward (loss , ** kwargs )
0 commit comments