Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

R1-Style distributed GRPO #2326

Merged
merged 117 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
117 commits
Select commit Hold shift + click to select a range
1c29d67
RL starter code
RedTachyon Jan 25, 2025
ed544da
Add gsm8k
RedTachyon Jan 27, 2025
6ca2c38
Notebooks
RedTachyon Jan 27, 2025
7e39b69
Distributed dev progress
RedTachyon Jan 27, 2025
55b1c65
Decent progress on R1 RL
RedTachyon Jan 29, 2025
aa34954
PoC training loop, but not working - code checkpoint
RedTachyon Jan 30, 2025
0cc7795
8B recipe, sorta running?
RedTachyon Jan 31, 2025
b11e742
Merge branch 'main' into grpo
RedTachyon Jan 31, 2025
ce7e8ce
Merge pull request #1 from RedTachyon/grpo
RedTachyon Jan 31, 2025
736b14f
Some updates, some progress
RedTachyon Feb 1, 2025
eb3d5b9
Multi-node training, new reward shaping, success metric
RedTachyon Feb 1, 2025
9df4b38
Synchronize metrics, fix parsing, more memory management
RedTachyon Feb 2, 2025
3d39ffc
Sync rewards and successes between processes
RedTachyon Feb 3, 2025
30fa65c
Add filter kwargs and partition for easier dataset filtering
RedTachyon Feb 3, 2025
e8f19f2
Merge pull request #2 from RedTachyon/grpo
RedTachyon Feb 3, 2025
499c013
Reorganize methods
RedTachyon Feb 4, 2025
3ca4f49
Batched logit to logprob conversion
RedTachyon Feb 5, 2025
300c117
Remove old notes
RedTachyon Feb 6, 2025
3cf169c
Revert config changes
RedTachyon Feb 6, 2025
7d9d37a
More config cleanup
RedTachyon Feb 6, 2025
654eb56
Recipe cleanup
RedTachyon Feb 6, 2025
c75cb7a
Remove unnecessary PPO change
RedTachyon Feb 6, 2025
1f6be85
Cleanup
RedTachyon Feb 6, 2025
cf770d3
Merge branch 'pytorch:main' into main
RedTachyon Feb 6, 2025
1fd86c8
Remove redundant code
RedTachyon Feb 6, 2025
e198825
More redundant code
RedTachyon Feb 6, 2025
474c8dc
Reorganize recipes
RedTachyon Feb 6, 2025
60d7cd9
Some cleanup of RL dataset and GSM8k
RedTachyon Feb 7, 2025
8d018ba
Pre-commit cleanup
RedTachyon Feb 7, 2025
c9a01cb
Remove MATH dataset for now
RedTachyon Feb 7, 2025
c17bfc0
Properly remove MATH dataset
RedTachyon Feb 7, 2025
b312af8
Docstrings
RedTachyon Feb 7, 2025
f7c0929
Reorganize some recipes, add sbatch for SFT
RedTachyon Feb 7, 2025
f6c7e53
Remove unused 8B configs, add another sbatch
RedTachyon Feb 7, 2025
72cb4bd
Recipes leanup
RedTachyon Feb 7, 2025
a9112b8
GRPO recipe cleanup
RedTachyon Feb 7, 2025
b089ff2
Final MVP bugfixes
RedTachyon Feb 7, 2025
4870213
Remove old unused test
RedTachyon Feb 7, 2025
2dd6f60
Pre-commit
RedTachyon Feb 7, 2025
7b58712
Stop token handling for both single and multi device
RedTachyon Feb 7, 2025
b010ef6
Update recipes/configs/dev/grpo/3B_full_rl.yaml
RedTachyon Feb 8, 2025
603d16a
Remove redundant comment
RedTachyon Feb 8, 2025
d7269e4
Delete recipes/configs/llama3_2/3B_full_rl_single_device_mps.yaml
RedTachyon Feb 8, 2025
9978a44
Fix function arguments in optimizer setup
RedTachyon Feb 8, 2025
d710200
Update recipes/configs/dev/grpo/3B_full_rl.yaml
RedTachyon Feb 8, 2025
352c4fb
Update torchtune/rlhf/loss/grpo.py
RedTachyon Feb 8, 2025
8f60178
PPO -> GRPO
RedTachyon Feb 8, 2025
bb8b97d
RL -> GRPO (recipe
RedTachyon Feb 8, 2025
44aeb85
Make sure we're not training with float16
RedTachyon Feb 8, 2025
86de099
Remove chunked loss logic
RedTachyon Feb 8, 2025
120b1d2
Save rank and world_size in the recipe
RedTachyon Feb 8, 2025
f3e9b77
Rename R1Trajectory to GRPOTrajectory, other cleanup
RedTachyon Feb 8, 2025
b856695
| None -> Optional
RedTachyon Feb 8, 2025
e61894b
Fix `return_logits==False` edge case
RedTachyon Feb 8, 2025
96dc9c9
Undo an accidental change in generation
RedTachyon Feb 8, 2025
b1cfc3b
Update generate_trajectory docstring
RedTachyon Feb 8, 2025
8a62d1e
Reenable reference network
RedTachyon Feb 10, 2025
7394cc7
Remove optimizer_in_bwd
RedTachyon Feb 10, 2025
24cd238
Remove activation offloading
RedTachyon Feb 10, 2025
67458f2
(docstring) Reward model -> reward function
RedTachyon Feb 10, 2025
d9c37b8
Move reward function to a new file
RedTachyon Feb 10, 2025
18baeae
Remove dead code
RedTachyon Feb 10, 2025
646a2ce
Remove redundant logging
RedTachyon Feb 10, 2025
f641161
Remove question from the reward function
RedTachyon Feb 10, 2025
9f4fb88
Update recipes/dev/grpo_full_finetune_distributed.py
RedTachyon Feb 10, 2025
123883a
_grpo_step -> grpo_step
RedTachyon Feb 10, 2025
d9cf499
Remove breakpoints and comments
RedTachyon Feb 10, 2025
98df9a7
Remove breakpoints and comments
RedTachyon Feb 10, 2025
b068132
Docstring for the reward function
RedTachyon Feb 10, 2025
738dc2d
Handle reference checkpoint separately
RedTachyon Feb 10, 2025
759ff8b
Remove mentions of activation offloading
RedTachyon Feb 10, 2025
239d382
Fix messed up loss
RedTachyon Feb 10, 2025
957671c
Fix messed up loss, barriers to keep things in sync
RedTachyon Feb 10, 2025
357aa43
Delete max_steps_per_epoch and gradient accumulation, simplify inner …
RedTachyon Feb 11, 2025
3c54d01
Pre-commit
RedTachyon Feb 11, 2025
0d77799
Reorganize recipes
RedTachyon Feb 11, 2025
0f8df07
Remove dead settings from the GRPO config
RedTachyon Feb 11, 2025
184aeef
Use DiskLogger in GRPO
RedTachyon Feb 11, 2025
3018eeb
Use DiskLogger in SFT for GRPO
RedTachyon Feb 11, 2025
d29c455
Recipe name in logging
RedTachyon Feb 11, 2025
b9a56b9
Remove redundant logging
RedTachyon Feb 11, 2025
b143b00
Cleaned up official configs
RedTachyon Feb 12, 2025
05cf10d
Merge branch 'main' into main
RedTachyon Feb 12, 2025
be4d96d
Docstrings for GRPO types
RedTachyon Feb 12, 2025
1303695
Merge remote-tracking branch 'arielpublic/main'
RedTachyon Feb 12, 2025
9d5bfa5
Fix checkpointing
RedTachyon Feb 12, 2025
6cbb10a
Update recipes/dev/grpo_full_finetune_distributed.py
RedTachyon Feb 12, 2025
97686e2
Update recipes/dev/grpo_full_finetune_distributed.py
RedTachyon Feb 12, 2025
aeb69cf
Update recipes/dev/grpo_full_finetune_distributed.py
RedTachyon Feb 12, 2025
feeb042
Remove mention of ac
RedTachyon Feb 12, 2025
e7cb937
Pre-commit
RedTachyon Feb 12, 2025
2a75224
Revert generation changes, pre-commit
RedTachyon Feb 12, 2025
5010f1d
Fix RL collate function
RedTachyon Feb 12, 2025
711ee6b
Remove resharding from config
RedTachyon Feb 12, 2025
aad81fb
| -> Optional in rl dataset
RedTachyon Feb 13, 2025
53cd129
| -> Optional in GSM8k dataset
RedTachyon Feb 13, 2025
4dbe231
Merge branch 'pytorch:main' into main
RedTachyon Feb 16, 2025
3a8c6de
Experimental stuff
RedTachyon Feb 17, 2025
9677df1
Additional experimental stuff
RedTachyon Feb 17, 2025
9340f89
Move experimental code to /dev
RedTachyon Feb 17, 2025
53ee98f
Properly move experimental code to /dev
RedTachyon Feb 17, 2025
682ebef
Remove recursive_reshard from the public API
RedTachyon Feb 17, 2025
8dd7546
Separate optimized generation function, small fixes
RedTachyon Feb 17, 2025
0939c94
Undo generation changes in the main function
RedTachyon Feb 17, 2025
81a7765
Fix custom generate
RedTachyon Feb 17, 2025
e41a520
Pre-commit
RedTachyon Feb 17, 2025
45084e0
Fix SFT dataset path
RedTachyon Feb 18, 2025
8283230
Merge branch 'pytorch:main' into main
RedTachyon Feb 19, 2025
987e971
Merge branch 'experiments' into main
RedTachyon Feb 19, 2025
7a03139
Revert "Merge branch 'experiments' into main"
RedTachyon Feb 19, 2025
3fc43d7
Update recipes/configs/dev/3B_full_grpo.yaml
RedTachyon Feb 21, 2025
dde5fd8
Update recipes/configs/dev/3B_sft_for_grpo.yaml
RedTachyon Feb 21, 2025
2fddf9c
Update recipes/configs/dev/3B_full_grpo.yaml
RedTachyon Feb 21, 2025
aead54e
Remove redundant async checkpointing code
RedTachyon Feb 21, 2025
16ad525
Remove some redundant clones
RedTachyon Feb 21, 2025
70886b2
Add a generation comment
RedTachyon Feb 21, 2025
2ba4a97
Pre-commit
RedTachyon Feb 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions recipes/configs/dev/grpo/3B_full_rl.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Config for multi-node GRPO in dev/grpo_full_finetune_distributed.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would possibly rename this config something like 3B_full_grpo.yaml. (Also related to my other comments.. I think the other configs you've added may not be necessary. In that case can just put this in recipes/configs/dev/3B_full_grpo.yaml without needing a dedicated GRPO directory.) Separately looking at our dev recipes/configs I realize the structure there is a bit inconsistent, so sorry for any confusion that may have caused.

# using a Llama3.2 3B Base model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Meta-Llama-3.2-3B --output-dir /tmp/Meta-Llama-3.2-3B-Instruct --ignore-patterns "original/consolidated.00.pth"
#
# It can be beneficial to first train the base model with SFT using the 3B_sft recipe.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can do something similar to what's in our knowledge distillation configs and just explicitly add the command here for convenience. Ref

#
# To launch on 4 devices, run the following command from root:
# tune run --nproc_per_node 4 dev/grpo_full_finetune_distributed --config dev/grpo/3B_full_rl
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be this based on how it's defined in the recipe registry i think

Suggested change
# tune run --nproc_per_node 4 dev/grpo_full_finetune_distributed --config dev/grpo/3B_full_rl
# tune run --nproc_per_node 4 dev/grpo_full_finetune_distributed --config dev/3B_full_rl

#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nproc_per_node 4 dev/grpo_full_finetune_distributed --config dev/grpo/3B_full_rl checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works best when the model is being fine-tuned on 2+ GPUs.
#
# Furthermore, you can launch it on multiple nodes by going to recipes/dev/ and using
# sbatch multinode_grpo.sbatch

base_dir: /tmp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do you feel about deleting base_dir, and just keep output_dir // base_model_path, so we can remove one extra param?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So my personal preference is actually keeping base_dir or something similar, but for the "official" codebase it probably makes sense to keep it more uniform.

My reasoning is that - putting everything in /tmp/ isn't a terrible assumption with no information about the user's filesystem, but at the same time, most of the time, I don't want to put everything in /tmp. So it's much easier for me to change this in one place instead of tracking down all /tmp instances in the config.

name: grpo_sft_start

output_dir: ${base_dir}/checkpoints/${name}
base_model_path: ${base_dir}/Llama-3.2-3B # Use this to train from the base model
#base_model_path: ${base_dir}/llama3B_gsm8k_sft_part0/epoch_0 # Use this to train from the slightly trained SFT model

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: ${base_model_path}/original/tokenizer.model
max_seq_len: null

# Dataset
dataset:
_component_: torchtune.datasets.gsm8k_dataset
partition: 1-9/10
# packed: False # True increases speed
seed: null
shuffle: False

# Model Arguments
model:
_component_: torchtune.models.llama3_2.llama3_2_3b

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: ${base_model_path}
checkpoint_files: [
model-00001-of-00002.safetensors, # Add ft- if starting from SFT model
model-00002-of-00002.safetensors,
]
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False
save_every_n_epochs: 1

# Fine-tuning arguments
batch_size: 1
grpo_samples: 16
forward_batch_size: 1
max_generated_tokens: 512
top_k: null
temperature: 1.0

fsdp_reshard_after_forward: True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why was this necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not necessary - I think my process was that I initially set it to False, which sped up the generation (by a decent margin too, like 25% in the data generation phase), but then it caused some irregular memory issues. I didn't want to deal with the memory profiling just yet, so I set it back to True to see if the pipeline even works.

In any case, this will likely be an impactful parameter to check when optimizing the recipe for performance, and I'm guessing it's not very commonly used in other recipes, since here we're doing an unusually large amount of generation.


ppo_epochs: 1
ppo_batch_size: ${grpo_samples} # single-step for the "simple" loss
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to check where this is used, but i wonder if 'ppo' is a good name here. I also find i a bit confusing, without context, the difference between ppo_epochs and epochs, or batch_size and ppo_batch_size.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It probably makes sense to rename it, but it will be an awkward area of weird overlaps in naming and active research.

As an algorithm, GRPO is basically the same as PPO, but with a different advantage estimation method. Or at least that's how it's defined in the R1 paper.
However, the implementation that's validated to run right now doesn't follow that math super closely, and instead follows the interpretation introduced in TRL, which has their own justification for that interpretation based on Deepseek's paper.

The general idea in PPO is that we can take multiple gradient steps with a single batch of data - ppo_epochs indicates how many times we go through the batch, and ppo_batch_size is a minibatch size inside of this (what we actually process at once).

The implementation that is currently running simplifies a few elements of the PPO/GRPO loss, which also requires only taking a single gradient step per data batch. This is what works in TRL, and what I validated to work here in torchtune. It's very likely that some other variations will also work, but I've yet to run those experiments.


num_steps: 10000

clip_grad_norm: 1.0

epochs: 10
optimizer:
_component_: torch.optim.AdamW
lr: 1e-5
fused: True
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 50
loss:
_component_: torchtune.rlhf.loss.GRPOSimpleLoss
kl_coeff: 0.04
epsilon: 0.2
max_steps_per_epoch: null
gradient_accumulation_steps: 1 # Use to increase virtual batch size

# Training env
device: cuda

# Memory management
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: True # True reduces memory
compile: False # pytorch compile, set to true for better perf/memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can also just consider removing this from the recipe for simplicity unless you find that it's necessary


# Reduced precision
dtype: bf16
custom_sharded_layers: ['decoder.tok_embeddings']


# Logging
metric_logger:
_component_: torchtune.training.metric_logging.WandBLogger
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least prior to landing should just use DiskLogger as in our other configs (not everyone necessarily has WandB setup)

project: grpo_gsm8k_torchtune
log_dir: ${output_dir}/logs
name: ${name}
log_every_n_steps: 1
log_peak_memory_stats: True

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: True

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: True
with_stack: True
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 3
active_steps: 2
num_cycles: 1
112 changes: 112 additions & 0 deletions recipes/configs/dev/grpo/3B_sft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Config for multi-device SFT for reasoning in full_finetune_distributed.py
# using a Llama3.2 3B Base model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Meta-Llama-3.2-3B --output-dir /tmp/Meta-Llama-3.2-3B-Instruct --ignore-patterns "original/consolidated.00.pth"
#
# To launch on 4 devices, run the following command from root:
# tune run --nproc_per_node 4 full_finetune_distributed --config dev/grpo/3B_sft
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nproc_per_node 4 full_finetune_distributed --config dev/grpo/3B_sft checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works best when the model is being fine-tuned on 2+ GPUs.


base_dir: /tmp
name: llama3B_gsm8k_sft_part0

output_dir: ${base_dir}/${name}

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: ${base_dir}/Llama-3.2-3B/original/tokenizer.model
max_seq_len: null

# Dataset
dataset:
_component_: torchtune.datasets.gsm8k_sft
partition: 0-0/10
seed: null
shuffle: True

# Model Arguments
model:
_component_: torchtune.models.llama3_2.llama3_2_3b

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: ${base_dir}/Llama-3.2-3B/
checkpoint_files: [
model-00001-of-00002.safetensors,
model-00002-of-00002.safetensors,
]
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 2
epochs: 1

optimizer:
_component_: torch.optim.AdamW
lr: 1e-5
fused: True
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
clip_grad_norm: null
compile: False # torch.compile the model + loss, True increases speed + decreases memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
gradient_accumulation_steps: 1 # Use to increase effective batch size

# Training env
device: cuda

# Memory management
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16

# Logging
metric_logger:
_component_: torchtune.training.metric_logging.WandBLogger
project: gsm8k_sft
log_dir: ${output_dir}/logs
name: ${name}
log_every_n_steps: 1
log_peak_memory_stats: True


# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 3
active_steps: 2
num_cycles: 1
110 changes: 110 additions & 0 deletions recipes/configs/llama3_2/3B_full_rl_single_device_mps.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Config for single device full finetuning in full_finetune_single_device.py
# using a Llama3.2 3B Instruct model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Llama-3.2-3B --output-dir /Users/ariel/checkpoint/tmp/Llama-3.2-3B --ignore-patterns "original/consolidated.00.pth"
#
# The default config uses an optimizer from bitsandbytes. If you do not have it installed,
# you can install it with
# pip install bitsandbytes
#
# To launch on a single device, run the following command from root:
# tune run full_finetune_single_device --config llama3_2/3B_full_single_device
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run full_finetune_single_device --config llama3_2/3B_full_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works only for training on single device.


# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /Users/ariel/checkpoint/tmp/Llama-3.2-3B/original/tokenizer.model
max_seq_len: null

# Dataset
dataset:
# _component_: torchtune.datasets.math_dataset
_component_: torchtune.datasets.gsm8k_dataset
seed: null
shuffle: True

# Model Arguments
model:
_component_: torchtune.models.llama3_2.llama3_2_3b

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /Users/ariel/checkpoint/tmp/Llama-3.2-3B/
checkpoint_files: [
model-00001-of-00002.safetensors,
model-00002-of-00002.safetensors,
]
recipe_checkpoint: null
output_dir: /Users/ariel/checkpoint/tmp/Llama-3.2-3B/
model_type: LLAMA3_2
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 4
max_generated_tokens: 100
top_k: null
temperature: 0.7

epochs: 1
optimizer:
_component_: bitsandbytes.optim.PagedAdamW8bit
lr: 2e-5
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1 # Use to increase virtual batch size
optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1
compile: False # pytorch compile, set to true for better perf/memory

# Training environment
device: mps

# Memory management
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16

# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
output_dir: /Users/ariel/checkpoint/tmp/full-llama3.2-finetune
log_every_n_steps: 1
log_peak_memory_stats: False # for mac

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 3
active_steps: 2
num_cycles: 1
Loading