-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Expand file tree
/
Copy pathgrpc_driver.py
More file actions
317 lines (282 loc) · 11.8 KB
/
grpc_driver.py
File metadata and controls
317 lines (282 loc) · 11.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower gRPC Driver."""
import time
import warnings
from collections.abc import Iterable, Iterator
from logging import DEBUG, WARNING
from typing import Optional, cast
import grpc
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
from flwr.common.constant import (
SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS,
SUPERLINK_NODE_ID,
)
from flwr.common.grpc import create_channel, on_channel_state_change
from flwr.common.logger import log
from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
from flwr.common.typing import Run
from flwr.proto.message_pb2 import Message as ProtoMessage # pylint: disable=E0611
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
GetNodesRequest,
GetNodesResponse,
PullResMessagesRequest,
PullResMessagesResponse,
PushInsMessagesRequest,
PushInsMessagesResponse,
)
from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub # pylint: disable=E0611
from .driver import Driver
ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
[flwr-serverapp] Error: Not connected.
Call `connect()` on the `GrpcDriverStub` instance before calling any of the other
`GrpcDriverStub` methods.
"""
class GrpcDriver(Driver): # pylint: disable=too-many-instance-attributes
"""`GrpcDriver` provides an interface to the ServerAppIo API.
Parameters
----------
serverappio_service_address : str (default: "[::]:9091")
The address (URL, IPv6, IPv4) of the SuperLink ServerAppIo API service.
root_certificates : Optional[bytes] (default: None)
The PEM-encoded root certificates as a byte string.
If provided, a secure connection using the certificates will be
established to an SSL-enabled Flower server.
"""
def __init__( # pylint: disable=too-many-arguments
self,
serverappio_service_address: str = SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS,
root_certificates: Optional[bytes] = None,
) -> None:
self._addr = serverappio_service_address
self._cert = root_certificates
self._run: Optional[Run] = None
self._grpc_stub: Optional[ServerAppIoStub] = None
self._channel: Optional[grpc.Channel] = None
self.node = Node(node_id=SUPERLINK_NODE_ID)
self._retry_invoker = _make_simple_grpc_retry_invoker()
self._message_ids: set[str] = set()
@property
def _is_connected(self) -> bool:
"""Check if connected to the ServerAppIo API server."""
return self._channel is not None
def _connect(self) -> None:
"""Connect to the ServerAppIo API.
This will not call GetRun.
"""
if self._is_connected:
log(WARNING, "Already connected")
return
self._channel = create_channel(
server_address=self._addr,
insecure=(self._cert is None),
root_certificates=self._cert,
)
self._channel.subscribe(on_channel_state_change)
self._grpc_stub = ServerAppIoStub(self._channel)
_wrap_stub(self._grpc_stub, self._retry_invoker)
log(DEBUG, "[flwr-serverapp] Connected to %s", self._addr)
def _disconnect(self) -> None:
"""Disconnect from the ServerAppIo API."""
if not self._is_connected:
log(DEBUG, "Already disconnected")
return
channel: grpc.Channel = self._channel
self._channel = None
self._grpc_stub = None
channel.close()
log(DEBUG, "[flwr-serverapp] Disconnected")
def set_run(self, run_id: int) -> None:
"""Set the run."""
# Get the run info
req = GetRunRequest(run_id=run_id)
res: GetRunResponse = self._stub.GetRun(req)
if not res.HasField("run"):
raise RuntimeError(f"Cannot find the run with ID: {run_id}")
self._run = run_from_proto(res.run)
@property
def run(self) -> Run:
"""Run information."""
return Run(**vars(self._run))
@property
def _stub(self) -> ServerAppIoStub:
"""ServerAppIo stub."""
if not self._is_connected:
self._connect()
return cast(ServerAppIoStub, self._grpc_stub)
@property
def message_ids(self) -> Iterable[str]:
"""Message IDs of pushed messages."""
return self._message_ids.copy()
def _check_message(self, message: Message) -> None:
# Check if the message is valid
if not (
# Assume self._run being initialized
message.metadata.run_id == cast(Run, self._run).run_id
and message.metadata.src_node_id == self.node.node_id
and message.metadata.message_id == ""
and message.metadata.reply_to_message == ""
and message.metadata.ttl > 0
):
raise ValueError(f"Invalid message: {message}")
def create_message( # pylint: disable=too-many-arguments,R0917
self,
content: RecordSet,
message_type: str,
dst_node_id: int,
group_id: str,
ttl: Optional[float] = None,
) -> Message:
"""Create a new message with specified parameters.
This method constructs a new `Message` with given content and metadata.
The `run_id` and `src_node_id` will be set automatically.
"""
if ttl:
warnings.warn(
"A custom TTL was set, but note that the SuperLink does not enforce "
"the TTL yet. The SuperLink will start enforcing the TTL in a future "
"version of Flower.",
stacklevel=2,
)
ttl_ = DEFAULT_TTL if ttl is None else ttl
metadata = Metadata(
run_id=cast(Run, self._run).run_id,
message_id="", # Will be set by the server
src_node_id=self.node.node_id,
dst_node_id=dst_node_id,
reply_to_message="",
group_id=group_id,
ttl=ttl_,
message_type=message_type,
)
return Message(metadata=metadata, content=content)
def get_node_ids(self) -> Iterable[int]:
"""Get node IDs."""
# Call GrpcDriverStub method
res: GetNodesResponse = self._stub.GetNodes(
GetNodesRequest(run_id=cast(Run, self._run).run_id)
)
return [node.node_id for node in res.nodes]
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
"""Push messages to specified node IDs.
This method takes an iterable of messages and sends each message
to the node specified in `dst_node_id`.
"""
# Construct Messages
message_proto_list: list[ProtoMessage] = []
for msg in messages:
# Check message
self._check_message(msg)
# Convert to proto
msg_proto = message_to_proto(msg)
# Add to list
message_proto_list.append(msg_proto)
# Call GrpcDriverStub method
res: PushInsMessagesResponse = self._stub.PushMessages(
PushInsMessagesRequest(
messages_list=message_proto_list, run_id=cast(Run, self._run).run_id
)
)
if len([msg_id for msg_id in res.message_ids if msg_id]) != len(
list(message_proto_list)
):
log(
WARNING,
"Not all messages could be pushed to the SuperLink. The returned "
"list has `None` for those messages (the order is preserved as passed "
"to `push_messages`). This could be due to a malformed message.",
)
# Store message IDs
self._message_ids.update(res.message_ids)
return list(res.message_ids)
def pull_messages(
self, message_ids: Optional[Iterable[str]] = None
) -> Iterable[Message]:
"""Pull messages based on message IDs.
This method is used to collect messages from the SuperLink that correspond to a
set of given message IDs. If no message IDs are provided, it defaults to the
stored message IDs.
"""
# Raise an error if no message IDs are provided and none are stored
if not self._message_ids:
raise ValueError("No message IDs to pull. Call `push_messages` first.")
# Allow an override but default to the stored pending IDs
if message_ids is None:
# If no message_ids are provided, use the stored ones
msg_ids_to_pull = self._message_ids
else:
# Else, keep the IDs (from the given IDs) that are in `self._message_ids`
provided_ids = set(message_ids)
msg_ids_to_pull = provided_ids & self._message_ids
if missing_ids := provided_ids - msg_ids_to_pull:
log(
WARNING,
"Cannot pull messages for the following missing message IDs: %s",
missing_ids,
)
def iter_msg() -> Iterator[Message]:
for msg_id in sorted(msg_ids_to_pull):
# Pull a Message for each message ID
res: PullResMessagesResponse = self._stub.PullMessages(
PullResMessagesRequest(
message_ids=[msg_id],
run_id=cast(Run, self._run).run_id,
)
)
# Yield a message if the response contains it, otherwise continue
if res.messages_list:
# Convert Message from Protobuf representation
msg = message_from_proto(res.messages_list[0])
# Remove the message once pulled
self._message_ids.remove(msg.metadata.reply_to_message)
yield msg
return iter_msg()
def send_and_receive(
self,
messages: Iterable[Message],
*,
timeout: Optional[float] = None,
) -> Iterable[Message]:
"""Push messages to specified node IDs and pull the reply messages.
This method sends a list of messages to their destination node IDs and then
waits for the replies. It continues to pull replies until either all replies are
received or the specified timeout duration is exceeded.
"""
# Push messages
msg_ids = set(self.push_messages(messages))
# Pull messages
end_time = time.time() + (timeout if timeout is not None else 0.0)
ret: list[Message] = []
while timeout is None or time.time() < end_time:
res_msgs = list(self.pull_messages(msg_ids))
ret.extend(res_msgs)
msg_ids.difference_update(
{msg.metadata.reply_to_message for msg in res_msgs}
)
if len(msg_ids) == 0:
break
# Sleep
time.sleep(3)
return ret
def close(self) -> None:
"""Disconnect from the SuperLink if connected."""
# Check if `connect` was called before
if not self._is_connected:
return
# Disconnect
self._disconnect()