Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

R1-Style distributed GRPO #2326

Merged
merged 117 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from 104 commits
Commits
Show all changes
117 commits
Select commit Hold shift + click to select a range
1c29d67
RL starter code
RedTachyon Jan 25, 2025
ed544da
Add gsm8k
RedTachyon Jan 27, 2025
6ca2c38
Notebooks
RedTachyon Jan 27, 2025
7e39b69
Distributed dev progress
RedTachyon Jan 27, 2025
55b1c65
Decent progress on R1 RL
RedTachyon Jan 29, 2025
aa34954
PoC training loop, but not working - code checkpoint
RedTachyon Jan 30, 2025
0cc7795
8B recipe, sorta running?
RedTachyon Jan 31, 2025
b11e742
Merge branch 'main' into grpo
RedTachyon Jan 31, 2025
ce7e8ce
Merge pull request #1 from RedTachyon/grpo
RedTachyon Jan 31, 2025
736b14f
Some updates, some progress
RedTachyon Feb 1, 2025
eb3d5b9
Multi-node training, new reward shaping, success metric
RedTachyon Feb 1, 2025
9df4b38
Synchronize metrics, fix parsing, more memory management
RedTachyon Feb 2, 2025
3d39ffc
Sync rewards and successes between processes
RedTachyon Feb 3, 2025
30fa65c
Add filter kwargs and partition for easier dataset filtering
RedTachyon Feb 3, 2025
e8f19f2
Merge pull request #2 from RedTachyon/grpo
RedTachyon Feb 3, 2025
499c013
Reorganize methods
RedTachyon Feb 4, 2025
3ca4f49
Batched logit to logprob conversion
RedTachyon Feb 5, 2025
300c117
Remove old notes
RedTachyon Feb 6, 2025
3cf169c
Revert config changes
RedTachyon Feb 6, 2025
7d9d37a
More config cleanup
RedTachyon Feb 6, 2025
654eb56
Recipe cleanup
RedTachyon Feb 6, 2025
c75cb7a
Remove unnecessary PPO change
RedTachyon Feb 6, 2025
1f6be85
Cleanup
RedTachyon Feb 6, 2025
cf770d3
Merge branch 'pytorch:main' into main
RedTachyon Feb 6, 2025
1fd86c8
Remove redundant code
RedTachyon Feb 6, 2025
e198825
More redundant code
RedTachyon Feb 6, 2025
474c8dc
Reorganize recipes
RedTachyon Feb 6, 2025
60d7cd9
Some cleanup of RL dataset and GSM8k
RedTachyon Feb 7, 2025
8d018ba
Pre-commit cleanup
RedTachyon Feb 7, 2025
c9a01cb
Remove MATH dataset for now
RedTachyon Feb 7, 2025
c17bfc0
Properly remove MATH dataset
RedTachyon Feb 7, 2025
b312af8
Docstrings
RedTachyon Feb 7, 2025
f7c0929
Reorganize some recipes, add sbatch for SFT
RedTachyon Feb 7, 2025
f6c7e53
Remove unused 8B configs, add another sbatch
RedTachyon Feb 7, 2025
72cb4bd
Recipes leanup
RedTachyon Feb 7, 2025
a9112b8
GRPO recipe cleanup
RedTachyon Feb 7, 2025
b089ff2
Final MVP bugfixes
RedTachyon Feb 7, 2025
4870213
Remove old unused test
RedTachyon Feb 7, 2025
2dd6f60
Pre-commit
RedTachyon Feb 7, 2025
7b58712
Stop token handling for both single and multi device
RedTachyon Feb 7, 2025
b010ef6
Update recipes/configs/dev/grpo/3B_full_rl.yaml
RedTachyon Feb 8, 2025
603d16a
Remove redundant comment
RedTachyon Feb 8, 2025
d7269e4
Delete recipes/configs/llama3_2/3B_full_rl_single_device_mps.yaml
RedTachyon Feb 8, 2025
9978a44
Fix function arguments in optimizer setup
RedTachyon Feb 8, 2025
d710200
Update recipes/configs/dev/grpo/3B_full_rl.yaml
RedTachyon Feb 8, 2025
352c4fb
Update torchtune/rlhf/loss/grpo.py
RedTachyon Feb 8, 2025
8f60178
PPO -> GRPO
RedTachyon Feb 8, 2025
bb8b97d
RL -> GRPO (recipe
RedTachyon Feb 8, 2025
44aeb85
Make sure we're not training with float16
RedTachyon Feb 8, 2025
86de099
Remove chunked loss logic
RedTachyon Feb 8, 2025
120b1d2
Save rank and world_size in the recipe
RedTachyon Feb 8, 2025
f3e9b77
Rename R1Trajectory to GRPOTrajectory, other cleanup
RedTachyon Feb 8, 2025
b856695
| None -> Optional
RedTachyon Feb 8, 2025
e61894b
Fix `return_logits==False` edge case
RedTachyon Feb 8, 2025
96dc9c9
Undo an accidental change in generation
RedTachyon Feb 8, 2025
b1cfc3b
Update generate_trajectory docstring
RedTachyon Feb 8, 2025
8a62d1e
Reenable reference network
RedTachyon Feb 10, 2025
7394cc7
Remove optimizer_in_bwd
RedTachyon Feb 10, 2025
24cd238
Remove activation offloading
RedTachyon Feb 10, 2025
67458f2
(docstring) Reward model -> reward function
RedTachyon Feb 10, 2025
d9c37b8
Move reward function to a new file
RedTachyon Feb 10, 2025
18baeae
Remove dead code
RedTachyon Feb 10, 2025
646a2ce
Remove redundant logging
RedTachyon Feb 10, 2025
f641161
Remove question from the reward function
RedTachyon Feb 10, 2025
9f4fb88
Update recipes/dev/grpo_full_finetune_distributed.py
RedTachyon Feb 10, 2025
123883a
_grpo_step -> grpo_step
RedTachyon Feb 10, 2025
d9cf499
Remove breakpoints and comments
RedTachyon Feb 10, 2025
98df9a7
Remove breakpoints and comments
RedTachyon Feb 10, 2025
b068132
Docstring for the reward function
RedTachyon Feb 10, 2025
738dc2d
Handle reference checkpoint separately
RedTachyon Feb 10, 2025
759ff8b
Remove mentions of activation offloading
RedTachyon Feb 10, 2025
239d382
Fix messed up loss
RedTachyon Feb 10, 2025
957671c
Fix messed up loss, barriers to keep things in sync
RedTachyon Feb 10, 2025
357aa43
Delete max_steps_per_epoch and gradient accumulation, simplify inner …
RedTachyon Feb 11, 2025
3c54d01
Pre-commit
RedTachyon Feb 11, 2025
0d77799
Reorganize recipes
RedTachyon Feb 11, 2025
0f8df07
Remove dead settings from the GRPO config
RedTachyon Feb 11, 2025
184aeef
Use DiskLogger in GRPO
RedTachyon Feb 11, 2025
3018eeb
Use DiskLogger in SFT for GRPO
RedTachyon Feb 11, 2025
d29c455
Recipe name in logging
RedTachyon Feb 11, 2025
b9a56b9
Remove redundant logging
RedTachyon Feb 11, 2025
b143b00
Cleaned up official configs
RedTachyon Feb 12, 2025
05cf10d
Merge branch 'main' into main
RedTachyon Feb 12, 2025
be4d96d
Docstrings for GRPO types
RedTachyon Feb 12, 2025
1303695
Merge remote-tracking branch 'arielpublic/main'
RedTachyon Feb 12, 2025
9d5bfa5
Fix checkpointing
RedTachyon Feb 12, 2025
6cbb10a
Update recipes/dev/grpo_full_finetune_distributed.py
RedTachyon Feb 12, 2025
97686e2
Update recipes/dev/grpo_full_finetune_distributed.py
RedTachyon Feb 12, 2025
aeb69cf
Update recipes/dev/grpo_full_finetune_distributed.py
RedTachyon Feb 12, 2025
feeb042
Remove mention of ac
RedTachyon Feb 12, 2025
e7cb937
Pre-commit
RedTachyon Feb 12, 2025
2a75224
Revert generation changes, pre-commit
RedTachyon Feb 12, 2025
5010f1d
Fix RL collate function
RedTachyon Feb 12, 2025
711ee6b
Remove resharding from config
RedTachyon Feb 12, 2025
aad81fb
| -> Optional in rl dataset
RedTachyon Feb 13, 2025
53cd129
| -> Optional in GSM8k dataset
RedTachyon Feb 13, 2025
4dbe231
Merge branch 'pytorch:main' into main
RedTachyon Feb 16, 2025
3a8c6de
Experimental stuff
RedTachyon Feb 17, 2025
9677df1
Additional experimental stuff
RedTachyon Feb 17, 2025
9340f89
Move experimental code to /dev
RedTachyon Feb 17, 2025
53ee98f
Properly move experimental code to /dev
RedTachyon Feb 17, 2025
682ebef
Remove recursive_reshard from the public API
RedTachyon Feb 17, 2025
8dd7546
Separate optimized generation function, small fixes
RedTachyon Feb 17, 2025
0939c94
Undo generation changes in the main function
RedTachyon Feb 17, 2025
81a7765
Fix custom generate
RedTachyon Feb 17, 2025
e41a520
Pre-commit
RedTachyon Feb 17, 2025
45084e0
Fix SFT dataset path
RedTachyon Feb 18, 2025
8283230
Merge branch 'pytorch:main' into main
RedTachyon Feb 19, 2025
987e971
Merge branch 'experiments' into main
RedTachyon Feb 19, 2025
7a03139
Revert "Merge branch 'experiments' into main"
RedTachyon Feb 19, 2025
3fc43d7
Update recipes/configs/dev/3B_full_grpo.yaml
RedTachyon Feb 21, 2025
dde5fd8
Update recipes/configs/dev/3B_sft_for_grpo.yaml
RedTachyon Feb 21, 2025
2fddf9c
Update recipes/configs/dev/3B_full_grpo.yaml
RedTachyon Feb 21, 2025
aead54e
Remove redundant async checkpointing code
RedTachyon Feb 21, 2025
16ad525
Remove some redundant clones
RedTachyon Feb 21, 2025
70886b2
Add a generation comment
RedTachyon Feb 21, 2025
2ba4a97
Pre-commit
RedTachyon Feb 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions recipes/configs/dev/3B_full_grpo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Config for multi-node GRPO in dev/grpo_full_finetune_distributed.py
# using a Llama3.2 3B Base model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Llama-3.2-3B --output-dir /tmp/Llama-3.2-3B --ignore-patterns "original/consolidated.00.pth"
#
# It can be beneficial to first train the base model with SFT using the 3B_sft recipe.
#
# To launch on 4 devices, run the following command from root:
# tune run --nproc_per_node 4 dev/grpo_full_finetune_distributed --config dev/grpo/3B_full_rl
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nproc_per_node 4 dev/grpo_full_finetune_distributed --config dev/grpo/3B_full_rl checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works best when the model is being fine-tuned on 2+ GPUs.
#
# Furthermore, you can launch it on multiple nodes by going to recipes/dev/ and using
# sbatch multinode_grpo.sbatch

name: grpo_llama3b

output_dir: /tmp/checkpoints/${name}
base_model_path: /tmp/llama3B_gsm8k_sft_part0/epoch_0 # Use this to train from the slightly trained SFT model

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/original/tokenizer.model
max_seq_len: null

# Dataset
dataset:
_component_: torchtune.dev.grpo.gsm8k.gsm8k_dataset
partition: 1-9/10
seed: null
shuffle: False

# Model Arguments
model:
_component_: torchtune.models.llama3_2.llama3_2_3b

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: ${base_model_path}
checkpoint_files: [
model-00001-of-00002.safetensors,
model-00002-of-00002.safetensors,
]
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3


ref_checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: ${base_model_path}
checkpoint_files: [
model-00001-of-00002.safetensors,
model-00002-of-00002.safetensors,
]
recipe_checkpoint: null
output_dir: ${output_dir}/ref # shouldn't be used?
model_type: LLAMA3


resume_from_checkpoint: False
save_every_n_epochs: 1

# Fine-tuning arguments
batch_size: 1
grpo_samples: 16
forward_batch_size: 1
max_generated_tokens: 512
top_k: null
temperature: 1.0

ppo_epochs: 1

num_steps: 200

clip_grad_norm: 1.0

epochs: 10
optimizer:
_component_: torch.optim.AdamW
lr: 1e-5
fused: True
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 50
loss:
_component_: torchtune.dev.grpo.loss.GRPOSimpleLoss
kl_coeff: 0.01
epsilon: 0.2

# Training env
device: cuda

# Memory management
enable_activation_checkpointing: True # True reduces memory
compile: False # pytorch compile, set to true for better perf/memory

# Reduced precision
dtype: bf16


# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: True

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: True
with_stack: True
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 3
active_steps: 2
num_cycles: 1
109 changes: 109 additions & 0 deletions recipes/configs/dev/3B_sft_for_grpo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Config for multi-device SFT for reasoning in full_finetune_distributed.py
# using a Llama3.2 3B Base model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Meta-Llama-3.2-3B --output-dir /tmp/Meta-Llama-3.2-3B-Instruct --ignore-patterns "original/consolidated.00.pth"
#
# To launch on 4 devices, run the following command from root:
# tune run --nproc_per_node 4 full_finetune_distributed --config dev/grpo/3B_sft
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nproc_per_node 4 full_finetune_distributed --config dev/grpo/3B_sft checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works best when the model is being fine-tuned on 2+ GPUs.


name: llama3B_gsm8k_sft_part0

output_dir: /tmp/${name}

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Llama-3.2-3B/original/tokenizer.model
max_seq_len: null

# Dataset
dataset:
_component_: torchtune.dev.grpo.data.gsm8k_sft
partition: 0-0/10
seed: null
shuffle: True

# Model Arguments
model:
_component_: torchtune.models.llama3_2.llama3_2_3b

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-3.2-3B/
checkpoint_files: [
model-00001-of-00002.safetensors,
model-00002-of-00002.safetensors,
]
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 2
epochs: 1

optimizer:
_component_: torch.optim.AdamW
lr: 1e-5
fused: True
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
clip_grad_norm: null
compile: False # torch.compile the model + loss, True increases speed + decreases memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
gradient_accumulation_steps: 1 # Use to increase effective batch size

# Training env
device: cuda

# Memory management
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16

# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True


# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 3
active_steps: 2
num_cycles: 1
Loading