Skip to content

Commit 353c2f9

Browse files
committed
Use templating info in MH
1 parent 4041673 commit 353c2f9

1 file changed

Lines changed: 17 additions & 12 deletions

File tree

src/mcmc/mh.jl

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,10 @@ spl = MH(
6464
6565
You can also use a callable to define a proposal that is conditional on the current values.
6666
The callable must accept a single argument, which is a `DynamicPPL.VarNamedTuple` that holds
67-
all the current values of the parameters. You can obtain the value of a specific parameter
68-
by indexing into this `VarNamedTuple` using a `VarName` (note that symbol indexing is not
69-
supported). The callable must then return a `Distribution` from which to draw the proposal.
67+
all the values of the parameters from the previous step. You can obtain the value of a
68+
specific parameter by indexing into this `VarNamedTuple` using a `VarName` (note that symbol
69+
indexing is not supported). The callable must then return a `Distribution` from which to
70+
draw the proposal.
7071
7172
!!! note
7273
In general, there is no way for Turing to reliably detect whether a proposal is meant to
@@ -125,9 +126,14 @@ function DynamicPPL.init(
125126
return DynamicPPL.UntransformedValue(rand(rng, dist))
126127
end
127128

128-
function MH(pair1, pairs...)
129+
const SymOrVNPair = Pair{<:Union{Symbol,VarName},<:Any}
130+
131+
function MH(pair1::SymOrVNPair, pairs::Vararg{SymOrVNPair})
129132
vn_proposal_pairs = (pair1, pairs...)
130133
return MH(
134+
# It is assumed that `vnt` is a VarNamedTuple that has all the variables' values
135+
# already set. NOTE: It doesn't store `AbstractTransformedValue`s, but the actual
136+
# raw values.
131137
vnt -> begin
132138
proposals = DynamicPPL.VarNamedTuple()
133139
for pair in vn_proposal_pairs
@@ -143,16 +149,15 @@ function MH(pair1, pairs...)
143149
)
144150
end
145151
# Check whether the proposal is a Distribution.
146-
# TODO(penelopeysm): We don't have templating information here. :(
147-
# We *could* steal the templating info from the `vnt` argument. It would
148-
# be quite type unstable since you would be mixing and matching
149-
# distributions with values. Need to check if it's worth it.
150-
if proposal isa Distribution
151-
proposals = BangBang.setindex!!(proposals, proposal, vn)
152+
proposal_dist = if proposal isa Distribution
153+
proposal
152154
else
153-
# Assume it's a callable.
154-
proposals = BangBang.setindex!!(proposals, proposal(vnt), vn)
155+
# It's a callable that takes `vnt` and returns a distribution.
156+
proposal(vnt)
155157
end
158+
proposals = DynamicPPL.templated_setindex!!(
159+
proposals, proposal_dist, vn, vnt.data[AbstractPPL.getsym(vn)]
160+
)
156161
end
157162
return InitFromProposals(proposals, Dict{VarName,Distribution}())
158163
end,

0 commit comments

Comments
 (0)