Open
Description
Notice that Tracker does not behave like Zygote here:
julia> arr = [3.0, 4.0];
julia> Tracker.withgradient(nt -> sum(sum(x) for x in nt), (a=arr, b=arr, c=copy(arr)))
(val = 21.0 (tracked), grad = ((a = [2.0, 2.0], b = [2.0, 2.0], c = [1.0, 1.0]),))
julia> Zygote.gradient(nt -> sum(sum(x) for x in nt), (a=arr, b=arr, c=copy(arr)))
((a = Fill(1.0, 2), b = Fill(1.0, 2), c = Fill(1.0, 2)),)
julia> Enzyme.gradient(Reverse, nt -> sum(sum(x) for x in nt), (a=arr, b=arr, c=copy(arr)))
(a = [2.0, 2.0], b = [2.0, 2.0], c = [1.0, 1.0])
I believe this is likely to confuse Optimisers.jl, which is written with the Zygote convention in mind.
Some possible fixes are:
- change Tracker to give an error on such cases, easy
- change Tracker to replace all but the first duplicate with nothing, Optimisers.jl will like that
- simplify Optimisers.jl to accept the Tracker (and Enzyme) convention, and add a Zygote compatibility layer somewhere (e.g. make
Flux.withgradient
handle this)
Metadata
Metadata
Assignees
Labels
No labels