@@ -242,6 +242,9 @@ def optim_step(param) -> None:
242
242
p .register_post_accumulate_grad_hook (optim_step )
243
243
244
244
245
+ _BYTES_IN_GIB = 1024 ** 3
246
+
247
+
245
248
def get_memory_stats (device : torch .device , reset_stats : bool = True ) -> dict :
246
249
"""
247
250
Computes a memory summary for the passed in device. If ``reset_stats`` is ``True``, this will
@@ -250,33 +253,41 @@ def get_memory_stats(device: torch.device, reset_stats: bool = True) -> dict:
250
253
individual sections of training.
251
254
252
255
Args:
253
- device (torch.device): Device to get memory summary for. Only CUDA devices are supported .
256
+ device (torch.device): Device to get memory summary for. Supports CUDA and MPS devices .
254
257
reset_stats (bool): Whether to reset CUDA's peak memory tracking.
255
258
256
259
Returns:
257
260
Dict[str, float]: A dictionary containing the peak memory active, peak memory allocated,
258
261
and peak memory reserved. This dict is useful for logging memory stats.
259
262
260
263
Raises:
261
- ValueError: If the passed-in device is not CUDA .
264
+ ValueError: If the passed-in device is CPU .
262
265
"""
263
266
if device .type == "cpu" :
264
267
raise ValueError ("Logging memory stats is not supported on CPU devices" )
265
268
266
- torch_device = get_torch_device_namespace ()
267
- peak_memory_active = torch_device .memory_stats ().get ("active_bytes.all.peak" , 0 ) / (
268
- 1024 ** 3
269
- )
270
- peak_mem_alloc = torch_device .max_memory_allocated (device ) / (1024 ** 3 )
271
- peak_mem_reserved = torch_device .max_memory_reserved (device ) / (1024 ** 3 )
272
- if reset_stats :
273
- torch_device .reset_peak_memory_stats (device )
274
-
275
- memory_stats = {
276
- "peak_memory_active" : peak_memory_active ,
277
- "peak_memory_alloc" : peak_mem_alloc ,
278
- "peak_memory_reserved" : peak_mem_reserved ,
279
- }
269
+ if device .type == "mps" :
270
+ peak_memory_active = torch .mps .current_allocated_memory () / _BYTES_IN_GIB
271
+ peak_memory_alloc = torch .mps .driver_allocated_memory () / _BYTES_IN_GIB
272
+ memory_stats = {
273
+ "peak_memory_active" : peak_memory_active ,
274
+ "peak_memory_alloc" : peak_memory_alloc ,
275
+ }
276
+ else :
277
+ torch_device = get_torch_device_namespace ()
278
+ peak_memory_active = (
279
+ torch_device .memory_stats ().get ("active_bytes.all.peak" , 0 ) / _BYTES_IN_GIB
280
+ )
281
+ peak_memory_alloc = torch_device .max_memory_allocated (device ) / _BYTES_IN_GIB
282
+ peak_memory_reserved = torch_device .max_memory_reserved (device ) / _BYTES_IN_GIB
283
+ memory_stats = {
284
+ "peak_memory_active" : peak_memory_active ,
285
+ "peak_memory_alloc" : peak_memory_alloc ,
286
+ "peak_memory_reserved" : peak_memory_reserved ,
287
+ }
288
+ if reset_stats :
289
+ torch_device .reset_peak_memory_stats (device )
290
+
280
291
return memory_stats
281
292
282
293
@@ -288,19 +299,20 @@ def log_memory_stats(
288
299
) -> None :
289
300
"""
290
301
Logs a dict containing memory stats to the logger. ``stats`` should contain the fields
291
- ``peak_memory_active``, ``peak_memory_alloc``, and ``peak_memory_reserved`` as
302
+ ``peak_memory_active``, ``peak_memory_alloc``, and ``peak_memory_reserved`` (optional) as
292
303
returned by :func:`torchtune.training.get_memory_stats`.
293
304
294
305
Args:
295
306
stats (Dict[str, float]): A dictionary containing the peak memory active, peak memory
296
- allocated, and peak memory reserved stats.
307
+ allocated, and peak memory reserved (optional) stats.
297
308
message (str): An optional message to prepend to the log output.
298
309
Defaults to "Memory stats after model init:"
299
310
"""
300
311
device_support = get_device_support ()
301
312
_log .info (
302
- f"{ message } "
303
- f"\n \t { device_support .device_name } peak memory allocation: { stats ['peak_memory_alloc' ]:.2f} GiB"
304
- f"\n \t { device_support .device_name } peak memory reserved: { stats ['peak_memory_reserved' ]:.2f} GiB"
305
- f"\n \t { device_support .device_name } peak memory active: { stats ['peak_memory_active' ]:.2f} GiB"
313
+ f"{ message } \n "
314
+ + "\n " .join (
315
+ f"\t { device_support .device_name } { key .replace ('_' , ' ' )} : { value :.2f} GiB"
316
+ for key , value in stats .items ()
317
+ )
306
318
)
0 commit comments