Skip to content

Enzyme is dropping gradients when using custom rule and views



I've managed to get into a situation where a custom rule seems incorrect when used within a loop. Here is a MWE

vl = [2, 2]
tot = sum(vl)
rg = [1:2, 3:4]
Nx = 4

iminds = reshape([CartesianIndex(i) for i in 1:2], :)
visinds = [collect(rg[i]) for i in eachindex(rg)]
Bs = Dict((iminds[i]=> ones(ComplexF64, vl[i], Nx*Nx) for i in eachindex(vl)))

function _mul!(out, A, b)
    mul!(out, A, b)
    return nothing 

function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, ::Const{typeof(_mul!)}, ::Type{<:Const}, out::Duplicated,
                                      A::Const{<:Matrix}, b::Duplicated)
    _mul!(out.val, A.val, b.val)
    return EnzymeRules.AugmentedReturn(nothing, nothing, nothing)

@noinline function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
                                       ::Type{<:Const}, tape, out::Duplicated, A::Const{<:Matrix},

    b.dval .+= real.(A.val' * out.dval)
    out.dval .= 0
    return (nothing, nothing, nothing)

@inline function f(Bs, visinds, iminds, tot, x)
    out = similar(x, Complex{eltype(x)}, tot)
    for i in eachindex(iminds, visinds)
        imind = iminds[i]
        visind = visinds[i]
        _mul!(@view(out[visind]), Bs[imind], reshape(@view(x[:, :, imind]), :))
    return sum(abs2, out)

x = ones(Nx, Nx, 2)
dx = zero(x)

f(Bs, visinds, iminds, tot, x)
autodiff(set_runtime_activity(Reverse), f, Active, Const(Bs), Const(visinds), Const(iminds), Const(tot), Duplicated(x, fill!(dx, 0)))
@show dx
# dx = [0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0;;; 64.0 64.0 64.0 64.0; 64.0 64.0 64.0 64.0; 64.0 64.0 64.0 64.0; 64.0 64.0 64.0 64.0]

So the gradients for x[:,:,1] are zero'd rather than also filled with 64. Funny enough if I change vl to be different e.g.,

vl = [2, 3]
tot = sum(vl)
rg = [1:2, 3:5]
Nx = 4

and keep everything else identical I get a out of bounds error

ERROR: DimensionMismatch: matrix A has dimensions (16,2), vector B has length 3
  [1] _generic_matvecmul!
    @ ~/.julia/juliaup/julia-1.10.5+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:697
  [2] generic_matvecmul!
    @ ~/.julia/juliaup/julia-1.10.5+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:687 [inlined]
  [3] mul!
    @ ~/.julia/juliaup/julia-1.10.5+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:66 [inlined]
  [4] mul!
    @ ~/.julia/juliaup/julia-1.10.5+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:237 [inlined]
  [5] *
    @ ~/.julia/juliaup/julia-1.10.5+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:57
  [6] reverse
    @ ~/Research/Enzyme/dft.jl:32 [inlined]
  [7] f
    @ ~/Research/Enzyme/dft.jl:42 [inlined]
  [8] diffejulia_f_7370wrap
    @ ~/Research/Enzyme/dft.jl:0
  [9] macro expansion
    @ ~/.julia/packages/Enzyme/0PGOL/src/compiler.jl:7045 [inlined]
 [10] enzyme_call
    @ ~/.julia/packages/Enzyme/0PGOL/src/compiler.jl:6648 [inlined]
 [11] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/0PGOL/src/compiler.jl:6525 [inlined]
 [12] autodiff
    @ ~/.julia/packages/Enzyme/0PGOL/src/Enzyme.jl:316 [inlined]
 [13] autodiff(::ReverseMode{…}, ::typeof(f), ::Type{…}, ::Const{…}, ::Const{…}, ::Const{…}, ::Const{…}, ::Duplicated{…})
    @ Enzyme ~/.julia/packages/Enzyme/0PGOL/src/Enzyme.jl:328
 [14] top-level scope
    @ ~/Research/Enzyme/dft.jl:51

so it looks like the shadow is incorrect.

Note that when acting on a single array the rule is correct, e.g.,

function fsimple(Bs, x)
    out = similar(Bs, Complex{eltype(x)}, size(Bs, 1))
    _mul!(out, Bs, reshape(x, :))
    return sum(abs2, out)

B = ones(ComplexF64, 2, 16)
xx = ones(4,4)
dxx = zero(xx)
autodiff(set_runtime_activity(Reverse), fsimple, Active, Const(B), Duplicated(xx, dxx))
@show dxx
# dxx = [64.0 64.0 64.0 64.0; 64.0 64.0 64.0 64.0; 64.0 64.0 64.0 64.0; 64.0 64.0 64.0 64.0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment



No one assigned


    No labels
    No labels


    No type


    No projects


    No milestone


    None yet


    No branches or pull requests

    Issue actions