Skip to content

Commit bac171c

Browse files
committed
refactor: rename role to worker_type to align with sglang naming
1 parent 647f0e1 commit bac171c

File tree

3 files changed

+75
-81
lines changed

3 files changed

+75
-81
lines changed

slime/backends/sglang_utils/arguments.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,27 @@ def new_add_argument_wrapper(*name_or_flags, **kwargs):
117117
ServerArgs.add_cli_args(parser)
118118
parser.add_argument = old_add_argument
119119

120+
# PD disaggregation / multi-group config
121+
parser.add_argument(
122+
"--prefill-num-servers",
123+
type=int,
124+
default=None,
125+
help="Number of prefill servers for disaggregation.",
126+
)
127+
parser.add_argument(
128+
"--sglang-config",
129+
type=str,
130+
default=None,
131+
help=(
132+
"Path to a YAML config for SGLang engine deployment. "
133+
"Defines engine_groups with worker_type (regular/prefill/decode/placeholder), "
134+
"num_gpus per group, and optional per-group 'overrides' dict of "
135+
"ServerArgs field names that override the base --sglang-* CLI args. "
136+
"Placeholder groups reserve GPU slots without creating engines. "
137+
"Mutually exclusive with --prefill-num-servers."
138+
),
139+
)
140+
120141
return parser
121142

122143

@@ -141,6 +162,19 @@ def validate_args(args):
141162
if getattr(args, "sglang_router_ip", None):
142163
args.sglang_router_ip = _wrap_ipv6(args.sglang_router_ip)
143164

165+
# Mutual-exclusion checks for PD disaggregation / sglang-config.
166+
assert not (
167+
getattr(args, "prefill_num_servers", None) is not None and args.rollout_external
168+
), "prefill_num_servers cannot be set when rollout_external is set."
169+
170+
assert not (
171+
getattr(args, "sglang_config", None) is not None and args.rollout_external
172+
), "sglang_config cannot be set when rollout_external is set."
173+
174+
assert not (
175+
getattr(args, "sglang_config", None) is not None and getattr(args, "prefill_num_servers", None) is not None
176+
), "sglang_config and prefill_num_servers are mutually exclusive. Use engine_groups in the YAML config instead."
177+
144178

145179
def sglang_parse_args():
146180
"""

slime/ray/rollout.py

Lines changed: 41 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,23 @@ class EngineGroupConfig:
4040
"""Configuration for a single engine group.
4141
4242
Attributes:
43-
role: One of "regular", "prefill", "decode", or "placeholder".
44-
"placeholder" reserves GPU slots without creating engines.
43+
worker_type: One of "regular", "prefill", "decode", or "placeholder".
44+
"placeholder" reserves GPU slots without creating engines.
4545
num_gpus: Total number of GPUs for this group.
4646
overrides: Optional dict of SGLang ``ServerArgs`` field overrides.
4747
These are applied on top of the base CLI ``--sglang-*``
4848
arguments in ``_compute_server_args``.
4949
"""
5050

51-
role: str
51+
worker_type: str
5252
num_gpus: int
5353
overrides: dict = dataclasses.field(default_factory=dict)
5454

5555
def __post_init__(self):
56-
valid_roles = {"regular", "prefill", "decode", "placeholder"}
57-
assert self.role in valid_roles, f"Invalid role '{self.role}', must be one of {valid_roles}"
56+
valid_types = {"regular", "prefill", "decode", "placeholder"}
57+
assert (
58+
self.worker_type in valid_types
59+
), f"Invalid worker_type '{self.worker_type}', must be one of {valid_types}"
5860
assert self.num_gpus > 0, f"num_gpus must be > 0, got {self.num_gpus}"
5961

6062

@@ -65,14 +67,14 @@ class SglangConfig:
6567
Loaded from ``--sglang-config`` YAML file::
6668
6769
engine_groups:
68-
- role: prefill
70+
- worker_type: prefill
6971
num_gpus: 4
7072
overrides:
7173
mem_fraction_static: 0.9
7274
chunked_prefill_size: 8192
73-
- role: placeholder
75+
- worker_type: placeholder
7476
num_gpus: 2
75-
- role: decode
77+
- worker_type: decode
7678
num_gpus: 10
7779
overrides:
7880
mem_fraction_static: 0.7
@@ -104,15 +106,15 @@ def from_prefill_num_servers(args) -> "SglangConfig":
104106
assert decode_gpus > 0, f"No decode GPUs: total {total_gpus}, prefill {prefill_gpus}"
105107
return SglangConfig(
106108
engine_groups=[
107-
EngineGroupConfig(role="prefill", num_gpus=prefill_gpus),
108-
EngineGroupConfig(role="decode", num_gpus=decode_gpus),
109+
EngineGroupConfig(worker_type="prefill", num_gpus=prefill_gpus),
110+
EngineGroupConfig(worker_type="decode", num_gpus=decode_gpus),
109111
]
110112
)
111113

112114
@property
113115
def has_pd_disaggregation(self) -> bool:
114-
"""Whether any group uses prefill or decode roles."""
115-
return any(g.role in ("prefill", "decode") for g in self.engine_groups)
116+
"""Whether any group uses prefill or decode worker_type."""
117+
return any(g.worker_type in ("prefill", "decode") for g in self.engine_groups)
116118

117119
@property
118120
def total_num_gpus(self) -> int:
@@ -133,7 +135,7 @@ class EngineGroup:
133135
all_engines: list
134136
nodes_per_engine: int
135137
num_new_engines: int
136-
role: str = "regular" # "regular", "prefill", or "decode"
138+
worker_type: str = "regular" # "regular", "prefill", or "decode"
137139
rank_offset: int = 0 # global rank of the first engine in this group
138140
sglang_overrides: dict = dataclasses.field(default_factory=dict)
139141

@@ -148,9 +150,9 @@ def start_engines(self) -> list:
148150
Returns a list of Ray ObjectRefs for the init calls. The caller
149151
should ``ray.get()`` on them to block until the engines are healthy.
150152
151-
Placeholder groups (role="placeholder") skip engine creation entirely.
153+
Placeholder groups (worker_type="placeholder") skip engine creation entirely.
152154
"""
153-
if self.args.debug_train_only or self.role == "placeholder":
155+
if self.args.debug_train_only or self.worker_type == "placeholder":
154156
self.num_new_engines = 0
155157
return []
156158

@@ -202,7 +204,7 @@ def start_engines(self) -> list:
202204
).remote(
203205
self.args,
204206
rank=global_rank,
205-
worker_type=self.role,
207+
worker_type=self.worker_type,
206208
base_gpu_id=base_gpu_id,
207209
sglang_overrides=self.sglang_overrides,
208210
)
@@ -224,7 +226,7 @@ def start_engines(self) -> list:
224226
args=self.args,
225227
num_engines=total_num_engines,
226228
rollout_engines=rollout_engines,
227-
worker_type=self.role,
229+
worker_type=self.worker_type,
228230
)
229231

230232
init_handles = [engine.init.remote(**(addr_and_ports[rank])) for rank, engine in rollout_engines]
@@ -252,11 +254,11 @@ class RolloutServer:
252254
Configured via ``--sglang-config`` YAML::
253255
254256
engine_groups:
255-
- role: prefill
257+
- worker_type: prefill
256258
num_gpus: 4
257-
- role: placeholder
259+
- worker_type: placeholder
258260
num_gpus: 2
259-
- role: decode
261+
- worker_type: decode
260262
num_gpus: 10
261263
262264
Placeholder groups reserve GPU slots without creating engines.
@@ -266,56 +268,49 @@ class RolloutServer:
266268
router_ip: str | None = None
267269
router_port: int | None = None
268270

269-
@property
270-
def active_engine_groups(self):
271-
"""Engine groups that have actual engines (excludes placeholder)."""
272-
return [g for g in self.engine_groups if g.role != "placeholder"]
273-
274271
@property
275272
def engines(self):
276-
"""All node-0 engines across all active groups."""
277-
return [e for g in self.active_engine_groups for e in g.engines]
273+
"""All node-0 engines across all groups (placeholder groups contribute nothing)."""
274+
return [e for g in self.engine_groups for e in g.engines]
278275

279276
@property
280277
def all_engines(self):
281-
"""All engines (including non-node-0) across all active groups."""
282-
return [e for g in self.active_engine_groups for e in g.all_engines]
278+
"""All engines (including non-node-0) across all groups."""
279+
return [e for g in self.engine_groups for e in g.all_engines]
283280

284281
@property
285282
def num_new_engines(self):
286-
return sum(g.num_new_engines for g in self.active_engine_groups)
283+
return sum(g.num_new_engines for g in self.engine_groups)
287284

288285
@num_new_engines.setter
289286
def num_new_engines(self, value):
290-
for g in self.active_engine_groups:
287+
for g in self.engine_groups:
291288
g.num_new_engines = value
292289

293290
@property
294291
def nodes_per_engine(self):
295292
"""Nodes per engine. Only valid when all active groups share the same value."""
296-
values = {g.nodes_per_engine for g in self.active_engine_groups}
293+
values = {g.nodes_per_engine for g in self.engine_groups}
297294
assert len(values) == 1, f"Heterogeneous nodes_per_engine: {values}"
298295
return values.pop()
299296

300297
def recover(self):
301298
"""Recover dead engines across all active groups, overlapping init."""
302299
# Record dead indices per group before starting.
303-
dead_per_group = [
304-
[i for i, engine in enumerate(g.all_engines) if engine is None] for g in self.active_engine_groups
305-
]
300+
dead_per_group = [[i for i, engine in enumerate(g.all_engines) if engine is None] for g in self.engine_groups]
306301

307302
# Start all groups concurrently.
308303
all_handles = []
309-
for g in self.active_engine_groups:
304+
for g in self.engine_groups:
310305
all_handles.extend(g.start_engines())
311306
if all_handles:
312307
ray.get(all_handles)
313308

314309
# Post-recovery: offload then onload weights for newly created engines.
315310
release_handles = []
316311
new_engines_all = []
317-
for g, dead_indices in zip(self.active_engine_groups, dead_per_group, strict=True):
318-
logger.info(f"Recovered {g.num_new_engines} dead rollout engines (role={g.role})")
312+
for g, dead_indices in zip(self.engine_groups, dead_per_group, strict=True):
313+
logger.info(f"Recovered {g.num_new_engines} dead rollout engines (worker_type={g.worker_type})")
319314
assert g.num_new_engines == len(dead_indices), "num_new_engines does not match dead_indices length"
320315
if g.args.offload_rollout and dead_indices:
321316
new_engines = [g.all_engines[i] for i in dead_indices]
@@ -329,16 +324,16 @@ def recover(self):
329324
)
330325

331326
def offload(self):
332-
"""Release memory occupation across all active groups (concurrent)."""
327+
"""Release memory occupation across all groups (concurrent)."""
333328
handles = []
334-
for g in self.active_engine_groups:
329+
for g in self.engine_groups:
335330
handles.extend(g.offload())
336331
return ray.get(handles) if handles else []
337332

338333
def onload(self, tags: list[str] | None = None):
339-
"""Resume memory occupation across all active groups (concurrent)."""
334+
"""Resume memory occupation across all groups (concurrent)."""
340335
handles = []
341-
for g in self.active_engine_groups:
336+
for g in self.engine_groups:
342337
handles.extend(g.onload(tags))
343338
return ray.get(handles) if handles else []
344339

@@ -381,7 +376,7 @@ def __init__(self, args, pg):
381376

382377
self._health_monitors = []
383378
if not self.args.debug_train_only and self.args.use_fault_tolerance:
384-
for group in self.server.active_engine_groups:
379+
for group in self.server.engine_groups:
385380
monitor = RolloutHealthMonitor(group, args)
386381
monitor.start()
387382
self._health_monitors.append(monitor)
@@ -919,10 +914,10 @@ def start_rollout_server(args, pg) -> RolloutServer:
919914
group = EngineGroup(
920915
args=args,
921916
pg=pg,
922-
all_engines=[None] * num_engines if group_cfg.role != "placeholder" else [],
917+
all_engines=[None] * num_engines if group_cfg.worker_type != "placeholder" else [],
923918
nodes_per_engine=nodes_per_engine,
924919
num_new_engines=0,
925-
role=group_cfg.role,
920+
worker_type=group_cfg.worker_type,
926921
rank_offset=rank_offset,
927922
sglang_overrides=group_cfg.overrides,
928923
)
@@ -955,7 +950,7 @@ def _resolve_sglang_config(args) -> SglangConfig:
955950
return SglangConfig.from_prefill_num_servers(args)
956951

957952
# Default: single regular group.
958-
return SglangConfig(engine_groups=[EngineGroupConfig(role="regular", num_gpus=args.rollout_num_gpus)])
953+
return SglangConfig(engine_groups=[EngineGroupConfig(worker_type="regular", num_gpus=args.rollout_num_gpus)])
959954

960955

961956
def _log_eval_rollout_data(rollout_id, args, data, extra_metrics: dict[str, Any] | None = None):

slime/utils/arguments.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,28 +1361,6 @@ def add_mtp_training_arguments(parser):
13611361

13621362
return parser
13631363

1364-
def add_prefill_decode_disaggregation_arguments(parser):
1365-
parser.add_argument(
1366-
"--prefill-num-servers",
1367-
type=int,
1368-
default=None,
1369-
help="Number of prefill servers for disaggregation.",
1370-
)
1371-
parser.add_argument(
1372-
"--sglang-config",
1373-
type=str,
1374-
default=None,
1375-
help=(
1376-
"Path to a YAML config for SGLang engine deployment. "
1377-
"Defines engine_groups with roles (regular/prefill/decode/placeholder), "
1378-
"num_gpus per group, and optional per-group 'overrides' dict of "
1379-
"ServerArgs field names that override the base --sglang-* CLI args. "
1380-
"Placeholder groups reserve GPU slots without creating engines. "
1381-
"Mutually exclusive with --prefill-num-servers."
1382-
),
1383-
)
1384-
return parser
1385-
13861364
def add_ci_arguments(parser):
13871365
parser.add_argument(
13881366
"--ci-test",
@@ -1424,7 +1402,6 @@ def add_ci_arguments(parser):
14241402
parser = add_reward_model_arguments(parser)
14251403
parser = add_rollout_buffer_arguments(parser)
14261404
parser = add_mtp_training_arguments(parser)
1427-
parser = add_prefill_decode_disaggregation_arguments(parser)
14281405
parser = add_ci_arguments(parser)
14291406
parser = add_custom_megatron_plugins_arguments(parser)
14301407
reset_arg(
@@ -1796,18 +1773,6 @@ def slime_validate_args(args):
17961773
args.rollout_max_prompt_len <= args.rollout_max_context_len - 1
17971774
), f"args.rollout_max_prompt_len ({args.rollout_max_prompt_len}) must be smaller than args.rollout_max_context_len ({args.rollout_max_context_len}) so that there is at least one generated token to compute loss."
17981775

1799-
assert not (
1800-
args.prefill_num_servers is not None and args.rollout_external
1801-
), "prefill_num_servers cannot be set when rollout_external is set."
1802-
1803-
assert not (
1804-
getattr(args, "sglang_config", None) is not None and args.rollout_external
1805-
), "sglang_config cannot be set when rollout_external is set."
1806-
1807-
assert not (
1808-
getattr(args, "sglang_config", None) is not None and args.prefill_num_servers is not None
1809-
), "sglang_config and prefill_num_servers are mutually exclusive. Use engine_groups in the YAML config instead."
1810-
18111776
if args.qkv_format == "bshd":
18121777
assert args.train_backend == "megatron", "bshd format is only supported for megatron backend."
18131778
assert (

0 commit comments

Comments
 (0)