diff --git a/src/models/hierarchical.jl b/src/models/hierarchical.jl index 60c6f32..dd1e4a6 100644 --- a/src/models/hierarchical.jl +++ b/src/models/hierarchical.jl @@ -1,4 +1,5 @@ -import SSMProblems: LatentDynamics, ObservationProcess, simulate +import SSMProblems: + LatentDynamics, ObservationProcess, simulate, batch_simulate, batch_sample export HierarchicalSSM struct HierarchicalSSM{T<:Real,OD<:LatentDynamics{T},M<:StateSpaceModel{T}} <: @@ -47,3 +48,35 @@ function AbstractMCMC.sample( return x0, z0, xs, zs, ys end + +function SSMProblems.batch_sample( + rng::AbstractRNG, model::HierarchicalSSM, T::Integer, N::Integer; kwargs... +) + outer_dyn, inner_model = model.outer_dyn, model.inner_model + inner_dyn, obs = inner_model.dyn, inner_model.obs + # Batched types are not known at compile time + x0s = batch_simulate(rng, outer_dyn, N; kwargs...) + z0s = batch_simulate(rng, inner_dyn, N; new_outer=x0s, kwargs...) + xss = Vector{typeof(x0s)}(undef, T) + zss = Vector{typeof(z0s)}(undef, T) + xss[1] = batch_simulate(rng, outer_dyn, 1, x0s; kwargs...) + zss[1] = batch_simulate(rng, inner_dyn, 1, z0s; new_outer=xss[1], kwargs...) + y1s = batch_simulate(rng, obs, 1, zss[1]; new_outer=xss[1], kwargs...) + yss = Vector{typeof(y1s)}(undef, T) + + for t in 2:T + xss[t] = batch_simulate(rng, outer_dyn, t, xss[t - 1]; kwargs...) + zss[t] = batch_simulate( + rng, + inner_dyn, + t, + zss[t - 1]; + prev_outer=xss[t - 1], + new_outer=xss[t], + kwargs..., + ) + yss[t] = batch_simulate(rng, obs, t, zss[t]; new_outer=xss[t], kwargs...) + end + + return x0s, z0s, xss, zss, yss +end diff --git a/src/models/linear_gaussian.jl b/src/models/linear_gaussian.jl index 20756f2..2fede2a 100644 --- a/src/models/linear_gaussian.jl +++ b/src/models/linear_gaussian.jl @@ -71,6 +71,72 @@ function SSMProblems.distribution( return MvNormal(H * state + c, R) end +########################## +#### BATCH SIMULATION #### +########################## + +function cholesky_from_lu(Ls_lu::CuArray{T}, D, N) where {T} + # Compute Cholesky factor from L (scale L/U to have equal diagonals) + # This is not a particularly optimised approach since it will be replaced later + diags = CUDA.zeros(T, D, D, N) + for d in 1:D + diags[d, d, :] .= Ls_lu[d, d, :] + end + Ls = CUDA.zeros(T, D, D, N) + for d in 1:D + Ls[d, d, :] .= 1.0 + end + for i in 1:D + for j in 1:(i - 1) + Ls[i, j, :] .= Ls_lu[i, j, :] + end + end + return NNlib.batched_mul(Ls, sqrt.(diags)) +end + +function SSMProblems.batch_simulate( + ::AbstractRNG, dyn::LinearGaussianLatentDynamics{T}, N::Integer; kwargs... +) where {T} + μ0s, Σ0s = batch_calc_initial(dyn, N; kwargs...) + D = size(μ0s, 1) + # Compute Cholesky factor using LU decomposition (no pivoting needed for Cholesky) + # TODO: replace this with MAGMA's batched Choleksy + Ls_lu = CUDA.CUBLAS.getrf_batched(Σ0s, false) + Ls = cholesky_from_lu(Ls_lu, D, N) + return μ0s .+ NNlib.batched_vec(Ls, CUDA.randn(T, D, N)) +end + +function SSMProblems.batch_simulate( + ::AbstractRNG, + dyn::LinearGaussianLatentDynamics{T}, + step::Integer, + prev_state; + kwargs..., +) where {T} + N = size(prev_state, 2) + As, bs, Qs = batch_calc_params(dyn, step, N; kwargs...) + D = size(prev_state, 1) + # Compute Cholesky factor using LU decomposition (no pivoting needed for Cholesky) + Ls_lu = CUDA.CUBLAS.getrf_batched(Qs, false) + Ls = cholesky_from_lu(Ls_lu, D, size(prev_state, 2)) + return ( + (NNlib.batched_vec(As, prev_state) .+ bs) + + NNlib.batched_vec(Ls, CUDA.randn(T, D, N)) + ) +end + +function SSMProblems.batch_simulate( + ::AbstractRNG, obs::LinearGaussianObservationProcess{T}, step::Integer, state; kwargs... +) where {T} + N = size(state, 2) + Hs, cs, Rs = batch_calc_params(obs, step, N; kwargs...) + D = size(Hs, 1) + # Compute Cholesky factor using LU decomposition (no pivoting needed for Cholesky) + Ls_lu = CUDA.CUBLAS.getrf_batched(Rs, false) + Ls = cholesky_from_lu(Ls_lu, D, N) + return (NNlib.batched_vec(Hs, state) .+ cs) + NNlib.batched_vec(Ls, CUDA.randn(T, D, N)) +end + ########################################### #### HOMOGENEOUS LINEAR GAUSSIAN MODEL #### ###########################################