Skip to content

Commit ec483b7

Browse files
feat(raiko): put the tasks that cannot run in parallel into pending list (#358)
* put the tasks that cannot run in parallel into pending list Signed-off-by: smtmfft <[email protected]> * Update host/src/proof.rs Co-authored-by: Petar Vujović <[email protected]> * Update host/src/proof.rs Co-authored-by: Petar Vujović <[email protected]> * fix merge conflicts * fix compile issue * Update host/src/proof.rs Co-authored-by: Petar Vujović <[email protected]> * Update host/src/proof.rs Co-authored-by: Petar Vujović <[email protected]> * Update host/src/proof.rs Co-authored-by: Petar Vujović <[email protected]> * Update host/src/proof.rs Co-authored-by: Petar Vujović <[email protected]> * Update host/src/proof.rs Co-authored-by: Petar Vujović <[email protected]> * Update host/src/proof.rs Co-authored-by: Petar Vujović <[email protected]> --------- Signed-off-by: smtmfft <[email protected]> Co-authored-by: Petar Vujović <[email protected]>
1 parent eb4d032 commit ec483b7

File tree

2 files changed

+72
-22
lines changed

2 files changed

+72
-22
lines changed

host/src/lib.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ pub struct ProverState {
152152
pub enum Message {
153153
Cancel(TaskDescriptor),
154154
Task(ProofRequest),
155+
TaskComplete(ProofRequest),
155156
CancelAggregate(AggregationOnlyRequest),
156157
Aggregate(AggregationOnlyRequest),
157158
}
@@ -200,9 +201,9 @@ impl ProverState {
200201

201202
let opts_clone = opts.clone();
202203
let chain_specs_clone = chain_specs.clone();
203-
204+
let sender = task_channel.clone();
204205
tokio::spawn(async move {
205-
ProofActor::new(receiver, opts_clone, chain_specs_clone)
206+
ProofActor::new(sender, receiver, opts_clone, chain_specs_clone)
206207
.run()
207208
.await;
208209
});

host/src/proof.rs

+69-20
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
use std::{collections::HashMap, str::FromStr, sync::Arc};
1+
use std::{
2+
collections::{HashMap, VecDeque},
3+
str::FromStr,
4+
sync::Arc,
5+
};
26

37
use anyhow::anyhow;
48
use raiko_core::{
@@ -16,10 +20,13 @@ use raiko_tasks::{get_task_manager, TaskDescriptor, TaskManager, TaskManagerWrap
1620
use reth_primitives::B256;
1721
use tokio::{
1822
select,
19-
sync::{mpsc::Receiver, Mutex, OwnedSemaphorePermit, Semaphore},
23+
sync::{
24+
mpsc::{Receiver, Sender},
25+
Mutex, OwnedSemaphorePermit, Semaphore,
26+
},
2027
};
2128
use tokio_util::sync::CancellationToken;
22-
use tracing::{error, info, warn};
29+
use tracing::{debug, error, info, warn};
2330

2431
use crate::{
2532
cache,
@@ -35,32 +42,42 @@ use crate::{
3542
pub struct ProofActor {
3643
opts: Opts,
3744
chain_specs: SupportedChainSpecs,
38-
tasks: Arc<Mutex<HashMap<TaskDescriptor, CancellationToken>>>,
3945
aggregate_tasks: Arc<Mutex<HashMap<AggregationOnlyRequest, CancellationToken>>>,
46+
running_tasks: Arc<Mutex<HashMap<TaskDescriptor, CancellationToken>>>,
47+
pending_tasks: Arc<Mutex<VecDeque<ProofRequest>>>,
4048
receiver: Receiver<Message>,
49+
sender: Sender<Message>,
4150
}
4251

4352
impl ProofActor {
44-
pub fn new(receiver: Receiver<Message>, opts: Opts, chain_specs: SupportedChainSpecs) -> Self {
45-
let tasks = Arc::new(Mutex::new(
53+
pub fn new(
54+
sender: Sender<Message>,
55+
receiver: Receiver<Message>,
56+
opts: Opts,
57+
chain_specs: SupportedChainSpecs,
58+
) -> Self {
59+
let running_tasks = Arc::new(Mutex::new(
4660
HashMap::<TaskDescriptor, CancellationToken>::new(),
4761
));
4862
let aggregate_tasks = Arc::new(Mutex::new(HashMap::<
4963
AggregationOnlyRequest,
5064
CancellationToken,
5165
>::new()));
66+
let pending_tasks = Arc::new(Mutex::new(VecDeque::<ProofRequest>::new()));
5267

5368
Self {
54-
tasks,
55-
aggregate_tasks,
5669
opts,
5770
chain_specs,
71+
aggregate_tasks,
72+
running_tasks,
73+
pending_tasks,
5874
receiver,
75+
sender,
5976
}
6077
}
6178

6279
pub async fn cancel_task(&mut self, key: TaskDescriptor) -> HostResult<()> {
63-
let tasks_map = self.tasks.lock().await;
80+
let tasks_map = self.running_tasks.lock().await;
6481
let Some(task) = tasks_map.get(&key) else {
6582
warn!("No task with those keys to cancel");
6683
return Ok(());
@@ -85,7 +102,7 @@ impl ProofActor {
85102
Ok(())
86103
}
87104

88-
pub async fn run_task(&mut self, proof_request: ProofRequest, _permit: OwnedSemaphorePermit) {
105+
pub async fn run_task(&mut self, proof_request: ProofRequest) {
89106
let cancel_token = CancellationToken::new();
90107

91108
let Ok((chain_id, blockhash)) = get_task_data(
@@ -106,10 +123,11 @@ impl ProofActor {
106123
proof_request.prover.clone().to_string(),
107124
));
108125

109-
let mut tasks = self.tasks.lock().await;
126+
let mut tasks = self.running_tasks.lock().await;
110127
tasks.insert(key.clone(), cancel_token.clone());
128+
let sender = self.sender.clone();
111129

112-
let tasks = self.tasks.clone();
130+
let tasks = self.running_tasks.clone();
113131
let opts = self.opts.clone();
114132
let chain_specs = self.chain_specs.clone();
115133

@@ -118,7 +136,7 @@ impl ProofActor {
118136
_ = cancel_token.cancelled() => {
119137
info!("Task cancelled");
120138
}
121-
result = Self::handle_message(proof_request, key.clone(), &opts, &chain_specs) => {
139+
result = Self::handle_message(proof_request.clone(), key.clone(), &opts, &chain_specs) => {
122140
match result {
123141
Ok(status) => {
124142
info!("Host handling message: {status:?}");
@@ -131,6 +149,11 @@ impl ProofActor {
131149
}
132150
let mut tasks = tasks.lock().await;
133151
tasks.remove(&key);
152+
// notify complete task to let next pending task run
153+
sender
154+
.send(Message::TaskComplete(proof_request))
155+
.await
156+
.expect("Couldn't send message");
134157
});
135158
}
136159

@@ -203,21 +226,47 @@ impl ProofActor {
203226
}
204227

205228
pub async fn run(&mut self) {
229+
// recv() is protected by outside mpsc, no lock needed here
206230
let semaphore = Arc::new(Semaphore::new(self.opts.concurrency_limit));
207-
208231
while let Some(message) = self.receiver.recv().await {
209232
match message {
210233
Message::Cancel(key) => {
234+
debug!("Message::Cancel task: {key:?}");
211235
if let Err(error) = self.cancel_task(key).await {
212236
error!("Failed to cancel task: {error}")
213237
}
214238
}
215239
Message::Task(proof_request) => {
216-
let permit = Arc::clone(&semaphore)
217-
.acquire_owned()
218-
.await
219-
.expect("Couldn't acquire permit");
220-
self.run_task(proof_request, permit).await;
240+
debug!("Message::Task proof_request: {proof_request:?}");
241+
let running_task_count = self.running_tasks.lock().await.len();
242+
if running_task_count < self.opts.concurrency_limit {
243+
info!("Running task {proof_request:?}");
244+
self.run_task(proof_request).await;
245+
} else {
246+
info!(
247+
"Task concurrency limit reached, current running {running_task_count:?}, pending: {:?}",
248+
self.pending_tasks.lock().await.len()
249+
);
250+
let mut pending_tasks = self.pending_tasks.lock().await;
251+
pending_tasks.push_back(proof_request);
252+
}
253+
}
254+
Message::TaskComplete(req) => {
255+
// pop up pending task if any task complete
256+
debug!("Message::TaskComplete: {req:?}");
257+
info!(
258+
"task completed, current running {:?}, pending: {:?}",
259+
self.running_tasks.lock().await.len(),
260+
self.pending_tasks.lock().await.len()
261+
);
262+
let mut pending_tasks = self.pending_tasks.lock().await;
263+
if let Some(proof_request) = pending_tasks.pop_front() {
264+
info!("Pop out pending task {proof_request:?}");
265+
self.sender
266+
.send(Message::Task(proof_request))
267+
.await
268+
.expect("Couldn't send message");
269+
}
221270
}
222271
Message::CancelAggregate(request) => {
223272
if let Err(error) = self.cancel_aggregation_task(request).await {
@@ -326,7 +375,7 @@ pub async fn handle_proof(
326375
store: Option<&mut TaskManagerWrapper>,
327376
) -> HostResult<Proof> {
328377
info!(
329-
"# Generating proof for block {} on {}",
378+
"Generating proof for block {} on {}",
330379
proof_request.block_number, proof_request.network
331380
);
332381

0 commit comments

Comments
 (0)