diff --git a/src/lightning/pytorch/callbacks/device_stats_monitor.py b/src/lightning/pytorch/callbacks/device_stats_monitor.py index 873c4c05f5aed..57d1618c30bd9 100644 --- a/src/lightning/pytorch/callbacks/device_stats_monitor.py +++ b/src/lightning/pytorch/callbacks/device_stats_monitor.py @@ -27,6 +27,7 @@ from lightning.pytorch.accelerators.cpu import _PSUTIL_AVAILABLE from lightning.pytorch.callbacks.callback import Callback from lightning.pytorch.utilities.exceptions import MisconfigurationException +from lightning.pytorch.utilities.rank_zero import rank_zero_warn from lightning.pytorch.utilities.types import STEP_OUTPUT @@ -99,6 +100,12 @@ class DeviceStatsMonitor(Callback): cpu_stats: if ``None``, it will log CPU stats only if the accelerator is CPU. If ``True``, it will log CPU stats regardless of the accelerator. If ``False``, it will not log CPU stats regardless of the accelerator. + filter_keys: if ``None``, all stats returned by the accelerator are logged. + If a ``set`` of strings is provided, only the keys present in the set will be logged. + Keys are matched against the base metric names before prefixing (e.g., + ``"cpu_percent"`` not ``"DeviceStatsMonitor.on_train_batch_end/cpu_percent"``). + A ``rank_zero_warn`` is emitted for any key in ``filter_keys`` not found in the + collected stats, which helps catch typos early. Raises: MisconfigurationException: @@ -110,13 +117,29 @@ class DeviceStatsMonitor(Callback): from lightning import Trainer from lightning.pytorch.callbacks import DeviceStatsMonitor + + # log all stats (default behaviour) device_stats = DeviceStatsMonitor() trainer = Trainer(callbacks=[device_stats]) + # log only peak and current allocated GPU memory + device_stats = DeviceStatsMonitor( + filter_keys={"allocated_bytes.all.current", "allocated_bytes.all.peak"} + ) + trainer = Trainer(callbacks=[device_stats]) + + # log CPU stats alongside a subset of GPU memory stats + device_stats = DeviceStatsMonitor( + cpu_stats=True, + filter_keys={"cpu_percent", "allocated_bytes.all.current"}, + ) + trainer = Trainer(callbacks=[device_stats]) + """ - def __init__(self, cpu_stats: Optional[bool] = None) -> None: + def __init__(self, cpu_stats: Optional[bool] = None, filter_keys: Optional[set[str]] = None) -> None: self._cpu_stats = cpu_stats + self._filter_keys = filter_keys @override def setup( @@ -138,6 +161,20 @@ def setup( f"`DeviceStatsMonitor` cannot log CPU stats as `psutil` is not installed. {str(_PSUTIL_AVAILABLE)} " ) + if self._filter_keys is not None: + device_stats = trainer.accelerator.get_device_stats(device) + if self._cpu_stats and device.type != "cpu": + from lightning.pytorch.accelerators.cpu import get_cpu_stats + + device_stats.update(get_cpu_stats()) + + unrecognized = self._filter_keys - device_stats.keys() + if unrecognized: + rank_zero_warn( + f"`DeviceStatsMonitor` filter_keys contains keys not found in device stats and will be ignored:" + f" {unrecognized}" + ) + def _get_and_log_device_stats(self, trainer: "pl.Trainer", key: str) -> None: if not trainer._logger_connector.should_update_logs: return @@ -155,6 +192,9 @@ def _get_and_log_device_stats(self, trainer: "pl.Trainer", key: str) -> None: device_stats.update(get_cpu_stats()) + if self._filter_keys is not None: + device_stats = {k: v for k, v in device_stats.items() if k in self._filter_keys} + for logger in trainer.loggers: separator = logger.group_separator prefixed_device_stats = _prefix_metric_keys(device_stats, f"{self.__class__.__qualname__}.{key}", separator) diff --git a/tests/tests_pytorch/callbacks/test_device_stats_monitor.py b/tests/tests_pytorch/callbacks/test_device_stats_monitor.py index 290a0921cb06d..d485ab76d198e 100644 --- a/tests/tests_pytorch/callbacks/test_device_stats_monitor.py +++ b/tests/tests_pytorch/callbacks/test_device_stats_monitor.py @@ -217,3 +217,67 @@ def test_device_stats_monitor_logs_for_different_stages(tmp_path): test = any(test_stage_results) assert test, "testing stage logs not found" + + +@RunIf(psutil=True) +@pytest.mark.parametrize( + ("filter_keys", "expected_present", "expected_absent"), + [ + ( + {_CPU_VM_PERCENT, _CPU_PERCENT}, + [_CPU_VM_PERCENT, _CPU_PERCENT], + [_CPU_SWAP_PERCENT], + ), + ( + {_CPU_PERCENT}, + [_CPU_PERCENT], + [_CPU_VM_PERCENT, _CPU_SWAP_PERCENT], + ), + ], +) +def test_device_stats_monitor_filter_keys(tmp_path, filter_keys, expected_present, expected_absent): + """Test that filter_keys logs only the specified keys and omits the rest.""" + model = BoringModel() + + class AssertFilterLogger(CSVLogger): + def log_metrics(self, metrics: dict[str, float], step: Optional[int] = None) -> None: + for key in expected_present: + assert any(key in k for k in metrics), f"Expected key {key!r} not found in metrics" + for key in expected_absent: + assert not any(key in k for k in metrics), f"Unexpected key {key!r} found in metrics" + + device_stats = DeviceStatsMonitor(cpu_stats=True, filter_keys=filter_keys) + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=1, + limit_train_batches=2, + limit_val_batches=0, + log_every_n_steps=1, + callbacks=device_stats, + logger=AssertFilterLogger(tmp_path), + enable_checkpointing=False, + enable_progress_bar=False, + accelerator="cpu", + ) + trainer.fit(model) + + +@RunIf(psutil=True) +def test_device_stats_monitor_filter_keys_unrecognized_warns(tmp_path): + """Test that filter_keys emits a warning for keys not present in device stats.""" + model = BoringModel() + device_stats = DeviceStatsMonitor(cpu_stats=True, filter_keys={"nonexistent_key_xyz"}) + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=1, + limit_train_batches=1, + limit_val_batches=0, + log_every_n_steps=1, + callbacks=device_stats, + logger=CSVLogger(tmp_path), + enable_checkpointing=False, + enable_progress_bar=False, + accelerator="cpu", + ) + with pytest.warns(UserWarning, match="filter_keys contains keys not found"): + trainer.fit(model)