diff --git a/Cargo.lock b/Cargo.lock index e093d349..46706ca8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -690,6 +690,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "bb8-oracle" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a471c2e1027f28a98972d3d38bd2612efd05cf7b27d6cc1566ec901537e2d1c8" +dependencies = [ + "bb8", + "oracle", + "tokio", +] + [[package]] name = "bb8-postgres" version = "0.9.0" @@ -986,9 +997,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.41" +version = "1.2.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac9fe6cdbb24b6ade63616c0a0688e45bb56732262c158df3c0c4bea4ca47cb7" +checksum = "cd4932aefd12402b36c60956a4fe0035421f544799057659ff86f923657aada3" dependencies = [ "find-msvc-tools", "jobserver", @@ -1314,14 +1325,38 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f27ae1dd37df86211c42e150270f82743308803d90a6f6e6651cd730d5e1732f" +[[package]] +name = "darling" +version = "0.13.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a01d95850c592940db9b8194bc39f4bc0e89dee5c4265e4b1807c34a9aba453c" +dependencies = [ + "darling_core 0.13.4", + "darling_macro 0.13.4", +] + [[package]] name = "darling" version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" dependencies = [ - "darling_core", - "darling_macro", + "darling_core 0.20.11", + "darling_macro 0.20.11", +] + +[[package]] +name = "darling_core" +version = "0.13.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "859d65a907b6852c9361e3185c862aae7fafd2887876799fa55f5f99dc40d610" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim 0.10.0", + "syn 1.0.109", ] [[package]] @@ -1334,17 +1369,28 @@ dependencies = [ "ident_case", "proc-macro2", "quote", - "strsim", + "strsim 0.11.1", "syn 2.0.111", ] +[[package]] +name = "darling_macro" +version = "0.13.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c972679f83bdf9c42bd905396b6c3588a843a17f0f16dfcfa3e2c5d57441835" +dependencies = [ + "darling_core 0.13.4", + "quote", + "syn 1.0.109", +] + [[package]] name = "darling_macro" version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" dependencies = [ - "darling_core", + "darling_core 0.20.11", "quote", "syn 2.0.111", ] @@ -2111,6 +2157,7 @@ dependencies = [ "async-trait", "base64", "bb8", + "bb8-oracle", "bb8-postgres", "bigdecimal", "bollard", @@ -2139,6 +2186,7 @@ dependencies = [ "native-tls", "num-bigint", "odbc-api", + "oracle", "pem", "postgres-native-tls", "prost", @@ -2353,9 +2401,9 @@ dependencies = [ [[package]] name = "find-msvc-tools" -version = "0.1.4" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52051878f80a721bb68ebfbc930e07b65ba72f2da88968ea5c06fd6ca3d3a127" +checksum = "f449e6c6c08c865631d4890cfacf252b3d396c9bcc83adb6623cdb02a8336c41" [[package]] name = "fixedbitset" @@ -3561,7 +3609,7 @@ version = "0.32.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "66f62cad7623a9cb6f8f64037f0c4f69c8db8e82914334a83c9788201c2c1bfa" dependencies = [ - "darling", + "darling 0.20.11", "heck 0.5.0", "num-bigint", "proc-macro-crate", @@ -4038,6 +4086,15 @@ version = "0.27.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bd7e3c4b5b7bbd3e7bd01dc00cb4614f2445591cad1f6f18a7e16d7f98c392e9" +[[package]] +name = "odpic-sys" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "920b5474a5128a9f0232df5a0ffc50aaa5b077b29b8b06ab0131985ac82793ed" +dependencies = [ + "cc", +] + [[package]] name = "once_cell" version = "1.21.3" @@ -4094,6 +4151,32 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "oracle" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3db40fe6e4df881b683691ade5ef1f7b1afd52aefa115581f7b92855524d7ec0" +dependencies = [ + "cc", + "odpic-sys", + "once_cell", + "oracle_procmacro", + "paste", + "rustversion", +] + +[[package]] +name = "oracle_procmacro" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad247f3421d57de56a0d0408d3249d4b1048a522be2013656d92f022c3d8af27" +dependencies = [ + "darling 0.13.4", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "orbclient" version = "0.3.48" @@ -5089,7 +5172,7 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bae0cbad6ab996955664982739354128c58d16e126114fe88c2a493642502aab" dependencies = [ - "darling", + "darling 0.20.11", "heck 0.4.1", "proc-macro2", "quote", @@ -5482,6 +5565,12 @@ dependencies = [ "unicode-properties", ] +[[package]] +name = "strsim" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + [[package]] name = "strsim" version = "0.11.1" diff --git a/README.md b/README.md index 2976d271..55cab613 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ let ctx = SessionContext::with_state(state); - PostgreSQL - MySQL +- Oracle - SQLite - ClickHouse - DuckDB @@ -165,6 +166,57 @@ EOF cargo run -p datafusion-table-providers --example mysql --features mysql ``` +### Oracle + +In order to run the Oracle example, you need to have an Oracle database server running. You can use the following command to start an Oracle Free server in a Docker container the example can use: + +```bash +docker run --name oracle-free \ + -e ORACLE_PASSWORD=OraclePassword123 \ + -p 1521:1521 \ + -d gvenzl/oracle-free:latest + +# Wait for the Oracle server to start and healthcheck to pass +echo "Waiting for Oracle to start (this may take 1-2 minutes)..." +until docker exec oracle-free /usr/local/bin/checkHealth.sh >/dev/null 2>&1; do + sleep 5 +done +echo "Oracle is ready!" + +# Create a table in the Oracle server and insert some data +docker exec -i oracle-free sqlplus system/OraclePassword123@FREEPDB1 <, + }, +} + +pub type Result = std::result::Result; + +pub struct OracleTableFactory { + pool: Arc, +} + +impl OracleTableFactory { + #[must_use] + pub fn new(pool: Arc) -> Self { + Self { pool } + } + + pub async fn table_provider( + &self, + table_reference: TableReference, + ) -> Result, Box> { + let pool = Arc::clone(&self.pool); + let dyn_pool = pool as Arc< + dyn db_connection_pool::DbConnectionPool< + OraclePooledConnection, + oracle::sql_type::OracleType, + > + Send + + Sync + + 'static, + >; + + let table = SqlTable::new("oracle", &dyn_pool, table_reference) + .await + .map_err(|e| Box::new(e) as Box)?; + + let oracle_table = Arc::new(OracleTable::new(Arc::clone(&self.pool), table)); + + #[cfg(feature = "oracle-federation")] + let oracle_table = Arc::new( + oracle_table + .create_federated_table_provider() + .map_err(|e| Box::new(e) as Box)?, + ); + + Ok(oracle_table) + } +} + +#[derive(Debug)] +pub struct OracleTableProviderFactory {} + +impl OracleTableProviderFactory { + #[must_use] + pub fn new() -> Self { + Self {} + } +} + +impl Default for OracleTableProviderFactory { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl TableProviderFactory for OracleTableProviderFactory { + async fn create( + &self, + _state: &dyn Session, + cmd: &CreateExternalTable, + ) -> datafusion::common::Result> { + let name = cmd.name.to_string(); + let options = &cmd.options; + + // Construct params from options + let mut params: HashMap = HashMap::new(); + for (k, v) in options { + params.insert(k.clone(), SecretString::from(v.clone())); + } + + let pool = OracleConnectionPool::new(params) + .await + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + let factory = OracleTableFactory::new(Arc::new(pool)); + + let table = factory + .table_provider(TableReference::from(name)) + .await + .map_err(DataFusionError::External)?; + + Ok(table) + } +} diff --git a/core/src/oracle/federation.rs b/core/src/oracle/federation.rs new file mode 100644 index 00000000..3ed44b4c --- /dev/null +++ b/core/src/oracle/federation.rs @@ -0,0 +1,104 @@ +use crate::sql::db_connection_pool::dbconnection::oracleconn::OraclePooledConnection; +use crate::sql::db_connection_pool::dbconnection::{get_schema, Error as DbError}; +use crate::sql::sql_provider_datafusion::{get_stream, to_execution_error}; +use arrow::datatypes::SchemaRef; +use async_trait::async_trait; +use datafusion::sql::unparser::dialect::Dialect; +use datafusion_federation::sql::{ + RemoteTableRef, SQLExecutor, SQLFederationProvider, SQLTableSource, +}; +use datafusion_federation::{FederatedTableProviderAdaptor, FederatedTableSource}; +use futures::TryStreamExt; +use snafu::ResultExt; +use std::sync::Arc; + +use super::sql_table::OracleTable; +use datafusion::{ + datasource::TableProvider, + error::{DataFusionError, Result as DataFusionResult}, + execution::SendableRecordBatchStream, + physical_plan::stream::RecordBatchStreamAdapter, + sql::TableReference, +}; + +impl OracleTable { + pub fn create_federated_table_source( + self: Arc, + ) -> DataFusionResult> { + let table_reference = self.base_table.table_reference.clone(); + let schema = Arc::clone(&self.base_table.schema()); + let fed_provider = Arc::new(SQLFederationProvider::new(self.clone())); + Ok(Arc::new(SQLTableSource::new_with_schema( + fed_provider, + RemoteTableRef::from(table_reference), + schema, + ))) + } + + pub fn create_federated_table_provider( + self: Arc, + ) -> DataFusionResult { + let table_source = self.clone().create_federated_table_source()?; + Ok(FederatedTableProviderAdaptor::new_with_provider( + table_source, + self, + )) + } +} + +#[async_trait] +impl SQLExecutor for OracleTable { + fn name(&self) -> &str { + self.base_table.name() + } + + fn compute_context(&self) -> Option { + None + } + + fn dialect(&self) -> Arc { + Arc::new(Self::dialect()) + } + + fn execute( + &self, + query: &str, + schema: SchemaRef, + ) -> DataFusionResult { + let pool = self.base_table.clone_pool(); + let dyn_pool = pool as Arc< + dyn crate::sql::db_connection_pool::DbConnectionPool< + OraclePooledConnection, + oracle::sql_type::OracleType, + > + Send + + Sync, + >; + let fut = get_stream(dyn_pool, query.to_string(), Arc::clone(&schema)); + + let stream = futures::stream::once(fut).try_flatten(); + Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) + } + + async fn table_names(&self) -> DataFusionResult> { + Err(DataFusionError::NotImplemented( + "table inference not implemented".to_string(), + )) + } + + async fn get_table_schema(&self, table_name: &str) -> DataFusionResult { + let pool = self.base_table.clone_pool(); + let dyn_pool = pool as Arc< + dyn crate::sql::db_connection_pool::DbConnectionPool< + OraclePooledConnection, + oracle::sql_type::OracleType, + > + Send + + Sync, + >; + let conn = dyn_pool.connect().await.map_err(to_execution_error)?; + get_schema(conn, &TableReference::from(table_name)) + .await + .boxed() + .map_err(|e| DbError::UnableToGetSchema { source: e }) + .map_err(to_execution_error) + } +} diff --git a/core/src/oracle/sql_table.rs b/core/src/oracle/sql_table.rs new file mode 100644 index 00000000..7ee6ae79 --- /dev/null +++ b/core/src/oracle/sql_table.rs @@ -0,0 +1,284 @@ +use crate::sql::db_connection_pool::oraclepool::OracleConnectionPool; + +use async_trait::async_trait; +use datafusion::catalog::Session; +use futures::TryStreamExt; +use std::fmt::Display; +use std::{any::Any, fmt, sync::Arc}; + +use crate::sql::db_connection_pool::dbconnection::oracleconn::OraclePooledConnection; +use crate::sql::sql_provider_datafusion::{ + get_stream, to_execution_error, Result as SqlResult, SqlExec, SqlTable, +}; +use datafusion::{ + arrow::datatypes::{DataType, SchemaRef}, + common::utils::quote_identifier, + datasource::TableProvider, + error::Result as DataFusionResult, + execution::TaskContext, + logical_expr::{Expr, TableProviderFilterPushDown, TableType}, + physical_plan::{ + stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan, + PlanProperties, SendableRecordBatchStream, + }, + sql::{ + sqlparser, + unparser::{ + dialect::{CustomDialect, CustomDialectBuilder}, + Unparser, + }, + }, +}; + +pub struct OracleTable { + pool: Arc, + pub(crate) base_table: SqlTable, +} + +impl std::fmt::Debug for OracleTable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OracleTable") + .field("base_table", &self.base_table) + .finish() + } +} + +impl OracleTable { + pub fn new( + pool: Arc, + base_table: SqlTable, + ) -> Self { + Self { pool, base_table } + } + + pub(crate) fn dialect() -> CustomDialect { + CustomDialectBuilder::new() + .with_identifier_quote_style('"') + // There is no 'DOUBLE' SQL type in Oracle: it can use 'FLOAT' for both single and double precision float values + .with_float64_ast_dtype(sqlparser::ast::DataType::Float( + sqlparser::ast::ExactNumberInfo::None, + )) + .build() + } + + fn create_physical_plan( + &self, + projections: Option<&Vec>, + schema: &SchemaRef, + filters: &[Expr], + limit: Option, + ) -> DataFusionResult> { + let projected_schema = if let Some(proj) = projections { + Arc::new(schema.project(proj)?) + } else { + Arc::clone(schema) + }; + + let columns = projected_schema + .fields() + .iter() + .map(|f| quote_identifier(f.name())) + .collect::>() + .join(", "); + + let dialect = Self::dialect(); + + let where_expr = if filters.is_empty() { + String::new() + } else { + let filter_expr = filters + .iter() + .map(|f| { + Unparser::new(&dialect) + .expr_to_sql(f) + .map(|e| e.to_string()) + }) + .collect::>>()? + .join(" AND "); + format!("WHERE {filter_expr}") + }; + + let limit_expr = if let Some(limit) = limit { + format!("FETCH FIRST {limit} ROWS ONLY") + } else { + String::new() + }; + + let table_reference = self.base_table.table_reference.to_quoted_string(); + let sql = format!("SELECT {columns} FROM {table_reference} {where_expr} {limit_expr}"); + + Ok(Arc::new(OracleSQLExec::new( + projections, + schema, + Arc::clone(&self.pool), + sql, + )?)) + } + + /// Check if an expression contains datetime-related types that Oracle cannot handle + /// in filter pushdown due to datetime literal format requirements. + fn contains_datetime_expr(expr: &Expr) -> bool { + match expr { + Expr::BinaryExpr(binary_expr) => { + Self::is_datetime_type_expr(&binary_expr.left) + || Self::is_datetime_type_expr(&binary_expr.right) + || Self::contains_datetime_expr(&binary_expr.left) + || Self::contains_datetime_expr(&binary_expr.right) + } + Expr::Not(inner) => Self::contains_datetime_expr(inner), + _ => Self::is_datetime_type_expr(expr), + } + } + + fn is_datetime_type_expr(expr: &Expr) -> bool { + match expr { + Expr::Cast(cast) => matches!( + cast.data_type, + DataType::Time32(_) + | DataType::Time64(_) + | DataType::Date32 + | DataType::Date64 + | DataType::Timestamp(_, _) + ), + Expr::Literal(literal, _) => matches!( + literal.data_type(), + DataType::Time32(_) + | DataType::Time64(_) + | DataType::Date32 + | DataType::Date64 + | DataType::Timestamp(_, _) + ), + _ => false, + } + } +} + +#[async_trait] +impl TableProvider for OracleTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.base_table.schema() + } + + fn table_type(&self) -> TableType { + self.base_table.table_type() + } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> DataFusionResult> { + // Oracle requires specific format for datetime literals that the expression + // unparser cannot handle correctly, resulting in ORA-01843 errors. + // We mark datetime-related filters as unsupported to prevent pushdown. + let mut results = Vec::with_capacity(filters.len()); + for filter in filters { + if Self::contains_datetime_expr(filter) { + results.push(TableProviderFilterPushDown::Unsupported); + } else { + // For non-datetime filters, delegate to base table + let base_result = self.base_table.supports_filters_pushdown(&[filter])?; + results.extend(base_result); + } + } + Ok(results) + } + + async fn scan( + &self, + _state: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> DataFusionResult> { + return self.create_physical_plan(projection, &self.schema(), filters, limit); + } +} + +impl Display for OracleTable { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "OracleTable {}", self.base_table.name()) + } +} + +struct OracleSQLExec { + base_exec: SqlExec, +} + +impl OracleSQLExec { + fn new( + projections: Option<&Vec>, + schema: &SchemaRef, + pool: Arc, + sql: String, + ) -> DataFusionResult { + let base_exec = SqlExec::new(projections, schema, pool, sql)?; + + Ok(Self { base_exec }) + } + + fn sql(&self) -> SqlResult { + self.base_exec.sql() + } +} + +impl std::fmt::Debug for OracleSQLExec { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let sql = self.sql().unwrap_or_default(); + write!(f, "OracleSQLExec sql={sql}") + } +} + +impl DisplayAs for OracleSQLExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result { + let sql = self.sql().unwrap_or_default(); + write!(f, "OracleSQLExec sql={sql}") + } +} + +impl ExecutionPlan for OracleSQLExec { + fn name(&self) -> &'static str { + "OracleSQLExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.base_exec.schema() + } + + fn properties(&self) -> &PlanProperties { + self.base_exec.properties() + } + + fn children(&self) -> Vec<&Arc> { + self.base_exec.children() + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> DataFusionResult> { + Ok(self) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> DataFusionResult { + let sql = self.sql().map_err(to_execution_error)?; + tracing::debug!("OracleSQLExec sql: {sql}"); + + let fut = get_stream(self.base_exec.clone_pool(), sql, Arc::clone(&self.schema())); + + let stream = futures::stream::once(fut).try_flatten(); + let schema = Arc::clone(&self.schema()); + Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) + } +} diff --git a/core/src/oracle/write.rs b/core/src/oracle/write.rs new file mode 100644 index 00000000..8c321b83 --- /dev/null +++ b/core/src/oracle/write.rs @@ -0,0 +1,21 @@ +use crate::sql::db_connection_pool::oraclepool::OracleConnectionPool; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::error::{DataFusionError, Result}; +use std::sync::Arc; + +pub struct OracleTableWriter { + _pool: Arc, +} + +impl OracleTableWriter { + pub fn new(pool: Arc) -> Self { + Self { _pool: pool } + } + + pub async fn insert_batch(&self, _batch: RecordBatch) -> Result { + // Implement batch insert using oracle-rs BatchBuilder + Err(DataFusionError::NotImplemented( + "Oracle write support not yet implemented".to_string(), + )) + } +} diff --git a/core/src/sql/arrow_sql_gen/arrow.rs b/core/src/sql/arrow_sql_gen/arrow.rs index 437d76a0..12625672 100644 --- a/core/src/sql/arrow_sql_gen/arrow.rs +++ b/core/src/sql/arrow_sql_gen/arrow.rs @@ -3,13 +3,13 @@ use datafusion::arrow::{ types::Int8Type, ArrayBuilder, BinaryBuilder, BooleanBuilder, Date32Builder, Date64Builder, Decimal128Builder, Decimal256Builder, Decimal32Builder, Decimal64Builder, FixedSizeBinaryBuilder, FixedSizeListBuilder, Float32Builder, Float64Builder, Int16Builder, - Int32Builder, Int64Builder, Int8Builder, IntervalMonthDayNanoBuilder, LargeBinaryBuilder, - LargeStringBuilder, ListBuilder, NullBuilder, StringBuilder, StringDictionaryBuilder, - StructBuilder, Time64NanosecondBuilder, TimestampMicrosecondBuilder, - TimestampMillisecondBuilder, TimestampNanosecondBuilder, TimestampSecondBuilder, - UInt16Builder, UInt32Builder, UInt64Builder, UInt8Builder, + Int32Builder, Int64Builder, Int8Builder, IntervalMonthDayNanoBuilder, + IntervalYearMonthBuilder, LargeBinaryBuilder, LargeStringBuilder, ListBuilder, NullBuilder, + StringBuilder, StringDictionaryBuilder, StructBuilder, Time64NanosecondBuilder, + TimestampMicrosecondBuilder, TimestampMillisecondBuilder, TimestampNanosecondBuilder, + TimestampSecondBuilder, UInt16Builder, UInt32Builder, UInt64Builder, UInt8Builder, }, - datatypes::{DataType, TimeUnit, UInt16Type}, + datatypes::{DataType, IntervalUnit, TimeUnit, UInt16Type}, }; pub fn map_data_type_to_array_builder_optional( @@ -39,7 +39,12 @@ pub fn map_data_type_to_array_builder(data_type: &DataType) -> Box Box::new(BooleanBuilder::new()), DataType::Binary => Box::new(BinaryBuilder::new()), DataType::LargeBinary => Box::new(LargeBinaryBuilder::new()), - DataType::Interval(_) => Box::new(IntervalMonthDayNanoBuilder::new()), + DataType::Interval(interval_unit) => match interval_unit { + IntervalUnit::YearMonth => Box::new(IntervalYearMonthBuilder::new()), + IntervalUnit::DayTime | IntervalUnit::MonthDayNano => { + Box::new(IntervalMonthDayNanoBuilder::new()) + } + }, DataType::Decimal32(precision, scale) => Box::new( Decimal32Builder::new() .with_precision_and_scale(*precision, *scale) diff --git a/core/src/sql/arrow_sql_gen/mod.rs b/core/src/sql/arrow_sql_gen/mod.rs index 606b831a..d0e8947b 100644 --- a/core/src/sql/arrow_sql_gen/mod.rs +++ b/core/src/sql/arrow_sql_gen/mod.rs @@ -45,6 +45,8 @@ pub mod arrow; #[cfg(feature = "mysql")] pub mod mysql; +#[cfg(feature = "oracle")] +pub mod oracle; #[cfg(feature = "postgres")] pub mod postgres; #[cfg(feature = "sqlite")] diff --git a/core/src/sql/arrow_sql_gen/oracle.rs b/core/src/sql/arrow_sql_gen/oracle.rs new file mode 100644 index 00000000..5603b086 --- /dev/null +++ b/core/src/sql/arrow_sql_gen/oracle.rs @@ -0,0 +1,454 @@ +use crate::sql::arrow_sql_gen::arrow::map_data_type_to_array_builder_optional; +use arrow::{ + array::{ + ArrayBuilder, ArrayRef, BinaryBuilder, BooleanBuilder, Date32Builder, Decimal128Builder, + Decimal256Builder, IntervalMonthDayNanoBuilder, IntervalYearMonthBuilder, + LargeBinaryBuilder, LargeStringBuilder, StringBuilder, TimestampMicrosecondBuilder, + TimestampMillisecondBuilder, TimestampNanosecondBuilder, TimestampSecondBuilder, + }, + datatypes::{ + i256, DataType, Field, IntervalMonthDayNano, IntervalUnit, Schema, SchemaRef, TimeUnit, + }, + error::ArrowError, + record_batch::RecordBatch, +}; + +use bigdecimal::num_bigint; +use bigdecimal::{BigDecimal, ToPrimitive}; +use chrono::{TimeZone, Utc}; +use oracle::Row; +use snafu::{OptionExt, ResultExt, Snafu}; +use std::sync::Arc; + +#[derive(Debug, Snafu)] +pub enum Error { + #[snafu(display("Failed to build record batch: {source}"))] + FailedToBuildRecordBatch { source: ArrowError }, + + #[snafu(display("No builder found for index {index}"))] + NoBuilderForIndex { index: usize }, + + #[snafu(display("Failed to downcast builder for index {index}"))] + FailedToDowncastBuilder { index: usize }, + + #[snafu(display("Oracle error: {source}"))] + OracleError { source: oracle::Error }, + + #[snafu(display("Cannot represent BigDecimal as i128: {big_decimal}"))] + FailedToConvertBigDecimalToI128 { big_decimal: BigDecimal }, + + #[snafu(display("Failed to parse BigDecimal from string '{value}': {source}"))] + ParseBigDecimalError { + value: String, + source: bigdecimal::ParseBigDecimalError, + }, + + #[snafu(display("Failed to map column {name} to arrow type"))] + FailedToMapColumnType { name: String }, +} + +pub type Result = std::result::Result; + +pub fn rows_to_arrow(rows: Vec, projected_schema: &Option) -> Result { + if rows.is_empty() { + return Ok(RecordBatch::new_empty( + projected_schema + .clone() + .unwrap_or_else(|| Arc::new(Schema::empty())), + )); + } + + let mut builders: Vec> = Vec::new(); + let mut arrow_fields: Vec = Vec::new(); + + let first_row = &rows[0]; + + // Determine schema fields + if let Some(schema) = projected_schema { + for field in schema.fields() { + arrow_fields.push((**field).clone()); + builders.push( + map_data_type_to_array_builder_optional(Some(field.data_type())) + .context(FailedToMapColumnTypeSnafu { name: field.name() })?, + ); + } + } else { + // Infer from first row - using ODPI-C metadata + // We can get column names from the Row's column_info. + // We default to Utf8 for all columns when schema is not provided, + // as we don't have the explicit mappings available here easily without the broader context. + let column_info = first_row.column_info(); + for info in column_info { + let name = info.name().to_string(); + let data_type = DataType::Utf8; + let field = Field::new(name, data_type.clone(), true); + arrow_fields.push(field.clone()); + builders.push( + map_data_type_to_array_builder_optional(Some(&data_type)) + .context(FailedToMapColumnTypeSnafu { name: field.name() })?, + ); + } + } + + for row in rows { + for (i, builder) in builders.iter_mut().enumerate() { + let field = &arrow_fields[i]; + + match field.data_type() { + DataType::Utf8 => { + let builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or(Error::FailedToDowncastBuilder { index: i })?; + let val: Option = row + .get::<_, Option>(i) + .map_err(|e| Error::OracleError { source: e })?; + builder.append_option(val); + } + DataType::Float64 => { + use arrow::array::Float64Builder; + let builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or(Error::FailedToDowncastBuilder { index: i })?; + let val: Option = row + .get::<_, Option>(i) + .map_err(|e| Error::OracleError { source: e })?; + builder.append_option(val); + } + DataType::Float32 => { + use arrow::array::Float32Builder; + let builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or(Error::FailedToDowncastBuilder { index: i })?; + let val: Option = row + .get::<_, Option>(i) + .map_err(|e| Error::OracleError { source: e })?; + builder.append_option(val); + } + DataType::Int64 => { + use arrow::array::Int64Builder; + let builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or(Error::FailedToDowncastBuilder { index: i })?; + let val: Option = row + .get::<_, Option>(i) + .map_err(|e| Error::OracleError { source: e })?; + builder.append_option(val); + } + DataType::Decimal128(_p, s) => { + let builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or(Error::FailedToDowncastBuilder { index: i })?; + let val: Option = row + .get::<_, Option>(i) + .map_err(|e| Error::OracleError { source: e })?; + if let Some(s_val) = val { + let big_dec = s_val.parse::().map_err(|e| { + Error::ParseBigDecimalError { + value: s_val.clone(), + source: e, + } + })?; + let i128_val = to_decimal_128(&big_dec, *s as i64).ok_or( + Error::FailedToConvertBigDecimalToI128 { + big_decimal: big_dec, + }, + )?; + builder.append_value(i128_val); + } else { + builder.append_null(); + } + } + DataType::Decimal256(_p, _s) => { + let builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or(Error::FailedToDowncastBuilder { index: i })?; + let val: Option = row + .get::<_, Option>(i) + .map_err(|e| Error::OracleError { source: e })?; + if let Some(s_val) = val { + let big_dec = s_val.parse::().map_err(|e| { + Error::ParseBigDecimalError { + value: s_val.clone(), + source: e, + } + })?; + let i256_val = to_decimal_256(&big_dec); + builder.append_value(i256_val); + } else { + builder.append_null(); + } + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + let builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or(Error::FailedToDowncastBuilder { index: i })?; + let val: Option = row + .get::<_, Option>(i) + .map_err(|e| Error::OracleError { source: e })?; + if let Some(ts) = val { + let chrono_ts = Utc + .with_ymd_and_hms( + ts.year(), + ts.month(), + ts.day(), + ts.hour(), + ts.minute(), + ts.second(), + ) + .single(); + + if let Some(chrono_ts) = chrono_ts { + let micros = + chrono_ts.timestamp() * 1_000_000 + (ts.nanosecond() / 1000) as i64; + builder.append_value(micros); + } else { + builder.append_null(); + } + } else { + builder.append_null(); + } + } + DataType::Timestamp(TimeUnit::Second, _) => { + let builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or(Error::FailedToDowncastBuilder { index: i })?; + let val: Option = row + .get::<_, Option>(i) + .map_err(|e| Error::OracleError { source: e })?; + if let Some(ts) = val { + let chrono_ts = Utc + .with_ymd_and_hms( + ts.year(), + ts.month(), + ts.day(), + ts.hour(), + ts.minute(), + ts.second(), + ) + .single(); + + if let Some(chrono_ts) = chrono_ts { + builder.append_value(chrono_ts.timestamp()); + } else { + builder.append_null(); + } + } else { + builder.append_null(); + } + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + let builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or(Error::FailedToDowncastBuilder { index: i })?; + let val: Option = row + .get::<_, Option>(i) + .map_err(|e| Error::OracleError { source: e })?; + if let Some(ts) = val { + let chrono_ts = Utc + .with_ymd_and_hms( + ts.year(), + ts.month(), + ts.day(), + ts.hour(), + ts.minute(), + ts.second(), + ) + .single(); + + if let Some(chrono_ts) = chrono_ts { + let millis = chrono_ts.timestamp() * 1_000 + + (ts.nanosecond() / 1_000_000) as i64; + builder.append_value(millis); + } else { + builder.append_null(); + } + } else { + builder.append_null(); + } + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + let builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or(Error::FailedToDowncastBuilder { index: i })?; + let val: Option = row + .get::<_, Option>(i) + .map_err(|e| Error::OracleError { source: e })?; + if let Some(ts) = val { + let chrono_ts = Utc + .with_ymd_and_hms( + ts.year(), + ts.month(), + ts.day(), + ts.hour(), + ts.minute(), + ts.second(), + ) + .single(); + + if let Some(chrono_ts) = chrono_ts { + let nanos = + chrono_ts.timestamp() * 1_000_000_000 + ts.nanosecond() as i64; + builder.append_value(nanos); + } else { + builder.append_null(); + } + } else { + builder.append_null(); + } + } + DataType::Date32 => { + let builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or(Error::FailedToDowncastBuilder { index: i })?; + let val: Option = row + .get::<_, Option>(i) + .map_err(|e| Error::OracleError { source: e })?; + if let Some(ts) = val { + // Date32 is days since Unix epoch + // Use NaiveDate to avoid timezone issues + use chrono::NaiveDate; + let naive_date = NaiveDate::from_ymd_opt(ts.year(), ts.month(), ts.day()); + if let Some(date) = naive_date { + // Calculate days since Unix epoch (1970-01-01) + let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + let days = date.signed_duration_since(epoch).num_days() as i32; + builder.append_value(days); + } else { + builder.append_null(); + } + } else { + builder.append_null(); + } + } + DataType::Boolean => { + let builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or(Error::FailedToDowncastBuilder { index: i })?; + let val: Option = row + .get::<_, Option>(i) + .map_err(|e| Error::OracleError { source: e })?; + builder.append_option(val); + } + DataType::Binary => { + let builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or(Error::FailedToDowncastBuilder { index: i })?; + let val: Option> = row + .get::<_, Option>>(i) + .map_err(|e| Error::OracleError { source: e })?; + builder.append_option(val); + } + DataType::LargeBinary => { + let builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or(Error::FailedToDowncastBuilder { index: i })?; + let val: Option> = row + .get::<_, Option>>(i) + .map_err(|e| Error::OracleError { source: e })?; + builder.append_option(val); + } + DataType::LargeUtf8 => { + let builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or(Error::FailedToDowncastBuilder { index: i })?; + let val: Option = row + .get::<_, Option>(i) + .map_err(|e| Error::OracleError { source: e })?; + builder.append_option(val); + } + DataType::Interval(IntervalUnit::YearMonth) => { + let builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or(Error::FailedToDowncastBuilder { index: i })?; + let val: Option = row + .get::<_, Option>(i) + .map_err(|e| Error::OracleError { source: e })?; + if let Some(interval) = val { + // Convert to total months: years * 12 + months + let total_months = interval.years() * 12 + interval.months(); + builder.append_value(total_months); + } else { + builder.append_null(); + } + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + let builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or(Error::FailedToDowncastBuilder { index: i })?; + let val: Option = row + .get::<_, Option>(i) + .map_err(|e| Error::OracleError { source: e })?; + if let Some(interval) = val { + // Convert to nanoseconds: hours, minutes, seconds, nanoseconds + let nanos = (interval.hours() as i64 * 3600 + + interval.minutes() as i64 * 60 + + interval.seconds() as i64) + * 1_000_000_000 + + interval.nanoseconds() as i64; + let interval_value = IntervalMonthDayNano::new(0, interval.days(), nanos); + builder.append_value(interval_value); + } else { + builder.append_null(); + } + } + _ => { + // Fallback: try to get as string + let builder = builder + .as_any_mut() + .downcast_mut::() + .ok_or(Error::FailedToDowncastBuilder { index: i })?; + let val: Option = row + .get::<_, Option>(i) + .map_err(|e| Error::OracleError { source: e })?; + builder.append_option(val); + } + } + } + } + + let arrays: Vec = builders.into_iter().map(|mut b| b.finish()).collect(); + let schema = Arc::new(Schema::new(arrow_fields)); + + RecordBatch::try_new(schema, arrays).context(FailedToBuildRecordBatchSnafu) +} + +fn to_decimal_128(decimal: &BigDecimal, scale: i64) -> Option { + let scale_u32: u32 = scale.try_into().ok()?; + (decimal * 10i128.pow(scale_u32)).to_i128() +} + +fn to_decimal_256(decimal: &BigDecimal) -> i256 { + let (bigint_value, _) = decimal.as_bigint_and_exponent(); + let mut bigint_bytes = bigint_value.to_signed_bytes_le(); + + let is_negative = bigint_value.sign() == num_bigint::Sign::Minus; + let fill_byte = if is_negative { 0xFF } else { 0x00 }; + + if bigint_bytes.len() > 32 { + bigint_bytes.truncate(32); + } else { + bigint_bytes.resize(32, fill_byte); + }; + + let mut array = [0u8; 32]; + array.copy_from_slice(&bigint_bytes); + + i256::from_le_bytes(array) +} diff --git a/core/src/sql/db_connection_pool/dbconnection.rs b/core/src/sql/db_connection_pool/dbconnection.rs index 94590026..8eb6164a 100644 --- a/core/src/sql/db_connection_pool/dbconnection.rs +++ b/core/src/sql/db_connection_pool/dbconnection.rs @@ -13,6 +13,8 @@ pub mod duckdbconn; pub mod mysqlconn; #[cfg(feature = "odbc")] pub mod odbcconn; +#[cfg(feature = "oracle")] +pub mod oracleconn; #[cfg(feature = "postgres")] pub mod postgresconn; #[cfg(feature = "sqlite")] diff --git a/core/src/sql/db_connection_pool/dbconnection/oracleconn.rs b/core/src/sql/db_connection_pool/dbconnection/oracleconn.rs new file mode 100644 index 00000000..44ec9e17 --- /dev/null +++ b/core/src/sql/db_connection_pool/dbconnection/oracleconn.rs @@ -0,0 +1,375 @@ +use async_trait::async_trait; +use bb8_oracle::OracleConnectionManager; +use datafusion::{ + arrow::datatypes::SchemaRef, execution::SendableRecordBatchStream, + physical_plan::stream::RecordBatchStreamAdapter, sql::TableReference, +}; +use std::{any::Any, sync::Arc}; + +use async_stream::stream; +use snafu::ResultExt; +use tokio::sync::mpsc; +use tokio::task; + +use crate::sql::{ + arrow_sql_gen::oracle::rows_to_arrow, + db_connection_pool::dbconnection::{ + AsyncDbConnection, DbConnection, Error, GenericError, Result, + }, +}; + +pub type OraclePooledConnection = bb8::PooledConnection<'static, OracleConnectionManager>; + +pub struct OracleConnection { + pub conn: OraclePooledConnection, +} + +impl OracleConnection { + pub fn new(conn: OraclePooledConnection) -> Self { + Self { conn } + } +} + +impl DbConnection for OracleConnection { + fn as_any(&self) -> &dyn Any { + self + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + fn as_async( + &self, + ) -> Option<&dyn AsyncDbConnection> { + Some(self) + } +} + +#[async_trait] +impl AsyncDbConnection for OracleConnection { + fn new(conn: OraclePooledConnection) -> Self { + Self { conn } + } + + async fn get_schema( + &self, + table_reference: &TableReference, + ) -> std::result::Result { + let table_name = table_reference.table().to_uppercase(); + let schema_name = table_reference.schema().map(|s| s.to_uppercase()); + + let conn = self.conn.clone(); + + let rows = task::spawn_blocking(move || { + if let Some(schema) = schema_name { + let rows = conn.query( + "SELECT column_name, data_type, data_precision, data_scale, nullable + FROM all_tab_columns + WHERE owner = :1 AND table_name = :2 + ORDER BY column_id", + &[&schema, &table_name], + )?; + rows.collect::, _>>() + } else { + let rows = conn.query( + "SELECT column_name, data_type, data_precision, data_scale, nullable + FROM all_tab_columns + WHERE table_name = :1 + ORDER BY column_id", + &[&table_name], + )?; + rows.collect::, _>>() + } + }) + .await + .map_err(|e| Box::new(e) as GenericError) + .context(super::UnableToGetSchemaSnafu)? + .map_err(|e| Box::new(e) as GenericError) + .context(super::UnableToGetSchemaSnafu)?; + + let mut fields: Vec = Vec::new(); + + for row in rows { + let column_name: String = row + .get(0) + .map_err(|e| Box::new(e) as GenericError) + .context(super::UnableToGetSchemaSnafu)?; + let data_type_str: String = row + .get(1) + .map_err(|e| Box::new(e) as GenericError) + .context(super::UnableToGetSchemaSnafu)?; + let precision: Option = row + .get(2) + .map_err(|e| Box::new(e) as GenericError) + .context(super::UnableToGetSchemaSnafu)?; + let scale: Option = row + .get(3) + .map_err(|e| Box::new(e) as GenericError) + .context(super::UnableToGetSchemaSnafu)?; + let nullable_str: String = row + .get(4) + .map_err(|e| Box::new(e) as GenericError) + .context(super::UnableToGetSchemaSnafu)?; + let nullable = nullable_str != "N"; + + let arrow_type = map_oracle_type_to_arrow(&data_type_str, precision, scale); + + fields.push(datafusion::arrow::datatypes::Field::new( + column_name, // Keep original case from Oracle + arrow_type, + nullable, + )); + } + + Ok(Arc::new(datafusion::arrow::datatypes::Schema::new(fields))) + } + + async fn query_arrow( + &self, + sql: &str, + _params: &[oracle::sql_type::OracleType], + projected_schema: Option, + ) -> Result { + let sql = sql.to_string(); + let conn = self.conn.clone(); + let schema_clone = projected_schema.clone(); + + let (tx, mut rx) = mpsc::channel(2); + + task::spawn_blocking(move || { + let process = || -> std::result::Result<(), GenericError> { + let mut stmt = conn + .statement(&sql) + .fetch_array_size(100_000) + .build() + .map_err(|e| Box::new(e) as GenericError) + .context(super::UnableToQueryArrowSnafu)?; + + let rows = stmt + .query(&[]) + .map_err(|e| Box::new(e) as GenericError) + .context(super::UnableToQueryArrowSnafu)?; + + let mut chunk = Vec::with_capacity(4096); + for row_result in rows { + let row = row_result + .map_err(|e| Box::new(e) as GenericError) + .context(super::UnableToQueryArrowSnafu)?; + + chunk.push(row); + if chunk.len() >= 4096 { + let batch_res = rows_to_arrow(chunk, &schema_clone) + .map_err(|e| Box::new(e) as GenericError) + .context(super::UnableToQueryArrowSnafu) + .map_err(|e| Box::new(e) as GenericError); + + if tx.blocking_send(batch_res).is_err() { + return Ok(()); + } + chunk = Vec::with_capacity(4096); + } + } + if !chunk.is_empty() { + let batch_res = rows_to_arrow(chunk, &schema_clone) + .map_err(|e| Box::new(e) as GenericError) + .context(super::UnableToQueryArrowSnafu) + .map_err(|e| Box::new(e) as GenericError); + let _ = tx.blocking_send(batch_res); + } + Ok(()) + }; + + if let Err(e) = process() { + let _ = tx.blocking_send(Err(e)); + } + }); + + // Peek first batch to determine schema if needed + let first_result = rx.recv().await; + + let Some(first_batch_res) = first_result else { + // Stream empty + let empty_schema = projected_schema + .unwrap_or_else(|| Arc::new(datafusion::arrow::datatypes::Schema::empty())); + return Ok(Box::pin(RecordBatchStreamAdapter::new( + empty_schema, + futures::stream::empty(), + ))); + }; + + let first_batch = first_batch_res?; + let schema = first_batch.schema(); + + let output_stream = stream! { + yield Ok(first_batch); + while let Some(result) = rx.recv().await { + match result { + Ok(batch) => yield Ok(batch), + Err(e) => yield Err(datafusion::error::DataFusionError::External(e)), + } + } + }; + + Ok(Box::pin(RecordBatchStreamAdapter::new( + projected_schema.unwrap_or(schema), + output_stream, + ))) + } + + async fn execute(&self, sql: &str, _params: &[oracle::sql_type::OracleType]) -> Result { + let sql = sql.to_string(); + let conn = self.conn.clone(); + + let row_count = task::spawn_blocking(move || { + let stmt = conn.execute(&sql, &[])?; + stmt.row_count() + }) + .await + .map_err(|e| Box::new(e) as GenericError) + .context(super::UnableToQueryArrowSnafu)? + .map_err(|e| Box::new(e) as GenericError) + .context(super::UnableToQueryArrowSnafu)?; + + Ok(row_count) + } + + async fn tables(&self, schema: &str) -> std::result::Result, Error> { + let schema = schema.to_uppercase(); + let conn = self.conn.clone(); + + let table_names = task::spawn_blocking(move || { + let rows = conn.query( + "SELECT table_name FROM all_tables WHERE owner = :1", + &[&schema], + )?; + let mut result = Vec::new(); + for row in rows { + let row = row?; + let val: String = row.get(0)?; + result.push(val); + } + Ok::, oracle::Error>(result) + }) + .await + .map_err(|e| Box::new(e) as GenericError) + .context(super::UnableToGetTablesSnafu)? + .map_err(|e| Box::new(e) as GenericError) + .context(super::UnableToGetTablesSnafu)?; + + Ok(table_names) + } + + async fn schemas(&self) -> std::result::Result, Error> { + let conn = self.conn.clone(); + + let schemas = task::spawn_blocking(move || { + let rows = conn.query("SELECT username FROM all_users", &[])?; + let mut result = Vec::new(); + for row in rows { + let row = row?; + let val: String = row.get(0)?; + result.push(val); + } + Ok::, oracle::Error>(result) + }) + .await + .map_err(|e| Box::new(e) as GenericError) + .context(super::UnableToGetSchemasSnafu)? + .map_err(|e| Box::new(e) as GenericError) + .context(super::UnableToGetSchemasSnafu)?; + + Ok(schemas) + } +} + +/// Map Oracle data types to Arrow data types +fn map_oracle_type_to_arrow( + oracle_type: &str, + precision: Option, + scale: Option, +) -> datafusion::arrow::datatypes::DataType { + use datafusion::arrow::datatypes::DataType; + + let type_upper = oracle_type.to_uppercase(); + + // Handle types with parameters like VARCHAR2(100) + let base_type = if let Some(paren_pos) = type_upper.find('(') { + &type_upper[..paren_pos] + } else { + &type_upper + }; + + match base_type.trim() { + // String types + "VARCHAR2" | "NVARCHAR2" | "CHAR" | "NCHAR" => DataType::Utf8, + "CLOB" | "NCLOB" | "LONG" => DataType::LargeUtf8, + + // Numeric types + "NUMBER" | "NUMERIC" | "DECIMAL" | "DEC" => { + let p = precision.unwrap_or(38) as u8; + let s = scale.unwrap_or(0) as i8; + // Int64 for integer types (scale = 0, precision ≤ 18) + if s == 0 && p <= 18 { + return DataType::Int64; + } + if p > 38 { + DataType::Decimal256(p, s) + } else { + DataType::Decimal128(p, s) + } + } + "INTEGER" | "INT" | "SMALLINT" => DataType::Int64, + "FLOAT" => { + // FLOAT precision in Oracle is binary: ≤24 → Float32, >24 → Float64 + match precision { + Some(p) if p <= 24 => DataType::Float32, + _ => DataType::Float64, + } + } + "REAL" | "DOUBLE PRECISION" => DataType::Float64, + "BINARY_FLOAT" => DataType::Float32, + "BINARY_DOUBLE" => DataType::Float64, + + "BOOLEAN" => DataType::Boolean, + + // Date/Time types - Oracle DATE is conventionally a date, not a timestamp + "DATE" => DataType::Date32, + _ if type_upper.contains("TIMESTAMP") => { + use datafusion::arrow::datatypes::TimeUnit; + // Precision-aware timestamp: scale contains fractional seconds precision + let fractional_precision = scale.unwrap_or(6); + let time_unit = match fractional_precision { + 0 => TimeUnit::Second, + 1..=3 => TimeUnit::Millisecond, + 4..=6 => TimeUnit::Microsecond, + _ => TimeUnit::Nanosecond, + }; + let tz = if type_upper.contains("WITH TIME ZONE") + || type_upper.contains("WITH LOCAL TIME ZONE") + { + Some("UTC".into()) + } else { + None + }; + DataType::Timestamp(time_unit, tz) + } + + // Interval types + _ if type_upper.starts_with("INTERVAL YEAR") => { + use datafusion::arrow::datatypes::IntervalUnit; + DataType::Interval(IntervalUnit::YearMonth) + } + _ if type_upper.starts_with("INTERVAL DAY") => { + use datafusion::arrow::datatypes::IntervalUnit; + DataType::Interval(IntervalUnit::MonthDayNano) + } + + // Binary types + "RAW" => DataType::Binary, + "BLOB" | "LONG RAW" => DataType::LargeBinary, + + // Other types - default to string + _ => DataType::Utf8, + } +} diff --git a/core/src/sql/db_connection_pool/mod.rs b/core/src/sql/db_connection_pool/mod.rs index 535110e9..b38bf8fe 100644 --- a/core/src/sql/db_connection_pool/mod.rs +++ b/core/src/sql/db_connection_pool/mod.rs @@ -12,6 +12,8 @@ pub mod duckdbpool; pub mod mysqlpool; #[cfg(feature = "odbc")] pub mod odbcpool; +#[cfg(feature = "oracle")] +pub mod oraclepool; #[cfg(feature = "postgres")] pub mod postgrespool; pub mod runtime; diff --git a/core/src/sql/db_connection_pool/oraclepool.rs b/core/src/sql/db_connection_pool/oraclepool.rs new file mode 100644 index 00000000..9bf1498d --- /dev/null +++ b/core/src/sql/db_connection_pool/oraclepool.rs @@ -0,0 +1,191 @@ +use std::collections::HashMap; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; + +use async_trait::async_trait; +use bb8::CustomizeConnection; +use bb8_oracle::OracleConnectionManager; +use oracle::{Connection, Connector}; + +use secrecy::{ExposeSecret, SecretString}; +use snafu::prelude::*; + +use super::DbConnectionPool; + +/// Default TCP port for Oracle Database +const DEFAULT_ORACLE_PORT: u16 = 1521; + +/// Default service name for Oracle Database connections +static DEFAULT_SERVICE_NAME: &str = "ORCL"; + +/// Default maximum pool size +const DEFAULT_POOL_MAX_SIZE: u32 = 10; + +/// Default timezone for Oracle sessions (UTC for consistent timestamp handling) +static DEFAULT_TIMEZONE: &str = "UTC"; + +#[derive(Debug, Snafu)] +pub enum Error { + #[snafu(display("Oracle connection failed: {source}"))] + ConnectionError { source: oracle::Error }, + + #[snafu(display("Unable to create Oracle connection pool: {source}"))] + PoolCreationError { source: bb8_oracle::Error }, + + #[snafu(display("Unable to get Oracle connection from pool: {source}"))] + PoolRunError { + source: bb8::RunError, + }, + + #[snafu(display("Missing required parameter: {param}"))] + MissingParameter { param: String }, +} + +pub type Result = std::result::Result; + +/// Customizer that sets session timezone on connection acquire. +/// This ensures consistent timestamp handling across all connections. +#[derive(Debug, Clone)] +pub struct SetTimezoneCustomizer { + pub timezone: String, +} + +impl CustomizeConnection, bb8_oracle::Error> for SetTimezoneCustomizer { + fn on_acquire<'a>( + &'a self, + conn: &'a mut Arc, + ) -> Pin> + Send + 'a>> { + let sql = format!("ALTER SESSION SET TIME_ZONE = '{}'", self.timezone); + Box::pin(async move { + // Execute the timezone setting synchronously + // rust-oracle is synchronous, so this is safe in async context + conn.execute(&sql, &[]) + .map_err(bb8_oracle::Error::Database)?; + Ok(()) + }) + } +} + +pub struct OracleConnectionPool { + pool: Arc>, +} + +impl std::fmt::Debug for OracleConnectionPool { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OracleConnectionPool").finish() + } +} + +impl OracleConnectionPool { + pub async fn new(params: HashMap) -> Result { + let user = params + .get("user") + .ok_or(Error::MissingParameter { + param: "user".to_string(), + })? + .expose_secret(); + + let password = params + .get("password") + .ok_or(Error::MissingParameter { + param: "password".to_string(), + })? + .expose_secret(); + + let host = params + .get("host") + .ok_or(Error::MissingParameter { + param: "host".to_string(), + })? + .expose_secret(); + + let port = params + .get("port") + .map(|s| { + s.expose_secret() + .parse::() + .unwrap_or(DEFAULT_ORACLE_PORT) + }) + .unwrap_or(DEFAULT_ORACLE_PORT); + + let service_name = params + .get("service_name") + .or_else(|| params.get("sid")) + .map(|s| s.expose_secret().to_string()) + .unwrap_or_else(|| DEFAULT_SERVICE_NAME.to_string()); + + let connector = Connector::new( + user, + password, + format!("//{}:{}/{}", host, port, service_name), + ); + + let manager = OracleConnectionManager::from_connector(connector); + + // Get timezone setting (default to UTC for consistent timestamp handling) + let timezone = params + .get("timezone") + .map(|s| s.expose_secret().to_string()) + .unwrap_or_else(|| DEFAULT_TIMEZONE.to_string()); + + let pool = bb8::Pool::builder() + .max_size( + params + .get("pool_max") + .and_then(|s| s.expose_secret().parse().ok()) + .unwrap_or(DEFAULT_POOL_MAX_SIZE), + ) + .connection_customizer(Box::new(SetTimezoneCustomizer { timezone })) + .build(manager) + .await + .map_err(|e| Error::PoolCreationError { source: e })?; + + Ok(Self { + pool: Arc::new(pool), + }) + } + + pub async fn connect_direct( + &self, + ) -> Result { + let conn = Arc::clone(&self.pool) + .get_owned() + .await + .map_err(|e| Error::PoolRunError { source: e })?; + + Ok(super::dbconnection::oracleconn::OracleConnection::new(conn)) + } +} + +#[async_trait] +impl + DbConnectionPool< + bb8::PooledConnection<'static, OracleConnectionManager>, + oracle::sql_type::OracleType, + > for OracleConnectionPool +{ + async fn connect( + &self, + ) -> std::result::Result< + Box< + dyn super::dbconnection::DbConnection< + bb8::PooledConnection<'static, OracleConnectionManager>, + oracle::sql_type::OracleType, + >, + >, + Box, + > { + let conn = Arc::clone(&self.pool).get_owned().await.map_err(|e| { + Box::new(Error::PoolRunError { source: e }) as Box + })?; + + Ok(Box::new( + super::dbconnection::oracleconn::OracleConnection::new(conn), + )) + } + + fn join_push_down(&self) -> super::JoinPushDown { + super::JoinPushDown::Disallow + } +} diff --git a/core/tests/integration.rs b/core/tests/integration.rs index 02dd2041..a0ce8050 100644 --- a/core/tests/integration.rs +++ b/core/tests/integration.rs @@ -10,6 +10,8 @@ mod duckdb; mod flight; #[cfg(feature = "mysql")] mod mysql; +#[cfg(feature = "oracle")] +mod oracle; #[cfg(feature = "postgres")] mod postgres; #[cfg(feature = "sqlite")] diff --git a/core/tests/oracle/common.rs b/core/tests/oracle/common.rs new file mode 100644 index 00000000..d09e7e96 --- /dev/null +++ b/core/tests/oracle/common.rs @@ -0,0 +1,43 @@ +use datafusion_table_providers::sql::db_connection_pool::oraclepool::OracleConnectionPool; +use secrecy::SecretString; +use std::collections::HashMap; +use std::env; + +const ORACLE_PASSWORD: &str = "password"; +const ORACLE_USER: &str = "system"; +const ORACLE_SERVICE: &str = "FREEPDB1"; +const DEFAULT_ORACLE_PORT: u16 = 1521; + +pub fn get_oracle_params() -> HashMap { + let mut params = HashMap::new(); + + // Default to strict env vars or defaults + let host = env::var("ORACLE_HOST").unwrap_or_else(|_| "localhost".to_string()); + let port = env::var("ORACLE_PORT").unwrap_or_else(|_| DEFAULT_ORACLE_PORT.to_string()); + let user = env::var("ORACLE_USER").unwrap_or_else(|_| ORACLE_USER.to_string()); + let pass = env::var("ORACLE_PASSWORD").unwrap_or_else(|_| ORACLE_PASSWORD.to_string()); + let service = env::var("ORACLE_SERVICE").unwrap_or_else(|_| ORACLE_SERVICE.to_string()); + + params.insert("host".to_string(), SecretString::from(host)); + params.insert("port".to_string(), SecretString::from(port)); + params.insert("user".to_string(), SecretString::from(user)); + params.insert("password".to_string(), SecretString::from(pass)); + params.insert("service_name".to_string(), SecretString::from(service)); + + // Optional wallet params + if let Ok(wallet) = env::var("ORACLE_WALLET_PATH") { + params.insert("wallet_path".to_string(), SecretString::from(wallet)); + } + if let Ok(wpass) = env::var("ORACLE_WALLET_PASSWORD") { + params.insert("wallet_password".to_string(), SecretString::from(wpass)); + } + + params +} + +pub async fn get_oracle_connection_pool() -> OracleConnectionPool { + let params = get_oracle_params(); + OracleConnectionPool::new(params) + .await + .expect("Failed to create Oracle connection pool") +} diff --git a/core/tests/oracle/mod.rs b/core/tests/oracle/mod.rs new file mode 100644 index 00000000..bb451432 --- /dev/null +++ b/core/tests/oracle/mod.rs @@ -0,0 +1,861 @@ +use datafusion::arrow::array::*; +use datafusion::arrow::datatypes::{ + i256, DataType, Field, IntervalMonthDayNano, IntervalUnit, Schema, TimeUnit, +}; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::execution::context::SessionContext; +use datafusion::sql::TableReference; +use datafusion_table_providers::oracle::OracleTableFactory; +use datafusion_table_providers::sql::db_connection_pool::dbconnection::oracleconn::OraclePooledConnection; +use datafusion_table_providers::sql::db_connection_pool::dbconnection::DbConnection; +use datafusion_table_providers::sql::db_connection_pool::DbConnectionPool; +use datafusion_table_providers::sql::sql_provider_datafusion::SqlTable; +use futures::StreamExt; +use std::sync::Arc; + +mod common; + +#[tokio::test] +async fn test_oracle_connection_pool() { + let pool = common::get_oracle_connection_pool().await; + let conn = pool + .connect_direct() + .await + .expect("Failed to get connection"); + + let rows = conn + .conn + .query("SELECT 1 FROM DUAL", &[]) + .expect("Failed to execute query"); + let rows: Vec = rows + .collect::, _>>() + .expect("Failed to collect rows"); + assert!(!rows.is_empty()); + + let first_row = &rows[0]; + let val_str: String = first_row.get(0).expect("Value should exist"); + assert_eq!(val_str, "1"); +} + +/// Test registering Oracle's DUAL table as a DataFusion table provider +#[tokio::test] +async fn test_oracle_table_provider_registration() { + let pool = common::get_oracle_connection_pool().await; + let factory = OracleTableFactory::new(Arc::new(pool)); + + let provider = factory + .table_provider(TableReference::from("DUAL")) + .await + .expect("Failed to create table provider"); + + let ctx = SessionContext::new(); + ctx.register_table("dual_test", provider) + .expect("Failed to register table"); + + let df = ctx + .sql("SELECT * FROM dual_test") + .await + .expect("Failed to create dataframe"); + let _result = df.collect().await.expect("Failed to execute query"); +} + +/// Test querying data through DataFusion using ALL_TABLES system view. +/// Validates the full Provider -> DataFusion query path. +#[tokio::test] +async fn test_oracle_query_with_data() { + let pool = common::get_oracle_connection_pool().await; + let factory = OracleTableFactory::new(Arc::new(pool)); + + let table_name = "ALL_TABLES"; + let provider = factory + .table_provider(TableReference::from(table_name)) + .await + .expect("Failed to create table provider for ALL_TABLES"); + + // Verify schema contains expected columns + let schema = provider.schema(); + let fields: Vec = schema.fields().iter().map(|f| f.name().clone()).collect(); + assert!( + fields.contains(&"TABLE_NAME".to_string()) || fields.contains(&"table_name".to_string()), + "Schema missing 'table_name' column: {:?}", + fields + ); + + let ctx = SessionContext::new(); + ctx.register_table("system_tables", provider) + .expect("Failed to register table"); + + let sql = "SELECT \"TABLE_NAME\", \"OWNER\" FROM system_tables"; + let df = ctx.sql(sql).await.expect("Failed to build plan"); + + let batches = df.collect().await.expect("Query execution failed"); + let row_count: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert!(row_count > 0, "Expected to read rows from ALL_TABLES"); +} + +/// Test schema inference for ALL_TABLES system view +#[tokio::test] +async fn test_oracle_explain_plan() { + let pool = common::get_oracle_connection_pool().await; + let factory = OracleTableFactory::new(Arc::new(pool)); + + let provider = factory + .table_provider(TableReference::from("ALL_TABLES")) + .await + .expect("Failed to create table provider for ALL_TABLES"); + + let schema = provider.schema(); + println!("\n=== ALL_TABLES Schema ==="); + for field in schema.fields() { + println!( + " {} : {:?} (nullable: {})", + field.name(), + field.data_type(), + field.is_nullable() + ); + } + + assert!(!schema.fields().is_empty(), "Expected schema with columns"); + + let field_names: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect(); + assert!( + field_names.contains(&"TABLE_NAME"), + "Expected TABLE_NAME column" + ); + assert!(field_names.contains(&"OWNER"), "Expected OWNER column"); +} + +/// Test schema inference for ALL_TAB_COLUMNS system view +#[tokio::test] +async fn test_oracle_explain_verbose() { + let pool = common::get_oracle_connection_pool().await; + let factory = OracleTableFactory::new(Arc::new(pool)); + + let provider = factory + .table_provider(TableReference::from("ALL_TAB_COLUMNS")) + .await + .expect("Failed to create table provider for ALL_TAB_COLUMNS"); + + let schema = provider.schema(); + println!("\n=== ALL_TAB_COLUMNS Schema ==="); + for field in schema.fields() { + println!(" {} : {:?}", field.name(), field.data_type()); + } + + assert!(!schema.fields().is_empty(), "Expected schema with columns"); + + let field_names: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect(); + assert!( + field_names.contains(&"COLUMN_NAME"), + "Expected COLUMN_NAME column" + ); + assert!( + field_names.contains(&"DATA_TYPE"), + "Expected DATA_TYPE column" + ); +} + +/// Row struct for insertion test +#[derive(Debug)] +struct Row { + id: i64, + name: String, + age: i32, + score: f64, +} + +fn create_sample_rows() -> Vec { + vec![ + Row { + id: 1, + name: "Alice".to_string(), + age: 30, + score: 91.5, + }, + Row { + id: 2, + name: "Bob".to_string(), + age: 45, + score: 85.2, + }, + ] +} + +/// Creates or recreates a test table with the given name +async fn create_test_table( + conn: &oracle::Connection, + table_name: &str, +) -> std::result::Result<(), oracle::Error> { + let check_sql = format!( + "SELECT count(*) FROM user_tables WHERE table_name = '{}'", + table_name + ); + let rows = conn.query(&check_sql, &[])?; + let rows: Vec = rows.collect::, _>>()?; + let count: i64 = if !rows.is_empty() { rows[0].get(0)? } else { 0 }; + + if count > 0 { + let _ = conn.execute(&format!("DROP TABLE {}", table_name), &[]); + } + + let sql = format!( + "CREATE TABLE {} ( + id NUMBER, + name VARCHAR2(100), + age NUMBER, + score BINARY_DOUBLE + )", + table_name + ); + + conn.execute(&sql, &[])?; + Ok(()) +} + +async fn insert_test_rows( + conn: &oracle::Connection, + table_name: &str, + rows: Vec, +) -> std::result::Result<(), oracle::Error> { + let sql = format!( + "INSERT INTO {} (id, name, age, score) VALUES (:1, :2, :3, :4)", + table_name + ); + + for row in rows { + conn.execute(&sql, &[&row.id, &row.name, &row.age, &row.score])?; + } + + conn.commit()?; + Ok(()) +} + +/// Full integration test: Create table -> Insert data -> Read via DataFusion +#[tokio::test] +async fn test_oracle_insert_and_read() { + let table_name = "TEST_EMPLOYEES"; + + let pool = common::get_oracle_connection_pool().await; + let conn = pool + .connect_direct() + .await + .expect("Failed to get connection"); + + create_test_table(&conn.conn, table_name) + .await + .expect("Create table failed"); + insert_test_rows(&conn.conn, table_name, create_sample_rows()) + .await + .expect("Insert failed"); + + drop(conn); + drop(pool); + + let pool_query = common::get_oracle_connection_pool().await; + let factory = OracleTableFactory::new(Arc::new(pool_query)); + + let ctx = SessionContext::new(); + let provider = factory + .table_provider(TableReference::from(table_name)) + .await + .expect("Provider creation failed"); + + // Verify schema has expected columns (uppercase in Oracle) + let schema = provider.schema(); + let fields: Vec = schema.fields().iter().map(|f| f.name().clone()).collect(); + assert!(fields.contains(&"ID".to_string())); + assert!(fields.contains(&"NAME".to_string())); + assert!(fields.contains(&"SCORE".to_string())); + + ctx.register_table("employees", provider) + .expect("Table register failed"); + + // Note: Column names must be quoted for uppercase identifiers in DataFusion SQL + let sql = "SELECT * FROM employees ORDER BY \"ID\""; + let df = ctx.sql(sql).await.expect("Query failed"); + let batches = df.collect().await.expect("Collect failed"); + + let row_count: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(row_count, 2); +} + +#[tokio::test] +async fn test_oracle_number_types() { + let create_table_stmt = " + CREATE TABLE number_test_table ( + n1 NUMBER, + n2 NUMBER(10), + n3 NUMBER(10, 2), + n4 NUMBER(38, 10), + n5 NUMBER(38) + ) + "; + let insert_table_stmt = " + INSERT INTO number_test_table (n1, n2, n3, n4, n5) + VALUES ( + 123.456, + 1234567890, + 12345678.90, + 123456789012345678.1234567890, + 12345678901234567890123456789012345678 + ) + "; + + let schema = Arc::new(Schema::new(vec![ + Field::new("N1", DataType::Decimal128(38, 0), true), + Field::new("N2", DataType::Int64, true), // NUMBER(10,0) where p≤18 → Int64 + Field::new("N3", DataType::Decimal128(10, 2), true), + Field::new("N4", DataType::Decimal128(38, 10), true), + Field::new("N5", DataType::Decimal128(38, 0), true), + ])); + + let expected_record = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new( + Decimal128Array::from(vec![Some(123)]) + .with_precision_and_scale(38, 0) + .unwrap(), + ), + Arc::new(Int64Array::from(vec![1234567890])), // NUMBER(10,0) → Int64 + Arc::new( + Decimal128Array::from(vec![Some(1234567890)]) + .with_precision_and_scale(10, 2) + .unwrap(), + ), + Arc::new( + Decimal128Array::from(vec![Some(1234567890123456781234567890)]) + .with_precision_and_scale(38, 10) + .unwrap(), + ), + Arc::new( + Decimal128Array::from(vec![Some(12345678901234567890123456789012345678)]) + .with_precision_and_scale(38, 0) + .unwrap(), + ), + ], + ) + .expect("Failed to create expected record batch"); + + arrow_oracle_one_way( + "NUMBER_TEST_TABLE", + create_table_stmt, + insert_table_stmt, + expected_record, + ) + .await; +} + +async fn arrow_oracle_one_way( + table_name: &str, + create_table_stmt: &str, + insert_table_stmt: &str, + expected_record: RecordBatch, +) -> Vec { + let pool = common::get_oracle_connection_pool().await; + let conn = pool + .connect_direct() + .await + .expect("Failed to get connection"); + + // Cleanup and create table + let _ = conn + .conn + .execute(&format!("DROP TABLE {}", table_name), &[]); + conn.conn + .execute(create_table_stmt, &[]) + .expect("Failed to create table"); + conn.conn + .execute(insert_table_stmt, &[]) + .expect("Failed to insert data"); + conn.conn.commit().expect("Failed to commit"); + + let sqltable_pool: Arc< + dyn DbConnectionPool + + Send + + Sync + + 'static, + > = Arc::new(pool); + let table = SqlTable::new("oracle", &sqltable_pool, table_name) + .await + .expect("Table should be created"); + + let ctx = SessionContext::new(); + ctx.register_table(table_name, Arc::new(table)) + .expect("Table should be registered"); + + let sql = format!("SELECT * FROM {}", table_name); + let df = ctx.sql(&sql).await.expect("Query failed"); + + let record_batches = df.collect().await.expect("Collect failed"); + + assert_eq!(record_batches.len(), 1); + assert_eq!(record_batches[0].schema(), expected_record.schema()); + assert_eq!(record_batches[0], expected_record); + + record_batches +} + +#[tokio::test] +async fn test_oracle_date_time_types() { + let create_table_stmt = " + CREATE TABLE date_time_test_table ( + d1 DATE, + t1 TIMESTAMP, + t2 TIMESTAMP(6), + t3 TIMESTAMP WITH TIME ZONE + ) + "; + let insert_table_stmt = " + INSERT INTO date_time_test_table (d1, t1, t2, t3) + VALUES ( + TO_DATE('2024-09-12', 'YYYY-MM-DD'), + TO_TIMESTAMP('2024-09-12 10:00:00.123', 'YYYY-MM-DD HH24:MI:SS.FF3'), + TO_TIMESTAMP('2024-09-12 10:00:00.123456', 'YYYY-MM-DD HH24:MI:SS.FF6'), + TO_TIMESTAMP_TZ('2024-09-12 10:00:00.123 +00:00', 'YYYY-MM-DD HH24:MI:SS.FF3 TZH:TZM') + ) + "; + + let schema = Arc::new(Schema::new(vec![ + Field::new("D1", DataType::Date32, true), + Field::new("T1", DataType::Timestamp(TimeUnit::Microsecond, None), true), + Field::new("T2", DataType::Timestamp(TimeUnit::Microsecond, None), true), + Field::new( + "T3", + DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())), + true, + ), + ])); + + let expected_record = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Date32Array::from(vec![19978])), // 2024-09-12 as days since epoch + Arc::new(TimestampMicrosecondArray::from(vec![1_726_135_200_123_000])), + Arc::new(TimestampMicrosecondArray::from(vec![1_726_135_200_123_456])), + Arc::new( + TimestampMicrosecondArray::from(vec![1_726_135_200_123_000]).with_timezone("UTC"), + ), + ], + ) + .expect("Failed to create expected record batch"); + + arrow_oracle_one_way( + "DATE_TIME_TEST_TABLE", + create_table_stmt, + insert_table_stmt, + expected_record, + ) + .await; +} + +#[tokio::test] +async fn test_oracle_integer_types() { + let create_table_stmt = " + CREATE TABLE integer_test_table ( + i1 NUMBER(5, 0), + i2 NUMBER(10, 0), + i3 NUMBER(18, 0), + i4 NUMBER(38, 0) + ) + "; + let insert_table_stmt = " + INSERT INTO integer_test_table (i1, i2, i3, i4) + VALUES (12345, 1234567890, 123456789012345678, 12345678901234567890123456789012345678) + "; + + // NUMBER(p, 0) where p <= 18 should map to Int64 + let schema = Arc::new(Schema::new(vec![ + Field::new("I1", DataType::Int64, true), + Field::new("I2", DataType::Int64, true), + Field::new("I3", DataType::Int64, true), + Field::new("I4", DataType::Decimal128(38, 0), true), // p > 18, stays Decimal128 + ])); + + let expected_record = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int64Array::from(vec![12345])), + Arc::new(Int64Array::from(vec![1234567890])), + Arc::new(Int64Array::from(vec![123456789012345678])), + Arc::new( + Decimal128Array::from(vec![Some(12345678901234567890123456789012345678)]) + .with_precision_and_scale(38, 0) + .unwrap(), + ), + ], + ) + .expect("Failed to create expected record batch"); + + arrow_oracle_one_way( + "INTEGER_TEST_TABLE", + create_table_stmt, + insert_table_stmt, + expected_record, + ) + .await; +} + +#[tokio::test] +async fn test_oracle_precision_aware_timestamps() { + let create_table_stmt = " + CREATE TABLE timestamp_precision_table ( + t0 TIMESTAMP(0), + t3 TIMESTAMP(3), + t6 TIMESTAMP(6), + t9 TIMESTAMP(9) + ) + "; + let insert_table_stmt = " + INSERT INTO timestamp_precision_table (t0, t3, t6, t9) + VALUES ( + TO_TIMESTAMP('2024-09-12 10:00:00', 'YYYY-MM-DD HH24:MI:SS'), + TO_TIMESTAMP('2024-09-12 10:00:00.123', 'YYYY-MM-DD HH24:MI:SS.FF3'), + TO_TIMESTAMP('2024-09-12 10:00:00.123456', 'YYYY-MM-DD HH24:MI:SS.FF6'), + TO_TIMESTAMP('2024-09-12 10:00:00.123456789', 'YYYY-MM-DD HH24:MI:SS.FF9') + ) + "; + + let schema = Arc::new(Schema::new(vec![ + Field::new("T0", DataType::Timestamp(TimeUnit::Second, None), true), + Field::new("T3", DataType::Timestamp(TimeUnit::Millisecond, None), true), + Field::new("T6", DataType::Timestamp(TimeUnit::Microsecond, None), true), + Field::new("T9", DataType::Timestamp(TimeUnit::Nanosecond, None), true), + ])); + + let expected_record = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(TimestampSecondArray::from(vec![1_726_135_200])), + Arc::new(TimestampMillisecondArray::from(vec![1_726_135_200_123])), + Arc::new(TimestampMicrosecondArray::from(vec![1_726_135_200_123_456])), + Arc::new(TimestampNanosecondArray::from(vec![ + 1_726_135_200_123_456_789, + ])), + ], + ) + .expect("Failed to create expected record batch"); + + arrow_oracle_one_way( + "TIMESTAMP_PRECISION_TABLE", + create_table_stmt, + insert_table_stmt, + expected_record, + ) + .await; +} + +#[tokio::test] +async fn test_oracle_interval_types() { + let create_table_stmt = " + CREATE TABLE interval_test_table ( + ym INTERVAL YEAR TO MONTH, + ds INTERVAL DAY TO SECOND + ) + "; + let insert_table_stmt = " + INSERT INTO interval_test_table (ym, ds) + VALUES ( + INTERVAL '2-6' YEAR TO MONTH, + INTERVAL '3 12:30:45.123456' DAY TO SECOND + ) + "; + + let schema = Arc::new(Schema::new(vec![ + Field::new("YM", DataType::Interval(IntervalUnit::YearMonth), true), + Field::new("DS", DataType::Interval(IntervalUnit::MonthDayNano), true), + ])); + + // INTERVAL '2-6' YEAR TO MONTH = 2*12 + 6 = 30 months + // INTERVAL '3 12:30:45.123456' DAY TO SECOND = 3 days + (12*3600 + 30*60 + 45)*1e9 + 123456000 nanos + let ds_nanos = (12 * 3600 + 30 * 60 + 45) * 1_000_000_000 + 123_456_000; + + let expected_record = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(IntervalYearMonthArray::from(vec![30])), + Arc::new(IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNano::new(0, 3, ds_nanos), + ])), + ], + ) + .expect("Failed to create expected record batch"); + + arrow_oracle_one_way( + "INTERVAL_TEST_TABLE", + create_table_stmt, + insert_table_stmt, + expected_record, + ) + .await; +} + +#[tokio::test] +async fn test_oracle_binary_types() { + let create_table_stmt = " + CREATE TABLE binary_test_table ( + r1 RAW(10), + r2 RAW(100) + ) + "; + let insert_table_stmt = " + INSERT INTO binary_test_table (r1, r2) + VALUES ( + HEXTORAW('DEADBEEF'), + HEXTORAW('ABCDEF0123456789') + ) + "; + + let schema = Arc::new(Schema::new(vec![ + Field::new("R1", DataType::Binary, true), + Field::new("R2", DataType::Binary, true), + ])); + + let expected_record = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(BinaryArray::from_vec(vec![b"\xDE\xAD\xBE\xEF"])), + Arc::new(BinaryArray::from_vec(vec![ + b"\xAB\xCD\xEF\x01\x23\x45\x67\x89", + ])), + ], + ) + .expect("Failed to create expected record batch"); + + arrow_oracle_one_way( + "BINARY_TEST_TABLE", + create_table_stmt, + insert_table_stmt, + expected_record, + ) + .await; +} + +#[tokio::test] +async fn test_oracle_lob_types() { + let create_table_stmt = " + CREATE TABLE lob_test_table ( + b1 BLOB, + c1 CLOB + ) + "; + let insert_table_stmt = " + INSERT INTO lob_test_table (b1, c1) + VALUES ( + HEXTORAW('0102030405'), + 'Large text content for CLOB' + ) + "; + + let schema = Arc::new(Schema::new(vec![ + Field::new("B1", DataType::LargeBinary, true), + Field::new("C1", DataType::LargeUtf8, true), + ])); + + let expected_record = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(LargeBinaryArray::from_vec(vec![b"\x01\x02\x03\x04\x05"])), + Arc::new(LargeStringArray::from(vec!["Large text content for CLOB"])), + ], + ) + .expect("Failed to create expected record batch"); + + arrow_oracle_one_way( + "LOB_TEST_TABLE", + create_table_stmt, + insert_table_stmt, + expected_record, + ) + .await; +} + +#[tokio::test] +async fn test_oracle_null_handling() { + let create_table_stmt = " + CREATE TABLE null_test_table ( + n1 NUMBER, + d1 DATE, + t1 TIMESTAMP, + r1 RAW(10), + b1 BLOB, + c1 CLOB + ) + "; + let insert_table_stmt = " + INSERT INTO null_test_table (n1, d1, t1, r1, b1, c1) + VALUES (NULL, NULL, NULL, NULL, NULL, NULL) + "; + + let schema = Arc::new(Schema::new(vec![ + Field::new("N1", DataType::Decimal128(38, 0), true), + Field::new("D1", DataType::Date32, true), + Field::new("T1", DataType::Timestamp(TimeUnit::Microsecond, None), true), + Field::new("R1", DataType::Binary, true), + Field::new("B1", DataType::LargeBinary, true), + Field::new("C1", DataType::LargeUtf8, true), + ])); + + let expected_record = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new( + Decimal128Array::from(vec![Option::::None]) + .with_precision_and_scale(38, 0) + .unwrap(), + ), + Arc::new(Date32Array::from(vec![Option::::None])), + Arc::new(TimestampMicrosecondArray::from(vec![Option::::None])), + Arc::new(BinaryArray::from_opt_vec(vec![None])), + Arc::new(LargeBinaryArray::from_opt_vec(vec![None])), + Arc::new(LargeStringArray::from(vec![Option::<&str>::None])), + ], + ) + .expect("Failed to create expected record batch"); + + arrow_oracle_one_way( + "NULL_TEST_TABLE", + create_table_stmt, + insert_table_stmt, + expected_record, + ) + .await; +} +#[tokio::test] +async fn test_oracle_sql_generation_limit() { + let pool = common::get_oracle_connection_pool().await; + let factory = OracleTableFactory::new(Arc::new(pool)); + + // reusing "ALL_TABLES" as it is guaranteed to exist + let table_name = "ALL_TABLES"; + let provider = factory + .table_provider(TableReference::from(table_name)) + .await + .expect("Failed to create table provider"); + + let ctx = SessionContext::new(); + ctx.register_table("system_tables", provider) + .expect("Failed to register table"); + + // Test 1: Verify FETCH FIRST syntax for LIMIT + let sql_limit = "SELECT * FROM system_tables LIMIT 5"; + let df_limit = ctx.sql(sql_limit).await.expect("Failed to build plan"); + let plan_limit = df_limit + .create_physical_plan() + .await + .expect("Failed to create physical plan"); + let display_limit = format!( + "{}", + datafusion::physical_plan::displayable(plan_limit.as_ref()).indent(true) + ); + assert!( + display_limit.contains("FETCH FIRST 5 ROWS ONLY"), + "Plan did not contain FETCH FIRST syntax: {}", + display_limit + ); +} + +#[tokio::test] +async fn test_oracle_sql_generation_filter() { + let pool = common::get_oracle_connection_pool().await; + let factory = OracleTableFactory::new(Arc::new(pool)); + + let table_name = "ALL_TABLES"; + let provider = factory + .table_provider(TableReference::from(table_name)) + .await + .expect("Failed to create table provider"); + + let ctx = SessionContext::new(); + ctx.register_table("system_tables", provider) + .expect("Failed to register table"); + + // Test 2: Verify identifier quoting (using a filter) + // "OWNER" column exists in ALL_TABLES + let sql_filter = "SELECT * FROM system_tables WHERE \"OWNER\" = 'SYS'"; + let df_filter = ctx.sql(sql_filter).await.expect("Failed to build plan"); + let plan_filter = df_filter + .create_physical_plan() + .await + .expect("Failed to create physical plan"); + let display_filter = format!( + "{}", + datafusion::physical_plan::displayable(plan_filter.as_ref()).indent(true) + ); + + // Verify that the custom dialect correctly quotes identifiers (e.g. "OWNER"). + // Note: The expression might be wrapped in parentheses like WHERE ("OWNER" = 'SYS') + assert!( + display_filter.contains(r#""OWNER""#), + "Plan did not contain quoted identifier \"OWNER\": {}", + display_filter + ); +} + +#[tokio::test] +async fn test_oracle_large_query_streaming() { + let pool = common::get_oracle_connection_pool().await; + let conn = pool + .connect_direct() + .await + .expect("Failed to get connection"); + let oracle_conn = conn.conn.clone(); + + // Create table with test data + let table_name = "LARGE_QUERY_TEST"; + let _ = oracle_conn.execute(&format!("DROP TABLE {}", table_name), &[]); + oracle_conn + .execute( + &format!("CREATE TABLE {} (id NUMBER, val VARCHAR2(50))", table_name), + &[], + ) + .expect("Failed to create table"); + + // Insert 5000 rows (chunk size is 4096, so this guarantees at least 2 chunks) + let batch_size = 1000; + for i in 0..5 { + let mut ids = Vec::with_capacity(batch_size); + let mut vals = Vec::with_capacity(batch_size); + for j in 0..batch_size { + let id = i * batch_size + j; + ids.push(id as i64); + vals.push(format!("value_{}", id)); + } + + let sql = format!("INSERT INTO {} (id, val) VALUES (:1, :2)", table_name); + let mut stmt = oracle_conn.statement(&sql).build().expect("Prepare failed"); + + // Execute batch insertion manually for simplicity in test + for j in 0..batch_size { + stmt.execute(&[&ids[j], &vals[j]]).expect("Insert failed"); + } + } + oracle_conn.commit().expect("Commit failed"); + + // Query using query_arrow directly to verify chunks + let sql = format!("SELECT * FROM {} ORDER BY id", table_name); + let mut stream = conn + .as_async() + .expect("Should be async") + .query_arrow(&sql, &[], None) + .await + .expect("Query failed"); + + let mut total_rows = 0; + let mut batch_count = 0; + + while let Some(batch_result) = stream.next().await { + let batch = batch_result.expect("Batch should be Ok"); + total_rows += batch.num_rows(); + batch_count += 1; + println!("Batch {}: {} rows", batch_count, batch.num_rows()); + } + + assert_eq!(total_rows, 5000, "Total rows mismatch"); + // Verify we got multiple batches (streaming active) + assert!( + batch_count > 1, + "Expected streaming (more than 1 batch), got {}", + batch_count + ); +} diff --git a/instantclient-basiclite-linuxx64.zip b/instantclient-basiclite-linuxx64.zip new file mode 100644 index 00000000..d34442af Binary files /dev/null and b/instantclient-basiclite-linuxx64.zip differ