Skip to content

Add function to get device handle and fix MIG handle #1476

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: branch-25.06
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions dask_cuda/tests/test_dask_cuda_worker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from __future__ import absolute_import, division, print_function

import os
Expand All @@ -16,7 +19,7 @@
get_cluster_configuration,
get_device_total_memory,
get_gpu_count_mig,
get_gpu_uuid_from_index,
get_gpu_uuid,
get_n_gpus,
wait_workers,
)
Expand Down Expand Up @@ -409,7 +412,7 @@ def get_visible_devices():


def test_cuda_visible_devices_uuid(loop): # noqa: F811
gpu_uuid = get_gpu_uuid_from_index(0)
gpu_uuid = get_gpu_uuid(0)

with patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": gpu_uuid}):
with popen(["dask", "scheduler", "--port", "9359", "--no-dashboard"]):
Expand Down
7 changes: 5 additions & 2 deletions dask_cuda/tests/test_local_cuda_cluster.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

import asyncio
import os
import pkgutil
Expand All @@ -16,7 +19,7 @@
get_cluster_configuration,
get_device_total_memory,
get_gpu_count_mig,
get_gpu_uuid_from_index,
get_gpu_uuid,
print_cluster_config,
)
from dask_cuda.utils_test import MockWorker
Expand Down Expand Up @@ -419,7 +422,7 @@ def get_visible_devices():

@gen_test(timeout=20)
async def test_gpu_uuid():
gpu_uuid = get_gpu_uuid_from_index(0)
gpu_uuid = get_gpu_uuid(0)

async with LocalCUDACluster(
CUDA_VISIBLE_DEVICES=gpu_uuid,
Expand Down
85 changes: 53 additions & 32 deletions dask_cuda/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

import math
import operator
import os
Expand Down Expand Up @@ -86,6 +89,38 @@ def get_gpu_count():
return pynvml.nvmlDeviceGetCount()


def get_gpu_handle(device_index=0):
"""Get GPU handle from device index or UUID.

Parameters
----------
device_index: int or str
The index or UUID of the device from which to obtain the handle.

Examples
--------
>>> get_gpu_handle(device_index=0)

>>> get_gpu_handle(device_index="GPU-9fb42d6f-7d6b-368f-f79c-3c3e784c93f6")
"""
pynvml.nvmlInit()

try:
if device_index and not str(device_index).isnumeric():
# This means device_index is UUID.
# This works for both MIG and non-MIG device UUIDs.
handle = pynvml.nvmlDeviceGetHandleByUUID(str.encode(device_index))
if pynvml.nvmlDeviceIsMigDeviceHandle(handle):
# Additionally get parent device handle
# if the device itself is a MIG instance
handle = pynvml.nvmlDeviceGetDeviceHandleFromMigDeviceHandle(handle)
else:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
return handle
except pynvml.NVMLError:
raise ValueError(f"Invalid device index: {device_index}")


@toolz.memoize
def get_gpu_count_mig(return_uuids=False):
"""Return the number of MIG instances available
Expand Down Expand Up @@ -129,7 +164,7 @@ def get_cpu_affinity(device_index=None):
Parameters
----------
device_index: int or str
Index or UUID of the GPU device
The index or UUID of the device from which to obtain the CPU affinity.

Examples
--------
Expand All @@ -148,19 +183,8 @@ def get_cpu_affinity(device_index=None):
40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79]
"""
pynvml.nvmlInit()

try:
if device_index and not str(device_index).isnumeric():
# This means device_index is UUID.
# This works for both MIG and non-MIG device UUIDs.
handle = pynvml.nvmlDeviceGetHandleByUUID(str.encode(device_index))
if pynvml.nvmlDeviceIsMigDeviceHandle(handle):
# Additionally get parent device handle
# if the device itself is a MIG instance
handle = pynvml.nvmlDeviceGetDeviceHandleFromMigDeviceHandle(handle)
else:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
handle = get_gpu_handle(device_index)
# Result is a list of 64-bit integers, thus ceil(get_cpu_count() / 64)
affinity = pynvml.nvmlDeviceGetCpuAffinity(
handle,
Expand All @@ -182,18 +206,15 @@ def get_n_gpus():
return get_gpu_count()


def get_device_total_memory(index=0):
"""
Return total memory of CUDA device with index or with device identifier UUID
"""
pynvml.nvmlInit()
def get_device_total_memory(device_index=0):
"""Return total memory of CUDA device with index or with device identifier UUID.

if index and not str(index).isnumeric():
# This means index is UUID. This works for both MIG and non-MIG device UUIDs.
handle = pynvml.nvmlDeviceGetHandleByUUID(str.encode(str(index)))
else:
# This is a device index
handle = pynvml.nvmlDeviceGetHandleByIndex(index)
Parameters
----------
device_index: int or str
The index or UUID of the device from which to obtain the CPU affinity.
"""
handle = get_gpu_handle(device_index)
return pynvml.nvmlDeviceGetMemoryInfo(handle).total


Expand Down Expand Up @@ -553,26 +574,26 @@ def _align(size, alignment_size):
return _align(int(device_memory_limit), alignment_size)


def get_gpu_uuid_from_index(device_index=0):
def get_gpu_uuid(device_index=0):
"""Get GPU UUID from CUDA device index.

Parameters
----------
device_index: int or str
The index of the device from which to obtain the UUID. Default: 0.
The index or UUID of the device from which to obtain the UUID.

Examples
--------
>>> get_gpu_uuid_from_index()
>>> get_gpu_uuid()
'GPU-9baca7f5-0f2f-01ac-6b05-8da14d6e9005'

>>> get_gpu_uuid_from_index(3)
>>> get_gpu_uuid(3)
'GPU-9fb42d6f-7d6b-368f-f79c-3c3e784c93f6'
"""
import pynvml

pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
>>> get_gpu_uuid("GPU-9fb42d6f-7d6b-368f-f79c-3c3e784c93f6")
'GPU-9fb42d6f-7d6b-368f-f79c-3c3e784c93f6'
"""
handle = get_gpu_handle(device_index)
try:
return pynvml.nvmlDeviceGetUUID(handle).decode("utf-8")
except AttributeError:
Expand Down
Loading