diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py index 2bca82f20..ed783adc2 100644 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py @@ -43,6 +43,11 @@ def gather_log_probs(logits, labels): return log_probs_labels.squeeze(-1) +def assert_empty_partition_cache(model, tag): + avail_or_inflight_params = [p.ds_id for p in model.parameters() if hasattr(p, 'ds_id') and p.ds_status != ZeroParamStatus.NOT_AVAILABLE] + assert len(avail_or_inflight_params) == 0, \ + f'{tag} empty_partition_cache failed to evict all params: remaining = {avail_or_inflight_params}' + class DeepSpeedPPOTrainer(): def __init__(self, rlhf_engine, args): @@ -68,12 +73,14 @@ def __init__(self, rlhf_engine, args): def _generate_sequence(self, prompts): max_min_length = self.max_answer_seq_len + prompts.shape[1] - with torch.no_grad(): seq = self.actor_model.module.generate(prompts, max_length=max_min_length, min_length=max_min_length) + self.actor_model.empty_partition_cache() + assert_empty_partition_cache(self.actor_model, 'actor_model after generate') + # Filter out seq with no asnwers (or very short). This happens when users directly use the pre-training ckpt without supervised finetuning # NOTE: this will causes each GPU has different number of examples batch_size = seq.shape[0] @@ -171,6 +178,9 @@ def train_rlhf(self, inputs): action_mask[:, start:]) self.actor_model.backward(actor_loss) self.actor_model.step() + self.actor_model.empty_partition_cache() + assert_empty_partition_cache(self.actor_model, 'actor_model after rlhf step') + value = self.critic_model.forward_value(**batch, return_value_only=True, use_cache=False)[:, :-1] @@ -179,6 +189,8 @@ def train_rlhf(self, inputs): returns, action_mask[:, start:]) self.critic_model.backward(critic_loss) self.critic_model.step() + self.critic_model.empty_partition_cache() + assert_empty_partition_cache(self.critic_model, 'critic_model after rlhf step') return actor_loss, critic_loss @@ -267,5 +279,7 @@ def train_unsupervised(self, inputs, unsup_coef): loss = outputs.loss self.actor_model.backward(unsup_coef * loss) self.actor_model.step() + self.actor_model.empty_partition_cache() + assert_empty_partition_cache(self.actor_model, 'actor_model after unsuper_step') return loss