Skip to content

Commit 3fc8545

Browse files
jomitchellnvJonathan MitchellJonathan Mitchell
authored
Adds context parallelism to FSDP2 (#1358)
### Description Creates a new training script callled `train_fsdp2_cp.py` where we add CP to FSDP2 #### Usage This you can run this script in the same way that you execute `train_ddp_cp.py` ```python torchrun --nproc_per_node=8 train_fsdp2_cp.py cp_size=<CP_SIZE> ``` For equivalence see <img width="5056" height="2656" alt="W B Chart 12_1_2025, 2_56_52 PM" src="https://github.com/user-attachments/assets/ef1513e0-0d3e-4fb9-a8b0-7bed8a47e86c" /> ### Type of changes <!-- Mark the relevant option with an [x] --> - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Refactor - [ ] Documentation update - [ ] Other (please describe): ### CI Pipeline Configuration Configure CI behavior by applying the relevant labels. By default, only basic unit tests are run. - [ciflow:skip](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:skip) - Skip all CI tests for this PR - [ciflow:notebooks](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:notebooks) - Run Jupyter notebooks execution tests for bionemo2 - [ciflow:slow](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:slow) - Run slow single GPU integration tests marked as @pytest.mark.slow for bionemo2 - [ciflow:all](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all) - Run all tests (unit tests, slow tests, and notebooks) for bionemo2. This label can be used to enforce running tests for all bionemo2. - [ciflow:all-recipes](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all-recipes) - Run tests for all recipes (under bionemo-recipes). This label can be used to enforce running tests for all recipes. Unit tests marked as `@pytest.mark.multi_gpu` or `@pytest.mark.distributed` are not run in the PR pipeline. For more details, see [CONTRIBUTING](CONTRIBUTING.md) > [!NOTE] > By default, only basic unit tests are run. Add appropriate labels to enable an additional test coverage. #### Authorizing CI Runs We use [copy-pr-bot](https://docs.gha-runners.nvidia.com/apps/copy-pr-bot/#automation) to manage authorization of CI runs on NVIDIA's compute resources. - If a pull request is opened by a trusted user and contains only trusted changes, the pull request's code will automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123) - If a pull request is opened by an untrusted user or contains untrusted changes, an NVIDIA org member must leave an `/ok to test` comment on the pull request to trigger CI. This will need to be done for each new commit. ### Pre-submit Checklist <!--- Ensure all items are completed before submitting --> - [ ] I have tested these changes locally - [ ] I have updated the documentation accordingly - [ ] I have added/updated tests as needed - [X] All existing tests pass successfully --------- Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com> Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com> Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1428.ipp1a1.colossus.nvidia.com> Co-authored-by: Jonathan Mitchell <jomitchell@ipp1-1334.ipp1a1.colossus.nvidia.com> Co-authored-by: Jonathan Mitchell <jomitchell@ipp1-1428.ipp1a1.colossus.nvidia.com>
1 parent d3d0734 commit 3fc8545

File tree

2 files changed

+256
-0
lines changed

2 files changed

+256
-0
lines changed

bionemo-recipes/recipes/esm2_native_te/tests/test_train_two_gpu.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,21 @@ def test_multi_gpu_train_te_ddp_cp(tmp_path, recipe_path):
145145
],
146146
recipe_path,
147147
)
148+
149+
150+
@requires_multi_gpu
151+
@requires_datacenter_hardware
152+
def test_multi_gpu_train_te_fsdp2_cp(tmp_path, recipe_path):
153+
# Run 'accelerate launch train.py' as a subprocess
154+
run_train_cmd(
155+
[
156+
"torchrun",
157+
"--nproc_per_node=2",
158+
"train_fsdp2_cp.py",
159+
"--config-name",
160+
"L0_sanity_cp",
161+
"num_train_steps=4",
162+
"cp_size=2",
163+
],
164+
recipe_path,
165+
)
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
from contextlib import nullcontext
18+
from pathlib import Path
19+
20+
import hydra
21+
import torch
22+
import transformer_engine.pytorch
23+
from omegaconf import DictConfig, OmegaConf
24+
from torch.distributed.device_mesh import init_device_mesh
25+
from torch.distributed.fsdp import fully_shard
26+
from torch.optim import AdamW
27+
from transformer_engine.common.recipe import Format
28+
from transformers import AutoConfig, AutoModelForMaskedLM
29+
30+
# This import seems to be needed with meta device init and AutoModel.from_config
31+
from transformers.models.esm.modeling_esm import EsmForMaskedLM # noqa: F401
32+
33+
from checkpoint import load_checkpoint_fsdp2, save_checkpoint_fsdp2, save_final_model_fsdp2, should_save_checkpoint
34+
from dataset import create_cp_dataloader
35+
from distributed_config import DistributedConfig
36+
from perf_logger import PerfLogger
37+
from scheduler import get_linear_schedule_with_warmup
38+
39+
40+
logger = logging.getLogger(__name__)
41+
logger.setLevel(logging.INFO)
42+
43+
44+
@hydra.main(config_path="hydra_config", config_name="L0_sanity_cp", version_base="1.2")
45+
def main(args: DictConfig) -> float | None: # noqa: C901
46+
"""Train ESM-2 with TE layers using fsdp2.
47+
48+
Returns:
49+
float: The loss value for the final batch.
50+
"""
51+
# Initialize the distributed configuration, including creating the distributed process group.
52+
dist_config = DistributedConfig()
53+
logger.info("Initializing distributed training: %s", dist_config)
54+
device = torch.device(f"cuda:{dist_config.local_rank}")
55+
torch.distributed.init_process_group(backend="nccl", device_id=device)
56+
torch.cuda.set_device(dist_config.local_rank)
57+
58+
# Validate that world_size is divisible by cp_size
59+
if dist_config.world_size % args.cp_size != 0:
60+
raise ValueError(
61+
f"world_size ({dist_config.world_size}) must be divisible by cp_size ({args.cp_size}). "
62+
f"Set cp_size to a divisor of world_size."
63+
)
64+
65+
# Calculate DP size (number of data parallel replicas)
66+
dp_size = dist_config.world_size // args.cp_size
67+
68+
# Create a device mesh for DP and CP.
69+
# The mesh is organized as [CP_dimension, DDP_dimension] where:
70+
# - DDP dimension: number of data parallel replicas (world_size // cp_size)
71+
# - CP dimension: context parallel size
72+
# Total ranks = cp_size * dp_size = world_size
73+
device_mesh = init_device_mesh(
74+
"cuda",
75+
mesh_shape=(dp_size, args.cp_size),
76+
mesh_dim_names=("dp", "cp"),
77+
)
78+
79+
# Our flattened group must have at least 2 ranks to enable Context Parallelism.
80+
if dp_size * args.cp_size <= 1:
81+
cp_dp_mesh = device_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_shard_cp")
82+
else:
83+
cp_dp_mesh = device_mesh
84+
85+
logger.info(
86+
f"Creating device mesh: world_size={dist_config.world_size}, dp_size={dp_size}, cp_size={args.cp_size}"
87+
)
88+
89+
cp_group = device_mesh["cp"].get_group()
90+
cp_rank = device_mesh.get_local_rank("cp")
91+
92+
# Create an FP8 recipe -- this is only used if FP8 is enabled in the config.
93+
fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)(
94+
fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs
95+
)
96+
97+
# Create an empty ESM-2 model with a masked language model head, e.g. "nvidia/esm2_t6_8M_UR50D".
98+
config = AutoConfig.from_pretrained(
99+
args.model_tag, trust_remote_code=True, token_dropout=False, dtype=torch.bfloat16
100+
)
101+
# If we're using sequence packing with TE layers, we need to pass the `attn_input_format` argument.
102+
if args.use_sequence_packing:
103+
config.attn_input_format = "thd"
104+
105+
# Optionally use transformer engine to initialize only fp8 versions of weights by setting
106+
# `fp8_config.fp8_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16 and fp8
107+
# versions of weights are kept.
108+
with (
109+
torch.device("meta") if args.use_meta_device else nullcontext(),
110+
transformer_engine.pytorch.fp8_model_init(recipe=fp8_recipe, **args.fp8_config.fp8_model_init_kwargs),
111+
):
112+
model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True)
113+
114+
logger.info("Initialized Model:\n%s", model)
115+
116+
# We call the transformer stack "layers" in our TE models, but it's called "layer" in the original ESM-2 models.
117+
transformer_stack = model.esm.encoder.layers if hasattr(model.esm.encoder, "layers") else model.esm.encoder.layer
118+
# Fully shard takes in a DeviceMesh object, which is a 2D mesh of dimensions (CP_dimension, DP_dimension).
119+
# FSDP2 will shard the model across the DP (dim=1) dimension and then duplicate across the CP (dim=0) dimension.
120+
for layer in transformer_stack:
121+
fully_shard(layer, mesh=cp_dp_mesh)
122+
# Set CP group for layer if CP is enabled.
123+
if args.cp_size > 1:
124+
logger.debug(f"Rank {dist_config.rank}: Setting CP group for layer {layer}")
125+
layer.set_context_parallel_group(
126+
cp_group, torch.distributed.get_process_group_ranks(cp_group), torch.cuda.Stream()
127+
)
128+
fully_shard(model, mesh=cp_dp_mesh)
129+
130+
# Create optimizer. Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873).
131+
optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore
132+
scheduler = get_linear_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs)
133+
134+
if args.use_meta_device:
135+
model.to_empty(device=device)
136+
for module in model.modules():
137+
if hasattr(module, "reset_parameters"):
138+
module.reset_parameters()
139+
140+
# Context Parallelism requires THD Sequence Packing.
141+
assert args.use_sequence_packing, "Context Parallelism requires THD Sequence Packing."
142+
143+
train_dataloader, dataset_or_sampler = create_cp_dataloader(
144+
dist_config,
145+
cp_world_size=torch.distributed.get_world_size(group=cp_group),
146+
cp_group=cp_group,
147+
cp_rank=cp_rank,
148+
**args.dataset,
149+
)
150+
151+
if args.use_torch_compile:
152+
# If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency.
153+
model = torch.compile(model)
154+
155+
# If we're resuming from a checkpoint, load it and set the start step. Otherwise, start from step 0.
156+
ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_fsdp2" if args.checkpoint.ckpt_dir else None
157+
if args.checkpoint.resume_from_checkpoint and ckpt_path:
158+
model, optimizer, scheduler, train_dataloader, start_step, epoch = load_checkpoint_fsdp2(
159+
model=model,
160+
optimizer=optimizer,
161+
scheduler=scheduler,
162+
ckpt_path=ckpt_path,
163+
dist_config=dist_config,
164+
dataloader=train_dataloader,
165+
)
166+
else:
167+
start_step = 0
168+
epoch = 0
169+
170+
perf_logger = PerfLogger(dist_config, args)
171+
172+
# Training loop
173+
step = start_step
174+
while step < args.num_train_steps:
175+
for batch in train_dataloader:
176+
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901
177+
178+
# Forward pass with mixed precision.
179+
with transformer_engine.pytorch.fp8_autocast(enabled=args.fp8_config.enabled, fp8_recipe=fp8_recipe):
180+
outputs = model(**batch)
181+
182+
# Backward pass.
183+
loss = outputs.loss
184+
loss.backward()
185+
186+
# Compute and clip gradient norms.
187+
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item()
188+
189+
# Step optimizer.
190+
optimizer.step()
191+
scheduler.step()
192+
optimizer.zero_grad()
193+
194+
perf_logger.log_step(
195+
step=step,
196+
batch=batch,
197+
outputs=outputs,
198+
grad_norm=total_norm,
199+
lr=optimizer.param_groups[0]["lr"],
200+
)
201+
202+
if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps):
203+
save_checkpoint_fsdp2(
204+
model=model,
205+
optimizer=optimizer,
206+
scheduler=scheduler,
207+
ckpt_path=ckpt_path,
208+
step=step,
209+
epoch=epoch,
210+
dist_config=dist_config,
211+
dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None,
212+
)
213+
214+
step += 1
215+
if step >= args.num_train_steps:
216+
break
217+
218+
# Dataloader exhausted, incrementing epoch
219+
epoch += 1
220+
dataset_or_sampler.set_epoch(epoch)
221+
222+
# Save final model to a .safetensors file.
223+
if args.checkpoint.save_final_model and ckpt_path:
224+
save_final_model_fsdp2(
225+
model=model,
226+
save_directory=ckpt_path / "final_model",
227+
dist_config=dist_config,
228+
)
229+
230+
# Clean up distributed training
231+
perf_logger.finish()
232+
torch.distributed.destroy_process_group()
233+
234+
return perf_logger.min_loss
235+
236+
237+
if __name__ == "__main__":
238+
main()

0 commit comments

Comments
 (0)