Skip to content

[feat] Support DAPO #6263

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: grpo-latest
Choose a base branch
from
4 changes: 2 additions & 2 deletions applications/ColossalChat/coati/distributed/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def loop(self) -> None:
self.buffer = self.buffer[self.dp_size * self.microbatch_size :]
batch = bind_batch(batches)
batch = post_recv(batch)
loss = self.step(i, **batch)
loss = self.step(i, pbar, **batch)
if loss is not None:
pbar.set_postfix({"loss": loss})
i += 1
Expand Down Expand Up @@ -181,7 +181,7 @@ def setup(self):
super().setup()
self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer)

def step(self, step_idx: int, **kwargs) -> Optional[float]:
def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
labels = kwargs["input_ids"].clone()
labels[kwargs["attention_mask"] == 0] = -100
kwargs["labels"] = labels
Expand Down
375 changes: 220 additions & 155 deletions applications/ColossalChat/coati/distributed/grpo_consumer.py

Large diffs are not rendered by default.

7 changes: 2 additions & 5 deletions applications/ColossalChat/coati/distributed/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def launch_distributed(
inference_model_config: Dict[str, Any],
generate_config: Dict[str, Any],
train_model_config: Dict[str, Any],
grpo_config: Dict[str, Any],
plugin_config: Dict[str, Any],
tokenizer_config: Optional[Dict[str, Any]] = None,
inference_backend: str = "transformers",
Expand Down Expand Up @@ -103,11 +104,7 @@ def launch_distributed(
plugin_config=plugin_config,
microbatch_size=train_minibatch_size,
generate_config=generate_config_consumer,
training_config={
"filter_range": [0.05, 9.0],
"lr": 1e-6,
"train_microbatch_size": train_microbatch_size,
},
grpo_config=grpo_config,
num_generations=num_generations,
project_name=project_name,
)
Expand Down
52 changes: 39 additions & 13 deletions applications/ColossalChat/coati/distributed/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,27 @@

import torch
import torch.nn as nn
from coati.distributed.utils import masked_mean
from coati.distributed.utils import masked_mean, masked_sum


class PolicyLoss(nn.Module):
"""
Policy Loss for PPO
"""

def __init__(self, clip_eps: float = 0.2, skip_threshold: float = 20.0, beta: float = 0.01) -> None:
def __init__(
self,
clip_eps_low: float = 0.2,
clip_eps_high: float = 0.2,
beta: float = 0.01,
loss_variation: str = "sample_level",
) -> None:
super().__init__()
self.clip_eps = clip_eps
self.skip_threshold = skip_threshold
self.clip_eps_low = clip_eps_low
self.clip_eps_high = clip_eps_high
self.beta = beta
self.loss_variation = loss_variation
assert loss_variation in ["sample_level", "token_level"], f"Unsupported loss variation: {loss_variation}"

def forward(
self,
Expand All @@ -25,21 +33,39 @@ def forward(
action_mask: Optional[torch.Tensor] = None,
loss_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
skip = False
if action_mask is None:
ratio = (log_probs - log_probs.detach()).exp()
else:
ratio = ((log_probs - log_probs.detach()) * action_mask).exp()

surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
surr2 = ratio.clamp(1 - self.clip_eps_low, 1 + self.clip_eps_high) * advantages
if self.beta == 0:
# skip kl term if kl coefficient is zero
per_token_kl = 0.0
loss = -torch.min(surr1, surr2) + self.beta * per_token_kl

if action_mask is not None:
loss = masked_mean(loss, action_mask)
if self.loss_variation == "sample_level":
if action_mask is not None:
loss = masked_mean(loss, action_mask)
else:
loss = loss.mean(dim=1)
if loss_mask is not None:
loss = loss * loss_mask
loss = loss.mean()
elif self.loss_variation == "token_level":
total_tokens = 0
if action_mask is not None:
loss = masked_sum(loss, action_mask)
total_tokens = action_mask.sum(dim=1)
else:
loss = loss.sum(dim=1)
total_tokens = torch.ones_like(loss, device=loss.device) * log_probs.size(1)
if loss_mask is not None:
loss = loss * loss_mask
total_tokens = total_tokens * loss_mask
loss = loss.sum() / (total_tokens.sum() + 1e-8)
else:
loss = loss.mean(dim=1)
if loss_mask is not None:
loss = loss * loss_mask
loss = loss.mean()
return loss, skip, ratio.max()
raise ValueError(f"Unsupported loss variation: {self.loss_variation}")

return loss, ratio.max()
4 changes: 2 additions & 2 deletions applications/ColossalChat/coati/distributed/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,12 @@ def loop(self) -> None:
self.load_state_dict(state_dict)
del state_dict
torch.cuda.empty_cache()
# linear annealing for 1 episode, temperature from initial to 0.7
# linear annealing for 1 episode, temperature from initial to 0.9
if episode <= 0:
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
self.model.generate_config.temperature = (1 - ratio) * self.generate_config[
"temperature"
] + ratio * 0.7
] + ratio * 0.9


@ray.remote
Expand Down
31 changes: 22 additions & 9 deletions applications/ColossalChat/coati/distributed/reward/reward_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,23 @@


def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
format_score = 1.0
acc_score = 9.0
tokenizer = kwargs["tokenizer"]
soft_over_length_punishment = kwargs["soft_over_length_punishment"]
format_score = 0.0
acc_score = 10.0
reward = torch.tensor(0.0)
format_reward = torch.tensor(0.0)
acc_reward = torch.tensor(0.0)
format_acc = torch.tensor(0.0)
ans_acc = torch.tensor(0.0)
s, e = response_idx[0], response_idx[1]

length_reward = 0.0
if soft_over_length_punishment:
max_length = kwargs.get("max_length", 1024 * 4)
cache_length = kwargs.get("cache_length", 512)
res_length = e.item() - s.item() + 1
if max_length - cache_length < res_length < max_length:
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score

if gt_answer is None:
return reward

Expand All @@ -22,18 +32,21 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):

# Check format accuracy
if format_valid:
format_reward += format_score
format_acc += 1
reward += format_score

# Check answer accuracy
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
if (
final_answer is not None
format_valid
and final_answer is not None
and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower()
):
acc_reward += acc_score
ans_acc += 1
reward += acc_score

return torch.tensor([reward, format_reward, acc_reward]).to(input_ids.device)
reward = reward + length_reward

return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)


def gsm8k_reward_fn(input_ids, **kwargs):
Expand Down
17 changes: 17 additions & 0 deletions applications/ColossalChat/coati/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,20 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch
mask_sum = mask.sum(dim=dim)
mean = tensor / (mask_sum + 1e-8)
return mean


def masked_sum(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
"""
Compute the masked sum of a tensor along a specified dimension.

Args:
tensor (torch.Tensor): The input tensor.
mask (torch.Tensor): The mask tensor with the same shape as the input tensor.
dim (int, optional): The dimension along which to compute the sum. Default is 1.

Returns:
torch.Tensor: The masked sum tensor.

"""
tensor = tensor * mask
return tensor.sum(dim=dim)
8 changes: 6 additions & 2 deletions applications/ColossalChat/coati/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def all_reduce_mean(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor
return tensor


def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor:
def all_reduce_sum(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
"""
Performs an all-reduce operation to sum the values of the given tensor across all processes.

Expand All @@ -138,5 +138,9 @@ def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: The reduced tensor with the sum of values across all processes.
"""
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
# All reduce sum across DP group
if plugin is not None:
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM, group=plugin.dp_group)
else:
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
return tensor
46 changes: 36 additions & 10 deletions applications/ColossalChat/rl_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@
ray.init(address="local", namespace="ray-example")

inference_model_config = dict(path=args.model)
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
generate_config = dict(top_k=50, top_p=0.75, temperature=0.9)
train_model_config = dict(path=args.model, use_flash_attention_2=False, use_cache=False)
generate_config = dict(top_k=-1, top_p=1.0, temperature=1.0)

if args.backend == "transformers":
inference_model_config.update(
Expand All @@ -83,7 +83,7 @@
inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True))
generate_config.update(
dict(
max_tokens=2048,
max_tokens=4096,
ignore_eos=True,
include_stop_str_in_output=True,
stop=["</answer>"],
Expand All @@ -102,6 +102,29 @@
)
)

# Default Settings
# grpo_config = {
# "filter_range": [0.05, 9.0],
# "lr": 1e-6,
# "train_microbatch_size": train_microbatch_size,
# }

# DAPO variant settings
grpo_config = {
"filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch
"lr": 1e-6,
"train_microbatch_size": args.train_microbatch_size,
"clip_eps_low": 0.2,
"clip_eps_high": 0.28,
"skip_threshold": 20.0,
"beta": 0.0, # no KL penalty
"loss_variation": "token_level",
"soft_over_length_punishment": True,
"max_length": 4096,
"cache_length": 512,
"filter_truncated_response": True,
}

launch_distributed(
num_producers=args.num_inferencer,
num_proc_per_producer=1,
Expand All @@ -118,14 +141,17 @@
generate_config=generate_config,
num_generations=args.num_generations,
train_model_config=train_model_config,
# plugin_config={}, # for zero
grpo_config=grpo_config,
plugin_config={
"pp_size": 2,
"tp_size": 2,
"microbatch_size": args.train_microbatch_size // 2,
"zero_stage": 0,
"max_norm": 1.0,
}, # for pp
"zero_stage": 2,
}, # for zero
# plugin_config={
# "pp_size": 2,
# "tp_size": 2,
# "microbatch_size": args.train_microbatch_size // 2,
# "zero_stage": 0,
# "max_norm": 1.0,
# }, # for pp
inference_backend=args.backend,
master_addr="localhost",
master_port=29506,
Expand Down