From 7d2f8d245d374724d6f7353980dd6269d7845176 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Sat, 1 Mar 2025 14:46:25 +0000 Subject: [PATCH] update constructor --- .../client/grpc_client/connection_test.py | 2 +- .../secure_aggregation/secaggplus_mod_test.py | 6 ++--- src/py/flwr/common/record/recordset.py | 16 ++++---------- src/py/flwr/common/record/recordset_test.py | 8 ++++--- src/py/flwr/common/serde.py | 20 +++++++---------- src/py/flwr/common/serde_test.py | 22 +++++++------------ .../superlink/linkstate/linkstate_test.py | 2 +- .../secure_aggregation/secaggplus_workflow.py | 6 ++--- 8 files changed, 32 insertions(+), 50 deletions(-) diff --git a/src/py/flwr/client/grpc_client/connection_test.py b/src/py/flwr/client/grpc_client/connection_test.py index 13bd2c6af8e7..9d6ecbe55ae8 100644 --- a/src/py/flwr/client/grpc_client/connection_test.py +++ b/src/py/flwr/client/grpc_client/connection_test.py @@ -69,7 +69,7 @@ ttl=DEFAULT_TTL, message_type="reconnect", ), - content=RecordSet(configs_records={"config": ConfigsRecord({"reason": 0})}), + content=RecordSet({"config": ConfigsRecord({"reason": 0})}), ) diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py index c12beaa77800..7c754602bcf4 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py @@ -62,9 +62,7 @@ def func(configs: dict[str, ConfigsRecordValues]) -> ConfigsRecord: ttl=DEFAULT_TTL, message_type=MessageType.TRAIN, ), - content=RecordSet( - configs_records={RECORD_KEY_CONFIGS: ConfigsRecord(configs)} - ), + content=RecordSet({RECORD_KEY_CONFIGS: ConfigsRecord(configs)}), ) out_msg = app(in_msg, ctxt) return out_msg.content.configs_records[RECORD_KEY_CONFIGS] @@ -78,7 +76,7 @@ def _make_ctxt() -> Context: run_id=234, node_id=123, node_config={}, - state=RecordSet(configs_records={RECORD_KEY_STATE: cfg}), + state=RecordSet({RECORD_KEY_STATE: cfg}), run_config={}, ) diff --git a/src/py/flwr/common/record/recordset.py b/src/py/flwr/common/record/recordset.py index 4dc3861fc41b..da39fbf43506 100644 --- a/src/py/flwr/common/record/recordset.py +++ b/src/py/flwr/common/record/recordset.py @@ -155,19 +155,11 @@ class RecordSet(TypedDict[str, RecordType]): :code:`MetricsRecord` and :code:`ParametersRecord`. """ - def __init__( - self, - parameters_records: dict[str, ParametersRecord] | None = None, - metrics_records: dict[str, MetricsRecord] | None = None, - configs_records: dict[str, ConfigsRecord] | None = None, - ) -> None: + def __init__(self, records: dict[str, RecordType] | None = None) -> None: super().__init__(_check_key, _check_value) - for key, p_record in (parameters_records or {}).items(): - self[key] = p_record - for key, m_record in (metrics_records or {}).items(): - self[key] = m_record - for key, c_record in (configs_records or {}).items(): - self[key] = c_record + if records is not None: + for key, record in records.items(): + self[key] = record @property def parameters_records(self) -> TypedDict[str, ParametersRecord]: diff --git a/src/py/flwr/common/record/recordset_test.py b/src/py/flwr/common/record/recordset_test.py index 561dd2dd6eeb..6dbc1172a4f1 100644 --- a/src/py/flwr/common/record/recordset_test.py +++ b/src/py/flwr/common/record/recordset_test.py @@ -422,9 +422,11 @@ def test_recordset_repr() -> None: """Test the string representation of RecordSet.""" # Prepare rs = RecordSet( - parameters_records={"params": ParametersRecord()}, - metrics_records={"metrics": MetricsRecord({"aa": 123})}, - configs_records={"configs": ConfigsRecord({"cc": bytes(5)})}, + { + "params": ParametersRecord(), + "metrics": MetricsRecord({"aa": 123}), + "configs": ConfigsRecord({"cc": bytes(5)}), + }, ) expected = """RecordSet( parameters_records={'params': {}}, diff --git a/src/py/flwr/common/serde.py b/src/py/flwr/common/serde.py index 17c929b47a82..3b66b67db3a4 100644 --- a/src/py/flwr/common/serde.py +++ b/src/py/flwr/common/serde.py @@ -583,18 +583,14 @@ def recordset_to_proto(recordset: RecordSet) -> ProtoRecordSet: def recordset_from_proto(recordset_proto: ProtoRecordSet) -> RecordSet: """Deserialize RecordSet from ProtoBuf.""" - return RecordSet( - parameters_records={ - k: parameters_record_from_proto(v) - for k, v in recordset_proto.parameters.items() - }, - metrics_records={ - k: metrics_record_from_proto(v) for k, v in recordset_proto.metrics.items() - }, - configs_records={ - k: configs_record_from_proto(v) for k, v in recordset_proto.configs.items() - }, - ) + ret = RecordSet() + for k, p_record_proto in recordset_proto.parameters.items(): + ret[k] = parameters_record_from_proto(p_record_proto) + for k, m_record_proto in recordset_proto.metrics.items(): + ret[k] = metrics_record_from_proto(m_record_proto) + for k, c_record_proto in recordset_proto.configs.items(): + ret[k] = configs_record_from_proto(c_record_proto) + return ret # === Message === diff --git a/src/py/flwr/common/serde_test.py b/src/py/flwr/common/serde_test.py index d4292757509b..bed1f2fe759a 100644 --- a/src/py/flwr/common/serde_test.py +++ b/src/py/flwr/common/serde_test.py @@ -247,20 +247,14 @@ def recordset( num_configs_records: int, ) -> RecordSet: """Create a RecordSet.""" - return RecordSet( - parameters_records={ - self.get_str(): self.parameters_record() - for _ in range(num_params_records) - }, - metrics_records={ - self.get_str(): self.metrics_record() - for _ in range(num_metrics_records) - }, - configs_records={ - self.get_str(): self.configs_record() - for _ in range(num_configs_records) - }, - ) + ret = RecordSet() + for _ in range(num_params_records): + ret[self.get_str()] = self.parameters_record() + for _ in range(num_metrics_records): + ret[self.get_str()] = self.metrics_record() + for _ in range(num_configs_records): + ret[self.get_str()] = self.configs_record() + return ret def metadata(self) -> Metadata: """Create a Metadata.""" diff --git a/src/py/flwr/server/superlink/linkstate/linkstate_test.py b/src/py/flwr/server/superlink/linkstate/linkstate_test.py index 371e507e662e..b5cc2447d702 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate_test.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate_test.py @@ -1166,7 +1166,7 @@ def create_res_message( if error: out_msg = in_msg.create_error_reply(error=error) else: - out_msg = in_msg.create_reply(content=RecordSet(parameters_records={})) + out_msg = in_msg.create_reply(content=RecordSet()) return message_to_proto(out_msg) diff --git a/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py b/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py index d84a5496dfe1..998dcd19df65 100644 --- a/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +++ b/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py @@ -367,7 +367,7 @@ def setup_stage( # pylint: disable=R0912, R0914, R0915 # Send setup configuration to clients cfgs_record = ConfigsRecord(sa_params_dict) # type: ignore - content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record}) + content = RecordSet({RECORD_KEY_CONFIGS: cfgs_record}) def make(nid: int) -> Message: return driver.create_message( @@ -417,7 +417,7 @@ def make(nid: int) -> Message: {str(nid): state.nid_to_publickeys[nid] for nid in neighbours} ) cfgs_record[Key.STAGE] = Stage.SHARE_KEYS - content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record}) + content = RecordSet({RECORD_KEY_CONFIGS: cfgs_record}) return driver.create_message( content=content, message_type=MessageType.TRAIN, @@ -566,7 +566,7 @@ def make(nid: int) -> Message: Key.DEAD_NODE_ID_LIST: list(neighbours & dead_nids), } cfgs_record = ConfigsRecord(cfgs_dict) # type: ignore - content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record}) + content = RecordSet({RECORD_KEY_CONFIGS: cfgs_record}) return driver.create_message( content=content, message_type=MessageType.TRAIN,