-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Expand file tree
/
Copy pathdriver.py
More file actions
171 lines (143 loc) · 5.97 KB
/
driver.py
File metadata and controls
171 lines (143 loc) · 5.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Driver (abstract base class)."""
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Optional
from flwr.common import Message, RecordSet
from flwr.common.typing import Run
class Driver(ABC):
"""Abstract base Driver class for the ServerAppIo API."""
@abstractmethod
def set_run(self, run_id: int) -> None:
"""Request a run to the SuperLink with a given `run_id`.
If a Run with the specified `run_id` exists, a local Run
object will be created. It enables further functionality
in the driver, such as sending `Messages`.
Parameters
----------
run_id : int
The `run_id` of the Run this Driver object operates in.
"""
@property
@abstractmethod
def run(self) -> Run:
"""Run information."""
@property
@abstractmethod
def message_ids(self) -> Iterable[str]:
"""Message IDs of pushed messages."""
@abstractmethod
def create_message( # pylint: disable=too-many-arguments,R0917
self,
content: RecordSet,
message_type: str,
dst_node_id: int,
group_id: str,
ttl: Optional[float] = None,
) -> Message:
"""Create a new message with specified parameters.
This method constructs a new `Message` with given content and metadata.
The `run_id` and `src_node_id` will be set automatically.
Parameters
----------
content : RecordSet
The content for the new message. This holds records that are to be sent
to the destination node.
message_type : str
The type of the message, defining the action to be executed on
the receiving end.
dst_node_id : int
The ID of the destination node to which the message is being sent.
group_id : str
The ID of the group to which this message is associated. In some settings,
this is used as the FL round.
ttl : Optional[float] (default: None)
Time-to-live for the round trip of this message, i.e., the time from sending
this message to receiving a reply. It specifies in seconds the duration for
which the message and its potential reply are considered valid. If unset,
the default TTL (i.e., `common.DEFAULT_TTL`) will be used.
Returns
-------
message : Message
A new `Message` instance with the specified content and metadata.
"""
@abstractmethod
def get_node_ids(self) -> Iterable[int]:
"""Get node IDs."""
@abstractmethod
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
"""Push messages to specified node IDs.
This method takes an iterable of messages and sends each message
to the node specified in `dst_node_id`.
Parameters
----------
messages : Iterable[Message]
An iterable of messages to be sent.
Returns
-------
message_ids : Iterable[str]
An iterable of IDs for the messages that were sent, which can be used
to pull replies.
"""
@abstractmethod
def pull_messages(
self, message_ids: Optional[Iterable[str]] = None
) -> Iterable[Message]:
"""Pull messages from the SuperLink based on message IDs.
This method is used to collect messages from the SuperLink that correspond to a
set of given message IDs. If no message IDs are provided, it defaults to the
stored message IDs.
Parameters
----------
message_ids : Optional[Iterable[str]]
An iterable of message IDs for which reply messages are to be retrieved.
If specified, the method will only pull messages that correspond to these
IDs. If `None`, all messages will be retrieved.
Returns
-------
messages : Iterable[Message]
An iterable of messages received.
"""
@abstractmethod
def send_and_receive(
self,
messages: Iterable[Message],
*,
timeout: Optional[float] = None,
) -> Iterable[Message]:
"""Push messages to specified node IDs and pull the reply messages.
This method sends a list of messages to their destination node IDs and then
waits for the replies. It continues to pull replies until either all
replies are received or the specified timeout duration is exceeded.
Parameters
----------
messages : Iterable[Message]
An iterable of messages to be sent.
timeout : Optional[float] (default: None)
The timeout duration in seconds. If specified, the method will wait for
replies for this duration. If `None`, there is no time limit and the method
will wait until replies for all messages are received.
Returns
-------
replies : Iterable[Message]
An iterable of reply messages received from the SuperLink.
Notes
-----
This method uses `push_messages` to send the messages and `pull_messages`
to collect the replies. If `timeout` is set, the method may not return
replies for all sent messages. A message remains valid until its TTL,
which is not affected by `timeout`.
"""