Skip to content

Commit 3af5c86

Browse files
authored
fix: Only transpile natively if explicitly asked (#230)
Also added an `optimization_level` flag to `BraketAwsBackend.run` to give users an easy way to optimize their circuit when running.
1 parent 5a52fde commit 3af5c86

File tree

6 files changed

+71
-42
lines changed

6 files changed

+71
-42
lines changed

README.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,9 @@
22

33
[![Latest Version](https://img.shields.io/pypi/v/qiskit-braket-provider.svg)](https://pypi.python.org/pypi/qiskit-braket-provider)
44
[![Supported Python Versions](https://img.shields.io/pypi/pyversions/qiskit-braket-provider.svg)](https://pypi.python.org/pypi/qiskit-braket-provider)
5+
[![Qiskit compatibility](https://img.shields.io/badge/Qiskit%20compatibility-%3E%3D0.34.2-blueviolet?logo=Qiskit)](https://github.com/Qiskit/qiskit/releases)
56
[![Build status](https://github.com/qiskit-community/qiskit-braket-provider/actions/workflows/test_latest_versions.yml/badge.svg?branch=main)](https://github.com/qiskit-community/qiskit-braket-provider/actions/workflows/test_latest_versions.yml)
67

7-
[![Qiskit compatibility](https://img.shields.io/badge/Qiskit%20compatibility-%3E%3D0.34,%20%3C2.0-blueviolet?logo=Qiskit)](https://github.com/Qiskit/qiskit/releases)
8-
9-
108
Qiskit-Braket provider to execute Qiskit programs on AWS quantum computing hardware devices through Amazon Braket.
119

1210
### Table of Contents

qiskit_braket_provider/providers/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,18 @@
66
.. currentmodule:: qiskit_braket_provider.providers
77
88
Provider module contains classes and functions to connect
9-
AWS Braket abstraction to Qiskit architecture.
9+
Amazon Braket abstraction to Qiskit architecture.
1010
1111
Provider classes and functions
1212
==============================
1313
1414
.. autosummary::
1515
:toctree: ../stubs/
1616
17-
AWSBraketBackend
17+
BraketAwsBackend
1818
BraketLocalBackend
19-
AWSBraketProvider
20-
AmazonBraketTask
19+
BraketProvider
20+
BraketQuantumTask
2121
"""
2222

2323
from .adapter import to_braket as to_braket

qiskit_braket_provider/providers/adapter.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -371,11 +371,10 @@ def aws_device_to_target(device: AwsDevice) -> Target:
371371
return _simulator_target(
372372
f"Target for Amazon Braket simulator: {device.name}", device.properties
373373
)
374-
case _:
375-
raise QiskitBraketException(
376-
"Cannot convert to target. "
377-
f"{device.properties.__class__} device capabilities are not supported."
378-
)
374+
raise QiskitBraketException(
375+
"Cannot convert to target. "
376+
f"{device.properties.__class__} device capabilities are not supported."
377+
)
379378

380379

381380
def _simulator_target(description: str, properties: GateModelSimulatorDeviceCapabilities):
@@ -498,8 +497,8 @@ def to_braket(
498497
angle_restrictions: dict[str, dict[int, set[float] | tuple[float, float]]] | None = None,
499498
*,
500499
target: Target | None = None,
501-
braket_qubits: Sequence[int] | None = None,
502-
optimization_level: int | None = 0,
500+
qubit_labels: Sequence[int] | None = None,
501+
optimization_level: int = 0,
503502
) -> Circuit:
504503
"""Return a Braket quantum circuit from a Qiskit quantum circuit.
505504
@@ -518,11 +517,11 @@ def to_braket(
518517
validate numeric parameters. Default: `None`.
519518
target (Target | None): A backend transpiler target. Can only be provided
520519
if basis_gates is `None`. Default: `None`.
521-
braket_qubits (Sequence[int] | None): A list of (not necessarily contiguous) indices of
520+
qubit_labels (Sequence[int] | None): A list of (not necessarily contiguous) indices of
522521
qubits in the underlying Amazon Braket device. If not supplied, then the indices are
523522
assumed to be contiguous.
524-
optimization_level (int | None): The optimization level to pass to `qiskit.transpile`.
525-
Default: None.
523+
optimization_level (int): The optimization level to pass to `qiskit.transpile`.
524+
Default: 0 (no optimization).
526525
527526
Returns:
528527
Circuit: Braket circuit
@@ -554,7 +553,7 @@ def to_braket(
554553
# Handle qiskit to braket conversion
555554
measured_qubits: dict[int, int] = {}
556555
braket_circuit = Circuit()
557-
braket_qubits = braket_qubits or sorted(circuit.find_bit(q).index for q in circuit.qubits)
556+
qubit_labels = qubit_labels or sorted(circuit.find_bit(q).index for q in circuit.qubits)
558557
for circuit_instruction in circuit.data:
559558
operation = circuit_instruction.operation
560559
qubits = circuit_instruction.qubits
@@ -568,7 +567,7 @@ def to_braket(
568567
match gate_name := operation.name:
569568
case "measure":
570569
qubit = qubits[0] # qubit count = 1 for measure
571-
qubit_index = braket_qubits[circuit.find_bit(qubit).index]
570+
qubit_index = qubit_labels[circuit.find_bit(qubit).index]
572571
if qubit_index in measured_qubits.values():
573572
raise ValueError(f"Cannot measure previously measured qubit {qubit_index}")
574573
clbit = circuit.find_bit(circuit_instruction.clbits[0]).index
@@ -581,7 +580,7 @@ def to_braket(
581580
)
582581
case "unitary" | "kraus":
583582
params = _create_free_parameters(operation)
584-
qubit_indices = [braket_qubits[circuit.find_bit(qubit).index] for qubit in qubits][
583+
qubit_indices = [qubit_labels[circuit.find_bit(qubit).index] for qubit in qubits][
585584
::-1
586585
] # reversal for little to big endian notation
587586

@@ -597,7 +596,7 @@ def to_braket(
597596
):
598597
raise ValueError("Negative control is not supported")
599598
# Getting the index from the bit mapping
600-
qubit_indices = [braket_qubits[circuit.find_bit(qubit).index] for qubit in qubits]
599+
qubit_indices = [qubit_labels[circuit.find_bit(qubit).index] for qubit in qubits]
601600
if intersection := set(measured_qubits.values()).intersection(qubit_indices):
602601
raise ValueError(
603602
f"Cannot apply operation {gate_name} to measured qubits {intersection}"
@@ -795,7 +794,7 @@ def _sympy_to_qiskit(expr: Mul | Add | Symbol | Pow) -> ParameterExpression | Pa
795794
return Parameter(expr.name)
796795
case Pow():
797796
return _sympy_to_qiskit(expr.args[0]) ** int(expr.args[1])
798-
case obj if hasattr(obj, "is_number") and obj.is_real:
797+
case obj if getattr(obj, "is_real", False):
799798
return float(obj)
800799
raise TypeError(f"unrecognized parameter type in conversion: {type(expr)}")
801800

qiskit_braket_provider/providers/braket_backend.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def run(self, run_input: QuantumCircuit | list[QuantumCircuit], **options) -> Br
181181
class BraketAwsBackend(BraketBackend[AwsDevice]):
182182
"""BraketAwsBackend."""
183183

184-
def __init__( # pylint: disable=too-many-positional-arguments
184+
def __init__(
185185
self,
186186
arn: str | None = None,
187187
provider=None,
@@ -229,8 +229,10 @@ def __init__( # pylint: disable=too-many-positional-arguments
229229
f"QiskitBraketProvider/{version.__version__}"
230230
)
231231
self._target = aws_device_to_target(device=self._device)
232-
self._braket_qubits = (
233-
sorted(self._device.topology_graph.nodes) if self._device.topology_graph else None
232+
self._qubit_labels = (
233+
tuple(sorted(self._device.topology_graph.nodes))
234+
if self._device.topology_graph
235+
else None
234236
)
235237
self._gateset = self.get_gateset()
236238
self._supports_program_sets = (
@@ -262,12 +264,21 @@ def target(self):
262264
def max_circuits(self):
263265
return None
264266

267+
@property
268+
def qubit_labels(self) -> tuple[int, ...] | None:
269+
"""
270+
tuple[int, ...] | None: The qubit labels of the underlying device, in ascending order.
271+
272+
Unlike the qubits in the target, these labels are not necessarily contiguous.
273+
"""
274+
return self._qubit_labels
275+
265276
@classmethod
266277
def _default_options(cls):
267278
return Options()
268279

269280
def qubit_properties(self, qubit: int | list[int]) -> QubitProperties | list[QubitProperties]:
270-
# TODO: fetch information from device.properties.provider # pylint: disable=fixme
281+
# TODO: fetch information from device.properties.provider
271282
raise NotImplementedError
272283

273284
def queue_depth(self) -> QueueDepthInfo:
@@ -325,9 +336,15 @@ def acquire_channel(self, qubit: int):
325336
def control_channel(self, qubits: Iterable[int]):
326337
raise NotImplementedError(f"Control channel is not supported by {self.name}.")
327338

328-
def run(self, run_input, verbatim: bool = False, native: bool | None = None, **options):
329-
# Defaults to native transpilation if the underlying device is a QPU
330-
native = native if native is not None else self._device.type == AwsDeviceType.QPU
339+
def run(
340+
self,
341+
run_input: QuantumCircuit | list[QuantumCircuit],
342+
verbatim: bool = False,
343+
native: bool = False,
344+
*,
345+
optimization_level: int = 0,
346+
**options,
347+
):
331348
if isinstance(run_input, QuantumCircuit):
332349
circuits = [run_input]
333350
elif isinstance(run_input, list):
@@ -339,17 +356,25 @@ def run(self, run_input, verbatim: bool = False, native: bool | None = None, **o
339356
self._validate_meas_level(options["meas_level"])
340357
del options["meas_level"]
341358

359+
# Always use target for simulator
360+
target, basis_gates = (
361+
(self._target, None)
362+
if native or self._device.type == AwsDeviceType.SIMULATOR
363+
else (None, self._gateset)
364+
)
342365
braket_circuits = (
343-
[to_braket(circ, verbatim=True, braket_qubits=self._braket_qubits) for circ in circuits]
366+
[to_braket(circ, verbatim=True, qubit_labels=self._qubit_labels) for circ in circuits]
344367
if verbatim
345368
else [
346369
to_braket(
347370
circ,
348-
target=self._target,
349-
braket_qubits=self._braket_qubits,
371+
target=target,
372+
basis_gates=basis_gates,
373+
qubit_labels=self._qubit_labels,
350374
angle_restrictions=(
351375
native_angle_restrictions(self._device.properties) if native else None
352376
),
377+
optimization_level=optimization_level,
353378
)
354379
for circ in circuits
355380
]
@@ -376,7 +401,7 @@ def __init_subclass__(cls, **kwargs):
376401
warnings.warn(f"{cls.__name__} is deprecated.", DeprecationWarning, stacklevel=2)
377402
super().__init_subclass__(**kwargs)
378403

379-
def __init__( # pylint: disable=too-many-positional-arguments
404+
def __init__(
380405
self,
381406
device: AwsDevice,
382407
provider=None,

tests/providers/test_braket_backend.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Tests for AWS Braket backends."""
22

33
import unittest
4-
from unittest import TestCase, skip
4+
from unittest import TestCase
55
from unittest.mock import Mock, patch
66

77
import numpy as np
@@ -216,12 +216,15 @@ def test_run_multiple_circuits(self):
216216
circuit.h(0)
217217

218218
backend.run([circuit, circuit], shots=0, meas_level=2)
219+
braket_circuit = Circuit().h(0)
220+
device.run_batch.assert_called_once_with([braket_circuit, braket_circuit], shots=0)
219221

220-
braket_circuit = Circuit().add_verbatim_box(
222+
backend.run([circuit, circuit], shots=0, native=True)
223+
native_circuit = Circuit().add_verbatim_box(
221224
Circuit().rz(0, np.pi / 2).rx(0, np.pi / 2).rz(0, np.pi / 2)
222225
)
223-
device.run_batch.assert_called_with([braket_circuit, braket_circuit], shots=0)
224-
device.run_batch.assert_called_once()
226+
device.run_batch.assert_called_with([native_circuit, native_circuit], shots=0)
227+
self.assertEqual(device.run_batch.call_count, 2)
225228

226229
def test_run_multiple_circuits_program_set(self):
227230
"""Tests run with multiple circuits"""
@@ -241,14 +244,19 @@ def test_run_multiple_circuits_program_set(self):
241244
circuit.h(0)
242245

243246
backend.run([circuit, circuit], shots=5, meas_level=2)
247+
braket_circuit = Circuit().h(0)
248+
device.run.assert_called_once_with(
249+
ProgramSet([braket_circuit, braket_circuit], shots_per_executable=5)
250+
)
244251

245-
braket_circuit = Circuit().add_verbatim_box(
252+
backend.run([circuit, circuit], shots=5, native=True)
253+
native_circuit = Circuit().add_verbatim_box(
246254
Circuit().rz(0, np.pi / 2).rx(0, np.pi / 2).rz(0, np.pi / 2)
247255
)
248256
device.run.assert_called_with(
249-
ProgramSet([braket_circuit, braket_circuit], shots_per_executable=5)
257+
ProgramSet([native_circuit, native_circuit], shots_per_executable=5)
250258
)
251-
device.run.assert_called_once()
259+
self.assertEqual(device.run.call_count, 2)
252260

253261
def test_run_invalid_run_input(self):
254262
"""Tests run with invalid input to run"""
@@ -303,7 +311,6 @@ def test_meas_level_1(self):
303311
with self.assertRaises(exception.QiskitBraketException):
304312
backend.run(circuit, shots=10, meas_level=1)
305313

306-
@skip(reason="qiskit-algorithms doesn't support V2 primitives yet")
307314
def test_vqe(self):
308315
"""Tests VQE."""
309316
local_simulator = BraketLocalBackend(name="default")

tests/providers/test_braket_provider.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def test_discontinous_qubit_indices_qiskit_transpilation(self, mock_get_devices)
186186
circ.cx(2, 3)
187187

188188
self.assertEqual(
189-
to_braket(circ, target=backend.target, braket_qubits=backend._braket_qubits).qubits,
189+
to_braket(circ, target=backend.target, qubit_labels=backend.qubit_labels).qubits,
190190
{0, 1, 2, 7},
191191
)
192192

0 commit comments

Comments
 (0)