Skip to content

Commit 4041673

Browse files
committed
fixes for MH
1 parent 4cb1326 commit 4041673

3 files changed

Lines changed: 27 additions & 114 deletions

File tree

src/mcmc/mh.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ function DynamicPPL.init(
125125
return DynamicPPL.UntransformedValue(rand(rng, dist))
126126
end
127127

128-
function MH(vn_proposal_pairs...)
128+
function MH(pair1, pairs...)
129+
vn_proposal_pairs = (pair1, pairs...)
129130
return MH(
130131
vnt -> begin
131132
proposals = DynamicPPL.VarNamedTuple()

src/mcmc/particle_mcmc.jl

Lines changed: 24 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -334,90 +334,53 @@ function AbstractMCMC.step(
334334
end
335335

336336
"""
337-
get_trace_local_varinfo_maybe(vi::AbstractVarInfo)
337+
get_trace_local_varinfo()
338338
339339
Get the varinfo stored in the 'taped globals' of a `Libtask.TapedTask`, if one exists.
340-
If this function is not called within a TapedTask, return the provided `varinfo`.
341340
"""
342-
function get_trace_local_varinfo_maybe(varinfo::AbstractVarInfo)
343-
trace = try
344-
Libtask.get_taped_globals(Any).other
345-
catch e
346-
e == KeyError(:task_variable) ? nothing : rethrow(e)
347-
end
348-
has_trace = trace !== nothing
349-
return (has_trace ? trace.model.f.varinfo : varinfo)::AbstractVarInfo
341+
function get_trace_local_varinfo()
342+
trace = Libtask.get_taped_globals(Any).other
343+
return trace.model.f.varinfo::AbstractVarInfo
350344
end
351345

352346
"""
353-
get_trace_local_resampled_maybe(fallback_resampled::Bool)
347+
get_trace_local_resampled()
354348
355-
Get the Boolean `resample` value stored in the 'taped globals' of a `Libtask.TapedTask`, if
356-
one exists. If this function is not called within a TapedTask, return the provided
357-
`fallback_resampled` value.
349+
Get the Boolean `resample` value stored in the 'taped globals' of a `Libtask.TapedTask`.
358350
"""
359-
function get_trace_local_resampled_maybe(fallback_resampled::Bool)
360-
trace = try
361-
Libtask.get_taped_globals(Any).other
362-
catch e
363-
e == KeyError(:task_variable) ? nothing : rethrow(e)
364-
end
365-
has_trace = trace !== nothing
366-
return (has_trace ? trace.model.f.resample : fallback_resampled)::Bool
351+
function get_trace_local_resampled()
352+
trace = Libtask.get_taped_globals(Any).other
353+
return trace.model.f.resample::Bool
367354
end
368355

369356
"""
370-
get_trace_local_rng_maybe(rng::Random.AbstractRNG)
357+
get_trace_local_rng()
371358
372-
Get the `Trace` local rng if one exists.
373-
374-
If executed within a `TapedTask`, return the `rng` stored in the "taped globals" of the
375-
task, otherwise return `vi`.
359+
Get the RNG stored in the 'taped globals' of a `Libtask.TapedTask`, if one exists.
376360
"""
377-
function get_trace_local_rng_maybe(rng::Random.AbstractRNG)
378-
return try
379-
Libtask.get_taped_globals(Any).rng
380-
catch e
381-
e == KeyError(:task_variable) ? rng : rethrow(e)
382-
end
361+
function get_trace_local_rng()
362+
return Libtask.get_taped_globals(Any).rng
383363
end
384364

385365
"""
386-
set_trace_local_varinfo_maybe(vi::AbstractVarInfo)
366+
set_trace_local_varinfo(vi::AbstractVarInfo)
387367
388368
Set the `Trace` local varinfo if executing within a `Trace`. Return `nothing`.
389369
390370
If executed within a `TapedTask`, set the `varinfo` stored in the "taped globals" of the
391371
task. Otherwise do nothing.
392372
"""
393-
function set_trace_local_varinfo_maybe(vi::AbstractVarInfo)
394-
# TODO(mhauru) This should be done in a try-catch block, as in the commented out code.
395-
# However, Libtask currently can't handle this block.
396-
trace = #try
397-
Libtask.get_taped_globals(Any).other
398-
# catch e
399-
# e == KeyError(:task_variable) ? nothing : rethrow(e)
400-
# end
401-
if trace !== nothing
402-
# trace isa AdvancedPS.Trace (defined in AdvancedPS src/model.jl)
403-
# trace.model isa AdvancedPS.LibtaskModel (defined in AdvancedPSLibtaskExt)
404-
# trace.model.f isa TracedModel (defined above)
405-
trace.model.f.varinfo = vi
406-
end
407-
return nothing
373+
function set_trace_local_varinfo(vi::AbstractVarInfo)
374+
trace = Libtask.get_taped_globals(Any).other
375+
return trace.model.f.varinfo = vi
408376
end
409377

410378
function DynamicPPL.tilde_assume!!(
411-
ctx::ParticleMCMCContext,
412-
dist::Distribution,
413-
vn::VarName,
414-
template::Any,
415-
vi::AbstractVarInfo,
379+
::ParticleMCMCContext, dist::Distribution, vn::VarName, template::Any, ::AbstractVarInfo
416380
)
417-
vi = get_trace_local_varinfo_maybe(vi)
418-
419-
trng = get_trace_local_rng_maybe(ctx.rng)
420-
resample = get_trace_local_resampled_maybe(true)
381+
vi = get_trace_local_varinfo()
382+
trng = get_trace_local_rng()
383+
resample = get_trace_local_resampled()
421384

422385
dispatch_ctx = if ~haskey(vi, vn) || resample
423386
DynamicPPL.InitContext(trng, DynamicPPL.InitFromPrior())
@@ -426,7 +389,7 @@ function DynamicPPL.tilde_assume!!(
426389
end
427390
x, vi = DynamicPPL.tilde_assume!!(dispatch_ctx, dist, vn, template, vi)
428391

429-
set_trace_local_varinfo_maybe(vi)
392+
set_trace_local_varinfo(vi)
430393
return x, vi
431394
end
432395

@@ -437,9 +400,9 @@ function DynamicPPL.tilde_observe!!(
437400
vn::Union{VarName,Nothing},
438401
vi::AbstractVarInfo,
439402
)
440-
vi = get_trace_local_varinfo_maybe(vi)
403+
vi = get_trace_local_varinfo()
441404
left, vi = DynamicPPL.tilde_observe!!(DefaultContext(), right, left, vn, vi)
442-
set_trace_local_varinfo_maybe(vi)
405+
set_trace_local_varinfo(vi)
443406
return left, vi
444407
end
445408

test/mcmc/mh.jl

Lines changed: 1 addition & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,7 @@ GKernel(variance) = (vnt -> Normal(vnt[@varname(m)], sqrt(variance)))
7070
@varname(mu1) => MH(:mu1 => GKernel(1)),
7171
@varname(mu2) => MH(:mu2 => GKernel(1)),
7272
)
73-
initial_params = InitFromParams((
74-
mu1=1.0, mu2=1.0, z1=0.0, z2=0.0, z3=1.0, z4=1.0
75-
))
73+
initial_params = InitFromParams((mu1=1.0, mu2=1.0, z1=0, z2=0, z3=1, z4=1))
7674
chain = sample(
7775
StableRNG(seed),
7876
MoGtest_default,
@@ -160,55 +158,6 @@ GKernel(variance) = (vnt -> Normal(vnt[@varname(m)], sqrt(variance)))
160158
end
161159
end
162160
end
163-
164-
@testset "MH link/invlink" begin
165-
vi_base = DynamicPPL.VarInfo(gdemo_default)
166-
167-
# Don't link when no proposals are given since we're using priors
168-
# as proposals.
169-
vi = deepcopy(vi_base)
170-
spl = MH()
171-
vi = Turing.Inference.maybe_link!!(vi, spl, spl.proposals, gdemo_default)
172-
@test !DynamicPPL.is_transformed(vi)
173-
174-
# Link if proposal is `AdvancedHM.RandomWalkProposal`
175-
vi = deepcopy(vi_base)
176-
d = length(vi_base[:])
177-
spl = MH(AdvancedMH.RandomWalkProposal(MvNormal(zeros(d), I)))
178-
vi = Turing.Inference.maybe_link!!(vi, spl, spl.proposals, gdemo_default)
179-
@test DynamicPPL.is_transformed(vi)
180-
181-
# Link if ALL proposals are `AdvancedHM.RandomWalkProposal`.
182-
vi = deepcopy(vi_base)
183-
spl = MH(:s => AdvancedMH.RandomWalkProposal(Normal()))
184-
vi = Turing.Inference.maybe_link!!(vi, spl, spl.proposals, gdemo_default)
185-
@test DynamicPPL.is_transformed(vi)
186-
187-
# Don't link if at least one proposal is NOT `RandomWalkProposal`.
188-
# TODO: make it so that only those that are using `RandomWalkProposal`
189-
# are linked! I.e. resolve https://github.com/TuringLang/Turing.jl/issues/1583.
190-
# https://github.com/TuringLang/Turing.jl/pull/1582#issuecomment-817148192
191-
vi = deepcopy(vi_base)
192-
spl = MH(
193-
:m => AdvancedMH.StaticProposal(Normal()),
194-
:s => AdvancedMH.RandomWalkProposal(Normal()),
195-
)
196-
vi = Turing.Inference.maybe_link!!(vi, spl, spl.proposals, gdemo_default)
197-
@test !DynamicPPL.is_transformed(vi)
198-
end
199-
200-
@testset "`filldist` proposal (issue #2180)" begin
201-
@model demo_filldist_issue2180() = x ~ MvNormal(zeros(3), I)
202-
chain = sample(
203-
StableRNG(seed),
204-
demo_filldist_issue2180(),
205-
MH(AdvancedMH.RandomWalkProposal(filldist(Normal(), 3))),
206-
10_000,
207-
)
208-
check_numerical(
209-
chain, [Symbol("x[1]"), Symbol("x[2]"), Symbol("x[3]")], [0, 0, 0]; atol=0.2
210-
)
211-
end
212161
end
213162

214163
end

0 commit comments

Comments
 (0)