Skip to content

Commit fcfedb2

Browse files
committed
Add forward and reverse rules and tests for Enzyme
1 parent 252809f commit fcfedb2

4 files changed

Lines changed: 438 additions & 1 deletion

File tree

Project.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,20 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88

99
[weakdeps]
1010
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
11+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
1112
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
1213

1314
[extensions]
1415
VectorInterfaceChainRulesCoreExt = "ChainRulesCore"
16+
VectorInterfaceEnzymeExt = "Enzyme"
1517
VectorInterfaceMooncakeExt = "Mooncake"
1618

1719
[compat]
1820
Aqua = "0.6, 0.7, 0.8"
1921
ChainRulesCore = "1"
2022
ChainRulesTestUtils = "1"
23+
Enzyme = "0.13.131"
24+
EnzymeTestUtils = "0.2.6"
2125
LinearAlgebra = "1"
2226
Mooncake = "0.5"
2327
Random = "1"
@@ -29,10 +33,12 @@ julia = "1"
2933
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3034
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3135
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
36+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
37+
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
3238
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
3339
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3440
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3541
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
3642

3743
[targets]
38-
test = ["Test", "TestExtras", "Aqua", "ChainRulesTestUtils", "ChainRulesCore", "Mooncake", "Random"]
44+
test = ["Test", "TestExtras", "Aqua", "ChainRulesTestUtils", "ChainRulesCore", "Mooncake", "Enzyme", "EnzymeTestUtils", "Random"]

ext/VectorInterfaceEnzymeExt.jl

Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
module VectorInterfaceEnzymeExt
2+
3+
using VectorInterface
4+
using Enzyme
5+
using Enzyme.EnzymeCore
6+
using Enzyme.EnzymeCore: EnzymeRules
7+
8+
"""
9+
project_scalar(x::Number, dx::Number)
10+
11+
Project a computed tangent `dx` onto the correct tangent type for `x`.
12+
For example, we might compute a complex `dx` but only require the real part.
13+
"""
14+
project_scalar(x::Number, dx::Number) = oftype(x, dx)
15+
project_scalar(x::Real, dx::Complex) = project_scalar(x, real(dx))
16+
17+
function EnzymeRules.augmented_primal(
18+
config::EnzymeRules.RevConfigWidth{1},
19+
func::Const{typeof(scale!)},
20+
::Type{RT},
21+
C::Annotation,
22+
α::Annotation{<:Number},
23+
) where {RT}
24+
dret = !isa(C, Const) ? C.dval : nothing
25+
cacheα = EnzymeRules.overwritten(config)[3] ? copy.val) : α.val
26+
cache = (cacheα, copy(C.val)) # is this better than just unscaling?
27+
ret = scale!(C.val, α.val)
28+
shadow = EnzymeRules.needs_shadow(config) ? dret : nothing
29+
primal = EnzymeRules.needs_primal(config) ? ret : nothing
30+
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
31+
end
32+
33+
function EnzymeRules.reverse(
34+
config::EnzymeRules.RevConfigWidth{1},
35+
func::Const{typeof(scale!)},
36+
::Type{RT},
37+
cache,
38+
C::Annotation,
39+
α::Annotation{<:Number},
40+
) where {RT}
41+
αval, Cval = cache
42+
Δα = if !isa(α, Const) && !isa(C, Const)
43+
project_scalar.val, inner(Cval, C.dval))
44+
elseif !isa(α, Const)
45+
zero.val)
46+
else
47+
nothing
48+
end
49+
scale!(C.dval, conj(αval))
50+
return (nothing, Δα)
51+
end
52+
53+
function EnzymeRules.forward(
54+
config::EnzymeRules.FwdConfigWidth{1},
55+
func::Const{typeof(scale!)},
56+
::Type{RT},
57+
C::Annotation,
58+
α::Annotation{<:Number},
59+
) where {RT}
60+
if !isa(α, Const) && !isa(C, Const)
61+
add!(C.dval, C.val, α.dval, α.val)
62+
elseif !isa(C, Const)
63+
scale!(C.dval, α.val)
64+
end
65+
scale!(C.val, α.val)
66+
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
67+
return C
68+
elseif EnzymeRules.needs_primal(config)
69+
return C.val
70+
elseif EnzymeRules.needs_shadow(config)
71+
return C.dval
72+
else
73+
return nothing
74+
end
75+
end
76+
77+
function EnzymeRules.augmented_primal(
78+
config::EnzymeRules.RevConfigWidth{1},
79+
func::Const{typeof(scale!)},
80+
::Type{RT},
81+
C::Annotation,
82+
A::Annotation,
83+
α::Annotation{<:Number},
84+
) where {RT}
85+
cacheA = EnzymeRules.overwritten(config)[3] ? copy(A.val) : A.val
86+
cacheα = EnzymeRules.overwritten(config)[4] ? copy.val) : α.val
87+
cache = (cacheA, cacheα)
88+
ret = scale!(C.val, A.val, α.val)
89+
dret = !isa(C, Const) ? C.dval : nothing
90+
shadow = EnzymeRules.needs_shadow(config) ? dret : nothing
91+
primal = EnzymeRules.needs_primal(config) ? ret : nothing
92+
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
93+
end
94+
95+
function EnzymeRules.reverse(
96+
config::EnzymeRules.RevConfigWidth{1},
97+
func::Const{typeof(scale!)},
98+
::Type{RT},
99+
cache,
100+
C::Annotation,
101+
A::Annotation,
102+
α::Annotation{<:Number},
103+
) where {RT}
104+
Aval, αval = cache
105+
!isa(A, Const) && !isa(C, Const) && add!(A.dval, C.dval, conj(αval))
106+
Δα = if !isa(α, Const) && !isa(C, Const)
107+
project_scalar.val, inner(Aval, C.dval))
108+
elseif !isa(α, Const)
109+
zero.val)
110+
else
111+
nothing
112+
end
113+
zerovector!(C.dval)
114+
return (nothing, nothing, Δα)
115+
end
116+
117+
function EnzymeRules.forward(
118+
config::EnzymeRules.FwdConfigWidth{1},
119+
func::Const{typeof(scale!)},
120+
::Type{RT},
121+
C::Annotation,
122+
A::Annotation,
123+
α::Annotation{<:Number},
124+
) where {RT}
125+
scale!(C.val, A.val, α.val)
126+
!isa(C, Const) && !isa(A, Const) && scale!(C.dval, A.dval, α.val)
127+
!isa(α, Const) && !isa(C, Const) && add!(C.dval, A.val, α.dval, One())
128+
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
129+
return C
130+
elseif EnzymeRules.needs_primal(config)
131+
return C.val
132+
elseif EnzymeRules.needs_shadow(config)
133+
return C.dval
134+
else
135+
return nothing
136+
end
137+
end
138+
139+
function EnzymeRules.augmented_primal(
140+
config::EnzymeRules.RevConfigWidth{1},
141+
func::Const{typeof(add!)},
142+
::Type{RT},
143+
C::Annotation,
144+
A::Annotation,
145+
α::Annotation{<:Number},
146+
β::Annotation{<:Number},
147+
) where {RT}
148+
dret = !isa(C, Const) ? C.dval : nothing
149+
# only need copy of A if α is not constant
150+
cacheA = !isa(α, Const) && EnzymeRules.overwritten(config)[3] ? copy(A.val) : A.val
151+
cacheα = EnzymeRules.overwritten(config)[4] ? copy.val) : α.val
152+
cacheβ = EnzymeRules.overwritten(config)[5] ? copy.val) : β.val
153+
# only need copy of C if β is not constant
154+
cacheC = !isa(β, Const) ? copy(C.val) : C.val
155+
cache = (cacheA, cacheα, cacheβ, cacheC)
156+
ret = add!(C.val, A.val, α.val, β.val)
157+
shadow = EnzymeRules.needs_shadow(config) ? dret : nothing
158+
primal = EnzymeRules.needs_primal(config) ? ret : nothing
159+
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
160+
end
161+
162+
function EnzymeRules.reverse(
163+
config::EnzymeRules.RevConfigWidth{1},
164+
func::Const{typeof(add!)},
165+
::Type{RT},
166+
cache,
167+
C::Annotation,
168+
A::Annotation,
169+
α::Annotation{<:Number},
170+
β::Annotation{<:Number},
171+
) where {RT}
172+
Aval, αval, βval, Cval = cache
173+
Δα = if !isa(α, Const) && !isa(C, Const)
174+
project_scalar.val, inner(Aval, C.dval))
175+
elseif !isa(α, Const)
176+
zero.val)
177+
else
178+
nothing
179+
end
180+
Δβ = if !isa(β, Const) && !isa(C, Const)
181+
project_scalar.val, inner(Cval, C.dval))
182+
elseif !isa(β, Const)
183+
zero.val)
184+
else
185+
nothing
186+
end
187+
!isa(A, Const) && !isa(C, Const) && add!(A.dval, C.dval, conj(αval))
188+
!isa(C, Const) && scale!(C.dval, conj(βval))
189+
return (nothing, nothing, Δα, Δβ)
190+
end
191+
192+
function EnzymeRules.forward(
193+
config::EnzymeRules.FwdConfigWidth{1},
194+
func::Const{typeof(add!)},
195+
::Type{RT},
196+
C::Annotation,
197+
A::Annotation,
198+
α::Annotation{<:Number},
199+
β::Annotation{<:Number},
200+
) where {RT}
201+
!isa(C, Const) && !isa(A, Const) && add!(C.dval, A.dval, α.val, β.val)
202+
!isa(C, Const) && !isa(α, Const) && add!(C.dval, A.val, α.dval, One())
203+
!isa(C, Const) && !isa(β, Const) && add!(C.dval, C.val, β.dval, One())
204+
add!(C.val, A.val, α.val, β.val)
205+
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
206+
return C
207+
elseif EnzymeRules.needs_primal(config)
208+
return C.val
209+
elseif EnzymeRules.needs_shadow(config)
210+
return C.dval
211+
else
212+
return nothing
213+
end
214+
end
215+
216+
function EnzymeRules.augmented_primal(
217+
config::EnzymeRules.RevConfigWidth{1},
218+
func::Const{typeof(inner)},
219+
::Type{RT},
220+
A::Annotation,
221+
B::Annotation,
222+
) where {RT}
223+
cacheA = EnzymeRules.overwritten(config)[2] ? copy(A.val) : A.val
224+
cacheB = EnzymeRules.overwritten(config)[3] ? copy(B.val) : B.val
225+
cache = (cacheA, cacheB)
226+
ret = inner(A.val, B.val)
227+
shadow = EnzymeRules.needs_shadow(config) ? zero(ret) : nothing
228+
primal = EnzymeRules.needs_primal(config) ? ret : nothing
229+
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
230+
end
231+
232+
function EnzymeRules.reverse(
233+
config::EnzymeRules.RevConfigWidth{1},
234+
func::Const{typeof(inner)},
235+
dret::Active,
236+
cache,
237+
A::Annotation,
238+
B::Annotation,
239+
)
240+
ΔS = dret.val
241+
Aval, Bval = cache
242+
!isa(A, Const) && add!(A.dval, Bval, conj(ΔS))
243+
!isa(B, Const) && add!(B.dval, Aval, ΔS)
244+
return (nothing, nothing)
245+
end
246+
247+
function EnzymeRules.reverse(
248+
config::EnzymeRules.RevConfigWidth{1},
249+
func::Const{typeof(inner)},
250+
RT::Type{<:Const},
251+
cache,
252+
A::Annotation,
253+
B::Annotation,
254+
)
255+
return (nothing, nothing)
256+
end
257+
258+
function EnzymeRules.forward(
259+
config::EnzymeRules.FwdConfigWidth{1},
260+
func::Const{typeof(inner)},
261+
::Type{RT},
262+
A::Annotation,
263+
B::Annotation,
264+
) where {RT}
265+
ret = inner(A.val, B.val)
266+
if EnzymeRules.needs_shadow(config) # only compute this if actually needed
267+
dret = zero(ret)
268+
!isa(A, Const) && (dret += inner(A.dval, B.val))
269+
!isa(B, Const) && (dret += inner(A.val, B.dval))
270+
else
271+
dret = nothing
272+
end
273+
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
274+
return Duplicated(ret, dret)
275+
elseif EnzymeRules.needs_primal(config)
276+
return ret
277+
elseif EnzymeRules.needs_shadow(config)
278+
return dret
279+
else
280+
return nothing
281+
end
282+
end
283+
284+
end

0 commit comments

Comments
 (0)