Skip to content

Commit ec2e772

Browse files
authored
feat: Allow passing PassManager (#231)
1 parent 3af5c86 commit ec2e772

File tree

5 files changed

+114
-36
lines changed

5 files changed

+114
-36
lines changed

qiskit_braket_provider/providers/adapter.py

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from qiskit.circuit import Instruction as QiskitInstruction
1313
from qiskit.circuit.library import get_standard_gate_name_mapping
1414
from qiskit.circuit.parametervector import ParameterVectorElement
15-
from qiskit.transpiler import Target
15+
from qiskit.transpiler import PassManager, Target
1616
from qiskit_ionq import add_equivalences, ionq_gates
1717
from sympy import Add, Mul, Pow, Symbol
1818

@@ -499,6 +499,8 @@ def to_braket(
499499
target: Target | None = None,
500500
qubit_labels: Sequence[int] | None = None,
501501
optimization_level: int = 0,
502+
callback: Callable | None = None,
503+
pass_manager: PassManager | None = None,
502504
) -> Circuit:
503505
"""Return a Braket quantum circuit from a Qiskit quantum circuit.
504506
@@ -522,6 +524,10 @@ def to_braket(
522524
assumed to be contiguous.
523525
optimization_level (int): The optimization level to pass to `qiskit.transpile`.
524526
Default: 0 (no optimization).
527+
callback (Callable | None): A callback function that will be called after each transpiler
528+
pass execution. Default: `None`.
529+
pass_manager (PassManager): `PassManager` to transpile the circuit; will raise an error if
530+
used in conjunction with a target, basis gates, or connectivity. Default: `None`.
525531
526532
Returns:
527533
Circuit: Braket circuit
@@ -530,30 +536,41 @@ def to_braket(
530536
circuit = to_qiskit(circuit)
531537
if not isinstance(circuit, QuantumCircuit):
532538
raise TypeError(f"Expected a QuantumCircuit, got {type(circuit)} instead.")
533-
if (basis_gates or connectivity) and target:
534-
raise ValueError("Basis gates and connectivity cannot be specified alongside target.")
535-
536-
# If basis_gates is not None, then target remains empty
537-
target = target if basis_gates or target else _create_default_target(circuit)
538-
needs_transpilation = (
539-
target
540-
or connectivity
541-
or (basis_gates and not {gate.name for gate, _, _ in circuit.data}.issubset(basis_gates))
542-
)
543-
if not verbatim and needs_transpilation:
544-
circuit = transpile(
545-
circuit,
546-
basis_gates=basis_gates,
547-
coupling_map=connectivity,
548-
optimization_level=optimization_level,
549-
target=target,
539+
loose_constraints = basis_gates or connectivity
540+
if pass_manager and (target or loose_constraints):
541+
raise ValueError(
542+
"Cannot specify target, basis gates, or connectivity alongside pass manager"
550543
)
544+
if loose_constraints and target:
545+
raise ValueError("Cannot specify basis gates or connectivity alongside target.")
546+
547+
if pass_manager:
548+
circuit = pass_manager.run(circuit, callback=callback)
549+
elif not verbatim:
550+
# If basis_gates is not None, then target remains empty
551+
target = target if basis_gates or target else _default_target(circuit)
552+
if (
553+
target
554+
or connectivity
555+
or (
556+
basis_gates and not {gate.name for gate, _, _ in circuit.data}.issubset(basis_gates)
557+
)
558+
):
559+
circuit = transpile(
560+
circuit,
561+
basis_gates=basis_gates,
562+
coupling_map=connectivity,
563+
optimization_level=optimization_level,
564+
target=target,
565+
callback=callback,
566+
)
567+
551568
# Verify that ParameterVector would not collide with scalar variables after renaming.
552569
_validate_name_conflicts(circuit.parameters)
553570
# Handle qiskit to braket conversion
554571
measured_qubits: dict[int, int] = {}
555572
braket_circuit = Circuit()
556-
qubit_labels = qubit_labels or sorted(circuit.find_bit(q).index for q in circuit.qubits)
573+
qubit_labels = qubit_labels or _default_qubit_labels(circuit)
557574
for circuit_instruction in circuit.data:
558575
operation = circuit_instruction.operation
559576
qubits = circuit_instruction.qubits
@@ -630,7 +647,7 @@ def to_braket(
630647

631648
# QPU targets will have qubits/pairs specified for each instruction;
632649
# Targets whose values consist solely of {None: None} are either simulator or default targets
633-
if verbatim or (target and any(v != {None: None} for v in target.values())):
650+
if verbatim or (target and any(v != {None: None} for v in target.values())) or pass_manager:
634651
braket_circuit = Circuit(braket_circuit.result_types).add_verbatim_box(
635652
Circuit(braket_circuit.instructions)
636653
)
@@ -641,15 +658,20 @@ def to_braket(
641658
return braket_circuit
642659

643660

644-
def _create_default_target(circuit: QuantumCircuit) -> Target:
661+
def _default_target(circuit: QuantumCircuit) -> Target:
645662
target = Target(num_qubits=circuit.num_qubits)
646663
for braket_name, instruction in _BRAKET_GATE_NAME_TO_QISKIT_GATE.items():
647-
if (name := braket_name.lower()) in _BRAKET_TO_QISKIT_NAMES:
648-
target.add_instruction(instruction, name=_BRAKET_TO_QISKIT_NAMES[name])
664+
if name := _BRAKET_TO_QISKIT_NAMES.get(braket_name.lower()):
665+
target.add_instruction(instruction, name=name)
649666
target.add_instruction(Measure())
650667
return target
651668

652669

670+
def _default_qubit_labels(circuit: QuantumCircuit) -> tuple[int, ...]:
671+
bits = sorted(circuit.find_bit(q).index for q in circuit.qubits)
672+
return tuple(range(max(bits) + 1)) if bits else tuple()
673+
674+
653675
def _create_free_parameters(operation):
654676
params = operation.params if hasattr(operation, "params") else []
655677
for i, param in enumerate(params):

qiskit_braket_provider/providers/braket_backend.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
import logging
66
import warnings
77
from abc import ABC
8-
from collections.abc import Iterable
8+
from collections.abc import Callable, Iterable
99
from typing import Generic, TypeVar
1010

1111
from qiskit import QuantumCircuit
1212
from qiskit.providers import BackendV2, Options, QubitProperties
13+
from qiskit.transpiler import PassManager, Target
1314

1415
from braket.aws import AwsDevice, AwsDeviceType, AwsQuantumTask
1516
from braket.aws.queue_information import QueueDepthInfo
@@ -343,6 +344,8 @@ def run(
343344
native: bool = False,
344345
*,
345346
optimization_level: int = 0,
347+
callback: Callable | None = None,
348+
pass_manager: PassManager | None = None,
346349
**options,
347350
):
348351
if isinstance(run_input, QuantumCircuit):
@@ -357,13 +360,12 @@ def run(
357360
del options["meas_level"]
358361

359362
# Always use target for simulator
360-
target, basis_gates = (
361-
(self._target, None)
362-
if native or self._device.type == AwsDeviceType.SIMULATOR
363-
else (None, self._gateset)
364-
)
363+
target, basis_gates = self._target_and_basis_gates(native, pass_manager)
365364
braket_circuits = (
366-
[to_braket(circ, verbatim=True, qubit_labels=self._qubit_labels) for circ in circuits]
365+
[
366+
to_braket(circ, verbatim=True, qubit_labels=self._qubit_labels, callback=callback)
367+
for circ in circuits
368+
]
367369
if verbatim
368370
else [
369371
to_braket(
@@ -375,6 +377,8 @@ def run(
375377
native_angle_restrictions(self._device.properties) if native else None
376378
),
377379
optimization_level=optimization_level,
380+
callback=callback,
381+
pass_manager=pass_manager,
378382
)
379383
for circ in circuits
380384
]
@@ -386,6 +390,16 @@ def run(
386390
else self._run_batch(braket_circuits, shots, **options)
387391
)
388392

393+
def _target_and_basis_gates(
394+
self, native: bool, pass_manager: PassManager
395+
) -> tuple[Target | None, set[str] | None]:
396+
if pass_manager:
397+
return None, None
398+
if native or self._device.type == AwsDeviceType.SIMULATOR:
399+
# Always use target for simulator
400+
return self._target, None
401+
return None, self._gateset
402+
389403
def _run_batch(self, braket_circuits: list[Circuit], shots: int, **options):
390404
batch_task = self._device.run_batch(braket_circuits, shots=shots, **options)
391405
tasks: list[AwsQuantumTask] = batch_task.tasks

tests/providers/test_adapter.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import numpy as np
77
import pytest
8-
from qiskit import ClassicalRegister, QuantumCircuit, QuantumRegister
8+
from qiskit import ClassicalRegister, QuantumCircuit, QuantumRegister, generate_preset_pass_manager
99
from qiskit.circuit import Instruction as QiskitInstruction
1010
from qiskit.circuit import Parameter, ParameterVector
1111
from qiskit.circuit.library import GlobalPhaseGate, PauliEvolutionGate
@@ -274,9 +274,28 @@ def test_target_with_loose_constraints(self):
274274
target.add_instruction(qiskit_gates.HGate())
275275

276276
with pytest.raises(ValueError):
277-
to_braket(circuit, basis_gates={"h"}, target=target)
277+
to_braket(circuit, target=target, basis_gates={"h"})
278278
with pytest.raises(ValueError):
279-
to_braket(circuit, connectivity=[[0, 1], [1, 2]], target=target)
279+
to_braket(circuit, target=target, connectivity=[[0, 1], [1, 2]])
280+
281+
def test_pass_manager_with_other_arguments(self):
282+
"""
283+
Tests that to_braket raises a ValueError if pass_manager is supplied
284+
with target or loose constraints.
285+
"""
286+
circuit = QuantumCircuit(1, 1)
287+
circuit.h(0)
288+
289+
target = Target()
290+
target.add_instruction(qiskit_gates.HGate())
291+
pass_manager = generate_preset_pass_manager(2, target=target)
292+
293+
with pytest.raises(ValueError):
294+
to_braket(circuit, pass_manager=pass_manager, target=target)
295+
with pytest.raises(ValueError):
296+
to_braket(circuit, pass_manager=pass_manager, basis_gates={"h"})
297+
with pytest.raises(ValueError):
298+
to_braket(circuit, pass_manager=pass_manager, connectivity=[[0, 1], [1, 2]])
280299

281300
def test_convert_parametric_qiskit_to_braket_circuit(self):
282301
"""Tests to_braket works with parametric circuits."""

tests/providers/test_braket_backend.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88
from botocore import errorfactory
99
from networkx import DiGraph, complete_graph, from_dict_of_lists, relabel_nodes
10-
from qiskit import QuantumCircuit, transpile
10+
from qiskit import QuantumCircuit, generate_preset_pass_manager, transpile
1111
from qiskit.circuit import Instruction as QiskitInstruction
1212
from qiskit.circuit.library import TwoLocal
1313
from qiskit.circuit.random import random_circuit
@@ -258,6 +258,29 @@ def test_run_multiple_circuits_program_set(self):
258258
)
259259
self.assertEqual(device.run.call_count, 2)
260260

261+
def test_run_with_pass_manager(self):
262+
"""Tests run with pass_manager"""
263+
device = Mock()
264+
device.properties = RIGETTI_MOCK_GATE_MODEL_QPU_CAPABILITIES
265+
device.type = "QPU"
266+
device.topology_graph = None
267+
backend = BraketAwsBackend(device=device)
268+
mock_task_1 = Mock(spec=LocalQuantumTask)
269+
mock_task_1.id = "0"
270+
mock_task_2 = Mock(spec=LocalQuantumTask)
271+
mock_task_2.id = "1"
272+
mock_batch = Mock(spec=AwsQuantumTaskBatch)
273+
mock_batch.tasks = [mock_task_1, mock_task_2]
274+
backend._device.run_batch = Mock(return_value=mock_batch)
275+
circuit = QuantumCircuit(1)
276+
circuit.h(0)
277+
278+
backend.run(circuit, shots=0, pass_manager=generate_preset_pass_manager(2, backend))
279+
native_circuit = Circuit().add_verbatim_box(
280+
Circuit().rz(0, np.pi / 2).rx(0, np.pi / 2).rz(0, np.pi / 2)
281+
)
282+
device.run_batch.assert_called_once_with([native_circuit], shots=0)
283+
261284
def test_run_invalid_run_input(self):
262285
"""Tests run with invalid input to run"""
263286
device = Mock()

tests/providers/test_braket_instructions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from braket.experimental_capabilities import EnableExperimentalCapability
99
from qiskit_braket_provider import to_braket
10-
from qiskit_braket_provider.providers.adapter import _create_default_target
10+
from qiskit_braket_provider.providers.adapter import _default_target
1111
from qiskit_braket_provider.providers.braket_instructions import CCPRx, MeasureFF
1212

1313

@@ -83,7 +83,7 @@ def test_circuit_with_measureff_ccprx(self):
8383
CCPRx(0.5, 0.7, 0), qubits=(Qubit(QuantumRegister(1, "q"), 0),)
8484
)
8585

86-
target = _create_default_target(circuit)
86+
target = _default_target(circuit)
8787
target.add_instruction(
8888
CCPRx(Parameter("angle_1"), Parameter("angle_2"), Parameter("feedback_key"))
8989
)

0 commit comments

Comments
 (0)