Skip to content

Commit a09cde3

Browse files
tianyu-lwanchaollessw2020wconstabgnadathur
authored
merge upstream changes (#569)
It seems SimpleFSDP + TP works with PP out-of-box, but whole-model compile doesn't. So `torch.compile` on SimpleFSDP won't take effect with PP. Need to figure out what's the best way to do it. --------- Co-authored-by: Wanchao <[email protected]> Co-authored-by: Less Wright <[email protected]> Co-authored-by: Will Constable <[email protected]> Co-authored-by: gnadathur <[email protected]> Co-authored-by: gnadathur <[email protected]> Co-authored-by: gnadathur <[email protected]> Co-authored-by: Driss Guessous <[email protected]> Co-authored-by: Iris Z <[email protected]> Co-authored-by: Will Constable <[email protected]> Co-authored-by: Geeta Chauhan <[email protected]> Co-authored-by: Andrew Gu <[email protected]> Co-authored-by: Andrew Gu <[email protected]> Co-authored-by: wz337 <[email protected]> Co-authored-by: Soumith Chintala <[email protected]> Co-authored-by: Mark Saroufim <[email protected]> Co-authored-by: Chien-Chin Huang <[email protected]> Co-authored-by: Chien-Chin Huang <[email protected]> Co-authored-by: liangluofb <[email protected]> Co-authored-by: Huy Do <[email protected]> Co-authored-by: Gokul <[email protected]> Co-authored-by: Pavel Belevich <[email protected]> Co-authored-by: Ke Wen <[email protected]> Co-authored-by: Wei (Will) Feng <[email protected]> Co-authored-by: Howard Huang <[email protected]> Co-authored-by: Xilun Wu <[email protected]> Co-authored-by: Sanket Jayant Purandare <[email protected]> Co-authored-by: Yifu Wang <[email protected]> Co-authored-by: Vasiliy Kuznetsov <[email protected]> Co-authored-by: Sanket Jayant Purandare <[email protected]> Co-authored-by: Hugo <[email protected]>
1 parent ac90c36 commit a09cde3

File tree

11 files changed

+392
-21
lines changed

11 files changed

+392
-21
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,6 @@ torchtitan/datasets/**/*.model
1717
*.log
1818
error.json
1919
_remote_module_non_scriptable.py
20+
21+
# torch compile debug related
22+
torch_compile_debug/*

benchmark.py

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
import time
9+
from datetime import timedelta
10+
11+
import torch
12+
from torch.distributed.elastic.multiprocessing.errors import record
13+
14+
from torchbenchmark.util.experiment.instantiator import (
15+
load_model,
16+
TorchBenchModelConfig,
17+
)
18+
from torchbenchmark.util.experiment.metrics import get_model_flops
19+
from torchbenchmark.util.input import input_cast
20+
21+
from torchtitan import utils
22+
from torchtitan.checkpoint import TrainState
23+
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
24+
from torchtitan.logging import init_logger, logger
25+
from torchtitan.metrics import build_gpu_memory_monitor
26+
from torchtitan.parallelisms import ParallelDims
27+
from torchtitan.parallelisms.parallelize_llama import torch_spmd_parallelize
28+
from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
29+
30+
31+
# Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
32+
@record
33+
def main(job_config: JobConfig):
34+
init_logger()
35+
logger.info(f"Starting job: {job_config.job.description}")
36+
37+
# used for colorful printing
38+
color = utils.Color if job_config.metrics.enable_color_printing else utils.NoColor
39+
40+
# take control of garbage collection to avoid stragglers
41+
gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq)
42+
43+
# init distributed
44+
world_size = int(os.environ["WORLD_SIZE"])
45+
parallel_dims = ParallelDims(
46+
dp=job_config.training.data_parallel_degree,
47+
tp=job_config.training.tensor_parallel_degree,
48+
pp=job_config.experimental.pipeline_parallel_degree,
49+
world_size=world_size,
50+
enable_loss_parallel=job_config.training.enable_loss_parallel,
51+
dp_type=job_config.training.data_parallel_type,
52+
)
53+
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
54+
torch.cuda.set_device(device)
55+
utils.init_distributed(job_config)
56+
# initialize GPU memory monitor and get peak flops for MFU calculation
57+
gpu_memory_monitor = build_gpu_memory_monitor()
58+
gpu_peak_flops = utils.get_peak_flops(gpu_memory_monitor.device_name)
59+
60+
# build meshes
61+
world_mesh = parallel_dims.build_mesh(device_type="cuda")
62+
if parallel_dims.dp_enabled:
63+
dp_mesh = world_mesh["dp"]
64+
dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()
65+
else:
66+
dp_degree, dp_rank = 1, 0
67+
68+
if parallel_dims.pp_enabled:
69+
pp_mesh = world_mesh["pp"]
70+
71+
model_name = job_config.model.name
72+
73+
# initiate model from torchbench
74+
config = TorchBenchModelConfig(
75+
name=model_name,
76+
test="train",
77+
device="cuda",
78+
batch_size=job_config.training.batch_size,
79+
extra_args=[],
80+
)
81+
model_flops = get_model_flops(config)
82+
benchmark_model = load_model(config)
83+
model, _ = benchmark_model.get_module()
84+
85+
# TODO: there seems to be a bug with dtype conversion (e.g. use resnet50)
86+
# cast input dtype if needed
87+
param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param]
88+
input_cond = lambda x: x.dtype == torch.float32
89+
input_action = lambda x: x.to(param_dtype)
90+
if hasattr(benchmark_model, "example_inputs"):
91+
benchmark_model.example_inputs = input_cast(
92+
input_cond, input_action, benchmark_model.example_inputs
93+
)
94+
else:
95+
logger.warning(
96+
f"{model_name} example inputs haven't been cast to {action} yet!"
97+
)
98+
99+
# log model size
100+
model_param_count = utils.get_num_params(model)
101+
logger.info(
102+
f"{color.blue}Model {model_name} "
103+
f"{color.red}size: {model_param_count:,} total parameters{color.reset}"
104+
)
105+
106+
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
107+
model = torch_spmd_parallelize(model, world_mesh, parallel_dims, job_config)
108+
109+
# update model and optimizer after applying parallelisms
110+
benchmark_model.set_module(model)
111+
optimizer = benchmark_model.get_optimizer()
112+
optimizer.add_param_group({"params": model.parameters()})
113+
114+
model.train()
115+
116+
gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
117+
logger.info(
118+
f"GPU memory usage for model: "
119+
f"{gpu_mem_stats.max_reserved_gib:.2f}GiB"
120+
f"({gpu_mem_stats.max_reserved_pct:.2f}%)"
121+
)
122+
123+
train_state = TrainState()
124+
125+
# variables used to keep info for metrics logging
126+
losses_since_last_log = []
127+
gpu_memory_monitor.reset_peak_stats()
128+
129+
# train loop
130+
logger.info(
131+
f"Training starts at step {train_state.step + 1}, "
132+
f"with local batch size {job_config.training.batch_size}, "
133+
f"global batch size {job_config.training.batch_size * dp_degree}, "
134+
f"total steps {job_config.training.steps}"
135+
)
136+
with maybe_enable_profiling(
137+
job_config, global_step=train_state.step
138+
) as torch_profiler, maybe_enable_memory_snapshot(
139+
job_config, global_step=train_state.step
140+
) as memory_profiler:
141+
while train_state.step < job_config.training.steps:
142+
train_state.step += 1
143+
gc_handler.run(train_state.step)
144+
145+
torch.cuda.synchronize()
146+
start_event = torch.cuda.Event(enable_timing=True)
147+
end_event = torch.cuda.Event(enable_timing=True)
148+
149+
# Collect time_ns() instead of time() which does not provide better precision than 1
150+
# second according to https://docs.python.org/3/library/time.html#time.time.
151+
t0 = time.time_ns()
152+
start_event.record()
153+
154+
is_staged = (
155+
hasattr(benchmark_model, "forward")
156+
and hasattr(benchmark_model, "backward")
157+
and hasattr(benchmark_model, "optimizer_step")
158+
)
159+
if is_staged and (getattr(benchmark_model, "train", None) is None):
160+
if optimizer is not None:
161+
optimizer.zero_grad()
162+
loss = benchmark_model.forward()
163+
benchmark_model.backward(loss)
164+
if optimizer is not None:
165+
benchmark_model.optimizer_step()
166+
else:
167+
loss = benchmark_model.train()
168+
169+
end_event.record()
170+
torch.cuda.synchronize()
171+
t1 = time.time_ns()
172+
time_delta = start_event.elapsed_time(end_event), (t1 - t0) / 1_000_000
173+
174+
# log metrics
175+
losses_since_last_log.append(loss)
176+
if (
177+
train_state.step == 1
178+
or train_state.step % job_config.metrics.log_freq == 0
179+
):
180+
losses = [
181+
loss.item() if isinstance(loss, torch.Tensor) else loss
182+
for loss in losses_since_last_log
183+
]
184+
avg_loss, max_loss = sum(losses) / len(losses), max(losses)
185+
if parallel_dims.dp_enabled:
186+
global_avg_loss, global_max_loss = (
187+
utils.dist_mean(avg_loss, dp_mesh),
188+
utils.dist_max(max_loss, dp_mesh),
189+
)
190+
else:
191+
global_avg_loss, global_max_loss = avg_loss, max_loss
192+
193+
gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
194+
195+
logger.info(
196+
f"{color.cyan}step: {train_state.step:2} "
197+
f"{color.green}loss: {global_avg_loss:7.4f} "
198+
f"{color.yellow}memory: {gpu_mem_stats.max_reserved_gib:5.2f}GiB"
199+
f"({gpu_mem_stats.max_reserved_pct:.2f}%) "
200+
f"{color.blue}GPU time: {time_delta[0]:.3f}ms "
201+
f"CPU wall time: {time_delta[1]:.3f}ms{color.reset}"
202+
)
203+
204+
losses_since_last_log.clear()
205+
gpu_memory_monitor.reset_peak_stats()
206+
207+
# signal the profiler that the next profiling step has started
208+
if torch_profiler:
209+
torch_profiler.step()
210+
if memory_profiler:
211+
memory_profiler.step()
212+
213+
# reduce timeout after first train step for faster signal
214+
# (assuming lazy init and compilation are finished)
215+
if train_state.step == 1:
216+
utils.set_pg_timeouts(
217+
timeout=timedelta(seconds=job_config.comm.train_timeout_seconds),
218+
world_mesh=world_mesh,
219+
)
220+
221+
if torch.distributed.get_rank() == 0:
222+
logger.info("Sleeping 2 seconds for other ranks to complete")
223+
time.sleep(2)
224+
225+
logger.info("Training completed")
226+
227+
228+
if __name__ == "__main__":
229+
config = JobConfig()
230+
config.parse_args()
231+
main(config)
232+
torch.distributed.destroy_process_group()

run_benchmark_train.sh

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#!/usr/bin/bash
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
set -ex
9+
10+
# use envs as local overrides for convenience
11+
# e.g.
12+
# LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh
13+
NGPU=${NGPU:-"8"}
14+
LOG_RANK=${LOG_RANK:-0}
15+
CONFIG_FILE=${CONFIG_FILE:-"./train_configs/benchmark_model.toml"}
16+
17+
overrides=""
18+
if [ $# -ne 0 ]; then
19+
overrides="$*"
20+
fi
21+
22+
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
23+
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
24+
benchmark.py --job.config_file ${CONFIG_FILE} $overrides

run_llama_train.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ if [ $# -ne 0 ]; then
1919
overrides="$*"
2020
fi
2121

22+
# TORCH_TRACE="./outputs/trace" \
23+
TORCH_NCCL_AVOID_RECORD_STREAMS=1 \
2224
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
2325
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
2426
train.py --job.config_file ${CONFIG_FILE} $overrides

torchtitan/config_manager.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,14 @@ def __init__(self):
241241
action="store_true",
242242
help="Whether to apply loss parallel when sequence parallel is enabled",
243243
)
244+
245+
# experimental configs
246+
self.parser.add_argument(
247+
"--experimental.torch_spmd",
248+
default=False,
249+
action="store_true",
250+
help="Whether to use the experimental torch_spmd style parallelism",
251+
)
244252
self.parser.add_argument(
245253
"--experimental.enable_async_tensor_parallel",
246254
default=False,

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,55 @@
2929
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
3030
from torchtitan.logging import logger
3131
from torchtitan.parallelisms.parallel_dims import ParallelDims
32-
from torchtitan.parallelisms.utils import check_strided_sharding_enabled
32+
33+
34+
# NOTE(lty): experimental for the PT-D 24 research internship project
35+
def torch_spmd_parallelize(
36+
model: nn.Module,
37+
world_mesh: DeviceMesh,
38+
parallel_dims: ParallelDims,
39+
job_config: JobConfig,
40+
):
41+
torch._inductor.config.simplefsdp.enable_reorder = True
42+
torch._inductor.config.simplefsdp.enable_bucket = True
43+
44+
if parallel_dims.tp_enabled:
45+
apply_tp(
46+
model,
47+
world_mesh["tp"],
48+
loss_parallel=parallel_dims.loss_parallel_enabled,
49+
enable_float8=job_config.float8.enable_float8_linear,
50+
enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
51+
)
52+
53+
ac_config = job_config.activation_checkpoint
54+
if ac_config.mode != "none":
55+
apply_ac(model, ac_config)
56+
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
57+
58+
if parallel_dims.dp_enabled:
59+
from torch_spmd.data_parallel import data_parallel, MixedPrecisionPolicy
60+
61+
mp_policy = MixedPrecisionPolicy(
62+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
63+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
64+
)
65+
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
66+
67+
model = data_parallel(
68+
model,
69+
dp_mesh,
70+
mode="fully_shard",
71+
ac_mode=ac_config.mode,
72+
mp_policy=mp_policy,
73+
)
74+
logger.info("Applied Simple FSDP to the model")
75+
76+
if job_config.training.compile:
77+
model = torch.compile(model, fullgraph=True)
78+
logger.info("Compiling with torch.compile")
79+
80+
return model
3381

3482

3583
def parallelize_llama(
@@ -45,6 +93,9 @@ def parallelize_llama(
4593
NOTE: The passed-in model preferably should be on meta device. Otherwise,
4694
the model must fit on GPU or CPU memory.
4795
"""
96+
# NOTE(lty): experimental for the PT-D 24 research internship project
97+
if job_config.experimental.torch_spmd:
98+
return torch_spmd_parallelize(model, world_mesh, parallel_dims, job_config)
4899

49100
if parallel_dims.tp_enabled:
50101
if (
@@ -300,11 +351,12 @@ def apply_fsdp(
300351
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
301352
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
302353

354+
# TODO(lty): the check below requires the latest PyTorch nightly; remove for now
303355
# TODO: remove this check once PyTorch 2.5 is released. We can safely assume
304356
# that users won't use a nightly build which is older than 20240809 by then.
305-
if tp_enabled:
306-
# check if strided sharding is enabled, which is necessary for 2D/3D DCP
307-
check_strided_sharding_enabled()
357+
# if tp_enabled:
358+
# # check if strided sharding is enabled, which is necessary for 2D/3D DCP
359+
# check_strided_sharding_enabled()
308360

309361
for layer_id, transformer_block in model.layers.items():
310362
if pp_enabled:

0 commit comments

Comments
 (0)