Skip to content

Commit e3ddb1a

Browse files
Merge pull request AI-Hypercomputer#2530 from AI-Hypercomputer:universal_grpo
PiperOrigin-RevId: 828688487
2 parents cb136bc + 1c627c1 commit e3ddb1a

File tree

17 files changed

+1251
-4416
lines changed

17 files changed

+1251
-4416
lines changed

dependencies/scripts/docker_upload_runner.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ if ! docker image inspect "${LOCAL_IMAGE_NAME}" &> /dev/null; then
8888
exit 1
8989
fi
9090

91-
docker build --build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} \
91+
docker build --no-cache --build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} \
9292
-f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_runner.Dockerfile' \
9393
-t ${LOCAL_IMAGE_NAME_RUNNER} .
9494

docs/tutorials/grpo.md

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,17 @@ We use the scheduler code from vLLM, and the model runner code from `tpu_commons
5858

5959
## Run GRPO
6060

61-
Finally, run the script
61+
Finally, run the command
6262

63-
`python ~/maxtext/src/MaxText/examples/grpo_llama3_1_8b_demo.py`
63+
```
64+
python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
65+
--model_name=llama3.1-8b \
66+
--tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \
67+
--load_parameters_path=gs://path/to/checkpoint/0/items \
68+
--run_name=$WORKLOAD \
69+
--base_output_directory=$OUTPUT_PATH \
70+
--hf_access_token=$HF_TOKEN
71+
```
6472

6573
The overview of the demo script is as follows:
6674

docs/tutorials/grpo_with_pathways.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,20 @@ bash docker_upload_runner.sh CLOUD_IMAGE_NAME=path/to/gcr.io
5353

5454
### Submit your jobs
5555

56-
Please use a pathways enabled cluster, and you can submit the script `maxtext/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py` via XPK
56+
Please use a pathways enabled [XPK](https://github.com/AI-Hypercomputer/xpk) cluster, and you can submit the `train_rl.py` script via [XPK](https://github.com/AI-Hypercomputer/xpk)
5757
```
5858
xpk workload create-pathways --workload $WORKLOAD \
5959
--docker-image path/to/gcr.io:latest --cluster $TPU_CLUSTER \
6060
--tpu-type=$TPU_TYPE --num-slices=1 --zone=$ZONE \
6161
--project=$PROJECT_ID --priority=high \
62-
--command "HF_TOKEN=$HF_TOKEN TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' python src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py"
62+
--command "HF_TOKEN=$HF_TOKEN TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' # Llama3.1-70B-Instruct
63+
python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \
64+
--model_name=llama3.1-70b \
65+
--tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \
66+
--load_parameters_path=gs://path/to/checkpoint/0/items \
67+
--run_name=$WORKLOAD \
68+
--base_output_directory=$OUTPUT_PATH \
69+
--hf_access_token=$HF_TOKEN"
6370
```
6471

6572
The overview of the demo script ~/maxtext/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py` is as follows:

src/MaxText/configs/base.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ run_name: ""
1818

1919
model_name: "default" # override config settings to match a specific model. other than the override, nothing should use this!
2020
override_model_config: False # When set to true allows overriding model parameters via CLI for the purpose of debugging/testing.
21+
debug:
22+
rl: False # RL-specific debugging
23+
2124
normalization_layer_epsilon: 1.e-05 # epsilon value for rmsnorm, layernorm.
2225

2326
################################## CHECKPOINTING ##################################
@@ -47,6 +50,7 @@ enable_checkpointing: True
4750
save_checkpoint_on_completion: True
4851
async_checkpointing: True
4952
checkpoint_period: 10_000
53+
max_num_checkpoints_to_keep: None
5054
# enables one replica to read the ckpt then broadcast to the rest
5155
enable_single_replica_ckpt_restoring: False
5256

src/MaxText/configs/rl.yml

Lines changed: 146 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -12,66 +12,151 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
# RL Configuration
16+
# This config consolidates common parameters for RL training across different model sizes
17+
1518
base_config: "base.yml"
1619

17-
logical_axis_rules: [
18-
['prefill_activation_length', ['data']],
19-
['prefill_activation_norm_length', ['data']],
20-
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
21-
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
22-
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
23-
['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
24-
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
25-
['activation_length', ['context_autoregressive', 'sequence']],
26-
['activation_length', ['context_autoregressive']],
27-
['activation_q_length', ['context_autoregressive']],
28-
['activation_kv_length', ['context_autoregressive']],
29-
['activation_norm_length', ['tensor_sequence', 'sequence']],
30-
['activation_embed', ['tensor_transpose']],
31-
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
32-
['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],
33-
['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
34-
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']],
35-
['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']],
36-
['activation_vocab', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence']],
37-
['activation_vocab', ['tensor', 'tensor_transpose']],
38-
['activation_vocab', 'tensor_sequence'],
39-
['activation_vocab', ['sequence', 'context_autoregressive']],
40-
['activation_stage', 'stage'],
41-
['activation_exp', ['expert', 'context_autoregressive']],
42-
['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']],
43-
['decode_length', []],
44-
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
45-
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive','context_autoregressive']],
46-
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
47-
['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
48-
['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
49-
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'expert']],
50-
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'expert']],
51-
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']],
52-
['embed', ['fsdp', 'sequence', 'expert']],
53-
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']],
54-
['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']],
55-
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']],
56-
['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive']],
57-
['norm', ['tensor', 'tensor_transpose', 'tensor_sequence']],
58-
['layers', 'stage'],
59-
['kv', []],
60-
['kv_head_dim', []],
61-
['cache_batch_prefill', []],
62-
['cache_batch', ['context_autoregressive']],
63-
['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']],
64-
['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']],
65-
['cache_kv', []],
66-
['cache_sequence', ['context_autoregressive']],
67-
['cache_scale_sequence', ['context_autoregressive']],
68-
['exp', ['expert', 'context_autoregressive']],
69-
['paged_kv_heads', []],
70-
['num_pages', ['tensor']],
71-
['tokens_per_page', []],
72-
['paged_kv_head_dim_size', []],
73-
]
74-
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
75-
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
76-
77-
return_log_prob: True
20+
# ====== Hardware =====
21+
trainer_devices_fraction: 0.5
22+
sampler_devices_fraction: 0.5
23+
chips_per_vm: 4 # depends on hardware, for v5p this is 4
24+
25+
# ====== Reproducibility ======
26+
data_shuffle_seed: 42
27+
28+
# ====== GRPO ======
29+
30+
# The number of times the policy generates multiple responses for a given prompt
31+
# within a single training step. This corresponds to `G` in Algorithm 1 in the
32+
# paper. The "group" in GRPO comes from here.
33+
num_generations: 2
34+
35+
# === other GRPO configs ===
36+
# The number of iterations per batch (𝜇 in GRPO algo 1).
37+
num_iterations: 1
38+
39+
# The coefficient for the KL divergence penalty (𝛽) in the GRPO loss function.
40+
# Important to keep a high enough value for this, otherwise, the KL divergence
41+
# can increase unchecked.
42+
grpo_beta: 0.08
43+
# Epsilon value for clipping (𝜀 in GRPO loss in paper). Similar to PPO, for
44+
# stable updates.
45+
grpo_epsilon: 0.2
46+
loss_algo: 'grpo' # grpo or gspo-token
47+
48+
49+
# ====== Models ======
50+
# for MaxText
51+
# Model and Tokenizer Configuration
52+
# Override these via CLI:
53+
# model_name, tokenizer_path, load_parameters_path
54+
# Model-Specific Overrides (examples)
55+
# For Llama3.1-8B:
56+
# model_name: llama3.1-8b
57+
# HF tokenizer_path: meta-llama/Llama-3.1-8B-Instruct
58+
#
59+
# For Llama3.1-70B with Pathways:
60+
# model_name: llama3.1-70b
61+
# HF tokenizer_path: meta-llama/Llama-3.1-70B-Instruct
62+
63+
# ====== MaxText configs ======
64+
weight_dtype: 'bfloat16'
65+
attention: 'dot_product'
66+
remat_policy: 'custom'
67+
decoder_layer_input: 'offload'
68+
query_proj: 'offload'
69+
key_proj: 'offload'
70+
value_proj: 'offload'
71+
checkpoint_storage_use_ocdbt: False # For Pathways
72+
checkpoint_storage_use_zarr3: False # For Pathways
73+
use_pathways: True
74+
75+
# ====== Debugging ======
76+
debug:
77+
rl: True
78+
79+
# ====== Training ======
80+
batch_size: 1
81+
# Increase `batch_size` and `MAX_STEPS` for better results.
82+
# num_batches: 3738
83+
num_batches: 4 # 200
84+
# Keep `num_test_batches` low so that evaluation runs quickly. It can be
85+
# increased to a max. of 330 (if batch size is 4).
86+
num_test_batches: 5 # 200
87+
train_fraction: 1.0
88+
89+
eval_interval: 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`.
90+
91+
num_epochs: 1 # can potentially train for more epochs
92+
93+
learning_rate: 3e-6
94+
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
95+
adam_b2: 0.99 # Exponential decay rate to track the second moment of past gradients.
96+
gradient_clipping_threshold: 0.1
97+
98+
# ====== Evaluation ======
99+
eval_sampling_strategy: "greedy" # can be "greedy", "standard", or "liberal"
100+
generation_configs:
101+
greedy:
102+
eval_temperature: 0.01
103+
eval_top_k: 1
104+
eval_top_p: 1.0
105+
standard:
106+
eval_temperature: 0.7
107+
eval_top_k: 50
108+
eval_top_p: 0.95
109+
liberal:
110+
eval_temperature: 0.85
111+
eval_top_k: 2000
112+
eval_top_p: 1.0
113+
114+
num_eval_passes: 1 # Number of generation passes during evaluation
115+
eval_corr_lst: False # If True, only include correct responses in the list during evaluation
116+
eval_make_lst: False # If True, return a list of (question, answer, responses) during evaluation
117+
118+
# ====== Inference ======
119+
# === Generation during GRPO training ===
120+
# max Lengths for prompt and completion
121+
max_prefill_predict_length: 256
122+
max_target_length: 1024
123+
kv_cache_buffer: 256
124+
hbm_utilization_vllm: 0.72
125+
swap_space_vllm_gb: 2
126+
# Generation Configuration During Training
127+
# Important to keep a high-ish temperature for varied, diverse responses during
128+
# training.
129+
decode_sampling_temperature: 0.9
130+
decode_sampling_top_k: 50
131+
decode_sampling_nucleus_p: 1.0
132+
133+
# ====== Checkpoint Configuration ======
134+
enable_checkpointing: True
135+
async_checkpointing: False
136+
checkpoint_period: 50
137+
max_num_checkpoints_to_keep: 10
138+
139+
# ====== Reward ======
140+
141+
reward_exact_format_match: 3.0
142+
reward_white_space_format_match: 1.5
143+
reward_partial_format_match: 0.5
144+
reward_ratio_guess_to_answer_high: 0.5
145+
reward_ratio_guess_to_answer_low: 0.25
146+
penalty_incorrect_format: -0.5
147+
penalty_incorrect_answer: -1.0
148+
149+
# ====== Special tokens/templates for GSM8K reasoning ======
150+
reasoning_start_token: '<reasoning>'
151+
reasoning_end_token: '</reasoning>'
152+
solution_start_token: '<answer>'
153+
solution_end_token: '</answer>'
154+
chat_template_path: 'src/MaxText/examples/chat_templates/gsm8k_rl.json'
155+
skip_jax_distributed_system: True
156+
157+
# # TODO(@mazumdera): fix this
158+
# Dataset Configuration
159+
dataset_name: 'gsm8k'
160+
train_split: 'train'
161+
eval_split: 'test'
162+
tokenizer_type: 'huggingface'

src/MaxText/configs/rl_mt_jt.yml

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
base_config: "base.yml"
16+
17+
logical_axis_rules: [
18+
['prefill_activation_length', ['data']],
19+
['prefill_activation_norm_length', ['data']],
20+
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
21+
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
22+
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
23+
['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
24+
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
25+
['activation_length', ['context_autoregressive', 'sequence']],
26+
['activation_length', ['context_autoregressive']],
27+
['activation_q_length', ['context_autoregressive']],
28+
['activation_kv_length', ['context_autoregressive']],
29+
['activation_norm_length', ['tensor_sequence', 'sequence']],
30+
['activation_embed', ['tensor_transpose']],
31+
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
32+
['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],
33+
['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
34+
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']],
35+
['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']],
36+
['activation_vocab', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence']],
37+
['activation_vocab', ['tensor', 'tensor_transpose']],
38+
['activation_vocab', 'tensor_sequence'],
39+
['activation_vocab', ['sequence', 'context_autoregressive']],
40+
['activation_stage', 'stage'],
41+
['activation_exp', ['expert', 'context_autoregressive']],
42+
['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']],
43+
['decode_length', []],
44+
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
45+
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive','context_autoregressive']],
46+
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
47+
['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
48+
['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
49+
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'expert']],
50+
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'expert']],
51+
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']],
52+
['embed', ['fsdp', 'sequence', 'expert']],
53+
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']],
54+
['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']],
55+
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']],
56+
['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive']],
57+
['norm', ['tensor', 'tensor_transpose', 'tensor_sequence']],
58+
['layers', 'stage'],
59+
['kv', []],
60+
['kv_head_dim', []],
61+
['cache_batch_prefill', []],
62+
['cache_batch', ['context_autoregressive']],
63+
['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']],
64+
['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']],
65+
['cache_kv', []],
66+
['cache_sequence', ['context_autoregressive']],
67+
['cache_scale_sequence', ['context_autoregressive']],
68+
['exp', ['expert', 'context_autoregressive']],
69+
['paged_kv_heads', []],
70+
['num_pages', ['tensor']],
71+
['tokens_per_page', []],
72+
['paged_kv_head_dim_size', []],
73+
]
74+
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
75+
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
76+
77+
return_log_prob: True

0 commit comments

Comments
 (0)