@@ -31,16 +31,7 @@ concurrency:
3131 cancel-in-progress : true
3232
3333jobs :
34- prelim :
35- runs-on : ["linux-x86-n2-32"]
36- steps :
37- - name : Test gsutil installation
38- run : which gsutil >/dev/null 2>&1 || { echo >&2 "gsutil is required but not installed. Aborting"; exit 24;}
39- - name : Cleanup old docker images
40- run : docker system prune --all --force
41-
4234 tpu_unit_tests :
43- needs : prelim
4435 runs-on : [linux-x86-ct5lp-224-8tpu]
4536 container :
4637 image : python:3.12-slim
@@ -124,48 +115,6 @@ jobs:
124115 print('Llama3 params loaded successfully')
125116 "
126117
127- tpu_integration_tests :
128- needs : prelim
129- runs-on : [self-hosted, linux-x86-ct5lp-224-8tpu]
130- container :
131- image : python:3.12-slim
132- options : --privileged --cpus=2 --memory=4Gi
133- env :
134- TPU_ACCELERATOR_TYPE : " "
135- JAX_PLATFORMS : " cpu"
136- CUDA_VISIBLE_DEVICES : " "
137- TF_CPP_MIN_LOG_LEVEL : " 3"
138- steps :
139- - name : Checkout code
140- uses : actions/checkout@v4
141- with :
142- fetch-depth : 0
143-
144- - name : Set up Python
145- uses : actions/setup-python@v4
146- with :
147- python-version : ' 3.12'
148-
149- - name : Install system dependencies
150- run : |
151- sudo apt-get update
152- sudo apt-get install -y git curl
153-
154- - name : Set up JAX for TPU
155- run : |
156- pip install --upgrade pip
157- pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
158-
159- - name : Install tunix dependencies
160- run : |
161- pip install -e .
162- pip install pytest pytest-xdist
163-
164- - name : Run integration tests
165- run : |
166- # Run more comprehensive tests that might take longer
167- python -m pytest tests/ -v --tb=short -m "integration_test" --timeout=300
168-
169118 notify_failure :
170119 name : Notify failed build
171120 needs : [tpu_unit_tests, tpu_integration_tests]
0 commit comments