Skip to content

Commit 6ad7671

Browse files
author
kip-cxj
committed
Merge remote-tracking branch 'upstream/main' into main
2 parents 8273ecd + 009082d commit 6ad7671

File tree

7 files changed

+85
-58
lines changed

7 files changed

+85
-58
lines changed

checkpoint_engine/distributed/base.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -103,23 +103,21 @@ def new_group(
103103
class TorchBackend(Distributed):
104104
def init_process_group(
105105
self,
106-
host: str,
107-
port: int,
108106
rank: int,
109107
world_size: int,
108+
store: torch.distributed.TCPStore,
110109
timeout: timedelta,
111110
**kwargs,
112111
):
113112
backend = kwargs.get("backend", "nccl")
114-
store = torch.distributed.TCPStore(
115-
host, port, world_size, timeout=timeout, is_master=(rank == 0)
116-
)
113+
store_counter = kwargs.get("store_counter", "nccl")
114+
sub_store = torch.distributed.PrefixStore(f"prefix-{store_counter}", store)
117115
torch.distributed.init_process_group(
118116
backend=backend,
119117
world_size=world_size,
120118
rank=rank,
121119
timeout=timeout,
122-
store=store,
120+
store=sub_store,
123121
)
124122

125123
def destroy_process_group(self, group: DistributedProcessGroup | None = None):
@@ -243,14 +241,13 @@ def use_backend(backend: str | None):
243241

244242

245243
def init_process_group(
246-
host: str,
247-
port: int,
248244
rank: int,
249245
world_size: int,
246+
store: torch.distributed.TCPStore,
250247
timeout: timedelta = timedelta(seconds=300),
251248
**kwargs,
252249
):
253-
_BACKEND_INSTANCE.init_process_group(host, port, rank, world_size, timeout, **kwargs)
250+
_BACKEND_INSTANCE.init_process_group(rank, world_size, store, timeout, **kwargs)
254251

255252

256253
def destroy_process_group(group: DistributedProcessGroup | None = None):

checkpoint_engine/distributed/hccl.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -230,23 +230,22 @@ def _use_group(self, group: CommGroup | None, src: int | None = None):
230230

231231
def init_process_group(
232232
self,
233-
host: str,
234-
port: int,
235233
rank: int,
236234
world_size: int,
235+
store: torch.distributed.TCPStore,
237236
timeout: timedelta = timedelta(seconds=300),
238237
**kwargs,
239238
):
240239
assert not self.initialized, "already initialized"
241240

242-
self.host = host
243-
self.port = port
241+
self.host = store.host
242+
self.port = store.port + 1
244243
self.rank = rank
245244
self.world_size = world_size
246245
self.device = torch.device("npu", torch.npu.current_device())
247246

248247
self.pg = StatelessProcessGroup.create(
249-
host, port, rank, world_size, store_timeout=int(timeout.total_seconds())
248+
self.host, self.port, rank, world_size, store_timeout=int(timeout.total_seconds())
250249
)
251250
self.pyhccl = PyHcclCommunicatorEx(group=self.pg, device=self.device)
252251
self.comm = self.pyhccl.comm

checkpoint_engine/distributed/nccl.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,23 +133,22 @@ def _use_group(self, group: CommGroup | None, src: int | None = None):
133133

134134
def init_process_group(
135135
self,
136-
host: str,
137-
port: int,
138136
rank: int,
139137
world_size: int,
138+
store: torch.distributed.TCPStore,
140139
timeout: timedelta = timedelta(seconds=300),
141140
**kwargs,
142141
):
143142
assert not self.initialized, "already initialized"
144143

145-
self.host = host
146-
self.port = port
144+
self.host = store.host
145+
self.port = store.port + 1
147146
self.rank = rank
148147
self.world_size = world_size
149148
self.device = torch.device("cuda", torch.cuda.current_device())
150149

151150
self.pg = StatelessProcessGroup.create(
152-
host, port, rank, world_size, store_timeout=int(timeout.total_seconds())
151+
self.host, self.port, rank, world_size, store_timeout=int(timeout.total_seconds())
153152
)
154153

155154
self.pynccl = PyNcclCommunicatorEx(group=self.pg, device=self.device)

checkpoint_engine/pin_memory.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,9 @@ def _pin(t: torch.Tensor):
209209
torch.cuda.set_device(device_index)
210210
cudart = torch.cuda.cudart()
211211
r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
212-
assert r == 0, f"pin memory error, error code: {r}"
212+
if r != 0:
213+
error_msg = cudart.cudaGetErrorString(r)
214+
raise RuntimeError(f"pin memory error, error code: {r}, error message: {error_msg}")
213215

214216
# TODO: should only support /dev/shm? but we found files in disk also work?
215217
size = os.stat(file_path).st_size
@@ -254,6 +256,12 @@ def _pin(t: torch.Tensor):
254256
# Remove the file after successfully loading. This will avoid doubling the memory usage.
255257
# We assume files in /dev/shm/ are temporary files. So it's safe to remove them after loading.
256258
os.remove(file_path)
259+
if not metas:
260+
# TODO: should we still return this buffer?
261+
assert buffer.nbytes == 0, f"buffer nbytes {buffer.nbytes} should be 0"
262+
logger.warning(f"[rank{rank}] no metas found in {file_path}, skip pin memory")
263+
return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=[], manually_pinned=False)
264+
257265
_pin(buffer)
258266
logger.info(
259267
f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB"

checkpoint_engine/ps.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ def __init__(
176176
auto_pg: bool = True,
177177
gpu_count: int | None = None,
178178
mem_fraction: float | None = None,
179+
master_addr: str | None = None,
180+
master_port: int | None = None,
179181
):
180182
"""
181183
Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set.
@@ -229,6 +231,17 @@ def __init__(
229231
self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index)
230232
self._rdma_device = None if self._p2p_store is None else self._p2p_store.device
231233

234+
master_addr = master_addr or os.getenv("MASTER_ADDR")
235+
assert master_addr, "master_addr is required"
236+
self._store = torch.distributed.TCPStore(
237+
master_addr,
238+
_get_master_port(master_port),
239+
self._world_size,
240+
timeout=timedelta(minutes=10),
241+
is_master=self._rank == 0,
242+
)
243+
self._store_counter = 0
244+
232245
def _get_memory_pool(self, checkpoint_name: str) -> list[MemoryBuffer]:
233246
if checkpoint_name == self._current_shared_memory_pool_user:
234247
assert self._memory_pool[self.shared_memory_pool_name], (
@@ -392,7 +405,11 @@ def _unpin(t: torch.Tensor):
392405
)
393406
cudart = torch.cuda.cudart()
394407
r = cudart.cudaHostUnregister(t.data_ptr())
395-
assert r == 0, f"unpin memory error, error code: {r}"
408+
if r != 0:
409+
error_msg = cudart.cudaGetErrorString(r)
410+
raise RuntimeError(
411+
f"unpin memory error, error code: {r}, error message: {error_msg}"
412+
)
396413

397414
# if the checkpoint is pinned by cudaHostRegister manually, we need to unpin it manually
398415
try:
@@ -408,7 +425,13 @@ def _unpin(t: torch.Tensor):
408425
del self._memory_pool[checkpoint_name]
409426
# see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018
410427
# this works by using torch>=2.5.0
411-
torch._C._host_emptyCache()
428+
if self.device_manager.device_type == "cuda":
429+
torch._C._host_emptyCache()
430+
else:
431+
# torch._C._host_emptyCache() is not supported on NPU, so we call gc.collect() to empty host cache.
432+
import gc
433+
434+
gc.collect()
412435

413436
def gather_metas(self, checkpoint_name: str):
414437
"""
@@ -478,8 +501,6 @@ def gather_metas(self, checkpoint_name: str):
478501
def init_process_group(
479502
self,
480503
*,
481-
master_addr: str | None = None,
482-
master_port: int | None = None,
483504
timeout: timedelta = timedelta(minutes=10),
484505
):
485506
"""
@@ -489,21 +510,18 @@ def init_process_group(
489510
master_port: The specified port of the master node. If not set, will use _get_master_port to get the port.
490511
timeout: The timeout of the process group.
491512
"""
492-
master_addr = master_addr or os.getenv("MASTER_ADDR")
493-
assert master_addr, "master_addr is required"
513+
self._store_counter += 1
494514
dist.init_process_group(
495-
host=master_addr,
496-
port=_get_master_port(master_port),
497515
rank=self._rank,
498516
world_size=self._world_size,
517+
store=self._store,
499518
timeout=timeout,
500519
backend=self.device_manager.backend,
520+
store_counter=self._store_counter,
501521
)
502522
logger.info(f"[rank{self._rank}] init process group successfully.")
503523

504-
def store_based_barrier(
505-
self, store: torch.distributed.TCPStore, timeout: timedelta = timedelta(minutes=5)
506-
) -> None:
524+
def store_based_barrier(self, timeout: timedelta = timedelta(minutes=5)) -> None:
507525
"""
508526
Perform a store-based barrier synchronization across all ranks.
509527
@@ -516,7 +534,7 @@ def store_based_barrier(
516534
"""
517535
torch.distributed.distributed_c10d._store_based_barrier(
518536
rank=self._rank,
519-
store=store,
537+
store=self._store,
520538
group_name="parameter_server_barrier",
521539
rendezvous_count=self._world_size,
522540
timeout=timeout,
@@ -529,8 +547,6 @@ def update(
529547
*,
530548
timeout: timedelta = timedelta(minutes=10),
531549
ranks: list[int] | None = None,
532-
master_addr: str | None = None,
533-
master_port: int | None = None,
534550
) -> None:
535551
"""
536552
Update the checkpoint to inference engine. This function should be called after gather_metas.
@@ -551,25 +567,12 @@ def update(
551567
assert req_func is not None, "req_func is required"
552568
ranks_group = None
553569
try:
554-
master_addr = os.getenv("MASTER_ADDR") or master_addr
555-
assert master_addr, "master_addr is required"
556570
if self._auto_pg and not dist.is_initialized():
557-
self.init_process_group(
558-
timeout=timeout, master_addr=master_addr, master_port=master_port
559-
)
560-
# HACK: MASTER_PORT+2 for barrier store if master_port is not provided, _get_master_port() returns MASTER_PORT+1
561-
# If master_port is provided, use master_port+1 for barrier store
562-
manager_store = torch.distributed.TCPStore(
563-
master_addr,
564-
_get_master_port(master_port) + 1,
565-
self._world_size,
566-
timeout=timeout,
567-
is_master=self._rank == 0,
568-
)
571+
self.init_process_group(timeout=timeout)
569572
# if ranks is None or [], it will use fully broadcast to update to all ranks
570573
ranks_group = dist.new_group(ranks) if ranks else None
571574
self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
572-
self.store_based_barrier(manager_store)
575+
self.store_based_barrier()
573576
except Exception as e:
574577
logger.exception(
575578
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
@@ -580,7 +583,6 @@ def update(
580583
dist.destroy_process_group(ranks_group)
581584
if self._auto_pg and dist.is_initialized():
582585
dist.destroy_process_group()
583-
del manager_store
584586
self.device_manager.device_module.empty_cache()
585587
logger.info(
586588
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "

checkpoint_engine/worker.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from checkpoint_engine.device_utils import DeviceManager, npu_generate_uuid
1111

1212

13+
_WEIGHTS_TYPE = list[tuple[str, torch.Tensor]]
14+
15+
1316
def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor:
1417
func, args = handle
1518
list_args = list(args)
@@ -29,11 +32,9 @@ class FlattenedTensorMetadata(TypedDict):
2932
offset: int
3033

3134

32-
def _extract_weights(
33-
payload: list[FlattenedTensorMetadata], buffer: torch.Tensor
34-
) -> list[tuple[str, torch.Tensor]]:
35+
def _extract_weights(payload: list[FlattenedTensorMetadata], buffer: torch.Tensor) -> _WEIGHTS_TYPE:
3536
assert buffer is not None
36-
weights: list[tuple[str, torch.Tensor]] = []
37+
weights: _WEIGHTS_TYPE = []
3738
for item in payload:
3839
shape = item["shape"]
3940
if isinstance(shape, list | tuple):
@@ -166,12 +167,31 @@ def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
166167
self.device = torch.device(f"npu:{self.local_rank}")
167168
assert self.device is not None
168169

170+
def _load_weights(weights: _WEIGHTS_TYPE):
171+
# Load main model weights
172+
self.model_runner.model.load_weights(weights)
173+
# Load drafter model weights if MTP/speculative decoding is enabled
174+
if (
175+
getattr(self.model_runner, "drafter", None) is not None
176+
and getattr(self.model_runner.drafter, "model", None) is not None
177+
):
178+
self.model_runner.drafter.model.load_weights(weights=weights)
179+
180+
def _post_hook():
181+
process_weights_after_loading(self.model_runner.model, self.model_config, self.device)
182+
# Also trigger drafter model's post processing if MTP is enabled
183+
if (
184+
getattr(self.model_runner, "drafter", None) is not None
185+
and getattr(self.model_runner.drafter, "model", None) is not None
186+
):
187+
process_weights_after_loading(
188+
self.model_runner.drafter.model, self.model_config, self.device
189+
)
190+
169191
update_weights_from_ipc(
170192
self._zmq_ctx,
171193
zmq_handles[self._device_uuid],
172194
device_id=self.device.index,
173-
run=self.model_runner.model.load_weights,
174-
post_hook=lambda: process_weights_after_loading(
175-
self.model_runner.model, self.model_config, self.device
176-
),
195+
run=_load_weights,
196+
post_hook=_post_hook,
177197
)

tests/test_reuse_pin_memory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ def generate_dummy_checkpoint() -> dict[str, torch.Tensor]:
2323
def test_register_pin_memory():
2424
os.environ["RANK"] = "0"
2525
os.environ["WORLD_SIZE"] = "1"
26+
os.environ["MASTER_ADDR"] = "localhost"
27+
os.environ["MASTER_PORT"] = "25400"
2628
ps = ParameterServer()
2729
checkpoint1 = generate_dummy_checkpoint()
2830
checkpoint_shared1 = generate_dummy_checkpoint()

0 commit comments

Comments
 (0)