-
Notifications
You must be signed in to change notification settings - Fork 189
Open
Labels
Description
This produces invalid code after AD:
entry scan_arr_add [n]
(inp: [2][n]f32)
(adj: [2][n]f32) : [2][n]f32 =
let adj =
vjp (scan (\x y -> [x[0] + y[0], x[1] + y[1]])
(replicate 2 0))
(transpose inp)
(transpose adj)
in transpose adj
The problem is that some part of the differentiation rule assumes that the operands are never arrays.