Skip to content

Commit c718ff8

Browse files
authored
fix linearise for JacobianLinearOperator with jac=bwd and use linear_transpose in mv (#191)
* add failing jacrev custom vjp linearise tests * fix broken JacobianLinearOperator linearise * optimise mv * return early rather than skip jacrev complex tests * symmetric_tag in self.tags -> is_symmetric(self) * use _Unwrap and keep more examples from old docs * skip jacrev for test_tangent_as_matrix
1 parent a61917e commit c718ff8

File tree

5 files changed

+137
-24
lines changed

5 files changed

+137
-24
lines changed

lineax/_operator.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def as_matrix(self):
268268
return self.matrix
269269

270270
def transpose(self):
271-
if symmetric_tag in self.tags:
271+
if is_symmetric(self):
272272
return self
273273
return MatrixLinearOperator(self.matrix.T, transpose_tags(self.tags))
274274

@@ -447,7 +447,7 @@ def concat_in(struct, subpytree):
447447
return jnp.concatenate(matrix, axis=0)
448448

449449
def transpose(self):
450-
if symmetric_tag in self.tags:
450+
if is_symmetric(self):
451451
return self
452452

453453
def _transpose(struct, subtree):
@@ -544,15 +544,21 @@ class JacobianLinearOperator(AbstractLinearOperator):
544544
`MatrixLinearOperator(jax.jacfwd(fn)(x))`.
545545
546546
The Jacobian is not materialised; matrix-vector products, which are in fact
547-
Jacobian-vector products, are computed using autodifferentiation, specifically
548-
`jax.jvp`. Thus, `JacobianLinearOperator(fn, x).mv(v)` is equivalent to
549-
`jax.jvp(fn, (x,), (v,))`.
550-
551-
See also [`lineax.linearise`][], which caches the primal computation, i.e.
552-
it returns `_, lin = jax.linearize(fn, x); FunctionLinearOperator(lin, ...)`
547+
Jacobian-vector products, are computed using autodifferentiation. By default
548+
(or with `jac="fwd"`), `JacobianLinearOperator(fn, x).mv(v)` is equivalent to
549+
`jax.jvp(fn, (x,), (v,))`. For `jac="bwd"`, `jax.vjp` is combined with
550+
`jax.linear_transpose`, which works even with functions
551+
that only define a custom VJP (via `jax.custom_vjp`) and don't support
552+
forward-mode differentiation.
553553
554554
See also [`lineax.materialise`][], which materialises the whole Jacobian in
555555
memory.
556+
557+
!!! tip
558+
559+
For repeated `mv()` calls, consider using [`lineax.linearise`][] to cache
560+
the primal computation, e.g. for `jac="fwd"/None` it returns
561+
`_, lin = jax.linearize(fn, x); FunctionLinearOperator(lin, ...)`
556562
"""
557563

558564
fn: Callable[
@@ -619,10 +625,18 @@ def mv(self, vector):
619625
if self.jac == "fwd" or self.jac is None:
620626
_, out = jax.jvp(fn, (self.x,), (vector,))
621627
elif self.jac == "bwd":
622-
jac = jax.jacrev(fn)(self.x)
623-
out = PyTreeLinearOperator(jac, output_structure=self.out_structure()).mv(
624-
vector
625-
)
628+
# Use VJP + linear_transpose instead of materializing full Jacobian.
629+
# This works even for custom_vjp functions that don't have JVP rules.
630+
_, vjp_fn = jax.vjp(fn, self.x)
631+
if is_symmetric(self):
632+
# For symmetric operators, J = J.T, so vjp directly gives J @ v
633+
(out,) = vjp_fn(vector)
634+
else:
635+
# For non-symmetric, transpose the VJP to get J @ v from J.T @ v
636+
transpose_vjp = jax.linear_transpose(
637+
lambda g: vjp_fn(g)[0], self.out_structure()
638+
)
639+
(out,) = transpose_vjp(vector)
626640
else:
627641
raise ValueError("`jac` should be either `'fwd'`, `'bwd'`, or `None`.")
628642
return out
@@ -631,7 +645,7 @@ def as_matrix(self):
631645
return materialise(self).as_matrix()
632646

633647
def transpose(self):
634-
if symmetric_tag in self.tags:
648+
if is_symmetric(self):
635649
return self
636650
fn = _NoAuxOut(_NoAuxIn(self.fn, self.args))
637651
# Works because vjpfn is a PyTree
@@ -698,7 +712,7 @@ def as_matrix(self):
698712
return materialise(self).as_matrix()
699713

700714
def transpose(self):
701-
if symmetric_tag in self.tags:
715+
if is_symmetric(self):
702716
return self
703717
transpose_fn = jax.linear_transpose(self.fn, self.in_structure())
704718

@@ -1239,8 +1253,21 @@ def _(operator):
12391253
@linearise.register(JacobianLinearOperator)
12401254
def _(operator):
12411255
fn = _NoAuxIn(operator.fn, operator.args)
1242-
(_, aux), lin = jax.linearize(fn, operator.x)
1243-
lin = _NoAuxOut(lin)
1256+
if operator.jac == "bwd":
1257+
# For backward mode, use VJP + linear_transpose.
1258+
# This works even with custom_vjp functions that don't support forward-mode AD.
1259+
_, vjp_fn, aux = jax.vjp(fn, operator.x, has_aux=True)
1260+
if is_symmetric(operator):
1261+
# For symmetric: J = J.T, so vjp directly gives J @ v
1262+
lin = _Unwrap(vjp_fn())
1263+
else:
1264+
# Transpose the VJP to get J @ v from J.T @ v
1265+
lin = _Unwrap(
1266+
jax.linear_transpose(lambda g: vjp_fn(g)[0], operator.out_structure())
1267+
)
1268+
else: # "fwd" or None
1269+
(_, aux), lin = jax.linearize(fn, operator.x)
1270+
lin = _NoAuxOut(lin)
12441271
out = FunctionLinearOperator(lin, operator.in_structure(), operator.tags)
12451272
return AuxLinearOperator(out, aux)
12461273

tests/helpers.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,56 @@ def make_jac_operator(getkey, matrix, tags):
209209
return lx.JacobianLinearOperator(fn, x, None, tags)
210210

211211

212+
@_operators_append
213+
def make_jacfwd_operator(getkey, matrix, tags):
214+
out_size, in_size = matrix.shape
215+
x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype)
216+
a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype)
217+
b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
218+
c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
219+
fn_tmp = lambda x, _: a + b @ x + c @ x**2.0
220+
jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None)
221+
diff = matrix - jac
222+
fn = lambda x, _: a + (b + diff) @ x + c @ x**2
223+
return lx.JacobianLinearOperator(fn, x, None, tags, jac="fwd")
224+
225+
226+
@_operators_append
227+
def make_jacrev_operator(getkey, matrix, tags):
228+
"""JacobianLinearOperator with jac='bwd' using a custom_vjp function.
229+
230+
This uses custom_vjp so that forward-mode autodiff is NOT available,
231+
which tests that jac='bwd' works correctly without relying on JVP.
232+
"""
233+
out_size, in_size = matrix.shape
234+
x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype)
235+
a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype)
236+
b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
237+
c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
238+
fn_tmp = lambda x, _: a + b @ x + c @ x**2.0
239+
jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None)
240+
diff = matrix - jac
241+
242+
# Use custom_vjp to define a function that only has reverse-mode autodiff
243+
@jax.custom_vjp
244+
def custom_fn(x):
245+
return a + (b + diff) @ x + c @ x**2
246+
247+
def custom_fn_fwd(x):
248+
return custom_fn(x), x
249+
250+
def custom_fn_bwd(x, g):
251+
# Jacobian is: (b + diff) + 2 * c * x
252+
# VJP is: g @ J = g @ ((b + diff) + 2 * c * x)
253+
# So J.T @ g =
254+
return ((b + diff).T @ g + 2 * (c.T @ g) * x,)
255+
256+
custom_fn.defvjp(custom_fn_fwd, custom_fn_bwd)
257+
258+
fn = lambda x, _: custom_fn(x)
259+
return lx.JacobianLinearOperator(fn, x, None, tags, jac="bwd")
260+
261+
212262
@_operators_append
213263
def make_trivial_diagonal_operator(getkey, matrix, tags):
214264
assert tags == lx.diagonal_tag

tests/test_adjoint.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from .helpers import (
99
make_identity_operator,
10+
make_jacrev_operator,
1011
make_operators,
1112
make_tridiagonal_operator,
1213
make_trivial_diagonal_operator,
@@ -33,6 +34,9 @@ def test_adjoint(make_operator, dtype, getkey):
3334
tags = ()
3435
in_size = 5
3536
out_size = 3
37+
if make_operator is make_jacrev_operator and dtype is jnp.complex128:
38+
# JacobianLinearOperator does not support complex dtypes when jac="bwd"
39+
return
3640
operator = make_operator(getkey, matrix, tags)
3741
v1, v2 = (
3842
jr.normal(getkey(), (in_size,), dtype=dtype),

tests/test_operator.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from .helpers import (
2525
make_identity_operator,
26+
make_jacrev_operator,
2627
make_operators,
2728
make_tridiagonal_operator,
2829
make_trivial_diagonal_operator,
@@ -45,6 +46,9 @@ def test_ops(make_operator, getkey, dtype):
4546
else:
4647
matrix = jr.normal(getkey(), (3, 3), dtype=dtype)
4748
tags = ()
49+
if make_operator is make_jacrev_operator and dtype is jnp.complex128:
50+
# JacobianLinearOperator does not support complex dtypes when jac="bwd"
51+
return
4852
matrix1 = make_operator(getkey, matrix, tags)
4953
matrix2 = lx.MatrixLinearOperator(jr.normal(getkey(), (3, 3), dtype=dtype))
5054
scalar = jr.normal(getkey(), (), dtype=dtype)
@@ -137,9 +141,22 @@ def _assert_except_diag(cond_fun, operators, flip_cond):
137141

138142
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
139143
def test_linearise(dtype, getkey):
140-
operators = _setup(getkey, jr.normal(getkey(), (3, 3), dtype=dtype))
144+
matrix = jr.normal(getkey(), (3, 3), dtype=dtype)
145+
operators = list(_setup(getkey, matrix))
146+
vec = jr.normal(getkey(), (3,), dtype=dtype)
141147
for operator in operators:
142-
lx.linearise(operator)
148+
# Skip jacrev operators with complex dtype (jacrev doesn't support complex)
149+
if (
150+
isinstance(operator, lx.JacobianLinearOperator)
151+
and operator.jac == "bwd"
152+
and dtype is jnp.complex128
153+
):
154+
continue
155+
linearised = lx.linearise(operator)
156+
# Actually evaluate the linearised operator to ensure it works
157+
result = linearised.mv(vec)
158+
expected = operator.mv(vec)
159+
assert tree_allclose(result, expected)
143160

144161

145162
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
@@ -283,7 +300,12 @@ def test_is_tridiagonal(dtype, getkey):
283300
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
284301
def test_tangent_as_matrix(dtype, getkey):
285302
def _list_setup(matrix):
286-
return list(_setup(getkey, matrix))
303+
# Exclude jacrev operator: jac="bwd" uses custom_vjp which doesn't support JVP
304+
return [
305+
op
306+
for op in _setup(getkey, matrix)
307+
if not (isinstance(op, lx.JacobianLinearOperator) and op.jac == "bwd")
308+
]
287309

288310
matrix = jr.normal(getkey(), (3, 3), dtype=dtype)
289311
t_matrix = jr.normal(getkey(), (3, 3), dtype=dtype)
@@ -421,25 +443,31 @@ def test_zero_pytree_as_matrix(dtype):
421443

422444

423445
def test_jacrev_operator():
446+
# Test that custom_vjp is respected. The custom backward multiplies by 3
447+
# instead of the true derivative (which would be 2).
448+
# This tests that lineax uses the custom_vjp, not the true derivative.
424449
@jax.custom_vjp
425450
def f(x, _):
426-
return dict(foo=x["bar"] + 2)
451+
return dict(foo=x["bar"] * 2) # forward: multiply by 2
427452

428453
def f_fwd(x, _):
429454
return f(x, None), None
430455

431456
def f_bwd(_, g):
432-
return dict(bar=g["foo"] + 5), None
457+
# Custom backward: multiply by 3 (not the true derivative 2)
458+
# This must be linear in g for linear_transpose to work correctly.
459+
return dict(bar=g["foo"] * 3), None
433460

434461
f.defvjp(f_fwd, f_bwd)
435462

436463
x = dict(bar=jnp.arange(2.0))
437464
rev_op = lx.JacobianLinearOperator(f, x, jac="bwd")
438-
as_matrix = jnp.array([[6.0, 5.0], [5.0, 6.0]])
465+
# Jacobian is 3*I (from custom backward, not 2*I from true derivative)
466+
as_matrix = jnp.array([[3.0, 0.0], [0.0, 3.0]])
439467
assert tree_allclose(rev_op.as_matrix(), as_matrix)
440468

441-
y = dict(bar=jnp.arange(2.0) + 1)
442-
true_out = dict(foo=jnp.array([16.0, 17.0]))
469+
y = dict(bar=jnp.arange(2.0) + 1) # y = [1, 2]
470+
true_out = dict(foo=jnp.array([3.0, 6.0])) # 3*I @ [1, 2] = [3, 6]
443471
for op in (rev_op, lx.materialise(rev_op)):
444472
out = op.mv(y)
445473
assert tree_allclose(out, true_out)

tests/test_well_posed.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from .helpers import (
2222
construct_matrix,
23+
make_jacrev_operator,
2324
ops,
2425
params,
2526
solvers,
@@ -31,6 +32,9 @@
3132
@pytest.mark.parametrize("ops", ops)
3233
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
3334
def test_small_wellposed(make_operator, solver, tags, ops, getkey, dtype):
35+
if make_operator is make_jacrev_operator and dtype is jnp.complex128:
36+
# JacobianLinearOperator does not support complex dtypes when jac="bwd"
37+
return
3438
if jax.config.jax_enable_x64: # pyright: ignore
3539
tol = 1e-10
3640
else:

0 commit comments

Comments
 (0)