Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/py/flwr/client/grpc_client/connection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
ttl=DEFAULT_TTL,
message_type="reconnect",
),
content=RecordSet(configs_records={"config": ConfigsRecord({"reason": 0})}),
content=RecordSet({"config": ConfigsRecord({"reason": 0})}),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ def func(configs: dict[str, ConfigsRecordValues]) -> ConfigsRecord:
ttl=DEFAULT_TTL,
message_type=MessageType.TRAIN,
),
content=RecordSet(
configs_records={RECORD_KEY_CONFIGS: ConfigsRecord(configs)}
),
content=RecordSet({RECORD_KEY_CONFIGS: ConfigsRecord(configs)}),
)
out_msg = app(in_msg, ctxt)
return out_msg.content.configs_records[RECORD_KEY_CONFIGS]
Expand All @@ -78,7 +76,7 @@ def _make_ctxt() -> Context:
run_id=234,
node_id=123,
node_config={},
state=RecordSet(configs_records={RECORD_KEY_STATE: cfg}),
state=RecordSet({RECORD_KEY_STATE: cfg}),
run_config={},
)

Expand Down
16 changes: 4 additions & 12 deletions src/py/flwr/common/record/recordset.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,19 +155,11 @@ class RecordSet(TypedDict[str, RecordType]):
:code:`MetricsRecord` and :code:`ParametersRecord`.
"""

def __init__(
self,
parameters_records: dict[str, ParametersRecord] | None = None,
metrics_records: dict[str, MetricsRecord] | None = None,
configs_records: dict[str, ConfigsRecord] | None = None,
) -> None:
def __init__(self, records: dict[str, RecordType] | None = None) -> None:
super().__init__(_check_key, _check_value)
for key, p_record in (parameters_records or {}).items():
self[key] = p_record
for key, m_record in (metrics_records or {}).items():
self[key] = m_record
for key, c_record in (configs_records or {}).items():
self[key] = c_record
if records is not None:
for key, record in records.items():
self[key] = record

@property
def parameters_records(self) -> TypedDict[str, ParametersRecord]:
Expand Down
8 changes: 5 additions & 3 deletions src/py/flwr/common/record/recordset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,9 +422,11 @@ def test_recordset_repr() -> None:
"""Test the string representation of RecordSet."""
# Prepare
rs = RecordSet(
parameters_records={"params": ParametersRecord()},
metrics_records={"metrics": MetricsRecord({"aa": 123})},
configs_records={"configs": ConfigsRecord({"cc": bytes(5)})},
{
"params": ParametersRecord(),
"metrics": MetricsRecord({"aa": 123}),
"configs": ConfigsRecord({"cc": bytes(5)}),
},
)
expected = """RecordSet(
parameters_records={'params': {}},
Expand Down
20 changes: 8 additions & 12 deletions src/py/flwr/common/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,18 +583,14 @@ def recordset_to_proto(recordset: RecordSet) -> ProtoRecordSet:

def recordset_from_proto(recordset_proto: ProtoRecordSet) -> RecordSet:
"""Deserialize RecordSet from ProtoBuf."""
return RecordSet(
parameters_records={
k: parameters_record_from_proto(v)
for k, v in recordset_proto.parameters.items()
},
metrics_records={
k: metrics_record_from_proto(v) for k, v in recordset_proto.metrics.items()
},
configs_records={
k: configs_record_from_proto(v) for k, v in recordset_proto.configs.items()
},
)
ret = RecordSet()
for k, p_record_proto in recordset_proto.parameters.items():
ret[k] = parameters_record_from_proto(p_record_proto)
for k, m_record_proto in recordset_proto.metrics.items():
ret[k] = metrics_record_from_proto(m_record_proto)
for k, c_record_proto in recordset_proto.configs.items():
ret[k] = configs_record_from_proto(c_record_proto)
return ret


# === Message ===
Expand Down
22 changes: 8 additions & 14 deletions src/py/flwr/common/serde_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,20 +247,14 @@ def recordset(
num_configs_records: int,
) -> RecordSet:
"""Create a RecordSet."""
return RecordSet(
parameters_records={
self.get_str(): self.parameters_record()
for _ in range(num_params_records)
},
metrics_records={
self.get_str(): self.metrics_record()
for _ in range(num_metrics_records)
},
configs_records={
self.get_str(): self.configs_record()
for _ in range(num_configs_records)
},
)
ret = RecordSet()
for _ in range(num_params_records):
ret[self.get_str()] = self.parameters_record()
for _ in range(num_metrics_records):
ret[self.get_str()] = self.metrics_record()
for _ in range(num_configs_records):
ret[self.get_str()] = self.configs_record()
return ret

def metadata(self) -> Metadata:
"""Create a Metadata."""
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/server/superlink/linkstate/linkstate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,7 +1166,7 @@ def create_res_message(
if error:
out_msg = in_msg.create_error_reply(error=error)
else:
out_msg = in_msg.create_reply(content=RecordSet(parameters_records={}))
out_msg = in_msg.create_reply(content=RecordSet())

return message_to_proto(out_msg)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def setup_stage( # pylint: disable=R0912, R0914, R0915

# Send setup configuration to clients
cfgs_record = ConfigsRecord(sa_params_dict) # type: ignore
content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record})
content = RecordSet({RECORD_KEY_CONFIGS: cfgs_record})

def make(nid: int) -> Message:
return driver.create_message(
Expand Down Expand Up @@ -417,7 +417,7 @@ def make(nid: int) -> Message:
{str(nid): state.nid_to_publickeys[nid] for nid in neighbours}
)
cfgs_record[Key.STAGE] = Stage.SHARE_KEYS
content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record})
content = RecordSet({RECORD_KEY_CONFIGS: cfgs_record})
return driver.create_message(
content=content,
message_type=MessageType.TRAIN,
Expand Down Expand Up @@ -566,7 +566,7 @@ def make(nid: int) -> Message:
Key.DEAD_NODE_ID_LIST: list(neighbours & dead_nids),
}
cfgs_record = ConfigsRecord(cfgs_dict) # type: ignore
content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record})
content = RecordSet({RECORD_KEY_CONFIGS: cfgs_record})
return driver.create_message(
content=content,
message_type=MessageType.TRAIN,
Expand Down