Skip to content
56 changes: 16 additions & 40 deletions cirq-core/cirq/ops/dense_pauli_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def __truediv__(self, other):

def __mul__(self, other):
concrete_class = type(self)
other = _try_interpret_as_dps(other) or other
if isinstance(other, BaseDensePauliString):
if isinstance(other, MutableDensePauliString):
concrete_class = MutableDensePauliString
Expand All @@ -268,31 +269,14 @@ def __mul__(self, other):
return NotImplemented
return concrete_class(pauli_mask=self.pauli_mask, coefficient=new_coef)

split = _attempt_value_to_pauli_index(other)
if split is not None:
p, i = split
mask = np.copy(self.pauli_mask)
mask[i] ^= p
return concrete_class(
pauli_mask=mask,
coefficient=self.coefficient * _vectorized_pauli_mul_phase(self.pauli_mask[i], p),
)

return NotImplemented

def __rmul__(self, other):
if isinstance(other, (sympy.Basic, numbers.Number)):
return self.__mul__(other)

split = _attempt_value_to_pauli_index(other)
if split is not None:
p, i = split
mask = np.copy(self.pauli_mask)
mask[i] ^= p
return type(self)(
pauli_mask=mask,
coefficient=self.coefficient * _vectorized_pauli_mul_phase(p, self.pauli_mask[i]),
)
if other := _try_interpret_as_dps(other):
return other.__mul__(self)

return NotImplemented

Expand Down Expand Up @@ -369,18 +353,12 @@ def __repr__(self) -> str:
)

def _commutes_(self, other: Any, *, atol: float = 1e-8) -> bool | NotImplementedType | None:
other = _try_interpret_as_dps(other)
if isinstance(other, BaseDensePauliString):
n = min(len(self.pauli_mask), len(other.pauli_mask))
phase = _vectorized_pauli_mul_phase(self.pauli_mask[:n], other.pauli_mask[:n])
return phase == 1 or phase == -1

# Single qubit Pauli operation.
split = _attempt_value_to_pauli_index(other)
if split is not None:
p1, i = split
p2 = self.pauli_mask[i]
return (p1 or p2) == (p2 or p1)

return NotImplemented

def frozen(self) -> DensePauliString:
Expand Down Expand Up @@ -518,6 +496,7 @@ def __itruediv__(self, other):
return NotImplemented

def __imul__(self, other):
other = _try_interpret_as_dps(other) or other
if isinstance(other, BaseDensePauliString):
if len(other) > len(self):
raise ValueError(
Expand All @@ -539,13 +518,6 @@ def __imul__(self, other):
self._coefficient = new_coef if isinstance(new_coef, sympy.Basic) else complex(new_coef)
return self

split = _attempt_value_to_pauli_index(other)
if split is not None:
p, i = split
self._coefficient *= _vectorized_pauli_mul_phase(self.pauli_mask[i], p)
self.pauli_mask[i] ^= p
return self

return NotImplemented

def copy(
Expand Down Expand Up @@ -613,23 +585,27 @@ def _as_pauli_mask(val: Iterable[cirq.PAULI_GATE_LIKE] | np.ndarray) -> np.ndarr
return np.array([_pauli_index(v) for v in val], dtype=np.uint8)


def _attempt_value_to_pauli_index(v: cirq.Operation) -> tuple[int, int] | None:
def _try_interpret_as_dps(v: cirq.Operation) -> BaseDensePauliString | None:
if isinstance(v, BaseDensePauliString):
return v

if (ps := pauli_string._try_interpret_as_pauli_string(v)) is None:
return None

if len(ps.qubits) != 1:
return None # pragma: no cover

q = ps.qubits[0]
from cirq import devices

if not isinstance(q, devices.LineQubit):
if not all(isinstance(q, devices.LineQubit) for q in ps.qubits):
raise ValueError(
'Got a Pauli operation, but it was applied to a qubit type '
'other than `cirq.LineQubit` so its dense index is ambiguous.\n'
f'v={repr(v)}.'
)
return pauli_string.PAULI_GATE_LIKE_TO_INDEX_MAP[ps[q]], q.x

pauli_mask = np.zeros(max([q.x + 1 for q in ps.qubits], default=0), dtype=np.uint8)
for q in ps.qubits:
pauli_mask[q.x] = pauli_string.PAULI_GATE_LIKE_TO_INDEX_MAP[ps[q]]

return DensePauliString(pauli_mask)


def _vectorized_pauli_mul_phase(lhs: int | np.ndarray, rhs: int | np.ndarray) -> complex:
Expand Down
17 changes: 17 additions & 0 deletions cirq-core/cirq/ops/dense_pauli_string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ def test_mul() -> None:
with pytest.raises(ValueError, match='other than `cirq.LineQubit'):
_ = f('III') * cirq.X(cirq.NamedQubit('tmp'))

# Parity operations.
assert f('IXYZ') * cirq.XX(*cirq.LineQubit.range(1, 3)) == -1j * f('IIZZ')
assert cirq.XX(*cirq.LineQubit.range(1, 3)) * f('IXYZ') == 1j * f('IIZZ')

# Mixed types.
m = cirq.MutableDensePauliString
assert m('X') * m('Z') == -1j * m('Y')
Expand All @@ -187,6 +191,10 @@ def test_mul() -> None:
assert f('I') * f('III') == f('III')
assert f('X') * f('XXX') == f('IXX')
assert f('XXX') * f('X') == f('IXX')
assert f('X') * cirq.Y(cirq.LineQubit(2)) == f('XIY')
assert f('XY') * cirq.YY(*cirq.LineQubit.range(1, 3)) == f('XIY')
assert cirq.X(cirq.LineQubit(2)) * f('Y') == f('YIX')
assert cirq.XX(*cirq.LineQubit.range(1, 3)) * f('YX') == f('YIX')

with pytest.raises(TypeError):
_ = f('I') * object()
Expand Down Expand Up @@ -235,8 +243,15 @@ def test_imul() -> None:
p *= cirq.X(cirq.LineQubit(1))
assert p == m('IZI')

p *= cirq.ZZ(*cirq.LineQubit.range(1, 3))
assert p == m('IIZ')

with pytest.raises(ValueError, match='smaller than'):
p *= f('XXXXXXXXXXXX')
with pytest.raises(ValueError, match='smaller than'):
p *= cirq.X(cirq.LineQubit(3))
with pytest.raises(ValueError, match='smaller than'):
p *= cirq.XX(*cirq.LineQubit.range(2, 4))
with pytest.raises(TypeError):
p *= object()

Expand Down Expand Up @@ -511,6 +526,8 @@ def test_commutes() -> None:
assert cirq.commutes(f('IIIXII'), cirq.X(cirq.LineQubit(2)) ** 3)
assert not cirq.commutes(f('IIIXII'), cirq.Z(cirq.LineQubit(3)) ** 3)
assert cirq.commutes(f('IIIXII'), cirq.Z(cirq.LineQubit(2)) ** 3)
assert cirq.commutes(f('X'), cirq.Z(cirq.LineQubit(10)))
assert cirq.commutes(cirq.Z(cirq.LineQubit(10)), f('X'))

assert cirq.commutes(f('XX'), "test", default=NotImplemented) is NotImplemented

Expand Down