Skip to content

Commit bc73a75

Browse files
mach-kernellukekim
authored andcommitted
Scheduler UDF sync -> client planning with stubs (#4)
* [1/?] attempt to serialize bare min data for stub udf planning in client * [2/?] fix return type issue, basic case working * fix error handling, serialize fn arguments for docs, add note about return field * tomlfmt * lint * more lintg
1 parent 9b60af1 commit bc73a75

13 files changed

Lines changed: 1172 additions & 22 deletions

File tree

ballista/client/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ datafusion = { workspace = true }
3636
log = { workspace = true }
3737

3838
tokio = { workspace = true }
39+
tonic = { workspace = true }
3940
url = { workspace = true }
4041

4142
[dev-dependencies]

ballista/client/src/extension.rs

Lines changed: 152 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,19 @@
1616
// under the License.
1717

1818
pub use ballista_core::extension::{SessionConfigExt, SessionStateExt};
19+
use ballista_core::remote_catalog::remote_scalar_udf::RemoteScalarUDF;
20+
use ballista_core::remote_catalog::remote_table_provider::RemoteTableProvider;
1921
use ballista_core::serde::protobuf::scheduler_grpc_client::SchedulerGrpcClient;
22+
use ballista_core::serde::protobuf::{GetCatalogParams, GetRemoteFunctionsParams};
23+
use datafusion::catalog::{
24+
CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider,
25+
};
26+
use datafusion::execution::FunctionRegistry;
27+
use datafusion::logical_expr::ScalarUDF;
2028
use datafusion::{
2129
error::DataFusionError, execution::SessionState, prelude::SessionContext,
2230
};
31+
use std::sync::Arc;
2332
use url::Url;
2433

2534
const DEFAULT_SCHEDULER_PORT: u16 = 50050;
@@ -86,6 +95,19 @@ pub trait SessionContextExt {
8695
url: &str,
8796
state: SessionState,
8897
) -> datafusion::error::Result<SessionContext>;
98+
99+
/// Populates the local catalog with metadata from the remote scheduler.
100+
/// This allows catalog queries like SHOW TABLES to work on the client.
101+
async fn populate_catalog_from_scheduler(
102+
&self,
103+
scheduler_url: &str,
104+
) -> datafusion::error::Result<()>;
105+
106+
/// Populates local context with functions from the scheduler.
107+
async fn populate_functions_from_scheduler(
108+
&self,
109+
scheduler_url: &str,
110+
) -> datafusion::error::Result<()>;
89111
}
90112

91113
#[async_trait::async_trait]
@@ -100,14 +122,21 @@ impl SessionContextExt for SessionContext {
100122
scheduler_url.clone()
101123
);
102124

103-
let session_state = state.upgrade_for_ballista(scheduler_url)?;
125+
let session_state = state.upgrade_for_ballista(scheduler_url.clone())?;
104126

105127
log::info!(
106128
"Server side SessionContext created with session id: {}",
107129
session_state.session_id()
108130
);
109131

110-
Ok(SessionContext::new_with_state(session_state))
132+
let ctx = SessionContext::new_with_state(session_state);
133+
134+
// Populate local catalog from scheduler
135+
ctx.populate_catalog_from_scheduler(&scheduler_url).await?;
136+
ctx.populate_functions_from_scheduler(&scheduler_url)
137+
.await?;
138+
139+
Ok(ctx)
111140
}
112141

113142
async fn remote(url: &str) -> datafusion::error::Result<SessionContext> {
@@ -117,13 +146,20 @@ impl SessionContextExt for SessionContext {
117146
scheduler_url.clone()
118147
);
119148

120-
let session_state = SessionState::new_ballista_state(scheduler_url)?;
149+
let session_state = SessionState::new_ballista_state(scheduler_url.clone())?;
121150
log::info!(
122151
"Server side SessionContext created with session id: {}",
123152
session_state.session_id()
124153
);
125154

126-
Ok(SessionContext::new_with_state(session_state))
155+
let ctx = SessionContext::new_with_state(session_state);
156+
157+
// Populate local catalog from scheduler
158+
ctx.populate_catalog_from_scheduler(&scheduler_url).await?;
159+
ctx.populate_functions_from_scheduler(&scheduler_url)
160+
.await?;
161+
162+
Ok(ctx)
127163
}
128164

129165
#[cfg(feature = "standalone")]
@@ -157,6 +193,118 @@ impl SessionContextExt for SessionContext {
157193

158194
Ok(SessionContext::new_with_state(session_state))
159195
}
196+
197+
async fn populate_catalog_from_scheduler(
198+
&self,
199+
scheduler_url: &str,
200+
) -> datafusion::error::Result<()> {
201+
let mut client = SchedulerGrpcClient::connect(scheduler_url.to_string())
202+
.await
203+
.map_err(|e| {
204+
DataFusionError::External(
205+
format!("Failed to connect to scheduler: {}", e).into(),
206+
)
207+
})?;
208+
209+
let request = tonic::Request::new(GetCatalogParams {
210+
session_id: self.state().session_id().to_string(),
211+
});
212+
213+
let response = client.get_catalog(request).await.map_err(|e| {
214+
DataFusionError::External(format!("Failed to fetch catalog: {}", e).into())
215+
})?;
216+
217+
let catalog_result = response.into_inner();
218+
219+
log::info!(
220+
"Received {} catalogs from scheduler",
221+
catalog_result.catalogs.len()
222+
);
223+
224+
for catalog_info in catalog_result.catalogs {
225+
let catalog_name = catalog_info.catalog_name;
226+
227+
let catalog: Arc<dyn CatalogProvider> =
228+
if let Some(_existing_catalog) = self.catalog(&catalog_name) {
229+
continue;
230+
} else {
231+
let new_catalog: Arc<dyn CatalogProvider> =
232+
Arc::new(MemoryCatalogProvider::new());
233+
self.register_catalog(&catalog_name, Arc::clone(&new_catalog));
234+
new_catalog
235+
};
236+
237+
for schema_info in catalog_info.schemas {
238+
let schema_name = schema_info.schema_name;
239+
let schema = if let Some(_existing_schema) = catalog.schema(&schema_name)
240+
{
241+
continue;
242+
} else {
243+
let new_schema: Arc<dyn SchemaProvider> =
244+
Arc::new(MemorySchemaProvider::new());
245+
catalog.register_schema(&schema_name, Arc::clone(&new_schema))?;
246+
new_schema
247+
};
248+
249+
// Bind `RemoteTableProvider` tables
250+
for table_info in schema_info.tables {
251+
if let Some(proto_schema) = table_info.schema {
252+
let arrow_schema: datafusion::arrow::datatypes::Schema =
253+
(&proto_schema)
254+
.try_into()
255+
.map_err(|e| DataFusionError::External(Box::new(e)))?;
256+
257+
let stub_table = RemoteTableProvider::new(
258+
&catalog_name,
259+
&schema_name,
260+
&table_info.table_name,
261+
Arc::new(arrow_schema),
262+
);
263+
264+
schema.register_table(
265+
table_info.table_name.clone(),
266+
Arc::new(stub_table),
267+
)?;
268+
}
269+
}
270+
}
271+
}
272+
273+
Ok(())
274+
}
275+
276+
async fn populate_functions_from_scheduler(
277+
&self,
278+
scheduler_url: &str,
279+
) -> datafusion::common::Result<()> {
280+
let mut client = SchedulerGrpcClient::connect(scheduler_url.to_string())
281+
.await
282+
.map_err(|e| {
283+
DataFusionError::External(
284+
format!("Failed to connect to scheduler: {}", e).into(),
285+
)
286+
})?;
287+
288+
let request = tonic::Request::new(GetRemoteFunctionsParams {
289+
session_id: self.state().session_id().to_string(),
290+
});
291+
292+
let response = client.get_remote_functions(request).await.map_err(|e| {
293+
DataFusionError::External(format!("Failed to fetch catalog: {}", e).into())
294+
})?;
295+
296+
let remote_functions = response.into_inner();
297+
298+
for udf in remote_functions.udfs {
299+
if self.state().udf(&udf.name).is_ok() {
300+
continue;
301+
}
302+
303+
self.register_udf(ScalarUDF::new_from_impl(RemoteScalarUDF::new(udf)?))
304+
}
305+
306+
Ok(())
307+
}
160308
}
161309

162310
struct Extension {}

ballista/core/proto/ballista.proto

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,75 @@ message RunningTaskInfo {
710710
uint32 partition_id = 4;;
711711
}
712712

713+
///////////////////////////////////////////////////////////////////////////////////////////////////
714+
// Catalog Metadata
715+
///////////////////////////////////////////////////////////////////////////////////////////////////
716+
717+
message GetCatalogParams {
718+
string session_id = 1;
719+
}
720+
721+
message GetCatalogResult {
722+
repeated CatalogInfo catalogs = 1;
723+
}
724+
725+
message CatalogInfo {
726+
string catalog_name = 1;
727+
repeated SchemaInfo schemas = 2;
728+
}
729+
730+
message SchemaInfo {
731+
string schema_name = 1;
732+
repeated TableInfo tables = 2;
733+
}
734+
735+
message TableInfo {
736+
string table_name = 1;
737+
datafusion_common.Schema schema = 2;
738+
}
739+
740+
message RemoteTableProviderNode {
741+
string catalog_name = 1;
742+
string schema_name = 2;
743+
string table_name = 3;
744+
datafusion_common.Schema schema = 4;
745+
}
746+
747+
///////////////////////////////////////////////////////////////////////////////////////////////////
748+
// Remote UDF Metadata
749+
///////////////////////////////////////////////////////////////////////////////////////////////////
750+
751+
message GetRemoteFunctionsParams {
752+
string session_id = 1;
753+
}
754+
755+
message GetRemoteFunctionsResult {
756+
repeated ScalarUDFInfo udfs = 1;
757+
}
758+
759+
message ScalarUDFInfo {
760+
string name = 1;
761+
optional ScalarUDFDocumentation documentation = 2;
762+
repeated ScalarUDFTypeSignature signatures = 3;
763+
}
764+
765+
message ScalarUDFTypeSignature {
766+
repeated datafusion_common.ArrowType arity = 1;
767+
datafusion_common.ArrowType return_type = 2;
768+
}
769+
770+
message ScalarUDFDocumentation {
771+
string description = 1;
772+
string syntax_example = 2;
773+
optional string sql_example = 3;
774+
repeated ScalarUDFDocumentationArgument arguments = 4;
775+
}
776+
777+
message ScalarUDFDocumentationArgument {
778+
string argument = 1;
779+
string description = 2;
780+
}
781+
713782
service SchedulerGrpc {
714783
// Executors must poll the scheduler for heartbeat and to receive tasks
715784
rpc PollWork (PollWorkParams) returns (PollWorkResult) {}
@@ -736,6 +805,12 @@ service SchedulerGrpc {
736805
rpc CancelJob (CancelJobParams) returns (CancelJobResult) {}
737806

738807
rpc CleanJobData (CleanJobDataParams) returns (CleanJobDataResult) {}
808+
809+
// Get catalog metadata for a session
810+
rpc GetCatalog (GetCatalogParams) returns (GetCatalogResult) {}
811+
812+
// Get catalog metadata for a session
813+
rpc GetRemoteFunctions (GetRemoteFunctionsParams) returns (GetRemoteFunctionsResult) {}
739814
}
740815

741816
service ExecutorGrpc {
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use crate::serde::generated::ballista::{CatalogInfo, SchemaInfo, TableInfo};
19+
use datafusion::catalog::CatalogProvider;
20+
use datafusion::prelude::SessionContext;
21+
use futures::stream;
22+
use futures::StreamExt;
23+
use std::sync::Arc;
24+
25+
/// Used to serialize catalog schemas and names to ship to Ballista clients
26+
#[async_trait::async_trait]
27+
pub trait CatalogSerializeExt {
28+
async fn serialize_catalogs(&self) -> Vec<CatalogInfo>;
29+
30+
async fn serialize_catalog(&self, name: &str) -> Option<CatalogInfo>;
31+
32+
async fn serialize_schema(
33+
&self,
34+
name: &str,
35+
catalog: &Arc<dyn CatalogProvider>,
36+
) -> Option<SchemaInfo>;
37+
}
38+
39+
#[async_trait::async_trait]
40+
impl CatalogSerializeExt for SessionContext {
41+
async fn serialize_catalogs(&self) -> Vec<CatalogInfo> {
42+
let catalog_names = self.state().catalog_list().catalog_names();
43+
44+
stream::iter(catalog_names.iter())
45+
.filter_map(|catalog_name| self.serialize_catalog(catalog_name))
46+
.collect::<Vec<_>>()
47+
.await
48+
}
49+
50+
async fn serialize_catalog(&self, name: &str) -> Option<CatalogInfo> {
51+
let catalog = self.catalog(name)?;
52+
53+
let schemas = stream::iter(catalog.schema_names().iter())
54+
.filter_map(|schema_name| self.serialize_schema(schema_name, &catalog))
55+
.collect::<Vec<_>>()
56+
.await;
57+
58+
Some(CatalogInfo {
59+
catalog_name: name.to_string(),
60+
schemas,
61+
})
62+
}
63+
64+
async fn serialize_schema(
65+
&self,
66+
name: &str,
67+
catalog: &Arc<dyn CatalogProvider>,
68+
) -> Option<SchemaInfo> {
69+
let schema = catalog.schema(name)?;
70+
71+
let tables = stream::iter(schema.table_names())
72+
.filter_map(|table_name| async {
73+
let table_lookup = schema.table(&table_name).await;
74+
table_lookup.ok().and_then(|maybe_table| {
75+
maybe_table.map(|provider| (table_name, provider))
76+
})
77+
})
78+
.map(|(table_name, provider)| TableInfo {
79+
table_name,
80+
schema: Some(
81+
provider
82+
.schema()
83+
.as_ref()
84+
.try_into()
85+
.expect("Must serialize schema"),
86+
),
87+
})
88+
.collect::<Vec<_>>()
89+
.await;
90+
91+
Some(SchemaInfo {
92+
schema_name: name.to_string(),
93+
tables,
94+
})
95+
}
96+
}

0 commit comments

Comments
 (0)