|
321 | 321 | " preds = self([source, dec_input])\n", |
322 | 322 | " one_hot = tf.one_hot(dec_target, depth=self.num_classes)\n", |
323 | 323 | " mask = tf.math.logical_not(tf.math.equal(dec_target, 0))\n", |
324 | | - " loss = model.compute_loss(None, one_hot, preds, sample_weight=mask)\n", |
| 324 | + " loss = self.compute_loss(None, one_hot, preds, sample_weight=mask)\n", |
325 | 325 | " trainable_vars = self.trainable_variables\n", |
326 | 326 | " gradients = tape.gradient(loss, trainable_vars)\n", |
327 | 327 | " self.optimizer.apply_gradients(zip(gradients, trainable_vars))\n", |
|
336 | 336 | " preds = self([source, dec_input])\n", |
337 | 337 | " one_hot = tf.one_hot(dec_target, depth=self.num_classes)\n", |
338 | 338 | " mask = tf.math.logical_not(tf.math.equal(dec_target, 0))\n", |
339 | | - " loss = model.compute_loss(None, one_hot, preds, sample_weight=mask)\n", |
| 339 | + " loss = self.compute_loss(None, one_hot, preds, sample_weight=mask)\n", |
340 | 340 | " self.loss_metric.update_state(loss)\n", |
341 | 341 | " return {\"loss\": self.loss_metric.result()}\n", |
342 | 342 | "\n", |
|
0 commit comments