Skip to content

Commit 36a5614

Browse files
fix(lib,provers,tasks): move from sync to async trait (#328)
1 parent 959bdea commit 36a5614

File tree

9 files changed

+62
-65
lines changed

9 files changed

+62
-65
lines changed

Cargo.lock

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

lib/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ utoipa = { workspace = true }
4646
cfg-if = { workspace = true }
4747
tracing = { workspace = true }
4848
bincode = { workspace = true }
49+
async-trait = { workspace = true }
4950

5051
# [target.'cfg(feature = "std")'.dependencies]
5152
flate2 = { workspace = true, optional = true }

lib/src/prover.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,16 @@ pub struct Proof {
3737
pub kzg_proof: Option<String>,
3838
}
3939

40+
#[async_trait::async_trait]
4041
pub trait IdWrite: Send {
41-
fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()>;
42+
async fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()>;
4243

43-
fn remove_id(&mut self, key: ProofKey) -> ProverResult<()>;
44+
async fn remove_id(&mut self, key: ProofKey) -> ProverResult<()>;
4445
}
4546

47+
#[async_trait::async_trait]
4648
pub trait IdStore: IdWrite {
47-
fn read_id(&self, key: ProofKey) -> ProverResult<String>;
49+
async fn read_id(&self, key: ProofKey) -> ProverResult<String>;
4850
}
4951

5052
#[allow(async_fn_in_trait)]

provers/risc0/driver/src/bonsai.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ pub async fn prove_bonsai<O: Eq + Debug + DeserializeOwned>(
219219
)?;
220220

221221
if let Some(id_store) = id_store {
222-
id_store.store_id(proof_key, session.uuid.clone())?;
222+
id_store.store_id(proof_key, session.uuid.clone()).await?;
223223
}
224224

225225
verify_bonsai_receipt(image_id, expected_output, session.uuid.clone(), 8).await

provers/risc0/driver/src/lib.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,11 @@ impl Prover for Risc0Prover {
112112
}
113113

114114
async fn cancel(key: ProofKey, id_store: Box<&mut dyn IdStore>) -> ProverResult<()> {
115-
let uuid = id_store.read_id(key)?;
115+
let uuid = id_store.read_id(key).await?;
116116
cancel_proof(uuid)
117117
.await
118118
.map_err(|e| ProverError::GuestError(e.to_string()))?;
119-
id_store.remove_id(key)?;
119+
id_store.remove_id(key).await?;
120120
Ok(())
121121
}
122122
}

provers/sp1/driver/src/lib.rs

+8-6
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,12 @@ impl Prover for Sp1Prover {
6969
ProverError::GuestError("Sp1: creating proof failed".to_owned())
7070
})?;
7171
if let Some(id_store) = id_store {
72-
id_store.store_id(
73-
(input.chain_spec.chain_id, output.hash, SP1_PROVER_CODE),
74-
proof_id.clone(),
75-
)?;
72+
id_store
73+
.store_id(
74+
(input.chain_spec.chain_id, output.hash, SP1_PROVER_CODE),
75+
proof_id.clone(),
76+
)
77+
.await?;
7678
}
7779
let proof = {
7880
let mut is_claimed = false;
@@ -136,7 +138,7 @@ impl Prover for Sp1Prover {
136138
}
137139

138140
async fn cancel(key: ProofKey, id_store: Box<&mut dyn IdStore>) -> ProverResult<()> {
139-
let proof_id = id_store.read_id(key)?;
141+
let proof_id = id_store.read_id(key).await?;
140142
let private_key = env::var("SP1_PRIVATE_KEY").map_err(|_| {
141143
ProverError::GuestError("SP1_PRIVATE_KEY must be set for remote proving".to_owned())
142144
})?;
@@ -145,7 +147,7 @@ impl Prover for Sp1Prover {
145147
.unclaim_proof(proof_id, UnclaimReason::Abandoned, "".to_owned())
146148
.await
147149
.map_err(|_| ProverError::GuestError("Sp1: couldn't unclaim proof".to_owned()))?;
148-
id_store.remove_id(key)?;
150+
id_store.remove_id(key).await?;
149151
Ok(())
150152
}
151153
}

tasks/src/adv_sqlite.rs

+18-22
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ use raiko_lib::{
166166
use rusqlite::{
167167
named_params, {Connection, OpenFlags},
168168
};
169-
use tokio::{runtime::Builder, sync::Mutex};
169+
use tokio::sync::Mutex;
170170

171171
use crate::{
172172
TaskDescriptor, TaskManager, TaskManagerError, TaskManagerOpts, TaskManagerResult,
@@ -833,34 +833,30 @@ impl TaskDb {
833833
}
834834
}
835835

836+
#[async_trait::async_trait]
836837
impl IdWrite for SqliteTaskManager {
837-
fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()> {
838-
let rt = Builder::new_current_thread().enable_all().build()?;
839-
rt.block_on(async move {
840-
let task_db = self.arc_task_db.lock().await;
841-
task_db.store_id(key, id)
842-
})
843-
.map_err(|e| ProverError::StoreError(e.to_string()))
838+
async fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()> {
839+
let task_db = self.arc_task_db.lock().await;
840+
task_db
841+
.store_id(key, id)
842+
.map_err(|e| ProverError::StoreError(e.to_string()))
844843
}
845844

846-
fn remove_id(&mut self, key: ProofKey) -> ProverResult<()> {
847-
let rt = Builder::new_current_thread().enable_all().build()?;
848-
rt.block_on(async move {
849-
let task_db = self.arc_task_db.lock().await;
850-
task_db.remove_id(key)
851-
})
852-
.map_err(|e| ProverError::StoreError(e.to_string()))
845+
async fn remove_id(&mut self, key: ProofKey) -> ProverResult<()> {
846+
let task_db = self.arc_task_db.lock().await;
847+
task_db
848+
.remove_id(key)
849+
.map_err(|e| ProverError::StoreError(e.to_string()))
853850
}
854851
}
855852

853+
#[async_trait::async_trait]
856854
impl IdStore for SqliteTaskManager {
857-
fn read_id(&self, key: ProofKey) -> ProverResult<String> {
858-
let rt = Builder::new_current_thread().enable_all().build()?;
859-
rt.block_on(async move {
860-
let task_db = self.arc_task_db.lock().await;
861-
task_db.read_id(key)
862-
})
863-
.map_err(|e| ProverError::StoreError(e.to_string()))
855+
async fn read_id(&self, key: ProofKey) -> ProverResult<String> {
856+
let task_db = self.arc_task_db.lock().await;
857+
task_db
858+
.read_id(key)
859+
.map_err(|e| ProverError::StoreError(e.to_string()))
864860
}
865861
}
866862

tasks/src/lib.rs

+11-9
Original file line numberDiff line numberDiff line change
@@ -179,27 +179,29 @@ pub struct TaskManagerWrapper {
179179
manager: TaskManagerInstance,
180180
}
181181

182+
#[async_trait::async_trait]
182183
impl IdWrite for TaskManagerWrapper {
183-
fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()> {
184+
async fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()> {
184185
match &mut self.manager {
185-
TaskManagerInstance::InMemory(ref mut manager) => manager.store_id(key, id),
186-
TaskManagerInstance::Sqlite(ref mut manager) => manager.store_id(key, id),
186+
TaskManagerInstance::InMemory(ref mut manager) => manager.store_id(key, id).await,
187+
TaskManagerInstance::Sqlite(ref mut manager) => manager.store_id(key, id).await,
187188
}
188189
}
189190

190-
fn remove_id(&mut self, key: ProofKey) -> ProverResult<()> {
191+
async fn remove_id(&mut self, key: ProofKey) -> ProverResult<()> {
191192
match &mut self.manager {
192-
TaskManagerInstance::InMemory(ref mut manager) => manager.remove_id(key),
193-
TaskManagerInstance::Sqlite(ref mut manager) => manager.remove_id(key),
193+
TaskManagerInstance::InMemory(ref mut manager) => manager.remove_id(key).await,
194+
TaskManagerInstance::Sqlite(ref mut manager) => manager.remove_id(key).await,
194195
}
195196
}
196197
}
197198

199+
#[async_trait::async_trait]
198200
impl IdStore for TaskManagerWrapper {
199-
fn read_id(&self, key: ProofKey) -> ProverResult<String> {
201+
async fn read_id(&self, key: ProofKey) -> ProverResult<String> {
200202
match &self.manager {
201-
TaskManagerInstance::InMemory(manager) => manager.read_id(key),
202-
TaskManagerInstance::Sqlite(manager) => manager.read_id(key),
203+
TaskManagerInstance::InMemory(manager) => manager.read_id(key).await,
204+
TaskManagerInstance::Sqlite(manager) => manager.read_id(key).await,
203205
}
204206
}
205207
}

tasks/src/mem_db.rs

+15-22
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use std::{
1414

1515
use chrono::Utc;
1616
use raiko_lib::prover::{IdStore, IdWrite, ProofKey, ProverError, ProverResult};
17-
use tokio::{runtime::Builder, sync::Mutex};
17+
use tokio::sync::Mutex;
1818
use tracing::{debug, info};
1919

2020
use crate::{
@@ -143,34 +143,27 @@ impl InMemoryTaskDb {
143143
}
144144
}
145145

146+
#[async_trait::async_trait]
146147
impl IdWrite for InMemoryTaskManager {
147-
fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()> {
148-
let rt = Builder::new_current_thread().enable_all().build()?;
149-
rt.block_on(async move {
150-
let mut db = self.db.lock().await;
151-
db.store_id(key, id)
152-
})
153-
.map_err(|e| ProverError::StoreError(e.to_string()))
148+
async fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()> {
149+
let mut db = self.db.lock().await;
150+
db.store_id(key, id)
151+
.map_err(|e| ProverError::StoreError(e.to_string()))
154152
}
155153

156-
fn remove_id(&mut self, key: ProofKey) -> ProverResult<()> {
157-
let rt = Builder::new_current_thread().enable_all().build()?;
158-
rt.block_on(async move {
159-
let mut db = self.db.lock().await;
160-
db.remove_id(key)
161-
})
162-
.map_err(|e| ProverError::StoreError(e.to_string()))
154+
async fn remove_id(&mut self, key: ProofKey) -> ProverResult<()> {
155+
let mut db = self.db.lock().await;
156+
db.remove_id(key)
157+
.map_err(|e| ProverError::StoreError(e.to_string()))
163158
}
164159
}
165160

161+
#[async_trait::async_trait]
166162
impl IdStore for InMemoryTaskManager {
167-
fn read_id(&self, key: ProofKey) -> ProverResult<String> {
168-
let rt = Builder::new_current_thread().enable_all().build()?;
169-
rt.block_on(async move {
170-
let mut db = self.db.lock().await;
171-
db.read_id(key)
172-
})
173-
.map_err(|e| ProverError::StoreError(e.to_string()))
163+
async fn read_id(&self, key: ProofKey) -> ProverResult<String> {
164+
let mut db = self.db.lock().await;
165+
db.read_id(key)
166+
.map_err(|e| ProverError::StoreError(e.to_string()))
174167
}
175168
}
176169

0 commit comments

Comments
 (0)