From f5d48a9e3e6ac6973a5025204eeec136a34521c4 Mon Sep 17 00:00:00 2001 From: Nikita Sokolov Date: Fri, 14 Feb 2025 15:44:46 +0100 Subject: [PATCH] jax.lax.stop_gradient should be not useless in an example The `x` in `u(x)` is different from one in lambdas, so `stop_gradient` was doing nothing. I have fixed the example for it to have some effect. --- chapter_preliminaries/autograd.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/chapter_preliminaries/autograd.md b/chapter_preliminaries/autograd.md index b2b1a3cf3a..f10e352a2a 100644 --- a/chapter_preliminaries/autograd.md +++ b/chapter_preliminaries/autograd.md @@ -463,10 +463,10 @@ import jax y = lambda x: x * x # jax.lax primitives are Python wrappers around XLA operations -u = jax.lax.stop_gradient(y(x)) -z = lambda x: u * x - -grad(lambda x: z(x).sum())(x) == y(x) +( + grad(lambda x: (y(x) * x).sum())(x) == 3 * y(x), + grad(lambda x: (jax.lax.stop_gradient(y(x)) * x).sum())(x) == y(x), +) ``` Note that while this procedure @@ -694,4 +694,4 @@ For now, try to remember these basics: (i) attach gradients to those variables w :begin_tab:`jax` [Discussions](https://discuss.d2l.ai/t/17970) -:end_tab: \ No newline at end of file +:end_tab: