Skip to content

Commit 2849907

Browse files
authored
feat: add filter_keys​ to log only specified device stats (#21707)
* feat: add filter_keys​ to log only specified device stats * update
1 parent 1120456 commit 2849907

2 files changed

Lines changed: 105 additions & 1 deletion

File tree

src/lightning/pytorch/callbacks/device_stats_monitor.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from lightning.pytorch.accelerators.cpu import _PSUTIL_AVAILABLE
2828
from lightning.pytorch.callbacks.callback import Callback
2929
from lightning.pytorch.utilities.exceptions import MisconfigurationException
30+
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
3031
from lightning.pytorch.utilities.types import STEP_OUTPUT
3132

3233

@@ -99,6 +100,12 @@ class DeviceStatsMonitor(Callback):
99100
cpu_stats: if ``None``, it will log CPU stats only if the accelerator is CPU.
100101
If ``True``, it will log CPU stats regardless of the accelerator.
101102
If ``False``, it will not log CPU stats regardless of the accelerator.
103+
filter_keys: if ``None``, all stats returned by the accelerator are logged.
104+
If a ``set`` of strings is provided, only the keys present in the set will be logged.
105+
Keys are matched against the base metric names before prefixing (e.g.,
106+
``"cpu_percent"`` not ``"DeviceStatsMonitor.on_train_batch_end/cpu_percent"``).
107+
A ``rank_zero_warn`` is emitted for any key in ``filter_keys`` not found in the
108+
collected stats, which helps catch typos early.
102109
103110
Raises:
104111
MisconfigurationException:
@@ -110,13 +117,29 @@ class DeviceStatsMonitor(Callback):
110117
111118
from lightning import Trainer
112119
from lightning.pytorch.callbacks import DeviceStatsMonitor
120+
121+
# log all stats (default behaviour)
113122
device_stats = DeviceStatsMonitor()
114123
trainer = Trainer(callbacks=[device_stats])
115124
125+
# log only peak and current allocated GPU memory
126+
device_stats = DeviceStatsMonitor(
127+
filter_keys={"allocated_bytes.all.current", "allocated_bytes.all.peak"}
128+
)
129+
trainer = Trainer(callbacks=[device_stats])
130+
131+
# log CPU stats alongside a subset of GPU memory stats
132+
device_stats = DeviceStatsMonitor(
133+
cpu_stats=True,
134+
filter_keys={"cpu_percent", "allocated_bytes.all.current"},
135+
)
136+
trainer = Trainer(callbacks=[device_stats])
137+
116138
"""
117139

118-
def __init__(self, cpu_stats: Optional[bool] = None) -> None:
140+
def __init__(self, cpu_stats: Optional[bool] = None, filter_keys: Optional[set[str]] = None) -> None:
119141
self._cpu_stats = cpu_stats
142+
self._filter_keys = filter_keys
120143

121144
@override
122145
def setup(
@@ -138,6 +161,20 @@ def setup(
138161
f"`DeviceStatsMonitor` cannot log CPU stats as `psutil` is not installed. {str(_PSUTIL_AVAILABLE)} "
139162
)
140163

164+
if self._filter_keys is not None:
165+
device_stats = trainer.accelerator.get_device_stats(device)
166+
if self._cpu_stats and device.type != "cpu":
167+
from lightning.pytorch.accelerators.cpu import get_cpu_stats
168+
169+
device_stats.update(get_cpu_stats())
170+
171+
unrecognized = self._filter_keys - device_stats.keys()
172+
if unrecognized:
173+
rank_zero_warn(
174+
f"`DeviceStatsMonitor` filter_keys contains keys not found in device stats and will be ignored:"
175+
f" {unrecognized}"
176+
)
177+
141178
def _get_and_log_device_stats(self, trainer: "pl.Trainer", key: str) -> None:
142179
if not trainer._logger_connector.should_update_logs:
143180
return
@@ -155,6 +192,9 @@ def _get_and_log_device_stats(self, trainer: "pl.Trainer", key: str) -> None:
155192

156193
device_stats.update(get_cpu_stats())
157194

195+
if self._filter_keys is not None:
196+
device_stats = {k: v for k, v in device_stats.items() if k in self._filter_keys}
197+
158198
for logger in trainer.loggers:
159199
separator = logger.group_separator
160200
prefixed_device_stats = _prefix_metric_keys(device_stats, f"{self.__class__.__qualname__}.{key}", separator)

tests/tests_pytorch/callbacks/test_device_stats_monitor.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,67 @@ def test_device_stats_monitor_logs_for_different_stages(tmp_path):
217217
test = any(test_stage_results)
218218

219219
assert test, "testing stage logs not found"
220+
221+
222+
@RunIf(psutil=True)
223+
@pytest.mark.parametrize(
224+
("filter_keys", "expected_present", "expected_absent"),
225+
[
226+
(
227+
{_CPU_VM_PERCENT, _CPU_PERCENT},
228+
[_CPU_VM_PERCENT, _CPU_PERCENT],
229+
[_CPU_SWAP_PERCENT],
230+
),
231+
(
232+
{_CPU_PERCENT},
233+
[_CPU_PERCENT],
234+
[_CPU_VM_PERCENT, _CPU_SWAP_PERCENT],
235+
),
236+
],
237+
)
238+
def test_device_stats_monitor_filter_keys(tmp_path, filter_keys, expected_present, expected_absent):
239+
"""Test that filter_keys logs only the specified keys and omits the rest."""
240+
model = BoringModel()
241+
242+
class AssertFilterLogger(CSVLogger):
243+
def log_metrics(self, metrics: dict[str, float], step: Optional[int] = None) -> None:
244+
for key in expected_present:
245+
assert any(key in k for k in metrics), f"Expected key {key!r} not found in metrics"
246+
for key in expected_absent:
247+
assert not any(key in k for k in metrics), f"Unexpected key {key!r} found in metrics"
248+
249+
device_stats = DeviceStatsMonitor(cpu_stats=True, filter_keys=filter_keys)
250+
trainer = Trainer(
251+
default_root_dir=tmp_path,
252+
max_epochs=1,
253+
limit_train_batches=2,
254+
limit_val_batches=0,
255+
log_every_n_steps=1,
256+
callbacks=device_stats,
257+
logger=AssertFilterLogger(tmp_path),
258+
enable_checkpointing=False,
259+
enable_progress_bar=False,
260+
accelerator="cpu",
261+
)
262+
trainer.fit(model)
263+
264+
265+
@RunIf(psutil=True)
266+
def test_device_stats_monitor_filter_keys_unrecognized_warns(tmp_path):
267+
"""Test that filter_keys emits a warning for keys not present in device stats."""
268+
model = BoringModel()
269+
device_stats = DeviceStatsMonitor(cpu_stats=True, filter_keys={"nonexistent_key_xyz"})
270+
trainer = Trainer(
271+
default_root_dir=tmp_path,
272+
max_epochs=1,
273+
limit_train_batches=1,
274+
limit_val_batches=0,
275+
log_every_n_steps=1,
276+
callbacks=device_stats,
277+
logger=CSVLogger(tmp_path),
278+
enable_checkpointing=False,
279+
enable_progress_bar=False,
280+
accelerator="cpu",
281+
)
282+
with pytest.warns(UserWarning, match="filter_keys contains keys not found"):
283+
trainer.fit(model)

0 commit comments

Comments
 (0)