Skip to content
116 changes: 116 additions & 0 deletions bellows/ezsp/fragmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""Implements APS fragmentation reassembly on the EZSP Host side,
mirroring the logic from fragmentation.c in the EmberZNet stack.
"""

from __future__ import annotations

import asyncio
import logging

LOGGER = logging.getLogger(__name__)

# The maximum time (in seconds) we wait for all fragments of a given message.
# If not all fragments arrive within this time, we discard the partial data.
FRAGMENT_TIMEOUT = 10

# store partial data keyed by (sender, aps_sequence, profile_id, cluster_id)
FragmentKey = tuple[int, int, int, int]


class _FragmentEntry:
def __init__(self, fragment_count: int):
self.fragment_count = fragment_count
self.fragments_received = 0
self.fragment_data = {}

def add_fragment(self, index: int, data: bytes) -> None:
if index not in self.fragment_data:
self.fragment_data[index] = data
self.fragments_received += 1

def is_complete(self) -> bool:
return self.fragments_received == self.fragment_count

def assemble(self) -> bytes:
return b"".join(
self.fragment_data[i] for i in sorted(self.fragment_data.keys())
)


class FragmentManager:
def __init__(self):
self._partial: dict[FragmentKey, _FragmentEntry] = {}
self._cleanup_timers: dict[FragmentKey, asyncio.TimerHandle] = {}

def handle_incoming_fragment(
self,
sender_nwk: int,
aps_sequence: int,
profile_id: int,
cluster_id: int,
fragment_count: int,
fragment_index: int,
payload: bytes,
) -> tuple[bool, bytes | None, int, int]:
"""Handle a newly received fragment.

:param sender_nwk: NWK address or the short ID of the sender.
:param aps_sequence: The APS sequence from the incoming APS frame.
:param profile_id: The APS frame's profileId.
:param cluster_id: The APS frame's clusterId.
:param fragment_count: The total number of expected message fragments.
:param fragment_index: The index of the current fragment being processed.
:param payload: The fragment of data for this message.
:return: (complete, reassembled_data, fragment_count, fragment_index)
complete = True if we have all fragments now, else False
reassembled_data = the final complete payload (bytes) if complete is True
fragment_coutn = the total number of fragments holding the complete packet
fragment_index = the index of the current received fragment
"""

key: FragmentKey = (sender_nwk, aps_sequence, profile_id, cluster_id)

# If we have never seen this message, create a reassembly entry.
if key not in self._partial:
entry = _FragmentEntry(fragment_count)
self._partial[key] = entry
else:
entry = self._partial[key]

LOGGER.debug(
"Received fragment %d/%d from %s (APS seq=%d, cluster=0x%04X)",
fragment_index + 1,
fragment_count,
sender_nwk,
aps_sequence,
cluster_id,
)

entry.add_fragment(fragment_index, payload)

loop = asyncio.get_running_loop()
self._cleanup_timers[key] = loop.call_later(
FRAGMENT_TIMEOUT, self.cleanup_partial, key
)

if entry.is_complete():
reassembled = entry.assemble()
del self._partial[key]
timer = self._cleanup_timers.pop(key, None)
if timer:
timer.cancel()
LOGGER.debug(
"Message reassembly complete. Total length=%d", len(reassembled)
)
return (True, reassembled, fragment_count, fragment_index)
else:
return (False, None, fragment_count, fragment_index)

def cleanup_partial(self, key: FragmentKey):
# Called when FRAGMENT_TIMEOUT passes with no new fragments for that key.
LOGGER.debug(
"Timeout for partial reassembly of fragmented message, discarding key=%s",
key,
)
self._partial.pop(key, None)
self._cleanup_timers.pop(key, None)
75 changes: 75 additions & 0 deletions bellows/ezsp/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from bellows.config import CONF_EZSP_POLICIES
from bellows.exception import InvalidCommandError
from bellows.ezsp.fragmentation import FragmentManager
import bellows.types as t

if TYPE_CHECKING:
Expand Down Expand Up @@ -53,6 +54,8 @@ def __init__(self, cb_handler: Callable, gateway: Gateway) -> None:

# Cached by `set_extended_timeout` so subsequent calls are a little faster
self._address_table_size: int | None = None
self._fragment_manager = FragmentManager()
self._fragment_ack_tasks: set[asyncio.Task] = set()

def _ezsp_frame(self, name: str, *args: Any, **kwargs: Any) -> bytes:
"""Serialize the named frame and data."""
Expand Down Expand Up @@ -181,6 +184,52 @@ def __call__(self, data: bytes) -> None:
if data:
LOGGER.debug("Frame contains trailing data: %s", data)

if (
frame_name == "incomingMessageHandler"
and result[1].options & t.EmberApsOption.APS_OPTION_FRAGMENT
):
# Extract received APS frame and sender
aps_frame = result[1]
sender = result[4]

# The fragment count and index are encoded in the groupId field
fragment_count = (aps_frame.groupId >> 8) & 0xFF
fragment_index = aps_frame.groupId & 0xFF

(
complete,
reassembled,
frag_count,
frag_index,
) = self._fragment_manager.handle_incoming_fragment(
sender_nwk=sender,
aps_sequence=aps_frame.sequence,
profile_id=aps_frame.profileId,
cluster_id=aps_frame.clusterId,
fragment_count=fragment_count,
fragment_index=fragment_index,
payload=result[7],
)

ack_task = asyncio.create_task(
self._send_fragment_ack(sender, aps_frame, frag_count, frag_index)
) # APS Ack

self._fragment_ack_tasks.add(ack_task)
ack_task.add_done_callback(lambda t: self._fragment_ack_tasks.discard(t))

if not complete:
# Do not pass partial data up the stack
LOGGER.debug("Fragment reassembly not complete. waiting for more data.")
return

# Replace partial data with fully reassembled data
result[7] = reassembled

LOGGER.debug(
"Reassembled fragmented message. Proceeding with normal handling."
)

if sequence in self._awaiting:
expected_id, schema, future = self._awaiting.pop(sequence)
try:
Expand All @@ -205,6 +254,32 @@ def __call__(self, data: bytes) -> None:
else:
self._handle_callback(frame_name, result)

async def _send_fragment_ack(
self,
sender: int,
incoming_aps: t.EmberApsFrame,
fragment_count: int,
fragment_index: int,
) -> t.EmberStatus:
ackFrame = t.EmberApsFrame(
profileId=incoming_aps.profileId,
clusterId=incoming_aps.clusterId,
sourceEndpoint=incoming_aps.destinationEndpoint,
destinationEndpoint=incoming_aps.sourceEndpoint,
options=incoming_aps.options,
groupId=((0xFF00) | (fragment_index & 0xFF)),
sequence=incoming_aps.sequence,
)

LOGGER.debug(
"Sending fragment ack to 0x%04X for fragment index=%d/%d",
sender,
fragment_index + 1,
fragment_count,
)
status = await self.sendReply(sender, ackFrame, b"")
return status[0]

def __getattr__(self, name: str) -> Callable:
if name not in self.COMMANDS:
raise AttributeError(f"{name} not found in COMMANDS")
Expand Down
157 changes: 157 additions & 0 deletions tests/test_ezsp_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,160 @@ async def test_parsing_schema_response(prot_hndl_v9):

rsp = await coro
assert rsp == GetTokenDataRsp(status=t.EmberStatus.LIBRARY_NOT_PRESENT)


async def test_send_fragment_ack(prot_hndl, caplog):
"""Test the _send_fragment_ack method."""
sender = 0x1D6F
incoming_aps = t.EmberApsFrame(
profileId=260,
clusterId=65281,
sourceEndpoint=2,
destinationEndpoint=2,
options=33088,
groupId=512,
sequence=238,
)
fragment_count = 2
fragment_index = 0

expected_ack_frame = t.EmberApsFrame(
profileId=260,
clusterId=65281,
sourceEndpoint=2,
destinationEndpoint=2,
options=33088,
groupId=((0xFF00) | (fragment_index & 0xFF)),
sequence=238,
)

with patch.object(prot_hndl, "sendReply", new=AsyncMock()) as mock_send_reply:
mock_send_reply.return_value = (t.EmberStatus.SUCCESS,)

caplog.set_level(logging.DEBUG)
status = await prot_hndl._send_fragment_ack(
sender, incoming_aps, fragment_count, fragment_index
)

# Assertions
assert status == t.EmberStatus.SUCCESS
assert (
"Sending fragment ack to 0x1d6f for fragment index=1/2".lower()
in caplog.text.lower()
)
mock_send_reply.assert_called_once_with(sender, expected_ack_frame, b"")


async def test_incoming_fragmented_message_incomplete(prot_hndl, caplog):
"""Test handling of an incomplete fragmented message."""
packet = b"\x90\x01\x45\x00\x05\x01\x01\xff\x02\x02\x40\x81\x00\x02\xee\xff\xf8\x6f\x1d\xff\xff\x01\xdd"

# Parse packet manually to extract parameters for assertions
sender = 0x1D6F
aps_frame = t.EmberApsFrame(
profileId=261, # 0x0105
clusterId=65281, # 0xFF01
sourceEndpoint=2, # 0x02
destinationEndpoint=2, # 0x02
options=33088, # 0x8140 (APS_OPTION_FRAGMENT + others)
groupId=512, # 0x0002 (fragment_count=2, fragment_index=0)
sequence=238, # 0xEE
)

with patch.object(prot_hndl, "_send_fragment_ack", new=AsyncMock()) as mock_ack:
mock_ack.return_value = None

caplog.set_level(logging.DEBUG)
prot_hndl(packet)

assert len(prot_hndl._fragment_ack_tasks) == 1
ack_task = next(iter(prot_hndl._fragment_ack_tasks))
await asyncio.gather(ack_task) # Ensure task completes and triggers callback
assert (
len(prot_hndl._fragment_ack_tasks) == 0
), "Done callback should have removed task"

prot_hndl._handle_callback.assert_not_called()
assert "Fragment reassembly not complete. waiting for more data." in caplog.text
mock_ack.assert_called_once_with(sender, aps_frame, 2, 0)


async def test_incoming_fragmented_message_complete(prot_hndl, caplog):
"""Test handling of a complete fragmented message."""
packet1 = (
b"\x90\x01\x45\x00\x04\x01\x01\xff\x02\x02\x40\x81\x00\x02\xee\xff\xf8\x6f\x1d\xff\xff\x09"
+ b"complete "
) # fragment index 0
packet2 = (
b"\x90\x01\x45\x00\x04\x01\x01\xff\x02\x02\x40\x81\x01\x02\xee\xff\xf8\x6f\x1d\xff\xff\x07"
+ b"message"
) # fragment index 1
sender = 0x1D6F

aps_frame_1 = t.EmberApsFrame(
profileId=260,
clusterId=65281,
sourceEndpoint=2,
destinationEndpoint=2,
options=33088, # Includes APS_OPTION_FRAGMENT
groupId=512, # fragment_count=2, fragment_index=0
sequence=238,
)
aps_frame_2 = t.EmberApsFrame(
profileId=260,
clusterId=65281,
sourceEndpoint=2,
destinationEndpoint=2,
options=33088,
groupId=513, # fragment_count=2, fragment_index=1
sequence=238,
)
reassembled = b"complete message"

with patch.object(prot_hndl, "_send_fragment_ack", new=AsyncMock()) as mock_ack:
mock_ack.return_value = None
caplog.set_level(logging.DEBUG)

# Packet 1
prot_hndl(packet1)
assert len(prot_hndl._fragment_ack_tasks) == 1
ack_task = next(iter(prot_hndl._fragment_ack_tasks))
await asyncio.gather(ack_task) # Ensure task completes and triggers callback
assert (
len(prot_hndl._fragment_ack_tasks) == 0
), "Done callback should have removed task"

prot_hndl._handle_callback.assert_not_called()
assert (
"Reassembled fragmented message. Proceeding with normal handling."
not in caplog.text
)
mock_ack.assert_called_with(sender, aps_frame_1, 2, 0)

# Packet 2
prot_hndl(packet2)
assert len(prot_hndl._fragment_ack_tasks) == 1
ack_task = next(iter(prot_hndl._fragment_ack_tasks))
await asyncio.gather(ack_task) # Ensure task completes and triggers callback
assert (
len(prot_hndl._fragment_ack_tasks) == 0
), "Done callback should have removed task"

prot_hndl._handle_callback.assert_called_once_with(
"incomingMessageHandler",
[
t.EmberIncomingMessageType.INCOMING_UNICAST, # 0x00
aps_frame_2, # Parsed APS frame
255, # lastHopLqi: 0xFF
-8, # lastHopRssi: 0xF8
sender, # 0x1D6F
255, # bindingIndex: 0xFF
255, # addressIndex: 0xFF
reassembled, # Reassembled payload
],
)
assert (
"Reassembled fragmented message. Proceeding with normal handling."
in caplog.text
)
mock_ack.assert_called_with(sender, aps_frame_2, 2, 1)
Loading
Loading