Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
ruff
micropython-esp32-stubs~=1.24.1
6 changes: 3 additions & 3 deletions mqterm/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,17 +118,17 @@ class ListDirJob(Job):
def output(self):
import os

return BytesIO("\n".join(os.listdir(self.args[0])).encode("utf-8"))
return BytesIO("\n".join(sorted(os.listdir(self.args[0]))).encode("utf-8"))


class PutFileJob(SequentialJob):
"""A job to stream a file from another client to the device."""

argc = 2
argc = 1

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.file = open(self.args[1], "wb")
self.file = open(self.args[0], "wb")
self.bytes_written = 0

async def update(self, payload, seq):
Expand Down
96 changes: 45 additions & 51 deletions mqterm/terminal.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,61 @@
import logging

from amqc.properties import CORRELATION_DATA, USER_PROPERTY

from mqterm.jobs import Job


def format_topic(*parts):
"""Create a slash-delimited MQTT topic from a list of strings."""
return "/" + "/".join(part.strip("/") for part in parts if part)


def format_properties(client_id, seq):
"""Format MQTT properties for a message."""
return {
CORRELATION_DATA: client_id.encode("utf-8"),
USER_PROPERTY: {"seq": str(seq)},
}


def parse_client_id(properties):
"""Extract the client ID from MQTT properties."""
client_id = properties.get(CORRELATION_DATA, None)
if not client_id:
raise ValueError("Missing client ID")
return client_id.decode("utf-8")


def parse_seq(properties):
"""Extract the sequence number from MQTT User Properties."""
user_properties = properties.get(USER_PROPERTY, {})
seq = user_properties.get("seq", None)
if not seq:
raise ValueError("Missing sequence information")
try:
seq = int(seq)
except TypeError:
raise ValueError(f"Invalid sequence information: {seq}")
return seq


class MqttTerminal:
PKTLEN = 1400 # data bytes that reasonably fit into a TCP packet
BUFLEN = PKTLEN * 2 # payload size for MQTT messages

PROP_USER = 0x26 # mqtt user properties
PROP_CORR = 0x09 # mqtt correlation data

def __init__(
self, mqtt_client, topic_prefix=None, logger=logging.getLogger("mqterm")
):
self.mqtt_client = mqtt_client
self.topic_prefix = topic_prefix
self.in_topic = self._format_topic(self.topic_prefix, "tty", "in")
self.out_topic = self._format_topic(self.topic_prefix, "tty", "out")
self.err_topic = self._format_topic(self.topic_prefix, "tty", "err")
self.in_topic = format_topic(self.topic_prefix, "tty", "in")
self.out_topic = format_topic(self.topic_prefix, "tty", "out")
self.err_topic = format_topic(self.topic_prefix, "tty", "err")
self.out_buffer = bytearray(self.BUFLEN)
self.out_view = memoryview(self.out_buffer)
self.logger = logger
self.jobs = {}

@staticmethod
def _format_topic(*parts):
"""Create a slash-delimited MQTT topic from a list of strings."""
return "/" + "/".join(part.strip("/") for part in parts if part)

async def connect(self):
"""Start processing messages in the input stream."""
await self.mqtt_client.subscribe(self.in_topic, qos=1)
Expand All @@ -38,8 +66,8 @@ async def disconnect(self):

async def handle_msg(self, _topic, msg, properties={}):
"""Process a single MQTT message and apply to the appropriate job."""
client_id = self._get_client_id(properties)
seq = self._get_seq(properties)
client_id = parse_client_id(properties)
seq = parse_seq(properties)
try:
await self.update_job(client_id=client_id, seq=seq, payload=msg)
except RuntimeError as e: # logged & handled as warning
Expand All @@ -48,21 +76,15 @@ async def handle_msg(self, _topic, msg, properties={}):
self.err_topic,
str(e).encode("utf-8"),
qos=1,
properties={
self.PROP_CORR: client_id.encode("utf-8"),
self.PROP_USER: {"seq": str(seq)},
},
properties=format_properties(client_id, seq),
)
except Exception as e:
self.logger.exception(e)
await self.mqtt_client.publish(
self.err_topic,
str(e).encode("utf-8"),
qos=1,
properties={
self.PROP_CORR: client_id.encode("utf-8"),
self.PROP_USER: {"seq": "-1"},
},
properties=format_properties(client_id, -1),
)
if client_id in self.jobs: # remove job on fatal error
del self.jobs[client_id]
Expand All @@ -87,10 +109,7 @@ async def update_job(self, client_id, seq, payload):
self.out_topic,
b"",
qos=1,
properties={
self.PROP_CORR: client_id.encode("utf-8"),
self.PROP_USER: {"seq": "-1"},
},
properties=format_properties(client_id, -1),
)
del self.jobs[client_id]

Expand All @@ -101,37 +120,12 @@ async def stream_job_output(self, job):
while True:
bytes_read = in_buffer.readinto(self.out_buffer)
if bytes_read > 0:
self.logger.debug(f"Streaming {bytes_read} bytes")
await self.mqtt_client.publish(
self.out_topic,
self.out_view[:bytes_read],
qos=1,
properties={
self.PROP_CORR: job.client_id.encode("utf-8"),
self.PROP_USER: {"seq": str(seq)},
},
properties=format_properties(job.client_id, seq),
)
seq += 1
else:
break

# Client ID: MQTT Correlation Data
# Always bytes; we format it as UTF-8
def _get_client_id(self, properties):
client_id = properties.get(self.PROP_CORR, None)
if not client_id:
raise ValueError("Missing client ID")
return client_id.decode("utf-8")

# Sequence: MQTT User Properties
# List of tuples; we store sequence info as a string in the first one
def _get_seq(self, properties):
user_properties = properties.get(self.PROP_USER, {})
seq = user_properties.get("seq", None)
if not seq:
raise ValueError("Missing sequence information")
try:
seq = int(seq)
except TypeError:
raise ValueError(f"Invalid sequence information: {seq}")
return seq
22 changes: 11 additions & 11 deletions tests/e2e/e2e_file_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from amqc.client import MQTTClient, config

from mqterm.terminal import MqttTerminal
from mqterm.terminal import MqttTerminal, format_properties

# Set up logging; pass LOG_LEVEL=DEBUG if needed for local testing
logger = logging.getLogger()
Expand Down Expand Up @@ -42,21 +42,21 @@
term = MqttTerminal(device_client, topic_prefix="/test")


def create_props(seq: int, client_id: str) -> dict:
"""Create MQTTv5 properties with a seq number and client ID."""
return {
MqttTerminal.PROP_CORR: client_id.encode("utf-8"),
MqttTerminal.PROP_USER: {"seq": str(seq)},
}
# def create_props(seq: int, client_id: str) -> dict:
# """Create MQTTv5 properties with a seq number and client ID."""
# return {
# CORRELATION_DATA: client_id.encode("utf-8"),
# USER_PROPERTY: {"seq": str(seq)},
# }


async def send_file(buffer: BytesIO):
"""Send a file to the terminal."""
# Send the first message that will create the job
seq = 0
props = create_props(seq, "tty0")
props = format_properties("tty0", seq)
await control_client.publish(
"/test/tty/in", "cp test.txt test.txt".encode("utf-8"), properties=props
"/test/tty/in", "cp test.txt".encode("utf-8"), properties=props
)

# Send the file in 4-byte chunks; close when done
Expand All @@ -68,7 +68,7 @@ async def send_file(buffer: BytesIO):
seq += 1
else:
seq = -1
props = create_props(seq, "tty0")
props = format_properties("tty0", seq)
logger.debug(f"Sending chunk {seq} of size {len(chunk)}: {chunk!r}")
await control_client.publish("/test/tty/in", chunk, properties=props)
if seq == -1:
Expand All @@ -90,7 +90,7 @@ async def get_file(buffer: BytesIO):
"""Send a file to the terminal and read it back."""
# Send the request for the file
seq = 0
props = create_props(seq, "tty0")
props = format_properties("tty0", seq)
await control_client.publish(
"/test/tty/in", "cat test.txt".encode(), properties=props
)
Expand Down
118 changes: 118 additions & 0 deletions tests/test_jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Test the GetFileJob."""

import asyncio
import os
from unittest import TestCase

from mqterm.jobs import GetFileJob, Job, PlatformInfoJob, PutFileJob, WhoAmIJob


class TestJob(TestCase):
def test_from_cmd(self):
"""Job should parse a command string into a Job object"""
job = Job.from_cmd("cat file.txt")
self.assertEqual(job.cmd, "cat")
self.assertEqual(job.args, ["file.txt"])
self.assertIsInstance(job, GetFileJob)
with self.assertRaises(ValueError):
Job.from_cmd("unknown")

def test_str(self):
"""Job should have a string representation"""
job = Job("cat", args=["file.txt"])
self.assertEqual(str(job), "Job for localhost: cat file.txt")


class TestGetFileJob(TestCase):
def setUp(self):
# Mock the file reading for the test
self.file_content = "abc"
with open("file.txt", "w") as f:
f.write(self.file_content)

def tearDown(self):
# Clean up the file after the test
try:
os.remove("file.txt")
except OSError:
pass

def test_init(self):
with self.assertRaises(ValueError, msg="GetFileJob requires a filename"):
GetFileJob("cat")

def test_read_file(self):
job = GetFileJob("cat", ["file.txt"])
output = job.output().read().decode("utf-8")
self.assertEqual(output, "abc", msg="File content should match expected output")


class TestWhoAmIJob(TestCase):
def test_run(self):
job = WhoAmIJob("whoami", client_id="user@client")
output = job.output().read().decode("utf-8")
self.assertEqual(
output, "user@client", msg="WhoAmIJob should return the client ID"
)


class TestPlatformInfoJob(TestCase):
def test_run(self):
job = PlatformInfoJob("uname")
output = job.output().read().decode("utf-8")
self.assertIn(
"MQTerm v", output, msg="platform info should contain mqterm version"
)
self.assertIn(
"micropython v",
output,
msg="platform info should contain micropython version",
)


class TestListDirJob(TestCase):
def setUp(self):
# Create a temporary directory with some files for testing
self.test_dir = "test_dir"
os.mkdir(self.test_dir)
with open(f"{self.test_dir}/file1.txt", "w") as f:
f.write("content1")
with open(f"{self.test_dir}/file2.txt", "w") as f:
f.write("content2")

def tearDown(self):
# Clean up the test directory after the test
for file in os.listdir(self.test_dir):
os.remove(f"{self.test_dir}/{file}")
os.rmdir(self.test_dir)

def test_run(self):
job = Job.from_cmd(f"ls {self.test_dir}")
output = job.output().read().decode("utf-8").strip()
expected_files = "file1.txt\nfile2.txt"
self.assertEqual(output, expected_files)


class TestPutFileJob(TestCase):
def setUp(self):
self.test_file = "test_file.txt"
self.test_contents = b"test content"

def tearDown(self):
# Clean up the test file after the test
try:
os.remove(self.test_file)
except OSError:
pass

def test_run(self):
"""Should write to file and return bytes written"""
job = PutFileJob(f"put {self.test_file}", [self.test_file])
asyncio.run(job.update(b"test ", seq=1))
asyncio.run(job.update(b"content", seq=2))
asyncio.run(job.update(b"", seq=-1)) # Signal end of file transfer
assert job.ready, "Job should be ready after final update"
output = job.output().read().decode("utf-8")
self.assertEqual(int(output), len(self.test_contents))
with open(self.test_file, "rb") as f:
self.assertEqual(f.read(), self.test_contents)
Loading