Skip to content

Commit fd58573

Browse files
committed
Handle mixed types for inner properly
1 parent a0cf90f commit fd58573

4 files changed

Lines changed: 31 additions & 12 deletions

File tree

ext/VectorInterfaceEnzymeExt.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,16 @@ function EnzymeRules.forward(
215215
end
216216
end
217217

218+
function project_add!(C, A, α)
219+
TC = Base.promote_op(+, scalartype(A), scalartype(α))
220+
return if !(TC <: Real) && scalartype(C) <: Real
221+
add!(C, real(add!(zerovector(C, TC), A, α)))
222+
else
223+
add!(C, A, α)
224+
end
225+
end
226+
227+
218228
function EnzymeRules.augmented_primal(
219229
config::EnzymeRules.RevConfigWidth{1},
220230
func::Const{typeof(inner)},
@@ -241,8 +251,8 @@ function EnzymeRules.reverse(
241251
)
242252
ΔS = dret.val
243253
Aval, Bval = cache
244-
!isa(A, Const) && add!(A.dval, Bval, conj(ΔS))
245-
!isa(B, Const) && add!(B.dval, Aval, ΔS)
254+
!isa(A, Const) && project_add!(A.dval, Bval, conj(ΔS))
255+
!isa(B, Const) && project_add!(B.dval, Aval, ΔS)
246256
return (nothing, nothing)
247257
end
248258

ext/VectorInterfaceMooncakeExt.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,15 @@ end
140140
# inner
141141
# -----
142142

143+
function project_add!(C, A, α)
144+
TC = Base.promote_op(+, scalartype(A), scalartype(α))
145+
return if !(TC <: Real) && scalartype(C) <: Real
146+
add!(C, real(add!(zerovector(C, TC), A, α)))
147+
else
148+
add!(C, A, α)
149+
end
150+
end
151+
143152
@is_primitive DefaultCtx Tuple{typeof(inner), AbstractArray, AbstractArray}
144153

145154
function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual, B_ΔB::CoDual)
@@ -151,8 +160,8 @@ function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual, B_ΔB::CoDual)
151160
s = inner(A, B)
152161

153162
function inner_pullback(Δs)
154-
add!(ΔA, B, conj(Δs))
155-
add!(ΔB, A, Δs)
163+
project_add!(ΔA, B, conj(Δs))
164+
project_add!(ΔB, A, Δs)
156165
return NoRData(), NoRData(), NoRData()
157166
end
158167

test/enzyme.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,13 @@ end
104104
end
105105
end
106106

107-
@testset "inner ($T)" for T in eltypes
107+
@testset "inner ($Tx, $Ty)" for Tx in eltypes, Ty in eltypes
108108
n = 12
109-
atol = rtol = n * precision(T)
109+
atol = rtol = n * max(precision(Tx), precision(Ty))
110110

111111
# Vector
112-
x = randn(T, n)
113-
y = randn(T, n)
112+
x = randn(Tx, n)
113+
y = randn(Ty, n)
114114
for RT in (Const, Active)
115115
test_reverse(inner, RT, (x, Duplicated), (y, Duplicated); atol, rtol)
116116
end

test/mooncake.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,13 @@ end
7373
Mooncake.TestUtils.test_rule(rng, add!!, my, mx, α, β; atol, rtol, is_primitive = false)
7474
end
7575

76-
@testset "inner pullbacks ($T)" for T in eltypes
76+
@testset "inner pullbacks ($Tx, $Ty)" for Tx in eltypes, Ty in eltypes
7777
n = 12
78-
atol = rtol = n * precision(T)
78+
atol = rtol = n * max(precision(Tx), precision(Ty))
7979

8080
# Vector
81-
x = randn(T, n)
82-
y = randn(T, n)
81+
x = randn(Tx, n)
82+
y = randn(Ty, n)
8383
Mooncake.TestUtils.test_rule(rng, inner, x, y; atol, rtol, is_primitive = false)
8484

8585
# MinimalMVec

0 commit comments

Comments
 (0)