Skip to content

Commit 27d00e2

Browse files
committed
Add a basic mock implementation of generic traits
1 parent 63a282b commit 27d00e2

File tree

6 files changed

+308
-4
lines changed

6 files changed

+308
-4
lines changed

Cargo.lock

Lines changed: 24 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@ members = [
88
"roslibrust_codegen_macro",
99
"roslibrust_genmsg",
1010
"roslibrust_test",
11+
"roslibrust_mock",
1112
]
1213
resolver = "2"

roslibrust/src/lib.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,11 @@
100100
101101
mod rosbridge;
102102
pub use rosbridge::*;
103-
use roslibrust_codegen::RosServiceType;
103+
104+
// Re export the codegen traits so that crates that only interact with abstract messages
105+
// don't need to depend on the codegen crate
106+
pub use roslibrust_codegen::RosMessageType;
107+
pub use roslibrust_codegen::RosServiceType;
104108

105109
#[cfg(feature = "rosapi")]
106110
pub mod rosapi;

roslibrust_mock/Cargo.toml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
[package]
2+
name = "roslibrust_mock"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
[dependencies]
7+
roslibrust = { path = "../roslibrust", features = ["topic_provider"] }
8+
tokio = { version = "1.41", features = ["sync", "rt-multi-thread", "macros"] }
9+
# Used for serializing messages
10+
bincode = "1.3"
11+
# We add logging to aid in debugging tests
12+
log = "0.4"
13+
14+
[dev-dependencies]
15+
roslibrust_codegen = { path = "../roslibrust_codegen" }
16+
roslibrust_codegen_macro = { path = "../roslibrust_codegen_macro" }

roslibrust_mock/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# RosLibRust Mock
2+
3+
A mock implementation of roslibrust's generic traits for use in building automated testing of nodes.

roslibrust_mock/src/lib.rs

Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
use std::collections::BTreeMap;
2+
use std::sync::Arc;
3+
4+
use roslibrust::topic_provider::*;
5+
use roslibrust::RosLibRustError;
6+
use roslibrust::RosLibRustResult;
7+
use roslibrust::RosMessageType;
8+
9+
use roslibrust::RosServiceType;
10+
use roslibrust::ServiceFn;
11+
use tokio::sync::broadcast as Channel;
12+
use tokio::sync::RwLock;
13+
14+
use log::*;
15+
16+
type TypeErasedCallback = Arc<
17+
dyn Fn(Vec<u8>) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>>
18+
+ Send
19+
+ Sync
20+
+ 'static,
21+
>;
22+
23+
pub struct MockRos {
24+
// We could probably achieve some fancier type erasure than actually serializing the data
25+
// but this ends up being pretty simple
26+
topics: RwLock<BTreeMap<String, (Channel::Sender<Vec<u8>>, Channel::Receiver<Vec<u8>>)>>,
27+
services: RwLock<BTreeMap<String, TypeErasedCallback>>,
28+
}
29+
30+
impl MockRos {
31+
pub fn new() -> Self {
32+
Self {
33+
topics: RwLock::new(BTreeMap::new()),
34+
services: RwLock::new(BTreeMap::new()),
35+
}
36+
}
37+
}
38+
39+
// This is a very basic mocking of sending and receiving messages over topics
40+
// It does not implement automatic shutdown of topics on dropping
41+
impl TopicProvider for MockRos {
42+
type Publisher<T: RosMessageType> = MockPublisher<T>;
43+
type Subscriber<T: RosMessageType> = MockSubscriber<T>;
44+
45+
async fn advertise<T: RosMessageType>(
46+
&self,
47+
topic: &str,
48+
) -> RosLibRustResult<Self::Publisher<T>> {
49+
// Check if we already have this channel
50+
{
51+
let topics = self.topics.read().await;
52+
if let Some((sender, _)) = topics.get(topic) {
53+
debug!("Issued new publisher to existing topic {}", topic);
54+
return Ok(MockPublisher {
55+
sender: sender.clone(),
56+
_marker: Default::default(),
57+
});
58+
}
59+
} // Drop read lock here
60+
// Create a new channel
61+
let tx_rx = Channel::channel(10);
62+
let tx_copy = tx_rx.0.clone();
63+
let mut topics = self.topics.write().await;
64+
topics.insert(topic.to_string(), tx_rx);
65+
debug!("Created new publisher and channel for topic {}", topic);
66+
Ok(MockPublisher {
67+
sender: tx_copy,
68+
_marker: Default::default(),
69+
})
70+
}
71+
72+
async fn subscribe<T: RosMessageType>(
73+
&self,
74+
topic: &str,
75+
) -> RosLibRustResult<Self::Subscriber<T>> {
76+
// Check if we already have this channel
77+
{
78+
let topics = self.topics.read().await;
79+
if let Some((_, receiver)) = topics.get(topic) {
80+
debug!("Issued new subscriber to existing topic {}", topic);
81+
return Ok(MockSubscriber {
82+
receiver: receiver.resubscribe(),
83+
_marker: Default::default(),
84+
});
85+
}
86+
} // Drop read lock here
87+
// Create a new channel
88+
let tx_rx = Channel::channel(10);
89+
let rx_copy = tx_rx.1.resubscribe();
90+
let mut topics = self.topics.write().await;
91+
topics.insert(topic.to_string(), tx_rx);
92+
debug!("Created new subscriber and channel for topic {}", topic);
93+
Ok(MockSubscriber {
94+
receiver: rx_copy,
95+
_marker: Default::default(),
96+
})
97+
}
98+
}
99+
100+
pub struct MockServiceClient<T: RosServiceType> {
101+
callback: TypeErasedCallback,
102+
_marker: std::marker::PhantomData<T>,
103+
}
104+
105+
impl<T: RosServiceType> Service<T> for MockServiceClient<T> {
106+
async fn call(&self, request: &T::Request) -> RosLibRustResult<T::Response> {
107+
let data = bincode::serialize(request)
108+
.map_err(|e| RosLibRustError::SerializationError(e.to_string()))?;
109+
let response = (self.callback)(data)
110+
.map_err(|e| RosLibRustError::SerializationError(e.to_string()))?;
111+
let response = bincode::deserialize(&response[..])
112+
.map_err(|e| RosLibRustError::SerializationError(e.to_string()))?;
113+
Ok(response)
114+
}
115+
}
116+
117+
impl ServiceProvider for MockRos {
118+
type ServiceClient<T: RosServiceType> = MockServiceClient<T>;
119+
type ServiceServer = ();
120+
121+
async fn service_client<T: RosServiceType + 'static>(
122+
&self,
123+
topic: &str,
124+
) -> RosLibRustResult<Self::ServiceClient<T>> {
125+
let services = self.services.read().await;
126+
if let Some(callback) = services.get(topic) {
127+
return Ok(MockServiceClient {
128+
callback: callback.clone(),
129+
_marker: Default::default(),
130+
});
131+
}
132+
Err(RosLibRustError::Disconnected)
133+
}
134+
135+
async fn advertise_service<T: RosServiceType + 'static, F>(
136+
&self,
137+
topic: &str,
138+
server: F,
139+
) -> RosLibRustResult<Self::ServiceServer>
140+
where
141+
F: ServiceFn<T>,
142+
{
143+
// Type erase the service function here
144+
let erased_closure =
145+
move |message: Vec<u8>| -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
146+
let request = bincode::deserialize(&message[..])
147+
.map_err(|e| RosLibRustError::SerializationError(e.to_string()))?;
148+
let response = server(request)?;
149+
let bytes = bincode::serialize(&response)
150+
.map_err(|e| RosLibRustError::SerializationError(e.to_string()))?;
151+
Ok(bytes)
152+
};
153+
let erased_closure = Arc::new(erased_closure);
154+
let mut services = self.services.write().await;
155+
services.insert(topic.to_string(), erased_closure);
156+
157+
// We technically need to hand back a token that shuts the service down here
158+
// But we haven't implemented that yet in this mock
159+
Ok(())
160+
}
161+
}
162+
163+
pub struct MockPublisher<T: RosMessageType> {
164+
sender: Channel::Sender<Vec<u8>>,
165+
_marker: std::marker::PhantomData<T>,
166+
}
167+
168+
impl<T: RosMessageType> Publish<T> for MockPublisher<T> {
169+
async fn publish(&self, data: &T) -> RosLibRustResult<()> {
170+
let data = bincode::serialize(data)
171+
.map_err(|e| RosLibRustError::SerializationError(e.to_string()))?;
172+
self.sender
173+
.send(data)
174+
.map_err(|_e| RosLibRustError::Disconnected)?;
175+
debug!("Sent data on topic {}", T::ROS_TYPE_NAME);
176+
Ok(())
177+
}
178+
}
179+
180+
pub struct MockSubscriber<T: RosMessageType> {
181+
receiver: Channel::Receiver<Vec<u8>>,
182+
_marker: std::marker::PhantomData<T>,
183+
}
184+
185+
impl<T: RosMessageType> Subscribe<T> for MockSubscriber<T> {
186+
async fn next(&mut self) -> RosLibRustResult<T> {
187+
let data = self
188+
.receiver
189+
.recv()
190+
.await
191+
.map_err(|_| RosLibRustError::Disconnected)?;
192+
let msg = bincode::deserialize(&data[..])
193+
.map_err(|e| RosLibRustError::SerializationError(e.to_string()))?;
194+
debug!("Received data on topic {}", T::ROS_TYPE_NAME);
195+
Ok(msg)
196+
}
197+
}
198+
199+
#[cfg(test)]
200+
mod tests {
201+
use super::*;
202+
203+
roslibrust_codegen_macro::find_and_generate_ros_messages!(
204+
"assets/ros1_common_interfaces/std_msgs",
205+
"assets/ros1_common_interfaces/ros_comm_msgs/std_srvs"
206+
);
207+
208+
#[tokio::test(flavor = "multi_thread")]
209+
async fn test_mock_topics() {
210+
let mock_ros = MockRos::new();
211+
212+
let pub_handle = mock_ros
213+
.advertise::<std_msgs::String>("test_topic")
214+
.await
215+
.unwrap();
216+
let mut sub_handle = mock_ros
217+
.subscribe::<std_msgs::String>("test_topic")
218+
.await
219+
.unwrap();
220+
221+
let msg = std_msgs::String {
222+
data: "Hello, world!".to_string(),
223+
};
224+
225+
pub_handle.publish(&msg).await.unwrap();
226+
227+
let received_msg = sub_handle.next().await.unwrap();
228+
229+
assert_eq!(msg, received_msg);
230+
}
231+
232+
#[tokio::test(flavor = "multi_thread")]
233+
async fn test_mock_services() {
234+
let mock_topics = MockRos::new();
235+
236+
let server_fn = |request: std_srvs::SetBoolRequest| {
237+
Ok(std_srvs::SetBoolResponse {
238+
success: request.data,
239+
message: "You set my bool!".to_string(),
240+
})
241+
};
242+
243+
mock_topics
244+
.advertise_service::<std_srvs::SetBool, _>("test_service", server_fn)
245+
.await
246+
.unwrap();
247+
248+
let client = mock_topics
249+
.service_client::<std_srvs::SetBool>("test_service")
250+
.await
251+
.unwrap();
252+
253+
let request = std_srvs::SetBoolRequest { data: true };
254+
255+
let response = client.call(&request).await.unwrap();
256+
assert_eq!(response.success, true);
257+
assert_eq!(response.message, "You set my bool!");
258+
}
259+
}

0 commit comments

Comments
 (0)