Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cirq-core/cirq/circuits/circuit_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1253,22 +1253,27 @@ def test_repeat_until_protocols() -> None:
# Ensure the _repeat_until has been mapped, the measurement has been mapped to the same key,
# and the control keys of the subcircuit is empty (because the control key of the condition is
# bound to the measurement).
assert scoped._mapped_repeat_until is not None
assert scoped._mapped_repeat_until.keys == (cirq.MeasurementKey('a', ('0',)),)
assert cirq.measurement_key_objs(scoped) == {cirq.MeasurementKey('a', ('0',))}
assert not cirq.control_keys(scoped)
mapped = cirq.with_measurement_key_mapping(scoped, {'a': 'b'})
assert mapped._mapped_repeat_until is not None
assert mapped._mapped_repeat_until.keys == (cirq.MeasurementKey('b', ('0',)),)
assert cirq.measurement_key_objs(mapped) == {cirq.MeasurementKey('b', ('0',))}
assert not cirq.control_keys(mapped)
prefixed = cirq.with_key_path_prefix(mapped, ('1',))
assert prefixed._mapped_repeat_until is not None
assert prefixed._mapped_repeat_until.keys == (cirq.MeasurementKey('b', ('1', '0')),)
assert cirq.measurement_key_objs(prefixed) == {cirq.MeasurementKey('b', ('1', '0'))}
assert not cirq.control_keys(prefixed)
setpath = cirq.with_key_path(prefixed, ('2',))
assert setpath._mapped_repeat_until is not None
assert setpath._mapped_repeat_until.keys == (cirq.MeasurementKey('b', ('2',)),)
assert cirq.measurement_key_objs(setpath) == {cirq.MeasurementKey('b', ('2',))}
assert not cirq.control_keys(setpath)
resolved = cirq.resolve_parameters(setpath, {'p': 1})
assert resolved._mapped_repeat_until is not None
assert resolved._mapped_repeat_until.keys == (cirq.MeasurementKey('b', ('2',)),)
assert cirq.measurement_key_objs(resolved) == {cirq.MeasurementKey('b', ('2',))}
assert not cirq.control_keys(resolved)
Expand Down
10 changes: 3 additions & 7 deletions cirq-core/cirq/ops/classically_controlled_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import annotations

from typing import AbstractSet, Any, cast, Mapping, Sequence, TYPE_CHECKING
from typing import AbstractSet, Any, Mapping, Sequence, TYPE_CHECKING

import sympy

Expand Down Expand Up @@ -207,17 +207,13 @@ def _with_measurement_key_mapping_(
conditions = [protocols.with_measurement_key_mapping(c, key_map) for c in self._conditions]
sub_operation = protocols.with_measurement_key_mapping(self._sub_operation, key_map)
sub_operation = self._sub_operation if sub_operation is NotImplemented else sub_operation
return cast(
ClassicallyControlledOperation, sub_operation.with_classical_controls(*conditions)
)
return sub_operation.with_classical_controls(*conditions)

def _with_key_path_prefix_(self, prefix: tuple[str, ...]) -> ClassicallyControlledOperation:
conditions = [protocols.with_key_path_prefix(c, prefix) for c in self._conditions]
sub_operation = protocols.with_key_path_prefix(self._sub_operation, prefix)
sub_operation = self._sub_operation if sub_operation is NotImplemented else sub_operation
return cast(
ClassicallyControlledOperation, sub_operation.with_classical_controls(*conditions)
)
return sub_operation.with_classical_controls(*conditions)

def _with_rescoped_keys_(
self, path: tuple[str, ...], bindable_keys: frozenset[cirq.MeasurementKey]
Expand Down
13 changes: 7 additions & 6 deletions cirq-core/cirq/ops/measure_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

from __future__ import annotations

from typing import Callable, Iterable, overload, TYPE_CHECKING
from typing import Callable, cast, Iterable, overload, TYPE_CHECKING

import numpy as np

from cirq import protocols
from cirq.ops import pauli_string, raw_types
from cirq.ops import gate_operation, pauli_string, raw_types
from cirq.ops.measurement_gate import MeasurementGate
from cirq.ops.pauli_measurement_gate import PauliMeasurementGate

Expand Down Expand Up @@ -96,7 +96,7 @@ def measure(
key: str | cirq.MeasurementKey | None = None,
invert_mask: tuple[bool, ...] = (),
confusion_map: dict[tuple[int, ...], np.ndarray] | None = None,
) -> raw_types.Operation:
) -> gate_operation.GateOperation:
pass


Expand All @@ -107,7 +107,7 @@ def measure(
key: str | cirq.MeasurementKey | None = None,
invert_mask: tuple[bool, ...] = (),
confusion_map: dict[tuple[int, ...], np.ndarray] | None = None,
) -> raw_types.Operation:
) -> gate_operation.GateOperation:
pass


Expand All @@ -116,7 +116,7 @@ def measure(
key: str | cirq.MeasurementKey | None = None,
invert_mask: tuple[bool, ...] = (),
confusion_map: dict[tuple[int, ...], np.ndarray] | None = None,
) -> raw_types.Operation:
) -> gate_operation.GateOperation:
"""Returns a single MeasurementGate applied to all the given qubits.

The qubits are measured in the computational basis. This can also be
Expand Down Expand Up @@ -161,7 +161,8 @@ def measure(
if key is None:
key = _default_measurement_key(targets)
qid_shape = protocols.qid_shape(targets)
return MeasurementGate(len(targets), key, invert_mask, qid_shape, confusion_map).on(*targets)
gate = MeasurementGate(len(targets), key, invert_mask, qid_shape, confusion_map)
return cast(gate_operation.GateOperation, gate.on(*targets))


M = measure
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/ops/pauli_string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2101,7 +2101,7 @@ def test_pauli_ops_identity_gate_operation(gate1: cirq.Pauli, gate2: cirq.Pauli)
assert np.array_equal(subtraction.matrix(), unitary1 - unitary2)


def test_pauli_gate_multiplication_with_power():
def test_pauli_gate_multiplication_with_power() -> None:
q = cirq.LineQubit(0)

# Test all Pauli gates (X, Y, Z)
Expand All @@ -2124,7 +2124,7 @@ def test_pauli_gate_multiplication_with_power():
assert gate**5 * gate**0 == gate**5


def test_try_interpret_as_pauli_string():
def test_try_interpret_as_pauli_string() -> None:
from cirq.ops.pauli_string import _try_interpret_as_pauli_string

q = cirq.LineQubit(0)
Expand Down
21 changes: 19 additions & 2 deletions cirq-core/cirq/ops/raw_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Hashable,
Iterable,
Mapping,
overload,
Sequence,
TYPE_CHECKING,
)
Expand Down Expand Up @@ -687,9 +688,17 @@ def classical_controls(self) -> frozenset[cirq.Condition]:
"""The classical controls gating this operation."""
return frozenset()

@overload
def with_classical_controls(self) -> cirq.Operation:
pass

@overload
def with_classical_controls(
self, *conditions: str | cirq.MeasurementKey | cirq.Condition | sympy.Expr
) -> cirq.Operation:
) -> cirq.ClassicallyControlledOperation:
pass

def with_classical_controls(self, *conditions):
"""Returns a classically controlled version of this operation.

An operation that is classically controlled is executed iff all
Expand Down Expand Up @@ -957,9 +966,17 @@ def without_classical_controls(self) -> cirq.Operation:
new_sub_operation = self.sub_operation.without_classical_controls()
return self if new_sub_operation is self.sub_operation else new_sub_operation

@overload
def with_classical_controls(self) -> cirq.Operation:
pass

@overload
def with_classical_controls(
self, *conditions: str | cirq.MeasurementKey | cirq.Condition | sympy.Expr
) -> cirq.Operation:
) -> cirq.ClassicallyControlledOperation:
pass

def with_classical_controls(self, *conditions):
if not conditions:
return self
return self.sub_operation.with_classical_controls(*conditions)
Expand Down
20 changes: 11 additions & 9 deletions cirq-core/cirq/protocols/has_stabilizer_effect_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from __future__ import annotations

from typing import Any

import numpy as np

import cirq
Expand Down Expand Up @@ -50,40 +52,40 @@ def __init__(self, q: cirq.Qid = cirq.LineQubit(0)):
self.q = q

@property
def qubits(self):
def qubits(self) -> tuple[cirq.Qid, ...]:
return (self.q,)

def with_qubits(self, *new_qubits): # pragma: no cover
def with_qubits(self, *new_qubits) -> cirq.Operation: # pragma: no cover
return self


class NoOp(EmptyOp):
@property
def gate(self):
def gate(self) -> Any:
return No()


class NoOp1(EmptyOp):
@property
def gate(self):
def gate(self) -> Any:
return No1()


class NoOp2(EmptyOp):
@property
def gate(self):
def gate(self) -> Any:
return No2()


class NoOp3(EmptyOp):
@property
def gate(self):
def gate(self) -> Any:
return No3()


class YesOp(EmptyOp):
@property
def gate(self):
def gate(self) -> Any:
return Yes()


Expand All @@ -95,8 +97,8 @@ def _unitary_(self):
return self.unitary

@property
def qubits(self):
return cirq.LineQubit.range(self.unitary.shape[0].bit_length() - 1)
def qubits(self) -> tuple[cirq.Qid, ...]:
return tuple(cirq.LineQubit.range(self.unitary.shape[0].bit_length() - 1))


class GateDecomposes(cirq.Gate):
Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/protocols/has_unitary_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,8 @@ class EmptyOp(cirq.Operation):
"""A trivial operation that will be recognized as `_apply_unitary_`-able."""

@property
def qubits(self):
def qubits(self) -> tuple[cirq.Qid, ...]:
return ()

def with_qubits(self, *new_qubits): # pragma: no cover
return self
def with_qubits(self, *new_qubits) -> cirq.Operation:
raise NotImplementedError()
6 changes: 3 additions & 3 deletions cirq-core/cirq/protocols/json_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._memo: dict[Any, dict] = {}

def default(self, o):
def default(self, o) -> dict[str, Any] | list[Any] | float | bool:
# Object with custom method?
if hasattr(o, '_json_dict_'):
json_dict = _json_dict_with_cirq_type(o)
Expand Down Expand Up @@ -519,7 +519,7 @@ def read_json(
*,
json_text: str | None = None,
resolvers: Sequence[JsonResolver] | None = None,
):
) -> Any:
"""Read a JSON file that optionally contains cirq objects.

Args:
Expand Down Expand Up @@ -605,7 +605,7 @@ def read_json_gzip(
*,
gzip_raw: bytes | None = None,
resolvers: Sequence[JsonResolver] | None = None,
):
) -> Any:
"""Read a gzipped JSON file that optionally contains cirq objects.

Args:
Expand Down
Loading