Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
498677b
[WIP] Add conversion pattern for qecl.qec ops
joeycarter Apr 24, 2026
e7ef5c6
[WIP] Continue adding QEC cycle conversion pattern
joeycarter Apr 27, 2026
e605612
Merge branch 'main' into joeycarter/qecl-to-qecp-qec-cycle
joeycarter Apr 27, 2026
f937de9
Insert Z-correction ops
joeycarter Apr 27, 2026
76109b9
[WIP] Apply both X and Z corrections
joeycarter Apr 27, 2026
60b7037
Update conversion pattern; it's now complete!
joeycarter Apr 27, 2026
c3760a0
Merge branch 'main' into joeycarter/qecl-to-qecp-qec-cycle
joeycarter Apr 27, 2026
0370bfb
Fix mistake in tensor shape result of qecp.decode_esm_css op
joeycarter Apr 27, 2026
887afd9
Add changelog entry
joeycarter Apr 27, 2026
3334ce8
Merge branch 'main' into joeycarter/qecl-to-qecp-qec-cycle
joeycarter Apr 28, 2026
2ed1413
Clean up
joeycarter Apr 28, 2026
0932f87
Merge branch 'main' into joeycarter/qecl-to-qecp-qec-cycle
joeycarter Apr 28, 2026
f7f0101
Fix np.ndarray type hints that caused pylint error
joeycarter Apr 28, 2026
3f6755d
Merge branch 'main' into joeycarter/qecl-to-qecp-qec-cycle
joeycarter Apr 28, 2026
b5fefb9
Fix the parity_check_matrix_to_tanner_csc function
joeycarter Apr 30, 2026
2737848
Update qecl.qec pattern to use correct Tanner graph
joeycarter Apr 30, 2026
3ed0dcb
Merge branch 'main' into joeycarter/qecl-to-qecp-qec-cycle
joeycarter Apr 30, 2026
2ed3442
Placate CodeFactor
joeycarter Apr 30, 2026
c33e85f
[no-ci] Apply suggestions from code review
joeycarter May 1, 2026
3c2e128
Fix row_idx, col_ptr names in tanner_graph_lib in docs
joeycarter May 1, 2026
225fc01
Replace math.floor -> //
joeycarter May 1, 2026
048fb7a
Mathify more of the tanner lib docstring
joeycarter May 1, 2026
ce04ee4
Clean up wording on stabilizers use to identify errors
joeycarter May 1, 2026
2d2d6be
Merge branch 'main' into joeycarter/qecl-to-qecp-qec-cycle
joeycarter May 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,7 @@
[(#2737)](https://github.com/PennyLaneAI/catalyst/pull/2737)
[(#2731)](https://github.com/PennyLaneAI/catalyst/pull/2731)
[(#2735)](https://github.com/PennyLaneAI/catalyst/pull/2735)
[(#2754)](https://github.com/PennyLaneAI/catalyst/pull/2754)

* A number of deprecation warnings have been fixed in the compiler python interface.
[(#2621)](https://github.com/PennyLaneAI/catalyst/pull/2621)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from xdsl.builder import ImplicitBuilder
from xdsl.context import Context
from xdsl.dialects import arith, builtin, func, scf, tensor
from xdsl.dialects.builtin import I1, IndexType, SymbolRefAttr, TensorType, i1
from xdsl.dialects.builtin import I1, IndexType, SymbolRefAttr, TensorType, i1, i32, i64
from xdsl.ir import Block, BlockArgument, OpResult, Region
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
Expand All @@ -49,6 +49,7 @@

from catalyst.python_interface.dialects import qecl, qecp
from catalyst.python_interface.pass_api.compiler_transform import compiler_transform
from catalyst.python_interface.transforms.qecp.tanner_graph_lib import dense_tanner_graph_to_csc
from catalyst.utils.exceptions import CompileError

from .convert_qecl_noise_to_qec_noise import ConvertQECLNoiseOpToQECPNoisePass
Expand Down Expand Up @@ -167,20 +168,56 @@ def match_and_rewrite(self, op: qecl.EncodeOp, rewriter: PatternRewriter):
"for init_state 'zero'"
)

if (k := op.in_codeblock.type.k.value.data) != self.qec_code.k:
in_codeblock = cast(
qecl.LogicalCodeBlockSSAValue | qecp.PhysicalCodeBlockSSAValue, op.in_codeblock
)
Comment thread
joeycarter marked this conversation as resolved.

if (k := in_codeblock.type.k.value.data) != self.qec_code.k:
raise CompileError(
f"Circuit expressed in the qecl dialect with k={k} is not compatible with "
f"lowering to a code with k={self.qec_code.k}"
)

callee = builtin.SymbolRefAttr(self.encode_subroutine.sym_name)
arguments = (op.in_codeblock,)
arguments = (in_codeblock,)
return_types = self.encode_subroutine.function_type.outputs.data
callOp = func.CallOp(callee, arguments, return_types)

rewriter.replace_op(op, callOp)


# MARK: QEC Cycle Op Pattern


@dataclass
class QecCycleOpConversion(RewritePattern):
"""Converts qecl.qec to the equivalent subroutine of qecp gates."""

qec_code: QecCode
qec_cycle_subroutine: func.FuncOp

@op_type_rewrite_pattern
def match_and_rewrite(self, op: qecl.QecCycleOp, rewriter: PatternRewriter):
"""Rewrite pattern for `qecl.qec` ops."""

in_codeblock = cast(
qecl.LogicalCodeBlockSSAValue | qecp.PhysicalCodeBlockSSAValue, op.in_codeblock
)

if (k := in_codeblock.type.k.value.data) != self.qec_code.k:
raise CompileError(
f"Circuit expressed in the qecl dialect with k={k} is not compatible with "
f"lowering to a code with k={self.qec_code.k}"
)

callee = builtin.SymbolRefAttr(self.qec_cycle_subroutine.sym_name)
arguments = (in_codeblock,)
return_types = self.qec_cycle_subroutine.function_type.outputs.data
callOp = func.CallOp(callee, arguments, return_types)

rewriter.replace_op(op, callOp)


# MARK: Measure Op Pattern


Expand Down Expand Up @@ -286,12 +323,17 @@ def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
n=self.qec_code.n, number_errors=self.number_errors
).apply(ctx, op)

# Insert subroutines into the module
module_block = op.regions[0].blocks.first
assert module_block is not None, "Module has no block"

encode_subroutine = self.create_encode_subroutine()
module_block.add_op(encode_subroutine)
tanner_x, tanner_z = self.insert_tanner_graph_ops_into_block(module_block)

# Insert subroutines that implement the QEC protocols
encode_funcop = self.create_encode_subroutine()
module_block.add_op(encode_funcop)

qec_cycle_funcop = self.create_qec_cycle_subroutine(tanner_x=tanner_x, tanner_z=tanner_z)
module_block.add_op(qec_cycle_funcop)

measure_subroutine = self.create_measure_subroutine()
module_block.add_op(measure_subroutine)
Expand All @@ -305,7 +347,10 @@ def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
DeallocationConversion(),
InsertBlockConversion(),
ExtractBlockConversion(),
EncodeOpConversion(qec_code=self.qec_code, encode_subroutine=encode_subroutine),
EncodeOpConversion(qec_code=self.qec_code, encode_subroutine=encode_funcop),
QecCycleOpConversion(
qec_code=self.qec_code, qec_cycle_subroutine=qec_cycle_funcop
),
MeasureOpConversion(
qec_code=self.qec_code,
measure_subroutine=measure_subroutine,
Expand All @@ -314,6 +359,78 @@ def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
)
).rewrite_module(op)

def insert_tanner_graph_ops_into_block(
self, block: Block
) -> tuple[OpResult[qecp.TannerGraphType], OpResult[qecp.TannerGraphType]]:
"""Insert Tanner graph operations into the given block.

The operations are inserted at the beginning of the block.

Returns the X and Z Tanner graph SSA values from the `qecp.assemble_tanner` ops (we assume
a CSS code here and therefore have separate X and Z Tanner graphs).
"""
x_tanner_row_idx_array, x_tanner_col_ptr_array = dense_tanner_graph_to_csc(
self.qec_code.x_tanner
)
Comment thread
joeycarter marked this conversation as resolved.
Outdated
x_tanner_row_idx_const_op = arith.ConstantOp(
builtin.DenseIntOrFPElementsAttr.from_list(
type=builtin.TensorType(i32, shape=x_tanner_row_idx_array.shape),
data=x_tanner_row_idx_array.tolist(),
)
)
x_tanner_col_ptr_const_op = arith.ConstantOp(
builtin.DenseIntOrFPElementsAttr.from_list(
type=builtin.TensorType(i32, shape=x_tanner_col_ptr_array.shape),
data=x_tanner_col_ptr_array.tolist(),
)
)
assemble_x_tanner_op = qecp.AssembleTannerGraphOp(
row_idx=x_tanner_row_idx_const_op,
col_ptr=x_tanner_col_ptr_const_op,
tanner_graph_type=qecp.TannerGraphType(
x_tanner_row_idx_array.shape[0], x_tanner_col_ptr_array.shape[0], i32
),
)

z_tanner_row_idx_array, z_tanner_col_ptr_array = dense_tanner_graph_to_csc(
self.qec_code.z_tanner
)
z_tanner_row_idx_const_op = arith.ConstantOp(
builtin.DenseIntOrFPElementsAttr.from_list(
type=builtin.TensorType(i32, shape=z_tanner_row_idx_array.shape),
data=z_tanner_row_idx_array.tolist(),
)
)
z_tanner_col_ptr_const_op = arith.ConstantOp(
builtin.DenseIntOrFPElementsAttr.from_list(
type=builtin.TensorType(i32, shape=z_tanner_col_ptr_array.shape),
data=z_tanner_col_ptr_array.tolist(),
)
)
assemble_z_tanner_op = qecp.AssembleTannerGraphOp(
row_idx=z_tanner_row_idx_const_op,
col_ptr=z_tanner_col_ptr_const_op,
tanner_graph_type=qecp.TannerGraphType(
z_tanner_row_idx_array.shape[0], z_tanner_col_ptr_array.shape[0], i32
),
)

ops_to_insert = (
x_tanner_row_idx_const_op,
x_tanner_col_ptr_const_op,
assemble_x_tanner_op,
z_tanner_row_idx_const_op,
z_tanner_col_ptr_const_op,
assemble_z_tanner_op,
)

if block.first_op is None:
block.add_ops(ops_to_insert)
else:
block.insert_ops_before(ops_to_insert, block.first_op)

return assemble_x_tanner_op.tanner_graph, assemble_z_tanner_op.tanner_graph

def create_measure_subroutine(self) -> func.FuncOp:
"""Create the subroutine that performs the transversal measurement of a physical codeblock.

Expand Down Expand Up @@ -445,7 +562,7 @@ def check_pattern(
in_aux_qbs: Iterable[qecp.QecPhysicalQubitSSAValue],
in_codeblock: qecp.PhysicalCodeBlockSSAValue,
check_type: CheckType,
) -> tuple[Iterable[qecp.MeasureOp], qecp.PhysicalCodeBlockSSAValue]:
) -> tuple[list[qecp.MeasureOp], qecp.PhysicalCodeBlockSSAValue]:
"""Contains the ops to perform a QEC check on the provided auxiliary qubits and codeblock.
Intended to be called inside `builder.ImplicitBuilder` to add these operations to a block.

Expand Down Expand Up @@ -537,5 +654,180 @@ def cnot_fn(aux_qb, data_qb):

return tanner_graph, cnot_fn

def create_qec_cycle_subroutine(
self, tanner_x: qecp.TannerGraphSSAValue, tanner_z: qecp.TannerGraphSSAValue
) -> func.FuncOp:
"""Create a subroutine that performs a cycle of QEC on an input physical codeblock.

The generated subroutine assumes a CSS QEC code and performs separate X and Z corrections,
as defined by the input X and Z Tanner graphs, `tanner_x` and `tanner_z`. Recall that
X-Tanner graphs define the X stabilizer components of the code, which are used to perform Z
Comment thread
joeycarter marked this conversation as resolved.
Outdated
corrections, and conversely Z-Tanner graphs define the Z stabilizer components of the code,
which are used to perform X corrections.

For each of the X and Z components of the QEC protocol, the subroutine allocates auxiliary
qubits for error-syndrome measurement (ESM) based on the number of rows in the respective
Tanner graph. After obtaining the ESM, it deallocates the auxiliary qubits and feeds the ESM
into a call to the ESM decoder, which returns the indices in the physical codeblock where
the detected error(s) occurred. It then iterates over these codeblock indices, applies the
respective correction, and finally returns the updated physical codeblock SSA value.

Note that this method does not insert the subroutine into the module op. Instead it returns
the built func.FuncOp object that can then be subsequently inserted where desired.
"""

codeblock_type = qecp.PhysicalCodeblockType(self.qec_code.k, self.qec_code.n)
input_types = (codeblock_type,)
output_types = (codeblock_type,)

block = Block(arg_types=input_types)

with ImplicitBuilder(block):
in_codeblock = cast(BlockArgument[qecp.PhysicalCodeblockType], block.args[0])

# Apply X checks pattern for Z corrections
x_out_codeblock = self._qec_cycle_css_pattern(in_codeblock, CheckType.X, tanner_x)

# Apply Z checks pattern for X corrections
z_out_codeblock = self._qec_cycle_css_pattern(x_out_codeblock, CheckType.Z, tanner_z)

# Return the corrected codeblock
func.ReturnOp(z_out_codeblock)

funcOp = func.FuncOp(
name=f"qec_cycle_{self.qec_code.name}",
function_type=(input_types, output_types),
visibility="private",
region=Region([block]),
)

return funcOp

def _qec_cycle_css_pattern(
self,
in_codeblock: qecp.PhysicalCodeBlockSSAValue,
check_type: CheckType,
tanner_graph: qecp.TannerGraphSSAValue,
) -> OpResult[qecp.PhysicalCodeblockType]:
"""Build the operations that perform a single X or Z component of a CSS QEC cycle on the
given `in_codeblock`.

This method is intended to be a helper function to `create_qec_cycle_subroutine()` and to be
called inside a `builder.ImplicitBuilder` context to automatically add these operations to a
block.
"""
# Allocate auxiliary qubits for ESM checks
aux_allocate_ops = (qecp.AllocAuxQubitOp() for row in self.qec_code.x_tanner)
aux_qubits = [
cast(OpResult[qecp.QecPhysicalQubitType], op.results[0]) for op in aux_allocate_ops
]

# Apply gate+measurement pattern for the check
measure_ops, post_check_codeblock = self.check_pattern(
aux_qubits, in_codeblock, check_type=check_type
)

# Checks are done; deallocate the auxiliary qubits
for x_meas_op in measure_ops:
qecp.DeallocAuxQubitOp(x_meas_op.out_qubit)

# Pack measurement results into a tensor for decoding
pack_mres_tensor_op = tensor.FromElementsOp.build(
operands=([meas_op.mres for meas_op in measure_ops],),
result_types=(TensorType(i1, shape=(len(measure_ops),)),),
)

# Decode ESM syndrome
num_correctable_errors = self.qec_code.correctable_errors
decode_esm_op = qecp.DecodeEsmCssOp(
tanner_graph=tanner_graph,
esm=pack_mres_tensor_op.result,
err_idx_type=TensorType(IndexType(), shape=(num_correctable_errors,)),
)

# Apply correction(s)
err_indices = cast(OpResult[TensorType[IndexType]], decode_esm_op.err_idx)

assert err_indices.type == (
expected_type := TensorType(IndexType(), shape=(num_correctable_errors,))
), (
f"Expected result of op '{decode_esm_op}' to have type '{expected_type}', "
f"but got '{err_indices.type}'"
)

# Build a for loop that iterates over each error index
lb_op = arith.ConstantOp.from_int_and_width(0, IndexType())
ub_op = arith.ConstantOp.from_int_and_width(num_correctable_errors, IndexType())
step_op = arith.ConstantOp.from_int_and_width(1, IndexType())

for_body = Block(
[],
arg_types=(builtin.IndexType(), post_check_codeblock.type),
)

for_each_err_idx_op = scf.ForOp(
lb=lb_op,
ub=ub_op,
step=step_op,
iter_args=(post_check_codeblock,),
body=for_body,
)

with ImplicitBuilder(for_each_err_idx_op.body):
indvar = cast(BlockArgument[IndexType], for_each_err_idx_op.body.block.args[0])
codeblock = cast(
BlockArgument[qecp.PhysicalCodeblockType],
for_each_err_idx_op.body.block.args[1],
)

extract_err_idx_op = tensor.ExtractOp(
err_indices, indices=indvar, result_type=IndexType()
)
err_idx = cast(OpResult[IndexType], extract_err_idx_op.result)

# Now we have the error index for this iteration in the for loop. Next we check if its
# value indicates that an error was detected (idx != -1), or if no error was detected
# (idx == -1).
cast_index_op = arith.IndexCastOp(err_idx, target_type=i64)
minus_1_const_op = arith.ConstantOp.from_int_and_width(-1, 64)
apply_corr_cond_op = arith.CmpiOp(cast_index_op.result, minus_1_const_op.result, "ne")

if_apply_corr_op = scf.IfOp(
apply_corr_cond_op.result,
return_types=(codeblock.type,),
true_region=Region(Block()),
false_region=Region(Block()),
)

with ImplicitBuilder(if_apply_corr_op.true_region):
# This branch is for the case where a correctable error was detected
extract_err_qubit_op = qecp.ExtractQubitOp(codeblock=codeblock, idx=err_idx)
err_qubit = extract_err_qubit_op.qubit

match check_type:
case CheckType.X:
corr_qubit_op = qecp.PauliZOp(in_qubit=err_qubit)
case CheckType.Z:
corr_qubit_op = qecp.PauliXOp(in_qubit=err_qubit)
case _:
assert False, f"Unknown CheckType: '{check_type}'"

insert_err_qubit_op = qecp.InsertQubitOp(
in_codeblock=post_check_codeblock, idx=err_idx, qubit=corr_qubit_op.out_qubit
)

scf.YieldOp(insert_err_qubit_op.out_codeblock)

with ImplicitBuilder(if_apply_corr_op.false_region):
# This branch is for the case where no correctable error was detected
scf.YieldOp(codeblock)

out_codeblock = cast(OpResult[qecp.PhysicalCodeblockType], if_apply_corr_op.results[0])

scf.YieldOp(out_codeblock)

# Return updated codeblock SSA value
return out_codeblock


convert_qecl_to_qecp_pass = compiler_transform(ConvertQecLogicalToQecPhysicalPass)
Loading
Loading