Skip to content

Commit b008322

Browse files
authored
Merge pull request #116 from EHTJulia/ptiede-copyto
Switch to copyto! for better performance with Enzyme
2 parents 79367b9 + 09ff6ac commit b008322

File tree

14 files changed

+239
-16
lines changed

14 files changed

+239
-16
lines changed

Project.toml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "VLBISkyModels"
22
uuid = "d6343c73-7174-4e0f-bb64-562643efbeca"
3-
version = "0.6.20"
3+
version = "0.6.21"
44
authors = ["Paul Tiede <ptiede91@gmail.com> and contributors"]
55

66
[deps]
@@ -36,18 +36,21 @@ StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
3636
FINUFFT = "d8beea63-0952-562e-9c6a-8e8ef7364055"
3737
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
3838
NonuniformFFTs = "cd96f58b-6017-4a02-bb9e-f4d81626177f"
39+
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
40+
3941

4042
[extensions]
4143
VLBISkyModelsFINUFFT = ["FINUFFT"]
4244
VLBISkyModelsMakieExt = ["Makie", "DimensionalData"]
4345
VLBISkyModelsNonuniformFFTs = ["NonuniformFFTs"]
46+
VLBISkyModelsReactantExt = ["Reactant"]
4447

4548
[compat]
4649
AbstractFFTs = "1"
4750
Accessors = "0.1"
4851
ArgCheck = "2"
4952
ChainRulesCore = "1"
50-
ComradeBase = "^0.9.6"
53+
ComradeBase = "^0.9.8"
5154
DelimitedFiles = "1"
5255
DimensionalData = "0.29 - 0.29.24, ^0.29.26"
5356
DocStringExtensions = "0.6,0.7,0.8,0.9"
@@ -66,6 +69,7 @@ NonuniformFFTs = "0.9"
6669
PaddedViews = "0.5"
6770
PolarizedTypes = "^0.1.1"
6871
Printf = "1.8"
72+
Reactant = "0.2"
6973
RecipesBase = "1"
7074
Reexport = "1"
7175
Serialization = "1.8"
@@ -78,7 +82,8 @@ julia = "1.9"
7882
FINUFFT = "d8beea63-0952-562e-9c6a-8e8ef7364055"
7983
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
8084
NonuniformFFTs = "cd96f58b-6017-4a02-bb9e-f4d81626177f"
85+
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
8186
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
8287

8388
[targets]
84-
test = ["Test", "FINUFFT", "Makie", "NonuniformFFTs"]
89+
test = ["Test", "FINUFFT", "Makie", "NonuniformFFTs", "Reactant"]

docs/src/base_api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ ComradeBase.phasecenter
8989
ComradeBase.executor
9090
ComradeBase.Serial
9191
ComradeBase.ThreadsEx
92+
ComradeBase.ReactantEx
9293
ComradeBase.header
9394
ComradeBase.NoHeader
9495
ComradeBase.MinimalHeader

ext/VLBISkyModelsFINUFFT.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ EnzymeRules.inactive_type(::Type{<:FINUFFT.finufft_plan}) = true
7070

7171
function VLBISkyModels._jlnuft!(out, A::FINUFFTPlan, b::AbstractArray{<:Real})
7272
bc = getcache(A)
73-
bc .= b
73+
copyto!(bc, b)
7474
FINUFFT.finufft_exec!(A.forward, bc, out)
7575
return nothing
7676
end

ext/VLBISkyModelsReactantExt.jl

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
module VLBISkyModelsReactantExt
2+
3+
using VLBISkyModels
4+
using AbstractFFTs
5+
using Reactant
6+
using NFFT
7+
using NFFT: AbstractNFFTs
8+
using VLBISkyModels: ReactantAlg
9+
using LinearAlgebra
10+
11+
12+
struct ReactantNFFTPlan{T, D, K <: AbstractArray, arrTc, vecI, vecII, FP, BP, INV, SM} <:
13+
AbstractNFFTPlan{T, D, 1}
14+
N::NTuple{D, Int}
15+
NOut::NTuple{1, Int}
16+
J::Int
17+
k::K
18+
::NTuple{D, Int}
19+
dims::UnitRange{Int}
20+
forwardFFT::FP
21+
backwardFFT::BP
22+
tmpVec::arrTc
23+
tmpVecHat::arrTc
24+
deconvolveIdx::vecI
25+
windowLinInterp::vecII
26+
windowHatInvLUT::INV
27+
B::SM
28+
end
29+
30+
31+
function VLBISkyModels.plan_nuft_spatial(
32+
::ReactantAlg,
33+
imgdomain::ComradeBase.AbstractRectiGrid,
34+
visdomain::ComradeBase.UnstructuredDomain,
35+
)
36+
visp = domainpoints(visdomain)
37+
uv2 = similar(visp.U, (2, length(visdomain)))
38+
dpx = pixelsizes(imgdomain)
39+
dx = dpx.X
40+
dy = dpx.Y
41+
rm = ComradeBase.rotmat(imgdomain)'
42+
# Here we flip the sign because the NFFT uses the -2pi convention
43+
uv2[1, :] .= -VLBISkyModels._rotatex.(visp.U, visp.V, Ref(rm)) .* dx
44+
uv2[2, :] .= -VLBISkyModels._rotatey.(visp.U, visp.V, Ref(rm)) .* dy
45+
return ReactantNFFTPlan(uv2, size(imgdomain))
46+
end
47+
48+
function VLBISkyModels.make_phases(
49+
::ReactantAlg,
50+
imgdomain::ComradeBase.AbstractRectiGrid,
51+
visdomain::ComradeBase.UnstructuredDomain,
52+
)
53+
return Reactant.to_rarray(VLBISkyModels.make_phases(NFFTAlg(), imgdomain, visdomain))
54+
end
55+
56+
function VLBISkyModels._jlnuft!(out, A::ReactantNFFTPlan, inp::Reactant.AnyTracedRArray)
57+
LinearAlgebra.mul!(out, A, inp)
58+
return nothing
59+
end
60+
61+
62+
Base.adjoint(p::ReactantNFFTPlan) = p
63+
64+
65+
function AbstractNFFTs.plan_nfft(
66+
arr::Type{<:Reactant.AnyTracedRArray},
67+
k::AbstractMatrix,
68+
N::NTuple{D, Int},
69+
rest...;
70+
kargs...,
71+
) where {D}
72+
p = ReactantNFFTPlan(arr, k, N; kargs...)
73+
return p
74+
end
75+
76+
function ReactantNFFTPlan(
77+
k::AbstractArray{T}, N::NTuple{D, Int}; fftflags = nothing, kwargs...
78+
) where {T, D}
79+
80+
81+
dims = 1:D
82+
CT = complex(T)
83+
params, N, NOut, J, Ñ, dims_ = NFFT.initParams(k, N, dims; kwargs...)
84+
85+
# Get the correct type
86+
FP = @jit plan_fft!(zeros(ComplexF64, 2, 2))
87+
BP = @jit plan_bfft!(zeros(ComplexF64, 2, 2))
88+
89+
params.storeDeconvolutionIdx = true # GPU_NFFT only works this way
90+
params.precompute = NFFT.FULL # GPU_NFFT only works this way
91+
92+
windowLinInterp, windowPolyInterp, windowHatInvLUT, deconvolveIdx, B = NFFT.precomputation(
93+
k, N[dims_], Ñ[dims_], params
94+
)
95+
96+
U = params.storeDeconvolutionIdx ? N : ntuple(d -> 0, Val(D))
97+
98+
tmpVec = Reactant.to_rarray(zeros(CT, Ñ))
99+
tmpVecHat = Reactant.to_rarray(zeros(CT, U))
100+
deconvIdx = Reactant.to_rarray(Int.(deconvolveIdx))
101+
winHatInvLUT = Reactant.to_rarray(complex(windowHatInvLUT[1]))
102+
B_ = (Reactant.to_rarray(complex.(Array(B))))
103+
104+
return ReactantNFFTPlan{
105+
T,
106+
D,
107+
typeof(k),
108+
typeof(tmpVec),
109+
typeof(deconvIdx),
110+
typeof(windowLinInterp),
111+
typeof(FP),
112+
typeof(BP),
113+
typeof(winHatInvLUT),
114+
typeof(B_),
115+
}(
116+
N,
117+
NOut,
118+
J,
119+
k,
120+
Ñ,
121+
dims_,
122+
FP,
123+
BP,
124+
tmpVec,
125+
tmpVecHat,
126+
deconvIdx,
127+
windowLinInterp,
128+
winHatInvLUT,
129+
B_,
130+
)
131+
end
132+
133+
AbstractNFFTs.size_in(p::ReactantNFFTPlan) = p.N
134+
AbstractNFFTs.size_out(p::ReactantNFFTPlan) = p.NOut
135+
136+
function AbstractNFFTs.convolve!(
137+
p::ReactantNFFTPlan{T, D}, g::Reactant.AnyTracedRArray, fHat::Reactant.AnyTracedRArray
138+
) where {D, T}
139+
mul!(fHat, transpose(p.B), vec(g))
140+
return nothing
141+
end
142+
143+
function AbstractNFFTs.convolve_transpose!(
144+
p::ReactantNFFTPlan{T, D}, fHat::Reactant.AnyTracedRArray, g::Reactant.AnyTracedRArray
145+
) where {D, T}
146+
mul!(vec(g), p.B, fHat)
147+
return nothing
148+
end
149+
150+
function Base.:*(p::ReactantNFFTPlan{T}, f::Reactant.AnyTracedRArray; kargs...) where {T}
151+
fHat = similar(f, complex(T), size_out(p))
152+
mul!(fHat, p, f; kargs...)
153+
return fHat
154+
end
155+
156+
function AbstractNFFTs.deconvolve!(
157+
p::ReactantNFFTPlan{T, D}, f::AbstractArray, g::AbstractArray
158+
) where {D, T}
159+
tmp = f .* reshape(p.windowHatInvLUT, size(f))
160+
@allowscalar g[p.deconvolveIdx] = reshape(tmp, :)
161+
return nothing
162+
end
163+
164+
""" in-place NFFT on the GPU"""
165+
function LinearAlgebra.mul!(
166+
fHat::Reactant.AnyTracedRArray,
167+
p::ReactantNFFTPlan{T, D},
168+
f::Reactant.AnyTracedRArray;
169+
verbose = false,
170+
timing::Union{Nothing, TimingStats} = nothing,
171+
) where {T, D}
172+
NFFT.consistencyCheck(p, f, fHat)
173+
174+
fill!(p.tmpVec, zero(Complex{T}))
175+
t1 = @elapsed @inbounds deconvolve!(p, f, p.tmpVec)
176+
fHat .= p.tmpVec[1:length(fHat)]
177+
p.forwardFFT * p.tmpVec
178+
return t3 = @elapsed @inbounds NFFT.convolve!(p, p.tmpVec, fHat)
179+
end
180+
181+
function NFFT.nfft(k::AbstractMatrix, f::Reactant.AnyTracedRArray, args...; kwargs...)
182+
p = ReactantNFFTPlan(typeof(f), k, size(f))
183+
return p * f
184+
end
185+
186+
187+
end

src/fourierdomain/nuft/nfft_alg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ end
8383
@inline function _nuft!(out::AbstractArray, A, b::AbstractArray)
8484
tmp = similar(out)
8585
_jlnuft!(tmp, A, b)
86-
out .= tmp
86+
copyto!(out, tmp)
8787
return nothing
8888
end
8989

@@ -183,7 +183,7 @@ function EnzymeRules.reverse(
183183
for (db, dout) in zip(dbs, douts)
184184
# TODO open PR on NFFT so we can do this in place.
185185
_jlnuft_adjointadd!(db, A.val, dout)
186-
dout .= 0
186+
fill!(dout, 0)
187187
end
188188
return (nothing, nothing, nothing)
189189
end
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
struct ReactantAlg <: NUFT end

src/fourierdomain/nuft/nuft.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,5 @@ include(joinpath(@__DIR__, "dft_alg.jl"))
206206
include(joinpath(@__DIR__, "finufft.jl"))
207207

208208
include(joinpath(@__DIR__, "nonuniformffts.jl"))
209+
210+
include(joinpath(@__DIR__, "nfft_reactant.jl"))

src/models/combinators.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,9 @@ end
144144
function intensitymap_numeric!(sim::IntensityMap, m::AddModel)
145145
csim = copy(sim)
146146
intensitymap!(csim, m.m1)
147-
sim .= csim
147+
copyto!(sim, csim)
148148
intensitymap!(csim, m.m2)
149-
sim .= sim .+ csim
149+
sim .+= csim
150150
return nothing
151151
end
152152

@@ -296,7 +296,7 @@ end
296296
) where {M1, M2}
297297
cvis = similar(vis)
298298
visibilitymap!(cvis, m.m1)
299-
vis .= cvis
299+
copyto!(vis, cvis)
300300
visibilitymap!(cvis, m.m2)
301301
vis .*= cvis
302302
return nothing

src/models/continuous_image.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,10 @@ function applypulse!(vis, pulse, gfour::AbstractFourierDualDomain)
194194
# through the broadcast
195195
pvis = parent(vis)
196196
dp = domainpoints(guv)
197-
for i in eachindex(pvis, dp)
198-
pvis[i] *= visibility_point(mp, dp[i])
199-
end
197+
pvis .*= visibility_point.(Ref(mp), dp)
198+
# for i in eachindex(pvis, dp)
199+
# pvis[i] *= visibility_point(mp, dp[i])
200+
# end
200201
# pvis .*= visibility_point.(Ref(mp), dp)
201202
return vis
202203
end

src/models/geometric_models.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ end
6161
return exp(-2 * T(π)^2 * (u^2 + v^2)) + zero(T)im
6262
end
6363

64-
6564
"""
6665
$(TYPEDEF)
6766

0 commit comments

Comments
 (0)