@@ -63,6 +63,7 @@ def connect_rollout_engines(
6363 rollout_engines : Sequence [ActorHandle ],
6464 rollout_engine_lock : ActorHandle ,
6565 engine_gpu_counts : Sequence [int ] | None = None ,
66+ engine_gpu_offsets : Sequence [int ] | None = None ,
6667 ) -> None :
6768 """
6869 Split colocated/distributed engines. Global source rank (DP=TP=PP=0) creates NCCL
@@ -72,15 +73,20 @@ def connect_rollout_engines(
7273
7374 if engine_gpu_counts is None :
7475 engine_gpu_counts = [self .args .rollout_num_gpus_per_engine ] * len (rollout_engines )
75-
76- # Compute colocated engine count from cumulative GPU budget.
76+ if engine_gpu_offsets is None :
77+ # Fallback: assume engines are densely packed (no placeholder gaps).
78+ engine_gpu_offsets = []
79+ offset = 0
80+ for c in engine_gpu_counts :
81+ engine_gpu_offsets .append (offset )
82+ offset += c
83+
84+ # Compute colocated engine count: engines whose GPUs fall within actor GPU range.
7785 total_actor_gpus = self .args .actor_num_nodes * self .args .actor_num_gpus_per_node
7886 colocate_engine_nums = 0
79- gpu_sum = 0
80- for c in engine_gpu_counts :
81- if gpu_sum + c > total_actor_gpus :
87+ for gpu_offset , gpu_count in zip (engine_gpu_offsets , engine_gpu_counts , strict = True ):
88+ if gpu_offset + gpu_count > total_actor_gpus :
8289 break
83- gpu_sum += c
8490 colocate_engine_nums += 1
8591
8692 self .use_distribute = len (rollout_engines ) > colocate_engine_nums
@@ -108,25 +114,24 @@ def connect_rollout_engines(
108114 engine_gpu_counts = distributed_gpu_counts ,
109115 )
110116
111- # Cumulative rank offsets for (potentially) non-uniform colocated groups.
117+ colocate_gpu_offsets = engine_gpu_offsets [: colocate_engine_nums ]
112118 colocate_gpu_counts = engine_gpu_counts [:colocate_engine_nums ]
113- cumulative = [0 ]
114- for c in colocate_gpu_counts :
115- cumulative .append (cumulative [- 1 ] + c )
116119
117120 # Create IPC Gloo gather groups (only on first call; partitioning is
118121 # fixed across reconnects).
119122 if self ._ipc_gather_group is None :
120123 for i in range (colocate_engine_nums ):
121- group_ranks = list (range (cumulative [i ], cumulative [ i + 1 ]))
124+ group_ranks = list (range (colocate_gpu_offsets [i ], colocate_gpu_offsets [ i ] + colocate_gpu_counts [ i ]))
122125 new_group = dist .new_group (ranks = group_ranks , backend = "gloo" )
123126 if dist .get_rank () in group_ranks :
124127 self ._ipc_gather_group = new_group
125- self ._ipc_gather_src = cumulative [i ]
128+ self ._ipc_gather_src = colocate_gpu_offsets [i ]
126129
127130 # Map training ranks to colocated engine actors.
128131 for i , engine in enumerate (self .rollout_engines ):
129- if cumulative [i ] <= dist .get_rank () < cumulative [i + 1 ]:
132+ start = colocate_gpu_offsets [i ]
133+ end = start + colocate_gpu_counts [i ]
134+ if start <= dist .get_rank () < end :
130135 self ._ipc_engine = engine
131136
132137 @torch .no_grad ()
0 commit comments