@@ -39,62 +39,18 @@ echo "Max steps: $max_steps"
3939echo " Rounded warmup steps: $warmup_steps "
4040
4141python3 -m tunix.cli.grpo_main \
42- base_config.yaml \
43- model_config.model_name=" gemma-3-12b-it" \
44- model_config.model_id=" google/gemma-3-12b-it" \
42+ tunix/cli/base_config.yaml \
43+ override_config_file=examples/rl/grpo/gsm8k/configs/gemma3_12b.yaml \
4544 model_config.model_path=" gs://gemma-data/checkpoints/gemma3-12b-it" \
46- model_config.model_source=" gcs" \
4745 model_config.intermediate_ckpt_dir=" /tmp/intermediate_ckpt/gemma3_12b" \
4846 model_config.model_download_path=" /tmp/models/gemma-3-12b-it" \
49- model_config.mesh.shape=" (2,4)" \
50- model_config.mesh.axis_names=" ('fsdp','tp')" \
51- model_config.rng_seed=42 \
52- actor_model_config.lora_config.rank=64 \
53- actor_model_config.lora_config.alpha=64.0 \
54- actor_model_config.lora_config.module_path=" .*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" \
55- actor_model_config.mesh.shape=" (2,4)" \
56- actor_model_config.mesh.axis_names=" ('fsdp','tp')" \
57- reference_model_config.mesh=null \
58- reference_model_config.same_mesh_as=" actor" \
59- rollout_model_config.mesh=null \
60- rollout_model_config.same_mesh_as=" actor" \
6147 tokenizer_config.tokenizer_path=" gs://gemma-data/tokenizers/tokenizer_gemma3.model" \
62- tokenizer_config.tokenizer_type=" sentencepiece" \
63- tokenizer_config.add_bos=false \
64- dataset_name=" gsm8k" \
6548 batch_size=$batch_size \
6649 num_batches=$num_batches \
67- num_test_batches=100 \
6850 num_train_epochs=$num_train_epochs \
6951 train_fraction=$train_fraction \
70- rl_training_config.actor_optimizer_config.opt_type=" adamw" \
71- rl_training_config.actor_optimizer_config.peak_value=3e-6 \
72- rl_training_config.actor_optimizer_config.schedule_type=" warmup_cosine_decay_schedule" \
73- rl_training_config.actor_optimizer_config.init_value=0.0 \
74- rl_training_config.actor_optimizer_config.end_value=0.0 \
7552 rl_training_config.actor_optimizer_config.warmup_ratio=$warmup_ratio \
7653 rl_training_config.actor_optimizer_config.warmup_steps=$warmup_steps \
7754 rl_training_config.actor_optimizer_config.decay_steps=$max_steps \
78- rl_training_config.actor_optimizer_config.b1=0.9 \
79- rl_training_config.actor_optimizer_config.b2=0.99 \
80- rl_training_config.actor_optimizer_config.weight_decay=0.1 \
81- rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \
82- rl_training_config.eval_every_n_steps=10 \
8355 rl_training_config.max_steps=$max_steps \
84- rl_training_config.metrics_logging_options.log_dir=" /tmp/tensorboard/grpo_gemma3_12b" \
85- rl_training_config.metrics_logging_options.flush_every_n_steps=20 \
86- rl_training_config.checkpointing_options.save_interval_steps=500 \
87- rl_training_config.checkpointing_options.max_to_keep=4 \
88- rl_training_config.profiler_options={} \
89- rollout_config.total_generation_steps=768 \
90- rollout_config.max_prompt_length=256 \
91- rollout_config.temperature=0.9 \
92- rollout_config.top_p=1.0 \
93- rollout_config.top_k=50 \
94- rollout_engine=" vanilla" \
95- offload_to_cpu=false \
96- grpo_config.num_generations=2 \
97- grpo_config.num_iterations=1 \
98- grpo_config.beta=0.08 \
99- grpo_config.epsilon=0.2 \
100- reward_functions=" ['tunix/cli/reward_fn/gsm8k.py']"
56+ rl_training_config.metrics_logging_options.log_dir=" /tmp/tensorboard/grpo_gemma3_12b"
0 commit comments