@@ -213,8 +213,6 @@ def __init__(
213
213
self ._generator_gradient_accumulator .reset ()
214
214
self ._discriminator_gradient_accumulator .reset ()
215
215
216
-
217
-
218
216
def init_train_eval_metrics (self , list_metrics_name ):
219
217
with self ._strategy .scope ():
220
218
super ().init_train_eval_metrics (list_metrics_name )
@@ -706,7 +704,6 @@ def __init__(
706
704
self ._gradient_accumulator = GradientAccumulator ()
707
705
self ._gradient_accumulator .reset ()
708
706
709
-
710
707
def init_train_eval_metrics (self , list_metrics_name ):
711
708
with self ._strategy .scope ():
712
709
super ().init_train_eval_metrics (list_metrics_name )
@@ -833,7 +830,7 @@ def _one_step_forward_per_replica(self, batch):
833
830
if self .config ["gradient_accumulation_steps" ] == 1 :
834
831
gradients , per_replica_losses = self ._calculate_gradient_per_batch (batch )
835
832
self ._optimizer .apply_gradients (
836
- zip (gradients , self ._trainable_variables )
833
+ zip (gradients , self ._trainable_variables ), 1.0
837
834
)
838
835
else :
839
836
# gradient acummulation here.
@@ -856,7 +853,7 @@ def _one_step_forward_per_replica(self, batch):
856
853
857
854
gradients = self ._gradient_accumulator .gradients
858
855
self ._optimizer .apply_gradients (
859
- zip (gradients , self ._trainable_variables )
856
+ zip (gradients , self ._trainable_variables ), 1.0
860
857
)
861
858
self ._gradient_accumulator .reset ()
862
859
0 commit comments