-
Notifications
You must be signed in to change notification settings - Fork 128
Expand file tree
/
Copy pathtrain_ddp_cp.py
More file actions
236 lines (196 loc) · 9.11 KB
/
train_ddp_cp.py
File metadata and controls
236 lines (196 loc) · 9.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
import hydra
import torch
from omegaconf import DictConfig, OmegaConf
from torch.distributed.device_mesh import init_device_mesh
from torch.optim import AdamW
from transformer_engine.common.recipe import Format
from checkpoint import load_checkpoint_ddp, save_checkpoint_ddp, save_final_model_ddp, should_save_checkpoint
from dataset import create_cp_dataloader
from distributed_config import DistributedConfig
from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM
from perf_logger import PerfLogger
from quantization import resolve_layer_precision
from scheduler import get_linear_schedule_with_warmup
logger = logging.getLogger(__name__)
@hydra.main(config_path="hydra_config", config_name="L0_sanity_cp", version_base="1.2")
def main(args: DictConfig) -> float | None:
"""Train ESM-2 with TE layers using DDP.
Returns:
float: The loss value for the final batch.
"""
# Initialize the distributed configuration, including creating the distributed process group.
dist_config = DistributedConfig()
logger.info("Initializing distributed training: %s", dist_config)
device = torch.device(f"cuda:{dist_config.local_rank}")
torch.distributed.init_process_group(backend="nccl", device_id=device)
torch.cuda.set_device(dist_config.local_rank)
# Validate that world_size is divisible by cp_size
if dist_config.world_size % args.cp_size != 0:
raise ValueError(
f"world_size ({dist_config.world_size}) must be divisible by cp_size ({args.cp_size}). "
f"Set cp_size to a divisor of world_size."
)
# Calculate DDP size (number of data parallel replicas)
ddp_size = dist_config.world_size // args.cp_size
logger.info(
f"Creating device mesh: world_size={dist_config.world_size}, ddp_size={ddp_size}, cp_size={args.cp_size}"
)
# Create a device mesh for DDP and CP.
# The mesh is organized as [DDP_dimension, CP_dimension] where:
# - DDP dimension: number of data parallel replicas (world_size // cp_size)
# - CP dimension: context parallel size
# Total ranks = ddp_size * cp_size = world_size
device_mesh = init_device_mesh(
"cuda",
mesh_shape=(ddp_size, args.cp_size),
mesh_dim_names=("ddp", "cp"),
)
# Create quantization recipes -- these are only used if FP8/FP4 is enabled in the config.
fp8_recipe = None
fp4_recipe = None
if args.fp8_config.enabled:
fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)(
fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs
)
if args.fp4_config.enabled:
fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)(
fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs
)
if args.use_fp32_master_weights:
raise ValueError("FP32 master weights are not supported with DDP+CP. Use train_fsdp2_cp.py instead.")
# Create an empty ESM-2 model with a masked language model head, e.g. "nvidia/esm2_t6_8M_UR50D".
# Note: token_dropout is set to False because it's not compatible with context parallelism.
config = NVEsmConfig.from_pretrained(
args.config_name_or_path, token_dropout=False, dtype=torch.bfloat16, **args.config_kwargs
)
num_layers = config.num_hidden_layers
# Resolve layer-wise quantization assignments and store on config.
layer_precision = resolve_layer_precision(
num_layers=num_layers,
fp8_enabled=args.fp8_config.enabled,
fp4_enabled=args.fp4_config.enabled,
fp8_layers=OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None else None,
fp4_layers=OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None else None,
)
config.layer_precision = layer_precision
# If we're using sequence packing with TE layers, we need to pass the `attn_input_format` argument.
if args.use_sequence_packing:
config.attn_input_format = "thd"
# Create the model -- recipes and quantized_model_init are handled internally via get_autocast_context().
model = NVEsmForMaskedLM(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
logger.info("Initialized Model:\n%s", model)
# Create optimizer.
optimizer = AdamW(model.parameters(), **args.adamw_kwargs)
scheduler = get_linear_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs)
model = model.to(device=device)
group_fsdp_cp = device_mesh[("ddp", "cp")]._flatten("dp_cp").get_group()
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[dist_config.local_rank],
output_device=dist_config.local_rank,
process_group=group_fsdp_cp,
)
if args.cp_size > 1:
for i, transformer_layer in enumerate(model.module.esm.encoder.layers):
logger.debug(f"Rank {dist_config.rank}: Setting CP group for layer {i}")
transformer_layer.set_context_parallel_group(
device_mesh["cp"].get_group(),
torch.distributed.get_process_group_ranks(device_mesh["cp"].get_group()),
torch.cuda.Stream(),
)
# Context Parallelism requires THD Sequence Packing.
assert args.use_sequence_packing, "Context Parallelism requires THD Sequence Packing."
train_dataloader, dataset_or_sampler = create_cp_dataloader(
dist_config,
cp_mesh=device_mesh["cp"],
**args.dataset,
)
if args.use_torch_compile:
# If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency.
model = torch.compile(model)
# If we're resuming from a checkpoint, load it and set the start step. Otherwise, start from step 0.
ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_ddp" if args.checkpoint.ckpt_dir else None
if args.checkpoint.resume_from_checkpoint and ckpt_path:
model, optimizer, scheduler, train_dataloader, start_step, epoch = load_checkpoint_ddp(
model=model,
optimizer=optimizer,
scheduler=scheduler,
ckpt_path=ckpt_path,
dist_config=dist_config,
dataloader=train_dataloader,
)
else:
start_step = 0
epoch = 0
perf_logger = PerfLogger(dist_config, args)
# Training loop
step = start_step
while step < args.num_train_steps:
for batch in train_dataloader:
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa PLW2901
# Forward pass.
outputs = model(**batch)
# Backward pass.
loss = outputs.loss
loss.backward()
# Compute and clip gradient norms.
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item()
# Step optimizer.
optimizer.step()
scheduler.step()
optimizer.zero_grad()
perf_logger.log_step(
step=step,
batch=batch,
outputs=outputs,
grad_norm=total_norm,
lr=optimizer.param_groups[0]["lr"],
)
if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps):
save_checkpoint_ddp(
model=model,
optimizer=optimizer,
scheduler=scheduler,
ckpt_path=ckpt_path,
step=step,
epoch=epoch,
dist_config=dist_config,
dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None,
max_checkpoints=args.checkpoint.max_checkpoints,
)
step += 1
if step >= args.num_train_steps:
break
# Dataloader exhausted, incrementing epoch
epoch += 1
if dataset_or_sampler is not None: # The dataset only exists on rank 0
dataset_or_sampler.set_epoch(epoch)
# Save final model to a .safetensors file.
if args.checkpoint.save_final_model and ckpt_path:
save_final_model_ddp(
model=model,
save_directory=ckpt_path / "final_model",
dist_config=dist_config,
)
# Clean up distributed training
perf_logger.finish()
torch.distributed.destroy_process_group()
return perf_logger.min_loss
if __name__ == "__main__":
main()