Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 48 additions & 4 deletions ext-arrow/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,20 @@ pub trait ArrowClientExt {
/// # Errors
/// Any [`ArrowError`][arrow_schema::ArrowError]s are wrapped as [`Error::Other`].
fn insert_arrow(&self, table: &str) -> Result<ArrowInsert, Error>;

/// Begin inserting Arrow [`RecordBatch`]es into the target table.
///
/// The request isn't begun until the first batch is written.
///
/// `sql` should be of the form `INSERT INTO [<schema>.]<table>[(<column>, ...)] FORMAT ArrowStream`.
///
/// # Note: Missing or Unknown Columns
/// Any fields in the record stream which do not match the target table are silently ignored
/// by default, which could lead to data loss if this is the result of a typo;
/// the intended field in the table would be filled with the default value for the type instead.
///
/// This method is intended for advanced usage only.
fn insert_arrow_with(&self, sql: &str) -> ArrowInsert;
}

impl ArrowClientExt for Client {
Expand All @@ -52,13 +66,23 @@ impl ArrowClientExt for Client {
.map_err(|e| Error::Other(e.into()))?;

Ok(ArrowInsert {
state: InsertState::NotStarted {
state: InsertState::TableName {
client: self.clone(),
table: table.to_string(),
},
sent_rows: Saturating(0),
})
}

fn insert_arrow_with(&self, sql: &str) -> ArrowInsert {
ArrowInsert {
state: InsertState::PresetSql {
sql: sql.to_string(),
client: self.clone(),
},
sent_rows: Saturating(0),
}
}
}

/// Extension methods for [`clickhouse::query::Query`] for use with Arrow.
Expand Down Expand Up @@ -115,7 +139,8 @@ pub struct ArrowInsert {
}

enum InsertState {
NotStarted { table: String, client: Client },
TableName { table: String, client: Client },
PresetSql { sql: String, client: Client },
Started(StreamWriter<InsertWriter>),
Finished,
}
Expand Down Expand Up @@ -200,6 +225,8 @@ impl ArrowInsert {

/// Flush the remaining data and finish the `INSERT` request.
///
/// If no data has been written, no request will be sent.
///
/// # Not Cancel-Safe
/// If this method is canceled while data is being flushed, any data left in the buffer is lost.
///
Expand All @@ -211,7 +238,9 @@ impl ArrowInsert {
pub async fn end(self) -> Result<(), Error> {
let mut insert = match self.state {
InsertState::Started(writer) => writer.into_inner().map_err(wrap_arrow_err)?.insert,
InsertState::NotStarted { .. } | InsertState::Finished => return Ok(()),
InsertState::TableName { .. }
| InsertState::PresetSql { .. }
| InsertState::Finished => return Ok(()),
};

tracing::record_all!(
Expand All @@ -237,7 +266,7 @@ impl InsertState {
}

match mem::replace(self, Self::Finished) {
Self::NotStarted { table, client } => {
Self::TableName { table, client } => {
let mut query_string = "INSERT INTO ".to_string();

sql_escape_identifier(&table, &mut query_string)
Expand Down Expand Up @@ -275,6 +304,21 @@ impl InsertState {
);
Ok(())
}
Self::PresetSql { sql, client } => {
let insert = client
.insert_formatted_with(sql)
// Prevent ClickHouse from double-compressing
.with_setting("output_format_arrow_compression_method", "none")
Copy link
Copy Markdown

@joe-clickhouse joe-clickhouse Jun 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think this'll be a no-op for inserts? Any double compression concerns on insert would be in the transport-level on top of Arrow. Also, that same 5 line method chain is repeated twice, only diff seems to be where the sql comes from. Possible op to factor it out...not a big deal though.

// Add specific product info to let us track Arrow adoption
.with_product_info("clickhouse-ext-arrow", _priv::CARGO_PKG_VERSION)
.buffered();

*self = Self::Started(
StreamWriter::try_new(InsertWriter { insert }, schema)
.map_err(wrap_arrow_err)?,
);
Ok(())
}
Self::Started(writer) => {
*self = Self::Started(writer);
Ok(())
Expand Down
91 changes: 91 additions & 0 deletions tests/it/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,97 @@ async fn insert() {
assert_eq!(result, expected);
}

#[tokio::test]
async fn insert_custom() {
let client = prepare_database!();

client
.query(
"CREATE TABLE arrow_custom_insert_test(bar Int32, baz String) ENGINE = MergeTree ORDER BY bar",
)
.execute()
.await
.unwrap();

let batch_size = 100;
let num_batches = 100;
let mut next_id = 1..;

let mut batches = Vec::new();
let schema = Arc::new(Schema::new(vec![
Field::new("bar", DataType::Int32, false),
Field::new("baz", DataType::Utf8, false),
]));

for batch in 1..=num_batches {
let bars: PrimitiveArray<Int32Type> = (0..batch_size)
.zip(&mut next_id)
.map(|(_, id)| id)
.collect();

let bazzes: StringArray = bars
.iter()
.filter_map(|bar| {
let bar = bar?;
Some(format!("batch_{batch}_bar_{bar}"))
})
.collect::<Vec<String>>()
.into();

batches.push(
RecordBatch::try_new(schema.clone(), vec![Arc::new(bars), Arc::new(bazzes)]).unwrap(),
);
}

let mut insert = client
.insert_arrow_with("INSERT INTO arrow_custom_insert_test(bar, baz) FORMAT ArrowStream");

for batch in batches {
insert.write(&batch).await.unwrap();
}

insert.end().await.unwrap();

let result = client
.query(
"SELECT \
count(*) AS row_count, \
first_value(bar) AS min_bar, \
first_value(baz) AS min_baz,
last_value(bar) AS max_bar, \
last_value(baz) AS max_baz \
FROM (SELECT * FROM arrow_custom_insert_test ORDER BY bar)",
)
.fetch_arrow()
.unwrap()
.collect_merged()
.await
.unwrap();

let expected_count = batch_size * num_batches;

let expected = RecordBatch::try_new(
Schema::new(vec![
Field::new("row_count", DataType::UInt64, false),
Field::new("min_bar", DataType::Int32, false),
Field::new("min_baz", DataType::Utf8, false),
Field::new("max_bar", DataType::Int32, false),
Field::new("max_baz", DataType::Utf8, false),
])
.into(),
vec![
create_array!(UInt64, [expected_count]),
create_array!(Int32, [1]),
create_array!(Utf8, ["batch_1_bar_1"]),
create_array!(Int32, [expected_count as i32]),
create_array!(Utf8, ["batch_100_bar_10000"]),
],
)
.unwrap();

assert_eq!(result, expected);
}

#[tokio::test]
async fn query_empty_response() {
let client = get_client();
Expand Down
Loading