Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,7 @@
[(#2672)](https://github.com/PennyLaneAI/catalyst/pull/2672)
[(#2694)](https://github.com/PennyLaneAI/catalyst/pull/2694)
[(#2717)](https://github.com/PennyLaneAI/catalyst/pull/2717)
[(#2740)](https://github.com/PennyLaneAI/catalyst/pull/2740)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self reminder to do measurement and reset


* Removed the `condition` operand from `pbc.ppm` (Pauli Product Measurement) operations.
Conditional PPR decompositions in the `decompose-clifford-ppr` pass now emit the
Expand Down Expand Up @@ -1053,7 +1054,7 @@
[(#2738)](https://github.com/PennyLaneAI/catalyst/pull/2738)
[(#2736)](https://github.com/PennyLaneAI/catalyst/pull/2736)
[(#2715)](https://github.com/PennyLaneAI/catalyst/pull/2715)

* The "Compatibility with PennyLane transforms" section of the
:doc:`Sharp bits and debugging tips <../dev/sharp_bits>` document has been updated to describe
potential oddities that can be encountered when composing PennyLane transforms together.
Expand Down
1 change: 0 additions & 1 deletion frontend/catalyst/from_plxpr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,4 @@

"""Conversion from plxpr to catalyst jaxpr"""

from catalyst.from_plxpr.control_flow import handle_cond
from catalyst.from_plxpr.from_plxpr import from_plxpr, trace_from_pennylane
204 changes: 0 additions & 204 deletions frontend/catalyst/from_plxpr/control_flow.py

This file was deleted.

92 changes: 92 additions & 0 deletions frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
lower_jaxpr,
)

from pennylane.capture.primitives import cond_prim as pl_cond_prim
from pennylane.capture.primitives import for_loop_prim as pl_for_loop_prim
from pennylane.capture.primitives import jacobian_prim as pl_jac_prim
from pennylane.capture.primitives import jvp_prim as pl_jvp_prim
Expand Down Expand Up @@ -2373,6 +2374,96 @@ def emit_branches(preds, branch_jaxprs, ip):
return head_if_op.results


def _pl_cond_lowering(
jax_ctx: mlir.LoweringRuleContext,
*invals,
jaxpr_branches,
consts_slices,
args_slice,
):
result_types = [mlir.aval_to_ir_types(a)[0] for a in jax_ctx.avals_out]
num_preds = len(jaxpr_branches)
preds = invals[:num_preds]
args = invals[slice(*args_slice)]

# recursively lower if-else chains to nested IfOps
def emit_branches(preds, sub_branches, sub_consts_slices, insertion_point):
# closure vars are invals, args, jax_ctx

# ip is an MLIR InsertionPoint. This allows recursive calls to emit their Operations inside
# the 'else' blocks of preceding IfOps.
with insertion_point:
pred_extracted = TensorExtractOp(ir.IntegerType.get_signless(1), preds[0], []).result
if_op_scf = IfOp(pred_extracted, result_types, hasElse=True)
true_jaxpr = sub_branches[0]
if_block = if_op_scf.then_block

# if block
source_info_util.extend_name_stack("if")
if_ctx = jax_ctx.replace(name_stack=jax_ctx.name_stack.extend("if"))
with ir.InsertionPoint(if_block):
consts = invals[slice(*sub_consts_slices[0])]

new_jaxpr = true_jaxpr.replace(
constvars=(), invars=true_jaxpr.constvars + true_jaxpr.invars
)

# recursively generate the mlir for the if block
out, _ = mlir.jaxpr_subcomp(
if_ctx.module_context,
new_jaxpr,
if_ctx.name_stack,
mlir.TokenSet(),
[],
*consts,
*args,
dim_var_values=jax_ctx.dim_var_values,
const_lowering=jax_ctx.const_lowering,
)

YieldOp(out)

# else block
source_info_util.extend_name_stack("else")
else_ctx = jax_ctx.replace(name_stack=jax_ctx.name_stack.extend("else"))
else_block = if_op_scf.else_block
if len(preds) == 1:
# Base case: reached the otherwise block
else_jaxpr = sub_branches[-1]
consts = invals[slice(*sub_consts_slices[-1])]

new_jaxpr = else_jaxpr.replace(
constvars=(), invars=else_jaxpr.constvars + else_jaxpr.invars
)

with ir.InsertionPoint(else_block):
out, _ = mlir.jaxpr_subcomp(
else_ctx.module_context,
new_jaxpr,
else_ctx.name_stack,
mlir.TokenSet(),
[],
*consts,
*args,
dim_var_values=jax_ctx.dim_var_values,
const_lowering=jax_ctx.const_lowering,
)

YieldOp(out)
else:
with ir.InsertionPoint(else_block) as else_ip:
child_if_op = emit_branches(
preds[1:], sub_branches[1:], sub_consts_slices[1:], else_ip
)
YieldOp(child_if_op.results)
return if_op_scf

head_if_op = emit_branches(
preds, jaxpr_branches, consts_slices, jax_ctx.module_context.ip.current
)
return head_if_op.results


#
# Index Switch
#
Expand Down Expand Up @@ -3100,6 +3191,7 @@ def subroutine_lowering(*args, **kwargs):
(pl_jvp_prim, _capture_jvp_lowering),
(pl_value_and_grad_prim, _capture_value_and_grad_lowering),
(pl_while_loop_prim, _pl_while_loop_lowering),
(pl_cond_prim, _pl_cond_lowering),
(func_p, _func_lowering),
(jvp_p, _jvp_lowering),
(vjp_p, _vjp_lowering),
Expand Down
Loading
Loading