Skip to content

Commit 539539b

Browse files
committed
Fix SpikeDetection.load_state_dict referencing nonexistent self.running
`SpikeDetection.__init__` stores the running-mean metric on `self.running_mean` (line 57). `SpikeDetection.state_dict` saves it under the key `"running"` via `self.running_mean.state_dict()` (line 156). `SpikeDetection.load_state_dict` then tries to restore it with `self.running.load_state_dict(...)` — but `self.running` is never defined, so any resume of a SpikeDetection callback state crashes with: AttributeError: 'SpikeDetection' object has no attribute 'running' The asymmetry with the symmetric line right below it (`self.running_mean.base_metric.load_state_dict(...)`) confirms it's a typo of `self.running_mean`, not a deliberate different attribute. Adds a regression test that round-trips state_dict / load_state_dict on a fresh SpikeDetection.
1 parent 35e56ef commit 539539b

2 files changed

Lines changed: 13 additions & 1 deletion

File tree

src/lightning/fabric/utilities/spike.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
165165
self.rtol = state_dict.pop("rtol")
166166
self.bad_batches = state_dict.pop("bad_batches")
167167
self.exclude_batches_path = state_dict.pop("bad_batches_path")
168-
self.running.load_state_dict(state_dict.pop("running"))
168+
self.running_mean.load_state_dict(state_dict.pop("running"))
169169
self.running_mean.base_metric.load_state_dict(state_dict.pop("mean"))
170170

171171

tests/tests_fabric/utilities/test_spike.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,15 @@ def test_fabric_spike_detection_integration(tmp_path, global_rank_spike, num_dev
8181
spike_value=spike_value,
8282
should_raise=should_raise,
8383
)
84+
85+
86+
@pytest.mark.skipif(not _TORCHMETRICS_GREATER_EQUAL_1_0_0, reason="requires torchmetrics>=1.0.0")
87+
def test_spike_detection_state_dict_roundtrip():
88+
# Regression: ``load_state_dict`` previously called
89+
# ``self.running.load_state_dict(...)`` but the attribute is named
90+
# ``self.running_mean`` (set in ``__init__``), so any resume of a
91+
# SpikeDetection callback crashed with
92+
# ``AttributeError: 'SpikeDetection' object has no attribute 'running'``.
93+
src = SpikeDetection()
94+
dst = SpikeDetection()
95+
dst.load_state_dict(src.state_dict())

0 commit comments

Comments
 (0)