Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ballista/client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ datafusion = { workspace = true }
log = { workspace = true }

tokio = { workspace = true }
url = { workspace = true }
tonic = { workspace = true }
url = { workspace = true }

[dev-dependencies]
ballista-executor = { path = "../executor", version = "50.0.0" }
Expand Down
48 changes: 47 additions & 1 deletion ballista/client/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
// under the License.

pub use ballista_core::extension::{SessionConfigExt, SessionStateExt};
use ballista_core::remote_catalog::remote_scalar_udf::RemoteScalarUDF;
use ballista_core::remote_catalog::remote_table_provider::RemoteTableProvider;
use ballista_core::serde::protobuf::scheduler_grpc_client::SchedulerGrpcClient;
use ballista_core::serde::protobuf::GetCatalogParams;
use ballista_core::serde::protobuf::{GetCatalogParams, GetRemoteFunctionsParams};
use datafusion::catalog::{
CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider,
};
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::ScalarUDF;
use datafusion::{
error::DataFusionError, execution::SessionState, prelude::SessionContext,
};
Expand Down Expand Up @@ -99,6 +102,12 @@ pub trait SessionContextExt {
&self,
scheduler_url: &str,
) -> datafusion::error::Result<()>;

/// Populates local context with functions from the scheduler.
async fn populate_functions_from_scheduler(
&self,
scheduler_url: &str,
) -> datafusion::error::Result<()>;
}

#[async_trait::async_trait]
Expand All @@ -124,6 +133,8 @@ impl SessionContextExt for SessionContext {

// Populate local catalog from scheduler
ctx.populate_catalog_from_scheduler(&scheduler_url).await?;
ctx.populate_functions_from_scheduler(&scheduler_url)
.await?;

Ok(ctx)
}
Expand All @@ -145,6 +156,8 @@ impl SessionContextExt for SessionContext {

// Populate local catalog from scheduler
ctx.populate_catalog_from_scheduler(&scheduler_url).await?;
ctx.populate_functions_from_scheduler(&scheduler_url)
.await?;

Ok(ctx)
}
Expand Down Expand Up @@ -259,6 +272,39 @@ impl SessionContextExt for SessionContext {

Ok(())
}

async fn populate_functions_from_scheduler(
&self,
scheduler_url: &str,
) -> datafusion::common::Result<()> {
let mut client = SchedulerGrpcClient::connect(scheduler_url.to_string())
.await
.map_err(|e| {
DataFusionError::External(
format!("Failed to connect to scheduler: {}", e).into(),
)
})?;

let request = tonic::Request::new(GetRemoteFunctionsParams {
session_id: self.state().session_id().to_string(),
});

let response = client.get_remote_functions(request).await.map_err(|e| {
DataFusionError::External(format!("Failed to fetch catalog: {}", e).into())
})?;

let remote_functions = response.into_inner();

for udf in remote_functions.udfs {
if self.state().udf(&udf.name).is_ok() {
continue;
}

self.register_udf(ScalarUDF::new_from_impl(RemoteScalarUDF::new(udf)?))
}

Ok(())
}
}

struct Extension {}
Expand Down
38 changes: 38 additions & 0 deletions ballista/core/proto/ballista.proto
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,41 @@ message RemoteTableProviderNode {
datafusion_common.Schema schema = 4;
}

///////////////////////////////////////////////////////////////////////////////////////////////////
// Remote UDF Metadata
///////////////////////////////////////////////////////////////////////////////////////////////////

message GetRemoteFunctionsParams {
string session_id = 1;
}

message GetRemoteFunctionsResult {
repeated ScalarUDFInfo udfs = 1;
}

message ScalarUDFInfo {
string name = 1;
optional ScalarUDFDocumentation documentation = 2;
repeated ScalarUDFTypeSignature signatures = 3;
}

message ScalarUDFTypeSignature {
repeated datafusion_common.ArrowType arity = 1;
datafusion_common.ArrowType return_type = 2;
}

message ScalarUDFDocumentation {
string description = 1;
string syntax_example = 2;
optional string sql_example = 3;
repeated ScalarUDFDocumentationArgument arguments = 4;
}

message ScalarUDFDocumentationArgument {
string argument = 1;
string description = 2;
}

service SchedulerGrpc {
// Executors must poll the scheduler for heartbeat and to receive tasks
rpc PollWork (PollWorkParams) returns (PollWorkResult) {}
Expand Down Expand Up @@ -774,6 +809,9 @@ service SchedulerGrpc {

// Get catalog metadata for a session
rpc GetCatalog (GetCatalogParams) returns (GetCatalogResult) {}

// Get catalog metadata for a session
rpc GetRemoteFunctions (GetRemoteFunctionsParams) returns (GetRemoteFunctionsResult) {}
}

service ExecutorGrpc {
Expand Down
10 changes: 3 additions & 7 deletions ballista/core/src/remote_catalog/catalog_serialize_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,13 @@ impl CatalogSerializeExt for SessionContext {
let catalog_names = self.state().catalog_list().catalog_names();

stream::iter(catalog_names.iter())
.filter_map(|catalog_name| self.serialize_catalog(&catalog_name))
.filter_map(|catalog_name| self.serialize_catalog(catalog_name))
.collect::<Vec<_>>()
.await
}

async fn serialize_catalog(&self, name: &str) -> Option<CatalogInfo> {
let Some(catalog) = self.catalog(name) else {
return None;
};
let catalog = self.catalog(name)?;

let schemas = stream::iter(catalog.schema_names().iter())
.filter_map(|schema_name| self.serialize_schema(schema_name, &catalog))
Expand All @@ -68,9 +66,7 @@ impl CatalogSerializeExt for SessionContext {
name: &str,
catalog: &Arc<dyn CatalogProvider>,
) -> Option<SchemaInfo> {
let Some(schema) = catalog.schema(name) else {
return None;
};
let schema = catalog.schema(name)?;

let tables = stream::iter(schema.table_names())
.filter_map(|table_name| async {
Expand Down
2 changes: 2 additions & 0 deletions ballista/core/src/remote_catalog/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,6 @@
//

pub mod catalog_serialize_ext;
pub mod remote_function_serialize_ext;
pub mod remote_scalar_udf;
pub mod remote_table_provider;
99 changes: 99 additions & 0 deletions ballista/core/src/remote_catalog/remote_function_serialize_ext.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::serde::protobuf::{
ScalarUdfDocumentation, ScalarUdfDocumentationArgument, ScalarUdfInfo,
ScalarUdfTypeSignature,
};
use datafusion::execution::FunctionRegistry;
use datafusion::functions::all_default_functions;
use datafusion::prelude::SessionContext;
use datafusion_proto_common::ArrowType;
use std::collections::HashSet;

/// Used to serialize function shapes to ship to Ballista clients
pub trait RemoteFunctionSerializeExt {
fn serialize_udfs(&self) -> Vec<ScalarUdfInfo>;
}

impl RemoteFunctionSerializeExt for SessionContext {
fn serialize_udfs(&self) -> Vec<ScalarUdfInfo> {
let mut udfs = vec![];

let skip = all_default_functions()
.iter()
.map(|f| f.name().to_string())
.collect::<HashSet<_>>();

for udf in self.udfs() {
if skip.contains(&udf) {
continue;
}

let f = self.udf(&udf).expect("Must find defined UDF");
let signature = f.signature();
let signatures = signature
.type_signature
.get_example_types()
.into_iter()
.filter_map(|t| {
let arity = t
.iter()
.map(TryInto::try_into)
.collect::<Result<Vec<ArrowType>, _>>()
.expect("Must serialize data types");

// TODO: some functions use `ScalarUDF::return_field_from_args`, which this does not support
f.return_type(&t)
.ok()
.and_then(|ref return_type| return_type.try_into().ok())
.map(|arrow_return_type| ScalarUdfTypeSignature {
arity,
return_type: Some(arrow_return_type),
})
})
.collect::<Vec<_>>();

let docs = f.documentation().map(|d| {
let arguments = d
.arguments
.iter()
.flatten()
.map(|(arg, desc)| ScalarUdfDocumentationArgument {
argument: arg.clone(),
description: desc.clone(),
})
.collect::<Vec<_>>();

ScalarUdfDocumentation {
description: d.description.clone(),
syntax_example: d.syntax_example.clone(),
sql_example: d.sql_example.clone(),
arguments,
}
});

udfs.push(ScalarUdfInfo {
name: f.name().to_string(),
documentation: docs,
signatures,
});
}

udfs
}
}
Loading
Loading