Skip to content

Commit 01c3200

Browse files
authored
Merge pull request #414 from ptiede/ptiede-multisky
Add ability to have multiple grids for a sky model
2 parents 7b4ee54 + 4c8d8e2 commit 01c3200

File tree

8 files changed

+221
-106
lines changed

8 files changed

+221
-106
lines changed

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ Comrade.Coherencies
8787
Comrade.AbstractSkyModel
8888
Comrade.SkyModel
8989
Comrade.FixedSkyModel
90+
Comrade.MultiSkyModel
9091
Comrade.idealvisibilities
9192
Comrade.skymodel(::Comrade.AbstractVLBIPosterior, ::Any)
9293
```

src/Comrade.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ import ComradeBase: flux, radialextent, intensitymap, intensitymap!,
6565
amplitudemap
6666
include("observations/observations.jl")
6767
include("instrument/instrument.jl")
68-
include("skymodels/models.jl")
68+
include("skymodels/abstract.jl")
6969
include("posterior/abstract.jl")
7070
include("inference/inference.jl")
7171
include("visualizations/visualizations.jl")

src/posterior/vlbiposterior.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
struct VLBIPosterior{D, T, P, MS <: ObservedSkyModel, MI <: AbstractInstrumentModel, ADMode <: Union{Nothing, EnzymeCore.Mode}} <: AbstractVLBIPosterior
1+
struct VLBIPosterior{D, T, P, MS <: AbstractSkyModel, MI <: AbstractInstrumentModel, ADMode <: Union{Nothing, EnzymeCore.Mode}} <: AbstractVLBIPosterior
22
data::D
33
lklhds::T
44
prior::P

src/skymodels/abstract.jl

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""
2+
AbstractSkyModel
3+
4+
The abstract type for Comrade Sky Models. For a concrete implementation see [`SkyModel`](@ref).
5+
6+
Any subtype must implement the following methods
7+
- `ObservedSkyModel(m::AbstractSkyModel, array::AbstractArrayConfiguration)`: Constructs an observed sky model
8+
given the sky model `m` and the array configuration `array`. This method is used to compute the visibilities
9+
and image of the sky model.
10+
11+
The following methods have default implementations:
12+
- `idealvisibilities(m::AbstractSkyModel, x)`: Computes the ideal visibilities of the sky model `m`
13+
given the model parameters `x`.
14+
- `skymodel(m::AbstractSkyModel, x)`: Returns the sky model image given the model parameters `x`.
15+
- `domain(m::AbstractSkyModel)`: Returns the domain of the sky model `m`.
16+
- `set_array(m::AbstractSkyModel, array::AbstractArrayConfiguration)`: Sets the array configuration
17+
for the sky model `m` and returns the observed sky model and prior.
18+
- `set_prior(m::AbstractSkyModel, array::AbstractArrayConfiguration)`: Sets the prior for the sky model
19+
`m` given the array configuration `array`. This is used to set the priors for the model parameters.
20+
21+
"""
22+
abstract type AbstractSkyModel end
23+
24+
function Base.show(io::IO, mime::MIME"text/plain", m::AbstractSkyModel)
25+
T = typeof(m)
26+
ST = split(split(" $T", '{')[1], ".")[end]
27+
printstyled(io, ST; bold = true, color = :blue)
28+
println(io)
29+
println(io, " with map: $(skymodel(m))")
30+
# GT = typeof(domain(m))
31+
# SGT = split("$GT", '{')[1]
32+
print(io, " on grid: \n")
33+
show(io, mime, domain(m))
34+
return print(io, "\n )\n")
35+
end
36+
37+
skymodel(m::AbstractSkyModel) = getfield(m, :f)
38+
39+
40+
function set_array(m::AbstractSkyModel, array::AbstractArrayConfiguration)
41+
return ObservedSkyModel(m, array), set_prior(m, array)
42+
end
43+
44+
function domain(m::AbstractSkyModel; kwargs...)
45+
return getfield(m, :grid)
46+
end
47+
48+
"""
49+
idealvisibilities(m::AbstractSkyModel, x)
50+
51+
Computes the ideal non-corrupted visibilities of the sky model `m` given the model parameters `x`.
52+
"""
53+
function idealvisibilities(m::AbstractSkyModel, x)
54+
skym = skymodel(m, x.sky)
55+
return visibilitymap(skym, domain(m))
56+
end
57+
58+
function skymodel(m::AbstractSkyModel, x)
59+
return m.f(x, m.metadata)
60+
end
61+
62+
function set_prior(m::AbstractSkyModel, array::AbstractArrayConfiguration)
63+
return getfield(m, :prior)
64+
end
65+
66+
struct ObservedSkyModel{F, G <: VLBISkyModels.AbstractDomain, M} <: AbstractSkyModel
67+
f::F
68+
grid::G
69+
metadata::M
70+
end
71+
72+
73+
include("models.jl")
74+
include("fixed.jl")
75+
include("multi.jl")

src/skymodels/fixed.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""
2+
FixedSkyModel(m::AbstractModel, grid::AbstractRectiGrid; algorithm = NFFTAlg())
3+
4+
Construct a sky model that has no free parameters. This is useful for models where the
5+
image structure is known apriori but the instrument model is unknown.
6+
7+
# Arguments
8+
9+
- `m` : The fixed sky model.
10+
- `grid` : The domain on which the model is defined. This defines the field of view and resolution
11+
of the model. Note that if `f` produces a analytic model then this field of view isn't used
12+
directly in the computation of the visibilities.
13+
14+
# Optional Arguments
15+
- `algorithm` : The Fourier transform algorithm used to compute the visibilities. The default is
16+
`NFFTAlg()` which uses a non-uniform fast Fourier transform. Other options can be found by using
17+
`subtypes(VLBISkyModels.FourierTransform)`
18+
"""
19+
Base.@kwdef struct FixedSkyModel{M <: AbstractModel, G, A <: FourierTransform} <: AbstractSkyModel
20+
model::M
21+
grid::G
22+
algorithm::A = NFFTAlg()
23+
end
24+
25+
skymodel(m::FixedSkyModel) = m.model
26+
27+
function FixedSkyModel(m::AbstractModel, grid::AbstractRectiGrid; algorithm = NFFTAlg())
28+
return FixedSkyModel(m, grid, algorithm)
29+
end
30+
31+
function ObservedSkyModel(m::FixedSkyModel, arr::AbstractArrayConfiguration)
32+
gfour = FourierDualDomain(m.grid, arr, m.algorithm)
33+
img = intensitymap(m.model, gfour)
34+
vis = visibilitymap(m.model, gfour)
35+
return ObservedSkyModel(m, gfour, (; img, vis))
36+
end
37+
38+
function set_prior(::FixedSkyModel, ::AbstractArrayConfiguration)
39+
return NamedTuple()
40+
end
41+
42+
function idealvisibilities(m::ObservedSkyModel{<:FixedSkyModel}, x)
43+
return m.metadata.vis
44+
end
45+
46+
function skymodel(m::ObservedSkyModel{<:FixedSkyModel}, x)
47+
return m.f.model
48+
end

src/skymodels/models.jl

Lines changed: 0 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,6 @@
11
export SkyModel, FixedSkyModel
22

33

4-
"""
5-
AbstractSkyModel
6-
7-
The abstract type for Comrade Sky Models. For a concrete implementation see [`SkyModel`](@ref).
8-
9-
Any subtype must implement the following methods
10-
11-
- `set_array(m::AbstractSkyModel, array::AbstractArrayConfiguration)`: Sets the array configuration
12-
for the sky model `m` and returns the observed sky model and prior.
13-
14-
The following methods have default implementations:
15-
- `idealvisibilities(m::AbstractSkyModel, x)`: Computes the ideal visibilities of the sky model `m`
16-
given the model parameters `x`.
17-
- `skymodel(m::AbstractSkyModel, x)`: Returns the sky model image given the model parameters `x`.
18-
- `domain(m::AbstractSkyModel)`: Returns the domain of the sky model `m`.
19-
"""
20-
abstract type AbstractSkyModel end
21-
22-
234
struct SkyModel{F, P, G <: AbstractDomain, A <: FourierTransform, M} <: AbstractSkyModel
245
f::F
256
prior::P
@@ -28,16 +9,6 @@ struct SkyModel{F, P, G <: AbstractDomain, A <: FourierTransform, M} <: Abstract
289
metadata::M
2910
end
3011

31-
function Base.show(io::IO, mime::MIME"text/plain", m::AbstractSkyModel)
32-
T = typeof(m)
33-
ST = split(split(" $T", '{')[1], ".")[end]
34-
printstyled(io, ST; bold = true, color = :blue)
35-
println(io)
36-
println(io, " with map: $(m.f)")
37-
GT = typeof(m.grid)
38-
SGT = split("$GT", '{')[1]
39-
return print(io, " on grid: $SGT")
40-
end
4112

4213
"""
4314
SkyModel(f, prior, grid::AbstractRectiGrid; algorithm = NFFTAlg(), metadata=nothing)
@@ -72,15 +43,6 @@ function VLBISkyModels.FourierDualDomain(grid::AbstractRectiGrid, array::Abstrac
7243
return FourierDualDomain(grid, domain(array; executor), alg)
7344
end
7445

75-
struct ObservedSkyModel{F, G <: VLBISkyModels.AbstractDomain, M} <: AbstractSkyModel
76-
f::F
77-
grid::G
78-
metadata::M
79-
end
80-
81-
function domain(m::AbstractSkyModel; kwargs...)
82-
return getfield(m, :grid)
83-
end
8446

8547
# If we are using a analytic model then we don't need to plan the FT and we
8648
# can save some memory by not storing the plans.
@@ -110,69 +72,3 @@ function ObservedSkyModel(m::SkyModel, arr::AbstractArrayConfiguration)
11072
end
11173
return ObservedSkyModel(m.f, g, m.metadata)
11274
end
113-
114-
115-
function set_array(m::AbstractSkyModel, array::AbstractArrayConfiguration)
116-
return ObservedSkyModel(m, array), m.prior
117-
end
118-
119-
"""
120-
idealvisibilities(m::AbstractSkyModel, x)
121-
122-
Computes the ideal non-corrupted visibilities of the sky model `m` given the model parameters `x`.
123-
"""
124-
function idealvisibilities(m::AbstractSkyModel, x)
125-
skym = skymodel(m, x.sky)
126-
return visibilitymap(skym, domain(m))
127-
end
128-
129-
function skymodel(m::AbstractSkyModel, x)
130-
return m.f(x, m.metadata)
131-
end
132-
133-
"""
134-
FixedSkyModel(m::AbstractModel, grid::AbstractRectiGrid; algorithm = NFFTAlg())
135-
136-
Construct a sky model that has no free parameters. This is useful for models where the
137-
image structure is known apriori but the instrument model is unknown.
138-
139-
# Arguments
140-
141-
- `m` : The fixed sky model.
142-
- `grid` : The domain on which the model is defined. This defines the field of view and resolution
143-
of the model. Note that if `f` produces a analytic model then this field of view isn't used
144-
directly in the computation of the visibilities.
145-
146-
# Optional Arguments
147-
- `algorithm` : The Fourier transform algorithm used to compute the visibilities. The default is
148-
`NFFTAlg()` which uses a non-uniform fast Fourier transform. Other options can be found by using
149-
`subtypes(VLBISkyModels.FourierTransform)`
150-
"""
151-
Base.@kwdef struct FixedSkyModel{M <: AbstractModel, G, A <: FourierTransform} <: AbstractSkyModel
152-
model::M
153-
grid::G
154-
algorithm::A = NFFTAlg()
155-
end
156-
157-
function FixedSkyModel(m::AbstractModel, grid::AbstractRectiGrid; algorithm = NFFTAlg())
158-
return FixedSkyModel(m, grid, algorithm)
159-
end
160-
161-
function ObservedSkyModel(m::FixedSkyModel, arr::AbstractArrayConfiguration)
162-
gfour = FourierDualDomain(m.grid, arr, m.algorithm)
163-
img = intensitymap(m.model, gfour)
164-
vis = visibilitymap(m.model, gfour)
165-
return ObservedSkyModel(m, gfour, (; img, vis))
166-
end
167-
168-
function set_array(m::FixedSkyModel, array::AbstractArrayConfiguration)
169-
return ObservedSkyModel(m, array), NamedTuple()
170-
end
171-
172-
function idealvisibilities(m::ObservedSkyModel{<:FixedSkyModel}, x)
173-
return m.metadata.vis
174-
end
175-
176-
function skymodel(m::ObservedSkyModel{<:FixedSkyModel}, x)
177-
return m.f.model
178-
end

src/skymodels/multi.jl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
export MultiSkyModel
2+
3+
"""
4+
MultiSkyModel(skymodels::NamedTuple)
5+
6+
Create a sky model that is a collection of multiple components. This is useful when
7+
you, e.g., want to decompose the sky into multiple grids, such as for wide-field imaging
8+
where the is a core component and a very far away component.
9+
10+
# Arguments
11+
- `skymodels` : A named tuple of sky models, where each model is a [`SkyModel`](@ref).
12+
13+
# Example
14+
```julia
15+
julia> s1 = SkyModel(...)
16+
julia> s2 = SkyModel(...)
17+
julia> stot = MultiSkyModel((core = m1, far = m2))
18+
julia> skymodel(stot, x)
19+
(core = ..., far = ...)
20+
```
21+
"""
22+
struct MultiSkyModel{N, T} <: AbstractSkyModel
23+
skymodels::NamedTuple{N, T}
24+
end
25+
26+
function ObservedSkyModel(m::MultiSkyModel, arr::AbstractArrayConfiguration)
27+
skymodels = map(m.skymodels) do sm
28+
ObservedSkyModel(sm, arr)
29+
end
30+
return MultiSkyModel(skymodels)
31+
end
32+
33+
function set_prior(m::MultiSkyModel, array::AbstractArrayConfiguration)
34+
prs = map(m.skymodels) do sm
35+
set_prior(sm, array)
36+
end
37+
return prs
38+
end
39+
40+
function idealvisibilities(m::MultiSkyModel{N}, x) where {N}
41+
sm = m.skymodels
42+
vis = map(N) do n
43+
Base.@_inline_meta
44+
@inline idealvisibilities(getproperty(sm, n), (; sky = getproperty(x.sky, n)))
45+
end
46+
return reduce(+, vis)
47+
end
48+
49+
function domain(m::MultiSkyModel)
50+
return map(m.skymodels) do sm
51+
domain(sm)
52+
end
53+
end
54+
55+
function skymodel(m::MultiSkyModel{N}, x) where {N}
56+
sm = m.skymodels
57+
skyms = map(N) do n
58+
Base.@_inline_meta
59+
skymodel(getproperty(sm, n), getproperty(x, n))
60+
end
61+
return NamedTuple{N}(skyms)
62+
end
63+
64+
function skymodel(m::MultiSkyModel)
65+
return map(m.skymodels) do sm
66+
skymodel(sm)
67+
end
68+
end

test/Core/models.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ using Tables
1313
using Plots
1414
import TransformVariables as TV
1515
using VLBIImagePriors
16+
using JET
1617

1718
ntequal(x::NamedTuple{N}, y::NamedTuple{N}) where {N} = map(_ntequal, (x), (y))
1819
ntequal(x, y) = false
@@ -128,6 +129,32 @@ end
128129
@test Comrade.skymodel(oskym, x) == m
129130
@test Comrade.idealvisibilities(oskym, (; sky = x)) Comrade.idealvisibilities(oskyf, (; sky = x))
130131
end
132+
133+
@testset "MultiSkyModel" begin
134+
_, vis, amp, lcamp, cphase = load_data()
135+
skytot = MultiSkyModel((dynamic = skym, static = skyf))
136+
oskytot, ptot = Comrade.set_array(skytot, arrayconfig(vis))
137+
138+
show(IOBuffer(), MIME"text/plain"(), skytot)
139+
140+
141+
x = rand(Comrade.NamedDist(ptot))
142+
143+
oskym, = Comrade.set_array(skym, arrayconfig(vis))
144+
oskyf, = Comrade.set_array(skyf, arrayconfig(vis))
145+
146+
tt = Comrade.skymodel(oskytot, x)
147+
@test tt.dynamic == Comrade.skymodel(oskym, x.dynamic)
148+
@test tt.static == Comrade.skymodel(oskyf, x.static)
149+
150+
vtot = Comrade.idealvisibilities(oskytot, (; sky = x))
151+
vdyn = Comrade.idealvisibilities(oskym, (; sky = x.dynamic))
152+
vstat = Comrade.idealvisibilities(oskyf, (; sky = x.static))
153+
154+
@test_opt Comrade.idealvisibilities(oskytot, (; sky = x))
155+
156+
@test vtot vdyn + vstat
157+
end
131158
end
132159

133160
@testset "GMRF" begin

0 commit comments

Comments
 (0)