Skip to content

Commit f5d48a9

Browse files
authored
jax.lax.stop_gradient should be not useless in an example
The `x` in `u(x)` is different from one in lambdas, so `stop_gradient` was doing nothing. I have fixed the example for it to have some effect.
1 parent 23d7a5a commit f5d48a9

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

chapter_preliminaries/autograd.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -463,10 +463,10 @@ import jax
463463
464464
y = lambda x: x * x
465465
# jax.lax primitives are Python wrappers around XLA operations
466-
u = jax.lax.stop_gradient(y(x))
467-
z = lambda x: u * x
468-
469-
grad(lambda x: z(x).sum())(x) == y(x)
466+
(
467+
grad(lambda x: (y(x) * x).sum())(x) == 3 * y(x),
468+
grad(lambda x: (jax.lax.stop_gradient(y(x)) * x).sum())(x) == y(x),
469+
)
470470
```
471471

472472
Note that while this procedure
@@ -694,4 +694,4 @@ For now, try to remember these basics: (i) attach gradients to those variables w
694694

695695
:begin_tab:`jax`
696696
[Discussions](https://discuss.d2l.ai/t/17970)
697-
:end_tab:
697+
:end_tab:

0 commit comments

Comments
 (0)