diff --git a/.gitignore b/.gitignore index 8bc74b4c8c2c..533450a7cce1 100644 --- a/.gitignore +++ b/.gitignore @@ -163,3 +163,5 @@ coverage.xml # log, test files - ColossalChat applications/ColossalChat/logs applications/ColossalChat/tests/logs +applications/ColossalChat/wandb +applications/ColossalChat/model diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py index fc3a4930c058..4518fd71f91b 100755 --- a/applications/ColossalChat/coati/dataset/loader.py +++ b/applications/ColossalChat/coati/dataset/loader.py @@ -356,10 +356,24 @@ def apply_chat_template_and_mask( truncation: bool = True, ignore_idx: int = -100, ) -> Dict[str, torch.Tensor]: + + system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the tags, i.e., 123 .\n\n" + + system_element = { + "role": "system", + "content": system_prompt, + } + + # Format for RL. + gt_answer = None + if "messages" in chat and "gt_answer" in chat: + gt_answer = chat["gt_answer"] + chat = [chat["messages"]] + tokens = [] assistant_mask = [] for i, msg in enumerate(chat): - msg_tokens = tokenizer.apply_chat_template([msg], tokenize=True) + msg_tokens = tokenizer.apply_chat_template([system_element, msg], tokenize=True, add_generation_prompt=True) # remove unexpected bos token if i > 0 and msg_tokens[0] == tokenizer.bos_token_id: msg_tokens = msg_tokens[1:] @@ -372,14 +386,10 @@ def apply_chat_template_and_mask( if max_length is not None: if padding and len(tokens) < max_length: to_pad = max_length - len(tokens) - if tokenizer.padding_side == "right": - tokens.extend([tokenizer.pad_token_id] * to_pad) - assistant_mask.extend([False] * to_pad) - attention_mask.extend([0] * to_pad) - else: - tokens = [tokenizer.pad_token_id] * to_pad + tokens - assistant_mask = [False] * to_pad + assistant_mask - attention_mask = [0] * to_pad + attention_mask + # Left padding for generation. + tokens = [tokenizer.pad_token_id] * to_pad + tokens + assistant_mask = [False] * to_pad + assistant_mask + attention_mask = [0] * to_pad + attention_mask if truncation and len(tokens) > max_length: tokens = tokens[:max_length] assistant_mask = assistant_mask[:max_length] @@ -389,6 +399,13 @@ def apply_chat_template_and_mask( labels = input_ids.clone() labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx + if gt_answer is not None: + gt_answer = tokenizer.encode( + gt_answer, padding="max_length", truncation=True, max_length=128, return_tensors="pt" + ) + gt_answer = gt_answer.squeeze(1) + return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "gt_answer": gt_answer} + return { "input_ids": input_ids, "attention_mask": attention_mask, diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 84a69979fb88..79beb2a2dba6 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -1,3 +1,4 @@ +import os from contextlib import nullcontext from typing import Any, Dict, Optional @@ -33,6 +34,8 @@ def __init__( model_config: Dict[str, Any], plugin_config: Dict[str, Any], microbatch_size: int = 1, + save_interval: int = 100, + save_dir: str = "./model", ): self.num_producers = num_producers self.num_episodes = num_episodes @@ -44,14 +47,16 @@ def __init__( self.num_recv_per_update = num_recv_per_update self.batch_size = batch_size self.microbatch_size = microbatch_size + self.save_interval = save_interval + self.save_dir = save_dir assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size" self.num_microbatches = batch_size // microbatch_size self.model_config = model_config self.plugin_config = plugin_config - assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now" self.device = get_current_device() + self.lr_scheduler = None def setup(self) -> None: for i in range(self.num_producers): @@ -60,18 +65,15 @@ def setup(self) -> None: cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model") launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0) - plugin_config = dict( - tp_size=1, - pp_size=1, - precision="bf16", - zero_stage=1, - ) + plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2) if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config: plugin_config["microbatch_size"] = self.microbatch_size plugin_config.update(self.plugin_config) self.plugin = HybridParallelPlugin(**plugin_config) self.booster = Booster(plugin=self.plugin) self.dp_rank = dist.get_rank(self.plugin.dp_group) + self.tp_rank = dist.get_rank(self.plugin.tp_group) + self.dp_size = dist.get_world_size(self.plugin.dp_group) self.buffer = [] @@ -94,7 +96,6 @@ def loop(self) -> None: i = 0 for _ in range(self.num_recv_per_update): # receive data from producers - for r in range(self.num_producers): print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}") self.buffer.extend( @@ -116,13 +117,26 @@ def loop(self) -> None: pbar.set_postfix({"loss": loss}) i += 1 assert len(self.buffer) == 0 + if self.lr_scheduler is not None: + self.lr_scheduler.step() + if (step + 1) % self.save_interval == 0: + if self.rank == 0: + print(f"Start saving policy model at step {step + 1}.") + save_path = os.path.join(self.save_dir, f"modeling-step-{step + 1}") + self.booster.save_model(self.policy_model, save_path, shard=True) + if self.rank == 0: + print(f"Saved model checkpoint at step {step + 1} in folder {save_path}") + if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1: print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") + torch.cuda.empty_cache() state_dict = self.state_dict() if self.rank == 0: ray_broadcast_tensor_dict( state_dict, src=self.num_producers, device=self.device, group_name="sync_model" ) + del state_dict + torch.cuda.empty_cache() @ray.remote diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py new file mode 100644 index 000000000000..e23254d1b07a --- /dev/null +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -0,0 +1,529 @@ +import json +import os +from contextlib import nullcontext +from typing import Optional + +import ray +import torch +import torch.distributed as dist +import wandb +from coati.distributed.consumer import BaseConsumer +from coati.distributed.loss import PolicyLoss +from coati.distributed.reward.reward_fn import math_reward_fn +from coati.distributed.reward.verifiable_reward import VerifiableReward +from coati.distributed.utils import calc_action_log_probs +from coati.trainer.utils import all_reduce_mean +from transformers import AutoModelForCausalLM, AutoTokenizer + +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam + + +@ray.remote +class GRPOConsumer(BaseConsumer): + def __init__( + self, + num_producers, + num_episodes, + rank, + world_size, + master_addr, + master_port, + num_update_per_episode, + num_recv_per_update, + batch_size, + model_config, + plugin_config, + microbatch_size=1, + num_generations=8, + use_wandb=True, + generate_config=None, + training_config={}, + project_name=None, + ): + super().__init__( + num_producers, + num_episodes, + rank, + world_size, + master_addr, + master_port, + num_update_per_episode, + num_recv_per_update, + batch_size, + model_config, + plugin_config, + microbatch_size, + ) + path = model_config.pop("path") + self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) + self.policy_model.train() + self.policy_model.gradient_checkpointing_enable() + self.optimizer = HybridAdam(self.policy_model.parameters(), lr=training_config.get("lr", 1e-6)) + self.accum_loss = torch.zeros(1, device=self.device) + self.accum_reward = torch.zeros(1, device=self.device) + self.accum_kl = torch.zeros(1, device=self.device) + self.accum_format_reward = torch.zeros(1, device=self.device) + self.accum_acc_reward = torch.zeros(1, device=self.device) + self.accum_advantages = torch.zeros(1, device=self.device) + self.accum_response_length = torch.zeros(1, device=self.device) + self.accum_count = 0 + self.generate_config = generate_config + self.training_config = training_config + self.project_name = project_name + + # Reference model is initialized from policy model. + self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) + self.reference_model.eval() + + self.tokenizer = AutoTokenizer.from_pretrained(path) + self.pad_token_id = self.tokenizer.pad_token_id + self.num_generations = num_generations + self.filter_range = training_config.get("filter_range", None) + if self.filter_range is not None: + assert len(self.filter_range) == 2, "Filter range should have 2 values." + + # Initialize verifiable reward. + response_format_tags = { + "think_start": {"text": "", "num_occur": 1}, + "think_end": {"text": "", "num_occur": 1}, + "answer_start": {"text": "", "num_occur": 1}, + "answer_end": {"text": "", "num_occur": 1}, + } + self.reward_model = VerifiableReward( + reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags + ) + + self.policy_loss_fn = PolicyLoss() + self.global_step = 0 + self.use_wandb = use_wandb + + self.lr_scheduler = CosineAnnealingWarmupLR( + optimizer=self.optimizer, + total_steps=min(self.num_episodes, 4) * self.num_update_per_episode, + warmup_steps=0, + eta_min=0.1 * training_config.get("lr", 1e-6), + ) + + def setup(self): + super().setup() + if self.use_wandb and ( + (not self.plugin.pp_size > 1 and self.rank == 0) + or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0) + ): + # Initialize wandb. + name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}" + self.wandb_run = wandb.init(project=self.project_name, sync_tensorboard=True, dir="./wandb", name=name) + + self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost( + self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler + ) + self.reference_model, *_ = self.booster.boost(self.reference_model) + self.plugin.logger.set_level("ERROR") + + def step(self, step_idx: int, **kwargs) -> Optional[float]: + """ + Step data from policy model: + [{ + "input_ids": torch.Tensor, + "attention_mask": torch.Tensor, + "action_mask": torch.Tensor, + "action_log_probs": torch.Tensor, + }, + ...] + Format: + [batch_size, num_of_generation, prompt_length + response_length] --- ............. + """ + + # Reshape to [batch_size x num_of_generation, prompt_length + response_length] + data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()} + action_mask = data["action_mask"] + num_action = action_mask.shape[1] + old_action_log_probs = data["action_log_probs"] + response_length = torch.sum(action_mask, dim=1).to(torch.float32) + forward_batch_size = self.training_config.get("train_microbatch_size", data["input_ids"].size(0)) + + need_update = (step_idx + 1) % self.num_microbatches == 0 + # Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500 + ctx = ( + nullcontext() + if need_update or self.booster.plugin.zero_stage == 2 + else self.booster.no_sync(self.policy_model, self.optimizer) + ) + with ctx: + reward_group = self.reward_model( + data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"] + ) + + reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device) + format_reward = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device) + acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device) + + # [batch_size, num_generations] + + group_reward = reward.view(-1, self.num_generations) + reward_mean = group_reward.mean(dim=1) + # [batch_size x num_generations] + reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0) + reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) + # [batch_size x num_generations] + advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1) + # filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score), + loss_mask = ( + None + if self.filter_range is None + else torch.logical_and( + reward_mean > self.filter_range[0], reward_mean < self.filter_range[1] + ).repeat_interleave(self.num_generations, dim=0) + ) + mean_kl, mean_loss = [], [] + + for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size): + input_ids_forward_micro_batch = data["input_ids"][ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + attention_mask_forward_micro_batch = data["attention_mask"][ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + action_mask_forward_micro_batch = action_mask[ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + loss_mask_forward_micro_batch = ( + loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size] + if loss_mask is not None + else None + ) + advantages_forward_micro_batch = advantages[ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + + if self.plugin.pp_size > 1: + # Support training with PP. + + with torch.no_grad(): + reference_model_outputs = self.booster.execute_pipeline( + iter( + [ + { + "input_ids": input_ids_forward_micro_batch, + "attention_mask": attention_mask_forward_micro_batch, + } + ] + ), + self.reference_model, + criterion=lambda outputs, inputs: torch.tensor( + [0.0], device=action_mask.device + ), # dummy criterion + optimizer=None, + return_loss=False, + return_outputs=True, + ) + + if self.booster.plugin.stage_manager.is_last_stage(): + reference_model_logits = reference_model_outputs["outputs"]["logits"] + reference_action_log_probs = calc_action_log_probs( + reference_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) + else: + # Dummy reference logprobs for data iterator. + reference_action_log_probs = None + + data_policy_forward = { + "input_ids": input_ids_forward_micro_batch, + "attention_mask": attention_mask_forward_micro_batch, + "action_mask": action_mask_forward_micro_batch, + "reference_action_log_probs": reference_action_log_probs, + "advantages": advantages_forward_micro_batch, + "loss_mask": loss_mask_forward_micro_batch, + "source": self.rank, + } + + kl = [] + + def _criterion(outputs, inputs): + action_logits = outputs.logits + action_log_probs = calc_action_log_probs( + action_logits / self.generate_config["temperature"], + inputs["input_ids"], + num_action, + self.plugin.shard_config, + ) + per_token_kl = ( + torch.exp(inputs["reference_action_log_probs"] - action_log_probs) + - (inputs["reference_action_log_probs"] - action_log_probs) + - 1 + ) + appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum( + inputs["action_mask"], dim=-1 + ) + kl.append(appox_kl.mean()) + loss, skip_update, _ = self.policy_loss_fn( + action_log_probs, + action_log_probs, + inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1), + per_token_kl, + inputs["action_mask"], + loss_mask=inputs["loss_mask"], + ) + return loss + + policy_model_outputs = self.booster.execute_pipeline( + iter([data_policy_forward]), + self.policy_model, + criterion=_criterion, + optimizer=self.optimizer, + return_loss=True, + return_outputs=True, + ) + loss = policy_model_outputs["loss"] + + if self.booster.plugin.stage_manager.is_last_stage(): + if len(kl) > 0: + kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data + mean_kl.append(kl) + mean_loss.append(all_reduce_mean(loss, self.plugin).data) + else: + + policy_model_logits = self.policy_model( + input_ids=input_ids_forward_micro_batch, + attention_mask=attention_mask_forward_micro_batch, + ).logits + action_log_probs = calc_action_log_probs( + policy_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) + + with torch.no_grad(): + reference_model_logits = self.reference_model( + input_ids=input_ids_forward_micro_batch, + attention_mask=attention_mask_forward_micro_batch, + ).logits + reference_action_log_probs = calc_action_log_probs( + reference_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) + per_token_kl = ( + torch.exp(reference_action_log_probs - action_log_probs) + - (reference_action_log_probs - action_log_probs) + - 1 + ) + kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum( + action_mask_forward_micro_batch, dim=-1 + ) + + loss, skip_update, _ = self.policy_loss_fn( + action_log_probs, + old_action_log_probs, + advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1), + per_token_kl, + action_mask_forward_micro_batch, + loss_mask=loss_mask_forward_micro_batch, + ) + + if not skip_update: + self.booster.backward(loss, self.optimizer) + loss = all_reduce_mean(loss, self.plugin) + kl = all_reduce_mean(kl.mean(), self.plugin) + # Calculate accumulate value. + mean_kl.append(kl.data) + mean_loss.append(loss.data) + if not self.plugin.pp_size > 1 or ( + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 + ): + reward = all_reduce_mean(reward.mean(), self.plugin) + format_reward = all_reduce_mean(format_reward.mean(), self.plugin) + acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin) + advantages = all_reduce_mean(advantages.mean(), self.plugin) + response_length = all_reduce_mean(response_length.mean(), self.plugin) + self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) + self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) + self.accum_reward.add_(reward.data) + self.accum_format_reward.add_(format_reward.data) + self.accum_acc_reward.add_(acc_reward.data) + self.accum_advantages.add_(advantages.data) + self.accum_response_length.add_(response_length.data) + self.accum_count += 1 + if need_update: + self.optimizer.step() + self.optimizer.zero_grad() + if not self.plugin.pp_size > 1 or ( + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 + ): + loss_scalar = self.accum_loss.item() + if (not self.plugin.pp_size > 1 and self.rank == 0) or ( + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 + ): + print( + "Loss:", + self.accum_loss.item() / self.accum_count, + "\nReward:", + self.accum_reward.item() / self.accum_count, + "\nFormat Reward:", + self.accum_format_reward.item() / self.accum_count, + "\nAcc Reward:", + self.accum_acc_reward.item() / self.accum_count, + "\nKL:", + self.accum_kl.item() / self.accum_count, + "\nAdvantages:", + self.accum_advantages.item() / self.accum_count, + "\nResponse Length:", + self.accum_response_length.item() / self.accum_count, + ) + self.wandb_run.log( + { + "metrics/reward": self.accum_reward.item() / self.accum_count, + "metrics/format_reward": self.accum_format_reward.item() / self.accum_count, + "metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count, + "metrics/response_length": self.accum_response_length.item() / self.accum_count, + "train/loss": self.accum_loss.item() / self.accum_count, + "train/kl": self.accum_kl.item() / self.accum_count, + "train/advantages": self.accum_advantages.item() / self.accum_count, + "train/learning_rate": self.lr_scheduler.get_last_lr()[0], + "rollout/temperature": data["temperature"].cpu().numpy()[0][0], + } + ) + self.accum_loss.zero_() + self.accum_reward.zero_() + self.accum_acc_reward.zero_() + self.accum_format_reward.zero_() + self.accum_kl.zero_() + self.accum_advantages.zero_() + self.accum_response_length.zero_() + + self.accum_count = 0 + return loss_scalar + + def state_dict(self): + self.policy_model._force_wait_all_gather() + model = self.policy_model.unwrap() + state_dict = model.state_dict() + return state_dict + + +@ray.remote +class GRPOEvalConsumer(BaseConsumer): + def __init__( + self, + num_producers, + num_episodes, + rank, + world_size, + master_addr, + master_port, + num_update_per_episode, + num_recv_per_update, + batch_size, + model_config, + plugin_config, + microbatch_size=1, + num_generations=4, + use_wandb=True, + log_dir="./results", + ): + super().__init__( + num_producers, + num_episodes, + rank, + world_size, + master_addr, + master_port, + num_update_per_episode, + num_recv_per_update, + batch_size, + model_config, + plugin_config, + microbatch_size, + ) + path = model_config.pop("path") + self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) + self.policy_model.train() + self.accum_reward = torch.zeros(1, device=self.device) + self.accum_format_reward = torch.zeros(1, device=self.device) + self.accum_acc_reward = torch.zeros(1, device=self.device) + self.accum_response_length = torch.zeros(1, device=self.device) + self.accum_count = torch.zeros(1, device=self.device) + + self.tokenizer = AutoTokenizer.from_pretrained(path) + self.pad_token_id = self.tokenizer.pad_token_id + self.num_generations = num_generations + + # Initialize verifiable reward. + response_format_tags = { + "think_start": {"text": "", "num_occur": 1}, + "think_end": {"text": "", "num_occur": 1}, + "answer_start": {"text": "", "num_occur": 1}, + "answer_end": {"text": "", "num_occur": 1}, + } + self.reward_model = VerifiableReward( + reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags + ) + + self.log_dir = log_dir + if not os.path.exists(self.log_dir): + os.makedirs(self.log_dir) + else: + os.system(f"rm -rf {self.log_dir}/*") + + def setup(self): + super().setup() + self.policy_model, _, *_ = self.booster.boost(self.policy_model) + + def step(self, step_idx: int, **kwargs) -> Optional[float]: + rank = dist.get_rank() + data = {k: v.view(-1, v.size(-1)).cpu() for k, v in kwargs.items()} + kwargs["input_ids"].size(0) + reward_group = self.reward_model( + data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"] + ) + reward = [value[0].item() for value in reward_group] + format_reward = [value[1].item() for value in reward_group] + acc_reward = [value[2].item() for value in reward_group] + response_length = [(data["response_idx"][i][1] - data["response_idx"][i][0]).item() for i in range(len(reward))] + + response = self.tokenizer.batch_decode(data["input_ids"], skip_special_tokens=True) + with open(f"{self.log_dir}/eval_results_rank_{rank}.jsonl", "a", encoding="utf8") as f: + for i in range(len(response)): + f.write( + json.dumps( + { + "response": response[i], + "reward": reward[i], + "format_reward": format_reward[i], + "acc_reward": acc_reward[i], + "response_length": response_length[i], + }, + ensure_ascii=False, + ) + + "\n" + ) + + self.accum_reward += sum(reward) + self.accum_format_reward += sum(format_reward) + self.accum_acc_reward += sum(acc_reward) + self.accum_response_length += sum(response_length) + self.accum_count += len(reward) + + # print results + total_count = all_reduce_mean(self.accum_count, self.plugin) + mean_reward = all_reduce_mean(self.accum_reward, self.plugin) / total_count + mean_format_reward = all_reduce_mean(self.accum_format_reward, self.plugin) / total_count + mean_acc_reward = all_reduce_mean(self.accum_acc_reward, self.plugin) / total_count + mean_response_length = all_reduce_mean(self.accum_response_length, self.plugin) / total_count + if rank == 0: + print( + f"Step {step_idx}: Mean Reward: {mean_reward}, Mean Format Reward: {mean_format_reward}, Mean Acc Reward: {mean_acc_reward}, Mean Response Length: {mean_response_length}" + ) + return None + + def state_dict(self): + self.policy_model._force_wait_all_gather() + model = self.policy_model.unwrap() + state_dict = model.state_dict() + return state_dict diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index 95b7d1e80308..17c71c8a85b1 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -53,7 +53,13 @@ class TransformersInferenceBackend(BaseInferenceBackend): ) FORCE_GENERATE_CONFIG = dict(output_logits=True, return_dict_in_generate=True) - def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): + def __init__( + self, + model_config: Dict[str, Any], + generate_config: Dict[str, Any], + tokenizer: PreTrainedTokenizer, + num_generations: int = 8, + ): model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG) model_config.update(self.FORCE_MODEL_CONFIG) path = model_config.pop("path") @@ -61,12 +67,22 @@ def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any] self.generate_config = generate_config.copy() self.generate_config.update(self.FORCE_GENERATE_CONFIG) self.tokenizer = tokenizer + self.num_generations = num_generations @torch.no_grad() def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: + micro_batch_size = input_ids.size(0) input_ids = input_ids.to(get_current_device()) attention_mask = attention_mask.to(get_current_device()) - out = self.model.generate(input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config) + gt_answer = None + if "gt_answer" in kwargs: + gt_answer = kwargs.pop("gt_answer") + if self.num_generations > 1: + input_ids = input_ids.repeat_interleave(self.num_generations, dim=0) + attention_mask = attention_mask.repeat_interleave(self.num_generations, dim=0) + out = self.model.generate( + input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config, tokenizer=self.tokenizer + ) input_len = input_ids.shape[-1] new_token_ids = out.sequences[:, input_len:] # get log probs @@ -76,10 +92,13 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar action_log_probs.append(log_probs_from_logits(logits[:, None, :], new_token_ids[:, i : i + 1])) action_log_probs = torch.cat(action_log_probs, dim=1) # get action mask + response_idx = torch.zeros((new_token_ids.size(0), 2), dtype=torch.int).to(get_current_device()) action_mask = torch.ones_like(new_token_ids, dtype=attention_mask.dtype) if self.tokenizer.eos_token_id is not None: for indices in torch.nonzero(new_token_ids == self.tokenizer.eos_token_id): action_mask[indices[0], indices[1] + 1 :] = 0 + response_idx[:, 0] = input_len + response_idx[:, 1] = input_len + action_mask.sum(dim=1) - 1 if attention_mask.size(0) != action_mask.size(0): assert action_mask.size(0) % attention_mask.size(0) == 0 @@ -91,7 +110,15 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar "attention_mask": attention_mask, "action_log_probs": action_log_probs, "action_mask": action_mask, + "response_idx": response_idx, } + + data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()} + + if gt_answer is not None: + # repeat gt_answer for each prompt. + data["gt_answer"] = gt_answer.repeat_interleave(self.num_generations, dim=1) + data = {k: v.to(get_current_device()) for k, v in data.items()} return data def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: @@ -99,7 +126,13 @@ def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: class SGLangInferenceBackend(BaseInferenceBackend): - def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): + def __init__( + self, + model_config: Dict[str, Any], + generate_config: Dict[str, Any], + tokenizer: PreTrainedTokenizer, + num_generations: int = 8, + ): if sgl is None: raise ImportError("sglang is not installed") path = model_config.pop("path") @@ -156,29 +189,46 @@ class VLLMInferenceBackend(BaseInferenceBackend): logprobs=0, ) - def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer): + def __init__( + self, + model_config: Dict[str, Any], + generate_config: Dict[str, Any], + tokenizer: PreTrainedTokenizer, + num_generations: int = 8, + ): if LLM is None: raise ImportError("vllm is not installed") model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG) path = model_config.pop("path") - self.llm = LLM(path, **model_config) + self.llm = LLM(model=path, **model_config) generate_config = generate_config.copy() generate_config.update(self.FORCE_GENERATE_CONFIG) + generate_config.update({"n": num_generations}) self.generate_config = SamplingParams(**generate_config) self.tokenizer = tokenizer + self.num_generations = num_generations @torch.no_grad() def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: + micro_batch_size = input_ids.size(0) + response_start_idx = input_ids.size(1) + first_non_padding_token_idx = (input_ids != self.tokenizer.pad_token_id).int().argmax(dim=1) + micro_batch_input_ids = input_ids.tolist() + micro_batch_input_ids_no_padding = [ + micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size) + ] outputs = self.llm.generate( - prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False + prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False ) out_tokens = [] out_len = [] log_probs = [] + response_idx = [] for out in outputs: for output_i in out.outputs: out_len.append(len(output_i.token_ids)) out_tokens.append(list(output_i.token_ids)) + response_idx.append((response_start_idx, response_start_idx + len(output_i.token_ids) - 1)) assert len(output_i.logprobs) == len(output_i.token_ids) p = [m[t].logprob for m, t in zip(output_i.logprobs, output_i.token_ids)] log_probs.append(p) @@ -195,6 +245,8 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar out_tokens = torch.tensor(out_tokens) log_probs = torch.tensor(log_probs) + response_idx = torch.tensor(response_idx) + if attention_mask.size(0) != action_mask.size(0): assert action_mask.size(0) % attention_mask.size(0) == 0 num_returns = action_mask.size(0) // attention_mask.size(0) @@ -209,7 +261,14 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar "attention_mask": attention_mask, "action_log_probs": log_probs, "action_mask": action_mask, + "response_idx": response_idx, } + + data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()} + + if "gt_answer" in kwargs: + # repeat gt_answer for each prompt. + data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(self.num_generations, dim=1) data = {k: v.to(get_current_device()) for k, v in data.items()} return data diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 438c463002f0..699d90a8cdff 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -1,10 +1,14 @@ +import copy from typing import Any, Dict, Optional import ray from .consumer import SimpleConsumer +from .grpo_consumer import GRPOConsumer, GRPOEvalConsumer from .producer import SimpleProducer +ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "EvalGRPO": GRPOEvalConsumer} + def get_jsonl_size_fast(path: str) -> int: with open(path) as f: @@ -30,6 +34,7 @@ def launch_distributed( inference_microbatch_size: int, train_batch_size: int, train_microbatch_size: int, + train_minibatch_size: int, dataset_config: Dict[str, Any], dataloaders_config: Dict[str, Any], inference_model_config: Dict[str, Any], @@ -38,9 +43,18 @@ def launch_distributed( plugin_config: Dict[str, Any], tokenizer_config: Optional[Dict[str, Any]] = None, inference_backend: str = "transformers", + num_generations: int = 8, master_addr: str = "localhost", master_port: int = 29500, + core_algo: str = "GRPO", + project_name: Optional[str] = None, ): + + if core_algo not in ALGO_MAP: + raise NotImplementedError(f"{core_algo} is not supported yet.") + else: + core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer) + train_dp_size = get_dp_size_fast(num_producers, plugin_config) assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0 @@ -65,10 +79,17 @@ def launch_distributed( tokenizer_config=tokenizer_config, microbatch_size=inference_microbatch_size, backend=inference_backend, + num_generations=num_generations, ) procs.append(producer) + generate_config_consumer = copy.deepcopy(generate_config) + generate_config_consumer.update( + dict( + backend=inference_backend, + ) + ) for i in range(num_consumer_procs): - consumer = SimpleConsumer.options(num_gpus=1).remote( + consumer = core_consumer.options(num_gpus=1).remote( num_producers=num_producers, num_episodes=num_episodes, rank=i, @@ -80,7 +101,15 @@ def launch_distributed( batch_size=train_batch_size, model_config=train_model_config, plugin_config=plugin_config, - microbatch_size=train_microbatch_size, + microbatch_size=train_minibatch_size, + generate_config=generate_config_consumer, + training_config={ + "filter_range": [0.05, 9.0], + "lr": 1e-6, + "train_microbatch_size": train_microbatch_size, + }, + num_generations=num_generations, + project_name=project_name, ) procs.append(consumer) ray.get([p.setup.remote() for p in procs]) diff --git a/applications/ColossalChat/coati/distributed/loss.py b/applications/ColossalChat/coati/distributed/loss.py new file mode 100644 index 000000000000..90ad09736281 --- /dev/null +++ b/applications/ColossalChat/coati/distributed/loss.py @@ -0,0 +1,45 @@ +from typing import Optional + +import torch +import torch.nn as nn +from coati.distributed.utils import masked_mean + + +class PolicyLoss(nn.Module): + """ + Policy Loss for PPO + """ + + def __init__(self, clip_eps: float = 0.2, skip_threshold: float = 20.0, beta: float = 0.01) -> None: + super().__init__() + self.clip_eps = clip_eps + self.skip_threshold = skip_threshold + self.beta = beta + + def forward( + self, + log_probs: torch.Tensor, + old_log_probs: torch.Tensor, + advantages: torch.Tensor, + per_token_kl: torch.Tensor, + action_mask: Optional[torch.Tensor] = None, + loss_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + skip = False + if action_mask is None: + ratio = (log_probs - log_probs.detach()).exp() + else: + ratio = ((log_probs - log_probs.detach()) * action_mask).exp() + + surr1 = ratio * advantages + surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages + loss = -torch.min(surr1, surr2) + self.beta * per_token_kl + + if action_mask is not None: + loss = masked_mean(loss, action_mask) + else: + loss = loss.mean(dim=1) + if loss_mask is not None: + loss = loss * loss_mask + loss = loss.mean() + return loss, skip, ratio.max() diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 3e4a5277ad2e..2c6a24a36711 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -100,7 +100,11 @@ def loop(self) -> None: if i >= num_valid_microbatches: break outputs = self.rollout(**batch) + print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") + outputs["temperature"] = torch.tensor( + [self.model.generate_config.temperature] * outputs["input_ids"].size(0) + ).to(outputs["input_ids"].device) outputs = pre_send(outputs) ray_broadcast_tensor_dict( outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}" @@ -113,10 +117,19 @@ def loop(self) -> None: print( f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}" ) + state_dict = ray_broadcast_tensor_dict( None, self.num_producers, device=self.device, group_name="sync_model" ) self.load_state_dict(state_dict) + del state_dict + torch.cuda.empty_cache() + # linear annealing for 1 episode, temperature from initial to 0.7 + if episode <= 0: + ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader) + self.model.generate_config.temperature = (1 - ratio) * self.generate_config[ + "temperature" + ] + ratio * 0.7 @ray.remote @@ -135,6 +148,7 @@ def __init__( tokenizer_config=None, microbatch_size=1, backend="transformers", + num_generations: int = 8, ): super().__init__( producer_idx, @@ -150,11 +164,15 @@ def __init__( microbatch_size, backend, ) - self.model = self.backend_cls(model_config, generate_config, self.tokenizer) + self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations) @torch.no_grad() def rollout(self, input_ids, attention_mask, **kwargs): - return self.model.generate(input_ids, attention_mask, **kwargs) + rollouts = self.model.generate(input_ids, attention_mask, **kwargs) + if self.producer_idx == 1: + print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True)) + + return rollouts def load_state_dict(self, state_dict): self.model.load_state_dict(state_dict) diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py new file mode 100644 index 000000000000..53bc15e25a8c --- /dev/null +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -0,0 +1,61 @@ +import torch + +from .reward_utils import extract_solution, validate_response_structure + + +def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): + format_score = 1.0 + acc_score = 9.0 + tokenizer = kwargs["tokenizer"] + reward = torch.tensor(0.0) + format_reward = torch.tensor(0.0) + acc_reward = torch.tensor(0.0) + s, e = response_idx[0], response_idx[1] + if gt_answer is None: + return reward + + decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) + gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True) + final_answer, processed_str = extract_solution(decoded_final_answer) + + format_valid = validate_response_structure(processed_str, kwargs["tags"]) + + # Check format accuracy + if format_valid: + format_reward += format_score + reward += format_score + + # Check answer accuracy + if ( + final_answer is not None + and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower() + ): + acc_reward += acc_score + reward += acc_score + + return torch.tensor([reward, format_reward, acc_reward]).to(input_ids.device) + + +def gsm8k_reward_fn(input_ids, **kwargs): + gt_answer = kwargs["gt_answer"] + tokenizer = kwargs["tokenizer"] + s, e = kwargs["response_start"], kwargs["response_end"] + reward = torch.tensor(0.0).to(input_ids.device) + if gt_answer is None: + return reward + decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) + final_answer, processed_str = extract_solution(decoded_final_answer) + is_valid = True + try: + int(final_answer.strip()) + except Exception: + is_valid = False + + format_valid = validate_response_structure(processed_str, kwargs["tags"]) + if not is_valid or not format_valid: + return reward + else: + reward += 1.0 + if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower(): + reward = reward + 9.0 + return reward diff --git a/applications/ColossalChat/coati/distributed/reward/reward_utils.py b/applications/ColossalChat/coati/distributed/reward/reward_utils.py new file mode 100644 index 000000000000..c1e73d4b9738 --- /dev/null +++ b/applications/ColossalChat/coati/distributed/reward/reward_utils.py @@ -0,0 +1,76 @@ +# Copyright Unakar +# Modified from https://github.com/Unakar/Logic-RL/blob/086373176ac198c97277ff50f4b6e7e1bfe669d3/verl/utils/reward_score/kk.py#L99 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Dict, Optional, Tuple + + +def validate_response_structure(processed_str: str, tags: Dict = None) -> bool: + """Performs comprehensive validation of response structure. + + Args: + processed_str: Processed response string from the model + + Returns: + Boolean indicating whether all formatting requirements are met + """ + validation_passed = True + # Check required tags + if tags is None: + tags = { + "think_start": {"text": "", "num_occur": 1}, + "think_end": {"text": "", "num_occur": 1}, + "answer_start": {"text": "", "num_occur": 1}, + "answer_end": {"text": "", "num_occur": 1}, + } + positions = {} + for tag_name, tag_info in tags.items(): + tag_str = tag_info["text"] + expected_count = tag_info["num_occur"] + count = processed_str.count(tag_str) + positions[tag_name] = pos = processed_str.find(tag_str) + if count != expected_count: + validation_passed = False + # Verify tag order + if ( + positions["think_start"] > positions["think_end"] + or positions["think_end"] > positions["answer_start"] + or positions["answer_start"] > positions["answer_end"] + ): + validation_passed = False + if len(processed_str) - positions["answer_end"] != len(tags["answer_end"]["text"]): + validation_passed = False + return validation_passed + + +def extract_solution(solution_str: str) -> Tuple[Optional[str], str]: + """Extracts the final answer from the model's response string. + + Args: + solution_str: Raw response string from the language model + + Returns: + Tuple containing (extracted_answer, processed_string) + """ + + # Extract final answer using XML-style tags + answer_pattern = r"(.*?)" + matches = list(re.finditer(answer_pattern, solution_str, re.DOTALL)) + + if not matches: + return None, solution_str + + final_answer = matches[-1].group(1).strip() + return final_answer, solution_str diff --git a/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py new file mode 100644 index 000000000000..ba83f7787586 --- /dev/null +++ b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py @@ -0,0 +1,43 @@ +""" +Function-based reward verification module. +""" + +from typing import Any, Dict, List + +import torch + + +class VerifiableReward: + def __init__(self, reward_fns: List[callable], **kwargs: List[Dict[str, Any]]): + self.reward_fns = reward_fns + self.kwargs = kwargs + + def __call__( + self, + input_ids: torch.LongTensor, + gt_answer: List[torch.Tensor] = None, + response_idx: List[torch.Tensor] = None, + ) -> torch.Tensor: + # Get batch size + bs = input_ids.size(0) + # Initialize reward + rewards = torch.zeros((bs, 3), device=input_ids.device) + + # Loop through reward functions + for reward_fn in self.reward_fns: + # Apply the reward function to the entire batch at once + reward_batch = torch.stack( + [ + reward_fn( + input_ids[i], + gt_answer=gt_answer[i], + response_idx=response_idx[i], + **self.kwargs, + ) + for i in range(bs) + ], + dim=0, + ) + + rewards += reward_batch + return rewards diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index 533a5ffb22da..919e4434faa6 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -2,6 +2,8 @@ import torch +from colossalai.shardformer.layer.loss import dist_log_prob + def unbind_batch(batch: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]: batches = [] @@ -64,3 +66,50 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T log_probs = torch.log_softmax(logits, dim=-1) per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)) return per_label_logps.squeeze(-1) + + +def calc_action_log_probs( + logits: torch.Tensor, + sequences: torch.LongTensor, + num_actions: int, + shard_config, + vocab_size: int = None, +) -> torch.Tensor: + """Calculate action log probs. + + Args: + logits (torch.Tensor): Output tensor of Actor.forward.logits. + sequences (torch.LongTensor): Input sequences. + num_actions (int): Number of actions. + shard_config + vocab_size + + + Returns: + torch.Tensor: Action log probs. + """ + # labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] + # logits: torch.Tensor, # [B, S, Vocab_size] + log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype) + log_probs = log_probs.squeeze(-1) + return log_probs[:, -num_actions:] + + +def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor: + """ + Compute the masked mean of a tensor along a specified dimension. + + Args: + tensor (torch.Tensor): The input tensor. + mask (torch.Tensor): The mask tensor with the same shape as the input tensor. + dim (int, optional): The dimension along which to compute the mean. Default is 1. + + Returns: + torch.Tensor: The masked mean tensor. + + """ + tensor = tensor * mask + tensor = tensor.sum(dim=dim) + mask_sum = mask.sum(dim=dim) + mean = tensor / (mask_sum + 1e-8) + return mean diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index a6f82b3bebf9..6c43ccd1960f 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -10,54 +10,83 @@ parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") parser.add_argument("-t", "--num-trainers", type=int, default=2) parser.add_argument("-i", "--num-inferencer", type=int, default=2) - parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64) - parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=16) - parser.add_argument("-tbs", "--train-batch-size", type=int, default=16) - parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2) - parser.add_argument("-b", "--backend", type=str, default="transformers") + parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.") + parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.") + parser.add_argument( + "-ibs", + "--inference-batch-size", + type=int, + default=64, + help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.", + ) + parser.add_argument( + "-imbs", + "--inference-microbatch-size", + type=int, + default=8, + help="Effective batch size for the inference backend to run generation. Please select based on memory constraint.", + ) + parser.add_argument( + "-tbs", + "--train-batch-size", + type=int, + default=32, + help="Number of unique prompts to update policy model per step per dp group. Gradient is accumulated across tbs * dp_size unique prompts, equivalently tbs * g * dp_size samples", + ) + parser.add_argument( + "-tMbs", + "--train-minibatch-size", + type=int, + default=1, + help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs", + ) + parser.add_argument( + "-tmbs", + "--train-microbatch-size", + type=int, + default=2, + help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.", + ) + parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"]) + parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"]) args = parser.parse_args() + assert args.train_minibatch_size > 0, "Train mini batch size must be greater than 0" + assert ( + args.train_minibatch_size * args.num_generations >= args.train_microbatch_size + and args.train_microbatch_size > 0 + ), "Train micro batch size must be greater than 0 less than train mini batch size * num generations" + ray.init(address="local", namespace="ray-example") inference_model_config = dict(path=args.model) - train_model_config = dict(path=args.model) - generate_config = dict( - top_k=50, - top_p=0.8, - ) + train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False) + generate_config = dict(top_k=50, top_p=0.75, temperature=0.9) if args.backend == "transformers": inference_model_config.update( dict( - attn_implementation="flash_attention_2", + use_flash_attention_2=True, torch_dtype=torch.bfloat16, ) ) - train_model_config.update( - dict( - attn_implementation="flash_attention_2", - torch_dtype=torch.bfloat16, - use_cache=False, - ) - ) generate_config.update( dict( - max_length=512, + max_length=1024 + 512, do_sample=True, max_new_tokens=None, early_stopping=False, + stop_strings=[""], ) ) elif args.backend == "vllm": - inference_model_config.update( - dict( - gpu_memory_utilization=0.6, - ) - ) + inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True)) generate_config.update( dict( - max_tokens=256, + max_tokens=2048, ignore_eos=True, + include_stop_str_in_output=True, + stop=[""], ) ) else: @@ -77,18 +106,29 @@ num_producers=args.num_inferencer, num_proc_per_producer=1, num_consumer_procs=args.num_trainers, - num_episodes=1, + num_episodes=10, inference_batch_size=args.inference_batch_size, inference_microbatch_size=args.inference_microbatch_size, train_batch_size=args.train_batch_size, + train_minibatch_size=args.train_minibatch_size, train_microbatch_size=args.train_microbatch_size, - dataset_config={"path": args.dataset, "max_length": 256}, + dataset_config={"path": args.dataset, "max_length": 300}, dataloaders_config={}, inference_model_config=inference_model_config, generate_config=generate_config, + num_generations=args.num_generations, train_model_config=train_model_config, - plugin_config={}, + # plugin_config={}, # for zero + plugin_config={ + "pp_size": 2, + "tp_size": 2, + "microbatch_size": args.train_microbatch_size // 2, + "zero_stage": 0, + "max_norm": 1.0, + }, # for pp inference_backend=args.backend, master_addr="localhost", - master_port=29504, + master_port=29506, + core_algo=args.algo, + project_name=args.project, ) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 1684fd702e70..74349091b4d4 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1411,8 +1411,10 @@ def execute_pipeline( ) # run with gradients accumulation - if model.require_grad_sync == False or ( - isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False + if ( + not torch.is_grad_enabled() + or model.require_grad_sync == False + or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False) ): return outputs diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 0bd1b60923e9..a1b80bf56b63 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -3,7 +3,7 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D -from .loss import cross_entropy_1d, dist_cross_entropy +from .loss import cross_entropy_1d, dist_cross_entropy, dist_log_prob, dist_log_prob_1d from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .parallel_module import ParallelModule from .qkv_fused_linear import ( @@ -28,6 +28,8 @@ "DropoutForReplicatedInput", "cross_entropy_1d", "dist_cross_entropy", + "dist_log_prob_1d", + "dist_log_prob", "BaseLayerNorm", "LayerNorm", "RMSNorm", diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 0e2241af9fc9..a9bb76fc7d6b 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -3,13 +3,21 @@ from torch.autograd import Function from torch.distributed import ProcessGroup from torch.nn import CrossEntropyLoss +from torch.nn.functional import log_softmax from colossalai.shardformer.layer._operation import reduce_forward from colossalai.shardformer.shard import ShardConfig from .utils import is_share_sp_tp -__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"] +__all__ = [ + "DistCrossEntropy", + "cross_entropy_1d", + "dist_cross_entropy", + "DistLogProb", + "dist_log_prob_1d", + "dist_log_prob", +] _IGNORE_IDX = -100 @@ -137,6 +145,98 @@ def backward(ctx, grad_output): return grad_logits, None, None, None, None, None, None +class DistLogProb(Function): + r""" + Overwrite the forward and backward function to calculate the log prob before gather + + Args: + Function (:class:`torch.autograd.Function`): default + """ + + @staticmethod + def forward( + ctx, + vocab_logits: torch.Tensor, + target: torch.Tensor, + process_group: ProcessGroup, + vocab_size: int, + dtype=torch.float32, + ): + + ################## + # Step1:Find the global maximum value of logits + ################## + logits_max = torch.max(vocab_logits, dim=-1)[0] + handle = dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group, async_op=True) + + ################## + # Step2:Find the local mask. local mask will be use to select log_probs value in Step 4. + # For accleration, we overlap Step 2 and Step 3 + ################## + rank = dist.get_rank(group=process_group) + world_size = dist.get_world_size(group=process_group) + if vocab_size is None: + partition_vocab_size = vocab_logits.size()[-1] + global_vocab_size = partition_vocab_size * world_size + else: + global_vocab_size = vocab_size + partition_vocab_size = global_vocab_size // world_size + # down and up threshold for local logits + delta = (global_vocab_size + world_size - 1) // world_size + down_threshold = rank * delta + up_threshold = down_threshold + delta + if up_threshold > global_vocab_size: + up_threshold = global_vocab_size + # mask + mask = (target < down_threshold) | (target >= up_threshold) + masked_target = target.clone() - down_threshold + masked_target[mask] = 0 + masked_target_1d = masked_target.view(-1).contiguous() + handle.wait() + + ################## + # Step3:Calculate global summation exp logits + ################## + vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1) + exp_logits = torch.exp(vocab_logits) + sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32) # local summation exp logits + dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group) + + ################## + # Step4:Calculate local prob. We first cal log_softmax, then select log probs via local mask + ################## + log_probs = vocab_logits - torch.log(sum_exp_logits.unsqueeze(dim=-1)) # cal log_softmax + log_probs = log_probs.gather(dim=-1, index=masked_target.unsqueeze(-1)) + log_probs[mask.unsqueeze(-1)] = 0 # set masked val to zero + dist.all_reduce(log_probs, op=dist.ReduceOp.SUM, group=process_group) + + ctx.save_for_backward(exp_logits, mask, masked_target_1d, sum_exp_logits) + ctx.dtype = dtype + return log_probs + + @staticmethod + def backward(ctx, grad_output): + exp_logits, mask, masked_target_1d, sum_exp_logits = ctx.saved_tensors + ################## + # Step1:Find the global sofmax value + ################## + softmax_logits = exp_logits / sum_exp_logits.unsqueeze(dim=-1) + + ################## + # Step2:Update softmax value based on local target index + ################## + partion_vocab_size = softmax_logits.shape[-1] + softmax_logits_2d = softmax_logits.view(-1, partion_vocab_size) + update = 1.0 - mask.view(-1).float().to(ctx.dtype) + softmax_logits_2d[torch.arange(0, softmax_logits_2d.shape[0]), masked_target_1d] -= update + + ################## + # Step3:Calculate grad_output, which is the gradient of the loss function with respect to the output of logsoftmax + ################## + grad_logits = -softmax_logits.mul_(grad_output) + return grad_logits, None, None, None, None, None, None + + def cross_entropy_1d( vocab_logits: torch.Tensor, labels: torch.Tensor, @@ -149,6 +249,16 @@ def cross_entropy_1d( return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype, mode) +def dist_log_prob_1d( + vocab_logits: torch.Tensor, + labels: torch.Tensor, + process_group: ProcessGroup = None, + vocab_size: int = None, + dtype: torch.dtype = None, +) -> torch.Tensor: + return DistLogProb.apply(vocab_logits, labels, process_group, vocab_size, dtype) + + def dist_cross_entropy( labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] logits: torch.Tensor, # [B, S, Vocab_size] @@ -243,3 +353,41 @@ def dist_cross_entropy( loss, num_nonzero = loss[0], loss[1].detach() loss = (loss / num_nonzero).squeeze() return loss + + +def dist_log_prob( + labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] + logits: torch.Tensor, # [B, S, Vocab_size] + shard_config: ShardConfig, + vocab_size: int, + dtype: torch.dtype, + seq_dim: int = 1, +) -> torch.Tensor: + """ + Helper to compute log prob for most shardformer models supporting PP, TP. + """ + # Split labels if not gather output + parallel_output = shard_config.parallel_output + is_tp = shard_config.enable_tensor_parallelism + + # TODO:support sp + labels = labels[..., 1:] + logits = logits[..., :-1, :] + labels = labels.contiguous() + logits = logits.contiguous() + assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}" + + # Flatten the tokens + if is_tp and parallel_output: + log_prob = dist_log_prob_1d( + logits, + labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=vocab_size, + dtype=dtype, + ) + else: + log_prob = log_softmax(logits, dim=-1) + log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1)) + + return log_prob diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 569fc4a459c5..27571309e453 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -284,6 +284,7 @@ def qwen2_for_causal_lm_forward( hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, + **kwargs, ): r""" Args: @@ -832,7 +833,6 @@ def forward( loss = None if labels is not None: loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype) - if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index 84d2b2fdbd99..0adcdfdbd553 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -13,6 +13,7 @@ PaddingEmbedding, RMSNorm, VocabParallelEmbedding1D, + VocabParallelLMHead1D, ) from ..modeling.qwen2 import ( @@ -429,8 +430,12 @@ def module_policy(self): sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", - target_module=Linear1D_Col, - kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv), + target_module=VocabParallelLMHead1D, + kwargs=dict( + gather_output=not self.shard_config.parallel_output, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ) ], method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}, @@ -446,7 +451,16 @@ def module_policy(self): suffix="lm_head", target_module=LinearWithGradAccum, kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv), - ) + ), + SubModuleReplacementDescription( + suffix="lm_head", + target_module=VocabParallelLMHead1D, + kwargs={ + "gather_output": not self.shard_config.parallel_output, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + }, + ), ], method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}, ) diff --git a/tests/test_shardformer/test_layer/test_dist_log_prob.py b/tests/test_shardformer/test_layer/test_dist_log_prob.py new file mode 100644 index 000000000000..05a6a5d4766f --- /dev/null +++ b/tests/test_shardformer/test_layer/test_dist_log_prob.py @@ -0,0 +1,52 @@ +import pytest +import torch +from coati.distributed.utils import log_probs_from_logits + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer import dist_log_prob_1d +from colossalai.testing import rerun_if_address_is_in_use, spawn + +CONFIG = dict( + parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode="1d")), +) + + +def check_dist_log_prob(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl") + + # prepare data + pred = torch.randn(2, 4, 8, requires_grad=True).cuda() + labels = torch.randint(8, (2, 4)).cuda() + + logprob = log_probs_from_logits(pred, labels) + + pred.retain_grad() + logprob.mean().backward() + + dist_pred = pred.clone().chunk(world_size, -1)[rank].detach() + dist_pred.requires_grad = True + dist_logprob = dist_log_prob_1d(dist_pred, labels) + + dist_pred.retain_grad() + dist_logprob.squeeze(-1).mean().backward() + + assert torch.allclose( + logprob, dist_logprob.squeeze(-1), atol=1e-5 + ), f"dist cross entropy logprob is not equal to orgin logprob\n{logprob}\n{dist_logprob.squeeze(-1)}" + + pred_grad_partial = pred.grad.clone().chunk(world_size, -1)[rank].detach() + assert torch.allclose( + pred_grad_partial, dist_pred.grad + ), f"dist grad is not equal to orgin grad\n{pred.grad}\n{dist_pred.grad}" + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_dist_log_prob(): + spawn(check_dist_log_prob, 2) + + +if __name__ == "__main__": + test_dist_log_prob()