Skip to content

Commit 12957e4

Browse files
authored
Scheduler UDF sync -> client planning with stubs (#4)
* [1/?] attempt to serialize bare min data for stub udf planning in client * [2/?] fix return type issue, basic case working * fix error handling, serialize fn arguments for docs, add note about return field * tomlfmt * lint * more lintg
1 parent 336bb6e commit 12957e4

11 files changed

Lines changed: 488 additions & 18 deletions

File tree

ballista/client/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ datafusion = { workspace = true }
3636
log = { workspace = true }
3737

3838
tokio = { workspace = true }
39-
url = { workspace = true }
4039
tonic = { workspace = true }
40+
url = { workspace = true }
4141

4242
[dev-dependencies]
4343
ballista-executor = { path = "../executor", version = "50.0.0" }

ballista/client/src/extension.rs

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@
1616
// under the License.
1717

1818
pub use ballista_core::extension::{SessionConfigExt, SessionStateExt};
19+
use ballista_core::remote_catalog::remote_scalar_udf::RemoteScalarUDF;
1920
use ballista_core::remote_catalog::remote_table_provider::RemoteTableProvider;
2021
use ballista_core::serde::protobuf::scheduler_grpc_client::SchedulerGrpcClient;
21-
use ballista_core::serde::protobuf::GetCatalogParams;
22+
use ballista_core::serde::protobuf::{GetCatalogParams, GetRemoteFunctionsParams};
2223
use datafusion::catalog::{
2324
CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider,
2425
};
26+
use datafusion::execution::FunctionRegistry;
27+
use datafusion::logical_expr::ScalarUDF;
2528
use datafusion::{
2629
error::DataFusionError, execution::SessionState, prelude::SessionContext,
2730
};
@@ -99,6 +102,12 @@ pub trait SessionContextExt {
99102
&self,
100103
scheduler_url: &str,
101104
) -> datafusion::error::Result<()>;
105+
106+
/// Populates local context with functions from the scheduler.
107+
async fn populate_functions_from_scheduler(
108+
&self,
109+
scheduler_url: &str,
110+
) -> datafusion::error::Result<()>;
102111
}
103112

104113
#[async_trait::async_trait]
@@ -124,6 +133,8 @@ impl SessionContextExt for SessionContext {
124133

125134
// Populate local catalog from scheduler
126135
ctx.populate_catalog_from_scheduler(&scheduler_url).await?;
136+
ctx.populate_functions_from_scheduler(&scheduler_url)
137+
.await?;
127138

128139
Ok(ctx)
129140
}
@@ -145,6 +156,8 @@ impl SessionContextExt for SessionContext {
145156

146157
// Populate local catalog from scheduler
147158
ctx.populate_catalog_from_scheduler(&scheduler_url).await?;
159+
ctx.populate_functions_from_scheduler(&scheduler_url)
160+
.await?;
148161

149162
Ok(ctx)
150163
}
@@ -259,6 +272,39 @@ impl SessionContextExt for SessionContext {
259272

260273
Ok(())
261274
}
275+
276+
async fn populate_functions_from_scheduler(
277+
&self,
278+
scheduler_url: &str,
279+
) -> datafusion::common::Result<()> {
280+
let mut client = SchedulerGrpcClient::connect(scheduler_url.to_string())
281+
.await
282+
.map_err(|e| {
283+
DataFusionError::External(
284+
format!("Failed to connect to scheduler: {}", e).into(),
285+
)
286+
})?;
287+
288+
let request = tonic::Request::new(GetRemoteFunctionsParams {
289+
session_id: self.state().session_id().to_string(),
290+
});
291+
292+
let response = client.get_remote_functions(request).await.map_err(|e| {
293+
DataFusionError::External(format!("Failed to fetch catalog: {}", e).into())
294+
})?;
295+
296+
let remote_functions = response.into_inner();
297+
298+
for udf in remote_functions.udfs {
299+
if self.state().udf(&udf.name).is_ok() {
300+
continue;
301+
}
302+
303+
self.register_udf(ScalarUDF::new_from_impl(RemoteScalarUDF::new(udf)?))
304+
}
305+
306+
Ok(())
307+
}
262308
}
263309

264310
struct Extension {}

ballista/core/proto/ballista.proto

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,41 @@ message RemoteTableProviderNode {
745745
datafusion_common.Schema schema = 4;
746746
}
747747

748+
///////////////////////////////////////////////////////////////////////////////////////////////////
749+
// Remote UDF Metadata
750+
///////////////////////////////////////////////////////////////////////////////////////////////////
751+
752+
message GetRemoteFunctionsParams {
753+
string session_id = 1;
754+
}
755+
756+
message GetRemoteFunctionsResult {
757+
repeated ScalarUDFInfo udfs = 1;
758+
}
759+
760+
message ScalarUDFInfo {
761+
string name = 1;
762+
optional ScalarUDFDocumentation documentation = 2;
763+
repeated ScalarUDFTypeSignature signatures = 3;
764+
}
765+
766+
message ScalarUDFTypeSignature {
767+
repeated datafusion_common.ArrowType arity = 1;
768+
datafusion_common.ArrowType return_type = 2;
769+
}
770+
771+
message ScalarUDFDocumentation {
772+
string description = 1;
773+
string syntax_example = 2;
774+
optional string sql_example = 3;
775+
repeated ScalarUDFDocumentationArgument arguments = 4;
776+
}
777+
778+
message ScalarUDFDocumentationArgument {
779+
string argument = 1;
780+
string description = 2;
781+
}
782+
748783
service SchedulerGrpc {
749784
// Executors must poll the scheduler for heartbeat and to receive tasks
750785
rpc PollWork (PollWorkParams) returns (PollWorkResult) {}
@@ -774,6 +809,9 @@ service SchedulerGrpc {
774809

775810
// Get catalog metadata for a session
776811
rpc GetCatalog (GetCatalogParams) returns (GetCatalogResult) {}
812+
813+
// Get catalog metadata for a session
814+
rpc GetRemoteFunctions (GetRemoteFunctionsParams) returns (GetRemoteFunctionsResult) {}
777815
}
778816

779817
service ExecutorGrpc {

ballista/core/src/remote_catalog/catalog_serialize_ext.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,13 @@ impl CatalogSerializeExt for SessionContext {
4242
let catalog_names = self.state().catalog_list().catalog_names();
4343

4444
stream::iter(catalog_names.iter())
45-
.filter_map(|catalog_name| self.serialize_catalog(&catalog_name))
45+
.filter_map(|catalog_name| self.serialize_catalog(catalog_name))
4646
.collect::<Vec<_>>()
4747
.await
4848
}
4949

5050
async fn serialize_catalog(&self, name: &str) -> Option<CatalogInfo> {
51-
let Some(catalog) = self.catalog(name) else {
52-
return None;
53-
};
51+
let catalog = self.catalog(name)?;
5452

5553
let schemas = stream::iter(catalog.schema_names().iter())
5654
.filter_map(|schema_name| self.serialize_schema(schema_name, &catalog))
@@ -68,9 +66,7 @@ impl CatalogSerializeExt for SessionContext {
6866
name: &str,
6967
catalog: &Arc<dyn CatalogProvider>,
7068
) -> Option<SchemaInfo> {
71-
let Some(schema) = catalog.schema(name) else {
72-
return None;
73-
};
69+
let schema = catalog.schema(name)?;
7470

7571
let tables = stream::iter(schema.table_names())
7672
.filter_map(|table_name| async {

ballista/core/src/remote_catalog/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,6 @@
1717
//
1818

1919
pub mod catalog_serialize_ext;
20+
pub mod remote_function_serialize_ext;
21+
pub mod remote_scalar_udf;
2022
pub mod remote_table_provider;
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use crate::serde::protobuf::{
19+
ScalarUdfDocumentation, ScalarUdfDocumentationArgument, ScalarUdfInfo,
20+
ScalarUdfTypeSignature,
21+
};
22+
use datafusion::execution::FunctionRegistry;
23+
use datafusion::functions::all_default_functions;
24+
use datafusion::prelude::SessionContext;
25+
use datafusion_proto_common::ArrowType;
26+
use std::collections::HashSet;
27+
28+
/// Used to serialize function shapes to ship to Ballista clients
29+
pub trait RemoteFunctionSerializeExt {
30+
fn serialize_udfs(&self) -> Vec<ScalarUdfInfo>;
31+
}
32+
33+
impl RemoteFunctionSerializeExt for SessionContext {
34+
fn serialize_udfs(&self) -> Vec<ScalarUdfInfo> {
35+
let mut udfs = vec![];
36+
37+
let skip = all_default_functions()
38+
.iter()
39+
.map(|f| f.name().to_string())
40+
.collect::<HashSet<_>>();
41+
42+
for udf in self.udfs() {
43+
if skip.contains(&udf) {
44+
continue;
45+
}
46+
47+
let f = self.udf(&udf).expect("Must find defined UDF");
48+
let signature = f.signature();
49+
let signatures = signature
50+
.type_signature
51+
.get_example_types()
52+
.into_iter()
53+
.filter_map(|t| {
54+
let arity = t
55+
.iter()
56+
.map(TryInto::try_into)
57+
.collect::<Result<Vec<ArrowType>, _>>()
58+
.expect("Must serialize data types");
59+
60+
// TODO: some functions use `ScalarUDF::return_field_from_args`, which this does not support
61+
f.return_type(&t)
62+
.ok()
63+
.and_then(|ref return_type| return_type.try_into().ok())
64+
.map(|arrow_return_type| ScalarUdfTypeSignature {
65+
arity,
66+
return_type: Some(arrow_return_type),
67+
})
68+
})
69+
.collect::<Vec<_>>();
70+
71+
let docs = f.documentation().map(|d| {
72+
let arguments = d
73+
.arguments
74+
.iter()
75+
.flatten()
76+
.map(|(arg, desc)| ScalarUdfDocumentationArgument {
77+
argument: arg.clone(),
78+
description: desc.clone(),
79+
})
80+
.collect::<Vec<_>>();
81+
82+
ScalarUdfDocumentation {
83+
description: d.description.clone(),
84+
syntax_example: d.syntax_example.clone(),
85+
sql_example: d.sql_example.clone(),
86+
arguments,
87+
}
88+
});
89+
90+
udfs.push(ScalarUdfInfo {
91+
name: f.name().to_string(),
92+
documentation: docs,
93+
signatures,
94+
});
95+
}
96+
97+
udfs
98+
}
99+
}

0 commit comments

Comments
 (0)