1
+ # SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
1
4
import math
2
5
import operator
3
6
import os
@@ -86,6 +89,38 @@ def get_gpu_count():
86
89
return pynvml .nvmlDeviceGetCount ()
87
90
88
91
92
+ def get_gpu_handle (device_index = 0 ):
93
+ """Get GPU handle from device index or UUID.
94
+
95
+ Parameters
96
+ ----------
97
+ device_index: int or str
98
+ The index or UUID of the device from which to obtain the handle.
99
+
100
+ Examples
101
+ --------
102
+ >>> get_gpu_handle(device_index=0)
103
+
104
+ >>> get_gpu_handle(device_index="GPU-9fb42d6f-7d6b-368f-f79c-3c3e784c93f6")
105
+ """
106
+ pynvml .nvmlInit ()
107
+
108
+ try :
109
+ if device_index and not str (device_index ).isnumeric ():
110
+ # This means device_index is UUID.
111
+ # This works for both MIG and non-MIG device UUIDs.
112
+ handle = pynvml .nvmlDeviceGetHandleByUUID (str .encode (device_index ))
113
+ if pynvml .nvmlDeviceIsMigDeviceHandle (handle ):
114
+ # Additionally get parent device handle
115
+ # if the device itself is a MIG instance
116
+ handle = pynvml .nvmlDeviceGetDeviceHandleFromMigDeviceHandle (handle )
117
+ else :
118
+ handle = pynvml .nvmlDeviceGetHandleByIndex (device_index )
119
+ return handle
120
+ except pynvml .NVMLError :
121
+ raise ValueError (f"Invalid device index: { device_index } " )
122
+
123
+
89
124
@toolz .memoize
90
125
def get_gpu_count_mig (return_uuids = False ):
91
126
"""Return the number of MIG instances available
@@ -129,7 +164,7 @@ def get_cpu_affinity(device_index=None):
129
164
Parameters
130
165
----------
131
166
device_index: int or str
132
- Index or UUID of the GPU device
167
+ The index or UUID of the device from which to obtain the CPU affinity.
133
168
134
169
Examples
135
170
--------
@@ -148,19 +183,8 @@ def get_cpu_affinity(device_index=None):
148
183
40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
149
184
60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79]
150
185
"""
151
- pynvml .nvmlInit ()
152
-
153
186
try :
154
- if device_index and not str (device_index ).isnumeric ():
155
- # This means device_index is UUID.
156
- # This works for both MIG and non-MIG device UUIDs.
157
- handle = pynvml .nvmlDeviceGetHandleByUUID (str .encode (device_index ))
158
- if pynvml .nvmlDeviceIsMigDeviceHandle (handle ):
159
- # Additionally get parent device handle
160
- # if the device itself is a MIG instance
161
- handle = pynvml .nvmlDeviceGetDeviceHandleFromMigDeviceHandle (handle )
162
- else :
163
- handle = pynvml .nvmlDeviceGetHandleByIndex (device_index )
187
+ handle = get_gpu_handle (device_index )
164
188
# Result is a list of 64-bit integers, thus ceil(get_cpu_count() / 64)
165
189
affinity = pynvml .nvmlDeviceGetCpuAffinity (
166
190
handle ,
@@ -182,18 +206,15 @@ def get_n_gpus():
182
206
return get_gpu_count ()
183
207
184
208
185
- def get_device_total_memory (index = 0 ):
186
- """
187
- Return total memory of CUDA device with index or with device identifier UUID
188
- """
189
- pynvml .nvmlInit ()
209
+ def get_device_total_memory (device_index = 0 ):
210
+ """Return total memory of CUDA device with index or with device identifier UUID.
190
211
191
- if index and not str ( index ). isnumeric ():
192
- # This means index is UUID. This works for both MIG and non-MIG device UUIDs.
193
- handle = pynvml . nvmlDeviceGetHandleByUUID ( str . encode ( str ( index )))
194
- else :
195
- # This is a device index
196
- handle = pynvml . nvmlDeviceGetHandleByIndex ( index )
212
+ Parameters
213
+ ----------
214
+ device_index: int or str
215
+ The index or UUID of the device from which to obtain the CPU affinity.
216
+ """
217
+ handle = get_gpu_handle ( device_index )
197
218
return pynvml .nvmlDeviceGetMemoryInfo (handle ).total
198
219
199
220
@@ -553,26 +574,26 @@ def _align(size, alignment_size):
553
574
return _align (int (device_memory_limit ), alignment_size )
554
575
555
576
556
- def get_gpu_uuid_from_index (device_index = 0 ):
577
+ def get_gpu_uuid (device_index = 0 ):
557
578
"""Get GPU UUID from CUDA device index.
558
579
559
580
Parameters
560
581
----------
561
582
device_index: int or str
562
- The index of the device from which to obtain the UUID. Default: 0 .
583
+ The index or UUID of the device from which to obtain the UUID.
563
584
564
585
Examples
565
586
--------
566
- >>> get_gpu_uuid_from_index ()
587
+ >>> get_gpu_uuid ()
567
588
'GPU-9baca7f5-0f2f-01ac-6b05-8da14d6e9005'
568
589
569
- >>> get_gpu_uuid_from_index (3)
590
+ >>> get_gpu_uuid (3)
570
591
'GPU-9fb42d6f-7d6b-368f-f79c-3c3e784c93f6'
571
- """
572
- import pynvml
573
592
574
- pynvml .nvmlInit ()
575
- handle = pynvml .nvmlDeviceGetHandleByIndex (device_index )
593
+ >>> get_gpu_uuid("GPU-9fb42d6f-7d6b-368f-f79c-3c3e784c93f6")
594
+ 'GPU-9fb42d6f-7d6b-368f-f79c-3c3e784c93f6'
595
+ """
596
+ handle = get_gpu_handle (device_index )
576
597
try :
577
598
return pynvml .nvmlDeviceGetUUID (handle ).decode ("utf-8" )
578
599
except AttributeError :
0 commit comments