Skip to content

Commit 9f8949e

Browse files
Fix incorrect reference to model in Transformer class (keras-team#2078)
fix: correct model reference to self in Transformer class Changed model.compute_loss() to self.compute_loss() in train_step and test_step methods to properly reference the class instance.
1 parent 300124b commit 9f8949e

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

examples/audio/transformer_asr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def train_step(self, batch):
247247
preds = self([source, dec_input])
248248
one_hot = tf.one_hot(dec_target, depth=self.num_classes)
249249
mask = tf.math.logical_not(tf.math.equal(dec_target, 0))
250-
loss = model.compute_loss(None, one_hot, preds, sample_weight=mask)
250+
loss = self.compute_loss(None, one_hot, preds, sample_weight=mask)
251251
trainable_vars = self.trainable_variables
252252
gradients = tape.gradient(loss, trainable_vars)
253253
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
@@ -262,7 +262,7 @@ def test_step(self, batch):
262262
preds = self([source, dec_input])
263263
one_hot = tf.one_hot(dec_target, depth=self.num_classes)
264264
mask = tf.math.logical_not(tf.math.equal(dec_target, 0))
265-
loss = model.compute_loss(None, one_hot, preds, sample_weight=mask)
265+
loss = self.compute_loss(None, one_hot, preds, sample_weight=mask)
266266
self.loss_metric.update_state(loss)
267267
return {"loss": self.loss_metric.result()}
268268

0 commit comments

Comments
 (0)