From 9fcff62a720282e8da4effc0cf7129b84010a3d2 Mon Sep 17 00:00:00 2001 From: Luke Kim <80174+lukekim@users.noreply.github.com> Date: Wed, 21 Jan 2026 21:45:20 -0800 Subject: [PATCH 1/6] feat: Store shuffles in object store (S3, Azure) --- Cargo.lock | 2 + Cargo.toml | 3 +- ballista/core/Cargo.toml | 3 +- ballista/core/src/config.rs | 22 + .../src/execution_plans/shuffle_reader.rs | 199 ++++- ballista/core/src/extension.rs | 49 +- ballista/core/src/lib.rs | 3 + ballista/core/src/shuffle_storage.rs | 755 ++++++++++++++++++ 8 files changed, 1021 insertions(+), 15 deletions(-) create mode 100644 ballista/core/src/shuffle_storage.rs diff --git a/Cargo.lock b/Cargo.lock index cc2a39129f..ef14ba781e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -992,6 +992,7 @@ dependencies = [ "async-trait", "aws-config", "aws-credential-types", + "bytes", "chrono", "clap 4.5.54", "datafusion", @@ -3812,6 +3813,7 @@ dependencies = [ "futures", "http 1.4.0", "http-body-util", + "httparse", "humantime", "hyper", "itertools", diff --git a/Cargo.toml b/Cargo.toml index bd73e3601a..5693cabcb7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,7 +40,8 @@ datafusion-cli = "51.0.0" datafusion-proto = "51.0.0" datafusion-proto-common = "51.0.0" datafusion-substrait = "51.0.0" -object_store = "0.12" +object_store = { version = "0.12", features = ["aws", "azure"] } +bytes = "1.5" prost = "0.14" prost-types = "0.14" rstest = { version = "0.26" } diff --git a/ballista/core/Cargo.toml b/ballista/core/Cargo.toml index 4eebdc7007..5eec1c19e0 100644 --- a/ballista/core/Cargo.toml +++ b/ballista/core/Cargo.toml @@ -45,6 +45,7 @@ arrow-flight = { workspace = true } async-trait = { workspace = true } aws-config = { version = "1.6.0", optional = true } aws-credential-types = { version = "1.2.0", optional = true } +bytes = { workspace = true } chrono = { version = "0.4", default-features = false } clap = { workspace = true, optional = true } datafusion = { workspace = true } @@ -54,7 +55,7 @@ futures = { workspace = true } itertools = "0.14" log = { workspace = true } md-5 = { version = "^0.10.0" } -object_store = { workspace = true, features = ["aws", "http"], optional = true } +object_store = { workspace = true, features = ["aws", "azure", "http"], optional = true } parking_lot = { workspace = true } prost = { workspace = true } prost-types = { workspace = true } diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs index 1e6a593fd2..e45fe1afa1 100644 --- a/ballista/core/src/config.rs +++ b/ballista/core/src/config.rs @@ -43,6 +43,10 @@ pub const BALLISTA_SHUFFLE_READER_FORCE_REMOTE_READ: &str = /// Configuration key to prefer Flight protocol for remote shuffle reads. pub const BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT: &str = "ballista.shuffle.remote_read_prefer_flight"; +/// Configuration key for shuffle storage type (local, s3, azure). +pub const BALLISTA_SHUFFLE_STORAGE_TYPE: &str = "ballista.shuffle.storage_type"; +/// Configuration key for shuffle storage base URL/path. +pub const BALLISTA_SHUFFLE_STORAGE_URL: &str = "ballista.shuffle.storage_url"; /// Configuration key for gRPC client connection timeout in seconds. pub const BALLISTA_GRPC_CLIENT_CONNECT_TIMEOUT_SECONDS: &str = @@ -85,6 +89,14 @@ static CONFIG_ENTRIES: LazyLock> = LazyLock::new(|| "Forces the shuffle reader to use flight reader instead of block reader for remote read. Block reader usually has better performance and resource utilization".to_string(), DataType::Boolean, Some((false).to_string())), + ConfigEntry::new(BALLISTA_SHUFFLE_STORAGE_TYPE.to_string(), + "Storage type for shuffle data: 'local' (default), 's3', or 'azure'".to_string(), + DataType::Utf8, + Some("local".to_string())), + ConfigEntry::new(BALLISTA_SHUFFLE_STORAGE_URL.to_string(), + "Base URL/path for shuffle storage. For local: file path; For S3: s3://bucket/prefix; For Azure: abfs://container@account.dfs.core.windows.net/prefix".to_string(), + DataType::Utf8, + None), ConfigEntry::new(BALLISTA_GRPC_CLIENT_CONNECT_TIMEOUT_SECONDS.to_string(), "Connection timeout for gRPC client in seconds".to_string(), DataType::UInt64, @@ -264,6 +276,16 @@ impl BallistaConfig { self.get_bool_setting(BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT) } + /// Returns the shuffle storage type (local, s3, azure). + pub fn shuffle_storage_type(&self) -> String { + self.get_string_setting(BALLISTA_SHUFFLE_STORAGE_TYPE) + } + + /// Returns the shuffle storage base URL/path if configured. + pub fn shuffle_storage_url(&self) -> Option { + self.settings.get(BALLISTA_SHUFFLE_STORAGE_URL).cloned() + } + fn get_usize_setting(&self, key: &str) -> usize { if let Some(v) = self.settings.get(key) { // infallible because we validate all configs in the constructor diff --git a/ballista/core/src/execution_plans/shuffle_reader.rs b/ballista/core/src/execution_plans/shuffle_reader.rs index ad9cea9bae..8d0ebbc4d4 100644 --- a/ballista/core/src/execution_plans/shuffle_reader.rs +++ b/ballista/core/src/execution_plans/shuffle_reader.rs @@ -22,12 +22,21 @@ use std::any::Any; use std::collections::HashMap; use std::fmt::Debug; use std::fs::File; -use std::io::BufReader; +use std::io::{BufReader, Cursor}; use std::pin::Pin; use std::result; use std::sync::Arc; use std::task::{Context, Poll}; +#[cfg(feature = "build-binary")] +use object_store::aws::AmazonS3Builder; +#[cfg(feature = "build-binary")] +use object_store::azure::MicrosoftAzureBuilder; +#[cfg(feature = "build-binary")] +use object_store::ObjectStore; +#[cfg(feature = "build-binary")] +use url::Url; + use crate::client::BallistaClient; use crate::extension::{BallistaConfigGrpcEndpoint, SessionConfigExt}; use crate::serde::scheduler::{PartitionLocation, PartitionStats}; @@ -371,23 +380,34 @@ impl Stream for AbortableReceiverStream { .map_err(|e| ArrowError::ExternalError(Box::new(e))) } } -/// Splits the provided partition locations into local and remote partitions. +/// Splits the provided partition locations into local, object store, and remote partitions. /// Local partitions are read directly from local Arrow IPC files, +/// object store partitions are read via the object store client, /// while remote partitions are fetched using the Arrow Flight client. /// If `force_remote_read` is true, all partitions are treated as remote. fn local_remote_read_split( partition_locations: Vec, force_remote_read: bool, -) -> (Vec, Vec) { +) -> (Vec, Vec, Vec) { if !force_remote_read { - partition_locations + let (local, non_local): (Vec<_>, Vec<_>) = partition_locations + .into_iter() + .partition(check_is_local_location); + let (object_store, remote): (Vec<_>, Vec<_>) = non_local .into_iter() - .partition(check_is_local_location) + .partition(check_is_object_store_location); + (local, object_store, remote) } else { - (vec![], partition_locations) + (vec![], vec![], partition_locations) } } +/// Check if the location is an object store path (S3 or Azure). +fn check_is_object_store_location(location: &PartitionLocation) -> bool { + let path = location.path.as_str(); + path.starts_with("s3://") || path.starts_with("abfs://") || path.starts_with("az://") +} + fn send_fetch_partitions( partition_locations: Vec, max_request_num: usize, @@ -401,12 +421,13 @@ fn send_fetch_partitions( let semaphore = Arc::new(Semaphore::new(max_request_num)); let mut spawned_tasks: Vec> = vec![]; - let (local_locations, remote_locations): (Vec<_>, Vec<_>) = + let (local_locations, object_store_locations, remote_locations): (Vec<_>, Vec<_>, Vec<_>) = local_remote_read_split(partition_locations, force_remote_read); debug!( - "local shuffle file counts:{}, remote shuffle file count:{}.", + "local shuffle file counts:{}, object store shuffle file count:{}, remote shuffle file count:{}.", local_locations.len(), + object_store_locations.len(), remote_locations.len() ); @@ -430,6 +451,31 @@ fn send_fetch_partitions( } })); + // Handle object store partitions with concurrency control + for p in object_store_locations.into_iter() { + let semaphore = semaphore.clone(); + let response_sender = response_sender.clone(); + spawned_tasks.push(SpawnedTask::spawn(async move { + // Block if exceeds max request number. + let permit = semaphore.acquire_owned().await.unwrap(); + let r = PartitionReaderEnum::ObjectStoreRemote + .fetch_partition( + &p, + max_message_size, + false, // flight_transport not used for object store + None, // customize_endpoint not used for object store + false, // use_tls not used for object store + ) + .await; + // Block if the channel buffer is full. + if let Err(e) = response_sender.send(r).await { + error!("Fail to send response event to the channel due to {e}"); + } + // Increase semaphore by dropping existing permits. + drop(permit); + })); + } + for p in remote_locations.into_iter() { let semaphore = semaphore.clone(); let response_sender = response_sender.clone(); @@ -590,14 +636,143 @@ fn fetch_partition_local_inner( Ok(reader) } +#[cfg(feature = "build-binary")] +async fn fetch_partition_object_store( + location: &PartitionLocation, +) -> result::Result { + let path = &location.path; + let metadata = &location.executor_meta; + let partition_id = &location.partition_id; + + debug!("Fetching shuffle partition from object store: {}", path); + + let batches = fetch_partition_object_store_inner(path).await.map_err(|e| { + // return BallistaError::FetchFailed may let scheduler retry this task. + BallistaError::FetchFailed( + metadata.id.clone(), + partition_id.stage_id, + partition_id.partition_id, + e.to_string(), + ) + })?; + + if batches.is_empty() { + return Err(BallistaError::General(format!( + "No batches found in shuffle partition at {}", + path + ))); + } + + let schema = batches[0].schema(); + let stream = futures::stream::iter(batches.into_iter().map(Ok)); + Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) +} + +#[cfg(not(feature = "build-binary"))] async fn fetch_partition_object_store( _location: &PartitionLocation, ) -> result::Result { Err(BallistaError::NotImplemented( - "Should not use ObjectStorePartitionReader".to_string(), + "Object store support requires 'build-binary' feature".to_string(), )) } +#[cfg(feature = "build-binary")] +async fn fetch_partition_object_store_inner( + path: &str, +) -> result::Result, BallistaError> { + use object_store::path::Path as ObjectPath; + + let url = Url::parse(path).map_err(|e| { + BallistaError::General(format!("Failed to parse object store URL '{}': {:?}", path, e)) + })?; + + let scheme = url.scheme(); + let store: Arc = match scheme { + "s3" => { + let bucket = url.host_str().ok_or_else(|| { + BallistaError::General(format!("No bucket in S3 URL: {}", path)) + })?; + let builder = AmazonS3Builder::from_env().with_bucket_name(bucket); + Arc::new(builder.build().map_err(|e| { + BallistaError::General(format!("Failed to create S3 client: {:?}", e)) + })?) + } + "abfs" | "az" => { + // Parse Azure URL: abfs://container@account.dfs.core.windows.net/path + let host = url.host_str().ok_or_else(|| { + BallistaError::General(format!("No host in Azure URL: {}", path)) + })?; + + // Extract container from username portion + let container = url.username(); + if container.is_empty() { + return Err(BallistaError::General(format!( + "No container in Azure URL. Expected format: abfs://container@account.dfs.core.windows.net/path. Got: {}", + path + ))); + } + + // Extract account from host (account.dfs.core.windows.net) + let account = host.split('.').next().ok_or_else(|| { + BallistaError::General(format!("No account in Azure URL: {}", path)) + })?; + + let builder = MicrosoftAzureBuilder::from_env() + .with_account(account) + .with_container_name(container); + Arc::new(builder.build().map_err(|e| { + BallistaError::General(format!("Failed to create Azure client: {:?}", e)) + })?) + } + _ => { + return Err(BallistaError::General(format!( + "Unsupported object store scheme: {}. Supported: s3, abfs, az", + scheme + ))); + } + }; + + // Extract the object path from the URL + let object_path = ObjectPath::from(url.path().trim_start_matches('/')); + + debug!("Reading object from path: {:?}", object_path); + + let get_result = store.get(&object_path).await.map_err(|e| { + BallistaError::General(format!( + "Failed to read object from {}: {:?}", + path, e + )) + })?; + + let bytes = get_result.bytes().await.map_err(|e| { + BallistaError::General(format!( + "Failed to read bytes from {}: {:?}", + path, e + )) + })?; + + let cursor = Cursor::new(bytes.to_vec()); + let stream_reader = StreamReader::try_new(cursor, None).map_err(|e| { + BallistaError::General(format!( + "Failed to create Arrow stream reader for {}: {:?}", + path, e + )) + })?; + + let mut batches = Vec::new(); + for batch_result in stream_reader { + batches.push(batch_result.map_err(|e| { + BallistaError::General(format!( + "Failed to read batch from {}: {:?}", + path, e + )) + })?); + } + + Ok(batches) +} + #[cfg(test)] mod tests { use super::*; @@ -955,14 +1130,16 @@ mod tests { let partition_locations = get_test_partition_locations(1, file_path.to_str().unwrap().to_string()); - let (local, remote) = local_remote_read_split(partition_locations.clone(), false); + let (local, object_store, remote) = local_remote_read_split(partition_locations.clone(), false); assert!(!local.is_empty()); + assert!(object_store.is_empty()); assert!(remote.is_empty()); - let (local, remote) = local_remote_read_split(partition_locations, true); + let (local, object_store, remote) = local_remote_read_split(partition_locations, true); assert!(local.is_empty()); + assert!(object_store.is_empty()); assert!(!remote.is_empty()); } diff --git a/ballista/core/src/extension.rs b/ballista/core/src/extension.rs index 0ba9eeda55..afe705dc00 100644 --- a/ballista/core/src/extension.rs +++ b/ballista/core/src/extension.rs @@ -18,8 +18,8 @@ use crate::config::{ BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE, BALLISTA_JOB_NAME, BALLISTA_SHUFFLE_READER_FORCE_REMOTE_READ, BALLISTA_SHUFFLE_READER_MAX_REQUESTS, - BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT, BALLISTA_STANDALONE_PARALLELISM, - BallistaConfig, + BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT, BALLISTA_SHUFFLE_STORAGE_TYPE, + BALLISTA_SHUFFLE_STORAGE_URL, BALLISTA_STANDALONE_PARALLELISM, BallistaConfig, }; use crate::planner::BallistaQueryPlanner; use crate::serde::protobuf::KeyValuePair; @@ -175,6 +175,18 @@ pub trait SessionConfigExt { /// Get whether to use TLS for executor connections fn ballista_use_tls(&self) -> bool; + + /// Returns the shuffle storage type (local, s3, azure). + fn ballista_shuffle_storage_type(&self) -> String; + + /// Sets the shuffle storage type. + fn with_ballista_shuffle_storage_type(self, storage_type: &str) -> Self; + + /// Returns the shuffle storage base URL/path if configured. + fn ballista_shuffle_storage_url(&self) -> Option; + + /// Sets the shuffle storage base URL/path. + fn with_ballista_shuffle_storage_url(self, url: &str) -> Self; } /// [SessionConfigHelperExt] is set of [SessionConfig] extension methods @@ -459,6 +471,39 @@ impl SessionConfigExt for SessionConfig { .map(|ext| ext.0) .unwrap_or(false) } + + fn ballista_shuffle_storage_type(&self) -> String { + self.options() + .extensions + .get::() + .map(|c| c.shuffle_storage_type()) + .unwrap_or_else(|| BallistaConfig::default().shuffle_storage_type()) + } + + fn with_ballista_shuffle_storage_type(self, storage_type: &str) -> Self { + if self.options().extensions.get::().is_some() { + self.set_str(BALLISTA_SHUFFLE_STORAGE_TYPE, storage_type) + } else { + self.with_option_extension(BallistaConfig::default()) + .set_str(BALLISTA_SHUFFLE_STORAGE_TYPE, storage_type) + } + } + + fn ballista_shuffle_storage_url(&self) -> Option { + self.options() + .extensions + .get::() + .and_then(|c| c.shuffle_storage_url()) + } + + fn with_ballista_shuffle_storage_url(self, url: &str) -> Self { + if self.options().extensions.get::().is_some() { + self.set_str(BALLISTA_SHUFFLE_STORAGE_URL, url) + } else { + self.with_option_extension(BallistaConfig::default()) + .set_str(BALLISTA_SHUFFLE_STORAGE_URL, url) + } + } } impl SessionConfigHelperExt for SessionConfig { diff --git a/ballista/core/src/lib.rs b/ballista/core/src/lib.rs index 8c483e5a31..bfdc7dc4f9 100644 --- a/ballista/core/src/lib.rs +++ b/ballista/core/src/lib.rs @@ -54,6 +54,9 @@ pub mod planner; pub mod registry; /// Serialization and deserialization for Ballista messages and plans. pub mod serde; +/// Shuffle storage abstraction for local and object store backends. +#[cfg(feature = "build-binary")] +pub mod shuffle_storage; /// General utility functions for Ballista operations. diff --git a/ballista/core/src/shuffle_storage.rs b/ballista/core/src/shuffle_storage.rs new file mode 100644 index 0000000000..6d67e9d859 --- /dev/null +++ b/ballista/core/src/shuffle_storage.rs @@ -0,0 +1,755 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Shuffle storage abstraction for storing shuffle data on local disk or object stores. +//! +//! This module provides a unified interface for reading and writing shuffle data +//! to different storage backends including local filesystem, Amazon S3, and Azure Blob Storage. + +use async_trait::async_trait; +use bytes::Bytes; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::ipc::reader::StreamReader; +use datafusion::arrow::ipc::writer::IpcWriteOptions; +use datafusion::arrow::ipc::writer::StreamWriter; +use datafusion::arrow::ipc::CompressionType; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::physical_plan::metrics; +use futures::StreamExt; +use log::{debug, error}; +use object_store::azure::MicrosoftAzureBuilder; +use object_store::aws::AmazonS3Builder; +use object_store::path::Path as ObjectPath; +use object_store::{ObjectStore, PutPayload}; +use std::fmt::{Debug, Display}; +use std::fs::File; +use std::io::{BufReader, Cursor}; +use std::path::PathBuf; +use std::sync::Arc; +use url::Url; + +use crate::error::{BallistaError, Result}; +use crate::serde::scheduler::PartitionStats; + +/// Defines the type of storage to use for shuffle data. +#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub enum ShuffleStorageType { + /// Store shuffle data on local disk (default behavior). + #[default] + Local, + /// Store shuffle data in Amazon S3. + S3, + /// Store shuffle data in Azure Blob Storage (ABFS). + Azure, +} + +impl Display for ShuffleStorageType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ShuffleStorageType::Local => write!(f, "local"), + ShuffleStorageType::S3 => write!(f, "s3"), + ShuffleStorageType::Azure => write!(f, "azure"), + } + } +} + +impl std::str::FromStr for ShuffleStorageType { + type Err = String; + + fn from_str(s: &str) -> std::result::Result { + match s.to_lowercase().as_str() { + "local" | "disk" => Ok(ShuffleStorageType::Local), + "s3" | "aws" => Ok(ShuffleStorageType::S3), + "azure" | "abfs" | "adls" => Ok(ShuffleStorageType::Azure), + _ => Err(format!( + "Unknown shuffle storage type: '{}'. Valid options are: local, s3, azure", + s + )), + } + } +} + +/// Configuration for S3 shuffle storage. +#[derive(Clone, Debug, Default)] +pub struct S3ShuffleConfig { + /// S3 bucket name for shuffle data. + pub bucket: Option, + /// AWS region. + pub region: Option, + /// S3 endpoint URL (for MinIO or custom S3-compatible storage). + pub endpoint: Option, + /// AWS access key ID. + pub access_key_id: Option, + /// AWS secret access key. + pub secret_access_key: Option, + /// Allow HTTP connections (default is HTTPS only). + pub allow_http: bool, +} + +/// Configuration for Azure Blob Storage shuffle storage. +#[derive(Clone, Debug, Default)] +pub struct AzureShuffleConfig { + /// Azure storage account name. + pub account: Option, + /// Azure storage container name. + pub container: Option, + /// Azure storage access key. + pub access_key: Option, + /// Azure SAS token (alternative to access key). + pub sas_token: Option, +} + +/// Configuration for shuffle storage. +#[derive(Clone, Debug, Default)] +pub struct ShuffleStorageConfig { + /// The type of storage to use. + pub storage_type: ShuffleStorageType, + /// Base URL/path for shuffle data storage. + /// For local: file path (e.g., /tmp/ballista) + /// For S3: s3://bucket/prefix + /// For Azure: abfs://container@account.dfs.core.windows.net/prefix + pub base_url: Option, + /// S3-specific configuration. + pub s3_config: S3ShuffleConfig, + /// Azure-specific configuration. + pub azure_config: AzureShuffleConfig, +} + +impl ShuffleStorageConfig { + /// Creates a new local storage configuration with the given work directory. + pub fn new_local(work_dir: &str) -> Self { + Self { + storage_type: ShuffleStorageType::Local, + base_url: Some(work_dir.to_string()), + ..Default::default() + } + } + + /// Creates a new S3 storage configuration. + pub fn new_s3(bucket: &str, prefix: Option<&str>, region: Option<&str>) -> Self { + let base_url = match prefix { + Some(p) => format!("s3://{}/{}", bucket, p), + None => format!("s3://{}", bucket), + }; + Self { + storage_type: ShuffleStorageType::S3, + base_url: Some(base_url), + s3_config: S3ShuffleConfig { + bucket: Some(bucket.to_string()), + region: region.map(|s| s.to_string()), + ..Default::default() + }, + ..Default::default() + } + } + + /// Creates a new Azure Blob Storage configuration. + pub fn new_azure(account: &str, container: &str, prefix: Option<&str>) -> Self { + let base_url = match prefix { + Some(p) => format!("abfs://{}@{}.dfs.core.windows.net/{}", container, account, p), + None => format!("abfs://{}@{}.dfs.core.windows.net", container, account), + }; + Self { + storage_type: ShuffleStorageType::Azure, + base_url: Some(base_url), + azure_config: AzureShuffleConfig { + account: Some(account.to_string()), + container: Some(container.to_string()), + ..Default::default() + }, + ..Default::default() + } + } +} + +/// Trait for shuffle storage operations. +#[async_trait] +pub trait ShuffleStorage: Send + Sync + Debug { + /// Write a record batch to storage and return the path where it was written. + async fn write_shuffle_data( + &self, + job_id: &str, + stage_id: usize, + partition_id: usize, + input_partition: usize, + batches: Vec, + schema: SchemaRef, + write_metric: &metrics::Time, + ) -> Result<(String, PartitionStats)>; + + /// Read shuffle data from storage. + async fn read_shuffle_data(&self, path: &str) -> Result>; + + /// Delete shuffle data for a job. + async fn delete_job_data(&self, job_id: &str) -> Result<()>; + + /// Get the base path/URL for this storage. + fn base_path(&self) -> &str; + + /// Check if a path is accessible by this storage backend. + fn can_handle(&self, path: &str) -> bool; +} + +/// Local filesystem shuffle storage implementation. +#[derive(Debug)] +pub struct LocalShuffleStorage { + work_dir: String, +} + +impl LocalShuffleStorage { + /// Creates a new local shuffle storage with the given work directory. + pub fn new(work_dir: &str) -> Self { + Self { + work_dir: work_dir.to_string(), + } + } +} + +#[async_trait] +impl ShuffleStorage for LocalShuffleStorage { + async fn write_shuffle_data( + &self, + job_id: &str, + stage_id: usize, + partition_id: usize, + input_partition: usize, + batches: Vec, + schema: SchemaRef, + write_metric: &metrics::Time, + ) -> Result<(String, PartitionStats)> { + let mut path = PathBuf::from(&self.work_dir); + path.push(job_id); + path.push(format!("{}", stage_id)); + path.push(format!("{}", partition_id)); + std::fs::create_dir_all(&path)?; + + let filename = if input_partition == partition_id { + "data.arrow".to_string() + } else { + format!("data-{}.arrow", input_partition) + }; + path.push(&filename); + + let path_str = path.to_str().unwrap().to_string(); + debug!("Writing shuffle data to local path: {}", path_str); + + let timer = write_metric.timer(); + let file = File::create(&path).map_err(|e| { + error!("Failed to create shuffle file at {}: {:?}", path_str, e); + BallistaError::IoError(e) + })?; + + let options = IpcWriteOptions::default() + .try_with_compression(Some(CompressionType::LZ4_FRAME))?; + + let mut writer = StreamWriter::try_new_with_options(file, schema.as_ref(), options)?; + + let mut num_rows = 0; + let mut num_batches = 0; + let mut num_bytes = 0; + + for batch in batches { + num_batches += 1; + num_rows += batch.num_rows(); + num_bytes += batch.get_array_memory_size(); + writer.write(&batch)?; + } + + writer.finish()?; + timer.done(); + + let stats = PartitionStats::new( + Some(num_rows as u64), + Some(num_batches), + Some(num_bytes as u64), + ); + + Ok((path_str, stats)) + } + + async fn read_shuffle_data(&self, path: &str) -> Result> { + let file = File::open(path).map_err(|e| { + BallistaError::General(format!("Failed to open shuffle file at {}: {:?}", path, e)) + })?; + let reader = BufReader::new(file); + let stream_reader = StreamReader::try_new(reader, None)?; + + let mut batches = Vec::new(); + for batch_result in stream_reader { + batches.push(batch_result?); + } + + Ok(batches) + } + + async fn delete_job_data(&self, job_id: &str) -> Result<()> { + let mut path = PathBuf::from(&self.work_dir); + path.push(job_id); + if path.exists() { + std::fs::remove_dir_all(&path)?; + } + Ok(()) + } + + fn base_path(&self) -> &str { + &self.work_dir + } + + fn can_handle(&self, path: &str) -> bool { + // Local storage can handle paths that don't start with a URL scheme + !path.starts_with("s3://") && !path.starts_with("abfs://") && !path.starts_with("az://") + } +} + +/// Object store based shuffle storage implementation (for S3 and Azure). +#[derive(Debug)] +pub struct ObjectStoreShuffleStorage { + store: Arc, + base_url: String, + storage_type: ShuffleStorageType, +} + +impl ObjectStoreShuffleStorage { + /// Creates a new S3 shuffle storage. + pub fn new_s3(config: &ShuffleStorageConfig) -> Result { + let s3_config = &config.s3_config; + let bucket = s3_config.bucket.as_ref().ok_or_else(|| { + BallistaError::General("S3 bucket not configured".to_string()) + })?; + + let mut builder = AmazonS3Builder::from_env().with_bucket_name(bucket); + + if let Some(region) = &s3_config.region { + builder = builder.with_region(region); + } + + if let Some(endpoint) = &s3_config.endpoint { + builder = builder.with_endpoint(endpoint); + } + + if let (Some(access_key), Some(secret_key)) = + (&s3_config.access_key_id, &s3_config.secret_access_key) + { + builder = builder + .with_access_key_id(access_key) + .with_secret_access_key(secret_key); + } + + if s3_config.allow_http { + builder = builder.with_allow_http(true); + } + + let store = builder.build().map_err(|e| { + BallistaError::General(format!("Failed to create S3 object store: {:?}", e)) + })?; + + let base_url = config + .base_url + .clone() + .unwrap_or_else(|| format!("s3://{}", bucket)); + + Ok(Self { + store: Arc::new(store), + base_url, + storage_type: ShuffleStorageType::S3, + }) + } + + /// Creates a new Azure Blob Storage shuffle storage. + pub fn new_azure(config: &ShuffleStorageConfig) -> Result { + let azure_config = &config.azure_config; + let account = azure_config.account.as_ref().ok_or_else(|| { + BallistaError::General("Azure storage account not configured".to_string()) + })?; + let container = azure_config.container.as_ref().ok_or_else(|| { + BallistaError::General("Azure storage container not configured".to_string()) + })?; + + let mut builder = MicrosoftAzureBuilder::new() + .with_account(account) + .with_container_name(container); + + if let Some(access_key) = &azure_config.access_key { + builder = builder.with_access_key(access_key); + } + + if let Some(sas_token) = &azure_config.sas_token { + // Parse SAS token into key-value pairs + // SAS token format: ?sv=2021-06-08&ss=bf&srt=sco&... + let query_pairs: Vec<(String, String)> = sas_token + .trim_start_matches('?') + .split('&') + .filter_map(|pair| { + let mut parts = pair.splitn(2, '='); + match (parts.next(), parts.next()) { + (Some(key), Some(value)) => Some((key.to_string(), value.to_string())), + _ => None, + } + }) + .collect(); + builder = builder.with_sas_authorization(query_pairs); + } + + let store = builder.build().map_err(|e| { + BallistaError::General(format!("Failed to create Azure object store: {:?}", e)) + })?; + + let base_url = config.base_url.clone().unwrap_or_else(|| { + format!("abfs://{}@{}.dfs.core.windows.net", container, account) + }); + + Ok(Self { + store: Arc::new(store), + base_url, + storage_type: ShuffleStorageType::Azure, + }) + } + + /// Creates object store storage from configuration. + pub fn from_config(config: &ShuffleStorageConfig) -> Result { + match config.storage_type { + ShuffleStorageType::S3 => Self::new_s3(config), + ShuffleStorageType::Azure => Self::new_azure(config), + ShuffleStorageType::Local => Err(BallistaError::General( + "Use LocalShuffleStorage for local storage".to_string(), + )), + } + } + + fn make_path(&self, job_id: &str, stage_id: usize, partition_id: usize, input_partition: usize) -> String { + let filename = if input_partition == partition_id { + "data.arrow".to_string() + } else { + format!("data-{}.arrow", input_partition) + }; + format!("{}/{}/{}/{}", job_id, stage_id, partition_id, filename) + } +} + +#[async_trait] +impl ShuffleStorage for ObjectStoreShuffleStorage { + async fn write_shuffle_data( + &self, + job_id: &str, + stage_id: usize, + partition_id: usize, + input_partition: usize, + batches: Vec, + schema: SchemaRef, + write_metric: &metrics::Time, + ) -> Result<(String, PartitionStats)> { + let relative_path = self.make_path(job_id, stage_id, partition_id, input_partition); + let full_url = format!("{}/{}", self.base_url, relative_path); + + debug!("Writing shuffle data to object store: {}", full_url); + + let timer = write_metric.timer(); + + // Write batches to an in-memory buffer first + let mut buffer = Vec::new(); + let options = IpcWriteOptions::default() + .try_with_compression(Some(CompressionType::LZ4_FRAME))?; + + let (total_rows, total_batches) = { + let mut writer = StreamWriter::try_new_with_options( + Cursor::new(&mut buffer), + schema.as_ref(), + options, + )?; + + let mut num_rows = 0; + let mut num_batches = 0; + + for batch in &batches { + num_rows += batch.num_rows(); + num_batches += 1; + writer.write(batch)?; + } + + writer.finish()?; + (num_rows, num_batches) + }; + + let num_bytes = buffer.len(); + + // Upload to object store + let object_path = ObjectPath::from(relative_path); + let payload = PutPayload::from(Bytes::from(buffer)); + + self.store.put(&object_path, payload).await.map_err(|e| { + BallistaError::General(format!( + "Failed to upload shuffle data to {}: {:?}", + full_url, e + )) + })?; + + timer.done(); + + let stats = PartitionStats::new( + Some(batches.iter().map(|b| b.num_rows() as u64).sum()), + Some(batches.len() as u64), + Some(num_bytes as u64), + ); + + Ok((full_url, stats)) + } + + async fn read_shuffle_data(&self, path: &str) -> Result> { + // Extract the object path from the full URL + let object_path = self.extract_object_path(path)?; + + debug!("Reading shuffle data from object store: {}", path); + + let get_result = self.store.get(&object_path).await.map_err(|e| { + BallistaError::General(format!( + "Failed to read shuffle data from {}: {:?}", + path, e + )) + })?; + + let bytes = get_result.bytes().await.map_err(|e| { + BallistaError::General(format!( + "Failed to read bytes from {}: {:?}", + path, e + )) + })?; + + let cursor = Cursor::new(bytes.to_vec()); + let stream_reader = StreamReader::try_new(cursor, None)?; + + let mut batches = Vec::new(); + for batch_result in stream_reader { + batches.push(batch_result?); + } + + Ok(batches) + } + + async fn delete_job_data(&self, job_id: &str) -> Result<()> { + let prefix = ObjectPath::from(job_id.to_string()); + + // List all objects with the job_id prefix + let mut list_stream = self.store.list(Some(&prefix)); + let mut objects_to_delete = Vec::new(); + + while let Some(result) = list_stream.next().await { + match result { + Ok(meta) => objects_to_delete.push(meta.location), + Err(e) => { + return Err(BallistaError::General(format!( + "Failed to list objects for job {}: {:?}", + job_id, e + ))); + } + } + } + + // Delete all objects + for path in objects_to_delete { + self.store.delete(&path).await.map_err(|e| { + BallistaError::General(format!( + "Failed to delete object {:?}: {:?}", + path, e + )) + })?; + } + + Ok(()) + } + + fn base_path(&self) -> &str { + &self.base_url + } + + fn can_handle(&self, path: &str) -> bool { + match self.storage_type { + ShuffleStorageType::S3 => path.starts_with("s3://"), + ShuffleStorageType::Azure => { + path.starts_with("abfs://") || path.starts_with("az://") + } + ShuffleStorageType::Local => false, + } + } +} + +impl ObjectStoreShuffleStorage { + fn extract_object_path(&self, path: &str) -> Result { + // Parse the URL and extract the path component + if let Ok(url) = Url::parse(path) { + let path_str = url.path().trim_start_matches('/'); + Ok(ObjectPath::from(path_str)) + } else { + // If it's not a valid URL, assume it's already a relative path + Ok(ObjectPath::from(path)) + } + } +} + +/// Factory for creating shuffle storage instances. +pub struct ShuffleStorageFactory; + +impl ShuffleStorageFactory { + /// Creates a shuffle storage instance based on the configuration. + pub fn create(config: &ShuffleStorageConfig) -> Result> { + match config.storage_type { + ShuffleStorageType::Local => { + let work_dir = config + .base_url + .as_ref() + .ok_or_else(|| BallistaError::General("Work directory not configured".to_string()))?; + Ok(Arc::new(LocalShuffleStorage::new(work_dir))) + } + ShuffleStorageType::S3 => { + Ok(Arc::new(ObjectStoreShuffleStorage::new_s3(config)?)) + } + ShuffleStorageType::Azure => { + Ok(Arc::new(ObjectStoreShuffleStorage::new_azure(config)?)) + } + } + } + + /// Creates a local shuffle storage with the given work directory. + pub fn create_local(work_dir: &str) -> Arc { + Arc::new(LocalShuffleStorage::new(work_dir)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::arrow::array::Int32Array; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; + use tempfile::TempDir; + + fn create_test_batch() -> (RecordBatch, SchemaRef) { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + (batch, schema) + } + + #[tokio::test] + async fn test_local_shuffle_storage_write_read() { + let temp_dir = TempDir::new().unwrap(); + let storage = LocalShuffleStorage::new(temp_dir.path().to_str().unwrap()); + + let (batch, schema) = create_test_batch(); + let metrics = ExecutionPlanMetricsSet::new(); + let time_metric = metrics::MetricBuilder::new(&metrics).subset_time("write_time", 0); + + let (path, stats) = storage + .write_shuffle_data( + "test_job", + 1, + 0, + 0, + vec![batch.clone()], + schema, + &time_metric, + ) + .await + .unwrap(); + + assert!(path.contains("test_job")); + assert_eq!(stats.num_rows, Some(3)); + assert_eq!(stats.num_batches, Some(1)); + + let read_batches = storage.read_shuffle_data(&path).await.unwrap(); + assert_eq!(read_batches.len(), 1); + assert_eq!(read_batches[0].num_rows(), 3); + } + + #[tokio::test] + async fn test_local_shuffle_storage_delete() { + let temp_dir = TempDir::new().unwrap(); + let storage = LocalShuffleStorage::new(temp_dir.path().to_str().unwrap()); + + let (batch, schema) = create_test_batch(); + let metrics = ExecutionPlanMetricsSet::new(); + let time_metric = metrics::MetricBuilder::new(&metrics).subset_time("write_time", 0); + + let (path, _) = storage + .write_shuffle_data( + "test_job", + 1, + 0, + 0, + vec![batch], + schema, + &time_metric, + ) + .await + .unwrap(); + + assert!(std::path::Path::new(&path).exists()); + + storage.delete_job_data("test_job").await.unwrap(); + assert!(!std::path::Path::new(&path).exists()); + } + + #[test] + fn test_shuffle_storage_type_parse() { + assert_eq!( + "local".parse::().unwrap(), + ShuffleStorageType::Local + ); + assert_eq!( + "s3".parse::().unwrap(), + ShuffleStorageType::S3 + ); + assert_eq!( + "azure".parse::().unwrap(), + ShuffleStorageType::Azure + ); + assert_eq!( + "abfs".parse::().unwrap(), + ShuffleStorageType::Azure + ); + } + + #[test] + fn test_storage_config_new_local() { + let config = ShuffleStorageConfig::new_local("/tmp/ballista"); + assert_eq!(config.storage_type, ShuffleStorageType::Local); + assert_eq!(config.base_url, Some("/tmp/ballista".to_string())); + } + + #[test] + fn test_storage_config_new_s3() { + let config = ShuffleStorageConfig::new_s3("my-bucket", Some("shuffle"), Some("us-east-1")); + assert_eq!(config.storage_type, ShuffleStorageType::S3); + assert_eq!(config.base_url, Some("s3://my-bucket/shuffle".to_string())); + assert_eq!(config.s3_config.bucket, Some("my-bucket".to_string())); + assert_eq!(config.s3_config.region, Some("us-east-1".to_string())); + } + + #[test] + fn test_storage_config_new_azure() { + let config = ShuffleStorageConfig::new_azure("myaccount", "mycontainer", Some("shuffle")); + assert_eq!(config.storage_type, ShuffleStorageType::Azure); + assert_eq!( + config.base_url, + Some("abfs://mycontainer@myaccount.dfs.core.windows.net/shuffle".to_string()) + ); + } +} From 1a98c29db4bdefcdc017ec9d3448cdbf8e74dddf Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Thu, 22 Jan 2026 17:39:39 -0800 Subject: [PATCH 2/6] Add comprehensive metrics instrumentation for scheduler and executor (#10) * Add shuffle read metrics extraction and QueryStageExecutor::plan() method - Add public getter methods to PartitionStats (num_rows, num_batches, num_bytes) - Extend QueryStageExecutor trait with plan() method to access underlying ExecutionPlan - Add extract_shuffle_read_metrics() to walk plan tree and sum ShuffleReaderExec partition stats - Record shuffle read metrics (bytes, rows, duration) after successful task execution in executor * Add shuffle locality metrics to ExecutorMetricsCollector, SchedulerMetricsCollector, and ShuffleReaderExec - Add record_shuffle_read_local/remote methods to ExecutorMetricsCollector trait - Add record_task_shuffle_affinity_hit/miss methods to SchedulerMetricsCollector trait - Add ShuffleReadMetricsCallback trait in ballista-core for tracking local vs remote reads - Instrument shuffle_reader.rs to call metrics callback during partition fetches - Add SessionConfigExt methods to pass metrics callback via session config * Add metrics collector to SchedulerState and instrument executor and planning metrics - Add metrics_collector field to SchedulerState struct - Instrument record_planning_duration in submit_job - Instrument record_executor_registered/deregistered and set_active_executor_count - Update all SchedulerState constructors and call sites * Add stage and task lifecycle metrics instrumentation to update_task_status flow * Add shuffle affinity metrics to scheduler task binding * Add actual task scheduling latency tracking - Add schedulable_time_millis field to TaskDescription to track when a task became schedulable (when its stage transitioned to running state) - Update all TaskDescription creation sites to pass RunningStage.stage_running_time - Calculate actual scheduling latency in record_task_scheduled calls by computing the difference between current time and schedulable_time_millis - This enables accurate scheduler_task_scheduling_latency_ms metrics instead of the previous placeholder value of 0 * fix lint --- ballista/core/src/client.rs | 2 - .../src/execution_plans/distributed_query.rs | 62 +++-- .../src/execution_plans/shuffle_reader.rs | 116 ++++++--- ballista/core/src/extension.rs | 191 ++++++++++++++ ballista/core/src/lib.rs | 4 +- .../remote_catalog/catalog_serialize_ext.rs | 7 +- ballista/core/src/remote_catalog/mod.rs | 10 + .../remote_function_serialize_ext.rs | 3 +- .../src/remote_catalog/remote_scalar_udf.rs | 3 +- .../remote_catalog/remote_table_provider.rs | 14 +- ballista/core/src/serde/scheduler/mod.rs | 15 ++ ballista/core/src/utils.rs | 4 +- ballista/executor/Cargo.toml | 4 +- ballista/executor/src/execution_engine.rs | 10 + ballista/executor/src/execution_loop.rs | 13 +- ballista/executor/src/executor.rs | 234 +++++++++++++++++- ballista/executor/src/executor_process.rs | 7 +- ballista/executor/src/executor_server.rs | 3 +- ballista/executor/src/metrics/mod.rs | 179 +++++++++++++- ballista/scheduler/src/cluster/memory.rs | 47 ++-- ballista/scheduler/src/cluster/mod.rs | 183 +++++++++++--- ballista/scheduler/src/config.rs | 15 ++ ballista/scheduler/src/metrics/mod.rs | 191 +++++++++++++- ballista/scheduler/src/metrics/prometheus.rs | 40 +++ ballista/scheduler/src/scheduler_process.rs | 5 +- .../scheduler/src/scheduler_server/grpc.rs | 61 ++++- .../scheduler/src/scheduler_server/mod.rs | 25 +- .../scheduler_server/query_stage_scheduler.rs | 10 + .../scheduler/src/state/execution_graph.rs | 231 +++++++++++++++-- .../scheduler/src/state/executor_manager.rs | 51 +--- ballista/scheduler/src/state/mod.rs | 110 +++++++- ballista/scheduler/src/state/task_manager.rs | 66 ++++- ballista/scheduler/src/test_utils.rs | 51 ++++ 33 files changed, 1729 insertions(+), 238 deletions(-) diff --git a/ballista/core/src/client.rs b/ballista/core/src/client.rs index 02e82a95ac..0a7a7af368 100644 --- a/ballista/core/src/client.rs +++ b/ballista/core/src/client.rs @@ -47,8 +47,6 @@ use datafusion::error::Result; use crate::extension::BallistaConfigGrpcEndpoint; use crate::serde::protobuf; -use crate::utils::{GrpcClientConfig, create_grpc_client_connection}; - use crate::utils::create_grpc_client_endpoint; use datafusion::physical_plan::{RecordBatchStream, SendableRecordBatchStream}; diff --git a/ballista/core/src/execution_plans/distributed_query.rs b/ballista/core/src/execution_plans/distributed_query.rs index f083d8ed0d..d632e1063a 100644 --- a/ballista/core/src/execution_plans/distributed_query.rs +++ b/ballista/core/src/execution_plans/distributed_query.rs @@ -18,7 +18,8 @@ use crate::client::BallistaClient; use crate::config::BallistaConfig; use crate::extension::{ - BallistaConfigGrpcEndpoint, BallistaGrpcMetadataInterceptor, SessionConfigExt, + BallistaConfigGrpcEndpoint, BallistaGrpcMetadataInterceptor, + ResultFetchMetricsCallback, SessionConfigExt, }; use crate::serde::protobuf::SuccessfulJob; use crate::serde::protobuf::{ @@ -27,8 +28,6 @@ use crate::serde::protobuf::{ scheduler_grpc_client::SchedulerGrpcClient, }; -use crate::utils::{GrpcClientConfig, create_grpc_client_connection}; - use crate::utils::create_grpc_client_endpoint; use datafusion::arrow::datatypes::SchemaRef; @@ -248,8 +247,6 @@ impl ExecutionPlan for DistributedQueryExec { let metric_total_bytes = MetricBuilder::new(&self.metrics).counter("transferred_bytes", partition); - - let interceptor = context.session_config().ballista_grpc_interceptor(); let customize_endpoint = context @@ -258,23 +255,22 @@ impl ExecutionPlan for DistributedQueryExec { let use_tls = context.session_config().ballista_use_tls(); + let result_fetch_callback = context + .session_config() + .ballista_result_fetch_metrics_callback(); let stream = futures::stream::once( execute_query( self.scheduler_url.clone(), self.session_id.clone(), query, - - self.config.default_grpc_client_max_message_size(), - GrpcClientConfig::from(&self.config), Arc::new(self.metrics.clone()), partition, - self.config.clone(), interceptor, customize_endpoint, use_tls, - + result_fetch_callback, ) .map_err(|e| ArrowError::ExternalError(Box::new(e))), ) @@ -306,21 +302,18 @@ impl ExecutionPlan for DistributedQueryExec { } } +#[allow(clippy::too_many_arguments)] async fn execute_query( scheduler_url: String, session_id: String, query: ExecuteQueryParams, - - max_message_size: usize, - grpc_config: GrpcClientConfig, metrics: Arc, partition: usize, - config: BallistaConfig, grpc_interceptor: Arc, customize_endpoint: Option>, use_tls: bool, - + result_fetch_callback: Option>, ) -> Result> + Send> { // Capture query submission time for total_query_time_ms let query_start_time = std::time::Instant::now(); @@ -450,12 +443,14 @@ async fn execute_query( // This could be added in a future enhancement by wrapping the stream. let streams = partition_location.into_iter().map(move |partition| { + let callback = result_fetch_callback.clone(); let f = fetch_partition( partition, max_message_size, true, customize_endpoint.clone(), use_tls, + callback, ) .map_err(|e| ArrowError::ExternalError(Box::new(e))); @@ -474,13 +469,29 @@ async fn fetch_partition( flight_transport: bool, customize_endpoint: Option>, use_tls: bool, + metrics_callback: Option>, ) -> Result { + let start_time = std::time::Instant::now(); + let metadata = location.executor_meta.ok_or_else(|| { DataFusionError::Internal("Received empty executor metadata".to_owned()) })?; let partition_id = location.partition_id.ok_or_else(|| { DataFusionError::Internal("Received empty partition id".to_owned()) })?; + + // Extract stats before consuming location + let stats = location.partition_stats.as_ref(); + #[expect(clippy::cast_sign_loss)] + let expected_bytes = stats.map(|s| s.num_bytes as u64).unwrap_or(0); + #[expect(clippy::cast_sign_loss)] + let expected_rows = stats.map(|s| s.num_rows as u64).unwrap_or(0); + + let job_id = partition_id.job_id.clone(); + let stage_id = partition_id.stage_id as usize; + let partition = partition_id.partition_id as usize; + let executor_id = metadata.id.clone(); + let host = metadata.host.as_str(); let port = metadata.port as u16; let mut ballista_client = BallistaClient::try_new( @@ -492,7 +503,8 @@ async fn fetch_partition( ) .await .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; - ballista_client + + let stream = ballista_client .fetch_partition( &metadata.id, &partition_id.into(), @@ -502,5 +514,21 @@ async fn fetch_partition( flight_transport, ) .await - .map_err(|e| DataFusionError::External(Box::new(e))) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + // Record metrics after successful fetch + if let Some(callback) = metrics_callback { + let duration_ms = start_time.elapsed().as_millis() as u64; + callback.record_result_fetch( + &job_id, + stage_id, + partition, + &executor_id, + expected_bytes, + expected_rows, + duration_ms, + ); + } + + Ok(stream) } diff --git a/ballista/core/src/execution_plans/shuffle_reader.rs b/ballista/core/src/execution_plans/shuffle_reader.rs index 8d0ebbc4d4..532d4f7dbd 100644 --- a/ballista/core/src/execution_plans/shuffle_reader.rs +++ b/ballista/core/src/execution_plans/shuffle_reader.rs @@ -28,17 +28,19 @@ use std::result; use std::sync::Arc; use std::task::{Context, Poll}; +#[cfg(feature = "build-binary")] +use object_store::ObjectStore; #[cfg(feature = "build-binary")] use object_store::aws::AmazonS3Builder; #[cfg(feature = "build-binary")] use object_store::azure::MicrosoftAzureBuilder; #[cfg(feature = "build-binary")] -use object_store::ObjectStore; -#[cfg(feature = "build-binary")] use url::Url; use crate::client::BallistaClient; -use crate::extension::{BallistaConfigGrpcEndpoint, SessionConfigExt}; +use crate::extension::{ + BallistaConfigGrpcEndpoint, SessionConfigExt, ShuffleReadMetricsCallback, +}; use crate::serde::scheduler::{PartitionLocation, PartitionStats}; use datafusion::arrow::datatypes::SchemaRef; @@ -173,6 +175,7 @@ impl ExecutionPlan for ShuffleReaderExec { let prefer_flight = config.ballista_shuffle_reader_remote_prefer_flight(); let customize_endpoint = config.ballista_override_create_grpc_client_endpoint(); let use_tls = config.ballista_use_tls(); + let metrics_callback = config.ballista_shuffle_read_metrics_callback(); if force_remote_read { debug!( @@ -208,6 +211,7 @@ impl ExecutionPlan for ShuffleReaderExec { prefer_flight, customize_endpoint, use_tls, + metrics_callback, ); let result = RecordBatchStreamAdapter::new( @@ -388,7 +392,11 @@ impl Stream for AbortableReceiverStream { fn local_remote_read_split( partition_locations: Vec, force_remote_read: bool, -) -> (Vec, Vec, Vec) { +) -> ( + Vec, + Vec, + Vec, +) { if !force_remote_read { let (local, non_local): (Vec<_>, Vec<_>) = partition_locations .into_iter() @@ -408,6 +416,7 @@ fn check_is_object_store_location(location: &PartitionLocation) -> bool { path.starts_with("s3://") || path.starts_with("abfs://") || path.starts_with("az://") } +#[allow(clippy::too_many_arguments)] fn send_fetch_partitions( partition_locations: Vec, max_request_num: usize, @@ -416,13 +425,17 @@ fn send_fetch_partitions( flight_transport: bool, customize_endpoint: Option>, use_tls: bool, + metrics_callback: Option>, ) -> AbortableReceiverStream { let (response_sender, response_receiver) = mpsc::channel(max_request_num); let semaphore = Arc::new(Semaphore::new(max_request_num)); let mut spawned_tasks: Vec> = vec![]; - let (local_locations, object_store_locations, remote_locations): (Vec<_>, Vec<_>, Vec<_>) = - local_remote_read_split(partition_locations, force_remote_read); + let (local_locations, object_store_locations, remote_locations): ( + Vec<_>, + Vec<_>, + Vec<_>, + ) = local_remote_read_split(partition_locations, force_remote_read); debug!( "local shuffle file counts:{}, object store shuffle file count:{}, remote shuffle file count:{}.", @@ -434,8 +447,10 @@ fn send_fetch_partitions( // keep local shuffle files reading in serial order for memory control. let response_sender_c = response_sender.clone(); let customize_endpoint_c = customize_endpoint.clone(); + let metrics_callback_c = metrics_callback.clone(); spawned_tasks.push(SpawnedTask::spawn(async move { for p in local_locations { + let start_time = std::time::Instant::now(); let r = PartitionReaderEnum::Local .fetch_partition( &p, @@ -445,6 +460,25 @@ fn send_fetch_partitions( use_tls, ) .await; + + // Record local read metrics if callback is set and read succeeded + if r.is_ok() + && let Some(ref callback) = metrics_callback_c + { + let duration_ms = start_time.elapsed().as_millis() as u64; + let bytes = p.partition_stats.num_bytes().unwrap_or(0); + let rows = p.partition_stats.num_rows().unwrap_or(0); + callback.record_local_read( + &p.partition_id.job_id, + p.partition_id.stage_id, + p.partition_id.partition_id, + &p.executor_meta.id, + bytes, + rows, + duration_ms, + ); + } + if let Err(e) = response_sender_c.send(r).await { error!("Fail to send response event to the channel due to {e}"); } @@ -480,9 +514,11 @@ fn send_fetch_partitions( let semaphore = semaphore.clone(); let response_sender = response_sender.clone(); let customize_endpoint_c = customize_endpoint.clone(); + let metrics_callback_c = metrics_callback.clone(); spawned_tasks.push(SpawnedTask::spawn(async move { // Block if exceeds max request number. let permit = semaphore.acquire_owned().await.unwrap(); + let start_time = std::time::Instant::now(); let r = PartitionReaderEnum::FlightRemote .fetch_partition( &p, @@ -492,6 +528,25 @@ fn send_fetch_partitions( use_tls, ) .await; + + // Record remote read metrics if callback is set and read succeeded + if r.is_ok() + && let Some(ref callback) = metrics_callback_c + { + let duration_ms = start_time.elapsed().as_millis() as u64; + let bytes = p.partition_stats.num_bytes().unwrap_or(0); + let rows = p.partition_stats.num_rows().unwrap_or(0); + callback.record_remote_read( + &p.partition_id.job_id, + p.partition_id.stage_id, + p.partition_id.partition_id, + &p.executor_meta.id, + bytes, + rows, + duration_ms, + ); + } + // Block if the channel buffer is full. if let Err(e) = response_sender.send(r).await { error!("Fail to send response event to the channel due to {e}"); @@ -646,15 +701,17 @@ async fn fetch_partition_object_store( debug!("Fetching shuffle partition from object store: {}", path); - let batches = fetch_partition_object_store_inner(path).await.map_err(|e| { - // return BallistaError::FetchFailed may let scheduler retry this task. - BallistaError::FetchFailed( - metadata.id.clone(), - partition_id.stage_id, - partition_id.partition_id, - e.to_string(), - ) - })?; + let batches = fetch_partition_object_store_inner(path) + .await + .map_err(|e| { + // return BallistaError::FetchFailed may let scheduler retry this task. + BallistaError::FetchFailed( + metadata.id.clone(), + partition_id.stage_id, + partition_id.partition_id, + e.to_string(), + ) + })?; if batches.is_empty() { return Err(BallistaError::General(format!( @@ -684,7 +741,10 @@ async fn fetch_partition_object_store_inner( use object_store::path::Path as ObjectPath; let url = Url::parse(path).map_err(|e| { - BallistaError::General(format!("Failed to parse object store URL '{}': {:?}", path, e)) + BallistaError::General(format!( + "Failed to parse object store URL '{}': {:?}", + path, e + )) })?; let scheme = url.scheme(); @@ -703,7 +763,7 @@ async fn fetch_partition_object_store_inner( let host = url.host_str().ok_or_else(|| { BallistaError::General(format!("No host in Azure URL: {}", path)) })?; - + // Extract container from username portion let container = url.username(); if container.is_empty() { @@ -739,17 +799,11 @@ async fn fetch_partition_object_store_inner( debug!("Reading object from path: {:?}", object_path); let get_result = store.get(&object_path).await.map_err(|e| { - BallistaError::General(format!( - "Failed to read object from {}: {:?}", - path, e - )) + BallistaError::General(format!("Failed to read object from {}: {:?}", path, e)) })?; let bytes = get_result.bytes().await.map_err(|e| { - BallistaError::General(format!( - "Failed to read bytes from {}: {:?}", - path, e - )) + BallistaError::General(format!("Failed to read bytes from {}: {:?}", path, e)) })?; let cursor = Cursor::new(bytes.to_vec()); @@ -763,10 +817,7 @@ async fn fetch_partition_object_store_inner( let mut batches = Vec::new(); for batch_result in stream_reader { batches.push(batch_result.map_err(|e| { - BallistaError::General(format!( - "Failed to read batch from {}: {:?}", - path, e - )) + BallistaError::General(format!("Failed to read batch from {}: {:?}", path, e)) })?); } @@ -1130,13 +1181,15 @@ mod tests { let partition_locations = get_test_partition_locations(1, file_path.to_str().unwrap().to_string()); - let (local, object_store, remote) = local_remote_read_split(partition_locations.clone(), false); + let (local, object_store, remote) = + local_remote_read_split(partition_locations.clone(), false); assert!(!local.is_empty()); assert!(object_store.is_empty()); assert!(remote.is_empty()); - let (local, object_store, remote) = local_remote_read_split(partition_locations, true); + let (local, object_store, remote) = + local_remote_read_split(partition_locations, true); assert!(local.is_empty()); assert!(object_store.is_empty()); @@ -1169,6 +1222,7 @@ mod tests { true, None, false, + None, // No metrics callback in tests ); let stream = RecordBatchStreamAdapter::new( diff --git a/ballista/core/src/extension.rs b/ballista/core/src/extension.rs index afe705dc00..28a06185a1 100644 --- a/ballista/core/src/extension.rs +++ b/ballista/core/src/extension.rs @@ -161,11 +161,13 @@ pub trait SessionConfigExt { /// Get a `tonic` interceptor configured to decorate the provided metadata keys fn ballista_grpc_interceptor(&self) -> Arc; + /// Set an override function for creating gRPC client endpoints. fn with_ballista_override_create_grpc_client_endpoint( self, override_f: EndpointOverrideFn, ) -> Self; + /// Get the override function for creating gRPC client endpoints. fn ballista_override_create_grpc_client_endpoint( &self, ) -> Option>; @@ -187,6 +189,34 @@ pub trait SessionConfigExt { /// Sets the shuffle storage base URL/path. fn with_ballista_shuffle_storage_url(self, url: &str) -> Self; + + /// Set a callback for recording shuffle read metrics (local vs remote). + /// + /// This callback will be invoked by the shuffle reader during execution + /// to record detailed metrics about local and remote shuffle reads. + fn with_ballista_shuffle_read_metrics_callback( + self, + callback: Arc, + ) -> Self; + + /// Get the shuffle read metrics callback if one has been set. + fn ballista_shuffle_read_metrics_callback( + &self, + ) -> Option>; + + /// Set a callback for recording result fetch metrics. + /// + /// This callback will be invoked by `DistributedQueryExec` when fetching + /// final query results from executors. + fn with_ballista_result_fetch_metrics_callback( + self, + callback: Arc, + ) -> Self; + + /// Get the result fetch metrics callback if one has been set. + fn ballista_result_fetch_metrics_callback( + &self, + ) -> Option>; } /// [SessionConfigHelperExt] is set of [SessionConfig] extension methods @@ -504,6 +534,36 @@ impl SessionConfigExt for SessionConfig { .set_str(BALLISTA_SHUFFLE_STORAGE_URL, url) } } + + fn with_ballista_shuffle_read_metrics_callback( + self, + callback: Arc, + ) -> Self { + let extension = ShuffleReadMetricsCallbackExtension::new(callback); + self.with_extension(Arc::new(extension)) + } + + fn ballista_shuffle_read_metrics_callback( + &self, + ) -> Option> { + self.get_extension::() + .map(|ext| ext.callback()) + } + + fn with_ballista_result_fetch_metrics_callback( + self, + callback: Arc, + ) -> Self { + let extension = ResultFetchMetricsCallbackExtension::new(callback); + self.with_extension(Arc::new(extension)) + } + + fn ballista_result_fetch_metrics_callback( + &self, + ) -> Option> { + self.get_extension::() + .map(|ext| ext.callback()) + } } impl SessionConfigHelperExt for SessionConfig { @@ -651,6 +711,7 @@ pub struct BallistaGrpcMetadataInterceptor { } impl BallistaGrpcMetadataInterceptor { + /// Create a new interceptor with additional metadata to be added to requests. pub fn new(additional_metadata: HashMap) -> Self { Self { additional_metadata, @@ -681,16 +742,19 @@ impl Interceptor for BallistaGrpcMetadataInterceptor { } } +/// Wrapper for gRPC endpoint configuration override function. #[derive(Clone)] pub struct BallistaConfigGrpcEndpoint { override_f: EndpointOverrideFn, } impl BallistaConfigGrpcEndpoint { + /// Create a new endpoint configuration with the given override function. pub fn new(override_f: EndpointOverrideFn) -> Self { Self { override_f } } + /// Configure an endpoint using the override function. pub fn configure_endpoint( &self, endpoint: Endpoint, @@ -703,6 +767,133 @@ impl BallistaConfigGrpcEndpoint { #[derive(Clone, Copy)] pub struct BallistaUseTls(pub bool); +/// Callback trait for recording shuffle read metrics from the shuffle reader. +/// +/// This trait is designed to be passed via session config extension to the +/// shuffle reader, allowing external systems (like Spice) to capture detailed +/// shuffle read locality metrics without creating circular dependencies. +pub trait ShuffleReadMetricsCallback: Send + Sync { + /// Record a local shuffle read operation. + /// + /// Called when the shuffle reader successfully reads data from a local file + /// (i.e., the partition was produced by this executor). + /// + /// # Arguments + /// * `job_id` - The job identifier + /// * `stage_id` - The stage that is reading the shuffle data + /// * `partition` - The partition being read + /// * `source_executor_id` - The executor that produced the shuffle data (same as current executor for local reads) + /// * `bytes` - Number of bytes read + /// * `rows` - Number of rows read + /// * `duration_ms` - Time taken to read the partition + #[allow(clippy::too_many_arguments)] + fn record_local_read( + &self, + job_id: &str, + stage_id: usize, + partition: usize, + source_executor_id: &str, + bytes: u64, + rows: u64, + duration_ms: u64, + ); + + /// Record a remote shuffle read operation. + /// + /// Called when the shuffle reader fetches data from a remote executor + /// via Arrow Flight. + /// + /// # Arguments + /// * `job_id` - The job identifier + /// * `stage_id` - The stage that is reading the shuffle data + /// * `partition` - The partition being read + /// * `source_executor_id` - The executor that produced the shuffle data + /// * `bytes` - Number of bytes read + /// * `rows` - Number of rows read + /// * `duration_ms` - Time taken to fetch the partition + #[allow(clippy::too_many_arguments)] + fn record_remote_read( + &self, + job_id: &str, + stage_id: usize, + partition: usize, + source_executor_id: &str, + bytes: u64, + rows: u64, + duration_ms: u64, + ); +} + +/// Session config extension wrapper for the shuffle read metrics callback. +#[derive(Clone)] +pub struct ShuffleReadMetricsCallbackExtension { + callback: Arc, +} + +impl ShuffleReadMetricsCallbackExtension { + /// Create a new extension wrapping the provided callback. + pub fn new(callback: Arc) -> Self { + Self { callback } + } + + /// Get the callback. + pub fn callback(&self) -> Arc { + Arc::clone(&self.callback) + } +} + +/// Callback trait for recording result fetch metrics from distributed query execution. +/// +/// This trait is designed to be passed via session config extension to the +/// `DistributedQueryExec`, allowing external systems (like Spice) to capture detailed +/// metrics about fetching final query results from executors. +/// +/// Note: Result fetching is always "remote" from the client's perspective since +/// the client (scheduler in Spice's case) always fetches from executors over the network. +pub trait ResultFetchMetricsCallback: Send + Sync { + /// Record a result fetch operation from an executor. + /// + /// Called when the client successfully fetches final query result data from an executor. + /// + /// # Arguments + /// * `job_id` - The job identifier + /// * `stage_id` - The final stage that produced the results + /// * `partition` - The partition being fetched + /// * `source_executor_id` - The executor that produced the result data + /// * `bytes` - Number of bytes fetched + /// * `rows` - Number of rows fetched + /// * `duration_ms` - Time taken to fetch the partition + #[allow(clippy::too_many_arguments)] + fn record_result_fetch( + &self, + job_id: &str, + stage_id: usize, + partition: usize, + source_executor_id: &str, + bytes: u64, + rows: u64, + duration_ms: u64, + ); +} + +/// Session config extension wrapper for the result fetch metrics callback. +#[derive(Clone)] +pub struct ResultFetchMetricsCallbackExtension { + callback: Arc, +} + +impl ResultFetchMetricsCallbackExtension { + /// Create a new extension wrapping the provided callback. + pub fn new(callback: Arc) -> Self { + Self { callback } + } + + /// Get the callback. + pub fn callback(&self) -> Arc { + Arc::clone(&self.callback) + } +} + #[cfg(test)] mod test { use datafusion::{ diff --git a/ballista/core/src/lib.rs b/ballista/core/src/lib.rs index bfdc7dc4f9..e0107bc6d0 100644 --- a/ballista/core/src/lib.rs +++ b/ballista/core/src/lib.rs @@ -58,10 +58,10 @@ pub mod serde; #[cfg(feature = "build-binary")] pub mod shuffle_storage; -/// General utility functions for Ballista operations. - +/// Remote catalog serialization and stub providers for Ballista clients. pub mod remote_catalog; +/// General utility functions for Ballista operations. pub mod utils; /// diff --git a/ballista/core/src/remote_catalog/catalog_serialize_ext.rs b/ballista/core/src/remote_catalog/catalog_serialize_ext.rs index af5d3c8603..331155ec23 100644 --- a/ballista/core/src/remote_catalog/catalog_serialize_ext.rs +++ b/ballista/core/src/remote_catalog/catalog_serialize_ext.rs @@ -18,17 +18,20 @@ use crate::serde::generated::ballista::{CatalogInfo, SchemaInfo, TableInfo}; use datafusion::catalog::CatalogProvider; use datafusion::prelude::SessionContext; -use futures::stream; use futures::StreamExt; +use futures::stream; use std::sync::Arc; -/// Used to serialize catalog schemas and names to ship to Ballista clients +/// Extension trait for serializing catalog schemas and names to ship to Ballista clients. #[async_trait::async_trait] pub trait CatalogSerializeExt { + /// Serialize all catalogs in the session context. async fn serialize_catalogs(&self) -> Vec; + /// Serialize a specific catalog by name. async fn serialize_catalog(&self, name: &str) -> Option; + /// Serialize a specific schema within a catalog. async fn serialize_schema( &self, name: &str, diff --git a/ballista/core/src/remote_catalog/mod.rs b/ballista/core/src/remote_catalog/mod.rs index 8b8728791e..c5e1a76385 100644 --- a/ballista/core/src/remote_catalog/mod.rs +++ b/ballista/core/src/remote_catalog/mod.rs @@ -16,7 +16,17 @@ // under the License. // +//! Remote catalog serialization and stub providers for Ballista clients. +//! +//! This module provides functionality to serialize catalog metadata (schemas, tables, functions) +//! from the scheduler to ship to Ballista clients, as well as stub providers that allow clients +//! to perform logical planning without access to actual table data. + +/// Extension trait for serializing catalog schemas and table names. pub mod catalog_serialize_ext; +/// Extension trait for serializing user-defined functions. pub mod remote_function_serialize_ext; +/// Stub scalar UDF implementation for remote function planning. pub mod remote_scalar_udf; +/// Stub table provider for remote table planning. pub mod remote_table_provider; diff --git a/ballista/core/src/remote_catalog/remote_function_serialize_ext.rs b/ballista/core/src/remote_catalog/remote_function_serialize_ext.rs index 8af5002ade..d7a0a28eb5 100644 --- a/ballista/core/src/remote_catalog/remote_function_serialize_ext.rs +++ b/ballista/core/src/remote_catalog/remote_function_serialize_ext.rs @@ -25,8 +25,9 @@ use datafusion::prelude::SessionContext; use datafusion_proto_common::ArrowType; use std::collections::HashSet; -/// Used to serialize function shapes to ship to Ballista clients +/// Extension trait for serializing function signatures to ship to Ballista clients. pub trait RemoteFunctionSerializeExt { + /// Serialize all user-defined scalar functions in the session context. fn serialize_udfs(&self) -> Vec; } diff --git a/ballista/core/src/remote_catalog/remote_scalar_udf.rs b/ballista/core/src/remote_catalog/remote_scalar_udf.rs index d93a97f409..bfe99ca09c 100644 --- a/ballista/core/src/remote_catalog/remote_scalar_udf.rs +++ b/ballista/core/src/remote_catalog/remote_scalar_udf.rs @@ -19,7 +19,7 @@ use crate::serde::protobuf::ScalarUdfInfo; use datafusion::arrow::datatypes::DataType; use datafusion::common::Result; -use datafusion::common::{exec_err, plan_err, DataFusionError}; +use datafusion::common::{DataFusionError, exec_err, plan_err}; use datafusion::logical_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, @@ -56,6 +56,7 @@ impl PartialEq for RemoteScalarUDF { impl Eq for RemoteScalarUDF {} impl RemoteScalarUDF { + /// Create a new RemoteScalarUDF from a ScalarUdfInfo protobuf message. pub fn new(meta: ScalarUdfInfo) -> Result { let mut arities = vec![]; diff --git a/ballista/core/src/remote_catalog/remote_table_provider.rs b/ballista/core/src/remote_catalog/remote_table_provider.rs index 8f5c8ca1be..b78aacb8fb 100644 --- a/ballista/core/src/remote_catalog/remote_table_provider.rs +++ b/ballista/core/src/remote_catalog/remote_table_provider.rs @@ -18,7 +18,7 @@ use datafusion::arrow::datatypes::SchemaRef; use datafusion::catalog::{Session, TableProvider}; -use datafusion::common::{exec_err, Result}; +use datafusion::common::{Result, exec_err}; use datafusion::datasource::TableType; use datafusion::logical_expr::Expr; use datafusion::physical_plan::ExecutionPlan; @@ -36,6 +36,7 @@ pub struct RemoteTableProvider { } impl RemoteTableProvider { + /// Create a new RemoteTableProvider with the given catalog, schema, and table names. pub fn new( catalog_name: &str, schema_name: &str, @@ -50,14 +51,17 @@ impl RemoteTableProvider { } } + /// Get the catalog name. pub fn catalog_name(&self) -> &str { &self.catalog_name } + /// Get the schema name. pub fn schema_name(&self) -> &str { &self.schema_name } + /// Get the table name. pub fn table_name(&self) -> &str { &self.table_name } @@ -84,7 +88,11 @@ impl TableProvider for RemoteTableProvider { _filters: &[Expr], _limit: Option, ) -> Result> { - exec_err!("{}.{}.{} is a stub table implementation to be resolved on the Ballista scheduler. It should not be scanned on the client. This is a bug.", - self.catalog_name, self.schema_name, self.table_name) + exec_err!( + "{}.{}.{} is a stub table implementation to be resolved on the Ballista scheduler. It should not be scanned on the client. This is a bug.", + self.catalog_name, + self.schema_name, + self.table_name + ) } } diff --git a/ballista/core/src/serde/scheduler/mod.rs b/ballista/core/src/serde/scheduler/mod.rs index 1dd0855162..407d62bac2 100644 --- a/ballista/core/src/serde/scheduler/mod.rs +++ b/ballista/core/src/serde/scheduler/mod.rs @@ -163,6 +163,21 @@ impl PartitionStats { } } + /// Returns the number of rows in the partition, if known. + pub fn num_rows(&self) -> Option { + self.num_rows + } + + /// Returns the number of batches in the partition, if known. + pub fn num_batches(&self) -> Option { + self.num_batches + } + + /// Returns the number of bytes in the partition, if known. + pub fn num_bytes(&self) -> Option { + self.num_bytes + } + /// Returns the Arrow struct field representation of these statistics. pub fn arrow_struct_repr(self) -> Field { Field::new( diff --git a/ballista/core/src/utils.rs b/ballista/core/src/utils.rs index 3b6674e89c..861dd72e7a 100644 --- a/ballista/core/src/utils.rs +++ b/ballista/core/src/utils.rs @@ -223,7 +223,9 @@ where /// Creates a gRPC client endpoint (returns Endpoint without connecting). /// Used for TLS/API key customization before establishing connection. -pub fn create_grpc_client_endpoint(dst: D) -> std::result::Result +pub fn create_grpc_client_endpoint( + dst: D, +) -> std::result::Result where D: std::convert::TryInto, D::Error: Into, diff --git a/ballista/executor/Cargo.toml b/ballista/executor/Cargo.toml index fad29d3e65..fb692cbf51 100644 --- a/ballista/executor/Cargo.toml +++ b/ballista/executor/Cargo.toml @@ -41,10 +41,10 @@ arrow = { workspace = true } arrow-flight = { workspace = true } async-trait = { workspace = true } -ballista-core = { path = "../core", version = "51.0.0" } - backoff = { workspace = true } +ballista-core = { path = "../core", version = "51.0.0" } + clap = { workspace = true, optional = true } dashmap = { workspace = true } datafusion = { workspace = true } diff --git a/ballista/executor/src/execution_engine.rs b/ballista/executor/src/execution_engine.rs index 45732828d6..50adb1865c 100644 --- a/ballista/executor/src/execution_engine.rs +++ b/ballista/executor/src/execution_engine.rs @@ -72,6 +72,12 @@ pub trait QueryStageExecutor: Sync + Send + Debug + Display { /// Collects execution metrics from all operators in the plan. fn collect_plan_metrics(&self) -> Vec; + + /// Returns a reference to the underlying execution plan. + /// + /// This is used to walk the plan tree and extract metrics from specific + /// operators like ShuffleReaderExec. + fn plan(&self) -> &dyn ExecutionPlan; } /// Default execution engine using DataFusion's ShuffleWriterExec. @@ -162,4 +168,8 @@ impl QueryStageExecutor for DefaultQueryStageExec { fn collect_plan_metrics(&self) -> Vec { utils::collect_plan_metrics(&self.shuffle_writer) } + + fn plan(&self) -> &dyn ExecutionPlan { + &self.shuffle_writer + } } diff --git a/ballista/executor/src/execution_loop.rs b/ballista/executor/src/execution_loop.rs index 35e05002d5..3b16595977 100644 --- a/ballista/executor/src/execution_loop.rs +++ b/ballista/executor/src/execution_loop.rs @@ -27,8 +27,8 @@ use crate::executor_process::remove_job_dir; use crate::{TaskExecutionTimes, as_task_status}; -use backoff::backoff::Backoff; use backoff::ExponentialBackoff; +use backoff::backoff::Backoff; use ballista_core::error::BallistaError; use ballista_core::extension::SessionConfigHelperExt; @@ -48,15 +48,13 @@ use std::any::Any; use std::cell::LazyCell; use std::convert::TryInto; use std::error::Error; -use std::sync::mpsc::{Receiver, Sender, TryRecvError}; use std::sync::Arc; +use std::sync::mpsc::{Receiver, Sender, TryRecvError}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use tokio::sync::oneshot::Sender as OneShotSender; use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use tonic::codegen::{Body, Bytes, StdError}; - - /// Main execution loop that polls the scheduler for available tasks. /// /// This function runs indefinitely, periodically asking the scheduler for @@ -68,7 +66,10 @@ use tonic::codegen::{Body, Bytes, StdError}; /// Number of consecutive failures before reducing log level from WARN to DEBUG. const QUIET_AFTER_FAILURES: u32 = 5; - +/// Main polling loop for executor task execution. +/// +/// This function polls the scheduler for new tasks to execute and runs them, +/// ensuring no more than the configured number of tasks run simultaneously. pub async fn poll_loop( mut scheduler: SchedulerGrpcClient, executor: Arc, @@ -237,7 +238,6 @@ where } } Err(error) => { - warn!( "Executor poll work loop failed. If this continues to happen the Scheduler might be marked as dead. Error: {error}" ); @@ -260,7 +260,6 @@ where tokio::time::sleep(duration).await; } continue; - } } diff --git a/ballista/executor/src/executor.rs b/ballista/executor/src/executor.rs index d376136913..43fc128243 100644 --- a/ballista/executor/src/executor.rs +++ b/ballista/executor/src/executor.rs @@ -25,6 +25,7 @@ use crate::metrics::LoggingMetricsCollector; use ballista_core::ConfigProducer; use ballista_core::RuntimeProducer; use ballista_core::error::BallistaError; +use ballista_core::execution_plans::ShuffleReaderExec; use ballista_core::registry::BallistaFunctionRegistry; use ballista_core::serde::protobuf; use ballista_core::serde::protobuf::ExecutorRegistration; @@ -32,6 +33,7 @@ use ballista_core::serde::scheduler::PartitionId; use dashmap::DashMap; use datafusion::execution::context::TaskContext; use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionConfig; use futures::future::AbortHandle; use std::future::Future; @@ -39,6 +41,160 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +/// Categorize a BallistaError into a short string suitable for use as a metric label. +/// +/// This function is provided for use by custom `ExecutorMetricsCollector` implementations +/// that may need to categorize errors differently than the default. +#[allow(dead_code)] +pub fn categorize_ballista_error(error: &BallistaError) -> String { + match error { + BallistaError::NotImplemented(_) => "not_implemented".to_string(), + BallistaError::General(_) => "general".to_string(), + BallistaError::Internal(_) => "internal".to_string(), + BallistaError::Configuration(_) => "configuration".to_string(), + BallistaError::ArrowError(_) => "arrow".to_string(), + BallistaError::DataFusionError(_) => "datafusion".to_string(), + BallistaError::SqlError(_) => "sql".to_string(), + BallistaError::IoError(_) => "io".to_string(), + BallistaError::TonicError(_) => "tonic".to_string(), + BallistaError::GrpcError(_) => "grpc".to_string(), + BallistaError::GrpcConnectionError(_) => "grpc_connection".to_string(), + BallistaError::TokioError(_) => "tokio".to_string(), + BallistaError::GrpcActionError(_) => "grpc_action".to_string(), + BallistaError::FetchFailed(_, _, _, _) => "fetch_failed".to_string(), + BallistaError::Cancelled => "cancelled".to_string(), + } +} + +/// Categorize a DataFusionError into a short string suitable for use as a metric label. +fn categorize_datafusion_error(error: &datafusion::error::DataFusionError) -> String { + use datafusion::error::DataFusionError; + match error { + DataFusionError::ArrowError(_, _) => "arrow".to_string(), + DataFusionError::IoError(_) => "io".to_string(), + DataFusionError::SQL(_, _) => "sql".to_string(), + DataFusionError::NotImplemented(_) => "not_implemented".to_string(), + DataFusionError::Internal(_) => "internal".to_string(), + DataFusionError::Plan(_) => "plan".to_string(), + DataFusionError::Configuration(_) => "configuration".to_string(), + DataFusionError::SchemaError(_, _) => "schema".to_string(), + DataFusionError::Execution(_) => "execution".to_string(), + DataFusionError::ResourcesExhausted(_) => "resources_exhausted".to_string(), + DataFusionError::External(_) => "external".to_string(), + DataFusionError::Context(_, _) => "context".to_string(), + DataFusionError::Substrait(_) => "substrait".to_string(), + DataFusionError::Diagnostic(_, _) => "diagnostic".to_string(), + DataFusionError::Collection(_) => "collection".to_string(), + DataFusionError::ParquetError(_) => "parquet".to_string(), + DataFusionError::ObjectStore(_) => "object_store".to_string(), + DataFusionError::ExecutionJoin(_) => "execution_join".to_string(), + DataFusionError::Shared(_) => "shared".to_string(), + // Catch-all for feature-gated variants (e.g., AvroError when avro feature is enabled) + #[allow(unreachable_patterns)] + _ => "other".to_string(), + } +} + +/// Extract shuffle write metrics from the query stage executor's plan metrics. +/// +/// Returns (bytes_written, rows_written, write_time_ms) if metrics are available. +fn extract_shuffle_write_metrics( + query_stage_exec: &Arc, +) -> Option<(u64, u64, u64)> { + let metrics_sets = query_stage_exec.collect_plan_metrics(); + + let total_bytes = 0u64; + let mut total_rows = 0u64; + let mut total_write_time_nanos = 0u64; + + for metrics_set in &metrics_sets { + for metric in metrics_set.iter() { + let name = metric.value().name(); + match name { + "output_rows" => { + total_rows += metric.value().as_usize() as u64; + } + "write_time" => { + // write_time is recorded in nanoseconds + total_write_time_nanos += metric.value().as_usize() as u64; + } + _ => {} + } + } + } + + // Note: bytes_written is not directly tracked in ShuffleWriteMetrics, + // but we can estimate from the output file sizes or use output_rows as proxy. + // For now, we'll return 0 for bytes and let the caller decide. + // TODO: Add bytes tracking to ShuffleWriterExec if needed. + + if total_rows > 0 || total_write_time_nanos > 0 { + let write_time_ms = total_write_time_nanos / 1_000_000; + Some((total_bytes, total_rows, write_time_ms)) + } else { + None + } +} + +/// Extract shuffle read metrics by walking the execution plan tree and summing +/// the partition statistics from all ShuffleReaderExec nodes. +/// +/// Returns (total_bytes, total_rows) if any shuffle readers are found with stats. +/// Note: Duration is not available from stats; it would need to be tracked during execution. +fn extract_shuffle_read_metrics(plan: &dyn ExecutionPlan) -> Option<(u64, u64)> { + let mut total_bytes = 0u64; + let mut total_rows = 0u64; + let mut found_any = false; + + // Recursively walk the plan tree + extract_shuffle_read_metrics_recursive( + plan, + &mut total_bytes, + &mut total_rows, + &mut found_any, + ); + + if found_any { + Some((total_bytes, total_rows)) + } else { + None + } +} + +fn extract_shuffle_read_metrics_recursive( + plan: &dyn ExecutionPlan, + total_bytes: &mut u64, + total_rows: &mut u64, + found_any: &mut bool, +) { + // Check if this node is a ShuffleReaderExec + if let Some(shuffle_reader) = plan.as_any().downcast_ref::() { + // Sum up partition stats from all partition locations + for partition_locations in &shuffle_reader.partition { + for location in partition_locations { + if let Some(bytes) = location.partition_stats.num_bytes() { + *total_bytes += bytes; + *found_any = true; + } + if let Some(rows) = location.partition_stats.num_rows() { + *total_rows += rows; + *found_any = true; + } + } + } + } + + // Recurse into children + for child in plan.children() { + extract_shuffle_read_metrics_recursive( + child.as_ref(), + total_bytes, + total_rows, + found_any, + ); + } +} + /// A future that resolves when all active tasks on an executor have completed. /// /// This is used during graceful shutdown to wait for in-flight tasks to drain @@ -168,6 +324,15 @@ impl Executor { query_stage_exec: Arc, task_ctx: Arc, ) -> Result, BallistaError> { + let start_time = std::time::Instant::now(); + + // Record task start for metrics tracking + self.metrics_collector.record_task_started( + &partition.job_id, + partition.stage_id, + partition.partition_id, + ); + let (task, abort_handle) = futures::future::abortable( query_stage_exec.execute_query_stage(partition.partition_id, task_ctx), ); @@ -175,18 +340,71 @@ impl Executor { self.abort_handles .insert((task_id, partition.clone()), abort_handle); - let partitions = task.await??; + let result = task.await; + let duration_ms = start_time.elapsed().as_millis() as u64; self.abort_handles.remove(&(task_id, partition.clone())); - self.metrics_collector.record_stage( - &partition.job_id, - partition.stage_id, - partition.partition_id, - query_stage_exec, - ); + match result { + Ok(Ok(partitions)) => { + // Extract shuffle write metrics from the plan + let shuffle_write_metrics = + extract_shuffle_write_metrics(&query_stage_exec); + if let Some((bytes, rows, write_time_ms)) = shuffle_write_metrics { + self.metrics_collector.record_shuffle_write( + &partition.job_id, + partition.stage_id, + partition.partition_id, + bytes, + rows, + write_time_ms, + ); + } - Ok(partitions) + // Extract shuffle read metrics from ShuffleReaderExec nodes in the plan + // Note: Duration is approximated as task duration minus write time since + // we don't have fine-grained timing for the read phase + let shuffle_read_metrics = + extract_shuffle_read_metrics(query_stage_exec.plan()); + if let Some((bytes, rows)) = shuffle_read_metrics { + // Approximate read duration: if we have write time, subtract it from total + // Otherwise, use 0 (the bytes/rows are still valuable) + let read_duration_ms = shuffle_write_metrics + .map(|(_, _, write_ms)| duration_ms.saturating_sub(write_ms)) + .unwrap_or(0); + self.metrics_collector.record_shuffle_read( + &partition.job_id, + partition.stage_id, + partition.partition_id, + bytes, + rows, + read_duration_ms, + ); + } + + self.metrics_collector.record_stage( + &partition.job_id, + partition.stage_id, + partition.partition_id, + query_stage_exec, + duration_ms, + ); + Ok(partitions) + } + Ok(Err(e)) => { + self.metrics_collector.record_task_failed( + &partition.job_id, + partition.stage_id, + partition.partition_id, + &categorize_datafusion_error(&e), + ); + Err(BallistaError::from(e)) + } + Err(_aborted) => { + // Task was cancelled - don't record as failure, it was intentional + Err(BallistaError::Cancelled) + } + } } /// Cancels a running task by aborting its execution. diff --git a/ballista/executor/src/executor_process.rs b/ballista/executor/src/executor_process.rs index c270f4b8cf..971ff5b783 100644 --- a/ballista/executor/src/executor_process.rs +++ b/ballista/executor/src/executor_process.rs @@ -42,7 +42,6 @@ use datafusion::execution::runtime_env::RuntimeEnvBuilder; use ballista_core::config::{LogRotationPolicy, TaskSchedulingPolicy}; use ballista_core::error::BallistaError; -use ballista_core::extension::SessionConfigExt; use ballista_core::serde::protobuf::executor_resource::Resource; use ballista_core::serde::protobuf::executor_status::Status; use ballista_core::serde::protobuf::{ @@ -53,8 +52,8 @@ use ballista_core::serde::{ BallistaCodec, BallistaLogicalExtensionCodec, BallistaPhysicalExtensionCodec, }; use ballista_core::utils::{ - GrpcServerConfig, create_grpc_client_connection, create_grpc_client_endpoint, - create_grpc_server, default_config_producer, get_time_before, + GrpcServerConfig, create_grpc_client_endpoint, create_grpc_server, + default_config_producer, get_time_before, }; use ballista_core::{BALLISTA_VERSION, ConfigProducer, RuntimeProducer}; use tonic::transport::{Endpoint, Error as TonicTransportError}; @@ -289,8 +288,6 @@ pub async fn start_executor_process( )); let connect_timeout = opt.scheduler_connect_timeout_seconds as u64; - let session_config = (executor.config_producer)(); - let ballista_config = session_config.ballista_config(); let connection = if connect_timeout == 0 { let mut endpoint = create_grpc_client_endpoint(scheduler_url).map_err(|_| { BallistaError::GrpcConnectionError( diff --git a/ballista/executor/src/executor_server.rs b/ballista/executor/src/executor_server.rs index 86a21d558c..a792e89be2 100644 --- a/ballista/executor/src/executor_server.rs +++ b/ballista/executor/src/executor_server.rs @@ -38,7 +38,6 @@ pub type EndpointOverrideFn = Arc Result + Send + Sync>; use ballista_core::error::BallistaError; -use ballista_core::extension::SessionConfigExt; use ballista_core::serde::BallistaCodec; use ballista_core::serde::protobuf::{ CancelTasksParams, CancelTasksResult, ExecutorMetric, ExecutorStatus, @@ -55,7 +54,7 @@ use ballista_core::serde::scheduler::TaskDefinition; use ballista_core::serde::scheduler::from_proto::{ get_task_definition, get_task_definition_vec, }; -use ballista_core::utils::{create_grpc_client_connection, create_grpc_client_endpoint, create_grpc_server}; +use ballista_core::utils::{create_grpc_client_endpoint, create_grpc_server}; use dashmap::DashMap; use datafusion::execution::TaskContext; diff --git a/ballista/executor/src/metrics/mod.rs b/ballista/executor/src/metrics/mod.rs index acc2cceed3..4cdb31356b 100644 --- a/ballista/executor/src/metrics/mod.rs +++ b/ballista/executor/src/metrics/mod.rs @@ -19,37 +19,202 @@ use crate::execution_engine::QueryStageExecutor; use log::info; use std::sync::Arc; -/// `ExecutorMetricsCollector` records metrics for `ShuffleWriteExec` -/// after they are executed. +/// `ExecutorMetricsCollector` records metrics for task execution on an executor. /// -/// After each stage completes, `ShuffleWriteExec::record_stage` will be -/// called. +/// This trait provides hooks for recording metrics at various points during +/// task execution, including task start, completion, failure, and shuffle operations. +/// +/// Implementations can use these hooks to integrate with metrics systems like +/// Prometheus, OpenTelemetry, or custom monitoring solutions. pub trait ExecutorMetricsCollector: Send + Sync { - /// Record metrics for stage after it is executed + /// Record that a task has started execution. + /// + /// Called when a task begins executing on this executor. Use this to track + /// active task counts and task start times. + fn record_task_started(&self, job_id: &str, stage_id: usize, partition: usize); + + /// Record metrics for a stage/task after successful execution. + /// + /// Called when a task completes successfully. The `plan` contains execution + /// metrics from DataFusion, and `duration_ms` is the wall-clock execution time. fn record_stage( &self, job_id: &str, stage_id: usize, partition: usize, plan: Arc, + duration_ms: u64, + ); + + /// Record that a task has failed. + /// + /// Called when a task fails with an error. The `error_type` is a categorized + /// error string suitable for use as a metric label. + fn record_task_failed( + &self, + job_id: &str, + stage_id: usize, + partition: usize, + error_type: &str, + ); + + /// Record shuffle write metrics. + /// + /// Called after shuffle data is written. Tracks bytes, rows, and duration + /// for shuffle write operations. + fn record_shuffle_write( + &self, + job_id: &str, + stage_id: usize, + partition: usize, + bytes: u64, + rows: u64, + duration_ms: u64, + ); + + /// Record shuffle read metrics. + /// + /// Called after shuffle data is read. Tracks bytes, rows, and duration + /// for shuffle read operations. + fn record_shuffle_read( + &self, + job_id: &str, + stage_id: usize, + partition: usize, + bytes: u64, + rows: u64, + duration_ms: u64, ); + + /// Record local shuffle read metrics (data read from local disk). + /// + /// Called when shuffle data is read from a local file. This means the partition + /// was written by this same executor in a previous stage, avoiding network transfer. + fn record_shuffle_read_local( + &self, + job_id: &str, + stage_id: usize, + partition: usize, + bytes: u64, + rows: u64, + duration_ms: u64, + ); + + /// Record remote shuffle read metrics (data fetched from another executor). + /// + /// Called when shuffle data must be fetched over the network from another + /// executor that produced the partition. The `source_executor_id` identifies + /// the executor that holds the shuffle data. + #[allow(clippy::too_many_arguments)] + fn record_shuffle_read_remote( + &self, + job_id: &str, + stage_id: usize, + partition: usize, + source_executor_id: &str, + bytes: u64, + rows: u64, + duration_ms: u64, + ); + + /// Record executor memory availability. + /// + /// Called periodically (e.g., during heartbeat) to report the executor's + /// available memory. This helps with capacity planning and load balancing. + fn record_memory_available(&self, available_bytes: u64); } /// Implementation of `ExecutorMetricsCollector` which logs the completed -/// plan to stdout. +/// plan to stdout. Useful for debugging and development. #[derive(Default)] pub struct LoggingMetricsCollector {} impl ExecutorMetricsCollector for LoggingMetricsCollector { + fn record_task_started(&self, job_id: &str, stage_id: usize, partition: usize) { + info!("=== [{job_id}/{stage_id}/{partition}] Task started ==="); + } + fn record_stage( &self, job_id: &str, stage_id: usize, partition: usize, plan: Arc, + duration_ms: u64, + ) { + info!( + "=== [{job_id}/{stage_id}/{partition}] Task completed in {duration_ms}ms ===\n{plan}\n" + ); + } + + fn record_task_failed( + &self, + job_id: &str, + stage_id: usize, + partition: usize, + error_type: &str, + ) { + info!("=== [{job_id}/{stage_id}/{partition}] Task failed: {error_type} ==="); + } + + fn record_shuffle_write( + &self, + job_id: &str, + stage_id: usize, + partition: usize, + bytes: u64, + rows: u64, + duration_ms: u64, + ) { + info!( + "=== [{job_id}/{stage_id}/{partition}] Shuffle write: {bytes} bytes, {rows} rows in {duration_ms}ms ===" + ); + } + + fn record_shuffle_read( + &self, + job_id: &str, + stage_id: usize, + partition: usize, + bytes: u64, + rows: u64, + duration_ms: u64, ) { info!( - "=== [{job_id}/{stage_id}/{partition}] Physical plan with metrics ===\n{plan}\n" + "=== [{job_id}/{stage_id}/{partition}] Shuffle read: {bytes} bytes, {rows} rows in {duration_ms}ms ===" ); } + + fn record_shuffle_read_local( + &self, + job_id: &str, + stage_id: usize, + partition: usize, + bytes: u64, + rows: u64, + duration_ms: u64, + ) { + info!( + "=== [{job_id}/{stage_id}/{partition}] Local shuffle read: {bytes} bytes, {rows} rows in {duration_ms}ms ===" + ); + } + + fn record_shuffle_read_remote( + &self, + job_id: &str, + stage_id: usize, + partition: usize, + source_executor_id: &str, + bytes: u64, + rows: u64, + duration_ms: u64, + ) { + info!( + "=== [{job_id}/{stage_id}/{partition}] Remote shuffle read from {source_executor_id}: {bytes} bytes, {rows} rows in {duration_ms}ms ===" + ); + } + + fn record_memory_available(&self, available_bytes: u64) { + info!("=== Executor memory available: {available_bytes} bytes ==="); + } } diff --git a/ballista/scheduler/src/cluster/memory.rs b/ballista/scheduler/src/cluster/memory.rs index a1db74ecba..0066ed8b49 100644 --- a/ballista/scheduler/src/cluster/memory.rs +++ b/ballista/scheduler/src/cluster/memory.rs @@ -16,8 +16,8 @@ // under the License. use crate::cluster::{ - BoundTask, ClusterState, ExecutorSlot, JobState, JobStateEvent, JobStateEventStream, - JobStatus, TaskDistributionPolicy, TopologyNode, bind_task_bias, + BindingResult, ClusterState, ExecutorSlot, JobState, JobStateEvent, + JobStateEventStream, JobStatus, TaskDistributionPolicy, TopologyNode, bind_task_bias, bind_task_consistent_hash, bind_task_round_robin, get_scan_files, is_skip_consistent_hash, }; @@ -111,7 +111,7 @@ impl ClusterState for InMemoryClusterState { distribution: TaskDistributionPolicy, active_jobs: Arc>, executors: Option>, - ) -> Result> { + ) -> Result { let mut guard = self.task_slots.lock().await; let available_slots: Vec<&mut AvailableTaskSlots> = guard @@ -126,7 +126,7 @@ impl ClusterState for InMemoryClusterState { }) .collect(); - let bound_tasks = match distribution { + let result = match distribution { TaskDistributionPolicy::Bias => { bind_task_bias(available_slots, active_jobs, |_| false).await } @@ -137,7 +137,7 @@ impl ClusterState for InMemoryClusterState { num_replicas, tolerance, } => { - let mut bound_tasks = bind_task_round_robin( + let mut result = bind_task_round_robin( available_slots, active_jobs.clone(), |stage_plan: Arc| { @@ -150,22 +150,24 @@ impl ClusterState for InMemoryClusterState { }, ) .await; - info!("{} tasks bound by round robin policy", bound_tasks.len()); - let (bound_tasks_consistent_hash, ch_topology) = - bind_task_consistent_hash( - self.get_topology_nodes(&guard, executors), - num_replicas, - tolerance, - active_jobs, - |_, plan| get_scan_files(plan), - ) - .await?; + info!( + "{} tasks bound by round robin policy", + result.bound_tasks.len() + ); + let (consistent_hash_result, ch_topology) = bind_task_consistent_hash( + self.get_topology_nodes(&guard, executors), + num_replicas, + tolerance, + active_jobs, + |_, plan| get_scan_files(plan), + ) + .await?; info!( "{} tasks bound by consistent hashing policy", - bound_tasks_consistent_hash.len() + consistent_hash_result.bound_tasks.len() ); - if !bound_tasks_consistent_hash.is_empty() { - bound_tasks.extend(bound_tasks_consistent_hash); + if !consistent_hash_result.bound_tasks.is_empty() { + result.extend(consistent_hash_result); // Update the available slots let ch_topology = ch_topology.unwrap(); for node in ch_topology.nodes() { @@ -176,14 +178,17 @@ impl ClusterState for InMemoryClusterState { } } } - bound_tasks + result } TaskDistributionPolicy::Custom(ref policy) => { - policy.bind_tasks(available_slots, active_jobs).await? + // Custom policies don't support affinity tracking yet + BindingResult::from_tasks( + policy.bind_tasks(available_slots, active_jobs).await?, + ) } }; - Ok(bound_tasks) + Ok(result) } async fn unbind_tasks(&self, executor_slots: Vec) -> Result<()> { diff --git a/ballista/scheduler/src/cluster/mod.rs b/ballista/scheduler/src/cluster/mod.rs index 7b1f4d27a5..f5aeefce2b 100644 --- a/ballista/scheduler/src/cluster/mod.rs +++ b/ballista/scheduler/src/cluster/mod.rs @@ -155,6 +155,53 @@ pub type ExecutorHeartbeatStream = Pin /// Tuple of (executor_id, task_description). pub type BoundTask = (String, TaskDescription); +/// Shuffle affinity information for a bound task. +/// +/// Tracks whether a task was scheduled on an executor that has local shuffle data +/// from the task's input stages. +#[derive(Debug, Clone)] +pub struct ShuffleAffinityInfo { + /// Job ID for this task. + pub job_id: String, + /// Stage ID for this task. + pub stage_id: usize, + /// Executor ID where the task was scheduled. + pub executor_id: String, + /// True if the executor has local shuffle data for at least one input partition. + pub has_local_data: bool, +} + +/// Result of task binding including shuffle affinity metrics. +#[derive(Debug, Default)] +pub struct BindingResult { + /// Tasks bound to executors. + pub bound_tasks: Vec, + /// Shuffle affinity information for tasks that have shuffle inputs. + /// Only populated for tasks whose stages read from shuffle (not leaf stages). + pub shuffle_affinity: Vec, +} + +impl BindingResult { + /// Creates a new empty binding result. + pub fn new() -> Self { + Self::default() + } + + /// Creates a binding result from bound tasks without affinity info. + pub fn from_tasks(bound_tasks: Vec) -> Self { + Self { + bound_tasks, + shuffle_affinity: Vec::new(), + } + } + + /// Extends this result with another binding result. + pub fn extend(&mut self, other: BindingResult) { + self.bound_tasks.extend(other.bound_tasks); + self.shuffle_affinity.extend(other.shuffle_affinity); + } +} + /// An executor slot representing available task capacity. /// /// Tuple of (executor_id, slot_count). @@ -175,12 +222,13 @@ pub trait ClusterState: Send + Sync + 'static { /// Binds ready-to-run tasks from active jobs to available executor slots. /// /// If `executors` is provided, only bind slots from the specified executor IDs. + /// Returns both the bound tasks and shuffle affinity information for metrics. async fn bind_schedulable_tasks( &self, distribution: TaskDistributionPolicy, active_jobs: Arc>, executors: Option>, - ) -> Result>; + ) -> Result; /// Unbinds executor slots when tasks finish or fail. /// @@ -349,17 +397,43 @@ pub trait JobState: Send + Sync { fn produce_config(&self) -> SessionConfig; } +use crate::state::execution_stage::RunningStage; + +/// Collects the set of executor IDs that have local shuffle data for this stage. +/// +/// Returns `None` if the stage has no shuffle inputs (leaf stage), +/// otherwise returns `Some(set)` where the set contains executor IDs with local data. +fn get_executors_with_local_shuffle_data( + running_stage: &RunningStage, +) -> Option> { + if running_stage.inputs.is_empty() { + // Leaf stage with no shuffle inputs + return None; + } + + let mut executors_with_local_data = HashSet::new(); + for stage_output in running_stage.inputs.values() { + for partition_locations in stage_output.partition_locations.values() { + for location in partition_locations { + executors_with_local_data.insert(location.executor_meta.id.clone()); + } + } + } + + Some(executors_with_local_data) +} + pub(crate) async fn bind_task_bias( mut slots: Vec<&mut AvailableTaskSlots>, running_jobs: Arc>, if_skip: fn(Arc) -> bool, -) -> Vec { - let mut schedulable_tasks: Vec = vec![]; +) -> BindingResult { + let mut result = BindingResult::new(); let total_slots = slots.iter().fold(0, |acc, s| acc + s.slots); if total_slots == 0 { debug!("Not enough available executor slots for task running!!!"); - return schedulable_tasks; + return result; } // Sort the slots by descending order @@ -386,6 +460,12 @@ pub(crate) async fn bind_task_bias( black_list.push(running_stage.stage_id); continue; } + + // Get executors with local shuffle data before we borrow task_infos mutably + let executors_with_local_data = + get_executors_with_local_shuffle_data(running_stage); + let stage_id = running_stage.stage_id; + // We are sure that it will at least bind one task by going through the following logic. // It will not go into a dead loop. let runnable_tasks = running_stage @@ -400,7 +480,7 @@ pub(crate) async fn bind_task_bias( while slot.slots == 0 { idx_slot += 1; if idx_slot >= slots.len() { - return schedulable_tasks; + return result; } slot = &mut slots[idx_slot]; } @@ -409,9 +489,19 @@ pub(crate) async fn bind_task_bias( *task_id_gen += 1; *task_info = Some(create_task_info(executor_id.clone(), task_id)); + // Record shuffle affinity for this task if it has shuffle inputs + if let Some(ref local_executors) = executors_with_local_data { + result.shuffle_affinity.push(ShuffleAffinityInfo { + job_id: job_id.clone(), + stage_id, + executor_id: executor_id.clone(), + has_local_data: local_executors.contains(&executor_id), + }); + } + let partition = PartitionId { job_id: job_id.clone(), - stage_id: running_stage.stage_id, + stage_id, partition_id, }; let task_desc = TaskDescription { @@ -422,28 +512,29 @@ pub(crate) async fn bind_task_bias( task_attempt: running_stage.task_failure_numbers[partition_id], plan: running_stage.plan.clone(), session_config: running_stage.session_config.clone(), + schedulable_time_millis: running_stage.stage_running_time, }; - schedulable_tasks.push((executor_id, task_desc)); + result.bound_tasks.push((executor_id, task_desc)); slot.slots -= 1; } } } - schedulable_tasks + result } pub(crate) async fn bind_task_round_robin( mut slots: Vec<&mut AvailableTaskSlots>, running_jobs: Arc>, if_skip: fn(Arc) -> bool, -) -> Vec { - let mut schedulable_tasks: Vec = vec![]; +) -> BindingResult { + let mut result = BindingResult::new(); let mut total_slots = slots.iter().fold(0, |acc, s| acc + s.slots); if total_slots == 0 { debug!("Not enough available executor slots for task running!!!"); - return schedulable_tasks; + return result; } debug!("Total slot number is {total_slots}"); @@ -470,6 +561,12 @@ pub(crate) async fn bind_task_round_robin( black_list.push(running_stage.stage_id); continue; } + + // Get executors with local shuffle data before we borrow task_infos mutably + let executors_with_local_data = + get_executors_with_local_shuffle_data(running_stage); + let stage_id = running_stage.stage_id; + // We are sure that it will at least bind one task by going through the following logic. // It will not go into a dead loop. let runnable_tasks = running_stage @@ -495,9 +592,19 @@ pub(crate) async fn bind_task_round_robin( *task_id_gen += 1; *task_info = Some(create_task_info(executor_id.clone(), task_id)); + // Record shuffle affinity for this task if it has shuffle inputs + if let Some(ref local_executors) = executors_with_local_data { + result.shuffle_affinity.push(ShuffleAffinityInfo { + job_id: job_id.clone(), + stage_id, + executor_id: executor_id.clone(), + has_local_data: local_executors.contains(&executor_id), + }); + } + let partition = PartitionId { job_id: job_id.clone(), - stage_id: running_stage.stage_id, + stage_id, partition_id, }; let task_desc = TaskDescription { @@ -508,20 +615,21 @@ pub(crate) async fn bind_task_round_robin( task_attempt: running_stage.task_failure_numbers[partition_id], plan: running_stage.plan.clone(), session_config: running_stage.session_config.clone(), + schedulable_time_millis: running_stage.stage_running_time, }; - schedulable_tasks.push((executor_id, task_desc)); + result.bound_tasks.push((executor_id, task_desc)); idx_slot += 1; slot.slots -= 1; total_slots -= 1; if total_slots == 0 { - return schedulable_tasks; + return result; } } } } - schedulable_tasks + result } /// Maps execution plan to list of files it scans @@ -568,7 +676,7 @@ pub(crate) async fn bind_task_consistent_hash( tolerance: usize, running_jobs: Arc>, get_scan_files: GetScanFilesFunc, -) -> Result<(Vec, Option>)> { +) -> Result<(BindingResult, Option>)> { let mut total_slots = 0usize; for (_, node) in topology_nodes.iter() { total_slots += node.available_slots as usize; @@ -577,7 +685,7 @@ pub(crate) async fn bind_task_consistent_hash( debug!( "Not enough available executor slots for binding tasks with consistent hashing policy!!!" ); - return Ok((vec![], None)); + return Ok((BindingResult::new(), None)); } debug!("Total slot number for consistent hash binding is {total_slots}"); @@ -588,7 +696,7 @@ pub(crate) async fn bind_task_consistent_hash( let mut ch_topology: ConsistentHash = ConsistentHash::new(node_replicas); - let mut schedulable_tasks: Vec = vec![]; + let mut result = BindingResult::new(); for (job_id, job_info) in running_jobs.iter() { if !matches!(job_info.status, Some(job_status::Status::Running(_))) { debug!("Job {job_id} is not in running status and will be skipped"); @@ -638,6 +746,10 @@ pub(crate) async fn bind_task_consistent_hash( *task_id_gen += 1; *task_info = Some(create_task_info(executor_id.clone(), task_id)); + // Note: Consistent hash is used for scan (leaf) stages, not shuffle stages, + // so we don't track shuffle affinity here. The stage has scan files, + // meaning it's a leaf stage without shuffle inputs. + let partition = PartitionId { job_id: job_id.clone(), stage_id: running_stage.stage_id, @@ -652,13 +764,14 @@ pub(crate) async fn bind_task_consistent_hash( [partition_id], plan: running_stage.plan.clone(), session_config: running_stage.session_config.clone(), + schedulable_time_millis: running_stage.stage_running_time, }; - schedulable_tasks.push((executor_id, task_desc)); + result.bound_tasks.push((executor_id, task_desc)); node.available_slots -= 1; total_slots -= 1; if total_slots == 0 { - return Ok((schedulable_tasks, Some(ch_topology))); + return Ok((result, Some(ch_topology))); } } } @@ -671,7 +784,7 @@ pub(crate) async fn bind_task_consistent_hash( } } - Ok((schedulable_tasks, Some(ch_topology))) + Ok((result, Some(ch_topology))) } // If if there's no plan which needs to scan files, skip it. @@ -778,11 +891,11 @@ mod test { let mut available_slots = mock_available_slots(); let available_slots_ref: Vec<&mut AvailableTaskSlots> = available_slots.iter_mut().collect(); - let bound_tasks = + let binding_result = bind_task_bias(available_slots_ref, Arc::new(active_jobs), |_| false).await; - assert_eq!(9, bound_tasks.len()); + assert_eq!(9, binding_result.bound_tasks.len()); - let result = get_result(bound_tasks); + let result = get_result(binding_result.bound_tasks); let mut expected = Vec::new(); { @@ -828,12 +941,12 @@ mod test { let mut available_slots = mock_available_slots(); let available_slots_ref: Vec<&mut AvailableTaskSlots> = available_slots.iter_mut().collect(); - let bound_tasks = + let binding_result = bind_task_round_robin(available_slots_ref, Arc::new(active_jobs), |_| false) .await; - assert_eq!(9, bound_tasks.len()); + assert_eq!(9, binding_result.bound_tasks.len()); - let result = get_result(bound_tasks); + let result = get_result(binding_result.bound_tasks); let mut expected = Vec::new(); { @@ -888,7 +1001,7 @@ mod test { // Check none scan files case { - let (bound_tasks, _) = bind_task_consistent_hash( + let (binding_result, _) = bind_task_consistent_hash( topology_nodes.clone(), num_replicas, tolerance, @@ -896,12 +1009,12 @@ mod test { |_, _| Ok(vec![]), ) .await?; - assert_eq!(0, bound_tasks.len()); + assert_eq!(0, binding_result.bound_tasks.len()); } // Check job_b with scan files { - let (bound_tasks, _) = bind_task_consistent_hash( + let (binding_result, _) = bind_task_consistent_hash( topology_nodes, num_replicas, tolerance, @@ -909,9 +1022,9 @@ mod test { |job_id, _| mock_get_scan_files("job_b", job_id, 8), ) .await?; - assert_eq!(6, bound_tasks.len()); + assert_eq!(6, binding_result.bound_tasks.len()); - let result = get_result(bound_tasks); + let result = get_result(binding_result.bound_tasks); let mut expected = HashMap::new(); { @@ -941,7 +1054,7 @@ mod test { let tolerance = 1; { - let (bound_tasks, _) = bind_task_consistent_hash( + let (binding_result, _) = bind_task_consistent_hash( topology_nodes, num_replicas, tolerance, @@ -949,9 +1062,9 @@ mod test { |job_id, _| mock_get_scan_files("job_b", job_id, 8), ) .await?; - assert_eq!(7, bound_tasks.len()); + assert_eq!(7, binding_result.bound_tasks.len()); - let result = get_result(bound_tasks); + let result = get_result(binding_result.bound_tasks); let mut expected = HashMap::new(); { diff --git a/ballista/scheduler/src/config.rs b/ballista/scheduler/src/config.rs index 24b0e689fa..bf22d0d0f8 100644 --- a/ballista/scheduler/src/config.rs +++ b/ballista/scheduler/src/config.rs @@ -27,6 +27,7 @@ use crate::SessionBuilder; use crate::cluster::DistributionPolicy; +use crate::metrics::SchedulerMetricsCollector; use ballista_core::{ConfigProducer, config::TaskSchedulingPolicy}; use datafusion_proto::logical_plan::LogicalExtensionCodec; use datafusion_proto::physical_plan::PhysicalExtensionCodec; @@ -245,6 +246,8 @@ pub struct SchedulerConfig { pub override_physical_codec: Option>, /// Override function for customizing gRPC client endpoints before they are used pub override_create_grpc_client_endpoint: Option, + /// [SchedulerMetricsCollector] override option + pub override_metrics_collector: Option>, } impl Default for SchedulerConfig { @@ -272,6 +275,7 @@ impl Default for SchedulerConfig { override_logical_codec: None, override_physical_codec: None, override_create_grpc_client_endpoint: None, + override_metrics_collector: None, } } } @@ -392,6 +396,7 @@ impl SchedulerConfig { self } + /// Sets an override function for creating gRPC client endpoints. pub fn with_override_create_grpc_client_endpoint( mut self, override_fn: Arc< @@ -401,6 +406,15 @@ impl SchedulerConfig { self.override_create_grpc_client_endpoint = Some(override_fn); self } + + /// Sets a custom metrics collector. + pub fn with_override_metrics_collector( + mut self, + metrics_collector: Arc, + ) -> Self { + self.override_metrics_collector = Some(metrics_collector); + self + } } /// Policy of distributing tasks to available executor slots @@ -512,6 +526,7 @@ impl TryFrom for SchedulerConfig { override_physical_codec: None, override_session_builder: None, override_create_grpc_client_endpoint: None, + override_metrics_collector: None, }; Ok(config) diff --git a/ballista/scheduler/src/metrics/mod.rs b/ballista/scheduler/src/metrics/mod.rs index 819534e183..934dda1d95 100644 --- a/ballista/scheduler/src/metrics/mod.rs +++ b/ballista/scheduler/src/metrics/mod.rs @@ -23,10 +23,19 @@ use crate::metrics::prometheus::PrometheusMetricsCollector; use ballista_core::error::Result; use std::sync::Arc; -/// Interface for recording metrics events in the scheduler. An instance of `Arc` -/// will be passed when constructing the `QueryStageScheduler` which is the core event loop of the scheduler. +/// Interface for recording metrics events in the scheduler. +/// +/// An instance of `Arc` will be passed when constructing +/// the `QueryStageScheduler` which is the core event loop of the scheduler. /// The event loop will then record metric events through this trait. +/// +/// This trait provides hooks for job lifecycle events, stage lifecycle events, +/// task scheduling events, and executor management events. pub trait SchedulerMetricsCollector: Send + Sync { + // ========================================================================= + // Job lifecycle events (existing) + // ========================================================================= + /// Record that job with `job_id` was submitted. This will be invoked /// after the job's `ExecutionGraph` is created and it is ready to be scheduled /// on executors. @@ -52,9 +61,132 @@ pub trait SchedulerMetricsCollector: Send + Sync { /// to schedule on an executor but cannot be scheduled because no resources are available. fn set_pending_tasks_queue_size(&self, value: u64); + /// Set the current number of pending jobs in scheduler. A pending job is a job that has been + /// queued but not yet submitted for execution (i.e., not yet planned). + fn set_pending_jobs_queue_size(&self, value: u64); + /// Gather current metric set that should be returned when calling the scheduler's metrics API /// Should return a tuple containing the content of the metric set and the content type (e.g. `application/json`, `text/plain`, etc) fn gather_metrics(&self) -> Result, String)>>; + + // ========================================================================= + // Stage lifecycle events (new) + // ========================================================================= + + /// Record that a stage has started execution. + /// + /// Called when a stage transitions to Running state. The `task_count` is the + /// total number of partitions/tasks in this stage. + fn record_stage_started(&self, job_id: &str, stage_id: usize, task_count: usize); + + /// Record that a stage has completed successfully. + /// + /// Called when all tasks in a stage complete successfully. The `duration_ms` + /// is the wall-clock time from stage start to completion. + fn record_stage_completed(&self, job_id: &str, stage_id: usize, duration_ms: u64); + + /// Record that a stage has failed. + /// + /// Called when a stage fails (e.g., due to task failures exceeding retry limit). + /// The `error_type` is a categorized error string suitable for use as a metric label. + fn record_stage_failed(&self, job_id: &str, stage_id: usize, error_type: &str); + + /// Record that a stage is being retried. + /// + /// Called when a stage is reset for retry after a failure. + fn record_stage_retry(&self, job_id: &str, stage_id: usize); + + // ========================================================================= + // Task scheduling events (new) + // ========================================================================= + + /// Record that a task has been scheduled to an executor. + /// + /// Called when the scheduler assigns a task to an executor. The `latency_ms` + /// is the time from when the task became schedulable to when it was assigned. + fn record_task_scheduled( + &self, + job_id: &str, + stage_id: usize, + executor_id: &str, + latency_ms: u64, + ); + + /// Record that a task has completed on an executor. + /// + /// Called when the scheduler receives notification of task completion. + fn record_task_completed(&self, job_id: &str, stage_id: usize, executor_id: &str); + + /// Record that a task has failed on an executor. + /// + /// Called when the scheduler receives notification of task failure. + fn record_task_failed( + &self, + job_id: &str, + stage_id: usize, + executor_id: &str, + error_type: &str, + ); + + /// Record that a task is being retried. + /// + /// Called when a task is rescheduled after a failure. This is distinct from + /// stage-level retries and tracks individual task retry attempts. + fn record_task_retry(&self, job_id: &str, stage_id: usize); + + /// Record a shuffle affinity hit - task was assigned to an executor that has + /// local shuffle data from a parent stage. + /// + /// Called when the scheduler assigns a task to an executor that already has + /// the required shuffle partitions from upstream stages stored locally. + /// This indicates the task can read shuffle data without network transfer. + fn record_task_shuffle_affinity_hit( + &self, + job_id: &str, + stage_id: usize, + executor_id: &str, + ); + + /// Record a shuffle affinity miss - task was assigned to an executor that does + /// NOT have local shuffle data from a parent stage. + /// + /// Called when the scheduler assigns a task to an executor that does not have + /// the required shuffle partitions locally. This indicates the task will need + /// to fetch shuffle data over the network from other executors. + fn record_task_shuffle_affinity_miss( + &self, + job_id: &str, + stage_id: usize, + executor_id: &str, + ); + + // ========================================================================= + // Executor management events (new) + // ========================================================================= + + /// Set the current count of active executors. + /// + /// Called when the executor count changes (registration/deregistration). + fn set_active_executor_count(&self, count: usize); + + /// Record that an executor has registered with the scheduler. + fn record_executor_registered(&self, executor_id: &str); + + /// Record that an executor has been removed from the scheduler. + /// + /// This can happen due to explicit deregistration, heartbeat timeout, or + /// executor failure. + fn record_executor_deregistered(&self, executor_id: &str); + + // ========================================================================= + // Planning events (new) + // ========================================================================= + + /// Record the duration of query planning. + /// + /// Called after a query has been planned and the ExecutionGraph is created. + /// The `duration_ms` is the time spent in the distributed planner. + fn record_planning_duration(&self, job_id: &str, duration_ms: u64); } /// Implementation of `SchedulerMetricsCollector` that ignores all events. This can be used as @@ -63,15 +195,66 @@ pub trait SchedulerMetricsCollector: Send + Sync { pub struct NoopMetricsCollector {} impl SchedulerMetricsCollector for NoopMetricsCollector { + // Job lifecycle fn record_submitted(&self, _job_id: &str, _queued_at: u64, _submitted_at: u64) {} - fn record_completed(&self, _job_id: &str, _queued_at: u64, _completed_att: u64) {} + fn record_completed(&self, _job_id: &str, _queued_at: u64, _completed_at: u64) {} fn record_failed(&self, _job_id: &str, _queued_at: u64, _failed_at: u64) {} fn record_cancelled(&self, _job_id: &str) {} fn set_pending_tasks_queue_size(&self, _value: u64) {} - + fn set_pending_jobs_queue_size(&self, _value: u64) {} fn gather_metrics(&self) -> Result, String)>> { Ok(None) } + + // Stage lifecycle + fn record_stage_started(&self, _job_id: &str, _stage_id: usize, _task_count: usize) {} + fn record_stage_completed(&self, _job_id: &str, _stage_id: usize, _duration_ms: u64) { + } + fn record_stage_failed(&self, _job_id: &str, _stage_id: usize, _error_type: &str) {} + fn record_stage_retry(&self, _job_id: &str, _stage_id: usize) {} + + // Task scheduling + fn record_task_scheduled( + &self, + _job_id: &str, + _stage_id: usize, + _executor_id: &str, + _latency_ms: u64, + ) { + } + fn record_task_completed(&self, _job_id: &str, _stage_id: usize, _executor_id: &str) { + } + fn record_task_failed( + &self, + _job_id: &str, + _stage_id: usize, + _executor_id: &str, + _error_type: &str, + ) { + } + fn record_task_retry(&self, _job_id: &str, _stage_id: usize) {} + fn record_task_shuffle_affinity_hit( + &self, + _job_id: &str, + _stage_id: usize, + _executor_id: &str, + ) { + } + fn record_task_shuffle_affinity_miss( + &self, + _job_id: &str, + _stage_id: usize, + _executor_id: &str, + ) { + } + + // Executor management + fn set_active_executor_count(&self, _count: usize) {} + fn record_executor_registered(&self, _executor_id: &str) {} + fn record_executor_deregistered(&self, _executor_id: &str) {} + + // Planning + fn record_planning_duration(&self, _job_id: &str, _duration_ms: u64) {} } /// Returns the default metrics collector for the system. diff --git a/ballista/scheduler/src/metrics/prometheus.rs b/ballista/scheduler/src/metrics/prometheus.rs index 032f0f82b1..4cf68088e0 100644 --- a/ballista/scheduler/src/metrics/prometheus.rs +++ b/ballista/scheduler/src/metrics/prometheus.rs @@ -162,6 +162,10 @@ impl SchedulerMetricsCollector for PrometheusMetricsCollector { self.pending_queue_size.set(value as f64); } + fn set_pending_jobs_queue_size(&self, _value: u64) { + // Not tracked by default Prometheus collector + } + fn gather_metrics(&self) -> Result, String)>> { let encoder = TextEncoder::new(); @@ -173,4 +177,40 @@ impl SchedulerMetricsCollector for PrometheusMetricsCollector { Ok(Some((buffer, encoder.format_type().to_owned()))) } + + // Stage lifecycle - not tracked by default Prometheus collector + fn record_stage_started(&self, _job_id: &str, _stage_id: usize, _task_count: usize) {} + fn record_stage_completed(&self, _job_id: &str, _stage_id: usize, _duration_ms: u64) { + } + fn record_stage_failed(&self, _job_id: &str, _stage_id: usize, _error_type: &str) {} + fn record_stage_retry(&self, _job_id: &str, _stage_id: usize) {} + + // Task scheduling - not tracked by default Prometheus collector + fn record_task_scheduled( + &self, + _job_id: &str, + _stage_id: usize, + _executor_id: &str, + _latency_ms: u64, + ) { + } + fn record_task_completed(&self, _job_id: &str, _stage_id: usize, _executor_id: &str) { + } + fn record_task_failed( + &self, + _job_id: &str, + _stage_id: usize, + _executor_id: &str, + _error_type: &str, + ) { + } + fn record_task_retry(&self, _job_id: &str, _stage_id: usize) {} + + // Executor management - not tracked by default Prometheus collector + fn set_active_executor_count(&self, _count: usize) {} + fn record_executor_registered(&self, _executor_id: &str) {} + fn record_executor_deregistered(&self, _executor_id: &str) {} + + // Planning - not tracked by default Prometheus collector + fn record_planning_duration(&self, _job_id: &str, _duration_ms: u64) {} } diff --git a/ballista/scheduler/src/scheduler_process.rs b/ballista/scheduler/src/scheduler_process.rs index 361ade55fb..47b6af7f8d 100644 --- a/ballista/scheduler/src/scheduler_process.rs +++ b/ballista/scheduler/src/scheduler_process.rs @@ -65,7 +65,10 @@ pub async fn create_scheduler< .unwrap_or_else(|| Arc::new(BallistaPhysicalExtensionCodec::default())); let codec = BallistaCodec::new(codec_logical, codec_physical); - let metrics_collector = default_metrics_collector()?; + let metrics_collector = config + .override_metrics_collector + .clone() + .map_or_else(|| default_metrics_collector(), Ok)?; let mut scheduler_server = SchedulerServer::new( config.scheduler_name(), diff --git a/ballista/scheduler/src/scheduler_server/grpc.rs b/ballista/scheduler/src/scheduler_server/grpc.rs index 3579dbb0e1..1a922dc4f9 100644 --- a/ballista/scheduler/src/scheduler_server/grpc.rs +++ b/ballista/scheduler/src/scheduler_server/grpc.rs @@ -46,12 +46,12 @@ use { use std::ops::Deref; -use crate::cluster::{bind_task_bias, bind_task_round_robin}; +use crate::cluster::{BindingResult, bind_task_bias, bind_task_round_robin}; use crate::config::TaskDistributionPolicy; -use crate::scheduler_server::event::QueryStageSchedulerEvent; use crate::scheduler_server::SchedulerServer; -use ballista_core::remote_catalog::remote_function_serialize_ext::RemoteFunctionSerializeExt; +use crate::scheduler_server::event::QueryStageSchedulerEvent; use ballista_core::remote_catalog::catalog_serialize_ext::CatalogSerializeExt; +use ballista_core::remote_catalog::remote_function_serialize_ext::RemoteFunctionSerializeExt; use std::time::{SystemTime, UNIX_EPOCH}; use tonic::{Request, Response, Status}; @@ -118,7 +118,7 @@ impl SchedulerGrpc }]; let available_slots = available_slots.iter_mut().collect(); let running_jobs = self.state.task_manager.get_running_job_cache(); - let schedulable_tasks = match self.state.config.task_distribution { + let binding_result = match self.state.config.task_distribution { TaskDistributionPolicy::Bias => { bind_task_bias(available_slots, running_jobs, |_| false).await } @@ -131,15 +131,54 @@ impl SchedulerGrpc )); } - TaskDistributionPolicy::Custom(ref policy) => policy - .bind_tasks(available_slots, running_jobs) - .await - .map_err(|e| Status::internal(e.to_string()))?, + TaskDistributionPolicy::Custom(ref policy) => BindingResult::from_tasks( + policy + .bind_tasks(available_slots, running_jobs) + .await + .map_err(|e| Status::internal(e.to_string()))?, + ), }; + // Record shuffle affinity metrics + for affinity in &binding_result.shuffle_affinity { + if affinity.has_local_data { + self.state + .metrics_collector + .record_task_shuffle_affinity_hit( + &affinity.job_id, + affinity.stage_id, + &affinity.executor_id, + ); + } else { + self.state + .metrics_collector + .record_task_shuffle_affinity_miss( + &affinity.job_id, + affinity.stage_id, + &affinity.executor_id, + ); + } + } + let mut tasks = vec![]; - for (_, task) in schedulable_tasks { + let now_millis = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + + for (_, task) in binding_result.bound_tasks { let job_id = task.partition.job_id.clone(); + let stage_id = task.partition.stage_id; + + // Record task scheduling metric with actual latency + let latency_ms = now_millis.saturating_sub(task.schedulable_time_millis); + self.state.metrics_collector.record_task_scheduled( + &job_id, + stage_id, + &executor_id, + latency_ms as u64, + ); + match self.state.task_manager.prepare_task_definition(task) { Ok(task_definition) => tasks.push(task_definition), Err(e) => { @@ -490,6 +529,7 @@ impl SchedulerGrpc ); let executor_manager = self.state.executor_manager.clone(); + let metrics_collector = self.state.metrics_collector.clone(); let event_sender = self.query_stage_event_loop.get_sender().map_err(|e| { let msg = format!("Get query stage event loop error due to {e:?}"); error!("{msg}"); @@ -499,6 +539,7 @@ impl SchedulerGrpc Self::remove_executor( executor_manager, event_sender, + metrics_collector, &executor_id, Some(reason), self.config.executor_termination_grace_period, @@ -665,6 +706,7 @@ mod test { SchedulerState::new_with_default_scheduler_name( cluster.clone(), BallistaCodec::default(), + default_metrics_collector().unwrap(), ); state.init().await?; @@ -697,6 +739,7 @@ mod test { SchedulerState::new_with_default_scheduler_name( cluster.clone(), BallistaCodec::default(), + default_metrics_collector().unwrap(), ); state.init().await?; diff --git a/ballista/scheduler/src/scheduler_server/mod.rs b/ballista/scheduler/src/scheduler_server/mod.rs index dd91bac2df..a6d8ac18c3 100644 --- a/ballista/scheduler/src/scheduler_server/mod.rs +++ b/ballista/scheduler/src/scheduler_server/mod.rs @@ -99,6 +99,7 @@ impl SchedulerServer SchedulerServer SchedulerServer SchedulerServer, + metrics_collector: Arc, executor_id: &str, reason: Option, wait_secs: u64, @@ -356,6 +360,13 @@ impl SchedulerServer SchedulerServer Result<()> { + let executor_id = metadata.id.clone(); let executor_data = ExecutorData { - executor_id: metadata.id.clone(), + executor_id: executor_id.clone(), total_task_slots: metadata.specification.task_slots, available_task_slots: metadata.specification.task_slots, }; @@ -378,6 +390,17 @@ impl SchedulerServer ); } } + + // Update queue size metrics after processing each event + let pending_jobs = self.state.task_manager.pending_job_number(); + self.metrics_collector + .set_pending_jobs_queue_size(pending_jobs as u64); + + let pending_tasks = self.state.task_manager.total_pending_tasks().await; + self.metrics_collector + .set_pending_tasks_queue_size(pending_tasks as u64); + Ok(()) } diff --git a/ballista/scheduler/src/state/execution_graph.rs b/ballista/scheduler/src/state/execution_graph.rs index 959272bc2c..f9526beb2f 100644 --- a/ballista/scheduler/src/state/execution_graph.rs +++ b/ballista/scheduler/src/state/execution_graph.rs @@ -50,6 +50,48 @@ pub(crate) use crate::state::execution_stage::{ }; use crate::state::task_manager::UpdatedStages; +/// Information about stage lifecycle changes during a task status update. +/// +/// This struct is returned from `update_task_status()` to allow the caller to +/// record metrics about stage transitions without coupling the execution graph +/// to the metrics system. +#[derive(Clone, Debug, Default)] +pub struct StageMetricsInfo { + /// Stages that started running (transitioned from Resolved to Running). + /// Contains (job_id, stage_id, task_count, started_at_ms). + pub stages_started: Vec<(String, usize, usize, u64)>, + /// Stages that completed successfully. + /// Contains (job_id, stage_id, duration_ms). + pub stages_completed: Vec<(String, usize, u64)>, + /// Stages that failed. + /// Contains (job_id, stage_id, error_type). + pub stages_failed: Vec<(String, usize, String)>, + /// Stages that are being retried. + /// Contains (job_id, stage_id). + pub stages_retried: Vec<(String, usize)>, + /// Tasks that completed successfully. + /// Contains (job_id, stage_id, executor_id). + pub tasks_completed: Vec<(String, usize, String)>, + /// Tasks that failed. + /// Contains (job_id, stage_id, executor_id, error_type). + pub tasks_failed: Vec<(String, usize, String, String)>, + /// Tasks that are being retried. + /// Contains (job_id, stage_id). + pub tasks_retried: Vec<(String, usize)>, +} + +/// Result from updating task statuses in an execution graph. +/// +/// Contains both scheduler events to be processed and metrics information +/// for recording observability data. +#[derive(Clone, Debug, Default)] +pub struct TaskStatusUpdateResult { + /// Events to be processed by the scheduler event loop. + pub events: Vec, + /// Metrics information for recording stage/task lifecycle events. + pub metrics_info: StageMetricsInfo, +} + /// Represents the DAG for a distributed query plan. /// /// A distributed query plan consists of a set of stages which must be executed sequentially. @@ -263,6 +305,15 @@ impl ExecutionGraph { /// Revive the execution graph by converting the resolved stages to running stages /// If any stages are converted, return true; else false. pub fn revive(&mut self) -> bool { + self.revive_with_metrics().0 + } + + /// Revive the execution graph by converting the resolved stages to running stages. + /// + /// Returns a tuple of (stages_converted, stages_started_info) where: + /// - `stages_converted`: true if any stages were converted + /// - `stages_started_info`: list of (stage_id, task_count, started_at_ms) for each started stage + pub fn revive_with_metrics(&mut self) -> (bool, Vec<(usize, usize, u64)>) { let running_stages = self .stages .values() @@ -276,28 +327,40 @@ impl ExecutionGraph { .collect::>(); if running_stages.is_empty() { - false + (false, vec![]) } else { + let mut stages_started = Vec::with_capacity(running_stages.len()); for running_stage in running_stages { + let stage_id = running_stage.stage_id; + let task_count = running_stage.partitions; + #[expect(clippy::cast_possible_truncation)] + let started_at_ms = running_stage.stage_running_time as u64; + + stages_started.push((stage_id, task_count, started_at_ms)); + self.stages.insert( running_stage.stage_id, ExecutionStage::Running(running_stage), ); } - true + (true, stages_started) } } /// Update task statuses and task metrics in the graph. /// This will also push shuffle partitions to their respective shuffle read stages. + /// + /// Returns a `TaskStatusUpdateResult` containing scheduler events to process + /// and metrics information for observability. pub fn update_task_status( &mut self, executor: &ExecutorMetadata, task_statuses: Vec, max_task_failures: usize, max_stage_failures: usize, - ) -> Result> { + ) -> Result { let job_id = self.job_id().to_owned(); + let mut metrics_info = StageMetricsInfo::default(); // First of all, classify the statuses by stages let mut job_task_statuses: HashMap> = HashMap::new(); for task_status in task_statuses { @@ -308,7 +371,15 @@ impl ExecutionGraph { // Revive before updating due to some updates not saved // It will be refined later - self.revive(); + let (_, stages_started) = self.revive_with_metrics(); + for (stage_id, task_count, started_at_ms) in stages_started { + metrics_info.stages_started.push(( + job_id.clone(), + stage_id, + task_count, + started_at_ms, + )); + } let current_running_stages: HashSet = HashSet::from_iter(self.running_stages()); @@ -372,6 +443,13 @@ impl ExecutionGraph { Some(FailedReason::FetchPartitionError( fetch_partiton_error, )) => { + // Record task failure metric + metrics_info.tasks_failed.push(( + job_id.clone(), + stage_id, + executor.id.clone(), + "fetch_partition_error".to_string(), + )); let failed_attempts = failed_stage_attempts .entry(stage_id) .or_default(); @@ -431,6 +509,13 @@ impl ExecutionGraph { } } Some(FailedReason::ExecutionError(_)) => { + // Record task failure metric + metrics_info.tasks_failed.push(( + job_id.clone(), + stage_id, + executor.id.clone(), + "execution_error".to_string(), + )); failed_stages.insert(stage_id, failed_task.error); } Some(_) => { @@ -440,6 +525,17 @@ impl ExecutionGraph { if running_stage.task_failure_number(partition_id) < max_task_failures { + // Record task retry metric + metrics_info + .tasks_retried + .push((job_id.clone(), stage_id)); + // Record task failure metric + metrics_info.tasks_failed.push(( + job_id.clone(), + stage_id, + executor.id.clone(), + "retryable_error".to_string(), + )); // TODO add new struct to track all the failed task infos // The failure TaskInfo is ignored and set to None here running_stage.reset_task_info(partition_id); @@ -452,9 +548,27 @@ impl ExecutionGraph { failed_task.error ); error!("{error_msg}"); + // Record task failure metric + metrics_info.tasks_failed.push(( + job_id.clone(), + stage_id, + executor.id.clone(), + "max_retries_exceeded".to_string(), + )); failed_stages.insert(stage_id, error_msg); } } else if failed_task.retryable { + // Record task retry metric + metrics_info + .tasks_retried + .push((job_id.clone(), stage_id)); + // Record task failure metric (but retryable) + metrics_info.tasks_failed.push(( + job_id.clone(), + stage_id, + executor.id.clone(), + "retryable_no_count".to_string(), + )); // TODO add new struct to track all the failed task infos // The failure TaskInfo is ignored and set to None here running_stage.reset_task_info(partition_id); @@ -465,6 +579,13 @@ impl ExecutionGraph { "Task {partition_id} in Stage {stage_id} failed with unknown failure reasons, fail the stage" ); error!("{error_msg}"); + // Record task failure metric + metrics_info.tasks_failed.push(( + job_id.clone(), + stage_id, + executor.id.clone(), + "unknown".to_string(), + )); failed_stages.insert(stage_id, error_msg); } } @@ -472,6 +593,12 @@ impl ExecutionGraph { successful_task, )) = task_status.status { + // Record task completion metric + metrics_info.tasks_completed.push(( + job_id.clone(), + stage_id, + executor.id.clone(), + )); // update task metrics for successfu task running_stage .update_task_metrics(partition_id, operator_metrics)?; @@ -673,7 +800,7 @@ impl ExecutionGraph { } } - self.processing_stages_update(UpdatedStages { + let (events, stage_metrics) = self.processing_stages_update(UpdatedStages { resolved_stages, successful_stages, failed_stages, @@ -682,29 +809,87 @@ impl ExecutionGraph { .keys() .cloned() .collect(), + })?; + + // Combine task metrics collected during processing with stage metrics + metrics_info + .stages_started + .extend(stage_metrics.stages_started); + metrics_info + .stages_completed + .extend(stage_metrics.stages_completed); + metrics_info + .stages_failed + .extend(stage_metrics.stages_failed); + metrics_info + .stages_retried + .extend(stage_metrics.stages_retried); + + Ok(TaskStatusUpdateResult { + events, + metrics_info, }) } /// Processing stage status update after task status changing + /// + /// Returns a tuple of (events, stage_metrics_info) containing scheduler events + /// and metrics information about stage lifecycle changes. fn processing_stages_update( &mut self, updated_stages: UpdatedStages, - ) -> Result> { + ) -> Result<(Vec, StageMetricsInfo)> { let job_id = self.job_id().to_owned(); let mut has_resolved = false; let mut job_err_msg = "".to_owned(); + let mut stage_metrics = StageMetricsInfo::default(); for stage_id in updated_stages.resolved_stages { self.resolve_stage(stage_id)?; has_resolved = true; } - for stage_id in updated_stages.successful_stages { + for stage_id in updated_stages.successful_stages.clone() { + // Get stage duration before transitioning + if let Some(ExecutionStage::Running(running_stage)) = + self.stages.get(&stage_id) + { + // Calculate duration from stage start time to now + let now = timestamp_millis(); + // Use the earliest task scheduled_time as an approximation for stage start + let stage_start = running_stage + .task_infos + .iter() + .filter_map(|info| info.as_ref().map(|t| t.scheduled_time as u64)) + .min() + .unwrap_or(now); + let duration_ms = now.saturating_sub(stage_start); + stage_metrics.stages_completed.push(( + job_id.clone(), + stage_id, + duration_ms, + )); + } self.succeed_stage(stage_id); } // Fail the stage and also abort the job for (stage_id, err_msg) in &updated_stages.failed_stages { + // Categorize error type for metrics + let error_type = if err_msg.contains("FetchPartitionError") { + "fetch_partition_error" + } else if err_msg.contains("ExecutionError") { + "execution_error" + } else if err_msg.contains("failed") && err_msg.contains("times") { + "max_retries_exceeded" + } else { + "unknown" + }; + stage_metrics.stages_failed.push(( + job_id.clone(), + *stage_id, + error_type.to_string(), + )); job_err_msg = format!("Job failed due to stage {stage_id} failed: {err_msg}\n"); } @@ -714,11 +899,19 @@ impl ExecutionGraph { if updated_stages.failed_stages.is_empty() { let mut running_tasks_to_cancel = vec![]; for (stage_id, failure_reasons) in updated_stages.rollback_running_stages { + // Record stage retry before rollback + stage_metrics + .stages_retried + .push((job_id.clone(), stage_id)); let tasks = self.rollback_running_stage(stage_id, failure_reasons)?; running_tasks_to_cancel.extend(tasks); } for stage_id in updated_stages.resubmit_successful_stages { + // Record stage retry for successful stages being rerun + stage_metrics + .stages_retried + .push((job_id.clone(), stage_id)); self.rerun_successful_stage(stage_id); } @@ -750,7 +943,7 @@ impl ExecutionGraph { } else if has_resolved { events.push(QueryStageSchedulerEvent::JobUpdated(job_id)) } - Ok(events) + Ok((events, stage_metrics)) } /// Return a Vec of resolvable stage ids @@ -942,7 +1135,8 @@ impl ExecutionGraph { task_id, task_attempt, plan: stage.plan.clone(), - session_config: self.session_config.clone() + session_config: self.session_config.clone(), + schedulable_time_millis: stage.stage_running_time, }) } else { Err(BallistaError::General(format!("Stage {stage_id} is not a running stage"))) @@ -1535,6 +1729,9 @@ pub struct TaskDescription { pub plan: Arc, /// Session configuration for this task's execution context. pub session_config: Arc, + /// Timestamp (millis since epoch) when this task became schedulable. + /// This is when the stage transitioned to running state. + pub schedulable_time_millis: u128, } impl Debug for TaskDescription { @@ -2028,7 +2225,7 @@ mod test { // This long delayed failed task should not failure the stage/job and should not trigger any query stage events let query_stage_events = agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?; - assert!(query_stage_events.is_empty()); + assert!(query_stage_events.events.is_empty()); drain_tasks(&mut agg_graph)?; assert!(agg_graph.is_successful(), "Failed to complete agg plan"); @@ -2082,9 +2279,9 @@ mod test { 4, )?; - assert_eq!(stage_events.len(), 1); + assert_eq!(stage_events.events.len(), 1); assert!(matches!( - stage_events[0], + stage_events.events[0], QueryStageSchedulerEvent::CancelTasks(_) )); @@ -2201,7 +2398,7 @@ mod test { if attempt < 3 { // No JobRunningFailed stage events - assert_eq!(stage_events.len(), 0); + assert_eq!(stage_events.events.len(), 0); // Stage 1 is running let running_stage = agg_graph.running_stages(); assert_eq!(running_stage.len(), 1); @@ -2209,9 +2406,9 @@ mod test { assert_eq!(agg_graph.available_tasks(), 2); } else { // Job is failed after exceeds the max_stage_failures - assert_eq!(stage_events.len(), 1); + assert_eq!(stage_events.events.len(), 1); assert!(matches!( - stage_events[0], + stage_events.events[0], QueryStageSchedulerEvent::JobRunningFailed { .. } )); // Stage 2 is still running @@ -2685,9 +2882,9 @@ mod test { 4, )?; - assert_eq!(stage_events.len(), 1); + assert_eq!(stage_events.events.len(), 1); assert!(matches!( - stage_events[0], + stage_events.events[0], QueryStageSchedulerEvent::JobRunningFailed { .. } )); diff --git a/ballista/scheduler/src/state/executor_manager.rs b/ballista/scheduler/src/state/executor_manager.rs index 059f885c56..ef368b3f43 100644 --- a/ballista/scheduler/src/state/executor_manager.rs +++ b/ballista/scheduler/src/state/executor_manager.rs @@ -22,12 +22,11 @@ use ballista_core::error::Result; use ballista_core::serde::protobuf; use log::trace; -use crate::cluster::{BoundTask, ClusterState, ExecutorSlot}; +use crate::cluster::{BindingResult, ClusterState, ExecutorSlot}; use crate::config::SchedulerConfig; use crate::state::execution_graph::RunningTaskInfo; use crate::state::task_manager::JobInfoCache; -use ballista_core::extension::SessionConfigExt; use ballista_core::serde::protobuf::executor_grpc_client::ExecutorGrpcClient; use ballista_core::serde::protobuf::{ CancelTasksParams, ExecutorHeartbeat, MultiTaskDefinition, RemoveJobDataParams, @@ -35,10 +34,7 @@ use ballista_core::serde::protobuf::{ }; use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata}; -use ballista_core::utils::{ - GrpcClientConfig, create_grpc_client_connection, create_grpc_client_endpoint, get_time_before, -}; - +use ballista_core::utils::{create_grpc_client_endpoint, get_time_before}; use dashmap::DashMap; use log::{debug, error, info, warn}; @@ -66,8 +62,6 @@ pub struct ExecutorManager { clients: ExecutorClients, /// Jobs pending cleanup on each executor. pending_cleanup_jobs: Arc>>, - /// Configuration for gRPC client connections. - grpc_client_config: GrpcClientConfig, } impl ExecutorManager { @@ -76,20 +70,11 @@ impl ExecutorManager { cluster_state: Arc, config: Arc, ) -> Self { - let grpc_client_config = - if let Some(config_producer) = &config.override_config_producer { - let session_config = config_producer(); - let ballista_config = session_config.ballista_config(); - GrpcClientConfig::from(&ballista_config) - } else { - GrpcClientConfig::default() - }; Self { cluster_state, config, clients: Default::default(), pending_cleanup_jobs: Default::default(), - grpc_client_config, } } @@ -102,19 +87,19 @@ impl ExecutorManager { /// Binds ready-to-run tasks from active jobs to available executor slots. /// - /// Returns a list of bound tasks that can be launched on executors. + /// Returns a binding result containing bound tasks and shuffle affinity info. pub async fn bind_schedulable_tasks( &self, running_jobs: Arc>, - ) -> Result> { + ) -> Result { if running_jobs.is_empty() { debug!("There's no active jobs for binding tasks"); - return Ok(vec![]); + return Ok(BindingResult::new()); } let alive_executors = self.get_alive_executors(); if alive_executors.is_empty() { debug!("There's no alive executors for binding tasks"); - return Ok(vec![]); + return Ok(BindingResult::new()); } self.cluster_state .bind_schedulable_tasks( @@ -150,10 +135,7 @@ impl ExecutorManager { let executor_manager = self.clone(); tokio::spawn(async move { for (executor_id, infos) in tasks_to_cancel { - if let Ok(mut client) = executor_manager - .get_client(&executor_id, &executor_manager.grpc_client_config) - .await - { + if let Ok(mut client) = executor_manager.get_client(&executor_id).await { if let Err(e) = client .cancel_tasks(CancelTasksParams { task_infos: infos }) .await @@ -210,9 +192,7 @@ impl ExecutorManager { let job_id_clone = job_id.to_owned(); if self.config.is_push_staged_scheduling() { - if let Ok(mut client) = - self.get_client(&executor, &self.grpc_client_config).await - { + if let Ok(mut client) = self.get_client(&executor).await { tokio::spawn(async move { if let Err(err) = client .remove_job_data(RemoveJobDataParams { @@ -311,10 +291,7 @@ impl ExecutorManager { /// Sends a stop request to the specified executor. pub async fn stop_executor(&self, executor_id: &str, stop_reason: String) { let executor_id = executor_id.to_string(); - match self - .get_client(&executor_id, &self.grpc_client_config) - .await - { + match self.get_client(&executor_id).await { Ok(mut client) => { tokio::task::spawn(async move { match client @@ -347,9 +324,7 @@ impl ExecutorManager { multi_tasks: Vec, scheduler_id: String, ) -> Result<()> { - let mut client = self - .get_client(executor_id, &self.grpc_client_config) - .await?; + let mut client = self.get_client(executor_id).await?; client .launch_multi_task(protobuf::LaunchMultiTaskParams { multi_tasks, @@ -461,11 +436,7 @@ impl ExecutorManager { .collect::>() } - async fn get_client( - &self, - executor_id: &str, - grpc_client_config: &GrpcClientConfig, - ) -> Result> { + async fn get_client(&self, executor_id: &str) -> Result> { let client = self.clients.get(executor_id).map(|value| value.clone()); if let Some(client) = client { diff --git a/ballista/scheduler/src/state/mod.rs b/ballista/scheduler/src/state/mod.rs index 7315890262..aef8df736e 100644 --- a/ballista/scheduler/src/state/mod.rs +++ b/ballista/scheduler/src/state/mod.rs @@ -23,7 +23,7 @@ use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; use std::any::type_name; use std::collections::HashMap; use std::sync::Arc; -use std::time::Instant; +use std::time::{Instant, SystemTime, UNIX_EPOCH}; use crate::scheduler_server::event::QueryStageSchedulerEvent; @@ -37,6 +37,7 @@ use crate::state::task_manager::{TaskLauncher, TaskManager}; use crate::cluster::{BallistaCluster, BoundTask, ExecutorSlot}; use crate::config::SchedulerConfig; +use crate::metrics::SchedulerMetricsCollector; use crate::state::execution_graph::TaskDescription; use ballista_core::error::{BallistaError, Result}; use ballista_core::event_loop::EventSender; @@ -117,6 +118,8 @@ pub struct SchedulerState, /// Scheduler configuration. pub config: Arc, + /// Metrics collector for recording scheduler metrics. + pub metrics_collector: Arc, } impl SchedulerState { @@ -126,6 +129,7 @@ impl SchedulerState, scheduler_name: String, config: Arc, + metrics_collector: Arc, ) -> Self { Self { executor_manager: ExecutorManager::new( @@ -140,6 +144,7 @@ impl SchedulerState SchedulerState, + metrics_collector: Arc, ) -> Self { let config = Arc::new(SchedulerConfig::default()); - SchedulerState::new(cluster, codec, "localhost:50050".to_owned(), config) + SchedulerState::new( + cluster, + codec, + "localhost:50050".to_owned(), + config, + metrics_collector, + ) } #[allow(dead_code)] @@ -159,6 +171,7 @@ impl SchedulerState, scheduler_name: String, config: Arc, + metrics_collector: Arc, dispatcher: Arc, ) -> Self { Self { @@ -175,6 +188,7 @@ impl SchedulerState SchedulerState, ) -> Result<()> { - let schedulable_tasks = self + let binding_result = self .executor_manager .bind_schedulable_tasks(self.task_manager.get_running_job_cache()) .await?; - if schedulable_tasks.is_empty() { + if binding_result.bound_tasks.is_empty() { debug!("No schedulable tasks found to be launched"); return Ok(()); } + // Record shuffle affinity metrics + for affinity in &binding_result.shuffle_affinity { + if affinity.has_local_data { + self.metrics_collector.record_task_shuffle_affinity_hit( + &affinity.job_id, + affinity.stage_id, + &affinity.executor_id, + ); + } else { + self.metrics_collector.record_task_shuffle_affinity_miss( + &affinity.job_id, + affinity.stage_id, + &affinity.executor_id, + ); + } + } + + let schedulable_tasks = binding_result.bound_tasks; let state = self.clone(); tokio::spawn(async move { let mut if_revive = false; @@ -270,6 +302,24 @@ impl SchedulerState, ) -> Result> { + // Get current time once for all latency calculations + let now_millis = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + + // Record task scheduling metrics for each task + for (executor_id, task) in &bound_tasks { + // Calculate scheduling latency: time from when task became schedulable to now + let latency_ms = now_millis.saturating_sub(task.schedulable_time_millis); + self.metrics_collector.record_task_scheduled( + &task.partition.job_id, + task.partition.stage_id, + executor_id, + latency_ms as u64, + ); + } + // Put tasks to the same executor together // And put tasks belonging to the same stage together for creating MultiTaskDefinition let mut executor_stage_assignments: HashMap< @@ -366,9 +416,53 @@ impl SchedulerState SchedulerState TaskManager self.active_job_cache.len() } + /// Get the total number of pending tasks across all active jobs. + /// + /// A pending task is a task that is available to schedule on an executor + /// but cannot be scheduled because no resources are available. + pub async fn total_pending_tasks(&self) -> usize { + let mut total = 0; + for entry in self.active_job_cache.iter() { + let graph = entry.value().execution_graph.read().await; + total += graph.available_tasks(); + } + total + } + /// Generate an ExecutionGraph for the job and save it to the persistent state. /// By default, this job will be curated by the scheduler which receives it. /// Then we will also save it to the active execution graph @@ -362,14 +375,15 @@ impl TaskManager } } - /// Update given task statuses in the respective job and return a tuple containing: - /// 1. A list of QueryStageSchedulerEvent to publish. - /// 2. A list of reservations that can now be offered. + /// Update given task statuses in the respective job and return a `TaskStatusUpdateResult` + /// containing: + /// 1. A list of `QueryStageSchedulerEvent` to publish. + /// 2. Metrics information about stage/task lifecycle changes. pub(crate) async fn update_task_statuses( &self, executor: &ExecutorMetadata, task_status: Vec, - ) -> Result> { + ) -> Result { let mut job_updates: HashMap> = HashMap::new(); for status in task_status { trace!("Task Update\n{status:?}"); @@ -378,13 +392,12 @@ impl TaskManager job_task_statuses.push(status); } - let mut events: Vec = vec![]; + let mut combined_result = TaskStatusUpdateResult::default(); for (job_id, statuses) in job_updates { let num_tasks = statuses.len(); debug!("Updating {num_tasks} tasks in job {job_id}"); - // let graph = self.get_active_execution_graph(&job_id).await; - let job_events = if let Some(cached) = + let job_result = if let Some(cached) = self.get_active_execution_graph(&job_id) { let mut graph = cached.write().await; @@ -399,15 +412,42 @@ impl TaskManager error!( "Fail to find job {job_id} in the active cache and it may not be curated by this scheduler" ); - vec![] + TaskStatusUpdateResult::default() }; - for event in job_events { - events.push(event); - } + // Combine events and metrics from all jobs + combined_result.events.extend(job_result.events); + combined_result + .metrics_info + .stages_started + .extend(job_result.metrics_info.stages_started); + combined_result + .metrics_info + .stages_completed + .extend(job_result.metrics_info.stages_completed); + combined_result + .metrics_info + .stages_failed + .extend(job_result.metrics_info.stages_failed); + combined_result + .metrics_info + .stages_retried + .extend(job_result.metrics_info.stages_retried); + combined_result + .metrics_info + .tasks_completed + .extend(job_result.metrics_info.tasks_completed); + combined_result + .metrics_info + .tasks_failed + .extend(job_result.metrics_info.tasks_failed); + combined_result + .metrics_info + .tasks_retried + .extend(job_result.metrics_info.tasks_retried); } - Ok(events) + Ok(combined_result) } /// Mark a job to success. This will create a key under the CompletedJobs keyspace diff --git a/ballista/scheduler/src/test_utils.rs b/ballista/scheduler/src/test_utils.rs index b4a9400202..25ad5356e1 100644 --- a/ballista/scheduler/src/test_utils.rs +++ b/ballista/scheduler/src/test_utils.rs @@ -755,6 +755,57 @@ impl SchedulerMetricsCollector for TestMetricsCollector { } fn set_pending_tasks_queue_size(&self, _value: u64) {} + fn set_pending_jobs_queue_size(&self, _value: u64) {} + + // Stage lifecycle + fn record_stage_started(&self, _job_id: &str, _stage_id: usize, _task_count: usize) {} + fn record_stage_completed(&self, _job_id: &str, _stage_id: usize, _duration_ms: u64) { + } + fn record_stage_failed(&self, _job_id: &str, _stage_id: usize, _error_type: &str) {} + fn record_stage_retry(&self, _job_id: &str, _stage_id: usize) {} + + // Task scheduling + fn record_task_scheduled( + &self, + _job_id: &str, + _stage_id: usize, + _executor_id: &str, + _latency_ms: u64, + ) { + } + fn record_task_completed(&self, _job_id: &str, _stage_id: usize, _executor_id: &str) { + } + fn record_task_failed( + &self, + _job_id: &str, + _stage_id: usize, + _executor_id: &str, + _error_type: &str, + ) { + } + fn record_task_retry(&self, _job_id: &str, _stage_id: usize) {} + fn record_task_shuffle_affinity_hit( + &self, + _job_id: &str, + _stage_id: usize, + _executor_id: &str, + ) { + } + fn record_task_shuffle_affinity_miss( + &self, + _job_id: &str, + _stage_id: usize, + _executor_id: &str, + ) { + } + + // Executor management + fn set_active_executor_count(&self, _count: usize) {} + fn record_executor_registered(&self, _executor_id: &str) {} + fn record_executor_deregistered(&self, _executor_id: &str) {} + + // Planning + fn record_planning_duration(&self, _job_id: &str, _duration_ms: u64) {} fn gather_metrics(&self) -> Result, String)>> { Ok(None) From f161028f06bc0d7ac329a48f8f4a17826b1f4c5e Mon Sep 17 00:00:00 2001 From: Luke Kim <80174+lukekim@users.noreply.github.com> Date: Thu, 22 Jan 2026 18:05:38 -0800 Subject: [PATCH 3/6] feat: add Vortex columnar format support for shuffle operations (#7) * feat: add Vortex columnar format support for shuffle operations - Introduced Vortex dependencies in Cargo.toml for columnar format handling. - Updated Ballista configuration to support shuffle format selection between Arrow IPC and Vortex. - Implemented Vortex shuffle reader and writer in execution plans. - Enhanced shuffle operations to detect and handle Vortex files. - Added utility functions for writing streams to disk in both Arrow IPC and Vortex formats. - Created a new module for Vortex shuffle operations, including reading and writing logic. - Added tests for Vortex write and read roundtrip functionality. * Fix Clippy and lint * Fix reading of Vortex files --- Cargo.lock | 422 ++++++++++++++++++ Cargo.toml | 7 + ballista/core/Cargo.toml | 9 + ballista/core/src/config.rs | 56 ++- ballista/core/src/execution_plans/mod.rs | 9 + .../src/execution_plans/shuffle_reader.rs | 74 ++- .../src/execution_plans/shuffle_writer.rs | 143 ++++-- .../src/execution_plans/vortex_shuffle.rs | 350 +++++++++++++++ ballista/core/src/extension.rs | 32 +- ballista/core/src/lib.rs | 6 +- ballista/core/src/utils.rs | 35 +- ballista/executor/Cargo.toml | 5 + ballista/executor/src/flight_service.rs | 200 ++++++++- ballista/scheduler/Cargo.toml | 4 + ballista/scheduler/build.rs | 2 +- benchmarks/src/bin/tpch.rs | 2 +- 16 files changed, 1288 insertions(+), 68 deletions(-) create mode 100644 ballista/core/src/execution_plans/vortex_shuffle.rs diff --git a/Cargo.lock b/Cargo.lock index ef14ba781e..e4886c1390 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -155,6 +155,12 @@ dependencies = [ "object", ] +[[package]] +name = "arcref" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28f6098a1e8ab66ff91324cce8fea6643101882cf7d09c85acdb1485ecf61e29" + [[package]] name = "arrayref" version = "0.3.9" @@ -383,6 +389,7 @@ version = "57.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6bb63203e8e0e54b288d0d8043ca8fa1013820822a27692ef1b78a977d879f2c" dependencies = [ + "bitflags 2.10.0", "serde_core", "serde_json", ] @@ -1018,6 +1025,11 @@ dependencies = [ "tonic-prost-build", "url", "uuid", + "vortex-array", + "vortex-buffer", + "vortex-dtype", + "vortex-error", + "vortex-ipc", ] [[package]] @@ -1066,6 +1078,8 @@ dependencies = [ "tracing-appender", "tracing-subscriber", "uuid", + "vortex-array", + "vortex-ipc", ] [[package]] @@ -1097,6 +1111,8 @@ dependencies = [ "tokio", "tokio-stream", "tonic", + "tonic-prost", + "tonic-prost-build", "tracing", "tracing-appender", "tracing-subscriber", @@ -1151,6 +1167,18 @@ version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" +[[package]] +name = "bitvec" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" +dependencies = [ + "funty", + "radium", + "tap", + "wyz", +] + [[package]] name = "blake2" version = "0.10.6" @@ -1621,6 +1649,16 @@ version = "0.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2931af7e13dc045d8e9d26afccc6fa115d64e115c9c84b1166288b46f6782c2" +[[package]] +name = "cudarc" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3aa12038120eb13347a6ae2ffab1d34efe78150125108627fd85044dd4d6ff1e" +dependencies = [ + "half", + "libloading", +] + [[package]] name = "darling" version = "0.21.3" @@ -2581,6 +2619,26 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" +[[package]] +name = "enum-iterator" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4549325971814bda7a44061bf3fe7e487d447cba01e4220a4b454d630d7a016" +dependencies = [ + "enum-iterator-derive", +] + +[[package]] +name = "enum-iterator-derive" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "685adfa4d6f3d765a26bc5dbc936577de9abf756c1feeb3089b01dd395034842" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.114", +] + [[package]] name = "env_filter" version = "0.1.4" @@ -2732,6 +2790,12 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" +[[package]] +name = "funty" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" + [[package]] name = "futures" version = "0.3.31" @@ -2914,6 +2978,8 @@ dependencies = [ "cfg-if", "crunchy", "num-traits", + "rand 0.9.2", + "rand_distr", "zerocopy", ] @@ -3059,6 +3125,15 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" +[[package]] +name = "humansize" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6cb51c9a029ddc91b07a787f1d86b53ccfa49b0e86688c946ebe8d3555685dd7" +dependencies = [ + "libm", +] + [[package]] name = "humantime" version = "2.3.0" @@ -3364,6 +3439,15 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "inventory" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc61209c082fbeb19919bee74b176221b27223e27b65d781eb91af24eb1fb46e" +dependencies = [ + "rustversion", +] + [[package]] name = "ipnet" version = "2.11.0" @@ -3408,10 +3492,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e67e8da4c49d6d9909fe03361f9b620f58898859f5c7aded68351e85e71ecf50" dependencies = [ "jiff-static", + "jiff-tzdb-platform", "log", "portable-atomic", "portable-atomic-util", "serde_core", + "windows-sys 0.52.0", ] [[package]] @@ -3425,6 +3511,21 @@ dependencies = [ "syn 2.0.114", ] +[[package]] +name = "jiff-tzdb" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68971ebff725b9e2ca27a601c5eb38a4c5d64422c4cbab0c535f248087eda5c2" + +[[package]] +name = "jiff-tzdb-platform" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "875a5a69ac2bab1a891711cf5eccbec1ce0341ea805560dcd90b7a2e925132e8" +dependencies = [ + "jiff-tzdb", +] + [[package]] name = "jobserver" version = "0.1.34" @@ -3520,6 +3621,16 @@ version = "0.2.180" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc" +[[package]] +name = "libloading" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55" +dependencies = [ + "cfg-if", + "windows-link", +] + [[package]] name = "libm" version = "0.2.15" @@ -3679,6 +3790,28 @@ version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084" +[[package]] +name = "multiversion" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7edb7f0ff51249dfda9ab96b5823695e15a052dc15074c9dbf3d118afaf2c201" +dependencies = [ + "multiversion-macros", + "target-features", +] + +[[package]] +name = "multiversion-macros" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b093064383341eb3271f42e381cb8f10a01459478446953953c75d24bd339fc0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.114", + "target-features", +] + [[package]] name = "nibble_vec" version = "0.1.0" @@ -3790,6 +3923,27 @@ dependencies = [ "libm", ] +[[package]] +name = "num_enum" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1207a7e20ad57b847bbddc6776b968420d38292bbfe2089accff5e19e82454c" +dependencies = [ + "num_enum_derive", + "rustversion", +] + +[[package]] +name = "num_enum_derive" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff32365de1b6743cb203b710788263c44a03de03802daf96092f2da4fe6ba4d7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.114", +] + [[package]] name = "object" version = "0.32.2" @@ -4441,6 +4595,12 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "radium" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" + [[package]] name = "radix_trie" version = "0.2.1" @@ -4510,6 +4670,16 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "rand_distr" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463" +dependencies = [ + "num-traits", + "rand 0.9.2", +] + [[package]] name = "recursive" version = "0.1.1" @@ -5255,6 +5425,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + [[package]] name = "strsim" version = "0.11.1" @@ -5399,6 +5575,18 @@ dependencies = [ "syn 2.0.114", ] +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + +[[package]] +name = "target-features" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1bbb9f3c5c463a01705937a24fdabc5047929ac764b2d5b9cf681c1f5041ed5" + [[package]] name = "tempfile" version = "3.24.0" @@ -5412,6 +5600,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "termtree" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683" + [[package]] name = "testcontainers" version = "0.25.2" @@ -6079,6 +6273,225 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "vortex-array" +version = "0.1.0" +source = "git+https://github.com/spiceai/vortex?branch=spiceai-51#623936e22bdc1daad586af4a54053240f2009d5c" +dependencies = [ + "arcref", + "arrow-arith", + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-ord", + "arrow-schema", + "arrow-select", + "arrow-string", + "cfg-if", + "enum-iterator", + "flatbuffers", + "futures", + "getrandom 0.3.4", + "humansize", + "inventory", + "itertools", + "multiversion", + "num-traits", + "num_enum", + "parking_lot", + "paste", + "pin-project-lite", + "prost", + "rand 0.9.2", + "rustc-hash", + "simdutf8", + "termtree", + "tracing", + "vortex-buffer", + "vortex-compute", + "vortex-dtype", + "vortex-error", + "vortex-flatbuffers", + "vortex-mask", + "vortex-proto", + "vortex-scalar", + "vortex-session", + "vortex-utils", + "vortex-vector", +] + +[[package]] +name = "vortex-buffer" +version = "0.1.0" +source = "git+https://github.com/spiceai/vortex?branch=spiceai-51#623936e22bdc1daad586af4a54053240f2009d5c" +dependencies = [ + "arrow-buffer", + "bitvec", + "bytes", + "cudarc", + "itertools", + "simdutf8", + "vortex-error", +] + +[[package]] +name = "vortex-compute" +version = "0.1.0" +source = "git+https://github.com/spiceai/vortex?branch=spiceai-51#623936e22bdc1daad586af4a54053240f2009d5c" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-schema", + "half", + "itertools", + "multiversion", + "num-traits", + "paste", + "tracing", + "vortex-buffer", + "vortex-dtype", + "vortex-error", + "vortex-mask", + "vortex-vector", +] + +[[package]] +name = "vortex-dtype" +version = "0.1.0" +source = "git+https://github.com/spiceai/vortex?branch=spiceai-51#623936e22bdc1daad586af4a54053240f2009d5c" +dependencies = [ + "arrow-buffer", + "arrow-schema", + "flatbuffers", + "half", + "itertools", + "jiff", + "num-traits", + "num_enum", + "paste", + "prost", + "serde", + "static_assertions", + "vortex-buffer", + "vortex-error", + "vortex-flatbuffers", + "vortex-proto", + "vortex-utils", +] + +[[package]] +name = "vortex-error" +version = "0.1.0" +source = "git+https://github.com/spiceai/vortex?branch=spiceai-51#623936e22bdc1daad586af4a54053240f2009d5c" +dependencies = [ + "arrow-schema", + "flatbuffers", + "jiff", + "prost", + "url", +] + +[[package]] +name = "vortex-flatbuffers" +version = "0.1.0" +source = "git+https://github.com/spiceai/vortex?branch=spiceai-51#623936e22bdc1daad586af4a54053240f2009d5c" +dependencies = [ + "flatbuffers", + "vortex-buffer", +] + +[[package]] +name = "vortex-ipc" +version = "0.1.0" +source = "git+https://github.com/spiceai/vortex?branch=spiceai-51#623936e22bdc1daad586af4a54053240f2009d5c" +dependencies = [ + "bytes", + "flatbuffers", + "futures", + "itertools", + "pin-project-lite", + "vortex-array", + "vortex-buffer", + "vortex-dtype", + "vortex-error", + "vortex-flatbuffers", +] + +[[package]] +name = "vortex-mask" +version = "0.1.0" +source = "git+https://github.com/spiceai/vortex?branch=spiceai-51#623936e22bdc1daad586af4a54053240f2009d5c" +dependencies = [ + "arrow-buffer", + "itertools", + "vortex-buffer", + "vortex-error", +] + +[[package]] +name = "vortex-proto" +version = "0.1.0" +source = "git+https://github.com/spiceai/vortex?branch=spiceai-51#623936e22bdc1daad586af4a54053240f2009d5c" +dependencies = [ + "prost", + "prost-types", +] + +[[package]] +name = "vortex-scalar" +version = "0.1.0" +source = "git+https://github.com/spiceai/vortex?branch=spiceai-51#623936e22bdc1daad586af4a54053240f2009d5c" +dependencies = [ + "arrow-array", + "bytes", + "itertools", + "num-traits", + "paste", + "prost", + "vortex-buffer", + "vortex-dtype", + "vortex-error", + "vortex-mask", + "vortex-proto", + "vortex-utils", + "vortex-vector", +] + +[[package]] +name = "vortex-session" +version = "0.1.0" +source = "git+https://github.com/spiceai/vortex?branch=spiceai-51#623936e22bdc1daad586af4a54053240f2009d5c" +dependencies = [ + "dashmap", + "vortex-error", + "vortex-utils", +] + +[[package]] +name = "vortex-utils" +version = "0.1.0" +source = "git+https://github.com/spiceai/vortex?branch=spiceai-51#623936e22bdc1daad586af4a54053240f2009d5c" +dependencies = [ + "dashmap", + "hashbrown 0.16.1", + "vortex-error", +] + +[[package]] +name = "vortex-vector" +version = "0.1.0" +source = "git+https://github.com/spiceai/vortex?branch=spiceai-51#623936e22bdc1daad586af4a54053240f2009d5c" +dependencies = [ + "num-traits", + "paste", + "static_assertions", + "vortex-buffer", + "vortex-dtype", + "vortex-error", + "vortex-mask", +] + [[package]] name = "vsimd" version = "0.8.0" @@ -6496,6 +6909,15 @@ version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" +[[package]] +name = "wyz" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" +dependencies = [ + "tap", +] + [[package]] name = "xattr" version = "1.6.1" diff --git a/Cargo.toml b/Cargo.toml index 5693cabcb7..72b8fcfbdb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -72,6 +72,13 @@ tokio-stream = { version = "0.1" } backoff = { version = "0.4" } url = { version = "2.5" } +# Vortex columnar format dependencies from spiceai fork +vortex-array = { git = "https://github.com/spiceai/vortex", branch = "spiceai-51", default-features = false } +vortex-buffer = { git = "https://github.com/spiceai/vortex", branch = "spiceai-51", default-features = false } +vortex-dtype = { git = "https://github.com/spiceai/vortex", branch = "spiceai-51", default-features = false } +vortex-error = { git = "https://github.com/spiceai/vortex", branch = "spiceai-51", default-features = false } +vortex-ipc = { git = "https://github.com/spiceai/vortex", branch = "spiceai-51", default-features = false } + # cargo build --profile release-lto [profile.release-lto] codegen-units = 1 diff --git a/ballista/core/Cargo.toml b/ballista/core/Cargo.toml index 5eec1c19e0..bb34be2bdc 100644 --- a/ballista/core/Cargo.toml +++ b/ballista/core/Cargo.toml @@ -39,6 +39,8 @@ build-binary = ["aws-config", "aws-credential-types", "clap", "object_store"] docsrs = [] # Used for testing ONLY: causes all values to hash to the same value (test for collisions) force_hash_collisions = ["datafusion/force_hash_collisions"] +# Enable Vortex columnar format support for shuffles +vortex = ["vortex-array", "vortex-buffer", "vortex-dtype", "vortex-error", "vortex-ipc"] [dependencies] arrow-flight = { workspace = true } @@ -68,6 +70,13 @@ tonic-prost = { workspace = true } url = { workspace = true } uuid = { workspace = true } +# Vortex columnar format dependencies (optional) +vortex-array = { workspace = true, optional = true } +vortex-buffer = { workspace = true, optional = true } +vortex-dtype = { workspace = true, optional = true } +vortex-error = { workspace = true, optional = true } +vortex-ipc = { workspace = true, optional = true } + [dev-dependencies] tempfile = { workspace = true } diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs index e45fe1afa1..306a1083a5 100644 --- a/ballista/core/src/config.rs +++ b/ballista/core/src/config.rs @@ -19,6 +19,7 @@ //! Ballista configuration use std::result; +use std::str::FromStr; use std::{collections::HashMap, fmt::Display}; use crate::error::{BallistaError, Result}; @@ -47,6 +48,8 @@ pub const BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT: &str = pub const BALLISTA_SHUFFLE_STORAGE_TYPE: &str = "ballista.shuffle.storage_type"; /// Configuration key for shuffle storage base URL/path. pub const BALLISTA_SHUFFLE_STORAGE_URL: &str = "ballista.shuffle.storage_url"; +/// Shuffle format configuration: "arrow_ipc" or "vortex" +pub const BALLISTA_SHUFFLE_FORMAT: &str = "ballista.shuffle.format"; /// Configuration key for gRPC client connection timeout in seconds. pub const BALLISTA_GRPC_CLIENT_CONNECT_TIMEOUT_SECONDS: &str = @@ -112,7 +115,11 @@ static CONFIG_ENTRIES: LazyLock> = LazyLock::new(|| ConfigEntry::new(BALLISTA_GRPC_CLIENT_HTTP2_KEEPALIVE_INTERVAL_SECONDS.to_string(), "HTTP/2 keep-alive interval for gRPC client in seconds".to_string(), DataType::UInt64, - Some((300).to_string())) + Some((300).to_string())), + ConfigEntry::new(BALLISTA_SHUFFLE_FORMAT.to_string(), + "Shuffle data format: 'arrow_ipc' (default) or 'vortex'. Vortex requires the 'vortex' feature to be enabled.".to_string(), + DataType::Utf8, + Some(ShuffleFormat::default().to_string())) ]; entries .into_iter() @@ -120,6 +127,42 @@ static CONFIG_ENTRIES: LazyLock> = LazyLock::new(|| .collect::>() }); +/// Shuffle data format for intermediate shuffle files +#[derive( + Clone, Copy, Debug, PartialEq, Eq, Default, serde::Deserialize, serde::Serialize, +)] +pub enum ShuffleFormat { + /// Arrow IPC format (default, always available) + #[default] + ArrowIpc, + /// Vortex columnar format (requires 'vortex' feature) + Vortex, +} + +impl Display for ShuffleFormat { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ShuffleFormat::ArrowIpc => f.write_str("arrow_ipc"), + ShuffleFormat::Vortex => f.write_str("vortex"), + } + } +} + +impl FromStr for ShuffleFormat { + type Err = String; + + fn from_str(s: &str) -> std::result::Result { + match s.to_lowercase().as_str() { + "arrow_ipc" | "arrow-ipc" | "arrowips" | "ipc" => Ok(ShuffleFormat::ArrowIpc), + "vortex" => Ok(ShuffleFormat::Vortex), + _ => Err(format!( + "Invalid shuffle format '{}'. Valid options are: 'arrow_ipc', 'vortex'", + s + )), + } + } +} + /// Configuration option meta-data #[derive(Debug, Clone)] pub struct ConfigEntry { @@ -286,6 +329,17 @@ impl BallistaConfig { self.settings.get(BALLISTA_SHUFFLE_STORAGE_URL).cloned() } + /// Returns the configured shuffle format (ArrowIpc or Vortex) + /// + /// Note: Vortex format requires the 'vortex' feature to be enabled. + /// If Vortex is configured but the feature is not enabled, this will + /// still return Vortex, but the shuffle operations will fail at runtime. + pub fn shuffle_format(&self) -> ShuffleFormat { + self.get_string_setting(BALLISTA_SHUFFLE_FORMAT) + .parse() + .unwrap_or_default() + } + fn get_usize_setting(&self, key: &str) -> usize { if let Some(v) = self.settings.get(key) { // infallible because we validate all configs in the constructor diff --git a/ballista/core/src/execution_plans/mod.rs b/ballista/core/src/execution_plans/mod.rs index 7a5e105c6c..b431c0a1f6 100644 --- a/ballista/core/src/execution_plans/mod.rs +++ b/ballista/core/src/execution_plans/mod.rs @@ -23,7 +23,16 @@ mod shuffle_reader; mod shuffle_writer; mod unresolved_shuffle; +#[cfg(feature = "vortex")] +pub mod vortex_shuffle; + pub use distributed_query::DistributedQueryExec; pub use shuffle_reader::ShuffleReaderExec; pub use shuffle_writer::ShuffleWriterExec; pub use unresolved_shuffle::UnresolvedShuffleExec; + +#[cfg(feature = "vortex")] +pub use vortex_shuffle::{ + LocalVortexShuffleStream, VortexWriteTracker, vortex_file_extension, + write_stream_to_disk_vortex, +}; diff --git a/ballista/core/src/execution_plans/shuffle_reader.rs b/ballista/core/src/execution_plans/shuffle_reader.rs index 532d4f7dbd..6b49184aa7 100644 --- a/ballista/core/src/execution_plans/shuffle_reader.rs +++ b/ballista/core/src/execution_plans/shuffle_reader.rs @@ -666,19 +666,48 @@ async fn fetch_partition_local( let metadata = &location.executor_meta; let partition_id = &location.partition_id; - let reader = fetch_partition_local_inner(path).map_err(|e| { - // return BallistaError::FetchFailed may let scheduler retry this task. - BallistaError::FetchFailed( - metadata.id.clone(), - partition_id.stage_id, - partition_id.partition_id, - e.to_string(), - ) - })?; - Ok(Box::pin(LocalShuffleStream::new(reader))) + // Detect format from file extension + let is_vortex = path.ends_with(".vortex"); + + if is_vortex { + #[cfg(feature = "vortex")] + { + // For Vortex files, we need the schema. Get it from partition stats or infer. + // For now, we'll create a stream that reads the vortex file. + // Note: Vortex IPC format is self-describing, so we can read the schema from the file. + let stream = fetch_partition_local_vortex(path).map_err(|e| { + BallistaError::FetchFailed( + metadata.id.clone(), + partition_id.stage_id, + partition_id.partition_id, + e.to_string(), + ) + })?; + Ok(stream) + } + #[cfg(not(feature = "vortex"))] + { + Err(BallistaError::General( + "Vortex format files found but 'vortex' feature is not enabled" + .to_string(), + )) + } + } else { + // Arrow IPC format + let reader = fetch_partition_local_arrow(path).map_err(|e| { + BallistaError::FetchFailed( + metadata.id.clone(), + partition_id.stage_id, + partition_id.partition_id, + e.to_string(), + ) + })?; + Ok(Box::pin(LocalShuffleStream::new(reader))) + } } -fn fetch_partition_local_inner( +/// Fetch partition from local Arrow IPC file +fn fetch_partition_local_arrow( path: &str, ) -> result::Result>, BallistaError> { let file = File::open(path).map_err(|e| { @@ -686,11 +715,30 @@ fn fetch_partition_local_inner( })?; let file = BufReader::new(file); let reader = StreamReader::try_new(file, None).map_err(|e| { - BallistaError::General(format!("Failed to new arrow FileReader at {path}: {e:?}")) + BallistaError::General(format!( + "Failed to create Arrow IPC reader at {path}: {e:?}" + )) })?; Ok(reader) } +/// Fetch partition from local Vortex file +#[cfg(feature = "vortex")] +fn fetch_partition_local_vortex( + path: &str, +) -> result::Result { + use super::vortex_shuffle::LocalVortexShuffleStream; + + // Vortex IPC format is self-describing, but we need a schema for the stream interface. + // For now, use an empty schema - the actual data schema will come from the Vortex arrays. + // TODO: Consider storing schema metadata in the Vortex file or a sidecar file. + let schema = std::sync::Arc::new(datafusion::arrow::datatypes::Schema::empty()); + + // Create the stream - it handles reading and converting Vortex arrays to Arrow + let stream = LocalVortexShuffleStream::try_new(path, schema)?; + Ok(Box::pin(stream)) +} + #[cfg(feature = "build-binary")] async fn fetch_partition_object_store( location: &PartitionLocation, @@ -1146,7 +1194,7 @@ mod tests { // from to input partitions test the first one with two batches let file_path = path.value(0); - let reader = fetch_partition_local_inner(file_path).unwrap(); + let reader = fetch_partition_local_arrow(file_path).unwrap(); let mut stream: Pin> = async { Box::pin(LocalShuffleStream::new(reader)) }.await; diff --git a/ballista/core/src/execution_plans/shuffle_writer.rs b/ballista/core/src/execution_plans/shuffle_writer.rs index 114fc66437..ec281da65e 100644 --- a/ballista/core/src/execution_plans/shuffle_writer.rs +++ b/ballista/core/src/execution_plans/shuffle_writer.rs @@ -17,8 +17,9 @@ //! ShuffleWriterExec represents a section of a query plan that has consistent partitioning and //! can be executed as one unit with each partition being executed in parallel. The output of each -//! partition is re-partitioned and streamed to disk in Arrow IPC format. Future stages of the query -//! will use the ShuffleReaderExec to read these results. +//! partition is re-partitioned and streamed to disk in Arrow IPC format (default) or Vortex format. +//! The shuffle format is configurable. Future stages of the query will use the ShuffleReaderExec +//! to read these results. use datafusion::arrow::ipc::CompressionType; use datafusion::arrow::ipc::writer::IpcWriteOptions; @@ -34,6 +35,8 @@ use std::path::PathBuf; use std::sync::Arc; use std::time::Instant; +use crate::config::ShuffleFormat; +use crate::extension::SessionConfigExt; use crate::utils; use crate::serde::protobuf::ShuffleWritePartition; @@ -102,10 +105,96 @@ impl std::fmt::Display for ShuffleWriterExec { } } +/// Writer for Arrow IPC format +pub struct ArrowIpcWriter { + writer: StreamWriter, +} + +impl ArrowIpcWriter { + pub fn try_new( + file: File, + schema: &datafusion::arrow::datatypes::Schema, + ) -> Result { + let options = IpcWriteOptions::default() + .try_with_compression(Some(CompressionType::LZ4_FRAME))?; + let writer = StreamWriter::try_new_with_options(file, schema, options)?; + Ok(Self { writer }) + } + + pub fn write(&mut self, batch: &RecordBatch) -> Result<()> { + self.writer.write(batch)?; + Ok(()) + } + + pub fn finish(&mut self) -> Result<()> { + self.writer.finish()?; + Ok(()) + } +} + +/// Format-agnostic shuffle writer enum +pub enum ShuffleFileWriter { + ArrowIpc(ArrowIpcWriter), + #[cfg(feature = "vortex")] + Vortex(super::vortex_shuffle::VortexWriteTracker), +} + +impl ShuffleFileWriter { + pub fn try_new_arrow_ipc( + path: PathBuf, + schema: &datafusion::arrow::datatypes::Schema, + ) -> Result { + let file = File::create(&path)?; + Ok(Self::ArrowIpc(ArrowIpcWriter::try_new(file, schema)?)) + } + + #[cfg(feature = "vortex")] + pub fn try_new_vortex( + path: PathBuf, + schema: datafusion::arrow::datatypes::SchemaRef, + ) -> Result { + let tracker = super::vortex_shuffle::VortexWriteTracker::try_new(path, schema)?; + Ok(Self::Vortex(tracker)) + } + + pub fn try_new( + path: PathBuf, + schema: datafusion::arrow::datatypes::SchemaRef, + format: ShuffleFormat, + ) -> Result { + match format { + ShuffleFormat::ArrowIpc => Self::try_new_arrow_ipc(path, schema.as_ref()), + #[cfg(feature = "vortex")] + ShuffleFormat::Vortex => Self::try_new_vortex(path, schema), + #[cfg(not(feature = "vortex"))] + ShuffleFormat::Vortex => Err(DataFusionError::NotImplemented( + "Vortex format requires the 'vortex' feature to be enabled".to_string(), + )), + } + } + + pub fn write(&mut self, batch: &RecordBatch) -> Result<()> { + match self { + Self::ArrowIpc(w) => w.write(batch), + #[cfg(feature = "vortex")] + Self::Vortex(w) => w.write(batch), + } + } + + pub fn finish(self) -> Result<()> { + match self { + Self::ArrowIpc(mut w) => w.finish(), + #[cfg(feature = "vortex")] + Self::Vortex(w) => w.finish(), + } + } +} + +/// Tracks write progress for a partition pub struct WriteTracker { pub num_batches: usize, pub num_rows: usize, - pub writer: StreamWriter, + pub writer: ShuffleFileWriter, pub path: PathBuf, } @@ -205,6 +294,10 @@ impl ShuffleWriterExec { let output_partitioning = self.shuffle_output_partitioning.clone(); let plan = self.plan.clone(); + // Get shuffle format from session config + let shuffle_format = context.session_config().ballista_shuffle_format(); + let file_ext = utils::shuffle_file_extension(shuffle_format); + async move { let now = Instant::now(); let mut stream = plan.execute(input_partition, context)?; @@ -214,15 +307,16 @@ impl ShuffleWriterExec { let timer = write_metrics.write_time.timer(); path.push(format!("{input_partition}")); std::fs::create_dir_all(&path)?; - path.push("data.arrow"); + path.push(format!("data.{file_ext}")); let path = path.to_str().unwrap(); - debug!("Writing results to {path}"); + debug!("Writing results to {path} (format: {shuffle_format})"); - // stream results to disk - let stats = utils::write_stream_to_disk( + // stream results to disk using configured format + let stats = utils::write_stream_to_disk_with_format( &mut stream, path, &write_metrics.write_time, + shuffle_format, ) .await .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; @@ -264,6 +358,8 @@ impl ShuffleWriterExec { write_metrics.repart_time.clone(), )?; + let schema = stream.schema(); + while let Some(result) = stream.next().await { let input_batch = result?; @@ -281,34 +377,27 @@ impl ShuffleWriterExec { w.writer.write(&output_batch)?; } None => { - let mut path = path.clone(); - path.push(format!("{output_partition}")); - std::fs::create_dir_all(&path)?; + let mut file_path = path.clone(); + file_path.push(format!("{output_partition}")); + std::fs::create_dir_all(&file_path)?; - path.push(format!( - "data-{input_partition}.arrow" + file_path.push(format!( + "data-{input_partition}.{file_ext}" )); - debug!("Writing results to {path:?}"); - - let options = IpcWriteOptions::default() - .try_with_compression(Some( - CompressionType::LZ4_FRAME, - ))?; + debug!("Writing results to {file_path:?} (format: {shuffle_format})"); - let file = File::create(path.clone())?; - let mut writer = - StreamWriter::try_new_with_options( - file, - stream.schema().as_ref(), - options, - )?; + let mut writer = ShuffleFileWriter::try_new( + file_path.clone(), + schema.clone(), + shuffle_format, + )?; writer.write(&output_batch)?; writers[output_partition] = Some(WriteTracker { num_batches: 1, num_rows: output_batch.num_rows(), writer, - path, + path: file_path, }); } } @@ -321,7 +410,7 @@ impl ShuffleWriterExec { let mut part_locs = vec![]; - for (i, w) in writers.iter_mut().enumerate() { + for (i, w) in writers.into_iter().enumerate() { if let Some(w) = w { let num_bytes = fs::metadata(&w.path)?.len(); w.writer.finish()?; diff --git a/ballista/core/src/execution_plans/vortex_shuffle.rs b/ballista/core/src/execution_plans/vortex_shuffle.rs new file mode 100644 index 0000000000..32ae0872d6 --- /dev/null +++ b/ballista/core/src/execution_plans/vortex_shuffle.rs @@ -0,0 +1,350 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Vortex format support for shuffle operations. +//! +//! This module provides Vortex-based serialization for shuffle data, +//! offering an alternative to Arrow IPC format with potentially better +//! compression and performance characteristics. +//! +//! Vortex IPC format is used for streaming data between processes. + +use std::fs::File; +use std::io::{BufReader, BufWriter, Cursor, Read, Write}; +use std::path::PathBuf; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::error::Result; +use datafusion::physical_plan::RecordBatchStream; +use futures::Stream; +use log::debug; + +use vortex_array::ArrayRef; +use vortex_array::arrow::FromArrowArray; +use vortex_array::arrow::IntoArrowArray; +use vortex_array::iter::ArrayIteratorAdapter; +use vortex_array::session::ArraySession; +use vortex_error::VortexResult; +use vortex_ipc::iterator::{ArrayIteratorIPC, SyncIPCReader}; + +use crate::error::BallistaError; +use crate::serde::scheduler::PartitionStats; + +/// Writer for Vortex format shuffle data +pub struct VortexWriteTracker { + /// Number of record batches written + pub num_batches: usize, + /// Total number of rows written + pub num_rows: usize, + /// Path to the output file + pub path: PathBuf, + file: BufWriter, + #[allow(dead_code)] // May be needed for schema validation in the future + schema: SchemaRef, + buffer: Vec, +} + +impl VortexWriteTracker { + /// Create a new Vortex writer for the given path + pub fn try_new(path: PathBuf, schema: SchemaRef) -> Result { + let file = File::create(&path)?; + let writer = BufWriter::new(file); + + Ok(Self { + num_batches: 0, + num_rows: 0, + path, + file: writer, + schema, + buffer: Vec::new(), + }) + } + + /// Write a record batch to the buffer + pub fn write(&mut self, batch: &RecordBatch) -> Result<()> { + // Convert Arrow RecordBatch to Vortex Array + let vortex_array = ArrayRef::from_arrow(batch, false); + + self.buffer.push(vortex_array); + self.num_batches += 1; + self.num_rows += batch.num_rows(); + Ok(()) + } + + /// Finish writing and close the file + pub fn finish(mut self) -> Result<()> { + // Write all buffered arrays using IPC format + if !self.buffer.is_empty() { + // Get the dtype from the first array + let dtype = self.buffer[0].dtype().clone(); + + // Create an ArrayIterator from the buffer + let iter = self + .buffer + .into_iter() + .map(|a| Ok(a) as VortexResult); + let array_iter = ArrayIteratorAdapter::new(dtype, iter); + + // Convert to IPC bytes + let ipc_data = array_iter + .into_ipc() + .collect_to_buffer() + .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; + + self.file.write_all(ipc_data.as_ref()).map_err(|e| { + datafusion::error::DataFusionError::Execution(format!( + "Failed to write Vortex IPC data: {e}" + )) + })?; + } + + self.file.flush().map_err(|e| { + datafusion::error::DataFusionError::Execution(format!( + "Failed to flush Vortex file: {e}" + )) + })?; + + Ok(()) + } +} + +/// Stream for reading Vortex shuffle files locally +pub struct LocalVortexShuffleStream { + arrays: std::vec::IntoIter, + schema: SchemaRef, +} + +impl LocalVortexShuffleStream { + /// Create a new stream from a Vortex file path + pub fn try_new( + path: &str, + schema: SchemaRef, + ) -> std::result::Result { + let file = File::open(path).map_err(|e| { + BallistaError::General(format!( + "Failed to open Vortex partition file at {path}: {e:?}" + )) + })?; + + let mut buf_reader = BufReader::new(file); + let mut data = Vec::new(); + buf_reader.read_to_end(&mut data).map_err(|e| { + BallistaError::General(format!("Failed to read Vortex file at {path}: {e:?}")) + })?; + + // Create default registry with all canonical encodings + let session = ArraySession::default(); + let registry = session.registry().clone(); + + // Read IPC data + let cursor = Cursor::new(data); + let reader = SyncIPCReader::try_new(cursor, registry).map_err(|e| { + BallistaError::General(format!( + "Failed to create Vortex IPC reader at {path}: {e:?}" + )) + })?; + + let arrays: Vec = reader + .map(|r| { + r.map_err(|e| { + BallistaError::General(format!("Failed to read array: {e:?}")) + }) + }) + .collect::, _>>()?; + + Ok(Self { + arrays: arrays.into_iter(), + schema, + }) + } +} + +impl Stream for LocalVortexShuffleStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + match self.arrays.next() { + Some(array) => { + // Convert Vortex array back to Arrow + let arrow_array = array.into_arrow_preferred().map_err(|e| { + datafusion::error::DataFusionError::External(Box::new(e)) + })?; + + // The arrow_array should be a StructArray since we converted from RecordBatch + let struct_array = arrow_array + .as_any() + .downcast_ref::( + ); + + match struct_array { + Some(sa) => { + let batch = RecordBatch::from(sa); + Poll::Ready(Some(Ok(batch))) + } + None => Poll::Ready(Some(Err( + datafusion::error::DataFusionError::Internal( + "Expected StructArray from Vortex".to_string(), + ), + ))), + } + } + None => Poll::Ready(None), + } + } +} + +impl RecordBatchStream for LocalVortexShuffleStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +/// Write a stream to disk in Vortex IPC format +pub async fn write_stream_to_disk_vortex( + stream: &mut Pin>, + path: &str, + disk_write_metric: &datafusion::physical_plan::metrics::Time, +) -> std::result::Result { + use futures::StreamExt; + + let file = File::create(path).map_err(|e| { + log::error!("Failed to create Vortex partition file at {path}: {e:?}"); + BallistaError::IoError(e) + })?; + + let mut num_rows = 0; + let mut num_batches = 0; + let mut num_bytes = 0; + let mut arrays: Vec = Vec::new(); + + while let Some(result) = stream.next().await { + let batch = result?; + + let batch_size_bytes: usize = batch.get_array_memory_size(); + num_batches += 1; + num_rows += batch.num_rows(); + num_bytes += batch_size_bytes; + + // Convert Arrow RecordBatch to Vortex Array + let vortex_array = ArrayRef::from_arrow(&batch, false); + arrays.push(vortex_array); + } + + // Write all arrays using IPC format + let timer = disk_write_metric.timer(); + let mut writer = BufWriter::new(file); + + if !arrays.is_empty() { + // Get the dtype from the first array + let dtype = arrays[0].dtype().clone(); + + // Create an ArrayIterator from the buffer + let iter = arrays.into_iter().map(|a| Ok(a) as VortexResult); + let array_iter = ArrayIteratorAdapter::new(dtype, iter); + + // Convert to IPC bytes + let ipc_data = array_iter.into_ipc().collect_to_buffer().map_err(|e| { + BallistaError::General(format!("Failed to write Vortex IPC: {e}")) + })?; + + writer.write_all(ipc_data.as_ref()).map_err(|e| { + BallistaError::General(format!("Failed to write to file: {e}")) + })?; + } + + writer.flush().map_err(|e| { + BallistaError::General(format!("Failed to flush Vortex file: {e}")) + })?; + timer.done(); + + debug!( + "Wrote Vortex shuffle file to {}: {} rows, {} batches, {} bytes", + path, num_rows, num_batches, num_bytes + ); + + Ok(PartitionStats::new( + Some(num_rows as u64), + Some(num_batches), + Some(num_bytes as u64), + )) +} + +/// Get the file extension for Vortex files +pub fn vortex_file_extension() -> &'static str { + "vortex" +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::arrow::array::{Int32Array, StringArray}; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use std::sync::Arc; + use tempfile::TempDir; + + fn create_test_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + ])); + + RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + ], + ) + .unwrap() + } + + #[tokio::test] + async fn test_vortex_write_read_roundtrip() { + let work_dir = TempDir::new().unwrap(); + let path = work_dir.path().join("test.vortex"); + + let batch = create_test_batch(); + let schema = batch.schema(); + + // Write + { + let mut writer = + VortexWriteTracker::try_new(path.clone(), schema.clone()).unwrap(); + writer.write(&batch).unwrap(); + writer.finish().unwrap(); + } + + // Read + { + let stream = + LocalVortexShuffleStream::try_new(path.to_str().unwrap(), schema) + .unwrap(); + let mut pinned = Box::pin(stream); + + use futures::StreamExt; + let result = pinned.next().await.unwrap().unwrap(); + assert_eq!(result.num_rows(), 3); + assert_eq!(result.num_columns(), 2); + } + } +} diff --git a/ballista/core/src/extension.rs b/ballista/core/src/extension.rs index 28a06185a1..c799ebdf48 100644 --- a/ballista/core/src/extension.rs +++ b/ballista/core/src/extension.rs @@ -16,10 +16,11 @@ // under the License. use crate::config::{ - BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE, BALLISTA_JOB_NAME, + BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE, BALLISTA_JOB_NAME, BALLISTA_SHUFFLE_FORMAT, BALLISTA_SHUFFLE_READER_FORCE_REMOTE_READ, BALLISTA_SHUFFLE_READER_MAX_REQUESTS, BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT, BALLISTA_SHUFFLE_STORAGE_TYPE, BALLISTA_SHUFFLE_STORAGE_URL, BALLISTA_STANDALONE_PARALLELISM, BallistaConfig, + ShuffleFormat, }; use crate::planner::BallistaQueryPlanner; use crate::serde::protobuf::KeyValuePair; @@ -190,6 +191,18 @@ pub trait SessionConfigExt { /// Sets the shuffle storage base URL/path. fn with_ballista_shuffle_storage_url(self, url: &str) -> Self; + /// Get the shuffle format (ArrowIpc or Vortex) + /// + /// Note: Vortex format requires the 'vortex' feature to be enabled. + fn ballista_shuffle_format(&self) -> ShuffleFormat; + + /// Set the shuffle format for intermediate shuffle data + /// + /// Available formats: + /// - `ShuffleFormat::ArrowIpc` (default) - Standard Arrow IPC format + /// - `ShuffleFormat::Vortex` - Vortex columnar format (requires 'vortex' feature) + fn with_ballista_shuffle_format(self, format: ShuffleFormat) -> Self; + /// Set a callback for recording shuffle read metrics (local vs remote). /// /// This callback will be invoked by the shuffle reader during execution @@ -535,6 +548,23 @@ impl SessionConfigExt for SessionConfig { } } + fn ballista_shuffle_format(&self) -> ShuffleFormat { + self.options() + .extensions + .get::() + .map(|c| c.shuffle_format()) + .unwrap_or_else(|| BallistaConfig::default().shuffle_format()) + } + + fn with_ballista_shuffle_format(self, format: ShuffleFormat) -> Self { + if self.options().extensions.get::().is_some() { + self.set_str(BALLISTA_SHUFFLE_FORMAT, &format.to_string()) + } else { + self.with_option_extension(BallistaConfig::default()) + .set_str(BALLISTA_SHUFFLE_FORMAT, &format.to_string()) + } + } + fn with_ballista_shuffle_read_metrics_callback( self, callback: Arc, diff --git a/ballista/core/src/lib.rs b/ballista/core/src/lib.rs index e0107bc6d0..846e6671ac 100644 --- a/ballista/core/src/lib.rs +++ b/ballista/core/src/lib.rs @@ -52,15 +52,13 @@ pub mod object_store; pub mod planner; /// Runtime registry for codec and function registration. pub mod registry; +/// Remote catalog for distributed function and table registration. +pub mod remote_catalog; /// Serialization and deserialization for Ballista messages and plans. pub mod serde; /// Shuffle storage abstraction for local and object store backends. #[cfg(feature = "build-binary")] pub mod shuffle_storage; - -/// Remote catalog serialization and stub providers for Ballista clients. -pub mod remote_catalog; - /// General utility functions for Ballista operations. pub mod utils; diff --git a/ballista/core/src/utils.rs b/ballista/core/src/utils.rs index 861dd72e7a..d89dc6c5c5 100644 --- a/ballista/core/src/utils.rs +++ b/ballista/core/src/utils.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::config::BallistaConfig; +use crate::config::{BallistaConfig, ShuffleFormat}; use crate::error::{BallistaError, Result}; use crate::extension::SessionConfigExt; use crate::serde::scheduler::PartitionStats; @@ -196,6 +196,39 @@ pub async fn collect_stream( Ok(batches) } +/// Write stream to disk using the specified shuffle format +/// +/// This function dispatches to the appropriate writer based on the format: +/// - ArrowIpc: Uses Arrow IPC streaming format with LZ4 compression +/// - Vortex: Uses Vortex columnar format (requires 'vortex' feature) +pub async fn write_stream_to_disk_with_format( + stream: &mut Pin>, + path: &str, + disk_write_metric: &metrics::Time, + format: ShuffleFormat, +) -> Result { + match format { + ShuffleFormat::ArrowIpc => write_stream_to_disk(stream, path, disk_write_metric).await, + #[cfg(feature = "vortex")] + ShuffleFormat::Vortex => { + crate::execution_plans::write_stream_to_disk_vortex(stream, path, disk_write_metric) + .await + } + #[cfg(not(feature = "vortex"))] + ShuffleFormat::Vortex => Err(BallistaError::General( + "Vortex format is not available. Enable the 'vortex' feature to use Vortex shuffle format.".to_string(), + )), + } +} + +/// Get the file extension for the given shuffle format +pub fn shuffle_file_extension(format: ShuffleFormat) -> &'static str { + match format { + ShuffleFormat::ArrowIpc => "arrow", + ShuffleFormat::Vortex => "vortex", + } +} + /// Creates a gRPC client connection with the specified configuration. pub async fn create_grpc_client_connection( dst: D, diff --git a/ballista/executor/Cargo.toml b/ballista/executor/Cargo.toml index fb692cbf51..41ab99b904 100644 --- a/ballista/executor/Cargo.toml +++ b/ballista/executor/Cargo.toml @@ -35,6 +35,7 @@ required-features = ["build-binary"] [features] build-binary = ["clap", "tracing-subscriber", "tracing-appender", "tracing", "ballista-core/build-binary"] default = ["build-binary", "mimalloc"] +vortex = ["ballista-core/vortex", "vortex-array", "vortex-ipc"] [dependencies] arrow = { workspace = true } @@ -63,6 +64,10 @@ tracing-appender = { workspace = true, optional = true } tracing-subscriber = { workspace = true, optional = true } uuid = { workspace = true } +# Vortex columnar format dependencies (optional) +vortex-array = { workspace = true, optional = true } +vortex-ipc = { workspace = true, optional = true } + [dev-dependencies] [build-dependencies] diff --git a/ballista/executor/src/flight_service.rs b/ballista/executor/src/flight_service.rs index 635ec6997b..73f9128e67 100644 --- a/ballista/executor/src/flight_service.rs +++ b/ballista/executor/src/flight_service.rs @@ -17,9 +17,11 @@ //! Implementation of the Apache Arrow Flight protocol that wraps an executor. +use datafusion::arrow::datatypes::SchemaRef; use datafusion::arrow::ipc::reader::StreamReader; use std::convert::TryFrom; use std::fs::File; +use std::path::Path; use std::pin::Pin; use tokio_util::io::ReaderStream; @@ -96,25 +98,35 @@ impl FlightService for BallistaFlightService { match &action { BallistaAction::FetchPartition { path, .. } => { - debug!("FetchPartition reading {path}"); - let file = File::open(path) - .map_err(|e| { - BallistaError::General(format!( - "Failed to open partition file at {path}: {e:?}" - )) - }) - .map_err(|e| from_ballista_err(&e))?; - let file = BufReader::new(file); - let reader = - StreamReader::try_new(file, None).map_err(|e| from_arrow_err(&e))?; - - let (tx, rx) = channel(2); - let schema = reader.schema(); - task::spawn_blocking(move || { - if let Err(e) = read_partition(reader, tx) { - log::warn!("error streaming shuffle partition: {e}"); + // Detect shuffle format based on file extension + let is_vortex = Path::new(path) + .extension() + .map(|ext| ext == "vortex") + .unwrap_or(false); + + let format = if is_vortex { "vortex" } else { "arrow-ipc" }; + debug!("FetchPartition reading {path} (format: {format})"); + + // Detect shuffle format based on file extension + let is_vortex = Path::new(path) + .extension() + .map(|ext| ext == "vortex") + .unwrap_or(false); + + let (schema, rx) = if is_vortex { + #[cfg(feature = "vortex")] + { + read_vortex_partition(path)? + } + #[cfg(not(feature = "vortex"))] + { + return Err(Status::unimplemented( + "Vortex format is not available. Enable the 'vortex' feature.", + )); } - }); + } else { + read_arrow_ipc_partition(path)? + }; let write_options: IpcWriteOptions = IpcWriteOptions::default() .try_with_compression(Some(CompressionType::LZ4_FRAME)) @@ -274,7 +286,157 @@ impl FlightService for BallistaFlightService { } } -fn read_partition( +/// Read an Arrow IPC partition file and return the schema and a receiver for record batches +fn read_arrow_ipc_partition( + path: &str, +) -> Result< + ( + SchemaRef, + tokio::sync::mpsc::Receiver>, + ), + Status, +> { + let file = File::open(path) + .map_err(|e| { + BallistaError::General(format!( + "Failed to open partition file at {path}: {e:?}" + )) + }) + .map_err(|e| from_ballista_err(&e))?; + let file = BufReader::new(file); + let reader = StreamReader::try_new(file, None).map_err(|e| from_arrow_err(&e))?; + + let (tx, rx) = channel(2); + let schema = reader.schema(); + task::spawn_blocking(move || { + if let Err(e) = read_arrow_ipc_batches(reader, tx) { + log::warn!("error streaming Arrow IPC shuffle partition: {e}"); + } + }); + + Ok((schema, rx)) +} + +/// Read Vortex partition file and return the schema and a receiver for record batches +#[cfg(feature = "vortex")] +fn read_vortex_partition( + path: &str, +) -> Result< + ( + SchemaRef, + tokio::sync::mpsc::Receiver>, + ), + Status, +> { + use std::io::Cursor; + use std::sync::Arc; + use vortex_array::ArrayRef; + use vortex_array::iter::ArrayIterator; + use vortex_array::session::ArraySession; + use vortex_ipc::iterator::SyncIPCReader; + + let file = File::open(path) + .map_err(|e| { + BallistaError::General(format!( + "Failed to open Vortex partition file at {path}: {e:?}" + )) + }) + .map_err(|e| from_ballista_err(&e))?; + + let mut buf_reader = BufReader::new(file); + let mut data = Vec::new(); + std::io::Read::read_to_end(&mut buf_reader, &mut data).map_err(|e| { + from_ballista_err(&BallistaError::General(format!( + "Failed to read Vortex file at {path}: {e:?}" + ))) + })?; + + // Create default registry with all canonical encodings + let session = ArraySession::default(); + let registry = session.registry().clone(); + + // Read IPC data + let cursor = Cursor::new(data); + let reader = SyncIPCReader::try_new(cursor, registry).map_err(|e| { + from_ballista_err(&BallistaError::General(format!( + "Failed to create Vortex IPC reader at {path}: {e:?}" + ))) + })?; + + // Get schema from IPC header via ArrayIterator::dtype() method + // This is stored in the Vortex IPC format header, not inferred from data + let dtype = reader.dtype().clone(); + let arrow_schema = dtype.to_arrow_schema().map_err(|e| { + from_ballista_err(&BallistaError::General(format!( + "Failed to convert Vortex DType to Arrow schema: {e:?}" + ))) + })?; + let schema = Arc::new(arrow_schema); + + let arrays: Vec = reader + .map(|r| { + r.map_err(|e| { + from_ballista_err(&BallistaError::General(format!( + "Failed to read Vortex array: {e:?}" + ))) + }) + }) + .collect::, _>>()?; + + let (tx, rx) = channel(2); + task::spawn_blocking(move || { + if let Err(e) = read_vortex_batches(arrays, tx) { + log::warn!("error streaming Vortex shuffle partition: {e}"); + } + }); + + Ok((schema, rx)) +} + +/// Read Vortex arrays and send them as record batches +#[cfg(feature = "vortex")] +fn read_vortex_batches( + arrays: Vec, + tx: Sender>, +) -> Result<(), FlightError> { + use vortex_array::arrow::IntoArrowArray; + + if tx.is_closed() { + return Err(FlightError::Tonic(Box::new(Status::internal( + "Can't send a batch, channel is closed", + )))); + } + + for array in arrays { + let arrow_array = array + .into_arrow_preferred() + .map_err(|e| FlightError::Arrow(ArrowError::ExternalError(Box::new(e))))?; + + let struct_array = arrow_array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + FlightError::Arrow(ArrowError::InvalidArgumentError( + "Expected StructArray from Vortex".to_string(), + )) + })?; + + let batch = RecordBatch::from(struct_array); + + tx.blocking_send(Ok(batch)).map_err(|err| { + if let SendError(Err(err)) = err { + err + } else { + FlightError::Tonic(Box::new(Status::internal(format!( + "Can't send a batch, something went wrong: {err:?}" + )))) + } + })?; + } + Ok(()) +} + +fn read_arrow_ipc_batches( reader: StreamReader>, tx: Sender>, ) -> Result<(), FlightError> diff --git a/ballista/scheduler/Cargo.toml b/ballista/scheduler/Cargo.toml index d29c7c5cc1..883b211d13 100644 --- a/ballista/scheduler/Cargo.toml +++ b/ballista/scheduler/Cargo.toml @@ -69,11 +69,15 @@ serde = { workspace = true, features = ["derive"] } tokio = { workspace = true, features = ["full"] } tokio-stream = { workspace = true, features = ["net"] } tonic = { workspace = true, features = ["router"] } +tonic-prost = { workspace = true } tracing = { workspace = true, optional = true } tracing-appender = { workspace = true, optional = true } tracing-subscriber = { workspace = true, optional = true } uuid = { workspace = true } +[build-dependencies] +tonic-prost-build = { workspace = true } + [dev-dependencies] rstest = { workspace = true } diff --git a/ballista/scheduler/build.rs b/ballista/scheduler/build.rs index ae0369dd53..5a90821250 100644 --- a/ballista/scheduler/build.rs +++ b/ballista/scheduler/build.rs @@ -20,7 +20,7 @@ fn main() -> Result<(), String> { println!("cargo:rerun-if-changed=proto/keda.proto"); #[cfg(feature = "keda-scaler")] - tonic_build::configure() + tonic_prost_build::configure() .compile_protos(&["proto/keda.proto"], &["proto"]) .map_err(|e| format!("protobuf compilation failed: {e}"))?; diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index d1401a6589..53dea05701 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -62,7 +62,7 @@ use std::{ use structopt::StructOpt; use tokio::task::JoinHandle; -#[cfg(feature = "snmalloc")] +#[cfg(all(feature = "snmalloc", not(feature = "mimalloc")))] #[global_allocator] static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; From 7ef8519b8cd0286ab00fe3dd091104b93017a45c Mon Sep 17 00:00:00 2001 From: Luke Kim <80174+lukekim@users.noreply.github.com> Date: Thu, 22 Jan 2026 21:05:56 -0800 Subject: [PATCH 4/6] Fix lint --- ballista/core/src/shuffle_storage.rs | 92 ++++++++++++++++------------ 1 file changed, 52 insertions(+), 40 deletions(-) diff --git a/ballista/core/src/shuffle_storage.rs b/ballista/core/src/shuffle_storage.rs index 6d67e9d859..dfc917e95d 100644 --- a/ballista/core/src/shuffle_storage.rs +++ b/ballista/core/src/shuffle_storage.rs @@ -23,16 +23,16 @@ use async_trait::async_trait; use bytes::Bytes; use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::ipc::CompressionType; use datafusion::arrow::ipc::reader::StreamReader; use datafusion::arrow::ipc::writer::IpcWriteOptions; use datafusion::arrow::ipc::writer::StreamWriter; -use datafusion::arrow::ipc::CompressionType; use datafusion::arrow::record_batch::RecordBatch; use datafusion::physical_plan::metrics; use futures::StreamExt; use log::{debug, error}; -use object_store::azure::MicrosoftAzureBuilder; use object_store::aws::AmazonS3Builder; +use object_store::azure::MicrosoftAzureBuilder; use object_store::path::Path as ObjectPath; use object_store::{ObjectStore, PutPayload}; use std::fmt::{Debug, Display}; @@ -160,7 +160,10 @@ impl ShuffleStorageConfig { /// Creates a new Azure Blob Storage configuration. pub fn new_azure(account: &str, container: &str, prefix: Option<&str>) -> Self { let base_url = match prefix { - Some(p) => format!("abfs://{}@{}.dfs.core.windows.net/{}", container, account, p), + Some(p) => format!( + "abfs://{}@{}.dfs.core.windows.net/{}", + container, account, p + ), None => format!("abfs://{}@{}.dfs.core.windows.net", container, account), }; Self { @@ -178,6 +181,7 @@ impl ShuffleStorageConfig { /// Trait for shuffle storage operations. #[async_trait] +#[allow(clippy::too_many_arguments)] pub trait ShuffleStorage: Send + Sync + Debug { /// Write a record batch to storage and return the path where it was written. async fn write_shuffle_data( @@ -256,7 +260,8 @@ impl ShuffleStorage for LocalShuffleStorage { let options = IpcWriteOptions::default() .try_with_compression(Some(CompressionType::LZ4_FRAME))?; - let mut writer = StreamWriter::try_new_with_options(file, schema.as_ref(), options)?; + let mut writer = + StreamWriter::try_new_with_options(file, schema.as_ref(), options)?; let mut num_rows = 0; let mut num_batches = 0; @@ -283,7 +288,10 @@ impl ShuffleStorage for LocalShuffleStorage { async fn read_shuffle_data(&self, path: &str) -> Result> { let file = File::open(path).map_err(|e| { - BallistaError::General(format!("Failed to open shuffle file at {}: {:?}", path, e)) + BallistaError::General(format!( + "Failed to open shuffle file at {}: {:?}", + path, e + )) })?; let reader = BufReader::new(file); let stream_reader = StreamReader::try_new(reader, None)?; @@ -311,7 +319,9 @@ impl ShuffleStorage for LocalShuffleStorage { fn can_handle(&self, path: &str) -> bool { // Local storage can handle paths that don't start with a URL scheme - !path.starts_with("s3://") && !path.starts_with("abfs://") && !path.starts_with("az://") + !path.starts_with("s3://") + && !path.starts_with("abfs://") + && !path.starts_with("az://") } } @@ -396,7 +406,9 @@ impl ObjectStoreShuffleStorage { .filter_map(|pair| { let mut parts = pair.splitn(2, '='); match (parts.next(), parts.next()) { - (Some(key), Some(value)) => Some((key.to_string(), value.to_string())), + (Some(key), Some(value)) => { + Some((key.to_string(), value.to_string())) + } _ => None, } }) @@ -405,7 +417,10 @@ impl ObjectStoreShuffleStorage { } let store = builder.build().map_err(|e| { - BallistaError::General(format!("Failed to create Azure object store: {:?}", e)) + BallistaError::General(format!( + "Failed to create Azure object store: {:?}", + e + )) })?; let base_url = config.base_url.clone().unwrap_or_else(|| { @@ -430,7 +445,13 @@ impl ObjectStoreShuffleStorage { } } - fn make_path(&self, job_id: &str, stage_id: usize, partition_id: usize, input_partition: usize) -> String { + fn make_path( + &self, + job_id: &str, + stage_id: usize, + partition_id: usize, + input_partition: usize, + ) -> String { let filename = if input_partition == partition_id { "data.arrow".to_string() } else { @@ -452,9 +473,10 @@ impl ShuffleStorage for ObjectStoreShuffleStorage { schema: SchemaRef, write_metric: &metrics::Time, ) -> Result<(String, PartitionStats)> { - let relative_path = self.make_path(job_id, stage_id, partition_id, input_partition); + let relative_path = + self.make_path(job_id, stage_id, partition_id, input_partition); let full_url = format!("{}/{}", self.base_url, relative_path); - + debug!("Writing shuffle data to object store: {}", full_url); let timer = write_metric.timer(); @@ -463,8 +485,8 @@ impl ShuffleStorage for ObjectStoreShuffleStorage { let mut buffer = Vec::new(); let options = IpcWriteOptions::default() .try_with_compression(Some(CompressionType::LZ4_FRAME))?; - - let (total_rows, total_batches) = { + + let (_total_rows, _total_batches) = { let mut writer = StreamWriter::try_new_with_options( Cursor::new(&mut buffer), schema.as_ref(), @@ -489,7 +511,7 @@ impl ShuffleStorage for ObjectStoreShuffleStorage { // Upload to object store let object_path = ObjectPath::from(relative_path); let payload = PutPayload::from(Bytes::from(buffer)); - + self.store.put(&object_path, payload).await.map_err(|e| { BallistaError::General(format!( "Failed to upload shuffle data to {}: {:?}", @@ -511,7 +533,7 @@ impl ShuffleStorage for ObjectStoreShuffleStorage { async fn read_shuffle_data(&self, path: &str) -> Result> { // Extract the object path from the full URL let object_path = self.extract_object_path(path)?; - + debug!("Reading shuffle data from object store: {}", path); let get_result = self.store.get(&object_path).await.map_err(|e| { @@ -522,10 +544,7 @@ impl ShuffleStorage for ObjectStoreShuffleStorage { })?; let bytes = get_result.bytes().await.map_err(|e| { - BallistaError::General(format!( - "Failed to read bytes from {}: {:?}", - path, e - )) + BallistaError::General(format!("Failed to read bytes from {}: {:?}", path, e)) })?; let cursor = Cursor::new(bytes.to_vec()); @@ -541,7 +560,7 @@ impl ShuffleStorage for ObjectStoreShuffleStorage { async fn delete_job_data(&self, job_id: &str) -> Result<()> { let prefix = ObjectPath::from(job_id.to_string()); - + // List all objects with the job_id prefix let mut list_stream = self.store.list(Some(&prefix)); let mut objects_to_delete = Vec::new(); @@ -607,10 +626,9 @@ impl ShuffleStorageFactory { pub fn create(config: &ShuffleStorageConfig) -> Result> { match config.storage_type { ShuffleStorageType::Local => { - let work_dir = config - .base_url - .as_ref() - .ok_or_else(|| BallistaError::General("Work directory not configured".to_string()))?; + let work_dir = config.base_url.as_ref().ok_or_else(|| { + BallistaError::General("Work directory not configured".to_string()) + })?; Ok(Arc::new(LocalShuffleStorage::new(work_dir))) } ShuffleStorageType::S3 => { @@ -637,9 +655,7 @@ mod tests { use tempfile::TempDir; fn create_test_batch() -> (RecordBatch, SchemaRef) { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, false), - ])); + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); let batch = RecordBatch::try_new( schema.clone(), vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], @@ -655,7 +671,8 @@ mod tests { let (batch, schema) = create_test_batch(); let metrics = ExecutionPlanMetricsSet::new(); - let time_metric = metrics::MetricBuilder::new(&metrics).subset_time("write_time", 0); + let time_metric = + metrics::MetricBuilder::new(&metrics).subset_time("write_time", 0); let (path, stats) = storage .write_shuffle_data( @@ -686,18 +703,11 @@ mod tests { let (batch, schema) = create_test_batch(); let metrics = ExecutionPlanMetricsSet::new(); - let time_metric = metrics::MetricBuilder::new(&metrics).subset_time("write_time", 0); + let time_metric = + metrics::MetricBuilder::new(&metrics).subset_time("write_time", 0); let (path, _) = storage - .write_shuffle_data( - "test_job", - 1, - 0, - 0, - vec![batch], - schema, - &time_metric, - ) + .write_shuffle_data("test_job", 1, 0, 0, vec![batch], schema, &time_metric) .await .unwrap(); @@ -736,7 +746,8 @@ mod tests { #[test] fn test_storage_config_new_s3() { - let config = ShuffleStorageConfig::new_s3("my-bucket", Some("shuffle"), Some("us-east-1")); + let config = + ShuffleStorageConfig::new_s3("my-bucket", Some("shuffle"), Some("us-east-1")); assert_eq!(config.storage_type, ShuffleStorageType::S3); assert_eq!(config.base_url, Some("s3://my-bucket/shuffle".to_string())); assert_eq!(config.s3_config.bucket, Some("my-bucket".to_string())); @@ -745,7 +756,8 @@ mod tests { #[test] fn test_storage_config_new_azure() { - let config = ShuffleStorageConfig::new_azure("myaccount", "mycontainer", Some("shuffle")); + let config = + ShuffleStorageConfig::new_azure("myaccount", "mycontainer", Some("shuffle")); assert_eq!(config.storage_type, ShuffleStorageType::Azure); assert_eq!( config.base_url, From c1b6a7819333920c70d39285c85209c7e9bf548c Mon Sep 17 00:00:00 2001 From: Luke Kim <80174+lukekim@users.noreply.github.com> Date: Thu, 22 Jan 2026 21:31:40 -0800 Subject: [PATCH 5/6] Don't expose final stage --- ballista/core/src/config.rs | 24 ++++++++++++---- ballista/core/src/extension.rs | 51 +++++++++++++++++++--------------- 2 files changed, 47 insertions(+), 28 deletions(-) diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs index 4ebfdaa64a..bd44d9ceaa 100644 --- a/ballista/core/src/config.rs +++ b/ballista/core/src/config.rs @@ -50,7 +50,8 @@ pub const BALLISTA_SHUFFLE_STORAGE_TYPE: &str = "ballista.shuffle.storage_type"; pub const BALLISTA_SHUFFLE_STORAGE_URL: &str = "ballista.shuffle.storage_url"; /// Configuration key for shuffle storage mode (disk or memory). pub const BALLISTA_SHUFFLE_MEMORY_MODE: &str = "ballista.shuffle.memory_mode"; -/// Configuration key indicating if this is the final output stage. +/// Internal configuration key indicating if this is the final output stage. +/// This is set by the scheduler based on stage topology, NOT user-configurable. /// When true, shuffle data is always written to disk regardless of memory_mode setting. pub const BALLISTA_IS_FINAL_STAGE: &str = "ballista.shuffle.is_final_stage"; /// Shuffle format configuration: "arrow_ipc" or "vortex" @@ -109,10 +110,9 @@ static CONFIG_ENTRIES: LazyLock> = LazyLock::new(|| "When enabled, shuffle data is kept in memory on executors instead of being written to disk. This can improve performance for workloads with sufficient memory.".to_string(), DataType::Boolean, Some((false).to_string())), - ConfigEntry::new(BALLISTA_IS_FINAL_STAGE.to_string(), - "When true, indicates this is the final output stage. Final stages always write to disk regardless of memory_mode setting to ensure proper cleanup.".to_string(), - DataType::Boolean, - Some((false).to_string())), + // Note: BALLISTA_IS_FINAL_STAGE is intentionally NOT in CONFIG_ENTRIES. + // It's an internal flag set by the scheduler based on stage topology, + // not a user-configurable setting. ConfigEntry::new(BALLISTA_GRPC_CLIENT_CONNECT_TIMEOUT_SECONDS.to_string(), "Connection timeout for gRPC client in seconds".to_string(), DataType::UInt64, @@ -352,9 +352,21 @@ impl BallistaConfig { } /// Returns whether this is the final output stage. + /// This is an internal flag set by the scheduler, not user-configurable. /// Final stages always write to disk regardless of memory_mode setting. pub fn is_final_stage(&self) -> bool { - self.get_bool_setting(BALLISTA_IS_FINAL_STAGE) + self.settings + .get(BALLISTA_IS_FINAL_STAGE) + .and_then(|v| v.parse::().ok()) + .unwrap_or(false) + } + + /// Sets the internal is_final_stage flag. + /// This should only be called by the scheduler when creating task configurations. + pub fn with_is_final_stage(mut self, is_final: bool) -> Self { + self.settings + .insert(BALLISTA_IS_FINAL_STAGE.to_string(), is_final.to_string()); + self } /// Returns the configured shuffle format (ArrowIpc or Vortex) diff --git a/ballista/core/src/extension.rs b/ballista/core/src/extension.rs index f14875c989..7ca09f2647 100644 --- a/ballista/core/src/extension.rs +++ b/ballista/core/src/extension.rs @@ -16,12 +16,11 @@ // under the License. use crate::config::{ - BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE, BALLISTA_IS_FINAL_STAGE, BALLISTA_JOB_NAME, - BALLISTA_SHUFFLE_FORMAT, BALLISTA_SHUFFLE_MEMORY_MODE, - BALLISTA_SHUFFLE_READER_FORCE_REMOTE_READ, BALLISTA_SHUFFLE_READER_MAX_REQUESTS, - BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT, BALLISTA_SHUFFLE_STORAGE_TYPE, - BALLISTA_SHUFFLE_STORAGE_URL, BALLISTA_STANDALONE_PARALLELISM, BallistaConfig, - ShuffleFormat, + BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE, BALLISTA_JOB_NAME, BALLISTA_SHUFFLE_FORMAT, + BALLISTA_SHUFFLE_MEMORY_MODE, BALLISTA_SHUFFLE_READER_FORCE_REMOTE_READ, + BALLISTA_SHUFFLE_READER_MAX_REQUESTS, BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT, + BALLISTA_SHUFFLE_STORAGE_TYPE, BALLISTA_SHUFFLE_STORAGE_URL, + BALLISTA_STANDALONE_PARALLELISM, BallistaConfig, ShuffleFormat, }; use crate::planner::BallistaQueryPlanner; use crate::serde::protobuf::KeyValuePair; @@ -513,16 +512,20 @@ impl SessionConfigExt for SessionConfig { .extensions .get::() .map(|c| c.is_final_stage()) - .unwrap_or_else(|| BallistaConfig::default().is_final_stage()) + .unwrap_or(false) } fn with_ballista_is_final_stage(self, is_final: bool) -> Self { - if self.options().extensions.get::().is_some() { - self.set_bool(BALLISTA_IS_FINAL_STAGE, is_final) - } else { - self.with_option_extension(BallistaConfig::default()) - .set_bool(BALLISTA_IS_FINAL_STAGE, is_final) - } + // is_final_stage is an internal flag, not a user-configurable setting, + // so we modify the BallistaConfig directly instead of using set_bool + let ballista_config = self + .options() + .extensions + .get::() + .cloned() + .unwrap_or_default() + .with_is_final_stage(is_final); + self.with_option_extension(ballista_config) } fn with_ballista_grpc_metadata(self, metadata: HashMap) -> Self { @@ -980,7 +983,7 @@ mod test { }; use crate::{ - config::BALLISTA_JOB_NAME, + config::{BALLISTA_JOB_NAME, BallistaConfig}, extension::{SessionConfigExt, SessionConfigHelperExt, SessionStateExt}, }; @@ -1065,17 +1068,21 @@ mod test { } #[test] - fn test_is_final_stage_serialization() { - use crate::config::BALLISTA_IS_FINAL_STAGE; - - // Test that is_final_stage is included in key-value pairs + fn test_is_final_stage_internal_setting() { + // Test that is_final_stage is properly stored in BallistaConfig let config = SessionConfig::new_with_ballista().with_ballista_is_final_stage(true); - let pairs = config.to_key_value_pairs(); - let is_final_pair = pairs.iter().find(|p| p.key == BALLISTA_IS_FINAL_STAGE); - assert!(is_final_pair.is_some()); - assert_eq!(is_final_pair.unwrap().value, Some("true".to_string())); + // Verify via the getter + assert!(config.ballista_is_final_stage()); + + // Verify the internal BallistaConfig has the setting + let ballista_config = config + .options() + .extensions + .get::() + .expect("BallistaConfig should exist"); + assert!(ballista_config.is_final_stage()); } #[test] From 3c333954493f986cf8f7adda11970c4a95ebdf79 Mon Sep 17 00:00:00 2001 From: Luke Kim <80174+lukekim@users.noreply.github.com> Date: Thu, 22 Jan 2026 21:37:48 -0800 Subject: [PATCH 6/6] Remove build-binary --- Cargo.lock | 1 - ballista/core/Cargo.toml | 8 +++---- ballista/core/src/config.rs | 24 ++++++++++++++----- .../src/execution_plans/shuffle_reader.rs | 15 ------------ ballista/core/src/lib.rs | 2 -- ballista/executor/Cargo.toml | 2 +- ballista/scheduler/Cargo.toml | 2 +- 7 files changed, 23 insertions(+), 31 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1ea71540fc..4b03ce4ed3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1001,7 +1001,6 @@ dependencies = [ "aws-credential-types", "bytes", "chrono", - "clap 4.5.54", "dashmap", "datafusion", "datafusion-proto", diff --git a/ballista/core/Cargo.toml b/ballista/core/Cargo.toml index 7a21172787..068d7101e4 100644 --- a/ballista/core/Cargo.toml +++ b/ballista/core/Cargo.toml @@ -35,7 +35,6 @@ exclude = ["*.proto"] rustc-args = ["--cfg", "docsrs"] [features] -build-binary = ["aws-config", "aws-credential-types", "clap", "object_store"] docsrs = [] # Used for testing ONLY: causes all values to hash to the same value (test for collisions) force_hash_collisions = ["datafusion/force_hash_collisions"] @@ -45,11 +44,10 @@ vortex = ["vortex-array", "vortex-buffer", "vortex-dtype", "vortex-error", "vort [dependencies] arrow-flight = { workspace = true } async-trait = { workspace = true } -aws-config = { version = "1.6.0", optional = true } -aws-credential-types = { version = "1.2.0", optional = true } +aws-config = { version = "1.6.0" } +aws-credential-types = { version = "1.2.0" } bytes = { workspace = true } chrono = { version = "0.4", default-features = false } -clap = { workspace = true, optional = true } dashmap = "6" datafusion = { workspace = true } datafusion-proto = { workspace = true } @@ -58,7 +56,7 @@ futures = { workspace = true } itertools = "0.14" log = { workspace = true } md-5 = { version = "^0.10.0" } -object_store = { workspace = true, features = ["aws", "azure", "http"], optional = true } +object_store = { workspace = true, features = ["aws", "azure", "http"] } parking_lot = { workspace = true } prost = { workspace = true } prost-types = { workspace = true } diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs index bd44d9ceaa..45c955153f 100644 --- a/ballista/core/src/config.rs +++ b/ballista/core/src/config.rs @@ -468,7 +468,6 @@ impl datafusion::config::ConfigExtension for BallistaConfig { /// Ballista supports both push-based and pull-based task scheduling. /// It is recommended that you try both to determine which is the best for your use case. #[derive(Clone, Copy, Debug, serde::Deserialize, Default)] -#[cfg_attr(feature = "build-binary", derive(clap::ValueEnum))] pub enum TaskSchedulingPolicy { /// Pull-based scheduling works in a similar way to Apache Spark #[default] @@ -485,18 +484,23 @@ impl Display for TaskSchedulingPolicy { } } -#[cfg(feature = "build-binary")] impl std::str::FromStr for TaskSchedulingPolicy { type Err = String; fn from_str(s: &str) -> std::result::Result { - clap::ValueEnum::from_str(s, true) + match s.to_lowercase().as_str() { + "pull-staged" | "pullstaged" => Ok(TaskSchedulingPolicy::PullStaged), + "push-staged" | "pushstaged" => Ok(TaskSchedulingPolicy::PushStaged), + _ => Err(format!( + "Invalid scheduling policy '{}'. Valid options: 'pull-staged', 'push-staged'", + s + )), + } } } /// Configures the log file rotation policy. #[derive(Clone, Copy, Debug, serde::Deserialize, Default)] -#[cfg_attr(feature = "build-binary", derive(clap::ValueEnum))] pub enum LogRotationPolicy { /// Rotate log files every minute. Minutely, @@ -520,12 +524,20 @@ impl Display for LogRotationPolicy { } } -#[cfg(feature = "build-binary")] impl std::str::FromStr for LogRotationPolicy { type Err = String; fn from_str(s: &str) -> std::result::Result { - clap::ValueEnum::from_str(s, true) + match s.to_lowercase().as_str() { + "minutely" => Ok(LogRotationPolicy::Minutely), + "hourly" => Ok(LogRotationPolicy::Hourly), + "daily" => Ok(LogRotationPolicy::Daily), + "never" => Ok(LogRotationPolicy::Never), + _ => Err(format!( + "Invalid rotation policy '{}'. Valid options: 'minutely', 'hourly', 'daily', 'never'", + s + )), + } } } diff --git a/ballista/core/src/execution_plans/shuffle_reader.rs b/ballista/core/src/execution_plans/shuffle_reader.rs index 30b51fbde7..ad45959fd5 100644 --- a/ballista/core/src/execution_plans/shuffle_reader.rs +++ b/ballista/core/src/execution_plans/shuffle_reader.rs @@ -28,13 +28,9 @@ use std::result; use std::sync::Arc; use std::task::{Context, Poll}; -#[cfg(feature = "build-binary")] use object_store::ObjectStore; -#[cfg(feature = "build-binary")] use object_store::aws::AmazonS3Builder; -#[cfg(feature = "build-binary")] use object_store::azure::MicrosoftAzureBuilder; -#[cfg(feature = "build-binary")] use url::Url; use crate::client::BallistaClient; @@ -841,7 +837,6 @@ fn check_is_object_store_location(location: &PartitionLocation) -> bool { path.starts_with("s3://") || path.starts_with("abfs://") || path.starts_with("az://") } -#[cfg(feature = "build-binary")] async fn fetch_partition_object_store( location: &PartitionLocation, ) -> result::Result { @@ -877,16 +872,6 @@ async fn fetch_partition_object_store( Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) } -#[cfg(not(feature = "build-binary"))] -async fn fetch_partition_object_store( - _location: &PartitionLocation, -) -> result::Result { - Err(BallistaError::NotImplemented( - "Object store support requires 'build-binary' feature".to_string(), - )) -} - -#[cfg(feature = "build-binary")] async fn fetch_partition_object_store_inner( path: &str, ) -> result::Result, BallistaError> { diff --git a/ballista/core/src/lib.rs b/ballista/core/src/lib.rs index 846e6671ac..a9aa9be1a4 100644 --- a/ballista/core/src/lib.rs +++ b/ballista/core/src/lib.rs @@ -45,7 +45,6 @@ pub mod event_loop; pub mod execution_plans; /// Extension traits and utilities for DataFusion integration. pub mod extension; -#[cfg(feature = "build-binary")] /// Object store configuration and utilities for distributed file access. pub mod object_store; /// Query planning utilities for distributed execution. @@ -57,7 +56,6 @@ pub mod remote_catalog; /// Serialization and deserialization for Ballista messages and plans. pub mod serde; /// Shuffle storage abstraction for local and object store backends. -#[cfg(feature = "build-binary")] pub mod shuffle_storage; /// General utility functions for Ballista operations. pub mod utils; diff --git a/ballista/executor/Cargo.toml b/ballista/executor/Cargo.toml index f0fb6043a7..750bed7f61 100644 --- a/ballista/executor/Cargo.toml +++ b/ballista/executor/Cargo.toml @@ -33,7 +33,7 @@ path = "src/bin/main.rs" required-features = ["build-binary"] [features] -build-binary = ["clap", "tracing-subscriber", "tracing-appender", "tracing", "ballista-core/build-binary"] +build-binary = ["clap", "tracing-subscriber", "tracing-appender", "tracing"] default = ["build-binary", "mimalloc"] vortex = ["ballista-core/vortex", "vortex-array", "vortex-ipc"] diff --git a/ballista/scheduler/Cargo.toml b/ballista/scheduler/Cargo.toml index 883b211d13..9245e1e1b6 100644 --- a/ballista/scheduler/Cargo.toml +++ b/ballista/scheduler/Cargo.toml @@ -33,7 +33,7 @@ path = "src/bin/main.rs" required-features = ["build-binary"] [features] -build-binary = ["clap", "tracing-subscriber", "tracing-appender", "tracing", "ballista-core/build-binary"] +build-binary = ["clap", "tracing-subscriber", "tracing-appender", "tracing"] default = ["build-binary", "substrait"] # job info can cache stage plans, in some cases where # task plans can be re-computed, cache behavior may need to be disabled.