Skip to content

Commit 58e35e0

Browse files
RedTachyonfelipemello1ebsmothersSalmanMohammadi
authored andcommitted
R1-Style distributed GRPO (pytorch#2326)
Co-authored-by: Felipe Mello <[email protected]> Co-authored-by: ebsmothers <[email protected]> Co-authored-by: salman <[email protected]>
1 parent e3744e3 commit 58e35e0

21 files changed

+2425
-2
lines changed

recipes/configs/dev/3B_full_grpo.yaml

+140
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Config for multi-node GRPO in dev/grpo_full_finetune_distributed.py
2+
# using a Llama3.2 3B Base model
3+
#
4+
# This config assumes that you've run the following command before launching
5+
# this run:
6+
# tune download meta-llama/Llama-3.2-3B --output-dir /tmp/Llama-3.2-3B --ignore-patterns "original/consolidated.00.pth"
7+
#
8+
# It can be beneficial to first train the base model with SFT using the 3B_sft recipe.
9+
#
10+
# To launch on 4 devices, run the following command from root:
11+
# tune run --nproc_per_node 4 dev/grpo_full_finetune_distributed --config dev/3B_full_grpo
12+
#
13+
# You can add specific overrides through the command line. For example
14+
# to override the checkpointer directory while launching training
15+
# you can run:
16+
# tune run --nproc_per_node 4 dev/grpo_full_finetune_distributed --config dev/grpo/3B_full_rl checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
17+
#
18+
# This config works best when the model is being fine-tuned on 2+ GPUs.
19+
#
20+
# Furthermore, you can launch it on multiple nodes by going to recipes/dev/ and using
21+
# sbatch multinode_grpo.sbatch
22+
23+
name: grpo_llama3b
24+
25+
output_dir: /tmp/checkpoints/${name}
26+
base_model_path: /tmp/llama3B_gsm8k_sft_part0/epoch_0 # Use this to train from the slightly trained SFT model
27+
28+
# Tokenizer
29+
tokenizer:
30+
_component_: torchtune.models.llama3.llama3_tokenizer
31+
path: /tmp/Llama-3.2-3B/original/tokenizer.model
32+
max_seq_len: null
33+
34+
# Dataset
35+
dataset:
36+
_component_: torchtune.dev.grpo.gsm8k.gsm8k_dataset
37+
partition: 1-9/10
38+
seed: null
39+
shuffle: False
40+
41+
# Model Arguments
42+
model:
43+
_component_: torchtune.models.llama3_2.llama3_2_3b
44+
45+
checkpointer:
46+
_component_: torchtune.training.FullModelHFCheckpointer
47+
checkpoint_dir: ${base_model_path}
48+
checkpoint_files: [
49+
model-00001-of-00002.safetensors,
50+
model-00002-of-00002.safetensors,
51+
]
52+
recipe_checkpoint: null
53+
output_dir: ${output_dir}
54+
model_type: LLAMA3
55+
56+
57+
ref_checkpointer:
58+
_component_: torchtune.training.FullModelHFCheckpointer
59+
checkpoint_dir: ${base_model_path}
60+
checkpoint_files: [
61+
model-00001-of-00002.safetensors,
62+
model-00002-of-00002.safetensors,
63+
]
64+
recipe_checkpoint: null
65+
output_dir: ${output_dir}/ref # shouldn't be used?
66+
model_type: LLAMA3
67+
68+
69+
resume_from_checkpoint: False
70+
save_every_n_epochs: 1
71+
72+
# Fine-tuning arguments
73+
batch_size: 1
74+
grpo_samples: 16
75+
forward_batch_size: 1
76+
max_generated_tokens: 512
77+
top_k: null
78+
temperature: 1.0
79+
80+
ppo_epochs: 1
81+
82+
num_steps: 200
83+
84+
clip_grad_norm: 1.0
85+
86+
epochs: 10
87+
optimizer:
88+
_component_: torch.optim.AdamW
89+
lr: 1e-5
90+
fused: True
91+
lr_scheduler:
92+
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
93+
num_warmup_steps: 50
94+
loss:
95+
_component_: torchtune.dev.grpo.loss.GRPOSimpleLoss
96+
kl_coeff: 0.01
97+
epsilon: 0.2
98+
99+
# Training env
100+
device: cuda
101+
102+
# Memory management
103+
enable_activation_checkpointing: True # True reduces memory
104+
compile: False # pytorch compile, set to true for better perf/memory
105+
106+
# Reduced precision
107+
dtype: bf16
108+
109+
110+
# Logging
111+
metric_logger:
112+
_component_: torchtune.training.metric_logging.DiskLogger
113+
log_dir: ${output_dir}/logs
114+
log_every_n_steps: 1
115+
log_peak_memory_stats: True
116+
117+
# Profiler (disabled)
118+
profiler:
119+
_component_: torchtune.training.setup_torch_profiler
120+
enabled: True
121+
122+
#Output directory of trace artifacts
123+
output_dir: ${output_dir}/profiling_outputs
124+
125+
#`torch.profiler.ProfilerActivity` types to trace
126+
cpu: True
127+
cuda: True
128+
129+
#trace options passed to `torch.profiler.profile`
130+
profile_memory: True
131+
with_stack: True
132+
record_shapes: True
133+
with_flops: False
134+
135+
# `torch.profiler.schedule` options:
136+
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
137+
wait_steps: 5
138+
warmup_steps: 3
139+
active_steps: 2
140+
num_cycles: 1
+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Config for multi-device SFT for reasoning in full_finetune_distributed.py
2+
# using a Llama3.2 3B Base model
3+
#
4+
# This config assumes that you've run the following command before launching
5+
# this run:
6+
# tune download meta-llama/Meta-Llama-3.2-3B --output-dir /tmp/Meta-Llama-3.2-3B-Instruct --ignore-patterns "original/consolidated.00.pth"
7+
#
8+
# To launch on 4 devices, run the following command from root:
9+
# tune run --nproc_per_node 4 full_finetune_distributed --config dev/3B_grpo_sft
10+
#
11+
# You can add specific overrides through the command line. For example
12+
# to override the checkpointer directory while launching training
13+
# you can run:
14+
# tune run --nproc_per_node 4 full_finetune_distributed --config dev/grpo/3B_sft checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
15+
#
16+
# This config works best when the model is being fine-tuned on 2+ GPUs.
17+
18+
19+
name: llama3B_gsm8k_sft_part0
20+
21+
output_dir: /tmp/${name}
22+
23+
# Tokenizer
24+
tokenizer:
25+
_component_: torchtune.models.llama3.llama3_tokenizer
26+
path: /tmp/Llama-3.2-3B/original/tokenizer.model
27+
max_seq_len: null
28+
29+
# Dataset
30+
dataset:
31+
_component_: torchtune.dev.grpo.gsm8k.gsm8k_sft
32+
partition: 0-0/10
33+
seed: null
34+
shuffle: True
35+
36+
# Model Arguments
37+
model:
38+
_component_: torchtune.models.llama3_2.llama3_2_3b
39+
40+
checkpointer:
41+
_component_: torchtune.training.FullModelHFCheckpointer
42+
checkpoint_dir: /tmp/Llama-3.2-3B/
43+
checkpoint_files: [
44+
model-00001-of-00002.safetensors,
45+
model-00002-of-00002.safetensors,
46+
]
47+
recipe_checkpoint: null
48+
output_dir: ${output_dir}
49+
model_type: LLAMA3
50+
resume_from_checkpoint: False
51+
52+
# Fine-tuning arguments
53+
batch_size: 2
54+
epochs: 1
55+
56+
optimizer:
57+
_component_: torch.optim.AdamW
58+
lr: 1e-5
59+
fused: True
60+
loss:
61+
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
62+
max_steps_per_epoch: null
63+
clip_grad_norm: null
64+
compile: False # torch.compile the model + loss, True increases speed + decreases memory
65+
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
66+
gradient_accumulation_steps: 1 # Use to increase effective batch size
67+
68+
# Training env
69+
device: cuda
70+
71+
# Memory management
72+
enable_activation_checkpointing: True # True reduces memory
73+
enable_activation_offloading: False # True reduces memory
74+
75+
# Reduced precision
76+
dtype: bf16
77+
78+
# Logging
79+
metric_logger:
80+
_component_: torchtune.training.metric_logging.DiskLogger
81+
log_dir: ${output_dir}/logs
82+
log_every_n_steps: 1
83+
log_peak_memory_stats: True
84+
85+
86+
# Profiler (disabled)
87+
profiler:
88+
_component_: torchtune.training.setup_torch_profiler
89+
enabled: False
90+
91+
#Output directory of trace artifacts
92+
output_dir: ${output_dir}/profiling_outputs
93+
94+
#`torch.profiler.ProfilerActivity` types to trace
95+
cpu: True
96+
cuda: True
97+
98+
#trace options passed to `torch.profiler.profile`
99+
profile_memory: False
100+
with_stack: False
101+
record_shapes: True
102+
with_flops: False
103+
104+
# `torch.profiler.schedule` options:
105+
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
106+
wait_steps: 5
107+
warmup_steps: 3
108+
active_steps: 2
109+
num_cycles: 1

0 commit comments

Comments
 (0)