From fdca5a18d6fe6f2f6516e160d6d276152dfdf099 Mon Sep 17 00:00:00 2001 From: Rob Royce Date: Fri, 23 Aug 2024 16:22:25 -0700 Subject: [PATCH] Improve ROS2 tools and add unit tests with CI (#15) * Simplified within_bounds function by removing redundant 'elif' condition. Improved code readability and maintainability. (#13) * Add unit tests and CI. (#14) * refactor: better error handling and response parsing for ROS2 tools, add blacklist where applicable. * feat(ros2): add ros2 topic echo tool. * chore: bump version to 1.0.4, update CHANGELOG.md * chore: bump langchain versions. * feat(tests): add unit tests for most tools and the ROSATools class. * fix: passing a blacklist into any of the tools no longer overrides the blacklist passed into the ROSA constructor. They are concatenated instead. * feat(CI): add ci workflow. * fix: properly filter out blacklisted topics and nodes. * feat(tests): add ros2 tests. * feat(ci): update humble jobs. * feat(tests): add stubs for additional test classes. * docs: update README * chore: bump version to 1.0.5 --------- Co-authored-by: Kejun Liu <119113065+dawnkisser@users.noreply.github.com> --- .github/workflows/ci.yml | 70 +++ CHANGELOG.md | 15 + Dockerfile | 2 +- README.md | 29 +- setup.py | 8 +- src/rosa/tools/__init__.py | 38 +- src/rosa/tools/log.py | 34 +- src/rosa/tools/ros1.py | 110 ++++- src/rosa/tools/ros2.py | 76 +-- src/rosa/tools/system.py | 2 - src/turtle_agent/scripts/tools/turtle.py | 2 +- tests/__init__.py | 0 tests/rosa/__init__.py | 0 tests/rosa/test_prompts.py | 13 + tests/rosa/test_rosa.py | 13 + tests/rosa/tools/__init__.py | 0 tests/rosa/tools/test_calculation.py | 195 ++++++++ tests/rosa/tools/test_log.py | 176 +++++++ tests/rosa/tools/test_ros1.py | 603 +++++++++++++++++++++++ tests/rosa/tools/test_ros2.py | 290 +++++++++++ tests/rosa/tools/test_rosa_tools.py | 110 +++++ tests/rosa/tools/test_system.py | 60 +++ 22 files changed, 1723 insertions(+), 123 deletions(-) create mode 100644 .github/workflows/ci.yml create mode 100644 tests/__init__.py create mode 100644 tests/rosa/__init__.py create mode 100644 tests/rosa/test_prompts.py create mode 100644 tests/rosa/test_rosa.py create mode 100644 tests/rosa/tools/__init__.py create mode 100644 tests/rosa/tools/test_calculation.py create mode 100644 tests/rosa/tools/test_log.py create mode 100644 tests/rosa/tools/test_ros1.py create mode 100644 tests/rosa/tools/test_ros2.py create mode 100644 tests/rosa/tools/test_rosa_tools.py create mode 100644 tests/rosa/tools/test_system.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..0729c58 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,70 @@ +name: CI Pipeline + +on: + push: + branches: + - main + - dev + pull_request: + branches: + - main + - dev + +jobs: + test-noetic: + runs-on: ubuntu-latest + container: + image: osrf/ros:noetic-desktop + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.x' + + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y libc6 libc6-dev + sudo apt-get install -y python3.9 + sudo apt-get install -y python3-pip + python3.9 -m pip install --user -e . + shell: bash + + - name: Run tests + run: | + . /opt/ros/noetic/setup.bash + python3.9 -m unittest discover -s tests --verbose + shell: bash + env: + ROS_VERSION: 1 + + test-humble: + runs-on: ubuntu-latest + container: + image: osrf/ros:humble-desktop + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y python3-pip + python3.10 -m pip install --user -e . + shell: bash + + - name: Run tests + run: | + . /opt/ros/humble/setup.bash + python3.10 -m unittest discover -s tests --verbose + shell: bash + env: + ROS_VERSION: 2 diff --git a/CHANGELOG.md b/CHANGELOG.md index f4ecc9d..b0c2b67 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,21 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.0.5] + +### Added + +* CI pipeline for automated testing +* Unit tests for ROSA tools and utilities + +### Changed + +* Improvements to various ROS2 tools +* Upgrade dependencies: + * `langchain` to 0.2.14 + * `langchain_core` to 0.2.34 + * `langchain-openai` to 0.1.22 + ## [1.0.4] - 2024-08-21 ### Added diff --git a/Dockerfile b/Dockerfile index b26d22f..cff6308 100644 --- a/Dockerfile +++ b/Dockerfile @@ -28,7 +28,7 @@ RUN apt-get update && apt-get install -y \ RUN apt-get update && apt-get install -y python3.9 RUN apt-get update && apt-get install -y python3-pip RUN python3 -m pip install -U python-dotenv catkin_tools -RUN python3.9 -m pip install -U jpl-rosa>=1.0.4 +RUN python3.9 -m pip install -U jpl-rosa>=1.0.5 # Configure ROS RUN rosdep update diff --git a/README.md b/README.md index 86cd2d3..e819ae3 100644 --- a/README.md +++ b/README.md @@ -7,17 +7,26 @@

ROSA - Robot Operating System Agent

-
ROSA is an AI Agent designed to interact with ROS-based robotics systems using natural language queries.
- +
+  ROSA is an AI Agent designed to interact with ROS-based robotics systems
using natural language queries. +
+ +
-![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/nasa-jpl/rosa/publish.yml) -![Static Badge](https://img.shields.io/badge/Python->=3.9-blue) -![Static Badge](https://img.shields.io/badge/ROS_1-Supported-blue) -![Static Badge](https://img.shields.io/badge/ROS_2-Supported-blue) +![Static Badge](https://img.shields.io/badge/ROS_1-Noetic-blue) +![Static Badge](https://img.shields.io/badge/ROS_2-Humble|Iron|Jazzy-blue) ![PyPI - License](https://img.shields.io/pypi/l/jpl-rosa) +[![SLIM](https://img.shields.io/badge/Best%20Practices%20from-SLIM-blue)](https://nasa-ammos.github.io/slim/) + +![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/nasa-jpl/rosa/ci.yml?branch=main&label=main) +![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/nasa-jpl/rosa/ci.yml?branch=dev&label=dev) +![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/nasa-jpl/rosa/publish.yml?label=publish) ![PyPI - Version](https://img.shields.io/pypi/v/jpl-rosa) ![PyPI - Downloads](https://img.shields.io/pypi/dw/jpl-rosa) -[![SLIM](https://img.shields.io/badge/Best%20Practices%20from-SLIM-blue)](https://nasa-ammos.github.io/slim/) + +
+ + ROSA is an AI agent that can be used to interact with ROS1 _and_ ROS2 systems in order to carry out various tasks. It is built using the open-source [Langchain](https://python.langchain.com/v0.2/docs/introduction/) framework, and can @@ -90,9 +99,11 @@ rosa.invoke("Show me a list of topics that have publishers but no subscribers") ## TurtleSim Demo -We have included a demo that uses ROSA to control the TurtleSim robot in simulation. To run the demo, you will need to have Docker installed on your machine. +We have included a demo that uses ROSA to control the TurtleSim robot in simulation. To run the demo, you will need to +have Docker installed on your machine. -The following video shows ROSA reasoning about how to draw a 5-point star, then executing the necessary commands to do so. +The following video shows ROSA reasoning about how to draw a 5-point star, then executing the necessary commands to do +so. https://github.com/user-attachments/assets/77b97014-6d2e-4123-8d0b-ea0916d93a4e diff --git a/setup.py b/setup.py index 926ba7a..f3e71a3 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ setup( name="jpl-rosa", - version="1.0.4", + version="1.0.5", license="Apache 2.0", description="ROSA: the Robot Operating System Agent", long_description=long_description, @@ -49,10 +49,10 @@ install_requires=[ "PyYAML==6.0.1", "python-dotenv>=1.0.1", - "langchain==0.2.13", + "langchain==0.2.14", "langchain-community==0.2.12", - "langchain-core==0.2.32", - "langchain-openai==0.1.21", + "langchain-core==0.2.34", + "langchain-openai==0.1.22", "pydantic", "pyinputplus", "azure-identity", diff --git a/src/rosa/tools/__init__.py b/src/rosa/tools/__init__.py index 78221f2..4285dde 100644 --- a/src/rosa/tools/__init__.py +++ b/src/rosa/tools/__init__.py @@ -19,7 +19,7 @@ from langchain.agents import Tool -def inject_blacklist(blacklist): +def inject_blacklist(default_blacklist: List[str]): """ Inject a blacklist parameter into @tool functions that require it. Required because we do not want to rely on the LLM to manually use the blacklist, as it may "forget" to do so. @@ -32,18 +32,28 @@ def inject_blacklist(blacklist): def decorator(func): @wraps(func) def wrapper(*args, **kwargs): - if "blacklist" in kwargs: - kwargs["blacklist"] = blacklist + if args and isinstance(args[0], dict): + if "blacklist" in args[0]: + args[0]["blacklist"] = default_blacklist + args[0]["blacklist"] + else: + args[0]["blacklist"] = default_blacklist else: - params = inspect.signature(func).parameters - if "blacklist" in params: - kwargs["blacklist"] = blacklist + if "blacklist" in kwargs: + kwargs["blacklist"] = default_blacklist + kwargs["blacklist"] + else: + params = inspect.signature(func).parameters + if "blacklist" in params: + kwargs["blacklist"] = default_blacklist return func(*args, **kwargs) # Rebuild the signature to include 'blacklist' sig = inspect.signature(func) new_params = [ - param.replace(default=blacklist) if param.name == "blacklist" else param + ( + param.replace(default=default_blacklist) + if param.name == "blacklist" + else param + ) for param in sig.parameters.values() ] wrapper.__signature__ = sig.replace(parameters=new_params) @@ -68,19 +78,13 @@ def __init__( self.__iterative_add(system) if self.__ros_version == 1: - try: - from . import ros1 + from . import ros1 - self.__iterative_add(ros1, blacklist=blacklist) - except Exception as e: - print(e) + self.__iterative_add(ros1, blacklist=blacklist) elif self.__ros_version == 2: - try: - from . import ros2 + from . import ros2 - self.__iterative_add(ros2, blacklist=blacklist) - except Exception as e: - print(e) + self.__iterative_add(ros2, blacklist=blacklist) else: raise ValueError("Invalid ROS version. Must be either 1 or 2.") diff --git a/src/rosa/tools/log.py b/src/rosa/tools/log.py index ea5a475..d6544ae 100644 --- a/src/rosa/tools/log.py +++ b/src/rosa/tools/log.py @@ -13,7 +13,7 @@ # limitations under the License. import os -from typing import Optional +from typing import Optional, Literal from langchain.agents import tool @@ -22,20 +22,28 @@ def read_log( log_file_directory: str, log_filename: str, - level_filter: Optional[str], - line_range: tuple = (-200, -1), + level_filter: Optional[ + Literal[ + "ERROR", "INFO", "DEBUG", "WARNING", "CRITICAL", "FATAL", "TRACE", "DEBUG" + ] + ] = None, + num_lines: Optional[int] = None, ) -> dict: """ Read a log file and return the log lines that match the level filter and line range. - :arg log_file_directory: The directory containing the log file to read (use your tools to get it) - :arg log_filename: The path to the log file to read - :arg level_filter: Only show log lines that contain this level (e.g. "ERROR", "INFO", "DEBUG", etc.) - :arg line_range: A tuple of two integers representing the start and end line numbers to return + :param log_file_directory: The directory containing the log file to read (use your tools to get it) + :param log_filename: The path to the log file to read + :param level_filter: Only show log lines that contain this level (e.g. "ERROR", "INFO", "DEBUG", etc.) + :param num_lines: The number of most recent lines to return from the log file """ + if num_lines is not None and num_lines < 1: + return {"error": "Invalid `num_lines` argument. It must be a positive integer."} + if not os.path.exists(log_file_directory): return { - "error": f"The log directory '{log_file_directory}' does not exist. You should first use your tools to get the correct log directory." + "error": f"The log directory '{log_file_directory}' does not exist. You should first use your tools to " + f"get the correct log directory." } full_log_path = os.path.join(log_file_directory, log_filename) @@ -56,13 +64,15 @@ def read_log( for i in range(len(log_lines)): log_lines[i] = f"line {i+1}: " + log_lines[i].strip() - print(f"Reading log file '{log_filename}' lines {line_range[0]} to {line_range[1]}") - log_lines = log_lines[line_range[0] : line_range[1]] + if num_lines is not None: + # Get the most recent num_lines from the log file + log_lines = log_lines[-num_lines:] # If there are more than 200 lines, return a message to use the line_range argument if len(log_lines) > 200: return { - "error": f"The log file '{log_filename}' has more than 200 lines. Please use the `line_range` argument to read a subset of the log file at a time." + "error": f"The log file '{log_filename}' has more than 200 lines. Please use the `num_lines` argument to " + f"read a subset of the log file at a time." } if level_filter is not None: @@ -72,7 +82,7 @@ def read_log( "log_filename": log_filename, "log_file_directory": log_file_directory, "level_filter": level_filter, - "line_range": line_range, + "requested_num_lines": num_lines, "total_lines": total_lines, "lines_returned": len(log_lines), "lines": log_lines, diff --git a/src/rosa/tools/ros1.py b/src/rosa/tools/ros1.py index 36bcd93..2aeb2a5 100644 --- a/src/rosa/tools/ros1.py +++ b/src/rosa/tools/ros1.py @@ -51,9 +51,17 @@ def get_entities( in_namespace = len(entities) if pattern: - entities = list(filter(lambda x: regex.match(f".*{pattern}", x), entities)) + entities = list(filter(lambda x: regex.match(f".*{pattern}.*", x), entities)) match_pattern = len(entities) + if blacklist: + entities = list( + filter( + lambda x: not any(regex.match(f".*{bl}.*", x) for bl in blacklist), + entities, + ) + ) + if total == 0: entities = [f"There are currently no {type}s available in the system."] elif in_namespace == 0: @@ -65,22 +73,11 @@ def get_entities( f"There are currently no {type}s available matching the specified pattern." ] - if blacklist: - entities = list( - filter( - lambda x: not any( - regex.match(f".*{pattern}", x) for pattern in blacklist - ), - entities, - ) - ) - return total, in_namespace, match_pattern, sorted(entities) @tool def rosgraph_get( - namespace: Optional[str] = "/", node_pattern: Optional[str] = ".*", topic_pattern: Optional[str] = ".*", blacklist: List[str] = None, @@ -89,12 +86,12 @@ def rosgraph_get( """ Get a list of tuples representing nodes and topics in the ROS graph. - :param namespace: ROS namespace to scope return values by. Namespace must already be resolved. :param node_pattern: A regex pattern for the nodes to include in the graph (publishers and subscribers). :param topic_pattern: A regex pattern for the topics to include in the graph. :param exclude_self_connections: Exclude connections where the publisher and subscriber are the same node. :note: you should avoid using the topic pattern when searching for nodes, as it may not return any results. + :important: you must NOT use this function to get lists of nodes, topics, etc. Example regex patterns: - .*node.* any node containing "node" @@ -102,9 +99,6 @@ def rosgraph_get( - node.* any node that starts with "node" - (.*node1.*|.*node2.*|.*node3.*) any node containing either "node1", "node2", or "node3" """ - rospy.loginfo( - f"Getting ROS graph with namespace '{namespace}', node_pattern '{node_pattern}', and topic_pattern '{topic_pattern}'" - ) try: publishers, subscribers, services = rosgraph.masterapi.Master( "/rosout" @@ -118,8 +112,6 @@ def rosgraph_get( for pub in publishers: for node in pub[1]: - if namespace and not node.startswith(namespace): - continue if pub[0] in topic_pub_map: topic_pub_map[pub[0]].append(node) else: @@ -127,8 +119,6 @@ def rosgraph_get( for sub in subscribers: for node in sub[1]: - if namespace and not node.startswith(namespace): - continue if sub[0] in topic_sub_map: topic_sub_map[sub[0]].append(node) else: @@ -229,6 +219,14 @@ def rostopic_list( except Exception as e: return {"error": f"Failed to get ROS topics: {e}"} + if blacklist: + topics = list( + filter( + lambda x: not any(regex.match(f".*{bl}.*", x) for bl in blacklist), + topics, + ) + ) + return dict( namespace=namespace if namespace else "/", pattern=pattern if pattern else ".*", @@ -260,6 +258,14 @@ def rosnode_list( except Exception as e: return {"error": f"Failed to get ROS nodes: {e}"} + if blacklist: + nodes = list( + filter( + lambda x: not any(regex.match(f".*{bl}.*", x) for bl in blacklist), + nodes, + ) + ) + return dict( namespace=namespace if namespace else "/", pattern=pattern if pattern else ".*", @@ -360,7 +366,6 @@ def rostopic_echo( for i in range(count): try: msg = rospy.wait_for_message(topic, msg_class, timeout) - print(msg) if return_echoes: msgs.append(msg) @@ -369,7 +374,6 @@ def rostopic_echo( time.sleep(delay) except (rospy.ROSException, rospy.ROSInterruptException) as e: - print(f"Failed to get message from topic '{topic}': {e}") break response = dict(topic=topic, requested_count=count, actual_count=len(msgs)) @@ -484,7 +488,6 @@ def rosservice_call(service: str, args: List[str]) -> dict: :param service: The name of the ROS service to call. :param args: A list of arguments to pass to the service. """ - print(f"Calling ROS service '{service}' with arguments: {args}") try: response = rosservice.call_service(service, args) return response @@ -519,7 +522,6 @@ def rossrv_info(srv_type: List[str], raw: bool = False) -> dict: for srv in srv_type: # Get the Python class corresponding to the srv file - print(f"Getting details for {srv}") srv_path = rosmsg.get_srv_text(srv, raw=raw) details[srv] = srv_path return details @@ -740,3 +742,63 @@ def get_roslog_directories() -> dict: latest=latest_directory, from_env=from_env, ) + + +@tool +def roslaunch(package: str, launch_file: str) -> str: + """Launches a ROS launch file. + + :param package: The name of the ROS package containing the launch file. + :param launch_file: The name of the launch file to launch. + """ + rospy.loginfo(f"Launching ROS launch file '{launch_file}' in package '{package}'") + try: + os.system(f"roslaunch {package} {launch_file}") + return f"Launched ROS launch file '{launch_file}' in package '{package}'." + except Exception as e: + return f"Failed to launch ROS launch file '{launch_file}' in package '{package}': {e}" + + +@tool +def roslaunch_list(package: str) -> dict: + """Returns a list of available ROS launch files in a package. + + :param package: The name of the ROS package to list launch files for. + """ + rospy.loginfo(f"Getting ROS launch files in package '{package}'") + try: + rospack = rospkg.RosPack() + directory = rospack.get_path(package) + launch = os.path.join(directory, "launch") + + launch_files = [] + + # Get all files in the launch directory + if os.path.exists(launch): + launch_files = [ + f for f in os.listdir(launch) if os.path.isfile(os.path.join(launch, f)) + ] + + return { + "package": package, + "directory": directory, + "total": len(launch_files), + "launch_files": launch_files, + } + + except Exception as e: + return {"error": f"Failed to get ROS launch files in package '{package}': {e}"} + + +@tool +def rosnode_kill(node: str) -> str: + """Kills a specific ROS node. + + :param node: The name of the ROS node to kill. + """ + rospy.loginfo(f"Killing ROS node '{node}'") + try: + os.system(f"rosnode kill {node}") + return f"Killed ROS node '{node}'." + except Exception as e: + return f"Failed to kill ROS node '{node}': {e}" diff --git a/src/rosa/tools/ros2.py b/src/rosa/tools/ros2.py index b825b69..bdfb359 100644 --- a/src/rosa/tools/ros2.py +++ b/src/rosa/tools/ros2.py @@ -83,6 +83,8 @@ def get_entities( if pattern: entities = list(filter(lambda x: re.match(f".*{pattern}.*", x), entities)) + entities = [e for e in entities if e.strip() != ""] + return entities @@ -189,57 +191,6 @@ def ros2_node_info(nodes: List[str]) -> dict: return data -def parse_ros2_topic_info(output): - topic_info = {"name": "", "type": "", "publishers": [], "subscribers": []} - - lines = output.split("\n") - - # Extract the topic name - for line in lines: - if line.startswith("ros2 topic info"): - topic_info["name"] = line.split(" ")[3] - - # Extract the Type - for line in lines: - if line.startswith("Type:"): - topic_info["type"] = line.split(": ")[1] - - # Extract publisher and subscriber sections - publisher_section = "" - subscriber_section = "" - collecting_publishers = False - collecting_subscribers = False - - for line in lines: - if line.startswith("Publisher count:"): - collecting_publishers = True - collecting_subscribers = False - elif line.startswith("Subscription count:"): - collecting_publishers = False - collecting_subscribers = True - - if collecting_publishers: - publisher_section += line + "\n" - if collecting_subscribers: - subscriber_section += line + "\n" - - # Extract node names for publishers - publisher_lines = publisher_section.split("\n") - for line in publisher_lines: - if line.startswith("Node name:"): - node_name = line.split(": ")[1] - topic_info["publishers"].append(node_name) - - # Extract node names for subscribers - subscriber_lines = subscriber_section.split("\n") - for line in subscriber_lines: - if line.startswith("Node name:"): - node_name = line.split(": ")[1] - topic_info["subscribers"].append(node_name) - - return topic_info - - @tool def ros2_topic_info(topics: List[str]) -> dict: """ @@ -255,7 +206,7 @@ def ros2_topic_info(topics: List[str]) -> dict: if not success: topic_info = dict(error=output) else: - topic_info = parse_ros2_topic_info(output) + topic_info = output data[topic] = topic_info @@ -276,7 +227,17 @@ def ros2_param_list( """ if node_name: cmd = f"ros2 param list {node_name}" - params = get_entities(cmd, pattern=pattern, blacklist=blacklist) + success, output = execute_ros_command(cmd) + if not success: + return {"error": output} + + params = [o for o in output.split("\n") if o] + if pattern: + params = [p for p in params if re.match(f".*{pattern}.*", p)] + if blacklist: + params = [ + p for p in params if not any(re.match(f".*{b}.*", p) for b in blacklist) + ] return {node_name: params} else: cmd = f"ros2 param list" @@ -296,6 +257,15 @@ def ros2_param_list( data[current_node] = [] elif line.strip() != "": data[current_node].append(line.strip()) + + if pattern: + data = {k: v for k, v in data.items() if re.match(f".*{pattern}.*", k)} + if blacklist: + data = { + k: v + for k, v in data.items() + if not any(re.match(f".*{b}.*", k) for b in blacklist) + } return data diff --git a/src/rosa/tools/system.py b/src/rosa/tools/system.py index e18295a..96ae8d1 100644 --- a/src/rosa/tools/system.py +++ b/src/rosa/tools/system.py @@ -26,7 +26,6 @@ def set_verbosity(enable_verbose_messages: bool) -> str: :arg enable_verbose_messages: A boolean value to enable or disable verbose messages. """ global VERBOSE - print(f"Setting verbosity to {enable_verbose_messages}") VERBOSE = enable_verbose_messages set_verbose(VERBOSE) return f"Verbose messages are now {'enabled' if VERBOSE else 'disabled'}." @@ -41,7 +40,6 @@ def set_debuging(enable_debug_messages: bool) -> str: :arg enable_debug_messages: A boolean value to enable or disable debug messages. """ global DEBUG - print(f"Setting debug to {enable_debug_messages}") DEBUG = enable_debug_messages set_debug(DEBUG) return f"Debug messages are now {'enabled' if DEBUG else 'disabled'}." diff --git a/src/turtle_agent/scripts/tools/turtle.py b/src/turtle_agent/scripts/tools/turtle.py index 398ddee..0068d20 100644 --- a/src/turtle_agent/scripts/tools/turtle.py +++ b/src/turtle_agent/scripts/tools/turtle.py @@ -48,7 +48,7 @@ def within_bounds(x: float, y: float) -> tuple: """ if 0 <= x <= 11 and 0 <= y <= 11: return True, "Coordinates are within bounds." - elif x < 0 or x > 11 or y < 0 or y > 11: + else: return False, f"({x}, {y}) will be out of bounds. Range is [0, 11] for each." diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/rosa/__init__.py b/tests/rosa/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/rosa/test_prompts.py b/tests/rosa/test_prompts.py new file mode 100644 index 0000000..b0da2ec --- /dev/null +++ b/tests/rosa/test_prompts.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024. Jet Propulsion Laboratory. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/rosa/test_rosa.py b/tests/rosa/test_rosa.py new file mode 100644 index 0000000..b0da2ec --- /dev/null +++ b/tests/rosa/test_rosa.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024. Jet Propulsion Laboratory. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/rosa/tools/__init__.py b/tests/rosa/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/rosa/tools/test_calculation.py b/tests/rosa/tools/test_calculation.py new file mode 100644 index 0000000..b594299 --- /dev/null +++ b/tests/rosa/tools/test_calculation.py @@ -0,0 +1,195 @@ +# Copyright (c) 2024. Jet Propulsion Laboratory. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import statistics +import unittest + +from src.rosa.tools.calculation import ( + add_all, + multiply_all, + mean, + median, + mode, + variance, + add, + subtract, + multiply, + divide, + exponentiate, + modulo, + sine, + cosine, + tangent, + asin, + acos, + atan, + sinh, + cosh, + tanh, + count_list, + count_words, + count_lines, + degrees_to_radians, + radians_to_degrees, +) + + +class TestCalculationTools(unittest.TestCase): + + def test_add_all_returns_sum_of_numbers(self): + self.assertEqual(add_all.invoke({"numbers": [1, 2, 3]}), 6) + self.assertEqual(add_all.invoke({"numbers": []}), 0) + + def test_multiply_all_returns_product_of_numbers(self): + self.assertEqual(multiply_all.invoke({"numbers": [1, 2, 3]}), 6) + self.assertEqual(multiply_all.invoke({"numbers": [1, 0, 3]}), 0) + + def test_mean_returns_mean_and_stdev_of_numbers(self): + self.assertEqual(mean.invoke({"numbers": [1, 2, 3]}), {"mean": 2, "stdev": 1}) + with self.assertRaises(statistics.StatisticsError): + mean.invoke({"numbers": []}) + + def test_median_returns_median_of_numbers(self): + self.assertEqual(median.invoke({"numbers": [1, 2, 3]}), 2) + self.assertEqual(median.invoke({"numbers": [1, 2, 3, 4]}), 2.5) + + def test_mode_returns_mode_of_numbers(self): + self.assertEqual(mode.invoke({"numbers": [1, 1, 2, 3]}), 1) + self.assertEqual(mode.invoke({"numbers": [1, 2, 3]}), 1) + + def test_variance_returns_variance_of_numbers(self): + self.assertEqual(variance.invoke({"numbers": [1, 2, 3]}), 1) + with self.assertRaises(statistics.StatisticsError): + variance.invoke({"numbers": [1]}) + + def test_add_returns_sum_of_xy_pairs(self): + self.assertEqual( + add.invoke({"xy_pairs": [(1, 2), (3, 4)]}), [{"1+2": 3}, {"3+4": 7}] + ) + + def test_subtract_returns_difference_of_xy_pairs(self): + self.assertEqual( + subtract.invoke({"xy_pairs": [(1, 2), (3, 4)]}), [{"1-2": -1}, {"3-4": -1}] + ) + + def test_multiply_returns_product_of_xy_pairs(self): + self.assertEqual( + multiply.invoke({"xy_pairs": [(1, 2), (3, 4)]}), [{"1*2": 2}, {"3*4": 12}] + ) + + def test_divide_returns_quotient_of_xy_pairs(self): + self.assertEqual( + divide.invoke({"xy_pairs": [(1, 2), (3, 4)]}), [{"1/2": 0.5}, {"3/4": 0.75}] + ) + self.assertEqual(divide.invoke({"xy_pairs": [(1, 0)]}), [{"1/0": "undefined"}]) + + def test_exponentiate_returns_exponentiation_of_xy_pairs(self): + self.assertEqual( + exponentiate.invoke({"xy_pairs": [(2, 3), (3, 2)]}), + [{"2^3": 8}, {"3^2": 9}], + ) + + def test_modulo_returns_modulo_of_xy_pairs(self): + self.assertEqual( + modulo.invoke({"xy_pairs": [(5, 3), (10, 2)]}), [{"5%3": 2}, {"10%2": 0}] + ) + self.assertEqual(modulo.invoke({"xy_pairs": [(1, 0)]}), [{"1%0": "undefined"}]) + + def test_sine_returns_sine_of_x_values(self): + self.assertAlmostEqual( + sine.invoke({"x_values": [0, math.pi / 2]}), + [{"sin(0.0)": 0.0}, {"sin(1.5707963267948966)": 1.0}], + ) + + def test_cosine_returns_cosine_of_x_values(self): + cosines = cosine.invoke({"x_values": [0, math.pi / 2]}) + self.assertAlmostEqual(cosines[0]["cos(0.0)"], 1.0, delta=0.0000000000000001) + self.assertAlmostEqual( + cosines[1]["cos(1.5707963267948966)"], 0.0, delta=0.0000000000000001 + ) + + def test_tangent_returns_tangent_of_x_values(self): + # Convert the above to use assertAlmostEqual + tangents = tangent.invoke({"x_values": [0, math.pi / 4]}) + self.assertAlmostEqual(tangents[0]["tan(0.0)"], 0.0, delta=0.0000000000000001) + self.assertAlmostEqual( + tangents[1]["tan(0.7853981633974483)"], 1.0, delta=0.000000000000001 + ) + + def test_asin_returns_arcsine_of_x_values(self): + self.assertEqual( + asin.invoke({"x_values": [0, 1]}), + [{"asin(0.0)": 0.0}, {"asin(1.0)": 1.5707963267948966}], + ) + self.assertEqual(asin.invoke({"x_values": [2]}), [{"asin(2.0)": "undefined"}]) + + def test_acos_returns_arccosine_of_x_values(self): + self.assertEqual( + acos.invoke({"x_values": [0, 1]}), + [{"acos(0.0)": 1.5707963267948966}, {"acos(1.0)": 0.0}], + ) + self.assertEqual(acos.invoke({"x_values": [2]}), [{"acos(2.0)": "undefined"}]) + + def test_atan_returns_arctangent_of_x_values(self): + self.assertEqual( + atan.invoke({"x_values": [0, 1]}), + [{"atan(0.0)": 0.0}, {"atan(1.0)": 0.7853981633974483}], + ) + + def test_sinh_returns_hyperbolic_sine_of_x_values(self): + self.assertEqual( + sinh.invoke({"x_values": [0, 1]}), + [{"sinh(0.0)": 0.0}, {"sinh(1.0)": 1.1752011936438014}], + ) + + def test_cosh_returns_hyperbolic_cosine_of_x_values(self): + self.assertAlmostEqual( + cosh.invoke({"x_values": [0, 1]}), + [{"cosh(0.0)": 1.0}, {"cosh(1.0)": 1.5430806348152437}], + ) + + def test_tanh_returns_hyperbolic_tangent_of_x_values(self): + self.assertEqual( + tanh.invoke({"x_values": [0, 1]}), + [{"tanh(0.0)": 0.0}, {"tanh(1.0)": 0.7615941559557649}], + ) + + def test_count_list_returns_number_of_items_in_list(self): + self.assertEqual(count_list.invoke({"items": [1, 2, 3]}), 3) + self.assertEqual(count_list.invoke({"items": []}), 0) + + def test_count_words_returns_number_of_words_in_string(self): + self.assertEqual(count_words.invoke({"text": "Hello world"}), 2) + self.assertEqual(count_words.invoke({"text": ""}), 0) + + def test_count_lines_returns_number_of_lines_in_string(self): + self.assertEqual(count_lines.invoke({"text": "Hello\nworld"}), 2) + self.assertEqual(count_lines.invoke({"text": ""}), 1) + + def test_degrees_to_radians_converts_degrees_to_radians(self): + self.assertEqual( + degrees_to_radians.invoke({"degrees": [0, 180]}), + {0: "0.0 radians.", 180: "3.14159 radians."}, + ) + + def test_radians_to_degrees_converts_radians_to_degrees(self): + self.assertEqual( + radians_to_degrees.invoke({"radians": [0, 3.14159]}), + {0: "0.0 degrees.", 3.14159: "180.0 degrees."}, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/rosa/tools/test_log.py b/tests/rosa/tools/test_log.py new file mode 100644 index 0000000..ca6d09a --- /dev/null +++ b/tests/rosa/tools/test_log.py @@ -0,0 +1,176 @@ +# Copyright (c) 2024. Jet Propulsion Laboratory. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch, mock_open + +from src.rosa.tools.log import read_log + + +class TestReadLog(unittest.TestCase): + + @patch("os.path.exists") + def test_log_directory_does_not_exist(self, mock_exists): + mock_exists.return_value = False + result = read_log.invoke( + { + "log_file_directory": "/invalid/directory", + "log_filename": "logfile.log", + } + ) + self.assertEqual( + result["error"], + "The log directory '/invalid/directory' does not exist. You should first use your tools to get the " + "correct log directory.", + ) + + @patch("os.path.exists") + def test_log_path_is_not_a_file(self, mock_exists): + mock_exists.side_effect = [True, True] + with patch("os.path.isfile", return_value=False): + result = read_log.invoke( + { + "log_file_directory": "/valid/directory", + "log_filename": "logfile.log", + } + ) + self.assertEqual( + result["error"], + "The path '/valid/directory/logfile.log' is not a file.", + ) + + @patch( + "builtins.open", + new_callable=mock_open, + read_data="INFO: line 1\nERROR: line 2\nDEBUG: line 3\n", + ) + @patch("os.path.exists", return_value=True) + @patch("os.path.isfile", return_value=True) + def test_read_log_with_level_filter(self, mock_exists, mock_isfile, mock_file): + result = read_log.invoke( + { + "log_file_directory": "/valid/directory", + "log_filename": "logfile.log", + "level_filter": "ERROR", + } + ) + self.assertEqual(result["lines"], ["line 2: ERROR: line 2"]) + + @patch( + "builtins.open", + new_callable=mock_open, + read_data="INFO: line 1\nERROR: line 2\nDEBUG: line 3\n", + ) + @patch("os.path.exists", return_value=True) + @patch("os.path.isfile", return_value=True) + def test_read_log_with_line_range(self, mock_exists, mock_isfile, mock_file): + result = read_log.invoke( + { + "log_file_directory": "/valid/directory", + "log_filename": "logfile.log", + "num_lines": 2, + } + ) + self.assertEqual( + result["lines"], ["line 2: ERROR: line 2", "line 3: DEBUG: line 3"] + ) + + @patch("builtins.open", new_callable=mock_open, read_data="INFO: line 1\n" * 202) + @patch("os.path.exists", return_value=True) + @patch("os.path.isfile", return_value=True) + def test_log_file_exceeds_200_lines(self, mock_exists, mock_isfile, mock_file): + result = read_log.invoke( + { + "log_file_directory": "/valid/directory", + "log_filename": "logfile.log", + "num_lines": 203, + } + ) + self.assertEqual( + result["error"], + "The log file 'logfile.log' has more than 200 lines. Please use the `num_lines` argument to read a subset " + "of the log file at a time.", + ) + + @patch( + "builtins.open", + new_callable=mock_open, + read_data="INFO: line 1\nERROR: line 2\nDEBUG: line 3\n", + ) + @patch("os.path.exists", return_value=True) + @patch("os.path.isfile", return_value=True) + def test_read_log_happy_path(self, mock_exists, mock_isfile, mock_file): + result = read_log.invoke( + { + "log_file_directory": "/valid/directory", + "log_filename": "logfile.log", + } + ) + self.assertEqual( + result["lines"], + ["line 1: INFO: line 1", "line 2: ERROR: line 2", "line 3: DEBUG: line 3"], + ) + + @patch("os.path.exists", return_value=True) + @patch("os.path.isfile", return_value=True) + def test_invalid_num_lines_argument(self, mock_exists, mock_isfile): + with patch( + "builtins.open", + new_callable=mock_open, + read_data="INFO: line 1\nERROR: line 2\n", + ): + result = read_log.invoke( + { + "log_file_directory": "/valid/directory", + "log_filename": "logfile.log", + "num_lines": -1, + } + ) + self.assertEqual( + result["error"], + "Invalid `num_lines` argument. It must be a positive integer.", + ) + + @patch("os.path.exists", return_value=True) + @patch("os.path.isfile", return_value=True) + def test_empty_log_file(self, mock_exists, mock_isfile): + with patch("builtins.open", new_callable=mock_open, read_data=""): + result = read_log.invoke( + { + "log_file_directory": "/valid/directory", + "log_filename": "logfile.log", + } + ) + self.assertEqual(result["lines"], []) + + @patch("os.path.exists", return_value=True) + @patch("os.path.isfile", return_value=True) + def test_specific_log_level_not_present(self, mock_exists, mock_isfile): + with patch( + "builtins.open", + new_callable=mock_open, + read_data="INFO: line 1\nDEBUG: line 2\n", + ): + result = read_log.invoke( + { + "log_file_directory": "/valid/directory", + "log_filename": "logfile.log", + "level_filter": "ERROR", + } + ) + self.assertEqual(result["lines"], []) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/rosa/tools/test_ros1.py b/tests/rosa/tools/test_ros1.py new file mode 100644 index 0000000..2fdd463 --- /dev/null +++ b/tests/rosa/tools/test_ros1.py @@ -0,0 +1,603 @@ +# Copyright (c) 2024. Jet Propulsion Laboratory. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +from unittest.mock import patch, MagicMock + +try: + from src.rosa.tools.ros1 import ( + get_entities, + rosgraph_get, + rostopic_list, + rostopic_info, + rostopic_echo, + rosnode_list, + rosnode_info, + rosservice_list, + rosservice_info, + rosservice_call, + rosmsg_info, + rossrv_info, + rosparam_list, + rosparam_get, + rosparam_set, + rospkg_list, + rospkg_roots, + roslog_list, + ) +except ModuleNotFoundError: + pass + + +@unittest.skipIf( + os.environ.get("ROS_VERSION") == "2", + "Skipping ROS1 tests because ROS_VERSION is set to 2", +) +class TestROS1Tools(unittest.TestCase): + + @patch("src.rosa.tools.ros1.rostopic.get_topic_list") + def test_get_entities_topics(self, mock_get_topic_list): + mock_get_topic_list.return_value = ( + [("/turtle1/cmd_vel", "std_msgs/Empty")], + [("/turtle1/pose", "std_msgs/Empty")], + ) + total, in_namespace, match_pattern, entities = get_entities("topic", None, None) + self.assertEqual(total, 2) + self.assertEqual(in_namespace, 2) + self.assertEqual(match_pattern, 2) + self.assertIn("/turtle1/cmd_vel", entities) + self.assertIn("/turtle1/pose", entities) + + @patch("src.rosa.tools.ros1.rosnode.get_node_names") + def test_get_entities_nodes(self, mock_get_node_names): + mock_get_node_names.return_value = ["/turtlesim"] + total, in_namespace, match_pattern, entities = get_entities("node", None, None) + self.assertEqual(total, 1) + self.assertEqual(in_namespace, 1) + self.assertEqual(match_pattern, 1) + self.assertIn("/turtlesim", entities) + + @patch("src.rosa.tools.ros1.rosgraph.masterapi.Master.getSystemState") + def test_rosgraph_get_returns_graph(self, mock_get_system_state): + mock_get_system_state.return_value = ( + [("/topic1", ["/node1"]), ("/topic2", ["/node2"])], + [("/topic1", ["/node3"]), ("/topic2", ["/node4"])], + [], + ) + result = rosgraph_get.invoke( + { + "node_pattern": ".*", + "topic_pattern": ".*", + "blacklist": [], + "exclude_self_connections": True, + } + ) + self.assertIn("graph", result) + self.assertEqual(len(result["graph"]), 2) + + @patch("src.rosa.tools.ros1.rosgraph.masterapi.Master.getSystemState") + def test_rosgraph_get_handles_empty_graph(self, mock_get_system_state): + mock_get_system_state.return_value = ([], [], []) + result = rosgraph_get.invoke( + { + "node_pattern": ".*", + "topic_pattern": ".*", + "blacklist": [], + "exclude_self_connections": True, + } + ) + self.assertIn("error", result) + self.assertEqual( + result["error"], + "No results found for the specified parameters. Note that the following have been excluded: []", + ) + + @patch("src.rosa.tools.ros1.rosgraph.masterapi.Master.getSystemState") + def test_rosgraph_get_excludes_blacklisted_nodes(self, mock_get_system_state): + mock_get_system_state.return_value = ( + [("/topic1", ["/node1"]), ("/topic2", ["/node2"])], + [("/topic1", ["/node3"]), ("/topic2", ["/node4"])], + [], + ) + result = rosgraph_get.invoke( + { + "node_pattern": ".*", + "topic_pattern": ".*", + "blacklist": ["node1"], + "exclude_self_connections": True, + } + ) + self.assertIn("graph", result) + self.assertEqual(len(result["graph"]), 1) + self.assertNotIn("/node1", result["graph"][0]) + + @patch("src.rosa.tools.ros1.rosgraph.masterapi.Master.getSystemState") + def test_rosgraph_get_excludes_self_connections(self, mock_get_system_state): + mock_get_system_state.return_value = ( + [("/topic1", ["/node1"])], + [("/topic1", ["/node1"])], + [], + ) + result = rosgraph_get.invoke( + { + "node_pattern": ".*", + "topic_pattern": ".*", + "blacklist": [], + "exclude_self_connections": True, + } + ) + self.assertIn("error", result) + self.assertEqual( + result["error"], + "No results found for the specified parameters. Note that the following have been excluded: []", + ) + + @patch("src.rosa.tools.ros1.rostopic.get_info_text") + def test_rostopic_info(self, mock_get_info_text): + mock_get_info_text.return_value = ( + "Type: std_msgs/String\nPublishers:\n* /turtlesim\nSubscribers:\n* /rosout" + ) + result = rostopic_info.invoke({"topics": ["/turtle1/cmd_vel"]}) + self.assertIn("/turtle1/cmd_vel", result) + self.assertEqual(result["/turtle1/cmd_vel"]["type"], "std_msgs/String") + self.assertIn("/turtlesim", result["/turtle1/cmd_vel"]["publishers"]) + self.assertIn("/rosout", result["/turtle1/cmd_vel"]["subscribers"]) + + @patch("src.rosa.tools.ros1.rospy.wait_for_message") + @patch("src.rosa.tools.ros1.rostopic.get_topic_class") + def test_rostopic_echo(self, mock_get_topic_class, mock_wait_for_message): + mock_get_topic_class.return_value = (MagicMock(), None, None) + mock_wait_for_message.return_value = MagicMock() + result = rostopic_echo.invoke( + {"topic": "/turtle1/cmd_vel", "count": 1, "return_echoes": True} + ) + self.assertEqual(result["requested_count"], 1) + self.assertEqual(result["actual_count"], 1) + self.assertIn("echoes", result) + + @patch("rosnode.get_node_names") + def test_rosnode_list_returns_all_nodes(self, mock_get_node_names): + mock_get_node_names.return_value = ["/node1", "/node2", "/node3"] + result = rosnode_list.invoke({}) + self.assertEqual(result["total"], 3) + self.assertEqual(result["in_namespace"], 3) + self.assertEqual(result["match_pattern"], 3) + self.assertEqual(result["nodes"], ["/node1", "/node2", "/node3"]) + + @patch("rosnode.get_node_names") + def test_rosnode_list_filters_by_namespace(self, mock_get_node_names): + mock_get_node_names.return_value = ["/namespace1/node1", "/namespace2/node2"] + result = rosnode_list.invoke({"namespace": "/namespace1"}) + self.assertEqual(result["total"], 2) + self.assertEqual(result["in_namespace"], 1) + self.assertEqual(result["match_pattern"], 1) + self.assertEqual(result["nodes"], ["/namespace1/node1"]) + + @patch("rosnode.get_node_names") + def test_rosnode_list_filters_by_pattern(self, mock_get_node_names): + mock_get_node_names.return_value = ["/node1", "/node2", "/node3"] + result = rosnode_list.invoke({"pattern": "node1"}) + self.assertEqual(result["total"], 3) + self.assertEqual(result["in_namespace"], 3) + self.assertEqual(result["match_pattern"], 1) + self.assertEqual(result["nodes"], ["/node1"]) + + @patch("rosnode.get_node_names") + def test_rosnode_list_handles_no_nodes(self, mock_get_node_names): + mock_get_node_names.return_value = [] + result = rosnode_list.invoke({}) + self.assertEqual(result["total"], 0) + self.assertEqual(result["in_namespace"], 0) + self.assertEqual(result["match_pattern"], 0) + self.assertEqual( + result["nodes"], ["There are currently no nodes available in the system."] + ) + + @patch("rosnode.get_node_names") + def test_rosnode_list_handles_no_nodes_in_namespace(self, mock_get_node_names): + mock_get_node_names.return_value = ["/node1", "/node2"] + result = rosnode_list.invoke({"namespace": "/namespace1"}) + self.assertEqual(result["total"], 2) + self.assertEqual(result["in_namespace"], 0) + self.assertEqual(result["match_pattern"], 0) + self.assertEqual( + result["nodes"], + [ + "There are currently no nodes available using the '/namespace1' namespace." + ], + ) + + @patch("rosnode.get_node_names") + def test_rosnode_list_handles_no_nodes_matching_pattern(self, mock_get_node_names): + mock_get_node_names.return_value = ["/node1", "/node2"] + result = rosnode_list.invoke({"pattern": "node3"}) + self.assertEqual(result["total"], 2) + self.assertEqual(result["in_namespace"], 2) + self.assertEqual(result["match_pattern"], 0) + self.assertEqual( + result["nodes"], + ["There are currently no nodes available matching the specified pattern."], + ) + + @patch("rosnode.get_node_names") + def test_rosnode_list_filters_by_blacklist(self, mock_get_node_names): + mock_get_node_names.return_value = ["/node1", "/node2", "/node3"] + result = rosnode_list.invoke({"blacklist": ["node2"]}) + self.assertEqual(result["total"], 3) + self.assertEqual(result["in_namespace"], 3) + self.assertEqual(result["match_pattern"], 3) + self.assertEqual(result["nodes"], ["/node1", "/node3"]) + + @patch("src.rosa.tools.ros1.rosnode.get_node_info_description") + def test_rosnode_info(self, mock_get_node_info_description): + mock_get_node_info_description.return_value = ( + "Node: /turtlesim\nPublications: /turtle1/cmd_vel" + ) + result = rosnode_info.invoke({"nodes": ["/turtlesim"]}) + self.assertIn("/turtlesim", result) + self.assertIn("Node: /turtlesim", result["/turtlesim"]) + + @patch("src.rosa.tools.ros1.rosservice.get_service_list") + def test_rosservice_list(self, mock_get_service_list): + mock_get_service_list.return_value = ["/clear", "/reset"] + result = rosservice_list.invoke({}) + self.assertIn("/clear", result) + self.assertIn("/reset", result) + + @patch("src.rosa.tools.ros1.rosservice.get_service_headers") + @patch("src.rosa.tools.ros1.rosservice.get_service_uri") + def test_rosservice_info(self, mock_get_service_uri, mock_get_service_headers): + mock_get_service_uri.return_value = "rosrpc://localhost:12345" + mock_get_service_headers.return_value = {"callerid": "/turtlesim"} + result = rosservice_info.invoke({"services": ["/clear"]}) + self.assertIn("/clear", result) + self.assertIn("callerid", result["/clear"]) + + @patch("src.rosa.tools.ros1.rosservice.call_service") + def test_rosservice_call(self, mock_call_service): + mock_call_service.return_value = "success" + result = rosservice_call.invoke({"service": "/clear", "args": []}) + self.assertEqual(result, "success") + + @patch("src.rosa.tools.ros1.rosmsg.get_msg_text") + def test_rosmsg_info(self, mock_get_msg_text): + mock_get_msg_text.return_value = "string data" + result = rosmsg_info.invoke({"msg_type": ["std_msgs/String"]}) + self.assertIn("std_msgs/String", result) + self.assertEqual(result["std_msgs/String"], "string data") + + @patch("src.rosa.tools.ros1.rosmsg.get_srv_text") + def test_rossrv_info(self, mock_get_srv_text): + mock_get_srv_text.return_value = "string data" + result = rossrv_info.invoke({"srv_type": ["std_srvs/Empty"]}) + self.assertIn("std_srvs/Empty", result) + self.assertEqual(result["std_srvs/Empty"], "string data") + + @patch("src.rosa.tools.ros1.rosparam.list_params") + def test_rosparam_list(self, mock_list_params): + mock_list_params.return_value = [ + "/turtlesim/background_r", + "/turtlesim/background_g", + ] + result = rosparam_list.invoke({}) + self.assertIn("/turtlesim/background_r", result["ros_params"]) + self.assertIn("/turtlesim/background_g", result["ros_params"]) + + @patch("src.rosa.tools.ros1.rosparam.get_param") + def test_rosparam_get(self, mock_get_param): + mock_get_param.return_value = 255 + result = rosparam_get.invoke({"params": ["/turtlesim/background_r"]}) + self.assertIn("/turtlesim/background_r", result) + self.assertEqual(result["/turtlesim/background_r"], 255) + + @patch("src.rosa.tools.ros1.rosparam.set_param") + def test_rosparam_set(self, mock_set_param): + result = rosparam_set.invoke( + {"param": "/turtlesim/background_r", "value": "255", "is_rosa_param": False} + ) + self.assertEqual(result, "Set parameter '/turtlesim/background_r' to '255'.") + + @patch("src.rosa.tools.ros1.rospkg.RosPack.list") + def test_rospkg_list(self, mock_list): + mock_list.return_value = ["turtlesim", "std_msgs"] + result = rospkg_list.invoke({"ignore_msgs": True}) + self.assertIn("turtlesim", result["packages"]) + self.assertNotIn("std_msgs", result["packages"]) + + result = rospkg_list.invoke({"ignore_msgs": False}) + self.assertIn("turtlesim", result["packages"]) + self.assertIn("std_msgs", result["packages"]) + + @patch("src.rosa.tools.ros1.rospkg.get_ros_package_path") + def test_rospkg_roots(self, mock_get_ros_package_path): + mock_get_ros_package_path.return_value = ["/opt/ros/noetic/share"] + result = rospkg_roots.invoke({}) + self.assertIn("/opt/ros/noetic/share", result) + + @patch("src.rosa.tools.ros1.get_roslog_directories") + @patch("os.listdir") + @patch("os.path.isfile") + @patch("os.path.getsize") + def test_roslog_list_with_min_size( + self, mock_getsize, mock_isfile, mock_listdir, mock_get_roslog_directories + ): + mock_get_roslog_directories.return_value = {"default": "/mock/log/dir"} + mock_listdir.return_value = ["log1.log", "log2.log", "log3.log"] + mock_isfile.side_effect = lambda x: x.endswith(".log") + mock_getsize.side_effect = lambda x: 3000 if "log1.log" in x else 1000 + + result = roslog_list.invoke({"min_size": 2048}) + + self.assertEqual(result["total"], 1) + self.assertEqual(len(result["logs"]), 1) + self.assertEqual(result["logs"][0]["total"], 1) + self.assertIn("/log1.log", result["logs"][0]["files"][0]) + + @patch("src.rosa.tools.ros1.get_roslog_directories") + @patch("os.listdir") + @patch("os.path.isfile") + @patch("os.path.getsize") + def test_roslog_list_with_blacklist( + self, mock_getsize, mock_isfile, mock_listdir, mock_get_roslog_directories + ): + mock_get_roslog_directories.return_value = {"default": "/mock/log/dir"} + mock_listdir.return_value = ["log1.log", "log2.log", "log3.log"] + mock_isfile.side_effect = lambda x: x.endswith(".log") + mock_getsize.side_effect = lambda x: 3000 + + result = roslog_list.invoke({"blacklist": ["log2"]}) + + self.assertEqual(result["total"], 1) + self.assertEqual(len(result["logs"]), 1) + self.assertEqual(result["logs"][0]["total"], 2) + self.assertNotIn("log2.log", result["logs"][0]["files"][0]) + + @patch("src.rosa.tools.ros1.get_roslog_directories") + @patch("os.listdir") + @patch("os.path.isfile") + @patch("os.path.getsize") + def test_roslog_list_no_logs( + self, mock_getsize, mock_isfile, mock_listdir, mock_get_roslog_directories + ): + mock_get_roslog_directories.return_value = {"default": "/mock/log/dir"} + mock_listdir.return_value = [] + mock_isfile.side_effect = lambda x: x.endswith(".log") + mock_getsize.side_effect = lambda x: 3000 + + result = roslog_list.invoke({}) + + self.assertEqual(result["total"], 0) + self.assertEqual(len(result["logs"]), 0) + + @patch("src.rosa.tools.ros1.get_roslog_directories") + @patch("os.listdir") + @patch("os.path.isfile") + @patch("os.path.getsize") + def test_roslog_list_with_multiple_directories( + self, mock_getsize, mock_isfile, mock_listdir, mock_get_roslog_directories + ): + mock_get_roslog_directories.return_value = { + "default": "/mock/log/dir1", + "latest": "/mock/log/dir2", + } + mock_listdir.side_effect = lambda x: ( + ["log1.log", "log2.log"] if "dir1" in x else ["log3.log", "log4.log"] + ) + mock_isfile.side_effect = lambda x: x.endswith(".log") + mock_getsize.side_effect = lambda x: 3000 + + result = roslog_list.invoke({}) + + self.assertEqual(result["total"], 2) + self.assertEqual(len(result["logs"]), 2) + self.assertEqual(result["logs"][0]["total"], 2) + self.assertEqual(result["logs"][1]["total"], 2) + + @patch("rospy.loginfo") + @patch("src.rosa.tools.ros1.get_entities") + def test_rostopic_list_returns_all_topics(self, mock_get_entities, mock_loginfo): + mock_get_entities.return_value = (10, 10, 10, ["topic1", "topic2"]) + result = rostopic_list.invoke({}) + self.assertEqual(result["total"], 10) + self.assertEqual(result["in_namespace"], 10) + self.assertEqual(result["match_pattern"], 10) + self.assertEqual(result["topics"], ["topic1", "topic2"]) + + @patch("rospy.loginfo") + @patch("src.rosa.tools.ros1.get_entities") + def test_rostopic_list_with_pattern(self, mock_get_entities, mock_loginfo): + mock_get_entities.return_value = (10, 10, 2, ["topic1", "topic2"]) + result = rostopic_list.invoke({"pattern": "topic"}) + self.assertEqual(result["match_pattern"], 2) + self.assertEqual(result["topics"], ["topic1", "topic2"]) + + @patch("rospy.loginfo") + @patch("src.rosa.tools.ros1.get_entities") + def test_rostopic_list_with_namespace(self, mock_get_entities, mock_loginfo): + mock_get_entities.return_value = ( + 10, + 5, + 5, + ["namespace/topic1", "namespace/topic2"], + ) + result = rostopic_list.invoke({"namespace": "namespace"}) + self.assertEqual(result["in_namespace"], 5) + self.assertEqual(result["topics"], ["namespace/topic1", "namespace/topic2"]) + + @patch("rospy.loginfo") + @patch("src.rosa.tools.ros1.get_entities") + def test_rostopic_list_with_blacklist(self, mock_get_entities, mock_loginfo): + mock_get_entities.return_value = (2, 2, 2, ["topic1", "topic2"]) + result = rostopic_list.invoke({"blacklist": ["topic2"]}) + self.assertEqual(result["topics"], ["topic1"]) + + @patch("rospy.loginfo") + @patch("src.rosa.tools.ros1.get_entities") + def test_rostopic_list_no_topics_available(self, mock_get_entities, mock_loginfo): + mock_get_entities.return_value = ( + 0, + 0, + 0, + ["There are currently no topics available in the system."], + ) + result = rostopic_list.invoke({}) + self.assertEqual(result["total"], 0) + self.assertEqual( + result["topics"], ["There are currently no topics available in the system."] + ) + + @patch("src.rosa.tools.ros1.rostopic.get_topic_list") + @patch("src.rosa.tools.ros1.rosnode.get_node_names") + def test_get_entities_topics(self, mock_get_node_names, mock_get_topic_list): + mock_get_topic_list.return_value = ( + [(f"/topic{i}", "type") for i in range(5)], + [(f"/topic{i}", "type") for i in range(5, 10)], + ) + total, in_namespace, match_pattern, entities = get_entities("topic", None, None) + self.assertEqual(total, 10) + self.assertEqual(in_namespace, 10) + self.assertEqual(match_pattern, 10) + self.assertEqual(len(entities), 10) + + @patch("src.rosa.tools.ros1.rostopic.get_topic_list") + @patch("src.rosa.tools.ros1.rosnode.get_node_names") + def test_get_entities_nodes(self, mock_get_node_names, mock_get_topic_list): + mock_get_node_names.return_value = [f"/node{i}" for i in range(10)] + total, in_namespace, match_pattern, entities = get_entities("node", None, None) + self.assertEqual(total, 10) + self.assertEqual(in_namespace, 10) + self.assertEqual(match_pattern, 10) + self.assertEqual(len(entities), 10) + + @patch("src.rosa.tools.ros1.rostopic.get_topic_list") + @patch("src.rosa.tools.ros1.rosnode.get_node_names") + def test_get_entities_with_namespace( + self, mock_get_node_names, mock_get_topic_list + ): + mock_get_topic_list.return_value = ( + [(f"/namespace/topic{i}", "type") for i in range(5)] + + [(f"/topic{i}", "type") for i in range(5)], + [(f"/namespace/topic{i}", "type") for i in range(5, 10)], + ) + + mock_get_node_names.return_value = [ + f"/namespace/node{i}" for i in range(10) + ] + [f"/node{i}" for i in range(10)] + + total, in_namespace, match_pattern, entities = get_entities( + "topic", None, "/namespace" + ) + self.assertEqual(total, 15) + self.assertEqual(in_namespace, 10) + self.assertEqual(match_pattern, 10) + self.assertEqual(len(entities), 10) + + total, in_namespace, match_pattern, entities = get_entities( + "node", None, "/namespace" + ) + self.assertEqual(total, 20) + self.assertEqual(in_namespace, 10) + self.assertEqual(match_pattern, 10) + self.assertEqual(len(entities), 10) + + @patch("src.rosa.tools.ros1.rostopic.get_topic_list") + @patch("src.rosa.tools.ros1.rosnode.get_node_names") + def test_get_entities_with_pattern(self, mock_get_node_names, mock_get_topic_list): + mock_get_topic_list.return_value = ( + [(f"/topic{i}", "type") for i in range(5)], + [(f"/topic{i}", "type") for i in range(5, 10)], + ) + total, in_namespace, match_pattern, entities = get_entities( + "topic", "topic[0-4]", None + ) + self.assertEqual(total, 10) + self.assertEqual(in_namespace, 10) + self.assertEqual(match_pattern, 5) + self.assertEqual(len(entities), 5) + + @patch("src.rosa.tools.ros1.rostopic.get_topic_list") + @patch("src.rosa.tools.ros1.rosnode.get_node_names") + def test_get_entities_with_blacklist( + self, mock_get_node_names, mock_get_topic_list + ): + mock_get_topic_list.return_value = ( + [(f"/topic{i}", "type") for i in range(5)], + [(f"/topic{i}", "type") for i in range(5, 10)], + ) + total, in_namespace, match_pattern, entities = get_entities( + "topic", None, None, blacklist=["/topic0", "/topic1"] + ) + self.assertEqual(total, 10) + self.assertEqual(in_namespace, 10) + self.assertEqual(match_pattern, 10) + self.assertEqual(len(entities), 8) + + @patch("src.rosa.tools.ros1.rostopic.get_topic_list") + @patch("src.rosa.tools.ros1.rosnode.get_node_names") + def test_get_entities_no_entities(self, mock_get_node_names, mock_get_topic_list): + mock_get_topic_list.return_value = ([], []) + total, in_namespace, match_pattern, entities = get_entities("topic", None, None) + self.assertEqual(total, 0) + self.assertEqual(in_namespace, 0) + self.assertEqual(match_pattern, 0) + self.assertEqual( + entities, ["There are currently no topics available in the system."] + ) + + @patch("src.rosa.tools.ros1.rostopic.get_topic_list") + @patch("src.rosa.tools.ros1.rosnode.get_node_names") + def test_get_entities_no_namespace_entities( + self, mock_get_node_names, mock_get_topic_list + ): + mock_get_topic_list.return_value = ( + [(f"/topic{i}", "type") for i in range(5)], + [(f"/topic{i}", "type") for i in range(5, 10)], + ) + total, in_namespace, match_pattern, entities = get_entities( + "topic", None, "/nonexistent" + ) + self.assertEqual(total, 10) + self.assertEqual(in_namespace, 0) + self.assertEqual(match_pattern, 0) + self.assertEqual( + entities, + [ + "There are currently no topics available using the '/nonexistent' namespace." + ], + ) + + @patch("src.rosa.tools.ros1.rostopic.get_topic_list") + @patch("src.rosa.tools.ros1.rosnode.get_node_names") + def test_get_entities_no_pattern_entities( + self, mock_get_node_names, mock_get_topic_list + ): + mock_get_topic_list.return_value = ( + [(f"/topic{i}", "type") for i in range(5)], + [(f"/topic{i}", "type") for i in range(5, 10)], + ) + total, in_namespace, match_pattern, entities = get_entities( + "topic", "nonexistent", None + ) + self.assertEqual(total, 10) + self.assertEqual(in_namespace, 10) + self.assertEqual(match_pattern, 0) + self.assertEqual( + entities, + ["There are currently no topics available matching the specified pattern."], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/rosa/tools/test_ros2.py b/tests/rosa/tools/test_ros2.py new file mode 100644 index 0000000..2dc4cf5 --- /dev/null +++ b/tests/rosa/tools/test_ros2.py @@ -0,0 +1,290 @@ +# Copyright (c) 2024. Jet Propulsion Laboratory. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +import unittest +from unittest.mock import patch + +try: + from src.rosa.tools.ros2 import ( + execute_ros_command, + ros2_node_list, + ros2_topic_list, + ros2_topic_echo, + ros2_service_list, + ros2_node_info, + ros2_param_list, + ros2_param_get, + ros2_param_set, + ) +except ModuleNotFoundError: + pass + + +@unittest.skipIf( + os.environ.get("ROS_VERSION") == "1", + "Skipping ROS2 tests because ROS_VERSION is set to 1", +) +class TestROS2Tools(unittest.TestCase): + + @patch("src.rosa.tools.ros2.subprocess.check_output") + def test_execute_valid_ros2_command(self, mock_check_output): + mock_check_output.return_value = b"Node /example_node\n" + success, output = execute_ros_command("ros2 node list") + self.assertTrue(success) + self.assertEqual(output, "Node /example_node\n") + + @patch("src.rosa.tools.ros2.subprocess.check_output") + def test_execute_invalid_ros2_command(self, mock_check_output): + mock_check_output.side_effect = subprocess.CalledProcessError( + 1, "ros2 node list" + ) + success, output = execute_ros_command("ros2 node list") + self.assertFalse(success) + self.assertIn( + "Command 'ros2 node list' returned non-zero exit status 1.", output + ) + + def test_execute_command_with_invalid_prefix(self): + with self.assertRaises(ValueError): + execute_ros_command("invalid node list") + + def test_execute_command_with_invalid_subcommand(self): + with self.assertRaises(ValueError): + execute_ros_command("ros2 invalid_subcommand") + + def test_execute_command_with_insufficient_arguments(self): + with self.assertRaises(ValueError): + execute_ros_command("ros2") + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_node_list_returns_nodes(self, mock_execute): + mock_execute.return_value = (True, "/node1\n/node2\n") + result = ros2_node_list.invoke({"pattern": None}) + self.assertEqual(result, {"nodes": ["/node1", "/node2"]}) + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_node_list_with_pattern(self, mock_execute): + mock_execute.return_value = (True, "/node1\n/node2\n") + result = ros2_node_list.invoke({"pattern": "node1"}) + self.assertEqual(result, {"nodes": ["/node1"]}) + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_node_list_with_blacklist(self, mock_execute): + mock_execute.return_value = (True, "/node1\n/node2\n") + result = ros2_node_list.invoke({"blacklist": ["node2"]}) + self.assertEqual(result, {"nodes": ["/node1"]}) + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_node_list_invalid_command(self, mock_execute): + mock_execute.return_value = (False, "Invalid command") + result = ros2_node_list.invoke({"pattern": None}) + self.assertEqual(result, {"nodes": ["Invalid command"]}) + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_topic_list_returns_topics(self, mock_execute): + mock_execute.return_value = (True, "/topic1\n/topic2\n") + result = ros2_topic_list.invoke({"pattern": None}) + self.assertEqual(result, {"topics": ["/topic1", "/topic2"]}) + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_topic_list_with_pattern(self, mock_execute): + mock_execute.return_value = (True, "/topic1\n/topic2\n") + result = ros2_topic_list.invoke({"pattern": "topic1"}) + self.assertEqual(result, {"topics": ["/topic1"]}) + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_topic_list_with_blacklist(self, mock_execute): + mock_execute.return_value = (True, "/topic1\n/topic2\n") + result = ros2_topic_list.invoke({"blacklist": ["topic2"]}) + self.assertEqual(result, {"topics": ["/topic1"]}) + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_topic_list_invalid_command(self, mock_execute): + mock_execute.return_value = (False, "Invalid command") + result = ros2_topic_list.invoke({"pattern": None}) + self.assertEqual(result, {"topics": ["Invalid command"]}) + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_topic_echo_success(self, mock_execute): + mock_execute.return_value = (True, "Message 1\n") + result = ros2_topic_echo.invoke( + {"topic": "/example_topic", "count": 1, "return_echoes": True} + ) + self.assertEqual(result, {"echoes": ["Message 1\n"]}) + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_topic_echo_multiple_messages(self, mock_execute): + mock_execute.return_value = (True, "Message 1\n") + result = ros2_topic_echo.invoke( + {"topic": "/example_topic", "count": 3, "return_echoes": True} + ) + self.assertEqual( + result, {"echoes": ["Message 1\n", "Message 1\n", "Message 1\n"]} + ) + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_topic_echo_invalid_topic(self, mock_execute): + mock_execute.return_value = (False, "Invalid topic") + result = ros2_topic_echo.invoke({"topic": "/invalid_topic", "count": 1}) + self.assertEqual(result, {"error": "Invalid topic"}) + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_topic_echo_invalid_count(self, mock_execute): + result = ros2_topic_echo.invoke({"topic": "/example_topic", "count": 11}) + self.assertEqual(result, {"error": "Count must be between 1 and 10."}) + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_topic_echo_command_failure(self, mock_execute): + mock_execute.return_value = (False, "Command failed") + result = ros2_topic_echo.invoke({"topic": "/example_topic", "count": 1}) + self.assertEqual(result, {"error": "Command failed"}) + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_service_list_returns_services(self, mock_execute): + mock_execute.return_value = (True, "/service1\n/service2\n") + result = ros2_service_list.invoke({"pattern": None}) + self.assertEqual(result, {"services": ["/service1", "/service2"]}) + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_service_list_with_pattern(self, mock_execute): + mock_execute.return_value = (True, "/service1\n/service2\n") + result = ros2_service_list.invoke({"pattern": "service1"}) + self.assertEqual(result, {"services": ["/service1"]}) + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_service_list_with_blacklist(self, mock_execute): + mock_execute.return_value = (True, "/service1\n/service2\n") + result = ros2_service_list.invoke({"blacklist": ["service2"]}) + self.assertEqual(result, {"services": ["/service1"]}) + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_service_list_invalid_command(self, mock_execute): + mock_execute.return_value = (False, "Invalid command") + result = ros2_service_list.invoke({"pattern": None}) + self.assertEqual(result, {"services": ["Invalid command"]}) + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_node_info_success(self, mock_execute): + mock_execute.return_value = (True, "Node info for /node1") + result = ros2_node_info.invoke({"nodes": ["/node1"]}) + self.assertEqual(result, {"/node1": "Node info for /node1"}) + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_node_info_multiple_nodes(self, mock_execute): + mock_execute.side_effect = [ + (True, "Node info for /node1"), + (True, "Node info for /node2"), + ] + result = ros2_node_info.invoke({"nodes": ["/node1", "/node2"]}) + self.assertEqual( + result, {"/node1": "Node info for /node1", "/node2": "Node info for /node2"} + ) + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_node_info_invalid_node(self, mock_execute): + mock_execute.return_value = (False, "Invalid node") + result = ros2_node_info.invoke({"nodes": ["/invalid_node"]}) + self.assertEqual(result, {"/invalid_node": {"error": "Invalid node"}}) + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_node_info_command_failure(self, mock_execute): + mock_execute.return_value = (False, "Command failed") + result = ros2_node_info.invoke({"nodes": ["/node1"]}) + self.assertEqual(result, {"/node1": {"error": "Command failed"}}) + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_param_list_returns_params_for_node(self, mock_execute): + mock_execute.return_value = (True, "param1\nparam2\n") + result = ros2_param_list.invoke({"node_name": "/example_node"}) + self.assertEqual(result, {"/example_node": ["param1", "param2"]}) + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_param_list_returns_all_params(self, mock_execute): + mock_execute.return_value = ( + True, + "/node1\n param1\n param2\n/node2\n param3\n", + ) + result = ros2_param_list.invoke({}) + self.assertEqual(result, {"/node1": ["param1", "param2"], "/node2": ["param3"]}) + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_param_list_with_pattern(self, mock_execute): + mock_execute.return_value = (True, "param1\nparam2\n") + result = ros2_param_list.invoke( + {"node_name": "/example_node", "pattern": "param1"} + ) + self.assertEqual(result, {"/example_node": ["param1"]}) + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_param_list_with_blacklist(self, mock_execute): + mock_execute.return_value = (True, "param1\nparam2\n") + result = ros2_param_list.invoke( + {"node_name": "/example_node", "blacklist": ["param2"]} + ) + self.assertEqual(result, {"/example_node": ["param1"]}) + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_param_list_invalid_command(self, mock_execute): + mock_execute.return_value = (False, "Invalid command") + result = ros2_param_list.invoke({"node_name": "/example_node"}) + self.assertEqual(result, {"error": "Invalid command"}) + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_param_get_success(self, mock_execute): + mock_execute.return_value = (True, "value1") + result = ros2_param_get.invoke( + {"node_name": "/example_node", "param_name": "param1"} + ) + self.assertEqual(result, {"param1": "value1"}) + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_param_get_invalid_command(self, mock_execute): + mock_execute.return_value = (False, "Invalid command") + result = ros2_param_get.invoke( + {"node_name": "/example_node", "param_name": "param1"} + ) + self.assertEqual(result, {"error": "Invalid command"}) + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_param_set_success(self, mock_execute): + mock_execute.return_value = (True, "value1") + result = ros2_param_set.invoke( + { + "node_name": "/example_node", + "param_name": "param1", + "param_value": "value1", + } + ) + self.assertEqual(result, {"param1": "value1"}) + + @patch("src.rosa.tools.ros2.execute_ros_command") + def test_ros2_param_set_invalid_command(self, mock_execute): + mock_execute.return_value = (False, "Invalid command") + result = ros2_param_set.invoke( + { + "node_name": "/example_node", + "param_name": "param1", + "param_value": "value1", + } + ) + self.assertEqual(result, {"error": "Invalid command"}) + + +if __name__ == "__main__": + import os + + if os.environ.get("ROS_VERSION") == 2: + unittest.main() diff --git a/tests/rosa/tools/test_rosa_tools.py b/tests/rosa/tools/test_rosa_tools.py new file mode 100644 index 0000000..4cecdf3 --- /dev/null +++ b/tests/rosa/tools/test_rosa_tools.py @@ -0,0 +1,110 @@ +# Copyright (c) 2024. Jet Propulsion Laboratory. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +from unittest.mock import patch + +from langchain.agents import tool + +from src.rosa.tools import ROSATools, inject_blacklist + + +@tool +def sample_tool(blacklist=None): + """A sample tool that returns the blacklist.""" + return blacklist + + +class TestROSATools(unittest.TestCase): + def setUp(self): + self.ros_version = int(os.getenv("ROS_VERSION", 1)) + + def test_initializes_with_ros_version_1(self): + if self.ros_version == 1: + tools = ROSATools(ros_version=1) + self.assertEqual(tools._ROSATools__ros_version, 1) + else: + with self.assertRaises(ModuleNotFoundError): + tools = ROSATools(ros_version=1) + self.assertEqual(tools._ROSATools__ros_version, 1) + + def test_initializes_with_ros_version_2(self): + if self.ros_version == 2: + tools = ROSATools(ros_version=2) + self.assertEqual(tools._ROSATools__ros_version, 2) + else: + with self.assertRaises(ModuleNotFoundError): + tools = ROSATools(ros_version=2) + self.assertEqual(tools._ROSATools__ros_version, 2) + + def test_raises_value_error_for_invalid_ros_version(self): + if self.ros_version == 1: + with self.assertRaises(ModuleNotFoundError): + ROSATools(ros_version=2) + else: + with self.assertRaises(ModuleNotFoundError): + ROSATools(ros_version=1) + + @patch("src.rosa.tools.calculation") + @patch("src.rosa.tools.log") + @patch("src.rosa.tools.system") + def test_adds_default_tools(self, mock_system, mock_log, mock_calculation): + if self.ros_version == 1: + tools = ROSATools(ros_version=1) + else: + tools = ROSATools(ros_version=2) + self.assertIn(mock_calculation.return_value, tools.get_tools()) + self.assertIn(mock_log.return_value, tools.get_tools()) + self.assertIn(mock_system.return_value, tools.get_tools()) + + def test_injects_blacklist_into_tool_function(self): + def sample_tool(blacklist=None): + return blacklist + + decorated_tool = inject_blacklist(["item1", "item2"])(sample_tool) + self.assertEqual(decorated_tool(), ["item1", "item2"]) + + def test_blacklist_gets_concatenated(self): + decorated_tool = inject_blacklist(["item1", "item2"])(sample_tool) + self.assertEqual( + decorated_tool({"blacklist": ["item3"]}), + ["item1", "item2", "item3"], + ) + + +@unittest.skipIf(os.environ.get("ROS_VERSION") == "2", "Skipping ROS 1 tests") +class TestROSA1Tools(unittest.TestCase): + @patch("src.rosa.tools.ros1") + def test_ros1_tools(self, mock_ros1): + tools = ROSATools(ros_version=1) + self.assertIn(mock_ros1.return_value, tools.get_tools()) + with self.assertRaises(ModuleNotFoundError): + tools = ROSATools(ros_version=2) + self.assertIn(mock_ros1.return_value, tools.get_tools()) + + +@unittest.skipIf(os.environ.get("ROS_VERSION") == "1", "Skipping ROS 2 tests") +class TestROSA2Tools(unittest.TestCase): + @patch("src.rosa.tools.ros2") + def test_ros2_tools(self, mock_ros2): + tools = ROSATools(ros_version=2) + self.assertIn(mock_ros2.return_value, tools.get_tools()) + with self.assertRaises(ModuleNotFoundError): + tools = ROSATools(ros_version=1) + self.assertIn(mock_ros2.return_value, tools.get_tools()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/rosa/tools/test_system.py b/tests/rosa/tools/test_system.py new file mode 100644 index 0000000..55924a1 --- /dev/null +++ b/tests/rosa/tools/test_system.py @@ -0,0 +1,60 @@ +# Copyright (c) 2024. Jet Propulsion Laboratory. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import unittest + +from langchain.globals import get_debug, get_verbose, set_debug + +from src.rosa.tools.system import set_verbosity, set_debuging, wait + + +class TestSystemTools(unittest.TestCase): + + def test_sets_verbosity_to_true(self): + result = set_verbosity.invoke({"enable_verbose_messages": True}) + self.assertEqual(result, "Verbose messages are now enabled.") + self.assertTrue(get_verbose()) + result = set_verbosity.invoke({"enable_verbose_messages": False}) + self.assertEqual(result, "Verbose messages are now disabled.") + self.assertFalse(get_verbose()) + + def test_sets_debug_to_true(self): + result = set_debuging.invoke({"enable_debug_messages": True}) + self.assertEqual(result, "Debug messages are now enabled.") + self.assertTrue(get_debug()) + set_debug(False) + result = set_debuging.invoke({"enable_debug_messages": False}) + self.assertEqual(result, "Debug messages are now disabled.") + self.assertFalse(get_debug()) + + def test_waits_for_specified_seconds(self): + start = time.time() + result = wait.invoke({"seconds": 1.0}) + end = time.time() + + self.assertTrue(result.startswith("Waited exactly")) + self.assertAlmostEqual(end - start, 1.0, places=1) + + def test_waits_for_zero_seconds(self): + start = time.time() + result = wait.invoke({"seconds": 0}) + end = time.time() + + self.assertTrue(result.startswith("Waited exactly")) + self.assertAlmostEqual(end - start, 0.0, places=1) + + +if __name__ == "__main__": + unittest.main()