Skip to content

Commit faf1dd0

Browse files
author
kip-cxj
committed
add dist.use_backend
1 parent 5ab13a1 commit faf1dd0

File tree

3 files changed

+22
-25
lines changed

3 files changed

+22
-25
lines changed

checkpoint_engine/distributed/base.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -223,32 +223,32 @@ def _common_all_gather_object(
223223
object_list[i] = _tensor_to_object(tensor, tensor_size)
224224

225225

226+
def use_backend(backend: str | None):
227+
global _BACKEND_INSTANCE
228+
229+
if not backend:
230+
return
231+
232+
mapping = {
233+
"vllm_nccl": ".nccl.DistributedNccl",
234+
"vllm_hccl": ".hccl.DistributedHccl",
235+
}
236+
if backend not in mapping:
237+
raise ValueError(f"Unsupported custom backend: {backend}")
238+
239+
module_path, class_name = mapping[backend].rsplit(".", 1)
240+
module = importlib.import_module(module_path, "checkpoint_engine.distributed")
241+
backend_class = getattr(module, class_name)
242+
_BACKEND_INSTANCE = backend_class()
243+
244+
226245
def init_process_group(
227246
host: str,
228247
port: int,
229248
rank: int,
230249
world_size: int,
231-
custom_dist: bool,
232-
backend: str,
233250
timeout: timedelta = timedelta(seconds=300),
234251
):
235-
global _BACKEND_INSTANCE
236-
237-
if not custom_dist:
238-
_BACKEND_INSTANCE = TorchBackend(backend_type=backend)
239-
else:
240-
mapping = {
241-
"nccl": ".nccl.DistributedNccl",
242-
"hccl": ".hccl.DistributedHccl",
243-
}
244-
if backend not in mapping:
245-
raise ValueError(f"Unsupported custom backend: {backend}")
246-
247-
module_path, class_name = mapping[backend].rsplit(".", 1)
248-
module = importlib.import_module(module_path, "checkpoint_engine.distributed")
249-
backend_class = getattr(module, class_name)
250-
_BACKEND_INSTANCE = backend_class()
251-
252252
_BACKEND_INSTANCE.init_process_group(host, port, rank, world_size, timeout)
253253

254254

checkpoint_engine/ps.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,6 @@ def __init__(
176176
auto_pg: bool = True,
177177
gpu_count: int | None = None,
178178
mem_fraction: float | None = None,
179-
custom_dist: bool = False,
180179
):
181180
"""
182181
Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set.
@@ -197,7 +196,6 @@ def __init__(
197196
self._local_rdma_devices: dict[str, set[int]] = defaultdict(set)
198197
self._remote_rdma_devices: dict[str, set[int]] = defaultdict(set)
199198
self._mem_fraction = mem_fraction or float(os.getenv("PS_MEM_FRACTION", "0.9"))
200-
self._custom_dist = custom_dist
201199

202200
assert self._rank is not None and self._rank >= 0, self._rank
203201
assert self._world_size and self._world_size > 0, self._world_size
@@ -498,8 +496,6 @@ def init_process_group(
498496
port=_get_master_port(master_port),
499497
rank=self._rank,
500498
world_size=self._world_size,
501-
custom_dist=self._custom_dist,
502-
backend=self.device_manager.backend,
503499
timeout=timeout,
504500
)
505501
logger.info(f"[rank{self._rank}] init process group successfully.")

examples/update.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,13 +159,14 @@ def join(
159159
parser.add_argument("--checkpoint-name", type=str, default="my-checkpoint-iter-0")
160160
parser.add_argument("--update-method", type=str, default="broadcast")
161161
parser.add_argument("--uds", type=str, default=None)
162-
parser.add_argument("--custom-dist", action="store_true")
162+
parser.add_argument("--custom-dist", type=str, default=None)
163163
args = parser.parse_args()
164164
rank = int(os.getenv("RANK"))
165165
world_size = int(os.getenv("WORLD_SIZE"))
166166

167167
req_func = req_inference(args.endpoint, args.inference_parallel_size, args.uds)
168-
ps = ParameterServer(auto_pg=True, custom_dist=args.custom_dist)
168+
dist.use_backend(args.custom_dist)
169+
ps = ParameterServer(auto_pg=True)
169170
if args.load_metas_file:
170171
join(
171172
ps,

0 commit comments

Comments
 (0)