1- from probdiffeq .backend import abc , containers , func , linalg , np , random , tree
1+ from probdiffeq .backend import abc , func , linalg , np , random , structs , tree
22from probdiffeq .backend .typing import Any , Callable
33from probdiffeq .impl import _conditional , _normal
44from probdiffeq .util import cholesky_util
@@ -14,7 +14,7 @@ def linearize(self, fun, rv, state: None, *, damp: float):
1414 raise NotImplementedError
1515
1616
17- @containers .dataclass
17+ @structs .dataclass
1818class DenseTs0 (Linearization ):
1919 ode_order : int
2020 ode_shape : tuple
@@ -42,7 +42,7 @@ def a1(m):
4242 return cond , None
4343
4444
45- @containers .dataclass
45+ @structs .dataclass
4646class DenseTs1 (Linearization ):
4747 ode_order : int
4848 ode_shape : tuple
@@ -76,7 +76,7 @@ def constraint(m):
7676 return cond , None
7777
7878
79- @containers .dataclass
79+ @structs .dataclass
8080class DenseSlr0 (Linearization ):
8181 cubature_rule : Any
8282 ode_shape : tuple
@@ -149,7 +149,7 @@ def slr0(self, fn, x):
149149 return _normal .Normal (fx_mean , cov_sqrtm .T )
150150
151151
152- @containers .dataclass
152+ @structs .dataclass
153153class DenseSlr1 (Linearization ):
154154 cubature_rule : Any
155155 ode_shape : tuple
@@ -224,7 +224,7 @@ def slr1(self, fn, x):
224224 return linop_cond , rv_cond
225225
226226
227- @containers .dataclass
227+ @structs .dataclass
228228class IsotropicTs0 (Linearization ):
229229 ode_order : int
230230 unravel : Callable
@@ -253,7 +253,7 @@ def a1(m):
253253 return cond , None
254254
255255
256- @containers .dataclass
256+ @structs .dataclass
257257class IsotropicTs1 (Linearization ):
258258 ode_order : int
259259 unravel : Callable
@@ -305,7 +305,7 @@ def select_0(s):
305305 return cond , key
306306
307307
308- @containers .dataclass
308+ @structs .dataclass
309309class BlockDiagTs0 (Linearization ):
310310 ode_order : int
311311 unravel : Callable
@@ -336,7 +336,7 @@ def a1(s):
336336 return cond , None
337337
338338
339- @containers .dataclass
339+ @structs .dataclass
340340class BlockDiagTs1 (Linearization ):
341341 ode_order : int
342342 unravel : Callable
0 commit comments