Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 33 additions & 14 deletions src/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ pub mod federation;
#[cfg(feature = "sqlite-federation")]
pub mod sqlite_interval;

#[cfg(feature = "sqlite-federation")]
pub mod between;

pub mod sql_table;
pub mod write;

Expand Down Expand Up @@ -119,6 +122,7 @@ type Result<T, E = Error> = std::result::Result<T, E>;
#[derive(Debug)]
pub struct SqliteTableProviderFactory {
instances: Arc<Mutex<HashMap<DbInstanceKey, SqliteConnectionPool>>>,
decimal_between: bool,
}

const SQLITE_DB_PATH_PARAM: &str = "file";
Expand All @@ -131,9 +135,16 @@ impl SqliteTableProviderFactory {
pub fn new() -> Self {
Self {
instances: Arc::new(Mutex::new(HashMap::new())),
decimal_between: false,
}
}

#[must_use]
pub fn with_decimal_between(mut self, decimal_between: bool) -> Self {
self.decimal_between = decimal_between;
self
}

#[must_use]
pub fn attach_databases(&self, options: &HashMap<String, String>) -> Option<Vec<Arc<str>>> {
options.get(SQLITE_ATTACH_DATABASES_PARAM).map(|databases| {
Expand Down Expand Up @@ -353,11 +364,10 @@ impl TableProviderFactory for SqliteTableProviderFactory {

let dyn_pool: Arc<DynSqliteConnectionPool> = read_pool;

let read_provider = Arc::new(SQLiteTable::new_with_schema(
&dyn_pool,
Arc::clone(&schema),
name,
));
let read_provider = Arc::new(
SQLiteTable::new_with_schema(&dyn_pool, Arc::clone(&schema), name)
.with_decimal_between(self.decimal_between),
);

let sqlite = Arc::into_inner(sqlite)
.context(DanglingReferenceToSqliteSnafu)
Expand All @@ -377,12 +387,22 @@ impl TableProviderFactory for SqliteTableProviderFactory {

pub struct SqliteTableFactory {
pool: Arc<SqliteConnectionPool>,
decimal_between: bool,
}

impl SqliteTableFactory {
#[must_use]
pub fn new(pool: Arc<SqliteConnectionPool>) -> Self {
Self { pool }
Self {
pool,
decimal_between: false,
}
}

#[must_use]
pub fn with_decimal_between(mut self, decimal_between: bool) -> Self {
self.decimal_between = decimal_between;
self
}

pub async fn table_provider(
Expand All @@ -398,11 +418,10 @@ impl SqliteTableFactory {

let dyn_pool: Arc<DynSqliteConnectionPool> = pool;

let read_provider = Arc::new(SQLiteTable::new_with_schema(
&dyn_pool,
Arc::clone(&schema),
table_reference,
));
let read_provider = Arc::new(
SQLiteTable::new_with_schema(&dyn_pool, Arc::clone(&schema), table_reference)
.with_decimal_between(self.decimal_between),
);

Ok(read_provider)
}
Expand Down Expand Up @@ -473,12 +492,12 @@ impl Sqlite {

async fn table_exists(&self, sqlite_conn: &mut SqliteConnection) -> bool {
let sql = format!(
r#"SELECT EXISTS (
"SELECT EXISTS (
SELECT 1
FROM sqlite_master
WHERE type='table'
AND name = '{name}'
)"#,
)",
name = self.table
);
tracing::trace!("{sql}");
Expand Down Expand Up @@ -516,7 +535,7 @@ impl Sqlite {

fn delete_all_table_data(&self, transaction: &Transaction<'_>) -> rusqlite::Result<()> {
transaction.execute(
format!(r#"DELETE FROM {}"#, self.table.to_quoted_string()).as_str(),
format!("DELETE FROM {}", self.table.to_quoted_string()).as_str(),
[],
)?;

Expand Down
Loading