diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
new file mode 100644
index 0000000..0729c58
--- /dev/null
+++ b/.github/workflows/ci.yml
@@ -0,0 +1,70 @@
+name: CI Pipeline
+
+on:
+ push:
+ branches:
+ - main
+ - dev
+ pull_request:
+ branches:
+ - main
+ - dev
+
+jobs:
+ test-noetic:
+ runs-on: ubuntu-latest
+ container:
+ image: osrf/ros:noetic-desktop
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v2
+
+ - name: Set up Python
+ uses: actions/setup-python@v2
+ with:
+ python-version: '3.x'
+
+ - name: Install dependencies
+ run: |
+ sudo apt-get update
+ sudo apt-get install -y libc6 libc6-dev
+ sudo apt-get install -y python3.9
+ sudo apt-get install -y python3-pip
+ python3.9 -m pip install --user -e .
+ shell: bash
+
+ - name: Run tests
+ run: |
+ . /opt/ros/noetic/setup.bash
+ python3.9 -m unittest discover -s tests --verbose
+ shell: bash
+ env:
+ ROS_VERSION: 1
+
+ test-humble:
+ runs-on: ubuntu-latest
+ container:
+ image: osrf/ros:humble-desktop
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v2
+
+ - name: Set up Python
+ uses: actions/setup-python@v2
+ with:
+ python-version: '3.10'
+
+ - name: Install dependencies
+ run: |
+ sudo apt-get update
+ sudo apt-get install -y python3-pip
+ python3.10 -m pip install --user -e .
+ shell: bash
+
+ - name: Run tests
+ run: |
+ . /opt/ros/humble/setup.bash
+ python3.10 -m unittest discover -s tests --verbose
+ shell: bash
+ env:
+ ROS_VERSION: 2
diff --git a/CHANGELOG.md b/CHANGELOG.md
index f4ecc9d..b0c2b67 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,6 +5,21 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
+## [1.0.5]
+
+### Added
+
+* CI pipeline for automated testing
+* Unit tests for ROSA tools and utilities
+
+### Changed
+
+* Improvements to various ROS2 tools
+* Upgrade dependencies:
+ * `langchain` to 0.2.14
+ * `langchain_core` to 0.2.34
+ * `langchain-openai` to 0.1.22
+
## [1.0.4] - 2024-08-21
### Added
diff --git a/Dockerfile b/Dockerfile
index b26d22f..cff6308 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -28,7 +28,7 @@ RUN apt-get update && apt-get install -y \
RUN apt-get update && apt-get install -y python3.9
RUN apt-get update && apt-get install -y python3-pip
RUN python3 -m pip install -U python-dotenv catkin_tools
-RUN python3.9 -m pip install -U jpl-rosa>=1.0.4
+RUN python3.9 -m pip install -U jpl-rosa>=1.0.5
# Configure ROS
RUN rosdep update
diff --git a/README.md b/README.md
index 86cd2d3..e819ae3 100644
--- a/README.md
+++ b/README.md
@@ -7,17 +7,26 @@
ROSA - Robot Operating System Agent
-ROSA is an AI Agent designed to interact with ROS-based robotics systems using natural language queries.
-
+
+ ROSA is an AI Agent designed to interact with ROS-based robotics systems
using natural language queries.
+
+
+
-![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/nasa-jpl/rosa/publish.yml)
-![Static Badge](https://img.shields.io/badge/Python->=3.9-blue)
-![Static Badge](https://img.shields.io/badge/ROS_1-Supported-blue)
-![Static Badge](https://img.shields.io/badge/ROS_2-Supported-blue)
+![Static Badge](https://img.shields.io/badge/ROS_1-Noetic-blue)
+![Static Badge](https://img.shields.io/badge/ROS_2-Humble|Iron|Jazzy-blue)
![PyPI - License](https://img.shields.io/pypi/l/jpl-rosa)
+[![SLIM](https://img.shields.io/badge/Best%20Practices%20from-SLIM-blue)](https://nasa-ammos.github.io/slim/)
+
+![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/nasa-jpl/rosa/ci.yml?branch=main&label=main)
+![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/nasa-jpl/rosa/ci.yml?branch=dev&label=dev)
+![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/nasa-jpl/rosa/publish.yml?label=publish)
![PyPI - Version](https://img.shields.io/pypi/v/jpl-rosa)
![PyPI - Downloads](https://img.shields.io/pypi/dw/jpl-rosa)
-[![SLIM](https://img.shields.io/badge/Best%20Practices%20from-SLIM-blue)](https://nasa-ammos.github.io/slim/)
+
+
+
+
ROSA is an AI agent that can be used to interact with ROS1 _and_ ROS2 systems in order to carry out various tasks.
It is built using the open-source [Langchain](https://python.langchain.com/v0.2/docs/introduction/) framework, and can
@@ -90,9 +99,11 @@ rosa.invoke("Show me a list of topics that have publishers but no subscribers")
## TurtleSim Demo
-We have included a demo that uses ROSA to control the TurtleSim robot in simulation. To run the demo, you will need to have Docker installed on your machine.
+We have included a demo that uses ROSA to control the TurtleSim robot in simulation. To run the demo, you will need to
+have Docker installed on your machine.
-The following video shows ROSA reasoning about how to draw a 5-point star, then executing the necessary commands to do so.
+The following video shows ROSA reasoning about how to draw a 5-point star, then executing the necessary commands to do
+so.
https://github.com/user-attachments/assets/77b97014-6d2e-4123-8d0b-ea0916d93a4e
diff --git a/setup.py b/setup.py
index 926ba7a..f3e71a3 100644
--- a/setup.py
+++ b/setup.py
@@ -23,7 +23,7 @@
setup(
name="jpl-rosa",
- version="1.0.4",
+ version="1.0.5",
license="Apache 2.0",
description="ROSA: the Robot Operating System Agent",
long_description=long_description,
@@ -49,10 +49,10 @@
install_requires=[
"PyYAML==6.0.1",
"python-dotenv>=1.0.1",
- "langchain==0.2.13",
+ "langchain==0.2.14",
"langchain-community==0.2.12",
- "langchain-core==0.2.32",
- "langchain-openai==0.1.21",
+ "langchain-core==0.2.34",
+ "langchain-openai==0.1.22",
"pydantic",
"pyinputplus",
"azure-identity",
diff --git a/src/rosa/tools/__init__.py b/src/rosa/tools/__init__.py
index 78221f2..4285dde 100644
--- a/src/rosa/tools/__init__.py
+++ b/src/rosa/tools/__init__.py
@@ -19,7 +19,7 @@
from langchain.agents import Tool
-def inject_blacklist(blacklist):
+def inject_blacklist(default_blacklist: List[str]):
"""
Inject a blacklist parameter into @tool functions that require it. Required because we do not
want to rely on the LLM to manually use the blacklist, as it may "forget" to do so.
@@ -32,18 +32,28 @@ def inject_blacklist(blacklist):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
- if "blacklist" in kwargs:
- kwargs["blacklist"] = blacklist
+ if args and isinstance(args[0], dict):
+ if "blacklist" in args[0]:
+ args[0]["blacklist"] = default_blacklist + args[0]["blacklist"]
+ else:
+ args[0]["blacklist"] = default_blacklist
else:
- params = inspect.signature(func).parameters
- if "blacklist" in params:
- kwargs["blacklist"] = blacklist
+ if "blacklist" in kwargs:
+ kwargs["blacklist"] = default_blacklist + kwargs["blacklist"]
+ else:
+ params = inspect.signature(func).parameters
+ if "blacklist" in params:
+ kwargs["blacklist"] = default_blacklist
return func(*args, **kwargs)
# Rebuild the signature to include 'blacklist'
sig = inspect.signature(func)
new_params = [
- param.replace(default=blacklist) if param.name == "blacklist" else param
+ (
+ param.replace(default=default_blacklist)
+ if param.name == "blacklist"
+ else param
+ )
for param in sig.parameters.values()
]
wrapper.__signature__ = sig.replace(parameters=new_params)
@@ -68,19 +78,13 @@ def __init__(
self.__iterative_add(system)
if self.__ros_version == 1:
- try:
- from . import ros1
+ from . import ros1
- self.__iterative_add(ros1, blacklist=blacklist)
- except Exception as e:
- print(e)
+ self.__iterative_add(ros1, blacklist=blacklist)
elif self.__ros_version == 2:
- try:
- from . import ros2
+ from . import ros2
- self.__iterative_add(ros2, blacklist=blacklist)
- except Exception as e:
- print(e)
+ self.__iterative_add(ros2, blacklist=blacklist)
else:
raise ValueError("Invalid ROS version. Must be either 1 or 2.")
diff --git a/src/rosa/tools/log.py b/src/rosa/tools/log.py
index ea5a475..d6544ae 100644
--- a/src/rosa/tools/log.py
+++ b/src/rosa/tools/log.py
@@ -13,7 +13,7 @@
# limitations under the License.
import os
-from typing import Optional
+from typing import Optional, Literal
from langchain.agents import tool
@@ -22,20 +22,28 @@
def read_log(
log_file_directory: str,
log_filename: str,
- level_filter: Optional[str],
- line_range: tuple = (-200, -1),
+ level_filter: Optional[
+ Literal[
+ "ERROR", "INFO", "DEBUG", "WARNING", "CRITICAL", "FATAL", "TRACE", "DEBUG"
+ ]
+ ] = None,
+ num_lines: Optional[int] = None,
) -> dict:
"""
Read a log file and return the log lines that match the level filter and line range.
- :arg log_file_directory: The directory containing the log file to read (use your tools to get it)
- :arg log_filename: The path to the log file to read
- :arg level_filter: Only show log lines that contain this level (e.g. "ERROR", "INFO", "DEBUG", etc.)
- :arg line_range: A tuple of two integers representing the start and end line numbers to return
+ :param log_file_directory: The directory containing the log file to read (use your tools to get it)
+ :param log_filename: The path to the log file to read
+ :param level_filter: Only show log lines that contain this level (e.g. "ERROR", "INFO", "DEBUG", etc.)
+ :param num_lines: The number of most recent lines to return from the log file
"""
+ if num_lines is not None and num_lines < 1:
+ return {"error": "Invalid `num_lines` argument. It must be a positive integer."}
+
if not os.path.exists(log_file_directory):
return {
- "error": f"The log directory '{log_file_directory}' does not exist. You should first use your tools to get the correct log directory."
+ "error": f"The log directory '{log_file_directory}' does not exist. You should first use your tools to "
+ f"get the correct log directory."
}
full_log_path = os.path.join(log_file_directory, log_filename)
@@ -56,13 +64,15 @@ def read_log(
for i in range(len(log_lines)):
log_lines[i] = f"line {i+1}: " + log_lines[i].strip()
- print(f"Reading log file '{log_filename}' lines {line_range[0]} to {line_range[1]}")
- log_lines = log_lines[line_range[0] : line_range[1]]
+ if num_lines is not None:
+ # Get the most recent num_lines from the log file
+ log_lines = log_lines[-num_lines:]
# If there are more than 200 lines, return a message to use the line_range argument
if len(log_lines) > 200:
return {
- "error": f"The log file '{log_filename}' has more than 200 lines. Please use the `line_range` argument to read a subset of the log file at a time."
+ "error": f"The log file '{log_filename}' has more than 200 lines. Please use the `num_lines` argument to "
+ f"read a subset of the log file at a time."
}
if level_filter is not None:
@@ -72,7 +82,7 @@ def read_log(
"log_filename": log_filename,
"log_file_directory": log_file_directory,
"level_filter": level_filter,
- "line_range": line_range,
+ "requested_num_lines": num_lines,
"total_lines": total_lines,
"lines_returned": len(log_lines),
"lines": log_lines,
diff --git a/src/rosa/tools/ros1.py b/src/rosa/tools/ros1.py
index 36bcd93..2aeb2a5 100644
--- a/src/rosa/tools/ros1.py
+++ b/src/rosa/tools/ros1.py
@@ -51,9 +51,17 @@ def get_entities(
in_namespace = len(entities)
if pattern:
- entities = list(filter(lambda x: regex.match(f".*{pattern}", x), entities))
+ entities = list(filter(lambda x: regex.match(f".*{pattern}.*", x), entities))
match_pattern = len(entities)
+ if blacklist:
+ entities = list(
+ filter(
+ lambda x: not any(regex.match(f".*{bl}.*", x) for bl in blacklist),
+ entities,
+ )
+ )
+
if total == 0:
entities = [f"There are currently no {type}s available in the system."]
elif in_namespace == 0:
@@ -65,22 +73,11 @@ def get_entities(
f"There are currently no {type}s available matching the specified pattern."
]
- if blacklist:
- entities = list(
- filter(
- lambda x: not any(
- regex.match(f".*{pattern}", x) for pattern in blacklist
- ),
- entities,
- )
- )
-
return total, in_namespace, match_pattern, sorted(entities)
@tool
def rosgraph_get(
- namespace: Optional[str] = "/",
node_pattern: Optional[str] = ".*",
topic_pattern: Optional[str] = ".*",
blacklist: List[str] = None,
@@ -89,12 +86,12 @@ def rosgraph_get(
"""
Get a list of tuples representing nodes and topics in the ROS graph.
- :param namespace: ROS namespace to scope return values by. Namespace must already be resolved.
:param node_pattern: A regex pattern for the nodes to include in the graph (publishers and subscribers).
:param topic_pattern: A regex pattern for the topics to include in the graph.
:param exclude_self_connections: Exclude connections where the publisher and subscriber are the same node.
:note: you should avoid using the topic pattern when searching for nodes, as it may not return any results.
+ :important: you must NOT use this function to get lists of nodes, topics, etc.
Example regex patterns:
- .*node.* any node containing "node"
@@ -102,9 +99,6 @@ def rosgraph_get(
- node.* any node that starts with "node"
- (.*node1.*|.*node2.*|.*node3.*) any node containing either "node1", "node2", or "node3"
"""
- rospy.loginfo(
- f"Getting ROS graph with namespace '{namespace}', node_pattern '{node_pattern}', and topic_pattern '{topic_pattern}'"
- )
try:
publishers, subscribers, services = rosgraph.masterapi.Master(
"/rosout"
@@ -118,8 +112,6 @@ def rosgraph_get(
for pub in publishers:
for node in pub[1]:
- if namespace and not node.startswith(namespace):
- continue
if pub[0] in topic_pub_map:
topic_pub_map[pub[0]].append(node)
else:
@@ -127,8 +119,6 @@ def rosgraph_get(
for sub in subscribers:
for node in sub[1]:
- if namespace and not node.startswith(namespace):
- continue
if sub[0] in topic_sub_map:
topic_sub_map[sub[0]].append(node)
else:
@@ -229,6 +219,14 @@ def rostopic_list(
except Exception as e:
return {"error": f"Failed to get ROS topics: {e}"}
+ if blacklist:
+ topics = list(
+ filter(
+ lambda x: not any(regex.match(f".*{bl}.*", x) for bl in blacklist),
+ topics,
+ )
+ )
+
return dict(
namespace=namespace if namespace else "/",
pattern=pattern if pattern else ".*",
@@ -260,6 +258,14 @@ def rosnode_list(
except Exception as e:
return {"error": f"Failed to get ROS nodes: {e}"}
+ if blacklist:
+ nodes = list(
+ filter(
+ lambda x: not any(regex.match(f".*{bl}.*", x) for bl in blacklist),
+ nodes,
+ )
+ )
+
return dict(
namespace=namespace if namespace else "/",
pattern=pattern if pattern else ".*",
@@ -360,7 +366,6 @@ def rostopic_echo(
for i in range(count):
try:
msg = rospy.wait_for_message(topic, msg_class, timeout)
- print(msg)
if return_echoes:
msgs.append(msg)
@@ -369,7 +374,6 @@ def rostopic_echo(
time.sleep(delay)
except (rospy.ROSException, rospy.ROSInterruptException) as e:
- print(f"Failed to get message from topic '{topic}': {e}")
break
response = dict(topic=topic, requested_count=count, actual_count=len(msgs))
@@ -484,7 +488,6 @@ def rosservice_call(service: str, args: List[str]) -> dict:
:param service: The name of the ROS service to call.
:param args: A list of arguments to pass to the service.
"""
- print(f"Calling ROS service '{service}' with arguments: {args}")
try:
response = rosservice.call_service(service, args)
return response
@@ -519,7 +522,6 @@ def rossrv_info(srv_type: List[str], raw: bool = False) -> dict:
for srv in srv_type:
# Get the Python class corresponding to the srv file
- print(f"Getting details for {srv}")
srv_path = rosmsg.get_srv_text(srv, raw=raw)
details[srv] = srv_path
return details
@@ -740,3 +742,63 @@ def get_roslog_directories() -> dict:
latest=latest_directory,
from_env=from_env,
)
+
+
+@tool
+def roslaunch(package: str, launch_file: str) -> str:
+ """Launches a ROS launch file.
+
+ :param package: The name of the ROS package containing the launch file.
+ :param launch_file: The name of the launch file to launch.
+ """
+ rospy.loginfo(f"Launching ROS launch file '{launch_file}' in package '{package}'")
+ try:
+ os.system(f"roslaunch {package} {launch_file}")
+ return f"Launched ROS launch file '{launch_file}' in package '{package}'."
+ except Exception as e:
+ return f"Failed to launch ROS launch file '{launch_file}' in package '{package}': {e}"
+
+
+@tool
+def roslaunch_list(package: str) -> dict:
+ """Returns a list of available ROS launch files in a package.
+
+ :param package: The name of the ROS package to list launch files for.
+ """
+ rospy.loginfo(f"Getting ROS launch files in package '{package}'")
+ try:
+ rospack = rospkg.RosPack()
+ directory = rospack.get_path(package)
+ launch = os.path.join(directory, "launch")
+
+ launch_files = []
+
+ # Get all files in the launch directory
+ if os.path.exists(launch):
+ launch_files = [
+ f for f in os.listdir(launch) if os.path.isfile(os.path.join(launch, f))
+ ]
+
+ return {
+ "package": package,
+ "directory": directory,
+ "total": len(launch_files),
+ "launch_files": launch_files,
+ }
+
+ except Exception as e:
+ return {"error": f"Failed to get ROS launch files in package '{package}': {e}"}
+
+
+@tool
+def rosnode_kill(node: str) -> str:
+ """Kills a specific ROS node.
+
+ :param node: The name of the ROS node to kill.
+ """
+ rospy.loginfo(f"Killing ROS node '{node}'")
+ try:
+ os.system(f"rosnode kill {node}")
+ return f"Killed ROS node '{node}'."
+ except Exception as e:
+ return f"Failed to kill ROS node '{node}': {e}"
diff --git a/src/rosa/tools/ros2.py b/src/rosa/tools/ros2.py
index b825b69..bdfb359 100644
--- a/src/rosa/tools/ros2.py
+++ b/src/rosa/tools/ros2.py
@@ -83,6 +83,8 @@ def get_entities(
if pattern:
entities = list(filter(lambda x: re.match(f".*{pattern}.*", x), entities))
+ entities = [e for e in entities if e.strip() != ""]
+
return entities
@@ -189,57 +191,6 @@ def ros2_node_info(nodes: List[str]) -> dict:
return data
-def parse_ros2_topic_info(output):
- topic_info = {"name": "", "type": "", "publishers": [], "subscribers": []}
-
- lines = output.split("\n")
-
- # Extract the topic name
- for line in lines:
- if line.startswith("ros2 topic info"):
- topic_info["name"] = line.split(" ")[3]
-
- # Extract the Type
- for line in lines:
- if line.startswith("Type:"):
- topic_info["type"] = line.split(": ")[1]
-
- # Extract publisher and subscriber sections
- publisher_section = ""
- subscriber_section = ""
- collecting_publishers = False
- collecting_subscribers = False
-
- for line in lines:
- if line.startswith("Publisher count:"):
- collecting_publishers = True
- collecting_subscribers = False
- elif line.startswith("Subscription count:"):
- collecting_publishers = False
- collecting_subscribers = True
-
- if collecting_publishers:
- publisher_section += line + "\n"
- if collecting_subscribers:
- subscriber_section += line + "\n"
-
- # Extract node names for publishers
- publisher_lines = publisher_section.split("\n")
- for line in publisher_lines:
- if line.startswith("Node name:"):
- node_name = line.split(": ")[1]
- topic_info["publishers"].append(node_name)
-
- # Extract node names for subscribers
- subscriber_lines = subscriber_section.split("\n")
- for line in subscriber_lines:
- if line.startswith("Node name:"):
- node_name = line.split(": ")[1]
- topic_info["subscribers"].append(node_name)
-
- return topic_info
-
-
@tool
def ros2_topic_info(topics: List[str]) -> dict:
"""
@@ -255,7 +206,7 @@ def ros2_topic_info(topics: List[str]) -> dict:
if not success:
topic_info = dict(error=output)
else:
- topic_info = parse_ros2_topic_info(output)
+ topic_info = output
data[topic] = topic_info
@@ -276,7 +227,17 @@ def ros2_param_list(
"""
if node_name:
cmd = f"ros2 param list {node_name}"
- params = get_entities(cmd, pattern=pattern, blacklist=blacklist)
+ success, output = execute_ros_command(cmd)
+ if not success:
+ return {"error": output}
+
+ params = [o for o in output.split("\n") if o]
+ if pattern:
+ params = [p for p in params if re.match(f".*{pattern}.*", p)]
+ if blacklist:
+ params = [
+ p for p in params if not any(re.match(f".*{b}.*", p) for b in blacklist)
+ ]
return {node_name: params}
else:
cmd = f"ros2 param list"
@@ -296,6 +257,15 @@ def ros2_param_list(
data[current_node] = []
elif line.strip() != "":
data[current_node].append(line.strip())
+
+ if pattern:
+ data = {k: v for k, v in data.items() if re.match(f".*{pattern}.*", k)}
+ if blacklist:
+ data = {
+ k: v
+ for k, v in data.items()
+ if not any(re.match(f".*{b}.*", k) for b in blacklist)
+ }
return data
diff --git a/src/rosa/tools/system.py b/src/rosa/tools/system.py
index e18295a..96ae8d1 100644
--- a/src/rosa/tools/system.py
+++ b/src/rosa/tools/system.py
@@ -26,7 +26,6 @@ def set_verbosity(enable_verbose_messages: bool) -> str:
:arg enable_verbose_messages: A boolean value to enable or disable verbose messages.
"""
global VERBOSE
- print(f"Setting verbosity to {enable_verbose_messages}")
VERBOSE = enable_verbose_messages
set_verbose(VERBOSE)
return f"Verbose messages are now {'enabled' if VERBOSE else 'disabled'}."
@@ -41,7 +40,6 @@ def set_debuging(enable_debug_messages: bool) -> str:
:arg enable_debug_messages: A boolean value to enable or disable debug messages.
"""
global DEBUG
- print(f"Setting debug to {enable_debug_messages}")
DEBUG = enable_debug_messages
set_debug(DEBUG)
return f"Debug messages are now {'enabled' if DEBUG else 'disabled'}."
diff --git a/src/turtle_agent/scripts/tools/turtle.py b/src/turtle_agent/scripts/tools/turtle.py
index 398ddee..0068d20 100644
--- a/src/turtle_agent/scripts/tools/turtle.py
+++ b/src/turtle_agent/scripts/tools/turtle.py
@@ -48,7 +48,7 @@ def within_bounds(x: float, y: float) -> tuple:
"""
if 0 <= x <= 11 and 0 <= y <= 11:
return True, "Coordinates are within bounds."
- elif x < 0 or x > 11 or y < 0 or y > 11:
+ else:
return False, f"({x}, {y}) will be out of bounds. Range is [0, 11] for each."
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/rosa/__init__.py b/tests/rosa/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/rosa/test_prompts.py b/tests/rosa/test_prompts.py
new file mode 100644
index 0000000..b0da2ec
--- /dev/null
+++ b/tests/rosa/test_prompts.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2024. Jet Propulsion Laboratory. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/rosa/test_rosa.py b/tests/rosa/test_rosa.py
new file mode 100644
index 0000000..b0da2ec
--- /dev/null
+++ b/tests/rosa/test_rosa.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2024. Jet Propulsion Laboratory. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/rosa/tools/__init__.py b/tests/rosa/tools/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/rosa/tools/test_calculation.py b/tests/rosa/tools/test_calculation.py
new file mode 100644
index 0000000..b594299
--- /dev/null
+++ b/tests/rosa/tools/test_calculation.py
@@ -0,0 +1,195 @@
+# Copyright (c) 2024. Jet Propulsion Laboratory. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import statistics
+import unittest
+
+from src.rosa.tools.calculation import (
+ add_all,
+ multiply_all,
+ mean,
+ median,
+ mode,
+ variance,
+ add,
+ subtract,
+ multiply,
+ divide,
+ exponentiate,
+ modulo,
+ sine,
+ cosine,
+ tangent,
+ asin,
+ acos,
+ atan,
+ sinh,
+ cosh,
+ tanh,
+ count_list,
+ count_words,
+ count_lines,
+ degrees_to_radians,
+ radians_to_degrees,
+)
+
+
+class TestCalculationTools(unittest.TestCase):
+
+ def test_add_all_returns_sum_of_numbers(self):
+ self.assertEqual(add_all.invoke({"numbers": [1, 2, 3]}), 6)
+ self.assertEqual(add_all.invoke({"numbers": []}), 0)
+
+ def test_multiply_all_returns_product_of_numbers(self):
+ self.assertEqual(multiply_all.invoke({"numbers": [1, 2, 3]}), 6)
+ self.assertEqual(multiply_all.invoke({"numbers": [1, 0, 3]}), 0)
+
+ def test_mean_returns_mean_and_stdev_of_numbers(self):
+ self.assertEqual(mean.invoke({"numbers": [1, 2, 3]}), {"mean": 2, "stdev": 1})
+ with self.assertRaises(statistics.StatisticsError):
+ mean.invoke({"numbers": []})
+
+ def test_median_returns_median_of_numbers(self):
+ self.assertEqual(median.invoke({"numbers": [1, 2, 3]}), 2)
+ self.assertEqual(median.invoke({"numbers": [1, 2, 3, 4]}), 2.5)
+
+ def test_mode_returns_mode_of_numbers(self):
+ self.assertEqual(mode.invoke({"numbers": [1, 1, 2, 3]}), 1)
+ self.assertEqual(mode.invoke({"numbers": [1, 2, 3]}), 1)
+
+ def test_variance_returns_variance_of_numbers(self):
+ self.assertEqual(variance.invoke({"numbers": [1, 2, 3]}), 1)
+ with self.assertRaises(statistics.StatisticsError):
+ variance.invoke({"numbers": [1]})
+
+ def test_add_returns_sum_of_xy_pairs(self):
+ self.assertEqual(
+ add.invoke({"xy_pairs": [(1, 2), (3, 4)]}), [{"1+2": 3}, {"3+4": 7}]
+ )
+
+ def test_subtract_returns_difference_of_xy_pairs(self):
+ self.assertEqual(
+ subtract.invoke({"xy_pairs": [(1, 2), (3, 4)]}), [{"1-2": -1}, {"3-4": -1}]
+ )
+
+ def test_multiply_returns_product_of_xy_pairs(self):
+ self.assertEqual(
+ multiply.invoke({"xy_pairs": [(1, 2), (3, 4)]}), [{"1*2": 2}, {"3*4": 12}]
+ )
+
+ def test_divide_returns_quotient_of_xy_pairs(self):
+ self.assertEqual(
+ divide.invoke({"xy_pairs": [(1, 2), (3, 4)]}), [{"1/2": 0.5}, {"3/4": 0.75}]
+ )
+ self.assertEqual(divide.invoke({"xy_pairs": [(1, 0)]}), [{"1/0": "undefined"}])
+
+ def test_exponentiate_returns_exponentiation_of_xy_pairs(self):
+ self.assertEqual(
+ exponentiate.invoke({"xy_pairs": [(2, 3), (3, 2)]}),
+ [{"2^3": 8}, {"3^2": 9}],
+ )
+
+ def test_modulo_returns_modulo_of_xy_pairs(self):
+ self.assertEqual(
+ modulo.invoke({"xy_pairs": [(5, 3), (10, 2)]}), [{"5%3": 2}, {"10%2": 0}]
+ )
+ self.assertEqual(modulo.invoke({"xy_pairs": [(1, 0)]}), [{"1%0": "undefined"}])
+
+ def test_sine_returns_sine_of_x_values(self):
+ self.assertAlmostEqual(
+ sine.invoke({"x_values": [0, math.pi / 2]}),
+ [{"sin(0.0)": 0.0}, {"sin(1.5707963267948966)": 1.0}],
+ )
+
+ def test_cosine_returns_cosine_of_x_values(self):
+ cosines = cosine.invoke({"x_values": [0, math.pi / 2]})
+ self.assertAlmostEqual(cosines[0]["cos(0.0)"], 1.0, delta=0.0000000000000001)
+ self.assertAlmostEqual(
+ cosines[1]["cos(1.5707963267948966)"], 0.0, delta=0.0000000000000001
+ )
+
+ def test_tangent_returns_tangent_of_x_values(self):
+ # Convert the above to use assertAlmostEqual
+ tangents = tangent.invoke({"x_values": [0, math.pi / 4]})
+ self.assertAlmostEqual(tangents[0]["tan(0.0)"], 0.0, delta=0.0000000000000001)
+ self.assertAlmostEqual(
+ tangents[1]["tan(0.7853981633974483)"], 1.0, delta=0.000000000000001
+ )
+
+ def test_asin_returns_arcsine_of_x_values(self):
+ self.assertEqual(
+ asin.invoke({"x_values": [0, 1]}),
+ [{"asin(0.0)": 0.0}, {"asin(1.0)": 1.5707963267948966}],
+ )
+ self.assertEqual(asin.invoke({"x_values": [2]}), [{"asin(2.0)": "undefined"}])
+
+ def test_acos_returns_arccosine_of_x_values(self):
+ self.assertEqual(
+ acos.invoke({"x_values": [0, 1]}),
+ [{"acos(0.0)": 1.5707963267948966}, {"acos(1.0)": 0.0}],
+ )
+ self.assertEqual(acos.invoke({"x_values": [2]}), [{"acos(2.0)": "undefined"}])
+
+ def test_atan_returns_arctangent_of_x_values(self):
+ self.assertEqual(
+ atan.invoke({"x_values": [0, 1]}),
+ [{"atan(0.0)": 0.0}, {"atan(1.0)": 0.7853981633974483}],
+ )
+
+ def test_sinh_returns_hyperbolic_sine_of_x_values(self):
+ self.assertEqual(
+ sinh.invoke({"x_values": [0, 1]}),
+ [{"sinh(0.0)": 0.0}, {"sinh(1.0)": 1.1752011936438014}],
+ )
+
+ def test_cosh_returns_hyperbolic_cosine_of_x_values(self):
+ self.assertAlmostEqual(
+ cosh.invoke({"x_values": [0, 1]}),
+ [{"cosh(0.0)": 1.0}, {"cosh(1.0)": 1.5430806348152437}],
+ )
+
+ def test_tanh_returns_hyperbolic_tangent_of_x_values(self):
+ self.assertEqual(
+ tanh.invoke({"x_values": [0, 1]}),
+ [{"tanh(0.0)": 0.0}, {"tanh(1.0)": 0.7615941559557649}],
+ )
+
+ def test_count_list_returns_number_of_items_in_list(self):
+ self.assertEqual(count_list.invoke({"items": [1, 2, 3]}), 3)
+ self.assertEqual(count_list.invoke({"items": []}), 0)
+
+ def test_count_words_returns_number_of_words_in_string(self):
+ self.assertEqual(count_words.invoke({"text": "Hello world"}), 2)
+ self.assertEqual(count_words.invoke({"text": ""}), 0)
+
+ def test_count_lines_returns_number_of_lines_in_string(self):
+ self.assertEqual(count_lines.invoke({"text": "Hello\nworld"}), 2)
+ self.assertEqual(count_lines.invoke({"text": ""}), 1)
+
+ def test_degrees_to_radians_converts_degrees_to_radians(self):
+ self.assertEqual(
+ degrees_to_radians.invoke({"degrees": [0, 180]}),
+ {0: "0.0 radians.", 180: "3.14159 radians."},
+ )
+
+ def test_radians_to_degrees_converts_radians_to_degrees(self):
+ self.assertEqual(
+ radians_to_degrees.invoke({"radians": [0, 3.14159]}),
+ {0: "0.0 degrees.", 3.14159: "180.0 degrees."},
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/rosa/tools/test_log.py b/tests/rosa/tools/test_log.py
new file mode 100644
index 0000000..ca6d09a
--- /dev/null
+++ b/tests/rosa/tools/test_log.py
@@ -0,0 +1,176 @@
+# Copyright (c) 2024. Jet Propulsion Laboratory. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+from unittest.mock import patch, mock_open
+
+from src.rosa.tools.log import read_log
+
+
+class TestReadLog(unittest.TestCase):
+
+ @patch("os.path.exists")
+ def test_log_directory_does_not_exist(self, mock_exists):
+ mock_exists.return_value = False
+ result = read_log.invoke(
+ {
+ "log_file_directory": "/invalid/directory",
+ "log_filename": "logfile.log",
+ }
+ )
+ self.assertEqual(
+ result["error"],
+ "The log directory '/invalid/directory' does not exist. You should first use your tools to get the "
+ "correct log directory.",
+ )
+
+ @patch("os.path.exists")
+ def test_log_path_is_not_a_file(self, mock_exists):
+ mock_exists.side_effect = [True, True]
+ with patch("os.path.isfile", return_value=False):
+ result = read_log.invoke(
+ {
+ "log_file_directory": "/valid/directory",
+ "log_filename": "logfile.log",
+ }
+ )
+ self.assertEqual(
+ result["error"],
+ "The path '/valid/directory/logfile.log' is not a file.",
+ )
+
+ @patch(
+ "builtins.open",
+ new_callable=mock_open,
+ read_data="INFO: line 1\nERROR: line 2\nDEBUG: line 3\n",
+ )
+ @patch("os.path.exists", return_value=True)
+ @patch("os.path.isfile", return_value=True)
+ def test_read_log_with_level_filter(self, mock_exists, mock_isfile, mock_file):
+ result = read_log.invoke(
+ {
+ "log_file_directory": "/valid/directory",
+ "log_filename": "logfile.log",
+ "level_filter": "ERROR",
+ }
+ )
+ self.assertEqual(result["lines"], ["line 2: ERROR: line 2"])
+
+ @patch(
+ "builtins.open",
+ new_callable=mock_open,
+ read_data="INFO: line 1\nERROR: line 2\nDEBUG: line 3\n",
+ )
+ @patch("os.path.exists", return_value=True)
+ @patch("os.path.isfile", return_value=True)
+ def test_read_log_with_line_range(self, mock_exists, mock_isfile, mock_file):
+ result = read_log.invoke(
+ {
+ "log_file_directory": "/valid/directory",
+ "log_filename": "logfile.log",
+ "num_lines": 2,
+ }
+ )
+ self.assertEqual(
+ result["lines"], ["line 2: ERROR: line 2", "line 3: DEBUG: line 3"]
+ )
+
+ @patch("builtins.open", new_callable=mock_open, read_data="INFO: line 1\n" * 202)
+ @patch("os.path.exists", return_value=True)
+ @patch("os.path.isfile", return_value=True)
+ def test_log_file_exceeds_200_lines(self, mock_exists, mock_isfile, mock_file):
+ result = read_log.invoke(
+ {
+ "log_file_directory": "/valid/directory",
+ "log_filename": "logfile.log",
+ "num_lines": 203,
+ }
+ )
+ self.assertEqual(
+ result["error"],
+ "The log file 'logfile.log' has more than 200 lines. Please use the `num_lines` argument to read a subset "
+ "of the log file at a time.",
+ )
+
+ @patch(
+ "builtins.open",
+ new_callable=mock_open,
+ read_data="INFO: line 1\nERROR: line 2\nDEBUG: line 3\n",
+ )
+ @patch("os.path.exists", return_value=True)
+ @patch("os.path.isfile", return_value=True)
+ def test_read_log_happy_path(self, mock_exists, mock_isfile, mock_file):
+ result = read_log.invoke(
+ {
+ "log_file_directory": "/valid/directory",
+ "log_filename": "logfile.log",
+ }
+ )
+ self.assertEqual(
+ result["lines"],
+ ["line 1: INFO: line 1", "line 2: ERROR: line 2", "line 3: DEBUG: line 3"],
+ )
+
+ @patch("os.path.exists", return_value=True)
+ @patch("os.path.isfile", return_value=True)
+ def test_invalid_num_lines_argument(self, mock_exists, mock_isfile):
+ with patch(
+ "builtins.open",
+ new_callable=mock_open,
+ read_data="INFO: line 1\nERROR: line 2\n",
+ ):
+ result = read_log.invoke(
+ {
+ "log_file_directory": "/valid/directory",
+ "log_filename": "logfile.log",
+ "num_lines": -1,
+ }
+ )
+ self.assertEqual(
+ result["error"],
+ "Invalid `num_lines` argument. It must be a positive integer.",
+ )
+
+ @patch("os.path.exists", return_value=True)
+ @patch("os.path.isfile", return_value=True)
+ def test_empty_log_file(self, mock_exists, mock_isfile):
+ with patch("builtins.open", new_callable=mock_open, read_data=""):
+ result = read_log.invoke(
+ {
+ "log_file_directory": "/valid/directory",
+ "log_filename": "logfile.log",
+ }
+ )
+ self.assertEqual(result["lines"], [])
+
+ @patch("os.path.exists", return_value=True)
+ @patch("os.path.isfile", return_value=True)
+ def test_specific_log_level_not_present(self, mock_exists, mock_isfile):
+ with patch(
+ "builtins.open",
+ new_callable=mock_open,
+ read_data="INFO: line 1\nDEBUG: line 2\n",
+ ):
+ result = read_log.invoke(
+ {
+ "log_file_directory": "/valid/directory",
+ "log_filename": "logfile.log",
+ "level_filter": "ERROR",
+ }
+ )
+ self.assertEqual(result["lines"], [])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/rosa/tools/test_ros1.py b/tests/rosa/tools/test_ros1.py
new file mode 100644
index 0000000..2fdd463
--- /dev/null
+++ b/tests/rosa/tools/test_ros1.py
@@ -0,0 +1,603 @@
+# Copyright (c) 2024. Jet Propulsion Laboratory. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import unittest
+from unittest.mock import patch, MagicMock
+
+try:
+ from src.rosa.tools.ros1 import (
+ get_entities,
+ rosgraph_get,
+ rostopic_list,
+ rostopic_info,
+ rostopic_echo,
+ rosnode_list,
+ rosnode_info,
+ rosservice_list,
+ rosservice_info,
+ rosservice_call,
+ rosmsg_info,
+ rossrv_info,
+ rosparam_list,
+ rosparam_get,
+ rosparam_set,
+ rospkg_list,
+ rospkg_roots,
+ roslog_list,
+ )
+except ModuleNotFoundError:
+ pass
+
+
+@unittest.skipIf(
+ os.environ.get("ROS_VERSION") == "2",
+ "Skipping ROS1 tests because ROS_VERSION is set to 2",
+)
+class TestROS1Tools(unittest.TestCase):
+
+ @patch("src.rosa.tools.ros1.rostopic.get_topic_list")
+ def test_get_entities_topics(self, mock_get_topic_list):
+ mock_get_topic_list.return_value = (
+ [("/turtle1/cmd_vel", "std_msgs/Empty")],
+ [("/turtle1/pose", "std_msgs/Empty")],
+ )
+ total, in_namespace, match_pattern, entities = get_entities("topic", None, None)
+ self.assertEqual(total, 2)
+ self.assertEqual(in_namespace, 2)
+ self.assertEqual(match_pattern, 2)
+ self.assertIn("/turtle1/cmd_vel", entities)
+ self.assertIn("/turtle1/pose", entities)
+
+ @patch("src.rosa.tools.ros1.rosnode.get_node_names")
+ def test_get_entities_nodes(self, mock_get_node_names):
+ mock_get_node_names.return_value = ["/turtlesim"]
+ total, in_namespace, match_pattern, entities = get_entities("node", None, None)
+ self.assertEqual(total, 1)
+ self.assertEqual(in_namespace, 1)
+ self.assertEqual(match_pattern, 1)
+ self.assertIn("/turtlesim", entities)
+
+ @patch("src.rosa.tools.ros1.rosgraph.masterapi.Master.getSystemState")
+ def test_rosgraph_get_returns_graph(self, mock_get_system_state):
+ mock_get_system_state.return_value = (
+ [("/topic1", ["/node1"]), ("/topic2", ["/node2"])],
+ [("/topic1", ["/node3"]), ("/topic2", ["/node4"])],
+ [],
+ )
+ result = rosgraph_get.invoke(
+ {
+ "node_pattern": ".*",
+ "topic_pattern": ".*",
+ "blacklist": [],
+ "exclude_self_connections": True,
+ }
+ )
+ self.assertIn("graph", result)
+ self.assertEqual(len(result["graph"]), 2)
+
+ @patch("src.rosa.tools.ros1.rosgraph.masterapi.Master.getSystemState")
+ def test_rosgraph_get_handles_empty_graph(self, mock_get_system_state):
+ mock_get_system_state.return_value = ([], [], [])
+ result = rosgraph_get.invoke(
+ {
+ "node_pattern": ".*",
+ "topic_pattern": ".*",
+ "blacklist": [],
+ "exclude_self_connections": True,
+ }
+ )
+ self.assertIn("error", result)
+ self.assertEqual(
+ result["error"],
+ "No results found for the specified parameters. Note that the following have been excluded: []",
+ )
+
+ @patch("src.rosa.tools.ros1.rosgraph.masterapi.Master.getSystemState")
+ def test_rosgraph_get_excludes_blacklisted_nodes(self, mock_get_system_state):
+ mock_get_system_state.return_value = (
+ [("/topic1", ["/node1"]), ("/topic2", ["/node2"])],
+ [("/topic1", ["/node3"]), ("/topic2", ["/node4"])],
+ [],
+ )
+ result = rosgraph_get.invoke(
+ {
+ "node_pattern": ".*",
+ "topic_pattern": ".*",
+ "blacklist": ["node1"],
+ "exclude_self_connections": True,
+ }
+ )
+ self.assertIn("graph", result)
+ self.assertEqual(len(result["graph"]), 1)
+ self.assertNotIn("/node1", result["graph"][0])
+
+ @patch("src.rosa.tools.ros1.rosgraph.masterapi.Master.getSystemState")
+ def test_rosgraph_get_excludes_self_connections(self, mock_get_system_state):
+ mock_get_system_state.return_value = (
+ [("/topic1", ["/node1"])],
+ [("/topic1", ["/node1"])],
+ [],
+ )
+ result = rosgraph_get.invoke(
+ {
+ "node_pattern": ".*",
+ "topic_pattern": ".*",
+ "blacklist": [],
+ "exclude_self_connections": True,
+ }
+ )
+ self.assertIn("error", result)
+ self.assertEqual(
+ result["error"],
+ "No results found for the specified parameters. Note that the following have been excluded: []",
+ )
+
+ @patch("src.rosa.tools.ros1.rostopic.get_info_text")
+ def test_rostopic_info(self, mock_get_info_text):
+ mock_get_info_text.return_value = (
+ "Type: std_msgs/String\nPublishers:\n* /turtlesim\nSubscribers:\n* /rosout"
+ )
+ result = rostopic_info.invoke({"topics": ["/turtle1/cmd_vel"]})
+ self.assertIn("/turtle1/cmd_vel", result)
+ self.assertEqual(result["/turtle1/cmd_vel"]["type"], "std_msgs/String")
+ self.assertIn("/turtlesim", result["/turtle1/cmd_vel"]["publishers"])
+ self.assertIn("/rosout", result["/turtle1/cmd_vel"]["subscribers"])
+
+ @patch("src.rosa.tools.ros1.rospy.wait_for_message")
+ @patch("src.rosa.tools.ros1.rostopic.get_topic_class")
+ def test_rostopic_echo(self, mock_get_topic_class, mock_wait_for_message):
+ mock_get_topic_class.return_value = (MagicMock(), None, None)
+ mock_wait_for_message.return_value = MagicMock()
+ result = rostopic_echo.invoke(
+ {"topic": "/turtle1/cmd_vel", "count": 1, "return_echoes": True}
+ )
+ self.assertEqual(result["requested_count"], 1)
+ self.assertEqual(result["actual_count"], 1)
+ self.assertIn("echoes", result)
+
+ @patch("rosnode.get_node_names")
+ def test_rosnode_list_returns_all_nodes(self, mock_get_node_names):
+ mock_get_node_names.return_value = ["/node1", "/node2", "/node3"]
+ result = rosnode_list.invoke({})
+ self.assertEqual(result["total"], 3)
+ self.assertEqual(result["in_namespace"], 3)
+ self.assertEqual(result["match_pattern"], 3)
+ self.assertEqual(result["nodes"], ["/node1", "/node2", "/node3"])
+
+ @patch("rosnode.get_node_names")
+ def test_rosnode_list_filters_by_namespace(self, mock_get_node_names):
+ mock_get_node_names.return_value = ["/namespace1/node1", "/namespace2/node2"]
+ result = rosnode_list.invoke({"namespace": "/namespace1"})
+ self.assertEqual(result["total"], 2)
+ self.assertEqual(result["in_namespace"], 1)
+ self.assertEqual(result["match_pattern"], 1)
+ self.assertEqual(result["nodes"], ["/namespace1/node1"])
+
+ @patch("rosnode.get_node_names")
+ def test_rosnode_list_filters_by_pattern(self, mock_get_node_names):
+ mock_get_node_names.return_value = ["/node1", "/node2", "/node3"]
+ result = rosnode_list.invoke({"pattern": "node1"})
+ self.assertEqual(result["total"], 3)
+ self.assertEqual(result["in_namespace"], 3)
+ self.assertEqual(result["match_pattern"], 1)
+ self.assertEqual(result["nodes"], ["/node1"])
+
+ @patch("rosnode.get_node_names")
+ def test_rosnode_list_handles_no_nodes(self, mock_get_node_names):
+ mock_get_node_names.return_value = []
+ result = rosnode_list.invoke({})
+ self.assertEqual(result["total"], 0)
+ self.assertEqual(result["in_namespace"], 0)
+ self.assertEqual(result["match_pattern"], 0)
+ self.assertEqual(
+ result["nodes"], ["There are currently no nodes available in the system."]
+ )
+
+ @patch("rosnode.get_node_names")
+ def test_rosnode_list_handles_no_nodes_in_namespace(self, mock_get_node_names):
+ mock_get_node_names.return_value = ["/node1", "/node2"]
+ result = rosnode_list.invoke({"namespace": "/namespace1"})
+ self.assertEqual(result["total"], 2)
+ self.assertEqual(result["in_namespace"], 0)
+ self.assertEqual(result["match_pattern"], 0)
+ self.assertEqual(
+ result["nodes"],
+ [
+ "There are currently no nodes available using the '/namespace1' namespace."
+ ],
+ )
+
+ @patch("rosnode.get_node_names")
+ def test_rosnode_list_handles_no_nodes_matching_pattern(self, mock_get_node_names):
+ mock_get_node_names.return_value = ["/node1", "/node2"]
+ result = rosnode_list.invoke({"pattern": "node3"})
+ self.assertEqual(result["total"], 2)
+ self.assertEqual(result["in_namespace"], 2)
+ self.assertEqual(result["match_pattern"], 0)
+ self.assertEqual(
+ result["nodes"],
+ ["There are currently no nodes available matching the specified pattern."],
+ )
+
+ @patch("rosnode.get_node_names")
+ def test_rosnode_list_filters_by_blacklist(self, mock_get_node_names):
+ mock_get_node_names.return_value = ["/node1", "/node2", "/node3"]
+ result = rosnode_list.invoke({"blacklist": ["node2"]})
+ self.assertEqual(result["total"], 3)
+ self.assertEqual(result["in_namespace"], 3)
+ self.assertEqual(result["match_pattern"], 3)
+ self.assertEqual(result["nodes"], ["/node1", "/node3"])
+
+ @patch("src.rosa.tools.ros1.rosnode.get_node_info_description")
+ def test_rosnode_info(self, mock_get_node_info_description):
+ mock_get_node_info_description.return_value = (
+ "Node: /turtlesim\nPublications: /turtle1/cmd_vel"
+ )
+ result = rosnode_info.invoke({"nodes": ["/turtlesim"]})
+ self.assertIn("/turtlesim", result)
+ self.assertIn("Node: /turtlesim", result["/turtlesim"])
+
+ @patch("src.rosa.tools.ros1.rosservice.get_service_list")
+ def test_rosservice_list(self, mock_get_service_list):
+ mock_get_service_list.return_value = ["/clear", "/reset"]
+ result = rosservice_list.invoke({})
+ self.assertIn("/clear", result)
+ self.assertIn("/reset", result)
+
+ @patch("src.rosa.tools.ros1.rosservice.get_service_headers")
+ @patch("src.rosa.tools.ros1.rosservice.get_service_uri")
+ def test_rosservice_info(self, mock_get_service_uri, mock_get_service_headers):
+ mock_get_service_uri.return_value = "rosrpc://localhost:12345"
+ mock_get_service_headers.return_value = {"callerid": "/turtlesim"}
+ result = rosservice_info.invoke({"services": ["/clear"]})
+ self.assertIn("/clear", result)
+ self.assertIn("callerid", result["/clear"])
+
+ @patch("src.rosa.tools.ros1.rosservice.call_service")
+ def test_rosservice_call(self, mock_call_service):
+ mock_call_service.return_value = "success"
+ result = rosservice_call.invoke({"service": "/clear", "args": []})
+ self.assertEqual(result, "success")
+
+ @patch("src.rosa.tools.ros1.rosmsg.get_msg_text")
+ def test_rosmsg_info(self, mock_get_msg_text):
+ mock_get_msg_text.return_value = "string data"
+ result = rosmsg_info.invoke({"msg_type": ["std_msgs/String"]})
+ self.assertIn("std_msgs/String", result)
+ self.assertEqual(result["std_msgs/String"], "string data")
+
+ @patch("src.rosa.tools.ros1.rosmsg.get_srv_text")
+ def test_rossrv_info(self, mock_get_srv_text):
+ mock_get_srv_text.return_value = "string data"
+ result = rossrv_info.invoke({"srv_type": ["std_srvs/Empty"]})
+ self.assertIn("std_srvs/Empty", result)
+ self.assertEqual(result["std_srvs/Empty"], "string data")
+
+ @patch("src.rosa.tools.ros1.rosparam.list_params")
+ def test_rosparam_list(self, mock_list_params):
+ mock_list_params.return_value = [
+ "/turtlesim/background_r",
+ "/turtlesim/background_g",
+ ]
+ result = rosparam_list.invoke({})
+ self.assertIn("/turtlesim/background_r", result["ros_params"])
+ self.assertIn("/turtlesim/background_g", result["ros_params"])
+
+ @patch("src.rosa.tools.ros1.rosparam.get_param")
+ def test_rosparam_get(self, mock_get_param):
+ mock_get_param.return_value = 255
+ result = rosparam_get.invoke({"params": ["/turtlesim/background_r"]})
+ self.assertIn("/turtlesim/background_r", result)
+ self.assertEqual(result["/turtlesim/background_r"], 255)
+
+ @patch("src.rosa.tools.ros1.rosparam.set_param")
+ def test_rosparam_set(self, mock_set_param):
+ result = rosparam_set.invoke(
+ {"param": "/turtlesim/background_r", "value": "255", "is_rosa_param": False}
+ )
+ self.assertEqual(result, "Set parameter '/turtlesim/background_r' to '255'.")
+
+ @patch("src.rosa.tools.ros1.rospkg.RosPack.list")
+ def test_rospkg_list(self, mock_list):
+ mock_list.return_value = ["turtlesim", "std_msgs"]
+ result = rospkg_list.invoke({"ignore_msgs": True})
+ self.assertIn("turtlesim", result["packages"])
+ self.assertNotIn("std_msgs", result["packages"])
+
+ result = rospkg_list.invoke({"ignore_msgs": False})
+ self.assertIn("turtlesim", result["packages"])
+ self.assertIn("std_msgs", result["packages"])
+
+ @patch("src.rosa.tools.ros1.rospkg.get_ros_package_path")
+ def test_rospkg_roots(self, mock_get_ros_package_path):
+ mock_get_ros_package_path.return_value = ["/opt/ros/noetic/share"]
+ result = rospkg_roots.invoke({})
+ self.assertIn("/opt/ros/noetic/share", result)
+
+ @patch("src.rosa.tools.ros1.get_roslog_directories")
+ @patch("os.listdir")
+ @patch("os.path.isfile")
+ @patch("os.path.getsize")
+ def test_roslog_list_with_min_size(
+ self, mock_getsize, mock_isfile, mock_listdir, mock_get_roslog_directories
+ ):
+ mock_get_roslog_directories.return_value = {"default": "/mock/log/dir"}
+ mock_listdir.return_value = ["log1.log", "log2.log", "log3.log"]
+ mock_isfile.side_effect = lambda x: x.endswith(".log")
+ mock_getsize.side_effect = lambda x: 3000 if "log1.log" in x else 1000
+
+ result = roslog_list.invoke({"min_size": 2048})
+
+ self.assertEqual(result["total"], 1)
+ self.assertEqual(len(result["logs"]), 1)
+ self.assertEqual(result["logs"][0]["total"], 1)
+ self.assertIn("/log1.log", result["logs"][0]["files"][0])
+
+ @patch("src.rosa.tools.ros1.get_roslog_directories")
+ @patch("os.listdir")
+ @patch("os.path.isfile")
+ @patch("os.path.getsize")
+ def test_roslog_list_with_blacklist(
+ self, mock_getsize, mock_isfile, mock_listdir, mock_get_roslog_directories
+ ):
+ mock_get_roslog_directories.return_value = {"default": "/mock/log/dir"}
+ mock_listdir.return_value = ["log1.log", "log2.log", "log3.log"]
+ mock_isfile.side_effect = lambda x: x.endswith(".log")
+ mock_getsize.side_effect = lambda x: 3000
+
+ result = roslog_list.invoke({"blacklist": ["log2"]})
+
+ self.assertEqual(result["total"], 1)
+ self.assertEqual(len(result["logs"]), 1)
+ self.assertEqual(result["logs"][0]["total"], 2)
+ self.assertNotIn("log2.log", result["logs"][0]["files"][0])
+
+ @patch("src.rosa.tools.ros1.get_roslog_directories")
+ @patch("os.listdir")
+ @patch("os.path.isfile")
+ @patch("os.path.getsize")
+ def test_roslog_list_no_logs(
+ self, mock_getsize, mock_isfile, mock_listdir, mock_get_roslog_directories
+ ):
+ mock_get_roslog_directories.return_value = {"default": "/mock/log/dir"}
+ mock_listdir.return_value = []
+ mock_isfile.side_effect = lambda x: x.endswith(".log")
+ mock_getsize.side_effect = lambda x: 3000
+
+ result = roslog_list.invoke({})
+
+ self.assertEqual(result["total"], 0)
+ self.assertEqual(len(result["logs"]), 0)
+
+ @patch("src.rosa.tools.ros1.get_roslog_directories")
+ @patch("os.listdir")
+ @patch("os.path.isfile")
+ @patch("os.path.getsize")
+ def test_roslog_list_with_multiple_directories(
+ self, mock_getsize, mock_isfile, mock_listdir, mock_get_roslog_directories
+ ):
+ mock_get_roslog_directories.return_value = {
+ "default": "/mock/log/dir1",
+ "latest": "/mock/log/dir2",
+ }
+ mock_listdir.side_effect = lambda x: (
+ ["log1.log", "log2.log"] if "dir1" in x else ["log3.log", "log4.log"]
+ )
+ mock_isfile.side_effect = lambda x: x.endswith(".log")
+ mock_getsize.side_effect = lambda x: 3000
+
+ result = roslog_list.invoke({})
+
+ self.assertEqual(result["total"], 2)
+ self.assertEqual(len(result["logs"]), 2)
+ self.assertEqual(result["logs"][0]["total"], 2)
+ self.assertEqual(result["logs"][1]["total"], 2)
+
+ @patch("rospy.loginfo")
+ @patch("src.rosa.tools.ros1.get_entities")
+ def test_rostopic_list_returns_all_topics(self, mock_get_entities, mock_loginfo):
+ mock_get_entities.return_value = (10, 10, 10, ["topic1", "topic2"])
+ result = rostopic_list.invoke({})
+ self.assertEqual(result["total"], 10)
+ self.assertEqual(result["in_namespace"], 10)
+ self.assertEqual(result["match_pattern"], 10)
+ self.assertEqual(result["topics"], ["topic1", "topic2"])
+
+ @patch("rospy.loginfo")
+ @patch("src.rosa.tools.ros1.get_entities")
+ def test_rostopic_list_with_pattern(self, mock_get_entities, mock_loginfo):
+ mock_get_entities.return_value = (10, 10, 2, ["topic1", "topic2"])
+ result = rostopic_list.invoke({"pattern": "topic"})
+ self.assertEqual(result["match_pattern"], 2)
+ self.assertEqual(result["topics"], ["topic1", "topic2"])
+
+ @patch("rospy.loginfo")
+ @patch("src.rosa.tools.ros1.get_entities")
+ def test_rostopic_list_with_namespace(self, mock_get_entities, mock_loginfo):
+ mock_get_entities.return_value = (
+ 10,
+ 5,
+ 5,
+ ["namespace/topic1", "namespace/topic2"],
+ )
+ result = rostopic_list.invoke({"namespace": "namespace"})
+ self.assertEqual(result["in_namespace"], 5)
+ self.assertEqual(result["topics"], ["namespace/topic1", "namespace/topic2"])
+
+ @patch("rospy.loginfo")
+ @patch("src.rosa.tools.ros1.get_entities")
+ def test_rostopic_list_with_blacklist(self, mock_get_entities, mock_loginfo):
+ mock_get_entities.return_value = (2, 2, 2, ["topic1", "topic2"])
+ result = rostopic_list.invoke({"blacklist": ["topic2"]})
+ self.assertEqual(result["topics"], ["topic1"])
+
+ @patch("rospy.loginfo")
+ @patch("src.rosa.tools.ros1.get_entities")
+ def test_rostopic_list_no_topics_available(self, mock_get_entities, mock_loginfo):
+ mock_get_entities.return_value = (
+ 0,
+ 0,
+ 0,
+ ["There are currently no topics available in the system."],
+ )
+ result = rostopic_list.invoke({})
+ self.assertEqual(result["total"], 0)
+ self.assertEqual(
+ result["topics"], ["There are currently no topics available in the system."]
+ )
+
+ @patch("src.rosa.tools.ros1.rostopic.get_topic_list")
+ @patch("src.rosa.tools.ros1.rosnode.get_node_names")
+ def test_get_entities_topics(self, mock_get_node_names, mock_get_topic_list):
+ mock_get_topic_list.return_value = (
+ [(f"/topic{i}", "type") for i in range(5)],
+ [(f"/topic{i}", "type") for i in range(5, 10)],
+ )
+ total, in_namespace, match_pattern, entities = get_entities("topic", None, None)
+ self.assertEqual(total, 10)
+ self.assertEqual(in_namespace, 10)
+ self.assertEqual(match_pattern, 10)
+ self.assertEqual(len(entities), 10)
+
+ @patch("src.rosa.tools.ros1.rostopic.get_topic_list")
+ @patch("src.rosa.tools.ros1.rosnode.get_node_names")
+ def test_get_entities_nodes(self, mock_get_node_names, mock_get_topic_list):
+ mock_get_node_names.return_value = [f"/node{i}" for i in range(10)]
+ total, in_namespace, match_pattern, entities = get_entities("node", None, None)
+ self.assertEqual(total, 10)
+ self.assertEqual(in_namespace, 10)
+ self.assertEqual(match_pattern, 10)
+ self.assertEqual(len(entities), 10)
+
+ @patch("src.rosa.tools.ros1.rostopic.get_topic_list")
+ @patch("src.rosa.tools.ros1.rosnode.get_node_names")
+ def test_get_entities_with_namespace(
+ self, mock_get_node_names, mock_get_topic_list
+ ):
+ mock_get_topic_list.return_value = (
+ [(f"/namespace/topic{i}", "type") for i in range(5)]
+ + [(f"/topic{i}", "type") for i in range(5)],
+ [(f"/namespace/topic{i}", "type") for i in range(5, 10)],
+ )
+
+ mock_get_node_names.return_value = [
+ f"/namespace/node{i}" for i in range(10)
+ ] + [f"/node{i}" for i in range(10)]
+
+ total, in_namespace, match_pattern, entities = get_entities(
+ "topic", None, "/namespace"
+ )
+ self.assertEqual(total, 15)
+ self.assertEqual(in_namespace, 10)
+ self.assertEqual(match_pattern, 10)
+ self.assertEqual(len(entities), 10)
+
+ total, in_namespace, match_pattern, entities = get_entities(
+ "node", None, "/namespace"
+ )
+ self.assertEqual(total, 20)
+ self.assertEqual(in_namespace, 10)
+ self.assertEqual(match_pattern, 10)
+ self.assertEqual(len(entities), 10)
+
+ @patch("src.rosa.tools.ros1.rostopic.get_topic_list")
+ @patch("src.rosa.tools.ros1.rosnode.get_node_names")
+ def test_get_entities_with_pattern(self, mock_get_node_names, mock_get_topic_list):
+ mock_get_topic_list.return_value = (
+ [(f"/topic{i}", "type") for i in range(5)],
+ [(f"/topic{i}", "type") for i in range(5, 10)],
+ )
+ total, in_namespace, match_pattern, entities = get_entities(
+ "topic", "topic[0-4]", None
+ )
+ self.assertEqual(total, 10)
+ self.assertEqual(in_namespace, 10)
+ self.assertEqual(match_pattern, 5)
+ self.assertEqual(len(entities), 5)
+
+ @patch("src.rosa.tools.ros1.rostopic.get_topic_list")
+ @patch("src.rosa.tools.ros1.rosnode.get_node_names")
+ def test_get_entities_with_blacklist(
+ self, mock_get_node_names, mock_get_topic_list
+ ):
+ mock_get_topic_list.return_value = (
+ [(f"/topic{i}", "type") for i in range(5)],
+ [(f"/topic{i}", "type") for i in range(5, 10)],
+ )
+ total, in_namespace, match_pattern, entities = get_entities(
+ "topic", None, None, blacklist=["/topic0", "/topic1"]
+ )
+ self.assertEqual(total, 10)
+ self.assertEqual(in_namespace, 10)
+ self.assertEqual(match_pattern, 10)
+ self.assertEqual(len(entities), 8)
+
+ @patch("src.rosa.tools.ros1.rostopic.get_topic_list")
+ @patch("src.rosa.tools.ros1.rosnode.get_node_names")
+ def test_get_entities_no_entities(self, mock_get_node_names, mock_get_topic_list):
+ mock_get_topic_list.return_value = ([], [])
+ total, in_namespace, match_pattern, entities = get_entities("topic", None, None)
+ self.assertEqual(total, 0)
+ self.assertEqual(in_namespace, 0)
+ self.assertEqual(match_pattern, 0)
+ self.assertEqual(
+ entities, ["There are currently no topics available in the system."]
+ )
+
+ @patch("src.rosa.tools.ros1.rostopic.get_topic_list")
+ @patch("src.rosa.tools.ros1.rosnode.get_node_names")
+ def test_get_entities_no_namespace_entities(
+ self, mock_get_node_names, mock_get_topic_list
+ ):
+ mock_get_topic_list.return_value = (
+ [(f"/topic{i}", "type") for i in range(5)],
+ [(f"/topic{i}", "type") for i in range(5, 10)],
+ )
+ total, in_namespace, match_pattern, entities = get_entities(
+ "topic", None, "/nonexistent"
+ )
+ self.assertEqual(total, 10)
+ self.assertEqual(in_namespace, 0)
+ self.assertEqual(match_pattern, 0)
+ self.assertEqual(
+ entities,
+ [
+ "There are currently no topics available using the '/nonexistent' namespace."
+ ],
+ )
+
+ @patch("src.rosa.tools.ros1.rostopic.get_topic_list")
+ @patch("src.rosa.tools.ros1.rosnode.get_node_names")
+ def test_get_entities_no_pattern_entities(
+ self, mock_get_node_names, mock_get_topic_list
+ ):
+ mock_get_topic_list.return_value = (
+ [(f"/topic{i}", "type") for i in range(5)],
+ [(f"/topic{i}", "type") for i in range(5, 10)],
+ )
+ total, in_namespace, match_pattern, entities = get_entities(
+ "topic", "nonexistent", None
+ )
+ self.assertEqual(total, 10)
+ self.assertEqual(in_namespace, 10)
+ self.assertEqual(match_pattern, 0)
+ self.assertEqual(
+ entities,
+ ["There are currently no topics available matching the specified pattern."],
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/rosa/tools/test_ros2.py b/tests/rosa/tools/test_ros2.py
new file mode 100644
index 0000000..2dc4cf5
--- /dev/null
+++ b/tests/rosa/tools/test_ros2.py
@@ -0,0 +1,290 @@
+# Copyright (c) 2024. Jet Propulsion Laboratory. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import subprocess
+import unittest
+from unittest.mock import patch
+
+try:
+ from src.rosa.tools.ros2 import (
+ execute_ros_command,
+ ros2_node_list,
+ ros2_topic_list,
+ ros2_topic_echo,
+ ros2_service_list,
+ ros2_node_info,
+ ros2_param_list,
+ ros2_param_get,
+ ros2_param_set,
+ )
+except ModuleNotFoundError:
+ pass
+
+
+@unittest.skipIf(
+ os.environ.get("ROS_VERSION") == "1",
+ "Skipping ROS2 tests because ROS_VERSION is set to 1",
+)
+class TestROS2Tools(unittest.TestCase):
+
+ @patch("src.rosa.tools.ros2.subprocess.check_output")
+ def test_execute_valid_ros2_command(self, mock_check_output):
+ mock_check_output.return_value = b"Node /example_node\n"
+ success, output = execute_ros_command("ros2 node list")
+ self.assertTrue(success)
+ self.assertEqual(output, "Node /example_node\n")
+
+ @patch("src.rosa.tools.ros2.subprocess.check_output")
+ def test_execute_invalid_ros2_command(self, mock_check_output):
+ mock_check_output.side_effect = subprocess.CalledProcessError(
+ 1, "ros2 node list"
+ )
+ success, output = execute_ros_command("ros2 node list")
+ self.assertFalse(success)
+ self.assertIn(
+ "Command 'ros2 node list' returned non-zero exit status 1.", output
+ )
+
+ def test_execute_command_with_invalid_prefix(self):
+ with self.assertRaises(ValueError):
+ execute_ros_command("invalid node list")
+
+ def test_execute_command_with_invalid_subcommand(self):
+ with self.assertRaises(ValueError):
+ execute_ros_command("ros2 invalid_subcommand")
+
+ def test_execute_command_with_insufficient_arguments(self):
+ with self.assertRaises(ValueError):
+ execute_ros_command("ros2")
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_node_list_returns_nodes(self, mock_execute):
+ mock_execute.return_value = (True, "/node1\n/node2\n")
+ result = ros2_node_list.invoke({"pattern": None})
+ self.assertEqual(result, {"nodes": ["/node1", "/node2"]})
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_node_list_with_pattern(self, mock_execute):
+ mock_execute.return_value = (True, "/node1\n/node2\n")
+ result = ros2_node_list.invoke({"pattern": "node1"})
+ self.assertEqual(result, {"nodes": ["/node1"]})
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_node_list_with_blacklist(self, mock_execute):
+ mock_execute.return_value = (True, "/node1\n/node2\n")
+ result = ros2_node_list.invoke({"blacklist": ["node2"]})
+ self.assertEqual(result, {"nodes": ["/node1"]})
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_node_list_invalid_command(self, mock_execute):
+ mock_execute.return_value = (False, "Invalid command")
+ result = ros2_node_list.invoke({"pattern": None})
+ self.assertEqual(result, {"nodes": ["Invalid command"]})
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_topic_list_returns_topics(self, mock_execute):
+ mock_execute.return_value = (True, "/topic1\n/topic2\n")
+ result = ros2_topic_list.invoke({"pattern": None})
+ self.assertEqual(result, {"topics": ["/topic1", "/topic2"]})
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_topic_list_with_pattern(self, mock_execute):
+ mock_execute.return_value = (True, "/topic1\n/topic2\n")
+ result = ros2_topic_list.invoke({"pattern": "topic1"})
+ self.assertEqual(result, {"topics": ["/topic1"]})
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_topic_list_with_blacklist(self, mock_execute):
+ mock_execute.return_value = (True, "/topic1\n/topic2\n")
+ result = ros2_topic_list.invoke({"blacklist": ["topic2"]})
+ self.assertEqual(result, {"topics": ["/topic1"]})
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_topic_list_invalid_command(self, mock_execute):
+ mock_execute.return_value = (False, "Invalid command")
+ result = ros2_topic_list.invoke({"pattern": None})
+ self.assertEqual(result, {"topics": ["Invalid command"]})
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_topic_echo_success(self, mock_execute):
+ mock_execute.return_value = (True, "Message 1\n")
+ result = ros2_topic_echo.invoke(
+ {"topic": "/example_topic", "count": 1, "return_echoes": True}
+ )
+ self.assertEqual(result, {"echoes": ["Message 1\n"]})
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_topic_echo_multiple_messages(self, mock_execute):
+ mock_execute.return_value = (True, "Message 1\n")
+ result = ros2_topic_echo.invoke(
+ {"topic": "/example_topic", "count": 3, "return_echoes": True}
+ )
+ self.assertEqual(
+ result, {"echoes": ["Message 1\n", "Message 1\n", "Message 1\n"]}
+ )
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_topic_echo_invalid_topic(self, mock_execute):
+ mock_execute.return_value = (False, "Invalid topic")
+ result = ros2_topic_echo.invoke({"topic": "/invalid_topic", "count": 1})
+ self.assertEqual(result, {"error": "Invalid topic"})
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_topic_echo_invalid_count(self, mock_execute):
+ result = ros2_topic_echo.invoke({"topic": "/example_topic", "count": 11})
+ self.assertEqual(result, {"error": "Count must be between 1 and 10."})
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_topic_echo_command_failure(self, mock_execute):
+ mock_execute.return_value = (False, "Command failed")
+ result = ros2_topic_echo.invoke({"topic": "/example_topic", "count": 1})
+ self.assertEqual(result, {"error": "Command failed"})
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_service_list_returns_services(self, mock_execute):
+ mock_execute.return_value = (True, "/service1\n/service2\n")
+ result = ros2_service_list.invoke({"pattern": None})
+ self.assertEqual(result, {"services": ["/service1", "/service2"]})
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_service_list_with_pattern(self, mock_execute):
+ mock_execute.return_value = (True, "/service1\n/service2\n")
+ result = ros2_service_list.invoke({"pattern": "service1"})
+ self.assertEqual(result, {"services": ["/service1"]})
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_service_list_with_blacklist(self, mock_execute):
+ mock_execute.return_value = (True, "/service1\n/service2\n")
+ result = ros2_service_list.invoke({"blacklist": ["service2"]})
+ self.assertEqual(result, {"services": ["/service1"]})
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_service_list_invalid_command(self, mock_execute):
+ mock_execute.return_value = (False, "Invalid command")
+ result = ros2_service_list.invoke({"pattern": None})
+ self.assertEqual(result, {"services": ["Invalid command"]})
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_node_info_success(self, mock_execute):
+ mock_execute.return_value = (True, "Node info for /node1")
+ result = ros2_node_info.invoke({"nodes": ["/node1"]})
+ self.assertEqual(result, {"/node1": "Node info for /node1"})
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_node_info_multiple_nodes(self, mock_execute):
+ mock_execute.side_effect = [
+ (True, "Node info for /node1"),
+ (True, "Node info for /node2"),
+ ]
+ result = ros2_node_info.invoke({"nodes": ["/node1", "/node2"]})
+ self.assertEqual(
+ result, {"/node1": "Node info for /node1", "/node2": "Node info for /node2"}
+ )
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_node_info_invalid_node(self, mock_execute):
+ mock_execute.return_value = (False, "Invalid node")
+ result = ros2_node_info.invoke({"nodes": ["/invalid_node"]})
+ self.assertEqual(result, {"/invalid_node": {"error": "Invalid node"}})
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_node_info_command_failure(self, mock_execute):
+ mock_execute.return_value = (False, "Command failed")
+ result = ros2_node_info.invoke({"nodes": ["/node1"]})
+ self.assertEqual(result, {"/node1": {"error": "Command failed"}})
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_param_list_returns_params_for_node(self, mock_execute):
+ mock_execute.return_value = (True, "param1\nparam2\n")
+ result = ros2_param_list.invoke({"node_name": "/example_node"})
+ self.assertEqual(result, {"/example_node": ["param1", "param2"]})
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_param_list_returns_all_params(self, mock_execute):
+ mock_execute.return_value = (
+ True,
+ "/node1\n param1\n param2\n/node2\n param3\n",
+ )
+ result = ros2_param_list.invoke({})
+ self.assertEqual(result, {"/node1": ["param1", "param2"], "/node2": ["param3"]})
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_param_list_with_pattern(self, mock_execute):
+ mock_execute.return_value = (True, "param1\nparam2\n")
+ result = ros2_param_list.invoke(
+ {"node_name": "/example_node", "pattern": "param1"}
+ )
+ self.assertEqual(result, {"/example_node": ["param1"]})
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_param_list_with_blacklist(self, mock_execute):
+ mock_execute.return_value = (True, "param1\nparam2\n")
+ result = ros2_param_list.invoke(
+ {"node_name": "/example_node", "blacklist": ["param2"]}
+ )
+ self.assertEqual(result, {"/example_node": ["param1"]})
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_param_list_invalid_command(self, mock_execute):
+ mock_execute.return_value = (False, "Invalid command")
+ result = ros2_param_list.invoke({"node_name": "/example_node"})
+ self.assertEqual(result, {"error": "Invalid command"})
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_param_get_success(self, mock_execute):
+ mock_execute.return_value = (True, "value1")
+ result = ros2_param_get.invoke(
+ {"node_name": "/example_node", "param_name": "param1"}
+ )
+ self.assertEqual(result, {"param1": "value1"})
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_param_get_invalid_command(self, mock_execute):
+ mock_execute.return_value = (False, "Invalid command")
+ result = ros2_param_get.invoke(
+ {"node_name": "/example_node", "param_name": "param1"}
+ )
+ self.assertEqual(result, {"error": "Invalid command"})
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_param_set_success(self, mock_execute):
+ mock_execute.return_value = (True, "value1")
+ result = ros2_param_set.invoke(
+ {
+ "node_name": "/example_node",
+ "param_name": "param1",
+ "param_value": "value1",
+ }
+ )
+ self.assertEqual(result, {"param1": "value1"})
+
+ @patch("src.rosa.tools.ros2.execute_ros_command")
+ def test_ros2_param_set_invalid_command(self, mock_execute):
+ mock_execute.return_value = (False, "Invalid command")
+ result = ros2_param_set.invoke(
+ {
+ "node_name": "/example_node",
+ "param_name": "param1",
+ "param_value": "value1",
+ }
+ )
+ self.assertEqual(result, {"error": "Invalid command"})
+
+
+if __name__ == "__main__":
+ import os
+
+ if os.environ.get("ROS_VERSION") == 2:
+ unittest.main()
diff --git a/tests/rosa/tools/test_rosa_tools.py b/tests/rosa/tools/test_rosa_tools.py
new file mode 100644
index 0000000..4cecdf3
--- /dev/null
+++ b/tests/rosa/tools/test_rosa_tools.py
@@ -0,0 +1,110 @@
+# Copyright (c) 2024. Jet Propulsion Laboratory. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import unittest
+from unittest.mock import patch
+
+from langchain.agents import tool
+
+from src.rosa.tools import ROSATools, inject_blacklist
+
+
+@tool
+def sample_tool(blacklist=None):
+ """A sample tool that returns the blacklist."""
+ return blacklist
+
+
+class TestROSATools(unittest.TestCase):
+ def setUp(self):
+ self.ros_version = int(os.getenv("ROS_VERSION", 1))
+
+ def test_initializes_with_ros_version_1(self):
+ if self.ros_version == 1:
+ tools = ROSATools(ros_version=1)
+ self.assertEqual(tools._ROSATools__ros_version, 1)
+ else:
+ with self.assertRaises(ModuleNotFoundError):
+ tools = ROSATools(ros_version=1)
+ self.assertEqual(tools._ROSATools__ros_version, 1)
+
+ def test_initializes_with_ros_version_2(self):
+ if self.ros_version == 2:
+ tools = ROSATools(ros_version=2)
+ self.assertEqual(tools._ROSATools__ros_version, 2)
+ else:
+ with self.assertRaises(ModuleNotFoundError):
+ tools = ROSATools(ros_version=2)
+ self.assertEqual(tools._ROSATools__ros_version, 2)
+
+ def test_raises_value_error_for_invalid_ros_version(self):
+ if self.ros_version == 1:
+ with self.assertRaises(ModuleNotFoundError):
+ ROSATools(ros_version=2)
+ else:
+ with self.assertRaises(ModuleNotFoundError):
+ ROSATools(ros_version=1)
+
+ @patch("src.rosa.tools.calculation")
+ @patch("src.rosa.tools.log")
+ @patch("src.rosa.tools.system")
+ def test_adds_default_tools(self, mock_system, mock_log, mock_calculation):
+ if self.ros_version == 1:
+ tools = ROSATools(ros_version=1)
+ else:
+ tools = ROSATools(ros_version=2)
+ self.assertIn(mock_calculation.return_value, tools.get_tools())
+ self.assertIn(mock_log.return_value, tools.get_tools())
+ self.assertIn(mock_system.return_value, tools.get_tools())
+
+ def test_injects_blacklist_into_tool_function(self):
+ def sample_tool(blacklist=None):
+ return blacklist
+
+ decorated_tool = inject_blacklist(["item1", "item2"])(sample_tool)
+ self.assertEqual(decorated_tool(), ["item1", "item2"])
+
+ def test_blacklist_gets_concatenated(self):
+ decorated_tool = inject_blacklist(["item1", "item2"])(sample_tool)
+ self.assertEqual(
+ decorated_tool({"blacklist": ["item3"]}),
+ ["item1", "item2", "item3"],
+ )
+
+
+@unittest.skipIf(os.environ.get("ROS_VERSION") == "2", "Skipping ROS 1 tests")
+class TestROSA1Tools(unittest.TestCase):
+ @patch("src.rosa.tools.ros1")
+ def test_ros1_tools(self, mock_ros1):
+ tools = ROSATools(ros_version=1)
+ self.assertIn(mock_ros1.return_value, tools.get_tools())
+ with self.assertRaises(ModuleNotFoundError):
+ tools = ROSATools(ros_version=2)
+ self.assertIn(mock_ros1.return_value, tools.get_tools())
+
+
+@unittest.skipIf(os.environ.get("ROS_VERSION") == "1", "Skipping ROS 2 tests")
+class TestROSA2Tools(unittest.TestCase):
+ @patch("src.rosa.tools.ros2")
+ def test_ros2_tools(self, mock_ros2):
+ tools = ROSATools(ros_version=2)
+ self.assertIn(mock_ros2.return_value, tools.get_tools())
+ with self.assertRaises(ModuleNotFoundError):
+ tools = ROSATools(ros_version=1)
+ self.assertIn(mock_ros2.return_value, tools.get_tools())
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/rosa/tools/test_system.py b/tests/rosa/tools/test_system.py
new file mode 100644
index 0000000..55924a1
--- /dev/null
+++ b/tests/rosa/tools/test_system.py
@@ -0,0 +1,60 @@
+# Copyright (c) 2024. Jet Propulsion Laboratory. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+import unittest
+
+from langchain.globals import get_debug, get_verbose, set_debug
+
+from src.rosa.tools.system import set_verbosity, set_debuging, wait
+
+
+class TestSystemTools(unittest.TestCase):
+
+ def test_sets_verbosity_to_true(self):
+ result = set_verbosity.invoke({"enable_verbose_messages": True})
+ self.assertEqual(result, "Verbose messages are now enabled.")
+ self.assertTrue(get_verbose())
+ result = set_verbosity.invoke({"enable_verbose_messages": False})
+ self.assertEqual(result, "Verbose messages are now disabled.")
+ self.assertFalse(get_verbose())
+
+ def test_sets_debug_to_true(self):
+ result = set_debuging.invoke({"enable_debug_messages": True})
+ self.assertEqual(result, "Debug messages are now enabled.")
+ self.assertTrue(get_debug())
+ set_debug(False)
+ result = set_debuging.invoke({"enable_debug_messages": False})
+ self.assertEqual(result, "Debug messages are now disabled.")
+ self.assertFalse(get_debug())
+
+ def test_waits_for_specified_seconds(self):
+ start = time.time()
+ result = wait.invoke({"seconds": 1.0})
+ end = time.time()
+
+ self.assertTrue(result.startswith("Waited exactly"))
+ self.assertAlmostEqual(end - start, 1.0, places=1)
+
+ def test_waits_for_zero_seconds(self):
+ start = time.time()
+ result = wait.invoke({"seconds": 0})
+ end = time.time()
+
+ self.assertTrue(result.startswith("Waited exactly"))
+ self.assertAlmostEqual(end - start, 0.0, places=1)
+
+
+if __name__ == "__main__":
+ unittest.main()