@@ -73,8 +73,8 @@ class GrpoPipeline(config.HyperParameters):
7373 plus ``max_turns``, ``context_ratio``, ``per_turn_timeout_secs``.
7474 * role-specific ``*_model_config.mesh``: any role with an explicit mesh gets
7575 its own device slice; omitted meshes share the actor mesh by default.
76- * role-specific ``same_mesh_as ``: optional mesh sharing like
77- ``reference_model_config.same_mesh_as: actor`` .
76+ * role-specific ``colocate_with ``: share another role's device set while
77+ still allowing a different mesh shape on that same device set .
7878 * ``sglang_jax_config`` / ``vllm_config``: engine-specific rollout params.
7979 * ``chat_parser_config.type``: ``"default"`` or ``"qwen"``.
8080 * ``agent_class_path`` / ``env_class_path``: dotted Python paths to load
@@ -116,21 +116,19 @@ def _resolve_split_role(self, role_name: str) -> rl_cluster_lib.Role:
116116 )
117117 return self ._SPLIT_ROLE_ALIASES [normalized ]
118118
119- def _get_same_mesh_as_map (
119+ def _get_colocate_with_map (
120120 self ,
121121 ) -> dict [rl_cluster_lib .Role , rl_cluster_lib .Role ]:
122- same_mesh_as = {}
122+ colocate_with = {}
123123 for role , model_key in self ._ROLE_TO_MODEL_KEY .items ():
124124 model_cfg = self .config .get (model_key , {}) or {}
125- target_name = model_cfg .get ("same_mesh_as " )
125+ target_name = model_cfg .get ("colocate_with " )
126126 if target_name is None :
127127 continue
128- target_role = self ._resolve_split_role (str (target_name ))
129128 if role == rl_cluster_lib .Role .ACTOR :
130- raise ValueError ("Actor must own its mesh." )
131- same_mesh_as [role ] = target_role
132-
133- return same_mesh_as
129+ raise ValueError ("Actor must own its device set." )
130+ colocate_with [role ] = self ._resolve_split_role (str (target_name ))
131+ return colocate_with
134132
135133 def _is_role_active (self , role : rl_cluster_lib .Role ) -> bool :
136134 if role in (
@@ -145,10 +143,10 @@ def _is_role_active(self, role: rl_cluster_lib.Role) -> bool:
145143 def _resolve_mesh_owners (
146144 self ,
147145 ) -> dict [rl_cluster_lib .Role , rl_cluster_lib .Role ]:
148- same_mesh_as = self ._get_same_mesh_as_map ()
146+ colocate_with = self ._get_colocate_with_map ()
149147 base_owners = {}
150148 for role , model_key in self ._ROLE_TO_MODEL_KEY .items ():
151- if not self ._is_role_active (role ) and role not in same_mesh_as :
149+ if not self ._is_role_active (role ):
152150 continue
153151 has_mesh = bool (self .config .get (model_key , {}).get ("mesh" ))
154152 base_owners [role ] = (
@@ -162,35 +160,28 @@ def resolve_owner(
162160 seen : set [rl_cluster_lib .Role ],
163161 ) -> rl_cluster_lib .Role :
164162 if role in seen :
165- raise ValueError ("same_mesh_as contains a cycle." )
166- if role not in same_mesh_as :
163+ raise ValueError ("colocate_with contains a cycle." )
164+ if role not in colocate_with :
167165 return base_owners [role ]
168166 seen .add (role )
169- target_role = same_mesh_as [role ]
167+ target_role = colocate_with [role ]
170168 if target_role not in base_owners :
171169 raise ValueError (
172170 f"Role { target_role .value !r} is not active in this config."
173171 )
174172 return resolve_owner (target_role , seen )
175173
176174 role_to_owner = {}
177- for role , model_key in self ._ROLE_TO_MODEL_KEY .items ():
178- if role not in base_owners :
179- continue
180- has_mesh = bool (self .config .get (model_key , {}).get ("mesh" ))
181- if role in same_mesh_as :
182- if has_mesh :
183- raise ValueError (
184- f"{ model_key } .mesh is specified, so it must own a separate mesh "
185- "and cannot also use same_mesh_as."
186- )
187- else :
188- role_to_owner [role ] = resolve_owner (role , set ())
189- continue
175+ for role in base_owners :
190176 role_to_owner [role ] = resolve_owner (role , set ())
191177 return role_to_owner
192178
193- def _create_role_to_mesh (self ):
179+ def create_role_to_mesh (self ):
180+ """Build role→mesh mapping.
181+
182+ Any role with an explicit ``*.mesh`` config gets a dedicated device slice.
183+ Roles without a mesh share the actor mesh by default.
184+ """
194185 devices = list (jax .devices ())
195186 role_to_owner = self ._resolve_mesh_owners ()
196187 owner_order = []
@@ -235,16 +226,18 @@ def _create_role_to_mesh(self):
235226 for owner in owner_order
236227 },
237228 )
238- return {role : owner_to_mesh [owner ] for role , owner in role_to_owner .items ()}
239-
240- def create_role_to_mesh (self ):
241- """Build role→mesh mapping.
229+ role_to_mesh = {}
230+ for role , owner in role_to_owner .items ():
231+ model_key = self ._ROLE_TO_MODEL_KEY [role ]
232+ has_mesh = bool (self .config .get (model_key , {}).get ("mesh" ))
233+ if role == owner or not has_mesh :
234+ role_to_mesh [role ] = owner_to_mesh [owner ]
235+ else :
236+ role_to_mesh [role ] = self .create_mesh (
237+ model_key , devices = owner_to_device_slice [owner ]
238+ )
239+ return role_to_mesh
242240
243- Any role with an explicit ``*.mesh`` config gets a dedicated device slice.
244- Roles without a mesh share the actor mesh by default, or can point at
245- another role via ``same_mesh_as``.
246- """
247- return self ._create_role_to_mesh ()
248241
249242 # ------------------------------------------------------------------
250243 # Rollout config
0 commit comments