Skip to content

Add support for FSDP2 #3317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ jobs:
- name: Install dependencies
run: |
source .venv/bin/activate
uv pip install accelerate==0.34.0
uv pip install accelerate==1.6.0
uv pip install datasets==3.0.0
uv pip install transformers==4.46.0
uv pip install ".[dev]"
Expand Down
28 changes: 28 additions & 0 deletions examples/accelerate_configs/fsdp1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: true
fsdp_offload_params: false
fsdp_reshard_after_forward: FULL_SHARD
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: true
fsdp_version: 1
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
24 changes: 24 additions & 0 deletions examples/accelerate_configs/fsdp2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_version: 2
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
25 changes: 0 additions & 25 deletions examples/accelerate_configs/fsdp_qlora.yaml

This file was deleted.

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
__version__ = "0.17.0.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)

REQUIRED_PKGS = [
"accelerate>=0.34.0",
"accelerate>=1.6.0",
"datasets>=3.0.0",
"rich", # rich shouldn't be a required package for trl, we should remove it from here
"transformers>=4.46.0",
Expand Down
2 changes: 1 addition & 1 deletion trl/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def prepare_fsdp(model, accelerator):
accelerator.state.fsdp_plugin.set_auto_wrap_policy(model)
fsdp_plugin = accelerator.state.fsdp_plugin
kwargs = {
"sharding_strategy": fsdp_plugin.sharding_strategy,
"sharding_strategy": fsdp_plugin.sharding_strategy or fsdp_plugin.reshard_after_forward,
"cpu_offload": fsdp_plugin.cpu_offload,
"auto_wrap_policy": fsdp_plugin.auto_wrap_policy,
"mixed_precision": fsdp_plugin.mixed_precision_policy,
Expand Down
Loading