From 8e6c9a4ab3f10a22ee5d1d3001ea64d0bda5191a Mon Sep 17 00:00:00 2001 From: Tong Li Date: Sun, 23 Feb 2025 11:02:54 +0800 Subject: [PATCH 01/35] add reward related function --- .../coati/distributed/reward/reward_fn.py | 51 +++++++++++++ .../coati/distributed/reward/reward_utils.py | 76 +++++++++++++++++++ .../distributed/reward/verifiable_reward.py | 47 ++++++++++++ 3 files changed, 174 insertions(+) create mode 100644 applications/ColossalChat/coati/distributed/reward/reward_fn.py create mode 100644 applications/ColossalChat/coati/distributed/reward/reward_utils.py create mode 100644 applications/ColossalChat/coati/distributed/reward/verifiable_reward.py diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py new file mode 100644 index 000000000000..f127a1eced20 --- /dev/null +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -0,0 +1,51 @@ +import torch + +from .reward_utils import extract_solution, validate_response_structure + + +def math_reward_fn(input_ids, **kwargs): + # apply varifiable reward + # reward 10 points if the final answer is correct, reward 1 point if format is correct + + gt_answer = kwargs["gt_answer"] + tokenizer = kwargs["tokenizer"] + s, e = kwargs["response_start"], kwargs["response_end"] + reward = torch.tensor(0.0).to(input_ids.device) + if gt_answer is None: + return reward + decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) + final_answer, processed_str = extract_solution(decoded_final_answer) + + format_valid = validate_response_structure(processed_str, kwargs["tags"]) + if not format_valid: + return reward + else: + reward += 1.0 + if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower(): + reward = reward + 9.0 + return reward + + +def gsm8k_reward_fn(input_ids, **kwargs): + gt_answer = kwargs["gt_answer"] + tokenizer = kwargs["tokenizer"] + s, e = kwargs["response_start"], kwargs["response_end"] + reward = torch.tensor(0.0).to(input_ids.device) + if gt_answer is None: + return reward + decoded_final_answer = tokenizer.decode(input_ids[s:e], skip_special_tokens=True) + final_answer, processed_str = extract_solution(decoded_final_answer) + is_valid = True + try: + int(final_answer.strip()) + except Exception: + is_valid = False + + format_valid = validate_response_structure(processed_str, kwargs["tags"]) + if not is_valid or not format_valid: + return reward + else: + reward += 1.0 + if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower(): + reward = reward + 9.0 + return reward diff --git a/applications/ColossalChat/coati/distributed/reward/reward_utils.py b/applications/ColossalChat/coati/distributed/reward/reward_utils.py new file mode 100644 index 000000000000..c1e73d4b9738 --- /dev/null +++ b/applications/ColossalChat/coati/distributed/reward/reward_utils.py @@ -0,0 +1,76 @@ +# Copyright Unakar +# Modified from https://github.com/Unakar/Logic-RL/blob/086373176ac198c97277ff50f4b6e7e1bfe669d3/verl/utils/reward_score/kk.py#L99 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Dict, Optional, Tuple + + +def validate_response_structure(processed_str: str, tags: Dict = None) -> bool: + """Performs comprehensive validation of response structure. + + Args: + processed_str: Processed response string from the model + + Returns: + Boolean indicating whether all formatting requirements are met + """ + validation_passed = True + # Check required tags + if tags is None: + tags = { + "think_start": {"text": "", "num_occur": 1}, + "think_end": {"text": "", "num_occur": 1}, + "answer_start": {"text": "", "num_occur": 1}, + "answer_end": {"text": "", "num_occur": 1}, + } + positions = {} + for tag_name, tag_info in tags.items(): + tag_str = tag_info["text"] + expected_count = tag_info["num_occur"] + count = processed_str.count(tag_str) + positions[tag_name] = pos = processed_str.find(tag_str) + if count != expected_count: + validation_passed = False + # Verify tag order + if ( + positions["think_start"] > positions["think_end"] + or positions["think_end"] > positions["answer_start"] + or positions["answer_start"] > positions["answer_end"] + ): + validation_passed = False + if len(processed_str) - positions["answer_end"] != len(tags["answer_end"]["text"]): + validation_passed = False + return validation_passed + + +def extract_solution(solution_str: str) -> Tuple[Optional[str], str]: + """Extracts the final answer from the model's response string. + + Args: + solution_str: Raw response string from the language model + + Returns: + Tuple containing (extracted_answer, processed_string) + """ + + # Extract final answer using XML-style tags + answer_pattern = r"(.*?)" + matches = list(re.finditer(answer_pattern, solution_str, re.DOTALL)) + + if not matches: + return None, solution_str + + final_answer = matches[-1].group(1).strip() + return final_answer, solution_str diff --git a/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py new file mode 100644 index 000000000000..d1700d86f14e --- /dev/null +++ b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py @@ -0,0 +1,47 @@ +""" +Function-based reward verification module. +""" + +from typing import Any, Dict, List + +import torch + + +class VerifiableReward: + def __init__(self, reward_fn: List[callable], reward_args: List[Dict[str, Any]]): + self.reward_fn = reward_fn + self.reward_args = reward_args + + def __call__( + self, + input_ids: torch.LongTensor, + attention_mask: torch.LongTensor, + response_start: List[int] = None, + response_end: List[int] = None, + gt_answer: List[str] = None, + ) -> torch.Tensor: + # Get batch size + bs = input_ids.size(0) + # Initialize reward + reward = torch.zeros(bs, device=input_ids.device) + + # Loop through reward functions + for reward_fn in self.reward_fn_list: + # Apply the reward function to the entire batch at once + reward_batch = torch.stack( + [ + reward_fn( + input_ids[i], + attention_mask[i], + response_start=response_start[i], + response_end=response_end[i], + gt_answer=gt_answer[i], + **self.kwargs, + ) + for i in range(bs) + ], + dim=0, + ) + + rewards += reward_batch + return rewards From ffd3878a1ecaba925de4e64eb41e89bf0dfafd70 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Sun, 23 Feb 2025 22:54:26 +0800 Subject: [PATCH 02/35] add simple grpo --- .../ColossalChat/coati/dataset/loader.py | 11 ++ .../coati/distributed/grpo_consumer.py | 150 ++++++++++++++++++ .../coati/distributed/inference_backend.py | 2 + .../ColossalChat/coati/distributed/launch.py | 4 +- .../ColossalChat/coati/distributed/loss.py | 44 +++++ .../coati/distributed/reward/reward_fn.py | 10 +- .../distributed/reward/verifiable_reward.py | 18 +-- .../ColossalChat/coati/distributed/utils.py | 35 ++++ 8 files changed, 253 insertions(+), 21 deletions(-) create mode 100644 applications/ColossalChat/coati/distributed/grpo_consumer.py create mode 100644 applications/ColossalChat/coati/distributed/loss.py diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py index fc3a4930c058..04093e7057ae 100755 --- a/applications/ColossalChat/coati/dataset/loader.py +++ b/applications/ColossalChat/coati/dataset/loader.py @@ -356,6 +356,12 @@ def apply_chat_template_and_mask( truncation: bool = True, ignore_idx: int = -100, ) -> Dict[str, torch.Tensor]: + # Format for RL. + gt_answer = None + if "messages" in chat and "gt_answer" in chat: + gt_answer = chat["gt_answer"] + chat = [chat["messages"]] + tokens = [] assistant_mask = [] for i, msg in enumerate(chat): @@ -389,6 +395,11 @@ def apply_chat_template_and_mask( labels = input_ids.clone() labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx + if gt_answer is not None: + gt_answer = tokenizer.encode(gt_answer, padding="max_length", max_length=64, return_tensors="pt") + gt_answer = gt_answer.squeeze(1) + return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "gt_answer": gt_answer} + return { "input_ids": input_ids, "attention_mask": attention_mask, diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py new file mode 100644 index 000000000000..79128b89e17a --- /dev/null +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -0,0 +1,150 @@ +from contextlib import nullcontext +from typing import Optional + +import ray +import torch +from coati.distributed.consumer import BaseConsumer +from coati.distributed.loss import PolicyLoss +from coati.distributed.reward.reward_fn import math_reward_fn +from coati.distributed.reward.verifiable_reward import VerifiableReward +from coati.distributed.utils import calc_action_log_probs +from transformers import AutoModelForCausalLM, AutoTokenizer + +from colossalai.nn.optimizer import HybridAdam + + +@ray.remote +class GRPOConsumer(BaseConsumer): + def __init__( + self, + num_producers, + num_episodes, + rank, + world_size, + master_addr, + master_port, + num_update_per_episode, + num_recv_per_update, + batch_size, + model_config, + plugin_config, + microbatch_size=1, + ): + super().__init__( + num_producers, + num_episodes, + rank, + world_size, + master_addr, + master_port, + num_update_per_episode, + num_recv_per_update, + batch_size, + model_config, + plugin_config, + microbatch_size, + ) + path = model_config.pop("path") + self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) + self.policy_model.train() + self.policy_model.gradient_checkpointing_enable() + self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-4) + self.accum_loss = torch.zeros(1, device=self.device) + + # Reference model is initialized from policy model. + self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) + self.reference_model.eval() + + self.tokenizer = AutoTokenizer.from_pretrained(path) + self.pad_token_id = self.tokenizer.pad_token_id + + # Initialize verifiable reward. + response_format_tags = { + "think_start": {"text": "", "num_occur": 1}, + "think_end": {"text": "", "num_occur": 1}, + "answer_start": {"text": "", "num_occur": 1}, + "answer_end": {"text": "", "num_occur": 1}, + } + self.reward_model = VerifiableReward( + reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags + ) + + self.policy_loss_fn = PolicyLoss() + + def setup(self): + super().setup() + self.policy_model, self.optimizer, *_ = self.booster.boost(self.policy_model, self.optimizer) + self.reference_model, *_ = self.booster.boost(self.reference_model) + + def step(self, step_idx: int, **kwargs) -> Optional[float]: + """ + Step data from policy model: + [{ + "input_ids": torch.Tensor, + "attention_mask": torch.Tensor, + "action_mask": torch.Tensor, + "action_log_probs": torch.Tensor, + }, + ...] + Format: + [batch_size, prompt_length + response_length] --- ............. + """ + labels = kwargs["input_ids"].clone() + labels[kwargs["attention_mask"] == 0] = -100 + kwargs["labels"] = labels + sequences = kwargs["input_ids"] + action_mask = kwargs["action_mask"] + num_action = action_mask.shape[1] + old_action_log_probs = kwargs["action_log_probs"] + assert kwargs.pop("action_mask").shape == kwargs.pop("action_log_probs").shape + + need_update = (step_idx + 1) % self.num_microbatches == 0 + + ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer) + with ctx: + policy_model_logits = self.policy_model( + input_ids=kwargs["input_ids"], + attention_mask=kwargs["attention_mask"], + )["logits"] + action_log_probs = calc_action_log_probs(policy_model_logits, sequences, num_action) + + reference_model_logits = self.reference_model( + input_ids=sequences, + attention_mask=kwargs["attention_mask"], + )["logits"] + reference_action_log_probs = calc_action_log_probs(reference_model_logits, sequences, num_action) + + # GRPO advantage calculation + kl = torch.sum(-0.01 * (action_log_probs - reference_action_log_probs) * action_mask, dim=-1) / torch.sum( + action_mask, dim=-1 + ) + + reward = self.reward_model(sequences, gt_answer=kwargs["gt_answer"]) + reward = reward + kl + mean = reward.view(-1, reward.size(0)).mean(dim=1) + std = reward.view(-1, reward.size(0)).std(dim=1) + advantages = (reward - mean) / (std + 1e-4) + # Calculate Loss + loss, skip_update, _ = self.policy_loss_fn( + action_log_probs, + old_action_log_probs, + advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1), + action_mask, + ) + + loss = loss / self.num_microbatches + self.accum_loss.add_(loss.data) + if not skip_update: + self.booster.backward(loss, self.optimizer) + if need_update: + self.optimizer.step() + self.optimizer.zero_grad() + loss_scalar = self.accum_loss.item() + self.accum_loss.zero_() + return loss_scalar + + def state_dict(self): + self.policy_model._force_wait_all_gather() + model = self.policy_model.unwrap() + state_dict = model.state_dict() + return state_dict diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index 95b7d1e80308..210ed503635c 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -210,6 +210,8 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar "action_log_probs": log_probs, "action_mask": action_mask, } + if "gt_answer" in kwargs: + data["gt_answer"] = kwargs["gt_answer"] data = {k: v.to(get_current_device()) for k, v in data.items()} return data diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 438c463002f0..5244cc7d9fed 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -2,7 +2,7 @@ import ray -from .consumer import SimpleConsumer +from .grpo_consumer import GRPOConsumer from .producer import SimpleProducer @@ -68,7 +68,7 @@ def launch_distributed( ) procs.append(producer) for i in range(num_consumer_procs): - consumer = SimpleConsumer.options(num_gpus=1).remote( + consumer = GRPOConsumer.options(num_gpus=1).remote( num_producers=num_producers, num_episodes=num_episodes, rank=i, diff --git a/applications/ColossalChat/coati/distributed/loss.py b/applications/ColossalChat/coati/distributed/loss.py new file mode 100644 index 000000000000..c08acba511cf --- /dev/null +++ b/applications/ColossalChat/coati/distributed/loss.py @@ -0,0 +1,44 @@ +from typing import Optional + +import torch +import torch.nn as nn +from coati.distributed.utils import masked_mean + + +class PolicyLoss(nn.Module): + """ + Policy Loss for PPO + """ + + def __init__(self, clip_eps: float = 0.2, skip_threshold: float = 20.0) -> None: + super().__init__() + self.clip_eps = clip_eps + self.skip_threshold = skip_threshold + + def forward( + self, + log_probs: torch.Tensor, + old_log_probs: torch.Tensor, + advantages: torch.Tensor, + action_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + skip = False + if action_mask is None: + ratio_ = (log_probs - old_log_probs).exp() + else: + ratio_ = ((log_probs - old_log_probs) * action_mask).exp() + + # note that if dropout is disabled (recommanded), ratio will always be 1. + if ratio_.mean() > self.skip_threshold: + skip = True + + ratio = ratio_.clamp(0.0, 10.0) + surr1 = ratio * advantages + surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages + loss = -torch.min(surr1, surr2) + if action_mask is not None: + loss = masked_mean(loss, action_mask) + else: + loss = loss.mean(dim=1) + loss = loss.mean() + return loss, skip, ratio_.max() diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index f127a1eced20..c7b452c54545 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -3,17 +3,13 @@ from .reward_utils import extract_solution, validate_response_structure -def math_reward_fn(input_ids, **kwargs): - # apply varifiable reward - # reward 10 points if the final answer is correct, reward 1 point if format is correct - - gt_answer = kwargs["gt_answer"] +def math_reward_fn(input_ids, gt_answer, **kwargs): tokenizer = kwargs["tokenizer"] - s, e = kwargs["response_start"], kwargs["response_end"] reward = torch.tensor(0.0).to(input_ids.device) if gt_answer is None: return reward - decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) + decoded_final_answer = tokenizer.decode(input_ids, skip_special_tokens=True) + gt_answer = tokenizer.decode(gt_answer.squeeze(0)) final_answer, processed_str = extract_solution(decoded_final_answer) format_valid = validate_response_structure(processed_str, kwargs["tags"]) diff --git a/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py index d1700d86f14e..fe889a7f4a46 100644 --- a/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py +++ b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py @@ -8,33 +8,27 @@ class VerifiableReward: - def __init__(self, reward_fn: List[callable], reward_args: List[Dict[str, Any]]): - self.reward_fn = reward_fn - self.reward_args = reward_args + def __init__(self, reward_fns: List[callable], **kwargs: List[Dict[str, Any]]): + self.reward_fns = reward_fns + self.kwargs = kwargs def __call__( self, input_ids: torch.LongTensor, - attention_mask: torch.LongTensor, - response_start: List[int] = None, - response_end: List[int] = None, - gt_answer: List[str] = None, + gt_answer: List[torch.Tensor] = None, ) -> torch.Tensor: # Get batch size bs = input_ids.size(0) # Initialize reward - reward = torch.zeros(bs, device=input_ids.device) + rewards = torch.zeros(bs, device=input_ids.device) # Loop through reward functions - for reward_fn in self.reward_fn_list: + for reward_fn in self.reward_fns: # Apply the reward function to the entire batch at once reward_batch = torch.stack( [ reward_fn( input_ids[i], - attention_mask[i], - response_start=response_start[i], - response_end=response_end[i], gt_answer=gt_answer[i], **self.kwargs, ) diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index 533a5ffb22da..98b54815b5b4 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -64,3 +64,38 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T log_probs = torch.log_softmax(logits, dim=-1) per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)) return per_label_logps.squeeze(-1) + + +def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor: + """Calculate action log probs. + + Args: + output (torch.Tensor): Output tensor of Actor.forward.logits. + sequences (torch.LongTensor): Input sequences. + num_actions (int): Number of actions. + + Returns: + torch.Tensor: Action log probs. + """ + log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) + return log_probs[:, -num_actions:] + + +def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor: + """ + Compute the masked mean of a tensor along a specified dimension. + + Args: + tensor (torch.Tensor): The input tensor. + mask (torch.Tensor): The mask tensor with the same shape as the input tensor. + dim (int, optional): The dimension along which to compute the mean. Default is 1. + + Returns: + torch.Tensor: The masked mean tensor. + + """ + tensor = tensor * mask + tensor = tensor.sum(dim=dim) + mask_sum = mask.sum(dim=dim) + mean = tensor / (mask_sum + 1e-8) + return mean From f736d747e3b3c2601a15d240db669607e8aacba9 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Tue, 25 Feb 2025 18:12:04 +0800 Subject: [PATCH 03/35] update grpo --- .../coati/distributed/grpo_consumer.py | 70 +++++++++++++------ 1 file changed, 50 insertions(+), 20 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 79128b89e17a..d88df2360df7 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -3,11 +3,13 @@ import ray import torch +import wandb from coati.distributed.consumer import BaseConsumer from coati.distributed.loss import PolicyLoss from coati.distributed.reward.reward_fn import math_reward_fn from coati.distributed.reward.verifiable_reward import VerifiableReward from coati.distributed.utils import calc_action_log_probs +from coati.trainer.utils import all_reduce_mean, is_rank_0 from transformers import AutoModelForCausalLM, AutoTokenizer from colossalai.nn.optimizer import HybridAdam @@ -29,6 +31,8 @@ def __init__( model_config, plugin_config, microbatch_size=1, + num_generations=4, + use_wandb=False, ): super().__init__( num_producers, @@ -50,6 +54,8 @@ def __init__( self.policy_model.gradient_checkpointing_enable() self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-4) self.accum_loss = torch.zeros(1, device=self.device) + self.accum_reward = torch.zeros(1, device=self.device) + self.accum_kl = torch.zeros(1, device=self.device) # Reference model is initialized from policy model. self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) @@ -57,6 +63,7 @@ def __init__( self.tokenizer = AutoTokenizer.from_pretrained(path) self.pad_token_id = self.tokenizer.pad_token_id + self.num_generations = num_generations # Initialize verifiable reward. response_format_tags = { @@ -70,6 +77,8 @@ def __init__( ) self.policy_loss_fn = PolicyLoss() + if is_rank_0(): + self.run = wandb.init(project="Colossal-GRPO-Test4") def setup(self): super().setup() @@ -87,43 +96,52 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: }, ...] Format: - [batch_size, prompt_length + response_length] --- ............. + [batch_size, num_of_generation, prompt_length + response_length] --- ............. """ - labels = kwargs["input_ids"].clone() - labels[kwargs["attention_mask"] == 0] = -100 - kwargs["labels"] = labels - sequences = kwargs["input_ids"] - action_mask = kwargs["action_mask"] + + # Reshape to [batch_size x num_of_generation, prompt_length + response_length] + data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()} + action_mask = data["action_mask"] num_action = action_mask.shape[1] - old_action_log_probs = kwargs["action_log_probs"] - assert kwargs.pop("action_mask").shape == kwargs.pop("action_log_probs").shape + old_action_log_probs = data["action_log_probs"] need_update = (step_idx + 1) % self.num_microbatches == 0 ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer) with ctx: policy_model_logits = self.policy_model( - input_ids=kwargs["input_ids"], - attention_mask=kwargs["attention_mask"], + input_ids=data["input_ids"], + attention_mask=data["attention_mask"], )["logits"] - action_log_probs = calc_action_log_probs(policy_model_logits, sequences, num_action) + action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action) reference_model_logits = self.reference_model( - input_ids=sequences, - attention_mask=kwargs["attention_mask"], + input_ids=data["input_ids"], + attention_mask=data["attention_mask"], )["logits"] - reference_action_log_probs = calc_action_log_probs(reference_model_logits, sequences, num_action) + reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action) + + # GRPO advantage calculation + kl = torch.sum(-0.1 * (action_log_probs - reference_action_log_probs) * action_mask, dim=-1) / torch.sum( + action_mask, dim=-1 + ) + + reward = self.reward_model(data["input_ids"], gt_answer=data["gt_answer"]) + reward = kl + reward + # [batch_size, num_generations] + group_reward = reward.view(-1, self.num_generations) + + # [batch_size x num_generations] + reward_mean = group_reward.mean(dim=1).repeat_interleave(self.num_generations, dim=0) + reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) + # [batch_size x num_generations] + advantages = (group_reward.view(-1) - reward_mean) / (reward_std + 1e-4) # GRPO advantage calculation kl = torch.sum(-0.01 * (action_log_probs - reference_action_log_probs) * action_mask, dim=-1) / torch.sum( action_mask, dim=-1 ) - reward = self.reward_model(sequences, gt_answer=kwargs["gt_answer"]) - reward = reward + kl - mean = reward.view(-1, reward.size(0)).mean(dim=1) - std = reward.view(-1, reward.size(0)).std(dim=1) - advantages = (reward - mean) / (std + 1e-4) # Calculate Loss loss, skip_update, _ = self.policy_loss_fn( action_log_probs, @@ -133,14 +151,26 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: ) loss = loss / self.num_microbatches - self.accum_loss.add_(loss.data) if not skip_update: self.booster.backward(loss, self.optimizer) + loss = all_reduce_mean(loss) + reward = all_reduce_mean(reward.mean()) + kl = all_reduce_mean(kl.mean()) + self.accum_loss.add_(loss.data) + self.accum_reward.add_(reward.data) + self.accum_kl.add_(kl.data) if need_update: self.optimizer.step() self.optimizer.zero_grad() loss_scalar = self.accum_loss.item() + if is_rank_0(): + print("Loss:", self.accum_loss.item(), "Reward:", self.accum_reward.item(), "KL:", self.accum_kl.item()) + self.run.log( + {"loss": self.accum_loss.item(), "reward": self.accum_reward.item(), "kl": self.accum_kl.item()} + ) self.accum_loss.zero_() + self.accum_reward.zero_() + self.accum_kl.zero_() return loss_scalar def state_dict(self): From 070907dd7fe9d2dcf32e06b928a54a73b406bd83 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Fri, 28 Feb 2025 10:16:42 +0800 Subject: [PATCH 04/35] polish --- .../ColossalChat/coati/dataset/loader.py | 27 ++++++++----- .../coati/distributed/grpo_consumer.py | 40 ++++++++++++++----- .../coati/distributed/inference_backend.py | 15 ++++++- .../coati/distributed/reward/reward_fn.py | 8 ++-- .../distributed/reward/verifiable_reward.py | 2 + applications/ColossalChat/rl_example.py | 8 +++- 6 files changed, 74 insertions(+), 26 deletions(-) diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py index 04093e7057ae..bfee5dce0304 100755 --- a/applications/ColossalChat/coati/dataset/loader.py +++ b/applications/ColossalChat/coati/dataset/loader.py @@ -356,6 +356,14 @@ def apply_chat_template_and_mask( truncation: bool = True, ignore_idx: int = -100, ) -> Dict[str, torch.Tensor]: + + system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the tags, your final answer should be a integer without unit, currency mark, thousands separator or other text. i.e., 123 .\n" + + system_element = { + "role": "system", + "content": system_prompt, + } + # Format for RL. gt_answer = None if "messages" in chat and "gt_answer" in chat: @@ -365,7 +373,7 @@ def apply_chat_template_and_mask( tokens = [] assistant_mask = [] for i, msg in enumerate(chat): - msg_tokens = tokenizer.apply_chat_template([msg], tokenize=True) + msg_tokens = tokenizer.apply_chat_template([system_element, msg], tokenize=True, add_generation_prompt=True) # remove unexpected bos token if i > 0 and msg_tokens[0] == tokenizer.bos_token_id: msg_tokens = msg_tokens[1:] @@ -378,14 +386,15 @@ def apply_chat_template_and_mask( if max_length is not None: if padding and len(tokens) < max_length: to_pad = max_length - len(tokens) - if tokenizer.padding_side == "right": - tokens.extend([tokenizer.pad_token_id] * to_pad) - assistant_mask.extend([False] * to_pad) - attention_mask.extend([0] * to_pad) - else: - tokens = [tokenizer.pad_token_id] * to_pad + tokens - assistant_mask = [False] * to_pad + assistant_mask - attention_mask = [0] * to_pad + attention_mask + # Left padding for generation. + # if tokenizer.padding_side == "right": + # tokens.extend([tokenizer.pad_token_id] * to_pad) + # assistant_mask.extend([False] * to_pad) + # attention_mask.extend([0] * to_pad) + # else: + tokens = [tokenizer.pad_token_id] * to_pad + tokens + assistant_mask = [False] * to_pad + assistant_mask + attention_mask = [0] * to_pad + attention_mask if truncation and len(tokens) > max_length: tokens = tokens[:max_length] assistant_mask = assistant_mask[:max_length] diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index d88df2360df7..2f230f5ed574 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -9,7 +9,7 @@ from coati.distributed.reward.reward_fn import math_reward_fn from coati.distributed.reward.verifiable_reward import VerifiableReward from coati.distributed.utils import calc_action_log_probs -from coati.trainer.utils import all_reduce_mean, is_rank_0 +from coati.trainer.utils import all_reduce_mean from transformers import AutoModelForCausalLM, AutoTokenizer from colossalai.nn.optimizer import HybridAdam @@ -77,8 +77,15 @@ def __init__( ) self.policy_loss_fn = PolicyLoss() - if is_rank_0(): - self.run = wandb.init(project="Colossal-GRPO-Test4") + self.global_step = 0 + if self.rank == 0: + self.wandb_run = wandb.init(project="Colossal-GRPO-Test6", sync_tensorboard=True) + # import os + # import time + + # log_dir = self.wandb_run.dir + # # log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) + # # self.writer = SummaryWriter(log_dir=log_dir) def setup(self): super().setup() @@ -115,10 +122,11 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: )["logits"] action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action) - reference_model_logits = self.reference_model( - input_ids=data["input_ids"], - attention_mask=data["attention_mask"], - )["logits"] + with torch.no_grad(): + reference_model_logits = self.reference_model( + input_ids=data["input_ids"], + attention_mask=data["attention_mask"], + )["logits"] reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action) # GRPO advantage calculation @@ -126,7 +134,9 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: action_mask, dim=-1 ) - reward = self.reward_model(data["input_ids"], gt_answer=data["gt_answer"]) + reward = self.reward_model( + data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"] + ) reward = kl + reward # [batch_size, num_generations] group_reward = reward.view(-1, self.num_generations) @@ -163,11 +173,19 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: self.optimizer.step() self.optimizer.zero_grad() loss_scalar = self.accum_loss.item() - if is_rank_0(): + if self.rank == 0: print("Loss:", self.accum_loss.item(), "Reward:", self.accum_reward.item(), "KL:", self.accum_kl.item()) - self.run.log( - {"loss": self.accum_loss.item(), "reward": self.accum_reward.item(), "kl": self.accum_kl.item()} + self.wandb_run.log( + { + "train/loss": self.accum_loss.item(), + "train/reward": self.accum_reward.item(), + "train/kl": self.accum_kl.item(), + } ) + # self.writer.add_scalar("train/loss", self.accum_loss.item(), self.global_step) + # self.writer.add_scalar("train/reward", self.accum_reward.item(), self.global_step) + # self.writer.add_scalar("train/kl", self.accum_kl.item(), self.global_step) + # self.global_step += 1 self.accum_loss.zero_() self.accum_reward.zero_() self.accum_kl.zero_() diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index 210ed503635c..bc0ae5c36673 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -154,6 +154,7 @@ class VLLMInferenceBackend(BaseInferenceBackend): ) FORCE_GENERATE_CONFIG = dict( logprobs=0, + n=4, ) def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): @@ -166,19 +167,24 @@ def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any] generate_config.update(self.FORCE_GENERATE_CONFIG) self.generate_config = SamplingParams(**generate_config) self.tokenizer = tokenizer + self.num_generations = self.FORCE_GENERATE_CONFIG["n"] @torch.no_grad() def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: + micro_batch_size = input_ids.size(0) + response_start_idx = input_ids.size(1) outputs = self.llm.generate( prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False ) out_tokens = [] out_len = [] log_probs = [] + response_idx = [] for out in outputs: for output_i in out.outputs: out_len.append(len(output_i.token_ids)) out_tokens.append(list(output_i.token_ids)) + response_idx.append((response_start_idx, response_start_idx + len(output_i.token_ids) - 1)) assert len(output_i.logprobs) == len(output_i.token_ids) p = [m[t].logprob for m, t in zip(output_i.logprobs, output_i.token_ids)] log_probs.append(p) @@ -195,6 +201,8 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar out_tokens = torch.tensor(out_tokens) log_probs = torch.tensor(log_probs) + response_idx = torch.tensor(response_idx) + if attention_mask.size(0) != action_mask.size(0): assert action_mask.size(0) % attention_mask.size(0) == 0 num_returns = action_mask.size(0) // attention_mask.size(0) @@ -209,9 +217,14 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar "attention_mask": attention_mask, "action_log_probs": log_probs, "action_mask": action_mask, + "response_idx": response_idx, } + + data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()} + if "gt_answer" in kwargs: - data["gt_answer"] = kwargs["gt_answer"] + # repeat gt_answer for each prompt. + data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(self.num_generations, dim=1) data = {k: v.to(get_current_device()) for k, v in data.items()} return data diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index c7b452c54545..c92f822f7373 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -3,12 +3,14 @@ from .reward_utils import extract_solution, validate_response_structure -def math_reward_fn(input_ids, gt_answer, **kwargs): +def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): tokenizer = kwargs["tokenizer"] reward = torch.tensor(0.0).to(input_ids.device) + s, e = response_idx[0], response_idx[1] if gt_answer is None: return reward - decoded_final_answer = tokenizer.decode(input_ids, skip_special_tokens=True) + + decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) gt_answer = tokenizer.decode(gt_answer.squeeze(0)) final_answer, processed_str = extract_solution(decoded_final_answer) @@ -29,7 +31,7 @@ def gsm8k_reward_fn(input_ids, **kwargs): reward = torch.tensor(0.0).to(input_ids.device) if gt_answer is None: return reward - decoded_final_answer = tokenizer.decode(input_ids[s:e], skip_special_tokens=True) + decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) final_answer, processed_str = extract_solution(decoded_final_answer) is_valid = True try: diff --git a/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py index fe889a7f4a46..b43ba65c0ab7 100644 --- a/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py +++ b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py @@ -16,6 +16,7 @@ def __call__( self, input_ids: torch.LongTensor, gt_answer: List[torch.Tensor] = None, + response_idx: List[torch.Tensor] = None, ) -> torch.Tensor: # Get batch size bs = input_ids.size(0) @@ -30,6 +31,7 @@ def __call__( reward_fn( input_ids[i], gt_answer=gt_answer[i], + response_idx=response_idx[i], **self.kwargs, ) for i in range(bs) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index a6f82b3bebf9..1b5c18486b35 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -51,13 +51,17 @@ elif args.backend == "vllm": inference_model_config.update( dict( - gpu_memory_utilization=0.6, + gpu_memory_utilization=0.7, ) ) generate_config.update( dict( - max_tokens=256, + max_tokens=2048, ignore_eos=True, + include_stop_str_in_output=True, + stop=[""], + temperature=0.2, + top_p=0.95, ) ) else: From c15225bc528f341dbd7757e3279439a18e065a5d Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 6 Mar 2025 10:49:44 +0800 Subject: [PATCH 05/35] modify data loader --- applications/ColossalChat/coati/dataset/loader.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py index bfee5dce0304..265449326541 100755 --- a/applications/ColossalChat/coati/dataset/loader.py +++ b/applications/ColossalChat/coati/dataset/loader.py @@ -387,11 +387,6 @@ def apply_chat_template_and_mask( if padding and len(tokens) < max_length: to_pad = max_length - len(tokens) # Left padding for generation. - # if tokenizer.padding_side == "right": - # tokens.extend([tokenizer.pad_token_id] * to_pad) - # assistant_mask.extend([False] * to_pad) - # attention_mask.extend([0] * to_pad) - # else: tokens = [tokenizer.pad_token_id] * to_pad + tokens assistant_mask = [False] * to_pad + assistant_mask attention_mask = [0] * to_pad + attention_mask @@ -405,7 +400,9 @@ def apply_chat_template_and_mask( labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx if gt_answer is not None: - gt_answer = tokenizer.encode(gt_answer, padding="max_length", max_length=64, return_tensors="pt") + gt_answer = tokenizer.encode( + gt_answer, padding="max_length", truncation=True, max_length=128, return_tensors="pt" + ) gt_answer = gt_answer.squeeze(1) return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "gt_answer": gt_answer} From b96d69055e693d8c84d62a00dae896e50d3a7e60 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 6 Mar 2025 10:51:27 +0800 Subject: [PATCH 06/35] grpo consumer --- .../coati/distributed/grpo_consumer.py | 56 +++++++++---------- 1 file changed, 26 insertions(+), 30 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 2f230f5ed574..49240d8da176 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -52,10 +52,11 @@ def __init__( self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.policy_model.train() self.policy_model.gradient_checkpointing_enable() - self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-4) + self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-6) self.accum_loss = torch.zeros(1, device=self.device) self.accum_reward = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device) + self.accum_count = 0 # Reference model is initialized from policy model. self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) @@ -79,13 +80,7 @@ def __init__( self.policy_loss_fn = PolicyLoss() self.global_step = 0 if self.rank == 0: - self.wandb_run = wandb.init(project="Colossal-GRPO-Test6", sync_tensorboard=True) - # import os - # import time - - # log_dir = self.wandb_run.dir - # # log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) - # # self.writer = SummaryWriter(log_dir=log_dir) + self.wandb_run = wandb.init(project="GRPO-Test", sync_tensorboard=True) def setup(self): super().setup() @@ -129,15 +124,16 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: )["logits"] reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action) - # GRPO advantage calculation - kl = torch.sum(-0.1 * (action_log_probs - reference_action_log_probs) * action_mask, dim=-1) / torch.sum( - action_mask, dim=-1 + per_token_kl = ( + torch.exp(reference_action_log_probs - action_log_probs) + - (reference_action_log_probs - action_log_probs) + - 1 ) + kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1) reward = self.reward_model( data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"] ) - reward = kl + reward # [batch_size, num_generations] group_reward = reward.view(-1, self.num_generations) @@ -145,50 +141,50 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: reward_mean = group_reward.mean(dim=1).repeat_interleave(self.num_generations, dim=0) reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) # [batch_size x num_generations] - advantages = (group_reward.view(-1) - reward_mean) / (reward_std + 1e-4) - - # GRPO advantage calculation - kl = torch.sum(-0.01 * (action_log_probs - reference_action_log_probs) * action_mask, dim=-1) / torch.sum( - action_mask, dim=-1 - ) + advantages = (reward - reward_mean) / (reward_std + 1e-4) # Calculate Loss loss, skip_update, _ = self.policy_loss_fn( action_log_probs, old_action_log_probs, advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1), + per_token_kl, action_mask, ) - loss = loss / self.num_microbatches if not skip_update: self.booster.backward(loss, self.optimizer) - loss = all_reduce_mean(loss) - reward = all_reduce_mean(reward.mean()) - kl = all_reduce_mean(kl.mean()) + loss = all_reduce_mean(loss, self.plugin) + reward = all_reduce_mean(reward.mean(), self.plugin) + kl = all_reduce_mean(kl.mean(), self.plugin) self.accum_loss.add_(loss.data) self.accum_reward.add_(reward.data) self.accum_kl.add_(kl.data) + self.accum_count += 1 if need_update: self.optimizer.step() self.optimizer.zero_grad() loss_scalar = self.accum_loss.item() if self.rank == 0: - print("Loss:", self.accum_loss.item(), "Reward:", self.accum_reward.item(), "KL:", self.accum_kl.item()) + print( + "Loss:", + self.accum_loss.item() / self.accum_count, + "Reward:", + self.accum_reward.item() / self.accum_count, + "KL:", + self.accum_kl.item() / self.accum_count, + ) self.wandb_run.log( { - "train/loss": self.accum_loss.item(), - "train/reward": self.accum_reward.item(), - "train/kl": self.accum_kl.item(), + "train/loss": self.accum_loss.item() / self.accum_count, + "train/reward": self.accum_reward.item() / self.accum_count, + "train/kl": self.accum_kl.item() / self.accum_count, } ) - # self.writer.add_scalar("train/loss", self.accum_loss.item(), self.global_step) - # self.writer.add_scalar("train/reward", self.accum_reward.item(), self.global_step) - # self.writer.add_scalar("train/kl", self.accum_kl.item(), self.global_step) - # self.global_step += 1 self.accum_loss.zero_() self.accum_reward.zero_() self.accum_kl.zero_() + self.accum_count = 0 return loss_scalar def state_dict(self): From 678f5a9ecac480f5f9dfe19c9aa5668e53b77976 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 6 Mar 2025 10:53:03 +0800 Subject: [PATCH 07/35] update loss --- applications/ColossalChat/coati/distributed/loss.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/loss.py b/applications/ColossalChat/coati/distributed/loss.py index c08acba511cf..222540b9270d 100644 --- a/applications/ColossalChat/coati/distributed/loss.py +++ b/applications/ColossalChat/coati/distributed/loss.py @@ -10,16 +10,18 @@ class PolicyLoss(nn.Module): Policy Loss for PPO """ - def __init__(self, clip_eps: float = 0.2, skip_threshold: float = 20.0) -> None: + def __init__(self, clip_eps: float = 0.2, skip_threshold: float = 20.0, beta: float = 0.01) -> None: super().__init__() self.clip_eps = clip_eps self.skip_threshold = skip_threshold + self.beta = beta def forward( self, log_probs: torch.Tensor, old_log_probs: torch.Tensor, advantages: torch.Tensor, + per_token_kl: torch.Tensor, action_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: skip = False @@ -35,7 +37,8 @@ def forward( ratio = ratio_.clamp(0.0, 10.0) surr1 = ratio * advantages surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages - loss = -torch.min(surr1, surr2) + loss = -torch.min(surr1, surr2) + self.beta * per_token_kl + if action_mask is not None: loss = masked_mean(loss, action_mask) else: From d03cdea949425f02b0f147e8b604f6d770a5cfeb Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 6 Mar 2025 10:53:48 +0800 Subject: [PATCH 08/35] update reward fn --- .../ColossalChat/coati/distributed/reward/reward_fn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index c92f822f7373..9e6d1066e1df 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -11,7 +11,7 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): return reward decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) - gt_answer = tokenizer.decode(gt_answer.squeeze(0)) + gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True) final_answer, processed_str = extract_solution(decoded_final_answer) format_valid = validate_response_structure(processed_str, kwargs["tags"]) @@ -20,7 +20,7 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): else: reward += 1.0 if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower(): - reward = reward + 9.0 + reward = reward + 2.0 return reward From 7f2ceac5c3acc75b8366062970eb124ce0a56e2c Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 6 Mar 2025 10:54:23 +0800 Subject: [PATCH 09/35] update example --- applications/ColossalChat/rl_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 1b5c18486b35..57cab4164f9d 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -60,7 +60,7 @@ ignore_eos=True, include_stop_str_in_output=True, stop=[""], - temperature=0.2, + temperature=0.7, top_p=0.95, ) ) From 812f4b775040c3de0fb40e973190a9cf68abe93c Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 6 Mar 2025 11:44:42 +0800 Subject: [PATCH 10/35] update loader --- applications/ColossalChat/coati/dataset/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py index 265449326541..35b3deced717 100755 --- a/applications/ColossalChat/coati/dataset/loader.py +++ b/applications/ColossalChat/coati/dataset/loader.py @@ -357,7 +357,7 @@ def apply_chat_template_and_mask( ignore_idx: int = -100, ) -> Dict[str, torch.Tensor]: - system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the tags, your final answer should be a integer without unit, currency mark, thousands separator or other text. i.e., 123 .\n" + system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a math problem that involves reasoning.\n" system_element = { "role": "system", From 0f566cc2d49657fec09c34aa50389f951f55103b Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 6 Mar 2025 14:29:22 +0800 Subject: [PATCH 11/35] add algo selection --- .../ColossalChat/coati/distributed/launch.py | 15 ++++++++++++++- applications/ColossalChat/rl_example.py | 2 ++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 5244cc7d9fed..8581ff5865f8 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -2,9 +2,15 @@ import ray +from .consumer import SimpleConsumer from .grpo_consumer import GRPOConsumer from .producer import SimpleProducer +ALGO_MAP = { + "Simple": SimpleConsumer, + "GRPO": GRPOConsumer, +} + def get_jsonl_size_fast(path: str) -> int: with open(path) as f: @@ -40,7 +46,14 @@ def launch_distributed( inference_backend: str = "transformers", master_addr: str = "localhost", master_port: int = 29500, + core_algo: str = "GRPO", ): + + if core_algo not in ALGO_MAP: + raise NotImplementedError(f"{core_algo} is not supported yet.") + else: + core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer) + train_dp_size = get_dp_size_fast(num_producers, plugin_config) assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0 @@ -68,7 +81,7 @@ def launch_distributed( ) procs.append(producer) for i in range(num_consumer_procs): - consumer = GRPOConsumer.options(num_gpus=1).remote( + consumer = core_consumer.options(num_gpus=1).remote( num_producers=num_producers, num_episodes=num_episodes, rank=i, diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 57cab4164f9d..40231582d787 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -15,6 +15,7 @@ parser.add_argument("-tbs", "--train-batch-size", type=int, default=16) parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2) parser.add_argument("-b", "--backend", type=str, default="transformers") + parser.add_argument("-a", "--algo", type=str, default="GPRO", choices=["Simple, GPRO"]) args = parser.parse_args() ray.init(address="local", namespace="ray-example") @@ -95,4 +96,5 @@ inference_backend=args.backend, master_addr="localhost", master_port=29504, + core_algo=args.algo ) From ab5b6d8432ea1287a21345993b1755e40c07a0c7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 6 Mar 2025 06:30:26 +0000 Subject: [PATCH 12/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- applications/ColossalChat/rl_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 40231582d787..77c567913523 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -96,5 +96,5 @@ inference_backend=args.backend, master_addr="localhost", master_port=29504, - core_algo=args.algo + core_algo=args.algo, ) From 0cc0c843ede88f9e7ae88988b99237f40950acd1 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 6 Mar 2025 16:26:14 +0800 Subject: [PATCH 13/35] add save --- applications/ColossalChat/coati/dataset/loader.py | 4 +++- .../ColossalChat/coati/distributed/consumer.py | 14 +++++++++++++- .../coati/distributed/grpo_consumer.py | 4 ++-- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py index 35b3deced717..1a6b04d43924 100755 --- a/applications/ColossalChat/coati/dataset/loader.py +++ b/applications/ColossalChat/coati/dataset/loader.py @@ -357,7 +357,9 @@ def apply_chat_template_and_mask( ignore_idx: int = -100, ) -> Dict[str, torch.Tensor]: - system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a math problem that involves reasoning.\n" + system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the tags, i.e., 123 .\n\n" + + system_element = { "role": "system", diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 84a69979fb88..a99198d7ed8f 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -1,6 +1,6 @@ from contextlib import nullcontext from typing import Any, Dict, Optional - +import os import ray import ray.util.collective as cc import torch @@ -33,6 +33,8 @@ def __init__( model_config: Dict[str, Any], plugin_config: Dict[str, Any], microbatch_size: int = 1, + save_interval: int = 100, + save_dir: str = "./model" ): self.num_producers = num_producers self.num_episodes = num_episodes @@ -44,6 +46,8 @@ def __init__( self.num_recv_per_update = num_recv_per_update self.batch_size = batch_size self.microbatch_size = microbatch_size + self.save_interval = save_interval + self.save_dir = save_dir assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size" self.num_microbatches = batch_size // microbatch_size @@ -116,6 +120,14 @@ def loop(self) -> None: pbar.set_postfix({"loss": loss}) i += 1 assert len(self.buffer) == 0 + if (step + 1) % self.save_interval == 0: + if self.rank == 0: + print(f"Start saving policy model at step {step + 1}.") + save_path = os.path.join(self.save_dir, f"modeling-step-{step + 1}") + self.booster.save_model(self.policy_model, save_path, shard=True) + if self.rank == 0: + print(f"Saved model checkpoint at step {step + 1} in folder {save_path}") + if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1: print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") state_dict = self.state_dict() diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 49240d8da176..ead0c86e032a 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -32,7 +32,7 @@ def __init__( plugin_config, microbatch_size=1, num_generations=4, - use_wandb=False, + use_wandb=True, ): super().__init__( num_producers, @@ -79,7 +79,7 @@ def __init__( self.policy_loss_fn = PolicyLoss() self.global_step = 0 - if self.rank == 0: + if use_wandb and self.rank == 0: self.wandb_run = wandb.init(project="GRPO-Test", sync_tensorboard=True) def setup(self): From 0590f10fb78f45e9b94c4a928379262b9bfad4e0 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 6 Mar 2025 16:27:13 +0800 Subject: [PATCH 14/35] update select algo --- applications/ColossalChat/rl_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 40231582d787..347ffeeae61b 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -15,7 +15,7 @@ parser.add_argument("-tbs", "--train-batch-size", type=int, default=16) parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2) parser.add_argument("-b", "--backend", type=str, default="transformers") - parser.add_argument("-a", "--algo", type=str, default="GPRO", choices=["Simple, GPRO"]) + parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple, GPRO"]) args = parser.parse_args() ray.init(address="local", namespace="ray-example") From eb6337f07f5b01c181762ebb9e5338005552bca6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 6 Mar 2025 08:29:58 +0000 Subject: [PATCH 15/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- applications/ColossalChat/coati/dataset/loader.py | 2 -- applications/ColossalChat/coati/distributed/consumer.py | 5 +++-- applications/ColossalChat/coati/distributed/grpo_consumer.py | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py index 1a6b04d43924..4518fd71f91b 100755 --- a/applications/ColossalChat/coati/dataset/loader.py +++ b/applications/ColossalChat/coati/dataset/loader.py @@ -359,8 +359,6 @@ def apply_chat_template_and_mask( system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the tags, i.e., 123 .\n\n" - - system_element = { "role": "system", "content": system_prompt, diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index a99198d7ed8f..1e85cccb3c5b 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -1,6 +1,7 @@ +import os from contextlib import nullcontext from typing import Any, Dict, Optional -import os + import ray import ray.util.collective as cc import torch @@ -34,7 +35,7 @@ def __init__( plugin_config: Dict[str, Any], microbatch_size: int = 1, save_interval: int = 100, - save_dir: str = "./model" + save_dir: str = "./model", ): self.num_producers = num_producers self.num_episodes = num_episodes diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index ead0c86e032a..15f7e340ebb3 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -79,7 +79,7 @@ def __init__( self.policy_loss_fn = PolicyLoss() self.global_step = 0 - if use_wandb and self.rank == 0: + if use_wandb and self.rank == 0: self.wandb_run = wandb.init(project="GRPO-Test", sync_tensorboard=True) def setup(self): From 9d9d51614e6aa9d047dc5ef4c7f89535f1e071a8 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Mon, 10 Mar 2025 14:12:04 +0800 Subject: [PATCH 16/35] update grpo --- .../coati/distributed/grpo_consumer.py | 28 ++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 15f7e340ebb3..ae9f2c400bce 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -56,6 +56,9 @@ def __init__( self.accum_loss = torch.zeros(1, device=self.device) self.accum_reward = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device) + self.accum_format_reward = torch.zeros(1, device=self.device) + self.accum_acc_reward = torch.zeros(1, device=self.device) + self.accum_advantages = torch.zeros(1, device=self.device) self.accum_count = 0 # Reference model is initialized from policy model. @@ -131,9 +134,14 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: ) kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1) - reward = self.reward_model( + reward_group = self.reward_model( data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"] ) + + reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device) + format_reward = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device) + acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device) + # [batch_size, num_generations] group_reward = reward.view(-1, self.num_generations) @@ -157,9 +165,16 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: loss = all_reduce_mean(loss, self.plugin) reward = all_reduce_mean(reward.mean(), self.plugin) kl = all_reduce_mean(kl.mean(), self.plugin) + format_reward = all_reduce_mean(format_reward.mean(), self.plugin) + acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin) + advantages = all_reduce_mean(advantages.mean(), self.plugin) + # Calculate accumulate value. self.accum_loss.add_(loss.data) self.accum_reward.add_(reward.data) self.accum_kl.add_(kl.data) + self.accum_format_reward.add_(format_reward.data) + self.accum_acc_reward.add_(acc_reward.data) + self.accum_advantages.add_(advantages.data) self.accum_count += 1 if need_update: self.optimizer.step() @@ -173,17 +188,28 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: self.accum_reward.item() / self.accum_count, "KL:", self.accum_kl.item() / self.accum_count, + "Format Reward:", + self.accum_format_reward.item() / self.accum_count, + "Acc Reward:", + self.accum_acc_reward.item() / self.accum_count, + "Advantages:", + self.accum_advantages.item() / self.accum_count, ) self.wandb_run.log( { "train/loss": self.accum_loss.item() / self.accum_count, "train/reward": self.accum_reward.item() / self.accum_count, "train/kl": self.accum_kl.item() / self.accum_count, + "train/format_reward": self.accum_format_reward.item() / self.accum_count, + "train/acc_reward": self.accum_acc_reward.item() / self.accum_count, + "train/advantages": self.accum_advantages.item() / self.accum_count, } ) self.accum_loss.zero_() self.accum_reward.zero_() self.accum_kl.zero_() + self.accum_acc_reward.zero_() + self.accum_format_reward.zero_() self.accum_count = 0 return loss_scalar From 754b16dfbf5e83d7764925dae631b9e65ae67a7a Mon Sep 17 00:00:00 2001 From: Tong Li Date: Mon, 10 Mar 2025 14:18:22 +0800 Subject: [PATCH 17/35] update reward fn --- .../coati/distributed/reward/reward_fn.py | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index 9e6d1066e1df..da19c7d22458 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -5,7 +5,9 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): tokenizer = kwargs["tokenizer"] - reward = torch.tensor(0.0).to(input_ids.device) + reward = torch.tensor(0.0) + format_reward = torch.tensor(0.0) + acc_reward = torch.tensor(0.0) s, e = response_idx[0], response_idx[1] if gt_answer is None: return reward @@ -15,13 +17,21 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): final_answer, processed_str = extract_solution(decoded_final_answer) format_valid = validate_response_structure(processed_str, kwargs["tags"]) - if not format_valid: - return reward - else: + + # Check format accuracy + if format_valid: + format_reward += 1.0 reward += 1.0 - if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower(): - reward = reward + 2.0 - return reward + + # Check answer accuracy + if ( + final_answer is not None + and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower() + ): + acc_reward += 5.0 + reward += 5.0 + + return torch.tensor([reward, format_reward, acc_reward]).to(input_ids.device) def gsm8k_reward_fn(input_ids, **kwargs): From 71a0181fcec3da8768918576b05796be8dddbe0b Mon Sep 17 00:00:00 2001 From: Tong Li Date: Mon, 10 Mar 2025 14:19:10 +0800 Subject: [PATCH 18/35] update reward --- .../ColossalChat/coati/distributed/reward/verifiable_reward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py index b43ba65c0ab7..ba83f7787586 100644 --- a/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py +++ b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py @@ -21,7 +21,7 @@ def __call__( # Get batch size bs = input_ids.size(0) # Initialize reward - rewards = torch.zeros(bs, device=input_ids.device) + rewards = torch.zeros((bs, 3), device=input_ids.device) # Loop through reward functions for reward_fn in self.reward_fns: From abca66e69f2b1cc58ab98f0efaaf09164da59fe5 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Tue, 11 Mar 2025 10:17:32 +0800 Subject: [PATCH 19/35] fix reward score --- .../ColossalChat/coati/distributed/reward/reward_fn.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index da19c7d22458..53bc15e25a8c 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -4,6 +4,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): + format_score = 1.0 + acc_score = 9.0 tokenizer = kwargs["tokenizer"] reward = torch.tensor(0.0) format_reward = torch.tensor(0.0) @@ -20,16 +22,16 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): # Check format accuracy if format_valid: - format_reward += 1.0 - reward += 1.0 + format_reward += format_score + reward += format_score # Check answer accuracy if ( final_answer is not None and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower() ): - acc_reward += 5.0 - reward += 5.0 + acc_reward += acc_score + reward += acc_score return torch.tensor([reward, format_reward, acc_reward]).to(input_ids.device) From 47d64937782b4e2b0d21b84eef07b5e258b82abb Mon Sep 17 00:00:00 2001 From: Tong Li Date: Tue, 11 Mar 2025 13:06:09 +0800 Subject: [PATCH 20/35] add response length --- .../coati/distributed/grpo_consumer.py | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index ae9f2c400bce..55dfd09ab244 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -59,6 +59,7 @@ def __init__( self.accum_format_reward = torch.zeros(1, device=self.device) self.accum_acc_reward = torch.zeros(1, device=self.device) self.accum_advantages = torch.zeros(1, device=self.device) + self.accum_response_length = torch.zeros(1, device=self.device) self.accum_count = 0 # Reference model is initialized from policy model. @@ -83,7 +84,7 @@ def __init__( self.policy_loss_fn = PolicyLoss() self.global_step = 0 if use_wandb and self.rank == 0: - self.wandb_run = wandb.init(project="GRPO-Test", sync_tensorboard=True) + self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True) def setup(self): super().setup() @@ -109,6 +110,7 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: action_mask = data["action_mask"] num_action = action_mask.shape[1] old_action_log_probs = data["action_log_probs"] + response_length = torch.sum(action_mask, dim=1).to(torch.float32) need_update = (step_idx + 1) % self.num_microbatches == 0 @@ -168,6 +170,7 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: format_reward = all_reduce_mean(format_reward.mean(), self.plugin) acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin) advantages = all_reduce_mean(advantages.mean(), self.plugin) + response_length = all_reduce_mean(response_length.mean(), self.plugin) # Calculate accumulate value. self.accum_loss.add_(loss.data) self.accum_reward.add_(reward.data) @@ -175,6 +178,7 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: self.accum_format_reward.add_(format_reward.data) self.accum_acc_reward.add_(acc_reward.data) self.accum_advantages.add_(advantages.data) + self.accum_response_length.add_(response_length.data) self.accum_count += 1 if need_update: self.optimizer.step() @@ -184,32 +188,38 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: print( "Loss:", self.accum_loss.item() / self.accum_count, - "Reward:", + "\nReward:", self.accum_reward.item() / self.accum_count, - "KL:", - self.accum_kl.item() / self.accum_count, - "Format Reward:", + "\nFormat Reward:", self.accum_format_reward.item() / self.accum_count, - "Acc Reward:", + "\nAcc Reward:", self.accum_acc_reward.item() / self.accum_count, - "Advantages:", + "\nKL:", + self.accum_kl.item() / self.accum_count, + "\nAdvantages:", self.accum_advantages.item() / self.accum_count, + "\nResponse Length:", + self.accum_response_length.item() / self.accum_count, ) self.wandb_run.log( { "train/loss": self.accum_loss.item() / self.accum_count, "train/reward": self.accum_reward.item() / self.accum_count, - "train/kl": self.accum_kl.item() / self.accum_count, "train/format_reward": self.accum_format_reward.item() / self.accum_count, "train/acc_reward": self.accum_acc_reward.item() / self.accum_count, + "train/kl": self.accum_kl.item() / self.accum_count, "train/advantages": self.accum_advantages.item() / self.accum_count, + "train/response_length": self.accum_response_length.item() / self.accum_count, } ) self.accum_loss.zero_() self.accum_reward.zero_() - self.accum_kl.zero_() self.accum_acc_reward.zero_() self.accum_format_reward.zero_() + self.accum_kl.zero_() + self.accum_advantages.zero_() + self.accum_response_length.zero_() + self.accum_count = 0 return loss_scalar From 704866a240fd9dfd10f8ecdd91e812026b0af09e Mon Sep 17 00:00:00 2001 From: Tong Li Date: Tue, 11 Mar 2025 16:17:02 +0800 Subject: [PATCH 21/35] detach --- applications/ColossalChat/coati/distributed/loss.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/loss.py b/applications/ColossalChat/coati/distributed/loss.py index 222540b9270d..af5776731a25 100644 --- a/applications/ColossalChat/coati/distributed/loss.py +++ b/applications/ColossalChat/coati/distributed/loss.py @@ -26,15 +26,10 @@ def forward( ) -> torch.Tensor: skip = False if action_mask is None: - ratio_ = (log_probs - old_log_probs).exp() + ratio = (log_probs - log_probs.detach()).exp() else: - ratio_ = ((log_probs - old_log_probs) * action_mask).exp() + ratio = ((log_probs - log_probs.detach()) * action_mask).exp() - # note that if dropout is disabled (recommanded), ratio will always be 1. - if ratio_.mean() > self.skip_threshold: - skip = True - - ratio = ratio_.clamp(0.0, 10.0) surr1 = ratio * advantages surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages loss = -torch.min(surr1, surr2) + self.beta * per_token_kl @@ -44,4 +39,4 @@ def forward( else: loss = loss.mean(dim=1) loss = loss.mean() - return loss, skip, ratio_.max() + return loss, skip, ratio.max() From 131eeceb5d93d5f4eef44f1e04a72a5815a3df1a Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 13 Mar 2025 14:52:09 +0800 Subject: [PATCH 22/35] fix tp bug --- colossalai/shardformer/policies/qwen2.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index 84d2b2fdbd99..8e150fef1f3d 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -13,6 +13,7 @@ PaddingEmbedding, RMSNorm, VocabParallelEmbedding1D, + VocabParallelLMHead1D, ) from ..modeling.qwen2 import ( @@ -446,7 +447,16 @@ def module_policy(self): suffix="lm_head", target_module=LinearWithGradAccum, kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv), - ) + ), + SubModuleReplacementDescription( + suffix="lm_head", + target_module=VocabParallelLMHead1D, + kwargs={ + "gather_output": not self.shard_config.parallel_output, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + }, + ), ], method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}, ) From afddfde2dd2855258d413b1e95c5b9170b841b94 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 13 Mar 2025 14:55:26 +0800 Subject: [PATCH 23/35] fix consumer --- applications/ColossalChat/coati/distributed/consumer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 1e85cccb3c5b..380a2ee1b78a 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -73,6 +73,8 @@ def setup(self) -> None: ) if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config: plugin_config["microbatch_size"] = self.microbatch_size + if self.plugin_config.get("tp_size", 1) > 1: + plugin_config["parallel_output"] = False plugin_config.update(self.plugin_config) self.plugin = HybridParallelPlugin(**plugin_config) self.booster = Booster(plugin=self.plugin) From 4702d5784145b6164fe3dd90b35ba32fcd85b491 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 13 Mar 2025 16:49:02 +0800 Subject: [PATCH 24/35] convert to 8 generation --- .../ColossalChat/coati/distributed/inference_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index bc0ae5c36673..8711d0b8c89e 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -154,7 +154,7 @@ class VLLMInferenceBackend(BaseInferenceBackend): ) FORCE_GENERATE_CONFIG = dict( logprobs=0, - n=4, + n=8, ) def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): From 45ac6c6cb29d7d1cfd9bb31fbbc8a81a2aa8f5d6 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 13 Mar 2025 16:51:22 +0800 Subject: [PATCH 25/35] print results --- applications/ColossalChat/coati/distributed/producer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 3e4a5277ad2e..a3ae22a7935c 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -154,7 +154,11 @@ def __init__( @torch.no_grad() def rollout(self, input_ids, attention_mask, **kwargs): - return self.model.generate(input_ids, attention_mask, **kwargs) + rollouts = self.model.generate(input_ids, attention_mask, **kwargs) + if self.producer_idx == 1: + print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)) + + return rollouts def load_state_dict(self, state_dict): self.model.load_state_dict(state_dict) From 57b49da5e4932327a4576be3901a3cb08c81bae8 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 13 Mar 2025 16:52:15 +0800 Subject: [PATCH 26/35] setup update --- applications/ColossalChat/rl_example.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 3d4b8a575cad..30d56c90b4ad 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -10,12 +10,12 @@ parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") parser.add_argument("-t", "--num-trainers", type=int, default=2) parser.add_argument("-i", "--num-inferencer", type=int, default=2) - parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64) + parser.add_argument("-ibs", "--inference-batch-size", type=int, default=32) parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=16) parser.add_argument("-tbs", "--train-batch-size", type=int, default=16) - parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2) + parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=1) parser.add_argument("-b", "--backend", type=str, default="transformers") - parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple, GPRO"]) + parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple, GRPO"]) args = parser.parse_args() ray.init(address="local", namespace="ray-example") From bc0171d392fcc5d32f93f40791c8733ee1880908 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Fri, 14 Mar 2025 18:12:35 +0800 Subject: [PATCH 27/35] fix transformers backend --- .gitignore | 1 + .../coati/distributed/inference_backend.py | 23 ++++++++++++++++++- applications/ColossalChat/rl_example.py | 20 ++++++++-------- 3 files changed, 34 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index 8bc74b4c8c2c..16f764c1b1ef 100644 --- a/.gitignore +++ b/.gitignore @@ -163,3 +163,4 @@ coverage.xml # log, test files - ColossalChat applications/ColossalChat/logs applications/ColossalChat/tests/logs +applications/ColossalChat/wandb diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index 8711d0b8c89e..58414b29fd47 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -61,12 +61,22 @@ def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any] self.generate_config = generate_config.copy() self.generate_config.update(self.FORCE_GENERATE_CONFIG) self.tokenizer = tokenizer + self.num_generations = 8 @torch.no_grad() def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: + micro_batch_size = input_ids.size(0) input_ids = input_ids.to(get_current_device()) attention_mask = attention_mask.to(get_current_device()) - out = self.model.generate(input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config) + gt_answer = None + if "gt_answer" in kwargs: + gt_answer = kwargs.pop("gt_answer") + if self.num_generations > 1: + input_ids = input_ids.repeat_interleave(self.num_generations, dim=0) + attention_mask = attention_mask.repeat_interleave(self.num_generations, dim=0) + out = self.model.generate( + input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config, tokenizer=self.tokenizer + ) input_len = input_ids.shape[-1] new_token_ids = out.sequences[:, input_len:] # get log probs @@ -76,10 +86,13 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar action_log_probs.append(log_probs_from_logits(logits[:, None, :], new_token_ids[:, i : i + 1])) action_log_probs = torch.cat(action_log_probs, dim=1) # get action mask + response_idx = torch.zeros((new_token_ids.size(0), 2), dtype=torch.int).to(get_current_device()) action_mask = torch.ones_like(new_token_ids, dtype=attention_mask.dtype) if self.tokenizer.eos_token_id is not None: for indices in torch.nonzero(new_token_ids == self.tokenizer.eos_token_id): action_mask[indices[0], indices[1] + 1 :] = 0 + response_idx[:, 0] = input_len + response_idx[:, 1] = input_len + action_mask.sum(dim=1) - 1 if attention_mask.size(0) != action_mask.size(0): assert action_mask.size(0) % attention_mask.size(0) == 0 @@ -91,7 +104,15 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar "attention_mask": attention_mask, "action_log_probs": action_log_probs, "action_mask": action_mask, + "response_idx": response_idx, } + + data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()} + + if gt_answer is not None: + # repeat gt_answer for each prompt. + data["gt_answer"] = gt_answer.repeat_interleave(self.num_generations, dim=1) + data = {k: v.to(get_current_device()) for k, v in data.items()} return data def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 30d56c90b4ad..1de8b649d5d1 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -10,9 +10,9 @@ parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") parser.add_argument("-t", "--num-trainers", type=int, default=2) parser.add_argument("-i", "--num-inferencer", type=int, default=2) - parser.add_argument("-ibs", "--inference-batch-size", type=int, default=32) - parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=16) - parser.add_argument("-tbs", "--train-batch-size", type=int, default=16) + parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64) + parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8) + parser.add_argument("-tbs", "--train-batch-size", type=int, default=32) parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=1) parser.add_argument("-b", "--backend", type=str, default="transformers") parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple, GRPO"]) @@ -24,29 +24,31 @@ train_model_config = dict(path=args.model) generate_config = dict( top_k=50, - top_p=0.8, + top_p=0.9, + temperature=1.0, ) if args.backend == "transformers": inference_model_config.update( dict( - attn_implementation="flash_attention_2", + use_flash_attention_2=True, torch_dtype=torch.bfloat16, ) ) train_model_config.update( dict( - attn_implementation="flash_attention_2", + use_flash_attention_2=True, torch_dtype=torch.bfloat16, use_cache=False, ) ) generate_config.update( dict( - max_length=512, + max_length=1024 + 512, do_sample=True, max_new_tokens=None, early_stopping=False, + stop_strings=[""], ) ) elif args.backend == "vllm": @@ -82,12 +84,12 @@ num_producers=args.num_inferencer, num_proc_per_producer=1, num_consumer_procs=args.num_trainers, - num_episodes=1, + num_episodes=10, inference_batch_size=args.inference_batch_size, inference_microbatch_size=args.inference_microbatch_size, train_batch_size=args.train_batch_size, train_microbatch_size=args.train_microbatch_size, - dataset_config={"path": args.dataset, "max_length": 256}, + dataset_config={"path": args.dataset, "max_length": 300}, dataloaders_config={}, inference_model_config=inference_model_config, generate_config=generate_config, From 7795d4c50d727fea5c1a45b665fe30de63166518 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 18 Mar 2025 17:47:55 +0800 Subject: [PATCH 28/35] [Feature] Support Distributed LogProb for GRPO Training (#6247) * [fix] fix qwen VocabParallelLMHead1D and gather output * fix tp bug * fix consumer * [feat] Support Distributed LogProb for GRPO Training * [fix] fix loss func * [fix] fix log prob plugin * [fix] fix qwen modeling param * [fix] rm comments * [fix] rm hard-code;fix non-dist version * [fix] fix test file param name and benchmark tp gather output=True/False * [fix] rm non-dist version in dist log prob * [fix] fix comments * [fix] fix dis log prob plugin * [fix] fix test case * [fix] fix qwen VocabParallelLMHead1D and gather output * [fix] fix DistLogProb comments * [fix] restore tp size * [fix] fix comments * [fix] fix comment; fix LogSoftmax usage --------- Co-authored-by: Tong Li --- .../coati/distributed/consumer.py | 2 - .../coati/distributed/grpo_consumer.py | 8 +- .../ColossalChat/coati/distributed/utils.py | 20 ++- colossalai/shardformer/layer/__init__.py | 4 +- colossalai/shardformer/layer/loss.py | 150 +++++++++++++++++- colossalai/shardformer/modeling/qwen2.py | 1 - colossalai/shardformer/policies/qwen2.py | 8 +- .../test_layer/test_dist_log_prob.py | 52 ++++++ 8 files changed, 233 insertions(+), 12 deletions(-) create mode 100644 tests/test_shardformer/test_layer/test_dist_log_prob.py diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 380a2ee1b78a..1e85cccb3c5b 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -73,8 +73,6 @@ def setup(self) -> None: ) if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config: plugin_config["microbatch_size"] = self.microbatch_size - if self.plugin_config.get("tp_size", 1) > 1: - plugin_config["parallel_output"] = False plugin_config.update(self.plugin_config) self.plugin = HybridParallelPlugin(**plugin_config) self.booster = Booster(plugin=self.plugin) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 55dfd09ab244..b1edb89bb0e5 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -120,14 +120,18 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: input_ids=data["input_ids"], attention_mask=data["attention_mask"], )["logits"] - action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action) + action_log_probs = calc_action_log_probs( + policy_model_logits, data["input_ids"], num_action, self.plugin.shard_config + ) with torch.no_grad(): reference_model_logits = self.reference_model( input_ids=data["input_ids"], attention_mask=data["attention_mask"], )["logits"] - reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action) + reference_action_log_probs = calc_action_log_probs( + reference_model_logits, data["input_ids"], num_action, self.plugin.shard_config + ) per_token_kl = ( torch.exp(reference_action_log_probs - action_log_probs) diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index 98b54815b5b4..919e4434faa6 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -2,6 +2,8 @@ import torch +from colossalai.shardformer.layer.loss import dist_log_prob + def unbind_batch(batch: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]: batches = [] @@ -66,18 +68,30 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T return per_label_logps.squeeze(-1) -def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor: +def calc_action_log_probs( + logits: torch.Tensor, + sequences: torch.LongTensor, + num_actions: int, + shard_config, + vocab_size: int = None, +) -> torch.Tensor: """Calculate action log probs. Args: - output (torch.Tensor): Output tensor of Actor.forward.logits. + logits (torch.Tensor): Output tensor of Actor.forward.logits. sequences (torch.LongTensor): Input sequences. num_actions (int): Number of actions. + shard_config + vocab_size + Returns: torch.Tensor: Action log probs. """ - log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) + # labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] + # logits: torch.Tensor, # [B, S, Vocab_size] + log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype) + log_probs = log_probs.squeeze(-1) return log_probs[:, -num_actions:] diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 0bd1b60923e9..a1b80bf56b63 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -3,7 +3,7 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D -from .loss import cross_entropy_1d, dist_cross_entropy +from .loss import cross_entropy_1d, dist_cross_entropy, dist_log_prob, dist_log_prob_1d from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .parallel_module import ParallelModule from .qkv_fused_linear import ( @@ -28,6 +28,8 @@ "DropoutForReplicatedInput", "cross_entropy_1d", "dist_cross_entropy", + "dist_log_prob_1d", + "dist_log_prob", "BaseLayerNorm", "LayerNorm", "RMSNorm", diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 0e2241af9fc9..51419a38a0ed 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -3,13 +3,21 @@ from torch.autograd import Function from torch.distributed import ProcessGroup from torch.nn import CrossEntropyLoss +from torch.nn.functional import log_softmax from colossalai.shardformer.layer._operation import reduce_forward from colossalai.shardformer.shard import ShardConfig from .utils import is_share_sp_tp -__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"] +__all__ = [ + "DistCrossEntropy", + "cross_entropy_1d", + "dist_cross_entropy", + "DistLogProb", + "dist_log_prob_1d", + "dist_log_prob", +] _IGNORE_IDX = -100 @@ -137,6 +145,98 @@ def backward(ctx, grad_output): return grad_logits, None, None, None, None, None, None +class DistLogProb(Function): + r""" + Overwrite the forward and backward function to calculate the log prob before gather + + Args: + Function (:class:`torch.autograd.Function`): default + """ + + @staticmethod + def forward( + ctx, + vocab_logits: torch.Tensor, + target: torch.Tensor, + process_group: ProcessGroup, + vocab_size: int, + dtype=torch.float32, + ): + + ################## + # Step1:Find the global maximum value of logits + ################## + logits_max = torch.max(vocab_logits, dim=-1)[0] + handle = dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group, async_op=True) + + ################## + # Step2:Find the local mask. local mask will be use to select log_probs value in Step 4. + # For accleration, we overlap Step 2 and Step 3 + ################## + rank = dist.get_rank(group=process_group) + world_size = dist.get_world_size(group=process_group) + if vocab_size is None: + partition_vocab_size = vocab_logits.size()[-1] + global_vocab_size = partition_vocab_size * world_size + else: + global_vocab_size = vocab_size + partition_vocab_size = global_vocab_size // world_size + # down and up threshold for local logits + delta = (global_vocab_size + world_size - 1) // world_size + down_threshold = rank * delta + up_threshold = down_threshold + delta + if up_threshold > global_vocab_size: + up_threshold = global_vocab_size + # mask + mask = (target < down_threshold) | (target >= up_threshold) + masked_target = target.clone() - down_threshold + masked_target[mask] = 0 + masked_target_1d = masked_target.view(-1).contiguous() + handle.wait() + + ################## + # Step3:Calculate global summation exp logits + ################## + vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1) + exp_logits = torch.exp(vocab_logits) + sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32) # local summation exp logits + dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group) + + ################## + # Step4:Calculate local prob. We first cal log_softmax, then select log probs via local mask + ################## + log_probs = vocab_logits - torch.log(sum_exp_logits.unsqueeze(dim=-1)) # cal log_softmax + log_probs = log_probs.gather(dim=-1, index=masked_target.unsqueeze(-1)) + log_probs[mask.unsqueeze(-1)] = 0 # set masked val to zero + dist.all_reduce(log_probs, op=dist.ReduceOp.SUM, group=process_group) + + ctx.save_for_backward(exp_logits, mask, masked_target_1d, sum_exp_logits) + ctx.dtype = dtype + return log_probs + + @staticmethod + def backward(ctx, grad_output): + exp_logits, mask, masked_target_1d, sum_exp_logits = ctx.saved_tensors + ################## + # Step1:Find the global sofmax value + ################## + softmax_logits = exp_logits / sum_exp_logits.unsqueeze(dim=-1) + + ################## + # Step2:Update softmax value based on local target index + ################## + partion_vocab_size = softmax_logits.shape[-1] + softmax_logits_2d = softmax_logits.view(-1, partion_vocab_size) + update = 1.0 - mask.view(-1).float().to(ctx.dtype) + softmax_logits_2d[torch.arange(0, softmax_logits_2d.shape[0]), masked_target_1d] -= update + + ################## + # Step3:Calculate grad_output, which is the gradient of the loss function with respect to the output of logsoftmax + ################## + grad_logits = -softmax_logits.mul_(grad_output) + return grad_logits, None, None, None, None, None, None + + def cross_entropy_1d( vocab_logits: torch.Tensor, labels: torch.Tensor, @@ -149,6 +249,16 @@ def cross_entropy_1d( return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype, mode) +def dist_log_prob_1d( + vocab_logits: torch.Tensor, + labels: torch.Tensor, + process_group: ProcessGroup = None, + vocab_size: int = None, + dtype: torch.dtype = None, +) -> torch.Tensor: + return DistLogProb.apply(vocab_logits, labels, process_group, vocab_size, dtype) + + def dist_cross_entropy( labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] logits: torch.Tensor, # [B, S, Vocab_size] @@ -243,3 +353,41 @@ def dist_cross_entropy( loss, num_nonzero = loss[0], loss[1].detach() loss = (loss / num_nonzero).squeeze() return loss + + +def dist_log_prob( + labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] + logits: torch.Tensor, # [B, S, Vocab_size] + shard_config: ShardConfig, + vocab_size: int, + dtype: torch.dtype, + seq_dim: int = 1, +) -> torch.Tensor: + """ + Helper to compute log prob for most shardformer models supporting PP, TP. + """ + # Split labels if not gather output + parallel_output = shard_config.parallel_output + is_tp = shard_config.enable_tensor_parallelism + + # TODO:support sp + labels = labels[..., 1:] + logits = logits[..., :-1, :] + labels = labels.contiguous() + logits = logits.contiguous() + assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}" + + # Flatten the tokens + if is_tp and parallel_output: + log_prob = dist_log_prob_1d( + logits, + labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=vocab_size, + dtype=dtype, + ) + else: + log_prob = log_softmax(logits) + log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1)) + + return log_prob diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 569fc4a459c5..71e3557fe214 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -832,7 +832,6 @@ def forward( loss = None if labels is not None: loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype) - if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index 8e150fef1f3d..0adcdfdbd553 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -430,8 +430,12 @@ def module_policy(self): sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", - target_module=Linear1D_Col, - kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv), + target_module=VocabParallelLMHead1D, + kwargs=dict( + gather_output=not self.shard_config.parallel_output, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ) ], method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}, diff --git a/tests/test_shardformer/test_layer/test_dist_log_prob.py b/tests/test_shardformer/test_layer/test_dist_log_prob.py new file mode 100644 index 000000000000..05a6a5d4766f --- /dev/null +++ b/tests/test_shardformer/test_layer/test_dist_log_prob.py @@ -0,0 +1,52 @@ +import pytest +import torch +from coati.distributed.utils import log_probs_from_logits + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer import dist_log_prob_1d +from colossalai.testing import rerun_if_address_is_in_use, spawn + +CONFIG = dict( + parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode="1d")), +) + + +def check_dist_log_prob(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl") + + # prepare data + pred = torch.randn(2, 4, 8, requires_grad=True).cuda() + labels = torch.randint(8, (2, 4)).cuda() + + logprob = log_probs_from_logits(pred, labels) + + pred.retain_grad() + logprob.mean().backward() + + dist_pred = pred.clone().chunk(world_size, -1)[rank].detach() + dist_pred.requires_grad = True + dist_logprob = dist_log_prob_1d(dist_pred, labels) + + dist_pred.retain_grad() + dist_logprob.squeeze(-1).mean().backward() + + assert torch.allclose( + logprob, dist_logprob.squeeze(-1), atol=1e-5 + ), f"dist cross entropy logprob is not equal to orgin logprob\n{logprob}\n{dist_logprob.squeeze(-1)}" + + pred_grad_partial = pred.grad.clone().chunk(world_size, -1)[rank].detach() + assert torch.allclose( + pred_grad_partial, dist_pred.grad + ), f"dist grad is not equal to orgin grad\n{pred.grad}\n{dist_pred.grad}" + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_dist_log_prob(): + spawn(check_dist_log_prob, 2) + + +if __name__ == "__main__": + test_dist_log_prob() From 7ee4452f8c5851a2dea715a920f6b6972c2ff9ee Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Wed, 19 Mar 2025 17:07:20 +0800 Subject: [PATCH 29/35] fix vllm --- .../coati/distributed/grpo_consumer.py | 148 +++++++++++++++++- .../coati/distributed/inference_backend.py | 11 +- .../ColossalChat/coati/distributed/launch.py | 16 +- .../ColossalChat/coati/distributed/loss.py | 3 + applications/ColossalChat/rl_example.py | 18 +-- 5 files changed, 172 insertions(+), 24 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index b1edb89bb0e5..785fa820ec8b 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -1,8 +1,11 @@ +import json +import os from contextlib import nullcontext from typing import Optional import ray import torch +import torch.distributed as dist import wandb from coati.distributed.consumer import BaseConsumer from coati.distributed.loss import PolicyLoss @@ -33,6 +36,8 @@ def __init__( microbatch_size=1, num_generations=4, use_wandb=True, + generator_config=None, + filter_range=None, ): super().__init__( num_producers, @@ -69,6 +74,9 @@ def __init__( self.tokenizer = AutoTokenizer.from_pretrained(path) self.pad_token_id = self.tokenizer.pad_token_id self.num_generations = num_generations + self.filter_range = filter_range + if self.filter_range is not None: + assert len(self.filter_range) == 2, "Filter range should have 2 values." # Initialize verifiable reward. response_format_tags = { @@ -84,7 +92,11 @@ def __init__( self.policy_loss_fn = PolicyLoss() self.global_step = 0 if use_wandb and self.rank == 0: - self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True) + if "repetition_penalty" in generator_config: + name = f"{generator_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generator_config['temperature']:.01f}_rep_penalty_{generator_config['repetition_penalty']:.01f}" + else: + name = f"{generator_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generator_config['temperature']:.01f}" + self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name) def setup(self): super().setup() @@ -121,7 +133,7 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: attention_mask=data["attention_mask"], )["logits"] action_log_probs = calc_action_log_probs( - policy_model_logits, data["input_ids"], num_action, self.plugin.shard_config + policy_model_logits / generator_config["temperature"], data["input_ids"], num_action, self.plugin.shard_config ) with torch.no_grad(): @@ -130,7 +142,7 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: attention_mask=data["attention_mask"], )["logits"] reference_action_log_probs = calc_action_log_probs( - reference_model_logits, data["input_ids"], num_action, self.plugin.shard_config + reference_model_logits / generator_config["temperature"], data["input_ids"], num_action, self.plugin.shard_config ) per_token_kl = ( @@ -149,7 +161,14 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device) # [batch_size, num_generations] + # filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score), + loss_mask = ( + None + if self.filter_range is None + else torch.logical_and(reward > self.filter_range[0], reward < self.filter_range[1]) + ) group_reward = reward.view(-1, self.num_generations) + reward_mean = group_reward.mean(dim=1) # [batch_size x num_generations] reward_mean = group_reward.mean(dim=1).repeat_interleave(self.num_generations, dim=0) @@ -164,6 +183,7 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1), per_token_kl, action_mask, + loss_mask=loss_mask, ) if not skip_update: @@ -232,3 +252,125 @@ def state_dict(self): model = self.policy_model.unwrap() state_dict = model.state_dict() return state_dict + + +@ray.remote +class GRPOEvalConsumer(BaseConsumer): + def __init__( + self, + num_producers, + num_episodes, + rank, + world_size, + master_addr, + master_port, + num_update_per_episode, + num_recv_per_update, + batch_size, + model_config, + plugin_config, + microbatch_size=1, + num_generations=4, + use_wandb=True, + log_dir="./results", + ): + super().__init__( + num_producers, + num_episodes, + rank, + world_size, + master_addr, + master_port, + num_update_per_episode, + num_recv_per_update, + batch_size, + model_config, + plugin_config, + microbatch_size, + ) + path = model_config.pop("path") + self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) + self.policy_model.train() + self.accum_reward = torch.zeros(1, device=self.device) + self.accum_format_reward = torch.zeros(1, device=self.device) + self.accum_acc_reward = torch.zeros(1, device=self.device) + self.accum_response_length = torch.zeros(1, device=self.device) + self.accum_count = torch.zeros(1, device=self.device) + + self.tokenizer = AutoTokenizer.from_pretrained(path) + self.pad_token_id = self.tokenizer.pad_token_id + self.num_generations = num_generations + + # Initialize verifiable reward. + response_format_tags = { + "think_start": {"text": "", "num_occur": 1}, + "think_end": {"text": "", "num_occur": 1}, + "answer_start": {"text": "", "num_occur": 1}, + "answer_end": {"text": "", "num_occur": 1}, + } + self.reward_model = VerifiableReward( + reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags + ) + + self.log_dir = log_dir + if not os.path.exists(self.log_dir): + os.makedirs(self.log_dir) + else: + os.system(f"rm -rf {self.log_dir}/*") + + def setup(self): + super().setup() + self.policy_model, _, *_ = self.booster.boost(self.policy_model) + + def step(self, step_idx: int, **kwargs) -> Optional[float]: + rank = dist.get_rank() + data = {k: v.view(-1, v.size(-1)).cpu() for k, v in kwargs.items()} + kwargs["input_ids"].size(0) + reward_group = self.reward_model( + data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"] + ) + reward = [value[0].item() for value in reward_group] + format_reward = [value[1].item() for value in reward_group] + acc_reward = [value[2].item() for value in reward_group] + response_length = [(data["response_idx"][i][1] - data["response_idx"][i][0]).item() for i in range(len(reward))] + + response = self.tokenizer.batch_decode(data["input_ids"], skip_special_tokens=True) + with open(f"{self.log_dir}/eval_results_rank_{rank}.jsonl", "a", encoding="utf8") as f: + for i in range(len(response)): + f.write( + json.dumps( + { + "response": response[i], + "reward": reward[i], + "format_reward": format_reward[i], + "acc_reward": acc_reward[i], + "response_length": response_length[i], + }, + ensure_ascii=False, + ) + + "\n" + ) + + self.accum_reward += sum(reward) + self.accum_format_reward += sum(format_reward) + self.accum_acc_reward += sum(acc_reward) + self.accum_response_length += sum(response_length) + self.accum_count += len(reward) + + # print results + total_count = all_reduce_mean(self.accum_count, self.plugin) + mean_reward = all_reduce_mean(self.accum_reward, self.plugin) / total_count + mean_format_reward = all_reduce_mean(self.accum_format_reward, self.plugin) / total_count + mean_acc_reward = all_reduce_mean(self.accum_acc_reward, self.plugin) / total_count + mean_response_length = all_reduce_mean(self.accum_response_length, self.plugin) / total_count + if rank == 0: + print( + f"Step {step_idx}: Mean Reward: {mean_reward}, Mean Format Reward: {mean_format_reward}, Mean Acc Reward: {mean_acc_reward}, Mean Response Length: {mean_response_length}" + ) + return None + + def state_dict(self): + self.policy_model._force_wait_all_gather() + model = self.policy_model.unwrap() + state_dict = model.state_dict() + return state_dict diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index 58414b29fd47..b4cfffa4dc81 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -183,7 +183,7 @@ def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any] raise ImportError("vllm is not installed") model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG) path = model_config.pop("path") - self.llm = LLM(path, **model_config) + self.llm = LLM(model=path, **model_config) generate_config = generate_config.copy() generate_config.update(self.FORCE_GENERATE_CONFIG) self.generate_config = SamplingParams(**generate_config) @@ -194,8 +194,15 @@ def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any] def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: micro_batch_size = input_ids.size(0) response_start_idx = input_ids.size(1) + micro_batch_input_ids = input_ids.tolist() + micro_batch_input_ids_no_padding = [] + for i in range(micro_batch_size): + for j in range(input_ids.size(1)): + if micro_batch_input_ids[i][j] != self.tokenizer.pad_token_id: + micro_batch_input_ids_no_padding.append(micro_batch_input_ids[i][j:]) + break outputs = self.llm.generate( - prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False + prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False ) out_tokens = [] out_len = [] diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 8581ff5865f8..512e7261fd6c 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -1,15 +1,13 @@ +import copy from typing import Any, Dict, Optional import ray from .consumer import SimpleConsumer -from .grpo_consumer import GRPOConsumer +from .grpo_consumer import GRPOConsumer, GRPOEvalConsumer from .producer import SimpleProducer -ALGO_MAP = { - "Simple": SimpleConsumer, - "GRPO": GRPOConsumer, -} +ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "EvalGRPO": GRPOEvalConsumer} def get_jsonl_size_fast(path: str) -> int: @@ -80,6 +78,12 @@ def launch_distributed( backend=inference_backend, ) procs.append(producer) + generate_config_consumer = copy.deepcopy(generate_config) + generate_config_consumer.update( + dict( + backend=inference_backend, + ) + ) for i in range(num_consumer_procs): consumer = core_consumer.options(num_gpus=1).remote( num_producers=num_producers, @@ -94,6 +98,8 @@ def launch_distributed( model_config=train_model_config, plugin_config=plugin_config, microbatch_size=train_microbatch_size, + generate_config=generate_config_consumer, + filter_range=[0.05, 9.0], ) procs.append(consumer) ray.get([p.setup.remote() for p in procs]) diff --git a/applications/ColossalChat/coati/distributed/loss.py b/applications/ColossalChat/coati/distributed/loss.py index af5776731a25..90ad09736281 100644 --- a/applications/ColossalChat/coati/distributed/loss.py +++ b/applications/ColossalChat/coati/distributed/loss.py @@ -23,6 +23,7 @@ def forward( advantages: torch.Tensor, per_token_kl: torch.Tensor, action_mask: Optional[torch.Tensor] = None, + loss_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: skip = False if action_mask is None: @@ -38,5 +39,7 @@ def forward( loss = masked_mean(loss, action_mask) else: loss = loss.mean(dim=1) + if loss_mask is not None: + loss = loss * loss_mask loss = loss.mean() return loss, skip, ratio.max() diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 1de8b649d5d1..fc32ece21cec 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -15,18 +15,14 @@ parser.add_argument("-tbs", "--train-batch-size", type=int, default=32) parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=1) parser.add_argument("-b", "--backend", type=str, default="transformers") - parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple, GRPO"]) + parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"]) args = parser.parse_args() ray.init(address="local", namespace="ray-example") inference_model_config = dict(path=args.model) train_model_config = dict(path=args.model) - generate_config = dict( - top_k=50, - top_p=0.9, - temperature=1.0, - ) + generate_config = dict(top_k=50, top_p=0.9, temperature=0.7) if args.backend == "transformers": inference_model_config.update( @@ -52,19 +48,13 @@ ) ) elif args.backend == "vllm": - inference_model_config.update( - dict( - gpu_memory_utilization=0.7, - ) - ) + inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True)) generate_config.update( dict( max_tokens=2048, ignore_eos=True, include_stop_str_in_output=True, stop=[""], - temperature=0.7, - top_p=0.95, ) ) else: @@ -97,6 +87,6 @@ plugin_config={}, inference_backend=args.backend, master_addr="localhost", - master_port=29504, + master_port=29503, core_algo=args.algo, ) From 0472f44163c232d78506e6b2ebf7932780332951 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Fri, 21 Mar 2025 10:24:24 +0800 Subject: [PATCH 30/35] fix logprob, add filtering, temperature annealing, lr descent --- .../coati/distributed/consumer.py | 3 ++ .../coati/distributed/grpo_consumer.py | 50 ++++++++++++------- .../coati/distributed/inference_backend.py | 30 ++++++++--- .../ColossalChat/coati/distributed/launch.py | 5 +- .../coati/distributed/producer.py | 9 +++- applications/ColossalChat/rl_example.py | 2 +- colossalai/shardformer/layer/loss.py | 2 +- 7 files changed, 74 insertions(+), 27 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 1e85cccb3c5b..de289738347c 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -57,6 +57,7 @@ def __init__( assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now" self.device = get_current_device() + self.lr_scheduler = None def setup(self) -> None: for i in range(self.num_producers): @@ -121,6 +122,8 @@ def loop(self) -> None: pbar.set_postfix({"loss": loss}) i += 1 assert len(self.buffer) == 0 + if self.lr_scheduler is not None: + self.lr_scheduler.step() if (step + 1) % self.save_interval == 0: if self.rank == 0: print(f"Start saving policy model at step {step + 1}.") diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 785fa820ec8b..5a488f5aaf91 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -15,6 +15,7 @@ from coati.trainer.utils import all_reduce_mean from transformers import AutoModelForCausalLM, AutoTokenizer +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam @@ -34,10 +35,10 @@ def __init__( model_config, plugin_config, microbatch_size=1, - num_generations=4, + num_generations=8, use_wandb=True, - generator_config=None, - filter_range=None, + generate_config=None, + training_config={}, ): super().__init__( num_producers, @@ -57,7 +58,7 @@ def __init__( self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.policy_model.train() self.policy_model.gradient_checkpointing_enable() - self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-6) + self.optimizer = HybridAdam(self.policy_model.parameters(), lr=training_config.get("lr", 1e-6)) self.accum_loss = torch.zeros(1, device=self.device) self.accum_reward = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device) @@ -66,6 +67,7 @@ def __init__( self.accum_advantages = torch.zeros(1, device=self.device) self.accum_response_length = torch.zeros(1, device=self.device) self.accum_count = 0 + self.generate_config = generate_config # Reference model is initialized from policy model. self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) @@ -74,7 +76,7 @@ def __init__( self.tokenizer = AutoTokenizer.from_pretrained(path) self.pad_token_id = self.tokenizer.pad_token_id self.num_generations = num_generations - self.filter_range = filter_range + self.filter_range = training_config.get("filter_range", None) if self.filter_range is not None: assert len(self.filter_range) == 2, "Filter range should have 2 values." @@ -92,15 +94,21 @@ def __init__( self.policy_loss_fn = PolicyLoss() self.global_step = 0 if use_wandb and self.rank == 0: - if "repetition_penalty" in generator_config: - name = f"{generator_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generator_config['temperature']:.01f}_rep_penalty_{generator_config['repetition_penalty']:.01f}" - else: - name = f"{generator_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generator_config['temperature']:.01f}" + name = f"{generate_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}" self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name) + self.lr_scheduler = CosineAnnealingWarmupLR( + optimizer=self.optimizer, + total_steps=min(self.num_episodes, 4) * self.num_update_per_episode, + warmup_steps=0, + eta_min=0.1 * training_config.get("lr", 1e-6), + ) + def setup(self): super().setup() - self.policy_model, self.optimizer, *_ = self.booster.boost(self.policy_model, self.optimizer) + self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost( + self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler + ) self.reference_model, *_ = self.booster.boost(self.reference_model) def step(self, step_idx: int, **kwargs) -> Optional[float]: @@ -133,7 +141,10 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: attention_mask=data["attention_mask"], )["logits"] action_log_probs = calc_action_log_probs( - policy_model_logits / generator_config["temperature"], data["input_ids"], num_action, self.plugin.shard_config + policy_model_logits / self.generate_config["temperature"], + data["input_ids"], + num_action, + self.plugin.shard_config, ) with torch.no_grad(): @@ -142,7 +153,10 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: attention_mask=data["attention_mask"], )["logits"] reference_action_log_probs = calc_action_log_probs( - reference_model_logits / generator_config["temperature"], data["input_ids"], num_action, self.plugin.shard_config + reference_model_logits / self.generate_config["temperature"], + data["input_ids"], + num_action, + self.plugin.shard_config, ) per_token_kl = ( @@ -161,22 +175,24 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device) # [batch_size, num_generations] + + group_reward = reward.view(-1, self.num_generations) + reward_mean = group_reward.mean(dim=1) # filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score), loss_mask = ( None if self.filter_range is None - else torch.logical_and(reward > self.filter_range[0], reward < self.filter_range[1]) + else torch.logical_and( + reward_mean > self.filter_range[0], reward_mean < self.filter_range[1] + ).repeat_interleave(self.num_generations, dim=0) ) - group_reward = reward.view(-1, self.num_generations) - reward_mean = group_reward.mean(dim=1) # [batch_size x num_generations] - reward_mean = group_reward.mean(dim=1).repeat_interleave(self.num_generations, dim=0) + reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0) reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) # [batch_size x num_generations] advantages = (reward - reward_mean) / (reward_std + 1e-4) - # Calculate Loss loss, skip_update, _ = self.policy_loss_fn( action_log_probs, old_action_log_probs, diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index b4cfffa4dc81..5039d89f5df5 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -53,7 +53,13 @@ class TransformersInferenceBackend(BaseInferenceBackend): ) FORCE_GENERATE_CONFIG = dict(output_logits=True, return_dict_in_generate=True) - def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): + def __init__( + self, + model_config: Dict[str, Any], + generate_config: Dict[str, Any], + tokenizer: PreTrainedTokenizer, + num_generations: int = 8, + ): model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG) model_config.update(self.FORCE_MODEL_CONFIG) path = model_config.pop("path") @@ -61,7 +67,7 @@ def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any] self.generate_config = generate_config.copy() self.generate_config.update(self.FORCE_GENERATE_CONFIG) self.tokenizer = tokenizer - self.num_generations = 8 + self.num_generations = num_generations @torch.no_grad() def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: @@ -120,7 +126,13 @@ def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: class SGLangInferenceBackend(BaseInferenceBackend): - def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): + def __init__( + self, + model_config: Dict[str, Any], + generate_config: Dict[str, Any], + tokenizer: PreTrainedTokenizer, + num_generations: int = 8, + ): if sgl is None: raise ImportError("sglang is not installed") path = model_config.pop("path") @@ -175,10 +187,15 @@ class VLLMInferenceBackend(BaseInferenceBackend): ) FORCE_GENERATE_CONFIG = dict( logprobs=0, - n=8, ) - def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): + def __init__( + self, + model_config: Dict[str, Any], + generate_config: Dict[str, Any], + tokenizer: PreTrainedTokenizer, + num_generations: int = 8, + ): if LLM is None: raise ImportError("vllm is not installed") model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG) @@ -186,9 +203,10 @@ def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any] self.llm = LLM(model=path, **model_config) generate_config = generate_config.copy() generate_config.update(self.FORCE_GENERATE_CONFIG) + generate_config.update({"n": num_generations}) self.generate_config = SamplingParams(**generate_config) self.tokenizer = tokenizer - self.num_generations = self.FORCE_GENERATE_CONFIG["n"] + self.num_generations = num_generations @torch.no_grad() def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 512e7261fd6c..c50db1378e16 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -42,6 +42,7 @@ def launch_distributed( plugin_config: Dict[str, Any], tokenizer_config: Optional[Dict[str, Any]] = None, inference_backend: str = "transformers", + num_generations: int = 8, master_addr: str = "localhost", master_port: int = 29500, core_algo: str = "GRPO", @@ -76,6 +77,7 @@ def launch_distributed( tokenizer_config=tokenizer_config, microbatch_size=inference_microbatch_size, backend=inference_backend, + num_generations=num_generations, ) procs.append(producer) generate_config_consumer = copy.deepcopy(generate_config) @@ -99,7 +101,8 @@ def launch_distributed( plugin_config=plugin_config, microbatch_size=train_microbatch_size, generate_config=generate_config_consumer, - filter_range=[0.05, 9.0], + training_config={"filter_range": [0.05, 9.0], "lr": 1e-6}, + num_generations=num_generations, ) procs.append(consumer) ray.get([p.setup.remote() for p in procs]) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index a3ae22a7935c..6cc9b33305a6 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -117,6 +117,12 @@ def loop(self) -> None: None, self.num_producers, device=self.device, group_name="sync_model" ) self.load_state_dict(state_dict) + # linear annealing for 1 episode, temperature from initial to 0.7 + if episode <= 0: + ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader) + self.model.generate_config.temperature = ( + ratio * self.generate_config["temperature"] + (1 - ratio) * 0.7 + ) @ray.remote @@ -135,6 +141,7 @@ def __init__( tokenizer_config=None, microbatch_size=1, backend="transformers", + num_generations: int = 8, ): super().__init__( producer_idx, @@ -150,7 +157,7 @@ def __init__( microbatch_size, backend, ) - self.model = self.backend_cls(model_config, generate_config, self.tokenizer) + self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations) @torch.no_grad() def rollout(self, input_ids, attention_mask, **kwargs): diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index fc32ece21cec..a67a10bc5b35 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -22,7 +22,7 @@ inference_model_config = dict(path=args.model) train_model_config = dict(path=args.model) - generate_config = dict(top_k=50, top_p=0.9, temperature=0.7) + generate_config = dict(top_k=50, top_p=0.75, temperature=0.9) if args.backend == "transformers": inference_model_config.update( diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 51419a38a0ed..a9bb76fc7d6b 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -387,7 +387,7 @@ def dist_log_prob( dtype=dtype, ) else: - log_prob = log_softmax(logits) + log_prob = log_softmax(logits, dim=-1) log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1)) return log_prob From d8eaf0d4734af11f4e57600509413ef6242966e3 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Fri, 21 Mar 2025 15:03:10 +0800 Subject: [PATCH 31/35] simplify vllm preprocessing input ids --- .../coati/distributed/inference_backend.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index 5039d89f5df5..17c71c8a85b1 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -212,13 +212,11 @@ def __init__( def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: micro_batch_size = input_ids.size(0) response_start_idx = input_ids.size(1) + first_non_padding_token_idx = (input_ids != self.tokenizer.pad_token_id).int().argmax(dim=1) micro_batch_input_ids = input_ids.tolist() - micro_batch_input_ids_no_padding = [] - for i in range(micro_batch_size): - for j in range(input_ids.size(1)): - if micro_batch_input_ids[i][j] != self.tokenizer.pad_token_id: - micro_batch_input_ids_no_padding.append(micro_batch_input_ids[i][j:]) - break + micro_batch_input_ids_no_padding = [ + micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size) + ] outputs = self.llm.generate( prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False ) From 2aa7385c88a1f271bb49741e56c19989421e3756 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Fri, 21 Mar 2025 16:12:07 +0800 Subject: [PATCH 32/35] update logging --- .../ColossalChat/coati/distributed/grpo_consumer.py | 11 ++++++----- .../ColossalChat/coati/distributed/producer.py | 3 +++ 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 5a488f5aaf91..1c0773f4e9bc 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -133,7 +133,6 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: response_length = torch.sum(action_mask, dim=1).to(torch.float32) need_update = (step_idx + 1) % self.num_microbatches == 0 - ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer) with ctx: policy_model_logits = self.policy_model( @@ -243,13 +242,15 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: ) self.wandb_run.log( { + "metrics/reward": self.accum_reward.item() / self.accum_count, + "metrics/format_reward": self.accum_format_reward.item() / self.accum_count, + "metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count, + "metrics/response_length": self.accum_response_length.item() / self.accum_count, "train/loss": self.accum_loss.item() / self.accum_count, - "train/reward": self.accum_reward.item() / self.accum_count, - "train/format_reward": self.accum_format_reward.item() / self.accum_count, - "train/acc_reward": self.accum_acc_reward.item() / self.accum_count, "train/kl": self.accum_kl.item() / self.accum_count, "train/advantages": self.accum_advantages.item() / self.accum_count, - "train/response_length": self.accum_response_length.item() / self.accum_count, + "train/learning_rate": self.lr_scheduler.get_last_lr()[0], + "rollout/temperature": data["temperature"].cpu().numpy()[0][0], } ) self.accum_loss.zero_() diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 6cc9b33305a6..51a1af332f25 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -101,6 +101,9 @@ def loop(self) -> None: break outputs = self.rollout(**batch) print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") + outputs["temperature"] = torch.tensor( + [self.model.generate_config.temperature] * outputs["input_ids"].size(0) + ).to(outputs["input_ids"].device) outputs = pre_send(outputs) ray_broadcast_tensor_dict( outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}" From 50153005b4c2af4dfecc3b07adc91b20e97d00b0 Mon Sep 17 00:00:00 2001 From: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Date: Fri, 28 Mar 2025 10:24:58 +0800 Subject: [PATCH 33/35] [feat] add microbatch forwarding (#6251) * add microbatch forwarding * fix forward microbatch * fix producer OOM * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change project name * fix temperature annealing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address conversation --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../coati/distributed/consumer.py | 7 +- .../coati/distributed/grpo_consumer.py | 133 +++++++++++------- .../ColossalChat/coati/distributed/launch.py | 9 +- .../coati/distributed/producer.py | 10 +- applications/ColossalChat/rl_example.py | 27 ++-- 5 files changed, 113 insertions(+), 73 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index de289738347c..027acc2e7537 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -66,12 +66,7 @@ def setup(self) -> None: cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model") launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0) - plugin_config = dict( - tp_size=1, - pp_size=1, - precision="bf16", - zero_stage=1, - ) + plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2) if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config: plugin_config["microbatch_size"] = self.microbatch_size plugin_config.update(self.plugin_config) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 1c0773f4e9bc..4174f96514b8 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -68,6 +68,7 @@ def __init__( self.accum_response_length = torch.zeros(1, device=self.device) self.accum_count = 0 self.generate_config = generate_config + self.training_config = training_config # Reference model is initialized from policy model. self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) @@ -131,40 +132,16 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: num_action = action_mask.shape[1] old_action_log_probs = data["action_log_probs"] response_length = torch.sum(action_mask, dim=1).to(torch.float32) + forward_batch_size = self.training_config.get("train_microbatch_size", data["input_ids"].size(0)) need_update = (step_idx + 1) % self.num_microbatches == 0 - ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer) + # Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500 + ctx = ( + nullcontext() + if need_update or self.booster.plugin.zero_stage == 2 + else self.booster.no_sync(self.policy_model, self.optimizer) + ) with ctx: - policy_model_logits = self.policy_model( - input_ids=data["input_ids"], - attention_mask=data["attention_mask"], - )["logits"] - action_log_probs = calc_action_log_probs( - policy_model_logits / self.generate_config["temperature"], - data["input_ids"], - num_action, - self.plugin.shard_config, - ) - - with torch.no_grad(): - reference_model_logits = self.reference_model( - input_ids=data["input_ids"], - attention_mask=data["attention_mask"], - )["logits"] - reference_action_log_probs = calc_action_log_probs( - reference_model_logits / self.generate_config["temperature"], - data["input_ids"], - num_action, - self.plugin.shard_config, - ) - - per_token_kl = ( - torch.exp(reference_action_log_probs - action_log_probs) - - (reference_action_log_probs - action_log_probs) - - 1 - ) - kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1) - reward_group = self.reward_model( data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"] ) @@ -177,6 +154,11 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: group_reward = reward.view(-1, self.num_generations) reward_mean = group_reward.mean(dim=1) + # [batch_size x num_generations] + reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0) + reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) + # [batch_size x num_generations] + advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1) # filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score), loss_mask = ( None @@ -185,35 +167,82 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: reward_mean > self.filter_range[0], reward_mean < self.filter_range[1] ).repeat_interleave(self.num_generations, dim=0) ) + mean_kl, mean_loss = [], [] + for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size): + input_ids_forward_micro_batch = data["input_ids"][ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + attention_mask_forward_micro_batch = data["attention_mask"][ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + action_mask_forward_micro_batch = action_mask[ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + loss_mask_forward_micro_batch = ( + loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size] + if loss_mask is not None + else None + ) + advantages_forward_micro_batch = advantages[ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + policy_model_logits = self.policy_model( + input_ids=input_ids_forward_micro_batch, + attention_mask=attention_mask_forward_micro_batch, + ).logits + action_log_probs = calc_action_log_probs( + policy_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) - # [batch_size x num_generations] - reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0) - reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) - # [batch_size x num_generations] - advantages = (reward - reward_mean) / (reward_std + 1e-4) - - loss, skip_update, _ = self.policy_loss_fn( - action_log_probs, - old_action_log_probs, - advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1), - per_token_kl, - action_mask, - loss_mask=loss_mask, - ) + with torch.no_grad(): + reference_model_logits = self.reference_model( + input_ids=input_ids_forward_micro_batch, + attention_mask=attention_mask_forward_micro_batch, + ).logits + reference_action_log_probs = calc_action_log_probs( + reference_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) + + per_token_kl = ( + torch.exp(reference_action_log_probs - action_log_probs) + - (reference_action_log_probs - action_log_probs) + - 1 + ) + kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum( + action_mask_forward_micro_batch, dim=-1 + ) + + loss, skip_update, _ = self.policy_loss_fn( + action_log_probs, + old_action_log_probs, + advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1), + per_token_kl, + action_mask_forward_micro_batch, + loss_mask=loss_mask_forward_micro_batch, + ) + + if not skip_update: + self.booster.backward(loss, self.optimizer) + loss = all_reduce_mean(loss, self.plugin) + kl = all_reduce_mean(kl.mean(), self.plugin) + # Calculate accumulate value. + mean_kl.append(kl.data) + mean_loss.append(loss.data) - if not skip_update: - self.booster.backward(loss, self.optimizer) - loss = all_reduce_mean(loss, self.plugin) reward = all_reduce_mean(reward.mean(), self.plugin) - kl = all_reduce_mean(kl.mean(), self.plugin) format_reward = all_reduce_mean(format_reward.mean(), self.plugin) acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin) advantages = all_reduce_mean(advantages.mean(), self.plugin) response_length = all_reduce_mean(response_length.mean(), self.plugin) - # Calculate accumulate value. - self.accum_loss.add_(loss.data) + self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) + self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) self.accum_reward.add_(reward.data) - self.accum_kl.add_(kl.data) self.accum_format_reward.add_(format_reward.data) self.accum_acc_reward.add_(acc_reward.data) self.accum_advantages.add_(advantages.data) diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index c50db1378e16..ba5d3a9d4fd8 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -34,6 +34,7 @@ def launch_distributed( inference_microbatch_size: int, train_batch_size: int, train_microbatch_size: int, + train_minibatch_size: int, dataset_config: Dict[str, Any], dataloaders_config: Dict[str, Any], inference_model_config: Dict[str, Any], @@ -99,9 +100,13 @@ def launch_distributed( batch_size=train_batch_size, model_config=train_model_config, plugin_config=plugin_config, - microbatch_size=train_microbatch_size, + microbatch_size=train_minibatch_size, generate_config=generate_config_consumer, - training_config={"filter_range": [0.05, 9.0], "lr": 1e-6}, + training_config={ + "filter_range": [0.05, 9.0], + "lr": 1e-6, + "train_microbatch_size": train_microbatch_size, + }, num_generations=num_generations, ) procs.append(consumer) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 51a1af332f25..2c6a24a36711 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -100,6 +100,7 @@ def loop(self) -> None: if i >= num_valid_microbatches: break outputs = self.rollout(**batch) + print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") outputs["temperature"] = torch.tensor( [self.model.generate_config.temperature] * outputs["input_ids"].size(0) @@ -116,16 +117,19 @@ def loop(self) -> None: print( f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}" ) + state_dict = ray_broadcast_tensor_dict( None, self.num_producers, device=self.device, group_name="sync_model" ) self.load_state_dict(state_dict) + del state_dict + torch.cuda.empty_cache() # linear annealing for 1 episode, temperature from initial to 0.7 if episode <= 0: ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader) - self.model.generate_config.temperature = ( - ratio * self.generate_config["temperature"] + (1 - ratio) * 0.7 - ) + self.model.generate_config.temperature = (1 - ratio) * self.generate_config[ + "temperature" + ] + ratio * 0.7 @ray.remote diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index a67a10bc5b35..4a4a4c3404e9 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -10,18 +10,30 @@ parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") parser.add_argument("-t", "--num-trainers", type=int, default=2) parser.add_argument("-i", "--num-inferencer", type=int, default=2) + parser.add_argument("-g", "--num-generations", type=int, default=8) parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64) parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8) parser.add_argument("-tbs", "--train-batch-size", type=int, default=32) - parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=1) + parser.add_argument("-tMbs", "--train-minibatch-size", type=int, default=1) + parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2) parser.add_argument("-b", "--backend", type=str, default="transformers") parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"]) args = parser.parse_args() + assert args.train_minibatch_size > 0, "Train mini batch size must be greater than 0" + assert ( + args.train_minibatch_size * args.num_generations >= args.train_microbatch_size + and args.train_microbatch_size > 0 + ), "Train micro batch size must be greater than 0 less than train mini batch size * num generations" + ray.init(address="local", namespace="ray-example") inference_model_config = dict(path=args.model) - train_model_config = dict(path=args.model) + train_model_config = dict( + path=args.model, + # use_flash_attention_2=True, + # use_cache=False + ) generate_config = dict(top_k=50, top_p=0.75, temperature=0.9) if args.backend == "transformers": @@ -31,13 +43,6 @@ torch_dtype=torch.bfloat16, ) ) - train_model_config.update( - dict( - use_flash_attention_2=True, - torch_dtype=torch.bfloat16, - use_cache=False, - ) - ) generate_config.update( dict( max_length=1024 + 512, @@ -78,15 +83,17 @@ inference_batch_size=args.inference_batch_size, inference_microbatch_size=args.inference_microbatch_size, train_batch_size=args.train_batch_size, + train_minibatch_size=args.train_minibatch_size, train_microbatch_size=args.train_microbatch_size, dataset_config={"path": args.dataset, "max_length": 300}, dataloaders_config={}, inference_model_config=inference_model_config, generate_config=generate_config, + num_generations=args.num_generations, train_model_config=train_model_config, plugin_config={}, inference_backend=args.backend, master_addr="localhost", - master_port=29503, + master_port=29505, core_algo=args.algo, ) From ed43a4be046a3fe7650b061086681e1bd1599b12 Mon Sep 17 00:00:00 2001 From: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Date: Wed, 9 Apr 2025 13:23:24 +0800 Subject: [PATCH 34/35] [Distributed RLHF] Integration of PP (#6257) * update help information * update style * fix * minor fix * support PP training * add pp support * remove unused code * address conversation --------- Co-authored-by: Tong Li --- .gitignore | 1 + .../coati/distributed/consumer.py | 2 - .../coati/distributed/grpo_consumer.py | 306 ++++++++++++------ .../ColossalChat/coati/distributed/launch.py | 2 + applications/ColossalChat/rl_example.py | 63 +++- .../booster/plugin/hybrid_parallel_plugin.py | 6 +- colossalai/shardformer/modeling/qwen2.py | 1 + 7 files changed, 264 insertions(+), 117 deletions(-) diff --git a/.gitignore b/.gitignore index 16f764c1b1ef..533450a7cce1 100644 --- a/.gitignore +++ b/.gitignore @@ -164,3 +164,4 @@ coverage.xml applications/ColossalChat/logs applications/ColossalChat/tests/logs applications/ColossalChat/wandb +applications/ColossalChat/model diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 027acc2e7537..c6ae7be2daf9 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -54,7 +54,6 @@ def __init__( self.model_config = model_config self.plugin_config = plugin_config - assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now" self.device = get_current_device() self.lr_scheduler = None @@ -95,7 +94,6 @@ def loop(self) -> None: i = 0 for _ in range(self.num_recv_per_update): # receive data from producers - for r in range(self.num_producers): print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}") self.buffer.extend( diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 4174f96514b8..a282439cbd33 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -39,6 +39,7 @@ def __init__( use_wandb=True, generate_config=None, training_config={}, + project_name=None, ): super().__init__( num_producers, @@ -69,6 +70,7 @@ def __init__( self.accum_count = 0 self.generate_config = generate_config self.training_config = training_config + self.project_name = project_name # Reference model is initialized from policy model. self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) @@ -94,9 +96,7 @@ def __init__( self.policy_loss_fn = PolicyLoss() self.global_step = 0 - if use_wandb and self.rank == 0: - name = f"{generate_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}" - self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name) + self.use_wandb = use_wandb self.lr_scheduler = CosineAnnealingWarmupLR( optimizer=self.optimizer, @@ -107,10 +107,19 @@ def __init__( def setup(self): super().setup() + if self.use_wandb and ( + (not self.plugin.pp_size > 1 and self.rank == 0) + or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()) + ): + # Initialize wandb. + name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}" + self.wandb_run = wandb.init(project=self.project_name, sync_tensorboard=True, dir="./wandb", name=name) + self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost( self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler ) self.reference_model, *_ = self.booster.boost(self.reference_model) + self.plugin.logger.set_level("ERROR") def step(self, step_idx: int, **kwargs) -> Optional[float]: """ @@ -168,6 +177,7 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: ).repeat_interleave(self.num_generations, dim=0) ) mean_kl, mean_loss = [], [] + for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size): input_ids_forward_micro_batch = data["input_ids"][ forward_micro_batch_start : forward_micro_batch_start + forward_batch_size @@ -186,112 +196,210 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: advantages_forward_micro_batch = advantages[ forward_micro_batch_start : forward_micro_batch_start + forward_batch_size ] - policy_model_logits = self.policy_model( - input_ids=input_ids_forward_micro_batch, - attention_mask=attention_mask_forward_micro_batch, - ).logits - action_log_probs = calc_action_log_probs( - policy_model_logits / self.generate_config["temperature"], - input_ids_forward_micro_batch, - num_action, - self.plugin.shard_config, - ) - with torch.no_grad(): - reference_model_logits = self.reference_model( + if self.plugin.pp_size > 1: + # Support training with PP. + + with torch.no_grad(): + reference_model_outputs = self.booster.execute_pipeline( + iter( + [ + { + "input_ids": input_ids_forward_micro_batch, + "attention_mask": attention_mask_forward_micro_batch, + } + ] + ), + self.reference_model, + criterion=lambda outputs, inputs: torch.tensor( + [0.0], device=action_mask.device + ), # dummy criterion + optimizer=None, + return_loss=False, + return_outputs=True, + ) + + if self.booster.plugin.stage_manager.is_last_stage(): + reference_model_logits = reference_model_outputs["outputs"]["logits"] + reference_action_log_probs = calc_action_log_probs( + reference_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) + else: + # Dummy reference logprobs for data iterator. + reference_action_log_probs = None + + data_policy_forward = { + "input_ids": input_ids_forward_micro_batch, + "attention_mask": attention_mask_forward_micro_batch, + "action_mask": action_mask_forward_micro_batch, + "reference_action_log_probs": reference_action_log_probs, + "advantages": advantages_forward_micro_batch, + "loss_mask": loss_mask_forward_micro_batch, + "source": self.rank, + } + + kl = [] + + def _criterion(outputs, inputs): + action_logits = outputs.logits + action_log_probs = calc_action_log_probs( + action_logits / self.generate_config["temperature"], + inputs["input_ids"], + num_action, + self.plugin.shard_config, + ) + per_token_kl = ( + torch.exp(inputs["reference_action_log_probs"] - action_log_probs) + - (inputs["reference_action_log_probs"] - action_log_probs) + - 1 + ) + appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum( + inputs["action_mask"], dim=-1 + ) + kl.append(appox_kl.mean()) + loss, skip_update, _ = self.policy_loss_fn( + action_log_probs, + action_log_probs, + inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1), + per_token_kl, + inputs["action_mask"], + loss_mask=inputs["loss_mask"], + ) + return loss + + policy_model_outputs = self.booster.execute_pipeline( + iter([data_policy_forward]), + self.policy_model, + criterion=_criterion, + optimizer=self.optimizer, + return_loss=True, + return_outputs=True, + ) + loss = policy_model_outputs["loss"] + + if self.booster.plugin.stage_manager.is_last_stage(): + if len(kl) > 0: + kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin) + mean_kl.append(kl) + loss = all_reduce_mean(loss, self.plugin) + mean_loss.append(loss.data) + else: + + policy_model_logits = self.policy_model( input_ids=input_ids_forward_micro_batch, attention_mask=attention_mask_forward_micro_batch, ).logits - reference_action_log_probs = calc_action_log_probs( - reference_model_logits / self.generate_config["temperature"], - input_ids_forward_micro_batch, - num_action, - self.plugin.shard_config, - ) + action_log_probs = calc_action_log_probs( + policy_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) - per_token_kl = ( - torch.exp(reference_action_log_probs - action_log_probs) - - (reference_action_log_probs - action_log_probs) - - 1 - ) - kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum( - action_mask_forward_micro_batch, dim=-1 - ) + with torch.no_grad(): + reference_model_logits = self.reference_model( + input_ids=input_ids_forward_micro_batch, + attention_mask=attention_mask_forward_micro_batch, + ).logits + reference_action_log_probs = calc_action_log_probs( + reference_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) + per_token_kl = ( + torch.exp(reference_action_log_probs - action_log_probs) + - (reference_action_log_probs - action_log_probs) + - 1 + ) + kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum( + action_mask_forward_micro_batch, dim=-1 + ) - loss, skip_update, _ = self.policy_loss_fn( - action_log_probs, - old_action_log_probs, - advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1), - per_token_kl, - action_mask_forward_micro_batch, - loss_mask=loss_mask_forward_micro_batch, - ) + loss, skip_update, _ = self.policy_loss_fn( + action_log_probs, + old_action_log_probs, + advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1), + per_token_kl, + action_mask_forward_micro_batch, + loss_mask=loss_mask_forward_micro_batch, + ) - if not skip_update: - self.booster.backward(loss, self.optimizer) - loss = all_reduce_mean(loss, self.plugin) - kl = all_reduce_mean(kl.mean(), self.plugin) - # Calculate accumulate value. - mean_kl.append(kl.data) - mean_loss.append(loss.data) - - reward = all_reduce_mean(reward.mean(), self.plugin) - format_reward = all_reduce_mean(format_reward.mean(), self.plugin) - acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin) - advantages = all_reduce_mean(advantages.mean(), self.plugin) - response_length = all_reduce_mean(response_length.mean(), self.plugin) - self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) - self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) - self.accum_reward.add_(reward.data) - self.accum_format_reward.add_(format_reward.data) - self.accum_acc_reward.add_(acc_reward.data) - self.accum_advantages.add_(advantages.data) - self.accum_response_length.add_(response_length.data) - self.accum_count += 1 + if not skip_update: + self.booster.backward(loss, self.optimizer) + loss = all_reduce_mean(loss, self.plugin) + kl = all_reduce_mean(kl.mean(), self.plugin) + # Calculate accumulate value. + mean_kl.append(kl.data) + mean_loss.append(loss.data) + if not self.plugin.pp_size > 1 or ( + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() + ): + reward = all_reduce_mean(reward.mean(), self.plugin) + format_reward = all_reduce_mean(format_reward.mean(), self.plugin) + acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin) + advantages = all_reduce_mean(advantages.mean(), self.plugin) + response_length = all_reduce_mean(response_length.mean(), self.plugin) + self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) + self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) + self.accum_reward.add_(reward.data) + self.accum_format_reward.add_(format_reward.data) + self.accum_acc_reward.add_(acc_reward.data) + self.accum_advantages.add_(advantages.data) + self.accum_response_length.add_(response_length.data) + self.accum_count += 1 if need_update: self.optimizer.step() self.optimizer.zero_grad() - loss_scalar = self.accum_loss.item() - if self.rank == 0: - print( - "Loss:", - self.accum_loss.item() / self.accum_count, - "\nReward:", - self.accum_reward.item() / self.accum_count, - "\nFormat Reward:", - self.accum_format_reward.item() / self.accum_count, - "\nAcc Reward:", - self.accum_acc_reward.item() / self.accum_count, - "\nKL:", - self.accum_kl.item() / self.accum_count, - "\nAdvantages:", - self.accum_advantages.item() / self.accum_count, - "\nResponse Length:", - self.accum_response_length.item() / self.accum_count, - ) - self.wandb_run.log( - { - "metrics/reward": self.accum_reward.item() / self.accum_count, - "metrics/format_reward": self.accum_format_reward.item() / self.accum_count, - "metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count, - "metrics/response_length": self.accum_response_length.item() / self.accum_count, - "train/loss": self.accum_loss.item() / self.accum_count, - "train/kl": self.accum_kl.item() / self.accum_count, - "train/advantages": self.accum_advantages.item() / self.accum_count, - "train/learning_rate": self.lr_scheduler.get_last_lr()[0], - "rollout/temperature": data["temperature"].cpu().numpy()[0][0], - } - ) - self.accum_loss.zero_() - self.accum_reward.zero_() - self.accum_acc_reward.zero_() - self.accum_format_reward.zero_() - self.accum_kl.zero_() - self.accum_advantages.zero_() - self.accum_response_length.zero_() - - self.accum_count = 0 - return loss_scalar + if not self.plugin.pp_size > 1 or ( + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() + ): + loss_scalar = self.accum_loss.item() + if (not self.plugin.pp_size > 1 and self.rank == 0) or ( + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() + ): + print( + "Loss:", + self.accum_loss.item() / self.accum_count, + "\nReward:", + self.accum_reward.item() / self.accum_count, + "\nFormat Reward:", + self.accum_format_reward.item() / self.accum_count, + "\nAcc Reward:", + self.accum_acc_reward.item() / self.accum_count, + "\nKL:", + self.accum_kl.item() / self.accum_count, + "\nAdvantages:", + self.accum_advantages.item() / self.accum_count, + "\nResponse Length:", + self.accum_response_length.item() / self.accum_count, + ) + self.wandb_run.log( + { + "metrics/reward": self.accum_reward.item() / self.accum_count, + "metrics/format_reward": self.accum_format_reward.item() / self.accum_count, + "metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count, + "metrics/response_length": self.accum_response_length.item() / self.accum_count, + "train/loss": self.accum_loss.item() / self.accum_count, + "train/kl": self.accum_kl.item() / self.accum_count, + "train/advantages": self.accum_advantages.item() / self.accum_count, + "train/learning_rate": self.lr_scheduler.get_last_lr()[0], + "rollout/temperature": data["temperature"].cpu().numpy()[0][0], + } + ) + self.accum_loss.zero_() + self.accum_reward.zero_() + self.accum_acc_reward.zero_() + self.accum_format_reward.zero_() + self.accum_kl.zero_() + self.accum_advantages.zero_() + self.accum_response_length.zero_() + + self.accum_count = 0 + return loss_scalar def state_dict(self): self.policy_model._force_wait_all_gather() diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index ba5d3a9d4fd8..699d90a8cdff 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -47,6 +47,7 @@ def launch_distributed( master_addr: str = "localhost", master_port: int = 29500, core_algo: str = "GRPO", + project_name: Optional[str] = None, ): if core_algo not in ALGO_MAP: @@ -108,6 +109,7 @@ def launch_distributed( "train_microbatch_size": train_microbatch_size, }, num_generations=num_generations, + project_name=project_name, ) procs.append(consumer) ray.get([p.setup.remote() for p in procs]) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 4a4a4c3404e9..f87f12ed23ca 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -10,13 +10,44 @@ parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") parser.add_argument("-t", "--num-trainers", type=int, default=2) parser.add_argument("-i", "--num-inferencer", type=int, default=2) - parser.add_argument("-g", "--num-generations", type=int, default=8) - parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64) - parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8) - parser.add_argument("-tbs", "--train-batch-size", type=int, default=32) - parser.add_argument("-tMbs", "--train-minibatch-size", type=int, default=1) - parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2) - parser.add_argument("-b", "--backend", type=str, default="transformers") + parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.") + parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.") + parser.add_argument( + "-ibs", + "--inference-batch-size", + type=int, + default=64, + help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.", + ) + parser.add_argument( + "-imbs", + "--inference-microbatch-size", + type=int, + default=8, + help="Effective batch size for the inference backend to run generation. Please select based on memory constraint.", + ) + parser.add_argument( + "-tbs", + "--train-batch-size", + type=int, + default=32, + help="Number of unique prompts to update policy model per step per dp group. Gradient is accumulated across tbs * dp_size unique prompts, equivalently tbs * g * dp_size samples", + ) + parser.add_argument( + "-tMbs", + "--train-minibatch-size", + type=int, + default=1, + help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs", + ) + parser.add_argument( + "-tmbs", + "--train-microbatch-size", + type=int, + default=2, + help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.", + ) + parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"]) parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"]) args = parser.parse_args() @@ -29,11 +60,7 @@ ray.init(address="local", namespace="ray-example") inference_model_config = dict(path=args.model) - train_model_config = dict( - path=args.model, - # use_flash_attention_2=True, - # use_cache=False - ) + train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False) generate_config = dict(top_k=50, top_p=0.75, temperature=0.9) if args.backend == "transformers": @@ -91,9 +118,17 @@ generate_config=generate_config, num_generations=args.num_generations, train_model_config=train_model_config, - plugin_config={}, + # plugin_config={}, # for zero + plugin_config={ + "pp_size": 2, + "tp_size": 1, + "microbatch_size": args.train_microbatch_size // 2, + "zero_stage": 0, + "max_norm": 1.0, + }, # for pp inference_backend=args.backend, master_addr="localhost", - master_port=29505, + master_port=29506, core_algo=args.algo, + project_name=args.project, ) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 1684fd702e70..74349091b4d4 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1411,8 +1411,10 @@ def execute_pipeline( ) # run with gradients accumulation - if model.require_grad_sync == False or ( - isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False + if ( + not torch.is_grad_enabled() + or model.require_grad_sync == False + or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False) ): return outputs diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 71e3557fe214..27571309e453 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -284,6 +284,7 @@ def qwen2_for_causal_lm_forward( hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, + **kwargs, ): r""" Args: From 9467c106904d62bcfc00d2484ce1374f74da5826 Mon Sep 17 00:00:00 2001 From: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Date: Thu, 10 Apr 2025 10:52:18 +0800 Subject: [PATCH 35/35] [hot-fix] Fix memory leakage bug, support TP+PP (#6258) * update help information * update style * fix * minor fix * support PP training * add pp support * remove unused code * address conversation * fix memory leakage support tp+pp * move empty cache * move empty cache --------- Co-authored-by: Tong Li --- .../ColossalChat/coati/distributed/consumer.py | 5 +++++ .../ColossalChat/coati/distributed/grpo_consumer.py | 13 ++++++------- applications/ColossalChat/rl_example.py | 2 +- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index c6ae7be2daf9..79beb2a2dba6 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -72,6 +72,8 @@ def setup(self) -> None: self.plugin = HybridParallelPlugin(**plugin_config) self.booster = Booster(plugin=self.plugin) self.dp_rank = dist.get_rank(self.plugin.dp_group) + self.tp_rank = dist.get_rank(self.plugin.tp_group) + self.dp_size = dist.get_world_size(self.plugin.dp_group) self.buffer = [] @@ -127,11 +129,14 @@ def loop(self) -> None: if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1: print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") + torch.cuda.empty_cache() state_dict = self.state_dict() if self.rank == 0: ray_broadcast_tensor_dict( state_dict, src=self.num_producers, device=self.device, group_name="sync_model" ) + del state_dict + torch.cuda.empty_cache() @ray.remote diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index a282439cbd33..e23254d1b07a 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -109,7 +109,7 @@ def setup(self): super().setup() if self.use_wandb and ( (not self.plugin.pp_size > 1 and self.rank == 0) - or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()) + or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0) ): # Initialize wandb. name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}" @@ -282,10 +282,9 @@ def _criterion(outputs, inputs): if self.booster.plugin.stage_manager.is_last_stage(): if len(kl) > 0: - kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin) + kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data mean_kl.append(kl) - loss = all_reduce_mean(loss, self.plugin) - mean_loss.append(loss.data) + mean_loss.append(all_reduce_mean(loss, self.plugin).data) else: policy_model_logits = self.policy_model( @@ -336,7 +335,7 @@ def _criterion(outputs, inputs): mean_kl.append(kl.data) mean_loss.append(loss.data) if not self.plugin.pp_size > 1 or ( - self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): reward = all_reduce_mean(reward.mean(), self.plugin) format_reward = all_reduce_mean(format_reward.mean(), self.plugin) @@ -355,11 +354,11 @@ def _criterion(outputs, inputs): self.optimizer.step() self.optimizer.zero_grad() if not self.plugin.pp_size > 1 or ( - self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): loss_scalar = self.accum_loss.item() if (not self.plugin.pp_size > 1 and self.rank == 0) or ( - self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): print( "Loss:", diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index f87f12ed23ca..6c43ccd1960f 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -121,7 +121,7 @@ # plugin_config={}, # for zero plugin_config={ "pp_size": 2, - "tp_size": 1, + "tp_size": 2, "microbatch_size": args.train_microbatch_size // 2, "zero_stage": 0, "max_norm": 1.0,