Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions src/py/flwr/server/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def set_run(self, run_id: int) -> None:
def run(self) -> Run:
"""Run information."""

@property
@abstractmethod
def message_ids(self) -> Iterable[str]:
"""Message IDs of pushed messages."""

@abstractmethod
def create_message( # pylint: disable=too-many-arguments,R0917
self,
Expand Down Expand Up @@ -108,16 +113,21 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
"""

@abstractmethod
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
"""Pull messages based on message IDs.
def pull_messages(
self, message_ids: Optional[Iterable[str]] = None
) -> Iterable[Message]:
"""Pull messages from the SuperLink based on message IDs.

This method is used to collect messages from the SuperLink
that correspond to a set of given message IDs.
This method is used to collect messages from the SuperLink that correspond to a
set of given message IDs. If no message IDs are provided, it defaults to the
stored message IDs.

Parameters
----------
message_ids : Iterable[str]
message_ids : Optional[Iterable[str]]
An iterable of message IDs for which reply messages are to be retrieved.
If specified, the method will only pull messages that correspond to these
IDs. If `None`, all messages will be retrieved.

Returns
-------
Expand Down
68 changes: 53 additions & 15 deletions src/py/flwr/server/driver/grpc_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import time
import warnings
from collections.abc import Iterable
from collections.abc import Iterable, Iterator
from logging import DEBUG, WARNING
from typing import Optional, cast

Expand Down Expand Up @@ -56,7 +56,7 @@
"""


class GrpcDriver(Driver):
class GrpcDriver(Driver): # pylint: disable=too-many-instance-attributes
"""`GrpcDriver` provides an interface to the ServerAppIo API.

Parameters
Expand All @@ -81,6 +81,7 @@ def __init__( # pylint: disable=too-many-arguments
self._channel: Optional[grpc.Channel] = None
self.node = Node(node_id=SUPERLINK_NODE_ID)
self._retry_invoker = _make_simple_grpc_retry_invoker()
self._message_ids: set[str] = set()

@property
def _is_connected(self) -> bool:
Expand Down Expand Up @@ -137,6 +138,11 @@ def _stub(self) -> ServerAppIoStub:
self._connect()
return cast(ServerAppIoStub, self._grpc_stub)

@property
def message_ids(self) -> Iterable[str]:
"""Message IDs of pushed messages."""
return self._message_ids.copy()

def _check_message(self, message: Message) -> None:
# Check if the message is valid
if not (
Expand Down Expand Up @@ -221,24 +227,56 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
"list has `None` for those messages (the order is preserved as passed "
"to `push_messages`). This could be due to a malformed message.",
)
# Store message IDs
self._message_ids.update(res.message_ids)
return list(res.message_ids)

def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
def pull_messages(
self, message_ids: Optional[Iterable[str]] = None
) -> Iterable[Message]:
"""Pull messages based on message IDs.

This method is used to collect messages from the SuperLink that correspond to a
set of given message IDs.
set of given message IDs. If no message IDs are provided, it defaults to the
stored message IDs.
"""
# Pull Messages
res: PullResMessagesResponse = self._stub.PullMessages(
PullResMessagesRequest(
message_ids=message_ids,
run_id=cast(Run, self._run).run_id,
)
)
# Convert Message from Protobuf representation
msgs = [message_from_proto(msg_proto) for msg_proto in res.messages_list]
return msgs
# Raise an error if no message IDs are provided and none are stored
if not self._message_ids:
raise ValueError("No message IDs to pull. Call `push_messages` first.")

# Allow an override but default to the stored pending IDs
if message_ids is None:
# If no message_ids are provided, use the stored ones
msg_ids_to_pull = self._message_ids
else:
# Else, keep the IDs (from the given IDs) that are in `self._message_ids`
provided_ids = set(message_ids)
msg_ids_to_pull = provided_ids & self._message_ids
if missing_ids := provided_ids - msg_ids_to_pull:
log(
WARNING,
"Cannot pull messages for the following missing message IDs: %s",
missing_ids,
)

def iter_msg() -> Iterator[Message]:
for msg_id in sorted(msg_ids_to_pull):
# Pull a Message for each message ID
res: PullResMessagesResponse = self._stub.PullMessages(
PullResMessagesRequest(
message_ids=[msg_id],
run_id=cast(Run, self._run).run_id,
)
)
# Yield a message if the response contains it, otherwise continue
if res.messages_list:
# Convert Message from Protobuf representation
msg = message_from_proto(res.messages_list[0])
# Remove the message once pulled
self._message_ids.remove(msg.metadata.reply_to_message)
yield msg

return iter_msg()

def send_and_receive(
self,
Expand All @@ -259,7 +297,7 @@ def send_and_receive(
end_time = time.time() + (timeout if timeout is not None else 0.0)
ret: list[Message] = []
while timeout is None or time.time() < end_time:
res_msgs = self.pull_messages(msg_ids)
res_msgs = list(self.pull_messages(msg_ids))
ret.extend(res_msgs)
msg_ids.difference_update(
{msg.metadata.reply_to_message for msg in res_msgs}
Expand Down
135 changes: 109 additions & 26 deletions src/py/flwr/server/driver/grpc_driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

import time
import unittest
from logging import WARNING
from typing import Iterable
from unittest.mock import Mock, patch

import grpc

from flwr.common import DEFAULT_TTL, RecordSet
from flwr.common import DEFAULT_TTL, Message, RecordSet
from flwr.common.message import Error
from flwr.proto.run_pb2 import ( # pylint: disable=E0611
GetRunRequest,
Expand Down Expand Up @@ -129,36 +131,117 @@ def test_push_messages_invalid(self) -> None:
with self.assertRaises(ValueError):
self.driver.push_messages(msgs)

def test_pull_messages_with_given_message_ids(self) -> None:
"""Test pulling messages with specific message IDs."""
# Prepare
mock_response = Mock()
# A Message must have either content or error set so we prepare
run_id = 12345
def _setup_pull_messages_mocks(self, run_id: int) -> list[str]:
"""Set up common mocks for pull_messages.

This creates two mock responses, one for each message ID, and configures the
stub to return them sequentially.
"""
# Create messages with distinct reply_to values
ok_message = create_res_message(src_node_id=123, dst_node_id=456, run_id=run_id)
ok_message.metadata.reply_to_message = "id2"

error_message = create_res_message(
src_node_id=123, dst_node_id=789, run_id=run_id, error=Error(code=0)
)
error_message.metadata.reply_to_message = "id3"
# The the response from the DriverServicer is in the form of Protbuf Messages
mock_response.messages_list = [ok_message, error_message]
self.mock_stub.PullMessages.return_value = mock_response
msg_ids = ["id1", "id2", "id3"]

# Create separate mock responses for each call
mock_response1 = Mock()
mock_response1.messages_list = [ok_message]
mock_response2 = Mock()
mock_response2.messages_list = [error_message]

# Configure PullMessages to return a different response per call
self.mock_stub.PullMessages.side_effect = [mock_response1, mock_response2]

# Return the message IDs to be used in the tests (the valid ones)
return ["id2", "id3"]

def _assert_pull_messages(
self, expected_ids: list[str], messages: Iterable[Message]
) -> None:
"""Check that PullMessages was called once per expected message ID and that the
returned messages have the correct reply_to values."""
reply_tos = {msg.metadata.reply_to_message for msg in messages}
calls = self.mock_stub.PullMessages.call_args_list
self.assertEqual(len(calls), len(expected_ids))
for call, expected_msg_id in zip(calls, expected_ids):
args, kwargs = call
self.assertEqual(len(args), 1)
self.assertEqual(len(kwargs), 0)
self.assertIsInstance(args[0], PullResMessagesRequest)
# Each call should be made with a single-element list
# containing the expected message id
self.assertEqual(args[0].message_ids, [expected_msg_id])
self.assertSetEqual(reply_tos, {"id2", "id3"})

def test_pull_messages_with_given_message_ids(self) -> None:
"""Test pulling messages with specific message IDs."""
# Prepare
run_id = 12345
msg_ids = self._setup_pull_messages_mocks(run_id)

# Store message IDs in the driver's internal state for testing
self.driver._message_ids.update(msg_ids) # pylint: disable=protected-access

# Execute
msgs = self.driver.pull_messages(msg_ids)
reply_tos = {msg.metadata.reply_to_message for msg in msgs}
args, kwargs = self.mock_stub.PullMessages.call_args
messages = self.driver.pull_messages(msg_ids)

# Assert
self.mock_stub.GetRun.assert_called_once()
self.assertEqual(len(args), 1)
self.assertEqual(len(kwargs), 0)
self.assertIsInstance(args[0], PullResMessagesRequest)
self.assertEqual(args[0].message_ids, msg_ids)
self.assertEqual(reply_tos, {"id2", "id3"})
self._assert_pull_messages(msg_ids, messages)
# Ensure messages are deleted
# pylint: disable=protected-access
self.assertFalse(set(msg_ids).issubset(self.driver._message_ids))

def test_pull_messages_without_given_message_ids(self) -> None:
"""Test pulling messages successful when no message_ids are provided."""
# Prepare
run_id = 12345
msg_ids = self._setup_pull_messages_mocks(run_id)

# Store message IDs in the driver's internal state for testing
self.driver._message_ids.update(msg_ids) # pylint: disable=protected-access

# Execute
messages = self.driver.pull_messages()

# Assert
self._assert_pull_messages(msg_ids, messages)
# Ensure messages are deleted
# pylint: disable=protected-access
self.assertFalse(set(msg_ids).issubset(self.driver._message_ids))

def test_pull_messages_with_invalid_message_ids(self) -> None:
"""Test pulling messages when provided message_ids include values not stored in
self._message_ids."""
# Prepare
run_id = 12345
msg_ids = self._setup_pull_messages_mocks(run_id)

# Store message IDs in the driver's internal state for testing
self.driver._message_ids.update(msg_ids) # pylint: disable=protected-access

provided_msg_ids = {"id2", "id3", "id4", "id5"}
expected_missing = provided_msg_ids - set(msg_ids)

# Execute
with patch("flwr.server.driver.grpc_driver.log") as mock_log:
messages = self.driver.pull_messages(provided_msg_ids)

# Assert
# Only valid IDs are pulled
self._assert_pull_messages(msg_ids, messages)
# Ensure messages are deleted
# pylint: disable=protected-access
self.assertFalse(set(msg_ids).issubset(self.driver._message_ids))
# Warning was logged with the missing IDs
mock_log.assert_called_once()
args, _ = mock_log.call_args
self.assertEqual(args[0], WARNING)
self.assertEqual(
args[1], "Cannot pull messages for the following missing message IDs: %s"
)
self.assertSetEqual(args[2], expected_missing)

def test_send_and_receive_messages_complete(self) -> None:
"""Test send and receive all messages successfully."""
Expand Down Expand Up @@ -191,14 +274,14 @@ def test_send_and_receive_messages_timeout(self) -> None:
sleep_fn = time.sleep
mock_response = Mock(message_ids=["id1"])
self.mock_stub.PushMessages.return_value = mock_response
mock_response = Mock(messages_list=[])
self.mock_stub.PullMessages.return_value = mock_response
msgs = [self.driver.create_message(RecordSet(), "", 0, "", DEFAULT_TTL)]

# Execute
with patch("time.sleep", side_effect=lambda t: sleep_fn(t * 0.01)):
start_time = time.time()
ret_msgs = list(self.driver.send_and_receive(msgs, timeout=0.15))
# Patch pull_messages to always return an empty iterator.
with patch.object(self.driver, "pull_messages", return_value=iter([])):
with patch("time.sleep", side_effect=lambda t: sleep_fn(t * 0.01)):
start_time = time.time()
ret_msgs = list(self.driver.send_and_receive(msgs, timeout=0.15))

# Assert
self.assertLess(time.time() - start_time, 0.2)
Expand Down
Loading
Loading