-
Notifications
You must be signed in to change notification settings - Fork 281
329 lines (295 loc) · 12.5 KB
/
tpu-tests.yml
File metadata and controls
329 lines (295 loc) · 12.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
name: TPU Tests
on:
workflow_call:
secrets:
HF_TOKEN:
description: 'HuggingFace token for model downloads'
KAGGLE_USERNAME:
description: 'Kaggle Username'
KAGGLE_KEY:
description: 'Kaggle API Key'
concurrency:
# Dedup pull requests (canceling previous runs of the same workflow for same PR), and scheduled runs but nothing else
group: ${{ github.event_name == 'pull_request' && format('{0}-pr-{1}', github.workflow, github.event.pull_request.number) || github.event_name == 'schedule' && format('{0}-schedule', github.workflow) || github.run_id }}
cancel-in-progress: true
env:
HF_HOME: ~/.cache/huggingface
HF_HUB_ENABLE_HF_TRANSFER: "1"
jobs:
run_prod:
runs-on: [linux-x86-ct5lp-224-8tpu]
environment: testing
container:
image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:latest
options: --privileged
env:
CLOUD_TPU_ACCELERATOR: v5e-8
JAX_PLATFORMS: tpu
steps:
# Cache Hugging Face hub
- name: Cache HF hub
uses: actions/cache@v4
with:
path: ~/.cache/huggingface
key: hf-${{ runner.os }}-${{ hashFiles('pyproject.toml', 'requirements*.txt', 'constraints*.txt') }}
restore-keys: |
hf-${{ runner.os }}-
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Install tunix dependencies
run: |
pip install --upgrade pip
pip install -e .[prod,test] --force-reinstall
- name: Verify TPU availability
run: |
python -c "
import jax
print(f'JAX version: {jax.__version__}')
print(f'JAX devices: {jax.devices()}')
# Check if we have TPU devices specifically
devices = jax.devices()
has_tpu = len(devices) > 0 and all(device.platform == 'tpu' for device in devices)
print(f'TPU available: {has_tpu}')
if not has_tpu:
print('ERROR: No TPU devices found! Expected TPU devices but got:', [device.platform for device in devices])
exit(1)
else:
print(f'SUCCESS: Found {len(devices)} TPU device(s)')
"
- name: Run tunix model tests
run: |
python -m pytest tests/models/ -v --tb=short -m "not cpu_only and not gpu_only" \
--ignore=tests/models/naming_test.py
- name: Run tunix generation tests (PASSED only)
run: |
# tokenizer_adapter_test requires access to gated repo
python -m pytest tests/generate/ -v --tb=short \
--ignore=tests/generate/vllm_sampler_test.py \
--ignore=tests/generate/vllm_sampler_qwen_test.py \
--ignore=tests/generate/vllm_driver_test.py \
--ignore=tests/generate/tokenizer_adapter_test.py \
--ignore=tests/generate/sglang_jax_sampler_test.py \
--ignore=tests/generate/sglang_jax_lora_test.py
- name: Run tunix SFT tests
run: |
python -m pytest tests/sft/ -v --tb=short
- name: Run tunix distillation tests
run: |
python -m pytest tests/distillation/ -v --tb=short
- name: Run tunix RL tests
run: |
# RL common tests that passed
# b/448133814: test_grpo_with_lora_model fails
python -m pytest tests/rl/ -v --tb=short -k "not test_grpo_with_lora_model" --ignore=tests/rl/agentic
- name: Run tunix tests not covered by the above categories
run: |
# This category is to catch tests added but not covered by CI yet. Whenever you add new folders under tests/, please add a new category above and skip those tests here.
python -m pytest tests/ -v --tb=short --ignore=tests/perf/ --ignore=tests/model_alignment/ --ignore=tests/models/ --ignore=tests/cli/ --ignore=tests/generate/ --ignore=tests/sft/ --ignore=tests/distillation/ --ignore=tests/rl/ --ignore=tests/smoke_tests/ || code=$?
if [ "${code:-0}" = "5" ]; then
echo "No tests collected (expected)."
exit 0
else
exit "${code:-0}"
fi
run_dev_others:
if: ${{ github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository }}
runs-on: [linux-x86-ct5lp-224-8tpu]
environment: testing
container:
image: vllm/vllm-tpu:nightly-20260406-581c4d4-f6983f0
options: --privileged
env:
CLOUD_TPU_ACCELERATOR: v5e-8
JAX_PLATFORMS: tpu,cpu
steps:
# Cache Hugging Face hub
- name: Cache HF hub
uses: actions/cache@v4
with:
path: ~/.cache/huggingface
key: hf-${{ runner.os }}-${{ hashFiles('pyproject.toml', 'requirements*.txt', 'constraints*.txt') }}
restore-keys: |
hf-${{ runner.os }}-
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Tunix , tpu-inference and dependencies
run: |
echo "Current directory:"
pwd
pip install --upgrade pip setuptools wheel
# Install Tunix with dev and test dependencies without overwriting the vLLM dependencies.
pip install -e .[dev,test]
pip install transformers==4.57.1 --force-reinstall # Issue: https://github.com/google/tunix/pull/795
# tpu-inference/Numba needs NumPy 2.3 or less.
pip install numpy==2.3.5 --force-reinstall
- name: Run tunix model tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
python3 -m pytest tests/models/naming_test.py -v --tb=short
- name: GRPO Integration Test
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
# Run GRPO demo script with minimal configuration
python3 scripts/grpo_demo_llama3_qwen2.py \
--root-dir=/tmp/grpo_test \
--num-batches=2 \
--num-test-batches=1 \
--global-batch-size=2 \
--train-mini-batch-size=2 \
--train-micro-batch-size=2 \
--rollout-engine=vanilla
- name: Run tunix SFT integration tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
# Reinstall Tunix with prod dependencies
pip install -e .[prod] --force-reinstall
# Loading tfds requires tensorflow.
pip install tensorflow
export JAX_PLATFORMS=tpu,cpu
./tests/sft/sft_tpu_smoke_test.sh
- name: Run Smoke tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
KAGGLE_USERNAME: ${{ secrets.KAGGLE_USERNAME }}
KAGGLE_KEY: ${{ secrets.KAGGLE_KEY }}
run: |
echo "Running Smoke tests..."
python -m pytest tests/smoke_tests/model_creation_test.py -v --tb=short
- name: Run tunix cli tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
KAGGLE_USERNAME: ${{ secrets.KAGGLE_USERNAME }}
KAGGLE_KEY: ${{ secrets.KAGGLE_KEY }}
run: |
# Config tests that passed
python -m pytest tests/cli/ -v --tb=short \
--ignore=tests/cli/utils/model_test.py
- name: Run model alignment tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
python -m pip install torch
JAX_PLATFORMS=cpu python -m pytest tests/model_alignment/ -v --tb=short
unset JAX_PLATFORMS
run_dev_vllm:
if: ${{ github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository }}
runs-on: [linux-x86-ct5lp-224-8tpu]
environment: testing
container:
image: vllm/vllm-tpu:nightly-20260406-581c4d4-f6983f0
options: --privileged
env:
CLOUD_TPU_ACCELERATOR: v5e-8
JAX_PLATFORMS: tpu,cpu
steps:
# Cache Hugging Face hub
- name: Cache HF hub
uses: actions/cache@v4
with:
path: ~/.cache/huggingface
key: hf-${{ runner.os }}-${{ hashFiles('pyproject.toml', 'requirements*.txt', 'constraints*.txt') }}
restore-keys: |
hf-${{ runner.os }}-
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Tunix , tpu-inference and dependencies
run: |
echo "Current directory:"
pwd
pip install --upgrade pip setuptools wheel
# Install Tunix with dev and test dependencies without overwriting the vLLM dependencies.
pip install -e .[dev,test]
pip install transformers==4.57.1 --force-reinstall # Issue: https://github.com/google/tunix/pull/795
# tpu-inference/Numba needs NumPy 2.3 or less.
pip install numpy==2.3.5 --force-reinstall
- name: Run vllm tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
unset JAX_PLATFORMS
pytest tests/generate/vllm_sampler_qwen_test.py -v --tb=short
pytest tests/generate/vllm_driver_test.py -v --tb=short
pytest tests/generate/vllm_sampler_test.py --collect-only -q --no-header --no-summary --disable-warnings | grep '::' > test_collections.txt
while read -r test; do
pytest -s "$test" -v --tb=short
done < test_collections.txt
run_dev_sglang:
if: ${{ github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository }}
runs-on: [linux-x86-ct5lp-224-8tpu]
environment: testing
container:
image: vllm/vllm-tpu:nightly-20260406-581c4d4-f6983f0
options: --privileged
env:
CLOUD_TPU_ACCELERATOR: v5e-8
JAX_PLATFORMS: tpu,cpu
steps:
# Cache Hugging Face hub
- name: Cache HF hub
uses: actions/cache@v4
with:
path: ~/.cache/huggingface
key: hf-${{ runner.os }}-${{ hashFiles('pyproject.toml', 'requirements*.txt', 'constraints*.txt') }}
restore-keys: |
hf-${{ runner.os }}-
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Tunix , tpu-inference and dependencies
run: |
echo "Current directory:"
pwd
pip install --upgrade pip setuptools wheel
# Install Tunix with dev and test dependencies without overwriting the vLLM dependencies.
pip install -e .[dev,test]
pip install transformers==4.57.1 --force-reinstall # Issue: https://github.com/google/tunix/pull/795
# tpu-inference/Numba needs NumPy 2.3 or less.
pip install numpy==2.3.5 --force-reinstall
- name: Run install sglang-jax && test
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
## because sglang-jax has codes like jax.local_devices('cpu')
# TODO(lancewang): Re-enable this test once the bug is fixed.
unset JAX_PLATFORMS
pip list | egrep 'jax|flax|libtpu'
cd ..
git clone https://github.com/sgl-project/sglang-jax.git && cd sglang-jax/python && pip install -e . --force-reinstall && cd ../..
# SGLang Jax removed qwix dependency in https://github.com/sgl-project/sglang-jax/pull/734
pip install qwix --force-reinstall
# TODO(b/470113586): Remove this test once the bug is fixed.
pip install jax[tpu]==0.8.1 flax==0.12.4 --force-reinstall
pip list | egrep 'jax|flax|libtpu'
# Install bookworm, vllm container defaults to bullseye causes segfault for sglang-jax.
cat >/etc/apt/sources.list <<'EOF'
deb http://deb.debian.org/debian bookworm main contrib non-free
deb http://deb.debian.org/debian bookworm-updates main contrib non-free
deb http://security.debian.org/debian-security bookworm-security main contrib non-free
EOF
apt-get update; apt-get install -y less
cd tunix && python tests/generate/sglang_jax_sampler_test.py && python tests/generate/sglang_jax_lora_test.py