Skip to content

Commit 105d5ac

Browse files
committed
Delete dead code
1 parent a3f21f2 commit 105d5ac

2 files changed

Lines changed: 4 additions & 13 deletions

File tree

probdiffeq/impl/_normal.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def from_tcoeffs(self, tcoeffs: list, damp: float = 0.0):
6060
powers = 1 / np.arange(1, len(tcoeffs) + 1)
6161
c_sqrtm0_corrected = linalg.diagonal_matrix(damp**powers)
6262

63-
leaves, structure = tree_util.tree_flatten(tcoeffs)
63+
leaves, _ = tree_util.tree_flatten(tcoeffs)
6464
m0_corrected = np.stack(leaves)
6565
return Normal(m0_corrected, c_sqrtm0_corrected)
6666

@@ -82,7 +82,7 @@ def from_tcoeffs(self, tcoeffs: list, damp: float = 0.0):
8282
cholesky = linalg.diagonal_matrix(damp**powers)
8383
cholesky = np.ones((*self.ode_shape, 1, 1)) * cholesky[None, ...]
8484

85-
leaves, structure = tree_util.tree_flatten(tcoeffs)
85+
leaves, _ = tree_util.tree_flatten(tcoeffs)
8686
mean = np.stack(leaves).T
8787
return Normal(mean, cholesky)
8888

probdiffeq/ivpsolvers.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -441,24 +441,22 @@ class _Correction:
441441
linearize: Callable
442442
vector_field: Callable
443443

444-
can_handle_higher_order: bool
445-
446444
def init(self, x, /):
447445
"""Initialise the state from the solution."""
448446
y = self.ssm.prototypes.observed()
449447
return x, y
450448

451449
def correct(self, rv, /, t):
452450
"""Perform the correction step."""
453-
f_wrapped = self._parametrize_vector_field(t=t)
451+
f_wrapped = functools.partial(self.vector_field, t=t)
454452
cond = self.linearize(f_wrapped, rv)
455453
observed, reverted = self.ssm.conditional.revert(rv, cond)
456454
corrected = reverted.noise
457455
return corrected, observed
458456

459457
def estimate_error(self, rv, /, t):
460458
"""Estimate the error."""
461-
f_wrapped = self._parametrize_vector_field(t=t)
459+
f_wrapped = functools.partial(self.vector_field, t=t)
462460
cond = self.linearize(f_wrapped, rv)
463461
observed = self.ssm.conditional.marginalise(rv, cond)
464462

@@ -469,9 +467,6 @@ def estimate_error(self, rv, /, t):
469467
error_estimate = output_scale * error_estimate_unscaled
470468
return error_estimate, observed
471469

472-
def _parametrize_vector_field(self, *, t):
473-
return functools.partial(self.vector_field, t=t)
474-
475470

476471
def correction_ts0(vector_field, *, ssm, ode_order=1, damp: float = 0.0) -> _Correction:
477472
"""Zeroth-order Taylor linearisation."""
@@ -482,7 +477,6 @@ def correction_ts0(vector_field, *, ssm, ode_order=1, damp: float = 0.0) -> _Cor
482477
ode_order=ode_order,
483478
ssm=ssm,
484479
linearize=linearize,
485-
can_handle_higher_order=True,
486480
)
487481

488482

@@ -495,7 +489,6 @@ def correction_ts1(vector_field, *, ssm, ode_order=1, damp: float = 0.0) -> _Cor
495489
ode_order=ode_order,
496490
ssm=ssm,
497491
linearize=linearize,
498-
can_handle_higher_order=True,
499492
)
500493

501494

@@ -510,7 +503,6 @@ def correction_slr0(
510503
ode_order=1,
511504
linearize=linearize,
512505
name="SLR0",
513-
can_handle_higher_order=False, # TODO: implement this
514506
)
515507

516508

@@ -525,7 +517,6 @@ def correction_slr1(
525517
ode_order=1,
526518
linearize=linearize,
527519
name="SLR1",
528-
can_handle_higher_order=False, # TODO: implement this
529520
)
530521

531522

0 commit comments

Comments
 (0)