Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 0 additions & 99 deletions probdiffeq/impl/_conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,101 +14,6 @@
from probdiffeq.util import cholesky_util


class TransformBackend(abc.ABC):
@abc.abstractmethod
def marginalise(self, rv, transformation, /):
raise NotImplementedError

@abc.abstractmethod
def revert(self, rv, transformation, /):
raise NotImplementedError


class DenseTransform(TransformBackend):
def marginalise(self, rv, transformation, /):
A, b = transformation
cholesky_new = cholesky_util.triu_via_qr((A @ rv.cholesky).T).T
return _normal.Normal(A @ rv.mean + b, cholesky_new)

def revert(self, rv, transformation, /):
A, b = transformation
mean, cholesky = rv.mean, rv.cholesky

# QR-decomposition
# (todo: rename revert_conditional_noisefree to
# revert_transformation_cov_sqrt())
r_obs, (r_cor, gain) = cholesky_util.revert_conditional_noisefree(
R_X_F=(A @ cholesky).T, R_X=cholesky.T
)

# Gather terms and return
m_cor = mean - gain @ (A @ mean + b)
corrected = _normal.Normal(m_cor, r_cor.T)
observed = _normal.Normal(A @ mean + b, r_obs.T)
return observed, Conditional(gain, corrected)


class IsotropicTransform(TransformBackend):
def marginalise(self, rv, transformation, /):
A, b = transformation
mean, cholesky = rv.mean, rv.cholesky
cholesky_new = cholesky_util.triu_via_qr((A @ cholesky).T)
cholesky_squeezed = np.reshape(cholesky_new, ())
return _normal.Normal((A @ mean) + b, cholesky_squeezed)

def revert(self, rv, transformation, /):
A, b = transformation
mean, cholesky = rv.mean, rv.cholesky

# QR-decomposition
# (todo: rename revert_conditional_noisefree
# to revert_transformation_cov_sqrt())
r_obs, (r_cor, gain) = cholesky_util.revert_conditional_noisefree(
R_X_F=(A @ cholesky).T, R_X=cholesky.T
)
cholesky_obs = np.reshape(r_obs, ())
cholesky_cor = r_cor.T

# Gather terms and return
mean_observed = A @ mean + b
m_cor = mean - gain * mean_observed
corrected = _normal.Normal(m_cor, cholesky_cor)
observed = _normal.Normal(mean_observed, cholesky_obs)
return observed, Conditional(gain, corrected)


class BlockDiagTransform(TransformBackend):
def __init__(self, ode_shape):
self.ode_shape = ode_shape

def marginalise(self, rv, transformation, /):
A, b = transformation
mean, cholesky = rv.mean, rv.cholesky

A_cholesky = A @ cholesky
cholesky = functools.vmap(cholesky_util.triu_via_qr)(_transpose(A_cholesky))

mean = functools.vmap(lambda x, y, z: x @ y + z)(A, mean, b)
return _normal.Normal(mean, cholesky)

def revert(self, rv, transformation, /):
A, bias = transformation
cholesky_upper = np.transpose(rv.cholesky, axes=(0, -1, -2))
A_cholesky_upper = _transpose(A @ rv.cholesky)

revert_fun = functools.vmap(cholesky_util.revert_conditional_noisefree)
r_obs, (r_cor, gain) = revert_fun(A_cholesky_upper, cholesky_upper)
cholesky_obs = _transpose(r_obs)
cholesky_cor = _transpose(r_cor)

# Gather terms and return
mean_observed = functools.vmap(lambda x, y, z: x @ y + z)(A, rv.mean, bias)
m_cor = rv.mean - (gain * (mean_observed[..., None]))[..., 0]
corrected = _normal.Normal(m_cor, cholesky_cor)
observed = _normal.Normal(mean_observed, cholesky_obs)
return observed, Conditional(gain, corrected)


class Conditional(containers.NamedTuple):
"""Conditional distributions."""

Expand Down Expand Up @@ -550,10 +455,6 @@ def preconditioner_prepare(*, num_derivatives):
return tree_util.Partial(preconditioner_diagonal, scales=scales, powers=powers)


def _hilbert(a):
return 1 / (a[:, None] + a[None, :] + 1)


def _pascal(a, /):
return _batch_gram(_binom)(a[:, None], a[None, :])

Expand Down
78 changes: 53 additions & 25 deletions probdiffeq/impl/_linearise.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
from probdiffeq.backend import abc, functools
from probdiffeq.backend import numpy as np
from probdiffeq.backend.typing import Callable
from probdiffeq.impl import _normal
from probdiffeq.util import cholesky_util


class LinearisationBackend(abc.ABC):
@abc.abstractmethod
def ode_taylor_0th(self, ode_order):
def ode_taylor_0th(self, ode_order: int, damp: float) -> _normal.Normal:
raise NotImplementedError

@abc.abstractmethod
def ode_taylor_1st(self, ode_order):
def ode_taylor_1st(self, ode_order: int, damp: float) -> _normal.Normal:
raise NotImplementedError

@abc.abstractmethod
def ode_statistical_1st(self, cubature_fun): # ode_order > 1 not supported
def ode_statistical_1st(
self, cubature_fun: Callable, damp: float
) -> _normal.Normal:
raise NotImplementedError

@abc.abstractmethod
def ode_statistical_0th(self, cubature_fun): # ode_order > 1 not supported
def ode_statistical_0th(
self, cubature_fun: Callable, damp: float
) -> _normal.Normal:
raise NotImplementedError


Expand All @@ -27,8 +32,9 @@ def __init__(self, ode_shape, unravel):
self.ode_shape = ode_shape
self.unravel = unravel

def ode_taylor_0th(self, ode_order):
def linearise_fun_wrapped(fun, mean):
def ode_taylor_0th(self, ode_order, damp: float):
def linearise_fun_wrapped(fun, rv):
mean = rv.mean
a0 = functools.partial(self._select_dy, idx_or_slice=slice(0, ode_order))
a1 = functools.partial(self._select_dy, idx_or_slice=ode_order)

Expand All @@ -40,12 +46,15 @@ def linearise_fun_wrapped(fun, mean):
linop = _jac_materialize(
lambda v, _p: self._autobatch_linop(a1)(v), inputs=mean
)
return linop, -fx
cov_lower = damp * np.eye(len(fx))
bias = _normal.Normal(-fx, cov_lower)
return linop, bias

return linearise_fun_wrapped

def ode_taylor_1st(self, ode_order):
def new(fun, mean, /):
def ode_taylor_1st(self, ode_order, damp):
def new(fun, rv, /):
mean = rv.mean
a0 = functools.partial(self._select_dy, idx_or_slice=slice(0, ode_order))
a1 = functools.partial(self._select_dy, idx_or_slice=ode_order)

Expand All @@ -62,11 +71,13 @@ def A(x):
return x1 - jvp(x0)

linop = _jac_materialize(lambda v, _p: A(v), inputs=mean)
return linop, -fx
cov_lower = damp * np.eye(len(fx))
bias = _normal.Normal(-fx, cov_lower)
return linop, bias

return new

def ode_statistical_1st(self, cubature_fun):
def ode_statistical_1st(self, cubature_fun, damp: float):
cubature_rule = cubature_fun(input_shape=self.ode_shape)
linearise_fun = functools.partial(self.slr1, cubature_rule=cubature_rule)

Expand All @@ -91,14 +102,18 @@ def A(x):
return a1(x) - J @ a0(x)

linop = _jac_materialize(lambda v, _p: A(v), inputs=rv.mean)

mean, cov_lower = noise.mean, noise.cholesky

# Include the damping term. (TODO: use a single qr?)
damping = damp * np.eye(len(cov_lower))
stack = np.concatenate((cov_lower.T, damping.T))
cov_lower = cholesky_util.triu_via_qr(stack).T
bias = _normal.Normal(-mean, cov_lower)
return linop, bias

return new

def ode_statistical_0th(self, cubature_fun):
def ode_statistical_0th(self, cubature_fun, damp: float):
cubature_rule = cubature_fun(input_shape=self.ode_shape)
linearise_fun = functools.partial(self.slr0, cubature_rule=cubature_rule)

Expand All @@ -119,6 +134,12 @@ def new(fun, rv, /):
# Gather the variables and return
noise = linearise_fun(fun, linearisation_pt)
mean, cov_lower = noise.mean, noise.cholesky

# Include the damping term. (TODO: use a single qr?)
damping = damp * np.eye(len(cov_lower))
stack = np.concatenate((cov_lower.T, damping.T))
cov_lower = cholesky_util.triu_via_qr(stack).T

bias = _normal.Normal(-mean, cov_lower)
linop = _jac_materialize(lambda v, _p: a1(v), inputs=rv.mean)
return linop, bias
Expand Down Expand Up @@ -195,23 +216,26 @@ def slr0(fn, x, *, cubature_rule):


class IsotropicLinearisation(LinearisationBackend):
def ode_taylor_1st(self, ode_order):
def ode_taylor_1st(self, ode_order, damp: float):
raise NotImplementedError

def ode_taylor_0th(self, ode_order):
def linearise_fun_wrapped(fun, mean):
def ode_taylor_0th(self, ode_order, damp: float):
def linearise_fun_wrapped(fun, rv):
mean = rv.mean
fx = self.ts0(fun, mean[:ode_order, ...])
linop = _jac_materialize(
lambda s, _p: s[[ode_order], ...], inputs=mean[:, 0]
)
return linop, -fx
cov_lower = damp * np.eye(1)
bias = _normal.Normal(-fx, cov_lower)
return linop, bias

return linearise_fun_wrapped

def ode_statistical_0th(self, cubature_fun):
def ode_statistical_0th(self, cubature_fun, damp: float):
raise NotImplementedError

def ode_statistical_1st(self, cubature_fun):
def ode_statistical_1st(self, cubature_fun, damp: float):
raise NotImplementedError

@staticmethod
Expand All @@ -220,8 +244,9 @@ def ts0(fn, m):


class BlockDiagLinearisation(LinearisationBackend):
def ode_taylor_0th(self, ode_order):
def linearise_fun_wrapped(fun, mean):
def ode_taylor_0th(self, ode_order, damp: float):
def linearise_fun_wrapped(fun, rv):
mean = rv.mean
m0 = mean[:, :ode_order]
fx = self.ts0(fun, m0.T)

Expand All @@ -233,17 +258,20 @@ def lo(s):
return _jac_materialize(lambda v, _p: a1(v), inputs=s)

linop = lo(mean)
return linop, -fx[:, None]
d, *_ = linop.shape
cov_lower = damp * np.ones((d, 1, 1))
bias = _normal.Normal(-fx[:, None], cov_lower)
return linop, bias

return linearise_fun_wrapped

def ode_taylor_1st(self, ode_order):
def ode_taylor_1st(self, ode_order, damp: float):
raise NotImplementedError

def ode_statistical_0th(self, cubature_fun):
def ode_statistical_0th(self, cubature_fun, damp: float):
raise NotImplementedError

def ode_statistical_1st(self, cubature_fun):
def ode_statistical_1st(self, cubature_fun, damp: float):
raise NotImplementedError

@staticmethod
Expand Down
7 changes: 0 additions & 7 deletions probdiffeq/impl/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class FactImpl:
stats: _stats.StatsBackend
linearise: _linearise.LinearisationBackend
conditional: _conditional.ConditionalBackend
transform: _conditional.TransformBackend

# To assert a valid tree_equal of solutions, the factorisations
# must be comparable.
Expand Down Expand Up @@ -61,11 +60,9 @@ def _select_dense(*, tcoeffs_like) -> FactImpl:
unravel=unravel,
flat_shape=flat.shape,
)
transform = _conditional.DenseTransform()
return FactImpl(
name="dense",
linearise=linearise,
transform=transform,
conditional=conditional,
normal=normal,
prototypes=prototypes,
Expand All @@ -88,15 +85,13 @@ def _select_isotropic(*, tcoeffs_like) -> FactImpl:
conditional = _conditional.IsotropicConditional(
ode_shape=ode_shape, num_derivatives=num_derivatives, unravel_tree=unravel_tree
)
transform = _conditional.IsotropicTransform()
return FactImpl(
name="isotropic",
prototypes=prototypes,
normal=normal,
stats=stats,
linearise=linearise,
conditional=conditional,
transform=transform,
)


Expand All @@ -115,13 +110,11 @@ def _select_blockdiag(*, tcoeffs_like) -> FactImpl:
conditional = _conditional.BlockDiagConditional(
ode_shape=ode_shape, num_derivatives=num_derivatives, unravel_tree=unravel_tree
)
transform = _conditional.BlockDiagTransform(ode_shape=ode_shape)
return FactImpl(
name="blockdiag",
prototypes=prototypes,
normal=normal,
stats=stats,
linearise=linearise,
conditional=conditional,
transform=transform,
)
Loading