@@ -26,13 +26,13 @@ function (callback::WeightedParticleRecorderCallback)(
26
26
end
27
27
28
28
function gen_trajectory (
29
- rng:: Random.AbstractRNG ,
30
- model:: StateSpaceModel ,
29
+ rng:: Random.AbstractRNG ,
30
+ model:: StateSpaceModel ,
31
31
particles:: AbstractMatrix{T} , # Need better container
32
- log_weights:: AbstractMatrix{WT} ,
32
+ log_weights:: AbstractMatrix{WT} ,
33
33
forward_state,
34
34
n_timestep:: Int ;
35
- kwargs...
35
+ kwargs... ,
36
36
) where {T,WT}
37
37
trajectory = Vector {T} (undef, n_timestep)
38
38
trajectory[end ] = forward_state
@@ -51,7 +51,6 @@ function gen_trajectory(
51
51
return trajectory
52
52
end
53
53
54
-
55
54
function backward (
56
55
model:: StateSpaceModel , step:: Integer , state, particles:: T , log_weights:: WT ; kwargs...
57
56
) where {T,WT}
@@ -61,7 +60,6 @@ function backward(
61
60
return log_weights + transitions
62
61
end
63
62
64
-
65
63
function sample (
66
64
rng:: Random.AbstractRNG ,
67
65
model:: StateSpaceModel{T,LDT} ,
@@ -71,7 +69,6 @@ function sample(
71
69
callback= nothing ,
72
70
kwargs... ,
73
71
) where {T,LDT,N}
74
-
75
72
n_timestep = length (obs)
76
73
recorder = WeightedParticleRecorderCallback (
77
74
Array {eltype(model.dyn)} (undef, n_timestep, N), Array {T} (undef, n_timestep, N)
@@ -85,7 +82,14 @@ function sample(
85
82
86
83
trajectories[end , :] = particles. filtered[idx_ref]
87
84
for j in 1 : M
88
- trajectories[:, j] = gen_trajectory (rng, model, recorder. particles, recorder. log_weights, trajectories[end , j], n_timestep)
85
+ trajectories[:, j] = gen_trajectory (
86
+ rng,
87
+ model,
88
+ recorder. particles,
89
+ recorder. log_weights,
90
+ trajectories[end , j],
91
+ n_timestep,
92
+ )
89
93
end
90
94
return trajectories
91
95
end
0 commit comments