-
Notifications
You must be signed in to change notification settings - Fork 3
Add (naive) FFBS algo #20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: fred/auxiliary-particle-filter
Are you sure you want to change the base?
Changes from 1 commit
440cdc8
2c33c80
5198e3c
9254898
fa84e14
4b948db
556babd
20faf22
9f03987
899c507
51fd883
d4e2a40
756d1f0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
export GuidedFilter, GPF, AbstractProposal | ||
|
||
## PROPOSALS ############################################################################### | ||
""" | ||
AbstractProposal | ||
""" | ||
abstract type AbstractProposal end | ||
|
||
function SSMProblems.distribution( | ||
model::AbstractStateSpaceModel, | ||
prop::AbstractProposal, | ||
step::Integer, | ||
state, | ||
observation; | ||
kwargs..., | ||
) | ||
return throw( | ||
MethodError( | ||
SSMProblems.distribution, (model, prop, step, state, observation, kwargs...) | ||
), | ||
) | ||
end | ||
|
||
function SSMProblems.simulate( | ||
rng::AbstractRNG, | ||
model::AbstractStateSpaceModel, | ||
prop::AbstractProposal, | ||
step::Integer, | ||
state, | ||
observation; | ||
kwargs..., | ||
) | ||
return rand( | ||
rng, SSMProblems.distribution(model, prop, step, state, observation; kwargs...) | ||
) | ||
end | ||
|
||
function SSMProblems.logdensity( | ||
model::AbstractStateSpaceModel, | ||
prop::AbstractProposal, | ||
step::Integer, | ||
prev_state, | ||
new_state, | ||
observation; | ||
kwargs..., | ||
) | ||
return logpdf( | ||
SSMProblems.distribution(model, prop, step, prev_state, observation; kwargs...), | ||
new_state, | ||
) | ||
end | ||
|
||
## GUIDED FILTERING ######################################################################## | ||
|
||
struct GuidedFilter{N,RS<:AbstractResampler,P<:AbstractProposal} <: | ||
AbstractParticleFilter{N} | ||
resampler::RS | ||
proposal::P | ||
end | ||
|
||
function GuidedFilter( | ||
N::Integer, proposal::P; threshold::Real=1.0, resampler::AbstractResampler=Systematic() | ||
) where {P<:AbstractProposal} | ||
conditional_resampler = ESSResampler(threshold, resampler) | ||
return GuidedFilter{N,typeof(conditional_resampler),P}(conditional_resampler, proposal) | ||
end | ||
|
||
"""Shorthand for `GuidedFilter`""" | ||
const GPF = GuidedFilter | ||
|
||
function initialise( | ||
rng::AbstractRNG, | ||
model::StateSpaceModel{T}, | ||
filter::GuidedFilter{N}; | ||
ref_state::Union{Nothing,AbstractVector}=nothing, | ||
kwargs..., | ||
) where {N,T} | ||
initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:N) | ||
initial_weights = zeros(T, N) | ||
|
||
return update_ref!( | ||
ParticleContainer(initial_states, initial_weights), ref_state, filter | ||
) | ||
end | ||
|
||
function predict( | ||
rng::AbstractRNG, | ||
model::StateSpaceModel, | ||
filter::GuidedFilter, | ||
step::Integer, | ||
states::ParticleContainer{T}, | ||
observation; | ||
ref_state::Union{Nothing,AbstractVector{T}}=nothing, | ||
kwargs..., | ||
) where {T} | ||
states.proposed, states.ancestors = resample( | ||
rng, filter.resampler, states.filtered, filter | ||
) | ||
states.proposed.particles = map( | ||
x -> SSMProblems.simulate( | ||
rng, model, filter.proposal, step, x, observation; kwargs... | ||
), | ||
collect(states.proposed), | ||
) | ||
|
||
return update_ref!(states, ref_state, filter, step) | ||
end | ||
|
||
function update( | ||
model::StateSpaceModel{T}, | ||
filter::GuidedFilter{N}, | ||
step::Integer, | ||
states::ParticleContainer, | ||
observation; | ||
kwargs..., | ||
) where {T,N} | ||
# this is a little messy and may require a deepcopy | ||
particle_collection = zip( | ||
collect(states.proposed), deepcopy(states.filtered.particles[states.ancestors]) | ||
) | ||
|
||
log_increments = map(particle_collection) do (new_state, prev_state) | ||
log_f = SSMProblems.logdensity(model.dyn, step, prev_state, new_state; kwargs...) | ||
log_g = SSMProblems.logdensity(model.obs, step, new_state, observation; kwargs...) | ||
log_q = SSMProblems.logdensity( | ||
model, filter.proposal, step, prev_state, new_state, observation; kwargs... | ||
) | ||
|
||
# println(log_f) | ||
|
||
(log_f + log_g - log_q) | ||
end | ||
|
||
# println(logsumexp(log_increments)) | ||
|
||
states.filtered.log_weights = states.proposed.log_weights + log_increments | ||
states.filtered.particles = states.proposed.particles | ||
|
||
return states, logmarginal(states, filter) | ||
end | ||
|
||
function step( | ||
rng::AbstractRNG, | ||
model::AbstractStateSpaceModel, | ||
alg::GuidedFilter, | ||
iter::Integer, | ||
state, | ||
observation; | ||
kwargs..., | ||
) | ||
proposed_state = predict(rng, model, alg, iter, state, observation; kwargs...) | ||
filtered_state, ll = update(model, alg, iter, proposed_state, observation; kwargs...) | ||
|
||
return filtered_state, ll | ||
end | ||
|
||
function reset_weights!(state::ParticleState{T,WT}, idxs, ::GuidedFilter) where {T,WT<:Real} | ||
fill!(state.log_weights, zero(WT)) | ||
return state | ||
end | ||
|
||
function logmarginal(states::ParticleContainer, ::GuidedFilter) | ||
return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights) | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -109,9 +109,9 @@ end | |
_, _, data = sample(rng, model, 20) | ||
|
||
bf = BF(2^12; threshold=0.8) | ||
apf = APF(2^10, threshold=1.) | ||
apf = APF(2^10; threshold=1.0) | ||
bf_state, llbf = GeneralisedFilters.filter(rng, model, bf, data) | ||
_, llapf= GeneralisedFilters.filter(rng, model, apf, data) | ||
_, llapf = GeneralisedFilters.filter(rng, model, apf, data) | ||
kf_state, llkf = GeneralisedFilters.filter(rng, model, KF(), data) | ||
|
||
xs = bf_state.filtered.particles | ||
|
@@ -167,6 +167,95 @@ end | |
@test llkf ≈ llbf atol = 2 | ||
end | ||
|
||
@testitem "Guided particle filter test" begin | ||
using GeneralisedFilters | ||
using SSMProblems | ||
using StableRNGs | ||
using PDMats | ||
using LinearAlgebra | ||
using LogExpFunctions: softmax | ||
using Random: randexp | ||
using Distributions | ||
|
||
# this is a pseudo optimal proposal kernel for linear Gaussian models | ||
struct LinearGaussianProposal{T<:Real} <: GeneralisedFilters.AbstractProposal | ||
φ::Vector{T} | ||
end | ||
|
||
# a lot of computations done at each step | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps more pressing, these computations are completed twice, once for the predict step and the again for update. Not completely clear to me how to get around that though. We could potentially compute the proposal distribution before running the predict/update step and pass this in to each step. |
||
function SSMProblems.distribution( | ||
model::AbstractStateSpaceModel, | ||
kernel::LinearGaussianProposal, | ||
step::Integer, | ||
state, | ||
observation; | ||
kwargs..., | ||
) | ||
# get model dimensions | ||
dx = length(state) | ||
dy = length(observation) | ||
|
||
# see (Corenflos et al, 2021) for details | ||
A = GeneralisedFilters.calc_A(model.dyn, step; kwargs...) | ||
Γ = diagm(dx, dy, kernel.φ[(dx + 1):end]) | ||
Σ = PDiagMat(φ[1:dx]) | ||
|
||
return MvNormal(inv(Σ) * A * state + inv(Σ) * Γ * observation, Σ) | ||
end | ||
|
||
T = Float32 | ||
rng = StableRNG(1234) | ||
σx², σy² = randexp(rng, T, 2) | ||
|
||
# initial state distribution | ||
μ0 = zeros(T, 1) | ||
Σ0 = PDiagMat(T[1;]) | ||
|
||
# state transition equation | ||
A = T[0.9;;] | ||
b = T[0;] | ||
Q = PDiagMat([σx²;]) | ||
|
||
# observation equation | ||
H = T[1;;] | ||
c = T[0;] | ||
R = PDiagMat([σy²;]) | ||
|
||
# proposal kernel (kind of optimal...) | ||
φ = ones(T, 2) | ||
proposal = LinearGaussianProposal(φ) | ||
|
||
# when working with PDMats, the Kalman filter doesn't play nicely without this | ||
function Base.convert(::Type{PDMat{T,MT}}, mat::MT) where {MT<:AbstractMatrix,T<:Real} | ||
return PDMat(Symmetric(mat)) | ||
end | ||
|
||
model = create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R) | ||
_, _, data = sample(rng, model, 20) | ||
|
||
# bootstrap filter | ||
bf = BF(2^12; threshold=0.8) | ||
bf_state, llbf = GeneralisedFilters.filter(rng, model, bf, data) | ||
bf_xs = bf_state.filtered.particles | ||
bf_ws = softmax(bf_state.filtered.log_weights) | ||
|
||
# guided particle filter | ||
pf = GPF(2^12, proposal; threshold=0.8) | ||
pf_state, llpf = GeneralisedFilters.filter(rng, model, pf, data) | ||
pf_xs = pf_state.filtered.particles | ||
pf_ws = softmax(pf_state.filtered.log_weights) | ||
|
||
kf_state, llkf = GeneralisedFilters.filter(rng, model, KF(), data) | ||
|
||
# Compare filtered states | ||
@test first(kf_state.μ) ≈ sum(first.(bf_xs) .* bf_ws) rtol = 1e-2 | ||
@test first(kf_state.μ) ≈ sum(first.(pf_xs) .* pf_ws) rtol = 1e-2 | ||
|
||
# since this is log valued, we can up the tolerance | ||
@test llkf ≈ llbf atol = 0.1 | ||
@test llkf ≈ llpf atol = 2 | ||
end | ||
|
||
@testitem "Forward algorithm test" begin | ||
using GeneralisedFilters | ||
using Distributions | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should change the
SSMProblems
interface. The proposal should be part of thefilter
interface, maybe something along the lines of:And we should probably update the
filter/predict
functions:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I 100% agree with the BF integration, I was intentionally working my way up to that, but didn't want to drastically change the interface upon the first commit.
And you're totally right about the SSMProblems integration. But it was convenient to recycle the structures.