Skip to content

Commit 25d4cac

Browse files
s-noghabiThe tunix Authors
authored andcommitted
update qwen3 scripts to use yaml configs
PiperOrigin-RevId: 915129002
1 parent 45e1b4a commit 25d4cac

7 files changed

Lines changed: 216 additions & 303 deletions

File tree

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
model_config:
16+
model_name: "Qwen3-1.7B-base"
17+
model_id: "Qwen/Qwen3-1.7B-base"
18+
model_source: "huggingface"
19+
use_flash_attention: true
20+
flash_attention_block_size: 256
21+
mesh:
22+
shape: "(2,4)"
23+
axis_names: "('fsdp','tp')"
24+
rng_seed: 42
25+
26+
actor_model_config:
27+
lora_config:
28+
rank: 64
29+
alpha: 64.0
30+
module_path: ".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj"
31+
mesh:
32+
shape: "(2,4)"
33+
axis_names: "('fsdp','tp')"
34+
35+
reference_model_config:
36+
mesh: null
37+
same_mesh_as: "actor"
38+
39+
rollout_model_config:
40+
mesh: null
41+
same_mesh_as: "actor"
42+
43+
tokenizer_config:
44+
tokenizer_type: "huggingface"
45+
add_bos: false
46+
47+
dataset_name: "gsm8k"
48+
batch_size: 8
49+
num_test_batches: 100
50+
num_train_epochs: 1
51+
52+
rl_training_config:
53+
actor_optimizer_config:
54+
opt_type: "adamw"
55+
peak_value: 3e-6
56+
schedule_type: "warmup_cosine_decay_schedule"
57+
init_value: 0.0
58+
end_value: 0.0
59+
warmup_ratio: 0.1
60+
b1: 0.9
61+
b2: 0.99
62+
weight_decay: 0.1
63+
max_grad_norm: 0.1
64+
eval_every_n_steps: 10
65+
metrics_logging_options:
66+
flush_every_n_steps: 20
67+
checkpointing_options:
68+
save_interval_steps: 500
69+
max_to_keep: 4
70+
profiler_options: {}
71+
72+
rollout_config:
73+
total_generation_steps: 768
74+
max_prompt_length: 256
75+
temperature: 0.9
76+
top_p: 1.0
77+
top_k: 50
78+
79+
rollout_engine: "vanilla"
80+
offload_to_cpu: false
81+
82+
grpo_config:
83+
num_generations: 4
84+
num_iterations: 1
85+
beta: 0.08
86+
epsilon: 0.2
87+
88+
reward_functions:
89+
- "tunix/cli/reward_fn/gsm8k.py"
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
model_config:
16+
rng_seed: 42
17+
model_display: false
18+
remat_config: 3
19+
20+
actor_model_config:
21+
mesh:
22+
shape: "(8,1)"
23+
axis_names: "('fsdp','tp')"
24+
25+
rollout_model_config:
26+
mesh:
27+
shape: "(1,8)"
28+
axis_names: "('fsdp','tp')"
29+
30+
reference_model_config:
31+
mesh: null
32+
same_mesh_as: "actor"
33+
34+
data_source: "huggingface"
35+
dataset_name: "openai/gsm8k:main"
36+
prompt_key: "question"
37+
38+
training_mode: "agentic_grpo"
39+
num_test_batches: 100
40+
reward_functions:
41+
- "tunix/cli/reward_fn/gsm8k.py"
42+
verl_compatible: false
43+
44+
rollout_engine: "vllm"
45+
offload_to_cpu: false
46+
47+
rollout_config:
48+
max_prompt_length: 256
49+
total_generation_steps: 768
50+
max_tokens_to_generate: 768
51+
temperature: 0.9
52+
top_p: 1.0
53+
top_k: 50
54+
return_logprobs: true
55+
56+
vllm_config:
57+
hbm_utilization: 0.4
58+
tpu_backend_type: "jax"
59+
server_mode: true
60+
async_scheduling: true
61+
kwargs:
62+
kv_cache_metrics: true
63+
disable_log_stats: false
64+
enable_prefix_caching: true
65+
66+
chat_parser_config:
67+
type: "qwen"
68+
69+
tokenizer_config:
70+
tokenizer_type: "huggingface"
71+
add_bos: false
72+
add_eos: false
73+
74+
agentic_grpo_config:
75+
num_iterations: 1
76+
beta: 0.08
77+
epsilon: 0.2
78+
system_prompt: "You are given a grade school math problem. Think step by step and respond using <reasoning>...</reasoning> followed by <answer>...</answer> with only the final numeric answer inside <answer>."
79+
max_concurrency: 128
80+
max_response_length: 768
81+
max_turns: 1
82+
context_ratio: 1
83+
84+
rl_training_config:
85+
actor_optimizer_config:
86+
opt_type: "adamw"
87+
learning_rate: 3e-6
88+
schedule_type: "warmup_cosine_decay_schedule"
89+
init_value: 0.0
90+
peak_value: 3e-6
91+
end_value: 0.0
92+
b1: 0.9
93+
b2: 0.99
94+
weight_decay: 0.1
95+
max_grad_norm: 0.1
96+
eval_every_n_steps: 10
97+
checkpointing_options:
98+
save_interval_steps: 250
99+
max_to_keep: 4
100+
metrics_logging_options:
101+
flush_every_n_steps: 20

examples/rl/grpo/gsm8k/run_qwen3.sh

Lines changed: 6 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -31,60 +31,16 @@ echo " Train Fraction: $train_fraction"
3131
echo " Checkpoint Directory: $checkpoint_dir"
3232

3333
python3 -m tunix.cli.grpo_main \
34-
base_config.yaml \
35-
model_config.model_name=${model_name} \
36-
model_config.model_id=Qwen/${model_name} \
37-
model_config.model_source=huggingface \
38-
model_config.use_flash_attention=true \
39-
model_config.flash_attention_block_size=256 \
34+
tunix/cli/base_config.yaml \
35+
override_config_file=examples/rl/grpo/gsm8k/configs/qwen3.yaml \
36+
model_config.model_name="${model_name}" \
37+
model_config.model_id="Qwen/${model_name}" \
4038
model_config.intermediate_ckpt_dir="/tmp/intermediate_ckpt/${model_name}" \
4139
model_config.model_download_path="/tmp/models/${model_name}" \
42-
model_config.mesh.shape="(2,4)" \
43-
model_config.mesh.axis_names="('fsdp','tp')" \
44-
model_config.rng_seed=42 \
45-
actor_model_config.lora_config.rank=64 \
46-
actor_model_config.lora_config.alpha=64.0 \
47-
actor_model_config.lora_config.module_path=".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*down_proj|.*up_proj" \
48-
actor_model_config.mesh.shape="(2,4)" \
49-
actor_model_config.mesh.axis_names="('fsdp','tp')" \
50-
reference_model_config.mesh=null \
51-
reference_model_config.same_mesh_as="actor" \
52-
rollout_model_config.mesh=null \
53-
rollout_model_config.same_mesh_as="actor" \
54-
tokenizer_config.tokenizer_path=Qwen/${model_name} \
55-
tokenizer_config.tokenizer_type=huggingface \
56-
tokenizer_config.add_bos=false \
57-
dataset_name="gsm8k" \
40+
tokenizer_config.tokenizer_path="Qwen/${model_name}" \
5841
batch_size=$batch_size \
59-
num_test_batches=100 \
6042
num_train_epochs=$num_train_epochs \
6143
train_fraction=$train_fraction \
62-
rl_training_config.actor_optimizer_config.opt_type="adamw" \
63-
rl_training_config.actor_optimizer_config.peak_value=3e-6 \
64-
rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \
65-
rl_training_config.actor_optimizer_config.init_value=0.0 \
66-
rl_training_config.actor_optimizer_config.end_value=0.0 \
6744
rl_training_config.actor_optimizer_config.warmup_ratio=$warmup_ratio \
68-
rl_training_config.actor_optimizer_config.b1=0.9 \
69-
rl_training_config.actor_optimizer_config.b2=0.99 \
70-
rl_training_config.actor_optimizer_config.weight_decay=0.1 \
71-
rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \
72-
rl_training_config.eval_every_n_steps=10 \
7345
rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/${model_name}" \
74-
rl_training_config.metrics_logging_options.flush_every_n_steps=20 \
75-
rl_training_config.checkpoint_root_directory="$checkpoint_dir" \
76-
rl_training_config.checkpointing_options.save_interval_steps=500 \
77-
rl_training_config.checkpointing_options.max_to_keep=4 \
78-
rl_training_config.profiler_options={} \
79-
rollout_config.total_generation_steps=768 \
80-
rollout_config.max_prompt_length=256 \
81-
rollout_config.temperature=0.9 \
82-
rollout_config.top_p=1.0 \
83-
rollout_config.top_k=50 \
84-
rollout_engine="vanilla" \
85-
offload_to_cpu=false \
86-
grpo_config.num_generations=4 \
87-
grpo_config.num_iterations=1 \
88-
grpo_config.beta=0.08 \
89-
grpo_config.epsilon=0.2 \
90-
reward_functions="['tunix/cli/reward_fn/gsm8k.py']"
46+
rl_training_config.checkpoint_root_directory="$checkpoint_dir"

examples/rl/grpo/gsm8k/run_qwen3_8b_disagg.sh

Lines changed: 2 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -72,104 +72,28 @@ vllm_max_num_seqs=$(awk "BEGIN {
7272

7373
python -m tunix.cli.grpo_main \
7474
tunix/cli/base_agentic_config.yaml \
75-
\
76-
`# -- Model ------------------------------------------------------------` \
75+
override_config_file=examples/rl/grpo/gsm8k/configs/qwen3_disagg.yaml \
7776
model_config.model_name="$model_name" \
7877
model_config.model_id="$model_id" \
7978
model_config.model_source="huggingface" \
8079
model_config.model_download_path="/tmp/models/${model_name}" \
81-
model_config.rng_seed=42 \
82-
model_config.model_display=false \
83-
model_config.remat_config=3 \
80+
tokenizer_config.tokenizer_path="$tokenizer_path" \
8481
actor_model_config.mesh.shape="$train_mesh" \
85-
actor_model_config.mesh.axis_names="('fsdp','tp')" \
86-
reference_model_config.mesh=null \
87-
reference_model_config.same_mesh_as="actor" \
8882
rollout_model_config.mesh.shape="$rollout_mesh" \
89-
rollout_model_config.mesh.axis_names="('fsdp','tp')" \
90-
\
91-
`# -- Data -------------------------------------------------------------` \
92-
data_source="huggingface" \
93-
dataset_name="openai/gsm8k:main" \
94-
prompt_key="question" \
95-
\
96-
`# -- Training loop ----------------------------------------------------` \
97-
training_mode="agentic_grpo" \
9883
batch_size="$batch_size" \
9984
num_batches="$num_batches" \
100-
num_test_batches=100 \
10185
num_train_epochs="$num_train_epochs" \
10286
train_fraction="$train_fraction" \
103-
reward_functions=["tunix/cli/reward_fn/gsm8k.py"] \
104-
verl_compatible=false \
105-
\
106-
`# -- Rollout engine (vanilla | vllm | sglang_jax) ---------------------` \
107-
rollout_engine="vllm" \
108-
offload_to_cpu=false \
109-
\
110-
`# -- Rollout config ---------------------------------------------------` \
111-
rollout_config.max_prompt_length=256 \
112-
rollout_config.total_generation_steps=768 \
113-
rollout_config.max_tokens_to_generate=768 \
114-
rollout_config.temperature=0.9 \
115-
rollout_config.top_p=1.0 \
116-
rollout_config.top_k=50 \
117-
rollout_config.return_logprobs=true \
118-
\
119-
`# -- vLLM (used when rollout_engine=vllm) -----------------------------` \
120-
vllm_config.hbm_utilization=0.4 \
121-
vllm_config.tpu_backend_type="jax" \
122-
vllm_config.server_mode=true \
123-
vllm_config.async_scheduling=true \
12487
vllm_config.max_num_seqs="$vllm_max_num_seqs" \
125-
vllm_config.kwargs.kv_cache_metrics=true \
126-
vllm_config.kwargs.disable_log_stats=false \
127-
vllm_config.kwargs.enable_prefix_caching=true \
128-
\
129-
`# -- Tokenizer / chat parsing ----------------------------------------` \
130-
chat_parser_config.type="qwen" \
131-
tokenizer_config.tokenizer_type="huggingface" \
132-
tokenizer_config.tokenizer_path="$tokenizer_path" \
133-
tokenizer_config.add_bos=false \
134-
tokenizer_config.add_eos=false \
135-
\
136-
`# -- GRPO algorithm ---------------------------------------------------` \
13788
agentic_grpo_config.num_generations="$num_generations" \
138-
agentic_grpo_config.num_iterations=1 \
139-
agentic_grpo_config.beta=0.08 \
140-
agentic_grpo_config.epsilon=0.2 \
141-
agentic_grpo_config.system_prompt="You are given a grade school math problem. Think step by step and respond using <reasoning>...</reasoning> followed by <answer>...</answer> with only the final numeric answer inside <answer>." \
142-
agentic_grpo_config.max_concurrency=128 \
143-
agentic_grpo_config.max_response_length=768 \
144-
agentic_grpo_config.max_turns=1 \
145-
agentic_grpo_config.context_ratio=1 \
146-
\
147-
`# -- Optimizer --------------------------------------------------------` \
148-
rl_training_config.actor_optimizer_config.opt_type="adamw" \
149-
rl_training_config.actor_optimizer_config.learning_rate=3e-6 \
150-
rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \
151-
rl_training_config.actor_optimizer_config.init_value=0.0 \
152-
rl_training_config.actor_optimizer_config.peak_value=3e-6 \
153-
rl_training_config.actor_optimizer_config.end_value=0.0 \
15489
rl_training_config.actor_optimizer_config.warmup_ratio="$warmup_ratio" \
15590
rl_training_config.actor_optimizer_config.warmup_steps="$warmup_steps" \
15691
rl_training_config.actor_optimizer_config.decay_steps="$max_steps" \
157-
rl_training_config.actor_optimizer_config.b1=0.9 \
158-
rl_training_config.actor_optimizer_config.b2=0.99 \
159-
rl_training_config.actor_optimizer_config.weight_decay=0.1 \
160-
rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \
161-
\
162-
`# -- RL training ------------------------------------------------------` \
163-
rl_training_config.eval_every_n_steps=10 \
16492
rl_training_config.max_steps="$max_steps" \
16593
rl_training_config.mini_batch_size="$mini_batch_size" \
16694
rl_training_config.train_micro_batch_size="$train_micro_batch_size" \
16795
rl_training_config.rollout_micro_batch_size="$rollout_micro_batch_size" \
16896
rl_training_config.compute_logps_micro_batch_size="$compute_logps_micro_batch_size" \
16997
rl_training_config.checkpoint_root_directory="$checkpoint_dir" \
170-
rl_training_config.checkpointing_options.save_interval_steps=250 \
171-
rl_training_config.checkpointing_options.max_to_keep=4 \
17298
rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/gsm8k_qwen3_8b" \
173-
rl_training_config.metrics_logging_options.flush_every_n_steps=20 \
174-
\
17599
"$@"

0 commit comments

Comments
 (0)