-
Greeting community, Below is the code I wrote, but the output isn't correct:
The correct output can be obtained using the following code:
The computed result
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
I think both results are correct, in that JAX is computing the requested gradient for the function you defined. Note that if you change your second output to this: result = jnp.sum(jac_f_val, axis=0) then you get the same result as the first. This makes sense because summation and differentiation commute: the gradient of the sum will be the same as the sum of the gradients (but will not be the same as a single term from the sum).
I'm not sure there's any trick beyond making sure you're writing code which matches the mathematical expression you have in mind! Hope that helps! |
Beta Was this translation helpful? Give feedback.
Yes - that makes sense!