Skip to content

[feat] GRPO with distributed implementation #6230

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 37 commits into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
8e6c9a4
add reward related function
TongLi3701 Feb 23, 2025
ffd3878
add simple grpo
TongLi3701 Feb 23, 2025
f736d74
update grpo
TongLi3701 Feb 25, 2025
070907d
polish
TongLi3701 Feb 28, 2025
c15225b
modify data loader
Mar 6, 2025
b96d690
grpo consumer
Mar 6, 2025
678f5a9
update loss
Mar 6, 2025
d03cdea
update reward fn
Mar 6, 2025
7f2ceac
update example
Mar 6, 2025
812f4b7
update loader
Mar 6, 2025
0f566cc
add algo selection
Mar 6, 2025
ab5b6d8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2025
0cc0c84
add save
Mar 6, 2025
0590f10
update select algo
Mar 6, 2025
22cc155
Merge branch 'grpo-latest' of github.com:hpcaitech/ColossalAI into gr…
Mar 6, 2025
eb6337f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2025
9d9d516
update grpo
Mar 10, 2025
754b16d
update reward fn
Mar 10, 2025
71a0181
update reward
Mar 10, 2025
abca66e
fix reward score
Mar 11, 2025
47d6493
add response length
Mar 11, 2025
704866a
detach
Mar 11, 2025
131eece
fix tp bug
Mar 13, 2025
afddfde
fix consumer
Mar 13, 2025
4702d57
convert to 8 generation
Mar 13, 2025
45ac6c6
print results
Mar 13, 2025
57b49da
setup update
Mar 13, 2025
bc0171d
fix transformers backend
YeAnbang Mar 14, 2025
7795d4c
[Feature] Support Distributed LogProb for GRPO Training (#6247)
duanjunwen Mar 18, 2025
7ee4452
fix vllm
YeAnbang Mar 19, 2025
0472f44
fix logprob, add filtering, temperature annealing, lr descent
YeAnbang Mar 21, 2025
d8eaf0d
simplify vllm preprocessing input ids
YeAnbang Mar 21, 2025
2aa7385
update logging
YeAnbang Mar 21, 2025
489f215
Merge pull request #6250 from hpcaitech/grpo-latest-dev
YeAnbang Mar 21, 2025
5015300
[feat] add microbatch forwarding (#6251)
YeAnbang Mar 28, 2025
ed43a4b
[Distributed RLHF] Integration of PP (#6257)
YeAnbang Apr 9, 2025
9467c10
[hot-fix] Fix memory leakage bug, support TP+PP (#6258)
YeAnbang Apr 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,5 @@ coverage.xml
# log, test files - ColossalChat
applications/ColossalChat/logs
applications/ColossalChat/tests/logs
applications/ColossalChat/wandb
applications/ColossalChat/model
35 changes: 26 additions & 9 deletions applications/ColossalChat/coati/dataset/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,10 +356,24 @@ def apply_chat_template_and_mask(
truncation: bool = True,
ignore_idx: int = -100,
) -> Dict[str, torch.Tensor]:

system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, i.e., <answer> 123 </answer>.\n\n"

system_element = {
"role": "system",
"content": system_prompt,
}

# Format for RL.
gt_answer = None
if "messages" in chat and "gt_answer" in chat:
gt_answer = chat["gt_answer"]
chat = [chat["messages"]]

tokens = []
assistant_mask = []
for i, msg in enumerate(chat):
msg_tokens = tokenizer.apply_chat_template([msg], tokenize=True)
msg_tokens = tokenizer.apply_chat_template([system_element, msg], tokenize=True, add_generation_prompt=True)
# remove unexpected bos token
if i > 0 and msg_tokens[0] == tokenizer.bos_token_id:
msg_tokens = msg_tokens[1:]
Expand All @@ -372,14 +386,10 @@ def apply_chat_template_and_mask(
if max_length is not None:
if padding and len(tokens) < max_length:
to_pad = max_length - len(tokens)
if tokenizer.padding_side == "right":
tokens.extend([tokenizer.pad_token_id] * to_pad)
assistant_mask.extend([False] * to_pad)
attention_mask.extend([0] * to_pad)
else:
tokens = [tokenizer.pad_token_id] * to_pad + tokens
assistant_mask = [False] * to_pad + assistant_mask
attention_mask = [0] * to_pad + attention_mask
# Left padding for generation.
tokens = [tokenizer.pad_token_id] * to_pad + tokens
assistant_mask = [False] * to_pad + assistant_mask
attention_mask = [0] * to_pad + attention_mask
if truncation and len(tokens) > max_length:
tokens = tokens[:max_length]
assistant_mask = assistant_mask[:max_length]
Expand All @@ -389,6 +399,13 @@ def apply_chat_template_and_mask(
labels = input_ids.clone()
labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx

if gt_answer is not None:
gt_answer = tokenizer.encode(
gt_answer, padding="max_length", truncation=True, max_length=128, return_tensors="pt"
)
gt_answer = gt_answer.squeeze(1)
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "gt_answer": gt_answer}

return {
"input_ids": input_ids,
"attention_mask": attention_mask,
Expand Down
30 changes: 22 additions & 8 deletions applications/ColossalChat/coati/distributed/consumer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from contextlib import nullcontext
from typing import Any, Dict, Optional

Expand Down Expand Up @@ -33,6 +34,8 @@ def __init__(
model_config: Dict[str, Any],
plugin_config: Dict[str, Any],
microbatch_size: int = 1,
save_interval: int = 100,
save_dir: str = "./model",
):
self.num_producers = num_producers
self.num_episodes = num_episodes
Expand All @@ -44,14 +47,16 @@ def __init__(
self.num_recv_per_update = num_recv_per_update
self.batch_size = batch_size
self.microbatch_size = microbatch_size
self.save_interval = save_interval
self.save_dir = save_dir
assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size"
self.num_microbatches = batch_size // microbatch_size

self.model_config = model_config
self.plugin_config = plugin_config
assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"

self.device = get_current_device()
self.lr_scheduler = None

def setup(self) -> None:
for i in range(self.num_producers):
Expand All @@ -60,18 +65,15 @@ def setup(self) -> None:
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)

plugin_config = dict(
tp_size=1,
pp_size=1,
precision="bf16",
zero_stage=1,
)
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
plugin_config["microbatch_size"] = self.microbatch_size
plugin_config.update(self.plugin_config)
self.plugin = HybridParallelPlugin(**plugin_config)
self.booster = Booster(plugin=self.plugin)
self.dp_rank = dist.get_rank(self.plugin.dp_group)
self.tp_rank = dist.get_rank(self.plugin.tp_group)

self.dp_size = dist.get_world_size(self.plugin.dp_group)

self.buffer = []
Expand All @@ -94,7 +96,6 @@ def loop(self) -> None:
i = 0
for _ in range(self.num_recv_per_update):
# receive data from producers

for r in range(self.num_producers):
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
self.buffer.extend(
Expand All @@ -116,13 +117,26 @@ def loop(self) -> None:
pbar.set_postfix({"loss": loss})
i += 1
assert len(self.buffer) == 0
if self.lr_scheduler is not None:
self.lr_scheduler.step()
if (step + 1) % self.save_interval == 0:
if self.rank == 0:
print(f"Start saving policy model at step {step + 1}.")
save_path = os.path.join(self.save_dir, f"modeling-step-{step + 1}")
self.booster.save_model(self.policy_model, save_path, shard=True)
if self.rank == 0:
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")

if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
torch.cuda.empty_cache()
state_dict = self.state_dict()
if self.rank == 0:
ray_broadcast_tensor_dict(
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
)
del state_dict
torch.cuda.empty_cache()


@ray.remote
Expand Down
Loading
Loading