Skip to content

Commit 3afa38a

Browse files
musabgultekinjoecummings
authored andcommitted
MPS memory usage support (pytorch#2406)
1 parent c7c4270 commit 3afa38a

File tree

1 file changed

+34
-22
lines changed

1 file changed

+34
-22
lines changed

torchtune/training/memory.py

+34-22
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,9 @@ def optim_step(param) -> None:
242242
p.register_post_accumulate_grad_hook(optim_step)
243243

244244

245+
_BYTES_IN_GIB = 1024**3
246+
247+
245248
def get_memory_stats(device: torch.device, reset_stats: bool = True) -> dict:
246249
"""
247250
Computes a memory summary for the passed in device. If ``reset_stats`` is ``True``, this will
@@ -250,33 +253,41 @@ def get_memory_stats(device: torch.device, reset_stats: bool = True) -> dict:
250253
individual sections of training.
251254
252255
Args:
253-
device (torch.device): Device to get memory summary for. Only CUDA devices are supported.
256+
device (torch.device): Device to get memory summary for. Supports CUDA and MPS devices.
254257
reset_stats (bool): Whether to reset CUDA's peak memory tracking.
255258
256259
Returns:
257260
Dict[str, float]: A dictionary containing the peak memory active, peak memory allocated,
258261
and peak memory reserved. This dict is useful for logging memory stats.
259262
260263
Raises:
261-
ValueError: If the passed-in device is not CUDA.
264+
ValueError: If the passed-in device is CPU.
262265
"""
263266
if device.type == "cpu":
264267
raise ValueError("Logging memory stats is not supported on CPU devices")
265268

266-
torch_device = get_torch_device_namespace()
267-
peak_memory_active = torch_device.memory_stats().get("active_bytes.all.peak", 0) / (
268-
1024**3
269-
)
270-
peak_mem_alloc = torch_device.max_memory_allocated(device) / (1024**3)
271-
peak_mem_reserved = torch_device.max_memory_reserved(device) / (1024**3)
272-
if reset_stats:
273-
torch_device.reset_peak_memory_stats(device)
274-
275-
memory_stats = {
276-
"peak_memory_active": peak_memory_active,
277-
"peak_memory_alloc": peak_mem_alloc,
278-
"peak_memory_reserved": peak_mem_reserved,
279-
}
269+
if device.type == "mps":
270+
peak_memory_active = torch.mps.current_allocated_memory() / _BYTES_IN_GIB
271+
peak_memory_alloc = torch.mps.driver_allocated_memory() / _BYTES_IN_GIB
272+
memory_stats = {
273+
"peak_memory_active": peak_memory_active,
274+
"peak_memory_alloc": peak_memory_alloc,
275+
}
276+
else:
277+
torch_device = get_torch_device_namespace()
278+
peak_memory_active = (
279+
torch_device.memory_stats().get("active_bytes.all.peak", 0) / _BYTES_IN_GIB
280+
)
281+
peak_memory_alloc = torch_device.max_memory_allocated(device) / _BYTES_IN_GIB
282+
peak_memory_reserved = torch_device.max_memory_reserved(device) / _BYTES_IN_GIB
283+
memory_stats = {
284+
"peak_memory_active": peak_memory_active,
285+
"peak_memory_alloc": peak_memory_alloc,
286+
"peak_memory_reserved": peak_memory_reserved,
287+
}
288+
if reset_stats:
289+
torch_device.reset_peak_memory_stats(device)
290+
280291
return memory_stats
281292

282293

@@ -288,19 +299,20 @@ def log_memory_stats(
288299
) -> None:
289300
"""
290301
Logs a dict containing memory stats to the logger. ``stats`` should contain the fields
291-
``peak_memory_active``, ``peak_memory_alloc``, and ``peak_memory_reserved`` as
302+
``peak_memory_active``, ``peak_memory_alloc``, and ``peak_memory_reserved`` (optional) as
292303
returned by :func:`torchtune.training.get_memory_stats`.
293304
294305
Args:
295306
stats (Dict[str, float]): A dictionary containing the peak memory active, peak memory
296-
allocated, and peak memory reserved stats.
307+
allocated, and peak memory reserved (optional) stats.
297308
message (str): An optional message to prepend to the log output.
298309
Defaults to "Memory stats after model init:"
299310
"""
300311
device_support = get_device_support()
301312
_log.info(
302-
f"{message}"
303-
f"\n\t{device_support.device_name} peak memory allocation: {stats['peak_memory_alloc']:.2f} GiB"
304-
f"\n\t{device_support.device_name} peak memory reserved: {stats['peak_memory_reserved']:.2f} GiB"
305-
f"\n\t{device_support.device_name} peak memory active: {stats['peak_memory_active']:.2f} GiB"
313+
f"{message}\n"
314+
+ "\n".join(
315+
f"\t{device_support.device_name} {key.replace('_', ' ')}: {value:.2f} GiB"
316+
for key, value in stats.items()
317+
)
306318
)

0 commit comments

Comments
 (0)