Skip to content

Commit 6cc025f

Browse files
authored
Remove Taylor-coefficient-scaling in doubling (#838)
1 parent 75b69d9 commit 6cc025f

2 files changed

Lines changed: 7 additions & 5 deletions

File tree

probdiffeq/backend/functools.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@ def jit(func, /, static_argnums=None, static_argnames=None):
1818
return jax.jit(func, static_argnums=static_argnums, static_argnames=static_argnames)
1919

2020

21-
def jet(func, /, primals, series):
22-
return jax.experimental.jet.jet(func, primals=primals, series=series)
21+
def jet(func, /, primals, series, *, is_tcoeff=False):
22+
return jax.experimental.jet.jet(
23+
func, primals=primals, series=series, factorial_scaled=not is_tcoeff
24+
)
2325

2426

2527
def linearize(func, *args):

probdiffeq/taylor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,9 +275,9 @@ def jet_embedded(*c, degree):
275275
(compared to unnormalised coefficients).
276276
"""
277277
coeffs_emb = [*c] + [zeros] * degree
278-
p, *s = _unnormalise(*coeffs_emb)
279-
p_new, s_new = functools.jet(vf, (p,), (s,))
280-
return _normalise(p_new, *s_new)
278+
p, *s = coeffs_emb
279+
p_new, s_new = functools.jet(vf, (p,), (s,), is_tcoeff=True)
280+
return p_new, *s_new
281281

282282
taylor_coefficients = [u0]
283283
degrees = list(itertools.accumulate(map(lambda s: 2**s, range(num_doublings))))

0 commit comments

Comments
 (0)