Skip to content

Enzyme is dropping gradients when using custom rule and views

Open

Description

I've managed to get into a situation where a custom rule seems incorrect when used within a loop. Here is a MWE

vl = [2, 2]
tot = sum(vl)
rg = [1:2, 3:4]
Nx = 4

iminds = reshape([CartesianIndex(i) for i in 1:2], :)
visinds = [collect(rg[i]) for i in eachindex(rg)]
Bs = Dict((iminds[i]=> ones(ComplexF64, vl[i], Nx*Nx) for i in eachindex(vl)))

function _mul!(out, A, b)
    mul!(out, A, b)
    return nothing 
end


function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, ::Const{typeof(_mul!)}, ::Type{<:Const}, out::Duplicated,
                                      A::Const{<:Matrix}, b::Duplicated)
    _mul!(out.val, A.val, b.val)
    return EnzymeRules.AugmentedReturn(nothing, nothing, nothing)
end

@noinline function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
                                       ::Const{typeof(_mul!)},
                                       ::Type{<:Const}, tape, out::Duplicated, A::Const{<:Matrix},
                                       b::Duplicated)

    b.dval .+= real.(A.val' * out.dval)
    out.dval .= 0
    return (nothing, nothing, nothing)
end

@inline function f(Bs, visinds, iminds, tot, x)
    out = similar(x, Complex{eltype(x)}, tot)
    for i in eachindex(iminds, visinds)
        imind = iminds[i]
        visind = visinds[i]
        _mul!(@view(out[visind]), Bs[imind], reshape(@view(x[:, :, imind]), :))
    end
    return sum(abs2, out)
end

x = ones(Nx, Nx, 2)
dx = zero(x)

f(Bs, visinds, iminds, tot, x)
autodiff(set_runtime_activity(Reverse), f, Active, Const(Bs), Const(visinds), Const(iminds), Const(tot), Duplicated(x, fill!(dx, 0)))
@show dx
# dx = [0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0;;; 64.0 64.0 64.0 64.0; 64.0 64.0 64.0 64.0; 64.0 64.0 64.0 64.0; 64.0 64.0 64.0 64.0]

So the gradients for x[:,:,1] are zero'd rather than also filled with 64. Funny enough if I change vl to be different e.g.,

vl = [2, 3]
tot = sum(vl)
rg = [1:2, 3:5]
Nx = 4

and keep everything else identical I get a out of bounds error

ERROR: DimensionMismatch: matrix A has dimensions (16,2), vector B has length 3
  [1] _generic_matvecmul!
    @ ~/.julia/juliaup/julia-1.10.5+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:697
  [2] generic_matvecmul!
    @ ~/.julia/juliaup/julia-1.10.5+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:687 [inlined]
  [3] mul!
    @ ~/.julia/juliaup/julia-1.10.5+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:66 [inlined]
  [4] mul!
    @ ~/.julia/juliaup/julia-1.10.5+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:237 [inlined]
  [5] *
    @ ~/.julia/juliaup/julia-1.10.5+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:57
  [6] reverse
    @ ~/Research/Enzyme/dft.jl:32 [inlined]
  [7] f
    @ ~/Research/Enzyme/dft.jl:42 [inlined]
  [8] diffejulia_f_7370wrap
    @ ~/Research/Enzyme/dft.jl:0
  [9] macro expansion
    @ ~/.julia/packages/Enzyme/0PGOL/src/compiler.jl:7045 [inlined]
 [10] enzyme_call
    @ ~/.julia/packages/Enzyme/0PGOL/src/compiler.jl:6648 [inlined]
 [11] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/0PGOL/src/compiler.jl:6525 [inlined]
 [12] autodiff
    @ ~/.julia/packages/Enzyme/0PGOL/src/Enzyme.jl:316 [inlined]
 [13] autodiff(::ReverseMode{…}, ::typeof(f), ::Type{…}, ::Const{…}, ::Const{…}, ::Const{…}, ::Const{…}, ::Duplicated{…})
    @ Enzyme ~/.julia/packages/Enzyme/0PGOL/src/Enzyme.jl:328
 [14] top-level scope
    @ ~/Research/Enzyme/dft.jl:51

so it looks like the shadow is incorrect.

Note that when acting on a single array the rule is correct, e.g.,

function fsimple(Bs, x)
    out = similar(Bs, Complex{eltype(x)}, size(Bs, 1))
    _mul!(out, Bs, reshape(x, :))
    return sum(abs2, out)
end

B = ones(ComplexF64, 2, 16)
xx = ones(4,4)
dxx = zero(xx)
autodiff(set_runtime_activity(Reverse), fsimple, Active, Const(B), Duplicated(xx, dxx))
@show dxx
# dxx = [64.0 64.0 64.0 64.0; 64.0 64.0 64.0 64.0; 64.0 64.0 64.0 64.0; 64.0 64.0 64.0 64.0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions