diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index a4eb0227ee1..6a51d5294a4 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -988,18 +988,12 @@ def qid_shape( qids = ops.QubitOrder.as_qubit_order(qubit_order).order_for(self.all_qubits()) return protocols.qid_shape(qids) - def all_measurement_key_objs(self) -> frozenset[cirq.MeasurementKey]: - return frozenset( - key for op in self.all_operations() for key in protocols.measurement_key_objs(op) - ) - - def _measurement_key_objs_(self) -> frozenset[cirq.MeasurementKey]: - """Returns the set of all measurement keys in this circuit. + @property + def measurement_keys(self) -> frozenset[cirq.MeasurementKey]: + return frozenset(key for m in self.moments for key in m.measurement_keys) - Returns: frozenset of `cirq.MeasurementKey` objects that are - in this circuit. - """ - return self.all_measurement_key_objs() + def all_measurement_key_objs(self) -> frozenset[cirq.MeasurementKey]: + return self.measurement_keys def all_measurement_key_names(self) -> frozenset[str]: """Returns the set of all measurement key names in this circuit. @@ -1011,9 +1005,6 @@ def all_measurement_key_names(self) -> frozenset[str]: key for op in self.all_operations() for key in protocols.measurement_key_names(op) ) - def _measurement_key_names_(self) -> frozenset[str]: - return self.all_measurement_key_names() - def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]): return self._from_moments( (protocols.with_measurement_key_mapping(moment, key_map) for moment in self.moments), @@ -1038,7 +1029,7 @@ def _with_rescoped_keys_( for moment in self.moments: new_moment = protocols.with_rescoped_keys(moment, path, bindable_keys) moments.append(new_moment) - bindable_keys |= protocols.measurement_key_objs(new_moment) + bindable_keys |= new_moment.measurement_keys return self._from_moments(moments, tags=self.tags) def _qid_shape_(self) -> tuple[int, ...]: @@ -1279,10 +1270,7 @@ def to_text_diagram_drawer( """ qubits = ops.QubitOrder.as_qubit_order(qubit_order).order_for(self.all_qubits()) cbits = tuple( - sorted( - set(key for op in self.all_operations() for key in protocols.control_keys(op)), - key=str, - ) + sorted(set(key for op in self.all_operations() for key in op.control_keys), key=str) ) labels = qubits + cbits label_map = {labels[i]: i for i in range(len(labels))} @@ -1665,14 +1653,15 @@ def factorize(self) -> Iterable[Self]: for qubits in qubit_factors ) - def _control_keys_(self) -> frozenset[cirq.MeasurementKey]: + @property + def control_keys(self) -> frozenset[cirq.MeasurementKey]: measures: set[cirq.MeasurementKey] = set() controls: set[cirq.MeasurementKey] = set() for op in self.all_operations(): # Only require keys that haven't already been measured earlier - controls.update(k for k in protocols.control_keys(op) if k not in measures) + controls.update(op.control_keys - measures) # Record any measurement keys produced by this op - measures.update(protocols.measurement_key_objs(op)) + measures.update(op.measurement_keys) return frozenset(controls) @@ -2139,19 +2128,19 @@ def earliest_available_moment( end_moment_index = len(self.moments) last_available = end_moment_index k = end_moment_index - op_control_keys = protocols.control_keys(op) - op_measurement_keys = protocols.measurement_key_objs(op) + op_control_keys = op.control_keys + op_measurement_keys = op.measurement_keys op_qubits = op.qubits while k > 0: k -= 1 moment = self._moments[k] if moment.operates_on(op_qubits): return last_available - moment_measurement_keys = moment._measurement_key_objs_() + moment_measurement_keys = moment.measurement_keys if ( not op_measurement_keys.isdisjoint(moment_measurement_keys) or not op_control_keys.isdisjoint(moment_measurement_keys) - or not moment._control_keys_().isdisjoint(op_measurement_keys) + or not moment.control_keys.isdisjoint(op_measurement_keys) ): return last_available if self._can_add_op_at(k, op): @@ -2970,8 +2959,8 @@ def get_earliest_accommodating_moment_index( The integer index of the earliest moment that can accommodate the given moment or operation. """ mop_qubits = moment_or_operation.qubits - mop_mkeys = protocols.measurement_key_objs(moment_or_operation) - mop_ckeys = protocols.control_keys(moment_or_operation) + mop_mkeys = moment_or_operation.measurement_keys + mop_ckeys = moment_or_operation.control_keys if isinstance(moment_or_operation, Moment): # For consistency with `Circuit.append`, moments always get placed at the end of a circuit. diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index b7632768975..ba3bcd358b0 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -197,9 +197,7 @@ def __init__( if mapped_repeat_until: if self._use_repetition_ids or self._repetitions != 1: raise ValueError('Cannot use repetitions with repeat_until') - if protocols.measurement_key_objs(self._mapped_single_loop()).isdisjoint( - mapped_repeat_until.keys - ): + if self._mapped_single_loop().measurement_keys.isdisjoint(mapped_repeat_until.keys): raise ValueError('Infinite loop: condition is not modified in subcircuit.') @property @@ -309,8 +307,8 @@ def _ensure_deterministic_loop_count(self): raise ValueError('Cannot unroll circuit due to nondeterministic repetitions') @cached_property - def _measurement_key_objs(self) -> frozenset[cirq.MeasurementKey]: - circuit_keys = protocols.measurement_key_objs(self.circuit) + def measurement_keys(self) -> frozenset[cirq.MeasurementKey]: + circuit_keys = self.circuit.measurement_keys if circuit_keys and self.use_repetition_ids: self._ensure_deterministic_loop_count() if self.repetition_ids is not None: @@ -327,27 +325,18 @@ def _measurement_key_objs(self) -> frozenset[cirq.MeasurementKey]: for key in circuit_keys ) - def _measurement_key_objs_(self) -> frozenset[cirq.MeasurementKey]: - return self._measurement_key_objs - - def _measurement_key_names_(self) -> frozenset[str]: - return frozenset(str(key) for key in self._measurement_key_objs_()) - @cached_property - def _control_keys(self) -> frozenset[cirq.MeasurementKey]: + def control_keys(self) -> frozenset[cirq.MeasurementKey]: keys = ( frozenset() - if not protocols.control_keys(self.circuit) - else protocols.control_keys(self._mapped_single_loop()) + if not self.circuit.control_keys + else self._mapped_single_loop().control_keys ) mapped_repeat_until = self._mapped_repeat_until if mapped_repeat_until is not None: - keys |= frozenset(mapped_repeat_until.keys) - self._measurement_key_objs_() + keys |= frozenset(mapped_repeat_until.keys) - self.measurement_keys return keys - def _control_keys_(self) -> frozenset[cirq.MeasurementKey]: - return self._control_keys - def _is_parameterized_(self) -> bool: return any(self._parameter_names_generator()) @@ -395,9 +384,7 @@ def _mapped_repeat_until(self) -> cirq.Condition | None: repeat_until, self.param_resolver, recursive=False ) return protocols.with_rescoped_keys( - repeat_until, - self.parent_path, - bindable_keys=self._extern_keys | self._measurement_key_objs, + repeat_until, self.parent_path, bindable_keys=self._extern_keys | self.measurement_keys ) def mapped_circuit(self, deep: bool = False) -> cirq.Circuit: diff --git a/cirq-core/cirq/circuits/circuit_test.py b/cirq-core/cirq/circuits/circuit_test.py index cd8f7b00c70..be964a5ab75 100644 --- a/cirq-core/cirq/circuits/circuit_test.py +++ b/cirq-core/cirq/circuits/circuit_test.py @@ -4914,6 +4914,7 @@ def test_create_speed() -> None: c = cirq.Circuit(ops) duration = time.perf_counter() - t assert len(c) == moments + print(duration) assert duration < 4 @@ -4936,6 +4937,7 @@ def test_append_speed() -> None: c.append(xs[q]) duration = time.perf_counter() - t assert len(c) == moments + print(duration) assert duration < 5 diff --git a/cirq-core/cirq/circuits/moment.py b/cirq-core/cirq/circuits/moment.py index 414e2fab0f2..5ac37aab827 100644 --- a/cirq-core/cirq/circuits/moment.py +++ b/cirq-core/cirq/circuits/moment.py @@ -120,8 +120,6 @@ def __init__( raise ValueError(f'Overlapping operations: {self.operations}') self._qubit_to_op[q] = op - self._measurement_key_objs: frozenset[cirq.MeasurementKey] | None = None - self._control_keys: frozenset[cirq.MeasurementKey] | None = None self._tags = tags @classmethod @@ -234,10 +232,8 @@ def with_operation(self, operation: cirq.Operation) -> cirq.Moment: m._sorted_operations = None m._qubit_to_op = {**self._qubit_to_op, **{q: operation for q in operation.qubits}} - m._measurement_key_objs = self._measurement_key_objs_().union( - protocols.measurement_key_objs(operation) - ) - m._control_keys = self._control_keys_().union(protocols.control_keys(operation)) + m.__setattr__('measurement_keys', self.measurement_keys | operation.measurement_keys) + m.__setattr__('control_keys', self.control_keys | operation.control_keys) return m @@ -272,11 +268,12 @@ def with_operations(self, *contents: cirq.OP_TREE) -> cirq.Moment: m._operations = self._operations + flattened_contents m._sorted_operations = None - m._measurement_key_objs = self._measurement_key_objs_().union( - set(itertools.chain(*(protocols.measurement_key_objs(op) for op in flattened_contents))) + m.__setattr__( + 'measurement_keys', + self.measurement_keys.union(*(op.measurement_keys for op in flattened_contents)), ) - m._control_keys = self._control_keys_().union( - set(itertools.chain(*(protocols.control_keys(op) for op in flattened_contents))) + m.__setattr__( + 'control_keys', self.control_keys.union(*(op.control_keys for op in flattened_contents)) ) return m @@ -335,23 +332,13 @@ def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]): for op in self.operations ) - @_compat.cached_method() - def _measurement_key_names_(self) -> frozenset[str]: - return frozenset(str(key) for key in self._measurement_key_objs_()) - - def _measurement_key_objs_(self) -> frozenset[cirq.MeasurementKey]: - if self._measurement_key_objs is None: - self._measurement_key_objs = frozenset( - key for op in self.operations for key in protocols.measurement_key_objs(op) - ) - return self._measurement_key_objs + @cached_property + def measurement_keys(self) -> frozenset[cirq.MeasurementKey]: + return frozenset(key for op in self.operations for key in op.measurement_keys) - def _control_keys_(self) -> frozenset[cirq.MeasurementKey]: - if self._control_keys is None: - self._control_keys = frozenset( - k for op in self.operations for k in protocols.control_keys(op) - ) - return self._control_keys + @cached_property + def control_keys(self) -> frozenset[cirq.MeasurementKey]: + return frozenset(key for op in self.operations for key in op.control_keys) def _sorted_operations_(self) -> tuple[cirq.Operation, ...]: if self._sorted_operations is None: diff --git a/cirq-core/cirq/ops/classically_controlled_operation.py b/cirq-core/cirq/ops/classically_controlled_operation.py index bbb80823e0d..79d7f95e046 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation.py +++ b/cirq-core/cirq/ops/classically_controlled_operation.py @@ -14,6 +14,7 @@ from __future__ import annotations +from functools import cached_property from typing import AbstractSet, Any, Mapping, Sequence, TYPE_CHECKING import sympy @@ -89,7 +90,7 @@ def __init__( ValueError: If an unsupported gate is being classically controlled. """ - if protocols.measurement_key_objs(sub_operation): + if sub_operation.measurement_keys: raise ValueError( f'Cannot conditionally run operations with measurements: {sub_operation}' ) @@ -222,11 +223,10 @@ def _with_rescoped_keys_( sub_operation = protocols.with_rescoped_keys(self._sub_operation, path, bindable_keys) return sub_operation.with_classical_controls(*conds) - def _control_keys_(self) -> frozenset[cirq.MeasurementKey]: - local_keys: frozenset[cirq.MeasurementKey] = frozenset( - k for condition in self._conditions for k in condition.keys - ) - return local_keys.union(protocols.control_keys(self._sub_operation)) + @cached_property + def control_keys(self) -> frozenset[cirq.MeasurementKey]: + local_keys = frozenset(k for condition in self._conditions for k in condition.keys) + return local_keys | self._sub_operation.control_keys def _qasm_(self, args: cirq.QasmArgs) -> str | None: args.validate_version('2.0', '3.0') diff --git a/cirq-core/cirq/ops/eigen_gate.py b/cirq-core/cirq/ops/eigen_gate.py index dd206ce38da..ab3637e32cd 100644 --- a/cirq-core/cirq/ops/eigen_gate.py +++ b/cirq-core/cirq/ops/eigen_gate.py @@ -18,6 +18,7 @@ import fractions import math import numbers +from functools import cached_property from types import NotImplementedType from typing import AbstractSet, Any, cast, Iterable, NamedTuple, TYPE_CHECKING @@ -372,9 +373,6 @@ def _equal_up_to_global_phase_(self, other, atol): def _json_dict_(self) -> dict[str, Any]: return protocols.obj_to_dict_helper(self, ['exponent', 'global_shift']) - def _measurement_key_objs_(self): - return frozenset() - def _lcm(vals: Iterable[int]) -> int: t = 1 diff --git a/cirq-core/cirq/ops/gate_operation.py b/cirq-core/cirq/ops/gate_operation.py index f565b13d2c8..53406be81b1 100644 --- a/cirq-core/cirq/ops/gate_operation.py +++ b/cirq-core/cirq/ops/gate_operation.py @@ -18,6 +18,7 @@ import re import warnings +from functools import cached_property from types import NotImplementedType from typing import ( AbstractSet, @@ -233,30 +234,6 @@ def _is_measurement_(self) -> bool | None: # Let the protocol handle the fallback. return NotImplemented - def _measurement_key_name_(self) -> str | None: - getter = getattr(self.gate, '_measurement_key_name_', None) - if getter is not None: - return getter() - return NotImplemented - - def _measurement_key_names_(self) -> frozenset[str] | NotImplementedType | None: - getter = getattr(self.gate, '_measurement_key_names_', None) - if getter is not None: - return getter() - return NotImplemented - - def _measurement_key_obj_(self) -> cirq.MeasurementKey | None: - getter = getattr(self.gate, '_measurement_key_obj_', None) - if getter is not None: - return getter() - return NotImplemented - - def _measurement_key_objs_(self) -> frozenset[cirq.MeasurementKey] | NotImplementedType | None: - getter = getattr(self.gate, '_measurement_key_objs_', None) - if getter is not None: - return getter() - return NotImplemented - def _act_on_(self, sim_state: cirq.SimulationStateBase): getter = getattr(self.gate, '_act_on_', None) if getter is not None: @@ -380,5 +357,9 @@ def controlled_by( control_qid_shape=tuple(q.dimension for q in qubits), ).on(*(qubits + self._qubits)) + @cached_property + def measurement_keys(self) -> frozenset[cirq.MeasurementKey]: + return self.gate.measurement_keys + TV = TypeVar('TV', bound=raw_types.Gate) diff --git a/cirq-core/cirq/ops/kraus_channel.py b/cirq-core/cirq/ops/kraus_channel.py index 275c7dd30f0..d15d5bf1f20 100644 --- a/cirq-core/cirq/ops/kraus_channel.py +++ b/cirq-core/cirq/ops/kraus_channel.py @@ -2,6 +2,7 @@ from __future__ import annotations +from functools import cached_property from typing import Any, Iterable, Mapping, TYPE_CHECKING import numpy as np @@ -78,15 +79,11 @@ def num_qubits(self) -> int: def _kraus_(self): return self._kraus_ops - def _measurement_key_name_(self) -> str: + @cached_property + def measurement_keys(self) -> frozenset[cirq.MeasurementKey]: if self._key is None: - return NotImplemented - return str(self._key) - - def _measurement_key_obj_(self) -> cirq.MeasurementKey: - if self._key is None: - return NotImplemented - return self._key + return frozenset() + return frozenset([self._key]) def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]): if self._key is None: diff --git a/cirq-core/cirq/ops/measurement_gate.py b/cirq-core/cirq/ops/measurement_gate.py index fbc400587fa..416330ad337 100644 --- a/cirq-core/cirq/ops/measurement_gate.py +++ b/cirq-core/cirq/ops/measurement_gate.py @@ -14,6 +14,7 @@ from __future__ import annotations +from functools import cached_property from typing import Any, Iterable, Mapping, Sequence, TYPE_CHECKING import numpy as np @@ -170,11 +171,9 @@ def full_invert_mask(self) -> tuple[bool, ...]: def _is_measurement_(self) -> bool: return True - def _measurement_key_name_(self) -> str: - return self.key - - def _measurement_key_obj_(self) -> cirq.MeasurementKey: - return self.mkey + @cached_property + def measurement_keys(self) -> frozenset[cirq.MeasurementKey]: + return frozenset([self.mkey]) def _kraus_(self): size = np.prod(self._qid_shape, dtype=np.int64) diff --git a/cirq-core/cirq/ops/mixed_unitary_channel.py b/cirq-core/cirq/ops/mixed_unitary_channel.py index 0f1124c500c..b49810be974 100644 --- a/cirq-core/cirq/ops/mixed_unitary_channel.py +++ b/cirq-core/cirq/ops/mixed_unitary_channel.py @@ -2,6 +2,7 @@ from __future__ import annotations +from functools import cached_property from typing import Any, Iterable, Mapping, TYPE_CHECKING import numpy as np @@ -83,15 +84,11 @@ def num_qubits(self) -> int: def _mixture_(self): return self._mixture - def _measurement_key_name_(self) -> str: + @cached_property + def measurement_keys(self) -> frozenset[cirq.MeasurementKey]: if self._key is None: - return NotImplemented - return str(self._key) - - def _measurement_key_obj_(self) -> cirq.MeasurementKey: - if self._key is None: - return NotImplemented - return self._key + return frozenset() + return frozenset([self._key]) def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]): if self._key is None: diff --git a/cirq-core/cirq/ops/pauli_measurement_gate.py b/cirq-core/cirq/ops/pauli_measurement_gate.py index 72a2ba23493..82283777d20 100644 --- a/cirq-core/cirq/ops/pauli_measurement_gate.py +++ b/cirq-core/cirq/ops/pauli_measurement_gate.py @@ -14,6 +14,7 @@ from __future__ import annotations +from functools import cached_property from typing import Any, cast, Iterable, Iterator, Mapping, Sequence, TYPE_CHECKING from cirq import protocols, value @@ -123,11 +124,9 @@ def with_observable( def _is_measurement_(self) -> bool: return True - def _measurement_key_name_(self) -> str: - return self.key - - def _measurement_key_obj_(self) -> cirq.MeasurementKey: - return self.mkey + @cached_property + def measurement_keys(self) -> frozenset[cirq.MeasurementKey]: + return frozenset([self.mkey]) def observable(self) -> cirq.DensePauliString: """Pauli observable which should be measured by the gate.""" diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index 1043eb2f81f..f17c59d564d 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -18,6 +18,7 @@ import abc import functools +from functools import cached_property from types import NotImplementedType from typing import ( AbstractSet, @@ -455,6 +456,10 @@ def _qid_shape_(self) -> tuple[int, ...]: """ raise NotImplementedError + @cached_property + def measurement_keys(self) -> frozenset[cirq.MeasurementKey]: + return protocols.measurement_key_objs(self, _skip_property_check=True) + def _equal_up_to_global_phase_( self, other: Any, atol: float = 1e-8 ) -> NotImplementedType | bool: @@ -749,6 +754,14 @@ def without_classical_controls(self) -> cirq.Operation: """ return self + @cached_property + def measurement_keys(self) -> frozenset[cirq.MeasurementKey]: + return protocols.measurement_key_objs(self, _skip_property_check=True) + + @cached_property + def control_keys(self) -> frozenset[cirq.MeasurementKey]: + return protocols.control_keys(self, _skip_property_check=True) + @value.value_equality class TaggedOperation(Operation): @@ -881,13 +894,9 @@ def _has_kraus_(self) -> bool: def _kraus_(self) -> tuple[np.ndarray, ...] | NotImplementedType: return protocols.kraus(self.sub_operation, NotImplemented) - @cached_method - def _measurement_key_names_(self) -> frozenset[str]: - return protocols.measurement_key_names(self.sub_operation) - - @cached_method - def _measurement_key_objs_(self) -> frozenset[cirq.MeasurementKey]: - return protocols.measurement_key_objs(self.sub_operation) + @cached_property + def measurement_keys(self) -> frozenset[cirq.MeasurementKey]: + return self.sub_operation.measurement_keys @cached_method def _is_measurement_(self) -> bool: @@ -981,8 +990,9 @@ def with_classical_controls(self, *conditions): return self return self.sub_operation.with_classical_controls(*conditions) - def _control_keys_(self) -> frozenset[cirq.MeasurementKey]: - return protocols.control_keys(self.sub_operation) + @cached_property + def control_keys(self) -> frozenset[cirq.MeasurementKey]: + return self.sub_operation.control_keys @value.value_equality @@ -1099,10 +1109,10 @@ def _operations_commutes_impl( False: `ops1` and `ops2` do not commute. NotImplemented: The commutativity cannot be determined here. """ - ops1_keys = frozenset(k for op in ops1 for k in protocols.measurement_key_objs(op)) - ops2_keys = frozenset(k for op in ops2 for k in protocols.measurement_key_objs(op)) - ops1_control_keys = frozenset(k for op in ops1 for k in protocols.control_keys(op)) - ops2_control_keys = frozenset(k for op in ops2 for k in protocols.control_keys(op)) + ops1_keys = frozenset(k for op in ops1 for k in op.measurement_keys) + ops2_keys = frozenset(k for op in ops2 for k in op.measurement_keys) + ops1_control_keys = frozenset(k for op in ops1 for k in op.control_keys) + ops2_control_keys = frozenset(k for op in ops2 for k in op.control_keys) if ( not ops1_keys.isdisjoint(ops2_keys) or not ops1_control_keys.isdisjoint(ops2_keys) diff --git a/cirq-core/cirq/protocols/control_key_protocol.py b/cirq-core/cirq/protocols/control_key_protocol.py index a02af1508e9..045b64e9529 100644 --- a/cirq-core/cirq/protocols/control_key_protocol.py +++ b/cirq-core/cirq/protocols/control_key_protocol.py @@ -44,7 +44,7 @@ def _control_keys_(self) -> frozenset[cirq.MeasurementKey] | NotImplementedType """ -def control_keys(val: Any) -> frozenset[cirq.MeasurementKey]: +def control_keys(val: Any, _skip_property_check=False) -> frozenset[cirq.MeasurementKey]: """Gets the keys that the value is classically controlled by. Args: @@ -61,6 +61,11 @@ def control_keys(val: Any) -> frozenset[cirq.MeasurementKey]: the subcircuit are still required externally and thus appear in the result. """ + if not _skip_property_check: + attr = getattr(val, 'control_keys', None) + if attr is not None: + return attr + getter = getattr(val, '_control_keys_', None) result = NotImplemented if getter is None else getter() if result is not NotImplemented and result is not None: diff --git a/cirq-core/cirq/protocols/measurement_key_protocol.py b/cirq-core/cirq/protocols/measurement_key_protocol.py index 6346e577540..7e832f647e2 100644 --- a/cirq-core/cirq/protocols/measurement_key_protocol.py +++ b/cirq-core/cirq/protocols/measurement_key_protocol.py @@ -196,10 +196,14 @@ def measurement_key_name(val, default=RaiseTypeErrorIfNotProvided): def _measurement_key_objs_from_magic_methods( - val: Any, + val: Any, _skip_property_check=False ) -> frozenset[cirq.MeasurementKey] | NotImplementedType | None: """Uses the measurement key related magic methods to get the `MeasurementKey`s for this object.""" + if not _skip_property_check: + attr = getattr(val, 'measurement_keys', None) + if attr is not None: + return attr getter = getattr(val, '_measurement_key_objs_', None) result = NotImplemented if getter is None else getter() @@ -231,7 +235,7 @@ def _measurement_key_names_from_magic_methods( return result -def measurement_key_objs(val: Any) -> frozenset[cirq.MeasurementKey]: +def measurement_key_objs(val: Any, _skip_property_check=False) -> frozenset[cirq.MeasurementKey]: """Gets the measurement key objects of measurements within the given value. Args: @@ -241,7 +245,7 @@ def measurement_key_objs(val: Any) -> frozenset[cirq.MeasurementKey]: The measurement key objects of the value. If the value has no measurement, the result is the empty set. """ - result = _measurement_key_objs_from_magic_methods(val) + result = _measurement_key_objs_from_magic_methods(val, _skip_property_check) if result is not NotImplemented and result is not None: return result key_strings = _measurement_key_names_from_magic_methods(val) diff --git a/cirq-core/cirq/transformers/stratify.py b/cirq-core/cirq/transformers/stratify.py index 23d40702d7b..05068822c7e 100644 --- a/cirq-core/cirq/transformers/stratify.py +++ b/cirq-core/cirq/transformers/stratify.py @@ -167,9 +167,9 @@ def _stratify_circuit( # Update qubit, measurement key, and control key moments. for qubit in op.qubits: qubit_time_index[qubit] = time_index - for key in protocols.measurement_key_objs(op): + for key in op.measurement_keys: measurement_time_index[key] = time_index - for key in protocols.control_keys(op): + for key in op.control_keys: control_time_index[key] = time_index return circuits.Circuit(circuits.Moment(moment) for moment in new_moments if moment) diff --git a/cirq-core/cirq/transformers/synchronize_terminal_measurements.py b/cirq-core/cirq/transformers/synchronize_terminal_measurements.py index 3c4abc39918..783dd161765 100644 --- a/cirq-core/cirq/transformers/synchronize_terminal_measurements.py +++ b/cirq-core/cirq/transformers/synchronize_terminal_measurements.py @@ -50,11 +50,11 @@ def find_terminal_measurements(circuit: cirq.AbstractCircuit) -> list[tuple[int, op is not None and open_qubits.issuperset(op.qubits) and protocols.is_measurement(op) - and not (seen_control_keys & protocols.measurement_key_objs(op)) + and not (seen_control_keys & op.measurement_keys) ): terminal_measurements.add((i, op)) open_qubits -= moment.qubits - seen_control_keys |= protocols.control_keys(moment) + seen_control_keys |= moment.control_keys if not open_qubits: break return list(terminal_measurements) diff --git a/cirq-core/cirq/transformers/transformer_primitives.py b/cirq-core/cirq/transformers/transformer_primitives.py index 14396764011..10bfbdab48e 100644 --- a/cirq-core/cirq/transformers/transformer_primitives.py +++ b/cirq-core/cirq/transformers/transformer_primitives.py @@ -316,9 +316,9 @@ def add_op_to_moment(self, moment_index: int, op: cirq.Operation) -> None: self.qubit_indexes[q].append(moment_index) else: bisect.insort(self.qubit_indexes[q], moment_index) - for mkey in protocols.measurement_key_objs(op): + for mkey in op.measurement_keys: bisect.insort(self.mkey_indexes[mkey], moment_index) - for ckey in protocols.control_keys(op): + for ckey in op.control_keys: bisect.insort(self.ckey_indexes[ckey], moment_index) def remove_op_from_moment(self, moment_index: int, op: cirq.Operation) -> None: @@ -328,9 +328,9 @@ def remove_op_from_moment(self, moment_index: int, op: cirq.Operation) -> None: self.qubit_indexes[q].pop() else: self.qubit_indexes[q].remove(moment_index) - for mkey in protocols.measurement_key_objs(op): + for mkey in op.measurement_keys: self.mkey_indexes[mkey].remove(moment_index) - for ckey in protocols.control_keys(op): + for ckey in op.control_keys: self.ckey_indexes[ckey].remove(moment_index) def get_mergeable_ops( @@ -338,10 +338,8 @@ def get_mergeable_ops( ) -> tuple[int, list[cirq.Operation]]: # Find the index of previous moment which can be merged with `op`. idx = max([self.qubit_indexes[q][-1] for q in op_qs], default=-1) - idx = max([idx] + [self.mkey_indexes[ckey][-1] for ckey in protocols.control_keys(op)]) - idx = max( - [idx] + [self.ckey_indexes[mkey][-1] for mkey in protocols.measurement_key_objs(op)] - ) + idx = max([idx] + [self.mkey_indexes[ckey][-1] for ckey in op.control_keys]) + idx = max([idx] + [self.ckey_indexes[mkey][-1] for mkey in op.measurement_keys]) # Return the set of overlapping ops in moment with index `idx`. if idx == -1: return idx, []