Skip to content

Commit b246dd4

Browse files
Reduce logging duplication in AptaTransLightning (#241)
fix #224 Added _log_metric helper to avoid repeating the same self.log() params in both _step methods. --------- Co-authored-by: tarun111111 <tarunpuri2544@gmail.com>
1 parent 1335756 commit b246dd4

1 file changed

Lines changed: 21 additions & 56 deletions

File tree

pyaptamer/aptatrans/_model_lightning.py

Lines changed: 21 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,13 @@ def __init__(
7676
self.weight_decay = weight_decay
7777
self.betas = betas
7878

79-
def _step(self, batch: tuple[Tensor, Tensor, Tensor], batch_idx: int) -> Tensor:
79+
def _log_metric(self, name: str, value: Tensor) -> None:
80+
"""Log metric at epoch level with progress bar display."""
81+
self.log(name, value, on_epoch=True, on_step=False, prog_bar=True)
82+
83+
def _step(
84+
self, batch: tuple[Tensor, Tensor, Tensor], batch_idx: int, stage: str
85+
) -> Tensor:
8086
"""Defines a single (mini-batch) step in the training/test loop.
8187
8288
Parameters
@@ -85,6 +91,8 @@ def _step(self, batch: tuple[Tensor, Tensor, Tensor], batch_idx: int) -> Tensor:
8591
A batch of data containing aptamer sequences, protein sequences, and labels.
8692
batch_idx: int
8793
Index of the batch.
94+
stage: str
95+
The stage of the step, either "train" or "test".
8896
8997
Returns
9098
-------
@@ -100,7 +108,10 @@ def _step(self, batch: tuple[Tensor, Tensor, Tensor], batch_idx: int) -> Tensor:
100108
y_pred = (y_hat > 0.5).float()
101109
accuracy = (y_pred == y.float()).float().mean()
102110

103-
return loss, accuracy
111+
self._log_metric(f"{stage}_loss", loss)
112+
self._log_metric(f"{stage}_accuracy", accuracy)
113+
114+
return loss
104115

105116
def training_step(
106117
self, batch: tuple[Tensor, Tensor, Tensor], batch_idx: int
@@ -119,14 +130,7 @@ def training_step(
119130
Tensor
120131
The computed loss for the batch.
121132
"""
122-
loss, accuracy = self._step(batch, batch_idx)
123-
124-
self.log("train_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
125-
self.log(
126-
"train_accuracy", accuracy, on_epoch=True, on_step=False, prog_bar=True
127-
)
128-
129-
return loss
133+
return self._step(batch, batch_idx, "train")
130134

131135
def test_step(self, batch: tuple[Tensor, Tensor, Tensor], batch_idx: int) -> Tensor:
132136
"""Defines a single (mini-batch) step in the test loop.
@@ -143,12 +147,7 @@ def test_step(self, batch: tuple[Tensor, Tensor, Tensor], batch_idx: int) -> Ten
143147
Tensor
144148
The computed loss for the batch.
145149
"""
146-
loss, accuracy = self._step(batch, batch_idx)
147-
148-
self.log("test_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
149-
self.log("test_accuracy", accuracy, on_epoch=True, on_step=False, prog_bar=True)
150-
151-
return loss
150+
return self._step(batch, batch_idx, "test")
152151

153152
def configure_optimizers(self) -> torch.optim.Optimizer:
154153
"""Defines the optimizer to be used during training."""
@@ -256,7 +255,9 @@ def __init__(
256255
self.weight_mlm = weight_mlm
257256
self.weight_ssp = weight_ssp
258257

259-
def _step(self, batch: tuple[Tensor, Tensor, Tensor], batch_idx: int) -> Tensor:
258+
def _step(
259+
self, batch: tuple[Tensor, Tensor, Tensor], batch_idx: int, stage: str
260+
) -> Tensor:
260261
"""Defines a single (mini-batch) step in the training/test loop.
261262
262263
The loss function is a weighted sum of the masked language modeling (MLM)
@@ -268,6 +269,8 @@ def _step(self, batch: tuple[Tensor, Tensor, Tensor], batch_idx: int) -> Tensor:
268269
A batch of data containing aptamer sequences, protein sequences, and labels.
269270
batch_idx: int
270271
Index of the batch.
272+
stage: str
273+
The stage of the step, either "train" or "test".
271274
272275
Returns
273276
-------
@@ -284,46 +287,8 @@ def _step(self, batch: tuple[Tensor, Tensor, Tensor], batch_idx: int) -> Tensor:
284287
loss_ssp = F.cross_entropy(y_ssp_hat.transpose(1, 2), y_ssp.long())
285288
loss = self.weight_mlm * loss_mlm + self.weight_ssp * loss_ssp
286289

287-
return loss
288-
289-
def training_step(
290-
self, batch: tuple[Tensor, Tensor, Tensor], batch_idx: int
291-
) -> Tensor:
292-
"""Defines a single (mini-batch) step in the training loop.
290+
self._log_metric(f"{stage}_loss", loss)
293291

294-
Parameters
295-
----------
296-
batch: tuple[Tensor, Tensor, Tensor]
297-
A batch of data containing aptamer sequences, protein sequences, and labels.
298-
batch_idx: int
299-
Index of the batch.
300-
301-
Returns
302-
-------
303-
Tensor
304-
The computed loss for the batch.
305-
"""
306-
loss = self._step(batch, batch_idx)
307-
self.log("train_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
308-
return loss
309-
310-
def test_step(self, batch: tuple[Tensor, Tensor, Tensor], batch_idx: int) -> Tensor:
311-
"""Defines a single (mini-batch) step in the test loop.
312-
313-
Parameters
314-
----------
315-
batch: tuple[Tensor, Tensor, Tensor]
316-
A batch of data containing aptamer sequences, protein sequences, and labels.
317-
batch_idx: int
318-
Index of the batch.
319-
320-
Returns
321-
-------
322-
Tensor
323-
The computed loss for the batch.
324-
"""
325-
loss = self._step(batch, batch_idx)
326-
self.log("test_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
327292
return loss
328293

329294
def configure_optimizers(self) -> torch.optim.Optimizer:

0 commit comments

Comments
 (0)