Skip to content

Commit 52f5382

Browse files
committed
add basic fragmentation support
1 parent bdb7374 commit 52f5382

File tree

2 files changed

+151
-0
lines changed

2 files changed

+151
-0
lines changed

bellows/ezsp/fragmentation.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""
2+
Implements APS fragmentation reassembly on the EZSP Host side,
3+
mirroring the logic from fragmentation.c in the EmberZNet stack.
4+
"""
5+
6+
import asyncio
7+
import logging
8+
from collections import defaultdict
9+
from typing import Optional, Dict, Tuple
10+
11+
LOGGER = logging.getLogger(__name__)
12+
13+
# The maximum time (in seconds) we wait for all fragments of a given message.
14+
# If not all fragments arrive within this time, we discard the partial data.
15+
FRAGMENT_TIMEOUT = 10
16+
17+
# store partial data keyed by (sender, aps_sequence, profile_id, cluster_id)
18+
FragmentKey = Tuple[int, int, int, int]
19+
20+
class _FragmentEntry:
21+
def __init__(self, fragment_count: int):
22+
self.fragment_count = fragment_count
23+
self.fragments_received = 0
24+
self.fragment_data = {}
25+
self.start_time = asyncio.get_event_loop().time()
26+
27+
def add_fragment(self, index: int, data: bytes) -> None:
28+
if index not in self.fragment_data:
29+
self.fragment_data[index] = data
30+
self.fragments_received += 1
31+
32+
def is_complete(self) -> bool:
33+
return self.fragments_received == self.fragment_count
34+
35+
def assemble(self) -> bytes:
36+
return b''.join(self.fragment_data[i] for i in sorted(self.fragment_data.keys()))
37+
38+
class FragmentManager:
39+
def __init__(self):
40+
self._partial: Dict[FragmentKey, _FragmentEntry] = {}
41+
42+
def handle_incoming_fragment(self, sender_nwk: int, aps_sequence: int, profile_id: int, cluster_id: int,
43+
group_id: int, payload: bytes) -> Tuple[bool, Optional[bytes], int, int]:
44+
"""
45+
Handle a newly received fragment. The group_id field
46+
encodes high byte = total fragment count, low byte = current fragment index.
47+
48+
:param sender_nwk: NWK address or the short ID of the sender.
49+
:param aps_sequence: The APS sequence from the incoming APS frame.
50+
:param profile_id: The APS frame's profileId.
51+
:param cluster_id: The APS frame's clusterId.
52+
:param group_id: The APS frame's groupId (used to store fragment # / total).
53+
:param payload: The fragment of data for this message.
54+
:return: (complete, reassembled_data, fragment_count, fragment_index)
55+
complete = True if we have all fragments now, else False
56+
reassembled_data = the final complete payload (bytes) if complete is True
57+
fragment_coutn = the total number of fragments holding the complete packet
58+
fragment_index = the index of the current received fragment
59+
"""
60+
fragment_count = (group_id >> 8) & 0xFF
61+
fragment_index = group_id & 0xFF
62+
63+
key: FragmentKey = (sender_nwk, aps_sequence, profile_id, cluster_id)
64+
65+
# If we have never seen this message, create a reassembly entry.
66+
if key not in self._partial:
67+
entry = _FragmentEntry(fragment_count)
68+
self._partial[key] = entry
69+
else:
70+
entry = self._partial[key]
71+
72+
LOGGER.debug("Received fragment %d/%d from %s (APS seq=%d, cluster=0x%04X)",
73+
fragment_index, fragment_count, sender_nwk, aps_sequence, cluster_id)
74+
75+
entry.add_fragment(fragment_index, payload)
76+
77+
if entry.is_complete():
78+
reassembled = entry.assemble()
79+
del self._partial[key]
80+
LOGGER.debug("Message reassembly complete. Total length=%d", len(reassembled))
81+
return (True, reassembled, fragment_count, fragment_index)
82+
else:
83+
return (False, None, fragment_count, fragment_index)
84+
85+
def cleanup_expired(self) -> None:
86+
87+
now = asyncio.get_event_loop().time()
88+
to_remove = []
89+
for k, entry in self._partial.items():
90+
if now - entry.start_time > FRAGMENT_TIMEOUT:
91+
to_remove.append(k)
92+
for k in to_remove:
93+
del self._partial[k]
94+
LOGGER.debug("Removed stale fragment reassembly for key=%s", k)
95+
96+
# Create a single global manager instance
97+
fragment_manager = FragmentManager()
98+

bellows/ezsp/protocol.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,38 @@ def __call__(self, data: bytes) -> None:
181181
if data:
182182
LOGGER.debug("Frame contains trailing data: %s", data)
183183

184+
if frame_name == "incomingMessageHandler" and result[1].options & 0x8000: # incoming message with APS_OPTION_FRAGMENT raised
185+
from bellows.ezsp.fragmentation import fragment_manager
186+
187+
# Extract received APS frame and sender
188+
aps_frame = result[1]
189+
sender = result[4]
190+
191+
group_id = aps_frame.groupId
192+
profile_id = aps_frame.profileId
193+
cluster_id = aps_frame.clusterId
194+
aps_seq = aps_frame.sequence
195+
196+
complete, reassembled, frag_count, frag_index = fragment_manager.handle_incoming_fragment(
197+
sender_nwk=sender,
198+
aps_sequence=aps_seq,
199+
profile_id=profile_id,
200+
cluster_id=cluster_id,
201+
group_id=group_id,
202+
payload=result[7]
203+
)
204+
asyncio.create_task(self._send_fragment_ack(sender, aps_frame, frag_count, frag_index)) # APS Ack
205+
206+
if not complete:
207+
# Do not pass partial data up the stack
208+
LOGGER.debug("Fragment reassembly not complete. waiting for more data.")
209+
return
210+
else:
211+
# Replace partial data with fully reassembled data
212+
result[7] = reassembled
213+
214+
LOGGER.debug("Reassembled fragmented message. Proceeding with normal handling.")
215+
184216
if sequence in self._awaiting:
185217
expected_id, schema, future = self._awaiting.pop(sequence)
186218
try:
@@ -205,6 +237,27 @@ def __call__(self, data: bytes) -> None:
205237
else:
206238
self._handle_callback(frame_name, result)
207239

240+
async def _send_fragment_ack(self, sender: int, incoming_aps: t.EmberApsFrame, fragment_count: int, fragment_index: int):
241+
242+
ackFrame = t.EmberApsFrame(
243+
profileId=incoming_aps.profileId,
244+
clusterId=incoming_aps.clusterId,
245+
sourceEndpoint=incoming_aps.destinationEndpoint,
246+
destinationEndpoint=incoming_aps.sourceEndpoint,
247+
options=incoming_aps.options,
248+
groupId=((0xFF00) | (fragment_index & 0xFF)),
249+
sequence=incoming_aps.sequence
250+
)
251+
252+
LOGGER.debug("Sending fragment ack to 0x%04X for fragment index=%d/%d", sender, fragment_index, fragment_count)
253+
await self.sendReply(sender, ackFrame, b'')
254+
255+
async def _cleanup_fragments_periodically(self):
256+
from bellows.ezsp.fragmentation import fragment_manager
257+
while True:
258+
await asyncio.sleep(5)
259+
fragment_manager.cleanup_expired()
260+
208261
def __getattr__(self, name: str) -> Callable:
209262
if name not in self.COMMANDS:
210263
raise AttributeError(f"{name} not found in COMMANDS")

0 commit comments

Comments
 (0)