From 440cdc8279790639b084ced2488dd99b172c15b9 Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Sun, 20 Oct 2024 11:47:32 +0100 Subject: [PATCH 01/13] Add auxiliary particle filter --- src/algorithms/bootstrap.jl | 11 +++++++++++ src/resamplers.jl | 5 ++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/algorithms/bootstrap.jl b/src/algorithms/bootstrap.jl index a508761..da67da5 100644 --- a/src/algorithms/bootstrap.jl +++ b/src/algorithms/bootstrap.jl @@ -76,3 +76,14 @@ end function logmarginal(states::ParticleContainer, ::BootstrapFilter) return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights) end + +function reset_weights!( + state::ParticleState{T,WT}, idxs, filter::BootstrapFilter +) where {T,WT<:Real} + fill!(state.log_weights, -log(WT(length(state.particles)))) + return state +end + +function logmarginal(states::ParticleContainer, ::BootstrapFilter) + return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights) +end diff --git a/src/resamplers.jl b/src/resamplers.jl index 60a7b36..85e40fc 100644 --- a/src/resamplers.jl +++ b/src/resamplers.jl @@ -35,8 +35,7 @@ function resample( deepcopy(states.z_particles[idxs]), CUDA.zeros(T, length(states)), ) - - return new_state, idxs + return reset_weights!(state, idxs, filter) end ## CONDITIONAL RESAMPLING ################################################################## @@ -81,7 +80,7 @@ function resample( @debug "ESS: $ess" if cond_resampler.threshold * n ≥ ess - return resample(rng, cond_resampler.resampler, state) + return resample(rng, cond_resampler.resampler, state, filter) else return deepcopy(state), collect(1:n) end From 2c33c80230fd08f79fda6fcd4e0c53eaa3e9a045 Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Wed, 23 Oct 2024 20:45:31 +0100 Subject: [PATCH 02/13] Mean transition --- src/GeneralisedFilters.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/GeneralisedFilters.jl b/src/GeneralisedFilters.jl index d53589f..0437ce3 100644 --- a/src/GeneralisedFilters.jl +++ b/src/GeneralisedFilters.jl @@ -75,6 +75,7 @@ function filter( callback=nothing, kwargs..., ) + println("1") states = initialise(rng, model, alg; kwargs...) log_evidence = zero(eltype(model)) @@ -95,6 +96,7 @@ function filter( observations::AbstractVector; kwargs..., ) + println("2") return filter(default_rng(), model, alg, observations; kwargs...) end From 5198e3cd6fa5edc32323f4387761b4c55080325e Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Wed, 23 Oct 2024 21:03:18 +0100 Subject: [PATCH 03/13] Merge conflict --- src/GeneralisedFilters.jl | 2 -- src/resamplers.jl | 3 ++- test/runtests.jl | 42 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 44 insertions(+), 3 deletions(-) diff --git a/src/GeneralisedFilters.jl b/src/GeneralisedFilters.jl index 0437ce3..d53589f 100644 --- a/src/GeneralisedFilters.jl +++ b/src/GeneralisedFilters.jl @@ -75,7 +75,6 @@ function filter( callback=nothing, kwargs..., ) - println("1") states = initialise(rng, model, alg; kwargs...) log_evidence = zero(eltype(model)) @@ -96,7 +95,6 @@ function filter( observations::AbstractVector; kwargs..., ) - println("2") return filter(default_rng(), model, alg, observations; kwargs...) end diff --git a/src/resamplers.jl b/src/resamplers.jl index 85e40fc..7a281dd 100644 --- a/src/resamplers.jl +++ b/src/resamplers.jl @@ -35,7 +35,8 @@ function resample( deepcopy(states.z_particles[idxs]), CUDA.zeros(T, length(states)), ) - return reset_weights!(state, idxs, filter) + reset_weights!(new_state, idxs, filter) + return new_state, idxs end ## CONDITIONAL RESAMPLING ################################################################## diff --git a/test/runtests.jl b/test/runtests.jl index 4d85a0c..882c1e8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -125,6 +125,48 @@ end @test llkf ≈ llapf atol = 2 end +@testitem "APF filter test" begin + using GeneralisedFilters + using SSMProblems + using StableRNGs + using PDMats + using LinearAlgebra + using Random: randexp + + T = Float32 + rng = StableRNG(1234) + σx², σy² = randexp(rng, T, 2) + + # initial state distribution + μ0 = zeros(T, 2) + Σ0 = PDMat(T[1 0; 0 1]) + + # state transition equation + A = T[1 1; 0 1] + b = T[0; 0] + Q = PDiagMat([σx²; 0]) + + # observation equation + H = T[1 0] + c = T[0;] + R = [σy²;;] + + # when working with PDMats, the Kalman filter doesn't play nicely without this + function Base.convert(::Type{PDMat{T,MT}}, mat::MT) where {MT<:AbstractMatrix,T<:Real} + return PDMat(Symmetric(mat)) + end + + model = create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R) + _, _, data = sample(rng, model, 20) + + bf = APF(2^10; threshold=0.8) + _, llbf = GeneralisedFilters.filter(rng, model, bf, data) + _, llkf = GeneralisedFilters.filter(rng, model, KF(), data) + + # since this is log valued, we can up the tolerance + @test llkf ≈ llbf atol = 2 +end + @testitem "Forward algorithm test" begin using GeneralisedFilters using Distributions From 92548985ed8d122d1214cb6ccbcdf8cf7d380ee1 Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Wed, 23 Oct 2024 21:51:18 +0100 Subject: [PATCH 04/13] test --- .envrc | 2 ++ scratch.jl | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+) create mode 100644 .envrc create mode 100644 scratch.jl diff --git a/.envrc b/.envrc new file mode 100644 index 0000000..b945361 --- /dev/null +++ b/.envrc @@ -0,0 +1,2 @@ +layout julia +use julia 1.10 \ No newline at end of file diff --git a/scratch.jl b/scratch.jl new file mode 100644 index 0000000..0728d4f --- /dev/null +++ b/scratch.jl @@ -0,0 +1,54 @@ +using GeneralisedFilters +using SSMProblems +using PDMats +using LinearAlgebra +using Random: randexp +using Random +using CairoMakie +using Statistics + +T = Float32 +rng = MersenneTwister(1) +σx², σy² = randexp(rng, T, 2) + +# initial state distribution +μ0 = zeros(T, 2) +Σ0 = PDMat(T[1 0; 0 1]) + +# state transition equation +A = T[1 1; 0 1] +b = T[0; 0] +Q = PDiagMat([σx²; T(1e-6)]) + +# observation equation +H = T[1 0] +c = T[0;] +R = [σy²;;] + +# when working with PDMats, the Kalman filter doesn't play nicely without this +function Base.convert(::Type{PDMat{T,MT}}, mat::MT) where {MT<:AbstractMatrix,T<:Real} + return PDMat(Symmetric(mat)) +end + +model = create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R) +_, latent, data = sample(rng, model, 100) + +bf = BF(2^10; threshold=0.8) +smoother = FFBS(bf) + +M = 1_000 +trajectories = GeneralisedFilters.sample(rng, model, smoother, data, M) +x1 = first.(trajectories) +m = vec(mean(x1, dims=2)) +stdev = vec(std(x1, dims=2)) + +figure = Figure() +pos = figure[1, 1] +lines(pos, m - 2 * stdev, color="black") +lines!(pos, m + 2 * stdev, color="black") +lines!(m, color="red") +lines!(pos, first.(latent), label="True latent trajectory", color="blue") +# for i in 1:M +# lines!(pos, x1[:, i], color=:black, alpha=.01) +# end +figure \ No newline at end of file From fa84e14a9f9a1e62d78ffd3cf1af022ae1cafdbb Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Sat, 26 Oct 2024 17:13:54 +0100 Subject: [PATCH 05/13] Add FFBS draft, change filter interface --- src/GeneralisedFilters.jl | 1 + src/algorithms/apf.jl | 10 ++++-- src/algorithms/bootstrap.jl | 18 ++++++----- src/algorithms/ffbs.jl | 61 +++++++++++++++++++++++++++++++++++++ 4 files changed, 80 insertions(+), 10 deletions(-) create mode 100644 src/algorithms/ffbs.jl diff --git a/src/GeneralisedFilters.jl b/src/GeneralisedFilters.jl index d53589f..fad4f92 100644 --- a/src/GeneralisedFilters.jl +++ b/src/GeneralisedFilters.jl @@ -128,5 +128,6 @@ include("algorithms/apf.jl") include("algorithms/kalman.jl") include("algorithms/forward.jl") include("algorithms/rbpf.jl") +include("algorithms/ffbs.jl") end diff --git a/src/algorithms/apf.jl b/src/algorithms/apf.jl index f4a9df6..17489da 100644 --- a/src/algorithms/apf.jl +++ b/src/algorithms/apf.jl @@ -6,7 +6,7 @@ mutable struct AuxiliaryParticleFilter{N,RS<:AbstractConditionalResampler} <: Ab end function AuxiliaryParticleFilter( - N::Integer; threshold::Real=0., resampler::AbstractResampler=Systematic() + N::Integer; threshold::Real=0.0, resampler::AbstractResampler=Systematic() ) conditional_resampler = ESSResampler(threshold, resampler) return AuxiliaryParticleFilter{N,typeof(conditional_resampler)}(conditional_resampler, zeros(N)) @@ -24,7 +24,9 @@ function initialise( initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:N) initial_weights = zeros(T, N) - return update_ref!(ParticleContainer(initial_states, initial_weights), ref_state, filter) + return update_ref!( + ParticleContainer(initial_states, initial_weights), ref_state, filter + ) end function update_weights!( @@ -59,7 +61,9 @@ function predict( states.filtered.log_weights .+= auxiliary_weights filter.aux = auxiliary_weights - states.proposed, states.ancestors = resample(rng, filter.resampler, states.filtered, filter) + states.proposed, states.ancestors = resample( + rng, filter.resampler, states.filtered, filter + ) states.proposed.particles = map( x -> SSMProblems.simulate(rng, model.dyn, step, x; kwargs...), states.proposed.particles, diff --git a/src/algorithms/bootstrap.jl b/src/algorithms/bootstrap.jl index da67da5..08ba5ba 100644 --- a/src/algorithms/bootstrap.jl +++ b/src/algorithms/bootstrap.jl @@ -1,5 +1,7 @@ export BootstrapFilter, BF +abstract type AbstractParticleFilter{N} <: AbstractFilter end + struct BootstrapFilter{N,RS<:AbstractResampler} <: AbstractParticleFilter{N} resampler::RS end @@ -8,7 +10,7 @@ function BootstrapFilter( N::Integer; threshold::Real=1.0, resampler::AbstractResampler=Systematic() ) conditional_resampler = ESSResampler(threshold, resampler) - return BootstrapFilter{N, typeof(conditional_resampler)}(conditional_resampler) + return BootstrapFilter{N,typeof(conditional_resampler)}(conditional_resampler) end """Shorthand for `BootstrapFilter`""" @@ -38,7 +40,9 @@ function predict( ref_state::Union{Nothing,AbstractVector{T}}=nothing, kwargs..., ) where {T} - states.proposed, states.ancestors = resample(rng, filter.resampler, states.filtered, filter) + states.proposed, states.ancestors = resample( + rng, filter.resampler, states.filtered, filter + ) states.proposed.particles = map( x -> SSMProblems.simulate(rng, model.dyn, step, x; kwargs...), collect(states.proposed), @@ -49,12 +53,12 @@ end function update( model::StateSpaceModel{T}, - filter::BootstrapFilter, + filter::BootstrapFilter{N}, step::Integer, states::ParticleContainer, observation; kwargs..., -) where {T} +) where {T,N} log_increments = map( x -> SSMProblems.logdensity(model.obs, step, x, observation; kwargs...), collect(states.proposed), @@ -78,9 +82,9 @@ function logmarginal(states::ParticleContainer, ::BootstrapFilter) end function reset_weights!( - state::ParticleState{T,WT}, idxs, filter::BootstrapFilter -) where {T,WT<:Real} - fill!(state.log_weights, -log(WT(length(state.particles)))) + state::ParticleState{T,WT}, idxs, filter::BootstrapFilter{N} +) where {T,WT<:Real,N} + fill!(state.log_weights, -log(WT(N))) return state end diff --git a/src/algorithms/ffbs.jl b/src/algorithms/ffbs.jl new file mode 100644 index 0000000..6007f2a --- /dev/null +++ b/src/algorithms/ffbs.jl @@ -0,0 +1,61 @@ +export FFBS + +abstract type AbstractSmoother <: AbstractSampler end + +struct FFBS{T<:AbstractParticleFilter} + filter::T +end + +""" + smooth(rng::AbstractRNG, alg::AbstractSmooterh, model::AbstractStateSpaceModel, obs::AbstractVector, M::Integer; callback, kwargs...) +""" +function smooth end + +struct WeightedParticleRecorderCallback{T,WT} + particles::Array{T} + log_weights::Array{WT} +end + +function (callback::WeightedParticleRecorderCallback)( + model, filter, step, states, data; kwargs... +) + filtered_states = states.filtered + callback.particles[step, :] = filtered_states.particles + callback.log_weights[step, :] = filtered_states.log_weights + return nothing +end + +function smooth( + rng::Random.AbstractRNG, + model::StateSpaceModel{T,LDT}, + alg::FFBS{<:BootstrapFilter{N}}, + obs::AbstractVector, + M::Integer; + callback=nothing, + kwargs..., +) where {T,LDT,N} + n_timestep = length(obs) + recorder = WeightedParticleRecorderCallback( + Array{eltype(model.dyn)}(undef, n_timestep, N), Array{T}(undef, n_timestep, N) + ) + + particles, _ = filter(rng, model, alg.filter, obs; callback=recorder, kwargs...) + idx_ref = rand(rng, Categorical(weights(particles.filtered)), M) + trajectories = Array{eltype(model.dyn)}(undef, n_timestep, M) + + forward_state = particles.filtered[idx_ref] + trajectories[end, :] = forward_state + for step in (n_timestep - 1):-1:1 + for j in 1:M + transitions = map( + x -> + SSMProblems.logdensity(model.dyn, step, forward_state[j], x; kwargs...), + recorder.particles[step, :], + ) + backward_weights = recorder.log_weights[step, :] + transitions + ancestor = rand(rng, Categorical(softmax(backward_weights))) + trajectories[step, j] = recorder.particles[step, ancestor] + end + end + return trajectories +end From 4b948db6652e01310b5eb5ee5b3366b13112164e Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Sat, 26 Oct 2024 17:22:38 +0100 Subject: [PATCH 06/13] Fix backward weights: --- src/algorithms/ffbs.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/algorithms/ffbs.jl b/src/algorithms/ffbs.jl index 6007f2a..de7b790 100644 --- a/src/algorithms/ffbs.jl +++ b/src/algorithms/ffbs.jl @@ -43,13 +43,12 @@ function smooth( idx_ref = rand(rng, Categorical(weights(particles.filtered)), M) trajectories = Array{eltype(model.dyn)}(undef, n_timestep, M) - forward_state = particles.filtered[idx_ref] - trajectories[end, :] = forward_state + trajectories[end, :] = particles.filtered[idx_ref] for step in (n_timestep - 1):-1:1 for j in 1:M transitions = map( x -> - SSMProblems.logdensity(model.dyn, step, forward_state[j], x; kwargs...), + SSMProblems.logdensity(model.dyn, step, x, trajectories[step+1]; kwargs...), recorder.particles[step, :], ) backward_weights = recorder.log_weights[step, :] + transitions From 556babdfce3b031953b5d989b258a1e061f580e5 Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Sat, 26 Oct 2024 18:25:13 +0100 Subject: [PATCH 07/13] Weigghts: --- src/algorithms/apf.jl | 4 ++-- src/algorithms/bootstrap.jl | 2 +- src/algorithms/ffbs.jl | 23 ++++++++++++++++++----- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/algorithms/apf.jl b/src/algorithms/apf.jl index 17489da..bf9e089 100644 --- a/src/algorithms/apf.jl +++ b/src/algorithms/apf.jl @@ -74,12 +74,12 @@ end function update( model::StateSpaceModel{T}, - filter::AuxiliaryParticleFilter, + filter::AuxiliaryParticleFilter{N}, step::Integer, states::ParticleContainer, observation; kwargs..., -) where {T} +) where {T,N} @debug "step $step" log_increments = map( x -> SSMProblems.logdensity(model.obs, step, x, observation; kwargs...), diff --git a/src/algorithms/bootstrap.jl b/src/algorithms/bootstrap.jl index 08ba5ba..471edc7 100644 --- a/src/algorithms/bootstrap.jl +++ b/src/algorithms/bootstrap.jl @@ -82,7 +82,7 @@ function logmarginal(states::ParticleContainer, ::BootstrapFilter) end function reset_weights!( - state::ParticleState{T,WT}, idxs, filter::BootstrapFilter{N} + state::ParticleState{T,WT}, idxs, ::BootstrapFilter{N} ) where {T,WT<:Real,N} fill!(state.log_weights, -log(WT(N))) return state diff --git a/src/algorithms/ffbs.jl b/src/algorithms/ffbs.jl index de7b790..b960770 100644 --- a/src/algorithms/ffbs.jl +++ b/src/algorithms/ffbs.jl @@ -25,7 +25,7 @@ function (callback::WeightedParticleRecorderCallback)( return nothing end -function smooth( +function sample( rng::Random.AbstractRNG, model::StateSpaceModel{T,LDT}, alg::FFBS{<:BootstrapFilter{N}}, @@ -40,21 +40,34 @@ function smooth( ) particles, _ = filter(rng, model, alg.filter, obs; callback=recorder, kwargs...) + + # Backward sampling - exact idx_ref = rand(rng, Categorical(weights(particles.filtered)), M) trajectories = Array{eltype(model.dyn)}(undef, n_timestep, M) trajectories[end, :] = particles.filtered[idx_ref] for step in (n_timestep - 1):-1:1 for j in 1:M - transitions = map( - x -> - SSMProblems.logdensity(model.dyn, step, x, trajectories[step+1]; kwargs...), + backward_weights = backward( + model::StateSpaceModel, + step, + trajectories[step + 1], recorder.particles[step, :], + recorder.log_weights[step, :]; + kwargs..., ) - backward_weights = recorder.log_weights[step, :] + transitions ancestor = rand(rng, Categorical(softmax(backward_weights))) trajectories[step, j] = recorder.particles[step, ancestor] end end return trajectories end + +function backward( + model::StateSpaceModel, step::Integer, state, particles::T, log_weights::WT; kwargs... +) where {T,WT} + transitions = map( + x -> SSMProblems.logdensity(model.dyn, step, x, state; kwargs...), particles + ) + return log_weights + transitions +end From 20faf22ae8686837841f9a0560bde124b96df0d0 Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Sun, 27 Oct 2024 11:26:11 +0000 Subject: [PATCH 08/13] Split trajectory --- src/algorithms/ffbs.jl | 62 +++++++++++++++++++++++++++--------------- 1 file changed, 40 insertions(+), 22 deletions(-) diff --git a/src/algorithms/ffbs.jl b/src/algorithms/ffbs.jl index b960770..711b165 100644 --- a/src/algorithms/ffbs.jl +++ b/src/algorithms/ffbs.jl @@ -25,6 +25,43 @@ function (callback::WeightedParticleRecorderCallback)( return nothing end +function gen_trajectory( + rng::Random.AbstractRNG, + model::StateSpaceModel, + particles::AbstractMatrix{T}, # Need better container + log_weights::AbstractMatrix{WT}, + forward_state, + n_timestep::Int; + kwargs... +) where {T,WT} + trajectory = Vector{T}(undef, n_timestep) + trajectory[end] = forward_state + for step in (n_timestep - 1):-1:1 + backward_weights = backward( + model, + step, + trajectory[step + 1], + particles[step, :], + log_weights[step, :]; + kwargs..., + ) + ancestor = rand(rng, Categorical(softmax(backward_weights))) + trajectory[step] = particles[step, ancestor] + end + return trajectory +end + + +function backward( + model::StateSpaceModel, step::Integer, state, particles::T, log_weights::WT; kwargs... +) where {T,WT} + transitions = map(particles) do prev_state + SSMProblems.logdensity(model.dyn, step, prev_state, state; kwargs...) + end + return log_weights + transitions +end + + function sample( rng::Random.AbstractRNG, model::StateSpaceModel{T,LDT}, @@ -34,6 +71,7 @@ function sample( callback=nothing, kwargs..., ) where {T,LDT,N} + n_timestep = length(obs) recorder = WeightedParticleRecorderCallback( Array{eltype(model.dyn)}(undef, n_timestep, N), Array{T}(undef, n_timestep, N) @@ -46,28 +84,8 @@ function sample( trajectories = Array{eltype(model.dyn)}(undef, n_timestep, M) trajectories[end, :] = particles.filtered[idx_ref] - for step in (n_timestep - 1):-1:1 - for j in 1:M - backward_weights = backward( - model::StateSpaceModel, - step, - trajectories[step + 1], - recorder.particles[step, :], - recorder.log_weights[step, :]; - kwargs..., - ) - ancestor = rand(rng, Categorical(softmax(backward_weights))) - trajectories[step, j] = recorder.particles[step, ancestor] - end + for j in 1:M + trajectories[:, j] = gen_trajectory(rng, model, recorder.particles, recorder.log_weights, trajectories[end, j], n_timestep) end return trajectories end - -function backward( - model::StateSpaceModel, step::Integer, state, particles::T, log_weights::WT; kwargs... -) where {T,WT} - transitions = map( - x -> SSMProblems.logdensity(model.dyn, step, x, state; kwargs...), particles - ) - return log_weights + transitions -end From 9f03987390138fba8a22fbfdc8ccb6c150ea260d Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Sun, 27 Oct 2024 11:31:33 +0000 Subject: [PATCH 09/13] format --- src/algorithms/ffbs.jl | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/algorithms/ffbs.jl b/src/algorithms/ffbs.jl index 711b165..c02dae5 100644 --- a/src/algorithms/ffbs.jl +++ b/src/algorithms/ffbs.jl @@ -26,13 +26,13 @@ function (callback::WeightedParticleRecorderCallback)( end function gen_trajectory( - rng::Random.AbstractRNG, - model::StateSpaceModel, + rng::Random.AbstractRNG, + model::StateSpaceModel, particles::AbstractMatrix{T}, # Need better container - log_weights::AbstractMatrix{WT}, + log_weights::AbstractMatrix{WT}, forward_state, n_timestep::Int; - kwargs... + kwargs..., ) where {T,WT} trajectory = Vector{T}(undef, n_timestep) trajectory[end] = forward_state @@ -51,7 +51,6 @@ function gen_trajectory( return trajectory end - function backward( model::StateSpaceModel, step::Integer, state, particles::T, log_weights::WT; kwargs... ) where {T,WT} @@ -61,7 +60,6 @@ function backward( return log_weights + transitions end - function sample( rng::Random.AbstractRNG, model::StateSpaceModel{T,LDT}, @@ -71,7 +69,6 @@ function sample( callback=nothing, kwargs..., ) where {T,LDT,N} - n_timestep = length(obs) recorder = WeightedParticleRecorderCallback( Array{eltype(model.dyn)}(undef, n_timestep, N), Array{T}(undef, n_timestep, N) @@ -85,7 +82,14 @@ function sample( trajectories[end, :] = particles.filtered[idx_ref] for j in 1:M - trajectories[:, j] = gen_trajectory(rng, model, recorder.particles, recorder.log_weights, trajectories[end, j], n_timestep) + trajectories[:, j] = gen_trajectory( + rng, + model, + recorder.particles, + recorder.log_weights, + trajectories[end, j], + n_timestep, + ) end return trajectories end From 899c5075c15c38dc66b47277775488d092af51f7 Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Sun, 27 Oct 2024 12:40:59 +0000 Subject: [PATCH 10/13] Clean up --- .envrc | 2 -- scratch.jl | 54 ------------------------------------------------------ 2 files changed, 56 deletions(-) delete mode 100644 .envrc delete mode 100644 scratch.jl diff --git a/.envrc b/.envrc deleted file mode 100644 index b945361..0000000 --- a/.envrc +++ /dev/null @@ -1,2 +0,0 @@ -layout julia -use julia 1.10 \ No newline at end of file diff --git a/scratch.jl b/scratch.jl deleted file mode 100644 index 0728d4f..0000000 --- a/scratch.jl +++ /dev/null @@ -1,54 +0,0 @@ -using GeneralisedFilters -using SSMProblems -using PDMats -using LinearAlgebra -using Random: randexp -using Random -using CairoMakie -using Statistics - -T = Float32 -rng = MersenneTwister(1) -σx², σy² = randexp(rng, T, 2) - -# initial state distribution -μ0 = zeros(T, 2) -Σ0 = PDMat(T[1 0; 0 1]) - -# state transition equation -A = T[1 1; 0 1] -b = T[0; 0] -Q = PDiagMat([σx²; T(1e-6)]) - -# observation equation -H = T[1 0] -c = T[0;] -R = [σy²;;] - -# when working with PDMats, the Kalman filter doesn't play nicely without this -function Base.convert(::Type{PDMat{T,MT}}, mat::MT) where {MT<:AbstractMatrix,T<:Real} - return PDMat(Symmetric(mat)) -end - -model = create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R) -_, latent, data = sample(rng, model, 100) - -bf = BF(2^10; threshold=0.8) -smoother = FFBS(bf) - -M = 1_000 -trajectories = GeneralisedFilters.sample(rng, model, smoother, data, M) -x1 = first.(trajectories) -m = vec(mean(x1, dims=2)) -stdev = vec(std(x1, dims=2)) - -figure = Figure() -pos = figure[1, 1] -lines(pos, m - 2 * stdev, color="black") -lines!(pos, m + 2 * stdev, color="black") -lines!(m, color="red") -lines!(pos, first.(latent), label="True latent trajectory", color="blue") -# for i in 1:M -# lines!(pos, x1[:, i], color=:black, alpha=.01) -# end -figure \ No newline at end of file From 51fd883b19f2975674f7bd249bf13b4d025baae6 Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Sun, 27 Oct 2024 12:49:47 +0000 Subject: [PATCH 11/13] Ambiguous --- src/algorithms/bootstrap.jl | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/src/algorithms/bootstrap.jl b/src/algorithms/bootstrap.jl index 471edc7..11a19fe 100644 --- a/src/algorithms/bootstrap.jl +++ b/src/algorithms/bootstrap.jl @@ -71,23 +71,12 @@ function update( end function reset_weights!( - state::ParticleState{T,WT}, idxs, filter::BootstrapFilter + state::ParticleState{T,WT}, idxs, ::BootstrapFilter ) where {T,WT<:Real} - fill!(state.log_weights, -log(WT(length(state.particles)))) + fill!(state.log_weights, zero(WT)) return state end function logmarginal(states::ParticleContainer, ::BootstrapFilter) return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights) -end - -function reset_weights!( - state::ParticleState{T,WT}, idxs, ::BootstrapFilter{N} -) where {T,WT<:Real,N} - fill!(state.log_weights, -log(WT(N))) - return state -end - -function logmarginal(states::ParticleContainer, ::BootstrapFilter) - return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights) -end +end \ No newline at end of file From d4e2a406260a858fec0aceb53a24f9c0f1c7a548 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Fri, 8 Nov 2024 10:57:35 -0500 Subject: [PATCH 12/13] fixed ancestry callbacks --- src/containers.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/containers.jl b/src/containers.jl index e31c7f0..27dcc62 100644 --- a/src/containers.jl +++ b/src/containers.jl @@ -276,7 +276,7 @@ end function (c::AncestorCallback)(model, filter, step, states, data; kwargs...) if step == 1 # this may be incorrect, but it is functional - @inbounds c.tree.states[1:(filter.N)] = deepcopy(states.filtered.particles) + @inbounds c.tree.states[keys(states.filtered)] = deepcopy(states.filtered.particles) end # TODO: when using non-stack version, may be more efficient to wait until storage full # to prune @@ -304,7 +304,7 @@ end function (c::ResamplerCallback)(model, filter, step, states, data; kwargs...) if step != 1 prune!(c.tree, get_offspring(states.ancestors)) - insert!(c.tree, collect(1:(filter.N)), states.ancestors) + insert!(c.tree, collect(keys(states.filtered)), states.ancestors) end return nothing end From 756d1f0dbf3b70fe674460c13229d9a678be3dff Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Fri, 8 Nov 2024 12:16:38 -0500 Subject: [PATCH 13/13] add guided particle filer --- src/GeneralisedFilters.jl | 1 + src/algorithms/guidedpf.jl | 164 +++++++++++++++++++++++++++++++++++++ test/runtests.jl | 93 ++++++++++++++++++++- 3 files changed, 256 insertions(+), 2 deletions(-) create mode 100644 src/algorithms/guidedpf.jl diff --git a/src/GeneralisedFilters.jl b/src/GeneralisedFilters.jl index fad4f92..ef339a4 100644 --- a/src/GeneralisedFilters.jl +++ b/src/GeneralisedFilters.jl @@ -129,5 +129,6 @@ include("algorithms/kalman.jl") include("algorithms/forward.jl") include("algorithms/rbpf.jl") include("algorithms/ffbs.jl") +include("algorithms/guidedpf.jl") end diff --git a/src/algorithms/guidedpf.jl b/src/algorithms/guidedpf.jl new file mode 100644 index 0000000..677538f --- /dev/null +++ b/src/algorithms/guidedpf.jl @@ -0,0 +1,164 @@ +export GuidedFilter, GPF, AbstractProposal + +## PROPOSALS ############################################################################### +""" + AbstractProposal +""" +abstract type AbstractProposal end + +function SSMProblems.distribution( + model::AbstractStateSpaceModel, + prop::AbstractProposal, + step::Integer, + state, + observation; + kwargs..., +) + return throw( + MethodError( + SSMProblems.distribution, (model, prop, step, state, observation, kwargs...) + ), + ) +end + +function SSMProblems.simulate( + rng::AbstractRNG, + model::AbstractStateSpaceModel, + prop::AbstractProposal, + step::Integer, + state, + observation; + kwargs..., +) + return rand( + rng, SSMProblems.distribution(model, prop, step, state, observation; kwargs...) + ) +end + +function SSMProblems.logdensity( + model::AbstractStateSpaceModel, + prop::AbstractProposal, + step::Integer, + prev_state, + new_state, + observation; + kwargs..., +) + return logpdf( + SSMProblems.distribution(model, prop, step, prev_state, observation; kwargs...), + new_state, + ) +end + +## GUIDED FILTERING ######################################################################## + +struct GuidedFilter{N,RS<:AbstractResampler,P<:AbstractProposal} <: + AbstractParticleFilter{N} + resampler::RS + proposal::P +end + +function GuidedFilter( + N::Integer, proposal::P; threshold::Real=1.0, resampler::AbstractResampler=Systematic() +) where {P<:AbstractProposal} + conditional_resampler = ESSResampler(threshold, resampler) + return GuidedFilter{N,typeof(conditional_resampler),P}(conditional_resampler, proposal) +end + +"""Shorthand for `GuidedFilter`""" +const GPF = GuidedFilter + +function initialise( + rng::AbstractRNG, + model::StateSpaceModel{T}, + filter::GuidedFilter{N}; + ref_state::Union{Nothing,AbstractVector}=nothing, + kwargs..., +) where {N,T} + initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:N) + initial_weights = zeros(T, N) + + return update_ref!( + ParticleContainer(initial_states, initial_weights), ref_state, filter + ) +end + +function predict( + rng::AbstractRNG, + model::StateSpaceModel, + filter::GuidedFilter, + step::Integer, + states::ParticleContainer{T}, + observation; + ref_state::Union{Nothing,AbstractVector{T}}=nothing, + kwargs..., +) where {T} + states.proposed, states.ancestors = resample( + rng, filter.resampler, states.filtered, filter + ) + states.proposed.particles = map( + x -> SSMProblems.simulate( + rng, model, filter.proposal, step, x, observation; kwargs... + ), + collect(states.proposed), + ) + + return update_ref!(states, ref_state, filter, step) +end + +function update( + model::StateSpaceModel{T}, + filter::GuidedFilter{N}, + step::Integer, + states::ParticleContainer, + observation; + kwargs..., +) where {T,N} + # this is a little messy and may require a deepcopy + particle_collection = zip( + collect(states.proposed), deepcopy(states.filtered.particles[states.ancestors]) + ) + + log_increments = map(particle_collection) do (new_state, prev_state) + log_f = SSMProblems.logdensity(model.dyn, step, prev_state, new_state; kwargs...) + log_g = SSMProblems.logdensity(model.obs, step, new_state, observation; kwargs...) + log_q = SSMProblems.logdensity( + model, filter.proposal, step, prev_state, new_state, observation; kwargs... + ) + + # println(log_f) + + (log_f + log_g - log_q) + end + + # println(logsumexp(log_increments)) + + states.filtered.log_weights = states.proposed.log_weights + log_increments + states.filtered.particles = states.proposed.particles + + return states, logmarginal(states, filter) +end + +function step( + rng::AbstractRNG, + model::AbstractStateSpaceModel, + alg::GuidedFilter, + iter::Integer, + state, + observation; + kwargs..., +) + proposed_state = predict(rng, model, alg, iter, state, observation; kwargs...) + filtered_state, ll = update(model, alg, iter, proposed_state, observation; kwargs...) + + return filtered_state, ll +end + +function reset_weights!(state::ParticleState{T,WT}, idxs, ::GuidedFilter) where {T,WT<:Real} + fill!(state.log_weights, zero(WT)) + return state +end + +function logmarginal(states::ParticleContainer, ::GuidedFilter) + return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights) +end diff --git a/test/runtests.jl b/test/runtests.jl index 882c1e8..6465f5e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -109,9 +109,9 @@ end _, _, data = sample(rng, model, 20) bf = BF(2^12; threshold=0.8) - apf = APF(2^10, threshold=1.) + apf = APF(2^10; threshold=1.0) bf_state, llbf = GeneralisedFilters.filter(rng, model, bf, data) - _, llapf= GeneralisedFilters.filter(rng, model, apf, data) + _, llapf = GeneralisedFilters.filter(rng, model, apf, data) kf_state, llkf = GeneralisedFilters.filter(rng, model, KF(), data) xs = bf_state.filtered.particles @@ -167,6 +167,95 @@ end @test llkf ≈ llbf atol = 2 end +@testitem "Guided particle filter test" begin + using GeneralisedFilters + using SSMProblems + using StableRNGs + using PDMats + using LinearAlgebra + using LogExpFunctions: softmax + using Random: randexp + using Distributions + + # this is a pseudo optimal proposal kernel for linear Gaussian models + struct LinearGaussianProposal{T<:Real} <: GeneralisedFilters.AbstractProposal + φ::Vector{T} + end + + # a lot of computations done at each step + function SSMProblems.distribution( + model::AbstractStateSpaceModel, + kernel::LinearGaussianProposal, + step::Integer, + state, + observation; + kwargs..., + ) + # get model dimensions + dx = length(state) + dy = length(observation) + + # see (Corenflos et al, 2021) for details + A = GeneralisedFilters.calc_A(model.dyn, step; kwargs...) + Γ = diagm(dx, dy, kernel.φ[(dx + 1):end]) + Σ = PDiagMat(φ[1:dx]) + + return MvNormal(inv(Σ) * A * state + inv(Σ) * Γ * observation, Σ) + end + + T = Float32 + rng = StableRNG(1234) + σx², σy² = randexp(rng, T, 2) + + # initial state distribution + μ0 = zeros(T, 1) + Σ0 = PDiagMat(T[1;]) + + # state transition equation + A = T[0.9;;] + b = T[0;] + Q = PDiagMat([σx²;]) + + # observation equation + H = T[1;;] + c = T[0;] + R = PDiagMat([σy²;]) + + # proposal kernel (kind of optimal...) + φ = ones(T, 2) + proposal = LinearGaussianProposal(φ) + + # when working with PDMats, the Kalman filter doesn't play nicely without this + function Base.convert(::Type{PDMat{T,MT}}, mat::MT) where {MT<:AbstractMatrix,T<:Real} + return PDMat(Symmetric(mat)) + end + + model = create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R) + _, _, data = sample(rng, model, 20) + + # bootstrap filter + bf = BF(2^12; threshold=0.8) + bf_state, llbf = GeneralisedFilters.filter(rng, model, bf, data) + bf_xs = bf_state.filtered.particles + bf_ws = softmax(bf_state.filtered.log_weights) + + # guided particle filter + pf = GPF(2^12, proposal; threshold=0.8) + pf_state, llpf = GeneralisedFilters.filter(rng, model, pf, data) + pf_xs = pf_state.filtered.particles + pf_ws = softmax(pf_state.filtered.log_weights) + + kf_state, llkf = GeneralisedFilters.filter(rng, model, KF(), data) + + # Compare filtered states + @test first(kf_state.μ) ≈ sum(first.(bf_xs) .* bf_ws) rtol = 1e-2 + @test first(kf_state.μ) ≈ sum(first.(pf_xs) .* pf_ws) rtol = 1e-2 + + # since this is log valued, we can up the tolerance + @test llkf ≈ llbf atol = 0.1 + @test llkf ≈ llpf atol = 2 +end + @testitem "Forward algorithm test" begin using GeneralisedFilters using Distributions