diff --git a/src/GeneralisedFilters.jl b/src/GeneralisedFilters.jl index d53589f..ef339a4 100644 --- a/src/GeneralisedFilters.jl +++ b/src/GeneralisedFilters.jl @@ -128,5 +128,7 @@ include("algorithms/apf.jl") include("algorithms/kalman.jl") include("algorithms/forward.jl") include("algorithms/rbpf.jl") +include("algorithms/ffbs.jl") +include("algorithms/guidedpf.jl") end diff --git a/src/algorithms/apf.jl b/src/algorithms/apf.jl index f4a9df6..bf9e089 100644 --- a/src/algorithms/apf.jl +++ b/src/algorithms/apf.jl @@ -6,7 +6,7 @@ mutable struct AuxiliaryParticleFilter{N,RS<:AbstractConditionalResampler} <: Ab end function AuxiliaryParticleFilter( - N::Integer; threshold::Real=0., resampler::AbstractResampler=Systematic() + N::Integer; threshold::Real=0.0, resampler::AbstractResampler=Systematic() ) conditional_resampler = ESSResampler(threshold, resampler) return AuxiliaryParticleFilter{N,typeof(conditional_resampler)}(conditional_resampler, zeros(N)) @@ -24,7 +24,9 @@ function initialise( initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:N) initial_weights = zeros(T, N) - return update_ref!(ParticleContainer(initial_states, initial_weights), ref_state, filter) + return update_ref!( + ParticleContainer(initial_states, initial_weights), ref_state, filter + ) end function update_weights!( @@ -59,7 +61,9 @@ function predict( states.filtered.log_weights .+= auxiliary_weights filter.aux = auxiliary_weights - states.proposed, states.ancestors = resample(rng, filter.resampler, states.filtered, filter) + states.proposed, states.ancestors = resample( + rng, filter.resampler, states.filtered, filter + ) states.proposed.particles = map( x -> SSMProblems.simulate(rng, model.dyn, step, x; kwargs...), states.proposed.particles, @@ -70,12 +74,12 @@ end function update( model::StateSpaceModel{T}, - filter::AuxiliaryParticleFilter, + filter::AuxiliaryParticleFilter{N}, step::Integer, states::ParticleContainer, observation; kwargs..., -) where {T} +) where {T,N} @debug "step $step" log_increments = map( x -> SSMProblems.logdensity(model.obs, step, x, observation; kwargs...), diff --git a/src/algorithms/bootstrap.jl b/src/algorithms/bootstrap.jl index a508761..11a19fe 100644 --- a/src/algorithms/bootstrap.jl +++ b/src/algorithms/bootstrap.jl @@ -1,5 +1,7 @@ export BootstrapFilter, BF +abstract type AbstractParticleFilter{N} <: AbstractFilter end + struct BootstrapFilter{N,RS<:AbstractResampler} <: AbstractParticleFilter{N} resampler::RS end @@ -8,7 +10,7 @@ function BootstrapFilter( N::Integer; threshold::Real=1.0, resampler::AbstractResampler=Systematic() ) conditional_resampler = ESSResampler(threshold, resampler) - return BootstrapFilter{N, typeof(conditional_resampler)}(conditional_resampler) + return BootstrapFilter{N,typeof(conditional_resampler)}(conditional_resampler) end """Shorthand for `BootstrapFilter`""" @@ -38,7 +40,9 @@ function predict( ref_state::Union{Nothing,AbstractVector{T}}=nothing, kwargs..., ) where {T} - states.proposed, states.ancestors = resample(rng, filter.resampler, states.filtered, filter) + states.proposed, states.ancestors = resample( + rng, filter.resampler, states.filtered, filter + ) states.proposed.particles = map( x -> SSMProblems.simulate(rng, model.dyn, step, x; kwargs...), collect(states.proposed), @@ -49,12 +53,12 @@ end function update( model::StateSpaceModel{T}, - filter::BootstrapFilter, + filter::BootstrapFilter{N}, step::Integer, states::ParticleContainer, observation; kwargs..., -) where {T} +) where {T,N} log_increments = map( x -> SSMProblems.logdensity(model.obs, step, x, observation; kwargs...), collect(states.proposed), @@ -67,12 +71,12 @@ function update( end function reset_weights!( - state::ParticleState{T,WT}, idxs, filter::BootstrapFilter + state::ParticleState{T,WT}, idxs, ::BootstrapFilter ) where {T,WT<:Real} - fill!(state.log_weights, -log(WT(length(state.particles)))) + fill!(state.log_weights, zero(WT)) return state end function logmarginal(states::ParticleContainer, ::BootstrapFilter) return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights) -end +end \ No newline at end of file diff --git a/src/algorithms/ffbs.jl b/src/algorithms/ffbs.jl new file mode 100644 index 0000000..c02dae5 --- /dev/null +++ b/src/algorithms/ffbs.jl @@ -0,0 +1,95 @@ +export FFBS + +abstract type AbstractSmoother <: AbstractSampler end + +struct FFBS{T<:AbstractParticleFilter} + filter::T +end + +""" + smooth(rng::AbstractRNG, alg::AbstractSmooterh, model::AbstractStateSpaceModel, obs::AbstractVector, M::Integer; callback, kwargs...) +""" +function smooth end + +struct WeightedParticleRecorderCallback{T,WT} + particles::Array{T} + log_weights::Array{WT} +end + +function (callback::WeightedParticleRecorderCallback)( + model, filter, step, states, data; kwargs... +) + filtered_states = states.filtered + callback.particles[step, :] = filtered_states.particles + callback.log_weights[step, :] = filtered_states.log_weights + return nothing +end + +function gen_trajectory( + rng::Random.AbstractRNG, + model::StateSpaceModel, + particles::AbstractMatrix{T}, # Need better container + log_weights::AbstractMatrix{WT}, + forward_state, + n_timestep::Int; + kwargs..., +) where {T,WT} + trajectory = Vector{T}(undef, n_timestep) + trajectory[end] = forward_state + for step in (n_timestep - 1):-1:1 + backward_weights = backward( + model, + step, + trajectory[step + 1], + particles[step, :], + log_weights[step, :]; + kwargs..., + ) + ancestor = rand(rng, Categorical(softmax(backward_weights))) + trajectory[step] = particles[step, ancestor] + end + return trajectory +end + +function backward( + model::StateSpaceModel, step::Integer, state, particles::T, log_weights::WT; kwargs... +) where {T,WT} + transitions = map(particles) do prev_state + SSMProblems.logdensity(model.dyn, step, prev_state, state; kwargs...) + end + return log_weights + transitions +end + +function sample( + rng::Random.AbstractRNG, + model::StateSpaceModel{T,LDT}, + alg::FFBS{<:BootstrapFilter{N}}, + obs::AbstractVector, + M::Integer; + callback=nothing, + kwargs..., +) where {T,LDT,N} + n_timestep = length(obs) + recorder = WeightedParticleRecorderCallback( + Array{eltype(model.dyn)}(undef, n_timestep, N), Array{T}(undef, n_timestep, N) + ) + + particles, _ = filter(rng, model, alg.filter, obs; callback=recorder, kwargs...) + + # Backward sampling - exact + idx_ref = rand(rng, Categorical(weights(particles.filtered)), M) + trajectories = Array{eltype(model.dyn)}(undef, n_timestep, M) + + trajectories[end, :] = particles.filtered[idx_ref] + for j in 1:M + trajectories[:, j] = gen_trajectory( + rng, + model, + recorder.particles, + recorder.log_weights, + trajectories[end, j], + n_timestep, + ) + end + return trajectories +end diff --git a/src/algorithms/guidedpf.jl b/src/algorithms/guidedpf.jl new file mode 100644 index 0000000..677538f --- /dev/null +++ b/src/algorithms/guidedpf.jl @@ -0,0 +1,164 @@ +export GuidedFilter, GPF, AbstractProposal + +## PROPOSALS ############################################################################### +""" + AbstractProposal +""" +abstract type AbstractProposal end + +function SSMProblems.distribution( + model::AbstractStateSpaceModel, + prop::AbstractProposal, + step::Integer, + state, + observation; + kwargs..., +) + return throw( + MethodError( + SSMProblems.distribution, (model, prop, step, state, observation, kwargs...) + ), + ) +end + +function SSMProblems.simulate( + rng::AbstractRNG, + model::AbstractStateSpaceModel, + prop::AbstractProposal, + step::Integer, + state, + observation; + kwargs..., +) + return rand( + rng, SSMProblems.distribution(model, prop, step, state, observation; kwargs...) + ) +end + +function SSMProblems.logdensity( + model::AbstractStateSpaceModel, + prop::AbstractProposal, + step::Integer, + prev_state, + new_state, + observation; + kwargs..., +) + return logpdf( + SSMProblems.distribution(model, prop, step, prev_state, observation; kwargs...), + new_state, + ) +end + +## GUIDED FILTERING ######################################################################## + +struct GuidedFilter{N,RS<:AbstractResampler,P<:AbstractProposal} <: + AbstractParticleFilter{N} + resampler::RS + proposal::P +end + +function GuidedFilter( + N::Integer, proposal::P; threshold::Real=1.0, resampler::AbstractResampler=Systematic() +) where {P<:AbstractProposal} + conditional_resampler = ESSResampler(threshold, resampler) + return GuidedFilter{N,typeof(conditional_resampler),P}(conditional_resampler, proposal) +end + +"""Shorthand for `GuidedFilter`""" +const GPF = GuidedFilter + +function initialise( + rng::AbstractRNG, + model::StateSpaceModel{T}, + filter::GuidedFilter{N}; + ref_state::Union{Nothing,AbstractVector}=nothing, + kwargs..., +) where {N,T} + initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:N) + initial_weights = zeros(T, N) + + return update_ref!( + ParticleContainer(initial_states, initial_weights), ref_state, filter + ) +end + +function predict( + rng::AbstractRNG, + model::StateSpaceModel, + filter::GuidedFilter, + step::Integer, + states::ParticleContainer{T}, + observation; + ref_state::Union{Nothing,AbstractVector{T}}=nothing, + kwargs..., +) where {T} + states.proposed, states.ancestors = resample( + rng, filter.resampler, states.filtered, filter + ) + states.proposed.particles = map( + x -> SSMProblems.simulate( + rng, model, filter.proposal, step, x, observation; kwargs... + ), + collect(states.proposed), + ) + + return update_ref!(states, ref_state, filter, step) +end + +function update( + model::StateSpaceModel{T}, + filter::GuidedFilter{N}, + step::Integer, + states::ParticleContainer, + observation; + kwargs..., +) where {T,N} + # this is a little messy and may require a deepcopy + particle_collection = zip( + collect(states.proposed), deepcopy(states.filtered.particles[states.ancestors]) + ) + + log_increments = map(particle_collection) do (new_state, prev_state) + log_f = SSMProblems.logdensity(model.dyn, step, prev_state, new_state; kwargs...) + log_g = SSMProblems.logdensity(model.obs, step, new_state, observation; kwargs...) + log_q = SSMProblems.logdensity( + model, filter.proposal, step, prev_state, new_state, observation; kwargs... + ) + + # println(log_f) + + (log_f + log_g - log_q) + end + + # println(logsumexp(log_increments)) + + states.filtered.log_weights = states.proposed.log_weights + log_increments + states.filtered.particles = states.proposed.particles + + return states, logmarginal(states, filter) +end + +function step( + rng::AbstractRNG, + model::AbstractStateSpaceModel, + alg::GuidedFilter, + iter::Integer, + state, + observation; + kwargs..., +) + proposed_state = predict(rng, model, alg, iter, state, observation; kwargs...) + filtered_state, ll = update(model, alg, iter, proposed_state, observation; kwargs...) + + return filtered_state, ll +end + +function reset_weights!(state::ParticleState{T,WT}, idxs, ::GuidedFilter) where {T,WT<:Real} + fill!(state.log_weights, zero(WT)) + return state +end + +function logmarginal(states::ParticleContainer, ::GuidedFilter) + return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights) +end diff --git a/src/containers.jl b/src/containers.jl index e31c7f0..27dcc62 100644 --- a/src/containers.jl +++ b/src/containers.jl @@ -276,7 +276,7 @@ end function (c::AncestorCallback)(model, filter, step, states, data; kwargs...) if step == 1 # this may be incorrect, but it is functional - @inbounds c.tree.states[1:(filter.N)] = deepcopy(states.filtered.particles) + @inbounds c.tree.states[keys(states.filtered)] = deepcopy(states.filtered.particles) end # TODO: when using non-stack version, may be more efficient to wait until storage full # to prune @@ -304,7 +304,7 @@ end function (c::ResamplerCallback)(model, filter, step, states, data; kwargs...) if step != 1 prune!(c.tree, get_offspring(states.ancestors)) - insert!(c.tree, collect(1:(filter.N)), states.ancestors) + insert!(c.tree, collect(keys(states.filtered)), states.ancestors) end return nothing end diff --git a/src/resamplers.jl b/src/resamplers.jl index 60a7b36..7a281dd 100644 --- a/src/resamplers.jl +++ b/src/resamplers.jl @@ -35,7 +35,7 @@ function resample( deepcopy(states.z_particles[idxs]), CUDA.zeros(T, length(states)), ) - + reset_weights!(new_state, idxs, filter) return new_state, idxs end @@ -81,7 +81,7 @@ function resample( @debug "ESS: $ess" if cond_resampler.threshold * n ≥ ess - return resample(rng, cond_resampler.resampler, state) + return resample(rng, cond_resampler.resampler, state, filter) else return deepcopy(state), collect(1:n) end diff --git a/test/runtests.jl b/test/runtests.jl index 4d85a0c..6465f5e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -109,9 +109,9 @@ end _, _, data = sample(rng, model, 20) bf = BF(2^12; threshold=0.8) - apf = APF(2^10, threshold=1.) + apf = APF(2^10; threshold=1.0) bf_state, llbf = GeneralisedFilters.filter(rng, model, bf, data) - _, llapf= GeneralisedFilters.filter(rng, model, apf, data) + _, llapf = GeneralisedFilters.filter(rng, model, apf, data) kf_state, llkf = GeneralisedFilters.filter(rng, model, KF(), data) xs = bf_state.filtered.particles @@ -125,6 +125,137 @@ end @test llkf ≈ llapf atol = 2 end +@testitem "APF filter test" begin + using GeneralisedFilters + using SSMProblems + using StableRNGs + using PDMats + using LinearAlgebra + using Random: randexp + + T = Float32 + rng = StableRNG(1234) + σx², σy² = randexp(rng, T, 2) + + # initial state distribution + μ0 = zeros(T, 2) + Σ0 = PDMat(T[1 0; 0 1]) + + # state transition equation + A = T[1 1; 0 1] + b = T[0; 0] + Q = PDiagMat([σx²; 0]) + + # observation equation + H = T[1 0] + c = T[0;] + R = [σy²;;] + + # when working with PDMats, the Kalman filter doesn't play nicely without this + function Base.convert(::Type{PDMat{T,MT}}, mat::MT) where {MT<:AbstractMatrix,T<:Real} + return PDMat(Symmetric(mat)) + end + + model = create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R) + _, _, data = sample(rng, model, 20) + + bf = APF(2^10; threshold=0.8) + _, llbf = GeneralisedFilters.filter(rng, model, bf, data) + _, llkf = GeneralisedFilters.filter(rng, model, KF(), data) + + # since this is log valued, we can up the tolerance + @test llkf ≈ llbf atol = 2 +end + +@testitem "Guided particle filter test" begin + using GeneralisedFilters + using SSMProblems + using StableRNGs + using PDMats + using LinearAlgebra + using LogExpFunctions: softmax + using Random: randexp + using Distributions + + # this is a pseudo optimal proposal kernel for linear Gaussian models + struct LinearGaussianProposal{T<:Real} <: GeneralisedFilters.AbstractProposal + φ::Vector{T} + end + + # a lot of computations done at each step + function SSMProblems.distribution( + model::AbstractStateSpaceModel, + kernel::LinearGaussianProposal, + step::Integer, + state, + observation; + kwargs..., + ) + # get model dimensions + dx = length(state) + dy = length(observation) + + # see (Corenflos et al, 2021) for details + A = GeneralisedFilters.calc_A(model.dyn, step; kwargs...) + Γ = diagm(dx, dy, kernel.φ[(dx + 1):end]) + Σ = PDiagMat(φ[1:dx]) + + return MvNormal(inv(Σ) * A * state + inv(Σ) * Γ * observation, Σ) + end + + T = Float32 + rng = StableRNG(1234) + σx², σy² = randexp(rng, T, 2) + + # initial state distribution + μ0 = zeros(T, 1) + Σ0 = PDiagMat(T[1;]) + + # state transition equation + A = T[0.9;;] + b = T[0;] + Q = PDiagMat([σx²;]) + + # observation equation + H = T[1;;] + c = T[0;] + R = PDiagMat([σy²;]) + + # proposal kernel (kind of optimal...) + φ = ones(T, 2) + proposal = LinearGaussianProposal(φ) + + # when working with PDMats, the Kalman filter doesn't play nicely without this + function Base.convert(::Type{PDMat{T,MT}}, mat::MT) where {MT<:AbstractMatrix,T<:Real} + return PDMat(Symmetric(mat)) + end + + model = create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R) + _, _, data = sample(rng, model, 20) + + # bootstrap filter + bf = BF(2^12; threshold=0.8) + bf_state, llbf = GeneralisedFilters.filter(rng, model, bf, data) + bf_xs = bf_state.filtered.particles + bf_ws = softmax(bf_state.filtered.log_weights) + + # guided particle filter + pf = GPF(2^12, proposal; threshold=0.8) + pf_state, llpf = GeneralisedFilters.filter(rng, model, pf, data) + pf_xs = pf_state.filtered.particles + pf_ws = softmax(pf_state.filtered.log_weights) + + kf_state, llkf = GeneralisedFilters.filter(rng, model, KF(), data) + + # Compare filtered states + @test first(kf_state.μ) ≈ sum(first.(bf_xs) .* bf_ws) rtol = 1e-2 + @test first(kf_state.μ) ≈ sum(first.(pf_xs) .* pf_ws) rtol = 1e-2 + + # since this is log valued, we can up the tolerance + @test llkf ≈ llbf atol = 0.1 + @test llkf ≈ llpf atol = 2 +end + @testitem "Forward algorithm test" begin using GeneralisedFilters using Distributions