Skip to content

Commit 6eadc14

Browse files
committed
update
1 parent 11daccc commit 6eadc14

File tree

4 files changed

+48
-13
lines changed

4 files changed

+48
-13
lines changed

examples/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
3232
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
3333
VLBIImagePriors = "b1ba175b-8447-452c-b961-7db2d6f7a029"
3434
VLBILikelihoods = "90db92cd-0007-4c0a-8e51-dbf0782ce592"
35+
VLBISkyModels = "d6343c73-7174-4e0f-bb64-562643efbeca"
3536
WGLMakie = "276b4fcb-3e11-5398-bf8b-a0c2d153d008"
3637
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3738

playground/enzyme_dft_vis.jl

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,14 @@ dvis = extract_table(obs, ComplexVisibilities())
3838

3939

4040
function sky(θ, metadata)
41-
c = θ
41+
c = θ.c
4242
(;grid, cache) = metadata
4343
img = IntensityMap(c, grid)
4444
m = ContinuousImage(img, cache)
4545
return m
4646
end
4747

48-
npix = 12
48+
npix = 48
4949
fovx = μas2rad(80.0)
5050
fovy = μas2rad(80.0)
5151

@@ -56,24 +56,58 @@ cache = create_cache(NFFTAlg(dvis), buffer, DeltaPulse())
5656

5757

5858
using VLBIImagePriors
59-
prior = ImageUniform(npix, npix)
6059

6160
skymetadata = (;grid, cache)
62-
# instrumentmetadata = (;gcache, gcachep)
63-
lklhd = RadioLikelihood(sky, dvis; skymeta=skymetadata)
61+
62+
function instrument(θ, metadata)
63+
(; lgamp, gphase) = θ
64+
(; gcache, gcachep) = metadata
65+
## Now form our instrument model
66+
gvis = exp.(lgamp)
67+
gphase = exp.(1im.*gphase)
68+
jgamp = jonesStokes(gvis, gcache)
69+
jgphase = jonesStokes(gphase, gcachep)
70+
return JonesModel(jgamp*jgphase)
71+
end
72+
73+
gcache = jonescache(dvis, ScanSeg())
74+
gcachep = jonescache(dvis, ScanSeg(); autoref=SEFDReference((complex(1.0))))
75+
76+
using VLBIImagePriors
77+
# Now we can form our metadata we need to fully define our model.
78+
metadata = (;gcache, gcachep)
79+
80+
using Distributions
81+
using DistributionsAD
82+
distamp = station_tuple(dvis, Normal(0.0, 0.1); LM = Normal(1.0))
83+
84+
distphase = station_tuple(dvis, DiagonalVonMises(0.0, inv(π^2)))
85+
86+
87+
prior = NamedDist(
88+
c = ImageUniform(npix, npix),
89+
lgamp = CalPrior(distamp, gcache),
90+
gphase = CalPrior(distphase, gcachep),
91+
)
92+
93+
94+
95+
lklhd = RadioLikelihood(sky, instrument, dvis; skymeta=skymetadata, instrumentmeta=metadata)
6496
post = Posterior(lklhd, prior)
6597

6698
tpost = asflat(post)
6799
ndim = dimension(tpost)
68100

69101
using Enzyme
102+
using Zygote
70103
Enzyme.API.runtimeActivity!(true)
71104
# Enzyme.API.printall!(false)
72105
x0 = prior_sample(tpost)
73106
dx0 = zero(x0)
74107
lt=logdensityof(tpost)
75108
= logdensityof(tpost, x0)
76-
dtpost = deepcopy(tpost)
109+
gz, = Zygote.gradient(lt, x0)
77110
using BenchmarkTools
78-
@time autodiff(Reverse, logdensityof, Const(tpost), Duplicated(x0, fill!(dx0, 0.0)))
79-
@btime autodiff(Reverse, logdensityof, $(Duplicated(tpost, dtpost)), Duplicated($x0, fill!($dx0, 0.0)))
111+
autodiff(Reverse, Const(lt), Active, Duplicated(x0, fill!(dx0, 0.0)))
112+
@benchmark autodiff(Reverse, logdensityof, $(Const(tpost)), Duplicated($x0, fill!($dx0, 0.0)))
113+
@benchmark Zygote.gradient($lt, $x0)

playground/enzyme_geom_vis.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ using LinearAlgebra
2121
using StableRNGs
2222
rng = StableRNG(42)
2323

24-
24+
Enzyme.Compiler.bitcode_replacement!(false)
2525

2626
# ## Load the Data
2727

@@ -106,7 +106,7 @@ prior = NamedDist(
106106

107107
# Putting it all together we form our likelihood and posterior objects for optimization and
108108
# sampling.
109-
lklhd = RadioLikelihood(sky, instrument, dvis; instrumentmeta = metadata)
109+
lklhd = RadioLikelihood(sky, instrument, dvis; instrumentmeta=metadata)
110110
post = Posterior(lklhd, prior)
111111

112112
# ## Reconstructing the Image and Instrument Effects
@@ -122,10 +122,10 @@ ndim = dimension(tpost)
122122
# inference packages use this interface as well.
123123
using Zygote
124124
using Enzyme
125-
Enzyme.API.runtimeActivity!(true)
125+
# Enzyme.API.runtimeActivity!(true)
126126

127127
x0 = randn(rng, ndim)
128128
= logdensityof(tpost)
129129
gz, = Zygote.gradient(ℓ, x0)
130130
dx0 = zero(x0)
131-
autodiff(Reverse, logdensityof, Duplicated(tpost, deepcopy(tpost)), Duplicated(x0, dx0))
131+
autodiff(Reverse, Const(ℓ), Active, Duplicated(x0, dx0))

src/calibration/jones.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
export JonesCache, TrackSeg, ScanSeg, FixedSeg, IntegSeg, jonesG, jonesD, jonesT,
2-
ResponseCache, JonesModel, jonescache, station_tuple, jonesmap
2+
ResponseCache, JonesModel, jonescache, station_tuple
33

44
"""
55
$(TYPEDEF)

0 commit comments

Comments
 (0)