Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# See LICENSE for license information.
###############################################################################

import ctypes

import torch
from megatron.core.dist_checkpointing.strategies.filesystem_async import (
FileSystemWriterAsync,
Expand All @@ -21,19 +23,41 @@ def preload_tensors(*args, **kwargs):
# (limou)
# change argument non_blocking to False on HIP platform
# the tensors will be stored in pinned memory if non_blocking=True
# currently on the ROCm platform
# on the ROCm platform (hip_runtime_version < 7.1)
# forking a subprocess afterward with pinned_memory=True will trigger segmentation fault
if torch.version.hip:
log_rank_0("HIP env detected, change argument non_blocking in FileSystemWriterAsync to False")
if "non_blocking" in kwargs:
kwargs["non_blocking"] = False
elif len(args) > 0 and type(args[-1]) == type(True):
# TODO (limou)
# non_blocking may NOT always be the last argument in the future
args = args[:-1] + (False,)
else:
warning_rank_0("found argument non_blocking failed")
major, minor = PrimusFileSystemWriterAsync.get_hip_runtime_version()
log_rank_0(f"hip runtime version : {major}.{minor}")
if major < 7 or (major == 7 and minor < 1):
log_rank_0("HIP env detected, change argument non_blocking in FileSystemWriterAsync to False")
if "non_blocking" in kwargs:
kwargs["non_blocking"] = False
elif len(args) > 0 and type(args[-1]) == type(True):
# TODO (limou)
# non_blocking may NOT always be the last argument in the future
args = args[:-1] + (False,)
else:
warning_rank_0("found argument non_blocking failed")

return super(PrimusFileSystemWriterAsync, PrimusFileSystemWriterAsync).preload_tensors(
*args, **kwargs
)

# unlike torch.version.hip
# hipRuntimeGetVersion() can return the HIP runtime version instead of build-time
@staticmethod
def get_hip_runtime_version():
try:
libhip = ctypes.CDLL("libamdhip64.so")
hipRuntimeGetVersion = libhip.hipRuntimeGetVersion
hipRuntimeGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
hipRuntimeGetVersion.restype = ctypes.c_int
version = ctypes.c_int()
error_code = hipRuntimeGetVersion(ctypes.byref(version))
if error_code != 0:
return (-1, -1)
# (major_version, minor_version)
return (version.value // 10000000, (version.value // 100000) % 100)
except Exception as e:
print(e)
return (-1, -1)
2 changes: 2 additions & 0 deletions primus/configs/modules/megatron/trainer_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ non_persistent_ckpt_type: null # 'global', 'local', 'in_memory', null
non_persistent_global_ckpt_dir: null # str
non_persistent_local_ckpt_dir: null # str
non_persistent_local_ckpt_algo: "fully_parallel" # 'fully_parallel', 'atomic'
dist_ckpt_save_pre_mcore_014: null
dist_ckpt_optim_fully_reshardable: null

pretrained_checkpoint: null
ckpt_step: null
Expand Down