Skip to content

Commit 924c970

Browse files
committed
add Forward diff and selection of gradient-based Samplers
1 parent d5a079b commit 924c970

File tree

2 files changed

+310
-0
lines changed

2 files changed

+310
-0
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
1010
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1111
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1212
EnsembleKalmanProcesses = "aa8a2aa5-91d8-4396-bcef-d4f2ec43552d"
13+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1314
GaussianProcesses = "891a1506-143c-57d2-908e-e1f8e92e6de9"
1415
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1516
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
@@ -31,6 +32,7 @@ Conda = "1.7"
3132
Distributions = "0.24, 0.25"
3233
DocStringExtensions = "0.8, 0.9"
3334
EnsembleKalmanProcesses = "2"
35+
ForwardDiff = "0.10.38"
3436
GaussianProcesses = "0.12"
3537
MCMCChains = "4.14, 5, 6"
3638
Printf = "1"

src/MarkovChainMonteCarlo.jl

Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ using LinearAlgebra
1111
using Printf
1212
using Random
1313
using Statistics
14+
using ForwardDiff
1415

1516
using MCMCChains
1617
import AbstractMCMC: sample # Reexport sample()
@@ -22,6 +23,13 @@ export EmulatorPosteriorModel,
2223
MCMCProtocol,
2324
RWMHSampling,
2425
pCNMHSampling,
26+
MALASampling,
27+
BarkerSampling,
28+
HMCSampling,
29+
infMALASampling,
30+
infHMCSampling,
31+
infmMALASampling,
32+
infmHMCSampling,
2533
MCMCWrapper,
2634
accept_ratio,
2735
optimize_stepsize,
@@ -139,6 +147,140 @@ AdvancedMH.logratio_proposal_density(
139147

140148
MetropolisHastingsSampler(::pCNMHSampling, prior::ParameterDistribution) = pCNMetropolisHastings(_get_proposal(prior))
141149

150+
"""
151+
$(DocStringExtensions.TYPEDEF)
152+
153+
[`MCMCProtocol`](@ref) which uses Metropolis-Hastings sampling that generates proposals for
154+
new parameters according to the MALA.
155+
"""
156+
struct MALASampling <: MCMCProtocol end
157+
158+
struct MetropolisAdjustedLangevin{D} <: AdvancedMH.MHSampler
159+
proposal::D
160+
end
161+
# Define method needed by AdvancedMH for new Sampler
162+
AdvancedMH.logratio_proposal_density(
163+
sampler::MetropolisAdjustedLangevin,
164+
transition_prev::AdvancedMH.AbstractTransition,
165+
candidate,
166+
) = AdvancedMH.logratio_proposal_density(sampler.proposal, transition_prev.params, candidate)
167+
168+
MetropolisHastingsSampler(::MALASampling, prior::ParameterDistribution) = MetropolisAdjustedLangevin(_get_proposal(prior))
169+
"""
170+
$(DocStringExtensions.TYPEDEF)
171+
172+
[`MCMCProtocol`](@ref) which uses Metropolis-Hastings sampling that generates proposals for
173+
new parameters according to the Barker proposal.
174+
"""
175+
struct BarkerSampling <: MCMCProtocol end
176+
177+
struct BarkerMetropolisHastings{D} <: AdvancedMH.MHSampler
178+
proposal::D
179+
end
180+
# Define method needed by AdvancedMH for new Sampler
181+
AdvancedMH.logratio_proposal_density(
182+
sampler::BarkerMetropolisHastings,
183+
transition_prev::AdvancedMH.AbstractTransition,
184+
candidate,
185+
) = AdvancedMH.logratio_proposal_density(sampler.proposal, transition_prev.params, candidate)
186+
187+
MetropolisHastingsSampler(::BarkerSampling, prior::ParameterDistribution) = BarkerMetropolisHastings(_get_proposal(prior))
188+
"""
189+
$(DocStringExtensions.TYPEDEF)
190+
191+
[`MCMCProtocol`](@ref) which uses Metropolis-Hastings sampling that generates proposals for
192+
new parameters according to the HMC proposal.
193+
"""
194+
struct HMCSampling <: MCMCProtocol end
195+
196+
struct HMCMetropolisHastings{D} <: AdvancedMH.MHSampler
197+
proposal::D
198+
end
199+
# Define method needed by AdvancedMH for new Sampler
200+
AdvancedMH.logratio_proposal_density(
201+
sampler::HMCMetropolisHastings,
202+
transition_prev::AdvancedMH.AbstractTransition,
203+
candidate,
204+
) = AdvancedMH.logratio_proposal_density(sampler.proposal, transition_prev.params, candidate)
205+
206+
MetropolisHastingsSampler(::HMCSampling, prior::ParameterDistribution) = HMCMetropolisHastings(_get_proposal(prior))
207+
"""
208+
$(DocStringExtensions.TYPEDEF)
209+
210+
[`MCMCProtocol`](@ref) which uses Metropolis-Hastings sampling that generates proposals for
211+
new parameters according to the infinite-dimensional MALA proposal.
212+
"""
213+
struct infMALASampling <: MCMCProtocol end
214+
215+
struct infMALAMetropolisHastings{D} <: AdvancedMH.MHSampler
216+
proposal::D
217+
end
218+
# Define method needed by AdvancedMH for new Sampler
219+
AdvancedMH.logratio_proposal_density(
220+
sampler::infMALAMetropolisHastings,
221+
transition_prev::AdvancedMH.AbstractTransition,
222+
candidate,
223+
) = AdvancedMH.logratio_proposal_density(sampler.proposal, transition_prev.params, candidate)
224+
225+
MetropolisHastingsSampler(::infMALASampling, prior::ParameterDistribution) = infMALAMetropolisHastings(_get_proposal(prior))
226+
"""
227+
$(DocStringExtensions.TYPEDEF)
228+
229+
[`MCMCProtocol`](@ref) which uses Metropolis-Hastings sampling that generates proposals for
230+
new parameters according to the infinite-dimensional HMC proposal.
231+
"""
232+
struct infHMCSampling <: MCMCProtocol end
233+
234+
struct infHMCMetropolisHastings{D} <: AdvancedMH.MHSampler
235+
proposal::D
236+
end
237+
# Define method needed by AdvancedMH for new Sampler
238+
AdvancedMH.logratio_proposal_density(
239+
sampler::infHMCMetropolisHastings,
240+
transition_prev::AdvancedMH.AbstractTransition,
241+
candidate,
242+
) = AdvancedMH.logratio_proposal_density(sampler.proposal, transition_prev.params, candidate)
243+
244+
MetropolisHastingsSampler(::infHMCSampling, prior::ParameterDistribution) = infHMCMetropolisHastings(_get_proposal(prior))
245+
"""
246+
$(DocStringExtensions.TYPEDEF)
247+
248+
[`MCMCProtocol`](@ref) which uses Metropolis-Hastings sampling that generates proposals for
249+
new parameters according to the infinite-dimensional mMALA proposal.
250+
"""
251+
struct infmMALASampling <: MCMCProtocol end
252+
253+
struct infmMALAMetropolisHastings{D} <: AdvancedMH.MHSampler
254+
proposal::D
255+
end
256+
# Define method needed by AdvancedMH for new Sampler
257+
AdvancedMH.logratio_proposal_density(
258+
sampler::infmMALAMetropolisHastings,
259+
transition_prev::AdvancedMH.AbstractTransition,
260+
candidate,
261+
) = AdvancedMH.logratio_proposal_density(sampler.proposal, transition_prev.params, candidate)
262+
263+
MetropolisHastingsSampler(::infmMALASampling, prior::ParameterDistribution) = infmMALAMetropolisHastings(_get_proposal(prior))
264+
"""
265+
$(DocStringExtensions.TYPEDEF)
266+
267+
[`MCMCProtocol`](@ref) which uses Metropolis-Hastings sampling that generates proposals for
268+
new parameters according to the infinite-dimensional mHMC proposal.
269+
"""
270+
struct infmHMCSampling <: MCMCProtocol end
271+
272+
struct infmHMCMetropolisHastings{D} <: AdvancedMH.MHSampler
273+
proposal::D
274+
end
275+
# Define method needed by AdvancedMH for new Sampler
276+
AdvancedMH.logratio_proposal_density(
277+
sampler::infmHMCMetropolisHastings,
278+
transition_prev::AdvancedMH.AbstractTransition,
279+
candidate,
280+
) = AdvancedMH.logratio_proposal_density(sampler.proposal, transition_prev.params, candidate)
281+
282+
MetropolisHastingsSampler(::infmHMCSampling, prior::ParameterDistribution) = infmHMCMetropolisHastings(_get_proposal(prior))
283+
142284
# ------------------------------------------------------------------------------------------
143285
# Use emulated model in sampler
144286

@@ -245,6 +387,158 @@ function AdvancedMH.propose(
245387
return ρ * current_state.params .+ sqrt(1 - ρ^2) * rand(rng, sampler.proposal)
246388
end
247389

390+
# method extending AdvancedMH.propose() for Metropolis-adjusted Langevin algorithm
391+
function AdvancedMH.propose(
392+
rng::Random.AbstractRNG,
393+
sampler::MetropolisAdjustedLangevin,
394+
model::AdvancedMH.DensityModel,
395+
current_state::MCMCState;
396+
stepsize::FT = 1.0,
397+
) where {FT <: AbstractFloat}
398+
# Compute the gradient of the log-density at the current state
399+
log_gradient = ForwardDiff.gradient(x -> AdvancedMH.logdensity(model, x), current_state.params)
400+
proposed_state = current_state.params .+ (stepsize^2 / 2) .* log_gradient .+ stepsize * rand(rng, sampler.proposal)
401+
return proposed_state
402+
end
403+
404+
# method extending AdvancedMH.propose() for the Barker proposal
405+
function AdvancedMH.propose(
406+
rng::Random.AbstractRNG,
407+
sampler::BarkerMetropolisHastings,
408+
model::AdvancedMH.DensityModel,
409+
current_state::MCMCState;
410+
stepsize::FT = 1.0,
411+
) where {FT <: AbstractFloat}
412+
# Livingstone and Zanella (2022)
413+
# Compute the gradient of the log-density at the current state
414+
log_gradient = ForwardDiff.gradient(x -> AdvancedMH.logdensity(model, x), current_state.params)
415+
n = length(current_state.params)
416+
u = rand(rng, n)
417+
xi = rand(rng, sampler.proposal)
418+
b = u .< 1 ./ (1 .+ exp.(- log_gradient .* xi))
419+
return current_state.params .+ b .* xi
420+
end
421+
422+
# method extending AdvancedMH.propose() for the HMC proposal
423+
function AdvancedMH.propose(
424+
rng::Random.AbstractRNG,
425+
sampler::HMCMetropolisHastings,
426+
model::AdvancedMH.DensityModel,
427+
current_state::MCMCState;
428+
stepsize::FT = 1.0,
429+
) where {FT <: AbstractFloat}
430+
# Compute the gradient of the log-density at the current state
431+
# L = floor(1 / sqrt_step)
432+
L = 10
433+
proposed_aux_init = rand(rng, sampler.proposal)
434+
proposed_state_init = current_state.params
435+
proposed_aux = proposed_aux_init
436+
proposed_state = proposed_state_init
437+
log_grad_proposed_state = ForwardDiff.gradient(x -> AdvancedMH.logdensity(model, x), current_state.params)
438+
439+
for t in 1:L
440+
log_gradient = log_grad_proposed_state
441+
proposed_state .+= sqrt(stepsize) .* proposed_aux - (stepsize / 2) .* log_grad_proposed_state
442+
log_grad_proposed_state = ForwardDiff.gradient(x -> AdvancedMH.logdensity(model, x), proposed_state)
443+
proposed_aux .+= - (sqrt(stepsize) / 2) .* log_gradient .- (sqrt(stepsize) / 2) .* log_grad_proposed_state
444+
end
445+
println("L: ", L, " stepsize: ", round(stepsize, digits = 6), " proposed_state: ", round.(proposed_state, digits = 5))
446+
return proposed_state
447+
end
448+
449+
# method extending AdvancedMH.propose() for ∞-MALA
450+
function AdvancedMH.propose(
451+
rng::Random.AbstractRNG,
452+
sampler::infMALAMetropolisHastings,
453+
model::AdvancedMH.DensityModel,
454+
current_state::MCMCState;
455+
stepsize::FT = 1.0,
456+
) where {FT <: AbstractFloat}
457+
# Compute the gradient of the log-density at the current state
458+
ρ = (1 - stepsize / 4) / (1 + stepsize / 4)
459+
log_gradient = ForwardDiff.gradient(x -> AdvancedMH.logdensity(model, x), current_state.params)
460+
proposed_state = ρ * current_state.params .- sqrt(1 - ρ^2) * (sqrt(stepsize) / 2) .* log_gradient .+ sqrt(1 - ρ^2) * rand(rng, sampler.proposal)
461+
return proposed_state
462+
end
463+
464+
# method extending AdvancedMH.propose() for the ∞-HMC proposal
465+
function AdvancedMH.propose(
466+
rng::Random.AbstractRNG,
467+
sampler::infHMCMetropolisHastings,
468+
model::AdvancedMH.DensityModel,
469+
current_state::MCMCState;
470+
stepsize::FT = 1.0,
471+
) where {FT <: AbstractFloat}
472+
# Compute the gradient of the log-density at the current state
473+
L = 30
474+
# L = 4
475+
proposed_aux_init = rand(rng, sampler.proposal)
476+
proposed_state_init = current_state.params
477+
proposed_aux = proposed_aux_init
478+
proposed_state = proposed_state_init
479+
log_grad_proposed_state = ForwardDiff.gradient(x -> AdvancedMH.logdensity(model, x), current_state.params)
480+
481+
for t in 1:L
482+
log_gradient = log_grad_proposed_state
483+
proposed_state .+= sqrt(stepsize) .* proposed_aux - (stepsize / 2) .* log_grad_proposed_state
484+
log_grad_proposed_state = ForwardDiff.gradient(x -> AdvancedMH.logdensity(model, x), proposed_state)
485+
proposed_aux .+= - (sqrt(stepsize) / 2) .* log_gradient .- (sqrt(stepsize) / 2) .* log_grad_proposed_state
486+
end
487+
println("L: ", L, " stepsize: ", round(stepsize, digits = 8), " proposed_state: ", round.(proposed_state, digits = 5))
488+
return proposed_state
489+
end
490+
491+
# method extending AdvancedMH.propose() for ∞-mMALA
492+
function AdvancedMH.propose(
493+
rng::Random.AbstractRNG,
494+
sampler::infmMALAMetropolisHastings,
495+
model::AdvancedMH.DensityModel,
496+
current_state::MCMCState;
497+
stepsize::FT = 1.0,
498+
) where {FT <: AbstractFloat}
499+
# Compute the gradient of the log-density at the current state
500+
ρ = (1 - stepsize / 4) / (1 + stepsize / 4)
501+
log_gradient = ForwardDiff.gradient(x -> AdvancedMH.logdensity(model, x), current_state.params)
502+
hessian = Symmetric(ForwardDiff.hessian(x -> AdvancedMH.logdensity(model, x), current_state.params))
503+
K = Symmetric(inv(- hessian))
504+
C_inv = I(size(K, 1))
505+
xi = cholesky(K, check=false).L * randn(size(K, 1))
506+
# xi = rand(rng, MvNormal(zeros(size(K, 1)), K))# or cholesky(K_u).L * randn(size(K_u, 1))
507+
nu = xi .- (stepsize / 2) .* K * ((C_inv + hessian) * current_state.params .+ log_gradient)
508+
return ρ * current_state.params .+ sqrt(1 - ρ^2) * nu
509+
end
510+
511+
# method extending AdvancedMH.propose() for the ∞-mHMC proposal
512+
function AdvancedMH.propose(
513+
rng::Random.AbstractRNG,
514+
sampler::infmHMCMetropolisHastings,
515+
model::AdvancedMH.DensityModel,
516+
current_state::MCMCState;
517+
stepsize::FT = 1.0,
518+
) where {FT <: AbstractFloat}
519+
# Compute the gradient of the log-density at the current state
520+
# L = floor(1 / sqrt_step)
521+
L = 4
522+
proposed_aux_init = rand(rng, sampler.proposal)
523+
proposed_state_init = current_state.params
524+
proposed_aux = proposed_aux_init
525+
proposed_state = proposed_state_init
526+
log_grad_proposed_state = ForwardDiff.gradient(x -> AdvancedMH.logdensity(model, x), current_state.params)
527+
528+
for t in 1:L-1
529+
println("Iteration t = ", t)
530+
println("Before update, proposed_state: ", proposed_state)
531+
log_gradient = log_grad_proposed_state
532+
proposed_state .+= stepsize .* proposed_aux - (stepsize^2 / 2) .* log_grad_proposed_state
533+
log_grad_proposed_state = ForwardDiff.gradient(x -> AdvancedMH.logdensity(model, x), proposed_state)
534+
proposed_aux .+= - (stepsize / 2) .* log_gradient .- (stepsize / 2) .* log_grad_proposed_state
535+
println("After update, proposed_state: ", proposed_state)
536+
println("proposed_aux: ", proposed_aux)
537+
end
538+
return proposed_state
539+
end
540+
541+
248542
# Copy a MCMCState and set accepted = false
249543
reject_transition(t::MCMCState) = MCMCState(t.params, t.log_density, false)
250544

@@ -408,6 +702,20 @@ decorrelation) that was applied in the Emulator. It creates and wraps an instanc
408702
fixed stepsize.
409703
- [`pCNMHSampling`](@ref): Metropolis-Hastings sampling using the preconditioned
410704
Crank-Nicholson algorithm, which has a well-behaved small-stepsize limit.
705+
[`MALASampling`](@ref): Metropolis-Hastings sampling using the Metropolis
706+
-adjusted Langevin algorithm, which exploits the gradient information of the target.
707+
- [`BarkerSampling`](@ref): Metropolis-Hastings sampling using the Barker
708+
proposal, which has a robustness to choosing step-size parameters.
709+
- [`HMCSampling`](@ref): Metropolis-Hastings sampling using the Hamiltonian
710+
Monte Carlo algorithm, which is a momentum-added gradient-based MCMC.
711+
- [`infMALASampling`](@ref): Metropolis-Hastings sampling using the infinite dimensional
712+
MALA, which exploits the gradient information and has a well-behaved small-stepsize limit.
713+
- [`infHMCSampling`](@ref): Metropolis-Hastings sampling using the infinite dimensional
714+
HMC, which is a momentum-added gradient-based and has a well-behaved small-stepsize limit.
715+
- [`infmMALASampling`](@ref): Metropolis-Hastings sampling using the ∞-mMALA,
716+
which is geometry-informed and has a well-behaved small-stepsize limit.
717+
- [`infmHMCSampling`](@ref): Metropolis-Hastings sampling using the ∞-mHMC,
718+
which is geometry-informed and has a well-behaved small-stepsize limit.
411719
412720
- `obs_sample`: A single sample from the observations. Can, e.g., be picked from an
413721
Observation struct using `get_obs_sample`.

0 commit comments

Comments
 (0)