Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
set -x

# Colocated GRPO training+generation for Qwen3-0.6B on GSM8K with FSDP.

# uv run examples/gsm8k/gsm8k_dataset.py --output_dir $HOME/data/gsm8k
# export WANDB_API_KEY=<your_key_here>
# bash examples/training_backends/megatron/run_fsdp_baseline.sh

DATA_DIR="$HOME/data/gsm8k"
NUM_GPUS=4
LOGGER="wandb" # change to "console" to print to stdout

INFERENCE_BACKEND="vllm"

uv run --isolated --extra $INFERENCE_BACKEND -m skyrl_train.entrypoints.main_base \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
trainer.algorithm.advantage_estimator="grpo" \
trainer.policy.model.path="Qwen/Qwen3-0.6B" \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp \
trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \
trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \
generator.num_inference_engines=$NUM_GPUS \
generator.inference_engine_tensor_parallel_size=1 \
trainer.epochs=20 \
trainer.eval_batch_size=1024 \
trainer.eval_before_train=true \
trainer.eval_interval=5 \
trainer.update_epochs_per_batch=1 \
trainer.train_batch_size=128 \
trainer.policy_mini_batch_size=64 \
trainer.micro_forward_batch_size_per_gpu=4 \
trainer.micro_train_batch_size_per_gpu=4 \
trainer.ckpt_interval=10 \
trainer.max_prompt_length=512 \
generator.sampling_params.max_generate_length=1024 \
trainer.policy.optimizer_config.lr=4.0e-6 \
trainer.algorithm.use_kl_loss=true \
generator.backend=$INFERENCE_BACKEND \
generator.run_engines_locally=true \
generator.weight_sync_backend=nccl \
generator.async_engine=true \
generator.batched=true \
environment.env_class=gsm8k \
generator.n_samples_per_prompt=5 \
generator.gpu_memory_utilization=0.8 \
trainer.logger="$LOGGER" \
trainer.project_name="gsm8k_megatron" \
trainer.run_name="gsm8k_fsdp1_4gpus" \
trainer.resume_mode=null \
trainer.ckpt_path="$HOME/ckpts/gsm8k_fsdp_ckpt" \
$@
62 changes: 62 additions & 0 deletions skyrl-train/examples/training_backends/megatron/run_megatron.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
set -x

# Colocated GRPO training+generation for Qwen3-0.6B on GSM8K with Megatron.

# uv run examples/gsm8k/gsm8k_dataset.py --output_dir $HOME/data/gsm8k
# export WANDB_API_KEY=<your_key_here>
# bash examples/training_backends/megatron/run_megatron.sh

DATA_DIR="$HOME/data/gsm8k"
NUM_GPUS=4
LOGGER="wandb" # change to "console" to print to stdout
MODEL_NAME="Qwen/Qwen3-0.6B"

INFERENCE_BACKEND="vllm" # currently only vllm is supported for megatron

MEGATRON_TP=2
MEGATRON_PP=2

uv run --isolated --extra $INFERENCE_BACKEND --extra mcore -m skyrl_train.entrypoints.main_base \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
trainer.algorithm.advantage_estimator="grpo" \
trainer.policy.model.path=$MODEL_NAME \
trainer.placement.colocate_all=true \
trainer.strategy=megatron \
trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \
trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \
generator.num_inference_engines=$NUM_GPUS \
generator.inference_engine_tensor_parallel_size=1 \
megatron_config.policy.tensor_model_parallel_size=$MEGATRON_TP \
megatron_config.policy.pipeline_model_parallel_size=$MEGATRON_PP \
megatron_config.ref.tensor_model_parallel_size=$MEGATRON_TP \
megatron_config.ref.pipeline_model_parallel_size=$MEGATRON_PP \
trainer.use_sample_packing=false \
trainer.epochs=20 \
trainer.eval_batch_size=1024 \
trainer.eval_before_train=false \
trainer.eval_interval=5 \
trainer.update_epochs_per_batch=1 \
trainer.train_batch_size=128 \
trainer.policy_mini_batch_size=64 \
trainer.micro_forward_batch_size_per_gpu=4 \
trainer.micro_train_batch_size_per_gpu=4 \
trainer.ckpt_interval=10 \
trainer.max_prompt_length=512 \
generator.sampling_params.max_generate_length=1024 \
trainer.policy.optimizer_config.lr=1.0e-6 \
trainer.algorithm.use_kl_loss=true \
generator.backend=$INFERENCE_BACKEND \
generator.run_engines_locally=true \
generator.weight_sync_backend=nccl \
generator.async_engine=true \
generator.batched=true \
environment.env_class=gsm8k \
generator.n_samples_per_prompt=5 \
generator.gpu_memory_utilization=0.6 \
trainer.logger="$LOGGER" \
trainer.project_name="gsm8k_megatron" \
trainer.run_name="gsm8k_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_${MODEL_NAME}" \
trainer.resume_mode=null \
trainer.ckpt_path="$HOME/ckpts/gsm8k_megatron_ckpt" \
$@
20 changes: 19 additions & 1 deletion skyrl-train/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dependencies = [
"debugpy==1.8.0",
"hf_transfer",
"wandb",
"datasets",
"datasets>=3.6.0",
"tensordict",
"jaxtyping",
"skyrl-gym",
Expand All @@ -51,6 +51,11 @@ conflicts = [
{ extra = "vllm" },
{ extra = "flashrl" },
{ extra = "sglang" },
],
[
{ extra = "mcore" },
{ extra = "sglang" },
{ extra = "flashrl" },
]
]

Expand Down Expand Up @@ -101,6 +106,18 @@ sglang = [
"torch==2.7.1",
"torchvision",
]
mcore = [
# Build transformer-engine separately first!
# uv pip install "torch==2.7.1"
# uv pip install "nvidia-cudnn-cu12>=9.3"
# export CUDNN_PATH="$(python -c 'import inspect, nvidia.cudnn as c, os; print(os.path.dirname(inspect.getfile(c)))')"
# export CPATH="$CUDNN_PATH/include:${CPATH:-}"
# export LD_LIBRARY_PATH="$CUDNN_PATH/lib:${LD_LIBRARY_PATH:-}"
# uv pip install --no-build-isolation "transformer_engine[pytorch]==2.5.0" --verbose
"mbridge==0.13.0",
"megatron-core@git+https://github.com/NVIDIA/Megatron-LM.git@core_r0.13.0",
"transformer-engine[pytorch]==2.5.0",
]
flashrl = [
# TODO (sumanthrh): Shift to install from git source
# Currently, this is the wheel built for commit 4b04dfc at: https://github.com/SumanthRH/vllm/tree/flashrl
Expand All @@ -111,6 +128,7 @@ flashrl = [
"torchvision",
]


[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
Expand Down
35 changes: 35 additions & 0 deletions skyrl-train/skyrl_train/config/megatron_config/policy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# @package megatron_config.policy
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
context_parallel_size: 1
expert_model_parallel_size: 1
expert_tensor_parallel_size: 1

# Settings for the Distributed Data Parallel (DDP) config
ddp_config:
grad_reduce_in_fp32: true
overlap_grad_reduce: false
overlap_param_gather: false
average_in_collective: true

# kwargs to override the HF model config
model_config_kwargs: {}

# kwargs to override the Megatron TransformerConfig object
transformer_config_kwargs:
# Recompute config - used for gradient/activation checkpointing
# default use minimal performance-interference recompute methods
# Recompute granualarity, choices: ["full", "selective"]
recompute_granularity: null

# Recompute modules, multiple choices: ["core_attn", "moe_act", "layernorm", "mla_up_proj", "mlp", "moe"]
# Please use correct module in matched model
recompute_modules: ["core_attn"]

# 'uniform', 'block'
# 'uniform' divides the total number of transformer layers and checkpoints the input activation of each chunk
# 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity
recompute_method: null

# 'full' will checkpoint the entire transformer layer and 'selective' only checkpoints memory intensive part of attention
recompute_num_layers: null
28 changes: 28 additions & 0 deletions skyrl-train/skyrl_train/config/megatron_config/ref.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# @package megatron_config.ref
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
context_parallel_size: 1
expert_model_parallel_size: 1
expert_tensor_parallel_size: 1


model_config_kwargs: {}

# additional transformer config like: num_layers_in_first(/last)_pipeline_stage
transformer_config_kwargs:
# Recompute configuration, same as in megatron.training.arguments
# default use minimal performance-interference recompute methods
# Recompute granualarity, choices: ["full", "selective"]
recompute_granularity: null

# Recompute modules, multiple choices: ["core_attn", "moe_act", "layernorm", "mla_up_proj", "mlp", "moe"]
# Please use correct module in matched model
recompute_modules: ["core_attn"]

# 'uniform', 'block'
# 'uniform' divides the total number of transformer layers and checkpoints the input activation of each chunk
# 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity
recompute_method: null

# 'full' will checkpoint the entire transformer layer and 'selective' only checkpoints memory intensive part of attention
recompute_num_layers: null
6 changes: 5 additions & 1 deletion skyrl-train/skyrl_train/config/ppo_base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ defaults:
- _self_
- deepspeed_config@deepspeed_config.train: train
- deepspeed_config@deepspeed_config.eval: eval
- megatron_config@megatron_config.policy: policy
- megatron_config@megatron_config.ref: ref
- skyrl_gym_config: default

data:
Expand Down Expand Up @@ -34,11 +36,12 @@ trainer:
max_grad_norm: 1.0 # gradient clipping
offload_after_step: true # offload optimizer state to cpu after each step. Applicable only when `colocate_all=true`
num_warmup_steps: 0
scheduler: "constant_with_warmup"
scheduler: "constant_with_warmup"
fsdp_config:
cpu_offload: false # offload params + optimizer state to cpu during fwd pass
reshard_after_forward: true # fsdp2 only, [True, False, int between 1 and fsdp_size]
fsdp_size: -1
megatron_config: ${megatron_config.policy}
sequence_parallel_size: 1
# uses torch compile with logits calculation
use_torch_compile: false
Expand All @@ -47,6 +50,7 @@ trainer:
ref:
sequence_parallel_size: 1
deepspeed_config: ${deepspeed_config.eval}
megatron_config: ${megatron_config.ref}
fsdp_config:
cpu_offload: true
reshard_after_forward: true
Expand Down
38 changes: 0 additions & 38 deletions skyrl-train/skyrl_train/distributed/deepspeed_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,44 +201,6 @@ def _ds_init_eval_model(self, model):
model = engine
return model

def all_reduce(self, data, op="mean"):
assert op in ("mean", "max", "sum")
if isinstance(data, dict):
ret = {}
for k, v in data.items():
ret[k] = self.all_reduce(v, op)
return ret
else:
is_tensor = True
if not isinstance(data, torch.Tensor):
data = torch.Tensor([data])
is_tensor = False
is_cpu_tensor = data.device.type == "cpu"

if is_cpu_tensor:
data = data.to(torch.cuda.current_device())
if op == "mean":
data /= self.world_size
dist.all_reduce(data, op=dist.ReduceOp.MAX if op == "max" else dist.ReduceOp.SUM)
if is_cpu_tensor:
data = data.cpu()
return data.item() if not is_tensor else data

def all_gather(self, data):
if isinstance(data, dict):
ret = {}
for k, v in data.items():
ret[k] = self.all_gather(v)
return ret
else:
if not isinstance(data, torch.Tensor):
data = torch.Tensor([data])
is_cpu_tensor = data.device.type == "cpu"

ret = [torch.zeros_like(data).to(torch.cuda.current_device()) for _ in range(self.world_size)]
dist.all_gather(ret, data.to(torch.cuda.current_device()))
return torch.cat(ret).cpu() if is_cpu_tensor else torch.cat(ret)

def _unwrap_model(self, model) -> nn.Module:
if isinstance(model, Actor):
return self._unwrap_model(model.model)
Expand Down
17 changes: 9 additions & 8 deletions skyrl-train/skyrl_train/distributed/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,19 @@ class MeshRank:

world_size: int
dp_size: int
pp_size: int

def is_primary_dp_rank(self) -> bool:
"""Check if this rank is the primary DP rank.
def is_collection_dp_rank(self) -> bool:
"""Check if this rank is a DP rank to collect from

This is the rank with (SP=0, TP=0, PP=0)
This is the rank with (SP=0, TP=0, PP=pp_size-1)

Note: double check this for ETP > 1 (but this is not a typically used case)
"""
return self.tp == 0 and self.pp == 0 and self.sp == 0
return self.tp == 0 and self.pp == self.pp_size - 1 and self.sp == 0

def __str__(self) -> str:
return (
f"MeshRank(dp={self.dp}, sp={self.sp}, tp={self.tp}, world_size={self.world_size}, dp_size={self.dp_size})"
)
return f"MeshRank(dp={self.dp}, sp={self.sp}, tp={self.tp}, pp={self.pp}, world_size={self.world_size}, dp_size={self.dp_size}, pp_size={self.pp_size})"

def __repr__(self) -> str:
return self.__str__()
Expand Down Expand Up @@ -256,7 +257,7 @@ def concatenate_outputs_after_mesh_dispatch(
# collect in-order
dp_rank_to_shard = {}
for actor_info, data_batch in zip(actor_infos, data_batches):
if actor_info.rank.is_primary_dp_rank():
if actor_info.rank.is_collection_dp_rank():
dp_rank = actor_info.rank.dp
dp_rank_to_shard[dp_rank] = data_batch
for i in range(actor_infos[0].rank.dp_size):
Expand Down
40 changes: 0 additions & 40 deletions skyrl-train/skyrl_train/distributed/fsdp_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,46 +307,6 @@ def _fsdp_init_eval_model(self, model):

return model

def all_reduce(self, data, op="mean"):
"""Perform all_reduce across all processes"""
assert op in ("mean", "max", "sum")
if isinstance(data, dict):
ret = {}
for k, v in data.items():
ret[k] = self.all_reduce(v, op)
return ret
else:
is_tensor = True
if not isinstance(data, torch.Tensor):
data = torch.Tensor([data])
is_tensor = False
is_cpu_tensor = data.device.type == "cpu"

if is_cpu_tensor:
data = data.to(torch.cuda.current_device())
if op == "mean":
data /= self.world_size
dist.all_reduce(data, op=dist.ReduceOp.MAX if op == "max" else dist.ReduceOp.SUM)
if is_cpu_tensor:
data = data.cpu()
return data.item() if not is_tensor else data

def all_gather(self, data):
"""Perform all_gather across all processes"""
if isinstance(data, dict):
ret = {}
for k, v in data.items():
ret[k] = self.all_gather(v)
return ret
else:
if not isinstance(data, torch.Tensor):
data = torch.Tensor([data])
is_cpu_tensor = data.device.type == "cpu"

ret = [torch.zeros_like(data).to(torch.cuda.current_device()) for _ in range(self.world_size)]
dist.all_gather(ret, data.to(torch.cuda.current_device()))
return torch.cat(ret).cpu() if is_cpu_tensor else torch.cat(ret)

def _unwrap_model(self, model) -> nn.Module:
"""Unwrap model from Actor or FSDP"""
# Handle Actor wrapper
Expand Down
Loading