diff --git a/nvitop/api/libnvml.py b/nvitop/api/libnvml.py index 1eddb2e4..1952a2f6 100644 --- a/nvitop/api/libnvml.py +++ b/nvitop/api/libnvml.py @@ -37,13 +37,20 @@ # Python Bindings for the NVIDIA Management Library (NVML) # https://pypi.org/project/nvidia-ml-py -import pynvml as _pynvml -from pynvml import * # noqa: F403 # pylint: disable=wildcard-import,unused-wildcard-import -from pynvml import nvmlDeviceGetPciInfo # appease mypy # noqa: F401 # pylint: disable=unused-import -from nvitop.api.utils import NA, UINT_MAX, ULONGLONG_MAX, NaType +from nvitop.api.utils import NA, UINT_MAX, ULONGLONG_MAX, NaType, is_musa from nvitop.api.utils import colored as __colored +_is_musa = is_musa() + +if not _is_musa: + import pynvml as _pynvml + from pynvml import * # noqa: F403 # pylint: disable=wildcard-import,unused-wildcard-import + from pynvml import nvmlDeviceGetPciInfo # appease mypy # noqa: F401 # pylint: disable=unused-import +else: + import pymtml as _pynvml + from pymtml import * # noqa: F403 # pylint: disable=wildcard-import,unused-wildcard-import + from pymtml import nvmlDeviceGetPciInfo # appease mypy # noqa: F401 # pylint: disable=unused-import if _TYPE_CHECKING: from collections.abc import Callable as _Callable @@ -540,7 +547,10 @@ def nvmlCheckReturn(retval: _Any, types: type | tuple[type, ...] | None = None, # Patch function `nvmlDeviceGet{Compute,Graphics,MPSCompute}RunningProcesses` if not _pynvml_installation_corrupted: # pylint: disable-next=ungrouped-imports - from pynvml import _nvmlGetFunctionPointer, _PrintableStructure, nvmlStructToFriendlyObject + if not _is_musa: + from pynvml import _nvmlGetFunctionPointer, _PrintableStructure, nvmlStructToFriendlyObject + else: + from pymtml import _nvmlGetFunctionPointer, _PrintableStructure, nvmlStructToFriendlyObject def _nvmlLookupFunctionPointer(symbol: str) -> _Any | None: try: @@ -671,7 +681,11 @@ def __nvml_device_get_running_processes( # First call to get the size c_count = _ctypes.c_uint(0) - fn = _nvmlGetFunctionPointer(f'{func}{version_suffix}') + try: + fn = _nvmlGetFunctionPointer(f'{func}{version_suffix}') + except Exception: + return [] + ret = fn(handle, _ctypes.byref(c_count), None) if ret == NVML_SUCCESS: @@ -876,7 +890,11 @@ def nvmlDeviceGetMemoryInfo( # pylint: disable=function-redefined 'function `nvmlDeviceGetMemoryInfo`.', ) - fn = _nvmlGetFunctionPointer(f'nvmlDeviceGetMemoryInfo{version_suffix}') + try: + fn = _nvmlGetFunctionPointer(f'nvmlDeviceGetMemoryInfo{version_suffix}') + except Exception: + return NA + ret = fn(handle, _ctypes.byref(c_memory)) if ret != NVML_SUCCESS: raise NVMLError(ret) @@ -952,7 +970,10 @@ def nvmlDeviceGetTemperature( # pylint: disable=function-redefined c_temp_v1.version = nvmlTemperature_v1 # pylint: disable-next=attribute-defined-outside-init c_temp_v1.sensorType = _ctypes.c_uint(sensor) - fn = _nvmlGetFunctionPointer('nvmlDeviceGetTemperatureV') + try: + fn = _nvmlGetFunctionPointer(f'nvmlDeviceGetTemperatureV{version_suffix}') + except Exception: + return NA ret = fn(handle, _ctypes.byref(c_temp_v1)) if ret != NVML_SUCCESS: raise NVMLError(ret) @@ -960,7 +981,10 @@ def nvmlDeviceGetTemperature( # pylint: disable=function-redefined if version_suffix == '': c_temp = _ctypes.c_uint(0) - fn = _nvmlGetFunctionPointer('nvmlDeviceGetTemperature') + try: + fn = _nvmlGetFunctionPointer('nvmlDeviceGetTemperature') + except Exception: + return NA ret = fn(handle, _ctypes.c_uint(sensor), _ctypes.byref(c_temp)) if ret != NVML_SUCCESS: raise NVMLError(ret) diff --git a/nvitop/api/utils.py b/nvitop/api/utils.py index b876f802..3c3b59d0 100644 --- a/nvitop/api/utils.py +++ b/nvitop/api/utils.py @@ -792,6 +792,16 @@ def cache_deactivate(self: object) -> None: wrapped.cache_deactivate = cache_deactivate # type: ignore[attr-defined] return wrapped # type: ignore[return-value] +def is_musa() -> bool: + """Check if the current Python interpreter is Musa.""" + try: + import pymtml # noqa: F401 + pymtml.nvmlInit() + pymtml.nvmlShutdown() + except Exception: + return False + + return True if __name__ == '__main__': import doctest diff --git a/pyproject.toml b/pyproject.toml index 82c90f41..03f394e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ classifiers = [ dependencies = [ # Sync with nvitop/version.py and requirements.txt "nvidia-ml-py >= 11.450.51, < 13.591.0a0", + "mthreads-ml-py >= 2.2.1", "psutil >= 5.6.6", "colorama >= 0.4.0; platform_system == 'Windows'", "windows-curses >= 2.2.0; platform_system == 'Windows'",