Skip to content

Commit 1c09cb5

Browse files
aybchanoluptongpupuck
authored
Add NGC workflow (#1692)
- [x] Add NGC workflow + EKS maxtext workflow to `main` to avoid adding it to release branch anew --------- Co-authored-by: Olli Lupton <[email protected]> Co-authored-by: Brian Yang <[email protected]>
1 parent fcaa58f commit 1c09cb5

File tree

6 files changed

+376
-44
lines changed

6 files changed

+376
-44
lines changed

.github/actions/gke-xpk/action.yml

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ runs:
253253
254254
if [ $? -ne 0 ]; then
255255
echo "The JobSet ${WORKLOAD_NAME} on ${{ inputs.GKE_CLUSTER }} did not complete as expected "
256+
echo "XPK_EXIT_CODE=1" >> ${GITHUB_ENV}
256257
exit 1
257258
fi
258259
@@ -268,11 +269,12 @@ runs:
268269
ALL_EXIT_CODES=$(( ALL_EXIT_CODES + POD_EXIT_CODE ))
269270
done
270271
272+
echo "XPK_EXIT_CODE=${ALL_EXIT_CODES}" >> ${GITHUB_ENV}
271273
if [ ${ALL_EXIT_CODES} -gt 0 ]; then
272274
exit 1
273275
fi
274276
exit 0
275-
277+
276278
- name: Clean up JobSet from cluster
277279
shell: bash -x -u {0}
278280
if: ${{ always() }}
@@ -297,3 +299,38 @@ runs:
297299
if: ${{ always() }}
298300
run: |
299301
sudo rm -rf ${WORKLOAD_NAME}
302+
303+
- name: Generate sitrep
304+
id: sitrep
305+
shell: bash -x -e {0}
306+
if: ${{ always() }}
307+
run: |
308+
source .github/workflows/scripts/to_json.sh
309+
badge_label="${{ matrix.test }}"
310+
311+
summary="${{ inputs.WORKLOAD_NAME_PREFIX }}"
312+
outcome=success
313+
badge_label="${{ inputs.WORKLOAD_NAME_PREFIX }}"
314+
badge_color=brightgreen
315+
316+
if [ "${XPK_EXIT_CODE}" -gt 0 ]; then
317+
badge_color=red
318+
outcome=failed
319+
summary+=": fail"
320+
else
321+
summary+=": pass"
322+
fi
323+
324+
to_json summary \
325+
badge_label \
326+
badge_color \
327+
outcome | \
328+
tee sitrep.json
329+
330+
- name: Upload sitrep to GitHub Actions from runner
331+
if: ${{ always() }}
332+
uses: actions/upload-artifact@v4
333+
with:
334+
name: ${{ inputs.WORKLOAD_NAME_PREFIX }}-sitrep
335+
path: |
336+
sitrep.json

.github/container/git-clone.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,13 @@ if [[ -n "${GIT_REF}" ]]; then
7979
fi
8080
COMMIT_SHA=$(git rev-parse HEAD)
8181
git submodule update --init --recursive
82+
if [[ "${GIT_REPO}" == *"gitlab"* ]]; then
83+
git remote remove origin
84+
if grep -q -r gitlab-ci-token .git; then
85+
grep -r gitlab-ci-token .git | awk -F: '{print $1}' | xargs rm -f
86+
fi
87+
git branch -D main
88+
fi
8289
popd
8390

8491
## update the manifest file

.github/container/pip-finalize.sh

Lines changed: 52 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,54 +4,60 @@ set -eoux pipefail
44

55
pushd /opt/pip-tools.d
66

7-
# First pip-compile gathers all reqs, but we are care only about VCS installs
8-
# It's possible there are 2nd degree transitive dependencies that are VCS, so
9-
# this is more robust to gather VCS requirements at the cost of pip-compiling
10-
# twice
11-
pip-compile -o requirements.pre $(ls requirements-*.in)
7+
# If requirements-pinned.txt exists, skip compilation
8+
if [[ -f "requirements-pinned.txt" ]]; then
9+
sed -E 's/#sha256=[a-f0-9]+//g' requirements-pinned.txt > requirements.txt
10+
else
11+
# First pip-compile gathers all reqs, but we are care only about VCS installs
12+
# It's possible there are 2nd degree transitive dependencies that are VCS, so
13+
# this is more robust to gather VCS requirements at the cost of pip-compiling
14+
# twice
15+
pip-compile -o requirements.pre $(ls requirements-*.in)
1216

13-
IFS=$'\n'
14-
for line in $(cat requirements.pre | egrep '^[^#].+ @ git\+' || true); do
15-
# VCS installs are of the form "PACKAGE @ git+..."
16-
PACKAGE=$(echo "$line" | awk '{print $1}')
17-
ref=$(yq e ".${PACKAGE}.latest_verified_commit" ${MANIFEST_FILE})
18-
if [[ "$line" == *"#subdirectory="* ]]; then
19-
# This is required b/c git-refs/commits cannot come after
20-
# the subdirectory fragment.
21-
# An example of an install that is of this form is:
22-
# 'orbax-checkpoint @ git+https://github.com/google/orbax/#subdirectory=checkpoint'
23-
echo "${line}" | sed "s/#subdirectory=/@${ref}#subdirectory=/"
24-
else
25-
echo "${line}@${ref}"
26-
fi
27-
done | tee requirements.vcs
28-
unset IFS
17+
IFS=$'\n'
18+
for line in $(cat requirements.pre | egrep '^[^#].+ @ git\+' || true); do
19+
# VCS installs are of the form "PACKAGE @ git+..."
20+
PACKAGE=$(echo "$line" | awk '{print $1}')
21+
ref=$(yq e ".${PACKAGE}.latest_verified_commit" ${MANIFEST_FILE})
22+
if [[ "$line" == *"#subdirectory="* ]]; then
23+
# This is required b/c git-refs/commits cannot come after
24+
# the subdirectory fragment.
25+
# An example of an install that is of this form is:
26+
# 'orbax-checkpoint @ git+https://github.com/google/orbax/#subdirectory=checkpoint'
27+
echo "${line}" | sed "s/#subdirectory=/@${ref}#subdirectory=/"
28+
else
29+
echo "${line}@${ref}"
30+
fi
31+
done | tee requirements.vcs
32+
unset IFS
2933

30-
# Second pip-compile includes one more requirements file that pins all vcs installs
31-
# Uses a special env var to let our custom pip impl know to treat the following as
32-
# equivalent:
33-
#
34-
# fiddle @ git+https://github.com/google/fiddle
35-
# fiddle @ git+https://github.com/google/fiddle@cd4497e4c09bdf95dcccaa1e138c2c125d32d39f
36-
#
37-
# JAX_TOOLBOX_VCS_EQUIVALENCY is an environment variable enabling custom logic in pip
38-
# that treats the above as equivalent and prefers the URI wit the SHA
39-
JAX_TOOLBOX_VCS_EQUIVALENCY=true pip-compile -o requirements.txt requirements.vcs $(ls requirements-*.in)
34+
# Second pip-compile includes one more requirements file that pins all vcs installs
35+
# Uses a special env var to let our custom pip impl know to treat the following as
36+
# equivalent:
37+
#
38+
# fiddle @ git+https://github.com/google/fiddle
39+
# fiddle @ git+https://github.com/google/fiddle@cd4497e4c09bdf95dcccaa1e138c2c125d32d39f
40+
#
41+
# JAX_TOOLBOX_VCS_EQUIVALENCY is an environment variable enabling custom logic in pip
42+
# that treats the above as equivalent and prefers the URI wit the SHA
43+
JAX_TOOLBOX_VCS_EQUIVALENCY=true pip-compile -o requirements.txt requirements.vcs $(ls requirements-*.in)
4044

41-
# If there are unpinned VCS dependencies, error since these should be included in the manifest
42-
unpinned_vcs_dependencies=$(cat requirements.txt | egrep '^[^#].+ @ git\+' | egrep -v '^[^#].+ @ git\+.+@' || true)
43-
if [[ $(echo -n "$unpinned_vcs_dependencies" | wc -l) -gt 0 ]]; then
44-
echo "Unpinned VCS installs found in $(readlink -f requirements.txt):"
45-
echo "$unpinned_vcs_dependencies"
46-
exit 1
47-
fi
45+
# If there are unpinned VCS dependencies, error since these should be included in the manifest
46+
unpinned_vcs_dependencies=$(cat requirements.txt | egrep '^[^#].+ @ git\+' | egrep -v '^[^#].+ @ git\+.+@' || true)
47+
if [[ $(echo -n "$unpinned_vcs_dependencies" | wc -l) -gt 0 ]]; then
48+
echo "Unpinned VCS installs found in $(readlink -f requirements.txt):"
49+
echo "$unpinned_vcs_dependencies"
50+
exit 1
51+
fi
4852

49-
# Replace any tensorflow==X with tensorflow-cpu==X in requirements.txt only on amd64
50-
if [ "$(uname -m)" = "x86_64" ]; then
51-
sed -i 's/^tensorflow==\([0-9.*]\+\)$/tensorflow-cpu==\1/' requirements.txt
52-
else
53-
echo "Skipping TF on $(uname -m)"
53+
# Replace any tensorflow==X with tensorflow-cpu==X in requirements.txt only on amd64
54+
if [[ "$(uname -m)" = "x86_64" ]]; then
55+
sed -i 's/^tensorflow==\([0-9.*]\+\)$/tensorflow-cpu==\1/' requirements.txt
56+
else
57+
echo "Skipping TF on $(uname -m)"
58+
fi
5459
fi
60+
5561
# --no-deps is required since conflicts can still appear during pip-sync
5662
pip-sync --pip-args '--no-deps --src /opt' requirements.txt
5763

@@ -63,3 +69,6 @@ for post_install in $(ls /opt/pip-tools-post-install.d/*); do
6369
"${post_install}"
6470
fi
6571
done
72+
73+
echo "######## Frozen requirements ########"
74+
pip freeze
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
apiVersion: v1
2+
kind: Service
3+
metadata:
4+
name: PLACEHOLDER
5+
spec:
6+
clusterIP: None # clusterIP must be None to create a headless service
7+
selector:
8+
job-name: PLACEHOLDER # must match Job name
9+
---
10+
apiVersion: batch/v1
11+
kind: Job
12+
metadata:
13+
name: PLACEHOLDER
14+
labels:
15+
kueue.x-k8s.io/queue-name: p5-queue
16+
spec:
17+
completions: 2 # number of nodes
18+
parallelism: 2 # number of nodes
19+
completionMode: Indexed
20+
backoffLimitPerIndex: 0 # max failures per index
21+
maxFailedIndexes: 0 # all indices must succeed
22+
template:
23+
spec:
24+
subdomain: PLACEHOLDER # has to match Service name
25+
restartPolicy: Never
26+
imagePullSecrets:
27+
- name: PLACEHOLDER
28+
containers:
29+
- name: maxtext
30+
image: PLACEHOLDER
31+
ports:
32+
- containerPort: 3389
33+
command:
34+
- bash
35+
- -c
36+
# The logging logic: stream stdout/stderr from the 0th process inside this pod,
37+
# record all of the processes' stdout/stderr + the INFO-level NCCL logs to file
38+
- |
39+
export SERVICE_NAME=$0
40+
export JOB_NAME=$1
41+
cat >each-process.sh <<'EOL'
42+
export JAX_COORDINATOR_IP=${JOB_NAME}-0.${SERVICE_NAME}
43+
export JAX_COORDINATOR_PORT=3389
44+
export NNODES=16 # actually #processes == #GPUs
45+
export NODE_RANK=$((JOB_COMPLETION_INDEX*8 + LOCAL_RANK))
46+
export JAX_LOCAL_DEVICE_IDS=$LOCAL_RANK
47+
export NCCL_DEBUG=INFO
48+
export NCCL_DEBUG_FILE=/opt/output/nccl.$NODE_RANK.log
49+
[[ $LOCAL_RANK == 0 ]] && console="/dev/stdout" || console="/dev/null"
50+
nsys-jax \
51+
--capture-range=cudaProfilerApi \
52+
--capture-range-end=stop \
53+
-o /opt/output/profile.$NODE_RANK.zip \
54+
-- \
55+
test-maxtext.sh \
56+
-n 2 \
57+
-b 2 \
58+
--model-name=llama2-7b \
59+
--attn-type=cudnn_flash_te \
60+
--remat-policy=minimal_flash \
61+
--steps=20 \
62+
--fsdp=16 \
63+
-a "scan_layers=false \
64+
max_target_length=4096 \
65+
use_iota_embed=true \
66+
logits_dot_in_fp32=false \
67+
profiler=nsys \
68+
skip_first_n_steps_for_profiler=3 \
69+
profiler_steps=8" \
70+
|& tee /opt/output/output.$NODE_RANK.log >"${console}"
71+
code=$?
72+
# Should run even on failure
73+
cat /opt/output/nccl.$NODE_RANK.log >"${console}"
74+
exit $code
75+
EOL
76+
# TODO: upgrade parallel-launch to return a failure code as soon as any
77+
# of its children do (it already does this eventually, but it could
78+
# be slow)
79+
parallel-launch LOCAL_RANK 8 bash each-process.sh
80+
code=$?
81+
# Should run even on failure
82+
touch /opt/output/.done
83+
exit $code
84+
- PLACEHOLDER
85+
- PLACEHOLDER
86+
resources:
87+
limits:
88+
nvidia.com/gpu: 8
89+
vpc.amazonaws.com/efa: 32
90+
volumeMounts:
91+
- mountPath: /dev/shm
92+
name: shmem
93+
- mountPath: /opt/output
94+
name: output
95+
- name: upload
96+
image: amazon/aws-cli
97+
command:
98+
- bash
99+
- -c
100+
- |
101+
JOB_NAME="$0"
102+
while [[ ! -f /opt/output/.done ]]; do
103+
sleep 1
104+
done
105+
rm /opt/output/.done
106+
aws s3 cp \
107+
--recursive \
108+
/opt/output \
109+
"s3://jax-toolbox-eks-output/${JOB_NAME}/"
110+
- PLACEHOLDER
111+
volumeMounts:
112+
- mountPath: /opt/output
113+
name: output
114+
volumes:
115+
- name: output
116+
emptyDir: {}
117+
- name: shmem
118+
emptyDir:
119+
medium: Memory
120+
sizeLimit: 16Gi

0 commit comments

Comments
 (0)