Skip to content

Commit 04ae89b

Browse files
committed
code cleanup and improved partial fragment removal
1 parent ea38115 commit 04ae89b

File tree

2 files changed

+92
-50
lines changed

2 files changed

+92
-50
lines changed

bellows/ezsp/fragmentation.py

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
"""
2-
Implements APS fragmentation reassembly on the EZSP Host side,
1+
"""Implements APS fragmentation reassembly on the EZSP Host side,
32
mirroring the logic from fragmentation.c in the EmberZNet stack.
43
"""
54

65
import asyncio
76
import logging
8-
from collections import defaultdict
9-
from typing import Optional, Dict, Tuple
7+
from typing import Dict, Optional, Tuple
108

119
LOGGER = logging.getLogger(__name__)
1210

@@ -17,13 +15,14 @@
1715
# store partial data keyed by (sender, aps_sequence, profile_id, cluster_id)
1816
FragmentKey = Tuple[int, int, int, int]
1917

18+
2019
class _FragmentEntry:
2120
def __init__(self, fragment_count: int):
2221
self.fragment_count = fragment_count
2322
self.fragments_received = 0
2423
self.fragment_data = {}
2524
self.start_time = asyncio.get_event_loop().time()
26-
25+
2726
def add_fragment(self, index: int, data: bytes) -> None:
2827
if index not in self.fragment_data:
2928
self.fragment_data[index] = data
@@ -33,32 +32,41 @@ def is_complete(self) -> bool:
3332
return self.fragments_received == self.fragment_count
3433

3534
def assemble(self) -> bytes:
36-
return b''.join(self.fragment_data[i] for i in sorted(self.fragment_data.keys()))
35+
return b"".join(
36+
self.fragment_data[i] for i in sorted(self.fragment_data.keys())
37+
)
38+
3739

3840
class FragmentManager:
3941
def __init__(self):
4042
self._partial: Dict[FragmentKey, _FragmentEntry] = {}
41-
42-
def handle_incoming_fragment(self, sender_nwk: int, aps_sequence: int, profile_id: int, cluster_id: int,
43-
group_id: int, payload: bytes) -> Tuple[bool, Optional[bytes], int, int]:
44-
"""
45-
Handle a newly received fragment. The group_id field
46-
encodes high byte = total fragment count, low byte = current fragment index.
43+
self._cleanup_timers: Dict[FragmentKey, asyncio.TimerHandle] = {}
44+
45+
def handle_incoming_fragment(
46+
self,
47+
sender_nwk: int,
48+
aps_sequence: int,
49+
profile_id: int,
50+
cluster_id: int,
51+
fragment_count: int,
52+
fragment_index: int,
53+
payload: bytes,
54+
) -> Tuple[bool, Optional[bytes], int, int]:
55+
"""Handle a newly received fragment.
4756
4857
:param sender_nwk: NWK address or the short ID of the sender.
4958
:param aps_sequence: The APS sequence from the incoming APS frame.
5059
:param profile_id: The APS frame's profileId.
5160
:param cluster_id: The APS frame's clusterId.
52-
:param group_id: The APS frame's groupId (used to store fragment # / total).
61+
:param fragment_count: The total number of expected message fragments.
62+
:param fragment_index: The index of the current fragment being processed.
5363
:param payload: The fragment of data for this message.
5464
:return: (complete, reassembled_data, fragment_count, fragment_index)
5565
complete = True if we have all fragments now, else False
5666
reassembled_data = the final complete payload (bytes) if complete is True
5767
fragment_coutn = the total number of fragments holding the complete packet
5868
fragment_index = the index of the current received fragment
5969
"""
60-
fragment_count = (group_id >> 8) & 0xFF
61-
fragment_index = group_id & 0xFF
6270

6371
key: FragmentKey = (sender_nwk, aps_sequence, profile_id, cluster_id)
6472

@@ -69,30 +77,44 @@ def handle_incoming_fragment(self, sender_nwk: int, aps_sequence: int, profile_i
6977
else:
7078
entry = self._partial[key]
7179

72-
LOGGER.debug("Received fragment %d/%d from %s (APS seq=%d, cluster=0x%04X)",
73-
fragment_index, fragment_count, sender_nwk, aps_sequence, cluster_id)
80+
LOGGER.debug(
81+
"Received fragment %d/%d from %s (APS seq=%d, cluster=0x%04X)",
82+
fragment_index + 1,
83+
fragment_count,
84+
sender_nwk,
85+
aps_sequence,
86+
cluster_id,
87+
)
7488

7589
entry.add_fragment(fragment_index, payload)
7690

91+
loop = asyncio.get_running_loop()
92+
self._cleanup_timers[key] = loop.call_later(
93+
FRAGMENT_TIMEOUT, self.cleanup_partial, key
94+
)
95+
7796
if entry.is_complete():
7897
reassembled = entry.assemble()
7998
del self._partial[key]
80-
LOGGER.debug("Message reassembly complete. Total length=%d", len(reassembled))
99+
timer = self._cleanup_timers.pop(key, None)
100+
if timer:
101+
timer.cancel()
102+
LOGGER.debug(
103+
"Message reassembly complete. Total length=%d", len(reassembled)
104+
)
81105
return (True, reassembled, fragment_count, fragment_index)
82106
else:
83107
return (False, None, fragment_count, fragment_index)
84108

85-
def cleanup_expired(self) -> None:
109+
def cleanup_partial(self, key: FragmentKey):
110+
# Called when FRAGMENT_TIMEOUT passes with no new fragments for that key.
111+
LOGGER.debug(
112+
"Timeout for partial reassembly of fragmented message, discarding key=%s",
113+
key,
114+
)
115+
self._partial.pop(key, None)
116+
self._cleanup_timers.pop(key, None)
86117

87-
now = asyncio.get_event_loop().time()
88-
to_remove = []
89-
for k, entry in self._partial.items():
90-
if now - entry.start_time > FRAGMENT_TIMEOUT:
91-
to_remove.append(k)
92-
for k in to_remove:
93-
del self._partial[k]
94-
LOGGER.debug("Removed stale fragment reassembly for key=%s", k)
95118

96119
# Create a single global manager instance
97120
fragment_manager = FragmentManager()
98-

bellows/ezsp/protocol.py

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ def __init__(self, cb_handler: Callable, gateway: Gateway) -> None:
5353

5454
# Cached by `set_extended_timeout` so subsequent calls are a little faster
5555
self._address_table_size: int | None = None
56-
self._cleanup_fragments_periodically()
5756

5857
def _ezsp_frame(self, name: str, *args: Any, **kwargs: Any) -> bytes:
5958
"""Serialize the named frame and data."""
@@ -182,27 +181,41 @@ def __call__(self, data: bytes) -> None:
182181
if data:
183182
LOGGER.debug("Frame contains trailing data: %s", data)
184183

185-
if frame_name == "incomingMessageHandler" and result[1].options & 0x8000: # incoming message with APS_OPTION_FRAGMENT raised
184+
if (
185+
frame_name == "incomingMessageHandler" and result[1].options & 0x8000
186+
): # incoming message with APS_OPTION_FRAGMENT raised
186187
from bellows.ezsp.fragmentation import fragment_manager
187188

188189
# Extract received APS frame and sender
189190
aps_frame = result[1]
190-
sender = result[4]
191+
sender = result[4]
191192

192193
group_id = aps_frame.groupId
193194
profile_id = aps_frame.profileId
194195
cluster_id = aps_frame.clusterId
195196
aps_seq = aps_frame.sequence
196-
197-
complete, reassembled, frag_count, frag_index = fragment_manager.handle_incoming_fragment(
197+
198+
fragment_count = (group_id >> 8) & 0xFF
199+
fragment_index = group_id & 0xFF
200+
201+
(
202+
complete,
203+
reassembled,
204+
frag_count,
205+
frag_index,
206+
) = fragment_manager.handle_incoming_fragment(
198207
sender_nwk=sender,
199208
aps_sequence=aps_seq,
200209
profile_id=profile_id,
201210
cluster_id=cluster_id,
202-
group_id=group_id,
203-
payload=result[7]
211+
fragment_count=fragment_count,
212+
fragment_index=fragment_index,
213+
payload=result[7],
204214
)
205-
asyncio.create_task(self._send_fragment_ack(sender, aps_frame, frag_count, frag_index)) # APS Ack
215+
ack_task = asyncio.create_task(
216+
self._send_fragment_ack(sender, aps_frame, frag_count, frag_index)
217+
) # APS Ack
218+
ack_task.add_done_callback(self._ack_tasks.remove)
206219

207220
if not complete:
208221
# Do not pass partial data up the stack
@@ -211,8 +224,10 @@ def __call__(self, data: bytes) -> None:
211224
else:
212225
# Replace partial data with fully reassembled data
213226
result[7] = reassembled
214-
215-
LOGGER.debug("Reassembled fragmented message. Proceeding with normal handling.")
227+
228+
LOGGER.debug(
229+
"Reassembled fragmented message. Proceeding with normal handling."
230+
)
216231

217232
if sequence in self._awaiting:
218233
expected_id, schema, future = self._awaiting.pop(sequence)
@@ -238,27 +253,32 @@ def __call__(self, data: bytes) -> None:
238253
else:
239254
self._handle_callback(frame_name, result)
240255

241-
async def _send_fragment_ack(self, sender: int, incoming_aps: t.EmberApsFrame, fragment_count: int, fragment_index: int):
242-
256+
async def _send_fragment_ack(
257+
self,
258+
sender: int,
259+
incoming_aps: t.EmberApsFrame,
260+
fragment_count: int,
261+
fragment_index: int,
262+
) -> t.EmberStatus:
243263
ackFrame = t.EmberApsFrame(
244264
profileId=incoming_aps.profileId,
245265
clusterId=incoming_aps.clusterId,
246266
sourceEndpoint=incoming_aps.destinationEndpoint,
247267
destinationEndpoint=incoming_aps.sourceEndpoint,
248268
options=incoming_aps.options,
249269
groupId=((0xFF00) | (fragment_index & 0xFF)),
250-
sequence=incoming_aps.sequence
270+
sequence=incoming_aps.sequence,
271+
)
272+
273+
LOGGER.debug(
274+
"Sending fragment ack to 0x%04X for fragment index=%d/%d",
275+
sender,
276+
fragment_index + 1,
277+
fragment_count,
251278
)
279+
status = await self.sendReply(sender, ackFrame, b"")
280+
return status
252281

253-
LOGGER.debug("Sending fragment ack to 0x%04X for fragment index=%d/%d", sender, fragment_index, fragment_count)
254-
await self.sendReply(sender, ackFrame, b'')
255-
256-
async def _cleanup_fragments_periodically(self):
257-
from bellows.ezsp.fragmentation import fragment_manager
258-
while True:
259-
await asyncio.sleep(5)
260-
fragment_manager.cleanup_expired()
261-
262282
def __getattr__(self, name: str) -> Callable:
263283
if name not in self.COMMANDS:
264284
raise AttributeError(f"{name} not found in COMMANDS")

0 commit comments

Comments
 (0)