Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions cirq-core/cirq/ops/common_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,12 +326,6 @@ def _json_dict_(self) -> Dict[str, Any]:
d['dimension'] = self.dimension
return d

def _value_equality_values_(self):
return (*super()._value_equality_values_(), self._dimension)

def _value_equality_approximate_values_(self):
return (*super()._value_equality_approximate_values_(), self._dimension)


class Rx(XPowGate):
r"""A gate with matrix $e^{-i X t/2}$ that rotates around the X axis of the Bloch sphere by $t$.
Expand Down Expand Up @@ -862,12 +856,6 @@ def _json_dict_(self) -> Dict[str, Any]:
d['dimension'] = self.dimension
return d

def _value_equality_values_(self):
return (*super()._value_equality_values_(), self._dimension)

def _value_equality_approximate_values_(self):
return (*super()._value_equality_approximate_values_(), self._dimension)


class Rz(ZPowGate):
r"""A gate with matrix $e^{-i Z t/2}$ that rotates around the Z axis of the Bloch sphere by $t$.
Expand Down
5 changes: 3 additions & 2 deletions cirq-core/cirq/ops/common_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,12 @@ def test_rot_gates_eq():
eq.add_equality_group(cirq.YPowGate(), cirq.YPowGate(exponent=1), cirq.Y)
eq.add_equality_group(cirq.ZPowGate(), cirq.ZPowGate(exponent=1), cirq.Z)
eq.add_equality_group(
cirq.ZPowGate(exponent=1, global_shift=-0.5), cirq.ZPowGate(exponent=5, global_shift=-0.5)
cirq.ZPowGate(exponent=1, global_shift=-0.5),
cirq.ZPowGate(exponent=5, global_shift=-0.5),
cirq.ZPowGate(exponent=5, global_shift=-0.1),
)
eq.add_equality_group(cirq.ZPowGate(exponent=3, global_shift=-0.5))
eq.add_equality_group(cirq.ZPowGate(exponent=1, global_shift=-0.1))
eq.add_equality_group(cirq.ZPowGate(exponent=5, global_shift=-0.1))
eq.add_equality_group(
cirq.CNotPowGate(), cirq.CXPowGate(), cirq.CNotPowGate(exponent=1), cirq.CNOT
)
Expand Down
43 changes: 6 additions & 37 deletions cirq-core/cirq/ops/eigen_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import sympy

from cirq import protocols, value
from cirq.linalg import tolerance
from cirq.ops import raw_types

if TYPE_CHECKING:
Expand Down Expand Up @@ -122,7 +121,6 @@ def __init__(
exponent = exponent.real
self._exponent = exponent
self._global_shift = global_shift
self._canonical_exponent_cached = None

@property
def exponent(self) -> value.TParamVal:
Expand Down Expand Up @@ -305,30 +303,12 @@ def __pow__(self, exponent: Union[float, sympy.Symbol]) -> 'EigenGate':
return NotImplemented # pragma: no cover
return self._with_exponent(exponent=new_exponent)

@property
def _canonical_exponent(self):
if self._canonical_exponent_cached is None:
period = self._period()
if not period:
self._canonical_exponent_cached = self._exponent
elif protocols.is_parameterized(self._exponent):
self._canonical_exponent_cached = self._exponent
if isinstance(self._exponent, sympy.Number):
self._canonical_exponent_cached = float(self._exponent)
else:
self._canonical_exponent_cached = self._exponent % period
return self._canonical_exponent_cached

def _value_equality_values_(self):
return self._canonical_exponent, self._global_shift

def _value_equality_approximate_values_(self):
period = self._period()
if not period or protocols.is_parameterized(self._exponent):
exponent = self._exponent
else:
exponent = value.PeriodicValue(self._exponent, period)
return exponent, self._global_shift
"""The phases by which we multiply the eigencomponents."""
symbolic = lambda x: isinstance(x, sympy.Expr) and x.free_symbols
f = lambda x: x if symbolic(x) else float(x)
shifts = (f(self._exponent) * f(self._global_shift + e) for e in self._eigen_shifts())
return tuple(s if symbolic(s) else value.PeriodicValue(f(s), 2) for s in shifts)

def _trace_distance_bound_(self) -> Optional[float]:
if protocols.is_parameterized(self._exponent):
Expand Down Expand Up @@ -378,20 +358,9 @@ def _equal_up_to_global_phase_(self, other, atol):
return False
self_without_phase = self._with_exponent(self.exponent)
self_without_phase._global_shift = 0
self_without_exp_or_phase = self_without_phase._with_exponent(0)
self_without_exp_or_phase._global_shift = 0
other_without_phase = other._with_exponent(other.exponent)
other_without_phase._global_shift = 0
other_without_exp_or_phase = other_without_phase._with_exponent(0)
other_without_exp_or_phase._global_shift = 0
if not protocols.approx_eq(
self_without_exp_or_phase, other_without_exp_or_phase, atol=atol
):
return False

period = self_without_phase._period()
exponents_diff = exponents[0] - exponents[1]
return tolerance.near_zero_mod(exponents_diff, period, atol=atol)
return protocols.approx_eq(self_without_phase, other_without_phase, atol=atol)

def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['exponent', 'global_shift'])
Expand Down
70 changes: 60 additions & 10 deletions cirq-core/cirq/ops/eigen_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from typing import List, Tuple

import numpy as np
Expand Down Expand Up @@ -50,7 +49,7 @@ def _eigen_components(self) -> List[Tuple[float, np.ndarray]]:
]


class ZGateDef(cirq.EigenGate, cirq.testing.TwoQubitGate):
class ZGateDef(cirq.EigenGate, cirq.testing.SingleQubitGate):
@property
def exponent(self):
return self._exponent
Expand Down Expand Up @@ -97,7 +96,6 @@ def test_eq():
eq.make_equality_group(lambda: CExpZinGate(quarter_turns=0.1))
eq.add_equality_group(CExpZinGate(0), CExpZinGate(4), CExpZinGate(-4))

# Equates by canonicalized period.
eq.add_equality_group(CExpZinGate(1.5), CExpZinGate(41.5))
eq.add_equality_group(CExpZinGate(3.5), CExpZinGate(-0.5))

Expand All @@ -109,6 +107,64 @@ def test_eq():
eq.add_equality_group(ZGateDef(exponent=0.5, global_shift=0.5))
eq.add_equality_group(ZGateDef(exponent=1.0, global_shift=0.5))

# All variants of (0,0) == (0*a,0*a) == (0, 2) == (2, 2)
a, b = sympy.symbols('a, b')
eq.add_equality_group(
WeightedZPowGate(0),
WeightedZPowGate(0) ** 1.1,
WeightedZPowGate(0) ** a,
(WeightedZPowGate(0) ** a) ** 1.2,
WeightedZPowGate(0) ** (a + 1.3),
WeightedZPowGate(0) ** b,
WeightedZPowGate(1) ** 2,
WeightedZPowGate(0, global_shift=1) ** 2,
WeightedZPowGate(1, global_shift=1) ** 2,
WeightedZPowGate(2),
WeightedZPowGate(0, global_shift=2),
WeightedZPowGate(2, global_shift=2),
)
# WeightedZPowGate(2) is identity, but non-integer exponent would make it different, similar to
# how we treat (X**2)**0.5==X. So these are in their own equality group. (0, 2*a)
eq.add_equality_group(
WeightedZPowGate(2) ** a,
(WeightedZPowGate(1) ** 2) ** a,
(WeightedZPowGate(1) ** a) ** 2,
WeightedZPowGate(1) ** (a * 2),
WeightedZPowGate(1) ** (a + a),
)
# Similarly, these are identity without the exponent, but global_shift affects both phases
# instead of just the one, so will have a different effect from the above depending on the
# exponent. (2*a, 0)
eq.add_equality_group(
WeightedZPowGate(0, global_shift=2) ** a,
(WeightedZPowGate(0, global_shift=1) ** 2) ** a,
(WeightedZPowGate(0, global_shift=1) ** a) ** 2,
WeightedZPowGate(0, global_shift=1) ** (a * 2),
WeightedZPowGate(0, global_shift=1) ** (a + a),
)
# Symbolic exponents that cancel (0, 1) == (0, a/a)
eq.add_equality_group(
WeightedZPowGate(1),
WeightedZPowGate(a) ** (1 / a),
WeightedZPowGate(b) ** (1 / b),
WeightedZPowGate(1 / a) ** a,
WeightedZPowGate(1 / b) ** b,
)
# Symbol in one phase and constant off by period in another (0, a) == (2, a)
eq.add_equality_group(
WeightedZPowGate(a),
WeightedZPowGate(a - 2, global_shift=2),
WeightedZPowGate(1 - 2 / a, global_shift=2 / a) ** a,
)
# Different symbol, different equality group (0, b)
eq.add_equality_group(WeightedZPowGate(b))
# Various number types
eq.add_equality_group(
WeightedZPowGate(np.int64(3), global_shift=sympy.Number(5)) ** 7.0,
WeightedZPowGate(sympy.Number(3), global_shift=5.0) ** np.int64(7),
WeightedZPowGate(3.0, global_shift=np.int64(5)) ** sympy.Number(7),
)


def test_approx_eq():
assert cirq.approx_eq(CExpZinGate(1.5), CExpZinGate(1.5), atol=0.1)
Expand All @@ -118,8 +174,7 @@ def test_approx_eq():
assert cirq.approx_eq(ZGateDef(exponent=1.5), ZGateDef(exponent=1.5), atol=0.1)
assert not cirq.approx_eq(CExpZinGate(1.5), ZGateDef(exponent=1.5), atol=0.1)
with pytest.raises(
TypeError,
match=re.escape("unsupported operand type(s) for -: 'Symbol' and 'PeriodicValue'"),
TypeError, match="unsupported operand type\\(s\\) for -: '.*' and 'PeriodicValue'"
):
cirq.approx_eq(ZGateDef(exponent=1.5), ZGateDef(exponent=sympy.Symbol('a')), atol=0.1)
assert cirq.approx_eq(CExpZinGate(sympy.Symbol('a')), CExpZinGate(sympy.Symbol('a')), atol=0.1)
Expand Down Expand Up @@ -333,11 +388,6 @@ def __init__(self, weight, **kwargs):
self.weight = weight
super().__init__(**kwargs)

def _value_equality_values_(self):
return self.weight, self._canonical_exponent, self._global_shift

_value_equality_approximate_values_ = _value_equality_values_

def _eigen_components(self) -> List[Tuple[float, np.ndarray]]:
return [(0, np.diag([1, 0])), (self.weight, np.diag([0, 1]))]

Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/ops/parity_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def test_xx_eq():
cirq.XXPowGate(),
cirq.XXPowGate(exponent=1, global_shift=0),
cirq.XXPowGate(exponent=3, global_shift=0),
cirq.XXPowGate(global_shift=100000),
)
eq.add_equality_group(cirq.XX**0.5, cirq.XX**2.5, cirq.XX**4.5)
eq.add_equality_group(cirq.XX**0.25, cirq.XX**2.25, cirq.XX**-1.75)
Expand Down
5 changes: 0 additions & 5 deletions cirq-core/cirq/ops/pauli_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,6 @@ def on(self, *qubits: 'cirq.Qid') -> 'SingleQubitPauliStringGateOperation':

return pauli_string.SingleQubitPauliStringGateOperation(self, qubits[0])

@property
def _canonical_exponent(self):
"""Overrides EigenGate._canonical_exponent in subclasses."""
return 1


class _PauliX(Pauli, common_gates.XPowGate):
def __init__(self):
Expand Down
8 changes: 7 additions & 1 deletion cirq-core/cirq/ops/pauli_interaction_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,13 @@ def _num_qubits_(self) -> int:
return 2

def _value_equality_values_(self):
return (self.pauli0, self.invert0, self.pauli1, self.invert1, self._canonical_exponent)
return (
self.pauli0,
self.invert0,
self.pauli1,
self.invert1,
value.PeriodicValue(self.exponent, 2),
)

def qubit_index_to_equivalence_group_key(self, index: int) -> int:
if self.pauli0 == self.pauli1 and self.invert0 == self.invert1:
Expand Down
Loading