@@ -40,21 +40,23 @@ class EngineGroupConfig:
4040 """Configuration for a single engine group.
4141
4242 Attributes:
43- role : One of "regular", "prefill", "decode", or "placeholder".
44- "placeholder" reserves GPU slots without creating engines.
43+ worker_type : One of "regular", "prefill", "decode", or "placeholder".
44+ "placeholder" reserves GPU slots without creating engines.
4545 num_gpus: Total number of GPUs for this group.
4646 overrides: Optional dict of SGLang ``ServerArgs`` field overrides.
4747 These are applied on top of the base CLI ``--sglang-*``
4848 arguments in ``_compute_server_args``.
4949 """
5050
51- role : str
51+ worker_type : str
5252 num_gpus : int
5353 overrides : dict = dataclasses .field (default_factory = dict )
5454
5555 def __post_init__ (self ):
56- valid_roles = {"regular" , "prefill" , "decode" , "placeholder" }
57- assert self .role in valid_roles , f"Invalid role '{ self .role } ', must be one of { valid_roles } "
56+ valid_types = {"regular" , "prefill" , "decode" , "placeholder" }
57+ assert (
58+ self .worker_type in valid_types
59+ ), f"Invalid worker_type '{ self .worker_type } ', must be one of { valid_types } "
5860 assert self .num_gpus > 0 , f"num_gpus must be > 0, got { self .num_gpus } "
5961
6062
@@ -65,14 +67,14 @@ class SglangConfig:
6567 Loaded from ``--sglang-config`` YAML file::
6668
6769 engine_groups:
68- - role : prefill
70+ - worker_type : prefill
6971 num_gpus: 4
7072 overrides:
7173 mem_fraction_static: 0.9
7274 chunked_prefill_size: 8192
73- - role : placeholder
75+ - worker_type : placeholder
7476 num_gpus: 2
75- - role : decode
77+ - worker_type : decode
7678 num_gpus: 10
7779 overrides:
7880 mem_fraction_static: 0.7
@@ -104,15 +106,15 @@ def from_prefill_num_servers(args) -> "SglangConfig":
104106 assert decode_gpus > 0 , f"No decode GPUs: total { total_gpus } , prefill { prefill_gpus } "
105107 return SglangConfig (
106108 engine_groups = [
107- EngineGroupConfig (role = "prefill" , num_gpus = prefill_gpus ),
108- EngineGroupConfig (role = "decode" , num_gpus = decode_gpus ),
109+ EngineGroupConfig (worker_type = "prefill" , num_gpus = prefill_gpus ),
110+ EngineGroupConfig (worker_type = "decode" , num_gpus = decode_gpus ),
109111 ]
110112 )
111113
112114 @property
113115 def has_pd_disaggregation (self ) -> bool :
114- """Whether any group uses prefill or decode roles ."""
115- return any (g .role in ("prefill" , "decode" ) for g in self .engine_groups )
116+ """Whether any group uses prefill or decode worker_type ."""
117+ return any (g .worker_type in ("prefill" , "decode" ) for g in self .engine_groups )
116118
117119 @property
118120 def total_num_gpus (self ) -> int :
@@ -133,7 +135,7 @@ class EngineGroup:
133135 all_engines : list
134136 nodes_per_engine : int
135137 num_new_engines : int
136- role : str = "regular" # "regular", "prefill", or "decode"
138+ worker_type : str = "regular" # "regular", "prefill", or "decode"
137139 rank_offset : int = 0 # global rank of the first engine in this group
138140 sglang_overrides : dict = dataclasses .field (default_factory = dict )
139141
@@ -148,9 +150,9 @@ def start_engines(self) -> list:
148150 Returns a list of Ray ObjectRefs for the init calls. The caller
149151 should ``ray.get()`` on them to block until the engines are healthy.
150152
151- Placeholder groups (role ="placeholder") skip engine creation entirely.
153+ Placeholder groups (worker_type ="placeholder") skip engine creation entirely.
152154 """
153- if self .args .debug_train_only or self .role == "placeholder" :
155+ if self .args .debug_train_only or self .worker_type == "placeholder" :
154156 self .num_new_engines = 0
155157 return []
156158
@@ -202,7 +204,7 @@ def start_engines(self) -> list:
202204 ).remote (
203205 self .args ,
204206 rank = global_rank ,
205- worker_type = self .role ,
207+ worker_type = self .worker_type ,
206208 base_gpu_id = base_gpu_id ,
207209 sglang_overrides = self .sglang_overrides ,
208210 )
@@ -224,7 +226,7 @@ def start_engines(self) -> list:
224226 args = self .args ,
225227 num_engines = total_num_engines ,
226228 rollout_engines = rollout_engines ,
227- worker_type = self .role ,
229+ worker_type = self .worker_type ,
228230 )
229231
230232 init_handles = [engine .init .remote (** (addr_and_ports [rank ])) for rank , engine in rollout_engines ]
@@ -252,11 +254,11 @@ class RolloutServer:
252254 Configured via ``--sglang-config`` YAML::
253255
254256 engine_groups:
255- - role : prefill
257+ - worker_type : prefill
256258 num_gpus: 4
257- - role : placeholder
259+ - worker_type : placeholder
258260 num_gpus: 2
259- - role : decode
261+ - worker_type : decode
260262 num_gpus: 10
261263
262264 Placeholder groups reserve GPU slots without creating engines.
@@ -266,56 +268,49 @@ class RolloutServer:
266268 router_ip : str | None = None
267269 router_port : int | None = None
268270
269- @property
270- def active_engine_groups (self ):
271- """Engine groups that have actual engines (excludes placeholder)."""
272- return [g for g in self .engine_groups if g .role != "placeholder" ]
273-
274271 @property
275272 def engines (self ):
276- """All node-0 engines across all active groups."""
277- return [e for g in self .active_engine_groups for e in g .engines ]
273+ """All node-0 engines across all groups (placeholder groups contribute nothing) ."""
274+ return [e for g in self .engine_groups for e in g .engines ]
278275
279276 @property
280277 def all_engines (self ):
281- """All engines (including non-node-0) across all active groups."""
282- return [e for g in self .active_engine_groups for e in g .all_engines ]
278+ """All engines (including non-node-0) across all groups."""
279+ return [e for g in self .engine_groups for e in g .all_engines ]
283280
284281 @property
285282 def num_new_engines (self ):
286- return sum (g .num_new_engines for g in self .active_engine_groups )
283+ return sum (g .num_new_engines for g in self .engine_groups )
287284
288285 @num_new_engines .setter
289286 def num_new_engines (self , value ):
290- for g in self .active_engine_groups :
287+ for g in self .engine_groups :
291288 g .num_new_engines = value
292289
293290 @property
294291 def nodes_per_engine (self ):
295292 """Nodes per engine. Only valid when all active groups share the same value."""
296- values = {g .nodes_per_engine for g in self .active_engine_groups }
293+ values = {g .nodes_per_engine for g in self .engine_groups }
297294 assert len (values ) == 1 , f"Heterogeneous nodes_per_engine: { values } "
298295 return values .pop ()
299296
300297 def recover (self ):
301298 """Recover dead engines across all active groups, overlapping init."""
302299 # Record dead indices per group before starting.
303- dead_per_group = [
304- [i for i , engine in enumerate (g .all_engines ) if engine is None ] for g in self .active_engine_groups
305- ]
300+ dead_per_group = [[i for i , engine in enumerate (g .all_engines ) if engine is None ] for g in self .engine_groups ]
306301
307302 # Start all groups concurrently.
308303 all_handles = []
309- for g in self .active_engine_groups :
304+ for g in self .engine_groups :
310305 all_handles .extend (g .start_engines ())
311306 if all_handles :
312307 ray .get (all_handles )
313308
314309 # Post-recovery: offload then onload weights for newly created engines.
315310 release_handles = []
316311 new_engines_all = []
317- for g , dead_indices in zip (self .active_engine_groups , dead_per_group , strict = True ):
318- logger .info (f"Recovered { g .num_new_engines } dead rollout engines (role ={ g .role } )" )
312+ for g , dead_indices in zip (self .engine_groups , dead_per_group , strict = True ):
313+ logger .info (f"Recovered { g .num_new_engines } dead rollout engines (worker_type ={ g .worker_type } )" )
319314 assert g .num_new_engines == len (dead_indices ), "num_new_engines does not match dead_indices length"
320315 if g .args .offload_rollout and dead_indices :
321316 new_engines = [g .all_engines [i ] for i in dead_indices ]
@@ -329,16 +324,16 @@ def recover(self):
329324 )
330325
331326 def offload (self ):
332- """Release memory occupation across all active groups (concurrent)."""
327+ """Release memory occupation across all groups (concurrent)."""
333328 handles = []
334- for g in self .active_engine_groups :
329+ for g in self .engine_groups :
335330 handles .extend (g .offload ())
336331 return ray .get (handles ) if handles else []
337332
338333 def onload (self , tags : list [str ] | None = None ):
339- """Resume memory occupation across all active groups (concurrent)."""
334+ """Resume memory occupation across all groups (concurrent)."""
340335 handles = []
341- for g in self .active_engine_groups :
336+ for g in self .engine_groups :
342337 handles .extend (g .onload (tags ))
343338 return ray .get (handles ) if handles else []
344339
@@ -381,7 +376,7 @@ def __init__(self, args, pg):
381376
382377 self ._health_monitors = []
383378 if not self .args .debug_train_only and self .args .use_fault_tolerance :
384- for group in self .server .active_engine_groups :
379+ for group in self .server .engine_groups :
385380 monitor = RolloutHealthMonitor (group , args )
386381 monitor .start ()
387382 self ._health_monitors .append (monitor )
@@ -919,10 +914,10 @@ def start_rollout_server(args, pg) -> RolloutServer:
919914 group = EngineGroup (
920915 args = args ,
921916 pg = pg ,
922- all_engines = [None ] * num_engines if group_cfg .role != "placeholder" else [],
917+ all_engines = [None ] * num_engines if group_cfg .worker_type != "placeholder" else [],
923918 nodes_per_engine = nodes_per_engine ,
924919 num_new_engines = 0 ,
925- role = group_cfg .role ,
920+ worker_type = group_cfg .worker_type ,
926921 rank_offset = rank_offset ,
927922 sglang_overrides = group_cfg .overrides ,
928923 )
@@ -955,7 +950,7 @@ def _resolve_sglang_config(args) -> SglangConfig:
955950 return SglangConfig .from_prefill_num_servers (args )
956951
957952 # Default: single regular group.
958- return SglangConfig (engine_groups = [EngineGroupConfig (role = "regular" , num_gpus = args .rollout_num_gpus )])
953+ return SglangConfig (engine_groups = [EngineGroupConfig (worker_type = "regular" , num_gpus = args .rollout_num_gpus )])
959954
960955
961956def _log_eval_rollout_data (rollout_id , args , data , extra_metrics : dict [str , Any ] | None = None ):
0 commit comments