From ecacd5b614d2df6d589b8af1e2aad6a570d3511d Mon Sep 17 00:00:00 2001 From: Md Fahim Faysal Khan Date: Thu, 5 Sep 2024 11:57:27 -0700 Subject: [PATCH 01/11] remove deprecated XLA flag (#1010) 1. `xla_gpu_enable_triton_gemm` is still needed. 2. Removed some other deprecated XLA flags: `xla_gpu_enable_triton_softmax_fusion` 3. Also removed some XLA flags that are now turned on by default. `xla_enable_async_all_gather` etc. --- .github/container/test-maxtext.sh | 5 ++--- README.md | 2 ++ rosetta/docs/GPU_performance.md | 6 +++++- rosetta/docs/NATIVE_FP8.md | 13 +++++-------- rosetta/docs/PGLE.md | 1 - rosetta/rosetta/projects/maxtext/README.md | 8 ++------ .../projects/maxtext/scripts/example_slurm.sub | 8 +------- rosetta/rosetta/projects/pax/README.md | 8 ++++---- 8 files changed, 21 insertions(+), 30 deletions(-) diff --git a/.github/container/test-maxtext.sh b/.github/container/test-maxtext.sh index 164fa5912..0dc26c8c1 100755 --- a/.github/container/test-maxtext.sh +++ b/.github/container/test-maxtext.sh @@ -223,7 +223,7 @@ export NVTE_FUSED_ATTN=${ENABLE_FUSED_ATTN} export XLA_PYTHON_CLIENT_MEM_FRACTION=${MEM_FRACTION} export CUDA_DEVICE_MAX_CONNECTIONS=1 -export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true +export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false --xla_gpu_graph_level=0 --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 @@ -232,8 +232,7 @@ export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_schedule --xla_gpu_enable_pipelined_all_gather=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true - --xla_gpu_enable_while_loop_double_buffering=true - --xla_gpu_enable_triton_softmax_fusion=false + --xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_disable_hlo_passes=rematerialization} diff --git a/README.md b/README.md index 66d9b2a4e..1764c5f00 100644 --- a/README.md +++ b/README.md @@ -300,6 +300,8 @@ The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is emb There are various other XLA flags users can set to improve performance. For a detailed explanation of these flags, please refer to the [GPU performance](./rosetta/docs/GPU_performance.md) doc. XLA flags can be tuned per workflow. For example, each script in [contrib/gpu/scripts_gpu](https://github.com/google/paxml/tree/main/paxml/contrib/gpu/scripts_gpu) sets its own [XLA flags](https://github.com/google/paxml/blob/93fbc8010dca95af59ab615c366d912136b7429c/paxml/contrib/gpu/scripts_gpu/benchmark_gpt_multinode.sh#L30-L33). +For a list of previously used XLA flags that are no longer needed, please also refer to the [GPU performance](./rosetta/docs/GPU_performance.md#previously-used-xla-flags) page. + ## Profiling JAX programs on GPU See [this page](./docs/profiling.md) for more information about how to profile JAX programs on GPU. diff --git a/rosetta/docs/GPU_performance.md b/rosetta/docs/GPU_performance.md index c5456e3c4..fabbc6963 100644 --- a/rosetta/docs/GPU_performance.md +++ b/rosetta/docs/GPU_performance.md @@ -128,6 +128,10 @@ Fine-grain control to improve performance by initializing a NCCL communicator to - --xla_gpu_enable_cudnn_fmha=false (enables XLA pattern matcher to detect multi-headed attention pattern in JAX) - --xla_disable_hlo_passes=<> (turns off specific HLO passes; can be used for debugging) +## Previously used XLA Flags - +The following flags were used previously used but no longer required. +- --xla_gpu_enable_async_reduce_scatter, --xla_gpu_enable_async_all_reduce, --xla_gpu_enable_async_all_gather ; Turned on by default, no longer needed +- --xla_gpu_enable_highest_priority_async_stream ; Turned on by default +- --xla_gpu_enable_triton_softmax_fusion ; Deprecated, no longer used diff --git a/rosetta/docs/NATIVE_FP8.md b/rosetta/docs/NATIVE_FP8.md index dd3aa1bae..069b06fdd 100644 --- a/rosetta/docs/NATIVE_FP8.md +++ b/rosetta/docs/NATIVE_FP8.md @@ -111,13 +111,11 @@ Enabling this feature is effortless. Users only need to include the option `--fd In addition to the suggested XLA flags mentioned in [this section](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta/rosetta/projects/pax/README.md#xla-flags), we also recommend setting these following XLA flags. The execution script should look like: ```bash export XLA_FLAGS=" \ - --xla_gpu_enable_reduction_epilogue_fusion=false \ --xla_gpu_enable_triton_gemm=false \ - --xla_gpu_enable_cudnn_fmha=false \ - --xla_gpu_enable_cudnn_layer_norm=true \ - --xla_gpu_enable_cublaslt=true \ - --xla_gpu_enable_latency_hiding_scheduler=true \ - --xla_gpu_all_reduce_combine_threshold_bytes=51200 " + --xla_gpu_enable_pipelined_all_reduce=false \ + --xla_gpu_enable_pipelined_all_gather=false \ + --xla_gpu_enable_pipelined_reduce_scatter=false \ +" export ENABLE_TE=0 python -m paxml.main \ ... @@ -125,8 +123,7 @@ python -m paxml.main \ ... ``` -Please ensure you include the first two flags, `--xla_gpu_enable_reduction_epilogue_fusion=false` and `--xla_gpu_enable_triton_gemm=false`, as they are essential for enabling the FP8 functionality. The additional flags primarily focus on performance enhancement and should also prove beneficial for non-FP8 executions. - +Please not that disabling the triton gemm and pipelined collectives is essential for enabling the FP8 functionality and performance. ## Transformer Engine vs Native FP8 Support Native XLA-FP8 specifically targets matrix multiplication operations. In contrast, the Transformer Engine focuses on enhancing the overall performance of the entire transformer layer. This encompasses not only the FP8 matrix multiplication but also attention mechanisms, layer normalizations, and other components. diff --git a/rosetta/docs/PGLE.md b/rosetta/docs/PGLE.md index 02e5f5294..2425ddffe 100644 --- a/rosetta/docs/PGLE.md +++ b/rosetta/docs/PGLE.md @@ -70,7 +70,6 @@ export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_while_loop_double_buffering=true ---xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_disable_hlo_passes=rematerialization diff --git a/rosetta/rosetta/projects/maxtext/README.md b/rosetta/rosetta/projects/maxtext/README.md index fde5a9125..2320a7ed9 100644 --- a/rosetta/rosetta/projects/maxtext/README.md +++ b/rosetta/rosetta/projects/maxtext/README.md @@ -67,12 +67,9 @@ In order to obtain the best performance, please set the appropriate XLA flags. W The [GPU Performance document](../../../docs/GPU_performance.md) provides a detailed description of the XLA flags that can be set to optimize performance. These are the recommended XLA flags to get good performance for MaxText. ``` -XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true - --xla_gpu_enable_async_all_gather=true - --xla_gpu_enable_async_reduce_scatter=true +XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false - --xla_gpu_graph_level=0 - --xla_gpu_enable_async_all_reduce=true + --xla_gpu_graph_level=0 --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=1073741824 --xla_gpu_reduce_scatter_combine_threshold_bytes=134217728 @@ -80,7 +77,6 @@ XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_while_loop_double_buffering=true - --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_disable_hlo_passes=rematerialization" diff --git a/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub b/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub index e96eaa781..0ca3fd802 100644 --- a/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub +++ b/rosetta/rosetta/projects/maxtext/scripts/example_slurm.sub @@ -53,11 +53,8 @@ export NCCL_IB_SL=1 # Set XLA Flags export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true - --xla_gpu_enable_async_all_gather=true - --xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_triton_gemm=false --xla_gpu_graph_level=0 - --xla_gpu_enable_async_all_reduce=true --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=1073741824 --xla_gpu_reduce_scatter_combine_threshold_bytes=134217728 @@ -65,12 +62,9 @@ export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_while_loop_double_buffering=true - --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false - --xla_disable_hlo_passes=rematerialization - --xla_gpu_enable_custom_fusions=false - --xla_gpu_enable_address_computation_fusion=false" + --xla_disable_hlo_passes=rematerialization" # Make directories that may not exist mkdir -p $BASE_WORKSPACE_DIR diff --git a/rosetta/rosetta/projects/pax/README.md b/rosetta/rosetta/projects/pax/README.md index 6ac4dc150..d1829b847 100644 --- a/rosetta/rosetta/projects/pax/README.md +++ b/rosetta/rosetta/projects/pax/README.md @@ -138,10 +138,10 @@ The [GPU Performance document](../../../docs/GPU_performance.md) provides a deta For the the 126M model, we recommend setting `--xla_gpu_all_reduce_combine_threshold_bytes=33554432`, which is different from the value recommended in `paxml/contrib/gpu/scripts_gpu/run_pile_multinode.sh`. To overwrite the default XLA flags set in the script, set the `BASE_XLA_FLAGS` environment variable prior to running `run_pile_multinode` as follows: ``` -BASE_XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false - --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true - --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_all_reduce_combine_threshold_bytes=33554432 - --xla_gpu_graph_level=0 --xla_gpu_enable_async_all_reduce=true" bash run_pile_multinode.sh ... +BASE_XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true + --xla_gpu_enable_triton_gemm=false + --xla_gpu_all_reduce_combine_threshold_bytes=33554432 + --xla_gpu_graph_level=0" bash run_pile_multinode.sh ... ``` # Configs From 44b4dfee401a03c0cf3bebfec8700d9b61eb231f Mon Sep 17 00:00:00 2001 From: Md Fahim Faysal Khan Date: Thu, 5 Sep 2024 21:47:28 -0700 Subject: [PATCH 02/11] fix tensorboard events dir path (#1032) Fixed the tensorboard dir path after a recent change in MaxText software: https://github.com/google/maxtext/pull/863 --- .github/workflows/baselines/test_maxtext_metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/baselines/test_maxtext_metrics.py b/.github/workflows/baselines/test_maxtext_metrics.py index bd180ecfe..a130c86c6 100644 --- a/.github/workflows/baselines/test_maxtext_metrics.py +++ b/.github/workflows/baselines/test_maxtext_metrics.py @@ -19,7 +19,7 @@ def test_loss(baseline_filename): baseline_filepath = os.path.join(baselines_dir, baseline_filename) test_config = baseline_filename.split(".")[0] - event_file = os.path.join(results_dir, test_config, "logdir/tensorboard/events*") + event_file = os.path.join(results_dir, test_config, "logdir/tensorboard/logdir/events*") event_file = glob.glob(event_file)[0] with open(baseline_filepath, "r") as baseline_file: end_step = json.load(baseline_file)["end_step"] @@ -31,7 +31,7 @@ def test_loss(baseline_filename): def test_step_time(baseline_filename): baseline_filepath = os.path.join(baselines_dir, baseline_filename) test_config = baseline_filename.split(".")[0] - event_file = os.path.join(results_dir, test_config, "logdir/tensorboard/events*") + event_file = os.path.join(results_dir, test_config, "logdir/tensorboard/logdir/events*") event_file = glob.glob(event_file)[0] with open(baseline_filepath, "r") as baseline_file: step_time_avg_expected = json.load(baseline_file)["step_time_avg"] From f808df5883c23ab0fe9e9310384e0875386a1d8b Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Fri, 6 Sep 2024 15:06:18 -0700 Subject: [PATCH 03/11] Makes jaxlib wheel dirs readable for non-root users (#1023) Example as of 8-28-2024 ``` $ docker run --entrypoint='' --rm -it ghcr.io/nvidia/jax:pax-2024-08-28 ls -lah /opt/jaxlibs total 20K drwxr-xr-x 1 root root 4.0K Aug 28 09:43 . drwxr-xr-x 1 root root 4.0K Aug 28 10:04 .. drwx------ 1 root root 4.0K Aug 28 09:43 jax_gpu_pjrt drwx------ 1 root root 4.0K Aug 28 09:43 jax_gpu_plugin drwx------ 1 root root 4.0K Aug 28 09:43 jaxlib ``` Signed-off-by: Terry Kong --- .github/container/build-jax.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/container/build-jax.sh b/.github/container/build-jax.sh index fa4c055b8..8ff65ca99 100755 --- a/.github/container/build-jax.sh +++ b/.github/container/build-jax.sh @@ -316,6 +316,9 @@ pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUIL # jaxlib 0.4.32.dev20240808 /opt/jaxlibs/jaxlib pip list | grep jax +# Ensure directories are readable by all for non-root users +chmod 755 $BUILD_PATH_JAXLIB/* + ## Cleanup pushd $SRC_PATH_JAX From f116054dbed654cbb280764457fa9a78fe003b51 Mon Sep 17 00:00:00 2001 From: Vladislav Date: Mon, 9 Sep 2024 10:24:07 -0700 Subject: [PATCH 04/11] TE multithread build (#1009) --- .github/container/Dockerfile.jax | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/container/Dockerfile.jax b/.github/container/Dockerfile.jax index c85bee347..726656a7a 100644 --- a/.github/container/Dockerfile.jax +++ b/.github/container/Dockerfile.jax @@ -62,6 +62,7 @@ pip install ninja && rm -rf ~/.cache/pip # TransformerEngine now needs JAX at build time git-clone.sh ${URLREF_TRANSFORMER_ENGINE} ${SRC_PATH_TRANSFORMER_ENGINE} pushd ${SRC_PATH_TRANSFORMER_ENGINE} +export NVTE_BUILD_THREADS_PER_JOB=8 python setup.py bdist_wheel && rm -rf build ls "${SRC_PATH_TRANSFORMER_ENGINE}/dist" EOF From 056a3b0db2e34c497f7984e54bb504d9b33efe58 Mon Sep 17 00:00:00 2001 From: Shanbin Ke Date: Wed, 18 Sep 2024 15:56:13 -0700 Subject: [PATCH 05/11] Add an option to test-pax.sh to enable XLA cuDNN flash attention (#1045) Provide an option to run XLA cuDNN flash attention as an alternative to TE cuDNN flash attention. --- .github/container/test-pax.sh | 49 +++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/.github/container/test-pax.sh b/.github/container/test-pax.sh index 2b33f53f7..46ce6ae73 100755 --- a/.github/container/test-pax.sh +++ b/.github/container/test-pax.sh @@ -15,7 +15,8 @@ usage() { echo " -a, --additional-args Additional fiddle args to pass to paxml/main.py" echo " -b, --batch-per-gpu Batch size per GPU, defaults to 4." echo " --dtype Batch size, defaults to bfloat16." - echo " --enable-te If set, will run with env var ENABLE_TE=1." + echo " --enable-te If set, will run with env var ENABLE_TE=1." + echo " --enable-cudnn-fa If set, will use cudnn fa." echo " --enable-dropout If set, will set DROPOUT_PROB to 0.1." echo " --disable-fused-attn Whether disable TE fused attention." echo " --model-type One of 126M, 5B, LLaMA70BProxy. Defaults to 126M" @@ -26,13 +27,13 @@ usage() { echo " --data-parallel Data parallelism to use. Defaults to 1." echo " --fsdp Fully-sharded data parallelism to use. Defaults to 1." echo " --tensor-parallel Tensor parallelism to use. Defaults to 1." - echo " --pipeline-parallel Pipeline parallelism to use. Defaults to 1 for no pipelining." + echo " --pipeline-parallel Pipeline parallelism to use. Defaults to 1 for no pipelining." echo " -n, --nodes Number of nodes." echo " -h, --help Print usage." exit $1 } -args=$(getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-dropout,disable-fused-attn,model-type:,evaluate,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@") +args=$(getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-cudnn-fa,enable-dropout,disable-fused-attn,model-type:,evaluate,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@") if [[ $? -ne 0 ]]; then exit $1 fi @@ -50,6 +51,7 @@ TP=1 PP=1 NODES=1 ENABLE_TE=0 +ENABLE_CUDNN_FA=0 MODEL_TYPE=126M NVTE_FUSED_ATTN=1 DROPOUT=0 @@ -75,6 +77,10 @@ while [ : ]; do ENABLE_TE=1 shift 1 ;; + --enable-cudnn-fa) + ENABLE_CUDNN_FA=1 + shift 1 + ;; --enable-dropout) DROPOUT='0.1' shift 1 @@ -128,7 +134,7 @@ while [ : ]; do ;; --) shift; - break + break ;; *) echo "UNKNOWN OPTION $1" @@ -149,6 +155,7 @@ print_var NGPUS print_var OUTPUT print_var MULTIPROCESS print_var ENABLE_TE +print_var ENABLE_CUDNN_FA print_var NVTE_FUSED_ATTN print_var EVALUATE print_var DROPOUT @@ -196,10 +203,10 @@ if dcn_factor > 1: if dp % dcn_factor == 0: dcn_dp = dcn_factor dp = int(dp / dcn_factor) - elif fsdp % dcn_factor == 0: + elif fsdp % dcn_factor == 0: dcn_fsdp = dcn_factor fsdp = int(fsdp / dcn_factor) - elif pp % dcn_factor == 0: + elif pp % dcn_factor == 0: dcn_pp = dcn_factor pp = int(pp / dcn_factor) @@ -209,12 +216,12 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam): USE_REPEATED_LAYER = False ICI_MESH_SHAPE = [64,1,1] MAX_STEPS = 600000 - + MAX_SEQ_LEN = 2048 VOCAB_SIZE = 50304 PACKED_INPUT = True PERCORE_BATCH_SIZE = 4 - + NUM_LAYERS = 12 NUM_HEADS = 12 MODEL_DIMS = 768 @@ -223,14 +230,14 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam): TRAINABLE_POSITION_EMB = True TRAINABLE_PE_MAX_SEQ_LEN = MAX_SEQ_LEN - + USE_BIAS = True LAYERNORM_EPSILON = 1e-5 ATTEN_LOGIT_CAP = -1.0 INIT_STD = 0.023 SOFTMAX_INIT_STD = 0.023 ACTIVATION_CLS = layers.GELU - + ## optimizer-related ADAM_BETA1 = 0.9 ADAM_BETA2 = 0.95 @@ -255,7 +262,7 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam): ## disable eval to avoid including eval ## in steps/sec calculation EVAL_INTERVAL_STEPS = 100000 - + def task(self): task_p = super().task() task_p = configure_gpt3_task(self, task_p) @@ -263,7 +270,7 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam): task_p.train.num_train_steps = self.MAX_STEPS model_p = task_p.model - + ### compute layernorm reductions in fp32. Needed for stable training on GPUs stacked_p = model_p.lm_tpl.stacked_transformer_tpl if stacked_p.cls == layers.PipelinedTransformer: @@ -274,13 +281,13 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam): transformer_layer_p.ln_tpl.reductions_in_fp32 = True transformer_layer_p.tr_fflayer_tpl.ln_tpl.reductions_in_fp32 = True task_p.model.lm_tpl.final_ln_tpl.reductions_in_fp32 = True - + model_p.params_init = WeightInit.Gaussian(self.INIT_STD) softmax_init = WeightInit.Gaussian(self.SOFTMAX_INIT_STD) model_p.lm_tpl.softmax_tpl.params_init = softmax_init - + model_p.apply_eval_sample_weights = True - + ## set input, residual, attention dropout to DROPOUT_PROB, remaining dropout to 0 stacked_p.dropout_prob = 0.0 stacked_p.input_dropout_prob = self.DROPOUT_PROB @@ -316,14 +323,14 @@ class LLaMA70BSyntheticSmall(BaseLLaMA, SyntheticDataset): if pp > 1: @experiment_registry.register class Synthetic126MCI(GPT126MPP, SyntheticDataset): - + ICI_MESH_SHAPE = [pp, dp, fsdp, tp] DCN_MESH_SHAPE = [dcn_pp, dcn_dp, dcn_fsdp, 1] MICROBATCH_SIZE = 2 NUM_STAGES = pp PERCORE_BATCH_SIZE = percore_batch_size FRPOP_DTYPE = dtype - + def task(self): task_p = super().task() task_p.train.always_use_train_for_model_init=False @@ -333,7 +340,7 @@ if pp > 1: else: @experiment_registry.register class Synthetic126MCI(Synthetic126M): - + ICI_MESH_SHAPE = [dp, fsdp, tp] DCN_MESH_SHAPE = [dcn_dp, dcn_fsdp, 1] PERCORE_BATCH_SIZE = percore_batch_size @@ -343,7 +350,7 @@ else: ## disable eval EVAL_INTERVAL_STEPS = 100000 - + def task(self): task_p = super().task() @@ -374,6 +381,10 @@ export ENABLE_TE=$ENABLE_TE export NVTE_FUSED_ATTN=$NVTE_FUSED_ATTN export VOCAB_PATH=${VOCAB_PATH:-gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model} +if [[ ${ENABLE_CUDNN_FA} -ne 0 ]]; then + ADDITIONAL_ARGS="${ADDITIONAL_ARGS} --fdl.USE_CUDNN_FLASH_ATTENTION=True" +fi + if [[ ${MODEL_TYPE} == "126M" ]]; then CONFIG=ci_configs.Synthetic126MCI elif [[ ${MODEL_TYPE} == "5B" ]]; then From 57919e03ea63a4be4252ac850dae0e11e28b2a40 Mon Sep 17 00:00:00 2001 From: Vladislav Date: Tue, 24 Sep 2024 21:50:55 -0700 Subject: [PATCH 06/11] Bump CUDA to 12.6.1 (#1050) --- .github/container/Dockerfile.base | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/container/Dockerfile.base b/.github/container/Dockerfile.base index 023576cb5..c7cafa503 100644 --- a/.github/container/Dockerfile.base +++ b/.github/container/Dockerfile.base @@ -1,5 +1,5 @@ # syntax=docker/dockerfile:1-labs -ARG BASE_IMAGE=nvidia/cuda:12.5.0-devel-ubuntu22.04 +ARG BASE_IMAGE=nvidia/cuda:12.6.1-devel-ubuntu22.04 ARG GIT_USER_NAME="JAX Toolbox" ARG GIT_USER_EMAIL=jax@nvidia.com ARG CLANG_VERSION=17 From 3a2e8c8f71e167f7448594ca0d02da95ace9ba38 Mon Sep 17 00:00:00 2001 From: Vladislav Date: Wed, 25 Sep 2024 00:57:57 -0700 Subject: [PATCH 07/11] Bump clang to 18 (#1060) Forced by this change in JAX build system: https://github.com/jax-ml/jax/pull/23787 --- .github/container/Dockerfile.base | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/container/Dockerfile.base b/.github/container/Dockerfile.base index c7cafa503..9f0851897 100644 --- a/.github/container/Dockerfile.base +++ b/.github/container/Dockerfile.base @@ -2,7 +2,7 @@ ARG BASE_IMAGE=nvidia/cuda:12.6.1-devel-ubuntu22.04 ARG GIT_USER_NAME="JAX Toolbox" ARG GIT_USER_EMAIL=jax@nvidia.com -ARG CLANG_VERSION=17 +ARG CLANG_VERSION=18 ############################################################################### ## Obtain GCP's NCCL TCPx plugin From ccededf8c97989f69d9027220885bfb8e4db6f76 Mon Sep 17 00:00:00 2001 From: "Yu-Hang \"Maxin\" Tang" Date: Wed, 25 Sep 2024 09:34:02 -0700 Subject: [PATCH 08/11] Add CI argument for user-defined CUDA base image (#1013) Co-authored-by: Olli Lupton --- .github/workflows/_ci.yaml | 6 ++++++ .github/workflows/ci.yaml | 19 +++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/.github/workflows/_ci.yaml b/.github/workflows/_ci.yaml index fc04b83ab..426764323 100644 --- a/.github/workflows/_ci.yaml +++ b/.github/workflows/_ci.yaml @@ -11,6 +11,11 @@ on: description: 'Build date in YYYY-MM-DD format' required: false default: NOT SPECIFIED + CUDA_IMAGE: + type: string + description: CUDA image to use as base, e.g. nvidia/cuda:X.Y.Z-devel-ubuntu22.04 + default: 'latest' + required: false MANIFEST_ARTIFACT_NAME: type: string description: 'Artifact name in current run w/ manifest/patches. Leaving empty uses manifest/patches in current branch' @@ -37,6 +42,7 @@ jobs: uses: ./.github/workflows/_build_base.yaml with: ARCHITECTURE: ${{ inputs.ARCHITECTURE }} + BASE_IMAGE: ${{ inputs.CUDA_IMAGE }} BUILD_DATE: ${{ inputs.BUILD_DATE }} MANIFEST_ARTIFACT_NAME: ${{ inputs.MANIFEST_ARTIFACT_NAME }} secrets: inherit diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 70aeff5ff..0c3c8bdb0 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -28,6 +28,11 @@ on: description: "(used if BUMP_MANIFEST=true) If true: attempt to PR/merge manifest branch" default: false required: false + CUDA_IMAGE: + type: string + description: CUDA image to use as base, e.g. nvidia/cuda:X.Y.Z-devel-ubuntu22.04 + default: 'latest' + required: false SOURCE_OVERRIDES: type: string description: | @@ -60,6 +65,7 @@ jobs: MANIFEST_ARTIFACT_NAME: ${{ steps.manifest-branch.outputs.MANIFEST_ARTIFACT_NAME }} MANIFEST_BRANCH: ${{ steps.manifest-branch.outputs.MANIFEST_BRANCH }} MERGE_BUMPED_MANIFEST: ${{ steps.manifest-branch.outputs.MERGE_BUMBED_MANIFEST }} + CUDA_IMAGE: ${{ steps.cuda-image.outputs.CUDA_IMAGE }} steps: - name: Cancel workflow run if the trigger is a draft PR id: cancel-if-draft @@ -114,6 +120,17 @@ jobs: exit 1 fi + - name: Determine CUDA image to use + id: cuda-image + shell: bash -x -e {0} + run: | + if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then + CUDA_IMAGE="${{ inputs.CUDA_IMAGE }}" + else + CUDA_IMAGE="latest" + fi + echo "CUDA_IMAGE=${CUDA_IMAGE}" >> $GITHUB_OUTPUT + bump-manifest: needs: metadata runs-on: ubuntu-22.04 @@ -177,6 +194,7 @@ jobs: with: ARCHITECTURE: amd64 BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} + CUDA_IMAGE: ${{ needs.metadata.outputs.CUDA_IMAGE }} MANIFEST_ARTIFACT_NAME: ${{ needs.metadata.outputs.MANIFEST_ARTIFACT_NAME }} SOURCE_URLREFS: ${{ needs.bump-manifest.outputs.SOURCE_URLREFS }} secrets: inherit @@ -187,6 +205,7 @@ jobs: with: ARCHITECTURE: arm64 BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} + CUDA_IMAGE: ${{ needs.metadata.outputs.CUDA_IMAGE }} MANIFEST_ARTIFACT_NAME: ${{ needs.metadata.outputs.MANIFEST_ARTIFACT_NAME }} SOURCE_URLREFS: ${{ needs.bump-manifest.outputs.SOURCE_URLREFS }} secrets: inherit From 1a3febb377a5a0d7ff3344b3d7844ae6f56dacc7 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Thu, 26 Sep 2024 16:05:35 -0700 Subject: [PATCH 09/11] Model XLA Flags (#1052) Moves XLA flags from model CI into their own files that can be sourced. Each file can be sourced and will print what it sets. Some files source other files, which was intentional to avoid introducing sim-links into the repo, which can sometimes have platform issues (like on windows). --------- Signed-off-by: Terry Kong --- .../maxtext/xla_flags/llama2-7b-1N8G.env | 24 ++++++++++++++++++ .../rosetta/projects/pax/xla_flags/common.env | 13 ++++++++++ .../projects/pax/xla_flags/glam-126m64e.env | 3 +++ .../projects/pax/xla_flags/glam-64b64e.env | 3 +++ .../projects/pax/xla_flags/gpt-126m.env | 14 +++++++++++ .../projects/pax/xla_flags/gpt-175b.env | 3 +++ .../rosetta/projects/pax/xla_flags/gpt-5b.env | 3 +++ .../projects/pax/xla_flags/grok-proxy.env | 25 +++++++++++++++++++ .../projects/pax/xla_flags/llama-70b.env | 3 +++ .../projects/pax/xla_flags/llama-7b-lora.env | 4 +++ .../projects/pax/xla_flags/llama-7b.env | 4 +++ rosetta/rosetta/projects/t5x/xla_flags/t5.env | 4 +++ .../vit/xla_flags/vit-base-highgbs.env | 4 +++ .../projects/vit/xla_flags/vit-base.env | 4 +++ 14 files changed, 111 insertions(+) create mode 100644 rosetta/rosetta/projects/maxtext/xla_flags/llama2-7b-1N8G.env create mode 100644 rosetta/rosetta/projects/pax/xla_flags/common.env create mode 100644 rosetta/rosetta/projects/pax/xla_flags/glam-126m64e.env create mode 100644 rosetta/rosetta/projects/pax/xla_flags/glam-64b64e.env create mode 100644 rosetta/rosetta/projects/pax/xla_flags/gpt-126m.env create mode 100644 rosetta/rosetta/projects/pax/xla_flags/gpt-175b.env create mode 100644 rosetta/rosetta/projects/pax/xla_flags/gpt-5b.env create mode 100644 rosetta/rosetta/projects/pax/xla_flags/grok-proxy.env create mode 100644 rosetta/rosetta/projects/pax/xla_flags/llama-70b.env create mode 100644 rosetta/rosetta/projects/pax/xla_flags/llama-7b-lora.env create mode 100644 rosetta/rosetta/projects/pax/xla_flags/llama-7b.env create mode 100644 rosetta/rosetta/projects/t5x/xla_flags/t5.env create mode 100644 rosetta/rosetta/projects/vit/xla_flags/vit-base-highgbs.env create mode 100644 rosetta/rosetta/projects/vit/xla_flags/vit-base.env diff --git a/rosetta/rosetta/projects/maxtext/xla_flags/llama2-7b-1N8G.env b/rosetta/rosetta/projects/maxtext/xla_flags/llama2-7b-1N8G.env new file mode 100644 index 000000000..d999f5b5e --- /dev/null +++ b/rosetta/rosetta/projects/maxtext/xla_flags/llama2-7b-1N8G.env @@ -0,0 +1,24 @@ +set -x +NUM_NODES=1 +NUM_GPUS=8 +THRESHOLD_BYTES=1073741824 +export XLA_FLAGS="\ + --xla_gpu_enable_latency_hiding_scheduler=true \ + --xla_gpu_enable_triton_gemm=false \ + --xla_gpu_graph_level=0 \ + --xla_gpu_enable_highest_priority_async_stream=true \ + --xla_gpu_all_reduce_combine_threshold_bytes=${THRESHOLD_BYTES} \ + --xla_gpu_all_gather_combine_threshold_bytes=$((THRESHOLD_BYTES/(NUM_NODES*NUM_GPUS))) \ + --xla_gpu_reduce_scatter_combine_threshold_bytes=$((THRESHOLD_BYTES/(NUM_NODES*NUM_GPUS*2))) \ + --xla_gpu_enable_pipelined_all_gather=true \ + --xla_gpu_enable_pipelined_reduce_scatter=true \ + --xla_gpu_enable_pipelined_all_reduce=true \ + --xla_gpu_enable_while_loop_double_buffering=true \ + --xla_gpu_enable_triton_softmax_fusion=false \ + --xla_gpu_enable_all_gather_combine_by_dim=false \ + --xla_gpu_enable_reduce_scatter_combine_by_dim=false \ + --xla_disable_hlo_passes=rematerialization \ + " +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 +unset NUM_NODES NUM_GPUS THRESHOLD_BYTES +set +x diff --git a/rosetta/rosetta/projects/pax/xla_flags/common.env b/rosetta/rosetta/projects/pax/xla_flags/common.env new file mode 100644 index 000000000..26c819143 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/common.env @@ -0,0 +1,13 @@ +set -x +THRESHOLD_BYTES=51200 +export XLA_FLAGS="\ + --xla_gpu_enable_latency_hiding_scheduler=true \ + --xla_allow_excess_precision \ + --xla_gpu_enable_highest_priority_async_stream=true \ + --xla_gpu_enable_triton_softmax_fusion=false \ + --xla_gpu_all_reduce_combine_threshold_bytes=${THRESHOLD_BYTES} \ + --xla_gpu_graph_level=0 \ + " +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.8 +unset THRESHOLD_BYTES +set +x diff --git a/rosetta/rosetta/projects/pax/xla_flags/glam-126m64e.env b/rosetta/rosetta/projects/pax/xla_flags/glam-126m64e.env new file mode 100644 index 000000000..8b0237170 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/glam-126m64e.env @@ -0,0 +1,3 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +source $SCRIPT_DIR/common.env +unset SCRIPT_DIR diff --git a/rosetta/rosetta/projects/pax/xla_flags/glam-64b64e.env b/rosetta/rosetta/projects/pax/xla_flags/glam-64b64e.env new file mode 100644 index 000000000..8b0237170 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/glam-64b64e.env @@ -0,0 +1,3 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +source $SCRIPT_DIR/common.env +unset SCRIPT_DIR diff --git a/rosetta/rosetta/projects/pax/xla_flags/gpt-126m.env b/rosetta/rosetta/projects/pax/xla_flags/gpt-126m.env new file mode 100644 index 000000000..e5b97b466 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/gpt-126m.env @@ -0,0 +1,14 @@ +set -x +THRESHOLD_BYTES=33554432 +export XLA_FLAGS="\ + --xla_gpu_enable_latency_hiding_scheduler=true \ + --xla_allow_excess_precision \ + --xla_gpu_enable_highest_priority_async_stream=true \ + --xla_gpu_enable_triton_softmax_fusion=false \ + --xla_gpu_all_reduce_combine_threshold_bytes=${THRESHOLD_BYTES} \ + --xla_gpu_graph_level=0 \ + --xla_gpu_enable_cudnn_fmha=false \ + " +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.8 +unset THRESHOLD_BYTES +set +x diff --git a/rosetta/rosetta/projects/pax/xla_flags/gpt-175b.env b/rosetta/rosetta/projects/pax/xla_flags/gpt-175b.env new file mode 100644 index 000000000..8b0237170 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/gpt-175b.env @@ -0,0 +1,3 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +source $SCRIPT_DIR/common.env +unset SCRIPT_DIR diff --git a/rosetta/rosetta/projects/pax/xla_flags/gpt-5b.env b/rosetta/rosetta/projects/pax/xla_flags/gpt-5b.env new file mode 100644 index 000000000..8b0237170 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/gpt-5b.env @@ -0,0 +1,3 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +source $SCRIPT_DIR/common.env +unset SCRIPT_DIR diff --git a/rosetta/rosetta/projects/pax/xla_flags/grok-proxy.env b/rosetta/rosetta/projects/pax/xla_flags/grok-proxy.env new file mode 100644 index 000000000..e48b76dcf --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/grok-proxy.env @@ -0,0 +1,25 @@ +set -x +ALL_REDUCE_THRESHOLD_BYTES=3221225472 +ALL_GATHER_THRESHOLD_BYTES=3221225472 +REDUCE_SCATTER_THRESHOLD_BYTES=402653184 +export XLA_FLAGS="\ + --xla_gpu_enable_latency_hiding_scheduler=true \ + --xla_allow_excess_precision \ + --xla_gpu_enable_highest_priority_async_stream=true \ + --xla_gpu_enable_triton_softmax_fusion=false \ + --xla_gpu_all_reduce_combine_threshold_bytes=${ALL_REDUCE_THRESHOLD_BYTES} \ + --xla_gpu_graph_level=0 \ + --xla_gpu_all_gather_combine_threshold_bytes=${ALL_GATHER_THRESHOLD_BYTES} \ + --xla_gpu_reduce_scatter_combine_threshold_bytes=${REDUCE_SCATTER_THRESHOLD_BYTES} \ + --xla_gpu_enable_pipelined_all_gather=true \ + --xla_gpu_enable_pipelined_reduce_scatter=true \ + --xla_gpu_enable_pipelined_all_reduce=true \ + --xla_gpu_enable_while_loop_double_buffering=true \ + --xla_gpu_enable_all_gather_combine_by_dim=false \ + --xla_gpu_enable_reduce_scatter_combine_by_dim=false \ + --xla_disable_hlo_passes=rematerialization \ + --xla_gpu_enable_custom_fusions=true + " +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 +unset ALL_REDUCE_THRESHOLD_BYTES ALL_GATHER_THRESHOLD_BYTES REDUCE_SCATTER_THRESHOLD_BYTES +set +x diff --git a/rosetta/rosetta/projects/pax/xla_flags/llama-70b.env b/rosetta/rosetta/projects/pax/xla_flags/llama-70b.env new file mode 100644 index 000000000..8b0237170 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/llama-70b.env @@ -0,0 +1,3 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +source $SCRIPT_DIR/common.env +unset SCRIPT_DIR diff --git a/rosetta/rosetta/projects/pax/xla_flags/llama-7b-lora.env b/rosetta/rosetta/projects/pax/xla_flags/llama-7b-lora.env new file mode 100644 index 000000000..d1568e92c --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/llama-7b-lora.env @@ -0,0 +1,4 @@ +set -x +echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'" +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85 +set +x diff --git a/rosetta/rosetta/projects/pax/xla_flags/llama-7b.env b/rosetta/rosetta/projects/pax/xla_flags/llama-7b.env new file mode 100644 index 000000000..bd4ae50d5 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/llama-7b.env @@ -0,0 +1,4 @@ +set -x +echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'" +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.8 +set +x diff --git a/rosetta/rosetta/projects/t5x/xla_flags/t5.env b/rosetta/rosetta/projects/t5x/xla_flags/t5.env new file mode 100644 index 000000000..bd4ae50d5 --- /dev/null +++ b/rosetta/rosetta/projects/t5x/xla_flags/t5.env @@ -0,0 +1,4 @@ +set -x +echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'" +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.8 +set +x diff --git a/rosetta/rosetta/projects/vit/xla_flags/vit-base-highgbs.env b/rosetta/rosetta/projects/vit/xla_flags/vit-base-highgbs.env new file mode 100644 index 000000000..45140ed88 --- /dev/null +++ b/rosetta/rosetta/projects/vit/xla_flags/vit-base-highgbs.env @@ -0,0 +1,4 @@ +set -x +echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'" +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.75 +set +x diff --git a/rosetta/rosetta/projects/vit/xla_flags/vit-base.env b/rosetta/rosetta/projects/vit/xla_flags/vit-base.env new file mode 100644 index 000000000..882c9e9e8 --- /dev/null +++ b/rosetta/rosetta/projects/vit/xla_flags/vit-base.env @@ -0,0 +1,4 @@ +set -x +echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'" +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 +set +x From 3638a661e0aec7ff42e6ebac6dfeb7aa458de770 Mon Sep 17 00:00:00 2001 From: Vladislav Date: Fri, 27 Sep 2024 10:45:23 -0700 Subject: [PATCH 10/11] Add pathwaysutils for MaxText to manifest file (#1065) The latest MaxText uses `pathwayutils`, which is added as a dependency. Need to add it to our manifest.yaml file to resolve reference issue during final installation. --- .github/container/manifest.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/container/manifest.yaml b/.github/container/manifest.yaml index 60ef1a001..e9d30a3bc 100644 --- a/.github/container/manifest.yaml +++ b/.github/container/manifest.yaml @@ -177,3 +177,8 @@ orbax-checkpoint: tracking_ref: main latest_verified_commit: 16c2d409e365576284dbaf190ac002b24c1f927f mode: pip-vcs +pathwaysutils: + url: https://github.com/google/pathways-utils.git + tracking_ref: main + latest_verified_commit: 359776d454940ffaa337c36d1df16308d44a95a9 + mode: pip-vcs From ef3fd66e7d3421659d21aa51907fa50e209a34a6 Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Tue, 1 Oct 2024 22:34:19 +0200 Subject: [PATCH 11/11] nsys-jax post-processing: treat host-device copies as 1-device collectives (#1073) This adds logic to treat `dynamic[-update]-slice` operations that have a source/destination operand in the host memory space as being communication operations, labelling them as single-device "collectives". The goal is to improve support for analysing profiles of execution including offloading to host memory. Also fix using nsys 2024.6 by applying the same patch as 2024.5 that adds the thread ID. --- .github/container/install-nsight.sh | 11 +-- .../python/jax_nsys/jax_nsys/analysis.py | 54 +++++++++-- .../python/jax_nsys/jax_nsys/data_loaders.py | 19 +++- .../python/jax_nsys/jax_nsys/protobuf.py | 90 +++++++++++++++---- 4 files changed, 138 insertions(+), 36 deletions(-) diff --git a/.github/container/install-nsight.sh b/.github/container/install-nsight.sh index 73aee4163..f3e4e0715 100755 --- a/.github/container/install-nsight.sh +++ b/.github/container/install-nsight.sh @@ -17,11 +17,12 @@ apt-get clean rm -rf /var/lib/apt/lists/* -NSYS202451=/opt/nvidia/nsight-systems-cli/2024.5.1 -if [[ -d "${NSYS202451}" ]]; then - # * can match at least sbsa-armv8 and x86 - (cd ${NSYS202451}/target-linux-*/python/packages && git apply < /opt/nvidia/nsys-2024.5-tid-export.patch) -fi +for NSYS in /opt/nvidia/nsight-systems-cli/2024.5.1 /opt/nvidia/nsight-systems-cli/2024.6.1; do + if [[ -d "${NSYS}" ]]; then + # * can match at least sbsa-armv8 and x86 + (cd ${NSYS}/target-linux-*/python/packages && git apply < /opt/nvidia/nsys-2024.5-tid-export.patch) + fi +done # Install extra dependencies needed for `nsys recipe ...` commands. These are # used by the nsys-jax wrapper script. diff --git a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py index 9e3aaee4f..4e72a33fb 100644 --- a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py +++ b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py @@ -6,7 +6,7 @@ import pathlib from typing import Any -from .protobuf import HloProto, xla_module_metadata +from .protobuf import HloProto, _host_memory_space, xla_module_metadata from .utils import make_child_mask, ProfilerData pd.options.mode.copy_on_write = True @@ -38,6 +38,11 @@ def align_profiler_data_timestamps( # Determine which collective size will be used for the alignment num_profiled_devices = len(comm_df.index.get_level_values("Device").unique()) max_collective_size = comm_df["CollectiveSize"].max() + if max_collective_size == 1: + print( + f"WARNING: cannot align {num_profiled_devices} devices because max collective size is 1" + ) + return frames, {} assert ( num_profiled_devices == max_collective_size ), f"Aligning {num_profiled_devices} using collectives of size {max_collective_size} is not implemented" @@ -193,13 +198,51 @@ def _get_message_size( "all-to-all", "collective-broadcast", "collective-permute-start", + "dynamic-slice", + "dynamic-update-slice", "reduce-scatter", } ), f"{instruction}: message size calculation for {comm_inst.opcode} has not yet been validated" + + def _byte_size(inst) -> int: + size_bits = math.prod( + inst.shape.dimensions, + start=element_type_width(inst.shape.element_type), + ) + size_bytes, rem = divmod(size_bits, 8) + assert rem == 0 + return size_bytes + if comm_inst.opcode == "collective-permute-start": # See https://openxla.org/xla/operation_semantics#collectivepermute, which # generates pair-wise send+recv between devices collective_size = 2 + elif comm_inst.opcode in {"dynamic-slice", "dynamic-update-slice"}: + # Label host-device transfers orchestrated by dynamic[-update]-slice as single + # device collectives. + collective_size = 1 + if comm_inst.opcode == "dynamic-update-slice": + # For dynamic-update-slice the second operand is the one being copied + _, src_inst = module_proto.find_instruction_by_id(comm_inst.operand_ids[1]) + transfer_size = _byte_size(src_inst.proto()) + else: + # For dynamic-slice the return type size is the transfer size + assert comm_inst.opcode == "dynamic-slice" + _, src_inst = module_proto.find_instruction_by_id(comm_inst.operand_ids[0]) + transfer_size = _byte_size(comm_inst) + dest_on_host = _host_memory_space(comm_inst) + src_on_host = _host_memory_space(src_inst.proto()) + assert src_on_host != dest_on_host, ( + 'dynamic[-update]-slice is only considered is only "communication" if it ' + "represents a host-device transfer" + ) + return ( + transfer_size, + "device-to-host" if dest_on_host else "host-to-device", + 1, # collective size + 1.0, # bw_correction + 1.0, # bus_correction + ) else: # replica_groups is something like {{0,1},{4,5},{2,3},{6,7}}, if there are 8 # devices that are doing pair-wise collectives @@ -220,17 +263,12 @@ def _get_message_size( total_msg_size = 0 for operand_id in comm_inst.operand_ids: _, operand = module_proto.find_instruction_by_id(operand_id) - msg_size_bits = math.prod( - operand.proto().shape.dimensions, - start=element_type_width(operand.proto().shape.element_type), - ) + msg_size_bytes = _byte_size(operand.proto()) if comm_inst.opcode == "reduce-scatter": # NCCL's convention is that the message size of a reduce-scatter is the size of output buffer: # https://github.com/NVIDIA/nccl/blob/ab2b89c4c339bd7f816fbc114a4b05d386b66290/src/collectives.cc#L122 - msg_size_bits, rem = divmod(msg_size_bits, collective_size) + msg_size_bytes, rem = divmod(msg_size_bytes, collective_size) assert rem == 0 - msg_size_bytes, rem = divmod(msg_size_bits, 8) - assert rem == 0 total_msg_size += msg_size_bytes collective = comm_inst.opcode.removesuffix("-start") diff --git a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py index 6c25cb2ee..d6e4464bd 100644 --- a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py +++ b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py @@ -103,6 +103,9 @@ def is_communication(row): return _calculate_overlap(thunk_df) +compile_prefix = "XlaCompile:#module=" + + def _load_nvtx_gpu_proj_trace_single( prefix: pathlib.Path, file: pathlib.Path, @@ -305,10 +308,21 @@ def _load_nvtx_gpu_proj_trace_single( unique_pid_tid_pairs = module_df.loc[:, ("PID", "TID")].drop_duplicates() if len(unique_pid_tid_pairs) == 1: main_pid_tid_candidates.add(tuple(unique_pid_tid_pairs.iloc[0])) + # If the profile only includes N>1 modules, we may still be able to identify the + # main thread as the one responsible for XlaCompile ranges projected onto the GPU + # timeline + compile_ranges = df.loc[~all_thunks, "Name"].str.startswith( + tsl_prefix + compile_prefix + ) + compile_range_ids = compile_ranges[compile_ranges].index + unique_pid_tid_pairs = df.loc[compile_range_ids, ("PID", "TID")].drop_duplicates() + if len(unique_pid_tid_pairs) == 1: + main_pid_tid_candidates.add(tuple(unique_pid_tid_pairs.iloc[0])) assert len(main_pid_tid_candidates) < 2 if len(main_pid_tid_candidates) == 1: # Possibly not correct if len(device_by_pid_tid) > 1 assert len(device_by_pid_tid) > 0 + # Associate the main thread with the 0th device in device_by_pid_tid main_thread_df = device_by_pid_tid.iloc[:1] main_thread_df.index = pd.MultiIndex.from_tuples( main_pid_tid_candidates, names=["PID", "TID"] @@ -425,16 +439,13 @@ def _load_nvtx_gpu_proj_trace( return output -compile_prefix = "TSL:XlaCompile:#module=" - - def _splice_parallel_ranges(compile_df: pd.DataFrame) -> pd.DataFrame: # When parallel compilation is enabled, we end up with worker threads that # emit NVTX ranges but which are not accounted for in the RangeStack tree. # Splice these in under the relevant XlaCompile ranges in the RangeStack tree and # drop everything else. retain_mask = pd.Series(False, index=compile_df.index) - compile_mask = compile_df["Name"].str.startswith(compile_prefix) + compile_mask = compile_df["Name"].str.startswith("TSL:" + compile_prefix) for compile_range in compile_df[compile_mask].itertuples(): # Identify the slice of `compile_df` that overlaps in time with this XlaCompile # range diff --git a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf.py b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf.py index ef74165fd..4feae6038 100644 --- a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf.py +++ b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf.py @@ -1,10 +1,13 @@ -from collections import defaultdict import functools import lzma import pathlib import typing +def _host_memory_space(inst): + return inst.shape.layout.memory_space == 5 + + class StackFrame(typing.NamedTuple): column: int file: str @@ -25,6 +28,35 @@ def __init__(self, wrapped_hlo_proto, proto): # proto representing the actual collective, which will be different if the # async launch is handled by an async-start op # TODO: can any of copy-start, custom-call, recv, send represent communication? + # This also aims to identify, and (for now) flag as communication, kernels that + # implement device-to-host and host-to-device copies for memory offloading. + # For example, a device-to-host offload might look like + # computation { + # ... + # ROOT r1 = bf16[2,8,128,2048]{3,2,1,0:S(5)} dynamic-update-slice(...) + # } + # async_computation { + # ... + # ROOT r2 = bf16[2,8,128,2048]{3,2,1,0:S(5)} fusion(...), calls=computation + # } + # start = (...) async-start(...), calls=async_computation + # where the :S(5) annotation shows that a buffer is in host memory. + # A host-to-device load might look like + # computation { + # param_0 = bf16[2,8,128,2048]{3,2,1,0:S(5)} parameter(0) + # ... + # ROOT r1 = bf16[2,8,128,2048]{3,2,1,0} dynamic-slice(param_0, ...) + # } + # async_computation { + # param_0 = bf16[2,8,128,2048]{3,2,1,0:S(5)} parameter(0) + # ... + # ROOT r2 = bf16[2,8,128,2048]{3,2,1,0} fusion(param_0, ...), calls=computation + # } + # start = (...) async-start(...), calls=async_computation + # where the :S(5) memory space annotation is in a parameter instead of in the + # return value. + # For now, handling host-device kernels as single-device "collective" + # communication should be sufficient. self._comm_proto = None comm_opcodes = { "all-gather", @@ -39,25 +71,50 @@ def __init__(self, wrapped_hlo_proto, proto): "all-reduce-start", "collective-permute-start", } + + def _is_offloading_instruction(inst): + host_dest = _host_memory_space(inst) + + def _host_operand(i): + _, op = wrapped_hlo_proto.find_instruction_by_id(inst.operand_ids[i]) + return _host_memory_space(op.proto()) + + if inst.opcode == "dynamic-slice" and host_dest != _host_operand(0): + return True + elif ( + inst.opcode == "dynamic-update-slice" + and host_dest == _host_operand(0) + and host_dest != _host_operand(1) + ): + return True + return False + if self._proto.opcode in comm_opcodes | comm_start_opcodes: self._comm_proto = self._proto - elif self._proto.opcode == "async-start": + elif self._proto.opcode in {"async-start", "fusion"}: + # fusion example: + # computation { + # param_0 = f32[...]{...:S(5)} parameter(0) + # ... + # ROOT dus = f32[...]{...:S(5)} dynamic-update-slice(param_0, ...) + # } + # inst = f32[256,128,128]{2,1,0:S(5)} fusion(...), calls=computation # This might be thinly wrapping an opcode in `comm_opcodes` - other_opcodes = defaultdict(int) - for called_id in self._proto.called_computation_ids: - for called_inst in wrapped_hlo_proto.find_computation( - called_id - ).instructions: - if called_inst.opcode in comm_opcodes: + def _visit_computation(computation_id): + computation = wrapped_hlo_proto.find_computation(computation_id) + for called_inst in computation.instructions: + for called_id in called_inst.called_computation_ids: + _visit_computation(called_id) + if called_inst.opcode in comm_opcodes or _is_offloading_instruction( + called_inst + ): assert ( self._comm_proto is None ), f"Found {called_inst.opcode} child having already found {self._comm_proto.opcode}" self._comm_proto = called_inst - else: - other_opcodes[called_inst.opcode] += 1 - assert ( - other_opcodes.keys() == {"parameter"} - ), f"async-start op {self._proto.name} wrapped too many opcode types ({dict(other_opcodes)}) in addition to {self._comm_proto}" + + for called_id in self._proto.called_computation_ids: + _visit_computation(called_id) def communication_proto(self): return self._comm_proto @@ -68,12 +125,7 @@ def is_communication(self) -> bool: a little more complicated than you might hope, because async communications are not handled uniformly. """ - if self._comm_proto is None: - return False - assert ( - self._comm_proto.channel_id != 0 - ), f"Got channel_id={self._comm_proto.channel_id} for {self._comm_proto.name}" - return True + return self._comm_proto is not None def proto(self): """