|
37 | 37 |
|
38 | 38 | # Python Bindings for the NVIDIA Management Library (NVML) |
39 | 39 | # https://pypi.org/project/nvidia-ml-py |
40 | | -import pynvml as _pynvml |
41 | | -from pynvml import * # noqa: F403 # pylint: disable=wildcard-import,unused-wildcard-import |
42 | | -from pynvml import nvmlDeviceGetPciInfo # appease mypy # noqa: F401 # pylint: disable=unused-import |
43 | 40 |
|
44 | | -from nvitop.api.utils import NA, UINT_MAX, ULONGLONG_MAX, NaType |
| 41 | +from nvitop.api.utils import NA, UINT_MAX, ULONGLONG_MAX, NaType, is_musa |
45 | 42 | from nvitop.api.utils import colored as __colored |
46 | 43 |
|
| 44 | +_is_musa = is_musa() |
| 45 | + |
| 46 | +if not _is_musa: |
| 47 | + import pynvml as _pynvml |
| 48 | + from pynvml import * # noqa: F403 # pylint: disable=wildcard-import,unused-wildcard-import |
| 49 | + from pynvml import nvmlDeviceGetPciInfo # appease mypy # noqa: F401 # pylint: disable=unused-import |
| 50 | +else: |
| 51 | + import pymtml as _pynvml |
| 52 | + from pymtml import * # noqa: F403 # pylint: disable=wildcard-import,unused-wildcard-import |
| 53 | + from pymtml import nvmlDeviceGetPciInfo # appease mypy # noqa: F401 # pylint: disable=unused-import |
47 | 54 |
|
48 | 55 | if _TYPE_CHECKING: |
49 | 56 | from collections.abc import Callable as _Callable |
@@ -540,7 +547,10 @@ def nvmlCheckReturn(retval: _Any, types: type | tuple[type, ...] | None = None, |
540 | 547 | # Patch function `nvmlDeviceGet{Compute,Graphics,MPSCompute}RunningProcesses` |
541 | 548 | if not _pynvml_installation_corrupted: |
542 | 549 | # pylint: disable-next=ungrouped-imports |
543 | | - from pynvml import _nvmlGetFunctionPointer, _PrintableStructure, nvmlStructToFriendlyObject |
| 550 | + if not _is_musa: |
| 551 | + from pynvml import _nvmlGetFunctionPointer, _PrintableStructure, nvmlStructToFriendlyObject |
| 552 | + else: |
| 553 | + from pymtml import _nvmlGetFunctionPointer, _PrintableStructure, nvmlStructToFriendlyObject |
544 | 554 |
|
545 | 555 | def _nvmlLookupFunctionPointer(symbol: str) -> _Any | None: |
546 | 556 | try: |
@@ -671,7 +681,11 @@ def __nvml_device_get_running_processes( |
671 | 681 |
|
672 | 682 | # First call to get the size |
673 | 683 | c_count = _ctypes.c_uint(0) |
674 | | - fn = _nvmlGetFunctionPointer(f'{func}{version_suffix}') |
| 684 | + try: |
| 685 | + fn = _nvmlGetFunctionPointer(f'{func}{version_suffix}') |
| 686 | + except Exception: |
| 687 | + return [] |
| 688 | + |
675 | 689 | ret = fn(handle, _ctypes.byref(c_count), None) |
676 | 690 |
|
677 | 691 | if ret == NVML_SUCCESS: |
@@ -876,7 +890,11 @@ def nvmlDeviceGetMemoryInfo( # pylint: disable=function-redefined |
876 | 890 | 'function `nvmlDeviceGetMemoryInfo`.', |
877 | 891 | ) |
878 | 892 |
|
879 | | - fn = _nvmlGetFunctionPointer(f'nvmlDeviceGetMemoryInfo{version_suffix}') |
| 893 | + try: |
| 894 | + fn = _nvmlGetFunctionPointer(f'nvmlDeviceGetMemoryInfo{version_suffix}') |
| 895 | + except Exception: |
| 896 | + return NA |
| 897 | + |
880 | 898 | ret = fn(handle, _ctypes.byref(c_memory)) |
881 | 899 | if ret != NVML_SUCCESS: |
882 | 900 | raise NVMLError(ret) |
@@ -952,15 +970,21 @@ def nvmlDeviceGetTemperature( # pylint: disable=function-redefined |
952 | 970 | c_temp_v1.version = nvmlTemperature_v1 |
953 | 971 | # pylint: disable-next=attribute-defined-outside-init |
954 | 972 | c_temp_v1.sensorType = _ctypes.c_uint(sensor) |
955 | | - fn = _nvmlGetFunctionPointer('nvmlDeviceGetTemperatureV') |
| 973 | + try: |
| 974 | + fn = _nvmlGetFunctionPointer(f'nvmlDeviceGetTemperatureV{version_suffix}') |
| 975 | + except Exception: |
| 976 | + return NA |
956 | 977 | ret = fn(handle, _ctypes.byref(c_temp_v1)) |
957 | 978 | if ret != NVML_SUCCESS: |
958 | 979 | raise NVMLError(ret) |
959 | 980 | return int(c_temp_v1.temperature) |
960 | 981 |
|
961 | 982 | if version_suffix == '': |
962 | 983 | c_temp = _ctypes.c_uint(0) |
963 | | - fn = _nvmlGetFunctionPointer('nvmlDeviceGetTemperature') |
| 984 | + try: |
| 985 | + fn = _nvmlGetFunctionPointer('nvmlDeviceGetTemperature') |
| 986 | + except Exception: |
| 987 | + return NA |
964 | 988 | ret = fn(handle, _ctypes.c_uint(sensor), _ctypes.byref(c_temp)) |
965 | 989 | if ret != NVML_SUCCESS: |
966 | 990 | raise NVMLError(ret) |
|
0 commit comments