From 0a11bffc4b96da62188d4e71f154366876a6cb66 Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Tue, 11 Mar 2025 16:14:56 +0000 Subject: [PATCH 01/15] Init --- src/py/flwr/server/driver/driver.py | 17 +++++++++++------ src/py/flwr/server/driver/grpc_driver.py | 15 +++++++++++---- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/src/py/flwr/server/driver/driver.py b/src/py/flwr/server/driver/driver.py index 4c015751d3da..cc3776af94fa 100644 --- a/src/py/flwr/server/driver/driver.py +++ b/src/py/flwr/server/driver/driver.py @@ -89,7 +89,7 @@ def get_node_ids(self) -> Iterable[int]: """Get node IDs.""" @abstractmethod - def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: + def push_messages(self, messages: Iterable[Message]) -> None: """Push messages to specified node IDs. This method takes an iterable of messages and sends each message @@ -108,16 +108,21 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: """ @abstractmethod - def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]: - """Pull messages based on message IDs. + def pull_messages( + self, message_ids: Optional[Iterable[str]] = None + ) -> Iterable[Message]: + """Pull messages from the SuperLink. - This method is used to collect messages from the SuperLink - that correspond to a set of given message IDs. + This method is used to collect all available messages from the SuperLink. + If provided, it will only pull messages that correspond to a set of given + message IDs. Parameters ---------- - message_ids : Iterable[str] + message_ids : Optional[Iterable[str]] An iterable of message IDs for which reply messages are to be retrieved. + If specified, the method will only pull messages that correspond to these + IDs. If `None`, all messages will be retrieved. Returns ------- diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index 9a5157691775..26c184252009 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -81,6 +81,7 @@ def __init__( # pylint: disable=too-many-arguments self._channel: Optional[grpc.Channel] = None self.node = Node(node_id=SUPERLINK_NODE_ID) self._retry_invoker = _make_simple_grpc_retry_invoker() + self._pending_messages: set[str] = set() @property def _is_connected(self) -> bool: @@ -191,7 +192,7 @@ def get_node_ids(self) -> Iterable[int]: ) return [node.node_id for node in res.nodes] - def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: + def push_messages(self, messages: Iterable[Message]) -> None: """Push messages to specified node IDs. This method takes an iterable of messages and sends each message @@ -221,18 +222,24 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: "list has `None` for those messages (the order is preserved as passed " "to `push_messages`). This could be due to a malformed message.", ) - return list(res.message_ids) + self._pending_messages.update(res.message_ids) - def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]: + def pull_messages( + self, message_ids: Optional[Iterable[str]] = None + ) -> Iterable[Message]: """Pull messages based on message IDs. This method is used to collect messages from the SuperLink that correspond to a set of given message IDs. """ + # Allow an override but default to the stored pending IDs. + message_ids_to_pull = ( + message_ids if message_ids is not None else self._pending_messages + ) # Pull Messages res: PullResMessagesResponse = self._stub.PullMessages( PullResMessagesRequest( - message_ids=message_ids, + message_ids=message_ids_to_pull, run_id=cast(Run, self._run).run_id, ) ) From 2e9a5d127b8f23075c6a14a03347a96c543f88f0 Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Wed, 12 Mar 2025 13:12:33 +0000 Subject: [PATCH 02/15] Add ReadOnlyList type, update pull_messages --- src/py/flwr/common/typing.py | 33 +++++++++++++- src/py/flwr/server/driver/driver.py | 17 ++++--- src/py/flwr/server/driver/grpc_driver.py | 58 ++++++++++++++++-------- 3 files changed, 81 insertions(+), 27 deletions(-) diff --git a/src/py/flwr/common/typing.py b/src/py/flwr/common/typing.py index 5ec4c1086e1c..a006fae511ff 100644 --- a/src/py/flwr/common/typing.py +++ b/src/py/flwr/common/typing.py @@ -15,9 +15,10 @@ """Flower type definitions.""" +from collections.abc import Iterator, Sequence from dataclasses import dataclass from enum import Enum -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Union, cast, overload import numpy as np import numpy.typing as npt @@ -322,3 +323,33 @@ class LogEntry: actor: Actor event: Event status: str + + +class ReadOnlyList(Sequence[str]): + """A thin, generic read-only wrapper for a list of strings.""" + + def __init__(self, data: list[str]) -> None: + # Store a reference to the original mutable list + self._data = data + + @overload + def __getitem__(self, index: int) -> str: ... + + @overload + def __getitem__(self, index: slice) -> Sequence[str]: ... + + def __getitem__(self, index: Union[int, slice]) -> Union[str, "ReadOnlyList"]: + result = self._data[index] + # If the result is a slice, wrap it in a ReadOnlyList. + if isinstance(index, slice): + return ReadOnlyList(cast(list[str], result)) + return cast(str, result) + + def __len__(self) -> int: + return len(self._data) + + def __iter__(self) -> Iterator[str]: + return iter(self._data) + + def __repr__(self) -> str: + return f"{self.__class__.__qualname__}({self._data!r})" diff --git a/src/py/flwr/server/driver/driver.py b/src/py/flwr/server/driver/driver.py index cc3776af94fa..a987aaac7fa6 100644 --- a/src/py/flwr/server/driver/driver.py +++ b/src/py/flwr/server/driver/driver.py @@ -20,7 +20,7 @@ from typing import Optional from flwr.common import Message, RecordSet -from flwr.common.typing import Run +from flwr.common.typing import ReadOnlyList, Run class Driver(ABC): @@ -45,6 +45,11 @@ def set_run(self, run_id: int) -> None: def run(self) -> Run: """Run information.""" + @property + @abstractmethod + def message_ids(self) -> ReadOnlyList: + """Message IDs of pushed messages.""" + @abstractmethod def create_message( # pylint: disable=too-many-arguments,R0917 self, @@ -89,7 +94,7 @@ def get_node_ids(self) -> Iterable[int]: """Get node IDs.""" @abstractmethod - def push_messages(self, messages: Iterable[Message]) -> None: + def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: """Push messages to specified node IDs. This method takes an iterable of messages and sends each message @@ -111,11 +116,11 @@ def push_messages(self, messages: Iterable[Message]) -> None: def pull_messages( self, message_ids: Optional[Iterable[str]] = None ) -> Iterable[Message]: - """Pull messages from the SuperLink. + """Pull messages from the SuperLink based on message IDs. - This method is used to collect all available messages from the SuperLink. - If provided, it will only pull messages that correspond to a set of given - message IDs. + This method is used to collect messages from the SuperLink that correspond to a + set of given message IDs. If no message IDs are provided, it defaults to the + stored message IDs. Parameters ---------- diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index 26c184252009..66f5f444a13d 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -19,7 +19,7 @@ import warnings from collections.abc import Iterable from logging import DEBUG, WARNING -from typing import Optional, cast +from typing import Generator, Optional, cast import grpc @@ -32,7 +32,7 @@ from flwr.common.logger import log from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto -from flwr.common.typing import Run +from flwr.common.typing import ReadOnlyList, Run from flwr.proto.message_pb2 import Message as ProtoMessage # pylint: disable=E0611 from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 @@ -56,7 +56,7 @@ """ -class GrpcDriver(Driver): +class GrpcDriver(Driver): # pylint: disable=too-many-instance-attributes """`GrpcDriver` provides an interface to the ServerAppIo API. Parameters @@ -81,7 +81,7 @@ def __init__( # pylint: disable=too-many-arguments self._channel: Optional[grpc.Channel] = None self.node = Node(node_id=SUPERLINK_NODE_ID) self._retry_invoker = _make_simple_grpc_retry_invoker() - self._pending_messages: set[str] = set() + self._message_ids: list[str] = [] @property def _is_connected(self) -> bool: @@ -138,6 +138,11 @@ def _stub(self) -> ServerAppIoStub: self._connect() return cast(ServerAppIoStub, self._grpc_stub) + @property + def message_ids(self) -> ReadOnlyList: + """Message IDs of pushed messages.""" + return ReadOnlyList(self._message_ids) + def _check_message(self, message: Message) -> None: # Check if the message is valid if not ( @@ -192,7 +197,7 @@ def get_node_ids(self) -> Iterable[int]: ) return [node.node_id for node in res.nodes] - def push_messages(self, messages: Iterable[Message]) -> None: + def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: """Push messages to specified node IDs. This method takes an iterable of messages and sends each message @@ -222,30 +227,43 @@ def push_messages(self, messages: Iterable[Message]) -> None: "list has `None` for those messages (the order is preserved as passed " "to `push_messages`). This could be due to a malformed message.", ) - self._pending_messages.update(res.message_ids) + # Store message IDs + msg_ids = list(res.message_ids) + self._message_ids = msg_ids + return msg_ids def pull_messages( self, message_ids: Optional[Iterable[str]] = None ) -> Iterable[Message]: - """Pull messages based on message IDs. + """Pull messages from the SuperLink based on message IDs. This method is used to collect messages from the SuperLink that correspond to a - set of given message IDs. + set of given message IDs. If no message IDs are provided, it defaults to the + stored message IDs. """ - # Allow an override but default to the stored pending IDs. + # Raise an error if no message IDs are provided and none are stored + if not self._message_ids: + raise ValueError("No message IDs to pull. Call `push_messages` first.") + + # Allow an override but default to the stored pending IDs message_ids_to_pull = ( - message_ids if message_ids is not None else self._pending_messages + message_ids if message_ids is not None else self._message_ids ) - # Pull Messages - res: PullResMessagesResponse = self._stub.PullMessages( - PullResMessagesRequest( - message_ids=message_ids_to_pull, - run_id=cast(Run, self._run).run_id, - ) - ) - # Convert Message from Protobuf representation - msgs = [message_from_proto(msg_proto) for msg_proto in res.messages_list] - return msgs + + def iter_msg() -> Generator[Message, None, None]: + for msg_id in message_ids_to_pull: + # Pull Messages for each message ID + res: PullResMessagesResponse = self._stub.PullMessages( + PullResMessagesRequest( + message_ids=[msg_id], + run_id=cast(Run, self._run).run_id, + ) + ) + # Convert Message from Protobuf representation + msg = message_from_proto(res.messages_list[0]) + yield msg + + return iter_msg() def send_and_receive( self, From 0df1f43de195eebbf70a30c4ff9ac145ada3b92b Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Wed, 12 Mar 2025 14:17:09 +0000 Subject: [PATCH 03/15] Simplify --- src/py/flwr/common/typing.py | 33 +----------------------- src/py/flwr/server/driver/driver.py | 4 +-- src/py/flwr/server/driver/grpc_driver.py | 11 ++++---- 3 files changed, 8 insertions(+), 40 deletions(-) diff --git a/src/py/flwr/common/typing.py b/src/py/flwr/common/typing.py index a006fae511ff..5ec4c1086e1c 100644 --- a/src/py/flwr/common/typing.py +++ b/src/py/flwr/common/typing.py @@ -15,10 +15,9 @@ """Flower type definitions.""" -from collections.abc import Iterator, Sequence from dataclasses import dataclass from enum import Enum -from typing import Any, Callable, Optional, Union, cast, overload +from typing import Any, Callable, Optional, Union import numpy as np import numpy.typing as npt @@ -323,33 +322,3 @@ class LogEntry: actor: Actor event: Event status: str - - -class ReadOnlyList(Sequence[str]): - """A thin, generic read-only wrapper for a list of strings.""" - - def __init__(self, data: list[str]) -> None: - # Store a reference to the original mutable list - self._data = data - - @overload - def __getitem__(self, index: int) -> str: ... - - @overload - def __getitem__(self, index: slice) -> Sequence[str]: ... - - def __getitem__(self, index: Union[int, slice]) -> Union[str, "ReadOnlyList"]: - result = self._data[index] - # If the result is a slice, wrap it in a ReadOnlyList. - if isinstance(index, slice): - return ReadOnlyList(cast(list[str], result)) - return cast(str, result) - - def __len__(self) -> int: - return len(self._data) - - def __iter__(self) -> Iterator[str]: - return iter(self._data) - - def __repr__(self) -> str: - return f"{self.__class__.__qualname__}({self._data!r})" diff --git a/src/py/flwr/server/driver/driver.py b/src/py/flwr/server/driver/driver.py index a987aaac7fa6..e737e4b06ba7 100644 --- a/src/py/flwr/server/driver/driver.py +++ b/src/py/flwr/server/driver/driver.py @@ -20,7 +20,7 @@ from typing import Optional from flwr.common import Message, RecordSet -from flwr.common.typing import ReadOnlyList, Run +from flwr.common.typing import Run class Driver(ABC): @@ -47,7 +47,7 @@ def run(self) -> Run: @property @abstractmethod - def message_ids(self) -> ReadOnlyList: + def message_ids(self) -> list[str]: """Message IDs of pushed messages.""" @abstractmethod diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index 66f5f444a13d..293a975508ed 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -32,7 +32,7 @@ from flwr.common.logger import log from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto -from flwr.common.typing import ReadOnlyList, Run +from flwr.common.typing import Run from flwr.proto.message_pb2 import Message as ProtoMessage # pylint: disable=E0611 from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 @@ -139,9 +139,9 @@ def _stub(self) -> ServerAppIoStub: return cast(ServerAppIoStub, self._grpc_stub) @property - def message_ids(self) -> ReadOnlyList: + def message_ids(self) -> list[str]: """Message IDs of pushed messages.""" - return ReadOnlyList(self._message_ids) + return self._message_ids.copy() def _check_message(self, message: Message) -> None: # Check if the message is valid @@ -228,9 +228,8 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: "to `push_messages`). This could be due to a malformed message.", ) # Store message IDs - msg_ids = list(res.message_ids) - self._message_ids = msg_ids - return msg_ids + self._message_ids.extend(res.message_ids) + return list(res.message_ids) def pull_messages( self, message_ids: Optional[Iterable[str]] = None From 0be12982a47e6db72005ab554e341a12b79f74c0 Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Wed, 12 Mar 2025 14:22:15 +0000 Subject: [PATCH 04/15] Use Iterator instead of Generator --- src/py/flwr/server/driver/grpc_driver.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index 293a975508ed..0e9912b52540 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -19,7 +19,7 @@ import warnings from collections.abc import Iterable from logging import DEBUG, WARNING -from typing import Generator, Optional, cast +from typing import Iterator, Optional, cast import grpc @@ -249,7 +249,7 @@ def pull_messages( message_ids if message_ids is not None else self._message_ids ) - def iter_msg() -> Generator[Message, None, None]: + def iter_msg() -> Iterator[Message]: for msg_id in message_ids_to_pull: # Pull Messages for each message ID res: PullResMessagesResponse = self._stub.PullMessages( From 3949c7c97845e8eb473fd29167115b5358679670 Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Wed, 12 Mar 2025 15:13:46 +0000 Subject: [PATCH 05/15] Update test --- src/py/flwr/server/driver/grpc_driver.py | 2 +- src/py/flwr/server/driver/grpc_driver_test.py | 38 +++++++++++++------ src/py/flwr/server/driver/inmemory_driver.py | 27 ++++++++++--- 3 files changed, 49 insertions(+), 18 deletions(-) diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index 0e9912b52540..18c52aa8d0b4 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -283,7 +283,7 @@ def send_and_receive( end_time = time.time() + (timeout if timeout is not None else 0.0) ret: list[Message] = [] while timeout is None or time.time() < end_time: - res_msgs = self.pull_messages(msg_ids) + res_msgs = list(self.pull_messages(msg_ids)) ret.extend(res_msgs) msg_ids.difference_update( {msg.metadata.reply_to_message for msg in res_msgs} diff --git a/src/py/flwr/server/driver/grpc_driver_test.py b/src/py/flwr/server/driver/grpc_driver_test.py index d32c6a0e7c58..82e1738d6daa 100644 --- a/src/py/flwr/server/driver/grpc_driver_test.py +++ b/src/py/flwr/server/driver/grpc_driver_test.py @@ -132,7 +132,6 @@ def test_push_messages_invalid(self) -> None: def test_pull_messages_with_given_message_ids(self) -> None: """Test pulling messages with specific message IDs.""" # Prepare - mock_response = Mock() # A Message must have either content or error set so we prepare run_id = 12345 ok_message = create_res_message(src_node_id=123, dst_node_id=456, run_id=run_id) @@ -142,22 +141,37 @@ def test_pull_messages_with_given_message_ids(self) -> None: src_node_id=123, dst_node_id=789, run_id=run_id, error=Error(code=0) ) error_message.metadata.reply_to_message = "id3" - # The the response from the DriverServicer is in the form of Protbuf Messages - mock_response.messages_list = [ok_message, error_message] - self.mock_stub.PullMessages.return_value = mock_response - msg_ids = ["id1", "id2", "id3"] + + # Create separate mock responses for each call + mock_response1 = Mock() + mock_response1.messages_list = [ok_message] + mock_response2 = Mock() + mock_response2.messages_list = [error_message] + + # Configure PullMessages to return a different response per call + self.mock_stub.PullMessages.side_effect = [mock_response1, mock_response2] + + msg_ids = ["id2", "id3"] + # Set driver message_ids for the test + self.driver._message_ids = msg_ids # pylint: disable=protected-access # Execute - msgs = self.driver.pull_messages(msg_ids) + msgs = list(self.driver.pull_messages(msg_ids)) reply_tos = {msg.metadata.reply_to_message for msg in msgs} - args, kwargs = self.mock_stub.PullMessages.call_args # Assert - self.mock_stub.GetRun.assert_called_once() - self.assertEqual(len(args), 1) - self.assertEqual(len(kwargs), 0) - self.assertIsInstance(args[0], PullResMessagesRequest) - self.assertEqual(args[0].message_ids, msg_ids) + # PullMessages was called twice (once per message id) + calls = self.mock_stub.PullMessages.call_args_list + self.assertEqual(len(calls), len(msg_ids)) + for call, expected_msg_id in zip(calls, msg_ids): + args, kwargs = call + self.assertEqual(len(args), 1) + self.assertEqual(len(kwargs), 0) + self.assertIsInstance(args[0], PullResMessagesRequest) + # Each call should be made with a single-element list containing + # the current message id + self.assertEqual(args[0].message_ids, [expected_msg_id]) + self.assertEqual(reply_tos, {"id2", "id3"}) def test_send_and_receive_messages_complete(self) -> None: diff --git a/src/py/flwr/server/driver/inmemory_driver.py b/src/py/flwr/server/driver/inmemory_driver.py index 4b50badb4c11..08019613523b 100644 --- a/src/py/flwr/server/driver/inmemory_driver.py +++ b/src/py/flwr/server/driver/inmemory_driver.py @@ -50,6 +50,7 @@ def __init__( self.state = state_factory.state() self.pull_interval = pull_interval self.node = Node(node_id=SUPERLINK_NODE_ID) + self._message_ids: list[str] = [] def _check_message(self, message: Message) -> None: # Check if the message is valid @@ -75,6 +76,11 @@ def run(self) -> Run: """Run ID.""" return Run(**vars(cast(Run, self._run))) + @property + def message_ids(self) -> list[str]: + """Message IDs of pushed messages.""" + return self._message_ids.copy() + def create_message( # pylint: disable=too-many-arguments,R0917 self, content: RecordSet, @@ -127,16 +133,27 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: msg_id = self.state.store_message_ins(msg) if msg_id: msg_ids.append(str(msg_id)) - + # Store message IDs + self._message_ids.extend(msg_ids) return msg_ids - def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]: - """Pull messages based on message IDs. + def pull_messages( + self, message_ids: Optional[Iterable[str]] = None + ) -> Iterable[Message]: + """Pull messages from the SuperLink based on message IDs. This method is used to collect messages from the SuperLink that correspond to a - set of given message IDs. + set of given message IDs. If no message IDs are provided, it defaults to the + stored message IDs. """ - msg_ids = {UUID(msg_id) for msg_id in message_ids} + # Raise an error if no message IDs are provided and none are stored + if not self._message_ids: + raise ValueError("No message IDs to pull. Call `push_messages` first.") + msg_ids = ( + {UUID(msg_id) for msg_id in message_ids} + if message_ids is not None + else {UUID(msg_id) for msg_id in self._message_ids} + ) # Pull Messages message_res_list = self.state.get_message_res(message_ids=msg_ids) # Get IDs of Messages these replies are for From ae021dc5c4cd01f479c78a109b033bb6cfefa739 Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Wed, 12 Mar 2025 16:33:29 +0000 Subject: [PATCH 06/15] Update tests --- src/py/flwr/server/driver/grpc_driver.py | 22 +++- src/py/flwr/server/driver/grpc_driver_test.py | 121 ++++++++++++++---- 2 files changed, 114 insertions(+), 29 deletions(-) diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index 18c52aa8d0b4..e3c3eacc4515 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -245,9 +245,25 @@ def pull_messages( raise ValueError("No message IDs to pull. Call `push_messages` first.") # Allow an override but default to the stored pending IDs - message_ids_to_pull = ( - message_ids if message_ids is not None else self._message_ids - ) + # Prepare the set of message IDs to pull: + # - If no message_ids are provided, use the stored ones. + # - Otherwise, filter the provided list to only those in self._message_ids. + if message_ids is None: + message_ids_to_pull = self._message_ids + else: + provided_ids = list(message_ids) # Preserve order + message_ids_to_pull = [ + msg_id for msg_id in provided_ids if msg_id in self._message_ids + ] + missing_ids = [ + msg_id for msg_id in provided_ids if msg_id not in self._message_ids + ] + if missing_ids: + log( + WARNING, + "Cannot pull messages for the following missing message IDs: %s", + missing_ids, + ) def iter_msg() -> Iterator[Message]: for msg_id in message_ids_to_pull: diff --git a/src/py/flwr/server/driver/grpc_driver_test.py b/src/py/flwr/server/driver/grpc_driver_test.py index 82e1738d6daa..e156d85f76f3 100644 --- a/src/py/flwr/server/driver/grpc_driver_test.py +++ b/src/py/flwr/server/driver/grpc_driver_test.py @@ -17,11 +17,12 @@ import time import unittest +from logging import WARNING from unittest.mock import Mock, patch import grpc -from flwr.common import DEFAULT_TTL, RecordSet +from flwr.common import DEFAULT_TTL, Message, RecordSet from flwr.common.message import Error from flwr.proto.run_pb2 import ( # pylint: disable=E0611 GetRunRequest, @@ -129,14 +130,13 @@ def test_push_messages_invalid(self) -> None: with self.assertRaises(ValueError): self.driver.push_messages(msgs) - def test_pull_messages_with_given_message_ids(self) -> None: - """Test pulling messages with specific message IDs.""" - # Prepare - # A Message must have either content or error set so we prepare - run_id = 12345 + def _setup_pull_messages_mocks(self, run_id: int) -> list[str]: + """Set up common mocks for pull_messages. This creates two mock responses, + one for each message ID, and configures the stub to return them sequentially. + """ + # Create messages with distinct reply_to values ok_message = create_res_message(src_node_id=123, dst_node_id=456, run_id=run_id) ok_message.metadata.reply_to_message = "id2" - error_message = create_res_message( src_node_id=123, dst_node_id=789, run_id=run_id, error=Error(code=0) ) @@ -151,29 +151,98 @@ def test_pull_messages_with_given_message_ids(self) -> None: # Configure PullMessages to return a different response per call self.mock_stub.PullMessages.side_effect = [mock_response1, mock_response2] - msg_ids = ["id2", "id3"] - # Set driver message_ids for the test - self.driver._message_ids = msg_ids # pylint: disable=protected-access + # Return the message IDs to be used in the tests (the valid ones) + return ["id2", "id3"] - # Execute - msgs = list(self.driver.pull_messages(msg_ids)) - reply_tos = {msg.metadata.reply_to_message for msg in msgs} - - # Assert - # PullMessages was called twice (once per message id) + def _assert_pull_messages( + self, expected_ids: list[str], messages: list[Message] + ) -> None: + """Check that PullMessages was called once per expected message ID and + that the returned messages have the correct reply_to values. + """ + reply_tos = {msg.metadata.reply_to_message for msg in messages} calls = self.mock_stub.PullMessages.call_args_list - self.assertEqual(len(calls), len(msg_ids)) - for call, expected_msg_id in zip(calls, msg_ids): + self.assertEqual(len(calls), len(expected_ids)) + for call, expected_msg_id in zip(calls, expected_ids): args, kwargs = call self.assertEqual(len(args), 1) self.assertEqual(len(kwargs), 0) self.assertIsInstance(args[0], PullResMessagesRequest) - # Each call should be made with a single-element list containing - # the current message id + # Each call should be made with a single-element list + # containing the expected message id self.assertEqual(args[0].message_ids, [expected_msg_id]) - self.assertEqual(reply_tos, {"id2", "id3"}) + def test_pull_messages_with_given_message_ids(self) -> None: + """Test pulling messages with specific message IDs.""" + # Prepare + run_id = 12345 + msg_ids = self._setup_pull_messages_mocks(run_id) + + # Store message IDs in the driver's internal state for testing + self.driver._message_ids.extend(msg_ids) # pylint: disable=protected-access + + # Execute + messages = list(self.driver.pull_messages(msg_ids)) + + # Assert + self._assert_pull_messages(msg_ids, messages) + + def test_pull_messages_without_given_message_ids(self) -> None: + """Test pulling messages successful when no message_ids are provided.""" + # Prepare + run_id = 12345 + msg_ids = self._setup_pull_messages_mocks(run_id) + + # Store message IDs in the driver's internal state for testing + self.driver._message_ids.extend(msg_ids) # pylint: disable=protected-access + + # Execute + messages = list(self.driver.pull_messages()) + + # Assert + self._assert_pull_messages(msg_ids, messages) + + def test_pull_messages_with_invalid_message_ids(self) -> None: + """Test pulling messages when provided message_ids include values not + stored in self._message_ids.""" + # Prepare + run_id = 12345 + valid_msg_ids = self._setup_pull_messages_mocks( + run_id + ) # returns ["id2", "id3"] + # Store message IDs in the driver's internal state for testing + self.driver._message_ids.extend( # pylint: disable=protected-access + valid_msg_ids + ) + provided_msg_ids = [ + "id2", + "id3", + "id4", + "id5", + ] # "id4" and "id5" are not stored. + expected_missing = [ + msg_id for msg_id in provided_msg_ids if msg_id not in valid_msg_ids + ] + + # Patch the log function to capture the warning. + with patch("flwr.server.driver.grpc_driver.log") as mock_log: + # Execute + messages = list(self.driver.pull_messages(provided_msg_ids)) + + # Assert + # Only valid IDs are pulled + self._assert_pull_messages(valid_msg_ids, messages) + # Warning was logged with the missing IDs + mock_log.assert_called_once() + args, _ = mock_log.call_args + log_level = args[0] + logged_missing_ids = args[2] + # Verify that the log level is WARNING and the missing IDs appear in the + # log message + self.assertEqual(log_level, WARNING) + self.assertEqual(logged_missing_ids, expected_missing) + def test_send_and_receive_messages_complete(self) -> None: """Test send and receive all messages successfully.""" # Prepare @@ -205,14 +274,14 @@ def test_send_and_receive_messages_timeout(self) -> None: sleep_fn = time.sleep mock_response = Mock(message_ids=["id1"]) self.mock_stub.PushMessages.return_value = mock_response - mock_response = Mock(messages_list=[]) - self.mock_stub.PullMessages.return_value = mock_response msgs = [self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL)] # Execute - with patch("time.sleep", side_effect=lambda t: sleep_fn(t * 0.01)): - start_time = time.time() - ret_msgs = list(self.driver.send_and_receive(msgs, timeout=0.15)) + # Patch pull_messages to always return an empty iterator. + with patch.object(self.driver, "pull_messages", return_value=iter([])): + with patch("time.sleep", side_effect=lambda t: sleep_fn(t * 0.01)): + start_time = time.time() + ret_msgs = list(self.driver.send_and_receive(msgs, timeout=0.15)) # Assert self.assertLess(time.time() - start_time, 0.2) From b5d90859397e32a180437ce439dd16a8d47e81ea Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Wed, 12 Mar 2025 20:43:37 +0000 Subject: [PATCH 07/15] Fix docformatter --- src/py/flwr/server/driver/grpc_driver_test.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/py/flwr/server/driver/grpc_driver_test.py b/src/py/flwr/server/driver/grpc_driver_test.py index e156d85f76f3..63e4691fafd4 100644 --- a/src/py/flwr/server/driver/grpc_driver_test.py +++ b/src/py/flwr/server/driver/grpc_driver_test.py @@ -131,8 +131,10 @@ def test_push_messages_invalid(self) -> None: self.driver.push_messages(msgs) def _setup_pull_messages_mocks(self, run_id: int) -> list[str]: - """Set up common mocks for pull_messages. This creates two mock responses, - one for each message ID, and configures the stub to return them sequentially. + """Set up common mocks for pull_messages. + + This creates two mock responses, one for each message ID, and configures the + stub to return them sequentially. """ # Create messages with distinct reply_to values ok_message = create_res_message(src_node_id=123, dst_node_id=456, run_id=run_id) @@ -157,9 +159,8 @@ def _setup_pull_messages_mocks(self, run_id: int) -> list[str]: def _assert_pull_messages( self, expected_ids: list[str], messages: list[Message] ) -> None: - """Check that PullMessages was called once per expected message ID and - that the returned messages have the correct reply_to values. - """ + """Check that PullMessages was called once per expected message ID and that the + returned messages have the correct reply_to values.""" reply_tos = {msg.metadata.reply_to_message for msg in messages} calls = self.mock_stub.PullMessages.call_args_list self.assertEqual(len(calls), len(expected_ids)) @@ -204,8 +205,8 @@ def test_pull_messages_without_given_message_ids(self) -> None: self._assert_pull_messages(msg_ids, messages) def test_pull_messages_with_invalid_message_ids(self) -> None: - """Test pulling messages when provided message_ids include values not - stored in self._message_ids.""" + """Test pulling messages when provided message_ids include values not stored in + self._message_ids.""" # Prepare run_id = 12345 valid_msg_ids = self._setup_pull_messages_mocks( From cd75e47c0d945c53723edc24590548fbe42ad8bd Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Wed, 12 Mar 2025 21:07:06 +0000 Subject: [PATCH 08/15] Switch to set, update test --- src/py/flwr/server/driver/grpc_driver.py | 15 ++++------ src/py/flwr/server/driver/grpc_driver_test.py | 29 ++++++------------- 2 files changed, 14 insertions(+), 30 deletions(-) diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index e3c3eacc4515..43890ec4df2c 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -245,19 +245,14 @@ def pull_messages( raise ValueError("No message IDs to pull. Call `push_messages` first.") # Allow an override but default to the stored pending IDs - # Prepare the set of message IDs to pull: - # - If no message_ids are provided, use the stored ones. - # - Otherwise, filter the provided list to only those in self._message_ids. if message_ids is None: + # If no message_ids are provided, use the stored ones message_ids_to_pull = self._message_ids else: - provided_ids = list(message_ids) # Preserve order - message_ids_to_pull = [ - msg_id for msg_id in provided_ids if msg_id in self._message_ids - ] - missing_ids = [ - msg_id for msg_id in provided_ids if msg_id not in self._message_ids - ] + # Otherwise, filter the provided list to only those in self._message_ids + provided_ids = set(message_ids) + message_ids_to_pull = sorted(provided_ids & set(self._message_ids)) + missing_ids = sorted(provided_ids - set(self._message_ids)) if missing_ids: log( WARNING, diff --git a/src/py/flwr/server/driver/grpc_driver_test.py b/src/py/flwr/server/driver/grpc_driver_test.py index 63e4691fafd4..7df58af9c528 100644 --- a/src/py/flwr/server/driver/grpc_driver_test.py +++ b/src/py/flwr/server/driver/grpc_driver_test.py @@ -209,38 +209,27 @@ def test_pull_messages_with_invalid_message_ids(self) -> None: self._message_ids.""" # Prepare run_id = 12345 - valid_msg_ids = self._setup_pull_messages_mocks( - run_id - ) # returns ["id2", "id3"] + msg_ids = self._setup_pull_messages_mocks(run_id) + # Store message IDs in the driver's internal state for testing - self.driver._message_ids.extend( # pylint: disable=protected-access - valid_msg_ids - ) - provided_msg_ids = [ - "id2", - "id3", - "id4", - "id5", - ] # "id4" and "id5" are not stored. - expected_missing = [ - msg_id for msg_id in provided_msg_ids if msg_id not in valid_msg_ids - ] + self.driver._message_ids.extend(msg_ids) # pylint: disable=protected-access + + provided_msg_ids = {"id2", "id3", "id4", "id5"} + expected_missing = sorted(provided_msg_ids - set(msg_ids)) - # Patch the log function to capture the warning. + # Execute with patch("flwr.server.driver.grpc_driver.log") as mock_log: - # Execute messages = list(self.driver.pull_messages(provided_msg_ids)) # Assert # Only valid IDs are pulled - self._assert_pull_messages(valid_msg_ids, messages) + self._assert_pull_messages(msg_ids, messages) # Warning was logged with the missing IDs mock_log.assert_called_once() args, _ = mock_log.call_args log_level = args[0] logged_missing_ids = args[2] - # Verify that the log level is WARNING and the missing IDs appear in the - # log message + # Log level is WARNING and the missing IDs appear in the log message self.assertEqual(log_level, WARNING) self.assertEqual(logged_missing_ids, expected_missing) From 49b419f3a83c505d4dd5ba7c41ff63c4fb3675ff Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Wed, 12 Mar 2025 21:08:21 +0000 Subject: [PATCH 09/15] Fix Iterator import --- src/py/flwr/server/driver/grpc_driver.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index 43890ec4df2c..f58674287652 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -17,9 +17,9 @@ import time import warnings -from collections.abc import Iterable +from collections.abc import Iterable, Iterator from logging import DEBUG, WARNING -from typing import Iterator, Optional, cast +from typing import Optional, cast import grpc From ed61c7887006796bba2a442178f41ef6b6d6d603 Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Wed, 12 Mar 2025 21:58:39 +0000 Subject: [PATCH 10/15] Update inmemory_driver_test --- src/py/flwr/server/driver/grpc_driver_test.py | 10 ++-- src/py/flwr/server/driver/inmemory_driver.py | 34 +++++++++--- .../server/driver/inmemory_driver_test.py | 54 +++++++++++++++++++ 3 files changed, 85 insertions(+), 13 deletions(-) diff --git a/src/py/flwr/server/driver/grpc_driver_test.py b/src/py/flwr/server/driver/grpc_driver_test.py index 7df58af9c528..d4c46c73adb9 100644 --- a/src/py/flwr/server/driver/grpc_driver_test.py +++ b/src/py/flwr/server/driver/grpc_driver_test.py @@ -227,11 +227,11 @@ def test_pull_messages_with_invalid_message_ids(self) -> None: # Warning was logged with the missing IDs mock_log.assert_called_once() args, _ = mock_log.call_args - log_level = args[0] - logged_missing_ids = args[2] - # Log level is WARNING and the missing IDs appear in the log message - self.assertEqual(log_level, WARNING) - self.assertEqual(logged_missing_ids, expected_missing) + self.assertEqual(args[0], WARNING) + self.assertEqual( + args[1], "Cannot pull messages for the following missing message IDs: %s" + ) + self.assertEqual(args[2], expected_missing) def test_send_and_receive_messages_complete(self) -> None: """Test send and receive all messages successfully.""" diff --git a/src/py/flwr/server/driver/inmemory_driver.py b/src/py/flwr/server/driver/inmemory_driver.py index 08019613523b..4b001da52dcb 100644 --- a/src/py/flwr/server/driver/inmemory_driver.py +++ b/src/py/flwr/server/driver/inmemory_driver.py @@ -17,12 +17,14 @@ import time import warnings -from collections.abc import Iterable +from collections.abc import Iterable, Iterator +from logging import WARNING from typing import Optional, cast from uuid import UUID from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet from flwr.common.constant import SUPERLINK_NODE_ID +from flwr.common.logger import log from flwr.common.typing import Run from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.server.superlink.linkstate import LinkStateFactory @@ -149,11 +151,24 @@ def pull_messages( # Raise an error if no message IDs are provided and none are stored if not self._message_ids: raise ValueError("No message IDs to pull. Call `push_messages` first.") - msg_ids = ( - {UUID(msg_id) for msg_id in message_ids} - if message_ids is not None - else {UUID(msg_id) for msg_id in self._message_ids} - ) + + # Allow an override but default to the stored pending IDs + if message_ids is None: + # If no message_ids are provided, use the stored ones + message_ids_to_pull = self._message_ids + else: + # Otherwise, filter the provided list to only those in self._message_ids + provided_ids = set(message_ids) + message_ids_to_pull = sorted(provided_ids & set(self._message_ids)) + missing_ids = sorted(provided_ids - set(self._message_ids)) + if missing_ids: + log( + WARNING, + "Cannot pull messages for the following missing message IDs: %s", + missing_ids, + ) + + msg_ids = {UUID(msg_id) for msg_id in message_ids_to_pull} # Pull Messages message_res_list = self.state.get_message_res(message_ids=msg_ids) # Get IDs of Messages these replies are for @@ -163,7 +178,10 @@ def pull_messages( # Delete self.state.delete_messages(message_ins_ids=message_ins_ids_to_delete) - return message_res_list + def iter_msg() -> Iterator[Message]: + return iter(message_res_list) + + return iter_msg() def send_and_receive( self, @@ -184,7 +202,7 @@ def send_and_receive( end_time = time.time() + (timeout if timeout is not None else 0.0) ret: list[Message] = [] while timeout is None or time.time() < end_time: - res_msgs = self.pull_messages(msg_ids) + res_msgs = list(self.pull_messages(msg_ids)) ret.extend(res_msgs) msg_ids.difference_update( {msg.metadata.reply_to_message for msg in res_msgs} diff --git a/src/py/flwr/server/driver/inmemory_driver_test.py b/src/py/flwr/server/driver/inmemory_driver_test.py index 4fa5cf7d3b62..c52a694c84d7 100644 --- a/src/py/flwr/server/driver/inmemory_driver_test.py +++ b/src/py/flwr/server/driver/inmemory_driver_test.py @@ -18,6 +18,7 @@ import time import unittest from collections.abc import Iterable +from logging import WARNING from unittest.mock import MagicMock, patch from uuid import UUID, uuid4 @@ -157,6 +158,7 @@ def test_pull_messages_with_given_message_ids(self) -> None: msg_ids = [str(uuid4()) for _ in range(2)] message_res_list = create_message_replies_for_specific_ids(msg_ids) self.state.get_message_res.return_value = message_res_list + self.driver._message_ids.extend(msg_ids) # pylint: disable=protected-access # Execute pulled_msgs = list(self.driver.pull_messages(msg_ids)) @@ -170,6 +172,58 @@ def test_pull_messages_with_given_message_ids(self) -> None: message_ins_ids={UUID(m_id) for m_id in msg_ids} ) + def test_pull_messages_without_given_message_ids(self) -> None: + """Test pulling messages successful when no message_ids are provided.""" + # Prepare + msg_ids = [str(uuid4()) for _ in range(2)] + message_res_list = create_message_replies_for_specific_ids(msg_ids) + self.state.get_message_res.return_value = message_res_list + self.driver._message_ids.extend(msg_ids) # pylint: disable=protected-access + + # Execute + pulled_msgs = list(self.driver.pull_messages()) + reply_tos = [msg.metadata.reply_to_message for msg in pulled_msgs] + + # Assert + self.assertEqual(len(pulled_msgs), 2) + self.assertEqual(reply_tos, msg_ids) + # Ensure messages are deleted + self.state.delete_messages.assert_called_once_with( + message_ins_ids={UUID(m_id) for m_id in msg_ids} + ) + + def test_pull_messages_with_invalid_message_ids(self) -> None: + """Test pulling messages when provided message_ids include values not stored in + self._message_ids.""" + # Prepare + provided_msg_ids = sorted([str(uuid4()) for _ in range(5)]) + msg_ids = provided_msg_ids[:2] + message_res_list = create_message_replies_for_specific_ids(msg_ids) + self.state.get_message_res.return_value = message_res_list + self.driver._message_ids.extend(msg_ids) # pylint: disable=protected-access + expected_missing = sorted(set(provided_msg_ids) - set(msg_ids)) + + # Execute + with patch("flwr.server.driver.inmemory_driver.log") as mock_log: + pulled_msgs = list(self.driver.pull_messages(provided_msg_ids)) + reply_tos = [msg.metadata.reply_to_message for msg in pulled_msgs] + + # Assert + self.assertEqual(len(pulled_msgs), 2) + self.assertEqual(reply_tos, msg_ids) + # Ensure messages are deleted + self.state.delete_messages.assert_called_once_with( + message_ins_ids={UUID(m_id) for m_id in msg_ids} + ) + # Warning was logged with the missing IDs + mock_log.assert_called_once() + args, _ = mock_log.call_args + self.assertEqual(args[0], WARNING) + self.assertEqual( + args[1], "Cannot pull messages for the following missing message IDs: %s" + ) + self.assertEqual(args[2], expected_missing) + def test_send_and_receive_messages_complete(self) -> None: """Test send and receive all messages successfully.""" # Prepare From d21a8c1ef9240c36c5b61c48c41b2154a464006d Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Thu, 13 Mar 2025 13:59:20 +0000 Subject: [PATCH 11/15] Fix empty pull message response --- src/py/flwr/server/driver/grpc_driver.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index f58674287652..eb8ee6217e11 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -269,9 +269,13 @@ def iter_msg() -> Iterator[Message]: run_id=cast(Run, self._run).run_id, ) ) - # Convert Message from Protobuf representation - msg = message_from_proto(res.messages_list[0]) - yield msg + # Yield a message if the response contains it, otherwise continue + if res.messages_list: + # Convert Message from Protobuf representation + msg = message_from_proto(res.messages_list[0]) + yield msg + else: + continue return iter_msg() From fd928845f18c199218c426ba4cb829a33d310da2 Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Thu, 13 Mar 2025 14:09:11 +0000 Subject: [PATCH 12/15] Refactor to improve readability --- src/py/flwr/server/driver/grpc_driver.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index eb8ee6217e11..629ae208c9e0 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -234,7 +234,7 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: def pull_messages( self, message_ids: Optional[Iterable[str]] = None ) -> Iterable[Message]: - """Pull messages from the SuperLink based on message IDs. + """Pull messages based on message IDs. This method is used to collect messages from the SuperLink that correspond to a set of given message IDs. If no message IDs are provided, it defaults to the @@ -247,12 +247,13 @@ def pull_messages( # Allow an override but default to the stored pending IDs if message_ids is None: # If no message_ids are provided, use the stored ones - message_ids_to_pull = self._message_ids + msg_ids_to_pull = self._message_ids else: - # Otherwise, filter the provided list to only those in self._message_ids + # Else, keep the IDs (from the given IDs) that are in `self._message_ids` provided_ids = set(message_ids) - message_ids_to_pull = sorted(provided_ids & set(self._message_ids)) - missing_ids = sorted(provided_ids - set(self._message_ids)) + stored_ids = set(self._message_ids) + msg_ids_to_pull = sorted(provided_ids & stored_ids) + missing_ids = sorted(provided_ids - stored_ids) if missing_ids: log( WARNING, @@ -261,8 +262,8 @@ def pull_messages( ) def iter_msg() -> Iterator[Message]: - for msg_id in message_ids_to_pull: - # Pull Messages for each message ID + for msg_id in msg_ids_to_pull: + # Pull a Message for each message ID res: PullResMessagesResponse = self._stub.PullMessages( PullResMessagesRequest( message_ids=[msg_id], From 8b79ea29dbdb41150f52fbbfb65558c0f0f04c5e Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Thu, 13 Mar 2025 15:24:30 +0000 Subject: [PATCH 13/15] Update --- src/py/flwr/server/driver/driver.py | 2 +- src/py/flwr/server/driver/grpc_driver.py | 22 +++++++++---------- src/py/flwr/server/driver/grpc_driver_test.py | 15 ++++++++----- src/py/flwr/server/driver/inmemory_driver.py | 17 +++++++------- .../server/driver/inmemory_driver_test.py | 10 ++++----- 5 files changed, 34 insertions(+), 32 deletions(-) diff --git a/src/py/flwr/server/driver/driver.py b/src/py/flwr/server/driver/driver.py index e737e4b06ba7..a064901ac1f9 100644 --- a/src/py/flwr/server/driver/driver.py +++ b/src/py/flwr/server/driver/driver.py @@ -47,7 +47,7 @@ def run(self) -> Run: @property @abstractmethod - def message_ids(self) -> list[str]: + def message_ids(self) -> Iterable[str]: """Message IDs of pushed messages.""" @abstractmethod diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index 629ae208c9e0..46d131035ef5 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -81,7 +81,7 @@ def __init__( # pylint: disable=too-many-arguments self._channel: Optional[grpc.Channel] = None self.node = Node(node_id=SUPERLINK_NODE_ID) self._retry_invoker = _make_simple_grpc_retry_invoker() - self._message_ids: list[str] = [] + self._message_ids: set[str] = set() @property def _is_connected(self) -> bool: @@ -139,7 +139,7 @@ def _stub(self) -> ServerAppIoStub: return cast(ServerAppIoStub, self._grpc_stub) @property - def message_ids(self) -> list[str]: + def message_ids(self) -> Iterable[str]: """Message IDs of pushed messages.""" return self._message_ids.copy() @@ -228,7 +228,7 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: "to `push_messages`). This could be due to a malformed message.", ) # Store message IDs - self._message_ids.extend(res.message_ids) + self._message_ids.update(res.message_ids) return list(res.message_ids) def pull_messages( @@ -251,10 +251,8 @@ def pull_messages( else: # Else, keep the IDs (from the given IDs) that are in `self._message_ids` provided_ids = set(message_ids) - stored_ids = set(self._message_ids) - msg_ids_to_pull = sorted(provided_ids & stored_ids) - missing_ids = sorted(provided_ids - stored_ids) - if missing_ids: + msg_ids_to_pull = provided_ids & self._message_ids + if missing_ids := provided_ids - msg_ids_to_pull: log( WARNING, "Cannot pull messages for the following missing message IDs: %s", @@ -262,7 +260,7 @@ def pull_messages( ) def iter_msg() -> Iterator[Message]: - for msg_id in msg_ids_to_pull: + for msg_id in sorted(msg_ids_to_pull): # Pull a Message for each message ID res: PullResMessagesResponse = self._stub.PullMessages( PullResMessagesRequest( @@ -274,9 +272,11 @@ def iter_msg() -> Iterator[Message]: if res.messages_list: # Convert Message from Protobuf representation msg = message_from_proto(res.messages_list[0]) + # Remove the message once pulled + self._message_ids.remove(msg.metadata.reply_to_message) yield msg - else: - continue + # else: + # continue return iter_msg() @@ -299,7 +299,7 @@ def send_and_receive( end_time = time.time() + (timeout if timeout is not None else 0.0) ret: list[Message] = [] while timeout is None or time.time() < end_time: - res_msgs = list(self.pull_messages(msg_ids)) + res_msgs = self.pull_messages(msg_ids) ret.extend(res_msgs) msg_ids.difference_update( {msg.metadata.reply_to_message for msg in res_msgs} diff --git a/src/py/flwr/server/driver/grpc_driver_test.py b/src/py/flwr/server/driver/grpc_driver_test.py index d4c46c73adb9..9ff058a21089 100644 --- a/src/py/flwr/server/driver/grpc_driver_test.py +++ b/src/py/flwr/server/driver/grpc_driver_test.py @@ -181,7 +181,7 @@ def test_pull_messages_with_given_message_ids(self) -> None: msg_ids = self._setup_pull_messages_mocks(run_id) # Store message IDs in the driver's internal state for testing - self.driver._message_ids.extend(msg_ids) # pylint: disable=protected-access + self.driver._message_ids.update(msg_ids) # pylint: disable=protected-access # Execute messages = list(self.driver.pull_messages(msg_ids)) @@ -196,7 +196,7 @@ def test_pull_messages_without_given_message_ids(self) -> None: msg_ids = self._setup_pull_messages_mocks(run_id) # Store message IDs in the driver's internal state for testing - self.driver._message_ids.extend(msg_ids) # pylint: disable=protected-access + self.driver._message_ids.update(msg_ids) # pylint: disable=protected-access # Execute messages = list(self.driver.pull_messages()) @@ -212,14 +212,14 @@ def test_pull_messages_with_invalid_message_ids(self) -> None: msg_ids = self._setup_pull_messages_mocks(run_id) # Store message IDs in the driver's internal state for testing - self.driver._message_ids.extend(msg_ids) # pylint: disable=protected-access + self.driver._message_ids.update(msg_ids) # pylint: disable=protected-access provided_msg_ids = {"id2", "id3", "id4", "id5"} - expected_missing = sorted(provided_msg_ids - set(msg_ids)) + expected_missing = provided_msg_ids - set(msg_ids) # Execute with patch("flwr.server.driver.grpc_driver.log") as mock_log: - messages = list(self.driver.pull_messages(provided_msg_ids)) + messages = self.driver.pull_messages(provided_msg_ids) # Assert # Only valid IDs are pulled @@ -231,7 +231,7 @@ def test_pull_messages_with_invalid_message_ids(self) -> None: self.assertEqual( args[1], "Cannot pull messages for the following missing message IDs: %s" ) - self.assertEqual(args[2], expected_missing) + self.assertSetEqual(args[2], expected_missing) def test_send_and_receive_messages_complete(self) -> None: """Test send and receive all messages successfully.""" @@ -251,6 +251,9 @@ def test_send_and_receive_messages_complete(self) -> None: self.mock_stub.PullMessages.return_value = mock_response msgs = [self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL)] + # Store message IDs in the driver's internal state for testing + self.driver._message_ids.update(["id1"]) # pylint: disable=protected-access + # Execute ret_msgs = list(self.driver.send_and_receive(msgs)) diff --git a/src/py/flwr/server/driver/inmemory_driver.py b/src/py/flwr/server/driver/inmemory_driver.py index 4b001da52dcb..454ceb4184e8 100644 --- a/src/py/flwr/server/driver/inmemory_driver.py +++ b/src/py/flwr/server/driver/inmemory_driver.py @@ -52,7 +52,7 @@ def __init__( self.state = state_factory.state() self.pull_interval = pull_interval self.node = Node(node_id=SUPERLINK_NODE_ID) - self._message_ids: list[str] = [] + self._message_ids: set[str] = set() def _check_message(self, message: Message) -> None: # Check if the message is valid @@ -79,7 +79,7 @@ def run(self) -> Run: return Run(**vars(cast(Run, self._run))) @property - def message_ids(self) -> list[str]: + def message_ids(self) -> Iterable[str]: """Message IDs of pushed messages.""" return self._message_ids.copy() @@ -136,7 +136,7 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: if msg_id: msg_ids.append(str(msg_id)) # Store message IDs - self._message_ids.extend(msg_ids) + self._message_ids.update(msg_ids) return msg_ids def pull_messages( @@ -155,20 +155,19 @@ def pull_messages( # Allow an override but default to the stored pending IDs if message_ids is None: # If no message_ids are provided, use the stored ones - message_ids_to_pull = self._message_ids + msg_ids_to_pull = self._message_ids else: # Otherwise, filter the provided list to only those in self._message_ids provided_ids = set(message_ids) - message_ids_to_pull = sorted(provided_ids & set(self._message_ids)) - missing_ids = sorted(provided_ids - set(self._message_ids)) - if missing_ids: + msg_ids_to_pull = provided_ids & self._message_ids + if missing_ids := provided_ids - msg_ids_to_pull: log( WARNING, "Cannot pull messages for the following missing message IDs: %s", missing_ids, ) - msg_ids = {UUID(msg_id) for msg_id in message_ids_to_pull} + msg_ids = {UUID(msg_id) for msg_id in msg_ids_to_pull} # Pull Messages message_res_list = self.state.get_message_res(message_ids=msg_ids) # Get IDs of Messages these replies are for @@ -202,7 +201,7 @@ def send_and_receive( end_time = time.time() + (timeout if timeout is not None else 0.0) ret: list[Message] = [] while timeout is None or time.time() < end_time: - res_msgs = list(self.pull_messages(msg_ids)) + res_msgs = self.pull_messages(msg_ids) ret.extend(res_msgs) msg_ids.difference_update( {msg.metadata.reply_to_message for msg in res_msgs} diff --git a/src/py/flwr/server/driver/inmemory_driver_test.py b/src/py/flwr/server/driver/inmemory_driver_test.py index c52a694c84d7..d1d6cac79770 100644 --- a/src/py/flwr/server/driver/inmemory_driver_test.py +++ b/src/py/flwr/server/driver/inmemory_driver_test.py @@ -158,7 +158,7 @@ def test_pull_messages_with_given_message_ids(self) -> None: msg_ids = [str(uuid4()) for _ in range(2)] message_res_list = create_message_replies_for_specific_ids(msg_ids) self.state.get_message_res.return_value = message_res_list - self.driver._message_ids.extend(msg_ids) # pylint: disable=protected-access + self.driver._message_ids.update(msg_ids) # pylint: disable=protected-access # Execute pulled_msgs = list(self.driver.pull_messages(msg_ids)) @@ -178,7 +178,7 @@ def test_pull_messages_without_given_message_ids(self) -> None: msg_ids = [str(uuid4()) for _ in range(2)] message_res_list = create_message_replies_for_specific_ids(msg_ids) self.state.get_message_res.return_value = message_res_list - self.driver._message_ids.extend(msg_ids) # pylint: disable=protected-access + self.driver._message_ids.update(msg_ids) # pylint: disable=protected-access # Execute pulled_msgs = list(self.driver.pull_messages()) @@ -200,8 +200,8 @@ def test_pull_messages_with_invalid_message_ids(self) -> None: msg_ids = provided_msg_ids[:2] message_res_list = create_message_replies_for_specific_ids(msg_ids) self.state.get_message_res.return_value = message_res_list - self.driver._message_ids.extend(msg_ids) # pylint: disable=protected-access - expected_missing = sorted(set(provided_msg_ids) - set(msg_ids)) + self.driver._message_ids.update(msg_ids) # pylint: disable=protected-access + expected_missing = set(provided_msg_ids) - set(msg_ids) # Execute with patch("flwr.server.driver.inmemory_driver.log") as mock_log: @@ -222,7 +222,7 @@ def test_pull_messages_with_invalid_message_ids(self) -> None: self.assertEqual( args[1], "Cannot pull messages for the following missing message IDs: %s" ) - self.assertEqual(args[2], expected_missing) + self.assertSetEqual(args[2], expected_missing) def test_send_and_receive_messages_complete(self) -> None: """Test send and receive all messages successfully.""" From 886bb53628c5be696c5324ed15112a19ab4a7530 Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Thu, 13 Mar 2025 16:06:37 +0000 Subject: [PATCH 14/15] Address comments --- src/py/flwr/server/driver/grpc_driver.py | 4 +--- src/py/flwr/server/driver/grpc_driver_test.py | 21 ++++++++++++------- src/py/flwr/server/driver/inmemory_driver.py | 2 +- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/src/py/flwr/server/driver/grpc_driver.py b/src/py/flwr/server/driver/grpc_driver.py index 46d131035ef5..5f1407987f11 100644 --- a/src/py/flwr/server/driver/grpc_driver.py +++ b/src/py/flwr/server/driver/grpc_driver.py @@ -275,8 +275,6 @@ def iter_msg() -> Iterator[Message]: # Remove the message once pulled self._message_ids.remove(msg.metadata.reply_to_message) yield msg - # else: - # continue return iter_msg() @@ -299,7 +297,7 @@ def send_and_receive( end_time = time.time() + (timeout if timeout is not None else 0.0) ret: list[Message] = [] while timeout is None or time.time() < end_time: - res_msgs = self.pull_messages(msg_ids) + res_msgs = list(self.pull_messages(msg_ids)) ret.extend(res_msgs) msg_ids.difference_update( {msg.metadata.reply_to_message for msg in res_msgs} diff --git a/src/py/flwr/server/driver/grpc_driver_test.py b/src/py/flwr/server/driver/grpc_driver_test.py index 9ff058a21089..11774d6e1220 100644 --- a/src/py/flwr/server/driver/grpc_driver_test.py +++ b/src/py/flwr/server/driver/grpc_driver_test.py @@ -18,6 +18,7 @@ import time import unittest from logging import WARNING +from typing import Iterable from unittest.mock import Mock, patch import grpc @@ -157,7 +158,7 @@ def _setup_pull_messages_mocks(self, run_id: int) -> list[str]: return ["id2", "id3"] def _assert_pull_messages( - self, expected_ids: list[str], messages: list[Message] + self, expected_ids: list[str], messages: Iterable[Message] ) -> None: """Check that PullMessages was called once per expected message ID and that the returned messages have the correct reply_to values.""" @@ -172,7 +173,7 @@ def _assert_pull_messages( # Each call should be made with a single-element list # containing the expected message id self.assertEqual(args[0].message_ids, [expected_msg_id]) - self.assertEqual(reply_tos, {"id2", "id3"}) + self.assertSetEqual(reply_tos, {"id2", "id3"}) def test_pull_messages_with_given_message_ids(self) -> None: """Test pulling messages with specific message IDs.""" @@ -184,10 +185,13 @@ def test_pull_messages_with_given_message_ids(self) -> None: self.driver._message_ids.update(msg_ids) # pylint: disable=protected-access # Execute - messages = list(self.driver.pull_messages(msg_ids)) + messages = self.driver.pull_messages(msg_ids) # Assert self._assert_pull_messages(msg_ids, messages) + # Ensure messages are deleted + # pylint: disable=protected-access + self.assertFalse(set(msg_ids).issubset(self.driver._message_ids)) def test_pull_messages_without_given_message_ids(self) -> None: """Test pulling messages successful when no message_ids are provided.""" @@ -199,10 +203,13 @@ def test_pull_messages_without_given_message_ids(self) -> None: self.driver._message_ids.update(msg_ids) # pylint: disable=protected-access # Execute - messages = list(self.driver.pull_messages()) + messages = self.driver.pull_messages() # Assert self._assert_pull_messages(msg_ids, messages) + # Ensure messages are deleted + # pylint: disable=protected-access + self.assertFalse(set(msg_ids).issubset(self.driver._message_ids)) def test_pull_messages_with_invalid_message_ids(self) -> None: """Test pulling messages when provided message_ids include values not stored in @@ -224,6 +231,9 @@ def test_pull_messages_with_invalid_message_ids(self) -> None: # Assert # Only valid IDs are pulled self._assert_pull_messages(msg_ids, messages) + # Ensure messages are deleted + # pylint: disable=protected-access + self.assertFalse(set(msg_ids).issubset(self.driver._message_ids)) # Warning was logged with the missing IDs mock_log.assert_called_once() args, _ = mock_log.call_args @@ -251,9 +261,6 @@ def test_send_and_receive_messages_complete(self) -> None: self.mock_stub.PullMessages.return_value = mock_response msgs = [self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL)] - # Store message IDs in the driver's internal state for testing - self.driver._message_ids.update(["id1"]) # pylint: disable=protected-access - # Execute ret_msgs = list(self.driver.send_and_receive(msgs)) diff --git a/src/py/flwr/server/driver/inmemory_driver.py b/src/py/flwr/server/driver/inmemory_driver.py index 454ceb4184e8..65e2b787d2e3 100644 --- a/src/py/flwr/server/driver/inmemory_driver.py +++ b/src/py/flwr/server/driver/inmemory_driver.py @@ -201,7 +201,7 @@ def send_and_receive( end_time = time.time() + (timeout if timeout is not None else 0.0) ret: list[Message] = [] while timeout is None or time.time() < end_time: - res_msgs = self.pull_messages(msg_ids) + res_msgs = list(self.pull_messages(msg_ids)) ret.extend(res_msgs) msg_ids.difference_update( {msg.metadata.reply_to_message for msg in res_msgs} From 1d54a99223c52c907bf9acc2d0e1efa09a874c0f Mon Sep 17 00:00:00 2001 From: Chong Shen Ng Date: Thu, 13 Mar 2025 16:15:02 +0000 Subject: [PATCH 15/15] Revert iterator to list for inmemory driver --- src/py/flwr/server/driver/inmemory_driver.py | 9 +++------ src/py/flwr/server/driver/inmemory_driver_test.py | 6 +++--- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/py/flwr/server/driver/inmemory_driver.py b/src/py/flwr/server/driver/inmemory_driver.py index 65e2b787d2e3..e57632b8ade4 100644 --- a/src/py/flwr/server/driver/inmemory_driver.py +++ b/src/py/flwr/server/driver/inmemory_driver.py @@ -17,7 +17,7 @@ import time import warnings -from collections.abc import Iterable, Iterator +from collections.abc import Iterable from logging import WARNING from typing import Optional, cast from uuid import UUID @@ -177,10 +177,7 @@ def pull_messages( # Delete self.state.delete_messages(message_ins_ids=message_ins_ids_to_delete) - def iter_msg() -> Iterator[Message]: - return iter(message_res_list) - - return iter_msg() + return message_res_list def send_and_receive( self, @@ -201,7 +198,7 @@ def send_and_receive( end_time = time.time() + (timeout if timeout is not None else 0.0) ret: list[Message] = [] while timeout is None or time.time() < end_time: - res_msgs = list(self.pull_messages(msg_ids)) + res_msgs = self.pull_messages(msg_ids) ret.extend(res_msgs) msg_ids.difference_update( {msg.metadata.reply_to_message for msg in res_msgs} diff --git a/src/py/flwr/server/driver/inmemory_driver_test.py b/src/py/flwr/server/driver/inmemory_driver_test.py index d1d6cac79770..8f0e9f93069a 100644 --- a/src/py/flwr/server/driver/inmemory_driver_test.py +++ b/src/py/flwr/server/driver/inmemory_driver_test.py @@ -161,7 +161,7 @@ def test_pull_messages_with_given_message_ids(self) -> None: self.driver._message_ids.update(msg_ids) # pylint: disable=protected-access # Execute - pulled_msgs = list(self.driver.pull_messages(msg_ids)) + pulled_msgs = self.driver.pull_messages(msg_ids) reply_tos = [msg.metadata.reply_to_message for msg in pulled_msgs] # Assert @@ -181,7 +181,7 @@ def test_pull_messages_without_given_message_ids(self) -> None: self.driver._message_ids.update(msg_ids) # pylint: disable=protected-access # Execute - pulled_msgs = list(self.driver.pull_messages()) + pulled_msgs = self.driver.pull_messages() reply_tos = [msg.metadata.reply_to_message for msg in pulled_msgs] # Assert @@ -205,7 +205,7 @@ def test_pull_messages_with_invalid_message_ids(self) -> None: # Execute with patch("flwr.server.driver.inmemory_driver.log") as mock_log: - pulled_msgs = list(self.driver.pull_messages(provided_msg_ids)) + pulled_msgs = self.driver.pull_messages(provided_msg_ids) reply_tos = [msg.metadata.reply_to_message for msg in pulled_msgs] # Assert