-
Notifications
You must be signed in to change notification settings - Fork 566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
R1-Style distributed GRPO #2326
Changes from 40 commits
1c29d67
ed544da
6ca2c38
7e39b69
55b1c65
aa34954
0cc7795
b11e742
ce7e8ce
736b14f
eb3d5b9
9df4b38
3d39ffc
30fa65c
e8f19f2
499c013
3ca4f49
300c117
3cf169c
7d9d37a
654eb56
c75cb7a
1f6be85
cf770d3
1fd86c8
e198825
474c8dc
60d7cd9
8d018ba
c9a01cb
c17bfc0
b312af8
f7c0929
f6c7e53
72cb4bd
a9112b8
b089ff2
4870213
2dd6f60
7b58712
b010ef6
603d16a
d7269e4
9978a44
d710200
352c4fb
8f60178
bb8b97d
44aeb85
86de099
120b1d2
f3e9b77
b856695
e61894b
96dc9c9
b1cfc3b
8a62d1e
7394cc7
24cd238
67458f2
d9c37b8
18baeae
646a2ce
f641161
9f4fb88
123883a
d9cf499
98df9a7
b068132
738dc2d
759ff8b
239d382
957671c
357aa43
3c54d01
0d77799
0f8df07
184aeef
3018eeb
d29c455
b9a56b9
b143b00
05cf10d
be4d96d
1303695
9d5bfa5
6cbb10a
97686e2
aeb69cf
feeb042
e7cb937
2a75224
5010f1d
711ee6b
aad81fb
53cd129
4dbe231
3a8c6de
9677df1
9340f89
53ee98f
682ebef
8dd7546
0939c94
81a7765
e41a520
45084e0
8283230
987e971
7a03139
3fc43d7
dde5fd8
2fddf9c
aead54e
16ad525
70886b2
2ba4a97
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,139 @@ | ||||||
# Config for multi-node GRPO in dev/grpo_full_finetune_distributed.py | ||||||
# using a Llama3.2 3B Base model | ||||||
# | ||||||
# This config assumes that you've run the following command before launching | ||||||
# this run: | ||||||
# tune download meta-llama/Meta-Llama-3.2-3B --output-dir /tmp/Meta-Llama-3.2-3B-Instruct --ignore-patterns "original/consolidated.00.pth" | ||||||
RedTachyon marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
# | ||||||
# It can be beneficial to first train the base model with SFT using the 3B_sft recipe. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can do something similar to what's in our knowledge distillation configs and just explicitly add the command here for convenience. Ref |
||||||
# | ||||||
# To launch on 4 devices, run the following command from root: | ||||||
# tune run --nproc_per_node 4 dev/grpo_full_finetune_distributed --config dev/grpo/3B_full_rl | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should be this based on how it's defined in the recipe registry i think
Suggested change
|
||||||
# | ||||||
# You can add specific overrides through the command line. For example | ||||||
# to override the checkpointer directory while launching training | ||||||
# you can run: | ||||||
# tune run --nproc_per_node 4 dev/grpo_full_finetune_distributed --config dev/grpo/3B_full_rl checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> | ||||||
# | ||||||
# This config works best when the model is being fine-tuned on 2+ GPUs. | ||||||
# | ||||||
# Furthermore, you can launch it on multiple nodes by going to recipes/dev/ and using | ||||||
# sbatch multinode_grpo.sbatch | ||||||
|
||||||
base_dir: /tmp | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how do you feel about deleting base_dir, and just keep output_dir // base_model_path, so we can remove one extra param? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So my personal preference is actually keeping base_dir or something similar, but for the "official" codebase it probably makes sense to keep it more uniform. My reasoning is that - putting everything in /tmp/ isn't a terrible assumption with no information about the user's filesystem, but at the same time, most of the time, I don't want to put everything in /tmp. So it's much easier for me to change this in one place instead of tracking down all /tmp instances in the config. |
||||||
name: grpo_sft_start | ||||||
|
||||||
output_dir: ${base_dir}/checkpoints/${name} | ||||||
base_model_path: ${base_dir}/Llama-3.2-3B # Use this to train from the base model | ||||||
#base_model_path: ${base_dir}/llama3B_gsm8k_sft_part0/epoch_0 # Use this to train from the slightly trained SFT model | ||||||
|
||||||
# Tokenizer | ||||||
tokenizer: | ||||||
_component_: torchtune.models.llama3.llama3_tokenizer | ||||||
path: ${base_model_path}/original/tokenizer.model | ||||||
max_seq_len: null | ||||||
|
||||||
# Dataset | ||||||
dataset: | ||||||
_component_: torchtune.datasets.gsm8k_dataset | ||||||
partition: 1-9/10 | ||||||
# packed: False # True increases speed | ||||||
seed: null | ||||||
shuffle: False | ||||||
|
||||||
# Model Arguments | ||||||
model: | ||||||
_component_: torchtune.models.llama3_2.llama3_2_3b | ||||||
|
||||||
checkpointer: | ||||||
_component_: torchtune.training.FullModelHFCheckpointer | ||||||
checkpoint_dir: ${base_model_path} | ||||||
checkpoint_files: [ | ||||||
model-00001-of-00002.safetensors, # Add ft- if starting from SFT model | ||||||
RedTachyon marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
model-00002-of-00002.safetensors, | ||||||
] | ||||||
recipe_checkpoint: null | ||||||
output_dir: ${output_dir} | ||||||
model_type: LLAMA3 | ||||||
resume_from_checkpoint: False | ||||||
save_every_n_epochs: 1 | ||||||
|
||||||
# Fine-tuning arguments | ||||||
batch_size: 1 | ||||||
grpo_samples: 16 | ||||||
forward_batch_size: 1 | ||||||
max_generated_tokens: 512 | ||||||
top_k: null | ||||||
temperature: 1.0 | ||||||
|
||||||
fsdp_reshard_after_forward: True | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why was this necessary? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably not necessary - I think my process was that I initially set it to False, which sped up the generation (by a decent margin too, like 25% in the data generation phase), but then it caused some irregular memory issues. I didn't want to deal with the memory profiling just yet, so I set it back to True to see if the pipeline even works. In any case, this will likely be an impactful parameter to check when optimizing the recipe for performance, and I'm guessing it's not very commonly used in other recipes, since here we're doing an unusually large amount of generation. |
||||||
|
||||||
ppo_epochs: 1 | ||||||
ppo_batch_size: ${grpo_samples} # single-step for the "simple" loss | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need to check where this is used, but i wonder if 'ppo' is a good name here. I also find i a bit confusing, without context, the difference between ppo_epochs and epochs, or batch_size and ppo_batch_size. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It probably makes sense to rename it, but it will be an awkward area of weird overlaps in naming and active research. As an algorithm, GRPO is basically the same as PPO, but with a different advantage estimation method. Or at least that's how it's defined in the R1 paper. The general idea in PPO is that we can take multiple gradient steps with a single batch of data - The implementation that is currently running simplifies a few elements of the PPO/GRPO loss, which also requires only taking a single gradient step per data batch. This is what works in TRL, and what I validated to work here in torchtune. It's very likely that some other variations will also work, but I've yet to run those experiments. |
||||||
|
||||||
num_steps: 10000 | ||||||
|
||||||
clip_grad_norm: 1.0 | ||||||
|
||||||
epochs: 10 | ||||||
optimizer: | ||||||
_component_: torch.optim.AdamW | ||||||
lr: 1e-5 | ||||||
fused: True | ||||||
lr_scheduler: | ||||||
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup | ||||||
num_warmup_steps: 50 | ||||||
loss: | ||||||
_component_: torchtune.rlhf.loss.GRPOSimpleLoss | ||||||
kl_coeff: 0.04 | ||||||
epsilon: 0.2 | ||||||
max_steps_per_epoch: null | ||||||
gradient_accumulation_steps: 1 # Use to increase virtual batch size | ||||||
|
||||||
# Training env | ||||||
device: cuda | ||||||
|
||||||
# Memory management | ||||||
enable_activation_checkpointing: True # True reduces memory | ||||||
enable_activation_offloading: True # True reduces memory | ||||||
compile: False # pytorch compile, set to true for better perf/memory | ||||||
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can also just consider removing this from the recipe for simplicity unless you find that it's necessary |
||||||
|
||||||
# Reduced precision | ||||||
dtype: bf16 | ||||||
custom_sharded_layers: ['decoder.tok_embeddings'] | ||||||
RedTachyon marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
|
||||||
# Logging | ||||||
metric_logger: | ||||||
_component_: torchtune.training.metric_logging.WandBLogger | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At least prior to landing should just use DiskLogger as in our other configs (not everyone necessarily has WandB setup) |
||||||
project: grpo_gsm8k_torchtune | ||||||
log_dir: ${output_dir}/logs | ||||||
name: ${name} | ||||||
log_every_n_steps: 1 | ||||||
log_peak_memory_stats: True | ||||||
|
||||||
# Profiler (disabled) | ||||||
profiler: | ||||||
_component_: torchtune.training.setup_torch_profiler | ||||||
enabled: True | ||||||
|
||||||
#Output directory of trace artifacts | ||||||
output_dir: ${output_dir}/profiling_outputs | ||||||
|
||||||
#`torch.profiler.ProfilerActivity` types to trace | ||||||
cpu: True | ||||||
cuda: True | ||||||
|
||||||
#trace options passed to `torch.profiler.profile` | ||||||
profile_memory: True | ||||||
with_stack: True | ||||||
record_shapes: True | ||||||
with_flops: False | ||||||
|
||||||
# `torch.profiler.schedule` options: | ||||||
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat | ||||||
wait_steps: 5 | ||||||
warmup_steps: 3 | ||||||
active_steps: 2 | ||||||
num_cycles: 1 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
# Config for multi-device SFT for reasoning in full_finetune_distributed.py | ||
# using a Llama3.2 3B Base model | ||
# | ||
# This config assumes that you've run the following command before launching | ||
# this run: | ||
# tune download meta-llama/Meta-Llama-3.2-3B --output-dir /tmp/Meta-Llama-3.2-3B-Instruct --ignore-patterns "original/consolidated.00.pth" | ||
# | ||
# To launch on 4 devices, run the following command from root: | ||
# tune run --nproc_per_node 4 full_finetune_distributed --config dev/grpo/3B_sft | ||
# | ||
# You can add specific overrides through the command line. For example | ||
# to override the checkpointer directory while launching training | ||
# you can run: | ||
# tune run --nproc_per_node 4 full_finetune_distributed --config dev/grpo/3B_sft checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> | ||
# | ||
# This config works best when the model is being fine-tuned on 2+ GPUs. | ||
|
||
|
||
base_dir: /tmp | ||
name: llama3B_gsm8k_sft_part0 | ||
|
||
output_dir: ${base_dir}/${name} | ||
|
||
# Tokenizer | ||
tokenizer: | ||
_component_: torchtune.models.llama3.llama3_tokenizer | ||
path: ${base_dir}/Llama-3.2-3B/original/tokenizer.model | ||
max_seq_len: null | ||
|
||
# Dataset | ||
dataset: | ||
_component_: torchtune.datasets.gsm8k_sft | ||
partition: 0-0/10 | ||
seed: null | ||
shuffle: True | ||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.llama3_2.llama3_2_3b | ||
|
||
checkpointer: | ||
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: ${base_dir}/Llama-3.2-3B/ | ||
checkpoint_files: [ | ||
model-00001-of-00002.safetensors, | ||
model-00002-of-00002.safetensors, | ||
] | ||
recipe_checkpoint: null | ||
output_dir: ${output_dir} | ||
model_type: LLAMA3 | ||
resume_from_checkpoint: False | ||
|
||
# Fine-tuning arguments | ||
batch_size: 2 | ||
epochs: 1 | ||
|
||
optimizer: | ||
_component_: torch.optim.AdamW | ||
lr: 1e-5 | ||
fused: True | ||
loss: | ||
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss | ||
max_steps_per_epoch: null | ||
clip_grad_norm: null | ||
compile: False # torch.compile the model + loss, True increases speed + decreases memory | ||
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 | ||
gradient_accumulation_steps: 1 # Use to increase effective batch size | ||
|
||
# Training env | ||
device: cuda | ||
|
||
# Memory management | ||
enable_activation_checkpointing: True # True reduces memory | ||
enable_activation_offloading: False # True reduces memory | ||
|
||
# Reduced precision | ||
dtype: bf16 | ||
|
||
# Logging | ||
metric_logger: | ||
_component_: torchtune.training.metric_logging.WandBLogger | ||
project: gsm8k_sft | ||
log_dir: ${output_dir}/logs | ||
name: ${name} | ||
log_every_n_steps: 1 | ||
log_peak_memory_stats: True | ||
|
||
|
||
# Profiler (disabled) | ||
profiler: | ||
_component_: torchtune.training.setup_torch_profiler | ||
enabled: False | ||
|
||
#Output directory of trace artifacts | ||
output_dir: ${output_dir}/profiling_outputs | ||
|
||
#`torch.profiler.ProfilerActivity` types to trace | ||
cpu: True | ||
cuda: True | ||
|
||
#trace options passed to `torch.profiler.profile` | ||
profile_memory: False | ||
with_stack: False | ||
record_shapes: True | ||
with_flops: False | ||
|
||
# `torch.profiler.schedule` options: | ||
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat | ||
wait_steps: 5 | ||
warmup_steps: 3 | ||
active_steps: 2 | ||
num_cycles: 1 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Config for single device full finetuning in full_finetune_single_device.py | ||
# using a Llama3.2 3B Instruct model | ||
# | ||
# This config assumes that you've run the following command before launching | ||
# this run: | ||
# tune download meta-llama/Llama-3.2-3B --output-dir /Users/ariel/checkpoint/tmp/Llama-3.2-3B --ignore-patterns "original/consolidated.00.pth" | ||
# | ||
# The default config uses an optimizer from bitsandbytes. If you do not have it installed, | ||
# you can install it with | ||
# pip install bitsandbytes | ||
# | ||
# To launch on a single device, run the following command from root: | ||
# tune run full_finetune_single_device --config llama3_2/3B_full_single_device | ||
# | ||
# You can add specific overrides through the command line. For example | ||
# to override the checkpointer directory while launching training | ||
# you can run: | ||
# tune run full_finetune_single_device --config llama3_2/3B_full_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> | ||
# | ||
# This config works only for training on single device. | ||
|
||
|
||
# Tokenizer | ||
tokenizer: | ||
_component_: torchtune.models.llama3.llama3_tokenizer | ||
path: /Users/ariel/checkpoint/tmp/Llama-3.2-3B/original/tokenizer.model | ||
max_seq_len: null | ||
|
||
# Dataset | ||
dataset: | ||
# _component_: torchtune.datasets.math_dataset | ||
_component_: torchtune.datasets.gsm8k_dataset | ||
seed: null | ||
shuffle: True | ||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.llama3_2.llama3_2_3b | ||
|
||
checkpointer: | ||
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: /Users/ariel/checkpoint/tmp/Llama-3.2-3B/ | ||
checkpoint_files: [ | ||
model-00001-of-00002.safetensors, | ||
model-00002-of-00002.safetensors, | ||
] | ||
recipe_checkpoint: null | ||
output_dir: /Users/ariel/checkpoint/tmp/Llama-3.2-3B/ | ||
model_type: LLAMA3_2 | ||
resume_from_checkpoint: False | ||
|
||
# Fine-tuning arguments | ||
batch_size: 4 | ||
max_generated_tokens: 100 | ||
top_k: null | ||
temperature: 0.7 | ||
RedTachyon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
epochs: 1 | ||
optimizer: | ||
_component_: bitsandbytes.optim.PagedAdamW8bit | ||
lr: 2e-5 | ||
loss: | ||
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss | ||
max_steps_per_epoch: null | ||
gradient_accumulation_steps: 1 # Use to increase virtual batch size | ||
optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 | ||
compile: False # pytorch compile, set to true for better perf/memory | ||
|
||
# Training environment | ||
device: mps | ||
RedTachyon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Memory management | ||
enable_activation_checkpointing: True # True reduces memory | ||
enable_activation_offloading: False # True reduces memory | ||
|
||
# Reduced precision | ||
dtype: bf16 | ||
|
||
# Logging | ||
metric_logger: | ||
_component_: torchtune.training.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
output_dir: /Users/ariel/checkpoint/tmp/full-llama3.2-finetune | ||
log_every_n_steps: 1 | ||
log_peak_memory_stats: False # for mac | ||
|
||
# Profiler (disabled) | ||
profiler: | ||
_component_: torchtune.training.setup_torch_profiler | ||
enabled: False | ||
|
||
#Output directory of trace artifacts | ||
output_dir: ${output_dir}/profiling_outputs | ||
|
||
#`torch.profiler.ProfilerActivity` types to trace | ||
cpu: True | ||
cuda: True | ||
|
||
#trace options passed to `torch.profiler.profile` | ||
profile_memory: False | ||
with_stack: False | ||
record_shapes: True | ||
with_flops: False | ||
|
||
# `torch.profiler.schedule` options: | ||
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat | ||
wait_steps: 5 | ||
warmup_steps: 3 | ||
active_steps: 2 | ||
num_cycles: 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would possibly rename this config something like
3B_full_grpo.yaml
. (Also related to my other comments.. I think the other configs you've added may not be necessary. In that case can just put this inrecipes/configs/dev/3B_full_grpo.yaml
without needing a dedicated GRPO directory.) Separately looking at our dev recipes/configs I realize the structure there is a bit inconsistent, so sorry for any confusion that may have caused.