44import os
55import pickle
66import random
7- import socket
87import threading
98import time
109from collections import defaultdict
1110from collections .abc import Callable
1211from datetime import timedelta
13- from functools import lru_cache
1412from typing import TYPE_CHECKING , Annotated , Any , BinaryIO , NamedTuple
1513
1614import httpx
2321from safetensors .torch import safe_open
2422from torch .multiprocessing .reductions import reduce_tensor
2523
24+ from checkpoint_engine .device_utils import DeviceManager , get_ip , npu_generate_uuid
25+
2626
2727if TYPE_CHECKING :
2828 from typing import TypeVar
@@ -254,28 +254,16 @@ def _concat_tp_weights(
254254 return torch .cat ([w for w in tp_weights ], dim = tp_concat_dim )
255255
256256
257- def _get_physical_gpu_id (device_index : int | None = None ) -> str :
257+ def _get_physical_gpu_id (device_manager : DeviceManager , device_index : int | None = None ) -> str :
258258 try :
259- return f"GPU-{ torch .cuda .get_device_properties (device_index ).uuid !s} "
259+ if device_manager .device_type == "npu" :
260+ return f"NPU-{ npu_generate_uuid ()} "
261+ else :
262+ return f"GPU-{ device_manager .device_module .get_device_properties (device_index ).uuid !s} "
260263 except AssertionError as e :
261264 raise ValueError (f"fail to get physical gpu id { device_index } " ) from e
262265
263266
264- @lru_cache (maxsize = 1 )
265- def _get_ip () -> str :
266- try :
267- # try to get ip from network interface
268- with socket .socket (socket .AF_INET , socket .SOCK_DGRAM ) as s :
269- s .connect (("8.8.8.8" , 80 ))
270- return s .getsockname ()[0 ]
271- except Exception as e : # noqa: BLE001
272- # fallback to get ip from hostname
273- logger .warning (
274- f"fail to get ip from network interface, fallback to get ip from hostname: { e } "
275- )
276- return socket .gethostbyname (socket .gethostname ())
277-
278-
279267def _ibv_get_device_list () -> list [str ]:
280268 lib = ctypes .CDLL ("libibverbs.so.1" )
281269 lib .ibv_get_device_list .argtypes = [ctypes .POINTER (ctypes .c_int )] # int *num_devices
@@ -677,14 +665,14 @@ def _get_bcast_rank_map(world_size: int, ranks: list[int] | None) -> dict[int, i
677665
678666
679667class P2PStore :
680- def __init__ (self ):
668+ def __init__ (self , device_manager : DeviceManager ):
681669 from mooncake .engine import TransferEngine
682670
683671 self .rank = int (os .getenv ("RANK" ))
684- gpu_count = torch . cuda .device_count ()
672+ gpu_count = device_manager . device_module .device_count ()
685673 local_rank = self .rank % gpu_count
686674 self .device = _get_my_rdma_device (local_rank , gpu_count , _get_rdma_devices ())
687- self .ip = _get_ip ()
675+ self .ip = get_ip ()
688676
689677 # we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases
690678 retry_count = 8
@@ -761,7 +749,8 @@ def __init__(
761749 """
762750 self ._rank = rank or int (os .environ .get ("RANK" , None ))
763751 self ._world_size = world_size or int (os .environ .get ("WORLD_SIZE" , None ))
764- self ._gpu_count = gpu_count or torch .cuda .device_count ()
752+ self .device_manager = DeviceManager ()
753+ self ._gpu_count = gpu_count or self .device_manager .device_module .device_count ()
765754 self ._local_rank = self ._rank % self ._gpu_count
766755 self ._auto_pg = auto_pg
767756 self ._all_hosts = []
@@ -775,7 +764,7 @@ def __init__(
775764 assert (
776765 self ._gpu_count is not None
777766 and self ._gpu_count > 0
778- and self ._gpu_count <= torch . cuda .device_count ()
767+ and self ._gpu_count <= self . device_manager . device_module .device_count ()
779768 ), self ._gpu_count
780769 assert (
781770 self ._mem_fraction is not None and self ._mem_fraction > 0 and self ._mem_fraction <= 1
@@ -788,14 +777,14 @@ def __init__(
788777 # dict key is owner_rank, value is a bucket metas list in owner_rank
789778 self ._current_global_parameter_metas : dict [int , MemoryBufferMetaList ] = {}
790779 try :
791- self ._p2p_store = P2PStore ()
780+ self ._p2p_store = P2PStore (self . device_manager )
792781 except ImportError as e :
793782 logger .warning (f"[rank{ self ._rank } ] fail to initialize p2p store due to { e } " )
794783 self ._p2p_store = None
795784
796785 device_index = self ._local_rank
797- torch . cuda .set_device (device_index )
798- self ._device_uuid = _get_physical_gpu_id (device_index )
786+ self . device_manager . device_module .set_device (device_index )
787+ self ._device_uuid = _get_physical_gpu_id (self . device_manager , device_index )
799788 self ._rdma_device = None if self ._p2p_store is None else self ._p2p_store .device
800789
801790 def _logger_rank0 (self , msg : str ):
@@ -885,7 +874,7 @@ def gather_metas(self, checkpoint_name: str):
885874 for x in self ._memory_pool .get (checkpoint_name , [])
886875 ],
887876 p2p_store_addr = None if self ._p2p_store is None else self ._p2p_store .addr ,
888- host_ip = _get_ip (),
877+ host_ip = get_ip (),
889878 device_uuid = self ._device_uuid ,
890879 rdma_device = self ._rdma_device or "" ,
891880 )
@@ -948,7 +937,7 @@ def init_process_group(
948937 is_master = self ._rank == 0 ,
949938 )
950939 dist .init_process_group (
951- backend = "nccl" ,
940+ backend = self . device_manager . backend ,
952941 world_size = self ._world_size ,
953942 rank = self ._rank ,
954943 timeout = timeout ,
@@ -994,12 +983,12 @@ def update(
994983 if self ._auto_pg :
995984 dist .destroy_process_group ()
996985
997- torch . cuda .empty_cache ()
986+ self . device_manager . device_module .empty_cache ()
998987
999988 logger .info (
1000989 f"[rank{ self ._rank } ] update checkpoint { checkpoint_name } with ranks { ranks } done. "
1001- f"Current CUDA allocated { torch . cuda .memory_allocated () / 1024 / 1024 } MB, "
1002- f"reserved { torch . cuda .memory_reserved () / 1024 / 1024 } MB."
990+ f"Current CUDA allocated { self . device_manager . device_module .memory_allocated () / 1024 / 1024 } MB, "
991+ f"reserved { self . device_manager . device_module .memory_reserved () / 1024 / 1024 } MB."
1003992 )
1004993 except Exception as e :
1005994 logger .exception (
@@ -1023,13 +1012,15 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int,
10231012 tensor = torch .tensor (
10241013 [
10251014 # proportion of current cuda free memory bytes
1026- int (float (torch .cuda .mem_get_info ()[0 ]) * self ._mem_fraction ),
1015+ int (
1016+ float (self .device_manager .device_module .mem_get_info ()[0 ]) * self ._mem_fraction
1017+ ),
10271018 # we use negative value to reuse allreduce min operation
10281019 # for getting the max value of zmq_addr_counter in all ranks
10291020 - self ._zmq_addr_counter ,
10301021 ],
10311022 dtype = torch .int64 ,
1032- device = "cuda" ,
1023+ device = self . device_manager . device_type ,
10331024 )
10341025 dist .all_reduce (tensor , op = dist .ReduceOp .MIN )
10351026 tensor = tensor .cpu ()
@@ -1092,7 +1083,7 @@ def _copy_to_buffer(
10921083 assert offset == bucket .size , f"offset { offset } != bucket_size { bucket .size } "
10931084 if owner_rank is not None :
10941085 self ._p2p_store .batch_transfer_sync_read (target_addr , buf_ptrs , remote_ptrs , lens )
1095- torch . cuda .synchronize ()
1086+ self . device_manager . device_module .synchronize ()
10961087
10971088 def init_process_group_for_ranks (
10981089 self ,
@@ -1199,7 +1190,7 @@ def _update_per_bucket(
11991190 h2d_buffer : torch .Tensor | None = (
12001191 None
12011192 if disable_h2d_buffer
1202- else torch .empty (bucket_size , dtype = torch .uint8 , device = "cuda" )
1193+ else torch .empty (bucket_size , dtype = torch .uint8 , device = self . device_manager . device_type )
12031194 )
12041195 # p2p store need to register h2d_buffer to let other ranks read
12051196 if ranks :
@@ -1212,7 +1203,9 @@ def _update_per_bucket(
12121203 continue
12131204 receiver_rank_buckets .append ((owner_rank , bucket ))
12141205
1215- buffer = torch .empty (bucket_size * 2 , dtype = torch .uint8 , device = "cuda" )
1206+ buffer = torch .empty (
1207+ bucket_size * 2 , dtype = torch .uint8 , device = self .device_manager .device_type
1208+ )
12161209 handle = reduce_tensor (buffer )
12171210
12181211 buckets_by_receiver_rank : dict [int , list [H2DBucket ]] = defaultdict (list )
@@ -1245,8 +1238,8 @@ def _update_per_bucket(
12451238 continue
12461239 bucket = _buckets [i ]
12471240 alloc , reserved = (
1248- torch . cuda .memory_allocated () / 1024 / 1024 ,
1249- torch . cuda .memory_reserved () / 1024 / 1024 ,
1241+ self . device_manager . device_module .memory_allocated () / 1024 / 1024 ,
1242+ self . device_manager . device_module .memory_reserved () / 1024 / 1024 ,
12501243 )
12511244 self ._logger_rank0 (
12521245 f"[rank{ self ._rank } ] begin to update bucket { gidx + 1 } /{ len (buckets )} receiver_rank { receiver_rank } in checkpoint { checkpoint_name } , bucket_size: { bucket .size / 1024 / 1024 :.2f} MiB, length: { len (bucket .items )} . "
@@ -1276,7 +1269,7 @@ def _update_per_bucket(
12761269 if ranks and h2d_buffer is not None :
12771270 self ._p2p_store .unregister_named_tensors ([h2d_buffer_name ])
12781271
1279- torch . cuda .empty_cache ()
1272+ self . device_manager . device_module .empty_cache ()
12801273
12811274
12821275def _init_api (ps : ParameterServer ) -> Any :
0 commit comments