Skip to content

Commit 4871642

Browse files
committed
Added loss function from 10.7.5 to Seq2Seq
1 parent 23d7a5a commit 4871642

File tree

4 files changed

+26
-0
lines changed

4 files changed

+26
-0
lines changed

d2l/jax.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,6 +1199,17 @@ def validation_step(self, params, batch, state):
11991199
def configure_optimizers(self):
12001200
# Adam optimizer is used here
12011201
return optax.adam(learning_rate=self.lr)
1202+
1203+
@partial(jax.jit, static_argnums=(0, 5))
1204+
def loss(self, params, X, Y, state, averaged=False):
1205+
Y_hat = state.apply_fn({'params': params}, *X,
1206+
rngs={'dropout': state.dropout_rng})
1207+
Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1]))
1208+
Y = Y.reshape((-1,))
1209+
fn = optax.softmax_cross_entropy_with_integer_labels
1210+
l = fn(Y_hat, Y)
1211+
mask = (Y.reshape(-1) != self.tgt_pad).astype(jnp.float32)
1212+
return (l * mask).sum() / mask.sum(), {}
12021213

12031214
def bleu(pred_seq, label_seq, k):
12041215
"""Compute the BLEU.

d2l/mxnet.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,6 +1025,11 @@ def configure_optimizers(self):
10251025
# Adam optimizer is used here
10261026
return gluon.Trainer(self.parameters(), 'adam',
10271027
{'learning_rate': self.lr})
1028+
1029+
def loss(self, Y_hat, Y):
1030+
l = super(Seq2Seq, self).loss(Y_hat, Y, averaged=False)
1031+
mask = (Y.reshape(-1) != self.tgt_pad).astype(np.float32)
1032+
return (l * mask).sum() / mask.sum()
10281033

10291034
def bleu(pred_seq, label_seq, k):
10301035
"""Compute the BLEU.

d2l/tensorflow.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -979,6 +979,11 @@ def configure_optimizers(self):
979979
# Adam optimizer is used here
980980
return tf.keras.optimizers.Adam(learning_rate=self.lr)
981981

982+
def loss(self, Y_hat, Y):
983+
l = super(Seq2Seq, self).loss(Y_hat, Y, averaged=False)
984+
mask = tf.cast(tf.reshape(Y, -1) != self.tgt_pad, tf.float32)
985+
return tf.reduce_sum(l * mask) / tf.reduce_sum(mask)
986+
982987
def bleu(pred_seq, label_seq, k):
983988
"""Compute the BLEU.
984989

d2l/torch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,6 +1026,11 @@ def configure_optimizers(self):
10261026
# Adam optimizer is used here
10271027
return torch.optim.Adam(self.parameters(), lr=self.lr)
10281028

1029+
def loss(self, Y_hat, Y):
1030+
l = super(Seq2Seq, self).loss(Y_hat, Y, averaged=False)
1031+
mask = (Y.reshape(-1) != self.tgt_pad).type(torch.float32)
1032+
return (l * mask).sum() / mask.sum()
1033+
10291034
def bleu(pred_seq, label_seq, k):
10301035
"""Compute the BLEU.
10311036

0 commit comments

Comments
 (0)