@@ -65,3 +65,101 @@ function externalsampler(
6565)
6666 return ExternalSampler (sampler, adtype, Val (unconstrained))
6767end
68+
69+ struct TuringState{S,M,V,C}
70+ state:: S
71+ ldf:: DynamicPPL.LogDensityFunction{M,V,C}
72+ end
73+
74+ state_to_turing (f:: DynamicPPL.LogDensityFunction , state) = TuringState (state, f)
75+ function transition_to_turing (f:: DynamicPPL.LogDensityFunction , transition)
76+ # TODO : We should probably rename this `getparams` since it returns something
77+ # very different from `Turing.Inference.getparams`.
78+ θ = getparams (f. model, transition)
79+ varinfo = DynamicPPL. unflatten (f. varinfo, θ)
80+ return Transition (f. model, varinfo, transition)
81+ end
82+
83+ function varinfo (state:: TuringState )
84+ θ = getparams (state. ldf. model, state. state)
85+ # TODO : Do we need to link here first?
86+ return DynamicPPL. unflatten (state. ldf. varinfo, θ)
87+ end
88+ varinfo (state:: AbstractVarInfo ) = state
89+
90+ # NOTE: Only thing that depends on the underlying sampler.
91+ # Something similar should be part of AbstractMCMC at some point:
92+ # https://github.com/TuringLang/AbstractMCMC.jl/pull/86
93+ getparams (:: DynamicPPL.Model , transition:: AdvancedHMC.Transition ) = transition. z. θ
94+ function getparams (model:: DynamicPPL.Model , state:: AdvancedHMC.HMCState )
95+ return getparams (model, state. transition)
96+ end
97+ getstats (transition:: AdvancedHMC.Transition ) = transition. stat
98+
99+ getparams (:: DynamicPPL.Model , transition:: AdvancedMH.Transition ) = transition. params
100+
101+ # TODO : Do we also support `resume`, etc?
102+ function AbstractMCMC. step (
103+ rng:: Random.AbstractRNG ,
104+ model:: DynamicPPL.Model ,
105+ sampler_wrapper:: Sampler{<:ExternalSampler} ;
106+ initial_state= nothing ,
107+ initial_params= nothing ,
108+ kwargs... ,
109+ )
110+ alg = sampler_wrapper. alg
111+ sampler = alg. sampler
112+
113+ # Initialise varinfo with initial params and link the varinfo if needed.
114+ varinfo = DynamicPPL. VarInfo (model)
115+ if requires_unconstrained_space (alg)
116+ if initial_params != = nothing
117+ # If we have initial parameters, we need to set the varinfo before linking.
118+ varinfo = DynamicPPL. link (DynamicPPL. unflatten (varinfo, initial_params), model)
119+ # Extract initial parameters in unconstrained space.
120+ initial_params = varinfo[:]
121+ else
122+ varinfo = DynamicPPL. link (varinfo, model)
123+ end
124+ end
125+
126+ # Construct LogDensityFunction
127+ f = DynamicPPL. LogDensityFunction (model, varinfo; adtype= alg. adtype)
128+
129+ # Then just call `AbstractMCMC.step` with the right arguments.
130+ if initial_state === nothing
131+ transition_inner, state_inner = AbstractMCMC. step (
132+ rng, AbstractMCMC. LogDensityModel (f), sampler; initial_params, kwargs...
133+ )
134+ else
135+ transition_inner, state_inner = AbstractMCMC. step (
136+ rng,
137+ AbstractMCMC. LogDensityModel (f),
138+ sampler,
139+ initial_state;
140+ initial_params,
141+ kwargs... ,
142+ )
143+ end
144+ # Update the `state`
145+ return transition_to_turing (f, transition_inner), state_to_turing (f, state_inner)
146+ end
147+
148+ function AbstractMCMC. step (
149+ rng:: Random.AbstractRNG ,
150+ model:: DynamicPPL.Model ,
151+ sampler_wrapper:: Sampler{<:ExternalSampler} ,
152+ state:: TuringState ;
153+ kwargs... ,
154+ )
155+ sampler = sampler_wrapper. alg. sampler
156+ f = state. ldf
157+
158+ # Then just call `AdvancedHMC.step` with the right arguments.
159+ transition_inner, state_inner = AbstractMCMC. step (
160+ rng, AbstractMCMC. LogDensityModel (f), sampler, state. state; kwargs...
161+ )
162+
163+ # Update the `state`
164+ return transition_to_turing (f, transition_inner), state_to_turing (f, state_inner)
165+ end
0 commit comments