From e0bffc733d9bebf0820251bbc693914243480b7d Mon Sep 17 00:00:00 2001 From: Harsh Singh Date: Wed, 20 May 2026 14:31:37 +0530 Subject: [PATCH] Add setparamsgit add src/mcmc/gibbs_setparams.jl src/mcmc/gibbs_model_hooks.jl src/mcmc/Inference.jl and condition hooks for AbstractMCMC Gibbs --- src/mcmc/Inference.jl | 2 + src/mcmc/gibbs_model_hooks.jl | 104 ++++++++++++++++++++++++++++++++++ src/mcmc/gibbs_setparams.jl | 85 +++++++++++++++++++++++++++ 3 files changed, 191 insertions(+) create mode 100644 src/mcmc/gibbs_model_hooks.jl create mode 100644 src/mcmc/gibbs_setparams.jl diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index b9aa08e8f..8539e6af4 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -115,5 +115,7 @@ include("prior.jl") include("gibbs.jl") include("gibbs_conditional.jl") +include("gibbs_setparams.jl") +include("gibbs_model_hooks.jl") end # module diff --git a/src/mcmc/gibbs_model_hooks.jl b/src/mcmc/gibbs_model_hooks.jl new file mode 100644 index 000000000..ced55aa26 --- /dev/null +++ b/src/mcmc/gibbs_model_hooks.jl @@ -0,0 +1,104 @@ +# Disambiguate: (DynamicPPL.Model, AbstractMCMC.Gibbs) is more specific than +# both (AbstractModel, Gibbs) and (DynamicPPL.Model, AbstractSampler). +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + spl::AbstractMCMC.Gibbs; + initial_params=nothing, + kwargs..., +) + # Turing passes InitFromPrior/InitFromParams as initial_params. + # Treat anything that is not a VarNamedTuple as "no prior values" for condition(). + # Still forward initial_params to component samplers so they can initialise. + gv = initial_params isa DynamicPPL.VarNamedTuple ? initial_params : nothing + component_initial_params = + initial_params === nothing ? DynamicPPL.InitFromPrior() : initial_params + sub_states = AbstractMCMC._gibbs_initial_steps( + rng, + model, + spl.varnames, + spl.samplers, + gv; + initial_params=component_initial_params, + kwargs..., + ) + global_values = AbstractMCMC._collect_global_values( + model, spl.varnames, spl.samplers, sub_states + ) + return AbstractMCMC._build_gibbs_transition(global_values), + AbstractMCMC.GibbsState(global_values, sub_states) +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + spl::AbstractMCMC.Gibbs, + state::AbstractMCMC.GibbsState; + kwargs..., +) + global_values, sub_states = AbstractMCMC._gibbs_sweep( + rng, + model, + spl.varnames, + spl.samplers, + state.sub_states, + state.global_values; + kwargs..., + ) + return AbstractMCMC._build_gibbs_transition(global_values), + AbstractMCMC.GibbsState(global_values, sub_states) +end + +function AbstractMCMC.condition( + model::DynamicPPL.Model, + target_varnames::AbstractVector{<:VarName}, + global_values::DynamicPPL.VarNamedTuple, +) + conditioned_model, _ctx = make_conditional(model, target_varnames, global_values) + return conditioned_model +end + +function AbstractMCMC.condition( + model::DynamicPPL.Model, target_varnames::AbstractVector{<:VarName}, ::Nothing +) + return model +end + +function AbstractMCMC._init_global_values( + ::DynamicPPL.Model, ::AbstractVector{<:VarName}, ::DynamicPPL.Model, sub_state +) + return gibbs_get_raw_values(sub_state) +end + +function AbstractMCMC._update_global_values( + ::DynamicPPL.Model, + global_values::DynamicPPL.VarNamedTuple, + ::AbstractVector{<:VarName}, + cond_model::DynamicPPL.Model, + new_params::AbstractVector{<:Real}, +) + accs = DynamicPPL.OnlyAccsVarInfo(DynamicPPL.RawValueAccumulator(false)) + _, accs = DynamicPPL.init!!( + cond_model, + accs, + DynamicPPL.InitFromParams(new_params, nothing), + DynamicPPL.UnlinkAll(), + ) + return merge(global_values, DynamicPPL.get_raw_values(accs)) +end + +# VarNamedTuple overload: MH's getparams returns a VarNamedTuple directly; +# merge it straight into global_values without the encode/decode roundtrip. +function AbstractMCMC._update_global_values( + ::DynamicPPL.Model, + global_values::DynamicPPL.VarNamedTuple, + ::AbstractVector{<:VarName}, + ::DynamicPPL.Model, + new_params::DynamicPPL.VarNamedTuple, +) + return merge(global_values, new_params) +end + +function AbstractMCMC._build_gibbs_transition(global_values::DynamicPPL.VarNamedTuple) + return global_values +end diff --git a/src/mcmc/gibbs_setparams.jl b/src/mcmc/gibbs_setparams.jl new file mode 100644 index 000000000..ae966c71a --- /dev/null +++ b/src/mcmc/gibbs_setparams.jl @@ -0,0 +1,85 @@ +function AbstractMCMC.setparams!!( + model::DynamicPPL.Model, state::HMCState, params::AbstractVector{<:Real} +) + new_ldf, new_params, _ = gibbs_recompute_ldf_and_params(state.ldf, model, params) + metric = gen_metric(LogDensityProblems.dimension(new_ldf), state) + lp_func = Base.Fix1(LogDensityProblems.logdensity, new_ldf) + lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, new_ldf) + new_hamiltonian = AHMC.Hamiltonian(metric, lp_func, lp_grad_func) + new_z = deepcopy(state.z) + new_z.θ .= new_params + return HMCState(state.i, state.kernel, new_hamiltonian, new_z, state.adaptor, new_ldf) +end + +function AbstractMCMC.setparams!!( + model::DynamicPPL.Model, state::TuringESSState, params::AbstractVector{<:Real} +) + new_ldf, new_params, accs = gibbs_recompute_ldf_and_params( + state.ldf, model, params, (DynamicPPL.LogLikelihoodAccumulator(),) + ) + return TuringESSState( + new_ldf, new_params, DynamicPPL.getloglikelihood(accs), state.priors + ) +end + +function AbstractMCMC.setparams!!( + model::DynamicPPL.Model, + state::DynamicPPL.AbstractVarInfo, + params::AbstractVector{<:Real}, +) + return last( + DynamicPPL.init!!( + model, state, DynamicPPL.InitFromParams(params, nothing), DynamicPPL.UnlinkAll() + ), + ) +end + +function AbstractMCMC.setparams!!( + model::DynamicPPL.Model, + state::DynamicPPL.AbstractVarInfo, + params::DynamicPPL.VarNamedTuple, +) + return last( + DynamicPPL.init!!( + model, state, DynamicPPL.InitFromParams(params, nothing), DynamicPPL.UnlinkAll() + ), + ) +end + +function AbstractMCMC.setparams!!( + model::DynamicPPL.Model, state::TuringState, params::AbstractVector{<:Real} +) + new_ldf, new_params, _ = gibbs_recompute_ldf_and_params(state.ldf, model, params) + new_inner_state = AbstractMCMC.setparams!!( + AbstractMCMC.LogDensityModel(new_ldf), state.state, new_params + ) + return TuringState(new_inner_state, new_params, new_ldf) +end + +function AbstractMCMC.setparams!!( + ::DynamicPPL.Model, state::DynamicPPL.OnlyAccsVarInfo, ::AbstractVector{<:Real} +) + return state +end + +function AbstractMCMC.getparams(::DynamicPPL.Model, state::HMCState) + return DynamicPPL.ParamsWithStats( + state.z.θ, state.ldf; include_log_probs=false, include_colon_eq=false + ).params +end + +function AbstractMCMC.getparams(::DynamicPPL.Model, state::TuringESSState) + return DynamicPPL.ParamsWithStats( + state.params, state.ldf; include_log_probs=false, include_colon_eq=false + ).params +end + +function AbstractMCMC.getparams(::DynamicPPL.Model, state::TuringState) + return DynamicPPL.ParamsWithStats( + state.params, state.ldf; include_log_probs=false, include_colon_eq=false + ).params +end + +function AbstractMCMC.getparams(::DynamicPPL.Model, state::DynamicPPL.AbstractVarInfo) + return gibbs_get_raw_values(state) +end