Skip to content

Commit d58ebab

Browse files
committed
Evaluate contractions inside Zero in diff
1 parent dfba2d2 commit d58ebab

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

src/forward.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@ function diff(arg::Tensor, wrt::Tensor)
1515
return evaluate(D) # evaluate to get rid of the constant factor
1616
end
1717

18-
indices = union(arg.indices, [flip(i) for i wrt.indices])
18+
# TODO: What is the canonical way?
19+
indices = copy(arg.indices)
20+
foreach(i -> push!(indices, flip(i)), wrt.indices)
21+
22+
indices = eliminate_indices(indices)
1923

2024
return Zero(unique(indices)...)
2125
end

0 commit comments

Comments
 (0)