Skip to content

Commit a1c3f09

Browse files
committed
support mthreads-ml-py
1 parent 31792dd commit a1c3f09

File tree

3 files changed

+44
-9
lines changed

3 files changed

+44
-9
lines changed

nvitop/api/libnvml.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,20 @@
3737

3838
# Python Bindings for the NVIDIA Management Library (NVML)
3939
# https://pypi.org/project/nvidia-ml-py
40-
import pynvml as _pynvml
41-
from pynvml import * # noqa: F403 # pylint: disable=wildcard-import,unused-wildcard-import
42-
from pynvml import nvmlDeviceGetPciInfo # appease mypy # noqa: F401 # pylint: disable=unused-import
4340

44-
from nvitop.api.utils import NA, UINT_MAX, ULONGLONG_MAX, NaType
41+
from nvitop.api.utils import NA, UINT_MAX, ULONGLONG_MAX, NaType, is_musa
4542
from nvitop.api.utils import colored as __colored
4643

44+
_is_musa = is_musa()
45+
46+
if not _is_musa:
47+
import pynvml as _pynvml
48+
from pynvml import * # noqa: F403 # pylint: disable=wildcard-import,unused-wildcard-import
49+
from pynvml import nvmlDeviceGetPciInfo # appease mypy # noqa: F401 # pylint: disable=unused-import
50+
else:
51+
import pymtml as _pynvml
52+
from pymtml import * # noqa: F403 # pylint: disable=wildcard-import,unused-wildcard-import
53+
from pymtml import nvmlDeviceGetPciInfo # appease mypy # noqa: F401 # pylint: disable=unused-import
4754

4855
if _TYPE_CHECKING:
4956
from collections.abc import Callable as _Callable
@@ -540,7 +547,10 @@ def nvmlCheckReturn(retval: _Any, types: type | tuple[type, ...] | None = None,
540547
# Patch function `nvmlDeviceGet{Compute,Graphics,MPSCompute}RunningProcesses`
541548
if not _pynvml_installation_corrupted:
542549
# pylint: disable-next=ungrouped-imports
543-
from pynvml import _nvmlGetFunctionPointer, _PrintableStructure, nvmlStructToFriendlyObject
550+
if not _is_musa:
551+
from pynvml import _nvmlGetFunctionPointer, _PrintableStructure, nvmlStructToFriendlyObject
552+
else:
553+
from pymtml import _nvmlGetFunctionPointer, _PrintableStructure, nvmlStructToFriendlyObject
544554

545555
def _nvmlLookupFunctionPointer(symbol: str) -> _Any | None:
546556
try:
@@ -671,7 +681,11 @@ def __nvml_device_get_running_processes(
671681

672682
# First call to get the size
673683
c_count = _ctypes.c_uint(0)
674-
fn = _nvmlGetFunctionPointer(f'{func}{version_suffix}')
684+
try:
685+
fn = _nvmlGetFunctionPointer(f'{func}{version_suffix}')
686+
except Exception:
687+
return []
688+
675689
ret = fn(handle, _ctypes.byref(c_count), None)
676690

677691
if ret == NVML_SUCCESS:
@@ -876,7 +890,11 @@ def nvmlDeviceGetMemoryInfo( # pylint: disable=function-redefined
876890
'function `nvmlDeviceGetMemoryInfo`.',
877891
)
878892

879-
fn = _nvmlGetFunctionPointer(f'nvmlDeviceGetMemoryInfo{version_suffix}')
893+
try:
894+
fn = _nvmlGetFunctionPointer(f'nvmlDeviceGetMemoryInfo{version_suffix}')
895+
except Exception:
896+
return NA
897+
880898
ret = fn(handle, _ctypes.byref(c_memory))
881899
if ret != NVML_SUCCESS:
882900
raise NVMLError(ret)
@@ -952,15 +970,21 @@ def nvmlDeviceGetTemperature( # pylint: disable=function-redefined
952970
c_temp_v1.version = nvmlTemperature_v1
953971
# pylint: disable-next=attribute-defined-outside-init
954972
c_temp_v1.sensorType = _ctypes.c_uint(sensor)
955-
fn = _nvmlGetFunctionPointer('nvmlDeviceGetTemperatureV')
973+
try:
974+
fn = _nvmlGetFunctionPointer(f'nvmlDeviceGetTemperatureV{version_suffix}')
975+
except Exception:
976+
return NA
956977
ret = fn(handle, _ctypes.byref(c_temp_v1))
957978
if ret != NVML_SUCCESS:
958979
raise NVMLError(ret)
959980
return int(c_temp_v1.temperature)
960981

961982
if version_suffix == '':
962983
c_temp = _ctypes.c_uint(0)
963-
fn = _nvmlGetFunctionPointer('nvmlDeviceGetTemperature')
984+
try:
985+
fn = _nvmlGetFunctionPointer('nvmlDeviceGetTemperature')
986+
except Exception:
987+
return NA
964988
ret = fn(handle, _ctypes.c_uint(sensor), _ctypes.byref(c_temp))
965989
if ret != NVML_SUCCESS:
966990
raise NVMLError(ret)

nvitop/api/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,16 @@ def cache_deactivate(self: object) -> None:
792792
wrapped.cache_deactivate = cache_deactivate # type: ignore[attr-defined]
793793
return wrapped # type: ignore[return-value]
794794

795+
def is_musa() -> bool:
796+
"""Check if the current Python interpreter is Musa."""
797+
try:
798+
import pymtml # noqa: F401
799+
pymtml.nvmlInit()
800+
pymtml.nvmlShutdown()
801+
except Exception:
802+
return False
803+
804+
return True
795805

796806
if __name__ == '__main__':
797807
import doctest

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ classifiers = [
4848
dependencies = [
4949
# Sync with nvitop/version.py and requirements.txt
5050
"nvidia-ml-py >= 11.450.51, < 13.591.0a0",
51+
"mthreads-ml-py >= 2.2.1",
5152
"psutil >= 5.6.6",
5253
"colorama >= 0.4.0; platform_system == 'Windows'",
5354
"windows-curses >= 2.2.0; platform_system == 'Windows'",

0 commit comments

Comments
 (0)