Skip to content

Commit 4aafab4

Browse files
committed
refactor(pyfsd.protocol): line length => buffer size, use queue instead of lock
1 parent e399572 commit 4aafab4

File tree

2 files changed

+138
-130
lines changed

2 files changed

+138
-130
lines changed

src/pyfsd/protocol/__init__.py

+12-13
Original file line numberDiff line numberDiff line change
@@ -23,33 +23,32 @@ class LineReceiver(Protocol, metaclass=ABCMeta):
2323

2424
buffer: bytes = b""
2525
delimiter: bytes = b"\r\n"
26-
max_length: int = 1024*128 # 128kb
26+
buffer_size: int = 1024 * 512 # 512kb
2727

2828
@abstractmethod
2929
def line_received(self, line: bytes) -> None:
3030
"""Called when a line was received."""
3131
raise NotImplementedError
3232

3333
@abstractmethod
34-
def max_length_exceed(self, length: int) -> None:
35-
"""Called when line length exceed max length."""
34+
def buffer_size_exceed(self, length: int) -> None:
35+
"""Called when buffer exceed max size."""
3636
raise NotImplementedError
3737

3838
def data_received(self, data: bytes) -> None:
3939
"""Handle datas and call line_received as soon as we received a line."""
40-
if self.max_length != -1:
40+
if self.buffer_size != -1:
4141
length = len(self.buffer) + len(data)
42-
if length > self.max_length:
43-
self.max_length_exceed(length)
42+
if length > self.buffer_size:
43+
self.buffer_size_exceed(length)
4444

45-
if self.delimiter in data:
46-
*lines, left = data.split(self.delimiter)
47-
lines[0] = self.buffer + lines[0]
48-
self.buffer = left
45+
self.buffer += data
46+
del data
47+
if self.delimiter in self.buffer:
48+
*lines, left = self.buffer.split(self.delimiter)
4949
for line in lines:
5050
self.line_received(line)
51-
else:
52-
self.buffer += data
51+
self.buffer = left
5352

5453

5554
class LineProtocol(LineReceiver):
@@ -67,7 +66,7 @@ def connection_made(self, transport: "Transport") -> None: # type: ignore[overr
6766
self.transport = transport
6867

6968
# ruff: noqa: ARG002
70-
def max_length_exceed(self, length: int) -> None:
69+
def buffer_size_exceed(self, length: int) -> None:
7170
"""Kill when line length exceed max length."""
7271
self.transport.close()
7372

src/pyfsd/protocol/client.py

+126-117
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# ruff: noqa: S101
22
"""PyFSD client protocol."""
33

4-
from asyncio import Lock, create_task
4+
from asyncio import Queue, create_task
55
from asyncio import sleep as asleep
66
from collections.abc import Awaitable
77
from inspect import isawaitable
@@ -34,7 +34,12 @@
3434
break_packet,
3535
make_packet,
3636
)
37-
from pyfsd.define.utils import is_callsign_valid, str_to_float, str_to_int
37+
from pyfsd.define.utils import (
38+
is_callsign_valid,
39+
mustdone_task_keeper,
40+
str_to_float,
41+
str_to_int,
42+
)
3843
from pyfsd.object.client import Client, ClientType
3944

4045
from . import LineProtocol
@@ -142,28 +147,123 @@ class ClientProtocol(LineProtocol):
142147

143148
factory: "ClientFactory"
144149
timeout_killer_task: Optional["Task[None]"]
150+
worker_task: Optional["Task[None]"]
151+
worker_queue: Queue[bytes]
145152
transport: "Transport"
146-
tasks: set["Task"]
147153
client: Optional[Client]
148-
lock = Lock()
149154

150155
def __init__(self, factory: "ClientFactory") -> None:
151156
"""Create a ClientProtocol instance."""
152157
self.factory = factory
153-
self.tasks = set()
154158
self.client = None
155159
self.timeout_killer_task = None
156-
# timeout_killer_task and transport will be initialized in connection_made.
160+
self.worker_task = None
161+
self.worker_queue = Queue()
162+
super().__init__()
163+
# timeout_killer_task and worker_task and transport will be
164+
# initialized in connection_made.
165+
166+
async def handle_line_worker_func(self) -> None:
167+
"""Worker processes line."""
168+
result: "PyFSDHandledEventResult | PluginHandledEventResult" # noqa: UP037
169+
170+
while True:
171+
line = await self.worker_queue.get()
172+
173+
# First try to let plugins to process
174+
plugin_result = await self.factory.plugin_manager.trigger_event_handlers(
175+
"line_received_from_client",
176+
(self, line),
177+
{},
178+
)
179+
if plugin_result is None: # Not handled by plugin
180+
packet_ok, has_result = await self.handle_line(line)
181+
result = cast(
182+
"PyFSDHandledEventResult",
183+
{
184+
"handled_by_plugin": False,
185+
"success": packet_ok and has_result,
186+
"packet": line,
187+
"packet_ok": packet_ok,
188+
"has_result": has_result,
189+
},
190+
)
191+
else:
192+
result = plugin_result
193+
194+
self.factory.plugin_manager.trigger_event_auditers_nonblock(
195+
"line_received_from_client",
196+
(self, line, result),
197+
{},
198+
)
199+
200+
def connection_made(self, transport: "Transport") -> None: # type: ignore[override]
201+
"""Initialize something after the connection is made."""
202+
super().connection_made(transport)
203+
ip = self.transport.get_extra_info("peername")[0]
204+
if ip in self.factory.blacklist:
205+
logger.info("Kicking %s: blacklist", ip)
206+
self.transport.close()
207+
return
208+
209+
self.worker_task = create_task(self.handle_line_worker_func())
210+
self.reset_timeout_killer()
211+
logger.info("New connection from %s.", ip)
212+
self.factory.plugin_manager.trigger_event_auditers_nonblock(
213+
"new_connection_established", (self,), {}
214+
)
215+
216+
def line_received(self, line: bytes) -> None:
217+
"""Handle a line."""
218+
self.reset_timeout_killer()
219+
self.worker_queue.put_nowait(line)
157220

158-
def max_length_exceed(self, length: int) -> None:
159-
"""Called when line length exceed max length."""
160-
logger.info("Kicking %s: max length exceeded", self.get_description())
161-
return super().max_length_exceed(length)
221+
def connection_lost(self, exc: Optional[BaseException] = None) -> None:
222+
"""Handle connection lost."""
223+
if self.timeout_killer_task:
224+
self.timeout_killer_task.cancel()
225+
self.timeout_killer_task = None
226+
if self.worker_task:
227+
self.worker_task.cancel()
228+
self.worker_task = None
162229

163-
def add_task(self, task: "Task") -> None:
164-
"""Store a task's strong reference to keep it away from disappear."""
165-
self.tasks.add(task)
166-
task.add_done_callback(self.tasks.discard)
230+
client = None
231+
if self.client is not None:
232+
self.factory.broadcast(
233+
make_packet(
234+
(
235+
FSDClientCommand.REMOVE_ATC
236+
if self.client.type == "ATC"
237+
else FSDClientCommand.REMOVE_PILOT
238+
)
239+
+ self.client.callsign,
240+
self.client.cid.encode(),
241+
),
242+
from_client=self.client,
243+
)
244+
del self.factory.clients[self.client.callsign]
245+
client = self.client
246+
logger.info(
247+
"%s disconnected%s.",
248+
self.get_description(),
249+
f" due to {exc}" if exc else "",
250+
)
251+
self.client = None
252+
253+
self.factory.plugin_manager.trigger_event_auditers_nonblock(
254+
"client_disconnected",
255+
(self, client),
256+
{},
257+
)
258+
259+
def buffer_size_exceed(self, length: int) -> None:
260+
"""Called when client exceed max buffer size."""
261+
logger.info(
262+
"Kicking %s: buffer size exceeded (%d)",
263+
self.get_description(),
264+
length,
265+
)
266+
return super().buffer_size_exceed(length)
167267

168268
def kill_after_1sec(self) -> None:
169269
"""Kill this client after 1 second by kill_func."""
@@ -172,7 +272,17 @@ async def kill() -> None:
172272
await asleep(1)
173273
self.transport.close()
174274

175-
self.add_task(create_task(kill()))
275+
mustdone_task_keeper.add(create_task(kill()))
276+
277+
def get_description(self) -> str:
278+
"""Get text description of this client."""
279+
if self.client is not None:
280+
return (
281+
cast("str", self.transport.get_extra_info("peername")[0])
282+
+ f" ({self.client.callsign.decode(errors='replace')})"
283+
)
284+
285+
return cast("str", self.transport.get_extra_info("peername")[0])
176286

177287
def reset_timeout_killer(self) -> None:
178288
"""Reset timeout killer."""
@@ -187,21 +297,6 @@ async def timeout_killer() -> None:
187297
self.timeout_killer_task.cancel()
188298
self.timeout_killer_task = create_task(timeout_killer())
189299

190-
def connection_made(self, transport: "Transport") -> None: # type: ignore[override]
191-
"""Initialize something after the connection is made."""
192-
super().connection_made(transport)
193-
ip = self.transport.get_extra_info("peername")[0]
194-
if ip in self.factory.blacklist:
195-
logger.info("Kicking %s: blacklist", ip)
196-
self.transport.close()
197-
return
198-
199-
self.reset_timeout_killer()
200-
logger.info("New connection from %s.", ip)
201-
self.factory.plugin_manager.trigger_event_auditers_nonblock(
202-
"new_connection_established", (self,), {}
203-
)
204-
205300
def send_error(
206301
self, errno: FSDClientError, *, env: bytes = b"", fatal: bool = False
207302
) -> None:
@@ -904,7 +999,7 @@ async def killer() -> None:
904999
await asleep(1)
9051000
transport_to_kill.close()
9061001

907-
self.add_task(create_task(killer()))
1002+
mustdone_task_keeper.add(create_task(killer()))
9081003

9091004
logger.info(
9101005
"Kicking %s: killed by %s",
@@ -914,46 +1009,6 @@ async def killer() -> None:
9141009
kill_it()
9151010
return True, True
9161011

917-
def line_received(self, line: bytes) -> None:
918-
"""Handle a line."""
919-
self.reset_timeout_killer()
920-
921-
async def handle() -> None:
922-
result: "PyFSDHandledEventResult | PluginHandledEventResult" # noqa: UP037
923-
# First try to let plugins to process
924-
plugin_result = await self.factory.plugin_manager.trigger_event_handlers(
925-
"line_received_from_client",
926-
(self, line),
927-
{},
928-
)
929-
if plugin_result is None: # Not handled by plugin
930-
packet_ok, has_result = await self.handle_line(line)
931-
result = cast(
932-
"PyFSDHandledEventResult",
933-
{
934-
"handled_by_plugin": False,
935-
"success": packet_ok and has_result,
936-
"packet": line,
937-
"packet_ok": packet_ok,
938-
"has_result": has_result,
939-
},
940-
)
941-
else:
942-
result = plugin_result
943-
944-
self.factory.plugin_manager.trigger_event_auditers_nonblock(
945-
"line_received_from_client",
946-
(self, line, result),
947-
{},
948-
)
949-
950-
async def do_after_before_done() -> None:
951-
"""Wait last task done then handle this."""
952-
async with self.lock:
953-
await handle()
954-
955-
self.add_task(create_task(do_after_before_done()))
956-
9571012
async def handle_line(
9581013
self,
9591014
byte_line: bytes,
@@ -1054,49 +1109,3 @@ async def handle_line(
10541109
return await self.handle_kill(packet)
10551110
self.send_error(FSDClientError.SYNTAX)
10561111
return False, False
1057-
1058-
def get_description(self) -> str:
1059-
"""Get text description of this client."""
1060-
if self.client is not None:
1061-
return (
1062-
cast("str", self.transport.get_extra_info("peername")[0])
1063-
+ f" ({self.client.callsign.decode(errors='replace')})"
1064-
)
1065-
1066-
return cast("str", self.transport.get_extra_info("peername")[0])
1067-
1068-
def connection_lost(self, exc: Optional[BaseException] = None) -> None:
1069-
"""Handle connection lost."""
1070-
if self.timeout_killer_task:
1071-
self.timeout_killer_task.cancel()
1072-
self.timeout_killer_task = None
1073-
for pending_task in self.tasks:
1074-
pending_task.cancel()
1075-
client = None
1076-
if self.client is not None:
1077-
self.factory.broadcast(
1078-
make_packet(
1079-
(
1080-
FSDClientCommand.REMOVE_ATC
1081-
if self.client.type == "ATC"
1082-
else FSDClientCommand.REMOVE_PILOT
1083-
)
1084-
+ self.client.callsign,
1085-
self.client.cid.encode(),
1086-
),
1087-
from_client=self.client,
1088-
)
1089-
del self.factory.clients[self.client.callsign]
1090-
client = self.client
1091-
logger.info(
1092-
"%s disconnected%s.",
1093-
self.get_description(),
1094-
f" due to {exc}" if exc else "",
1095-
)
1096-
self.client = None
1097-
1098-
self.factory.plugin_manager.trigger_event_auditers_nonblock(
1099-
"client_disconnected",
1100-
(self, client),
1101-
{},
1102-
)

0 commit comments

Comments
 (0)