Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 40 additions & 6 deletions src/database/connection.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use std::{future::Future, pin::Pin};

use futures_util::Stream;

use crate::{
DbBackend, DbErr, ExecResult, QueryResult, Statement, StatementBuilder, TransactionError,
};
use futures_util::Stream;
use std::{future::Future, pin::Pin};

/// The generic API for a database connection that can perform query or execute statements.
/// It abstracts database connection and transaction
Expand Down Expand Up @@ -125,6 +127,40 @@ impl std::fmt::Display for AccessMode {
}
}

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
/// Which kind of transaction to start. Only supported by SQLite.
/// <https://www.sqlite.org/lang_transaction.html>
pub enum SqliteTransactionMode {
/// The default. Transaction starts when the next statement is executed, and
/// will be a read or write transaction depending on that statement.
Deferred,
/// Start a write transaction as soon as the BEGIN statement is received.
Immediate,
/// Start a write transaction as soon as the BEGIN statement is received.
/// When in non-WAL mode, also block all other transactions from reading the
/// database.
Exclusive,
}

impl SqliteTransactionMode {
/// The keyword used to start a transaction in this mode (the word coming after "BEGIN").
pub fn sqlite_keyword(&self) -> &'static str {
match self {
SqliteTransactionMode::Deferred => "DEFERRED",
SqliteTransactionMode::Immediate => "IMMEDIATE",
SqliteTransactionMode::Exclusive => "EXCLUSIVE",
}
}
}

/// Configuration for starting a transaction
#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)]
pub struct TransactionConfig {
pub isolation_level: Option<IsolationLevel>,
pub access_mode: Option<AccessMode>,
pub sqlite_transaction_mode: Option<SqliteTransactionMode>,
}

/// Spawn database transaction
#[async_trait::async_trait]
pub trait TransactionTrait {
Expand All @@ -139,8 +175,7 @@ pub trait TransactionTrait {
/// Returns a Transaction that can be committed or rolled back
async fn begin_with_config(
&self,
isolation_level: Option<IsolationLevel>,
access_mode: Option<AccessMode>,
config: TransactionConfig,
) -> Result<Self::Transaction, DbErr>;

/// Execute the function inside a transaction.
Expand All @@ -159,8 +194,7 @@ pub trait TransactionTrait {
async fn transaction_with_config<F, T, E>(
&self,
callback: F,
isolation_level: Option<IsolationLevel>,
access_mode: Option<AccessMode>,
config: TransactionConfig,
) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(
Expand Down
90 changes: 44 additions & 46 deletions src/database/db_connection.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
use crate::{
AccessMode, ConnectionTrait, DatabaseTransaction, ExecResult, IsolationLevel, QueryResult,
Schema, SchemaBuilder, Statement, StatementBuilder, StreamTrait, TransactionError,
TransactionTrait, error::*,
};
#[cfg(any(feature = "mock", feature = "proxy"))]
use std::sync::Arc;
use std::{fmt::Debug, future::Future, pin::Pin};
use tracing::instrument;
use url::Url;

#[cfg(feature = "sqlx-dep")]
use sqlx::pool::PoolConnection;
use tracing::instrument;
use url::Url;

#[cfg(feature = "rusqlite")]
use crate::driver::rusqlite::{RusqliteInnerConnection, RusqliteSharedConnection};

#[cfg(any(feature = "mock", feature = "proxy"))]
use std::sync::Arc;
use crate::{
ConnectionTrait, DatabaseTransaction, ExecResult, QueryResult, Schema, SchemaBuilder,
Statement, StatementBuilder, StreamTrait, TransactionConfig, TransactionError,
TransactionTrait, error::*,
};

/// Handle a database connection depending on the backend enabled by the feature
/// flags. This creates a connection pool internally (for SQLx connections),
Expand Down Expand Up @@ -351,15 +350,21 @@ impl TransactionTrait for DatabaseConnection {
async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
match &self.inner {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.begin(None, None).await,
DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
conn.begin(TransactionConfig::default()).await
}
#[cfg(feature = "sqlx-postgres")]
DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
conn.begin(None, None).await
conn.begin(TransactionConfig::default()).await
}
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.begin(None, None).await,
DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
conn.begin(TransactionConfig::default()).await
}
#[cfg(feature = "rusqlite")]
DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.begin(None, None),
DatabaseConnectionType::RusqliteSharedConnection(conn) => {
conn.begin(TransactionConfig::default())
}
#[cfg(feature = "mock")]
DatabaseConnectionType::MockDatabaseConnection(conn) => {
DatabaseTransaction::new_mock(Arc::clone(conn), None).await
Expand All @@ -375,26 +380,17 @@ impl TransactionTrait for DatabaseConnection {
#[instrument(level = "trace")]
async fn begin_with_config(
&self,
_isolation_level: Option<IsolationLevel>,
_access_mode: Option<AccessMode>,
_config: TransactionConfig,
) -> Result<DatabaseTransaction, DbErr> {
match &self.inner {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
conn.begin(_isolation_level, _access_mode).await
}
DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => conn.begin(_config).await,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
conn.begin(_isolation_level, _access_mode).await
}
DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => conn.begin(_config).await,
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
conn.begin(_isolation_level, _access_mode).await
}
DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => conn.begin(_config).await,
#[cfg(feature = "rusqlite")]
DatabaseConnectionType::RusqliteSharedConnection(conn) => {
conn.begin(_isolation_level, _access_mode)
}
DatabaseConnectionType::RusqliteSharedConnection(conn) => conn.begin(_config),
#[cfg(feature = "mock")]
DatabaseConnectionType::MockDatabaseConnection(conn) => {
DatabaseTransaction::new_mock(Arc::clone(conn), None).await
Expand All @@ -408,7 +404,8 @@ impl TransactionTrait for DatabaseConnection {
}

/// Execute the function inside a transaction.
/// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
/// If the function returns an error, the transaction will be rolled back.
/// If it does not return an error, the transaction will be committed.
#[instrument(level = "trace", skip(_callback))]
async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
where
Expand All @@ -422,19 +419,22 @@ impl TransactionTrait for DatabaseConnection {
match &self.inner {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
conn.transaction(_callback, None, None).await
conn.transaction(_callback, TransactionConfig::default())
.await
}
#[cfg(feature = "sqlx-postgres")]
DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
conn.transaction(_callback, None, None).await
conn.transaction(_callback, TransactionConfig::default())
.await
}
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
conn.transaction(_callback, None, None).await
conn.transaction(_callback, TransactionConfig::default())
.await
}
#[cfg(feature = "rusqlite")]
DatabaseConnectionType::RusqliteSharedConnection(conn) => {
conn.transaction(_callback, None, None)
conn.transaction(_callback, TransactionConfig::default())
}
#[cfg(feature = "mock")]
DatabaseConnectionType::MockDatabaseConnection(conn) => {
Expand All @@ -455,13 +455,13 @@ impl TransactionTrait for DatabaseConnection {
}

/// Execute the function inside a transaction.
/// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
/// If the function returns an error, the transaction will be rolled back.
/// If it does not return an error, the transaction will be committed.
#[instrument(level = "trace", skip(_callback))]
async fn transaction_with_config<F, T, E>(
&self,
_callback: F,
_isolation_level: Option<IsolationLevel>,
_access_mode: Option<AccessMode>,
_config: TransactionConfig,
) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(
Expand All @@ -474,22 +474,19 @@ impl TransactionTrait for DatabaseConnection {
match &self.inner {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnectionType::SqlxMySqlPoolConnection(conn) => {
conn.transaction(_callback, _isolation_level, _access_mode)
.await
conn.transaction(_callback, _config).await
}
#[cfg(feature = "sqlx-postgres")]
DatabaseConnectionType::SqlxPostgresPoolConnection(conn) => {
conn.transaction(_callback, _isolation_level, _access_mode)
.await
conn.transaction(_callback, _config).await
}
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnectionType::SqlxSqlitePoolConnection(conn) => {
conn.transaction(_callback, _isolation_level, _access_mode)
.await
conn.transaction(_callback, _config).await
}
#[cfg(feature = "rusqlite")]
DatabaseConnectionType::RusqliteSharedConnection(conn) => {
conn.transaction(_callback, _isolation_level, _access_mode)
conn.transaction(_callback, _config)
}
#[cfg(feature = "mock")]
DatabaseConnectionType::MockDatabaseConnection(conn) => {
Expand Down Expand Up @@ -556,8 +553,8 @@ impl DatabaseConnection {

#[cfg(feature = "rbac")]
impl DatabaseConnection {
/// Load RBAC data from the same database as this connection and setup RBAC engine.
/// If the RBAC engine already exists, it will be replaced.
/// Load RBAC data from the same database as this connection and setup RBAC
/// engine. If the RBAC engine already exists, it will be replaced.
pub async fn load_rbac(&self) -> Result<(), DbErr> {
self.load_rbac_from(self).await
}
Expand All @@ -575,7 +572,8 @@ impl DatabaseConnection {
self.rbac.replace(engine);
}

/// Create a restricted connection with access control specific for the user.
/// Create a restricted connection with access control specific for the
/// user.
pub fn restricted_for(
&self,
user_id: crate::rbac::RbacUserId,
Expand Down
30 changes: 11 additions & 19 deletions src/database/executor.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::future::Future;
use std::pin::Pin;

use crate::{
AccessMode, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbBackend, DbErr,
ExecResult, IsolationLevel, QueryResult, Statement, TransactionError, TransactionTrait,
ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbBackend, DbErr, ExecResult,
QueryResult, Statement, TransactionConfig, TransactionError, TransactionTrait,
};
use crate::{Schema, SchemaBuilder};
use std::future::Future;
use std::pin::Pin;

/// A wrapper that holds either a reference to a [`DatabaseConnection`] or [`DatabaseTransaction`].
#[derive(Debug)]
Expand Down Expand Up @@ -78,16 +79,11 @@ impl TransactionTrait for DatabaseExecutor<'_> {

async fn begin_with_config(
&self,
isolation_level: Option<IsolationLevel>,
access_mode: Option<AccessMode>,
config: TransactionConfig,
) -> Result<DatabaseTransaction, DbErr> {
match self {
DatabaseExecutor::Connection(conn) => {
conn.begin_with_config(isolation_level, access_mode).await
}
DatabaseExecutor::Transaction(trans) => {
trans.begin_with_config(isolation_level, access_mode).await
}
DatabaseExecutor::Connection(conn) => conn.begin_with_config(config).await,
DatabaseExecutor::Transaction(trans) => trans.begin_with_config(config).await,
}
}

Expand All @@ -109,8 +105,7 @@ impl TransactionTrait for DatabaseExecutor<'_> {
async fn transaction_with_config<F, T, E>(
&self,
callback: F,
isolation_level: Option<IsolationLevel>,
access_mode: Option<AccessMode>,
config: TransactionConfig,
) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(
Expand All @@ -122,13 +117,10 @@ impl TransactionTrait for DatabaseExecutor<'_> {
{
match self {
DatabaseExecutor::Connection(conn) => {
conn.transaction_with_config(callback, isolation_level, access_mode)
.await
conn.transaction_with_config(callback, config).await
}
DatabaseExecutor::Transaction(trans) => {
trans
.transaction_with_config(callback, isolation_level, access_mode)
.await
trans.transaction_with_config(callback, config).await
}
}
}
Expand Down
Loading
Loading