diff --git a/probdiffeq/backend/functools.py b/probdiffeq/backend/functools.py index ef6bccfc5..b74ae7048 100644 --- a/probdiffeq/backend/functools.py +++ b/probdiffeq/backend/functools.py @@ -18,8 +18,10 @@ def jit(func, /, static_argnums=None, static_argnames=None): return jax.jit(func, static_argnums=static_argnums, static_argnames=static_argnames) -def jet(func, /, primals, series): - return jax.experimental.jet.jet(func, primals=primals, series=series) +def jet(func, /, primals, series, *, is_tcoeff=False): + return jax.experimental.jet.jet( + func, primals=primals, series=series, factorial_scaled=not is_tcoeff + ) def linearize(func, *args): diff --git a/probdiffeq/taylor.py b/probdiffeq/taylor.py index fdd9fed73..0b4524592 100644 --- a/probdiffeq/taylor.py +++ b/probdiffeq/taylor.py @@ -275,9 +275,9 @@ def jet_embedded(*c, degree): (compared to unnormalised coefficients). """ coeffs_emb = [*c] + [zeros] * degree - p, *s = _unnormalise(*coeffs_emb) - p_new, s_new = functools.jet(vf, (p,), (s,)) - return _normalise(p_new, *s_new) + p, *s = coeffs_emb + p_new, s_new = functools.jet(vf, (p,), (s,), is_tcoeff=True) + return p_new, *s_new taylor_coefficients = [u0] degrees = list(itertools.accumulate(map(lambda s: 2**s, range(num_doublings))))