Skip to content

Commit 23aefa3

Browse files
committed
cuda.core.system: Add conveniences to convert device types
1 parent ce333b6 commit 23aefa3

File tree

4 files changed

+74
-0
lines changed

4 files changed

+74
-0
lines changed

cuda_core/cuda/core/_device.pyx

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,6 +1034,17 @@ class Device:
10341034
total = system.get_num_devices()
10351035
return tuple(cls(device_id) for device_id in range(total))
10361036

1037+
def to_system_device(self) -> 'cuda.core.system.Device':
1038+
"""
1039+
Get the corresponding :class:`cuda.core.system.Device` (which is used
1040+
for NVIDIA Machine Library (NVML) access) for this
1041+
:class:`cuda.core.Device` (which is used for CUDA access).
1042+
1043+
The devices are mapped to one another by their UUID.
1044+
"""
1045+
from cuda.core.system import Device as SystemDevice
1046+
return SystemDevice(uuid=self.uuid)
1047+
10371048
@property
10381049
def device_id(self) -> int:
10391050
"""Return device ordinal."""

cuda_core/cuda/core/system/_device.pyx

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,36 @@ cdef class Device:
722722
pci_bus_id = pci_bus_id.decode("ascii")
723723
self._handle = nvml.device_get_handle_by_pci_bus_id_v2(pci_bus_id)
724724

725+
def to_cuda_device(self) -> "cuda.core.Device":
726+
"""
727+
Get the corresponding :class:`cuda.core.Device` (which is used for CUDA
728+
access) for this :class:`cuda.core.system.Device` (which is used for
729+
NVIDIA machine library (NVML) access).
730+
731+
The devices are mapped to one another by their UUID.
732+
733+
Returns
734+
-------
735+
cuda.core.Device
736+
The corresponding CUDA device.
737+
"""
738+
from cuda.core import Device as CudaDevice
739+
740+
# CUDA does not have an API to get a device by its UUID, so we just
741+
# search all the devices for one with a matching UUID.
742+
743+
# NVML UUIDs have a `GPU-` or `MIG-` prefix. Possibly we should only do
744+
# this matching when it has a `GPU-` prefix, but for now we just strip
745+
# it. If a matching CUDA device can't be found, we will get a helpful
746+
# exception, anyway, below.
747+
uuid = self.uuid[4:]
748+
749+
for cuda_device in CudaDevice.get_all_devices():
750+
if cuda_device.uuid == uuid:
751+
return cuda_device
752+
753+
raise RuntimeError("No corresponding CUDA device found for this NVML device.")
754+
725755
@classmethod
726756
def get_device_count(cls) -> int:
727757
"""

cuda_core/tests/system/test_system_device.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,23 @@ def test_device_count():
3333
assert system.Device.get_device_count() == system.get_num_devices()
3434

3535

36+
def test_to_cuda_device():
37+
from cuda.core import Device as CudaDevice
38+
39+
for device in system.Device.get_all_devices():
40+
cuda_device = device.to_cuda_device()
41+
42+
assert isinstance(cuda_device, CudaDevice)
43+
assert cuda_device.uuid == device.uuid[4:]
44+
45+
# Technically, this test will only work with PCI devices, but are there
46+
# non-PCI devices we need to support?
47+
48+
# CUDA only returns a 2-byte PCI bus ID domain, whereas NVML returns a
49+
# 4-byte domain
50+
assert cuda_device.pci_bus_id == device.pci_info.bus_id[4:]
51+
52+
3653
def test_device_architecture():
3754
for device in system.Device.get_all_devices():
3855
device_arch = device.architecture

cuda_core/tests/test_device.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,22 @@ def cuda_version():
2525
return _py_major_ver, _driver_ver
2626

2727

28+
def test_to_system_device(deinit_cuda):
29+
from cuda.core.system import Device as SystemDevice
30+
31+
device = Device()
32+
system_device = device.to_system_device()
33+
assert isinstance(system_device, SystemDevice)
34+
assert system_device.uuid[4:] == device.uuid
35+
36+
# Technically, this test will only work with PCI devices, but are there
37+
# non-PCI devices we need to support?
38+
39+
# CUDA only returns a 2-byte PCI bus ID domain, whereas NVML returns a
40+
# 4-byte domain
41+
assert device.pci_bus_id == system_device.pci_info.bus_id[4:]
42+
43+
2844
def test_device_set_current(deinit_cuda):
2945
device = Device()
3046
device.set_current()

0 commit comments

Comments
 (0)