|
1 | 1 | using Enzyme, Test |
| 2 | +using Enzyme: EnzymeRules |
2 | 3 |
|
3 | 4 | @noinline function force_stup(A) |
4 | 5 | A11 = A[]; |
@@ -141,3 +142,113 @@ g_test(p) = sum(Enzyme.gradient(Enzyme.set_runtime_activity(Enzyme.Reverse), inn |
141 | 142 | @test dp ≈ [3.18, 2.68, 2.82] |
142 | 143 | end |
143 | 144 |
|
| 145 | +abstract type AbstractDomainCallConv end |
| 146 | + |
| 147 | +mutable struct InnerPlanCallConv |
| 148 | + b::Float64 |
| 149 | +end |
| 150 | + |
| 151 | +struct MyPlanCallConv |
| 152 | + a::Int64 |
| 153 | + b::Float64 |
| 154 | + c::Int32 |
| 155 | + d::Int32 |
| 156 | + plan::InnerPlanCallConv |
| 157 | + phases::Vector{Float64} |
| 158 | + indices::Tuple{Vector{Int64}, Vector{Int64}} |
| 159 | + h::Int64 |
| 160 | +end |
| 161 | + |
| 162 | +struct MyDomainCallConv <: AbstractDomainCallConv |
| 163 | + plan::MyPlanCallConv |
| 164 | +end |
| 165 | + |
| 166 | +@noinline forward_plan_callconv(g::AbstractDomainCallConv) = getfield(g, :plan) |
| 167 | +EnzymeRules.inactive(::typeof(forward_plan_callconv), args...) = nothing |
| 168 | + |
| 169 | +@noinline getplan_callconv(p::MyPlanCallConv) = getfield(p, :plan) |
| 170 | +EnzymeRules.inactive(::typeof(getplan_callconv), args...) = nothing |
| 171 | + |
| 172 | +@noinline getindices_callconv(p::MyPlanCallConv) = getfield(p, :indices) |
| 173 | +EnzymeRules.inactive(::typeof(getindices_callconv), args...) = nothing |
| 174 | + |
| 175 | +@noinline function my_nuft_callconv!(out, A, b) |
| 176 | + out .= b .* A.b |
| 177 | + return nothing |
| 178 | +end |
| 179 | + |
| 180 | +function EnzymeRules.augmented_primal( |
| 181 | + config::EnzymeRules.RevConfigWidth, |
| 182 | + func::Const{typeof(my_nuft_callconv!)}, |
| 183 | + ::Type{<:Const}, |
| 184 | + out::EnzymeRules.Annotation, |
| 185 | + A::EnzymeRules.Annotation, |
| 186 | + b::EnzymeRules.Annotation, |
| 187 | +) |
| 188 | + primal = EnzymeRules.needs_primal(config) ? out.val : nothing |
| 189 | + shadow = EnzymeRules.needs_shadow(config) ? out.dval : nothing |
| 190 | + func.val(out.val, A.val, b.val) |
| 191 | + return EnzymeRules.AugmentedReturn(primal, shadow, nothing) |
| 192 | +end |
| 193 | + |
| 194 | +function EnzymeRules.reverse( |
| 195 | + config::EnzymeRules.RevConfigWidth, |
| 196 | + ::Const{typeof(my_nuft_callconv!)}, |
| 197 | + ::Type{RT}, |
| 198 | + tape, |
| 199 | + out::EnzymeRules.Annotation, |
| 200 | + A::EnzymeRules.Annotation, |
| 201 | + b::EnzymeRules.Annotation, |
| 202 | +) where {RT} |
| 203 | + b.dval .+= out.dval .* A.val.b |
| 204 | + fill!(out.dval, 0) |
| 205 | + return (nothing, nothing, nothing) |
| 206 | +end |
| 207 | + |
| 208 | +function applyphases_callconv!(vis, phases) |
| 209 | + for i in eachindex(vis, phases) |
| 210 | + vis[i] *= phases[i] |
| 211 | + end |
| 212 | + return vis |
| 213 | +end |
| 214 | + |
| 215 | +@inline function applyft_callconv(p, img) |
| 216 | + vis = similar(img) |
| 217 | + plan = getplan_callconv(p) |
| 218 | + iminds, visinds = getindices_callconv(p) |
| 219 | + for i in eachindex(iminds, visinds) |
| 220 | + imind = iminds[i] |
| 221 | + visind = visinds[i] |
| 222 | + vis_view = @view(vis[visind:visind]) |
| 223 | + img_view = @view(img[imind:imind]) |
| 224 | + my_nuft_callconv!(vis_view, plan, img_view) |
| 225 | + end |
| 226 | + applyphases_callconv!(vis, p.phases) |
| 227 | + return vis |
| 228 | +end |
| 229 | + |
| 230 | +@noinline function visibilitymap_numeric_callconv(grid::AbstractDomainCallConv, img::Vector{Float64}) |
| 231 | + return applyft_callconv(forward_plan_callconv(grid), img) |
| 232 | +end |
| 233 | + |
| 234 | +@noinline function foo_callconv(grid::AbstractDomainCallConv, img) |
| 235 | + return sum(visibilitymap_numeric_callconv(grid, img)) |
| 236 | +end |
| 237 | + |
| 238 | +@testset "Custom rule calling conv rewrite" begin |
| 239 | + inner = InnerPlanCallConv(2.0) |
| 240 | + plan = MyPlanCallConv(1, 2.0, 3, 4, inner, [2.0, 3.0, 4.0], ([1, 2, 3], [1, 2, 3]), 7) |
| 241 | + grid = MyDomainCallConv(plan) |
| 242 | + img = [1.0, 2.0, 3.0] |
| 243 | + dimg = zeros(3) |
| 244 | + |
| 245 | + Enzyme.autodiff( |
| 246 | + Enzyme.set_runtime_activity(Enzyme.Reverse), |
| 247 | + foo_callconv, |
| 248 | + Active, |
| 249 | + Const(grid), |
| 250 | + Duplicated(img, dimg), |
| 251 | + ) |
| 252 | + |
| 253 | + @test dimg == [4.0, 6.0, 8.0] |
| 254 | +end |
0 commit comments