Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 17 additions & 28 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,18 +988,12 @@ def qid_shape(
qids = ops.QubitOrder.as_qubit_order(qubit_order).order_for(self.all_qubits())
return protocols.qid_shape(qids)

def all_measurement_key_objs(self) -> frozenset[cirq.MeasurementKey]:
return frozenset(
key for op in self.all_operations() for key in protocols.measurement_key_objs(op)
)

def _measurement_key_objs_(self) -> frozenset[cirq.MeasurementKey]:
"""Returns the set of all measurement keys in this circuit.
@property
def measurement_keys(self) -> frozenset[cirq.MeasurementKey]:
return frozenset(key for m in self.moments for key in m.measurement_keys)

Returns: frozenset of `cirq.MeasurementKey` objects that are
in this circuit.
"""
return self.all_measurement_key_objs()
def all_measurement_key_objs(self) -> frozenset[cirq.MeasurementKey]:
return self.measurement_keys

def all_measurement_key_names(self) -> frozenset[str]:
"""Returns the set of all measurement key names in this circuit.
Expand All @@ -1011,9 +1005,6 @@ def all_measurement_key_names(self) -> frozenset[str]:
key for op in self.all_operations() for key in protocols.measurement_key_names(op)
)

def _measurement_key_names_(self) -> frozenset[str]:
return self.all_measurement_key_names()

def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]):
return self._from_moments(
(protocols.with_measurement_key_mapping(moment, key_map) for moment in self.moments),
Expand All @@ -1038,7 +1029,7 @@ def _with_rescoped_keys_(
for moment in self.moments:
new_moment = protocols.with_rescoped_keys(moment, path, bindable_keys)
moments.append(new_moment)
bindable_keys |= protocols.measurement_key_objs(new_moment)
bindable_keys |= new_moment.measurement_keys
return self._from_moments(moments, tags=self.tags)

def _qid_shape_(self) -> tuple[int, ...]:
Expand Down Expand Up @@ -1279,10 +1270,7 @@ def to_text_diagram_drawer(
"""
qubits = ops.QubitOrder.as_qubit_order(qubit_order).order_for(self.all_qubits())
cbits = tuple(
sorted(
set(key for op in self.all_operations() for key in protocols.control_keys(op)),
key=str,
)
sorted(set(key for op in self.all_operations() for key in op.control_keys), key=str)
)
labels = qubits + cbits
label_map = {labels[i]: i for i in range(len(labels))}
Expand Down Expand Up @@ -1665,14 +1653,15 @@ def factorize(self) -> Iterable[Self]:
for qubits in qubit_factors
)

def _control_keys_(self) -> frozenset[cirq.MeasurementKey]:
@property
def control_keys(self) -> frozenset[cirq.MeasurementKey]:
measures: set[cirq.MeasurementKey] = set()
controls: set[cirq.MeasurementKey] = set()
for op in self.all_operations():
# Only require keys that haven't already been measured earlier
controls.update(k for k in protocols.control_keys(op) if k not in measures)
controls.update(op.control_keys - measures)
# Record any measurement keys produced by this op
measures.update(protocols.measurement_key_objs(op))
measures.update(op.measurement_keys)
return frozenset(controls)


Expand Down Expand Up @@ -2139,19 +2128,19 @@ def earliest_available_moment(
end_moment_index = len(self.moments)
last_available = end_moment_index
k = end_moment_index
op_control_keys = protocols.control_keys(op)
op_measurement_keys = protocols.measurement_key_objs(op)
op_control_keys = op.control_keys
op_measurement_keys = op.measurement_keys
op_qubits = op.qubits
while k > 0:
k -= 1
moment = self._moments[k]
if moment.operates_on(op_qubits):
return last_available
moment_measurement_keys = moment._measurement_key_objs_()
moment_measurement_keys = moment.measurement_keys
if (
not op_measurement_keys.isdisjoint(moment_measurement_keys)
or not op_control_keys.isdisjoint(moment_measurement_keys)
or not moment._control_keys_().isdisjoint(op_measurement_keys)
or not moment.control_keys.isdisjoint(op_measurement_keys)
):
return last_available
if self._can_add_op_at(k, op):
Expand Down Expand Up @@ -2970,8 +2959,8 @@ def get_earliest_accommodating_moment_index(
The integer index of the earliest moment that can accommodate the given moment or operation.
"""
mop_qubits = moment_or_operation.qubits
mop_mkeys = protocols.measurement_key_objs(moment_or_operation)
mop_ckeys = protocols.control_keys(moment_or_operation)
mop_mkeys = moment_or_operation.measurement_keys
mop_ckeys = moment_or_operation.control_keys

if isinstance(moment_or_operation, Moment):
# For consistency with `Circuit.append`, moments always get placed at the end of a circuit.
Expand Down
29 changes: 8 additions & 21 deletions cirq-core/cirq/circuits/circuit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,7 @@ def __init__(
if mapped_repeat_until:
if self._use_repetition_ids or self._repetitions != 1:
raise ValueError('Cannot use repetitions with repeat_until')
if protocols.measurement_key_objs(self._mapped_single_loop()).isdisjoint(
mapped_repeat_until.keys
):
if self._mapped_single_loop().measurement_keys.isdisjoint(mapped_repeat_until.keys):
raise ValueError('Infinite loop: condition is not modified in subcircuit.')

@property
Expand Down Expand Up @@ -309,8 +307,8 @@ def _ensure_deterministic_loop_count(self):
raise ValueError('Cannot unroll circuit due to nondeterministic repetitions')

@cached_property
def _measurement_key_objs(self) -> frozenset[cirq.MeasurementKey]:
circuit_keys = protocols.measurement_key_objs(self.circuit)
def measurement_keys(self) -> frozenset[cirq.MeasurementKey]:
circuit_keys = self.circuit.measurement_keys
if circuit_keys and self.use_repetition_ids:
self._ensure_deterministic_loop_count()
if self.repetition_ids is not None:
Expand All @@ -327,27 +325,18 @@ def _measurement_key_objs(self) -> frozenset[cirq.MeasurementKey]:
for key in circuit_keys
)

def _measurement_key_objs_(self) -> frozenset[cirq.MeasurementKey]:
return self._measurement_key_objs

def _measurement_key_names_(self) -> frozenset[str]:
return frozenset(str(key) for key in self._measurement_key_objs_())

@cached_property
def _control_keys(self) -> frozenset[cirq.MeasurementKey]:
def control_keys(self) -> frozenset[cirq.MeasurementKey]:
keys = (
frozenset()
if not protocols.control_keys(self.circuit)
else protocols.control_keys(self._mapped_single_loop())
if not self.circuit.control_keys
else self._mapped_single_loop().control_keys
)
mapped_repeat_until = self._mapped_repeat_until
if mapped_repeat_until is not None:
keys |= frozenset(mapped_repeat_until.keys) - self._measurement_key_objs_()
keys |= frozenset(mapped_repeat_until.keys) - self.measurement_keys
return keys

def _control_keys_(self) -> frozenset[cirq.MeasurementKey]:
return self._control_keys

def _is_parameterized_(self) -> bool:
return any(self._parameter_names_generator())

Expand Down Expand Up @@ -395,9 +384,7 @@ def _mapped_repeat_until(self) -> cirq.Condition | None:
repeat_until, self.param_resolver, recursive=False
)
return protocols.with_rescoped_keys(
repeat_until,
self.parent_path,
bindable_keys=self._extern_keys | self._measurement_key_objs,
repeat_until, self.parent_path, bindable_keys=self._extern_keys | self.measurement_keys
)

def mapped_circuit(self, deep: bool = False) -> cirq.Circuit:
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4914,6 +4914,7 @@ def test_create_speed() -> None:
c = cirq.Circuit(ops)
duration = time.perf_counter() - t
assert len(c) == moments
print(duration)
assert duration < 4


Expand All @@ -4936,6 +4937,7 @@ def test_append_speed() -> None:
c.append(xs[q])
duration = time.perf_counter() - t
assert len(c) == moments
print(duration)
assert duration < 5


Expand Down
39 changes: 13 additions & 26 deletions cirq-core/cirq/circuits/moment.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,6 @@ def __init__(
raise ValueError(f'Overlapping operations: {self.operations}')
self._qubit_to_op[q] = op

self._measurement_key_objs: frozenset[cirq.MeasurementKey] | None = None
self._control_keys: frozenset[cirq.MeasurementKey] | None = None
self._tags = tags

@classmethod
Expand Down Expand Up @@ -234,10 +232,8 @@ def with_operation(self, operation: cirq.Operation) -> cirq.Moment:
m._sorted_operations = None
m._qubit_to_op = {**self._qubit_to_op, **{q: operation for q in operation.qubits}}

m._measurement_key_objs = self._measurement_key_objs_().union(
protocols.measurement_key_objs(operation)
)
m._control_keys = self._control_keys_().union(protocols.control_keys(operation))
m.__setattr__('measurement_keys', self.measurement_keys | operation.measurement_keys)
m.__setattr__('control_keys', self.control_keys | operation.control_keys)

return m

Expand Down Expand Up @@ -272,11 +268,12 @@ def with_operations(self, *contents: cirq.OP_TREE) -> cirq.Moment:

m._operations = self._operations + flattened_contents
m._sorted_operations = None
m._measurement_key_objs = self._measurement_key_objs_().union(
set(itertools.chain(*(protocols.measurement_key_objs(op) for op in flattened_contents)))
m.__setattr__(
'measurement_keys',
self.measurement_keys.union(*(op.measurement_keys for op in flattened_contents)),
)
m._control_keys = self._control_keys_().union(
set(itertools.chain(*(protocols.control_keys(op) for op in flattened_contents)))
m.__setattr__(
'control_keys', self.control_keys.union(*(op.control_keys for op in flattened_contents))
)

return m
Expand Down Expand Up @@ -335,23 +332,13 @@ def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]):
for op in self.operations
)

@_compat.cached_method()
def _measurement_key_names_(self) -> frozenset[str]:
return frozenset(str(key) for key in self._measurement_key_objs_())

def _measurement_key_objs_(self) -> frozenset[cirq.MeasurementKey]:
if self._measurement_key_objs is None:
self._measurement_key_objs = frozenset(
key for op in self.operations for key in protocols.measurement_key_objs(op)
)
return self._measurement_key_objs
@cached_property
def measurement_keys(self) -> frozenset[cirq.MeasurementKey]:
return frozenset(key for op in self.operations for key in op.measurement_keys)

def _control_keys_(self) -> frozenset[cirq.MeasurementKey]:
if self._control_keys is None:
self._control_keys = frozenset(
k for op in self.operations for k in protocols.control_keys(op)
)
return self._control_keys
@cached_property
def control_keys(self) -> frozenset[cirq.MeasurementKey]:
return frozenset(key for op in self.operations for key in op.control_keys)

def _sorted_operations_(self) -> tuple[cirq.Operation, ...]:
if self._sorted_operations is None:
Expand Down
12 changes: 6 additions & 6 deletions cirq-core/cirq/ops/classically_controlled_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

from functools import cached_property
from typing import AbstractSet, Any, Mapping, Sequence, TYPE_CHECKING

import sympy
Expand Down Expand Up @@ -89,7 +90,7 @@ def __init__(
ValueError: If an unsupported gate is being classically
controlled.
"""
if protocols.measurement_key_objs(sub_operation):
if sub_operation.measurement_keys:
raise ValueError(
f'Cannot conditionally run operations with measurements: {sub_operation}'
)
Expand Down Expand Up @@ -222,11 +223,10 @@ def _with_rescoped_keys_(
sub_operation = protocols.with_rescoped_keys(self._sub_operation, path, bindable_keys)
return sub_operation.with_classical_controls(*conds)

def _control_keys_(self) -> frozenset[cirq.MeasurementKey]:
local_keys: frozenset[cirq.MeasurementKey] = frozenset(
k for condition in self._conditions for k in condition.keys
)
return local_keys.union(protocols.control_keys(self._sub_operation))
@cached_property
def control_keys(self) -> frozenset[cirq.MeasurementKey]:
local_keys = frozenset(k for condition in self._conditions for k in condition.keys)
return local_keys | self._sub_operation.control_keys

def _qasm_(self, args: cirq.QasmArgs) -> str | None:
args.validate_version('2.0', '3.0')
Expand Down
4 changes: 1 addition & 3 deletions cirq-core/cirq/ops/eigen_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import fractions
import math
import numbers
from functools import cached_property
from types import NotImplementedType
from typing import AbstractSet, Any, cast, Iterable, NamedTuple, TYPE_CHECKING

Expand Down Expand Up @@ -372,9 +373,6 @@ def _equal_up_to_global_phase_(self, other, atol):
def _json_dict_(self) -> dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['exponent', 'global_shift'])

def _measurement_key_objs_(self):
return frozenset()


def _lcm(vals: Iterable[int]) -> int:
t = 1
Expand Down
29 changes: 5 additions & 24 deletions cirq-core/cirq/ops/gate_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import re
import warnings
from functools import cached_property
from types import NotImplementedType
from typing import (
AbstractSet,
Expand Down Expand Up @@ -233,30 +234,6 @@ def _is_measurement_(self) -> bool | None:
# Let the protocol handle the fallback.
return NotImplemented

def _measurement_key_name_(self) -> str | None:
getter = getattr(self.gate, '_measurement_key_name_', None)
if getter is not None:
return getter()
return NotImplemented

def _measurement_key_names_(self) -> frozenset[str] | NotImplementedType | None:
getter = getattr(self.gate, '_measurement_key_names_', None)
if getter is not None:
return getter()
return NotImplemented

def _measurement_key_obj_(self) -> cirq.MeasurementKey | None:
getter = getattr(self.gate, '_measurement_key_obj_', None)
if getter is not None:
return getter()
return NotImplemented

def _measurement_key_objs_(self) -> frozenset[cirq.MeasurementKey] | NotImplementedType | None:
getter = getattr(self.gate, '_measurement_key_objs_', None)
if getter is not None:
return getter()
return NotImplemented

def _act_on_(self, sim_state: cirq.SimulationStateBase):
getter = getattr(self.gate, '_act_on_', None)
if getter is not None:
Expand Down Expand Up @@ -380,5 +357,9 @@ def controlled_by(
control_qid_shape=tuple(q.dimension for q in qubits),
).on(*(qubits + self._qubits))

@cached_property
def measurement_keys(self) -> frozenset[cirq.MeasurementKey]:
return self.gate.measurement_keys


TV = TypeVar('TV', bound=raw_types.Gate)
13 changes: 5 additions & 8 deletions cirq-core/cirq/ops/kraus_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from functools import cached_property
from typing import Any, Iterable, Mapping, TYPE_CHECKING

import numpy as np
Expand Down Expand Up @@ -78,15 +79,11 @@ def num_qubits(self) -> int:
def _kraus_(self):
return self._kraus_ops

def _measurement_key_name_(self) -> str:
@cached_property
def measurement_keys(self) -> frozenset[cirq.MeasurementKey]:
if self._key is None:
return NotImplemented
return str(self._key)

def _measurement_key_obj_(self) -> cirq.MeasurementKey:
if self._key is None:
return NotImplemented
return self._key
return frozenset()
return frozenset([self._key])

def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]):
if self._key is None:
Expand Down
Loading
Loading