@@ -22,18 +22,18 @@ def get_memory_mb():
2222except BaseException :
2323 get_memory_mb = None
2424
25- try :
26- import torch
25+ import torch
2726
27+ try :
2828 if torch .cuda .is_available ():
2929 from utils .conf import get_alloc_memory_all_devices
3030
31- def get_memory_gpu_mb ():
31+ def get_memory_gpu_mb (avail_devices = None ):
3232 """
33- Get the memory usage of all GPUs in MB.
33+ Get the memory usage of the selected GPUs in MB.
3434 """
3535
36- return [d / 1024 / 1024 for d in get_alloc_memory_all_devices ()]
36+ return [d / 1024 / 1024 for d in get_alloc_memory_all_devices (avail_devices = avail_devices )]
3737 else :
3838 get_memory_gpu_mb = None
3939except BaseException :
@@ -43,6 +43,54 @@ def get_memory_gpu_mb():
4343from utils .loggers import Logger
4444
4545
46+ def _parse_device_ids (device ):
47+ """
48+ Normalize a device specification to a list of CUDA ids.
49+ """
50+ if device is None :
51+ return None
52+
53+ if isinstance (device , torch .device ):
54+ if device .type != 'cuda' :
55+ return None
56+ if device .index is None :
57+ return list (range (torch .cuda .device_count ()))
58+ if 0 <= device .index < torch .cuda .device_count ():
59+ return [device .index ]
60+ logging .warning (f"Requested device index { device .index } is out of range." )
61+ return None
62+
63+ if isinstance (device , str ):
64+ if 'cuda' not in device :
65+ return None
66+ parts = [p for p in device .split (',' ) if p .strip () != '' ]
67+ if len (parts ) == 0 :
68+ return list (range (torch .cuda .device_count ()))
69+ ids = []
70+ for p in parts :
71+ try :
72+ ids .append (int (p .split (':' )[- 1 ]))
73+ except ValueError :
74+ logging .warning (f"Could not parse device id from `{ p } `, skipping." )
75+ ids = [i for i in ids if 0 <= i < torch .cuda .device_count ()]
76+ if len (ids ) == 0 :
77+ logging .warning ("No valid CUDA device ids parsed, falling back to all visible devices." )
78+ return list (range (torch .cuda .device_count ()))
79+ return ids
80+
81+ if isinstance (device , (list , tuple )):
82+ ids = []
83+ for d in device :
84+ if isinstance (d , int ):
85+ ids .append (d )
86+ elif isinstance (d , torch .device ) and d .type == 'cuda' and d .index is not None :
87+ ids .append (d .index )
88+ ids = [i for i in ids if 0 <= i < torch .cuda .device_count ()]
89+ return ids or None
90+
91+ return None
92+
93+
4694class track_system_stats :
4795 """
4896 A context manager that tracks the memory usage of the system.
@@ -59,9 +107,10 @@ class track_system_stats:
59107
60108 cpu_res, gpu_res = t.cpu_res, t.gpu_res
61109
62- Args:
63- logger (Logger): external logger.
64- disabled (bool): If True, the context manager will not track the memory usage.
110+ Args:
111+ logger (Logger): external logger.
112+ device: Device (or list of devices) to monitor. Defaults to all visible CUDA devices.
113+ disabled (bool): If True, the context manager will not track the memory usage.
65114 """
66115
67116 def get_stats (self ):
@@ -77,14 +126,16 @@ def get_stats(self):
77126
78127 gpu_res = None
79128 if get_memory_gpu_mb is not None :
80- gpu_res = get_memory_gpu_mb ()
129+ gpu_res = get_memory_gpu_mb (self .gpu_ids )
130+ gpu_res = self ._zip_gpu_res (gpu_res )
81131
82132 return cpu_res , gpu_res
83133
84- def __init__ (self , logger : Logger = None , disabled = False ):
134+ def __init__ (self , logger : Logger = None , device = None , disabled = False ):
85135 self .logger = logger
86136 self .disabled = disabled
87137 self ._it = 0
138+ self .gpu_ids = _parse_device_ids (device ) if torch .cuda .is_available () else None
88139
89140 def __enter__ (self ):
90141 if self .disabled :
@@ -93,9 +144,6 @@ def __enter__(self):
93144 if self .initial_cpu_res is None and self .initial_gpu_res is None :
94145 self .disabled = True
95146 else :
96- if self .initial_gpu_res is not None :
97- self .initial_gpu_res = {g : g_res for g , g_res in enumerate (self .initial_gpu_res )}
98-
99147 self .avg_gpu_res = self .initial_gpu_res
100148 self .avg_cpu_res = self .initial_cpu_res
101149
@@ -130,7 +178,7 @@ def update_stats(self, cpu_res, gpu_res):
130178
131179 Args:
132180 cpu_res (float): The memory usage of the CPU.
133- gpu_res (list ): The memory usage of the GPUs.
181+ gpu_res (dict ): The memory usage of the GPUs keyed by device id .
134182 """
135183 if self .disabled :
136184 return
@@ -143,9 +191,8 @@ def update_stats(self, cpu_res, gpu_res):
143191 self .max_cpu_res = max (self .max_cpu_res , cpu_res )
144192
145193 if self .initial_gpu_res is not None :
146- self .avg_gpu_res = {g : (g_res + alpha * (g_res - self .avg_gpu_res [g ])) for g , g_res in enumerate (gpu_res )}
147- self .max_gpu_res = {g : max (self .max_gpu_res [g ], g_res ) for g , g_res in enumerate (gpu_res )}
148- gpu_res = {g : g_res for g , g_res in enumerate (gpu_res )}
194+ self .avg_gpu_res = {g : (g_res + alpha * (g_res - self .avg_gpu_res [g ])) for g , g_res in gpu_res .items ()}
195+ self .max_gpu_res = {g : max (self .max_gpu_res [g ], g_res ) for g , g_res in gpu_res .items ()}
149196
150197 if self .logger is not None :
151198 self .logger .log_system_stats (cpu_res , gpu_res )
@@ -166,8 +213,21 @@ def print_stats(self):
166213 logging .info (f"\t Max CPU memory usage: { self .max_cpu_res :.2f} MB" )
167214
168215 if gpu_res is not None :
169- for gpu_id , g_res in enumerate ( gpu_res ):
216+ for gpu_id , g_res in gpu_res . items ( ):
170217 logging .info (f"\t Initial GPU { gpu_id } memory usage: { self .initial_gpu_res [gpu_id ]:.2f} MB" )
171218 logging .info (f"\t Average GPU { gpu_id } memory usage: { self .avg_gpu_res [gpu_id ]:.2f} MB" )
172219 logging .info (f"\t Final GPU { gpu_id } memory usage: { g_res :.2f} MB" )
173220 logging .info (f"\t Max GPU { gpu_id } memory usage: { self .max_gpu_res [gpu_id ]:.2f} MB" )
221+
222+ def _zip_gpu_res (self , gpu_res ):
223+ """
224+ Zip a list of GPU stats to a dict keyed by the selected GPU ids.
225+ """
226+ if gpu_res is None :
227+ return None
228+
229+ keys = self .gpu_ids if self .gpu_ids is not None else list (range (len (gpu_res )))
230+ if len (keys ) != len (gpu_res ):
231+ logging .warning ("Mismatch between provided GPU ids and measured GPUs. Falling back to enumeration." )
232+ keys = list (range (len (gpu_res )))
233+ return {g : g_res for g , g_res in zip (keys , gpu_res )}
0 commit comments