Skip to content

Commit 06ebcbf

Browse files
authored
feat(adapter/nemo): add EventLoggingCallback for lifecycle monitoring (#34)
Sample Log: https://gist.github.com/kkkapu/09eb0d8522dd6ee992c7833cb630714c
1 parent 22d8691 commit 06ebcbf

File tree

2 files changed

+245
-0
lines changed

2 files changed

+245
-0
lines changed
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, Dict
16+
17+
import lightning.pytorch as pl
18+
import torch
19+
from lightning.pytorch import callbacks as pl_callbacks
20+
from lightning.pytorch.utilities.types import STEP_OUTPUT
21+
from typing_extensions import override
22+
23+
from ml_flashpoint.core.mlf_logging import get_logger
24+
25+
_LOGGER = get_logger(__name__)
26+
27+
28+
class EventLoggingCallback(pl_callbacks.Callback):
29+
"""
30+
A comprehensive logging callback to record timestamps for all key PyTorch Lightning
31+
lifecycle events to monitor execution flow.
32+
"""
33+
34+
def _log_event(self, hook_name: str) -> None:
35+
_LOGGER.info(f"event={hook_name}")
36+
37+
@override
38+
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
39+
"""Called when the train begins."""
40+
self._log_event("on_train_start")
41+
42+
@override
43+
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
44+
"""Called when the train ends."""
45+
self._log_event("on_train_end")
46+
47+
@override
48+
def on_train_batch_start(
49+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int
50+
) -> None:
51+
"""Called when the train batch begins."""
52+
self._log_event("on_train_batch_start")
53+
54+
@override
55+
def on_train_batch_end(
56+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
57+
) -> None:
58+
"""Called when the train batch ends."""
59+
self._log_event("on_train_batch_end")
60+
61+
@override
62+
def on_validation_batch_start(
63+
self,
64+
trainer: "pl.Trainer",
65+
pl_module: "pl.LightningModule",
66+
batch: Any,
67+
batch_idx: int,
68+
dataloader_idx: int = 0,
69+
) -> None:
70+
"""Called when the validation batch begins."""
71+
self._log_event("on_validation_batch_start")
72+
73+
@override
74+
def on_validation_batch_end(
75+
self,
76+
trainer: "pl.Trainer",
77+
pl_module: "pl.LightningModule",
78+
outputs: STEP_OUTPUT,
79+
batch: Any,
80+
batch_idx: int,
81+
dataloader_idx: int = 0,
82+
) -> None:
83+
"""Called when the validation batch ends."""
84+
self._log_event("on_validation_batch_end")
85+
86+
@override
87+
def on_test_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
88+
"""Called when the test epoch begins."""
89+
self._log_event("on_test_epoch_start")
90+
91+
@override
92+
def on_test_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
93+
"""Called when the test epoch ends."""
94+
self._log_event("on_test_epoch_end")
95+
96+
@override
97+
def on_test_batch_start(
98+
self,
99+
trainer: "pl.Trainer",
100+
pl_module: "pl.LightningModule",
101+
batch: Any,
102+
batch_idx: int,
103+
dataloader_idx: int = 0,
104+
) -> None:
105+
"""Called when the test batch begins."""
106+
self._log_event("on_test_batch_start")
107+
108+
@override
109+
def on_test_batch_end(
110+
self,
111+
trainer: "pl.Trainer",
112+
pl_module: "pl.LightningModule",
113+
outputs: STEP_OUTPUT,
114+
batch: Any,
115+
batch_idx: int,
116+
dataloader_idx: int = 0,
117+
) -> None:
118+
"""Called when the test batch ends."""
119+
self._log_event("on_test_batch_end")
120+
121+
@override
122+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
123+
"""Called when loading a checkpoint, implement to reload callback state."""
124+
self._log_event("load_state_dict")
125+
126+
@override
127+
def on_save_checkpoint(
128+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
129+
) -> None:
130+
"""Called when saving a checkpoint."""
131+
self._log_event("on_save_checkpoint")
132+
133+
@override
134+
def on_load_checkpoint(
135+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
136+
) -> None:
137+
"""Called when loading a model checkpoint, use to reload state."""
138+
self._log_event("on_load_checkpoint")
139+
140+
@override
141+
def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss: torch.Tensor) -> None:
142+
"""Called before loss.backward()."""
143+
self._log_event("on_before_backward")
144+
145+
@override
146+
def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
147+
"""Called after loss.backward() and before optimizers are stepped."""
148+
self._log_event("on_after_backward")
149+
150+
@override
151+
def on_before_optimizer_step(
152+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: torch.optim.Optimizer
153+
) -> None:
154+
"""Called before optimizer.step()."""
155+
self._log_event("on_before_optimizer_step")
156+
157+
@override
158+
def on_before_zero_grad(
159+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: torch.optim.Optimizer
160+
) -> None:
161+
"""Called before optimizer.zero_grad()."""
162+
self._log_event("on_before_zero_grad")
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest.mock import MagicMock
16+
17+
import lightning.pytorch as pl
18+
import pytest
19+
import torch
20+
21+
from ml_flashpoint.adapter.nemo.event_logging_callback import EventLoggingCallback
22+
23+
# Exhaustive list of all hooks implemented in EventLoggingCallback.
24+
# Format: (method_name, extra_kwargs, expected_event_string)
25+
HOOKS_TO_TEST = [
26+
("on_train_start", {}, "on_train_start"),
27+
("on_train_end", {}, "on_train_end"),
28+
("on_train_batch_start", {"batch": None, "batch_idx": 0}, "on_train_batch_start"),
29+
("on_train_batch_end", {"outputs": None, "batch": None, "batch_idx": 0}, "on_train_batch_end"),
30+
("on_validation_batch_start", {"batch": None, "batch_idx": 0, "dataloader_idx": 0}, "on_validation_batch_start"),
31+
(
32+
"on_validation_batch_end",
33+
{"outputs": None, "batch": None, "batch_idx": 0, "dataloader_idx": 0},
34+
"on_validation_batch_end",
35+
),
36+
("on_test_epoch_start", {}, "on_test_epoch_start"),
37+
("on_test_epoch_end", {}, "on_test_epoch_end"),
38+
("on_test_batch_start", {"batch": None, "batch_idx": 0, "dataloader_idx": 0}, "on_test_batch_start"),
39+
("on_test_batch_end", {"outputs": None, "batch": None, "batch_idx": 0, "dataloader_idx": 0}, "on_test_batch_end"),
40+
("load_state_dict", {"state_dict": {}}, "load_state_dict"),
41+
("on_save_checkpoint", {"checkpoint": {}}, "on_save_checkpoint"),
42+
("on_load_checkpoint", {"checkpoint": {}}, "on_load_checkpoint"),
43+
("on_before_backward", {"loss": torch.tensor(0.0)}, "on_before_backward"),
44+
("on_after_backward", {}, "on_after_backward"),
45+
("on_before_optimizer_step", {"optimizer": MagicMock(spec=torch.optim.Optimizer)}, "on_before_optimizer_step"),
46+
("on_before_zero_grad", {"optimizer": MagicMock(spec=torch.optim.Optimizer)}, "on_before_zero_grad"),
47+
]
48+
49+
50+
def test_is_subtype_of_pytorch_lightning_callback():
51+
"""Verify inheritance to ensure compatibility with PyTorch Lightning."""
52+
assert issubclass(EventLoggingCallback, pl.callbacks.Callback)
53+
54+
55+
@pytest.mark.parametrize("hook_name, kwargs, expected_event", HOOKS_TO_TEST)
56+
def test_event_logging_hooks_log_correctly(mocker, hook_name, kwargs, expected_event):
57+
"""
58+
Tests that every lifecycle hook in EventLoggingCallback logs the correct event.
59+
"""
60+
61+
# Given
62+
mock_logger = mocker.patch("ml_flashpoint.adapter.nemo.event_logging_callback._LOGGER")
63+
callback = EventLoggingCallback()
64+
65+
# Mock Trainer and LightningModule as required by the PyTorch Lightning API.
66+
trainer = mocker.MagicMock(spec=pl.Trainer)
67+
pl_module = mocker.MagicMock(spec=pl.LightningModule)
68+
69+
# When
70+
# Dynamically fetch the method to test.
71+
hook_method = getattr(callback, hook_name)
72+
73+
# load_state_dict does not follow the (trainer, pl_module) signature.
74+
if hook_name == "load_state_dict":
75+
hook_method(**kwargs)
76+
else:
77+
hook_method(trainer=trainer, pl_module=pl_module, **kwargs)
78+
79+
# Then
80+
# Verify the log content matches the implementation of _log_event.
81+
# LOGGER.info(f"event={hook_name}")
82+
expected_log_msg = f"event={expected_event}"
83+
mock_logger.info.assert_called_once_with(expected_log_msg)

0 commit comments

Comments
 (0)