Skip to content

Commit 9781aa5

Browse files
author
kip-cxj
committed
add type hints
1 parent 2846606 commit 9781aa5

File tree

4 files changed

+84
-49
lines changed

4 files changed

+84
-49
lines changed

checkpoint_engine/distributed/base.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
import torch
99
import torch.distributed as torch_dist
10+
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
11+
from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator
1012

1113

1214
class Distributed(ABC):
@@ -24,7 +26,7 @@ def init_process_group(
2426
@abstractmethod
2527
def destroy_process_group(
2628
self,
27-
group,
29+
group: torch_dist.ProcessGroup | int | None = None,
2830
):
2931
raise NotImplementedError
3032

@@ -37,7 +39,7 @@ def all_gather_object(
3739
self,
3840
object_list: list[Any],
3941
obj: Any,
40-
group,
42+
group: torch_dist.ProcessGroup | int | None = None,
4143
):
4244
raise NotImplementedError
4345

@@ -46,7 +48,7 @@ def all_reduce(
4648
self,
4749
tensor: torch.Tensor,
4850
op: torch_dist.ReduceOp,
49-
group,
51+
group: torch_dist.ProcessGroup | int | None = None,
5052
):
5153
raise NotImplementedError
5254

@@ -55,14 +57,14 @@ def broadcast(
5557
self,
5658
tensor: torch.Tensor,
5759
src: int,
58-
group,
60+
group: torch_dist.ProcessGroup | int | None = None,
5961
):
6062
raise NotImplementedError
6163

6264
@abstractmethod
6365
def barrier(
6466
self,
65-
group,
67+
group: torch_dist.ProcessGroup | int | None = None,
6668
):
6769
raise NotImplementedError
6870

@@ -81,7 +83,7 @@ def new_group(
8183
_unpickler = pickle.Unpickler
8284

8385

84-
def _object_to_tensor(obj, device):
86+
def _object_to_tensor(obj: Any, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
8587
f = io.BytesIO()
8688
_pickler(f).dump(obj)
8789
byte_storage = torch.ByteStorage._from_buffer(f.getvalue())
@@ -90,13 +92,15 @@ def _object_to_tensor(obj, device):
9092
return byte_tensor, local_size
9193

9294

93-
def _tensor_to_object(tensor, tensor_size):
95+
def _tensor_to_object(tensor: torch.Tensor, tensor_size: int) -> Any:
9496
tensor = tensor.cpu()
9597
buf = tensor.numpy().tobytes()[:tensor_size]
9698
return _unpickler(io.BytesIO(buf)).load()
9799

98100

99-
def _flatten_for_scatter_gather(tensor_list, copy=False):
101+
def _flatten_for_scatter_gather(
102+
tensor_list: list[torch.Tensor], copy: bool = False
103+
) -> torch.Tensor:
100104
if not tensor_list:
101105
raise RuntimeError("Received an empty list.")
102106
t = tensor_list[0]
@@ -109,7 +113,13 @@ def _flatten_for_scatter_gather(tensor_list, copy=False):
109113
return buffer
110114

111115

112-
def _common_all_gather_object(comm, device, world_size, object_list, object):
116+
def _common_all_gather_object(
117+
comm: PyNcclCommunicator | PyHcclCommunicator | Any,
118+
device: torch.device,
119+
world_size: int,
120+
object_list: list[Any],
121+
object: Any,
122+
):
113123
input_tensor, local_size = _object_to_tensor(object, device)
114124
object_sizes_tensor = torch.empty(world_size, dtype=torch.long, device=device)
115125
comm.all_gather(object_sizes_tensor, local_size)
@@ -157,7 +167,7 @@ def init_process_group(
157167
_BACKEND_INSTANCE.init_process_group(host, port, rank, world_size, timeout)
158168

159169

160-
def destroy_process_group(group=None):
170+
def destroy_process_group(group: torch_dist.ProcessGroup | int | None = None):
161171
if _BACKEND_INSTANCE is None:
162172
torch_dist.destroy_process_group(group)
163173
return
@@ -173,7 +183,7 @@ def is_initialized() -> bool:
173183
def all_gather_object(
174184
object_list: list[Any],
175185
obj: Any,
176-
group=None,
186+
group: torch_dist.ProcessGroup | int | None = None,
177187
):
178188
if _BACKEND_INSTANCE is None:
179189
torch_dist.all_gather_object(object_list, obj, group)
@@ -183,8 +193,8 @@ def all_gather_object(
183193

184194
def all_reduce(
185195
tensor: torch.Tensor,
186-
op=torch_dist.ReduceOp.SUM,
187-
group=None,
196+
op: torch_dist.ReduceOp = torch_dist.ReduceOp.SUM,
197+
group: torch_dist.ProcessGroup | int | None = None,
188198
**kwargs,
189199
):
190200
if _BACKEND_INSTANCE is None:
@@ -195,8 +205,8 @@ def all_reduce(
195205

196206
def broadcast(
197207
tensor: torch.Tensor,
198-
src=None,
199-
group=None,
208+
src: int = 0,
209+
group: torch_dist.ProcessGroup | int | None = None,
200210
**kwargs,
201211
):
202212
if _BACKEND_INSTANCE is None:
@@ -205,14 +215,14 @@ def broadcast(
205215
_BACKEND_INSTANCE.broadcast(tensor, src, group)
206216

207217

208-
def barrier(group=None, **kwargs):
218+
def barrier(group: torch_dist.ProcessGroup | int | None = None, **kwargs):
209219
if _BACKEND_INSTANCE is None:
210220
torch_dist.barrier(group, **kwargs)
211221
return
212222
_BACKEND_INSTANCE.barrier(group)
213223

214224

215-
def new_group(ranks: list[int], **kwargs):
225+
def new_group(ranks: list[int], **kwargs) -> torch_dist.ProcessGroup | int | None:
216226
if _BACKEND_INSTANCE is None:
217227
return torch_dist.new_group(ranks, **kwargs)
218228
return _BACKEND_INSTANCE.new_group(ranks)

checkpoint_engine/distributed/hccl.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import ctypes
22
from datetime import timedelta
3-
from typing import Any
3+
from typing import Any, ClassVar
44

55
import torch
66
from torch.distributed import ReduceOp
@@ -22,7 +22,7 @@
2222

2323

2424
class HcclCommConfig(ctypes.Structure):
25-
_fields_ = [
25+
_fields_: ClassVar[list[tuple[str, Any]]] = [
2626
("size", ctypes.c_size_t),
2727
("magic_word", ctypes.c_uint32),
2828
("version", ctypes.c_uint32),
@@ -81,15 +81,29 @@ class HcclCommConfig(ctypes.Structure):
8181
]
8282

8383

84-
def hccl_all_gather(self, send_buf, recv_buf, count, data_type, comm, stream):
84+
def hccl_all_gather(
85+
self, # noqa: ANN001
86+
send_buf: buffer_type,
87+
recv_buf: buffer_type,
88+
count: ctypes.c_uint64,
89+
data_type: hcclDataType_t,
90+
comm: hcclComm_t,
91+
stream: aclrtStream_t,
92+
):
8593
self.HCCL_CHECK(
8694
self._funcs["HcclAllGather"](send_buf, recv_buf, count, data_type, comm, stream)
8795
)
8896

8997

9098
def hccl_create_subcomm_config(
91-
self, comm, ranks_size, c_rank_ids, subcomm_id, subcomm_rank, comm_config
92-
):
99+
self, # noqa: ANN001
100+
comm: hcclComm_t,
101+
ranks_size: ctypes.c_uint32,
102+
c_rank_ids: ctypes.POINTER(ctypes.c_uint32),
103+
subcomm_id: ctypes.c_uint64,
104+
subcomm_rank: ctypes.c_uint64,
105+
comm_config: HcclCommConfig,
106+
) -> hcclComm_t:
93107
subcomm = hcclComm_t()
94108
self.HCCL_CHECK(
95109
self._funcs["HcclCreateSubCommConfig"](
@@ -112,17 +126,19 @@ def hccl_create_subcomm_config(
112126

113127

114128
class PyHcclCommunicatorEx(PyHcclCommunicator):
115-
def __init__(self, group, device):
129+
def __init__(self, group: StatelessProcessGroup, device: torch.device):
116130
super().__init__(group, device)
117131
self.subcomm_id = 1
118132

119-
def destroy_comm(self, comm=None):
133+
def destroy_comm(self, comm: hcclComm_t = None):
120134
if comm:
121135
self.hccl.hcclCommDestroy(comm)
122136
else:
123137
self.hccl.hcclCommDestroy(self.comm)
124138

125-
def all_gather(self, out_tensor: torch.Tensor, in_tensor: torch.Tensor, stream=None):
139+
def all_gather(
140+
self, out_tensor: torch.Tensor, in_tensor: torch.Tensor, stream: torch.npu.Stream = None
141+
) -> torch.Tensor:
126142
if self.disabled:
127143
return
128144
assert in_tensor.device == self.device, (
@@ -141,7 +157,7 @@ def all_gather(self, out_tensor: torch.Tensor, in_tensor: torch.Tensor, stream=N
141157
)
142158
return out_tensor
143159

144-
def create_subcomm(self, ranks):
160+
def create_subcomm(self, ranks: list[int]) -> hcclComm_t:
145161
comm_config = HcclCommConfig(
146162
size=312,
147163
magic_word=0xF0F0F0F0,
@@ -214,7 +230,7 @@ def init_process_group(
214230

215231
def destroy_process_group(
216232
self,
217-
group=None,
233+
group: int | None = None,
218234
):
219235
assert self.initialized, "not initialized"
220236

@@ -232,7 +248,7 @@ def destroy_process_group(
232248
def is_initialized(self) -> bool:
233249
return self.initialized
234250

235-
def all_gather_object(self, object_list: list[Any], obj: Any, group=None):
251+
def all_gather_object(self, object_list: list[Any], obj: Any, group: int | None = None):
236252
assert self.initialized, "not initialized"
237253

238254
if group:
@@ -246,7 +262,9 @@ def all_gather_object(self, object_list: list[Any], obj: Any, group=None):
246262
if group:
247263
self.pyhccl.comm = self.comm
248264

249-
def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, group=None):
265+
def all_reduce(
266+
self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, group: int | None = None
267+
):
250268
assert self.initialized, "not initialized"
251269

252270
if group:
@@ -261,7 +279,7 @@ def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, group=None):
261279
if group:
262280
self.pyhccl.comm = self.comm
263281

264-
def broadcast(self, tensor: torch.Tensor, src=None, group=None):
282+
def broadcast(self, tensor: torch.Tensor, src: int | None = None, group: int | None = None):
265283
assert self.initialized, "not initialized"
266284

267285
if group:
@@ -280,7 +298,7 @@ def broadcast(self, tensor: torch.Tensor, src=None, group=None):
280298
self.pyhccl.comm = self.comm
281299
self.pyhccl.rank = self.rank
282300

283-
def barrier(self, group=None):
301+
def barrier(self, group: int | None = None):
284302
assert self.initialized, "not initialized"
285303

286304
if group:
@@ -295,7 +313,7 @@ def barrier(self, group=None):
295313
if group:
296314
self.pyhccl.comm = self.comm
297315

298-
def new_group(self, ranks):
316+
def new_group(self, ranks: list[int]) -> int:
299317
assert self.initialized, "not initialized"
300318

301319
# if ranks is None or [], using the world instead

0 commit comments

Comments
 (0)