@@ -661,12 +661,6 @@ def eval_batch(
661661 assert isinstance (loss_tensor , paddle .Tensor ), (
662662 "Currently, loss_fn should obtain Paddle.Tensor dtype"
663663 )
664- with paddle .amp .auto_cast (enable = False ):
665- if (
666- self .accumulate_steps > 1
667- and not self ._delay_scale_loss
668- ):
669- loss_tensor = loss_tensor / self .accumulate_steps
670664 if self .total_loss is None :
671665 self .total_loss = []
672666 # when self.total_loss length is less than idx, append a new tensor
@@ -691,17 +685,14 @@ def eval_batch(
691685 return_micro_batch_loss = False
692686 for idx in range (len (self ._layers ._loss_fn )):
693687 self .total_loss [idx ] = paddle .to_tensor (self .total_loss [idx ])
694- if not return_micro_batch_loss :
695- # TODO(shenliang03): it will use mean/sum to calculate loss
696- tmp = paddle .zeros_like (self .total_loss [idx ][0 ])
697- for loss in self .total_loss [idx ]:
698- tmp += loss .detach ()
699- if not self ._delay_scale_loss :
700- losses .append (tmp )
701- else :
702- losses .append (tmp / self .accumulate_steps )
703- else :
704- losses .append (self .total_loss [idx ].detach ())
688+ # if not return_micro_batch_loss:
689+ # TODO(shenliang03): it will use mean/sum to calculate loss
690+ tmp = paddle .zeros_like (self .total_loss [idx ][0 ])
691+ for loss in self .total_loss [idx ]:
692+ tmp += loss .detach ()
693+ losses .append (tmp / self .accumulate_steps )
694+ # else:
695+ # losses.append(self.total_loss[idx].detach())
705696 res = losses [0 ] if len (losses ) == 1 else losses
706697 else :
707698 res = output_list
0 commit comments