Skip to content
107 changes: 40 additions & 67 deletions cirq-core/cirq/ops/dense_pauli_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,35 +247,25 @@ def __truediv__(self, other):

def __mul__(self, other):
concrete_class = type(self)
if isinstance(other, BaseDensePauliString):
if isinstance(other, MutableDensePauliString):
concrete_class = MutableDensePauliString
max_len = max(len(self.pauli_mask), len(other.pauli_mask))
min_len = min(len(self.pauli_mask), len(other.pauli_mask))
new_mask = np.zeros(max_len, dtype=np.uint8)
new_mask[: len(self.pauli_mask)] ^= self.pauli_mask
new_mask[: len(other.pauli_mask)] ^= other.pauli_mask
tweak = _vectorized_pauli_mul_phase(
self.pauli_mask[:min_len], other.pauli_mask[:min_len]
)
return concrete_class(
pauli_mask=new_mask, coefficient=self.coefficient * other.coefficient * tweak
)

if isinstance(other, (sympy.Basic, numbers.Number)):
new_coef = protocols.mul(self.coefficient, other, default=None)
if new_coef is None:
return NotImplemented
return concrete_class(pauli_mask=self.pauli_mask, coefficient=new_coef)

split = _attempt_value_to_pauli_index(other)
if split is not None:
p, i = split
mask = np.copy(self.pauli_mask)
mask[i] ^= p
if (other_dps := _try_interpret_as_dps(other)) is not None:
if isinstance(other_dps, MutableDensePauliString):
concrete_class = MutableDensePauliString
max_len = max(len(self.pauli_mask), len(other_dps.pauli_mask))
min_len = min(len(self.pauli_mask), len(other_dps.pauli_mask))
new_mask = np.zeros(max_len, dtype=np.uint8)
new_mask[: len(self.pauli_mask)] ^= self.pauli_mask
new_mask[: len(other_dps.pauli_mask)] ^= other_dps.pauli_mask
tweak = _vectorized_pauli_mul_phase(
self.pauli_mask[:min_len], other_dps.pauli_mask[:min_len]
)
return concrete_class(
pauli_mask=mask,
coefficient=self.coefficient * _vectorized_pauli_mul_phase(self.pauli_mask[i], p),
pauli_mask=new_mask, coefficient=self.coefficient * other_dps.coefficient * tweak
)

return NotImplemented
Expand All @@ -284,15 +274,8 @@ def __rmul__(self, other):
if isinstance(other, (sympy.Basic, numbers.Number)):
return self.__mul__(other)

split = _attempt_value_to_pauli_index(other)
if split is not None:
p, i = split
mask = np.copy(self.pauli_mask)
mask[i] ^= p
return type(self)(
pauli_mask=mask,
coefficient=self.coefficient * _vectorized_pauli_mul_phase(p, self.pauli_mask[i]),
)
if other := _try_interpret_as_dps(other):
return other.__mul__(self)

return NotImplemented

Expand Down Expand Up @@ -369,18 +352,11 @@ def __repr__(self) -> str:
)

def _commutes_(self, other: Any, *, atol: float = 1e-8) -> bool | NotImplementedType | None:
if isinstance(other, BaseDensePauliString):
n = min(len(self.pauli_mask), len(other.pauli_mask))
phase = _vectorized_pauli_mul_phase(self.pauli_mask[:n], other.pauli_mask[:n])
if (other_dps := _try_interpret_as_dps(other)) is not None:
n = min(len(self.pauli_mask), len(other_dps.pauli_mask))
phase = _vectorized_pauli_mul_phase(self.pauli_mask[:n], other_dps.pauli_mask[:n])
return phase == 1 or phase == -1

# Single qubit Pauli operation.
split = _attempt_value_to_pauli_index(other)
if split is not None:
p1, i = split
p2 = self.pauli_mask[i]
return (p1 or p2) == (p2 or p1)

return NotImplemented

def frozen(self) -> DensePauliString:
Expand Down Expand Up @@ -518,32 +494,25 @@ def __itruediv__(self, other):
return NotImplemented

def __imul__(self, other):
if isinstance(other, BaseDensePauliString):
if len(other) > len(self):
raise ValueError(
"The receiving dense pauli string is smaller than "
"the dense pauli string being multiplied into it.\n"
f"self={repr(self)}\n"
f"other={repr(other)}"
)
self_mask = self.pauli_mask[: len(other.pauli_mask)]
self._coefficient *= _vectorized_pauli_mul_phase(self_mask, other.pauli_mask)
self._coefficient *= other.coefficient
self_mask ^= other.pauli_mask
return self

if isinstance(other, (sympy.Basic, numbers.Number)):
new_coef = protocols.mul(self.coefficient, other, default=None)
if new_coef is None:
return NotImplemented
self._coefficient = new_coef if isinstance(new_coef, sympy.Basic) else complex(new_coef)
return self

split = _attempt_value_to_pauli_index(other)
if split is not None:
p, i = split
self._coefficient *= _vectorized_pauli_mul_phase(self.pauli_mask[i], p)
self.pauli_mask[i] ^= p
if (other_dps := _try_interpret_as_dps(other)) is not None:
if len(other_dps) > len(self):
raise ValueError(
"The receiving dense pauli string is smaller than "
"the dense pauli string being multiplied into it.\n"
f"self={repr(self)}\n"
f"other={repr(other)}"
)
self_mask = self.pauli_mask[: len(other_dps.pauli_mask)]
self._coefficient *= _vectorized_pauli_mul_phase(self_mask, other_dps.pauli_mask)
self._coefficient *= other_dps.coefficient
self_mask ^= other_dps.pauli_mask
return self

return NotImplemented
Expand Down Expand Up @@ -613,23 +582,27 @@ def _as_pauli_mask(val: Iterable[cirq.PAULI_GATE_LIKE] | np.ndarray) -> np.ndarr
return np.array([_pauli_index(v) for v in val], dtype=np.uint8)


def _attempt_value_to_pauli_index(v: cirq.Operation) -> tuple[int, int] | None:
def _try_interpret_as_dps(v: cirq.Operation) -> BaseDensePauliString | None:
if isinstance(v, BaseDensePauliString):
return v

if (ps := pauli_string._try_interpret_as_pauli_string(v)) is None:
return None

if len(ps.qubits) != 1:
return None # pragma: no cover

q = ps.qubits[0]
from cirq import devices

if not isinstance(q, devices.LineQubit):
if not all(isinstance(q, devices.LineQubit) for q in ps.qubits):
raise ValueError(
'Got a Pauli operation, but it was applied to a qubit type '
'other than `cirq.LineQubit` so its dense index is ambiguous.\n'
f'v={repr(v)}.'
)
return pauli_string.PAULI_GATE_LIKE_TO_INDEX_MAP[ps[q]], q.x

pauli_mask = np.zeros(max((q.x + 1 for q in ps.qubits), default=0), dtype=np.uint8)
for q in ps.qubits:
pauli_mask[q.x] = pauli_string.PAULI_GATE_LIKE_TO_INDEX_MAP[ps[q]]

return DensePauliString(pauli_mask)


def _vectorized_pauli_mul_phase(lhs: int | np.ndarray, rhs: int | np.ndarray) -> complex:
Expand Down
17 changes: 17 additions & 0 deletions cirq-core/cirq/ops/dense_pauli_string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ def test_mul() -> None:
with pytest.raises(ValueError, match='other than `cirq.LineQubit'):
_ = f('III') * cirq.X(cirq.NamedQubit('tmp'))

# Parity operations.
assert f('IXYZ') * cirq.XX(*cirq.LineQubit.range(1, 3)) == -1j * f('IIZZ')
assert cirq.XX(*cirq.LineQubit.range(1, 3)) * f('IXYZ') == 1j * f('IIZZ')

# Mixed types.
m = cirq.MutableDensePauliString
assert m('X') * m('Z') == -1j * m('Y')
Expand All @@ -187,6 +191,10 @@ def test_mul() -> None:
assert f('I') * f('III') == f('III')
assert f('X') * f('XXX') == f('IXX')
assert f('XXX') * f('X') == f('IXX')
assert f('X') * cirq.Y(cirq.LineQubit(2)) == f('XIY')
assert f('XY') * cirq.YY(*cirq.LineQubit.range(1, 3)) == f('XIY')
assert cirq.X(cirq.LineQubit(2)) * f('Y') == f('YIX')
assert cirq.XX(*cirq.LineQubit.range(1, 3)) * f('YX') == f('YIX')

with pytest.raises(TypeError):
_ = f('I') * object()
Expand Down Expand Up @@ -235,8 +243,15 @@ def test_imul() -> None:
p *= cirq.X(cirq.LineQubit(1))
assert p == m('IZI')

p *= cirq.ZZ(*cirq.LineQubit.range(1, 3))
assert p == m('IIZ')

with pytest.raises(ValueError, match='smaller than'):
p *= f('XXXXXXXXXXXX')
with pytest.raises(ValueError, match='smaller than'):
p *= cirq.X(cirq.LineQubit(3))
with pytest.raises(ValueError, match='smaller than'):
p *= cirq.XX(*cirq.LineQubit.range(2, 4))
with pytest.raises(TypeError):
p *= object()

Expand Down Expand Up @@ -511,6 +526,8 @@ def test_commutes() -> None:
assert cirq.commutes(f('IIIXII'), cirq.X(cirq.LineQubit(2)) ** 3)
assert not cirq.commutes(f('IIIXII'), cirq.Z(cirq.LineQubit(3)) ** 3)
assert cirq.commutes(f('IIIXII'), cirq.Z(cirq.LineQubit(2)) ** 3)
assert cirq.commutes(f('X'), cirq.Z(cirq.LineQubit(10)))
assert cirq.commutes(cirq.Z(cirq.LineQubit(10)), f('X'))

assert cirq.commutes(f('XX'), "test", default=NotImplemented) is NotImplemented

Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/ops/pauli_string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ def test_constructor_flexibility() -> None:
@pytest.mark.parametrize('qubit_pauli_map', _sample_qubit_pauli_maps())
def test_getitem(qubit_pauli_map) -> None:
other = cirq.NamedQubit('other')
pauli_string: cirq.PauliString[cirq.NamedQubit]
pauli_string = cirq.PauliString(qubit_pauli_map=qubit_pauli_map)
for key in qubit_pauli_map:
assert qubit_pauli_map[key] == pauli_string[key]
Expand Down