@@ -24,7 +24,7 @@ use adbc_core::{Connection, Database, Driver, LOAD_FLAG_DEFAULT, Optionable, Sta
2424use adbc_driver_manager:: ManagedDriver ;
2525use arrow:: compute:: cast;
2626use arrow:: datatypes:: { DataType , Schema } ;
27- use arrow_array:: RecordBatch ;
27+ use arrow_array:: { Array , RecordBatch , StringArray } ;
2828use arrow_schema:: Field ;
2929use snafu:: prelude:: * ;
3030use std:: collections:: HashMap ;
@@ -56,15 +56,21 @@ pub type Result<T, E = Error> = std::result::Result<T, E>;
5656pub struct AdbcConnection {
5757 conn : adbc_driver_manager:: ManagedConnection ,
5858 downcast_utf8view : bool ,
59+ resolve_opaque_numerics : bool ,
5960}
6061
6162impl AdbcConnection {
6263 /// Create an `AdbcConnection` from an already-established [`ManagedConnection`].
6364 #[ must_use]
64- pub fn new ( conn : adbc_driver_manager:: ManagedConnection , downcast_utf8view : bool ) -> Self {
65+ pub fn new (
66+ conn : adbc_driver_manager:: ManagedConnection ,
67+ downcast_utf8view : bool ,
68+ resolve_opaque_numerics : bool ,
69+ ) -> Self {
6570 Self {
6671 conn,
6772 downcast_utf8view,
73+ resolve_opaque_numerics,
6874 }
6975 }
7076
@@ -106,7 +112,11 @@ impl AdbcConnection {
106112 reason : e. to_string ( ) ,
107113 } ) ?;
108114
109- Ok ( Self :: new ( conn, driver_name == "databricks" ) )
115+ Ok ( Self :: new (
116+ conn,
117+ driver_name == "databricks" ,
118+ driver_name == "postgresql" ,
119+ ) )
110120 }
111121
112122 /// Lightweight check that the connection is still usable.
@@ -135,9 +145,15 @@ impl AdbcConnection {
135145 reason : e. to_string ( ) ,
136146 } ) ?;
137147
138- reader
148+ let mut batches = reader
139149 . collect :: < std:: result:: Result < Vec < _ > , _ > > ( )
140- . context ( ReadBatchSnafu )
150+ . context ( ReadBatchSnafu ) ?;
151+
152+ if self . resolve_opaque_numerics {
153+ batches = resolve_opaque_numerics ( batches) ;
154+ }
155+
156+ Ok ( batches)
141157 }
142158
143159 /// Execute a SQL data-modification statement and return the affected row count when provided by the driver.
@@ -275,3 +291,110 @@ fn downcast_utf8view(batch: &RecordBatch) -> RecordBatch {
275291
276292 RecordBatch :: try_new ( Arc :: new ( Schema :: new ( fields) ) , columns) . unwrap ( )
277293}
294+
295+ /// Returns `true` if the field uses the Arrow opaque extension type for
296+ /// PostgreSQL `numeric`.
297+ fn is_opaque_numeric ( field : & Field ) -> bool {
298+ if !matches ! ( field. data_type( ) , DataType :: Utf8 ) {
299+ return false ;
300+ }
301+ let metadata = field. metadata ( ) ;
302+ let Some ( ext_name) = metadata. get ( "ARROW:extension:name" ) else {
303+ return false ;
304+ } ;
305+ if ext_name != "arrow.opaque" {
306+ return false ;
307+ }
308+ let Some ( ext_meta) = metadata. get ( "ARROW:extension:metadata" ) else {
309+ return false ;
310+ } ;
311+ serde_json:: from_str :: < serde_json:: Value > ( ext_meta)
312+ . ok ( )
313+ . and_then ( |v| v. get ( "type_name" ) ?. as_str ( ) . map ( |s| s == "numeric" ) )
314+ . unwrap_or ( false )
315+ }
316+
317+ /// Determine the maximum decimal scale (digits after the decimal point)
318+ /// across all non-null values in a string array.
319+ fn max_decimal_scale ( array : & StringArray ) -> i8 {
320+ let mut scale: i8 = 0 ;
321+ for i in 0 ..array. len ( ) {
322+ if array. is_null ( i) {
323+ continue ;
324+ }
325+ let val = array. value ( i) ;
326+ if let Some ( dot_pos) = val. find ( '.' ) {
327+ let s = ( val. len ( ) - dot_pos - 1 ) as i8 ;
328+ scale = scale. max ( s) ;
329+ }
330+ }
331+ scale
332+ }
333+
334+ /// Convert columns returned by the PostgreSQL ADBC driver as
335+ /// `Utf8` with `arrow.opaque` extension metadata for `numeric` to
336+ /// `Decimal128`, matching the representation used in checkpoint
337+ /// parquet files.
338+ ///
339+ /// The scale for each column is inferred from the actual data across
340+ /// all batches. If casting fails (e.g. the column contains `NaN` or
341+ /// `inf`), the original `Utf8` column is kept.
342+ fn resolve_opaque_numerics ( batches : Vec < RecordBatch > ) -> Vec < RecordBatch > {
343+ if batches. is_empty ( ) {
344+ return batches;
345+ }
346+
347+ let schema = batches[ 0 ] . schema ( ) ;
348+ let opaque_cols: Vec < usize > = schema
349+ . fields ( )
350+ . iter ( )
351+ . enumerate ( )
352+ . filter_map ( |( i, f) | if is_opaque_numeric ( f) { Some ( i) } else { None } )
353+ . collect ( ) ;
354+
355+ if opaque_cols. is_empty ( ) {
356+ return batches;
357+ }
358+
359+ // Determine the max scale for each opaque numeric column across
360+ // all batches so every batch uses a consistent Decimal128 type.
361+ let mut scales: Vec < i8 > = vec ! [ 0 ; opaque_cols. len( ) ] ;
362+ for batch in & batches {
363+ for ( j, & col_idx) in opaque_cols. iter ( ) . enumerate ( ) {
364+ if let Some ( arr) = batch. column ( col_idx) . as_any ( ) . downcast_ref :: < StringArray > ( ) {
365+ scales[ j] = scales[ j] . max ( max_decimal_scale ( arr) ) ;
366+ }
367+ }
368+ }
369+
370+ batches
371+ . into_iter ( )
372+ . map ( |batch| {
373+ let schema = batch. schema ( ) ;
374+ let mut fields = Vec :: with_capacity ( schema. fields ( ) . len ( ) ) ;
375+ let mut columns = Vec :: with_capacity ( schema. fields ( ) . len ( ) ) ;
376+
377+ for ( i, field) in schema. fields ( ) . iter ( ) . enumerate ( ) {
378+ if let Some ( j) = opaque_cols. iter ( ) . position ( |& idx| idx == i) {
379+ let target_type = DataType :: Decimal128 ( 38 , scales[ j] ) ;
380+ match cast ( batch. column ( i) , & target_type) {
381+ Ok ( converted) => {
382+ fields. push ( Arc :: new ( Field :: new (
383+ field. name ( ) ,
384+ target_type,
385+ field. is_nullable ( ) ,
386+ ) ) ) ;
387+ columns. push ( converted) ;
388+ continue ;
389+ }
390+ Err ( _) => { /* fall through to keep original */ }
391+ }
392+ }
393+ fields. push ( field. clone ( ) ) ;
394+ columns. push ( batch. column ( i) . clone ( ) ) ;
395+ }
396+
397+ RecordBatch :: try_new ( Arc :: new ( Schema :: new ( fields) ) , columns) . unwrap ( )
398+ } )
399+ . collect ( )
400+ }
0 commit comments