Skip to content

Commit 4e1d5ae

Browse files
authored
Implement a closed form Cholesky factor of a Hilbert matrix (#817)
* Add an error message for non-vector-valued initial conditions * Improve the visuals of the ODE in the quickstart example * Update the ruff-pre-commit hook * Treat warnings as errors (and ditch Diffrax from tests because it raises warnings) * Allow passing strategies to solvers as keyword arguments * Implement a closed-form Cholesky factor of a Hilbert matrix
1 parent 0b46bad commit 4e1d5ae

3 files changed

Lines changed: 97 additions & 5 deletions

File tree

probdiffeq/backend/control_flow.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ def scan(step_func, /, init, xs, *, reverse=False, length=None):
3131
return _jax_scan(step_func, init=init, xs=xs, reverse=reverse, length=length)
3232

3333

34+
def fori_loop(lower, upper, step_func, /, init):
35+
return jax.lax.fori_loop(lower, upper, step_func, init)
36+
37+
3438
@contextlib.contextmanager
3539
def context_overwrite_while_loop(func, /):
3640
"""Overwrite the while_loop() function.

probdiffeq/backend/numpy.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def reshape(arr, /, new_shape, order="C"):
4848
return jnp.reshape(arr, new_shape, order=order)
4949

5050

51-
def flip(arr, /):
52-
return jnp.flip(arr)
51+
def flip(arr, /, axis=None):
52+
return jnp.flip(arr, axis=axis)
5353

5454

5555
def asarray(x, /):
@@ -214,3 +214,7 @@ def isnan(arr, /):
214214

215215
def linspace(start, stop, *, num=50, endpoint=True):
216216
return jnp.linspace(start, stop, num=num, endpoint=endpoint)
217+
218+
219+
def tril(arr, /):
220+
return jnp.tril(arr)

probdiffeq/impl/_conditional.py

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
"""Conditionals."""
22

3-
from probdiffeq.backend import abc, containers, functools, linalg, tree_util
3+
from probdiffeq.backend import (
4+
abc,
5+
containers,
6+
control_flow,
7+
functools,
8+
linalg,
9+
tree_util,
10+
)
411
from probdiffeq.backend import numpy as np
512
from probdiffeq.backend.typing import Any, Array
613
from probdiffeq.impl import _normal
@@ -447,8 +454,85 @@ def system_matrices_1d(num_derivatives, output_scale):
447454
x = np.arange(0, num_derivatives + 1)
448455

449456
A_1d = np.flip(_pascal(x)[0]) # no idea why the [0] is necessary...
450-
Q_1d = np.flip(_hilbert(x))
451-
return A_1d, output_scale * linalg.cholesky_factor(Q_1d)
457+
458+
# Cholesky factor of flip(hilbert(n))
459+
Q_1d = cholesky_hilbert(num_derivatives + 1)
460+
Q_1d_flipped = np.flip(Q_1d, axis=0)
461+
Q_1d = linalg.qr_r(Q_1d_flipped.T).T
462+
return A_1d, output_scale * Q_1d
463+
464+
465+
def cholesky_hilbert(n: int, K: int = 0):
466+
"""Compute the Cholesky factor of a Hilbert matrix.
467+
468+
This routine implements W. Kahan's stable recurrence (see "Hilbert Matrices",
469+
Math H110 notes) to construct a Cholesky factor.
470+
471+
Parameters
472+
----------
473+
n
474+
Size of the Hilbert matrix (``n x n``).
475+
K
476+
Shift parameter. ``K = 0`` gives the classical Hilbert matrix.
477+
Increasing ``K`` produces related matrices with entries
478+
``1 / (i + j + K - 1)``. Default is 0.
479+
480+
Returns
481+
-------
482+
Lower-triangular Cholesky factor of the Hilbert matrix.
483+
484+
485+
Notes
486+
-----
487+
- Hilbert matrices are notoriously ill-conditioned; even with float64,
488+
the factorization loses accuracy for moderately large ``n`` (≈15 or more).
489+
490+
References
491+
----------
492+
W. Kahan, *Hilbert Matrices*,
493+
https://people.eecs.berkeley.edu/~wkahan/MathH110/HilbMats.pdf
494+
"""
495+
Kf = np.asarray(K)
496+
497+
odds = np.arange(K + 1, K + 2 * n, step=2) # length n
498+
dr = np.sqrt(odds) # shape (n,)
499+
500+
f = np.ones((n,)) * (1.0 + Kf)
501+
502+
def f_body(idx, f):
503+
prev = f[idx - 1]
504+
idxf = np.asarray(idx)
505+
val = (((prev / idxf) * (Kf + 2.0 * idxf)) / (Kf + idxf)) * (
506+
Kf + 2.0 * idxf + 1.0
507+
)
508+
return f.at[idx].set(val)
509+
510+
f = control_flow.fori_loop(1, n, f_body, f)
511+
f = 1.0 / f
512+
513+
U = np.eye(n)
514+
515+
def body_j(j_idx, U):
516+
# compute column j_idx (0-based) of U using downward recurrence
517+
g = U[:, j_idx]
518+
519+
def inner_body(k, g):
520+
# k runs 0..j_idx-1, we want i = j_idx-1-k (descend j-1 .. 0)
521+
i = j_idx - 1 - k
522+
denom = np.asarray(j_idx - i) # == k+1
523+
factor = Kf + np.asarray(i + 1) + np.asarray(j_idx + 1)
524+
newval = (g[i + 1] / denom) * factor
525+
return g.at[i].set(newval)
526+
527+
g = control_flow.fori_loop(0, j_idx, inner_body, g)
528+
return U.at[:, j_idx].set(g)
529+
530+
U = control_flow.fori_loop(1, n, body_j, U)
531+
532+
# scale columns: U = U .* (dr * f_row)
533+
U = U * (dr[:, None] * f[None, :])
534+
535+
return np.tril(U.T)
452536

453537

454538
def preconditioner_diagonal(dt, *, scales, powers):

0 commit comments

Comments
 (0)