|
1 | 1 | """Conditionals.""" |
2 | 2 |
|
3 | | -from probdiffeq.backend import abc, containers, functools, linalg, tree_util |
| 3 | +from probdiffeq.backend import ( |
| 4 | + abc, |
| 5 | + containers, |
| 6 | + control_flow, |
| 7 | + functools, |
| 8 | + linalg, |
| 9 | + tree_util, |
| 10 | +) |
4 | 11 | from probdiffeq.backend import numpy as np |
5 | 12 | from probdiffeq.backend.typing import Any, Array |
6 | 13 | from probdiffeq.impl import _normal |
@@ -447,8 +454,85 @@ def system_matrices_1d(num_derivatives, output_scale): |
447 | 454 | x = np.arange(0, num_derivatives + 1) |
448 | 455 |
|
449 | 456 | A_1d = np.flip(_pascal(x)[0]) # no idea why the [0] is necessary... |
450 | | - Q_1d = np.flip(_hilbert(x)) |
451 | | - return A_1d, output_scale * linalg.cholesky_factor(Q_1d) |
| 457 | + |
| 458 | + # Cholesky factor of flip(hilbert(n)) |
| 459 | + Q_1d = cholesky_hilbert(num_derivatives + 1) |
| 460 | + Q_1d_flipped = np.flip(Q_1d, axis=0) |
| 461 | + Q_1d = linalg.qr_r(Q_1d_flipped.T).T |
| 462 | + return A_1d, output_scale * Q_1d |
| 463 | + |
| 464 | + |
| 465 | +def cholesky_hilbert(n: int, K: int = 0): |
| 466 | + """Compute the Cholesky factor of a Hilbert matrix. |
| 467 | +
|
| 468 | + This routine implements W. Kahan's stable recurrence (see "Hilbert Matrices", |
| 469 | + Math H110 notes) to construct a Cholesky factor. |
| 470 | +
|
| 471 | + Parameters |
| 472 | + ---------- |
| 473 | + n |
| 474 | + Size of the Hilbert matrix (``n x n``). |
| 475 | + K |
| 476 | + Shift parameter. ``K = 0`` gives the classical Hilbert matrix. |
| 477 | + Increasing ``K`` produces related matrices with entries |
| 478 | + ``1 / (i + j + K - 1)``. Default is 0. |
| 479 | +
|
| 480 | + Returns |
| 481 | + ------- |
| 482 | + Lower-triangular Cholesky factor of the Hilbert matrix. |
| 483 | +
|
| 484 | +
|
| 485 | + Notes |
| 486 | + ----- |
| 487 | + - Hilbert matrices are notoriously ill-conditioned; even with float64, |
| 488 | + the factorization loses accuracy for moderately large ``n`` (≈15 or more). |
| 489 | +
|
| 490 | + References |
| 491 | + ---------- |
| 492 | + W. Kahan, *Hilbert Matrices*, |
| 493 | + https://people.eecs.berkeley.edu/~wkahan/MathH110/HilbMats.pdf |
| 494 | + """ |
| 495 | + Kf = np.asarray(K) |
| 496 | + |
| 497 | + odds = np.arange(K + 1, K + 2 * n, step=2) # length n |
| 498 | + dr = np.sqrt(odds) # shape (n,) |
| 499 | + |
| 500 | + f = np.ones((n,)) * (1.0 + Kf) |
| 501 | + |
| 502 | + def f_body(idx, f): |
| 503 | + prev = f[idx - 1] |
| 504 | + idxf = np.asarray(idx) |
| 505 | + val = (((prev / idxf) * (Kf + 2.0 * idxf)) / (Kf + idxf)) * ( |
| 506 | + Kf + 2.0 * idxf + 1.0 |
| 507 | + ) |
| 508 | + return f.at[idx].set(val) |
| 509 | + |
| 510 | + f = control_flow.fori_loop(1, n, f_body, f) |
| 511 | + f = 1.0 / f |
| 512 | + |
| 513 | + U = np.eye(n) |
| 514 | + |
| 515 | + def body_j(j_idx, U): |
| 516 | + # compute column j_idx (0-based) of U using downward recurrence |
| 517 | + g = U[:, j_idx] |
| 518 | + |
| 519 | + def inner_body(k, g): |
| 520 | + # k runs 0..j_idx-1, we want i = j_idx-1-k (descend j-1 .. 0) |
| 521 | + i = j_idx - 1 - k |
| 522 | + denom = np.asarray(j_idx - i) # == k+1 |
| 523 | + factor = Kf + np.asarray(i + 1) + np.asarray(j_idx + 1) |
| 524 | + newval = (g[i + 1] / denom) * factor |
| 525 | + return g.at[i].set(newval) |
| 526 | + |
| 527 | + g = control_flow.fori_loop(0, j_idx, inner_body, g) |
| 528 | + return U.at[:, j_idx].set(g) |
| 529 | + |
| 530 | + U = control_flow.fori_loop(1, n, body_j, U) |
| 531 | + |
| 532 | + # scale columns: U = U .* (dr * f_row) |
| 533 | + U = U * (dr[:, None] * f[None, :]) |
| 534 | + |
| 535 | + return np.tril(U.T) |
452 | 536 |
|
453 | 537 |
|
454 | 538 | def preconditioner_diagonal(dt, *, scales, powers): |
|
0 commit comments