Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 16 additions & 19 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,7 @@ def _with_rescoped_keys_(
for moment in self.moments:
new_moment = protocols.with_rescoped_keys(moment, path, bindable_keys)
moments.append(new_moment)
bindable_keys |= protocols.measurement_key_objs(new_moment)
bindable_keys |= new_moment.measurement_keys
return self._from_moments(moments, tags=self.tags)

def _qid_shape_(self) -> tuple[int, ...]:
Expand Down Expand Up @@ -1279,10 +1279,7 @@ def to_text_diagram_drawer(
"""
qubits = ops.QubitOrder.as_qubit_order(qubit_order).order_for(self.all_qubits())
cbits = tuple(
sorted(
set(key for op in self.all_operations() for key in protocols.control_keys(op)),
key=str,
)
sorted(set(key for op in self.all_operations() for key in op.control_keys), key=str)
)
labels = qubits + cbits
label_map = {labels[i]: i for i in range(len(labels))}
Expand Down Expand Up @@ -1665,14 +1662,19 @@ def factorize(self) -> Iterable[Self]:
for qubits in qubit_factors
)

def _control_keys_(self) -> frozenset[cirq.MeasurementKey]:
@property
def measurement_keys(self) -> frozenset[cirq.MeasurementKey]:
return frozenset().union(*(m.measurement_keys for m in self.moments))

@property
def control_keys(self) -> frozenset[cirq.MeasurementKey]:
measures: set[cirq.MeasurementKey] = set()
controls: set[cirq.MeasurementKey] = set()
for op in self.all_operations():
# Only require keys that haven't already been measured earlier
controls.update(k for k in protocols.control_keys(op) if k not in measures)
controls.update(op.control_keys - measures)
# Record any measurement keys produced by this op
measures.update(protocols.measurement_key_objs(op))
measures.update(op.measurement_keys)
return frozenset(controls)


Expand Down Expand Up @@ -2139,19 +2141,14 @@ def earliest_available_moment(
end_moment_index = len(self.moments)
last_available = end_moment_index
k = end_moment_index
op_control_keys = protocols.control_keys(op)
op_measurement_keys = protocols.measurement_key_objs(op)
op_qubits = op.qubits
while k > 0:
k -= 1
moment = self._moments[k]
if moment.operates_on(op_qubits):
return last_available
moment_measurement_keys = moment._measurement_key_objs_()
if (
not op_measurement_keys.isdisjoint(moment_measurement_keys)
or not op_control_keys.isdisjoint(moment_measurement_keys)
or not moment._control_keys_().isdisjoint(op_measurement_keys)
moment.operates_on(op.qubits)
or not op.measurement_keys.isdisjoint(moment.measurement_keys)
or not op.control_keys.isdisjoint(moment.measurement_keys)
or not moment.control_keys.isdisjoint(op.measurement_keys)
):
return last_available
if self._can_add_op_at(k, op):
Expand Down Expand Up @@ -2970,8 +2967,8 @@ def get_earliest_accommodating_moment_index(
The integer index of the earliest moment that can accommodate the given moment or operation.
"""
mop_qubits = moment_or_operation.qubits
mop_mkeys = protocols.measurement_key_objs(moment_or_operation)
mop_ckeys = protocols.control_keys(moment_or_operation)
mop_mkeys = moment_or_operation.measurement_keys
mop_ckeys = moment_or_operation.control_keys

if isinstance(moment_or_operation, Moment):
# For consistency with `Circuit.append`, moments always get placed at the end of a circuit.
Expand Down
29 changes: 8 additions & 21 deletions cirq-core/cirq/circuits/circuit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,7 @@ def __init__(
if mapped_repeat_until:
if self._use_repetition_ids or self._repetitions != 1:
raise ValueError('Cannot use repetitions with repeat_until')
if protocols.measurement_key_objs(self._mapped_single_loop()).isdisjoint(
mapped_repeat_until.keys
):
if self._mapped_single_loop().measurement_keys.isdisjoint(mapped_repeat_until.keys):
raise ValueError('Infinite loop: condition is not modified in subcircuit.')

@property
Expand Down Expand Up @@ -309,8 +307,8 @@ def _ensure_deterministic_loop_count(self):
raise ValueError('Cannot unroll circuit due to nondeterministic repetitions')

@cached_property
def _measurement_key_objs(self) -> frozenset[cirq.MeasurementKey]:
circuit_keys = protocols.measurement_key_objs(self.circuit)
def measurement_keys(self) -> frozenset[cirq.MeasurementKey]:
circuit_keys = self.circuit.measurement_keys
if circuit_keys and self.use_repetition_ids:
self._ensure_deterministic_loop_count()
if self.repetition_ids is not None:
Expand All @@ -327,27 +325,18 @@ def _measurement_key_objs(self) -> frozenset[cirq.MeasurementKey]:
for key in circuit_keys
)

def _measurement_key_objs_(self) -> frozenset[cirq.MeasurementKey]:
return self._measurement_key_objs

def _measurement_key_names_(self) -> frozenset[str]:
return frozenset(str(key) for key in self._measurement_key_objs_())

@cached_property
def _control_keys(self) -> frozenset[cirq.MeasurementKey]:
def control_keys(self) -> frozenset[cirq.MeasurementKey]:
keys = (
frozenset()
if not protocols.control_keys(self.circuit)
else protocols.control_keys(self._mapped_single_loop())
if not self.circuit.control_keys
else self._mapped_single_loop().control_keys
)
mapped_repeat_until = self._mapped_repeat_until
if mapped_repeat_until is not None:
keys |= frozenset(mapped_repeat_until.keys) - self._measurement_key_objs_()
keys |= frozenset(mapped_repeat_until.keys) - self.measurement_keys
return keys

def _control_keys_(self) -> frozenset[cirq.MeasurementKey]:
return self._control_keys

def _is_parameterized_(self) -> bool:
return any(self._parameter_names_generator())

Expand Down Expand Up @@ -395,9 +384,7 @@ def _mapped_repeat_until(self) -> cirq.Condition | None:
repeat_until, self.param_resolver, recursive=False
)
return protocols.with_rescoped_keys(
repeat_until,
self.parent_path,
bindable_keys=self._extern_keys | self._measurement_key_objs,
repeat_until, self.parent_path, bindable_keys=self._extern_keys | self.measurement_keys
)

def mapped_circuit(self, deep: bool = False) -> cirq.Circuit:
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4914,6 +4914,7 @@ def test_create_speed() -> None:
c = cirq.Circuit(ops)
duration = time.perf_counter() - t
assert len(c) == moments
print(duration)
assert duration < 4


Expand All @@ -4936,6 +4937,7 @@ def test_append_speed() -> None:
c.append(xs[q])
duration = time.perf_counter() - t
assert len(c) == moments
print(duration)
assert duration < 5


Expand Down
12 changes: 8 additions & 4 deletions cirq-core/cirq/circuits/frozen_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,6 @@ def all_measurement_key_objs(self) -> frozenset[cirq.MeasurementKey]:
def _measurement_key_objs_(self) -> frozenset[cirq.MeasurementKey]:
return self.all_measurement_key_objs()

@_compat.cached_method
def _control_keys_(self) -> frozenset[cirq.MeasurementKey]:
return super()._control_keys_()

@_compat.cached_method
def are_all_measurements_terminal(self) -> bool:
return super().are_all_measurements_terminal()
Expand Down Expand Up @@ -223,3 +219,11 @@ def to_op(self) -> cirq.CircuitOperation:
from cirq.circuits import CircuitOperation

return CircuitOperation(self)

@cached_property
def measurement_keys(self) -> frozenset[cirq.MeasurementKey]:
return super().measurement_keys

@cached_property
def control_keys(self) -> frozenset[cirq.MeasurementKey]:
return super().control_keys
39 changes: 13 additions & 26 deletions cirq-core/cirq/circuits/moment.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,6 @@ def __init__(
raise ValueError(f'Overlapping operations: {self.operations}')
self._qubit_to_op[q] = op

self._measurement_key_objs: frozenset[cirq.MeasurementKey] | None = None
self._control_keys: frozenset[cirq.MeasurementKey] | None = None
self._tags = tags

@classmethod
Expand Down Expand Up @@ -234,10 +232,8 @@ def with_operation(self, operation: cirq.Operation) -> cirq.Moment:
m._sorted_operations = None
m._qubit_to_op = {**self._qubit_to_op, **{q: operation for q in operation.qubits}}

m._measurement_key_objs = self._measurement_key_objs_().union(
protocols.measurement_key_objs(operation)
)
m._control_keys = self._control_keys_().union(protocols.control_keys(operation))
m.__setattr__('measurement_keys', self.measurement_keys | operation.measurement_keys)
m.__setattr__('control_keys', self.control_keys | operation.control_keys)

return m

Expand Down Expand Up @@ -272,11 +268,12 @@ def with_operations(self, *contents: cirq.OP_TREE) -> cirq.Moment:

m._operations = self._operations + flattened_contents
m._sorted_operations = None
m._measurement_key_objs = self._measurement_key_objs_().union(
set(itertools.chain(*(protocols.measurement_key_objs(op) for op in flattened_contents)))
m.__setattr__(
'measurement_keys',
self.measurement_keys.union(*(op.measurement_keys for op in flattened_contents)),
)
m._control_keys = self._control_keys_().union(
set(itertools.chain(*(protocols.control_keys(op) for op in flattened_contents)))
m.__setattr__(
'control_keys', self.control_keys.union(*(op.control_keys for op in flattened_contents))
)

return m
Expand Down Expand Up @@ -335,23 +332,13 @@ def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]):
for op in self.operations
)

@_compat.cached_method()
def _measurement_key_names_(self) -> frozenset[str]:
return frozenset(str(key) for key in self._measurement_key_objs_())

def _measurement_key_objs_(self) -> frozenset[cirq.MeasurementKey]:
if self._measurement_key_objs is None:
self._measurement_key_objs = frozenset(
key for op in self.operations for key in protocols.measurement_key_objs(op)
)
return self._measurement_key_objs
@cached_property
def measurement_keys(self) -> frozenset[cirq.MeasurementKey]:
return frozenset().union(*(op.measurement_keys for op in self.operations))

def _control_keys_(self) -> frozenset[cirq.MeasurementKey]:
if self._control_keys is None:
self._control_keys = frozenset(
k for op in self.operations for k in protocols.control_keys(op)
)
return self._control_keys
@cached_property
def control_keys(self) -> frozenset[cirq.MeasurementKey]:
return frozenset().union(*(op.control_keys for op in self.operations))

def _sorted_operations_(self) -> tuple[cirq.Operation, ...]:
if self._sorted_operations is None:
Expand Down
19 changes: 8 additions & 11 deletions cirq-core/cirq/circuits/moment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,19 +413,16 @@ def test_measurement_keys() -> None:
def test_measurement_key_objs_caching() -> None:
q0, q1, q2, q3 = cirq.LineQubit.range(4)
m = cirq.Moment(cirq.measure(q0, key='foo'))
assert m._measurement_key_objs is None
assert m.measurement_keys == {cirq.MeasurementKey(name='foo')}
key_objs = cirq.measurement_key_objs(m)
assert m._measurement_key_objs == key_objs
assert m.measurement_keys == key_objs

# Make sure it gets updated when adding an operation.
m = m.with_operation(cirq.measure(q1, key='bar'))
assert m._measurement_key_objs == {
cirq.MeasurementKey(name='bar'),
cirq.MeasurementKey(name='foo'),
}
assert m.measurement_keys == {cirq.MeasurementKey(name='bar'), cirq.MeasurementKey(name='foo')}
# Or multiple operations.
m = m.with_operations(cirq.measure(q2, key='doh'), cirq.measure(q3, key='baz'))
assert m._measurement_key_objs == {
assert m.measurement_keys == {
cirq.MeasurementKey(name='bar'),
cirq.MeasurementKey(name='foo'),
cirq.MeasurementKey(name='doh'),
Expand All @@ -436,18 +433,18 @@ def test_measurement_key_objs_caching() -> None:
def test_control_keys_caching() -> None:
q0, q1, q2, q3 = cirq.LineQubit.range(4)
m = cirq.Moment(cirq.X(q0).with_classical_controls('foo'))
assert m._control_keys is None
assert m.control_keys == {cirq.MeasurementKey(name='foo')}
keys = cirq.control_keys(m)
assert m._control_keys == keys
assert m.control_keys == keys

# Make sure it gets updated when adding an operation.
m = m.with_operation(cirq.X(q1).with_classical_controls('bar'))
assert m._control_keys == {cirq.MeasurementKey(name='bar'), cirq.MeasurementKey(name='foo')}
assert m.control_keys == {cirq.MeasurementKey(name='bar'), cirq.MeasurementKey(name='foo')}
# Or multiple operations.
m = m.with_operations(
cirq.X(q2).with_classical_controls('doh'), cirq.X(q3).with_classical_controls('baz')
)
assert m._control_keys == {
assert m.control_keys == {
cirq.MeasurementKey(name='bar'),
cirq.MeasurementKey(name='foo'),
cirq.MeasurementKey(name='doh'),
Expand Down
12 changes: 6 additions & 6 deletions cirq-core/cirq/ops/classically_controlled_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

from functools import cached_property
from typing import AbstractSet, Any, Mapping, Sequence, TYPE_CHECKING

import sympy
Expand Down Expand Up @@ -89,7 +90,7 @@ def __init__(
ValueError: If an unsupported gate is being classically
controlled.
"""
if protocols.measurement_key_objs(sub_operation):
if sub_operation.measurement_keys:
raise ValueError(
f'Cannot conditionally run operations with measurements: {sub_operation}'
)
Expand Down Expand Up @@ -222,11 +223,10 @@ def _with_rescoped_keys_(
sub_operation = protocols.with_rescoped_keys(self._sub_operation, path, bindable_keys)
return sub_operation.with_classical_controls(*conds)

def _control_keys_(self) -> frozenset[cirq.MeasurementKey]:
local_keys: frozenset[cirq.MeasurementKey] = frozenset(
k for condition in self._conditions for k in condition.keys
)
return local_keys.union(protocols.control_keys(self._sub_operation))
@cached_property
def control_keys(self) -> frozenset[cirq.MeasurementKey]:
local_keys = frozenset(k for condition in self._conditions for k in condition.keys)
return local_keys | self._sub_operation.control_keys

def _qasm_(self, args: cirq.QasmArgs) -> str | None:
args.validate_version('2.0', '3.0')
Expand Down
5 changes: 5 additions & 0 deletions cirq-core/cirq/ops/gate_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import re
import warnings
from functools import cached_property
from types import NotImplementedType
from typing import (
AbstractSet,
Expand Down Expand Up @@ -380,5 +381,9 @@ def controlled_by(
control_qid_shape=tuple(q.dimension for q in qubits),
).on(*(qubits + self._qubits))

@cached_property
def measurement_keys(self) -> frozenset[cirq.MeasurementKey]:
return self.gate.measurement_keys


TV = TypeVar('TV', bound=raw_types.Gate)
13 changes: 5 additions & 8 deletions cirq-core/cirq/ops/kraus_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from functools import cached_property
from typing import Any, Iterable, Mapping, TYPE_CHECKING

import numpy as np
Expand Down Expand Up @@ -78,15 +79,11 @@ def num_qubits(self) -> int:
def _kraus_(self):
return self._kraus_ops

def _measurement_key_name_(self) -> str:
@cached_property
def measurement_keys(self) -> frozenset[cirq.MeasurementKey]:
if self._key is None:
return NotImplemented
return str(self._key)

def _measurement_key_obj_(self) -> cirq.MeasurementKey:
if self._key is None:
return NotImplemented
return self._key
return frozenset()
return frozenset([self._key])

def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]):
if self._key is None:
Expand Down
Loading