Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

69 changes: 65 additions & 4 deletions bin/spice/src/commands/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ limitations under the License.
//! Run command implementation - starts the Spice runtime.

use crate::context::RuntimeContext;
use crate::error::{ChildProcessIdSnafu, Result, RuntimeExecutionSnafu, SignalHandlerSnafu};
use crate::error::{
ChildProcessIdSnafu, InvalidArgumentSnafu, Result, RuntimeExecutionSnafu, SignalHandlerSnafu,
};
use clap::Args;
use snafu::{OptionExt, ResultExt};
use snafu::{OptionExt, ResultExt, ensure};
use std::process::Stdio;

/// Arguments for the run command.
Expand All @@ -44,6 +46,12 @@ Examples:
See more at: https://spiceai.org/docs/"#
)]
pub struct RunArgs {
/// Specifies the runtime endpoint. The scheme determines the endpoint type:
/// http:// or https:// sets the HTTP endpoint, grpc:// or grpc+tls:// sets the Flight endpoint.
/// A scheme is required.
#[arg(long)]
endpoint: Option<String>,

/// Specifies the runtime HTTP endpoint (overrides global --http-endpoint for binding)
#[arg(long)]
http_endpoint: Option<String>,
Expand Down Expand Up @@ -72,6 +80,13 @@ pub async fn execute(ctx: &RuntimeContext, args: &RunArgs, verbosity: u8) -> Res
.await?;
}

// Route --endpoint to the appropriate endpoint based on scheme
let (http_endpoint, flight_endpoint) = resolve_endpoint(
args.endpoint.as_deref(),
args.http_endpoint.as_deref(),
args.flight_endpoint.as_deref(),
)?;

tracing::info!("Spice.ai runtime starting...");

let mut spiced_args = args.args.clone();
Expand All @@ -83,7 +98,7 @@ pub async fn execute(ctx: &RuntimeContext, args: &RunArgs, verbosity: u8) -> Res
}

// Add endpoint flags if specified
if let Some(flight) = &args.flight_endpoint {
if let Some(flight) = &flight_endpoint {
spiced_args.push("--flight".to_string());
spiced_args.push(flight.clone());
}
Expand All @@ -93,7 +108,7 @@ pub async fn execute(ctx: &RuntimeContext, args: &RunArgs, verbosity: u8) -> Res
spiced_args.push(metrics.clone());
}

let std_cmd = ctx.get_run_cmd(&spiced_args, args.http_endpoint.as_deref())?;
let std_cmd = ctx.get_run_cmd(&spiced_args, http_endpoint.as_deref())?;

// Convert std::process::Command to tokio::process::Command
let mut cmd = tokio::process::Command::from(std_cmd);
Expand Down Expand Up @@ -159,3 +174,49 @@ async fn run_with_signal_forwarding(
) -> Result<std::process::ExitStatus> {
child.wait().await.context(RuntimeExecutionSnafu)
}

/// Resolve `--endpoint` into the appropriate HTTP or Flight endpoint based on its URL scheme.
///
/// Returns `(http_endpoint, flight_endpoint)`. If `--endpoint` is provided, it takes precedence
/// over the corresponding specific endpoint flag. An error is returned if `--endpoint` has no
/// recognized scheme or conflicts with an already-specified endpoint.
fn resolve_endpoint(
endpoint: Option<&str>,
http_endpoint: Option<&str>,
flight_endpoint: Option<&str>,
) -> Result<(Option<String>, Option<String>)> {
let Some(ep) = endpoint else {
return Ok((
http_endpoint.map(String::from),
flight_endpoint.map(String::from),
));
};

if ep.starts_with("http://") || ep.starts_with("https://") {
ensure!(
http_endpoint.is_none(),
InvalidArgumentSnafu {
message: "--endpoint with http(s):// scheme cannot be combined with --http-endpoint"
}
);
Ok((Some(ep.to_string()), flight_endpoint.map(String::from)))
} else if ep.starts_with("grpc://") || ep.starts_with("grpc+tls://") {
ensure!(
flight_endpoint.is_none(),
InvalidArgumentSnafu {
message: "--endpoint with grpc:// scheme cannot be combined with --flight-endpoint"
}
);
let addr = ep
.trim_start_matches("grpc+tls://")
.trim_start_matches("grpc://");
Ok((http_endpoint.map(String::from), Some(addr.to_string())))
} else {
Err(InvalidArgumentSnafu {
message: format!(
"Unrecognized scheme in --endpoint '{ep}'. Use http://, https://, grpc://, or grpc+tls://"
),
}
.build())
}
}
1 change: 1 addition & 0 deletions crates/data-connectors/connector-graphql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ data_components = { path = "../../data_components" }
datafusion.workspace = true
linkme.workspace = true
paste.workspace = true
reqwest.workspace = true
runtime = { path = "../../runtime" }
snafu.workspace = true
token_provider = { path = "../../token_provider" }
Expand Down
20 changes: 20 additions & 0 deletions crates/data-connectors/connector-graphql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ impl GraphQLFactory {

const PARAMETERS: &[ParameterSpec] = &[
// Connector parameters
ParameterSpec::component("auth_header")
.description("A custom header name to use for authentication instead of the default 'Authorization: Bearer' header. When set, the value of 'auth_token' is sent as the value of this header."),
ParameterSpec::component("auth_token")
.description("The bearer token to use in the GraphQL requests.")
.secret(),
Expand Down Expand Up @@ -104,6 +106,23 @@ impl GraphQL {
Arc::new(StaticTokenProvider::new(token.clone())) as Arc<dyn TokenProvider>
});

let auth_header = self
.params
.get("auth_header")
.expose()
.ok()
.map(|h| {
reqwest::header::HeaderName::try_from(h).map_err(|source| {
DataConnectorError::InvalidConfiguration {
dataconnector: "graphql".to_string(),
message: format!("Invalid 'graphql_auth_header' value: '{h}'. Ensure it is a valid HTTP header name. For details, visit: https://spiceai.org/docs/components/data-connectors/graphql"),
connector_component: ConnectorComponent::from(dataset),
source: source.into(),
}
})
})
.transpose()?;

let user = self
.params
.get("auth_user")
Expand Down Expand Up @@ -163,6 +182,7 @@ impl GraphQL {
None,
None,
None,
auth_header,
)
.boxed()
.map_err(|source| DataConnectorError::InternalWithSource {
Expand Down
9 changes: 9 additions & 0 deletions crates/data_components/src/graphql/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub struct GraphQLClientBuilder {
rate_limiter: Option<Arc<dyn RateLimiter>>,
rate_controller: Option<Arc<RateController>>,
semaphore: Option<Arc<Semaphore>>,
auth_header: Option<reqwest::header::HeaderName>,
}

impl GraphQLClientBuilder {
Expand All @@ -52,6 +53,7 @@ impl GraphQLClientBuilder {
rate_limiter: None,
rate_controller: None,
semaphore: None,
auth_header: None,
}
}

Expand Down Expand Up @@ -103,6 +105,12 @@ impl GraphQLClientBuilder {
self
}

#[must_use]
pub fn with_auth_header(mut self, auth_header: Option<reqwest::header::HeaderName>) -> Self {
self.auth_header = auth_header;
self
}

pub fn build(self, client: reqwest::Client) -> Result<GraphQLClient> {
GraphQLClient::new(
client,
Expand All @@ -116,6 +124,7 @@ impl GraphQLClientBuilder {
self.rate_limiter,
self.rate_controller,
self.semaphore,
self.auth_header,
)
}
}
Loading
Loading