Skip to content

Commit fa01a40

Browse files
committed
format
1 parent 06e3675 commit fa01a40

File tree

2 files changed

+73
-51
lines changed

2 files changed

+73
-51
lines changed

src/MarkovChainMonteCarlo.jl

Lines changed: 60 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ export EmulatorPosteriorModel,
2929
infMALASampling,
3030
infHMCSampling,
3131
infmMALASampling,
32-
infmHMCSampling,
32+
infmHMCSampling,
3333
MCMCWrapper,
3434
accept_ratio,
3535
optimize_stepsize,
@@ -166,7 +166,8 @@ AdvancedMH.logratio_proposal_density(
166166
candidate,
167167
) = AdvancedMH.logratio_proposal_density(sampler.proposal, transition_prev.params, candidate)
168168

169-
MetropolisHastingsSampler(::MALASampling, prior::ParameterDistribution) = MetropolisAdjustedLangevin(_get_proposal(prior))
169+
MetropolisHastingsSampler(::MALASampling, prior::ParameterDistribution) =
170+
MetropolisAdjustedLangevin(_get_proposal(prior))
170171
"""
171172
$(DocStringExtensions.TYPEDEF)
172173
@@ -185,7 +186,8 @@ AdvancedMH.logratio_proposal_density(
185186
candidate,
186187
) = AdvancedMH.logratio_proposal_density(sampler.proposal, transition_prev.params, candidate)
187188

188-
MetropolisHastingsSampler(::BarkerSampling, prior::ParameterDistribution) = BarkerMetropolisHastings(_get_proposal(prior))
189+
MetropolisHastingsSampler(::BarkerSampling, prior::ParameterDistribution) =
190+
BarkerMetropolisHastings(_get_proposal(prior))
189191
"""
190192
$(DocStringExtensions.TYPEDEF)
191193
@@ -223,7 +225,8 @@ AdvancedMH.logratio_proposal_density(
223225
candidate,
224226
) = AdvancedMH.logratio_proposal_density(sampler.proposal, transition_prev.params, candidate)
225227

226-
MetropolisHastingsSampler(::infMALASampling, prior::ParameterDistribution) = infMALAMetropolisHastings(_get_proposal(prior))
228+
MetropolisHastingsSampler(::infMALASampling, prior::ParameterDistribution) =
229+
infMALAMetropolisHastings(_get_proposal(prior))
227230
"""
228231
$(DocStringExtensions.TYPEDEF)
229232
@@ -242,7 +245,8 @@ AdvancedMH.logratio_proposal_density(
242245
candidate,
243246
) = AdvancedMH.logratio_proposal_density(sampler.proposal, transition_prev.params, candidate)
244247

245-
MetropolisHastingsSampler(::infHMCSampling, prior::ParameterDistribution) = infHMCMetropolisHastings(_get_proposal(prior))
248+
MetropolisHastingsSampler(::infHMCSampling, prior::ParameterDistribution) =
249+
infHMCMetropolisHastings(_get_proposal(prior))
246250
"""
247251
$(DocStringExtensions.TYPEDEF)
248252
@@ -261,7 +265,8 @@ AdvancedMH.logratio_proposal_density(
261265
candidate,
262266
) = AdvancedMH.logratio_proposal_density(sampler.proposal, transition_prev.params, candidate)
263267

264-
MetropolisHastingsSampler(::infmMALASampling, prior::ParameterDistribution) = infmMALAMetropolisHastings(_get_proposal(prior))
268+
MetropolisHastingsSampler(::infmMALASampling, prior::ParameterDistribution) =
269+
infmMALAMetropolisHastings(_get_proposal(prior))
265270
"""
266271
$(DocStringExtensions.TYPEDEF)
267272
@@ -280,7 +285,8 @@ AdvancedMH.logratio_proposal_density(
280285
candidate,
281286
) = AdvancedMH.logratio_proposal_density(sampler.proposal, transition_prev.params, candidate)
282287

283-
MetropolisHastingsSampler(::infmHMCSampling, prior::ParameterDistribution) = infmHMCMetropolisHastings(_get_proposal(prior))
288+
MetropolisHastingsSampler(::infmHMCSampling, prior::ParameterDistribution) =
289+
infmHMCMetropolisHastings(_get_proposal(prior))
284290

285291
# ------------------------------------------------------------------------------------------
286292
# Use emulated model in sampler
@@ -396,10 +402,10 @@ function AdvancedMH.propose(
396402
current_state::MCMCState;
397403
stepsize::FT = 1.0,
398404
) where {FT <: AbstractFloat}
399-
# Compute the gradient of the log-density at the current state
400-
log_gradient = ForwardDiff.gradient(x -> AdvancedMH.logdensity(model, x), current_state.params)
401-
proposed_state = current_state.params .+ (stepsize^2 / 2) .* log_gradient .+ stepsize * rand(rng, sampler.proposal)
402-
return proposed_state
405+
# Compute the gradient of the log-density at the current state
406+
log_gradient = ForwardDiff.gradient(x -> AdvancedMH.logdensity(model, x), current_state.params)
407+
proposed_state = current_state.params .+ (stepsize^2 / 2) .* log_gradient .+ stepsize * rand(rng, sampler.proposal)
408+
return proposed_state
403409
end
404410

405411
# method extending AdvancedMH.propose() for the Barker proposal
@@ -410,13 +416,13 @@ function AdvancedMH.propose(
410416
current_state::MCMCState;
411417
stepsize::FT = 1.0,
412418
) where {FT <: AbstractFloat}
413-
# Livingstone and Zanella (2022)
419+
# Livingstone and Zanella (2022)
414420
# Compute the gradient of the log-density at the current state
415421
log_gradient = ForwardDiff.gradient(x -> AdvancedMH.logdensity(model, x), current_state.params)
416422
n = length(current_state.params)
417423
u = rand(rng, n)
418424
xi = rand(rng, sampler.proposal)
419-
b = u .< 1 ./ (1 .+ exp.(- log_gradient .* xi))
425+
b = u .< 1 ./ (1 .+ exp.(-log_gradient .* xi))
420426
return current_state.params .+ b .* xi
421427
end
422428

@@ -441,9 +447,16 @@ function AdvancedMH.propose(
441447
log_gradient = log_grad_proposed_state
442448
proposed_state .+= sqrt(stepsize) .* proposed_aux - (stepsize / 2) .* log_grad_proposed_state
443449
log_grad_proposed_state = ForwardDiff.gradient(x -> AdvancedMH.logdensity(model, x), proposed_state)
444-
proposed_aux .+= - (sqrt(stepsize) / 2) .* log_gradient .- (sqrt(stepsize) / 2) .* log_grad_proposed_state
450+
proposed_aux .+= -(sqrt(stepsize) / 2) .* log_gradient .- (sqrt(stepsize) / 2) .* log_grad_proposed_state
445451
end
446-
println("L: ", L, " stepsize: ", round(stepsize, digits = 6), " proposed_state: ", round.(proposed_state, digits = 5))
452+
println(
453+
"L: ",
454+
L,
455+
" stepsize: ",
456+
round(stepsize, digits = 6),
457+
" proposed_state: ",
458+
round.(proposed_state, digits = 5),
459+
)
447460
return proposed_state
448461
end
449462

@@ -455,11 +468,13 @@ function AdvancedMH.propose(
455468
current_state::MCMCState;
456469
stepsize::FT = 1.0,
457470
) where {FT <: AbstractFloat}
458-
# Compute the gradient of the log-density at the current state
459-
ρ = (1 - stepsize / 4) / (1 + stepsize / 4)
460-
log_gradient = ForwardDiff.gradient(x -> AdvancedMH.logdensity(model, x), current_state.params)
461-
proposed_state = ρ * current_state.params .- sqrt(1 - ρ^2) * (sqrt(stepsize) / 2) .* log_gradient .+ sqrt(1 - ρ^2) * rand(rng, sampler.proposal)
462-
return proposed_state
471+
# Compute the gradient of the log-density at the current state
472+
ρ = (1 - stepsize / 4) / (1 + stepsize / 4)
473+
log_gradient = ForwardDiff.gradient(x -> AdvancedMH.logdensity(model, x), current_state.params)
474+
proposed_state =
475+
ρ * current_state.params .- sqrt(1 - ρ^2) * (sqrt(stepsize) / 2) .* log_gradient .+
476+
sqrt(1 - ρ^2) * rand(rng, sampler.proposal)
477+
return proposed_state
463478
end
464479

465480
# method extending AdvancedMH.propose() for the ∞-HMC proposal
@@ -483,9 +498,16 @@ function AdvancedMH.propose(
483498
log_gradient = log_grad_proposed_state
484499
proposed_state .+= sqrt(stepsize) .* proposed_aux - (stepsize / 2) .* log_grad_proposed_state
485500
log_grad_proposed_state = ForwardDiff.gradient(x -> AdvancedMH.logdensity(model, x), proposed_state)
486-
proposed_aux .+= - (sqrt(stepsize) / 2) .* log_gradient .- (sqrt(stepsize) / 2) .* log_grad_proposed_state
501+
proposed_aux .+= -(sqrt(stepsize) / 2) .* log_gradient .- (sqrt(stepsize) / 2) .* log_grad_proposed_state
487502
end
488-
println("L: ", L, " stepsize: ", round(stepsize, digits = 8), " proposed_state: ", round.(proposed_state, digits = 5))
503+
println(
504+
"L: ",
505+
L,
506+
" stepsize: ",
507+
round(stepsize, digits = 8),
508+
" proposed_state: ",
509+
round.(proposed_state, digits = 5),
510+
)
489511
return proposed_state
490512
end
491513

@@ -497,16 +519,16 @@ function AdvancedMH.propose(
497519
current_state::MCMCState;
498520
stepsize::FT = 1.0,
499521
) where {FT <: AbstractFloat}
500-
# Compute the gradient of the log-density at the current state
501-
ρ = (1 - stepsize / 4) / (1 + stepsize / 4)
502-
log_gradient = ForwardDiff.gradient(x -> AdvancedMH.logdensity(model, x), current_state.params)
503-
hessian = Symmetric(ForwardDiff.hessian(x -> AdvancedMH.logdensity(model, x), current_state.params))
504-
K = Symmetric(inv(- hessian))
505-
C_inv = I(size(K, 1))
506-
xi = cholesky(K, check=false).L * randn(size(K, 1))
507-
# xi = rand(rng, MvNormal(zeros(size(K, 1)), K))# or cholesky(K_u).L * randn(size(K_u, 1))
508-
nu = xi .- (stepsize / 2) .* K * ((C_inv + hessian) * current_state.params .+ log_gradient)
509-
return ρ * current_state.params .+ sqrt(1 - ρ^2) * nu
522+
# Compute the gradient of the log-density at the current state
523+
ρ = (1 - stepsize / 4) / (1 + stepsize / 4)
524+
log_gradient = ForwardDiff.gradient(x -> AdvancedMH.logdensity(model, x), current_state.params)
525+
hessian = Symmetric(ForwardDiff.hessian(x -> AdvancedMH.logdensity(model, x), current_state.params))
526+
K = Symmetric(inv(-hessian))
527+
C_inv = I(size(K, 1))
528+
xi = cholesky(K, check = false).L * randn(size(K, 1))
529+
# xi = rand(rng, MvNormal(zeros(size(K, 1)), K))# or cholesky(K_u).L * randn(size(K_u, 1))
530+
nu = xi .- (stepsize / 2) .* K * ((C_inv + hessian) * current_state.params .+ log_gradient)
531+
return ρ * current_state.params .+ sqrt(1 - ρ^2) * nu
510532
end
511533

512534
# method extending AdvancedMH.propose() for the ∞-mHMC proposal
@@ -526,13 +548,13 @@ function AdvancedMH.propose(
526548
proposed_state = proposed_state_init
527549
log_grad_proposed_state = ForwardDiff.gradient(x -> AdvancedMH.logdensity(model, x), current_state.params)
528550

529-
for t in 1:L-1
551+
for t in 1:(L - 1)
530552
println("Iteration t = ", t)
531553
println("Before update, proposed_state: ", proposed_state)
532554
log_gradient = log_grad_proposed_state
533555
proposed_state .+= stepsize .* proposed_aux - (stepsize^2 / 2) .* log_grad_proposed_state
534556
log_grad_proposed_state = ForwardDiff.gradient(x -> AdvancedMH.logdensity(model, x), proposed_state)
535-
proposed_aux .+= - (stepsize / 2) .* log_gradient .- (stepsize / 2) .* log_grad_proposed_state
557+
proposed_aux .+= -(stepsize / 2) .* log_gradient .- (stepsize / 2) .* log_grad_proposed_state
536558
println("After update, proposed_state: ", proposed_state)
537559
println("proposed_aux: ", proposed_aux)
538560
end
@@ -939,18 +961,18 @@ $(DocStringExtensions.TYPEDSIGNATURES)
939961
Computes the expected squared jump distance of the chain.
940962
"""
941963
function esjd(chain::MCMCChains.Chains)
942-
samples = chain.value[:,:,1] # N_samples x N_params x n_chains
964+
samples = chain.value[:, :, 1] # N_samples x N_params x n_chains
943965
n_samples, n_params = size(samples)
944966
esjd = zeros(Float64, n_params)
945967
for i in 2:n_samples
946-
esjd .+= (samples[i, :] .- samples[i - 1, :]).^ 2 ./ n_samples
968+
esjd .+= (samples[i, :] .- samples[i - 1, :]) .^ 2 ./ n_samples
947969
end
948970
return esjd
949-
971+
950972
end
951973

952974

953975

954976

955-
977+
956978
end # module MarkovChainMonteCarlo

test/MarkovChainMonteCarlo/runtests.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ function mcmc_test_template(
122122

123123
# Now begin the actual MCMC, sample is multiply exported so we qualify
124124
chain = MCMC.sample(rng, mcmc, 100_000; stepsize = new_step, discard_initial = 1000)
125-
125+
126126
posterior_distribution = get_posterior(mcmc, chain)
127127
#post_mean = mean(posterior, dims=1)[1]
128128
posterior_mean = mean(posterior_distribution)
@@ -158,8 +158,8 @@ end
158158

159159
@testset "Sine GP & RW Metropolis" begin
160160
em_1 = test_gp_1(y, σ2_y, iopairs)
161-
new_step, posterior_mean_1,chain1 = mcmc_test_template(prior, σ2_y, em_1; mcmc_params...)
162-
esjd1=esjd(chain1)
161+
new_step, posterior_mean_1, chain1 = mcmc_test_template(prior, σ2_y, em_1; mcmc_params...)
162+
esjd1 = esjd(chain1)
163163
@info "ESJD = $esjd1"
164164
@test isapprox(new_step, 0.5; atol = 0.5)
165165
# difference between mean_1 and ground truth comes from MCMC convergence and GP sampling
@@ -172,12 +172,12 @@ end
172172
_, posterior_mean_2, chain2 = mcmc_test_template(prior, σ2_y, em_2; mcmc_params...)
173173
# difference between mean_1 and mean_2 only from MCMC convergence
174174
@test isapprox(posterior_mean_2, posterior_mean_1; atol = 0.1)
175-
# test diagnostic functions on the chain
176-
esjd2=esjd(chain2)
175+
# test diagnostic functions on the chain
176+
esjd2 = esjd(chain2)
177177
@info "ESJD = $esjd2"
178178
# approx [0.04190683285347798, 0.1685296224916364, 0.4129400000002722]
179-
@test all(isapprox.(esjd1,esjd2, rtol=0.1))
180-
179+
@test all(isapprox.(esjd1, esjd2, rtol = 0.1))
180+
181181
end
182182

183183
@testset "Sine GP & pCN" begin
@@ -190,8 +190,8 @@ end
190190
)
191191

192192
em_1 = test_gp_1(y, σ2_y, iopairs)
193-
new_step, posterior_mean_1,chain1 = mcmc_test_template(prior, σ2_y, em_1; mcmc_params...)
194-
esjd1=esjd(chain1)
193+
new_step, posterior_mean_1, chain1 = mcmc_test_template(prior, σ2_y, em_1; mcmc_params...)
194+
esjd1 = esjd(chain1)
195195
@info "ESJD = $esjd1"
196196
@test isapprox(new_step, 0.75; atol = 0.6)
197197
# difference between mean_1 and ground truth comes from MCMC convergence and GP sampling
@@ -201,15 +201,15 @@ end
201201
norm_factor = 10.0
202202
norm_factor = fill(norm_factor, size(y[:, 1])) # must be size of output dim
203203
em_2 = test_gp_2(y, σ2_y, iopairs; norm_factor = norm_factor)
204-
_, posterior_mean_2,chain2 = mcmc_test_template(prior, σ2_y, em_2; mcmc_params...)
204+
_, posterior_mean_2, chain2 = mcmc_test_template(prior, σ2_y, em_2; mcmc_params...)
205205
# difference between mean_1 and mean_2 only from MCMC convergence
206206
@test isapprox(posterior_mean_2, posterior_mean_1; atol = 0.1)
207207

208-
esjd2=esjd(chain2)
208+
esjd2 = esjd(chain2)
209209
@info "ESJD = $esjd2"
210210
# approx [0.03470825350663073, 0.161606734823579, 0.38970000000024896]
211211

212-
@test all(isapprox.(esjd1,esjd2, rtol=0.1))
213-
212+
@test all(isapprox.(esjd1, esjd2, rtol = 0.1))
213+
214214
end
215215
end

0 commit comments

Comments
 (0)