diff --git a/cirq-core/cirq/transformers/routing/route_circuit_cqc.py b/cirq-core/cirq/transformers/routing/route_circuit_cqc.py index 3c4032b5ea3..2c78b59de46 100644 --- a/cirq-core/cirq/transformers/routing/route_circuit_cqc.py +++ b/cirq-core/cirq/transformers/routing/route_circuit_cqc.py @@ -81,6 +81,15 @@ class RouteCQC: the swap that minimises the cost and use it to update our logical to physical mapping. Repeat from 3.1. + Handling Directed Graphs: + + When the device_graph is directed (e.g., edges represent unidirectional CNOT constraints), + the routing logic still operates as if the graph were undirected. This is because SWAP + gates are logically symmetric regardless of underlying gate direction constraints. + After routing completes, any inserted SWAP gates are decomposed into a directional-aware + sequence using the Hadamard trick: + ``SWAP = CNOT(ctrl, tgt) - H⊗H - CNOT(ctrl, tgt) - H⊗H - CNOT(ctrl, tgt)`` + For example: >>> import cirq_google as cg @@ -100,13 +109,9 @@ def __init__(self, device_graph: nx.Graph): Args: device_graph: The connectivity graph of physical qubits. - - Raises: - ValueError: if `device_graph` is a directed graph. + Can be directed or undirected. """ - if nx.is_directed(device_graph): - raise ValueError("Device graph must be undirected.") self.device_graph = device_graph def __call__( @@ -182,7 +187,7 @@ def route_circuit( circuit: the input circuit to be transformed. lookahead_radius: the maximum number of succeeding timesteps the algorithm will consider for ranking candidate swaps with the cost cost function. - tag_inserted_swaps: whether or not a RoutingSwapTag should be attched to inserted swap + tag_inserted_swaps: whether or not a RoutingSwapTag should be attached to inserted swap operations. initial_mapper: an initial mapping strategy (placement) of logical qubits in the circuit onto physical qubits on the device. @@ -223,7 +228,7 @@ def route_circuit( two_qubit_ops, single_qubit_ops = self._get_one_and_two_qubit_ops_as_timesteps(circuit) # 4. Do the routing and save the routed circuit as a list of moments. - routed_ops = self._route( + routed_ops, routing_swaps = self._route( mm, two_qubit_ops, single_qubit_ops, @@ -231,10 +236,20 @@ def route_circuit( tag_inserted_swaps=tag_inserted_swaps, ) - # 5. Return the routed circuit by packing each inner list of ops as densely as possible and - # preserving outer moment structure. Also return initial map and swap permutation map. + # 5. Replace tagged SWAP gates with directional decompositions if needed. + # This handles directed device graphs by decomposing SWAP into a sequence of CNOTs + # that respect the edge direction constraints. + routed_circuit = circuits.Circuit(circuits.Circuit(m) for m in routed_ops) + if routing_swaps and nx.is_directed(self.device_graph): + final_circuit = self._replace_swaps_with_directional_decomposition( + routed_circuit, routing_swaps + ) + else: + final_circuit = routed_circuit + + # 6. Return the routed circuit and mappings. return ( - circuits.Circuit(circuits.Circuit(m) for m in routed_ops), + final_circuit, initial_mapping, { initial_mapping[mm.int_to_logical_qid[k]]: mm.int_to_physical_qid[v] @@ -285,6 +300,59 @@ def _get_one_and_two_qubit_ops_as_timesteps( two_qubit_ops = [list(m) for m in two_qubit_circuit] return two_qubit_ops, single_qubit_ops + def _replace_swaps_with_directional_decomposition( + self, circuit: cirq.AbstractCircuit, routing_swaps: set[cirq.Operation] + ) -> cirq.AbstractCircuit: + """Replaces routing-added SWAP gates with directional decompositions. + + For directed device graphs, SWAP gates need to be decomposed into CNOTs that + respect the edge direction. This method uses cirq.map_operations_and_unroll to + find all routing-added SWAP gates and replaces them with the appropriate + decomposition. + + For bidirectional edges (or undirected graphs), the SWAP is left unchanged. + For unidirectional edges, the SWAP is decomposed using the Hadamard trick: + SWAP = CNOT(ctrl, tgt) - H⊗H - CNOT(ctrl, tgt) - H⊗H - CNOT(ctrl, tgt) + + Args: + circuit: The routed circuit containing SWAP operations. + routing_swaps: Set of routing-added SWAP operations to decompose. + + Returns: + Circuit with directional SWAP decompositions where needed. + """ + + def map_func(op: cirq.Operation, _: int) -> cirq.OP_TREE: + """Map function to replace routing-added SWAPs with directional decomposition.""" + # Check if this is a routing-added SWAP operation + if op not in routing_swaps: + return op + + q1, q2 = op.qubits + has_forward = self.device_graph.has_edge(q1, q2) + has_reverse = self.device_graph.has_edge(q2, q1) + + if has_forward ^ has_reverse: + # Unidirectional: decompose SWAP using Hadamard trick + ctrl, tgt = (q1, q2) if has_forward else (q2, q1) + # Preserve the RoutingSwapTag on the decomposed operations + decomposed_ops: list[cirq.Operation] = [ + ops.CNOT(ctrl, tgt), + ops.H(ctrl), + ops.H(tgt), + ops.CNOT(ctrl, tgt), + ops.H(ctrl), + ops.H(tgt), + ops.CNOT(ctrl, tgt), + ] + # Transfer tags from original SWAP to decomposed operations + return [op_i.with_tags(*op.tags) for op_i in decomposed_ops] + + # Bidirectional or no edge check needed at routing level - keep as-is + return op + + return transformer_primitives.map_operations_and_unroll(circuit, map_func) + @classmethod def _route( cls, @@ -293,10 +361,10 @@ def _route( single_qubit_ops: list[list[cirq.Operation]], lookahead_radius: int, tag_inserted_swaps: bool = False, - ) -> list[list[cirq.Operation]]: + ) -> tuple[list[list[cirq.Operation]], set[cirq.Operation]]: """Main routing procedure that inserts necessary swaps on the given timesteps. - The i'th element of the returned list corresponds to the routed operatiosn in the i'th + The i'th element of the returned list corresponds to the routed operations in the i'th timestep. Args: @@ -306,11 +374,12 @@ def _route( the paper. lookahead_radius: the maximum number of times the cost function can be iterated for convergence. - tag_inserted_swaps: whether or not a RoutingSwapTag should be attched to inserted swap + tag_inserted_swaps: whether or not a RoutingSwapTag should be attached to inserted swap operations. Returns: - a list of lists corresponding to timesteps of the routed circuit. + A list of lists corresponding to timesteps of the routed circuit and + a set of inserted SWAP operations. """ two_qubit_ops_ints: list[list[QidIntPair]] = [ [ @@ -336,6 +405,8 @@ def process_executable_two_qubit_ops(timestep: int) -> int: strats = [cls._choose_single_swap, cls._choose_pair_of_swaps] + inserted_swaps: set[cirq.Operation] = set() + for timestep in range(len(two_qubit_ops)): # Add single-qubit ops with qubits given by the current mapping. routed_ops.append([mm.mapped_op(op) for op in single_qubit_ops[timestep]]) @@ -362,10 +433,11 @@ def process_executable_two_qubit_ops(timestep: int) -> int: ) if tag_inserted_swaps: inserted_swap = inserted_swap.with_tags(ops.RoutingSwapTag()) + inserted_swaps.add(inserted_swap) routed_ops[timestep].append(inserted_swap) mm.apply_swap(*swap) - return routed_ops + return routed_ops, inserted_swaps @classmethod def _brute_force_strategy( diff --git a/cirq-core/cirq/transformers/routing/route_circuit_cqc_test.py b/cirq-core/cirq/transformers/routing/route_circuit_cqc_test.py index 530f9c6344a..fac3f35c63a 100644 --- a/cirq-core/cirq/transformers/routing/route_circuit_cqc_test.py +++ b/cirq-core/cirq/transformers/routing/route_circuit_cqc_test.py @@ -14,6 +14,7 @@ from __future__ import annotations +import networkx as nx import pytest import cirq @@ -22,8 +23,48 @@ def test_directed_device() -> None: device = cirq.testing.construct_ring_device(10, directed=True) device_graph = device.metadata.nx_graph - with pytest.raises(ValueError, match="Device graph must be undirected."): - cirq.RouteCQC(device_graph) + # Directed graphs should now be accepted + router = cirq.RouteCQC(device_graph) + # Test that we can route a simple circuit on a directed graph + q = cirq.LineQubit.range(3) + circuit = cirq.Circuit(cirq.CNOT(q[0], q[1]), cirq.CNOT(q[1], q[2])) + hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(q, q))) + routed_circuit = router(circuit, initial_mapper=hard_coded_mapper) + device.validate_circuit(routed_circuit) + cirq.testing.assert_same_circuits(routed_circuit, circuit) + + +@pytest.mark.parametrize("tag_inserted_swaps", [True, False]) +def test_directed_device_swap_decomposition(tag_inserted_swaps: bool) -> None: + # Create a simple directed graph: q0 -> q1 (one-way only) + device_graph = nx.DiGraph() + q = cirq.LineQubit.range(2) + device_graph.add_edge(q[0], q[1]) + + router = cirq.RouteCQC(device_graph) + + # A circuit that requires a SWAP to execute (qubits need to be swapped) + circuit = cirq.Circuit(cirq.CNOT(q[1], q[0])) # Reverse direction not available + + hard_coded_mapper = cirq.HardCodedInitialMapper({q[0]: q[0], q[1]: q[1]}) + routed_circuit = router( + circuit, initial_mapper=hard_coded_mapper, tag_inserted_swaps=tag_inserted_swaps + ) + + # Expected: Hadamard-based SWAP decomposition followed by CNOT + # SWAP decomposition for unidirectional edge: CNOT-H⊗H-CNOT-H⊗H-CNOT + t = (cirq.RoutingSwapTag(),) if tag_inserted_swaps else () + expected = cirq.Circuit( + cirq.CNOT(q[0], q[1]).with_tags(*t), + cirq.H(q[0]).with_tags(*t), + cirq.H(q[1]).with_tags(*t), + cirq.CNOT(q[0], q[1]).with_tags(*t), + cirq.H(q[0]).with_tags(*t), + cirq.H(q[1]).with_tags(*t), + cirq.CNOT(q[0], q[1]).with_tags(*t), + cirq.CNOT(q[0], q[1]), # The original CNOT after swap + ) + cirq.testing.assert_same_circuits(routed_circuit, expected) @pytest.mark.parametrize( @@ -104,7 +145,7 @@ def test_circuit_with_measurement_gates() -> None: device_graph = device.metadata.nx_graph q = cirq.LineQubit.range(3) circuit = cirq.Circuit(cirq.MeasurementGate(2).on(q[0], q[2]), cirq.MeasurementGate(3).on(*q)) - hard_coded_mapper = cirq.HardCodedInitialMapper({q[i]: q[i] for i in range(3)}) + hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(q, q))) router = cirq.RouteCQC(device_graph) routed_circuit = router(circuit, initial_mapper=hard_coded_mapper) cirq.testing.assert_same_circuits(routed_circuit, circuit) @@ -115,7 +156,7 @@ def test_circuit_with_two_qubit_intermediate_measurement_gate() -> None: device_graph = device.metadata.nx_graph router = cirq.RouteCQC(device_graph) qs = cirq.LineQubit.range(2) - hard_coded_mapper = cirq.HardCodedInitialMapper({qs[i]: qs[i] for i in range(2)}) + hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(qs, qs))) circuit = cirq.Circuit([cirq.Moment(cirq.measure(qs)), cirq.Moment(cirq.H.on_each(qs))]) routed_circuit = router( circuit, initial_mapper=hard_coded_mapper, context=cirq.TransformerContext(deep=True) @@ -128,7 +169,7 @@ def test_circuit_with_multi_qubit_intermediate_measurement_gate_and_with_default device_graph = device.metadata.nx_graph router = cirq.RouteCQC(device_graph) qs = cirq.LineQubit.range(3) - hard_coded_mapper = cirq.HardCodedInitialMapper({qs[i]: qs[i] for i in range(3)}) + hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(qs, qs))) circuit = cirq.Circuit([cirq.Moment(cirq.measure(qs)), cirq.Moment(cirq.H.on_each(qs))]) routed_circuit = router( circuit, initial_mapper=hard_coded_mapper, context=cirq.TransformerContext(deep=True) @@ -142,7 +183,7 @@ def test_circuit_with_multi_qubit_intermediate_measurement_gate_with_custom_key( device_graph = device.metadata.nx_graph router = cirq.RouteCQC(device_graph) qs = cirq.LineQubit.range(3) - hard_coded_mapper = cirq.HardCodedInitialMapper({qs[i]: qs[i] for i in range(3)}) + hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(qs, qs))) circuit = cirq.Circuit( [cirq.Moment(cirq.measure(qs, key="test")), cirq.Moment(cirq.H.on_each(qs))] ) @@ -163,7 +204,7 @@ def test_circuit_with_non_unitary_and_global_phase() -> None: cirq.Moment(cirq.depolarize(0.1, 2).on(q[0], q[2]), cirq.depolarize(0.1).on(q[1])), ] ) - hard_coded_mapper = cirq.HardCodedInitialMapper({q[i]: q[i] for i in range(3)}) + hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(q, q))) router = cirq.RouteCQC(device_graph) routed_circuit = router(circuit, initial_mapper=hard_coded_mapper) expected = cirq.Circuit( @@ -191,7 +232,7 @@ def test_circuit_with_tagged_ops() -> None: cirq.Moment(cirq.X(q[0]).with_tags("u")), ] ) - hard_coded_mapper = cirq.HardCodedInitialMapper({q[i]: q[i] for i in range(3)}) + hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(q, q))) router = cirq.RouteCQC(device_graph) routed_circuit = router(circuit, initial_mapper=hard_coded_mapper) expected = cirq.Circuit( @@ -237,6 +278,97 @@ def test_empty_circuit() -> None: ) +def test_directed_device_with_tag_inserted_swaps() -> None: + # Use a directed ring device + device = cirq.testing.construct_ring_device(10, directed=True) + device_graph = device.metadata.nx_graph + router = cirq.RouteCQC(device_graph) + + q = cirq.LineQubit.range(3) + # Force routing with non-adjacent qubits to trigger swaps + circuit = cirq.Circuit(cirq.CNOT(q[0], q[2])) + + hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(q, q))) + routed_circuit = router(circuit, initial_mapper=hard_coded_mapper, tag_inserted_swaps=True) + + # Verify that operations with RoutingSwapTag exist + tagged_ops = [op for op in routed_circuit.all_operations() if cirq.RoutingSwapTag() in op.tags] + assert tagged_ops + + # Verify presence of gates from the decomposed swap + tagged_gates = {op.gate for op in tagged_ops} + assert tagged_gates == {cirq.CNOT, cirq.H} + + +def test_directed_device_reverse_only_edge() -> None: + # Create a mixed graph with both bidirectional and reverse-only edges + device_graph = nx.DiGraph() + q = cirq.LineQubit.range(4) + # Create a path: 0<->1->2<-3 (1->2 is forward-only, 3->2 is reverse-only) + device_graph.add_edges_from( + [ + (q[0], q[1]), + (q[1], q[0]), # bidirectional + (q[1], q[2]), # forward-only + (q[3], q[2]), # reverse-only (no 2->3 edge) + ] + ) + + router = cirq.RouteCQC(device_graph) + + # Create a circuit that requires a swap on the reverse-only edge + # Map qubits so we need to swap on edge 3<-2 + circuit = cirq.Circuit(cirq.CNOT(q[0], q[1]), cirq.CNOT(q[2], q[3]), cirq.CNOT(q[1], q[0])) + + hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(q, q))) + routed_circuit, initial_map, swap_map = router.route_circuit( + circuit, initial_mapper=hard_coded_mapper, tag_inserted_swaps=True + ) + + # Verify that operations with RoutingSwapTag exist + tagged_ops = [op for op in routed_circuit.all_operations() if cirq.RoutingSwapTag() in op.tags] + assert tagged_ops + + # Verify presence of gates from the decomposed swap + tagged_gates = {op.gate for op in tagged_ops} + assert tagged_gates == {cirq.CNOT, cirq.H} + + # Verify CNOTs on bidirectional edge were preserved and the CNOT on the reverse edge flipped + untagged_ops = [op for op in routed_circuit.all_operations() if not op.tags] + assert untagged_ops == [cirq.CNOT(q[0], q[1]), cirq.CNOT(q[3], q[2]), cirq.CNOT(q[1], q[0])] + + # Verify the routed circuit is mathematically equivalent to the original + cirq.testing.assert_circuits_have_same_unitary_given_final_permutation( + routed_circuit, circuit, swap_map + ) + + +def test_directed_device_bidirectional_swap_preserved() -> None: + # Create a graph: 0 <-> 1 -> 2 + device_graph = nx.DiGraph() + q = cirq.LineQubit.range(3) + device_graph.add_edges_from([(q[0], q[1]), (q[1], q[0]), (q[1], q[2])]) + + router = cirq.RouteCQC(device_graph) + + # Circuit needing routing: CNOT(0, 2). Shortest path is 0->1->2. + # Routing should swap 0 and 1 to bring them adjacent. + circuit = cirq.Circuit(cirq.CNOT(q[0], q[2])) + + hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(q, q))) + routed_circuit = router(circuit, initial_mapper=hard_coded_mapper, tag_inserted_swaps=True) + + # Check for preserved SWAP(0, 1) + # Identify SWAPs that are kept as SWAP gates (bidirectional) and tagged + preserved_swaps = [ + op + for op in routed_circuit.all_operations() + if isinstance(op.gate, cirq.SwapPowGate) and cirq.RoutingSwapTag() in op.tags + ] + + assert preserved_swaps == [cirq.SWAP(q[0], q[1]).with_tags(cirq.RoutingSwapTag())] + + def test_repr() -> None: device = cirq.testing.construct_ring_device(10) device_graph = device.metadata.nx_graph