11use crate :: sql:: db_connection_pool:: {
22 self ,
33 dbconnection:: {
4- duckdbconn:: { flatten_table_function_name, is_table_function, DuckDbConnection } ,
4+ duckdbconn:: {
5+ flatten_table_function_name, is_table_function, DuckDBParameter , DuckDbConnection ,
6+ } ,
57 get_schema, DbConnection ,
68 } ,
79 duckdbpool:: DuckDbConnectionPool ,
8- DbConnectionPool , Mode ,
10+ DbConnectionPool , DbInstanceKey , Mode ,
911} ;
1012use crate :: sql:: sql_provider_datafusion;
1113use crate :: util:: {
@@ -25,10 +27,11 @@ use datafusion::{
2527 logical_expr:: CreateExternalTable ,
2628 sql:: TableReference ,
2729} ;
28- use duckdb:: { AccessMode , DuckdbConnectionManager , ToSql , Transaction } ;
30+ use duckdb:: { AccessMode , DuckdbConnectionManager , Transaction } ;
2931use itertools:: Itertools ;
3032use snafu:: prelude:: * ;
3133use std:: { cmp, collections:: HashMap , sync:: Arc } ;
34+ use tokio:: sync:: Mutex ;
3235
3336use self :: { creator:: TableCreator , sql_table:: DuckDBTable , write:: DuckDBTableWriter } ;
3437
@@ -87,11 +90,6 @@ pub enum Error {
8790 #[ snafu( display( "Unable to commit transaction: {source}" ) ) ]
8891 UnableToCommitTransaction { source : duckdb:: Error } ,
8992
90- #[ snafu( display( "Unable to checkpoint duckdb: {source}" ) ) ]
91- UnableToCheckpoint {
92- source : Box < dyn std:: error:: Error + Send + Sync > ,
93- } ,
94-
9593 #[ snafu( display( "Unable to begin duckdb transaction: {source}" ) ) ]
9694 UnableToBeginTransaction { source : duckdb:: Error } ,
9795
@@ -121,6 +119,7 @@ type Result<T, E = Error> = std::result::Result<T, E>;
121119
122120pub struct DuckDBTableProviderFactory {
123121 access_mode : AccessMode ,
122+ instances : Arc < Mutex < HashMap < DbInstanceKey , DuckDbConnectionPool > > > ,
124123}
125124
126125const DUCKDB_DB_PATH_PARAM : & str = "open" ;
@@ -129,9 +128,10 @@ const DUCKDB_ATTACH_DATABASES_PARAM: &str = "attach_databases";
129128
130129impl DuckDBTableProviderFactory {
131130 #[ must_use]
132- pub fn new ( ) -> Self {
131+ pub fn new ( access_mode : AccessMode ) -> Self {
133132 Self {
134- access_mode : AccessMode :: ReadOnly ,
133+ access_mode,
134+ instances : Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ,
135135 }
136136 }
137137
@@ -148,12 +148,6 @@ impl DuckDBTableProviderFactory {
148148 . unwrap_or_default ( )
149149 }
150150
151- #[ must_use]
152- pub fn access_mode ( mut self , access_mode : AccessMode ) -> Self {
153- self . access_mode = access_mode;
154- self
155- }
156-
157151 #[ must_use]
158152 pub fn duckdb_file_path ( & self , name : & str , options : & mut HashMap < String , String > ) -> String {
159153 let options = util:: remove_prefix_from_hashmap_keys ( options. clone ( ) , "duckdb_" ) ;
@@ -169,15 +163,44 @@ impl DuckDBTableProviderFactory {
169163 . cloned ( )
170164 . unwrap_or ( default_filepath)
171165 }
172- }
173166
174- impl Default for DuckDBTableProviderFactory {
175- fn default ( ) -> Self {
176- Self :: new ( )
167+ pub async fn get_or_init_memory_instance ( & self ) -> Result < DuckDbConnectionPool > {
168+ let key = DbInstanceKey :: memory ( ) ;
169+ let mut instances = self . instances . lock ( ) . await ;
170+
171+ if let Some ( instance) = instances. get ( & key) {
172+ return Ok ( instance. clone ( ) ) ;
173+ }
174+
175+ let pool = DuckDbConnectionPool :: new_memory ( ) . context ( DbConnectionPoolSnafu ) ?;
176+
177+ instances. insert ( key, pool. clone ( ) ) ;
178+
179+ Ok ( pool)
180+ }
181+
182+ pub async fn get_or_init_file_instance (
183+ & self ,
184+ db_path : impl Into < Arc < str > > ,
185+ ) -> Result < DuckDbConnectionPool > {
186+ let db_path = db_path. into ( ) ;
187+ let key = DbInstanceKey :: file ( Arc :: clone ( & db_path) ) ;
188+ let mut instances = self . instances . lock ( ) . await ;
189+
190+ if let Some ( instance) = instances. get ( & key) {
191+ return Ok ( instance. clone ( ) ) ;
192+ }
193+
194+ let pool = DuckDbConnectionPool :: new_file ( & db_path, & self . access_mode )
195+ . context ( DbConnectionPoolSnafu ) ?;
196+
197+ instances. insert ( key, pool. clone ( ) ) ;
198+
199+ Ok ( pool)
177200 }
178201}
179202
180- type DynDuckDbConnectionPool = dyn DbConnectionPool < r2d2:: PooledConnection < DuckdbConnectionManager > , & ' static dyn ToSql >
203+ type DynDuckDbConnectionPool = dyn DbConnectionPool < r2d2:: PooledConnection < DuckdbConnectionManager > , DuckDBParameter >
181204 + Send
182205 + Sync ;
183206
@@ -229,12 +252,13 @@ impl TableProviderFactory for DuckDBTableProviderFactory {
229252 // open duckdb at given path or create a new one
230253 let db_path = self . duckdb_file_path ( & name, & mut options) ;
231254
232- DuckDbConnectionPool :: new_file ( & db_path , & self . access_mode )
233- . context ( DbConnectionPoolSnafu )
255+ self . get_or_init_file_instance ( db_path )
256+ . await
234257 . map_err ( to_datafusion_error) ?
235258 }
236- Mode :: Memory => DuckDbConnectionPool :: new_memory ( )
237- . context ( DbConnectionPoolSnafu )
259+ Mode :: Memory => self
260+ . get_or_init_memory_instance ( )
261+ . await
238262 . map_err ( to_datafusion_error) ?,
239263 } ;
240264
@@ -265,7 +289,12 @@ impl TableProviderFactory for DuckDBTableProviderFactory {
265289 ) ) ;
266290
267291 #[ cfg( feature = "duckdb-federation" ) ]
268- let read_provider = Arc :: new ( read_provider. create_federated_table_provider ( ) ?) ;
292+ let read_provider: Arc < dyn TableProvider > = if mode == Mode :: File {
293+ // federation is disabled for in-memory mode until memory connections are updated to use the same database instance instead of separate instances
294+ Arc :: new ( read_provider. create_federated_table_provider ( ) ?)
295+ } else {
296+ read_provider
297+ } ;
269298
270299 Ok ( DuckDBTableWriter :: create (
271300 read_provider,
@@ -317,18 +346,18 @@ impl DuckDB {
317346 pub fn connect_sync (
318347 & self ,
319348 ) -> Result <
320- Box < dyn DbConnection < r2d2:: PooledConnection < DuckdbConnectionManager > , & ' static dyn ToSql > > ,
349+ Box < dyn DbConnection < r2d2:: PooledConnection < DuckdbConnectionManager > , DuckDBParameter > > ,
321350 > {
322351 Arc :: clone ( & self . pool )
323352 . connect_sync ( )
324353 . context ( DbConnectionSnafu )
325354 }
326355
327- pub fn duckdb_conn < ' a > (
328- db_connection : & ' a mut Box <
329- dyn DbConnection < r2d2:: PooledConnection < DuckdbConnectionManager > , & ' static dyn ToSql > ,
356+ pub fn duckdb_conn (
357+ db_connection : & mut Box <
358+ dyn DbConnection < r2d2:: PooledConnection < DuckdbConnectionManager > , DuckDBParameter > ,
330359 > ,
331- ) -> Result < & ' a mut DuckDbConnection > {
360+ ) -> Result < & mut DuckDbConnection > {
332361 db_connection
333362 . as_any_mut ( )
334363 . downcast_mut :: < DuckDbConnection > ( )
@@ -441,7 +470,12 @@ impl DuckDBTableFactory {
441470 ) ) ;
442471
443472 #[ cfg( feature = "duckdb-federation" ) ]
444- let table_provider = Arc :: new ( table_provider. create_federated_table_provider ( ) ?) ;
473+ let table_provider: Arc < dyn TableProvider > = if self . pool . mode ( ) == Mode :: File {
474+ // federation is disabled for in-memory mode until memory connections are updated to use the same database instance instead of separate instances
475+ Arc :: new ( table_provider. create_federated_table_provider ( ) ?)
476+ } else {
477+ table_provider
478+ } ;
445479
446480 Ok ( table_provider)
447481 }
0 commit comments