Skip to content

Commit b71677c

Browse files
authored
Merge branch 'main' into ops-qp
2 parents 9f27ac2 + 08cd163 commit b71677c

18 files changed

Lines changed: 686 additions & 219 deletions

doc/releases/changelog-dev.md

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,21 @@ The following classes have been ported over:
175175

176176
<h3>Improvements 🛠</h3>
177177

178-
* Replaced the O(n²) incremental ``@=`` operator chaining in ``qp.pauli.string_to_pauli_word`` and ``qp.pauli.binary_to_pauli`` with a single ``qp.prod(*tuple_of_ops)`` call, collecting operators via generator expressions. These operators are now much faster for large Pauli strings.
178+
* With program capture and `for_loop` and `while_loop`, const closure variables with dynamic shapes
179+
can now be combined with explicit inputs with dynamic shapes when they have matching shapes.
180+
[(#9275)](https://github.com/PennyLaneAI/pennylane/pull/9275)
181+
182+
* Added another decomposition to `MultiControlledX` with two control wires and at least one zeroed
183+
work wire that has been passed explicitly. It decomposes into a pair of `TemporaryAND` and a
184+
`CNOT`.
185+
[(#9291)](https://github.com/PennyLaneAI/pennylane/pull/9291)
186+
187+
* Operations using ``FermiWord`` are now much faster due to various performance improvements to the class
188+
[(#9283)](https://github.com/PennyLaneAI/pennylane/pull/9283)
189+
190+
* Replaced the O(n²) incremental ``@=`` operator chaining in ``qp.pauli.string_to_pauli_word`` and
191+
``qp.pauli.binary_to_pauli`` with a single ``qp.prod(*tuple_of_ops)`` call, collecting operators via
192+
generator expressions. These operators are now much faster for large Pauli strings.
179193
[(#9271)](https://github.com/PennyLaneAI/pennylane/pull/9271)
180194

181195
* Operations using ``PauliSentence`` are now much faster due to additional memoization in ``PauliWord.__hash__``
@@ -789,6 +803,9 @@ The following classes have been ported over:
789803

790804
<h3>Internal changes ⚙️</h3>
791805

806+
* During program, `qml.for_loop` with negative step sizes is now handled immediately during capture time.
807+
[(#9299)](https://github.com/PennyLaneAI/pennylane/pull/9299)
808+
792809
* With program capture, arrays dynamic shapes with `qml.for_loop` and `qml.while_loop` can now be combined
793810
after the loop.
794811
[(#9245)](https://github.com/PennyLaneAI/pennylane/pull/9245)
@@ -975,6 +992,10 @@ The following classes have been ported over:
975992

976993
<h3>Bug fixes 🐛</h3>
977994

995+
* Global phases are now supported in `from_qasm3` so that QASM including the `gphase` instruction
996+
can be interpreted.
997+
[(#9247)](https://github.com/PennyLaneAI/pennylane/pull/9247)
998+
978999
* Fixes an issue with Catalyst and `qml.for_loop` and `qml.while_loop`, where it was defaulting
9791000
to `allow_array_resizing=True` instead of `allow_array_resizing=False`.
9801001
[(#9251)](https://github.com/PennyLaneAI/pennylane/pull/9251)
@@ -1094,6 +1115,10 @@ The following classes have been ported over:
10941115
produced for a specific wire configuration.
10951116
[(#9270)](https://github.com/PennyLaneAI/pennylane/pull/9270)
10961117

1118+
* Fixes a bug where the `DecompositionGraph` underestimates the minimum number of work wires required to solve for a particular operator
1119+
when it has decomposition rules with a lower work wire budget but is unrecheable from the provided gate set.
1120+
[(#9298)](https://github.com/PennyLaneAI/pennylane/pull/9298)
1121+
10971122
<h3>Contributors ✍️</h3>
10981123

10991124
This release contains contributions from (in alphabetical order):

pennylane/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
Version number (major.minor.patch[-label])
1717
"""
1818

19-
__version__ = "0.45.0-dev77"
19+
__version__ = "0.45.0-dev80"

pennylane/control_flow/_loop_abstract_axes.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,44 @@
2929
AbstractShapeLocation = namedtuple("AbstractShapeLocation", ("arg_idx", "shape_idx"))
3030

3131

32+
def promote_consts_to_inputs(f):
33+
"""This function extracts any closure variables with dynamic shapes from f.__closure__
34+
and promotes them to being normal arguments. This produces a new function that
35+
takes the original args and the new consts as explicit inputs. It also returns
36+
the extracted consts.
37+
"""
38+
indices = []
39+
consts = []
40+
41+
if getattr(f, "__closure__", None) is not None:
42+
for ind, cell in enumerate(f.__closure__):
43+
val = cell.cell_contents
44+
if hasattr(val, "shape") and not all(isinstance(s, int) for s in val.shape):
45+
indices.append(ind)
46+
consts.append(val)
47+
48+
def new_f(args, new_consts):
49+
"""A version of f where the consts with dynamic shapes have been promoted to inputs."""
50+
51+
# even deepcopy does not actually copy the closure for a function
52+
# we don't have a way to produce a new function with independent closure
53+
# so therefore we just need to make sure to clean up after ourselves after
54+
# in-place modifying the closure contents.
55+
56+
try:
57+
for ind, c in zip(indices, new_consts, strict=True):
58+
f.__closure__[ind].cell_contents = c
59+
60+
f_results = f(*args)
61+
finally:
62+
for ind, c in zip(indices, consts, strict=True):
63+
f.__closure__[ind].cell_contents = c
64+
65+
return f_results, new_consts
66+
67+
return new_f, consts
68+
69+
3270
def add_abstract_shapes(f, shape_locations: list[list[AbstractShapeLocation]]): # pragma: no cover
3371
"""Add the abstract shapes at the specified locations to the output of f.
3472

pennylane/control_flow/for_loop.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import warnings
1919
from typing import Literal
2020

21-
from pennylane import capture
21+
from pennylane import capture, math
2222
from pennylane.capture import FlatFn, enabled
2323
from pennylane.capture.dynamic_shapes import register_custom_staging_rule
2424
from pennylane.compiler.compiler import AvailableCompilers, active_compiler
@@ -29,13 +29,33 @@
2929
get_dummy_arg,
3030
handle_jaxpr_error,
3131
loop_determine_abstracted_axes,
32+
promote_consts_to_inputs,
3233
validate_no_resizing_returns,
3334
)
3435

3536
logger = logging.getLogger(__name__)
3637
logger.addHandler(logging.NullHandler())
3738

3839

40+
def _reverse_iterator(f, start, step):
41+
"""Produces a new f with positive steps."""
42+
43+
def new_f(*args):
44+
new_i = start + step * args[0]
45+
inputs = args[1:]
46+
return f(new_i, *inputs)
47+
48+
return new_f
49+
50+
51+
def _is_reverse_iteration(start, stop, step):
52+
# without the int() call, when we have a jnp array with a single int
53+
# in it (jnp.array(1)), performing a comparison will produce a tracer
54+
if not math.is_abstract(step):
55+
return int(step) < 0
56+
return not math.is_abstract(start) and not math.is_abstract(stop) and int(stop) < int(start)
57+
58+
3959
def for_loop(
4060
start, stop=None, step=1, *, allow_array_resizing: Literal["auto", True, False] = "auto"
4161
):
@@ -397,8 +417,10 @@ def _get_jaxpr(self, init_state, allow_array_resizing):
397417

398418
import jax # pylint: disable=import-outside-toplevel
399419

420+
f_consts_extracted, dynamic_consts = promote_consts_to_inputs(self.body_fn)
421+
400422
# need in_tree to include index so flat_fn will repack args correctly
401-
flat_args, in_tree = jax.tree_util.tree_flatten((0, *init_state))
423+
flat_args, in_tree = jax.tree_util.tree_flatten(((0, *init_state), dynamic_consts))
402424

403425
# slice out the index so shape_locations indexes from non-index args/ results
404426
flat_args = flat_args[1:]
@@ -407,7 +429,7 @@ def _get_jaxpr(self, init_state, allow_array_resizing):
407429
tuple(flat_args), allow_array_resizing=tmp_array_resizing
408430
)
409431

410-
flat_fn = FlatFn(self.body_fn, in_tree=in_tree)
432+
flat_fn = FlatFn(f_consts_extracted, in_tree=in_tree)
411433

412434
if abstracted_axes: # pragma: no cover
413435
new_body_fn = add_abstract_shapes(flat_fn, shape_locations)
@@ -417,6 +439,10 @@ def _get_jaxpr(self, init_state, allow_array_resizing):
417439
new_body_fn = flat_fn
418440
dummy_init_state = flat_args
419441

442+
if _is_reverse_iteration(self.start, self.stop, self.step):
443+
# MLIR does not support reverse iteration of for loops
444+
new_body_fn = _reverse_iterator(new_body_fn, self.start, self.step)
445+
420446
try:
421447
jaxpr_body_fn = jax.make_jaxpr(new_body_fn, abstracted_axes=abstracted_axes)(
422448
0, *dummy_init_state
@@ -469,10 +495,17 @@ def _call_capture_enabled(self, *init_state):
469495
abstract_shapes_slice = slice(consts_slice.stop, consts_slice.stop + len(abstract_shapes))
470496
args_slice = slice(abstract_shapes_slice.stop, None)
471497

498+
if _is_reverse_iteration(self.start, self.stop, self.step):
499+
# mlir does not support reverse iteration of for loops
500+
num_iterations = math.ceil((self.stop - self.start) / self.step).astype(int)
501+
start, stop, step = 0, num_iterations, 1
502+
else:
503+
start, stop, step = self.start, self.stop, self.step
504+
472505
results = for_loop_prim.bind(
473-
self.start,
474-
self.stop,
475-
self.step,
506+
start,
507+
stop,
508+
step,
476509
*jaxpr_body_fn.consts,
477510
*abstract_shapes,
478511
*flat_args,
@@ -483,7 +516,8 @@ def _call_capture_enabled(self, *init_state):
483516
)
484517

485518
results = results[-out_tree.num_leaves :]
486-
return jax.tree_util.tree_unflatten(out_tree, results)
519+
# [0] to slice out the consts extracted by promote_consts_to_inputs
520+
return jax.tree_util.tree_unflatten(out_tree, results)[0]
487521

488522
def __call__(self, *init_state):
489523

pennylane/control_flow/while_loop.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
get_dummy_arg,
2828
handle_jaxpr_error,
2929
loop_determine_abstracted_axes,
30+
promote_consts_to_inputs,
3031
validate_no_resizing_returns,
3132
)
3233

@@ -43,6 +44,14 @@ def _new_cond_fn(*args, **kwargs):
4344
return _new_cond_fn
4445

4546

47+
def _body_consts_extracted_cond(cond_fn):
48+
# pylint: disable=unused-argument
49+
def new_cond_fn(args, body_consts):
50+
return cond_fn(*args)
51+
52+
return new_cond_fn
53+
54+
4655
def while_loop(cond_fn, allow_array_resizing: Literal["auto", True, False] = "auto"):
4756
"""A :func:`~.qjit` compatible while-loop for PennyLane programs. When
4857
used without :func:`~.qjit` or program capture, this function will fall back to a standard
@@ -317,14 +326,16 @@ def _call_capture_disabled(self, *init_state):
317326
def _get_jaxprs(self, init_state, allow_array_resizing):
318327
import jax # pylint: disable=import-outside-toplevel
319328

320-
flat_args, in_tree = jax.tree_util.tree_flatten(init_state)
329+
body_consts_extracted, dynamic_consts = promote_consts_to_inputs(self.body_fn)
330+
331+
flat_args, in_tree = jax.tree_util.tree_flatten((init_state, dynamic_consts))
321332
tmp_array_resizing = False if allow_array_resizing == "auto" else allow_array_resizing
322333
abstracted_axes, abstract_shapes, shape_locations = loop_determine_abstracted_axes(
323334
tuple(flat_args), allow_array_resizing=tmp_array_resizing
324335
)
325336

326-
flat_body_fn = FlatFn(self.body_fn, in_tree=in_tree)
327-
flat_cond_fn = FlatFn(self.cond_fn, in_tree=in_tree)
337+
flat_body_fn = FlatFn(body_consts_extracted, in_tree=in_tree)
338+
flat_cond_fn = FlatFn(_body_consts_extracted_cond(self.cond_fn), in_tree=in_tree)
328339
bool_cond_fn = _to_bool_cond_fn(flat_cond_fn)
329340

330341
if abstracted_axes: # pragma: no cover
@@ -379,7 +390,8 @@ def _call_capture_enabled(self, *init_state):
379390
)
380391

381392
results = results[-out_tree.num_leaves :]
382-
return jax.tree_util.tree_unflatten(out_tree, results)
393+
# [0] to slice out the consts extracted by promote_consts_to_inputs
394+
return jax.tree_util.tree_unflatten(out_tree, results)[0]
383395

384396
def __call__(self, *init_state):
385397

pennylane/decomposition/decomposition_graph.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ class _OperatorNode:
9797
min_work_wires: int = 0
9898
"""The minimum number of additional work wires required to decompose this operator."""
9999

100+
reachable: bool = False
101+
"""Whether the operator node can be reached from the gate set."""
102+
100103
def __hash__(self) -> int:
101104
# If the decomposition of an operator does not depend on the availability of work wires
102105
# at all, we don't need to have multiple nodes representing the same operator with
@@ -130,6 +133,7 @@ class _DecompositionNode:
130133
num_work_wire_not_available: int
131134
work_wire_dependent: bool = False
132135
min_work_wires: int = 0
136+
reachable: bool = True
133137

134138
def __post_init__(self):
135139
self.min_work_wires = self.min_work_wires or self.work_wire_spec.total
@@ -309,6 +313,9 @@ def _add_op_node(self, op: CompressedResourceOp, num_used_work_wires: int) -> in
309313
if op_node in self._all_op_indices:
310314
return self._all_op_indices[op_node]
311315

316+
if op in self._gate_set_weights:
317+
op_node = replace(op_node, reachable=True)
318+
312319
op_node_idx = self._graph.add_node(op_node)
313320
self._all_op_indices[op_node] = op_node_idx
314321
self._op_to_op_nodes[op].add(op_node)
@@ -333,23 +340,32 @@ def _add_op_node(self, op: CompressedResourceOp, num_used_work_wires: int) -> in
333340
self._graph.add_edge(self._start, op_node_idx, math.inf)
334341
return op_node_idx
335342

343+
op_reachable = False
336344
work_wire_dependent = known_work_wire_dependent
337345
min_work_wires = -1 # use -1 to represent undetermined work wire requirement
338346
for decomposition in rules:
339347
d_node = self._add_decomp(decomposition, op_node, op_node_idx, num_used_work_wires)
348+
if not d_node or not d_node.reachable:
349+
continue
350+
# If a decomposition is reachable, the operator is also reachable
351+
op_reachable = True
340352
# If any of the operator's decompositions depend on work wires, this operator
341353
# should also depend on work wires.
342-
if d_node and d_node.work_wire_dependent:
354+
if d_node.work_wire_dependent:
343355
work_wire_dependent = True
344-
if d_node and (min_work_wires == -1 or d_node.min_work_wires < min_work_wires):
356+
if min_work_wires == -1 or d_node.min_work_wires < min_work_wires:
345357
min_work_wires = d_node.min_work_wires
346358

359+
if op_reachable:
360+
op_node = replace(op_node, reachable=True)
361+
self._replace_node(op_node_idx, op_node)
362+
347363
# If we found that this operator depends on work wires, but it's currently recorded
348364
# as independent of work wires, we must replace every record of this operator node
349365
# with a new node with `work_wire_dependent` set to `True`.
350366
if not known_work_wire_dependent and work_wire_dependent:
351-
new_op_node = replace(op_node, work_wire_dependent=True, min_work_wires=min_work_wires)
352-
self._replace_node(op_node_idx, new_op_node)
367+
op_node = replace(op_node, work_wire_dependent=True, min_work_wires=min_work_wires)
368+
self._replace_node(op_node_idx, op_node)
353369
# Also record that this operator type depends on work wires, so in the future
354370
# when we encounter other instances of the same operator type, we correctly
355371
# identify it as work-wire dependent.
@@ -400,6 +416,8 @@ def _add_decomp(
400416
# decomposition is also dependent on work wires, even it itself does not use
401417
# any work wires.
402418
op_node = self._graph[op_node_idx]
419+
if not op_node.reachable:
420+
d_node.reachable = False
403421
if op_node.work_wire_dependent:
404422
d_node.work_wire_dependent = True
405423
max_op_min_work_wires = max(op_node.min_work_wires, max_op_min_work_wires)

0 commit comments

Comments
 (0)