|
| 1 | +/* |
| 2 | +Copyright 2024 The Spice.ai OSS Authors |
| 3 | +
|
| 4 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +you may not use this file except in compliance with the License. |
| 6 | +You may obtain a copy of the License at |
| 7 | +
|
| 8 | + https://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | +Unless required by applicable law or agreed to in writing, software |
| 11 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +See the License for the specific language governing permissions and |
| 14 | +limitations under the License. |
| 15 | +*/ |
| 16 | + |
| 17 | +use crate::sql::db_connection_pool::dbconnection::odbcconn::ODBCDbConnectionPool; |
| 18 | +use crate::sql::{ |
| 19 | + db_connection_pool as db_connection_pool_datafusion, |
| 20 | + sql_provider_datafusion::{Engine, SqlTable}, |
| 21 | +}; |
| 22 | +use arrow::datatypes::SchemaRef; |
| 23 | +use datafusion::error::DataFusionError; |
| 24 | +use datafusion::{datasource::TableProvider, sql::TableReference}; |
| 25 | +use snafu::prelude::*; |
| 26 | +use std::sync::Arc; |
| 27 | + |
| 28 | +#[derive(Debug, Snafu)] |
| 29 | +pub enum Error { |
| 30 | + #[snafu(display("DbConnectionError: {source}"))] |
| 31 | + DbConnectionError { |
| 32 | + source: db_connection_pool_datafusion::dbconnection::GenericError, |
| 33 | + }, |
| 34 | + #[snafu(display("The table '{table_name}' doesn't exist in the Postgres server"))] |
| 35 | + TableDoesntExist { table_name: String }, |
| 36 | + |
| 37 | + #[snafu(display("Unable to get a DB connection from the pool: {source}"))] |
| 38 | + UnableToGetConnectionFromPool { |
| 39 | + source: db_connection_pool_datafusion::Error, |
| 40 | + }, |
| 41 | + |
| 42 | + #[snafu(display("Unable to get schema: {source}"))] |
| 43 | + UnableToGetSchema { |
| 44 | + source: db_connection_pool_datafusion::dbconnection::Error, |
| 45 | + }, |
| 46 | + |
| 47 | + #[snafu(display("Unable to generate SQL: {source}"))] |
| 48 | + UnableToGenerateSQL { source: DataFusionError }, |
| 49 | +} |
| 50 | + |
| 51 | +type Result<T, E = Error> = std::result::Result<T, E>; |
| 52 | + |
| 53 | +pub struct ODBCTableFactory<'a> { |
| 54 | + pool: Arc<ODBCDbConnectionPool<'a>>, |
| 55 | +} |
| 56 | + |
| 57 | +impl<'a> ODBCTableFactory<'a> |
| 58 | +where |
| 59 | + 'a: 'static, |
| 60 | +{ |
| 61 | + #[must_use] |
| 62 | + pub fn new(pool: Arc<ODBCDbConnectionPool<'a>>) -> Self { |
| 63 | + Self { pool } |
| 64 | + } |
| 65 | + |
| 66 | + pub async fn table_provider( |
| 67 | + &self, |
| 68 | + table_reference: TableReference, |
| 69 | + _schema: Option<SchemaRef>, |
| 70 | + ) -> Result<Arc<dyn TableProvider + 'static>, Box<dyn std::error::Error + Send + Sync>> { |
| 71 | + let pool = Arc::clone(&self.pool); |
| 72 | + let dyn_pool: Arc<ODBCDbConnectionPool<'a>> = pool; |
| 73 | + |
| 74 | + let table = SqlTable::new("odbc", &dyn_pool, table_reference, Some(Engine::ODBC)) |
| 75 | + .await |
| 76 | + .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?; |
| 77 | + |
| 78 | + let table_provider = Arc::new(table); |
| 79 | + |
| 80 | + #[cfg(feature = "odbc-federation")] |
| 81 | + let table_provider = Arc::new( |
| 82 | + table_provider |
| 83 | + .create_federated_table_provider() |
| 84 | + .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?, |
| 85 | + ); |
| 86 | + |
| 87 | + Ok(table_provider) |
| 88 | + } |
| 89 | +} |
0 commit comments