Skip to content

Commit cd2626f

Browse files
committed
Add function to get device handle, fix MIG handle
1 parent fca083f commit cd2626f

File tree

3 files changed

+63
-36
lines changed

3 files changed

+63
-36
lines changed

dask_cuda/tests/test_dask_cuda_worker.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-License-Identifier: Apache-2.0
3+
14
from __future__ import absolute_import, division, print_function
25

36
import os
@@ -16,7 +19,7 @@
1619
get_cluster_configuration,
1720
get_device_total_memory,
1821
get_gpu_count_mig,
19-
get_gpu_uuid_from_index,
22+
get_gpu_uuid,
2023
get_n_gpus,
2124
wait_workers,
2225
)
@@ -409,7 +412,7 @@ def get_visible_devices():
409412

410413

411414
def test_cuda_visible_devices_uuid(loop): # noqa: F811
412-
gpu_uuid = get_gpu_uuid_from_index(0)
415+
gpu_uuid = get_gpu_uuid(0)
413416

414417
with patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": gpu_uuid}):
415418
with popen(["dask", "scheduler", "--port", "9359", "--no-dashboard"]):

dask_cuda/tests/test_local_cuda_cluster.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-License-Identifier: Apache-2.0
3+
14
import asyncio
25
import os
36
import pkgutil
@@ -16,7 +19,7 @@
1619
get_cluster_configuration,
1720
get_device_total_memory,
1821
get_gpu_count_mig,
19-
get_gpu_uuid_from_index,
22+
get_gpu_uuid,
2023
print_cluster_config,
2124
)
2225
from dask_cuda.utils_test import MockWorker
@@ -419,7 +422,7 @@ def get_visible_devices():
419422

420423
@gen_test(timeout=20)
421424
async def test_gpu_uuid():
422-
gpu_uuid = get_gpu_uuid_from_index(0)
425+
gpu_uuid = get_gpu_uuid(0)
423426

424427
async with LocalCUDACluster(
425428
CUDA_VISIBLE_DEVICES=gpu_uuid,

dask_cuda/utils.py

+53-32
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-License-Identifier: Apache-2.0
3+
14
import math
25
import operator
36
import os
@@ -86,6 +89,38 @@ def get_gpu_count():
8689
return pynvml.nvmlDeviceGetCount()
8790

8891

92+
def get_gpu_handle(device_index=0):
93+
"""Get GPU handle from device index or UUID.
94+
95+
Parameters
96+
----------
97+
device_index: int or str
98+
The index or UUID of the device from which to obtain the handle.
99+
100+
Examples
101+
--------
102+
>>> get_gpu_handle(device_index=0)
103+
104+
>>> get_gpu_handle(device_index="GPU-9fb42d6f-7d6b-368f-f79c-3c3e784c93f6")
105+
"""
106+
pynvml.nvmlInit()
107+
108+
try:
109+
if device_index and not str(device_index).isnumeric():
110+
# This means device_index is UUID.
111+
# This works for both MIG and non-MIG device UUIDs.
112+
handle = pynvml.nvmlDeviceGetHandleByUUID(str.encode(device_index))
113+
if pynvml.nvmlDeviceIsMigDeviceHandle(handle):
114+
# Additionally get parent device handle
115+
# if the device itself is a MIG instance
116+
handle = pynvml.nvmlDeviceGetDeviceHandleFromMigDeviceHandle(handle)
117+
else:
118+
handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
119+
return handle
120+
except pynvml.NVMLError:
121+
raise ValueError(f"Invalid device index: {device_index}")
122+
123+
89124
@toolz.memoize
90125
def get_gpu_count_mig(return_uuids=False):
91126
"""Return the number of MIG instances available
@@ -129,7 +164,7 @@ def get_cpu_affinity(device_index=None):
129164
Parameters
130165
----------
131166
device_index: int or str
132-
Index or UUID of the GPU device
167+
The index or UUID of the device from which to obtain the CPU affinity.
133168
134169
Examples
135170
--------
@@ -148,19 +183,8 @@ def get_cpu_affinity(device_index=None):
148183
40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
149184
60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79]
150185
"""
151-
pynvml.nvmlInit()
152-
153186
try:
154-
if device_index and not str(device_index).isnumeric():
155-
# This means device_index is UUID.
156-
# This works for both MIG and non-MIG device UUIDs.
157-
handle = pynvml.nvmlDeviceGetHandleByUUID(str.encode(device_index))
158-
if pynvml.nvmlDeviceIsMigDeviceHandle(handle):
159-
# Additionally get parent device handle
160-
# if the device itself is a MIG instance
161-
handle = pynvml.nvmlDeviceGetDeviceHandleFromMigDeviceHandle(handle)
162-
else:
163-
handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
187+
handle = get_gpu_handle(device_index)
164188
# Result is a list of 64-bit integers, thus ceil(get_cpu_count() / 64)
165189
affinity = pynvml.nvmlDeviceGetCpuAffinity(
166190
handle,
@@ -182,18 +206,15 @@ def get_n_gpus():
182206
return get_gpu_count()
183207

184208

185-
def get_device_total_memory(index=0):
186-
"""
187-
Return total memory of CUDA device with index or with device identifier UUID
188-
"""
189-
pynvml.nvmlInit()
209+
def get_device_total_memory(device_index=0):
210+
"""Return total memory of CUDA device with index or with device identifier UUID.
190211
191-
if index and not str(index).isnumeric():
192-
# This means index is UUID. This works for both MIG and non-MIG device UUIDs.
193-
handle = pynvml.nvmlDeviceGetHandleByUUID(str.encode(str(index)))
194-
else:
195-
# This is a device index
196-
handle = pynvml.nvmlDeviceGetHandleByIndex(index)
212+
Parameters
213+
----------
214+
device_index: int or str
215+
The index or UUID of the device from which to obtain the CPU affinity.
216+
"""
217+
handle = get_gpu_handle(device_index)
197218
return pynvml.nvmlDeviceGetMemoryInfo(handle).total
198219

199220

@@ -553,26 +574,26 @@ def _align(size, alignment_size):
553574
return _align(int(device_memory_limit), alignment_size)
554575

555576

556-
def get_gpu_uuid_from_index(device_index=0):
577+
def get_gpu_uuid(device_index=0):
557578
"""Get GPU UUID from CUDA device index.
558579
559580
Parameters
560581
----------
561582
device_index: int or str
562-
The index of the device from which to obtain the UUID. Default: 0.
583+
The index or UUID of the device from which to obtain the UUID.
563584
564585
Examples
565586
--------
566-
>>> get_gpu_uuid_from_index()
587+
>>> get_gpu_uuid()
567588
'GPU-9baca7f5-0f2f-01ac-6b05-8da14d6e9005'
568589
569-
>>> get_gpu_uuid_from_index(3)
590+
>>> get_gpu_uuid(3)
570591
'GPU-9fb42d6f-7d6b-368f-f79c-3c3e784c93f6'
571-
"""
572-
import pynvml
573592
574-
pynvml.nvmlInit()
575-
handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
593+
>>> get_gpu_uuid("GPU-9fb42d6f-7d6b-368f-f79c-3c3e784c93f6")
594+
'GPU-9fb42d6f-7d6b-368f-f79c-3c3e784c93f6'
595+
"""
596+
handle = get_gpu_handle(device_index)
576597
try:
577598
return pynvml.nvmlDeviceGetUUID(handle).decode("utf-8")
578599
except AttributeError:

0 commit comments

Comments
 (0)