Skip to content

[feat] Support DAPO #6263

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions applications/ColossalChat/coati/distributed/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from colossalai.utils import get_current_device

from .comm import ray_broadcast_tensor_dict
from .utils import bind_batch, post_recv, unbind_batch
from .utils import bind_batch, pad_batch, post_recv, unbind_batch


class BaseConsumer:
Expand All @@ -33,7 +33,7 @@ def __init__(
batch_size: int,
model_config: Dict[str, Any],
plugin_config: Dict[str, Any],
microbatch_size: int = 1,
minibatch_size: int = 1,
save_interval: int = 100,
save_dir: str = "./model",
):
Expand All @@ -46,11 +46,11 @@ def __init__(
self.num_update_per_episode = num_update_per_episode
self.num_recv_per_update = num_recv_per_update
self.batch_size = batch_size
self.microbatch_size = microbatch_size
self.minibatch_size = minibatch_size
self.save_interval = save_interval
self.save_dir = save_dir
assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size"
self.num_microbatches = batch_size // microbatch_size
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
self.num_microbatches = batch_size // minibatch_size

self.model_config = model_config
self.plugin_config = plugin_config
Expand All @@ -67,7 +67,7 @@ def setup(self) -> None:

plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
plugin_config["microbatch_size"] = self.microbatch_size
plugin_config["microbatch_size"] = self.minibatch_size
plugin_config.update(self.plugin_config)
self.plugin = HybridParallelPlugin(**plugin_config)
self.booster = Booster(plugin=self.plugin)
Expand Down Expand Up @@ -105,18 +105,26 @@ def loop(self) -> None:
)
)
)
while len(self.buffer) >= self.dp_size * self.microbatch_size:
while len(self.buffer) >= self.dp_size * self.minibatch_size:
batches = self.buffer[
self.dp_rank * self.microbatch_size : (self.dp_rank + 1) * self.microbatch_size
self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size
]
self.buffer = self.buffer[self.dp_size * self.microbatch_size :]
batch = pad_batch(
batches
) # when `imbs` is smaller than `tMbs`, samples may have differ in size, need to pad before stacking
batch = bind_batch(batches)
batch = post_recv(batch)
loss = self.step(i, **batch)
loss, num_excessive_prompts = self.step(i, pbar, **batch)
self.buffer = (
self.buffer[
(self.dp_rank + 1) * self.minibatch_size
- num_excessive_prompts : (self.dp_rank + 1) * self.minibatch_size
]
+ self.buffer[self.dp_size * self.minibatch_size :]
)
if loss is not None:
pbar.set_postfix({"loss": loss})
i += 1
assert len(self.buffer) == 0
if self.lr_scheduler is not None:
self.lr_scheduler.step()
if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode:
Expand Down Expand Up @@ -154,7 +162,9 @@ def __init__(
batch_size,
model_config,
plugin_config,
microbatch_size=1,
minibatch_size=1,
save_interval: int = 100,
save_dir="./model",
):
super().__init__(
num_producers,
Expand All @@ -168,7 +178,7 @@ def __init__(
batch_size,
model_config,
plugin_config,
microbatch_size,
minibatch_size,
)
path = model_config.pop("path")
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)
Expand All @@ -181,7 +191,7 @@ def setup(self):
super().setup()
self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer)

def step(self, step_idx: int, **kwargs) -> Optional[float]:
def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
labels = kwargs["input_ids"].clone()
labels[kwargs["attention_mask"] == 0] = -100
kwargs["labels"] = labels
Expand Down
Loading