1616
1717
1818from collections import OrderedDict
19- from collections .abc import MutableMapping
20- from typing import Any , TypeVar , cast
21-
22- from google .protobuf .message import Message as GrpcMessage
19+ from typing import Any , cast
2320
2421# pylint: disable=E0611
2522from flwr .proto .clientappio_pb2 import ClientAppOutputCode , ClientAppOutputStatus
3027from flwr .proto .message_pb2 import Metadata as ProtoMetadata
3128from flwr .proto .recorddict_pb2 import Array as ProtoArray
3229from flwr .proto .recorddict_pb2 import ArrayRecord as ProtoArrayRecord
33- from flwr .proto .recorddict_pb2 import BoolList , BytesList
3430from flwr .proto .recorddict_pb2 import ConfigRecord as ProtoConfigRecord
3531from flwr .proto .recorddict_pb2 import ConfigRecordValue as ProtoConfigRecordValue
36- from flwr .proto .recorddict_pb2 import DoubleList
3732from flwr .proto .recorddict_pb2 import MetricRecord as ProtoMetricRecord
3833from flwr .proto .recorddict_pb2 import MetricRecordValue as ProtoMetricRecordValue
3934from flwr .proto .recorddict_pb2 import RecordDict as ProtoRecordDict
40- from flwr .proto .recorddict_pb2 import SintList , StringList , UintList
4135from flwr .proto .run_pb2 import Run as ProtoRun
4236from flwr .proto .run_pb2 import RunStatus as ProtoRunStatus
4337from flwr .proto .transport_pb2 import (
6054 RecordDict ,
6155 typing ,
6256)
57+ from .constant import INT64_MAX_VALUE
6358from .message import Error , Message , Metadata , make_message
64- from .record . typeddict import TypedDict
59+ from .serde_utils import record_value_dict_from_proto , record_value_dict_to_proto
6560
6661# === Parameters message ===
6762
@@ -339,7 +334,6 @@ def metrics_from_proto(proto: Any) -> typing.Metrics:
339334
340335
341336# === Scalar messages ===
342- INT64_MAX_VALUE = 9223372036854775807 # (1 << 63) - 1
343337
344338
345339def scalar_to_proto (scalar : typing .Scalar ) -> Scalar :
@@ -377,97 +371,6 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar:
377371# === Record messages ===
378372
379373
380- _type_to_field : dict [type , str ] = {
381- float : "double" ,
382- int : "sint64" ,
383- bool : "bool" ,
384- str : "string" ,
385- bytes : "bytes" ,
386- }
387- _list_type_to_class_and_field : dict [type , tuple [type [GrpcMessage ], str ]] = {
388- float : (DoubleList , "double_list" ),
389- int : (SintList , "sint_list" ),
390- bool : (BoolList , "bool_list" ),
391- str : (StringList , "string_list" ),
392- bytes : (BytesList , "bytes_list" ),
393- }
394- T = TypeVar ("T" )
395-
396-
397- def _is_uint64 (value : Any ) -> bool :
398- """Check if a value is uint64."""
399- return isinstance (value , int ) and value > INT64_MAX_VALUE
400-
401-
402- def _record_value_to_proto (
403- value : Any , allowed_types : list [type ], proto_class : type [T ]
404- ) -> T :
405- """Serialize `*RecordValue` to ProtoBuf.
406-
407- Note: `bool` MUST be put in the front of allowd_types if it exists.
408- """
409- arg = {}
410- for t in allowed_types :
411- # Single element
412- # Note: `isinstance(False, int) == True`.
413- if isinstance (value , t ):
414- fld = _type_to_field [t ]
415- if t is int and _is_uint64 (value ):
416- fld = "uint64"
417- arg [fld ] = value
418- return proto_class (** arg )
419- # List
420- if isinstance (value , list ) and all (isinstance (item , t ) for item in value ):
421- list_class , fld = _list_type_to_class_and_field [t ]
422- # Use UintList if any element is of type `uint64`.
423- if t is int and any (_is_uint64 (v ) for v in value ):
424- list_class , fld = UintList , "uint_list"
425- arg [fld ] = list_class (vals = value )
426- return proto_class (** arg )
427- # Invalid types
428- raise TypeError (
429- f"The type of the following value is not allowed "
430- f"in '{ proto_class .__name__ } ':\n { value } "
431- )
432-
433-
434- def _record_value_from_proto (value_proto : GrpcMessage ) -> Any :
435- """Deserialize `*RecordValue` from ProtoBuf."""
436- value_field = cast (str , value_proto .WhichOneof ("value" ))
437- if value_field .endswith ("list" ):
438- value = list (getattr (value_proto , value_field ).vals )
439- else :
440- value = getattr (value_proto , value_field )
441- return value
442-
443-
444- def _record_value_dict_to_proto (
445- value_dict : TypedDict [str , Any ],
446- allowed_types : list [type ],
447- value_proto_class : type [T ],
448- ) -> dict [str , T ]:
449- """Serialize the record value dict to ProtoBuf.
450-
451- Note: `bool` MUST be put in the front of allowd_types if it exists.
452- """
453- # Move bool to the front
454- if bool in allowed_types and allowed_types [0 ] != bool :
455- allowed_types .remove (bool )
456- allowed_types .insert (0 , bool )
457-
458- def proto (_v : Any ) -> T :
459- return _record_value_to_proto (_v , allowed_types , value_proto_class )
460-
461- return {k : proto (v ) for k , v in value_dict .items ()}
462-
463-
464- def _record_value_dict_from_proto (
465- value_dict_proto : MutableMapping [str , Any ]
466- ) -> dict [str , Any ]:
467- """Deserialize the record value dict from ProtoBuf."""
468- return {k : _record_value_from_proto (v ) for k , v in value_dict_proto .items ()}
469-
470-
471374def array_to_proto (array : Array ) -> ProtoArray :
472375 """Serialize Array to ProtoBuf."""
473376 return ProtoArray (** vars (array ))
@@ -506,7 +409,7 @@ def array_record_from_proto(
506409def metric_record_to_proto (record : MetricRecord ) -> ProtoMetricRecord :
507410 """Serialize MetricRecord to ProtoBuf."""
508411 return ProtoMetricRecord (
509- data = _record_value_dict_to_proto (record , [float , int ], ProtoMetricRecordValue )
412+ data = record_value_dict_to_proto (record , [float , int ], ProtoMetricRecordValue )
510413 )
511414
512415
@@ -515,7 +418,7 @@ def metric_record_from_proto(record_proto: ProtoMetricRecord) -> MetricRecord:
515418 return MetricRecord (
516419 metric_dict = cast (
517420 dict [str , typing .MetricRecordValues ],
518- _record_value_dict_from_proto (record_proto .data ),
421+ record_value_dict_from_proto (record_proto .data ),
519422 ),
520423 keep_input = False ,
521424 )
@@ -524,7 +427,7 @@ def metric_record_from_proto(record_proto: ProtoMetricRecord) -> MetricRecord:
524427def config_record_to_proto (record : ConfigRecord ) -> ProtoConfigRecord :
525428 """Serialize ConfigRecord to ProtoBuf."""
526429 return ProtoConfigRecord (
527- data = _record_value_dict_to_proto (
430+ data = record_value_dict_to_proto (
528431 record ,
529432 [bool , int , float , str , bytes ],
530433 ProtoConfigRecordValue ,
@@ -537,7 +440,7 @@ def config_record_from_proto(record_proto: ProtoConfigRecord) -> ConfigRecord:
537440 return ConfigRecord (
538441 config_dict = cast (
539442 dict [str , typing .ConfigRecordValues ],
540- _record_value_dict_from_proto (record_proto .data ),
443+ record_value_dict_from_proto (record_proto .data ),
541444 ),
542445 keep_input = False ,
543446 )
0 commit comments