Skip to content

Commit ea967c3

Browse files
authored
Merge pull request #72 from lxr2/device-aware-stats
Hi @lxr2 , thank you. I've had some reports in the past about this but never had the time to proverly go into this. Besides your fix to the track changes utility I add a small fix that was suggested me by my collegues some time ago, which is removing the garbage collector call in `conf.py`. This was added to try and keep the used memory under control but it also results in severe performance degradation.
2 parents f82adca + 2a5de6a commit ea967c3

File tree

4 files changed

+82
-22
lines changed

4 files changed

+82
-22
lines changed

utils/conf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ def get_alloc_memory_all_devices(avail_devices=None, return_all=False) -> Union[
8080
gpu_memory_nvidiasmi.append(-1)
8181

8282
del _
83-
gc.collect()
84-
torch.cuda.empty_cache()
83+
# gc.collect()
84+
# torch.cuda.empty_cache()
8585

8686
if return_all:
8787
return gpu_memory_reserved, gpu_memory_allocated, gpu_memory_nvidiasmi

utils/deprecated/continual_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def train(args: Namespace):
7171
model.net.to(model.device)
7272
torch.cuda.empty_cache()
7373

74-
with track_system_stats(logger) as system_tracker:
74+
with track_system_stats(logger, device=args.device) as system_tracker:
7575
epoch, i = 0, 0
7676
model.net.train()
7777

utils/stats.py

Lines changed: 78 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,18 @@ def get_memory_mb():
2222
except BaseException:
2323
get_memory_mb = None
2424

25-
try:
26-
import torch
25+
import torch
2726

27+
try:
2828
if torch.cuda.is_available():
2929
from utils.conf import get_alloc_memory_all_devices
3030

31-
def get_memory_gpu_mb():
31+
def get_memory_gpu_mb(avail_devices=None):
3232
"""
33-
Get the memory usage of all GPUs in MB.
33+
Get the memory usage of the selected GPUs in MB.
3434
"""
3535

36-
return [d / 1024 / 1024 for d in get_alloc_memory_all_devices()]
36+
return [d / 1024 / 1024 for d in get_alloc_memory_all_devices(avail_devices=avail_devices)]
3737
else:
3838
get_memory_gpu_mb = None
3939
except BaseException:
@@ -43,6 +43,54 @@ def get_memory_gpu_mb():
4343
from utils.loggers import Logger
4444

4545

46+
def _parse_device_ids(device):
47+
"""
48+
Normalize a device specification to a list of CUDA ids.
49+
"""
50+
if device is None:
51+
return None
52+
53+
if isinstance(device, torch.device):
54+
if device.type != 'cuda':
55+
return None
56+
if device.index is None:
57+
return list(range(torch.cuda.device_count()))
58+
if 0 <= device.index < torch.cuda.device_count():
59+
return [device.index]
60+
logging.warning(f"Requested device index {device.index} is out of range.")
61+
return None
62+
63+
if isinstance(device, str):
64+
if 'cuda' not in device:
65+
return None
66+
parts = [p for p in device.split(',') if p.strip() != '']
67+
if len(parts) == 0:
68+
return list(range(torch.cuda.device_count()))
69+
ids = []
70+
for p in parts:
71+
try:
72+
ids.append(int(p.split(':')[-1]))
73+
except ValueError:
74+
logging.warning(f"Could not parse device id from `{p}`, skipping.")
75+
ids = [i for i in ids if 0 <= i < torch.cuda.device_count()]
76+
if len(ids) == 0:
77+
logging.warning("No valid CUDA device ids parsed, falling back to all visible devices.")
78+
return list(range(torch.cuda.device_count()))
79+
return ids
80+
81+
if isinstance(device, (list, tuple)):
82+
ids = []
83+
for d in device:
84+
if isinstance(d, int):
85+
ids.append(d)
86+
elif isinstance(d, torch.device) and d.type == 'cuda' and d.index is not None:
87+
ids.append(d.index)
88+
ids = [i for i in ids if 0 <= i < torch.cuda.device_count()]
89+
return ids or None
90+
91+
return None
92+
93+
4694
class track_system_stats:
4795
"""
4896
A context manager that tracks the memory usage of the system.
@@ -59,9 +107,10 @@ class track_system_stats:
59107
60108
cpu_res, gpu_res = t.cpu_res, t.gpu_res
61109
62-
Args:
63-
logger (Logger): external logger.
64-
disabled (bool): If True, the context manager will not track the memory usage.
110+
Args:
111+
logger (Logger): external logger.
112+
device: Device (or list of devices) to monitor. Defaults to all visible CUDA devices.
113+
disabled (bool): If True, the context manager will not track the memory usage.
65114
"""
66115

67116
def get_stats(self):
@@ -77,14 +126,16 @@ def get_stats(self):
77126

78127
gpu_res = None
79128
if get_memory_gpu_mb is not None:
80-
gpu_res = get_memory_gpu_mb()
129+
gpu_res = get_memory_gpu_mb(self.gpu_ids)
130+
gpu_res = self._zip_gpu_res(gpu_res)
81131

82132
return cpu_res, gpu_res
83133

84-
def __init__(self, logger: Logger = None, disabled=False):
134+
def __init__(self, logger: Logger = None, device=None, disabled=False):
85135
self.logger = logger
86136
self.disabled = disabled
87137
self._it = 0
138+
self.gpu_ids = _parse_device_ids(device) if torch.cuda.is_available() else None
88139

89140
def __enter__(self):
90141
if self.disabled:
@@ -93,9 +144,6 @@ def __enter__(self):
93144
if self.initial_cpu_res is None and self.initial_gpu_res is None:
94145
self.disabled = True
95146
else:
96-
if self.initial_gpu_res is not None:
97-
self.initial_gpu_res = {g: g_res for g, g_res in enumerate(self.initial_gpu_res)}
98-
99147
self.avg_gpu_res = self.initial_gpu_res
100148
self.avg_cpu_res = self.initial_cpu_res
101149

@@ -130,7 +178,7 @@ def update_stats(self, cpu_res, gpu_res):
130178
131179
Args:
132180
cpu_res (float): The memory usage of the CPU.
133-
gpu_res (list): The memory usage of the GPUs.
181+
gpu_res (dict): The memory usage of the GPUs keyed by device id.
134182
"""
135183
if self.disabled:
136184
return
@@ -143,9 +191,8 @@ def update_stats(self, cpu_res, gpu_res):
143191
self.max_cpu_res = max(self.max_cpu_res, cpu_res)
144192

145193
if self.initial_gpu_res is not None:
146-
self.avg_gpu_res = {g: (g_res + alpha * (g_res - self.avg_gpu_res[g])) for g, g_res in enumerate(gpu_res)}
147-
self.max_gpu_res = {g: max(self.max_gpu_res[g], g_res) for g, g_res in enumerate(gpu_res)}
148-
gpu_res = {g: g_res for g, g_res in enumerate(gpu_res)}
194+
self.avg_gpu_res = {g: (g_res + alpha * (g_res - self.avg_gpu_res[g])) for g, g_res in gpu_res.items()}
195+
self.max_gpu_res = {g: max(self.max_gpu_res[g], g_res) for g, g_res in gpu_res.items()}
149196

150197
if self.logger is not None:
151198
self.logger.log_system_stats(cpu_res, gpu_res)
@@ -166,8 +213,21 @@ def print_stats(self):
166213
logging.info(f"\tMax CPU memory usage: {self.max_cpu_res:.2f} MB")
167214

168215
if gpu_res is not None:
169-
for gpu_id, g_res in enumerate(gpu_res):
216+
for gpu_id, g_res in gpu_res.items():
170217
logging.info(f"\tInitial GPU {gpu_id} memory usage: {self.initial_gpu_res[gpu_id]:.2f} MB")
171218
logging.info(f"\tAverage GPU {gpu_id} memory usage: {self.avg_gpu_res[gpu_id]:.2f} MB")
172219
logging.info(f"\tFinal GPU {gpu_id} memory usage: {g_res:.2f} MB")
173220
logging.info(f"\tMax GPU {gpu_id} memory usage: {self.max_gpu_res[gpu_id]:.2f} MB")
221+
222+
def _zip_gpu_res(self, gpu_res):
223+
"""
224+
Zip a list of GPU stats to a dict keyed by the selected GPU ids.
225+
"""
226+
if gpu_res is None:
227+
return None
228+
229+
keys = self.gpu_ids if self.gpu_ids is not None else list(range(len(gpu_res)))
230+
if len(keys) != len(gpu_res):
231+
logging.warning("Mismatch between provided GPU ids and measured GPUs. Falling back to enumeration.")
232+
keys = list(range(len(gpu_res)))
233+
return {g: g_res for g, g_res in zip(keys, gpu_res)}

utils/training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def train(model: ContinualModel, dataset: ContinualDataset,
156156
model.net.to(model.device)
157157
torch.cuda.empty_cache()
158158

159-
with track_system_stats(logger) as system_tracker:
159+
with track_system_stats(logger, device=args.device) as system_tracker:
160160
results, results_mask_classes = [], []
161161

162162
if args.eval_future:

0 commit comments

Comments
 (0)