Skip to content

Implement batch forward simulation #28

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion src/models/hierarchical.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import SSMProblems: LatentDynamics, ObservationProcess, simulate
import SSMProblems:
LatentDynamics, ObservationProcess, simulate, batch_simulate, batch_sample
export HierarchicalSSM

struct HierarchicalSSM{T<:Real,OD<:LatentDynamics{T},M<:StateSpaceModel{T}} <:
Expand Down Expand Up @@ -47,3 +48,35 @@ function AbstractMCMC.sample(

return x0, z0, xs, zs, ys
end

function SSMProblems.batch_sample(
rng::AbstractRNG, model::HierarchicalSSM, T::Integer, N::Integer; kwargs...
)
outer_dyn, inner_model = model.outer_dyn, model.inner_model
inner_dyn, obs = inner_model.dyn, inner_model.obs
# Batched types are not known at compile time
x0s = batch_simulate(rng, outer_dyn, N; kwargs...)
z0s = batch_simulate(rng, inner_dyn, N; new_outer=x0s, kwargs...)
xss = Vector{typeof(x0s)}(undef, T)
zss = Vector{typeof(z0s)}(undef, T)
xss[1] = batch_simulate(rng, outer_dyn, 1, x0s; kwargs...)
zss[1] = batch_simulate(rng, inner_dyn, 1, z0s; new_outer=xss[1], kwargs...)
y1s = batch_simulate(rng, obs, 1, zss[1]; new_outer=xss[1], kwargs...)
yss = Vector{typeof(y1s)}(undef, T)

for t in 2:T
xss[t] = batch_simulate(rng, outer_dyn, t, xss[t - 1]; kwargs...)
zss[t] = batch_simulate(
rng,
inner_dyn,
t,
zss[t - 1];
prev_outer=xss[t - 1],
new_outer=xss[t],
kwargs...,
)
yss[t] = batch_simulate(rng, obs, t, zss[t]; new_outer=xss[t], kwargs...)
end

return x0s, z0s, xss, zss, yss
end
66 changes: 66 additions & 0 deletions src/models/linear_gaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,72 @@ function SSMProblems.distribution(
return MvNormal(H * state + c, R)
end

##########################
#### BATCH SIMULATION ####
##########################

function cholesky_from_lu(Ls_lu::CuArray{T}, D, N) where {T}
# Compute Cholesky factor from L (scale L/U to have equal diagonals)
# This is not a particularly optimised approach since it will be replaced later
diags = CUDA.zeros(T, D, D, N)
for d in 1:D
diags[d, d, :] .= Ls_lu[d, d, :]
end
Ls = CUDA.zeros(T, D, D, N)
for d in 1:D
Ls[d, d, :] .= 1.0
end
for i in 1:D
for j in 1:(i - 1)
Ls[i, j, :] .= Ls_lu[i, j, :]
end
end
return NNlib.batched_mul(Ls, sqrt.(diags))
end

function SSMProblems.batch_simulate(
::AbstractRNG, dyn::LinearGaussianLatentDynamics{T}, N::Integer; kwargs...
) where {T}
μ0s, Σ0s = batch_calc_initial(dyn, N; kwargs...)
D = size(μ0s, 1)
# Compute Cholesky factor using LU decomposition (no pivoting needed for Cholesky)
# TODO: replace this with MAGMA's batched Choleksy
Ls_lu = CUDA.CUBLAS.getrf_batched(Σ0s, false)
Ls = cholesky_from_lu(Ls_lu, D, N)
return μ0s .+ NNlib.batched_vec(Ls, CUDA.randn(T, D, N))
end

function SSMProblems.batch_simulate(
::AbstractRNG,
dyn::LinearGaussianLatentDynamics{T},
step::Integer,
prev_state;
kwargs...,
) where {T}
N = size(prev_state, 2)
As, bs, Qs = batch_calc_params(dyn, step, N; kwargs...)
D = size(prev_state, 1)
# Compute Cholesky factor using LU decomposition (no pivoting needed for Cholesky)
Ls_lu = CUDA.CUBLAS.getrf_batched(Qs, false)
Ls = cholesky_from_lu(Ls_lu, D, size(prev_state, 2))
return (
(NNlib.batched_vec(As, prev_state) .+ bs) +
NNlib.batched_vec(Ls, CUDA.randn(T, D, N))
)
end

function SSMProblems.batch_simulate(
::AbstractRNG, obs::LinearGaussianObservationProcess{T}, step::Integer, state; kwargs...
) where {T}
N = size(state, 2)
Hs, cs, Rs = batch_calc_params(obs, step, N; kwargs...)
D = size(Hs, 1)
# Compute Cholesky factor using LU decomposition (no pivoting needed for Cholesky)
Ls_lu = CUDA.CUBLAS.getrf_batched(Rs, false)
Ls = cholesky_from_lu(Ls_lu, D, N)
return (NNlib.batched_vec(Hs, state) .+ cs) + NNlib.batched_vec(Ls, CUDA.randn(T, D, N))
end

###########################################
#### HOMOGENEOUS LINEAR GAUSSIAN MODEL ####
###########################################
Expand Down