Skip to content

Commit ea72bab

Browse files
committed
🛸 Fix (#389).
1 parent 28e0128 commit ea72bab

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

tensorflow_tts/trainers/base_trainer.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,6 @@ def __init__(
213213
self._generator_gradient_accumulator.reset()
214214
self._discriminator_gradient_accumulator.reset()
215215

216-
217-
218216
def init_train_eval_metrics(self, list_metrics_name):
219217
with self._strategy.scope():
220218
super().init_train_eval_metrics(list_metrics_name)
@@ -706,7 +704,6 @@ def __init__(
706704
self._gradient_accumulator = GradientAccumulator()
707705
self._gradient_accumulator.reset()
708706

709-
710707
def init_train_eval_metrics(self, list_metrics_name):
711708
with self._strategy.scope():
712709
super().init_train_eval_metrics(list_metrics_name)
@@ -833,7 +830,7 @@ def _one_step_forward_per_replica(self, batch):
833830
if self.config["gradient_accumulation_steps"] == 1:
834831
gradients, per_replica_losses = self._calculate_gradient_per_batch(batch)
835832
self._optimizer.apply_gradients(
836-
zip(gradients, self._trainable_variables)
833+
zip(gradients, self._trainable_variables), 1.0
837834
)
838835
else:
839836
# gradient acummulation here.
@@ -856,7 +853,7 @@ def _one_step_forward_per_replica(self, batch):
856853

857854
gradients = self._gradient_accumulator.gradients
858855
self._optimizer.apply_gradients(
859-
zip(gradients, self._trainable_variables)
856+
zip(gradients, self._trainable_variables), 1.0
860857
)
861858
self._gradient_accumulator.reset()
862859

0 commit comments

Comments
 (0)