Skip to content

Commit defff56

Browse files
committed
feat: integrate compute_output_permutation into _optimize_unitary
Extract with `up_to_perm=True` so that basic_optimization runs on a SWAP-free circuit, then prepend the output permutation as SWAP gate objects (1 gate each instead of 3 CNOTs).
1 parent 6d18e28 commit defff56

2 files changed

Lines changed: 129 additions & 26 deletions

File tree

test/test_zxpass.py

Lines changed: 90 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,75 @@ def test_compute_output_permutation_non_bijective() -> None:
694694
compute_output_permutation(g)
695695

696696

697+
def test_permutation_swaps_in_pipeline() -> None:
698+
"""Test that _optimize_unitary uses SWAP gates (not CNOT decompositions).
699+
700+
After extraction with up_to_perm=True, the output permutation should be
701+
prepended as SWAP gate objects. Each SWAP counts as one gate instead of
702+
three CNOTs, improving the gate-count comparison.
703+
"""
704+
from zxpass.zxpass import _optimize_unitary # pylint: disable=import-outside-toplevel
705+
706+
# Build a circuit that produces a non-trivial output permutation.
707+
c = zx.Circuit(4)
708+
c.add_gate("CNOT", 0, 1)
709+
c.add_gate("CNOT", 1, 2)
710+
c.add_gate("HAD", 0)
711+
c.add_gate("CNOT", 2, 3)
712+
c.add_gate("HAD", 3)
713+
c.add_gate("CNOT", 3, 0)
714+
c.add_gate("CNOT", 0, 2)
715+
c.add_gate("HAD", 1)
716+
c.add_gate("CNOT", 2, 1)
717+
c.add_gate("CNOT", 1, 3)
718+
c.add_gate("HAD", 2)
719+
c.add_gate("CNOT", 3, 2)
720+
721+
optimized = _optimize_unitary(c)
722+
723+
# If the fallback didn't trigger, verify no CNOT-decomposed SWAPs leaked.
724+
if len(optimized.gates) < len(c.gates):
725+
# SWAPs should be at the beginning (prepended for the permutation).
726+
for i, g in enumerate(optimized.gates):
727+
if g.qasm_name != "swap":
728+
# All subsequent gates should be non-SWAP.
729+
rest = optimized.gates[i:]
730+
assert all(r.qasm_name != "swap" for r in rest), (
731+
"SWAP gates should only appear at the beginning"
732+
)
733+
break
734+
735+
736+
def test_permutation_to_swaps_correctness() -> None:
737+
"""Test that _permutation_to_swaps produces a correct SWAP sequence."""
738+
from zxpass.zxpass import _permutation_to_swaps # pylint: disable=import-outside-toplevel
739+
740+
# Identity permutation: no SWAPs needed.
741+
assert not _permutation_to_swaps({0: 0, 1: 1, 2: 2})
742+
743+
# Simple transposition.
744+
swaps = _permutation_to_swaps({0: 1, 1: 0})
745+
assert len(swaps) == 1
746+
assert set(swaps[0]) == {0, 1}
747+
748+
# 3-cycle: needs 2 transpositions.
749+
perm = {0: 1, 1: 2, 2: 0}
750+
swaps = _permutation_to_swaps(perm)
751+
# Verify the SWAPs implement the permutation.
752+
state = list(range(3))
753+
for i, j in swaps:
754+
state[i], state[j] = state[j], state[i]
755+
assert state == [perm[k] for k in range(3)]
756+
757+
# Larger permutation with multiple cycles.
758+
perm = {0: 2, 1: 0, 2: 1, 3: 4, 4: 3}
759+
swaps = _permutation_to_swaps(perm)
760+
state = list(range(5))
761+
for i, j in swaps:
762+
state[i], state[j] = state[j], state[i]
763+
assert state == [perm[k] for k in range(5)]
764+
765+
697766
def test_post_extraction_cleanup() -> None:
698767
"""Test that ``_optimize_unitary`` applies ``basic_optimization`` after extraction.
699768
@@ -738,32 +807,35 @@ def test_post_extraction_cleanup() -> None:
738807

739808

740809
def test_post_extraction_cleanup_equivalence() -> None:
741-
"""Test that the post-extraction cleanup preserves circuit equivalence.
742-
743-
Runs multiple circuits through the full ZXPass pipeline and verifies
744-
statevector equivalence after the basic_optimization post-pass.
745-
"""
746-
circuits = []
747-
748-
# Bell state preparation with extra gates.
810+
"""Test that post-extraction cleanup and permutation integration preserve equivalence."""
749811
qc1 = QuantumCircuit(3)
750812
qc1.h(0)
751813
qc1.cx(0, 1)
752814
qc1.cx(1, 2)
753815
qc1.h(2)
754816
qc1.cx(2, 0)
755-
circuits.append(qc1)
756817

757-
# Multi-CX circuit.
758-
qc2 = QuantumCircuit(4)
759-
for i in range(3):
818+
qc2 = QuantumCircuit(5)
819+
for i in range(5):
820+
qc2.h(i)
821+
for i in range(4):
760822
qc2.cx(i, i + 1)
761-
qc2.h(0)
762-
qc2.cx(3, 0)
763-
qc2.h(1)
764-
circuits.append(qc2)
765-
766-
for qc in circuits:
823+
qc2.cx(4, 0)
824+
for i in range(5):
825+
qc2.rz(0.3 * (i + 1), i)
826+
827+
qc3 = QuantumCircuit(4)
828+
qc3.h(0)
829+
qc3.cx(0, 1)
830+
qc3.h(1)
831+
qc3.cx(1, 2)
832+
qc3.h(2)
833+
qc3.cx(2, 3)
834+
qc3.cx(3, 0)
835+
qc3.cx(2, 1)
836+
qc3.cx(1, 0)
837+
838+
for qc in [qc1, qc2, qc3]:
767839
assert _run_zxpass(qc), f"Equivalence failed for circuit:\n{qc}"
768840

769841

zxpass/zxpass.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -177,21 +177,52 @@ def compute_output_permutation(g: Any) -> Dict[int, int]:
177177
return perm
178178

179179

180+
def _permutation_to_swaps(perm: Dict[int, int]) -> List[Tuple[int, int]]:
181+
"""Decompose a permutation into transpositions (SWAP pairs).
182+
183+
Uses cycle decomposition. The returned list, applied left to right,
184+
implements the forward permutation (wire *j* ends up holding the state
185+
of input qubit ``perm[j]``).
186+
"""
187+
n = len(perm)
188+
current = [perm[i] for i in range(n)]
189+
swaps: List[Tuple[int, int]] = []
190+
for i in range(n):
191+
while current[i] != i:
192+
j = current[i]
193+
swaps.append((i, j))
194+
current[i], current[j] = current[j], current[i]
195+
swaps.reverse()
196+
return swaps
197+
198+
180199
def _optimize_unitary(c: zx.Circuit) -> zx.Circuit:
181200
"""Optimise a purely unitary PyZX circuit using full_reduce and extraction.
182201
183-
After extraction, ``basic_optimization`` converts HAD-CZ-HAD sequences to
184-
CNOTs and cancels redundant single-qubit gates. If the result still has at
185-
least as many gates as the original circuit, the original is returned
186-
unchanged to avoid regressions on small circuits with compact multi-qubit
187-
gates (e.g. Toffoli, Fredkin). The comparison counts PyZX gate objects
188-
directly; since ``_recover_dag`` emits one Qiskit op per PyZX gate, this
189-
matches the Qiskit-side ``size()`` that downstream passes see.
202+
Extracts with ``up_to_perm=True`` so that ``basic_optimization`` runs on a
203+
circuit free of SWAP-decomposition clutter. The output permutation is then
204+
prepended as SWAP gates (each counting as one gate rather than three CNOTs),
205+
giving a fairer gate-count comparison against the original circuit. If the
206+
result still has at least as many gates as the original, the original is
207+
returned unchanged to avoid regressions on small circuits with compact
208+
multi-qubit gates (e.g. Toffoli, Fredkin). The comparison counts PyZX gate
209+
objects directly; since ``_recover_dag`` emits one Qiskit op per PyZX gate,
210+
this matches the Qiskit-side ``size()`` that downstream passes see.
190211
"""
191212
g = c.to_graph()
192213
zx.simplify.full_reduce(g)
193-
optimized = zx.extract.extract_circuit(g)
214+
optimized = zx.extract.extract_circuit(g, up_to_perm=True)
215+
perm = compute_output_permutation(g)
194216
optimized = basic_optimization(optimized.to_basic_gates(), do_swaps=False)
217+
# Prepend SWAP gates for the output permutation.
218+
swap_pairs = _permutation_to_swaps(perm)
219+
if swap_pairs:
220+
with_perm = zx.Circuit(c.qubits)
221+
for i, j in swap_pairs:
222+
with_perm.add_gate(SWAP(i, j))
223+
for gate in optimized.gates:
224+
with_perm.add_gate(gate)
225+
optimized = with_perm
195226
# TODO: Consider a two-axis comparison keyed primarily on 2-qubit gate
196227
# count (``twoqubitcount()``), with total gate count as a tiebreaker. The
197228
# 2-qubit count is the dominant hardware cost and is naturally apples-to-

0 commit comments

Comments
 (0)