Skip to content

Commit 36587fe

Browse files
authored
Outsource Jacobian handling from linearisations (#847)
* Move ssm._linearize content to ssm._conditional because they operate on the same datastructures * Implement Jacobian handling in a dedicated object to clarify stochastic trace estimation * Separate Hutchinson from materialisation and forward/reverse autodiff * Make all JacobianHandler arguments keyword-args for now
1 parent a890b80 commit 36587fe

7 files changed

Lines changed: 645 additions & 502 deletions

File tree

probdiffeq/backend/func.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ def linearize(func, *args):
2828
return jax.linearize(func, *args)
2929

3030

31+
def vjp(func, *args):
32+
return jax.vjp(func, *args)
33+
34+
3135
def jvp(func, /, primals, tangents):
3236
return jax.jvp(func, primals, tangents)
3337

probdiffeq/backend/typing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Typing module."""
22

33
from collections.abc import Callable, Sequence
4-
from typing import Any, Generic, Protocol, TypeVar
4+
from typing import Any, Generic, Literal, Protocol, TypeVar
55

66
from jax import Array
77
from jax.typing import ArrayLike

0 commit comments

Comments
 (0)