1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ # RL Configuration
16+ # This config consolidates common parameters for RL training across different model sizes
17+
1518base_config : " base.yml"
1619
17- logical_axis_rules : [
18- ['prefill_activation_length', ['data']],
19- ['prefill_activation_norm_length', ['data']],
20- ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
21- ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
22- ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
23- ['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
24- ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
25- ['activation_length', ['context_autoregressive', 'sequence']],
26- ['activation_length', ['context_autoregressive']],
27- ['activation_q_length', ['context_autoregressive']],
28- ['activation_kv_length', ['context_autoregressive']],
29- ['activation_norm_length', ['tensor_sequence', 'sequence']],
30- ['activation_embed', ['tensor_transpose']],
31- ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
32- ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],
33- ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
34- ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']],
35- ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']],
36- ['activation_vocab', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence']],
37- ['activation_vocab', ['tensor', 'tensor_transpose']],
38- ['activation_vocab', 'tensor_sequence'],
39- ['activation_vocab', ['sequence', 'context_autoregressive']],
40- ['activation_stage', 'stage'],
41- ['activation_exp', ['expert', 'context_autoregressive']],
42- ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']],
43- ['decode_length', []],
44- ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
45- ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive','context_autoregressive']],
46- ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
47- ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
48- ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
49- ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'expert']],
50- ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'expert']],
51- ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']],
52- ['embed', ['fsdp', 'sequence', 'expert']],
53- ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']],
54- ['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']],
55- ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']],
56- ['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive']],
57- ['norm', ['tensor', 'tensor_transpose', 'tensor_sequence']],
58- ['layers', 'stage'],
59- ['kv', []],
60- ['kv_head_dim', []],
61- ['cache_batch_prefill', []],
62- ['cache_batch', ['context_autoregressive']],
63- ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']],
64- ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']],
65- ['cache_kv', []],
66- ['cache_sequence', ['context_autoregressive']],
67- ['cache_scale_sequence', ['context_autoregressive']],
68- ['exp', ['expert', 'context_autoregressive']],
69- ['paged_kv_heads', []],
70- ['num_pages', ['tensor']],
71- ['tokens_per_page', []],
72- ['paged_kv_head_dim_size', []],
73- ]
74- # Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
75- data_sharding : [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
76-
77- return_log_prob : True
20+ # ====== Hardware =====
21+ trainer_devices_fraction : 0.5
22+ sampler_devices_fraction : 0.5
23+ chips_per_vm : 4 # depends on hardware, for v5p this is 4
24+
25+ # ====== Reproducibility ======
26+ data_shuffle_seed : 42
27+
28+ # ====== GRPO ======
29+
30+ # The number of times the policy generates multiple responses for a given prompt
31+ # within a single training step. This corresponds to `G` in Algorithm 1 in the
32+ # paper. The "group" in GRPO comes from here.
33+ num_generations : 2
34+
35+ # === other GRPO configs ===
36+ # The number of iterations per batch (𝜇 in GRPO algo 1).
37+ num_iterations : 1
38+
39+ # The coefficient for the KL divergence penalty (𝛽) in the GRPO loss function.
40+ # Important to keep a high enough value for this, otherwise, the KL divergence
41+ # can increase unchecked.
42+ grpo_beta : 0.08
43+ # Epsilon value for clipping (𝜀 in GRPO loss in paper). Similar to PPO, for
44+ # stable updates.
45+ grpo_epsilon : 0.2
46+ loss_algo : ' grpo' # grpo or gspo-token
47+
48+
49+ # ====== Models ======
50+ # for MaxText
51+ # Model and Tokenizer Configuration
52+ # Override these via CLI:
53+ # model_name, tokenizer_path, load_parameters_path
54+ # Model-Specific Overrides (examples)
55+ # For Llama3.1-8B:
56+ # model_name: llama3.1-8b
57+ # HF tokenizer_path: meta-llama/Llama-3.1-8B-Instruct
58+ #
59+ # For Llama3.1-70B with Pathways:
60+ # model_name: llama3.1-70b
61+ # HF tokenizer_path: meta-llama/Llama-3.1-70B-Instruct
62+
63+ # ====== MaxText configs ======
64+ weight_dtype : ' bfloat16'
65+ attention : ' dot_product'
66+ remat_policy : ' custom'
67+ decoder_layer_input : ' offload'
68+ query_proj : ' offload'
69+ key_proj : ' offload'
70+ value_proj : ' offload'
71+ checkpoint_storage_use_ocdbt : False # For Pathways
72+ checkpoint_storage_use_zarr3 : False # For Pathways
73+ use_pathways : True
74+
75+ # ====== Debugging ======
76+ debug :
77+ rl : True
78+
79+ # ====== Training ======
80+ batch_size : 1
81+ # Increase `batch_size` and `MAX_STEPS` for better results.
82+ # num_batches: 3738
83+ num_batches : 4 # 200
84+ # Keep `num_test_batches` low so that evaluation runs quickly. It can be
85+ # increased to a max. of 330 (if batch size is 4).
86+ num_test_batches : 5 # 200
87+ train_fraction : 1.0
88+
89+ eval_interval : 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`.
90+
91+ num_epochs : 1 # can potentially train for more epochs
92+
93+ learning_rate : 3e-6
94+ adam_b1 : 0.9 # Exponential decay rate to track the first moment of past gradients.
95+ adam_b2 : 0.99 # Exponential decay rate to track the second moment of past gradients.
96+ gradient_clipping_threshold : 0.1
97+
98+ # ====== Evaluation ======
99+ eval_sampling_strategy : " greedy" # can be "greedy", "standard", or "liberal"
100+ generation_configs :
101+ greedy :
102+ eval_temperature : 0.01
103+ eval_top_k : 1
104+ eval_top_p : 1.0
105+ standard :
106+ eval_temperature : 0.7
107+ eval_top_k : 50
108+ eval_top_p : 0.95
109+ liberal :
110+ eval_temperature : 0.85
111+ eval_top_k : 2000
112+ eval_top_p : 1.0
113+
114+ num_eval_passes : 1 # Number of generation passes during evaluation
115+ eval_corr_lst : False # If True, only include correct responses in the list during evaluation
116+ eval_make_lst : False # If True, return a list of (question, answer, responses) during evaluation
117+
118+ # ====== Inference ======
119+ # === Generation during GRPO training ===
120+ # max Lengths for prompt and completion
121+ max_prefill_predict_length : 256
122+ max_target_length : 1024
123+ kv_cache_buffer : 256
124+ hbm_utilization_vllm : 0.72
125+ swap_space_vllm_gb : 2
126+ # Generation Configuration During Training
127+ # Important to keep a high-ish temperature for varied, diverse responses during
128+ # training.
129+ decode_sampling_temperature : 0.9
130+ decode_sampling_top_k : 50
131+ decode_sampling_nucleus_p : 1.0
132+
133+ # ====== Checkpoint Configuration ======
134+ enable_checkpointing : True
135+ async_checkpointing : False
136+ checkpoint_period : 50
137+ max_num_checkpoints_to_keep : 10
138+
139+ # ====== Reward ======
140+
141+ reward_exact_format_match : 3.0
142+ reward_white_space_format_match : 1.5
143+ reward_partial_format_match : 0.5
144+ reward_ratio_guess_to_answer_high : 0.5
145+ reward_ratio_guess_to_answer_low : 0.25
146+ penalty_incorrect_format : -0.5
147+ penalty_incorrect_answer : -1.0
148+
149+ # ====== Special tokens/templates for GSM8K reasoning ======
150+ reasoning_start_token : ' <reasoning>'
151+ reasoning_end_token : ' </reasoning>'
152+ solution_start_token : ' <answer>'
153+ solution_end_token : ' </answer>'
154+ chat_template_path : ' src/MaxText/examples/chat_templates/gsm8k_rl.json'
155+ skip_jax_distributed_system : True
156+
157+ # # TODO(@mazumdera): fix this
158+ # Dataset Configuration
159+ dataset_name : ' gsm8k'
160+ train_split : ' train'
161+ eval_split : ' test'
162+ tokenizer_type : ' huggingface'
0 commit comments