@@ -78,13 +78,20 @@ def __init__(
7878 self .session_id = f"{ hostname } :{ rpc_port } "
7979 self .hostname = hostname
8080
81- self .buf = torch .zeros (self .bucket_size , dtype = torch .uint8 , device = self .device )
82- assert self .engine .register_memory (self .buf .data_ptr (), self .bucket_size ) == 0 , "register_memory failed"
81+ self .buf = torch .empty (2 * self .bucket_size , dtype = torch .uint8 , device = self .device )
82+ self .magic_buf = torch .empty (4 * 1024 , dtype = torch .uint8 , device = self .device )
83+ ret = self .engine .batch_register_memory (
84+ [self .buf .data_ptr (), self .magic_buf .data_ptr ()],
85+ [2 * self .bucket_size , 4 * 1024 ],
86+ )
87+ assert ret == 0 , f"batch_register_memory failed ret={ ret } "
88+ logger .info (f"__init__ session_id={ self .session_id } " )
8389
8490 def prepare (self ) -> dict [str , Any ]:
8591 """Prepare send and recv buckets"""
8692 # self.buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device=self.device)
8793 # self.engine.register_memory(self.buf.data_ptr(), self.bucket_size)
94+ logger .info (f"__init__ ptr={ self .buf .data_ptr ():#x} len={ 2 * self .bucket_size } " )
8895 port , _ = get_free_port (self .hostname )
8996 return {"addr" : self .hostname , "port" : port }
9097
@@ -106,6 +113,7 @@ def init_process_group(self, rank: int, world_size: int, metadata: dict[str, Any
106113 self .rank = rank
107114 self .world_size = world_size
108115 if rank < 0 :
116+ logger .info (f"init_process_group rank={ rank } " )
109117 return
110118
111119 self .store = StatelessProcessGroup .create (
@@ -115,55 +123,74 @@ def init_process_group(self, rank: int, world_size: int, metadata: dict[str, Any
115123 world_size = world_size ,
116124 )
117125
118- if self .is_master :
119- buffer_info = {
120- "session_id" : self .session_id ,
121- "ptr" : self .buf .data_ptr (),
122- "len" : self .bucket_size ,
123- }
124- self .store .broadcast_obj (obj = buffer_info , src = 0 )
125- else :
126- self .buffer_info = self .store .broadcast_obj (obj = None , src = 0 )
126+ info = {
127+ "session_id" : self .session_id ,
128+ "ptr" : self .buf .data_ptr (),
129+ }
127130
131+ info_list = self .store .all_gather_obj (info )
132+ self .buffer_info = None if rank == 0 else info_list [rank - 1 ]
133+
134+ logger .info (
135+ f"init_process_group rank={ rank } world_size={ world_size } buffer_info={ self .buffer_info } "
136+ )
128137
129138 def finalize (self ):
130139 """Cleanup communication and deregister memory"""
131140 self .store = None
132141 get_torch_device ().empty_cache ()
133142 gc .collect ()
143+ logger .info (f"finalize rank={ self .rank } " )
134144
135- async def wait_for_complete (self ):
145+ async def wait_for_complete (self , buf : torch . Tensor ):
136146 magic = torch .tensor ([0xab , 0xdc , 0xef , 0x88 ], dtype = torch .uint8 , device = self .device )
137- target = magic .repeat (self .world_size - 1 )
138147 while True :
139- if torch .equal (self . buf [4 : 4 * self . world_size ], target ):
148+ if torch .equal (buf [: 4 ], magic ):
140149 break
141150 await asyncio .sleep (0 )
142151
143152 @torch .no_grad ()
144153 async def send_weights (self , weights : Generator [tuple [str , torch .Tensor ], None , None ]):
145154 """Send weights using Mooncake TransferEngine"""
155+ if self .rank < 0 :
156+ for name , weight in weights :
157+ pass
158+ logger .info (f"send_weights rank={ self .rank } " )
159+ return
160+
161+ total_bytes = 0
146162 start_time = time .time ()
147163 bucket_meta : dict [str , TensorMeta ] = {}
148164 offset = 0
165+ should_wait = False
166+ bufs = [self .buf [:self .bucket_size ], self .buf [self .bucket_size :]]
167+ idx = 0
168+ current = bufs [idx ]
149169
150170 for name , weight in weights :
151- if self .rank != 0 :
152- continue
153171 weight = weight .to (self .rollout_dtype )
154172
155173 if offset + weight .nbytes > self .bucket_size :
156- get_torch_device ().synchronize
174+ total_bytes += offset
175+ get_torch_device ().synchronize ()
157176 info = {
158177 "bucket_meta" : bucket_meta ,
178+ "ptr" : current .data_ptr (),
159179 "len" : offset ,
160180 "is_last" : False ,
161181 }
162- self .store .broadcast_obj (obj = info , src = 0 )
163- await self .wait_for_complete ()
182+ # send to rank 1
183+ self .store .send_obj (info , 1 )
184+
185+ idx ^= 1
186+ current = bufs [idx ]
164187 bucket_meta = {}
165188 offset = 0
166189
190+ if should_wait :
191+ await self .wait_for_complete (current )
192+ should_wait = True
193+
167194 assert offset + weight .nbytes <= self .bucket_size , (
168195 f"Weight { name } ({ weight .shape } , { weight .dtype } ) is too large to fit in the bucket."
169196 )
@@ -174,53 +201,78 @@ async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None,
174201 "dtype" : weight .dtype ,
175202 "offset" : offset ,
176203 }
177- self . buf [offset : offset + weight .nbytes ].copy_ (weight .view (- 1 ).view (torch .uint8 ), non_blocking = True )
204+ current [offset : offset + weight .nbytes ].copy_ (weight .view (- 1 ).view (torch .uint8 ), non_blocking = True )
178205 offset += weight .nbytes
179206
180- if self .rank != 0 :
181- return
182-
183207 get_torch_device ().synchronize ()
184208 info = {
185209 "bucket_meta" : bucket_meta ,
210+ "ptr" : current .data_ptr (),
186211 "len" : offset ,
187212 "is_last" : True ,
188213 }
189- self .store .broadcast_obj (obj = info , src = 0 )
190- await self .wait_for_complete ()
191- logger .info (f"send weights done, time cost: { time .time () - start_time :.2f} s" )
214+ self .store .send_obj (info , 1 )
215+ await self .wait_for_complete (current )
216+
217+ time_cost = time .time () - start_time
218+ bandwidth = total_bytes / time_cost / (1024 * 1024 * 1024 )
219+ logger .info (
220+ f"Rank { self .rank } send weights done, "
221+ f"total bytes: { total_bytes } time cost: { time_cost :.2f} s bandwidth: { bandwidth :.2f} GB/s"
222+ )
192223
193224 @torch .no_grad ()
194225 async def receive_weights (self ) -> AsyncGenerator [tuple [str , torch .Tensor ], None ]:
195226 """Receive weights using Mooncake TransferEngine"""
196227 start_time = time .time ()
197228 total_bytes = 0
229+ bufs = [self .buf [:self .bucket_size ], self .buf [self .bucket_size :]]
230+ idx = 0
231+ current = bufs [idx ]
232+ self .magic_buf = torch .tensor ([0xab , 0xdc , 0xef , 0x88 ], dtype = torch .uint8 , device = self .device )
233+
198234 while True :
199- info = self .store .broadcast_obj (obj = None , src = 0 )
235+ # 1 receive info from previous rank
236+ info = self .store .recv_obj (self .rank - 1 )
237+ if idx >= 2 and self .rank < self .world_size - 1 :
238+ await self .wait_for_complete (current )
239+
240+ ptr = info ["ptr" ]
200241 ret = self .engine .transfer_sync_read (
201242 self .buffer_info ["session_id" ],
202- self . buf .data_ptr (),
203- self . buffer_info [ " ptr" ] ,
243+ current .data_ptr (),
244+ ptr ,
204245 info ["len" ],
205246 )
206247 assert ret == 0 , f"transfer_sync_read failed { ret } "
207248 total_bytes += info ["len" ]
249+
250+ # 2 send info to next rank
251+ info ["ptr" ] = current .data_ptr ()
252+ if self .rank < self .world_size - 1 :
253+ self .store .send_obj (info , self .rank + 1 )
254+
255+ # 3 yield tensor from current buffer
208256 for name , meta in info ["bucket_meta" ].items ():
209257 dtype , shape = meta ["dtype" ], meta ["shape" ]
210258 size = dtype .itemsize * shape .numel ()
211- tensor = self . buf [meta ["offset" ] : meta ["offset" ] + size ].view (dtype = dtype ).view (shape )
259+ tensor = current [meta ["offset" ] : meta ["offset" ] + size ].view (dtype = dtype ).view (shape )
212260 yield name , tensor
213261
214- self .buf [:4 ] = torch .tensor ([0xab , 0xdc , 0xef , 0x88 ], dtype = torch .uint8 , device = self .device )
215-
216- offset = self .buffer_info ["ptr" ] + self .rank * 4
262+ # 4 write magic data to previous rank
217263 ret = self .engine .transfer_sync_write (
218264 self .buffer_info ["session_id" ],
219- self .buf .data_ptr (),
220- offset ,
265+ self .magic_buf .data_ptr (),
266+ ptr ,
221267 4 ,
222268 )
223269 assert ret == 0 , f"transfer_sync_write failed { ret } "
270+
271+ # 5 swap buffer
272+ idx += 1
273+ current = bufs [idx % 2 ]
274+ get_torch_device ().synchronize ()
275+
224276 if info ["is_last" ]:
225277 break
226278
0 commit comments