Skip to content

Commit f5eb0eb

Browse files
committed
Rename backend.containers into backend.structs to make code more compact
1 parent ed94830 commit f5eb0eb

10 files changed

Lines changed: 32 additions & 32 deletions

File tree

probdiffeq/impl/_conditional.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"""LatentConds."""
22

3-
from probdiffeq.backend import abc, containers, func, linalg, np, tree
3+
from probdiffeq.backend import abc, func, linalg, np, structs, tree
44
from probdiffeq.backend.typing import Array
55
from probdiffeq.impl import _normal, _stats
66
from probdiffeq.util import cholesky_util
77

88

99
@tree.register_dataclass
10-
@containers.dataclass
10+
@structs.dataclass
1111
class LatentCond:
1212
"""Conditional distributions in latent space."""
1313

probdiffeq/impl/_linearize.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from probdiffeq.backend import abc, containers, func, linalg, np, random, tree
1+
from probdiffeq.backend import abc, func, linalg, np, random, structs, tree
22
from probdiffeq.backend.typing import Any, Callable
33
from probdiffeq.impl import _conditional, _normal
44
from probdiffeq.util import cholesky_util
@@ -14,7 +14,7 @@ def linearize(self, fun, rv, state: None, *, damp: float):
1414
raise NotImplementedError
1515

1616

17-
@containers.dataclass
17+
@structs.dataclass
1818
class DenseTs0(Linearization):
1919
ode_order: int
2020
ode_shape: tuple
@@ -42,7 +42,7 @@ def a1(m):
4242
return cond, None
4343

4444

45-
@containers.dataclass
45+
@structs.dataclass
4646
class DenseTs1(Linearization):
4747
ode_order: int
4848
ode_shape: tuple
@@ -76,7 +76,7 @@ def constraint(m):
7676
return cond, None
7777

7878

79-
@containers.dataclass
79+
@structs.dataclass
8080
class DenseSlr0(Linearization):
8181
cubature_rule: Any
8282
ode_shape: tuple
@@ -149,7 +149,7 @@ def slr0(self, fn, x):
149149
return _normal.Normal(fx_mean, cov_sqrtm.T)
150150

151151

152-
@containers.dataclass
152+
@structs.dataclass
153153
class DenseSlr1(Linearization):
154154
cubature_rule: Any
155155
ode_shape: tuple
@@ -224,7 +224,7 @@ def slr1(self, fn, x):
224224
return linop_cond, rv_cond
225225

226226

227-
@containers.dataclass
227+
@structs.dataclass
228228
class IsotropicTs0(Linearization):
229229
ode_order: int
230230
unravel: Callable
@@ -253,7 +253,7 @@ def a1(m):
253253
return cond, None
254254

255255

256-
@containers.dataclass
256+
@structs.dataclass
257257
class IsotropicTs1(Linearization):
258258
ode_order: int
259259
unravel: Callable
@@ -305,7 +305,7 @@ def select_0(s):
305305
return cond, key
306306

307307

308-
@containers.dataclass
308+
@structs.dataclass
309309
class BlockDiagTs0(Linearization):
310310
ode_order: int
311311
unravel: Callable
@@ -336,7 +336,7 @@ def a1(s):
336336
return cond, None
337337

338338

339-
@containers.dataclass
339+
@structs.dataclass
340340
class BlockDiagTs1(Linearization):
341341
ode_order: int
342342
unravel: Callable

probdiffeq/impl/_normal.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
from probdiffeq.backend import abc, containers, linalg, np, tree
1+
from probdiffeq.backend import abc, linalg, np, structs, tree
22
from probdiffeq.backend.typing import Generic, Sequence, TypeVar
33

44
T = TypeVar("T")
55

66

77
@tree.register_dataclass
8-
@containers.dataclass
8+
@structs.dataclass
99
class Normal(Generic[T]):
1010
mean: T
1111
cholesky: T

probdiffeq/impl/impl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
"""State-space model implementations."""
22

3-
from probdiffeq.backend import containers, func, tree
3+
from probdiffeq.backend import func, structs, tree
44
from probdiffeq.backend.typing import Callable
55
from probdiffeq.impl import _conditional, _linearize, _normal, _prototypes, _stats
66

77

8-
@containers.dataclass
8+
@structs.dataclass
99
class FactImpl:
1010
"""Implementation of factorized state-space models."""
1111

probdiffeq/ivpsolve.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
See the tutorials for example use cases.
88
"""
99

10-
from probdiffeq.backend import containers, flow, func, linalg, np, tree, warnings
10+
from probdiffeq.backend import flow, func, linalg, np, structs, tree, warnings
1111
from probdiffeq.backend.typing import Any, Array, Callable, Generic, Protocol, TypeVar
1212

1313
T_contra = TypeVar("T_contra", contravariant=True)
@@ -225,7 +225,7 @@ def advance(sol_and_state: tuple, t_next) -> tuple[tuple, Any]:
225225
"""
226226

227227
@tree.register_dataclass
228-
@containers.dataclass
228+
@structs.dataclass
229229
class AdvanceState:
230230
do_continue: bool
231231
solution: Any
@@ -332,7 +332,7 @@ def dt0_adaptive(vf, initial_values, /, t0, *, error_contraction_rate, rtol, ato
332332

333333

334334
@tree.register_dataclass
335-
@containers.dataclass
335+
@structs.dataclass
336336
class TimeStepState(Generic[T]):
337337
"""A state variable type for adaptive time-stepping."""
338338

@@ -359,7 +359,7 @@ class TimeStepState(Generic[T]):
359359

360360

361361
@tree.register_dataclass
362-
@containers.dataclass
362+
@structs.dataclass
363363
class _RejectionLoopState:
364364
"""State for a single rejection loop.
365365

probdiffeq/probdiffeq.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
See the tutorials for example use cases.
44
"""
55

6-
from probdiffeq.backend import containers, flow, func, linalg, np, random, special, tree
6+
from probdiffeq.backend import flow, func, linalg, np, random, special, structs, tree
77
from probdiffeq.backend.typing import (
88
Any,
99
Array,
@@ -39,7 +39,7 @@ def __call__(self, *ys: T, t: ArrayLike) -> T: ...
3939

4040

4141
@tree.register_dataclass
42-
@containers.dataclass
42+
@structs.dataclass
4343
class CubaturePositiveWeights:
4444
"""A datastructure for cubature rules that have positive weights.
4545
@@ -220,7 +220,7 @@ def constraint_ode_slr1(*, ssm, cubature_fun=cubature_third_order_spherical):
220220

221221

222222
@tree.register_dataclass
223-
@containers.dataclass
223+
@structs.dataclass
224224
class TaylorCoeffTarget(Generic[C, T]):
225225
"""A probabilistic description of Taylor coefficients.
226226
@@ -240,7 +240,7 @@ class TaylorCoeffTarget(Generic[C, T]):
240240

241241

242242
@tree.register_dataclass
243-
@containers.dataclass
243+
@structs.dataclass
244244
class MarkovSequence(Generic[T]):
245245
"""A datastructure for Markov sequences as batches of joint distributions.
246246
@@ -492,7 +492,7 @@ def log_marginal_likelihood(self, u, /, *, standard_deviation, posterior: T):
492492

493493

494494
@tree.register_dataclass
495-
@containers.dataclass
495+
@structs.dataclass
496496
class ProbabilisticSolution(Generic[C, T]):
497497
"""A datastructure for probabilistic solutions of differential equations."""
498498

@@ -807,7 +807,7 @@ def prior_wiener_integrated_discrete(
807807

808808

809809
@tree.register_dataclass
810-
@containers.dataclass
810+
@structs.dataclass
811811
class _InterpRes(Generic[T]):
812812
"""A datastructure to store interpolation results.
813813

probdiffeq/util/filter_util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Mostly **discrete** filtering and smoothing.
44
"""
55

6-
from probdiffeq.backend import containers, flow, tree
6+
from probdiffeq.backend import flow, structs, tree
77
from probdiffeq.backend.typing import Any
88

99

@@ -40,7 +40,7 @@ def fixedpointsmoother_precon(*, ssm):
4040
"""Construct a discrete, preconditioned fixedpoint-smoother."""
4141

4242
@tree.register_dataclass
43-
@containers.dataclass
43+
@structs.dataclass
4444
class _FPState:
4545
rv: Any
4646
conditional: Any
@@ -78,7 +78,7 @@ def kalmanfilter_with_marginal_likelihood(*, ssm):
7878
"""Construct a Kalman-filter-implementation of computing the marginal likelihood."""
7979

8080
@tree.register_dataclass
81-
@containers.dataclass
81+
@structs.dataclass
8282
class _KFState:
8383
rv: Any
8484
num_data_points: float

tests/test_ivpsolve/test_fixed_grid_matches_adaptive_grid.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Compare solve_fixed_grid to solve_adaptive_save_every_step."""
22

33
from probdiffeq import ivpsolve, probdiffeq, taylor
4-
from probdiffeq.backend import containers, func, np, ode, testing, tree
4+
from probdiffeq.backend import func, np, ode, structs, testing, tree
55
from probdiffeq.backend.typing import Array
66
from probdiffeq.util import test_util
77

@@ -10,7 +10,7 @@
1010
def test_fixed_grid_result_matches_adaptive_grid_result_when_reusing_grid(fact):
1111
vf, u0, (t0, t1) = ode.ivp_lotka_volterra()
1212

13-
class Taylor(containers.NamedTuple):
13+
class Taylor(structs.NamedTuple):
1414
state: Array
1515
velocity: Array
1616
acceleration: Array

tests/test_ivpsolve/test_solution_api_respects_pytree_structure.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
"""Tests for interaction with the solution API."""
22

33
from probdiffeq import ivpsolve, probdiffeq, taylor
4-
from probdiffeq.backend import containers, func, np, ode, testing
4+
from probdiffeq.backend import func, np, ode, structs, testing
55
from probdiffeq.backend.typing import Array
66

77

8-
class Taylor(containers.NamedTuple):
8+
class Taylor(structs.NamedTuple):
99
"""A non-standard Taylor-coefficient data structure."""
1010

1111
state: Array

0 commit comments

Comments
 (0)