@@ -28,7 +28,8 @@ def is_torch_npu_available() -> bool:
2828 try :
2929 if hasattr (torch , "npu" ) and callable (getattr (torch .npu , "is_available" , None )):
3030 return torch .npu .is_available ()
31- return False
31+ else :
32+ return False
3233 except ImportError :
3334 return False
3435
@@ -37,7 +38,7 @@ def __init__(self):
3738 self .device_type = self ._detect_device_type ()
3839 self ._setup_device_module ()
3940
40- def _detect_device_type (self ):
41+ def _detect_device_type (self ) -> str :
4142 if is_torch_npu_available ():
4243 return "npu"
4344 elif torch .cuda .is_available ():
@@ -50,7 +51,7 @@ def _setup_device_module(self):
5051 elif self .device_type == "cuda" :
5152 self .device_module = torch .cuda
5253
53- def get_backend (self ):
54+ def get_backend (self ) -> str :
5455 if self .device_type == "npu" :
5556 return "hccl"
5657 elif self .device_type == "cuda" :
@@ -624,7 +625,7 @@ def _get_master_port(master_port: int | None = None) -> int:
624625
625626
626627class P2PStore :
627- def __init__ (self , device_manager ):
628+ def __init__ (self , device_manager : DeviceManager ):
628629 from mooncake .engine import TransferEngine
629630
630631 self .rank = int (os .getenv ("RANK" ))
0 commit comments