From d91e15c1e475ab69dc7620f88c277244f8bab36d Mon Sep 17 00:00:00 2001 From: Pavol Juhas Date: Wed, 20 Aug 2025 12:45:12 -0700 Subject: [PATCH 1/4] Add return-type to public functions, mostly tests part 4 No change in the effective code. A batch of ~50 files. Modified files pass ruff check --select=ANN201 Partially implements #4393 --- cirq-core/cirq/ops/measure_util.py | 13 +- cirq-core/cirq/ops/pauli_string_test.py | 4 +- .../has_stabilizer_effect_protocol_test.py | 20 +- .../protocols/has_unitary_protocol_test.py | 6 +- .../cirq/protocols/json_serialization.py | 6 +- .../cirq/protocols/json_serialization_test.py | 62 +++--- .../cirq/protocols/kraus_protocol_test.py | 10 +- .../protocols/measurement_key_protocol.py | 20 +- cirq-core/cirq/protocols/mixture_protocol.py | 2 +- .../cirq/protocols/mixture_protocol_test.py | 14 +- cirq-core/cirq/protocols/mul_protocol_test.py | 8 +- cirq-core/cirq/protocols/phase_protocol.py | 2 +- cirq-core/cirq/protocols/pow_protocol_test.py | 10 +- .../cirq/protocols/resolve_parameters.py | 2 +- .../cirq/protocols/unitary_protocol_test.py | 50 +++-- cirq-core/cirq/qis/clifford_tableau.py | 28 +-- cirq-core/cirq/qis/clifford_tableau_test.py | 34 ++-- cirq-core/cirq/qis/entropy.py | 2 +- cirq-core/cirq/qis/entropy_test.py | 2 +- cirq-core/cirq/qis/states_test.py | 108 +++++------ .../cirq/sim/classical_simulator_test.py | 84 ++++++--- .../cirq/sim/clifford/clifford_simulator.py | 10 +- .../sim/clifford/clifford_simulator_test.py | 99 +++++----- .../sim/clifford/stabilizer_state_ch_form.py | 18 +- .../sim/density_matrix_simulation_state.py | 12 +- .../cirq/sim/density_matrix_simulator.py | 2 +- .../cirq/sim/density_matrix_simulator_test.py | 178 +++++++++--------- .../cirq/sim/density_matrix_utils_test.py | 2 +- cirq-core/cirq/sim/mux_test.py | 52 ++--- .../cirq/sim/simulation_product_state_test.py | 14 +- cirq-core/cirq/sim/simulation_state.py | 8 +- cirq-core/cirq/sim/simulation_state_base.py | 2 +- cirq-core/cirq/sim/simulation_state_test.py | 10 +- cirq-core/cirq/sim/simulator.py | 4 +- cirq-core/cirq/sim/simulator_base_test.py | 84 +++++---- cirq-core/cirq/sim/simulator_test.py | 74 ++++---- cirq-core/cirq/sim/sparse_simulator.py | 2 +- cirq-core/cirq/sim/sparse_simulator_test.py | 174 +++++++++-------- cirq-core/cirq/sim/state_vector.py | 2 +- .../cirq/sim/state_vector_simulation_state.py | 10 +- .../cirq/sim/state_vector_simulator_test.py | 18 +- cirq-core/cirq/sim/state_vector_test.py | 74 ++++---- cirq-core/cirq/study/result_test.py | 40 ++-- cirq-core/cirq/study/sweepable_test.py | 40 ++-- cirq-core/cirq/study/sweeps_test.py | 86 ++++----- .../cirq/testing/circuit_compare_test.py | 30 +-- cirq-core/cirq/testing/consistent_channels.py | 4 +- .../testing/consistent_controlled_gate_op.py | 2 +- .../cirq/testing/consistent_decomposition.py | 6 +- cirq-core/cirq/testing/consistent_phase_by.py | 2 +- cirq-core/cirq/testing/consistent_qasm.py | 4 +- .../cirq/testing/consistent_qasm_test.py | 6 +- 52 files changed, 824 insertions(+), 732 deletions(-) diff --git a/cirq-core/cirq/ops/measure_util.py b/cirq-core/cirq/ops/measure_util.py index 2e31f77853c..5b02164479a 100644 --- a/cirq-core/cirq/ops/measure_util.py +++ b/cirq-core/cirq/ops/measure_util.py @@ -14,12 +14,12 @@ from __future__ import annotations -from typing import Callable, Iterable, overload, TYPE_CHECKING +from typing import Callable, cast, Iterable, overload, TYPE_CHECKING import numpy as np from cirq import protocols -from cirq.ops import pauli_string, raw_types +from cirq.ops import gate_operation, pauli_string, raw_types from cirq.ops.measurement_gate import MeasurementGate from cirq.ops.pauli_measurement_gate import PauliMeasurementGate @@ -96,7 +96,7 @@ def measure( key: str | cirq.MeasurementKey | None = None, invert_mask: tuple[bool, ...] = (), confusion_map: dict[tuple[int, ...], np.ndarray] | None = None, -) -> raw_types.Operation: +) -> gate_operation.GateOperation: pass @@ -107,7 +107,7 @@ def measure( key: str | cirq.MeasurementKey | None = None, invert_mask: tuple[bool, ...] = (), confusion_map: dict[tuple[int, ...], np.ndarray] | None = None, -) -> raw_types.Operation: +) -> gate_operation.GateOperation: pass @@ -116,7 +116,7 @@ def measure( key: str | cirq.MeasurementKey | None = None, invert_mask: tuple[bool, ...] = (), confusion_map: dict[tuple[int, ...], np.ndarray] | None = None, -) -> raw_types.Operation: +) -> gate_operation.GateOperation: """Returns a single MeasurementGate applied to all the given qubits. The qubits are measured in the computational basis. This can also be @@ -161,7 +161,8 @@ def measure( if key is None: key = _default_measurement_key(targets) qid_shape = protocols.qid_shape(targets) - return MeasurementGate(len(targets), key, invert_mask, qid_shape, confusion_map).on(*targets) + gate = MeasurementGate(len(targets), key, invert_mask, qid_shape, confusion_map) + return cast(gate_operation.GateOperation, gate.on(*targets)) M = measure diff --git a/cirq-core/cirq/ops/pauli_string_test.py b/cirq-core/cirq/ops/pauli_string_test.py index 91ccf2c1f92..e85a7efa977 100644 --- a/cirq-core/cirq/ops/pauli_string_test.py +++ b/cirq-core/cirq/ops/pauli_string_test.py @@ -2099,7 +2099,7 @@ def test_pauli_ops_identity_gate_operation(gate1: cirq.Pauli, gate2: cirq.Pauli) assert np.array_equal(subtraction.matrix(), unitary1 - unitary2) -def test_pauli_gate_multiplication_with_power(): +def test_pauli_gate_multiplication_with_power() -> None: q = cirq.LineQubit(0) # Test all Pauli gates (X, Y, Z) @@ -2122,7 +2122,7 @@ def test_pauli_gate_multiplication_with_power(): assert gate**5 * gate**0 == gate**5 -def test_try_interpret_as_pauli_string(): +def test_try_interpret_as_pauli_string() -> None: from cirq.ops.pauli_string import _try_interpret_as_pauli_string q = cirq.LineQubit(0) diff --git a/cirq-core/cirq/protocols/has_stabilizer_effect_protocol_test.py b/cirq-core/cirq/protocols/has_stabilizer_effect_protocol_test.py index 9f7043ac114..82c18426fff 100644 --- a/cirq-core/cirq/protocols/has_stabilizer_effect_protocol_test.py +++ b/cirq-core/cirq/protocols/has_stabilizer_effect_protocol_test.py @@ -14,6 +14,8 @@ from __future__ import annotations +from typing import Any + import numpy as np import cirq @@ -50,40 +52,40 @@ def __init__(self, q: cirq.Qid = cirq.LineQubit(0)): self.q = q @property - def qubits(self): + def qubits(self) -> tuple[cirq.Qid, ...]: return (self.q,) - def with_qubits(self, *new_qubits): # pragma: no cover + def with_qubits(self, *new_qubits) -> cirq.Operation: # pragma: no cover return self class NoOp(EmptyOp): @property - def gate(self): + def gate(self) -> Any: return No() class NoOp1(EmptyOp): @property - def gate(self): + def gate(self) -> Any: return No1() class NoOp2(EmptyOp): @property - def gate(self): + def gate(self) -> Any: return No2() class NoOp3(EmptyOp): @property - def gate(self): + def gate(self) -> Any: return No3() class YesOp(EmptyOp): @property - def gate(self): + def gate(self) -> Any: return Yes() @@ -95,8 +97,8 @@ def _unitary_(self): return self.unitary @property - def qubits(self): - return cirq.LineQubit.range(self.unitary.shape[0].bit_length() - 1) + def qubits(self) -> tuple[cirq.Qid, ...]: + return tuple(cirq.LineQubit.range(self.unitary.shape[0].bit_length() - 1)) class GateDecomposes(cirq.Gate): diff --git a/cirq-core/cirq/protocols/has_unitary_protocol_test.py b/cirq-core/cirq/protocols/has_unitary_protocol_test.py index b29023edbe3..a124ab4b178 100644 --- a/cirq-core/cirq/protocols/has_unitary_protocol_test.py +++ b/cirq-core/cirq/protocols/has_unitary_protocol_test.py @@ -218,8 +218,8 @@ class EmptyOp(cirq.Operation): """A trivial operation that will be recognized as `_apply_unitary_`-able.""" @property - def qubits(self): + def qubits(self) -> tuple[cirq.Qid, ...]: return () - def with_qubits(self, *new_qubits): # pragma: no cover - return self + def with_qubits(self, *new_qubits) -> cirq.Operation: + raise NotImplementedError() diff --git a/cirq-core/cirq/protocols/json_serialization.py b/cirq-core/cirq/protocols/json_serialization.py index b1013d0b40e..9bc643b3291 100644 --- a/cirq-core/cirq/protocols/json_serialization.py +++ b/cirq-core/cirq/protocols/json_serialization.py @@ -217,7 +217,7 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._memo: dict[Any, dict] = {} - def default(self, o): + def default(self, o) -> Any: # Object with custom method? if hasattr(o, '_json_dict_'): json_dict = _json_dict_with_cirq_type(o) @@ -519,7 +519,7 @@ def read_json( *, json_text: str | None = None, resolvers: Sequence[JsonResolver] | None = None, -): +) -> Any: """Read a JSON file that optionally contains cirq objects. Args: @@ -605,7 +605,7 @@ def read_json_gzip( *, gzip_raw: bytes | None = None, resolvers: Sequence[JsonResolver] | None = None, -): +) -> Any: """Read a gzipped JSON file that optionally contains cirq objects. Args: diff --git a/cirq-core/cirq/protocols/json_serialization_test.py b/cirq-core/cirq/protocols/json_serialization_test.py index 8e1a3cae7d0..35686375ab9 100644 --- a/cirq-core/cirq/protocols/json_serialization_test.py +++ b/cirq-core/cirq/protocols/json_serialization_test.py @@ -72,7 +72,7 @@ def _get_testspecs_for_modules() -> list[ModuleJsonTestSpec]: MODULE_TEST_SPECS = _get_testspecs_for_modules() -def test_deprecated_cirq_type_in_json_dict(): +def test_deprecated_cirq_type_in_json_dict() -> None: class HasOldJsonDict: # Required for testing serialization of non-cirq objects. __module__ = 'test.noncirq.namespace' @@ -99,7 +99,7 @@ def custom_resolver(name): assert_json_roundtrip_works(HasOldJsonDict(), resolvers=test_resolvers) -def test_line_qubit_roundtrip(): +def test_line_qubit_roundtrip() -> None: q1 = cirq.LineQubit(12) assert_json_roundtrip_works( q1, @@ -110,7 +110,7 @@ def test_line_qubit_roundtrip(): ) -def test_gridqubit_roundtrip(): +def test_gridqubit_roundtrip() -> None: q = cirq.GridQubit(15, 18) assert_json_roundtrip_works( q, @@ -122,7 +122,7 @@ def test_gridqubit_roundtrip(): ) -def test_op_roundtrip(): +def test_op_roundtrip() -> None: q = cirq.LineQubit(5) op1 = cirq.rx(0.123).on(q) assert_json_roundtrip_works( @@ -143,7 +143,7 @@ def test_op_roundtrip(): ) -def test_op_roundtrip_filename(tmpdir): +def test_op_roundtrip_filename(tmpdir) -> None: filename = f'{tmpdir}/op.json' q = cirq.LineQubit(5) op1 = cirq.rx(0.123).on(q) @@ -159,7 +159,7 @@ def test_op_roundtrip_filename(tmpdir): assert op1 == op3 -def test_op_roundtrip_file_obj(tmpdir): +def test_op_roundtrip_file_obj(tmpdir) -> None: filename = f'{tmpdir}/op.json' q = cirq.LineQubit(5) op1 = cirq.rx(0.123).on(q) @@ -179,7 +179,7 @@ def test_op_roundtrip_file_obj(tmpdir): assert op1 == op3 -def test_fail_to_resolve(): +def test_fail_to_resolve() -> None: buffer = io.StringIO() buffer.write( """ @@ -207,7 +207,7 @@ def test_fail_to_resolve(): # deprecation error in testing. It is cleaner to just turn it off than to assert # deprecation for each submodule. @mock.patch.dict(os.environ, clear='CIRQ_TESTING') -def test_shouldnt_be_serialized_no_superfluous(mod_spec: ModuleJsonTestSpec): +def test_shouldnt_be_serialized_no_superfluous(mod_spec: ModuleJsonTestSpec) -> None: # everything in the list should be ignored for a reason names = set(mod_spec.get_all_names()) missing_names = set(mod_spec.should_not_be_serialized).difference(names) @@ -223,7 +223,7 @@ def test_shouldnt_be_serialized_no_superfluous(mod_spec: ModuleJsonTestSpec): # deprecation error in testing. It is cleaner to just turn it off than to assert # deprecation for each submodule. @mock.patch.dict(os.environ, clear='CIRQ_TESTING') -def test_not_yet_serializable_no_superfluous(mod_spec: ModuleJsonTestSpec): +def test_not_yet_serializable_no_superfluous(mod_spec: ModuleJsonTestSpec) -> None: # everything in the list should be ignored for a reason names = set(mod_spec.get_all_names()) missing_names = set(mod_spec.not_yet_serializable).difference(names) @@ -233,7 +233,7 @@ def test_not_yet_serializable_no_superfluous(mod_spec: ModuleJsonTestSpec): @pytest.mark.parametrize('mod_spec', MODULE_TEST_SPECS, ids=repr) -def test_mutually_exclusive_not_serialize_lists(mod_spec: ModuleJsonTestSpec): +def test_mutually_exclusive_not_serialize_lists(mod_spec: ModuleJsonTestSpec) -> None: common = set(mod_spec.should_not_be_serialized) & set(mod_spec.not_yet_serializable) assert len(common) == 0, ( f"Defined in both {mod_spec.name} 'Not yet serializable' " @@ -242,7 +242,7 @@ def test_mutually_exclusive_not_serialize_lists(mod_spec: ModuleJsonTestSpec): @pytest.mark.parametrize('mod_spec', MODULE_TEST_SPECS, ids=repr) -def test_resolver_cache_vs_should_not_serialize(mod_spec: ModuleJsonTestSpec): +def test_resolver_cache_vs_should_not_serialize(mod_spec: ModuleJsonTestSpec) -> None: resolver_cache_types = set([n for (n, _) in mod_spec.get_resolver_cache_types()]) common = set(mod_spec.should_not_be_serialized) & resolver_cache_types @@ -254,7 +254,7 @@ def test_resolver_cache_vs_should_not_serialize(mod_spec: ModuleJsonTestSpec): @pytest.mark.parametrize('mod_spec', MODULE_TEST_SPECS, ids=repr) -def test_resolver_cache_vs_not_yet_serializable(mod_spec: ModuleJsonTestSpec): +def test_resolver_cache_vs_not_yet_serializable(mod_spec: ModuleJsonTestSpec) -> None: resolver_cache_types = set([n for (n, _) in mod_spec.get_resolver_cache_types()]) common = set(mod_spec.not_yet_serializable) & resolver_cache_types @@ -267,14 +267,14 @@ def test_resolver_cache_vs_not_yet_serializable(mod_spec: ModuleJsonTestSpec): ) -def test_builtins(): +def test_builtins() -> None: assert_json_roundtrip_works(True) assert_json_roundtrip_works(1) assert_json_roundtrip_works(1 + 2j) assert_json_roundtrip_works({'test': [123, 5.5], 'key2': 'asdf', '3': None, '0.0': []}) -def test_numpy(): +def test_numpy() -> None: x = np.ones(1)[0] assert_json_roundtrip_works(np.bool_(True)) @@ -295,7 +295,7 @@ def test_numpy(): assert_json_roundtrip_works(np.arange(3)) -def test_pandas(): +def test_pandas() -> None: assert_json_roundtrip_works( pd.DataFrame(data=[[1, 2, 3], [4, 5, 6]], columns=['x', 'y', 'z'], index=[2, 5]) ) @@ -320,7 +320,7 @@ def test_pandas(): ) -def test_sympy(): +def test_sympy() -> None: # Raw values. assert_json_roundtrip_works(sympy.Symbol('theta')) assert_json_roundtrip_works(sympy.Integer(5)) @@ -388,7 +388,7 @@ def _from_json_dict_(cls, name, data_list, data_tuple, data_dict, **kwargs): return cls(name, data_list, tuple(data_tuple), data_dict) -def test_serializable_by_key(): +def test_serializable_by_key() -> None: def custom_resolver(name): if name == 'SBKImpl': return SBKImpl @@ -443,12 +443,12 @@ def _list_public_classes_for_tested_modules(): @pytest.mark.parametrize('mod_spec,cirq_obj_name,cls', _list_public_classes_for_tested_modules()) -def test_json_test_data_coverage(mod_spec: ModuleJsonTestSpec, cirq_obj_name: str, cls): +def test_json_test_data_coverage(mod_spec: ModuleJsonTestSpec, cirq_obj_name: str, cls) -> None: if cirq_obj_name in mod_spec.tested_elsewhere: pytest.skip("Tested elsewhere.") if cirq_obj_name in mod_spec.not_yet_serializable: - return pytest.xfail(reason="Not serializable (yet)") + pytest.xfail(reason="Not serializable (yet)") test_data_path = mod_spec.test_data_path rel_path = test_data_path.relative_to(REPO_ROOT) @@ -529,12 +529,12 @@ def _from_json_dict_(cls, test_type, **kwargs): @pytest.mark.parametrize('mod_spec,cirq_obj_name,cls', _list_public_classes_for_tested_modules()) -def test_type_serialization(mod_spec: ModuleJsonTestSpec, cirq_obj_name: str, cls): +def test_type_serialization(mod_spec: ModuleJsonTestSpec, cirq_obj_name: str, cls) -> None: if cirq_obj_name in mod_spec.tested_elsewhere: pytest.skip("Tested elsewhere.") if cirq_obj_name in mod_spec.not_yet_serializable: - return pytest.xfail(reason="Not serializable (yet)") + pytest.xfail(reason="Not serializable (yet)") if cls is None: pytest.skip(f'No serialization for None-mapped type: {cirq_obj_name}') # pragma: no cover @@ -556,7 +556,7 @@ def custom_resolver(name): assert_json_roundtrip_works(sto, resolvers=test_resolvers) -def test_invalid_type_deserialize(): +def test_invalid_type_deserialize() -> None: def custom_resolver(name): if name == 'SerializableTypeObject': return SerializableTypeObject @@ -571,7 +571,7 @@ def custom_resolver(name): _ = cirq.read_json(json_text=factory_json, resolvers=test_resolvers) -def test_to_from_strings(): +def test_to_from_strings() -> None: x_json_text = """{ "cirq_type": "_PauliX", "exponent": 1.0, @@ -584,7 +584,7 @@ def test_to_from_strings(): cirq.read_json(io.StringIO(), json_text=x_json_text) -def test_to_from_json_gzip(): +def test_to_from_json_gzip() -> None: a, b = cirq.LineQubit.range(2) test_circuit = cirq.Circuit(cirq.H(a), cirq.CX(a, b)) gzip_data = cirq.to_json_gzip(test_circuit) @@ -631,7 +631,7 @@ def assert_repr_and_json_test_data_agree( json_path: pathlib.Path, inward_only: bool, deprecation_deadline: str | None, -): +) -> None: if not repr_path.exists() and not json_path.exists(): return @@ -701,7 +701,7 @@ def assert_repr_and_json_test_data_agree( 'mod_spec, abs_path', [(m, abs_path) for m in MODULE_TEST_SPECS for abs_path in m.all_test_data_keys()], ) -def test_json_and_repr_data(mod_spec: ModuleJsonTestSpec, abs_path: str): +def test_json_and_repr_data(mod_spec: ModuleJsonTestSpec, abs_path: str) -> None: assert_repr_and_json_test_data_agree( mod_spec=mod_spec, repr_path=pathlib.Path(f'{abs_path}.repr'), @@ -718,7 +718,7 @@ def test_json_and_repr_data(mod_spec: ModuleJsonTestSpec, abs_path: str): ) -def test_pathlib_paths(tmpdir): +def test_pathlib_paths(tmpdir) -> None: path = pathlib.Path(tmpdir) / 'op.json' cirq.to_json(cirq.X, path) assert cirq.read_json(path) == cirq.X @@ -746,7 +746,7 @@ def custom_resolver(name): assert_json_roundtrip_works(my_dc, resolvers=[custom_resolver, *cirq.DEFAULT_RESOLVERS]) -def test_numpy_values(): +def test_numpy_values() -> None: assert ( cirq.to_json({'value': np.array(1)}) == """{ @@ -755,7 +755,7 @@ def test_numpy_values(): ) -def test_basic_time_assertions(): +def test_basic_time_assertions() -> None: naive_dt = datetime.datetime.now() utc_dt = naive_dt.astimezone(datetime.timezone.utc) assert naive_dt.timestamp() == utc_dt.timestamp() @@ -768,7 +768,7 @@ def test_basic_time_assertions(): assert naive_dt == re_naive, 'works, as long as you called fromtimestamp from the same timezone' -def test_datetime(): +def test_datetime() -> None: naive_dt = datetime.datetime.now() re_naive_dt = cirq.read_json(json_text=cirq.to_json(naive_dt)) @@ -793,7 +793,7 @@ class _TestAttrsClas: x: int -def test_attrs_json_dict(): +def test_attrs_json_dict() -> None: obj = _TestAttrsClas('test', x=123) js = json_serialization.attrs_json_dict(obj) assert js == {'name': 'test', 'x': 123} diff --git a/cirq-core/cirq/protocols/kraus_protocol_test.py b/cirq-core/cirq/protocols/kraus_protocol_test.py index 9aea7704e41..cd1492cf00c 100644 --- a/cirq-core/cirq/protocols/kraus_protocol_test.py +++ b/cirq-core/cirq/protocols/kraus_protocol_test.py @@ -42,7 +42,7 @@ class NoMethod: assert not cirq.has_kraus(NoMethod()) -def assert_not_implemented(val): +def assert_not_implemented(val) -> None: with pytest.raises(TypeError, match='returned NotImplemented'): _ = cirq.kraus(val) @@ -174,7 +174,7 @@ def test_has_kraus_when_decomposed(decomposed_cls) -> None: assert not cirq.has_kraus(op, allow_decompose=False) -def test_strat_kraus_from_apply_channel_returns_none(): +def test_strat_kraus_from_apply_channel_returns_none() -> None: # Remove _kraus_ and _apply_channel_ methods class NoApplyChannelReset(cirq.ResetChannel): def _kraus_(self): @@ -228,7 +228,7 @@ def _apply_channel_(self, args: cirq.ApplyChannelArgs): np.testing.assert_allclose(actual_super, expected_super, atol=1e-8) -def test_reset_channel_kraus_apply_channel_consistency(): +def test_reset_channel_kraus_apply_channel_consistency() -> None: Reset = cirq.ResetChannel # Original gate gate = Reset() @@ -236,7 +236,7 @@ def test_reset_channel_kraus_apply_channel_consistency(): cirq.testing.assert_consistent_channel(gate) # Remove _kraus_ method - class NoKrausReset(Reset): + class NoKrausReset(cirq.ResetChannel): def _kraus_(self): return NotImplemented @@ -245,7 +245,7 @@ def _kraus_(self): np.testing.assert_allclose(cirq.kraus(gate), cirq.kraus(gate_no_kraus), atol=1e-8) -def test_kraus_channel_with_has_unitary(): +def test_kraus_channel_with_has_unitary() -> None: """CZSWAP has no unitary dunder method but has_unitary returns True.""" op = cirq.CZSWAP.on(cirq.q(1), cirq.q(2)) channels = cirq.kraus(op) diff --git a/cirq-core/cirq/protocols/measurement_key_protocol.py b/cirq-core/cirq/protocols/measurement_key_protocol.py index 07b4ff0b538..97b72e3b837 100644 --- a/cirq-core/cirq/protocols/measurement_key_protocol.py +++ b/cirq-core/cirq/protocols/measurement_key_protocol.py @@ -17,7 +17,7 @@ from __future__ import annotations from types import NotImplementedType -from typing import Any, Mapping, Protocol, TYPE_CHECKING +from typing import Any, Mapping, Protocol, TYPE_CHECKING, TypeVar from cirq import value from cirq._doc import doc_private @@ -25,6 +25,8 @@ if TYPE_CHECKING: import cirq +TDefault = TypeVar('TDefault') + # This is a special indicator value used by the inverse method to determine # whether or not the caller provided a 'default' argument. RaiseTypeErrorIfNotProvided: Any = ([],) @@ -104,7 +106,9 @@ def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]): """ -def measurement_key_obj(val: Any, default: Any = RaiseTypeErrorIfNotProvided): +def measurement_key_obj( + val: Any, default: TDefault = RaiseTypeErrorIfNotProvided +) -> cirq.MeasurementKey | TDefault: """Get the single measurement key object for the given value. Args: @@ -138,7 +142,9 @@ def measurement_key_obj(val: Any, default: Any = RaiseTypeErrorIfNotProvided): raise TypeError(f"Object of type '{type(val)}' had no measurement keys.") -def measurement_key_name(val: Any, default: Any = RaiseTypeErrorIfNotProvided): +def measurement_key_name( + val: Any, default: TDefault = RaiseTypeErrorIfNotProvided +) -> str | TDefault: """Get the single measurement key for the given value. Args: @@ -278,7 +284,7 @@ def is_measurement(val: Any) -> bool: return keys is not NotImplemented and bool(keys) -def with_measurement_key_mapping(val: Any, key_map: Mapping[str, str]): +def with_measurement_key_mapping(val: Any, key_map: Mapping[str, str]) -> Any: """Remaps the target's measurement keys according to the provided key_map. This method can be used to reassign measurement keys at runtime, or to @@ -288,7 +294,7 @@ def with_measurement_key_mapping(val: Any, key_map: Mapping[str, str]): return NotImplemented if getter is None else getter(key_map) -def with_key_path(val: Any, path: tuple[str, ...]): +def with_key_path(val: Any, path: tuple[str, ...]) -> Any: """Adds the path to the target's measurement keys. The path usually refers to an identifier or a list of identifiers from a subcircuit that @@ -299,7 +305,7 @@ def with_key_path(val: Any, path: tuple[str, ...]): return NotImplemented if getter is None else getter(path) -def with_key_path_prefix(val: Any, prefix: tuple[str, ...]): +def with_key_path_prefix(val: Any, prefix: tuple[str, ...]) -> Any: """Prefixes the path to the target's measurement keys. The path usually refers to an identifier or a list of identifiers from a subcircuit that @@ -316,7 +322,7 @@ def with_key_path_prefix(val: Any, prefix: tuple[str, ...]): def with_rescoped_keys( val: Any, path: tuple[str, ...], bindable_keys: frozenset[cirq.MeasurementKey] | None = None -): +) -> Any: """Rescopes any measurement and control keys to the provided path, given the existing keys. The path usually refers to an identifier or a list of identifiers from a subcircuit that diff --git a/cirq-core/cirq/protocols/mixture_protocol.py b/cirq-core/cirq/protocols/mixture_protocol.py index ce52cdcb2e9..2d51349f1e2 100644 --- a/cirq-core/cirq/protocols/mixture_protocol.py +++ b/cirq-core/cirq/protocols/mixture_protocol.py @@ -148,7 +148,7 @@ def has_mixture(val: Any, *, allow_decompose: bool = True) -> bool: return mixture(val, None) is not None -def validate_mixture(supports_mixture: SupportsMixture): +def validate_mixture(supports_mixture: SupportsMixture) -> None: """Validates that the mixture's tuple are valid probabilities.""" mixture_tuple = mixture(supports_mixture, None) if mixture_tuple is None: diff --git a/cirq-core/cirq/protocols/mixture_protocol_test.py b/cirq-core/cirq/protocols/mixture_protocol_test.py index ac6abbf7dfe..04693cee8bb 100644 --- a/cirq-core/cirq/protocols/mixture_protocol_test.py +++ b/cirq-core/cirq/protocols/mixture_protocol_test.py @@ -93,7 +93,7 @@ def _mixture_(self): (ReturnsMixtureOfReturnsUnitary(), ((0.4, np.eye(2)), (0.6, np.eye(2)))), ), ) -def test_objects_with_mixture(val, mixture): +def test_objects_with_mixture(val, mixture) -> None: expected_keys, expected_values = zip(*mixture) keys, values = zip(*cirq.mixture(val)) np.testing.assert_almost_equal(keys, expected_keys) @@ -107,7 +107,7 @@ def test_objects_with_mixture(val, mixture): @pytest.mark.parametrize( 'val', (NoMethod(), ReturnsNotImplemented(), ReturnsNotImplementedUnitary()) ) -def test_objects_with_no_mixture(val): +def test_objects_with_no_mixture(val) -> None: with pytest.raises(TypeError, match="mixture"): _ = cirq.mixture(val) assert cirq.mixture(val, None) is None @@ -116,7 +116,7 @@ def test_objects_with_no_mixture(val): assert cirq.mixture(val, default) == default -def test_has_mixture(): +def test_has_mixture() -> None: assert cirq.has_mixture(ReturnsValidTuple()) assert not cirq.has_mixture(ReturnsNotImplemented()) assert cirq.has_mixture(ReturnsMixtureButNoHasMixture()) @@ -124,7 +124,7 @@ def test_has_mixture(): assert not cirq.has_mixture(ReturnsNotImplementedUnitary()) -def test_valid_mixture(): +def test_valid_mixture() -> None: cirq.validate_mixture(ReturnsValidTuple()) @@ -136,11 +136,11 @@ def test_valid_mixture(): (ReturnsGreaterThanUnityProbability(), 'greater than 1'), ), ) -def test_invalid_mixture(val, message): +def test_invalid_mixture(val, message) -> None: with pytest.raises(ValueError, match=message): cirq.validate_mixture(val) -def test_missing_mixture(): +def test_missing_mixture() -> None: with pytest.raises(TypeError, match='_mixture_'): - cirq.validate_mixture(NoMethod) + cirq.validate_mixture(NoMethod) # type: ignore[arg-type] diff --git a/cirq-core/cirq/protocols/mul_protocol_test.py b/cirq-core/cirq/protocols/mul_protocol_test.py index e92ada6a322..43fb2551bd9 100644 --- a/cirq-core/cirq/protocols/mul_protocol_test.py +++ b/cirq-core/cirq/protocols/mul_protocol_test.py @@ -52,7 +52,7 @@ def __rmul__(self, other): return 8 -def test_equivalent_to_builtin_mul(): +def test_equivalent_to_builtin_mul() -> None: test_vals = [ 0, 1, @@ -76,14 +76,14 @@ def test_equivalent_to_builtin_mul(): c = cirq.mul(a, b, default=None) if c is None: with pytest.raises(TypeError): - _ = a * b + _ = a * b # type: ignore[operator] with pytest.raises(TypeError): _ = cirq.mul(a, b) else: - assert c == a * b + assert c == a * b # type: ignore[operator] -def test_symbol_special_case(): +def test_symbol_special_case() -> None: x = sympy.Symbol('x') assert cirq.mul(x, 1.0) is x assert cirq.mul(1.0, x) is x diff --git a/cirq-core/cirq/protocols/phase_protocol.py b/cirq-core/cirq/protocols/phase_protocol.py index 9d3e1213fe8..bdfa5b9f8ab 100644 --- a/cirq-core/cirq/protocols/phase_protocol.py +++ b/cirq-core/cirq/protocols/phase_protocol.py @@ -49,7 +49,7 @@ def _phase_by_(self, phase_turns: float, qubit_index: int): def phase_by( val: Any, phase_turns: float, qubit_index: int, default: TDefault = RaiseTypeErrorIfNotProvided -): +) -> Any: """Returns a phased version of the effect. For example, an X gate phased by 90 degrees would be a Y gate. diff --git a/cirq-core/cirq/protocols/pow_protocol_test.py b/cirq-core/cirq/protocols/pow_protocol_test.py index 29cbbb65d20..9895d83bac3 100644 --- a/cirq-core/cirq/protocols/pow_protocol_test.py +++ b/cirq-core/cirq/protocols/pow_protocol_test.py @@ -34,7 +34,7 @@ def __pow__(self, exponent) -> int: @pytest.mark.parametrize('val', (NoMethod(), 'text', object(), ReturnsNotImplemented())) -def test_powerless(val): +def test_powerless(val) -> None: assert cirq.pow(val, 5, None) is None assert cirq.pow(val, 2, NotImplemented) is NotImplemented @@ -42,13 +42,13 @@ def test_powerless(val): assert cirq.pow(val, 1, None) is None -def test_pow_error(): +def test_pow_error() -> None: with pytest.raises(TypeError, match="returned NotImplemented"): - _ = cirq.pow(ReturnsNotImplemented(), 3) + _ = cirq.pow(ReturnsNotImplemented(), 3) # type: ignore[call-overload] with pytest.raises(TypeError, match="no __pow__ method"): - _ = cirq.pow(NoMethod(), 3) + _ = cirq.pow(NoMethod(), 3) # type: ignore[call-overload] @pytest.mark.parametrize('val,exponent,out', ((ReturnsExponent(), 2, 2), (1, 2, 1), (2, 3, 8))) -def test_pow_with_result(val, exponent, out): +def test_pow_with_result(val, exponent, out) -> None: assert cirq.pow(val, exponent) == cirq.pow(val, exponent, default=None) == val**exponent == out diff --git a/cirq-core/cirq/protocols/resolve_parameters.py b/cirq-core/cirq/protocols/resolve_parameters.py index 803f74a484e..b27fbe7f273 100644 --- a/cirq-core/cirq/protocols/resolve_parameters.py +++ b/cirq-core/cirq/protocols/resolve_parameters.py @@ -196,6 +196,6 @@ def resolve_parameters( return val -def resolve_parameters_once(val: Any, param_resolver: cirq.ParamResolverOrSimilarType): +def resolve_parameters_once(val: T, param_resolver: cirq.ParamResolverOrSimilarType) -> T: """Performs a single parameter resolution step using the param resolver.""" return resolve_parameters(val, param_resolver, False) diff --git a/cirq-core/cirq/protocols/unitary_protocol_test.py b/cirq-core/cirq/protocols/unitary_protocol_test.py index 778a8f0894f..9c9392a0b08 100644 --- a/cirq-core/cirq/protocols/unitary_protocol_test.py +++ b/cirq-core/cirq/protocols/unitary_protocol_test.py @@ -56,7 +56,7 @@ def _has_unitary_(self): def _unitary_(self): return NotImplemented - def num_qubits(self): + def num_qubits(self) -> int: return 1 @@ -64,7 +64,7 @@ class ReturnsMatrix(cirq.Gate): def _unitary_(self) -> np.ndarray: return m1 - def num_qubits(self): + def num_qubits(self) -> int: return 1 # pragma: no cover @@ -80,7 +80,7 @@ def _unitary_(self) -> np.ndarray | None: return None return m1 - def num_qubits(self): + def num_qubits(self) -> int: return 1 @@ -88,7 +88,7 @@ class DecomposableGate(cirq.Gate): def __init__(self, unitary_value): self.unitary_value = unitary_value - def num_qubits(self): + def num_qubits(self) -> int: return 1 def _decompose_(self, qubits): @@ -154,7 +154,7 @@ def _decompose_(self): yield cirq.X(cirq.LineQubit(3)) -def test_unitary(): +def test_unitary() -> None: with pytest.raises(TypeError, match='unitary effect'): _ = cirq.unitary(NoMethod()) with pytest.raises(TypeError, match='unitary effect'): @@ -186,7 +186,7 @@ def test_unitary(): assert cirq.unitary(FullyImplemented(False), default=None) is None -def test_has_unitary(): +def test_has_unitary() -> None: assert not cirq.has_unitary(NoMethod()) assert not cirq.has_unitary(ReturnsNotImplemented()) assert cirq.has_unitary(ReturnsMatrix()) @@ -215,7 +215,7 @@ def _test_gate_that_allocates_qubits(gate): @pytest.mark.parametrize('ancilla_bitsize', [1, 4]) def test_decompose_gate_that_allocates_clean_qubits( theta: float, phase_state: int, target_bitsize: int, ancilla_bitsize: int -): +) -> None: gate = testing.PhaseUsingCleanAncilla(theta, phase_state, target_bitsize, ancilla_bitsize) _test_gate_that_allocates_qubits(gate) @@ -225,26 +225,38 @@ def test_decompose_gate_that_allocates_clean_qubits( @pytest.mark.parametrize('ancilla_bitsize', [1, 4]) def test_decompose_gate_that_allocates_dirty_qubits( phase_state: int, target_bitsize: int, ancilla_bitsize: int -): +) -> None: gate = testing.PhaseUsingDirtyAncilla(phase_state, target_bitsize, ancilla_bitsize) _test_gate_that_allocates_qubits(gate) -def test_decompose_and_get_unitary(): +def test_decompose_and_get_unitary() -> None: from cirq.protocols.unitary_protocol import _strat_unitary_from_decompose - np.testing.assert_allclose(_strat_unitary_from_decompose(DecomposableOperation((a,), True)), m1) np.testing.assert_allclose( - _strat_unitary_from_decompose(DecomposableOperation((a, b), True)), m2 + _strat_unitary_from_decompose(DecomposableOperation((a,), True)), m1 # type: ignore[arg-type] + ) + np.testing.assert_allclose( + _strat_unitary_from_decompose(DecomposableOperation((a, b), True)), m2 # type: ignore[arg-type] + ) + np.testing.assert_allclose( + _strat_unitary_from_decompose(DecomposableOrder((a, b, c))), m3 # type: ignore[arg-type] + ) + np.testing.assert_allclose( + _strat_unitary_from_decompose(ExampleOperation((a,))), np.eye(2) # type: ignore[arg-type] + ) + np.testing.assert_allclose( + _strat_unitary_from_decompose(ExampleOperation((a, b))), np.eye(4) # type: ignore[arg-type] + ) + np.testing.assert_allclose( + _strat_unitary_from_decompose(ExampleComposite()), np.eye(1) # type: ignore[arg-type] + ) + np.testing.assert_allclose( + _strat_unitary_from_decompose(OtherComposite()), m2 # type: ignore[arg-type] ) - np.testing.assert_allclose(_strat_unitary_from_decompose(DecomposableOrder((a, b, c))), m3) - np.testing.assert_allclose(_strat_unitary_from_decompose(ExampleOperation((a,))), np.eye(2)) - np.testing.assert_allclose(_strat_unitary_from_decompose(ExampleOperation((a, b))), np.eye(4)) - np.testing.assert_allclose(_strat_unitary_from_decompose(ExampleComposite()), np.eye(1)) - np.testing.assert_allclose(_strat_unitary_from_decompose(OtherComposite()), m2) -def test_decomposed_has_unitary(): +def test_decomposed_has_unitary() -> None: # Gates assert cirq.has_unitary(DecomposableGate(True)) assert not cirq.has_unitary(DecomposableGate(False)) @@ -263,7 +275,7 @@ def test_decomposed_has_unitary(): assert cirq.has_unitary(OtherComposite()) -def test_decomposed_unitary(): +def test_decomposed_unitary() -> None: # Gates np.testing.assert_allclose(cirq.unitary(DecomposableGate(True)), m1) @@ -283,7 +295,7 @@ def test_decomposed_unitary(): np.testing.assert_allclose(cirq.unitary(OtherComposite()), m2) -def test_unitary_from_apply_unitary(): +def test_unitary_from_apply_unitary() -> None: class ApplyGate(cirq.Gate): def num_qubits(self): return 1 diff --git a/cirq-core/cirq/qis/clifford_tableau.py b/cirq-core/cirq/qis/clifford_tableau.py index 4f1d2f0b19c..f9b66d95f25 100644 --- a/cirq-core/cirq/qis/clifford_tableau.py +++ b/cirq-core/cirq/qis/clifford_tableau.py @@ -38,7 +38,7 @@ class StabilizerState( """ @abc.abstractmethod - def apply_x(self, axis: int, exponent: float = 1, global_shift: float = 0): + def apply_x(self, axis: int, exponent: float = 1, global_shift: float = 0) -> None: """Apply an X operation to the state. Args: @@ -51,7 +51,7 @@ def apply_x(self, axis: int, exponent: float = 1, global_shift: float = 0): """ @abc.abstractmethod - def apply_y(self, axis: int, exponent: float = 1, global_shift: float = 0): + def apply_y(self, axis: int, exponent: float = 1, global_shift: float = 0) -> None: """Apply an Y operation to the state. Args: @@ -64,7 +64,7 @@ def apply_y(self, axis: int, exponent: float = 1, global_shift: float = 0): """ @abc.abstractmethod - def apply_z(self, axis: int, exponent: float = 1, global_shift: float = 0): + def apply_z(self, axis: int, exponent: float = 1, global_shift: float = 0) -> None: """Apply a Z operation to the state. Args: @@ -77,7 +77,7 @@ def apply_z(self, axis: int, exponent: float = 1, global_shift: float = 0): """ @abc.abstractmethod - def apply_h(self, axis: int, exponent: float = 1, global_shift: float = 0): + def apply_h(self, axis: int, exponent: float = 1, global_shift: float = 0) -> None: """Apply an H operation to the state. Args: @@ -92,7 +92,7 @@ def apply_h(self, axis: int, exponent: float = 1, global_shift: float = 0): @abc.abstractmethod def apply_cz( self, control_axis: int, target_axis: int, exponent: float = 1, global_shift: float = 0 - ): + ) -> None: """Apply a CZ operation to the state. Args: @@ -108,7 +108,7 @@ def apply_cz( @abc.abstractmethod def apply_cx( self, control_axis: int, target_axis: int, exponent: float = 1, global_shift: float = 0 - ): + ) -> None: """Apply a CX operation to the state. Args: @@ -122,7 +122,7 @@ def apply_cx( """ @abc.abstractmethod - def apply_global_phase(self, coefficient: linear_dict.Scalar): + def apply_global_phase(self, coefficient: linear_dict.Scalar) -> None: """Apply a global phase to the state. Args: @@ -561,7 +561,7 @@ def _measure(self, q, prng: np.random.RandomState) -> int: return int(self.rs[p]) - def apply_x(self, axis: int, exponent: float = 1, global_shift: float = 0): + def apply_x(self, axis: int, exponent: float = 1, global_shift: float = 0) -> None: if exponent % 2 == 0: return if exponent % 0.5 != 0.0: @@ -576,7 +576,7 @@ def apply_x(self, axis: int, exponent: float = 1, global_shift: float = 0): self.rs[:] ^= self.xs[:, axis] & self.zs[:, axis] self.xs[:, axis] ^= self.zs[:, axis] - def apply_y(self, axis: int, exponent: float = 1, global_shift: float = 0): + def apply_y(self, axis: int, exponent: float = 1, global_shift: float = 0) -> None: if exponent % 2 == 0: return if exponent % 0.5 != 0.0: @@ -597,7 +597,7 @@ def apply_y(self, axis: int, exponent: float = 1, global_shift: float = 0): self.xs[:, axis].copy(), ) - def apply_z(self, axis: int, exponent: float = 1, global_shift: float = 0): + def apply_z(self, axis: int, exponent: float = 1, global_shift: float = 0) -> None: if exponent % 2 == 0: return if exponent % 0.5 != 0.0: @@ -612,7 +612,7 @@ def apply_z(self, axis: int, exponent: float = 1, global_shift: float = 0): self.rs[:] ^= self.xs[:, axis] & (~self.zs[:, axis]) self.zs[:, axis] ^= self.xs[:, axis] - def apply_h(self, axis: int, exponent: float = 1, global_shift: float = 0): + def apply_h(self, axis: int, exponent: float = 1, global_shift: float = 0) -> None: if exponent % 2 == 0: return if exponent % 1 != 0: @@ -622,7 +622,7 @@ def apply_h(self, axis: int, exponent: float = 1, global_shift: float = 0): def apply_cz( self, control_axis: int, target_axis: int, exponent: float = 1, global_shift: float = 0 - ): + ) -> None: if exponent % 2 == 0: return if exponent % 1 != 0: @@ -647,7 +647,7 @@ def apply_cz( def apply_cx( self, control_axis: int, target_axis: int, exponent: float = 1, global_shift: float = 0 - ): + ) -> None: if exponent % 2 == 0: return if exponent % 1 != 0: @@ -660,7 +660,7 @@ def apply_cx( self.xs[:, target_axis] ^= self.xs[:, control_axis] self.zs[:, control_axis] ^= self.zs[:, target_axis] - def apply_global_phase(self, coefficient: linear_dict.Scalar): + def apply_global_phase(self, coefficient: linear_dict.Scalar) -> None: pass def measure( diff --git a/cirq-core/cirq/qis/clifford_tableau_test.py b/cirq-core/cirq/qis/clifford_tableau_test.py index b89293b545e..144d73f82db 100644 --- a/cirq-core/cirq/qis/clifford_tableau_test.py +++ b/cirq-core/cirq/qis/clifford_tableau_test.py @@ -47,7 +47,7 @@ def _CNOT(table, q1, q2): @pytest.mark.parametrize('num_qubits', range(1, 4)) -def test_tableau_initial_state_string(num_qubits): +def test_tableau_initial_state_string(num_qubits) -> None: for i in range(2**num_qubits): t = cirq.CliffordTableau(initial_state=i, num_qubits=num_qubits) splitted_represent_string = str(t).split('\n') @@ -58,7 +58,7 @@ def test_tableau_initial_state_string(num_qubits): assert splitted_represent_string[n] == expected_string -def test_tableau_invalid_initial_state(): +def test_tableau_invalid_initial_state() -> None: with pytest.raises(ValueError, match="2*num_qubits columns and of type bool."): cirq.CliffordTableau(1, rs=np.zeros(1, dtype=bool)) @@ -73,7 +73,7 @@ def test_tableau_invalid_initial_state(): cirq.CliffordTableau(1, zs=np.zeros(1, dtype=bool)) -def test_stabilizers(): +def test_stabilizers() -> None: # Note: the stabilizers are not unique for one state. We just use the one # produced by the tableau algorithm. # 1. Final state is |1>: Stabalized by -Z. @@ -101,7 +101,7 @@ def test_stabilizers(): assert stabilizers[1] == cirq.DensePauliString('IX', coefficient=1) -def test_destabilizers(): +def test_destabilizers() -> None: # Note: Like stablizers, the destabilizers are not unique for one state, too. # We just use the one produced by the tableau algorithm. # Under the clifford tableau algorithm, there are several properties that the @@ -135,7 +135,7 @@ def test_destabilizers(): assert destabilizers[1] == cirq.DensePauliString('IZ', coefficient=1) -def test_measurement(): +def test_measurement() -> None: repetitions = 500 prng = np.random.RandomState(seed=123456) @@ -195,7 +195,7 @@ def test_measurement(): assert sum(np.asarray(res) == 3) >= (repetitions / 4 * 0.9) -def test_validate_tableau(): +def test_validate_tableau() -> None: num_qubits = 4 for i in range(2**num_qubits): t = cirq.CliffordTableau(initial_state=i, num_qubits=num_qubits) @@ -218,7 +218,7 @@ def test_validate_tableau(): assert not t._validate() -def test_rowsum(): +def test_rowsum() -> None: # Note: rowsum should not apply on two rows that anti-commute each other. t = cirq.CliffordTableau(num_qubits=2) # XI * IX ==> XX @@ -246,7 +246,7 @@ def test_rowsum(): assert t.stabilizers()[1] == cirq.DensePauliString('YX', coefficient=1) -def test_json_dict(): +def test_json_dict() -> None: t = cirq.CliffordTableau._from_json_dict_(n=1, rs=[0, 0], xs=[[1], [0]], zs=[[0], [1]]) assert t.destabilizers()[0] == cirq.DensePauliString('X', coefficient=1) assert t.stabilizers()[0] == cirq.DensePauliString('Z', coefficient=1) @@ -266,7 +266,7 @@ def test_json_dict(): assert json_dict[k] == v -def test_str(): +def test_str() -> None: t = cirq.CliffordTableau(num_qubits=2) splitted_represent_string = str(t).split('\n') assert len(splitted_represent_string) == 2 @@ -288,11 +288,11 @@ def test_str(): assert splitted_represent_string[1] == '+ I Y ' -def test_repr(): +def test_repr() -> None: cirq.testing.assert_equivalent_repr(cirq.CliffordTableau(num_qubits=1)) -def test_str_full(): +def test_str_full() -> None: t = cirq.CliffordTableau(num_qubits=1) expected_str = r"""stable | destable -------+---------- @@ -317,7 +317,7 @@ def test_str_full(): assert t._str_full_() == expected_str -def test_copy(): +def test_copy() -> None: t = cirq.CliffordTableau(num_qubits=3, initial_state=3) new_t = t.copy() @@ -340,7 +340,7 @@ def _three_identical_table(num_qubits): return t1, t2, t3 -def test_tableau_then(): +def test_tableau_then() -> None: t1, t2, expected_t = _three_identical_table(1) assert expected_t == t1.then(t2) @@ -427,7 +427,7 @@ def random_circuit(num_ops, num_qubits, seed=12345): assert expected_t == t1.then(t2) -def test_tableau_matmul(): +def test_tableau_matmul() -> None: t1, t2, expected_t = _three_identical_table(1) _ = [_H(t, 0) for t in (t1, expected_t)] _ = [_H(t, 0) for t in (t2, expected_t)] @@ -443,17 +443,17 @@ def test_tableau_matmul(): t1 @ 21 -def test_tableau_then_with_bad_input(): +def test_tableau_then_with_bad_input() -> None: t1 = cirq.CliffordTableau(1) t2 = cirq.CliffordTableau(2) with pytest.raises(ValueError, match="Mismatched number of qubits of two tableaux: 1 vs 2."): t1.then(t2) with pytest.raises(TypeError): - t1.then(cirq.X) + t1.then(cirq.X) # type: ignore[arg-type] -def test_inverse(): +def test_inverse() -> None: t = cirq.CliffordTableau(num_qubits=1) assert t.inverse() == t diff --git a/cirq-core/cirq/qis/entropy.py b/cirq-core/cirq/qis/entropy.py index acb2c5cc0d7..982dbb055ef 100644 --- a/cirq-core/cirq/qis/entropy.py +++ b/cirq-core/cirq/qis/entropy.py @@ -93,7 +93,7 @@ def _compute_bitstrings_contribution_to_purity(bitstrings: npt.NDArray[np.int8]) def process_renyi_entropy_from_bitstrings( measured_bitstrings: npt.NDArray[np.int8], - subsystem: tuple[int] | None = None, + subsystem: tuple[int, ...] | None = None, pool: ThreadPoolExecutor | None = None, ) -> float: """Compute the Rényi entropy of an array of bitstrings. diff --git a/cirq-core/cirq/qis/entropy_test.py b/cirq-core/cirq/qis/entropy_test.py index c41dcaa1ced..5310c2754a0 100644 --- a/cirq-core/cirq/qis/entropy_test.py +++ b/cirq-core/cirq/qis/entropy_test.py @@ -23,7 +23,7 @@ @pytest.mark.parametrize('pool', [None, ThreadPoolExecutor(max_workers=1)]) -def test_process_renyi_entropy_from_bitstrings(pool): +def test_process_renyi_entropy_from_bitstrings(pool) -> None: bitstrings = np.array( [ [[0, 1, 1, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 1, 1, 1]], diff --git a/cirq-core/cirq/qis/states_test.py b/cirq-core/cirq/qis/states_test.py index 777f0487a93..abd994f288d 100644 --- a/cirq-core/cirq/qis/states_test.py +++ b/cirq-core/cirq/qis/states_test.py @@ -21,15 +21,15 @@ import cirq.testing -def assert_dirac_notation_numpy(vec, expected, decimals=2): +def assert_dirac_notation_numpy(vec, expected, decimals=2) -> None: assert cirq.dirac_notation(np.array(vec), decimals=decimals) == expected -def assert_dirac_notation_python(vec, expected, decimals=2): +def assert_dirac_notation_python(vec, expected, decimals=2) -> None: assert cirq.dirac_notation(vec, decimals=decimals) == expected -def assert_valid_density_matrix(matrix, num_qubits=None, qid_shape=None): +def assert_valid_density_matrix(matrix, num_qubits=None, qid_shape=None) -> None: if qid_shape is None and num_qubits is None: num_qubits = 1 np.testing.assert_almost_equal( @@ -40,7 +40,7 @@ def assert_valid_density_matrix(matrix, num_qubits=None, qid_shape=None): ) -def test_quantum_state(): +def test_quantum_state() -> None: state_vector_1 = cirq.one_hot(shape=(4,), dtype=np.complex128) state_tensor_1 = np.reshape(state_vector_1, (2, 2)) density_matrix_1 = np.outer(state_vector_1, np.conj(state_vector_1)) @@ -73,7 +73,7 @@ def test_quantum_state(): np.testing.assert_array_equal(state.state_vector_or_density_matrix(), density_matrix_1) -def test_quantum_state_quantum_state(): +def test_quantum_state_quantum_state() -> None: state_vector_1 = cirq.one_hot(shape=(4,), dtype=np.complex128) quantum_state = cirq.QuantumState(state_vector_1) @@ -96,7 +96,7 @@ def test_quantum_state_quantum_state(): state = cirq.quantum_state(quantum_state, qid_shape=(4,)) -def test_quantum_state_computational_basis_state(): +def test_quantum_state_computational_basis_state() -> None: state = cirq.quantum_state(7, qid_shape=(3, 4)) np.testing.assert_allclose(state.data, cirq.one_hot(index=7, shape=(12,), dtype=np.complex64)) assert state.qid_shape == (3, 4) @@ -125,7 +125,7 @@ def test_quantum_state_computational_basis_state(): _ = cirq.quantum_state((0, 0, 1, 1), qid_shape=(1, 1, 2, 2)) -def test_quantum_state_state_vector_state_tensor(): +def test_quantum_state_state_vector_state_tensor() -> None: state_vector_1 = cirq.one_hot(shape=(4,), dtype=np.complex128) state_tensor_1 = np.reshape(state_vector_1, (2, 2)) @@ -146,7 +146,7 @@ def test_quantum_state_state_vector_state_tensor(): _ = cirq.quantum_state(state_tensor_1, qid_shape=(2, 3)) -def test_quantum_state_density_matrix(): +def test_quantum_state_density_matrix() -> None: density_matrix_1 = np.eye(4, dtype=np.complex64) / 4 state = cirq.quantum_state(density_matrix_1, qid_shape=(4,), copy=True) @@ -159,7 +159,7 @@ def test_quantum_state_density_matrix(): _ = cirq.quantum_state(density_matrix_1, qid_shape=(8,)) -def test_quantum_state_product_state(): +def test_quantum_state_product_state() -> None: q0, q1, q2 = cirq.LineQubit.range(3) product_state_1 = cirq.KET_PLUS(q0) * cirq.KET_PLUS(q1) * cirq.KET_ONE(q2) @@ -172,7 +172,7 @@ def test_quantum_state_product_state(): _ = cirq.quantum_state(product_state_1, qid_shape=(2, 2)) -def test_density_matrix(): +def test_density_matrix() -> None: density_matrix_1 = np.eye(4, dtype=np.complex64) / 4 state_vector_1 = cirq.one_hot(shape=(4,), dtype=np.complex64) @@ -185,7 +185,7 @@ def test_density_matrix(): _ = cirq.density_matrix(state_vector_1) -def test_infer_qid_shape(): +def test_infer_qid_shape() -> None: computational_basis_state_1 = [0, 0, 0, 1] computational_basis_state_2 = [0, 1, 2, 3] computational_basis_state_3 = [0, 1, 2, 4] @@ -261,7 +261,7 @@ def test_infer_qid_shape(): @pytest.mark.parametrize('global_phase', (1, 1j, np.exp(1j))) -def test_bloch_vector_zero_state(global_phase): +def test_bloch_vector_zero_state(global_phase) -> None: zero_state = global_phase * np.array([1, 0]) bloch = cirq.bloch_vector_from_state_vector(zero_state, 0) @@ -270,7 +270,7 @@ def test_bloch_vector_zero_state(global_phase): @pytest.mark.parametrize('global_phase', (1, 1j, np.exp(1j))) -def test_bloch_vector_one_state(global_phase): +def test_bloch_vector_one_state(global_phase) -> None: one_state = global_phase * np.array([0, 1]) bloch = cirq.bloch_vector_from_state_vector(one_state, 0) @@ -279,7 +279,7 @@ def test_bloch_vector_one_state(global_phase): @pytest.mark.parametrize('global_phase', (1, 1j, np.exp(1j))) -def test_bloch_vector_plus_state(global_phase): +def test_bloch_vector_plus_state(global_phase) -> None: sqrt = np.sqrt(0.5) plus_state = global_phase * np.array([sqrt, sqrt]) @@ -289,7 +289,7 @@ def test_bloch_vector_plus_state(global_phase): @pytest.mark.parametrize('global_phase', (1, 1j, np.exp(1j))) -def test_bloch_vector_minus_state(global_phase): +def test_bloch_vector_minus_state(global_phase) -> None: sqrt = np.sqrt(0.5) minus_state = np.array([-1.0j * sqrt, 1.0j * sqrt]) bloch = cirq.bloch_vector_from_state_vector(minus_state, 0) @@ -299,7 +299,7 @@ def test_bloch_vector_minus_state(global_phase): @pytest.mark.parametrize('global_phase', (1, 1j, np.exp(1j))) -def test_bloch_vector_iplus_state(global_phase): +def test_bloch_vector_iplus_state(global_phase) -> None: sqrt = np.sqrt(0.5) iplus_state = global_phase * np.array([sqrt, 1j * sqrt]) @@ -309,7 +309,7 @@ def test_bloch_vector_iplus_state(global_phase): @pytest.mark.parametrize('global_phase', (1, 1j, np.exp(1j))) -def test_bloch_vector_iminus_state(global_phase): +def test_bloch_vector_iminus_state(global_phase) -> None: sqrt = np.sqrt(0.5) iminus_state = global_phase * np.array([sqrt, -1j * sqrt]) @@ -318,7 +318,7 @@ def test_bloch_vector_iminus_state(global_phase): np.testing.assert_array_almost_equal(bloch, desired_simple) -def test_bloch_vector_simple_th_zero(): +def test_bloch_vector_simple_th_zero() -> None: sqrt = np.sqrt(0.5) # State TH|0>. th_state = np.array([sqrt, 0.5 + 0.5j]) @@ -328,7 +328,7 @@ def test_bloch_vector_simple_th_zero(): np.testing.assert_array_almost_equal(bloch, desired_simple) -def test_bloch_vector_equal_sqrt3(): +def test_bloch_vector_equal_sqrt3() -> None: sqrt3 = 1 / np.sqrt(3) test_state = np.array([0.888074, 0.325058 + 0.325058j]) bloch = cirq.bloch_vector_from_state_vector(test_state, 0) @@ -337,7 +337,7 @@ def test_bloch_vector_equal_sqrt3(): np.testing.assert_array_almost_equal(bloch, desired_simple) -def test_bloch_vector_multi_pure(): +def test_bloch_vector_multi_pure() -> None: plus_plus_state = np.array([0.5, 0.5, 0.5, 0.5]) bloch_0 = cirq.bloch_vector_from_state_vector(plus_plus_state, 0) @@ -348,7 +348,7 @@ def test_bloch_vector_multi_pure(): np.testing.assert_array_almost_equal(bloch_0, desired_simple) -def test_bloch_vector_multi_mixed(): +def test_bloch_vector_multi_mixed() -> None: sqrt = np.sqrt(0.5) # Bell state 1/sqrt(2)(|00>+|11>) phi_plus = np.array([sqrt, 0.0, 0.0, sqrt]) @@ -371,7 +371,7 @@ def test_bloch_vector_multi_mixed(): np.testing.assert_array_almost_equal(true_mixed_1, bloch_mixed_1) -def test_bloch_vector_multi_big(): +def test_bloch_vector_multi_big() -> None: five_qubit_plus_state = np.array([0.1767767] * 32) desired_simple = np.array([1, 0, 0]) for qubit in range(5): @@ -379,7 +379,7 @@ def test_bloch_vector_multi_big(): np.testing.assert_array_almost_equal(bloch_i, desired_simple) -def test_bloch_vector_invalid(): +def test_bloch_vector_invalid() -> None: with pytest.raises(ValueError): _ = cirq.bloch_vector_from_state_vector(np.array([0.5, 0.5, 0.5]), 0) with pytest.raises(IndexError): @@ -388,7 +388,7 @@ def test_bloch_vector_invalid(): _ = cirq.bloch_vector_from_state_vector(np.array([0.5, 0.5, 0.5, 0.5]), 2) -def test_density_matrix_from_state_vector(): +def test_density_matrix_from_state_vector() -> None: test_state = np.array( [ 0.0 - 0.35355339j, @@ -428,7 +428,7 @@ def test_density_matrix_from_state_vector(): np.testing.assert_array_almost_equal(rho_zero, true_two) -def test_density_matrix_invalid(): +def test_density_matrix_invalid() -> None: bad_state = np.array([0.5, 0.5, 0.5]) good_state = np.array([0.5, 0.5, 0.5, 0.5]) with pytest.raises(ValueError): @@ -441,7 +441,7 @@ def test_density_matrix_invalid(): _ = cirq.density_matrix_from_state_vector(good_state, [-1]) -def test_dirac_notation(): +def test_dirac_notation() -> None: sqrt = np.sqrt(0.5) exp_pi_2 = 0.5 + 0.5j assert_dirac_notation_numpy([0, 0], "0") @@ -458,7 +458,7 @@ def test_dirac_notation(): assert_dirac_notation_python([0.71j, 0.71j], "0.71j|0⟩ + 0.71j|1⟩") -def test_dirac_notation_partial_state(): +def test_dirac_notation_partial_state() -> None: sqrt = np.sqrt(0.5) exp_pi_2 = 0.5 + 0.5j assert_dirac_notation_numpy([1, 0], "|0⟩") @@ -471,20 +471,20 @@ def test_dirac_notation_partial_state(): assert_dirac_notation_python([0, 0, 0, 1], "|11⟩") -def test_dirac_notation_precision(): +def test_dirac_notation_precision() -> None: sqrt = np.sqrt(0.5) assert_dirac_notation_numpy([sqrt, sqrt], "0.7|0⟩ + 0.7|1⟩", decimals=1) assert_dirac_notation_python([sqrt, sqrt], "0.707|0⟩ + 0.707|1⟩", decimals=3) -def test_dirac_notation_invalid(): +def test_dirac_notation_invalid() -> None: with pytest.raises(ValueError, match='state_vector has incorrect size'): - _ = cirq.dirac_notation([0.0, 0.0, 1.0]) + _ = cirq.dirac_notation(np.array([0.0, 0.0, 1.0])) with pytest.raises(ValueError, match='state_vector has incorrect size'): - _ = cirq.dirac_notation([1.0, 1.0], qid_shape=(3,)) + _ = cirq.dirac_notation(np.array([1.0, 1.0]), qid_shape=(3,)) -def test_to_valid_state_vector(): +def test_to_valid_state_vector() -> None: with pytest.raises(ValueError, match='Computational basis state is out of range'): cirq.to_valid_state_vector(2, 1) np.testing.assert_almost_equal( @@ -515,13 +515,13 @@ def test_to_valid_state_vector(): assert v[0] == 1 -def test_to_valid_state_vector_creates_new_copy(): +def test_to_valid_state_vector_creates_new_copy() -> None: state = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.complex64) out = cirq.to_valid_state_vector(state, 2) assert out is not state -def test_invalid_to_valid_state_vector(): +def test_invalid_to_valid_state_vector() -> None: with pytest.raises(ValueError, match="Please specify"): _ = cirq.to_valid_state_vector(np.array([1])) @@ -532,9 +532,9 @@ def test_invalid_to_valid_state_vector(): with pytest.raises(ValueError): _ = cirq.to_valid_state_vector(5, 2) with pytest.raises(ValueError, match='Invalid quantum state'): - _ = cirq.to_valid_state_vector('0000', 2) + _ = cirq.to_valid_state_vector('0000', 2) # type: ignore[arg-type] with pytest.raises(ValueError, match='Invalid quantum state'): - _ = cirq.to_valid_state_vector('not an int', 2) + _ = cirq.to_valid_state_vector('not an int', 2) # type: ignore[arg-type] with pytest.raises(ValueError, match=r'num_qubits != len\(qid_shape\)'): _ = cirq.to_valid_state_vector(0, 5, qid_shape=(1, 2, 3)) @@ -552,7 +552,7 @@ def test_invalid_to_valid_state_vector(): _ = cirq.to_valid_state_vector(np.array([1, 0], dtype=np.int64), qid_shape=(2, 1)) -def test_validate_normalized_state(): +def test_validate_normalized_state() -> None: cirq.validate_normalized_state_vector(cirq.testing.random_superposition(2), qid_shape=(2,)) cirq.validate_normalized_state_vector( np.array([0.5, 0.5, 0.5, 0.5], dtype=np.complex64), qid_shape=(2, 2) @@ -571,7 +571,7 @@ def test_validate_normalized_state(): ) -def test_validate_density_matrix(): +def test_validate_density_matrix() -> None: cirq.validate_density_matrix(cirq.testing.random_density_matrix(2), qid_shape=(2,)) with pytest.raises(ValueError, match='dtype'): cirq.to_valid_density_matrix( @@ -589,7 +589,7 @@ def test_validate_density_matrix(): ) -def test_to_valid_density_matrix_from_density_matrix(): +def test_to_valid_density_matrix_from_density_matrix() -> None: assert_valid_density_matrix(np.array([[1, 0], [0, 0]])) assert_valid_density_matrix(np.array([[0.5, 0], [0, 0.5]])) assert_valid_density_matrix(np.array([[0.5, 0.5], [0.5, 0.5]])) @@ -615,7 +615,7 @@ def test_to_valid_density_matrix_from_density_matrix(): assert_valid_density_matrix(np.diag([0.2, 0.8, 0, 0]), qid_shape=(4,)) -def test_to_valid_density_matrix_from_density_matrix_tensor(): +def test_to_valid_density_matrix_from_density_matrix_tensor() -> None: np.testing.assert_almost_equal( cirq.to_valid_density_matrix( cirq.one_hot(shape=(2, 2, 2, 2, 2, 2), dtype=np.complex64), num_qubits=3 @@ -630,19 +630,19 @@ def test_to_valid_density_matrix_from_density_matrix_tensor(): ) -def test_to_valid_density_matrix_not_square(): +def test_to_valid_density_matrix_not_square() -> None: with pytest.raises(ValueError, match='shape'): cirq.to_valid_density_matrix(np.array([[1], [0]]), num_qubits=1) -def test_to_valid_density_matrix_size_mismatch_num_qubits(): +def test_to_valid_density_matrix_size_mismatch_num_qubits() -> None: with pytest.raises(ValueError, match='shape'): cirq.to_valid_density_matrix(np.array([[[1, 0], [0, 0]], [[0, 0], [0, 0]]]), num_qubits=2) with pytest.raises(ValueError, match='shape'): cirq.to_valid_density_matrix(np.eye(4) / 4.0, num_qubits=1) -def test_to_valid_density_matrix_not_hermitian(): +def test_to_valid_density_matrix_not_hermitian() -> None: with pytest.raises(ValueError, match='hermitian'): cirq.to_valid_density_matrix(np.array([[0.5, 0.5j], [0.5, 0.5j]]), num_qubits=1) with pytest.raises(ValueError, match='hermitian'): @@ -654,7 +654,7 @@ def test_to_valid_density_matrix_not_hermitian(): ) -def test_to_valid_density_matrix_mismatched_qid_shape(): +def test_to_valid_density_matrix_mismatched_qid_shape() -> None: with pytest.raises(ValueError, match=r'num_qubits != len\(qid_shape\)'): cirq.to_valid_density_matrix(np.eye(4) / 4, num_qubits=1, qid_shape=(2, 2)) with pytest.raises(ValueError, match=r'num_qubits != len\(qid_shape\)'): @@ -663,28 +663,28 @@ def test_to_valid_density_matrix_mismatched_qid_shape(): cirq.to_valid_density_matrix(np.eye(4) / 4) -def test_to_valid_density_matrix_not_unit_trace(): +def test_to_valid_density_matrix_not_unit_trace() -> None: with pytest.raises(ValueError, match='trace 1'): cirq.to_valid_density_matrix(np.array([[1, 0], [0, -0.1]]), num_qubits=1) with pytest.raises(ValueError, match='trace 1'): cirq.to_valid_density_matrix(np.zeros([2, 2]), num_qubits=1) -def test_to_valid_density_matrix_not_positive_semidefinite(): +def test_to_valid_density_matrix_not_positive_semidefinite() -> None: with pytest.raises(ValueError, match='positive semidefinite'): cirq.to_valid_density_matrix( np.array([[0.6, 0.5], [0.5, 0.4]], dtype=np.complex64), num_qubits=1 ) -def test_to_valid_density_matrix_wrong_dtype(): +def test_to_valid_density_matrix_wrong_dtype() -> None: with pytest.raises(ValueError, match='dtype'): cirq.to_valid_density_matrix( np.array([[1, 0], [0, 0]], dtype=np.complex64), num_qubits=1, dtype=np.complex128 ) -def test_to_valid_density_matrix_from_state_vector(): +def test_to_valid_density_matrix_from_state_vector() -> None: np.testing.assert_almost_equal( cirq.to_valid_density_matrix( density_matrix_rep=np.array([1, 0], dtype=np.complex64), num_qubits=1 @@ -713,7 +713,7 @@ def test_to_valid_density_matrix_from_state_vector(): ) -def test_to_valid_density_matrix_from_state_vector_tensor(): +def test_to_valid_density_matrix_from_state_vector_tensor() -> None: np.testing.assert_almost_equal( cirq.to_valid_density_matrix( density_matrix_rep=np.array(np.full((2, 2), 0.5), dtype=np.complex64), num_qubits=2 @@ -722,12 +722,12 @@ def test_to_valid_density_matrix_from_state_vector_tensor(): ) -def test_to_valid_density_matrix_from_state_invalid_state(): +def test_to_valid_density_matrix_from_state_invalid_state() -> None: with pytest.raises(ValueError, match="Invalid quantum state"): cirq.to_valid_density_matrix(np.array([1, 0, 0]), num_qubits=2) -def test_to_valid_density_matrix_from_computational_basis(): +def test_to_valid_density_matrix_from_computational_basis() -> None: np.testing.assert_almost_equal( cirq.to_valid_density_matrix(density_matrix_rep=0, num_qubits=1), np.array([[1, 0], [0, 0]]) ) @@ -743,12 +743,12 @@ def test_to_valid_density_matrix_from_computational_basis(): ) -def test_to_valid_density_matrix_from_state_invalid_computational_basis(): +def test_to_valid_density_matrix_from_state_invalid_computational_basis() -> None: with pytest.raises(ValueError, match="out of range"): cirq.to_valid_density_matrix(-1, num_qubits=2) -def test_one_hot(): +def test_one_hot() -> None: result = cirq.one_hot(shape=4, dtype=np.int32) assert result.dtype == np.int32 np.testing.assert_array_equal(result, [1, 0, 0, 0]) @@ -766,7 +766,7 @@ def test_one_hot(): ) -def test_eye_tensor(): +def test_eye_tensor() -> None: assert np.all(cirq.eye_tensor((), dtype=int) == np.array(1)) assert np.all(cirq.eye_tensor((1,), dtype=int) == np.array([[1]])) assert np.all(cirq.eye_tensor((2,), dtype=int) == np.array([[1, 0], [0, 1]])) # yapf: disable diff --git a/cirq-core/cirq/sim/classical_simulator_test.py b/cirq-core/cirq/sim/classical_simulator_test.py index d50438a0439..f62986114f9 100644 --- a/cirq-core/cirq/sim/classical_simulator_test.py +++ b/cirq-core/cirq/sim/classical_simulator_test.py @@ -23,7 +23,7 @@ import cirq -def test_x_gate(): +def test_x_gate() -> None: q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit() circuit.append(cirq.X(q0)) @@ -31,30 +31,33 @@ def test_x_gate(): circuit.append(cirq.X(q1)) circuit.append(cirq.measure((q0, q1), key='key')) expected_results = {'key': np.array([[[1, 0]]], dtype=np.uint8)} + sim: cirq.ClassicalStateSimulator sim = cirq.ClassicalStateSimulator() results = sim.run(circuit, param_resolver=None, repetitions=1).records np.testing.assert_equal(results, expected_results) -def test_CNOT(): +def test_CNOT() -> None: q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit() circuit.append(cirq.X(q0)) circuit.append(cirq.CNOT(q0, q1)) circuit.append(cirq.measure(q1, key='key')) expected_results = {'key': np.array([[[1]]], dtype=np.uint8)} + sim: cirq.ClassicalStateSimulator sim = cirq.ClassicalStateSimulator() results = sim.run(circuit, param_resolver=None, repetitions=1).records np.testing.assert_equal(results, expected_results) -def test_Swap(): +def test_Swap() -> None: q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit() circuit.append(cirq.X(q0)) circuit.append(cirq.SWAP(q0, q1)) circuit.append(cirq.measure((q0, q1), key='key')) expected_results = {'key': np.array([[[0, 1]]], dtype=np.uint8)} + sim: cirq.ClassicalStateSimulator sim = cirq.ClassicalStateSimulator() results = sim.run(circuit, param_resolver=None, repetitions=1).records np.testing.assert_equal(results, expected_results) @@ -67,10 +70,11 @@ def test_Swap(): for n in np.random.randint(3, 8, size=10) ], ) -def test_qubit_permutation_gate(n, perm, state): +def test_qubit_permutation_gate(n, perm, state) -> None: qubits = cirq.LineQubit.range(n) perm_gate = cirq.QubitPermutationGate(perm) circuit = cirq.Circuit(perm_gate(*qubits), cirq.measure(*qubits, key='key')) + sim: cirq.ClassicalStateSimulator sim = cirq.ClassicalStateSimulator() result = sim.simulate(circuit, initial_state=state) expected = [0] * n @@ -79,7 +83,7 @@ def test_qubit_permutation_gate(n, perm, state): np.testing.assert_equal(result.measurements['key'], expected) -def test_CCNOT(): +def test_CCNOT() -> None: q0, q1, q2 = cirq.LineQubit.range(3) circuit = cirq.Circuit() circuit.append(cirq.CCNOT(q0, q1, q2)) @@ -97,13 +101,14 @@ def test_CCNOT(): expected_results = { 'key': np.array([[[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 1]]], dtype=np.uint8) } + sim: cirq.ClassicalStateSimulator sim = cirq.ClassicalStateSimulator() results = sim.run(circuit, param_resolver=None, repetitions=1).records np.testing.assert_equal(results, expected_results) @pytest.mark.parametrize(['initial_state'], [(list(x),) for x in product([0, 1], repeat=4)]) -def test_CCCX(initial_state): +def test_CCCX(initial_state) -> None: CCCX = cirq.CCNOT.controlled() qubits = cirq.LineQubit.range(4) @@ -114,13 +119,14 @@ def test_CCCX(initial_state): final_state = initial_state.copy() final_state[-1] ^= all(final_state[:-1]) + sim: cirq.ClassicalStateSimulator sim = cirq.ClassicalStateSimulator() results = sim.simulate(circuit, initial_state=initial_state).measurements['key'] np.testing.assert_equal(results, final_state) @pytest.mark.parametrize(['initial_state'], [(list(x),) for x in product([0, 1], repeat=3)]) -def test_CSWAP(initial_state): +def test_CSWAP(initial_state) -> None: CSWAP = cirq.SWAP.controlled() qubits = cirq.LineQubit.range(3) circuit = cirq.Circuit() @@ -134,46 +140,50 @@ def test_CSWAP(initial_state): b, c = c, b final_state = [a, b, c] + sim: cirq.ClassicalStateSimulator sim = cirq.ClassicalStateSimulator() results = sim.simulate(circuit, initial_state=initial_state).measurements['key'] np.testing.assert_equal(results, final_state) -def test_measurement_gate(): +def test_measurement_gate() -> None: q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit() circuit.append(cirq.measure((q0, q1), key='key')) expected_results = {'key': np.array([[[0, 0]]], dtype=np.uint8)} + sim: cirq.ClassicalStateSimulator sim = cirq.ClassicalStateSimulator() results = sim.run(circuit, param_resolver=None, repetitions=1).records np.testing.assert_equal(results, expected_results) -def test_qubit_order(): +def test_qubit_order() -> None: q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit() circuit.append(cirq.CNOT(q0, q1)) circuit.append(cirq.X(q0)) circuit.append(cirq.measure((q0, q1), key='key')) expected_results = {'key': np.array([[[1, 0]]], dtype=np.uint8)} + sim: cirq.ClassicalStateSimulator sim = cirq.ClassicalStateSimulator() results = sim.run(circuit, param_resolver=None, repetitions=1).records np.testing.assert_equal(results, expected_results) -def test_same_key_instances(): +def test_same_key_instances() -> None: q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit() circuit.append(cirq.measure((q0, q1), key='key')) circuit.append(cirq.X(q0)) circuit.append(cirq.measure((q0, q1), key='key')) expected_results = {'key': np.array([[[0, 0], [1, 0]]], dtype=np.uint8)} + sim: cirq.ClassicalStateSimulator sim = cirq.ClassicalStateSimulator() results = sim.run(circuit, param_resolver=None, repetitions=1).records np.testing.assert_equal(results, expected_results) -def test_same_key_instances_order(): +def test_same_key_instances_order() -> None: q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit() circuit.append(cirq.X(q0)) @@ -181,12 +191,13 @@ def test_same_key_instances_order(): circuit.append(cirq.X(q0)) circuit.append(cirq.measure((q1, q0), key='key')) expected_results = {'key': np.array([[[1, 0], [0, 0]]], dtype=np.uint8)} + sim: cirq.ClassicalStateSimulator sim = cirq.ClassicalStateSimulator() results = sim.run(circuit, param_resolver=None, repetitions=1).records np.testing.assert_equal(results, expected_results) -def test_repetitions(): +def test_repetitions() -> None: q0 = cirq.LineQubit.range(1) circuit = cirq.Circuit() circuit.append(cirq.measure(q0, key='key')) @@ -195,12 +206,13 @@ def test_repetitions(): [[[0]], [[0]], [[0]], [[0]], [[0]], [[0]], [[0]], [[0]], [[0]], [[0]]], dtype=np.uint8 ) } + sim: cirq.ClassicalStateSimulator sim = cirq.ClassicalStateSimulator() results = sim.run(circuit, param_resolver=None, repetitions=10).records np.testing.assert_equal(results, expected_results) -def test_multiple_gates(): +def test_multiple_gates() -> None: q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit() circuit.append(cirq.X(q0)) @@ -210,12 +222,13 @@ def test_multiple_gates(): circuit.append(cirq.X(q1)) circuit.append(cirq.measure((q0, q1), key='key')) expected_results = {'key': np.array([[[1, 0]]], dtype=np.uint8)} + sim: cirq.ClassicalStateSimulator sim = cirq.ClassicalStateSimulator() results = sim.run(circuit, param_resolver=None, repetitions=1).records np.testing.assert_equal(results, expected_results) -def test_multiple_gates_order(): +def test_multiple_gates_order() -> None: q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit() circuit.append(cirq.X(q0)) @@ -223,20 +236,22 @@ def test_multiple_gates_order(): circuit.append(cirq.CNOT(q1, q0)) circuit.append(cirq.measure((q0, q1), key='key')) expected_results = {'key': np.array([[[0, 1]]], dtype=np.uint8)} + sim: cirq.ClassicalStateSimulator sim = cirq.ClassicalStateSimulator() results = sim.run(circuit, param_resolver=None, repetitions=1).records np.testing.assert_equal(results, expected_results) -def test_tuple_initial_state(): +def test_tuple_initial_state() -> None: q0, q1, q2 = cirq.LineQubit.range(3) circuit = cirq.Circuit(cirq.X(q0), cirq.measure(q0, q1, q2, key='key')) + sim: cirq.ClassicalStateSimulator sim = cirq.ClassicalStateSimulator() result = sim.simulate(circuit, initial_state=(0, 1, 0)) np.testing.assert_equal(result.measurements['key'], [1, 1, 0]) -def test_param_resolver(): +def test_param_resolver() -> None: gate = cirq.CNOT ** sympy.Symbol('t') q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit() @@ -244,6 +259,7 @@ def test_param_resolver(): circuit.append(gate(q0, q1)) circuit.append(cirq.measure((q1), key='key')) resolver = cirq.ParamResolver({'t': 0}) + sim: cirq.ClassicalStateSimulator sim = cirq.ClassicalStateSimulator() results_with_parameter_zero = sim.run(circuit, param_resolver=resolver, repetitions=1).records resolver = cirq.ParamResolver({'t': 1}) @@ -252,33 +268,37 @@ def test_param_resolver(): np.testing.assert_equal(results_with_parameter_one, {'key': np.array([[[1]]], dtype=np.uint8)}) -def test_unknown_gates(): +def test_unknown_gates() -> None: gate = cirq.Y q = cirq.LineQubit(0) circuit = cirq.Circuit(gate(q), cirq.measure((q), key='key')) + sim: cirq.ClassicalStateSimulator sim = cirq.ClassicalStateSimulator() with pytest.raises(ValueError): _ = sim.run(circuit).records -def test_incompatible_measurements(): +def test_incompatible_measurements() -> None: qs = cirq.LineQubit.range(2) c = cirq.Circuit(cirq.measure(qs, key='key'), cirq.measure(qs[0], key='key')) + sim: cirq.ClassicalStateSimulator sim = cirq.ClassicalStateSimulator() with pytest.raises(ValueError): _ = sim.run(c) -def test_compatible_measurement(): +def test_compatible_measurement() -> None: qs = cirq.LineQubit.range(2) c = cirq.Circuit(cirq.measure(qs, key='key'), cirq.X.on_each(qs), cirq.measure(qs, key='key')) + sim: cirq.ClassicalStateSimulator sim = cirq.ClassicalStateSimulator() res = sim.run(c, repetitions=3).records np.testing.assert_equal(res['key'], np.array([[[0, 0], [1, 1]]] * 3, dtype=np.uint8)) -def test_simulate_sweeps_param_resolver(): +def test_simulate_sweeps_param_resolver() -> None: q0, q1 = cirq.LineQubit.range(2) + simulator: cirq.ClassicalStateSimulator simulator = cirq.ClassicalStateSimulator() for b0 in [0, 1]: for b1 in [0, 1]: @@ -295,18 +315,22 @@ def test_simulate_sweeps_param_resolver(): assert results[1].params == params[1] -def test_create_partial_simulation_state_from_int_with_no_qubits(): +def test_create_partial_simulation_state_from_int_with_no_qubits() -> None: + sim: cirq.ClassicalStateSimulator sim = cirq.ClassicalStateSimulator() initial_state = 5 qs = None classical_data = cirq.value.ClassicalDataDictionaryStore() with pytest.raises(ValueError): sim._create_partial_simulation_state( - initial_state=initial_state, qubits=qs, classical_data=classical_data + initial_state=initial_state, + qubits=qs, # type: ignore[arg-type] + classical_data=classical_data, ) -def test_create_partial_simulation_state_from_invalid_state(): +def test_create_partial_simulation_state_from_invalid_state() -> None: + sim: cirq.ClassicalStateSimulator sim = cirq.ClassicalStateSimulator() initial_state = None qs = cirq.LineQubit.range(2) @@ -317,7 +341,8 @@ def test_create_partial_simulation_state_from_invalid_state(): ) -def test_create_partial_simulation_state_from_int(): +def test_create_partial_simulation_state_from_int() -> None: + sim: cirq.ClassicalStateSimulator sim = cirq.ClassicalStateSimulator() initial_state = 15 qs = cirq.LineQubit.range(4) @@ -329,7 +354,8 @@ def test_create_partial_simulation_state_from_int(): assert result == expected_result -def test_create_valid_partial_simulation_state_from_list(): +def test_create_valid_partial_simulation_state_from_list() -> None: + sim: cirq.ClassicalStateSimulator sim = cirq.ClassicalStateSimulator() initial_state = [1, 1, 1, 1] qs = cirq.LineQubit.range(4) @@ -341,7 +367,8 @@ def test_create_valid_partial_simulation_state_from_list(): assert result == expected_result -def test_create_valid_partial_simulation_state_from_np(): +def test_create_valid_partial_simulation_state_from_np() -> None: + sim: cirq.ClassicalStateSimulator sim = cirq.ClassicalStateSimulator() initial_state = np.array([1, 1]) qs = cirq.LineQubit.range(2) @@ -355,10 +382,11 @@ def test_create_valid_partial_simulation_state_from_np(): np.testing.assert_equal(result, expected_result) -def test_create_invalid_partial_simulation_state_from_np(): +def test_create_invalid_partial_simulation_state_from_np() -> None: initial_state = np.array([[1, 1], [1, 1]]) qs = cirq.LineQubit.range(2) classical_data = cirq.value.ClassicalDataDictionaryStore() + sim: cirq.ClassicalStateSimulator sim = cirq.ClassicalStateSimulator() with pytest.raises(ValueError): @@ -367,7 +395,7 @@ def test_create_invalid_partial_simulation_state_from_np(): ) -def test_noise_model(): +def test_noise_model() -> None: noise_model = cirq.NoiseModel.from_noise_model_like(cirq.depolarize(p=0.01)) with pytest.raises(ValueError): cirq.ClassicalStateSimulator(noise=noise_model) diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator.py b/cirq-core/cirq/sim/clifford/clifford_simulator.py index 841fa282a92..0e7d62f06c3 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator.py @@ -161,7 +161,7 @@ def __init__( sim_state: The qubit:SimulationState lookup for this step. """ super().__init__(sim_state) - self._clifford_state = None + self._clifford_state: CliffordState | None = None def __str__(self) -> str: def bitstring(vals): @@ -183,7 +183,7 @@ def _repr_pretty_(self, p, cycle): p.text("cirq.CliffordSimulatorStateResult(...)" if cycle else self.__str__()) @property - def state(self): + def state(self) -> CliffordState: if self._clifford_state is None: clifford_state = CliffordState(self._qubit_mapping) clifford_state.ch_form = self._merged_sim_state.state.copy() @@ -240,10 +240,10 @@ def __str__(self) -> str: def to_numpy(self) -> np.ndarray: return self.ch_form.to_state_vector() - def state_vector(self): + def state_vector(self) -> np.ndarray: return self.ch_form.state_vector() - def apply_unitary(self, op: cirq.Operation): + def apply_unitary(self, op: cirq.Operation) -> None: ch_form_args = clifford.StabilizerChFormSimulationState( prng=np.random.RandomState(), qubits=self.qubit_map.keys(), initial_state=self.ch_form ) @@ -259,7 +259,7 @@ def apply_measurement( measurements: dict[str, list[int]], prng: np.random.RandomState, collapse_state_vector=True, - ): + ) -> None: if not isinstance(op.gate, cirq.MeasurementGate): raise TypeError( f'apply_measurement only supports cirq.MeasurementGate operations. ' diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator_test.py b/cirq-core/cirq/sim/clifford/clifford_simulator_test.py index ff26c1d2a1a..8a09947667d 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator_test.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator_test.py @@ -15,6 +15,7 @@ from __future__ import annotations import itertools +import random import numpy as np import pytest @@ -24,7 +25,7 @@ import cirq.testing -def test_simulate_no_circuit(): +def test_simulate_no_circuit() -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.CliffordSimulator() circuit = cirq.Circuit() @@ -33,7 +34,7 @@ def test_simulate_no_circuit(): assert len(result.measurements) == 0 -def test_run_no_repetitions(): +def test_run_no_repetitions() -> None: q0 = cirq.LineQubit(0) simulator = cirq.CliffordSimulator() circuit = cirq.Circuit(cirq.H(q0), cirq.measure(q0)) @@ -41,7 +42,7 @@ def test_run_no_repetitions(): assert sum(result.measurements['q(0)']) == 0 -def test_run_hadamard(): +def test_run_hadamard() -> None: q0 = cirq.LineQubit(0) simulator = cirq.CliffordSimulator() circuit = cirq.Circuit(cirq.H(q0), cirq.measure(q0)) @@ -50,7 +51,7 @@ def test_run_hadamard(): assert sum(result.measurements['q(0)'])[0] > 20 -def test_run_GHZ(): +def test_run_GHZ() -> None: (q0, q1) = (cirq.LineQubit(0), cirq.LineQubit(1)) simulator = cirq.CliffordSimulator() circuit = cirq.Circuit(cirq.H(q0), cirq.H(q1), cirq.measure(q0)) @@ -59,7 +60,7 @@ def test_run_GHZ(): assert sum(result.measurements['q(0)'])[0] > 20 -def test_run_correlations(): +def test_run_correlations() -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.CliffordSimulator() circuit = cirq.Circuit(cirq.H(q0), cirq.CNOT(q0, q1), cirq.measure(q0, q1)) @@ -69,7 +70,7 @@ def test_run_correlations(): assert bits[0] == bits[1] -def test_run_parameters_not_resolved(): +def test_run_parameters_not_resolved() -> None: a = cirq.LineQubit(0) simulator = cirq.CliffordSimulator() circuit = cirq.Circuit(cirq.XPowGate(exponent=sympy.Symbol('a'))(a), cirq.measure(a)) @@ -77,7 +78,7 @@ def test_run_parameters_not_resolved(): _ = simulator.run_sweep(circuit, cirq.ParamResolver({})) -def test_simulate_parameters_not_resolved(): +def test_simulate_parameters_not_resolved() -> None: a = cirq.LineQubit(0) simulator = cirq.CliffordSimulator() circuit = cirq.Circuit(cirq.XPowGate(exponent=sympy.Symbol('a'))(a), cirq.measure(a)) @@ -85,7 +86,7 @@ def test_simulate_parameters_not_resolved(): _ = simulator.simulate_sweep(circuit, cirq.ParamResolver({})) -def test_simulate(): +def test_simulate() -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.CliffordSimulator() circuit = cirq.Circuit(cirq.H(q0), cirq.H(q1)) @@ -94,7 +95,7 @@ def test_simulate(): assert len(result.measurements) == 0 -def test_simulate_initial_state(): +def test_simulate_initial_state() -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.CliffordSimulator() for b0 in [0, 1]: @@ -119,7 +120,7 @@ def test_simulate_initial_state(): ) -def test_simulation_state(): +def test_simulation_state() -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.CliffordSimulator() for b0 in [0, 1]: @@ -140,7 +141,7 @@ def test_simulation_state(): ) -def test_simulate_qubit_order(): +def test_simulate_qubit_order() -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.CliffordSimulator() for b0 in [0, 1]: @@ -160,7 +161,7 @@ def test_simulate_qubit_order(): ) -def test_run_measure_multiple_qubits(): +def test_run_measure_multiple_qubits() -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.CliffordSimulator() for b0 in [0, 1]: @@ -175,7 +176,7 @@ def test_run_measure_multiple_qubits(): np.testing.assert_equal(result.measurements, {'q(0),q(1)': [[b0, b1]] * 3}) -def test_simulate_moment_steps(): +def test_simulate_moment_steps() -> None: q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit(cirq.H(q0), cirq.H(q1), cirq.H(q0), cirq.H(q1)) simulator = cirq.CliffordSimulator() @@ -186,7 +187,7 @@ def test_simulate_moment_steps(): np.testing.assert_almost_equal(step.state.to_numpy(), np.array([1, 0, 0, 0])) -def test_simulate_moment_steps_sample(): +def test_simulate_moment_steps_sample() -> None: q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit(cirq.H(q0), cirq.CNOT(q0, q1)) simulator = cirq.CliffordSimulator() @@ -206,7 +207,7 @@ def test_simulate_moment_steps_sample(): @pytest.mark.parametrize('split', [True, False]) -def test_simulate_moment_steps_intermediate_measurement(split): +def test_simulate_moment_steps_intermediate_measurement(split) -> None: q0 = cirq.LineQubit(0) circuit = cirq.Circuit(cirq.H(q0), cirq.measure(q0), cirq.H(q0)) simulator = cirq.CliffordSimulator(split_untangled_states=split) @@ -221,7 +222,7 @@ def test_simulate_moment_steps_intermediate_measurement(split): np.testing.assert_almost_equal(step.state.to_numpy(), expected) -def test_clifford_state_initial_state(): +def test_clifford_state_initial_state() -> None: q0 = cirq.LineQubit(0) with pytest.raises(ValueError, match='Out of range'): _ = cirq.CliffordState(qubit_map={q0: 0}, initial_state=2) @@ -231,7 +232,7 @@ def test_clifford_state_initial_state(): assert state.copy() == state -def test_clifford_trial_result_repr(): +def test_clifford_trial_result_repr() -> None: q0 = cirq.LineQubit(0) final_simulator_state = cirq.StabilizerChFormSimulationState(qubits=[q0]) assert ( @@ -251,7 +252,7 @@ def test_clifford_trial_result_repr(): ) -def test_clifford_trial_result_str(): +def test_clifford_trial_result_str() -> None: q0 = cirq.LineQubit(0) final_simulator_state = cirq.StabilizerChFormSimulationState(qubits=[q0]) assert ( @@ -267,7 +268,7 @@ def test_clifford_trial_result_str(): ) -def test_clifford_trial_result_repr_pretty(): +def test_clifford_trial_result_repr_pretty() -> None: q0 = cirq.LineQubit(0) final_simulator_state = cirq.StabilizerChFormSimulationState(qubits=[q0]) result = cirq.CliffordTrialResult( @@ -280,7 +281,7 @@ def test_clifford_trial_result_repr_pretty(): cirq.testing.assert_repr_pretty(result, "cirq.CliffordTrialResult(...)", cycle=True) -def test_clifford_step_result_str(): +def test_clifford_step_result_str() -> None: q0 = cirq.LineQubit(0) result = next( cirq.CliffordSimulator().simulate_moment_steps(cirq.Circuit(cirq.measure(q0, key='m'))) @@ -288,7 +289,7 @@ def test_clifford_step_result_str(): assert str(result) == "m=0\n|0⟩" -def test_clifford_step_result_repr_pretty(): +def test_clifford_step_result_repr_pretty() -> None: q0 = cirq.LineQubit(0) result = next( cirq.CliffordSimulator().simulate_moment_steps(cirq.Circuit(cirq.measure(q0, key='m'))) @@ -297,40 +298,40 @@ def test_clifford_step_result_repr_pretty(): cirq.testing.assert_repr_pretty(result, "cirq.CliffordSimulatorStateResult(...)", cycle=True) -def test_clifford_step_result_no_measurements_str(): +def test_clifford_step_result_no_measurements_str() -> None: q0 = cirq.LineQubit(0) result = next(cirq.CliffordSimulator().simulate_moment_steps(cirq.Circuit(cirq.I(q0)))) assert str(result) == "|0⟩" -def test_clifford_state_str(): +def test_clifford_state_str() -> None: (q0, q1) = (cirq.LineQubit(0), cirq.LineQubit(1)) state = cirq.CliffordState(qubit_map={q0: 0, q1: 1}) assert str(state) == "|00⟩" -def test_clifford_state_state_vector(): +def test_clifford_state_state_vector() -> None: (q0, q1) = (cirq.LineQubit(0), cirq.LineQubit(1)) state = cirq.CliffordState(qubit_map={q0: 0, q1: 1}) np.testing.assert_equal(state.state_vector(), [1.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j]) -def test_stabilizerStateChForm_H(): +def test_stabilizerStateChForm_H() -> None: (q0, q1) = (cirq.LineQubit(0), cirq.LineQubit(1)) state = cirq.CliffordState(qubit_map={q0: 0, q1: 1}) with pytest.raises(ValueError, match="|y> is equal to |z>"): state.ch_form._H_decompose(0, 1, 1, 0) -def test_clifford_stabilizerStateChForm_repr(): +def test_clifford_stabilizerStateChForm_repr() -> None: (q0, q1) = (cirq.LineQubit(0), cirq.LineQubit(1)) state = cirq.CliffordState(qubit_map={q0: 0, q1: 1}) assert repr(state) == 'StabilizerStateChForm(num_qubits=2)' -def test_clifford_circuit_SHSYSHS(): +def test_clifford_circuit_SHSYSHS() -> None: q0 = cirq.LineQubit(0) circuit = cirq.Circuit( cirq.S(q0), @@ -355,7 +356,7 @@ def test_clifford_circuit_SHSYSHS(): @pytest.mark.parametrize('split', [True, False]) -def test_clifford_circuit(split): +def test_clifford_circuit(split) -> None: (q0, q1) = (cirq.LineQubit(0), cirq.LineQubit(1)) circuit = cirq.Circuit() @@ -363,15 +364,15 @@ def test_clifford_circuit(split): x = np.random.randint(7) if x == 0: - circuit.append(cirq.X(np.random.choice((q0, q1)))) + circuit.append(cirq.X(random.choice((q0, q1)))) elif x == 1: - circuit.append(cirq.Z(np.random.choice((q0, q1)))) + circuit.append(cirq.Z(random.choice((q0, q1)))) elif x == 2: - circuit.append(cirq.Y(np.random.choice((q0, q1)))) + circuit.append(cirq.Y(random.choice((q0, q1)))) elif x == 3: - circuit.append(cirq.S(np.random.choice((q0, q1)))) + circuit.append(cirq.S(random.choice((q0, q1)))) elif x == 4: - circuit.append(cirq.H(np.random.choice((q0, q1)))) + circuit.append(cirq.H(random.choice((q0, q1)))) elif x == 5: circuit.append(cirq.CNOT(q0, q1)) elif x == 6: @@ -388,7 +389,7 @@ def test_clifford_circuit(split): @pytest.mark.parametrize("qubits", [cirq.LineQubit.range(2), cirq.LineQubit.range(4)]) @pytest.mark.parametrize('split', [True, False]) -def test_clifford_circuit_2(qubits, split): +def test_clifford_circuit_2(qubits, split) -> None: circuit = cirq.Circuit() np.random.seed(2) @@ -419,7 +420,7 @@ def test_clifford_circuit_2(qubits, split): @pytest.mark.parametrize('split', [True, False]) -def test_clifford_circuit_3(split): +def test_clifford_circuit_3(split) -> None: # This test tests the simulator on arbitrary 1-qubit Clifford gates. (q0, q1) = (cirq.LineQubit(0), cirq.LineQubit(1)) circuit = cirq.Circuit() @@ -435,7 +436,7 @@ def random_clifford_gate(): if np.random.randint(5) == 0: circuit.append(cirq.CNOT(q0, q1)) else: - circuit.append(random_clifford_gate()(np.random.choice((q0, q1)))) + circuit.append(random_clifford_gate()(random.choice((q0, q1)))) clifford_simulator = cirq.CliffordSimulator(split_untangled_states=split) state_vector_simulator = cirq.Simulator() @@ -447,7 +448,7 @@ def random_clifford_gate(): ) -def test_non_clifford_circuit(): +def test_non_clifford_circuit() -> None: q0 = cirq.LineQubit(0) circuit = cirq.Circuit() circuit.append(cirq.T(q0)) @@ -455,7 +456,7 @@ def test_non_clifford_circuit(): cirq.CliffordSimulator().simulate(circuit) -def test_swap(): +def test_swap() -> None: a, b = cirq.LineQubit.range(2) circuit = cirq.Circuit( cirq.X(a), @@ -472,7 +473,7 @@ def test_swap(): cirq.CliffordSimulator().simulate((cirq.Circuit(cirq.SWAP(a, b) ** 3.5))) -def test_sample_seed(): +def test_sample_seed() -> None: q = cirq.NamedQubit('q') circuit = cirq.Circuit(cirq.H(q), cirq.measure(q)) simulator = cirq.CliffordSimulator(seed=1234) @@ -482,7 +483,7 @@ def test_sample_seed(): assert result_string == '11010001111100100000' -def test_is_supported_operation(): +def test_is_supported_operation() -> None: class MultiQubitOp(cirq.Operation): """Multi-qubit operation with unitary. @@ -514,7 +515,7 @@ def _unitary_(self): assert not cirq.CliffordSimulator.is_supported_operation(MultiQubitOp()) -def test_simulate_pauli_string(): +def test_simulate_pauli_string() -> None: q = cirq.NamedQubit('q') circuit = cirq.Circuit([cirq.PauliString({q: 'X'}), cirq.PauliString({q: 'Z'})]) simulator = cirq.CliffordSimulator() @@ -524,7 +525,7 @@ def test_simulate_pauli_string(): assert np.allclose(result, [0, -1]) -def test_simulate_global_phase_operation(): +def test_simulate_global_phase_operation() -> None: q1, q2 = cirq.LineQubit.range(2) circuit = cirq.Circuit([cirq.I(q1), cirq.I(q2), cirq.global_phase_operation(-1j)]) simulator = cirq.CliffordSimulator() @@ -534,7 +535,7 @@ def test_simulate_global_phase_operation(): assert np.allclose(result, [-1j, 0, 0, 0]) -def test_json_roundtrip(): +def test_json_roundtrip() -> None: (q0, q1, q2) = (cirq.LineQubit(0), cirq.LineQubit(1), cirq.LineQubit(2)) state = cirq.CliffordState(qubit_map={q0: 0, q1: 1, q2: 2}) @@ -557,19 +558,19 @@ def test_json_roundtrip(): assert np.allclose(state.ch_form.state_vector(), state_roundtrip.ch_form.state_vector()) -def test_invalid_apply_measurement(): +def test_invalid_apply_measurement() -> None: q0 = cirq.LineQubit(0) state = cirq.CliffordState(qubit_map={q0: 0}) - measurements = {} + measurements: dict[str, list[int]] = {} with pytest.raises(TypeError, match='only supports cirq.MeasurementGate'): state.apply_measurement(cirq.H(q0), measurements, np.random.RandomState()) assert measurements == {} -def test_valid_apply_measurement(): +def test_valid_apply_measurement() -> None: q0 = cirq.LineQubit(0) state = cirq.CliffordState(qubit_map={q0: 0}, initial_state=1) - measurements = {} + measurements: dict[str, list[int]] = {} state.apply_measurement( cirq.measure(q0), measurements, np.random.RandomState(), collapse_state_vector=False ) @@ -579,7 +580,7 @@ def test_valid_apply_measurement(): @pytest.mark.parametrize('split', [True, False]) -def test_reset(split): +def test_reset(split) -> None: q = cirq.LineQubit(0) c = cirq.Circuit(cirq.X(q), cirq.reset(q), cirq.measure(q, key="out")) sim = cirq.CliffordSimulator(split_untangled_states=split) @@ -590,7 +591,7 @@ def test_reset(split): assert sim.sample(c)["out"][0] == 0 -def test_state_copy(): +def test_state_copy() -> None: sim = cirq.CliffordSimulator() q = cirq.LineQubit(0) diff --git a/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form.py b/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form.py index 4dd54e207d2..5b7599bd865 100644 --- a/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form.py +++ b/cirq-core/cirq/sim/clifford/stabilizer_state_ch_form.py @@ -150,7 +150,7 @@ def _CNOT_right(self, q, r): self.F[:, r] ^= self.F[:, q] self.M[:, q] ^= self.M[:, r] - def update_sum(self, t, u, delta=0, alpha=0): + def update_sum(self, t, u, delta=0, alpha=0) -> None: """Implements the transformation (Proposition 4 in Bravyi et al) ``i^alpha U_H (|t> + i^delta |u>) = omega W_C W_H |s'>`` @@ -255,7 +255,7 @@ def _measure(self, q, prng: np.random.RandomState) -> int: self.project_Z(q, x_i) return x_i - def project_Z(self, q, z): + def project_Z(self, q, z) -> None: """Applies a Z projector on the q'th qubit. Returns: a normalized state with Z_q |psi> = z |psi> @@ -295,7 +295,7 @@ def reindex(self, axes: Sequence[int]) -> cirq.StabilizerStateChForm: copy.omega = self.omega return copy - def apply_x(self, axis: int, exponent: float = 1, global_shift: float = 0): + def apply_x(self, axis: int, exponent: float = 1, global_shift: float = 0) -> None: if exponent % 2 != 0: if exponent % 0.5 != 0.0: raise ValueError('X exponent must be half integer') # pragma: no cover @@ -304,7 +304,7 @@ def apply_x(self, axis: int, exponent: float = 1, global_shift: float = 0): self.apply_h(axis) self.omega *= _phase(exponent, global_shift) - def apply_y(self, axis: int, exponent: float = 1, global_shift: float = 0): + def apply_y(self, axis: int, exponent: float = 1, global_shift: float = 0) -> None: if exponent % 0.5 != 0.0: raise ValueError('Y exponent must be half integer') # pragma: no cover shift = _phase(exponent, global_shift) @@ -325,7 +325,7 @@ def apply_y(self, axis: int, exponent: float = 1, global_shift: float = 0): self.apply_z(axis) self.omega *= shift * (1 - 1j) / (2**0.5) - def apply_z(self, axis: int, exponent: float = 1, global_shift: float = 0): + def apply_z(self, axis: int, exponent: float = 1, global_shift: float = 0) -> None: if exponent % 2 != 0: if exponent % 0.5 != 0.0: raise ValueError('Z exponent must be half integer') # pragma: no cover @@ -337,7 +337,7 @@ def apply_z(self, axis: int, exponent: float = 1, global_shift: float = 0): self.gamma[axis] = (self.gamma[axis] - 1) % 4 self.omega *= _phase(exponent, global_shift) - def apply_h(self, axis: int, exponent: float = 1, global_shift: float = 0): + def apply_h(self, axis: int, exponent: float = 1, global_shift: float = 0) -> None: if exponent % 2 != 0: if exponent % 1 != 0: raise ValueError('H exponent must be integer') # pragma: no cover @@ -357,7 +357,7 @@ def apply_h(self, axis: int, exponent: float = 1, global_shift: float = 0): def apply_cz( self, control_axis: int, target_axis: int, exponent: float = 1, global_shift: float = 0 - ): + ) -> None: if exponent % 2 != 0: if exponent % 1 != 0: raise ValueError('CZ exponent must be integer') # pragma: no cover @@ -369,7 +369,7 @@ def apply_cz( def apply_cx( self, control_axis: int, target_axis: int, exponent: float = 1, global_shift: float = 0 - ): + ) -> None: if exponent % 2 != 0: if exponent % 1 != 0: raise ValueError('CX exponent must be integer') # pragma: no cover @@ -385,7 +385,7 @@ def apply_cx( self.M[control_axis, :] ^= self.M[target_axis, :] self.omega *= _phase(exponent, global_shift) - def apply_global_phase(self, coefficient: value.Scalar): + def apply_global_phase(self, coefficient: value.Scalar) -> None: self.omega *= coefficient def measure( diff --git a/cirq-core/cirq/sim/density_matrix_simulation_state.py b/cirq-core/cirq/sim/density_matrix_simulation_state.py index b75c377270e..e782a728fd2 100644 --- a/cirq-core/cirq/sim/density_matrix_simulation_state.py +++ b/cirq-core/cirq/sim/density_matrix_simulation_state.py @@ -16,7 +16,7 @@ from __future__ import annotations -from typing import Any, Callable, Sequence, TYPE_CHECKING +from typing import Any, Callable, Self, Sequence, TYPE_CHECKING import numpy as np @@ -284,7 +284,7 @@ def __init__( ) super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data) - def add_qubits(self, qubits: Sequence[cirq.Qid]): + def add_qubits(self, qubits: Sequence[cirq.Qid]) -> Self: ret = super().add_qubits(qubits) return ( self.kronecker_product(type(self)(qubits=qubits), inplace=True) @@ -292,7 +292,7 @@ def add_qubits(self, qubits: Sequence[cirq.Qid]): else ret ) - def remove_qubits(self, qubits: Sequence[cirq.Qid]): + def remove_qubits(self, qubits: Sequence[cirq.Qid]) -> Self: ret = super().remove_qubits(qubits) if ret is not NotImplemented: return ret @@ -332,15 +332,15 @@ def __repr__(self) -> str: ) @property - def target_tensor(self): + def target_tensor(self) -> np.ndarray: return self._state._density_matrix @property - def available_buffer(self): + def available_buffer(self) -> list[np.ndarray]: return self._state._buffer @property - def qid_shape(self): + def qid_shape(self) -> tuple[int, ...]: return self._state._qid_shape diff --git a/cirq-core/cirq/sim/density_matrix_simulator.py b/cirq-core/cirq/sim/density_matrix_simulator.py index 60e47af2068..434ea78f50f 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator.py +++ b/cirq-core/cirq/sim/density_matrix_simulator.py @@ -258,7 +258,7 @@ def __init__( self._dtype = dtype self._density_matrix: np.ndarray | None = None - def density_matrix(self, copy=True): + def density_matrix(self, copy=True) -> np.ndarray: """Returns the density matrix at this step in the simulation. The density matrix that is stored in this result is returned in the diff --git a/cirq-core/cirq/sim/density_matrix_simulator_test.py b/cirq-core/cirq/sim/density_matrix_simulator_test.py index 4c648da1db8..d10193d895b 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator_test.py +++ b/cirq-core/cirq/sim/density_matrix_simulator_test.py @@ -48,14 +48,14 @@ def _decompose_(self, qubits): return [chan.on(q) for chan, q in zip(self.channels, qubits)] -def test_invalid_dtype(): +def test_invalid_dtype() -> None: with pytest.raises(ValueError, match='complex'): - cirq.DensityMatrixSimulator(dtype=np.int32) + cirq.DensityMatrixSimulator(dtype=np.int32) # type: ignore[arg-type] @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_no_measurements(dtype: type[np.complexfloating], split: bool): +def test_run_no_measurements(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) @@ -66,7 +66,7 @@ def test_run_no_measurements(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_no_results(dtype: type[np.complexfloating], split: bool): +def test_run_no_results(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) @@ -77,7 +77,7 @@ def test_run_no_results(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_empty_circuit(dtype: type[np.complexfloating], split: bool): +def test_run_empty_circuit(dtype: type[np.complexfloating], split: bool) -> None: simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) with pytest.raises(ValueError, match="no measurements"): simulator.run(cirq.Circuit()) @@ -85,7 +85,7 @@ def test_run_empty_circuit(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_bit_flips(dtype: type[np.complexfloating], split: bool): +def test_run_bit_flips(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -99,7 +99,7 @@ def test_run_bit_flips(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_bit_flips_with_dephasing(dtype: type[np.complexfloating], split: bool): +def test_run_bit_flips_with_dephasing(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -113,7 +113,7 @@ def test_run_bit_flips_with_dephasing(dtype: type[np.complexfloating], split: bo @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_qudit_increments(dtype: type[np.complexfloating], split: bool): +def test_run_qudit_increments(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQid.for_qid_shape((3, 4)) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1, 2]: @@ -132,7 +132,7 @@ def test_run_qudit_increments(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_not_channel_op(dtype: type[np.complexfloating], split: bool): +def test_run_not_channel_op(dtype: type[np.complexfloating], split: bool) -> None: class BadOp(cirq.Operation): def __init__(self, qubits): self._qubits = qubits @@ -153,7 +153,7 @@ def with_qubits(self, *new_qubits): # pragma: no cover @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_mixture(dtype: type[np.complexfloating], split: bool): +def test_run_mixture(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit(cirq.bit_flip(0.5)(q0), cirq.measure(q0), cirq.measure(q1)) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) @@ -167,7 +167,7 @@ def test_run_mixture(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_qudit_mixture(dtype: type[np.complexfloating], split: bool): +def test_run_qudit_mixture(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQid.for_qid_shape((3, 2)) mixture = _TestMixture( [ @@ -188,7 +188,7 @@ def test_run_qudit_mixture(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_channel(dtype: type[np.complexfloating], split: bool): +def test_run_channel(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit( cirq.X(q0), cirq.amplitude_damp(0.5)(q0), cirq.measure(q0), cirq.measure(q1) @@ -205,7 +205,7 @@ def test_run_channel(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_decomposable_channel(dtype: type[np.complexfloating], split: bool): +def test_run_decomposable_channel(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit( @@ -226,7 +226,7 @@ def test_run_decomposable_channel(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_qudit_channel(dtype: type[np.complexfloating], split: bool): +def test_run_qudit_channel(dtype: type[np.complexfloating], split: bool) -> None: class TestChannel(cirq.Gate): def _qid_shape_(self): return (3,) @@ -258,7 +258,7 @@ def _kraus_(self): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_measure_at_end_no_repetitions(dtype: type[np.complexfloating], split: bool): +def test_run_measure_at_end_no_repetitions(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: @@ -277,7 +277,7 @@ def test_run_measure_at_end_no_repetitions(dtype: type[np.complexfloating], spli @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_repetitions_measure_at_end(dtype: type[np.complexfloating], split: bool): +def test_run_repetitions_measure_at_end(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: @@ -296,7 +296,9 @@ def test_run_repetitions_measure_at_end(dtype: type[np.complexfloating], split: @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_qudits_repetitions_measure_at_end(dtype: type[np.complexfloating], split: bool): +def test_run_qudits_repetitions_measure_at_end( + dtype: type[np.complexfloating], split: bool +) -> None: q0, q1 = cirq.LineQid.for_qid_shape((2, 3)) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: @@ -318,7 +320,9 @@ def test_run_qudits_repetitions_measure_at_end(dtype: type[np.complexfloating], @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_measurement_not_terminal_no_repetitions(dtype: type[np.complexfloating], split: bool): +def test_run_measurement_not_terminal_no_repetitions( + dtype: type[np.complexfloating], split: bool +) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: @@ -342,7 +346,9 @@ def test_run_measurement_not_terminal_no_repetitions(dtype: type[np.complexfloat @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_repetitions_measurement_not_terminal(dtype: type[np.complexfloating], split: bool): +def test_run_repetitions_measurement_not_terminal( + dtype: type[np.complexfloating], split: bool +) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: @@ -368,7 +374,7 @@ def test_run_repetitions_measurement_not_terminal(dtype: type[np.complexfloating @pytest.mark.parametrize('split', [True, False]) def test_run_qudits_repetitions_measurement_not_terminal( dtype: type[np.complexfloating], split: bool -): +) -> None: q0, q1 = cirq.LineQid.for_qid_shape((2, 3)) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: @@ -392,7 +398,7 @@ def test_run_qudits_repetitions_measurement_not_terminal( @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_param_resolver(dtype: type[np.complexfloating], split: bool): +def test_run_param_resolver(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -411,7 +417,7 @@ def test_run_param_resolver(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_correlations(dtype: type[np.complexfloating], split: bool): +def test_run_correlations(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) circuit = cirq.Circuit(cirq.H(q0), cirq.CNOT(q0, q1), cirq.measure(q0, q1)) @@ -423,7 +429,7 @@ def test_run_correlations(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_measure_multiple_qubits(dtype: type[np.complexfloating], split: bool): +def test_run_measure_multiple_qubits(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -435,7 +441,7 @@ def test_run_measure_multiple_qubits(dtype: type[np.complexfloating], split: boo @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_measure_multiple_qudits(dtype: type[np.complexfloating], split: bool): +def test_run_measure_multiple_qudits(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQid.for_qid_shape((2, 3)) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -449,7 +455,7 @@ def test_run_measure_multiple_qudits(dtype: type[np.complexfloating], split: boo @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_sweeps_param_resolvers(dtype: type[np.complexfloating], split: bool): +def test_run_sweeps_param_resolvers(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -475,7 +481,7 @@ def test_run_sweeps_param_resolvers(dtype: type[np.complexfloating], split: bool @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_no_circuit(dtype: type[np.complexfloating], split: bool): +def test_simulate_no_circuit(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) circuit = cirq.Circuit() @@ -488,7 +494,7 @@ def test_simulate_no_circuit(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate(dtype: type[np.complexfloating], split: bool): +def test_simulate(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) circuit = cirq.Circuit(cirq.H(q0), cirq.H(q1)) @@ -499,7 +505,7 @@ def test_simulate(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_qudits(dtype: type[np.complexfloating], split: bool): +def test_simulate_qudits(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQid.for_qid_shape((2, 3)) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) circuit = cirq.Circuit(cirq.H(q0), cirq.XPowGate(dimension=3)(q1) ** 2) @@ -514,7 +520,7 @@ def test_simulate_qudits(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('split', [True, False]) def test_reset_one_qubit_does_not_affect_partial_trace_of_other_qubits( dtype: type[np.complexfloating], split: bool -): +) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) circuit = cirq.Circuit(cirq.H(q0), cirq.CX(q0, q1), cirq.reset(q0)) @@ -532,7 +538,9 @@ def test_reset_one_qubit_does_not_affect_partial_trace_of_other_qubits( [cirq.testing.random_circuit(cirq.LineQubit.range(4), 5, 0.9) for _ in range(20)], ), ) -def test_simulate_compare_to_state_vector_simulator(dtype: type[np.complexfloating], circuit): +def test_simulate_compare_to_state_vector_simulator( + dtype: type[np.complexfloating], circuit +) -> None: qubits = cirq.LineQubit.range(4) pure_result = ( cirq.Simulator(dtype=dtype).simulate(circuit, qubit_order=qubits).density_matrix_of() @@ -548,7 +556,7 @@ def test_simulate_compare_to_state_vector_simulator(dtype: type[np.complexfloati @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_bit_flips(dtype: type[np.complexfloating], split: bool): +def test_simulate_bit_flips(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -565,7 +573,7 @@ def test_simulate_bit_flips(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_qudit_increments(dtype: type[np.complexfloating], split: bool): +def test_simulate_qudit_increments(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQid.for_qid_shape((2, 3)) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -593,7 +601,7 @@ def test_simulate_initial_state( dtype: type[np.complexfloating], split: bool, initial_state: int | cirq.DensityMatrixSimulationState, -): +) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -607,7 +615,7 @@ def test_simulate_initial_state( @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulation_state(dtype: type[np.complexfloating], split: bool): +def test_simulation_state(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -620,7 +628,7 @@ def test_simulation_state(dtype: type[np.complexfloating], split: bool): np.testing.assert_equal(result.final_density_matrix, expected_density_matrix) -def test_simulate_tps_initial_state(): +def test_simulate_tps_initial_state() -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator() for b0 in [0, 1]: @@ -634,7 +642,7 @@ def test_simulate_tps_initial_state(): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_initial_qudit_state(dtype: type[np.complexfloating], split: bool): +def test_simulate_initial_qudit_state(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQid.for_qid_shape((3, 4)) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1, 2]: @@ -654,7 +662,7 @@ def test_simulate_initial_qudit_state(dtype: type[np.complexfloating], split: bo @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_qubit_order(dtype: type[np.complexfloating], split: bool): +def test_simulate_qubit_order(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -668,7 +676,7 @@ def test_simulate_qubit_order(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_param_resolver(dtype: type[np.complexfloating], split: bool): +def test_simulate_param_resolver(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -687,7 +695,7 @@ def test_simulate_param_resolver(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_measure_multiple_qubits(dtype: type[np.complexfloating], split: bool): +def test_simulate_measure_multiple_qubits(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -699,7 +707,7 @@ def test_simulate_measure_multiple_qubits(dtype: type[np.complexfloating], split @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_measure_multiple_qudits(dtype: type[np.complexfloating], split: bool): +def test_simulate_measure_multiple_qudits(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQid.for_qid_shape((2, 3)) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -713,7 +721,7 @@ def test_simulate_measure_multiple_qudits(dtype: type[np.complexfloating], split @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_sweeps_param_resolver(dtype: type[np.complexfloating], split: bool): +def test_simulate_sweeps_param_resolver(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -740,7 +748,7 @@ def test_simulate_sweeps_param_resolver(dtype: type[np.complexfloating], split: @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_moment_steps(dtype: type[np.complexfloating], split: bool): +def test_simulate_moment_steps(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit(cirq.H(q0), cirq.H(q1), cirq.H(q0), cirq.H(q1)) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) @@ -754,7 +762,7 @@ def test_simulate_moment_steps(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_moment_steps_qudits(dtype: type[np.complexfloating], split: bool): +def test_simulate_moment_steps_qudits(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQid.for_qid_shape((2, 3)) circuit = cirq.Circuit( cirq.XPowGate(dimension=2)(q0), @@ -775,7 +783,7 @@ def test_simulate_moment_steps_qudits(dtype: type[np.complexfloating], split: bo @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_moment_steps_empty_circuit(dtype: type[np.complexfloating], split: bool): +def test_simulate_moment_steps_empty_circuit(dtype: type[np.complexfloating], split: bool) -> None: circuit = cirq.Circuit() simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) for step in simulator.simulate_moment_steps(circuit): @@ -786,7 +794,7 @@ def test_simulate_moment_steps_empty_circuit(dtype: type[np.complexfloating], sp @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_moment_steps_sample(dtype: type[np.complexfloating], split: bool): +def test_simulate_moment_steps_sample(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit(cirq.H(q0), cirq.CNOT(q0, q1)) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) @@ -807,7 +815,7 @@ def test_simulate_moment_steps_sample(dtype: type[np.complexfloating], split: bo @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_moment_steps_sample_qudits(dtype: type[np.complexfloating], split: bool): +def test_simulate_moment_steps_sample_qudits(dtype: type[np.complexfloating], split: bool) -> None: class TestGate(cirq.Gate): """Swaps the 2nd qid |0> and |2> states when the 1st is |1>.""" @@ -838,7 +846,7 @@ def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs): @pytest.mark.parametrize('split', [True, False]) def test_simulate_moment_steps_intermediate_measurement( dtype: type[np.complexfloating], split: bool -): +) -> None: q0 = cirq.LineQubit(0) circuit = cirq.Circuit(cirq.H(q0), cirq.measure(q0), cirq.H(q0)) simulator = cirq.DensityMatrixSimulator(dtype=dtype, split_untangled_states=split) @@ -854,7 +862,7 @@ def test_simulate_moment_steps_intermediate_measurement( @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) -def test_simulate_expectation_values(dtype): +def test_simulate_expectation_values(dtype) -> None: # Compare with test_expectation_from_state_vector_two_qubit_states # in file: cirq/ops/linear_combinations_test.py q0, q1 = cirq.LineQubit.range(2) @@ -878,7 +886,7 @@ def test_simulate_expectation_values(dtype): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) -def test_simulate_noisy_expectation_values(dtype): +def test_simulate_noisy_expectation_values(dtype) -> None: q0 = cirq.LineQubit(0) psums = [cirq.Z(q0), cirq.X(q0)] c1 = cirq.Circuit(cirq.X(q0), cirq.amplitude_damp(gamma=0.1).on(q0)) @@ -896,7 +904,7 @@ def test_simulate_noisy_expectation_values(dtype): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) -def test_simulate_expectation_values_terminal_measure(dtype): +def test_simulate_expectation_values_terminal_measure(dtype) -> None: q0 = cirq.LineQubit(0) circuit = cirq.Circuit(cirq.H(q0), cirq.measure(q0)) obs = cirq.Z(q0) @@ -933,7 +941,7 @@ def test_simulate_expectation_values_terminal_measure(dtype): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) -def test_simulate_expectation_values_qubit_order(dtype): +def test_simulate_expectation_values_qubit_order(dtype) -> None: q0, q1, q2 = cirq.LineQubit.range(3) circuit = cirq.Circuit(cirq.H(q0), cirq.H(q1), cirq.X(q2)) obs = cirq.X(q0) + cirq.X(q1) - cirq.Z(q2) @@ -947,7 +955,7 @@ def test_simulate_expectation_values_qubit_order(dtype): assert cirq.approx_eq(result_flipped[0], 3, atol=1e-6) -def test_density_matrix_step_result_repr(): +def test_density_matrix_step_result_repr() -> None: q0 = cirq.LineQubit(0) assert ( repr( @@ -967,7 +975,7 @@ def test_density_matrix_step_result_repr(): ) -def test_density_matrix_trial_result_eq(): +def test_density_matrix_trial_result_eq() -> None: q0 = cirq.LineQubit(0) final_simulator_state = cirq.DensityMatrixSimulationState( initial_state=np.ones((2, 2)) * 0.5, qubits=[q0] @@ -1001,7 +1009,9 @@ def test_density_matrix_trial_result_eq(): ) -def test_density_matrix_trial_result_qid_shape(): +def test_density_matrix_trial_result_qid_shape() -> None: + q0: cirq.Qid + q1: cirq.Qid q0, q1 = cirq.LineQubit.range(2) final_simulator_state = cirq.DensityMatrixSimulationState( initial_state=np.ones((4, 4)) / 4, qubits=[q0, q1] @@ -1026,7 +1036,7 @@ def test_density_matrix_trial_result_qid_shape(): ) == (3, 4) -def test_density_matrix_trial_result_repr(): +def test_density_matrix_trial_result_repr() -> None: q0 = cirq.LineQubit(0) dtype = np.complex64 final_simulator_state = cirq.DensityMatrixSimulationState( @@ -1060,17 +1070,17 @@ def __init__(self, q): self.q = q # pragma: no cover @property - def qubits(self): + def qubits(self) -> tuple[cirq.Qid, ...]: return (self.q,) # pragma: no cover - def with_qubits(self, *new_qubits): + def with_qubits(self, *new_qubits) -> cirq.Operation: return XAsOp(new_qubits[0]) # pragma: no cover def _kraus_(self): return cirq.kraus(cirq.X) # pragma: no cover -def test_works_on_operation(): +def test_works_on_operation() -> None: class XAsOp(cirq.Operation): def __init__(self, q): self.q = q @@ -1090,7 +1100,7 @@ def _kraus_(self): np.testing.assert_allclose(s.simulate(c).final_density_matrix, np.diag([0, 1]), atol=1e-8) -def test_works_on_pauli_string_phasor(): +def test_works_on_pauli_string_phasor() -> None: a, b = cirq.LineQubit.range(2) c = cirq.Circuit(np.exp(0.5j * np.pi * cirq.X(a) * cirq.X(b))) sim = cirq.DensityMatrixSimulator() @@ -1098,7 +1108,7 @@ def test_works_on_pauli_string_phasor(): np.testing.assert_allclose(result.reshape(4, 4), np.diag([0, 0, 0, 1]), atol=1e-8) -def test_works_on_pauli_string(): +def test_works_on_pauli_string() -> None: a, b = cirq.LineQubit.range(2) c = cirq.Circuit(cirq.X(a) * cirq.X(b)) sim = cirq.DensityMatrixSimulator() @@ -1106,7 +1116,7 @@ def test_works_on_pauli_string(): np.testing.assert_allclose(result.reshape(4, 4), np.diag([0, 0, 0, 1]), atol=1e-8) -def test_density_matrix_trial_result_str(): +def test_density_matrix_trial_result_str() -> None: q0 = cirq.LineQubit(0) dtype = np.complex64 final_simulator_state = cirq.DensityMatrixSimulationState( @@ -1130,7 +1140,7 @@ def test_density_matrix_trial_result_str(): ) -def test_density_matrix_trial_result_repr_pretty(): +def test_density_matrix_trial_result_repr_pretty() -> None: q0 = cirq.LineQubit(0) dtype = np.complex64 final_simulator_state = cirq.DensityMatrixSimulationState( @@ -1158,7 +1168,7 @@ def test_density_matrix_trial_result_repr_pretty(): cirq.testing.assert_repr_pretty(result, "cirq.DensityMatrixTrialResult(...)", cycle=True) -def test_run_sweep_parameters_not_resolved(): +def test_run_sweep_parameters_not_resolved() -> None: a = cirq.LineQubit(0) simulator = cirq.DensityMatrixSimulator() circuit = cirq.Circuit(cirq.XPowGate(exponent=sympy.Symbol('a'))(a), cirq.measure(a)) @@ -1166,7 +1176,7 @@ def test_run_sweep_parameters_not_resolved(): _ = simulator.run_sweep(circuit, cirq.ParamResolver({})) -def test_simulate_sweep_parameters_not_resolved(): +def test_simulate_sweep_parameters_not_resolved() -> None: a = cirq.LineQubit(0) simulator = cirq.DensityMatrixSimulator() circuit = cirq.Circuit(cirq.XPowGate(exponent=sympy.Symbol('a'))(a), cirq.measure(a)) @@ -1174,7 +1184,7 @@ def test_simulate_sweep_parameters_not_resolved(): _ = simulator.simulate_sweep(circuit, cirq.ParamResolver({})) -def test_random_seed(): +def test_random_seed() -> None: a = cirq.NamedQubit('a') circuit = cirq.Circuit(cirq.X(a) ** 0.5, cirq.measure(a)) @@ -1193,7 +1203,7 @@ def test_random_seed(): ) -def test_random_seed_does_not_modify_global_state_terminal_measurements(): +def test_random_seed_does_not_modify_global_state_terminal_measurements() -> None: a = cirq.NamedQubit('a') circuit = cirq.Circuit(cirq.X(a) ** 0.5, cirq.measure(a)) @@ -1208,7 +1218,7 @@ def test_random_seed_does_not_modify_global_state_terminal_measurements(): assert result1 == result2 -def test_random_seed_does_not_modify_global_state_non_terminal_measurements(): +def test_random_seed_does_not_modify_global_state_non_terminal_measurements() -> None: a = cirq.NamedQubit('a') circuit = cirq.Circuit( cirq.X(a) ** 0.5, cirq.measure(a, key='a0'), cirq.X(a) ** 0.5, cirq.measure(a, key='a1') @@ -1225,7 +1235,7 @@ def test_random_seed_does_not_modify_global_state_non_terminal_measurements(): assert result1 == result2 -def test_random_seed_terminal_measurements_deterministic(): +def test_random_seed_terminal_measurements_deterministic() -> None: a = cirq.NamedQubit('a') circuit = cirq.Circuit(cirq.X(a) ** 0.5, cirq.measure(a, key='a')) sim = cirq.DensityMatrixSimulator(seed=1234) @@ -1303,7 +1313,7 @@ def test_random_seed_terminal_measurements_deterministic(): ) -def test_random_seed_non_terminal_measurements_deterministic(): +def test_random_seed_non_terminal_measurements_deterministic() -> None: a = cirq.NamedQubit('a') circuit = cirq.Circuit( cirq.X(a) ** 0.5, cirq.measure(a, key='a'), cirq.X(a) ** 0.5, cirq.measure(a, key='b') @@ -1382,7 +1392,7 @@ def test_random_seed_non_terminal_measurements_deterministic(): ) -def test_simulate_with_invert_mask(): +def test_simulate_with_invert_mask() -> None: q0, q1, q2, q3, q4 = cirq.LineQid.for_qid_shape((2, 3, 3, 3, 4)) c = cirq.Circuit( cirq.XPowGate(dimension=2)(q0), @@ -1394,7 +1404,7 @@ def test_simulate_with_invert_mask(): assert np.all(cirq.DensityMatrixSimulator().run(c).measurements['a'] == [[0, 1, 0, 2, 3]]) -def test_simulate_noise_with_terminal_measurements(): +def test_simulate_noise_with_terminal_measurements() -> None: q = cirq.LineQubit(0) circuit1 = cirq.Circuit(cirq.measure(q)) circuit2 = circuit1 + cirq.I(q) @@ -1406,7 +1416,7 @@ def test_simulate_noise_with_terminal_measurements(): assert result1 == result2 -def test_simulate_noise_with_subcircuit_measurements(): +def test_simulate_noise_with_subcircuit_measurements() -> None: q = cirq.LineQubit(0) circuit1 = cirq.Circuit(cirq.measure(q)) circuit2 = cirq.Circuit(cirq.CircuitOperation(cirq.Circuit(cirq.measure(q)).freeze())) @@ -1418,7 +1428,7 @@ def test_simulate_noise_with_subcircuit_measurements(): assert result1 == result2 -def test_nonmeasuring_subcircuits_do_not_cause_sweep_repeat(): +def test_nonmeasuring_subcircuits_do_not_cause_sweep_repeat() -> None: q = cirq.LineQubit(0) circuit = cirq.Circuit( cirq.CircuitOperation(cirq.Circuit(cirq.H(q)).freeze()), cirq.measure(q, key='x') @@ -1429,7 +1439,7 @@ def test_nonmeasuring_subcircuits_do_not_cause_sweep_repeat(): assert mock_sim.call_count == 2 -def test_measuring_subcircuits_cause_sweep_repeat(): +def test_measuring_subcircuits_cause_sweep_repeat() -> None: q = cirq.LineQubit(0) circuit = cirq.Circuit( cirq.CircuitOperation(cirq.Circuit(cirq.measure(q)).freeze()), cirq.measure(q, key='x') @@ -1440,7 +1450,7 @@ def test_measuring_subcircuits_cause_sweep_repeat(): assert mock_sim.call_count == 11 -def test_density_matrix_copy(): +def test_density_matrix_copy() -> None: sim = cirq.DensityMatrixSimulator(split_untangled_states=False) q = cirq.LineQubit(0) @@ -1465,7 +1475,7 @@ def test_density_matrix_copy(): assert all(not np.shares_memory(x, y) for x, y in itertools.combinations(matrices, 2)) -def test_final_density_matrix_is_not_last_object(): +def test_final_density_matrix_is_not_last_object() -> None: sim = cirq.DensityMatrixSimulator() q = cirq.LineQubit(0) @@ -1477,7 +1487,7 @@ def test_final_density_matrix_is_not_last_object(): np.testing.assert_equal(result.final_density_matrix, initial_state) -def test_density_matrices_same_with_or_without_split_untangled_states(): +def test_density_matrices_same_with_or_without_split_untangled_states() -> None: sim = cirq.DensityMatrixSimulator(split_untangled_states=False) q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit(cirq.H(q0), cirq.CX.on(q0, q1), cirq.reset(q1)) @@ -1487,7 +1497,7 @@ def test_density_matrices_same_with_or_without_split_untangled_states(): assert np.allclose(result1, result2) -def test_large_untangled_okay(): +def test_large_untangled_okay() -> None: circuit = cirq.Circuit() for i in range(59): for _ in range(9): @@ -1499,9 +1509,9 @@ def test_large_untangled_okay(): _ = cirq.DensityMatrixSimulator(split_untangled_states=False).simulate(circuit) # Validate a simulation run - result = cirq.DensityMatrixSimulator().simulate(circuit) - assert set(result._final_simulator_state.qubits) == set(cirq.LineQubit.range(59)) - # _ = result.final_density_matrix hangs (as expected) + sim_result = cirq.DensityMatrixSimulator().simulate(circuit) + assert set(sim_result._final_simulator_state.qubits) == set(cirq.LineQubit.range(59)) + # _ = sim_result.final_density_matrix hangs (as expected) # Validate a trial run and sampling result = cirq.DensityMatrixSimulator().run(circuit, repetitions=1000) @@ -1510,7 +1520,7 @@ def test_large_untangled_okay(): assert (result.measurements['q(0)'] == np.full(1000, 1)).all() -def test_separated_states_str_does_not_merge(): +def test_separated_states_str_does_not_merge() -> None: q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit(cirq.measure(q0), cirq.measure(q1), cirq.X(q0)) @@ -1535,7 +1545,7 @@ def test_separated_states_str_does_not_merge(): ) -def test_unseparated_states_str(): +def test_unseparated_states_str() -> None: q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit(cirq.measure(q0), cirq.measure(q1), cirq.X(q0)) @@ -1553,7 +1563,7 @@ def test_unseparated_states_str(): ) -def test_sweep_unparameterized_prefix_not_repeated_even_non_unitaries(): +def test_sweep_unparameterized_prefix_not_repeated_even_non_unitaries() -> None: q = cirq.LineQubit(0) class NonUnitaryOp(cirq.Operation): diff --git a/cirq-core/cirq/sim/density_matrix_utils_test.py b/cirq-core/cirq/sim/density_matrix_utils_test.py index 0d53c3d40c6..506917e5854 100644 --- a/cirq-core/cirq/sim/density_matrix_utils_test.py +++ b/cirq-core/cirq/sim/density_matrix_utils_test.py @@ -216,7 +216,7 @@ def test_measure_density_matrix_partial_indices_all_orders() -> None: assert bits == [bool(1 & (x >> (2 - p))) for p in perm] -def matrix_000_plus_010(): +def matrix_000_plus_010() -> np.ndarray: state = np.zeros(8, dtype=np.complex64) state[0] = 1 / np.sqrt(2) state[2] = 1j / np.sqrt(2) diff --git a/cirq-core/cirq/sim/mux_test.py b/cirq-core/cirq/sim/mux_test.py index fd5c3781d7f..b4df88dfa8c 100644 --- a/cirq-core/cirq/sim/mux_test.py +++ b/cirq-core/cirq/sim/mux_test.py @@ -26,7 +26,7 @@ import cirq.testing -def test_sample(): +def test_sample() -> None: q = cirq.NamedQubit('q') with pytest.raises(ValueError, match="no measurements"): @@ -49,7 +49,7 @@ def test_sample(): assert results.histogram(key=q) == collections.Counter({0: 1}) -def test_sample_seed_unitary(): +def test_sample_seed_unitary() -> None: q = cirq.NamedQubit('q') circuit = cirq.Circuit(cirq.X(q) ** 0.2, cirq.measure(q)) result = cirq.sample(circuit, repetitions=10, seed=1234) @@ -60,7 +60,7 @@ def test_sample_seed_unitary(): ) -def test_sample_seed_non_unitary(): +def test_sample_seed_non_unitary() -> None: q = cirq.NamedQubit('q') circuit = cirq.Circuit(cirq.depolarize(0.5).on(q), cirq.measure(q)) result = cirq.sample(circuit, repetitions=10, seed=1234) @@ -70,7 +70,7 @@ def test_sample_seed_non_unitary(): ) -def test_sample_sweep(): +def test_sample_sweep() -> None: q = cirq.NamedQubit('q') c = cirq.Circuit(cirq.X(q), cirq.Y(q) ** sympy.Symbol('t'), cirq.measure(q)) @@ -102,7 +102,7 @@ def test_sample_sweep(): assert results[1].histogram(key=q) == collections.Counter({0: 3}) -def test_sample_sweep_seed(): +def test_sample_sweep_seed() -> None: q = cirq.NamedQubit('q') circuit = cirq.Circuit(cirq.X(q) ** sympy.Symbol('t'), cirq.measure(q)) @@ -124,7 +124,7 @@ def test_sample_sweep_seed(): assert np.all(results[2].measurements['q'] == [[True], [False]]) -def test_final_state_vector_different_program_types(): +def test_final_state_vector_different_program_types() -> None: a, b = cirq.LineQubit.range(2) np.testing.assert_allclose(cirq.final_state_vector(cirq.X), [0, 1], atol=1e-8) @@ -140,7 +140,7 @@ def test_final_state_vector_different_program_types(): ) -def test_final_state_vector_initial_state(): +def test_final_state_vector_initial_state() -> None: np.testing.assert_allclose(cirq.final_state_vector(cirq.X, initial_state=0), [0, 1], atol=1e-8) np.testing.assert_allclose(cirq.final_state_vector(cirq.X, initial_state=1), [1, 0], atol=1e-8) @@ -152,7 +152,7 @@ def test_final_state_vector_initial_state(): ) -def test_final_state_vector_dtype_insensitive_to_initial_state(): +def test_final_state_vector_dtype_insensitive_to_initial_state() -> None: assert cirq.final_state_vector(cirq.X).dtype == np.complex64 assert cirq.final_state_vector(cirq.X, initial_state=0).dtype == np.complex64 @@ -181,7 +181,7 @@ def test_final_state_vector_dtype_insensitive_to_initial_state(): ) -def test_final_state_vector_param_resolver(): +def test_final_state_vector_param_resolver() -> None: s = sympy.Symbol('s') with pytest.raises(ValueError, match='not unitary'): @@ -192,7 +192,7 @@ def test_final_state_vector_param_resolver(): ) -def test_final_state_vector_qubit_order(): +def test_final_state_vector_qubit_order() -> None: a, b = cirq.LineQubit.range(2) np.testing.assert_allclose( @@ -206,7 +206,7 @@ def test_final_state_vector_qubit_order(): ) -def test_final_state_vector_ignore_terminal_measurement(): +def test_final_state_vector_ignore_terminal_measurement() -> None: a, b = cirq.LineQubit.range(2) np.testing.assert_allclose( @@ -226,7 +226,7 @@ def test_final_state_vector_ignore_terminal_measurement(): @pytest.mark.parametrize('repetitions', (0, 1, 100)) -def test_repetitions(repetitions): +def test_repetitions(repetitions) -> None: a = cirq.LineQubit(0) c = cirq.Circuit(cirq.H(a), cirq.measure(a, key='m')) r = cirq.sample(c, repetitions=repetitions) @@ -235,7 +235,7 @@ def test_repetitions(repetitions): assert np.issubdtype(samples.dtype, np.integer) -def test_final_density_matrix_different_program_types(): +def test_final_density_matrix_different_program_types() -> None: a, b = cirq.LineQubit.range(2) np.testing.assert_allclose(cirq.final_density_matrix(cirq.X), [[0, 0], [0, 1]], atol=1e-8) @@ -244,12 +244,12 @@ def test_final_density_matrix_different_program_types(): np.testing.assert_allclose( cirq.final_density_matrix(cirq.Circuit(ops)), - [[0.5, 0, 0, 0.5], [0, 0, 0, 0], [0, 0, 0, 0], [0.5, 0, 0, 0.5]], + np.asarray([[0.5, 0, 0, 0.5], [0, 0, 0, 0], [0, 0, 0, 0], [0.5, 0, 0, 0.5]]), atol=1e-8, ) -def test_final_density_matrix_initial_state(): +def test_final_density_matrix_initial_state() -> None: np.testing.assert_allclose( cirq.final_density_matrix(cirq.X, initial_state=0), [[0, 0], [0, 1]], atol=1e-8 ) @@ -265,7 +265,7 @@ def test_final_density_matrix_initial_state(): ) -def test_final_density_matrix_dtype_insensitive_to_initial_state(): +def test_final_density_matrix_dtype_insensitive_to_initial_state() -> None: assert cirq.final_density_matrix(cirq.X).dtype == np.complex64 assert cirq.final_density_matrix(cirq.X, initial_state=0).dtype == np.complex64 @@ -296,7 +296,7 @@ def test_final_density_matrix_dtype_insensitive_to_initial_state(): ) -def test_final_density_matrix_param_resolver(): +def test_final_density_matrix_param_resolver() -> None: s = sympy.Symbol('s') with pytest.raises(ValueError, match='not specified in parameter sweep'): @@ -308,17 +308,17 @@ def test_final_density_matrix_param_resolver(): ) -def test_final_density_matrix_qubit_order(): +def test_final_density_matrix_qubit_order() -> None: a, b = cirq.LineQubit.range(2) np.testing.assert_allclose( cirq.final_density_matrix([cirq.X(a), cirq.X(b) ** 0.5], qubit_order=[a, b]), - [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0.5, 0.5j], [0, 0, -0.5j, 0.5]], + np.asarray([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0.5, 0.5j], [0, 0, -0.5j, 0.5]]), ) np.testing.assert_allclose( cirq.final_density_matrix([cirq.X(a), cirq.X(b) ** 0.5], qubit_order=[b, a]), - [[0, 0, 0, 0], [0, 0.5, 0, 0.5j], [0, 0, 0, 0], [0, -0.5j, 0, 0.5]], + np.asarray([[0, 0, 0, 0], [0, 0.5, 0, 0.5j], [0, 0, 0, 0], [0, -0.5j, 0, 0.5]]), ) np.testing.assert_allclose( @@ -331,7 +331,7 @@ def test_final_density_matrix_qubit_order(): ) -def test_final_density_matrix_seed_with_dephasing(): +def test_final_density_matrix_seed_with_dephasing() -> None: a = cirq.LineQubit(0) np.testing.assert_allclose( cirq.final_density_matrix([cirq.X(a) ** 0.5, cirq.measure(a)], seed=123), @@ -345,7 +345,7 @@ def test_final_density_matrix_seed_with_dephasing(): ) -def test_final_density_matrix_seed_with_collapsing(): +def test_final_density_matrix_seed_with_collapsing() -> None: a = cirq.LineQubit(0) np.testing.assert_allclose( cirq.final_density_matrix( @@ -363,7 +363,7 @@ def test_final_density_matrix_seed_with_collapsing(): ) -def test_final_density_matrix_noise(): +def test_final_density_matrix_noise() -> None: a = cirq.LineQubit(0) np.testing.assert_allclose( cirq.final_density_matrix([cirq.H(a), cirq.Z(a), cirq.H(a), cirq.measure(a)]), @@ -380,7 +380,7 @@ def test_final_density_matrix_noise(): ) -def test_final_density_matrix_classical_control(): +def test_final_density_matrix_classical_control() -> None: q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit( cirq.H(q0), @@ -394,7 +394,7 @@ def test_final_density_matrix_classical_control(): ) -def test_ps_initial_state_wfn(): +def test_ps_initial_state_wfn() -> None: q0, q1 = cirq.LineQubit.range(2) s00 = cirq.KET_ZERO(q0) * cirq.KET_ZERO(q1) sp0 = cirq.KET_PLUS(q0) * cirq.KET_ZERO(q1) @@ -410,7 +410,7 @@ def test_ps_initial_state_wfn(): ) -def test_ps_initial_state_dmat(): +def test_ps_initial_state_dmat() -> None: q0, q1 = cirq.LineQubit.range(2) s00 = cirq.KET_ZERO(q0) * cirq.KET_ZERO(q1) sp0 = cirq.KET_PLUS(q0) * cirq.KET_ZERO(q1) diff --git a/cirq-core/cirq/sim/simulation_product_state_test.py b/cirq-core/cirq/sim/simulation_product_state_test.py index f2336b700dc..fe60b384b17 100644 --- a/cirq-core/cirq/sim/simulation_product_state_test.py +++ b/cirq-core/cirq/sim/simulation_product_state_test.py @@ -14,29 +14,29 @@ from __future__ import annotations -from typing import Any, Sequence +from typing import Any, Self, Sequence import cirq class EmptyQuantumState(cirq.QuantumStateRepresentation): - def copy(self, deep_copy_buffers=True): + def copy(self, deep_copy_buffers=True) -> Self: return self - def measure(self, axes, seed=None): + def measure(self, axes, seed=None) -> list[int]: return [0] * len(axes) @property - def supports_factor(self): + def supports_factor(self) -> bool: return True - def kron(self, other): + def kron(self, other) -> Self: return self - def factor(self, axes, *, validate=True, atol=1e-07): + def factor(self, axes, *, validate=True, atol=1e-07) -> tuple[Self, Self]: return self, self - def reindex(self, axes): + def reindex(self, axes) -> Self: return self diff --git a/cirq-core/cirq/sim/simulation_state.py b/cirq-core/cirq/sim/simulation_state.py index 594987110c2..b0148d6151c 100644 --- a/cirq-core/cirq/sim/simulation_state.py +++ b/cirq-core/cirq/sim/simulation_state.py @@ -73,7 +73,7 @@ def measure( key: str, invert_mask: Sequence[bool], confusion_map: dict[tuple[int, ...], np.ndarray], - ): + ) -> None: """Measures the qubits and records to `log_of_measurement_results`. Any bitmasks will be applied to the measurement record. @@ -209,7 +209,7 @@ def factor( return extracted, remainder @property - def allows_factoring(self): + def allows_factoring(self) -> bool: """Subclasses that allow factorization should override this.""" return self._state.supports_factor if self._state is not None else False @@ -238,7 +238,7 @@ def transpose_to_qubit_order(self, qubits: Sequence[cirq.Qid], *, inplace=False) def qubits(self) -> tuple[cirq.Qid, ...]: return self._qubits - def swap(self, q1: cirq.Qid, q2: cirq.Qid, *, inplace=False): + def swap(self, q1: cirq.Qid, q2: cirq.Qid, *, inplace=False) -> Self: """Swaps two qubits. This only affects the index, and does not modify the underlying @@ -269,7 +269,7 @@ def swap(self, q1: cirq.Qid, q2: cirq.Qid, *, inplace=False): args._set_qubits(qubits) return args - def rename(self, q1: cirq.Qid, q2: cirq.Qid, *, inplace=False): + def rename(self, q1: cirq.Qid, q2: cirq.Qid, *, inplace=False) -> Self: """Renames `q1` to `q2`. Args: diff --git a/cirq-core/cirq/sim/simulation_state_base.py b/cirq-core/cirq/sim/simulation_state_base.py index 8acf5cbf3f5..9507d9f84cf 100644 --- a/cirq-core/cirq/sim/simulation_state_base.py +++ b/cirq-core/cirq/sim/simulation_state_base.py @@ -81,7 +81,7 @@ def _act_on_fallback_( Returns: True if the fallback applies, else NotImplemented.""" - def apply_operation(self, op: cirq.Operation): + def apply_operation(self, op: cirq.Operation) -> None: protocols.act_on(op, self) @abc.abstractmethod diff --git a/cirq-core/cirq/sim/simulation_state_test.py b/cirq-core/cirq/sim/simulation_state_test.py index 3c270199a1b..92a6e3050fe 100644 --- a/cirq-core/cirq/sim/simulation_state_test.py +++ b/cirq-core/cirq/sim/simulation_state_test.py @@ -25,13 +25,13 @@ class ExampleQuantumState(cirq.QuantumStateRepresentation): - def copy(self, deep_copy_buffers=True): - pass + def copy(self, deep_copy_buffers=True) -> ExampleQuantumState: + raise NotImplementedError() - def measure(self, axes, seed=None): + def measure(self, axes, seed=None) -> list[int]: return [5, 3] - def reindex(self, axes): + def reindex(self, axes) -> ExampleQuantumState: return self @@ -44,7 +44,7 @@ def _act_on_fallback_( ) -> bool: return True - def add_qubits(self, qubits): + def add_qubits(self, qubits) -> ExampleSimulationState: super().add_qubits(qubits) return self diff --git a/cirq-core/cirq/sim/simulator.py b/cirq-core/cirq/sim/simulator.py index e5b5f3934dd..56b54b9290b 100644 --- a/cirq-core/cirq/sim/simulator.py +++ b/cirq-core/cirq/sim/simulator.py @@ -217,7 +217,7 @@ def compute_amplitudes_sweep_iter( def sample_from_amplitudes( self, circuit: cirq.AbstractCircuit, - param_resolver: cirq.ParamResolver, + param_resolver: cirq.ParamResolverOrSimilarType, seed: cirq.RANDOM_STATE_OR_SEED_LIKE, repetitions: int = 1, qubit_order: cirq.QubitOrderOrList = ops.QubitOrder.DEFAULT, @@ -938,7 +938,7 @@ def _qubit_map_to_shape(qubit_map: Mapping[cirq.Qid, int]) -> tuple[int, ...]: return tuple(qid_shape) -def check_all_resolved(circuit): +def check_all_resolved(circuit) -> None: """Raises if the circuit contains unresolved symbols.""" if protocols.is_parameterized(circuit): unresolved = [op for moment in circuit for op in moment if protocols.is_parameterized(op)] diff --git a/cirq-core/cirq/sim/simulator_base_test.py b/cirq-core/cirq/sim/simulator_base_test.py index 14acda0ae06..dee6f8beaf9 100644 --- a/cirq-core/cirq/sim/simulator_base_test.py +++ b/cirq-core/cirq/sim/simulator_base_test.py @@ -73,25 +73,25 @@ def _act_on_fallback_( return True @property - def data(self): + def data(self) -> Any: return self._state.data @property - def gate_count(self): + def gate_count(self) -> int: return self._state.gate_count @property - def measurement_count(self): + def measurement_count(self) -> int: return self._state.measurement_count @property - def copy_count(self): + def copy_count(self) -> int: return self._state.copy_count class SplittableCountingSimulationState(CountingSimulationState): @property - def allows_factoring(self): + def allows_factoring(self) -> bool: return True @@ -167,59 +167,64 @@ def _create_partial_simulation_state( class TestOp(cirq.Operation): - def with_qubits(self, *new_qubits): - pass + def with_qubits(self, *new_qubits) -> cirq.Operation: + raise NotImplementedError() @property - def qubits(self): - return [q0] + def qubits(self) -> tuple[cirq.Qid, ...]: + return (q0,) -def test_simulate_empty_circuit(): +def test_simulate_empty_circuit() -> None: sim = CountingSimulator() r = sim.simulate(cirq.Circuit()) + assert isinstance(r._final_simulator_state, CountingSimulationState) assert r._final_simulator_state.gate_count == 0 assert r._final_simulator_state.measurement_count == 0 assert r._final_simulator_state.copy_count == 0 -def test_simulate_one_gate_circuit(): +def test_simulate_one_gate_circuit() -> None: sim = CountingSimulator() r = sim.simulate(cirq.Circuit(cirq.X(q0))) + assert isinstance(r._final_simulator_state, CountingSimulationState) assert r._final_simulator_state.gate_count == 1 assert r._final_simulator_state.copy_count == 0 -def test_simulate_one_measurement_circuit(): +def test_simulate_one_measurement_circuit() -> None: sim = CountingSimulator() r = sim.simulate(cirq.Circuit(cirq.measure(q0))) + assert isinstance(r._final_simulator_state, CountingSimulationState) assert r._final_simulator_state.gate_count == 0 assert r._final_simulator_state.measurement_count == 1 assert r._final_simulator_state.copy_count == 0 -def test_empty_circuit_simulation_has_moment(): +def test_empty_circuit_simulation_has_moment() -> None: sim = CountingSimulator() steps = list(sim.simulate_moment_steps(cirq.Circuit())) assert len(steps) == 1 -def test_noise_applied(): +def test_noise_applied() -> None: sim = CountingSimulator(noise=cirq.X) r = sim.simulate(cirq.Circuit(cirq.X(q0))) + assert isinstance(r._final_simulator_state, CountingSimulationState) assert r._final_simulator_state.gate_count == 2 assert r._final_simulator_state.copy_count == 0 -def test_noise_applied_measurement_gate(): +def test_noise_applied_measurement_gate() -> None: sim = CountingSimulator(noise=cirq.X) r = sim.simulate(cirq.Circuit(cirq.measure(q0))) + assert isinstance(r._final_simulator_state, CountingSimulationState) assert r._final_simulator_state.gate_count == 1 assert r._final_simulator_state.measurement_count == 1 assert r._final_simulator_state.copy_count == 0 -def test_parameterized_copies_all_but_last(): +def test_parameterized_copies_all_but_last() -> None: sim = CountingSimulator() n = 4 rs = sim.simulate_sweep( @@ -227,12 +232,13 @@ def test_parameterized_copies_all_but_last(): ) for i in range(n): r = rs[i] + assert isinstance(r._final_simulator_state, CountingSimulationState) assert r._final_simulator_state.gate_count == 1 assert r._final_simulator_state.measurement_count == 0 assert r._final_simulator_state.copy_count == 0 if i == n - 1 else 1 -def test_cannot_act(): +def test_cannot_act() -> None: class BadOp(TestOp): def _act_on_(self, sim_state): raise TypeError() @@ -242,25 +248,25 @@ def _act_on_(self, sim_state): sim.simulate(cirq.Circuit(BadOp())) -def test_run_one_gate_circuit(): +def test_run_one_gate_circuit() -> None: sim = CountingSimulator() r = sim.run(cirq.Circuit(cirq.X(q0), cirq.measure(q0)), repetitions=2) assert np.allclose(r.measurements['q(0)'], [[1], [1]]) -def test_run_one_gate_circuit_noise(): +def test_run_one_gate_circuit_noise() -> None: sim = CountingSimulator(noise=cirq.X) r = sim.run(cirq.Circuit(cirq.X(q0), cirq.measure(q0)), repetitions=2) assert np.allclose(r.measurements['q(0)'], [[2], [2]]) -def test_run_non_unitary_circuit(): +def test_run_non_unitary_circuit() -> None: sim = CountingSimulator() r = sim.run(cirq.Circuit(cirq.phase_damp(1).on(q0), cirq.measure(q0)), repetitions=2) assert np.allclose(r.measurements['q(0)'], [[1], [1]]) -def test_run_non_unitary_circuit_non_unitary_state(): +def test_run_non_unitary_circuit_non_unitary_state() -> None: class DensityCountingSimulator(CountingSimulator): def _can_be_in_run_prefix(self, val): return not cirq.is_measurement(val) @@ -270,15 +276,16 @@ def _can_be_in_run_prefix(self, val): assert np.allclose(r.measurements['q(0)'], [[1], [1]]) -def test_run_non_terminal_measurement(): +def test_run_non_terminal_measurement() -> None: sim = CountingSimulator() r = sim.run(cirq.Circuit(cirq.X(q0), cirq.measure(q0), cirq.X(q0)), repetitions=2) assert np.allclose(r.measurements['q(0)'], [[1], [1]]) -def test_integer_initial_state_is_split(): +def test_integer_initial_state_is_split() -> None: sim = SplittableCountingSimulator() state = sim._create_simulation_state(2, (q0, q1)) + assert isinstance(state, cirq.SimulationProductState) assert len(set(state.values())) == 3 assert state[q0] is not state[q1] assert state[q0].data == 1 @@ -286,7 +293,7 @@ def test_integer_initial_state_is_split(): assert state[None].data == 0 -def test_integer_initial_state_is_not_split_if_disabled(): +def test_integer_initial_state_is_not_split_if_disabled() -> None: sim = SplittableCountingSimulator(split_untangled_states=False) state = sim._create_simulation_state(2, (q0, q1)) assert isinstance(state, SplittableCountingSimulationState) @@ -294,7 +301,7 @@ def test_integer_initial_state_is_not_split_if_disabled(): assert state.data == 2 -def test_integer_initial_state_is_not_split_if_impossible(): +def test_integer_initial_state_is_not_split_if_impossible() -> None: sim = CountingSimulator() state = sim._create_simulation_state(2, (q0, q1)) assert isinstance(state, CountingSimulationState) @@ -303,18 +310,20 @@ def test_integer_initial_state_is_not_split_if_impossible(): assert state.data == 2 -def test_non_integer_initial_state_is_not_split(): +def test_non_integer_initial_state_is_not_split() -> None: sim = SplittableCountingSimulator() state = sim._create_simulation_state(entangled_state_repr, (q0, q1)) + assert isinstance(state, cirq.SimulationProductState) assert len(set(state.values())) == 2 assert (state[q0].data == entangled_state_repr).all() assert state[q1] is state[q0] assert state[None].data == 0 -def test_entanglement_causes_join(): +def test_entanglement_causes_join() -> None: sim = SplittableCountingSimulator() state = sim._create_simulation_state(2, (q0, q1)) + assert isinstance(state, cirq.SimulationProductState) assert len(set(state.values())) == 3 state.apply_operation(cirq.CNOT(q0, q1)) assert len(set(state.values())) == 2 @@ -322,9 +331,10 @@ def test_entanglement_causes_join(): assert state[None] is not state[q0] -def test_measurement_causes_split(): +def test_measurement_causes_split() -> None: sim = SplittableCountingSimulator() state = sim._create_simulation_state(entangled_state_repr, (q0, q1)) + assert isinstance(state, cirq.SimulationProductState) assert len(set(state.values())) == 2 state.apply_operation(cirq.measure(q0)) assert len(set(state.values())) == 3 @@ -332,7 +342,7 @@ def test_measurement_causes_split(): assert state[q0] is not state[None] -def test_measurement_does_not_split_if_disabled(): +def test_measurement_does_not_split_if_disabled() -> None: sim = SplittableCountingSimulator(split_untangled_states=False) state = sim._create_simulation_state(2, (q0, q1)) assert isinstance(state, SplittableCountingSimulationState) @@ -341,7 +351,7 @@ def test_measurement_does_not_split_if_disabled(): assert state[q0] is state[q1] -def test_measurement_does_not_split_if_impossible(): +def test_measurement_does_not_split_if_impossible() -> None: sim = CountingSimulator() state = sim._create_simulation_state(2, (q0, q1)) assert isinstance(state, CountingSimulationState) @@ -352,7 +362,7 @@ def test_measurement_does_not_split_if_impossible(): assert state[q0] is state[q1] -def test_reorder_succeeds(): +def test_reorder_succeeds() -> None: sim = SplittableCountingSimulator() state = sim._create_simulation_state(entangled_state_repr, (q0, q1)) reordered = state[q0].transpose_to_qubit_order([q1, q0]) @@ -360,7 +370,7 @@ def test_reorder_succeeds(): @pytest.mark.parametrize('split', [True, False]) -def test_sim_state_instance_unchanged_during_normal_sim(split: bool): +def test_sim_state_instance_unchanged_during_normal_sim(split: bool) -> None: sim = SplittableCountingSimulator(split_untangled_states=split) state = sim._create_simulation_state(0, (q0, q1)) circuit = cirq.Circuit(cirq.H(q0), cirq.CNOT(q0, q1), cirq.reset(q1)) @@ -369,7 +379,7 @@ def test_sim_state_instance_unchanged_during_normal_sim(split: bool): assert (step._merged_sim_state is not state) == split -def test_measurements_retained_in_step_results(): +def test_measurements_retained_in_step_results() -> None: sim = SplittableCountingSimulator() circuit = cirq.Circuit( cirq.measure(q0, key='a'), cirq.measure(q0, key='b'), cirq.measure(q0, key='c') @@ -381,7 +391,7 @@ def test_measurements_retained_in_step_results(): assert not any(iterator) -def test_sweep_unparameterized_prefix_not_repeated_iff_unitary(): +def test_sweep_unparameterized_prefix_not_repeated_iff_unitary() -> None: q = cirq.LineQubit(0) class TestOp(cirq.Operation): @@ -410,6 +420,8 @@ def _has_unitary_(self): op2 = TestOp(has_unitary=True) circuit = cirq.Circuit(op1, cirq.XPowGate(exponent=sympy.Symbol('a'))(q), op2) rs = simulator.simulate_sweep(program=circuit, params=params) + assert isinstance(rs[0]._final_simulator_state, CountingSimulationState) + assert isinstance(rs[1]._final_simulator_state, CountingSimulationState) assert rs[0]._final_simulator_state.copy_count == 1 assert rs[1]._final_simulator_state.copy_count == 0 assert op1.count == 1 @@ -419,13 +431,15 @@ def _has_unitary_(self): op2 = TestOp(has_unitary=False) circuit = cirq.Circuit(op1, cirq.XPowGate(exponent=sympy.Symbol('a'))(q), op2) rs = simulator.simulate_sweep(program=circuit, params=params) + assert isinstance(rs[0]._final_simulator_state, CountingSimulationState) + assert isinstance(rs[1]._final_simulator_state, CountingSimulationState) assert rs[0]._final_simulator_state.copy_count == 1 assert rs[1]._final_simulator_state.copy_count == 0 assert op1.count == 2 assert op2.count == 2 -def test_inhomogeneous_measurement_count_padding(): +def test_inhomogeneous_measurement_count_padding() -> None: q = cirq.LineQubit(0) key = cirq.MeasurementKey('m') sim = cirq.Simulator() diff --git a/cirq-core/cirq/sim/simulator_test.py b/cirq-core/cirq/sim/simulator_test.py index e456d8fa8de..60478f1aa67 100644 --- a/cirq-core/cirq/sim/simulator_test.py +++ b/cirq-core/cirq/sim/simulator_test.py @@ -56,13 +56,13 @@ def __init__(self, *, ones_qubits=None, final_state=None): def _simulator_state(self): return self._final_state # pragma: no cover - def state_vector(self): + def state_vector(self) -> None: pass def __setstate__(self, state): pass - def sample(self, qubits, repetitions=1, seed=None): + def sample(self, qubits, repetitions=1, seed=None) -> np.ndarray: return np.array([[qubit in self._ones_qubits for qubit in qubits]] * repetitions) @@ -99,7 +99,7 @@ def _create_simulator_trial_result( ) -def test_run_simulator_run(): +def test_run_simulator_run() -> None: expected_records = {'a': np.array([[[1]]])} simulator = FakeSimulatesSamples(expected_records) circuit = cirq.Circuit(cirq.measure(cirq.LineQubit(0), key='k')) @@ -110,7 +110,7 @@ def test_run_simulator_run(): ) -def test_run_simulator_sweeps(): +def test_run_simulator_sweeps() -> None: expected_records = {'a': np.array([[[1]]])} simulator = FakeSimulatesSamples(expected_records) circuit = cirq.Circuit(cirq.measure(cirq.LineQubit(0), key='k')) @@ -127,8 +127,8 @@ def test_run_simulator_sweeps(): @mock.patch.multiple( SimulatesIntermediateStateImpl, __abstractmethods__=set(), simulate_moment_steps=mock.Mock() ) -def test_intermediate_simulator(): - simulator = SimulatesIntermediateStateImpl() +def test_intermediate_simulator() -> None: + simulator: SimulatesIntermediateStateImpl = SimulatesIntermediateStateImpl() final_simulator_state = np.array([1, 0, 0, 0]) @@ -141,7 +141,7 @@ def steps(*args, **kwargs): result._simulator_state.return_value = final_simulator_state yield result - simulator.simulate_moment_steps.side_effect = steps + simulator.simulate_moment_steps.side_effect = steps # type: ignore[attr-defined] circuit = mock.Mock(cirq.Circuit) param_resolver = cirq.ParamResolver({}) qubit_order = mock.Mock(cirq.QubitOrder) @@ -158,8 +158,8 @@ def steps(*args, **kwargs): @mock.patch.multiple( SimulatesIntermediateStateImpl, __abstractmethods__=set(), simulate_moment_steps=mock.Mock() ) -def test_intermediate_sweeps(): - simulator = SimulatesIntermediateStateImpl() +def test_intermediate_sweeps() -> None: + simulator: SimulatesIntermediateStateImpl = SimulatesIntermediateStateImpl() final_state = np.array([1, 0, 0, 0]) @@ -169,7 +169,7 @@ def steps(*args, **kwargs): result._simulator_state.return_value = final_state yield result - simulator.simulate_moment_steps.side_effect = steps + simulator.simulate_moment_steps.side_effect = steps # type: ignore[attr-defined] circuit = mock.Mock(cirq.Circuit) param_resolvers = [cirq.ParamResolver({}), cirq.ParamResolver({})] qubit_order = mock.Mock(cirq.QubitOrder) @@ -192,7 +192,7 @@ def steps(*args, **kwargs): assert results == expected_results -def test_step_sample_measurement_ops(): +def test_step_sample_measurement_ops() -> None: q0, q1, q2 = cirq.LineQubit.range(3) measurement_ops = [cirq.measure(q0, q1), cirq.measure(q2)] step_result = FakeStepResult(ones_qubits=[q1]) @@ -201,7 +201,7 @@ def test_step_sample_measurement_ops(): np.testing.assert_equal(measurements, {'q(0),q(1)': [[False, True]], 'q(2)': [[False]]}) -def test_step_sample_measurement_ops_repetitions(): +def test_step_sample_measurement_ops_repetitions() -> None: q0, q1, q2 = cirq.LineQubit.range(3) measurement_ops = [cirq.measure(q0, q1), cirq.measure(q2)] step_result = FakeStepResult(ones_qubits=[q1]) @@ -210,7 +210,7 @@ def test_step_sample_measurement_ops_repetitions(): np.testing.assert_equal(measurements, {'q(0),q(1)': [[False, True]] * 3, 'q(2)': [[False]] * 3}) -def test_step_sample_measurement_ops_invert_mask(): +def test_step_sample_measurement_ops_invert_mask() -> None: q0, q1, q2 = cirq.LineQubit.range(3) measurement_ops = [ cirq.measure(q0, q1, invert_mask=(True,)), @@ -222,9 +222,11 @@ def test_step_sample_measurement_ops_invert_mask(): np.testing.assert_equal(measurements, {'q(0),q(1)': [[True, True]], 'q(2)': [[False]]}) -def test_step_sample_measurement_ops_confusion_map(): +def test_step_sample_measurement_ops_confusion_map() -> None: q0, q1, q2 = cirq.LineQubit.range(3) + cmap_01: dict[tuple[int, ...], np.ndarray] cmap_01 = {(0, 1): np.array([[0, 1, 0, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 0]])} + cmap_2: dict[tuple[int, ...], np.ndarray] cmap_2 = {(0,): np.array([[0, 1], [1, 0]])} measurement_ops = [ cirq.measure(q0, q1, confusion_map=cmap_01), @@ -236,21 +238,21 @@ def test_step_sample_measurement_ops_confusion_map(): np.testing.assert_equal(measurements, {'q(0),q(1)': [[False, True]], 'q(2)': [[False]]}) -def test_step_sample_measurement_ops_no_measurements(): +def test_step_sample_measurement_ops_no_measurements() -> None: step_result = FakeStepResult(ones_qubits=[]) measurements = step_result.sample_measurement_ops([]) assert measurements == {} -def test_step_sample_measurement_ops_not_measurement(): +def test_step_sample_measurement_ops_not_measurement() -> None: q0 = cirq.LineQubit(0) step_result = FakeStepResult(ones_qubits=[q0]) with pytest.raises(ValueError, match='MeasurementGate'): step_result.sample_measurement_ops([cirq.X(q0)]) -def test_step_sample_measurement_ops_repeated_qubit(): +def test_step_sample_measurement_ops_repeated_qubit() -> None: q0, q1, q2 = cirq.LineQubit.range(3) step_result = FakeStepResult(ones_qubits=[q0]) with pytest.raises(ValueError, match=r'Measurement key q\(0\) repeated'): @@ -259,7 +261,7 @@ def test_step_sample_measurement_ops_repeated_qubit(): ) -def test_simulation_trial_result_equality(): +def test_simulation_trial_result_equality() -> None: eq = cirq.testing.EqualsTester() eq.add_equality_group( cirq.SimulationTrialResult( @@ -290,7 +292,7 @@ def test_simulation_trial_result_equality(): ) -def test_simulation_trial_result_repr(): +def test_simulation_trial_result_repr() -> None: assert repr( cirq.SimulationTrialResult( params=cirq.ParamResolver({'s': 1}), @@ -305,7 +307,7 @@ def test_simulation_trial_result_repr(): ) -def test_simulation_trial_result_str(): +def test_simulation_trial_result_str() -> None: assert ( str( cirq.SimulationTrialResult( @@ -349,7 +351,7 @@ def test_simulation_trial_result_str(): ) -def test_pretty_print(): +def test_pretty_print() -> None: result = cirq.SimulationTrialResult(cirq.ParamResolver(), {}, np.array([1])) # Test Jupyter console output from @@ -371,7 +373,7 @@ def text(self, to_print): @duet.sync -async def test_async_sample(): +async def test_async_sample() -> None: m = {'mock': np.array([[[0]], [[1]]])} simulator = FakeSimulatesSamples(m) @@ -381,16 +383,16 @@ async def test_async_sample(): np.testing.assert_equal(result.records, m) -def test_simulation_trial_result_qubit_map(): +def test_simulation_trial_result_qubit_map() -> None: q = cirq.LineQubit.range(2) result = cirq.Simulator().simulate(cirq.Circuit([cirq.CZ(q[0], q[1])])) assert result.qubit_map == {q[0]: 0, q[1]: 1} - result = cirq.DensityMatrixSimulator().simulate(cirq.Circuit([cirq.CZ(q[0], q[1])])) - assert result.qubit_map == {q[0]: 0, q[1]: 1} + result2 = cirq.DensityMatrixSimulator().simulate(cirq.Circuit([cirq.CZ(q[0], q[1])])) + assert result2.qubit_map == {q[0]: 0, q[1]: 1} -def test_sample_repeated_measurement_keys(): +def test_sample_repeated_measurement_keys() -> None: q = cirq.LineQubit.range(2) circuit = cirq.Circuit() circuit.append( @@ -408,7 +410,7 @@ def test_sample_repeated_measurement_keys(): assert len(result.records['b'][0]) == 2 -def test_classical_controls_go_to_suffix_if_corresponding_measurement_does(): +def test_classical_controls_go_to_suffix_if_corresponding_measurement_does() -> None: subcircuit = cirq.CircuitOperation(cirq.FrozenCircuit()).with_classical_controls('a') m = cirq.measure(cirq.LineQubit(0), key='a') circuit = cirq.Circuit(m, subcircuit) @@ -419,7 +421,7 @@ def test_classical_controls_go_to_suffix_if_corresponding_measurement_does(): assert suffix == circuit -def test_simulate_with_invert_mask(): +def test_simulate_with_invert_mask() -> None: q0, q1, q2, q3, q4 = cirq.LineQid.for_qid_shape((2, 3, 3, 3, 4)) c = cirq.Circuit( cirq.XPowGate(dimension=2)(q0), @@ -431,7 +433,7 @@ def test_simulate_with_invert_mask(): assert np.all(cirq.Simulator().run(c).measurements['a'] == [[0, 1, 0, 2, 3]]) -def test_monte_carlo_on_unknown_channel(): +def test_monte_carlo_on_unknown_channel() -> None: class Reset11To00(cirq.Gate): def num_qubits(self) -> int: return 2 @@ -451,8 +453,10 @@ def _kraus_(self): ) -def test_iter_definitions(): - mock_trial_result = SimulationTrialResult(params={}, measurements={}, final_simulator_state=[]) +def test_iter_definitions() -> None: + mock_trial_result: SimulationTrialResult = SimulationTrialResult( + params=cirq.ParamResolver(), measurements={}, final_simulator_state=[] + ) class FakeNonIterSimulatorImpl( SimulatesAmplitudes, SimulatesExpectationValues, SimulatesFinalState @@ -495,7 +499,7 @@ def simulate_sweep( q0 = cirq.LineQubit(0) circuit = cirq.Circuit(cirq.X(q0)) bitstrings = [0b0] - params = {} + params: cirq.ParamMappingType = {} assert non_iter_sim.compute_amplitudes_sweep(circuit, bitstrings, params) == [[1.0]] amp_iter = non_iter_sim.compute_amplitudes_sweep_iter(circuit, bitstrings, params) assert next(amp_iter) == [1.0] @@ -510,7 +514,7 @@ def simulate_sweep( assert next(state_iter) == mock_trial_result -def test_missing_iter_definitions(): +def test_missing_iter_definitions() -> None: class FakeMissingIterSimulatorImpl( SimulatesAmplitudes, SimulatesExpectationValues, SimulatesFinalState ): @@ -520,7 +524,7 @@ class FakeMissingIterSimulatorImpl( q0 = cirq.LineQubit(0) circuit = cirq.Circuit(cirq.X(q0)) bitstrings = [0b0] - params = {} + params: cirq.ParamMappingType = {} with pytest.raises(RecursionError): missing_iter_sim.compute_amplitudes_sweep(circuit, bitstrings, params) with pytest.raises(RecursionError): @@ -541,7 +545,7 @@ class FakeMissingIterSimulatorImpl( next(state_iter) -def test_trial_result_initializer(): +def test_trial_result_initializer() -> None: resolver = cirq.ParamResolver() state = 3 x = SimulationTrialResult(resolver, {}, state) diff --git a/cirq-core/cirq/sim/sparse_simulator.py b/cirq-core/cirq/sim/sparse_simulator.py index 50c194dce91..43e933cd244 100644 --- a/cirq-core/cirq/sim/sparse_simulator.py +++ b/cirq-core/cirq/sim/sparse_simulator.py @@ -239,7 +239,7 @@ def __init__( self._dtype = dtype self._state_vector: np.ndarray | None = None - def state_vector(self, copy: bool = False): + def state_vector(self, copy: bool = False) -> np.ndarray: """Return the state vector at this point in the computation. The state is returned in the computational basis with these basis diff --git a/cirq-core/cirq/sim/sparse_simulator_test.py b/cirq-core/cirq/sim/sparse_simulator_test.py index 86d8bc685ad..3fd53704104 100644 --- a/cirq-core/cirq/sim/sparse_simulator_test.py +++ b/cirq-core/cirq/sim/sparse_simulator_test.py @@ -25,14 +25,14 @@ import cirq -def test_invalid_dtype(): +def test_invalid_dtype() -> None: with pytest.raises(ValueError, match='complex'): - cirq.Simulator(dtype=np.int32) + cirq.Simulator(dtype=np.int32) # type: ignore[arg-type] @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_no_measurements(dtype: type[np.complexfloating], split: bool): +def test_run_no_measurements(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) @@ -43,7 +43,7 @@ def test_run_no_measurements(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_no_results(dtype: type[np.complexfloating], split: bool): +def test_run_no_results(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) @@ -54,7 +54,7 @@ def test_run_no_results(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_empty_circuit(dtype: type[np.complexfloating], split: bool): +def test_run_empty_circuit(dtype: type[np.complexfloating], split: bool) -> None: simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) with pytest.raises(ValueError, match="no measurements"): simulator.run(cirq.Circuit()) @@ -62,7 +62,7 @@ def test_run_empty_circuit(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_reset(dtype: type[np.complexfloating], split: bool): +def test_run_reset(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQid.for_qid_shape((2, 3)) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) circuit = cirq.Circuit( @@ -82,7 +82,7 @@ def test_run_reset(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_bit_flips(dtype: type[np.complexfloating], split: bool): +def test_run_bit_flips(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -96,7 +96,7 @@ def test_run_bit_flips(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_measure_at_end_no_repetitions(dtype: type[np.complexfloating], split: bool): +def test_run_measure_at_end_no_repetitions(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: @@ -113,7 +113,7 @@ def test_run_measure_at_end_no_repetitions(dtype: type[np.complexfloating], spli assert mock_sim.call_count == 0 -def test_run_repetitions_terminal_measurement_stochastic(): +def test_run_repetitions_terminal_measurement_stochastic() -> None: q = cirq.LineQubit(0) c = cirq.Circuit(cirq.H(q), cirq.measure(q, key='q')) results = cirq.Simulator().run(c, repetitions=10000) @@ -122,7 +122,7 @@ def test_run_repetitions_terminal_measurement_stochastic(): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_repetitions_measure_at_end(dtype: type[np.complexfloating], split: bool): +def test_run_repetitions_measure_at_end(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: @@ -142,7 +142,7 @@ def test_run_repetitions_measure_at_end(dtype: type[np.complexfloating], split: @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_invert_mask_measure_not_terminal(dtype: type[np.complexfloating], split: bool): +def test_run_invert_mask_measure_not_terminal(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: @@ -163,7 +163,9 @@ def test_run_invert_mask_measure_not_terminal(dtype: type[np.complexfloating], s @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_partial_invert_mask_measure_not_terminal(dtype: type[np.complexfloating], split: bool): +def test_run_partial_invert_mask_measure_not_terminal( + dtype: type[np.complexfloating], split: bool +) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: @@ -184,7 +186,9 @@ def test_run_partial_invert_mask_measure_not_terminal(dtype: type[np.complexfloa @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_measurement_not_terminal_no_repetitions(dtype: type[np.complexfloating], split: bool): +def test_run_measurement_not_terminal_no_repetitions( + dtype: type[np.complexfloating], split: bool +) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: @@ -208,7 +212,9 @@ def test_run_measurement_not_terminal_no_repetitions(dtype: type[np.complexfloat @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_repetitions_measurement_not_terminal(dtype: type[np.complexfloating], split: bool): +def test_run_repetitions_measurement_not_terminal( + dtype: type[np.complexfloating], split: bool +) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim: @@ -233,7 +239,7 @@ def test_run_repetitions_measurement_not_terminal(dtype: type[np.complexfloating @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_param_resolver(dtype: type[np.complexfloating], split: bool): +def test_run_param_resolver(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -252,7 +258,7 @@ def test_run_param_resolver(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_mixture(dtype: type[np.complexfloating], split: bool): +def test_run_mixture(dtype: type[np.complexfloating], split: bool) -> None: q0 = cirq.LineQubit(0) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) circuit = cirq.Circuit(cirq.bit_flip(0.5)(q0), cirq.measure(q0)) @@ -262,7 +268,7 @@ def test_run_mixture(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_mixture_with_gates(dtype: type[np.complexfloating], split: bool): +def test_run_mixture_with_gates(dtype: type[np.complexfloating], split: bool) -> None: q0 = cirq.LineQubit(0) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split, seed=23) circuit = cirq.Circuit(cirq.H(q0), cirq.phase_flip(0.5)(q0), cirq.H(q0), cirq.measure(q0)) @@ -273,7 +279,7 @@ def test_run_mixture_with_gates(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_correlations(dtype: type[np.complexfloating], split: bool): +def test_run_correlations(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) circuit = cirq.Circuit(cirq.H(q0), cirq.CNOT(q0, q1), cirq.measure(q0, q1)) @@ -285,7 +291,7 @@ def test_run_correlations(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_measure_multiple_qubits(dtype: type[np.complexfloating], split: bool): +def test_run_measure_multiple_qubits(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -297,7 +303,7 @@ def test_run_measure_multiple_qubits(dtype: type[np.complexfloating], split: boo @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_run_sweeps_param_resolvers(dtype: type[np.complexfloating], split: bool): +def test_run_sweeps_param_resolvers(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -323,7 +329,7 @@ def test_run_sweeps_param_resolvers(dtype: type[np.complexfloating], split: bool @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_random_unitary(dtype: type[np.complexfloating], split: bool): +def test_simulate_random_unitary(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) for _ in range(10): @@ -341,7 +347,7 @@ def test_simulate_random_unitary(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_no_circuit(dtype: type[np.complexfloating], split: bool): +def test_simulate_no_circuit(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) circuit = cirq.Circuit() @@ -352,7 +358,7 @@ def test_simulate_no_circuit(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate(dtype: type[np.complexfloating], split: bool): +def test_simulate(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) circuit = cirq.Circuit(cirq.H(q0), cirq.H(q1)) @@ -374,7 +380,7 @@ def _mixture_(self): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_qudits(dtype: type[np.complexfloating], split: bool): +def test_simulate_qudits(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQid.for_qid_shape((3, 4)) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) circuit = cirq.Circuit(cirq.XPowGate(dimension=3)(q0), cirq.XPowGate(dimension=4)(q1) ** 3) @@ -387,7 +393,7 @@ def test_simulate_qudits(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_mixtures(dtype: type[np.complexfloating], split: bool): +def test_simulate_mixtures(dtype: type[np.complexfloating], split: bool) -> None: q0 = cirq.LineQubit(0) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) circuit = cirq.Circuit(cirq.bit_flip(0.5)(q0), cirq.measure(q0)) @@ -405,7 +411,7 @@ def test_simulate_mixtures(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize( 'dtype, split', itertools.product([np.complex64, np.complex128], [True, False]) ) -def test_simulate_qudit_mixtures(dtype: type[np.complexfloating], split: bool): +def test_simulate_qudit_mixtures(dtype: type[np.complexfloating], split: bool) -> None: q0 = cirq.LineQid(0, 3) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) mixture = _TestMixture( @@ -431,7 +437,7 @@ def test_simulate_qudit_mixtures(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_bit_flips(dtype: type[np.complexfloating], split: bool): +def test_simulate_bit_flips(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -456,7 +462,7 @@ def test_simulate_initial_state( dtype: type[np.complexfloating], split: bool, initial_state: int | cirq.StateVectorSimulationState, -): +) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -470,7 +476,7 @@ def test_simulate_initial_state( @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulation_state(dtype: type[np.complexfloating], split: bool): +def test_simulation_state(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -485,7 +491,7 @@ def test_simulation_state(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_qubit_order(dtype: type[np.complexfloating], split: bool): +def test_simulate_qubit_order(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -499,7 +505,7 @@ def test_simulate_qubit_order(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_param_resolver(dtype: type[np.complexfloating], split: bool): +def test_simulate_param_resolver(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -518,7 +524,7 @@ def test_simulate_param_resolver(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_measure_multiple_qubits(dtype: type[np.complexfloating], split: bool): +def test_simulate_measure_multiple_qubits(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -530,7 +536,7 @@ def test_simulate_measure_multiple_qubits(dtype: type[np.complexfloating], split @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_sweeps_param_resolver(dtype: type[np.complexfloating], split: bool): +def test_simulate_sweeps_param_resolver(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) for b0 in [0, 1]: @@ -557,7 +563,7 @@ def test_simulate_sweeps_param_resolver(dtype: type[np.complexfloating], split: @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_moment_steps(dtype: type[np.complexfloating], split: bool): +def test_simulate_moment_steps(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit(cirq.H(q0), cirq.H(q1), cirq.H(q0), cirq.H(q1)) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) @@ -570,7 +576,7 @@ def test_simulate_moment_steps(dtype: type[np.complexfloating], split: bool): @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_moment_steps_empty_circuit(dtype: type[np.complexfloating], split: bool): +def test_simulate_moment_steps_empty_circuit(dtype: type[np.complexfloating], split: bool) -> None: circuit = cirq.Circuit() simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) for step in simulator.simulate_moment_steps(circuit): @@ -581,7 +587,7 @@ def test_simulate_moment_steps_empty_circuit(dtype: type[np.complexfloating], sp @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_moment_steps_sample(dtype: type[np.complexfloating], split: bool): +def test_simulate_moment_steps_sample(dtype: type[np.complexfloating], split: bool) -> None: q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit(cirq.H(q0), cirq.CNOT(q0, q1)) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) @@ -604,7 +610,7 @@ def test_simulate_moment_steps_sample(dtype: type[np.complexfloating], split: bo @pytest.mark.parametrize('split', [True, False]) def test_simulate_moment_steps_intermediate_measurement( dtype: type[np.complexfloating], split: bool -): +) -> None: q0 = cirq.LineQubit(0) circuit = cirq.Circuit(cirq.H(q0), cirq.measure(q0), cirq.H(q0)) simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split) @@ -621,7 +627,7 @@ def test_simulate_moment_steps_intermediate_measurement( @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_expectation_values(dtype: type[np.complexfloating], split: bool): +def test_simulate_expectation_values(dtype: type[np.complexfloating], split: bool) -> None: # Compare with test_expectation_from_state_vector_two_qubit_states # in file: cirq/ops/linear_combinations_test.py q0, q1 = cirq.LineQubit.range(2) @@ -646,7 +652,9 @@ def test_simulate_expectation_values(dtype: type[np.complexfloating], split: boo @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_expectation_values_terminal_measure(dtype: type[np.complexfloating], split: bool): +def test_simulate_expectation_values_terminal_measure( + dtype: type[np.complexfloating], split: bool +) -> None: q0 = cirq.LineQubit(0) circuit = cirq.Circuit(cirq.H(q0), cirq.measure(q0)) obs = cirq.Z(q0) @@ -684,7 +692,9 @@ def test_simulate_expectation_values_terminal_measure(dtype: type[np.complexfloa @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('split', [True, False]) -def test_simulate_expectation_values_qubit_order(dtype: type[np.complexfloating], split: bool): +def test_simulate_expectation_values_qubit_order( + dtype: type[np.complexfloating], split: bool +) -> None: q0, q1, q2 = cirq.LineQubit.range(3) circuit = cirq.Circuit(cirq.H(q0), cirq.H(q1), cirq.X(q2)) obs = cirq.X(q0) + cirq.X(q1) - cirq.Z(q2) @@ -698,7 +708,7 @@ def test_simulate_expectation_values_qubit_order(dtype: type[np.complexfloating] assert cirq.approx_eq(result_flipped[0], 3, atol=1e-6) -def test_invalid_run_no_unitary(): +def test_invalid_run_no_unitary() -> None: class NoUnitary(cirq.testing.SingleQubitGate): pass @@ -710,7 +720,7 @@ class NoUnitary(cirq.testing.SingleQubitGate): simulator.run(circuit) -def test_allocates_new_state(): +def test_allocates_new_state() -> None: class NoUnitary(cirq.testing.SingleQubitGate): def _has_unitary_(self): return True @@ -728,7 +738,7 @@ def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs): assert not initial_state is result.state_vector() -def test_does_not_modify_initial_state(): +def test_does_not_modify_initial_state() -> None: q0 = cirq.LineQubit(0) simulator = cirq.Simulator() @@ -753,7 +763,7 @@ def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs): ) -def test_simulator_step_state_mixin(): +def test_simulator_step_state_mixin() -> None: qubits = cirq.LineQubit.range(2) args = cirq.StateVectorSimulationState( available_buffer=np.array([0, 1, 0, 0]).reshape((2, 2)), @@ -771,7 +781,7 @@ def test_simulator_step_state_mixin(): assert result.dirac_notation() == '|01⟩' -def test_sparse_simulator_repr(): +def test_sparse_simulator_repr() -> None: qubits = cirq.LineQubit.range(2) args = cirq.StateVectorSimulationState( available_buffer=np.array([0, 1, 0, 0]).reshape((2, 2)), @@ -795,7 +805,7 @@ def _decompose_(self, qubits): return cirq.H.on_each(*qubits) -def test_simulates_composite(): +def test_simulates_composite() -> None: c = cirq.Circuit(MultiHTestGate().on(*cirq.LineQubit.range(2))) expected = np.array([0.5] * 4) np.testing.assert_allclose( @@ -804,7 +814,7 @@ def test_simulates_composite(): np.testing.assert_allclose(cirq.Simulator().simulate(c).state_vector(), expected) -def test_simulate_measurement_inversions(): +def test_simulate_measurement_inversions() -> None: q = cirq.NamedQubit('q') c = cirq.Circuit(cirq.measure(q, key='q', invert_mask=(True,))) @@ -814,7 +824,7 @@ def test_simulate_measurement_inversions(): assert cirq.Simulator().simulate(c).measurements == {'q': np.array([False])} -def test_works_on_pauli_string_phasor(): +def test_works_on_pauli_string_phasor() -> None: a, b = cirq.LineQubit.range(2) c = cirq.Circuit(np.exp(0.5j * np.pi * cirq.X(a) * cirq.X(b))) sim = cirq.Simulator() @@ -822,7 +832,7 @@ def test_works_on_pauli_string_phasor(): np.testing.assert_allclose(result.reshape(4), np.array([0, 0, 0, 1j]), atol=1e-8) -def test_works_on_pauli_string(): +def test_works_on_pauli_string() -> None: a, b = cirq.LineQubit.range(2) c = cirq.Circuit(cirq.X(a) * cirq.X(b)) sim = cirq.Simulator() @@ -830,7 +840,7 @@ def test_works_on_pauli_string(): np.testing.assert_allclose(result.reshape(4), np.array([0, 0, 0, 1]), atol=1e-8) -def test_measure_at_end_invert_mask(): +def test_measure_at_end_invert_mask() -> None: simulator = cirq.Simulator() a = cirq.NamedQubit('a') circuit = cirq.Circuit(cirq.measure(a, key='a', invert_mask=(True,))) @@ -838,7 +848,7 @@ def test_measure_at_end_invert_mask(): np.testing.assert_equal(result.measurements['a'], np.array([[1]] * 4)) -def test_measure_at_end_invert_mask_multiple_qubits(): +def test_measure_at_end_invert_mask_multiple_qubits() -> None: simulator = cirq.Simulator() a, b, c = cirq.LineQubit.range(3) circuit = cirq.Circuit( @@ -850,7 +860,7 @@ def test_measure_at_end_invert_mask_multiple_qubits(): np.testing.assert_equal(result.measurements['bc'], np.array([[0, 1]] * 4)) -def test_measure_at_end_invert_mask_partial(): +def test_measure_at_end_invert_mask_partial() -> None: simulator = cirq.Simulator() a, _, c = cirq.LineQubit.range(3) circuit = cirq.Circuit(cirq.measure(a, c, key='ac', invert_mask=(True,))) @@ -858,7 +868,7 @@ def test_measure_at_end_invert_mask_partial(): np.testing.assert_equal(result.measurements['ac'], np.array([[1, 0]] * 4)) -def test_qudit_invert_mask(): +def test_qudit_invert_mask() -> None: q0, q1, q2, q3, q4 = cirq.LineQid.for_qid_shape((2, 3, 3, 3, 4)) c = cirq.Circuit( cirq.XPowGate(dimension=2)(q0), @@ -870,7 +880,7 @@ def test_qudit_invert_mask(): assert np.all(cirq.Simulator().run(c).measurements['a'] == [[0, 1, 0, 2, 3]]) -def test_compute_amplitudes(): +def test_compute_amplitudes() -> None: a, b = cirq.LineQubit.range(2) c = cirq.Circuit(cirq.X(a), cirq.H(a), cirq.H(b)) sim = cirq.Simulator() @@ -885,16 +895,16 @@ def test_compute_amplitudes(): np.testing.assert_allclose(np.array(result), np.array([-0.5, 0.5, -0.5])) -def test_compute_amplitudes_bad_input(): +def test_compute_amplitudes_bad_input() -> None: a, b = cirq.LineQubit.range(2) c = cirq.Circuit(cirq.X(a), cirq.H(a), cirq.H(b)) sim = cirq.Simulator() with pytest.raises(ValueError, match='1-dimensional'): - _ = sim.compute_amplitudes(c, np.array([[0, 0]])) + _ = sim.compute_amplitudes(c, np.array([[0, 0]])) # type: ignore[arg-type] -def test_sample_from_amplitudes(): +def test_sample_from_amplitudes() -> None: q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit(cirq.H(q0), cirq.CNOT(q0, q1), cirq.X(q1)) sim = cirq.Simulator(seed=1) @@ -905,7 +915,7 @@ def test_sample_from_amplitudes(): assert 3 not in result -def test_sample_from_amplitudes_teleport(): +def test_sample_from_amplitudes_teleport() -> None: q0, q1, q2 = cirq.LineQubit.range(3) # Initialize q0 to some state, teleport it to q2, then clean up. circuit = cirq.Circuit( @@ -936,7 +946,7 @@ def test_sample_from_amplitudes_teleport(): assert result_c[1] < 20 -def test_sample_from_amplitudes_nonunitary_fails(): +def test_sample_from_amplitudes_nonunitary_fails() -> None: q0, q1 = cirq.LineQubit.range(2) sim = cirq.Simulator(seed=1) @@ -951,7 +961,7 @@ def test_sample_from_amplitudes_nonunitary_fails(): _ = sim.sample_from_amplitudes(circuit2, {}, sim._prng) -def test_run_sweep_parameters_not_resolved(): +def test_run_sweep_parameters_not_resolved() -> None: a = cirq.LineQubit(0) simulator = cirq.Simulator() circuit = cirq.Circuit(cirq.XPowGate(exponent=sympy.Symbol('a'))(a), cirq.measure(a)) @@ -959,7 +969,7 @@ def test_run_sweep_parameters_not_resolved(): _ = simulator.run_sweep(circuit, cirq.ParamResolver({})) -def test_simulate_sweep_parameters_not_resolved(): +def test_simulate_sweep_parameters_not_resolved() -> None: a = cirq.LineQubit(0) simulator = cirq.Simulator() circuit = cirq.Circuit(cirq.XPowGate(exponent=sympy.Symbol('a'))(a), cirq.measure(a)) @@ -967,7 +977,7 @@ def test_simulate_sweep_parameters_not_resolved(): _ = simulator.simulate_sweep(circuit, cirq.ParamResolver({})) -def test_random_seed(): +def test_random_seed() -> None: a = cirq.NamedQubit('a') circuit = cirq.Circuit(cirq.X(a) ** 0.5, cirq.measure(a)) @@ -986,7 +996,7 @@ def test_random_seed(): ) -def test_random_seed_does_not_modify_global_state_terminal_measurements(): +def test_random_seed_does_not_modify_global_state_terminal_measurements() -> None: a = cirq.NamedQubit('a') circuit = cirq.Circuit(cirq.X(a) ** 0.5, cirq.measure(a)) @@ -1001,7 +1011,7 @@ def test_random_seed_does_not_modify_global_state_terminal_measurements(): assert result1 == result2 -def test_random_seed_does_not_modify_global_state_non_terminal_measurements(): +def test_random_seed_does_not_modify_global_state_non_terminal_measurements() -> None: a = cirq.NamedQubit('a') circuit = cirq.Circuit( cirq.X(a) ** 0.5, cirq.measure(a, key='a0'), cirq.X(a) ** 0.5, cirq.measure(a, key='a1') @@ -1018,7 +1028,7 @@ def test_random_seed_does_not_modify_global_state_non_terminal_measurements(): assert result1 == result2 -def test_random_seed_does_not_modify_global_state_mixture(): +def test_random_seed_does_not_modify_global_state_mixture() -> None: a = cirq.NamedQubit('a') circuit = cirq.Circuit(cirq.depolarize(0.5).on(a), cirq.measure(a)) @@ -1033,7 +1043,7 @@ def test_random_seed_does_not_modify_global_state_mixture(): assert result1 == result2 -def test_random_seed_terminal_measurements_deterministic(): +def test_random_seed_terminal_measurements_deterministic() -> None: a = cirq.NamedQubit('a') circuit = cirq.Circuit(cirq.X(a) ** 0.5, cirq.measure(a, key='a')) sim = cirq.Simulator(seed=1234) @@ -1111,7 +1121,7 @@ def test_random_seed_terminal_measurements_deterministic(): ) -def test_random_seed_non_terminal_measurements_deterministic(): +def test_random_seed_non_terminal_measurements_deterministic() -> None: a = cirq.NamedQubit('a') circuit = cirq.Circuit( cirq.X(a) ** 0.5, cirq.measure(a, key='a'), cirq.X(a) ** 0.5, cirq.measure(a, key='b') @@ -1190,7 +1200,7 @@ def test_random_seed_non_terminal_measurements_deterministic(): ) -def test_random_seed_mixture_deterministic(): +def test_random_seed_mixture_deterministic() -> None: a = cirq.NamedQubit('a') circuit = cirq.Circuit( cirq.depolarize(0.9).on(a), @@ -1239,7 +1249,7 @@ def test_random_seed_mixture_deterministic(): ) -def test_entangled_reset_does_not_break_randomness(): +def test_entangled_reset_does_not_break_randomness() -> None: """Test for bad assumptions on caching the wave function on general channels. A previous version of cirq made the mistake of assuming that it was okay to @@ -1258,7 +1268,7 @@ def test_entangled_reset_does_not_break_randomness(): assert 10 <= counts[1] <= 90 -def test_overlapping_measurements_at_end(): +def test_overlapping_measurements_at_end() -> None: a, b = cirq.LineQubit.range(2) circuit = cirq.Circuit( cirq.H(a), @@ -1282,7 +1292,7 @@ def test_overlapping_measurements_at_end(): assert 10 <= counts[1] <= 90 -def test_separated_measurements(): +def test_separated_measurements() -> None: a, b = cirq.LineQubit.range(2) c = cirq.Circuit( [ @@ -1299,7 +1309,7 @@ def test_separated_measurements(): np.testing.assert_array_equal(sample['zero'].values, [0] * 10) -def test_state_vector_copy(): +def test_state_vector_copy() -> None: sim = cirq.Simulator(split_untangled_states=False) class InplaceGate(cirq.testing.SingleQubitGate): @@ -1329,7 +1339,7 @@ def _apply_unitary_(self, args): assert any(not np.array_equal(x, y) for x, y in zip(vectors, copy_of_vectors)) -def test_final_state_vector_is_not_last_object(): +def test_final_state_vector_is_not_last_object() -> None: sim = cirq.Simulator() q = cirq.LineQubit(0) @@ -1341,7 +1351,7 @@ def test_final_state_vector_is_not_last_object(): np.testing.assert_equal(result.state_vector(), initial_state) -def test_deterministic_gate_noise(): +def test_deterministic_gate_noise() -> None: q = cirq.LineQubit(0) circuit = cirq.Circuit(cirq.I(q), cirq.measure(q)) @@ -1359,7 +1369,7 @@ def test_deterministic_gate_noise(): assert result1 != result3 -def test_nondeterministic_mixture_noise(): +def test_nondeterministic_mixture_noise() -> None: q = cirq.LineQubit(0) circuit = cirq.Circuit(cirq.I(q), cirq.measure(q)) @@ -1370,12 +1380,12 @@ def test_nondeterministic_mixture_noise(): assert result1 != result2 -def test_pure_state_creation(): +def test_pure_state_creation() -> None: sim = cirq.Simulator() qids = cirq.LineQubit.range(3) shape = cirq.qid_shape(qids) args = sim._create_simulation_state(1, qids) - values = list(args.values()) + values = list(args.values()) # type: ignore[attr-defined] arg = ( values[0] .kronecker_product(values[1]) @@ -1386,7 +1396,7 @@ def test_pure_state_creation(): np.testing.assert_allclose(arg.target_tensor, expected.reshape(shape)) -def test_noise_model(): +def test_noise_model() -> None: q = cirq.LineQubit(0) circuit = cirq.Circuit(cirq.H(q), cirq.measure(q)) @@ -1397,7 +1407,7 @@ def test_noise_model(): assert 20 <= sum(result.measurements['q(0)'])[0] < 80 -def test_separated_states_str_does_not_merge(): +def test_separated_states_str_does_not_merge() -> None: q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit( cirq.measure(q0), cirq.measure(q1), cirq.H(q0), cirq.global_phase_operation(0 + 1j) @@ -1419,7 +1429,7 @@ def test_separated_states_str_does_not_merge(): ) -def test_separable_non_dirac_str(): +def test_separable_non_dirac_str() -> None: circuit = cirq.Circuit() for i in range(4): circuit.append(cirq.H(cirq.LineQubit(i))) @@ -1429,7 +1439,7 @@ def test_separable_non_dirac_str(): assert '+0.j' in str(result) -def test_unseparated_states_str(): +def test_unseparated_states_str() -> None: q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit( cirq.measure(q0), cirq.measure(q1), cirq.H(q0), cirq.global_phase_operation(0 + 1j) @@ -1446,7 +1456,7 @@ def test_unseparated_states_str(): @pytest.mark.parametrize('split', [True, False]) -def test_measurement_preserves_phase(split: bool): +def test_measurement_preserves_phase(split: bool) -> None: c1, c2, t = cirq.LineQubit.range(3) circuit = cirq.Circuit( cirq.H(t), diff --git a/cirq-core/cirq/sim/state_vector.py b/cirq-core/cirq/sim/state_vector.py index a9b2f55089e..05096eca667 100644 --- a/cirq-core/cirq/sim/state_vector.py +++ b/cirq-core/cirq/sim/state_vector.py @@ -105,7 +105,7 @@ def dirac_notation(self, decimals: int = 2) -> str: and non-zero floats of the specified accuracy.""" return qis.dirac_notation(self.state_vector(), decimals, qid_shape=self._qid_shape) - def density_matrix_of(self, qubits: list[cirq.Qid] | None = None) -> np.ndarray: + def density_matrix_of(self, qubits: Sequence[cirq.Qid] | None = None) -> np.ndarray: r"""Returns the density matrix of the state. Calculate the density matrix for the system on the qubits provided. diff --git a/cirq-core/cirq/sim/state_vector_simulation_state.py b/cirq-core/cirq/sim/state_vector_simulation_state.py index 51f4a5e3d7f..220d8abf2b2 100644 --- a/cirq-core/cirq/sim/state_vector_simulation_state.py +++ b/cirq-core/cirq/sim/state_vector_simulation_state.py @@ -16,7 +16,7 @@ from __future__ import annotations -from typing import Any, Callable, Sequence, TYPE_CHECKING +from typing import Any, Callable, Self, Sequence, TYPE_CHECKING import numpy as np @@ -356,7 +356,7 @@ def __init__( ) super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data) - def add_qubits(self, qubits: Sequence[cirq.Qid]): + def add_qubits(self, qubits: Sequence[cirq.Qid]) -> Self: ret = super().add_qubits(qubits) return ( self.kronecker_product(type(self)(qubits=qubits), inplace=True) @@ -364,7 +364,7 @@ def add_qubits(self, qubits: Sequence[cirq.Qid]): else ret ) - def remove_qubits(self, qubits: Sequence[cirq.Qid]): + def remove_qubits(self, qubits: Sequence[cirq.Qid]) -> Self: ret = super().remove_qubits(qubits) if ret is not NotImplemented: return ret @@ -406,11 +406,11 @@ def __repr__(self) -> str: ) @property - def target_tensor(self): + def target_tensor(self) -> np.ndarray: return self._state._state_vector @property - def available_buffer(self): + def available_buffer(self) -> np.ndarray: return self._state._buffer diff --git a/cirq-core/cirq/sim/state_vector_simulator_test.py b/cirq-core/cirq/sim/state_vector_simulator_test.py index 99e6de306c2..91498ca4641 100644 --- a/cirq-core/cirq/sim/state_vector_simulator_test.py +++ b/cirq-core/cirq/sim/state_vector_simulator_test.py @@ -20,7 +20,7 @@ import cirq.testing -def test_state_vector_trial_result_repr(): +def test_state_vector_trial_result_repr() -> None: q0 = cirq.NamedQubit('a') final_simulator_state = cirq.StateVectorSimulationState( available_buffer=np.array([0, 1], dtype=np.complex64), @@ -47,7 +47,7 @@ def test_state_vector_trial_result_repr(): assert eval(expected_repr) == trial_result -def test_state_vector_trial_result_equality(): +def test_state_vector_trial_result_equality() -> None: eq = cirq.testing.EqualsTester() final_simulator_state = cirq.StateVectorSimulationState(initial_state=np.array([])) eq.add_equality_group( @@ -86,7 +86,7 @@ def test_state_vector_trial_result_equality(): ) -def test_state_vector_trial_result_state_mixin(): +def test_state_vector_trial_result_state_mixin() -> None: qubits = cirq.LineQubit.range(2) final_simulator_state = cirq.StateVectorSimulationState( qubits=qubits, initial_state=np.array([0, 1, 0, 0]) @@ -103,7 +103,7 @@ def test_state_vector_trial_result_state_mixin(): assert result.dirac_notation() == '|01⟩' -def test_state_vector_trial_result_qid_shape(): +def test_state_vector_trial_result_qid_shape() -> None: final_simulator_state = cirq.StateVectorSimulationState( qubits=[cirq.NamedQubit('a')], initial_state=np.array([0, 1]) ) @@ -125,7 +125,7 @@ def test_state_vector_trial_result_qid_shape(): assert cirq.qid_shape(trial_result) == (3, 2) -def test_state_vector_trial_state_vector_is_copy(): +def test_state_vector_trial_state_vector_is_copy() -> None: final_state_vector = np.array([0, 1], dtype=np.complex64) qubit_map = {cirq.NamedQubit('a'): 0} final_simulator_state = cirq.StateVectorSimulationState( @@ -137,7 +137,7 @@ def test_state_vector_trial_state_vector_is_copy(): assert trial_result.state_vector(copy=True) is not final_simulator_state.target_tensor -def test_state_vector_trial_result_no_qubits(): +def test_state_vector_trial_result_no_qubits() -> None: initial_state_vector = np.array([1], dtype=np.complex64) initial_state = initial_state_vector.reshape((2,) * 0) # reshape as tensor for 0 qubits final_simulator_state = cirq.StateVectorSimulationState(qubits=[], initial_state=initial_state) @@ -149,7 +149,7 @@ def test_state_vector_trial_result_no_qubits(): assert np.array_equal(state_vector, initial_state_vector) -def test_str_big(): +def test_str_big() -> None: qs = cirq.LineQubit.range(10) final_simulator_state = cirq.StateVectorSimulationState( prng=np.random.RandomState(0), @@ -161,7 +161,7 @@ def test_str_big(): assert 'output vector: [0.03125+0.j 0.03125+0.j 0.03125+0.j ..' in str(result) -def test_str_qudit(): +def test_str_qudit() -> None: qutrit = cirq.LineQid(0, dimension=3) final_simulator_state = cirq.StateVectorSimulationState( prng=np.random.RandomState(0), @@ -183,7 +183,7 @@ def test_str_qudit(): assert "|1⟩" in str(result) -def test_pretty_print(): +def test_pretty_print() -> None: final_simulator_state = cirq.StateVectorSimulationState( available_buffer=np.array([1]), prng=np.random.RandomState(0), diff --git a/cirq-core/cirq/sim/state_vector_test.py b/cirq-core/cirq/sim/state_vector_test.py index d6cfc204f1b..c7a9119b868 100644 --- a/cirq-core/cirq/sim/state_vector_test.py +++ b/cirq-core/cirq/sim/state_vector_test.py @@ -35,7 +35,7 @@ def use_np_transpose(request) -> Iterator[bool]: yield value -def test_state_mixin(): +def test_state_mixin() -> None: class TestClass(cirq.StateVectorMixin): def state_vector(self, copy: bool | None = None) -> np.ndarray: return np.array([0, 0, 1, 0]) @@ -60,7 +60,7 @@ def state_vector(self, copy: bool | None = None) -> np.ndarray: _ = TestClass({qubits[0]: -1, qubits[1]: 1}) -def test_sample_state_big_endian(): +def test_sample_state_big_endian() -> None: results = [] for x in range(8): state = cirq.to_valid_state_vector(x, 3) @@ -71,7 +71,7 @@ def test_sample_state_big_endian(): np.testing.assert_equal(result, expected) -def test_sample_state_partial_indices(): +def test_sample_state_partial_indices() -> None: for index in range(3): for x in range(8): state = cirq.to_valid_state_vector(x, 3) @@ -80,14 +80,14 @@ def test_sample_state_partial_indices(): ) -def test_sample_state_partial_indices_oder(): +def test_sample_state_partial_indices_oder() -> None: for x in range(8): state = cirq.to_valid_state_vector(x, 3) expected = [[bool(1 & (x >> 0)), bool(1 & (x >> 1))]] np.testing.assert_equal(cirq.sample_state_vector(state, [2, 1]), expected) -def test_sample_state_partial_indices_all_orders(): +def test_sample_state_partial_indices_all_orders() -> None: for perm in itertools.permutations([0, 1, 2]): for x in range(8): state = cirq.to_valid_state_vector(x, 3) @@ -95,7 +95,7 @@ def test_sample_state_partial_indices_all_orders(): np.testing.assert_equal(cirq.sample_state_vector(state, perm), expected) -def test_sample_state(): +def test_sample_state() -> None: state = np.zeros(8, dtype=np.complex64) state[0] = 1 / np.sqrt(2) state[2] = 1 / np.sqrt(2) @@ -110,12 +110,12 @@ def test_sample_state(): np.testing.assert_equal(cirq.sample_state_vector(state, [0]), [[False]]) -def test_sample_empty_state(): +def test_sample_empty_state() -> None: state = np.array([1.0]) np.testing.assert_almost_equal(cirq.sample_state_vector(state, []), np.zeros(shape=(1, 0))) -def test_sample_no_repetitions(): +def test_sample_no_repetitions() -> None: state = cirq.to_valid_state_vector(0, 3) np.testing.assert_almost_equal( cirq.sample_state_vector(state, [1], repetitions=0), np.zeros(shape=(0, 1)) @@ -125,7 +125,7 @@ def test_sample_no_repetitions(): ) -def test_sample_state_repetitions(): +def test_sample_state_repetitions() -> None: for perm in itertools.permutations([0, 1, 2]): for x in range(8): state = cirq.to_valid_state_vector(x, 3) @@ -135,7 +135,7 @@ def test_sample_state_repetitions(): np.testing.assert_equal(result, expected) -def test_sample_state_seed(): +def test_sample_state_seed() -> None: state = np.ones(2) / np.sqrt(2) samples = cirq.sample_state_vector(state, [0], repetitions=10, seed=1234) @@ -151,20 +151,20 @@ def test_sample_state_seed(): ) -def test_sample_state_negative_repetitions(): +def test_sample_state_negative_repetitions() -> None: state = cirq.to_valid_state_vector(0, 3) with pytest.raises(ValueError, match='-1'): cirq.sample_state_vector(state, [1], repetitions=-1) -def test_sample_state_not_power_of_two(): +def test_sample_state_not_power_of_two() -> None: with pytest.raises(ValueError, match='3'): cirq.sample_state_vector(np.array([1, 0, 0]), [1]) with pytest.raises(ValueError, match='5'): cirq.sample_state_vector(np.array([0, 1, 0, 0, 0]), [1]) -def test_sample_state_index_out_of_range(): +def test_sample_state_index_out_of_range() -> None: state = cirq.to_valid_state_vector(0, 3) with pytest.raises(IndexError, match='-2'): cirq.sample_state_vector(state, [-2]) @@ -172,12 +172,12 @@ def test_sample_state_index_out_of_range(): cirq.sample_state_vector(state, [3]) -def test_sample_no_indices(): +def test_sample_no_indices() -> None: state = cirq.to_valid_state_vector(0, 3) np.testing.assert_almost_equal(cirq.sample_state_vector(state, []), np.zeros(shape=(1, 0))) -def test_sample_no_indices_repetitions(): +def test_sample_no_indices_repetitions() -> None: state = cirq.to_valid_state_vector(0, 3) np.testing.assert_almost_equal( cirq.sample_state_vector(state, [], repetitions=2), np.zeros(shape=(2, 0)) @@ -185,7 +185,7 @@ def test_sample_no_indices_repetitions(): @pytest.mark.parametrize('use_np_transpose', [False, True], indirect=True) -def test_measure_state_computational_basis(use_np_transpose: bool): +def test_measure_state_computational_basis(use_np_transpose: bool) -> None: # verify patching of can_numpy_support_shape in the use_np_transpose fixture assert linalg.can_numpy_support_shape([1]) is use_np_transpose results = [] @@ -199,7 +199,7 @@ def test_measure_state_computational_basis(use_np_transpose: bool): @pytest.mark.parametrize('use_np_transpose', [False, True], indirect=True) -def test_measure_state_reshape(use_np_transpose: bool): +def test_measure_state_reshape(use_np_transpose: bool) -> None: results = [] for x in range(8): initial_state = np.reshape(cirq.to_valid_state_vector(x, 3), [2] * 3) @@ -211,7 +211,7 @@ def test_measure_state_reshape(use_np_transpose: bool): @pytest.mark.parametrize('use_np_transpose', [False, True], indirect=True) -def test_measure_state_partial_indices(use_np_transpose: bool): +def test_measure_state_partial_indices(use_np_transpose: bool) -> None: for index in range(3): for x in range(8): initial_state = cirq.to_valid_state_vector(x, 3) @@ -221,7 +221,7 @@ def test_measure_state_partial_indices(use_np_transpose: bool): @pytest.mark.parametrize('use_np_transpose', [False, True], indirect=True) -def test_measure_state_partial_indices_order(use_np_transpose: bool): +def test_measure_state_partial_indices_order(use_np_transpose: bool) -> None: for x in range(8): initial_state = cirq.to_valid_state_vector(x, 3) bits, state = cirq.measure_state_vector(initial_state, [2, 1]) @@ -230,7 +230,7 @@ def test_measure_state_partial_indices_order(use_np_transpose: bool): @pytest.mark.parametrize('use_np_transpose', [False, True], indirect=True) -def test_measure_state_partial_indices_all_orders(use_np_transpose: bool): +def test_measure_state_partial_indices_all_orders(use_np_transpose: bool) -> None: for perm in itertools.permutations([0, 1, 2]): for x in range(8): initial_state = cirq.to_valid_state_vector(x, 3) @@ -240,7 +240,7 @@ def test_measure_state_partial_indices_all_orders(use_np_transpose: bool): @pytest.mark.parametrize('use_np_transpose', [False, True], indirect=True) -def test_measure_state_collapse(use_np_transpose: bool): +def test_measure_state_collapse(use_np_transpose: bool) -> None: initial_state = np.zeros(8, dtype=np.complex64) initial_state[0] = 1 / np.sqrt(2) initial_state[2] = 1 / np.sqrt(2) @@ -264,7 +264,7 @@ def test_measure_state_collapse(use_np_transpose: bool): @pytest.mark.parametrize('use_np_transpose', [False, True], indirect=True) -def test_measure_state_seed(use_np_transpose: bool): +def test_measure_state_seed(use_np_transpose: bool) -> None: n = 10 initial_state = np.ones(2**n) / 2 ** (n / 2) @@ -284,7 +284,7 @@ def test_measure_state_seed(use_np_transpose: bool): @pytest.mark.parametrize('use_np_transpose', [False, True], indirect=True) -def test_measure_state_out_is_state(use_np_transpose: bool): +def test_measure_state_out_is_state(use_np_transpose: bool) -> None: initial_state = np.zeros(8, dtype=np.complex64) initial_state[0] = 1 / np.sqrt(2) initial_state[2] = 1 / np.sqrt(2) @@ -296,7 +296,7 @@ def test_measure_state_out_is_state(use_np_transpose: bool): @pytest.mark.parametrize('use_np_transpose', [False, True], indirect=True) -def test_measure_state_out_is_not_state(use_np_transpose: bool): +def test_measure_state_out_is_not_state(use_np_transpose: bool) -> None: initial_state = np.zeros(8, dtype=np.complex64) initial_state[0] = 1 / np.sqrt(2) initial_state[2] = 1 / np.sqrt(2) @@ -307,7 +307,7 @@ def test_measure_state_out_is_not_state(use_np_transpose: bool): @pytest.mark.parametrize('use_np_transpose', [False, True], indirect=True) -def test_measure_state_not_power_of_two(use_np_transpose: bool): +def test_measure_state_not_power_of_two(use_np_transpose: bool) -> None: with pytest.raises(ValueError, match='3'): _, _ = cirq.measure_state_vector(np.array([1, 0, 0]), [1]) with pytest.raises(ValueError, match='5'): @@ -315,7 +315,7 @@ def test_measure_state_not_power_of_two(use_np_transpose: bool): @pytest.mark.parametrize('use_np_transpose', [False, True], indirect=True) -def test_measure_state_index_out_of_range(use_np_transpose: bool): +def test_measure_state_index_out_of_range(use_np_transpose: bool) -> None: state = cirq.to_valid_state_vector(0, 3) with pytest.raises(IndexError, match='-2'): cirq.measure_state_vector(state, [-2]) @@ -324,7 +324,7 @@ def test_measure_state_index_out_of_range(use_np_transpose: bool): @pytest.mark.parametrize('use_np_transpose', [False, True], indirect=True) -def test_measure_state_no_indices(use_np_transpose: bool): +def test_measure_state_no_indices(use_np_transpose: bool) -> None: initial_state = cirq.to_valid_state_vector(0, 3) bits, state = cirq.measure_state_vector(initial_state, []) assert [] == bits @@ -332,7 +332,7 @@ def test_measure_state_no_indices(use_np_transpose: bool): @pytest.mark.parametrize('use_np_transpose', [False, True], indirect=True) -def test_measure_state_no_indices_out_is_state(use_np_transpose: bool): +def test_measure_state_no_indices_out_is_state(use_np_transpose: bool) -> None: initial_state = cirq.to_valid_state_vector(0, 3) bits, state = cirq.measure_state_vector(initial_state, [], out=initial_state) assert [] == bits @@ -341,7 +341,7 @@ def test_measure_state_no_indices_out_is_state(use_np_transpose: bool): @pytest.mark.parametrize('use_np_transpose', [False, True], indirect=True) -def test_measure_state_no_indices_out_is_not_state(use_np_transpose: bool): +def test_measure_state_no_indices_out_is_not_state(use_np_transpose: bool) -> None: initial_state = cirq.to_valid_state_vector(0, 3) out = np.zeros_like(initial_state) bits, state = cirq.measure_state_vector(initial_state, [], out=out) @@ -352,7 +352,7 @@ def test_measure_state_no_indices_out_is_not_state(use_np_transpose: bool): @pytest.mark.parametrize('use_np_transpose', [False, True], indirect=True) -def test_measure_state_empty_state(use_np_transpose: bool): +def test_measure_state_empty_state(use_np_transpose: bool) -> None: initial_state = np.array([1.0]) bits, state = cirq.measure_state_vector(initial_state, []) assert [] == bits @@ -364,12 +364,12 @@ def state_vector(self, copy: bool | None = None) -> np.ndarray: return np.array([0, 1, 0, 0]) -def test_step_result_pretty_state(): +def test_step_result_pretty_state() -> None: step_result = BasicStateVector() assert step_result.dirac_notation() == '|01⟩' -def test_step_result_density_matrix(): +def test_step_result_density_matrix() -> None: q0, q1 = cirq.LineQubit.range(2) step_result = BasicStateVector({q0: 0, q1: 1}) @@ -385,7 +385,7 @@ def test_step_result_density_matrix(): np.testing.assert_array_almost_equal(single_rho, step_result.density_matrix_of([q1])) -def test_step_result_density_matrix_invalid(): +def test_step_result_density_matrix_invalid() -> None: q0, q1 = cirq.LineQubit.range(2) step_result = BasicStateVector({q0: 0}) @@ -393,12 +393,12 @@ def test_step_result_density_matrix_invalid(): with pytest.raises(KeyError): step_result.density_matrix_of([q1]) with pytest.raises(KeyError): - step_result.density_matrix_of('junk') + step_result.density_matrix_of('junk') # type: ignore[arg-type] with pytest.raises(TypeError): - step_result.density_matrix_of(0) + step_result.density_matrix_of(0) # type: ignore[arg-type] -def test_step_result_bloch_vector(): +def test_step_result_bloch_vector() -> None: q0, q1 = cirq.LineQubit.range(2) step_result = BasicStateVector({q0: 0, q1: 1}) bloch1 = np.array([0, 0, -1]) @@ -407,7 +407,7 @@ def test_step_result_bloch_vector(): np.testing.assert_array_almost_equal(bloch0, step_result.bloch_vector_of(q0)) -def test_factor_validation(): +def test_factor_validation() -> None: args = cirq.Simulator()._create_simulation_state(0, qubits=cirq.LineQubit.range(2)) args.apply_operation(cirq.H(cirq.LineQubit(0)) ** 0.7) t = args.create_merged_state().target_tensor diff --git a/cirq-core/cirq/study/result_test.py b/cirq-core/cirq/study/result_test.py index 9cb99452227..2c71bdfdf3f 100644 --- a/cirq-core/cirq/study/result_test.py +++ b/cirq-core/cirq/study/result_test.py @@ -25,12 +25,12 @@ from cirq.study.result import _pack_digits -def test_result_init(): +def test_result_init() -> None: assert cirq.ResultDict(params=cirq.ParamResolver({}), measurements=None).repetitions == 0 assert cirq.ResultDict(params=cirq.ParamResolver({}), measurements={}).repetitions == 0 -def test_default_repetitions(): +def test_default_repetitions() -> None: class MyResult(cirq.Result): def __init__(self, records): self._records = records @@ -55,7 +55,7 @@ def data(self): assert MyResult({'a': np.zeros((5, 2, 3))}).repetitions == 5 -def test_repr(): +def test_repr() -> None: v = cirq.ResultDict( params=cirq.ParamResolver({'a': 2}), measurements={'xy': np.array([[1, 0], [0, 1]])} ) @@ -68,7 +68,7 @@ def test_repr(): cirq.testing.assert_equivalent_repr(v) -def test_construct_from_measurements(): +def test_construct_from_measurements() -> None: r = cirq.ResultDict( params=None, measurements={'a': np.array([[0, 0], [1, 1]]), 'b': np.array([[0, 0, 0], [1, 1, 1]])}, @@ -79,7 +79,7 @@ def test_construct_from_measurements(): assert np.all(r.records['b'] == np.array([[[0, 0, 0]], [[1, 1, 1]]])) -def test_construct_from_repeated_measurements(): +def test_construct_from_repeated_measurements() -> None: r = cirq.ResultDict( params=None, records={ @@ -104,13 +104,13 @@ def test_construct_from_repeated_measurements(): assert r2.repetitions == 2 -def test_empty_measurements(): +def test_empty_measurements() -> None: assert cirq.ResultDict(params=None).repetitions == 0 assert cirq.ResultDict(params=None, measurements={}).repetitions == 0 assert cirq.ResultDict(params=None, records={}).repetitions == 0 -def test_str(): +def test_str() -> None: result = cirq.ResultDict( params=cirq.ParamResolver({}), measurements={ @@ -136,7 +136,7 @@ def test_str(): assert str(result) == 'c=1, 0\nc=0, 1' -def test_df(): +def test_df() -> None: result = cirq.ResultDict( params=cirq.ParamResolver({}), measurements={ @@ -156,7 +156,7 @@ def test_df(): assert df.c.value_counts().to_dict() == {0: 3, 1: 2} -def test_df_large(): +def test_df_large() -> None: result = cirq.ResultDict( params=cirq.ParamResolver({}), measurements={ @@ -171,7 +171,7 @@ def test_df_large(): assert result.data['d'].dtype == object -def test_histogram(): +def test_histogram() -> None: result = cirq.ResultDict( params=cirq.ParamResolver({}), measurements={ @@ -188,7 +188,7 @@ def test_histogram(): assert result.histogram(key='c') == collections.Counter({0: 3, 1: 2}) -def test_multi_measurement_histogram(): +def test_multi_measurement_histogram() -> None: result = cirq.ResultDict( params=cirq.ParamResolver({}), measurements={ @@ -222,7 +222,7 @@ def test_multi_measurement_histogram(): ) -def test_result_equality(): +def test_result_equality() -> None: et = cirq.testing.EqualsTester() et.add_equality_group( cirq.ResultDict(params=cirq.ParamResolver({}), measurements={'a': np.array([[0]] * 5)}), @@ -239,7 +239,7 @@ def test_result_equality(): ) -def test_result_addition_valid(): +def test_result_addition_valid() -> None: a = cirq.ResultDict( params=cirq.ParamResolver({'ax': 1}), measurements={ @@ -278,7 +278,7 @@ def test_result_addition_valid(): ) -def test_result_addition_invalid(): +def test_result_addition_invalid() -> None: a = cirq.ResultDict( params=cirq.ParamResolver({'ax': 1}), measurements={ @@ -324,10 +324,10 @@ def test_result_addition_invalid(): with pytest.raises(ValueError, match='different measurement shapes'): _ = a + e with pytest.raises(TypeError): - _ = a + 'junk' + _ = a + 'junk' # type: ignore[operator] -def test_qubit_keys_for_histogram(): +def test_qubit_keys_for_histogram() -> None: a, b, c = cirq.LineQubit.range(3) circuit = cirq.Circuit(cirq.measure(a, b), cirq.X(c), cirq.measure(c)) results = cirq.Simulator().run(program=circuit, repetitions=100) @@ -339,7 +339,7 @@ def test_qubit_keys_for_histogram(): assert results.histogram(key=[c]) == collections.Counter({1: 100}) -def test_text_diagram_jupyter(): +def test_text_diagram_jupyter() -> None: result = cirq.ResultDict( params=cirq.ParamResolver({}), measurements={ @@ -397,12 +397,12 @@ def test_json_bit_packing_and_dtype(use_records: bool) -> None: np.testing.assert_allclose(len(bits_json), len(digits_json) / 8, rtol=0.02) -def test_json_bit_packing_error(): +def test_json_bit_packing_error() -> None: with pytest.raises(ValueError): _pack_digits(np.ones(10), pack_bits='hi mom') -def test_json_bit_packing_force(): +def test_json_bit_packing_force() -> None: assert _pack_digits(np.ones(10, dtype=int), pack_bits='force') == _pack_digits( np.ones(10), pack_bits='auto' ) @@ -418,7 +418,7 @@ def test_json_bit_packing_force(): ) -def test_json_unpack_compat(): +def test_json_unpack_compat() -> None: """Test reading old json with serialized measurements array.""" old_json = """ { diff --git a/cirq-core/cirq/study/sweepable_test.py b/cirq-core/cirq/study/sweepable_test.py index 60b1ea1b033..4da76150702 100644 --- a/cirq-core/cirq/study/sweepable_test.py +++ b/cirq-core/cirq/study/sweepable_test.py @@ -24,66 +24,66 @@ import cirq -def test_to_resolvers_none(): +def test_to_resolvers_none() -> None: assert list(cirq.to_resolvers(None)) == [cirq.ParamResolver({})] -def test_to_resolvers_single(): +def test_to_resolvers_single() -> None: resolver = cirq.ParamResolver({}) assert list(cirq.to_resolvers(resolver)) == [resolver] assert list(cirq.to_resolvers({})) == [resolver] -def test_to_resolvers_sweep(): +def test_to_resolvers_sweep() -> None: sweep = cirq.Linspace('a', 0, 1, 10) assert list(cirq.to_resolvers(sweep)) == list(sweep) -def test_to_resolvers_iterable(): +def test_to_resolvers_iterable() -> None: resolvers = [cirq.ParamResolver({'a': 2}), cirq.ParamResolver({'a': 1})] assert list(cirq.to_resolvers(resolvers)) == resolvers assert list(cirq.to_resolvers([{'a': 2}, {'a': 1}])) == resolvers -def test_to_resolvers_iterable_sweeps(): +def test_to_resolvers_iterable_sweeps() -> None: sweeps = [cirq.Linspace('a', 0, 1, 10), cirq.Linspace('b', 0, 1, 10)] assert list(cirq.to_resolvers(sweeps)) == list(itertools.chain(*sweeps)) -def test_to_resolvers_bad(): +def test_to_resolvers_bad() -> None: with pytest.raises(TypeError, match='Unrecognized sweepable'): for _ in cirq.study.to_resolvers('nope'): pass -def test_to_sweeps_none(): +def test_to_sweeps_none() -> None: assert cirq.study.to_sweeps(None) == [cirq.UnitSweep] -def test_to_sweeps_single(): +def test_to_sweeps_single() -> None: resolver = cirq.ParamResolver({}) assert cirq.study.to_sweeps(resolver) == [cirq.UnitSweep] assert cirq.study.to_sweeps({}) == [cirq.UnitSweep] -def test_to_sweeps_sweep(): +def test_to_sweeps_sweep() -> None: sweep = cirq.Linspace('a', 0, 1, 10) assert cirq.study.to_sweeps(sweep) == [sweep] -def test_to_sweeps_iterable(): +def test_to_sweeps_iterable() -> None: resolvers = [cirq.ParamResolver({'a': 2}), cirq.ParamResolver({'a': 1})] sweeps = [cirq.study.Zip(cirq.Points('a', [2])), cirq.study.Zip(cirq.Points('a', [1]))] assert cirq.study.to_sweeps(resolvers) == sweeps assert cirq.study.to_sweeps([{'a': 2}, {'a': 1}]) == sweeps -def test_to_sweeps_iterable_sweeps(): +def test_to_sweeps_iterable_sweeps() -> None: sweeps = [cirq.Linspace('a', 0, 1, 10), cirq.Linspace('b', 0, 1, 10)] assert cirq.study.to_sweeps(sweeps) == sweeps -def test_to_sweeps_dictionary_of_list(): +def test_to_sweeps_dictionary_of_list() -> None: with pytest.warns(DeprecationWarning, match='dict_to_product_sweep'): assert cirq.study.to_sweeps({'t': [0, 2, 3]}) == cirq.study.to_sweeps( [{'t': 0}, {'t': 2}, {'t': 3}] @@ -98,12 +98,12 @@ def test_to_sweeps_dictionary_of_list(): ) -def test_to_sweeps_invalid(): +def test_to_sweeps_invalid() -> None: with pytest.raises(TypeError, match='Unrecognized sweepable'): cirq.study.to_sweeps('nope') -def test_to_sweep_sweep(): +def test_to_sweep_sweep() -> None: sweep = cirq.Linspace('a', 0, 1, 10) assert cirq.to_sweep(sweep) is sweep @@ -117,7 +117,7 @@ def test_to_sweep_sweep(): lambda: cirq.ParamResolver({sympy.Symbol('a'): 1}), ], ) -def test_to_sweep_single_resolver(r_gen): +def test_to_sweep_single_resolver(r_gen) -> None: sweep = cirq.to_sweep(r_gen()) assert isinstance(sweep, cirq.Sweep) assert list(sweep) == [cirq.ParamResolver({'a': 1})] @@ -141,18 +141,18 @@ def test_to_sweep_single_resolver(r_gen): lambda: {object(): r for r in [{'a': 1}, {'a': 1.5}]}.values(), ], ) -def test_to_sweep_resolver_list(r_list_gen): +def test_to_sweep_resolver_list(r_list_gen) -> None: sweep = cirq.to_sweep(r_list_gen()) assert isinstance(sweep, cirq.Sweep) assert list(sweep) == [cirq.ParamResolver({'a': 1}), cirq.ParamResolver({'a': 1.5})] -def test_to_sweep_type_error(): +def test_to_sweep_type_error() -> None: with pytest.raises(TypeError, match='Unexpected sweep'): - cirq.to_sweep(5) + cirq.to_sweep(5) # type: ignore[arg-type] -def test_to_sweeps_with_param_dict_appends_metadata(): +def test_to_sweeps_with_param_dict_appends_metadata() -> None: params = {'a': 1, 'b': 2, 'c': 3} unit_map = {'a': 'ns', 'b': 'ns'} @@ -167,7 +167,7 @@ def test_to_sweeps_with_param_dict_appends_metadata(): ] -def test_to_sweeps_with_param_list_appends_metadata(): +def test_to_sweeps_with_param_list_appends_metadata() -> None: resolvers = [cirq.ParamResolver({'a': 2}), cirq.ParamResolver({'a': 1})] unit_map = {'a': 'ns'} diff --git a/cirq-core/cirq/study/sweeps_test.py b/cirq-core/cirq/study/sweeps_test.py index 5acb089b16e..5b38720a4d9 100644 --- a/cirq-core/cirq/study/sweeps_test.py +++ b/cirq-core/cirq/study/sweeps_test.py @@ -20,27 +20,27 @@ import cirq -def test_product_duplicate_keys(): +def test_product_duplicate_keys() -> None: with pytest.raises(ValueError): _ = cirq.Linspace('a', 0, 9, 10) * cirq.Linspace('a', 0, 10, 11) -def test_zip_duplicate_keys(): +def test_zip_duplicate_keys() -> None: with pytest.raises(ValueError): _ = cirq.Linspace('a', 0, 9, 10) + cirq.Linspace('a', 0, 10, 11) -def test_product_wrong_type(): +def test_product_wrong_type() -> None: with pytest.raises(TypeError): - _ = cirq.Linspace('a', 0, 9, 10) * 2 + _ = cirq.Linspace('a', 0, 9, 10) * 2 # type: ignore[operator] -def test_zip_wrong_type(): +def test_zip_wrong_type() -> None: with pytest.raises(TypeError): - _ = cirq.Linspace('a', 0, 9, 10) + 2 + _ = cirq.Linspace('a', 0, 9, 10) + 2 # type: ignore[operator] -def test_linspace(): +def test_linspace() -> None: sweep = cirq.Linspace('a', 0.34, 9.16, 7) assert len(sweep) == 7 params = list(sweep.param_tuples()) @@ -49,7 +49,7 @@ def test_linspace(): assert params[-1] == (('a', 9.16),) -def test_linspace_one_point(): +def test_linspace_one_point() -> None: sweep = cirq.Linspace('a', 0.34, 9.16, 1) assert len(sweep) == 1 params = list(sweep.param_tuples()) @@ -57,7 +57,7 @@ def test_linspace_one_point(): assert params[0] == (('a', 0.34),) -def test_linspace_sympy_symbol(): +def test_linspace_sympy_symbol() -> None: a = sympy.Symbol('a') sweep = cirq.Linspace(a, 0.34, 9.16, 7) assert len(sweep) == 7 @@ -67,14 +67,14 @@ def test_linspace_sympy_symbol(): assert params[-1] == (('a', 9.16),) -def test_points(): +def test_points() -> None: sweep = cirq.Points('a', [1, 2, 3, 4]) assert len(sweep) == 4 params = list(sweep) assert len(params) == 4 -def test_zip(): +def test_zip() -> None: sweep = cirq.Points('a', [1, 2, 3]) + cirq.Points('b', [4, 5, 6, 7]) assert len(sweep) == 3 assert _values(sweep, 'a') == [1, 2, 3] @@ -86,7 +86,7 @@ def test_zip(): ] -def test_zip_longest(): +def test_zip_longest() -> None: sweep = cirq.ZipLongest(cirq.Points('a', [1, 2, 3]), cirq.Points('b', [4, 5, 6, 7])) assert tuple(sweep.param_tuples()) == ( (('a', 1), ('b', 4)), @@ -104,7 +104,7 @@ def test_zip_longest(): ) -def test_zip_longest_compatibility(): +def test_zip_longest_compatibility() -> None: sweep = cirq.Zip(cirq.Points('a', [1, 2, 3]), cirq.Points('b', [4, 5, 6])) sweep_longest = cirq.ZipLongest(cirq.Points('a', [1, 2, 3]), cirq.Points('b', [4, 5, 6])) assert tuple(sweep.param_tuples()) == tuple(sweep_longest.param_tuples()) @@ -118,7 +118,7 @@ def test_zip_longest_compatibility(): assert tuple(sweep.param_tuples()) == tuple(sweep_longest.param_tuples()) -def test_empty_zip(): +def test_empty_zip() -> None: assert len(cirq.Zip()) == 0 assert len(cirq.ZipLongest()) == 0 assert str(cirq.Zip()) == 'Zip()' @@ -126,7 +126,7 @@ def test_empty_zip(): _ = cirq.ZipLongest(cirq.Points('e', []), cirq.Points('a', [1, 2, 3])) -def test_zip_eq(): +def test_zip_eq() -> None: et = cirq.testing.EqualsTester() point_sweep1 = cirq.Points('a', [1, 2, 3]) point_sweep2 = cirq.Points('b', [4, 5, 6, 7]) @@ -146,7 +146,7 @@ def test_zip_eq(): et.add_equality_group(cirq.Zip(point_sweep1, point_sweep2)) -def test_product(): +def test_product() -> None: sweep = cirq.Points('a', [1, 2, 3]) * cirq.Points('b', [4, 5, 6, 7]) assert len(sweep) == 12 assert _values(sweep, 'a') == [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3] @@ -172,7 +172,7 @@ def test_product(): assert list(map(list, sweep.param_tuples())) == [[(str(i), 0) for i in range(1025)]] -def test_nested_product_zip(): +def test_nested_product_zip() -> None: sweep = cirq.Product( cirq.Product(cirq.Points('a', [0]), cirq.Points('b', [0])), cirq.Zip(cirq.Points('c', [0, 1]), cirq.Points('d', [0, 1])), @@ -183,7 +183,7 @@ def test_nested_product_zip(): ] -def test_zip_addition(): +def test_zip_addition() -> None: zip_sweep = cirq.Zip(cirq.Points('a', [1, 2]), cirq.Points('b', [3, 4])) zip_sweep2 = cirq.Points('c', [5, 6]) + zip_sweep assert len(zip_sweep2) == 2 @@ -192,17 +192,17 @@ def test_zip_addition(): assert _values(zip_sweep2, 'c') == [5, 6] -def test_empty_product(): +def test_empty_product() -> None: sweep = cirq.Product() assert len(sweep) == len(list(sweep)) == 1 assert str(sweep) == 'Product()' assert list(map(list, sweep.param_tuples())) == [[]] -def test_slice_access_error(): +def test_slice_access_error() -> None: sweep = cirq.Points('a', [1, 2, 3]) with pytest.raises(TypeError, match=''): - _ = sweep['junk'] + _ = sweep['junk'] # type: ignore[call-overload] with pytest.raises(IndexError): _ = sweep[4] @@ -211,7 +211,7 @@ def test_slice_access_error(): _ = sweep[-4] -def test_slice_sweep(): +def test_slice_sweep() -> None: sweep = cirq.Points('a', [1, 2, 3]) * cirq.Points('b', [4, 5, 6, 7]) first_two = sweep[:2] @@ -239,7 +239,7 @@ def test_slice_sweep(): assert len(list(single_sweep.param_tuples())) == 1 -def test_access_sweep(): +def test_access_sweep() -> None: sweep = cirq.Points('a', [1, 2, 3]) * cirq.Points('b', [4, 5, 6, 7]) first_elem = sweep[-12] @@ -259,7 +259,7 @@ def test_access_sweep(): lambda: ({sympy.Symbol('a'): a, 'b': a + 1} for a in (0, 0.5, 1, -10)), ], ) -def test_list_sweep(r_list_factory): +def test_list_sweep(r_list_factory) -> None: sweep = cirq.ListSweep(r_list_factory()) assert sweep.keys == ['a', 'b'] assert len(sweep) == 4 @@ -270,13 +270,13 @@ def test_list_sweep(r_list_factory): assert params[3] == (('a', -10), ('b', -9)) -def test_list_sweep_empty(): +def test_list_sweep_empty() -> None: assert cirq.ListSweep([]).keys == [] -def test_list_sweep_type_error(): +def test_list_sweep_type_error() -> None: with pytest.raises(TypeError, match='Not a ParamResolver'): - _ = cirq.ListSweep([cirq.ParamResolver(), 'bad']) + _ = cirq.ListSweep([cirq.ParamResolver(), 'bad']) # type: ignore[list-item] def _values(sweep, key): @@ -284,7 +284,7 @@ def _values(sweep, key): return [resolver.value_of(p) for resolver in sweep] -def test_equality(): +def test_equality() -> None: et = cirq.testing.EqualsTester() et.add_equality_group(cirq.UnitSweep, cirq.UnitSweep) @@ -320,7 +320,7 @@ def test_equality(): et.make_equality_group(lambda: cirq.ListSweep([{'x': 1}, {'x': -1}])) -def test_repr(): +def test_repr() -> None: cirq.testing.assert_equivalent_repr( cirq.study.sweeps.Product(cirq.UnitSweep), setup_code='import cirq\nfrom collections import OrderedDict', @@ -343,7 +343,7 @@ def test_repr(): ) -def test_zip_product_str(): +def test_zip_product_str() -> None: assert ( str(cirq.UnitSweep + cirq.UnitSweep + cirq.UnitSweep) == 'cirq.UnitSweep + cirq.UnitSweep + cirq.UnitSweep' @@ -362,7 +362,7 @@ def test_zip_product_str(): ) -def test_list_sweep_str(): +def test_list_sweep_str() -> None: assert ( str(cirq.UnitSweep) == '''Sweep: @@ -411,7 +411,7 @@ def test_list_sweep_str(): ) -def test_dict_to_product_sweep(): +def test_dict_to_product_sweep() -> None: assert cirq.dict_to_product_sweep({'t': [0, 2, 3]}) == ( cirq.Product(cirq.Points('t', [0, 2, 3])) ) @@ -421,7 +421,7 @@ def test_dict_to_product_sweep(): ) -def test_dict_to_zip_sweep(): +def test_dict_to_zip_sweep() -> None: assert cirq.dict_to_zip_sweep({'t': [0, 2, 3]}) == (cirq.Zip(cirq.Points('t', [0, 2, 3]))) assert cirq.dict_to_zip_sweep({'t': [0, 1], 's': [2, 3], 'r': 4}) == ( @@ -429,7 +429,7 @@ def test_dict_to_zip_sweep(): ) -def test_concat_linspace(): +def test_concat_linspace() -> None: sweep1 = cirq.Linspace('a', 0.34, 9.16, 4) sweep2 = cirq.Linspace('a', 10, 20, 4) concat_sweep = cirq.Concat(sweep1, sweep2) @@ -444,7 +444,7 @@ def test_concat_linspace(): assert params[7] == (('a', 20.0),) -def test_concat_points(): +def test_concat_points() -> None: sweep1 = cirq.Points('a', [1, 2]) sweep2 = cirq.Points('a', [3, 4, 5]) concat_sweep = cirq.Concat(sweep1, sweep2) @@ -456,7 +456,7 @@ def test_concat_points(): assert _values(concat_sweep, 'a') == [1, 2, 3, 4, 5] -def test_concat_many_points(): +def test_concat_many_points() -> None: sweep1 = cirq.Points('a', [1, 2]) sweep2 = cirq.Points('a', [3, 4, 5]) sweep3 = cirq.Points('a', [6, 7, 8]) @@ -468,7 +468,7 @@ def test_concat_many_points(): assert _values(concat_sweep, 'a') == [1, 2, 3, 4, 5, 6, 7, 8] -def test_concat_mixed(): +def test_concat_mixed() -> None: sweep1 = cirq.Linspace('a', 0, 1, 3) sweep2 = cirq.Points('a', [2, 3]) concat_sweep = cirq.Concat(sweep1, sweep2) @@ -477,7 +477,7 @@ def test_concat_mixed(): assert _values(concat_sweep, 'a') == [0.0, 0.5, 1.0, 2, 3] -def test_concat_inconsistent_keys(): +def test_concat_inconsistent_keys() -> None: sweep1 = cirq.Linspace('a', 0, 1, 3) sweep2 = cirq.Points('b', [2, 3]) @@ -485,7 +485,7 @@ def test_concat_inconsistent_keys(): cirq.Concat(sweep1, sweep2) -def test_concat_sympy_symbol(): +def test_concat_sympy_symbol() -> None: a = sympy.Symbol('a') sweep1 = cirq.Linspace(a, 0, 1, 3) sweep2 = cirq.Points(a, [2, 3]) @@ -495,7 +495,7 @@ def test_concat_sympy_symbol(): assert _values(concat_sweep, 'a') == [0.0, 0.5, 1.0, 2, 3] -def test_concat_repr_and_str(): +def test_concat_repr_and_str() -> None: sweep1 = cirq.Linspace('a', 0, 1, 3) sweep2 = cirq.Points('a', [2, 3]) concat_sweep = cirq.Concat(sweep1, sweep2) @@ -509,7 +509,7 @@ def test_concat_repr_and_str(): assert str(concat_sweep) == expected_str -def test_concat_large_sweep(): +def test_concat_large_sweep() -> None: sweep1 = cirq.Points('a', list(range(101))) sweep2 = cirq.Points('a', list(range(101, 202))) concat_sweep = cirq.Concat(sweep1, sweep2) @@ -518,7 +518,7 @@ def test_concat_large_sweep(): assert _values(concat_sweep, 'a') == list(range(101)) + list(range(101, 202)) -def test_concat_different_keys_raises(): +def test_concat_different_keys_raises() -> None: sweep1 = cirq.Linspace('a', 0, 1, 3) sweep2 = cirq.Points('b', [2, 3]) @@ -526,6 +526,6 @@ def test_concat_different_keys_raises(): _ = cirq.Concat(sweep1, sweep2) -def test_concat_empty_sweep_raises(): +def test_concat_empty_sweep_raises() -> None: with pytest.raises(ValueError, match="Concat requires at least one sweep."): _ = cirq.Concat() diff --git a/cirq-core/cirq/testing/circuit_compare_test.py b/cirq-core/cirq/testing/circuit_compare_test.py index 42d7fe15257..e0f53371a9c 100644 --- a/cirq-core/cirq/testing/circuit_compare_test.py +++ b/cirq-core/cirq/testing/circuit_compare_test.py @@ -21,7 +21,7 @@ from cirq.testing.circuit_compare import _assert_apply_unitary_works_when_axes_transposed -def test_sensitive_to_phase(): +def test_sensitive_to_phase() -> None: q = cirq.NamedQubit('q') cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( @@ -38,7 +38,7 @@ def test_sensitive_to_phase(): ) -def test_sensitive_to_measurement_but_not_measured_phase(): +def test_sensitive_to_measurement_but_not_measured_phase() -> None: q = cirq.NamedQubit('q') with pytest.raises(AssertionError): @@ -80,7 +80,7 @@ def test_sensitive_to_measurement_but_not_measured_phase(): ) -def test_sensitive_to_measurement_toggle(): +def test_sensitive_to_measurement_toggle() -> None: q = cirq.NamedQubit('q') with pytest.raises(AssertionError): @@ -103,7 +103,7 @@ def test_sensitive_to_measurement_toggle(): ) -def test_measuring_qubits(): +def test_measuring_qubits() -> None: a, b = cirq.LineQubit.range(2) with pytest.raises(AssertionError): @@ -126,7 +126,7 @@ def test_measuring_qubits(): @pytest.mark.parametrize( 'circuit', [cirq.testing.random_circuit(cirq.LineQubit.range(2), 4, 0.5) for _ in range(5)] ) -def test_random_same_matrix(circuit): +def test_random_same_matrix(circuit) -> None: a, b = cirq.LineQubit.range(2) same = cirq.Circuit( cirq.MatrixGate(circuit.unitary(qubits_that_should_be_present=[a, b])).on(a, b) @@ -140,7 +140,7 @@ def test_random_same_matrix(circuit): cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(mutable_circuit, same) -def test_correct_qubit_ordering(): +def test_correct_qubit_ordering() -> None: a, b = cirq.LineQubit.range(2) cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( cirq.Circuit(cirq.Z(a), cirq.Z(b), cirq.measure(b)), @@ -154,7 +154,7 @@ def test_correct_qubit_ordering(): ) -def test_known_old_failure(): +def test_known_old_failure() -> None: a, b = cirq.LineQubit.range(2) cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( actual=cirq.Circuit( @@ -170,7 +170,7 @@ def test_known_old_failure(): ) -def test_assert_same_circuits(): +def test_assert_same_circuits() -> None: a, b = cirq.LineQubit.range(2) cirq.testing.assert_same_circuits(cirq.Circuit(cirq.H(a)), cirq.Circuit(cirq.H(a))) @@ -191,7 +191,7 @@ def test_assert_same_circuits(): ) -def test_assert_circuits_have_same_unitary_given_final_permutation(): +def test_assert_circuits_have_same_unitary_given_final_permutation() -> None: q = cirq.LineQubit.range(5) expected = cirq.Circuit([cirq.Moment(cirq.CNOT(q[2], q[1]), cirq.CNOT(q[3], q[0]))]) actual = cirq.Circuit( @@ -202,6 +202,7 @@ def test_assert_circuits_have_same_unitary_given_final_permutation(): cirq.Moment(cirq.CNOT(q[3], q[2])), ] ) + qubit_map: dict[cirq.Qid, cirq.Qid] qubit_map = {q[0]: q[2], q[2]: q[1], q[1]: q[0]} cirq.testing.assert_circuits_have_same_unitary_given_final_permutation( actual, expected, qubit_map @@ -213,6 +214,7 @@ def test_assert_circuits_have_same_unitary_given_final_permutation(): actual, expected, qubit_map=qubit_map ) + bad_qubit_map: dict[cirq.Qid, cirq.Qid] bad_qubit_map = {q[0]: q[2], q[2]: q[4], q[4]: q[0]} with pytest.raises(ValueError, match="'qubit_map' must be a mapping"): cirq.testing.assert_circuits_have_same_unitary_given_final_permutation( @@ -220,7 +222,7 @@ def test_assert_circuits_have_same_unitary_given_final_permutation(): ) -def test_assert_has_diagram(): +def test_assert_has_diagram() -> None: a, b = cirq.LineQubit.range(2) circuit = cirq.Circuit(cirq.CNOT(a, b)) cirq.testing.assert_has_diagram( @@ -263,7 +265,7 @@ def test_assert_has_diagram(): assert expected_error in ex_info.value.args[0] -def test_assert_has_consistent_apply_channel(): +def test_assert_has_consistent_apply_channel() -> None: class Correct: def _apply_channel_(self, args: cirq.ApplyChannelArgs): args.target_tensor[...] = 0 @@ -329,7 +331,7 @@ def _num_qubits_(self): cirq.testing.assert_has_consistent_apply_channel(NoApply()) -def test_assert_has_consistent_apply_unitary(): +def test_assert_has_consistent_apply_unitary() -> None: class IdentityReturningUnalteredWorkspace: def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> np.ndarray: return args.available_buffer @@ -461,7 +463,7 @@ class UnknownCountEffect: cirq.testing.assert_has_consistent_apply_unitary(cirq.X.on(cirq.NamedQubit('q'))) -def test_assert_has_consistent_qid_shape(): +def test_assert_has_consistent_qid_shape() -> None: class ConsistentGate(cirq.Gate): def _num_qubits_(self): return 4 @@ -560,7 +562,7 @@ class NoProtocol: cirq.testing.assert_has_consistent_qid_shape(NoProtocol()) -def test_assert_apply_unitary_works_when_axes_transposed_failure(): +def test_assert_apply_unitary_works_when_axes_transposed_failure() -> None: class BadOp: def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs): # Get a more convenient view of the data. diff --git a/cirq-core/cirq/testing/consistent_channels.py b/cirq-core/cirq/testing/consistent_channels.py index d9ff04421c4..c025a1f80a3 100644 --- a/cirq-core/cirq/testing/consistent_channels.py +++ b/cirq-core/cirq/testing/consistent_channels.py @@ -21,7 +21,7 @@ import cirq -def assert_consistent_channel(gate: Any, rtol: float = 1e-5, atol: float = 1e-8): +def assert_consistent_channel(gate: Any, rtol: float = 1e-5, atol: float = 1e-8) -> None: """Asserts that a given gate has Kraus operators and that they are properly normalized.""" assert cirq.has_kraus(gate), f"Given gate {gate!r} does not return True for cirq.has_kraus." kraus_ops = cirq.kraus(gate) @@ -31,7 +31,7 @@ def assert_consistent_channel(gate: Any, rtol: float = 1e-5, atol: float = 1e-8) ) -def assert_consistent_mixture(gate: Any, rtol: float = 1e-5, atol: float = 1e-8): +def assert_consistent_mixture(gate: Any, rtol: float = 1e-5, atol: float = 1e-8) -> None: """Asserts that a given gate is a mixture and the mixture probabilities sum to one.""" assert cirq.has_mixture(gate), f"Give gate {gate!r} does not return for cirq.has_mixture." mixture = cirq.mixture(gate) diff --git a/cirq-core/cirq/testing/consistent_controlled_gate_op.py b/cirq-core/cirq/testing/consistent_controlled_gate_op.py index 80cbf0d344d..6aed7365156 100644 --- a/cirq-core/cirq/testing/consistent_controlled_gate_op.py +++ b/cirq-core/cirq/testing/consistent_controlled_gate_op.py @@ -38,7 +38,7 @@ def assert_controlled_and_controlled_by_identical( _assert_gate_consistent(gate, num_control, control_value) -def assert_controlled_unitary_consistent(gate: ops.Gate): +def assert_controlled_unitary_consistent(gate: ops.Gate) -> None: """Checks that unitary of ControlledGate(gate) is consistent with gate.controlled().""" u_orig = protocols.unitary(ops.ControlledGate(gate)) diff --git a/cirq-core/cirq/testing/consistent_decomposition.py b/cirq-core/cirq/testing/consistent_decomposition.py index 5d437fbe899..0725ed7c2b3 100644 --- a/cirq-core/cirq/testing/consistent_decomposition.py +++ b/cirq-core/cirq/testing/consistent_decomposition.py @@ -22,7 +22,9 @@ from cirq.testing import lin_alg_utils -def assert_decompose_is_consistent_with_unitary(val: Any, ignoring_global_phase: bool = False): +def assert_decompose_is_consistent_with_unitary( + val: Any, ignoring_global_phase: bool = False +) -> None: """Uses `val._unitary_` to check `val._phase_by_`'s behavior.""" __tracebackhide__ = True @@ -76,7 +78,7 @@ def _known_gate_with_no_decomposition(val: Any): return False -def assert_decompose_ends_at_default_gateset(val: Any, ignore_known_gates: bool = True): +def assert_decompose_ends_at_default_gateset(val: Any, ignore_known_gates: bool = True) -> None: """Asserts that cirq.decompose(val) ends at default cirq gateset or a known gate.""" args = () if isinstance(val, ops.Operation) else (tuple(devices.LineQid.for_gate(val)),) dec_once = protocols.decompose_once(val, [val(*args[0]) if args else val], *args) diff --git a/cirq-core/cirq/testing/consistent_phase_by.py b/cirq-core/cirq/testing/consistent_phase_by.py index 9f55a7b48a3..819969bc058 100644 --- a/cirq-core/cirq/testing/consistent_phase_by.py +++ b/cirq-core/cirq/testing/consistent_phase_by.py @@ -23,7 +23,7 @@ from cirq.testing import lin_alg_utils -def assert_phase_by_is_consistent_with_unitary(val: Any): +def assert_phase_by_is_consistent_with_unitary(val: Any) -> None: """Uses `val._unitary_` to check `val._phase_by_`'s behavior.""" original = protocols.unitary(val, None) diff --git a/cirq-core/cirq/testing/consistent_qasm.py b/cirq-core/cirq/testing/consistent_qasm.py index a8e90a45742..73f43c56e60 100644 --- a/cirq-core/cirq/testing/consistent_qasm.py +++ b/cirq-core/cirq/testing/consistent_qasm.py @@ -23,7 +23,7 @@ from cirq.testing import lin_alg_utils -def assert_qasm_is_consistent_with_unitary(val: Any): +def assert_qasm_is_consistent_with_unitary(val: Any) -> None: """Uses `val._unitary_` to check `val._qasm_`'s behavior.""" # Only test if qiskit is installed. @@ -106,7 +106,7 @@ def assert_qasm_is_consistent_with_unitary(val: Any): ) -def assert_qiskit_parsed_qasm_consistent_with_unitary(qasm, unitary): # pragma: no cover +def assert_qiskit_parsed_qasm_consistent_with_unitary(qasm, unitary) -> None: # pragma: no cover try: # We don't want to require qiskit as a dependency but # if Qiskit is installed, test QASM output against it. diff --git a/cirq-core/cirq/testing/consistent_qasm_test.py b/cirq-core/cirq/testing/consistent_qasm_test.py index f7607f1e925..1caf5f2229e 100644 --- a/cirq-core/cirq/testing/consistent_qasm_test.py +++ b/cirq-core/cirq/testing/consistent_qasm_test.py @@ -32,10 +32,10 @@ def _unitary_(self): return self.unitary @property - def qubits(self): - return cirq.LineQubit.range(self.unitary.shape[0].bit_length() - 1) + def qubits(self) -> tuple[cirq.Qid, ...]: + return tuple(cirq.LineQubit.range(self.unitary.shape[0].bit_length() - 1)) - def with_qubits(self, *new_qubits): + def with_qubits(self, *new_qubits) -> Fixed: raise NotImplementedError() def _qasm_(self, args: cirq.QasmArgs): From 88f4ca293f937fd869222a3d7583fbb6be279cf4 Mon Sep 17 00:00:00 2001 From: Pavol Juhas Date: Thu, 21 Aug 2025 16:32:10 -0700 Subject: [PATCH 2/4] Fix mismatched argument type in record_channel_measurement call --- cirq-core/cirq/sim/state_vector_simulation_state.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/sim/state_vector_simulation_state.py b/cirq-core/cirq/sim/state_vector_simulation_state.py index 220d8abf2b2..1d155aeba36 100644 --- a/cirq-core/cirq/sim/state_vector_simulation_state.py +++ b/cirq-core/cirq/sim/state_vector_simulation_state.py @@ -429,7 +429,7 @@ def _strat_act_on_state_vector_from_mixture( if index is None: return NotImplemented if protocols.is_measurement(action): - key = protocols.measurement_key_name(action) + key = protocols.measurement_key_obj(action) args._classical_data.record_channel_measurement(key, index) return True @@ -441,6 +441,6 @@ def _strat_act_on_state_vector_from_channel( if index is None: return NotImplemented if protocols.is_measurement(action): - key = protocols.measurement_key_name(action) + key = protocols.measurement_key_obj(action) args._classical_data.record_channel_measurement(key, index) return True From 96e4cbea908be11c8deca74999e2facc1c8024c9 Mon Sep 17 00:00:00 2001 From: Pavol Juhas Date: Mon, 15 Sep 2025 23:03:29 -0700 Subject: [PATCH 3/4] Tighten return type annotation for CirqEncoder.default Specify types that can be actually returned rather than Any. --- cirq-core/cirq/protocols/json_serialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cirq-core/cirq/protocols/json_serialization.py b/cirq-core/cirq/protocols/json_serialization.py index 9bc643b3291..e92febf6d48 100644 --- a/cirq-core/cirq/protocols/json_serialization.py +++ b/cirq-core/cirq/protocols/json_serialization.py @@ -217,7 +217,7 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._memo: dict[Any, dict] = {} - def default(self, o) -> Any: + def default(self, o) -> dict[str, Any] | list[Any] | float | bool: # Object with custom method? if hasattr(o, '_json_dict_'): json_dict = _json_dict_with_cirq_type(o) From 0709e487acd46bd3460b09abf48d00c2ff61c3cc Mon Sep 17 00:00:00 2001 From: Pavol Juhas Date: Mon, 15 Sep 2025 23:14:53 -0700 Subject: [PATCH 4/4] Improve typing of protocol functions with optional default argument Annotate with `@overload` for the return type is known when the `default` is not provided. --- .../cirq/circuits/circuit_operation_test.py | 5 +++ .../ops/classically_controlled_operation.py | 10 ++--- cirq-core/cirq/ops/raw_types.py | 21 +++++++++- .../protocols/measurement_key_protocol.py | 41 +++++++++++++------ cirq-core/cirq/protocols/phase_protocol.py | 17 ++++++-- cirq-core/cirq/transformers/eject_z.py | 1 + 6 files changed, 70 insertions(+), 25 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit_operation_test.py b/cirq-core/cirq/circuits/circuit_operation_test.py index c5604be050a..596b473f62a 100644 --- a/cirq-core/cirq/circuits/circuit_operation_test.py +++ b/cirq-core/cirq/circuits/circuit_operation_test.py @@ -1253,22 +1253,27 @@ def test_repeat_until_protocols() -> None: # Ensure the _repeat_until has been mapped, the measurement has been mapped to the same key, # and the control keys of the subcircuit is empty (because the control key of the condition is # bound to the measurement). + assert scoped._mapped_repeat_until is not None assert scoped._mapped_repeat_until.keys == (cirq.MeasurementKey('a', ('0',)),) assert cirq.measurement_key_objs(scoped) == {cirq.MeasurementKey('a', ('0',))} assert not cirq.control_keys(scoped) mapped = cirq.with_measurement_key_mapping(scoped, {'a': 'b'}) + assert mapped._mapped_repeat_until is not None assert mapped._mapped_repeat_until.keys == (cirq.MeasurementKey('b', ('0',)),) assert cirq.measurement_key_objs(mapped) == {cirq.MeasurementKey('b', ('0',))} assert not cirq.control_keys(mapped) prefixed = cirq.with_key_path_prefix(mapped, ('1',)) + assert prefixed._mapped_repeat_until is not None assert prefixed._mapped_repeat_until.keys == (cirq.MeasurementKey('b', ('1', '0')),) assert cirq.measurement_key_objs(prefixed) == {cirq.MeasurementKey('b', ('1', '0'))} assert not cirq.control_keys(prefixed) setpath = cirq.with_key_path(prefixed, ('2',)) + assert setpath._mapped_repeat_until is not None assert setpath._mapped_repeat_until.keys == (cirq.MeasurementKey('b', ('2',)),) assert cirq.measurement_key_objs(setpath) == {cirq.MeasurementKey('b', ('2',))} assert not cirq.control_keys(setpath) resolved = cirq.resolve_parameters(setpath, {'p': 1}) + assert resolved._mapped_repeat_until is not None assert resolved._mapped_repeat_until.keys == (cirq.MeasurementKey('b', ('2',)),) assert cirq.measurement_key_objs(resolved) == {cirq.MeasurementKey('b', ('2',))} assert not cirq.control_keys(resolved) diff --git a/cirq-core/cirq/ops/classically_controlled_operation.py b/cirq-core/cirq/ops/classically_controlled_operation.py index a202036af47..bbb80823e0d 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation.py +++ b/cirq-core/cirq/ops/classically_controlled_operation.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import AbstractSet, Any, cast, Mapping, Sequence, TYPE_CHECKING +from typing import AbstractSet, Any, Mapping, Sequence, TYPE_CHECKING import sympy @@ -207,17 +207,13 @@ def _with_measurement_key_mapping_( conditions = [protocols.with_measurement_key_mapping(c, key_map) for c in self._conditions] sub_operation = protocols.with_measurement_key_mapping(self._sub_operation, key_map) sub_operation = self._sub_operation if sub_operation is NotImplemented else sub_operation - return cast( - ClassicallyControlledOperation, sub_operation.with_classical_controls(*conditions) - ) + return sub_operation.with_classical_controls(*conditions) def _with_key_path_prefix_(self, prefix: tuple[str, ...]) -> ClassicallyControlledOperation: conditions = [protocols.with_key_path_prefix(c, prefix) for c in self._conditions] sub_operation = protocols.with_key_path_prefix(self._sub_operation, prefix) sub_operation = self._sub_operation if sub_operation is NotImplemented else sub_operation - return cast( - ClassicallyControlledOperation, sub_operation.with_classical_controls(*conditions) - ) + return sub_operation.with_classical_controls(*conditions) def _with_rescoped_keys_( self, path: tuple[str, ...], bindable_keys: frozenset[cirq.MeasurementKey] diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index 7e8bf6dac1c..1043eb2f81f 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -28,6 +28,7 @@ Hashable, Iterable, Mapping, + overload, Sequence, TYPE_CHECKING, ) @@ -687,9 +688,17 @@ def classical_controls(self) -> frozenset[cirq.Condition]: """The classical controls gating this operation.""" return frozenset() + @overload + def with_classical_controls(self) -> cirq.Operation: + pass + + @overload def with_classical_controls( self, *conditions: str | cirq.MeasurementKey | cirq.Condition | sympy.Expr - ) -> cirq.Operation: + ) -> cirq.ClassicallyControlledOperation: + pass + + def with_classical_controls(self, *conditions): """Returns a classically controlled version of this operation. An operation that is classically controlled is executed iff all @@ -957,9 +966,17 @@ def without_classical_controls(self) -> cirq.Operation: new_sub_operation = self.sub_operation.without_classical_controls() return self if new_sub_operation is self.sub_operation else new_sub_operation + @overload + def with_classical_controls(self) -> cirq.Operation: + pass + + @overload def with_classical_controls( self, *conditions: str | cirq.MeasurementKey | cirq.Condition | sympy.Expr - ) -> cirq.Operation: + ) -> cirq.ClassicallyControlledOperation: + pass + + def with_classical_controls(self, *conditions): if not conditions: return self return self.sub_operation.with_classical_controls(*conditions) diff --git a/cirq-core/cirq/protocols/measurement_key_protocol.py b/cirq-core/cirq/protocols/measurement_key_protocol.py index 97b72e3b837..6346e577540 100644 --- a/cirq-core/cirq/protocols/measurement_key_protocol.py +++ b/cirq-core/cirq/protocols/measurement_key_protocol.py @@ -17,7 +17,7 @@ from __future__ import annotations from types import NotImplementedType -from typing import Any, Mapping, Protocol, TYPE_CHECKING, TypeVar +from typing import Any, Mapping, overload, Protocol, TYPE_CHECKING, TypeVar from cirq import value from cirq._doc import doc_private @@ -25,6 +25,7 @@ if TYPE_CHECKING: import cirq +T = TypeVar('T') TDefault = TypeVar('TDefault') # This is a special indicator value used by the inverse method to determine @@ -106,9 +107,17 @@ def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]): """ -def measurement_key_obj( - val: Any, default: TDefault = RaiseTypeErrorIfNotProvided -) -> cirq.MeasurementKey | TDefault: +@overload +def measurement_key_obj(val: Any) -> cirq.MeasurementKey: + pass + + +@overload +def measurement_key_obj(val: Any, default: TDefault) -> cirq.MeasurementKey | TDefault: + pass + + +def measurement_key_obj(val, default=RaiseTypeErrorIfNotProvided): """Get the single measurement key object for the given value. Args: @@ -142,9 +151,17 @@ def measurement_key_obj( raise TypeError(f"Object of type '{type(val)}' had no measurement keys.") -def measurement_key_name( - val: Any, default: TDefault = RaiseTypeErrorIfNotProvided -) -> str | TDefault: +@overload +def measurement_key_name(val: Any) -> str: + pass + + +@overload +def measurement_key_name(val: Any, default: TDefault) -> str | TDefault: + pass + + +def measurement_key_name(val, default=RaiseTypeErrorIfNotProvided): """Get the single measurement key for the given value. Args: @@ -284,7 +301,7 @@ def is_measurement(val: Any) -> bool: return keys is not NotImplemented and bool(keys) -def with_measurement_key_mapping(val: Any, key_map: Mapping[str, str]) -> Any: +def with_measurement_key_mapping(val: T, key_map: Mapping[str, str]) -> T: """Remaps the target's measurement keys according to the provided key_map. This method can be used to reassign measurement keys at runtime, or to @@ -294,7 +311,7 @@ def with_measurement_key_mapping(val: Any, key_map: Mapping[str, str]) -> Any: return NotImplemented if getter is None else getter(key_map) -def with_key_path(val: Any, path: tuple[str, ...]) -> Any: +def with_key_path(val: T, path: tuple[str, ...]) -> T: """Adds the path to the target's measurement keys. The path usually refers to an identifier or a list of identifiers from a subcircuit that @@ -305,7 +322,7 @@ def with_key_path(val: Any, path: tuple[str, ...]) -> Any: return NotImplemented if getter is None else getter(path) -def with_key_path_prefix(val: Any, prefix: tuple[str, ...]) -> Any: +def with_key_path_prefix(val: T, prefix: tuple[str, ...]) -> T: """Prefixes the path to the target's measurement keys. The path usually refers to an identifier or a list of identifiers from a subcircuit that @@ -321,8 +338,8 @@ def with_key_path_prefix(val: Any, prefix: tuple[str, ...]) -> Any: def with_rescoped_keys( - val: Any, path: tuple[str, ...], bindable_keys: frozenset[cirq.MeasurementKey] | None = None -) -> Any: + val: T, path: tuple[str, ...], bindable_keys: frozenset[cirq.MeasurementKey] | None = None +) -> T: """Rescopes any measurement and control keys to the provided path, given the existing keys. The path usually refers to an identifier or a list of identifiers from a subcircuit that diff --git a/cirq-core/cirq/protocols/phase_protocol.py b/cirq-core/cirq/protocols/phase_protocol.py index bdfa5b9f8ab..f001a553e7c 100644 --- a/cirq-core/cirq/protocols/phase_protocol.py +++ b/cirq-core/cirq/protocols/phase_protocol.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Any, Protocol, TypeVar +from typing import Any, overload, Protocol, TypeVar # This is a special value to indicate that a type error should be returned. # This is used within phase_by to raise an error if no underlying @@ -24,6 +24,7 @@ RaiseTypeErrorIfNotProvided: Any = ([],) TDefault = TypeVar('TDefault') +T = TypeVar('T') class SupportsPhase(Protocol): @@ -47,9 +48,17 @@ def _phase_by_(self, phase_turns: float, qubit_index: int): """ -def phase_by( - val: Any, phase_turns: float, qubit_index: int, default: TDefault = RaiseTypeErrorIfNotProvided -) -> Any: +@overload +def phase_by(val: T, phase_turns: float, qubit_index: int) -> T: + pass + + +@overload +def phase_by(val: T, phase_turns: float, qubit_index: int, default: TDefault) -> T | TDefault: + pass + + +def phase_by(val, phase_turns, qubit_index, default=RaiseTypeErrorIfNotProvided): """Returns a phased version of the effect. For example, an X gate phased by 90 degrees would be a Y gate. diff --git a/cirq-core/cirq/transformers/eject_z.py b/cirq-core/cirq/transformers/eject_z.py index 6b65ec28f6c..03bb02dd790 100644 --- a/cirq-core/cirq/transformers/eject_z.py +++ b/cirq-core/cirq/transformers/eject_z.py @@ -119,6 +119,7 @@ def map_func(op: cirq.Operation, moment_index: int) -> cirq.OP_TREE: return [] # Try to move the tracked phases over the operation via protocols.phase_by(op) + phased_op: cirq.Operation | None phased_op = op for i, p in enumerate([qubit_phase[q] for q in op.qubits]): if not single_qubit_decompositions.is_negligible_turn(p, atol):