Skip to content

Commit 30bd7a5

Browse files
committed
Move abstractmcmc.jl to external_sampler.jl
1 parent 1ad8a40 commit 30bd7a5

5 files changed

Lines changed: 100 additions & 101 deletions

File tree

src/mcmc/Inference.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ end
123123
# Default Transition #
124124
######################
125125
# Default
126-
# Extended in contrib/inference/abstractmcmc.jl
127126
getstats(t) = nothing
128127

129128
abstract type AbstractTransition end
@@ -359,7 +358,6 @@ end
359358
# Concrete algorithm implementations. #
360359
#######################################
361360

362-
include("abstractmcmc.jl")
363361
include("ess.jl")
364362
include("hmc.jl")
365363
include("mh.jl")

src/mcmc/abstractmcmc.jl

Lines changed: 0 additions & 97 deletions
This file was deleted.

src/mcmc/external_sampler.jl

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,101 @@ function externalsampler(
6565
)
6666
return ExternalSampler(sampler, adtype, Val(unconstrained))
6767
end
68+
69+
struct TuringState{S,M,V,C}
70+
state::S
71+
ldf::DynamicPPL.LogDensityFunction{M,V,C}
72+
end
73+
74+
state_to_turing(f::DynamicPPL.LogDensityFunction, state) = TuringState(state, f)
75+
function transition_to_turing(f::DynamicPPL.LogDensityFunction, transition)
76+
# TODO: We should probably rename this `getparams` since it returns something
77+
# very different from `Turing.Inference.getparams`.
78+
θ = getparams(f.model, transition)
79+
varinfo = DynamicPPL.unflatten(f.varinfo, θ)
80+
return Transition(f.model, varinfo, transition)
81+
end
82+
83+
function varinfo(state::TuringState)
84+
θ = getparams(state.ldf.model, state.state)
85+
# TODO: Do we need to link here first?
86+
return DynamicPPL.unflatten(state.ldf.varinfo, θ)
87+
end
88+
varinfo(state::AbstractVarInfo) = state
89+
90+
# NOTE: Only thing that depends on the underlying sampler.
91+
# Something similar should be part of AbstractMCMC at some point:
92+
# https://github.com/TuringLang/AbstractMCMC.jl/pull/86
93+
getparams(::DynamicPPL.Model, transition::AdvancedHMC.Transition) = transition.z.θ
94+
function getparams(model::DynamicPPL.Model, state::AdvancedHMC.HMCState)
95+
return getparams(model, state.transition)
96+
end
97+
getstats(transition::AdvancedHMC.Transition) = transition.stat
98+
99+
getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.params
100+
101+
# TODO: Do we also support `resume`, etc?
102+
function AbstractMCMC.step(
103+
rng::Random.AbstractRNG,
104+
model::DynamicPPL.Model,
105+
sampler_wrapper::Sampler{<:ExternalSampler};
106+
initial_state=nothing,
107+
initial_params=nothing,
108+
kwargs...,
109+
)
110+
alg = sampler_wrapper.alg
111+
sampler = alg.sampler
112+
113+
# Initialise varinfo with initial params and link the varinfo if needed.
114+
varinfo = DynamicPPL.VarInfo(model)
115+
if requires_unconstrained_space(alg)
116+
if initial_params !== nothing
117+
# If we have initial parameters, we need to set the varinfo before linking.
118+
varinfo = DynamicPPL.link(DynamicPPL.unflatten(varinfo, initial_params), model)
119+
# Extract initial parameters in unconstrained space.
120+
initial_params = varinfo[:]
121+
else
122+
varinfo = DynamicPPL.link(varinfo, model)
123+
end
124+
end
125+
126+
# Construct LogDensityFunction
127+
f = DynamicPPL.LogDensityFunction(model, varinfo; adtype=alg.adtype)
128+
129+
# Then just call `AbstractMCMC.step` with the right arguments.
130+
if initial_state === nothing
131+
transition_inner, state_inner = AbstractMCMC.step(
132+
rng, AbstractMCMC.LogDensityModel(f), sampler; initial_params, kwargs...
133+
)
134+
else
135+
transition_inner, state_inner = AbstractMCMC.step(
136+
rng,
137+
AbstractMCMC.LogDensityModel(f),
138+
sampler,
139+
initial_state;
140+
initial_params,
141+
kwargs...,
142+
)
143+
end
144+
# Update the `state`
145+
return transition_to_turing(f, transition_inner), state_to_turing(f, state_inner)
146+
end
147+
148+
function AbstractMCMC.step(
149+
rng::Random.AbstractRNG,
150+
model::DynamicPPL.Model,
151+
sampler_wrapper::Sampler{<:ExternalSampler},
152+
state::TuringState;
153+
kwargs...,
154+
)
155+
sampler = sampler_wrapper.alg.sampler
156+
f = state.ldf
157+
158+
# Then just call `AdvancedHMC.step` with the right arguments.
159+
transition_inner, state_inner = AbstractMCMC.step(
160+
rng, AbstractMCMC.LogDensityModel(f), sampler, state.state; kwargs...
161+
)
162+
163+
# Update the `state`
164+
return transition_to_turing(f, transition_inner), state_to_turing(f, state_inner)
165+
end
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module AbstractMCMCTests
1+
module ExternalSamplerTests
22

33
using AbstractMCMC: AbstractMCMC
44
using AdvancedMH: AdvancedMH

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ end
5454
@timeit_include("mcmc/hmc.jl")
5555
@timeit_include("mcmc/Inference.jl")
5656
@timeit_include("mcmc/sghmc.jl")
57-
@timeit_include("mcmc/abstractmcmc.jl")
57+
@timeit_include("mcmc/external_sampler.jl")
5858
@timeit_include("mcmc/mh.jl")
5959
@timeit_include("ext/dynamichmc.jl")
6060
@timeit_include("mcmc/repeat_sampler.jl")

0 commit comments

Comments
 (0)