2727from lightning .pytorch .accelerators .cpu import _PSUTIL_AVAILABLE
2828from lightning .pytorch .callbacks .callback import Callback
2929from lightning .pytorch .utilities .exceptions import MisconfigurationException
30+ from lightning .pytorch .utilities .rank_zero import rank_zero_warn
3031from lightning .pytorch .utilities .types import STEP_OUTPUT
3132
3233
@@ -99,6 +100,12 @@ class DeviceStatsMonitor(Callback):
99100 cpu_stats: if ``None``, it will log CPU stats only if the accelerator is CPU.
100101 If ``True``, it will log CPU stats regardless of the accelerator.
101102 If ``False``, it will not log CPU stats regardless of the accelerator.
103+ filter_keys: if ``None``, all stats returned by the accelerator are logged.
104+ If a ``set`` of strings is provided, only the keys present in the set will be logged.
105+ Keys are matched against the base metric names before prefixing (e.g.,
106+ ``"cpu_percent"`` not ``"DeviceStatsMonitor.on_train_batch_end/cpu_percent"``).
107+ A ``rank_zero_warn`` is emitted for any key in ``filter_keys`` not found in the
108+ collected stats, which helps catch typos early.
102109
103110 Raises:
104111 MisconfigurationException:
@@ -110,13 +117,29 @@ class DeviceStatsMonitor(Callback):
110117
111118 from lightning import Trainer
112119 from lightning.pytorch.callbacks import DeviceStatsMonitor
120+
121+ # log all stats (default behaviour)
113122 device_stats = DeviceStatsMonitor()
114123 trainer = Trainer(callbacks=[device_stats])
115124
125+ # log only peak and current allocated GPU memory
126+ device_stats = DeviceStatsMonitor(
127+ filter_keys={"allocated_bytes.all.current", "allocated_bytes.all.peak"}
128+ )
129+ trainer = Trainer(callbacks=[device_stats])
130+
131+ # log CPU stats alongside a subset of GPU memory stats
132+ device_stats = DeviceStatsMonitor(
133+ cpu_stats=True,
134+ filter_keys={"cpu_percent", "allocated_bytes.all.current"},
135+ )
136+ trainer = Trainer(callbacks=[device_stats])
137+
116138 """
117139
118- def __init__ (self , cpu_stats : Optional [bool ] = None ) -> None :
140+ def __init__ (self , cpu_stats : Optional [bool ] = None , filter_keys : Optional [ set [ str ]] = None ) -> None :
119141 self ._cpu_stats = cpu_stats
142+ self ._filter_keys = filter_keys
120143
121144 @override
122145 def setup (
@@ -138,6 +161,20 @@ def setup(
138161 f"`DeviceStatsMonitor` cannot log CPU stats as `psutil` is not installed. { str (_PSUTIL_AVAILABLE )} "
139162 )
140163
164+ if self ._filter_keys is not None :
165+ device_stats = trainer .accelerator .get_device_stats (device )
166+ if self ._cpu_stats and device .type != "cpu" :
167+ from lightning .pytorch .accelerators .cpu import get_cpu_stats
168+
169+ device_stats .update (get_cpu_stats ())
170+
171+ unrecognized = self ._filter_keys - device_stats .keys ()
172+ if unrecognized :
173+ rank_zero_warn (
174+ f"`DeviceStatsMonitor` filter_keys contains keys not found in device stats and will be ignored:"
175+ f" { unrecognized } "
176+ )
177+
141178 def _get_and_log_device_stats (self , trainer : "pl.Trainer" , key : str ) -> None :
142179 if not trainer ._logger_connector .should_update_logs :
143180 return
@@ -155,6 +192,9 @@ def _get_and_log_device_stats(self, trainer: "pl.Trainer", key: str) -> None:
155192
156193 device_stats .update (get_cpu_stats ())
157194
195+ if self ._filter_keys is not None :
196+ device_stats = {k : v for k , v in device_stats .items () if k in self ._filter_keys }
197+
158198 for logger in trainer .loggers :
159199 separator = logger .group_separator
160200 prefixed_device_stats = _prefix_metric_keys (device_stats , f"{ self .__class__ .__qualname__ } .{ key } " , separator )
0 commit comments