Skip to content

Support vectorize/devectorize inside gradients #1533

Open
@jyc

Description

Thanks for making Nx!

I tried to use value_and_grad on a function that takes two inputs: a vectorized tensor and a non-vectorized tensor.

defmodule Foo do
  import Nx.Defn
  defn f(x, y) do
    x + y
  end

  defn f_and_grad(x, y) do
    value_and_grad(y, fn y -> Foo.f(x, y) end)
  end
end

x = ~VEC[0 1] |> vectorize(:bar)
Foo.f_and_grad(x, 1)

This evaluates to:

{#Nx.Tensor<
   vectorized[bar: 2]
   s64
   EXLA.Backend<host:0, 0.731981912.321781778.128426>
   [1, 2]
 >,
 #Nx.Tensor<
   f32
   EXLA.Backend<host:0, 0.731981912.321781778.128427>
   2.0
 >}

The value is correct and maintains the vectorized axis of the vectorized input to x, but the gradient surprises me. I would have expected a vectorized tensor rank-1 dimension-2 vector with the same :foo axis and which is everywhere 1; it looks like instead Nx is summing up the two gradients.

Is this behavior expected? If so, is there any way to make Nx return a vectorized gradient?

Thanks!

Metadata

Assignees

Labels

area:defnApplies to defnkind:bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions