@@ -25,6 +25,43 @@ function (callback::WeightedParticleRecorderCallback)(
25
25
return nothing
26
26
end
27
27
28
+ function gen_trajectory (
29
+ rng:: Random.AbstractRNG ,
30
+ model:: StateSpaceModel ,
31
+ particles:: AbstractMatrix{T} , # Need better container
32
+ log_weights:: AbstractMatrix{WT} ,
33
+ forward_state,
34
+ n_timestep:: Int ;
35
+ kwargs...
36
+ ) where {T,WT}
37
+ trajectory = Vector {T} (undef, n_timestep)
38
+ trajectory[end ] = forward_state
39
+ for step in (n_timestep - 1 ): - 1 : 1
40
+ backward_weights = backward (
41
+ model,
42
+ step,
43
+ trajectory[step + 1 ],
44
+ particles[step, :],
45
+ log_weights[step, :];
46
+ kwargs... ,
47
+ )
48
+ ancestor = rand (rng, Categorical (softmax (backward_weights)))
49
+ trajectory[step] = particles[step, ancestor]
50
+ end
51
+ return trajectory
52
+ end
53
+
54
+
55
+ function backward (
56
+ model:: StateSpaceModel , step:: Integer , state, particles:: T , log_weights:: WT ; kwargs...
57
+ ) where {T,WT}
58
+ transitions = map (particles) do prev_state
59
+ SSMProblems. logdensity (model. dyn, step, prev_state, state; kwargs... )
60
+ end
61
+ return log_weights + transitions
62
+ end
63
+
64
+
28
65
function sample (
29
66
rng:: Random.AbstractRNG ,
30
67
model:: StateSpaceModel{T,LDT} ,
@@ -34,6 +71,7 @@ function sample(
34
71
callback= nothing ,
35
72
kwargs... ,
36
73
) where {T,LDT,N}
74
+
37
75
n_timestep = length (obs)
38
76
recorder = WeightedParticleRecorderCallback (
39
77
Array {eltype(model.dyn)} (undef, n_timestep, N), Array {T} (undef, n_timestep, N)
@@ -46,28 +84,8 @@ function sample(
46
84
trajectories = Array {eltype(model.dyn)} (undef, n_timestep, M)
47
85
48
86
trajectories[end , :] = particles. filtered[idx_ref]
49
- for step in (n_timestep - 1 ): - 1 : 1
50
- for j in 1 : M
51
- backward_weights = backward (
52
- model:: StateSpaceModel ,
53
- step,
54
- trajectories[step + 1 ],
55
- recorder. particles[step, :],
56
- recorder. log_weights[step, :];
57
- kwargs... ,
58
- )
59
- ancestor = rand (rng, Categorical (softmax (backward_weights)))
60
- trajectories[step, j] = recorder. particles[step, ancestor]
61
- end
87
+ for j in 1 : M
88
+ trajectories[:, j] = gen_trajectory (rng, model, recorder. particles, recorder. log_weights, trajectories[end , j], n_timestep)
62
89
end
63
90
return trajectories
64
91
end
65
-
66
- function backward (
67
- model:: StateSpaceModel , step:: Integer , state, particles:: T , log_weights:: WT ; kwargs...
68
- ) where {T,WT}
69
- transitions = map (
70
- x -> SSMProblems. logdensity (model. dyn, step, x, state; kwargs... ), particles
71
- )
72
- return log_weights + transitions
73
- end
0 commit comments