@@ -8,74 +8,83 @@ using Distributions
8
8
using Libtask
9
9
using SSMProblems
10
10
11
- Parameters = @NamedTuple begin
12
- a:: Float64
13
- q:: Float64
14
- kernel
11
+ # Gaussian process encoded transition dynamics
12
+ mutable struct GaussianProcessDynamics{T<: Real } <: SSMProblems.LatentDynamics{T,T}
13
+ proc:: AbstractGPs.AbstractGP
14
+ q:: T
15
+ function GaussianProcessDynamics (q:: T , kernel:: KT ) where {T<: Real ,KT<: Kernel }
16
+ return new {T} (GP (ZeroMean {T} (), kernel), q)
17
+ end
15
18
end
16
19
17
- mutable struct GPSSM <: SSMProblems.AbstractStateSpaceModel
18
- X:: Vector{Float64}
19
- observations:: Vector{Float64}
20
- θ:: Parameters
21
-
22
- GPSSM (params:: Parameters ) = new (Vector {Float64} (), params)
23
- GPSSM (y:: Vector{Float64} , params:: Parameters ) = new (Vector {Float64} (), y, params)
20
+ function SSMProblems. distribution (dyn:: GaussianProcessDynamics{T} ) where {T<: Real }
21
+ return Normal (zero (T), dyn. q)
24
22
end
25
23
26
- seed = 1
27
- T = 100
28
- Nₚ = 20
29
- Nₛ = 250
30
- a = 0.9
31
- q = 0.5
32
-
33
- params = Parameters ((a, q, SqExponentialKernel ()))
24
+ # TODO : broken...
25
+ function SSMProblems . simulate (
26
+ rng :: AbstractRNG , dyn :: GaussianProcessDynamics , step :: Int , state
27
+ )
28
+ dyn . proc = posterior (dyn . proc (step : step), [state])
29
+ μ, σ = mean_and_cov (dyn . proc, [step])
30
+ return rand (rng, Normal (μ[ 1 ], sqrt (σ[ 1 ])))
31
+ end
34
32
35
- f (θ:: Parameters , x, t) = Normal (θ. a * x, θ. q)
36
- h (θ:: Parameters ) = Normal (0 , θ. q)
37
- g (θ:: Parameters , x, t) = Normal (0 , exp (0.5 * x)^ 2 )
33
+ function SSMProblems. logdensity (dyn:: GaussianProcessDynamics , step:: Int , state, prev_state)
34
+ μ, σ = mean_and_cov (dyn. proc, [step])
35
+ return logpdf (Normal (μ, sqrt (σ)), state)
36
+ end
38
37
39
- rng = Random. MersenneTwister (seed)
38
+ # Linear Gaussian dynamics used for simulation
39
+ struct LinearGaussianDynamics{T<: Real } <: SSMProblems.LatentDynamics{T,T}
40
+ a:: T
41
+ q:: T
42
+ end
40
43
41
- x = zeros (T)
42
- y = similar (x)
43
- x[1 ] = rand (rng, h (params))
44
- for t in 1 : T
45
- if t < T
46
- x[t + 1 ] = rand (rng, f (params, x[t], t))
47
- end
48
- y[t] = rand (rng, g (params, x[t], t))
44
+ function SSMProblems. distribution (dyn:: LinearGaussianDynamics{T} ) where {T<: Real }
45
+ return Normal (zero (T), dyn. q)
49
46
end
50
47
51
- function gp_update (model:: GPSSM , state, step)
52
- gp = GP (model. θ. kernel)
53
- prior = gp (1 : (step - 1 ))
54
- post = posterior (prior, model. X[1 : (step - 1 )])
55
- μ, σ = mean_and_cov (post, [step])
56
- return Normal (μ[1 ], σ[1 ])
48
+ function SSMProblems. distribution (dyn:: LinearGaussianDynamics , :: Int , state)
49
+ return Normal (dyn. a * state, dyn. q)
57
50
end
58
51
59
- SSMProblems. transition!! (rng:: AbstractRNG , model:: GPSSM ) = rand (rng, h (model. θ))
60
- function SSMProblems. transition!! (rng:: AbstractRNG , model:: GPSSM , state, step)
61
- return rand (rng, gp_update (model, state, step))
52
+ # Observation process used in both variants of the model
53
+ struct StochasticVolatility{T<: Real } <: SSMProblems.ObservationProcess{T,T} end
54
+
55
+ function SSMProblems. distribution (:: StochasticVolatility{T} , :: Int , state) where {T<: Real }
56
+ return Normal (zero (T), exp ((1 / 2 ) * state))
62
57
end
63
58
64
- function SSMProblems. emission_logdensity (model:: GPSSM , state, step)
65
- return logpdf (g (model. θ, state, step), model. observations[step])
59
+ # Baseline model (for simulation)
60
+ function LinearGaussianStochasticVolatilityModel (a:: T , q:: T ) where {T<: Real }
61
+ dyn = LinearGaussianDynamics (a, q)
62
+ obs = StochasticVolatility {T} ()
63
+ return SSMProblems. StateSpaceModel (dyn, obs)
66
64
end
67
- function SSMProblems. transition_logdensity (model:: GPSSM , prev_state, current_state, step)
68
- return logpdf (gp_update (model, prev_state, step), current_state)
65
+
66
+ # Gaussian process model (for sampling)
67
+ function GaussianProcessStateSpaceModel (q:: T , kernel:: KT ) where {T<: Real ,KT<: Kernel }
68
+ dyn = GaussianProcessDynamics (q, kernel)
69
+ obs = StochasticVolatility {T} ()
70
+ return SSMProblems. StateSpaceModel (dyn, obs)
69
71
end
70
72
71
- AdvancedPS. isdone (:: GPSSM , step) = step > T
73
+ # Everything is now ready to simulate some data.
74
+ rng = Random. MersenneTwister (1234 )
75
+ true_model = LinearGaussianStochasticVolatilityModel (0.9 , 0.5 )
76
+ _, x, y = sample (rng, true_model, 100 );
72
77
73
- model = GPSSM (y, params)
74
- pg = AdvancedPS. PGAS (Nₚ)
75
- chains = sample (rng, model, pg, Nₛ)
78
+ # Create the model and run the sampler
79
+ gpssm = GaussianProcessStateSpaceModel (0.5 , SqExponentialKernel ())
80
+ model = gpssm (y)
81
+ pg = AdvancedPS. PGAS (20 )
82
+ chains = sample (rng, model, pg, 250 )
83
+ # md nothing #hide
76
84
77
85
particles = hcat ([chain. trajectory. model. X for chain in chains]. .. )
78
86
mean_trajectory = mean (particles; dims= 2 );
87
+ # md nothing #hide
79
88
80
89
scatter (particles; label= false , opacity= 0.01 , color= :black , xlabel= " t" , ylabel= " state" )
81
90
plot! (x; color= :darkorange , label= " Original Trajectory" )
0 commit comments