Skip to content

Commit f4193d8

Browse files
committed
Format and update
1 parent 624db3b commit f4193d8

File tree

9 files changed

+43
-20
lines changed

9 files changed

+43
-20
lines changed

examples/advanced/FitPS/main.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ skym = SkyModel(sky, prior, grid; metadata = skymeta)
128128
# Since we are fitting closures we do not need to include an instrument model, since
129129
# the closure likelihood is approximately independent of gains in the high SNR limit.
130130
using Enzyme
131-
post = VLBIPosterior(skym, dlcamp, dcphase; imgdata = (Comrade.CentroidData((0.0, 0.0), beamsize(dcphase)/10.0, grid), ))
131+
post = VLBIPosterior(skym, dlcamp, dcphase; imgdata = (Comrade.CentroidData((0.0, 0.0), beamsize(dcphase) / 10.0, grid),))
132132

133133
# ## Reconstructing the Image
134134

@@ -143,7 +143,7 @@ post = VLBIPosterior(skym, dlcamp, dcphase; imgdata = (Comrade.CentroidData((0.0
143143
# We also need to import Enzyme to allow for automatic differentiation.
144144
using Optimization, OptimizationOptimisers
145145
# tpost = asflat(post)
146-
xopt, sol = comrade_opt(post, Adam(); maxiters = 5000, initial_params=xopt)
146+
xopt, sol = comrade_opt(post, Adam(); maxiters = 5000, initial_params = xopt)
147147

148148
using CairoMakie
149149
using DisplayAs #hide

examples/intermediate/ClosureImaging/main.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ skym = SkyModel(sky, prior, grid; metadata = skymeta)
126126
# Since we are fitting closures we do not need to include an instrument model, since
127127
# the closure likelihood is approximately independent of gains in the high SNR limit.
128128
using Enzyme
129-
post = VLBIPosterior(skym, dlcamp, dcphase; imgdata = (Comrade.CentroidData((0.0, 0.0), μas2rad(0.1), grid), ))
129+
post = VLBIPosterior(skym, dlcamp, dcphase; imgdata = (Comrade.CentroidData((0.0, 0.0), μas2rad(0.1), grid),))
130130

131131
# ## Reconstructing the Image
132132

examples/intermediate/PolarizedImaging/main.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ intmodel = InstrumentModel(J, intprior)
301301
# Putting it all together, we form our likelihood and posterior objects for optimization and
302302
# sampling, and specifying to use Enzyme.Reverse with runtime activity for AD.
303303
using Enzyme
304-
post = VLBIPosterior(skym, intmodel, dvis; imgdata = (Comrade.CentroidData((0.0, 0.0), μas2rad(0.1), grid), ))
304+
post = VLBIPosterior(skym, intmodel, dvis; imgdata = (Comrade.CentroidData((0.0, 0.0), μas2rad(0.1), grid),))
305305

306306
# ## Reconstructing the Image and Instrument Effects
307307

@@ -383,7 +383,7 @@ fig |> DisplayAs.PNG |> DisplayAs.Text
383383
# other imaging examples. For example
384384
# ```julia
385385
# using AdvancedHMC
386-
chain = sample(rng, post, NUTS(0.8), 2_000, n_adapts=1000, progress=true, initial_params=xopt)
386+
chain = sample(rng, post, NUTS(0.8), 2_000, n_adapts = 1000, progress = true, initial_params = xopt)
387387
# ```
388388

389389

examples/intermediate/StokesIImaging/main.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ intmodel = InstrumentModel(G, intpr)
144144
# gradients of the posterior we also need to load `Enzyme.jl`. Under the hood, Comrade will use
145145
# Enzyme to compute the gradients of the posterior.
146146
using Enzyme
147-
post = VLBIPosterior(skym, intmodel, dvis; imgdata = (Comrade.CentroidData((0.0, 0.0), μas2rad(0.1), grid), ))
147+
post = VLBIPosterior(skym, intmodel, dvis; imgdata = (Comrade.CentroidData((0.0, 0.0), μas2rad(0.1), grid),))
148148

149149
# ## Optimization and Sampling
150150

src/posterior/abstract.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ data products observed.
9696
end
9797

9898
@inline function forward_model(d::AbstractVLBIPosterior, θ)
99-
forward_model_map(datatype(typeof(d)), d, θ)
99+
return forward_model_map(datatype(typeof(d)), d, θ)
100100
end
101101

102102
"""
@@ -160,14 +160,11 @@ end
160160
return zero(img)
161161
end
162162

163-
@inline function logdensityofimg(lklhds::Tuple{}, img::AbstractArray{T}) where T
163+
@inline function logdensityofimg(lklhds::Tuple{}, img::AbstractArray{T}) where {T}
164164
return zero(T)
165165
end
166166

167167

168-
169-
170-
171168
include("likelihood.jl")
172169
include("vlbiposterior.jl")
173170
include("transformed.jl")

src/posterior/likelihood.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,12 @@ struct CentroidData{M, N, G}
114114
end
115115

116116

117-
struct _Centroid{F,T}
117+
struct _Centroid{F, T}
118118
f::F
119119
Σ::T
120120
end
121121

122-
struct NormalFast{T,S}
122+
struct NormalFast{T, S}
123123
μ::T
124124
Σ::S
125125
end
@@ -137,7 +137,7 @@ end
137137
function makelikelihood(data::CentroidData)
138138
Σ = data.noise .^ 2
139139
meas = data.measurement
140-
f = SVectorcentroid
140+
f = SVector centroid
141141
= ConditionedLikelihood(_Centroid(f, Σ), SVector(meas))
142142
return
143143
end

src/posterior/vlbiposterior.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ admode(post::VLBIPosterior) = post.admode
1717
end
1818

1919

20-
2120
"""
2221
VLBIPosterior(skymodel::SkyModel, instumentmodel::InstrumentModel,
2322
dataproducts::EHTObservationTable...;
@@ -96,7 +95,7 @@ VLBIPosterior(
9695
skymodel::AbstractSkyModel, dataproducts::EHTObservationTable...;
9796
admode = EnzymeCore.set_runtime_activity(EnzymeCore.Reverse), kwargs...
9897
) =
99-
VLBIPosterior(skymodel, IdealInstrumentModel(), dataproducts...;admode, kwargs..., )
98+
VLBIPosterior(skymodel, IdealInstrumentModel(), dataproducts...; admode, kwargs...)
10099

101100
function combine_prior(skyprior, instrumentmodelprior)
102101
return NamedDist((sky = skyprior, instrument = instrumentmodelprior))

src/skymodels/abstract.jl

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,37 @@ end
5959

6060
function idealmaps(::DualData, m::AbstractSkyModel, x)
6161
skym = skymodel(m, x.sky)
62-
dmap = dualmap(skym, domain(m))
63-
return ComradeBase.imgmap(dmap), ComradeBase.vismap(dmap)
62+
# We include this special method to sometime optimize the dualmap computation
63+
# for specific sky models
64+
img, vis = fastdualmap(skym, domain(m))
65+
return img, vis
6466
end
6567

68+
function fastdualmap(skym::AbstractModel, grid::VLBISkyModels.AbstractFourierDualDomain)
69+
dm = dualmap(skym, grid)
70+
return ComradeBase.imgmap(dm), ComradeBase.vismap(dm)
71+
end
72+
73+
function fastdualmap(skym::ContinuousImage, grid::VLBISkyModels.AbstractFourierDualDomain)
74+
VLBISkyModels.checkgrid(axisdims(skym), imgdomain(grid)) ||
75+
throw(DomainError("Image domain does not match skymodel image domain"))
76+
img = VLBISkyModels.make_map(m)
77+
return img, visibilitymap(skym, grid)
78+
end
79+
80+
81+
function fastdualmap(skym::VLBISkyModels.CompositeModel{<:ContinuousImage}, grid::VLBISkyModels.AbstractFourierDualDomain)
82+
VLBISkyModels.checkgrid(axisdims(skym.m1), ComradeBase.imgdomain(grid)) || throw(DomainError("Image domain does not match skymodel image domain"))
83+
img = VLBISkyModels.make_map(skym.m1)
84+
img2 = intensitymap(skym.m2, grid)
85+
img2 .+= img #copy into this one because img will alias otherwise
86+
return img, visibilitymap(skym, grid)
87+
end
88+
89+
fastdualmap(m::VLBISkyModels.CompositeModel{M1, <:ContinuousImage}, grid::VLBISkyModels.AbstractFourierDualDomain) where {M1} =
90+
fastdualmap(swap(m), grid)
91+
92+
6693
function skymodel(m::AbstractSkyModel, x)
6794
return m.f(x, m.metadata)
6895
end

test/Core/models.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@ end
150150
@test tt.dynamic == Comrade.skymodel(oskym, x.dynamic)
151151
@test tt.static == Comrade.skymodel(oskyf, x.static)
152152

153-
vtot = last(Comrade.idealmaps(oskytot, (; sky = x)))
154-
vdyn = last(Comrade.idealmaps(oskym, (; sky = x.dynamic)))
153+
vtot = last(Comrade.idealmaps(oskytot, (; sky = x)))
154+
vdyn = last(Comrade.idealmaps(oskym, (; sky = x.dynamic)))
155155
vstat = last(Comrade.idealmaps(oskyf, (; sky = x.static)))
156156

157157
@test_opt Comrade.idealmaps(oskytot, (; sky = x))

0 commit comments

Comments
 (0)