Skip to content

Commit d04fdf8

Browse files
committed
pynvml compatibility
Seems like some version pynvml switched from returning bytes to str (I can pin down the exact version if needed). This updates our test and type annotations to accomodate either.
1 parent 94222c0 commit d04fdf8

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

distributed/diagnostics/nvml.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ class NVMLState(IntEnum):
3232

3333

3434
class CudaDeviceInfo(NamedTuple):
35-
uuid: bytes | None = None
35+
# Older versions of pynvml returned bytes, newer versions return str.
36+
uuid: str | bytes | None = None
3637
device_index: int | None = None
3738
mig_index: int | None = None
3839

@@ -278,13 +279,13 @@ def get_device_index_and_uuid(device):
278279
Examples
279280
--------
280281
>>> get_device_index_and_uuid(0) # doctest: +SKIP
281-
{'device-index': 0, 'uuid': b'GPU-e1006a74-5836-264f-5c26-53d19d212dfe'}
282+
{'device-index': 0, 'uuid': 'GPU-e1006a74-5836-264f-5c26-53d19d212dfe'}
282283
283284
>>> get_device_index_and_uuid('GPU-e1006a74-5836-264f-5c26-53d19d212dfe') # doctest: +SKIP
284-
{'device-index': 0, 'uuid': b'GPU-e1006a74-5836-264f-5c26-53d19d212dfe'}
285+
{'device-index': 0, 'uuid': 'GPU-e1006a74-5836-264f-5c26-53d19d212dfe'}
285286
286287
>>> get_device_index_and_uuid('MIG-7feb6df5-eccf-5faa-ab00-9a441867e237') # doctest: +SKIP
287-
{'device-index': 0, 'uuid': b'MIG-7feb6df5-eccf-5faa-ab00-9a441867e237'}
288+
{'device-index': 0, 'uuid': 'MIG-7feb6df5-eccf-5faa-ab00-9a441867e237'}
288289
"""
289290
init_once()
290291
try:

distributed/diagnostics/tests/test_nvml.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
pynvml = pytest.importorskip("pynvml")
1212

1313
import dask
14+
from dask.utils import ensure_unicode
1415

1516
from distributed.diagnostics import nvml
1617
from distributed.utils_test import gen_cluster
@@ -66,7 +67,7 @@ def run_has_cuda_context(queue):
6667
assert (
6768
ctx.has_context
6869
and ctx.device_info.device_index == 0
69-
and isinstance(ctx.device_info.uuid, bytes)
70+
and isinstance(ctx.device_info.uuid, str)
7071
)
7172

7273
queue.put(None)
@@ -127,7 +128,7 @@ def test_visible_devices_uuid():
127128
assert info.uuid
128129

129130
with mock.patch.dict(
130-
os.environ, {"CUDA_VISIBLE_DEVICES": info.uuid.decode("utf-8")}
131+
os.environ, {"CUDA_VISIBLE_DEVICES": ensure_unicode(info.uuid)}
131132
):
132133
h = nvml._pynvml_handles()
133134
h_expected = pynvml.nvmlDeviceGetHandleByIndex(0)
@@ -147,7 +148,7 @@ def test_visible_devices_uuid_2(index):
147148
assert info.uuid
148149

149150
with mock.patch.dict(
150-
os.environ, {"CUDA_VISIBLE_DEVICES": info.uuid.decode("utf-8")}
151+
os.environ, {"CUDA_VISIBLE_DEVICES": ensure_unicode(info.uuid)}
151152
):
152153
h = nvml._pynvml_handles()
153154
h_expected = pynvml.nvmlDeviceGetHandleByIndex(index)

0 commit comments

Comments
 (0)