-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Support routing for directed device graphs #7810
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
48cccfb
8f7eca9
55c56f3
4f5ea0f
b0057ae
fe74356
3e4a9e7
632e00a
05570c7
b7fef09
0895d1d
dd3f24d
01705d9
e133470
87094f2
239552a
2f19844
114d316
2acecbd
dc784ff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,7 @@ | |
|
|
||
| from __future__ import annotations | ||
|
|
||
| import networkx as nx | ||
| import pytest | ||
|
|
||
| import cirq | ||
|
|
@@ -22,8 +23,48 @@ | |
| def test_directed_device() -> None: | ||
| device = cirq.testing.construct_ring_device(10, directed=True) | ||
| device_graph = device.metadata.nx_graph | ||
| with pytest.raises(ValueError, match="Device graph must be undirected."): | ||
| cirq.RouteCQC(device_graph) | ||
| # Directed graphs should now be accepted | ||
| router = cirq.RouteCQC(device_graph) | ||
| # Test that we can route a simple circuit on a directed graph | ||
| q = cirq.LineQubit.range(3) | ||
| circuit = cirq.Circuit(cirq.CNOT(q[0], q[1]), cirq.CNOT(q[1], q[2])) | ||
| hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(q, q))) | ||
| routed_circuit = router(circuit, initial_mapper=hard_coded_mapper) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The CNOTs in the |
||
| device.validate_circuit(routed_circuit) | ||
| cirq.testing.assert_same_circuits(routed_circuit, circuit) | ||
|
|
||
|
|
||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| @pytest.mark.parametrize("tag_inserted_swaps", [True, False]) | ||
| def test_directed_device_swap_decomposition(tag_inserted_swaps: bool) -> None: | ||
| # Create a simple directed graph: q0 -> q1 (one-way only) | ||
| device_graph = nx.DiGraph() | ||
| q = cirq.LineQubit.range(2) | ||
| device_graph.add_edge(q[0], q[1]) | ||
|
|
||
| router = cirq.RouteCQC(device_graph) | ||
|
|
||
| # A circuit that requires a SWAP to execute (qubits need to be swapped) | ||
| circuit = cirq.Circuit(cirq.CNOT(q[1], q[0])) # Reverse direction not available | ||
|
|
||
| hard_coded_mapper = cirq.HardCodedInitialMapper({q[0]: q[0], q[1]: q[1]}) | ||
| routed_circuit = router( | ||
| circuit, initial_mapper=hard_coded_mapper, tag_inserted_swaps=tag_inserted_swaps | ||
| ) | ||
|
|
||
| # Expected: Hadamard-based SWAP decomposition followed by CNOT | ||
| # SWAP decomposition for unidirectional edge: CNOT-H⊗H-CNOT-H⊗H-CNOT | ||
| t = (cirq.RoutingSwapTag(),) if tag_inserted_swaps else () | ||
| expected = cirq.Circuit( | ||
| cirq.CNOT(q[0], q[1]).with_tags(*t), | ||
| cirq.H(q[0]).with_tags(*t), | ||
| cirq.H(q[1]).with_tags(*t), | ||
| cirq.CNOT(q[0], q[1]).with_tags(*t), | ||
| cirq.H(q[0]).with_tags(*t), | ||
| cirq.H(q[1]).with_tags(*t), | ||
| cirq.CNOT(q[0], q[1]).with_tags(*t), | ||
| cirq.CNOT(q[0], q[1]), # The original CNOT after swap | ||
| ) | ||
| cirq.testing.assert_same_circuits(routed_circuit, expected) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
|
|
@@ -104,7 +145,7 @@ def test_circuit_with_measurement_gates() -> None: | |
| device_graph = device.metadata.nx_graph | ||
| q = cirq.LineQubit.range(3) | ||
| circuit = cirq.Circuit(cirq.MeasurementGate(2).on(q[0], q[2]), cirq.MeasurementGate(3).on(*q)) | ||
| hard_coded_mapper = cirq.HardCodedInitialMapper({q[i]: q[i] for i in range(3)}) | ||
| hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(q, q))) | ||
| router = cirq.RouteCQC(device_graph) | ||
| routed_circuit = router(circuit, initial_mapper=hard_coded_mapper) | ||
| cirq.testing.assert_same_circuits(routed_circuit, circuit) | ||
|
|
@@ -115,7 +156,7 @@ def test_circuit_with_two_qubit_intermediate_measurement_gate() -> None: | |
| device_graph = device.metadata.nx_graph | ||
| router = cirq.RouteCQC(device_graph) | ||
| qs = cirq.LineQubit.range(2) | ||
| hard_coded_mapper = cirq.HardCodedInitialMapper({qs[i]: qs[i] for i in range(2)}) | ||
| hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(qs, qs))) | ||
| circuit = cirq.Circuit([cirq.Moment(cirq.measure(qs)), cirq.Moment(cirq.H.on_each(qs))]) | ||
| routed_circuit = router( | ||
| circuit, initial_mapper=hard_coded_mapper, context=cirq.TransformerContext(deep=True) | ||
|
|
@@ -128,7 +169,7 @@ def test_circuit_with_multi_qubit_intermediate_measurement_gate_and_with_default | |
| device_graph = device.metadata.nx_graph | ||
| router = cirq.RouteCQC(device_graph) | ||
| qs = cirq.LineQubit.range(3) | ||
| hard_coded_mapper = cirq.HardCodedInitialMapper({qs[i]: qs[i] for i in range(3)}) | ||
| hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(qs, qs))) | ||
| circuit = cirq.Circuit([cirq.Moment(cirq.measure(qs)), cirq.Moment(cirq.H.on_each(qs))]) | ||
| routed_circuit = router( | ||
| circuit, initial_mapper=hard_coded_mapper, context=cirq.TransformerContext(deep=True) | ||
|
|
@@ -142,7 +183,7 @@ def test_circuit_with_multi_qubit_intermediate_measurement_gate_with_custom_key( | |
| device_graph = device.metadata.nx_graph | ||
| router = cirq.RouteCQC(device_graph) | ||
| qs = cirq.LineQubit.range(3) | ||
| hard_coded_mapper = cirq.HardCodedInitialMapper({qs[i]: qs[i] for i in range(3)}) | ||
| hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(qs, qs))) | ||
| circuit = cirq.Circuit( | ||
| [cirq.Moment(cirq.measure(qs, key="test")), cirq.Moment(cirq.H.on_each(qs))] | ||
| ) | ||
|
|
@@ -163,7 +204,7 @@ def test_circuit_with_non_unitary_and_global_phase() -> None: | |
| cirq.Moment(cirq.depolarize(0.1, 2).on(q[0], q[2]), cirq.depolarize(0.1).on(q[1])), | ||
| ] | ||
| ) | ||
| hard_coded_mapper = cirq.HardCodedInitialMapper({q[i]: q[i] for i in range(3)}) | ||
| hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(q, q))) | ||
| router = cirq.RouteCQC(device_graph) | ||
| routed_circuit = router(circuit, initial_mapper=hard_coded_mapper) | ||
| expected = cirq.Circuit( | ||
|
|
@@ -191,7 +232,7 @@ def test_circuit_with_tagged_ops() -> None: | |
| cirq.Moment(cirq.X(q[0]).with_tags("u")), | ||
| ] | ||
| ) | ||
| hard_coded_mapper = cirq.HardCodedInitialMapper({q[i]: q[i] for i in range(3)}) | ||
| hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(q, q))) | ||
| router = cirq.RouteCQC(device_graph) | ||
| routed_circuit = router(circuit, initial_mapper=hard_coded_mapper) | ||
| expected = cirq.Circuit( | ||
|
|
@@ -237,6 +278,97 @@ def test_empty_circuit() -> None: | |
| ) | ||
|
|
||
|
|
||
| def test_directed_device_with_tag_inserted_swaps() -> None: | ||
| # Use a directed ring device | ||
| device = cirq.testing.construct_ring_device(10, directed=True) | ||
| device_graph = device.metadata.nx_graph | ||
| router = cirq.RouteCQC(device_graph) | ||
|
|
||
| q = cirq.LineQubit.range(3) | ||
| # Force routing with non-adjacent qubits to trigger swaps | ||
| circuit = cirq.Circuit(cirq.CNOT(q[0], q[2])) | ||
|
|
||
| hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(q, q))) | ||
| routed_circuit = router(circuit, initial_mapper=hard_coded_mapper, tag_inserted_swaps=True) | ||
|
|
||
| # Verify that operations with RoutingSwapTag exist | ||
| tagged_ops = [op for op in routed_circuit.all_operations() if cirq.RoutingSwapTag() in op.tags] | ||
| assert tagged_ops | ||
|
|
||
| # Verify presence of gates from the decomposed swap | ||
| tagged_gates = {op.gate for op in tagged_ops} | ||
| assert tagged_gates == {cirq.CNOT, cirq.H} | ||
|
|
||
|
|
||
| def test_directed_device_reverse_only_edge() -> None: | ||
| # Create a mixed graph with both bidirectional and reverse-only edges | ||
| device_graph = nx.DiGraph() | ||
| q = cirq.LineQubit.range(4) | ||
| # Create a path: 0<->1->2<-3 (1->2 is forward-only, 3->2 is reverse-only) | ||
| device_graph.add_edges_from( | ||
| [ | ||
| (q[0], q[1]), | ||
| (q[1], q[0]), # bidirectional | ||
| (q[1], q[2]), # forward-only | ||
| (q[3], q[2]), # reverse-only (no 2->3 edge) | ||
| ] | ||
| ) | ||
|
|
||
| router = cirq.RouteCQC(device_graph) | ||
|
|
||
| # Create a circuit that requires a swap on the reverse-only edge | ||
| # Map qubits so we need to swap on edge 3<-2 | ||
| circuit = cirq.Circuit(cirq.CNOT(q[0], q[1]), cirq.CNOT(q[2], q[3]), cirq.CNOT(q[1], q[0])) | ||
|
|
||
| hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(q, q))) | ||
| routed_circuit, initial_map, swap_map = router.route_circuit( | ||
| circuit, initial_mapper=hard_coded_mapper, tag_inserted_swaps=True | ||
| ) | ||
|
|
||
| # Verify that operations with RoutingSwapTag exist | ||
| tagged_ops = [op for op in routed_circuit.all_operations() if cirq.RoutingSwapTag() in op.tags] | ||
| assert tagged_ops | ||
|
|
||
| # Verify presence of gates from the decomposed swap | ||
| tagged_gates = {op.gate for op in tagged_ops} | ||
| assert tagged_gates == {cirq.CNOT, cirq.H} | ||
|
|
||
| # Verify CNOTs on bidirectional edge were preserved and the CNOT on the reverse edge flipped | ||
| untagged_ops = [op for op in routed_circuit.all_operations() if not op.tags] | ||
| assert untagged_ops == [cirq.CNOT(q[0], q[1]), cirq.CNOT(q[3], q[2]), cirq.CNOT(q[1], q[0])] | ||
|
|
||
| # Verify the routed circuit is mathematically equivalent to the original | ||
| cirq.testing.assert_circuits_have_same_unitary_given_final_permutation( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd suggest to move this check to |
||
| routed_circuit, circuit, swap_map | ||
| ) | ||
|
|
||
|
|
||
| def test_directed_device_bidirectional_swap_preserved() -> None: | ||
| # Create a graph: 0 <-> 1 -> 2 | ||
| device_graph = nx.DiGraph() | ||
| q = cirq.LineQubit.range(3) | ||
| device_graph.add_edges_from([(q[0], q[1]), (q[1], q[0]), (q[1], q[2])]) | ||
|
|
||
| router = cirq.RouteCQC(device_graph) | ||
|
|
||
| # Circuit needing routing: CNOT(0, 2). Shortest path is 0->1->2. | ||
| # Routing should swap 0 and 1 to bring them adjacent. | ||
| circuit = cirq.Circuit(cirq.CNOT(q[0], q[2])) | ||
|
|
||
| hard_coded_mapper = cirq.HardCodedInitialMapper(dict(zip(q, q))) | ||
| routed_circuit = router(circuit, initial_mapper=hard_coded_mapper, tag_inserted_swaps=True) | ||
|
|
||
| # Check for preserved SWAP(0, 1) | ||
| # Identify SWAPs that are kept as SWAP gates (bidirectional) and tagged | ||
| preserved_swaps = [ | ||
| op | ||
| for op in routed_circuit.all_operations() | ||
| if isinstance(op.gate, cirq.SwapPowGate) and cirq.RoutingSwapTag() in op.tags | ||
| ] | ||
|
|
||
| assert preserved_swaps == [cirq.SWAP(q[0], q[1]).with_tags(cirq.RoutingSwapTag())] | ||
|
|
||
|
|
||
| def test_repr() -> None: | ||
| device = cirq.testing.construct_ring_device(10) | ||
| device_graph = device.metadata.nx_graph | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.