diff --git a/Cargo.lock b/Cargo.lock index a35d18b..4ac7b4a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -450,6 +450,7 @@ dependencies = [ "fallible-streaming-iterator", "hashlink", "libduckdb-sys", + "num", "num-integer", "rust_decimal", "strum 0.27.2", diff --git a/Cargo.toml b/Cargo.toml index 407bc99..1559c60 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,11 @@ path = "src/wasm_lib.rs" crate-type = ["staticlib"] [dependencies] -duckdb = { version = "1.4.1", features = ["vtab-loadable"] } +duckdb = { version = "1.4.1", features = [ + "vscalar", + "vscalar-arrow", + "vtab-arrow", + "vtab-loadable" +] } duckdb-loadable-macros = "0.1.11" libduckdb-sys = { version = "1.4.1", features = ["loadable-extension"] } diff --git a/src/lib.rs b/src/lib.rs index ad65b14..3438c35 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,17 +4,77 @@ extern crate libduckdb_sys; use duckdb::{ core::{DataChunkHandle, Inserter, LogicalTypeHandle, LogicalTypeId}, + types::DuckString, + vscalar::{ScalarFunctionSignature, VScalar}, + vtab::arrow::WritableVector, vtab::{BindInfo, InitInfo, TableFunctionInfo, VTab}, Connection, Result, }; use duckdb_loadable_macros::duckdb_entrypoint_c_api; use libduckdb_sys as ffi; +use libduckdb_sys::duckdb_string_t; use std::{ error::Error, ffi::CString, sync::atomic::{AtomicBool, Ordering}, }; +struct EchoState { + multiplier: usize, + separator: String, + prefix: String, +} + +impl Default for EchoState { + fn default() -> Self { + Self { + multiplier: 3, + separator: "📢".to_string(), + prefix: "🐤".to_string(), + } + } +} + +struct EchoScalar {} + +impl VScalar for EchoScalar { + type State = EchoState; + + unsafe fn invoke( + state: &Self::State, + input: &mut DataChunkHandle, + output: &mut dyn WritableVector, + ) -> Result<(), Box> { + let values = input.flat_vector(0); + let values = values.as_slice_with_len::(input.len()); + let strings = values + .iter() + .map(|ptr| DuckString::new(&mut { *ptr }).as_str().to_string()) + .take(input.len()); + let output = output.flat_vector(); + + for (i, s) in strings.enumerate() { + let res = format!( + "{} {}", + state.prefix, + std::iter::repeat(s) + .take(state.multiplier) + .collect::>() + .join(&state.separator) + ); + output.insert(i, res.as_str()); + } + Ok(()) + } + + fn signatures() -> Vec { + vec![ScalarFunctionSignature::exact( + vec![LogicalTypeId::Varchar.into()], + LogicalTypeId::Varchar.into(), + )] + } +} + #[repr(C)] struct HelloBindData { name: String, @@ -43,7 +103,10 @@ impl VTab for HelloVTab { }) } - fn func(func: &TableFunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box> { + fn func( + func: &TableFunctionInfo, + output: &mut DataChunkHandle, + ) -> Result<(), Box> { let init_data = func.get_init_data(); let bind_data = func.get_bind_data(); if init_data.done.swap(true, Ordering::Relaxed) { @@ -68,5 +131,7 @@ const EXTENSION_NAME: &str = env!("CARGO_PKG_NAME"); pub unsafe fn extension_entrypoint(con: Connection) -> Result<(), Box> { con.register_table_function::(EXTENSION_NAME) .expect("Failed to register hello table function"); + con.register_scalar_function::("rusty_echo") + .expect("Failed to register echo scala function"); Ok(()) -} \ No newline at end of file +} diff --git a/test/sql/rusty_echo.test b/test/sql/rusty_echo.test new file mode 100644 index 0000000..ade8a1a --- /dev/null +++ b/test/sql/rusty_echo.test @@ -0,0 +1,22 @@ +# name: test/sql/rusty_echo.test +# description: test rusty_quack extension +# group: [quack] + +# Before we load the extension, this will fail +statement error +SELECT rusty_echo('col0') FROM values ('Hello') as v; +---- +Catalog Error: Scalar Function with name rusty_echo does not exist! + +# Require statement will ensure the extension is loaded from now on +require rusty_quack + +require icu + +# Confirm the extension works +query I +SELECT rusty_echo(col0) FROM values ('Hello'), ('rusty'), ('world') as v; +---- +🐤 Hello📢Hello📢Hello +🐤 rusty📢rusty📢rusty +🐤 world📢world📢world