From 6bcfd7cca2d236684c3e2de73d5af7dfad217183 Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Wed, 16 Apr 2025 21:52:54 +0100 Subject: [PATCH 1/3] Add APF --- .../src/GFTest/models/linear_gaussian.jl | 2 +- .../src/algorithms/particles.jl | 59 ++++++++++++++++++- 2 files changed, 57 insertions(+), 4 deletions(-) diff --git a/GeneralisedFilters/src/GFTest/models/linear_gaussian.jl b/GeneralisedFilters/src/GFTest/models/linear_gaussian.jl index 8c4fff6..c877055 100644 --- a/GeneralisedFilters/src/GFTest/models/linear_gaussian.jl +++ b/GeneralisedFilters/src/GFTest/models/linear_gaussian.jl @@ -61,4 +61,4 @@ function _compute_joint(model, T::Integer) Σ_Z = ((I - P) \ Σ_ϵ) / (I - P)' return μ_Z, Σ_Z -end \ No newline at end of file +end diff --git a/GeneralisedFilters/src/algorithms/particles.jl b/GeneralisedFilters/src/algorithms/particles.jl index 879e157..537ad7b 100644 --- a/GeneralisedFilters/src/algorithms/particles.jl +++ b/GeneralisedFilters/src/algorithms/particles.jl @@ -1,5 +1,7 @@ export BootstrapFilter, BF export ParticleFilter, PF, AbstractProposal +export AuxiliaryParticleFilter, APF +export logeta import SSMProblems: distribution, simulate, logdensity @@ -47,6 +49,8 @@ function SSMProblems.logdensity( ) end +function logeta end + abstract type AbstractParticleFilter <: AbstractFilter end struct ParticleFilter{RS,PT} <: AbstractParticleFilter @@ -76,7 +80,12 @@ function step( kwargs..., ) # capture the marginalized log-likelihood + # TODO: Add a presampling step for the auxiliary particle filter + println(typeof(algo)) + update_weights!(state, model, algo, iter, observation; kwargs...) state = resample(rng, algo.resampler, state; ref_state) + reset_weights!(state, algo) # Reset weights if needed + marginalization_term = logsumexp(state.log_weights) isnothing(callback) || callback(model, algo, iter, state, observation, PostResample; kwargs...) @@ -96,7 +105,7 @@ end function initialise( rng::AbstractRNG, model::StateSpaceModel{T}, - filter::ParticleFilter; + filter::AbstractParticleFilter; ref_state::Union{Nothing,AbstractVector}=nothing, kwargs..., ) where {T} @@ -115,7 +124,7 @@ end function predict( rng::AbstractRNG, model::StateSpaceModel, - filter::ParticleFilter, + filter::AbstractParticleFilter, iter::Integer, state::ParticleDistribution, observation; @@ -150,7 +159,7 @@ end function update( model::StateSpaceModel{T}, - filter::ParticleFilter, + filter::AbstractParticleFilter, iter::Integer, state::ParticleDistribution, observation; @@ -233,3 +242,47 @@ function filter( ) return filter(rng, ssm, algo, observations; ref_state=ref_state, kwargs...) end + +### AuxiliaryParticleFilter +mutable struct AuxiliaryParticleFilter{RS,P,WT} <: AbstractParticleFilter + const N::Int + const resampler::RS + proposal::P + aux::Array{WT} +end + +const APF = AuxiliaryParticleFilter + +# TODO Need to think more about that +function AuxiliaryParticleFilter( + N::Integer, proposal::PT; threshold::Real=1.0, resampler::AbstractResampler=Systematic() +) where {PT<:AbstractProposal} + conditional_resampler = ESSResampler(threshold, resampler) + aux = zeros(Float64, N) + return AuxiliaryParticleFilter{ESSResampler,PT,Float64}(N, conditional_resampler, proposal, aux) +end + +APF(N::Int; kwargs...) = AuxiliaryParticleFilter(N, LatentProposal(); kwargs...) + +update_weights!(state, model, algo, iter, observation; kwargs...) = state + +function update_weights!( + state::ParticleDistribution, + model::StateSpaceModel, + algo::AuxiliaryParticleFilter, + step::Int, + observation; + kwargs..., +) + # TODO: Can we dispatch on model capabilities maybe ? + auxiliary_log_weights = map(enumerate(state.particles)) do (i, particle) + logeta(particle, model, step, observation; kwargs...) + end + algo.aux = auxiliary_log_weights + state.log_weights += auxiliary_log_weights +end + +function reset_weights!(state::ParticleDistribution, algo::AuxiliaryParticleFilter) + state.log_weights = state.log_weights[state.ancestors] - algo.aux[state.ancestors] +end +reset_weights!(state::ParticleDistribution, algo::ParticleFilter) = state # Wonky, construction of ParticleDistribution From 658474dbd32d0083cec035415eed2c1b14a0d07b Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Thu, 17 Apr 2025 19:08:22 +0100 Subject: [PATCH 2/3] Remove extra type --- .../src/algorithms/particles.jl | 56 +++++++++---------- 1 file changed, 25 insertions(+), 31 deletions(-) diff --git a/GeneralisedFilters/src/algorithms/particles.jl b/GeneralisedFilters/src/algorithms/particles.jl index 537ad7b..c30f74a 100644 --- a/GeneralisedFilters/src/algorithms/particles.jl +++ b/GeneralisedFilters/src/algorithms/particles.jl @@ -53,21 +53,32 @@ function logeta end abstract type AbstractParticleFilter <: AbstractFilter end -struct ParticleFilter{RS,PT} <: AbstractParticleFilter - N::Int - resampler::RS - proposal::PT +mutable struct ParticleFilter{RS,PT,WT} <: AbstractParticleFilter + const N::Int + const resampler::RS + const proposal::PT + aux::Union{Nothing,WT} end const PF = ParticleFilter function ParticleFilter( - N::Integer, proposal::PT; threshold::Real=1.0, resampler::AbstractResampler=Systematic() + N::Integer, + proposal::PT; + threshold::Real=1.0, + resampler::AbstractResampler=Systematic(), + WT::Type=Float64, ) where {PT<:AbstractProposal} conditional_resampler = ESSResampler(threshold, resampler) - return ParticleFilter{ESSResampler,PT}(N, conditional_resampler, proposal) + aux = zeros(WT, N) + return ParticleFilter{ESSResampler,PT,typeof(aux)}( + N, conditional_resampler, proposal, aux + ) end +update_weights!(state, model, algo, iter, observation; kwargs...) = state +reset_weights!(state::ParticleDistribution, algo::ParticleFilter) = state # Wonky ... by construction of ParticleDistribution + function step( rng::AbstractRNG, model::AbstractStateSpaceModel, @@ -81,10 +92,9 @@ function step( ) # capture the marginalized log-likelihood # TODO: Add a presampling step for the auxiliary particle filter - println(typeof(algo)) - update_weights!(state, model, algo, iter, observation; kwargs...) + state = update_weights!(state, model, algo, iter, observation; kwargs...) state = resample(rng, algo.resampler, state; ref_state) - reset_weights!(state, algo) # Reset weights if needed + state = reset_weights!(state, algo) # Reset weights if needed marginalization_term = logsumexp(state.log_weights) isnothing(callback) || @@ -243,28 +253,11 @@ function filter( return filter(rng, ssm, algo, observations; ref_state=ref_state, kwargs...) end -### AuxiliaryParticleFilter -mutable struct AuxiliaryParticleFilter{RS,P,WT} <: AbstractParticleFilter - const N::Int - const resampler::RS - proposal::P - aux::Array{WT} -end - +### Auxiliary particle filter +const AuxiliaryParticleFilter{RS,P,WT} = ParticleFilter{RS,P,Array{WT}} const APF = AuxiliaryParticleFilter -# TODO Need to think more about that -function AuxiliaryParticleFilter( - N::Integer, proposal::PT; threshold::Real=1.0, resampler::AbstractResampler=Systematic() -) where {PT<:AbstractProposal} - conditional_resampler = ESSResampler(threshold, resampler) - aux = zeros(Float64, N) - return AuxiliaryParticleFilter{ESSResampler,PT,Float64}(N, conditional_resampler, proposal, aux) -end - -APF(N::Int; kwargs...) = AuxiliaryParticleFilter(N, LatentProposal(); kwargs...) - -update_weights!(state, model, algo, iter, observation; kwargs...) = state +APF(N::Int; kwargs...) = ParticleFilter(N, LatentProposal(); kwargs...) function update_weights!( state::ParticleDistribution, @@ -280,9 +273,10 @@ function update_weights!( end algo.aux = auxiliary_log_weights state.log_weights += auxiliary_log_weights + return state end function reset_weights!(state::ParticleDistribution, algo::AuxiliaryParticleFilter) - state.log_weights = state.log_weights[state.ancestors] - algo.aux[state.ancestors] + state.log_weights = state.log_weights - algo.aux[state.ancestors] + return state end -reset_weights!(state::ParticleDistribution, algo::ParticleFilter) = state # Wonky, construction of ParticleDistribution From 8c3aaee49293a54d091072250b8b60a1b3b52889 Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Thu, 17 Apr 2025 19:41:00 +0100 Subject: [PATCH 3/3] Type piracy --- GeneralisedFilters/src/algorithms/particles.jl | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/GeneralisedFilters/src/algorithms/particles.jl b/GeneralisedFilters/src/algorithms/particles.jl index c30f74a..e23eed6 100644 --- a/GeneralisedFilters/src/algorithms/particles.jl +++ b/GeneralisedFilters/src/algorithms/particles.jl @@ -57,7 +57,7 @@ mutable struct ParticleFilter{RS,PT,WT} <: AbstractParticleFilter const N::Int const resampler::RS const proposal::PT - aux::Union{Nothing,WT} + aux::WT end const PF = ParticleFilter @@ -67,10 +67,9 @@ function ParticleFilter( proposal::PT; threshold::Real=1.0, resampler::AbstractResampler=Systematic(), - WT::Type=Float64, -) where {PT<:AbstractProposal} + aux::Union{Nothing,Vector{WT}}=nothing, +) where {PT<:AbstractProposal,WT} conditional_resampler = ESSResampler(threshold, resampler) - aux = zeros(WT, N) return ParticleFilter{ESSResampler,PT,typeof(aux)}( N, conditional_resampler, proposal, aux ) @@ -254,10 +253,12 @@ function filter( end ### Auxiliary particle filter -const AuxiliaryParticleFilter{RS,P,WT} = ParticleFilter{RS,P,Array{WT}} +const AuxiliaryParticleFilter{RS,P,WT} = ParticleFilter{RS,P,Vector{WT}} const APF = AuxiliaryParticleFilter -APF(N::Int; kwargs...) = ParticleFilter(N, LatentProposal(); kwargs...) +function APF(N::Int; kwargs...) + ParticleFilter(N, LatentProposal(); aux=zeros(Float64, N), kwargs...) +end function update_weights!( state::ParticleDistribution,