Skip to content

Commit aed288e

Browse files
committed
typed action
fmt typed action example typed action TypedChannel + Chnl trait
1 parent 7c059dd commit aed288e

File tree

7 files changed

+391
-4
lines changed

7 files changed

+391
-4
lines changed

examples/typed_action.rs

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
//! Use [`TypedAction`] to rewrite compute_dag.rs
2+
//!
3+
//! Only use Dag, execute a job. The graph is as follows:
4+
//!
5+
//! ↱----------↴
6+
//! B -→ E --→ G
7+
//! ↗ ↗ ↗
8+
//! A --→ C /
9+
//! ↘ ↘ /
10+
//! D -→ F
11+
//!
12+
//! The final execution result is 272.
13+
14+
use std::sync::Arc;
15+
16+
use async_trait::async_trait;
17+
use dagrs::{
18+
connection::{in_channel::TypedInChannels, out_channel::TypedOutChannels},
19+
node::typed_action::TypedAction,
20+
Content, DefaultNode, EnvVar, Graph, Node, NodeTable, Output,
21+
};
22+
23+
const BASE: &str = "base";
24+
25+
struct Compute(usize);
26+
27+
#[async_trait]
28+
impl TypedAction for Compute {
29+
type I = usize;
30+
type O = usize;
31+
32+
async fn run(
33+
&self,
34+
mut in_channels: TypedInChannels<Self::I>,
35+
out_channels: TypedOutChannels<Self::O>,
36+
env: Arc<EnvVar>,
37+
) -> Output {
38+
let base = env.get::<usize>(BASE).unwrap();
39+
let mut sum = self.0;
40+
41+
// Collect all input values from input channels
42+
let inputs = in_channels
43+
.map(|result| {
44+
if let Ok(Some(value)) = result {
45+
*value
46+
} else {
47+
0
48+
}
49+
})
50+
.await;
51+
52+
// Calculate the sum
53+
for input in inputs {
54+
sum += input * base;
55+
}
56+
57+
// Broadcast the result to all output channels
58+
out_channels.broadcast(sum).await;
59+
60+
Output::Out(Some(Content::new(sum)))
61+
}
62+
}
63+
64+
fn main() {
65+
env_logger::init();
66+
67+
let mut node_table = NodeTable::default();
68+
69+
let a = DefaultNode::with_action("Compute A".to_string(), Compute(1), &mut node_table);
70+
let a_id = a.id();
71+
72+
let b = DefaultNode::with_action("Compute B".to_string(), Compute(2), &mut node_table);
73+
let b_id = b.id();
74+
75+
let mut c = DefaultNode::new("Compute C".to_string(), &mut node_table);
76+
c.set_action(Compute(4));
77+
let c_id = c.id();
78+
79+
let mut d = DefaultNode::new("Compute D".to_string(), &mut node_table);
80+
d.set_action(Compute(8));
81+
let d_id = d.id();
82+
83+
let e = DefaultNode::with_action("Compute E".to_string(), Compute(16), &mut node_table);
84+
let e_id = e.id();
85+
let f = DefaultNode::with_action("Compute F".to_string(), Compute(32), &mut node_table);
86+
let f_id = f.id();
87+
88+
let g = DefaultNode::with_action("Compute G".to_string(), Compute(64), &mut node_table);
89+
let g_id = g.id();
90+
91+
let mut graph = Graph::new();
92+
vec![a, b, c, d, e, f, g]
93+
.into_iter()
94+
.for_each(|node| graph.add_node(node));
95+
96+
graph.add_edge(a_id, vec![b_id, c_id, d_id]);
97+
graph.add_edge(b_id, vec![e_id, g_id]);
98+
graph.add_edge(c_id, vec![e_id, f_id]);
99+
graph.add_edge(d_id, vec![f_id]);
100+
graph.add_edge(e_id, vec![g_id]);
101+
graph.add_edge(f_id, vec![g_id]);
102+
103+
let mut env = EnvVar::new(node_table);
104+
env.set("base", 2usize);
105+
graph.set_env(env);
106+
107+
match graph.start() {
108+
Ok(_) => {
109+
let res = graph
110+
.get_results::<usize>()
111+
.get(&g_id)
112+
.unwrap()
113+
.clone()
114+
.unwrap();
115+
// 验证执行结果
116+
assert_eq!(*res, 272)
117+
}
118+
Err(e) => {
119+
panic!("图执行失败: {:?}", e);
120+
}
121+
}
122+
}

src/connection/in_channel.rs

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::{collections::HashMap, sync::Arc};
1+
use std::{collections::HashMap, marker::PhantomData, sync::Arc};
22

33
use futures::future::join_all;
44
use tokio::sync::{broadcast, mpsc, Mutex};
@@ -141,7 +141,6 @@ impl InChannel {
141141
},
142142
}
143143
}
144-
145144
/// Close the channel and drop the messages inside.
146145
fn close(&mut self) {
147146
match self {
@@ -152,6 +151,100 @@ impl InChannel {
152151
}
153152
}
154153

154+
/// # Typed Input Channels
155+
/// A hash-table mapping `NodeId` to `InChannel`. This provides type-safe channel communication
156+
/// between nodes.
157+
#[derive(Default)]
158+
pub struct TypedInChannels<T: Send + Sync + 'static>(
159+
pub(crate) HashMap<NodeId, Arc<Mutex<InChannel>>>,
160+
// maker for type T
161+
pub(crate) PhantomData<T>,
162+
);
163+
164+
impl<T: Send + Sync + 'static> TypedInChannels<T> {
165+
/// Perform a blocking receive on the incoming channel from `NodeId`.
166+
pub fn blocking_recv_from(&mut self, id: &NodeId) -> Result<Option<Arc<T>>, RecvErr> {
167+
match self.get(id) {
168+
Some(channel) => {
169+
let content: Content = channel.blocking_lock().blocking_recv()?;
170+
Ok(content.into_inner())
171+
}
172+
None => Err(RecvErr::NoSuchChannel),
173+
}
174+
}
175+
176+
/// Perform a asynchronous receive on the incoming channel from `NodeId`.
177+
pub async fn recv_from(&mut self, id: &NodeId) -> Result<Option<Arc<T>>, RecvErr> {
178+
match self.get(id) {
179+
Some(channel) => {
180+
let content: Content = channel.lock().await.recv().await?;
181+
Ok(content.into_inner())
182+
}
183+
None => Err(RecvErr::NoSuchChannel),
184+
}
185+
}
186+
187+
/// Calls `blocking_recv` for all the [`InChannel`]s, and applies transformation `f` to
188+
/// the return values of the call.
189+
pub fn blocking_map<F, U>(&mut self, mut f: F) -> Vec<U>
190+
where
191+
F: FnMut(Result<Option<Arc<T>>, RecvErr>) -> U,
192+
{
193+
self.keys()
194+
.into_iter()
195+
.map(|id| f(self.blocking_recv_from(&id)))
196+
.collect()
197+
}
198+
199+
/// Calls `recv` for all the [`InChannel`]s, and applies transformation `f` to
200+
/// the return values of the call asynchronously.
201+
pub async fn map<F, U>(&mut self, mut f: F) -> Vec<U>
202+
where
203+
F: FnMut(Result<Option<Arc<T>>, RecvErr>) -> U,
204+
{
205+
let futures = self.0.iter_mut().map(|(_, c)| async {
206+
let content: Content = c.lock().await.recv().await?;
207+
Ok(content.into_inner())
208+
});
209+
join_all(futures).await.into_iter().map(|x| f(x)).collect()
210+
}
211+
212+
/// Close the channel by the given `NodeId` asynchronously, and remove the channel in this map.
213+
pub async fn close_async(&mut self, id: &NodeId) {
214+
if let Some(c) = self.get(id) {
215+
c.lock().await.close();
216+
self.0.remove(id);
217+
}
218+
}
219+
220+
/// Close the channel by the given `NodeId`, and remove the channel in this map.
221+
pub fn close(&mut self, id: &NodeId) {
222+
if let Some(c) = self.get(id) {
223+
c.blocking_lock().close();
224+
self.0.remove(id);
225+
}
226+
}
227+
228+
pub(crate) fn insert(&mut self, node_id: NodeId, channel: Arc<Mutex<InChannel>>) {
229+
self.0.insert(node_id, channel);
230+
}
231+
232+
pub(crate) fn close_all(&mut self) {
233+
self.0.values_mut().for_each(|c| c.blocking_lock().close());
234+
}
235+
236+
fn get(&self, id: &NodeId) -> Option<Arc<Mutex<InChannel>>> {
237+
match self.0.get(id) {
238+
Some(c) => Some(c.clone()),
239+
None => None,
240+
}
241+
}
242+
243+
fn keys(&self) -> Vec<NodeId> {
244+
self.0.keys().map(|x| *x).collect()
245+
}
246+
}
247+
155248
/// # Input Channel Error Types
156249
/// - NoSuchChannel: try to get a channel with an invalid `NodeId`.
157250
/// - Closed: the channel to receive messages from is closed and empty already.

src/connection/information_packet.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::{any::Any, sync::Arc};
33
/// Container type to store task output.
44
#[derive(Debug, Clone)]
55
pub struct Content {
6-
inner: Arc<dyn Any + Send + Sync>,
6+
pub inner: Arc<dyn Any + Send + Sync>,
77
}
88

99
impl Content {

src/connection/out_channel.rs

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::{collections::HashMap, sync::Arc};
1+
use std::{collections::HashMap, marker::PhantomData, sync::Arc};
22

33
use futures::future::join_all;
44
use tokio::sync::{broadcast, mpsc, Mutex};
@@ -124,3 +124,73 @@ pub enum SendErr {
124124
NoSuchChannel,
125125
ClosedChannel(Content),
126126
}
127+
128+
/// # Typed Output Channels
129+
/// A hash-table mapping [`NodeId`] to [`OutChannel`]. This provides type-safe channel communication
130+
/// between nodes.
131+
#[derive(Default)]
132+
pub struct TypedOutChannels<T: Send + Sync + 'static>(
133+
pub(crate) HashMap<NodeId, Arc<Mutex<OutChannel>>>,
134+
// maker for type T
135+
pub(crate) PhantomData<T>,
136+
);
137+
138+
impl<T: Send + Sync + 'static> TypedOutChannels<T> {
139+
/// Perform a blocking send on the outcoming channel from `NodeId`.
140+
pub fn blocking_send_to(&self, id: &NodeId, content: T) -> Result<(), SendErr> {
141+
match self.get(id) {
142+
Some(channel) => channel.blocking_lock().blocking_send(Content::new(content)),
143+
None => Err(SendErr::NoSuchChannel),
144+
}
145+
}
146+
147+
/// Perform a asynchronous send on the outcoming channel from `NodeId`.
148+
pub async fn send_to(&self, id: &NodeId, content: T) -> Result<(), SendErr> {
149+
match self.get(id) {
150+
Some(channel) => channel.lock().await.send(Content::new(content)).await,
151+
None => Err(SendErr::NoSuchChannel),
152+
}
153+
}
154+
155+
/// Broadcasts the `content` to all the [`TypedOutChannel`]s asynchronously.
156+
pub async fn broadcast(&self, content: T) -> Vec<Result<(), SendErr>> {
157+
let content = Content::new(content);
158+
let futures = self
159+
.0
160+
.iter()
161+
.map(|(_, c)| async { c.lock().await.send(content.clone()).await });
162+
163+
join_all(futures).await
164+
}
165+
166+
/// Blocking broadcasts the `content` to all the [`TypedOutChannel`]s.
167+
pub fn blocking_broadcast(&self, content: T) -> Vec<Result<(), SendErr>> {
168+
let content = Content::new(content);
169+
self.0
170+
.iter()
171+
.map(|(_, c)| c.blocking_lock().blocking_send(content.clone()))
172+
.collect()
173+
}
174+
175+
/// Close the channel by the given `NodeId`, and remove the channel in this map.
176+
pub fn close(&mut self, id: &NodeId) {
177+
if let Some(_) = self.get(id) {
178+
self.0.remove(id);
179+
}
180+
}
181+
182+
pub(crate) fn close_all(&mut self) {
183+
self.0.clear();
184+
}
185+
186+
fn get(&self, id: &NodeId) -> Option<Arc<Mutex<OutChannel>>> {
187+
match self.0.get(id) {
188+
Some(c) => Some(c.clone()),
189+
None => None,
190+
}
191+
}
192+
193+
pub(crate) fn insert(&mut self, node_id: NodeId, channel: Arc<Mutex<OutChannel>>) {
194+
self.0.insert(node_id, channel);
195+
}
196+
}

src/node/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ pub mod conditional_node;
33
pub mod default_node;
44
pub mod id_allocate;
55
pub mod node;
6+
pub mod typed_action;

src/node/node.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,18 @@ pub trait Node: Send + Sync {
4444
fn loop_structure(&self) -> Option<Vec<Arc<Mutex<dyn Node>>>> {
4545
None
4646
}
47+
48+
/// Returns true if this node has TypedContent input.
49+
/// By default, it returns false.
50+
fn has_typed_input(&self) -> bool {
51+
false
52+
}
53+
54+
/// Returns true if this node has TypedContent output.
55+
/// By default, it returns false.
56+
fn has_typed_output(&self) -> bool {
57+
false
58+
}
4759
}
4860

4961
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Ord, PartialOrd)]

0 commit comments

Comments
 (0)