-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Add tuples, ndarrays, and complex numbers to cirq_google proto #7226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
ff66e84
dd0a3f4
adaeca2
8b50d62
851c1b9
e6977f5
d77b78a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -21,6 +21,7 @@ | |||||||||||||
|
||||||||||||||
from cirq.qis import CliffordTableau | ||||||||||||||
from cirq_google.api import v2 | ||||||||||||||
from cirq_google.api.v2 import ndarrays | ||||||||||||||
from cirq_google.ops import InternalGate | ||||||||||||||
|
||||||||||||||
SUPPORTED_SYMPY_OPS = (sympy.Symbol, sympy.Add, sympy.Mul, sympy.Pow) | ||||||||||||||
|
@@ -133,15 +134,22 @@ def arg_to_proto( | |||||||||||||
msg.arg_value.bool_value = bool(value) | ||||||||||||||
elif isinstance(value, FLOAT_TYPES): | ||||||||||||||
msg.arg_value.float_value = float(value) | ||||||||||||||
elif isinstance(value, complex): | ||||||||||||||
msg.arg_value.complex_value.real_value = value.real | ||||||||||||||
msg.arg_value.complex_value.imaginary_value = value.imag | ||||||||||||||
elif isinstance(value, bytes): | ||||||||||||||
msg.arg_value.bytes_value = value | ||||||||||||||
elif isinstance(value, str): | ||||||||||||||
msg.arg_value.string_value = value | ||||||||||||||
elif isinstance(value, (list, tuple, np.ndarray)): | ||||||||||||||
if len(value): | ||||||||||||||
|
||||||||||||||
if isinstance(value[0], str): | ||||||||||||||
if isinstance(value, np.ndarray): | ||||||||||||||
_ndarray_to_proto(value, out=msg) | ||||||||||||||
elif isinstance(value[0], str): | ||||||||||||||
if not all(isinstance(x, str) for x in value): | ||||||||||||||
raise ValueError('Sequences of mixed object types are not supported') | ||||||||||||||
# Not a uniform list, convert to tuple | ||||||||||||||
_tuple_to_proto(value, out=msg.arg_value.tuple_value) | ||||||||||||||
return msg | ||||||||||||||
msg.arg_value.string_values.values.extend(str(x) for x in value) | ||||||||||||||
else: | ||||||||||||||
# This is a numerical field. | ||||||||||||||
|
@@ -162,10 +170,9 @@ def arg_to_proto( | |||||||||||||
break | ||||||||||||||
|
||||||||||||||
if non_numerical is not None: | ||||||||||||||
raise ValueError( | ||||||||||||||
'Mixed Sequences with objects of type ' | ||||||||||||||
f'{type(non_numerical)} are not supported' | ||||||||||||||
) | ||||||||||||||
# Not a uniform list, convert to tuple | ||||||||||||||
_tuple_to_proto(value, out=msg.arg_value.tuple_value) | ||||||||||||||
return msg | ||||||||||||||
field, types_tuple = numerical_fields[cur_index] | ||||||||||||||
field.extend(types_tuple[0](x) for x in value) | ||||||||||||||
elif isinstance(value, tunits.Value): | ||||||||||||||
|
@@ -176,6 +183,67 @@ def arg_to_proto( | |||||||||||||
return msg | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def _ndarray_to_proto(value: np.ndarray, out: v2.program_pb2.Arg): | ||||||||||||||
ndarray_msg = out.arg_value.ndarray_value | ||||||||||||||
match value.dtype.name: | ||||||||||||||
case 'float64': | ||||||||||||||
|
||||||||||||||
ndarrays.to_float64_array(value, out=ndarray_msg.float64_array) | ||||||||||||||
case 'float32': | ||||||||||||||
ndarrays.to_float32_array(value, out=ndarray_msg.float32_array) | ||||||||||||||
case 'float16': | ||||||||||||||
ndarrays.to_float16_array(value, out=ndarray_msg.float16_array) | ||||||||||||||
case 'int64': | ||||||||||||||
ndarrays.to_int64_array(value, out=ndarray_msg.int64_array) | ||||||||||||||
case 'int32': | ||||||||||||||
ndarrays.to_int32_array(value, out=ndarray_msg.int32_array) | ||||||||||||||
case 'int16': | ||||||||||||||
ndarrays.to_int16_array(value, out=ndarray_msg.int16_array) | ||||||||||||||
case 'int8': | ||||||||||||||
ndarrays.to_int8_array(value, out=ndarray_msg.int8_array) | ||||||||||||||
case 'uint8': | ||||||||||||||
ndarrays.to_uint8_array(value, out=ndarray_msg.uint8_array) | ||||||||||||||
case 'complex128': | ||||||||||||||
ndarrays.to_complex128_array(value, out=ndarray_msg.complex128_array) | ||||||||||||||
case 'complex64': | ||||||||||||||
ndarrays.to_complex64_array(value, out=ndarray_msg.complex64_array) | ||||||||||||||
case 'bool': | ||||||||||||||
ndarrays.to_bitarray(value, out=ndarray_msg.bit_array) | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def _ndarray_from_proto(msg: v2.program_pb2.ArgValue): | ||||||||||||||
ndarray_msg = msg.ndarray_value | ||||||||||||||
match ndarray_msg.WhichOneof('arr'): | ||||||||||||||
case 'float64_array': | ||||||||||||||
return ndarrays.from_float64_array(ndarray_msg.float64_array) | ||||||||||||||
case 'float32_array': | ||||||||||||||
return ndarrays.from_float32_array(ndarray_msg.float32_array) | ||||||||||||||
case 'float16_array': | ||||||||||||||
return ndarrays.from_float16_array(ndarray_msg.float16_array) | ||||||||||||||
case 'int64_array': | ||||||||||||||
return ndarrays.from_int64_array(ndarray_msg.int64_array) | ||||||||||||||
case 'int32_array': | ||||||||||||||
return ndarrays.from_int32_array(ndarray_msg.int32_array) | ||||||||||||||
case 'int16_array': | ||||||||||||||
return ndarrays.from_int16_array(ndarray_msg.int16_array) | ||||||||||||||
case 'int8_array': | ||||||||||||||
return ndarrays.from_int8_array(ndarray_msg.int8_array) | ||||||||||||||
case 'uint8_array': | ||||||||||||||
return ndarrays.from_uint8_array(ndarray_msg.uint8_array) | ||||||||||||||
case 'complex128_array': | ||||||||||||||
return ndarrays.from_complex128_array(ndarray_msg.complex128_array) | ||||||||||||||
case 'complex64_array': | ||||||||||||||
return ndarrays.from_complex64_array(ndarray_msg.complex64_array) | ||||||||||||||
case 'bit_array': | ||||||||||||||
return ndarrays.from_bitarray(ndarray_msg.bit_array) | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def _tuple_to_proto(value: Union[list, tuple], out: v2.program_pb2.Tuple): | ||||||||||||||
"""Converts a tuple of mixed values to Arg protos.""" | ||||||||||||||
for arg in value: | ||||||||||||||
new_arg = out.values.add() | ||||||||||||||
arg_to_proto(arg, out=new_arg) | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def _arg_func_to_proto( | ||||||||||||||
value: ARG_LIKE, msg: Union[v2.program_pb2.Arg, v2.program_pb2.FloatArg] | ||||||||||||||
) -> None: | ||||||||||||||
|
@@ -299,6 +367,17 @@ def arg_from_proto( | |||||||||||||
return tunits.Value.from_proto(arg_value.value_with_unit) | ||||||||||||||
case 'bytes_value': | ||||||||||||||
return bytes(arg_value.bytes_value) | ||||||||||||||
case 'complex_value': | ||||||||||||||
return ( | ||||||||||||||
arg_value.complex_value.real_value | ||||||||||||||
+ 1j * arg_value.complex_value.imaginary_value | ||||||||||||||
|
return ( | |
arg_value.complex_value.real_value | |
+ 1j * arg_value.complex_value.imaginary_value | |
return complex( | |
arg_value.complex_value.real_value, | |
arg_value.complex_value.imaginary_value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use
real
andimag
for parity with builtin complex and numpy numeric types?