Skip to content

Commit 114847f

Browse files
committed
Early validation of jac argument of `JacobianLinearOperator (patrick-kidger#167)
* Early validation of `jac` argument of `JacobianLinearOperator * address nits
1 parent 56c3962 commit 114847f

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

lineax/_operator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,11 @@ def __init__(
588588
`jax.jacrev`. Otherwise, if not specified it will be chosen
589589
by default according to input and output shape.
590590
"""
591+
if jac not in [None, "fwd", "bwd"]:
592+
raise ValueError(
593+
"`jac` argument of `JacobianLinearOperator` should be either "
594+
"`'fwd'`, `'bwd'`, or `None`."
595+
)
591596
if not _has_aux:
592597
fn = NoneAux(fn)
593598
# Flush out any closed-over values, so that we can safely pass `self`

0 commit comments

Comments
 (0)