-
Notifications
You must be signed in to change notification settings - Fork 239
Expand file tree
/
Copy pathgibbs_setparams.jl
More file actions
85 lines (76 loc) · 2.85 KB
/
Copy pathgibbs_setparams.jl
File metadata and controls
85 lines (76 loc) · 2.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
function AbstractMCMC.setparams!!(
model::DynamicPPL.Model, state::HMCState, params::AbstractVector{<:Real}
)
new_ldf, new_params, _ = gibbs_recompute_ldf_and_params(state.ldf, model, params)
metric = gen_metric(LogDensityProblems.dimension(new_ldf), state)
lp_func = Base.Fix1(LogDensityProblems.logdensity, new_ldf)
lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, new_ldf)
new_hamiltonian = AHMC.Hamiltonian(metric, lp_func, lp_grad_func)
new_z = deepcopy(state.z)
new_z.θ .= new_params
return HMCState(state.i, state.kernel, new_hamiltonian, new_z, state.adaptor, new_ldf)
end
function AbstractMCMC.setparams!!(
model::DynamicPPL.Model, state::TuringESSState, params::AbstractVector{<:Real}
)
new_ldf, new_params, accs = gibbs_recompute_ldf_and_params(
state.ldf, model, params, (DynamicPPL.LogLikelihoodAccumulator(),)
)
return TuringESSState(
new_ldf, new_params, DynamicPPL.getloglikelihood(accs), state.priors
)
end
function AbstractMCMC.setparams!!(
model::DynamicPPL.Model,
state::DynamicPPL.AbstractVarInfo,
params::AbstractVector{<:Real},
)
return last(
DynamicPPL.init!!(
model, state, DynamicPPL.InitFromParams(params, nothing), DynamicPPL.UnlinkAll()
),
)
end
function AbstractMCMC.setparams!!(
model::DynamicPPL.Model,
state::DynamicPPL.AbstractVarInfo,
params::DynamicPPL.VarNamedTuple,
)
return last(
DynamicPPL.init!!(
model, state, DynamicPPL.InitFromParams(params, nothing), DynamicPPL.UnlinkAll()
),
)
end
function AbstractMCMC.setparams!!(
model::DynamicPPL.Model, state::TuringState, params::AbstractVector{<:Real}
)
new_ldf, new_params, _ = gibbs_recompute_ldf_and_params(state.ldf, model, params)
new_inner_state = AbstractMCMC.setparams!!(
AbstractMCMC.LogDensityModel(new_ldf), state.state, new_params
)
return TuringState(new_inner_state, new_params, new_ldf)
end
function AbstractMCMC.setparams!!(
::DynamicPPL.Model, state::DynamicPPL.OnlyAccsVarInfo, ::AbstractVector{<:Real}
)
return state
end
function AbstractMCMC.getparams(::DynamicPPL.Model, state::HMCState)
return DynamicPPL.ParamsWithStats(
state.z.θ, state.ldf; include_log_probs=false, include_colon_eq=false
).params
end
function AbstractMCMC.getparams(::DynamicPPL.Model, state::TuringESSState)
return DynamicPPL.ParamsWithStats(
state.params, state.ldf; include_log_probs=false, include_colon_eq=false
).params
end
function AbstractMCMC.getparams(::DynamicPPL.Model, state::TuringState)
return DynamicPPL.ParamsWithStats(
state.params, state.ldf; include_log_probs=false, include_colon_eq=false
).params
end
function AbstractMCMC.getparams(::DynamicPPL.Model, state::DynamicPPL.AbstractVarInfo)
return gibbs_get_raw_values(state)
end