11import ctypes
22from datetime import timedelta
3- from typing import Any
3+ from typing import Any , ClassVar
44
55import torch
66from torch .distributed import ReduceOp
2222
2323
2424class HcclCommConfig (ctypes .Structure ):
25- _fields_ = [
25+ _fields_ : ClassVar [ list [ tuple [ str , Any ]]] = [
2626 ("size" , ctypes .c_size_t ),
2727 ("magic_word" , ctypes .c_uint32 ),
2828 ("version" , ctypes .c_uint32 ),
@@ -81,15 +81,29 @@ class HcclCommConfig(ctypes.Structure):
8181]
8282
8383
84- def hccl_all_gather (self , send_buf , recv_buf , count , data_type , comm , stream ):
84+ def hccl_all_gather (
85+ self , # noqa: ANN001
86+ send_buf : buffer_type ,
87+ recv_buf : buffer_type ,
88+ count : ctypes .c_uint64 ,
89+ data_type : hcclDataType_t ,
90+ comm : hcclComm_t ,
91+ stream : aclrtStream_t ,
92+ ):
8593 self .HCCL_CHECK (
8694 self ._funcs ["HcclAllGather" ](send_buf , recv_buf , count , data_type , comm , stream )
8795 )
8896
8997
9098def hccl_create_subcomm_config (
91- self , comm , ranks_size , c_rank_ids , subcomm_id , subcomm_rank , comm_config
92- ):
99+ self , # noqa: ANN001
100+ comm : hcclComm_t ,
101+ ranks_size : ctypes .c_uint32 ,
102+ c_rank_ids : ctypes .POINTER (ctypes .c_uint32 ),
103+ subcomm_id : ctypes .c_uint64 ,
104+ subcomm_rank : ctypes .c_uint64 ,
105+ comm_config : HcclCommConfig ,
106+ ) -> hcclComm_t :
93107 subcomm = hcclComm_t ()
94108 self .HCCL_CHECK (
95109 self ._funcs ["HcclCreateSubCommConfig" ](
@@ -112,17 +126,19 @@ def hccl_create_subcomm_config(
112126
113127
114128class PyHcclCommunicatorEx (PyHcclCommunicator ):
115- def __init__ (self , group , device ):
129+ def __init__ (self , group : StatelessProcessGroup , device : torch . device ):
116130 super ().__init__ (group , device )
117131 self .subcomm_id = 1
118132
119- def destroy_comm (self , comm = None ):
133+ def destroy_comm (self , comm : hcclComm_t = None ):
120134 if comm :
121135 self .hccl .hcclCommDestroy (comm )
122136 else :
123137 self .hccl .hcclCommDestroy (self .comm )
124138
125- def all_gather (self , out_tensor : torch .Tensor , in_tensor : torch .Tensor , stream = None ):
139+ def all_gather (
140+ self , out_tensor : torch .Tensor , in_tensor : torch .Tensor , stream : torch .npu .Stream = None
141+ ) -> torch .Tensor :
126142 if self .disabled :
127143 return
128144 assert in_tensor .device == self .device , (
@@ -141,7 +157,7 @@ def all_gather(self, out_tensor: torch.Tensor, in_tensor: torch.Tensor, stream=N
141157 )
142158 return out_tensor
143159
144- def create_subcomm (self , ranks ) :
160+ def create_subcomm (self , ranks : list [ int ]) -> hcclComm_t :
145161 comm_config = HcclCommConfig (
146162 size = 312 ,
147163 magic_word = 0xF0F0F0F0 ,
@@ -214,7 +230,7 @@ def init_process_group(
214230
215231 def destroy_process_group (
216232 self ,
217- group = None ,
233+ group : int | None = None ,
218234 ):
219235 assert self .initialized , "not initialized"
220236
@@ -232,7 +248,7 @@ def destroy_process_group(
232248 def is_initialized (self ) -> bool :
233249 return self .initialized
234250
235- def all_gather_object (self , object_list : list [Any ], obj : Any , group = None ):
251+ def all_gather_object (self , object_list : list [Any ], obj : Any , group : int | None = None ):
236252 assert self .initialized , "not initialized"
237253
238254 if group :
@@ -246,7 +262,9 @@ def all_gather_object(self, object_list: list[Any], obj: Any, group=None):
246262 if group :
247263 self .pyhccl .comm = self .comm
248264
249- def all_reduce (self , tensor : torch .Tensor , op = ReduceOp .SUM , group = None ):
265+ def all_reduce (
266+ self , tensor : torch .Tensor , op : ReduceOp = ReduceOp .SUM , group : int | None = None
267+ ):
250268 assert self .initialized , "not initialized"
251269
252270 if group :
@@ -261,7 +279,7 @@ def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, group=None):
261279 if group :
262280 self .pyhccl .comm = self .comm
263281
264- def broadcast (self , tensor : torch .Tensor , src = None , group = None ):
282+ def broadcast (self , tensor : torch .Tensor , src : int | None = None , group : int | None = None ):
265283 assert self .initialized , "not initialized"
266284
267285 if group :
@@ -280,7 +298,7 @@ def broadcast(self, tensor: torch.Tensor, src=None, group=None):
280298 self .pyhccl .comm = self .comm
281299 self .pyhccl .rank = self .rank
282300
283- def barrier (self , group = None ):
301+ def barrier (self , group : int | None = None ):
284302 assert self .initialized , "not initialized"
285303
286304 if group :
@@ -295,7 +313,7 @@ def barrier(self, group=None):
295313 if group :
296314 self .pyhccl .comm = self .comm
297315
298- def new_group (self , ranks ) :
316+ def new_group (self , ranks : list [ int ]) -> int :
299317 assert self .initialized , "not initialized"
300318
301319 # if ranks is None or [], using the world instead
0 commit comments