Skip to content

Commit 9960fe6

Browse files
Copilotwsmoses
andauthored
Add regression test for Julia 1.12 custom-rule calling convention rewrite (#3196)
* Initial plan * Add callconv regression test * fix --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: Billy Moses <wmoses@google.com>
1 parent 1b49890 commit 9960fe6

2 files changed

Lines changed: 142 additions & 6 deletions

File tree

src/rules/customrules.jl

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -736,22 +736,47 @@ function enzyme_custom_setup_args(
736736
@assert ival !== nothing
737737

738738
for idx = 1:width
739+
local_shadow_root = if roots_ival !== nothing
740+
(width == 1) ? roots_ival : extract_value!(B, roots_ival, idx - 1)
741+
end
742+
739743
if !is_constant_value(gutils, op)
740744
ev =
741745
(width == 1) ? ptr_val : extract_value!(B, ptr_val, idx - 1)
742-
ld = load!(B, iarty, ev, "rules_shadow_load")
743-
metadata(ld)["enzyme_mustcache"] = MDNode(LLVM.Metadata[])
746+
747+
ld = if uncache_arg
748+
if !reverse
749+
ld0 = load!(B, iarty, ev, "rules_shadow_load")
750+
metadata(ld0)["enzyme_mustcache"] = MDNode(LLVM.Metadata[])
751+
if roots_op != nothing
752+
if uncacheable[arg.arg_i + 1] != 0
753+
ld0 = recombine_value!(B, ld0, local_shadow_root)
754+
else
755+
ld0 = nullify_rooted_values!(B, ld0)
756+
end
757+
end
758+
push!(byval_tapes, ld0)
759+
ld0
760+
else
761+
@assert tape isa LLVM.Value
762+
ld0 = extract_value!(B, tape, length(byval_tapes), "shadow_roots_op_extract_v1_")
763+
@assert value_type(ld0) == iarty
764+
push!(byval_tapes, ld0)
765+
ld0
766+
end
767+
else
768+
ld0 = load!(B, iarty, ev, "rules_shadow_load")
769+
metadata(ld0)["enzyme_mustcache"] = MDNode(LLVM.Metadata[])
770+
ld0
771+
end
772+
744773
ival = (width == 1) ? ld : insert_value!(B, ival, ld, idx - 1)
745774
@assert ival !== nothing
746775
else
747776
ival = (width == 1) ? ptr_val : insert_value!(B, ival, ptr_val, idx - 1)
748777
@assert ival !== nothing
749778
end
750779

751-
local_shadow_root = if roots_ival !== nothing
752-
(width == 1) ? roots_ival : extract_value!(B, roots_ival, idx - 1)
753-
end
754-
755780
if shadow_roots !== nothing
756781

757782
for r = 1:n_primal_roots

test/callconv.jl

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Enzyme, Test
2+
using Enzyme: EnzymeRules
23

34
@noinline function force_stup(A)
45
A11 = A[];
@@ -141,3 +142,113 @@ g_test(p) = sum(Enzyme.gradient(Enzyme.set_runtime_activity(Enzyme.Reverse), inn
141142
@test dp [3.18, 2.68, 2.82]
142143
end
143144

145+
abstract type AbstractDomainCallConv end
146+
147+
mutable struct InnerPlanCallConv
148+
b::Float64
149+
end
150+
151+
struct MyPlanCallConv
152+
a::Int64
153+
b::Float64
154+
c::Int32
155+
d::Int32
156+
plan::InnerPlanCallConv
157+
phases::Vector{Float64}
158+
indices::Tuple{Vector{Int64}, Vector{Int64}}
159+
h::Int64
160+
end
161+
162+
struct MyDomainCallConv <: AbstractDomainCallConv
163+
plan::MyPlanCallConv
164+
end
165+
166+
@noinline forward_plan_callconv(g::AbstractDomainCallConv) = getfield(g, :plan)
167+
EnzymeRules.inactive(::typeof(forward_plan_callconv), args...) = nothing
168+
169+
@noinline getplan_callconv(p::MyPlanCallConv) = getfield(p, :plan)
170+
EnzymeRules.inactive(::typeof(getplan_callconv), args...) = nothing
171+
172+
@noinline getindices_callconv(p::MyPlanCallConv) = getfield(p, :indices)
173+
EnzymeRules.inactive(::typeof(getindices_callconv), args...) = nothing
174+
175+
@noinline function my_nuft_callconv!(out, A, b)
176+
out .= b .* A.b
177+
return nothing
178+
end
179+
180+
function EnzymeRules.augmented_primal(
181+
config::EnzymeRules.RevConfigWidth,
182+
func::Const{typeof(my_nuft_callconv!)},
183+
::Type{<:Const},
184+
out::EnzymeRules.Annotation,
185+
A::EnzymeRules.Annotation,
186+
b::EnzymeRules.Annotation,
187+
)
188+
primal = EnzymeRules.needs_primal(config) ? out.val : nothing
189+
shadow = EnzymeRules.needs_shadow(config) ? out.dval : nothing
190+
func.val(out.val, A.val, b.val)
191+
return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
192+
end
193+
194+
function EnzymeRules.reverse(
195+
config::EnzymeRules.RevConfigWidth,
196+
::Const{typeof(my_nuft_callconv!)},
197+
::Type{RT},
198+
tape,
199+
out::EnzymeRules.Annotation,
200+
A::EnzymeRules.Annotation,
201+
b::EnzymeRules.Annotation,
202+
) where {RT}
203+
b.dval .+= out.dval .* A.val.b
204+
fill!(out.dval, 0)
205+
return (nothing, nothing, nothing)
206+
end
207+
208+
function applyphases_callconv!(vis, phases)
209+
for i in eachindex(vis, phases)
210+
vis[i] *= phases[i]
211+
end
212+
return vis
213+
end
214+
215+
@inline function applyft_callconv(p, img)
216+
vis = similar(img)
217+
plan = getplan_callconv(p)
218+
iminds, visinds = getindices_callconv(p)
219+
for i in eachindex(iminds, visinds)
220+
imind = iminds[i]
221+
visind = visinds[i]
222+
vis_view = @view(vis[visind:visind])
223+
img_view = @view(img[imind:imind])
224+
my_nuft_callconv!(vis_view, plan, img_view)
225+
end
226+
applyphases_callconv!(vis, p.phases)
227+
return vis
228+
end
229+
230+
@noinline function visibilitymap_numeric_callconv(grid::AbstractDomainCallConv, img::Vector{Float64})
231+
return applyft_callconv(forward_plan_callconv(grid), img)
232+
end
233+
234+
@noinline function foo_callconv(grid::AbstractDomainCallConv, img)
235+
return sum(visibilitymap_numeric_callconv(grid, img))
236+
end
237+
238+
@testset "Custom rule calling conv rewrite" begin
239+
inner = InnerPlanCallConv(2.0)
240+
plan = MyPlanCallConv(1, 2.0, 3, 4, inner, [2.0, 3.0, 4.0], ([1, 2, 3], [1, 2, 3]), 7)
241+
grid = MyDomainCallConv(plan)
242+
img = [1.0, 2.0, 3.0]
243+
dimg = zeros(3)
244+
245+
Enzyme.autodiff(
246+
Enzyme.set_runtime_activity(Enzyme.Reverse),
247+
foo_callconv,
248+
Active,
249+
Const(grid),
250+
Duplicated(img, dimg),
251+
)
252+
253+
@test dimg == [4.0, 6.0, 8.0]
254+
end

0 commit comments

Comments
 (0)