File tree Expand file tree Collapse file tree 4 files changed +74
-0
lines changed
Expand file tree Collapse file tree 4 files changed +74
-0
lines changed Original file line number Diff line number Diff line change @@ -1034,6 +1034,17 @@ class Device:
10341034 total = system.get_num_devices()
10351035 return tuple (cls (device_id) for device_id in range (total))
10361036
1037+ def to_system_device (self ) -> 'cuda.core.system.Device':
1038+ """
1039+ Get the corresponding :class:`cuda.core.system.Device` (which is used
1040+ for NVIDIA Machine Library (NVML ) access ) for this
1041+ :class:`cuda.core.Device` (which is used for CUDA access ).
1042+
1043+ The devices are mapped to one another by their UUID.
1044+ """
1045+ from cuda.core.system import Device as SystemDevice
1046+ return SystemDevice(uuid = self .uuid)
1047+
10371048 @property
10381049 def device_id(self ) -> int:
10391050 """Return device ordinal."""
Original file line number Diff line number Diff line change @@ -722,6 +722,36 @@ cdef class Device:
722722 pci_bus_id = pci_bus_id.decode(" ascii" )
723723 self ._handle = nvml.device_get_handle_by_pci_bus_id_v2(pci_bus_id)
724724
725+ def to_cuda_device (self ) -> "cuda.core.Device":
726+ """
727+ Get the corresponding :class:`cuda.core.Device` (which is used for CUDA
728+ access ) for this :class:`cuda.core.system.Device` (which is used for
729+ NVIDIA machine library (NVML ) access ).
730+
731+ The devices are mapped to one another by their UUID.
732+
733+ Returns
734+ -------
735+ cuda.core.Device
736+ The corresponding CUDA device.
737+ """
738+ from cuda.core import Device as CudaDevice
739+
740+ # CUDA does not have an API to get a device by its UUID , so we just
741+ # search all the devices for one with a matching UUID.
742+
743+ # NVML UUIDs have a `GPU-` or `MIG-` prefix. Possibly we should only do
744+ # this matching when it has a `GPU-` prefix , but for now we just strip
745+ # it. If a matching CUDA device can't be found , we will get a helpful
746+ # exception , anyway , below.
747+ uuid = self .uuid[4 :]
748+
749+ for cuda_device in CudaDevice.get_all_devices():
750+ if cuda_device.uuid == uuid:
751+ return cuda_device
752+
753+ raise RuntimeError (" No corresponding CUDA device found for this NVML device." )
754+
725755 @classmethod
726756 def get_device_count (cls ) -> int:
727757 """
Original file line number Diff line number Diff line change @@ -33,6 +33,23 @@ def test_device_count():
3333 assert system .Device .get_device_count () == system .get_num_devices ()
3434
3535
36+ def test_to_cuda_device ():
37+ from cuda .core import Device as CudaDevice
38+
39+ for device in system .Device .get_all_devices ():
40+ cuda_device = device .to_cuda_device ()
41+
42+ assert isinstance (cuda_device , CudaDevice )
43+ assert cuda_device .uuid == device .uuid [4 :]
44+
45+ # Technically, this test will only work with PCI devices, but are there
46+ # non-PCI devices we need to support?
47+
48+ # CUDA only returns a 2-byte PCI bus ID domain, whereas NVML returns a
49+ # 4-byte domain
50+ assert cuda_device .pci_bus_id == device .pci_info .bus_id [4 :]
51+
52+
3653def test_device_architecture ():
3754 for device in system .Device .get_all_devices ():
3855 device_arch = device .architecture
Original file line number Diff line number Diff line change @@ -25,6 +25,22 @@ def cuda_version():
2525 return _py_major_ver , _driver_ver
2626
2727
28+ def test_to_system_device (deinit_cuda ):
29+ from cuda .core .system import Device as SystemDevice
30+
31+ device = Device ()
32+ system_device = device .to_system_device ()
33+ assert isinstance (system_device , SystemDevice )
34+ assert system_device .uuid [4 :] == device .uuid
35+
36+ # Technically, this test will only work with PCI devices, but are there
37+ # non-PCI devices we need to support?
38+
39+ # CUDA only returns a 2-byte PCI bus ID domain, whereas NVML returns a
40+ # 4-byte domain
41+ assert device .pci_bus_id == system_device .pci_info .bus_id [4 :]
42+
43+
2844def test_device_set_current (deinit_cuda ):
2945 device = Device ()
3046 device .set_current ()
You can’t perform that action at this time.
0 commit comments