diff --git a/src/database/connection.rs b/src/database/connection.rs index f70341d78..d3f7d9ae6 100644 --- a/src/database/connection.rs +++ b/src/database/connection.rs @@ -1,8 +1,10 @@ +use std::{future::Future, pin::Pin}; + +use futures_util::Stream; + use crate::{ DbBackend, DbErr, ExecResult, QueryResult, Statement, StatementBuilder, TransactionError, }; -use futures_util::Stream; -use std::{future::Future, pin::Pin}; /// The generic API for a database connection that can perform query or execute statements. /// It abstracts database connection and transaction @@ -125,6 +127,40 @@ impl std::fmt::Display for AccessMode { } } +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +/// Which kind of transaction to start. Only supported by SQLite. +/// +pub enum SqliteTransactionMode { + /// The default. Transaction starts when the next statement is executed, and + /// will be a read or write transaction depending on that statement. + Deferred, + /// Start a write transaction as soon as the BEGIN statement is received. + Immediate, + /// Start a write transaction as soon as the BEGIN statement is received. + /// When in non-WAL mode, also block all other transactions from reading the + /// database. + Exclusive, +} + +impl SqliteTransactionMode { + /// The keyword used to start a transaction in this mode (the word coming after "BEGIN"). + pub fn sqlite_keyword(&self) -> &'static str { + match self { + SqliteTransactionMode::Deferred => "DEFERRED", + SqliteTransactionMode::Immediate => "IMMEDIATE", + SqliteTransactionMode::Exclusive => "EXCLUSIVE", + } + } +} + +/// Configuration for starting a transaction +#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)] +pub struct TransactionConfig { + pub isolation_level: Option, + pub access_mode: Option, + pub sqlite_transaction_mode: Option, +} + /// Spawn database transaction #[async_trait::async_trait] pub trait TransactionTrait { @@ -139,8 +175,7 @@ pub trait TransactionTrait { /// Returns a Transaction that can be committed or rolled back async fn begin_with_config( &self, - isolation_level: Option, - access_mode: Option, + config: TransactionConfig, ) -> Result; /// Execute the function inside a transaction. @@ -159,8 +194,7 @@ pub trait TransactionTrait { async fn transaction_with_config( &self, callback: F, - isolation_level: Option, - access_mode: Option, + config: TransactionConfig, ) -> Result> where F: for<'c> FnOnce( diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs index 108111ad1..e3a87360c 100644 --- a/src/database/db_connection.rs +++ b/src/database/db_connection.rs @@ -1,20 +1,19 @@ -use crate::{ - AccessMode, ConnectionTrait, DatabaseTransaction, ExecResult, IsolationLevel, QueryResult, - Schema, SchemaBuilder, Statement, StatementBuilder, StreamTrait, TransactionError, - TransactionTrait, error::*, -}; +#[cfg(any(feature = "mock", feature = "proxy"))] +use std::sync::Arc; use std::{fmt::Debug, future::Future, pin::Pin}; -use tracing::instrument; -use url::Url; #[cfg(feature = "sqlx-dep")] use sqlx::pool::PoolConnection; +use tracing::instrument; +use url::Url; #[cfg(feature = "rusqlite")] use crate::driver::rusqlite::{RusqliteInnerConnection, RusqliteSharedConnection}; - -#[cfg(any(feature = "mock", feature = "proxy"))] -use std::sync::Arc; +use crate::{ + ConnectionTrait, DatabaseTransaction, ExecResult, QueryResult, Schema, SchemaBuilder, + Statement, StatementBuilder, StreamTrait, TransactionConfig, TransactionError, + TransactionTrait, error::*, +}; /// Handle a database connection depending on the backend enabled by the feature /// flags. This creates a connection pool internally (for SQLx connections), @@ -351,15 +350,21 @@ impl TransactionTrait for DatabaseConnection { async fn begin(&self) -> Result { match &self.inner { #[cfg(feature = "sqlx-mysql")] - DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.begin(None, None).await, + DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => { + conn.begin(TransactionConfig::default()).await + } #[cfg(feature = "sqlx-postgres")] DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => { - conn.begin(None, None).await + conn.begin(TransactionConfig::default()).await } #[cfg(feature = "sqlx-sqlite")] - DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.begin(None, None).await, + DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => { + conn.begin(TransactionConfig::default()).await + } #[cfg(feature = "rusqlite")] - DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.begin(None, None), + DatabaseConnectionType::RusqliteSharedConnection(conn) => { + conn.begin(TransactionConfig::default()) + } #[cfg(feature = "mock")] DatabaseConnectionType::MockDatabaseConnection(conn) => { DatabaseTransaction::new_mock(Arc::clone(conn), None).await @@ -375,26 +380,17 @@ impl TransactionTrait for DatabaseConnection { #[instrument(level = "trace")] async fn begin_with_config( &self, - _isolation_level: Option, - _access_mode: Option, + _config: TransactionConfig, ) -> Result { match &self.inner { #[cfg(feature = "sqlx-mysql")] - DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => { - conn.begin(_isolation_level, _access_mode).await - } + DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.begin(_config).await, #[cfg(feature = "sqlx-postgres")] - DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => { - conn.begin(_isolation_level, _access_mode).await - } + DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.begin(_config).await, #[cfg(feature = "sqlx-sqlite")] - DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => { - conn.begin(_isolation_level, _access_mode).await - } + DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.begin(_config).await, #[cfg(feature = "rusqlite")] - DatabaseConnectionType::RusqliteSharedConnection(conn) => { - conn.begin(_isolation_level, _access_mode) - } + DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.begin(_config), #[cfg(feature = "mock")] DatabaseConnectionType::MockDatabaseConnection(conn) => { DatabaseTransaction::new_mock(Arc::clone(conn), None).await @@ -408,7 +404,8 @@ impl TransactionTrait for DatabaseConnection { } /// Execute the function inside a transaction. - /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. + /// If the function returns an error, the transaction will be rolled back. + /// If it does not return an error, the transaction will be committed. #[instrument(level = "trace", skip(_callback))] async fn transaction(&self, _callback: F) -> Result> where @@ -422,19 +419,22 @@ impl TransactionTrait for DatabaseConnection { match &self.inner { #[cfg(feature = "sqlx-mysql")] DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => { - conn.transaction(_callback, None, None).await + conn.transaction(_callback, TransactionConfig::default()) + .await } #[cfg(feature = "sqlx-postgres")] DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => { - conn.transaction(_callback, None, None).await + conn.transaction(_callback, TransactionConfig::default()) + .await } #[cfg(feature = "sqlx-sqlite")] DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => { - conn.transaction(_callback, None, None).await + conn.transaction(_callback, TransactionConfig::default()) + .await } #[cfg(feature = "rusqlite")] DatabaseConnectionType::RusqliteSharedConnection(conn) => { - conn.transaction(_callback, None, None) + conn.transaction(_callback, TransactionConfig::default()) } #[cfg(feature = "mock")] DatabaseConnectionType::MockDatabaseConnection(conn) => { @@ -455,13 +455,13 @@ impl TransactionTrait for DatabaseConnection { } /// Execute the function inside a transaction. - /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. + /// If the function returns an error, the transaction will be rolled back. + /// If it does not return an error, the transaction will be committed. #[instrument(level = "trace", skip(_callback))] async fn transaction_with_config( &self, _callback: F, - _isolation_level: Option, - _access_mode: Option, + _config: TransactionConfig, ) -> Result> where F: for<'c> FnOnce( @@ -474,22 +474,19 @@ impl TransactionTrait for DatabaseConnection { match &self.inner { #[cfg(feature = "sqlx-mysql")] DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => { - conn.transaction(_callback, _isolation_level, _access_mode) - .await + conn.transaction(_callback, _config).await } #[cfg(feature = "sqlx-postgres")] DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => { - conn.transaction(_callback, _isolation_level, _access_mode) - .await + conn.transaction(_callback, _config).await } #[cfg(feature = "sqlx-sqlite")] DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => { - conn.transaction(_callback, _isolation_level, _access_mode) - .await + conn.transaction(_callback, _config).await } #[cfg(feature = "rusqlite")] DatabaseConnectionType::RusqliteSharedConnection(conn) => { - conn.transaction(_callback, _isolation_level, _access_mode) + conn.transaction(_callback, _config) } #[cfg(feature = "mock")] DatabaseConnectionType::MockDatabaseConnection(conn) => { @@ -556,8 +553,8 @@ impl DatabaseConnection { #[cfg(feature = "rbac")] impl DatabaseConnection { - /// Load RBAC data from the same database as this connection and setup RBAC engine. - /// If the RBAC engine already exists, it will be replaced. + /// Load RBAC data from the same database as this connection and setup RBAC + /// engine. If the RBAC engine already exists, it will be replaced. pub async fn load_rbac(&self) -> Result<(), DbErr> { self.load_rbac_from(self).await } @@ -575,7 +572,8 @@ impl DatabaseConnection { self.rbac.replace(engine); } - /// Create a restricted connection with access control specific for the user. + /// Create a restricted connection with access control specific for the + /// user. pub fn restricted_for( &self, user_id: crate::rbac::RbacUserId, diff --git a/src/database/executor.rs b/src/database/executor.rs index ec9135c01..a841336c7 100644 --- a/src/database/executor.rs +++ b/src/database/executor.rs @@ -1,10 +1,11 @@ +use std::future::Future; +use std::pin::Pin; + use crate::{ - AccessMode, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbBackend, DbErr, - ExecResult, IsolationLevel, QueryResult, Statement, TransactionError, TransactionTrait, + ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbBackend, DbErr, ExecResult, + QueryResult, Statement, TransactionConfig, TransactionError, TransactionTrait, }; use crate::{Schema, SchemaBuilder}; -use std::future::Future; -use std::pin::Pin; /// A wrapper that holds either a reference to a [`DatabaseConnection`] or [`DatabaseTransaction`]. #[derive(Debug)] @@ -78,16 +79,11 @@ impl TransactionTrait for DatabaseExecutor<'_> { async fn begin_with_config( &self, - isolation_level: Option, - access_mode: Option, + config: TransactionConfig, ) -> Result { match self { - DatabaseExecutor::Connection(conn) => { - conn.begin_with_config(isolation_level, access_mode).await - } - DatabaseExecutor::Transaction(trans) => { - trans.begin_with_config(isolation_level, access_mode).await - } + DatabaseExecutor::Connection(conn) => conn.begin_with_config(config).await, + DatabaseExecutor::Transaction(trans) => trans.begin_with_config(config).await, } } @@ -109,8 +105,7 @@ impl TransactionTrait for DatabaseExecutor<'_> { async fn transaction_with_config( &self, callback: F, - isolation_level: Option, - access_mode: Option, + config: TransactionConfig, ) -> Result> where F: for<'c> FnOnce( @@ -122,13 +117,10 @@ impl TransactionTrait for DatabaseExecutor<'_> { { match self { DatabaseExecutor::Connection(conn) => { - conn.transaction_with_config(callback, isolation_level, access_mode) - .await + conn.transaction_with_config(callback, config).await } DatabaseExecutor::Transaction(trans) => { - trans - .transaction_with_config(callback, isolation_level, access_mode) - .await + trans.transaction_with_config(callback, config).await } } } diff --git a/src/database/restricted_connection.rs b/src/database/restricted_connection.rs index 441d14e03..44f45928e 100644 --- a/src/database/restricted_connection.rs +++ b/src/database/restricted_connection.rs @@ -1,19 +1,24 @@ -use crate::rbac::{ - PermissionRequest, RbacEngine, RbacError, RbacPermissionsByResources, - RbacResourcesAndPermissions, RbacRoleHierarchyList, RbacRolesAndRanks, RbacUserRolePermissions, - ResourceRequest, - entity::{role::RoleId, user::UserId}, +use std::{ + pin::Pin, + sync::{Arc, RwLock}, }; + +use tracing::instrument; + use crate::{ AccessMode, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbBackend, DbErr, ExecResult, IsolationLevel, QueryResult, Statement, StatementBuilder, TransactionError, TransactionSession, TransactionTrait, }; -use std::{ - pin::Pin, - sync::{Arc, RwLock}, +use crate::{ + TransactionConfig, + rbac::{ + PermissionRequest, RbacEngine, RbacError, RbacPermissionsByResources, + RbacResourcesAndPermissions, RbacRoleHierarchyList, RbacRolesAndRanks, + RbacUserRolePermissions, ResourceRequest, + entity::{role::RoleId, user::UserId}, + }, }; -use tracing::instrument; /// Wrapper of [`DatabaseConnection`] that performs authorization on all executed /// queries for the current user. Note that raw SQL [`Statement`] is not allowed @@ -223,15 +228,11 @@ impl TransactionTrait for RestrictedConnection { #[instrument(level = "trace")] async fn begin_with_config( &self, - isolation_level: Option, - access_mode: Option, + config: TransactionConfig, ) -> Result { Ok(RestrictedTransaction { user_id: self.user_id, - conn: self - .conn - .begin_with_config(isolation_level, access_mode) - .await?, + conn: self.conn.begin_with_config(config).await?, rbac: self.conn.rbac.clone(), }) } @@ -258,8 +259,7 @@ impl TransactionTrait for RestrictedConnection { async fn transaction_with_config( &self, callback: F, - isolation_level: Option, - access_mode: Option, + config: TransactionConfig, ) -> Result> where F: for<'c> FnOnce( @@ -270,7 +270,7 @@ impl TransactionTrait for RestrictedConnection { E: std::fmt::Display + std::fmt::Debug + Send, { let transaction = self - .begin_with_config(isolation_level, access_mode) + .begin_with_config(config) .await .map_err(TransactionError::Connection)?; transaction.run(callback).await @@ -293,15 +293,11 @@ impl TransactionTrait for RestrictedTransaction { #[instrument(level = "trace")] async fn begin_with_config( &self, - isolation_level: Option, - access_mode: Option, + config: TransactionConfig, ) -> Result { Ok(RestrictedTransaction { user_id: self.user_id, - conn: self - .conn - .begin_with_config(isolation_level, access_mode) - .await?, + conn: self.conn.begin_with_config(config).await?, rbac: self.rbac.clone(), }) } @@ -328,8 +324,7 @@ impl TransactionTrait for RestrictedTransaction { async fn transaction_with_config( &self, callback: F, - isolation_level: Option, - access_mode: Option, + config: TransactionConfig, ) -> Result> where F: for<'c> FnOnce( @@ -340,7 +335,7 @@ impl TransactionTrait for RestrictedTransaction { E: std::fmt::Display + std::fmt::Debug + Send, { let transaction = self - .begin_with_config(isolation_level, access_mode) + .begin_with_config(config) .await .map_err(TransactionError::Connection)?; transaction.run(callback).await diff --git a/src/database/transaction.rs b/src/database/transaction.rs index e16f158bc..854f70db4 100644 --- a/src/database/transaction.rs +++ b/src/database/transaction.rs @@ -1,7 +1,7 @@ #![allow(unused_assignments)] use crate::{ - AccessMode, ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, IsolationLevel, - QueryResult, Statement, StreamTrait, TransactionSession, TransactionStream, TransactionTrait, + ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, QueryResult, Statement, + StreamTrait, TransactionConfig, TransactionSession, TransactionStream, TransactionTrait, debug_print, error::*, }; #[cfg(feature = "sqlx-dep")] @@ -35,9 +35,13 @@ impl DatabaseTransaction { conn: Arc>, backend: DbBackend, metric_callback: Option, - isolation_level: Option, - access_mode: Option, + config: TransactionConfig, ) -> Result { + let TransactionConfig { + isolation_level, + access_mode, + sqlite_transaction_mode, + } = config; let res = DatabaseTransaction { conn, backend, @@ -92,9 +96,16 @@ impl DatabaseTransaction { access_mode, ) .await?; - ::TransactionManager::begin(c, None) - .await - .map_err(sqlx_error_to_query_err) + // TODO using this for beginning a nested transaction currently causes an error. Should we make it a warning instead? + let statement = config.sqlite_transaction_mode.map(|mode| { + std::borrow::Cow::from(format!("BEGIN {}", mode.sqlite_keyword())) + }); + ::TransactionManager::begin( + c, + statement.into(), + ) + .await + .map_err(sqlx_error_to_query_err) } #[cfg(feature = "rusqlite")] InnerConnection::Rusqlite(c) => c.begin(), @@ -603,8 +614,7 @@ impl TransactionTrait for DatabaseTransaction { Arc::clone(&self.conn), self.backend, self.metric_callback.clone(), - None, - None, + TransactionConfig::default(), ) .await } @@ -612,15 +622,13 @@ impl TransactionTrait for DatabaseTransaction { #[instrument(level = "trace")] async fn begin_with_config( &self, - isolation_level: Option, - access_mode: Option, + config: TransactionConfig, ) -> Result { DatabaseTransaction::begin( Arc::clone(&self.conn), self.backend, self.metric_callback.clone(), - isolation_level, - access_mode, + config, ) .await } @@ -649,8 +657,7 @@ impl TransactionTrait for DatabaseTransaction { async fn transaction_with_config( &self, _callback: F, - isolation_level: Option, - access_mode: Option, + config: TransactionConfig, ) -> Result> where F: for<'c> FnOnce( @@ -661,7 +668,7 @@ impl TransactionTrait for DatabaseTransaction { E: std::fmt::Display + std::fmt::Debug + Send, { let transaction = self - .begin_with_config(isolation_level, access_mode) + .begin_with_config(config) .await .map_err(TransactionError::Connection)?; transaction.run(_callback).await diff --git a/src/driver/mock.rs b/src/driver/mock.rs index 91bceb8e4..f8b217eac 100644 --- a/src/driver/mock.rs +++ b/src/driver/mock.rs @@ -1,8 +1,3 @@ -use crate::{ - DatabaseConnection, DatabaseConnectionType, DbBackend, ExecResult, MockDatabase, QueryResult, - Statement, Transaction, debug_print, error::*, -}; -use futures_util::Stream; use std::{ fmt::Debug, pin::Pin, @@ -11,8 +6,15 @@ use std::{ atomic::{AtomicUsize, Ordering}, }, }; + +use futures_util::Stream; use tracing::instrument; +use crate::{ + DatabaseConnection, DatabaseConnectionType, DbBackend, ExecResult, MockDatabase, QueryResult, + Statement, Transaction, TransactionConfig, debug_print, error::*, +}; + #[cfg(not(feature = "sync"))] type PinBoxStream = Pin> + Send>>; #[cfg(feature = "sync")] @@ -271,8 +273,7 @@ impl crate::DatabaseTransaction { Arc::new(Mutex::new(crate::InnerConnection::Mock(inner))), backend, metric_callback, - None, - None, + TransactionConfig::default(), ) .await } diff --git a/src/driver/rusqlite.rs b/src/driver/rusqlite.rs index 740deca5e..d4bd6b28e 100644 --- a/src/driver/rusqlite.rs +++ b/src/driver/rusqlite.rs @@ -3,7 +3,6 @@ use std::{ sync::{Arc, Mutex, MutexGuard, TryLockError}, time::{Duration, Instant}, }; -use tracing::{debug, instrument, warn}; pub use OwnedRow as RusqliteRow; use rusqlite::{ @@ -14,11 +13,12 @@ pub use rusqlite::{ Connection as RusqliteConnection, Error as RusqliteError, types::Value as RusqliteOwnedValue, }; use sea_query_rusqlite::{RusqliteValue, RusqliteValues, rusqlite}; +use tracing::{debug, instrument, warn}; use crate::{ - AccessMode, ColIdx, ConnectOptions, DatabaseConnection, DatabaseConnectionType, - DatabaseTransaction, InnerConnection, IsolationLevel, QueryStream, Statement, TransactionError, - error::*, executor::*, + ColIdx, ConnectOptions, DatabaseConnection, DatabaseConnectionType, DatabaseTransaction, + InnerConnection, QueryStream, Statement, TransactionConfig, TransactionError, error::*, + executor::*, }; /// A helper class to connect to Rusqlite @@ -329,18 +329,13 @@ impl RusqliteSharedConnection { /// Bundle a set of SQL statements that execute together. #[instrument(level = "trace")] - pub fn begin( - &self, - isolation_level: Option, - access_mode: Option, - ) -> Result { + pub fn begin(&self, config: TransactionConfig) -> Result { let conn = self.loan()?; DatabaseTransaction::begin( Arc::new(Mutex::new(InnerConnection::Rusqlite(conn))), crate::DbBackend::Sqlite, self.metric_callback.clone(), - isolation_level, - access_mode, + config, ) } @@ -349,14 +344,13 @@ impl RusqliteSharedConnection { pub fn transaction( &self, callback: F, - isolation_level: Option, - access_mode: Option, + config: TransactionConfig, ) -> Result> where F: for<'b> FnOnce(&'b DatabaseTransaction) -> Result, E: std::fmt::Display + std::fmt::Debug, { - self.begin(isolation_level, access_mode) + self.begin(config) .map_err(|e| TransactionError::Connection(e))? .run(callback) } diff --git a/src/driver/sqlx_mysql.rs b/src/driver/sqlx_mysql.rs index 62e7e5aab..5d0ca7857 100644 --- a/src/driver/sqlx_mysql.rs +++ b/src/driver/sqlx_mysql.rs @@ -1,25 +1,23 @@ +use std::{future::Future, pin::Pin, sync::Arc}; + use futures_util::lock::Mutex; use log::LevelFilter; use sea_query::Values; -use std::{future::Future, pin::Pin, sync::Arc}; - +use sea_query_sqlx::SqlxValues; use sqlx::{ Connection, Executor, MySql, MySqlPool, mysql::{MySqlConnectOptions, MySqlQueryResult, MySqlRow}, pool::PoolConnection, }; - -use sea_query_sqlx::SqlxValues; use tracing::instrument; +use super::sqlx_common::*; use crate::{ AccessMode, ConnectOptions, DatabaseConnection, DatabaseConnectionType, DatabaseTransaction, - DbBackend, IsolationLevel, QueryStream, Statement, TransactionError, debug_print, error::*, - executor::*, + DbBackend, IsolationLevel, QueryStream, Statement, TransactionConfig, TransactionError, + debug_print, error::*, executor::*, }; -use super::sqlx_common::*; - /// Defines the [sqlx::mysql] connector #[derive(Debug)] pub struct SqlxMySqlConnector; @@ -196,19 +194,9 @@ impl SqlxMySqlPoolConnection { /// Bundle a set of SQL statements that execute together. #[instrument(level = "trace")] - pub async fn begin( - &self, - isolation_level: Option, - access_mode: Option, - ) -> Result { + pub async fn begin(&self, config: TransactionConfig) -> Result { let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?; - DatabaseTransaction::new_mysql( - conn, - self.metric_callback.clone(), - isolation_level, - access_mode, - ) - .await + DatabaseTransaction::new_mysql(conn, self.metric_callback.clone(), config).await } /// Create a MySQL transaction @@ -216,8 +204,7 @@ impl SqlxMySqlPoolConnection { pub async fn transaction( &self, callback: F, - isolation_level: Option, - access_mode: Option, + config: TransactionConfig, ) -> Result> where F: for<'b> FnOnce( @@ -228,14 +215,10 @@ impl SqlxMySqlPoolConnection { E: std::fmt::Display + std::fmt::Debug + Send, { let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?; - let transaction = DatabaseTransaction::new_mysql( - conn, - self.metric_callback.clone(), - isolation_level, - access_mode, - ) - .await - .map_err(|e| TransactionError::Connection(e))?; + let transaction = + DatabaseTransaction::new_mysql(conn, self.metric_callback.clone(), config) + .await + .map_err(|e| TransactionError::Connection(e))?; transaction.run(callback).await } @@ -341,15 +324,13 @@ impl crate::DatabaseTransaction { pub(crate) async fn new_mysql( inner: PoolConnection, metric_callback: Option, - isolation_level: Option, - access_mode: Option, + config: TransactionConfig, ) -> Result { Self::begin( Arc::new(Mutex::new(crate::InnerConnection::MySql(inner))), crate::DbBackend::MySql, metric_callback, - isolation_level, - access_mode, + config, ) .await } diff --git a/src/driver/sqlx_postgres.rs b/src/driver/sqlx_postgres.rs index 142dd37f7..ede2b1d9d 100644 --- a/src/driver/sqlx_postgres.rs +++ b/src/driver/sqlx_postgres.rs @@ -1,24 +1,23 @@ +use std::{fmt::Write, future::Future, pin::Pin, sync::Arc}; + use futures_util::lock::Mutex; use log::LevelFilter; use sea_query::Values; -use std::{fmt::Write, future::Future, pin::Pin, sync::Arc}; - +use sea_query_sqlx::SqlxValues; use sqlx::{ Connection, Executor, PgPool, Postgres, pool::PoolConnection, postgres::{PgConnectOptions, PgQueryResult, PgRow}, }; - -use sea_query_sqlx::SqlxValues; use tracing::instrument; +use super::sqlx_common::*; use crate::{ AccessMode, ConnectOptions, DatabaseConnection, DatabaseConnectionType, DatabaseTransaction, - IsolationLevel, QueryStream, Statement, TransactionError, debug_print, error::*, executor::*, + IsolationLevel, QueryStream, Statement, TransactionConfig, TransactionError, debug_print, + error::*, executor::*, }; -use super::sqlx_common::*; - /// Defines the [sqlx::postgres] connector #[derive(Debug)] pub struct SqlxPostgresConnector; @@ -226,19 +225,9 @@ impl SqlxPostgresPoolConnection { /// Bundle a set of SQL statements that execute together. #[instrument(level = "trace")] - pub async fn begin( - &self, - isolation_level: Option, - access_mode: Option, - ) -> Result { + pub async fn begin(&self, config: TransactionConfig) -> Result { let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?; - DatabaseTransaction::new_postgres( - conn, - self.metric_callback.clone(), - isolation_level, - access_mode, - ) - .await + DatabaseTransaction::new_postgres(conn, self.metric_callback.clone(), config).await } /// Create a PostgreSQL transaction @@ -246,8 +235,7 @@ impl SqlxPostgresPoolConnection { pub async fn transaction( &self, callback: F, - isolation_level: Option, - access_mode: Option, + config: TransactionConfig, ) -> Result> where F: for<'b> FnOnce( @@ -258,14 +246,10 @@ impl SqlxPostgresPoolConnection { E: std::fmt::Display + std::fmt::Debug + Send, { let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?; - let transaction = DatabaseTransaction::new_postgres( - conn, - self.metric_callback.clone(), - isolation_level, - access_mode, - ) - .await - .map_err(|e| TransactionError::Connection(e))?; + let transaction = + DatabaseTransaction::new_postgres(conn, self.metric_callback.clone(), config) + .await + .map_err(|e| TransactionError::Connection(e))?; transaction.run(callback).await } @@ -373,15 +357,13 @@ impl crate::DatabaseTransaction { pub(crate) async fn new_postgres( inner: PoolConnection, metric_callback: Option, - isolation_level: Option, - access_mode: Option, + config: TransactionConfig, ) -> Result { Self::begin( Arc::new(Mutex::new(crate::InnerConnection::Postgres(inner))), crate::DbBackend::Postgres, metric_callback, - isolation_level, - access_mode, + config, ) .await } diff --git a/src/driver/sqlx_sqlite.rs b/src/driver/sqlx_sqlite.rs index 42b85582e..5ca5f282f 100644 --- a/src/driver/sqlx_sqlite.rs +++ b/src/driver/sqlx_sqlite.rs @@ -1,25 +1,23 @@ +use std::{future::Future, pin::Pin, sync::Arc}; + use futures_util::lock::Mutex; use log::LevelFilter; use sea_query::Values; -use std::{future::Future, pin::Pin, sync::Arc}; - +use sea_query_sqlx::SqlxValues; use sqlx::{ Connection, Executor, Sqlite, SqlitePool, pool::PoolConnection, sqlite::{SqliteConnectOptions, SqliteQueryResult, SqliteRow}, }; - -use sea_query_sqlx::SqlxValues; use tracing::{instrument, warn}; +use super::sqlx_common::*; use crate::{ AccessMode, ConnectOptions, DatabaseConnection, DatabaseConnectionType, DatabaseTransaction, - IsolationLevel, QueryStream, Statement, TransactionError, debug_print, error::*, executor::*, - sqlx_error_to_exec_err, + IsolationLevel, QueryStream, Statement, TransactionConfig, TransactionError, debug_print, + error::*, executor::*, sqlx_error_to_exec_err, }; -use super::sqlx_common::*; - /// Defines the [sqlx::sqlite] connector #[derive(Debug)] pub struct SqlxSqliteConnector; @@ -163,7 +161,8 @@ impl SqlxSqlitePoolConnection { } } - /// Get one result from a SQL query. Returns [Option::None] if no match was found + /// Get one result from a SQL query. Returns [Option::None] if no match was + /// found #[instrument(level = "trace")] pub async fn query_one(&self, stmt: Statement) -> Result, DbErr> { debug_print!("{}", stmt); @@ -211,19 +210,9 @@ impl SqlxSqlitePoolConnection { /// Bundle a set of SQL statements that execute together. #[instrument(level = "trace")] - pub async fn begin( - &self, - isolation_level: Option, - access_mode: Option, - ) -> Result { + pub async fn begin(&self, config: TransactionConfig) -> Result { let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?; - DatabaseTransaction::new_sqlite( - conn, - self.metric_callback.clone(), - isolation_level, - access_mode, - ) - .await + DatabaseTransaction::new_sqlite(conn, self.metric_callback.clone(), config).await } /// Create a SQLite transaction @@ -231,8 +220,7 @@ impl SqlxSqlitePoolConnection { pub async fn transaction( &self, callback: F, - isolation_level: Option, - access_mode: Option, + config: TransactionConfig, ) -> Result> where F: for<'b> FnOnce( @@ -243,14 +231,10 @@ impl SqlxSqlitePoolConnection { E: std::fmt::Display + std::fmt::Debug + Send, { let conn = self.pool.acquire().await.map_err(sqlx_conn_acquire_err)?; - let transaction = DatabaseTransaction::new_sqlite( - conn, - self.metric_callback.clone(), - isolation_level, - access_mode, - ) - .await - .map_err(|e| TransactionError::Connection(e))?; + let transaction = + DatabaseTransaction::new_sqlite(conn, self.metric_callback.clone(), config) + .await + .map_err(|e| TransactionError::Connection(e))?; transaction.run(callback).await } @@ -360,15 +344,13 @@ impl crate::DatabaseTransaction { pub(crate) async fn new_sqlite( inner: PoolConnection, metric_callback: Option, - isolation_level: Option, - access_mode: Option, + config: TransactionConfig, ) -> Result { Self::begin( Arc::new(Mutex::new(crate::InnerConnection::Sqlite(inner))), crate::DbBackend::Sqlite, metric_callback, - isolation_level, - access_mode, + config, ) .await } diff --git a/src/rbac/context.rs b/src/rbac/context.rs index 690bcf16d..39fe0cf35 100644 --- a/src/rbac/context.rs +++ b/src/rbac/context.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use super::{ AccessType, RbacError, RbacUserId, entity::{ @@ -11,10 +13,9 @@ use super::{ }, }; use crate::{ - AccessMode, EntityTrait, IsolationLevel, Set, TransactionSession, TransactionTrait, - error::DbErr, sea_query::OnConflict, + AccessMode, EntityTrait, IsolationLevel, Set, TransactionConfig, TransactionSession, + TransactionTrait, error::DbErr, sea_query::OnConflict, }; -use std::collections::HashMap; /// Helper class for manipulation of RBAC tables #[derive(Debug)] @@ -43,10 +44,11 @@ impl RbacContext { pub async fn load(db: &C) -> Result { // ensure snapshot is consistent across all tables let txn = &db - .begin_with_config( - Some(IsolationLevel::ReadCommitted), - Some(AccessMode::ReadOnly), - ) + .begin_with_config(TransactionConfig { + isolation_level: Some(IsolationLevel::ReadCommitted), + access_mode: Some(AccessMode::ReadOnly), + sqlite_transaction_mode: None, + }) .await?; let tables = resource::Entity::find() diff --git a/src/rbac/engine/loader.rs b/src/rbac/engine/loader.rs index 6b91b80d0..6bd947cfb 100644 --- a/src/rbac/engine/loader.rs +++ b/src/rbac/engine/loader.rs @@ -4,16 +4,19 @@ use super::super::entity::{ user_override::Entity as UserOverride, user_role::Entity as UserRole, }; use super::{RbacEngine, RbacSnapshot}; -use crate::{AccessMode, DbConn, DbErr, EntityTrait, IsolationLevel, TransactionTrait}; +use crate::{ + AccessMode, DbConn, DbErr, EntityTrait, IsolationLevel, TransactionConfig, TransactionTrait, +}; impl RbacEngine { pub async fn load_from(db: &DbConn) -> Result { // ensure snapshot is consistent across all tables let txn = &db - .begin_with_config( - Some(IsolationLevel::ReadCommitted), - Some(AccessMode::ReadOnly), - ) + .begin_with_config(TransactionConfig { + isolation_level: Some(IsolationLevel::ReadCommitted), + access_mode: Some(AccessMode::ReadOnly), + sqlite_transaction_mode: None, + }) .await?; let resources = Resource::find().all(txn).await?;