Skip to content

Commit 58e96f1

Browse files
Update transformer_asr.md
1 parent 9f8949e commit 58e96f1

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

examples/audio/md/transformer_asr.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ class Transformer(keras.Model):
262262
preds = self([source, dec_input])
263263
one_hot = tf.one_hot(dec_target, depth=self.num_classes)
264264
mask = tf.math.logical_not(tf.math.equal(dec_target, 0))
265-
loss = model.compute_loss(None, one_hot, preds, sample_weight=mask)
265+
loss = self.compute_loss(None, one_hot, preds, sample_weight=mask)
266266
trainable_vars = self.trainable_variables
267267
gradients = tape.gradient(loss, trainable_vars)
268268
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
@@ -277,7 +277,7 @@ class Transformer(keras.Model):
277277
preds = self([source, dec_input])
278278
one_hot = tf.one_hot(dec_target, depth=self.num_classes)
279279
mask = tf.math.logical_not(tf.math.equal(dec_target, 0))
280-
loss = model.compute_loss(None, one_hot, preds, sample_weight=mask)
280+
loss = self.compute_loss(None, one_hot, preds, sample_weight=mask)
281281
self.loss_metric.update_state(loss)
282282
return {"loss": self.loss_metric.result()}
283283

0 commit comments

Comments
 (0)