diff --git a/ballista/client/Cargo.toml b/ballista/client/Cargo.toml index 52614c8159..4b7f12e66a 100644 --- a/ballista/client/Cargo.toml +++ b/ballista/client/Cargo.toml @@ -36,8 +36,8 @@ datafusion = { workspace = true } log = { workspace = true } tokio = { workspace = true } -url = { workspace = true } tonic = { workspace = true } +url = { workspace = true } [dev-dependencies] ballista-executor = { path = "../executor", version = "50.0.0" } diff --git a/ballista/client/src/extension.rs b/ballista/client/src/extension.rs index 52142a980e..a9b6240dfe 100644 --- a/ballista/client/src/extension.rs +++ b/ballista/client/src/extension.rs @@ -16,12 +16,15 @@ // under the License. pub use ballista_core::extension::{SessionConfigExt, SessionStateExt}; +use ballista_core::remote_catalog::remote_scalar_udf::RemoteScalarUDF; use ballista_core::remote_catalog::remote_table_provider::RemoteTableProvider; use ballista_core::serde::protobuf::scheduler_grpc_client::SchedulerGrpcClient; -use ballista_core::serde::protobuf::GetCatalogParams; +use ballista_core::serde::protobuf::{GetCatalogParams, GetRemoteFunctionsParams}; use datafusion::catalog::{ CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider, }; +use datafusion::execution::FunctionRegistry; +use datafusion::logical_expr::ScalarUDF; use datafusion::{ error::DataFusionError, execution::SessionState, prelude::SessionContext, }; @@ -99,6 +102,12 @@ pub trait SessionContextExt { &self, scheduler_url: &str, ) -> datafusion::error::Result<()>; + + /// Populates local context with functions from the scheduler. + async fn populate_functions_from_scheduler( + &self, + scheduler_url: &str, + ) -> datafusion::error::Result<()>; } #[async_trait::async_trait] @@ -124,6 +133,8 @@ impl SessionContextExt for SessionContext { // Populate local catalog from scheduler ctx.populate_catalog_from_scheduler(&scheduler_url).await?; + ctx.populate_functions_from_scheduler(&scheduler_url) + .await?; Ok(ctx) } @@ -145,6 +156,8 @@ impl SessionContextExt for SessionContext { // Populate local catalog from scheduler ctx.populate_catalog_from_scheduler(&scheduler_url).await?; + ctx.populate_functions_from_scheduler(&scheduler_url) + .await?; Ok(ctx) } @@ -259,6 +272,39 @@ impl SessionContextExt for SessionContext { Ok(()) } + + async fn populate_functions_from_scheduler( + &self, + scheduler_url: &str, + ) -> datafusion::common::Result<()> { + let mut client = SchedulerGrpcClient::connect(scheduler_url.to_string()) + .await + .map_err(|e| { + DataFusionError::External( + format!("Failed to connect to scheduler: {}", e).into(), + ) + })?; + + let request = tonic::Request::new(GetRemoteFunctionsParams { + session_id: self.state().session_id().to_string(), + }); + + let response = client.get_remote_functions(request).await.map_err(|e| { + DataFusionError::External(format!("Failed to fetch catalog: {}", e).into()) + })?; + + let remote_functions = response.into_inner(); + + for udf in remote_functions.udfs { + if self.state().udf(&udf.name).is_ok() { + continue; + } + + self.register_udf(ScalarUDF::new_from_impl(RemoteScalarUDF::new(udf)?)) + } + + Ok(()) + } } struct Extension {} diff --git a/ballista/core/proto/ballista.proto b/ballista/core/proto/ballista.proto index 03f9cf27b1..61ab3c9e67 100644 --- a/ballista/core/proto/ballista.proto +++ b/ballista/core/proto/ballista.proto @@ -745,6 +745,41 @@ message RemoteTableProviderNode { datafusion_common.Schema schema = 4; } +/////////////////////////////////////////////////////////////////////////////////////////////////// +// Remote UDF Metadata +/////////////////////////////////////////////////////////////////////////////////////////////////// + +message GetRemoteFunctionsParams { + string session_id = 1; +} + +message GetRemoteFunctionsResult { + repeated ScalarUDFInfo udfs = 1; +} + +message ScalarUDFInfo { + string name = 1; + optional ScalarUDFDocumentation documentation = 2; + repeated ScalarUDFTypeSignature signatures = 3; +} + +message ScalarUDFTypeSignature { + repeated datafusion_common.ArrowType arity = 1; + datafusion_common.ArrowType return_type = 2; +} + +message ScalarUDFDocumentation { + string description = 1; + string syntax_example = 2; + optional string sql_example = 3; + repeated ScalarUDFDocumentationArgument arguments = 4; +} + +message ScalarUDFDocumentationArgument { + string argument = 1; + string description = 2; +} + service SchedulerGrpc { // Executors must poll the scheduler for heartbeat and to receive tasks rpc PollWork (PollWorkParams) returns (PollWorkResult) {} @@ -774,6 +809,9 @@ service SchedulerGrpc { // Get catalog metadata for a session rpc GetCatalog (GetCatalogParams) returns (GetCatalogResult) {} + + // Get catalog metadata for a session + rpc GetRemoteFunctions (GetRemoteFunctionsParams) returns (GetRemoteFunctionsResult) {} } service ExecutorGrpc { diff --git a/ballista/core/src/remote_catalog/catalog_serialize_ext.rs b/ballista/core/src/remote_catalog/catalog_serialize_ext.rs index e8e030423f..af5d3c8603 100644 --- a/ballista/core/src/remote_catalog/catalog_serialize_ext.rs +++ b/ballista/core/src/remote_catalog/catalog_serialize_ext.rs @@ -42,15 +42,13 @@ impl CatalogSerializeExt for SessionContext { let catalog_names = self.state().catalog_list().catalog_names(); stream::iter(catalog_names.iter()) - .filter_map(|catalog_name| self.serialize_catalog(&catalog_name)) + .filter_map(|catalog_name| self.serialize_catalog(catalog_name)) .collect::>() .await } async fn serialize_catalog(&self, name: &str) -> Option { - let Some(catalog) = self.catalog(name) else { - return None; - }; + let catalog = self.catalog(name)?; let schemas = stream::iter(catalog.schema_names().iter()) .filter_map(|schema_name| self.serialize_schema(schema_name, &catalog)) @@ -68,9 +66,7 @@ impl CatalogSerializeExt for SessionContext { name: &str, catalog: &Arc, ) -> Option { - let Some(schema) = catalog.schema(name) else { - return None; - }; + let schema = catalog.schema(name)?; let tables = stream::iter(schema.table_names()) .filter_map(|table_name| async { diff --git a/ballista/core/src/remote_catalog/mod.rs b/ballista/core/src/remote_catalog/mod.rs index 3803470e4b..8b8728791e 100644 --- a/ballista/core/src/remote_catalog/mod.rs +++ b/ballista/core/src/remote_catalog/mod.rs @@ -17,4 +17,6 @@ // pub mod catalog_serialize_ext; +pub mod remote_function_serialize_ext; +pub mod remote_scalar_udf; pub mod remote_table_provider; diff --git a/ballista/core/src/remote_catalog/remote_function_serialize_ext.rs b/ballista/core/src/remote_catalog/remote_function_serialize_ext.rs new file mode 100644 index 0000000000..8af5002ade --- /dev/null +++ b/ballista/core/src/remote_catalog/remote_function_serialize_ext.rs @@ -0,0 +1,99 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::serde::protobuf::{ + ScalarUdfDocumentation, ScalarUdfDocumentationArgument, ScalarUdfInfo, + ScalarUdfTypeSignature, +}; +use datafusion::execution::FunctionRegistry; +use datafusion::functions::all_default_functions; +use datafusion::prelude::SessionContext; +use datafusion_proto_common::ArrowType; +use std::collections::HashSet; + +/// Used to serialize function shapes to ship to Ballista clients +pub trait RemoteFunctionSerializeExt { + fn serialize_udfs(&self) -> Vec; +} + +impl RemoteFunctionSerializeExt for SessionContext { + fn serialize_udfs(&self) -> Vec { + let mut udfs = vec![]; + + let skip = all_default_functions() + .iter() + .map(|f| f.name().to_string()) + .collect::>(); + + for udf in self.udfs() { + if skip.contains(&udf) { + continue; + } + + let f = self.udf(&udf).expect("Must find defined UDF"); + let signature = f.signature(); + let signatures = signature + .type_signature + .get_example_types() + .into_iter() + .filter_map(|t| { + let arity = t + .iter() + .map(TryInto::try_into) + .collect::, _>>() + .expect("Must serialize data types"); + + // TODO: some functions use `ScalarUDF::return_field_from_args`, which this does not support + f.return_type(&t) + .ok() + .and_then(|ref return_type| return_type.try_into().ok()) + .map(|arrow_return_type| ScalarUdfTypeSignature { + arity, + return_type: Some(arrow_return_type), + }) + }) + .collect::>(); + + let docs = f.documentation().map(|d| { + let arguments = d + .arguments + .iter() + .flatten() + .map(|(arg, desc)| ScalarUdfDocumentationArgument { + argument: arg.clone(), + description: desc.clone(), + }) + .collect::>(); + + ScalarUdfDocumentation { + description: d.description.clone(), + syntax_example: d.syntax_example.clone(), + sql_example: d.sql_example.clone(), + arguments, + } + }); + + udfs.push(ScalarUdfInfo { + name: f.name().to_string(), + documentation: docs, + signatures, + }); + } + + udfs + } +} diff --git a/ballista/core/src/remote_catalog/remote_scalar_udf.rs b/ballista/core/src/remote_catalog/remote_scalar_udf.rs new file mode 100644 index 0000000000..2d154b56e5 --- /dev/null +++ b/ballista/core/src/remote_catalog/remote_scalar_udf.rs @@ -0,0 +1,142 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +use crate::serde::protobuf::ScalarUdfInfo; +use arrow::datatypes::DataType; +use datafusion::common::Result; +use datafusion::common::{exec_err, plan_err, DataFusionError}; +use datafusion::logical_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignature, Volatility, +}; +use std::any::Any; +use std::hash::Hash; + +/// A stub provider to encapsulate a function that exists in the scheduler's registry, +/// in order to decouple its specific implementation from logical planning +#[derive(Debug)] +pub struct RemoteScalarUDF { + arities: Vec>, + meta: ScalarUdfInfo, + signature: Signature, + documentation: Option, +} + +impl Hash for RemoteScalarUDF { + fn hash(&self, state: &mut H) { + self.arities.hash(state); + self.meta.name.hash(state); + self.signature.hash(state); + } +} + +impl PartialEq for RemoteScalarUDF { + fn eq(&self, other: &Self) -> bool { + self.arities == other.arities + && self.meta.name == other.meta.name + && self.signature == other.signature + } +} + +impl Eq for RemoteScalarUDF {} + +impl RemoteScalarUDF { + pub fn new(meta: ScalarUdfInfo) -> Result { + let mut arities = vec![]; + + for signature in &meta.signatures { + let signature_types = signature + .arity + .iter() + .map(|t| t.try_into()) + .collect::, _>>()?; + + arities.push(signature_types); + } + + let documentation = meta.documentation.clone().map(|d| Documentation { + doc_section: Default::default(), + description: d.description.clone(), + syntax_example: d.syntax_example.clone(), + sql_example: d.sql_example.clone(), + arguments: Some( + d.arguments + .iter() + .map(|a| (a.argument.clone(), a.description.clone())) + .collect(), + ), + alternative_syntax: None, + related_udfs: None, + }); + + Ok(Self { + arities, + documentation, + meta, + signature: Signature::new(TypeSignature::VariadicAny, Volatility::Volatile), + }) + } +} + +impl ScalarUDFImpl for RemoteScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.meta.name.as_str() + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type( + &self, + arg_types: &[DataType], + ) -> datafusion::common::Result { + let Some((sig_index, _)) = self + .arities + .iter() + .enumerate() + .find(|(_, s)| s == &arg_types) + else { + return plan_err!("Unable to determine function return type"); + }; + + let arrow_type = self.meta.signatures[sig_index] + .clone() + .return_type + .expect("Must define return type"); + + (&arrow_type) + .try_into() + .map_err(|e| DataFusionError::External(Box::new(e))) + } + + fn invoke_with_args( + &self, + _args: ScalarFunctionArgs, + ) -> datafusion::common::Result { + exec_err!("This is a stub function and should never be called on the client") + } + + fn documentation(&self) -> Option<&Documentation> { + self.documentation.as_ref() + } +} diff --git a/ballista/core/src/serde/generated/ballista.rs b/ballista/core/src/serde/generated/ballista.rs index c4f43e8afc..485819cce9 100644 --- a/ballista/core/src/serde/generated/ballista.rs +++ b/ballista/core/src/serde/generated/ballista.rs @@ -1117,6 +1117,50 @@ pub struct RemoteTableProviderNode { #[prost(message, optional, tag = "4")] pub schema: ::core::option::Option<::datafusion_proto_common::Schema>, } +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct GetRemoteFunctionsParams { + #[prost(string, tag = "1")] + pub session_id: ::prost::alloc::string::String, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct GetRemoteFunctionsResult { + #[prost(message, repeated, tag = "1")] + pub udfs: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ScalarUdfInfo { + #[prost(string, tag = "1")] + pub name: ::prost::alloc::string::String, + #[prost(message, optional, tag = "2")] + pub documentation: ::core::option::Option, + #[prost(message, repeated, tag = "3")] + pub signatures: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ScalarUdfTypeSignature { + #[prost(message, repeated, tag = "1")] + pub arity: ::prost::alloc::vec::Vec<::datafusion_proto_common::ArrowType>, + #[prost(message, optional, tag = "2")] + pub return_type: ::core::option::Option<::datafusion_proto_common::ArrowType>, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ScalarUdfDocumentation { + #[prost(string, tag = "1")] + pub description: ::prost::alloc::string::String, + #[prost(string, tag = "2")] + pub syntax_example: ::prost::alloc::string::String, + #[prost(string, optional, tag = "3")] + pub sql_example: ::core::option::Option<::prost::alloc::string::String>, + #[prost(message, repeated, tag = "4")] + pub arguments: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ScalarUdfDocumentationArgument { + #[prost(string, tag = "1")] + pub argument: ::prost::alloc::string::String, + #[prost(string, tag = "2")] + pub description: ::prost::alloc::string::String, +} /// Generated client implementations. pub mod scheduler_grpc_client { #![allow( @@ -1530,6 +1574,36 @@ pub mod scheduler_grpc_client { ); self.inner.unary(req, path, codec).await } + /// Get catalog metadata for a session + pub async fn get_remote_functions( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::unknown( + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/ballista.protobuf.SchedulerGrpc/GetRemoteFunctions", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert( + GrpcMethod::new( + "ballista.protobuf.SchedulerGrpc", + "GetRemoteFunctions", + ), + ); + self.inner.unary(req, path, codec).await + } } } /// Generated client implementations. @@ -1846,6 +1920,14 @@ pub mod scheduler_grpc_server { tonic::Response, tonic::Status, >; + /// Get catalog metadata for a session + async fn get_remote_functions( + &self, + request: tonic::Request, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + >; } #[derive(Debug)] pub struct SchedulerGrpcServer { @@ -2471,6 +2553,52 @@ pub mod scheduler_grpc_server { }; Box::pin(fut) } + "/ballista.protobuf.SchedulerGrpc/GetRemoteFunctions" => { + #[allow(non_camel_case_types)] + struct GetRemoteFunctionsSvc(pub Arc); + impl< + T: SchedulerGrpc, + > tonic::server::UnaryService + for GetRemoteFunctionsSvc { + type Response = super::GetRemoteFunctionsResult; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::get_remote_functions(&inner, request) + .await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let method = GetRemoteFunctionsSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } _ => { Box::pin(async move { let mut response = http::Response::new( diff --git a/ballista/core/src/serde/mod.rs b/ballista/core/src/serde/mod.rs index 017e4fdc73..354cdeb6da 100644 --- a/ballista/core/src/serde/mod.rs +++ b/ballista/core/src/serde/mod.rs @@ -50,7 +50,6 @@ use crate::remote_catalog::remote_table_provider::RemoteTableProvider; use crate::serde::protobuf::ballista_physical_plan_node::PhysicalPlanType; use crate::serde::scheduler::PartitionLocation; use datafusion::catalog::TableProvider; -use datafusion::logical_expr::UserDefinedLogicalNode; pub use generated::ballista as protobuf; use prost::Message; use std::fmt::Debug; diff --git a/ballista/scheduler/src/scheduler_server/grpc.rs b/ballista/scheduler/src/scheduler_server/grpc.rs index 3db3c6b1df..dc93c9509e 100644 --- a/ballista/scheduler/src/scheduler_server/grpc.rs +++ b/ballista/scheduler/src/scheduler_server/grpc.rs @@ -27,10 +27,10 @@ use ballista_core::serde::protobuf::{ CreateUpdateSessionParams, CreateUpdateSessionResult, ExecuteQueryFailureResult, ExecuteQueryParams, ExecuteQueryResult, ExecuteQuerySuccessResult, ExecutorHeartbeat, ExecutorStoppedParams, ExecutorStoppedResult, GetCatalogParams, GetCatalogResult, - GetJobStatusParams, GetJobStatusResult, HeartBeatParams, HeartBeatResult, - PollWorkParams, PollWorkResult, RegisterExecutorParams, RegisterExecutorResult, - RemoveSessionParams, RemoveSessionResult, UpdateTaskStatusParams, - UpdateTaskStatusResult, + GetJobStatusParams, GetJobStatusResult, GetRemoteFunctionsParams, + GetRemoteFunctionsResult, HeartBeatParams, HeartBeatResult, PollWorkParams, + PollWorkResult, RegisterExecutorParams, RegisterExecutorResult, RemoveSessionParams, + RemoveSessionResult, UpdateTaskStatusParams, UpdateTaskStatusResult, }; use ballista_core::serde::scheduler::ExecutorMetadata; use datafusion_proto::logical_plan::AsLogicalPlan; @@ -43,11 +43,11 @@ use std::ops::Deref; use crate::cluster::{bind_task_bias, bind_task_round_robin}; use crate::config::TaskDistributionPolicy; use crate::scheduler_server::event::QueryStageSchedulerEvent; +use crate::scheduler_server::SchedulerServer; +use ballista_core::remote_catalog::remote_function_serialize_ext::RemoteFunctionSerializeExt; use std::time::{SystemTime, UNIX_EPOCH}; use tonic::{Request, Response, Status}; -use crate::scheduler_server::SchedulerServer; - #[tonic::async_trait] impl SchedulerGrpc for SchedulerServer @@ -548,6 +548,26 @@ impl SchedulerGrpc catalogs: ctx.serialize_catalogs().await, })) } + + async fn get_remote_functions( + &self, + request: Request, + ) -> Result, Status> { + let GetRemoteFunctionsParams { session_id } = request.into_inner(); + let ctx = self + .state + .session_manager + .create_or_update_session( + session_id.as_str(), + &self.state.session_manager.produce_config(), + ) + .await + .map_err(|e| Status::internal(format!("Error creating session {e}")))?; + + Ok(Response::new(GetRemoteFunctionsResult { + udfs: ctx.serialize_udfs(), + })) + } } fn extract_connect_info(request: &Request) -> Option> { diff --git a/python/Cargo.toml b/python/Cargo.toml index 4e2cf1780d..fb4971fcce 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -42,8 +42,8 @@ ballista = { path = "../ballista/client", version = "50.0.0" } ballista-core = { path = "../ballista/core", version = "50.0.0" } ballista-executor = { path = "../ballista/executor", version = "50.0.0" } ballista-scheduler = { path = "../ballista/scheduler", version = "50.0.0" } -datafusion.workspace = true -datafusion-proto.workspace = true +datafusion = { workspace = true } +datafusion-proto = { workspace = true } datafusion-python = "50.1.0" pyo3 = { version = "0.25", features = [