Skip to content

Commit 9e0c78f

Browse files
committed
Add proper handling of node storages to ParallelFlow
Adds merging of child storages into parent storage
1 parent 133c7cb commit 9e0c78f

3 files changed

Lines changed: 15 additions & 7 deletions

File tree

src/flows/parallel_flow/chain_run/poll.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ pub trait ChainPollParallel<Output>: Send {
1212
self: Pin<&mut Self>,
1313
cx: &mut Context<'_>,
1414
tail_ready: bool,
15-
storage: &mut Storage,
15+
storage_acc: &mut Vec<Storage>,
1616
) -> Poll<Output>;
1717
}
1818

@@ -28,19 +28,19 @@ where
2828
self: Pin<&mut Self>,
2929
cx: &mut Context<'_>,
3030
tail_ready: bool,
31-
storage: &mut Storage,
31+
storage_acc: &mut Vec<Storage>,
3232
) -> Poll<Result<(HeadOutput, TailOutput), Error>> {
3333
let (head, tail) = unsafe { self.get_unchecked_mut() };
3434
let (head, mut tail) = unsafe { (Pin::new_unchecked(head), Pin::new_unchecked(tail)) };
3535
let tail_ready = tail.as_mut().poll(cx).is_ready() && tail_ready;
3636

37-
let Poll::Ready(res) = ChainPollParallel::poll(head, cx, tail_ready, storage) else {
37+
let Poll::Ready(res) = ChainPollParallel::poll(head, cx, tail_ready, storage_acc) else {
3838
return Poll::Pending;
3939
};
4040
match res {
4141
Ok(head_out) => match tail.take_output().unwrap() {
4242
Ok((tail_out, node_storage)) => {
43-
// TODO: merge storage
43+
storage_acc.push(node_storage);
4444
Poll::Ready(Ok((head_out, tail_out)))
4545
}
4646
Err(e) => Poll::Ready(Err(e)),
@@ -60,13 +60,13 @@ where
6060
self: Pin<&mut Self>,
6161
cx: &mut Context<'_>,
6262
tail_ready: bool,
63-
storage: &mut Storage,
63+
storage_acc: &mut Vec<Storage>,
6464
) -> Poll<Result<(HeadOutput,), Error>> {
6565
let mut head = unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().0) };
6666
if head.as_mut().poll(cx).is_ready() && tail_ready {
6767
match head.take_output().unwrap() {
6868
Ok((output, node_storage)) => {
69-
// TODO: merge storage
69+
storage_acc.push(node_storage);
7070
Poll::Ready(Ok((output,)))
7171
}
7272
Err(e) => Poll::Ready(Err(e)),

src/flows/parallel_flow/chain_run/run.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@ where
2020
{
2121
async fn run_with_storage(&self, input: Input, storage: &mut Storage) -> Result<Output, Error> {
2222
let fut_chain = self.spawn_with_storage(input, storage.new_gen());
23+
let mut storage_acc = Vec::with_capacity(U::NUM_FUTURES);
2324
let mut fut_chain = pin!(fut_chain);
24-
poll_fn(move |cx| ChainPollParallel::poll(fut_chain.as_mut(), cx, true, storage)).await
25+
let res =
26+
poll_fn(|cx| ChainPollParallel::poll(fut_chain.as_mut(), cx, true, &mut storage_acc))
27+
.await;
28+
storage.merge(&mut storage_acc);
29+
res
2530
}
2631
}

src/flows/parallel_flow/chain_run/spawn.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use crate::{
88

99
pub trait ChainSpawn<Input, Error, HeadOut, T> {
1010
type ChainOut;
11+
const NUM_FUTURES: usize;
1112

1213
fn spawn_with_storage(
1314
&self,
@@ -45,6 +46,7 @@ where
4546
Error: Send,
4647
{
4748
type ChainOut = Result<(HeadOut, NodeOutputStruct<TailNodeOutType>), Error>;
49+
const NUM_FUTURES: usize = Head::NUM_FUTURES + 1;
4850

4951
fn spawn_with_storage(
5052
&self,
@@ -86,6 +88,7 @@ where
8688
Error: Send,
8789
{
8890
type ChainOut = Result<(NodeOutputStruct<HeadNodeOutType>,), Error>;
91+
const NUM_FUTURES: usize = 1;
8992

9093
fn spawn_with_storage(
9194
&self,

0 commit comments

Comments
 (0)