From f0f10bc8f0fec25fda555be73200ef32c8ea5d9f Mon Sep 17 00:00:00 2001 From: Luke Kim <80174+lukekim@users.noreply.github.com> Date: Mon, 23 Mar 2026 15:10:59 -0700 Subject: [PATCH 1/3] feat: Add custom auth header support for GraphQL connector (#9899) * feat: add support for custom authentication header in GraphQL client * fix: address PR review - fix auth logic, add tests, revert unrelated changes - Fix auth selection: Basic auth now works even when auth_header is set but token is missing - Add 9 unit tests: auth selection logic + request_with_auth behavior for all variants - Add 2 integration tests: custom auth header verification, json_pointer + query combinations - Revert unrelated changes to cayenne, datafusion/mod.rs, and data_components/lib.rs * fix: add warning when auth_header is set but auth_token is missing * Address PR review: use prefixed param names in messages, improve test error handling * fix: Apply cargo fmt to tracing::warn macro call * Address PR review: generic warning messages, warn on auth_header without credentials, refactor test server helper * fix: Correct test_graphql_json_pointer_combinations assertion to match SELECT query --- Cargo.lock | 1 + .../connector-graphql/Cargo.toml | 1 + .../connector-graphql/src/lib.rs | 20 ++ crates/data_components/src/graphql/builder.rs | 9 + crates/data_components/src/graphql/client.rs | 271 +++++++++++++++++- crates/runtime/tests/graphql/mod.rs | 247 +++++++++++++++- 6 files changed, 538 insertions(+), 11 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7954f1bb55..5a2da38c19 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4118,6 +4118,7 @@ dependencies = [ "datafusion", "linkme", "paste", + "reqwest 0.12.24", "runtime", "snafu", "token_provider", diff --git a/crates/data-connectors/connector-graphql/Cargo.toml b/crates/data-connectors/connector-graphql/Cargo.toml index 8266cd2427..5d99fddd18 100644 --- a/crates/data-connectors/connector-graphql/Cargo.toml +++ b/crates/data-connectors/connector-graphql/Cargo.toml @@ -28,6 +28,7 @@ data_components = { path = "../../data_components" } datafusion.workspace = true linkme.workspace = true paste.workspace = true +reqwest.workspace = true runtime = { path = "../../runtime" } snafu.workspace = true token_provider = { path = "../../token_provider" } diff --git a/crates/data-connectors/connector-graphql/src/lib.rs b/crates/data-connectors/connector-graphql/src/lib.rs index d0409c86a3..7f58a2c979 100644 --- a/crates/data-connectors/connector-graphql/src/lib.rs +++ b/crates/data-connectors/connector-graphql/src/lib.rs @@ -52,6 +52,8 @@ impl GraphQLFactory { const PARAMETERS: &[ParameterSpec] = &[ // Connector parameters + ParameterSpec::component("auth_header") + .description("A custom header name to use for authentication instead of the default 'Authorization: Bearer' header. When set, the value of 'auth_token' is sent as the value of this header."), ParameterSpec::component("auth_token") .description("The bearer token to use in the GraphQL requests.") .secret(), @@ -104,6 +106,23 @@ impl GraphQL { Arc::new(StaticTokenProvider::new(token.clone())) as Arc }); + let auth_header = self + .params + .get("auth_header") + .expose() + .ok() + .map(|h| { + reqwest::header::HeaderName::try_from(h).map_err(|source| { + DataConnectorError::InvalidConfiguration { + dataconnector: "graphql".to_string(), + message: format!("Invalid 'graphql_auth_header' value: '{h}'. Ensure it is a valid HTTP header name. For details, visit: https://spiceai.org/docs/components/data-connectors/graphql"), + connector_component: ConnectorComponent::from(dataset), + source: source.into(), + } + }) + }) + .transpose()?; + let user = self .params .get("auth_user") @@ -163,6 +182,7 @@ impl GraphQL { None, None, None, + auth_header, ) .boxed() .map_err(|source| DataConnectorError::InternalWithSource { diff --git a/crates/data_components/src/graphql/builder.rs b/crates/data_components/src/graphql/builder.rs index 337e5ef0ad..4ca4c55b52 100644 --- a/crates/data_components/src/graphql/builder.rs +++ b/crates/data_components/src/graphql/builder.rs @@ -36,6 +36,7 @@ pub struct GraphQLClientBuilder { rate_limiter: Option>, rate_controller: Option>, semaphore: Option>, + auth_header: Option, } impl GraphQLClientBuilder { @@ -52,6 +53,7 @@ impl GraphQLClientBuilder { rate_limiter: None, rate_controller: None, semaphore: None, + auth_header: None, } } @@ -103,6 +105,12 @@ impl GraphQLClientBuilder { self } + #[must_use] + pub fn with_auth_header(mut self, auth_header: Option) -> Self { + self.auth_header = auth_header; + self + } + pub fn build(self, client: reqwest::Client) -> Result { GraphQLClient::new( client, @@ -116,6 +124,7 @@ impl GraphQLClientBuilder { self.rate_limiter, self.rate_controller, self.semaphore, + self.auth_header, ) } } diff --git a/crates/data_components/src/graphql/client.rs b/crates/data_components/src/graphql/client.rs index 0af9e5e98d..0ff456fa84 100644 --- a/crates/data_components/src/graphql/client.rs +++ b/crates/data_components/src/graphql/client.rs @@ -49,6 +49,7 @@ use datafusion::{error::DataFusionError, physical_plan::stream::RecordBatchRecei pub enum Auth { Basic(String, Option), Bearer(Arc), + CustomHeader(reqwest::header::HeaderName, Arc), } #[derive(Debug, PartialEq, Eq)] @@ -782,6 +783,7 @@ impl GraphQLClient { rate_limiter: Option>, rate_controller: Option>, semaphore: Option>, + auth_header: Option, ) -> Result { // Validate unnest depth to prevent excessive recursion if let UnnestBehavior::Depth(depth) = &unnest_behavior @@ -792,9 +794,28 @@ impl GraphQLClient { }); } - let auth = match (token, user, pass) { - (None, Some(user), pass) => Some(Auth::Basic(user, pass)), - (Some(token), _, _) => Some(Auth::Bearer(token)), + let auth = match (auth_header, token, user, pass) { + // Custom header with token takes precedence when both are configured + (Some(header_name), Some(token), _, _) => Some(Auth::CustomHeader(header_name, token)), + // Bearer token without custom header + (None, Some(token), _, _) => Some(Auth::Bearer(token)), + // When no token is available but a username is provided, use Basic auth + // regardless of whether a custom auth header was configured + (Some(_), None, Some(user), pass) => { + tracing::warn!( + "Custom auth header is configured without an auth token; falling back to Basic auth" + ); + Some(Auth::Basic(user, pass)) + } + (_, None, Some(user), pass) => Some(Auth::Basic(user, pass)), + // Custom auth header configured without any credentials + (Some(_), None, None, _) => { + tracing::warn!( + "Custom auth header is configured but no credentials are provided; requests will be unauthenticated" + ); + None + } + // No authentication configured _ => None, }; @@ -1283,6 +1304,9 @@ fn request_with_auth(request_builder: RequestBuilder, auth: Option<&Auth>) -> Re Some(Auth::Bearer(token_provider)) => { request_builder.bearer_auth(token_provider.get_token()) } + Some(Auth::CustomHeader(header_name, token_provider)) => { + request_builder.header(header_name.clone(), token_provider.get_token()) + } _ => request_builder, } } @@ -1479,6 +1503,7 @@ mod tests { use reqwest::StatusCode; use serde_json::Value; + use url::Url; use crate::graphql::client::GraphQLQuery; @@ -2023,4 +2048,244 @@ mod tests { assert_eq!(obj.get("1"), Some(&Value::String("a".to_string()))); assert_eq!(obj.get("2"), Some(&Value::String("a".to_string()))); } + + #[test] + fn test_auth_custom_header_with_token() { + let token = Arc::new(token_provider::StaticTokenProvider::new( + secrecy::SecretString::from("my_secret"), + )) as Arc; + let header = reqwest::header::HeaderName::from_static("x-shopify-access-token"); + + let client = super::GraphQLClient::new( + reqwest::Client::new(), + Url::parse("https://example.com/graphql").expect("valid url"), + Some("/data"), + Some(token), + None, + None, + UnnestBehavior::Depth(0), + None, + None, + None, + None, + Some(header), + ) + .expect("Should create client"); + + assert!( + matches!(&client.auth, Some(super::Auth::CustomHeader(name, _)) if name.as_str() == "x-shopify-access-token"), + "Expected CustomHeader auth, got {:?}", + client.auth.as_ref().map(|a| match a { + super::Auth::Basic(_, _) => "Basic", + super::Auth::Bearer(_) => "Bearer", + super::Auth::CustomHeader(_, _) => "CustomHeader", + }) + ); + } + + #[test] + fn test_auth_bearer_without_custom_header() { + let token = Arc::new(token_provider::StaticTokenProvider::new( + secrecy::SecretString::from("my_secret"), + )) as Arc; + + let client = super::GraphQLClient::new( + reqwest::Client::new(), + Url::parse("https://example.com/graphql").expect("valid url"), + None, + Some(token), + None, + None, + UnnestBehavior::Depth(0), + None, + None, + None, + None, + None, + ) + .expect("Should create client"); + + assert!( + matches!(&client.auth, Some(super::Auth::Bearer(_))), + "Expected Bearer auth" + ); + } + + #[test] + fn test_auth_basic_when_no_token() { + let client = super::GraphQLClient::new( + reqwest::Client::new(), + Url::parse("https://example.com/graphql").expect("valid url"), + None, + None, + Some("user".to_string()), + Some("pass".to_string()), + UnnestBehavior::Depth(0), + None, + None, + None, + None, + None, + ) + .expect("Should create client"); + + assert!( + matches!(&client.auth, Some(super::Auth::Basic(u, Some(p))) if u == "user" && p == "pass"), + "Expected Basic auth with user and pass" + ); + } + + #[test] + fn test_auth_basic_fallback_when_auth_header_set_without_token() { + let header = reqwest::header::HeaderName::from_static("x-custom"); + + let client = super::GraphQLClient::new( + reqwest::Client::new(), + Url::parse("https://example.com/graphql").expect("valid url"), + None, + None, + Some("user".to_string()), + None, + UnnestBehavior::Depth(0), + None, + None, + None, + None, + Some(header), + ) + .expect("Should create client"); + + assert!( + matches!(&client.auth, Some(super::Auth::Basic(u, None)) if u == "user"), + "Expected Basic auth fallback when auth_header is set but token is missing" + ); + } + + #[test] + fn test_auth_header_without_credentials_warns_and_returns_none() { + let header = reqwest::header::HeaderName::from_static("x-custom"); + + let client = super::GraphQLClient::new( + reqwest::Client::new(), + Url::parse("https://example.com/graphql").expect("valid url"), + None, + None, + None, + None, + UnnestBehavior::Depth(0), + None, + None, + None, + None, + Some(header), + ) + .expect("Should create client"); + + assert!( + client.auth.is_none(), + "Expected no auth when auth_header is set but no token or user" + ); + } + + #[test] + fn test_auth_none_when_nothing_set() { + let client = super::GraphQLClient::new( + reqwest::Client::new(), + Url::parse("https://example.com/graphql").expect("valid url"), + None, + None, + None, + None, + UnnestBehavior::Depth(0), + None, + None, + None, + None, + None, + ) + .expect("Should create client"); + + assert!(client.auth.is_none(), "Expected no auth"); + } + + #[test] + fn test_request_with_auth_custom_header() { + let token = Arc::new(token_provider::StaticTokenProvider::new( + secrecy::SecretString::from("secret_token_value"), + )) as Arc; + let header = reqwest::header::HeaderName::from_static("x-api-key"); + let auth = super::Auth::CustomHeader(header, token); + + let client = reqwest::Client::new(); + let request_builder = client.post("https://example.com/graphql"); + let request_builder = super::request_with_auth(request_builder, Some(&auth)); + let request = request_builder.build().expect("Should build request"); + + assert_eq!( + request + .headers() + .get("x-api-key") + .expect("Should have x-api-key header") + .to_str() + .expect("valid str"), + "secret_token_value" + ); + } + + #[test] + fn test_request_with_auth_bearer() { + let token = Arc::new(token_provider::StaticTokenProvider::new( + secrecy::SecretString::from("bearer_token"), + )) as Arc; + let auth = super::Auth::Bearer(token); + + let client = reqwest::Client::new(); + let request_builder = client.post("https://example.com/graphql"); + let request_builder = super::request_with_auth(request_builder, Some(&auth)); + let request = request_builder.build().expect("Should build request"); + + assert_eq!( + request + .headers() + .get("authorization") + .expect("Should have authorization header") + .to_str() + .expect("valid str"), + "Bearer bearer_token" + ); + } + + #[test] + fn test_request_with_auth_basic() { + let auth = super::Auth::Basic("user".to_string(), Some("pass".to_string())); + + let client = reqwest::Client::new(); + let request_builder = client.post("https://example.com/graphql"); + let request_builder = super::request_with_auth(request_builder, Some(&auth)); + let request = request_builder.build().expect("Should build request"); + + let auth_header = request + .headers() + .get("authorization") + .expect("Should have authorization header") + .to_str() + .expect("valid str"); + assert!( + auth_header.starts_with("Basic "), + "Expected Basic auth header, got: {auth_header}" + ); + } + + #[test] + fn test_request_with_auth_none() { + let client = reqwest::Client::new(); + let request_builder = client.post("https://example.com/graphql"); + let request_builder = super::request_with_auth(request_builder, None); + let request = request_builder.build().expect("Should build request"); + + assert!( + request.headers().get("authorization").is_none(), + "Expected no authorization header" + ); + } } diff --git a/crates/runtime/tests/graphql/mod.rs b/crates/runtime/tests/graphql/mod.rs index 84dd28ea93..e4b1cf9a86 100644 --- a/crates/runtime/tests/graphql/mod.rs +++ b/crates/runtime/tests/graphql/mod.rs @@ -25,6 +25,7 @@ use arrow::array::RecordBatch; use async_graphql::{EmptyMutation, EmptySubscription, SimpleObject}; use async_graphql::{Object, Schema}; use async_graphql_axum::{GraphQLRequest, GraphQLResponse}; +use axum::http::StatusCode; use axum::{Extension, Router, routing::post}; use runtime::Runtime; @@ -215,13 +216,10 @@ async fn graphql_handler(schema: Extension, req: GraphQLRequest) response.into() } -async fn start_server() -> Result<(tokio::sync::oneshot::Sender<()>, SocketAddr), String> { +async fn start_graphql_server( + app: Router, +) -> Result<(tokio::sync::oneshot::Sender<()>, SocketAddr), String> { let (tx, rx) = tokio::sync::oneshot::channel::<()>(); - let schema = Schema::build(QueryRoot, EmptyMutation, EmptySubscription).finish(); - - let app = Router::new() - .route("/graphql", post(graphql_handler)) - .layer(Extension(schema)); let tcp_listener = TcpListener::bind("127.0.0.1:0").await.map_err(|e| { tracing::error!("Failed to bind to address: {e}"); @@ -233,17 +231,27 @@ async fn start_server() -> Result<(tokio::sync::oneshot::Sender<()>, SocketAddr) })?; tokio::spawn(async move { - axum::serve(tcp_listener, app) + if let Err(e) = axum::serve(tcp_listener, app) .with_graceful_shutdown(async { rx.await.ok(); }) .await - .unwrap_or_default(); + { + tracing::error!("GraphQL test server failed: {e}"); + } }); Ok((tx, addr)) } +async fn start_server() -> Result<(tokio::sync::oneshot::Sender<()>, SocketAddr), String> { + let schema = Schema::build(QueryRoot, EmptyMutation, EmptySubscription).finish(); + let app = Router::new() + .route("/graphql", post(graphql_handler)) + .layer(Extension(schema)); + start_graphql_server(app).await +} + fn make_graphql_dataset( path: &str, name: &str, @@ -499,3 +507,226 @@ async fn test_graphql_pagination_with_limit() -> Result<(), String> { }).await } + +/// A handler that requires a custom `X-Api-Key` header and returns 401 if missing. +async fn graphql_handler_with_custom_auth( + schema: Extension, + headers: axum::http::HeaderMap, + req: GraphQLRequest, +) -> Result { + let api_key = headers.get("x-api-key").and_then(|v| v.to_str().ok()); + + match api_key { + Some("test-secret-key") => { + let response = schema.execute(req.into_inner()).await; + Ok(response.into()) + } + _ => Err(StatusCode::UNAUTHORIZED), + } +} + +async fn start_auth_server() -> Result<(tokio::sync::oneshot::Sender<()>, SocketAddr), String> { + let schema = Schema::build(QueryRoot, EmptyMutation, EmptySubscription).finish(); + let app = Router::new() + .route("/graphql", post(graphql_handler_with_custom_auth)) + .layer(Extension(schema)); + start_graphql_server(app).await +} + +fn make_graphql_dataset_with_auth( + path: &str, + name: &str, + query: &str, + json_pointer: &str, + auth_header: &str, + auth_token: &str, +) -> Dataset { + let mut dataset = Dataset::new(format!("graphql:{path}"), name.to_string()); + let params = HashMap::from([ + ("json_pointer".to_string(), json_pointer.to_string()), + ("graphql_query".to_string(), query.to_string()), + ("graphql_auth_header".to_string(), auth_header.to_string()), + ("graphql_auth_token".to_string(), auth_token.to_string()), + ]); + + dataset.params = Some(DatasetParams::from_string_map(params)); + dataset +} + +#[tokio::test] +async fn test_graphql_custom_auth_header() -> Result<(), String> { + let _tracing = init_tracing(Some("integration=debug,info")); + register_test_connectors().await; + + test_request_context() + .scope(async { + let (tx, addr) = start_auth_server().await?; + tracing::debug!("Auth server started at {}", addr); + + let app = AppBuilder::new("graphql_custom_auth_test") + .with_dataset(make_graphql_dataset_with_auth( + &format!("http://{addr}/graphql"), + "test_graphql_auth", + "query { users { id name } }", + "/data/users", + "X-Api-Key", + "test-secret-key", + )) + .build(); + + configure_test_datafusion(); + let mut rt = Runtime::builder().with_app(app).build().await; + let cloned_rt = Arc::new(rt.clone()); + + tokio::select! { + () = tokio::time::sleep(std::time::Duration::from_secs(60)) => { + return Err("Timed out waiting for datasets to load".to_string()); + } + () = cloned_rt.load_components() => {} + } + + run_query_and_check_results( + &mut rt, + "test_graphql_custom_auth_header", + "SELECT id, name FROM test_graphql_auth", + false, + Some(Box::new(|result_batches: Vec| { + let total_rows: usize = result_batches.iter().map(RecordBatch::num_rows).sum(); + assert_eq!(total_rows, 4, "Expected 4 users, got {total_rows}"); + for batch in result_batches { + assert_eq!( + batch.num_columns(), + 2, + "Expected 2 columns (id, name), got {}", + batch.num_columns() + ); + } + })), + ) + .await?; + + tx.send(()).map_err(|()| { + tracing::error!("Failed to send shutdown signal"); + "Failed to send shutdown signal".to_string() + })?; + + Ok(()) + }) + .await +} + +#[tokio::test] +async fn test_graphql_json_pointer_combinations() -> Result<(), String> { + let _tracing = init_tracing(Some("integration=debug,info")); + register_test_connectors().await; + + test_request_context() + .scope(async { + let (tx, addr) = start_server().await?; + tracing::debug!("Server started at {}", addr); + + // Test 1: Root-level json_pointer /data/users with flat query + let app = AppBuilder::new("graphql_json_pointer_test") + .with_dataset(make_graphql_dataset( + &format!("http://{addr}/graphql"), + "users_flat", + "query { users { id name } }", + "/data/users", + None, + )) + // Test 2: Nested json_pointer with paginated query + .with_dataset(make_graphql_dataset( + &format!("http://{addr}/graphql"), + "users_nested", + "query { paginatedUsers(first: 10) { users { id name posts { id title content } } pageInfo { hasNextPage endCursor } } }", + "/data/paginatedUsers/users", + None, + )) + // Test 3: Query with unnest_depth + .with_dataset(make_graphql_dataset( + &format!("http://{addr}/graphql"), + "users_unnested", + "query { users { id name posts { id title content } } }", + "/data/users", + Some(1), + )) + .build(); + + configure_test_datafusion(); + let mut rt = Runtime::builder().with_app(app).build().await; + let cloned_rt = Arc::new(rt.clone()); + + tokio::select! { + () = tokio::time::sleep(std::time::Duration::from_secs(60)) => { + return Err("Timed out waiting for datasets to load".to_string()); + } + () = cloned_rt.load_components() => {} + } + + // Test flat query with /data/users pointer returns 4 users with id and name + run_query_and_check_results( + &mut rt, + "test_graphql_json_pointer_flat", + "SELECT id, name FROM users_flat", + false, + Some(Box::new(|result_batches: Vec| { + let total_rows: usize = + result_batches.iter().map(RecordBatch::num_rows).sum(); + assert_eq!(total_rows, 4, "Expected 4 users from flat query, got {total_rows}"); + for batch in result_batches { + assert_eq!( + batch.num_columns(), + 2, + "Expected 2 columns (id, name), got {}", + batch.num_columns() + ); + } + })), + ) + .await?; + + // Test nested pointer /data/paginatedUsers/users with pagination + run_query_and_check_results( + &mut rt, + "test_graphql_json_pointer_nested", + "SELECT * FROM users_nested", + false, + Some(Box::new(|result_batches: Vec| { + let total_rows: usize = + result_batches.iter().map(RecordBatch::num_rows).sum(); + assert_eq!(total_rows, 4, "Expected 4 users from nested paginated query, got {total_rows}"); + for batch in &result_batches { + assert!( + batch.num_columns() >= 2, + "Expected at least 2 columns (id, name), got {}", + batch.num_columns() + ); + } + })), + ) + .await?; + + // Test unnesting with depth=1 flattens post fields to top level + run_query_and_check_results( + &mut rt, + "test_graphql_json_pointer_unnested", + "SELECT * FROM users_unnested", + false, + Some(Box::new(|result_batches: Vec| { + let total_rows: usize = + result_batches.iter().map(RecordBatch::num_rows).sum(); + // With unnest_depth=1, nested posts objects are flattened + assert!(total_rows > 0, "Expected results from unnested query, got {total_rows}"); + })), + ) + .await?; + + tx.send(()).map_err(|()| { + tracing::error!("Failed to send shutdown signal"); + "Failed to send shutdown signal".to_string() + })?; + + Ok(()) + }) + .await +} From 82d1db0801907d23d35bb41e131e6b0e9cf86b31 Mon Sep 17 00:00:00 2001 From: Viktor Yershov Date: Mon, 23 Mar 2026 15:30:34 -0700 Subject: [PATCH 2/3] Fix ADBC Catalog (#9900) * Fix ADBC Catalog * fix * Fix * Fix * Add more basic integration tests for distributed accelerations (#9740) * more basic tests for distributed accelerations * Refactor distributed acceleration tests to use async file operations and improve wait logic * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * Improve wait_for_row_count robustness and fix doc comment * Fix lint errors and update inline snapshots for distributed acceleration tests - Replace #[allow] with #[expect] (clippy::allow-attributes) - Add backticks around TEST_DATA_CSV in doc comment (clippy::doc-markdown) - Replace match with if let for single-pattern destructuring (clippy::single-match) - Add #[expect(clippy::cast_possible_truncation)] for i64-to-usize cast - Update inline snapshots to include bucket partition filters in plans * formatting * fix: use r#"..."# raw strings for insta snapshots containing double quotes Entire-Checkpoint: 7f5359852068 * fix snapshots --------- Co-authored-by: Luke Kim <80174+lukekim@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * Refactor ADBC functions to be crate-private and update tests for driver options * Refactor import statements in adbc.rs --------- Co-authored-by: Jack Eadie Co-authored-by: Luke Kim <80174+lukekim@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- crates/runtime/src/catalogconnector/adbc.rs | 63 ++++++++------------- crates/runtime/src/dataconnector/adbc.rs | 4 +- 2 files changed, 26 insertions(+), 41 deletions(-) diff --git a/crates/runtime/src/catalogconnector/adbc.rs b/crates/runtime/src/catalogconnector/adbc.rs index 1e702b95ba..afc94102ec 100644 --- a/crates/runtime/src/catalogconnector/adbc.rs +++ b/crates/runtime/src/catalogconnector/adbc.rs @@ -20,8 +20,9 @@ limitations under the License. //! and provides schema/table discovery using the ADBC metadata API. use super::{CatalogConnector, ConnectorComponent, ParameterSpec}; +use crate::dataconnector::adbc::{build_db_options, dialect_for_driver}; use crate::{Runtime, component::catalog::Catalog, dataconnector::parameters::ConnectorParams}; -use adbc_core::options::{AdbcVersion, OptionDatabase}; +use adbc_core::options::AdbcVersion; use adbc_core::{Driver as _, LOAD_FLAG_DEFAULT}; use adbc_driver_manager::{ManagedDatabase, ManagedDriver}; use async_trait::async_trait; @@ -56,6 +57,9 @@ pub const PARAMETERS: &[ParameterSpec] = &[ ParameterSpec::component("password") .description("Password for database authentication") .secret(), + ParameterSpec::component("driver_options").description( + "Semicolon-delimited driver-specific database options (e.g., 'key1=value1;key2=value2')", + ), ParameterSpec::runtime("connection_pool_size") .description("The maximum number of connections in the connection pool.") .default("5"), @@ -152,7 +156,7 @@ impl CatalogConnector for AdbcCatalog { ) -> super::Result> { let connector_component = ConnectorComponent::from(catalog); - let pool = create_pool(&self.params).await.map_err(|e| { + let (driver_name, pool) = create_pool(&self.params).await.map_err(|e| { super::Error::UnableToGetCatalogProvider { connector: PREFIX.to_string(), connector_component: connector_component.clone(), @@ -166,6 +170,7 @@ impl CatalogConnector for AdbcCatalog { pool, table_factory, catalog.include.clone(), + driver_name, )); provider @@ -181,24 +186,8 @@ impl CatalogConnector for AdbcCatalog { } } -/// Builds the list of ADBC database options from connection parameters. -fn build_db_options( - uri: &str, - username: Option<&str>, - password: Option<&str>, -) -> Vec<(OptionDatabase, adbc_core::options::OptionValue)> { - let mut opts = vec![(OptionDatabase::Uri, uri.into())]; - if let Some(u) = username { - opts.push((OptionDatabase::Username, u.into())); - } - if let Some(p) = password { - opts.push((OptionDatabase::Password, p.into())); - } - opts -} - /// Creates an ADBC connection pool from connector parameters. -async fn create_pool(params: &ConnectorParams) -> Result>> { +async fn create_pool(params: &ConnectorParams) -> Result<(String, Arc>)> { let driver_name = params .parameters .get("driver") @@ -220,7 +209,8 @@ async fn create_pool(params: &ConnectorParams) -> Result Result> { match params.parameters.get(name).expose().ok() { @@ -243,6 +233,7 @@ async fn create_pool(params: &ConnectorParams) -> Result Result Result>> { + tokio::task::spawn_blocking(move || -> Result<(String, Arc>)> { let mut driver = ManagedDriver::load_from_name( &driver_location, None, @@ -279,7 +270,7 @@ async fn create_pool(params: &ConnectorParams) -> Result, schemas: RwLock>>, include: Option>, + driver_name: String, } impl std::fmt::Debug for AdbcCatalogProvider { @@ -310,12 +302,14 @@ impl AdbcCatalogProvider { pool: Arc>, table_factory: AdbcTableFactory, include: Option, + driver_name: String, ) -> Self { Self { pool, table_factory, schemas: RwLock::new(HashMap::new()), include: include.map(Arc::new), + driver_name, } } @@ -390,6 +384,8 @@ impl AdbcCatalogProvider { ) -> HashMap> { let mut tables = HashMap::new(); + let dialect = dialect_for_driver(&self.driver_name); + for table_name in table_names { let schema_with_table = format!("{schema_name}.{table_name}"); if let Some(include) = &self.include @@ -401,7 +397,11 @@ impl AdbcCatalogProvider { let table_ref = TableReference::partial(schema_name.to_owned(), table_name.clone()); - match self.table_factory.table_provider(table_ref, None).await { + match self + .table_factory + .table_provider(table_ref, dialect.clone()) + .await + { Ok(provider) => { tables.insert(table_name, provider); } @@ -508,6 +508,7 @@ mod tests { assert!(param_names.contains(&"uri")); assert!(param_names.contains(&"username")); assert!(param_names.contains(&"password")); + assert!(param_names.contains(&"driver_options")); assert!(param_names.contains(&"connection_pool_size")); assert!(param_names.contains(&"connection_pool_min_idle")); } @@ -520,20 +521,4 @@ mod tests { let err = Error::MissingUri; assert_eq!(err.to_string(), "Missing required parameter: uri"); } - - #[test] - fn test_build_db_options_uri_only() { - let opts = build_db_options("file:test.db", None, None); - assert_eq!(opts.len(), 1); - assert_eq!(opts[0].0, OptionDatabase::Uri); - } - - #[test] - fn test_build_db_options_with_credentials() { - let opts = build_db_options("postgres://host/db", Some("admin"), Some("secret")); - assert_eq!(opts.len(), 3); - assert_eq!(opts[0].0, OptionDatabase::Uri); - assert_eq!(opts[1].0, OptionDatabase::Username); - assert_eq!(opts[2].0, OptionDatabase::Password); - } } diff --git a/crates/runtime/src/dataconnector/adbc.rs b/crates/runtime/src/dataconnector/adbc.rs index a0d917678c..04d8b7688a 100644 --- a/crates/runtime/src/dataconnector/adbc.rs +++ b/crates/runtime/src/dataconnector/adbc.rs @@ -305,7 +305,7 @@ impl DataConnectorFactory for AdbcFactory { } /// Builds the list of ADBC database options from connector parameters. -fn build_db_options( +pub(crate) fn build_db_options( uri: &str, username: Option<&str>, password: Option<&str>, @@ -371,7 +371,7 @@ fn build_conn_options( if opts.is_empty() { None } else { Some(opts) } } -fn dialect_for_driver(driver_name: &str) -> Option> { +pub(crate) fn dialect_for_driver(driver_name: &str) -> Option> { match driver_name { "bigquery" => Some(Arc::new(BigQueryDialect::new())), _ => None, From 1a5ee49fc258e94b8e6c27b98eac9e92de9327f5 Mon Sep 17 00:00:00 2001 From: Luke Kim <80174+lukekim@users.noreply.github.com> Date: Mon, 23 Mar 2026 19:11:09 -0700 Subject: [PATCH 3/3] Add --endpoint flag to spice run with scheme-based routing (#9903) * Add --endpoint flag to spice run with scheme-based routing Add a new --endpoint flag that auto-routes to the appropriate endpoint based on URL scheme: - http:// or https:// -> HTTP endpoint (--http) - grpc:// or grpc+tls:// -> Flight endpoint (--flight) Returns an error if no recognized scheme is provided, or if --endpoint conflicts with an already-specified --http-endpoint or --flight-endpoint. Fixes #9901 * Fix rustfmt formatting in resolve_endpoint --- bin/spice/src/commands/run.rs | 69 +++++++++++++++++++++++++++++++++-- 1 file changed, 65 insertions(+), 4 deletions(-) diff --git a/bin/spice/src/commands/run.rs b/bin/spice/src/commands/run.rs index c923dfe6ac..0db7b45b96 100644 --- a/bin/spice/src/commands/run.rs +++ b/bin/spice/src/commands/run.rs @@ -17,9 +17,11 @@ limitations under the License. //! Run command implementation - starts the Spice runtime. use crate::context::RuntimeContext; -use crate::error::{ChildProcessIdSnafu, Result, RuntimeExecutionSnafu, SignalHandlerSnafu}; +use crate::error::{ + ChildProcessIdSnafu, InvalidArgumentSnafu, Result, RuntimeExecutionSnafu, SignalHandlerSnafu, +}; use clap::Args; -use snafu::{OptionExt, ResultExt}; +use snafu::{OptionExt, ResultExt, ensure}; use std::process::Stdio; /// Arguments for the run command. @@ -44,6 +46,12 @@ Examples: See more at: https://spiceai.org/docs/"# )] pub struct RunArgs { + /// Specifies the runtime endpoint. The scheme determines the endpoint type: + /// http:// or https:// sets the HTTP endpoint, grpc:// or grpc+tls:// sets the Flight endpoint. + /// A scheme is required. + #[arg(long)] + endpoint: Option, + /// Specifies the runtime HTTP endpoint (overrides global --http-endpoint for binding) #[arg(long)] http_endpoint: Option, @@ -72,6 +80,13 @@ pub async fn execute(ctx: &RuntimeContext, args: &RunArgs, verbosity: u8) -> Res .await?; } + // Route --endpoint to the appropriate endpoint based on scheme + let (http_endpoint, flight_endpoint) = resolve_endpoint( + args.endpoint.as_deref(), + args.http_endpoint.as_deref(), + args.flight_endpoint.as_deref(), + )?; + tracing::info!("Spice.ai runtime starting..."); let mut spiced_args = args.args.clone(); @@ -83,7 +98,7 @@ pub async fn execute(ctx: &RuntimeContext, args: &RunArgs, verbosity: u8) -> Res } // Add endpoint flags if specified - if let Some(flight) = &args.flight_endpoint { + if let Some(flight) = &flight_endpoint { spiced_args.push("--flight".to_string()); spiced_args.push(flight.clone()); } @@ -93,7 +108,7 @@ pub async fn execute(ctx: &RuntimeContext, args: &RunArgs, verbosity: u8) -> Res spiced_args.push(metrics.clone()); } - let std_cmd = ctx.get_run_cmd(&spiced_args, args.http_endpoint.as_deref())?; + let std_cmd = ctx.get_run_cmd(&spiced_args, http_endpoint.as_deref())?; // Convert std::process::Command to tokio::process::Command let mut cmd = tokio::process::Command::from(std_cmd); @@ -159,3 +174,49 @@ async fn run_with_signal_forwarding( ) -> Result { child.wait().await.context(RuntimeExecutionSnafu) } + +/// Resolve `--endpoint` into the appropriate HTTP or Flight endpoint based on its URL scheme. +/// +/// Returns `(http_endpoint, flight_endpoint)`. If `--endpoint` is provided, it takes precedence +/// over the corresponding specific endpoint flag. An error is returned if `--endpoint` has no +/// recognized scheme or conflicts with an already-specified endpoint. +fn resolve_endpoint( + endpoint: Option<&str>, + http_endpoint: Option<&str>, + flight_endpoint: Option<&str>, +) -> Result<(Option, Option)> { + let Some(ep) = endpoint else { + return Ok(( + http_endpoint.map(String::from), + flight_endpoint.map(String::from), + )); + }; + + if ep.starts_with("http://") || ep.starts_with("https://") { + ensure!( + http_endpoint.is_none(), + InvalidArgumentSnafu { + message: "--endpoint with http(s):// scheme cannot be combined with --http-endpoint" + } + ); + Ok((Some(ep.to_string()), flight_endpoint.map(String::from))) + } else if ep.starts_with("grpc://") || ep.starts_with("grpc+tls://") { + ensure!( + flight_endpoint.is_none(), + InvalidArgumentSnafu { + message: "--endpoint with grpc:// scheme cannot be combined with --flight-endpoint" + } + ); + let addr = ep + .trim_start_matches("grpc+tls://") + .trim_start_matches("grpc://"); + Ok((http_endpoint.map(String::from), Some(addr.to_string()))) + } else { + Err(InvalidArgumentSnafu { + message: format!( + "Unrecognized scheme in --endpoint '{ep}'. Use http://, https://, grpc://, or grpc+tls://" + ), + } + .build()) + } +}