diff --git a/gs/backend/data/data_wrappers/wrappers.py b/gs/backend/data/data_wrappers/wrappers.py index bd7645c7d..08c260121 100644 --- a/gs/backend/data/data_wrappers/wrappers.py +++ b/gs/backend/data/data_wrappers/wrappers.py @@ -99,6 +99,21 @@ class CommandsWrapper(AbstractWrapper[Commands]): model = Commands + def retrieve_floating_commands(self) -> list[Commands]: + """ + Retrieves all commands which do not have a valid entry in + the packet_commands table. + A command which is not valid is considered as any command whose ID + does not match with any command_id in the packet_commands table + """ + packet_commands = PacketCommandsWrapper().get_all() + packet_ids = {packet_command.command_id for packet_command in packet_commands} + + commands = self.get_all() + floating_commands = [fc for fc in commands if fc.id not in packet_ids] + + return floating_commands + class TelemetryWrapper(AbstractWrapper[Telemetry]): """ diff --git a/python_test/conftest.py b/python_test/conftest.py index 08b9c5225..1bf9d8bd4 100644 --- a/python_test/conftest.py +++ b/python_test/conftest.py @@ -49,6 +49,7 @@ def test_get_db_session(monkeypatch, db_session: Session): When testing any database function that requires the `get_db_session()` function, you must add the module path to the list below. """ path_list: list[str] = [ + "gs.backend.data.data_wrappers.abstract_wrapper", "gs.backend.data.data_wrappers.aro_wrapper.aro_request_wrapper", "gs.backend.data.data_wrappers.aro_wrapper.aro_user_data_wrapper", "gs.backend.data.data_wrappers.aro_wrapper.aro_user_auth_token_wrapper", diff --git a/python_test/test_retrieve_floating_commands.py b/python_test/test_retrieve_floating_commands.py new file mode 100644 index 000000000..0dc2e285b --- /dev/null +++ b/python_test/test_retrieve_floating_commands.py @@ -0,0 +1,71 @@ +from datetime import datetime +from uuid import uuid4 + +from gs.backend.data.data_wrappers.wrappers import ( + CommandsWrapper, + CommsSessionWrapper, + MainCommandWrapper, + PacketCommandsWrapper, + PacketWrapper, +) +from gs.backend.data.enums.transactional import MainPacketType + + +def test_retrieve_floating_commands_filters(): + pcw = PacketCommandsWrapper() + cw = CommandsWrapper() + mc = MainCommandWrapper() + pw = PacketWrapper() + csw = CommsSessionWrapper() + + packet_id = uuid4() + cmd_type = mc.create( + dict( + id=1, + name="test", + data_size=1, + total_size=1, + ) + ).id + + comms_session = csw.create({"id": uuid4(), "start_time": datetime.now()}) + packet = pw.create( + dict( + id=packet_id, + session_id=comms_session.id, + raw_data=b"\x00", + type_=MainPacketType.UPLINK, + payload_data=b"\x00", + offset=0, + ) + ) + + cmd_in_packet = cw.create(dict(id=uuid4(), type_=cmd_type)) + cmd_free = cw.create(dict(id=uuid4(), type_=cmd_type)) + cmd_free2 = cw.create(dict(id=uuid4(), type_=cmd_type)) + + pcw.create(dict(packet_id=packet.id, command_id=cmd_in_packet.id)) + + result = cw.retrieve_floating_commands() + for command in result: + assert command.id in [cmd_free.id, cmd_free2.id] + + +def test_retrieve_floating_commands_no_packet_commands(): + cw = CommandsWrapper() + mc = MainCommandWrapper() + cmd_type = mc.create( + dict( + id=2, + name="test", + data_size=1, + total_size=1, + ) + ).id + + cw.create(dict(id=uuid4(), type_=cmd_type)) + cw.create(dict(id=uuid4(), type_=cmd_type)) + + result = cw.retrieve_floating_commands() + expected = cw.get_all() + assert {c.id for c in result} == {c.id for c in expected}