Skip to content

Shared parameters and withgradient #167

Open
@mcabbott

Description

@mcabbott

Notice that Tracker does not behave like Zygote here:

julia> arr = [3.0, 4.0];

julia> Tracker.withgradient(nt -> sum(sum(x) for x in nt), (a=arr, b=arr, c=copy(arr)))
(val = 21.0 (tracked), grad = ((a = [2.0, 2.0], b = [2.0, 2.0], c = [1.0, 1.0]),))

julia> Zygote.gradient(nt -> sum(sum(x) for x in nt), (a=arr, b=arr, c=copy(arr)))
((a = Fill(1.0, 2), b = Fill(1.0, 2), c = Fill(1.0, 2)),)

julia> Enzyme.gradient(Reverse, nt -> sum(sum(x) for x in nt), (a=arr, b=arr, c=copy(arr)))
(a = [2.0, 2.0], b = [2.0, 2.0], c = [1.0, 1.0])

I believe this is likely to confuse Optimisers.jl, which is written with the Zygote convention in mind.

Some possible fixes are:

  • change Tracker to give an error on such cases, easy
  • change Tracker to replace all but the first duplicate with nothing, Optimisers.jl will like that
  • simplify Optimisers.jl to accept the Tracker (and Enzyme) convention, and add a Zygote compatibility layer somewhere (e.g. make Flux.withgradient handle this)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions