Skip to content

Commit 9fc6ee7

Browse files
sizhit2The tunix Authors
authored andcommitted
Add RL smoke tests to Tunix TPU CI pipeline
PiperOrigin-RevId: 915142830
1 parent e67d0ef commit 9fc6ee7

3 files changed

Lines changed: 41 additions & 1 deletion

File tree

.github/workflows/tpu-tests.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,12 @@ jobs:
231231
232232
export JAX_PLATFORMS=tpu,cpu
233233
./tests/sft/sft_tpu_smoke_test.sh
234+
- name: Run tunix RL smoke tests
235+
env:
236+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
237+
run: |
238+
export JAX_PLATFORMS=tpu,cpu
239+
./tests/rl/rl_tpu_smoke_test.sh
234240
- name: Run Smoke tests
235241
env:
236242
HF_TOKEN: ${{ secrets.HF_TOKEN }}

examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,4 +97,5 @@ python3 -m tunix.cli.grpo_main \
9797
grpo_config.num_iterations=1 \
9898
grpo_config.beta=0.08 \
9999
grpo_config.epsilon=0.2 \
100-
reward_functions="['tunix/cli/reward_fn/simple_math.py']"
100+
reward_functions="['tunix/cli/reward_fn/simple_math.py']" \
101+
"$@"

tests/rl/rl_tpu_smoke_test.sh

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#!/bin/bash
2+
# Copyright 2026 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
set -ex
17+
18+
echo "=== Running Non-Agentic RL Smoke Test (Vanilla Rollout) ==="
19+
model_name="Qwen2.5-0.5B" \
20+
num_batches=1 \
21+
bash examples/rl/grpo/gsm8k/run_qwen3_simplereward.sh \
22+
rollout_config.total_generation_steps=64
23+
24+
echo "=== Running Agentic RL Smoke Test (vLLM Rollout) ==="
25+
model_name="Qwen2.5-0.5B" \
26+
model_id="Qwen/Qwen2.5-0.5B" \
27+
num_batches=1 \
28+
total_tpus=8 \
29+
train_mesh="(4,1)" \
30+
rollout_mesh="(1,4)" \
31+
checkpoint_dir="/tmp/rl_smoke_agentic_ckpts" \
32+
bash examples/rl/grpo/gsm8k/run_qwen3_8b_disagg.sh \
33+
rollout_config.total_generation_steps=64

0 commit comments

Comments
 (0)