diff --git a/ballista/executor/src/config.rs b/ballista/executor/src/config.rs index dfe2c7625c..9ddaf006fc 100644 --- a/ballista/executor/src/config.rs +++ b/ballista/executor/src/config.rs @@ -140,6 +140,8 @@ impl TryFrom for ExecutorProcessConfig { grpc_max_decoding_message_size: opt.grpc_server_max_decoding_message_size, grpc_max_encoding_message_size: opt.grpc_server_max_encoding_message_size, executor_heartbeat_interval_seconds: opt.executor_heartbeat_interval_seconds, + disable_scheduler_heartbeats: false, + disable_task_status_push: false, override_execution_engine: None, override_function_registry: None, override_config_producer: None, diff --git a/ballista/executor/src/executor.rs b/ballista/executor/src/executor.rs index 3723c42e82..dcd4e71ed3 100644 --- a/ballista/executor/src/executor.rs +++ b/ballista/executor/src/executor.rs @@ -22,6 +22,7 @@ use crate::execution_engine::ExecutionEngine; use crate::execution_engine::QueryStageExecutor; use crate::metrics::ExecutorMetricsCollector; use crate::metrics::LoggingMetricsCollector; +use crate::status_store::ExecutorStatusStore; use ballista_core::error::BallistaError; use ballista_core::registry::BallistaFunctionRegistry; use ballista_core::serde::protobuf; @@ -85,6 +86,9 @@ pub struct Executor { /// Execution engine that the executor will delegate to /// for executing query stages pub(crate) execution_engine: Arc, + + /// Stores task status updates for scheduler polling. + status_store: Arc, } impl Executor { @@ -133,6 +137,7 @@ impl Executor { abort_handles: Default::default(), execution_engine: execution_engine .unwrap_or_else(|| Arc::new(DefaultExecutionEngine {})), + status_store: Arc::new(ExecutorStatusStore::new()), } } } @@ -149,6 +154,11 @@ impl Executor { (self.config_producer)() } + #[must_use] + pub fn status_store(&self) -> Arc { + Arc::clone(&self.status_store) + } + /// Execute one partition of a query stage and persist the result to disk in IPC format. On /// success, return a RecordBatch containing metadata about the results, including path /// and statistics. diff --git a/ballista/executor/src/executor_process.rs b/ballista/executor/src/executor_process.rs index 7e81eb233b..8fc6322505 100644 --- a/ballista/executor/src/executor_process.rs +++ b/ballista/executor/src/executor_process.rs @@ -94,6 +94,10 @@ pub struct ExecutorProcessConfig { /// The maximum size of an encoded message pub grpc_max_encoding_message_size: u32, pub executor_heartbeat_interval_seconds: u64, + /// Disable outbound scheduler heartbeats for polling-based health. + pub disable_scheduler_heartbeats: bool, + /// Disable outbound task status updates for polling-based status. + pub disable_task_status_push: bool, /// Optional execution engine to use to execute physical plans, will default to /// DataFusion if none is provided. pub override_execution_engine: Option>, @@ -147,6 +151,8 @@ impl Default for ExecutorProcessConfig { grpc_max_decoding_message_size: 16777216, grpc_max_encoding_message_size: 16777216, executor_heartbeat_interval_seconds: 60, + disable_scheduler_heartbeats: false, + disable_task_status_push: false, override_execution_engine: None, override_function_registry: None, override_runtime_producer: None, diff --git a/ballista/executor/src/executor_server.rs b/ballista/executor/src/executor_server.rs index af7ee89321..f7552281a4 100644 --- a/ballista/executor/src/executor_server.rs +++ b/ballista/executor/src/executor_server.rs @@ -78,31 +78,22 @@ struct CuratorTaskStatus { } pub async fn startup( - mut scheduler: SchedulerGrpcClient, + scheduler: SchedulerGrpcClient, config: Arc, executor: Arc, codec: BallistaCodec, stop_send: mpsc::Sender, shutdown_noti: &ShutdownNotifier, ) -> Result { - let channel_buf_size = executor.concurrent_tasks * 50; - let (tx_task, rx_task) = mpsc::channel::(channel_buf_size); - let (tx_task_status, rx_task_status) = - mpsc::channel::(channel_buf_size); - - let executor_server = ExecutorServer::new( + let executor_server = create_executor_server( scheduler.clone(), + Arc::clone(&config), executor.clone(), - ExecutorEnv { - tx_task, - tx_task_status, - tx_stop: stop_send, - }, codec, - config.grpc_max_encoding_message_size as usize, - config.grpc_max_decoding_message_size as usize, - config.override_create_grpc_client_endpoint.clone(), - ); + stop_send, + shutdown_noti, + ) + .await?; // 1. Start executor grpc service let server = { @@ -113,7 +104,7 @@ pub async fn startup( info!( "Ballista v{BALLISTA_VERSION} Rust Executor Grpc Server listening on {addr:?}" ); - let server = ExecutorGrpcServer::new(executor_server.clone()) + let server = ExecutorGrpcServer::new(executor_server.as_ref().clone()) .max_encoding_message_size(config.grpc_max_encoding_message_size as usize) .max_decoding_message_size(config.grpc_max_decoding_message_size as usize); let mut grpc_shutdown = shutdown_noti.subscribe_for_shutdown(); @@ -129,34 +120,89 @@ pub async fn startup( }) }; - // 2. Do executor registration - // TODO the executor registration should happen only after the executor grpc server started. - let executor_server = Arc::new(executor_server); - match register_executor(&mut scheduler, executor.clone()).await { + Ok(server) +} + +pub async fn create_executor_server< + T: 'static + AsLogicalPlan, + U: 'static + AsExecutionPlan, +>( + scheduler: SchedulerGrpcClient, + config: Arc, + executor: Arc, + codec: BallistaCodec, + stop_send: mpsc::Sender, + shutdown_noti: &ShutdownNotifier, +) -> Result>, BallistaError> { + let mut scheduler = scheduler; + let executor_server = create_executor_server_without_registration( + scheduler.clone(), + config, + executor.clone(), + codec, + stop_send, + shutdown_noti, + ) + .await?; + register_executor_with_scheduler(&mut scheduler, executor).await?; + Ok(executor_server) +} + +pub async fn create_executor_server_without_registration< + T: 'static + AsLogicalPlan, + U: 'static + AsExecutionPlan, +>( + scheduler: SchedulerGrpcClient, + config: Arc, + executor: Arc, + codec: BallistaCodec, + stop_send: mpsc::Sender, + shutdown_noti: &ShutdownNotifier, +) -> Result>, BallistaError> { + let channel_buf_size = executor.concurrent_tasks * 50; + let (tx_task, rx_task) = mpsc::channel::(channel_buf_size); + let (tx_task_status, rx_task_status) = + mpsc::channel::(channel_buf_size); + + let executor_server = Arc::new(ExecutorServer::new( + scheduler, + executor.clone(), + ExecutorEnv { + tx_task, + tx_task_status, + tx_stop: stop_send, + }, + codec, + config.grpc_max_encoding_message_size as usize, + config.grpc_max_decoding_message_size as usize, + config.override_create_grpc_client_endpoint.clone(), + !config.disable_task_status_push, + !config.disable_scheduler_heartbeats, + )); + + let heartbeater = Heartbeater::new(Arc::clone(&executor_server)); + heartbeater.start(shutdown_noti, config.executor_heartbeat_interval_seconds); + + let task_runner_pool = TaskRunnerPool::new(Arc::clone(&executor_server)); + task_runner_pool.start(rx_task, rx_task_status, shutdown_noti); + + Ok(executor_server) +} + +pub async fn register_executor_with_scheduler( + scheduler: &mut SchedulerGrpcClient, + executor: Arc, +) -> Result<(), BallistaError> { + match register_executor(scheduler, executor).await { Ok(_) => { info!("Executor registration succeed"); + Ok(()) } Err(error) => { error!("Executor registration failed due to: {error}"); - // abort the Executor Grpc Future - server.abort(); - return Err(error); + Err(error) } - }; - - // 3. Start Heartbeater loop - { - let heartbeater = Heartbeater::new(executor_server.clone()); - heartbeater.start(shutdown_noti, config.executor_heartbeat_interval_seconds); } - - // 4. Start TaskRunnerPool loop - { - let task_runner_pool = TaskRunnerPool::new(executor_server.clone()); - task_runner_pool.start(rx_task, rx_task_status, shutdown_noti); - } - - Ok(server) } #[allow(clippy::clone_on_copy)] @@ -189,6 +235,8 @@ pub struct ExecutorServer, + push_task_status: bool, + send_heartbeats: bool, } #[derive(Clone)] @@ -208,6 +256,7 @@ unsafe impl Sync for ExecutorEnv {} pub static TERMINATING: AtomicBool = AtomicBool::new(false); impl ExecutorServer { + #[allow(clippy::too_many_arguments)] fn new( scheduler_to_register: SchedulerGrpcClient, executor: Arc, @@ -216,6 +265,8 @@ impl ExecutorServer, + push_task_status: bool, + send_heartbeats: bool, ) -> Self { Self { _start_time: SystemTime::now() @@ -230,6 +281,8 @@ impl ExecutorServer ExecutorServer Heartbeater shutdown_noti: &ShutdownNotifier, executor_heartbeat_interval_seconds: u64, ) { + if !self.executor_server.send_heartbeats { + return; + } let executor_server = self.executor_server.clone(); let mut heartbeat_shutdown = shutdown_noti.subscribe_for_shutdown(); let heartbeat_complete = shutdown_noti.shutdown_complete_tx.clone(); @@ -482,89 +545,95 @@ impl TaskRunnerPool, shutdown_noti: &ShutdownNotifier, ) { - //1. loop for task status reporting - let executor_server = self.executor_server.clone(); - let mut tasks_status_shutdown = shutdown_noti.subscribe_for_shutdown(); - let tasks_status_complete = shutdown_noti.shutdown_complete_tx.clone(); - tokio::spawn(async move { - info!("Starting the task status reporter"); - // As long as the shutdown notification has not been received - while !tasks_status_shutdown.is_shutdown() { - let mut curator_task_status_map: HashMap> = - HashMap::new(); - // First try to fetch task status from the channel in *blocking* mode - let maybe_task_status: Option = tokio::select! { - task_status = rx_task_status.recv() => task_status, - _ = tasks_status_shutdown.recv() => { - info!("Stop task status reporting loop"); + if self.executor_server.push_task_status { + //1. loop for task status reporting + let executor_server = self.executor_server.clone(); + let mut tasks_status_shutdown = shutdown_noti.subscribe_for_shutdown(); + let tasks_status_complete = shutdown_noti.shutdown_complete_tx.clone(); + tokio::spawn(async move { + info!("Starting the task status reporter"); + // As long as the shutdown notification has not been received + while !tasks_status_shutdown.is_shutdown() { + let mut curator_task_status_map: HashMap> = + HashMap::new(); + // First try to fetch task status from the channel in *blocking* mode + let maybe_task_status: Option = tokio::select! { + task_status = rx_task_status.recv() => task_status, + _ = tasks_status_shutdown.recv() => { + info!("Stop task status reporting loop"); + drop(tasks_status_complete); + return; + } + }; + + let mut fetched_task_num = 0usize; + if let Some(task_status) = maybe_task_status { + let task_status_vec = curator_task_status_map + .entry(task_status.scheduler_id) + .or_default(); + task_status_vec.push(task_status.task_status); + fetched_task_num += 1; + } else { + info!("Channel is closed and will exit the task status report loop."); drop(tasks_status_complete); return; } - }; - - let mut fetched_task_num = 0usize; - if let Some(task_status) = maybe_task_status { - let task_status_vec = curator_task_status_map - .entry(task_status.scheduler_id) - .or_default(); - task_status_vec.push(task_status.task_status); - fetched_task_num += 1; - } else { - info!("Channel is closed and will exit the task status report loop."); - drop(tasks_status_complete); - return; - } - // Then try to fetch by non-blocking mode to fetch as much finished tasks as possible - loop { - match rx_task_status.try_recv() { - Ok(task_status) => { - let task_status_vec = curator_task_status_map - .entry(task_status.scheduler_id) - .or_default(); - task_status_vec.push(task_status.task_status); - fetched_task_num += 1; - } - Err(TryRecvError::Empty) => { - info!("Fetched {fetched_task_num} tasks status to report"); - break; - } - Err(TryRecvError::Disconnected) => { - info!("Channel is closed and will exit the task status report loop"); - drop(tasks_status_complete); - return; + // Then try to fetch by non-blocking mode to fetch as much finished tasks as possible + loop { + match rx_task_status.try_recv() { + Ok(task_status) => { + let task_status_vec = curator_task_status_map + .entry(task_status.scheduler_id) + .or_default(); + task_status_vec.push(task_status.task_status); + fetched_task_num += 1; + } + Err(TryRecvError::Empty) => { + info!( + "Fetched {fetched_task_num} tasks status to report" + ); + break; + } + Err(TryRecvError::Disconnected) => { + info!("Channel is closed and will exit the task status report loop"); + drop(tasks_status_complete); + return; + } } } - } - for (scheduler_id, tasks_status) in curator_task_status_map.into_iter() { - match executor_server.get_scheduler_client(&scheduler_id).await { - Ok(mut scheduler) => { - if let Err(e) = scheduler - .update_task_status(UpdateTaskStatusParams { - executor_id: executor_server - .executor - .metadata - .id - .clone(), - task_status: tasks_status.clone(), - }) - .await - { + for (scheduler_id, tasks_status) in + curator_task_status_map.into_iter() + { + match executor_server.get_scheduler_client(&scheduler_id).await { + Ok(mut scheduler) => { + if let Err(e) = scheduler + .update_task_status(UpdateTaskStatusParams { + executor_id: executor_server + .executor + .metadata + .id + .clone(), + task_status: tasks_status.clone(), + }) + .await + { + error!( + "Fail to update tasks {tasks_status:?} due to {e:?}" + ); + } + } + Err(e) => { error!( - "Fail to update tasks {tasks_status:?} due to {e:?}" + "Fail to connect to scheduler {scheduler_id} due to {e:?}" ); } } - Err(e) => { - error!( - "Fail to connect to scheduler {scheduler_id} due to {e:?}" - ); - } } } - } - }); + }); + } //2. loop for task fetching and running let executor_server = self.executor_server.clone(); diff --git a/ballista/executor/src/lib.rs b/ballista/executor/src/lib.rs index 2636516247..d41c8e2bc0 100644 --- a/ballista/executor/src/lib.rs +++ b/ballista/executor/src/lib.rs @@ -28,6 +28,7 @@ pub mod executor_server; pub mod flight_service; pub mod metrics; pub mod shutdown; +pub mod status_store; pub mod terminate; mod cpu_bound_executor; diff --git a/ballista/executor/src/status_store.rs b/ballista/executor/src/status_store.rs new file mode 100644 index 0000000000..6ce78b9a5c --- /dev/null +++ b/ballista/executor/src/status_store.rs @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use ballista_core::serde::protobuf::TaskStatus; +use dashmap::DashMap; +use std::collections::VecDeque; +use std::sync::{Arc, Mutex}; + +#[derive(Default)] +pub struct ExecutorStatusStore { + task_statuses: DashMap>>>, +} + +impl ExecutorStatusStore { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + pub fn record_task_status(&self, scheduler_id: String, status: TaskStatus) { + let queue = self + .task_statuses + .entry(scheduler_id) + .or_insert_with(|| Arc::new(Mutex::new(VecDeque::new()))) + .value() + .clone(); + let lock_result = queue.lock(); + if let Ok(mut guard) = lock_result { + guard.push_back(status); + } + } + + pub fn drain_task_statuses( + &self, + scheduler_id: &str, + max_count: usize, + ) -> Vec { + let Some(entry) = self.task_statuses.get(scheduler_id) else { + return Vec::new(); + }; + let queue = entry.value().clone(); + let Ok(mut queue) = queue.lock() else { + return Vec::new(); + }; + + let count = if max_count == 0 { + queue.len() + } else { + max_count.min(queue.len()) + }; + + (0..count).filter_map(|_| queue.pop_front()).collect() + } +} diff --git a/ballista/scheduler/src/config.rs b/ballista/scheduler/src/config.rs index e6ee57ff64..a6f2344885 100644 --- a/ballista/scheduler/src/config.rs +++ b/ballista/scheduler/src/config.rs @@ -207,6 +207,8 @@ pub struct SchedulerConfig { pub grpc_server_max_decoding_message_size: u32, /// The maximum size of an encoded message at the grpc server side. pub grpc_server_max_encoding_message_size: u32, + /// Whether to use TLS when connecting to executor gRPC services. + pub executor_grpc_use_tls: bool, /// The executor timeout in seconds. It should be longer than executor's heartbeat intervals. pub executor_timeout_seconds: u64, /// The interval to check expired or dead executors @@ -242,6 +244,7 @@ impl Default for SchedulerConfig { scheduler_event_expected_processing_duration: 0, grpc_server_max_decoding_message_size: 16777216, grpc_server_max_encoding_message_size: 16777216, + executor_grpc_use_tls: false, executor_timeout_seconds: 180, expire_dead_executor_interval_seconds: 15, override_config_producer: None, @@ -316,6 +319,11 @@ impl SchedulerConfig { self } + pub fn with_executor_grpc_use_tls(mut self, use_tls: bool) -> Self { + self.executor_grpc_use_tls = use_tls; + self + } + pub fn with_job_resubmit_interval_ms(mut self, interval_ms: u64) -> Self { self.job_resubmit_interval_ms = Some(interval_ms); self @@ -461,6 +469,7 @@ impl TryFrom for SchedulerConfig { .grpc_server_max_decoding_message_size, grpc_server_max_encoding_message_size: opt .grpc_server_max_encoding_message_size, + executor_grpc_use_tls: false, executor_timeout_seconds: opt.executor_timeout_seconds, expire_dead_executor_interval_seconds: opt .expire_dead_executor_interval_seconds, diff --git a/ballista/scheduler/src/scheduler_server/mod.rs b/ballista/scheduler/src/scheduler_server/mod.rs index e741601e4f..0c5d87eca0 100644 --- a/ballista/scheduler/src/scheduler_server/mod.rs +++ b/ballista/scheduler/src/scheduler_server/mod.rs @@ -221,7 +221,7 @@ impl SchedulerServer, diff --git a/ballista/scheduler/src/state/executor_manager.rs b/ballista/scheduler/src/state/executor_manager.rs index 0db3d4487e..a11e820934 100644 --- a/ballista/scheduler/src/state/executor_manager.rs +++ b/ballista/scheduler/src/state/executor_manager.rs @@ -251,7 +251,7 @@ impl ExecutorManager { metadata.id, specification.total_task_slots ); - ExecutorManager::test_connectivity(&metadata).await?; + self.test_connectivity(&metadata).await?; self.cluster_state .register_executor(metadata, specification) @@ -330,7 +330,7 @@ impl ExecutorManager { .unwrap_or_default() } - pub(crate) async fn save_executor_heartbeat( + pub async fn save_executor_heartbeat( &self, heartbeat: ExecutorHeartbeat, ) -> Result<()> { @@ -423,10 +423,7 @@ impl ExecutorManager { Ok(client) } else { let executor_metadata = self.get_executor_metadata(executor_id).await?; - let executor_url = format!( - "http://{}:{}", - executor_metadata.host, executor_metadata.grpc_port - ); + let executor_url = self.executor_url(&executor_metadata); let mut endpoint = create_grpc_client_endpoint(executor_url)?; if let Some(ref override_fn) = @@ -446,22 +443,33 @@ impl ExecutorManager { } #[cfg(not(test))] - async fn test_connectivity(metadata: &ExecutorMetadata) -> Result<()> { - let executor_url = format!("http://{}:{}", metadata.host, metadata.grpc_port); + async fn test_connectivity(&self, metadata: &ExecutorMetadata) -> Result<()> { + let executor_url = self.executor_url(metadata); debug!("Connecting to executor {executor_url:?}"); - let _ = protobuf::executor_grpc_client::ExecutorGrpcClient::connect(executor_url) - .await - .map_err(|e| { - BallistaError::Internal(format!( - "Failed to register executor at {}:{}, could not connect: {:?}", - metadata.host, metadata.grpc_port, e - )) - })?; + let mut endpoint = create_grpc_client_endpoint(executor_url)?; + if let Some(ref override_fn) = self.config.override_create_grpc_client_endpoint { + endpoint = override_fn(endpoint)?; + } + let _ = endpoint.connect().await.map_err(|e| { + BallistaError::Internal(format!( + "Failed to register executor at {}:{}, could not connect: {:?}", + metadata.host, metadata.grpc_port, e + )) + })?; Ok(()) } #[cfg(test)] - async fn test_connectivity(_metadata: &ExecutorMetadata) -> Result<()> { + async fn test_connectivity(&self, _metadata: &ExecutorMetadata) -> Result<()> { Ok(()) } + + fn executor_url(&self, metadata: &ExecutorMetadata) -> String { + let scheme = if self.config.executor_grpc_use_tls { + "https" + } else { + "http" + }; + format!("{}://{}:{}", scheme, metadata.host, metadata.grpc_port) + } }