Skip to content

Commit ee14f61

Browse files
committed
updated compat for MCMC (and indirectly ForwardDiff/GP), along with compatible interface
1 parent 5fdb164 commit ee14f61

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

Project.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,17 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
3131

3232
[compat]
3333
AbstractGPs = "0.5.21"
34-
AbstractMCMC = "3.3, 4, 5"
35-
AdvancedMH = "0.6, 0.7, 0.8"
34+
AbstractMCMC = "5"
35+
AdvancedMH = "0.8"
3636
ChunkSplitters = "3.1.2"
3737
Conda = "1.7"
3838
Distributions = "0.24, 0.25"
3939
DocStringExtensions = "0.8, 0.9"
4040
EnsembleKalmanProcesses = "2"
41-
ForwardDiff = "0.10.38, 1"
41+
ForwardDiff = "1"
4242
GaussianProcesses = "0.12"
4343
KernelFunctions = "0.10.64"
44-
MCMCChains = "4.14, 5, 6, 7"
44+
MCMCChains = "7"
4545
Printf = "1"
4646
ProgressBars = "1"
4747
PyCall = "1.93"

src/MarkovChainMonteCarlo.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,10 @@ function AbstractMCMC.step(
379379
new_params = AdvancedMH.propose(rng, sampler, model, current_state; stepsize = stepsize)
380380
# Calculate the log acceptance probability and the log density of the candidate.
381381
new_log_density = AdvancedMH.logdensity(model, new_params)
382+
383+
# Just to initialize: if you pass state, it just reads the old log_density (initialized as false). in this case, compute it by passing the actual parameter values
384+
current_log_density = isa(AdvancedMH.logdensity(model, current_state), Bool) ? AdvancedMH.logdensity(model, current_state.params) : AdvancedMH.logdensity(model, current_state)
385+
382386
log_α =
383387
new_log_density - AdvancedMH.logdensity(model, current_state) +
384388
AdvancedMH.logratio_proposal_density(sampler, current_state, new_params)
@@ -573,7 +577,7 @@ function MCMCWrapper(
573577
end
574578

575579
sample_kwargs = (; # set defaults here
576-
:init_params => deepcopy(init_params),
580+
:initial_params => deepcopy(init_params),
577581
:param_names => param_names,
578582
:discard_initial => burnin,
579583
:chain_type => MCMCChains.Chains,
@@ -641,7 +645,7 @@ end
641645

642646
function _find_mcmc_step_log(mcmc::MCMCWrapper)
643647
str_ = @sprintf "%d starting params:" 0
644-
for p in zip(mcmc.sample_kwargs.param_names, mcmc.sample_kwargs.init_params)
648+
for p in zip(mcmc.sample_kwargs.param_names, mcmc.sample_kwargs.initial_params)
645649
str_ *= @sprintf " %s: %.3g" p[1] p[2]
646650
end
647651
println(str_)

0 commit comments

Comments
 (0)