@@ -64,9 +64,10 @@ spl = MH(
6464
6565You can also use a callable to define a proposal that is conditional on the current values.
6666The callable must accept a single argument, which is a `DynamicPPL.VarNamedTuple` that holds
67- all the current values of the parameters. You can obtain the value of a specific parameter
68- by indexing into this `VarNamedTuple` using a `VarName` (note that symbol indexing is not
69- supported). The callable must then return a `Distribution` from which to draw the proposal.
67+ all the values of the parameters from the previous step. You can obtain the value of a
68+ specific parameter by indexing into this `VarNamedTuple` using a `VarName` (note that symbol
69+ indexing is not supported). The callable must then return a `Distribution` from which to
70+ draw the proposal.
7071
7172!!! note
7273 In general, there is no way for Turing to reliably detect whether a proposal is meant to
@@ -125,9 +126,14 @@ function DynamicPPL.init(
125126 return DynamicPPL. UntransformedValue (rand (rng, dist))
126127end
127128
128- function MH (pair1, pairs... )
129+ const SymOrVNPair = Pair{<: Union{Symbol,VarName} ,<: Any }
130+
131+ function MH (pair1:: SymOrVNPair , pairs:: Vararg{SymOrVNPair} )
129132 vn_proposal_pairs = (pair1, pairs... )
130133 return MH (
134+ # It is assumed that `vnt` is a VarNamedTuple that has all the variables' values
135+ # already set. NOTE: It doesn't store `AbstractTransformedValue`s, but the actual
136+ # raw values.
131137 vnt -> begin
132138 proposals = DynamicPPL. VarNamedTuple ()
133139 for pair in vn_proposal_pairs
@@ -143,16 +149,15 @@ function MH(pair1, pairs...)
143149 )
144150 end
145151 # Check whether the proposal is a Distribution.
146- # TODO (penelopeysm): We don't have templating information here. :(
147- # We *could* steal the templating info from the `vnt` argument. It would
148- # be quite type unstable since you would be mixing and matching
149- # distributions with values. Need to check if it's worth it.
150- if proposal isa Distribution
151- proposals = BangBang. setindex!! (proposals, proposal, vn)
152+ proposal_dist = if proposal isa Distribution
153+ proposal
152154 else
153- # Assume it 's a callable.
154- proposals = BangBang . setindex!! (proposals, proposal (vnt), vn )
155+ # It 's a callable that takes `vnt` and returns a distribution .
156+ proposal (vnt)
155157 end
158+ proposals = DynamicPPL. templated_setindex!! (
159+ proposals, proposal_dist, vn, vnt. data[AbstractPPL. getsym (vn)]
160+ )
156161 end
157162 return InitFromProposals (proposals, Dict {VarName,Distribution} ())
158163 end ,
0 commit comments