Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "VLBISkyModels"
uuid = "d6343c73-7174-4e0f-bb64-562643efbeca"
version = "0.6.20"
version = "0.6.21"
authors = ["Paul Tiede <ptiede91@gmail.com> and contributors"]

[deps]
Expand Down Expand Up @@ -36,18 +36,21 @@ StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
FINUFFT = "d8beea63-0952-562e-9c6a-8e8ef7364055"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
NonuniformFFTs = "cd96f58b-6017-4a02-bb9e-f4d81626177f"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"


[extensions]
VLBISkyModelsFINUFFT = ["FINUFFT"]
VLBISkyModelsMakieExt = ["Makie", "DimensionalData"]
VLBISkyModelsNonuniformFFTs = ["NonuniformFFTs"]
VLBISkyModelsReactantExt = ["Reactant"]

[compat]
AbstractFFTs = "1"
Accessors = "0.1"
ArgCheck = "2"
ChainRulesCore = "1"
ComradeBase = "^0.9.6"
ComradeBase = "^0.9.8"
DelimitedFiles = "1"
DimensionalData = "0.29 - 0.29.24, ^0.29.26"
DocStringExtensions = "0.6,0.7,0.8,0.9"
Expand All @@ -66,6 +69,7 @@ NonuniformFFTs = "0.9"
PaddedViews = "0.5"
PolarizedTypes = "^0.1.1"
Printf = "1.8"
Reactant = "0.2"
RecipesBase = "1"
Reexport = "1"
Serialization = "1.8"
Expand All @@ -78,7 +82,8 @@ julia = "1.9"
FINUFFT = "d8beea63-0952-562e-9c6a-8e8ef7364055"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
NonuniformFFTs = "cd96f58b-6017-4a02-bb9e-f4d81626177f"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "FINUFFT", "Makie", "NonuniformFFTs"]
test = ["Test", "FINUFFT", "Makie", "NonuniformFFTs", "Reactant"]
2 changes: 1 addition & 1 deletion ext/VLBISkyModelsFINUFFT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ EnzymeRules.inactive_type(::Type{<:FINUFFT.finufft_plan}) = true

function VLBISkyModels._jlnuft!(out, A::FINUFFTPlan, b::AbstractArray{<:Real})
bc = getcache(A)
bc .= b
copyto!(bc, b)
FINUFFT.finufft_exec!(A.forward, bc, out)
return nothing
end
Expand Down
312 changes: 312 additions & 0 deletions ext/VLBISkyModelsReactantExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,312 @@
module VLBISkyModelsReactantExt

using VLBISkyModels
using Reactant
using NFFT
using NFFT: AbstractNFFTs
using AbstractFFTs
using VLBISkyModels: ReactantAlg
using LinearAlgebra



# Need a better way to get this
const AFTR = Base.get_extension(Reactant, :ReactantAbstractFFTsExt)

struct ReactantNFFTPlan{T, D, K <: AbstractArray, arrTc, vecI, vecII, FP, BP, INV, SM} <:
AbstractNFFTPlan{T, D, 1}
N::NTuple{D, Int}
NOut::NTuple{1, Int}
J::Int
k::K
Ñ::NTuple{D, Int}
dims::UnitRange{Int}
forwardFFT::FP
backwardFFT::BP
tmpVec::arrTc
tmpVecHat::arrTc
deconvolveIdx::vecI
windowLinInterp::vecII
windowHatInvLUT::INV
B::SM
end


function VLBISkyModels.plan_nuft_spatial(
::ReactantAlg,
imgdomain::ComradeBase.AbstractRectiGrid,
visdomain::ComradeBase.UnstructuredDomain,
)
visp = domainpoints(visdomain)
uv2 = similar(visp.U, (2, length(visdomain)))
dpx = pixelsizes(imgdomain)
dx = dpx.X
dy = dpx.Y
rm = ComradeBase.rotmat(imgdomain)'
# Here we flip the sign because the NFFT uses the -2pi convention
uv2[1, :] .= -VLBISkyModels._rotatex.(visp.U, visp.V, Ref(rm)) .* dx
uv2[2, :] .= -VLBISkyModels._rotatey.(visp.U, visp.V, Ref(rm)) .* dy
return ReactantNFFTPlan(uv2, size(imgdomain))
end

function VLBISkyModels.make_phases(
::ReactantAlg,
imgdomain::ComradeBase.AbstractRectiGrid,
visdomain::ComradeBase.UnstructuredDomain,
)
return Reactant.to_rarray(VLBISkyModels.make_phases(NFFTAlg(), imgdomain, visdomain))
end

function VLBISkyModels._jlnuft!(out, A::ReactantNFFTPlan, inp::Reactant.AnyTracedRArray)
LinearAlgebra.mul!(out, A, inp)
return nothing
end


Base.adjoint(p::ReactantNFFTPlan) = p


function AbstractNFFTs.plan_nfft(
arr::Type{<:Reactant.AnyTracedRArray},
k::AbstractMatrix,
N::NTuple{D, Int},
rest...;
kargs...,
) where {D}
p = ReactantNFFTPlan(arr, k, N; kargs...)
return p
end

function ReactantNFFTPlan(
k::AbstractArray{T}, N::NTuple{D, Int}; fftflags = nothing, kwargs...
) where {T, D}


dims = 1:D
CT = complex(T)
params, N, NOut, J, Ñ, dims_ = NFFT.initParams(k, N, dims; kwargs...)

FP0 = plan_fft!(zeros(ComplexF64, 2,2))
BP0 = plan_bfft!(zeros(ComplexF64, 2,2))
FP = AFTR.reactant_fftplan(AFTR.reactant_fftplan_type(typeof(FP0)), FP0)
BP = AFTR.reactant_fftplan(AFTR.reactant_fftplan_type(typeof(BP0)), BP0)

params.storeDeconvolutionIdx = true # GPU_NFFT only works this way
params.precompute = NFFT.FULL # GPU_NFFT only works this way

windowLinInterp, windowPolyInterp, windowHatInvLUT, deconvolveIdx, B = NFFT.precomputation(
k, N[dims_], Ñ[dims_], params
)

U = params.storeDeconvolutionIdx ? N : ntuple(d -> 0, Val(D))

tmpVec = Reactant.to_rarray(zeros(CT, Ñ))
tmpVecHat = Reactant.to_rarray(zeros(CT, U))
deconvIdx = Reactant.to_rarray(Int.(deconvolveIdx))
winHatInvLUT = Reactant.to_rarray(complex(windowHatInvLUT[1]))
B_ = (Reactant.to_rarray(complex.(Array(B))))

return ReactantNFFTPlan{
T,
D,
typeof(k),
typeof(tmpVec),
typeof(deconvIdx),
typeof(windowLinInterp),
typeof(FP),
typeof(BP),
typeof(winHatInvLUT),
typeof(B_),
}(
N,
NOut,
J,
k,
Ñ,
dims_,
FP,
BP,
tmpVec,
tmpVecHat,
deconvIdx,
windowLinInterp,
winHatInvLUT,
B_,
)
end

AbstractNFFTs.size_in(p::ReactantNFFTPlan) = p.N
AbstractNFFTs.size_out(p::ReactantNFFTPlan) = p.NOut

function AbstractNFFTs.convolve!(
p::ReactantNFFTPlan{T, D}, g::Reactant.AnyTracedRArray, fHat::Reactant.AnyTracedRArray
) where {D, T}
mul!(fHat, transpose(p.B), vec(g))
return nothing
end

function AbstractNFFTs.convolve_transpose!(
p::ReactantNFFTPlan{T, D}, fHat::Reactant.AnyTracedRArray, g::Reactant.AnyTracedRArray
) where {D, T}
mul!(vec(g), p.B, fHat)
return nothing
end

function Base.:*(p::ReactantNFFTPlan{T}, f::Reactant.AnyTracedRArray; kargs...) where {T}
fHat = similar(f, complex(T), size_out(p))
mul!(fHat, p, f; kargs...)
return fHat
end

function AbstractNFFTs.deconvolve!(
p::ReactantNFFTPlan{T, D}, f::AbstractArray, g::AbstractArray
) where {D, T}
tmp = f .* reshape(p.windowHatInvLUT, size(f))
@allowscalar g[p.deconvolveIdx] = reshape(tmp, :)
return nothing
end

""" in-place NFFT on the GPU"""
function LinearAlgebra.mul!(
fHat::Reactant.AnyTracedRArray,
p::ReactantNFFTPlan{T, D},
f::Reactant.AnyTracedRArray;
verbose = false,
timing::Union{Nothing, TimingStats} = nothing,
) where {T, D}
NFFT.consistencyCheck(p, f, fHat)

fill!(p.tmpVec, zero(Complex{T}))
t1 = @elapsed @inbounds deconvolve!(p, f, p.tmpVec)
fHat .= p.tmpVec[1:length(fHat)]
p.forwardFFT * p.tmpVec
return t3 = @elapsed @inbounds NFFT.convolve!(p, p.tmpVec, fHat)
end

function NFFT.nfft(k::AbstractMatrix, f::Reactant.AnyTracedRArray, args...; kwargs...)
p = ReactantNFFTPlan(typeof(f), k, size(f))
return p * f
end

function NFFT.initParams(
k::AbstractMatrix{T},
N::NTuple{D, Int},
dims::Union{Integer, UnitRange{Int64}} = 1:D;
kargs...,
) where {D, T}
# convert dims to a unit range
dims_ = (typeof(dims) <: Integer) ? (dims:dims) : dims

params = NFFTParams{T, D}(; kargs...)
m, σ, reltol = accuracyParams(; kargs...)
params.m = m
params.σ = σ
params.reltol = reltol

# Taken from NFFT3
m2K = [1, 3, 7, 9, 14, 17, 20, 23, 24]
K = m2K[min(m + 1, length(m2K))]
params.LUTSize = 2^(K) * (m) # ensure that LUTSize is dividable by (m)

if length(dims_) != size(k, 1)
throw(ArgumentError("Nodes x have dimension $(size(k, 1)) != $(length(dims_))"))
end

doTrafo = ntuple(d -> d ∈ dims_, Val(D))

Ñ = ntuple(
d ->
doTrafo[d] ? (ceil(Int, params.σ * N[d]) ÷ 2) * 2 : # ensure that n is an even integer
N[d], Val(D)
)

params.σ = Ñ[dims_[1]] / N[dims_[1]]

#params.blockSize = ntuple(d-> Ñ[d] , D) # just one block
if haskey(kargs, :blockSize)
params.blockSize = kargs[:blockSize]
else
params.blockSize = ntuple(d -> NFFT._blockSize(Ñ, d), Val(D))
end

J = size(k, 2)

# calculate output size
NOut = Int[]
Mtaken = false
ntuple(Val(D)) do d
if !doTrafo[d]
return N[d]
elseif !Mtaken
return J
Mtaken = true
end
end
for d in 1:D
if !doTrafo[d]
push!(NOut, N[d])
elseif !Mtaken
push!(NOut, J)
Mtaken = true
end
end
# Sort nodes in lexicographic way
if params.sortNodes
k .= sortslices(k; dims = 2)
end
return params, N, Tuple(NOut), J, Ñ, dims_
end

function NFFT.precomputation(k::AbstractVecOrMat, N::NTuple{D, Int}, Ñ, params) where {D}
m = params.m
σ = params.σ
window = params.window
LUTSize = params.LUTSize
precompute = params.precompute

win, win_hat = getWindow(window) # highly type instable. But what should be do
J = size(k, 2)

windowHatInvLUT_ = Vector{Vector{T}}(undef, D)
precomputeWindowHatInvLUT(windowHatInvLUT_, win_hat, N, Ñ, m, σ, T)

if params.storeDeconvolutionIdx
windowHatInvLUT = Vector{Vector{T}}(undef, 1)
windowHatInvLUT[1], deconvolveIdx = precompWindowHatInvLUT(
params, N, Ñ, windowHatInvLUT_
)
else
windowHatInvLUT = windowHatInvLUT_
deconvolveIdx = Array{Int64, 1}(undef, 0)
end

if precompute == LINEAR
windowLinInterp = precomputeLinInterp(win, m, σ, LUTSize, T)
windowPolyInterp = Matrix{T}(undef, 0, 0)
B = sparse([], [], T[])
elseif precompute == POLYNOMIAL
windowLinInterp = Vector{T}(undef, 0)
windowPolyInterp = precomputePolyInterp(win, m, σ, T)
B = sparse([], [], T[])
elseif precompute == FULL
windowLinInterp = Vector{T}(undef, 0)
windowPolyInterp = Matrix{T}(undef, 0, 0)
B = precomputeB(win, k, N, Ñ, m, J, σ, LUTSize, T)
#windowLinInterp = precomputeLinInterp(win, windowLinInterp, Ñ, m, σ, LUTSize, T) # These versions are for debugging
#B = precomputeB(windowLinInterp, k, N, Ñ, m, J, σ, LUTSize, T)
elseif precompute == TENSOR
windowLinInterp = Vector{T}(undef, 0)
windowPolyInterp = Matrix{T}(undef, 0, 0)
B = sparse([], [], T[])
else
windowLinInterp = Vector{T}(undef, 0)
windowPolyInterp = Matrix{T}(undef, 0, 0)
B = sparse([], [], T[])
error("precompute = $precompute not supported by NFFT.jl!")
end

return (windowLinInterp, windowPolyInterp, windowHatInvLUT, deconvolveIdx, B)
end


end
4 changes: 2 additions & 2 deletions src/fourierdomain/nuft/nfft_alg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ end
@inline function _nuft!(out::AbstractArray, A, b::AbstractArray)
tmp = similar(out)
_jlnuft!(tmp, A, b)
out .= tmp
copyto!(out, tmp)
return nothing
end

Expand Down Expand Up @@ -183,7 +183,7 @@ function EnzymeRules.reverse(
for (db, dout) in zip(dbs, douts)
# TODO open PR on NFFT so we can do this in place.
_jlnuft_adjointadd!(db, A.val, dout)
dout .= 0
fill!(dout, 0)
end
return (nothing, nothing, nothing)
end
1 change: 1 addition & 0 deletions src/fourierdomain/nuft/nfft_reactant.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
struct ReactantAlg <: NUFT end
2 changes: 2 additions & 0 deletions src/fourierdomain/nuft/nuft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,5 @@ include(joinpath(@__DIR__, "dft_alg.jl"))
include(joinpath(@__DIR__, "finufft.jl"))

include(joinpath(@__DIR__, "nonuniformffts.jl"))

include(joinpath(@__DIR__, "nfft_reactant.jl"))
Loading
Loading