Skip to content

Commit baa3ce5

Browse files
author
yexin
committed
implement ring algo
1 parent 3f62182 commit baa3ce5

File tree

1 file changed

+87
-35
lines changed

1 file changed

+87
-35
lines changed

verl/checkpoint_engine/mooncake_checkpoint_engine.py

Lines changed: 87 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,20 @@ def __init__(
7878
self.session_id = f"{hostname}:{rpc_port}"
7979
self.hostname = hostname
8080

81-
self.buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device=self.device)
82-
assert self.engine.register_memory(self.buf.data_ptr(), self.bucket_size) == 0, "register_memory failed"
81+
self.buf = torch.empty(2 * self.bucket_size, dtype=torch.uint8, device=self.device)
82+
self.magic_buf = torch.empty(4 * 1024, dtype=torch.uint8, device=self.device)
83+
ret = self.engine.batch_register_memory(
84+
[self.buf.data_ptr(), self.magic_buf.data_ptr()],
85+
[2 * self.bucket_size, 4 * 1024],
86+
)
87+
assert ret == 0, f"batch_register_memory failed ret={ret}"
88+
logger.info(f"__init__ session_id={self.session_id}")
8389

8490
def prepare(self) -> dict[str, Any]:
8591
"""Prepare send and recv buckets"""
8692
# self.buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device=self.device)
8793
# self.engine.register_memory(self.buf.data_ptr(), self.bucket_size)
94+
logger.info(f"__init__ ptr={self.buf.data_ptr():#x} len={2 * self.bucket_size}")
8895
port, _ = get_free_port(self.hostname)
8996
return {"addr": self.hostname, "port": port}
9097

@@ -106,6 +113,7 @@ def init_process_group(self, rank: int, world_size: int, metadata: dict[str, Any
106113
self.rank = rank
107114
self.world_size = world_size
108115
if rank < 0:
116+
logger.info(f"init_process_group rank={rank}")
109117
return
110118

111119
self.store = StatelessProcessGroup.create(
@@ -115,55 +123,74 @@ def init_process_group(self, rank: int, world_size: int, metadata: dict[str, Any
115123
world_size=world_size,
116124
)
117125

118-
if self.is_master:
119-
buffer_info = {
120-
"session_id": self.session_id,
121-
"ptr": self.buf.data_ptr(),
122-
"len": self.bucket_size,
123-
}
124-
self.store.broadcast_obj(obj=buffer_info, src=0)
125-
else:
126-
self.buffer_info = self.store.broadcast_obj(obj=None, src=0)
126+
info = {
127+
"session_id": self.session_id,
128+
"ptr": self.buf.data_ptr(),
129+
}
127130

131+
info_list = self.store.all_gather_obj(info)
132+
self.buffer_info = None if rank == 0 else info_list[rank - 1]
133+
134+
logger.info(
135+
f"init_process_group rank={rank} world_size={world_size} buffer_info={self.buffer_info}"
136+
)
128137

129138
def finalize(self):
130139
"""Cleanup communication and deregister memory"""
131140
self.store = None
132141
get_torch_device().empty_cache()
133142
gc.collect()
143+
logger.info(f"finalize rank={self.rank}")
134144

135-
async def wait_for_complete(self):
145+
async def wait_for_complete(self, buf: torch.Tensor):
136146
magic = torch.tensor([0xab, 0xdc, 0xef, 0x88], dtype=torch.uint8, device=self.device)
137-
target = magic.repeat(self.world_size - 1)
138147
while True:
139-
if torch.equal(self.buf[4:4 * self.world_size], target):
148+
if torch.equal(buf[:4], magic):
140149
break
141150
await asyncio.sleep(0)
142151

143152
@torch.no_grad()
144153
async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]):
145154
"""Send weights using Mooncake TransferEngine"""
155+
if self.rank < 0:
156+
for name, weight in weights:
157+
pass
158+
logger.info(f"send_weights rank={self.rank}")
159+
return
160+
161+
total_bytes = 0
146162
start_time = time.time()
147163
bucket_meta: dict[str, TensorMeta] = {}
148164
offset = 0
165+
should_wait = False
166+
bufs = [self.buf[:self.bucket_size], self.buf[self.bucket_size:]]
167+
idx = 0
168+
current = bufs[idx]
149169

150170
for name, weight in weights:
151-
if self.rank != 0:
152-
continue
153171
weight = weight.to(self.rollout_dtype)
154172

155173
if offset + weight.nbytes > self.bucket_size:
156-
get_torch_device().synchronize
174+
total_bytes += offset
175+
get_torch_device().synchronize()
157176
info = {
158177
"bucket_meta": bucket_meta,
178+
"ptr": current.data_ptr(),
159179
"len": offset,
160180
"is_last": False,
161181
}
162-
self.store.broadcast_obj(obj=info, src=0)
163-
await self.wait_for_complete()
182+
# send to rank 1
183+
self.store.send_obj(info, 1)
184+
185+
idx ^= 1
186+
current = bufs[idx]
164187
bucket_meta = {}
165188
offset = 0
166189

190+
if should_wait:
191+
await self.wait_for_complete(current)
192+
should_wait = True
193+
167194
assert offset + weight.nbytes <= self.bucket_size, (
168195
f"Weight {name}({weight.shape}, {weight.dtype}) is too large to fit in the bucket."
169196
)
@@ -174,53 +201,78 @@ async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None,
174201
"dtype": weight.dtype,
175202
"offset": offset,
176203
}
177-
self.buf[offset : offset + weight.nbytes].copy_(weight.view(-1).view(torch.uint8), non_blocking=True)
204+
current[offset : offset + weight.nbytes].copy_(weight.view(-1).view(torch.uint8), non_blocking=True)
178205
offset += weight.nbytes
179206

180-
if self.rank != 0:
181-
return
182-
183207
get_torch_device().synchronize()
184208
info = {
185209
"bucket_meta": bucket_meta,
210+
"ptr": current.data_ptr(),
186211
"len": offset,
187212
"is_last": True,
188213
}
189-
self.store.broadcast_obj(obj=info, src=0)
190-
await self.wait_for_complete()
191-
logger.info(f"send weights done, time cost: {time.time() - start_time:.2f}s")
214+
self.store.send_obj(info, 1)
215+
await self.wait_for_complete(current)
216+
217+
time_cost = time.time() - start_time
218+
bandwidth = total_bytes / time_cost / (1024 * 1024 * 1024)
219+
logger.info(
220+
f"Rank {self.rank} send weights done, "
221+
f"total bytes: {total_bytes} time cost: {time_cost:.2f}s bandwidth: {bandwidth:.2f} GB/s"
222+
)
192223

193224
@torch.no_grad()
194225
async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None]:
195226
"""Receive weights using Mooncake TransferEngine"""
196227
start_time = time.time()
197228
total_bytes = 0
229+
bufs = [self.buf[:self.bucket_size], self.buf[self.bucket_size:]]
230+
idx = 0
231+
current = bufs[idx]
232+
self.magic_buf = torch.tensor([0xab, 0xdc, 0xef, 0x88], dtype=torch.uint8, device=self.device)
233+
198234
while True:
199-
info = self.store.broadcast_obj(obj=None, src=0)
235+
# 1 receive info from previous rank
236+
info = self.store.recv_obj(self.rank - 1)
237+
if idx >= 2 and self.rank < self.world_size - 1:
238+
await self.wait_for_complete(current)
239+
240+
ptr = info["ptr"]
200241
ret = self.engine.transfer_sync_read(
201242
self.buffer_info["session_id"],
202-
self.buf.data_ptr(),
203-
self.buffer_info["ptr"],
243+
current.data_ptr(),
244+
ptr,
204245
info["len"],
205246
)
206247
assert ret == 0, f"transfer_sync_read failed {ret}"
207248
total_bytes += info["len"]
249+
250+
# 2 send info to next rank
251+
info["ptr"] = current.data_ptr()
252+
if self.rank < self.world_size - 1:
253+
self.store.send_obj(info, self.rank + 1)
254+
255+
# 3 yield tensor from current buffer
208256
for name, meta in info["bucket_meta"].items():
209257
dtype, shape = meta["dtype"], meta["shape"]
210258
size = dtype.itemsize * shape.numel()
211-
tensor = self.buf[meta["offset"] : meta["offset"] + size].view(dtype=dtype).view(shape)
259+
tensor = current[meta["offset"] : meta["offset"] + size].view(dtype=dtype).view(shape)
212260
yield name, tensor
213261

214-
self.buf[:4] = torch.tensor([0xab, 0xdc, 0xef, 0x88], dtype=torch.uint8, device=self.device)
215-
216-
offset = self.buffer_info["ptr"] + self.rank * 4
262+
# 4 write magic data to previous rank
217263
ret = self.engine.transfer_sync_write(
218264
self.buffer_info["session_id"],
219-
self.buf.data_ptr(),
220-
offset,
265+
self.magic_buf.data_ptr(),
266+
ptr,
221267
4,
222268
)
223269
assert ret == 0, f"transfer_sync_write failed {ret}"
270+
271+
# 5 swap buffer
272+
idx += 1
273+
current = bufs[idx % 2]
274+
get_torch_device().synchronize()
275+
224276
if info["is_last"]:
225277
break
226278

0 commit comments

Comments
 (0)