Skip to content

Update APF to new interface #79

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion GeneralisedFilters/src/GFTest/models/linear_gaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,4 @@ function _compute_joint(model, T::Integer)
Σ_Z = ((I - P) \ Σ_ϵ) / (I - P)'

return μ_Z, Σ_Z
end
end
68 changes: 58 additions & 10 deletions GeneralisedFilters/src/algorithms/particles.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
export BootstrapFilter, BF
export ParticleFilter, PF, AbstractProposal
export AuxiliaryParticleFilter, APF
export logeta

import SSMProblems: distribution, simulate, logdensity

Expand Down Expand Up @@ -47,23 +49,35 @@ function SSMProblems.logdensity(
)
end

function logeta end

abstract type AbstractParticleFilter <: AbstractFilter end

struct ParticleFilter{RS,PT} <: AbstractParticleFilter
N::Int
resampler::RS
proposal::PT
mutable struct ParticleFilter{RS,PT,WT} <: AbstractParticleFilter
const N::Int
const resampler::RS
const proposal::PT
aux::WT
end

const PF = ParticleFilter

function ParticleFilter(
N::Integer, proposal::PT; threshold::Real=1.0, resampler::AbstractResampler=Systematic()
) where {PT<:AbstractProposal}
N::Integer,
proposal::PT;
threshold::Real=1.0,
resampler::AbstractResampler=Systematic(),
aux::Union{Nothing,Vector{WT}}=nothing,
) where {PT<:AbstractProposal,WT}
conditional_resampler = ESSResampler(threshold, resampler)
return ParticleFilter{ESSResampler,PT}(N, conditional_resampler, proposal)
return ParticleFilter{ESSResampler,PT,typeof(aux)}(
N, conditional_resampler, proposal, aux
)
end

update_weights!(state, model, algo, iter, observation; kwargs...) = state
reset_weights!(state::ParticleDistribution, algo::ParticleFilter) = state # Wonky ... by construction of ParticleDistribution

function step(
rng::AbstractRNG,
model::AbstractStateSpaceModel,
Expand All @@ -76,7 +90,11 @@ function step(
kwargs...,
)
# capture the marginalized log-likelihood
# TODO: Add a presampling step for the auxiliary particle filter
state = update_weights!(state, model, algo, iter, observation; kwargs...)
state = resample(rng, algo.resampler, state; ref_state)
state = reset_weights!(state, algo) # Reset weights if needed

marginalization_term = logsumexp(state.log_weights)
isnothing(callback) ||
callback(model, algo, iter, state, observation, PostResample; kwargs...)
Expand All @@ -96,7 +114,7 @@ end
function initialise(
rng::AbstractRNG,
model::StateSpaceModel{T},
filter::ParticleFilter;
filter::AbstractParticleFilter;
ref_state::Union{Nothing,AbstractVector}=nothing,
kwargs...,
) where {T}
Expand All @@ -115,7 +133,7 @@ end
function predict(
rng::AbstractRNG,
model::StateSpaceModel,
filter::ParticleFilter,
filter::AbstractParticleFilter,
iter::Integer,
state::ParticleDistribution,
observation;
Expand Down Expand Up @@ -150,7 +168,7 @@ end

function update(
model::StateSpaceModel{T},
filter::ParticleFilter,
filter::AbstractParticleFilter,
iter::Integer,
state::ParticleDistribution,
observation;
Expand Down Expand Up @@ -233,3 +251,33 @@ function filter(
)
return filter(rng, ssm, algo, observations; ref_state=ref_state, kwargs...)
end

### Auxiliary particle filter
const AuxiliaryParticleFilter{RS,P,WT} = ParticleFilter{RS,P,Vector{WT}}
const APF = AuxiliaryParticleFilter

function APF(N::Int; kwargs...)
ParticleFilter(N, LatentProposal(); aux=zeros(Float64, N), kwargs...)
end

function update_weights!(
state::ParticleDistribution,
model::StateSpaceModel,
algo::AuxiliaryParticleFilter,
step::Int,
observation;
kwargs...,
)
# TODO: Can we dispatch on model capabilities maybe ?
auxiliary_log_weights = map(enumerate(state.particles)) do (i, particle)
logeta(particle, model, step, observation; kwargs...)
end
algo.aux = auxiliary_log_weights
state.log_weights += auxiliary_log_weights
return state
end

function reset_weights!(state::ParticleDistribution, algo::AuxiliaryParticleFilter)
state.log_weights = state.log_weights - algo.aux[state.ancestors]
return state
end
Loading