Skip to content

Commit b0c6ca0

Browse files
author
kip-cxj
committed
add dist.use_backend
1 parent 5ab13a1 commit b0c6ca0

File tree

6 files changed

+65
-59
lines changed

6 files changed

+65
-59
lines changed

checkpoint_engine/distributed/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
init_process_group,
99
is_initialized,
1010
new_group,
11+
use_backend,
1112
)
1213

1314

@@ -21,4 +22,5 @@
2122
"init_process_group",
2223
"is_initialized",
2324
"new_group",
25+
"use_backend",
2426
]

checkpoint_engine/distributed/base.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@ def all_gather(self, *args: Any, **kwargs: Any) -> torch.Tensor: ...
1616
class CommGroup:
1717
def __init__(self, comm_handle: int, ranks: list[int]):
1818
self._comm = comm_handle
19-
self.ranks = ranks
19+
self._ranks = ranks
2020

2121
@property
2222
def handle(self) -> int:
2323
return self._comm
2424

2525
@property
2626
def ranks(self) -> list[int]:
27-
return self.ranks
27+
return self._ranks
2828

2929

3030
DistributedProcessGroup = torch_dist.ProcessGroup | CommGroup
@@ -39,6 +39,7 @@ def init_process_group(
3939
rank: int,
4040
world_size: int,
4141
timeout: timedelta,
42+
**kwargs,
4243
):
4344
raise NotImplementedError
4445

@@ -100,22 +101,21 @@ def new_group(
100101

101102

102103
class TorchBackend(Distributed):
103-
def __init__(self, backend_type: str):
104-
self.backend_type = backend_type
105-
106104
def init_process_group(
107105
self,
108106
host: str,
109107
port: int,
110108
rank: int,
111109
world_size: int,
112110
timeout: timedelta,
111+
**kwargs,
113112
):
113+
backend = kwargs.get("backend", "nccl")
114114
store = torch.distributed.TCPStore(
115115
host, port, world_size, timeout=timeout, is_master=(rank == 0)
116116
)
117117
torch.distributed.init_process_group(
118-
backend=self.backend_type,
118+
backend=backend,
119119
world_size=world_size,
120120
rank=rank,
121121
timeout=timeout,
@@ -159,7 +159,7 @@ def new_group(self, ranks: list[int], **kwargs) -> DistributedProcessGroup | Non
159159

160160

161161
# specific device instance
162-
_BACKEND_INSTANCE: Distributed = TorchBackend(backend_type="nccl")
162+
_BACKEND_INSTANCE: Distributed = TorchBackend()
163163

164164
_pickler = pickle.Pickler
165165
_unpickler = pickle.Unpickler
@@ -223,33 +223,34 @@ def _common_all_gather_object(
223223
object_list[i] = _tensor_to_object(tensor, tensor_size)
224224

225225

226+
def use_backend(backend: str | None):
227+
global _BACKEND_INSTANCE
228+
229+
if not backend:
230+
return
231+
232+
mapping = {
233+
"vllm_nccl": ".nccl.DistributedNccl",
234+
"vllm_hccl": ".hccl.DistributedHccl",
235+
}
236+
if backend not in mapping:
237+
raise ValueError(f"Unsupported custom backend: {backend}")
238+
239+
module_path, class_name = mapping[backend].rsplit(".", 1)
240+
module = importlib.import_module(module_path, "checkpoint_engine.distributed")
241+
backend_class = getattr(module, class_name)
242+
_BACKEND_INSTANCE = backend_class()
243+
244+
226245
def init_process_group(
227246
host: str,
228247
port: int,
229248
rank: int,
230249
world_size: int,
231-
custom_dist: bool,
232-
backend: str,
233250
timeout: timedelta = timedelta(seconds=300),
251+
**kwargs,
234252
):
235-
global _BACKEND_INSTANCE
236-
237-
if not custom_dist:
238-
_BACKEND_INSTANCE = TorchBackend(backend_type=backend)
239-
else:
240-
mapping = {
241-
"nccl": ".nccl.DistributedNccl",
242-
"hccl": ".hccl.DistributedHccl",
243-
}
244-
if backend not in mapping:
245-
raise ValueError(f"Unsupported custom backend: {backend}")
246-
247-
module_path, class_name = mapping[backend].rsplit(".", 1)
248-
module = importlib.import_module(module_path, "checkpoint_engine.distributed")
249-
backend_class = getattr(module, class_name)
250-
_BACKEND_INSTANCE = backend_class()
251-
252-
_BACKEND_INSTANCE.init_process_group(host, port, rank, world_size, timeout)
253+
_BACKEND_INSTANCE.init_process_group(host, port, rank, world_size, timeout, **kwargs)
253254

254255

255256
def destroy_process_group(group: DistributedProcessGroup | None = None):

checkpoint_engine/distributed/hccl.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -208,25 +208,25 @@ def __init__(self):
208208

209209
@contextmanager
210210
def _use_group(self, group: CommGroup | None, src: int | None = None):
211+
active_src = src
211212
if group:
212-
assert group.handle() in self.sub_groups, "invalid sub_group"
213-
newcomm = ctypes.c_void_p(group.handle())
214-
self.pynccl.comm = newcomm
215-
active_src = src
213+
assert group.handle in self.sub_groups, "invalid sub_group"
214+
newcomm = ctypes.c_void_p(group.handle)
215+
self.pyhccl.comm = newcomm
216216

217217
if src is not None:
218-
assert src in group.ranks(), "src rank not in group"
218+
assert src in group.ranks, "src rank not in group"
219219
# convert src rank id in default world to newcomm
220-
active_src = group.ranks().index(src)
221-
self.pynccl.rank = group.ranks().index(self.rank)
220+
active_src = group.ranks.index(src)
221+
self.pyhccl.rank = group.ranks.index(self.rank)
222222

223223
try:
224224
yield active_src
225225
finally:
226226
if group:
227-
self.pynccl.comm = self.comm
227+
self.pyhccl.comm = self.comm
228228
if src is not None:
229-
self.pynccl.rank = self.rank
229+
self.pyhccl.rank = self.rank
230230

231231
def init_process_group(
232232
self,
@@ -235,6 +235,7 @@ def init_process_group(
235235
rank: int,
236236
world_size: int,
237237
timeout: timedelta = timedelta(seconds=300),
238+
**kwargs,
238239
):
239240
assert not self.initialized, "already initialized"
240241

@@ -257,10 +258,10 @@ def destroy_process_group(
257258
):
258259
assert self.initialized, "not initialized"
259260

260-
if group in self.sub_groups:
261-
subcomm = ctypes.c_void_p(group)
261+
if group and group.handle in self.sub_groups:
262+
subcomm = ctypes.c_void_p(group.handle)
262263
self.pyhccl.destroy_comm(subcomm)
263-
del self.sub_groups[group]
264+
del self.sub_groups[group.handle]
264265
return
265266

266267
self.pyhccl.destroy_comm()
@@ -297,7 +298,7 @@ def broadcast(
297298
):
298299
assert self.initialized, "not initialized"
299300

300-
with self._use_group(group) as local_rank:
301+
with self._use_group(group, src) as local_rank:
301302
self.pyhccl.broadcast(tensor, local_rank)
302303
current_stream().synchronize()
303304

@@ -318,8 +319,11 @@ def new_group(self, ranks: list[int], **kwargs) -> CommGroup:
318319
else:
319320
ranks.sort()
320321

321-
newcomm = self.pynccl.create_newcomm(ranks)
322-
if newcomm:
323-
group = CommGroup(newcomm.value, ranks)
324-
self.sub_groups[newcomm.value] = group
322+
if self.rank not in ranks:
323+
return
324+
325+
subcomm = self.pyhccl.create_subcomm(ranks)
326+
if subcomm:
327+
group = CommGroup(subcomm.value, ranks)
328+
self.sub_groups[subcomm.value] = group
325329
return group

checkpoint_engine/distributed/nccl.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -111,17 +111,17 @@ def __init__(self):
111111

112112
@contextmanager
113113
def _use_group(self, group: CommGroup | None, src: int | None = None):
114+
active_src = src
114115
if group:
115-
assert group.handle() in self.sub_groups, "invalid sub_group"
116-
newcomm = ctypes.c_void_p(group.handle())
116+
assert group.handle in self.sub_groups, "invalid sub_group"
117+
newcomm = ctypes.c_void_p(group.handle)
117118
self.pynccl.comm = newcomm
118-
active_src = src
119119

120120
if src is not None:
121-
assert src in group.ranks(), "src rank not in group"
121+
assert src in group.ranks, "src rank not in group"
122122
# convert src rank id in default world to newcomm
123-
active_src = group.ranks().index(src)
124-
self.pynccl.rank = group.ranks().index(self.rank)
123+
active_src = group.ranks.index(src)
124+
self.pynccl.rank = group.ranks.index(self.rank)
125125

126126
try:
127127
yield active_src
@@ -138,6 +138,7 @@ def init_process_group(
138138
rank: int,
139139
world_size: int,
140140
timeout: timedelta = timedelta(seconds=300),
141+
**kwargs,
141142
):
142143
assert not self.initialized, "already initialized"
143144

@@ -161,10 +162,10 @@ def destroy_process_group(
161162
):
162163
assert self.initialized, "not initialized"
163164

164-
if group.handle() in self.sub_groups:
165-
newcomm = ctypes.c_void_p(group.handle())
165+
if group and group.handle in self.sub_groups:
166+
newcomm = ctypes.c_void_p(group.handle)
166167
self.pynccl.destroy_comm(newcomm)
167-
del self.sub_groups[group.handle()]
168+
del self.sub_groups[group.handle]
168169
return
169170

170171
self.pynccl.destroy_comm()

checkpoint_engine/ps.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,6 @@ def __init__(
176176
auto_pg: bool = True,
177177
gpu_count: int | None = None,
178178
mem_fraction: float | None = None,
179-
custom_dist: bool = False,
180179
):
181180
"""
182181
Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set.
@@ -197,7 +196,6 @@ def __init__(
197196
self._local_rdma_devices: dict[str, set[int]] = defaultdict(set)
198197
self._remote_rdma_devices: dict[str, set[int]] = defaultdict(set)
199198
self._mem_fraction = mem_fraction or float(os.getenv("PS_MEM_FRACTION", "0.9"))
200-
self._custom_dist = custom_dist
201199

202200
assert self._rank is not None and self._rank >= 0, self._rank
203201
assert self._world_size and self._world_size > 0, self._world_size
@@ -498,9 +496,8 @@ def init_process_group(
498496
port=_get_master_port(master_port),
499497
rank=self._rank,
500498
world_size=self._world_size,
501-
custom_dist=self._custom_dist,
502-
backend=self.device_manager.backend,
503499
timeout=timeout,
500+
backend=self.device_manager.backend,
504501
)
505502
logger.info(f"[rank{self._rank}] init process group successfully.")
506503

examples/update.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,13 +159,14 @@ def join(
159159
parser.add_argument("--checkpoint-name", type=str, default="my-checkpoint-iter-0")
160160
parser.add_argument("--update-method", type=str, default="broadcast")
161161
parser.add_argument("--uds", type=str, default=None)
162-
parser.add_argument("--custom-dist", action="store_true")
162+
parser.add_argument("--custom-dist", type=str, default=None)
163163
args = parser.parse_args()
164164
rank = int(os.getenv("RANK"))
165165
world_size = int(os.getenv("WORLD_SIZE"))
166166

167167
req_func = req_inference(args.endpoint, args.inference_parallel_size, args.uds)
168-
ps = ParameterServer(auto_pg=True, custom_dist=args.custom_dist)
168+
dist.use_backend(args.custom_dist)
169+
ps = ParameterServer(auto_pg=True)
169170
if args.load_metas_file:
170171
join(
171172
ps,

0 commit comments

Comments
 (0)