Skip to content
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion .github/actions/gke-xpk/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ runs:

if [ $? -ne 0 ]; then
echo "The JobSet ${WORKLOAD_NAME} on ${{ inputs.GKE_CLUSTER }} did not complete as expected "
echo "XPK_EXIT_CODE=1" >> ${GITHUB_ENV}
exit 1
fi

Expand All @@ -262,11 +263,12 @@ runs:
ALL_EXIT_CODES=$(( ALL_EXIT_CODES + POD_EXIT_CODE ))
done

echo "XPK_EXIT_CODE=${ALL_EXIT_CODES}" >> ${GITHUB_ENV}
if [ ${ALL_EXIT_CODES} -gt 0 ]; then
exit 1
fi
exit 0

- name: Clean up JobSet from cluster
shell: bash -x -u {0}
if: ${{ always() }}
Expand All @@ -291,3 +293,38 @@ runs:
if: ${{ always() }}
run: |
sudo rm -rf ${WORKLOAD_NAME}

- name: Generate sitrep
id: sitrep
shell: bash -x -e {0}
if: ${{ always() }}
run: |
source .github/workflows/scripts/to_json.sh
badge_label="${{ matrix.test }}"

summary="${{ inputs.WORKLOAD_NAME_PREFIX }}"
outcome=success
badge_label="${{ inputs.WORKLOAD_NAME_PREFIX }}"
badge_color=brightgreen

if [ "${XPK_EXIT_CODE}" -gt 0 ]; then
badge_color=red
outcome=failed
summary+=": fail"
else
summary+=": pass"
fi

to_json summary \
badge_label \
badge_color \
outcome | \
tee sitrep.json

- name: Upload sitrep to GitHub Actions from runner
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: ${{ inputs.WORKLOAD_NAME_PREFIX }}-sitrep
path: |
sitrep.json
1 change: 1 addition & 0 deletions .github/container/Dockerfile.jax
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ RUN mkdir -p /builder/extra-targets && \
--src-path-xla ${SRC_PATH_XLA} \
--sm all \
--clean \
--release \
${EXTRA_BUILD_JAX_ARGS}

## Transformer engine: check out source and build wheel
Expand Down
13 changes: 12 additions & 1 deletion .github/container/build-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ INSTALL=1
SRC_PATH_JAX="/opt/jax"
SRC_PATH_XLA="/opt/xla"

args=$(getopt -o h --long bazel-cache:,bazel-cache-namespace:,build-param:,build-path-jaxlib:,clean,cpu-arch:,debug,extra-targets:,extra-target-dest:,no-clean,clean-only,help,install,no-install,src-path-jax:,src-path-xla:,sm: -- "$@")
args=$(getopt -o h,r --long bazel-cache:,bazel-cache-namespace:,build-param:,build-path-jaxlib:,clean,release,cpu-arch:,debug,extra-targets:,extra-target-dest:,no-clean,clean-only,help,install,no-install,src-path-jax:,src-path-xla:,sm: -- "$@")
if [[ $? -ne 0 ]]; then
exit 1
fi
Expand Down Expand Up @@ -135,6 +135,10 @@ while [ : ]; do
EXTRA_TARGET_DEST="$2"
shift 2
;;
-r | --release)
IS_RELEASE=1
shift 1
;;
-h | --help)
usage 1
;;
Expand Down Expand Up @@ -225,6 +229,7 @@ print_var INSTALL
print_var PYTHON_VERSION
print_var SRC_PATH_JAX
print_var SRC_PATH_XLA
print_var IS_RELEASE

echo "=================================================="

Expand Down Expand Up @@ -268,6 +273,12 @@ for component in jaxlib "jax-cuda${CUDA_MAJOR_VERSION}-pjrt" "jax-cuda${CUDA_MAJ
# version, so nvidia-*-cu12 wheels disappear from the lock file
sed -i "s|^${component}.*$|${component} @ file://${BUILD_PATH_JAXLIB}/${component//-/_}|" build/requirements.in
done

if [[ "${IS_RELEASE}" == "1" ]]; then
jaxlib_version=$(pip show jaxlib | grep Version | tr ':' '\n' | tail -1)
sed -i "s| f'jaxlib >={_minimum_jaxlib_version}, <={_jax_version}',| f'jaxlib>=0.5.0',|" /opt/jax/setup.py
fi

# Bazel args to avoid cache invalidation
BAZEL_ARGS=(
--config=cuda_libraries_from_stubs
Expand Down
21 changes: 19 additions & 2 deletions .github/container/build-te.sh
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ if [[ "$SM" == "all" ]]; then
SM_LIST=$(default_compute_capabilities)
elif [[ "$SM" == "local" ]]; then
SM_LIST=$("${SCRIPT_DIR}/local_cuda_arch")
if [[ -z "${SM_LIST}" ]]; then
echo "Could not determine the local GPU architecture."
echo "You should pass --sm when compiling on a machine without GPUs."
nvidia-smi || true
exit 1
fi
else
SM_LIST=${SM}
fi
Expand Down Expand Up @@ -131,8 +137,19 @@ export NVTE_FRAMEWORK=jax
export XLA_HOME=${SRC_PATH_XLA}

pushd ${SRC_PATH_TE}
# Install required packages that were removed in https://github.com/NVIDIA/TransformerEngine/pull/1852
pip install "pybind11[global]"
# Install some build dependencies, but avoid installing everything
# (jax, torch, ...) because we do not want to pull in a released version of
# JAX, or the wheel-based installation of CUDA. Note that when we build TE as
# part of building the JAX containers, JAX and XLA are not yet installed.
python - << EOF
import subprocess, sys, tomllib
with open("pyproject.toml", "rb") as ifile:
data = tomllib.load(ifile)
subprocess.run(
[sys.executable, "-m", "pip", "install"]
+ [r for r in data["build-system"]["requires"]
if r.startswith("nvidia-mathdx") or r.startswith("pybind11")])
EOF

# The wheel filename includes the TE commit; if this has changed since the last
# incremental build then we would end up with multiple wheels.
Expand Down
7 changes: 7 additions & 0 deletions .github/container/git-clone.sh
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ pushd ${DESTINATION}
git checkout ${GIT_REF}
COMMIT_SHA=$(git rev-parse HEAD)
git submodule update --init --recursive
if [[ "${GIT_REPO}" == *"gitlab"* ]]; then
git remote remove origin
if grep -q -r gitlab-ci-token .git; then
grep -r gitlab-ci-token .git | awk -F: '{print $1}' | xargs rm -f
fi
git branch -D main
fi
popd

## update the manifest file
Expand Down
10 changes: 9 additions & 1 deletion .github/container/test-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,15 @@ fi

readarray -t GPU_MEMORIES < <(nvidia-smi --query-gpu=memory.total --format=csv,noheader)
NGPUS="${#GPU_MEMORIES[@]}"
GPU_MEMORIES_MIB=("${GPU_MEMORIES[@]/ MiB/}")
if [[ " ${GPU_MEMORIES[*]} " =~ [[:space:]]\[N/A\][[:space:]] ]]; then
# On iGPU devices, nvidia-smi reports [N/A] GPU memory; use the system
# memory size instead to estimate what each GPU can use
SYSTEM_MEMORY_MIB=$(grep MemTotal /proc/meminfo | awk '{print $2 / 1024}')
declare -a GPU_MEMORIES_MIB
for (( i = 0; i < NGPUS; i++ )); do GPU_MEMORIES_MIB+=($(( SYSTEM_MEMORY_MIB / NGPUS ))); done
else
GPU_MEMORIES_MIB=("${GPU_MEMORIES[@]/ MiB/}")
fi

FLAGS=()

Expand Down
120 changes: 120 additions & 0 deletions .github/eks-workflow-files/maxtext-job.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
apiVersion: v1
kind: Service
metadata:
name: PLACEHOLDER
spec:
clusterIP: None # clusterIP must be None to create a headless service
selector:
job-name: PLACEHOLDER # must match Job name
---
apiVersion: batch/v1
kind: Job
metadata:
name: PLACEHOLDER
labels:
kueue.x-k8s.io/queue-name: p5-queue
spec:
completions: 2 # number of nodes
parallelism: 2 # number of nodes
completionMode: Indexed
backoffLimitPerIndex: 0 # max failures per index
maxFailedIndexes: 0 # all indices must succeed
template:
spec:
subdomain: PLACEHOLDER # has to match Service name
restartPolicy: Never
imagePullSecrets:
- name: PLACEHOLDER
containers:
- name: maxtext
image: PLACEHOLDER
ports:
- containerPort: 3389
command:
- bash
- -c
# The logging logic: stream stdout/stderr from the 0th process inside this pod,
# record all of the processes' stdout/stderr + the INFO-level NCCL logs to file
- |
export SERVICE_NAME=$0
export JOB_NAME=$1
cat >each-process.sh <<'EOL'
export JAX_COORDINATOR_IP=${JOB_NAME}-0.${SERVICE_NAME}
export JAX_COORDINATOR_PORT=3389
export NNODES=16 # actually #processes == #GPUs
export NODE_RANK=$((JOB_COMPLETION_INDEX*8 + LOCAL_RANK))
export JAX_LOCAL_DEVICE_IDS=$LOCAL_RANK
export NCCL_DEBUG=INFO
export NCCL_DEBUG_FILE=/opt/output/nccl.$NODE_RANK.log
[[ $LOCAL_RANK == 0 ]] && console="/dev/stdout" || console="/dev/null"
nsys-jax \
--capture-range=cudaProfilerApi \
--capture-range-end=stop \
-o /opt/output/profile.$NODE_RANK.zip \
-- \
test-maxtext.sh \
-n 2 \
-b 2 \
--model-name=llama2-7b \
--attn-type=cudnn_flash_te \
--remat-policy=minimal_flash \
--steps=20 \
--fsdp=16 \
-a "scan_layers=false \
max_target_length=4096 \
use_iota_embed=true \
logits_dot_in_fp32=false \
profiler=nsys \
skip_first_n_steps_for_profiler=3 \
profiler_steps=8" \
|& tee /opt/output/output.$NODE_RANK.log >"${console}"
code=$?
# Should run even on failure
cat /opt/output/nccl.$NODE_RANK.log >"${console}"
exit $code
EOL
# TODO: upgrade parallel-launch to return a failure code as soon as any
# of its children do (it already does this eventually, but it could
# be slow)
parallel-launch LOCAL_RANK 8 bash each-process.sh
code=$?
# Should run even on failure
touch /opt/output/.done
exit $code
- PLACEHOLDER
- PLACEHOLDER
resources:
limits:
nvidia.com/gpu: 8
vpc.amazonaws.com/efa: 32
volumeMounts:
- mountPath: /dev/shm
name: shmem
- mountPath: /opt/output
name: output
- name: upload
image: amazon/aws-cli
command:
- bash
- -c
- |
JOB_NAME="$0"
while [[ ! -f /opt/output/.done ]]; do
sleep 1
done
rm /opt/output/.done
aws s3 cp \
--recursive \
/opt/output \
"s3://jax-toolbox-eks-output/${JOB_NAME}/"
- PLACEHOLDER
volumeMounts:
- mountPath: /opt/output
name: output
volumes:
- name: output
emptyDir: {}
- name: shmem
emptyDir:
medium: Memory
sizeLimit: 16Gi
10 changes: 5 additions & 5 deletions .github/workflows/_test_maxtext_gke_xpk.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ jobs:
steps:
- uses: actions/checkout@v4

- name: Login to GitHub Container Registry
- name: Login to nvcr.io Container Registry
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
registry: nvcr.io
username: $oauthtoken
password: ${{ secrets.NVCR_TOKEN }}

- name: K8s GHCR store and delete token
- name: K8s store and delete token
id: store-token
uses: ./.github/actions/store-delete-k8s-ghcr

Expand Down
Loading
Loading