From 729456957a708529d8c09b3697ad85b8824e2b12 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Mon, 31 Mar 2025 11:35:23 +0800 Subject: [PATCH 01/17] update help information --- applications/ColossalChat/rl_example.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 4a4a4c3404e9..5e7af5c192d9 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -10,13 +10,13 @@ parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") parser.add_argument("-t", "--num-trainers", type=int, default=2) parser.add_argument("-i", "--num-inferencer", type=int, default=2) - parser.add_argument("-g", "--num-generations", type=int, default=8) - parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64) - parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8) - parser.add_argument("-tbs", "--train-batch-size", type=int, default=32) - parser.add_argument("-tMbs", "--train-minibatch-size", type=int, default=1) - parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2) - parser.add_argument("-b", "--backend", type=str, default="transformers") + parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.") + parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64, help="Number of prompts to generate per step.") + parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8, help="Number of prompts to send from the producer to the consumer.") + parser.add_argument("-tbs", "--train-batch-size", type=int, default=32, help="Number of prompts to update policy model.") + parser.add_argument("-tMbs", "--train-minibatch-size", type=int, default=1, help="Number of prompts per device. Number of samples = tMbs * num of generation per prompt.") + parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2, help="Number of samples per device.") + parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers, vllm"]) parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"]) args = parser.parse_args() From 27d0108c9f8aab90288ef0d841ad27fb032fd6bf Mon Sep 17 00:00:00 2001 From: Tong Li Date: Mon, 31 Mar 2025 11:44:10 +0800 Subject: [PATCH 02/17] update style --- applications/ColossalChat/rl_example.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 5e7af5c192d9..7ff9bd20d2d4 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -11,10 +11,26 @@ parser.add_argument("-t", "--num-trainers", type=int, default=2) parser.add_argument("-i", "--num-inferencer", type=int, default=2) parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.") - parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64, help="Number of prompts to generate per step.") - parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8, help="Number of prompts to send from the producer to the consumer.") - parser.add_argument("-tbs", "--train-batch-size", type=int, default=32, help="Number of prompts to update policy model.") - parser.add_argument("-tMbs", "--train-minibatch-size", type=int, default=1, help="Number of prompts per device. Number of samples = tMbs * num of generation per prompt.") + parser.add_argument( + "-ibs", "--inference-batch-size", type=int, default=64, help="Number of prompts to generate per step." + ) + parser.add_argument( + "-imbs", + "--inference-microbatch-size", + type=int, + default=8, + help="Number of prompts to send from the producer to the consumer.", + ) + parser.add_argument( + "-tbs", "--train-batch-size", type=int, default=32, help="Number of prompts to update policy model." + ) + parser.add_argument( + "-tMbs", + "--train-minibatch-size", + type=int, + default=1, + help="Number of prompts per device. Number of samples = tMbs * num of generation per prompt.", + ) parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2, help="Number of samples per device.") parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers, vllm"]) parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"]) From 9d6ede98dcc9470f7db55af03060d5f72c4d29c4 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Mon, 31 Mar 2025 13:14:41 +0800 Subject: [PATCH 03/17] fix --- applications/ColossalChat/rl_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 7ff9bd20d2d4..bb7b848267e7 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -32,7 +32,7 @@ help="Number of prompts per device. Number of samples = tMbs * num of generation per prompt.", ) parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2, help="Number of samples per device.") - parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers, vllm"]) + parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers","vllm"]) parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"]) args = parser.parse_args() From 6604654b2f3a37bea2a1dc314498f498aa1b6dde Mon Sep 17 00:00:00 2001 From: Tong Li Date: Mon, 31 Mar 2025 13:15:18 +0800 Subject: [PATCH 04/17] minor fix --- applications/ColossalChat/rl_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index bb7b848267e7..bb719a13c405 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -32,7 +32,7 @@ help="Number of prompts per device. Number of samples = tMbs * num of generation per prompt.", ) parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2, help="Number of samples per device.") - parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers","vllm"]) + parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"]) parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"]) args = parser.parse_args() From d961a5f7251f377856fc2def74dc1de98269fe49 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Tue, 1 Apr 2025 11:24:09 +0800 Subject: [PATCH 05/17] support PP training --- .../coati/distributed/consumer.py | 1 - .../coati/distributed/grpo_consumer.py | 166 +++++++++++------- applications/ColossalChat/rl_example.py | 16 +- .../booster/plugin/hybrid_parallel_plugin.py | 6 +- colossalai/shardformer/modeling/qwen2.py | 1 + 5 files changed, 121 insertions(+), 69 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 027acc2e7537..4e1cd1f3179a 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -54,7 +54,6 @@ def __init__( self.model_config = model_config self.plugin_config = plugin_config - assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now" self.device = get_current_device() self.lr_scheduler = None diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 4174f96514b8..d05709febf52 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -96,7 +96,7 @@ def __init__( self.global_step = 0 if use_wandb and self.rank == 0: name = f"{generate_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}" - self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name) + self.wandb_run = wandb.init(project="GRPO-V1-PP", sync_tensorboard=True, dir="./wandb", name=name) self.lr_scheduler = CosineAnnealingWarmupLR( optimizer=self.optimizer, @@ -168,72 +168,120 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: ).repeat_interleave(self.num_generations, dim=0) ) mean_kl, mean_loss = [], [] - for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size): - input_ids_forward_micro_batch = data["input_ids"][ - forward_micro_batch_start : forward_micro_batch_start + forward_batch_size - ] - attention_mask_forward_micro_batch = data["attention_mask"][ - forward_micro_batch_start : forward_micro_batch_start + forward_batch_size - ] - action_mask_forward_micro_batch = action_mask[ - forward_micro_batch_start : forward_micro_batch_start + forward_batch_size - ] - loss_mask_forward_micro_batch = ( - loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size] - if loss_mask is not None - else None - ) - advantages_forward_micro_batch = advantages[ - forward_micro_batch_start : forward_micro_batch_start + forward_batch_size - ] - policy_model_logits = self.policy_model( - input_ids=input_ids_forward_micro_batch, - attention_mask=attention_mask_forward_micro_batch, - ).logits - action_log_probs = calc_action_log_probs( - policy_model_logits / self.generate_config["temperature"], - input_ids_forward_micro_batch, - num_action, - self.plugin.shard_config, - ) + if self.plugin.pp_size > 1: + # Support training with PP. + data_iter = iter([data]) with torch.no_grad(): - reference_model_logits = self.reference_model( + reference_model_outputs = self.booster.execute_pipeline( + data_iter, + self.reference_model, + criterion=lambda outputs, inputs: outputs.logits.mean(), # dummy criterion + optimizer=None, + return_loss=False, + return_outputs=True, + ) + + if self.booster.plugin.stage_manager.is_last_stage(): + reference_model_logits = reference_model_outputs["outputs"]["logits"] + reference_action_log_probs = calc_action_log_probs( + reference_model_logits / self.generate_config["temperature"], + data["input_ids"], + num_action, + self.plugin.shard_config, + ) + else: + # Dummy reference logprobs for data iterator. + reference_action_log_probs = torch.zeros( + (old_action_log_probs.size(0), old_action_log_probs.size(1)) + ) + + data["reference_action_log_probs"] = reference_action_log_probs + + data_iter = iter([data]) + + def _criterion(outputs, inputs): + pass + + outputs = self.booster.execute_pipeline( + data_iter, + self.policy_model, + criterion=_criterion, + optimizer=self.optimizer, + return_loss=True, + ) + loss = outputs["loss"] + + if self.booster.plugin.stage_manager.is_last_stage(): + loss = all_reduce_mean(loss, self.plugin) + mean_loss.append(loss.data) + else: + for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size): + input_ids_forward_micro_batch = data["input_ids"][ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + attention_mask_forward_micro_batch = data["attention_mask"][ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + action_mask_forward_micro_batch = action_mask[ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + loss_mask_forward_micro_batch = ( + loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size] + if loss_mask is not None + else None + ) + advantages_forward_micro_batch = advantages[ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + policy_model_logits = self.policy_model( input_ids=input_ids_forward_micro_batch, attention_mask=attention_mask_forward_micro_batch, ).logits - reference_action_log_probs = calc_action_log_probs( - reference_model_logits / self.generate_config["temperature"], - input_ids_forward_micro_batch, - num_action, - self.plugin.shard_config, - ) + action_log_probs = calc_action_log_probs( + policy_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) - per_token_kl = ( - torch.exp(reference_action_log_probs - action_log_probs) - - (reference_action_log_probs - action_log_probs) - - 1 - ) - kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum( - action_mask_forward_micro_batch, dim=-1 - ) + with torch.no_grad(): + reference_model_logits = self.reference_model( + input_ids=input_ids_forward_micro_batch, + attention_mask=attention_mask_forward_micro_batch, + ).logits + reference_action_log_probs = calc_action_log_probs( + reference_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) - loss, skip_update, _ = self.policy_loss_fn( - action_log_probs, - old_action_log_probs, - advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1), - per_token_kl, - action_mask_forward_micro_batch, - loss_mask=loss_mask_forward_micro_batch, - ) + per_token_kl = ( + torch.exp(reference_action_log_probs - action_log_probs) + - (reference_action_log_probs - action_log_probs) + - 1 + ) + kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum( + action_mask_forward_micro_batch, dim=-1 + ) + + loss, skip_update, _ = self.policy_loss_fn( + action_log_probs, + old_action_log_probs, + advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1), + per_token_kl, + action_mask_forward_micro_batch, + loss_mask=loss_mask_forward_micro_batch, + ) - if not skip_update: - self.booster.backward(loss, self.optimizer) - loss = all_reduce_mean(loss, self.plugin) - kl = all_reduce_mean(kl.mean(), self.plugin) - # Calculate accumulate value. - mean_kl.append(kl.data) - mean_loss.append(loss.data) + if not skip_update: + self.booster.backward(loss, self.optimizer) + loss = all_reduce_mean(loss, self.plugin) + kl = all_reduce_mean(kl.mean(), self.plugin) + # Calculate accumulate value. + mean_kl.append(kl.data) + mean_loss.append(loss.data) reward = all_reduce_mean(reward.mean(), self.plugin) format_reward = all_reduce_mean(format_reward.mean(), self.plugin) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index bb719a13c405..2b6faaa4ab90 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -31,7 +31,13 @@ default=1, help="Number of prompts per device. Number of samples = tMbs * num of generation per prompt.", ) - parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2, help="Number of samples per device.") + parser.add_argument( + "-tmbs", + "--train-microbatch-size", + type=int, + default=2, + help="Number of samples per device. PP micro batchsize when PP is activated.", + ) parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"]) parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"]) args = parser.parse_args() @@ -45,11 +51,7 @@ ray.init(address="local", namespace="ray-example") inference_model_config = dict(path=args.model) - train_model_config = dict( - path=args.model, - # use_flash_attention_2=True, - # use_cache=False - ) + train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False) generate_config = dict(top_k=50, top_p=0.75, temperature=0.9) if args.backend == "transformers": @@ -107,7 +109,7 @@ generate_config=generate_config, num_generations=args.num_generations, train_model_config=train_model_config, - plugin_config={}, + plugin_config={"pp_size": 2, "tp_size": 1, "microbatch_size": 2, "zero_stage": 0}, inference_backend=args.backend, master_addr="localhost", master_port=29505, diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 1684fd702e70..74349091b4d4 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1411,8 +1411,10 @@ def execute_pipeline( ) # run with gradients accumulation - if model.require_grad_sync == False or ( - isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False + if ( + not torch.is_grad_enabled() + or model.require_grad_sync == False + or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False) ): return outputs diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 71e3557fe214..27571309e453 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -284,6 +284,7 @@ def qwen2_for_causal_lm_forward( hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, + **kwargs, ): r""" Args: From 09a3173a4920ee64292b9c638611adaa3c2d427f Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Fri, 4 Apr 2025 10:05:16 +0800 Subject: [PATCH 06/17] add pp support --- .gitignore | 1 + .../coati/distributed/consumer.py | 1 - .../coati/distributed/grpo_consumer.py | 311 +++++++++++------- applications/ColossalChat/rl_example.py | 9 +- 4 files changed, 200 insertions(+), 122 deletions(-) diff --git a/.gitignore b/.gitignore index 16f764c1b1ef..533450a7cce1 100644 --- a/.gitignore +++ b/.gitignore @@ -164,3 +164,4 @@ coverage.xml applications/ColossalChat/logs applications/ColossalChat/tests/logs applications/ColossalChat/wandb +applications/ColossalChat/model diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 4e1cd1f3179a..c6ae7be2daf9 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -94,7 +94,6 @@ def loop(self) -> None: i = 0 for _ in range(self.num_recv_per_update): # receive data from producers - for r in range(self.num_producers): print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}") self.buffer.extend( diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index d05709febf52..fbc06edc2aa1 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -94,9 +94,7 @@ def __init__( self.policy_loss_fn = PolicyLoss() self.global_step = 0 - if use_wandb and self.rank == 0: - name = f"{generate_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}" - self.wandb_run = wandb.init(project="GRPO-V1-PP", sync_tensorboard=True, dir="./wandb", name=name) + self.use_wandb = use_wandb self.lr_scheduler = CosineAnnealingWarmupLR( optimizer=self.optimizer, @@ -107,10 +105,19 @@ def __init__( def setup(self): super().setup() + if self.use_wandb and ( + (not self.plugin.pp_size > 1 and self.rank == 0) + or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()) + ): + # Initialize wandb. + name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}" + self.wandb_run = wandb.init(project="GRPO-V1-PP", sync_tensorboard=True, dir="./wandb", name=name) + self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost( self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler ) self.reference_model, *_ = self.booster.boost(self.reference_model) + self.plugin.logger.set_level("ERROR") def step(self, step_idx: int, **kwargs) -> Optional[float]: """ @@ -168,72 +175,130 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: ).repeat_interleave(self.num_generations, dim=0) ) mean_kl, mean_loss = [], [] - if self.plugin.pp_size > 1: - # Support training with PP. - data_iter = iter([data]) - - with torch.no_grad(): - reference_model_outputs = self.booster.execute_pipeline( - data_iter, - self.reference_model, - criterion=lambda outputs, inputs: outputs.logits.mean(), # dummy criterion - optimizer=None, - return_loss=False, - return_outputs=True, - ) - - if self.booster.plugin.stage_manager.is_last_stage(): - reference_model_logits = reference_model_outputs["outputs"]["logits"] - reference_action_log_probs = calc_action_log_probs( - reference_model_logits / self.generate_config["temperature"], - data["input_ids"], - num_action, - self.plugin.shard_config, - ) - else: - # Dummy reference logprobs for data iterator. - reference_action_log_probs = torch.zeros( - (old_action_log_probs.size(0), old_action_log_probs.size(1)) - ) - - data["reference_action_log_probs"] = reference_action_log_probs - data_iter = iter([data]) + for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size): + input_ids_forward_micro_batch = data["input_ids"][ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + attention_mask_forward_micro_batch = data["attention_mask"][ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + action_mask_forward_micro_batch = action_mask[ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + loss_mask_forward_micro_batch = ( + loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size] + if loss_mask is not None + else None + ) + advantages_forward_micro_batch = advantages[ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] - def _criterion(outputs, inputs): - pass + if self.plugin.pp_size > 1: + # Support training with PP. - outputs = self.booster.execute_pipeline( - data_iter, - self.policy_model, - criterion=_criterion, - optimizer=self.optimizer, - return_loss=True, - ) - loss = outputs["loss"] + with torch.no_grad(): + reference_model_outputs = self.booster.execute_pipeline( + iter( + [ + { + "input_ids": input_ids_forward_micro_batch, + "attention_mask": attention_mask_forward_micro_batch, + } + ] + ), + self.reference_model, + criterion=lambda outputs, inputs: torch.tensor( + [0.0], device=action_mask.device + ), # dummy criterion + optimizer=None, + return_loss=False, + return_outputs=True, + ) + + if self.booster.plugin.stage_manager.is_last_stage(): + reference_model_logits = reference_model_outputs["outputs"]["logits"] + reference_action_log_probs = calc_action_log_probs( + reference_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) + else: + # Dummy reference logprobs for data iterator. + reference_action_log_probs = None + + data_policy_forward = { + "input_ids": input_ids_forward_micro_batch, + "attention_mask": attention_mask_forward_micro_batch, + "action_mask": action_mask_forward_micro_batch, + "reference_action_log_probs": reference_action_log_probs, + "advantages": advantages_forward_micro_batch, + "loss_mask": loss_mask_forward_micro_batch, + "source": self.rank, + } - if self.booster.plugin.stage_manager.is_last_stage(): - loss = all_reduce_mean(loss, self.plugin) - mean_loss.append(loss.data) - else: - for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size): - input_ids_forward_micro_batch = data["input_ids"][ - forward_micro_batch_start : forward_micro_batch_start + forward_batch_size - ] - attention_mask_forward_micro_batch = data["attention_mask"][ - forward_micro_batch_start : forward_micro_batch_start + forward_batch_size - ] - action_mask_forward_micro_batch = action_mask[ - forward_micro_batch_start : forward_micro_batch_start + forward_batch_size - ] - loss_mask_forward_micro_batch = ( - loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size] - if loss_mask is not None - else None + def _criterion(outputs, inputs): + action_logits = outputs.logits + action_log_probs = calc_action_log_probs( + action_logits / self.generate_config["temperature"], + inputs["input_ids"], + num_action, + self.plugin.shard_config, + ) + per_token_kl = ( + torch.exp(inputs["reference_action_log_probs"] - action_log_probs) + - (inputs["reference_action_log_probs"] - action_log_probs) + - 1 + ) + decode_tokens_100 = self.tokenizer.batch_decode( + input_ids_forward_micro_batch[:, -num_action:], + skip_special_tokens=False, + ) + loss, skip_update, _ = self.policy_loss_fn( + action_log_probs, + action_log_probs, + inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1), + per_token_kl, + inputs["action_mask"], + loss_mask=inputs["loss_mask"], + ) + return loss + + policy_model_outputs = self.booster.execute_pipeline( + iter([data_policy_forward]), + self.policy_model, + criterion=_criterion, + optimizer=self.optimizer, + return_loss=True, + return_outputs=True, ) - advantages_forward_micro_batch = advantages[ - forward_micro_batch_start : forward_micro_batch_start + forward_batch_size - ] + loss = policy_model_outputs["loss"] + + if self.booster.plugin.stage_manager.is_last_stage(): + # calculate kl + action_logits = policy_model_outputs["outputs"]["logits"] + action_log_probs = calc_action_log_probs( + action_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) + per_token_kl = ( + torch.exp(reference_action_log_probs - action_log_probs) + - (reference_action_log_probs - action_log_probs) + - 1 + ) + kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum( + action_mask_forward_micro_batch, dim=-1 + ) + kl = all_reduce_mean(kl.mean(), self.plugin) + loss = all_reduce_mean(loss, self.plugin) + mean_loss.append(loss.data) + mean_kl.append(kl) + else: + policy_model_logits = self.policy_model( input_ids=input_ids_forward_micro_batch, attention_mask=attention_mask_forward_micro_batch, @@ -256,7 +321,6 @@ def _criterion(outputs, inputs): num_action, self.plugin.shard_config, ) - per_token_kl = ( torch.exp(reference_action_log_probs - action_log_probs) - (reference_action_log_probs - action_log_probs) @@ -282,64 +346,71 @@ def _criterion(outputs, inputs): # Calculate accumulate value. mean_kl.append(kl.data) mean_loss.append(loss.data) - - reward = all_reduce_mean(reward.mean(), self.plugin) - format_reward = all_reduce_mean(format_reward.mean(), self.plugin) - acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin) - advantages = all_reduce_mean(advantages.mean(), self.plugin) - response_length = all_reduce_mean(response_length.mean(), self.plugin) - self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) - self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) - self.accum_reward.add_(reward.data) - self.accum_format_reward.add_(format_reward.data) - self.accum_acc_reward.add_(acc_reward.data) - self.accum_advantages.add_(advantages.data) - self.accum_response_length.add_(response_length.data) - self.accum_count += 1 + if not self.plugin.pp_size > 1 or ( + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() + ): + reward = all_reduce_mean(reward.mean(), self.plugin) + format_reward = all_reduce_mean(format_reward.mean(), self.plugin) + acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin) + advantages = all_reduce_mean(advantages.mean(), self.plugin) + response_length = all_reduce_mean(response_length.mean(), self.plugin) + self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) + self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) + self.accum_reward.add_(reward.data) + self.accum_format_reward.add_(format_reward.data) + self.accum_acc_reward.add_(acc_reward.data) + self.accum_advantages.add_(advantages.data) + self.accum_response_length.add_(response_length.data) + self.accum_count += 1 if need_update: self.optimizer.step() self.optimizer.zero_grad() - loss_scalar = self.accum_loss.item() - if self.rank == 0: - print( - "Loss:", - self.accum_loss.item() / self.accum_count, - "\nReward:", - self.accum_reward.item() / self.accum_count, - "\nFormat Reward:", - self.accum_format_reward.item() / self.accum_count, - "\nAcc Reward:", - self.accum_acc_reward.item() / self.accum_count, - "\nKL:", - self.accum_kl.item() / self.accum_count, - "\nAdvantages:", - self.accum_advantages.item() / self.accum_count, - "\nResponse Length:", - self.accum_response_length.item() / self.accum_count, - ) - self.wandb_run.log( - { - "metrics/reward": self.accum_reward.item() / self.accum_count, - "metrics/format_reward": self.accum_format_reward.item() / self.accum_count, - "metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count, - "metrics/response_length": self.accum_response_length.item() / self.accum_count, - "train/loss": self.accum_loss.item() / self.accum_count, - "train/kl": self.accum_kl.item() / self.accum_count, - "train/advantages": self.accum_advantages.item() / self.accum_count, - "train/learning_rate": self.lr_scheduler.get_last_lr()[0], - "rollout/temperature": data["temperature"].cpu().numpy()[0][0], - } - ) - self.accum_loss.zero_() - self.accum_reward.zero_() - self.accum_acc_reward.zero_() - self.accum_format_reward.zero_() - self.accum_kl.zero_() - self.accum_advantages.zero_() - self.accum_response_length.zero_() - - self.accum_count = 0 - return loss_scalar + if not self.plugin.pp_size > 1 or ( + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() + ): + loss_scalar = self.accum_loss.item() + if (not self.plugin.pp_size > 1 and self.rank == 0) or ( + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() + ): + print( + "Loss:", + self.accum_loss.item() / self.accum_count, + "\nReward:", + self.accum_reward.item() / self.accum_count, + "\nFormat Reward:", + self.accum_format_reward.item() / self.accum_count, + "\nAcc Reward:", + self.accum_acc_reward.item() / self.accum_count, + "\nKL:", + self.accum_kl.item() / self.accum_count, + "\nAdvantages:", + self.accum_advantages.item() / self.accum_count, + "\nResponse Length:", + self.accum_response_length.item() / self.accum_count, + ) + self.wandb_run.log( + { + "metrics/reward": self.accum_reward.item() / self.accum_count, + "metrics/format_reward": self.accum_format_reward.item() / self.accum_count, + "metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count, + "metrics/response_length": self.accum_response_length.item() / self.accum_count, + "train/loss": self.accum_loss.item() / self.accum_count, + "train/kl": self.accum_kl.item() / self.accum_count, + "train/advantages": self.accum_advantages.item() / self.accum_count, + "train/learning_rate": self.lr_scheduler.get_last_lr()[0], + "rollout/temperature": data["temperature"].cpu().numpy()[0][0], + } + ) + self.accum_loss.zero_() + self.accum_reward.zero_() + self.accum_acc_reward.zero_() + self.accum_format_reward.zero_() + self.accum_kl.zero_() + self.accum_advantages.zero_() + self.accum_response_length.zero_() + + self.accum_count = 0 + return loss_scalar def state_dict(self): self.policy_model._force_wait_all_gather() diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 2b6faaa4ab90..bf7a657e56c4 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -109,7 +109,14 @@ generate_config=generate_config, num_generations=args.num_generations, train_model_config=train_model_config, - plugin_config={"pp_size": 2, "tp_size": 1, "microbatch_size": 2, "zero_stage": 0}, + # plugin_config={}, # for zero + plugin_config={ + "pp_size": 2, + "tp_size": 1, + "microbatch_size": args.train_microbatch_size // 2, + "zero_stage": 0, + "max_norm": 1.0, + }, # for pp inference_backend=args.backend, master_addr="localhost", master_port=29505, From 061d8cb3b6485787e7c2f75868342b3a243c44e1 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Fri, 4 Apr 2025 10:11:11 +0800 Subject: [PATCH 07/17] remove unused code --- .../ColossalChat/coati/distributed/grpo_consumer.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index fbc06edc2aa1..f4174261ad2c 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -252,10 +252,6 @@ def _criterion(outputs, inputs): - (inputs["reference_action_log_probs"] - action_log_probs) - 1 ) - decode_tokens_100 = self.tokenizer.batch_decode( - input_ids_forward_micro_batch[:, -num_action:], - skip_special_tokens=False, - ) loss, skip_update, _ = self.policy_loss_fn( action_log_probs, action_log_probs, @@ -277,7 +273,7 @@ def _criterion(outputs, inputs): loss = policy_model_outputs["loss"] if self.booster.plugin.stage_manager.is_last_stage(): - # calculate kl + # calculate kl, as we cannot do this inside callback, kl needs be calculate again action_logits = policy_model_outputs["outputs"]["logits"] action_log_probs = calc_action_log_probs( action_logits / self.generate_config["temperature"], From a40d82f6293113cc73ebac7c2d14f49caf5e41a4 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Wed, 9 Apr 2025 12:53:40 +0800 Subject: [PATCH 08/17] address conversation --- .../coati/distributed/grpo_consumer.py | 31 +++++++------------ .../ColossalChat/coati/distributed/launch.py | 2 ++ applications/ColossalChat/rl_example.py | 22 +++++++++---- 3 files changed, 30 insertions(+), 25 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index f4174261ad2c..a282439cbd33 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -39,6 +39,7 @@ def __init__( use_wandb=True, generate_config=None, training_config={}, + project_name=None, ): super().__init__( num_producers, @@ -69,6 +70,7 @@ def __init__( self.accum_count = 0 self.generate_config = generate_config self.training_config = training_config + self.project_name = project_name # Reference model is initialized from policy model. self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) @@ -111,7 +113,7 @@ def setup(self): ): # Initialize wandb. name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}" - self.wandb_run = wandb.init(project="GRPO-V1-PP", sync_tensorboard=True, dir="./wandb", name=name) + self.wandb_run = wandb.init(project=self.project_name, sync_tensorboard=True, dir="./wandb", name=name) self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost( self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler @@ -239,6 +241,8 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: "source": self.rank, } + kl = [] + def _criterion(outputs, inputs): action_logits = outputs.logits action_log_probs = calc_action_log_probs( @@ -252,6 +256,10 @@ def _criterion(outputs, inputs): - (inputs["reference_action_log_probs"] - action_log_probs) - 1 ) + appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum( + inputs["action_mask"], dim=-1 + ) + kl.append(appox_kl.mean()) loss, skip_update, _ = self.policy_loss_fn( action_log_probs, action_log_probs, @@ -273,26 +281,11 @@ def _criterion(outputs, inputs): loss = policy_model_outputs["loss"] if self.booster.plugin.stage_manager.is_last_stage(): - # calculate kl, as we cannot do this inside callback, kl needs be calculate again - action_logits = policy_model_outputs["outputs"]["logits"] - action_log_probs = calc_action_log_probs( - action_logits / self.generate_config["temperature"], - input_ids_forward_micro_batch, - num_action, - self.plugin.shard_config, - ) - per_token_kl = ( - torch.exp(reference_action_log_probs - action_log_probs) - - (reference_action_log_probs - action_log_probs) - - 1 - ) - kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum( - action_mask_forward_micro_batch, dim=-1 - ) - kl = all_reduce_mean(kl.mean(), self.plugin) + if len(kl) > 0: + kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin) + mean_kl.append(kl) loss = all_reduce_mean(loss, self.plugin) mean_loss.append(loss.data) - mean_kl.append(kl) else: policy_model_logits = self.policy_model( diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index ba5d3a9d4fd8..699d90a8cdff 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -47,6 +47,7 @@ def launch_distributed( master_addr: str = "localhost", master_port: int = 29500, core_algo: str = "GRPO", + project_name: Optional[str] = None, ): if core_algo not in ALGO_MAP: @@ -108,6 +109,7 @@ def launch_distributed( "train_microbatch_size": train_microbatch_size, }, num_generations=num_generations, + project_name=project_name, ) procs.append(consumer) ray.get([p.setup.remote() for p in procs]) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index bf7a657e56c4..f87f12ed23ca 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -11,32 +11,41 @@ parser.add_argument("-t", "--num-trainers", type=int, default=2) parser.add_argument("-i", "--num-inferencer", type=int, default=2) parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.") + parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.") parser.add_argument( - "-ibs", "--inference-batch-size", type=int, default=64, help="Number of prompts to generate per step." + "-ibs", + "--inference-batch-size", + type=int, + default=64, + help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.", ) parser.add_argument( "-imbs", "--inference-microbatch-size", type=int, default=8, - help="Number of prompts to send from the producer to the consumer.", + help="Effective batch size for the inference backend to run generation. Please select based on memory constraint.", ) parser.add_argument( - "-tbs", "--train-batch-size", type=int, default=32, help="Number of prompts to update policy model." + "-tbs", + "--train-batch-size", + type=int, + default=32, + help="Number of unique prompts to update policy model per step per dp group. Gradient is accumulated across tbs * dp_size unique prompts, equivalently tbs * g * dp_size samples", ) parser.add_argument( "-tMbs", "--train-minibatch-size", type=int, default=1, - help="Number of prompts per device. Number of samples = tMbs * num of generation per prompt.", + help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs", ) parser.add_argument( "-tmbs", "--train-microbatch-size", type=int, default=2, - help="Number of samples per device. PP micro batchsize when PP is activated.", + help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.", ) parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"]) parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"]) @@ -119,6 +128,7 @@ }, # for pp inference_backend=args.backend, master_addr="localhost", - master_port=29505, + master_port=29506, core_algo=args.algo, + project_name=args.project, ) From 1ea3b72c2299e7c6cf82a0565fbe8ee37a9340b3 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Wed, 9 Apr 2025 17:11:55 +0800 Subject: [PATCH 09/17] fix memory leakage support tp+pp --- .../ColossalChat/coati/distributed/consumer.py | 4 ++++ .../coati/distributed/grpo_consumer.py | 14 +++++++------- applications/ColossalChat/rl_example.py | 2 +- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index c6ae7be2daf9..6372e2c3a898 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -72,6 +72,8 @@ def setup(self) -> None: self.plugin = HybridParallelPlugin(**plugin_config) self.booster = Booster(plugin=self.plugin) self.dp_rank = dist.get_rank(self.plugin.dp_group) + self.tp_rank = dist.get_rank(self.plugin.tp_group) + self.dp_size = dist.get_world_size(self.plugin.dp_group) self.buffer = [] @@ -132,6 +134,8 @@ def loop(self) -> None: ray_broadcast_tensor_dict( state_dict, src=self.num_producers, device=self.device, group_name="sync_model" ) + del state_dict + torch.cuda.empty_cache() @ray.remote diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index a282439cbd33..68a01528e2d1 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -109,7 +109,7 @@ def setup(self): super().setup() if self.use_wandb and ( (not self.plugin.pp_size > 1 and self.rank == 0) - or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()) + or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0) ): # Initialize wandb. name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}" @@ -282,10 +282,10 @@ def _criterion(outputs, inputs): if self.booster.plugin.stage_manager.is_last_stage(): if len(kl) > 0: - kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin) + kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data mean_kl.append(kl) - loss = all_reduce_mean(loss, self.plugin) - mean_loss.append(loss.data) + mean_loss.append(all_reduce_mean(loss, self.plugin).data) + torch.cuda.empty_cache() else: policy_model_logits = self.policy_model( @@ -336,7 +336,7 @@ def _criterion(outputs, inputs): mean_kl.append(kl.data) mean_loss.append(loss.data) if not self.plugin.pp_size > 1 or ( - self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): reward = all_reduce_mean(reward.mean(), self.plugin) format_reward = all_reduce_mean(format_reward.mean(), self.plugin) @@ -355,11 +355,11 @@ def _criterion(outputs, inputs): self.optimizer.step() self.optimizer.zero_grad() if not self.plugin.pp_size > 1 or ( - self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): loss_scalar = self.accum_loss.item() if (not self.plugin.pp_size > 1 and self.rank == 0) or ( - self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): print( "Loss:", diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index f87f12ed23ca..6c43ccd1960f 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -121,7 +121,7 @@ # plugin_config={}, # for zero plugin_config={ "pp_size": 2, - "tp_size": 1, + "tp_size": 2, "microbatch_size": args.train_microbatch_size // 2, "zero_stage": 0, "max_norm": 1.0, From f7e532511ce766bb1f5661c397fb468c3bad5523 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Thu, 10 Apr 2025 10:22:28 +0800 Subject: [PATCH 10/17] move empty cache --- applications/ColossalChat/coati/distributed/grpo_consumer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 68a01528e2d1..e23254d1b07a 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -285,7 +285,6 @@ def _criterion(outputs, inputs): kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data mean_kl.append(kl) mean_loss.append(all_reduce_mean(loss, self.plugin).data) - torch.cuda.empty_cache() else: policy_model_logits = self.policy_model( From 1723a0286023132437001773865f1410e9f5e4a0 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Thu, 10 Apr 2025 10:22:43 +0800 Subject: [PATCH 11/17] move empty cache --- applications/ColossalChat/coati/distributed/consumer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 6372e2c3a898..79beb2a2dba6 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -129,6 +129,7 @@ def loop(self) -> None: if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1: print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") + torch.cuda.empty_cache() state_dict = self.state_dict() if self.rank == 0: ray_broadcast_tensor_dict( From 6e71e2a3cecf87993fc5dcdf409a155883f33afb Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Tue, 15 Apr 2025 18:28:35 +0800 Subject: [PATCH 12/17] add DAPO support --- .../coati/distributed/grpo_consumer.py | 312 +++++++++++------- .../ColossalChat/coati/distributed/launch.py | 7 +- .../ColossalChat/coati/distributed/loss.py | 49 ++- .../coati/distributed/producer.py | 4 +- .../coati/distributed/reward/reward_fn.py | 21 +- .../distributed/reward/verifiable_reward.py | 2 + .../ColossalChat/coati/distributed/utils.py | 17 + .../ColossalChat/coati/trainer/utils.py | 22 +- applications/ColossalChat/rl_example.py | 44 ++- 9 files changed, 322 insertions(+), 156 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index e23254d1b07a..5cd97b20a26a 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -12,7 +12,7 @@ from coati.distributed.reward.reward_fn import math_reward_fn from coati.distributed.reward.verifiable_reward import VerifiableReward from coati.distributed.utils import calc_action_log_probs -from coati.trainer.utils import all_reduce_mean +from coati.trainer.utils import all_reduce_mean, all_reduce_sum from transformers import AutoModelForCausalLM, AutoTokenizer from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR @@ -38,7 +38,7 @@ def __init__( num_generations=8, use_wandb=True, generate_config=None, - training_config={}, + grpo_config={}, project_name=None, ): super().__init__( @@ -59,7 +59,7 @@ def __init__( self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.policy_model.train() self.policy_model.gradient_checkpointing_enable() - self.optimizer = HybridAdam(self.policy_model.parameters(), lr=training_config.get("lr", 1e-6)) + self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6)) self.accum_loss = torch.zeros(1, device=self.device) self.accum_reward = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device) @@ -69,8 +69,9 @@ def __init__( self.accum_response_length = torch.zeros(1, device=self.device) self.accum_count = 0 self.generate_config = generate_config - self.training_config = training_config + self.grpo_config = grpo_config self.project_name = project_name + self.effective_sample_count = 0 # Reference model is initialized from policy model. self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) @@ -79,10 +80,21 @@ def __init__( self.tokenizer = AutoTokenizer.from_pretrained(path) self.pad_token_id = self.tokenizer.pad_token_id self.num_generations = num_generations - self.filter_range = training_config.get("filter_range", None) + self.filter_range = grpo_config.get("filter_range", None) if self.filter_range is not None: assert len(self.filter_range) == 2, "Filter range should have 2 values." + self.filter_truncated_response = grpo_config.get("filter_truncated_response", False) + if self.filter_truncated_response: + self.max_length = 0 + if "max_tokens" in self.generate_config: + self.max_length = self.generate_config["max_tokens"] + elif "max_new_tokens" in self.generate_config: + self.max_length = self.generate_config["max_new_tokens"] + else: + raise ValueError( + "either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config." + ) # Initialize verifiable reward. response_format_tags = { "think_start": {"text": "", "num_occur": 1}, @@ -90,11 +102,20 @@ def __init__( "answer_start": {"text": "", "num_occur": 1}, "answer_end": {"text": "", "num_occur": 1}, } + reward_model_kwargs = { + k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"] + } self.reward_model = VerifiableReward( - reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags + reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags, **reward_model_kwargs ) - self.policy_loss_fn = PolicyLoss() + self.policy_loss_fn = PolicyLoss( + clip_eps_low=grpo_config.get("clip_eps_low", 0.2), + clip_eps_high=grpo_config.get("clip_eps_high", 0.2), + skip_threshold=grpo_config.get("skip_threshold", 20.0), + beta=grpo_config.get("beta", 0.01), + loss_variation=grpo_config.get("loss_variation", "sample_level"), + ) self.global_step = 0 self.use_wandb = use_wandb @@ -102,7 +123,7 @@ def __init__( optimizer=self.optimizer, total_steps=min(self.num_episodes, 4) * self.num_update_per_episode, warmup_steps=0, - eta_min=0.1 * training_config.get("lr", 1e-6), + eta_min=0.1 * grpo_config.get("lr", 1e-6), ) def setup(self): @@ -141,9 +162,65 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: num_action = action_mask.shape[1] old_action_log_probs = data["action_log_probs"] response_length = torch.sum(action_mask, dim=1).to(torch.float32) - forward_batch_size = self.training_config.get("train_microbatch_size", data["input_ids"].size(0)) + forward_batch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0)) + + reward_group = self.reward_model( + int(step_idx / self.num_microbatches), + data["input_ids"], + gt_answer=data["gt_answer"], + response_idx=data["response_idx"], + ) + + reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device) + format_reward = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device) + acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device) + + # [batch_size, num_generations] + + group_reward = reward.view(-1, self.num_generations) + reward_mean = group_reward.mean(dim=1) + # [batch_size x num_generations] + reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0) + + reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) + # [batch_size x num_generations] + advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1) + # filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score), + reward_mean_no_length_penalty = ( + (format_reward + acc_reward) + .view(-1, self.num_generations) + .mean(dim=1) + .repeat_interleave(self.num_generations, dim=0) + ) + loss_mask = ( + torch.ones(action_mask.size(0), device=action_mask.device).bool() + if self.filter_range is None + else torch.logical_and( + reward_mean_no_length_penalty > self.filter_range[0], reward_mean < self.filter_range[1] + ) + ) + # filter out overlength samples + if self.filter_truncated_response and action_mask.size(1) == self.max_length: + loss_mask = torch.logical_and( + loss_mask, + action_mask[:, -1] == False, + ) + # for i in range(loss_mask.size(0)): + # if loss_mask[i] == False: + # print(data["input_ids"].size(), data["input_ids"][i], action_mask[i], "mean reward", reward_mean_no_length_penalty.size(), reward_mean_no_length_penalty[i]) + + effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin) + self.effective_sample_count += effective_samples.item() + + mean_kl, mean_loss = [], [] + + # update gradient only if at least 0.7*batch_size*num_generation valid samples are collected in case a lot of samples are invalid and got filtered out. + # balance between efficiency and accuracy + need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations * 0.75 + if need_update: + print(f"***** Update gradient based on {self.effective_sample_count} valid samples *****") + self.effective_sample_count = 0 - need_update = (step_idx + 1) % self.num_microbatches == 0 # Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500 ctx = ( nullcontext() @@ -151,32 +228,6 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: else self.booster.no_sync(self.policy_model, self.optimizer) ) with ctx: - reward_group = self.reward_model( - data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"] - ) - - reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device) - format_reward = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device) - acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device) - - # [batch_size, num_generations] - - group_reward = reward.view(-1, self.num_generations) - reward_mean = group_reward.mean(dim=1) - # [batch_size x num_generations] - reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0) - reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) - # [batch_size x num_generations] - advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1) - # filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score), - loss_mask = ( - None - if self.filter_range is None - else torch.logical_and( - reward_mean > self.filter_range[0], reward_mean < self.filter_range[1] - ).repeat_interleave(self.num_generations, dim=0) - ) - mean_kl, mean_loss = [], [] for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size): input_ids_forward_micro_batch = data["input_ids"][ @@ -199,47 +250,50 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: if self.plugin.pp_size > 1: # Support training with PP. - - with torch.no_grad(): - reference_model_outputs = self.booster.execute_pipeline( - iter( - [ - { - "input_ids": input_ids_forward_micro_batch, - "attention_mask": attention_mask_forward_micro_batch, - } - ] - ), - self.reference_model, - criterion=lambda outputs, inputs: torch.tensor( - [0.0], device=action_mask.device - ), # dummy criterion - optimizer=None, - return_loss=False, - return_outputs=True, - ) - - if self.booster.plugin.stage_manager.is_last_stage(): - reference_model_logits = reference_model_outputs["outputs"]["logits"] - reference_action_log_probs = calc_action_log_probs( - reference_model_logits / self.generate_config["temperature"], - input_ids_forward_micro_batch, - num_action, - self.plugin.shard_config, - ) + if self.policy_loss_fn.beta > 0: + with torch.no_grad(): + reference_model_outputs = self.booster.execute_pipeline( + iter( + [ + { + "input_ids": input_ids_forward_micro_batch, + "attention_mask": attention_mask_forward_micro_batch, + } + ] + ), + self.reference_model, + criterion=lambda outputs, inputs: torch.tensor( + [0.0], device=action_mask.device + ), # dummy criterion + optimizer=None, + return_loss=False, + return_outputs=True, + ) + + if self.booster.plugin.stage_manager.is_last_stage(): + reference_model_logits = reference_model_outputs["outputs"]["logits"] + reference_action_log_probs = calc_action_log_probs( + reference_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) + else: + # Dummy reference logprobs for data iterator. + reference_action_log_probs = None else: - # Dummy reference logprobs for data iterator. reference_action_log_probs = None data_policy_forward = { "input_ids": input_ids_forward_micro_batch, "attention_mask": attention_mask_forward_micro_batch, "action_mask": action_mask_forward_micro_batch, - "reference_action_log_probs": reference_action_log_probs, "advantages": advantages_forward_micro_batch, "loss_mask": loss_mask_forward_micro_batch, "source": self.rank, } + if reference_action_log_probs is not None: + data_policy_forward["reference_action_log_probs"] = reference_action_log_probs kl = [] @@ -251,15 +305,20 @@ def _criterion(outputs, inputs): num_action, self.plugin.shard_config, ) - per_token_kl = ( - torch.exp(inputs["reference_action_log_probs"] - action_log_probs) - - (inputs["reference_action_log_probs"] - action_log_probs) - - 1 - ) - appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum( - inputs["action_mask"], dim=-1 - ) - kl.append(appox_kl.mean()) + if "reference_action_log_probs" in inputs: + per_token_kl = ( + torch.exp(inputs["reference_action_log_probs"] - action_log_probs) + - (inputs["reference_action_log_probs"] - action_log_probs) + - 1 + ) + appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum( + inputs["action_mask"], dim=-1 + ) + kl.append(appox_kl.mean()) + else: + per_token_kl = 0.0 + kl.append(0.0) + loss, skip_update, _ = self.policy_loss_fn( action_log_probs, action_log_probs, @@ -298,25 +357,29 @@ def _criterion(outputs, inputs): self.plugin.shard_config, ) - with torch.no_grad(): - reference_model_logits = self.reference_model( - input_ids=input_ids_forward_micro_batch, - attention_mask=attention_mask_forward_micro_batch, - ).logits - reference_action_log_probs = calc_action_log_probs( - reference_model_logits / self.generate_config["temperature"], - input_ids_forward_micro_batch, - num_action, - self.plugin.shard_config, - ) - per_token_kl = ( - torch.exp(reference_action_log_probs - action_log_probs) - - (reference_action_log_probs - action_log_probs) - - 1 - ) - kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum( - action_mask_forward_micro_batch, dim=-1 - ) + if self.policy_loss_fn.beta > 0: + with torch.no_grad(): + reference_model_logits = self.reference_model( + input_ids=input_ids_forward_micro_batch, + attention_mask=attention_mask_forward_micro_batch, + ).logits + reference_action_log_probs = calc_action_log_probs( + reference_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) + per_token_kl = ( + torch.exp(reference_action_log_probs - action_log_probs) + - (reference_action_log_probs - action_log_probs) + - 1 + ) + kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum( + action_mask_forward_micro_batch, dim=-1 + ) + else: + per_token_kl = 0.0 + kl = None loss, skip_update, _ = self.policy_loss_fn( action_log_probs, @@ -330,9 +393,10 @@ def _criterion(outputs, inputs): if not skip_update: self.booster.backward(loss, self.optimizer) loss = all_reduce_mean(loss, self.plugin) - kl = all_reduce_mean(kl.mean(), self.plugin) # Calculate accumulate value. - mean_kl.append(kl.data) + if kl is not None: + kl = all_reduce_mean(kl.mean(), self.plugin) + mean_kl.append(kl.data) mean_loss.append(loss.data) if not self.plugin.pp_size > 1 or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 @@ -343,7 +407,8 @@ def _criterion(outputs, inputs): advantages = all_reduce_mean(advantages.mean(), self.plugin) response_length = all_reduce_mean(response_length.mean(), self.plugin) self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) - self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) + if self.policy_loss_fn.beta > 0: + self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) self.accum_reward.add_(reward.data) self.accum_format_reward.add_(format_reward.data) self.accum_acc_reward.add_(acc_reward.data) @@ -360,35 +425,32 @@ def _criterion(outputs, inputs): if (not self.plugin.pp_size > 1 and self.rank == 0) or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): - print( - "Loss:", - self.accum_loss.item() / self.accum_count, - "\nReward:", - self.accum_reward.item() / self.accum_count, - "\nFormat Reward:", - self.accum_format_reward.item() / self.accum_count, - "\nAcc Reward:", - self.accum_acc_reward.item() / self.accum_count, - "\nKL:", - self.accum_kl.item() / self.accum_count, - "\nAdvantages:", - self.accum_advantages.item() / self.accum_count, - "\nResponse Length:", - self.accum_response_length.item() / self.accum_count, - ) - self.wandb_run.log( - { - "metrics/reward": self.accum_reward.item() / self.accum_count, - "metrics/format_reward": self.accum_format_reward.item() / self.accum_count, - "metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count, - "metrics/response_length": self.accum_response_length.item() / self.accum_count, - "train/loss": self.accum_loss.item() / self.accum_count, - "train/kl": self.accum_kl.item() / self.accum_count, - "train/advantages": self.accum_advantages.item() / self.accum_count, - "train/learning_rate": self.lr_scheduler.get_last_lr()[0], - "rollout/temperature": data["temperature"].cpu().numpy()[0][0], - } + to_log_msg = ( + f"Loss: {self.accum_loss.item() / self.accum_count:.4f} \ + Reward: {self.accum_reward.item() / self.accum_count:.4f} \ + Format Reward: {self.accum_format_reward.item() / self.accum_count:.4f} \ + Acc Reward: {self.accum_acc_reward.item() / self.accum_count:.4f} \ + Advantages: {self.accum_advantages.item() / self.accum_count:.4f} \ + Response Length: {self.accum_response_length.item() / self.accum_count:.4f}" + + f" KL: {self.accum_kl.item() / self.accum_count:.4f}" + if self.policy_loss_fn.beta > 0 + else "" ) + print(to_log_msg) + metrics = { + "metrics/reward": self.accum_reward.item() / self.accum_count, + "metrics/format_reward": self.accum_format_reward.item() / self.accum_count, + "metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count, + "metrics/response_length": self.accum_response_length.item() / self.accum_count, + "train/loss": self.accum_loss.item() / self.accum_count, + "train/advantages": self.accum_advantages.item() / self.accum_count, + "train/learning_rate": self.lr_scheduler.get_last_lr()[0], + "rollout/temperature": data["temperature"].cpu().numpy()[0][0], + } + if self.policy_loss_fn.beta > 0: + metrics["train/kl"] = self.accum_kl.item() / self.accum_count + + self.wandb_run.log(metrics) self.accum_loss.zero_() self.accum_reward.zero_() self.accum_acc_reward.zero_() diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 699d90a8cdff..8936752d2a79 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -40,6 +40,7 @@ def launch_distributed( inference_model_config: Dict[str, Any], generate_config: Dict[str, Any], train_model_config: Dict[str, Any], + grpo_config: Dict[str, Any], plugin_config: Dict[str, Any], tokenizer_config: Optional[Dict[str, Any]] = None, inference_backend: str = "transformers", @@ -103,11 +104,7 @@ def launch_distributed( plugin_config=plugin_config, microbatch_size=train_minibatch_size, generate_config=generate_config_consumer, - training_config={ - "filter_range": [0.05, 9.0], - "lr": 1e-6, - "train_microbatch_size": train_microbatch_size, - }, + grpo_config=grpo_config, num_generations=num_generations, project_name=project_name, ) diff --git a/applications/ColossalChat/coati/distributed/loss.py b/applications/ColossalChat/coati/distributed/loss.py index 90ad09736281..bdbe64a2a749 100644 --- a/applications/ColossalChat/coati/distributed/loss.py +++ b/applications/ColossalChat/coati/distributed/loss.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn -from coati.distributed.utils import masked_mean +from coati.distributed.utils import masked_mean, masked_sum class PolicyLoss(nn.Module): @@ -10,11 +10,21 @@ class PolicyLoss(nn.Module): Policy Loss for PPO """ - def __init__(self, clip_eps: float = 0.2, skip_threshold: float = 20.0, beta: float = 0.01) -> None: + def __init__( + self, + clip_eps_low: float = 0.2, + clip_eps_high: float = 0.2, + skip_threshold: float = 20.0, + beta: float = 0.01, + loss_variation: str = "sample_level", + ) -> None: super().__init__() - self.clip_eps = clip_eps + self.clip_eps_low = clip_eps_low + self.clip_eps_high = clip_eps_high self.skip_threshold = skip_threshold self.beta = beta + self.loss_variation = loss_variation + assert loss_variation in ["sample_level", "token_level"], f"Unsupported loss variation: {loss_variation}" def forward( self, @@ -32,14 +42,31 @@ def forward( ratio = ((log_probs - log_probs.detach()) * action_mask).exp() surr1 = ratio * advantages - surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages + surr2 = ratio.clamp(1 - self.clip_eps_low, 1 + self.clip_eps_high) * advantages + if self.beta <= 0: + # skip kl term if kl coefficient is zero + per_token_kl = 0.0 loss = -torch.min(surr1, surr2) + self.beta * per_token_kl - if action_mask is not None: - loss = masked_mean(loss, action_mask) - else: - loss = loss.mean(dim=1) - if loss_mask is not None: - loss = loss * loss_mask - loss = loss.mean() + if self.loss_variation == "sample_level": + if action_mask is not None: + loss = masked_mean(loss, action_mask) + else: + loss = loss.mean(dim=1) + if loss_mask is not None: + loss = loss * loss_mask + loss = loss.mean() + elif self.loss_variation == "token_level": + total_tokens = 0 + if action_mask is not None: + loss = masked_sum(loss, action_mask) + total_tokens = action_mask.sum(dim=1) + else: + loss = loss.sum(dim=1) + total_tokens = torch.ones_like(loss, device=loss.device) * log_probs.size(1) + if loss_mask is not None: + loss = loss * loss_mask + total_tokens = total_tokens * loss_mask + loss = loss.sum() / (total_tokens.sum() + 1e-8) + return loss, skip, ratio.max() diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 2c6a24a36711..c5681b9b58d3 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -124,12 +124,12 @@ def loop(self) -> None: self.load_state_dict(state_dict) del state_dict torch.cuda.empty_cache() - # linear annealing for 1 episode, temperature from initial to 0.7 + # linear annealing for 1 episode, temperature from initial to 0.9 if episode <= 0: ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader) self.model.generate_config.temperature = (1 - ratio) * self.generate_config[ "temperature" - ] + ratio * 0.7 + ] + ratio * 0.9 @ray.remote diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index 53bc15e25a8c..a0f92d8c4ba8 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -3,14 +3,29 @@ from .reward_utils import extract_solution, validate_response_structure -def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): +def math_reward_fn(step, input_ids, gt_answer, response_idx, **kwargs): + tokenizer = kwargs["tokenizer"] + soft_over_length_punishment = kwargs["soft_over_length_punishment"] format_score = 1.0 acc_score = 9.0 - tokenizer = kwargs["tokenizer"] + if step > 30: + format_score = 0.0 + acc_score = 10.0 reward = torch.tensor(0.0) format_reward = torch.tensor(0.0) acc_reward = torch.tensor(0.0) s, e = response_idx[0], response_idx[1] + + length_reward = 0.0 + if soft_over_length_punishment: + max_length = kwargs.get("max_length", 1024 * 4) + cache_length = kwargs.get("cache_length", 512) + res_length = e.item() - s.item() + 1 + if res_length >= max_length: + length_reward = -1.0 * 2 + elif res_length > max_length - cache_length: + length_reward = ((max_length - cache_length) - res_length) / cache_length * 2 + if gt_answer is None: return reward @@ -33,6 +48,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): acc_reward += acc_score reward += acc_score + reward = reward + length_reward + return torch.tensor([reward, format_reward, acc_reward]).to(input_ids.device) diff --git a/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py index ba83f7787586..01d8f1663502 100644 --- a/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py +++ b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py @@ -14,6 +14,7 @@ def __init__(self, reward_fns: List[callable], **kwargs: List[Dict[str, Any]]): def __call__( self, + step: int, input_ids: torch.LongTensor, gt_answer: List[torch.Tensor] = None, response_idx: List[torch.Tensor] = None, @@ -29,6 +30,7 @@ def __call__( reward_batch = torch.stack( [ reward_fn( + step, input_ids[i], gt_answer=gt_answer[i], response_idx=response_idx[i], diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index 919e4434faa6..5f7879669d08 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -113,3 +113,20 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch mask_sum = mask.sum(dim=dim) mean = tensor / (mask_sum + 1e-8) return mean + + +def masked_sum(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor: + """ + Compute the masked sum of a tensor along a specified dimension. + + Args: + tensor (torch.Tensor): The input tensor. + mask (torch.Tensor): The mask tensor with the same shape as the input tensor. + dim (int, optional): The dimension along which to compute the sum. Default is 1. + + Returns: + torch.Tensor: The masked sum tensor. + + """ + tensor = tensor * mask + return tensor.sum(dim=dim) diff --git a/applications/ColossalChat/coati/trainer/utils.py b/applications/ColossalChat/coati/trainer/utils.py index 22a5f492ebd2..45dfab5880c1 100755 --- a/applications/ColossalChat/coati/trainer/utils.py +++ b/applications/ColossalChat/coati/trainer/utils.py @@ -128,7 +128,21 @@ def all_reduce_mean(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor return tensor -def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor: +# def all_reduce_sum(tensor: torch.Tensor, ) -> torch.Tensor: +# """ +# Performs an all-reduce operation to sum the values of the given tensor across all processes. + +# Args: +# tensor (torch.Tensor): The input tensor to be reduced. + +# Returns: +# torch.Tensor: The reduced tensor with the sum of values across all processes. +# """ +# dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) +# return tensor + + +def all_reduce_sum(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor: """ Performs an all-reduce operation to sum the values of the given tensor across all processes. @@ -138,5 +152,9 @@ def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: The reduced tensor with the sum of values across all processes. """ - dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) + # All reduce sum across DP group + if plugin is not None: + dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM, group=plugin.dp_group) + else: + dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) return tensor diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 6c43ccd1960f..d4befa20ed59 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -60,8 +60,8 @@ ray.init(address="local", namespace="ray-example") inference_model_config = dict(path=args.model) - train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False) - generate_config = dict(top_k=50, top_p=0.75, temperature=0.9) + train_model_config = dict(path=args.model, use_flash_attention_2=False, use_cache=False) + generate_config = dict(top_k=-1, top_p=1.0, temperature=1.0) if args.backend == "transformers": inference_model_config.update( @@ -102,6 +102,29 @@ ) ) + # Default Settings + # grpo_config = { + # "filter_range": [0.05, 9.0], + # "lr": 1e-6, + # "train_microbatch_size": train_microbatch_size, + # } + + # DAPO variant settings + grpo_config = { + "filter_range": [0.05, 9.0], + "lr": 1e-6, + "train_microbatch_size": args.train_microbatch_size, + "clip_eps_low": 0.2, + "clip_eps_high": 0.28, + "skip_threshold": 20.0, + "beta": 0.0, # no KL penalty + "loss_variation": "token_level", + "soft_over_length_punishment": True, + "max_length": 1024 * 2, + "cache_length": 256, + "filter_truncated_response": True, + } + launch_distributed( num_producers=args.num_inferencer, num_proc_per_producer=1, @@ -118,14 +141,17 @@ generate_config=generate_config, num_generations=args.num_generations, train_model_config=train_model_config, - # plugin_config={}, # for zero + grpo_config=grpo_config, plugin_config={ - "pp_size": 2, - "tp_size": 2, - "microbatch_size": args.train_microbatch_size // 2, - "zero_stage": 0, - "max_norm": 1.0, - }, # for pp + "zero_stage": 2, + }, # for zero + # plugin_config={ + # "pp_size": 2, + # "tp_size": 2, + # "microbatch_size": args.train_microbatch_size // 2, + # "zero_stage": 0, + # "max_norm": 1.0, + # }, # for pp inference_backend=args.backend, master_addr="localhost", master_port=29506, From 447ab74fb4bb52af72f7eba59f6ae741c55c4184 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Wed, 16 Apr 2025 09:42:45 +0800 Subject: [PATCH 13/17] remove format reward --- .../coati/distributed/reward/reward_fn.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index a0f92d8c4ba8..1260645c910d 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -3,14 +3,11 @@ from .reward_utils import extract_solution, validate_response_structure -def math_reward_fn(step, input_ids, gt_answer, response_idx, **kwargs): +def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): tokenizer = kwargs["tokenizer"] soft_over_length_punishment = kwargs["soft_over_length_punishment"] - format_score = 1.0 - acc_score = 9.0 - if step > 30: - format_score = 0.0 - acc_score = 10.0 + format_score = 0.0 + acc_score = 10.0 reward = torch.tensor(0.0) format_reward = torch.tensor(0.0) acc_reward = torch.tensor(0.0) @@ -21,10 +18,8 @@ def math_reward_fn(step, input_ids, gt_answer, response_idx, **kwargs): max_length = kwargs.get("max_length", 1024 * 4) cache_length = kwargs.get("cache_length", 512) res_length = e.item() - s.item() + 1 - if res_length >= max_length: - length_reward = -1.0 * 2 - elif res_length > max_length - cache_length: - length_reward = ((max_length - cache_length) - res_length) / cache_length * 2 + if max_length - cache_length < res_length < max_length: + length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score if gt_answer is None: return reward From cc4faa73008fd6e7401eb1dc72a9e89e7ad2ef6f Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Wed, 16 Apr 2025 14:10:43 +0800 Subject: [PATCH 14/17] fix filtering, still buggy --- .../coati/distributed/grpo_consumer.py | 112 +++++++++--------- .../ColossalChat/coati/distributed/loss.py | 9 +- .../coati/distributed/reward/reward_fn.py | 10 +- .../distributed/reward/verifiable_reward.py | 2 - .../ColossalChat/coati/trainer/utils.py | 14 --- applications/ColossalChat/rl_example.py | 5 +- 6 files changed, 67 insertions(+), 85 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 5cd97b20a26a..dee4e648ba8d 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -63,8 +63,8 @@ def __init__( self.accum_loss = torch.zeros(1, device=self.device) self.accum_reward = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device) - self.accum_format_reward = torch.zeros(1, device=self.device) - self.accum_acc_reward = torch.zeros(1, device=self.device) + self.accum_format_acc = torch.zeros(1, device=self.device) + self.accum_ans_acc = torch.zeros(1, device=self.device) self.accum_advantages = torch.zeros(1, device=self.device) self.accum_response_length = torch.zeros(1, device=self.device) self.accum_count = 0 @@ -72,10 +72,19 @@ def __init__( self.grpo_config = grpo_config self.project_name = project_name self.effective_sample_count = 0 + self.total_sample_count = 0 + + self.policy_loss_fn = PolicyLoss( + clip_eps_low=grpo_config.get("clip_eps_low", 0.2), + clip_eps_high=grpo_config.get("clip_eps_high", 0.2), + beta=grpo_config.get("beta", 0.01), + loss_variation=grpo_config.get("loss_variation", "sample_level"), + ) # Reference model is initialized from policy model. - self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) - self.reference_model.eval() + if self.policy_loss_fn.beta > 0: + self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) + self.reference_model.eval() self.tokenizer = AutoTokenizer.from_pretrained(path) self.pad_token_id = self.tokenizer.pad_token_id @@ -108,14 +117,6 @@ def __init__( self.reward_model = VerifiableReward( reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags, **reward_model_kwargs ) - - self.policy_loss_fn = PolicyLoss( - clip_eps_low=grpo_config.get("clip_eps_low", 0.2), - clip_eps_high=grpo_config.get("clip_eps_high", 0.2), - skip_threshold=grpo_config.get("skip_threshold", 20.0), - beta=grpo_config.get("beta", 0.01), - loss_variation=grpo_config.get("loss_variation", "sample_level"), - ) self.global_step = 0 self.use_wandb = use_wandb @@ -139,7 +140,8 @@ def setup(self): self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost( self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler ) - self.reference_model, *_ = self.booster.boost(self.reference_model) + if self.policy_loss_fn.beta > 0: + self.reference_model, *_ = self.booster.boost(self.reference_model) self.plugin.logger.set_level("ERROR") def step(self, step_idx: int, **kwargs) -> Optional[float]: @@ -165,15 +167,14 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: forward_batch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0)) reward_group = self.reward_model( - int(step_idx / self.num_microbatches), data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"], ) reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device) - format_reward = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device) - acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device) + format_acc = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device) + ans_acc = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device) # [batch_size, num_generations] @@ -186,18 +187,13 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: # [batch_size x num_generations] advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1) # filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score), - reward_mean_no_length_penalty = ( - (format_reward + acc_reward) - .view(-1, self.num_generations) - .mean(dim=1) - .repeat_interleave(self.num_generations, dim=0) + group_ans_acc = ( + ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0) ) loss_mask = ( torch.ones(action_mask.size(0), device=action_mask.device).bool() if self.filter_range is None - else torch.logical_and( - reward_mean_no_length_penalty > self.filter_range[0], reward_mean < self.filter_range[1] - ) + else torch.logical_and(group_ans_acc > self.filter_range[0], group_ans_acc < self.filter_range[1]) ) # filter out overlength samples if self.filter_truncated_response and action_mask.size(1) == self.max_length: @@ -205,21 +201,23 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: loss_mask, action_mask[:, -1] == False, ) - # for i in range(loss_mask.size(0)): - # if loss_mask[i] == False: - # print(data["input_ids"].size(), data["input_ids"][i], action_mask[i], "mean reward", reward_mean_no_length_penalty.size(), reward_mean_no_length_penalty[i]) effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin) + total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin) self.effective_sample_count += effective_samples.item() + self.total_sample_count += total_samples.item() + print( + loss_mask, + self.effective_sample_count, + self.total_sample_count, + self.batch_size * self.dp_size * self.num_generations * 0.75, + ) mean_kl, mean_loss = [], [] # update gradient only if at least 0.7*batch_size*num_generation valid samples are collected in case a lot of samples are invalid and got filtered out. # balance between efficiency and accuracy need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations * 0.75 - if need_update: - print(f"***** Update gradient based on {self.effective_sample_count} valid samples *****") - self.effective_sample_count = 0 # Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500 ctx = ( @@ -319,7 +317,7 @@ def _criterion(outputs, inputs): per_token_kl = 0.0 kl.append(0.0) - loss, skip_update, _ = self.policy_loss_fn( + loss, _ = self.policy_loss_fn( action_log_probs, action_log_probs, inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1), @@ -381,7 +379,7 @@ def _criterion(outputs, inputs): per_token_kl = 0.0 kl = None - loss, skip_update, _ = self.policy_loss_fn( + loss, _ = self.policy_loss_fn( action_log_probs, old_action_log_probs, advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1), @@ -390,8 +388,7 @@ def _criterion(outputs, inputs): loss_mask=loss_mask_forward_micro_batch, ) - if not skip_update: - self.booster.backward(loss, self.optimizer) + self.booster.backward(loss, self.optimizer) loss = all_reduce_mean(loss, self.plugin) # Calculate accumulate value. if kl is not None: @@ -402,22 +399,25 @@ def _criterion(outputs, inputs): self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): reward = all_reduce_mean(reward.mean(), self.plugin) - format_reward = all_reduce_mean(format_reward.mean(), self.plugin) - acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin) + format_acc = all_reduce_mean(format_acc.mean(), self.plugin) + ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin) advantages = all_reduce_mean(advantages.mean(), self.plugin) response_length = all_reduce_mean(response_length.mean(), self.plugin) self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) if self.policy_loss_fn.beta > 0: self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) self.accum_reward.add_(reward.data) - self.accum_format_reward.add_(format_reward.data) - self.accum_acc_reward.add_(acc_reward.data) + self.accum_format_acc.add_(format_acc.data) + self.accum_ans_acc.add_(ans_acc.data) self.accum_advantages.add_(advantages.data) self.accum_response_length.add_(response_length.data) self.accum_count += 1 if need_update: self.optimizer.step() self.optimizer.zero_grad() + sample_utilization = self.effective_sample_count / self.total_sample_count + self.effective_sample_count = 0 + self.total_sample_count = 0 if not self.plugin.pp_size > 1 or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): @@ -428,8 +428,8 @@ def _criterion(outputs, inputs): to_log_msg = ( f"Loss: {self.accum_loss.item() / self.accum_count:.4f} \ Reward: {self.accum_reward.item() / self.accum_count:.4f} \ - Format Reward: {self.accum_format_reward.item() / self.accum_count:.4f} \ - Acc Reward: {self.accum_acc_reward.item() / self.accum_count:.4f} \ + Format Reward: {self.accum_format_acc.item() / self.accum_count:.4f} \ + Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f} \ Advantages: {self.accum_advantages.item() / self.accum_count:.4f} \ Response Length: {self.accum_response_length.item() / self.accum_count:.4f}" + f" KL: {self.accum_kl.item() / self.accum_count:.4f}" @@ -439,12 +439,13 @@ def _criterion(outputs, inputs): print(to_log_msg) metrics = { "metrics/reward": self.accum_reward.item() / self.accum_count, - "metrics/format_reward": self.accum_format_reward.item() / self.accum_count, - "metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count, + "metrics/format_acc": self.accum_format_acc.item() / self.accum_count, + "metrics/ans_acc": self.accum_ans_acc.item() / self.accum_count, "metrics/response_length": self.accum_response_length.item() / self.accum_count, "train/loss": self.accum_loss.item() / self.accum_count, "train/advantages": self.accum_advantages.item() / self.accum_count, "train/learning_rate": self.lr_scheduler.get_last_lr()[0], + "train/sample_utilization": sample_utilization, "rollout/temperature": data["temperature"].cpu().numpy()[0][0], } if self.policy_loss_fn.beta > 0: @@ -453,12 +454,11 @@ def _criterion(outputs, inputs): self.wandb_run.log(metrics) self.accum_loss.zero_() self.accum_reward.zero_() - self.accum_acc_reward.zero_() - self.accum_format_reward.zero_() + self.accum_ans_acc.zero_() + self.accum_format_acc.zero_() self.accum_kl.zero_() self.accum_advantages.zero_() self.accum_response_length.zero_() - self.accum_count = 0 return loss_scalar @@ -507,8 +507,8 @@ def __init__( self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.policy_model.train() self.accum_reward = torch.zeros(1, device=self.device) - self.accum_format_reward = torch.zeros(1, device=self.device) - self.accum_acc_reward = torch.zeros(1, device=self.device) + self.accum_format_acc = torch.zeros(1, device=self.device) + self.accum_ans_acc = torch.zeros(1, device=self.device) self.accum_response_length = torch.zeros(1, device=self.device) self.accum_count = torch.zeros(1, device=self.device) @@ -545,8 +545,8 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"] ) reward = [value[0].item() for value in reward_group] - format_reward = [value[1].item() for value in reward_group] - acc_reward = [value[2].item() for value in reward_group] + format_acc = [value[1].item() for value in reward_group] + ans_acc = [value[2].item() for value in reward_group] response_length = [(data["response_idx"][i][1] - data["response_idx"][i][0]).item() for i in range(len(reward))] response = self.tokenizer.batch_decode(data["input_ids"], skip_special_tokens=True) @@ -557,8 +557,8 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: { "response": response[i], "reward": reward[i], - "format_reward": format_reward[i], - "acc_reward": acc_reward[i], + "format_acc": format_acc[i], + "ans_acc": ans_acc[i], "response_length": response_length[i], }, ensure_ascii=False, @@ -567,20 +567,20 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: ) self.accum_reward += sum(reward) - self.accum_format_reward += sum(format_reward) - self.accum_acc_reward += sum(acc_reward) + self.accum_format_acc += sum(format_acc) + self.accum_ans_acc += sum(ans_acc) self.accum_response_length += sum(response_length) self.accum_count += len(reward) # print results total_count = all_reduce_mean(self.accum_count, self.plugin) mean_reward = all_reduce_mean(self.accum_reward, self.plugin) / total_count - mean_format_reward = all_reduce_mean(self.accum_format_reward, self.plugin) / total_count - mean_acc_reward = all_reduce_mean(self.accum_acc_reward, self.plugin) / total_count + mean_format_acc = all_reduce_mean(self.accum_format_acc, self.plugin) / total_count + mean_ans_acc = all_reduce_mean(self.accum_ans_acc, self.plugin) / total_count mean_response_length = all_reduce_mean(self.accum_response_length, self.plugin) / total_count if rank == 0: print( - f"Step {step_idx}: Mean Reward: {mean_reward}, Mean Format Reward: {mean_format_reward}, Mean Acc Reward: {mean_acc_reward}, Mean Response Length: {mean_response_length}" + f"Step {step_idx}: Mean Reward: {mean_reward}, Mean Format Reward: {mean_format_acc}, Mean Acc Reward: {mean_ans_acc}, Mean Response Length: {mean_response_length}" ) return None diff --git a/applications/ColossalChat/coati/distributed/loss.py b/applications/ColossalChat/coati/distributed/loss.py index bdbe64a2a749..d00335db0912 100644 --- a/applications/ColossalChat/coati/distributed/loss.py +++ b/applications/ColossalChat/coati/distributed/loss.py @@ -14,14 +14,12 @@ def __init__( self, clip_eps_low: float = 0.2, clip_eps_high: float = 0.2, - skip_threshold: float = 20.0, beta: float = 0.01, loss_variation: str = "sample_level", ) -> None: super().__init__() self.clip_eps_low = clip_eps_low self.clip_eps_high = clip_eps_high - self.skip_threshold = skip_threshold self.beta = beta self.loss_variation = loss_variation assert loss_variation in ["sample_level", "token_level"], f"Unsupported loss variation: {loss_variation}" @@ -35,7 +33,6 @@ def forward( action_mask: Optional[torch.Tensor] = None, loss_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: - skip = False if action_mask is None: ratio = (log_probs - log_probs.detach()).exp() else: @@ -43,7 +40,7 @@ def forward( surr1 = ratio * advantages surr2 = ratio.clamp(1 - self.clip_eps_low, 1 + self.clip_eps_high) * advantages - if self.beta <= 0: + if self.beta == 0: # skip kl term if kl coefficient is zero per_token_kl = 0.0 loss = -torch.min(surr1, surr2) + self.beta * per_token_kl @@ -68,5 +65,7 @@ def forward( loss = loss * loss_mask total_tokens = total_tokens * loss_mask loss = loss.sum() / (total_tokens.sum() + 1e-8) + else: + raise ValueError(f"Unsupported loss variation: {self.loss_variation}") - return loss, skip, ratio.max() + return loss, ratio.max() diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index 1260645c910d..b1ac02fcd5ca 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -9,8 +9,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): format_score = 0.0 acc_score = 10.0 reward = torch.tensor(0.0) - format_reward = torch.tensor(0.0) - acc_reward = torch.tensor(0.0) + format_acc = torch.tensor(0.0) + ans_acc = torch.tensor(0.0) s, e = response_idx[0], response_idx[1] length_reward = 0.0 @@ -32,7 +32,7 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): # Check format accuracy if format_valid: - format_reward += format_score + format_acc += 1 reward += format_score # Check answer accuracy @@ -40,12 +40,12 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): final_answer is not None and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower() ): - acc_reward += acc_score + ans_acc += 1 reward += acc_score reward = reward + length_reward - return torch.tensor([reward, format_reward, acc_reward]).to(input_ids.device) + return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) def gsm8k_reward_fn(input_ids, **kwargs): diff --git a/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py index 01d8f1663502..ba83f7787586 100644 --- a/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py +++ b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py @@ -14,7 +14,6 @@ def __init__(self, reward_fns: List[callable], **kwargs: List[Dict[str, Any]]): def __call__( self, - step: int, input_ids: torch.LongTensor, gt_answer: List[torch.Tensor] = None, response_idx: List[torch.Tensor] = None, @@ -30,7 +29,6 @@ def __call__( reward_batch = torch.stack( [ reward_fn( - step, input_ids[i], gt_answer=gt_answer[i], response_idx=response_idx[i], diff --git a/applications/ColossalChat/coati/trainer/utils.py b/applications/ColossalChat/coati/trainer/utils.py index 45dfab5880c1..5153ce3adad5 100755 --- a/applications/ColossalChat/coati/trainer/utils.py +++ b/applications/ColossalChat/coati/trainer/utils.py @@ -128,20 +128,6 @@ def all_reduce_mean(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor return tensor -# def all_reduce_sum(tensor: torch.Tensor, ) -> torch.Tensor: -# """ -# Performs an all-reduce operation to sum the values of the given tensor across all processes. - -# Args: -# tensor (torch.Tensor): The input tensor to be reduced. - -# Returns: -# torch.Tensor: The reduced tensor with the sum of values across all processes. -# """ -# dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) -# return tensor - - def all_reduce_sum(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor: """ Performs an all-reduce operation to sum the values of the given tensor across all processes. diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 0dc84045abc4..ccbb5b29781d 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -111,7 +111,7 @@ # DAPO variant settings grpo_config = { - "filter_range": [0.05, 9.0], + "filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch "lr": 1e-6, "train_microbatch_size": args.train_microbatch_size, "clip_eps_low": 0.2, @@ -144,8 +144,7 @@ grpo_config=grpo_config, plugin_config={ "zero_stage": 2, - }, - # for zero + }, # for zero # plugin_config={ # "pp_size": 2, # "tp_size": 2, From 94743161326d2ede61326950d59efc8ed2119d0f Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Wed, 16 Apr 2025 15:59:25 +0800 Subject: [PATCH 15/17] small fix --- .../coati/distributed/grpo_consumer.py | 26 ++++++++----------- .../coati/distributed/reward/reward_fn.py | 5 ++-- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index dee4e648ba8d..5886cc7fee18 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -206,12 +206,6 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin) self.effective_sample_count += effective_samples.item() self.total_sample_count += total_samples.item() - print( - loss_mask, - self.effective_sample_count, - self.total_sample_count, - self.batch_size * self.dp_size * self.num_generations * 0.75, - ) mean_kl, mean_loss = [], [] @@ -426,17 +420,19 @@ def _criterion(outputs, inputs): self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): to_log_msg = ( - f"Loss: {self.accum_loss.item() / self.accum_count:.4f} \ - Reward: {self.accum_reward.item() / self.accum_count:.4f} \ - Format Reward: {self.accum_format_acc.item() / self.accum_count:.4f} \ - Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f} \ - Advantages: {self.accum_advantages.item() / self.accum_count:.4f} \ - Response Length: {self.accum_response_length.item() / self.accum_count:.4f}" - + f" KL: {self.accum_kl.item() / self.accum_count:.4f}" + [ + f"Loss: {self.accum_loss.item() / self.accum_count:.4f}", + f"Reward: {self.accum_reward.item() / self.accum_count:.4f}", + f"ormat Reward: {self.accum_format_acc.item() / self.accum_count:.4f}", + f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}", + f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}", + f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}", + ] + + [f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 - else "" + else [] ) - print(to_log_msg) + print("\n".join(to_log_msg)) metrics = { "metrics/reward": self.accum_reward.item() / self.accum_count, "metrics/format_acc": self.accum_format_acc.item() / self.accum_count, diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index b1ac02fcd5ca..3cf7a1af3cd8 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -35,9 +35,10 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): format_acc += 1 reward += format_score - # Check answer accuracy + # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid if ( - final_answer is not None + format_valid + and final_answer is not None and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower() ): ans_acc += 1 From c0c0da2f2609b89c3da8d9e10cbab762cee2d1b0 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Tue, 15 Apr 2025 18:28:35 +0800 Subject: [PATCH 16/17] add DAPO support --- applications/ColossalChat/coati/distributed/consumer.py | 4 ++-- .../ColossalChat/coati/distributed/grpo_consumer.py | 6 ++++-- applications/ColossalChat/rl_example.py | 6 +++--- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 79beb2a2dba6..1dcfde4d6944 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -112,7 +112,7 @@ def loop(self) -> None: self.buffer = self.buffer[self.dp_size * self.microbatch_size :] batch = bind_batch(batches) batch = post_recv(batch) - loss = self.step(i, **batch) + loss = self.step(i, pbar, **batch) if loss is not None: pbar.set_postfix({"loss": loss}) i += 1 @@ -181,7 +181,7 @@ def setup(self): super().setup() self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer) - def step(self, step_idx: int, **kwargs) -> Optional[float]: + def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]: labels = kwargs["input_ids"].clone() labels[kwargs["attention_mask"] == 0] = -100 kwargs["labels"] = labels diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 5886cc7fee18..c0cb7663bfec 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -1,7 +1,7 @@ import json import os from contextlib import nullcontext -from typing import Optional +from typing import Optional, Any import ray import torch @@ -144,7 +144,7 @@ def setup(self): self.reference_model, *_ = self.booster.boost(self.reference_model) self.plugin.logger.set_level("ERROR") - def step(self, step_idx: int, **kwargs) -> Optional[float]: + def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]: """ Step data from policy model: [{ @@ -212,6 +212,7 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: # update gradient only if at least 0.7*batch_size*num_generation valid samples are collected in case a lot of samples are invalid and got filtered out. # balance between efficiency and accuracy need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations * 0.75 + pbar.set_postfix({"Step": self.global_step + 1, "Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations * 0.75}"}) # Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500 ctx = ( @@ -409,6 +410,7 @@ def _criterion(outputs, inputs): if need_update: self.optimizer.step() self.optimizer.zero_grad() + self.global_step += 1 sample_utilization = self.effective_sample_count / self.total_sample_count self.effective_sample_count = 0 self.total_sample_count = 0 diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index ccbb5b29781d..bb60a4b1a28b 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -83,7 +83,7 @@ inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True)) generate_config.update( dict( - max_tokens=2048, + max_tokens=4096, ignore_eos=True, include_stop_str_in_output=True, stop=[""], @@ -120,8 +120,8 @@ "beta": 0.0, # no KL penalty "loss_variation": "token_level", "soft_over_length_punishment": True, - "max_length": 1024 * 2, - "cache_length": 256, + "max_length": 4096, + "cache_length": 512, "filter_truncated_response": True, } From e397327de01b675c49c3e95711b77579de040a62 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 16 Apr 2025 10:00:10 +0000 Subject: [PATCH 17/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../ColossalChat/coati/distributed/grpo_consumer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index c0cb7663bfec..8acf41807df1 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -1,7 +1,7 @@ import json import os from contextlib import nullcontext -from typing import Optional, Any +from typing import Any, Optional import ray import torch @@ -212,7 +212,12 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]: # update gradient only if at least 0.7*batch_size*num_generation valid samples are collected in case a lot of samples are invalid and got filtered out. # balance between efficiency and accuracy need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations * 0.75 - pbar.set_postfix({"Step": self.global_step + 1, "Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations * 0.75}"}) + pbar.set_postfix( + { + "Step": self.global_step + 1, + "Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations * 0.75}", + } + ) # Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500 ctx = (