|
| 1 | +module VLBISkyModelsReactantExt |
| 2 | + |
| 3 | +using VLBISkyModels |
| 4 | +using AbstractFFTs |
| 5 | +using Reactant |
| 6 | +using NFFT |
| 7 | +using NFFT: AbstractNFFTs |
| 8 | +using VLBISkyModels: ReactantAlg |
| 9 | +using LinearAlgebra |
| 10 | + |
| 11 | + |
| 12 | +struct ReactantNFFTPlan{T, D, K <: AbstractArray, arrTc, vecI, vecII, FP, BP, INV, SM} <: |
| 13 | + AbstractNFFTPlan{T, D, 1} |
| 14 | + N::NTuple{D, Int} |
| 15 | + NOut::NTuple{1, Int} |
| 16 | + J::Int |
| 17 | + k::K |
| 18 | + Ñ::NTuple{D, Int} |
| 19 | + dims::UnitRange{Int} |
| 20 | + forwardFFT::FP |
| 21 | + backwardFFT::BP |
| 22 | + tmpVec::arrTc |
| 23 | + tmpVecHat::arrTc |
| 24 | + deconvolveIdx::vecI |
| 25 | + windowLinInterp::vecII |
| 26 | + windowHatInvLUT::INV |
| 27 | + B::SM |
| 28 | +end |
| 29 | + |
| 30 | + |
| 31 | +function VLBISkyModels.plan_nuft_spatial( |
| 32 | + ::ReactantAlg, |
| 33 | + imgdomain::ComradeBase.AbstractRectiGrid, |
| 34 | + visdomain::ComradeBase.UnstructuredDomain, |
| 35 | + ) |
| 36 | + visp = domainpoints(visdomain) |
| 37 | + uv2 = similar(visp.U, (2, length(visdomain))) |
| 38 | + dpx = pixelsizes(imgdomain) |
| 39 | + dx = dpx.X |
| 40 | + dy = dpx.Y |
| 41 | + rm = ComradeBase.rotmat(imgdomain)' |
| 42 | + # Here we flip the sign because the NFFT uses the -2pi convention |
| 43 | + uv2[1, :] .= -VLBISkyModels._rotatex.(visp.U, visp.V, Ref(rm)) .* dx |
| 44 | + uv2[2, :] .= -VLBISkyModels._rotatey.(visp.U, visp.V, Ref(rm)) .* dy |
| 45 | + return ReactantNFFTPlan(uv2, size(imgdomain)) |
| 46 | +end |
| 47 | + |
| 48 | +function VLBISkyModels.make_phases( |
| 49 | + ::ReactantAlg, |
| 50 | + imgdomain::ComradeBase.AbstractRectiGrid, |
| 51 | + visdomain::ComradeBase.UnstructuredDomain, |
| 52 | + ) |
| 53 | + return Reactant.to_rarray(VLBISkyModels.make_phases(NFFTAlg(), imgdomain, visdomain)) |
| 54 | +end |
| 55 | + |
| 56 | +function VLBISkyModels._jlnuft!(out, A::ReactantNFFTPlan, inp::Reactant.AnyTracedRArray) |
| 57 | + LinearAlgebra.mul!(out, A, inp) |
| 58 | + return nothing |
| 59 | +end |
| 60 | + |
| 61 | + |
| 62 | +Base.adjoint(p::ReactantNFFTPlan) = p |
| 63 | + |
| 64 | + |
| 65 | +function AbstractNFFTs.plan_nfft( |
| 66 | + arr::Type{<:Reactant.AnyTracedRArray}, |
| 67 | + k::AbstractMatrix, |
| 68 | + N::NTuple{D, Int}, |
| 69 | + rest...; |
| 70 | + kargs..., |
| 71 | + ) where {D} |
| 72 | + p = ReactantNFFTPlan(arr, k, N; kargs...) |
| 73 | + return p |
| 74 | +end |
| 75 | + |
| 76 | +function ReactantNFFTPlan( |
| 77 | + k::AbstractArray{T}, N::NTuple{D, Int}; fftflags = nothing, kwargs... |
| 78 | + ) where {T, D} |
| 79 | + |
| 80 | + |
| 81 | + dims = 1:D |
| 82 | + CT = complex(T) |
| 83 | + params, N, NOut, J, Ñ, dims_ = NFFT.initParams(k, N, dims; kwargs...) |
| 84 | + |
| 85 | + # Get the correct type |
| 86 | + FP = @jit plan_fft!(zeros(ComplexF64, 2, 2)) |
| 87 | + BP = @jit plan_bfft!(zeros(ComplexF64, 2, 2)) |
| 88 | + |
| 89 | + params.storeDeconvolutionIdx = true # GPU_NFFT only works this way |
| 90 | + params.precompute = NFFT.FULL # GPU_NFFT only works this way |
| 91 | + |
| 92 | + windowLinInterp, windowPolyInterp, windowHatInvLUT, deconvolveIdx, B = NFFT.precomputation( |
| 93 | + k, N[dims_], Ñ[dims_], params |
| 94 | + ) |
| 95 | + |
| 96 | + U = params.storeDeconvolutionIdx ? N : ntuple(d -> 0, Val(D)) |
| 97 | + |
| 98 | + tmpVec = Reactant.to_rarray(zeros(CT, Ñ)) |
| 99 | + tmpVecHat = Reactant.to_rarray(zeros(CT, U)) |
| 100 | + deconvIdx = Reactant.to_rarray(Int.(deconvolveIdx)) |
| 101 | + winHatInvLUT = Reactant.to_rarray(complex(windowHatInvLUT[1])) |
| 102 | + B_ = (Reactant.to_rarray(complex.(Array(B)))) |
| 103 | + |
| 104 | + return ReactantNFFTPlan{ |
| 105 | + T, |
| 106 | + D, |
| 107 | + typeof(k), |
| 108 | + typeof(tmpVec), |
| 109 | + typeof(deconvIdx), |
| 110 | + typeof(windowLinInterp), |
| 111 | + typeof(FP), |
| 112 | + typeof(BP), |
| 113 | + typeof(winHatInvLUT), |
| 114 | + typeof(B_), |
| 115 | + }( |
| 116 | + N, |
| 117 | + NOut, |
| 118 | + J, |
| 119 | + k, |
| 120 | + Ñ, |
| 121 | + dims_, |
| 122 | + FP, |
| 123 | + BP, |
| 124 | + tmpVec, |
| 125 | + tmpVecHat, |
| 126 | + deconvIdx, |
| 127 | + windowLinInterp, |
| 128 | + winHatInvLUT, |
| 129 | + B_, |
| 130 | + ) |
| 131 | +end |
| 132 | + |
| 133 | +AbstractNFFTs.size_in(p::ReactantNFFTPlan) = p.N |
| 134 | +AbstractNFFTs.size_out(p::ReactantNFFTPlan) = p.NOut |
| 135 | + |
| 136 | +function AbstractNFFTs.convolve!( |
| 137 | + p::ReactantNFFTPlan{T, D}, g::Reactant.AnyTracedRArray, fHat::Reactant.AnyTracedRArray |
| 138 | + ) where {D, T} |
| 139 | + mul!(fHat, transpose(p.B), vec(g)) |
| 140 | + return nothing |
| 141 | +end |
| 142 | + |
| 143 | +function AbstractNFFTs.convolve_transpose!( |
| 144 | + p::ReactantNFFTPlan{T, D}, fHat::Reactant.AnyTracedRArray, g::Reactant.AnyTracedRArray |
| 145 | + ) where {D, T} |
| 146 | + mul!(vec(g), p.B, fHat) |
| 147 | + return nothing |
| 148 | +end |
| 149 | + |
| 150 | +function Base.:*(p::ReactantNFFTPlan{T}, f::Reactant.AnyTracedRArray; kargs...) where {T} |
| 151 | + fHat = similar(f, complex(T), size_out(p)) |
| 152 | + mul!(fHat, p, f; kargs...) |
| 153 | + return fHat |
| 154 | +end |
| 155 | + |
| 156 | +function AbstractNFFTs.deconvolve!( |
| 157 | + p::ReactantNFFTPlan{T, D}, f::AbstractArray, g::AbstractArray |
| 158 | + ) where {D, T} |
| 159 | + tmp = f .* reshape(p.windowHatInvLUT, size(f)) |
| 160 | + @allowscalar g[p.deconvolveIdx] = reshape(tmp, :) |
| 161 | + return nothing |
| 162 | +end |
| 163 | + |
| 164 | +""" in-place NFFT on the GPU""" |
| 165 | +function LinearAlgebra.mul!( |
| 166 | + fHat::Reactant.AnyTracedRArray, |
| 167 | + p::ReactantNFFTPlan{T, D}, |
| 168 | + f::Reactant.AnyTracedRArray; |
| 169 | + verbose = false, |
| 170 | + timing::Union{Nothing, TimingStats} = nothing, |
| 171 | + ) where {T, D} |
| 172 | + NFFT.consistencyCheck(p, f, fHat) |
| 173 | + |
| 174 | + fill!(p.tmpVec, zero(Complex{T})) |
| 175 | + t1 = @elapsed @inbounds deconvolve!(p, f, p.tmpVec) |
| 176 | + fHat .= p.tmpVec[1:length(fHat)] |
| 177 | + p.forwardFFT * p.tmpVec |
| 178 | + return t3 = @elapsed @inbounds NFFT.convolve!(p, p.tmpVec, fHat) |
| 179 | +end |
| 180 | + |
| 181 | +function NFFT.nfft(k::AbstractMatrix, f::Reactant.AnyTracedRArray, args...; kwargs...) |
| 182 | + p = ReactantNFFTPlan(typeof(f), k, size(f)) |
| 183 | + return p * f |
| 184 | +end |
| 185 | + |
| 186 | + |
| 187 | +end |
0 commit comments