Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 50 additions & 14 deletions lineax/_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,15 +543,20 @@ class JacobianLinearOperator(AbstractLinearOperator):
`MatrixLinearOperator(jax.jacfwd(fn)(x))`.

The Jacobian is not materialised; matrix-vector products, which are in fact
Jacobian-vector products, are computed using autodifferentiation, specifically
`jax.jvp`. Thus, `JacobianLinearOperator(fn, x).mv(v)` is equivalent to
`jax.jvp(fn, (x,), (v,))`.

See also [`lineax.linearise`][], which caches the primal computation, i.e.
it returns `_, lin = jax.linearize(fn, x); FunctionLinearOperator(lin, ...)`
Jacobian-vector products, are computed using autodifferentiation. By default
(or with `jac="fwd"`), this uses `jax.jvp`. With `jac="bwd"`, this uses
`jax.vjp` combined with `jax.linear_transpose`, which works even with functions
that only define a custom VJP (via `jax.custom_vjp`) and don't support
forward-mode differentiation.

See also [`lineax.materialise`][], which materialises the whole Jacobian in
memory.

!!! tip

For repeated `mv()` calls, consider using [`lineax.linearise`][] to cache
the primal computation. This is especially beneficial with `jac="bwd"`
as the primal computation affects the entire backward pass.
"""

fn: Callable[
Expand Down Expand Up @@ -618,10 +623,18 @@ def mv(self, vector):
if self.jac == "fwd" or self.jac is None:
_, out = jax.jvp(fn, (self.x,), (vector,))
elif self.jac == "bwd":
jac = jax.jacrev(fn)(self.x)
out = PyTreeLinearOperator(jac, output_structure=self.out_structure()).mv(
vector
)
# Use VJP + linear_transpose instead of materializing full Jacobian.
# This works even for custom_vjp functions that don't have JVP rules.
_, vjp_fn = jax.vjp(fn, self.x)
if symmetric_tag in self.tags:
# For symmetric operators, J = J.T, so vjp directly gives J @ v
(out,) = vjp_fn(vector)
else:
# For non-symmetric, transpose the VJP to get J @ v from J.T @ v
transpose_vjp = jax.linear_transpose(
lambda g: vjp_fn(g)[0], self.out_structure()
)
(out,) = transpose_vjp(vector)
else:
raise ValueError("`jac` should be either `'fwd'`, `'bwd'`, or `None`.")
return out
Expand Down Expand Up @@ -1238,10 +1251,33 @@ def _(operator):
@linearise.register(JacobianLinearOperator)
def _(operator):
fn = _NoAuxIn(operator.fn, operator.args)
(_, aux), lin = jax.linearize(fn, operator.x)
lin = _NoAuxOut(lin)
out = FunctionLinearOperator(lin, operator.in_structure(), operator.tags)
return AuxLinearOperator(out, aux)
if operator.jac == "bwd":
# For backward mode, use VJP + linear_transpose.
# This works with custom_vjp functions that don't support forward-mode.
_, vjp_fn, aux = jax.vjp(fn, operator.x, has_aux=True)
if symmetric_tag in operator.tags:
# For symmetric: J = J.T, so vjp directly gives J @ v
out = FunctionLinearOperator(
_Unwrap(vjp_fn), operator.in_structure(), operator.tags
)
else:
# Transpose the VJP to get J @ v from J.T @ v
transpose_vjp = jax.linear_transpose(
lambda g: vjp_fn(g)[0], operator.out_structure()
)

def mv_fn(v):
(out,) = transpose_vjp(v)
return out

out = FunctionLinearOperator(mv_fn, operator.in_structure(), operator.tags)
return AuxLinearOperator(out, aux)
else:
# Original implementation for fwd/None
(_, aux), lin = jax.linearize(fn, operator.x)
lin = _NoAuxOut(lin)
out = FunctionLinearOperator(lin, operator.in_structure(), operator.tags)
return AuxLinearOperator(out, aux)


# materialise
Expand Down
50 changes: 50 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,56 @@ def make_jac_operator(getkey, matrix, tags):
return lx.JacobianLinearOperator(fn, x, None, tags)


@_operators_append
def make_jacfwd_operator(getkey, matrix, tags):
out_size, in_size = matrix.shape
x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype)
a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype)
b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
fn_tmp = lambda x, _: a + b @ x + c @ x**2.0
jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None)
diff = matrix - jac
fn = lambda x, _: a + (b + diff) @ x + c @ x**2
return lx.JacobianLinearOperator(fn, x, None, tags, jac="fwd")


@_operators_append
def make_jacrev_operator(getkey, matrix, tags):
"""JacobianLinearOperator with jac='bwd' using a custom_vjp function.

This uses custom_vjp so that forward-mode autodiff is NOT available,
which tests that jac='bwd' works correctly without relying on JVP.
"""
out_size, in_size = matrix.shape
x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype)
a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype)
b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
fn_tmp = lambda x, _: a + b @ x + c @ x**2.0
jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None)
diff = matrix - jac

# Use custom_vjp to define a function that only has reverse-mode autodiff
@jax.custom_vjp
def custom_fn(x):
return a + (b + diff) @ x + c @ x**2

def custom_fn_fwd(x):
return custom_fn(x), x

def custom_fn_bwd(x, g):
# Jacobian is: (b + diff) + 2 * c * x
# VJP is: g @ J = g @ ((b + diff) + 2 * c * x)
# So J.T @ g =
return ((b + diff).T @ g + 2 * (c.T @ g) * x,)

custom_fn.defvjp(custom_fn_fwd, custom_fn_bwd)

fn = lambda x, _: custom_fn(x)
return lx.JacobianLinearOperator(fn, x, None, tags, jac="bwd")


@_operators_append
def make_trivial_diagonal_operator(getkey, matrix, tags):
assert tags == lx.diagonal_tag
Expand Down
5 changes: 5 additions & 0 deletions tests/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from .helpers import (
make_identity_operator,
make_jacrev_operator,
make_operators,
make_tridiagonal_operator,
make_trivial_diagonal_operator,
Expand All @@ -33,6 +34,10 @@ def test_adjoint(make_operator, dtype, getkey):
tags = ()
in_size = 5
out_size = 3
if make_operator is make_jacrev_operator and dtype is jnp.complex128:
pytest.skip(
'JacobianLinearOperator does not support complex dtypes when jac="bwd"'
)
operator = make_operator(getkey, matrix, tags)
v1, v2 = (
jr.normal(getkey(), (in_size,), dtype=dtype),
Expand Down
38 changes: 31 additions & 7 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from .helpers import (
make_identity_operator,
make_jacrev_operator,
make_operators,
make_tridiagonal_operator,
make_trivial_diagonal_operator,
Expand All @@ -45,6 +46,10 @@ def test_ops(make_operator, getkey, dtype):
else:
matrix = jr.normal(getkey(), (3, 3), dtype=dtype)
tags = ()
if make_operator is make_jacrev_operator and dtype is jnp.complex128:
pytest.skip(
'JacobianLinearOperator does not support complex dtypes when jac="bwd"'
)
matrix1 = make_operator(getkey, matrix, tags)
matrix2 = lx.MatrixLinearOperator(jr.normal(getkey(), (3, 3), dtype=dtype))
scalar = jr.normal(getkey(), (), dtype=dtype)
Expand Down Expand Up @@ -137,9 +142,22 @@ def _assert_except_diag(cond_fun, operators, flip_cond):

@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_linearise(dtype, getkey):
operators = _setup(getkey, jr.normal(getkey(), (3, 3), dtype=dtype))
matrix = jr.normal(getkey(), (3, 3), dtype=dtype)
operators = list(_setup(getkey, matrix))
vec = jr.normal(getkey(), (3,), dtype=dtype)
for operator in operators:
lx.linearise(operator)
# Skip jacrev operators with complex dtype (jacrev doesn't support complex)
if (
isinstance(operator, lx.JacobianLinearOperator)
and operator.jac == "bwd"
and dtype is jnp.complex128
):
continue
linearised = lx.linearise(operator)
# Actually evaluate the linearised operator to ensure it works
result = linearised.mv(vec)
expected = operator.mv(vec)
assert tree_allclose(result, expected)


@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
Expand Down Expand Up @@ -421,25 +439,31 @@ def test_zero_pytree_as_matrix(dtype):


def test_jacrev_operator():
# Test that custom_vjp is respected. The custom backward multiplies by 3
# instead of the true derivative (which would be 2).
# This tests that lineax uses the custom_vjp, not the true derivative.
@jax.custom_vjp
def f(x, _):
return dict(foo=x["bar"] + 2)
return dict(foo=x["bar"] * 2) # forward: multiply by 2

def f_fwd(x, _):
return f(x, None), None

def f_bwd(_, g):
return dict(bar=g["foo"] + 5), None
# Custom backward: multiply by 3 (not the true derivative 2)
# This must be linear in g for linear_transpose to work correctly.
return dict(bar=g["foo"] * 3), None

f.defvjp(f_fwd, f_bwd)

x = dict(bar=jnp.arange(2.0))
rev_op = lx.JacobianLinearOperator(f, x, jac="bwd")
as_matrix = jnp.array([[6.0, 5.0], [5.0, 6.0]])
# Jacobian is 3*I (from custom backward, not 2*I from true derivative)
as_matrix = jnp.array([[3.0, 0.0], [0.0, 3.0]])
assert tree_allclose(rev_op.as_matrix(), as_matrix)

y = dict(bar=jnp.arange(2.0) + 1)
true_out = dict(foo=jnp.array([16.0, 17.0]))
y = dict(bar=jnp.arange(2.0) + 1) # y = [1, 2]
true_out = dict(foo=jnp.array([3.0, 6.0])) # 3*I @ [1, 2] = [3, 6]
for op in (rev_op, lx.materialise(rev_op)):
out = op.mv(y)
assert tree_allclose(out, true_out)
Expand Down
5 changes: 5 additions & 0 deletions tests/test_well_posed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from .helpers import (
construct_matrix,
make_jacrev_operator,
ops,
params,
solvers,
Expand All @@ -31,6 +32,10 @@
@pytest.mark.parametrize("ops", ops)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_small_wellposed(make_operator, solver, tags, ops, getkey, dtype):
if make_operator is make_jacrev_operator and dtype is jnp.complex128:
pytest.skip(
'JacobianLinearOperator does not support complex dtypes when jac="bwd"'
)
if jax.config.jax_enable_x64: # pyright: ignore
tol = 1e-10
else:
Expand Down
Loading