diff --git a/src/lightning/fabric/utilities/spike.py b/src/lightning/fabric/utilities/spike.py index cd2e05309e087..05a7e8f1160d3 100644 --- a/src/lightning/fabric/utilities/spike.py +++ b/src/lightning/fabric/utilities/spike.py @@ -165,7 +165,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.rtol = state_dict.pop("rtol") self.bad_batches = state_dict.pop("bad_batches") self.exclude_batches_path = state_dict.pop("bad_batches_path") - self.running.load_state_dict(state_dict.pop("running")) + self.running_mean.load_state_dict(state_dict.pop("running")) self.running_mean.base_metric.load_state_dict(state_dict.pop("mean")) diff --git a/tests/tests_fabric/utilities/test_spike.py b/tests/tests_fabric/utilities/test_spike.py index e7f0bedb8e9e9..d553ce3a96585 100644 --- a/tests/tests_fabric/utilities/test_spike.py +++ b/tests/tests_fabric/utilities/test_spike.py @@ -81,3 +81,15 @@ def test_fabric_spike_detection_integration(tmp_path, global_rank_spike, num_dev spike_value=spike_value, should_raise=should_raise, ) + + +@pytest.mark.skipif(not _TORCHMETRICS_GREATER_EQUAL_1_0_0, reason="requires torchmetrics>=1.0.0") +def test_spike_detection_state_dict_roundtrip(): + # Regression: ``load_state_dict`` previously called + # ``self.running.load_state_dict(...)`` but the attribute is named + # ``self.running_mean`` (set in ``__init__``), so any resume of a + # SpikeDetection callback crashed with + # ``AttributeError: 'SpikeDetection' object has no attribute 'running'``. + src = SpikeDetection() + dst = SpikeDetection() + dst.load_state_dict(src.state_dict())