Skip to content

Commit bea56af

Browse files
authored
Allow flightsql and spiceai connectors to override flight max message size (spiceai#5407)
* Allow `flightsql` and `spiceai` connectors to override default flight max message size * Update
1 parent b1d37f3 commit bea56af

9 files changed

Lines changed: 116 additions & 6 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ base64 = "0.22.1"
6464
bb8 = "0.8"
6565
bb8-postgres = "0.8"
6666
bytes = "1.10.0"
67+
byte-unit = "5.1.4"
6768
charset = "0.1.5"
6869
chrono = "0.4.38"
6970
clap = { version = "4.5.36", features = ["derive", "env"] }

crates/cache/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ version.workspace = true
1212
arrow.workspace = true
1313
async-stream.workspace = true
1414
async-trait.workspace = true
15-
byte-unit = "5.1.4"
15+
byte-unit.workspace = true
1616
datafusion.workspace = true
1717
fundu = { workspace = true }
1818
futures.workspace = true

crates/flight_client/src/lib.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,20 @@ impl FlightClient {
246246
self
247247
}
248248

249+
/// Overrides default maximum message size for encoding and decoding.
250+
#[must_use]
251+
pub fn with_max_message_size(
252+
mut self,
253+
max_encoding_message_size: usize,
254+
max_decoding_message_size: usize,
255+
) -> Self {
256+
self.flight_client = self
257+
.flight_client
258+
.max_encoding_message_size(max_encoding_message_size)
259+
.max_decoding_message_size(max_decoding_message_size);
260+
self
261+
}
262+
249263
/// Queries the flight service for the schema of the path.
250264
///
251265
/// # Arguments

crates/runtime/src/dataconnector.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use crate::get_params_with_secrets;
2727
use crate::parameters::ParameterSpec;
2828
use crate::parameters::Parameters;
2929
use crate::secrets::Secrets;
30+
use app::App;
3031
use arrow_schema::SchemaRef;
3132
use arrow_tools::schema::schema_meta_get_computed_columns;
3233
use async_trait::async_trait;
@@ -579,6 +580,7 @@ pub struct ConnectorParams {
579580
pub(crate) parameters: Parameters,
580581
pub(crate) unsupported_type_action: Option<UnsupportedTypeAction>,
581582
pub(crate) component: ConnectorComponent,
583+
pub(crate) app: Option<Arc<App>>,
582584
}
583585

584586
pub struct ConnectorParamsBuilder {
@@ -601,7 +603,7 @@ impl ConnectorParamsBuilder {
601603
) -> Result<ConnectorParams, Box<dyn std::error::Error + Send + Sync>> {
602604
let name = self.connector.to_string();
603605
let mut unsupported_type_action = None;
604-
let (params, prefix, parameters) = match &self.component {
606+
let (params, prefix, parameters, app) = match &self.component {
605607
ConnectorComponent::Catalog(catalog) => {
606608
let guard = CATALOG_CONNECTOR_FACTORY_REGISTRY.lock().await;
607609
let connector_factory = guard.get(&name);
@@ -616,6 +618,7 @@ impl ConnectorParamsBuilder {
616618
get_params_with_secrets(Arc::clone(&secrets), &catalog.params).await,
617619
factory.prefix(),
618620
factory.parameters(),
621+
None,
619622
)
620623
}
621624
ConnectorComponent::Dataset(dataset) => {
@@ -639,7 +642,12 @@ impl ConnectorParamsBuilder {
639642

640643
let params = get_params_with_secrets(Arc::clone(&secrets), &dataset.params).await;
641644

642-
(params, factory.prefix(), factory.parameters())
645+
(
646+
params,
647+
factory.prefix(),
648+
factory.parameters(),
649+
Some(Arc::clone(&dataset.app)),
650+
)
643651
}
644652
};
645653

@@ -656,6 +664,7 @@ impl ConnectorParamsBuilder {
656664
parameters,
657665
unsupported_type_action: unsupported_type_action.map(UnsupportedTypeAction::from),
658666
component: self.component,
667+
app,
659668
})
660669
}
661670
}

crates/runtime/src/dataconnector/flightsql.rs

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ pub enum Error {
4343

4444
#[snafu(display("Failed to connect to the Flight server.\n{source}"))]
4545
UnableToPerformHandshake { source: arrow::error::ArrowError },
46+
47+
#[snafu(display(
48+
"Failed to apply parameter '{parameter}': {source}. Ensure the value is valid and retry.\nFor details, visit: https://spiceai.org/docs/components/data-connectors/flightsql#params"
49+
))]
50+
InvalidParameterValue {
51+
parameter: String,
52+
source: Box<dyn std::error::Error + Send + Sync>,
53+
},
4654
}
4755

4856
pub type Result<T, E = Error> = std::result::Result<T, E>;
@@ -95,9 +103,24 @@ impl DataConnectorFactory for FlightSQLFactory {
95103
.await
96104
.context(UnableToConstructTlsChannelSnafu)?;
97105

106+
let max_message_size =
107+
match params
108+
.app
109+
.as_ref()
110+
.and_then(|app| app.runtime.flight.as_ref())
111+
{
112+
Some(flight) => flight.max_message_size_bytes().map_err(|err| {
113+
Error::InvalidParameterValue {
114+
parameter: "max_message_size".to_string(),
115+
source: err,
116+
}
117+
})?,
118+
None => None,
119+
};
120+
98121
let flight_client = FlightServiceClient::new(flight_channel)
99-
.max_encoding_message_size(MAX_ENCODING_MESSAGE_SIZE)
100-
.max_decoding_message_size(MAX_DECODING_MESSAGE_SIZE);
122+
.max_encoding_message_size(max_message_size.unwrap_or(MAX_ENCODING_MESSAGE_SIZE))
123+
.max_decoding_message_size(max_message_size.unwrap_or(MAX_DECODING_MESSAGE_SIZE));
101124

102125
let mut client = FlightSqlServiceClient::new_from_inner(flight_client);
103126
let username = params.parameters.get("username").expose().ok();

crates/runtime/src/dataconnector/spiceai.rs

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ pub enum Error {
7272
value: Arc<str>,
7373
source: InvalidMetadataValue,
7474
},
75+
76+
#[snafu(display(
77+
"Failed to apply parameter '{parameter}': {source}. Ensure the value is valid and retry.\nFor details, visit: https://spiceai.org/docs/components/data-connectors/spiceai#parameters"
78+
))]
79+
InvalidParameterValue {
80+
parameter: String,
81+
source: Box<dyn std::error::Error + Send + Sync>,
82+
},
7583
}
7684

7785
pub type Result<T, E = Error> = std::result::Result<T, E>;
@@ -188,9 +196,12 @@ impl DataConnectorFactory for SpiceAIFactory {
188196
.ok_or_else(|p| MissingRequiredParameterSnafu { parameter: p.0 }.build())?;
189197
let credentials = Credentials::new("", api_key.clone());
190198

191-
let flight_client = FlightClient::try_new(url, credentials, None)
199+
let mut flight_client = FlightClient::try_new(url, credentials, None)
192200
.await
193201
.context(UnableToCreateFlightClientSnafu)?;
202+
203+
flight_client = configure_max_message_size(flight_client, &params)?;
204+
194205
let flight_factory = FlightFactory::new(
195206
"spice.ai",
196207
flight_client,
@@ -211,6 +222,29 @@ impl DataConnectorFactory for SpiceAIFactory {
211222
}
212223
}
213224

225+
/// Configures flight client's message size based on app parameters
226+
fn configure_max_message_size(
227+
mut flight_client: FlightClient,
228+
params: &ConnectorParams,
229+
) -> Result<FlightClient> {
230+
if let Some(app) = params.app.as_ref() {
231+
if let Some(flight) = app.runtime.flight.as_ref() {
232+
if let Some(max_message_size) =
233+
flight
234+
.max_message_size_bytes()
235+
.map_err(|err| Error::InvalidParameterValue {
236+
parameter: "max_message_size".to_string(),
237+
source: err,
238+
})?
239+
{
240+
flight_client =
241+
flight_client.with_max_message_size(max_message_size, max_message_size);
242+
}
243+
}
244+
}
245+
Ok(flight_client)
246+
}
247+
214248
#[async_trait]
215249
impl DataConnector for SpiceAI {
216250
fn as_any(&self) -> &dyn Any {

crates/spicepod/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ rust-version.workspace = true
88
version.workspace = true
99

1010
[dependencies]
11+
byte-unit.workspace = true
1112
fundu.workspace = true
1213
regex.workspace = true
1314
schemars = { version = "0.8.22", optional = true }

crates/spicepod/src/component/runtime.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ pub struct Runtime {
5656
#[serde(default, skip_serializing_if = "is_default")]
5757
pub cors: CorsConfig,
5858

59+
#[serde(skip_serializing_if = "Option::is_none")]
60+
pub flight: Option<Flight>,
61+
5962
/// Configures where the runtime will store temporary files needed for operations like
6063
/// spilling to disk for queries & accelerations that are larger than memory.
6164
#[serde(default, skip_serializing_if = "Option::is_none")]
@@ -183,6 +186,30 @@ impl Default for TelemetryConfig {
183186
}
184187
}
185188

189+
#[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq)]
190+
#[cfg_attr(feature = "schemars", derive(JsonSchema))]
191+
pub struct Flight {
192+
pub max_message_size: Option<String>,
193+
}
194+
195+
impl Flight {
196+
pub fn max_message_size_bytes(&self) -> Result<Option<usize>, Box<dyn Error + Send + Sync>> {
197+
if let Some(size_str) = &self.max_message_size {
198+
let size_in_bytes = usize::try_from(
199+
byte_unit::Byte::parse_str(size_str, true)
200+
.map_err(|e| {
201+
format!("Failed to parse 'max_message_size' value '{size_str}': {e}")
202+
})?
203+
.as_u64(),
204+
)
205+
.unwrap_or_default();
206+
Ok(Some(size_in_bytes))
207+
} else {
208+
Ok(None)
209+
}
210+
}
211+
}
212+
186213
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
187214
#[serde(deny_unknown_fields)]
188215
#[cfg_attr(feature = "schemars", derive(JsonSchema))]

0 commit comments

Comments
 (0)