Skip to content

Resolve issue of qjit(static_argnums=...) fails when the marked static argument has a default value #1295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion frontend/catalyst/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,8 @@ def capture(self, args, **kwargs):
PyTreeDef: PyTree metadata of the function output
Tuple[Any]: the dynamic argument signature
"""
verify_static_argnums(args, self.compile_options.static_argnums)

verify_static_argnums(args, self.original_function, self.compile_options.static_argnums)
static_argnums = self.compile_options.static_argnums
abstracted_axes = self.compile_options.abstracted_axes

Expand Down
12 changes: 9 additions & 3 deletions frontend/catalyst/tracing/type_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,21 +104,27 @@ def verify_static_argnums_type(static_argnums):
return None


def verify_static_argnums(args, static_argnums):
def verify_static_argnums(args, fn, static_argnums):
"""Verify that static_argnums have correct type and range.

Args:
args (Iterable): arguments to a compiled function
fn (Callable): the quantum or classical function in question
static_argnums (Iterable[int]): indices to verify

Returns:
None
"""
verify_static_argnums_type(static_argnums)

# use inspect to get parameters defined in the function declaration
sig_args = inspect.signature(fn).parameters

# `static_argnums` should be compared to the maximum args that can be passed to a function
arg_limit = max(len(args), len(sig_args))
for argnum in static_argnums:
if argnum < 0 or argnum >= len(args):
msg = f"argnum {argnum} is beyond the valid range of [0, {len(args)})."
if argnum < 0 or argnum >= arg_limit:
msg = f"argnum {argnum} is beyond the valid range of [0, {arg_limit})."
raise CompileError(msg)
return None

Expand Down
13 changes: 13 additions & 0 deletions frontend/test/pytest/test_static_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,19 @@ def f(x: MyClass, y: int, z: MyClass):
assert f(MyClass(5), 2, MyClass(5)) == 12
assert function == f.compiled_function

def test_default_static_arguments(self):
"""Test QJIT with static arguments that have a default value."""

@qjit(static_argnums=[1])
def f(y, x=9):
if x < 10:
return x + y
return 42000

assert f(20) == 29
assert f(20, 3) == 23
assert f(20, 300000) == 42000

def test_mutable_static_arguments(self):
"""Test QJIT with mutable static arguments."""

Expand Down
Loading