Skip to content

Commit e1d4836

Browse files
author
The tunix Authors
committed
Merge pull request #1140 from precur-ai:fix_cli_rollout_config
PiperOrigin-RevId: 874916940
2 parents 7a1ae80 + e543f21 commit e1d4836

File tree

1 file changed

+29
-10
lines changed

1 file changed

+29
-10
lines changed

tunix/cli/grpo_main.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414

1515
"""Main entry point for GRPO training."""
16+
17+
import dataclasses
1618
from absl import app
1719
from absl import flags
1820
from absl import logging
@@ -41,16 +43,33 @@ class GrpoPipeline(config.HyperParameters):
4143

4244
def create_rollout_config(self):
4345
rollout_config = self.config["rollout_config"]
44-
return base_rollout.RolloutConfig(
45-
max_tokens_to_generate=rollout_config["total_generation_steps"],
46-
max_prompt_length=rollout_config["max_prompt_length"],
47-
kv_cache_size=rollout_config["max_prompt_length"]
48-
+ rollout_config["total_generation_steps"]
49-
+ 256,
50-
temperature=rollout_config["temperature"],
51-
top_p=rollout_config["top_p"],
52-
top_k=rollout_config["top_k"],
53-
)
46+
47+
# Get all valid field names from RolloutConfig
48+
valid_fields = {
49+
f.name for f in dataclasses.fields(base_rollout.RolloutConfig)
50+
}
51+
52+
# Filter rollout_config to only include valid keys
53+
filtered_config = {
54+
k: v for k, v in rollout_config.items() if k in valid_fields
55+
}
56+
57+
# Apply explicit recomputed/renamed values
58+
if "total_generation_steps" in rollout_config:
59+
filtered_config["max_tokens_to_generate"] = rollout_config[
60+
"total_generation_steps"
61+
]
62+
if (
63+
"max_prompt_length" in rollout_config
64+
and "total_generation_steps" in rollout_config
65+
):
66+
filtered_config["kv_cache_size"] = (
67+
rollout_config["max_prompt_length"]
68+
+ rollout_config["total_generation_steps"]
69+
+ 256
70+
)
71+
72+
return base_rollout.RolloutConfig(**filtered_config)
5473

5574
def create_role_to_mesh(self):
5675
default_mesh = self.create_mesh("actor_model_config")

0 commit comments

Comments
 (0)