Skip to content

Commit 0cdd61b

Browse files
author
kip-cxj
committed
fix pre-commit
1 parent 305886f commit 0cdd61b

File tree

6 files changed

+30
-79
lines changed

6 files changed

+30
-79
lines changed
Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
from .base import (
22
Distributed,
3-
init_process_group,
4-
destroy_process_group,
5-
is_initialized,
63
all_gather_object,
74
all_reduce,
8-
broadcast,
95
barrier,
6+
broadcast,
7+
destroy_process_group,
8+
init_process_group,
9+
is_initialized,
1010
new_group,
1111
)
1212

1313
__all__ = [
1414
"Distributed",
15-
"init_process_group",
16-
"destroy_process_group",
17-
"is_initialized",
1815
"all_gather_object",
1916
"all_reduce",
20-
"broadcast",
2117
"barrier",
18+
"broadcast",
19+
"destroy_process_group",
20+
"init_process_group",
21+
"is_initialized",
2222
"new_group",
2323
]

checkpoint_engine/distributed/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from abc import ABC, abstractmethod
1+
import importlib
22
import io
33
import pickle
4+
from abc import ABC, abstractmethod
45
from datetime import timedelta
56
from typing import Any, List
6-
import importlib
77

88
import torch
99
import torch.distributed as torch_dist
@@ -169,6 +169,7 @@ def is_initialized() -> bool:
169169
return torch_dist.is_initialized()
170170
return _BACKEND_INSTANCE.is_initialized()
171171

172+
172173
def all_gather_object(
173174
object_list: list[Any],
174175
obj: Any,

checkpoint_engine/distributed/hccl.py

Lines changed: 6 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,10 @@
1414
hcclComm_t,
1515
hcclDataType_t,
1616
hcclDataTypeEnum,
17-
hcclRedOp_t,
18-
hcclRedOpTypeEnum,
1917
hcclResult_t,
20-
hcclUniqueId,
2118
)
2219
from vllm_ascend.utils import current_stream
20+
2321
from checkpoint_engine.distributed.base import Distributed, _common_all_gather_object
2422

2523

@@ -214,7 +212,6 @@ def init_process_group(
214212
self.comm = self.pyhccl.comm
215213
self.initialized = True
216214

217-
218215
def destroy_process_group(
219216
self,
220217
group=None,
@@ -236,13 +233,7 @@ def destroy_process_group(
236233
def is_initialized(self) -> bool:
237234
return self.initialized
238235

239-
240-
def all_gather_object(
241-
self,
242-
object_list: list[Any],
243-
obj: Any,
244-
group=None
245-
):
236+
def all_gather_object(self, object_list: list[Any], obj: Any, group=None):
246237
assert self.initialized, "not initialized"
247238

248239
if group:
@@ -257,12 +248,7 @@ def all_gather_object(
257248
self.pyhccl.comm = self.comm
258249

259250

260-
def all_reduce(
261-
self,
262-
tensor: torch.Tensor,
263-
op=ReduceOp.SUM,
264-
group=None
265-
):
251+
def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, group=None):
266252
assert self.initialized, "not initialized"
267253

268254
if group:
@@ -278,12 +264,7 @@ def all_reduce(
278264
self.pyhccl.comm = self.comm
279265

280266

281-
def broadcast(
282-
self,
283-
tensor: torch.Tensor,
284-
src=None,
285-
group=None
286-
):
267+
def broadcast(self, tensor: torch.Tensor, src=None, group=None):
287268
assert self.initialized, "not initialized"
288269

289270
if group:
@@ -303,10 +284,7 @@ def broadcast(
303284
self.pyhccl.rank = self.rank
304285

305286

306-
def barrier(
307-
self,
308-
group=None
309-
):
287+
def barrier(self, group=None):
310288
assert self.initialized, "not initialized"
311289

312290
if group:
@@ -322,10 +300,7 @@ def barrier(
322300
self.pyhccl.comm = self.comm
323301

324302

325-
def new_group(
326-
self,
327-
ranks
328-
):
303+
def new_group(self, ranks):
329304
assert self.initialized, "not initialized"
330305

331306
# if ranks is None or [], using the world instead

checkpoint_engine/distributed/nccl.py

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
import ctypes
22
from datetime import timedelta
3-
from typing import Any, List, Optional
3+
from typing import Any
44

55
import torch
66
from torch.distributed import ReduceOp
77
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
88
from vllm.distributed.device_communicators.pynccl_wrapper import (
99
Function,
1010
NCCLLibrary,
11-
buffer_type,
1211
ncclComm_t,
1312
ncclResult_t,
1413
)
1514
from vllm.distributed.utils import StatelessProcessGroup
1615
from vllm.utils import current_stream
16+
1717
from checkpoint_engine.distributed.base import Distributed, _common_all_gather_object
1818

1919

@@ -132,7 +132,6 @@ def init_process_group(
132132
self.comm = self.pynccl.comm
133133
self.initialized = True
134134

135-
136135
def destroy_process_group(
137136
self,
138137
group=None,
@@ -155,12 +154,7 @@ def is_initialized(self) -> bool:
155154
return self.initialized
156155

157156

158-
def all_gather_object(
159-
self,
160-
object_list: list[Any],
161-
obj: Any,
162-
group=None
163-
):
157+
def all_gather_object(self, object_list: list[Any], obj: Any, group=None):
164158
assert self.initialized, "not initialized"
165159

166160
if group:
@@ -175,12 +169,7 @@ def all_gather_object(
175169
self.pynccl.comm = self.comm
176170

177171

178-
def all_reduce(
179-
self,
180-
tensor: torch.Tensor,
181-
op=ReduceOp.SUM,
182-
group=None
183-
):
172+
def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, group=None):
184173
assert self.initialized, "not initialized"
185174

186175
if group:
@@ -196,12 +185,7 @@ def all_reduce(
196185
self.pynccl.comm = self.comm
197186

198187

199-
def broadcast(
200-
self,
201-
tensor: torch.Tensor,
202-
src=None,
203-
group=None
204-
):
188+
def broadcast(self, tensor: torch.Tensor, src=None, group=None):
205189
assert self.initialized, "not initialized"
206190

207191
if group:
@@ -221,10 +205,7 @@ def broadcast(
221205
self.pynccl.rank = self.rank
222206

223207

224-
def barrier(
225-
self,
226-
group=None
227-
):
208+
def barrier(self, group=None):
228209
assert self.initialized, "not initialized"
229210

230211
if group:
@@ -240,10 +221,7 @@ def barrier(
240221
self.pynccl.comm = self.comm
241222

242223

243-
def new_group(
244-
self,
245-
ranks
246-
):
224+
def new_group(self, ranks):
247225
assert self.initialized, "not initialized"
248226

249227
# ranks is None or []

checkpoint_engine/ps.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -519,9 +519,7 @@ def init_process_group(
519519
)
520520
logger.info(f"[rank{self._rank}] init process group successfully.")
521521

522-
def store_based_barrier(
523-
self, store, timeout: timedelta = timedelta(minutes=5)
524-
) -> None:
522+
def store_based_barrier(self, store, timeout: timedelta = timedelta(minutes=5)) -> None:
525523
"""
526524
Perform a store-based barrier synchronization across all ranks.
527525
@@ -571,11 +569,10 @@ def update(
571569
try:
572570
master_addr = os.getenv("MASTER_ADDR") or master_addr
573571
assert master_addr, "master_addr is required"
574-
if self._auto_pg:
575-
if not dist.is_initialized():
576-
self.init_process_group(
577-
timeout=timeout, master_addr=master_addr, master_port=master_port
578-
)
572+
if self._auto_pg and not dist.is_initialized():
573+
self.init_process_group(
574+
timeout=timeout, master_addr=master_addr, master_port=master_port
575+
)
579576
# if ranks is None or [], it will use fully broadcast to update to all ranks
580577
ranks_group = dist.new_group(ranks) if ranks else None
581578
self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)

examples/update.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
from loguru import logger
1414
from safetensors import safe_open
1515

16+
import checkpoint_engine.distributed as dist
1617
from checkpoint_engine import request_inference_to_update
1718
from checkpoint_engine.ps import ParameterServer
18-
import checkpoint_engine.distributed as dist
1919

2020

2121
@contextmanager

0 commit comments

Comments
 (0)