Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ DimensionalData = "0.29 - 0.29.24, ^0.29.26"
Distributions = "0.25"
DocStringExtensions = "0.8, 0.9"
Dynesty = "0.4"
Enzyme = "0.13 - 0.13.104, ^0.13.109"
Enzyme = "0.13 - 0.13.104, 0.13.109 - 0.13.111"
EnzymeCore = "0.8"
FillArrays = "1"
HypercubeTransform = "^0.4.11"
Expand Down
9 changes: 3 additions & 6 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@
Pages = ["api.md"]
```

```@docs
Comrade.Comrade
```

## Model Definitions


Expand Down Expand Up @@ -88,7 +84,7 @@ Comrade.AbstractSkyModel
Comrade.SkyModel
Comrade.FixedSkyModel
Comrade.MultiSkyModel
Comrade.idealvisibilities
Comrade.idealmaps
Comrade.skymodel(::Comrade.AbstractVLBIPosterior, ::Any)
```

Expand Down Expand Up @@ -133,13 +129,13 @@ Comrade.dataproducts
Comrade.skymodel
Comrade.instrumentmodel(::Comrade.AbstractVLBIPosterior)
Comrade.instrumentmodel(::Comrade.AbstractVLBIPosterior, ::Any)
Comrade.forward_model
Comrade.prior_sample
Comrade.likelihood
Comrade.VLBIPosterior
Comrade.simulate_observation
Comrade.residuals
Comrade.TransformedVLBIPosterior
Comrade.ImgNormalData
HypercubeTransform.transform(::Comrade.TransformedVLBIPosterior, ::Any)
HypercubeTransform.inverse(::Comrade.TransformedVLBIPosterior, ::Any)
HypercubeTransform.ascube(::Comrade.VLBIPosterior)
Expand Down Expand Up @@ -181,6 +177,7 @@ Comrade.rmap
```@docs
Comrade.build_datum
Comrade.ObservedSkyModel
Comrade.forward_model_map
```

### eht-imaging interface (Internal)
Expand Down
15 changes: 14 additions & 1 deletion docs/src/base_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ ComradeBase.intensitymap_numeric!

### Image Domain
```@docs
ComradeBase.imagepixels
ComradeBase.AbstractDualDomain
ComradeBase.RectiGrid
ComradeBase.UnstructuredDomain
ComradeBase.imagepixels
ComradeBase.dims
ComradeBase.named_dims
ComradeBase.axisdims
Expand All @@ -89,8 +90,20 @@ ComradeBase.baseimage
ComradeBase.centroid
ComradeBase.second_moment
ComradeBase.stokes
ComradeBase.dualmap
ComradeBase.DualMap
```

### Time and Frequency Domain
```@docs
ComradeBase.DomainParams
ComradeBase.paramtype
ComradeBase.getparam
ComradeBase.@unpack_params
ComradeBase.build_param
```


## Internal Methods not part of public API
```@docs
ComradeBase._visibilitymap
Expand Down
2 changes: 2 additions & 0 deletions examples/advanced/FitPS/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
Comrade = "99d987ce-9a1e-4df8-bc0b-1ea019aa547b"
ComradeBase = "6d8c423b-a35f-4ef1-850c-862fe21f82c4"
DisplayAs = "0b91fe84-8a4c-11e9-3e1d-67c38462b6d6"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Expand All @@ -18,6 +19,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
VLBIImagePriors = "b1ba175b-8447-452c-b961-7db2d6f7a029"
VLBILikelihoods = "90db92cd-0007-4c0a-8e51-dbf0782ce592"
VLBISkyModels = "d6343c73-7174-4e0f-bb64-562643efbeca"

[compat]
CairoMakie = "0.15"
Expand Down
14 changes: 7 additions & 7 deletions examples/advanced/FitPS/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ using NonuniformFFTs

# For reproducibility we use a stable random number genreator
using StableRNGs
rng = StableRNG(123)
rng = StableRNG(11)


# ## Load the Data
Expand Down Expand Up @@ -68,7 +68,7 @@ function sky(θ, metadata)
(; fb, c, ρs, σimg) = θ
(; mimg, pl) = metadata
## Apply the GMRF fluctuations to the image
x = genfield(StationaryRandomField(MarkovPS(ρs .^ 2), pl), c)
x = genfield(StationaryRandomField(MarkovPS(ρs), pl), c)
x .= σimg .* x
fbn = fb / length(mimg)
mb = mimg .* (1 - fb) .+ fbn
Expand All @@ -92,7 +92,7 @@ grid = imagepixels(fovx, fovy, nx, ny, μas2rad(150.0), -μas2rad(150.0))
# image. For this work we will use a symmetric Gaussian with a FWHM equal to the approximate
# beamsize of the array. This models the fact that we expect the AGN core to be compact.
fwhmfac = 2 * sqrt(2 * log(2))
mpr = modify(Gaussian(), Stretch(beamsize(dlcamp) / 4 / fwhmfac))
mpr = modify(TBlob(3.0), Stretch(beamsize(dlcamp) / 4 / fwhmfac))
imgpr = intensitymap(mpr, grid)
# To momdel the power spectrum we also need to construct our execution plan for the given grid.
# This will be used to construct the actual correlated realization of the RF given some initial
Expand All @@ -110,7 +110,7 @@ cprior = std_dist(pl)
# allows for a wide range of power spectra. Additionally, we truncate the expansion at order 3
# for simplicity in this tutorial.
using Distributions
ρs = ntuple(Returns(Uniform(0.1, 2 * max(size(grid)...))), 3)
ρs = ntuple(Returns(Uniform(0.01, max(size(grid)...))), 3)

# Putting everything together the total prior is then our image prior, a prior on the
# standard deviation of the MRF, and a prior on the fractional flux of the Gaussian component.
Expand Down Expand Up @@ -141,9 +141,9 @@ post = VLBIPosterior(skym, dlcamp, dcphase)
# functionality a user first needs to import `Optimization.jl` and the optimizer of choice.
# In this tutorial we will use the Adam optimizer.
# We also need to import Enzyme to allow for automatic differentiation.
using Optimization, OptimizationOptimisers
using Optimization, OptimizationLBFGSB
# tpost = asflat(post)
xopt, sol = comrade_opt(post, Adam(); maxiters = 5000)
xopt, sol = comrade_opt(post, LBFGSB(); initial_params = prior_sample(rng, post), maxiters = 5000)

using CairoMakie
using DisplayAs #hide
Expand Down Expand Up @@ -197,7 +197,7 @@ k = range(1 / size(grid)[1], π / 2, length = 512)
fig = Figure()
ax = Axis(fig[1, 1], xscale = log10, yscale = log10)
for i in 501:10:length(chain)
lines!(ax, k, VLBIImagePriors.ampspectrum.(Ref(MarkovPS(chain.sky.ρs[i] .^ 2)), tuple.(k, 0)))
lines!(ax, k, VLBIImagePriors.ampspectrum.(Ref(MarkovPS(chain.sky.ρs[i])), tuple.(k, 0)))
end
fig

Expand Down
2 changes: 2 additions & 0 deletions examples/intermediate/ClosureImaging/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
Comrade = "99d987ce-9a1e-4df8-bc0b-1ea019aa547b"
ComradeBase = "6d8c423b-a35f-4ef1-850c-862fe21f82c4"
DisplayAs = "0b91fe84-8a4c-11e9-3e1d-67c38462b6d6"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Expand All @@ -16,6 +17,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
VLBIImagePriors = "b1ba175b-8447-452c-b961-7db2d6f7a029"
VLBILikelihoods = "90db92cd-0007-4c0a-8e51-dbf0782ce592"
VLBISkyModels = "d6343c73-7174-4e0f-bb64-562643efbeca"

[compat]
CairoMakie = "0.15"
Expand Down
3 changes: 1 addition & 2 deletions examples/intermediate/ClosureImaging/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,9 @@ function sky(θ, metadata)
rast = apply_fluctuations(CenteredLR(), mimg, σimg .* c.params)
m = ContinuousImage(((1 - fg)) .* rast, BSplinePulse{3}())
## Force the image centroid to be at the origin
x0, y0 = centroid(m)
## Add a large-scale gaussian to deal with the over-resolved mas flux
g = modify(Gaussian(), Stretch(μas2rad(250.0), μas2rad(250.0)), Renormalize(fg))
return shifted(m, -x0, -y0) + g
return m + g
end


Expand Down
2 changes: 2 additions & 0 deletions examples/intermediate/PolarizedImaging/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
Comrade = "99d987ce-9a1e-4df8-bc0b-1ea019aa547b"
ComradeBase = "6d8c423b-a35f-4ef1-850c-862fe21f82c4"
DisplayAs = "0b91fe84-8a4c-11e9-3e1d-67c38462b6d6"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Expand All @@ -14,3 +15,4 @@ Pyehtim = "3d61700d-6e5b-419a-8e22-9c066cf00468"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
VLBIImagePriors = "b1ba175b-8447-452c-b961-7db2d6f7a029"
VLBISkyModels = "d6343c73-7174-4e0f-bb64-562643efbeca"
6 changes: 3 additions & 3 deletions examples/intermediate/PolarizedImaging/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@ function sky(θ, metadata)
end

pmap .= ftot .* pmap ./ ft
x0, y0 = centroid(pmap)
m = ContinuousImage(pmap, BSplinePulse{3}())
return shifted(m, -x0, -y0)
x, y = centroid(pmap)
return shifted(m, -x, -y)
end


Expand Down Expand Up @@ -384,7 +384,7 @@ fig |> DisplayAs.PNG |> DisplayAs.Text
# other imaging examples. For example
# ```julia
# using AdvancedHMC
# chain = sample(rng, post, NUTS(0.8), 10_000, n_adapts=5000, progress=true, initial_params=xopt)
# chain = sample(rng, post, NUTS(0.8), 2_000, n_adapts = 1000, progress = true, initial_params = xopt)
# ```


Expand Down
10 changes: 5 additions & 5 deletions examples/intermediate/StokesIImaging/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ function sky(θ, metadata)
pimg = parent(rast)
@. pimg = (ftot * (1 - fg)) * pimg
m = ContinuousImage(rast, BSplinePulse{3}())
x0, y0 = centroid(m)
## Add a large-scale gaussian to deal with the over-resolved mas flux
g = modify(Gaussian(), Stretch(μas2rad(500.0), μas2rad(500.0)), Renormalize(ftot * fg))
return shifted(m, -x0, -y0) + g
x, y = centroid(m)
return shifted(m, -x, -y) + g
end


Expand All @@ -90,7 +90,7 @@ grid = imagepixels(fovx, fovy, npix, npix)
using VLBIImagePriors
using Distributions
fwhmfac = 2 * sqrt(2 * log(2))
mpr = modify(Gaussian(), Stretch(μas2rad(50.0) ./ fwhmfac))
mpr = modify(Gaussian(), Stretch(μas2rad(60.0) ./ fwhmfac))
mimg = intensitymap(mpr, grid)


Expand Down Expand Up @@ -219,7 +219,7 @@ plotcaltable(abs.(intopt)) |> DisplayAs.PNG |> DisplayAs.Text
# run.
#-
using AdvancedHMC
chain = sample(rng, post, NUTS(0.8), 1_000; n_adapts = 500, initial_params = xopt)
chain = sample(rng, post, NUTS(0.8), 700; n_adapts = 500, initial_params = xopt)
#-
# !!! note
# The above sampler will store the samples in memory, i.e. RAM. For large models this
Expand All @@ -231,7 +231,7 @@ chain = sample(rng, post, NUTS(0.8), 1_000; n_adapts = 500, initial_params = xop


# Now we prune the adaptation phase
chain = chain[501:end]
chain = chain[(begin + 500):end]

#-
# !!! warning
Expand Down
4 changes: 0 additions & 4 deletions src/Comrade.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
"""
Comrade
Composable Modeling of Radio Emission
"""
module Comrade

using AbstractMCMC
Expand Down
15 changes: 15 additions & 0 deletions src/observations/dataproducts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,21 @@ export extract_table,
ClosurePhases, LogClosureAmplitudes,
VisibilityAmplitudes, Visibilities, Coherencies


# Traits that decide the domain of the data. We include this to prevent additional computation if
# e.g., on visibility data is considered.
abstract type ComradeDataType end
struct NoData <: ComradeDataType end
struct DualData <: ComradeDataType end
struct VisData <: ComradeDataType end
struct ImgData <: ComradeDataType end

datatype(::Type, ::Type{<:Nothing}) = VisData()
datatype(::Type{<:Nothing}, ::Type) = ImgData()
datatype(::Type{<:Nothing}, ::Type{<:Nothing}) = throw(ArgumentError("No data in uv plane or image plane is provided"))
datatype(::Type, ::Type) = DualData()


abstract type VLBIDataProducts{K} end

keywords(d::VLBIDataProducts) = d.keywords
Expand Down
58 changes: 25 additions & 33 deletions src/posterior/abstract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Default methods include:
abstract type AbstractVLBIPosterior end
@inline DensityInterface.DensityKind(::AbstractVLBIPosterior) = DensityInterface.IsDensity()


"""
logprior(d::AbstractVLBIPosterior, θ)

Expand Down Expand Up @@ -54,35 +55,6 @@ instrumentmodel(d::AbstractVLBIPosterior) = getfield(d, :instrumentmodel)
HypercubeTransform.dimension(d::AbstractVLBIPosterior) = length(d.prior)
EnzymeRules.inactive(::typeof(instrumentmodel), args...) = nothing

# @noinline logprior_ref(d, x) = logprior(d, x[])

# function ChainRulesCore.rrule(::typeof(logprior), d::AbstractVLBIPosterior, x)
# p = logprior(d, x)
# # We need this
# px = ProjectTo(x)
# function _logprior_pullback(Δ)
# # @info "HERE"
# xr = Ref(x)
# dxr = Ref(ntzero(x))
# autodiff(Reverse, logprior_ref, Active, Const(d), Duplicated(xr, dxr))
# return NoTangent(), NoTangent(), (_perturb(Δ, dxr[]))
# end
# return p, _logprior_pullback
# end

# function _perturb(Δ, x::Union{NamedTuple, Tuple})
# return map(x->_perturb(Δ, x), x)
# end

# function _perturb(Δ, x)
# return Δ*x
# end

# function _perturb(Δ, x::AbstractArray)
# x .= Δ*x
# return x
# end


function DensityInterface.logdensityof(post::AbstractVLBIPosterior, x)
pr = logprior(post, x)
Expand Down Expand Up @@ -118,9 +90,13 @@ Computes the forward model visibilities of the posterior `d` with parameters `θ
Note these are the complex visiblities or the coherency matrices, not the actual
data products observed.
"""
@inline function forward_model_map(D, d::AbstractVLBIPosterior, θ)
img, vis = idealmaps(D, skymodel(d), θ)
return img, apply_instrument(vis, instrumentmodel(d), θ)
end

@inline function forward_model(d::AbstractVLBIPosterior, θ)
vis = idealvisibilities(skymodel(d), θ)
return apply_instrument(vis, instrumentmodel(d), θ)
return forward_model_map(datatype(typeof(d)), d, θ)
end

"""
Expand All @@ -129,9 +105,10 @@ end
Computes the log-likelihood of the posterior `d` with parameters `θ`.
"""
@inline function loglikelihood(d::AbstractVLBIPosterior, θ)
vis = forward_model(d, θ)
img, vis = forward_model(d, θ)
# Convert because of conventions
return logdensityofvis(d.lklhds, vis)
lis = d.lklhdsimg
return logdensityofvis(d.lklhds, vis) + logdensityofimg(lis, img)
end

"""
Expand Down Expand Up @@ -172,6 +149,21 @@ end
return sum(ls)
end

@inline function logdensityofimg(lklhds, img::IntensityMap)
fl = Base.Fix2(logdensityof, img)
ls = map(fl, lklhds)
return sum(ls)
end

## There is no image data so just return 0
@inline function logdensityofimg(lklhds::Tuple{}, img::Number)
return zero(img)
end

@inline function logdensityofimg(lklhds::Tuple{}, img::AbstractArray{T}) where {T}
return zero(T)
end


include("likelihood.jl")
include("vlbiposterior.jl")
Expand Down
Loading
Loading