diff --git a/.ruff.toml b/.ruff.toml new file mode 100644 index 00000000..8dd063be --- /dev/null +++ b/.ruff.toml @@ -0,0 +1,27 @@ +target-version = "py310" +exclude = ["*_pb2.py"] + +line-length = 100 +[format] +indent-style = "space" +quote-style = "double" + +[lint] +select = [ + "E", # pycodestyle + "F", # Pyflakes + "UP", # pyupgrade + "B", # flake8-bugbear + "SIM", # flake8-simplify + "I", # isort +] +ignore = [ + "B905", # new 'strict' argument for zip in python3.10 + "E501", # ignore line length errors for now + "E701", # multiple statements on one line, will be fixed by format + "E722", # do not use bare except + "E741", # ambiguous variable name (seems to trigger on l, I, and such) + "F841", # unused local variable + "SIM105", # contextlib.suppress is 3x slower than try-except + "SIM115", # use context manager to open files (only works in some places) +] diff --git a/cnc/protocol/__init__.py b/cnc/protocol/__init__.py index 8043b0f2..20976d5e 100644 --- a/cnc/protocol/__init__.py +++ b/cnc/protocol/__init__.py @@ -1,3 +1,3 @@ # SPDX-FileCopyrightText: 2023 Carnegie Mellon University - Satyalab # -# SPDX-License-Identifier: GPL-2.0-only \ No newline at end of file +# SPDX-License-Identifier: GPL-2.0-only diff --git a/cnc/server/main.py b/cnc/server/main.py index f36acec9..6900c05b 100755 --- a/cnc/server/main.py +++ b/cnc/server/main.py @@ -5,14 +5,15 @@ # # SPDX-License-Identifier: GPL-2.0-only -from gabriel_server.network_engine import server_runner -import logging import argparse +import logging + +from gabriel_server.network_engine import server_runner DEFAULT_PORT = 9099 DEFAULT_NUM_TOKENS = 2 INPUT_QUEUE_MAXSIZE = 60 -SOURCE = 'cnc' +SOURCE = "cnc" logging.basicConfig(level=logging.INFO) @@ -20,25 +21,26 @@ def main(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( "-t", "--tokens", type=int, default=DEFAULT_NUM_TOKENS, help="number of tokens" ) - - parser.add_argument( - "-p", "--port", type=int, default=DEFAULT_PORT, help="Set port number" - ) + + parser.add_argument("-p", "--port", type=int, default=DEFAULT_PORT, help="Set port number") parser.add_argument( "-q", "--queue", type=int, default=INPUT_QUEUE_MAXSIZE, help="Max input queue size" ) - + args, _ = parser.parse_known_args() - server_runner.run(websocket_port=args.port, zmq_address='tcp://*:5555', num_tokens=args.tokens, - input_queue_maxsize=args.queue) + server_runner.run( + websocket_port=args.port, + zmq_address="tcp://*:5555", + num_tokens=args.tokens, + input_queue_maxsize=args.queue, + ) + if __name__ == "__main__": main() diff --git a/cnc/server/swarm_controller.py b/cnc/server/swarm_controller.py index 165119b8..285db711 100755 --- a/cnc/server/swarm_controller.py +++ b/cnc/server/swarm_controller.py @@ -5,46 +5,46 @@ # # SPDX-License-Identifier: GPL-2.0-only +import argparse import json +import logging import os import subprocess import sys -import logging +from urllib.parse import urlparse from zipfile import ZipFile -from google.protobuf.message import DecodeError -from google.protobuf import text_format + +import redis import requests -from cnc_protocol import cnc_pb2 -import argparse import zmq import zmq.asyncio -import redis -from urllib.parse import urlparse - +from cnc_protocol import cnc_pb2 +from google.protobuf import text_format +from google.protobuf.message import DecodeError logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) handler = logging.StreamHandler(sys.stdout) handler.setLevel(logging.DEBUG) -formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') +formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") handler.setFormatter(formatter) logger.addHandler(handler) # Set up the paths and variables for the compiler -compiler_path = '/compiler' -output_path = '/compiler/out/flightplan_' -platform_path = '/compiler/python/project' +compiler_path = "/compiler" +output_path = "/compiler/out/flightplan_" +platform_path = "/compiler/python/project" def download_script(script_url): try: # Get the ZIP file name from the URL - filename = script_url.rsplit(sep='/')[-1] - logger.info(f'Writing {filename} to disk...') - + filename = script_url.rsplit(sep="/")[-1] + logger.info(f"Writing {filename} to disk...") + # Download the ZIP file r = requests.get(script_url, stream=True) - with open(filename, mode='wb') as f: + with open(filename, mode="wb") as f: for chunk in r.iter_content(chunk_size=8192): f.write(chunk) @@ -53,177 +53,209 @@ def download_script(script_url): kml_file = None # Extract all contents of the ZIP file and remember .dsl and .kml filenames - with ZipFile(filename, 'r') as z: + with ZipFile(filename, "r") as z: z.extractall(path=compiler_path) for file_name in z.namelist(): - if file_name.endswith('.dsl'): + if file_name.endswith(".dsl"): dsl_file = file_name - elif file_name.endswith('.kml'): + elif file_name.endswith(".kml"): kml_file = file_name # Log or return the results logger.info(f"Extracted DSL files: {dsl_file}") logger.info(f"Extracted KML files: {kml_file}") - + return dsl_file, kml_file except Exception as e: logger.error(f"Error during download or extraction: {e}") - + + def compile_mission(dsl_file, kml_file, drone_list, alt, compiler_file): # Construct the full paths for the DSL and KML files dsl_file_path = os.path.join(compiler_path, dsl_file) kml_file_path = os.path.join(compiler_path, kml_file) jar_path = os.path.join(compiler_path, compiler_file) altitude = str(alt) - + # Define the command and arguments command = [ "java", - "-jar", jar_path, - "-d", drone_list, - "-s", dsl_file_path, - "-k", kml_file_path, - "-o", output_path, - "-p", platform_path, - "-a", altitude + "-jar", + jar_path, + "-d", + drone_list, + "-s", + dsl_file_path, + "-k", + kml_file_path, + "-o", + output_path, + "-p", + platform_path, + "-a", + altitude, ] - + # Run the command logger.info(f"Running command: {' '.join(command)}") result = subprocess.run(command, check=True, capture_output=True, text=True) - + # Log the output logger.info(f"Compilation output: {result.stdout}") - + # Output the results logger.info("Compilation successful.") - + def send_to_drone(msg, base_url, drone_list, cmd_front_cmdr_sock, redis): try: - logger.info(f"Sending request to drone...") - # Send the command to each drone + logger.info("Sending request to drone...") + # Send the command to each drone for drone_id in drone_list: # check if the cmd is a mission - if (base_url): + if base_url: # reconstruct the script url with the correct compiler output path msg.cmd.script_url = f"{base_url}{output_path}{drone_id}.ms" logger.info(f"script url: {msg.cmd.script_url}") - + # send the command to the drone - cmd_front_cmdr_sock.send_multipart([drone_id.encode('utf-8'), msg.SerializeToString()]) - logger.info(f'Delivered request to drone {drone_id}:\n {text_format.MessageToString(msg)}') - + cmd_front_cmdr_sock.send_multipart([drone_id.encode("utf-8"), msg.SerializeToString()]) + logger.info( + f"Delivered request to drone {drone_id}:\n {text_format.MessageToString(msg)}" + ) + # store the record in redis key = redis.xadd( - f"commands", - {"commander": msg.commander_id, "drone": drone_id, "value": text_format.MessageToString(msg),} + "commands", + { + "commander": msg.commander_id, + "drone": drone_id, + "value": text_format.MessageToString(msg), + }, ) logger.debug(f"Updated redis under stream commands at key {key}") except Exception as e: logger.error(f"Error sending request to drone: {e}") - + def listen_cmdrs(cmdr_sock, cmd_front_cmdr_sock, redis, alt, compiler_file): while True: - # Listen for incoming requests from cmdr req = cmdr_sock.recv() try: msg = cnc_pb2.Extras() msg.ParseFromString(req) - logger.info(f'Request received:\n{text_format.MessageToString(msg)}') + logger.info(f"Request received:\n{text_format.MessageToString(msg)}") except DecodeError: - cmdr_sock.send(b'Error decoding protobuf. Did you send a cnc_pb2?') - logger.info(f'Error decoding protobuf. Did you send a cnc_pb2?') + cmdr_sock.send(b"Error decoding protobuf. Did you send a cnc_pb2?") + logger.info("Error decoding protobuf. Did you send a cnc_pb2?") continue - + # get the drone list try: drone_list_json = msg.cmd.for_drone_id drone_list = json.loads(drone_list_json) logger.info(f"drone list: {drone_list}") except json.JSONDecodeError: - cmdr_sock.send(b'Error decoding drone list. Did you send a JSON list?') - logger.info(f'Error decoding drone list. Did you send a JSON list?') + cmdr_sock.send(b"Error decoding drone list. Did you send a JSON list?") + logger.info("Error decoding drone list. Did you send a JSON list?") continue - + # Check if the command contains a mission and compile it if true base_url = None - if (msg.cmd.script_url): + if msg.cmd.script_url: # download the script script_url = msg.cmd.script_url logger.info(f"script url: {script_url}") dsl, kml = download_script(script_url) - + # compile the mission drone_list_revised = "&".join(drone_list) logger.info(f"drone list revised: {drone_list_revised}") compile_mission(dsl, kml, drone_list_revised, alt, compiler_file) - + # get the base url parsed_url = urlparse(script_url) base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - # send the command to the drone send_to_drone(msg, base_url, drone_list, cmd_front_cmdr_sock, redis) - - - cmdr_sock.send(b'ACK') - logger.info('Sent ACK to commander') + cmdr_sock.send(b"ACK") + logger.info("Sent ACK to commander") + def main(): parser = argparse.ArgumentParser() - parser.add_argument('-d', '--droneport', type=int, default=5003, help='Specify port to listen for drone requests [default: 5003]') - parser.add_argument('-c', '--cmdrport', type=int, default=6001, help='Specify port to listen for commander requests [default: 6001]') parser.add_argument( - "-r", "--redis", type=int, default=6379, help="Set port number for redis connection [default: 6379]" + "-d", + "--droneport", + type=int, + default=5003, + help="Specify port to listen for drone requests [default: 5003]", + ) + parser.add_argument( + "-c", + "--cmdrport", + type=int, + default=6001, + help="Specify port to listen for commander requests [default: 6001]", ) parser.add_argument( - "-a", "--auth", default="", help="Share key for redis user." + "-r", + "--redis", + type=int, + default=6379, + help="Set port number for redis connection [default: 6379]", ) + parser.add_argument("-a", "--auth", default="", help="Share key for redis user.") parser.add_argument( "--altitude", type=int, default=15, help="base altitude for the drones mission" ) parser.add_argument( - "--compiler_file", default='compile-1.5-full.jar', help="compiler file name" - ) + "--compiler_file", default="compile-1.5-full.jar", help="compiler file name" + ) args = parser.parse_args() - + # Set the altitude alt = args.altitude logger.info(f"Starting control plane with altitude {alt}...") - + compiler_file = args.compiler_file logger.info(f"Using compiler file: {compiler_file}") - + # Connect to redis - r = redis.Redis(host='redis', port=args.redis, username='steeleagle', password=f'{args.auth}',decode_responses=True) + r = redis.Redis( + host="redis", + port=args.redis, + username="steeleagle", + password=f"{args.auth}", + decode_responses=True, + ) logger.info(f"Connected to redis on port {args.redis}...") # Set up the commander socket ctx = zmq.Context() cmdr_sock = ctx.socket(zmq.REP) - cmdr_sock.bind(f'tcp://*:{args.cmdrport}') - logger.info(f'Listening on tcp://*:{args.cmdrport} for commander requests...') + cmdr_sock.bind(f"tcp://*:{args.cmdrport}") + logger.info(f"Listening on tcp://*:{args.cmdrport} for commander requests...") # Set up the drone socket async_ctx = zmq.asyncio.Context() cmd_front_cmdr_sock = async_ctx.socket(zmq.ROUTER) cmd_front_cmdr_sock.setsockopt(zmq.ROUTER_HANDOVER, 1) - cmd_front_cmdr_sock.bind(f'tcp://*:{args.droneport}') - logger.info(f'Listening on tcp://*:{args.droneport} for drone requests...') - + cmd_front_cmdr_sock.bind(f"tcp://*:{args.droneport}") + logger.info(f"Listening on tcp://*:{args.droneport} for drone requests...") + # Listen for incoming requests from cmdr try: listen_cmdrs(cmdr_sock, cmd_front_cmdr_sock, r, alt, compiler_file) except KeyboardInterrupt: - logger.info('Shutting down...') + logger.info("Shutting down...") cmdr_sock.close() cmd_front_cmdr_sock.close() -if __name__ == '__main__': - main() \ No newline at end of file + +if __name__ == "__main__": + main() diff --git a/cnc/server/telemetry.py b/cnc/server/telemetry.py index ec7eb1a0..b134172c 100755 --- a/cnc/server/telemetry.py +++ b/cnc/server/telemetry.py @@ -5,39 +5,39 @@ # # SPDX-License-Identifier: GPL-2.0-only +import argparse +import logging + from gabriel_server.network_engine import engine_runner from telemetry_engine import TelemetryEngine -import logging -import argparse -SOURCE = 'telemetry' +SOURCE = "telemetry" logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + def main(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument( - "-p", "--port", type=int, default=9099, help="Set port number" - ) + parser.add_argument("-p", "--port", type=int, default=9099, help="Set port number") parser.add_argument( - "-g", "--gabriel", default="tcp://gabriel-server:5555", help="Gabriel server endpoint." + "-g", "--gabriel", default="tcp://gabriel-server:5555", help="Gabriel server endpoint." ) parser.add_argument( - "-r", "--redis", type=int, default=6379, help="Set port number for redis connection [default: 6379]" + "-r", + "--redis", + type=int, + default=6379, + help="Set port number for redis connection [default: 6379]", ) - parser.add_argument( - "-a", "--auth", default="", help="Share key for redis user." - ) + parser.add_argument("-a", "--auth", default="", help="Share key for redis user.") parser.add_argument( - "-l", "--publish", action='store_true', help="Publish incoming images via redis" + "-l", "--publish", action="store_true", help="Publish incoming images via redis" ) args, _ = parser.parse_known_args() @@ -46,7 +46,13 @@ def engine_setup(): return TelemetryEngine(args) logger.info("Starting telemetry cognitive engine..") - engine_runner.run(engine=engine_setup(), source_name=SOURCE, server_address=args.gabriel, all_responses_required=True) + engine_runner.run( + engine=engine_setup(), + source_name=SOURCE, + server_address=args.gabriel, + all_responses_required=True, + ) + if __name__ == "__main__": main() diff --git a/cnc/server/telemetry_engine.py b/cnc/server/telemetry_engine.py index 5fa01f83..7ef743e6 100644 --- a/cnc/server/telemetry_engine.py +++ b/cnc/server/telemetry_engine.py @@ -5,46 +5,63 @@ # # SPDX-License-Identifier: GPL-2.0-only -import time import datetime import logging -from gabriel_server import cognitive_engine -from gabriel_protocol import gabriel_pb2 -from cnc_protocol import cnc_pb2 -import redis import os -from PIL import Image +import time + import cv2 import numpy as np +import redis +from cnc_protocol import cnc_pb2 +from gabriel_protocol import gabriel_pb2 +from gabriel_server import cognitive_engine +from PIL import Image logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) + class TelemetryEngine(cognitive_engine.Engine): ENGINE_NAME = "telemetry" def __init__(self, args): logger.info("Telemetry engine intializing...") - self.r = redis.Redis(host='redis', port=args.redis, username='steeleagle', password=f'{args.auth}',decode_responses=True) + self.r = redis.Redis( + host="redis", + port=args.redis, + username="steeleagle", + password=f"{args.auth}", + decode_responses=True, + ) self.r.ping() logger.info(f"Connected to redis on port {args.redis}...") - self.storage_path = os.getcwd()+"/images/" + self.storage_path = os.getcwd() + "/images/" try: - os.makedirs(self.storage_path+"/raw") + os.makedirs(self.storage_path + "/raw") except FileExistsError: logger.info("Images directory already exists.") - logger.info("Storing detection images at {}".format(self.storage_path)) + logger.info(f"Storing detection images at {self.storage_path}") self.current_path = None self.publish = args.publish def updateDroneStatus(self, extras): key = self.r.xadd( - f"telemetry.{extras.drone_id}", - {"latitude": extras.location.latitude, "longitude": extras.location.longitude, "altitude": extras.location.altitude, - "rssi": extras.status.rssi, "battery": extras.status.battery, "mag": extras.status.mag, "bearing": int(extras.status.bearing)}, - ) - logger.debug(f"Updated status of {extras.drone_id} in redis under stream telemetry at key {key}") + f"telemetry.{extras.drone_id}", + { + "latitude": extras.location.latitude, + "longitude": extras.location.longitude, + "altitude": extras.location.altitude, + "rssi": extras.status.rssi, + "battery": extras.status.battery, + "mag": extras.status.mag, + "bearing": int(extras.status.bearing), + }, + ) + logger.debug( + f"Updated status of {extras.drone_id} in redis under stream telemetry at key {key}" + ) def handle(self, input_frame): extras = cognitive_engine.unpack_extras(cnc_pb2.Extras, input_frame) @@ -54,9 +71,9 @@ def handle(self, input_frame): result_wrapper.result_producer_name.value = self.ENGINE_NAME result = None - + if input_frame.payload_type == gabriel_pb2.PayloadType.TEXT: - if extras.drone_id is not "": + if extras.drone_id != "": if extras.registering: logger.info(f"Drone [{extras.drone_id}] connected.") if not os.path.exists(f"{self.storage_path}/raw/{extras.drone_id}"): @@ -69,30 +86,32 @@ def handle(self, input_frame): result = gabriel_pb2.ResultWrapper.Result() result.payload_type = gabriel_pb2.PayloadType.TEXT - result.payload = "Telemetry updated.".encode(encoding="utf-8") + result.payload = b"Telemetry updated." self.updateDroneStatus(extras) elif input_frame.payload_type == gabriel_pb2.PayloadType.IMAGE: image_np = np.fromstring(input_frame.payloads[0], dtype=np.uint8) - #have redis publish the latest image + # have redis publish the latest image if self.publish: logger.info(f"Publishing image to redis under imagery.{extras.drone_id} topic.") - self.r.publish(f'imagery.{extras.drone_id}', input_frame.payloads[0]) - #store images in the shared volume + self.r.publish(f"imagery.{extras.drone_id}", input_frame.payloads[0]) + # store images in the shared volume try: img = cv2.imdecode(image_np, cv2.IMREAD_COLOR) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = Image.fromarray(img) img.save(f"{self.current_path}/{str(int(time.time() * 1000))}.jpg", format="JPEG") img.save(f"{self.storage_path}/raw/{extras.drone_id}/temp.jpg", format="JPEG") - os.rename(f"{self.storage_path}/raw/{extras.drone_id}/temp.jpg", f"{self.storage_path}/raw/{extras.drone_id}/latest.jpg") + os.rename( + f"{self.storage_path}/raw/{extras.drone_id}/temp.jpg", + f"{self.storage_path}/raw/{extras.drone_id}/latest.jpg", + ) logger.info(f"Updated {self.storage_path}/raw/{extras.drone_id}/latest.jpg") except Exception as e: logger.error(f"Exception trying to store imagery: {e}") - + # only append the result if it has a payload # e.g. in the elif block where we received an image from the streaming thread, we don't add a payload if result is not None: result_wrapper.results.append(result) return result_wrapper - diff --git a/cnc/server/test/compiler_test.py b/cnc/server/test/compiler_test.py index 89dfc02f..f7bd581d 100644 --- a/cnc/server/test/compiler_test.py +++ b/cnc/server/test/compiler_test.py @@ -1,9 +1,11 @@ import os import subprocess -compiler_path = '/compiler' -output_path = '/compiler/out/flightplan_' -platform_path = '/compiler/python' +compiler_path = "/compiler" +output_path = "/compiler/out/flightplan_" +platform_path = "/compiler/python" + + def compile_mission(dsl_file, kml_file, drone_list): # Construct the full paths for the DSL and KML files dsl_file_path = os.path.join(compiler_path, dsl_file) @@ -13,28 +15,34 @@ def compile_mission(dsl_file, kml_file, drone_list): # Define the command and arguments command = [ "java", - "-jar", jar_path, - "-d", drone_list, - "-s", dsl_file_path, - "-k", kml_file_path, - "-o", output_path, - "-p", platform_path + "-jar", + jar_path, + "-d", + drone_list, + "-s", + dsl_file_path, + "-k", + kml_file_path, + "-o", + output_path, + "-p", + platform_path, ] - + try: # Run the command - result = subprocess.run(command, check=True, capture_output=True, text=True) + result = subprocess.run(command, check=True, capture_output=True, text=True) print("output: ", result) # Output the results print("Compilation successful.") - + except subprocess.CalledProcessError as e: print("Error output:", e.stderr) -if __name__ == '__main__': - kml_file = 'tst.kml' - dsl_file = 'tst.dsl' - drone_list = 'ant&mamba' - compile_mission(dsl_file, kml_file, drone_list) \ No newline at end of file +if __name__ == "__main__": + kml_file = "tst.kml" + dsl_file = "tst.dsl" + drone_list = "ant&mamba" + compile_mission(dsl_file, kml_file, drone_list) diff --git a/cnc/server/test/streaming_test_socket.py b/cnc/server/test/streaming_test_socket.py index f5265b18..981658b4 100755 --- a/cnc/server/test/streaming_test_socket.py +++ b/cnc/server/test/streaming_test_socket.py @@ -2,64 +2,67 @@ # # SPDX-License-Identifier: GPL-2.0-only +import argparse import socket +import time + import cv2 import numpy as np -import time import zmq -import argparse -import sys -HOST='' -PORT=8485 -LOC = {'latitude': 40.4136589, 'longitude': -79.9495332, 'altitude': 10} +HOST = "" +PORT = 8485 +LOC = {"latitude": 40.4136589, "longitude": -79.9495332, "altitude": 10} + # Required for sending image over zmq def send_array(sock, A, flags=0, copy=True, track=False): """send a numpy array with metadata""" global LOC - md = dict( - dtype = str(A.dtype), - shape = A.shape, - location = LOC, - model = 'robomaster' - ) - sock.send_json(md, flags|zmq.SNDMORE) + md = dict(dtype=str(A.dtype), shape=A.shape, location=LOC, model="robomaster") + sock.send_json(md, flags | zmq.SNDMORE) return sock.send(A, flags, copy=copy, track=track) def recv_from_nonblocking(fd: socket, size: int) -> bytes: - buf, received = [], 0 - while received < size: - buf.append(fd.recv(size - received)) - received += len(buf[-1]) - return b"".join(buf) + buf, received = [], 0 + while received < size: + buf.append(fd.recv(size - received)) + received += len(buf[-1]) + return b"".join(buf) -def _main(): - parser = argparse.ArgumentParser(prog='test_socket', - description='Receives image frames from Android streaming test app.') - parser.add_argument('-p', '--port', default=8485, - help='Specify port to listen on [default: 8485]') - parser.add_argument('-zp', '--zmq_port', default=5555, - help='Specify zmq port to publish to [default: 5555]') - parser.add_argument('-s', '--store', action='store_true', - help='Store images locally on disk [default: False]') - parser.add_argument('-z', '--zmq', action='store_true', - help='Send images over zmq to OpenScout (assumes OpenScout is listening locally) [default: False]') - parser.add_argument('-d', '--direct_send', action='store_true') +def _main(): + parser = argparse.ArgumentParser( + prog="test_socket", description="Receives image frames from Android streaming test app." + ) + parser.add_argument( + "-p", "--port", default=8485, help="Specify port to listen on [default: 8485]" + ) + parser.add_argument( + "-zp", "--zmq_port", default=5555, help="Specify zmq port to publish to [default: 5555]" + ) + parser.add_argument( + "-s", "--store", action="store_true", help="Store images locally on disk [default: False]" + ) + parser.add_argument( + "-z", + "--zmq", + action="store_true", + help="Send images over zmq to OpenScout (assumes OpenScout is listening locally) [default: False]", + ) + parser.add_argument("-d", "--direct_send", action="store_true") args = parser.parse_args() - - s=socket.socket(socket.AF_INET,socket.SOCK_STREAM) + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - s.bind((HOST,args.port)) - print(f'Socket bound to port: {args.port}') + s.bind((HOST, args.port)) + print(f"Socket bound to port: {args.port}") s.listen(10) - print('Socket now listening...') + print("Socket now listening...") - conn,addr=s.accept() + conn, addr = s.accept() data = b"" if args.zmq: context = zmq.Context() @@ -68,10 +71,12 @@ def _main(): try: with conn: - if(args.zmq): - print(f"Publishing images for OpenScout client's ZmqAdapter on port {args.zmq_port}..") + if args.zmq: + print( + f"Publishing images for OpenScout client's ZmqAdapter on port {args.zmq_port}.." + ) zmq_socket = context.socket(zmq.PUB) - zmq_socket.bind(f'tcp://*:{args.zmq_port}') + zmq_socket.bind(f"tcp://*:{args.zmq_port}") print(f"Client connected {addr}") frames_received = 0 @@ -82,42 +87,40 @@ def _main(): start = time.time() header = recv_from_nonblocking(conn, 4) size = int.from_bytes(header, "big") - #print(f"About to receive an image of {size} bytes...") + # print(f"About to receive an image of {size} bytes...") data = recv_from_nonblocking(conn, size) frames_received += 1 - #print(f"Frames Received: {frames_received} Data Length: {len(data)}") + # print(f"Frames Received: {frames_received} Data Length: {len(data)}") now = time.time() if now - lastprint > 5: - print( - "avg fps: {0:.2f}".format( - (frames_received - lastcount) / 5) - ) + print(f"avg fps: {(frames_received - lastcount) / 5:.2f}") print() lastcount = frames_received lastprint = now if not args.direct_send: frame = cv2.imdecode(np.fromstring(data, np.uint8), cv2.IMREAD_COLOR) - #cv2frame = cv2.cvtColor(frame.as_ndarray(), cv2.COLOR_YUV2BGR_I420) + # cv2frame = cv2.cvtColor(frame.as_ndarray(), cv2.COLOR_YUV2BGR_I420) else: packets = codec.parse(data) for packet in packets: frames = codec.decode(packet) frame = frames[-1] - if(args.zmq): + if args.zmq: print(f"Publishing frame {frames_received} to OpenScout client...") send_array(zmq_socket, frame) - if(args.store): - cv2.imwrite(f'{frames_received}.jpg', frame) - cv2.imshow('server',frame) + if args.store: + cv2.imwrite(f"{frames_received}.jpg", frame) + cv2.imshow("server", frame) cv2.waitKey(1) except KeyboardInterrupt: s.close() - print('Socket closed') + print("Socket closed") + if __name__ == "__main__": _main() diff --git a/cnc/streamlit/commander.py b/cnc/streamlit/commander.py index 68543c16..3006fe5d 100644 --- a/cnc/streamlit/commander.py +++ b/cnc/streamlit/commander.py @@ -2,22 +2,17 @@ # # SPDX-License-Identifier: GPL-2.0-only -import streamlit as st -import pandas as pd -from streamlit_autorefresh import st_autorefresh +import datetime +import os + import folium -from folium.plugins import Draw -from streamlit_folium import st_folium -from st_keypressed import st_keypressed -import streamlit_antd_components as sac import redis -import os -from cnc_protocol import cnc_pb2 +import streamlit as st import zmq -import time -import numpy as np -import cv2 -import datetime +from cnc_protocol import cnc_pb2 +from st_keypressed import st_keypressed +from streamlit_autorefresh import st_autorefresh +from streamlit_folium import st_folium st.set_page_config( page_title="Commander", @@ -46,20 +41,31 @@ if "map_server" not in st.session_state: st.session_state.map_server = "Google Hybrid" if "red" not in st.session_state: - st.session_state.red = redis.Redis(host=st.secrets.redis, port=st.secrets.redis_port, username=st.secrets.redis_user, password=st.secrets.redis_pw,decode_responses=True) + st.session_state.red = redis.Redis( + host=st.secrets.redis, + port=st.secrets.redis_port, + username=st.secrets.redis_user, + password=st.secrets.redis_pw, + decode_responses=True, + ) try: st.session_state.red.ping() except redis.ConnectionError: st.error("Cannot connect to Redis!") if "subscriber" not in st.session_state: - red2 = redis.Redis(host=st.secrets.redis, port=st.secrets.redis_port, username=st.secrets.redis_user, password=st.secrets.redis_pw) + red2 = redis.Redis( + host=st.secrets.redis, + port=st.secrets.redis_port, + username=st.secrets.redis_user, + password=st.secrets.redis_pw, + ) st.session_state.subscriber = red2.pubsub(ignore_subscribe_messages=True) - st.session_state.subscriber.psubscribe('imagery.*') + st.session_state.subscriber.psubscribe("imagery.*") if "zmq" not in st.session_state: ctx = zmq.Context() st.session_state.zmq = ctx.socket(zmq.REQ) - st.session_state.zmq.connect(f'tcp://{st.secrets.zmq}:{st.secrets.zmq_port}') + st.session_state.zmq.connect(f"tcp://{st.secrets.zmq}:{st.secrets.zmq_port}") if "last_image" not in st.session_state: st.session_state.last_image = None @@ -83,7 +89,7 @@ "mag": int, "sats": int, } -l=[] +l = [] for k in st.session_state.red.keys("telemetry.*"): l.append(k.split(".")[-1]) st.session_state.list = l @@ -91,9 +97,15 @@ st.session_state.selected_drone = l[0] if st.session_state.selected_drone is not None: - results = st.session_state.red.xrevrange(f"telemetry.{st.session_state.selected_drone}", "+", "-", 1) + results = st.session_state.red.xrevrange( + f"telemetry.{st.session_state.selected_drone}", "+", "-", 1 + ) telemetry = results[0][1] - telemetry["last_update"] = datetime.datetime.strftime(datetime.datetime.fromtimestamp(int(results[0][0].split("-")[0])/1000), "%d-%b-%Y %H:%M:%S") + telemetry["last_update"] = datetime.datetime.strftime( + datetime.datetime.fromtimestamp(int(results[0][0].split("-")[0]) / 1000), + "%d-%b-%Y %H:%M:%S", + ) + def run_flightscript(): if st.session_state.script_file is None: @@ -129,9 +141,7 @@ def enable_manual(): req.cmd.for_drone_id = st.session_state.selected_drone st.session_state.zmq.send(req.SerializeToString()) rep = st.session_state.zmq.recv() - st.toast( - f"Assuming manual control of {st.session_state.selected_drone}! Kill signal sent." - ) + st.toast(f"Assuming manual control of {st.session_state.selected_drone}! Kill signal sent.") def rth(): @@ -172,7 +182,7 @@ def rth(): help="Upload a flight script.", type=["ms"], ) - # st.divider() + # st.divider() st.button( key="manual_button", label=":joystick: Manual Control", @@ -193,8 +203,8 @@ def rth(): ) if st.session_state.manual_control: - st.subheader(f":blue[Manual Control Enabled]") - #st.subheader(":red[Manual Speed Controls]", divider="gray") + st.subheader(":blue[Manual Control Enabled]") + # st.subheader(":red[Manual Speed Controls]", divider="gray") st.sidebar.slider( key="pitch_speed", label="Drone Pitch (forward/backward)", @@ -228,9 +238,9 @@ def rth(): step=5, ) elif st.session_state.rth_sent: - st.subheader(f":orange[Return to Home Initiated]") + st.subheader(":orange[Return to Home Initiated]") elif st.session_state.script_file is not None: - st.subheader(f":violet[Autonomous Mode Enabled]") + st.subheader(":violet[Autonomous Mode Enabled]") with c2: @@ -252,22 +262,20 @@ def rth(): attr="Google", ) fg = folium.FeatureGroup(name="markers") - #Draw(export=True).add_to(m) + # Draw(export=True).add_to(m) if st.session_state.selected_drone is not None: plane = folium.Icon( icon="plane", color="red", prefix="glyphicon", - angle=int( - telemetry['bearing'] - ), + angle=int(telemetry["bearing"]), ) fg.add_child( folium.Marker( location=[ - telemetry['latitude'], - telemetry['longitude'], + telemetry["latitude"], + telemetry["longitude"], ], tooltip=st.session_state.selected_drone, icon=plane, @@ -310,13 +318,11 @@ def rth(): status_cols2 = st.columns(2) status_cols2[0].metric( label="Magnetometer", - value=MAG_STATE[ - int(telemetry['mag']) - ], + value=MAG_STATE[int(telemetry["mag"])], ) st.metric( - label="Last Update", - value=f"{telemetry['last_update']}", + label="Last Update", + value=f"{telemetry['last_update']}", ) # status_cols2[1].metric( # label="Status", @@ -325,7 +331,7 @@ def rth(): # st.session_state.prev_sats = int( # st.session_state.list.loc[st.session_state.selected_drone].sats # ) - st.session_state.prev_alt = float(telemetry['altitude']) + st.session_state.prev_alt = float(telemetry["altitude"]) with c3: tab1, tab2, tab3 = st.tabs(["Live", "Obstacle Avoidance", "Object Detection"]) @@ -337,17 +343,15 @@ def rth(): # img = cv2.imdecode(image_np, cv2.IMREAD_COLOR) # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # tab1.image(img) - #tab1.image(f"http://{st.secrets.webrtc}/api/frame.jpeg?src=file2&a={time.time()}", use_column_width="auto") + # tab1.image(f"http://{st.secrets.webrtc}/api/frame.jpeg?src=file2&a={time.time()}", use_column_width="auto") tab1.image(f"../server/steeleagle-vol/raw/{st.session_state.selected_drone}/latest.jpg") - #tab1.image(f"http://{st.secrets.webserver}/raw/{st.session_state.selected_drone}/latest.jpg?a={time.time()}") + # tab1.image(f"http://{st.secrets.webserver}/raw/{st.session_state.selected_drone}/latest.jpg?a={time.time()}") tab2.image(f"http://{st.secrets.webserver}/moa/latest.jpg", use_column_width="auto") - tab3.image( - f"http://{st.secrets.webserver}/detected/latest.jpg", use_column_width="auto" - ) + tab3.image(f"http://{st.secrets.webserver}/detected/latest.jpg", use_column_width="auto") - #st.write(f":keyboard: {st.session_state.key_pressed}") + # st.write(f":keyboard: {st.session_state.key_pressed}") st.session_state.key_pressed = st_keypressed() - if st.session_state.manual_control and st.session_state.selected_drone != None: + if st.session_state.manual_control and st.session_state.selected_drone is not None: req = cnc_pb2.Extras() req.commander_id = os.uname()[1] req.cmd.for_drone_id = st.session_state.selected_drone @@ -376,7 +380,7 @@ def rth(): yaw = 1 * st.session_state.yaw_speed elif st.session_state.key_pressed == "j": yaw = -1 * st.session_state.yaw_speed - #st.toast(f"PCMD(pitch = {pitch}, roll = {roll}, yaw = {yaw}, gaz = {gaz})") + # st.toast(f"PCMD(pitch = {pitch}, roll = {roll}, yaw = {yaw}, gaz = {gaz})") req.cmd.pcmd.yaw = yaw req.cmd.pcmd.pitch = pitch req.cmd.pcmd.roll = roll diff --git a/cnc/streamlit/commander_helper.py b/cnc/streamlit/commander_helper.py index 4b77a4bf..9b527614 100644 --- a/cnc/streamlit/commander_helper.py +++ b/cnc/streamlit/commander_helper.py @@ -5,15 +5,14 @@ # # SPDX-License-Identifier: GPL-2.0-only -import numpy as np -from gabriel_protocol import gabriel_pb2 -from gabriel_client.websocket_client import ProducerWrapper, WebsocketClient +import argparse import logging +import os + import zmq from cnc_protocol import cnc_pb2 -import argparse -import os -import asyncio +from gabriel_client.websocket_client import ProducerWrapper, WebsocketClient +from gabriel_protocol import gabriel_pb2 logger = logging.getLogger(__name__) @@ -49,7 +48,7 @@ async def producer(): extras = cnc_pb2.Extras() extras.commander_id = self.commander_id - if command != None: + if command is not None: extras.cmd.for_drone_id = command["drone_id"] if "kill" in command: extras.cmd.halt = True diff --git a/cnc/streamlit/overview.py b/cnc/streamlit/overview.py index 3f8797de..899b355e 100644 --- a/cnc/streamlit/overview.py +++ b/cnc/streamlit/overview.py @@ -2,17 +2,26 @@ # # SPDX-License-Identifier: GPL-2.0-only +import json import os import time -import json from zipfile import ZipFile -from cnc_protocol import cnc_pb2 + import folium import streamlit as st -from streamlit_folium import st_folium +from cnc_protocol import cnc_pb2 from folium.plugins import MiniMap -from util import stream_to_dataframe, connect_redis, connect_zmq, get_drones, menu, COLORS, authenticated from st_keypressed import st_keypressed +from streamlit_folium import st_folium +from util import ( + COLORS, + authenticated, + connect_redis, + connect_zmq, + get_drones, + menu, + stream_to_dataframe, +) if "map_server" not in st.session_state: st.session_state.map_server = "Google Hybrid" @@ -25,7 +34,7 @@ if "script_file" not in st.session_state: st.session_state.script_file = None if "inactivity_time" not in st.session_state: - st.session_state.inactivity_time = 1 #min + st.session_state.inactivity_time = 1 # min if "trail_length" not in st.session_state: st.session_state.trail_length = 500 if "armed" not in st.session_state: @@ -48,10 +57,10 @@ page_icon=":military_helmet:", layout="wide", menu_items={ - 'Get help': 'https://cmusatyalab.github.io/steeleagle/', - 'Report a bug': "https://github.com/cmusatyalab/steeleagle/issues", - 'About': "SteelEagle - Automated drone flights for visual inspection tasks\n https://github.com/cmusatyalab/steeleagle" - } + "Get help": "https://cmusatyalab.github.io/steeleagle/", + "Report a bug": "https://github.com/cmusatyalab/steeleagle/issues", + "About": "SteelEagle - Automated drone flights for visual inspection tasks\n https://github.com/cmusatyalab/steeleagle", + }, ) if "zmq" not in st.session_state: @@ -62,14 +71,13 @@ red = connect_redis() + def change_center(): if st.session_state.tracking_selection is not None: df = stream_to_dataframe( - red.xrevrange( - f"telemetry.{st.session_state.tracking_selection}", "+", "-", 1 - ) + red.xrevrange(f"telemetry.{st.session_state.tracking_selection}", "+", "-", 1) ) - for index, row in df.iterrows(): + for _index, row in df.iterrows(): st.session_state.center = [row["latitude"], row["longitude"]] @@ -79,7 +87,7 @@ def run_flightscript(): else: filename = f"{time.time_ns()}.ms" path = f"{st.secrets.scripts_path}/{filename}" - with ZipFile(path, 'w') as z: + with ZipFile(path, "w") as z: for file in st.session_state.script_file: z.writestr(file.name, file.read()) @@ -94,6 +102,7 @@ def run_flightscript(): icon="\u2601", ) + def enable_manual(): req = cnc_pb2.Extras() req.cmd.halt = True @@ -101,9 +110,8 @@ def enable_manual(): req.cmd.for_drone_id = json.dumps([d for d in st.session_state.selected_drones]) st.session_state.zmq.send(req.SerializeToString()) rep = st.session_state.zmq.recv() - st.toast( - f"Telling drone {req.cmd.for_drone_id} to halt! Kill signal sent." - ) + st.toast(f"Telling drone {req.cmd.for_drone_id} to halt! Kill signal sent.") + def rth(): req = cnc_pb2.Extras() @@ -115,6 +123,7 @@ def rth(): rep = st.session_state.zmq.recv() st.toast(f"Instructed {req.cmd.for_drone_id} to return to home!") + @st.fragment(run_every=f"{1/st.session_state.imagery_framerate}s") def update_imagery(): drone_list = [] @@ -123,8 +132,8 @@ def update_imagery(): hsv_header = "**:traffic_light: HSV Filtering**" for k in red.keys("telemetry.*"): df = stream_to_dataframe(red.xrevrange(f"{k}", "+", "-", st.session_state.trail_length)) - last_update = (int(df.index[0].split("-")[0])/1000) - if time.time() - last_update < st.session_state.inactivity_time * 60: # minutes -> seconds + last_update = int(df.index[0].split("-")[0]) / 1000 + if time.time() - last_update < st.session_state.inactivity_time * 60: # minutes -> seconds drone_name = k.split(".")[-1] drone_list.append(drone_name) drone_list.append(detected_header) @@ -132,18 +141,30 @@ def update_imagery(): drone_list.append(hsv_header) tabs = st.tabs(drone_list) - i = 0 - for d in drone_list: + for i, d in enumerate(drone_list): with tabs[i]: if d == detected_header: - st.image(f"http://{st.secrets.webserver}/detected/latest.jpg?a={time.time()}", use_container_width=True) + st.image( + f"http://{st.secrets.webserver}/detected/latest.jpg?a={time.time()}", + use_container_width=True, + ) elif d == avoidance_header: - st.image(f"http://{st.secrets.webserver}/moa/latest.jpg?a={time.time()}", use_container_width=True) + st.image( + f"http://{st.secrets.webserver}/moa/latest.jpg?a={time.time()}", + use_container_width=True, + ) elif d == hsv_header: - st.image(f"http://{st.secrets.webserver}/detected/hsv.jpg?a={time.time()}", use_container_width=True) + st.image( + f"http://{st.secrets.webserver}/detected/hsv.jpg?a={time.time()}", + use_container_width=True, + ) else: - st.image(f"http://{st.secrets.webserver}/raw/{d}/latest.jpg?a={time.time()}", use_container_width=True) - i += 1 + st.image( + f"http://{st.secrets.webserver}/raw/{d}/latest.jpg?a={time.time()}", + use_container_width=True, + ) + + @st.fragment(run_every="1s") def draw_map(): m = folium.Map( @@ -161,20 +182,19 @@ def draw_map(): marker_color = 0 for k in red.keys("telemetry.*"): df = stream_to_dataframe(red.xrevrange(f"{k}", "+", "-", st.session_state.trail_length)) - last_update = (int(df.index[0].split("-")[0])/1000) - if time.time() - last_update < st.session_state.inactivity_time * 60: # minutes -> seconds + last_update = int(df.index[0].split("-")[0]) / 1000 + if time.time() - last_update < st.session_state.inactivity_time * 60: # minutes -> seconds coords = [] - i = 0 drone_name = k.split(".")[-1] - for index, row in df.iterrows(): + for i, (_index, row) in enumerate(df.iterrows()): if i % 10 == 0: coords.append([row["latitude"], row["longitude"]]) if i == 0: text = folium.DivIcon( - icon_size="null", #set the size to null so that it expands to the length of the string inside in the div + icon_size="null", # set the size to null so that it expands to the length of the string inside in the div icon_anchor=(-20, 30), html=f'
{drone_name} ({int(row["battery"])}%) [{row["altitude"]:.2f}m]', - #TODO: concatenate current task to html once it is sent i.e. PatrolTask
+ # TODO: concatenate current task to html once it is sent i.e. PatrolTask ) plane = folium.Icon( icon="plane", @@ -203,8 +223,6 @@ def draw_map(): ) ) - i += 1 - ls = folium.PolyLine(locations=coords, color=COLORS[marker_color]) ls.add_to(tracks) marker_color += 1 @@ -217,9 +235,10 @@ def draw_map(): layer_control=lc, returned_objects=[], center=st.session_state.center, - height=500 + height=500, ) + menu() options_expander = st.expander(" **:gray-background[:wrench: Toolbar]**", expanded=True) @@ -227,10 +246,7 @@ def draw_map(): map_options = ["Google Sat", "Google Hybrid"] tiles_col = st.columns(5) tiles_col[0].selectbox( - key="map_server", - label=":world_map: **:blue[Tile Server]**", - options=map_options, - index=0 + key="map_server", label=":world_map: **:blue[Tile Server]**", options=map_options, index=0 ) tiles_col[1].selectbox( @@ -241,8 +257,13 @@ def draw_map(): placeholder="Select a drone to track...", ) - - tiles_col[2].number_input(":heartbeat: **:red[Active Threshold (min)]**", step=1, min_value=1, key="inactivity_time", max_value=600000) + tiles_col[2].number_input( + ":heartbeat: **:red[Active Threshold (min)]**", + step=1, + min_value=1, + key="inactivity_time", + max_value=600000, + ) if st.session_state.map_server == "Google Sat": tileset = "https://mt0.google.com/vt/lyrs=s&hl=en&x={x}&y={y}&z={z}&s=Ga" @@ -253,28 +274,47 @@ def draw_map(): name=st.session_state.map_server, tiles=tileset, attr="Google", max_zoom=20 ) - tiles_col[3].number_input(":straight_ruler: **:gray[Trail Length]**", step=500, min_value=500, max_value=2500, key="trail_length") - mode = "**:green-background[:joystick: Manual Control Enabled (armed)]**" if st.session_state.armed else "**:red-background[:joystick: Manual Control Disabled (disarmed)]**" - tiles_col[4].number_input(key = "imagery_framerate", label=":camera: **:orange[Imagery FPS]**", min_value=1, max_value=10, step=1, value=2, format="%0d") + tiles_col[3].number_input( + ":straight_ruler: **:gray[Trail Length]**", + step=500, + min_value=500, + max_value=2500, + key="trail_length", + ) + mode = ( + "**:green-background[:joystick: Manual Control Enabled (armed)]**" + if st.session_state.armed + else "**:red-background[:joystick: Manual Control Disabled (disarmed)]**" + ) + tiles_col[4].number_input( + key="imagery_framerate", + label=":camera: **:orange[Imagery FPS]**", + min_value=1, + max_value=10, + step=1, + value=2, + format="%0d", + ) col1, col2 = st.columns([0.6, 0.4]) with col1: update_imagery() with col2: - st.caption("**:blue-background[:globe_with_meridians: Flight Tracking]**") - draw_map() + st.caption("**:blue-background[:globe_with_meridians: Flight Tracking]**") + draw_map() with st.sidebar: drone_list = get_drones() if len(drone_list) > 0: - st.pills(label=":helicopter: **:orange[Swarm Control]** :helicopter:", + st.pills( + label=":helicopter: **:orange[Swarm Control]** :helicopter:", options=drone_list.keys(), default=drone_list.keys(), format_func=lambda option: drone_list[option], selection_mode="multi", - key="selected_drones" - ) + key="selected_drones", + ) else: st.caption("No active drones.") @@ -287,8 +327,8 @@ def draw_map(): label="**:violet[Upload Autonomous Mission Script]**", help="Upload a flight script.", type=["kml", "dsl"], - label_visibility='visible', - accept_multiple_files=True + label_visibility="visible", + accept_multiple_files=True, ) st.button( key="autonomous_button", @@ -317,19 +357,59 @@ def draw_map(): if st.session_state.armed and len(st.session_state.selected_drones) > 0: c1, c2 = st.columns(spec=2, gap="small") - c1.number_input(key="pitch_speed", label="Pitch %", min_value=0, max_value=100, value=50, step=5, format="%d") - c2.number_input(key = "thrust_speed", label="Thrust %", min_value=0, max_value=100, step=5, value=50, format="%d") + c1.number_input( + key="pitch_speed", + label="Pitch %", + min_value=0, + max_value=100, + value=50, + step=5, + format="%d", + ) + c2.number_input( + key="thrust_speed", + label="Thrust %", + min_value=0, + max_value=100, + step=5, + value=50, + format="%d", + ) c3, c4 = st.columns(spec=2, gap="small") - c3.number_input(key = "yaw_speed", label="Yaw %", min_value=0, max_value=100, step=5, value=50, format="%d") - c4.number_input(key = "roll_speed", label="Roll %", min_value=0, max_value=100, step=5, value=50, format="%d") + c3.number_input( + key="yaw_speed", + label="Yaw %", + min_value=0, + max_value=100, + step=5, + value=50, + format="%d", + ) + c4.number_input( + key="roll_speed", + label="Roll %", + min_value=0, + max_value=100, + step=5, + value=50, + format="%d", + ) c5, c6 = st.columns(spec=2, gap="small") - c5.number_input(key = "gimbal_speed", label="Gimbal Pitch %", min_value=0, max_value=100, step=5, value=50, format="%d") + c5.number_input( + key="gimbal_speed", + label="Gimbal Pitch %", + min_value=0, + max_value=100, + step=5, + value=50, + format="%d", + ) key_pressed = st_keypressed() req = cnc_pb2.Extras() req.commander_id = os.uname()[1] req.cmd.for_drone_id = json.dumps([d for d in st.session_state.selected_drones]) - #req.cmd.manual = True + # req.cmd.manual = True st.caption(f"keypressed={key_pressed}") if key_pressed == "t": req.cmd.takeoff = True @@ -359,7 +439,9 @@ def draw_map(): gimbal_pitch = 1 * st.session_state.gimbal_speed elif key_pressed == "f": gimbal_pitch = -1 * st.session_state.gimbal_speed - st.caption(f"(pitch = {pitch}, roll = {roll}, yaw = {yaw}, thrust = {thrust}, gimbal = {gimbal_pitch})") + st.caption( + f"(pitch = {pitch}, roll = {roll}, yaw = {yaw}, thrust = {thrust}, gimbal = {gimbal_pitch})" + ) req.cmd.pcmd.yaw = yaw req.cmd.pcmd.pitch = pitch req.cmd.pcmd.roll = roll diff --git a/cnc/streamlit/pages/control.py b/cnc/streamlit/pages/control.py index b38bde76..3d8e95d4 100644 --- a/cnc/streamlit/pages/control.py +++ b/cnc/streamlit/pages/control.py @@ -2,28 +2,37 @@ # # SPDX-License-Identifier: GPL-2.0-only -import streamlit as st import asyncio import json -from zipfile import ZipFile -from st_keypressed import st_keypressed import os -from cnc_protocol import cnc_pb2 import time +from zipfile import ZipFile + import folium -from streamlit_folium import st_folium +import streamlit as st +from cnc_protocol import cnc_pb2 from folium.plugins import MiniMap -from util import stream_to_dataframe, get_drones, connect_redis, connect_zmq, menu, connect_redis_publisher, COLORS, authenticated +from st_keypressed import st_keypressed +from streamlit_folium import st_folium +from util import ( + COLORS, + authenticated, + connect_redis, + connect_redis_publisher, + connect_zmq, + menu, + stream_to_dataframe, +) st.set_page_config( page_title="Commander", page_icon=":helicopter:", layout="wide", menu_items={ - 'Get help': 'https://cmusatyalab.github.io/steeleagle/', - 'Report a bug': "https://github.com/cmusatyalab/steeleagle/issues", - 'About': "SteelEagle - Automated drone flights for visual inspection tasks\n https://github.com/cmusatyalab/steeleagle" - } + "Get help": "https://cmusatyalab.github.io/steeleagle/", + "Report a bug": "https://github.com/cmusatyalab/steeleagle/issues", + "About": "SteelEagle - Automated drone flights for visual inspection tasks\n https://github.com/cmusatyalab/steeleagle", + }, ) if "armed" not in st.session_state: @@ -73,14 +82,32 @@ if not authenticated(): st.stop() # Do not continue if not authenticated -async def update(live, avoidance, detection, hsv, status,): + +async def update( + live, + avoidance, + detection, + hsv, + status, +): try: while True: - live.image(f"http://{st.secrets.webserver}/raw/{st.session_state.selected_drone}/latest.jpg?a={time.time()}", use_column_width="auto") - avoidance.image(f"http://{st.secrets.webserver}/moa/latest.jpg?a={time.time()}", use_column_width="auto") - detection.image(f"http://{st.secrets.webserver}/detected/latest.jpg?a={time.time()}", use_column_width="auto") - hsv.image(f"http://{st.secrets.webserver}/detected/hsv.jpg?a={time.time()}", use_column_width="auto") - + live.image( + f"http://{st.secrets.webserver}/raw/{st.session_state.selected_drone}/latest.jpg?a={time.time()}", + use_column_width="auto", + ) + avoidance.image( + f"http://{st.secrets.webserver}/moa/latest.jpg?a={time.time()}", + use_column_width="auto", + ) + detection.image( + f"http://{st.secrets.webserver}/detected/latest.jpg?a={time.time()}", + use_column_width="auto", + ) + hsv.image( + f"http://{st.secrets.webserver}/detected/hsv.jpg?a={time.time()}", + use_column_width="auto", + ) # message = st.session_state.subscriber.get_message() # if message is not None: @@ -107,34 +134,57 @@ async def update(live, avoidance, detection, hsv, status,): min_value=0, max_value=100, ), - "mag": st.column_config.CheckboxColumn(label="Mag", width="small"), + "mag": st.column_config.CheckboxColumn(label="Mag", width="small"), "bearing": st.column_config.NumberColumn( "Heading", format="%d°", - ), + ), } - order = ("altitude", "bearing", "battery", "mag", "rssi",) + order = ( + "altitude", + "bearing", + "battery", + "mag", + "rssi", + ) - st.session_state.telemetry = stream_to_dataframe(st.session_state.redis.xrevrange(f"telemetry.{st.session_state.selected_drone}", "+", "-", 1)) - st.session_state.telemetry["latitude"] = st.session_state.telemetry["latitude"].clip(-90, 90) - st.session_state.telemetry["longitude"] = st.session_state.telemetry["longitude"].clip(-180, 180) - st.session_state.telemetry["mag"] = st.session_state.telemetry["mag"].transform(lambda x: x == 0) - status.dataframe(st.session_state.telemetry, hide_index=False, use_container_width=True, column_order=order, column_config=columns) - #map_container.map(data=st.session_state.telemetry, use_container_width=True, zoom=16, size=1) + st.session_state.telemetry = stream_to_dataframe( + st.session_state.redis.xrevrange( + f"telemetry.{st.session_state.selected_drone}", "+", "-", 1 + ) + ) + st.session_state.telemetry["latitude"] = st.session_state.telemetry["latitude"].clip( + -90, 90 + ) + st.session_state.telemetry["longitude"] = st.session_state.telemetry["longitude"].clip( + -180, 180 + ) + st.session_state.telemetry["mag"] = st.session_state.telemetry["mag"].transform( + lambda x: x == 0 + ) + status.dataframe( + st.session_state.telemetry, + hide_index=False, + use_container_width=True, + column_order=order, + column_config=columns, + ) + # map_container.map(data=st.session_state.telemetry, use_container_width=True, zoom=16, size=1) - await asyncio.sleep(1/st.session_state.imagery_framerate) + await asyncio.sleep(1 / st.session_state.imagery_framerate) except asyncio.CancelledError: st.write("Update coroutine canceled.") + def run_flightscript(): if len(st.session_state.script_file) == 0: st.toast("You haven't uploaded a script yet!", icon="🚨") else: filename = f"{time.time_ns()}.ms" path = f"{st.secrets.scripts_path}/{filename}" - with ZipFile(path, 'w') as z: + with ZipFile(path, "w") as z: for file in st.session_state.script_file: z.writestr(file.name, file.read()) @@ -149,6 +199,7 @@ def run_flightscript(): icon="\u2601", ) + def enable_manual(): st.session_state.rth_sent = False st.session_state.autonomous = False @@ -159,9 +210,8 @@ def enable_manual(): req.cmd.for_drone_id = json.dumps([st.session_state.selected_drone]) st.session_state.zmq.send(req.SerializeToString()) rep = st.session_state.zmq.recv() - st.toast( - f"Assuming manual control of {st.session_state.selected_drone}! Kill signal sent." - ) + st.toast(f"Assuming manual control of {st.session_state.selected_drone}! Kill signal sent.") + def rth(): st.session_state.manual_control = False @@ -176,12 +226,11 @@ def rth(): rep = st.session_state.zmq.recv() st.toast(f"Instructed {st.session_state.selected_drone} to return to home!") + @st.fragment(run_every="1s") def draw_map(): tileset = "https://mt0.google.com/vt/lyrs=y&hl=en&x={x}&y={y}&z={z}&s=Ga" - tiles = folium.TileLayer( - name="map_tileserver", tiles=tileset, attr="Google", max_zoom=20 - ) + tiles = folium.TileLayer(name="map_tileserver", tiles=tileset, attr="Google", max_zoom=20) m = folium.Map( location=[40.415428612484924, -79.95028831875038], @@ -193,12 +242,16 @@ def draw_map(): fg = folium.FeatureGroup(name="Current Location") marker_color = 0 - df = stream_to_dataframe(st.session_state.redis .xrevrange(f"telemetry.{st.session_state.selected_drone}", "+", "-", 1)) - last_update = (int(df.index[0].split("-")[0])/1000) + df = stream_to_dataframe( + st.session_state.redis.xrevrange( + f"telemetry.{st.session_state.selected_drone}", "+", "-", 1 + ) + ) + last_update = int(df.index[0].split("-")[0]) / 1000 i = 0 - for index, row in df.iterrows(): + for _index, row in df.iterrows(): text = folium.DivIcon( - icon_size="null", #set the size to null so that it expands to the length of the string inside in the div + icon_size="null", # set the size to null so that it expands to the length of the string inside in the div icon_anchor=(-20, 30), html=f'
{st.session_state.selected_drone}
', ) @@ -209,7 +262,7 @@ def draw_map(): angle=int(row["bearing"]), ) html = f'' - st.session_state.center =[row['latitude'], row['longitude']] + st.session_state.center = [row["latitude"], row["longitude"]] fg.add_child( folium.Marker( location=[ @@ -239,9 +292,10 @@ def draw_map(): feature_group_to_add=fg, returned_objects=[], center=st.session_state.center, - height=500 + height=500, ) + menu(with_control=False) with st.sidebar: @@ -257,8 +311,8 @@ def draw_map(): label="**:violet[Upload Autonomous Mission Script]**", help="Upload a flight script.", type=["kml", "dsl"], - label_visibility='visible', - accept_multiple_files=True + label_visibility="visible", + accept_multiple_files=True, ) st.button( key="autonomous_button", @@ -268,7 +322,7 @@ def draw_map(): use_container_width=True, on_click=run_flightscript, ) - # st.divider() + # st.divider() st.button( key="manual_button", label=":joystick: Manual Control", @@ -288,35 +342,85 @@ def draw_map(): on_click=rth, ) if st.session_state.manual_control: - #st.subheader(f":blue[Manual Control Enabled]") + # st.subheader(f":blue[Manual Control Enabled]") mode = ":green[Manual (armed)]" if st.session_state.armed else ":red[Manual (disarmed)]" st.checkbox(key="armed", label="Arm Drone?") c1, c2 = st.columns(spec=2, gap="small") - c1.number_input(key="pitch_speed", label="Pitch %", min_value=0, max_value=100, value=50, step=5, format="%d") - c2.number_input(key = "thrust_speed", label="Thrust %", min_value=0, max_value=100, step=5, value=50, format="%d") + c1.number_input( + key="pitch_speed", + label="Pitch %", + min_value=0, + max_value=100, + value=50, + step=5, + format="%d", + ) + c2.number_input( + key="thrust_speed", + label="Thrust %", + min_value=0, + max_value=100, + step=5, + value=50, + format="%d", + ) c3, c4 = st.columns(spec=2, gap="small") - c3.number_input(key = "yaw_speed", label="Yaw %", min_value=0, max_value=100, step=5, value=50, format="%d") - c4.number_input(key = "roll_speed", label="Roll %", min_value=0, max_value=100, step=5, value=50, format="%d") + c3.number_input( + key="yaw_speed", + label="Yaw %", + min_value=0, + max_value=100, + step=5, + value=50, + format="%d", + ) + c4.number_input( + key="roll_speed", + label="Roll %", + min_value=0, + max_value=100, + step=5, + value=50, + format="%d", + ) c5, c6 = st.columns(spec=2, gap="small") - c5.number_input(key = "gimbal_speed", label="Gimbal Pitch %", min_value=0, max_value=100, step=5, value=50, format="%d") - c6.number_input(key = "imagery_framerate", label="Imagery Framerate", min_value=1, max_value=30, step=1, value=2, format="%0d") + c5.number_input( + key="gimbal_speed", + label="Gimbal Pitch %", + min_value=0, + max_value=100, + step=5, + value=50, + format="%d", + ) + c6.number_input( + key="imagery_framerate", + label="Imagery Framerate", + min_value=1, + max_value=30, + step=1, + value=2, + format="%0d", + ) elif st.session_state.rth_sent: - mode = f":orange[Return to Home Initiated]" + mode = ":orange[Return to Home Initiated]" elif st.session_state.script_file is not None: - mode = f":violet[Autonomous Mode Enabled]" + mode = ":violet[Autonomous Mode Enabled]" status_container, imagery_container = st.columns(spec=[2, 3], gap="large") with status_container: draw_map() - #st.session_state.subscriber.punsubscribe() - #st.session_state.subscriber.psubscribe(f'imagery.{st.session_state.selected_drone}') - - st.subheader(f":blue[{st.session_state.selected_drone}] Status - {mode}" - if st.session_state.selected_drone is not None else ":red[No Drone Connected]", - divider="gray", - ) + # st.session_state.subscriber.punsubscribe() + # st.session_state.subscriber.psubscribe(f'imagery.{st.session_state.selected_drone}') + + st.subheader( + f":blue[{st.session_state.selected_drone}] Status - {mode}" + if st.session_state.selected_drone is not None + else ":red[No Drone Connected]", + divider="gray", + ) status_container = st.empty() with imagery_container: @@ -333,11 +437,15 @@ def draw_map(): st.markdown(":traffic_light: **HSV Filtering**") st.session_state.key_pressed = st_keypressed() -if st.session_state.armed and st.session_state.manual_control and st.session_state.selected_drone is not None: +if ( + st.session_state.armed + and st.session_state.manual_control + and st.session_state.selected_drone is not None +): req = cnc_pb2.Extras() req.commander_id = os.uname()[1] req.cmd.for_drone_id = json.dumps([st.session_state.selected_drone]) - #req.cmd.manual = True + # req.cmd.manual = True if st.session_state.key_pressed == "t": req.cmd.takeoff = True st.info(f"Instructed {st.session_state.selected_drone} to takeoff.") @@ -366,7 +474,7 @@ def draw_map(): gimbal_pitch = 1 * st.session_state.gimbal_speed elif st.session_state.key_pressed == "f": gimbal_pitch = -1 * st.session_state.gimbal_speed - #st.toast(f"PCMD(pitch = {pitch}, roll = {roll}, yaw = {yaw}, thrust = {thrust})") + # st.toast(f"PCMD(pitch = {pitch}, roll = {roll}, yaw = {yaw}, thrust = {thrust})") req.cmd.pcmd.yaw = yaw req.cmd.pcmd.pitch = pitch req.cmd.pcmd.roll = roll @@ -376,4 +484,12 @@ def draw_map(): st.session_state.zmq.send(req.SerializeToString()) rep = st.session_state.zmq.recv() -asyncio.run(update(livefeed_container, avoidance, detection, hsv, status_container,)) +asyncio.run( + update( + livefeed_container, + avoidance, + detection, + hsv, + status_container, + ) +) diff --git a/cnc/streamlit/pages/plan.py b/cnc/streamlit/pages/plan.py index b191a0d8..333921de 100644 --- a/cnc/streamlit/pages/plan.py +++ b/cnc/streamlit/pages/plan.py @@ -3,17 +3,16 @@ # SPDX-License-Identifier: GPL-2.0-only import streamlit as st -from streamlit_ace import st_ace -from util import menu, authenticated import streamlit.components.v1 as components -from streamlit_oauth import OAuth2Component from google.auth.transport.requests import Request from google.oauth2.credentials import Credentials - from googleapiclient.discovery import build from googleapiclient.errors import HttpError +from streamlit_ace import st_ace +from streamlit_oauth import OAuth2Component +from util import authenticated, menu -sample="""Task { +sample = """Task { Detect patrol_route { way_points: , gimbal_pitch: -30.0, @@ -50,28 +49,27 @@ page_icon=":military_helmet:", layout="wide", menu_items={ - 'Get help': 'https://cmusatyalab.github.io/steeleagle/', - 'Report a bug': "https://github.com/cmusatyalab/steeleagle/issues", - 'About': "SteelEagle - Automated drone flights for visual inspection tasks\n https://github.com/cmusatyalab/steeleagle" - } + "Get help": "https://cmusatyalab.github.io/steeleagle/", + "Report a bug": "https://github.com/cmusatyalab/steeleagle/issues", + "About": "SteelEagle - Automated drone flights for visual inspection tasks\n https://github.com/cmusatyalab/steeleagle", + }, ) if not authenticated(): st.stop() # Do not continue if not authenticated - def fetch_mymaps(): creds = Credentials( - st.session_state.auth['token']['access_token'], - refresh_token=st.session_state.auth['token']['refresh_token'], + st.session_state.auth["token"]["access_token"], + refresh_token=st.session_state.auth["token"]["refresh_token"], token_uri=st.secrets.oauth.token_endpoint, client_id=st.secrets.oauth.client_id, - client_secret=st.secrets.oauth.client_secret) + client_secret=st.secrets.oauth.client_secret, + ) - if not creds or not creds.valid: - if creds and creds.expired and creds.refresh_token: - creds.refresh(Request()) + if creds and not creds.valid and creds.expired and creds.refresh_token: + creds.refresh(Request()) try: service = build("drive", "v3", credentials=creds) @@ -88,7 +86,7 @@ def fetch_mymaps(): items = results.get("files", []) for file in items: - st.session_state.file_list[file['id']] = file['name'] + st.session_state.file_list[file["id"]] = file["name"] except HttpError as error: st.error(f"An error occurred fetching files from Google Drive: {error}") @@ -98,12 +96,14 @@ def fetch_mymaps(): with c1: if "auth" not in st.session_state: # create a button to start the OAuth2 flow - oauth2 = OAuth2Component(st.secrets.oauth.client_id, - st.secrets.oauth.client_secret, - st.secrets.oauth.auth_endpoint, - st.secrets.oauth.token_endpoint, - st.secrets.oauth.token_endpoint, - st.secrets.oauth.revoke_endpoint) + oauth2 = OAuth2Component( + st.secrets.oauth.client_id, + st.secrets.oauth.client_secret, + st.secrets.oauth.auth_endpoint, + st.secrets.oauth.token_endpoint, + st.secrets.oauth.token_endpoint, + st.secrets.oauth.revoke_endpoint, + ) result = oauth2.authorize_button( name="Log in to Google Drive", icon="https://www.google.com/favicon.ico", @@ -112,7 +112,7 @@ def fetch_mymaps(): key="google", extras_params={"prompt": "consent", "access_type": "offline"}, use_container_width=True, - pkce='S256', + pkce="S256", ) if result: @@ -121,18 +121,32 @@ def fetch_mymaps(): else: fetch_mymaps() - st.selectbox(label=":world_map: **:blue[Load Map]**", options=st.session_state.file_list.keys(), format_func=lambda option: st.session_state.file_list[option], key="map_id", placeholder="Select a map to load from MyMaps...") - components.iframe(f"https://www.google.com/maps/d/u/0/embed?mid={st.session_state.map_id}", height=600, scrolling=True) - st.link_button(label="Download KML", type="primary", url=f"https://www.google.com/maps/d/kml?mid={st.session_state.map_id}&forcekml=1") - st.link_button(label="Edit in MyMaps", type="primary", url=f"https://www.google.com/maps/d/u/0/edit?mid={st.session_state.map_id}") + st.selectbox( + label=":world_map: **:blue[Load Map]**", + options=st.session_state.file_list.keys(), + format_func=lambda option: st.session_state.file_list[option], + key="map_id", + placeholder="Select a map to load from MyMaps...", + ) + components.iframe( + f"https://www.google.com/maps/d/u/0/embed?mid={st.session_state.map_id}", + height=600, + scrolling=True, + ) + st.link_button( + label="Download KML", + type="primary", + url=f"https://www.google.com/maps/d/kml?mid={st.session_state.map_id}&forcekml=1", + ) + st.link_button( + label="Edit in MyMaps", + type="primary", + url=f"https://www.google.com/maps/d/u/0/edit?mid={st.session_state.map_id}", + ) with c2: st.subheader(":clipboard: **:green[Edit Mission Script]**", divider="gray") dsl = st_ace(height=600, value=sample, language="yaml") st.download_button( - label="Download Mission File", - data=dsl, - file_name="mission.dsl", - type="primary" + label="Download Mission File", data=dsl, file_name="mission.dsl", type="primary" ) - diff --git a/cnc/streamlit/util.py b/cnc/streamlit/util.py index 51f6a229..e8129873 100644 --- a/cnc/streamlit/util.py +++ b/cnc/streamlit/util.py @@ -2,13 +2,14 @@ # # SPDX-License-Identifier: GPL-2.0-only -import streamlit as st -import redis -import pandas as pd +import hmac import json -import zmq import time -import hmac + +import pandas as pd +import redis +import streamlit as st +import zmq DATA_TYPES = { "latitude": "float", @@ -45,7 +46,8 @@ if "control_pressed" not in st.session_state: st.session_state.control_pressed = False if "inactivity_time" not in st.session_state: - st.session_state.inactivity_time = 1 #min + st.session_state.inactivity_time = 1 # min + def authenticated(): """Returns `True` if the user had the correct password.""" @@ -63,14 +65,18 @@ def password_entered(): return True # Show input for password. - a,b,c = st.columns(3) + a, b, c = st.columns(3) b.text_input( - "Password", type="password", on_change=password_entered, key="password", + "Password", + type="password", + on_change=password_entered, + key="password", ) if "password_correct" in st.session_state: b.error("Authentication failed.", icon=":material/block:") return False + @st.cache_resource def connect_redis(): red = redis.Redis( @@ -82,6 +88,7 @@ def connect_redis(): ) return red + @st.cache_resource def connect_redis_publisher(): red = redis.Redis( @@ -93,11 +100,12 @@ def connect_redis_publisher(): subscriber = red.pubsub(ignore_subscribe_messages=True) return subscriber + @st.cache_resource def connect_zmq(): ctx = zmq.Context() z = ctx.socket(zmq.REQ) - z.connect(f'tcp://{st.secrets.zmq}:{st.secrets.zmq_port}') + z.connect(f"tcp://{st.secrets.zmq}:{st.secrets.zmq_port}") return z @@ -106,23 +114,27 @@ def get_drones(): red = connect_redis() for k in red.keys("telemetry.*"): latest_entry = red.xrevrange(f"{k}", "+", "-", 1) - last_update = (int(latest_entry[0][0].split("-")[0])/1000) - if time.time() - last_update < st.session_state.inactivity_time * 60: # minutes -> seconds - l[f"{k.split('.')[-1]}"] = f"**{k.split('.')[-1]}** " #TODO: add :material/abc:[drone model] once it is sent with telemetry + last_update = int(latest_entry[0][0].split("-")[0]) / 1000 + if time.time() - last_update < st.session_state.inactivity_time * 60: # minutes -> seconds + l[f"{k.split('.')[-1]}"] = ( + f"**{k.split('.')[-1]}** " # TODO: add :material/abc:[drone model] once it is sent with telemetry + ) return l -def stream_to_dataframe(results, types=DATA_TYPES ) -> pd.DataFrame: + +def stream_to_dataframe(results, types=DATA_TYPES) -> pd.DataFrame: _container = {} for item in results: _container[item[0]] = json.loads(json.dumps(item[1])) - df = pd.DataFrame.from_dict(_container, orient='index') + df = pd.DataFrame.from_dict(_container, orient="index") if types is not None: df = df.astype(types) return df + def control_drone(drone): st.session_state.selected_drone = drone st.session_state.control_pressed = True @@ -131,6 +143,3 @@ def control_drone(drone): def menu(with_control=True): st.sidebar.page_link("overview.py", label=":satellite_antenna: Overview") st.sidebar.page_link("pages/plan.py", label=":ledger: Mission Planning") - - - diff --git a/droneDSL/python/interface/Task.py b/droneDSL/python/interface/Task.py index 8ac96b43..5f44a42c 100644 --- a/droneDSL/python/interface/Task.py +++ b/droneDSL/python/interface/Task.py @@ -2,36 +2,39 @@ # # SPDX-License-Identifier: GPL-2.0-only -from abc import ABC, abstractmethod import functools import logging import threading +from abc import ABC, abstractmethod + from aenum import Enum logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) + class TaskType(Enum): Detect = 1 Track = 2 Avoid = 3 Test = 4 -class TaskArguments(): + +class TaskArguments: def __init__(self, task_type, transitions_attributes, task_attributes): self.task_type = task_type self.task_attributes = task_attributes self.transitions_attributes = transitions_attributes - -class Task(ABC): + +class Task(ABC): def __init__(self, drone, cloudlet, task_id, trigger_event_queue, task_args): self.cloudlet = cloudlet self.drone = drone self.task_attributes = task_args.task_attributes self.transitions_attributes = task_args.transitions_attributes self.task_id = task_id - self.trans_active = [] + self.trans_active = [] self.trans_active_lock = threading.Lock() self.trigger_event_queue = trigger_event_queue @@ -42,25 +45,24 @@ async def run(self): def get_task_id(self): return self.task_id - def _exit(self): # kill all the transitions - logger.info(f"**************exit the task**************\n") + logger.info("**************exit the task**************\n") self.stop_trans() - self.trigger_event_queue.put((self.task_id, "done")) - + self.trigger_event_queue.put((self.task_id, "done")) + def stop_trans(self): - logger.info(f"**************stopping the transitions**************\n") + logger.info("**************stopping the transitions**************\n") for trans in self.trans_active: if trans.is_alive(): trans.stop() trans.join() - logger.info(f"**************the transitions stopped**************\n") - - + logger.info("**************the transitions stopped**************\n") + @classmethod def call_after_exit(cls, func): """Decorator to call _exit after the decorated function completes.""" + @functools.wraps(func) async def wrapper(self, *args, **kwargs): try: @@ -72,9 +74,11 @@ async def wrapper(self, *args, **kwargs): self._exit() return wrapper - + + @abstractmethod def pause(self): pass - + + @abstractmethod def resume(self): pass diff --git a/droneDSL/python/interface/Transition.py b/droneDSL/python/interface/Transition.py index a7978905..df90e7e5 100644 --- a/droneDSL/python/interface/Transition.py +++ b/droneDSL/python/interface/Transition.py @@ -1,37 +1,38 @@ -from abc import ABC, abstractmethod import logging import threading - +from abc import ABC, abstractmethod logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) + class Transition(threading.Thread, ABC): def __init__(self, args): super().__init__() - self.task_id = args['task_id'] - self.trans_active = args['trans_active'] - self.trans_active_lock = args['trans_active_lock'] - self.trigger_event_queue = args['trigger_event_queue'] + self.task_id = args["task_id"] + self.trans_active = args["trans_active"] + self.trans_active_lock = args["trans_active_lock"] + self.trigger_event_queue = args["trigger_event_queue"] # self.trigger_event_queue_lock = trigger_event_queue_lock - + @abstractmethod def stop(self): """This is an abstract method that must be implemented in a subclass.""" pass - + def _trigger_event(self, event): - logger.info(f"**************task id {self.task_id}: triggered event! {event}**************\n") + logger.info( + f"**************task id {self.task_id}: triggered event! {event}**************\n" + ) # with self.trigger_event_queue_lock: - self.trigger_event_queue.put((self.task_id, event)) - + self.trigger_event_queue.put((self.task_id, event)) + def _register(self): logger.info(f"**************{self.name} is registering by itself**************\n") with self.trans_active_lock: self.trans_active.append(self) - + def _unregister(self): logger.info(f"**************{self.name} is unregistering by itself**************\n") with self.trans_active_lock: self.trans_active.remove(self) - \ No newline at end of file diff --git a/droneDSL/python/project/task_defs/AvoidTask.py b/droneDSL/python/project/task_defs/AvoidTask.py index 1be64428..21d25b39 100644 --- a/droneDSL/python/project/task_defs/AvoidTask.py +++ b/droneDSL/python/project/task_defs/AvoidTask.py @@ -2,22 +2,23 @@ # # SPDX-License-Identifier: GPL-2.0-only -#from interfaces.Task import Task -import json -from json import JSONDecodeError -import time +# from interfaces.Task import Task import asyncio +import json import logging +import time +from json import JSONDecodeError + from gabriel_protocol import gabriel_pb2 -from ..transition_defs.TimerTransition import TimerTransition from interface.Task import Task +from ..transition_defs.TimerTransition import TimerTransition + logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -class AvoidTask(Task): - +class AvoidTask(Task): def __init__(self, drone, cloudlet, task_id, trigger_event_queue, task_args): super().__init__(drone, cloudlet, task_id, trigger_event_queue, task_args) self.drone = drone @@ -27,23 +28,23 @@ def __init__(self, drone, cloudlet, task_id, trigger_event_queue, task_args): self.time_prev = None self.error_prev = 0 self.setpt = [0.0, 0.0] - self.roll_pid_info = {"constants" : {"Kp": 5.0, "Ki": 0.01, "Kd": 6.0}, "saved" : {"I": 0.0}} - self.pitch_pid_info = {"constants" : {"Kp": 5.0, "Ki": 0.01, "Kd": 6.0}, "saved" : {"I": 0.0}} - self.forwardspeed = 1.5 + self.roll_pid_info = {"constants": {"Kp": 5.0, "Ki": 0.01, "Kd": 6.0}, "saved": {"I": 0.0}} + self.pitch_pid_info = {"constants": {"Kp": 5.0, "Ki": 0.01, "Kd": 6.0}, "saved": {"I": 0.0}} + self.forwardspeed = 1.5 self.horizontalspeed = 1 self.oscillations = 0 def create_transition(self): logger.info(self.transitions_attributes) args = { - 'task_id': self.task_id, - 'trans_active': self.trans_active, - 'trans_active_lock': self.trans_active_lock, - 'trigger_event_queue': self.trigger_event_queue + "task_id": self.task_id, + "trans_active": self.trans_active, + "trans_active_lock": self.trans_active_lock, + "trigger_event_queue": self.trigger_event_queue, } - + # Triggered event - if ("timeout" in self.transitions_attributes): + if "timeout" in self.transitions_attributes: timer = TimerTransition(args, self.transitions_attributes["timeout"]) timer.daemon = True timer.start() @@ -61,7 +62,7 @@ async def computeError(self): async def moveForwardAndAvoid(self, error): ts = round(time.time() * 1000) if self.time_prev is None or (ts - self.time_prev) > 1000: - self.time_prev = ts - 1 # Do this to prevent a divide by zero error! + self.time_prev = ts - 1 # Do this to prevent a divide by zero error! self.error_prev = error # Roll control loop @@ -73,10 +74,14 @@ async def moveForwardAndAvoid(self, error): self.roll_pid_info["saved"]["I"] = 0 else: self.roll_pid_info["saved"]["I"] += self.clamp(Ir, -100.0, 100.0) - Dr = self.roll_pid_info["constants"]["Kd"] * (error[0] - self.error_prev[0]) / (ts - self.time_prev) - + Dr = ( + self.roll_pid_info["constants"]["Kd"] + * (error[0] - self.error_prev[0]) + / (ts - self.time_prev) + ) + roll = self.clamp(int(Pr + Ir + Dr), -100, 100) - + # Pitch control loop Pp = self.pitch_pid_info["constants"]["Kp"] * error[1] Ip = self.pitch_pid_info["constants"]["Ki"] * (ts - self.time_prev) @@ -86,8 +91,12 @@ async def moveForwardAndAvoid(self, error): self.pitch_pid_info["saved"]["I"] = 0 else: self.pitch_pid_info["saved"]["I"] += self.clamp(Ip, -100.0, 100.0) - Dp = self.pitch_pid_info["constants"]["Kd"] * (error[1] - self.error_prev[1]) / (ts - self.time_prev) - + Dp = ( + self.pitch_pid_info["constants"]["Kd"] + * (error[1] - self.error_prev[1]) + / (ts - self.time_prev) + ) + pitch = self.clamp(int(Pp + Ip + Dp), -100, 100) self.time_prev = ts @@ -99,8 +108,10 @@ async def moveForwardAndAvoid(self, error): def setPoint(self, error): # Calculate horizontal error newpt = error * self.horizontalspeed - if newpt * self.setpt[0] < 0 and abs(self.setpt[0] - newpt) > 0.5 and self.oscillations < 3: # Check if they have different signs - self.oscillations += 1 + if ( + newpt * self.setpt[0] < 0 and abs(self.setpt[0] - newpt) > 0.5 and self.oscillations < 3 + ): # Check if they have different signs + self.oscillations += 1 else: self.oscillations = 0 self.setpt[0] = error * self.horizontalspeed @@ -121,22 +132,21 @@ async def run(self): try: logger.info(f"[ObstacleTask] result: {result}") if result is not None and result.payload_type == gabriel_pb2.TEXT: - json_string = result.payload.decode('utf-8') + json_string = result.payload.decode("utf-8") json_data = json.loads(json_string) logger.info("[ObstacleTask] Decoded results") - offset = json_data[0]['vector'] + offset = json_data[0]["vector"] self.setPoint(offset) logger.info(f"[ObstacleTask] Set point {self.setpt}") error = await self.computeError() logger.info(f"[ObstacleTask] Error {error}") await self.moveForwardAndAvoid(error) - except JSONDecodeError as e: - logger.error(f"[ObstacleTask]: Error decoding JSON") + except JSONDecodeError: + logger.error("[ObstacleTask]: Error decoding JSON") except Exception as e: - logger.error(f"[ObstacleTask] Threw an exception") + logger.error("[ObstacleTask] Threw an exception") logger.error(e) await asyncio.sleep(0.1) except Exception as e: logger.info(f"[ObstacleTask] Task failed with exception {e}") await self.drone.hover() - diff --git a/droneDSL/python/project/task_defs/DetectTask.py b/droneDSL/python/project/task_defs/DetectTask.py index d14569ff..cb89fdc3 100644 --- a/droneDSL/python/project/task_defs/DetectTask.py +++ b/droneDSL/python/project/task_defs/DetectTask.py @@ -1,65 +1,79 @@ - -from ..transition_defs.ObjectDetectionTransition import ObjectDetectionTransition -from ..transition_defs.TimerTransition import TimerTransition -from ..transition_defs.HSVDetectionTransition import HSVDetectionTransition -from interface.Task import Task -import asyncio import ast +import asyncio import logging -from gabriel_protocol import gabriel_pb2 +from interface.Task import Task + +from ..transition_defs.HSVDetectionTransition import HSVDetectionTransition +from ..transition_defs.ObjectDetectionTransition import ObjectDetectionTransition +from ..transition_defs.TimerTransition import TimerTransition logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -class DetectTask(Task): +class DetectTask(Task): def __init__(self, drone, cloudlet, task_id, trigger_event_queue, task_args): super().__init__(drone, cloudlet, task_id, trigger_event_queue, task_args) - - + def create_transition(self): - - logger.info(f"**************Detect Task {self.task_id}: create transition! **************\n") + logger.info( + f"**************Detect Task {self.task_id}: create transition! **************\n" + ) logger.info(self.transitions_attributes) args = { - 'task_id': self.task_id, - 'trans_active': self.trans_active, - 'trans_active_lock': self.trans_active_lock, - 'trigger_event_queue': self.trigger_event_queue + "task_id": self.task_id, + "trans_active": self.trans_active, + "trans_active_lock": self.trans_active_lock, + "trigger_event_queue": self.trigger_event_queue, } - + # triggered event - if ("timeout" in self.transitions_attributes): - logger.info(f"**************Detect Task {self.task_id}: timer transition! **************\n") + if "timeout" in self.transitions_attributes: + logger.info( + f"**************Detect Task {self.task_id}: timer transition! **************\n" + ) timer = TimerTransition(args, self.transitions_attributes["timeout"]) timer.daemon = True timer.start() - - if ("object_detection" in self.transitions_attributes): - logger.info(f"**************Detect Task {self.task_id}: object detection transition! **************\n") + + if "object_detection" in self.transitions_attributes: + logger.info( + f"**************Detect Task {self.task_id}: object detection transition! **************\n" + ) self.cloudlet.clearResults("openscout-object") - object_trans = ObjectDetectionTransition(args, self.transitions_attributes["object_detection"], self.cloudlet) + object_trans = ObjectDetectionTransition( + args, self.transitions_attributes["object_detection"], self.cloudlet + ) object_trans.daemon = True object_trans.start() - if ("hsv_detection" in self.transitions_attributes): - logger.info(f"**************Detect Task {self.task_id}: hsv detection transition! **************\n") + if "hsv_detection" in self.transitions_attributes: + logger.info( + f"**************Detect Task {self.task_id}: hsv detection transition! **************\n" + ) self.cloudlet.clearResults("openscout-object") - hsv = HSVDetectionTransition(args, self.transitions_attributes["hsv_detection"], self.cloudlet) + hsv = HSVDetectionTransition( + args, self.transitions_attributes["hsv_detection"], self.cloudlet + ) hsv.daemon = True hsv.start() - + @Task.call_after_exit async def run(self): # init the cloudlet self.cloudlet.switchModel(self.task_attributes["model"]) - self.cloudlet.setHSVFilter(lower_bound=self.task_attributes["lower_bound"], upper_bound=self.task_attributes["upper_bound"]) - + self.cloudlet.setHSVFilter( + lower_bound=self.task_attributes["lower_bound"], + upper_bound=self.task_attributes["upper_bound"], + ) + self.create_transition() - + # try: - logger.info(f"**************Detect Task {self.task_id}: hi this is detect task {self.task_id}**************\n") + logger.info( + f"**************Detect Task {self.task_id}: hi this is detect task {self.task_id}**************\n" + ) coords = ast.literal_eval(self.task_attributes["coords"]) await self.drone.setGimbalPose(0.0, float(self.task_attributes["gimbal_pitch"]), 0.0) hover_delay = int(self.task_attributes["hover_delay"]) @@ -68,11 +82,11 @@ async def run(self): lat = dest["lat"] alt = dest["alt"] logger.info(f"**************Detect Task {self.task_id}: Move **************\n") - logger.info(f"**************Detect Task {self.task_id}: move to {lat}, {lng}, {alt}**************\n") + logger.info( + f"**************Detect Task {self.task_id}: move to {lat}, {lng}, {alt}**************\n" + ) await self.drone.moveTo(lat, lng, alt) await asyncio.sleep(1) # await asyncio.sleep(hover_delay) logger.info(f"**************Detect Task {self.task_id}: Done**************\n") - - diff --git a/droneDSL/python/project/task_defs/SetHome.py b/droneDSL/python/project/task_defs/SetHome.py index 6117ab36..e18263b3 100644 --- a/droneDSL/python/project/task_defs/SetHome.py +++ b/droneDSL/python/project/task_defs/SetHome.py @@ -2,11 +2,12 @@ # # SPDX-License-Identifier: GPL-2.0-only -from interface.Task import Task import ast -class SetHome(Task): +from interface.Task import Task + +class SetHome(Task): def __init__(self, drone, cloudlet, **kwargs): super().__init__(drone, cloudlet, **kwargs) @@ -19,5 +20,3 @@ async def run(self): await self.drone.setHome(lat, lng, 1.0) except Exception as e: print(e) - - diff --git a/droneDSL/python/project/task_defs/TestTask.py b/droneDSL/python/project/task_defs/TestTask.py index 422eedea..c2ed730d 100644 --- a/droneDSL/python/project/task_defs/TestTask.py +++ b/droneDSL/python/project/task_defs/TestTask.py @@ -1,63 +1,68 @@ - -from ..transition_defs.TimerTransition import TimerTransition -from interface.Task import Task -import asyncio import ast +import asyncio import logging +from interface.Task import Task + +from ..transition_defs.TimerTransition import TimerTransition logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -class TestTask(Task): +class TestTask(Task): def __init__(self, drone, cloudlet, task_id, trigger_event_queue, task_args): super().__init__(drone, cloudlet, task_id, trigger_event_queue, task_args) - - + def create_transition(self): - logger.info(f"**************Test Task 2{self.task_id}: create transition! **************\n") logger.info(self.transitions_attributes) args = { - 'task_id': self.task_id, - 'trans_active': self.trans_active, - 'trans_active_lock': self.trans_active_lock, - 'trigger_event_queue': self.trigger_event_queue + "task_id": self.task_id, + "trans_active": self.trans_active, + "trans_active_lock": self.trans_active_lock, + "trigger_event_queue": self.trigger_event_queue, } - + # triggered event - if ("timeout" in self.transitions_attributes): - logger.info(f"**************Test Task 2{self.task_id}: timer transition! **************\n") + if "timeout" in self.transitions_attributes: + logger.info( + f"**************Test Task 2{self.task_id}: timer transition! **************\n" + ) timer = TimerTransition(args, self.transitions_attributes["timeout"]) timer.daemon = True timer.start() - - - # test all the driver calls + + # test all the driver calls @Task.call_after_exit async def run(self): - # self.create_transition() - - logger.info(f"**************Test Task 2{self.task_id}: hi this is Test task2 {self.task_id}**************\n") - + + logger.info( + f"**************Test Task 2{self.task_id}: hi this is Test task2 {self.task_id}**************\n" + ) + # coords = [ # {"lat": 37.7749, "lng": -122.4194, "alt": 30, "bear": 0}, # San Francisco # {"lat": 34.0522, "lng": -118.2437, "alt": 50, "bear": 0}, # Los Angeles # {"lat": 40.7128, "lng": -74.0060, "alt": 100, "bear": 0} # New York # ] while True: - - avoid_res = await self.cloudlet.getResults('obstacle-avoidance') - detect_res = await self.cloudlet.getResults('openscout-object') - - logger.info(f"**************Test Task 2{self.task_id}: Avoidance Result: {avoid_res}**************\n") - logger.info(f"**************Test Task 2{self.task_id}: Detection Result: {detect_res}**************\n") - + avoid_res = await self.cloudlet.getResults("obstacle-avoidance") + detect_res = await self.cloudlet.getResults("openscout-object") + + logger.info( + f"**************Test Task 2{self.task_id}: Avoidance Result: {avoid_res}**************\n" + ) + logger.info( + f"**************Test Task 2{self.task_id}: Detection Result: {detect_res}**************\n" + ) + coords = ast.literal_eval(self.task_attributes["coords"]) - logger.info(f"**************Test Task 2{self.task_id}: hi this is Test task2 {self.task_id}**************\n") + logger.info( + f"**************Test Task 2{self.task_id}: hi this is Test task2 {self.task_id}**************\n" + ) for dest in coords: lng = dest["lng"] lat = dest["lat"] @@ -65,14 +70,10 @@ async def run(self): # bear = dest["bear"] bear = 0 logger.info(f"**************Test Task 2{self.task_id}: setGPSLocation **************\n") - logger.info(f"**************Test Task 2{self.task_id}: GPSLocation: {lat}, {lng}, {alt} {bear}**************\n") + logger.info( + f"**************Test Task 2{self.task_id}: GPSLocation: {lat}, {lng}, {alt} {bear}**************\n" + ) await self.drone.setGPSLocation(lat, lng, alt, bear) await asyncio.sleep(0) - - logger.info(f"**************Test Task 2{self.task_id}: Done**************\n") - - - - diff --git a/droneDSL/python/project/task_defs/TrackTask.py b/droneDSL/python/project/task_defs/TrackTask.py index 076bc9ea..5815e34c 100644 --- a/droneDSL/python/project/task_defs/TrackTask.py +++ b/droneDSL/python/project/task_defs/TrackTask.py @@ -1,22 +1,23 @@ import asyncio -from json import JSONDecodeError -import sys import json -import numpy as np +import logging import math -from ..transition_defs.TimerTransition import TimerTransition -from interface.Task import Task +import sys import time -import logging +from json import JSONDecodeError + +import numpy as np from gabriel_protocol import gabriel_pb2 +from interface.Task import Task from scipy.spatial.transform import Rotation as R +from ..transition_defs.TimerTransition import TimerTransition + logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) class TrackTask(Task): - def __init__(self, drone, cloudlet, task_id, trigger_event_queue, task_args): super().__init__(drone, cloudlet, task_id, trigger_event_queue, task_args) # TODO: Make this a drone interface command @@ -36,20 +37,23 @@ def __init__(self, drone, cloudlet, task_id, trigger_event_queue, task_args): # ANAFI series. self.time_prev = None self.error_prev = [0, 0, 0] - self.yaw_pid_info = {"constants": {"Kp": 10.0, "Ki": 0.09, "Kd": 40.0}, "saved" : {"I": 0.0}} - self.move_pid_info = {"constants": {"Kp": 2.0, "Ki": 0.030, "Kd": 35.0}, "saved" : {"I": 0.0}} + self.yaw_pid_info = {"constants": {"Kp": 10.0, "Ki": 0.09, "Kd": 40.0}, "saved": {"I": 0.0}} + self.move_pid_info = { + "constants": {"Kp": 2.0, "Ki": 0.030, "Kd": 35.0}, + "saved": {"I": 0.0}, + } def create_transition(self): logger.info(self.transitions_attributes) args = { - 'task_id': self.task_id, - 'trans_active': self.trans_active, - 'trans_active_lock': self.trans_active_lock, - 'trigger_event_queue': self.trigger_event_queue + "task_id": self.task_id, + "trans_active": self.trans_active, + "trans_active_lock": self.trans_active_lock, + "trigger_event_queue": self.trigger_event_queue, } # Triggered event - if ("timeout" in self.transitions_attributes): + if "timeout" in self.transitions_attributes: timer = TimerTransition(args, self.transitions_attributes["timeout"]) timer.daemon = True timer.start() @@ -62,11 +66,14 @@ def targetBearing(self, origin, destination): rlat2 = math.radians(lat2) rlon1 = math.radians(lon1) rlon2 = math.radians(lon2) - dlon = math.radians(lon2-lon1) + dlon = math.radians(lon2 - lon1) - b = math.atan2(math.sin(dlon)*math.cos(rlat2),math.cos(rlat1)*math.sin(rlat2)-math.sin(rlat1)*math.cos(rlat2)*math.cos(dlon)) + b = math.atan2( + math.sin(dlon) * math.cos(rlat2), + math.cos(rlat1) * math.sin(rlat2) - math.sin(rlat1) * math.cos(rlat2) * math.cos(dlon), + ) bd = math.degrees(b) - br,bn = divmod(bd+360,360) + br, bn = divmod(bd + 360, 360) return bn @@ -85,7 +92,7 @@ async def estimateDistance(self, yaw, pitch): gimbal = await self.drone.getGimbalPitch() vf = [0, 1, 0] - r = R.from_euler('ZYX', [yaw, 0, pitch + gimbal], degrees=True) + r = R.from_euler("ZYX", [yaw, 0, pitch + gimbal], degrees=True) target_dir = r.as_matrix().dot(vf) target_vec = self.findIntersection(target_dir, np.array([0, 0, alt])) @@ -97,9 +104,15 @@ async def estimateDistance(self, yaw, pitch): async def error(self, box): target_x_pix = self.image_res[0] - int(((box[3] - box[1]) / 2.0) + box[1]) target_y_pix = self.image_res[1] - int(((box[2] - box[0]) / 2.0) + box[0]) - target_yaw_angle = ((target_x_pix - self.pixel_center[0]) / self.pixel_center[0]) * (self.HFOV / 2) - target_pitch_angle = ((target_y_pix - self.pixel_center[1]) / self.pixel_center[1]) * (self.VFOV / 2) - target_bottom_pitch_angle = (((self.image_res[1] - box[2]) - self.pixel_center[1]) / self.pixel_center[1]) * (self.VFOV / 2) + target_yaw_angle = ((target_x_pix - self.pixel_center[0]) / self.pixel_center[0]) * ( + self.HFOV / 2 + ) + target_pitch_angle = ((target_y_pix - self.pixel_center[1]) / self.pixel_center[1]) * ( + self.VFOV / 2 + ) + target_bottom_pitch_angle = ( + ((self.image_res[1] - box[2]) - self.pixel_center[1]) / self.pixel_center[1] + ) * (self.VFOV / 2) yaw_error = -1 * target_yaw_angle gimbal_error = target_pitch_angle @@ -120,7 +133,7 @@ async def pid(self, box): # Reset pid loop if we haven't seen a target for a second or this is # the first target we have seen. if self.time_prev is None or (ts - self.time_prev) > 1000: - self.time_prev = ts - 1 # Do this to prevent a divide by zero error! + self.time_prev = ts - 1 # Do this to prevent a divide by zero error! self.error_prev = [ye, ge, me] # Control loop for yaw @@ -130,7 +143,9 @@ async def pid(self, box): Iy *= -1 self.yaw_pid_info["saved"]["I"] += Iy self.yaw_pid_info["saved"]["I"] = self.clamp(self.yaw_pid_info["saved"]["I"], -100.0, 100.0) - Dy = self.yaw_pid_info["constants"]["Kd"] * (ye - self.error_prev[0]) / (ts - self.time_prev) + Dy = ( + self.yaw_pid_info["constants"]["Kd"] * (ye - self.error_prev[0]) / (ts - self.time_prev) + ) logger.info(f"[TrackTask]: YAW values {ye} {Py} {Iy} {Dy}") yaw = Py + Iy + Dy @@ -147,8 +162,14 @@ async def pid(self, box): if me < 0: Im *= -1 self.move_pid_info["saved"]["I"] += Im * extra - self.move_pid_info["saved"]["I"] = self.clamp(self.move_pid_info["saved"]["I"], -100.0, 100.0) - Dm = self.move_pid_info["constants"]["Kd"] * (me - self.error_prev[2]) / (ts - self.time_prev) + self.move_pid_info["saved"]["I"] = self.clamp( + self.move_pid_info["saved"]["I"], -100.0, 100.0 + ) + Dm = ( + self.move_pid_info["constants"]["Kd"] + * (me - self.error_prev[2]) + / (ts - self.time_prev) + ) logger.info(f"[TrackTask]: MOVE values {me} {Pm} {Im} {Dm}") move = Pm + Im + Dm @@ -161,7 +182,7 @@ async def pid(self, box): return (yaw, gimbal, pitch) async def actuate(self, vels): - #await self.drone.PCMD(0, vels[2], vels[0], 0) + # await self.drone.PCMD(0, vels[2], vels[0], 0) logger.info(f"Calling pcmd with {vels}") await self.drone.PCMD(0, vels[2], vels[0], 0) g = await self.drone.getGimbalPitch() @@ -171,7 +192,10 @@ async def run(self): logger.info("[TrackTask]: Starting tracking task") self.cloudlet.switchModel(self.task_attributes["model"]) - self.cloudlet.setHSVFilter(lower_bound=self.task_attributes["lower_bound"], upper_bound=self.task_attributes["upper_bound"]) + self.cloudlet.setHSVFilter( + lower_bound=self.task_attributes["lower_bound"], + upper_bound=self.task_attributes["upper_bound"], + ) # TODO: Parameterize this # self.leash_length = float(self.task_attributes["leash"]) @@ -181,39 +205,39 @@ async def run(self): target = self.task_attributes["class"] # TODO: This should only be done if requested. - #await self.drone.setGimbalPose(0.0, float(self.task_attributes["gimbal_pitch"]), 0.0) + # await self.drone.setGimbalPose(0.0, float(self.task_attributes["gimbal_pitch"]), 0.0) self.create_transition() last_seen = None while True: result = self.cloudlet.getResults("openscout-object") - if last_seen is not None and int(time.time() - last_seen) > self.target_lost_duration: - #if we have not found the target in N seconds trigger the done transition + if last_seen is not None and int(time.time() - last_seen) > self.target_lost_duration: + # if we have not found the target in N seconds trigger the done transition break - if result != None: - if result.payload_type == gabriel_pb2.TEXT: - try: - json_string = result.payload.decode('utf-8') - json_data = json.loads(json_string) - box = None - for det in json_data: - - # Return the first instance found of the target class. - if det["class"] == target and det["hsv_filter"]: - box = det["box"] - last_seen = time.time() - break - - # Found an instance of target, start tracking! - if box is not None: - logger.info(f"[TrackTask]: Detected instance of {target}, tracking...") - vels = await self.pid(box) - await self.actuate(vels) - except JSONDecodeError as e: - logger.error(f"[TrackTask]: Error decoding json, ignoring") - except Exception as e: - exc_type, exc_obj, exc_tb = sys.exc_info() - logger.error(f"[TrackTask]: Exception encountered, {e}, line no {exc_tb.tb_lineno}") + if result is not None and result.payload_type == gabriel_pb2.TEXT: + try: + json_string = result.payload.decode("utf-8") + json_data = json.loads(json_string) + box = None + for det in json_data: + # Return the first instance found of the target class. + if det["class"] == target and det["hsv_filter"]: + box = det["box"] + last_seen = time.time() + break + + # Found an instance of target, start tracking! + if box is not None: + logger.info(f"[TrackTask]: Detected instance of {target}, tracking...") + vels = await self.pid(box) + await self.actuate(vels) + except JSONDecodeError: + logger.error("[TrackTask]: Error decoding json, ignoring") + except Exception as e: + exc_type, exc_obj, exc_tb = sys.exc_info() + logger.error( + f"[TrackTask]: Exception encountered, {e}, line no {exc_tb.tb_lineno}" + ) await asyncio.sleep(0.03) self._exit() diff --git a/droneDSL/python/project/transition_defs/HSVDetectionTransition.py b/droneDSL/python/project/transition_defs/HSVDetectionTransition.py index 816e57e8..1b868011 100644 --- a/droneDSL/python/project/transition_defs/HSVDetectionTransition.py +++ b/droneDSL/python/project/transition_defs/HSVDetectionTransition.py @@ -1,45 +1,45 @@ -from json import JSONDecodeError import json import logging -from venv import logger -from interface.Transition import Transition +from json import JSONDecodeError + from gabriel_protocol import gabriel_pb2 +from interface.Transition import Transition logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) + class HSVDetectionTransition(Transition): - def __init__(self, args, target, cloudlet): + def __init__(self, args, target, cloudlet): super().__init__(args) self.stop_signal = False self.target = target self.cloudlet = cloudlet - + def stop(self): self.stop_signal = True - + def run(self): self._register() self.cloudlet.clearResults("openscout-object") while not self.stop_signal: result = self.cloudlet.getResults("openscout-object") - if (result != None): - if result.payload_type == gabriel_pb2.TEXT: - try: - json_string = result.payload.decode('utf-8') - json_data = json.loads(json_string) - for item in json_data: - class_attribute = item['class'] - hsv_filter = item['hsv_filter'] - if (class_attribute == self.target and hsv_filter): - logger.info(f"**************Transition: Task {self.task_id}: detect condition met! {class_attribute}**************\n") - self._trigger_event("hsv_detection") - break - except JSONDecodeError as e: - logger.error(f'Error decoding json: {json_string}') - except Exception as e: - logger.info(e) - + if result is not None and result.payload_type == gabriel_pb2.TEXT: + try: + json_string = result.payload.decode("utf-8") + json_data = json.loads(json_string) + for item in json_data: + class_attribute = item["class"] + hsv_filter = item["hsv_filter"] + if class_attribute == self.target and hsv_filter: + logger.info( + f"**************Transition: Task {self.task_id}: detect condition met! {class_attribute}**************\n" + ) + self._trigger_event("hsv_detection") + break + except JSONDecodeError: + logger.error(f"Error decoding json: {json_string}") + except Exception as e: + logger.info(e) + self._unregister() - - diff --git a/droneDSL/python/project/transition_defs/ObjectDetectionTransition.py b/droneDSL/python/project/transition_defs/ObjectDetectionTransition.py index c2aa9bd3..5d8f750a 100644 --- a/droneDSL/python/project/transition_defs/ObjectDetectionTransition.py +++ b/droneDSL/python/project/transition_defs/ObjectDetectionTransition.py @@ -1,24 +1,25 @@ -from json import JSONDecodeError import json import logging import time -from venv import logger -from interface.Transition import Transition +from json import JSONDecodeError + from gabriel_protocol import gabriel_pb2 +from interface.Transition import Transition logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) + class ObjectDetectionTransition(Transition): def __init__(self, args, target, cloudlet): super().__init__(args) self.stop_signal = False - self.target =target + self.target = target self.cloudlet = cloudlet - + def stop(self): self.stop_signal = True - + def run(self): self._register() time.sleep(4) @@ -26,30 +27,34 @@ def run(self): while not self.stop_signal: # get result result = self.cloudlet.getResults("openscout-object") - if (result != None): - logger.info(f"**************Transition: Task {self.task_id}: detected payload! {result}**************\n") + if result is not None: + logger.info( + f"**************Transition: Task {self.task_id}: detected payload! {result}**************\n" + ) # Check if the payload type is TEXT, since your JSON seems to be text data if result.payload_type == gabriel_pb2.TEXT: try: # Decode the payload from bytes to string - json_string = result.payload.decode('utf-8') + json_string = result.payload.decode("utf-8") # Parse the JSON string json_data = json.loads(json_string) # Access the 'class' attribute - class_attribute = json_data[0]['class'] # Adjust the indexing based on your JSON structure + class_attribute = json_data[0][ + "class" + ] # Adjust the indexing based on your JSON structure logger.info(class_attribute) - if (class_attribute== self.target): - logger.info(f"**************Transition: Task {self.task_id}: detect condition met! {class_attribute}**************\n") - self._trigger_event("object_detection") - break - except JSONDecodeError as e: - logger.error(f'Error decoding json: {json_string}') + if class_attribute == self.target: + logger.info( + f"**************Transition: Task {self.task_id}: detect condition met! {class_attribute}**************\n" + ) + self._trigger_event("object_detection") + break + except JSONDecodeError: + logger.error(f"Error decoding json: {json_string}") except Exception as e: logger.info(e) - # print("object stopping...\n") + # print("object stopping...\n") self._unregister() - - \ No newline at end of file diff --git a/droneDSL/python/project/transition_defs/TimerTransition.py b/droneDSL/python/project/transition_defs/TimerTransition.py index 11637693..29ebd5bc 100644 --- a/droneDSL/python/project/transition_defs/TimerTransition.py +++ b/droneDSL/python/project/transition_defs/TimerTransition.py @@ -1,25 +1,27 @@ import logging import threading + from interface.Transition import Transition logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -class TimerTransition (Transition): + +class TimerTransition(Transition): def __init__(self, args, timer_interval): super().__init__(args) # Create a Timer within the thread self.timer = threading.Timer(timer_interval, self._trigger_event, ["timeout"]) self.completed = True - - def stop (self): + + def stop(self): self.timer.cancel() self.completed = False - + def run(self): self._register() self.timer.start() self.timer.join() # Optionally wait for the timer to finish - if (self.completed): + if self.completed: logger.info(f"**************Transition: Task {self.task_id}: timeout!**************\n") - self._unregister() \ No newline at end of file + self._unregister() diff --git a/field-testing/avoidance/common.py b/field-testing/avoidance/common.py index 6849642c..669a7dbe 100755 --- a/field-testing/avoidance/common.py +++ b/field-testing/avoidance/common.py @@ -2,12 +2,13 @@ # # SPDX-License-Identifier: GPL-2.0-only -import numpy as np -import cv2 from operator import attrgetter +import cv2 +import numpy as np + -class KeyPointHistory(object): +class KeyPointHistory: def __init__(self): self.age = -1 self.lastFrameIdx = 0 @@ -18,7 +19,7 @@ def __init__(self): self.descriptor = None self.consecutive = 0 - def update(self,kp,desc,t0,t1,scale): + def update(self, kp, desc, t0, t1, scale): if self.timehist and t0 == self.timehist[-1][-1]: self.consecutive += 1 else: @@ -28,7 +29,7 @@ def update(self,kp,desc,t0,t1,scale): self.lastFrameIdx = 0 self.detects += 1 self.scalehist.append(scale) - self.timehist.append((t0,t1)) + self.timehist.append((t0, t1)) self.descriptor = desc.copy() self.keypoint = copyKP(kp) @@ -38,25 +39,41 @@ def downdate(self): return self def __repr__(self): - return repr(dict((attr,getattr(self,attr)) for attr in dir(self) - if not attr.startswith('_') and not callable(getattr(self,attr)))) + return repr( + dict( + (attr, getattr(self, attr)) + for attr in dir(self) + if not attr.startswith("_") and not callable(getattr(self, attr)) + ) + ) + def __str__(self): - return str(dict((attr,getattr(self,attr)) for attr in dir(self) - if not attr.startswith('_') and not callable(getattr(self,attr)))) + return str( + dict( + (attr, getattr(self, attr)) + for attr in dir(self) + if not attr.startswith("_") and not callable(getattr(self, attr)) + ) + ) + -class Cluster(object): - def __init__(self,keypoints,img): +class Cluster: + def __init__(self, keypoints, img): self.mask = np.zeros_like(img) for kp in keypoints: - cv2.circle(self.mask,inttuple(*kp.pt),int(kp.size//2),1,thickness=-1) + cv2.circle(self.mask, inttuple(*kp.pt), int(kp.size // 2), 1, thickness=-1) self.area = np.sum(self.mask) self.pt = findCoM(self.mask) self.p0, self.p1 = BlobBoundingBox(self.mask) self.KPs = [copyKP(kp) for kp in keypoints] - self.dist = [diffKP_L2(self.KPs[i],self.KPs[j]) for i in range(len(self.KPs)-1) for j in range(i+1,len(self.KPs))] + self.dist = [ + diffKP_L2(self.KPs[i], self.KPs[j]) + for i in range(len(self.KPs) - 1) + for j in range(i + 1, len(self.KPs)) + ] def __repr__(self): - return str(map(repr,(self.pt,self.area,len(self.KPs)))) + return str(map(repr, (self.pt, self.area, len(self.KPs)))) def BlobBoundingBox(blob): @@ -68,55 +85,102 @@ def BlobBoundingBox(blob): ones = np.flatnonzero(diff) ymin, ymax = ones[0], ones[-1] - return (xmin,ymin),(xmax,ymax) + return (xmin, ymin), (xmax, ymax) def findCoM(mask): - colnums = np.arange(np.shape(mask)[1]).reshape(1,-1) - rownums = np.arange(np.shape(mask)[0]).reshape(-1,1) + colnums = np.arange(np.shape(mask)[1]).reshape(1, -1) + rownums = np.arange(np.shape(mask)[0]).reshape(-1, 1) - x = np.sum(mask*colnums) // np.sum(mask) - y = np.sum(mask*rownums) // np.sum(mask) + x = np.sum(mask * colnums) // np.sum(mask) + y = np.sum(mask * rownums) // np.sum(mask) return x, y -trunc_coords = lambda shape,xy: [round(x) if x >= 0 and x <= dimsz else (0 if x < 0 else dimsz) - for dimsz,x in zip(shape[::-1],xy)] +def trunc_coords(shape, xy): + return [ + round(x) if x >= 0 and x <= dimsz else (0 if x < 0 else dimsz) + for dimsz, x in zip(shape[::-1], xy) + ] + + +def bboverlap(cl1, cl2): + return (cl1.p0[0] <= cl2.p1[0] and cl1.p1[0] >= cl2.p0[0]) and ( + cl1.p0[1] <= cl2.p1[1] and cl1.p1[1] >= cl2.p0[1] + ) + -bboverlap = lambda cl1,cl2: (cl1.p0[0] <= cl2.p1[0] and cl1.p1[0] >= cl2.p0[0]) and (cl1.p0[1] <= cl2.p1[1] and cl1.p1[1] >= cl2.p0[1]) +def overlap(kp1, kp2, eps=0): + return (kp1.size // 2 + kp2.size // 2 + eps) > diffKP_L2(kp1, kp2) -overlap = lambda kp1,kp2,eps=0: (kp1.size//2+kp2.size//2+eps) > diffKP_L2(kp1,kp2) -diffKP_L2 = lambda kp0,kp1: np.sqrt((kp0.pt[0]-kp1.pt[0])**2 + (kp0.pt[1]-kp1.pt[1])**2) +def diffKP_L2(kp0, kp1): + return np.sqrt((kp0.pt[0] - kp1.pt[0]) ** 2 + (kp0.pt[1] - kp1.pt[1]) ** 2) -diffKP = lambda kp0,kp1: (kp0.pt[0]-kp1.pt[0], kp0.pt[1]-kp1.pt[1]) -difftuple_L2 = lambda p0,p1: np.sqrt((p0[0]-p1[0])**2 + (p0[1]-p1[1])**2) +def diffKP(kp0, kp1): + return (kp0.pt[0] - kp1.pt[0], kp0.pt[1] - kp1.pt[1]) -difftuple = lambda p0,p1: (p1[0]-p0[0],p1[1]-p0[1]) -inttuple = lambda *x: tuple(map(int,x)) +def difftuple_L2(p0, p1): + return np.sqrt((p0[0] - p1[0]) ** 2 + (p0[1] - p1[1]) ** 2) -roundtuple = lambda *x: tuple(map(int,map(round,x))) -avgKP = lambda keypoints: map(lambda x: sum(x)/len(keypoints),zip(*map(attrgetter('pt'),keypoints))) +def difftuple(p0, p1): + return (p1[0] - p0[0], p1[1] - p0[1]) + + +def inttuple(*x): + return tuple(map(int, x)) + + +def roundtuple(*x): + return tuple(map(int, map(round, x))) + + +def avgKP(keypoints): + return map(lambda x: sum(x) / len(keypoints), zip(*map(attrgetter("pt"), keypoints))) + + +def toKeyPoint_cv(kp): + return cv2.KeyPoint( + kp.pt[0], + kp.pt[1], + kp.size, + _angle=kp.angle, + _response=kp.response, + _octave=kp.octave, + _class_id=kp.class_id, + ) -toKeyPoint_cv = lambda kp: cv2.KeyPoint(kp.pt[0],kp.pt[1],kp.size,_angle=kp.angle,_response=kp.response,_octave=kp.octave,_class_id=kp.class_id) def reprObj(obj): - return "\n".join(["%s = %s" % (attr, getattr(obj, attr)) for attr in dir(obj) if not attr.startswith('_') and not callable(getattr(src,attr))]) + return "\n".join( + [ + f"{attr} = {getattr(obj, attr)}" + for attr in dir(obj) + if not attr.startswith("_") and not callable(getattr(obj, attr)) + ] + ) + + +def cvtIdx(pt, shape): + return ( + int(pt[1] * shape[1] + pt[0]) + if hasattr(pt, "__len__") + else map(int, (pt % shape[1], pt // shape[1])) + ) -def cvtIdx(pt,shape): - return int(pt[1]*shape[1] + pt[0]) if hasattr(pt, '__len__') else map(int, (pt%shape[1], pt//shape[1])) +def drawInto(src, dst, tl=(0, 0)): + dst[tl[1] : tl[1] + src.shape[0], tl[0] : tl[0] + src.shape[1]] = src -def drawInto(src, dst, tl=(0,0)): - dst[tl[1]:tl[1]+src.shape[0], tl[0]:tl[0]+src.shape[1]] = src -def copyKP(src,dst=None): - if dst is None: dst = cv2.KeyPoint() +def copyKP(src, dst=None): + if dst is None: + dst = cv2.KeyPoint() for attr in dir(src): - if not attr.startswith('_') and not callable(getattr(src,attr)): - setattr(dst,attr,getattr(src,attr)) + if not attr.startswith("_") and not callable(getattr(src, attr)): + setattr(dst, attr, getattr(src, attr)) return dst diff --git a/field-testing/avoidance/midas_avoider.py b/field-testing/avoidance/midas_avoider.py index f3d72d88..e508a7c9 100644 --- a/field-testing/avoidance/midas_avoider.py +++ b/field-testing/avoidance/midas_avoider.py @@ -2,26 +2,24 @@ # # SPDX-License-Identifier: GPL-2.0-only -import olympe -from olympe.messages.ardrone3.Piloting import PCMD -from olympe.messages.ardrone3.PilotingState import GpsLocationChanged +import json import threading import time + import zmq -import json -import numpy as np -from datetime import datetime -import os +from olympe.messages.ardrone3.Piloting import PCMD +from olympe.messages.ardrone3.PilotingState import GpsLocationChanged FOLDER = "./avoidance/traces/" + class MiDaSAvoider(threading.Thread): def __init__(self, drone, speed=5, hysteresis=True): self.drone = drone self.context = zmq.Context() self.sub_socket = self.context.socket(zmq.SUB) - self.sub_socket.connect('tcp://localhost:5556') - self.sub_socket.setsockopt(zmq.SUBSCRIBE, b'') + self.sub_socket.connect("tcp://localhost:5556") + self.sub_socket.setsockopt(zmq.SUBSCRIBE, b"") self.speed = max(1, min(speed, 100)) self.hysteresis = hysteresis self.image_size = (640, 480) @@ -40,32 +38,29 @@ def move_by_offsets(self, vec): def run(self): self.tracking = False self.active = True - - #trace = open(FOLDER + datetime.now().strftime("%m-%d-%Y-%H-%M-%S") + ".txt", 'a') - #print("Writing trace!") + + # trace = open(FOLDER + datetime.now().strftime("%m-%d-%Y-%H-%M-%S") + ".txt", 'a') + # print("Writing trace!") lastvec = 0 while self.active: gps = self.drone.get_state(GpsLocationChanged) lat = gps["latitude"] lng = gps["longitude"] - #trace.write(f"{lat}, {lng}" + '\n') + # trace.write(f"{lat}, {lng}" + '\n') print("Wrote coordinates.") try: vec = json.loads(self.sub_socket.recv_json(flags=zmq.NOBLOCK))[0]["vector"] print(f"Receiving detections: {vec}") - if self.hysteresis: - diff = vec - lastvec - else: - diff = 0 + diff = vec - lastvec if self.hysteresis else 0 lastvec = vec self.move_by_offsets(vec + diff) - except Exception as e: + except Exception: print(f"Actuating on last: {lastvec}") self.move_by_offsets(lastvec) time.sleep(0.05) - #trace.close() + # trace.close() def stop(self): self.active = False diff --git a/field-testing/avoidance/sift_avoider.py b/field-testing/avoidance/sift_avoider.py index 9202b955..103ccc0d 100755 --- a/field-testing/avoidance/sift_avoider.py +++ b/field-testing/avoidance/sift_avoider.py @@ -4,37 +4,36 @@ # # SPDX-License-Identifier: GPL-2.0-only -import olympe -from olympe.messages.ardrone3.Piloting import PCMD -from olympe.messages.ardrone3.PilotingState import GpsLocationChanged +import sys +import time + import cv2 import numpy as np -import logging -from collections import OrderedDict -import time,sys -from matplotlib import pyplot as plt -import sys -sys.path.append('./avoidance/') -from common import * -import operator as op +from olympe.messages.ardrone3.Piloting import PCMD +from olympe.messages.ardrone3.PilotingState import GpsLocationChanged + +sys.path.append("./avoidance/") import math +import operator as op import threading -import zmq -import time -import traceback from datetime import datetime -STREAM_FPS = 1 #used for ttc +import zmq +from common import Cluster, overlap + +STREAM_FPS = 1 # used for ttc RATIO = 0.75 AGE = 10 FOLDER = "./avoidance/old-traces/" + def ClusterKeypoints(keypoints, img, epsilon=0): - if len(keypoints) < 2: return [] + if len(keypoints) < 2: + return [] cluster = [] - unclusteredKPs = sorted(keypoints,key=op.attrgetter('pt')) + unclusteredKPs = sorted(keypoints, key=op.attrgetter("pt")) while unclusteredKPs: clust = [unclusteredKPs.pop(0)] kp = clust[0] @@ -44,19 +43,24 @@ def ClusterKeypoints(keypoints, img, epsilon=0): clust.append(unclusteredKPs.pop(i)) else: i += 1 - if (len(clust) >= 3): cluster.append(Cluster(clust,img)) + if len(clust) >= 3: + cluster.append(Cluster(clust, img)) return cluster + def mse(img1, img2): - h, w = img1.shape - diff = cv2.subtract(img1, img2) - err = np.sum(diff**2) - mse = err/(float(h*w)) - return mse + h, w = img1.shape + diff = cv2.subtract(img1, img2) + err = np.sum(diff**2) + mse = err / (float(h * w)) + return mse + class SIFTAvoider(threading.Thread): - def __init__(self, drone, contrast=0.04, edge=50, dist=200.0, scale=1.3, roi=3, eps=50, speed=5): + def __init__( + self, drone, contrast=0.04, edge=50, dist=200.0, scale=1.3, roi=3, eps=50, speed=5 + ): self.drone = drone self.c = contrast self.e = edge @@ -69,8 +73,8 @@ def __init__(self, drone, contrast=0.04, edge=50, dist=200.0, scale=1.3, roi=3, self.context = zmq.Context() self.socket = self.context.socket(zmq.SUB) - self.socket.connect('tcp://localhost:5555') - self.socket.setsockopt(zmq.SUBSCRIBE, b'') + self.socket.connect("tcp://localhost:5555") + self.socket.setsockopt(zmq.SUBSCRIBE, b"") self.image_size = (640, 480) self.image = None @@ -80,7 +84,7 @@ def __init__(self, drone, contrast=0.04, edge=50, dist=200.0, scale=1.3, roi=3, self.init_sift() super().__init__() - + def execute_PCMD(self, dpitch, droll): print(f"ROLL: {droll}, PITCH: {dpitch}") self.drone(PCMD(1, round(droll), round(dpitch), 0, 0, timestampAndSeqNum=0)) @@ -96,21 +100,21 @@ def recv_array(self, flags=zmq.NOBLOCK, copy=True, track=False): md = self.socket.recv_json(flags=flags) msg = self.socket.recv(flags=flags, copy=copy, track=track) buf = memoryview(msg) - A = np.frombuffer(buf, dtype=md['dtype']) - return A.reshape(md['shape']) + A = np.frombuffer(buf, dtype=md["dtype"]) + return A.reshape(md["shape"]) def run(self): self.tracking = False self.active = True - trace = open(FOLDER + datetime.now().strftime("%m-%d-%Y-%H-%M-%S") + ".txt", 'a') + trace = open(FOLDER + datetime.now().strftime("%m-%d-%Y-%H-%M-%S") + ".txt", "a") lastvec = 0 while self.active: gps = self.drone.get_state(GpsLocationChanged) lat = gps["latitude"] lng = gps["longitude"] - trace.write(f"{lat}, {lng}" + '\n') + trace.write(f"{lat}, {lng}" + "\n") try: self.image = self.recv_array() vec = self.get_offsets() @@ -119,7 +123,7 @@ def run(self): self.prev_image = self.image if vec is not None: self.move_by_offsets(vec) - except Exception as e: + except Exception: if lastvec: self.move_by_offsets(lastvec) time.sleep(0.05) @@ -140,10 +144,13 @@ def init_sift(self): self.id = 1 try: - self.roi = np.zeros(prev_img.shape,np.uint8) - scrapY, scrapX = prev_img.shape[0]//self.r, prev_img.shape[1]//(self.r + 1) + self.roi = np.zeros(self.prev_img.shape, np.uint8) + scrapY, scrapX = ( + self.prev_img.shape[0] // self.r, + self.prev_img.shape[1] // (self.r + 1), + ) self.roi[scrapY:-scrapY, scrapX:-scrapX] = True - except Exception as e: + except Exception: pass def match(self): @@ -156,7 +163,7 @@ def match(self): # Sort them in the order of their distance. ## find the list of kps that were not matched. - matches = sorted(matches, key = lambda x:x.distance) + matches = sorted(matches, key=lambda x: x.distance) # Filter out bad matches good = [] @@ -165,13 +172,13 @@ def match(self): good.append(m) train_indices_matched.append(m.trainIdx) for each_prev_kp_index in range(len(self.prev_kps)): - if each_prev_kp_index not in train_indices_matched: + if each_prev_kp_index not in train_indices_matched: prev_kps_not_matched.append(self.prev_kps[each_prev_kp_index]) prev_descs_not_matched.append(self.prev_descs[each_prev_kp_index]) - if(len(matches) > 0): - max_value = max(matches, key=lambda x : x.distance) - min_value = min(matches, key=lambda x : x.distance) + if len(matches) > 0: + max_value = max(matches, key=lambda x: x.distance) + min_value = min(matches, key=lambda x: x.distance) return good @@ -185,29 +192,43 @@ def cull(self, good, curr, prev): for g in good: clsid = prev_kps[g.trainIdx].class_id - curr_point_x, curr_point_y = int(np.round(kps[g.queryIdx].pt[0])), int(np.round(kps[g.queryIdx].pt[1])) - prev_point_x, prev_point_y = int(np.round(prev_kps[g.trainIdx].pt[0])), int(np.round(prev_kps[g.trainIdx].pt[1])) - - size_expansion = 1.5 ### MAKE THIS TUNABLE - curr_size = int(max(1, np.round((kps[g.queryIdx].size)*size_expansion))) - prev_size = int(max(1, np.round((prev_kps[g.trainIdx].size)*size_expansion))) + curr_point_x, curr_point_y = ( + int(np.round(kps[g.queryIdx].pt[0])), + int(np.round(kps[g.queryIdx].pt[1])), + ) + prev_point_x, prev_point_y = ( + int(np.round(prev_kps[g.trainIdx].pt[0])), + int(np.round(prev_kps[g.trainIdx].pt[1])), + ) + + size_expansion = 1.5 ### MAKE THIS TUNABLE + curr_size = int(max(1, np.round((kps[g.queryIdx].size) * size_expansion))) + prev_size = int(max(1, np.round((prev_kps[g.trainIdx].size) * size_expansion))) if prev_size <= curr_size - 2: pass curr_total_x, curr_total_y = img.shape[1], img.shape[0] prev_total_x, prev_total_y = prev_img.shape[1], prev_img.shape[0] - if (curr_point_x + curr_size/2 + 1 > curr_total_x) or (curr_point_y + curr_size/2 + 1 > curr_total_y) or (prev_point_x + prev_size/2 + 1 > prev_total_x) or (prev_point_y + prev_size/2 + 1 > prev_total_y) or \ - (curr_point_x - curr_size / 2 - 1 < 0) or (curr_point_y - curr_size / 2 - 1 < 0) or (prev_point_x - prev_size / 2 - 1 < 0) or (prev_point_y - prev_size / 2 - 1 < 0): + if ( + (curr_point_x + curr_size / 2 + 1 > curr_total_x) + or (curr_point_y + curr_size / 2 + 1 > curr_total_y) + or (prev_point_x + prev_size / 2 + 1 > prev_total_x) + or (prev_point_y + prev_size / 2 + 1 > prev_total_y) + or (curr_point_x - curr_size / 2 - 1 < 0) + or (curr_point_y - curr_size / 2 - 1 < 0) + or (prev_point_x - prev_size / 2 - 1 < 0) + or (prev_point_y - prev_size / 2 - 1 < 0) + ): + continue + if curr_size < prev_size + 2: continue - if curr_size < prev_size + 2: continue - # extract sub image from perev image ## new image borders - if curr_size %2 == 1: - left = int(curr_point_x - math.floor(curr_size/2)) - right = int(curr_point_x + math.ceil(curr_size/2)) - top = int(curr_point_y - math.floor(curr_size/2)) - bottom = int(curr_point_y + math.ceil(curr_size/2)) + if curr_size % 2 == 1: + left = int(curr_point_x - math.floor(curr_size / 2)) + right = int(curr_point_x + math.ceil(curr_size / 2)) + top = int(curr_point_y - math.floor(curr_size / 2)) + bottom = int(curr_point_y + math.ceil(curr_size / 2)) else: left = int(curr_point_x - curr_size / 2) right = int(curr_point_x + curr_size / 2) @@ -230,25 +251,28 @@ def cull(self, good, curr, prev): temp_prev_image = np.asarray(prev_img[top:bottom, left:right]) ## loop from previous key point length to current key point length - scale_results=np.empty(0) + scale_results = np.empty(0) results_dict = {} for expansion in range(0, curr_size - prev_size, 2): - cropped_temp_curr_image = temp_curr_image[expansion//2:curr_size - expansion//2, expansion//2:curr_size - expansion//2] - resized_temp_prev_image = cv2.resize(temp_prev_image, (curr_size - expansion, curr_size - expansion)) + cropped_temp_curr_image = temp_curr_image[ + expansion // 2 : curr_size - expansion // 2, + expansion // 2 : curr_size - expansion // 2, + ] + resized_temp_prev_image = cv2.resize( + temp_prev_image, (curr_size - expansion, curr_size - expansion) + ) error = mse(cropped_temp_curr_image, resized_temp_prev_image) - results_dict[error] = resized_temp_prev_image.shape[0]/temp_prev_image.shape[0] + results_dict[error] = resized_temp_prev_image.shape[0] / temp_prev_image.shape[0] best = min(results_dict.keys()) ratio = results_dict[best] - - if(ratio > self.s): ## changed the condition check for clustering + if ratio > self.s: ## changed the condition check for clustering bigger.append(g) expandingKPs.append(prev_kps[g.trainIdx]) return bigger, expandingKPs - def get_offsets(self): print("Getting offsets!") if self.first_image: @@ -272,20 +296,24 @@ def get_offsets(self): for kp in self.kps: kp.class_id = self.id self.id += 1 - + good = self.match() prev_gray = cv2.cvtColor(self.prev_image, cv2.COLOR_BGR2GRAY) - b, expand = self.cull(good, (self.kps, self.descs, img), (self.prev_kps, self.prev_descs, prev_gray)) + b, expand = self.cull( + good, (self.kps, self.descs, img), (self.prev_kps, self.prev_descs, prev_gray) + ) print("Clustering.") cluster = ClusterKeypoints(expand, img, epsilon=self.eps) - b_w_disp[0:b_w_disp.shape[0], 0:b_w_disp.shape[1]] = (255, 255, 255) + b_w_disp[0 : b_w_disp.shape[0], 0 : b_w_disp.shape[1]] = (255, 255, 255) for c in cluster: - b_w_disp[0:img.shape[1], c.p0[0]:c.p1[0]] = (0,0,0) + b_w_disp[0 : img.shape[1], c.p0[0] : c.p1[0]] = (0, 0, 0) # Scrap out the ROI - scrapY, scrapX = self.image_size[0]//self.r, self.image_size[1]//(self.r + 1) - b_w_disp = b_w_disp[ scrapY : b_w_disp.shape[0] - scrapY,scrapX : b_w_disp.shape[1] - scrapX] + scrapY, scrapX = self.image_size[0] // self.r, self.image_size[1] // (self.r + 1) + b_w_disp = b_w_disp[ + scrapY : b_w_disp.shape[0] - scrapY, scrapX : b_w_disp.shape[1] - scrapX + ] b_w_disp = cv2.cvtColor(b_w_disp, cv2.COLOR_BGR2GRAY) # find contours in the binary image @@ -297,8 +325,16 @@ def get_offsets(self): cX = int(M["m10"] / M["m00"]) cY = int(M["m01"] / M["m00"]) cv2.circle(dispim, (scrapX + cX, scrapY + cY), 5, (0, 255, 0), -1) - cv2.putText(dispim, "safe", (scrapX + cX, scrapY + cY - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) - except Exception as e: + cv2.putText( + dispim, + "safe", + (scrapX + cX, scrapY + cY - 10), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (0, 255, 0), + 2, + ) + except Exception: pass self.prev_img = self.image @@ -308,4 +344,4 @@ def get_offsets(self): cv2.imshow("Safe", dispim) cv2.waitKey(1) - return (scrapX + cX) - (self.image_size[0] / 2) + return (scrapX + cX) - (self.image_size[0] / 2) diff --git a/field-testing/laptop-controller.py b/field-testing/laptop-controller.py index 75ce62cc..3933df48 100644 --- a/field-testing/laptop-controller.py +++ b/field-testing/laptop-controller.py @@ -2,36 +2,37 @@ # # SPDX-License-Identifier: GPL-2.0-only -import olympe -from olympe.messages.ardrone3.Piloting import TakeOff, Landing, PCMD, moveBy -from olympe.messages.ardrone3.PilotingState import AttitudeChanged, GpsLocationChanged, AltitudeChanged -from olympe.messages.skyctrl.CoPiloting import setPilotingSource -from olympe.messages.obstacle_avoidance import set_mode -from olympe.enums.obstacle_avoidance import mode -from olympe.messages.gimbal import set_target, attitude -from olympe.enums.gimbal import control_mode -from olympe.video.renderer import PdrawRenderer -import threading -import time -import queue +import argparse import logging +import queue import subprocess -import cv2 -from pynput.keyboard import Listener, Key, KeyCode +import threading +import time from collections import defaultdict +from datetime import datetime from enum import Enum -from time import sleep -import numpy as np -import math -import os + +import cv2 +import olympe import zmq -import json -from trackers import dynamic, static, parrot -from avoidance import sift_avoider, midas_avoider -import argparse -from datetime import datetime +from avoidance import midas_avoider, sift_avoider +from olympe.enums.gimbal import control_mode +from olympe.enums.obstacle_avoidance import mode +from olympe.messages.ardrone3.Piloting import PCMD, Landing, TakeOff +from olympe.messages.ardrone3.PilotingState import ( + AltitudeChanged, + AttitudeChanged, + GpsLocationChanged, +) +from olympe.messages.gimbal import attitude, set_target +from olympe.messages.obstacle_avoidance import set_mode +from olympe.messages.skyctrl.CoPiloting import setPilotingSource +from olympe.video.renderer import PdrawRenderer +from pynput.keyboard import Key, KeyCode, Listener +from trackers import dynamic + +DRONE_IP = "192.168.42.1" # Real drone no controller -DRONE_IP = "192.168.42.1" # Real drone no controller class Ctrl(Enum): ( @@ -96,10 +97,7 @@ def _on_press(self, key): self._key_pressed[key.char] = True elif isinstance(key, Key): self._key_pressed[key] = True - if self._key_pressed[self._ctrl_keys[Ctrl.QUIT]]: - return False - else: - return True + return not self._key_pressed[self._ctrl_keys[Ctrl.QUIT]] def _on_release(self, key): if isinstance(key, KeyCode): @@ -112,50 +110,33 @@ def quit(self): return not self.running or self._key_pressed[self._ctrl_keys[Ctrl.QUIT]] def _axis(self, left_key, right_key): - return 20 * ( - int(self._key_pressed[right_key]) - int(self._key_pressed[left_key]) - ) + return 20 * (int(self._key_pressed[right_key]) - int(self._key_pressed[left_key])) def roll(self): - return self._axis( - self._ctrl_keys[Ctrl.MOVE_LEFT], - self._ctrl_keys[Ctrl.MOVE_RIGHT] - ) + return self._axis(self._ctrl_keys[Ctrl.MOVE_LEFT], self._ctrl_keys[Ctrl.MOVE_RIGHT]) def pitch(self): - return self._axis( - self._ctrl_keys[Ctrl.MOVE_BACKWARD], - self._ctrl_keys[Ctrl.MOVE_FORWARD] - ) + return self._axis(self._ctrl_keys[Ctrl.MOVE_BACKWARD], self._ctrl_keys[Ctrl.MOVE_FORWARD]) def yaw(self): - return self._axis( - self._ctrl_keys[Ctrl.TURN_LEFT], - self._ctrl_keys[Ctrl.TURN_RIGHT] - ) - + return self._axis(self._ctrl_keys[Ctrl.TURN_LEFT], self._ctrl_keys[Ctrl.TURN_RIGHT]) def _clamp(self, num, min_val, max_val): return max(min(num, max_val), min_val) - def gimbal_pitch(self): - axis = float(self._axis( - self._ctrl_keys[Ctrl.GIMBAL_DOWN], - self._ctrl_keys[Ctrl.GIMBAL_UP] - ) / 20) + axis = float( + self._axis(self._ctrl_keys[Ctrl.GIMBAL_DOWN], self._ctrl_keys[Ctrl.GIMBAL_UP]) / 20 + ) self._current_gimbal_pitch += axis self._clamp(self._current_gimbal_pitch, -90.0, 90.0) return axis - + def get_gimbal_target(self): return self._current_gimbal_pitch def throttle(self): - return self._axis( - self._ctrl_keys[Ctrl.MOVE_DOWN], - self._ctrl_keys[Ctrl.MOVE_UP] - ) + return self._axis(self._ctrl_keys[Ctrl.MOVE_DOWN], self._ctrl_keys[Ctrl.MOVE_UP]) def has_piloting_cmd(self): return ( @@ -181,7 +162,7 @@ def takeoff(self): def landing(self): return self._rate_limit_cmd(Ctrl.LANDING, 2.0) - + def start_track(self): return self._rate_limit_cmd(Ctrl.START_TRACK, 2.0) @@ -197,8 +178,7 @@ def _get_ctrl_keys(self, ctrl_keys): # and the following only works on *nix/X11... keyboard_variant = ( subprocess.check_output( - "setxkbmap -query | grep 'variant:'|" - "cut -d ':' -f2 | tr -d ' '", + "setxkbmap -query | grep 'variant:'|" "cut -d ':' -f2 | tr -d ' '", shell=True, ) .decode() @@ -213,13 +193,13 @@ def _get_ctrl_keys(self, ctrl_keys): class OlympeStreaming(threading.Thread): - def __init__(self, drone, sample_rate=5, model='coco'): + def __init__(self, drone, sample_rate=5, model="coco"): self.drone = drone self.sample_rate = sample_rate self.model = model self.frame_queue = queue.Queue() self.flush_queue_lock = threading.Lock() - self.frame_num = 0 + self.frame_num = 0 self.renderer = None super().__init__() super().start() @@ -230,7 +210,7 @@ def start(self): # Socket to talk to server print("Publishing images for OpenScout client's ZmqAdapter..") self.socket = self.context.socket(zmq.PUB) - self.socket.bind('tcp://*:5555') + self.socket.bind("tcp://*:5555") # Setup your callback functions to do some live video processing self.drone.streaming.set_callbacks( @@ -240,7 +220,7 @@ def start(self): end_cb=self.end_cb, flush_raw_cb=self.flush_cb, ) - + # Start video streaming self.drone.streaming.start() self.renderer = PdrawRenderer(pdraw=self.drone.streaming) @@ -280,14 +260,18 @@ def h264_frame_cb(self, h264_frame): def send_array(self, A, meta, flags=0, copy=True, track=False): """send a numpy array with metadata""" md = dict( - dtype = str(A.dtype), - shape = A.shape, - location = {"latitude": meta["latitude"], "longitude": meta["longitude"], "altitude": meta["altitude"]}, - model = self.model, - gimbal_pitch = meta["gimbal_pitch"], - heading = meta["heading"] + dtype=str(A.dtype), + shape=A.shape, + location={ + "latitude": meta["latitude"], + "longitude": meta["longitude"], + "altitude": meta["altitude"], + }, + model=self.model, + gimbal_pitch=meta["gimbal_pitch"], + heading=meta["heading"], ) - self.socket.send_json(md, flags|zmq.SNDMORE) + self.socket.send_json(md, flags | zmq.SNDMORE) return self.socket.send(A, flags, copy=copy, track=track) def send_yuv_frame_to_server(self, yuv_frame): @@ -301,12 +285,18 @@ def send_yuv_frame_to_server(self, yuv_frame): info["raw"]["frame"]["info"]["width"], ) - #print(yuv_frame.vmeta()[1]) + # print(yuv_frame.vmeta()[1]) gps = drone.get_state(GpsLocationChanged) alt = drone.get_state(AltitudeChanged) att = drone.get_state(AttitudeChanged) gatt = drone.get_state(attitude) - meta = {"latitude": gps["latitude"], "longitude": gps["longitude"], "altitude": alt["altitude"], "heading": att["yaw"], "gimbal_pitch": gatt[0]["pitch_absolute"]} + meta = { + "latitude": gps["latitude"], + "longitude": gps["longitude"], + "altitude": alt["altitude"], + "heading": att["yaw"], + "gimbal_pitch": gatt[0]["pitch_absolute"], + } # yuv_frame.vmeta() returns a dictionary that contains additional # metadata from the drone (GPS coordinates, battery percentage, ...) @@ -323,16 +313,14 @@ def send_yuv_frame_to_server(self, yuv_frame): cv2frame = cv2.cvtColor(yuv_frame.as_ndarray(), cv2_cvt_color_flag) cv2frame = cv2.resize(cv2frame, (640, 480)) if self.frame_num % (30 / self.sample_rate) == 0: - #print(f"Publishing frame {self.frame_num} to OpenScout client...") + # print(f"Publishing frame {self.frame_num} to OpenScout client...") self.send_array(cv2frame, meta) self.frame_num += 1 except Exception as e: print("Got an exception", e) def run(self): - main_thread = next( - filter(lambda t: t.name == "MainThread", threading.enumerate()) - ) + main_thread = next(filter(lambda t: t.name == "MainThread", threading.enumerate())) while main_thread.is_alive(): with self.flush_queue_lock: try: @@ -341,8 +329,8 @@ def run(self): continue try: self.send_yuv_frame_to_server(yuv_frame) - except Exception as e: - #print(e) + except Exception: + # print(e) pass finally: # Don't forget to unref the yuv frame. We don't want to @@ -355,13 +343,31 @@ def run(self): if __name__ == "__main__": parser = argparse.ArgumentParser(usage="laptop-controller.py [options]") - parser.add_argument("-c", "--controller", action='store_true', help="Use an attached Parrot SkyController to increase Olympe range") - parser.add_argument("-s", "--simulate", action='store_true', help="Run on a simulated drone in Parrot Sphinx") - parser.add_argument("-nf", "--nofly", action='store_true', help="Prevent flight while running") - parser.add_argument("-ns", "--nostream", action='store_true', help="Prevent streaming while running") - parser.add_argument("-mds", "--midas", action='store_true', help="Use MiDaS to do obstacle avoidance") - parser.add_argument("-sft", "--sift", action='store_true', help="Use SIFT to do obstacle avoidance") - parser.add_argument("-o", "--obstacle", action='store_true', help="Use built-in obstacle avoidance (Anafi Ai only)") + parser.add_argument( + "-c", + "--controller", + action="store_true", + help="Use an attached Parrot SkyController to increase Olympe range", + ) + parser.add_argument( + "-s", "--simulate", action="store_true", help="Run on a simulated drone in Parrot Sphinx" + ) + parser.add_argument("-nf", "--nofly", action="store_true", help="Prevent flight while running") + parser.add_argument( + "-ns", "--nostream", action="store_true", help="Prevent streaming while running" + ) + parser.add_argument( + "-mds", "--midas", action="store_true", help="Use MiDaS to do obstacle avoidance" + ) + parser.add_argument( + "-sft", "--sift", action="store_true", help="Use SIFT to do obstacle avoidance" + ) + parser.add_argument( + "-o", + "--obstacle", + action="store_true", + help="Use built-in obstacle avoidance (Anafi Ai only)", + ) opts = parser.parse_args() @@ -374,21 +380,21 @@ def run(self): tracker = None drone = olympe.Drone(DRONE_IP) - + time.sleep(1) drone.connect() if opts.controller: drone(setPilotingSource(source="Controller")).wait().success() time.sleep(1) - + if not opts.nostream: - streamer = OlympeStreaming(drone, sample_rate=3, model='coco') + streamer = OlympeStreaming(drone, sample_rate=3, model="coco") streamer.start() - + trace = None if opts.obstacle: drone(set_mode(mode.standard)).wait().success() - trace = open(datetime.now().strftime("%m-%d-%Y-%H-%M-%S") + ".txt", 'a') + trace = open(datetime.now().strftime("%m-%d-%Y-%H-%M-%S") + ".txt", "a") control = KeyboardCtrl() while not control.quit(): @@ -408,7 +414,7 @@ def run(self): elif opts.sift: tracker = sift_avoider.SIFTAvoider(drone) else: - tracker = dynamic.DynamicLeashTracker(drone) + tracker = dynamic.DynamicLeashTracker(drone) tracker.start() tracking = True print("Starting track!") @@ -424,10 +430,21 @@ def run(self): control.pitch(), control.yaw(), control.throttle(), - timestampAndSeqNum=0 + timestampAndSeqNum=0, + ) + ) + drone( + set_target( + 0, + control_mode.position, + "none", + 0.0, + "absolute", + control.get_gimbal_target(), + "none", + 0.0, ) ) - drone(set_target(0, control_mode.position, "none", 0.0, "absolute", control.get_gimbal_target(), "none", 0.0)) elif not tracking: drone(PCMD(0, 0, 0, 0, 0, timestampAndSeqNum=0)) time.sleep(0.05) diff --git a/field-testing/olympe-test.py b/field-testing/olympe-test.py index 99bb6e8f..9cce7413 100644 --- a/field-testing/olympe-test.py +++ b/field-testing/olympe-test.py @@ -2,23 +2,25 @@ # # SPDX-License-Identifier: GPL-2.0-only -import olympe import time + +import olympe from olympe.messages.gimbal import set_target if __name__ == "__main__": - drone = olympe.Drone('192.168.42.1') + drone = olympe.Drone("192.168.42.1") drone.connect() - drone(set_target( - gimbal_id=0, - control_mode="position", - yaw_frame_of_reference="none", - yaw=0.0, - pitch_frame_of_reference="absolute", - pitch=30.0, - roll_frame_of_reference="none", - roll=0.0, - )).wait().success() + drone( + set_target( + gimbal_id=0, + control_mode="position", + yaw_frame_of_reference="none", + yaw=0.0, + pitch_frame_of_reference="absolute", + pitch=30.0, + roll_frame_of_reference="none", + roll=0.0, + ) + ).wait().success() time.sleep(10) drone.disconnect() - diff --git a/field-testing/streaming.py b/field-testing/streaming.py index f8a53049..ced079dd 100644 --- a/field-testing/streaming.py +++ b/field-testing/streaming.py @@ -18,15 +18,13 @@ import time import olympe -from olympe.messages.ardrone3.Piloting import TakeOff, Landing -from olympe.messages.ardrone3.Piloting import moveBy -from olympe.messages.ardrone3.PilotingState import FlyingStateChanged +from olympe.messages.ardrone3.GPSSettingsState import GPSFixStateChanged +from olympe.messages.ardrone3.Piloting import Landing, TakeOff, moveBy from olympe.messages.ardrone3.PilotingSettings import MaxTilt from olympe.messages.ardrone3.PilotingSettingsState import MaxTiltChanged -from olympe.messages.ardrone3.GPSSettingsState import GPSFixStateChanged +from olympe.messages.ardrone3.PilotingState import FlyingStateChanged from olympe.video.renderer import PdrawRenderer - olympe.log.update_config({"loggers": {"olympe": {"level": "WARNING"}}}) DRONE_IP = "192.168.42.1" @@ -41,9 +39,7 @@ def __init__(self): print(f"Olympe streaming example output dir: {self.tempd}") self.h264_frame_stats = [] self.h264_stats_file = open(os.path.join(self.tempd, "h264_stats.csv"), "w+") - self.h264_stats_writer = csv.DictWriter( - self.h264_stats_file, ["fps", "bitrate"] - ) + self.h264_stats_writer = csv.DictWriter(self.h264_stats_file, ["fps", "bitrate"]) self.h264_stats_writer.writeheader() self.frame_queue = queue.Queue() self.processing_thread = threading.Thread(target=self.yuv_frame_processing) @@ -179,9 +175,7 @@ def fly(self): GPSFixStateChanged(fixed=1, _timeout=10, _policy="check_wait") >> ( TakeOff(_no_expect=True) - & FlyingStateChanged( - state="hovering", _timeout=10, _policy="check_wait" - ) + & FlyingStateChanged(state="hovering", _timeout=10, _policy="check_wait") ) ) ).wait() @@ -207,7 +201,7 @@ def test_streaming(): # Start the video stream streaming_example.start() # Perform some live video processing while the drone is flying - #streaming_example.fly() + # streaming_example.fly() time.sleep(10) # Stop the video stream streaming_example.stop() diff --git a/field-testing/trackers/dynamic.py b/field-testing/trackers/dynamic.py index 8bb38345..69418281 100644 --- a/field-testing/trackers/dynamic.py +++ b/field-testing/trackers/dynamic.py @@ -2,16 +2,15 @@ # # SPDX-License-Identifier: GPL-2.0-only -import olympe -from olympe.messages.ardrone3.Piloting import PCMD -from olympe.messages.ardrone3.PilotingState import AltitudeChanged -from olympe.messages.gimbal import set_target, attitude, set_max_speed -from olympe.enums.gimbal import control_mode +import json import threading import time -import zmq -import json + import numpy as np +import zmq +from olympe.messages.ardrone3.Piloting import PCMD +from olympe.messages.ardrone3.PilotingState import AltitudeChanged +from olympe.messages.gimbal import attitude, set_max_speed from scipy.spatial.transform import Rotation as R @@ -22,8 +21,8 @@ def __init__(self, drone, leash=10.0, hysteresis=True): self.leash = leash self.context = zmq.Context() self.sub_socket = self.context.socket(zmq.SUB) - self.sub_socket.connect('tcp://localhost:5556') - self.sub_socket.setsockopt(zmq.SUBSCRIBE, b'') + self.sub_socket.connect("tcp://localhost:5556") + self.sub_socket.setsockopt(zmq.SUBSCRIBE, b"") self.image_res = (640, 480) self.pixel_center = (self.image_res[0] / 2, self.image_res[1] / 2) self.HFOV = 69 @@ -42,19 +41,19 @@ def find_intersection(self, target_dir, target_insct): t = (plane_norm.dot(plane_pt) - plane_norm.dot(target_insct)) / plane_norm.dot(target_dir) return target_insct + (t * target_dir) - + def get_movement_vectors(self, yaw, pitch): gatt = self.drone.get_state(attitude) current_gimbal_pitch = gatt[0]["pitch_absolute"] alt = self.drone.get_state(AltitudeChanged) current_drone_altitude = alt["altitude"] - + forward_vec = [0, 1, 0] - r = R.from_euler('ZYX', [yaw, 0, pitch + current_gimbal_pitch], degrees=True) + r = R.from_euler("ZYX", [yaw, 0, pitch + current_gimbal_pitch], degrees=True) target_dir = r.as_matrix().dot(forward_vec) target_vec = self.find_intersection(target_dir, np.array([0, 0, current_drone_altitude])) print(f"Distance estimate: {np.linalg.norm(target_vec)}") - + leash_vec = self.leash * (target_vec / np.linalg.norm(target_vec)) print(f"Leash vector: {leash_vec}") movement_vec = target_vec - leash_vec @@ -65,17 +64,29 @@ def get_movement_vectors(self, yaw, pitch): def calculate_offsets(self, box): target_x_pix = int((((box[3] - box[1]) / 2.0) + box[1]) * self.image_res[0]) target_y_pix = int((1 - (((box[2] - box[0]) / 2.0) + box[0])) * self.image_res[1]) - target_yaw_angle = ((target_x_pix - self.pixel_center[0]) / self.pixel_center[0]) * (self.HFOV / 2) - target_pitch_angle = ((target_y_pix - self.pixel_center[1]) / self.pixel_center[1]) * (self.VFOV / 2) + target_yaw_angle = ((target_x_pix - self.pixel_center[0]) / self.pixel_center[0]) * ( + self.HFOV / 2 + ) + target_pitch_angle = ((target_y_pix - self.pixel_center[1]) / self.pixel_center[1]) * ( + self.VFOV / 2 + ) drone_roll, drone_pitch = self.get_movement_vectors(target_yaw_angle, target_pitch_angle) - if self.hysteresis and self.prev_center_ts != None and round(time.time() * 1000) - self.prev_center_ts < 500: - hysteresis_yaw_angle = ((self.prev_center[0] - target_x_pix) / self.prev_center[0]) * (self.HFOV / 2) - hysteresis_pitch_angle = ((self.prev_center[1] - target_y_pix) / self.prev_center[1]) * (self.VFOV / 2) + if ( + self.hysteresis + and self.prev_center_ts is not None + and round(time.time() * 1000) - self.prev_center_ts < 500 + ): + hysteresis_yaw_angle = ((self.prev_center[0] - target_x_pix) / self.prev_center[0]) * ( + self.HFOV / 2 + ) + hysteresis_pitch_angle = ( + (self.prev_center[1] - target_y_pix) / self.prev_center[1] + ) * (self.VFOV / 2) target_yaw_angle += 0.90 * hysteresis_yaw_angle target_pitch_angle += 0.90 * hysteresis_pitch_angle - + self.prev_center_ts = round(time.time() * 1000) self.prev_center = (target_x_pix, target_y_pix) @@ -94,11 +105,13 @@ def gain(self, gpitch, dyaw, dpitch, droll): def execute_PCMD(self, gpitch, dyaw, dpitch, droll): gpitch, dyaw, dpitch, droll = self.gain(gpitch, dyaw, dpitch, droll) - print(f"Gimbal Pitch: {gpitch}, Drone Yaw: {dyaw}, Drone Pitch: {dpitch}, Drone Roll: {droll}") + print( + f"Gimbal Pitch: {gpitch}, Drone Yaw: {dyaw}, Drone Pitch: {dpitch}, Drone Roll: {droll}" + ) self.drone(PCMD(1, 0, dpitch, dyaw, 0, timestampAndSeqNum=0)) - #gatt = self.drone.get_state(attitude) - #current_gimbal_pitch = gatt[0]["pitch_absolute"] - #self.drone(set_target(0, control_mode.position, "none", 0.0, "absolute", current_gimbal_pitch + gpitch, "none", 0.0)) + # gatt = self.drone.get_state(attitude) + # current_gimbal_pitch = gatt[0]["pitch_absolute"] + # self.drone(set_target(0, control_mode.position, "none", 0.0, "absolute", current_gimbal_pitch + gpitch, "none", 0.0)) def run(self): self.tracking = False @@ -109,11 +122,13 @@ def run(self): det = json.loads(self.sub_socket.recv_json()) if len(det) > 0: if not self.tracking: - print("Starting new track on object: \"{0}\"".format(det[0]["class"])) + print('Starting new track on object: "{}"'.format(det[0]["class"])) else: print(f"Got detection from the cloudlet: {det}") self.tracking = True - gimbal_pitch, drone_yaw, drone_pitch, drone_roll = self.calculate_offsets(det[0]["box"]) + gimbal_pitch, drone_yaw, drone_pitch, drone_roll = self.calculate_offsets( + det[0]["box"] + ) self.execute_PCMD(gimbal_pitch, drone_yaw, drone_pitch, drone_roll) except Exception as e: print(f"Exception: {e}") diff --git a/field-testing/trackers/parrot.py b/field-testing/trackers/parrot.py index 97fb04f8..b3f88133 100644 --- a/field-testing/trackers/parrot.py +++ b/field-testing/trackers/parrot.py @@ -2,18 +2,19 @@ # # SPDX-License-Identifier: GPL-2.0-only -import olympe -from olympe.messages.ardrone3.Piloting import TakeOff, moveBy, Landing, moveTo, NavigateHome, PCMD -from olympe.messages.ardrone3.PilotingState import AttitudeChanged, GpsLocationChanged, AltitudeChanged -import olympe.messages.follow_me as follow_me -from olympe.enums.follow_me import mode -from olympe.messages.gimbal import set_target, attitude -from olympe.enums.gimbal import control_mode +import json import threading import time + +import numpy as np +import olympe.messages.follow_me as follow_me import zmq -import json from geopy.distance import geodesic as GD +from olympe.enums.follow_me import mode +from olympe.messages.ardrone3.PilotingState import ( + AltitudeChanged, + GpsLocationChanged, +) class ParrotFollowMeTracker(threading.Thread): @@ -22,14 +23,14 @@ def __init__(self, drone, mode=mode.look_at): self.behavior = mode self.context = zmq.Context() self.sub_socket = self.context.socket(zmq.SUB) - self.sub_socket.connect('tcp://localhost:5556') - self.sub_socket.setsockopt(zmq.SUBSCRIBE, b'') + self.sub_socket.connect("tcp://localhost:5556") + self.sub_socket.setsockopt(zmq.SUBSCRIBE, b"") super().__init__() def calculate_azimuth_elevation(self, target_lat, target_lon): - gps = drone.get_state(GpsLocationChanged) - alt = drone.get_state(AltitudeChanged) - + gps = self.drone.get_state(GpsLocationChanged) + alt = self.drone.get_state(AltitudeChanged) + # Elevation calculation d = GD((target_lat, target_lon), (gps["latitude"], gps["longitude"])) elev = np.arctan(alt["altitude"] / d.m) @@ -39,7 +40,9 @@ def calculate_azimuth_elevation(self, target_lat, target_lon): drone_lat = gps["latitude"] * (np.pi * 180) delta_lon = (gps["longitude"] - target_lon) * (np.pi * 180) y = np.sin(delta_lon) * np.cos(drone_lat) - x = np.cos(target_lat) * np.sin(drone_lat) - np.sin(target_lat) * np.cos(drone_lat) * np.cos(delta_lon) + x = np.cos(target_lat) * np.sin(drone_lat) - np.sin(target_lat) * np.cos( + drone_lat + ) * np.cos(delta_lon) azi = np.arctan2(y, x) return azi, elev @@ -56,27 +59,35 @@ def run(self): try: det = json.loads(self.sub_socket.recv_json()) if not self.tracking and len(det) > 0: - print("Starting new track on object: \"{0}\"".format(det[0]["class"])) + print('Starting new track on object: "{}"'.format(det[0]["class"])) self.tracking = True azi, elev = self.calculate_azimuth_elevation(det[0]["lat"], det[0]["lon"]) conf = int(det[0]["score"] * 255) self.start = self.current_time_millis() self.drone(follow_me.set_target_is_controller(0)) - self.drone(follow_me.target_image_detection(azi, elev, 0.0, conf, 1, self.current_time_millis() - self.start)) + self.drone( + follow_me.target_image_detection( + azi, elev, 0.0, conf, 1, self.current_time_millis() - self.start + ) + ) self.drone(follow_me.target_framing_position(50, 50)) self.drone(follow_me.start(self.behavior, _no_expect=True)) elif self.tracking and len(det) > 0: - print("Got detection from cloudlet: {0}".format(json.dumps(det))) + print(f"Got detection from cloudlet: {json.dumps(det)}") azi, elev = self.calculate_azimuth_elevation(det[0]["lat"], det[0]["lon"]) conf = int(det[0]["score"] * 255) - self.drone(follow_me.target_image_detection(azi, elev, 0.0, conf, 0, self.current_time_millis() - self.start)) + self.drone( + follow_me.target_image_detection( + azi, elev, 0.0, conf, 0, self.current_time_millis() - self.start + ) + ) info = self.drone.get_state(follow_me.mode_info) state = self.drone.get_state(follow_me.target_image_detection_state) print(info) print(state) except Exception as e: print(f"Exception: {e}") - + def stop(self): self.context.destroy() self.active = False diff --git a/field-testing/trackers/static.py b/field-testing/trackers/static.py index 6b1d8d16..465e0414 100644 --- a/field-testing/trackers/static.py +++ b/field-testing/trackers/static.py @@ -2,16 +2,16 @@ # # SPDX-License-Identifier: GPL-2.0-only -import olympe -from olympe.messages.ardrone3.Piloting import PCMD -from olympe.messages.ardrone3.PilotingState import AltitudeChanged -from olympe.messages.gimbal import set_target, attitude -from olympe.enums.gimbal import control_mode -import numpy as np +import json import threading import time + +import numpy as np import zmq -import json +from olympe.enums.gimbal import control_mode +from olympe.messages.ardrone3.Piloting import PCMD +from olympe.messages.ardrone3.PilotingState import AltitudeChanged +from olympe.messages.gimbal import set_target class StaticLeashTracker(threading.Thread): @@ -20,8 +20,8 @@ def __init__(self, drone, leash=20.0): self.leash = leash self.context = zmq.Context() self.sub_socket = self.context.socket(zmq.SUB) - self.sub_socket.connect('tcp://localhost:5556') - self.sub_socket.setsockopt(zmq.SUBSCRIBE, b'') + self.sub_socket.connect("tcp://localhost:5556") + self.sub_socket.setsockopt(zmq.SUBSCRIBE, b"") self.image_res = (640, 480) self.pixel_center = (self.image_res[0] / 2, self.image_res[1] / 2) self.HFOV = 69 @@ -32,10 +32,12 @@ def set_leash(self): alt = self.drone.get_state(AltitudeChanged)["altitude"] print(alt) print(self.leash) - print(np.arctan(self.leash/alt)) + print(np.arctan(self.leash / alt)) angle = -1 * (90 - (np.arctan(self.leash / alt) * (180 / np.pi))) print(angle) - self.drone(set_target(0, control_mode.position, "none", 0.0, "absolute", angle, "none", 0.0)) + self.drone( + set_target(0, control_mode.position, "none", 0.0, "absolute", angle, "none", 0.0) + ) def pitch_step_func(self, p): CUTOFFS = [0.0, 0.35, 0.45, 0.55, 0.65, 1.0] @@ -68,7 +70,7 @@ def roll_step_func(self, r): return SPEEDS[4] def calculate_offsets(self, box): - target_x_percentage = ((box[3] - box[1]) / 2.0) + box[1] + target_x_percentage = ((box[3] - box[1]) / 2.0) + box[1] roll_speed = self.roll_step_func(target_x_percentage) target_y_percentage = 1 - (((box[2] - box[0]) / 2.0) + box[0]) pitch_speed = self.pitch_step_func(target_y_percentage) @@ -88,7 +90,7 @@ def run(self): det = json.loads(self.sub_socket.recv_json()) if len(det) > 0: if not self.tracking: - print("Starting new track on object: \"{0}\"".format(det[0]["class"])) + print('Starting new track on object: "{}"'.format(det[0]["class"])) else: print(f"Got detection from the cloudlet: {det}") self.tracking = True diff --git a/onboard/python/implementation/cloudlets/PartialOffloadCloudlet.py b/onboard/python/implementation/cloudlets/PartialOffloadCloudlet.py index 4228aa73..33bcad0c 100644 --- a/onboard/python/implementation/cloudlets/PartialOffloadCloudlet.py +++ b/onboard/python/implementation/cloudlets/PartialOffloadCloudlet.py @@ -2,28 +2,26 @@ # # SPDX-License-Identifier: GPL-2.0-only -from interfaces import CloudletItf +import asyncio import json -from json import JSONDecodeError -import threading -import time import logging -import asyncio -import cv2 +from json import JSONDecodeError +import cv2 from cnc_protocol import cnc_pb2 -from gabriel_protocol import gabriel_pb2 from gabriel_client.websocket_client import ProducerWrapper +from gabriel_protocol import gabriel_pb2 +from interfaces import CloudletItf logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) -class PartialOffloadCloudlet(CloudletItf.CloudletItf): +class PartialOffloadCloudlet(CloudletItf.CloudletItf): def __init__(self): self.engine_results = {} - self.source = 'telemetry' - self.model = 'coco' + self.source = "telemetry" + self.model = "coco" self.drone = None self.sample_rate = 1 self.stop = True @@ -34,15 +32,15 @@ def processResults(self, result_wrapper): for result in result_wrapper.results: if result.payload_type == gabriel_pb2.PayloadType.TEXT: - payload = result.payload.decode('utf-8') + payload = result.payload.decode("utf-8") data = "" try: if len(payload) != 0: data = json.loads(payload) producer = result_wrapper.result_producer_name.value self.engine_results[producer] = result - except JSONDecodeError as e: - logger.error(f'Error decoding json: {payload}') + except JSONDecodeError: + logger.error(f"Error decoding json: {payload}") except Exception as e: print(e) else: @@ -74,7 +72,7 @@ async def producer(): input_frame = gabriel_pb2.InputFrame() if not self.stop: try: - _, frame = cv2.imencode('.jpg', self.drone.getVideoFrame()) + _, frame = cv2.imencode(".jpg", self.drone.getVideoFrame()) input_frame.payload_type = gabriel_pb2.PayloadType.IMAGE input_frame.payloads.append(frame.tobytes()) @@ -83,8 +81,8 @@ async def producer(): input_frame.extras.Pack(extras) except Exception as e: input_frame.payload_type = gabriel_pb2.PayloadType.TEXT - input_frame.payloads.append("Unable to produce a frame!".encode('utf-8')) - logger.error(f'Unable to produce a frame: {e}') + input_frame.payloads.append(b"Unable to produce a frame!") + logger.error(f"Unable to produce a frame: {e}") else: input_frame.payload_type = gabriel_pb2.PayloadType.TEXT input_frame.payloads.append("Streaming not started, no frame to show.") @@ -101,4 +99,3 @@ def getResults(self, engine_key): def clearResults(self, engine_key): self.engine_results[engine_key] = None - diff --git a/onboard/python/implementation/cloudlets/PureOffloadCloudlet.py b/onboard/python/implementation/cloudlets/PureOffloadCloudlet.py index 9c44c44d..cb45ea7c 100644 --- a/onboard/python/implementation/cloudlets/PureOffloadCloudlet.py +++ b/onboard/python/implementation/cloudlets/PureOffloadCloudlet.py @@ -2,35 +2,32 @@ # # SPDX-License-Identifier: GPL-2.0-only -from interfaces import CloudletItf +import asyncio import json -from json import JSONDecodeError -import threading -import time import logging -import asyncio -from syncer import sync -import cv2 -from typing import Tuple +from json import JSONDecodeError +import cv2 from cnc_protocol import cnc_pb2 -from gabriel_protocol import gabriel_pb2 from gabriel_client.websocket_client import ProducerWrapper +from gabriel_protocol import gabriel_pb2 +from interfaces import CloudletItf +from syncer import sync logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -class PureOffloadCloudlet(CloudletItf.CloudletItf): +class PureOffloadCloudlet(CloudletItf.CloudletItf): def __init__(self): self.engine_results = {} - self.source = 'telemetry' - self.model = 'coco' + self.source = "telemetry" + self.model = "coco" self.drone = None self.sample_rate = 1 self.stop = True - self.hsv_upper = [50,255,255] - self.hsv_lower = [30,100,100] + self.hsv_upper = [50, 255, 255] + self.hsv_lower = [30, 100, 100] def processResults(self, result_wrapper): if len(result_wrapper.results) != 1: @@ -38,15 +35,15 @@ def processResults(self, result_wrapper): for result in result_wrapper.results: if result.payload_type == gabriel_pb2.PayloadType.TEXT: - payload = result.payload.decode('utf-8') + payload = result.payload.decode("utf-8") data = "" try: if len(payload) != 0: data = json.loads(payload) producer = result_wrapper.result_producer_name.value self.engine_results[producer] = result - except JSONDecodeError as e: - logger.debug(f'Error decoding json: {payload}') + except JSONDecodeError: + logger.debug(f"Error decoding json: {payload}") except Exception as e: print(e) else: @@ -64,7 +61,7 @@ def stopStreaming(self): def switchModel(self, model): self.model = model - def setHSVFilter(self, lower_bound: Tuple[int, int, int], upper_bound: Tuple[int, int, int]): + def setHSVFilter(self, lower_bound: tuple[int, int, int], upper_bound: tuple[int, int, int]): self.hsv_lower = lower_bound self.hsv_upper = upper_bound @@ -89,7 +86,7 @@ async def producer(): if not self.stop: try: f = sync(self.drone.getVideoFrame()) - _, frame = cv2.imencode('.jpg', f) + _, frame = cv2.imencode(".jpg", f) input_frame.payload_type = gabriel_pb2.PayloadType.IMAGE input_frame.payloads.append(frame.tobytes()) @@ -98,8 +95,8 @@ async def producer(): input_frame.extras.Pack(extras) except Exception as e: input_frame.payload_type = gabriel_pb2.PayloadType.TEXT - input_frame.payloads.append("Unable to produce a frame!".encode('utf-8')) - logger.debug(f'Unable to produce a frame: {e}') + input_frame.payloads.append(b"Unable to produce a frame!") + logger.debug(f"Unable to produce a frame: {e}") else: input_frame.payload_type = gabriel_pb2.PayloadType.TEXT input_frame.payloads.append("Streaming not started, no frame to show.") diff --git a/onboard/python/implementation/drones/MavlinkDrone.py b/onboard/python/implementation/drones/MavlinkDrone.py index 791e4b38..d14add38 100644 --- a/onboard/python/implementation/drones/MavlinkDrone.py +++ b/onboard/python/implementation/drones/MavlinkDrone.py @@ -3,39 +3,51 @@ # SPDX-License-Identifier: GPL-2.0-only import asyncio -from interfaces import DroneItf +import logging import math -from mavsdk import System -from mavsdk.offboard import (OffboardError, PositionNedYaw, VelocityBodyYawspeed, PositionGlobalYaw) +import math as m +import os +import threading import time + +import cv2 import numpy as np -import math as m -import logging +from interfaces import DroneItf +from mavsdk import System +from mavsdk.offboard import ( + OffboardError, + PositionGlobalYaw, + VelocityBodyYawspeed, +) logger = logging.getLogger() + def bearing(origin, destination): - lat1, lon1 = origin - lat2, lon2 = destination + lat1, lon1 = origin + lat2, lon2 = destination - rlat1 = math.radians(lat1) - rlat2 = math.radians(lat2) - rlon1 = math.radians(lon1) - rlon2 = math.radians(lon2) - dlon = math.radians(lon2-lon1) + rlat1 = math.radians(lat1) + rlat2 = math.radians(lat2) + rlon1 = math.radians(lon1) + rlon2 = math.radians(lon2) + dlon = math.radians(lon2 - lon1) + + b = math.atan2( + math.sin(dlon) * math.cos(rlat2), + math.cos(rlat1) * math.sin(rlat2) - math.sin(rlat1) * math.cos(rlat2) * math.cos(dlon), + ) + bd = math.degrees(b) + br, bn = divmod(bd + 360, 360) + + return bn - b = math.atan2(math.sin(dlon)*math.cos(rlat2),math.cos(rlat1)*math.sin(rlat2)-math.sin(rlat1)*math.cos(rlat2)*math.cos(dlon)) - bd = math.degrees(b) - br,bn = divmod(bd+360,360) - - return bn def get_rot_mat(theta): return np.matrix([[m.cos(theta), -m.sin(theta), 0], [m.sin(theta), m.cos(theta), 0], [0, 0, 1]]) class MavlinkDrone(DroneItf.DroneItf): - VEL_TOL = 0.1 ANG_VEL_TOL = 0.01 RTH_ALT = 20 @@ -44,8 +56,8 @@ def __init__(self, **kwargs): self.drone = System() self.active = False - ''' Awaiting methods ''' - + """ Awaiting methods """ + async def hovering(self, timeout=None): start = time.time() # Allow previous command to take effect @@ -53,61 +65,78 @@ async def hovering(self, timeout=None): async for odometry in self.drone.telemetry.odometry(): velocity = odometry.velocity_body ang_velocity = odometry.angular_velocity_body - logger.info(f'[MavlinkDrone]: Current velocity: {velocity.x_m_s} {velocity.y_m_s} {velocity.z_m_s} {ang_velocity.yaw_rad_s}') - if velocity.x_m_s < self.VEL_TOL and velocity.y_m_s < self.VEL_TOL and velocity.z_m_s < self.VEL_TOL and ang_velocity.yaw_rad_s < self.ANG_VEL_TOL: - break # We are now hovering! + logger.info( + f"[MavlinkDrone]: Current velocity: {velocity.x_m_s} {velocity.y_m_s} {velocity.z_m_s} {ang_velocity.yaw_rad_s}" + ) + if ( + velocity.x_m_s < self.VEL_TOL + and velocity.y_m_s < self.VEL_TOL + and velocity.z_m_s < self.VEL_TOL + and ang_velocity.yaw_rad_s < self.ANG_VEL_TOL + ): + break # We are now hovering! else: - if timeout and time.time() - start > timeout: # Break with timeout + if timeout and time.time() - start > timeout: # Break with timeout break - + async def telemetry_subscriber(self): async def pos(self): await self.drone.telemetry.set_rate_position(1) async for position in self.drone.telemetry.position(): - self.telemetry['lat'] = position.latitude_deg - self.telemetry['lng'] = position.longitude_deg - self.telemetry['alt'] = position.absolute_altitude_m - self.telemetry['rel-alt'] = position.relative_altitude_m + self.telemetry["lat"] = position.latitude_deg + self.telemetry["lng"] = position.longitude_deg + self.telemetry["alt"] = position.absolute_altitude_m + self.telemetry["rel-alt"] = position.relative_altitude_m + async def head(self): async for heading in self.drone.telemetry.heading(): - self.telemetry['head'] = heading.heading_deg + self.telemetry["head"] = heading.heading_deg + async def battery(self): await self.drone.telemetry.set_rate_battery(1) async for battery in self.drone.telemetry.battery(): - self.telemetry['battery'] = battery.remaining_percent + self.telemetry["battery"] = battery.remaining_percent + async def mag(self): async for health in self.drone.telemetry.health(): - self.telemetry['mag'] = health.is_magnetometer_calibration_ok + self.telemetry["mag"] = health.is_magnetometer_calibration_ok + async def sat(self): await self.drone.telemetry.set_rate_gps_info(1) async for info in self.drone.telemetry.gps_info(): - self.telemetry['sat'] = info.num_satellites + self.telemetry["sat"] = info.num_satellites asyncio.gather(pos(self), head(self), battery(self), mag(self), sat(self)) - ''' Connection methods ''' + """ Connection methods """ async def connect(self): await self.drone.connect() # Set max speed for use by PCMD self.max_speed = await self.drone.action.get_maximum_speed() self.active = True - self.telemetry = {'battery': 100, 'rssi': 0, 'mag': 0, 'heading': 0, 'sat': 0, - 'lat': 0, 'lng': 0, 'alt': 0, 'rel-alt': 0} + self.telemetry = { + "battery": 100, + "rssi": 0, + "mag": 0, + "heading": 0, + "sat": 0, + "lat": 0, + "lng": 0, + "alt": 0, + "rel-alt": 0, + } asyncio.create_task(self.telemetry_subscriber()) async def isConnected(self): async for state in self.drone.core.connection_state(): - if state.is_connected: - return True - else: - return False + return bool(state.is_connected) async def disconnect(self): await self.drone.action.disarm() self.active = False - ''' Streaming methods ''' + """ Streaming methods """ async def startStreaming(self, **kwargs): pass @@ -118,7 +147,7 @@ async def getVideoFrame(self): async def stopStreaming(self): pass - ''' Take off / Landing methods ''' + """ Take off / Landing methods """ async def takeOff(self): await self.drone.action.arm() @@ -129,28 +158,35 @@ async def takeOff(self): await self.drone.offboard.set_velocity_body(VelocityBodyYawspeed(0.0, 0.0, 0.0, 0.0)) try: await self.drone.offboard.start() - except Exception as e: + except Exception: await self.land() async def land(self): try: await self.drone.offboard.stop() - except OffboardError as error: + except OffboardError: pass await self.drone.action.land() await self.drone.action.disarm() async def setHome(self, lat, lng, alt): - raise NotImplemented() + raise NotImplementedError() async def rth(self): await self.drone.action.set_return_to_launch_altitude(self.RTH_ALT) await self.drone.action.return_to_launch() - ''' Movement methods ''' + """ Movement methods """ async def PCMD(self, roll, pitch, yaw, gaz): - await self.drone.offboard.set_velocity_body(VelocityBodyYawspeed((pitch/100)*self.max_speed, (roll/100)*self.max_speed, (-1 * gaz/100)*self.max_speed, float(yaw))) + await self.drone.offboard.set_velocity_body( + VelocityBodyYawspeed( + (pitch / 100) * self.max_speed, + (roll / 100) * self.max_speed, + (-1 * gaz / 100) * self.max_speed, + float(yaw), + ) + ) async def moveTo(self, lat, lng, alt): # Get bearing to target @@ -158,12 +194,14 @@ async def moveTo(self, lat, lng, alt): currentLng = await self.getLng() b = bearing((currentLat, currentLng), (lat, lng)) try: - await self.drone.offboard.set_position_global(PositionGlobalYaw(lat, lng, alt, b, PositionGlobalYaw.AltitudeType(0))) + await self.drone.offboard.set_position_global( + PositionGlobalYaw(lat, lng, alt, b, PositionGlobalYaw.AltitudeType(0)) + ) except OffboardError as e: - logger.error(f'[MavlinkDrone] Offboard command failed: {e.result}') - logger.info('[MavlinkDrone] Awaiting hover state') + logger.error(f"[MavlinkDrone] Offboard command failed: {e.result}") + logger.info("[MavlinkDrone] Awaiting hover state") await self.hovering() - logger.info('[MavlinkDrone] Got to hover state') + logger.info("[MavlinkDrone] Got to hover state") async def moveBy(self, x, y, z, t): v = np.array([x, y, z]) @@ -174,7 +212,7 @@ async def moveBy(self, x, y, z, t): await self.hovering() async def rotateTo(self, theta): - await self.moveBy(0.0, 0.0, 0.0, theta) + await self.moveBy(0.0, 0.0, 0.0, theta) async def setGimbalPose(self, yaw_theta, pitch_theta, roll_theta): pass @@ -182,20 +220,20 @@ async def setGimbalPose(self, yaw_theta, pitch_theta, roll_theta): async def hover(self): self.drone.action.hold() - ''' Photography methods ''' + """ Photography methods """ async def takePhoto(self): - raise NotImplemented() + raise NotImplementedError() async def toggleThermal(self, on): - raise NotImplemented() + raise NotImplementedError() - ''' Status methods ''' + """ Status methods """ async def getName(self): try: product = await self.drone.info.get_product() - if product.product_name is not None and product.product_name != 'undefined': + if product.product_name is not None and product.product_name != "undefined": return product.product_name except Exception: pass @@ -203,46 +241,40 @@ async def getName(self): return "MavlinkDrone" async def getLat(self): - return self.telemetry['lat'] + return self.telemetry["lat"] async def getLng(self): - return self.telemetry['lng'] + return self.telemetry["lng"] async def getHeading(self): - return self.telemetry['head'] + return self.telemetry["head"] async def getRelAlt(self): - return self.telemetry['rel-alt'] + return self.telemetry["rel-alt"] async def getExactAlt(self): - return self.telemetry['alt'] - + return self.telemetry["alt"] + async def getRSSI(self): return 0 async def getBatteryPercentage(self): - return round(self.telemetry['battery'] * 100) + return round(self.telemetry["battery"] * 100) async def getMagnetometerReading(self): - return self.telemetry['mag'] + return self.telemetry["mag"] async def getSatellites(self): - return self.telemetry['sat'] + return self.telemetry["sat"] async def kill(self): self.active = False -import cv2 -import numpy as np -import os -import threading - class StreamingThread(threading.Thread): - def __init__(self, drone, ip): threading.Thread.__init__(self) - self.currentFrame = None + self.currentFrame = None self.drone = drone os.environ["OPENCV_FFMPEG_CAPTURE_OPTIONS"] = "rtsp_transport;udp" self.cap = cv2.VideoCapture(f"rtsp://{ip}/live", cv2.CAP_FFMPEG) @@ -250,7 +282,7 @@ def __init__(self, drone, ip): def run(self): try: - while(self.isRunning): + while self.isRunning: ret, self.currentFrame = self.cap.read() except Exception as e: print(e) @@ -259,9 +291,9 @@ def grabFrame(self): try: frame = self.currentFrame.copy() return frame - except Exception as e: + except Exception: # Send a blank frame - return np.zeros((720, 1280, 3), np.uint8) + return np.zeros((720, 1280, 3), np.uint8) def stop(self): self.isRunning = False diff --git a/onboard/python/implementation/drones/ModalAISeekerDrone.py b/onboard/python/implementation/drones/ModalAISeekerDrone.py index b40fa2b8..af411bc1 100644 --- a/onboard/python/implementation/drones/ModalAISeekerDrone.py +++ b/onboard/python/implementation/drones/ModalAISeekerDrone.py @@ -3,51 +3,64 @@ # SPDX-License-Identifier: GPL-2.0-only import asyncio -from interfaces import DroneItf +import logging import math -from mavsdk import System -from mavsdk.offboard import (OffboardError, PositionNedYaw, VelocityBodyYawspeed, PositionGlobalYaw) +import math as m +import os +import threading import time + +import cv2 import numpy as np -import math as m -import logging +from interfaces import DroneItf +from mavsdk import System +from mavsdk.offboard import ( + OffboardError, + PositionGlobalYaw, + PositionNedYaw, + VelocityBodyYawspeed, +) logger = logging.getLogger(__name__) logging.basicConfig() logger.setLevel(logging.INFO) + def bearing(origin, destination): - lat1, lon1 = origin - lat2, lon2 = destination + lat1, lon1 = origin + lat2, lon2 = destination + + rlat1 = math.radians(lat1) + rlat2 = math.radians(lat2) + rlon1 = math.radians(lon1) + rlon2 = math.radians(lon2) + dlon = math.radians(lon2 - lon1) - rlat1 = math.radians(lat1) - rlat2 = math.radians(lat2) - rlon1 = math.radians(lon1) - rlon2 = math.radians(lon2) - dlon = math.radians(lon2-lon1) + b = math.atan2( + math.sin(dlon) * math.cos(rlat2), + math.cos(rlat1) * math.sin(rlat2) - math.sin(rlat1) * math.cos(rlat2) * math.cos(dlon), + ) + bd = math.degrees(b) + br, bn = divmod(bd + 360, 360) - b = math.atan2(math.sin(dlon)*math.cos(rlat2),math.cos(rlat1)*math.sin(rlat2)-math.sin(rlat1)*math.cos(rlat2)*math.cos(dlon)) - bd = math.degrees(b) - br,bn = divmod(bd+360,360) + return bn - return bn def get_rot_mat(theta): return np.matrix([[m.cos(theta), -m.sin(theta), 0], [m.sin(theta), m.cos(theta), 0], [0, 0, 1]]) class ModalAISeekerDrone(DroneItf.DroneItf): - VEL_TOL = 0.1 ANG_VEL_TOL = 0.01 RTH_ALT = 20 def __init__(self, **kwargs): - self.server_address = kwargs['server_address'] + self.server_address = kwargs["server_address"] self.drone = System(mavsdk_server_address=self.server_address, port=50051) self.active = False - ''' Awaiting methods ''' + """ Awaiting methods """ async def hovering(self, timeout=None): start = time.time() @@ -56,39 +69,50 @@ async def hovering(self, timeout=None): async for odometry in self.drone.telemetry.odometry(): velocity = odometry.velocity_body ang_velocity = odometry.angular_velocity_body - logger.debug(f'[MavlinkDrone]: Current velocity: {velocity.x_m_s} {velocity.y_m_s} {velocity.z_m_s} {ang_velocity.yaw_rad_s}') - if velocity.x_m_s < self.VEL_TOL and velocity.y_m_s < self.VEL_TOL and velocity.z_m_s < self.VEL_TOL and ang_velocity.yaw_rad_s < self.ANG_VEL_TOL: - break # We are now hovering! + logger.debug( + f"[MavlinkDrone]: Current velocity: {velocity.x_m_s} {velocity.y_m_s} {velocity.z_m_s} {ang_velocity.yaw_rad_s}" + ) + if ( + velocity.x_m_s < self.VEL_TOL + and velocity.y_m_s < self.VEL_TOL + and velocity.z_m_s < self.VEL_TOL + and ang_velocity.yaw_rad_s < self.ANG_VEL_TOL + ): + break # We are now hovering! else: - if timeout and time.time() - start > timeout: # Break with timeout + if timeout and time.time() - start > timeout: # Break with timeout break async def telemetry_subscriber(self): async def pos(self): async for position in self.drone.telemetry.position(): - self.telemetry['lat'] = position.latitude_deg - self.telemetry['lng'] = position.longitude_deg - self.telemetry['alt'] = position.absolute_altitude_m - self.telemetry['rel-alt'] = position.relative_altitude_m + self.telemetry["lat"] = position.latitude_deg + self.telemetry["lng"] = position.longitude_deg + self.telemetry["alt"] = position.absolute_altitude_m + self.telemetry["rel-alt"] = position.relative_altitude_m + async def head(self): async for heading in self.drone.telemetry.heading(): - self.telemetry['head'] = heading.heading_deg + self.telemetry["head"] = heading.heading_deg + async def battery(self): async for battery in self.drone.telemetry.battery(): - self.telemetry['battery'] = battery.remaining_percent + self.telemetry["battery"] = battery.remaining_percent + async def mag(self): async for health in self.drone.telemetry.health(): - self.telemetry['mag'] = health.is_magnetometer_calibration_ok + self.telemetry["mag"] = health.is_magnetometer_calibration_ok + async def sat(self): async for info in self.drone.telemetry.gps_info(): - self.telemetry['sat'] = info.num_satellites + self.telemetry["sat"] = info.num_satellites try: await asyncio.gather(pos(self), head(self), battery(self), mag(self), sat(self)) except Exception as e: logger.error(f"Exception: {e}") - ''' Connection methods ''' + """ Connection methods """ async def connect(self): system_address = f"udp://{self.server_address}:14550" @@ -98,22 +122,28 @@ async def connect(self): # Set max speed for use by PCMD # self.max_speed = await self.drone.action.get_maximum_speed() self.active = True - self.telemetry = {'battery': 100, 'rssi': 0, 'mag': 0, 'heading': 0, 'sat': 0, - 'lat': 0, 'lng': 0, 'alt': 0, 'rel-alt': 0} + self.telemetry = { + "battery": 100, + "rssi": 0, + "mag": 0, + "heading": 0, + "sat": 0, + "lat": 0, + "lng": 0, + "alt": 0, + "rel-alt": 0, + } self.telemetry_task = asyncio.create_task(self.telemetry_subscriber()) async def isConnected(self): async for state in self.drone.core.connection_state(): - if state.is_connected: - return True - else: - return False + return bool(state.is_connected) async def disconnect(self): await self.drone.action.disarm() self.active = False - ''' Streaming methods ''' + """ Streaming methods """ async def startStreaming(self, **kwargs): self.streamingThread = StreamingThread(self.drone, "127.0.0.1:8900") @@ -127,7 +157,7 @@ async def stopStreaming(self): async def startOffboardMode(self): # Initial setpoint for offboard control - #await self.drone.offboard.set_velocity_body(VelocityBodyYawspeed(0.0, 0.0, 0.0, 0.0)) + # await self.drone.offboard.set_velocity_body(VelocityBodyYawspeed(0.0, 0.0, 0.0, 0.0)) await self.drone.offboard.set_position_ned(PositionNedYaw(0.0, 0.0, 0.0, 0.0)) try: await self.drone.offboard.start() @@ -136,7 +166,7 @@ async def startOffboardMode(self): logger.info("Landing...") await self.land() - ''' Take off / Landing methods ''' + """ Take off / Landing methods """ async def takeOff(self): logger.info("Takeoff: Arming") @@ -157,22 +187,29 @@ async def takeOff(self): async def land(self): try: await self.drone.offboard.stop() - except OffboardError as error: + except OffboardError: pass await self.drone.action.land() await self.drone.action.disarm() async def setHome(self, lat, lng, alt): - raise NotImplemented() + raise NotImplementedError() async def rth(self): await self.drone.action.set_return_to_launch_altitude(self.RTH_ALT) await self.drone.action.return_to_launch() - ''' Movement methods ''' + """ Movement methods """ async def PCMD(self, roll, pitch, yaw, gaz): - await self.drone.offboard.set_velocity_body(VelocityBodyYawspeed((pitch/100)*self.max_speed, (roll/100)*self.max_speed, (-1 * gaz/100)*self.max_speed, float(yaw))) + await self.drone.offboard.set_velocity_body( + VelocityBodyYawspeed( + (pitch / 100) * self.max_speed, + (roll / 100) * self.max_speed, + (-1 * gaz / 100) * self.max_speed, + float(yaw), + ) + ) async def moveTo(self, lat, lng, alt): # Get bearing to target @@ -180,12 +217,14 @@ async def moveTo(self, lat, lng, alt): currentLng = await self.getLng() b = bearing((currentLat, currentLng), (lat, lng)) try: - await self.drone.offboard.set_position_global(PositionGlobalYaw(lat, lng, alt, b, PositionGlobalYaw.AltitudeType(0))) + await self.drone.offboard.set_position_global( + PositionGlobalYaw(lat, lng, alt, b, PositionGlobalYaw.AltitudeType(0)) + ) except OffboardError as e: - logger.error(f'[MavlinkDrone] Offboard command failed: {e.result}') - logger.info('[MavlinkDrone] Awaiting hover state') + logger.error(f"[MavlinkDrone] Offboard command failed: {e.result}") + logger.info("[MavlinkDrone] Awaiting hover state") await self.hovering() - logger.info('[MavlinkDrone] Got to hover state') + logger.info("[MavlinkDrone] Got to hover state") async def moveBy(self, x, y, z, t): v = np.array([x, y, z]) @@ -194,7 +233,9 @@ async def moveBy(self, x, y, z, t): res = np.matmul(R, v) logger.info(f"Move by: setting position: {res}") - await self.drone.offboard.set_position_ned(PositionNedYaw(res[0,0], res[0,1], -1 * res[0,2], h + t)) + await self.drone.offboard.set_position_ned( + PositionNedYaw(res[0, 0], res[0, 1], -1 * res[0, 2], h + t) + ) logger.info("Waiting for hovering") await self.hovering() @@ -221,60 +262,54 @@ async def getSpeedRel(self): async def hover(self): await self.drone.action.hold() - ''' Photography methods ''' + """ Photography methods """ async def takePhoto(self): - raise NotImplemented() + raise NotImplementedError() async def toggleThermal(self, on): - raise NotImplemented() + raise NotImplementedError() - ''' Status methods ''' + """ Status methods """ async def getName(self): return "ModalAISeekerDrone" async def getLat(self): - return self.telemetry['lat'] + return self.telemetry["lat"] async def getLng(self): - return self.telemetry['lng'] + return self.telemetry["lng"] async def getHeading(self): - return self.telemetry['heading'] + return self.telemetry["heading"] async def getRelAlt(self): - return self.telemetry['rel-alt'] + return self.telemetry["rel-alt"] async def getExactAlt(self): - return self.telemetry['alt'] + return self.telemetry["alt"] async def getRSSI(self): return 0 async def getBatteryPercentage(self): - return round(self.telemetry['battery'] * 100) + return round(self.telemetry["battery"] * 100) async def getMagnetometerReading(self): - return self.telemetry['mag'] + return self.telemetry["mag"] async def getSatellites(self): - return self.telemetry['sat'] + return self.telemetry["sat"] - #async def getPositionNed(self): + # async def getPositionNed(self): # return self.telemetry[' async def kill(self): self.active = False -import cv2 -import numpy as np -import os -import threading - class StreamingThread(threading.Thread): - def __init__(self, drone, ip): threading.Thread.__init__(self) self.currentFrame = None @@ -285,9 +320,9 @@ def __init__(self, drone, ip): def run(self): try: - while(self.isRunning): + while self.isRunning: ret, frame = self.cap.read() - if len(frame.shape) < 3: # Grayscale + if len(frame.shape) < 3: # Grayscale self.currentFrame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR) else: self.currentFrame = frame @@ -298,7 +333,7 @@ def grabFrame(self): try: frame = self.currentFrame.copy() return frame - except Exception as e: + except Exception: # Send a blank frame return np.zeros((720, 1280, 3), np.uint8) diff --git a/onboard/python/implementation/drones/ParrotAnafiDrone.py b/onboard/python/implementation/drones/ParrotAnafiDrone.py index 95b8a51f..093b265e 100644 --- a/onboard/python/implementation/drones/ParrotAnafiDrone.py +++ b/onboard/python/implementation/drones/ParrotAnafiDrone.py @@ -3,49 +3,57 @@ # SPDX-License-Identifier: GPL-2.0-only import asyncio +import logging +import math +import os +import queue import threading -from interfaces import DroneItf +import time + +import cv2 +import numpy as np import olympe +import olympe.enums.move as move_mode +from interfaces import DroneItf from olympe import Drone -from olympe.messages.ardrone3.Piloting import TakeOff, Landing -from olympe.messages.ardrone3.Piloting import PCMD, moveTo, moveBy -from olympe.messages.rth import set_custom_location, return_to_home -from olympe.messages.ardrone3.PilotingState import moveToChanged -from olympe.messages.common.CommonState import BatteryStateChanged -from olympe.messages.ardrone3.PilotingState import AttitudeChanged, GpsLocationChanged, AltitudeChanged, FlyingStateChanged, SpeedChanged from olympe.messages.ardrone3.GPSState import NumberOfSatelliteChanged -from olympe.messages.gimbal import set_target, attitude -from olympe.messages.wifi import rssi_changed -from olympe.messages.battery import capacity +from olympe.messages.ardrone3.Piloting import PCMD, Landing, TakeOff, moveBy, moveTo +from olympe.messages.ardrone3.PilotingState import ( + AltitudeChanged, + AttitudeChanged, + FlyingStateChanged, + GpsLocationChanged, + SpeedChanged, +) from olympe.messages.common.CalibrationState import MagnetoCalibrationRequiredState -import olympe.enums.move as move_mode -import olympe.enums.gimbal as gimbal_mode -import math -import logging +from olympe.messages.common.CommonState import BatteryStateChanged +from olympe.messages.gimbal import attitude, set_target +from olympe.messages.rth import return_to_home, set_custom_location +from olympe.messages.wifi import rssi_changed logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -class ParrotAnafiDrone(DroneItf.DroneItf): +class ParrotAnafiDrone(DroneItf.DroneItf): def __init__(self, **kwargs): self.dronename = None - if 'sim' in kwargs: - self.ip = '10.202.0.1' - elif 'droneip' in kwargs: - self.ip = kwargs['droneip'] + if "sim" in kwargs: + self.ip = "10.202.0.1" + elif "droneip" in kwargs: + self.ip = kwargs["droneip"] else: - self.ip = '192.168.42.1' - if 'lowdelay' in kwargs: + self.ip = "192.168.42.1" + if "lowdelay" in kwargs: self.lowdelay = True else: self.lowdelay = False - if 'dronename' in kwargs: - self.dronename = kwargs['dronename'] + if "dronename" in kwargs: + self.dronename = kwargs["dronename"] self.drone = Drone(self.ip) self.active = False - ''' Awaiting methods ''' + """ Awaiting methods """ async def hovering(self, timeout=None): # Let the task start before checking for hover state. @@ -54,14 +62,16 @@ async def hovering(self, timeout=None): if timeout is not None: start = time.time() while True: - if self.drone(FlyingStateChanged(state="hovering", _policy="check")).success(): - break - elif start is not None and time.time() - start < timeout: + if ( + self.drone(FlyingStateChanged(state="hovering", _policy="check")).success() + or start is not None + and time.time() - start < timeout + ): break else: await asyncio.sleep(1) - ''' Connection methods ''' + """ Connection methods """ async def connect(self): self.drone.connect() @@ -77,7 +87,7 @@ async def disconnect(self): self.drone.disconnect() self.active = False - ''' Streaming methods ''' + """ Streaming methods """ async def startStreaming(self, **kwargs): if self.lowdelay: @@ -93,7 +103,7 @@ async def getVideoFrame(self): async def stopStreaming(self): self.streamingThread.stop() - ''' Take off / Landing methods ''' + """ Take off / Landing methods """ async def takeOff(self): self.drone(TakeOff()) @@ -109,23 +119,17 @@ async def rth(self): await self.hover() self.drone(return_to_home()) - ''' Movement methods ''' + """ Movement methods """ async def PCMD(self, roll, pitch, yaw, gaz): - self.drone( - PCMD(1, roll, pitch, yaw, gaz, timestampAndSeqNum=0) - ) + self.drone(PCMD(1, roll, pitch, yaw, gaz, timestampAndSeqNum=0)) async def moveTo(self, lat, lng, alt): - self.drone( - moveTo(lat, lng, alt, move_mode.orientation_mode.to_target, 0.0) - ) + self.drone(moveTo(lat, lng, alt, move_mode.orientation_mode.to_target, 0.0)) await self.hovering() async def moveBy(self, x, y, z, t): - self.drone( - moveBy(x, y, z, t) - ) + self.drone(moveBy(x, y, z, t)) await self.hovering() async def rotateTo(self, theta): @@ -135,21 +139,23 @@ async def rotateTo(self, theta): async def setGimbalPose(self, yaw_theta, pitch_theta, roll_theta): # The Anafi does not support yaw or roll on its gimbal, thus these # parameters are discarded without effect. - self.drone(set_target( - gimbal_id=0, - control_mode="position", - yaw_frame_of_reference="none", - yaw=yaw_theta, - pitch_frame_of_reference="absolute", - pitch=pitch_theta, - roll_frame_of_reference="none", - roll=roll_theta,) + self.drone( + set_target( + gimbal_id=0, + control_mode="position", + yaw_frame_of_reference="none", + yaw=yaw_theta, + pitch_frame_of_reference="absolute", + pitch=pitch_theta, + roll_frame_of_reference="none", + roll=roll_theta, + ) ) async def hover(self): await self.PCMD(0, 0, 0, 0) - ''' Photography methods ''' + """ Photography methods """ async def takePhoto(self): # TODO: Take a photo and save it to the local drone folder @@ -157,12 +163,13 @@ async def takePhoto(self): async def toggleThermal(self, on): from olympe.messages.thermal import set_mode + if on: self.drone(set_mode(mode="blended")).wait().success() else: self.drone(set_mode(mode="disabled")).wait().success() - ''' Status methods ''' + """ Status methods """ async def getName(self): return self.drone._device_name @@ -196,10 +203,14 @@ async def getSpeedRel(self): vecr = np.array([0.0, 1.0], dtype=float) rt = np.radians(hd + 90) c, s = np.cos(rt), np.sin(rt) - R2 = np.array(((c,-s), (s, c))) + R2 = np.array(((c, -s), (s, c))) vecr = np.dot(R2, vecr) - res = {"speedX": np.dot(vec, vecf) * -1, "speedY": np.dot(vec, vecr) * -1, "speedZ": NED["speedZ"]} + res = { + "speedX": np.dot(vec, vecf) * -1, + "speedY": np.dot(vec, vecr) * -1, + "speedZ": NED["speedZ"], + } return res async def getExactAlt(self): @@ -224,23 +235,20 @@ async def kill(self): self.active = False -import cv2 -import numpy as np -import os - class StreamingThread(threading.Thread): - def __init__(self, drone, ip): threading.Thread.__init__(self) self.currentFrame = None self.drone = drone os.environ["OPENCV_FFMPEG_CAPTURE_OPTIONS"] = "rtsp_transport;udp" - self.cap = cv2.VideoCapture(f"rtsp://{ip}/live", cv2.CAP_FFMPEG, (cv2.CAP_PROP_N_THREADS, 1)) + self.cap = cv2.VideoCapture( + f"rtsp://{ip}/live", cv2.CAP_FFMPEG, (cv2.CAP_PROP_N_THREADS, 1) + ) self.isRunning = True def run(self): try: - while(self.isRunning): + while self.isRunning: ret, self.currentFrame = self.cap.read() except Exception as e: print(e) @@ -249,18 +257,16 @@ def grabFrame(self): try: frame = self.currentFrame.copy() return frame - except Exception as e: + except Exception: # Send a blank frame return np.zeros((720, 1280, 3), np.uint8) def stop(self): self.isRunning = False -import queue class LowDelayStreamingThread(threading.Thread): - - def __init__(self, drone, ip, save_frames = False): + def __init__(self, drone, ip, save_frames=False): threading.Thread.__init__(self) self.drone = drone self.frame_queue = queue.Queue() @@ -292,7 +298,7 @@ def grabFrame(self): try: frame = self.currentFrame.copy() return frame - except Exception as e: + except Exception: # Send a blank frame return np.zeros((720, 1280, 3), np.uint8) @@ -311,7 +317,7 @@ def copyFrame(self, yuv_frame): self.currentFrame = cv2.cvtColor(yuv_frame.as_ndarray(), cv2_cvt_color_flag) - ''' Callbacks ''' + """ Callbacks """ def yuvFrameCb(self, yuv_frame): """ @@ -320,7 +326,7 @@ def yuvFrameCb(self, yuv_frame): :type yuv_frame: olympe.VideoFrame """ yuv_frame.ref() - fps = 30 // float(os.environ.get('FPS')) + fps = 30 // float(os.environ.get("FPS")) self.frames_recd += 1 if self.frames_recd == fps: self.frame_queue.put_nowait(yuv_frame) diff --git a/onboard/python/implementation/drones/test.py b/onboard/python/implementation/drones/test.py index 3ffb6627..4175fb4d 100644 --- a/onboard/python/implementation/drones/test.py +++ b/onboard/python/implementation/drones/test.py @@ -1,30 +1,31 @@ -from ModalAISeekerDrone import ModalAISeekerDrone import asyncio import logging - import time +from ModalAISeekerDrone import ModalAISeekerDrone + logger = logging.getLogger(__name__) logging.basicConfig() logger.setLevel(logging.INFO) + async def main(): logger.info("Starting script") # args = {'server_address': '162.172.22.130'} - args = {'server_address': '192.168.8.1'} + args = {"server_address": "192.168.8.1"} drone = ModalAISeekerDrone(**args) await drone.connect() - #connected = await drone.isConnected() - #print(f"Connected: {connected}") + # connected = await drone.isConnected() + # print(f"Connected: {connected}") - #satellites = await drone.getBatteryPercentage() - #print(satellites) + # satellites = await drone.getBatteryPercentage() + # print(satellites) await drone.takeOff() - #logger.info("Done taking off") + # logger.info("Done taking off") await drone.startOffboardMode() @@ -37,5 +38,6 @@ async def main(): await drone.moveBy(0, 0, 10, 0) await drone.setVelocity(0, 0, 0, 0) + if __name__ == "__main__": asyncio.run(main()) diff --git a/onboard/python/supervisor.py b/onboard/python/supervisor.py index 30a396ad..b45ac660 100644 --- a/onboard/python/supervisor.py +++ b/onboard/python/supervisor.py @@ -4,39 +4,40 @@ import argparse import asyncio -import nest_asyncio -nest_asyncio.apply() -from syncer import sync +import importlib import logging -import requests +import os import subprocess import sys -import validators -import os +import time from zipfile import ZipFile -import importlib +# from websocket_client import WebsocketClient +import nest_asyncio +import requests +import validators +import zmq from cnc_protocol import cnc_pb2 -from gabriel_protocol import gabriel_pb2 from gabriel_client.websocket_client import ProducerWrapper, WebsocketClient -#from websocket_client import WebsocketClient +from gabriel_protocol import gabriel_pb2 +from syncer import sync -import zmq +nest_asyncio.apply() logger = logging.getLogger() logger.setLevel(logging.INFO) handler = logging.StreamHandler(sys.stdout) handler.setLevel(logging.INFO) -formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') +formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") handler.setFormatter(formatter) logger.addHandler(handler) -fh = logging.FileHandler('supervisor.log') -formatter = logging.Formatter('%(asctime)s - %(message)s') +fh = logging.FileHandler("supervisor.log") +formatter = logging.Formatter("%(asctime)s - %(message)s") fh.setFormatter(formatter) logger.addHandler(fh) -class Supervisor: +class Supervisor: def __init__(self, args): # Import the files corresponding to the selected drone/cloudlet self.drone_id = None @@ -44,45 +45,45 @@ def __init__(self, args): cloudlet_import = f"implementation.cloudlets.{args.cloudlet}" try: Drone = importlib.import_module(drone_import) - except Exception as e: - logger.info('Could not import drone {args.drone}') + except Exception: + logger.info("Could not import drone {args.drone}") sys.exit(0) try: Cloudlet = importlib.import_module(cloudlet_import) - except Exception as e: - logger.info('Could not import cloudlet {args.cloudlet}') + except Exception: + logger.info("Could not import cloudlet {args.cloudlet}") sys.exit(0) try: self.cloudlet = getattr(Cloudlet, args.cloudlet)() - except Exception as e: - logger.info('Could not initialize {args.cloudlet}, name does not exist. Aborting.') + except Exception: + logger.info("Could not initialize {args.cloudlet}, name does not exist. Aborting.") sys.exit(0) try: kwargs = {} if args.sim: - kwargs['sim'] = True + kwargs["sim"] = True if args.lowdelay: - kwargs['lowdelay'] = True + kwargs["lowdelay"] = True if args.dronename: - kwargs['dronename'] = args.dronename + kwargs["dronename"] = args.dronename if args.droneip: - kwargs['droneip'] = args.droneip + kwargs["droneip"] = args.droneip logger.info(f"{kwargs=}") self.drone = getattr(Drone, args.drone)(**kwargs) - except Exception as e: - logger.info('Could not initialize {args.drone}, name does not exist. Aborting.') + except Exception: + logger.info("Could not initialize {args.drone}, name does not exist. Aborting.") sys.exit(0) # Set the Gabriel soure - self.source = 'telemetry' + self.source = "telemetry" self.reload = False self.mission = None self.missionTask = None - self.manual = True # Default to manual control + self.manual = True # Default to manual control self.heartbeats = 0 self.zmq = zmq.Context().socket(zmq.REQ) - self.zmq.connect(f'tcp://{args.server}:{args.zmqport}') + self.zmq.connect(f"tcp://{args.server}:{args.zmqport}") self.tlogfile = None if args.trajectory: self.tlogfile = open("trajectory.log", "w") @@ -90,59 +91,63 @@ def __init__(self, args): async def initializeConnection(self): await self.drone.connect() await self.drone.startStreaming() - self.cloudlet.startStreaming(self.drone, 'coco', 30) + self.cloudlet.startStreaming(self.drone, "coco", 30) async def executeFlightScript(self, url: str): - logger.debug('Starting flight plan download...') + logger.debug("Starting flight plan download...") try: self.download(url) - except Exception as e: - logger.debug('Flight script download failed! Aborting.') + except Exception: + logger.debug("Flight script download failed! Aborting.") return - logger.debug('Flight script downloaded...') + logger.debug("Flight script downloaded...") self.start_mission() def start_mission(self): - logger.debug('Start mission supervisor') + logger.debug("Start mission supervisor") logger.debug(self) # Stop existing mission (if there is one) self.stop_mission() # Start new task - logger.debug('MS import') + logger.debug("MS import") module_prefix = self.drone_id if not self.reload: - logger.info('first time...') + logger.info("first time...") importlib.import_module(f"{module_prefix}.mission") importlib.import_module(f"{module_prefix}.task_defs") importlib.import_module(f"{module_prefix}.transition_defs") else: - logger.info('Reloading...') + logger.info("Reloading...") modules = sys.modules.copy() for module in modules.values(): - if module.__name__.startswith(f'{module_prefix}.mission') or module.__name__.startswith(f'{module_prefix}.task_defs') or module.__name__.startswith('{module_prefix}.transition_defs'): + if ( + module.__name__.startswith(f"{module_prefix}.mission") + or module.__name__.startswith(f"{module_prefix}.task_defs") + or module.__name__.startswith("{module_prefix}.transition_defs") + ): importlib.reload(module) - logger.debug('MC init') - #from mission.MissionController import MissionController + logger.debug("MC init") + # from mission.MissionController import MissionController Mission = importlib.import_module(f"{module_prefix}.mission.MissionController") - self.mission = getattr(Mission, "MissionController")(self.drone, self.cloudlet) - logger.debug('Running flight script!') + self.mission = Mission.MissionController(self.drone, self.cloudlet) + logger.debug("Running flight script!") self.missionTask = asyncio.create_task(self.mission.run()) self.reload = True def stop_mission(self): if self.mission and not self.missionTask.cancelled(): - logger.info('Mission script stop signalled') + logger.info("Mission script stop signalled") self.missionTask.cancel() self.mission = None self.missionTask = None def download(self, url: str): - #download zipfile and extract reqs/flight script from cloudlet + # download zipfile and extract reqs/flight script from cloudlet try: - filename = url.rsplit(sep='/')[-1] - logger.info(f'Writing {filename} to disk...') + filename = url.rsplit(sep="/")[-1] + logger.info(f"Writing {filename} to disk...") r = requests.get(url, stream=True) - with open(filename, mode='wb') as f: + with open(filename, mode="wb") as f: for chunk in r.iter_content(): f.write(chunk) os.makedirs(self.drone_id, exist_ok=True) @@ -150,12 +155,14 @@ def download(self, url: str): sys.path.append(self.drone_id) os.chdir(self.drone_id) try: - subprocess.check_call(['rm', '-rf', './task_defs', './mission', './transition_defs']) + subprocess.check_call( + ["rm", "-rf", "./task_defs", "./mission", "./transition_defs"] + ) except subprocess.CalledProcessError as e: logger.debug(f"Error removing old task/transition defs: {e}") z.extractall() self.install_prereqs() - os.chdir('..') + os.chdir("..") except Exception as e: print(e) @@ -163,13 +170,12 @@ def install_prereqs(self) -> bool: ret = False # Pip install prerequsites for flight script try: - subprocess.check_call(['python3', '-m', 'pip', 'install', '-r', './requirements.txt']) + subprocess.check_call(["python3", "-m", "pip", "install", "-r", "./requirements.txt"]) ret = True except subprocess.CalledProcessError as e: logger.debug(f"Error pip installing requirements.txt: {e}") return ret - async def commandHandler(self): name = await self.drone.getName() @@ -181,48 +187,54 @@ async def commandHandler(self): try: self.zmq.send(req.SerializeToString()) rep = self.zmq.recv() - if b'No commands.' != rep: - extras = cnc_pb2.Extras() + if rep != b"No commands.": + extras = cnc_pb2.Extras() extras.ParseFromString(rep) if extras.cmd.rth: - logger.info('RTH signaled from commander') + logger.info("RTH signaled from commander") self.stop_mission() self.manual = False asyncio.create_task(self.drone.rth()) elif extras.cmd.halt: - logger.info('Killswitch signaled from commander') + logger.info("Killswitch signaled from commander") self.stop_mission() self.manual = True - logger.info('Manual control is now active!') + logger.info("Manual control is now active!") # Try cancelling the RTH task if it exists sync(self.drone.hover()) elif extras.cmd.script_url: # Validate url if validators.url(extras.cmd.script_url): - logger.info(f'Flight script sent by commander: {extras.cmd.script_url}') + logger.info(f"Flight script sent by commander: {extras.cmd.script_url}") self.manual = False asyncio.create_task(self.executeFlightScript(extras.cmd.script_url)) else: - logger.info(f'Invalid script URL sent by commander: {extras.cmd.script_url}') + logger.info( + f"Invalid script URL sent by commander: {extras.cmd.script_url}" + ) elif self.manual: if extras.cmd.takeoff: - logger.info(f'Received manual takeoff') + logger.info("Received manual takeoff") asyncio.create_task(self.drone.takeOff()) elif extras.cmd.land: - logger.info(f'Received manual land') + logger.info("Received manual land") asyncio.create_task(self.drone.land()) else: - logger.info(f'Received manual PCMD') + logger.info("Received manual PCMD") pitch = extras.cmd.pcmd.pitch yaw = extras.cmd.pcmd.yaw roll = extras.cmd.pcmd.roll gaz = extras.cmd.pcmd.gaz gimbal_pitch = extras.cmd.pcmd.gimbal_pitch - logger.debug(f'Got PCMD values: {pitch} {yaw} {roll} {gaz} {gimbal_pitch}') + logger.debug( + f"Got PCMD values: {pitch} {yaw} {roll} {gaz} {gimbal_pitch}" + ) asyncio.create_task(self.drone.PCMD(roll, pitch, yaw, gaz)) current = await self.drone.getGimbalPitch() - asyncio.create_task(self.drone.setGimbalPose(0, current+gimbal_pitch , 0)) - if self.tlogfile: # Log trajectory IMU data + asyncio.create_task( + self.drone.setGimbalPose(0, current + gimbal_pitch, 0) + ) + if self.tlogfile: # Log trajectory IMU data speeds = await self.drone.getSpeedRel() fspeed = speeds["speedX"] hspeed = speeds["speedY"] @@ -230,19 +242,19 @@ async def commandHandler(self): except Exception as e: logger.debug(e) - - ''' + """ Process results from engines. Forward openscout engine results to Cloudlet object Parse and deal with results from command engine - ''' + """ + def processResults(self, result_wrapper): - if self.cloudlet and result_wrapper.result_producer_name.value != 'telemetry': - #forward result to cloudlet + if self.cloudlet and result_wrapper.result_producer_name.value != "telemetry": + # forward result to cloudlet self.cloudlet.processResults(result_wrapper) return else: - #process result from command engine + # process result from command engine pass def get_producer_wrappers(self): @@ -251,7 +263,7 @@ async def producer(): self.heartbeats += 1 input_frame = gabriel_pb2.InputFrame() input_frame.payload_type = gabriel_pb2.PayloadType.TEXT - input_frame.payloads.append('heartbeart'.encode('utf8')) + input_frame.payloads.append(b"heartbeart") extras = cnc_pb2.Extras() try: @@ -259,20 +271,24 @@ async def producer(): extras.location.latitude = sync(self.drone.getLat()) extras.location.longitude = sync(self.drone.getLng()) extras.location.altitude = sync(self.drone.getRelAlt()) - logger.debug(f'Latitude: {extras.location.latitude} Longitude: {extras.location.longitude} Altitude: {extras.location.altitude}') + logger.debug( + f"Latitude: {extras.location.latitude} Longitude: {extras.location.longitude} Altitude: {extras.location.altitude}" + ) extras.status.battery = sync(self.drone.getBatteryPercentage()) extras.status.rssi = sync(self.drone.getRSSI()) extras.status.mag = sync(self.drone.getMagnetometerReading()) extras.status.bearing = sync(self.drone.getHeading()) - logger.debug(f'Battery: {extras.status.battery} RSSI: {extras.status.rssi} Magnetometer: {extras.status.mag} Heading: {extras.status.bearing}') + logger.debug( + f"Battery: {extras.status.battery} RSSI: {extras.status.rssi} Magnetometer: {extras.status.mag} Heading: {extras.status.bearing}" + ) except Exception as e: - logger.debug(f'Error getting telemetry: {e}') + logger.debug(f"Error getting telemetry: {e}") # Register on the first frame if self.heartbeats == 1: extras.registering = True - logger.debug('Producing Gabriel frame!') + logger.debug("Producing Gabriel frame!") input_frame.extras.Pack(extras) return input_frame @@ -280,34 +296,62 @@ async def producer(): async def _main(): - parser = argparse.ArgumentParser(prog='supervisor', - description='Bridges python API drones to SteelEagle.') - parser.add_argument('-d', '--drone', default='ParrotAnafiDrone', - help='Set the type of drone to interface with [default: ParrotAnafiDrone]') - parser.add_argument('-c', '--cloudlet', default='PureOffloadCloudlet', - help='Set the type of offload method to the cloudlet [default: PureOffloadCloudlet]') - parser.add_argument('-s', '--server', default='gabriel-server', - help='Specify address of Steel Eagle CNC server [default: gabriel-server]') - parser.add_argument('-p', '--port', default='9099', - help='Specify websocket port [default: 9099]') - parser.add_argument('-l', '--loglevel', default='INFO', - help='Set the log level') - parser.add_argument('-S', '--sim', action='store_true', - help='Connect to simulated drone instead of a real drone [default: False]') - parser.add_argument('-L', '--lowdelay', action='store_true', - help='Use low delay settings for video streaming [default: False]') - parser.add_argument('-zp', '--zmqport', type=int, default=6000, - help='Specify websocket port [default: 6000]') - parser.add_argument('-t', '--trajectory', action='store_true', - help='Log the trajectory of the drone over the flight duration [default: False]') - parser.add_argument('-i', '--droneip', default='192.168.42.1', - help='Specify drone IP address [default: 192.168.42.1]') + parser = argparse.ArgumentParser( + prog="supervisor", description="Bridges python API drones to SteelEagle." + ) + parser.add_argument( + "-d", + "--drone", + default="ParrotAnafiDrone", + help="Set the type of drone to interface with [default: ParrotAnafiDrone]", + ) + parser.add_argument( + "-c", + "--cloudlet", + default="PureOffloadCloudlet", + help="Set the type of offload method to the cloudlet [default: PureOffloadCloudlet]", + ) + parser.add_argument( + "-s", + "--server", + default="gabriel-server", + help="Specify address of Steel Eagle CNC server [default: gabriel-server]", + ) + parser.add_argument( + "-p", "--port", default="9099", help="Specify websocket port [default: 9099]" + ) + parser.add_argument("-l", "--loglevel", default="INFO", help="Set the log level") + parser.add_argument( + "-S", + "--sim", + action="store_true", + help="Connect to simulated drone instead of a real drone [default: False]", + ) + parser.add_argument( + "-L", + "--lowdelay", + action="store_true", + help="Use low delay settings for video streaming [default: False]", + ) + parser.add_argument( + "-zp", "--zmqport", type=int, default=6000, help="Specify websocket port [default: 6000]" + ) + parser.add_argument( + "-t", + "--trajectory", + action="store_true", + help="Log the trajectory of the drone over the flight duration [default: False]", + ) + parser.add_argument( + "-i", + "--droneip", + default="192.168.42.1", + help="Specify drone IP address [default: 192.168.42.1]", + ) - parser.add_argument('-n', '--dronename', - help='Specify drone name.') + parser.add_argument("-n", "--dronename", help="Specify drone name.") args = parser.parse_args() - logging.basicConfig(format="%(levelname)s: %(message)s", - level=args.loglevel) + logging.basicConfig(format="%(levelname)s: %(message)s", level=args.loglevel) logger.info(f"{args=}") adapter = Supervisor(args) @@ -316,8 +360,10 @@ async def _main(): logger.debug("Launching Gabriel") gabriel_client = WebsocketClient( - args.server, args.port, - [adapter.get_producer_wrappers(), adapter.cloudlet.sendFrame()], adapter.processResults + args.server, + args.port, + [adapter.get_producer_wrappers(), adapter.cloudlet.sendFrame()], + adapter.processResults, ) try: gabriel_client.launch() diff --git a/os/drivers/ModalAI/Seeker/Seeker.py b/os/drivers/ModalAI/Seeker/Seeker.py index 7ff9cbb9..f99f4a85 100644 --- a/os/drivers/ModalAI/Seeker/Seeker.py +++ b/os/drivers/ModalAI/Seeker/Seeker.py @@ -1,26 +1,26 @@ -from enum import Enum +import asyncio +import logging import math import os +import threading import time -import asyncio -import logging +from enum import Enum + +import cv2 from pymavlink import mavutil logger = logging.getLogger(__name__) -class ConnectionFailedException(Exception): - pass -class ModalAISeekerDrone(): - +class ModalAISeekerDrone: class FlightMode(Enum): - LAND = 'LAND' - RTL = 'RTL' - LOITER = 'LOITER' - TAKEOFF = 'TAKEOFF' - ALT_HOLD = 'ALT_HOLD' - OFFBOARD = 'OFFBOARD' - + LAND = "LAND" + RTL = "RTL" + LOITER = "LOITER" + TAKEOFF = "TAKEOFF" + ALT_HOLD = "ALT_HOLD" + OFFBOARD = "OFFBOARD" + def __init__(self): self.vehicle = None self.mode = None @@ -28,7 +28,8 @@ def __init__(self): self.listener_task = None self.gps_disabled = False - ''' Connect methods ''' + """ Connect methods """ + async def connect(self, connection_string): # connect to drone logger.info(f"Connecting to drone at {connection_string}...") @@ -37,11 +38,11 @@ async def connect(self, connection_string): logger.info("-- Connected to drone!") self.mode_mapping = self.vehicle.mode_mapping() logger.debug(f"Mode mapping: {self.mode_mapping}") - + # register telemetry streams await self.register_telemetry_streams() asyncio.create_task(self._message_listener()) - + async def register_telemetry_streams(self, frequency_hz: float = 10.0): # Define the telemetry message names telemetry_message_names = [ @@ -61,7 +62,9 @@ async def register_telemetry_streams(self, frequency_hz: float = 10.0): try: message_id = getattr(mavutil.mavlink, f"MAVLINK_MSG_ID_{message_name}", None) if message_id is None: - logger.warning(f"Message name {message_name} is not found in MAVLink definitions.") + logger.warning( + f"Message name {message_name} is not found in MAVLink definitions." + ) continue # Request the message interval @@ -73,14 +76,20 @@ async def register_telemetry_streams(self, frequency_hz: float = 10.0): def request_message_interval(self, message_id: int, frequency_hz: float): self.vehicle.mav.command_long_send( - self.vehicle.target_system, self.vehicle.target_component, - mavutil.mavlink.MAV_CMD_SET_MESSAGE_INTERVAL, 0, - message_id, # The MAVLink message ID - 1e6 / frequency_hz, # The interval between two messages in microseconds. Set to -1 to disable and 0 to request default rate. - 0, 0, 0, 0, # Unused parameters - 0, # Target address of message stream (if message has target address fields). 0: Flight-stack default (recommended), 1: address of requestor, 2: broadcast. + self.vehicle.target_system, + self.vehicle.target_component, + mavutil.mavlink.MAV_CMD_SET_MESSAGE_INTERVAL, + 0, + message_id, # The MAVLink message ID + 1e6 + / frequency_hz, # The interval between two messages in microseconds. Set to -1 to disable and 0 to request default rate. + 0, + 0, + 0, + 0, # Unused parameters + 0, # Target address of message stream (if message has target address fields). 0: Flight-stack default (recommended), 1: address of requestor, 2: broadcast. ) - + async def isConnected(self): return self.vehicle is not None @@ -93,19 +102,20 @@ async def disconnect(self): self.vehicle.close() logger.info("-- Disconnected from drone") - ''' Telemetry methods ''' + """ Telemetry methods """ + async def getTelemetry(self): try: tel_dict = {} - tel_dict['name'] = self.getName() - tel_dict['gps'] = self.getGPS() - tel_dict['relAlt'] = self.getAltitudeRel() - tel_dict['attitude'] = self.getAttitude() - tel_dict['magnetometer'] = self.getMagnetometerReading() - tel_dict['imu'] = self.getVelocityNEU() - tel_dict['battery'] = self.getBatteryPercentage() - tel_dict['satellites'] = self.getSatellites() - tel_dict['heading'] = self.getHeading() + tel_dict["name"] = self.getName() + tel_dict["gps"] = self.getGPS() + tel_dict["relAlt"] = self.getAltitudeRel() + tel_dict["attitude"] = self.getAttitude() + tel_dict["magnetometer"] = self.getMagnetometerReading() + tel_dict["imu"] = self.getVelocityNEU() + tel_dict["battery"] = self.getBatteryPercentage() + tel_dict["satellites"] = self.getSatellites() + tel_dict["heading"] = self.getHeading() logger.debug(f"Telemetry data: {tel_dict}") return tel_dict @@ -114,7 +124,7 @@ async def getTelemetry(self): return {} def getName(self): - drone_id = os.environ.get('DRONE_ID') + drone_id = os.environ.get("DRONE_ID") return drone_id def getGPS(self): @@ -124,7 +134,7 @@ def getGPS(self): return { "latitude": gps_msg.lat / 1e7, "longitude": gps_msg.lon / 1e7, - "altitude": gps_msg.alt / 1e3 + "altitude": gps_msg.alt / 1e3, } def getAltitudeRel(self): @@ -137,21 +147,13 @@ def getAttitude(self): attitude_msg = self._get_cached_message("ATTITUDE") if not attitude_msg: return None - return { - "roll": attitude_msg.roll, - "pitch": attitude_msg.pitch, - "yaw": attitude_msg.yaw - } + return {"roll": attitude_msg.roll, "pitch": attitude_msg.pitch, "yaw": attitude_msg.yaw} def getMagnetometerReading(self): imu_msg = self._get_cached_message("RAW_IMU") if not imu_msg: return {"x": None, "y": None, "z": None} - return { - "x": imu_msg.xmag, - "y": imu_msg.ymag, - "z": imu_msg.zmag - } + return {"x": imu_msg.xmag, "y": imu_msg.ymag, "z": imu_msg.zmag} def getBatteryPercentage(self): battery_msg = self._get_cached_message("BATTERY_STATUS") @@ -175,12 +177,8 @@ def getVelocityNEU(self): gps_msg = self._get_cached_message("GLOBAL_POSITION_INT") if not gps_msg: return None - return { - "forward": gps_msg.vx / 100, - "right": gps_msg.vy / 100, - "up": gps_msg.vz / 100 - } - + return {"forward": gps_msg.vx / 100, "right": gps_msg.vy / 100, "up": gps_msg.vz / 100} + def getVelocityBody(self): velocity_msg = self._get_cached_message("LOCAL_POSITION_NED") if not velocity_msg: @@ -188,7 +186,7 @@ def getVelocityBody(self): return { "vx": velocity_msg.vx, # Body-frame X velocity in m/s "vy": velocity_msg.vy, # Body-frame Y velocity in m/s - "vz": velocity_msg.vz # Body-frame Z velocity in m/s + "vz": velocity_msg.vz, # Body-frame Z velocity in m/s } def getRSSI(self): @@ -197,36 +195,43 @@ def getRSSI(self): return None return rssi_msg.rssi - ''' Actuation methods ''' + """ Actuation methods """ + async def hover(self): logger.info("-- Hovering") await self.setVelocity(0.0, 0.0, 0.0, 0.0) async def takeOff(self, target_altitude): logger.info("-- Taking off") - + await self.arm() await self.switchMode(ModalAISeekerDrone.FlightMode.TAKEOFF) - + # Take off at the current GPS location gps = self.getGPS() self.vehicle.mav.command_long_send( - self.vehicle.target_system, # target system - self.vehicle.target_component, # target component - mavutil.mavlink.MAV_CMD_NAV_TAKEOFF, # command - 0, # confirmation - 0, 0, 0, 0, gps['latitude'], gps['longitude'], gps['altitude'] + 2.5) # param 1 ~ 7 (param 7 is the target altitude) - + self.vehicle.target_system, # target system + self.vehicle.target_component, # target component + mavutil.mavlink.MAV_CMD_NAV_TAKEOFF, # command + 0, # confirmation + 0, + 0, + 0, + 0, + gps["latitude"], + gps["longitude"], + gps["altitude"] + 2.5, + ) # param 1 ~ 7 (param 7 is the target altitude) + result = await self._wait_for_condition( - lambda: self.is_mode_set(ModalAISeekerDrone.FlightMode.LOITER), - interval=1 + lambda: self.is_mode_set(ModalAISeekerDrone.FlightMode.LOITER), interval=1 ) if result: logger.info("-- Takeoff success") - else: + else: logger.error("-- Takeoff failed") - + return result async def land(self): @@ -234,19 +239,25 @@ async def land(self): await self.switchMode(ModalAISeekerDrone.FlightMode.LAND) self.vehicle.mav.command_long_send( - self.vehicle.target_system, self.vehicle.target_component, + self.vehicle.target_system, + self.vehicle.target_component, mavutil.mavlink.MAV_CMD_NAV_LAND, - 0, 0, 0, 0, 0, 0, 0, 0) - - result = await self._wait_for_condition( - lambda: self.is_disarmed(), - interval=1 + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, ) + + result = await self._wait_for_condition(lambda: self.is_disarmed(), interval=1) if result: logger.info("-- Landed and disarmed") - else: + else: logger.error("-- Landing failed") - + return result async def setHome(self, lat, lng, alt): @@ -256,61 +267,61 @@ async def setHome(self, lat, lng, alt): self.vehicle.target_component, mavutil.mavlink.MAV_CMD_DO_SET_HOME, 1, - 0, 0, 0, 0, - lat, lng, alt + 0, + 0, + 0, + 0, + lat, + lng, + alt, ) - result = await self._wait_for_condition( - lambda: self.is_home_set(), - timeout=5, - interval=0.1 - ) - + result = await self._wait_for_condition(lambda: self.is_home_set(), timeout=5, interval=0.1) + if result: logger.info("-- Home location set successfully") else: logger.error("-- Failed to set home location") - + return result - + async def rth(self): logger.info("-- Returning to launch") - if await self.switchMode(ModalAISeekerDrone.FlightMode.RTL) == False: + if await self.switchMode(ModalAISeekerDrone.FlightMode.RTL) is False: logger.error("Failed to set mode to RTL") return - result = await self._wait_for_condition( - lambda: self.is_disarmed(), - interval=1 - ) - + result = await self._wait_for_condition(lambda: self.is_disarmed(), interval=1) + if result: logger.info("-- Returned to launch and disarmed") - else: + else: logger.error("-- RTL failed") - + async def manual_control(self, forward_vel, right_vel, up_vel, angle_vel): if self.gps_disabled: - if await self.switchMode(ModalAISeekerDrone.FlightMode.OFFBOARD) == False: + if await self.switchMode(ModalAISeekerDrone.FlightMode.OFFBOARD) is False: logger.error("Failed to set mode to GUIDED_NOGPS") return # if await self.switchMode(ModalAISeekerDrone.FlightMode.ALT_HOLD) == False: # logger.error("Failed to set mode to GUIDED_NOGPS") # return else: - if await self.switchMode(ModalAISeekerDrone.FlightMode.OFFBOARD) == False: + if await self.switchMode(ModalAISeekerDrone.FlightMode.OFFBOARD) is False: logger.error("Failed to set mode to GUIDED") return - logger.info(f"Sending manual control: forward={forward_vel}, right={right_vel}, up={up_vel}, yaw={angle_vel}") + logger.info( + f"Sending manual control: forward={forward_vel}, right={right_vel}, up={up_vel}, yaw={angle_vel}" + ) # Ensure values are within MAVLink range (-1000 to 1000) def clamp(value, min_val, max_val): return max(min_val, min(max_val, int(value))) # Ensure integer conversion x = clamp(forward_vel * 1000, -1000, 1000) # Forward/backward movement - y = clamp(right_vel * 1000, -1000, 1000) # Left/right movement - z = clamp(up_vel * 1000, 0, 1000) # Throttle (0=lowest, 1000=full thrust) - r = clamp(angle_vel * 1000, -1000, 1000) # Yaw rotation + y = clamp(right_vel * 1000, -1000, 1000) # Left/right movement + z = clamp(up_vel * 1000, 0, 1000) # Throttle (0=lowest, 1000=full thrust) + r = clamp(angle_vel * 1000, -1000, 1000) # Yaw rotation buttons = 0 # No buttons pressed buttons2 = 0 # No additional buttons @@ -320,32 +331,41 @@ def clamp(value, min_val, max_val): try: self.vehicle.mav.manual_control_send( self.vehicle.target_system, - x, y, z, r, + x, + y, + z, + r, buttons, buttons2, enabled_extensions, - s, t, aux1, aux2, aux3, aux4, aux5, aux6 + s, + t, + aux1, + aux2, + aux3, + aux4, + aux5, + aux6, ) logger.info("Manual control command sent successfully!") except Exception as e: logger.error(f"Failed to send manual control: {e}") - - + async def setAttitude(self, pitch, roll, thrust, yaw): logger.info(f"-- Setting attitude: pitch={pitch}, roll={roll}, thrust={thrust}, yaw={yaw}") if self.gps_disabled: - if await self.switchMode(ModalAISeekerDrone.FlightMode.OFFBOARD) == False: + if await self.switchMode(ModalAISeekerDrone.FlightMode.OFFBOARD) is False: logger.error("Failed to set mode to GUIDED_NOGPS") return # if await self.switchMode(ModalAISeekerDrone.FlightMode.ALT_HOLD) == False: # logger.error("Failed to set mode to GUIDED_NOGPS") # return else: - if await self.switchMode(ModalAISeekerDrone.FlightMode.OFFBOARD) == False: + if await self.switchMode(ModalAISeekerDrone.FlightMode.OFFBOARD) is False: logger.error("Failed to set mode to GUIDED") return - + # Convert Euler angles to quaternion (w, x, y, z) def to_quaternion(roll=0.0, pitch=0.0, yaw=0.0): roll, pitch, yaw = map(math.radians, [roll, pitch, yaw]) @@ -353,54 +373,67 @@ def to_quaternion(roll=0.0, pitch=0.0, yaw=0.0): cp, sp = math.cos(pitch * 0.5), math.sin(pitch * 0.5) cr, sr = math.cos(roll * 0.5), math.sin(roll * 0.5) - return [cr * cp * cy + sr * sp * sy, # w - sr * cp * cy - cr * sp * sy, # x - cr * sp * cy + sr * cp * sy, # y - cr * cp * sy - sr * sp * cy] # z + return [ + cr * cp * cy + sr * sp * sy, # w + sr * cp * cy - cr * sp * sy, # x + cr * sp * cy + sr * cp * sy, # y + cr * cp * sy - sr * sp * cy, + ] # z q = to_quaternion(roll, pitch, yaw) base_thrust = 0.6 - + self.vehicle.mav.set_attitude_target_send( 0, # time_boot_ms self.vehicle.target_system, self.vehicle.target_component, 0b00000000, # type_mask q, # Quaternion - 0, 0, 0, # Body angular rates - base_thrust + thrust # Throttle + 0, + 0, + 0, # Body angular rates + base_thrust + thrust, # Throttle ) logger.info("-- setAttitude sent successfully") # continuous control: no blocking wait async def setVelocity(self, forward_vel, right_vel, up_vel, angle_vel): - logger.info(f"-- Setting velocity: forward_vel={forward_vel}, right_vel={right_vel}, up_vel={up_vel}, angle_vel={angle_vel}") - - if await self.switchMode(ModalAISeekerDrone.FlightMode.OFFBOARD) == False: + logger.info( + f"-- Setting velocity: forward_vel={forward_vel}, right_vel={right_vel}, up_vel={up_vel}, angle_vel={angle_vel}" + ) + + if await self.switchMode(ModalAISeekerDrone.FlightMode.OFFBOARD) is False: logger.error("Failed to set mode to GUIDED") return - + self.vehicle.mav.set_position_target_local_ned_send( 0, # time_boot_ms self.vehicle.target_system, self.vehicle.target_component, mavutil.mavlink.MAV_FRAME_BODY_NED, # frame 0b010111000111, # type_mask - 0, 0, 0, # x, y, z positions - forward_vel, right_vel, -up_vel, # x, y, z velocity - 0, 0, 0, # x, y, z acceleration - 0, angle_vel # yaw, yaw_rate + 0, + 0, + 0, # x, y, z positions + forward_vel, + right_vel, + -up_vel, # x, y, z velocity + 0, + 0, + 0, # x, y, z acceleration + 0, + angle_vel, # yaw, yaw_rate ) logger.info("-- setVelocity sent successfully") async def setGPSLocation(self, lat, lon, alt, bearing): logger.info(f"-- Setting GPS location: lat={lat}, lon={lon}, alt={alt}, bearing={bearing}") - - if await self.switchMode(ModalAISeekerDrone.FlightMode.OFFBOARD) == False: + + if await self.switchMode(ModalAISeekerDrone.FlightMode.OFFBOARD) is False: logger.error("Failed to set mode to GUIDED") return - + self.vehicle.mav.set_position_target_global_int_send( 0, self.vehicle.target_system, @@ -410,48 +443,58 @@ async def setGPSLocation(self, lat, lon, alt, bearing): int(lat * 1e7), int(lon * 1e7), alt, - 0, 0, 0, - 0, 0, 0, - 0, 0 + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, ) - # Calculate bearing if not provided + # Calculate bearing if not provided current_location = self.getGPS() current_lat = current_location["latitude"] current_lon = current_location["longitude"] if bearing is None: bearing = self.calculate_bearing(current_lat, current_lon, lat, lon) logger.info(f"-- Calculated bearing: {bearing}") - + await self.setBearing(bearing) - - result = await self._wait_for_condition( - lambda: self.is_at_target(lat, lon), - interval=1 - ) - - if result: + + result = await self._wait_for_condition(lambda: self.is_at_target(lat, lon), interval=1) + + if result: logger.info("-- Reached target GPS location") - else: + else: logger.info("-- Failed to reach target GPS location") - + return result async def setTranslatedLocation(self, forward, right, up, angle): - logger.info(f"-- Translating location: forward={forward}, right={right}, up={up}, angle={angle}") - - if await self.switchMode(ModalAISeekerDrone.FlightMode.OFFBOARD) == False: + logger.info( + f"-- Translating location: forward={forward}, right={right}, up={up}, angle={angle}" + ) + + if await self.switchMode(ModalAISeekerDrone.FlightMode.OFFBOARD) is False: logger.error("Failed to set mode to GUIDED") return - - current_location = self.getGPS() - current_heading = self.getHeading() - dx = forward * math.cos(math.radians(current_heading)) - right * math.sin(math.radians(current_heading)) - dy = forward * math.sin(math.radians(current_heading)) + right * math.cos(math.radians(current_heading)) + current_location = self.getGPS() + current_heading = self.getHeading() + + dx = forward * math.cos(math.radians(current_heading)) - right * math.sin( + math.radians(current_heading) + ) + dy = forward * math.sin(math.radians(current_heading)) + right * math.cos( + math.radians(current_heading) + ) dz = -up target_lat = current_location["latitude"] + (dx / 111320) - target_lon = current_location["longitude"] + (dy / (111320 * math.cos(math.radians(current_location["latitude"])))) + target_lon = current_location["longitude"] + ( + dy / (111320 * math.cos(math.radians(current_location["latitude"]))) + ) target_alt = current_location["altitude"] + dz self.vehicle.mav.set_position_target_global_int_send( @@ -463,34 +506,39 @@ async def setTranslatedLocation(self, forward, right, up, angle): int(target_lat * 1e7), int(target_lon * 1e7), target_alt, - 0, 0, 0, - 0, 0, 0, - 0, 0 + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, ) - - if angle is not None: await self.setBearing(angle) - + + if angle is not None: + await self.setBearing(angle) + result = await self._wait_for_condition( - lambda: self.is_at_target(target_lat, target_lon), - interval=1 + lambda: self.is_at_target(target_lat, target_lon), interval=1 ) - - if result: + + if result: logger.info("-- Reached target translated location") else: logger.error("-- Failed to reach target translated location") - + return result async def setBearing(self, bearing): logger.info(f"-- Setting yaw to {bearing} degrees") - - if await self.switchMode(ModalAISeekerDrone.FlightMode.OFFBOARD) == False: + + if await self.switchMode(ModalAISeekerDrone.FlightMode.OFFBOARD) is False: logger.error("Failed to set mode to GUIDED") return - - yaw_speed = 25 # deg/s - direction = 0 # 1: clockwise, -1: counter-clockwise 0: most quickly direction + + yaw_speed = 25 # deg/s + direction = 0 # 1: clockwise, -1: counter-clockwise 0: most quickly direction self.vehicle.mav.command_long_send( self.vehicle.target_system, self.vehicle.target_component, @@ -500,21 +548,22 @@ async def setBearing(self, bearing): yaw_speed, direction, 0, - 0, 0, 0 + 0, + 0, + 0, ) - - result = await self._wait_for_condition( - lambda: self.is_bearing_reached(bearing), - interval=0.5 + + result = await self._wait_for_condition( + lambda: self.is_bearing_reached(bearing), interval=0.5 ) - + result = True - + if result: logger.info(f"-- Yaw successfully set to {bearing} degrees") else: logger.error(f"-- Failed to set yaw to {bearing} degrees") - + return result async def arm(self): @@ -525,17 +574,17 @@ async def arm(self): mavutil.mavlink.MAV_CMD_COMPONENT_ARM_DISARM, 0, 1, - 0, 0, 0, 0, 0, 0 + 0, + 0, + 0, + 0, + 0, + 0, ) logger.info("-- Arm command sent") + result = await self._wait_for_condition(lambda: self.is_armed(), timeout=5, interval=1) - result = await self._wait_for_condition( - lambda: self.is_armed(), - timeout=5, - interval=1 - ) - if result: logger.info("-- Armed successfully") else: @@ -550,62 +599,72 @@ async def disarm(self): mavutil.mavlink.MAV_CMD_COMPONENT_ARM_DISARM, 0, 0, - 0, 0, 0, 0, 0, 0 + 0, + 0, + 0, + 0, + 0, + 0, ) logger.info("-- Disarm command sent") - result = await self._wait_for_condition( - lambda: self.is_disarmed(), - timeout=5, - interval=1 - ) - + result = await self._wait_for_condition(lambda: self.is_disarmed(), timeout=5, interval=1) + if result: self.mode = None logger.info("-- Disarmed successfully") else: logger.error("-- Disarm failed") - - return result + + return result async def switchMode(self, mode): logger.info(f"Switching mode to {mode}") mode_target = mode.value curr_mode = self.mode.value if self.mode else None - logger.info(f"mode map: {self.mode_mapping}, target mode: {mode_target}, current mode: {curr_mode}") - + logger.info( + f"mode map: {self.mode_mapping}, target mode: {mode_target}, current mode: {curr_mode}" + ) + if self.mode == mode: logger.info(f"Already in mode {mode_target}") return True - + # switch mode if mode_target not in self.mode_mapping: logger.info(f"Mode {mode_target} not supported!") return False - + mode_id = self.mode_mapping[mode_target] logger.info(f"Mode ID Triplet: {mode_id}") self.vehicle.mav.command_long_send( - self.vehicle.target_system, self.vehicle.target_component, - mavutil.mavlink.MAV_CMD_DO_SET_MODE, 0, - mode_id[0], mode_id[1], mode_id[2], 0, 0, 0, 0) - + self.vehicle.target_system, + self.vehicle.target_component, + mavutil.mavlink.MAV_CMD_DO_SET_MODE, + 0, + mode_id[0], + mode_id[1], + mode_id[2], + 0, + 0, + 0, + 0, + ) + if mode is not ModalAISeekerDrone.FlightMode.OFFBOARD: result = await self._wait_for_condition( - lambda: self.is_mode_set(mode), - timeout=5, - interval=1 + lambda: self.is_mode_set(mode), timeout=5, interval=1 ) - + if result: self.mode = mode logger.info(f"Mode switched to {mode_target}") return result else: - logger.info(f"Priming for OFFBOARD mode") + logger.info("Priming for OFFBOARD mode") return True - + async def disableGPS(self): # logger.info("-- Disabling GPS") @@ -617,41 +676,52 @@ async def disableGPS(self): # 3, # 3 = No GPS (indoor mode) # mavutil.mavlink.MAV_PARAM_TYPE_INT32 # ) - + # result = await self._wait_for_condition( # lambda: self.is_GPS_disabled() # ) - + # if result: # logger.info("-- GPS disabled") # else: - # logger.error("-- Failed to disable GPS") - # Wait and print received parameter messages + # logger.error("-- Failed to disable GPS") + # Wait and print received parameter messages # return result - + self.gps_disabled = True - - - - ''' ACK methods''' + + """ ACK methods""" + def is_armed(self): - return self.vehicle.recv_match(type="HEARTBEAT", blocking=True).base_mode & mavutil.mavlink.MAV_MODE_FLAG_SAFETY_ARMED - + return ( + self.vehicle.recv_match(type="HEARTBEAT", blocking=True).base_mode + & mavutil.mavlink.MAV_MODE_FLAG_SAFETY_ARMED + ) + def is_disarmed(self): - return not (self.vehicle.recv_match(type='HEARTBEAT', blocking=True).base_mode & mavutil.mavlink.MAV_MODE_FLAG_SAFETY_ARMED) + return not ( + self.vehicle.recv_match(type="HEARTBEAT", blocking=True).base_mode + & mavutil.mavlink.MAV_MODE_FLAG_SAFETY_ARMED + ) def is_mode_set(self, mode): - current_mode = mavutil.mode_string_v10(self.vehicle.recv_match(type='HEARTBEAT', blocking=True)) + current_mode = mavutil.mode_string_v10( + self.vehicle.recv_match(type="HEARTBEAT", blocking=True) + ) return current_mode == mode.value - + def is_home_set(self): - msg = self.vehicle.recv_match(type='COMMAND_ACK', blocking=True) - return msg and msg.command == mavutil.mavlink.MAV_CMD_DO_SET_HOME and msg.result == mavutil.mavlink.MAV_RESULT_ACCEPTED + msg = self.vehicle.recv_match(type="COMMAND_ACK", blocking=True) + return ( + msg + and msg.command == mavutil.mavlink.MAV_CMD_DO_SET_HOME + and msg.result == mavutil.mavlink.MAV_RESULT_ACCEPTED + ) def is_altitude_reached(self, target_altitude): current_altitude = self.getAltitudeRel() return current_altitude >= target_altitude * 0.95 - + def is_bearing_reached(self, bearing): logger.info(f"Checking if bearing is reached: {bearing}") attitude = self.getAttitude() @@ -661,7 +731,6 @@ def is_bearing_reached(self, bearing): current_yaw = (math.degrees(attitude["yaw"]) + 360) % 360 target_yaw = (bearing + 360) % 360 return abs(current_yaw - target_yaw) <= 2 - def is_at_target(self, lat, lon): current_location = self.getGPS() @@ -669,21 +738,22 @@ def is_at_target(self, lat, lon): return False dlat = lat - current_location["latitude"] dlon = lon - current_location["longitude"] - distance = math.sqrt((dlat ** 2) + (dlon ** 2)) * 1.113195e5 + distance = math.sqrt((dlat**2) + (dlon**2)) * 1.113195e5 return distance < 1.0 - - ''' Helper methods ''' + + """ Helper methods """ + # azimuth calculation for bearing def calculate_bearing(self, lat1, lon1, lat2, lon2): # Convert latitude and longitude from degrees to radians lat1, lon1, lat2, lon2 = map(math.radians, [lat1, lon1, lat2, lon2]) - + delta_lon = lon2 - lon1 # Bearing calculation x = math.sin(delta_lon) * math.cos(lat2) y = math.cos(lat1) * math.sin(lat2) - math.sin(lat1) * math.cos(lat2) * math.cos(delta_lon) - + initial_bearing = math.atan2(x, y) # Convert bearing from radians to degrees @@ -693,7 +763,7 @@ def calculate_bearing(self, lat1, lon1, lat2, lon2): converted_bearing = (initial_bearing + 360) % 360 return converted_bearing - + async def _message_listener(self): logger.info("-- Starting message listener") try: @@ -706,7 +776,7 @@ async def _message_listener(self): logger.info("-- Message listener stopped") except Exception as e: logger.error(f"-- Error in message listener: {e}") - + def _get_cached_message(self, message_type): try: logger.debug(f"Currently connection message types: {list(self.vehicle.messages)}") @@ -714,7 +784,7 @@ def _get_cached_message(self, message_type): except KeyError: logger.error(f"Message type {message_type} not found in cache") return None - + async def _wait_for_condition(self, condition_fn, timeout=None, interval=0.5): start_time = time.time() while True: @@ -729,60 +799,55 @@ async def _wait_for_condition(self, condition_fn, timeout=None, interval=0.5): return False await asyncio.sleep(interval) - ''' Stream methods ''' + """ Stream methods """ + async def getGimbalPose(self): pass - + async def startStreaming(self): - self.streamingThread = StreamingThread(self.vehicle) self.streamingThread.start() async def getVideoFrame(self): if self.streamingThread: - return [self.streamingThread.grabFrame().tobytes(), self.streamingThread.getFrameShape()] + return [ + self.streamingThread.grabFrame().tobytes(), + self.streamingThread.getFrameShape(), + ] async def stopStreaming(self): self.streamingThread.stop() - -import cv2 -import numpy as np -import os -import threading -class StreamingThread(threading.Thread): +class StreamingThread(threading.Thread): def __init__(self, drone): threading.Thread.__init__(self) self.currentFrame = None self.drone = drone - url_sim = os.environ.get('STREAM_SIM_URL') - url_mini = os.environ.get('STREAM_MINI_URL') - self.sim = os.environ.get('SIMULATION') - - if (self.sim == 'true'): - url = url_sim - else: - url = url_mini - + url_sim = os.environ.get("STREAM_SIM_URL") + url_mini = os.environ.get("STREAM_MINI_URL") + self.sim = os.environ.get("SIMULATION") + + url = url_sim if self.sim == "true" else url_mini + self.cap = cv2.VideoCapture(url) self.isRunning = True def run(self): try: - while(self.isRunning): + while self.isRunning: ret, self.currentFrame = self.cap.read() except Exception as e: logger.error(e) - + def getFrameShape(self): return self.currentFrame.shape - + def grabFrame(self): try: frame = self.currentFrame.copy() return frame - except Exception as e: + except Exception: # Send a blank frame return None diff --git a/os/drivers/Parrot/ParrotAnafi/ParrotAnafi.py b/os/drivers/Parrot/ParrotAnafi/ParrotAnafi.py index ce0e7b23..ee2c0e0a 100644 --- a/os/drivers/Parrot/ParrotAnafi/ParrotAnafi.py +++ b/os/drivers/Parrot/ParrotAnafi/ParrotAnafi.py @@ -2,68 +2,77 @@ # # SPDX-License-Identifier: GPL-2.0-only -import logging import asyncio -import threading +import logging import math import os +import queue +import threading import time +from enum import Enum +import cv2 +import logness +import numpy as np import olympe +import olympe.enums.move as move_mode from olympe import Drone -from olympe.messages.ardrone3.Piloting import TakeOff, Landing -from olympe.messages.ardrone3.Piloting import PCMD, moveTo, moveBy -from olympe.messages.rth import set_custom_location, return_to_home -from olympe.messages.ardrone3.PilotingState import moveToChanged -from olympe.messages.common.CommonState import BatteryStateChanged -from olympe.messages.ardrone3.PilotingSettingsState import MaxTiltChanged -from olympe.messages.ardrone3.SpeedSettingsState import MaxVerticalSpeedChanged, MaxRotationSpeedChanged -from olympe.messages.ardrone3.PilotingState import AttitudeChanged, GpsLocationChanged, AltitudeChanged, FlyingStateChanged, SpeedChanged from olympe.messages.ardrone3.GPSState import NumberOfSatelliteChanged -from olympe.messages.gimbal import set_target, attitude -from olympe.messages.wifi import rssi_changed -from olympe.messages.battery import capacity +from olympe.messages.ardrone3.Piloting import PCMD, Landing, TakeOff, moveBy, moveTo +from olympe.messages.ardrone3.PilotingSettingsState import MaxTiltChanged +from olympe.messages.ardrone3.PilotingState import ( + AltitudeChanged, + AttitudeChanged, + FlyingStateChanged, + GpsLocationChanged, + SpeedChanged, +) +from olympe.messages.ardrone3.SpeedSettingsState import ( + MaxRotationSpeedChanged, + MaxVerticalSpeedChanged, +) from olympe.messages.common.CalibrationState import MagnetoCalibrationRequiredState -import olympe.enums.move as move_mode -import olympe.enums.gimbal as gimbal_mode -from enum import Enum +from olympe.messages.common.CommonState import BatteryStateChanged +from olympe.messages.gimbal import attitude, set_target +from olympe.messages.rth import return_to_home, set_custom_location +from olympe.messages.wifi import rssi_changed logger = logging.getLogger(__name__) -import logness -logness.update_config({ - "handlers": { - "olympe_log_file": { - "class": "logness.FileHandler", - "formatter": "default_formatter", - "filename": "olympe.log" - }, - "ulog_log_file": { - "class": "logness.FileHandler", - "formatter": "default_formatter", - "filename": "ulog.log" +logness.update_config( + { + "handlers": { + "olympe_log_file": { + "class": "logness.FileHandler", + "formatter": "default_formatter", + "filename": "olympe.log", + }, + "ulog_log_file": { + "class": "logness.FileHandler", + "formatter": "default_formatter", + "filename": "ulog.log", + }, }, - }, - "loggers": { - "olympe": { - "level": "ERROR", - "handlers": ["console","olympe_log_file"] + "loggers": { + "olympe": {"level": "ERROR", "handlers": ["console", "olympe_log_file"]}, + "ulog": { + "level": "ERROR", + "handlers": ["console", "ulog_log_file"], + }, }, - "ulog": { - "level": "ERROR", - "handlers": ["console", "ulog_log_file"], - } } -}) +) + class ArgumentOutOfBoundsException(Exception): pass + class ConnectionFailedException(Exception): pass -class ParrotDrone(): +class ParrotDrone: class FlightMode(Enum): MANUAL = 1 ATTITUDE = 2 @@ -72,13 +81,13 @@ class FlightMode(Enum): def __init__(self, **kwargs): # Handle special arguments - self.ip = '192.168.42.1' - if 'sim' in kwargs and kwargs['sim']: - self.ip = '10.202.0.1' - if 'ip' in kwargs: - self.ip = kwargs['ip'] + self.ip = "192.168.42.1" + if "sim" in kwargs and kwargs["sim"]: + self.ip = "10.202.0.1" + if "ip" in kwargs: + self.ip = kwargs["ip"] self.ffmpeg = False - if 'ffmpeg' in kwargs and kwargs['ffmpeg']: + if "ffmpeg" in kwargs and kwargs["ffmpeg"]: self.ffmpeg = True # Create the drone object self.drone = Drone(self.ip) @@ -90,7 +99,7 @@ def __init__(self, **kwargs): self.flightmode = ParrotDrone.FlightMode.MANUAL logger.info("#####################parrot init##########################") - ''' Awaiting methods ''' + """ Awaiting methods """ async def switchModes(self, mode): if self.flightmode == mode: @@ -112,16 +121,18 @@ async def hovering(self, timeout=None): if timeout is not None: start = time.time() while True: - if self.drone(FlyingStateChanged(state="hovering", _policy="check")).success(): - break - elif start is not None and time.time() - start < timeout: + if ( + self.drone(FlyingStateChanged(state="hovering", _policy="check")).success() + or start is not None + and time.time() - start < timeout + ): break else: await asyncio.sleep(1) logger.info(f"Hovering function finished at: {time.time()}") - ''' Background PID tasks ''' + """ Background PID tasks """ async def _attitudePID(self): try: @@ -145,10 +156,7 @@ def updatePID(e, ep, tp, ts, pidDict): I *= -1 elif abs(e) <= 0.01 or I * pidDict["PrevI"] < 0: I = 0.0 - if abs(e) > 0.01: - D = pidDict["Kd"] * (e - ep) / (ts - tp) - else: - D = 0 + D = pidDict["Kd"] * (e - ep) / (ts - tp) if abs(e) > 0.01 else 0 return P, I, D @@ -234,10 +242,7 @@ def updatePID(e, ep, tp, ts, pidDict): I *= -1 elif abs(e) <= 0.05 or I * pidDict["PrevI"] < 0: I = 0.0 - if abs(e) > 0.01: - D = pidDict["Kd"] * (e - ep) / (ts - tp) - else: - D = 0.0 + D = pidDict["Kd"] * (e - ep) / (ts - tp) if abs(e) > 0.01 else 0.0 # For testing Integral component I = 0.0 @@ -316,7 +321,7 @@ def updatePID(e, ep, tp, ts, pidDict): except asyncio.CancelledError: pass - ''' Connection methods ''' + """ Connection methods """ async def connect(self): self.active = self.drone.connect() @@ -330,9 +335,9 @@ async def disconnect(self): self.drone.disconnect() self.active = False - ''' Streaming methods ''' + """ Streaming methods """ - async def startStreaming(self, save_frames = False): + async def startStreaming(self, save_frames=False): if not self.ffmpeg: self.streamingThread = PDRAWStreamingThread(self.drone, self.ip, save_frames) else: @@ -346,7 +351,7 @@ async def getVideoFrame(self): async def stopStreaming(self): self.streamingThread.stop() - ''' Take off / Landing methods ''' + """ Take off / Landing methods """ async def takeOff(self): logger.info(f"takeoff function started at: {time.time()}") @@ -375,19 +380,20 @@ async def rth(self): self.drone(return_to_home()) logger.info(f"rth function started at: {time.time()}") - ''' Camera methods ''' + """ Camera methods """ async def getCameras(self): pass async def switchCameras(self, camID): from olympe.messages.thermal import set_mode + if on: self.drone(set_mode(mode="blended")).wait().success() else: self.drone(set_mode(mode="disabled")).wait().success() - ''' Movement methods ''' + """ Movement methods """ async def setAttitude(self, pitch, roll, thrust, yaw): await self.switchModes(ParrotDrone.FlightMode.ATTITUDE) @@ -422,53 +428,51 @@ async def setVelocity(self, forward_vel, right_vel, up_vel, angle_vel): async def setGPSLocation(self, lat, lng, alt, bearing): await self.switchModes(ParrotDrone.FlightMode.GUIDED) if bearing is None: - self.drone( - moveTo(lat, lng, alt, move_mode.orientation_mode.to_target, 0.0) - ) + self.drone(moveTo(lat, lng, alt, move_mode.orientation_mode.to_target, 0.0)) else: - self.drone( - moveTo(lat, lng, alt, move_mode.orientation_mode.heading_during, bearing) - ) + self.drone(moveTo(lat, lng, alt, move_mode.orientation_mode.heading_during, bearing)) await self.hovering() async def setTranslatedPosition(self, forward, right, up, angle): await self.switchModes(ParrotDrone.FlightMode.GUIDED) - self.drone( - moveBy(forward, right, -1 * up, angle) - ) + self.drone(moveBy(forward, right, -1 * up, angle)) await self.hovering() async def rotateGimbal(self, yaw_theta, pitch_theta, roll_theta): pose_dict = await self.getGimbalPose() current_pitch = pose_dict["pitch"] - self.drone(set_target( - gimbal_id=0, - control_mode="position", - yaw_frame_of_reference="absolute", - yaw=yaw_theta, - pitch_frame_of_reference="absolute", - pitch=pitch_theta + current_pitch, - roll_frame_of_reference="absolute", - roll=roll_theta,) + self.drone( + set_target( + gimbal_id=0, + control_mode="position", + yaw_frame_of_reference="absolute", + yaw=yaw_theta, + pitch_frame_of_reference="absolute", + pitch=pitch_theta + current_pitch, + roll_frame_of_reference="absolute", + roll=roll_theta, + ) ) async def setGimbalPose(self, yaw_theta, pitch_theta, roll_theta): - self.drone(set_target( - gimbal_id=0, - control_mode="position", - yaw_frame_of_reference="absolute", - yaw=yaw_theta, - pitch_frame_of_reference="absolute", - pitch=pitch_theta, - roll_frame_of_reference="absolute", - roll=roll_theta,) + self.drone( + set_target( + gimbal_id=0, + control_mode="position", + yaw_frame_of_reference="absolute", + yaw=yaw_theta, + pitch_frame_of_reference="absolute", + pitch=pitch_theta, + roll_frame_of_reference="absolute", + roll=roll_theta, + ) ) async def hover(self): await self.switchModes(ParrotDrone.FlightMode.MANUAL) self.drone(PCMD(1, 0, 0, 0, 0, timestampAndSeqNum=0)) - ''' Status methods ''' + """ Status methods """ async def getTelemetry(self): telDict = {} @@ -489,10 +493,12 @@ async def getName(self): async def getGPS(self): try: - return (self.drone.get_state(GpsLocationChanged)["latitude"], + return ( + self.drone.get_state(GpsLocationChanged)["latitude"], self.drone.get_state(GpsLocationChanged)["longitude"], - self.drone.get_state(GpsLocationChanged)["altitude"]) - except Exception as e: + self.drone.get_state(GpsLocationChanged)["altitude"], + ) + except Exception: # If there is no GPS fix, return default values return (500.0, 500.0, 0.0) @@ -529,11 +535,10 @@ async def getVelocityBody(self): vecr = np.array([0.0, 1.0], dtype=float) rt = np.radians(hd + 90) c, s = np.cos(rt), np.sin(rt) - R2 = np.array(((c,-s), (s, c))) + R2 = np.array(((c, -s), (s, c))) vecr = np.dot(R2, vecr) - res = {"forward": np.dot(vec, vecf) * -1, "right": np.dot(vec, vecr) * -1, \ - "up": NEU["up"]} + res = {"forward": np.dot(vec, vecf) * -1, "right": np.dot(vec, vecr) * -1, "up": NEU["up"]} return res async def getRSSI(self): @@ -546,38 +551,42 @@ async def getMagnetometerReading(self): return self.drone.get_state(MagnetoCalibrationRequiredState)["required"] async def getGimbalPose(self): - return {"roll": 0.0, "pitch": self.drone.get_state(attitude)[0]["pitch_absolute"], "yaw": 0.0} + return { + "roll": 0.0, + "pitch": self.drone.get_state(attitude)[0]["pitch_absolute"], + "yaw": 0.0, + } async def getAttitude(self): att = self.drone.get_state(AttitudeChanged) rad_to_deg = 180 / math.pi - return {"roll": att["roll"] * rad_to_deg, "pitch": att["pitch"] * rad_to_deg, - "yaw": att["yaw"] * rad_to_deg} + return { + "roll": att["roll"] * rad_to_deg, + "pitch": att["pitch"] * rad_to_deg, + "yaw": att["yaw"] * rad_to_deg, + } - ''' Emergency methods ''' + """ Emergency methods """ async def kill(self): self.active = False -import cv2 -import numpy as np -import os - class FFMPEGStreamingThread(threading.Thread): - def __init__(self, drone, ip): threading.Thread.__init__(self) self.currentFrame = None self.drone = drone os.environ["OPENCV_FFMPEG_CAPTURE_OPTIONS"] = "rtsp_transport;udp" num_threads = int(os.environ.get("FFMPEG_THREADS")) - self.cap = cv2.VideoCapture(f"rtsp://{ip}/live", cv2.CAP_FFMPEG, (cv2.CAP_PROP_N_THREADS, num_threads)) + self.cap = cv2.VideoCapture( + f"rtsp://{ip}/live", cv2.CAP_FFMPEG, (cv2.CAP_PROP_N_THREADS, num_threads) + ) self.isRunning = True def run(self): try: - while(self.isRunning): + while self.isRunning: ret, self.currentFrame = self.cap.read() except Exception as e: logger.error(e) @@ -594,11 +603,9 @@ def grabFrame(self): def stop(self): self.isRunning = False -import queue class PDRAWStreamingThread(threading.Thread): - - def __init__(self, drone, ip, save_frames = False): + def __init__(self, drone, ip, save_frames=False): threading.Thread.__init__(self) self.drone = drone self.frame_queue = queue.Queue() @@ -657,7 +664,7 @@ def copyFrame(self, yuv_frame): os.makedirs(directory) cv2.imwrite(os.path.join(directory, filename), self.currentFrame) - ''' Callbacks ''' + """ Callbacks """ def yuvFrameCb(self, yuv_frame): """ diff --git a/os/drivers/SkyRocket/SkyViper2450GPS/SkyViper2450GPS.py b/os/drivers/SkyRocket/SkyViper2450GPS/SkyViper2450GPS.py index e820a032..8b32808c 100644 --- a/os/drivers/SkyRocket/SkyViper2450GPS/SkyViper2450GPS.py +++ b/os/drivers/SkyRocket/SkyViper2450GPS/SkyViper2450GPS.py @@ -1,26 +1,26 @@ -from enum import Enum +import asyncio +import logging import math import os +import threading import time -import asyncio -import logging +from enum import Enum + +import cv2 from pymavlink import mavutil logger = logging.getLogger(__name__) -class ConnectionFailedException(Exception): - pass -class SkyViper2450GPSDrone(): - +class SkyViper2450GPSDrone: class FlightMode(Enum): - LAND = 'LAND' - RTL = 'RTL' - LOITER = 'LOITER' - GUIDED = 'GUIDED' - GUIDED_NOGPS = 'GUIDED_NOGPS' - ALT_HOLD = 'ALT_HOLD' - + LAND = "LAND" + RTL = "RTL" + LOITER = "LOITER" + GUIDED = "GUIDED" + GUIDED_NOGPS = "GUIDED_NOGPS" + ALT_HOLD = "ALT_HOLD" + def __init__(self): self.vehicle = None self.mode = None @@ -28,8 +28,8 @@ def __init__(self): self.listener_task = None self.gps_disabled = False + """ Connect methods """ - ''' Connect methods ''' async def connect(self, connection_string): # connect to drone logger.info(f"Connecting to drone at {connection_string}...") @@ -38,11 +38,11 @@ async def connect(self, connection_string): logger.info("-- Connected to drone!") self.mode_mapping = self.vehicle.mode_mapping() logger.info(f"Mode mapping: {self.mode_mapping}") - + # register telemetry streams await self.register_telemetry_streams() asyncio.create_task(self._message_listener()) - + async def register_telemetry_streams(self, frequency_hz: float = 10.0): # Define the telemetry message names telemetry_message_names = [ @@ -62,7 +62,9 @@ async def register_telemetry_streams(self, frequency_hz: float = 10.0): try: message_id = getattr(mavutil.mavlink, f"MAVLINK_MSG_ID_{message_name}", None) if message_id is None: - logger.warning(f"Message name {message_name} is not found in MAVLink definitions.") + logger.warning( + f"Message name {message_name} is not found in MAVLink definitions." + ) continue # Request the message interval @@ -74,14 +76,20 @@ async def register_telemetry_streams(self, frequency_hz: float = 10.0): def request_message_interval(self, message_id: int, frequency_hz: float): self.vehicle.mav.command_long_send( - self.vehicle.target_system, self.vehicle.target_component, - mavutil.mavlink.MAV_CMD_SET_MESSAGE_INTERVAL, 0, - message_id, # The MAVLink message ID - 1e6 / frequency_hz, # The interval between two messages in microseconds. Set to -1 to disable and 0 to request default rate. - 0, 0, 0, 0, # Unused parameters - 0, # Target address of message stream (if message has target address fields). 0: Flight-stack default (recommended), 1: address of requestor, 2: broadcast. + self.vehicle.target_system, + self.vehicle.target_component, + mavutil.mavlink.MAV_CMD_SET_MESSAGE_INTERVAL, + 0, + message_id, # The MAVLink message ID + 1e6 + / frequency_hz, # The interval between two messages in microseconds. Set to -1 to disable and 0 to request default rate. + 0, + 0, + 0, + 0, # Unused parameters + 0, # Target address of message stream (if message has target address fields). 0: Flight-stack default (recommended), 1: address of requestor, 2: broadcast. ) - + async def isConnected(self): return self.vehicle is not None @@ -94,19 +102,20 @@ async def disconnect(self): self.vehicle.close() logger.info("-- Disconnected from drone") - ''' Telemetry methods ''' + """ Telemetry methods """ + async def getTelemetry(self): try: tel_dict = {} - tel_dict['name'] = self.getName() - tel_dict['gps'] = self.getGPS() - tel_dict['relAlt'] = self.getAltitudeRel() - tel_dict['attitude'] = self.getAttitude() - tel_dict['magnetometer'] = self.getMagnetometerReading() - tel_dict['imu'] = self.getVelocityNEU() - tel_dict['battery'] = self.getBatteryPercentage() - tel_dict['satellites'] = self.getSatellites() - tel_dict['heading'] = self.getHeading() + tel_dict["name"] = self.getName() + tel_dict["gps"] = self.getGPS() + tel_dict["relAlt"] = self.getAltitudeRel() + tel_dict["attitude"] = self.getAttitude() + tel_dict["magnetometer"] = self.getMagnetometerReading() + tel_dict["imu"] = self.getVelocityNEU() + tel_dict["battery"] = self.getBatteryPercentage() + tel_dict["satellites"] = self.getSatellites() + tel_dict["heading"] = self.getHeading() logger.debug(f"Telemetry data: {tel_dict}") return tel_dict @@ -115,7 +124,7 @@ async def getTelemetry(self): return {} def getName(self): - drone_id = os.environ.get('DRONE_ID') + drone_id = os.environ.get("DRONE_ID") return drone_id def getGPS(self): @@ -125,7 +134,7 @@ def getGPS(self): return { "latitude": gps_msg.lat / 1e7, "longitude": gps_msg.lon / 1e7, - "altitude": gps_msg.alt / 1e3 + "altitude": gps_msg.alt / 1e3, } def getAltitudeRel(self): @@ -138,21 +147,13 @@ def getAttitude(self): attitude_msg = self._get_cached_message("ATTITUDE") if not attitude_msg: return None - return { - "roll": attitude_msg.roll, - "pitch": attitude_msg.pitch, - "yaw": attitude_msg.yaw - } + return {"roll": attitude_msg.roll, "pitch": attitude_msg.pitch, "yaw": attitude_msg.yaw} def getMagnetometerReading(self): imu_msg = self._get_cached_message("RAW_IMU") if not imu_msg: return {"x": None, "y": None, "z": None} - return { - "x": imu_msg.xmag, - "y": imu_msg.ymag, - "z": imu_msg.zmag - } + return {"x": imu_msg.xmag, "y": imu_msg.ymag, "z": imu_msg.zmag} def getBatteryPercentage(self): battery_msg = self._get_cached_message("BATTERY_STATUS") @@ -176,12 +177,8 @@ def getVelocityNEU(self): gps_msg = self._get_cached_message("GLOBAL_POSITION_INT") if not gps_msg: return None - return { - "forward": gps_msg.vx / 100, - "right": gps_msg.vy / 100, - "up": gps_msg.vz / 100 - } - + return {"forward": gps_msg.vx / 100, "right": gps_msg.vy / 100, "up": gps_msg.vz / 100} + def getVelocityBody(self): velocity_msg = self._get_cached_message("LOCAL_POSITION_NED") if not velocity_msg: @@ -189,7 +186,7 @@ def getVelocityBody(self): return { "vx": velocity_msg.vx, # Body-frame X velocity in m/s "vy": velocity_msg.vy, # Body-frame Y velocity in m/s - "vz": velocity_msg.vz # Body-frame Z velocity in m/s + "vz": velocity_msg.vz, # Body-frame Z velocity in m/s } def getRSSI(self): @@ -198,7 +195,8 @@ def getRSSI(self): return None return rssi_msg.rssi - ''' Actuation methods ''' + """ Actuation methods """ + async def hover(self): logger.info("-- Hovering") # if await self.switchMode(SkyViper2450GPSDrone.FlightMode.LOITER) == False: @@ -208,28 +206,31 @@ async def hover(self): async def takeOff(self, target_altitude): logger.info("-- Taking off") - - if await self.switchMode(SkyViper2450GPSDrone.FlightMode.GUIDED) == False: + + if await self.switchMode(SkyViper2450GPSDrone.FlightMode.GUIDED) is False: logger.error("Failed to set mode to GUIDED") return - - if await self.arm() == False: + + if await self.arm() is False: logger.error("Failed to arm the drone") return - + self.vehicle.mav.command_long_send( self.vehicle.target_system, self.vehicle.target_component, mavutil.mavlink.MAV_CMD_NAV_TAKEOFF, 0, - 0, 0, 0, 0, - 0, 0, target_altitude + 0, + 0, + 0, + 0, + 0, + 0, + target_altitude, ) - + result = await self._wait_for_condition( - lambda: self.is_altitude_reached(target_altitude), - timeout=60, - interval=1 + lambda: self.is_altitude_reached(target_altitude), timeout=60, interval=1 ) if result: logger.info("-- Altitude reached") @@ -240,20 +241,16 @@ async def takeOff(self, target_altitude): async def land(self): logger.info("-- Landing") - if await self.switchMode(SkyViper2450GPSDrone.FlightMode.LAND) == False: + if await self.switchMode(SkyViper2450GPSDrone.FlightMode.LAND) is False: logger.error("Failed to set mode to LAND") return - result = await self._wait_for_condition( - lambda: self.is_disarmed(), - timeout=60, - interval=1 - ) + result = await self._wait_for_condition(lambda: self.is_disarmed(), timeout=60, interval=1) if result: logger.info("-- Landed and disarmed") - else: + else: logger.error("-- Landing failed") - + return result async def setHome(self, lat, lng, alt): @@ -263,62 +260,63 @@ async def setHome(self, lat, lng, alt): self.vehicle.target_component, mavutil.mavlink.MAV_CMD_DO_SET_HOME, 1, - 0, 0, 0, 0, - lat, lng, alt + 0, + 0, + 0, + 0, + lat, + lng, + alt, ) result = await self._wait_for_condition( - lambda: self.is_home_set(), - timeout=30, - interval=0.1 + lambda: self.is_home_set(), timeout=30, interval=0.1 ) - + if result: logger.info("-- Home location set successfully") else: logger.error("-- Failed to set home location") - + return result - + async def rth(self): logger.info("-- Returning to launch") - if await self.switchMode(SkyViper2450GPSDrone.FlightMode.RTL) == False: + if await self.switchMode(SkyViper2450GPSDrone.FlightMode.RTL) is False: logger.error("Failed to set mode to RTL") return - result = await self._wait_for_condition( - lambda: self.is_disarmed(), - timeout=60, - interval=1 - ) - + result = await self._wait_for_condition(lambda: self.is_disarmed(), timeout=60, interval=1) + if result: logger.info("-- Returned to launch and disarmed") - else: + else: logger.error("-- RTL failed") - + async def manual_control(self, forward_vel, right_vel, up_vel, angle_vel): if self.gps_disabled: - if await self.switchMode(SkyViper2450GPSDrone.FlightMode.GUIDED_NOGPS) == False: + if await self.switchMode(SkyViper2450GPSDrone.FlightMode.GUIDED_NOGPS) is False: logger.error("Failed to set mode to GUIDED_NOGPS") return # if await self.switchMode(SkyViper2450GPSDrone.FlightMode.ALT_HOLD) == False: # logger.error("Failed to set mode to GUIDED_NOGPS") # return else: - if await self.switchMode(SkyViper2450GPSDrone.FlightMode.GUIDED) == False: + if await self.switchMode(SkyViper2450GPSDrone.FlightMode.GUIDED) is False: logger.error("Failed to set mode to GUIDED") return - logger.info(f"Sending manual control: forward={forward_vel}, right={right_vel}, up={up_vel}, yaw={angle_vel}") + logger.info( + f"Sending manual control: forward={forward_vel}, right={right_vel}, up={up_vel}, yaw={angle_vel}" + ) # Ensure values are within MAVLink range (-1000 to 1000) def clamp(value, min_val, max_val): return max(min_val, min(max_val, int(value))) # Ensure integer conversion x = clamp(forward_vel * 1000, -1000, 1000) # Forward/backward movement - y = clamp(right_vel * 1000, -1000, 1000) # Left/right movement - z = clamp(up_vel * 1000, 0, 1000) # Throttle (0=lowest, 1000=full thrust) - r = clamp(angle_vel * 1000, -1000, 1000) # Yaw rotation + y = clamp(right_vel * 1000, -1000, 1000) # Left/right movement + z = clamp(up_vel * 1000, 0, 1000) # Throttle (0=lowest, 1000=full thrust) + r = clamp(angle_vel * 1000, -1000, 1000) # Yaw rotation buttons = 0 # No buttons pressed buttons2 = 0 # No additional buttons @@ -328,32 +326,41 @@ def clamp(value, min_val, max_val): try: self.vehicle.mav.manual_control_send( self.vehicle.target_system, - x, y, z, r, + x, + y, + z, + r, buttons, buttons2, enabled_extensions, - s, t, aux1, aux2, aux3, aux4, aux5, aux6 + s, + t, + aux1, + aux2, + aux3, + aux4, + aux5, + aux6, ) logger.info("Manual control command sent successfully!") except Exception as e: logger.error(f"Failed to send manual control: {e}") - - + async def setAttitude(self, pitch, roll, thrust, yaw): logger.info(f"-- Setting attitude: pitch={pitch}, roll={roll}, thrust={thrust}, yaw={yaw}") if self.gps_disabled: - if await self.switchMode(SkyViper2450GPSDrone.FlightMode.GUIDED_NOGPS) == False: + if await self.switchMode(SkyViper2450GPSDrone.FlightMode.GUIDED_NOGPS) is False: logger.error("Failed to set mode to GUIDED_NOGPS") return # if await self.switchMode(SkyViper2450GPSDrone.FlightMode.ALT_HOLD) == False: # logger.error("Failed to set mode to GUIDED_NOGPS") # return else: - if await self.switchMode(SkyViper2450GPSDrone.FlightMode.GUIDED) == False: + if await self.switchMode(SkyViper2450GPSDrone.FlightMode.GUIDED) is False: logger.error("Failed to set mode to GUIDED") return - + # Convert Euler angles to quaternion (w, x, y, z) def to_quaternion(roll=0.0, pitch=0.0, yaw=0.0): roll, pitch, yaw = map(math.radians, [roll, pitch, yaw]) @@ -361,60 +368,73 @@ def to_quaternion(roll=0.0, pitch=0.0, yaw=0.0): cp, sp = math.cos(pitch * 0.5), math.sin(pitch * 0.5) cr, sr = math.cos(roll * 0.5), math.sin(roll * 0.5) - return [cr * cp * cy + sr * sp * sy, # w - sr * cp * cy - cr * sp * sy, # x - cr * sp * cy + sr * cp * sy, # y - cr * cp * sy - sr * sp * cy] # z + return [ + cr * cp * cy + sr * sp * sy, # w + sr * cp * cy - cr * sp * sy, # x + cr * sp * cy + sr * cp * sy, # y + cr * cp * sy - sr * sp * cy, + ] # z q = to_quaternion(roll, pitch, yaw) base_thrust = 0.6 - + self.vehicle.mav.set_attitude_target_send( 0, # time_boot_ms self.vehicle.target_system, self.vehicle.target_component, 0b00000000, # type_mask q, # Quaternion - 0, 0, 0, # Body angular rates - base_thrust + thrust # Throttle + 0, + 0, + 0, # Body angular rates + base_thrust + thrust, # Throttle ) logger.info("-- setAttitude sent successfully") # continuous control: no blocking wait async def setVelocity(self, forward_vel, right_vel, up_vel, angle_vel): - logger.info(f"-- Setting velocity: forward_vel={forward_vel}, right_vel={right_vel}, up_vel={up_vel}, angle_vel={angle_vel}") - + logger.info( + f"-- Setting velocity: forward_vel={forward_vel}, right_vel={right_vel}, up_vel={up_vel}, angle_vel={angle_vel}" + ) + if self.gps_disabled: - if await self.switchMode(SkyViper2450GPSDrone.FlightMode.GUIDED_NOGPS) == False: + if await self.switchMode(SkyViper2450GPSDrone.FlightMode.GUIDED_NOGPS) is False: logger.error("Failed to set mode to GUIDED_NOGPS") return else: - if await self.switchMode(SkyViper2450GPSDrone.FlightMode.GUIDED) == False: + if await self.switchMode(SkyViper2450GPSDrone.FlightMode.GUIDED) is False: logger.error("Failed to set mode to GUIDED") return - + self.vehicle.mav.set_position_target_local_ned_send( 0, # time_boot_ms self.vehicle.target_system, self.vehicle.target_component, mavutil.mavlink.MAV_FRAME_BODY_NED, # frame 0b010111000111, # type_mask - 0, 0, 0, # x, y, z positions - forward_vel, right_vel, -up_vel, # x, y, z velocity - 0, 0, 0, # x, y, z acceleration - 0, angle_vel # yaw, yaw_rate + 0, + 0, + 0, # x, y, z positions + forward_vel, + right_vel, + -up_vel, # x, y, z velocity + 0, + 0, + 0, # x, y, z acceleration + 0, + angle_vel, # yaw, yaw_rate ) logger.info("-- setVelocity sent successfully") # continuous control: no blocking wait async def setGPSLocation(self, lat, lon, alt, bearing): logger.info(f"-- Setting GPS location: lat={lat}, lon={lon}, alt={alt}, bearing={bearing}") - - if await self.switchMode(SkyViper2450GPSDrone.FlightMode.GUIDED) == False: + + if await self.switchMode(SkyViper2450GPSDrone.FlightMode.GUIDED) is False: logger.error("Failed to set mode to GUIDED") return - + self.vehicle.mav.set_position_target_global_int_send( 0, self.vehicle.target_system, @@ -424,49 +444,60 @@ async def setGPSLocation(self, lat, lon, alt, bearing): int(lat * 1e7), int(lon * 1e7), alt, - 0, 0, 0, - 0, 0, 0, - 0, 0 + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, ) - # Calculate bearing if not provided + # Calculate bearing if not provided current_location = self.getGPS() current_lat = current_location["latitude"] current_lon = current_location["longitude"] if bearing is None: bearing = self.calculate_bearing(current_lat, current_lon, lat, lon) logger.info(f"-- Calculated bearing: {bearing}") - + await self.setBearing(bearing) - + result = await self._wait_for_condition( - lambda: self.is_at_target(lat, lon), - timeout=60, - interval=1 + lambda: self.is_at_target(lat, lon), timeout=60, interval=1 ) - - if result: + + if result: logger.info("-- Reached target GPS location") - else: + else: logger.info("-- Failed to reach target GPS location") - + return result async def setTranslatedLocation(self, forward, right, up, angle): - logger.info(f"-- Translating location: forward={forward}, right={right}, up={up}, angle={angle}") - - if await self.switchMode(SkyViper2450GPSDrone.FlightMode.GUIDED) == False: + logger.info( + f"-- Translating location: forward={forward}, right={right}, up={up}, angle={angle}" + ) + + if await self.switchMode(SkyViper2450GPSDrone.FlightMode.GUIDED) is False: logger.error("Failed to set mode to GUIDED") return - - current_location = self.getGPS() - current_heading = self.getHeading() - dx = forward * math.cos(math.radians(current_heading)) - right * math.sin(math.radians(current_heading)) - dy = forward * math.sin(math.radians(current_heading)) + right * math.cos(math.radians(current_heading)) + current_location = self.getGPS() + current_heading = self.getHeading() + + dx = forward * math.cos(math.radians(current_heading)) - right * math.sin( + math.radians(current_heading) + ) + dy = forward * math.sin(math.radians(current_heading)) + right * math.cos( + math.radians(current_heading) + ) dz = -up target_lat = current_location["latitude"] + (dx / 111320) - target_lon = current_location["longitude"] + (dy / (111320 * math.cos(math.radians(current_location["latitude"])))) + target_lon = current_location["longitude"] + ( + dy / (111320 * math.cos(math.radians(current_location["latitude"]))) + ) target_alt = current_location["altitude"] + dz self.vehicle.mav.set_position_target_global_int_send( @@ -478,35 +509,39 @@ async def setTranslatedLocation(self, forward, right, up, angle): int(target_lat * 1e7), int(target_lon * 1e7), target_alt, - 0, 0, 0, - 0, 0, 0, - 0, 0 + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, ) - - if angle is not None: await self.setBearing(angle) - + + if angle is not None: + await self.setBearing(angle) + result = await self._wait_for_condition( - lambda: self.is_at_target(target_lat, target_lon), - timeout=60, - interval=1 + lambda: self.is_at_target(target_lat, target_lon), timeout=60, interval=1 ) - - if result: + + if result: logger.info("-- Reached target translated location") else: logger.error("-- Failed to reach target translated location") - + return result async def setBearing(self, bearing): logger.info(f"-- Setting yaw to {bearing} degrees") - - if await self.switchMode(SkyViper2450GPSDrone.FlightMode.GUIDED) == False: + + if await self.switchMode(SkyViper2450GPSDrone.FlightMode.GUIDED) is False: logger.error("Failed to set mode to GUIDED") return - - yaw_speed = 25 # deg/s - direction = 0 # 1: clockwise, -1: counter-clockwise 0: most quickly direction + + yaw_speed = 25 # deg/s + direction = 0 # 1: clockwise, -1: counter-clockwise 0: most quickly direction self.vehicle.mav.command_long_send( self.vehicle.target_system, self.vehicle.target_component, @@ -516,22 +551,22 @@ async def setBearing(self, bearing): yaw_speed, direction, 0, - 0, 0, 0 + 0, + 0, + 0, ) - - result = await self._wait_for_condition( - lambda: self.is_bearing_reached(bearing), - timeout=30, - interval=0.5 + + result = await self._wait_for_condition( + lambda: self.is_bearing_reached(bearing), timeout=30, interval=0.5 ) - + result = True - + if result: logger.info(f"-- Yaw successfully set to {bearing} degrees") else: logger.error(f"-- Failed to set yaw to {bearing} degrees") - + return result async def arm(self): @@ -542,17 +577,17 @@ async def arm(self): mavutil.mavlink.MAV_CMD_COMPONENT_ARM_DISARM, 0, 1, - 0, 0, 0, 0, 0, 0 + 0, + 0, + 0, + 0, + 0, + 0, ) logger.info("-- Arm command sent") + result = await self._wait_for_condition(lambda: self.is_armed(), timeout=30, interval=1) - result = await self._wait_for_condition( - lambda: self.is_armed(), - timeout=30, - interval=1 - ) - if result: logger.info("-- Armed successfully") else: @@ -567,57 +602,57 @@ async def disarm(self): mavutil.mavlink.MAV_CMD_COMPONENT_ARM_DISARM, 0, 0, - 0, 0, 0, 0, 0, 0 + 0, + 0, + 0, + 0, + 0, + 0, ) logger.info("-- Disarm command sent") - result = await self._wait_for_condition( - lambda: self.is_disarmed(), - timeout=30, - interval=1 - ) - + result = await self._wait_for_condition(lambda: self.is_disarmed(), timeout=30, interval=1) + if result: self.mode = None logger.info("-- Disarmed successfully") else: logger.error("-- Disarm failed") - - return result + + return result + async def switchMode(self, mode): logger.info(f"Switching mode to {mode}") mode_target = mode.value curr_mode = self.mode.value if self.mode else None - logger.info(f"mode map: {self.mode_mapping}, target mode: {mode_target}, current mode: {curr_mode}") - + logger.info( + f"mode map: {self.mode_mapping}, target mode: {mode_target}, current mode: {curr_mode}" + ) + if self.mode == mode: logger.info(f"Already in mode {mode_target}") return True - + # switch mode if mode_target not in self.mode_mapping: logger.info(f"Mode {mode_target} not supported!") return False - + mode_id = self.mode_mapping[mode_target] self.vehicle.mav.set_mode_send( - self.vehicle.target_system, - mavutil.mavlink.MAV_MODE_FLAG_CUSTOM_MODE_ENABLED, - mode_id + self.vehicle.target_system, mavutil.mavlink.MAV_MODE_FLAG_CUSTOM_MODE_ENABLED, mode_id ) - + result = await self._wait_for_condition( - lambda: self.is_mode_set(mode_target), - timeout=30, - interval=1 + lambda: self.is_mode_set(mode_target), timeout=30, interval=1 ) - + if result: self.mode = mode logger.info(f"Mode switched to {mode_target}") - + return result - + async def disableGPS(self): # logger.info("-- Disabling GPS") @@ -629,41 +664,52 @@ async def disableGPS(self): # 3, # 3 = No GPS (indoor mode) # mavutil.mavlink.MAV_PARAM_TYPE_INT32 # ) - + # result = await self._wait_for_condition( # lambda: self.is_GPS_disabled() # ) - + # if result: # logger.info("-- GPS disabled") # else: - # logger.error("-- Failed to disable GPS") - # Wait and print received parameter messages + # logger.error("-- Failed to disable GPS") + # Wait and print received parameter messages # return result - + self.gps_disabled = True - - - - ''' ACK methods''' + + """ ACK methods""" + def is_armed(self): - return self.vehicle.recv_match(type="HEARTBEAT", blocking=True).base_mode & mavutil.mavlink.MAV_MODE_FLAG_SAFETY_ARMED - + return ( + self.vehicle.recv_match(type="HEARTBEAT", blocking=True).base_mode + & mavutil.mavlink.MAV_MODE_FLAG_SAFETY_ARMED + ) + def is_disarmed(self): - return not (self.vehicle.recv_match(type='HEARTBEAT', blocking=True).base_mode & mavutil.mavlink.MAV_MODE_FLAG_SAFETY_ARMED) + return not ( + self.vehicle.recv_match(type="HEARTBEAT", blocking=True).base_mode + & mavutil.mavlink.MAV_MODE_FLAG_SAFETY_ARMED + ) def is_mode_set(self, mode_target): - current_mode = mavutil.mode_string_v10(self.vehicle.recv_match(type='HEARTBEAT', blocking=True)) + current_mode = mavutil.mode_string_v10( + self.vehicle.recv_match(type="HEARTBEAT", blocking=True) + ) return current_mode == mode_target - + def is_home_set(self): - msg = self.vehicle.recv_match(type='COMMAND_ACK', blocking=True) - return msg and msg.command == mavutil.mavlink.MAV_CMD_DO_SET_HOME and msg.result == mavutil.mavlink.MAV_RESULT_ACCEPTED + msg = self.vehicle.recv_match(type="COMMAND_ACK", blocking=True) + return ( + msg + and msg.command == mavutil.mavlink.MAV_CMD_DO_SET_HOME + and msg.result == mavutil.mavlink.MAV_RESULT_ACCEPTED + ) def is_altitude_reached(self, target_altitude): current_altitude = self.getAltitudeRel() return current_altitude >= target_altitude * 0.95 - + def is_bearing_reached(self, bearing): logger.info(f"Checking if bearing is reached: {bearing}") attitude = self.getAttitude() @@ -673,7 +719,6 @@ def is_bearing_reached(self, bearing): current_yaw = (math.degrees(attitude["yaw"]) + 360) % 360 target_yaw = (bearing + 360) % 360 return abs(current_yaw - target_yaw) <= 2 - def is_at_target(self, lat, lon): current_location = self.getGPS() @@ -681,21 +726,22 @@ def is_at_target(self, lat, lon): return False dlat = lat - current_location["latitude"] dlon = lon - current_location["longitude"] - distance = math.sqrt((dlat ** 2) + (dlon ** 2)) * 1.113195e5 + distance = math.sqrt((dlat**2) + (dlon**2)) * 1.113195e5 return distance < 1.0 - - ''' Helper methods ''' + + """ Helper methods """ + # azimuth calculation for bearing def calculate_bearing(self, lat1, lon1, lat2, lon2): # Convert latitude and longitude from degrees to radians lat1, lon1, lat2, lon2 = map(math.radians, [lat1, lon1, lat2, lon2]) - + delta_lon = lon2 - lon1 # Bearing calculation x = math.sin(delta_lon) * math.cos(lat2) y = math.cos(lat1) * math.sin(lat2) - math.sin(lat1) * math.cos(lat2) * math.cos(delta_lon) - + initial_bearing = math.atan2(x, y) # Convert bearing from radians to degrees @@ -705,7 +751,7 @@ def calculate_bearing(self, lat1, lon1, lat2, lon2): converted_bearing = (initial_bearing + 360) % 360 return converted_bearing - + async def _message_listener(self): logger.info("-- Starting message listener") try: @@ -718,7 +764,7 @@ async def _message_listener(self): logger.info("-- Message listener stopped") except Exception as e: logger.error(f"-- Error in message listener: {e}") - + def _get_cached_message(self, message_type): try: logger.debug(f"Currently connection message types: {list(self.vehicle.messages)}") @@ -726,7 +772,7 @@ def _get_cached_message(self, message_type): except KeyError: logger.error(f"Message type {message_type} not found in cache") return None - + async def _wait_for_condition(self, condition_fn, timeout=30, interval=0.5): start_time = time.time() while True: @@ -741,64 +787,58 @@ async def _wait_for_condition(self, condition_fn, timeout=30, interval=0.5): return False await asyncio.sleep(interval) - ''' Stream methods ''' + """ Stream methods """ + async def getGimbalPose(self): pass - + async def startStreaming(self): - self.streamingThread = StreamingThread(self.vehicle) self.streamingThread.start() async def getVideoFrame(self): if self.streamingThread: - return [self.streamingThread.grabFrame().tobytes(), self.streamingThread.getFrameShape()] + return [ + self.streamingThread.grabFrame().tobytes(), + self.streamingThread.getFrameShape(), + ] async def stopStreaming(self): self.streamingThread.stop() - -import cv2 -import numpy as np -import os -import threading -class StreamingThread(threading.Thread): +class StreamingThread(threading.Thread): def __init__(self, drone): threading.Thread.__init__(self) self.currentFrame = None self.drone = drone - url_sim = os.environ.get('STREAM_SIM_URL') - url_mini = os.environ.get('STREAM_MINI_URL') - self.sim = os.environ.get('SIMULATION') - - if (self.sim == 'true'): - url = url_sim - else: - url = url_mini - + url_sim = os.environ.get("STREAM_SIM_URL") + url_mini = os.environ.get("STREAM_MINI_URL") + self.sim = os.environ.get("SIMULATION") + + url = url_sim if self.sim == "true" else url_mini + logger.info(f"url used: {url}") self.cap = cv2.VideoCapture(url) self.isRunning = True def run(self): try: - while(self.isRunning): - + while self.isRunning: ret, self.currentFrame = self.cap.read() # logger.info(f"Frame shape: {self.currentFrame.shape}") # logger.info(f"Frame: {self.currentFrame}") except Exception as e: logger.error(e) - + def getFrameShape(self): return self.currentFrame.shape - + def grabFrame(self): try: frame = self.currentFrame.copy() return frame - except Exception as e: + except Exception: # Send a blank frame return None diff --git a/os/drivers/base/DroneItf.py b/os/drivers/base/DroneItf.py index 4acb3ee9..78185909 100644 --- a/os/drivers/base/DroneItf.py +++ b/os/drivers/base/DroneItf.py @@ -1,5 +1,5 @@ -from abs import ABC -import asyncio +from abs import ABC, abstractmethod + class DroneDeviceItf(ABC): """ @@ -16,6 +16,7 @@ class Response: :param message: Message string to describe reason for failure. :type message: string """ + def __init__(self, rid, message): """ Constructor method. @@ -27,14 +28,14 @@ def __bool__(self): """ Overloaded boolean operator to support easy success checks. """ - return rid != 0 + return self.rid != 0 @abstractmethod async def connect(self): """ Connect to the drone hardware. - :return: 'True' if successful, 'False' otherwise + :return: 'True' if successful, 'False' otherwise :rtype: bool """ pass @@ -44,7 +45,7 @@ async def isConnected(self): """ Checks to see if the drone hardware is connected. - :return: 'True' if connected, 'False' otherwise + :return: 'True' if connected, 'False' otherwise :rtype: bool """ pass @@ -74,14 +75,14 @@ async def land(self): async def setHome(self, lat, lng, alt): """ Set the home destination for the drone. - + :param lat: New home latitude :type lat: float :param lng: New home longitude :type lng: float :param alt: New home altitude :type alt: float - :return: Response object + :return: Response object :rtype: class: Response """ pass @@ -118,10 +119,10 @@ async def setVeloctiy(self, forward_vel, right_vel, up_vel, angle_vel): """ Set the velocity of the drone. - :param forward_vel: Target velocity along forward axis, + :param forward_vel: Target velocity along forward axis, in meters per second :type forward_vel: float - :param right_vel: Target velocity along right axis, + :param right_vel: Target velocity along right axis, in meters per second :type right_vel: float :param up_vel: Target velocity along up axis, in meters per second @@ -168,12 +169,11 @@ async def setRelativePosition(self, north, east, up, bearing): """ pass - async def hover(self): - """ - Instruct the drone to hover. - - :return: Response object - :rtype: class: Response - """ - pass + async def hover(self): + """ + Instruct the drone to hover. + :return: Response object + :rtype: class: Response + """ + pass diff --git a/os/drivers/base/common.py b/os/drivers/base/common.py new file mode 100644 index 00000000..d62815a1 --- /dev/null +++ b/os/drivers/base/common.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2025 Carnegie Mellon University - Satyalab +# +# SPDX-License-Identifier: GPL-2.0-only + + +class ConnectionFailedException(Exception): + pass diff --git a/os/drivers/driver.py b/os/drivers/driver.py index 2878e597..372f1fec 100644 --- a/os/drivers/driver.py +++ b/os/drivers/driver.py @@ -1,48 +1,49 @@ -import time -import zmq -import zmq.asyncio +import asyncio import json +import logging import os +import signal import sys -import asyncio -import logging +import time + import cnc_protocol.cnc_pb2 as cnc_protocol -from util.utils import setup_socket, SocketOperation -import signal -from drivers.ModalAI.Seeker.Seeker import ModalAISeekerDrone, ConnectionFailedException -from drivers.SkyRocket.SkyViper2450GPS.SkyViper2450GPS import SkyViper2450GPSDrone, ConnectionFailedException +import zmq +import zmq.asyncio +from drivers.base.common import ConnectionFailedException +from drivers.ModalAI.Seeker.Seeker import ModalAISeekerDrone +from drivers.SkyRocket.SkyViper2450GPS.SkyViper2450GPS import SkyViper2450GPSDrone +from util.utils import SocketOperation, setup_socket # Configure logger -logging_format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s' -logging.basicConfig(level=os.environ.get('LOG_LEVEL', logging.INFO), - format=logging_format) +logging_format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" +logging.basicConfig(level=os.environ.get("LOG_LEVEL", logging.INFO), format=logging_format) logger = logging.getLogger(__name__) if os.environ.get("LOG_TO_FILE") == "true": - file_handler = logging.FileHandler('driver.log') + file_handler = logging.FileHandler("driver.log") file_handler.setFormatter(logging.Formatter(logging_format)) logger.addHandler(file_handler) -telemetry_logger = logging.getLogger('telemetry') -telemetry_handler = logging.FileHandler('telemetry.log') +telemetry_logger = logging.getLogger("telemetry") +telemetry_handler = logging.FileHandler("telemetry.log") formatter = logging.Formatter(logging_format) telemetry_handler.setFormatter(formatter) telemetry_logger.handlers.clear() telemetry_logger.addHandler(telemetry_handler) telemetry_logger.propagate = False -driverArgs = json.loads(os.environ.get('DRIVER_ARGS')) -droneArgs = json.loads(os.environ.get('DRONE_ARGS')) +driverArgs = json.loads(os.environ.get("DRIVER_ARGS")) +droneArgs = json.loads(os.environ.get("DRONE_ARGS")) if droneArgs is not None: for key, value in droneArgs.items(): driverArgs[key] = value -drone_id = driverArgs.get('drone_id') -drone_type = driverArgs.get('drone_type') -connection_string = driverArgs.get('connection_string') +drone_id = driverArgs.get("drone_id") +drone_type = driverArgs.get("drone_type") +connection_string = driverArgs.get("connection_string") -if drone_type == 'modalai': +if drone_type == "modalai": drone = ModalAISeekerDrone(drone_id) -elif drone_type == 'SkyViper2450GPS': +elif drone_type == "SkyViper2450GPS": drone = SkyViper2450GPSDrone(drone_id) context = zmq.asyncio.Context() @@ -51,49 +52,71 @@ cam_sock = context.socket(zmq.PUB) tel_sock.setsockopt(zmq.CONFLATE, 1) cam_sock.setsockopt(zmq.CONFLATE, 1) -setup_socket(tel_sock, SocketOperation.CONNECT, 'TEL_PORT', 'Created telemetry socket endpoint', os.environ.get("DATA_ENDPOINT")) -setup_socket(cam_sock, SocketOperation.CONNECT, 'CAM_PORT', 'Created camera socket endpoint', os.environ.get("DATA_ENDPOINT")) -setup_socket(cmd_back_sock, SocketOperation.CONNECT, 'CMD_BACK_PORT', 'Created command backend socket endpoint', os.environ.get("CMD_ENDPOINT")) +setup_socket( + tel_sock, + SocketOperation.CONNECT, + "TEL_PORT", + "Created telemetry socket endpoint", + os.environ.get("DATA_ENDPOINT"), +) +setup_socket( + cam_sock, + SocketOperation.CONNECT, + "CAM_PORT", + "Created camera socket endpoint", + os.environ.get("DATA_ENDPOINT"), +) +setup_socket( + cmd_back_sock, + SocketOperation.CONNECT, + "CMD_BACK_PORT", + "Created command backend socket endpoint", + os.environ.get("CMD_ENDPOINT"), +) + def handle_signal(signum, frame): logger.info(f"Received signal {signum}, cleaning up...") sys.exit(0) + signal.signal(signal.SIGINT, handle_signal) signal.signal(signal.SIGTERM, handle_signal) + async def camera_stream(drone, cam_sock): - logger.info('Starting camera stream') + logger.info("Starting camera stream") frame_id = 0 while await drone.isConnected(): try: cam_message = cnc_protocol.Frame() frame, frame_shape = await drone.getVideoFrame() - + if frame is None: - logger.error('Failed to get video frame') + logger.error("Failed to get video frame") continue - + cam_message.data = frame cam_message.height = frame_shape[0] cam_message.width = frame_shape[1] cam_message.channels = frame_shape[2] cam_message.id = frame_id cam_sock.send(cam_message.SerializeToString()) - logger.debug(f'Camera stream: sent frame {frame_id}, shape: {frame_shape}') + logger.debug(f"Camera stream: sent frame {frame_id}, shape: {frame_shape}") frame_id = frame_id + 1 except Exception as e: - logger.error(f'Failed to get video frame, error: {e}') + logger.error(f"Failed to get video frame, error: {e}") await asyncio.sleep(0.033) logger.info("Camera stream ended, disconnected from drone") + async def telemetry_stream(drone, tel_sock): - logger.debug('Starting telemetry stream') - - await asyncio.sleep(1) # solving for some contention issue with connecting to drone - + logger.debug("Starting telemetry stream") + + await asyncio.sleep(1) # solving for some contention issue with connecting to drone + while await drone.isConnected(): - logger.debug('HI from telemetry stream') + logger.debug("HI from telemetry stream") try: tel_message = cnc_protocol.Telemetry() telDict = await drone.getTelemetry() @@ -104,43 +127,54 @@ async def telemetry_stream(drone, tel_sock): tel_message.drone_attitude.pitch = telDict["attitude"]["pitch"] tel_message.drone_attitude.roll = telDict["attitude"]["roll"] tel_message.satellites = telDict["satellites"] - + tel_message.relative_position.up = telDict["relAlt"] - + tel_message.global_position.latitude = telDict["gps"]["latitude"] tel_message.global_position.longitude = telDict["gps"]["longitude"] tel_message.global_position.altitude = telDict["gps"]["altitude"] - + tel_message.velocity.forward_vel = telDict["imu"]["forward"] tel_message.velocity.right_vel = telDict["imu"]["right"] tel_message.velocity.up_vel = telDict["imu"]["up"] - - #tel_message.gimbal_attitude.yaw = telDict["gimbalAttitude"]["yaw"] - #tel_message.gimbal_attitude.pitch = telDict["gimbalAttitude"]["pitch"] - #tel_message.gimbal_attitude.roll = telDict["gimbalAttitude"]["roll"] - + + # tel_message.gimbal_attitude.yaw = telDict["gimbalAttitude"]["yaw"] + # tel_message.gimbal_attitude.pitch = telDict["gimbalAttitude"]["pitch"] + # tel_message.gimbal_attitude.roll = telDict["gimbalAttitude"]["roll"] + logger.info(f"Telemetry: {telDict}") tel_sock.send(tel_message.SerializeToString()) - logger.debug('Sent telemetry') + logger.debug("Sent telemetry") except Exception as e: - logger.error(f'Failed to get telemetry, error: {e}') + logger.error(f"Failed to get telemetry, error: {e}") await asyncio.sleep(0.01) logger.debug("Telemetry stream ended, disconnected from drone") + async def handle(identity, message, resp, action, resp_sock): try: match action: case "takeOff": - logger.info(f"takeoff function call started at: {time.time()}, seq id {message.seqNum}") - logger.info('####################################Taking OFF################################################################') + logger.info( + f"takeoff function call started at: {time.time()}, seq id {message.seqNum}" + ) + logger.info( + "####################################Taking OFF################################################################" + ) await drone.takeOff(5) resp.resp = cnc_protocol.ResponseStatus.COMPLETED logger.info(f"tookoff function call finished at: {time.time()}") case "setVelocity": velocity = message.setVelocity - logger.info(f"Setting velocity: {velocity} started at {time.time()}, seq id {message.seqNum}") - logger.info('####################################Setting Velocity#######################################################################') - await drone.setVelocity(velocity.forward_vel, velocity.right_vel, velocity.up_vel, velocity.angle_vel) + logger.info( + f"Setting velocity: {velocity} started at {time.time()}, seq id {message.seqNum}" + ) + logger.info( + "####################################Setting Velocity#######################################################################" + ) + await drone.setVelocity( + velocity.forward_vel, velocity.right_vel, velocity.up_vel, velocity.angle_vel + ) # await drone.setAttitude(velocity.forward_vel, velocity.right_vel, velocity.up_vel, velocity.angle_vel) # await drone.manual_control(velocity.forward_vel, velocity.right_vel, velocity.up_vel, velocity.angle_vel) resp.resp = cnc_protocol.ResponseStatus.COMPLETED @@ -148,30 +182,42 @@ async def handle(identity, message, resp, action, resp_sock): logger.info(f"land function call started at: {time.time()}") await drone.land() resp.resp = cnc_protocol.ResponseStatus.COMPLETED - logger.info('####################################Landing#######################################################################') + logger.info( + "####################################Landing#######################################################################" + ) logger.info(f"land function call finished at: {time.time()}") case "rth": logger.info(f"rth function call started at: {time.time()}") - logger.info('####################################Returning to Home#######################################################################') + logger.info( + "####################################Returning to Home#######################################################################" + ) await drone.rth() resp.resp = cnc_protocol.ResponseStatus.COMPLETED logger.info(f"rth function call finished at: {time.time()}") case "hover": - logger.info('####################################Hovering#######################################################################') - logger.info(f"hover function call started at: {time.time()}, seq id {message.seqNum}") + logger.info( + "####################################Hovering#######################################################################" + ) + logger.info( + f"hover function call started at: {time.time()}, seq id {message.seqNum}" + ) await drone.hover() logger.info("hover !") resp.resp = cnc_protocol.ResponseStatus.COMPLETED logger.info(f"hover function call finished at: {time.time()}") case "setGPSLocation": logger.info(f"setGPSLocation function call started at: {time.time()}") - logger.info('####################################Setting GPS Location#######################################################################') + logger.info( + "####################################Setting GPS Location#######################################################################" + ) location = message.setGPSLocation - await drone.setGPSLocation(location.latitude, location.longitude, location.altitude, None) + await drone.setGPSLocation( + location.latitude, location.longitude, location.altitude, None + ) resp.resp = cnc_protocol.ResponseStatus.COMPLETED logger.info(f"setGPSLocation function call finished at: {time.time()}") except Exception as e: - logger.error(f'Failed to handle command, error: {e.message}') + logger.error(f"Failed to handle command, error: {e.message}") resp.resp = cnc_protocol.ResponseStatus.FAILED resp_sock.send_multipart([identity, resp.SerializeToString()]) @@ -179,18 +225,18 @@ async def handle(identity, message, resp, action, resp_sock): async def main(drone, cam_sock, tel_sock, args): while True: try: - logger.info('starting connecting...') + logger.info("starting connecting...") await drone.connect(connection_string) - logger.info('drone connected') - except ConnectionFailedException as e: - logger.error('Failed to connect to drone, retrying...') + logger.info("drone connected") + except ConnectionFailedException: + logger.error("Failed to connect to drone, retrying...") continue - logger.info(f'Established connection to drone, ready to receive commands!') - + logger.info("Established connection to drone, ready to receive commands!") + await drone.startStreaming() - logger.info('Started streaming') + logger.info("Started streaming") asyncio.create_task(camera_stream(drone, cam_sock)) - + asyncio.create_task(telemetry_stream(drone, tel_sock)) await drone.disableGPS() @@ -211,9 +257,10 @@ async def main(drone, cam_sock, tel_sock, args): resp = message asyncio.create_task(handle(identity, message, resp, action, cmd_back_sock)) except Exception as e: - logger.info(f'cmd received error: {e}') + logger.info(f"cmd received error: {e}") + + logger.info("Disconnected from drone") - logger.info(f"Disconnected from drone") if __name__ == "__main__": asyncio.run(main(drone, cam_sock, tel_sock, driverArgs)) diff --git a/os/drivers/test/driver_test.py b/os/drivers/test/driver_test.py index 74242ad4..72818882 100644 --- a/os/drivers/test/driver_test.py +++ b/os/drivers/test/driver_test.py @@ -1,14 +1,15 @@ -from pynput.keyboard import Listener, Key, KeyCode -from enum import Enum -import subprocess +import asyncio import logging import time -import zmq -import zmq.asyncio -import asyncio from collections import defaultdict -from util.utils import setup_socket, SocketOperation +from enum import Enum + import cnc_protocol.cnc_pb2 as cnc_pb2 +import zmq +import zmq.asyncio +from pynput.keyboard import Key, KeyCode, Listener +from util.utils import SocketOperation, setup_socket + class Ctrl(Enum): ( @@ -25,6 +26,7 @@ class Ctrl(Enum): TURN_RIGHT, ) = range(11) + QWERTY_CTRL_KEYS = { Ctrl.QUIT: Key.esc, Ctrl.TAKEOFF: "t", @@ -39,6 +41,7 @@ class Ctrl(Enum): Ctrl.TURN_RIGHT: Key.right, } + class KeyboardCtrl(Listener): def __init__(self, ctrl_keys=None): self._ctrl_keys = self._get_ctrl_keys(ctrl_keys) @@ -52,10 +55,7 @@ def _on_press(self, key): self._key_pressed[key.char] = True elif isinstance(key, Key): self._key_pressed[key] = True - if self._key_pressed[self._ctrl_keys[Ctrl.QUIT]]: - return False - else: - return True + return not self._key_pressed[self._ctrl_keys[Ctrl.QUIT]] def _on_release(self, key): if isinstance(key, KeyCode): @@ -68,41 +68,22 @@ def quit(self): return not self.running or self._key_pressed[self._ctrl_keys[Ctrl.QUIT]] def _axis(self, left_key, right_key): - return ( - int(self._key_pressed[right_key]) - int(self._key_pressed[left_key]) - ) + return int(self._key_pressed[right_key]) - int(self._key_pressed[left_key]) def roll(self): - return self._axis( - self._ctrl_keys[Ctrl.MOVE_LEFT], - self._ctrl_keys[Ctrl.MOVE_RIGHT] - ) + return self._axis(self._ctrl_keys[Ctrl.MOVE_LEFT], self._ctrl_keys[Ctrl.MOVE_RIGHT]) def pitch(self): - return self._axis( - self._ctrl_keys[Ctrl.MOVE_BACKWARD], - self._ctrl_keys[Ctrl.MOVE_FORWARD] - ) + return self._axis(self._ctrl_keys[Ctrl.MOVE_BACKWARD], self._ctrl_keys[Ctrl.MOVE_FORWARD]) def yaw(self): - return self._axis( - self._ctrl_keys[Ctrl.TURN_LEFT], - self._ctrl_keys[Ctrl.TURN_RIGHT] - ) + return self._axis(self._ctrl_keys[Ctrl.TURN_LEFT], self._ctrl_keys[Ctrl.TURN_RIGHT]) def throttle(self): - return self._axis( - self._ctrl_keys[Ctrl.MOVE_DOWN], - self._ctrl_keys[Ctrl.MOVE_UP] - ) + return self._axis(self._ctrl_keys[Ctrl.MOVE_DOWN], self._ctrl_keys[Ctrl.MOVE_UP]) def has_piloting_cmd(self): - return ( - bool(self.roll()) - or bool(self.pitch()) - or bool(self.yaw()) - or bool(self.throttle()) - ) + return bool(self.roll()) or bool(self.pitch()) or bool(self.yaw()) or bool(self.throttle()) def _rate_limit_cmd(self, ctrl, delay): now = time.time() @@ -132,20 +113,23 @@ def _get_ctrl_keys(self, ctrl_keys): context = zmq.asyncio.Context() cmd_back_sock = context.socket(zmq.DEALER) -setup_socket(cmd_back_sock, SocketOperation.BIND, 'CMD_BACK_PORT', 'Created command backend socket endpoint') +setup_socket( + cmd_back_sock, SocketOperation.BIND, "CMD_BACK_PORT", "Created command backend socket endpoint" +) tel_sock = context.socket(zmq.SUB) -tel_sock.setsockopt(zmq.SUBSCRIBE, b'') # Subscribe to all topics +tel_sock.setsockopt(zmq.SUBSCRIBE, b"") # Subscribe to all topics tel_sock.setsockopt(zmq.CONFLATE, 1) -setup_socket(tel_sock, SocketOperation.BIND, 'TEL_PORT', 'Created telemetry socket endpoint') +setup_socket(tel_sock, SocketOperation.BIND, "TEL_PORT", "Created telemetry socket endpoint") cam_sock = context.socket(zmq.SUB) -cam_sock.setsockopt(zmq.SUBSCRIBE, b'') # Subscribe to all topics +cam_sock.setsockopt(zmq.SUBSCRIBE, b"") # Subscribe to all topics cam_sock.setsockopt(zmq.CONFLATE, 1) -setup_socket(cam_sock, SocketOperation.BIND, 'CAM_PORT', 'Created camera socket endpoint') +setup_socket(cam_sock, SocketOperation.BIND, "CAM_PORT", "Created camera socket endpoint") command_seq = 0 + async def recv_telemetry(): while True: try: @@ -156,7 +140,8 @@ async def recv_telemetry(): await asyncio.sleep(0.3) except Exception as e: logger.error(f"Telemetry handler error: {e}") - + + async def recv_camera(): while True: try: @@ -168,6 +153,7 @@ async def recv_camera(): except Exception as e: logger.error(f"Camera handler error: {e}") + async def send_comm(control): global command_seq driver_command = cnc_pb2.Driver() @@ -175,34 +161,39 @@ async def send_comm(control): command_seq += 1 if control.takeoff(): - logger.info('Takeoff!') + logger.info("Takeoff!") driver_command.takeOff = True elif control.landing(): - logger.info('Land!') + logger.info("Land!") driver_command.land = True elif control.has_piloting_cmd(): - logger.info(f'setVelocity({control.pitch()}, {control.roll()}, {control.throttle()}, {control.yaw()})') + logger.info( + f"setVelocity({control.pitch()}, {control.roll()}, {control.throttle()}, {control.yaw()})" + ) driver_command.setVelocity.forward_vel = control.pitch() driver_command.setVelocity.right_vel = control.roll() driver_command.setVelocity.up_vel = control.throttle() driver_command.setVelocity.angle_vel = control.yaw() else: - logger.info('Hover.') + logger.info("Hover.") driver_command.hover = True message = driver_command.SerializeToString() - identity = b'cmdr' + identity = b"cmdr" await cmd_back_sock.send_multipart([identity, message]) - logger.info('Sent message.') + logger.info("Sent message.") resp = await cmd_back_sock.recv_multipart() -logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') -async def main(): +logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s") + + +async def main(): control = KeyboardCtrl() asyncio.create_task(recv_telemetry()) while not control.quit(): - await send_comm(control) + await send_comm(control) await asyncio.sleep(0.2) + asyncio.run(main()) diff --git a/os/kernel/CommandService.py b/os/kernel/CommandService.py index eb65c062..737edcd0 100644 --- a/os/kernel/CommandService.py +++ b/os/kernel/CommandService.py @@ -1,21 +1,24 @@ -from enum import Enum -import sys +import asyncio +import logging +import os import time +from enum import Enum + import validators import zmq import zmq.asyncio -import asyncio -import logging -import os from cnc_protocol import cnc_pb2 from kernel.Service import Service -from util.utils import setup_socket, SocketOperation +from util.utils import SocketOperation, setup_socket # Configure logger -logging.basicConfig(level=os.environ.get('LOG_LEVEL', logging.INFO), - format='%(asctime)s - %(levelname)s - %(name)s - %(message)s') +logging.basicConfig( + level=os.environ.get("LOG_LEVEL", logging.INFO), + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", +) logger = logging.getLogger(__name__) + # Enumerations for Commands and Drone Types class ManualCommand(Enum): RTH = 1 @@ -34,7 +37,7 @@ def __init__(self, drone_id, drone_type): self.drone_type = drone_type self.drone_id = drone_id self.manual = True - + # init cmd seq self.command_seq = 0 @@ -43,19 +46,39 @@ def __init__(self, drone_id, drone_type): # Setting up sockets cmd_front_cmdr_sock = context.socket(zmq.DEALER) - cmd_front_cmdr_sock.setsockopt(zmq.IDENTITY, self.drone_id.encode('utf-8')) + cmd_front_cmdr_sock.setsockopt(zmq.IDENTITY, self.drone_id.encode("utf-8")) cmd_front_usr_sock = context.socket(zmq.DEALER) cmd_back_sock = context.socket(zmq.DEALER) msn_sock = context.socket(zmq.REQ) - setup_socket(cmd_front_cmdr_sock, SocketOperation.CONNECT, 'CMD_FRONT_CMDR_PORT', 'Connected command frontend cmdr socket endpoint', os.environ.get('STEELEAGLE_GABRIEL_SERVER')) - setup_socket(cmd_front_usr_sock, SocketOperation.BIND, 'CMD_FRONT_USR_PORT', 'Created command frontend user socket endpoint') - setup_socket(cmd_back_sock, SocketOperation.BIND, 'CMD_BACK_PORT', 'Created command backend socket endpoint') - setup_socket(msn_sock, SocketOperation.BIND, 'MSN_PORT', 'Created userspace mission control socket endpoint') + setup_socket( + cmd_front_cmdr_sock, + SocketOperation.CONNECT, + "CMD_FRONT_CMDR_PORT", + "Connected command frontend cmdr socket endpoint", + os.environ.get("STEELEAGLE_GABRIEL_SERVER"), + ) + setup_socket( + cmd_front_usr_sock, + SocketOperation.BIND, + "CMD_FRONT_USR_PORT", + "Created command frontend user socket endpoint", + ) + setup_socket( + cmd_back_sock, + SocketOperation.BIND, + "CMD_BACK_PORT", + "Created command backend socket endpoint", + ) + setup_socket( + msn_sock, + SocketOperation.BIND, + "MSN_PORT", + "Created userspace mission control socket endpoint", + ) self.cmd_front_cmdr_sock = cmd_front_cmdr_sock - self.cmd_front_usr_sock = cmd_front_usr_sock + self.cmd_front_usr_sock = cmd_front_usr_sock self.cmd_back_sock = cmd_back_sock self.msn_sock = msn_sock - # setting up tasks cmd_task = asyncio.create_task(self.cmd_proxy()) @@ -73,7 +96,7 @@ async def send_download_mission(self, url): mission_command = cnc_pb2.Mission() mission_command.downloadMission = url message = mission_command.SerializeToString() - logger.info(f'download_mission message:{message}') + logger.info(f"download_mission message:{message}") self.msn_sock.send(message) reply = await self.msn_sock.recv_string() logger.info(f"Mission reply: {reply}") @@ -84,7 +107,7 @@ async def send_start_mission(self): mission_command = cnc_pb2.Mission() mission_command.startMission = True message = mission_command.SerializeToString() - logger.info(f'start_mission message:{message}') + logger.info(f"start_mission message:{message}") self.msn_sock.send(message) reply = await self.msn_sock.recv_string() logger.info(f"Mission reply: {reply}") @@ -98,7 +121,6 @@ async def send_stop_mission(self): reply = await self.msn_sock.recv_string() logger.info(f"Mission reply: {reply}") - ######################################################## DRIVER ############################################################ async def send_driver_command(self, command, params): driver_command = cnc_pb2.Driver() @@ -123,23 +145,27 @@ async def send_driver_command(self, command, params): driver_command.setVelocity.angle_vel = params["yaw"] driver_command.setVelocity.right_vel = params["roll"] driver_command.setVelocity.up_vel = params["thrust"] - logger.info(f'Driver Command setVelocities: {driver_command.setVelocity} sent at:{time.time()}, seq id {driver_command.seqNum}') + logger.info( + f"Driver Command setVelocities: {driver_command.setVelocity} sent at:{time.time()}, seq id {driver_command.seqNum}" + ) elif command == ManualCommand.GIMBAL: driver_command.setGimbal.pitch_theta = params["gimbal_pitch"] - logger.info(f'Driver Command setGimbal: {driver_command.setGimbal} sent at:{time.time()}, seq id {driver_command.seqNum}') + logger.info( + f"Driver Command setGimbal: {driver_command.setGimbal} sent at:{time.time()}, seq id {driver_command.seqNum}" + ) elif command == ManualCommand.CONNECTION: driver_command.connectionStatus = cnc_pb2.ConnectionStatus() message = driver_command.SerializeToString() - identity = b'cmdr' + identity = b"cmdr" await self.cmd_back_sock.send_multipart([identity, message]) logger.info(f"Driver Command sent: {message}") - + return None ######################################################## COMMAND ############################################################ async def cmd_proxy(self): - logger.info('cmd_proxy started') + logger.info("cmd_proxy started") poller = zmq.asyncio.Poller() poller.register(self.cmd_front_cmdr_sock, zmq.POLLIN) poller.register(self.cmd_back_sock, zmq.POLLIN) @@ -147,26 +173,29 @@ async def cmd_proxy(self): while True: try: - logger.debug('proxy loop') + logger.debug("proxy loop") socks = dict(await poller.poll()) # Check for messages from CMDR if self.cmd_front_cmdr_sock in socks: msg = await self.cmd_front_cmdr_sock.recv_multipart() - cmd = msg[0] + cmd = msg[0] # Filter the message - logger.debug(f"proxy : cmd_front_cmdr_sock Received message from FRONTEND: cmd: {cmd}") + logger.debug( + f"proxy : cmd_front_cmdr_sock Received message from FRONTEND: cmd: {cmd}" + ) await self.process_command(cmd) - + # Check for messages from MSN if self.cmd_front_usr_sock in socks: msg = await self.cmd_front_usr_sock.recv_multipart() cmd = msg[0] - logger.debug(f"proxy : cmd_front_usr_sock Received message from FRONTEND: {cmd}") - identity = b'usr' + logger.debug( + f"proxy : cmd_front_usr_sock Received message from FRONTEND: {cmd}" + ) + identity = b"usr" await self.cmd_back_sock.send_multipart([identity, cmd]) - - + # Check for messages from DRIVER if self.cmd_back_sock in socks: message = await self.cmd_back_sock.recv_multipart() @@ -175,16 +204,22 @@ async def cmd_proxy(self): # Filter the message identity = message[0] cmd = message[1] - logger.debug(f"proxy : cmd_back_sock Received message from BACKEND: identity: {identity} cmd: {cmd}") - - if identity == b'cmdr': - logger.debug(f"proxy : cmd_back_sock Received message from BACKEND: discard bc of cmdr") + logger.debug( + f"proxy : cmd_back_sock Received message from BACKEND: identity: {identity} cmd: {cmd}" + ) + + if identity == b"cmdr": + logger.debug( + "proxy : cmd_back_sock Received message from BACKEND: discard bc of cmdr" + ) pass - elif identity == b'usr': - logger.debug(f"proxy : cmd_back_sock Received message from BACKEND: sent back bc of user") + elif identity == b"usr": + logger.debug( + "proxy : cmd_back_sock Received message from BACKEND: sent back bc of user" + ) await self.cmd_front_usr_sock.send_multipart([cmd]) else: - logger.error(f"proxy: invalid identity") + logger.error("proxy: invalid identity") except Exception as e: logger.error(f"proxy: {e}") @@ -204,16 +239,16 @@ async def process_command(self, cmd): await self.send_stop_mission() asyncio.create_task(self.send_driver_command(ManualCommand.HALT, None)) self.manual = True - logger.info('Manual control is now active!') + logger.info("Manual control is now active!") elif extras.cmd.script_url: url = extras.cmd.script_url if validators.url(url): - logger.info(f'Flight script sent by commander: {url}') + logger.info(f"Flight script sent by commander: {url}") self.manual = False await self.send_download_mission(url) await self.send_start_mission() else: - logger.info(f'Invalid script URL sent by commander: {extras.cmd.script_url}') + logger.info(f"Invalid script URL sent by commander: {extras.cmd.script_url}") elif self.manual: if extras.cmd.takeoff: logger.info(f"takeoff signal started at: {time.time()} seq id {self.command_seq}") @@ -225,17 +260,31 @@ async def process_command(self, cmd): self.handle_pcmd_command(extras.cmd.pcmd) def handle_pcmd_command(self, pcmd): - pitch, yaw, roll, thrust, gimbal_pitch = pcmd.pitch, pcmd.yaw, pcmd.roll, pcmd.gaz, pcmd.gimbal_pitch - params = {"pitch": pitch, "yaw": yaw, "roll": roll, "thrust": thrust, "gimbal_pitch": gimbal_pitch} - logger.info(f"PCMD signal started at: {time.time()} PCMD values: {params} seq id {self.command_seq}") + pitch, yaw, roll, thrust, gimbal_pitch = ( + pcmd.pitch, + pcmd.yaw, + pcmd.roll, + pcmd.gaz, + pcmd.gimbal_pitch, + ) + params = { + "pitch": pitch, + "yaw": yaw, + "roll": roll, + "thrust": thrust, + "gimbal_pitch": gimbal_pitch, + } + logger.info( + f"PCMD signal started at: {time.time()} PCMD values: {params} seq id {self.command_seq}" + ) asyncio.create_task(self.send_driver_command(ManualCommand.PCMD, params)) asyncio.create_task(self.send_driver_command(ManualCommand.GIMBAL, params)) ######################################################## MAIN ############################################################## async def async_main(): - drone_type = os.environ.get('DRONE_TYPE') - drone_id = os.environ.get('DRONE_ID') + drone_type = os.environ.get("DRONE_TYPE") + drone_id = os.environ.get("DRONE_ID") # init CommandService cmd_service = CommandService(drone_id, drone_type) @@ -243,7 +292,6 @@ async def async_main(): await cmd_service.start() - # Main Execution Block if __name__ == "__main__": logger.info("Main: starting CommandService") diff --git a/os/kernel/DataService.py b/os/kernel/DataService.py index 161c79d5..a4aed9c3 100644 --- a/os/kernel/DataService.py +++ b/os/kernel/DataService.py @@ -1,27 +1,27 @@ -import time -import zmq -import zmq.asyncio import asyncio -import os -import logging -import yaml import importlib +import logging +import os import pkgutil -from util.utils import setup_socket, SocketOperation -from cnc_protocol import cnc_pb2 +import sys + import computes -from kernel.computes.ComputeItf import ComputeInterface +import yaml +import zmq +import zmq.asyncio +from cnc_protocol import cnc_pb2 from DataStore import DataStore +from kernel.computes.ComputeItf import ComputeInterface from kernel.Service import Service -import sys +from util.utils import SocketOperation, setup_socket # Set up logging -logging_format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s' -logging.basicConfig(level=os.environ.get('LOG_LEVEL', logging.INFO), format=logging_format) +logging_format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" +logging.basicConfig(level=os.environ.get("LOG_LEVEL", logging.INFO), format=logging_format) logger = logging.getLogger(__name__) if os.environ.get("LOG_TO_FILE") == "true": - file_handler = logging.FileHandler('data_service.log') + file_handler = logging.FileHandler("data_service.log") file_handler.setFormatter(logging.Formatter(logging_format)) logger.addHandler(file_handler) @@ -40,30 +40,35 @@ def __init__(self, config_yaml): cam_sock = context.socket(zmq.SUB) cpt_usr_sock = context.socket(zmq.DEALER) - tel_sock.setsockopt(zmq.SUBSCRIBE, b'') # Subscribe to all topics + tel_sock.setsockopt(zmq.SUBSCRIBE, b"") # Subscribe to all topics tel_sock.setsockopt(zmq.CONFLATE, 1) - cam_sock.setsockopt(zmq.SUBSCRIBE, b'') # Subscribe to all topics + cam_sock.setsockopt(zmq.SUBSCRIBE, b"") # Subscribe to all topics cam_sock.setsockopt(zmq.CONFLATE, 1) - setup_socket(tel_sock, SocketOperation.BIND, 'TEL_PORT', 'Created telemetry socket endpoint') - setup_socket(cam_sock, SocketOperation.BIND, 'CAM_PORT', 'Created camera socket endpoint') - setup_socket(cpt_usr_sock, SocketOperation.BIND, 'CPT_USR_PORT', 'Created command frontend socket endpoint') + setup_socket( + tel_sock, SocketOperation.BIND, "TEL_PORT", "Created telemetry socket endpoint" + ) + setup_socket(cam_sock, SocketOperation.BIND, "CAM_PORT", "Created camera socket endpoint") + setup_socket( + cpt_usr_sock, + SocketOperation.BIND, + "CPT_USR_PORT", + "Created command frontend socket endpoint", + ) self.register_socket(tel_sock) self.register_socket(cam_sock) self.register_socket(cpt_usr_sock) - self.cam_sock = cam_sock self.tel_sock = tel_sock self.cpt_usr_sock = cpt_usr_sock - # setting up tasks tel_task = asyncio.create_task(self.telemetry_handler()) cam_task = asyncio.create_task(self.camera_handler()) usr_task = asyncio.create_task(self.user_handler()) - + self.register_task(tel_task) self.register_task(cam_task) self.register_task(usr_task) @@ -79,30 +84,32 @@ def __init__(self, config_yaml): def get_result(self, compute_type): logger.info(f"Processing getter for compute type: {compute_type}") getter_list = [] - for compute_id in self.compute_dict.keys(): + for compute_id in self.compute_dict: cpt_res = self.data_store.get_compute_result(compute_id, compute_type) - + if cpt_res is None: logger.error(f"Result not found for compute_id: {compute_id}") continue - + res = cpt_res[0] timestamp = str(cpt_res[1]) - - result= cnc_pb2.ComputeResult() + + result = cnc_pb2.ComputeResult() result.compute_id = compute_id result.timestamp = timestamp result.string_result = res - + getter_list.append(result) - logger.info(f"Sending result: {res} with compute_id : {compute_id}, timestamp: {timestamp}") + logger.info( + f"Sending result: {res} with compute_id : {compute_id}, timestamp: {timestamp}" + ) return getter_list def clear_result(self): logger.info("Processing setter") - for compute_id in self.compute_dict.keys(): + for compute_id in self.compute_dict: self.data_store.clear_compute_result(compute_id) - + async def user_handler(self): """Handles user commands.""" logger.info("User handler started") @@ -146,7 +153,6 @@ async def handle_compute(self, cpt_command): await self.cpt_usr_sock.send(cpt_command.SerializeToString()) elif cpt_command.setter: - if cpt_command.setter.clearResult: logger.info("Processing setter clear") self.clear_result() @@ -158,7 +164,7 @@ async def handle_compute(self, cpt_command): async def handle_driver(self, driver_command): """Processes a Driver command.""" logger.info(f"Received Driver command: {driver_command}") - + if driver_command.getTelemetry: logger.info("Processing getTelemetry") self.data_store.get_raw_data(driver_command.getTelemetry) @@ -185,7 +191,7 @@ def parse_frame(self, msg): frame = cnc_pb2.Frame() frame.ParseFromString(msg) return frame - + async def camera_handler(self): """Handles camera messages.""" logger.info("Camera handler started") @@ -195,7 +201,7 @@ async def camera_handler(self): frame = await asyncio.to_thread(self.parse_frame, msg) # Offload parsing self.data_store.set_raw_data(frame, frame.id) logger.debug(f"Received camera message after set: {frame}") - + except Exception as e: logger.error(f"Camera handler error: {e}") @@ -209,7 +215,11 @@ def discover_compute_classes(self): module = importlib.import_module(module_name) for attr_name in dir(module): attr = getattr(module, attr_name) - if isinstance(attr, type) and issubclass(attr, ComputeInterface) and attr is not ComputeInterface: + if ( + isinstance(attr, type) + and issubclass(attr, ComputeInterface) + and attr is not ComputeInterface + ): compute_classes[attr_name.lower()] = attr return compute_classes @@ -244,8 +254,10 @@ def spawn_computes(self, config_yaml): return compute_tasks + ######################################################## MAIN ############################################################## + async def async_main(): """Main entry point for the DataService.""" logger.info("Starting DataService") @@ -256,5 +268,6 @@ async def async_main(): data_service = DataService(config_yaml) await data_service.start() + if __name__ == "__main__": asyncio.run(async_main()) diff --git a/os/kernel/DataStore.py b/os/kernel/DataStore.py index f07e6906..e04e3650 100644 --- a/os/kernel/DataStore.py +++ b/os/kernel/DataStore.py @@ -1,10 +1,11 @@ import asyncio -from cnc_protocol import cnc_pb2 -from typing import Optional, Union import logging +from cnc_protocol import cnc_pb2 + logger = logging.getLogger(__name__) + class DataStore: def __init__(self): # Raw data caches @@ -28,10 +29,12 @@ def __init__(self): def clear_compute_result(self, compute_id): logger.info(f"clear_compute_result: Clearing result for compute {compute_id}") self._result_cache.pop(compute_id, None) - + ######################################################## COMPUTE ############################################################ - def get_compute_result(self, compute_id, result_type: str) -> Optional[Union[None, tuple]]: - logger.info(f"get_compute_result: Getting result for compute {compute_id} with type {result_type}") + def get_compute_result(self, compute_id, result_type: str) -> tuple | None: + logger.info( + f"get_compute_result: Getting result for compute {compute_id} with type {result_type}" + ) logger.info(self._result_cache) if compute_id not in self._result_cache: # Log an error and return None @@ -47,7 +50,9 @@ def get_compute_result(self, compute_id, result_type: str) -> Optional[Union[Non result = cache.get(result_type) if result is None: # Log an error and return None - logger.error(f"get_compute_result: No result found for compute {compute_id} with type {result_type}") + logger.error( + f"get_compute_result: No result found for compute {compute_id} with type {result_type}" + ) return None return result @@ -56,9 +61,13 @@ def append_compute(self, compute_id): self._result_cache[compute_id] = {} def update_compute_result(self, compute_id, result_type: str, result, timestamp): - assert isinstance(result_type, str), f"Argument must be a string, got {type(result_type).__name__}" + assert isinstance( + result_type, str + ), f"Argument must be a string, got {type(result_type).__name__}" self._result_cache[compute_id][result_type] = (result, timestamp) - logger.debug(f"update_compute_result: Updated result cache for compute {compute_id} with type {result_type}; result: {result}") + logger.debug( + f"update_compute_result: Updated result cache for compute {compute_id} with type {result_type}; result: {result}" + ) ######################################################## RAW DATA ############################################################ def get_raw_data(self, data_copy): @@ -76,10 +85,10 @@ def get_raw_data(self, data_copy): # Create a copy of the protobuf message data_copy.CopyFrom(cache) - + return self._raw_data_id.get(data_copy_type) - def set_raw_data(self, data, data_id = None): + def set_raw_data(self, data, data_id=None): data_type = type(data) if data_type not in self._raw_data_cache: logger.error(f"set_raw_data: No such data: data type {data_type}") @@ -100,4 +109,3 @@ async def wait_for_new_data(self, data_type): return None self._raw_data_event[data_type].clear() await self._raw_data_event[data_type].wait() - diff --git a/os/kernel/Service.py b/os/kernel/Service.py index 19829191..8b6d8d89 100644 --- a/os/kernel/Service.py +++ b/os/kernel/Service.py @@ -1,9 +1,9 @@ - import asyncio import logging logger = logging.getLogger(__name__) + class Service: def __init__(self): self.context = None @@ -13,7 +13,7 @@ def __init__(self): def register_context(self, context): self.context = context - def register_socket (self, sock): + def register_socket(self, sock): self.socks.append(sock) def register_task(self, task): @@ -21,7 +21,7 @@ def register_task(self, task): self.tasks.append(task) async def start(self): - logger.info(f'service started') + logger.info("service started") try: await asyncio.gather(*self.tasks) @@ -29,7 +29,7 @@ async def start(self): await self.shutdown() async def shutdown(self): - logger.info(f'{self.__class__.__name__}: Shutting down CommandService') + logger.info(f"{self.__class__.__name__}: Shutting down CommandService") for sock in self.socks: sock.close() @@ -50,8 +50,4 @@ async def shutdown(self): except Exception as err: logger.error(f"Task raised exception: {err}") - logger.info("Main: CommandService shutdown complete") - - - diff --git a/os/kernel/computes/ComputeItf.py b/os/kernel/computes/ComputeItf.py index e3d66bf9..73c23eea 100644 --- a/os/kernel/computes/ComputeItf.py +++ b/os/kernel/computes/ComputeItf.py @@ -12,7 +12,6 @@ def __init__(self, compute_id): self.compute_id = compute_id self.compute_status = self.ComputeStatus.IDLE - @abstractmethod async def run(self): """Running the worker.""" diff --git a/os/kernel/computes/GabrielCompute.py b/os/kernel/computes/GabrielCompute.py index 8b895fd7..328614a7 100644 --- a/os/kernel/computes/GabrielCompute.py +++ b/os/kernel/computes/GabrielCompute.py @@ -1,14 +1,13 @@ import asyncio -import json import logging import os import time + import cv2 import numpy as np -from gabriel_protocol import gabriel_pb2 -from gabriel_client.zeromq_client import ProducerWrapper, ZeroMQClient -from util.timer import Timer from cnc_protocol import cnc_pb2 +from gabriel_client.zeromq_client import ProducerWrapper, ZeroMQClient +from gabriel_protocol import gabriel_pb2 from kernel.computes.ComputeItf import ComputeInterface from kernel.DataStore import DataStore @@ -16,28 +15,26 @@ class GabrielCompute(ComputeInterface): - def __init__(self, compute_id, data_store:DataStore): + def __init__(self, compute_id, data_store: DataStore): super().__init__(compute_id) # remote computation parameters - self.set_params = { - "model": "coco", - "hsv_lower": None, - "hsv_upper": None - } + self.set_params = {"model": "coco", "hsv_lower": None, "hsv_upper": None} # Gabriel - gabriel_server = os.environ.get('STEELEAGLE_GABRIEL_SERVER') - logger.info(f'Gabriel compute: Gabriel server: {gabriel_server}') - gabriel_port = os.environ.get('STEELEAGLE_GABRIEL_PORT') - logger.info(f'Gabriel compute: Gabriel port: {gabriel_port}') + gabriel_server = os.environ.get("STEELEAGLE_GABRIEL_SERVER") + logger.info(f"Gabriel compute: Gabriel server: {gabriel_server}") + gabriel_port = os.environ.get("STEELEAGLE_GABRIEL_PORT") + logger.info(f"Gabriel compute: Gabriel port: {gabriel_port}") self.gabriel_server = gabriel_server self.gabriel_port = gabriel_port self.engine_results = {} self.drone_registered = False self.gabriel_client = ZeroMQClient( - self.gabriel_server, self.gabriel_port, - [self.get_telemetry_producer(), self.get_frame_producer()], self.process_results + self.gabriel_server, + self.gabriel_port, + [self.get_telemetry_producer(), self.get_frame_producer()], + self.process_results, ) # data_store @@ -45,16 +42,14 @@ def __init__(self, compute_id, data_store:DataStore): self.frame_id = -1 async def run(self): - logger.info(f"Gabriel compute: launching Gabriel client") + logger.info("Gabriel compute: launching Gabriel client") await self.gabriel_client.launch_async() - def set(self): self.set_params["model"] = None self.set_params["hsv_lower"] = None self.set_params["hsv_upper"] = None - def stop(self): """Stopping the worker.""" pass @@ -70,7 +65,7 @@ def process_results(self, result_wrapper): for result in result_wrapper.results: if result.payload_type == gabriel_pb2.PayloadType.TEXT: - payload = result.payload.decode('utf-8') + payload = result.payload.decode("utf-8") try: if len(payload) != 0: # get engine id @@ -78,8 +73,12 @@ def process_results(self, result_wrapper): # get timestamp timestamp = time.time() # update - logger.debug(f"Gabriel compute: timestamp = {timestamp}, compute type = {compute_type}, result = {result}") - self.data_store.update_compute_result(self.compute_id, compute_type, payload, timestamp) + logger.debug( + f"Gabriel compute: timestamp = {timestamp}, compute type = {compute_type}, result = {result}" + ) + self.data_store.update_compute_result( + self.compute_id, compute_type, payload, timestamp + ) except Exception as e: logger.error(f"Gabriel compute process_results: error processing result: {e}") else: @@ -106,14 +105,19 @@ async def producer(): tel_data = cnc_pb2.Telemetry() self.data_store.get_raw_data(tel_data) try: - if frame_data is not None and frame_data.data != b'' and tel_data is not None: + if frame_data is not None and frame_data.data != b"" and tel_data is not None: logger.debug("Waiting for new frame from driver") - logger.info(f"New frame frame_id={frame_data.id} available from driver, tel_data={tel_data}") + logger.info( + f"New frame frame_id={frame_data.id} available from driver, tel_data={tel_data}" + ) frame_bytes = frame_data.data - nparr = np.frombuffer(frame_bytes, dtype = np.uint8) - frame = cv2.imencode('.jpg', nparr.reshape(frame_data.height, frame_data.width, frame_data.channels))[1] + nparr = np.frombuffer(frame_bytes, dtype=np.uint8) + frame = cv2.imencode( + ".jpg", + nparr.reshape(frame_data.height, frame_data.width, frame_data.channels), + )[1] input_frame.payload_type = gabriel_pb2.PayloadType.IMAGE input_frame.payloads.append(frame.tobytes()) @@ -123,30 +127,31 @@ async def producer(): extras.location.latitude = tel_data.global_position.latitude extras.location.longitude = tel_data.global_position.longitude - if self.set_params['model'] is not None: - extras.detection_model = self.set_params['model'] - if self.set_params['hsv_lower'] is not None: - extras.lower_bound.H = self.set_params['hsv_lower'][0] - extras.lower_bound.S = self.set_params['hsv_lower'][1] - extras.lower_bound.V = self.set_params['hsv_lower'][2] - if self.set_params['hsv_upper'] is not None: - extras.upper_bound.H = self.set_params['hsv_upper'][0] - extras.upper_bound.S = self.set_params['hsv_upper'][1] - extras.upper_bound.V = self.set_params['hsv_upper'][2] + if self.set_params["model"] is not None: + extras.detection_model = self.set_params["model"] + if self.set_params["hsv_lower"] is not None: + extras.lower_bound.H = self.set_params["hsv_lower"][0] + extras.lower_bound.S = self.set_params["hsv_lower"][1] + extras.lower_bound.V = self.set_params["hsv_lower"][2] + if self.set_params["hsv_upper"] is not None: + extras.upper_bound.H = self.set_params["hsv_upper"][0] + extras.upper_bound.S = self.set_params["hsv_upper"][1] + extras.upper_bound.V = self.set_params["hsv_upper"][2] if extras is not None: input_frame.extras.Pack(extras) else: - logger.info('Gabriel compute Frame producer: frame is None') + logger.info("Gabriel compute Frame producer: frame is None") input_frame.payload_type = gabriel_pb2.PayloadType.TEXT - input_frame.payloads.append("Streaming not started, no frame to show.".encode('utf-8')) + input_frame.payloads.append(b"Streaming not started, no frame to show.") except Exception as e: input_frame.payload_type = gabriel_pb2.PayloadType.TEXT - input_frame.payloads.append("Unable to produce a frame!".encode('utf-8')) - logger.error(f'Gabriel compute Frame producer: unable to produce a frame: {e}') + input_frame.payloads.append(b"Unable to produce a frame!") + logger.error(f"Gabriel compute Frame producer: unable to produce a frame: {e}") logger.debug(f"Gabriel compute Frame producer: finished time {time.time()}") return input_frame - return ProducerWrapper(producer=producer, source_name='telemetry') + + return ProducerWrapper(producer=producer, source_name="telemetry") def get_telemetry_producer(self): async def producer(): @@ -156,7 +161,7 @@ async def producer(): logger.debug(f"tel producer: starting time {time.time()}") input_frame = gabriel_pb2.InputFrame() input_frame.payload_type = gabriel_pb2.PayloadType.TEXT - input_frame.payloads.append('heartbeart'.encode('utf8')) + input_frame.payloads.append(b"heartbeart") tel_data = cnc_pb2.Telemetry() self.data_store.get_raw_data(tel_data) try: @@ -176,17 +181,21 @@ async def producer(): # Register when we start sending telemetry if not self.drone_registered: - logger.info("Gabriel compute telemetry producer: Sending registeration request to backend") + logger.info( + "Gabriel compute telemetry producer: Sending registeration request to backend" + ) extras.registering = True self.drone_registered = True - logger.debug('Gabriel compute telemetry producer: sending Gabriel telemerty! content: {}'.format(extras)) + logger.debug( + f"Gabriel compute telemetry producer: sending Gabriel telemerty! content: {extras}" + ) input_frame.extras.Pack(extras) else: - logger.error('Telemetry unavailable') + logger.error("Telemetry unavailable") except Exception as e: - logger.debug(f'Gabriel compute telemetry producer: {e}') + logger.debug(f"Gabriel compute telemetry producer: {e}") logger.debug(f"tel producer: finished time {time.time()}") return input_frame - return ProducerWrapper(producer=producer, source_name='telemetry') + return ProducerWrapper(producer=producer, source_name="telemetry") diff --git a/os/kernel/computes/VOXLCompute.py b/os/kernel/computes/VOXLCompute.py index 6ad1f9ea..8be171e6 100644 --- a/os/kernel/computes/VOXLCompute.py +++ b/os/kernel/computes/VOXLCompute.py @@ -1,35 +1,38 @@ import asyncio -from cnc_protocol import cnc_pb2 -import cv2 -from enum import Enum -from kernel.computes.ComputeItf import ComputeInterface -from kernel.DataStore import DataStore import logging -import numpy as np import os +from enum import Enum + +import cv2 import kernel.computes.onboard_compute_pb2 as onboard_compute_pb2 -from util.utils import setup_socket, SocketOperation, lazy_pirate_request +import numpy as np import zmq import zmq.asyncio +from cnc_protocol import cnc_pb2 +from kernel.computes.ComputeItf import ComputeInterface +from kernel.DataStore import DataStore +from util.utils import SocketOperation, lazy_pirate_request, setup_socket logger = logging.getLogger(__name__) + class ComputationType(Enum): OBJECT_DETECTION = 1 DEPTH_ESTIMATION = 2 + class VOXLCompute(ComputeInterface): - ''' + """ Utilizes the onboard computational capabilities on the Modal AI VOXL 2. - ''' + """ def __init__(self, compute_id: int, data_store: DataStore): super().__init__(compute_id) self.context = zmq.asyncio.Context() self.socket = self.context.socket(zmq.REQ) - host = os.environ.get('LCE_HOST') - port = os.environ.get('LCE_PORT') + host = os.environ.get("LCE_HOST") + port = os.environ.get("LCE_PORT") if host is None: logger.error("Host not specified") @@ -38,10 +41,14 @@ def __init__(self, compute_id: int, data_store: DataStore): logger.error("Port not specified") raise Exception("Port not specified") - setup_socket(self.socket, SocketOperation.CONNECT, 'LCE_PORT', - 'Created socket to connect to local compute engine', - host) - self.server_endpoint = f'tcp://{host}:{port}' + setup_socket( + self.socket, + SocketOperation.CONNECT, + "LCE_PORT", + "Created socket to connect to local compute engine", + host, + ) + self.server_endpoint = f"tcp://{host}:{port}" self.is_running = False self.frame_id = -1 self.data_store = data_store @@ -61,13 +68,13 @@ async def get_status(self): return super().get_status() async def set(self): - raise NotImplemented() + raise NotImplementedError() async def run_loop(self): - ''' + """ Query data store in a loop and feed frames for processing to onboard compute engine. - ''' + """ logger.info("VOXL compute is running") while self.is_running: frame_data = cnc_pb2.Frame() @@ -79,10 +86,10 @@ async def run_loop(self): frame_id = self.data_store.get_raw_data(frame_data) self.frame_id = frame_id - if frame_data.data != b'': + if frame_data.data != b"": logger.info(f"VOXL compute got new frame {frame_data.id} from data store") frame_bytes = frame_data.data - nparr = np.frombuffer(frame_bytes, dtype = np.uint8) + nparr = np.frombuffer(frame_bytes, dtype=np.uint8) height = frame_data.height width = frame_data.width @@ -93,10 +100,10 @@ async def run_loop(self): await self.process_frame(frame, ComputationType.OBJECT_DETECTION) async def process_frame(self, frame: np.ndarray, computation_type: ComputationType): - ''' + """ Send frames to onboard compute engine for processing. Currently only supports object detection, and prints detected classes. - ''' + """ request = onboard_compute_pb2.ComputeRequest() request.frame_data = cv2.cvtColor(frame, cv2.COLOR_BGR2YUV_YUYV).tobytes() request.frame_width = frame.shape[1] @@ -111,15 +118,14 @@ async def process_frame(self, frame: np.ndarray, computation_type: ComputationTy reply = None (self.socket, reply) = await lazy_pirate_request( - self.socket, request.SerializeToString(), self.context, - self.server_endpoint) + self.socket, request.SerializeToString(), self.context, self.server_endpoint + ) - if reply == None: - logger.error(f"Local compute engine did not respond to request") + if reply is None: + logger.error("Local compute engine did not respond to request") return - logger.info(f"Received response from local compute engine") + logger.info("Received response from local compute engine") detections = onboard_compute_pb2.ComputeResult() detections.ParseFromString(reply) logger.info(f"Received detections: {detections}") - diff --git a/os/test/cprofile_reader.py b/os/test/cprofile_reader.py index 4bfccb31..9466dcc5 100644 --- a/os/test/cprofile_reader.py +++ b/os/test/cprofile_reader.py @@ -4,9 +4,11 @@ if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description='cprofile reader') - parser.add_argument('-p', '--path', default = './drivers/olympe/driver_test.out', help='the path to workspace') + parser = argparse.ArgumentParser(description="cprofile reader") + parser.add_argument( + "-p", "--path", default="./drivers/olympe/driver_test.out", help="the path to workspace" + ) args = parser.parse_args() path = args.path p = pstats.Stats(path) - p.sort_stats(SortKey.TIME).print_stats(10) \ No newline at end of file + p.sort_stats(SortKey.TIME).print_stats(10) diff --git a/os/test/dummy_commander.py b/os/test/dummy_commander.py index 59bedd7d..23068d77 100644 --- a/os/test/dummy_commander.py +++ b/os/test/dummy_commander.py @@ -1,52 +1,58 @@ +import asyncio import os -import zmq import time -from cnc_protocol import cnc_pb2 -import asyncio + +import zmq import zmq.asyncio +from cnc_protocol import cnc_pb2 from util.utils import setup_socket context = zmq.Context() # Create socket endpoints for driver cmd_front_sock = context.socket(zmq.DEALER) -cmdr_identity = b'cmdr' +cmdr_identity = b"cmdr" cmd_front_sock.setsockopt(zmq.IDENTITY, cmdr_identity) -setup_socket(cmd_front_sock, 'connect', 'CMD_FRONT_PORT', 'Created command frontend socket endpoint', os.environ.get("LOCALHOST")) - +setup_socket( + cmd_front_sock, + "connect", + "CMD_FRONT_PORT", + "Created command frontend socket endpoint", + os.environ.get("LOCALHOST"), +) -class c_client(): +class c_client: def send_takeOff(self): driver_command = cnc_pb2.Extras() driver_command.cmd.takeoff = True message = driver_command.SerializeToString() cmd_front_sock.send_multipart([message]) print(f"commander: take off sent at: {time.time()}") - + def send_land(self): driver_command = cnc_pb2.Extras() driver_command.cmd.land = True message = driver_command.SerializeToString() - cmd_front_socket.send_multipart([message]) + cmd_front_sock.send_multipart([message]) print(f"commander: land sent at: {time.time()}") - + def send_MCOM(self, key): driver_command = cnc_pb2.Extras() match key: - case 'w': + case "w": driver_command.cmd.pcmd.pitch = 25 - case 's': + case "s": driver_command.cmd.pcmd.pitch = -25 - case 'a': + case "a": driver_command.cmd.pcmd.roll = 25 - case 'd': + case "d": driver_command.cmd.pcmd.roll = -25 - case 'f': + case "f": pass message = driver_command.SerializeToString() cmd_front_sock.send_multipart([message]) - print(f"commander: manual command of \'{key}\' sent at: {time.time()}") + print(f"commander: manual command of '{key}' sent at: {time.time()}") def send_startM(self): driver_command = cnc_pb2.Extras() @@ -54,18 +60,18 @@ def send_startM(self): message = driver_command.SerializeToString() cmd_front_sock.send_multipart([message]) print(f"commander: mission sent at: {time.time()}") - + async def a_run(self): # Interactive command input loop - MCOM_SET = ['w', 'a', 's', 'd', 'i', 'j', 'k', 'l', 'f'] + MCOM_SET = ["w", "a", "s", "d", "i", "j", "k", "l", "f"] while True: user_input = input() - - if user_input == 't': + + if user_input == "t": self.send_takeOff() - elif user_input == 'g': + elif user_input == "g": self.send_land() - elif user_input == 'm': + elif user_input == "m": self.send_startM() elif user_input in MCOM_SET: self.send_MCOM(user_input) @@ -74,9 +80,10 @@ async def a_run(self): break else: print("Invalid command.") - + await asyncio.sleep(0) + if __name__ == "__main__": print("Starting client") k = c_client() diff --git a/os/test/dummy_driver.py b/os/test/dummy_driver.py index cbdfda7c..5c9733ac 100644 --- a/os/test/dummy_driver.py +++ b/os/test/dummy_driver.py @@ -1,6 +1,7 @@ import asyncio import os import time + import zmq import zmq.asyncio from cnc_protocol import cnc_pb2 @@ -12,24 +13,37 @@ cam_sock = context.socket(zmq.PUB) tel_sock.setsockopt(zmq.CONFLATE, 1) cam_sock.setsockopt(zmq.CONFLATE, 1) -setup_socket(tel_sock, 'connect', 'TEL_PORT', 'Created telemetry socket endpoint', os.environ.get("LOCALHOST")) -setup_socket(cam_sock, 'connect', 'CAM_PORT', 'Created camera socket endpoint', os.environ.get("LOCALHOST")) -setup_socket(cmd_back_sock, 'connect', 'CMD_BACK_PORT', 'Created command backend socket endpoint', os.environ.get("LOCALHOST")) - - +setup_socket( + tel_sock, + "connect", + "TEL_PORT", + "Created telemetry socket endpoint", + os.environ.get("LOCALHOST"), +) +setup_socket( + cam_sock, "connect", "CAM_PORT", "Created camera socket endpoint", os.environ.get("LOCALHOST") +) +setup_socket( + cmd_back_sock, + "connect", + "CMD_BACK_PORT", + "Created command backend socket endpoint", + os.environ.get("LOCALHOST"), +) async def camera_stream(drone, camera_sock): frame_id = 0 - cam_message = cnc_pb2.Frame() + cam_message = cnc_pb2.Frame() # while drone.isConnected(): while True: try: x = 1 - except Exception as e: + except Exception: pass await asyncio.sleep(0.033) + async def telemetry_stream(drone, telemetry_sock): tel_message = cnc_pb2.Telemetry() # while drone.isConnected(): @@ -37,39 +51,39 @@ async def telemetry_stream(drone, telemetry_sock): try: x = 2 except Exception as e: - print(f'Failed to get telemetry, error: {e}') + print(f"Failed to get telemetry, error: {e}") await asyncio.sleep(0) - -class d_server(): + +class d_server: async def a_run(self): asyncio.create_task(telemetry_stream(None, tel_sock)) - asyncio.create_task(camera_stream(None, cam_sock)) + asyncio.create_task(camera_stream(None, cam_sock)) while True: try: # Receive a message from the DEALER socket message_parts = await cmd_back_sock.recv_multipart() - + # Expecting three parts: [identity, empty, message] if len(message_parts) != 2: print(f"Invalid message received: {message_parts}") continue - + identity = message_parts[0] # Identity of the DEALER socket - message = message_parts[1] # The empty delimiter part + message = message_parts[1] # The empty delimiter part # message = message_parts[2] # The actual serialized request - + # Print each part to understand the structure print(f"Identity: {identity}") # print(f"Empty delimiter: {empty}") print(f"Message: {message}") - + # Parse the message driver_req = cnc_pb2.Driver() driver_req.ParseFromString(message) print(f"Received the message: {driver_req}") print(f"Request seqNum: {driver_req.seqNum}") - + # Print parsed message and determine the response if driver_req.takeOff: print(f"take off received: {time.time()}") @@ -78,17 +92,18 @@ async def a_run(self): else: print("Unknown request") driver_req.resp = cnc_pb2.ResponseStatus.NOTSUPPORTED - + serialized_response = driver_req.SerializeToString() - + # Send a reply back to the client with the identity frame and empty delimiter cmd_back_sock.send_multipart([identity, serialized_response]) - - print(f"done processing request") + + print("done processing request") except Exception as e: print(f"error: {e}") - + + if __name__ == "__main__": print("Starting server") - d= d_server() - asyncio.run(d.a_run()) \ No newline at end of file + d = d_server() + asyncio.run(d.a_run()) diff --git a/os/test/dummy_kernel_cmd.py b/os/test/dummy_kernel_cmd.py index 926ce9d0..ee56d367 100644 --- a/os/test/dummy_kernel_cmd.py +++ b/os/test/dummy_kernel_cmd.py @@ -1,19 +1,20 @@ +import asyncio +import logging import os -import zmq +import sys import time -from cnc_protocol import cnc_pb2 -import asyncio + +import zmq import zmq.asyncio +from cnc_protocol import cnc_pb2 from util.utils import SocketOperation, setup_socket -import logging -import sys # Configure logger logger = logging.getLogger() logger.setLevel(logging.INFO) handler = logging.StreamHandler(sys.stdout) handler.setLevel(logging.INFO) -formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") handler.setFormatter(formatter) logger.addHandler(handler) @@ -36,27 +37,41 @@ # Setting up conetxt context = zmq.asyncio.Context() -drone_id = os.environ.get('DRONE_ID') +drone_id = os.environ.get("DRONE_ID") # Setting up sockets cmd_front_cmdr_sock = context.socket(zmq.DEALER) -cmd_front_cmdr_sock.setsockopt(zmq.IDENTITY, drone_id.encode('utf-8')) +cmd_front_cmdr_sock.setsockopt(zmq.IDENTITY, drone_id.encode("utf-8")) cmd_front_usr_sock = context.socket(zmq.DEALER) cmd_back_sock = context.socket(zmq.DEALER) msn_sock = context.socket(zmq.REQ) -setup_socket(cmd_front_cmdr_sock, SocketOperation.CONNECT, 'CMD_FRONT_CMDR_PORT', 'Connected command frontend cmdr socket endpoint', os.environ.get('STEELEAGLE_GABRIEL_SERVER')) -setup_socket(cmd_front_usr_sock, SocketOperation.BIND, 'CMD_FRONT_USR_PORT', 'Created command frontend user socket endpoint') -setup_socket(cmd_back_sock, SocketOperation.BIND, 'CMD_BACK_PORT', 'Created command backend socket endpoint') -setup_socket(msn_sock, SocketOperation.BIND, 'MSN_PORT', 'Created userspace mission control socket endpoint') - +setup_socket( + cmd_front_cmdr_sock, + SocketOperation.CONNECT, + "CMD_FRONT_CMDR_PORT", + "Connected command frontend cmdr socket endpoint", + os.environ.get("STEELEAGLE_GABRIEL_SERVER"), +) +setup_socket( + cmd_front_usr_sock, + SocketOperation.BIND, + "CMD_FRONT_USR_PORT", + "Created command frontend user socket endpoint", +) +setup_socket( + cmd_back_sock, SocketOperation.BIND, "CMD_BACK_PORT", "Created command backend socket endpoint" +) +setup_socket( + msn_sock, SocketOperation.BIND, "MSN_PORT", "Created userspace mission control socket endpoint" +) -class k_client(): +class k_client: # Function to send a start mission command def send_start_mission(self): mission_command = cnc_pb2.Mission() mission_command.startMission = True message = mission_command.SerializeToString() - print(f'start_mission message:{message}') + print(f"start_mission message:{message}") msn_sock.send(message) reply = msn_sock.recv_string() print(f"Server reply: {reply}") @@ -69,21 +84,20 @@ def send_stop_mission(self): msn_sock.send(message) reply = msn_sock.recv_string() print(f"Server reply: {reply}") - - + async def send_takeOff(self): driver_command = cnc_pb2.Driver() driver_command.takeOff = True message = driver_command.SerializeToString() print(f"take off sent at: {time.time()}") await cmd_back_sock.send_multipart([message]) - + async def proxy(self): - print('user_driver_cmd_proxy started') + print("user_driver_cmd_proxy started") poller = zmq.asyncio.Poller() poller.register(cmd_front_cmdr_sock, zmq.POLLIN) poller.register(cmd_back_sock, zmq.POLLIN) - + while True: try: print("polling") @@ -98,14 +112,13 @@ async def proxy(self): identity = message[0] cmd = message[1] print(f"proxy : 2 Received message from BACKEND: identity: {identity}") - - if identity == b'cmdr': + + if identity == b"cmdr": await self.send_takeOff() - elif identity == b'usr': + elif identity == b"usr": await cmd_back_sock.send_multipart(message) else: - print(f"cmd_proxy: invalid identity") - + print("cmd_proxy: invalid identity") # Check for messages on the DEALER socket if cmd_back_sock in socks: @@ -116,27 +129,28 @@ async def proxy(self): identity = message[0] cmd = message[1] print(f"proxy : 4 Received message from FRONTEND: identity: {identity}") - - if identity == b'cmdr': - print(f"proxy : 5 Received message from FRONTEND: discard bc of cmdr") + + if identity == b"cmdr": + print("proxy : 5 Received message from FRONTEND: discard bc of cmdr") pass - elif identity == b'usr': - print(f"proxy : 5 Received message from FRONTEND: sent back bc of user") + elif identity == b"usr": + print("proxy : 5 Received message from FRONTEND: sent back bc of user") await cmd_front_cmdr_sock.send_multipart(message) else: - print(f"cmd_proxy: invalid identity") - + print("cmd_proxy: invalid identity") + except Exception as e: print(f"Proxy error: {e}") async def a_run(self): # Interactive command input loop asyncio.create_task(self.proxy()) - + while True: await asyncio.sleep(0) + if __name__ == "__main__": print("Starting client") k = k_client() - asyncio.run(k.a_run()) \ No newline at end of file + asyncio.run(k.a_run()) diff --git a/os/test/dummy_kernel_data.py b/os/test/dummy_kernel_data.py index af404b8f..0a5b8808 100644 --- a/os/test/dummy_kernel_data.py +++ b/os/test/dummy_kernel_data.py @@ -1,19 +1,19 @@ -import os -import zmq -import time -from cnc_protocol import cnc_pb2 import asyncio +import logging +import sys +import time + +import zmq import zmq.asyncio +from cnc_protocol import cnc_pb2 from util.utils import setup_socket -import logging -import sys # Configure logger logger = logging.getLogger() logger.setLevel(logging.INFO) handler = logging.StreamHandler(sys.stdout) handler.setLevel(logging.INFO) -formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") handler.setFormatter(formatter) logger.addHandler(handler) @@ -24,25 +24,24 @@ cmd_front_sock = context.socket(zmq.ROUTER) cmd_back_sock = context.socket(zmq.DEALER) msn_sock = context.socket(zmq.REQ) -tel_sock.setsockopt(zmq.SUBSCRIBE, b'') # Subscribe to all topics +tel_sock.setsockopt(zmq.SUBSCRIBE, b"") # Subscribe to all topics tel_sock.setsockopt(zmq.CONFLATE, 1) -cam_sock.setsockopt(zmq.SUBSCRIBE, b'') # Subscribe to all topics +cam_sock.setsockopt(zmq.SUBSCRIBE, b"") # Subscribe to all topics cam_sock.setsockopt(zmq.CONFLATE, 1) -setup_socket(tel_sock, 'bind', 'TEL_PORT', 'Created telemetry socket endpoint') -setup_socket(cam_sock, 'bind', 'CAM_PORT', 'Created camera socket endpoint') -setup_socket(cmd_front_sock, 'bind', 'CMD_FRONT_PORT', 'Created command frontend socket endpoint') -setup_socket(cmd_back_sock, 'bind', 'CMD_BACK_PORT', 'Created command backend socket endpoint') -setup_socket(msn_sock, 'bind', 'MSN_PORT', 'Created user space mission control socket endpoint') - +setup_socket(tel_sock, "bind", "TEL_PORT", "Created telemetry socket endpoint") +setup_socket(cam_sock, "bind", "CAM_PORT", "Created camera socket endpoint") +setup_socket(cmd_front_sock, "bind", "CMD_FRONT_PORT", "Created command frontend socket endpoint") +setup_socket(cmd_back_sock, "bind", "CMD_BACK_PORT", "Created command backend socket endpoint") +setup_socket(msn_sock, "bind", "MSN_PORT", "Created user space mission control socket endpoint") -class k_client(): +class k_client: # Function to send a start mission command def send_start_mission(self): mission_command = cnc_pb2.Mission() mission_command.startMission = True message = mission_command.SerializeToString() - print(f'start_mission message:{message}') + print(f"start_mission message:{message}") msn_sock.send(message) reply = msn_sock.recv_string() print(f"Server reply: {reply}") @@ -55,21 +54,20 @@ def send_stop_mission(self): msn_sock.send(message) reply = msn_sock.recv_string() print(f"Server reply: {reply}") - - + async def send_takeOff(self): driver_command = cnc_pb2.Driver() driver_command.takeOff = True message = driver_command.SerializeToString() print(f"take off sent at: {time.time()}") await cmd_back_sock.send_multipart([message]) - + async def proxy(self): - print('user_driver_cmd_proxy started') + print("user_driver_cmd_proxy started") poller = zmq.asyncio.Poller() poller.register(cmd_front_sock, zmq.POLLIN) poller.register(cmd_back_sock, zmq.POLLIN) - + while True: try: print("polling") @@ -84,14 +82,13 @@ async def proxy(self): identity = message[0] cmd = message[1] print(f"proxy : 2 Received message from BACKEND: identity: {identity}") - - if identity == b'cmdr': + + if identity == b"cmdr": await self.send_takeOff() - elif identity == b'usr': + elif identity == b"usr": await cmd_back_sock.send_multipart(message) else: - print(f"cmd_proxy: invalid identity") - + print("cmd_proxy: invalid identity") # Check for messages on the DEALER socket if cmd_back_sock in socks: @@ -102,27 +99,28 @@ async def proxy(self): identity = message[0] cmd = message[1] print(f"proxy : 4 Received message from FRONTEND: identity: {identity}") - - if identity == b'cmdr': - print(f"proxy : 5 Received message from FRONTEND: discard bc of cmdr") + + if identity == b"cmdr": + print("proxy : 5 Received message from FRONTEND: discard bc of cmdr") pass - elif identity == b'usr': - print(f"proxy : 5 Received message from FRONTEND: sent back bc of user") + elif identity == b"usr": + print("proxy : 5 Received message from FRONTEND: sent back bc of user") await cmd_front_sock.send_multipart(message) else: - print(f"cmd_proxy: invalid identity") - + print("cmd_proxy: invalid identity") + except Exception as e: print(f"Proxy error: {e}") async def a_run(self): # Interactive command input loop asyncio.create_task(self.proxy()) - + while True: await asyncio.sleep(0) + if __name__ == "__main__": print("Starting client") k = k_client() - asyncio.run(k.a_run()) \ No newline at end of file + asyncio.run(k.a_run()) diff --git a/os/test/kernel/dummy_gabriel_client.py b/os/test/kernel/dummy_gabriel_client.py index 145759d9..a0e83884 100644 --- a/os/test/kernel/dummy_gabriel_client.py +++ b/os/test/kernel/dummy_gabriel_client.py @@ -1,23 +1,25 @@ import asyncio import os import sys -from gabriel_protocol import gabriel_pb2 -from gabriel_client.websocket_client import ProducerWrapper, WebsocketClient -from cnc_protocol import cnc_pb2 + import nest_asyncio +from cnc_protocol import cnc_pb2 +from gabriel_client.websocket_client import ProducerWrapper, WebsocketClient +from gabriel_protocol import gabriel_pb2 + nest_asyncio.apply() + class Dummy: - def __init__(self, gabriel_server, gabriel_port): self.gabriel_server = gabriel_server self.gabriel_port = gabriel_port self.heartbeats = 0 self.drone_id = "test_drone" - + def processResults(self, result_wrapper): - if result_wrapper.result_producer_name.value == 'telemetry': - print(f'Telemetry received: {result_wrapper}') + if result_wrapper.result_producer_name.value == "telemetry": + print(f"Telemetry received: {result_wrapper}") def get_producer_wrappers(self): async def producer(): @@ -25,7 +27,7 @@ async def producer(): self.heartbeats += 1 input_frame = gabriel_pb2.InputFrame() input_frame.payload_type = gabriel_pb2.PayloadType.TEXT - input_frame.payloads.append('heartbeart'.encode('utf8')) + input_frame.payloads.append(b"heartbeart") extras = cnc_pb2.Extras() # test @@ -35,20 +37,22 @@ async def producer(): if self.heartbeats == 1: extras.registering = True - print('Producing Gabriel frame!') + print("Producing Gabriel frame!") input_frame.extras.Pack(extras) return input_frame - return ProducerWrapper(producer=producer, source_name='telemetry') - + return ProducerWrapper(producer=producer, source_name="telemetry") + async def run(self): - print('Creating client') + print("Creating client") gabriel_client = WebsocketClient( - self.gabriel_server, self.gabriel_port, - [self.get_producer_wrappers()], self.processResults + self.gabriel_server, + self.gabriel_port, + [self.get_producer_wrappers()], + self.processResults, ) - print('client created') - + print("client created") + try: # command_coroutine = asyncio.create_task(self.command_handler()) # telemetry_coroutine = asyncio.create_task(self.telemetry_handler()) @@ -56,7 +60,7 @@ async def run(self): # while True: # # logger.info('Running Kernel') # await asyncio.sleep(0) - + except KeyboardInterrupt: print("Shutting down Kernel") # command_coroutine.cancel() @@ -64,14 +68,14 @@ async def run(self): # await command_coroutine # await telemetry_coroutine sys.exit(0) - + + if __name__ == "__main__": print("Starting Kernel") - - gabriel_server = os.environ.get('STEELEAGLE_GABRIEL_SERVER') - print(f'Gabriel server: {gabriel_server}') - gabriel_port = os.environ.get('STEELEAGLE_GABRIEL_PORT') - print(f'Gabriel port: {gabriel_port}') + + gabriel_server = os.environ.get("STEELEAGLE_GABRIEL_SERVER") + print(f"Gabriel server: {gabriel_server}") + gabriel_port = os.environ.get("STEELEAGLE_GABRIEL_PORT") + print(f"Gabriel port: {gabriel_port}") k = Dummy(gabriel_server, gabriel_port) asyncio.run(k.run()) - \ No newline at end of file diff --git a/os/test/kernel/dummy_telemetry_handler.py b/os/test/kernel/dummy_telemetry_handler.py index c9625e71..d740a1cb 100644 --- a/os/test/kernel/dummy_telemetry_handler.py +++ b/os/test/kernel/dummy_telemetry_handler.py @@ -1,67 +1,67 @@ import asyncio import os + import zmq from cnc_protocol import cnc_pb2 # Create a pub/sub socket that telemetry can be read from context = zmq.Context() telemetry_socket = context.socket(zmq.SUB) -telemetry_socket.setsockopt(zmq.SUBSCRIBE, b'') -addr = 'tcp://' + os.environ.get('STEELEAGLE_DRIVER_TEL_SUB_ADDR') -print(f'Telemetry address: {addr}') +telemetry_socket.setsockopt(zmq.SUBSCRIBE, b"") +addr = "tcp://" + os.environ.get("STEELEAGLE_DRIVER_TEL_SUB_ADDR") +print(f"Telemetry address: {addr}") if addr: telemetry_socket.connect(addr) - print('Connected to telemetry publish endpoint') + print("Connected to telemetry publish endpoint") else: - print('Cannot get telemetry publish endpoint from system') + print("Cannot get telemetry publish endpoint from system") quit() - + + class Dummy: def __init__(self): self.telemetry_socket = telemetry_socket self.telemetry_cache = { - "location": { - "latitude": None, - "longitude": None, - "altitude": None - }, + "location": {"latitude": None, "longitude": None, "altitude": None}, "battery": None, "magnetometer": None, - "bearing": None + "bearing": None, } - + async def telemetry_handler(self): - print('Telemetry handler started') - + print("Telemetry handler started") + while True: try: - print(f'Telemetry Handler: Waiting for telemetry') + print("Telemetry Handler: Waiting for telemetry") msg = self.telemetry_socket.recv(flags=zmq.NOBLOCK) - print(f'Telemetry Handler: Received telemetry') + print("Telemetry Handler: Received telemetry") telemetry = cnc_pb2.Telemetry() telemetry.ParseFromString(msg) - print(f'Telemetry Handler: {telemetry}') - self.telemetry_cache['location']['latitude'] = telemetry.global_position.latitude - self.telemetry_cache['location']['longitude'] = telemetry.global_position.longitude - self.telemetry_cache['location']['altitude'] = telemetry.global_position.altitude - self.telemetry_cache['battery'] = telemetry.battery - self.telemetry_cache['magnetometer'] = telemetry.mag - self.telemetry_cache['bearing'] = telemetry.drone_attitude.yaw - - print(f'Telemetry Handler: Latitude: {telemetry.global_position.latitude} Longitude: {telemetry.global_position.longitude} Altitude: {telemetry.global_position.altitude}') - print(f'Telemetry Handler: Battery: {telemetry.battery}') - print(f'Telemetry Handler: Magnetometer: {telemetry.mag}') - print(f'Telemetry Handler: Bearing: {telemetry.drone_attitude.yaw}') + print(f"Telemetry Handler: {telemetry}") + self.telemetry_cache["location"]["latitude"] = telemetry.global_position.latitude + self.telemetry_cache["location"]["longitude"] = telemetry.global_position.longitude + self.telemetry_cache["location"]["altitude"] = telemetry.global_position.altitude + self.telemetry_cache["battery"] = telemetry.battery + self.telemetry_cache["magnetometer"] = telemetry.mag + self.telemetry_cache["bearing"] = telemetry.drone_attitude.yaw + + print( + f"Telemetry Handler: Latitude: {telemetry.global_position.latitude} Longitude: {telemetry.global_position.longitude} Altitude: {telemetry.global_position.altitude}" + ) + print(f"Telemetry Handler: Battery: {telemetry.battery}") + print(f"Telemetry Handler: Magnetometer: {telemetry.mag}") + print(f"Telemetry Handler: Bearing: {telemetry.drone_attitude.yaw}") except zmq.Again: - print('Telemetry handler no received telemetry') + print("Telemetry handler no received telemetry") pass - + except Exception as e: print(f"Telemetry Handler Exception: {e}") - + await asyncio.sleep(0) + if __name__ == "__main__": - k = Dummy() - asyncio.run(k.telemetry_handler()) \ No newline at end of file + asyncio.run(k.telemetry_handler()) diff --git a/os/test/user/dummy_kernel_msn.py b/os/test/user/dummy_kernel_msn.py index cc63bf7c..a2886a7e 100644 --- a/os/test/user/dummy_kernel_msn.py +++ b/os/test/user/dummy_kernel_msn.py @@ -1,8 +1,9 @@ import asyncio import logging import sys -from cnc_protocol import cnc_pb2 + import zmq +from cnc_protocol import cnc_pb2 from util.utils import SocketOperation, setup_socket # Configure logger @@ -10,27 +11,29 @@ logger.setLevel(logging.INFO) handler = logging.StreamHandler(sys.stdout) handler.setLevel(logging.INFO) -formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") handler.setFormatter(formatter) logger.addHandler(handler) context = zmq.asyncio.Context() msn_sock = context.socket(zmq.REQ) -setup_socket(msn_sock, SocketOperation.BIND, 'MSN_PORT', 'Created userspace mission control socket endpoint') +setup_socket( + msn_sock, SocketOperation.BIND, "MSN_PORT", "Created userspace mission control socket endpoint" +) + class c_client: - # Function to send a start mission command def send_start_mission(self): mission_command = cnc_pb2.Mission() mission_command.startMission = True message = mission_command.SerializeToString() - print(f'start_mission message:{message}') + print(f"start_mission message:{message}") msn_sock.send(message) reply = msn_sock.recv_string() print(f"Server reply: {reply}") - + # Function to send a stop mission command def send_stop_mission(self): mission_command = cnc_pb2.Mission() @@ -39,22 +42,23 @@ def send_stop_mission(self): msn_sock.send(message) reply = msn_sock.recv_string() print(f"Server reply: {reply}") - + async def a_run(self): # Interactive command input loop - MCOM_SET = ['start', 'stop'] + MCOM_SET = ["start", "stop"] while True: user_input = input() - - if user_input == 'start': + + if user_input == "start": self.send_start_mission() - elif user_input == 'stop': + elif user_input == "stop": self.send_stop_mission() else: print("Invalid command.") - + await asyncio.sleep(0) + if __name__ == "__main__": print("Starting client") k = c_client() diff --git a/os/user/common/MissionController.py b/os/user/common/MissionController.py index 9ef0891e..13774d4b 100644 --- a/os/user/common/MissionController.py +++ b/os/user/common/MissionController.py @@ -1,25 +1,23 @@ - +import asyncio import importlib +import logging import os -import subprocess import shutil +import subprocess import sys from zipfile import ZipFile + import requests import zmq -import asyncio -import logging - -from util.utils import SocketOperation, setup_socket -from system_call_stubs.DroneStub import DroneStub -from system_call_stubs.ComputeStub import ComputeStub from cnc_protocol import cnc_pb2 +from system_call_stubs.ComputeStub import ComputeStub +from system_call_stubs.DroneStub import DroneStub +from util.utils import SocketOperation, setup_socket logger = logging.getLogger(__name__) - -class MissionController(): +class MissionController: def __init__(self, user_path): self.isTerminated = False self.tm = None @@ -28,24 +26,29 @@ def __init__(self, user_path): self.reload = False self.user_path = user_path logger.info("Mission Controller created") - + context = zmq.Context() self.msn_sock = context.socket(zmq.REP) - setup_socket(self.msn_sock, SocketOperation.CONNECT, 'MSN_PORT', 'Connected to user space mission control socket endpoint', os.environ.get("CMD_ENDPOINT")) - + setup_socket( + self.msn_sock, + SocketOperation.CONNECT, + "MSN_PORT", + "Connected to user space mission control socket endpoint", + os.environ.get("CMD_ENDPOINT"), + ) + ######################################################## MISSION ############################################################ def install_prereqs(self) -> bool: ret = False # Pip install prerequsites for flight script - requirements_path = os.path.join(self.user_path, 'requirements.txt') + requirements_path = os.path.join(self.user_path, "requirements.txt") try: - subprocess.check_call(['python3', '-m', 'pip', 'install', '-r', requirements_path]) + subprocess.check_call(["python3", "-m", "pip", "install", "-r", requirements_path]) ret = True except subprocess.CalledProcessError as e: logger.debug(f"Error pip installing requirements.txt: {e}") return ret - - + def clean_user_path(self): for filename in os.listdir(self.user_path): file_path = os.path.join(self.user_path, filename) @@ -54,18 +57,17 @@ def clean_user_path(self): else: os.remove(file_path) # Remove files - def download_script(self, url): # Download zipfile and extract reqs/flight script from cloudlet try: - filename = url.rsplit(sep='/')[-1] - logger.info(f'Writing {filename} to disk...') - + filename = url.rsplit(sep="/")[-1] + logger.info(f"Writing {filename} to disk...") + # Download the file r = requests.get(url, stream=True) r.raise_for_status() # Raise an error for bad responses - - with open(filename, mode='wb') as f: + + with open(filename, mode="wb") as f: for chunk in r.iter_content(chunk_size=8192): # Use a chunk size f.write(chunk) @@ -83,54 +85,57 @@ def download_script(self, url): logger.info(f"Downloaded and extracted {filename} to {self.user_path}") self.install_prereqs() - + except Exception as e: logger.error(f"An unexpected error occurred: {e}") - + def download_mission(self, url): self.download_script(url) def reload_mission(self): - logger.info('Reloading...') + logger.info("Reloading...") modules = sys.modules.copy() for module_name, module in modules.items(): logger.info(f"Module name: {module_name}") - if module_name.startswith('project.task_defs') or module_name.startswith('project.Mission') or module_name.startswith('project.transition_defs'): + if ( + module_name.startswith("project.task_defs") + or module_name.startswith("project.Mission") + or module_name.startswith("project.transition_defs") + ): try: # Log and reload the module logger.info(f"Reloading module: {module_name}") importlib.reload(module) except Exception as e: logger.error(f"Failed to reload module {module_name}: {e}") - + def start_mission(self): if self.tm: - logger.info(f"mission already running") + logger.info("mission already running") return - else: # first time mission, create a task manager - import common.TaskManager as tm - - logger.info(f"start the mission") - if self.reload : + else: # first time mission, create a task manager + import common.TaskManager as tm + + logger.info("start the mission") + if self.reload: self.reload_mission() - - import project.Mission as msn # import the mission module instead of attribute of the module for the reload to work - self.reload = True - + + import project.Mission as msn # import the mission module instead of attribute of the module for the reload to work + + self.reload = True + msn.Mission.define_mission(self.transitMap, self.task_arg_map) - + # start the tm - logger.info(f"start the task manager") + logger.info("start the task manager") self.tm = tm.TaskManager(self.drone, self.compute, self.transitMap, self.task_arg_map) self.tm_coroutine = asyncio.create_task(self.tm.run()) - - - + async def end_mission(self): if self.tm and not self.tm_coroutine.cancelled(): self.tm_coroutine.cancel() try: - await self.tm_coroutine + await self.tm_coroutine except asyncio.CancelledError: logger.info("Mission coroutine was cancelled successfully.") except Exception as e: @@ -141,54 +146,52 @@ async def end_mission(self): logger.info("Mission has been ended and cleaned up.") else: logger.info("Mission not running or already cancelled.") - - ######################################################## MAIN LOOP ############################################################ + + ######################################################## MAIN LOOP ############################################################ async def run(self): self.drone = DroneStub() self.compute = ComputeStub() asyncio.create_task(self.drone.run()) asyncio.create_task(self.compute.run()) - + # self.compute = ComputeStub() while True: logger.debug("MC") try: # Receive a message message = self.msn_sock.recv(flags=zmq.NOBLOCK) - + # Log the raw received message logger.info(f"Received raw message: {message}") - + # Parse the message mission_command = cnc_pb2.Mission() mission_command.ParseFromString(message) logger.info(f"Parsed Command: {mission_command}") - + if mission_command.downloadMission: self.download_mission(mission_command.downloadMission) response = "Mission downloaded" - + elif mission_command.startMission: self.start_mission() response = "Mission started" - + elif mission_command.stopMission: await self.end_mission() response = "Mission stopped" - + else: response = "Unknown command" # Send a reply back to the client self.msn_sock.send_string(response) - + except zmq.Again: pass - + except Exception as e: logger.info(f"Failed to parse message: {e}") self.msn_sock.send_string("Error processing command") - + await asyncio.sleep(0) - - \ No newline at end of file diff --git a/os/user/common/TaskManager.py b/os/user/common/TaskManager.py index 23af6d2c..066c9240 100644 --- a/os/user/common/TaskManager.py +++ b/os/user/common/TaskManager.py @@ -1,17 +1,17 @@ import asyncio -import interface.Task as taskitf -import project.task_defs.TrackTask as track -import project.task_defs.DetectTask as detect +import logging +import queue + +import interface.Task as taskitf import project.task_defs.AvoidTask as avoid +import project.task_defs.DetectTask as detect import project.task_defs.TestTask as test -import queue -import logging +import project.task_defs.TrackTask as track logger = logging.getLogger(__name__) - -class TaskManager(): - + +class TaskManager: def __init__(self, drone, compute, transit_map, task_arg_map): super().__init__() self.trigger_event_queue = queue.Queue() @@ -21,82 +21,108 @@ def __init__(self, drone, compute, transit_map, task_arg_map): self.curr_task_id = None self.transit_map = transit_map self.task_arg_map = task_arg_map - - + self.currentTask = None self.taskCurrentCoroutinue = None ######################################################## TASK ############################################################# def get_current_task(self): return self.curr_task_id - + def retrieve_next_task(self, current_task_id, triggered_event): - logger.info(f"next task, current_task_id {current_task_id}, trigger_event {triggered_event}") + logger.info( + f"next task, current_task_id {current_task_id}, trigger_event {triggered_event}" + ) try: - next_task_id = self.transit_map.get(current_task_id)(triggered_event) + next_task_id = self.transit_map.get(current_task_id)(triggered_event) except Exception as e: logger.info(f"{e}") - + logger.info(f"next_task_id {next_task_id}") return next_task_id - + def transit_task_to(self, task): - logger.info(f"transit to task with task_id: {task.task_id}, current_task_id: {self.curr_task_id}") + logger.info( + f"transit to task with task_id: {task.task_id}, current_task_id: {self.curr_task_id}" + ) self.stop_task() self.start_task(task) self.curr_task_id = task.task_id - + async def init_task(self): - logger.info('init task') + logger.info("init task") self.start_task_id = self.retrieve_next_task("start", None) - logger.info('create start task') + logger.info("create start task") start_task = self.create_task(self.start_task_id) - if start_task != None: + if start_task is not None: # set the current task self.curr_task_id = start_task.task_id logger.info(f"start task, current taskid:{self.curr_task_id}\n") - - + # takeoff logger.info("taking off") await self.drone.takeOff() - + # start self.start_task(start_task) - + def create_task(self, task_id): - logger.info(f'taskid{task_id}') - if (task_id in self.task_arg_map.keys()): - if (self.task_arg_map[task_id].task_type == taskitf.TaskType.Detect): - logger.info('Detect task') - return detect.DetectTask(self.drone, self.compute, task_id, self.trigger_event_queue, self.task_arg_map[task_id]) - elif (self.task_arg_map[task_id].task_type == taskitf.TaskType.Track): - logger.info('Track task') - return track.TrackTask(self.drone, self.compute, task_id, self.trigger_event_queue, self.task_arg_map[task_id]) - elif (self.task_arg_map[task_id].task_type == taskitf.TaskType.Avoid): - logger.info('Avoid task') - return avoid.AvoidTask(self.drone, self.compute, task_id, self.trigger_event_queue, self.task_arg_map[task_id]) - elif (self.task_arg_map[task_id].task_type == taskitf.TaskType.Test): - logger.info('Test task') - return test.TestTask(self.drone, self.compute, task_id, self.trigger_event_queue, self.task_arg_map[task_id]) + logger.info(f"taskid{task_id}") + if task_id in self.task_arg_map: + if self.task_arg_map[task_id].task_type == taskitf.TaskType.Detect: + logger.info("Detect task") + return detect.DetectTask( + self.drone, + self.compute, + task_id, + self.trigger_event_queue, + self.task_arg_map[task_id], + ) + elif self.task_arg_map[task_id].task_type == taskitf.TaskType.Track: + logger.info("Track task") + return track.TrackTask( + self.drone, + self.compute, + task_id, + self.trigger_event_queue, + self.task_arg_map[task_id], + ) + elif self.task_arg_map[task_id].task_type == taskitf.TaskType.Avoid: + logger.info("Avoid task") + return avoid.AvoidTask( + self.drone, + self.compute, + task_id, + self.trigger_event_queue, + self.task_arg_map[task_id], + ) + elif self.task_arg_map[task_id].task_type == taskitf.TaskType.Test: + logger.info("Test task") + return test.TestTask( + self.drone, + self.compute, + task_id, + self.trigger_event_queue, + self.task_arg_map[task_id], + ) return None - + def stop_task(self): - logger.info(f'Stopping current task!') + logger.info("Stopping current task!") if self.taskCurrentCoroutinue: # stop all the transitions of the task self.currentTask.stop_trans() - logger.info(f'transitions in the current task stopped!') - + logger.info("transitions in the current task stopped!") + is_canceled = self.taskCurrentCoroutinue.cancel() if is_canceled: - logger.info(f' task cancelled successfully') - + logger.info(" task cancelled successfully") + def start_task(self, task): - logger.info(f'start the task! task: {str(task)}') + logger.info(f"start the task! task: {str(task)}") self.currentTask = task self.taskCurrentCoroutinue = asyncio.create_task(self.currentTask.run()) - logger.info(f'started the task! task: {str(task)}') + logger.info(f"started the task! task: {str(task)}") def pause_task(self): pass @@ -110,36 +136,34 @@ async def run(self): # start the mc logger.info("start the manager\n") await self.init_task() - + # main logger.info("go to the loop routine\n") while True: # logger.info("TM") # logger.info("loop routine\n") - if (not self.trigger_event_queue.empty()): + if not self.trigger_event_queue.empty(): item = self.trigger_event_queue.get() task_id = item[0] trigger_event = item[1] - logger.info(f"Trigger one event! \n") + logger.info("Trigger one event! \n") logger.info(f"Task id {task_id} \n") logger.info(f"event {trigger_event} \n") - if (task_id == self.get_current_task()): + if task_id == self.get_current_task(): next_task_id = self.retrieve_next_task(task_id, trigger_event) - if (next_task_id == "terminate"): + if next_task_id == "terminate": break else: next_task = self.create_task(next_task_id) logger.info(f"task created taskid {str(next_task.task_id)} \n") self.transit_task_to(next_task) - await asyncio.sleep(0) + await asyncio.sleep(0) except Exception as e: logger.info(f"catching the exception {e} \n") finally: # stop the current task self.stop_task() - #end the tr + # end the tr logger.info("terminate the manager\n") - - diff --git a/os/user/common/__main__.py b/os/user/common/__main__.py index 61ba9a3a..7f9c4835 100644 --- a/os/user/common/__main__.py +++ b/os/user/common/__main__.py @@ -1,21 +1,22 @@ -import asyncio +import asyncio import logging +import os import sys + from MissionController import MissionController -import os logger = logging.getLogger() logger.setLevel(logging.INFO) handler = logging.StreamHandler(sys.stdout) handler.setLevel(logging.INFO) -formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") handler.setFormatter(formatter) logger.addHandler(handler) logger.info("Starting the usr space") current_dir = os.path.dirname(os.path.abspath(__file__)) parent_dir = os.path.dirname(current_dir) -project_dir = os.path.join(parent_dir, 'project') +project_dir = os.path.join(parent_dir, "project") logger.info("proj_path: %s", project_dir) mc = MissionController(project_dir) diff --git a/os/user/interface/Task.py b/os/user/interface/Task.py index b4709c7d..5f44a42c 100644 --- a/os/user/interface/Task.py +++ b/os/user/interface/Task.py @@ -2,37 +2,39 @@ # # SPDX-License-Identifier: GPL-2.0-only -from abc import ABC, abstractmethod import functools import logging import threading +from abc import ABC, abstractmethod + from aenum import Enum -import inspect logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) + class TaskType(Enum): Detect = 1 Track = 2 Avoid = 3 Test = 4 -class TaskArguments(): + +class TaskArguments: def __init__(self, task_type, transitions_attributes, task_attributes): self.task_type = task_type self.task_attributes = task_attributes self.transitions_attributes = transitions_attributes - -class Task(ABC): + +class Task(ABC): def __init__(self, drone, cloudlet, task_id, trigger_event_queue, task_args): self.cloudlet = cloudlet self.drone = drone self.task_attributes = task_args.task_attributes self.transitions_attributes = task_args.transitions_attributes self.task_id = task_id - self.trans_active = [] + self.trans_active = [] self.trans_active_lock = threading.Lock() self.trigger_event_queue = trigger_event_queue @@ -43,25 +45,24 @@ async def run(self): def get_task_id(self): return self.task_id - def _exit(self): # kill all the transitions - logger.info(f"**************exit the task**************\n") + logger.info("**************exit the task**************\n") self.stop_trans() - self.trigger_event_queue.put((self.task_id, "done")) - + self.trigger_event_queue.put((self.task_id, "done")) + def stop_trans(self): - logger.info(f"**************stopping the transitions**************\n") + logger.info("**************stopping the transitions**************\n") for trans in self.trans_active: if trans.is_alive(): trans.stop() trans.join() - logger.info(f"**************the transitions stopped**************\n") - - + logger.info("**************the transitions stopped**************\n") + @classmethod def call_after_exit(cls, func): """Decorator to call _exit after the decorated function completes.""" + @functools.wraps(func) async def wrapper(self, *args, **kwargs): try: @@ -73,9 +74,11 @@ async def wrapper(self, *args, **kwargs): self._exit() return wrapper - + + @abstractmethod def pause(self): pass - + + @abstractmethod def resume(self): pass diff --git a/os/user/interface/Transition.py b/os/user/interface/Transition.py index a7978905..df90e7e5 100644 --- a/os/user/interface/Transition.py +++ b/os/user/interface/Transition.py @@ -1,37 +1,38 @@ -from abc import ABC, abstractmethod import logging import threading - +from abc import ABC, abstractmethod logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) + class Transition(threading.Thread, ABC): def __init__(self, args): super().__init__() - self.task_id = args['task_id'] - self.trans_active = args['trans_active'] - self.trans_active_lock = args['trans_active_lock'] - self.trigger_event_queue = args['trigger_event_queue'] + self.task_id = args["task_id"] + self.trans_active = args["trans_active"] + self.trans_active_lock = args["trans_active_lock"] + self.trigger_event_queue = args["trigger_event_queue"] # self.trigger_event_queue_lock = trigger_event_queue_lock - + @abstractmethod def stop(self): """This is an abstract method that must be implemented in a subclass.""" pass - + def _trigger_event(self, event): - logger.info(f"**************task id {self.task_id}: triggered event! {event}**************\n") + logger.info( + f"**************task id {self.task_id}: triggered event! {event}**************\n" + ) # with self.trigger_event_queue_lock: - self.trigger_event_queue.put((self.task_id, event)) - + self.trigger_event_queue.put((self.task_id, event)) + def _register(self): logger.info(f"**************{self.name} is registering by itself**************\n") with self.trans_active_lock: self.trans_active.append(self) - + def _unregister(self): logger.info(f"**************{self.name} is unregistering by itself**************\n") with self.trans_active_lock: self.trans_active.remove(self) - \ No newline at end of file diff --git a/os/user/system_call_stubs/ComputeStub.py b/os/user/system_call_stubs/ComputeStub.py index bb7052b8..f7649b77 100644 --- a/os/user/system_call_stubs/ComputeStub.py +++ b/os/user/system_call_stubs/ComputeStub.py @@ -5,45 +5,51 @@ import asyncio import logging import os + import zmq from cnc_protocol import cnc_pb2 -from util.utils import setup_socket -from util.utils import SocketOperation -from enum import Enum +from util.utils import SocketOperation, setup_socket logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) context = zmq.Context() cpt_usr_sock = context.socket(zmq.DEALER) -sock_identity = b'usr' +sock_identity = b"usr" cpt_usr_sock.setsockopt(zmq.IDENTITY, sock_identity) -setup_socket(cpt_usr_sock, SocketOperation.CONNECT, 'CPT_USR_PORT', 'Created command frontend socket endpoint', os.environ.get("DATA_ENDPOINT")) +setup_socket( + cpt_usr_sock, + SocketOperation.CONNECT, + "CPT_USR_PORT", + "Created command frontend socket endpoint", + os.environ.get("DATA_ENDPOINT"), +) + class ComputeRespond: - def __init__(self): self.event = asyncio.Event() self.permission = False self.result = None - + def putResult(self, result): self.result = result - + def getResult(self): return self.result - + async def wait(self): await self.event.wait() - - def set (self): + + def set(self): self.event.set() - -class ComputeStub(): + + +class ComputeStub: def __init__(self): - self.seqNum = 1 # set the initial seqNum to 1 caz cnc proto does not support to show 0 + self.seqNum = 1 # set the initial seqNum to 1 caz cnc proto does not support to show 0 self.seqNum_res = {} - + def sender(self, request, computeRespond): seqNum = self.seqNum logger.info(f"Sending request with seqNum: {seqNum}") @@ -52,7 +58,7 @@ def sender(self, request, computeRespond): self.seqNum_res[seqNum] = computeRespond serialized_request = request.SerializeToString() cpt_usr_sock.send_multipart([serialized_request]) - + def receiver(self, response_parts): if not response_parts: logger.error("Received empty response parts") @@ -65,7 +71,7 @@ def receiver(self, response_parts): data_rep = cnc_pb2.Compute() data_rep.ParseFromString(response) - if data_rep.seqNum == 0: # not the cpt reply + if data_rep.seqNum == 0: # not the cpt reply logger.info("Response does not look like Compute. Trying Driver parsing...") try: data_rep = cnc_pb2.Driver() @@ -88,10 +94,10 @@ def receiver(self, response_parts): elif status == cnc_pb2.ResponseStatus.COMPLETED: logger.info("STAGE 2: COMPLETED") - if hasattr(data_rep, 'getter') and hasattr(data_rep.getter, 'result'): + if hasattr(data_rep, "getter") and hasattr(data_rep.getter, "result"): logger.info("cpt rep") computeRespond.putResult(data_rep.getter.result) - elif hasattr(data_rep, 'getTelemetry'): + elif hasattr(data_rep, "getTelemetry"): logger.info("tel rep") computeRespond.putResult(data_rep.getTelemetry) else: @@ -99,7 +105,6 @@ def receiver(self, response_parts): computeRespond.set() - async def run(self): while True: try: @@ -111,32 +116,33 @@ async def run(self): logger.error(f"Failed to parse message: {e}") break await asyncio.sleep(0) - - '''Helper method to send a request and wait for a response''' + + """Helper method to send a request and wait for a response""" + async def send_and_wait(self, request): computeRespond = ComputeRespond() self.sender(request, computeRespond) - + await computeRespond.wait() return computeRespond.getResult() - + # Get results for a compute engine async def getResults(self, compute_type): logger.info(f"Getting results for compute type: {compute_type}") cpt_req = cnc_pb2.Compute() cpt_req.getter.compute_type = compute_type - + result = await self.send_and_wait(cpt_req) return result - + async def clearResults(self): logger.info("Clearing results") cpt_req = cnc_pb2.Compute() cpt_req.clear = True await self.send_and_wait(cpt_req) + """ Telemetry methods """ - ''' Telemetry methods ''' async def getTelemetry(self): logger.info("Getting telemetry") request = cnc_pb2.Driver(getTelemetry=cnc_pb2.Telemetry()) @@ -154,7 +160,7 @@ async def getTelemetry(self): logger.debug(f"Got telemetry: {result}\n") telDict["name"] = result.drone_name telDict["battery"] = result.battery - telDict["attitude"]["yaw"] = result.drone_attitude.yaw + telDict["attitude"]["yaw"] = result.drone_attitude.yaw telDict["attitude"]["pitch"] = result.drone_attitude.pitch telDict["attitude"]["roll"] = result.drone_attitude.roll telDict["satellites"] = result.satellites @@ -168,5 +174,5 @@ async def getTelemetry(self): logger.debug(f"finished receiving Telemetry: {telDict}") return telDict else: - logger.error("Failed to get telemetry") - return None \ No newline at end of file + logger.error("Failed to get telemetry") + return None diff --git a/os/user/system_call_stubs/DroneStub.py b/os/user/system_call_stubs/DroneStub.py index 38d2be48..aa5b9fb7 100644 --- a/os/user/system_call_stubs/DroneStub.py +++ b/os/user/system_call_stubs/DroneStub.py @@ -1,40 +1,47 @@ import asyncio import logging import os +from enum import Enum + import zmq from cnc_protocol import cnc_pb2 -from enum import Enum -from util.utils import setup_socket -from util.utils import SocketOperation +from util.utils import SocketOperation, setup_socket logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) context = zmq.Context() cmd_front_usr_sock = context.socket(zmq.DEALER) -sock_identity = b'usr' +sock_identity = b"usr" cmd_front_usr_sock.setsockopt(zmq.IDENTITY, sock_identity) -setup_socket(cmd_front_usr_sock, SocketOperation.CONNECT, 'CMD_FRONT_USR_PORT', 'Created command frontend socket endpoint', os.environ.get("CMD_ENDPOINT")) +setup_socket( + cmd_front_usr_sock, + SocketOperation.CONNECT, + "CMD_FRONT_USR_PORT", + "Created command frontend socket endpoint", + os.environ.get("CMD_ENDPOINT"), +) + -######################################################## DriverRespond ############################################################ +######################################################## DriverRespond ############################################################ class DriverRespond: def __init__(self): self.event = asyncio.Event() self.permission = False self.result = None - + def putResult(self, result): self.result = result - + def getResult(self): return self.result def set(self): self.event.set() - + def grantPermission(self): self.permission = True - + def checkPermission(self): return self.permission @@ -44,10 +51,9 @@ async def wait(self): ######################################################## DroneStub ############################################################ class DroneStub: - ######################################################## Common ############################################################ def __init__(self): - self.seqNum = 1 # set the initial seqNum to 1 caz cnc proto does not support to show 0 + self.seqNum = 1 # set the initial seqNum to 1 caz cnc proto does not support to show 0 self.seqNum_res = {} def sender(self, request, driverRespond): @@ -83,11 +89,11 @@ def receiver(self, response_parts): if status == cnc_pb2.ResponseStatus.FAILED: logger.error("STAGE 2: FAILED") elif status == cnc_pb2.ResponseStatus.COMPLETED: - logger.info("STAGE 2: COMPLETED") + logger.info("STAGE 2: COMPLETED") driverRespond.grantPermission() driverRespond.putResult(driver_rep) - driverRespond.set() - + driverRespond.set() + async def run(self): while True: try: @@ -99,17 +105,18 @@ async def run(self): logger.error(f"Failed to parse message: {e}") break await asyncio.sleep(0) - ######################################################## RPC ############################################################ - '''Helper method to send a request and wait for a response''' + """Helper method to send a request and wait for a response""" + async def send_and_wait(self, request): driverRespond = DriverRespond() self.sender(request, driverRespond) await driverRespond.wait() return driverRespond.getResult() if driverRespond.checkPermission() else None - - ''' Preemptive methods ''' + + """ Preemptive methods """ + async def takeOff(self): logger.info("takeOff") request = cnc_pb2.Driver(takeOff=True) @@ -127,14 +134,15 @@ async def rth(self): request = cnc_pb2.Driver(rth=True) result = await self.send_and_wait(request) return result.rth if result else False - + async def hover(self): logger.info("hover") request = cnc_pb2.Driver(hover=True) result = await self.send_and_wait(request) return result.hover if result else False - ''' Location methods ''' + """ Location methods """ + async def setHome(self, name, lat, lng, alt): logger.info("setHome") location = cnc_pb2.Location(name=name, latitude=lat, longitude=lng, altitude=alt) @@ -144,7 +152,7 @@ async def setHome(self, name, lat, lng, alt): async def getHome(self): pass - + # logger.info("getHome") # request = cnc_pb2.Driver(getHome=cnc_pb2.Location()) # result = await self.send_and_wait(request) @@ -153,55 +161,59 @@ async def getHome(self): # else: # return False - ''' Attitude methods ''' + """ Attitude methods """ + async def setAttitude(self, yaw, pitch, roll, thrust): logger.info("setAttitude") - attitude = cnc_pb2.Attitude(yaw = yaw, pitch = pitch, roll = roll, thrust = thrust) + attitude = cnc_pb2.Attitude(yaw=yaw, pitch=pitch, roll=roll, thrust=thrust) request = cnc_pb2.Driver(setAttitude=attitude) - + result = await self.send_and_wait(request) return result.setAttitude if result else False - - ''' Position methods ''' + + """ Position methods """ + async def setVelocity(self, forward_vel, right_vel, up_vel, angle_vel): logger.info("setVelocity") - velocity = cnc_pb2.Velocity(forward_vel=forward_vel, right_vel=right_vel, up_vel=up_vel, angle_vel=angle_vel) + velocity = cnc_pb2.Velocity( + forward_vel=forward_vel, right_vel=right_vel, up_vel=up_vel, angle_vel=angle_vel + ) request = cnc_pb2.Driver(setVelocity=velocity) result = await self.send_and_wait(request) return result.setVelocity if result else False - + async def setRelativePosition(self, forward, right, up, angle): logger.info("setRelativePosition") position = cnc_pb2.Position(forward=forward, right=right, up=up, angle=angle) request = cnc_pb2.Driver(setRelativePosition=position) result = await self.send_and_wait(request) return result.setRelativePosition if result else False - - + async def setTranslatedPosition(self, forward, right, up, angle): logger.info("setTranslatedPosition") position = cnc_pb2.Position(forward=forward, right=right, up=up, angle=angle) request = cnc_pb2.Driver(setTranslatedPosition=position) result = await self.send_and_wait(request) return result.setTranslatedPosition if result else False - + async def setGPSLocation(self, latitude, longitude, altitude, bearing): logger.info("setGPSLocation") - location = cnc_pb2.Location(latitude=latitude, longitude=longitude, altitude=altitude, bearing=bearing) + location = cnc_pb2.Location( + latitude=latitude, longitude=longitude, altitude=altitude, bearing=bearing + ) request = cnc_pb2.Driver(setGPSLocation=location) result = await self.send_and_wait(request) return result.setGPSLocation if result else False - - - ''' Camera methods ''' + + """ Camera methods """ + # define a camera type enum class CameraType(Enum): RGB = cnc_pb2.CameraType.RGB STEREO = cnc_pb2.CameraType.STEREO THERMAL = cnc_pb2.CameraType.THERMAL NIGHT = cnc_pb2.CameraType.NIGHT - - + async def getCameras(self): logger.info("getCameras") request = cnc_pb2.Driver(getCameras=cnc_pb2.Camera()) @@ -209,14 +221,13 @@ async def getCameras(self): if result: id = result.getCameras.id type = self.CameraType(result.getCameras.type) - + return [id, type] - else: + else: return False async def switchCamera(self, camera_id): logger.info("switchCamera") request = cnc_pb2.Driver(switchCamera=camera_id) result = await self.send_and_wait(request) - return result.switchCameras if result else False - + return result.switchCameras if result else False diff --git a/os/util/timer.py b/os/util/timer.py index 3dde071f..d6165049 100644 --- a/os/util/timer.py +++ b/os/util/timer.py @@ -1,19 +1,20 @@ +import logging import time from dataclasses import dataclass, field -from typing import Optional -import logging + class TimerError(Exception): """A custom exception used to report errors in use of Timer class""" + @dataclass class Timer: logger: logging.Logger - name: Optional[str] = None + name: str | None = None text: str = "{} took {:0.4f} seconds" - max_frequency: Optional[int] = None - _start_time: Optional[float] = field(default=None, init=False, repr=False) - _last_log_time: Optional[int] = None + max_frequency: int | None = None + _start_time: float | None = field(default=None, init=False, repr=False) + _last_log_time: int | None = None def __enter__(self): self.start() @@ -25,14 +26,14 @@ def __exit__(self, exc_type, exc_value, traceback): def start(self) -> None: """Start a new timer""" if self._start_time is not None: - raise TimerError(f"Timer is running. Use .stop() to stop it") + raise TimerError("Timer is running. Use .stop() to stop it") self._start_time = time.perf_counter() def stop(self) -> float: """Stop the timer, and report the elapsed time""" if self._start_time is None: - raise TimerError(f"Timer is not running. Use .start() to start it") + raise TimerError("Timer is not running. Use .start() to start it") # Calculate elapsed time elapsed_time = time.perf_counter() - self._start_time diff --git a/os/util/utils.py b/os/util/utils.py index ea2b530d..52de9faf 100644 --- a/os/util/utils.py +++ b/os/util/utils.py @@ -1,25 +1,28 @@ -from enum import Enum import logging import os +from enum import Enum + import zmq import zmq.asyncio logger = logging.getLogger(__name__) + class SocketOperation(Enum): BIND = 1 CONNECT = 2 + def setup_socket(socket, socket_op, port_num, logger_message, host_addr="*"): # Get port number from environment variables port = os.environ.get(port_num, "") if not port: - logger.fatal(f'Cannot get {port_num} from system') + logger.fatal(f"Cannot get {port_num} from system") quit() # Construct the address - addr = f'tcp://{host_addr}:{port}' + addr = f"tcp://{host_addr}:{port}" if socket_op == SocketOperation.CONNECT: logger.info(f"Connecting socket to {addr=}") @@ -33,21 +36,21 @@ def setup_socket(socket, socket_op, port_num, logger_message, host_addr="*"): logger.info(logger_message) -async def lazy_pirate_request(socket, payload, ctx, server_endpoint, retries=3, - timeout=2500): + +async def lazy_pirate_request(socket, payload, ctx, server_endpoint, retries=3, timeout=2500): if retries <= 0: raise ValueError(f"Retries must be positive; {retries=}") # Send payload socket.send(payload) retries_left = retries - while retries_left == None or retries_left > 0: + while retries_left is None or retries_left > 0: # Check if reply received within timeout poll_result = await socket.poll(timeout) if (poll_result & zmq.POLLIN) != 0: reply = await socket.recv() return (socket, reply) - if retries_left != None: + if retries_left is not None: retries_left -= 1 logger.warning(f"Request timeout for {server_endpoint=}") @@ -65,4 +68,3 @@ async def lazy_pirate_request(socket, payload, ctx, server_endpoint, retries=3, logger.info(f"Resending payload to {server_endpoint=}...") socket.send(payload) - diff --git a/test.py b/test.py new file mode 100644 index 00000000..e69de29b