Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion src/py/flwr/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
"""Flower type definitions."""


from collections.abc import Iterator, Sequence
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, Union, cast, overload

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -322,3 +323,33 @@ class LogEntry:
actor: Actor
event: Event
status: str


class ReadOnlyList(Sequence[str]):
"""A thin, generic read-only wrapper for a list of strings."""

def __init__(self, data: list[str]) -> None:
# Store a reference to the original mutable list
self._data = data

@overload
def __getitem__(self, index: int) -> str: ...

@overload
def __getitem__(self, index: slice) -> Sequence[str]: ...

def __getitem__(self, index: Union[int, slice]) -> Union[str, "ReadOnlyList"]:
result = self._data[index]
# If the result is a slice, wrap it in a ReadOnlyList.
if isinstance(index, slice):
return ReadOnlyList(cast(list[str], result))
return cast(str, result)

def __len__(self) -> int:
return len(self._data)

def __iter__(self) -> Iterator[str]:
return iter(self._data)

def __repr__(self) -> str:
return f"{self.__class__.__qualname__}({self._data!r})"
22 changes: 16 additions & 6 deletions src/py/flwr/server/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Optional

from flwr.common import Message, RecordSet
from flwr.common.typing import Run
from flwr.common.typing import ReadOnlyList, Run


class Driver(ABC):
Expand All @@ -45,6 +45,11 @@ def set_run(self, run_id: int) -> None:
def run(self) -> Run:
"""Run information."""

@property
@abstractmethod
def message_ids(self) -> ReadOnlyList:
"""Message IDs of pushed messages."""

@abstractmethod
def create_message( # pylint: disable=too-many-arguments,R0917
self,
Expand Down Expand Up @@ -108,16 +113,21 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
"""

@abstractmethod
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
"""Pull messages based on message IDs.
def pull_messages(
self, message_ids: Optional[Iterable[str]] = None
) -> Iterable[Message]:
"""Pull messages from the SuperLink based on message IDs.

This method is used to collect messages from the SuperLink
that correspond to a set of given message IDs.
This method is used to collect messages from the SuperLink that correspond to a
set of given message IDs. If no message IDs are provided, it defaults to the
stored message IDs.

Parameters
----------
message_ids : Iterable[str]
message_ids : Optional[Iterable[str]]
An iterable of message IDs for which reply messages are to be retrieved.
If specified, the method will only pull messages that correspond to these
IDs. If `None`, all messages will be retrieved.

Returns
-------
Expand Down
57 changes: 41 additions & 16 deletions src/py/flwr/server/driver/grpc_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import warnings
from collections.abc import Iterable
from logging import DEBUG, WARNING
from typing import Optional, cast
from typing import Generator, Optional, cast

import grpc

Expand All @@ -32,7 +32,7 @@
from flwr.common.logger import log
from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
from flwr.common.typing import Run
from flwr.common.typing import ReadOnlyList, Run
from flwr.proto.message_pb2 import Message as ProtoMessage # pylint: disable=E0611
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
Expand All @@ -56,7 +56,7 @@
"""


class GrpcDriver(Driver):
class GrpcDriver(Driver): # pylint: disable=too-many-instance-attributes
"""`GrpcDriver` provides an interface to the ServerAppIo API.

Parameters
Expand All @@ -81,6 +81,7 @@ def __init__( # pylint: disable=too-many-arguments
self._channel: Optional[grpc.Channel] = None
self.node = Node(node_id=SUPERLINK_NODE_ID)
self._retry_invoker = _make_simple_grpc_retry_invoker()
self._message_ids: list[str] = []

@property
def _is_connected(self) -> bool:
Expand Down Expand Up @@ -137,6 +138,11 @@ def _stub(self) -> ServerAppIoStub:
self._connect()
return cast(ServerAppIoStub, self._grpc_stub)

@property
def message_ids(self) -> ReadOnlyList:
"""Message IDs of pushed messages."""
return ReadOnlyList(self._message_ids)

def _check_message(self, message: Message) -> None:
# Check if the message is valid
if not (
Expand Down Expand Up @@ -221,24 +227,43 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
"list has `None` for those messages (the order is preserved as passed "
"to `push_messages`). This could be due to a malformed message.",
)
return list(res.message_ids)
# Store message IDs
msg_ids = list(res.message_ids)
self._message_ids = msg_ids
return msg_ids

def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
"""Pull messages based on message IDs.
def pull_messages(
self, message_ids: Optional[Iterable[str]] = None
) -> Iterable[Message]:
"""Pull messages from the SuperLink based on message IDs.

This method is used to collect messages from the SuperLink that correspond to a
set of given message IDs.
set of given message IDs. If no message IDs are provided, it defaults to the
stored message IDs.
"""
# Pull Messages
res: PullResMessagesResponse = self._stub.PullMessages(
PullResMessagesRequest(
message_ids=message_ids,
run_id=cast(Run, self._run).run_id,
)
# Raise an error if no message IDs are provided and none are stored
if not self._message_ids:
raise ValueError("No message IDs to pull. Call `push_messages` first.")

# Allow an override but default to the stored pending IDs
message_ids_to_pull = (
message_ids if message_ids is not None else self._message_ids
)
# Convert Message from Protobuf representation
msgs = [message_from_proto(msg_proto) for msg_proto in res.messages_list]
return msgs

def iter_msg() -> Generator[Message, None, None]:
for msg_id in message_ids_to_pull:
# Pull Messages for each message ID
res: PullResMessagesResponse = self._stub.PullMessages(
PullResMessagesRequest(
message_ids=[msg_id],
run_id=cast(Run, self._run).run_id,
)
)
# Convert Message from Protobuf representation
msg = message_from_proto(res.messages_list[0])
yield msg

return iter_msg()

def send_and_receive(
self,
Expand Down
Loading