Skip to content

Commit 8fbacde

Browse files
committed
placeholder for models on eks
1 parent 0de66b0 commit 8fbacde

File tree

3 files changed

+200
-0
lines changed

3 files changed

+200
-0
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
apiVersion: batch/v1
2+
kind: Job
3+
metadata:
4+
name: PLACEHOLDER
5+
labels:
6+
kueue.x-k8s.io/queue-name: p5-queue
7+
spec:
8+
completions: 1
9+
parallelism: 1
10+
template:
11+
spec:
12+
restartPolicy: Never
13+
containers:
14+
- name: axlearn
15+
image: PLACEHOLDER
16+
command:
17+
- bash
18+
- -xo
19+
- pipefail
20+
- -c
21+
- |
22+
23+
BASEDIR="/opt/axlearn"
24+
CONFIG="fuji-1B-v3-flash-single-host"
25+
HLO_DUMP=0
26+
POSTFIX=""
27+
28+
AR_THRESHOLD=1073741824
29+
AG_THRESHOLD=8589934592
30+
RS_THRESHOLD=8589934592
31+
XLA_BASE_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
32+
--xla_gpu_enable_triton_gemm=false
33+
--xla_gpu_enable_highest_priority_async_stream=true
34+
--xla_gpu_all_gather_combine_threshold_bytes=${AG_THRESHOLD}
35+
--xla_gpu_reduce_scatter_combine_threshold_bytes=${RS_THRESHOLD}
36+
--xla_gpu_enable_pipelined_all_gather=true
37+
--xla_gpu_enable_pipelined_reduce_scatter=true
38+
--xla_gpu_enable_nccl_comm_splitting=false"
39+
40+
export XLA_PYTHON_CLIENT_PREALLOCATE=false
41+
export TF_GPU_ALLOCATOR=cuda_malloc_async
42+
export XLA_FLAGS="${XLA_BASE_FLAGS}"
43+
44+
export NCCL_BUFFSIZE=8388608
45+
export NCCL_P2P_NET_CHUNKSIZE=524288
46+
export NCCL_LAUNCH_MODE=GROUP
47+
export NCCL_DEBUG=INFO
48+
49+
LOG_DIR=${BASEDIR}/logs
50+
TRAINER_DIR=${LOG_DIR}/${CONFIG}${POSTFIX}-eks/trainer-dir
51+
mkdir -p ${TRAINER_DIR}
52+
53+
cat << EOF > tf_gpu_fix.py
54+
import tensorflow as tf
55+
tf.config.set_visible_devices([], 'GPU')
56+
import runpy
57+
runpy.run_module('axlearn.common.launch_trainer_main', run_name='__main__')
58+
EOF
59+
60+
python3 tf_gpu_fix.py \
61+
--module=text.gpt.c4_trainer \
62+
--config=${CONFIG} \
63+
--trainer_dir=${TRAINER_DIR} \
64+
--data_dir=gs://axlearn-public/tensorflow_datasets \
65+
--jax_backend=gpu
66+
67+
resources:
68+
limits:
69+
nvidia.com/gpu: 8
70+
volumeMounts:
71+
- name: output
72+
mountPath: /opt/output
73+
imagePullSecrets:
74+
- name: PLACEHOLDER
75+
volumes:
76+
- name: output
77+
emptyDir: {}
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
apiVersion: batch/v1
2+
kind: Job
3+
metadata:
4+
name: PLACEHOLDER
5+
labels:
6+
kueue.x-k8s.io/queue-name: p5-queue
7+
spec:
8+
completions: 1
9+
parallelism: 1
10+
template:
11+
spec:
12+
restartPolicy: Never
13+
containers:
14+
- name: axlearn
15+
image: PLACEHOLDER
16+
command:
17+
- bash
18+
- -xo
19+
- pipefail
20+
- -c
21+
- |
22+
23+
BASEDIR="/opt/axlearn"
24+
CONFIG="fuji-3B-v3-flash-single-host"
25+
HLO_DUMP=0
26+
POSTFIX=""
27+
28+
AR_THRESHOLD=1073741824
29+
AG_THRESHOLD=8589934592
30+
RS_THRESHOLD=8589934592
31+
XLA_BASE_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
32+
--xla_gpu_enable_triton_gemm=false
33+
--xla_gpu_enable_highest_priority_async_stream=true
34+
--xla_gpu_all_gather_combine_threshold_bytes=${AG_THRESHOLD}
35+
--xla_gpu_reduce_scatter_combine_threshold_bytes=${RS_THRESHOLD}
36+
--xla_gpu_enable_pipelined_all_gather=true
37+
--xla_gpu_enable_pipelined_reduce_scatter=true
38+
--xla_gpu_enable_nccl_comm_splitting=false"
39+
40+
export XLA_PYTHON_CLEINT_PREALLOCATE=false
41+
export TF_GPU_ALLOCATOR=cuda_malloc_async
42+
export XLA_FLAGS="${XLA_BASE_FLAGS}"
43+
44+
export NCCL_BUFFSIZE=8388608
45+
export NCCL_P2P_NET_CHUNKSIZE=524288
46+
export NCCL_LAUNCH_MODE=GROUP
47+
export NCCL_DEBUG=INFO
48+
49+
LOG_DIR=${BASEDIR}/logs
50+
TRAINER_DIR=${LOG_DIR}/${CONFIG}${POSTFIX}-eks/trainer-dir
51+
mkdir -p ${TRAINER_DIR}
52+
53+
echo "Executing TF"
54+
cat << EOF > tf_fix_gpu.py
55+
import tensorflow as tf
56+
tf.config.set_visible_devices([], 'GPU')
57+
import runpy
58+
runpy.run_module('axlearn.common.launch_trainer_main', run_name='__main__')
59+
EOF
60+
61+
python3 tf_fix_gpu.py \
62+
--module=text.gpt.c4_trainer \
63+
--config=${CONFIG} \
64+
--trainer_dir=${TRAINER_DIR} \
65+
--data_dir=gs://axlearn-public/tensorflow_datasets \
66+
--jax_backend=gpu
67+
68+
resources:
69+
limits:
70+
nvidia.com/gpu: 8
71+
volumeMounts:
72+
- name: output
73+
mountPath: /opt/output
74+
imagePullSecrets:
75+
- name: PLACEHOLDER
76+
volumes:
77+
- name: output
78+
emptyDir: {}

.github/workflows/_ci.yaml

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -759,4 +759,49 @@ jobs:
759759
sitrep.json
760760
"badge-axlearn-test"
761761
summary.txt
762+
763+
764+
test-axlearn-fuji-1B:
765+
needs: build-axlearn
766+
if: inputs.ARCHITECTURE == 'amd64'
767+
runs-on: eks
768+
env:
769+
AXLEARN_DOCKER_IMAGE: ${{ needs.build-axlearn.outputs.DOCKER_TAG_FINAL }}
770+
JOB_NAME: axlearn-fuji-1B-${{ github.run_id }}
771+
TOKEN_NAME: axlearn-fuji-1B-${{ github.run_id }}-token
772+
steps:
773+
- name: Check out the repository
774+
uses: actions/checkout@v4
775+
- name: GHCR Login
776+
uses: ./.github/actions/ghcr-login
777+
with:
778+
docker-username: ${{ github.repository_owner }}
779+
docker-password: ${{ secrets.GITHUB_TOKEN }}
780+
token-name: ${{ env.TOKEN_NAME }}
781+
- name: Configure axlearn test job
782+
run: |
783+
yq -i ea '
784+
select(di == 0).metadata.name = strenv(JOB_NAME)
785+
| select(di == 0).spec.template.spec.containers[0].image = strenv(AXLEARN_DOCKER_IMAGE)
786+
| select(di == 0).spec.template.spec.imagePullSecrets[].name = strenv(TOKEN_NAME)' \
787+
.github/eks-workflow-files/axlearn/axlearn-1B-model.yml
788+
git diff .github/eks-workflow-files/axlearn/axlearn-1B-model.yml
789+
790+
- name: Submit & wait for axlearn test job
791+
uses: ./.github/actions/submit-k8s-job
792+
with:
793+
job-config-file: ".github/eks-workflow-files/axlearn/axlearn-1B-model.yml"
794+
job-name: ${{ env.JOB_NAME }}
795+
796+
- name: Delete axlearn test job
797+
uses: ./.github/actions/delete-k8s-job
798+
if: ${{ always() }}
799+
with:
800+
job-name: ${{ env.JOB_NAME }}
801+
802+
- name: Delete GitHub Container Registry token
803+
uses: ./.github/actions/delete-ghcr-token
804+
if: ${{ always() }}
805+
with:
806+
token-name: ${{ env.TOKEN_NAME }}
762807

0 commit comments

Comments
 (0)