Skip to content

Commit d018135

Browse files
committed
Reenable positional-or-keyword arguments in vector fields
1 parent d1c2b63 commit d018135

2 files changed

Lines changed: 24 additions & 18 deletions

File tree

probdiffeq/backend/ode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def ivp_lotka_volterra():
2626

2727
# Dictionary to ensure pytree compatibility
2828
@jax.jit
29-
def vf(x, /, *, t): # noqa: ARG001
29+
def vf(x, *, t): # noqa: ARG001
3030
return {"u": f(x["u"], *f_args)}
3131

3232
return vf, ({"u": u0},), (t0, t1)

probdiffeq/probdiffeq.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -426,21 +426,31 @@ def _verify_vector_field_signature_and_parse_order(vf) -> int:
426426
sig = inspect.signature(vf)
427427
params = list(sig.parameters.values())
428428

429-
# Collect positional-only state arguments
430-
# TODO: should we allow positional-or-keyword arguments?
431-
state_args = [p for p in params if p.kind in (inspect.Parameter.POSITIONAL_ONLY,)]
429+
POSITIONAL = (
430+
inspect.Parameter.POSITIONAL_ONLY,
431+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
432+
)
433+
KEYWORD = (inspect.Parameter.KEYWORD_ONLY,)
434+
435+
def is_positional(p):
436+
return p.kind in POSITIONAL
437+
438+
def is_keyword(p):
439+
return p.kind in KEYWORD
440+
441+
state_args = [p for p in params if is_positional(p)]
432442

433443
msg = f"""The dynamics' signature is not compatible with the constraint.
434444
435445
More precisely, the dynamics are expected to look like
436446
437-
- f(u, /, *, t),
438-
- f(u, du, /, *, t),
439-
- f(u, du, ddu /, *, t),
447+
- f(u, *, t),
448+
- f(u, du, *, t),
449+
- f(u, du, ddu *, t),
440450
441-
and so on, where the number of **positional-only** arguments
442-
specifies the order of the problem. (Mind the positional-only
443-
and keyword-only arguments in the signatures above.)
451+
and so on, where the number of positional arguments
452+
specifies the order of the problem.
453+
(Mind the keyword-only argument 't' in the signatures above.)
444454
445455
That said, the arguments
446456
@@ -456,15 +466,11 @@ def _verify_vector_field_signature_and_parse_order(vf) -> int:
456466
457467
"""
458468

459-
# Check for keyword-only 't' (and no other keyword-args)
460-
contains_kw_t = any(
461-
p.kind == inspect.Parameter.KEYWORD_ONLY and p.name == "t" for p in params
462-
)
463-
contains_other_kw = any(
464-
p.kind == inspect.Parameter.KEYWORD_ONLY and p.name != "t" for p in params
465-
)
469+
contains_no_positional = len(state_args) == 0
470+
t_is_not_keyword = not any(is_keyword(p) and p.name == "t" for p in params)
471+
contains_keyword_other_than_t = any(is_keyword(p) and p.name != "t" for p in params)
466472

467-
if not state_args or not contains_kw_t or contains_other_kw:
473+
if contains_no_positional or t_is_not_keyword or contains_keyword_other_than_t:
468474
raise TypeError(msg)
469475

470476
return len(state_args)

0 commit comments

Comments
 (0)