|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | """Main entry point for GRPO training.""" |
| 16 | + |
| 17 | +import dataclasses |
16 | 18 | from absl import app |
17 | 19 | from absl import flags |
18 | 20 | from absl import logging |
@@ -41,16 +43,33 @@ class GrpoPipeline(config.HyperParameters): |
41 | 43 |
|
42 | 44 | def create_rollout_config(self): |
43 | 45 | rollout_config = self.config["rollout_config"] |
44 | | - return base_rollout.RolloutConfig( |
45 | | - max_tokens_to_generate=rollout_config["total_generation_steps"], |
46 | | - max_prompt_length=rollout_config["max_prompt_length"], |
47 | | - kv_cache_size=rollout_config["max_prompt_length"] |
48 | | - + rollout_config["total_generation_steps"] |
49 | | - + 256, |
50 | | - temperature=rollout_config["temperature"], |
51 | | - top_p=rollout_config["top_p"], |
52 | | - top_k=rollout_config["top_k"], |
53 | | - ) |
| 46 | + |
| 47 | + # Get all valid field names from RolloutConfig |
| 48 | + valid_fields = { |
| 49 | + f.name for f in dataclasses.fields(base_rollout.RolloutConfig) |
| 50 | + } |
| 51 | + |
| 52 | + # Filter rollout_config to only include valid keys |
| 53 | + filtered_config = { |
| 54 | + k: v for k, v in rollout_config.items() if k in valid_fields |
| 55 | + } |
| 56 | + |
| 57 | + # Apply explicit recomputed/renamed values |
| 58 | + if "total_generation_steps" in rollout_config: |
| 59 | + filtered_config["max_tokens_to_generate"] = rollout_config[ |
| 60 | + "total_generation_steps" |
| 61 | + ] |
| 62 | + if ( |
| 63 | + "max_prompt_length" in rollout_config |
| 64 | + and "total_generation_steps" in rollout_config |
| 65 | + ): |
| 66 | + filtered_config["kv_cache_size"] = ( |
| 67 | + rollout_config["max_prompt_length"] |
| 68 | + + rollout_config["total_generation_steps"] |
| 69 | + + 256 |
| 70 | + ) |
| 71 | + |
| 72 | + return base_rollout.RolloutConfig(**filtered_config) |
54 | 73 |
|
55 | 74 | def create_role_to_mesh(self): |
56 | 75 | default_mesh = self.create_mesh("actor_model_config") |
|
0 commit comments