@@ -694,6 +694,75 @@ def test_compute_output_permutation_non_bijective() -> None:
694694 compute_output_permutation (g )
695695
696696
697+ def test_permutation_swaps_in_pipeline () -> None :
698+ """Test that _optimize_unitary uses SWAP gates (not CNOT decompositions).
699+
700+ After extraction with up_to_perm=True, the output permutation should be
701+ prepended as SWAP gate objects. Each SWAP counts as one gate instead of
702+ three CNOTs, improving the gate-count comparison.
703+ """
704+ from zxpass .zxpass import _optimize_unitary # pylint: disable=import-outside-toplevel
705+
706+ # Build a circuit that produces a non-trivial output permutation.
707+ c = zx .Circuit (4 )
708+ c .add_gate ("CNOT" , 0 , 1 )
709+ c .add_gate ("CNOT" , 1 , 2 )
710+ c .add_gate ("HAD" , 0 )
711+ c .add_gate ("CNOT" , 2 , 3 )
712+ c .add_gate ("HAD" , 3 )
713+ c .add_gate ("CNOT" , 3 , 0 )
714+ c .add_gate ("CNOT" , 0 , 2 )
715+ c .add_gate ("HAD" , 1 )
716+ c .add_gate ("CNOT" , 2 , 1 )
717+ c .add_gate ("CNOT" , 1 , 3 )
718+ c .add_gate ("HAD" , 2 )
719+ c .add_gate ("CNOT" , 3 , 2 )
720+
721+ optimized = _optimize_unitary (c )
722+
723+ # If the fallback didn't trigger, verify no CNOT-decomposed SWAPs leaked.
724+ if len (optimized .gates ) < len (c .gates ):
725+ # SWAPs should be at the beginning (prepended for the permutation).
726+ for i , g in enumerate (optimized .gates ):
727+ if g .qasm_name != "swap" :
728+ # All subsequent gates should be non-SWAP.
729+ rest = optimized .gates [i :]
730+ assert all (r .qasm_name != "swap" for r in rest ), (
731+ "SWAP gates should only appear at the beginning"
732+ )
733+ break
734+
735+
736+ def test_permutation_to_swaps_correctness () -> None :
737+ """Test that _permutation_to_swaps produces a correct SWAP sequence."""
738+ from zxpass .zxpass import _permutation_to_swaps # pylint: disable=import-outside-toplevel
739+
740+ # Identity permutation: no SWAPs needed.
741+ assert not _permutation_to_swaps ({0 : 0 , 1 : 1 , 2 : 2 })
742+
743+ # Simple transposition.
744+ swaps = _permutation_to_swaps ({0 : 1 , 1 : 0 })
745+ assert len (swaps ) == 1
746+ assert set (swaps [0 ]) == {0 , 1 }
747+
748+ # 3-cycle: needs 2 transpositions.
749+ perm = {0 : 1 , 1 : 2 , 2 : 0 }
750+ swaps = _permutation_to_swaps (perm )
751+ # Verify the SWAPs implement the permutation.
752+ state = list (range (3 ))
753+ for i , j in swaps :
754+ state [i ], state [j ] = state [j ], state [i ]
755+ assert state == [perm [k ] for k in range (3 )]
756+
757+ # Larger permutation with multiple cycles.
758+ perm = {0 : 2 , 1 : 0 , 2 : 1 , 3 : 4 , 4 : 3 }
759+ swaps = _permutation_to_swaps (perm )
760+ state = list (range (5 ))
761+ for i , j in swaps :
762+ state [i ], state [j ] = state [j ], state [i ]
763+ assert state == [perm [k ] for k in range (5 )]
764+
765+
697766def test_post_extraction_cleanup () -> None :
698767 """Test that ``_optimize_unitary`` applies ``basic_optimization`` after extraction.
699768
@@ -738,32 +807,35 @@ def test_post_extraction_cleanup() -> None:
738807
739808
740809def test_post_extraction_cleanup_equivalence () -> None :
741- """Test that the post-extraction cleanup preserves circuit equivalence.
742-
743- Runs multiple circuits through the full ZXPass pipeline and verifies
744- statevector equivalence after the basic_optimization post-pass.
745- """
746- circuits = []
747-
748- # Bell state preparation with extra gates.
810+ """Test that post-extraction cleanup and permutation integration preserve equivalence."""
749811 qc1 = QuantumCircuit (3 )
750812 qc1 .h (0 )
751813 qc1 .cx (0 , 1 )
752814 qc1 .cx (1 , 2 )
753815 qc1 .h (2 )
754816 qc1 .cx (2 , 0 )
755- circuits .append (qc1 )
756817
757- # Multi-CX circuit.
758- qc2 = QuantumCircuit (4 )
759- for i in range (3 ):
818+ qc2 = QuantumCircuit (5 )
819+ for i in range (5 ):
820+ qc2 .h (i )
821+ for i in range (4 ):
760822 qc2 .cx (i , i + 1 )
761- qc2 .h (0 )
762- qc2 .cx (3 , 0 )
763- qc2 .h (1 )
764- circuits .append (qc2 )
765-
766- for qc in circuits :
823+ qc2 .cx (4 , 0 )
824+ for i in range (5 ):
825+ qc2 .rz (0.3 * (i + 1 ), i )
826+
827+ qc3 = QuantumCircuit (4 )
828+ qc3 .h (0 )
829+ qc3 .cx (0 , 1 )
830+ qc3 .h (1 )
831+ qc3 .cx (1 , 2 )
832+ qc3 .h (2 )
833+ qc3 .cx (2 , 3 )
834+ qc3 .cx (3 , 0 )
835+ qc3 .cx (2 , 1 )
836+ qc3 .cx (1 , 0 )
837+
838+ for qc in [qc1 , qc2 , qc3 ]:
767839 assert _run_zxpass (qc ), f"Equivalence failed for circuit:\n { qc } "
768840
769841
0 commit comments