Skip to content

Commit 1a6df92

Browse files
github-actions[bot]mehrdad2merick-xanadurauletorrescsengthai
authored
Daily rc sync to main (#1647)
Automatic sync from the release candidate to main during a feature freeze. --------- Co-authored-by: Mehrdad Malek <[email protected]> Co-authored-by: erick-xanadu <[email protected]> Co-authored-by: Raul Torres <[email protected]> Co-authored-by: Sengthai Heng <[email protected]> Co-authored-by: David Ittah <[email protected]> Co-authored-by: Paul <[email protected]> Co-authored-by: Joey Carter <[email protected]> Co-authored-by: GitHub Actions Bot <> Co-authored-by: Mehrdad Malekmohammadi <[email protected]>
1 parent 6a6a94f commit 1a6df92

File tree

6 files changed

+139
-24
lines changed

6 files changed

+139
-24
lines changed

doc/releases/changelog-0.11.0.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@
223223

224224
* Added `qjit.mlir_opt` property to access the optimized MLIR representation of a compiled function.
225225
[(#1579)](https://github.com/PennyLaneAI/catalyst/pull/1579)
226+
[(#1637)](https://github.com/PennyLaneAI/catalyst/pull/1637)
226227

227228
* Improve error message for ZNE.
228229
[(#1603)](https://github.com/PennyLaneAI/catalyst/pull/1603)
@@ -268,6 +269,9 @@
268269
pytrees inside a loop with autograph causes falling back to python.
269270
[(#1601)](https://github.com/PennyLaneAI/catalyst/pull/1601)
270271

272+
* Closure variables are now supported with `grad` and `value_and_grad`.
273+
[(#1613)](https://github.com/PennyLaneAI/catalyst/pull/1613)
274+
271275
<h3>Internal changes ⚙️</h3>
272276

273277
* Updated the call signature for the PLXPR `qnode_prim` primitive.

frontend/catalyst/api_extensions/differentiation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -667,7 +667,7 @@ def __call__(self, *args, **kwargs):
667667

668668
# It always returns list as required by catalyst control-flows
669669
results = value_and_grad_p.bind(
670-
*input_data_flat, jaxpr=jaxpr, fn=fn, grad_params=grad_params
670+
*input_data_flat, *jaxpr.consts, jaxpr=jaxpr, fn=fn, grad_params=grad_params
671671
)
672672

673673
# value_and_grad returns two results: the values and the gradients,
@@ -686,7 +686,7 @@ def __call__(self, *args, **kwargs):
686686

687687
# It always returns list as required by catalyst control-flows
688688
results = grad_p.bind(
689-
*input_data_flat, jaxpr=jaxpr, fn=fn, grad_params=grad_params
689+
*input_data_flat, *jaxpr.consts, jaxpr=jaxpr, fn=fn, grad_params=grad_params
690690
)
691691

692692
# grad returns only the gradients,

frontend/catalyst/compiler.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,11 @@ def _catalyst(*args, stdin=None):
290290
catalyst *args
291291
"""
292292
cmd = _get_catalyst_cli_cmd(*args, stdin=stdin)
293-
result = subprocess.run(cmd, input=stdin, check=True, capture_output=True, text=True)
294-
return result.stdout
293+
try:
294+
result = subprocess.run(cmd, input=stdin, check=True, capture_output=True, text=True)
295+
return result.stdout
296+
except subprocess.CalledProcessError as e:
297+
raise CompileError(f"catalyst failed with error code {e.returncode}: {e.stderr}") from e
295298

296299

297300
def _quantum_opt(*args, stdin=None):

frontend/catalyst/jax_primitives.py

Lines changed: 49 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,27 @@ def _grad_lowering(ctx, *args, jaxpr, fn, grad_params):
500500
argnums: argument indices which define over which arguments to
501501
differentiate.
502502
"""
503+
consts = []
504+
offset = len(args) - len(jaxpr.consts)
505+
for i, jax_array_or_tracer in enumerate(jaxpr.consts):
506+
if not isinstance(
507+
jax_array_or_tracer, jax._src.interpreters.partial_eval.DynamicJaxprTracer
508+
):
509+
# ``ir.DenseElementsAttr.get()`` constructs a dense elements attribute from an array of
510+
# element values. This doesn't support ``jaxlib.xla_extension.Array``, so we have to
511+
# cast such constants to numpy array types.
512+
const = jax_array_or_tracer
513+
const_type = shape_dtype_to_ir_type(const.shape, const.dtype)
514+
nparray = np.asarray(const)
515+
attr = ir.DenseElementsAttr.get(nparray, type=const_type)
516+
constval = StableHLOConstantOp(attr).results
517+
consts.append(constval)
518+
else:
519+
# There are some cases where this value cannot be converted into
520+
# a jax.numpy.array.
521+
# in that case we get it from the arguments.
522+
consts.append(args[offset + i])
523+
503524
method, h, argnums = grad_params.method, grad_params.h, grad_params.expanded_argnums
504525
mlir_ctx = ctx.module_context.context
505526
finiteDiffParam = None
@@ -516,18 +537,9 @@ def _grad_lowering(ctx, *args, jaxpr, fn, grad_params):
516537
output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out))
517538
flat_output_types = util.flatten(output_types)
518539

519-
# ``ir.DenseElementsAttr.get()`` constructs a dense elements attribute from an array of
520-
# element values. This doesn't support ``jaxlib.xla_extension.Array``, so we have to cast
521-
# such constants to numpy array types.
522-
523-
constants = []
524-
for const in jaxpr.consts:
525-
const_type = shape_dtype_to_ir_type(const.shape, const.dtype)
526-
nparray = np.asarray(const)
527-
attr = ir.DenseElementsAttr.get(nparray, type=const_type)
528-
constantVals = StableHLOConstantOp(attr).results
529-
constants.append(constantVals)
530-
args_and_consts = constants + list(args)
540+
len_args = len(args)
541+
index = len_args - len(consts)
542+
args_and_consts = consts + list(args[:index])
531543

532544
return GradOp(
533545
flat_output_types,
@@ -563,21 +575,38 @@ def _value_and_grad_lowering(ctx, *args, jaxpr, fn, grad_params):
563575
Returns:
564576
MLIR results
565577
"""
566-
args = list(args)
578+
consts = []
579+
offset = len(args) - len(jaxpr.consts)
580+
for i, jax_array_or_tracer in enumerate(jaxpr.consts):
581+
if not isinstance(
582+
jax_array_or_tracer, jax._src.interpreters.partial_eval.DynamicJaxprTracer
583+
):
584+
# ``ir.DenseElementsAttr.get()`` constructs a dense elements attribute from an array of
585+
# element values. This doesn't support ``jaxlib.xla_extension.Array``, so we have to
586+
# cast such constants to numpy array types.
587+
const = jax_array_or_tracer
588+
const_type = shape_dtype_to_ir_type(const.shape, const.dtype)
589+
nparray = np.asarray(const)
590+
attr = ir.DenseElementsAttr.get(nparray, type=const_type)
591+
constval = StableHLOConstantOp(attr).results
592+
consts.append(constval)
593+
else:
594+
# There are some cases where this value cannot be converted into
595+
# a jax.numpy.array.
596+
# in that case we get it from the arguments.
597+
consts.append(args[offset + i])
598+
599+
len_args = len(args)
600+
index = len_args - len(consts)
601+
args = list(args[0:index])
567602
method, h, argnums = grad_params.method, grad_params.h, grad_params.expanded_argnums
568603
mlir_ctx = ctx.module_context.context
569604
new_argnums = np.array([len(jaxpr.consts) + num for num in argnums])
570605

571606
output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out))
572607
flat_output_types = util.flatten(output_types)
573608

574-
constants = []
575-
for const in jaxpr.consts:
576-
const_type = shape_dtype_to_ir_type(const.shape, const.dtype)
577-
nparray = np.asarray(const)
578-
attr = ir.DenseElementsAttr.get(nparray, type=const_type)
579-
constantVals = StableHLOConstantOp(attr).results
580-
constants.append(constantVals)
609+
constants = consts
581610

582611
consts_and_args = constants + args
583612
func_call_jaxpr = get_call_jaxpr(jaxpr)

frontend/test/pytest/test_debug.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,12 @@ def test_no_options_to_mlir_opt(self):
558558
).strip()
559559
assert expected in observed
560560

561+
def test_catalyst_error(self):
562+
mlir = """This is invalid MLIR"""
563+
msg = "custom op 'This'"
564+
with pytest.raises(CompileError, match=msg):
565+
to_mlir_opt(stdin=mlir)
566+
561567

562568
if __name__ == "__main__":
563569
pytest.main(["-x", __file__])

frontend/test/pytest/test_gradient.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2201,5 +2201,78 @@ def circuit(x: float):
22012201
return qml.expval(qml.PauliZ(wires=0))
22022202

22032203

2204+
def test_closure_variable_grad():
2205+
"""Test that grad can take closure variables"""
2206+
2207+
@qml.qjit
2208+
def workflow_closure(x, y):
2209+
2210+
dev = qml.device("lightning.qubit", wires=1)
2211+
2212+
@qml.qnode(dev)
2213+
def circuit(x):
2214+
qml.RX(jnp.pi * x, wires=0)
2215+
qml.RX(jnp.pi * y, wires=0)
2216+
return qml.expval(qml.PauliY(0))
2217+
2218+
g = grad(circuit)
2219+
return g(x)
2220+
2221+
@qml.qjit
2222+
def workflow_no_closure(x, y):
2223+
2224+
dev = qml.device("lightning.qubit", wires=1)
2225+
2226+
@qml.qnode(dev)
2227+
def circuit(x, y):
2228+
qml.RX(jnp.pi * x, wires=0)
2229+
qml.RX(jnp.pi * y, wires=0)
2230+
return qml.expval(qml.PauliY(0))
2231+
2232+
g = grad(circuit)
2233+
return g(x, y)
2234+
2235+
expected = workflow_no_closure(1.0, 0.25)
2236+
observed = workflow_closure(1.0, 0.25)
2237+
assert np.allclose(expected, observed)
2238+
2239+
2240+
def test_closure_variable_value_and_grad():
2241+
"""Test that value and grad can take closure variables"""
2242+
2243+
@qml.qjit
2244+
def workflow_closure(x, y):
2245+
2246+
dev = qml.device("lightning.qubit", wires=1)
2247+
2248+
@qml.qnode(dev)
2249+
def circuit(x):
2250+
qml.RX(jnp.pi * x, wires=0)
2251+
qml.RX(jnp.pi * y, wires=0)
2252+
return qml.expval(qml.PauliY(0))
2253+
2254+
g = value_and_grad(circuit)
2255+
return g(x)
2256+
2257+
@qml.qjit
2258+
def workflow_no_closure(x, y):
2259+
2260+
dev = qml.device("lightning.qubit", wires=1)
2261+
2262+
@qml.qnode(dev)
2263+
def circuit(x, y):
2264+
qml.RX(jnp.pi * x, wires=0)
2265+
qml.RX(jnp.pi * y, wires=0)
2266+
return qml.expval(qml.PauliY(0))
2267+
2268+
g = value_and_grad(circuit)
2269+
return g(x, y)
2270+
2271+
x, y = 1.0, 0.25
2272+
expected = workflow_no_closure(x, y)
2273+
observed = workflow_closure(x, y)
2274+
assert np.allclose(expected, observed)
2275+
2276+
22042277
if __name__ == "__main__":
22052278
pytest.main(["-x", __file__])

0 commit comments

Comments
 (0)