@@ -38,14 +38,14 @@ dvis = extract_table(obs, ComplexVisibilities())
3838
3939
4040function sky(θ, metadata)
41- c = θ
41+ c = θ. c
4242 (;grid, cache) = metadata
4343 img = IntensityMap(c, grid)
4444 m = ContinuousImage(img, cache)
4545 return m
4646end
4747
48- npix = 12
48+ npix = 48
4949fovx = μas2rad(80.0 )
5050fovy = μas2rad(80.0 )
5151
@@ -56,24 +56,58 @@ cache = create_cache(NFFTAlg(dvis), buffer, DeltaPulse())
5656
5757
5858using VLBIImagePriors
59- prior = ImageUniform(npix, npix)
6059
6160skymetadata = (;grid, cache)
62- # instrumentmetadata = (;gcache, gcachep)
63- lklhd = RadioLikelihood(sky, dvis; skymeta= skymetadata)
61+
62+ function instrument(θ, metadata)
63+ (; lgamp, gphase) = θ
64+ (; gcache, gcachep) = metadata
65+ # # Now form our instrument model
66+ gvis = exp.(lgamp)
67+ gphase = exp.(1im .* gphase)
68+ jgamp = jonesStokes(gvis, gcache)
69+ jgphase = jonesStokes(gphase, gcachep)
70+ return JonesModel(jgamp* jgphase)
71+ end
72+
73+ gcache = jonescache(dvis, ScanSeg())
74+ gcachep = jonescache(dvis, ScanSeg(); autoref= SEFDReference((complex(1.0 ))))
75+
76+ using VLBIImagePriors
77+ # Now we can form our metadata we need to fully define our model.
78+ metadata = (;gcache, gcachep)
79+
80+ using Distributions
81+ using DistributionsAD
82+ distamp = station_tuple(dvis, Normal(0.0 , 0.1 ); LM = Normal(1.0 ))
83+
84+ distphase = station_tuple(dvis, DiagonalVonMises(0.0 , inv(π^ 2 )))
85+
86+
87+ prior = NamedDist(
88+ c = ImageUniform(npix, npix),
89+ lgamp = CalPrior(distamp, gcache),
90+ gphase = CalPrior(distphase, gcachep),
91+ )
92+
93+
94+
95+ lklhd = RadioLikelihood(sky, instrument, dvis; skymeta= skymetadata, instrumentmeta= metadata)
6496post = Posterior(lklhd, prior)
6597
6698tpost = asflat(post)
6799ndim = dimension(tpost)
68100
69101using Enzyme
102+ using Zygote
70103Enzyme. API. runtimeActivity!(true )
71104# Enzyme.API.printall!(false)
72105x0 = prior_sample(tpost)
73106dx0 = zero(x0)
74107lt= logdensityof(tpost)
75108ℓ = logdensityof(tpost, x0)
76- dtpost = deepcopy(tpost )
109+ gz, = Zygote . gradient(lt, x0 )
77110using BenchmarkTools
78- @time autodiff(Reverse, logdensityof, Const(tpost), Duplicated(x0, fill!(dx0, 0.0 )))
79- @btime autodiff(Reverse, logdensityof, $ (Duplicated(tpost, dtpost)), Duplicated($ x0, fill!($ dx0, 0.0 )))
111+ autodiff(Reverse, Const(lt), Active, Duplicated(x0, fill!(dx0, 0.0 )))
112+ @benchmark autodiff(Reverse, logdensityof, $ (Const(tpost)), Duplicated($ x0, fill!($ dx0, 0.0 )))
113+ @benchmark Zygote. gradient($ lt, $ x0)
0 commit comments