|
| 1 | +# Copyright (c) 2025 lightning-hydra-boilerplate |
| 2 | +# Licensed under the MIT License. |
| 3 | + |
| 4 | +"""Callback for logging metrics during validation and test using PyTorch Lightning.""" |
| 5 | + |
| 6 | +import torch |
| 7 | +from lightning.pytorch import Callback, LightningModule, Trainer |
| 8 | +from lightning.pytorch.utilities.types import STEP_OUTPUT |
| 9 | + |
| 10 | + |
| 11 | +class MetricEvaluator(Callback): |
| 12 | + """Logs custom metrics during validation and test epochs. |
| 13 | +
|
| 14 | + Args: |
| 15 | + metrics (Dict[str, List[torch.nn.Module]]): A dictionary mapping each stage |
| 16 | + ("validation", "test") to a list of metric instances. |
| 17 | +
|
| 18 | + Example: |
| 19 | + { |
| 20 | + "validation": [Accuracy(...), F1Score(...)], |
| 21 | + "test": [Accuracy(...), F1Score(...)] |
| 22 | + } |
| 23 | + """ |
| 24 | + |
| 25 | + def __init__(self, metrics: dict[str, list[torch.nn.Module]]) -> None: |
| 26 | + self.metrics = metrics |
| 27 | + |
| 28 | + def _reset(self, stage: str) -> None: |
| 29 | + """Reset all metrics for a given stage.""" |
| 30 | + for metric in self.metrics.get(stage, []): |
| 31 | + metric.reset() |
| 32 | + |
| 33 | + def _update(self, stage: str, preds: torch.Tensor, targets: torch.Tensor) -> None: |
| 34 | + """Update all metrics for a given stage with predictions and targets.""" |
| 35 | + for metric in self.metrics.get(stage, []): |
| 36 | + metric.to(preds.device) |
| 37 | + metric.update(preds, targets) |
| 38 | + |
| 39 | + def _log(self, stage: str, pl_module: LightningModule) -> None: |
| 40 | + """Compute and log all metrics for a given stage.""" |
| 41 | + # prefix = "val" if stage == "validation" else "test" |
| 42 | + # for metric in self.metrics.get(stage, []): |
| 43 | + # name = metric.__class__.__name__.lower() |
| 44 | + # value = metric.compute() |
| 45 | + # pl_module.log(f"{prefix}_{name}", value, prog_bar=True) |
| 46 | + |
| 47 | + def on_validation_epoch_start(self, _trainer: Trainer, _pl_module: LightningModule) -> None: |
| 48 | + """Reset metrics at the start of validation epoch.""" |
| 49 | + self._reset("validation") |
| 50 | + |
| 51 | + def on_validation_batch_end( |
| 52 | + self, |
| 53 | + _trainer: Trainer, |
| 54 | + pl_module: LightningModule, |
| 55 | + _outputs: STEP_OUTPUT, |
| 56 | + batch: dict, |
| 57 | + _batch_idx: int, |
| 58 | + _dataloader_idx: int = 0, |
| 59 | + ) -> None: |
| 60 | + """Update validation metrics with each batch.""" |
| 61 | + x, y = batch[0], batch[1] |
| 62 | + #preds = pl_module(x).argmax(dim=1) |
| 63 | + #self._update("validation", preds, y) |
| 64 | + |
| 65 | + def on_validation_epoch_end(self, _trainer: Trainer, pl_module: LightningModule) -> None: |
| 66 | + """Compute and log validation metrics at end of epoch.""" |
| 67 | + self._log("validation", pl_module) |
| 68 | + |
| 69 | + def on_test_epoch_start(self, _trainer: Trainer, _pl_module: LightningModule) -> None: |
| 70 | + """Reset metrics at the start of test epoch.""" |
| 71 | + self._reset("test") |
| 72 | + |
| 73 | + def on_test_batch_end( |
| 74 | + self, |
| 75 | + _trainer: Trainer, |
| 76 | + pl_module: LightningModule, |
| 77 | + _outputs: STEP_OUTPUT, |
| 78 | + batch: dict, |
| 79 | + _batch_idx: int, |
| 80 | + _dataloader_idx: int = 0, |
| 81 | + ) -> None: |
| 82 | + """Update test metrics with each batch.""" |
| 83 | + # x, y = batch["image"], batch["label"] |
| 84 | + # preds = pl_module(x).argmax(dim=1) |
| 85 | + # self._update("test", preds, y) |
| 86 | + |
| 87 | + def on_test_epoch_end(self, _trainer: Trainer, pl_module: LightningModule) -> None: |
| 88 | + """Compute and log test metrics at end of epoch.""" |
| 89 | + self._log("test", pl_module) |
0 commit comments