2222 secrets :
2323 HF_TOKEN :
2424 description : ' HuggingFace token for model downloads'
25+ KAGGLE_USERNAME :
26+ description : ' Kaggle Username'
27+ KAGGLE_KEY :
28+ description : ' Kaggle API Key'
2529
2630concurrency :
2731 # Dedup pull requests (canceling previous runs of the same workflow for same PR), and scheduled runs but nothing else
3337 HF_HUB_ENABLE_HF_TRANSFER : " 1"
3438
3539jobs :
36- run_prod :
37- runs-on : [linux-x86-ct5lp-224-8tpu]
38- environment : testing
39- container :
40- image : us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:latest
41- options : --privileged
42- env :
43- CLOUD_TPU_ACCELERATOR : v5e-8
44- JAX_PLATFORMS : tpu
45- steps :
46-
47- # Cache Hugging Face hub
48- - name : Cache HF hub
49- uses : actions/cache@v4
50- with :
51- path : ~/.cache/huggingface
52- key : hf-${{ runner.os }}-${{ hashFiles('pyproject.toml', 'requirements*.txt', 'constraints*.txt') }}
53- restore-keys : |
54- hf-${{ runner.os }}-
55-
56- - name : Checkout code
57- uses : actions/checkout@v4
58- with :
59- fetch-depth : 0
60-
61- - name : Install tunix dependencies
62- run : |
63- pip install --upgrade pip
64- pip install -e .[prod] --force-reinstall
65- pip install pytest pytest-xdist
66-
67- - name : Verify TPU availability
68- run : |
69- python -c "
70- import jax
71- print(f'JAX version: {jax.__version__}')
72- print(f'JAX devices: {jax.devices()}')
73-
74- # Check if we have TPU devices specifically
75- devices = jax.devices()
76- has_tpu = len(devices) > 0 and all(device.platform == 'tpu' for device in devices)
77- print(f'TPU available: {has_tpu}')
78-
79- if not has_tpu:
80- print('ERROR: No TPU devices found! Expected TPU devices but got:', [device.platform for device in devices])
81- exit(1)
82- else:
83- print(f'SUCCESS: Found {len(devices)} TPU device(s)')
84- "
85-
86- - name : Run tunix model tests
87- run : |
88- python -m pytest tests/models/ -v --tb=short -m "not cpu_only and not gpu_only"
89-
90- - name : Run tunix generation tests (PASSED only)
91- run : |
92- # tokenizer_adapter_test requires access to gated repo
93- python -m pytest tests/generate/ -v --tb=short \
94- --ignore=tests/generate/vllm_sampler_test.py \
95- --ignore=tests/generate/vllm_driver_test.py \
96- --ignore=tests/generate/tokenizer_adapter_test.py \
97- --ignore=tests/generate/sglang_jax_sampler_test.py
98-
99- - name : Run tunix SFT tests
100- run : |
101- python -m pytest tests/sft/ -v --tb=short
102-
103- - name : Run tunix distillation tests
104- run : |
105- python -m pytest tests/distillation/ -v --tb=short
106-
107- - name : Run tunix RL tests
108- run : |
109- # RL common tests that passed
110- # b/448133814: test_grpo_with_lora_model fails
111- python -m pytest tests/rl/ -v --tb=short -k "not test_grpo_with_lora_model" --ignore=tests/rl/experimental/agentic
112-
113- - name : Run tunix tests not covered by the above categories
114- run : |
115- # This category is to catch tests added but not covered by CI yet. Whenever you add new folders under tests/, please add a new category above and skip those tests here.
116- python -m pytest tests/ -v --tb=short --ignore=tests/perf/ --ignore=tests/model_alignment/ --ignore=tests/models/ --ignore=tests/cli/ --ignore=tests/generate/ --ignore=tests/sft/ --ignore=tests/distillation/ --ignore=tests/rl/ || code=$?
117- if [ "${code:-0}" = "5" ]; then
118- echo "No tests collected (expected)."
119- exit 0
120- else
121- exit "${code:-0}"
122- fi
123-
40+ # run_prod:
41+ # runs-on: [linux-x86-ct5lp-224-8tpu]
42+ # environment: testing
43+ # container:
44+ # image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:latest
45+ # options: --privileged
46+ # env:
47+ # CLOUD_TPU_ACCELERATOR: v5e-8
48+ # JAX_PLATFORMS: tpu
49+ # steps:
50+
51+ # # Cache Hugging Face hub
52+ # - name: Cache HF hub
53+ # uses: actions/cache@v4
54+ # with:
55+ # path: ~/.cache/huggingface
56+ # key: hf-${{ runner.os }}-${{ hashFiles('pyproject.toml', 'requirements*.txt', 'constraints*.txt') }}
57+ # restore-keys: |
58+ # hf-${{ runner.os }}-
59+
60+ # - name: Checkout code
61+ # uses: actions/checkout@v4
62+ # with:
63+ # fetch-depth: 0
64+
65+ # - name: Install tunix dependencies
66+ # run: |
67+ # pip install --upgrade pip
68+ # pip install -e .[prod] --force-reinstall
69+ # pip install pytest pytest-xdist
70+
71+ # - name: Verify TPU availability
72+ # run: |
73+ # python -c "
74+ # import jax
75+ # print(f'JAX version: {jax.__version__}')
76+ # print(f'JAX devices: {jax.devices()}')
77+
78+ # # Check if we have TPU devices specifically
79+ # devices = jax.devices()
80+ # has_tpu = len(devices) > 0 and all(device.platform == 'tpu' for device in devices)
81+ # print(f'TPU available: {has_tpu}')
82+
83+ # if not has_tpu:
84+ # print('ERROR: No TPU devices found! Expected TPU devices but got:', [device.platform for device in devices])
85+ # exit(1)
86+ # else:
87+ # print(f'SUCCESS: Found {len(devices)} TPU device(s)')
88+ # "
89+
90+ # - name: Run tunix model tests
91+ # run: |
92+ # python -m pytest tests/models/ -v --tb=short -m "not cpu_only and not gpu_only"
93+
94+ # - name: Run tunix generation tests (PASSED only)
95+ # run: |
96+ # # tokenizer_adapter_test requires access to gated repo
97+ # python -m pytest tests/generate/ -v --tb=short \
98+ # --ignore=tests/generate/vllm_sampler_test.py \
99+ # --ignore=tests/generate/vllm_driver_test.py \
100+ # --ignore=tests/generate/tokenizer_adapter_test.py \
101+ # --ignore=tests/generate/sglang_jax_sampler_test.py
102+
103+ # - name: Run tunix SFT tests
104+ # run: |
105+ # python -m pytest tests/sft/ -v --tb=short
106+
107+ # - name: Run tunix distillation tests
108+ # run: |
109+ # python -m pytest tests/distillation/ -v --tb=short
110+
111+ # - name: Run tunix RL tests
112+ # run: |
113+ # # RL common tests that passed
114+ # # b/448133814: test_grpo_with_lora_model fails
115+ # python -m pytest tests/rl/ -v --tb=short -k "not test_grpo_with_lora_model" --ignore=tests/rl/experimental/agentic
116+
117+ # - name: Run tunix tests not covered by the above categories
118+ # run: |
119+ # # This category is to catch tests added but not covered by CI yet. Whenever you add new folders under tests/, please add a new category above and skip those tests here.
120+ # python -m pytest tests/ -v --tb=short --ignore=tests/perf/ --ignore=tests/model_alignment/ --ignore=tests/models/ --ignore=tests/cli/ --ignore=tests/generate/ --ignore=tests/sft/ --ignore=tests/distillation/ --ignore=tests/rl/ --ignore=tests/smoke_tests / || code=$?
121+ # if [ "${code:-0}" = "5" ]; then
122+ # echo "No tests collected (expected)."
123+ # exit 0
124+ # else
125+ # exit "${code:-0}"
126+ # fi
127+ #
124128 run_dev :
125129 if : ${{ github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository }}
126130 runs-on : [linux-x86-ct5lp-224-8tpu]
@@ -131,6 +135,10 @@ jobs:
131135 env :
132136 CLOUD_TPU_ACCELERATOR : v5e-8
133137 JAX_PLATFORMS : tpu,cpu
138+ secrets :
139+ HF_TOKEN : ${{ secrets.HF_TOKEN }}
140+ KAGGLE_USERNAME : ${{ secrets.KAGGLE_USERNAME }}
141+ KAGGLE_KEY : ${{ secrets.KAGGLE_KEY }}
134142 steps :
135143 # Cache Hugging Face hub
136144 - name : Cache HF hub
@@ -156,84 +164,98 @@ jobs:
156164 pip install -e .[dev]
157165 pip install transformers==4.57.1 --force-reinstall # Issue: https://github.com/google/tunix/pull/795
158166
159- - name : GRPO Integration Test
160- env :
161- HF_TOKEN : ${{ secrets.HF_TOKEN }}
162- run : |
163-
164- # Download GSM8K dataset
165- mkdir -p /tmp/grpo_test/rl/grpo/data
166- python3 -c "
167- from datasets import load_dataset
168- import json
169-
170- # Download and save GSM8K train split
171- dataset = load_dataset('openai/gsm8k', 'main', split='train')
172- train_data = [{'question': item['question'], 'answer': item['answer']} for item in dataset]
173- with open('/tmp/grpo_test/rl/grpo/data/gsm8k_train.json', 'w') as f:
174- json.dump(train_data, f)
175-
176- # Download and save GSM8K test split
177- dataset = load_dataset('openai/gsm8k', 'main', split='test')
178- test_data = [{'question': item['question'], 'answer': item['answer']} for item in dataset]
179- with open('/tmp/grpo_test/rl/grpo/data/gsm8k_test.json', 'w') as f:
180- json.dump(test_data, f)
181-
182- print('GSM8K dataset downloaded successfully')
183- "
184-
185- # TODO(lancewang): Re-enable this test once the segfault is fixed.
186- # Run GRPO demo script with minimal configuration
187- # python3 scripts/grpo_demo_llama3_qwen2.py \
188- # --root-dir=/tmp/grpo_test \
189- # --model-version=Qwen/Qwen2.5-0.5B-Instruct \
190- # --num-batches=1 \
191- # --num-test-batches=1 \
192- # --rollout-engine=vanilla
193- - name : Run vllm tests
194- env :
195- HF_TOKEN : ${{ secrets.HF_TOKEN }}
196- run : |
197- unset JAX_PLATFORMS
198- pytest tests/generate/vllm_driver_test.py -v --tb=short
199- pytest tests/generate/vllm_sampler_test.py --collect-only -q --no-header --no-summary --disable-warnings | grep '::' > test_collections.txt
200- while read -r test; do
201- pytest -s "$test" -v --tb=short
202- done < test_collections.txt
203-
204- - name : Run install sglang-jax && test
205- env :
206- HF_TOKEN : ${{ secrets.HF_TOKEN }}
207- run : |
208- ## because sglang-jax has codes like jax.local_devices('cpu')
209- # TODO(lancewang): Re-enable this test once the bug is fixed.
210- unset JAX_PLATFORMS
211- pip list | egrep 'jax|flax|libtpu'
212- cd ..
213- git clone https://github.com/sgl-project/sglang-jax.git && cd sglang-jax/python && pip install -e . --force-reinstall && cd ../..
214- pip list | egrep 'jax|flax|libtpu'
215-
216- # Install bookworm, vllm container defaults to bullseye causes segfault for sglang-jax.
217- cat >/etc/apt/sources.list <<'EOF'
218- deb http://deb.debian.org/debian bookworm main contrib non-free
219- deb http://deb.debian.org/debian bookworm-updates main contrib non-free
220- deb http://security.debian.org/debian-security bookworm-security main contrib non-free
221- EOF
222- apt-get update; apt-get install -y less
223-
224- cd tunix && python tests/generate/sglang_jax_sampler_test.py
225- - name : Run tunix SFT integration tests
167+ # - name: GRPO Integration Test
168+ # env:
169+ # HF_TOKEN: ${{ secrets.HF_TOKEN }}
170+ # run: |
171+
172+ # # Download GSM8K dataset
173+ # mkdir -p /tmp/grpo_test/rl/grpo/data
174+ # python3 -c "
175+ # from datasets import load_dataset
176+ # import json
177+
178+ # # Download and save GSM8K train split
179+ # dataset = load_dataset('openai/gsm8k', 'main', split='train')
180+ # train_data = [{'question': item['question'], 'answer': item['answer']} for item in dataset]
181+ # with open('/tmp/grpo_test/rl/grpo/data/gsm8k_train.json', 'w') as f:
182+ # json.dump(train_data, f)
183+
184+ # # Download and save GSM8K test split
185+ # dataset = load_dataset('openai/gsm8k', 'main', split='test')
186+ # test_data = [{'question': item['question'], 'answer': item['answer']} for item in dataset]
187+ # with open('/tmp/grpo_test/rl/grpo/data/gsm8k_test.json', 'w') as f:
188+ # json.dump(test_data, f)
189+
190+ # print('GSM8K dataset downloaded successfully')
191+ # "
192+
193+ # # TODO(lancewang): Re-enable this test once the segfault is fixed.
194+ # # Run GRPO demo script with minimal configuration
195+ # # python3 scripts/grpo_demo_llama3_qwen2.py \
196+ # # --root-dir=/tmp/grpo_test \
197+ # # --model-version=Qwen/Qwen2.5-0.5B-Instruct \
198+ # # --num-batches=1 \
199+ # # --num-test-batches=1 \
200+ # # --rollout-engine=vanilla
201+ # - name: Run vllm tests
202+ # env:
203+ # HF_TOKEN: ${{ secrets.HF_TOKEN }}
204+ # run: |
205+ # unset JAX_PLATFORMS
206+ # pytest tests/generate/vllm_driver_test.py -v --tb=short
207+ # pytest tests/generate/vllm_sampler_test.py --collect-only -q --no-header --no-summary --disable-warnings | grep '::' > test_collections.txt
208+ # while read -r test; do
209+ # pytest -s "$test" -v --tb=short
210+ # done < test_collections.txt
211+
212+ # - name: Run install sglang-jax && test
213+ # env:
214+ # HF_TOKEN: ${{ secrets.HF_TOKEN }}
215+ # run: |
216+ # ## because sglang-jax has codes like jax.local_devices('cpu')
217+ # # TODO(lancewang): Re-enable this test once the bug is fixed.
218+ # unset JAX_PLATFORMS
219+ # pip list | egrep 'jax|flax|libtpu'
220+ # cd ..
221+ # git clone https://github.com/sgl-project/sglang-jax.git && cd sglang-jax/python && pip install -e . --force-reinstall && cd ../..
222+ # pip list | egrep 'jax|flax|libtpu'
223+
224+ # # Install bookworm, vllm container defaults to bullseye causes segfault for sglang-jax.
225+ # cat >/etc/apt/sources.list <<'EOF'
226+ # deb http://deb.debian.org/debian bookworm main contrib non-free
227+ # deb http://deb.debian.org/debian bookworm-updates main contrib non-free
228+ # deb http://security.debian.org/debian-security bookworm-security main contrib non-free
229+ # EOF
230+ # apt-get update; apt-get install -y less
231+
232+ # cd tunix && python tests/generate/sglang_jax_sampler_test.py
233+ # - name: Run tunix SFT integration tests
234+ # env:
235+ # HF_TOKEN: ${{ secrets.HF_TOKEN }}
236+ # run: |
237+ # # Reinstall Tunix with prod dependencies
238+ # pip install -e .[prod] --force-reinstall
239+
240+ # # Loading tfds requires tensorflow.
241+ # pip install tensorflow
242+
243+ # export JAX_PLATFORMS=tpu,cpu
244+ # ./tests/sft/sft_tpu_smoke_test.sh
245+ - name : Run Smoke tests
226246 env :
227247 HF_TOKEN : ${{ secrets.HF_TOKEN }}
248+ KAGGLE_USERNAME : ${{ secrets.KAGGLE_USERNAME }}
249+ KAGGLE_KEY : ${{ secrets.KAGGLE_KEY }}
228250 run : |
229- # Reinstall Tunix with prod dependencies
230- pip install -e .[prod] --force-reinstall
231-
232- # Loading tfds requires tensorflow.
233- pip install tensorflow
234-
235- export JAX_PLATFORMS=tpu,cpu
236- ./ tests/sft/sft_tpu_smoke_test.sh
251+ echo "Running Smoke tests..."
252+ # Debugging: Check if env vars are set (don't print values)
253+ if [ -n "$KAGGLE_USERNAME" ]; then echo "KAGGLE_USERNAME is set"; else echo "KAGGLE_USERNAME is NOT set"; fi
254+ if [ -n "$KAGGLE_KEY" ]; then echo "KAGGLE_KEY is set"; else echo "KAGGLE_KEY is NOT set"; fi
255+ echo HF_TOKEN: ${HF_TOKEN}
256+ echo KAGGLE_USERNAME: ${KAGGLE_USERNAME}
257+ echo KAGGLE_KEY: ${KAGGLE_KEY}
258+ python -m pytest tests/smoke_tests/model_creation_test.py -v --tb=short
237259 - name : Run tunix cli tests
238260 env :
239261 HF_TOKEN : ${{ secrets.HF_TOKEN }}
0 commit comments