Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/container/Dockerfile.base
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# syntax=docker/dockerfile:1-labs
ARG BASE_IMAGE=nvidia/cuda:12.5.0-devel-ubuntu22.04
ARG BASE_IMAGE=nvidia/cuda:12.6.1-devel-ubuntu22.04
ARG GIT_USER_NAME="JAX Toolbox"
ARG [email protected]
ARG CLANG_VERSION=17
ARG CLANG_VERSION=18

###############################################################################
## Obtain GCP's NCCL TCPx plugin
Expand Down
1 change: 1 addition & 0 deletions .github/container/Dockerfile.jax
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ pip install ninja && rm -rf ~/.cache/pip
# TransformerEngine now needs JAX at build time
git-clone.sh ${URLREF_TRANSFORMER_ENGINE} ${SRC_PATH_TRANSFORMER_ENGINE}
pushd ${SRC_PATH_TRANSFORMER_ENGINE}
export NVTE_BUILD_THREADS_PER_JOB=8
python setup.py bdist_wheel && rm -rf build
ls "${SRC_PATH_TRANSFORMER_ENGINE}/dist"
EOF
Expand Down
3 changes: 3 additions & 0 deletions .github/container/build-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,9 @@ pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUIL
# jaxlib 0.4.32.dev20240808 /opt/jaxlibs/jaxlib
pip list | grep jax

# Ensure directories are readable by all for non-root users
chmod 755 $BUILD_PATH_JAXLIB/*

## Cleanup

pushd $SRC_PATH_JAX
Expand Down
5 changes: 2 additions & 3 deletions .github/container/test-maxtext.sh
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ export NVTE_FUSED_ATTN=${ENABLE_FUSED_ATTN}
export XLA_PYTHON_CLIENT_MEM_FRACTION=${MEM_FRACTION}
export CUDA_DEVICE_MAX_CONNECTIONS=1

export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true
export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_triton_gemm=false
--xla_gpu_graph_level=0
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824
Expand All @@ -232,8 +232,7 @@ export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_schedule
--xla_gpu_enable_pipelined_all_gather=true
--xla_gpu_enable_pipelined_reduce_scatter=true
--xla_gpu_enable_pipelined_all_reduce=true
--xla_gpu_enable_while_loop_double_buffering=true
--xla_gpu_enable_triton_softmax_fusion=false
--xla_gpu_enable_while_loop_double_buffering=true
--xla_gpu_enable_all_gather_combine_by_dim=false
--xla_gpu_enable_reduce_scatter_combine_by_dim=false
--xla_disable_hlo_passes=rematerialization}
Expand Down
49 changes: 30 additions & 19 deletions .github/container/test-pax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ usage() {
echo " -a, --additional-args Additional fiddle args to pass to paxml/main.py"
echo " -b, --batch-per-gpu Batch size per GPU, defaults to 4."
echo " --dtype Batch size, defaults to bfloat16."
echo " --enable-te If set, will run with env var ENABLE_TE=1."
echo " --enable-te If set, will run with env var ENABLE_TE=1."
echo " --enable-cudnn-fa If set, will use cudnn fa."
echo " --enable-dropout If set, will set DROPOUT_PROB to 0.1."
echo " --disable-fused-attn Whether disable TE fused attention."
echo " --model-type One of 126M, 5B, LLaMA70BProxy. Defaults to 126M"
Expand All @@ -26,13 +27,13 @@ usage() {
echo " --data-parallel Data parallelism to use. Defaults to 1."
echo " --fsdp Fully-sharded data parallelism to use. Defaults to 1."
echo " --tensor-parallel Tensor parallelism to use. Defaults to 1."
echo " --pipeline-parallel Pipeline parallelism to use. Defaults to 1 for no pipelining."
echo " --pipeline-parallel Pipeline parallelism to use. Defaults to 1 for no pipelining."
echo " -n, --nodes Number of nodes."
echo " -h, --help Print usage."
exit $1
}

args=$(getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-dropout,disable-fused-attn,model-type:,evaluate,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@")
args=$(getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-cudnn-fa,enable-dropout,disable-fused-attn,model-type:,evaluate,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@")
if [[ $? -ne 0 ]]; then
exit $1
fi
Expand All @@ -50,6 +51,7 @@ TP=1
PP=1
NODES=1
ENABLE_TE=0
ENABLE_CUDNN_FA=0
MODEL_TYPE=126M
NVTE_FUSED_ATTN=1
DROPOUT=0
Expand All @@ -75,6 +77,10 @@ while [ : ]; do
ENABLE_TE=1
shift 1
;;
--enable-cudnn-fa)
ENABLE_CUDNN_FA=1
shift 1
;;
--enable-dropout)
DROPOUT='0.1'
shift 1
Expand Down Expand Up @@ -128,7 +134,7 @@ while [ : ]; do
;;
--)
shift;
break
break
;;
*)
echo "UNKNOWN OPTION $1"
Expand All @@ -149,6 +155,7 @@ print_var NGPUS
print_var OUTPUT
print_var MULTIPROCESS
print_var ENABLE_TE
print_var ENABLE_CUDNN_FA
print_var NVTE_FUSED_ATTN
print_var EVALUATE
print_var DROPOUT
Expand Down Expand Up @@ -196,10 +203,10 @@ if dcn_factor > 1:
if dp % dcn_factor == 0:
dcn_dp = dcn_factor
dp = int(dp / dcn_factor)
elif fsdp % dcn_factor == 0:
elif fsdp % dcn_factor == 0:
dcn_fsdp = dcn_factor
fsdp = int(fsdp / dcn_factor)
elif pp % dcn_factor == 0:
elif pp % dcn_factor == 0:
dcn_pp = dcn_factor
pp = int(pp / dcn_factor)

Expand All @@ -209,12 +216,12 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam):
USE_REPEATED_LAYER = False
ICI_MESH_SHAPE = [64,1,1]
MAX_STEPS = 600000

MAX_SEQ_LEN = 2048
VOCAB_SIZE = 50304
PACKED_INPUT = True
PERCORE_BATCH_SIZE = 4

NUM_LAYERS = 12
NUM_HEADS = 12
MODEL_DIMS = 768
Expand All @@ -223,14 +230,14 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam):

TRAINABLE_POSITION_EMB = True
TRAINABLE_PE_MAX_SEQ_LEN = MAX_SEQ_LEN

USE_BIAS = True
LAYERNORM_EPSILON = 1e-5
ATTEN_LOGIT_CAP = -1.0
INIT_STD = 0.023
SOFTMAX_INIT_STD = 0.023
ACTIVATION_CLS = layers.GELU

## optimizer-related
ADAM_BETA1 = 0.9
ADAM_BETA2 = 0.95
Expand All @@ -255,15 +262,15 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam):
## disable eval to avoid including eval
## in steps/sec calculation
EVAL_INTERVAL_STEPS = 100000

def task(self):
task_p = super().task()
task_p = configure_gpt3_task(self, task_p)

task_p.train.num_train_steps = self.MAX_STEPS

model_p = task_p.model

### compute layernorm reductions in fp32. Needed for stable training on GPUs
stacked_p = model_p.lm_tpl.stacked_transformer_tpl
if stacked_p.cls == layers.PipelinedTransformer:
Expand All @@ -274,13 +281,13 @@ class GPT126MPP(TransformerLmSpmdPipelineAdam):
transformer_layer_p.ln_tpl.reductions_in_fp32 = True
transformer_layer_p.tr_fflayer_tpl.ln_tpl.reductions_in_fp32 = True
task_p.model.lm_tpl.final_ln_tpl.reductions_in_fp32 = True

model_p.params_init = WeightInit.Gaussian(self.INIT_STD)
softmax_init = WeightInit.Gaussian(self.SOFTMAX_INIT_STD)
model_p.lm_tpl.softmax_tpl.params_init = softmax_init

model_p.apply_eval_sample_weights = True

## set input, residual, attention dropout to DROPOUT_PROB, remaining dropout to 0
stacked_p.dropout_prob = 0.0
stacked_p.input_dropout_prob = self.DROPOUT_PROB
Expand Down Expand Up @@ -316,14 +323,14 @@ class LLaMA70BSyntheticSmall(BaseLLaMA, SyntheticDataset):
if pp > 1:
@experiment_registry.register
class Synthetic126MCI(GPT126MPP, SyntheticDataset):

ICI_MESH_SHAPE = [pp, dp, fsdp, tp]
DCN_MESH_SHAPE = [dcn_pp, dcn_dp, dcn_fsdp, 1]
MICROBATCH_SIZE = 2
NUM_STAGES = pp
PERCORE_BATCH_SIZE = percore_batch_size
FRPOP_DTYPE = dtype

def task(self):
task_p = super().task()
task_p.train.always_use_train_for_model_init=False
Expand All @@ -333,7 +340,7 @@ if pp > 1:
else:
@experiment_registry.register
class Synthetic126MCI(Synthetic126M):

ICI_MESH_SHAPE = [dp, fsdp, tp]
DCN_MESH_SHAPE = [dcn_dp, dcn_fsdp, 1]
PERCORE_BATCH_SIZE = percore_batch_size
Expand All @@ -343,7 +350,7 @@ else:

## disable eval
EVAL_INTERVAL_STEPS = 100000

def task(self):
task_p = super().task()

Expand Down Expand Up @@ -374,6 +381,10 @@ export ENABLE_TE=$ENABLE_TE
export NVTE_FUSED_ATTN=$NVTE_FUSED_ATTN
export VOCAB_PATH=${VOCAB_PATH:-gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model}

if [[ ${ENABLE_CUDNN_FA} -ne 0 ]]; then
ADDITIONAL_ARGS="${ADDITIONAL_ARGS} --fdl.USE_CUDNN_FLASH_ATTENTION=True"
fi

if [[ ${MODEL_TYPE} == "126M" ]]; then
CONFIG=ci_configs.Synthetic126MCI
elif [[ ${MODEL_TYPE} == "5B" ]]; then
Expand Down
6 changes: 6 additions & 0 deletions .github/workflows/_ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ on:
description: 'Build date in YYYY-MM-DD format'
required: false
default: NOT SPECIFIED
CUDA_IMAGE:
type: string
description: CUDA image to use as base, e.g. nvidia/cuda:X.Y.Z-devel-ubuntu22.04
default: 'latest'
required: false
MANIFEST_ARTIFACT_NAME:
type: string
description: 'Artifact name in current run w/ manifest/patches. Leaving empty uses manifest/patches in current branch'
Expand All @@ -37,6 +42,7 @@ jobs:
uses: ./.github/workflows/_build_base.yaml
with:
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
BASE_IMAGE: ${{ inputs.CUDA_IMAGE }}
BUILD_DATE: ${{ inputs.BUILD_DATE }}
MANIFEST_ARTIFACT_NAME: ${{ inputs.MANIFEST_ARTIFACT_NAME }}
secrets: inherit
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/baselines/test_maxtext_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
def test_loss(baseline_filename):
baseline_filepath = os.path.join(baselines_dir, baseline_filename)
test_config = baseline_filename.split(".")[0]
event_file = os.path.join(results_dir, test_config, "logdir/tensorboard/events*")
event_file = os.path.join(results_dir, test_config, "logdir/tensorboard/logdir/events*")
event_file = glob.glob(event_file)[0]
with open(baseline_filepath, "r") as baseline_file:
end_step = json.load(baseline_file)["end_step"]
Expand All @@ -31,7 +31,7 @@ def test_loss(baseline_filename):
def test_step_time(baseline_filename):
baseline_filepath = os.path.join(baselines_dir, baseline_filename)
test_config = baseline_filename.split(".")[0]
event_file = os.path.join(results_dir, test_config, "logdir/tensorboard/events*")
event_file = os.path.join(results_dir, test_config, "logdir/tensorboard/logdir/events*")
event_file = glob.glob(event_file)[0]
with open(baseline_filepath, "r") as baseline_file:
step_time_avg_expected = json.load(baseline_file)["step_time_avg"]
Expand Down
19 changes: 19 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ on:
description: "(used if BUMP_MANIFEST=true) If true: attempt to PR/merge manifest branch"
default: false
required: false
CUDA_IMAGE:
type: string
description: CUDA image to use as base, e.g. nvidia/cuda:X.Y.Z-devel-ubuntu22.04
default: 'latest'
required: false
SOURCE_OVERRIDES:
type: string
description: |
Expand Down Expand Up @@ -60,6 +65,7 @@ jobs:
MANIFEST_ARTIFACT_NAME: ${{ steps.manifest-branch.outputs.MANIFEST_ARTIFACT_NAME }}
MANIFEST_BRANCH: ${{ steps.manifest-branch.outputs.MANIFEST_BRANCH }}
MERGE_BUMPED_MANIFEST: ${{ steps.manifest-branch.outputs.MERGE_BUMBED_MANIFEST }}
CUDA_IMAGE: ${{ steps.cuda-image.outputs.CUDA_IMAGE }}
steps:
- name: Cancel workflow run if the trigger is a draft PR
id: cancel-if-draft
Expand Down Expand Up @@ -114,6 +120,17 @@ jobs:
exit 1
fi

- name: Determine CUDA image to use
id: cuda-image
shell: bash -x -e {0}
run: |
if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then
CUDA_IMAGE="${{ inputs.CUDA_IMAGE }}"
else
CUDA_IMAGE="latest"
fi
echo "CUDA_IMAGE=${CUDA_IMAGE}" >> $GITHUB_OUTPUT

bump-manifest:
needs: metadata
runs-on: ubuntu-22.04
Expand Down Expand Up @@ -177,6 +194,7 @@ jobs:
with:
ARCHITECTURE: amd64
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
CUDA_IMAGE: ${{ needs.metadata.outputs.CUDA_IMAGE }}
MANIFEST_ARTIFACT_NAME: ${{ needs.metadata.outputs.MANIFEST_ARTIFACT_NAME }}
SOURCE_URLREFS: ${{ needs.bump-manifest.outputs.SOURCE_URLREFS }}
secrets: inherit
Expand All @@ -187,6 +205,7 @@ jobs:
with:
ARCHITECTURE: arm64
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
CUDA_IMAGE: ${{ needs.metadata.outputs.CUDA_IMAGE }}
MANIFEST_ARTIFACT_NAME: ${{ needs.metadata.outputs.MANIFEST_ARTIFACT_NAME }}
SOURCE_URLREFS: ${{ needs.bump-manifest.outputs.SOURCE_URLREFS }}
secrets: inherit
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is emb

There are various other XLA flags users can set to improve performance. For a detailed explanation of these flags, please refer to the [GPU performance](./rosetta/docs/GPU_performance.md) doc. XLA flags can be tuned per workflow. For example, each script in [contrib/gpu/scripts_gpu](https://github.com/google/paxml/tree/main/paxml/contrib/gpu/scripts_gpu) sets its own [XLA flags](https://github.com/google/paxml/blob/93fbc8010dca95af59ab615c366d912136b7429c/paxml/contrib/gpu/scripts_gpu/benchmark_gpt_multinode.sh#L30-L33).

For a list of previously used XLA flags that are no longer needed, please also refer to the [GPU performance](./rosetta/docs/GPU_performance.md#previously-used-xla-flags) page.

## Profiling JAX programs on GPU
See [this page](./docs/profiling.md) for more information about how to profile JAX programs on GPU.

Expand Down
6 changes: 5 additions & 1 deletion rosetta/docs/GPU_performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ Fine-grain control to improve performance by initializing a NCCL communicator to
- --xla_gpu_enable_cudnn_fmha=false (enables XLA pattern matcher to detect multi-headed attention pattern in JAX)
- --xla_disable_hlo_passes=<> (turns off specific HLO passes; can be used for debugging)

## Previously used XLA Flags


The following flags were used previously used but no longer required.
- --xla_gpu_enable_async_reduce_scatter, --xla_gpu_enable_async_all_reduce, --xla_gpu_enable_async_all_gather ; Turned on by default, no longer needed
- --xla_gpu_enable_highest_priority_async_stream ; Turned on by default
- --xla_gpu_enable_triton_softmax_fusion ; Deprecated, no longer used

13 changes: 5 additions & 8 deletions rosetta/docs/NATIVE_FP8.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,22 +111,19 @@ Enabling this feature is effortless. Users only need to include the option `--fd
In addition to the suggested XLA flags mentioned in [this section](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta/rosetta/projects/pax/README.md#xla-flags), we also recommend setting these following XLA flags. The execution script should look like:
```bash
export XLA_FLAGS=" \
--xla_gpu_enable_reduction_epilogue_fusion=false \
--xla_gpu_enable_triton_gemm=false \
--xla_gpu_enable_cudnn_fmha=false \
--xla_gpu_enable_cudnn_layer_norm=true \
--xla_gpu_enable_cublaslt=true \
--xla_gpu_enable_latency_hiding_scheduler=true \
--xla_gpu_all_reduce_combine_threshold_bytes=51200 "
--xla_gpu_enable_pipelined_all_reduce=false \
--xla_gpu_enable_pipelined_all_gather=false \
--xla_gpu_enable_pipelined_reduce_scatter=false \
"
export ENABLE_TE=0
python -m paxml.main \
...
--fdl.USE_FP8=True \
...
```

Please ensure you include the first two flags, `--xla_gpu_enable_reduction_epilogue_fusion=false` and `--xla_gpu_enable_triton_gemm=false`, as they are essential for enabling the FP8 functionality. The additional flags primarily focus on performance enhancement and should also prove beneficial for non-FP8 executions.

Please not that disabling the triton gemm and pipelined collectives is essential for enabling the FP8 functionality and performance.

## Transformer Engine vs Native FP8 Support
Native XLA-FP8 specifically targets matrix multiplication operations. In contrast, the Transformer Engine focuses on enhancing the overall performance of the entire transformer layer. This encompasses not only the FP8 matrix multiplication but also attention mechanisms, layer normalizations, and other components.
Expand Down
1 change: 0 additions & 1 deletion rosetta/docs/PGLE.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_pipelined_reduce_scatter=true
--xla_gpu_enable_pipelined_all_reduce=true
--xla_gpu_enable_while_loop_double_buffering=true
--xla_gpu_enable_triton_softmax_fusion=false
--xla_gpu_enable_all_gather_combine_by_dim=false
--xla_gpu_enable_reduce_scatter_combine_by_dim=false
--xla_disable_hlo_passes=rematerialization
Expand Down
Loading
Loading