Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion src/lightning/pytorch/callbacks/device_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from lightning.pytorch.accelerators.cpu import _PSUTIL_AVAILABLE
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
from lightning.pytorch.utilities.types import STEP_OUTPUT


Expand Down Expand Up @@ -99,6 +100,12 @@ class DeviceStatsMonitor(Callback):
cpu_stats: if ``None``, it will log CPU stats only if the accelerator is CPU.
If ``True``, it will log CPU stats regardless of the accelerator.
If ``False``, it will not log CPU stats regardless of the accelerator.
filter_keys: if ``None``, all stats returned by the accelerator are logged.
If a ``set`` of strings is provided, only the keys present in the set will be logged.
Keys are matched against the base metric names before prefixing (e.g.,
``"cpu_percent"`` not ``"DeviceStatsMonitor.on_train_batch_end/cpu_percent"``).
A ``rank_zero_warn`` is emitted for any key in ``filter_keys`` not found in the
collected stats, which helps catch typos early.

Raises:
MisconfigurationException:
Expand All @@ -110,13 +117,29 @@ class DeviceStatsMonitor(Callback):

from lightning import Trainer
from lightning.pytorch.callbacks import DeviceStatsMonitor

# log all stats (default behaviour)
device_stats = DeviceStatsMonitor()
trainer = Trainer(callbacks=[device_stats])

# log only peak and current allocated GPU memory
device_stats = DeviceStatsMonitor(
filter_keys={"allocated_bytes.all.current", "allocated_bytes.all.peak"}
)
trainer = Trainer(callbacks=[device_stats])

# log CPU stats alongside a subset of GPU memory stats
device_stats = DeviceStatsMonitor(
cpu_stats=True,
filter_keys={"cpu_percent", "allocated_bytes.all.current"},
)
trainer = Trainer(callbacks=[device_stats])

"""

def __init__(self, cpu_stats: Optional[bool] = None) -> None:
def __init__(self, cpu_stats: Optional[bool] = None, filter_keys: Optional[set[str]] = None) -> None:
self._cpu_stats = cpu_stats
self._filter_keys = filter_keys

@override
def setup(
Expand All @@ -138,6 +161,20 @@ def setup(
f"`DeviceStatsMonitor` cannot log CPU stats as `psutil` is not installed. {str(_PSUTIL_AVAILABLE)} "
)

if self._filter_keys is not None:
device_stats = trainer.accelerator.get_device_stats(device)
if self._cpu_stats and device.type != "cpu":
from lightning.pytorch.accelerators.cpu import get_cpu_stats

device_stats.update(get_cpu_stats())

unrecognized = self._filter_keys - device_stats.keys()
if unrecognized:
rank_zero_warn(
f"`DeviceStatsMonitor` filter_keys contains keys not found in device stats and will be ignored:"
f" {unrecognized}"
)
Comment thread
deependujha marked this conversation as resolved.

def _get_and_log_device_stats(self, trainer: "pl.Trainer", key: str) -> None:
if not trainer._logger_connector.should_update_logs:
return
Expand All @@ -155,6 +192,9 @@ def _get_and_log_device_stats(self, trainer: "pl.Trainer", key: str) -> None:

device_stats.update(get_cpu_stats())

if self._filter_keys is not None:
device_stats = {k: v for k, v in device_stats.items() if k in self._filter_keys}

for logger in trainer.loggers:
separator = logger.group_separator
prefixed_device_stats = _prefix_metric_keys(device_stats, f"{self.__class__.__qualname__}.{key}", separator)
Expand Down
64 changes: 64 additions & 0 deletions tests/tests_pytorch/callbacks/test_device_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,67 @@ def test_device_stats_monitor_logs_for_different_stages(tmp_path):
test = any(test_stage_results)

assert test, "testing stage logs not found"


@RunIf(psutil=True)
@pytest.mark.parametrize(
("filter_keys", "expected_present", "expected_absent"),
[
(
{_CPU_VM_PERCENT, _CPU_PERCENT},
[_CPU_VM_PERCENT, _CPU_PERCENT],
[_CPU_SWAP_PERCENT],
),
(
{_CPU_PERCENT},
[_CPU_PERCENT],
[_CPU_VM_PERCENT, _CPU_SWAP_PERCENT],
),
],
)
def test_device_stats_monitor_filter_keys(tmp_path, filter_keys, expected_present, expected_absent):
"""Test that filter_keys logs only the specified keys and omits the rest."""
model = BoringModel()

class AssertFilterLogger(CSVLogger):
def log_metrics(self, metrics: dict[str, float], step: Optional[int] = None) -> None:
for key in expected_present:
assert any(key in k for k in metrics), f"Expected key {key!r} not found in metrics"
for key in expected_absent:
assert not any(key in k for k in metrics), f"Unexpected key {key!r} found in metrics"

device_stats = DeviceStatsMonitor(cpu_stats=True, filter_keys=filter_keys)
trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=1,
limit_train_batches=2,
limit_val_batches=0,
log_every_n_steps=1,
callbacks=device_stats,
logger=AssertFilterLogger(tmp_path),
enable_checkpointing=False,
enable_progress_bar=False,
accelerator="cpu",
)
trainer.fit(model)


@RunIf(psutil=True)
def test_device_stats_monitor_filter_keys_unrecognized_warns(tmp_path):
"""Test that filter_keys emits a warning for keys not present in device stats."""
model = BoringModel()
device_stats = DeviceStatsMonitor(cpu_stats=True, filter_keys={"nonexistent_key_xyz"})
trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=1,
limit_train_batches=1,
limit_val_batches=0,
log_every_n_steps=1,
callbacks=device_stats,
logger=CSVLogger(tmp_path),
enable_checkpointing=False,
enable_progress_bar=False,
accelerator="cpu",
)
with pytest.warns(UserWarning, match="filter_keys contains keys not found"):
trainer.fit(model)
Comment thread
deependujha marked this conversation as resolved.
Loading