From a09cde3502fa3e19a0de72246c507fb102649287 Mon Sep 17 00:00:00 2001 From: tianyu-l <150487191+tianyu-l@users.noreply.github.com> Date: Sat, 7 Sep 2024 19:55:45 -0700 Subject: [PATCH] merge upstream changes (#569) It seems SimpleFSDP + TP works with PP out-of-box, but whole-model compile doesn't. So `torch.compile` on SimpleFSDP won't take effect with PP. Need to figure out what's the best way to do it. --------- Co-authored-by: Wanchao Co-authored-by: Less Wright Co-authored-by: Will Constable Co-authored-by: gnadathur Co-authored-by: gnadathur Co-authored-by: gnadathur Co-authored-by: Driss Guessous <32754868+drisspg@users.noreply.github.com> Co-authored-by: Iris Z <31293777+wz337@users.noreply.github.com> Co-authored-by: Will Constable Co-authored-by: Geeta Chauhan <4461127+chauhang@users.noreply.github.com> Co-authored-by: Andrew Gu <31054793+awgu@users.noreply.github.com> Co-authored-by: Andrew Gu Co-authored-by: wz337 Co-authored-by: Soumith Chintala Co-authored-by: Mark Saroufim Co-authored-by: Chien-Chin Huang Co-authored-by: Chien-Chin Huang Co-authored-by: liangluofb <82682482+liangluofb@users.noreply.github.com> Co-authored-by: Huy Do Co-authored-by: Gokul Co-authored-by: Pavel Belevich Co-authored-by: Ke Wen Co-authored-by: Wei (Will) Feng <134637289+weifengpy@users.noreply.github.com> Co-authored-by: Howard Huang Co-authored-by: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Co-authored-by: Sanket Jayant Purandare Co-authored-by: Yifu Wang Co-authored-by: Vasiliy Kuznetsov Co-authored-by: Sanket Jayant Purandare Co-authored-by: Hugo <6937752+fduwjj@users.noreply.github.com> --- .gitignore | 3 + benchmark.py | 232 +++++++++++++++++++ run_benchmark_train.sh | 24 ++ run_llama_train.sh | 2 + torchtitan/config_manager.py | 8 + torchtitan/parallelisms/parallelize_llama.py | 60 ++++- train.py | 13 +- train_configs/benchmark_model.toml | 39 ++++ train_configs/debug_model.toml | 14 +- train_configs/llama3_405b.toml | 8 +- train_configs/llama3_8b.toml | 10 +- 11 files changed, 392 insertions(+), 21 deletions(-) create mode 100644 benchmark.py create mode 100755 run_benchmark_train.sh create mode 100644 train_configs/benchmark_model.toml diff --git a/.gitignore b/.gitignore index cf5f06e1..00037067 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,6 @@ torchtitan/datasets/**/*.model *.log error.json _remote_module_non_scriptable.py + +# torch compile debug related +torch_compile_debug/* diff --git a/benchmark.py b/benchmark.py new file mode 100644 index 00000000..16a706f9 --- /dev/null +++ b/benchmark.py @@ -0,0 +1,232 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import time +from datetime import timedelta + +import torch +from torch.distributed.elastic.multiprocessing.errors import record + +from torchbenchmark.util.experiment.instantiator import ( + load_model, + TorchBenchModelConfig, +) +from torchbenchmark.util.experiment.metrics import get_model_flops +from torchbenchmark.util.input import input_cast + +from torchtitan import utils +from torchtitan.checkpoint import TrainState +from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP +from torchtitan.logging import init_logger, logger +from torchtitan.metrics import build_gpu_memory_monitor +from torchtitan.parallelisms import ParallelDims +from torchtitan.parallelisms.parallelize_llama import torch_spmd_parallelize +from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling + + +# Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html +@record +def main(job_config: JobConfig): + init_logger() + logger.info(f"Starting job: {job_config.job.description}") + + # used for colorful printing + color = utils.Color if job_config.metrics.enable_color_printing else utils.NoColor + + # take control of garbage collection to avoid stragglers + gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq) + + # init distributed + world_size = int(os.environ["WORLD_SIZE"]) + parallel_dims = ParallelDims( + dp=job_config.training.data_parallel_degree, + tp=job_config.training.tensor_parallel_degree, + pp=job_config.experimental.pipeline_parallel_degree, + world_size=world_size, + enable_loss_parallel=job_config.training.enable_loss_parallel, + dp_type=job_config.training.data_parallel_type, + ) + device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") + torch.cuda.set_device(device) + utils.init_distributed(job_config) + # initialize GPU memory monitor and get peak flops for MFU calculation + gpu_memory_monitor = build_gpu_memory_monitor() + gpu_peak_flops = utils.get_peak_flops(gpu_memory_monitor.device_name) + + # build meshes + world_mesh = parallel_dims.build_mesh(device_type="cuda") + if parallel_dims.dp_enabled: + dp_mesh = world_mesh["dp"] + dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() + else: + dp_degree, dp_rank = 1, 0 + + if parallel_dims.pp_enabled: + pp_mesh = world_mesh["pp"] + + model_name = job_config.model.name + + # initiate model from torchbench + config = TorchBenchModelConfig( + name=model_name, + test="train", + device="cuda", + batch_size=job_config.training.batch_size, + extra_args=[], + ) + model_flops = get_model_flops(config) + benchmark_model = load_model(config) + model, _ = benchmark_model.get_module() + + # TODO: there seems to be a bug with dtype conversion (e.g. use resnet50) + # cast input dtype if needed + param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param] + input_cond = lambda x: x.dtype == torch.float32 + input_action = lambda x: x.to(param_dtype) + if hasattr(benchmark_model, "example_inputs"): + benchmark_model.example_inputs = input_cast( + input_cond, input_action, benchmark_model.example_inputs + ) + else: + logger.warning( + f"{model_name} example inputs haven't been cast to {action} yet!" + ) + + # log model size + model_param_count = utils.get_num_params(model) + logger.info( + f"{color.blue}Model {model_name} " + f"{color.red}size: {model_param_count:,} total parameters{color.reset}" + ) + + # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel + model = torch_spmd_parallelize(model, world_mesh, parallel_dims, job_config) + + # update model and optimizer after applying parallelisms + benchmark_model.set_module(model) + optimizer = benchmark_model.get_optimizer() + optimizer.add_param_group({"params": model.parameters()}) + + model.train() + + gpu_mem_stats = gpu_memory_monitor.get_peak_stats() + logger.info( + f"GPU memory usage for model: " + f"{gpu_mem_stats.max_reserved_gib:.2f}GiB" + f"({gpu_mem_stats.max_reserved_pct:.2f}%)" + ) + + train_state = TrainState() + + # variables used to keep info for metrics logging + losses_since_last_log = [] + gpu_memory_monitor.reset_peak_stats() + + # train loop + logger.info( + f"Training starts at step {train_state.step + 1}, " + f"with local batch size {job_config.training.batch_size}, " + f"global batch size {job_config.training.batch_size * dp_degree}, " + f"total steps {job_config.training.steps}" + ) + with maybe_enable_profiling( + job_config, global_step=train_state.step + ) as torch_profiler, maybe_enable_memory_snapshot( + job_config, global_step=train_state.step + ) as memory_profiler: + while train_state.step < job_config.training.steps: + train_state.step += 1 + gc_handler.run(train_state.step) + + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Collect time_ns() instead of time() which does not provide better precision than 1 + # second according to https://docs.python.org/3/library/time.html#time.time. + t0 = time.time_ns() + start_event.record() + + is_staged = ( + hasattr(benchmark_model, "forward") + and hasattr(benchmark_model, "backward") + and hasattr(benchmark_model, "optimizer_step") + ) + if is_staged and (getattr(benchmark_model, "train", None) is None): + if optimizer is not None: + optimizer.zero_grad() + loss = benchmark_model.forward() + benchmark_model.backward(loss) + if optimizer is not None: + benchmark_model.optimizer_step() + else: + loss = benchmark_model.train() + + end_event.record() + torch.cuda.synchronize() + t1 = time.time_ns() + time_delta = start_event.elapsed_time(end_event), (t1 - t0) / 1_000_000 + + # log metrics + losses_since_last_log.append(loss) + if ( + train_state.step == 1 + or train_state.step % job_config.metrics.log_freq == 0 + ): + losses = [ + loss.item() if isinstance(loss, torch.Tensor) else loss + for loss in losses_since_last_log + ] + avg_loss, max_loss = sum(losses) / len(losses), max(losses) + if parallel_dims.dp_enabled: + global_avg_loss, global_max_loss = ( + utils.dist_mean(avg_loss, dp_mesh), + utils.dist_max(max_loss, dp_mesh), + ) + else: + global_avg_loss, global_max_loss = avg_loss, max_loss + + gpu_mem_stats = gpu_memory_monitor.get_peak_stats() + + logger.info( + f"{color.cyan}step: {train_state.step:2} " + f"{color.green}loss: {global_avg_loss:7.4f} " + f"{color.yellow}memory: {gpu_mem_stats.max_reserved_gib:5.2f}GiB" + f"({gpu_mem_stats.max_reserved_pct:.2f}%) " + f"{color.blue}GPU time: {time_delta[0]:.3f}ms " + f"CPU wall time: {time_delta[1]:.3f}ms{color.reset}" + ) + + losses_since_last_log.clear() + gpu_memory_monitor.reset_peak_stats() + + # signal the profiler that the next profiling step has started + if torch_profiler: + torch_profiler.step() + if memory_profiler: + memory_profiler.step() + + # reduce timeout after first train step for faster signal + # (assuming lazy init and compilation are finished) + if train_state.step == 1: + utils.set_pg_timeouts( + timeout=timedelta(seconds=job_config.comm.train_timeout_seconds), + world_mesh=world_mesh, + ) + + if torch.distributed.get_rank() == 0: + logger.info("Sleeping 2 seconds for other ranks to complete") + time.sleep(2) + + logger.info("Training completed") + + +if __name__ == "__main__": + config = JobConfig() + config.parse_args() + main(config) + torch.distributed.destroy_process_group() diff --git a/run_benchmark_train.sh b/run_benchmark_train.sh new file mode 100755 index 00000000..022e74c5 --- /dev/null +++ b/run_benchmark_train.sh @@ -0,0 +1,24 @@ +#!/usr/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -ex + +# use envs as local overrides for convenience +# e.g. +# LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh +NGPU=${NGPU:-"8"} +LOG_RANK=${LOG_RANK:-0} +CONFIG_FILE=${CONFIG_FILE:-"./train_configs/benchmark_model.toml"} + +overrides="" +if [ $# -ne 0 ]; then + overrides="$*" +fi + +torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ +--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ +benchmark.py --job.config_file ${CONFIG_FILE} $overrides diff --git a/run_llama_train.sh b/run_llama_train.sh index a4107806..296da519 100755 --- a/run_llama_train.sh +++ b/run_llama_train.sh @@ -19,6 +19,8 @@ if [ $# -ne 0 ]; then overrides="$*" fi +# TORCH_TRACE="./outputs/trace" \ +TORCH_NCCL_AVOID_RECORD_STREAMS=1 \ torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ train.py --job.config_file ${CONFIG_FILE} $overrides diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 3ba1d102..d6ead31c 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -241,6 +241,14 @@ def __init__(self): action="store_true", help="Whether to apply loss parallel when sequence parallel is enabled", ) + + # experimental configs + self.parser.add_argument( + "--experimental.torch_spmd", + default=False, + action="store_true", + help="Whether to use the experimental torch_spmd style parallelism", + ) self.parser.add_argument( "--experimental.enable_async_tensor_parallel", default=False, diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index aa07f25f..ef5f1fc7 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -29,7 +29,55 @@ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging import logger from torchtitan.parallelisms.parallel_dims import ParallelDims -from torchtitan.parallelisms.utils import check_strided_sharding_enabled + + +# NOTE(lty): experimental for the PT-D 24 research internship project +def torch_spmd_parallelize( + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + torch._inductor.config.simplefsdp.enable_reorder = True + torch._inductor.config.simplefsdp.enable_bucket = True + + if parallel_dims.tp_enabled: + apply_tp( + model, + world_mesh["tp"], + loss_parallel=parallel_dims.loss_parallel_enabled, + enable_float8=job_config.float8.enable_float8_linear, + enable_async_tp=job_config.experimental.enable_async_tensor_parallel, + ) + + ac_config = job_config.activation_checkpoint + if ac_config.mode != "none": + apply_ac(model, ac_config) + logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") + + if parallel_dims.dp_enabled: + from torch_spmd.data_parallel import data_parallel, MixedPrecisionPolicy + + mp_policy = MixedPrecisionPolicy( + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + ) + dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh + + model = data_parallel( + model, + dp_mesh, + mode="fully_shard", + ac_mode=ac_config.mode, + mp_policy=mp_policy, + ) + logger.info("Applied Simple FSDP to the model") + + if job_config.training.compile: + model = torch.compile(model, fullgraph=True) + logger.info("Compiling with torch.compile") + + return model def parallelize_llama( @@ -45,6 +93,9 @@ def parallelize_llama( NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ + # NOTE(lty): experimental for the PT-D 24 research internship project + if job_config.experimental.torch_spmd: + return torch_spmd_parallelize(model, world_mesh, parallel_dims, job_config) if parallel_dims.tp_enabled: if ( @@ -300,11 +351,12 @@ def apply_fsdp( mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + # TODO(lty): the check below requires the latest PyTorch nightly; remove for now # TODO: remove this check once PyTorch 2.5 is released. We can safely assume # that users won't use a nightly build which is older than 20240809 by then. - if tp_enabled: - # check if strided sharding is enabled, which is necessary for 2D/3D DCP - check_strided_sharding_enabled() + # if tp_enabled: + # # check if strided sharding is enabled, which is necessary for 2D/3D DCP + # check_strided_sharding_enabled() for layer_id, transformer_block in model.layers.items(): if pp_enabled: diff --git a/train.py b/train.py index ffea00a9..4803aa8f 100644 --- a/train.py +++ b/train.py @@ -12,6 +12,9 @@ import torch from torch.distributed.elastic.multiprocessing.errors import record +# context needed by meta-init with torch_spmd +from torch_spmd.data_parallel import disable_data_parallel + from torchtitan import utils from torchtitan.checkpoint import CheckpointManager, TrainState from torchtitan.config_manager import JobConfig @@ -154,16 +157,20 @@ def loss_fn(pred, labels): # apply SPMD-style PT-D techniques models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config) m.to_empty(device="cuda") - m.init_weights() + with disable_data_parallel() if job_config.experimental.torch_spmd else contextlib.nullcontext(): + m.init_weights() m.train() else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel - models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config) + model = models_parallelize_fns[model_name]( + model, world_mesh, parallel_dims, job_config + ) # move sharded model to CPU/GPU and initialize weights via DTensor init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda" model.to_empty(device=init_device) - model.init_weights() + with disable_data_parallel() if job_config.experimental.torch_spmd else contextlib.nullcontext(): + model.init_weights() model.train() model_parts = [model] diff --git a/train_configs/benchmark_model.toml b/train_configs/benchmark_model.toml new file mode 100644 index 00000000..c2f37d04 --- /dev/null +++ b/train_configs/benchmark_model.toml @@ -0,0 +1,39 @@ +# torchtitan Config.toml + +[job] +dump_folder = "./outputs" +description = "torchbenchmark training" + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +enable_color_printing = true +enable_tensorboard = false +save_tb_folder = "tb" + +[model] +# name = "resnet50" +name = "hf_GPT2" + +[training] +batch_size = 8 +max_norm = 1.0 # grad norm clipping +steps = 10 +data_parallel_degree = -1 +compile = true +mixed_precision_param = "bfloat16" +mixed_precision_reduce = "bfloat16" +# mixed_precision_param = "float32" +# mixed_precision_reduce = "float32" + +[experimental] +torch_spmd = true + +[activation_checkpoint] +mode = 'none' # ['none', 'selective', 'full'] diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index af547214..f401b791 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -6,7 +6,7 @@ description = "Llama 3 debug training" use_for_integration_test = true [profiling] -enable_profiling = true +enable_profiling = false save_traces_folder = "profile_trace" profile_freq = 10 enable_memory_snapshot = false @@ -15,7 +15,7 @@ save_memory_snapshot_folder = "memory_snapshot" [metrics] log_freq = 1 enable_color_printing = true -enable_tensorboard = true +enable_tensorboard = false save_tb_folder = "tb" [model] @@ -37,12 +37,14 @@ max_norm = 1.0 # grad norm clipping steps = 10 data_parallel_degree = -1 tensor_parallel_degree = 1 -compile = false +compile = true dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) [experimental] -pipeline_parallel_degree = 1 +# pipeline_parallel_degree = 2 +# pipeline_parallel_split_points = ["layers.4"] enable_async_tensor_parallel = false +torch_spmd = true [checkpoint] enable_checkpoint = false @@ -54,8 +56,8 @@ export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] -mode = 'selective' # ['none', 'selective', 'full'] -selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy +mode = 'none' # ['none', 'selective', 'full'] +selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy [float8] enable_float8_linear = false diff --git a/train_configs/llama3_405b.toml b/train_configs/llama3_405b.toml index 1a83301f..1cceb14b 100644 --- a/train_configs/llama3_405b.toml +++ b/train_configs/llama3_405b.toml @@ -38,7 +38,7 @@ dataset = "c4" [experimental] pipeline_parallel_degree = 1 -enable_async_tensor_parallel = true +enable_async_tensor_parallel = false [checkpoint] enable_checkpoint = false @@ -53,6 +53,6 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] mode = 'full' # ['none', 'selective', 'full'] [float8] -enable_float8_linear = true -enable_fsdp_float8_all_gather = true -precompute_float8_dynamic_scale_for_fsdp = true +enable_float8_linear = false +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index 3d0c5160..c96278b7 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -6,13 +6,14 @@ dump_folder = "./outputs" description = "Llama 3 8B training" [profiling] -enable_profiling = true +enable_profiling = false save_traces_folder = "profile_trace" profile_freq = 100 [metrics] log_freq = 10 -enable_tensorboard = true +enable_color_printing = false +enable_tensorboard = false save_tb_folder = "tb" [model] @@ -33,11 +34,12 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 -compile = false +compile = true dataset = "c4" [experimental] pipeline_parallel_degree = 1 +torch_spmd = true [checkpoint] enable_checkpoint = false @@ -49,7 +51,7 @@ export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] -mode = 'selective' # ['none', 'selective', 'full'] +mode = 'none' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy [float8]