-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathapf.jl
120 lines (103 loc) · 3.69 KB
/
apf.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
export AuxiliaryParticleFilter, APF
mutable struct AuxiliaryParticleFilter{N,RS<:AbstractConditionalResampler} <: AbstractParticleFilter{N}
resampler::RS
aux::Vector # Auxiliary weights
end
function AuxiliaryParticleFilter(
N::Integer; threshold::Real=0.0, resampler::AbstractResampler=Systematic()
)
conditional_resampler = ESSResampler(threshold, resampler)
return AuxiliaryParticleFilter{N,typeof(conditional_resampler)}(conditional_resampler, zeros(N))
end
const APF = AuxiliaryParticleFilter
function initialise(
rng::AbstractRNG,
model::StateSpaceModel{T},
filter::AuxiliaryParticleFilter{N},
ref_state::Union{Nothing,AbstractVector}=nothing,
kwargs...,
) where {N,T}
initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:N)
initial_weights = zeros(T, N)
return update_ref!(
ParticleContainer(initial_states, initial_weights), ref_state, filter
)
end
function update_weights!(
rng::AbstractRNG, filter, model, step, states, observation; kwargs...
)
simulation_weights = eta(rng, model, step, states, observation)
return states.log_weights += simulation_weights
end
function predict(
rng::AbstractRNG,
model::StateSpaceModel,
filter::AuxiliaryParticleFilter,
step::Integer,
states::ParticleContainer{T},
observation;
ref_state::Union{Nothing,AbstractVector{T}}=nothing,
kwargs...,
) where {T}
# states = update_weights!(rng, filter.eta, model, step, states.filtered, observation; kwargs...)
# Compute auxilary weights
# POC: use the simplest approximation to the predictive likelihood
# Ideally should be something like update_weights!(filter, ...)
predicted = map(
x -> mean(SSMProblems.distribution(model.dyn, step, x; kwargs...)),
states.filtered.particles,
)
auxiliary_weights = map(
x -> SSMProblems.logdensity(model.obs, step, x, observation; kwargs...), predicted
)
states.filtered.log_weights .+= auxiliary_weights
filter.aux = auxiliary_weights
states.proposed, states.ancestors = resample(
rng, filter.resampler, states.filtered, filter
)
states.proposed.particles = map(
x -> SSMProblems.simulate(rng, model.dyn, step, x; kwargs...),
states.proposed.particles,
)
return update_ref!(states, ref_state, filter, step)
end
function update(
model::StateSpaceModel{T},
filter::AuxiliaryParticleFilter{N},
step::Integer,
states::ParticleContainer,
observation;
kwargs...,
) where {T,N}
@debug "step $step"
log_increments = map(
x -> SSMProblems.logdensity(model.obs, step, x, observation; kwargs...),
collect(states.proposed.particles),
)
states.filtered.log_weights = states.proposed.log_weights + log_increments
states.filtered.particles = states.proposed.particles
return states, logmarginal(states, filter)
end
function step(
rng::AbstractRNG,
model::AbstractStateSpaceModel,
alg::AuxiliaryParticleFilter,
iter::Integer,
state,
observation;
kwargs...,
)
proposed_state = predict(rng, model, alg, iter, state, observation; kwargs...)
filtered_state, ll = update(model, alg, iter, proposed_state, observation; kwargs...)
return filtered_state, ll
end
function reset_weights!(
state::ParticleState{T,WT}, idxs, filter::AuxiliaryParticleFilter
) where {T,WT<:Real}
# From Choping: An Introduction to sequential monte carlo, section 10.3.3
state.log_weights = state.log_weights[idxs] - filter.aux[idxs]
return state
end
function logmarginal(states::ParticleContainer, ::AuxiliaryParticleFilter)
return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights)
end