@@ -247,7 +247,7 @@ def train_step(self, batch):
247247 preds = self ([source , dec_input ])
248248 one_hot = tf .one_hot (dec_target , depth = self .num_classes )
249249 mask = tf .math .logical_not (tf .math .equal (dec_target , 0 ))
250- loss = self .compute_loss (None , one_hot , preds , sample_weight = mask )
250+ loss = model .compute_loss (None , one_hot , preds , sample_weight = mask )
251251 trainable_vars = self .trainable_variables
252252 gradients = tape .gradient (loss , trainable_vars )
253253 self .optimizer .apply_gradients (zip (gradients , trainable_vars ))
@@ -262,7 +262,7 @@ def test_step(self, batch):
262262 preds = self ([source , dec_input ])
263263 one_hot = tf .one_hot (dec_target , depth = self .num_classes )
264264 mask = tf .math .logical_not (tf .math .equal (dec_target , 0 ))
265- loss = self .compute_loss (None , one_hot , preds , sample_weight = mask )
265+ loss = model .compute_loss (None , one_hot , preds , sample_weight = mask )
266266 self .loss_metric .update_state (loss )
267267 return {"loss" : self .loss_metric .result ()}
268268
0 commit comments