diff --git a/cirq-core/cirq/ops/common_gates.py b/cirq-core/cirq/ops/common_gates.py index 8658c7001d4..f2823d2a3bd 100644 --- a/cirq-core/cirq/ops/common_gates.py +++ b/cirq-core/cirq/ops/common_gates.py @@ -222,25 +222,17 @@ def controlled( A `cirq.ControlledGate` (or `cirq.CXPowGate` if possible) representing `self` controlled by the given control values and qubits. """ - if control_values and not isinstance(control_values, cv.AbstractControlValues): - control_values = cv.ProductOfSums( - tuple( - (val,) if isinstance(val, int) else tuple(sorted(val)) for val in control_values - ) - ) result = super().controlled(num_controls, control_values, control_qid_shape) if ( self._global_shift == 0 and isinstance(result, controlled_gate.ControlledGate) and isinstance(result.control_values, cv.ProductOfSums) - and result.control_values[-1] == (1,) - and result.control_qid_shape[-1] == 2 + and result.control_values.is_trivial ): - return cirq.CXPowGate( - exponent=self._exponent, global_shift=self._global_shift - ).controlled( - result.num_controls() - 1, result.control_values[:-1], result.control_qid_shape[:-1] - ) + if result.control_qid_shape == (2,): + return cirq.CXPowGate(exponent=self._exponent) + if result.control_qid_shape == (2, 2): + return cirq.CCXPowGate(exponent=self._exponent) return result def _pauli_expansion_(self) -> value.LinearDict[str]: @@ -694,25 +686,17 @@ def controlled( A `cirq.ControlledGate` (or `cirq.CZPowGate` if possible) representing `self` controlled by the given control values and qubits. """ - if control_values and not isinstance(control_values, cv.AbstractControlValues): - control_values = cv.ProductOfSums( - tuple( - (val,) if isinstance(val, int) else tuple(sorted(val)) for val in control_values - ) - ) result = super().controlled(num_controls, control_values, control_qid_shape) if ( self._global_shift == 0 and isinstance(result, controlled_gate.ControlledGate) and isinstance(result.control_values, cv.ProductOfSums) - and result.control_values[-1] == (1,) - and result.control_qid_shape[-1] == 2 + and result.control_values.is_trivial ): - return cirq.CZPowGate( - exponent=self._exponent, global_shift=self._global_shift - ).controlled( - result.num_controls() - 1, result.control_values[:-1], result.control_qid_shape[:-1] - ) + if result.control_qid_shape == (2,): + return cirq.CZPowGate(exponent=self._exponent) + if result.control_qid_shape == (2, 2): + return cirq.CCZPowGate(exponent=self._exponent) return result def _qid_shape_(self) -> tuple[int, ...]: @@ -1138,26 +1122,14 @@ def controlled( A `cirq.ControlledGate` (or `cirq.CCZPowGate` if possible) representing `self` controlled by the given control values and qubits. """ - if control_values and not isinstance(control_values, cv.AbstractControlValues): - control_values = cv.ProductOfSums( - tuple( - (val,) if isinstance(val, int) else tuple(sorted(val)) for val in control_values - ) - ) result = super().controlled(num_controls, control_values, control_qid_shape) - if ( - self._global_shift == 0 - and isinstance(result, controlled_gate.ControlledGate) - and isinstance(result.control_values, cv.ProductOfSums) - and result.control_values[-1] == (1,) - and result.control_qid_shape[-1] == 2 - ): - return cirq.CCZPowGate( - exponent=self._exponent, global_shift=self._global_shift - ).controlled( - result.num_controls() - 1, result.control_values[:-1], result.control_qid_shape[:-1] - ) - return result + if self._global_shift != 0 or not isinstance(result, controlled_gate.ControlledGate): + return result + return ZPowGate(exponent=self.exponent).controlled( + num_controls=result.num_controls() + 1, + control_values=result.control_values & cv.ProductOfSums([1]), + control_qid_shape=result.control_qid_shape + (2,), + ) def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: return protocols.CircuitDiagramInfo( @@ -1340,26 +1312,14 @@ def controlled( A `cirq.ControlledGate` (or `cirq.CCXPowGate` if possible) representing `self` controlled by the given control values and qubits. """ - if control_values and not isinstance(control_values, cv.AbstractControlValues): - control_values = cv.ProductOfSums( - tuple( - (val,) if isinstance(val, int) else tuple(sorted(val)) for val in control_values - ) - ) result = super().controlled(num_controls, control_values, control_qid_shape) - if ( - self._global_shift == 0 - and isinstance(result, controlled_gate.ControlledGate) - and isinstance(result.control_values, cv.ProductOfSums) - and result.control_values[-1] == (1,) - and result.control_qid_shape[-1] == 2 - ): - return cirq.CCXPowGate( - exponent=self._exponent, global_shift=self._global_shift - ).controlled( - result.num_controls() - 1, result.control_values[:-1], result.control_qid_shape[:-1] - ) - return result + if self._global_shift != 0 or not isinstance(result, controlled_gate.ControlledGate): + return result + return XPowGate(exponent=self.exponent).controlled( + num_controls=result.num_controls() + 1, + control_values=result.control_values & cv.ProductOfSums([1]), + control_qid_shape=result.control_qid_shape + (2,), + ) def _qasm_(self, args: cirq.QasmArgs, qubits: tuple[cirq.Qid, ...]) -> str | None: if self._exponent != 1: diff --git a/cirq-core/cirq/ops/common_gates_test.py b/cirq-core/cirq/ops/common_gates_test.py index f60ff16a657..d615ddb29b1 100644 --- a/cirq-core/cirq/ops/common_gates_test.py +++ b/cirq-core/cirq/ops/common_gates_test.py @@ -109,19 +109,19 @@ def test_z_init(): @pytest.mark.parametrize( - 'input_gate, specialized_output', + 'input_gate, specialized_output, base_gate', [ - (cirq.Z, cirq.CZ), - (cirq.CZ, cirq.CCZ), - (cirq.X, cirq.CX), - (cirq.CX, cirq.CCX), - (cirq.ZPowGate(exponent=0.5), cirq.CZPowGate(exponent=0.5)), - (cirq.CZPowGate(exponent=0.5), cirq.CCZPowGate(exponent=0.5)), - (cirq.XPowGate(exponent=0.5), cirq.CXPowGate(exponent=0.5)), - (cirq.CXPowGate(exponent=0.5), cirq.CCXPowGate(exponent=0.5)), + (cirq.Z, cirq.CZ, cirq.Z), + (cirq.CZ, cirq.CCZ, cirq.Z), + (cirq.X, cirq.CX, cirq.X), + (cirq.CX, cirq.CCX, cirq.X), + (cirq.ZPowGate(exponent=0.5), cirq.CZPowGate(exponent=0.5), cirq.S), + (cirq.CZPowGate(exponent=0.5), cirq.CCZPowGate(exponent=0.5), cirq.S), + (cirq.XPowGate(exponent=0.5), cirq.CXPowGate(exponent=0.5), cirq.XPowGate(exponent=0.5)), + (cirq.CXPowGate(exponent=0.5), cirq.CCXPowGate(exponent=0.5), cirq.XPowGate(exponent=0.5)), ], ) -def test_specialized_control(input_gate, specialized_output): +def test_specialized_control(input_gate, specialized_output, base_gate): # Single qubit control on the input gate gives the specialized output assert input_gate.controlled() == specialized_output assert input_gate.controlled(num_controls=1) == specialized_output @@ -151,20 +151,24 @@ def test_specialized_control(input_gate, specialized_output): ) # When a control_value 1 qubit is not acting first, results in a regular - # ControlledGate on the input gate instance. + # ControlledGate on the base gate instance, with any extra control layer + # of the input gate being absorbed into the ControlledGate. + absorbed = 0 if base_gate == input_gate else 1 + absorbed_values = ((1,),) * absorbed + absorbed_shape = (2,) * absorbed assert input_gate.controlled(num_controls=1, control_qid_shape=(3,)) == cirq.ControlledGate( - input_gate, num_controls=1, control_qid_shape=(3,) + base_gate, num_controls=1 + absorbed, control_qid_shape=(3,) + absorbed_shape ) assert input_gate.controlled(control_values=((0,), (1,), (0,))) == cirq.ControlledGate( - input_gate, num_controls=3, control_values=((0,), (1,), (0,)) + base_gate, num_controls=3 + absorbed, control_values=((0,), (1,), (0,)) + absorbed_values ) assert input_gate.controlled(control_qid_shape=(3, 2, 3)) == cirq.ControlledGate( - input_gate, num_controls=3, control_qid_shape=(3, 2, 3) + base_gate, num_controls=3 + absorbed, control_qid_shape=(3, 2, 3) + absorbed_shape ) assert input_gate.controlled(control_qid_shape=(3,)).controlled( control_qid_shape=(2,) ).controlled(control_qid_shape=(4,)) != cirq.ControlledGate( - input_gate, num_controls=3, control_qid_shape=(3, 2, 4) + base_gate, num_controls=3 + absorbed, control_qid_shape=(3, 2, 4) + absorbed_shape ) diff --git a/cirq-core/cirq/ops/controlled_gate.py b/cirq-core/cirq/ops/controlled_gate.py index c07b134b983..915c0805c46 100644 --- a/cirq-core/cirq/ops/controlled_gate.py +++ b/cirq-core/cirq/ops/controlled_gate.py @@ -151,12 +151,7 @@ def _decompose_with_context_( ) # Prefer the subgate controlled version if available if self != controlled_sub_gate: - # Prevent 2-cycle from appearing in the recursive decomposition - # TODO: Remove after #7241 is resolved - if not isinstance(controlled_sub_gate, ControlledGate) or not isinstance( - controlled_sub_gate.sub_gate, common_gates.CZPowGate - ): - return controlled_sub_gate.on(*qubits) + return controlled_sub_gate.on(*qubits) if ( protocols.has_unitary(self.sub_gate) and protocols.num_qubits(self.sub_gate) == 1