diff --git a/cirq-core/cirq/ops/wait_gate.py b/cirq-core/cirq/ops/wait_gate.py index e94419ae475..fd5d860102d 100644 --- a/cirq-core/cirq/ops/wait_gate.py +++ b/cirq-core/cirq/ops/wait_gate.py @@ -16,6 +16,8 @@ from typing import AbstractSet, Any, TYPE_CHECKING +import sympy + from cirq import protocols, value from cirq.ops import raw_types @@ -51,8 +53,11 @@ def __init__( ValueError: If the `qid_shape` provided is empty or `num_qubits` contradicts `qid_shape`. """ - self._duration = value.Duration(duration) - if not protocols.is_parameterized(self.duration) and self.duration < 0: + self._duration = ( + duration if isinstance(duration, value.Duration) else value.Duration(duration) + ) + total_picos = self.duration.total_picos() + if not isinstance(total_picos, sympy.Basic) and total_picos < 0: raise ValueError('duration < 0') if qid_shape is None: if num_qubits is None: diff --git a/cirq-core/cirq/value/duration.py b/cirq-core/cirq/value/duration.py index ef5a966604b..1c328d39c1d 100644 --- a/cirq-core/cirq/value/duration.py +++ b/cirq-core/cirq/value/duration.py @@ -84,19 +84,19 @@ def __init__( """ self._time_vals: list[_NUMERIC_INPUT_TYPE] = [0, 0, 0, 0] self._multipliers = [1, 1000, 1000_000, 1000_000_000] - if value is not None and value != 0: + if value is not None: if isinstance(value, datetime.timedelta): # timedelta has microsecond resolution. self._time_vals[2] = int(value / datetime.timedelta(microseconds=1)) elif isinstance(value, Duration): self._time_vals = value._time_vals - else: + elif value != 0: raise TypeError(f'Not a `cirq.DURATION_LIKE`: {repr(value)}.') input_vals = [picos, nanos, micros, millis] self._time_vals = _add_time_vals(self._time_vals, input_vals) def _is_parameterized_(self) -> bool: - return protocols.is_parameterized(self._time_vals) + return any(isinstance(val, sympy.Basic) for val in self._time_vals) def _parameter_names_(self) -> AbstractSet[str]: return protocols.parameter_names(self._time_vals)