Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions cirq-google/cirq_google/api/v2/program.proto
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ syntax = "proto3";
package cirq.google.api.v2;

import "tunits/proto/tunits.proto";
import "cirq_google/api/v2/ndarrays.proto";

option java_package = "com.google.cirq.google.api.v2";
option java_outer_classname = "ProgramProto";
Expand Down Expand Up @@ -422,6 +423,9 @@ message ArgValue {
tunits.Value value_with_unit = 8;
bool bool_value = 9;
bytes bytes_value = 10;
Complex complex_value = 11;
Tuple tuple_value = 12;
NDArray ndarray_value = 13;
}
}

Expand All @@ -442,6 +446,33 @@ message RepeatedBoolean {
repeated bool values = 1;
}

// Representation of a mixed tuple of values
message Tuple {
repeated Arg values = 1;
}

// Representation of a complex number
message Complex {
double real_value = 1;
double imaginary_value = 2;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use real and imag for parity with builtin complex and numpy numeric types?

}

message NDArray {
oneof arr {
Complex128Array complex128_array = 1;
Complex64Array complex64_array = 2;
Float16Array float16_array = 3;
Float32Array float32_array = 4;
Float64Array float64_array = 5;
Int64Array int64_array = 6;
Int32Array int32_array = 7;
Int16Array int16_array = 8;
Int8Array int8_array = 9;
UInt8Array uint8_array = 10;
BitArray bit_array = 11;
}
}

// A function of arguments. This is an s-expression tree representing
// mathematically the function being evaluated.
//
Expand Down
257 changes: 132 additions & 125 deletions cirq-google/cirq_google/api/v2/program_pb2.py

Large diffs are not rendered by default.

115 changes: 112 additions & 3 deletions cirq-google/cirq_google/api/v2/program_pb2.pyi

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

91 changes: 85 additions & 6 deletions cirq-google/cirq_google/serialization/arg_func_langs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from cirq.qis import CliffordTableau
from cirq_google.api import v2
from cirq_google.api.v2 import ndarrays
from cirq_google.ops import InternalGate

SUPPORTED_SYMPY_OPS = (sympy.Symbol, sympy.Add, sympy.Mul, sympy.Pow)
Expand Down Expand Up @@ -133,15 +134,22 @@ def arg_to_proto(
msg.arg_value.bool_value = bool(value)
elif isinstance(value, FLOAT_TYPES):
msg.arg_value.float_value = float(value)
elif isinstance(value, complex):
msg.arg_value.complex_value.real_value = value.real
msg.arg_value.complex_value.imaginary_value = value.imag
elif isinstance(value, bytes):
msg.arg_value.bytes_value = value
elif isinstance(value, str):
msg.arg_value.string_value = value
elif isinstance(value, (list, tuple, np.ndarray)):
if len(value):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a round trip this converts (i) uniform-type lists or tuples to lists and (ii) non-uniform lists and tuples to tuples. It also discards empty lists, tuples and arrays, which return as None from round trip.

We can move the array conversion out of the if len(value) block as _ndarray_to_proto and _ndarray_from_proto seem to work with empty arrays.

As for tuples and lists, we could add a sequence_type enum to the ArgValue message which could be (UNSPECIFIED, LIST, TUPLE) and would let us reconstruct either list or tuple from a round trip. It would also allow us to express either an empty list or an empty tuple, in such case ArgValue would be empty except of a (new) sequence_type.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! I went ahead and did that, and also added set support while in there.

if isinstance(value[0], str):
if isinstance(value, np.ndarray):
_ndarray_to_proto(value, out=msg)
elif isinstance(value[0], str):
if not all(isinstance(x, str) for x in value):
raise ValueError('Sequences of mixed object types are not supported')
# Not a uniform list, convert to tuple
_tuple_to_proto(value, out=msg.arg_value.tuple_value)
return msg
msg.arg_value.string_values.values.extend(str(x) for x in value)
else:
# This is a numerical field.
Expand All @@ -162,10 +170,9 @@ def arg_to_proto(
break

if non_numerical is not None:
raise ValueError(
'Mixed Sequences with objects of type '
f'{type(non_numerical)} are not supported'
)
# Not a uniform list, convert to tuple
_tuple_to_proto(value, out=msg.arg_value.tuple_value)
return msg
field, types_tuple = numerical_fields[cur_index]
field.extend(types_tuple[0](x) for x in value)
elif isinstance(value, tunits.Value):
Expand All @@ -176,6 +183,67 @@ def arg_to_proto(
return msg


def _ndarray_to_proto(value: np.ndarray, out: v2.program_pb2.Arg):
ndarray_msg = out.arg_value.ndarray_value
match value.dtype.name:
case 'float64':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can match value.dtype rather than the name. This would protect against typo in string or a rename of data type.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

ndarrays.to_float64_array(value, out=ndarray_msg.float64_array)
case 'float32':
ndarrays.to_float32_array(value, out=ndarray_msg.float32_array)
case 'float16':
ndarrays.to_float16_array(value, out=ndarray_msg.float16_array)
case 'int64':
ndarrays.to_int64_array(value, out=ndarray_msg.int64_array)
case 'int32':
ndarrays.to_int32_array(value, out=ndarray_msg.int32_array)
case 'int16':
ndarrays.to_int16_array(value, out=ndarray_msg.int16_array)
case 'int8':
ndarrays.to_int8_array(value, out=ndarray_msg.int8_array)
case 'uint8':
ndarrays.to_uint8_array(value, out=ndarray_msg.uint8_array)
case 'complex128':
ndarrays.to_complex128_array(value, out=ndarray_msg.complex128_array)
case 'complex64':
ndarrays.to_complex64_array(value, out=ndarray_msg.complex64_array)
case 'bool':
ndarrays.to_bitarray(value, out=ndarray_msg.bit_array)


def _ndarray_from_proto(msg: v2.program_pb2.ArgValue):
ndarray_msg = msg.ndarray_value
match ndarray_msg.WhichOneof('arr'):
case 'float64_array':
return ndarrays.from_float64_array(ndarray_msg.float64_array)
case 'float32_array':
return ndarrays.from_float32_array(ndarray_msg.float32_array)
case 'float16_array':
return ndarrays.from_float16_array(ndarray_msg.float16_array)
case 'int64_array':
return ndarrays.from_int64_array(ndarray_msg.int64_array)
case 'int32_array':
return ndarrays.from_int32_array(ndarray_msg.int32_array)
case 'int16_array':
return ndarrays.from_int16_array(ndarray_msg.int16_array)
case 'int8_array':
return ndarrays.from_int8_array(ndarray_msg.int8_array)
case 'uint8_array':
return ndarrays.from_uint8_array(ndarray_msg.uint8_array)
case 'complex128_array':
return ndarrays.from_complex128_array(ndarray_msg.complex128_array)
case 'complex64_array':
return ndarrays.from_complex64_array(ndarray_msg.complex64_array)
case 'bit_array':
return ndarrays.from_bitarray(ndarray_msg.bit_array)


def _tuple_to_proto(value: Union[list, tuple], out: v2.program_pb2.Tuple):
"""Converts a tuple of mixed values to Arg protos."""
for arg in value:
new_arg = out.values.add()
arg_to_proto(arg, out=new_arg)


def _arg_func_to_proto(
value: ARG_LIKE, msg: Union[v2.program_pb2.Arg, v2.program_pb2.FloatArg]
) -> None:
Expand Down Expand Up @@ -299,6 +367,17 @@ def arg_from_proto(
return tunits.Value.from_proto(arg_value.value_with_unit)
case 'bytes_value':
return bytes(arg_value.bytes_value)
case 'complex_value':
return (
arg_value.complex_value.real_value
+ 1j * arg_value.complex_value.imaginary_value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return (
arg_value.complex_value.real_value
+ 1j * arg_value.complex_value.imaginary_value
return complex(
arg_value.complex_value.real_value,
arg_value.complex_value.imaginary_value

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

)
case 'tuple_value':
return tuple(
arg_from_proto(tuple_proto) for tuple_proto in arg_value.tuple_value.values
)
case 'ndarray_value':
return _ndarray_from_proto(arg_value)
raise ValueError(f'Unrecognized value type: {which_val!r}') # pragma: no cover
case 'symbol':
return sympy.Symbol(arg_proto.symbol)
Expand Down
Loading
Loading