Skip to content

Commit 8a74a95

Browse files
authored
refactor(framework) Move utils functions out of serde.py (#5343)
1 parent 3cecfd5 commit 8a74a95

File tree

3 files changed

+133
-104
lines changed

3 files changed

+133
-104
lines changed

framework/py/flwr/common/constant.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@
131131
HEAD_BODY_DIVIDER = b"\x00"
132132
TYPE_BODY_LEN_DIVIDER = " "
133133

134+
# Constants for serialization
135+
INT64_MAX_VALUE = 9223372036854775807 # (1 << 63) - 1
136+
134137

135138
class MessageType:
136139
"""Message type."""

framework/py/flwr/common/serde.py

Lines changed: 7 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,7 @@
1616

1717

1818
from collections import OrderedDict
19-
from collections.abc import MutableMapping
20-
from typing import Any, TypeVar, cast
21-
22-
from google.protobuf.message import Message as GrpcMessage
19+
from typing import Any, cast
2320

2421
# pylint: disable=E0611
2522
from flwr.proto.clientappio_pb2 import ClientAppOutputCode, ClientAppOutputStatus
@@ -30,14 +27,11 @@
3027
from flwr.proto.message_pb2 import Metadata as ProtoMetadata
3128
from flwr.proto.recorddict_pb2 import Array as ProtoArray
3229
from flwr.proto.recorddict_pb2 import ArrayRecord as ProtoArrayRecord
33-
from flwr.proto.recorddict_pb2 import BoolList, BytesList
3430
from flwr.proto.recorddict_pb2 import ConfigRecord as ProtoConfigRecord
3531
from flwr.proto.recorddict_pb2 import ConfigRecordValue as ProtoConfigRecordValue
36-
from flwr.proto.recorddict_pb2 import DoubleList
3732
from flwr.proto.recorddict_pb2 import MetricRecord as ProtoMetricRecord
3833
from flwr.proto.recorddict_pb2 import MetricRecordValue as ProtoMetricRecordValue
3934
from flwr.proto.recorddict_pb2 import RecordDict as ProtoRecordDict
40-
from flwr.proto.recorddict_pb2 import SintList, StringList, UintList
4135
from flwr.proto.run_pb2 import Run as ProtoRun
4236
from flwr.proto.run_pb2 import RunStatus as ProtoRunStatus
4337
from flwr.proto.transport_pb2 import (
@@ -60,8 +54,9 @@
6054
RecordDict,
6155
typing,
6256
)
57+
from .constant import INT64_MAX_VALUE
6358
from .message import Error, Message, Metadata, make_message
64-
from .record.typeddict import TypedDict
59+
from .serde_utils import record_value_dict_from_proto, record_value_dict_to_proto
6560

6661
# === Parameters message ===
6762

@@ -339,7 +334,6 @@ def metrics_from_proto(proto: Any) -> typing.Metrics:
339334

340335

341336
# === Scalar messages ===
342-
INT64_MAX_VALUE = 9223372036854775807 # (1 << 63) - 1
343337

344338

345339
def scalar_to_proto(scalar: typing.Scalar) -> Scalar:
@@ -377,97 +371,6 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar:
377371
# === Record messages ===
378372

379373

380-
_type_to_field: dict[type, str] = {
381-
float: "double",
382-
int: "sint64",
383-
bool: "bool",
384-
str: "string",
385-
bytes: "bytes",
386-
}
387-
_list_type_to_class_and_field: dict[type, tuple[type[GrpcMessage], str]] = {
388-
float: (DoubleList, "double_list"),
389-
int: (SintList, "sint_list"),
390-
bool: (BoolList, "bool_list"),
391-
str: (StringList, "string_list"),
392-
bytes: (BytesList, "bytes_list"),
393-
}
394-
T = TypeVar("T")
395-
396-
397-
def _is_uint64(value: Any) -> bool:
398-
"""Check if a value is uint64."""
399-
return isinstance(value, int) and value > INT64_MAX_VALUE
400-
401-
402-
def _record_value_to_proto(
403-
value: Any, allowed_types: list[type], proto_class: type[T]
404-
) -> T:
405-
"""Serialize `*RecordValue` to ProtoBuf.
406-
407-
Note: `bool` MUST be put in the front of allowd_types if it exists.
408-
"""
409-
arg = {}
410-
for t in allowed_types:
411-
# Single element
412-
# Note: `isinstance(False, int) == True`.
413-
if isinstance(value, t):
414-
fld = _type_to_field[t]
415-
if t is int and _is_uint64(value):
416-
fld = "uint64"
417-
arg[fld] = value
418-
return proto_class(**arg)
419-
# List
420-
if isinstance(value, list) and all(isinstance(item, t) for item in value):
421-
list_class, fld = _list_type_to_class_and_field[t]
422-
# Use UintList if any element is of type `uint64`.
423-
if t is int and any(_is_uint64(v) for v in value):
424-
list_class, fld = UintList, "uint_list"
425-
arg[fld] = list_class(vals=value)
426-
return proto_class(**arg)
427-
# Invalid types
428-
raise TypeError(
429-
f"The type of the following value is not allowed "
430-
f"in '{proto_class.__name__}':\n{value}"
431-
)
432-
433-
434-
def _record_value_from_proto(value_proto: GrpcMessage) -> Any:
435-
"""Deserialize `*RecordValue` from ProtoBuf."""
436-
value_field = cast(str, value_proto.WhichOneof("value"))
437-
if value_field.endswith("list"):
438-
value = list(getattr(value_proto, value_field).vals)
439-
else:
440-
value = getattr(value_proto, value_field)
441-
return value
442-
443-
444-
def _record_value_dict_to_proto(
445-
value_dict: TypedDict[str, Any],
446-
allowed_types: list[type],
447-
value_proto_class: type[T],
448-
) -> dict[str, T]:
449-
"""Serialize the record value dict to ProtoBuf.
450-
451-
Note: `bool` MUST be put in the front of allowd_types if it exists.
452-
"""
453-
# Move bool to the front
454-
if bool in allowed_types and allowed_types[0] != bool:
455-
allowed_types.remove(bool)
456-
allowed_types.insert(0, bool)
457-
458-
def proto(_v: Any) -> T:
459-
return _record_value_to_proto(_v, allowed_types, value_proto_class)
460-
461-
return {k: proto(v) for k, v in value_dict.items()}
462-
463-
464-
def _record_value_dict_from_proto(
465-
value_dict_proto: MutableMapping[str, Any]
466-
) -> dict[str, Any]:
467-
"""Deserialize the record value dict from ProtoBuf."""
468-
return {k: _record_value_from_proto(v) for k, v in value_dict_proto.items()}
469-
470-
471374
def array_to_proto(array: Array) -> ProtoArray:
472375
"""Serialize Array to ProtoBuf."""
473376
return ProtoArray(**vars(array))
@@ -506,7 +409,7 @@ def array_record_from_proto(
506409
def metric_record_to_proto(record: MetricRecord) -> ProtoMetricRecord:
507410
"""Serialize MetricRecord to ProtoBuf."""
508411
return ProtoMetricRecord(
509-
data=_record_value_dict_to_proto(record, [float, int], ProtoMetricRecordValue)
412+
data=record_value_dict_to_proto(record, [float, int], ProtoMetricRecordValue)
510413
)
511414

512415

@@ -515,7 +418,7 @@ def metric_record_from_proto(record_proto: ProtoMetricRecord) -> MetricRecord:
515418
return MetricRecord(
516419
metric_dict=cast(
517420
dict[str, typing.MetricRecordValues],
518-
_record_value_dict_from_proto(record_proto.data),
421+
record_value_dict_from_proto(record_proto.data),
519422
),
520423
keep_input=False,
521424
)
@@ -524,7 +427,7 @@ def metric_record_from_proto(record_proto: ProtoMetricRecord) -> MetricRecord:
524427
def config_record_to_proto(record: ConfigRecord) -> ProtoConfigRecord:
525428
"""Serialize ConfigRecord to ProtoBuf."""
526429
return ProtoConfigRecord(
527-
data=_record_value_dict_to_proto(
430+
data=record_value_dict_to_proto(
528431
record,
529432
[bool, int, float, str, bytes],
530433
ProtoConfigRecordValue,
@@ -537,7 +440,7 @@ def config_record_from_proto(record_proto: ProtoConfigRecord) -> ConfigRecord:
537440
return ConfigRecord(
538441
config_dict=cast(
539442
dict[str, typing.ConfigRecordValues],
540-
_record_value_dict_from_proto(record_proto.data),
443+
record_value_dict_from_proto(record_proto.data),
541444
),
542445
keep_input=False,
543446
)
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Utils for serde."""
16+
17+
from collections.abc import MutableMapping
18+
from typing import Any, TypeVar, cast
19+
20+
from google.protobuf.message import Message as GrpcMessage
21+
22+
# pylint: disable=E0611
23+
from flwr.proto.recorddict_pb2 import (
24+
BoolList,
25+
BytesList,
26+
DoubleList,
27+
SintList,
28+
StringList,
29+
UintList,
30+
)
31+
32+
from .constant import INT64_MAX_VALUE
33+
from .record.typeddict import TypedDict
34+
35+
_type_to_field: dict[type, str] = {
36+
float: "double",
37+
int: "sint64",
38+
bool: "bool",
39+
str: "string",
40+
bytes: "bytes",
41+
}
42+
_list_type_to_class_and_field: dict[type, tuple[type[GrpcMessage], str]] = {
43+
float: (DoubleList, "double_list"),
44+
int: (SintList, "sint_list"),
45+
bool: (BoolList, "bool_list"),
46+
str: (StringList, "string_list"),
47+
bytes: (BytesList, "bytes_list"),
48+
}
49+
T = TypeVar("T")
50+
51+
52+
def _is_uint64(value: Any) -> bool:
53+
"""Check if a value is uint64."""
54+
return isinstance(value, int) and value > INT64_MAX_VALUE
55+
56+
57+
def _record_value_to_proto(
58+
value: Any, allowed_types: list[type], proto_class: type[T]
59+
) -> T:
60+
"""Serialize `*RecordValue` to ProtoBuf.
61+
62+
Note: `bool` MUST be put in the front of allowd_types if it exists.
63+
"""
64+
arg = {}
65+
for t in allowed_types:
66+
# Single element
67+
# Note: `isinstance(False, int) == True`.
68+
if isinstance(value, t):
69+
fld = _type_to_field[t]
70+
if t is int and _is_uint64(value):
71+
fld = "uint64"
72+
arg[fld] = value
73+
return proto_class(**arg)
74+
# List
75+
if isinstance(value, list) and all(isinstance(item, t) for item in value):
76+
list_class, fld = _list_type_to_class_and_field[t]
77+
# Use UintList if any element is of type `uint64`.
78+
if t is int and any(_is_uint64(v) for v in value):
79+
list_class, fld = UintList, "uint_list"
80+
arg[fld] = list_class(vals=value)
81+
return proto_class(**arg)
82+
# Invalid types
83+
raise TypeError(
84+
f"The type of the following value is not allowed "
85+
f"in '{proto_class.__name__}':\n{value}"
86+
)
87+
88+
89+
def _record_value_from_proto(value_proto: GrpcMessage) -> Any:
90+
"""Deserialize `*RecordValue` from ProtoBuf."""
91+
value_field = cast(str, value_proto.WhichOneof("value"))
92+
if value_field.endswith("list"):
93+
value = list(getattr(value_proto, value_field).vals)
94+
else:
95+
value = getattr(value_proto, value_field)
96+
return value
97+
98+
99+
def record_value_dict_to_proto(
100+
value_dict: TypedDict[str, Any],
101+
allowed_types: list[type],
102+
value_proto_class: type[T],
103+
) -> dict[str, T]:
104+
"""Serialize the record value dict to ProtoBuf.
105+
106+
Note: `bool` MUST be put in the front of allowd_types if it exists.
107+
"""
108+
# Move bool to the front
109+
if bool in allowed_types and allowed_types[0] != bool:
110+
allowed_types.remove(bool)
111+
allowed_types.insert(0, bool)
112+
113+
def proto(_v: Any) -> T:
114+
return _record_value_to_proto(_v, allowed_types, value_proto_class)
115+
116+
return {k: proto(v) for k, v in value_dict.items()}
117+
118+
119+
def record_value_dict_from_proto(
120+
value_dict_proto: MutableMapping[str, Any]
121+
) -> dict[str, Any]:
122+
"""Deserialize the record value dict from ProtoBuf."""
123+
return {k: _record_value_from_proto(v) for k, v in value_dict_proto.items()}

0 commit comments

Comments
 (0)