Open
Description
Thanks for making Nx!
I tried to use value_and_grad
on a function that takes two inputs: a vectorized tensor and a non-vectorized tensor.
defmodule Foo do
import Nx.Defn
defn f(x, y) do
x + y
end
defn f_and_grad(x, y) do
value_and_grad(y, fn y -> Foo.f(x, y) end)
end
end
x = ~VEC[0 1] |> vectorize(:bar)
Foo.f_and_grad(x, 1)
This evaluates to:
{#Nx.Tensor<
vectorized[bar: 2]
s64
EXLA.Backend<host:0, 0.731981912.321781778.128426>
[1, 2]
>,
#Nx.Tensor<
f32
EXLA.Backend<host:0, 0.731981912.321781778.128427>
2.0
>}
The value is correct and maintains the vectorized axis of the vectorized input to x
, but the gradient surprises me. I would have expected a vectorized tensor rank-1 dimension-2 vector with the same :foo
axis and which is everywhere 1
; it looks like instead Nx is summing up the two gradients.
Is this behavior expected? If so, is there any way to make Nx return a vectorized gradient?
Thanks!