Skip to content

Guided Filter Implementation #74

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Apr 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
04ecb4e
fixed type stability of linear filter
charlesknipp Jan 8, 2025
152917b
added MLE demonstration
charlesknipp Jan 8, 2025
efe1573
flipped sign of objective function
charlesknipp Jan 8, 2025
e99be16
Merge remote-tracking branch 'origin/main' into ck/maximum-likelihood
charlesknipp Mar 7, 2025
acbc4e1
reorganized and added Mooncake MWE
charlesknipp Mar 7, 2025
25eb4d6
fixed KF type stability in Enzyme
charlesknipp Mar 10, 2025
6d061b0
add MWE for Kalman filtering
charlesknipp Mar 11, 2025
f8f0e59
replaced second order optimizer and added backend testing
charlesknipp Mar 11, 2025
da6a54b
fixed type stability of linear filter
charlesknipp Jan 8, 2025
39d9f32
added MLE demonstration
charlesknipp Jan 8, 2025
30a4be3
flipped sign of objective function
charlesknipp Jan 8, 2025
935df10
reorganized and added Mooncake MWE
charlesknipp Mar 7, 2025
2cc7d63
fixed KF type stability in Enzyme
charlesknipp Mar 10, 2025
5c9149e
add MWE for Kalman filtering
charlesknipp Mar 11, 2025
4936c52
replaced second order optimizer and added backend testing
charlesknipp Mar 11, 2025
e8e6c02
Merge branch 'ck/maximum-likelihood' of https://github.com/TuringLang…
charlesknipp Mar 13, 2025
56c6dc0
fixed Mooncake errors for Bootstrap filter
charlesknipp Mar 13, 2025
7da1732
Add guided filter draft
charlesknipp Mar 17, 2025
262650d
add VSMC replication
charlesknipp Mar 17, 2025
ebc7843
switch proposals for demonstration
charlesknipp Mar 17, 2025
016a336
fix formatting
charlesknipp Mar 17, 2025
6715cb6
add fix for Flux and Mooncake
charlesknipp Mar 18, 2025
2c3be92
add plots and fix formatting
charlesknipp Mar 18, 2025
ada8306
restructure particle filters
charlesknipp Mar 25, 2025
afb3021
update example
charlesknipp Mar 25, 2025
cc92ad8
fixed type signatures for bootstrap filter
charlesknipp Mar 27, 2025
55aaa30
guess who forgot to run the formatter??
charlesknipp Mar 28, 2025
a92f455
update forward algorithm
charlesknipp Apr 3, 2025
238661a
Merge remote-tracking branch 'origin/main' into ck/guided-filter
charlesknipp Apr 3, 2025
265ed38
additional merge fixes
charlesknipp Apr 3, 2025
c90d239
remove redundant files
charlesknipp Apr 3, 2025
5762295
consolidate MLE examples
charlesknipp Apr 3, 2025
1e95b33
suggested changes and house cleaning
charlesknipp Apr 4, 2025
93c86e2
update proposal definition
charlesknipp Apr 4, 2025
243f86a
add unit testing
charlesknipp Apr 4, 2025
0dcb20f
formatter
charlesknipp Apr 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions GeneralisedFilters/src/GeneralisedFilters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ function initialise(model, alg; kwargs...)
return initialise(default_rng(), model, alg; kwargs...)
end

function predict(model, alg, step, filtered; kwargs...)
function predict(model, alg, step, filtered, observation; kwargs...)
return predict(default_rng(), model, alg, step, filtered; kwargs...)
end

Expand Down Expand Up @@ -108,7 +108,7 @@ function step(
callback::Union{AbstractCallback,Nothing}=nothing,
kwargs...,
)
state = predict(rng, model, alg, iter, state; kwargs...)
state = predict(rng, model, alg, iter, state, observation; kwargs...)
isnothing(callback) ||
callback(model, alg, iter, state, observation, PostPredict; kwargs...)

Expand All @@ -132,7 +132,7 @@ include("models/discrete.jl")
include("models/hierarchical.jl")

# Filtering/smoothing algorithms
include("algorithms/bootstrap.jl")
include("algorithms/particles.jl")
include("algorithms/kalman.jl")
include("algorithms/forward.jl")
include("algorithms/rbpf.jl")
Expand Down
121 changes: 0 additions & 121 deletions GeneralisedFilters/src/algorithms/bootstrap.jl

This file was deleted.

3 changes: 2 additions & 1 deletion GeneralisedFilters/src/algorithms/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ function predict(
model::DiscreteStateSpaceModel{T},
filter::ForwardAlgorithm,
step::Integer,
states::AbstractVector;
states::AbstractVector,
observation;
kwargs...,
) where {T}
P = calc_P(model.dyn, step; kwargs...)
Expand Down
64 changes: 32 additions & 32 deletions GeneralisedFilters/src/algorithms/kalman.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
export KalmanFilter, filter, BatchKalmanFilter
using GaussianDistributions
using CUDA: i32
import LinearAlgebra: hermitianpart

export KalmanFilter, KF, KalmanSmoother, KS

Expand All @@ -18,42 +19,40 @@ end
function predict(
rng::AbstractRNG,
model::LinearGaussianStateSpaceModel,
filter::KalmanFilter,
step::Integer,
filtered::Gaussian;
algo::KalmanFilter,
iter::Integer,
state::Gaussian,
observation=nothing;
kwargs...,
)
μ, Σ = GaussianDistributions.pair(filtered)
A, b, Q = calc_params(model.dyn, step; kwargs...)
μ, Σ = GaussianDistributions.pair(state)
A, b, Q = calc_params(model.dyn, iter; kwargs...)
return Gaussian(A * μ + b, A * Σ * A' + Q)
end

function update(
model::LinearGaussianStateSpaceModel,
filter::KalmanFilter,
step::Integer,
proposed::Gaussian,
obs::AbstractVector;
algo::KalmanFilter,
iter::Integer,
state::Gaussian,
observation::AbstractVector;
kwargs...,
)
μ, Σ = GaussianDistributions.pair(proposed)
H, c, R = calc_params(model.obs, step; kwargs...)
μ, Σ = GaussianDistributions.pair(state)
H, c, R = calc_params(model.obs, iter; kwargs...)

# Update state
m = H * μ + c
y = obs - m
S = H * Σ * H' + R
y = observation - m
S = hermitianpart(H * Σ * H' + R)
K = Σ * H' / S

# HACK: force the covariance to be positive definite
S = (S + S') / 2

filtered = Gaussian(μ + K * y, Σ - K * H * Σ)
state = Gaussian(μ + K * y, Σ - K * H * Σ)

# Compute log-likelihood
ll = logpdf(MvNormal(m, S), obs)
ll = logpdf(MvNormal(m, S), observation)

return filtered, ll
return state, ll
end

struct BatchKalmanFilter <: AbstractBatchFilter
Expand All @@ -74,12 +73,13 @@ function predict(
rng::AbstractRNG,
model::LinearGaussianStateSpaceModel{T},
algo::BatchKalmanFilter,
step::Integer,
state::BatchGaussianDistribution;
iter::Integer,
state::BatchGaussianDistribution,
observation;
kwargs...,
) where {T}
μs, Σs = state.μs, state.Σs
As, bs, Qs = batch_calc_params(model.dyn, step, algo.batch_size; kwargs...)
As, bs, Qs = batch_calc_params(model.dyn, iter, algo.batch_size; kwargs...)
μ̂s = NNlib.batched_vec(As, μs) .+ bs
Σ̂s = NNlib.batched_mul(NNlib.batched_mul(As, Σs), NNlib.batched_transpose(As)) .+ Qs
return BatchGaussianDistribution(μ̂s, Σ̂s)
Expand All @@ -88,17 +88,17 @@ end
function update(
model::LinearGaussianStateSpaceModel{T},
algo::BatchKalmanFilter,
step::Integer,
iter::Integer,
state::BatchGaussianDistribution,
obs;
observation;
kwargs...,
) where {T}
μs, Σs = state.μs, state.Σs
Hs, cs, Rs = batch_calc_params(model.obs, step, algo.batch_size; kwargs...)
D = size(obs, 1)
Hs, cs, Rs = batch_calc_params(model.obs, iter, algo.batch_size; kwargs...)
D = size(observation, 1)

m = NNlib.batched_vec(Hs, μs) .+ cs
y_res = cu(obs) .- m
y_res = cu(observation) .- m
S = NNlib.batched_mul(Hs, NNlib.batched_mul(Σs, NNlib.batched_transpose(Hs))) .+ Rs

ΣH_T = NNlib.batched_mul(Σs, NNlib.batched_transpose(Hs))
Expand Down Expand Up @@ -151,7 +151,7 @@ function (callback::StateCallback)(
algo::KalmanFilter,
iter::Integer,
state,
obs,
observation,
::PostPredictCallback;
kwargs...,
)
Expand All @@ -164,7 +164,7 @@ function (callback::StateCallback)(
algo::KalmanFilter,
iter::Integer,
state,
obs,
observation,
::PostUpdateCallback;
kwargs...,
)
Expand All @@ -175,7 +175,7 @@ end
function smooth(
rng::AbstractRNG,
model::LinearGaussianStateSpaceModel{T},
alg::KalmanSmoother,
algo::KalmanSmoother,
observations::AbstractVector;
t_smooth=1,
callback=nothing,
Expand All @@ -190,7 +190,7 @@ function smooth(
back_state = filtered
for t in (length(observations) - 1):-1:t_smooth
back_state = backward(
rng, model, alg, t, back_state, observations[t]; states_cache=cache, kwargs...
rng, model, algo, t, back_state, observations[t]; states_cache=cache, kwargs...
)
end

Expand All @@ -200,7 +200,7 @@ end
function backward(
rng::AbstractRNG,
model::LinearGaussianStateSpaceModel{T},
alg::KalmanSmoother,
algo::KalmanSmoother,
iter::Integer,
back_state,
obs;
Expand Down
Loading
Loading