Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions cirq-core/cirq/ops/common_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,15 +1089,12 @@ def __init__(self, p: float) -> None:
ValueError: if p is not a valid probability.
"""
self._p = value.validate_probability(p, 'p')
self._delegate = AsymmetricDepolarizingChannel(p, 0.0, 0.0)

def _num_qubits_(self) -> int:
return 1

def _mixture_(self) -> Sequence[Tuple[float, np.ndarray]]:
mixture = self._delegate._mixture_()
# just return identity and x term
return (mixture[0], mixture[1])
def _mixture_(self) -> Sequence[Tuple[float, Any]]:
return ((1 - self._p, identity.I), (self._p, pauli_gates.X))

def _has_mixture_(self) -> bool:
return True
Expand Down
19 changes: 13 additions & 6 deletions cirq-core/cirq/protocols/apply_mixture_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from cirq._doc import doc_private
from cirq.protocols import qid_shape_protocol
from cirq.protocols.apply_unitary_protocol import apply_unitary, ApplyUnitaryArgs
from cirq.protocols.mixture_protocol import mixture

# This is a special indicator value used by the apply_mixture method
# to determine whether or not the caller provided a 'default' argument. It must
Expand Down Expand Up @@ -260,9 +259,9 @@ def err_str(buf_num_str):
return result

# Fallback to using the object's `_mixture_` matrices. (STEP C)
prob_mix = mixture(val, None)
if prob_mix is not None:
return _mixture_strat(prob_mix, args, is_density_matrix)
result = _apply_mixture_from_mixture_strat(val, args, is_density_matrix)
if result is not None:
return result

# Don't know how to apply mixture. Fallback to specified default behavior.
# (STEP D)
Expand Down Expand Up @@ -359,11 +358,19 @@ def _apply_unitary_from_matrix_strat(
return args.target_tensor


def _mixture_strat(val: Any, args: 'ApplyMixtureArgs', is_density_matrix: bool) -> np.ndarray:
def _apply_mixture_from_mixture_strat(
val: Any, args: 'ApplyMixtureArgs', is_density_matrix: bool
) -> Optional[np.ndarray]:
"""Attempt to use unitary matrices in _mixture_ and return the result."""
method = getattr(val, '_mixture_', None)
if method is None:
return None
prob_mix = method()
if prob_mix is NotImplemented or prob_mix is None:
return None
args.out_buffer[:] = 0
np.copyto(dst=args.auxiliary_buffer1, src=args.target_tensor)
for prob, op in val:
for prob, op in prob_mix:
np.copyto(dst=args.target_tensor, src=args.auxiliary_buffer1)
right_result = _apply_unitary_strat(op, args, is_density_matrix)
if right_result is None:
Expand Down
11 changes: 11 additions & 0 deletions cirq-core/cirq/protocols/apply_mixture_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,17 @@ class NoProtocols:
assert_apply_mixture_returns(NoProtocols(), rho, left_axes=[1], right_axes=[1])


def test_apply_mixture_mixture_returns_not_implemented():
class NoMixture:
def _mixture_(self):
return NotImplemented

rho = np.ones((2, 2, 2, 2), dtype=np.complex128)

with pytest.raises(TypeError, match='has no'):
assert_apply_mixture_returns(NoMixture(), rho, left_axes=[1], right_axes=[1])


def test_apply_mixture_no_protocols_implemented_default():
class NoProtocols:
pass
Expand Down
5 changes: 4 additions & 1 deletion cirq-core/cirq/protocols/kraus_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from cirq._doc import doc_private
from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits
from cirq.protocols.mixture_protocol import has_mixture
from cirq.protocols.unitary_protocol import unitary

# This is a special indicator value used by the channel method to determine
# whether or not the caller provided a 'default' argument. It must be of type
Expand Down Expand Up @@ -145,7 +146,9 @@ def kraus(
mixture_getter = getattr(val, '_mixture_', None)
mixture_result = NotImplemented if mixture_getter is None else mixture_getter()
if mixture_result is not NotImplemented and mixture_result is not None:
return tuple(np.sqrt(p) * u for p, u in mixture_result)
return tuple(
np.sqrt(p) * (u if isinstance(u, np.ndarray) else unitary(u)) for p, u in mixture_result
)

unitary_getter = getattr(val, '_unitary_', None)
unitary_result = NotImplemented if unitary_getter is None else unitary_getter()
Expand Down
7 changes: 4 additions & 3 deletions cirq-core/cirq/protocols/mixture_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from cirq._doc import doc_private
from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits
from cirq.protocols.has_unitary_protocol import has_unitary
from cirq.protocols.unitary_protocol import unitary

# This is a special indicator value used by the inverse method to determine
# whether or not the caller provided a 'default' argument.
Expand Down Expand Up @@ -84,14 +85,14 @@ def mixture(
with that probability in the mixture. The probabilities will sum to 1.0.
Raises:
TypeError: If `val` has no `_mixture_` or `_unitary_` mehod, or if it
TypeError: If `val` has no `_mixture_` or `_unitary_` method, or if it
does and this method returned `NotImplemented`.
"""

mixture_getter = getattr(val, '_mixture_', None)
result = NotImplemented if mixture_getter is None else mixture_getter()
if result is not NotImplemented:
return result
if result is not NotImplemented and result is not None:
return tuple((p, u if isinstance(u, np.ndarray) else unitary(u)) for p, u in result)

unitary_getter = getattr(val, '_unitary_', None)
result = NotImplemented if unitary_getter is None else unitary_getter()
Expand Down
31 changes: 20 additions & 11 deletions cirq-core/cirq/protocols/mixture_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

import cirq

a = np.array([1])
b = np.array([1j])


class NoMethod:
pass
Expand All @@ -32,35 +35,35 @@ def _has_mixture_(self):

class ReturnsValidTuple(cirq.SupportsMixture):
def _mixture_(self):
return ((0.4, 'a'), (0.6, 'b'))
return ((0.4, a), (0.6, b))

def _has_mixture_(self):
return True


class ReturnsNonnormalizedTuple:
def _mixture_(self):
return ((0.4, 'a'), (0.4, 'b'))
return ((0.4, a), (0.4, b))


class ReturnsNegativeProbability:
def _mixture_(self):
return ((0.4, 'a'), (-0.4, 'b'))
return ((0.4, a), (-0.4, b))


class ReturnsGreaterThanUnityProbability:
def _mixture_(self):
return ((1.2, 'a'), (0.4, 'b'))
return ((1.2, a), (0.4, b))


class ReturnsMixtureButNoHasMixture:
def _mixture_(self):
return ((0.4, 'a'), (0.6, 'b'))
return ((0.4, a), (0.6, b))


class ReturnsUnitary:
def _unitary_(self):
return np.ones((2, 2))
return np.eye(2)

def _has_unitary_(self):
return True
Expand All @@ -74,12 +77,18 @@ def _has_unitary_(self):
return NotImplemented


class ReturnsMixtureOfReturnsUnitary:
def _mixture_(self):
return ((0.4, ReturnsUnitary()), (0.6, ReturnsUnitary()))


@pytest.mark.parametrize(
'val,mixture',
(
(ReturnsValidTuple(), ((0.4, 'a'), (0.6, 'b'))),
(ReturnsNonnormalizedTuple(), ((0.4, 'a'), (0.4, 'b'))),
(ReturnsUnitary(), ((1.0, np.ones((2, 2))),)),
(ReturnsValidTuple(), ((0.4, a), (0.6, b))),
(ReturnsNonnormalizedTuple(), ((0.4, a), (0.4, b))),
(ReturnsUnitary(), ((1.0, np.eye(2)),)),
(ReturnsMixtureOfReturnsUnitary(), ((0.4, np.eye(2)), (0.6, np.eye(2)))),
),
)
def test_objects_with_mixture(val, mixture):
Expand All @@ -88,7 +97,7 @@ def test_objects_with_mixture(val, mixture):
np.testing.assert_almost_equal(keys, expected_keys)
np.testing.assert_equal(values, expected_values)

keys, values = zip(*cirq.mixture(val, ((0.3, 'a'), (0.7, 'b'))))
keys, values = zip(*cirq.mixture(val, ((0.3, a), (0.7, b))))
np.testing.assert_almost_equal(keys, expected_keys)
np.testing.assert_equal(values, expected_values)

Expand All @@ -101,7 +110,7 @@ def test_objects_with_no_mixture(val):
_ = cirq.mixture(val)
assert cirq.mixture(val, None) is None
assert cirq.mixture(val, NotImplemented) is NotImplemented
default = ((0.4, 'a'), (0.6, 'b'))
default = ((0.4, a), (0.6, b))
assert cirq.mixture(val, default) == default


Expand Down