@@ -76,7 +76,13 @@ def __init__(
7676 self .weight_decay = weight_decay
7777 self .betas = betas
7878
79- def _step (self , batch : tuple [Tensor , Tensor , Tensor ], batch_idx : int ) -> Tensor :
79+ def _log_metric (self , name : str , value : Tensor ) -> None :
80+ """Log metric at epoch level with progress bar display."""
81+ self .log (name , value , on_epoch = True , on_step = False , prog_bar = True )
82+
83+ def _step (
84+ self , batch : tuple [Tensor , Tensor , Tensor ], batch_idx : int , stage : str
85+ ) -> Tensor :
8086 """Defines a single (mini-batch) step in the training/test loop.
8187
8288 Parameters
@@ -85,6 +91,8 @@ def _step(self, batch: tuple[Tensor, Tensor, Tensor], batch_idx: int) -> Tensor:
8591 A batch of data containing aptamer sequences, protein sequences, and labels.
8692 batch_idx: int
8793 Index of the batch.
94+ stage: str
95+ The stage of the step, either "train" or "test".
8896
8997 Returns
9098 -------
@@ -100,7 +108,10 @@ def _step(self, batch: tuple[Tensor, Tensor, Tensor], batch_idx: int) -> Tensor:
100108 y_pred = (y_hat > 0.5 ).float ()
101109 accuracy = (y_pred == y .float ()).float ().mean ()
102110
103- return loss , accuracy
111+ self ._log_metric (f"{ stage } _loss" , loss )
112+ self ._log_metric (f"{ stage } _accuracy" , accuracy )
113+
114+ return loss
104115
105116 def training_step (
106117 self , batch : tuple [Tensor , Tensor , Tensor ], batch_idx : int
@@ -119,14 +130,7 @@ def training_step(
119130 Tensor
120131 The computed loss for the batch.
121132 """
122- loss , accuracy = self ._step (batch , batch_idx )
123-
124- self .log ("train_loss" , loss , on_epoch = True , on_step = False , prog_bar = True )
125- self .log (
126- "train_accuracy" , accuracy , on_epoch = True , on_step = False , prog_bar = True
127- )
128-
129- return loss
133+ return self ._step (batch , batch_idx , "train" )
130134
131135 def test_step (self , batch : tuple [Tensor , Tensor , Tensor ], batch_idx : int ) -> Tensor :
132136 """Defines a single (mini-batch) step in the test loop.
@@ -143,12 +147,7 @@ def test_step(self, batch: tuple[Tensor, Tensor, Tensor], batch_idx: int) -> Ten
143147 Tensor
144148 The computed loss for the batch.
145149 """
146- loss , accuracy = self ._step (batch , batch_idx )
147-
148- self .log ("test_loss" , loss , on_epoch = True , on_step = False , prog_bar = True )
149- self .log ("test_accuracy" , accuracy , on_epoch = True , on_step = False , prog_bar = True )
150-
151- return loss
150+ return self ._step (batch , batch_idx , "test" )
152151
153152 def configure_optimizers (self ) -> torch .optim .Optimizer :
154153 """Defines the optimizer to be used during training."""
@@ -256,7 +255,9 @@ def __init__(
256255 self .weight_mlm = weight_mlm
257256 self .weight_ssp = weight_ssp
258257
259- def _step (self , batch : tuple [Tensor , Tensor , Tensor ], batch_idx : int ) -> Tensor :
258+ def _step (
259+ self , batch : tuple [Tensor , Tensor , Tensor ], batch_idx : int , stage : str
260+ ) -> Tensor :
260261 """Defines a single (mini-batch) step in the training/test loop.
261262
262263 The loss function is a weighted sum of the masked language modeling (MLM)
@@ -268,6 +269,8 @@ def _step(self, batch: tuple[Tensor, Tensor, Tensor], batch_idx: int) -> Tensor:
268269 A batch of data containing aptamer sequences, protein sequences, and labels.
269270 batch_idx: int
270271 Index of the batch.
272+ stage: str
273+ The stage of the step, either "train" or "test".
271274
272275 Returns
273276 -------
@@ -284,46 +287,8 @@ def _step(self, batch: tuple[Tensor, Tensor, Tensor], batch_idx: int) -> Tensor:
284287 loss_ssp = F .cross_entropy (y_ssp_hat .transpose (1 , 2 ), y_ssp .long ())
285288 loss = self .weight_mlm * loss_mlm + self .weight_ssp * loss_ssp
286289
287- return loss
288-
289- def training_step (
290- self , batch : tuple [Tensor , Tensor , Tensor ], batch_idx : int
291- ) -> Tensor :
292- """Defines a single (mini-batch) step in the training loop.
290+ self ._log_metric (f"{ stage } _loss" , loss )
293291
294- Parameters
295- ----------
296- batch: tuple[Tensor, Tensor, Tensor]
297- A batch of data containing aptamer sequences, protein sequences, and labels.
298- batch_idx: int
299- Index of the batch.
300-
301- Returns
302- -------
303- Tensor
304- The computed loss for the batch.
305- """
306- loss = self ._step (batch , batch_idx )
307- self .log ("train_loss" , loss , on_epoch = True , on_step = False , prog_bar = True )
308- return loss
309-
310- def test_step (self , batch : tuple [Tensor , Tensor , Tensor ], batch_idx : int ) -> Tensor :
311- """Defines a single (mini-batch) step in the test loop.
312-
313- Parameters
314- ----------
315- batch: tuple[Tensor, Tensor, Tensor]
316- A batch of data containing aptamer sequences, protein sequences, and labels.
317- batch_idx: int
318- Index of the batch.
319-
320- Returns
321- -------
322- Tensor
323- The computed loss for the batch.
324- """
325- loss = self ._step (batch , batch_idx )
326- self .log ("test_loss" , loss , on_epoch = True , on_step = False , prog_bar = True )
327292 return loss
328293
329294 def configure_optimizers (self ) -> torch .optim .Optimizer :
0 commit comments