Skip to content

Add (naive) FFBS algo #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: fred/auxiliary-particle-filter
Choose a base branch
from
1 change: 1 addition & 0 deletions src/GeneralisedFilters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,5 +129,6 @@ include("algorithms/kalman.jl")
include("algorithms/forward.jl")
include("algorithms/rbpf.jl")
include("algorithms/ffbs.jl")
include("algorithms/guidedpf.jl")

end
164 changes: 164 additions & 0 deletions src/algorithms/guidedpf.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
export GuidedFilter, GPF, AbstractProposal

## PROPOSALS ###############################################################################
"""
AbstractProposal
"""
abstract type AbstractProposal end

function SSMProblems.distribution(
model::AbstractStateSpaceModel,
prop::AbstractProposal,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should change the SSMProblems interface. The proposal should be part of the filter interface, maybe something along the lines of:

abstract type AbstractProposal end

abstract type AbstractParticleFilter{N, P<:AbstractProposal} end 

struct ParticleFilter{N,RS,P} <: AbstractParticleFilter{N,P}
    resampler::RS
    proposal::P
end

# Default to latent dynamics
struct LatentProposal <: AbstractProposal end

const BootstrapFilter{N,RS} = ParticleFilter{N,RS,LatentProposal}
const BF = BootstrapFilter

function propose(
    rng::AbstractRNG, 
    prop::LatentProposal, 
    model::AbstractStateSpaceModel, 
    particles::ParticleContainer, 
    step, 
    state, 
    obs; 
    kwargs...
)
    return SSMProblems.simulate(rng, model.dyn, t, state; kwargs...)
end

function logdensity(prop::AbstractProposal, ...)
   return SSMProblmes.logdensity(...)
end

And we should probably update the filter/predict functions:

function predict(
    rng::AbstractRNG,
    model::StateSpaceModel,
    filter::BootstrapFilter,
    step::Integer,
    states::ParticleContainer{T};
    ref_state::Union{Nothing,AbstractVector{T}}=nothing,
    kwargs...,
) where {T}
    states.proposed, states.ancestors = resample(
        rng, filter.resampler, states.filtered, filter
    )
    states.proposed.particles = map(states.proposed) do state
        propose(rng, filter.proposal, model.dyn, step, state; kwargs...),
    end

    return update_ref!(states, ref_state, filter, step)
end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I 100% agree with the BF integration, I was intentionally working my way up to that, but didn't want to drastically change the interface upon the first commit.

And you're totally right about the SSMProblems integration. But it was convenient to recycle the structures.

step::Integer,
state,
observation;
kwargs...,
)
return throw(
MethodError(
SSMProblems.distribution, (model, prop, step, state, observation, kwargs...)
),
)
end

function SSMProblems.simulate(
rng::AbstractRNG,
model::AbstractStateSpaceModel,
prop::AbstractProposal,
step::Integer,
state,
observation;
kwargs...,
)
return rand(
rng, SSMProblems.distribution(model, prop, step, state, observation; kwargs...)
)
end

function SSMProblems.logdensity(
model::AbstractStateSpaceModel,
prop::AbstractProposal,
step::Integer,
prev_state,
new_state,
observation;
kwargs...,
)
return logpdf(
SSMProblems.distribution(model, prop, step, prev_state, observation; kwargs...),
new_state,
)
end

## GUIDED FILTERING ########################################################################

struct GuidedFilter{N,RS<:AbstractResampler,P<:AbstractProposal} <:
AbstractParticleFilter{N}
resampler::RS
proposal::P
end

function GuidedFilter(
N::Integer, proposal::P; threshold::Real=1.0, resampler::AbstractResampler=Systematic()
) where {P<:AbstractProposal}
conditional_resampler = ESSResampler(threshold, resampler)
return GuidedFilter{N,typeof(conditional_resampler),P}(conditional_resampler, proposal)
end

"""Shorthand for `GuidedFilter`"""
const GPF = GuidedFilter

function initialise(
rng::AbstractRNG,
model::StateSpaceModel{T},
filter::GuidedFilter{N};
ref_state::Union{Nothing,AbstractVector}=nothing,
kwargs...,
) where {N,T}
initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:N)
initial_weights = zeros(T, N)

return update_ref!(
ParticleContainer(initial_states, initial_weights), ref_state, filter
)
end

function predict(
rng::AbstractRNG,
model::StateSpaceModel,
filter::GuidedFilter,
step::Integer,
states::ParticleContainer{T},
observation;
ref_state::Union{Nothing,AbstractVector{T}}=nothing,
kwargs...,
) where {T}
states.proposed, states.ancestors = resample(
rng, filter.resampler, states.filtered, filter
)
states.proposed.particles = map(
x -> SSMProblems.simulate(
rng, model, filter.proposal, step, x, observation; kwargs...
),
collect(states.proposed),
)

return update_ref!(states, ref_state, filter, step)
end

function update(
model::StateSpaceModel{T},
filter::GuidedFilter{N},
step::Integer,
states::ParticleContainer,
observation;
kwargs...,
) where {T,N}
# this is a little messy and may require a deepcopy
particle_collection = zip(
collect(states.proposed), deepcopy(states.filtered.particles[states.ancestors])
)

log_increments = map(particle_collection) do (new_state, prev_state)
log_f = SSMProblems.logdensity(model.dyn, step, prev_state, new_state; kwargs...)
log_g = SSMProblems.logdensity(model.obs, step, new_state, observation; kwargs...)
log_q = SSMProblems.logdensity(
model, filter.proposal, step, prev_state, new_state, observation; kwargs...
)

# println(log_f)

(log_f + log_g - log_q)
end

# println(logsumexp(log_increments))

states.filtered.log_weights = states.proposed.log_weights + log_increments
states.filtered.particles = states.proposed.particles

return states, logmarginal(states, filter)
end

function step(
rng::AbstractRNG,
model::AbstractStateSpaceModel,
alg::GuidedFilter,
iter::Integer,
state,
observation;
kwargs...,
)
proposed_state = predict(rng, model, alg, iter, state, observation; kwargs...)
filtered_state, ll = update(model, alg, iter, proposed_state, observation; kwargs...)

return filtered_state, ll
end

function reset_weights!(state::ParticleState{T,WT}, idxs, ::GuidedFilter) where {T,WT<:Real}
fill!(state.log_weights, zero(WT))
return state
end

function logmarginal(states::ParticleContainer, ::GuidedFilter)
return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights)
end
93 changes: 91 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ end
_, _, data = sample(rng, model, 20)

bf = BF(2^12; threshold=0.8)
apf = APF(2^10, threshold=1.)
apf = APF(2^10; threshold=1.0)
bf_state, llbf = GeneralisedFilters.filter(rng, model, bf, data)
_, llapf= GeneralisedFilters.filter(rng, model, apf, data)
_, llapf = GeneralisedFilters.filter(rng, model, apf, data)
kf_state, llkf = GeneralisedFilters.filter(rng, model, KF(), data)

xs = bf_state.filtered.particles
Expand Down Expand Up @@ -167,6 +167,95 @@ end
@test llkf ≈ llbf atol = 2
end

@testitem "Guided particle filter test" begin
using GeneralisedFilters
using SSMProblems
using StableRNGs
using PDMats
using LinearAlgebra
using LogExpFunctions: softmax
using Random: randexp
using Distributions

# this is a pseudo optimal proposal kernel for linear Gaussian models
struct LinearGaussianProposal{T<:Real} <: GeneralisedFilters.AbstractProposal
φ::Vector{T}
end

# a lot of computations done at each step
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps more pressing, these computations are completed twice, once for the predict step and the again for update. Not completely clear to me how to get around that though.

We could potentially compute the proposal distribution before running the predict/update step and pass this in to each step.

function SSMProblems.distribution(
model::AbstractStateSpaceModel,
kernel::LinearGaussianProposal,
step::Integer,
state,
observation;
kwargs...,
)
# get model dimensions
dx = length(state)
dy = length(observation)

# see (Corenflos et al, 2021) for details
A = GeneralisedFilters.calc_A(model.dyn, step; kwargs...)
Γ = diagm(dx, dy, kernel.φ[(dx + 1):end])
Σ = PDiagMat(φ[1:dx])

return MvNormal(inv(Σ) * A * state + inv(Σ) * Γ * observation, Σ)
end

T = Float32
rng = StableRNG(1234)
σx², σy² = randexp(rng, T, 2)

# initial state distribution
μ0 = zeros(T, 1)
Σ0 = PDiagMat(T[1;])

# state transition equation
A = T[0.9;;]
b = T[0;]
Q = PDiagMat([σx²;])

# observation equation
H = T[1;;]
c = T[0;]
R = PDiagMat([σy²;])

# proposal kernel (kind of optimal...)
φ = ones(T, 2)
proposal = LinearGaussianProposal(φ)

# when working with PDMats, the Kalman filter doesn't play nicely without this
function Base.convert(::Type{PDMat{T,MT}}, mat::MT) where {MT<:AbstractMatrix,T<:Real}
return PDMat(Symmetric(mat))
end

model = create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R)
_, _, data = sample(rng, model, 20)

# bootstrap filter
bf = BF(2^12; threshold=0.8)
bf_state, llbf = GeneralisedFilters.filter(rng, model, bf, data)
bf_xs = bf_state.filtered.particles
bf_ws = softmax(bf_state.filtered.log_weights)

# guided particle filter
pf = GPF(2^12, proposal; threshold=0.8)
pf_state, llpf = GeneralisedFilters.filter(rng, model, pf, data)
pf_xs = pf_state.filtered.particles
pf_ws = softmax(pf_state.filtered.log_weights)

kf_state, llkf = GeneralisedFilters.filter(rng, model, KF(), data)

# Compare filtered states
@test first(kf_state.μ) ≈ sum(first.(bf_xs) .* bf_ws) rtol = 1e-2
@test first(kf_state.μ) ≈ sum(first.(pf_xs) .* pf_ws) rtol = 1e-2

# since this is log valued, we can up the tolerance
@test llkf ≈ llbf atol = 0.1
@test llkf ≈ llpf atol = 2
end

@testitem "Forward algorithm test" begin
using GeneralisedFilters
using Distributions
Expand Down