Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,5 +115,7 @@ include("prior.jl")

include("gibbs.jl")
include("gibbs_conditional.jl")
include("gibbs_setparams.jl")
include("gibbs_model_hooks.jl")

end # module
104 changes: 104 additions & 0 deletions src/mcmc/gibbs_model_hooks.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Disambiguate: (DynamicPPL.Model, AbstractMCMC.Gibbs) is more specific than
# both (AbstractModel, Gibbs) and (DynamicPPL.Model, AbstractSampler).
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
spl::AbstractMCMC.Gibbs;
initial_params=nothing,
kwargs...,
)
# Turing passes InitFromPrior/InitFromParams as initial_params.
# Treat anything that is not a VarNamedTuple as "no prior values" for condition().
# Still forward initial_params to component samplers so they can initialise.
gv = initial_params isa DynamicPPL.VarNamedTuple ? initial_params : nothing
component_initial_params =
initial_params === nothing ? DynamicPPL.InitFromPrior() : initial_params
sub_states = AbstractMCMC._gibbs_initial_steps(
rng,
model,
spl.varnames,
spl.samplers,
gv;
initial_params=component_initial_params,
kwargs...,
)
global_values = AbstractMCMC._collect_global_values(
model, spl.varnames, spl.samplers, sub_states
)
return AbstractMCMC._build_gibbs_transition(global_values),
AbstractMCMC.GibbsState(global_values, sub_states)
end

function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
spl::AbstractMCMC.Gibbs,
state::AbstractMCMC.GibbsState;
kwargs...,
)
global_values, sub_states = AbstractMCMC._gibbs_sweep(
rng,
model,
spl.varnames,
spl.samplers,
state.sub_states,
state.global_values;
kwargs...,
)
return AbstractMCMC._build_gibbs_transition(global_values),
AbstractMCMC.GibbsState(global_values, sub_states)
end

function AbstractMCMC.condition(
model::DynamicPPL.Model,
target_varnames::AbstractVector{<:VarName},
global_values::DynamicPPL.VarNamedTuple,
)
conditioned_model, _ctx = make_conditional(model, target_varnames, global_values)
return conditioned_model
end

function AbstractMCMC.condition(
model::DynamicPPL.Model, target_varnames::AbstractVector{<:VarName}, ::Nothing
)
return model
end

function AbstractMCMC._init_global_values(
::DynamicPPL.Model, ::AbstractVector{<:VarName}, ::DynamicPPL.Model, sub_state
)
return gibbs_get_raw_values(sub_state)
end

function AbstractMCMC._update_global_values(
::DynamicPPL.Model,
global_values::DynamicPPL.VarNamedTuple,
::AbstractVector{<:VarName},
cond_model::DynamicPPL.Model,
new_params::AbstractVector{<:Real},
)
accs = DynamicPPL.OnlyAccsVarInfo(DynamicPPL.RawValueAccumulator(false))
_, accs = DynamicPPL.init!!(
cond_model,
accs,
DynamicPPL.InitFromParams(new_params, nothing),
DynamicPPL.UnlinkAll(),
)
return merge(global_values, DynamicPPL.get_raw_values(accs))
end

# VarNamedTuple overload: MH's getparams returns a VarNamedTuple directly;
# merge it straight into global_values without the encode/decode roundtrip.
function AbstractMCMC._update_global_values(
::DynamicPPL.Model,
global_values::DynamicPPL.VarNamedTuple,
::AbstractVector{<:VarName},
::DynamicPPL.Model,
new_params::DynamicPPL.VarNamedTuple,
)
return merge(global_values, new_params)
end

function AbstractMCMC._build_gibbs_transition(global_values::DynamicPPL.VarNamedTuple)
return global_values
end
85 changes: 85 additions & 0 deletions src/mcmc/gibbs_setparams.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
function AbstractMCMC.setparams!!(
model::DynamicPPL.Model, state::HMCState, params::AbstractVector{<:Real}
)
new_ldf, new_params, _ = gibbs_recompute_ldf_and_params(state.ldf, model, params)
metric = gen_metric(LogDensityProblems.dimension(new_ldf), state)
lp_func = Base.Fix1(LogDensityProblems.logdensity, new_ldf)
lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, new_ldf)
new_hamiltonian = AHMC.Hamiltonian(metric, lp_func, lp_grad_func)
new_z = deepcopy(state.z)
new_z.θ .= new_params
return HMCState(state.i, state.kernel, new_hamiltonian, new_z, state.adaptor, new_ldf)
end

function AbstractMCMC.setparams!!(
model::DynamicPPL.Model, state::TuringESSState, params::AbstractVector{<:Real}
)
new_ldf, new_params, accs = gibbs_recompute_ldf_and_params(
state.ldf, model, params, (DynamicPPL.LogLikelihoodAccumulator(),)
)
return TuringESSState(
new_ldf, new_params, DynamicPPL.getloglikelihood(accs), state.priors
)
end

function AbstractMCMC.setparams!!(
model::DynamicPPL.Model,
state::DynamicPPL.AbstractVarInfo,
params::AbstractVector{<:Real},
)
return last(
DynamicPPL.init!!(
model, state, DynamicPPL.InitFromParams(params, nothing), DynamicPPL.UnlinkAll()
),
)
end

function AbstractMCMC.setparams!!(
model::DynamicPPL.Model,
state::DynamicPPL.AbstractVarInfo,
params::DynamicPPL.VarNamedTuple,
)
return last(
DynamicPPL.init!!(
model, state, DynamicPPL.InitFromParams(params, nothing), DynamicPPL.UnlinkAll()
),
)
end

function AbstractMCMC.setparams!!(
model::DynamicPPL.Model, state::TuringState, params::AbstractVector{<:Real}
)
new_ldf, new_params, _ = gibbs_recompute_ldf_and_params(state.ldf, model, params)
new_inner_state = AbstractMCMC.setparams!!(
AbstractMCMC.LogDensityModel(new_ldf), state.state, new_params
)
return TuringState(new_inner_state, new_params, new_ldf)
end

function AbstractMCMC.setparams!!(
::DynamicPPL.Model, state::DynamicPPL.OnlyAccsVarInfo, ::AbstractVector{<:Real}
)
return state
end

function AbstractMCMC.getparams(::DynamicPPL.Model, state::HMCState)
return DynamicPPL.ParamsWithStats(
state.z.θ, state.ldf; include_log_probs=false, include_colon_eq=false
).params
end

function AbstractMCMC.getparams(::DynamicPPL.Model, state::TuringESSState)
return DynamicPPL.ParamsWithStats(
state.params, state.ldf; include_log_probs=false, include_colon_eq=false
).params
end

function AbstractMCMC.getparams(::DynamicPPL.Model, state::TuringState)
return DynamicPPL.ParamsWithStats(
state.params, state.ldf; include_log_probs=false, include_colon_eq=false
).params
end

function AbstractMCMC.getparams(::DynamicPPL.Model, state::DynamicPPL.AbstractVarInfo)
return gibbs_get_raw_values(state)
end