diff --git a/GeneralisedFilters/src/GFTest/models/linear_gaussian.jl b/GeneralisedFilters/src/GFTest/models/linear_gaussian.jl index 8c4fff6..c877055 100644 --- a/GeneralisedFilters/src/GFTest/models/linear_gaussian.jl +++ b/GeneralisedFilters/src/GFTest/models/linear_gaussian.jl @@ -61,4 +61,4 @@ function _compute_joint(model, T::Integer) Σ_Z = ((I - P) \ Σ_ϵ) / (I - P)' return μ_Z, Σ_Z -end \ No newline at end of file +end diff --git a/GeneralisedFilters/src/algorithms/particles.jl b/GeneralisedFilters/src/algorithms/particles.jl index 879e157..e23eed6 100644 --- a/GeneralisedFilters/src/algorithms/particles.jl +++ b/GeneralisedFilters/src/algorithms/particles.jl @@ -1,5 +1,7 @@ export BootstrapFilter, BF export ParticleFilter, PF, AbstractProposal +export AuxiliaryParticleFilter, APF +export logeta import SSMProblems: distribution, simulate, logdensity @@ -47,23 +49,35 @@ function SSMProblems.logdensity( ) end +function logeta end + abstract type AbstractParticleFilter <: AbstractFilter end -struct ParticleFilter{RS,PT} <: AbstractParticleFilter - N::Int - resampler::RS - proposal::PT +mutable struct ParticleFilter{RS,PT,WT} <: AbstractParticleFilter + const N::Int + const resampler::RS + const proposal::PT + aux::WT end const PF = ParticleFilter function ParticleFilter( - N::Integer, proposal::PT; threshold::Real=1.0, resampler::AbstractResampler=Systematic() -) where {PT<:AbstractProposal} + N::Integer, + proposal::PT; + threshold::Real=1.0, + resampler::AbstractResampler=Systematic(), + aux::Union{Nothing,Vector{WT}}=nothing, +) where {PT<:AbstractProposal,WT} conditional_resampler = ESSResampler(threshold, resampler) - return ParticleFilter{ESSResampler,PT}(N, conditional_resampler, proposal) + return ParticleFilter{ESSResampler,PT,typeof(aux)}( + N, conditional_resampler, proposal, aux + ) end +update_weights!(state, model, algo, iter, observation; kwargs...) = state +reset_weights!(state::ParticleDistribution, algo::ParticleFilter) = state # Wonky ... by construction of ParticleDistribution + function step( rng::AbstractRNG, model::AbstractStateSpaceModel, @@ -76,7 +90,11 @@ function step( kwargs..., ) # capture the marginalized log-likelihood + # TODO: Add a presampling step for the auxiliary particle filter + state = update_weights!(state, model, algo, iter, observation; kwargs...) state = resample(rng, algo.resampler, state; ref_state) + state = reset_weights!(state, algo) # Reset weights if needed + marginalization_term = logsumexp(state.log_weights) isnothing(callback) || callback(model, algo, iter, state, observation, PostResample; kwargs...) @@ -96,7 +114,7 @@ end function initialise( rng::AbstractRNG, model::StateSpaceModel{T}, - filter::ParticleFilter; + filter::AbstractParticleFilter; ref_state::Union{Nothing,AbstractVector}=nothing, kwargs..., ) where {T} @@ -115,7 +133,7 @@ end function predict( rng::AbstractRNG, model::StateSpaceModel, - filter::ParticleFilter, + filter::AbstractParticleFilter, iter::Integer, state::ParticleDistribution, observation; @@ -150,7 +168,7 @@ end function update( model::StateSpaceModel{T}, - filter::ParticleFilter, + filter::AbstractParticleFilter, iter::Integer, state::ParticleDistribution, observation; @@ -233,3 +251,33 @@ function filter( ) return filter(rng, ssm, algo, observations; ref_state=ref_state, kwargs...) end + +### Auxiliary particle filter +const AuxiliaryParticleFilter{RS,P,WT} = ParticleFilter{RS,P,Vector{WT}} +const APF = AuxiliaryParticleFilter + +function APF(N::Int; kwargs...) + ParticleFilter(N, LatentProposal(); aux=zeros(Float64, N), kwargs...) +end + +function update_weights!( + state::ParticleDistribution, + model::StateSpaceModel, + algo::AuxiliaryParticleFilter, + step::Int, + observation; + kwargs..., +) + # TODO: Can we dispatch on model capabilities maybe ? + auxiliary_log_weights = map(enumerate(state.particles)) do (i, particle) + logeta(particle, model, step, observation; kwargs...) + end + algo.aux = auxiliary_log_weights + state.log_weights += auxiliary_log_weights + return state +end + +function reset_weights!(state::ParticleDistribution, algo::AuxiliaryParticleFilter) + state.log_weights = state.log_weights - algo.aux[state.ancestors] + return state +end