From 487164257a1dbae37ba49c1a0ee27dc4933d480e Mon Sep 17 00:00:00 2001 From: nickmcgreivy Date: Sun, 10 Aug 2025 21:06:02 -0400 Subject: [PATCH 1/3] Added loss function from 10.7.5 to Seq2Seq --- d2l/jax.py | 11 +++++++++++ d2l/mxnet.py | 5 +++++ d2l/tensorflow.py | 5 +++++ d2l/torch.py | 5 +++++ 4 files changed, 26 insertions(+) diff --git a/d2l/jax.py b/d2l/jax.py index b1210edc9c..fdde1bb247 100644 --- a/d2l/jax.py +++ b/d2l/jax.py @@ -1199,6 +1199,17 @@ def validation_step(self, params, batch, state): def configure_optimizers(self): # Adam optimizer is used here return optax.adam(learning_rate=self.lr) + + @partial(jax.jit, static_argnums=(0, 5)) + def loss(self, params, X, Y, state, averaged=False): + Y_hat = state.apply_fn({'params': params}, *X, + rngs={'dropout': state.dropout_rng}) + Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1])) + Y = Y.reshape((-1,)) + fn = optax.softmax_cross_entropy_with_integer_labels + l = fn(Y_hat, Y) + mask = (Y.reshape(-1) != self.tgt_pad).astype(jnp.float32) + return (l * mask).sum() / mask.sum(), {} def bleu(pred_seq, label_seq, k): """Compute the BLEU. diff --git a/d2l/mxnet.py b/d2l/mxnet.py index 8e7d7673a3..f51fa72de6 100644 --- a/d2l/mxnet.py +++ b/d2l/mxnet.py @@ -1025,6 +1025,11 @@ def configure_optimizers(self): # Adam optimizer is used here return gluon.Trainer(self.parameters(), 'adam', {'learning_rate': self.lr}) + + def loss(self, Y_hat, Y): + l = super(Seq2Seq, self).loss(Y_hat, Y, averaged=False) + mask = (Y.reshape(-1) != self.tgt_pad).astype(np.float32) + return (l * mask).sum() / mask.sum() def bleu(pred_seq, label_seq, k): """Compute the BLEU. diff --git a/d2l/tensorflow.py b/d2l/tensorflow.py index fd9ca23fda..7e14e59f04 100644 --- a/d2l/tensorflow.py +++ b/d2l/tensorflow.py @@ -979,6 +979,11 @@ def configure_optimizers(self): # Adam optimizer is used here return tf.keras.optimizers.Adam(learning_rate=self.lr) + def loss(self, Y_hat, Y): + l = super(Seq2Seq, self).loss(Y_hat, Y, averaged=False) + mask = tf.cast(tf.reshape(Y, -1) != self.tgt_pad, tf.float32) + return tf.reduce_sum(l * mask) / tf.reduce_sum(mask) + def bleu(pred_seq, label_seq, k): """Compute the BLEU. diff --git a/d2l/torch.py b/d2l/torch.py index 84ce7da901..00ff5c308e 100644 --- a/d2l/torch.py +++ b/d2l/torch.py @@ -1026,6 +1026,11 @@ def configure_optimizers(self): # Adam optimizer is used here return torch.optim.Adam(self.parameters(), lr=self.lr) + def loss(self, Y_hat, Y): + l = super(Seq2Seq, self).loss(Y_hat, Y, averaged=False) + mask = (Y.reshape(-1) != self.tgt_pad).type(torch.float32) + return (l * mask).sum() / mask.sum() + def bleu(pred_seq, label_seq, k): """Compute the BLEU. From aff144173a42742921f528d20a2bc93c13b4f209 Mon Sep 17 00:00:00 2001 From: nickmcgreivy Date: Sun, 10 Aug 2025 21:20:16 -0400 Subject: [PATCH 2/3] Fixed minor bug in machine translation tokenization algorithm. --- chapter_recurrent-modern/machine-translation-and-dataset.md | 2 +- d2l/jax.py | 2 +- d2l/mxnet.py | 2 +- d2l/tensorflow.py | 2 +- d2l/torch.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/chapter_recurrent-modern/machine-translation-and-dataset.md b/chapter_recurrent-modern/machine-translation-and-dataset.md index 8a4f3ae733..d4d73b5861 100644 --- a/chapter_recurrent-modern/machine-translation-and-dataset.md +++ b/chapter_recurrent-modern/machine-translation-and-dataset.md @@ -160,7 +160,7 @@ and `tgt[i]` is that in the target language (French here). def _tokenize(self, text, max_examples=None): src, tgt = [], [] for i, line in enumerate(text.split('\n')): - if max_examples and i > max_examples: break + if max_examples and i >= max_examples: break parts = line.split('\t') if len(parts) == 2: # Skip empty tokens diff --git a/d2l/jax.py b/d2l/jax.py index fdde1bb247..c61cdb6a1d 100644 --- a/d2l/jax.py +++ b/d2l/jax.py @@ -1028,7 +1028,7 @@ def _tokenize(self, text, max_examples=None): """Defined in :numref:`sec_machine_translation`""" src, tgt = [], [] for i, line in enumerate(text.split('\n')): - if max_examples and i > max_examples: break + if max_examples and i >= max_examples: break parts = line.split('\t') if len(parts) == 2: # Skip empty tokens diff --git a/d2l/mxnet.py b/d2l/mxnet.py index f51fa72de6..c4c5c50e18 100644 --- a/d2l/mxnet.py +++ b/d2l/mxnet.py @@ -871,7 +871,7 @@ def _tokenize(self, text, max_examples=None): """Defined in :numref:`sec_machine_translation`""" src, tgt = [], [] for i, line in enumerate(text.split('\n')): - if max_examples and i > max_examples: break + if max_examples and i >= max_examples: break parts = line.split('\t') if len(parts) == 2: # Skip empty tokens diff --git a/d2l/tensorflow.py b/d2l/tensorflow.py index 7e14e59f04..bbc314af0f 100644 --- a/d2l/tensorflow.py +++ b/d2l/tensorflow.py @@ -827,7 +827,7 @@ def _tokenize(self, text, max_examples=None): """Defined in :numref:`sec_machine_translation`""" src, tgt = [], [] for i, line in enumerate(text.split('\n')): - if max_examples and i > max_examples: break + if max_examples and i >= max_examples: break parts = line.split('\t') if len(parts) == 2: # Skip empty tokens diff --git a/d2l/torch.py b/d2l/torch.py index 00ff5c308e..1ba81a7200 100644 --- a/d2l/torch.py +++ b/d2l/torch.py @@ -861,7 +861,7 @@ def _tokenize(self, text, max_examples=None): """Defined in :numref:`sec_machine_translation`""" src, tgt = [], [] for i, line in enumerate(text.split('\n')): - if max_examples and i > max_examples: break + if max_examples and i >= max_examples: break parts = line.split('\t') if len(parts) == 2: # Skip empty tokens From 317d8f0d5a49d94e30d0164f5da9976d4ce2910d Mon Sep 17 00:00:00 2001 From: nickmcgreivy Date: Thu, 21 Aug 2025 09:55:46 -0700 Subject: [PATCH 3/3] Fixed minus sign error in 12.6 Momentum --- chapter_optimization/momentum.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/chapter_optimization/momentum.md b/chapter_optimization/momentum.md index 2d4fa240c4..d4fb62d92e 100644 --- a/chapter_optimization/momentum.md +++ b/chapter_optimization/momentum.md @@ -258,11 +258,11 @@ $$h(\mathbf{x}) = \frac{1}{2} \mathbf{x}^\top \mathbf{Q} \mathbf{x} + \mathbf{x} This is a general quadratic function. For positive definite matrices $\mathbf{Q} \succ 0$, i.e., for matrices with positive eigenvalues this has a minimizer at $\mathbf{x}^* = -\mathbf{Q}^{-1} \mathbf{c}$ with minimum value $b - \frac{1}{2} \mathbf{c}^\top \mathbf{Q}^{-1} \mathbf{c}$. Hence we can rewrite $h$ as -$$h(\mathbf{x}) = \frac{1}{2} (\mathbf{x} - \mathbf{Q}^{-1} \mathbf{c})^\top \mathbf{Q} (\mathbf{x} - \mathbf{Q}^{-1} \mathbf{c}) + b - \frac{1}{2} \mathbf{c}^\top \mathbf{Q}^{-1} \mathbf{c}.$$ +$$h(\mathbf{x}) = \frac{1}{2} (\mathbf{x} + \mathbf{Q}^{-1} \mathbf{c})^\top \mathbf{Q} (\mathbf{x} + \mathbf{Q}^{-1} \mathbf{c}) + b - \frac{1}{2} \mathbf{c}^\top \mathbf{Q}^{-1} \mathbf{c}.$$ -The gradient is given by $\partial_{\mathbf{x}} h(\mathbf{x}) = \mathbf{Q} (\mathbf{x} - \mathbf{Q}^{-1} \mathbf{c})$. That is, it is given by the distance between $\mathbf{x}$ and the minimizer, multiplied by $\mathbf{Q}$. Consequently also the velocity is a linear combination of terms $\mathbf{Q} (\mathbf{x}_t - \mathbf{Q}^{-1} \mathbf{c})$. +The gradient is given by $\partial_{\mathbf{x}} h(\mathbf{x}) = \mathbf{Q} (\mathbf{x} + \mathbf{Q}^{-1} \mathbf{c})$. That is, it is given by the distance between $\mathbf{x}$ and the minimizer, multiplied by $\mathbf{Q}$. Consequently also the velocity is a linear combination of terms $\mathbf{Q} (\mathbf{x}_t + \mathbf{Q}^{-1} \mathbf{c})$. -Since $\mathbf{Q}$ is positive definite it can be decomposed into its eigensystem via $\mathbf{Q} = \mathbf{O}^\top \boldsymbol{\Lambda} \mathbf{O}$ for an orthogonal (rotation) matrix $\mathbf{O}$ and a diagonal matrix $\boldsymbol{\Lambda}$ of positive eigenvalues. This allows us to perform a change of variables from $\mathbf{x}$ to $\mathbf{z} \stackrel{\textrm{def}}{=} \mathbf{O} (\mathbf{x} - \mathbf{Q}^{-1} \mathbf{c})$ to obtain a much simplified expression: +Since $\mathbf{Q}$ is positive definite it can be decomposed into its eigensystem via $\mathbf{Q} = \mathbf{O}^\top \boldsymbol{\Lambda} \mathbf{O}$ for an orthogonal (rotation) matrix $\mathbf{O}$ and a diagonal matrix $\boldsymbol{\Lambda}$ of positive eigenvalues. This allows us to perform a change of variables from $\mathbf{x}$ to $\mathbf{z} \stackrel{\textrm{def}}{=} \mathbf{O} (\mathbf{x} + \mathbf{Q}^{-1} \mathbf{c})$ to obtain a much simplified expression: $$h(\mathbf{z}) = \frac{1}{2} \mathbf{z}^\top \boldsymbol{\Lambda} \mathbf{z} + b'.$$