Skip to content

Commit 9cc94eb

Browse files
committed
update to most recent SSMProblems interface
1 parent 8b4b558 commit 9cc94eb

File tree

11 files changed

+232
-257
lines changed

11 files changed

+232
-257
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
*.jl.cov
22
*.jl.*.cov
33
*.jl.mem
4-
/Manifest.toml
4+
Manifest.toml
55
/test/Manifest.toml

Project.toml

-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ Random = "<0.0.1, 1"
2626
Random123 = "1.3"
2727
Requires = "1.0"
2828
StatsFuns = "0.9, 1"
29-
SSMProblems = "0.1"
3029
julia = "1.7"
3130

3231
[extras]

examples/gaussian-process/script.jl

+57-48
Original file line numberDiff line numberDiff line change
@@ -8,74 +8,83 @@ using Distributions
88
using Libtask
99
using SSMProblems
1010

11-
Parameters = @NamedTuple begin
12-
a::Float64
13-
q::Float64
14-
kernel
11+
# Gaussian process encoded transition dynamics
12+
mutable struct GaussianProcessDynamics{T<:Real} <: SSMProblems.LatentDynamics{T,T}
13+
proc::AbstractGPs.AbstractGP
14+
q::T
15+
function GaussianProcessDynamics(q::T, kernel::KT) where {T<:Real,KT<:Kernel}
16+
return new{T}(GP(ZeroMean{T}(), kernel), q)
17+
end
1518
end
1619

17-
mutable struct GPSSM <: SSMProblems.AbstractStateSpaceModel
18-
X::Vector{Float64}
19-
observations::Vector{Float64}
20-
θ::Parameters
21-
22-
GPSSM(params::Parameters) = new(Vector{Float64}(), params)
23-
GPSSM(y::Vector{Float64}, params::Parameters) = new(Vector{Float64}(), y, params)
20+
function SSMProblems.distribution(dyn::GaussianProcessDynamics{T}) where {T<:Real}
21+
return Normal(zero(T), dyn.q)
2422
end
2523

26-
seed = 1
27-
T = 100
28-
Nₚ = 20
29-
Nₛ = 250
30-
a = 0.9
31-
q = 0.5
32-
33-
params = Parameters((a, q, SqExponentialKernel()))
24+
# TODO: broken...
25+
function SSMProblems.simulate(
26+
rng::AbstractRNG, dyn::GaussianProcessDynamics, step::Int, state
27+
)
28+
dyn.proc = posterior(dyn.proc(step:step), [state])
29+
μ, σ = mean_and_cov(dyn.proc, [step])
30+
return rand(rng, Normal(μ[1], sqrt(σ[1])))
31+
end
3432

35-
f::Parameters, x, t) = Normal.a * x, θ.q)
36-
h::Parameters) = Normal(0, θ.q)
37-
g::Parameters, x, t) = Normal(0, exp(0.5 * x)^2)
33+
function SSMProblems.logdensity(dyn::GaussianProcessDynamics, step::Int, state, prev_state)
34+
μ, σ = mean_and_cov(dyn.proc, [step])
35+
return logpdf(Normal(μ, sqrt(σ)), state)
36+
end
3837

39-
rng = Random.MersenneTwister(seed)
38+
# Linear Gaussian dynamics used for simulation
39+
struct LinearGaussianDynamics{T<:Real} <: SSMProblems.LatentDynamics{T,T}
40+
a::T
41+
q::T
42+
end
4043

41-
x = zeros(T)
42-
y = similar(x)
43-
x[1] = rand(rng, h(params))
44-
for t in 1:T
45-
if t < T
46-
x[t + 1] = rand(rng, f(params, x[t], t))
47-
end
48-
y[t] = rand(rng, g(params, x[t], t))
44+
function SSMProblems.distribution(dyn::LinearGaussianDynamics{T}) where {T<:Real}
45+
return Normal(zero(T), dyn.q)
4946
end
5047

51-
function gp_update(model::GPSSM, state, step)
52-
gp = GP(model.θ.kernel)
53-
prior = gp(1:(step - 1))
54-
post = posterior(prior, model.X[1:(step - 1)])
55-
μ, σ = mean_and_cov(post, [step])
56-
return Normal(μ[1], σ[1])
48+
function SSMProblems.distribution(dyn::LinearGaussianDynamics, ::Int, state)
49+
return Normal(dyn.a * state, dyn.q)
5750
end
5851

59-
SSMProblems.transition!!(rng::AbstractRNG, model::GPSSM) = rand(rng, h(model.θ))
60-
function SSMProblems.transition!!(rng::AbstractRNG, model::GPSSM, state, step)
61-
return rand(rng, gp_update(model, state, step))
52+
# Observation process used in both variants of the model
53+
struct StochasticVolatility{T<:Real} <: SSMProblems.ObservationProcess{T,T} end
54+
55+
function SSMProblems.distribution(::StochasticVolatility{T}, ::Int, state) where {T<:Real}
56+
return Normal(zero(T), exp((1 / 2) * state))
6257
end
6358

64-
function SSMProblems.emission_logdensity(model::GPSSM, state, step)
65-
return logpdf(g(model.θ, state, step), model.observations[step])
59+
# Baseline model (for simulation)
60+
function LinearGaussianStochasticVolatilityModel(a::T, q::T) where {T<:Real}
61+
dyn = LinearGaussianDynamics(a, q)
62+
obs = StochasticVolatility{T}()
63+
return SSMProblems.StateSpaceModel(dyn, obs)
6664
end
67-
function SSMProblems.transition_logdensity(model::GPSSM, prev_state, current_state, step)
68-
return logpdf(gp_update(model, prev_state, step), current_state)
65+
66+
# Gaussian process model (for sampling)
67+
function GaussianProcessStateSpaceModel(q::T, kernel::KT) where {T<:Real,KT<:Kernel}
68+
dyn = GaussianProcessDynamics(q, kernel)
69+
obs = StochasticVolatility{T}()
70+
return SSMProblems.StateSpaceModel(dyn, obs)
6971
end
7072

71-
AdvancedPS.isdone(::GPSSM, step) = step > T
73+
# Everything is now ready to simulate some data.
74+
rng = Random.MersenneTwister(1234)
75+
true_model = LinearGaussianStochasticVolatilityModel(0.9, 0.5)
76+
_, x, y = sample(rng, true_model, 100);
7277

73-
model = GPSSM(y, params)
74-
pg = AdvancedPS.PGAS(Nₚ)
75-
chains = sample(rng, model, pg, Nₛ)
78+
# Create the model and run the sampler
79+
gpssm = GaussianProcessStateSpaceModel(0.5, SqExponentialKernel())
80+
model = gpssm(y)
81+
pg = AdvancedPS.PGAS(20)
82+
chains = sample(rng, model, pg, 250)
83+
#md nothing #hide
7684

7785
particles = hcat([chain.trajectory.model.X for chain in chains]...)
7886
mean_trajectory = mean(particles; dims=2);
87+
#md nothing #hide
7988

8089
scatter(particles; label=false, opacity=0.01, color=:black, xlabel="t", ylabel="state")
8190
plot!(x; color=:darkorange, label="Original Trajectory")

examples/gaussian-ssm/script.jl

+31-57
Original file line numberDiff line numberDiff line change
@@ -28,81 +28,55 @@ using SSMProblems
2828
# as well as the initial distribution $f_0(x) = \mathcal{N}(0, q^2/(1-a^2))$.
2929

3030
# To use `AdvancedPS` we first need to define a model type that subtypes `AdvancedPS.AbstractStateSpaceModel`.
31-
Parameters = @NamedTuple begin
32-
a::Float64
33-
q::Float64
34-
r::Float64
31+
mutable struct Parameters{T<:Real}
32+
a::T
33+
q::T
34+
r::T
3535
end
3636

37-
mutable struct LinearSSM <: SSMProblems.AbstractStateSpaceModel
38-
X::Vector{Float64}
39-
observations::Vector{Float64}
40-
θ::Parameters
41-
LinearSSM::Parameters) = new(Vector{Float64}(), θ)
42-
LinearSSM(y::Vector, θ::Parameters) = new(Vector{Float64}(), y, θ)
37+
struct LinearGaussianDynamics{T<:Real} <: SSMProblems.LatentDynamics{T,T}
38+
a::T
39+
q::T
4340
end
4441

45-
# and the densities defined above.
46-
f::Parameters, state, t) = Normal.a * state, θ.q) # Transition density
47-
g::Parameters, state, t) = Normal(state, θ.r) # Observation density
48-
f₀::Parameters) = Normal(0, θ.q^2 / (1 - θ.a^2)) # Initial state density
49-
#md nothing #hide
42+
function SSMProblems.distribution(dyn::LinearGaussianDynamics{T}; kwargs...) where {T<:Real}
43+
return Normal(zero(T), sqrt(dyn.q^2 / (1 - dyn.a^2)))
44+
end
5045

51-
# We also need to specify the dynamics of the system through the transition equations:
52-
# - `AdvancedPS.initialization`: the initial state density
53-
# - `AdvancedPS.transition`: the state transition density
54-
# - `AdvancedPS.observation`: the observation score given the observed data
55-
# - `AdvancedPS.isdone`: signals the end of the execution for the model
56-
SSMProblems.transition!!(rng::AbstractRNG, model::LinearSSM) = rand(rng, f₀(model.θ))
57-
function SSMProblems.transition!!(
58-
rng::AbstractRNG, model::LinearSSM, state::Float64, step::Int
59-
)
60-
return rand(rng, f(model.θ, state, step))
46+
function SSMProblems.distribution(dyn::LinearGaussianDynamics, step::Int, state; kwargs...)
47+
return Normal(dyn.a * state, dyn.q)
6148
end
6249

63-
function SSMProblems.emission_logdensity(modeL::LinearSSM, state::Float64, step::Int)
64-
return logpdf(g(model.θ, state, step), model.observations[step])
50+
struct LinearGaussianObservation{T<:Real} <: SSMProblems.ObservationProcess{T,T}
51+
r::T
6552
end
66-
function SSMProblems.transition_logdensity(
67-
model::LinearSSM, prev_state, current_state, step
53+
54+
function SSMProblems.distribution(
55+
obs::LinearGaussianObservation, step::Int, state; kwargs...
6856
)
69-
return logpdf(f(model.θ, prev_state, step), current_state)
57+
return Normal(state, obs.r)
7058
end
7159

72-
# We need to think seriously about how the data is handled
73-
AdvancedPS.isdone(::LinearSSM, step) = step > Tₘ
60+
function LinearGaussianStateSpaceModel::Parameters)
61+
dyn = LinearGaussianDynamics.a, θ.q)
62+
obs = LinearGaussianObservation.r)
63+
return SSMProblems.StateSpaceModel(dyn, obs)
64+
end
7465

7566
# Everything is now ready to simulate some data.
76-
a = 0.9 # Scale
77-
q = 0.32 # State variance
78-
r = 1 # Observation variance
79-
Tₘ = 200 # Number of observation
80-
Nₚ = 20 # Number of particles
81-
Nₛ = 500 # Number of samples
82-
seed = 1 # Reproduce everything
83-
84-
θ₀ = Parameters((a, q, r))
85-
rng = Random.MersenneTwister(seed)
86-
87-
x = zeros(Tₘ)
88-
y = zeros(Tₘ)
89-
x[1] = rand(rng, f₀(θ₀))
90-
for t in 1:Tₘ
91-
if t < Tₘ
92-
x[t + 1] = rand(rng, f(θ₀, x[t], t))
93-
end
94-
y[t] = rand(rng, g(θ₀, x[t], t))
95-
end
67+
rng = Random.MersenneTwister(1234)
68+
θ = Parameters(0.9, 0.32, 1.0)
69+
true_model = LinearGaussianStateSpaceModel(θ)
70+
_, x, y = sample(rng, true_model, 200);
9671

9772
# Here are the latent and obseravation timeseries
9873
plot(x; label="x", xlabel="t")
9974
plot!(y; seriestype=:scatter, label="y", xlabel="t", mc=:red, ms=2, ma=0.5)
10075

10176
# `AdvancedPS` subscribes to the `AbstractMCMC` API. To sample we just need to define a Particle Gibbs kernel
10277
# and a model interface.
103-
model = LinearSSM(y, θ₀)
104-
pgas = AdvancedPS.PGAS(Nₚ)
105-
chains = sample(rng, model, pgas, Nₛ; progress=false);
78+
pgas = AdvancedPS.PGAS(20)
79+
chains = sample(rng, true_model(y), pgas, 500; progress=false);
10680
#md nothing #hide
10781

10882
#
@@ -118,7 +92,7 @@ plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9)
11892
# We used a particle gibbs kernel with the ancestor updating step which should help with the particle
11993
# degeneracy problem and improve the mixing.
12094
# We can compute the update rate of $x_t$ vs $t$ defined as the proportion of times $t$ where $x_t$ gets updated:
121-
update_rate = sum(abs.(diff(particles; dims=2)) .> 0; dims=2) / Nₛ
95+
update_rate = sum(abs.(diff(particles; dims=2)) .> 0; dims=2) / length(chains)
12296
#md nothing #hide
12397

12498
# and compare it to the theoretical value of $1 - 1/Nₚ$.
@@ -130,4 +104,4 @@ plot(
130104
xlabel="Iteration",
131105
ylabel="Update rate",
132106
)
133-
hline!([1 - 1 / Nₚ]; label="N: $(Nₚ)")
107+
hline!([1 - 1 / length(chains)]; label="N: $(length(chains))")

0 commit comments

Comments
 (0)