Skip to content
This repository was archived by the owner on Oct 18, 2023. It is now read-only.

Commit a4eaec3

Browse files
committed
remove txn state from execute return
It is now passed to the result builder
1 parent 70a906f commit a4eaec3

File tree

10 files changed

+116
-84
lines changed

10 files changed

+116
-84
lines changed

Cargo.lock

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sqld/src/connection/libsql.rs

+54-35
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use std::sync::atomic::{AtomicBool, Ordering};
44
use std::sync::Arc;
55

66
use parking_lot::{Mutex, RwLock};
7-
use rusqlite::{DatabaseName, ErrorCode, OpenFlags, StatementStatus};
7+
use rusqlite::{DatabaseName, ErrorCode, OpenFlags, StatementStatus, TransactionState};
88
use sqld_libsql_bindings::wal_hook::{TransparentMethods, WalMethodsHook};
99
use tokio::sync::{watch, Notify};
1010
use tokio::time::{Duration, Instant};
@@ -144,7 +144,6 @@ where
144144
}
145145
}
146146

147-
#[derive(Clone)]
148147
pub struct LibSqlConnection<W: WalHook> {
149148
inner: Arc<Mutex<Connection<W>>>,
150149
}
@@ -160,6 +159,12 @@ impl<W: WalHook> std::fmt::Debug for LibSqlConnection<W> {
160159
}
161160
}
162161

162+
impl<W: WalHook> Clone for LibSqlConnection<W> {
163+
fn clone(&self) -> Self {
164+
Self { inner: self.inner.clone() }
165+
}
166+
}
167+
163168
pub fn open_conn<W>(
164169
path: &Path,
165170
wal_methods: &'static WalMethodsHook<W>,
@@ -219,6 +224,15 @@ where
219224
inner: Arc::new(Mutex::new(conn)),
220225
})
221226
}
227+
228+
pub fn txn_status(&self) -> crate::Result<TxnStatus> {
229+
Ok(self
230+
.inner
231+
.lock()
232+
.conn
233+
.transaction_state(Some(DatabaseName::Main))?
234+
.into())
235+
}
222236
}
223237

224238
struct Connection<W: WalHook = TransparentMethods> {
@@ -351,6 +365,16 @@ unsafe extern "C" fn busy_handler<W: WalHook>(state: *mut c_void, _retries: c_in
351365
})
352366
}
353367

368+
impl From<TransactionState> for TxnStatus {
369+
fn from(value: TransactionState) -> Self {
370+
use TransactionState as Tx;
371+
match value {
372+
Tx::None => TxnStatus::Init,
373+
Tx::Read | Tx::Write => TxnStatus::Txn,
374+
_ => unreachable!(),
375+
}
376+
}
377+
}
354378
impl<W: WalHook> Connection<W> {
355379
fn new(
356380
path: &Path,
@@ -405,7 +429,7 @@ impl<W: WalHook> Connection<W> {
405429
this: Arc<Mutex<Self>>,
406430
pgm: Program,
407431
mut builder: B,
408-
) -> Result<(B, TxnStatus)> {
432+
) -> Result<B> {
409433
use rusqlite::TransactionState as Tx;
410434

411435
let state = this.lock().state.clone();
@@ -469,23 +493,18 @@ impl<W: WalHook> Connection<W> {
469493
results.push(res);
470494
}
471495

472-
let status = if matches!(
473-
this.lock()
474-
.conn
475-
.transaction_state(Some(DatabaseName::Main))?,
476-
Tx::Read | Tx::Write
477-
) {
478-
TxnStatus::Txn
479-
} else {
480-
TxnStatus::Init
481-
};
496+
let status = this
497+
.lock()
498+
.conn
499+
.transaction_state(Some(DatabaseName::Main))?
500+
.into();
482501

483502
builder.finish(
484503
*this.lock().current_frame_no_receiver.borrow_and_update(),
485504
status,
486505
)?;
487506

488-
Ok((builder, status))
507+
Ok(builder)
489508
}
490509

491510
fn execute_step(
@@ -736,7 +755,7 @@ where
736755
auth: Authenticated,
737756
builder: B,
738757
_replication_index: Option<FrameNo>,
739-
) -> Result<(B, TxnStatus)> {
758+
) -> Result<B> {
740759
check_program_auth(auth, &pgm)?;
741760
let conn = self.inner.clone();
742761
tokio::task::spawn_blocking(move || Connection::run(conn, pgm, builder))
@@ -828,7 +847,7 @@ mod test {
828847
fn test_libsql_conn_builder_driver() {
829848
test_driver(1000, |b| {
830849
let conn = setup_test_conn();
831-
Connection::run(conn, Program::seq(&["select * from test"]), b).map(|x| x.0)
850+
Connection::run(conn, Program::seq(&["select * from test"]), b)
832851
})
833852
}
834853

@@ -852,23 +871,23 @@ mod test {
852871

853872
tokio::time::pause();
854873
let conn = make_conn.make_connection().await.unwrap();
855-
let (_builder, state) = Connection::run(
874+
let _builder = Connection::run(
856875
conn.inner.clone(),
857876
Program::seq(&["BEGIN IMMEDIATE"]),
858877
TestBuilder::default(),
859878
)
860879
.unwrap();
861-
assert_eq!(state, TxnStatus::Txn);
880+
assert_eq!(conn.txn_status().unwrap(), TxnStatus::Txn);
862881

863882
tokio::time::advance(TXN_TIMEOUT * 2).await;
864883

865-
let (builder, state) = Connection::run(
884+
let builder = Connection::run(
866885
conn.inner.clone(),
867886
Program::seq(&["BEGIN IMMEDIATE"]),
868887
TestBuilder::default(),
869888
)
870889
.unwrap();
871-
assert_eq!(state, TxnStatus::Init);
890+
assert_eq!(conn.txn_status().unwrap(), TxnStatus::Init);
872891
assert!(matches!(builder.into_ret()[0], Err(Error::LibSqlTxTimeout)));
873892
}
874893

@@ -896,13 +915,13 @@ mod test {
896915
for _ in 0..10 {
897916
let conn = make_conn.make_connection().await.unwrap();
898917
set.spawn_blocking(move || {
899-
let (builder, state) = Connection::run(
900-
conn.inner,
918+
let builder = Connection::run(
919+
conn.inner.clone(),
901920
Program::seq(&["BEGIN IMMEDIATE"]),
902921
TestBuilder::default(),
903922
)
904923
.unwrap();
905-
assert_eq!(state, TxnStatus::Txn);
924+
assert_eq!(conn.txn_status().unwrap(), TxnStatus::Txn);
906925
assert!(builder.into_ret()[0].is_ok());
907926
});
908927
}
@@ -937,15 +956,15 @@ mod test {
937956

938957
let conn1 = make_conn.make_connection().await.unwrap();
939958
tokio::task::spawn_blocking({
940-
let conn = conn1.inner.clone();
959+
let conn = conn1.clone();
941960
move || {
942-
let (builder, state) = Connection::run(
943-
conn,
961+
let builder = Connection::run(
962+
conn.inner.clone(),
944963
Program::seq(&["BEGIN IMMEDIATE"]),
945964
TestBuilder::default(),
946965
)
947966
.unwrap();
948-
assert_eq!(state, TxnStatus::Txn);
967+
assert_eq!(conn.txn_status().unwrap(), TxnStatus::Txn);
949968
assert!(builder.into_ret()[0].is_ok());
950969
}
951970
})
@@ -954,16 +973,16 @@ mod test {
954973

955974
let conn2 = make_conn.make_connection().await.unwrap();
956975
let handle = tokio::task::spawn_blocking({
957-
let conn = conn2.inner.clone();
976+
let conn = conn2.clone();
958977
move || {
959978
let before = Instant::now();
960-
let (builder, state) = Connection::run(
961-
conn,
979+
let builder = Connection::run(
980+
conn.inner.clone(),
962981
Program::seq(&["BEGIN IMMEDIATE"]),
963982
TestBuilder::default(),
964983
)
965984
.unwrap();
966-
assert_eq!(state, TxnStatus::Txn);
985+
assert_eq!(conn.txn_status().unwrap(), TxnStatus::Txn);
967986
assert!(builder.into_ret()[0].is_ok());
968987
before.elapsed()
969988
}
@@ -973,12 +992,12 @@ mod test {
973992
tokio::time::sleep(wait_time).await;
974993

975994
tokio::task::spawn_blocking({
976-
let conn = conn1.inner.clone();
995+
let conn = conn1.clone();
977996
move || {
978-
let (builder, state) =
979-
Connection::run(conn, Program::seq(&["COMMIT"]), TestBuilder::default())
997+
let builder =
998+
Connection::run(conn.inner.clone(), Program::seq(&["COMMIT"]), TestBuilder::default())
980999
.unwrap();
981-
assert_eq!(state, TxnStatus::Init);
1000+
assert_eq!(conn.txn_status().unwrap(), TxnStatus::Init);
9821001
assert!(builder.into_ret()[0].is_ok());
9831002
}
9841003
})

sqld/src/connection/mod.rs

+8-8
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use tokio::{sync::Semaphore, time::timeout};
88
use crate::auth::Authenticated;
99
use crate::error::Error;
1010
use crate::query::{Params, Query};
11-
use crate::query_analysis::{Statement, TxnStatus};
11+
use crate::query_analysis::Statement;
1212
use crate::query_result_builder::{IgnoreResult, QueryResultBuilder};
1313
use crate::replication::FrameNo;
1414
use crate::Result;
@@ -32,7 +32,7 @@ pub trait Connection: Send + Sync + 'static {
3232
auth: Authenticated,
3333
response_builder: B,
3434
replication_index: Option<FrameNo>,
35-
) -> Result<(B, TxnStatus)>;
35+
) -> Result<B>;
3636

3737
/// Execute all the queries in the batch sequentially.
3838
/// If an query in the batch fails, the remaining queries are ignores, and the batch current
@@ -43,7 +43,7 @@ pub trait Connection: Send + Sync + 'static {
4343
auth: Authenticated,
4444
result_builder: B,
4545
replication_index: Option<FrameNo>,
46-
) -> Result<(B, TxnStatus)> {
46+
) -> Result<B> {
4747
let batch_len = batch.len();
4848
let mut steps = make_batch_program(batch);
4949

@@ -67,11 +67,11 @@ pub trait Connection: Send + Sync + 'static {
6767

6868
// ignore the rollback result
6969
let builder = result_builder.take(batch_len);
70-
let (builder, state) = self
70+
let builder = self
7171
.execute_program(pgm, auth, builder, replication_index)
7272
.await?;
7373

74-
Ok((builder.into_inner(), state))
74+
Ok(builder.into_inner())
7575
}
7676

7777
/// Execute all the queries in the batch sequentially.
@@ -82,7 +82,7 @@ pub trait Connection: Send + Sync + 'static {
8282
auth: Authenticated,
8383
result_builder: B,
8484
replication_index: Option<FrameNo>,
85-
) -> Result<(B, TxnStatus)> {
85+
) -> Result<B> {
8686
let steps = make_batch_program(batch);
8787
let pgm = Program::new(steps);
8888
self.execute_program(pgm, auth, result_builder, replication_index)
@@ -312,7 +312,7 @@ impl<DB: Connection> Connection for TrackedConnection<DB> {
312312
auth: Authenticated,
313313
builder: B,
314314
replication_index: Option<FrameNo>,
315-
) -> crate::Result<(B, TxnStatus)> {
315+
) -> crate::Result<B> {
316316
self.atime.store(now_millis(), Ordering::Relaxed);
317317
self.inner
318318
.execute_program(pgm, auth, builder, replication_index)
@@ -367,7 +367,7 @@ mod test {
367367
_auth: Authenticated,
368368
_builder: B,
369369
_replication_index: Option<FrameNo>,
370-
) -> crate::Result<(B, TxnStatus)> {
370+
) -> crate::Result<B> {
371371
unreachable!()
372372
}
373373

sqld/src/connection/write_proxy.rs

+25-15
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ impl WriteProxyConnection {
184184
status: &mut TxnStatus,
185185
auth: Authenticated,
186186
builder: B,
187-
) -> Result<(B, TxnStatus)> {
187+
) -> Result<B> {
188188
self.stats.inc_write_requests_delegated();
189189
*status = TxnStatus::Invalid;
190190
let res = self
@@ -208,7 +208,7 @@ impl WriteProxyConnection {
208208
self.update_last_write_frame_no(current_frame_no);
209209
}
210210

211-
Ok((builder, new_status))
211+
Ok(builder)
212212
}
213213

214214
fn update_last_write_frame_no(&self, new_frame_no: FrameNo) {
@@ -339,8 +339,8 @@ impl RemoteConnection {
339339
}
340340
Err(e) => {
341341
tracing::error!("received error from connection stream: {e}");
342-
return Err(Error::StreamDisconnect)
343-
},
342+
return Err(Error::StreamDisconnect);
343+
}
344344
}
345345
}
346346

@@ -356,26 +356,30 @@ impl Connection for WriteProxyConnection {
356356
auth: Authenticated,
357357
builder: B,
358358
replication_index: Option<FrameNo>,
359-
) -> Result<(B, TxnStatus)> {
359+
) -> Result<B> {
360360
let mut state = self.state.lock().await;
361361

362362
// This is a fresh namespace, and it is not replicated yet, proxy the first request.
363363
if self.applied_frame_no_receiver.borrow().is_none() {
364364
self.execute_remote(pgm, &mut state, auth, builder).await
365365
} else if *state == TxnStatus::Init && pgm.is_read_only() {
366+
// set the state to invalid before doing anything, and set it to a valid state after.
367+
*state = TxnStatus::Invalid;
366368
self.wait_replication_sync(replication_index).await?;
367369
// We know that this program won't perform any writes. We attempt to run it on the
368370
// replica. If it leaves an open transaction, then this program is an interactive
369371
// transaction, so we rollback the replica, and execute again on the primary.
370-
let (builder, new_state) = self
372+
let builder = self
371373
.read_conn
372374
.execute_program(pgm.clone(), auth.clone(), builder, replication_index)
373375
.await?;
376+
let new_state = self.read_conn.txn_status()?;
374377
if new_state != TxnStatus::Init {
375378
self.read_conn.rollback(auth.clone()).await?;
376379
self.execute_remote(pgm, &mut state, auth, builder).await
377380
} else {
378-
Ok((builder, new_state))
381+
*state = new_state;
382+
Ok(builder)
379383
}
380384
} else {
381385
self.execute_remote(pgm, &mut state, auth, builder).await
@@ -433,7 +437,10 @@ pub mod test {
433437
use rand::Fill;
434438

435439
use super::*;
436-
use crate::{query_result_builder::test::test_driver, rpc::proxy::rpc::{ExecuteResults, query_result::RowResult}};
440+
use crate::{
441+
query_result_builder::test::test_driver,
442+
rpc::proxy::rpc::{query_result::RowResult, ExecuteResults},
443+
};
437444

438445
/// generate an arbitraty rpc value. see build.rs for usage.
439446
pub fn arbitrary_rpc_value(u: &mut Unstructured) -> arbitrary::Result<Vec<u8>> {
@@ -497,12 +504,15 @@ pub mod test {
497504
/// In this test, we generate random ExecuteResults, and ensures that the `execute_results_to_builder` drives the builder FSM correctly.
498505
#[test]
499506
fn test_execute_results_to_builder() {
500-
test_driver(1000, |b| -> std::result::Result<crate::query_result_builder::test::FsmQueryBuilder, Error> {
501-
let mut data = [0; 10_000];
502-
data.try_fill(&mut rand::thread_rng()).unwrap();
503-
let mut un = Unstructured::new(&data);
504-
let res = ExecuteResults::arbitrary(&mut un).unwrap();
505-
execute_results_to_builder(res, b, &QueryBuilderConfig::default())
506-
});
507+
test_driver(
508+
1000,
509+
|b| -> std::result::Result<crate::query_result_builder::test::FsmQueryBuilder, Error> {
510+
let mut data = [0; 10_000];
511+
data.try_fill(&mut rand::thread_rng()).unwrap();
512+
let mut un = Unstructured::new(&data);
513+
let res = ExecuteResults::arbitrary(&mut un).unwrap();
514+
execute_results_to_builder(res, b, &QueryBuilderConfig::default())
515+
},
516+
);
507517
}
508518
}

0 commit comments

Comments
 (0)