@@ -4,7 +4,7 @@ use std::sync::atomic::{AtomicBool, Ordering};
4
4
use std:: sync:: Arc ;
5
5
6
6
use parking_lot:: { Mutex , RwLock } ;
7
- use rusqlite:: { DatabaseName , ErrorCode , OpenFlags , StatementStatus } ;
7
+ use rusqlite:: { DatabaseName , ErrorCode , OpenFlags , StatementStatus , TransactionState } ;
8
8
use sqld_libsql_bindings:: wal_hook:: { TransparentMethods , WalMethodsHook } ;
9
9
use tokio:: sync:: { watch, Notify } ;
10
10
use tokio:: time:: { Duration , Instant } ;
@@ -144,7 +144,6 @@ where
144
144
}
145
145
}
146
146
147
- #[ derive( Clone ) ]
148
147
pub struct LibSqlConnection < W : WalHook > {
149
148
inner : Arc < Mutex < Connection < W > > > ,
150
149
}
@@ -160,6 +159,12 @@ impl<W: WalHook> std::fmt::Debug for LibSqlConnection<W> {
160
159
}
161
160
}
162
161
162
+ impl < W : WalHook > Clone for LibSqlConnection < W > {
163
+ fn clone ( & self ) -> Self {
164
+ Self { inner : self . inner . clone ( ) }
165
+ }
166
+ }
167
+
163
168
pub fn open_conn < W > (
164
169
path : & Path ,
165
170
wal_methods : & ' static WalMethodsHook < W > ,
@@ -219,6 +224,15 @@ where
219
224
inner : Arc :: new ( Mutex :: new ( conn) ) ,
220
225
} )
221
226
}
227
+
228
+ pub fn txn_status ( & self ) -> crate :: Result < TxnStatus > {
229
+ Ok ( self
230
+ . inner
231
+ . lock ( )
232
+ . conn
233
+ . transaction_state ( Some ( DatabaseName :: Main ) ) ?
234
+ . into ( ) )
235
+ }
222
236
}
223
237
224
238
struct Connection < W : WalHook = TransparentMethods > {
@@ -351,6 +365,16 @@ unsafe extern "C" fn busy_handler<W: WalHook>(state: *mut c_void, _retries: c_in
351
365
} )
352
366
}
353
367
368
+ impl From < TransactionState > for TxnStatus {
369
+ fn from ( value : TransactionState ) -> Self {
370
+ use TransactionState as Tx ;
371
+ match value {
372
+ Tx :: None => TxnStatus :: Init ,
373
+ Tx :: Read | Tx :: Write => TxnStatus :: Txn ,
374
+ _ => unreachable ! ( ) ,
375
+ }
376
+ }
377
+ }
354
378
impl < W : WalHook > Connection < W > {
355
379
fn new (
356
380
path : & Path ,
@@ -405,7 +429,7 @@ impl<W: WalHook> Connection<W> {
405
429
this : Arc < Mutex < Self > > ,
406
430
pgm : Program ,
407
431
mut builder : B ,
408
- ) -> Result < ( B , TxnStatus ) > {
432
+ ) -> Result < B > {
409
433
use rusqlite:: TransactionState as Tx ;
410
434
411
435
let state = this. lock ( ) . state . clone ( ) ;
@@ -469,23 +493,18 @@ impl<W: WalHook> Connection<W> {
469
493
results. push ( res) ;
470
494
}
471
495
472
- let status = if matches ! (
473
- this. lock( )
474
- . conn
475
- . transaction_state( Some ( DatabaseName :: Main ) ) ?,
476
- Tx :: Read | Tx :: Write
477
- ) {
478
- TxnStatus :: Txn
479
- } else {
480
- TxnStatus :: Init
481
- } ;
496
+ let status = this
497
+ . lock ( )
498
+ . conn
499
+ . transaction_state ( Some ( DatabaseName :: Main ) ) ?
500
+ . into ( ) ;
482
501
483
502
builder. finish (
484
503
* this. lock ( ) . current_frame_no_receiver . borrow_and_update ( ) ,
485
504
status,
486
505
) ?;
487
506
488
- Ok ( ( builder, status ) )
507
+ Ok ( builder)
489
508
}
490
509
491
510
fn execute_step (
@@ -736,7 +755,7 @@ where
736
755
auth : Authenticated ,
737
756
builder : B ,
738
757
_replication_index : Option < FrameNo > ,
739
- ) -> Result < ( B , TxnStatus ) > {
758
+ ) -> Result < B > {
740
759
check_program_auth ( auth, & pgm) ?;
741
760
let conn = self . inner . clone ( ) ;
742
761
tokio:: task:: spawn_blocking ( move || Connection :: run ( conn, pgm, builder) )
@@ -828,7 +847,7 @@ mod test {
828
847
fn test_libsql_conn_builder_driver ( ) {
829
848
test_driver ( 1000 , |b| {
830
849
let conn = setup_test_conn ( ) ;
831
- Connection :: run ( conn, Program :: seq ( & [ "select * from test" ] ) , b) . map ( |x| x . 0 )
850
+ Connection :: run ( conn, Program :: seq ( & [ "select * from test" ] ) , b)
832
851
} )
833
852
}
834
853
@@ -852,23 +871,23 @@ mod test {
852
871
853
872
tokio:: time:: pause ( ) ;
854
873
let conn = make_conn. make_connection ( ) . await . unwrap ( ) ;
855
- let ( _builder, state ) = Connection :: run (
874
+ let _builder = Connection :: run (
856
875
conn. inner . clone ( ) ,
857
876
Program :: seq ( & [ "BEGIN IMMEDIATE" ] ) ,
858
877
TestBuilder :: default ( ) ,
859
878
)
860
879
. unwrap ( ) ;
861
- assert_eq ! ( state , TxnStatus :: Txn ) ;
880
+ assert_eq ! ( conn . txn_status ( ) . unwrap ( ) , TxnStatus :: Txn ) ;
862
881
863
882
tokio:: time:: advance ( TXN_TIMEOUT * 2 ) . await ;
864
883
865
- let ( builder, state ) = Connection :: run (
884
+ let builder = Connection :: run (
866
885
conn. inner . clone ( ) ,
867
886
Program :: seq ( & [ "BEGIN IMMEDIATE" ] ) ,
868
887
TestBuilder :: default ( ) ,
869
888
)
870
889
. unwrap ( ) ;
871
- assert_eq ! ( state , TxnStatus :: Init ) ;
890
+ assert_eq ! ( conn . txn_status ( ) . unwrap ( ) , TxnStatus :: Init ) ;
872
891
assert ! ( matches!( builder. into_ret( ) [ 0 ] , Err ( Error :: LibSqlTxTimeout ) ) ) ;
873
892
}
874
893
@@ -896,13 +915,13 @@ mod test {
896
915
for _ in 0 ..10 {
897
916
let conn = make_conn. make_connection ( ) . await . unwrap ( ) ;
898
917
set. spawn_blocking ( move || {
899
- let ( builder, state ) = Connection :: run (
900
- conn. inner ,
918
+ let builder = Connection :: run (
919
+ conn. inner . clone ( ) ,
901
920
Program :: seq ( & [ "BEGIN IMMEDIATE" ] ) ,
902
921
TestBuilder :: default ( ) ,
903
922
)
904
923
. unwrap ( ) ;
905
- assert_eq ! ( state , TxnStatus :: Txn ) ;
924
+ assert_eq ! ( conn . txn_status ( ) . unwrap ( ) , TxnStatus :: Txn ) ;
906
925
assert ! ( builder. into_ret( ) [ 0 ] . is_ok( ) ) ;
907
926
} ) ;
908
927
}
@@ -937,15 +956,15 @@ mod test {
937
956
938
957
let conn1 = make_conn. make_connection ( ) . await . unwrap ( ) ;
939
958
tokio:: task:: spawn_blocking ( {
940
- let conn = conn1. inner . clone ( ) ;
959
+ let conn = conn1. clone ( ) ;
941
960
move || {
942
- let ( builder, state ) = Connection :: run (
943
- conn,
961
+ let builder = Connection :: run (
962
+ conn. inner . clone ( ) ,
944
963
Program :: seq ( & [ "BEGIN IMMEDIATE" ] ) ,
945
964
TestBuilder :: default ( ) ,
946
965
)
947
966
. unwrap ( ) ;
948
- assert_eq ! ( state , TxnStatus :: Txn ) ;
967
+ assert_eq ! ( conn . txn_status ( ) . unwrap ( ) , TxnStatus :: Txn ) ;
949
968
assert ! ( builder. into_ret( ) [ 0 ] . is_ok( ) ) ;
950
969
}
951
970
} )
@@ -954,16 +973,16 @@ mod test {
954
973
955
974
let conn2 = make_conn. make_connection ( ) . await . unwrap ( ) ;
956
975
let handle = tokio:: task:: spawn_blocking ( {
957
- let conn = conn2. inner . clone ( ) ;
976
+ let conn = conn2. clone ( ) ;
958
977
move || {
959
978
let before = Instant :: now ( ) ;
960
- let ( builder, state ) = Connection :: run (
961
- conn,
979
+ let builder = Connection :: run (
980
+ conn. inner . clone ( ) ,
962
981
Program :: seq ( & [ "BEGIN IMMEDIATE" ] ) ,
963
982
TestBuilder :: default ( ) ,
964
983
)
965
984
. unwrap ( ) ;
966
- assert_eq ! ( state , TxnStatus :: Txn ) ;
985
+ assert_eq ! ( conn . txn_status ( ) . unwrap ( ) , TxnStatus :: Txn ) ;
967
986
assert ! ( builder. into_ret( ) [ 0 ] . is_ok( ) ) ;
968
987
before. elapsed ( )
969
988
}
@@ -973,12 +992,12 @@ mod test {
973
992
tokio:: time:: sleep ( wait_time) . await ;
974
993
975
994
tokio:: task:: spawn_blocking ( {
976
- let conn = conn1. inner . clone ( ) ;
995
+ let conn = conn1. clone ( ) ;
977
996
move || {
978
- let ( builder, state ) =
979
- Connection :: run ( conn, Program :: seq ( & [ "COMMIT" ] ) , TestBuilder :: default ( ) )
997
+ let builder =
998
+ Connection :: run ( conn. inner . clone ( ) , Program :: seq ( & [ "COMMIT" ] ) , TestBuilder :: default ( ) )
980
999
. unwrap ( ) ;
981
- assert_eq ! ( state , TxnStatus :: Init ) ;
1000
+ assert_eq ! ( conn . txn_status ( ) . unwrap ( ) , TxnStatus :: Init ) ;
982
1001
assert ! ( builder. into_ret( ) [ 0 ] . is_ok( ) ) ;
983
1002
}
984
1003
} )
0 commit comments