Skip to content

Commit 6a98e33

Browse files
committed
adapt the latest Megatron‑LM to support the torch_dist checkpoint format, and update the async checkpoint patch logic
1 parent 97394b3 commit 6a98e33

File tree

2 files changed

+37
-12
lines changed

2 files changed

+37
-12
lines changed

primus/backends/megatron/core/dist_checkpointing/strategies/filesystem_async.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
# See LICENSE for license information.
55
###############################################################################
66

7+
import ctypes
8+
79
import torch
810
from megatron.core.dist_checkpointing.strategies.filesystem_async import (
911
FileSystemWriterAsync,
1012
)
1113

1214
from primus.modules.module_utils import log_rank_0, warning_rank_0
13-
14-
15+
1516
class PrimusFileSystemWriterAsync(FileSystemWriterAsync):
1617
def __init__(self, *args, **kwargs):
1718
super().__init__(*args, **kwargs)
@@ -21,19 +22,41 @@ def preload_tensors(*args, **kwargs):
2122
# (limou)
2223
# change argument non_blocking to False on HIP platform
2324
# the tensors will be stored in pinned memory if non_blocking=True
24-
# currently on the ROCm platform
25+
# on the ROCm platform (hip_runtime_version < 7.1)
2526
# forking a subprocess afterward with pinned_memory=True will trigger segmentation fault
2627
if torch.version.hip:
27-
log_rank_0("HIP env detected, change argument non_blocking in FileSystemWriterAsync to False")
28-
if "non_blocking" in kwargs:
29-
kwargs["non_blocking"] = False
30-
elif len(args) > 0 and type(args[-1]) == type(True):
31-
# TODO (limou)
32-
# non_blocking may NOT always be the last argument in the future
33-
args = args[:-1] + (False,)
34-
else:
35-
warning_rank_0("found argument non_blocking failed")
28+
major, minor = PrimusFileSystemWriterAsync.get_hip_runtime_version()
29+
log_rank_0(f"hip runtime version : {major}.{minor}")
30+
if major < 7 or (major == 7 and minor < 1):
31+
log_rank_0("HIP env detected, change argument non_blocking in FileSystemWriterAsync to False")
32+
if "non_blocking" in kwargs:
33+
kwargs["non_blocking"] = False
34+
elif len(args) > 0 and type(args[-1]) == type(True):
35+
# TODO (limou)
36+
# non_blocking may NOT always be the last argument in the future
37+
args = args[:-1] + (False,)
38+
else:
39+
warning_rank_0("found argument non_blocking failed")
3640

3741
return super(PrimusFileSystemWriterAsync, PrimusFileSystemWriterAsync).preload_tensors(
3842
*args, **kwargs
3943
)
44+
45+
# unlike torch.version.hip
46+
# hipRuntimeGetVersion() can return the HIP runtime version instead of build-time
47+
@staticmethod
48+
def get_hip_runtime_version():
49+
try:
50+
libhip = ctypes.CDLL("libamdhip64.so")
51+
hipRuntimeGetVersion = libhip.hipRuntimeGetVersion
52+
hipRuntimeGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
53+
hipRuntimeGetVersion.restype = ctypes.c_int
54+
version = ctypes.c_int()
55+
error_code = hipRuntimeGetVersion(ctypes.byref(version))
56+
if error_code != 0:
57+
return (-1, -1)
58+
# (major_version, minor_version)
59+
return (version.value//10000000, (version.value//100000)%100)
60+
except Exception as e:
61+
print(e)
62+
return (-1, -1)

primus/configs/modules/megatron/trainer_base.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ non_persistent_ckpt_type: null # 'global', 'local', 'in_memory', null
124124
non_persistent_global_ckpt_dir: null # str
125125
non_persistent_local_ckpt_dir: null # str
126126
non_persistent_local_ckpt_algo: "fully_parallel" # 'fully_parallel', 'atomic'
127+
dist_ckpt_save_pre_mcore_014: null
128+
dist_ckpt_optim_fully_reshardable: null
127129

128130
pretrained_checkpoint: null
129131
ckpt_step: null

0 commit comments

Comments
 (0)