diff --git a/chapter_preliminaries/autograd.md b/chapter_preliminaries/autograd.md index b2b1a3cf3..f10e352a2 100644 --- a/chapter_preliminaries/autograd.md +++ b/chapter_preliminaries/autograd.md @@ -463,10 +463,10 @@ import jax y = lambda x: x * x # jax.lax primitives are Python wrappers around XLA operations -u = jax.lax.stop_gradient(y(x)) -z = lambda x: u * x - -grad(lambda x: z(x).sum())(x) == y(x) +( + grad(lambda x: (y(x) * x).sum())(x) == 3 * y(x), + grad(lambda x: (jax.lax.stop_gradient(y(x)) * x).sum())(x) == y(x), +) ``` Note that while this procedure @@ -694,4 +694,4 @@ For now, try to remember these basics: (i) attach gradients to those variables w :begin_tab:`jax` [Discussions](https://discuss.d2l.ai/t/17970) -:end_tab: \ No newline at end of file +:end_tab: