Skip to content

Commit bd2e22f

Browse files
Split trajectory
1 parent 2cb12ef commit bd2e22f

File tree

1 file changed

+40
-22
lines changed

1 file changed

+40
-22
lines changed

Diff for: src/algorithms/ffbs.jl

+40-22
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,43 @@ function (callback::WeightedParticleRecorderCallback)(
2525
return nothing
2626
end
2727

28+
function gen_trajectory(
29+
rng::Random.AbstractRNG,
30+
model::StateSpaceModel,
31+
particles::AbstractMatrix{T}, # Need better container
32+
log_weights::AbstractMatrix{WT},
33+
forward_state,
34+
n_timestep::Int;
35+
kwargs...
36+
) where {T,WT}
37+
trajectory = Vector{T}(undef, n_timestep)
38+
trajectory[end] = forward_state
39+
for step in (n_timestep - 1):-1:1
40+
backward_weights = backward(
41+
model,
42+
step,
43+
trajectory[step + 1],
44+
particles[step, :],
45+
log_weights[step, :];
46+
kwargs...,
47+
)
48+
ancestor = rand(rng, Categorical(softmax(backward_weights)))
49+
trajectory[step] = particles[step, ancestor]
50+
end
51+
return trajectory
52+
end
53+
54+
55+
function backward(
56+
model::StateSpaceModel, step::Integer, state, particles::T, log_weights::WT; kwargs...
57+
) where {T,WT}
58+
transitions = map(particles) do prev_state
59+
SSMProblems.logdensity(model.dyn, step, prev_state, state; kwargs...)
60+
end
61+
return log_weights + transitions
62+
end
63+
64+
2865
function sample(
2966
rng::Random.AbstractRNG,
3067
model::StateSpaceModel{T,LDT},
@@ -34,6 +71,7 @@ function sample(
3471
callback=nothing,
3572
kwargs...,
3673
) where {T,LDT,N}
74+
3775
n_timestep = length(obs)
3876
recorder = WeightedParticleRecorderCallback(
3977
Array{eltype(model.dyn)}(undef, n_timestep, N), Array{T}(undef, n_timestep, N)
@@ -46,28 +84,8 @@ function sample(
4684
trajectories = Array{eltype(model.dyn)}(undef, n_timestep, M)
4785

4886
trajectories[end, :] = particles.filtered[idx_ref]
49-
for step in (n_timestep - 1):-1:1
50-
for j in 1:M
51-
backward_weights = backward(
52-
model::StateSpaceModel,
53-
step,
54-
trajectories[step + 1],
55-
recorder.particles[step, :],
56-
recorder.log_weights[step, :];
57-
kwargs...,
58-
)
59-
ancestor = rand(rng, Categorical(softmax(backward_weights)))
60-
trajectories[step, j] = recorder.particles[step, ancestor]
61-
end
87+
for j in 1:M
88+
trajectories[:, j] = gen_trajectory(rng, model, recorder.particles, recorder.log_weights, trajectories[end, j], n_timestep)
6289
end
6390
return trajectories
6491
end
65-
66-
function backward(
67-
model::StateSpaceModel, step::Integer, state, particles::T, log_weights::WT; kwargs...
68-
) where {T,WT}
69-
transitions = map(
70-
x -> SSMProblems.logdensity(model.dyn, step, x, state; kwargs...), particles
71-
)
72-
return log_weights + transitions
73-
end

0 commit comments

Comments
 (0)