-
Notifications
You must be signed in to change notification settings - Fork 82
Description
Enzyme throws an error in the following case, whereas it works with ForwardDiff. The error message is lengthy and I have pasted a truncated error stack here. In case the error is due to iip MVP, I have a commented out version that uses a vector (Svector) of vectors (MVectors), in which case dnka2! should be called. A separate dispatch for calculating the gradient of root-finding can be written in the following case (find_c_ fn) but I don't think that is the issue here.
MWE
using LinearAlgebra
using StaticArrays
using Enzyme
using DifferentiationInterface
mutable struct myModel{T1 <: AbstractArray{<:Any}, T2 <: AbstractArray{<:Any}, T3 <: AbstractArray{<:Any}, T4 <: AbstractArray{<:Any}}
m::T1
h::T2
ρ::T3
vp::T4
end
function dnka!(C, wvno2, gam, gammk, rho, a0, cpcq, cpy, cpz, cqw, cqx, xy, xz, wy, wz)
# constants
gamm1 = gam - 1
twgm1 = gam + gamm1
gmgmk = gam * gammk
gmgm1 = gam * gamm1
gm1sq = gamm1 * gamm1
a0pq = a0 -cpcq
t = -2*wvno2
C[1, 1] = cpcq - 2 * gmgm1 * a0pq - gmgmk * xz - wvno2 * gm1sq * wy
C[1, 2] = (wvno2 * cpy - cqx) / rho
C[1, 3] = -(twgm1 * a0pq + gammk * xz + wvno2 * gamm1 * wy) / rho
C[1, 4] = (cpz - wvno2 * cqw) / rho
C[1, 5] = -(2 * wvno2 * a0pq + xz + wvno2 * wvno2 * wy) / (rho*rho)
C[2, 1] = (gmgmk * cpz - gm1sq * cqw) * rho
C[2, 2] = cpcq
C[2, 3] = gammk * cpz - gamm1 * cqw
C[2, 4] = -wz
C[2, 5] = C[1, 4]
C[4, 1] = (gm1sq * cpy - gmgmk * cqx) * rho
C[4, 2] = -xy
C[4, 3] = gamm1 * cpy - gammk * cqx
C[4, 4] = C[2, 2]
C[4, 5] = C[1, 2]
C[5, 1] = (
-(2 * gmgmk * gm1sq * a0pq + gmgmk * gmgmk * xz + gm1sq * gm1sq * wy) * (rho*rho)
)
C[5, 2] = C[4, 1]
C[5, 3] = (
-(gammk * gamm1 * twgm1 * a0pq + gam * gammk * gammk * xz + gamm1 * gm1sq * wy)
* rho
)
C[5, 4] = C[2, 1]
C[5, 5] = C[1, 1]
C[3, 1] = t * C[5, 3]
C[3, 2] = t * C[4, 3]
C[3, 3] = a0 + 2 * (cpcq - C[1, 1])
C[3, 4] = t * C[2, 3]
C[3, 5] = t * C[1, 3]
return nothing;
end
function dnka2!(C, wvno2, gam, gammk, rho, a0, cpcq, cpy, cpz, cqw, cqx, xy, xz, wy, wz)
# constants
gamm1 = gam - 1
twgm1 = gam + gamm1
gmgmk = gam * gammk
gmgm1 = gam * gamm1
gm1sq = gamm1 * gamm1
a0pq = a0 -cpcq
t = -2*wvno2
C[1][1] = cpcq - 2 * gmgm1 * a0pq - gmgmk * xz - wvno2 * gm1sq * wy
C[1][2] = (wvno2 * cpy - cqx) / rho
C[1][3] = -(twgm1 * a0pq + gammk * xz + wvno2 * gamm1 * wy) / rho
C[1][4] = (cpz - wvno2 * cqw) / rho
C[1][5] = -(2 * wvno2 * a0pq + xz + wvno2 * wvno2 * wy) / (rho*rho)
C[2][1] = (gmgmk * cpz - gm1sq * cqw) * rho
C[2][2] = cpcq
C[2][3] = gammk * cpz - gamm1 * cqw
C[2][4] = -wz
C[2][5] = C[1][4]
C[4][1] = (gm1sq * cpy - gmgmk * cqx) * rho
C[4][2] = -xy
C[4][3] = gamm1 * cpy - gammk * cqx
C[4][4] = C[2][2]
C[4][5] = C[1][2]
C[5][1] = (
-(2 * gmgmk * gm1sq * a0pq + gmgmk * gmgmk * xz + gm1sq * gm1sq * wy) * (rho*rho)
)
C[5][2] = C[4][1]
C[5][3] = (
-(gammk * gamm1 * twgm1 * a0pq + gam * gammk * gammk * xz + gamm1 * gm1sq * wy)
* rho
)
C[5][4] = C[2][1]
C[5][5] = C[1][1]
C[3][1] = t * C[5][3]
C[3][2] = t * C[4][3]
C[3][3] = a0 + 2 * (cpcq - C[1][1])
C[3][4] = t * C[2][3]
C[3][5] = t * C[1][3]
return nothing;
end
function var(p, q, ra, rb, wvno, xka, xkb, dpth)
pex = zero(p) # TODO
cosp = zero(p)
sinp = zero(p)
if(wvno < xka)
sinp = sin(p)
w = sinp / ra
x = -ra * sinp
cosp = cos(p)
elseif(wvno == xka)
cosp = zero(p) + 1
w = dpth
x = zero(ra)
elseif(wvno > xka)
pex = p
fac = exp(-2p) #ifelse(p < 16, exp(-2p), 0)
cosp = (1 + fac) * oftype(fac, 0.5)
sinp = (1 - fac) * oftype(fac, 0.5)
w = sinp / ra
x = ra * sinp
end
# Examine S-wave eigenfunctions
# Checking whether c > vs, c = vs or c < vs
sex = zero(q)
if(wvno < xkb)
sinq = sin(q)
y = sinq / rb
z = -rb * sinq
cosq = cos(q)
elseif(wvno == xkb)
cosq = zero(q) + 1
y = dpth
z = zero(ra)
elseif(wvno > xkb)
sex = q
fac = exp(-2q) #ifelse(q < 16, exp(-2q), 0)
cosq = (1 + fac) * oftype(fac, 0.5)
sinq = (1 - fac) * oftype(fac, 0.5)
y = sinq / rb
z = rb * sinq
end
# Form eigenfunction products for use with compound matrices
exa = pex + sex
a0 = exp(-exa) #ifelse(exa < 60, exp(-exa), zero(exa))
cpcq = cosp * cosq
cpy = cosp * y
cpz = cosp * z
cqw = cosq * w
cqx = cosq * x
xy = x * y
xz = x * z
wy = w * y
wz = w * z
qmp = sex - pex
fac = exp(qmp) #ifelse(exa > -40, exp(qmp), zero(qmp))
cosq *= fac
y *= fac
z *= fac
return w, cosp, a0, cpcq, cpy, cpz, cqw, cqx, xy, xz, wy, wz
end
function dltar(k, ω, model::myModel, e, ee, C)
vp = model.vp
vs = model.m
ρ = model.ρ
h = model.h
(ω < 1f-4) && (ω = 1f-4 + zero(ω))
xka = ω / vp[end]
xkb = ω / vs[end]
ra = sqrt(abs(k^2 - xka^2))
rb = sqrt(abs(k^2 - xkb^2))
# t_ = vs[end] / ω
# E matrix for the bottom half-space
gammk = 2 * (vs[end] / ω)^2
gam = gammk * k*k
gamm1 = gam - 1
ρ_end = ρ[end]
e[1] = ρ_end^2 * (gamm1 * gamm1 - gam * gammk * ra * rb)
e[2] = -ρ_end * ra
e[3] = ρ_end * (gamm1 - gammk * ra * rb)
e[4] = ρ_end * rb
e[5] = k^2 - ra * rb
# Matrix multiplication from bottom layer upward
for m in range(length(vs) - 1, 1, step = -1)
xka = ω / vp[m]
xkb = ω / vs[m]
# t = vs[m] / ω
gammk = 2 * (vs[m] / ω)^2
gam = gammk * k*k
ra = sqrt(abs(k^2 - xka^2))
rb = sqrt(abs(k^2 - xkb^2))
dpth = h[m] # should change later on
p = ra * dpth
q = rb * dpth
# Evaluate cosP, cosQ...
_, _, a0, cpcq, cpy, cpz, cqw, cqx, xy, xz, wy, wz = var(
p, q, ra, rb, k, xka, xkb, dpth
)
# Evaluate Dunkin's matrix
# Use dnka2!(...) if C is a vector of vectors
dnka!(C,
k*k, gam, gammk, ρ[m], a0, cpcq, cpy, cpz, cqw, cqx, xy, xz, wy, wz
)
# C .= rand(9,9)
mul!(ee, e, C);
# for i in 1:5
# ee[i] = zero(ee[i])
# for j in 1:5
# ee[i] += e[j] * C[j][i]
# end
# end
norm_fac = maximum(abs.(ee))
e .= ee./norm_fac
end
return e[1]
end
function get_c!(resp_, t, m, mode, dc)
c_low_global = first(extrema(m.m))
c_high = 10 # should not really go this far
c_low = c_low_global .* 0.8
e = MMatrix{1,5}(zeros(eltype(m.m), 1, 5)) # can be preallocated
ee = MMatrix{1,5}(zeros(eltype(m.m), 1, 5)) # can be preallocated
C = MMatrix{5,5}(zeros(eltype(m.m), 5, 5)) # can be preallocated
# C = SVector{5}([MMatrix{5,1}(zeros(eltype(m.m), 5,1)) for _ in 1:5])
# resp_ = zero(t)
f(c, p) = dltar(p/c, p, m, e, ee, C)
# prob_init = IntervalNonlinearProblem(f, (c_low - 2dc ,c_high), 2π*inv(first(t)))
for i in eachindex(t) # this can be parallelized
ω = 2π/t[i]
c_low = c_low_global * 0.8
c_high_each = c_low
# f(c, p) = dltar(ω/c, ω, m)
for im in 1:mode+1
f_low = f(c_low, ω)
# f_low = f(c_low, [])
while c_high_each <= c_high
f_high_each = f(c_high_each, ω)
# f_high_each = f(c_high_each, [])
if sign(f_high_each) * sign(f_low) < 0
break
else
c_high_each += dc
end
if c_high_each> c_high
@warn "search space exceeded! t = $(t[i])"
break;
end
end
# c = (c_low_global + c_high_each)/2 #find_c(c_high_each - 2dc, c_high_each, f, ω)
# c = find_c(prob_init, c_high_each - 2dc, c_high_each, ω)
c = find_c_(c_high_each - 2dc, c_high_each, f, ω)
# prob_init = IntervalNonlinearProblem(f, (c_high_each - 2dc ,c_high_each), [])
# sol = solve(prob_init)
# c = sol.u
# c = (c_high_each - 2dc + c_high_each)/2
c_low = c + dc
resp_[i] = c
end
end
return nothing
end
# function find_c(prob, c1, c2, ω)
# prob_new = remake(prob; tspan = (c1, c2), p = ω)
# sol = solve(prob_new)
# return sol.u
# end
function find_c_(c1, c2, f, ω)
f1 = f(c1, ω)
f2 = f(c2, ω)
c3 = (c1+c2)/2
for i in 1:30
f3 = f(c3, ω)
if f3 < 1f-9
break
elseif f3 * f1 <0
c2 = c3
f2 = f3
else
c1 = c3
f1 =f3
end
if abs(c2-c1) < 1f-9
c3 = (c1+c2)/2
break
end
end
return c3
end
h_ = zeros(8) .+ 10.
vp_ = zeros(9) .+ 4.5 * 1.72
vs_ = zeros(9) .+ 4.5
density_ = zeros(9) .+ 2
my_m = myModel(vs_, h_, vp_, density_)
t_ = exp10.(range(0, 2, length = 100))
c_ = zero(t_)
get_c!(c_, t_, my_m, 0, 0.001)
c_
function get_c_wrapper2!(c, m, t, h, vp, density, mode, dc)
M = myModel(m,h,vp,density)
get_c!(c, t, M, mode, dc)
nothing
end
m0 = myModel(vs_, h_, vp_, density_)
cvec = zero(c_)
ad_backend = AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse))
# ad_backend = AutoForwardDiff()
mode_ = 0
dc_ = 0.001
prep_j = prepare_jacobian(
get_c_wrapper2!, cvec, ad_backend,
m0.m, Constant(t_), Constant(m0.h), Constant(m0.vp), Constant(m0.ρ),
Constant(mode_), Constant(dc_))
jac_ = zeros(100, 9);
DifferentiationInterface.jacobian!(get_c_wrapper2!, cvec, jac_, prep_j, ad_backend,
m0.m, Constant(t_), Constant(m0.h), Constant(m0.vp), Constant(m0.ρ),
Constant(mode_), Constant(dc_))
#
Error Stack
``` Enzyme cannot deduce type Current scope: ; Function Attrs: mustprogress willreturn define internal fastcc void @preprocess_julia_get_c__2981({} addrspace(10)* noundef nonnull align 16 dereferenceable(40) "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@double, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}" "enzymejl_parmtype"="4728174032" "enzymejl_parmtype_ref"="2" %0, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@double, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}" "enzymejl_parmtype"="4728174032" "enzymejl_parmtype_ref"="2" %1, {} addrspace(10)* noundef nonnull align 8 dereferenceable(32) "enzyme_type"="{[-1]:Pointer, [-1,-1]:Pointer, [-1,-1,0]:Pointer, [-1,-1,0,-1]:Float@double, [-1,-1,8]:Integer, [-1,-1,9]:Integer, [-1,-1,10]:Integer, [-1,-1,11]:Integer, [-1,-1,12]:Integer, [-1,-1,13]:Integer, [-1,-1,14]:Integer, [-1,-1,15]:Integer, [-1,-1,16]:Integer, [-1,-1,17]:Integer, [-1,-1,18]:Integer, [-1,-1,19]:Integer, [-1,-1,20]:Integer, [-1,-1,21]:Integer, [-1,-1,22]:Integer, [-1,-1,23]:Integer, [-1,-1,24]:Integer, [-1,-1,25]:Integer, [-1,-1,26]:Integer, [-1,-1,27]:Integer, [-1,-1,28]:Integer, [-1,-1,29]:Integer, [-1,-1,30]:Integer, [-1,-1,31]:Integer, [-1,-1,32]:Integer, [-1,-1,33]:Integer, [-1,-1,34]:Integer, [-1,-1,35]:Integer, [-1,-1,36]:Integer, [-1,-1,37]:Integer, [-1,-1,38]:Integer, [-1,-1,39]:Integer}" "enzymejl_parmtype"="10972312848" "enzymejl_parmtype_ref"="2" %2, i64 signext "enzyme_inactive" "enzyme_type"="{[-1]:Integer}" "enzymejl_parmtype"="4756506320" "enzymejl_parmtype_ref"="0" %3, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="4756505520" "enzymejl_parmtype_ref"="0" %4) unnamed_addr #182 !dbg !7930 { top: %phic = alloca {} addrspace(10)*, align 8 %phic1 = alloca {} addrspace(10)*, align 8 %phic284 = alloca i64, align 8 %phic285 = alloca double, align 8 %phic287 = alloca [4 x {} addrspace(10)*], align 8 %phic289 = alloca i64, align 8 %phic291 = alloca double, align 8 %phic293 = alloca i64, align 8 %phic295 = alloca double, align 8 %phic297 = alloca double, align 8 %phic10 = alloca {} addrspace(10)*, align 8 %phic299 = alloca i8, align 1 %phic301 = alloca i8, align 1 %phic13 = alloca {} addrspace(10)*, align 8 %phic14 = alloca {} addrspace(10)*, align 8 %phic15 = alloca {} addrspace(10)*, align 8A BUNCH OF STUFF
.pre534 = extractvalue [4 x {} addrspace(10)] %phic287.0.phic287.0.phic287.0.phic287.0.phic287.0.phic287.0.phic287.0.phic287.0.phic287.0., 1, !dbg !391, !enzyme_type !328: {[-1]:Pointer, [-1,-1]:Float@double}, intvals: {}
%.pre535 = extractvalue [4 x {} addrspace(10)] %phic287.0.phic287.0.phic287.0.phic287.0.phic287.0.phic287.0.phic287.0.phic287.0.phic287.0., 2, !dbg !391, !enzyme_type !328: {[-1]:Pointer, [-1,-1]:Float@double}, intvals: {}
%.pre536 = extractvalue [4 x {} addrspace(10)] %phic287.0.phic287.0.phic287.0.phic287.0.phic287.0.phic287.0.phic287.0.phic287.0.phic287.0., 3, !dbg !391, !enzyme_type !328: {[-1]:Pointer, [-1,-1]:Float@double}, intvals: {}
%141 = fmul double %value_phi215, 2.000000e+00, !dbg !392: {[-1]:Float@double}, intvals: {}
%142 = fsub double %value_phi220, %141, !dbg !395: {[-1]:Float@double}, intvals: {}
%144 = add i64 %value_phi219, -1, !dbg !397: {[-1]:Integer}, intvals: {0,}
%inbounds224 = icmp ult i64 %144, %arraylen223, !dbg !397: {[-1]:Integer}, intvals: {}
%.not437 = icmp eq i64 %iv.next, %arraylen173, !dbg !399: {[-1]:Integer}, intvals: {}
%146 = add nuw nsw i64 %iv.next, 1, !dbg !400: {[-1]:Integer}, intvals: {2,}
%148 = add i64 %value_phi180, 1, !dbg !403: {[-1]:Integer}, intvals: {}
%149 = icmp ugt i64 %value_phi180, 9223372036854775806, !dbg !406: {[-1]:Integer}, intvals: {}
%150 = fmul double %value_phi186, 8.000000e-01, !dbg !410: {[-1]:Float@double}, intvals: {}
%153 = fdiv double 0x401921FB54442D18, %arrayref191, !dbg !415: {[-1]:Float@double}, intvals: {}
%155 = fadd double %value_phi215, %143, !dbg !416: {[-1]:Float@double}, intvals: {}
%.not436 = icmp eq i64 %iv.next6, %148, !dbg !418: {[-1]:Integer}, intvals: {}
%158 = add nuw i64 %iv.next6, 1, !dbg !419: {[-1]:Integer}, intvals: {2,}
%160 = icmp eq {} addrspace(10) %131, addrspacecast ({}* inttoptr (i64 4824648176 to {}) to {} addrspace(10)), !dbg !379: {[-1]:Integer}, intvals: {}
%161 = add i64 %value_phi204, -1, !dbg !421: {[-1]:Integer}, intvals: {0,}
%inbounds248 = icmp ult i64 %161, %arraylen247, !dbg !421: {[-1]:Integer}, intvals: {}
%iv.next8 = add nuw nsw i64 %iv7, 1, !dbg !330: {[-1]:Integer}, intvals: {1,}
%iv.next6 = add nuw nsw i64 %iv5, 1, !dbg !322: {[-1]:Integer}, intvals: {1,}
%iv.next = add nuw nsw i64 %iv, 1, !dbg !320: {[-1]:Integer}, intvals: {1,}
Cannot deduce type of phi %value_phi238 = phi double [ %value_phi186, %idxend ], [ %value_phi218, %L379.loopexit ]{} sz: 8
Caused by:
Stacktrace:
[1] ==
@ ./promotion.jl:521
[2] iterate
@ ./range.jl:901
[3] get_c!
@ ./REPL[40]:60
within MethodInstance for get_c!(::Vector{Float64}, ::Vector{Float64}, ::myModel{Vector{Float64}, Vector{Float64}, Vector{Float64}, Vector{Float64}}, ::Int64, ::Float64)
Stacktrace:
[1] iterate
@ ./range.jl:901 [inlined]
[2] fill!
@ ./array.jl:396 [inlined]
[3] zeros
@ ./array.jl:637 [inlined]
[4] zeros
@ ./array.jl:632 [inlined]
[5] get_c!
@ ./REPL[40]:9
[6] myModel
@ ./REPL[35]:2 [inlined]
[7] get_c_wrapper2!
@ ./REPL[53]:2 [inlined]
[8] diffe15julia_get_c_wrapper2__2978wrap
@ ./REPL[53]:0
[9] macro expansion
@ ~/.julia/packages/Enzyme/iosr4/src/compiler.jl:5691 [inlined]
[10] enzyme_call
@ ~/.julia/packages/Enzyme/iosr4/src/compiler.jl:5225 [inlined]
[11] CombinedAdjointThunk
@ ~/.julia/packages/Enzyme/iosr4/src/compiler.jl:5100 [inlined]
[12] autodiff
@ ~/.julia/packages/Enzyme/iosr4/src/Enzyme.jl:517 [inlined]
[13] autodiff
@ ~/.julia/packages/Enzyme/iosr4/src/Enzyme.jl:538 [inlined]
[14] value_and_pullback!(::typeof(get_c_wrapper2!), ::Vector{…}, ::NTuple{…}, ::DifferentiationInterfaceEnzymeExt.EnzymeReverseTwoArgPullbackPrep{…}, ::AutoEnzyme{…}, ::Vector{…}, ::NTuple{…}, ::Constant{…}, ::Constant{…}, ::Constant{…}, ::Constant{…}, ::Constant{…}, ::Constant{…})
@ DifferentiationInterfaceEnzymeExt ~/.julia/packages/DifferentiationInterface/zJHX8/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl:166
[15] pullback!
@ ~/.julia/packages/DifferentiationInterface/zJHX8/src/first_order/pullback.jl:557 [inlined]
[16] _jacobian_aux!(::Tuple{…}, ::Matrix{…}, ::DifferentiationInterface.PullbackJacobianPrep{…}, ::AutoEnzyme{…}, ::Vector{…}, ::Constant{…}, ::Constant{…}, ::Constant{…}, ::Constant{…}, ::Constant{…}, ::Constant{…})
@ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/zJHX8/src/first_order/jacobian.jl:504
[17] jacobian!(::typeof(get_c_wrapper2!), ::Vector{…}, ::Matrix{…}, ::DifferentiationInterface.PullbackJacobianPrep{…}, ::AutoEnzyme{…}, ::Vector{…}, ::Constant{…}, ::Constant{…}, ::Constant{…}, ::Constant{…}, ::Constant{…}, ::Constant{…})
@ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/zJHX8/src/first_order/jacobian.jl:316
[18] top-level scope
@ REPL[65]:1
Some type information was truncated. Use show(err) to see complete types.