44# See LICENSE for license information.
55###############################################################################
66
7+ import ctypes
8+
79import torch
810from megatron .core .dist_checkpointing .strategies .filesystem_async import (
911 FileSystemWriterAsync ,
1012)
1113
1214from primus .modules .module_utils import log_rank_0 , warning_rank_0
13-
14-
15+
1516class PrimusFileSystemWriterAsync (FileSystemWriterAsync ):
1617 def __init__ (self , * args , ** kwargs ):
1718 super ().__init__ (* args , ** kwargs )
@@ -21,19 +22,41 @@ def preload_tensors(*args, **kwargs):
2122 # (limou)
2223 # change argument non_blocking to False on HIP platform
2324 # the tensors will be stored in pinned memory if non_blocking=True
24- # currently on the ROCm platform
25+ # on the ROCm platform (hip_runtime_version < 7.1)
2526 # forking a subprocess afterward with pinned_memory=True will trigger segmentation fault
2627 if torch .version .hip :
27- log_rank_0 ("HIP env detected, change argument non_blocking in FileSystemWriterAsync to False" )
28- if "non_blocking" in kwargs :
29- kwargs ["non_blocking" ] = False
30- elif len (args ) > 0 and type (args [- 1 ]) == type (True ):
31- # TODO (limou)
32- # non_blocking may NOT always be the last argument in the future
33- args = args [:- 1 ] + (False ,)
34- else :
35- warning_rank_0 ("found argument non_blocking failed" )
28+ major , minor = PrimusFileSystemWriterAsync .get_hip_runtime_version ()
29+ log_rank_0 (f"hip runtime version : { major } .{ minor } " )
30+ if major < 7 or (major == 7 and minor < 1 ):
31+ log_rank_0 ("HIP env detected, change argument non_blocking in FileSystemWriterAsync to False" )
32+ if "non_blocking" in kwargs :
33+ kwargs ["non_blocking" ] = False
34+ elif len (args ) > 0 and type (args [- 1 ]) == type (True ):
35+ # TODO (limou)
36+ # non_blocking may NOT always be the last argument in the future
37+ args = args [:- 1 ] + (False ,)
38+ else :
39+ warning_rank_0 ("found argument non_blocking failed" )
3640
3741 return super (PrimusFileSystemWriterAsync , PrimusFileSystemWriterAsync ).preload_tensors (
3842 * args , ** kwargs
3943 )
44+
45+ # unlike torch.version.hip
46+ # hipRuntimeGetVersion() can return the HIP runtime version instead of build-time
47+ @staticmethod
48+ def get_hip_runtime_version ():
49+ try :
50+ libhip = ctypes .CDLL ("libamdhip64.so" )
51+ hipRuntimeGetVersion = libhip .hipRuntimeGetVersion
52+ hipRuntimeGetVersion .argtypes = [ctypes .POINTER (ctypes .c_int )]
53+ hipRuntimeGetVersion .restype = ctypes .c_int
54+ version = ctypes .c_int ()
55+ error_code = hipRuntimeGetVersion (ctypes .byref (version ))
56+ if error_code != 0 :
57+ return (- 1 , - 1 )
58+ # (major_version, minor_version)
59+ return (version .value // 10000000 , (version .value // 100000 )% 100 )
60+ except Exception as e :
61+ print (e )
62+ return (- 1 , - 1 )
0 commit comments