Skip to content

Commit

Permalink
merge upstream changes (#569)
Browse files Browse the repository at this point in the history
It seems SimpleFSDP + TP works with PP out-of-box, but whole-model
compile doesn't. So `torch.compile` on SimpleFSDP won't take effect with
PP. Need to figure out what's the best way to do it.

---------

Co-authored-by: Wanchao <[email protected]>
Co-authored-by: Less Wright <[email protected]>
Co-authored-by: Will Constable <[email protected]>
Co-authored-by: gnadathur <[email protected]>
Co-authored-by: gnadathur <[email protected]>
Co-authored-by: gnadathur <[email protected]>
Co-authored-by: Driss Guessous <[email protected]>
Co-authored-by: Iris Z <[email protected]>
Co-authored-by: Will Constable <[email protected]>
Co-authored-by: Geeta Chauhan <[email protected]>
Co-authored-by: Andrew Gu <[email protected]>
Co-authored-by: Andrew Gu <[email protected]>
Co-authored-by: wz337 <[email protected]>
Co-authored-by: Soumith Chintala <[email protected]>
Co-authored-by: Mark Saroufim <[email protected]>
Co-authored-by: Chien-Chin Huang <[email protected]>
Co-authored-by: Chien-Chin Huang <[email protected]>
Co-authored-by: liangluofb <[email protected]>
Co-authored-by: Huy Do <[email protected]>
Co-authored-by: Gokul <[email protected]>
Co-authored-by: Pavel Belevich <[email protected]>
Co-authored-by: Ke Wen <[email protected]>
Co-authored-by: Wei (Will) Feng <[email protected]>
Co-authored-by: Howard Huang <[email protected]>
Co-authored-by: Xilun Wu <[email protected]>
Co-authored-by: Sanket Jayant Purandare <[email protected]>
Co-authored-by: Yifu Wang <[email protected]>
Co-authored-by: Vasiliy Kuznetsov <[email protected]>
Co-authored-by: Sanket Jayant Purandare <[email protected]>
Co-authored-by: Hugo <[email protected]>
  • Loading branch information
1 parent ac90c36 commit a09cde3
Show file tree
Hide file tree
Showing 11 changed files with 392 additions and 21 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ torchtitan/datasets/**/*.model
*.log
error.json
_remote_module_non_scriptable.py

# torch compile debug related
torch_compile_debug/*
232 changes: 232 additions & 0 deletions benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import time
from datetime import timedelta

import torch
from torch.distributed.elastic.multiprocessing.errors import record

from torchbenchmark.util.experiment.instantiator import (
load_model,
TorchBenchModelConfig,
)
from torchbenchmark.util.experiment.metrics import get_model_flops
from torchbenchmark.util.input import input_cast

from torchtitan import utils
from torchtitan.checkpoint import TrainState
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging import init_logger, logger
from torchtitan.metrics import build_gpu_memory_monitor
from torchtitan.parallelisms import ParallelDims
from torchtitan.parallelisms.parallelize_llama import torch_spmd_parallelize
from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling


# Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
@record
def main(job_config: JobConfig):
init_logger()
logger.info(f"Starting job: {job_config.job.description}")

# used for colorful printing
color = utils.Color if job_config.metrics.enable_color_printing else utils.NoColor

# take control of garbage collection to avoid stragglers
gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq)

# init distributed
world_size = int(os.environ["WORLD_SIZE"])
parallel_dims = ParallelDims(
dp=job_config.training.data_parallel_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
dp_type=job_config.training.data_parallel_type,
)
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
torch.cuda.set_device(device)
utils.init_distributed(job_config)
# initialize GPU memory monitor and get peak flops for MFU calculation
gpu_memory_monitor = build_gpu_memory_monitor()
gpu_peak_flops = utils.get_peak_flops(gpu_memory_monitor.device_name)

# build meshes
world_mesh = parallel_dims.build_mesh(device_type="cuda")
if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"]
dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()
else:
dp_degree, dp_rank = 1, 0

if parallel_dims.pp_enabled:
pp_mesh = world_mesh["pp"]

model_name = job_config.model.name

# initiate model from torchbench
config = TorchBenchModelConfig(
name=model_name,
test="train",
device="cuda",
batch_size=job_config.training.batch_size,
extra_args=[],
)
model_flops = get_model_flops(config)
benchmark_model = load_model(config)
model, _ = benchmark_model.get_module()

# TODO: there seems to be a bug with dtype conversion (e.g. use resnet50)
# cast input dtype if needed
param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param]
input_cond = lambda x: x.dtype == torch.float32
input_action = lambda x: x.to(param_dtype)
if hasattr(benchmark_model, "example_inputs"):
benchmark_model.example_inputs = input_cast(
input_cond, input_action, benchmark_model.example_inputs
)
else:
logger.warning(
f"{model_name} example inputs haven't been cast to {action} yet!"
)

# log model size
model_param_count = utils.get_num_params(model)
logger.info(
f"{color.blue}Model {model_name} "
f"{color.red}size: {model_param_count:,} total parameters{color.reset}"
)

# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
model = torch_spmd_parallelize(model, world_mesh, parallel_dims, job_config)

# update model and optimizer after applying parallelisms
benchmark_model.set_module(model)
optimizer = benchmark_model.get_optimizer()
optimizer.add_param_group({"params": model.parameters()})

model.train()

gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
logger.info(
f"GPU memory usage for model: "
f"{gpu_mem_stats.max_reserved_gib:.2f}GiB"
f"({gpu_mem_stats.max_reserved_pct:.2f}%)"
)

train_state = TrainState()

# variables used to keep info for metrics logging
losses_since_last_log = []
gpu_memory_monitor.reset_peak_stats()

# train loop
logger.info(
f"Training starts at step {train_state.step + 1}, "
f"with local batch size {job_config.training.batch_size}, "
f"global batch size {job_config.training.batch_size * dp_degree}, "
f"total steps {job_config.training.steps}"
)
with maybe_enable_profiling(
job_config, global_step=train_state.step
) as torch_profiler, maybe_enable_memory_snapshot(
job_config, global_step=train_state.step
) as memory_profiler:
while train_state.step < job_config.training.steps:
train_state.step += 1
gc_handler.run(train_state.step)

torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

# Collect time_ns() instead of time() which does not provide better precision than 1
# second according to https://docs.python.org/3/library/time.html#time.time.
t0 = time.time_ns()
start_event.record()

is_staged = (
hasattr(benchmark_model, "forward")
and hasattr(benchmark_model, "backward")
and hasattr(benchmark_model, "optimizer_step")
)
if is_staged and (getattr(benchmark_model, "train", None) is None):
if optimizer is not None:
optimizer.zero_grad()
loss = benchmark_model.forward()
benchmark_model.backward(loss)
if optimizer is not None:
benchmark_model.optimizer_step()
else:
loss = benchmark_model.train()

end_event.record()
torch.cuda.synchronize()
t1 = time.time_ns()
time_delta = start_event.elapsed_time(end_event), (t1 - t0) / 1_000_000

# log metrics
losses_since_last_log.append(loss)
if (
train_state.step == 1
or train_state.step % job_config.metrics.log_freq == 0
):
losses = [
loss.item() if isinstance(loss, torch.Tensor) else loss
for loss in losses_since_last_log
]
avg_loss, max_loss = sum(losses) / len(losses), max(losses)
if parallel_dims.dp_enabled:
global_avg_loss, global_max_loss = (
utils.dist_mean(avg_loss, dp_mesh),
utils.dist_max(max_loss, dp_mesh),
)
else:
global_avg_loss, global_max_loss = avg_loss, max_loss

gpu_mem_stats = gpu_memory_monitor.get_peak_stats()

logger.info(
f"{color.cyan}step: {train_state.step:2} "
f"{color.green}loss: {global_avg_loss:7.4f} "
f"{color.yellow}memory: {gpu_mem_stats.max_reserved_gib:5.2f}GiB"
f"({gpu_mem_stats.max_reserved_pct:.2f}%) "
f"{color.blue}GPU time: {time_delta[0]:.3f}ms "
f"CPU wall time: {time_delta[1]:.3f}ms{color.reset}"
)

losses_since_last_log.clear()
gpu_memory_monitor.reset_peak_stats()

# signal the profiler that the next profiling step has started
if torch_profiler:
torch_profiler.step()
if memory_profiler:
memory_profiler.step()

# reduce timeout after first train step for faster signal
# (assuming lazy init and compilation are finished)
if train_state.step == 1:
utils.set_pg_timeouts(
timeout=timedelta(seconds=job_config.comm.train_timeout_seconds),
world_mesh=world_mesh,
)

if torch.distributed.get_rank() == 0:
logger.info("Sleeping 2 seconds for other ranks to complete")
time.sleep(2)

logger.info("Training completed")


if __name__ == "__main__":
config = JobConfig()
config.parse_args()
main(config)
torch.distributed.destroy_process_group()
24 changes: 24 additions & 0 deletions run_benchmark_train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

set -ex

# use envs as local overrides for convenience
# e.g.
# LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh
NGPU=${NGPU:-"8"}
LOG_RANK=${LOG_RANK:-0}
CONFIG_FILE=${CONFIG_FILE:-"./train_configs/benchmark_model.toml"}

overrides=""
if [ $# -ne 0 ]; then
overrides="$*"
fi

torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
benchmark.py --job.config_file ${CONFIG_FILE} $overrides
2 changes: 2 additions & 0 deletions run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ if [ $# -ne 0 ]; then
overrides="$*"
fi

# TORCH_TRACE="./outputs/trace" \
TORCH_NCCL_AVOID_RECORD_STREAMS=1 \
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --job.config_file ${CONFIG_FILE} $overrides
8 changes: 8 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,14 @@ def __init__(self):
action="store_true",
help="Whether to apply loss parallel when sequence parallel is enabled",
)

# experimental configs
self.parser.add_argument(
"--experimental.torch_spmd",
default=False,
action="store_true",
help="Whether to use the experimental torch_spmd style parallelism",
)
self.parser.add_argument(
"--experimental.enable_async_tensor_parallel",
default=False,
Expand Down
60 changes: 56 additions & 4 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,55 @@
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging import logger
from torchtitan.parallelisms.parallel_dims import ParallelDims
from torchtitan.parallelisms.utils import check_strided_sharding_enabled


# NOTE(lty): experimental for the PT-D 24 research internship project
def torch_spmd_parallelize(
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: ParallelDims,
job_config: JobConfig,
):
torch._inductor.config.simplefsdp.enable_reorder = True
torch._inductor.config.simplefsdp.enable_bucket = True

if parallel_dims.tp_enabled:
apply_tp(
model,
world_mesh["tp"],
loss_parallel=parallel_dims.loss_parallel_enabled,
enable_float8=job_config.float8.enable_float8_linear,
enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
)

ac_config = job_config.activation_checkpoint
if ac_config.mode != "none":
apply_ac(model, ac_config)
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")

if parallel_dims.dp_enabled:
from torch_spmd.data_parallel import data_parallel, MixedPrecisionPolicy

mp_policy = MixedPrecisionPolicy(
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
)
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh

model = data_parallel(
model,
dp_mesh,
mode="fully_shard",
ac_mode=ac_config.mode,
mp_policy=mp_policy,
)
logger.info("Applied Simple FSDP to the model")

if job_config.training.compile:
model = torch.compile(model, fullgraph=True)
logger.info("Compiling with torch.compile")

return model


def parallelize_llama(
Expand All @@ -45,6 +93,9 @@ def parallelize_llama(
NOTE: The passed-in model preferably should be on meta device. Otherwise,
the model must fit on GPU or CPU memory.
"""
# NOTE(lty): experimental for the PT-D 24 research internship project
if job_config.experimental.torch_spmd:
return torch_spmd_parallelize(model, world_mesh, parallel_dims, job_config)

if parallel_dims.tp_enabled:
if (
Expand Down Expand Up @@ -300,11 +351,12 @@ def apply_fsdp(
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}

# TODO(lty): the check below requires the latest PyTorch nightly; remove for now
# TODO: remove this check once PyTorch 2.5 is released. We can safely assume
# that users won't use a nightly build which is older than 20240809 by then.
if tp_enabled:
# check if strided sharding is enabled, which is necessary for 2D/3D DCP
check_strided_sharding_enabled()
# if tp_enabled:
# # check if strided sharding is enabled, which is necessary for 2D/3D DCP
# check_strided_sharding_enabled()

for layer_id, transformer_block in model.layers.items():
if pp_enabled:
Expand Down
Loading

0 comments on commit a09cde3

Please sign in to comment.