@@ -426,21 +426,31 @@ def _verify_vector_field_signature_and_parse_order(vf) -> int:
426426 sig = inspect .signature (vf )
427427 params = list (sig .parameters .values ())
428428
429- # Collect positional-only state arguments
430- # TODO: should we allow positional-or-keyword arguments?
431- state_args = [p for p in params if p .kind in (inspect .Parameter .POSITIONAL_ONLY ,)]
429+ POSITIONAL = (
430+ inspect .Parameter .POSITIONAL_ONLY ,
431+ inspect .Parameter .POSITIONAL_OR_KEYWORD ,
432+ )
433+ KEYWORD = (inspect .Parameter .KEYWORD_ONLY ,)
434+
435+ def is_positional (p ):
436+ return p .kind in POSITIONAL
437+
438+ def is_keyword (p ):
439+ return p .kind in KEYWORD
440+
441+ state_args = [p for p in params if is_positional (p )]
432442
433443 msg = f"""The dynamics' signature is not compatible with the constraint.
434444
435445 More precisely, the dynamics are expected to look like
436446
437- - f(u, /, *, t),
438- - f(u, du, /, *, t),
439- - f(u, du, ddu /, *, t),
447+ - f(u, *, t),
448+ - f(u, du, *, t),
449+ - f(u, du, ddu *, t),
440450
441- and so on, where the number of ** positional-only** arguments
442- specifies the order of the problem. (Mind the positional-only
443- and keyword-only arguments in the signatures above.)
451+ and so on, where the number of positional arguments
452+ specifies the order of the problem.
453+ (Mind the keyword-only argument 't' in the signatures above.)
444454
445455 That said, the arguments
446456
@@ -456,15 +466,11 @@ def _verify_vector_field_signature_and_parse_order(vf) -> int:
456466
457467 """
458468
459- # Check for keyword-only 't' (and no other keyword-args)
460- contains_kw_t = any (
461- p .kind == inspect .Parameter .KEYWORD_ONLY and p .name == "t" for p in params
462- )
463- contains_other_kw = any (
464- p .kind == inspect .Parameter .KEYWORD_ONLY and p .name != "t" for p in params
465- )
469+ contains_no_positional = len (state_args ) == 0
470+ t_is_not_keyword = not any (is_keyword (p ) and p .name == "t" for p in params )
471+ contains_keyword_other_than_t = any (is_keyword (p ) and p .name != "t" for p in params )
466472
467- if not state_args or not contains_kw_t or contains_other_kw :
473+ if contains_no_positional or t_is_not_keyword or contains_keyword_other_than_t :
468474 raise TypeError (msg )
469475
470476 return len (state_args )
0 commit comments