Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 87 additions & 15 deletions cirq-core/cirq/transformers/routing/route_circuit_cqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ class RouteCQC:
the swap that minimises the cost and use it to update our logical to physical
mapping. Repeat from 3.1.

Handling Directed Graphs:

When the device_graph is directed (e.g., edges represent unidirectional CNOT constraints),
the routing logic still operates as if the graph were undirected. This is because SWAP
gates are logically symmetric regardless of underlying gate direction constraints.
After routing completes, any inserted SWAP gates are decomposed into a directional-aware
sequence using the Hadamard trick:
``SWAP = CNOT(ctrl, tgt) - H⊗H - CNOT(ctrl, tgt) - H⊗H - CNOT(ctrl, tgt)``

For example:

>>> import cirq_google as cg
Expand All @@ -100,13 +109,9 @@ def __init__(self, device_graph: nx.Graph):

Args:
device_graph: The connectivity graph of physical qubits.

Raises:
ValueError: if `device_graph` is a directed graph.
Can be directed or undirected.
"""

if nx.is_directed(device_graph):
raise ValueError("Device graph must be undirected.")
self.device_graph = device_graph

def __call__(
Expand Down Expand Up @@ -182,7 +187,7 @@ def route_circuit(
circuit: the input circuit to be transformed.
lookahead_radius: the maximum number of succeeding timesteps the algorithm will
consider for ranking candidate swaps with the cost cost function.
tag_inserted_swaps: whether or not a RoutingSwapTag should be attched to inserted swap
tag_inserted_swaps: whether or not a RoutingSwapTag should be attached to inserted swap
operations.
initial_mapper: an initial mapping strategy (placement) of logical qubits in the
circuit onto physical qubits on the device.
Expand Down Expand Up @@ -223,18 +228,28 @@ def route_circuit(
two_qubit_ops, single_qubit_ops = self._get_one_and_two_qubit_ops_as_timesteps(circuit)

# 4. Do the routing and save the routed circuit as a list of moments.
routed_ops = self._route(
routed_ops, routing_swaps = self._route(
mm,
two_qubit_ops,
single_qubit_ops,
lookahead_radius,
tag_inserted_swaps=tag_inserted_swaps,
)

# 5. Return the routed circuit by packing each inner list of ops as densely as possible and
# preserving outer moment structure. Also return initial map and swap permutation map.
# 5. Replace tagged SWAP gates with directional decompositions if needed.
# This handles directed device graphs by decomposing SWAP into a sequence of CNOTs
# that respect the edge direction constraints.
routed_circuit = circuits.Circuit(circuits.Circuit(m) for m in routed_ops)
if routing_swaps and nx.is_directed(self.device_graph):
final_circuit = self._replace_swaps_with_directional_decomposition(
routed_circuit, routing_swaps
)
else:
final_circuit = routed_circuit

# 6. Return the routed circuit and mappings.
return (
circuits.Circuit(circuits.Circuit(m) for m in routed_ops),
final_circuit,
initial_mapping,
{
initial_mapping[mm.int_to_logical_qid[k]]: mm.int_to_physical_qid[v]
Expand Down Expand Up @@ -285,6 +300,59 @@ def _get_one_and_two_qubit_ops_as_timesteps(
two_qubit_ops = [list(m) for m in two_qubit_circuit]
return two_qubit_ops, single_qubit_ops

def _replace_swaps_with_directional_decomposition(
self, circuit: cirq.AbstractCircuit, routing_swaps: set[cirq.Operation]
) -> cirq.AbstractCircuit:
"""Replaces routing-added SWAP gates with directional decompositions.

For directed device graphs, SWAP gates need to be decomposed into CNOTs that
respect the edge direction. This method uses cirq.map_operations_and_unroll to
find all routing-added SWAP gates and replaces them with the appropriate
decomposition.

For bidirectional edges (or undirected graphs), the SWAP is left unchanged.
For unidirectional edges, the SWAP is decomposed using the Hadamard trick:
SWAP = CNOT(ctrl, tgt) - H⊗H - CNOT(ctrl, tgt) - H⊗H - CNOT(ctrl, tgt)

Args:
circuit: The routed circuit containing SWAP operations.
routing_swaps: Set of routing-added SWAP operations to decompose.

Returns:
Circuit with directional SWAP decompositions where needed.
"""

def map_func(op: cirq.Operation, _: int) -> cirq.OP_TREE:
"""Map function to replace routing-added SWAPs with directional decomposition."""
# Check if this is a routing-added SWAP operation
if op not in routing_swaps:
return op

q1, q2 = op.qubits
has_forward = self.device_graph.has_edge(q1, q2)
has_reverse = self.device_graph.has_edge(q2, q1)

if has_forward ^ has_reverse:
# Unidirectional: decompose SWAP using Hadamard trick
ctrl, tgt = (q1, q2) if has_forward else (q2, q1)
# Preserve the RoutingSwapTag on the decomposed operations
decomposed_ops: list[cirq.Operation] = [
ops.CNOT(ctrl, tgt),
ops.H(ctrl),
ops.H(tgt),
ops.CNOT(ctrl, tgt),
ops.H(ctrl),
ops.H(tgt),
ops.CNOT(ctrl, tgt),
]
# Transfer tags from original SWAP to decomposed operations
return [op_i.with_tags(*op.tags) for op_i in decomposed_ops]

# Bidirectional or no edge check needed at routing level - keep as-is
return op

return transformer_primitives.map_operations_and_unroll(circuit, map_func)

@classmethod
def _route(
cls,
Expand All @@ -293,10 +361,10 @@ def _route(
single_qubit_ops: list[list[cirq.Operation]],
lookahead_radius: int,
tag_inserted_swaps: bool = False,
) -> list[list[cirq.Operation]]:
) -> tuple[list[list[cirq.Operation]], set[cirq.Operation]]:
"""Main routing procedure that inserts necessary swaps on the given timesteps.

The i'th element of the returned list corresponds to the routed operatiosn in the i'th
The i'th element of the returned list corresponds to the routed operations in the i'th
timestep.

Args:
Expand All @@ -306,11 +374,12 @@ def _route(
the paper.
lookahead_radius: the maximum number of times the cost function can be iterated for
convergence.
tag_inserted_swaps: whether or not a RoutingSwapTag should be attched to inserted swap
tag_inserted_swaps: whether or not a RoutingSwapTag should be attached to inserted swap
operations.

Returns:
a list of lists corresponding to timesteps of the routed circuit.
A list of lists corresponding to timesteps of the routed circuit and
a set of inserted SWAP operations.
"""
two_qubit_ops_ints: list[list[QidIntPair]] = [
[
Expand All @@ -336,6 +405,8 @@ def process_executable_two_qubit_ops(timestep: int) -> int:

strats = [cls._choose_single_swap, cls._choose_pair_of_swaps]

inserted_swaps: set[cirq.Operation] = set()

for timestep in range(len(two_qubit_ops)):
# Add single-qubit ops with qubits given by the current mapping.
routed_ops.append([mm.mapped_op(op) for op in single_qubit_ops[timestep]])
Expand All @@ -362,10 +433,11 @@ def process_executable_two_qubit_ops(timestep: int) -> int:
)
if tag_inserted_swaps:
inserted_swap = inserted_swap.with_tags(ops.RoutingSwapTag())
inserted_swaps.add(inserted_swap)
routed_ops[timestep].append(inserted_swap)
mm.apply_swap(*swap)

return routed_ops
return routed_ops, inserted_swaps

@classmethod
def _brute_force_strategy(
Expand Down
148 changes: 140 additions & 8 deletions cirq-core/cirq/transformers/routing/route_circuit_cqc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import networkx as nx
import pytest

import cirq
Expand All @@ -22,8 +23,48 @@
def test_directed_device() -> None:
device = cirq.testing.construct_ring_device(10, directed=True)
device_graph = device.metadata.nx_graph
with pytest.raises(ValueError, match="Device graph must be undirected."):
cirq.RouteCQC(device_graph)
# Directed graphs should now be accepted
router = cirq.RouteCQC(device_graph)
# Test that we can route a simple circuit on a directed graph
q = cirq.LineQubit.range(3)
circuit = cirq.Circuit(cirq.CNOT(q[0], q[1]), cirq.CNOT(q[1], q[2]))
hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(q, q)))
routed_circuit = router(circuit, initial_mapper=hard_coded_mapper)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CNOTs in the circuit satisfy edge directions so the routed_circuit is the same. If intentional, please add cirq.testing.assert_same_circuits(routed_circuit, circuit).

device.validate_circuit(routed_circuit)
cirq.testing.assert_same_circuits(routed_circuit, circuit)


This comment was marked as outdated.

@pytest.mark.parametrize("tag_inserted_swaps", [True, False])
def test_directed_device_swap_decomposition(tag_inserted_swaps: bool) -> None:
# Create a simple directed graph: q0 -> q1 (one-way only)
device_graph = nx.DiGraph()
q = cirq.LineQubit.range(2)
device_graph.add_edge(q[0], q[1])

router = cirq.RouteCQC(device_graph)

# A circuit that requires a SWAP to execute (qubits need to be swapped)
circuit = cirq.Circuit(cirq.CNOT(q[1], q[0])) # Reverse direction not available

hard_coded_mapper = cirq.HardCodedInitialMapper({q[0]: q[0], q[1]: q[1]})
routed_circuit = router(
circuit, initial_mapper=hard_coded_mapper, tag_inserted_swaps=tag_inserted_swaps
)

# Expected: Hadamard-based SWAP decomposition followed by CNOT
# SWAP decomposition for unidirectional edge: CNOT-H⊗H-CNOT-H⊗H-CNOT
t = (cirq.RoutingSwapTag(),) if tag_inserted_swaps else ()
expected = cirq.Circuit(
cirq.CNOT(q[0], q[1]).with_tags(*t),
cirq.H(q[0]).with_tags(*t),
cirq.H(q[1]).with_tags(*t),
cirq.CNOT(q[0], q[1]).with_tags(*t),
cirq.H(q[0]).with_tags(*t),
cirq.H(q[1]).with_tags(*t),
cirq.CNOT(q[0], q[1]).with_tags(*t),
cirq.CNOT(q[0], q[1]), # The original CNOT after swap
)
cirq.testing.assert_same_circuits(routed_circuit, expected)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -104,7 +145,7 @@ def test_circuit_with_measurement_gates() -> None:
device_graph = device.metadata.nx_graph
q = cirq.LineQubit.range(3)
circuit = cirq.Circuit(cirq.MeasurementGate(2).on(q[0], q[2]), cirq.MeasurementGate(3).on(*q))
hard_coded_mapper = cirq.HardCodedInitialMapper({q[i]: q[i] for i in range(3)})
hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(q, q)))
router = cirq.RouteCQC(device_graph)
routed_circuit = router(circuit, initial_mapper=hard_coded_mapper)
cirq.testing.assert_same_circuits(routed_circuit, circuit)
Expand All @@ -115,7 +156,7 @@ def test_circuit_with_two_qubit_intermediate_measurement_gate() -> None:
device_graph = device.metadata.nx_graph
router = cirq.RouteCQC(device_graph)
qs = cirq.LineQubit.range(2)
hard_coded_mapper = cirq.HardCodedInitialMapper({qs[i]: qs[i] for i in range(2)})
hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(qs, qs)))
circuit = cirq.Circuit([cirq.Moment(cirq.measure(qs)), cirq.Moment(cirq.H.on_each(qs))])
routed_circuit = router(
circuit, initial_mapper=hard_coded_mapper, context=cirq.TransformerContext(deep=True)
Expand All @@ -128,7 +169,7 @@ def test_circuit_with_multi_qubit_intermediate_measurement_gate_and_with_default
device_graph = device.metadata.nx_graph
router = cirq.RouteCQC(device_graph)
qs = cirq.LineQubit.range(3)
hard_coded_mapper = cirq.HardCodedInitialMapper({qs[i]: qs[i] for i in range(3)})
hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(qs, qs)))
circuit = cirq.Circuit([cirq.Moment(cirq.measure(qs)), cirq.Moment(cirq.H.on_each(qs))])
routed_circuit = router(
circuit, initial_mapper=hard_coded_mapper, context=cirq.TransformerContext(deep=True)
Expand All @@ -142,7 +183,7 @@ def test_circuit_with_multi_qubit_intermediate_measurement_gate_with_custom_key(
device_graph = device.metadata.nx_graph
router = cirq.RouteCQC(device_graph)
qs = cirq.LineQubit.range(3)
hard_coded_mapper = cirq.HardCodedInitialMapper({qs[i]: qs[i] for i in range(3)})
hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(qs, qs)))
circuit = cirq.Circuit(
[cirq.Moment(cirq.measure(qs, key="test")), cirq.Moment(cirq.H.on_each(qs))]
)
Expand All @@ -163,7 +204,7 @@ def test_circuit_with_non_unitary_and_global_phase() -> None:
cirq.Moment(cirq.depolarize(0.1, 2).on(q[0], q[2]), cirq.depolarize(0.1).on(q[1])),
]
)
hard_coded_mapper = cirq.HardCodedInitialMapper({q[i]: q[i] for i in range(3)})
hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(q, q)))
router = cirq.RouteCQC(device_graph)
routed_circuit = router(circuit, initial_mapper=hard_coded_mapper)
expected = cirq.Circuit(
Expand Down Expand Up @@ -191,7 +232,7 @@ def test_circuit_with_tagged_ops() -> None:
cirq.Moment(cirq.X(q[0]).with_tags("u")),
]
)
hard_coded_mapper = cirq.HardCodedInitialMapper({q[i]: q[i] for i in range(3)})
hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(q, q)))
router = cirq.RouteCQC(device_graph)
routed_circuit = router(circuit, initial_mapper=hard_coded_mapper)
expected = cirq.Circuit(
Expand Down Expand Up @@ -237,6 +278,97 @@ def test_empty_circuit() -> None:
)


def test_directed_device_with_tag_inserted_swaps() -> None:
# Use a directed ring device
device = cirq.testing.construct_ring_device(10, directed=True)
device_graph = device.metadata.nx_graph
router = cirq.RouteCQC(device_graph)

q = cirq.LineQubit.range(3)
# Force routing with non-adjacent qubits to trigger swaps
circuit = cirq.Circuit(cirq.CNOT(q[0], q[2]))

hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(q, q)))
routed_circuit = router(circuit, initial_mapper=hard_coded_mapper, tag_inserted_swaps=True)

# Verify that operations with RoutingSwapTag exist
tagged_ops = [op for op in routed_circuit.all_operations() if cirq.RoutingSwapTag() in op.tags]
assert tagged_ops

# Verify presence of gates from the decomposed swap
tagged_gates = {op.gate for op in tagged_ops}
assert tagged_gates == {cirq.CNOT, cirq.H}


def test_directed_device_reverse_only_edge() -> None:
# Create a mixed graph with both bidirectional and reverse-only edges
device_graph = nx.DiGraph()
q = cirq.LineQubit.range(4)
# Create a path: 0<->1->2<-3 (1->2 is forward-only, 3->2 is reverse-only)
device_graph.add_edges_from(
[
(q[0], q[1]),
(q[1], q[0]), # bidirectional
(q[1], q[2]), # forward-only
(q[3], q[2]), # reverse-only (no 2->3 edge)
]
)

router = cirq.RouteCQC(device_graph)

# Create a circuit that requires a swap on the reverse-only edge
# Map qubits so we need to swap on edge 3<-2
circuit = cirq.Circuit(cirq.CNOT(q[0], q[1]), cirq.CNOT(q[2], q[3]), cirq.CNOT(q[1], q[0]))

hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(q, q)))
routed_circuit, initial_map, swap_map = router.route_circuit(
circuit, initial_mapper=hard_coded_mapper, tag_inserted_swaps=True
)

# Verify that operations with RoutingSwapTag exist
tagged_ops = [op for op in routed_circuit.all_operations() if cirq.RoutingSwapTag() in op.tags]
assert tagged_ops

# Verify presence of gates from the decomposed swap
tagged_gates = {op.gate for op in tagged_ops}
assert tagged_gates == {cirq.CNOT, cirq.H}

# Verify CNOTs on bidirectional edge were preserved and the CNOT on the reverse edge flipped
untagged_ops = [op for op in routed_circuit.all_operations() if not op.tags]
assert untagged_ops == [cirq.CNOT(q[0], q[1]), cirq.CNOT(q[3], q[2]), cirq.CNOT(q[1], q[0])]

# Verify the routed circuit is mathematically equivalent to the original
cirq.testing.assert_circuits_have_same_unitary_given_final_permutation(
Copy link
Collaborator

@pavoljuhas pavoljuhas Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest to move this check to test_directed_device_reverse_only_edge which will give it a bit more check if that test routes a circuit with bidirectional and directed edges.

routed_circuit, circuit, swap_map
)


def test_directed_device_bidirectional_swap_preserved() -> None:
# Create a graph: 0 <-> 1 -> 2
device_graph = nx.DiGraph()
q = cirq.LineQubit.range(3)
device_graph.add_edges_from([(q[0], q[1]), (q[1], q[0]), (q[1], q[2])])

router = cirq.RouteCQC(device_graph)

# Circuit needing routing: CNOT(0, 2). Shortest path is 0->1->2.
# Routing should swap 0 and 1 to bring them adjacent.
circuit = cirq.Circuit(cirq.CNOT(q[0], q[2]))

hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(q, q)))
routed_circuit = router(circuit, initial_mapper=hard_coded_mapper, tag_inserted_swaps=True)

# Check for preserved SWAP(0, 1)
# Identify SWAPs that are kept as SWAP gates (bidirectional) and tagged
preserved_swaps = [
op
for op in routed_circuit.all_operations()
if isinstance(op.gate, cirq.SwapPowGate) and cirq.RoutingSwapTag() in op.tags
]

assert preserved_swaps == [cirq.SWAP(q[0], q[1]).with_tags(cirq.RoutingSwapTag())]


def test_repr() -> None:
device = cirq.testing.construct_ring_device(10)
device_graph = device.metadata.nx_graph
Expand Down
Loading