Skip to content

Commit 215df6d

Browse files
authored
Various improvements for Databricks and result validation (#164)
* Various improvements * WIP * Fix * Lint
1 parent 290dd72 commit 215df6d

6 files changed

Lines changed: 171 additions & 10 deletions

File tree

crates/adbc_client/src/lib.rs

Lines changed: 128 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use adbc_core::{Connection, Database, Driver, LOAD_FLAG_DEFAULT, Optionable, Sta
2424
use adbc_driver_manager::ManagedDriver;
2525
use arrow::compute::cast;
2626
use arrow::datatypes::{DataType, Schema};
27-
use arrow_array::RecordBatch;
27+
use arrow_array::{Array, RecordBatch, StringArray};
2828
use arrow_schema::Field;
2929
use snafu::prelude::*;
3030
use std::collections::HashMap;
@@ -56,15 +56,21 @@ pub type Result<T, E = Error> = std::result::Result<T, E>;
5656
pub struct AdbcConnection {
5757
conn: adbc_driver_manager::ManagedConnection,
5858
downcast_utf8view: bool,
59+
resolve_opaque_numerics: bool,
5960
}
6061

6162
impl AdbcConnection {
6263
/// Create an `AdbcConnection` from an already-established [`ManagedConnection`].
6364
#[must_use]
64-
pub fn new(conn: adbc_driver_manager::ManagedConnection, downcast_utf8view: bool) -> Self {
65+
pub fn new(
66+
conn: adbc_driver_manager::ManagedConnection,
67+
downcast_utf8view: bool,
68+
resolve_opaque_numerics: bool,
69+
) -> Self {
6570
Self {
6671
conn,
6772
downcast_utf8view,
73+
resolve_opaque_numerics,
6874
}
6975
}
7076

@@ -106,7 +112,11 @@ impl AdbcConnection {
106112
reason: e.to_string(),
107113
})?;
108114

109-
Ok(Self::new(conn, driver_name == "databricks"))
115+
Ok(Self::new(
116+
conn,
117+
driver_name == "databricks",
118+
driver_name == "postgresql",
119+
))
110120
}
111121

112122
/// Lightweight check that the connection is still usable.
@@ -135,9 +145,15 @@ impl AdbcConnection {
135145
reason: e.to_string(),
136146
})?;
137147

138-
reader
148+
let mut batches = reader
139149
.collect::<std::result::Result<Vec<_>, _>>()
140-
.context(ReadBatchSnafu)
150+
.context(ReadBatchSnafu)?;
151+
152+
if self.resolve_opaque_numerics {
153+
batches = resolve_opaque_numerics(batches);
154+
}
155+
156+
Ok(batches)
141157
}
142158

143159
/// Execute a SQL data-modification statement and return the affected row count when provided by the driver.
@@ -275,3 +291,110 @@ fn downcast_utf8view(batch: &RecordBatch) -> RecordBatch {
275291

276292
RecordBatch::try_new(Arc::new(Schema::new(fields)), columns).unwrap()
277293
}
294+
295+
/// Returns `true` if the field uses the Arrow opaque extension type for
296+
/// PostgreSQL `numeric`.
297+
fn is_opaque_numeric(field: &Field) -> bool {
298+
if !matches!(field.data_type(), DataType::Utf8) {
299+
return false;
300+
}
301+
let metadata = field.metadata();
302+
let Some(ext_name) = metadata.get("ARROW:extension:name") else {
303+
return false;
304+
};
305+
if ext_name != "arrow.opaque" {
306+
return false;
307+
}
308+
let Some(ext_meta) = metadata.get("ARROW:extension:metadata") else {
309+
return false;
310+
};
311+
serde_json::from_str::<serde_json::Value>(ext_meta)
312+
.ok()
313+
.and_then(|v| v.get("type_name")?.as_str().map(|s| s == "numeric"))
314+
.unwrap_or(false)
315+
}
316+
317+
/// Determine the maximum decimal scale (digits after the decimal point)
318+
/// across all non-null values in a string array.
319+
fn max_decimal_scale(array: &StringArray) -> i8 {
320+
let mut scale: i8 = 0;
321+
for i in 0..array.len() {
322+
if array.is_null(i) {
323+
continue;
324+
}
325+
let val = array.value(i);
326+
if let Some(dot_pos) = val.find('.') {
327+
let s = (val.len() - dot_pos - 1) as i8;
328+
scale = scale.max(s);
329+
}
330+
}
331+
scale
332+
}
333+
334+
/// Convert columns returned by the PostgreSQL ADBC driver as
335+
/// `Utf8` with `arrow.opaque` extension metadata for `numeric` to
336+
/// `Decimal128`, matching the representation used in checkpoint
337+
/// parquet files.
338+
///
339+
/// The scale for each column is inferred from the actual data across
340+
/// all batches. If casting fails (e.g. the column contains `NaN` or
341+
/// `inf`), the original `Utf8` column is kept.
342+
fn resolve_opaque_numerics(batches: Vec<RecordBatch>) -> Vec<RecordBatch> {
343+
if batches.is_empty() {
344+
return batches;
345+
}
346+
347+
let schema = batches[0].schema();
348+
let opaque_cols: Vec<usize> = schema
349+
.fields()
350+
.iter()
351+
.enumerate()
352+
.filter_map(|(i, f)| if is_opaque_numeric(f) { Some(i) } else { None })
353+
.collect();
354+
355+
if opaque_cols.is_empty() {
356+
return batches;
357+
}
358+
359+
// Determine the max scale for each opaque numeric column across
360+
// all batches so every batch uses a consistent Decimal128 type.
361+
let mut scales: Vec<i8> = vec![0; opaque_cols.len()];
362+
for batch in &batches {
363+
for (j, &col_idx) in opaque_cols.iter().enumerate() {
364+
if let Some(arr) = batch.column(col_idx).as_any().downcast_ref::<StringArray>() {
365+
scales[j] = scales[j].max(max_decimal_scale(arr));
366+
}
367+
}
368+
}
369+
370+
batches
371+
.into_iter()
372+
.map(|batch| {
373+
let schema = batch.schema();
374+
let mut fields = Vec::with_capacity(schema.fields().len());
375+
let mut columns = Vec::with_capacity(schema.fields().len());
376+
377+
for (i, field) in schema.fields().iter().enumerate() {
378+
if let Some(j) = opaque_cols.iter().position(|&idx| idx == i) {
379+
let target_type = DataType::Decimal128(38, scales[j]);
380+
match cast(batch.column(i), &target_type) {
381+
Ok(converted) => {
382+
fields.push(Arc::new(Field::new(
383+
field.name(),
384+
target_type,
385+
field.is_nullable(),
386+
)));
387+
columns.push(converted);
388+
continue;
389+
}
390+
Err(_) => { /* fall through to keep original */ }
391+
}
392+
}
393+
fields.push(field.clone());
394+
columns.push(batch.column(i).clone());
395+
}
396+
397+
RecordBatch::try_new(Arc::new(Schema::new(fields)), columns).unwrap()
398+
})
399+
.collect()
400+
}

crates/adbc_client/src/pool.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,20 @@ const DEFAULT_POOL_SIZE: u32 = 10;
4141
pub struct AdbcConnectionManager {
4242
database: Arc<ManagedDatabase>,
4343
downcast_utf8view: bool,
44+
resolve_opaque_numerics: bool,
4445
}
4546

4647
impl AdbcConnectionManager {
4748
/// Create a new manager from an existing [`ManagedDatabase`].
48-
pub fn new(database: ManagedDatabase, downcast_utf8view: bool) -> Self {
49+
pub fn new(
50+
database: ManagedDatabase,
51+
downcast_utf8view: bool,
52+
resolve_opaque_numerics: bool,
53+
) -> Self {
4954
Self {
5055
database: Arc::new(database),
5156
downcast_utf8view,
57+
resolve_opaque_numerics,
5258
}
5359
}
5460
}
@@ -64,7 +70,11 @@ impl r2d2::ManageConnection for AdbcConnectionManager {
6470
.map_err(|e| Error::CreateConnection {
6571
reason: e.to_string(),
6672
})?;
67-
Ok(AdbcConnection::new(conn, self.downcast_utf8view))
73+
Ok(AdbcConnection::new(
74+
conn,
75+
self.downcast_utf8view,
76+
self.resolve_opaque_numerics,
77+
))
6878
}
6979

7080
fn is_valid(&self, conn: &mut Self::Connection) -> std::result::Result<(), Self::Error> {
@@ -119,7 +129,8 @@ pub fn create_pool(
119129
reason: e.to_string(),
120130
})?;
121131

122-
let manager = AdbcConnectionManager::new(db, driver_name == "databricks");
132+
let manager =
133+
AdbcConnectionManager::new(db, driver_name == "databricks", driver_name == "postgresql");
123134

124135
r2d2::Pool::builder()
125136
.max_size(pool_size)

crates/etl/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,7 @@ async fn read_batches_until_min_rows(
357357
logical_steps_consumed: &StdArc<AtomicU64>,
358358
table_name: &str,
359359
start_batch_id: u64,
360+
step_limit: Option<usize>,
360361
) -> Result<(Vec<RecordBatch>, Vec<String>, bool, u64, u64), String> {
361362
let mut all_batches: Vec<RecordBatch> = Vec::new();
362363
let mut total_rows: usize = 0;
@@ -422,6 +423,7 @@ async fn read_batches_until_min_rows(
422423
&& join_set.len() < MAX_IN_FLIGHT_SOURCE_BATCH_READS.max(1)
423424
&& total_rows < TARGET_BATCH_ROWS
424425
&& !table_finished
426+
&& step_limit.map_or(true, |limit| consumed_work_units < limit as u64)
425427
{
426428
let reservation = {
427429
let mut state = work_state.lock().expect("work_state lock poisoned");
@@ -1557,6 +1559,7 @@ async fn run_pipeline(
15571559
&logical_steps_consumed,
15581560
&table_name,
15591561
batch_id,
1562+
step_limit,
15601563
)
15611564
.await
15621565
{

crates/system-adapter-protocol/src/client.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,11 @@ impl Client {
115115
.stdout(std::process::Stdio::piped())
116116
.stderr(std::process::Stdio::inherit());
117117

118+
// Place the child in its own process group so it doesn't receive
119+
// SIGINT when the user presses ctrl+c, allowing orderly teardown.
120+
#[cfg(unix)]
121+
cmd.process_group(0);
122+
118123
let mut child = cmd.spawn().map_err(|e| {
119124
ClientError::Transport(format!(
120125
"Failed to start stdio command '{command_str}': {e}"

crates/test-framework/src/queries/validation/mod.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,16 @@ fn datatype_equivalent(expected_type: &DataType, actual_type: &DataType) -> bool
148148
// Existing numeric and string type equivalences
149149
_ => matches!(
150150
(expected_type, actual_type),
151-
(DataType::Float32, DataType::Float64)
151+
(DataType::Decimal128(_, _), DataType::Decimal128(_, _))
152+
| (DataType::Float32, DataType::Float64)
152153
| (
153154
DataType::Float64 | DataType::Int64,
154155
DataType::Decimal128(_, _)
155156
)
157+
| (
158+
DataType::Decimal128(_, _),
159+
DataType::Int64 | DataType::Int32 | DataType::Float64
160+
)
156161
| (DataType::Int32, DataType::Int64)
157162
| (
158163
DataType::Int64,

src/commands/load/mod.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,7 @@ pub(crate) async fn run(
596596
const POLL_INTERVAL: Duration = Duration::from_secs(5);
597597
const MAX_WAIT: Duration = Duration::from_secs(600);
598598
let mut timed_out = false;
599+
let mut interrupted = false;
599600
loop {
600601
let status =
601602
validation_controller.status_rx.borrow().clone();
@@ -616,7 +617,13 @@ pub(crate) async fn run(
616617
timed_out = true;
617618
break;
618619
}
619-
tokio::time::sleep(POLL_INTERVAL).await;
620+
tokio::select! {
621+
_ = tokio::time::sleep(POLL_INTERVAL) => {}
622+
_ = signal::ctrl_c() => {
623+
interrupted = true;
624+
break;
625+
}
626+
}
620627
}
621628

622629
// Read the validation status before disabling.
@@ -656,6 +663,13 @@ pub(crate) async fn run(
656663
.command_tx
657664
.send(Some(ValidationCommand::Disable));
658665

666+
if interrupted {
667+
eprintln!("Interrupt received during checkpoint validation, stopping...");
668+
shutdown_token.cancel();
669+
etl_pipeline.cancel();
670+
break Some("Interrupted by user".to_string());
671+
}
672+
659673
if timed_out {
660674
eprintln!(
661675
"Checkpoint {} validation timed out after {}s without convergence, aborting run",

0 commit comments

Comments
 (0)