Skip to content

Commit 52d02f1

Browse files
committed
fix update weights with pla placeholder
1 parent ead3923 commit 52d02f1

File tree

8 files changed

+60
-26
lines changed

8 files changed

+60
-26
lines changed

.github/workflows/pr-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ jobs:
4848
strategy:
4949
fail-fast: false
5050
matrix:
51-
info: [{"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_gsm8k_short.py"}]
51+
info: [{"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_gsm8k_async_short.py"}, {"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_gsm8k_short.py"}, {"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_sglang_config.py"}]
5252
defaults:
5353
run:
5454
working-directory: ${{ github.workspace }}

.github/workflows/pr-test.yml.j2

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
'tests': [
55
{'test_file': 'test_qwen2.5_0.5B_gsm8k_async_short.py', 'num_gpus': 4},
66
{'test_file': 'test_qwen2.5_0.5B_gsm8k_short.py', 'num_gpus': 4},
7+
{'test_file': 'test_qwen2.5_0.5B_sglang_config.py', 'num_gpus': 4},
78
],
89
},
910
'e2e-test-fsdp': {

slime/backends/fsdp_utils/actor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -731,14 +731,15 @@ def update_weights(self) -> None: # type: ignore[override]
731731
if self.args.debug_train_only or self.args.debug_rollout_only:
732732
return
733733

734-
rollout_engines, rollout_engine_lock, num_new_engines, engine_gpu_counts = ray.get(
734+
rollout_engines, rollout_engine_lock, num_new_engines, engine_gpu_counts, engine_gpu_offsets = ray.get(
735735
self.rollout_manager.get_rollout_engines_and_lock.remote()
736736
)
737737
if num_new_engines > 0:
738738
self.weight_updater.connect_rollout_engines(
739739
rollout_engines,
740740
rollout_engine_lock,
741741
engine_gpu_counts=engine_gpu_counts,
742+
engine_gpu_offsets=engine_gpu_offsets,
742743
)
743744
dist.barrier(group=get_gloo_group())
744745
if dist.get_rank() == 0:

slime/backends/fsdp_utils/update_weight_utils.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def connect_rollout_engines(
4141
rollout_engines: Sequence[ActorHandle],
4242
rollout_engine_lock: ActorHandle | None,
4343
engine_gpu_counts: Sequence[int] | None = None,
44+
engine_gpu_offsets: Sequence[int] | None = None,
4445
) -> None:
4546
pass
4647

@@ -94,6 +95,7 @@ def connect_rollout_engines(
9495
rollout_engines: Sequence[ActorHandle],
9596
rollout_engine_lock: ActorHandle | None,
9697
engine_gpu_counts: Sequence[int] | None = None,
98+
engine_gpu_offsets: Sequence[int] | None = None,
9799
) -> None:
98100
"""Attach rollout engines and create per-engine IPC (Gloo) groups.
99101
@@ -104,15 +106,17 @@ def connect_rollout_engines(
104106

105107
if engine_gpu_counts is None:
106108
engine_gpu_counts = [self.args.rollout_num_gpus_per_engine] * len(rollout_engines)
107-
108-
# Cumulative rank offsets for (potentially) non-uniform engine groups.
109-
cumulative = [0]
110-
for c in engine_gpu_counts:
111-
cumulative.append(cumulative[-1] + c)
109+
if engine_gpu_offsets is None:
110+
# Fallback: assume engines are densely packed (no placeholder gaps).
111+
engine_gpu_offsets = []
112+
offset = 0
113+
for c in engine_gpu_counts:
114+
engine_gpu_offsets.append(offset)
115+
offset += c
112116

113117
for i, engine in enumerate(self.rollout_engines):
114-
start_rank = cumulative[i]
115-
end_rank = cumulative[i + 1]
118+
start_rank = engine_gpu_offsets[i]
119+
end_rank = start_rank + engine_gpu_counts[i]
116120
group_ranks = list(range(start_rank, end_rank))
117121
new_group = dist.new_group(
118122
ranks=group_ranks,
@@ -191,6 +195,7 @@ def connect_rollout_engines(
191195
rollout_engines: Sequence[ActorHandle],
192196
rollout_engine_lock: ActorHandle | None,
193197
engine_gpu_counts: Sequence[int] | None = None,
198+
engine_gpu_offsets: Sequence[int] | None = None,
194199
) -> None:
195200
"""On rank 0, initialize a temporary NCCL group for parameter broadcast."""
196201
self.rollout_engines = rollout_engines

slime/backends/megatron_utils/actor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ def update_weights(self) -> None:
539539
ray.get(self.rollout_manager.recover_rollout_engines.remote())
540540
dist.barrier(group=get_gloo_group())
541541

542-
rollout_engines, rollout_engine_lock, num_new_engines, engine_gpu_counts = ray.get(
542+
rollout_engines, rollout_engine_lock, num_new_engines, engine_gpu_counts, engine_gpu_offsets = ray.get(
543543
self.rollout_manager.get_rollout_engines_and_lock.remote()
544544
)
545545

@@ -551,6 +551,7 @@ def update_weights(self) -> None:
551551
rollout_engines,
552552
rollout_engine_lock,
553553
engine_gpu_counts=engine_gpu_counts,
554+
engine_gpu_offsets=engine_gpu_offsets,
554555
)
555556
dist.barrier(group=get_gloo_group())
556557
if dist.get_rank() == 0:

slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def connect_rollout_engines(
4747
rollout_engines: Sequence[ActorHandle],
4848
rollout_engine_lock: ActorHandle,
4949
engine_gpu_counts: Sequence[int] | None = None,
50+
engine_gpu_offsets: Sequence[int] | None = None,
5051
) -> None:
5152
"""
5253
Create NCCL "slime-pp_{pp_rank}" if PP source (DP=TP=0). Lock prevents concurrent broadcasts.

slime/backends/megatron_utils/update_weight/update_weight_from_tensor.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def connect_rollout_engines(
6363
rollout_engines: Sequence[ActorHandle],
6464
rollout_engine_lock: ActorHandle,
6565
engine_gpu_counts: Sequence[int] | None = None,
66+
engine_gpu_offsets: Sequence[int] | None = None,
6667
) -> None:
6768
"""
6869
Split colocated/distributed engines. Global source rank (DP=TP=PP=0) creates NCCL
@@ -72,15 +73,20 @@ def connect_rollout_engines(
7273

7374
if engine_gpu_counts is None:
7475
engine_gpu_counts = [self.args.rollout_num_gpus_per_engine] * len(rollout_engines)
75-
76-
# Compute colocated engine count from cumulative GPU budget.
76+
if engine_gpu_offsets is None:
77+
# Fallback: assume engines are densely packed (no placeholder gaps).
78+
engine_gpu_offsets = []
79+
offset = 0
80+
for c in engine_gpu_counts:
81+
engine_gpu_offsets.append(offset)
82+
offset += c
83+
84+
# Compute colocated engine count: engines whose GPUs fall within actor GPU range.
7785
total_actor_gpus = self.args.actor_num_nodes * self.args.actor_num_gpus_per_node
7886
colocate_engine_nums = 0
79-
gpu_sum = 0
80-
for c in engine_gpu_counts:
81-
if gpu_sum + c > total_actor_gpus:
87+
for gpu_offset, gpu_count in zip(engine_gpu_offsets, engine_gpu_counts, strict=True):
88+
if gpu_offset + gpu_count > total_actor_gpus:
8289
break
83-
gpu_sum += c
8490
colocate_engine_nums += 1
8591

8692
self.use_distribute = len(rollout_engines) > colocate_engine_nums
@@ -108,25 +114,24 @@ def connect_rollout_engines(
108114
engine_gpu_counts=distributed_gpu_counts,
109115
)
110116

111-
# Cumulative rank offsets for (potentially) non-uniform colocated groups.
117+
colocate_gpu_offsets = engine_gpu_offsets[:colocate_engine_nums]
112118
colocate_gpu_counts = engine_gpu_counts[:colocate_engine_nums]
113-
cumulative = [0]
114-
for c in colocate_gpu_counts:
115-
cumulative.append(cumulative[-1] + c)
116119

117120
# Create IPC Gloo gather groups (only on first call; partitioning is
118121
# fixed across reconnects).
119122
if self._ipc_gather_group is None:
120123
for i in range(colocate_engine_nums):
121-
group_ranks = list(range(cumulative[i], cumulative[i + 1]))
124+
group_ranks = list(range(colocate_gpu_offsets[i], colocate_gpu_offsets[i] + colocate_gpu_counts[i]))
122125
new_group = dist.new_group(ranks=group_ranks, backend="gloo")
123126
if dist.get_rank() in group_ranks:
124127
self._ipc_gather_group = new_group
125-
self._ipc_gather_src = cumulative[i]
128+
self._ipc_gather_src = colocate_gpu_offsets[i]
126129

127130
# Map training ranks to colocated engine actors.
128131
for i, engine in enumerate(self.rollout_engines):
129-
if cumulative[i] <= dist.get_rank() < cumulative[i + 1]:
132+
start = colocate_gpu_offsets[i]
133+
end = start + colocate_gpu_counts[i]
134+
if start <= dist.get_rank() < end:
130135
self._ipc_engine = engine
131136

132137
@torch.no_grad()

slime/ray/rollout.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,18 @@ def engine_gpu_counts(self) -> list[int]:
362362
"""Per-engine GPU count for all node-0 engines, parallel to ``engines``."""
363363
return [g.num_gpus_per_engine for g in self.engine_groups for _ in g.engines]
364364

365+
@property
366+
def engine_gpu_offsets(self) -> list[int]:
367+
"""Per-engine GPU offset for all node-0 engines, parallel to ``engines``.
368+
369+
Accounts for placeholder groups that occupy GPU slots without creating engines.
370+
"""
371+
offsets = []
372+
for g in self.engine_groups:
373+
for j in range(len(g.engines)):
374+
offsets.append(g.gpu_offset + j * g.num_gpus_per_engine)
375+
return offsets
376+
365377
@property
366378
def nodes_per_engine(self):
367379
"""Nodes per engine. Only valid when all active groups share the same value."""
@@ -505,8 +517,9 @@ def get_rollout_engines_and_lock(self, model_name: str | None = None):
505517
srv = self._get_server(model_name)
506518
engines = srv.engines if srv else []
507519
gpu_counts = srv.engine_gpu_counts if srv else []
520+
gpu_offsets = srv.engine_gpu_offsets if srv else []
508521
num_new = srv.num_new_engines if srv else 0
509-
return engines, self.rollout_engine_lock, num_new, gpu_counts
522+
return engines, self.rollout_engine_lock, num_new, gpu_counts, gpu_offsets
510523

511524
def get_num_rollout_per_epoch(self):
512525
assert self.args.rollout_global_dataset
@@ -566,10 +579,17 @@ def recover_rollout_engines(self, model_name: str | None = None):
566579
if self.rollout_id == -1 or srv is None:
567580
engines = srv.engines if srv else []
568581
gpu_counts = srv.engine_gpu_counts if srv else []
569-
return engines, self.rollout_engine_lock, (srv.num_new_engines if srv else 0), gpu_counts
582+
gpu_offsets = srv.engine_gpu_offsets if srv else []
583+
return engines, self.rollout_engine_lock, (srv.num_new_engines if srv else 0), gpu_counts, gpu_offsets
570584

571585
srv.recover()
572-
return srv.engines, self.rollout_engine_lock, srv.num_new_engines, srv.engine_gpu_counts
586+
return (
587+
srv.engines,
588+
self.rollout_engine_lock,
589+
srv.num_new_engines,
590+
srv.engine_gpu_counts,
591+
srv.engine_gpu_offsets,
592+
)
573593

574594
def clear_num_new_engines(self, model_name: str | None = None):
575595
# when fault tolerance is not enabled, we need to manually clear num_new_engines after update_weights

0 commit comments

Comments
 (0)