Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic Filtering Structure #56

Draft
wants to merge 23 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
440252c
added basic particle methods and filters
charlesknipp Aug 9, 2024
9fd4453
added qualifiers
charlesknipp Aug 12, 2024
3fd90c4
added parameter priors
charlesknipp Aug 12, 2024
884b9e3
Merge branch 'main' into ck/particle-methods
charlesknipp Aug 30, 2024
1def6a1
Merge branch 'main' into ck/particle-methods
charlesknipp Sep 24, 2024
a5a2e05
added adaptive resampling to bootstrap filter (WIP)
charlesknipp Sep 25, 2024
57da3ff
Julia fomatter changes
charlesknipp Sep 25, 2024
dc713b0
Merge branch 'ck/particle-methods' of https://github.com/TuringLang/S…
charlesknipp Sep 25, 2024
b846fa4
changed eltype for <: StateSpaceModel
charlesknipp Sep 26, 2024
4263ae7
updated naming conventions
charlesknipp Sep 26, 2024
5a2aeb4
formatter
charlesknipp Sep 26, 2024
8db658b
fixed adaptive resampling
charlesknipp Sep 27, 2024
15dfa9f
added particle ancestry
charlesknipp Oct 1, 2024
7e3c93d
formatter issues
charlesknipp Oct 1, 2024
f905a41
fixed metropolis and added rejection resampler
charlesknipp Oct 1, 2024
8ac1455
Keep track of free indices using stack
THargreaves Oct 2, 2024
f11a63e
updated particle types and organized directory
charlesknipp Oct 2, 2024
1fa3c93
weakened SSM type parameter assertions
charlesknipp Oct 4, 2024
8cb4338
improved particle state containment and resampling
charlesknipp Oct 4, 2024
73dd433
added hacky sparse ancestry to example
charlesknipp Oct 5, 2024
f71ab32
fixed RNG in rejection resampling
charlesknipp Oct 6, 2024
25cebf4
improved callbacks and resamplers
charlesknipp Oct 6, 2024
c729879
formatting
charlesknipp Oct 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions examples/particle-mcmc/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GaussianDistributions = "43dcc890-d446-5863-8d1a-14597580bb8d"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
85 changes: 85 additions & 0 deletions examples/particle-mcmc/script.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
using AdvancedMH
using CairoMakie

include("simple-filters.jl")

true_params, simulation_model = let T = Float32
θ = randexp(T, 3)
dyn = LinearGaussianLatentDynamics(T[1 1;0 1], diagm(θ[1:2]))
obs = LinearGaussianObservationProcess(T[0.5 0.5], diagm(θ[3:end]))
θ, StateSpaceModel(dyn, obs)
end

# simulate data
rng = MersenneTwister(1234)
_, _, data = sample(rng, simulation_model, 150)

# consider a default Gamma prior with Float32s
prior_dist = product_distribution(Gamma(1f0), Gamma(1f0), Gamma(1f0))

# test the adaptive resampling procedure
sample(rng, simulation_model, data, BF(512, 0.1); debug=true);


#=
Not crazy about this structure, especially since the RNG is referenced on
the global scope. I think we can make a PMCMC sampler type which includes
the filter algorithm within the sampler definition.

Another issue is that we lose information on the states. Granted, this is
also by design since that would cost a considerable amount of memory, but
is useful nonetheless. This also needs to interface with bundle_samples()
different than ususal, since we have the parameter space and the filtered
states.
=#
function density(θ::Vector{T}) where {T<:Real}
if insupport(prior_dist, θ)
dyn = LinearGaussianLatentDynamics(T[1 1;0 1], diagm(θ[1:2]))
obs = LinearGaussianObservationProcess(T[0.5 0.5], diagm(θ[3:end]))

# _, ll = sample(rng, StateSpaceModel(dyn, obs), data, BF(512))
_, ll = sample(rng, StateSpaceModel(dyn, obs), data, KF())
return ll + logpdf(prior_dist, θ)
else
return -Inf
end
end

# plug it into the DensityModel interface for now
pmmh = RWMH(MvNormal(zeros(Float32, 3), (0.01f0)*I))
model = DensityModel(density)

# works with AdvancedMH out of the box
chains = sample(model, pmmh, 10_000)
burn_in = 1_000

# plot the posteriors
hist_plots = begin
param_post = hcat(getproperty.(chains[burn_in:end], :params)...)
fig = Figure(size = (1200, 400))

for i in 1:3
# plot the posteriors with burn-in
hist(
fig[1, i],
param_post[i, :],
color = :gray,
strokewidth = 1,
normalization = :pdf
)

# plot the true values
vlines!(
fig[1, i],
true_params[i],
color = :red,
linestyle = :dash,
linewidth = 2
)
end

fig
end

# this is useful for SMC algorithms like SMC² or density tempered SMC
acc_ratio = mean(getproperty.(chains, :accepted))
Loading
Loading