From f186b2f540c4418dfc27e0a26fba8d49b0104484 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 31 May 2024 13:55:38 -0700 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- .ci/docker/requirements.txt | 2 +- ...riodic.yaml => integration_test_4gpu.yaml} | 5 +- .github/workflows/unit_test_4gpu.yaml | 36 ---------- .gitignore | 5 ++ README.md | 2 +- test_runner.py | 65 ++++++++++++------- torchtitan/config_manager.py | 6 ++ torchtitan/metrics.py | 15 +++-- torchtitan/parallelisms/parallelize_llama.py | 46 ++++++------- train_configs/debug_model.toml | 3 +- 10 files changed, 92 insertions(+), 93 deletions(-) rename .github/workflows/{integration_test_periodic.yaml => integration_test_4gpu.yaml} (94%) delete mode 100644 .github/workflows/unit_test_4gpu.yaml diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt index bb21293b..520eb8d4 100644 --- a/.ci/docker/requirements.txt +++ b/.ci/docker/requirements.txt @@ -1,4 +1,4 @@ -torch >= 2.2.0.dev +torch >= 2.3.0 datasets >= 2.19.0 tomli >= 1.1.0 ; python_version < "3.11" tensorboard diff --git a/.github/workflows/integration_test_periodic.yaml b/.github/workflows/integration_test_4gpu.yaml similarity index 94% rename from .github/workflows/integration_test_periodic.yaml rename to .github/workflows/integration_test_4gpu.yaml index 87c9005a..4cf29da4 100644 --- a/.github/workflows/integration_test_periodic.yaml +++ b/.github/workflows/integration_test_4gpu.yaml @@ -1,6 +1,9 @@ -name: GPU Integration Test +name: 4 GPU Integration Test on: + push: + branches: [ main ] + pull_request: schedule: # Runs hourly - cron: '0 * * * *' diff --git a/.github/workflows/unit_test_4gpu.yaml b/.github/workflows/unit_test_4gpu.yaml deleted file mode 100644 index 6f052868..00000000 --- a/.github/workflows/unit_test_4gpu.yaml +++ /dev/null @@ -1,36 +0,0 @@ -name: 4 GPU Unit Test - -on: - push: - branches: [ main ] - pull_request: - -concurrency: - group: unit-test${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} - cancel-in-progress: true - -jobs: - build-test: - uses: pytorch/test-infra/.github/workflows/linux_job.yml@main - with: - runner: linux.g5.12xlarge.nvidia.gpu - gpu-arch-type: cuda - gpu-arch-version: "12.1" - # This image is faster to clone than the default, but it lacks CC needed by triton - # (1m25s vs 2m37s). - docker-image: torchtitan-ubuntu-20.04-clang12 - repository: pytorch/torchtitan - upload-artifact: outputs - script: | - set -eux - - # The generic Linux job chooses to use base env, not the one setup by the image - CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") - conda activate "${CONDA_ENV}" - - pip config --user set global.progress_bar off - - python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 - python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/ - mkdir artifacts-to-be-uploaded - python ./test_runner.py artifacts-to-be-uploaded diff --git a/.gitignore b/.gitignore index 4f9856f4..cf5f06e1 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,8 @@ out wandb torchtitan/datasets/**/*.model + +# temp files +*.log +error.json +_remote_module_non_scriptable.py diff --git a/README.md b/README.md index 1c32a9d8..d47f6c67 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -[![GPU Integration Test](https://github.com/pytorch/torchtitan/actions/workflows/unit_test_4gpu.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/unit_test_4gpu.yaml) +[![4 GPU Integration Test](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_4gpu.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_4gpu.yaml) # torchtitan diff --git a/test_runner.py b/test_runner.py index 834fc080..31f53034 100755 --- a/test_runner.py +++ b/test_runner.py @@ -29,11 +29,12 @@ class OverrideDefinitions: override_args: Sequence[Sequence[str]] = tuple(tuple(" ")) test_descr: str = "default" + test_name: str = "default" requires_seed_checkpoint: bool = False ngpu: int = 4 -def build_test_list(args): +def build_test_list(): """ key is the config file name and value is a list of OverrideDefinitions that is used to generate variations of integration tests based on the @@ -45,7 +46,6 @@ def build_test_list(args): [ [ "--checkpoint.enable_checkpoint", - f"--job.dump_folder {args.output_dir}/pp_1f1b/", "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.1", "--experimental.pipeline_parallel_schedule 1f1b", @@ -53,6 +53,7 @@ def build_test_list(args): ], ], "PP 1D test 1f1b", + "pp_1f1b", requires_seed_checkpoint=True, ngpu=2, ), @@ -60,7 +61,6 @@ def build_test_list(args): [ [ "--checkpoint.enable_checkpoint", - f"--job.dump_folder {args.output_dir}/pp_gpipe/", "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.1", "--experimental.pipeline_parallel_schedule gpipe", @@ -68,6 +68,7 @@ def build_test_list(args): ], ], "PP 1D test gpipe", + "pp_gpipe", requires_seed_checkpoint=True, ngpu=2, ), @@ -75,7 +76,6 @@ def build_test_list(args): [ [ "--checkpoint.enable_checkpoint", - f"--job.dump_folder {args.output_dir}/pp_dp_1f1b/", "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.1", "--experimental.pipeline_parallel_schedule 1f1b", @@ -83,13 +83,13 @@ def build_test_list(args): ], ], "PP+DP 1f1b 2D test", + "pp_dp_1f1b", requires_seed_checkpoint=True, ), OverrideDefinitions( [ [ "--checkpoint.enable_checkpoint", - f"--job.dump_folder {args.output_dir}/pp_dp_gpipe/", "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.1", "--experimental.pipeline_parallel_schedule gpipe", @@ -97,13 +97,13 @@ def build_test_list(args): ], ], "PP+DP gpipe 2D test", + "pp_dp_gpipe", requires_seed_checkpoint=True, ), OverrideDefinitions( [ [ "--checkpoint.enable_checkpoint", - f"--job.dump_folder {args.output_dir}/pp_tp/", "--experimental.pipeline_parallel_degree 2", "--experimental.pipeline_parallel_split_points layers.1", "--training.tensor_parallel_degree 2", @@ -111,77 +111,89 @@ def build_test_list(args): ], ], "PP+TP 2D test", + "pp_tp", requires_seed_checkpoint=True, ), OverrideDefinitions( [ [ - f"--job.dump_folder {args.output_dir}/default/", + "--checkpoint.enable_checkpoint", + "--experimental.pipeline_parallel_degree 2", + "--experimental.pipeline_parallel_split_points layers.1", + "--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with tracer ], ], - "Default", + "PP tracer frontend test", + "pp_tracer", + requires_seed_checkpoint=True, + ), + OverrideDefinitions( + [ + [], + ], + "default", + "default", ), OverrideDefinitions( [ [ "--training.compile --model.norm_type=rmsnorm", - f"--job.dump_folder {args.output_dir}/1d_compile/", ], ], "1D compile", + "1d_compile", ), OverrideDefinitions( [ [ "--training.compile --training.tensor_parallel_degree 2 --model.norm_type=rmsnorm", - f"--job.dump_folder {args.output_dir}/2d_compile/", ], ], "2D compile", + "2d_compile", ), OverrideDefinitions( [ [ "--training.tensor_parallel_degree 2 --model.norm_type=rmsnorm", - f"--job.dump_folder {args.output_dir}/eager_2d/", ], ], "Eager mode 2DParallel", + "eager_2d", ), OverrideDefinitions( [ [ "--checkpoint.enable_checkpoint", - f"--job.dump_folder {args.output_dir}/full_checkpoint/", ], [ "--checkpoint.enable_checkpoint", - f"--job.dump_folder {args.output_dir}/full_checkpoint/", "--training.steps 20", ], ], "Checkpoint Integration Test - Save Load Full Checkpoint", + "full_checkpoint", ), OverrideDefinitions( [ [ "--checkpoint.enable_checkpoint", - f"--job.dump_folder {args.output_dir}/model_weights_only_fp32/", "--checkpoint.model_weights_only", ], ], "Checkpoint Integration Test - Save Model Weights Only fp32", + "model_weights_only_fp32", ), OverrideDefinitions( [ [ "--checkpoint.enable_checkpoint", - f"--job.dump_folder {args.output_dir}/model_weights_only_bf16/", "--checkpoint.model_weights_only", "--checkpoint.export_dtype bfloat16", ], ], "Checkpoint Integration Test - Save Model Weights Only bf16", + "model_weights_only_bf16", ), ] return integration_tests_flavors @@ -197,11 +209,15 @@ def _run_cmd(cmd): ) -def run_test(test_flavor: OverrideDefinitions, full_path: str): +def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str): # run_test supports sequence of tests. for override_arg in test_flavor.override_args: + test_name = test_flavor.test_name + dump_folder_arg = f"--job.dump_folder {output_dir}/{test_name}" cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK=0,1,2,3 ./run_llama_train.sh" + cmd += " " + dump_folder_arg + if override_arg: cmd += " " + " ".join(override_arg) logger.info( @@ -209,13 +225,6 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str): ) if test_flavor.requires_seed_checkpoint: - dump_folder_arg = None - for arg in override_arg: - if "--job.dump_folder" in arg: - dump_folder_arg = arg - assert ( - dump_folder_arg is not None - ), "Can't use seed checkpoint if folder is not specified" logger.info("Creating seed checkpoint") result = _run_cmd( f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh {dump_folder_arg}" @@ -231,7 +240,7 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str): def run_tests(args): - integration_tests_flavors = build_test_list(args) + integration_tests_flavors = build_test_list() for config_file in os.listdir(args.config_dir): if config_file.endswith(".toml"): full_path = os.path.join(args.config_dir, config_file) @@ -242,13 +251,19 @@ def run_tests(args): ) if is_integration_test: for test_flavor in integration_tests_flavors[config_file]: - run_test(test_flavor, full_path) + if args.test == "all" or test_flavor.test_name == args.test: + run_test(test_flavor, full_path, args.output_dir) def main(): parser = argparse.ArgumentParser() parser.add_argument("output_dir") parser.add_argument("--config_dir", default="./train_configs") + parser.add_argument( + "--test", + default="all", + help="test to run, acceptable values: `test_name` in `build_test_list` (default: all)", + ) args = parser.parse_args() if not os.path.exists(args.output_dir): diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index da80b425..6a730dcb 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -125,6 +125,12 @@ def __init__(self): default="tb", help="Folder to dump TensorBoard states", ) + self.parser.add_argument( + "--metrics.rank_0_only", + default=True, + action="store_true", + help="Whether to save TensorBoard metrics only for rank 0 or for all ranks", + ) # model configs self.parser.add_argument( diff --git a/torchtitan/metrics.py b/torchtitan/metrics.py index 90108976..b9b9cabd 100644 --- a/torchtitan/metrics.py +++ b/torchtitan/metrics.py @@ -113,16 +113,21 @@ def close(self): def build_metric_logger(config: JobConfig, tag: Optional[str] = None): dump_dir = config.job.dump_folder - save_tb_folder = config.metrics.save_tb_folder - # since we don't have run id yet, use current minute as identifier + tb_config = config.metrics + save_tb_folder = tb_config.save_tb_folder + # since we don't have run id, use current minute as the identifier datetime_str = datetime.now().strftime("%Y%m%d-%H%M") log_dir = os.path.join(dump_dir, save_tb_folder, datetime_str) - enable_tb = config.metrics.enable_tensorboard + enable_tb = tb_config.enable_tensorboard if enable_tb: logger.info( f"Metrics logging active. Tensorboard logs will be saved at {log_dir}" ) + if tb_config.rank_0_only: + enable_tb = torch.distributed.get_rank() == 0 + else: + rank_str = f"rank_{torch.distributed.get_rank()}" + log_dir = os.path.join(log_dir, rank_str) - rank_str = f"rank_{torch.distributed.get_rank()}" - return MetricLogger(os.path.join(log_dir, rank_str), tag, enable_tb) + return MetricLogger(log_dir, tag, enable_tb) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 425d3abe..3617eb23 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -18,10 +18,11 @@ checkpoint_wrapper as ptd_checkpoint_wrapper, CheckpointImpl, ) -from torch.distributed.pipelining import pipeline, SplitPoint -from torch.distributed.pipelining.PipelineStage import ( - _PipelineStage, +from torch.distributed.pipelining import ( ManualPipelineStage, + pipeline, + PipelineStage, + SplitPoint, ) from torch.distributed.tensor.parallel import ( ColwiseParallel, @@ -159,6 +160,14 @@ def _llama_trace_input(job_config, model_config, device="meta"): return (tokens,) +def _mixed_precision_dtype( + job_config: JobConfig, parallel_dims, default: torch.dtype = torch.float32 +) -> torch.dtype: + """Get the mixed precision dtype if fsdp is enabled, otherwise return the default""" + mp_arg = job_config.training.mixed_precision_param + return TORCH_DTYPE_MAP[mp_arg] if parallel_dims.dp_enabled else default + + def pipeline_llama_manual( model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict ): @@ -204,8 +213,7 @@ def pipeline_llama_manual( # TODO(whc) once ManualPipelineStage supports lazy shape inference, we can leave model on meta device longer and # get rid of the input shape hardcoded here. For now, it should not be a big deal since we only materialize the # layers of the model that map to this stage, not the whole model. - mp_arg = job_config.training.mixed_precision_param - mp_dtype = TORCH_DTYPE_MAP[mp_arg] if parallel_dims.dp_enabled else torch.float32 + mp_dtype = _mixed_precision_dtype(job_config, parallel_dims) batch_size = job_config.training.batch_size local_seq_len = int(job_config.training.seq_len // parallel_dims.tp) layers_io_shape = (batch_size, local_seq_len, model_config.dim) @@ -216,12 +224,7 @@ def pipeline_llama_manual( ) if pp_rank == 0: # first layer - input = torch.randint( - model_config.vocab_size, - size=(batch_size, job_config.training.seq_len), - dtype=torch.int64, - device=device, - ) + (input,) = _llama_trace_input(job_config, model_config, device=device) else: # later layers (assume all start w/ a transformer layer) input = torch.rand(layers_io_shape, dtype=mp_dtype, device=device) @@ -257,21 +260,21 @@ def pipeline_llama_tracer( "fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm." ) - # TODO(whc) maybe we can just fix this by feeding bf16 into the tracer for its input shapes? - raise NotImplementedError( - "pipeline tracer doesn't work with fsdp mixed precision currently. " - "To work around, edit fsdp mixed precision config to use fp32." - ) + if _mixed_precision_dtype(job_config, parallel_dims) == torch.bfloat16: + raise NotImplementedError( + "pipeline tracer doesn't work with fsdp mixed precision currently. " + "To work around, edit fsdp mixed precision config to use fp32." + ) + pp_mesh = world_mesh["pp"] pp_rank = pp_mesh.get_local_rank() - stage_idx = pp_mesh.get_local_rank() + stage_idx = pp_rank layers_per_rank = len(model.layers) // parallel_dims.pp split_spec = { f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING for i in range(1, parallel_dims.pp) } - # Create a pipeline representation from the model pipe = pipeline( model, job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp, @@ -279,10 +282,9 @@ def pipeline_llama_tracer( split_spec=split_spec, ) model = pipe.get_stage_module(stage_idx) - stage = _PipelineStage( - stage_module=model, - stage_index=pp_rank, - pipe_info=pipe.pipe_info, + stage = PipelineStage( + pipe, + stage_index=stage_idx, device=device, group=pp_mesh.get_group(), ) diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 009348b5..5d7e9987 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -3,7 +3,6 @@ [job] dump_folder = "./outputs" description = "Llama 3 debug training" -# TODO: turn this back on once ci have tokenizer use_for_integration_test = true [profiling] @@ -50,7 +49,7 @@ interval_type = "steps" interval = 5 model_weights_only = false export_dtype = "float32" -async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full']