Skip to content

Commit 30fec08

Browse files
s-noghabiThe tunix Authors
authored andcommitted
yaml based configs for gemma3 cli scripts
PiperOrigin-RevId: 915005187
1 parent 9a28b3d commit 30fec08

6 files changed

Lines changed: 252 additions & 141 deletions

File tree

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
model_config:
16+
model_name: "gemma-3-12b-it"
17+
model_id: "google/gemma-3-12b-it"
18+
model_source: "gcs"
19+
mesh:
20+
shape: "(2,4)"
21+
axis_names: "('fsdp','tp')"
22+
rng_seed: 42
23+
actor_model_config:
24+
lora_config:
25+
rank: 64
26+
alpha: 64.0
27+
module_path: ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum"
28+
mesh:
29+
shape: "(2,4)"
30+
axis_names: "('fsdp','tp')"
31+
reference_model_config:
32+
mesh: null
33+
same_mesh_as: "actor"
34+
rollout_model_config:
35+
mesh: null
36+
same_mesh_as: "actor"
37+
tokenizer_config:
38+
tokenizer_type: "sentencepiece"
39+
add_bos: False
40+
dataset_name: "gsm8k"
41+
batch_size: 1
42+
num_batches: 3738
43+
num_test_batches: 100
44+
num_train_epochs: 1
45+
rl_training_config:
46+
actor_optimizer_config:
47+
opt_type: "adamw"
48+
peak_value: 3e-6
49+
schedule_type: "warmup_cosine_decay_schedule"
50+
init_value: 0.0
51+
end_value: 0.0
52+
warmup_ratio: 0.1
53+
warmup_steps: 374
54+
decay_steps: 3738
55+
b1: 0.9
56+
b2: 0.99
57+
weight_decay: 0.1
58+
max_grad_norm: 0.1
59+
eval_every_n_steps: 10
60+
max_steps: 3738
61+
metrics_logging_options:
62+
flush_every_n_steps: 20
63+
checkpointing_options:
64+
save_interval_steps: 500
65+
max_to_keep: 4
66+
profiler_options: {}
67+
rollout_config:
68+
total_generation_steps: 768
69+
max_prompt_length: 256
70+
temperature: 0.9
71+
top_p: 1.0
72+
top_k: 50
73+
rollout_engine: "vanilla"
74+
offload_to_cpu: False
75+
grpo_config:
76+
num_generations: 2
77+
num_iterations: 1
78+
beta: 0.08
79+
epsilon: 0.2
80+
reward_functions:
81+
- "tunix/cli/reward_fn/gsm8k.py"
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
model_config:
16+
model_name: "gemma-3-1b-it"
17+
model_id: "google/gemma-3-1b-it"
18+
model_source: "gcs"
19+
mesh:
20+
shape: "(2,4)"
21+
axis_names: "('fsdp','tp')"
22+
rng_seed: 42
23+
actor_model_config:
24+
lora_config:
25+
rank: 64
26+
alpha: 64.0
27+
module_path: ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum"
28+
mesh:
29+
shape: "(2,4)"
30+
axis_names: "('fsdp','tp')"
31+
reference_model_config:
32+
mesh: null
33+
same_mesh_as: "actor"
34+
rollout_model_config:
35+
mesh: null
36+
same_mesh_as: "actor"
37+
tokenizer_config:
38+
tokenizer_type: "sentencepiece"
39+
add_bos: False
40+
dataset_name: "gsm8k"
41+
batch_size: 1
42+
num_batches: 3738
43+
num_test_batches: 100
44+
num_train_epochs: 1
45+
rl_training_config:
46+
actor_optimizer_config:
47+
opt_type: "adamw"
48+
peak_value: 3e-6
49+
schedule_type: "warmup_cosine_decay_schedule"
50+
init_value: 0.0
51+
end_value: 0.0
52+
warmup_ratio: 0.1
53+
warmup_steps: 374
54+
decay_steps: 3738
55+
b1: 0.9
56+
b2: 0.99
57+
weight_decay: 0.1
58+
max_grad_norm: 0.1
59+
eval_every_n_steps: 10
60+
max_steps: 3738
61+
metrics_logging_options:
62+
flush_every_n_steps: 20
63+
checkpointing_options:
64+
save_interval_steps: 500
65+
max_to_keep: 4
66+
profiler_options: {}
67+
rollout_config:
68+
total_generation_steps: 768
69+
max_prompt_length: 256
70+
temperature: 0.9
71+
top_p: 1.0
72+
top_k: 50
73+
rollout_engine: "vanilla"
74+
offload_to_cpu: False
75+
grpo_config:
76+
num_generations: 2
77+
num_iterations: 1
78+
beta: 0.08
79+
epsilon: 0.2
80+
reward_functions:
81+
- "tunix/cli/reward_fn/gsm8k.py"
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
model_config:
16+
model_name: "gemma-3-4b-it"
17+
model_id: "google/gemma-3-4b-it"
18+
model_source: "gcs"
19+
mesh:
20+
shape: "(2,4)"
21+
axis_names: "('fsdp','tp')"
22+
rng_seed: 42
23+
actor_model_config:
24+
lora_config:
25+
rank: 64
26+
alpha: 64.0
27+
module_path: ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum"
28+
mesh:
29+
shape: "(2,4)"
30+
axis_names: "('fsdp','tp')"
31+
reference_model_config:
32+
mesh: null
33+
same_mesh_as: "actor"
34+
rollout_model_config:
35+
mesh: null
36+
same_mesh_as: "actor"
37+
tokenizer_config:
38+
tokenizer_type: "sentencepiece"
39+
add_bos: False
40+
dataset_name: "gsm8k"
41+
batch_size: 1
42+
num_batches: 3738
43+
num_test_batches: 100
44+
num_train_epochs: 1
45+
rl_training_config:
46+
actor_optimizer_config:
47+
opt_type: "adamw"
48+
peak_value: 3e-6
49+
schedule_type: "warmup_cosine_decay_schedule"
50+
init_value: 0.0
51+
end_value: 0.0
52+
warmup_ratio: 0.1
53+
warmup_steps: 374
54+
decay_steps: 3738
55+
b1: 0.9
56+
b2: 0.99
57+
weight_decay: 0.1
58+
max_grad_norm: 0.1
59+
eval_every_n_steps: 10
60+
max_steps: 3738
61+
metrics_logging_options:
62+
flush_every_n_steps: 20
63+
checkpointing_options:
64+
save_interval_steps: 500
65+
max_to_keep: 4
66+
profiler_options: {}
67+
rollout_config:
68+
total_generation_steps: 768
69+
max_prompt_length: 256
70+
temperature: 0.9
71+
top_p: 1.0
72+
top_k: 50
73+
rollout_engine: "vanilla"
74+
offload_to_cpu: False
75+
grpo_config:
76+
num_generations: 2
77+
num_iterations: 1
78+
beta: 0.08
79+
epsilon: 0.2
80+
reward_functions:
81+
- "tunix/cli/reward_fn/gsm8k.py"

examples/rl/grpo/gsm8k/run_gemma3_12b.sh

Lines changed: 3 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -39,62 +39,18 @@ echo "Max steps: $max_steps"
3939
echo "Rounded warmup steps: $warmup_steps"
4040

4141
python3 -m tunix.cli.grpo_main \
42-
base_config.yaml \
43-
model_config.model_name="gemma-3-12b-it" \
44-
model_config.model_id="google/gemma-3-12b-it" \
42+
tunix/cli/base_config.yaml \
43+
override_config_file=examples/rl/grpo/gsm8k/configs/gemma3_12b.yaml \
4544
model_config.model_path="gs://gemma-data/checkpoints/gemma3-12b-it" \
46-
model_config.model_source="gcs" \
4745
model_config.intermediate_ckpt_dir="/tmp/intermediate_ckpt/gemma3_12b" \
4846
model_config.model_download_path="/tmp/models/gemma-3-12b-it" \
49-
model_config.mesh.shape="(2,4)" \
50-
model_config.mesh.axis_names="('fsdp','tp')" \
51-
model_config.rng_seed=42 \
52-
actor_model_config.lora_config.rank=64 \
53-
actor_model_config.lora_config.alpha=64.0 \
54-
actor_model_config.lora_config.module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" \
55-
actor_model_config.mesh.shape="(2,4)" \
56-
actor_model_config.mesh.axis_names="('fsdp','tp')" \
57-
reference_model_config.mesh=null \
58-
reference_model_config.same_mesh_as="actor" \
59-
rollout_model_config.mesh=null \
60-
rollout_model_config.same_mesh_as="actor" \
6147
tokenizer_config.tokenizer_path="gs://gemma-data/tokenizers/tokenizer_gemma3.model" \
62-
tokenizer_config.tokenizer_type="sentencepiece" \
63-
tokenizer_config.add_bos=false \
64-
dataset_name="gsm8k" \
6548
batch_size=$batch_size \
6649
num_batches=$num_batches \
67-
num_test_batches=100 \
6850
num_train_epochs=$num_train_epochs \
6951
train_fraction=$train_fraction \
70-
rl_training_config.actor_optimizer_config.opt_type="adamw" \
71-
rl_training_config.actor_optimizer_config.peak_value=3e-6 \
72-
rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \
73-
rl_training_config.actor_optimizer_config.init_value=0.0 \
74-
rl_training_config.actor_optimizer_config.end_value=0.0 \
7552
rl_training_config.actor_optimizer_config.warmup_ratio=$warmup_ratio \
7653
rl_training_config.actor_optimizer_config.warmup_steps=$warmup_steps \
7754
rl_training_config.actor_optimizer_config.decay_steps=$max_steps \
78-
rl_training_config.actor_optimizer_config.b1=0.9 \
79-
rl_training_config.actor_optimizer_config.b2=0.99 \
80-
rl_training_config.actor_optimizer_config.weight_decay=0.1 \
81-
rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \
82-
rl_training_config.eval_every_n_steps=10 \
8355
rl_training_config.max_steps=$max_steps \
84-
rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/grpo_gemma3_12b" \
85-
rl_training_config.metrics_logging_options.flush_every_n_steps=20 \
86-
rl_training_config.checkpointing_options.save_interval_steps=500 \
87-
rl_training_config.checkpointing_options.max_to_keep=4 \
88-
rl_training_config.profiler_options={} \
89-
rollout_config.total_generation_steps=768 \
90-
rollout_config.max_prompt_length=256 \
91-
rollout_config.temperature=0.9 \
92-
rollout_config.top_p=1.0 \
93-
rollout_config.top_k=50 \
94-
rollout_engine="vanilla" \
95-
offload_to_cpu=false \
96-
grpo_config.num_generations=2 \
97-
grpo_config.num_iterations=1 \
98-
grpo_config.beta=0.08 \
99-
grpo_config.epsilon=0.2 \
100-
reward_functions="['tunix/cli/reward_fn/gsm8k.py']"
56+
rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/grpo_gemma3_12b"

examples/rl/grpo/gsm8k/run_gemma3_1b.sh

Lines changed: 3 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -39,62 +39,18 @@ echo "Max steps: $max_steps"
3939
echo "Rounded warmup steps: $warmup_steps"
4040

4141
python3 -m tunix.cli.grpo_main \
42-
base_config.yaml \
43-
model_config.model_name="gemma-3-1b-it" \
44-
model_config.model_id="google/gemma-3-1b-it" \
42+
tunix/cli/base_config.yaml \
43+
override_config_file=examples/rl/grpo/gsm8k/configs/gemma3_1b.yaml \
4544
model_config.model_path="gs://gemma-data/checkpoints/gemma3-1b-it" \
46-
model_config.model_source="gcs" \
4745
model_config.intermediate_ckpt_dir="/tmp/intermediate_ckpt/gemma3_1b" \
4846
model_config.model_download_path="/tmp/models/gemma-3-1b-it" \
49-
model_config.mesh.shape="(2,4)" \
50-
model_config.mesh.axis_names="('fsdp','tp')" \
51-
model_config.rng_seed=42 \
52-
actor_model_config.lora_config.rank=64 \
53-
actor_model_config.lora_config.alpha=64.0 \
54-
actor_model_config.lora_config.module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" \
55-
actor_model_config.mesh.shape="(2,4)" \
56-
actor_model_config.mesh.axis_names="('fsdp','tp')" \
57-
reference_model_config.mesh=null \
58-
reference_model_config.same_mesh_as="actor" \
59-
rollout_model_config.mesh=null \
60-
rollout_model_config.same_mesh_as="actor" \
6147
tokenizer_config.tokenizer_path="gs://gemma-data/tokenizers/tokenizer_gemma3.model" \
62-
tokenizer_config.tokenizer_type="sentencepiece" \
63-
tokenizer_config.add_bos=false \
64-
dataset_name="gsm8k" \
6548
batch_size=$batch_size \
6649
num_batches=$num_batches \
67-
num_test_batches=100 \
6850
num_train_epochs=$num_train_epochs \
6951
train_fraction=$train_fraction \
70-
rl_training_config.actor_optimizer_config.opt_type="adamw" \
71-
rl_training_config.actor_optimizer_config.peak_value=3e-6 \
72-
rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \
73-
rl_training_config.actor_optimizer_config.init_value=0.0 \
74-
rl_training_config.actor_optimizer_config.end_value=0.0 \
7552
rl_training_config.actor_optimizer_config.warmup_ratio=$warmup_ratio \
7653
rl_training_config.actor_optimizer_config.warmup_steps=$warmup_steps \
7754
rl_training_config.actor_optimizer_config.decay_steps=$max_steps \
78-
rl_training_config.actor_optimizer_config.b1=0.9 \
79-
rl_training_config.actor_optimizer_config.b2=0.99 \
80-
rl_training_config.actor_optimizer_config.weight_decay=0.1 \
81-
rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \
82-
rl_training_config.eval_every_n_steps=10 \
8355
rl_training_config.max_steps=$max_steps \
84-
rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/grpo_gemma3_1b" \
85-
rl_training_config.metrics_logging_options.flush_every_n_steps=20 \
86-
rl_training_config.checkpointing_options.save_interval_steps=500 \
87-
rl_training_config.checkpointing_options.max_to_keep=4 \
88-
rl_training_config.profiler_options={} \
89-
rollout_config.total_generation_steps=768 \
90-
rollout_config.max_prompt_length=256 \
91-
rollout_config.temperature=0.9 \
92-
rollout_config.top_p=1.0 \
93-
rollout_config.top_k=50 \
94-
rollout_engine="vanilla" \
95-
offload_to_cpu=false \
96-
grpo_config.num_generations=2 \
97-
grpo_config.num_iterations=1 \
98-
grpo_config.beta=0.08 \
99-
grpo_config.epsilon=0.2 \
100-
reward_functions="['tunix/cli/reward_fn/gsm8k.py']"
56+
rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/grpo_gemma3_1b"

0 commit comments

Comments
 (0)