diff --git a/Cargo.lock b/Cargo.lock index 4920368d7c..4b03ce4ed3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -999,8 +999,8 @@ dependencies = [ "async-trait", "aws-config", "aws-credential-types", + "bytes", "chrono", - "clap 4.5.54", "dashmap", "datafusion", "datafusion-proto", @@ -3968,6 +3968,7 @@ dependencies = [ "futures", "http 1.4.0", "http-body-util", + "httparse", "humantime", "hyper", "itertools", diff --git a/Cargo.toml b/Cargo.toml index a0aade4dc5..72b8fcfbdb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,7 +40,8 @@ datafusion-cli = "51.0.0" datafusion-proto = "51.0.0" datafusion-proto-common = "51.0.0" datafusion-substrait = "51.0.0" -object_store = "0.12" +object_store = { version = "0.12", features = ["aws", "azure"] } +bytes = "1.5" prost = "0.14" prost-types = "0.14" rstest = { version = "0.26" } diff --git a/ballista/core/Cargo.toml b/ballista/core/Cargo.toml index 65fc6724dc..068d7101e4 100644 --- a/ballista/core/Cargo.toml +++ b/ballista/core/Cargo.toml @@ -35,7 +35,6 @@ exclude = ["*.proto"] rustc-args = ["--cfg", "docsrs"] [features] -build-binary = ["aws-config", "aws-credential-types", "clap", "object_store"] docsrs = [] # Used for testing ONLY: causes all values to hash to the same value (test for collisions) force_hash_collisions = ["datafusion/force_hash_collisions"] @@ -45,10 +44,10 @@ vortex = ["vortex-array", "vortex-buffer", "vortex-dtype", "vortex-error", "vort [dependencies] arrow-flight = { workspace = true } async-trait = { workspace = true } -aws-config = { version = "1.6.0", optional = true } -aws-credential-types = { version = "1.2.0", optional = true } +aws-config = { version = "1.6.0" } +aws-credential-types = { version = "1.2.0" } +bytes = { workspace = true } chrono = { version = "0.4", default-features = false } -clap = { workspace = true, optional = true } dashmap = "6" datafusion = { workspace = true } datafusion-proto = { workspace = true } @@ -57,7 +56,7 @@ futures = { workspace = true } itertools = "0.14" log = { workspace = true } md-5 = { version = "^0.10.0" } -object_store = { workspace = true, features = ["aws", "http"], optional = true } +object_store = { workspace = true, features = ["aws", "azure", "http"] } parking_lot = { workspace = true } prost = { workspace = true } prost-types = { workspace = true } diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs index 314ad05cb2..45c955153f 100644 --- a/ballista/core/src/config.rs +++ b/ballista/core/src/config.rs @@ -44,9 +44,14 @@ pub const BALLISTA_SHUFFLE_READER_FORCE_REMOTE_READ: &str = /// Configuration key to prefer Flight protocol for remote shuffle reads. pub const BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT: &str = "ballista.shuffle.remote_read_prefer_flight"; +/// Configuration key for shuffle storage type (local, s3, azure). +pub const BALLISTA_SHUFFLE_STORAGE_TYPE: &str = "ballista.shuffle.storage_type"; +/// Configuration key for shuffle storage base URL/path. +pub const BALLISTA_SHUFFLE_STORAGE_URL: &str = "ballista.shuffle.storage_url"; /// Configuration key for shuffle storage mode (disk or memory). pub const BALLISTA_SHUFFLE_MEMORY_MODE: &str = "ballista.shuffle.memory_mode"; -/// Configuration key indicating if this is the final output stage. +/// Internal configuration key indicating if this is the final output stage. +/// This is set by the scheduler based on stage topology, NOT user-configurable. /// When true, shuffle data is always written to disk regardless of memory_mode setting. pub const BALLISTA_IS_FINAL_STAGE: &str = "ballista.shuffle.is_final_stage"; /// Shuffle format configuration: "arrow_ipc" or "vortex" @@ -93,14 +98,21 @@ static CONFIG_ENTRIES: LazyLock> = LazyLock::new(|| "Forces the shuffle reader to use flight reader instead of block reader for remote read. Block reader usually has better performance and resource utilization".to_string(), DataType::Boolean, Some((false).to_string())), + ConfigEntry::new(BALLISTA_SHUFFLE_STORAGE_TYPE.to_string(), + "Storage type for shuffle data: 'local' (default), 's3', or 'azure'".to_string(), + DataType::Utf8, + Some("local".to_string())), + ConfigEntry::new(BALLISTA_SHUFFLE_STORAGE_URL.to_string(), + "Base URL/path for shuffle storage. For local: file path; For S3: s3://bucket/prefix; For Azure: abfs://container@account.dfs.core.windows.net/prefix".to_string(), + DataType::Utf8, + None), ConfigEntry::new(BALLISTA_SHUFFLE_MEMORY_MODE.to_string(), "When enabled, shuffle data is kept in memory on executors instead of being written to disk. This can improve performance for workloads with sufficient memory.".to_string(), DataType::Boolean, Some((false).to_string())), - ConfigEntry::new(BALLISTA_IS_FINAL_STAGE.to_string(), - "When true, indicates this is the final output stage. Final stages always write to disk regardless of memory_mode setting to ensure proper cleanup.".to_string(), - DataType::Boolean, - Some((false).to_string())), + // Note: BALLISTA_IS_FINAL_STAGE is intentionally NOT in CONFIG_ENTRIES. + // It's an internal flag set by the scheduler based on stage topology, + // not a user-configurable setting. ConfigEntry::new(BALLISTA_GRPC_CLIENT_CONNECT_TIMEOUT_SECONDS.to_string(), "Connection timeout for gRPC client in seconds".to_string(), DataType::UInt64, @@ -320,6 +332,16 @@ impl BallistaConfig { self.get_bool_setting(BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT) } + /// Returns the shuffle storage type (local, s3, azure). + pub fn shuffle_storage_type(&self) -> String { + self.get_string_setting(BALLISTA_SHUFFLE_STORAGE_TYPE) + } + + /// Returns the shuffle storage base URL/path if configured. + pub fn shuffle_storage_url(&self) -> Option { + self.settings.get(BALLISTA_SHUFFLE_STORAGE_URL).cloned() + } + /// Returns whether in-memory shuffle mode is enabled. /// /// When enabled, shuffle data is kept in memory on executors instead of @@ -330,9 +352,21 @@ impl BallistaConfig { } /// Returns whether this is the final output stage. + /// This is an internal flag set by the scheduler, not user-configurable. /// Final stages always write to disk regardless of memory_mode setting. pub fn is_final_stage(&self) -> bool { - self.get_bool_setting(BALLISTA_IS_FINAL_STAGE) + self.settings + .get(BALLISTA_IS_FINAL_STAGE) + .and_then(|v| v.parse::().ok()) + .unwrap_or(false) + } + + /// Sets the internal is_final_stage flag. + /// This should only be called by the scheduler when creating task configurations. + pub fn with_is_final_stage(mut self, is_final: bool) -> Self { + self.settings + .insert(BALLISTA_IS_FINAL_STAGE.to_string(), is_final.to_string()); + self } /// Returns the configured shuffle format (ArrowIpc or Vortex) @@ -434,7 +468,6 @@ impl datafusion::config::ConfigExtension for BallistaConfig { /// Ballista supports both push-based and pull-based task scheduling. /// It is recommended that you try both to determine which is the best for your use case. #[derive(Clone, Copy, Debug, serde::Deserialize, Default)] -#[cfg_attr(feature = "build-binary", derive(clap::ValueEnum))] pub enum TaskSchedulingPolicy { /// Pull-based scheduling works in a similar way to Apache Spark #[default] @@ -451,18 +484,23 @@ impl Display for TaskSchedulingPolicy { } } -#[cfg(feature = "build-binary")] impl std::str::FromStr for TaskSchedulingPolicy { type Err = String; fn from_str(s: &str) -> std::result::Result { - clap::ValueEnum::from_str(s, true) + match s.to_lowercase().as_str() { + "pull-staged" | "pullstaged" => Ok(TaskSchedulingPolicy::PullStaged), + "push-staged" | "pushstaged" => Ok(TaskSchedulingPolicy::PushStaged), + _ => Err(format!( + "Invalid scheduling policy '{}'. Valid options: 'pull-staged', 'push-staged'", + s + )), + } } } /// Configures the log file rotation policy. #[derive(Clone, Copy, Debug, serde::Deserialize, Default)] -#[cfg_attr(feature = "build-binary", derive(clap::ValueEnum))] pub enum LogRotationPolicy { /// Rotate log files every minute. Minutely, @@ -486,12 +524,20 @@ impl Display for LogRotationPolicy { } } -#[cfg(feature = "build-binary")] impl std::str::FromStr for LogRotationPolicy { type Err = String; fn from_str(s: &str) -> std::result::Result { - clap::ValueEnum::from_str(s, true) + match s.to_lowercase().as_str() { + "minutely" => Ok(LogRotationPolicy::Minutely), + "hourly" => Ok(LogRotationPolicy::Hourly), + "daily" => Ok(LogRotationPolicy::Daily), + "never" => Ok(LogRotationPolicy::Never), + _ => Err(format!( + "Invalid rotation policy '{}'. Valid options: 'minutely', 'hourly', 'daily', 'never'", + s + )), + } } } diff --git a/ballista/core/src/execution_plans/shuffle_reader.rs b/ballista/core/src/execution_plans/shuffle_reader.rs index cd88fcf317..ad45959fd5 100644 --- a/ballista/core/src/execution_plans/shuffle_reader.rs +++ b/ballista/core/src/execution_plans/shuffle_reader.rs @@ -22,12 +22,17 @@ use std::any::Any; use std::collections::HashMap; use std::fmt::Debug; use std::fs::File; -use std::io::BufReader; +use std::io::{BufReader, Cursor}; use std::pin::Pin; use std::result; use std::sync::Arc; use std::task::{Context, Poll}; +use object_store::ObjectStore; +use object_store::aws::AmazonS3Builder; +use object_store::azure::MicrosoftAzureBuilder; +use url::Url; + use crate::client::BallistaClient; use crate::execution_plans::shuffle_manager::global_shuffle_manager; use crate::extension::{ @@ -737,14 +742,6 @@ fn fetch_partition_local_vortex( Ok(Box::pin(stream)) } -async fn fetch_partition_object_store( - _location: &PartitionLocation, -) -> result::Result { - Err(BallistaError::NotImplemented( - "Should not use ObjectStorePartitionReader".to_string(), - )) -} - /// Fetch partition data from in-memory shuffle storage. /// /// After successfully fetching the data, the partition is removed from memory @@ -833,6 +830,137 @@ impl RecordBatchStream for InMemoryShuffleStream { } } +/// Check if the location is an object store path (S3 or Azure). +#[allow(dead_code)] +fn check_is_object_store_location(location: &PartitionLocation) -> bool { + let path = location.path.as_str(); + path.starts_with("s3://") || path.starts_with("abfs://") || path.starts_with("az://") +} + +async fn fetch_partition_object_store( + location: &PartitionLocation, +) -> result::Result { + use datafusion::physical_plan::stream::RecordBatchStreamAdapter; + + let path = &location.path; + let metadata = &location.executor_meta; + let partition_id = &location.partition_id; + + debug!("Fetching shuffle partition from object store: {}", path); + + let batches = fetch_partition_object_store_inner(path) + .await + .map_err(|e| { + // return BallistaError::FetchFailed may let scheduler retry this task. + BallistaError::FetchFailed( + metadata.id.clone(), + partition_id.stage_id, + partition_id.partition_id, + e.to_string(), + ) + })?; + + if batches.is_empty() { + return Err(BallistaError::General(format!( + "No batches found in shuffle partition at {}", + path + ))); + } + + let schema = batches[0].schema(); + let stream = futures::stream::iter(batches.into_iter().map(Ok)); + Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) +} + +async fn fetch_partition_object_store_inner( + path: &str, +) -> result::Result, BallistaError> { + use object_store::path::Path as ObjectPath; + + let url = Url::parse(path).map_err(|e| { + BallistaError::General(format!( + "Failed to parse object store URL '{}': {:?}", + path, e + )) + })?; + + let scheme = url.scheme(); + let store: Arc = match scheme { + "s3" => { + let bucket = url.host_str().ok_or_else(|| { + BallistaError::General(format!("No bucket in S3 URL: {}", path)) + })?; + let builder = AmazonS3Builder::from_env().with_bucket_name(bucket); + Arc::new(builder.build().map_err(|e| { + BallistaError::General(format!("Failed to create S3 client: {:?}", e)) + })?) + } + "abfs" | "az" => { + // Parse Azure URL: abfs://container@account.dfs.core.windows.net/path + let host = url.host_str().ok_or_else(|| { + BallistaError::General(format!("No host in Azure URL: {}", path)) + })?; + + // Extract container from username portion + let container = url.username(); + if container.is_empty() { + return Err(BallistaError::General(format!( + "No container in Azure URL. Expected format: abfs://container@account.dfs.core.windows.net/path. Got: {}", + path + ))); + } + + // Extract account from host (account.dfs.core.windows.net) + let account = host.split('.').next().ok_or_else(|| { + BallistaError::General(format!("No account in Azure URL: {}", path)) + })?; + + let builder = MicrosoftAzureBuilder::from_env() + .with_account(account) + .with_container_name(container); + Arc::new(builder.build().map_err(|e| { + BallistaError::General(format!("Failed to create Azure client: {:?}", e)) + })?) + } + _ => { + return Err(BallistaError::General(format!( + "Unsupported object store scheme: {}. Supported: s3, abfs, az", + scheme + ))); + } + }; + + // Extract the object path from the URL + let object_path = ObjectPath::from(url.path().trim_start_matches('/')); + + debug!("Reading object from path: {:?}", object_path); + + let get_result = store.get(&object_path).await.map_err(|e| { + BallistaError::General(format!("Failed to read object from {}: {:?}", path, e)) + })?; + + let bytes = get_result.bytes().await.map_err(|e| { + BallistaError::General(format!("Failed to read bytes from {}: {:?}", path, e)) + })?; + + let cursor = Cursor::new(bytes.to_vec()); + let stream_reader = StreamReader::try_new(cursor, None).map_err(|e| { + BallistaError::General(format!( + "Failed to create Arrow stream reader for {}: {:?}", + path, e + )) + })?; + + let mut batches = Vec::new(); + for batch_result in stream_reader { + batches.push(batch_result.map_err(|e| { + BallistaError::General(format!("Failed to read batch from {}: {:?}", path, e)) + })?); + } + + Ok(batches) +} + #[cfg(test)] mod tests { use super::*; diff --git a/ballista/core/src/extension.rs b/ballista/core/src/extension.rs index 8cb4a4ebbd..7ca09f2647 100644 --- a/ballista/core/src/extension.rs +++ b/ballista/core/src/extension.rs @@ -16,11 +16,11 @@ // under the License. use crate::config::{ - BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE, BALLISTA_IS_FINAL_STAGE, BALLISTA_JOB_NAME, - BALLISTA_SHUFFLE_FORMAT, BALLISTA_SHUFFLE_MEMORY_MODE, - BALLISTA_SHUFFLE_READER_FORCE_REMOTE_READ, BALLISTA_SHUFFLE_READER_MAX_REQUESTS, - BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT, BALLISTA_STANDALONE_PARALLELISM, - BallistaConfig, ShuffleFormat, + BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE, BALLISTA_JOB_NAME, BALLISTA_SHUFFLE_FORMAT, + BALLISTA_SHUFFLE_MEMORY_MODE, BALLISTA_SHUFFLE_READER_FORCE_REMOTE_READ, + BALLISTA_SHUFFLE_READER_MAX_REQUESTS, BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT, + BALLISTA_SHUFFLE_STORAGE_TYPE, BALLISTA_SHUFFLE_STORAGE_URL, + BALLISTA_STANDALONE_PARALLELISM, BallistaConfig, ShuffleFormat, }; use crate::planner::BallistaQueryPlanner; use crate::serde::protobuf::KeyValuePair; @@ -192,6 +192,18 @@ pub trait SessionConfigExt { /// Get whether to use TLS for executor connections fn ballista_use_tls(&self) -> bool; + /// Returns the shuffle storage type (local, s3, azure). + fn ballista_shuffle_storage_type(&self) -> String; + + /// Sets the shuffle storage type. + fn with_ballista_shuffle_storage_type(self, storage_type: &str) -> Self; + + /// Returns the shuffle storage base URL/path if configured. + fn ballista_shuffle_storage_url(&self) -> Option; + + /// Sets the shuffle storage base URL/path. + fn with_ballista_shuffle_storage_url(self, url: &str) -> Self; + /// Get the shuffle format (ArrowIpc or Vortex) /// /// Note: Vortex format requires the 'vortex' feature to be enabled. @@ -500,16 +512,20 @@ impl SessionConfigExt for SessionConfig { .extensions .get::() .map(|c| c.is_final_stage()) - .unwrap_or_else(|| BallistaConfig::default().is_final_stage()) + .unwrap_or(false) } fn with_ballista_is_final_stage(self, is_final: bool) -> Self { - if self.options().extensions.get::().is_some() { - self.set_bool(BALLISTA_IS_FINAL_STAGE, is_final) - } else { - self.with_option_extension(BallistaConfig::default()) - .set_bool(BALLISTA_IS_FINAL_STAGE, is_final) - } + // is_final_stage is an internal flag, not a user-configurable setting, + // so we modify the BallistaConfig directly instead of using set_bool + let ballista_config = self + .options() + .extensions + .get::() + .cloned() + .unwrap_or_default() + .with_is_final_stage(is_final); + self.with_option_extension(ballista_config) } fn with_ballista_grpc_metadata(self, metadata: HashMap) -> Self { @@ -550,6 +566,39 @@ impl SessionConfigExt for SessionConfig { .unwrap_or(false) } + fn ballista_shuffle_storage_type(&self) -> String { + self.options() + .extensions + .get::() + .map(|c| c.shuffle_storage_type()) + .unwrap_or_else(|| BallistaConfig::default().shuffle_storage_type()) + } + + fn with_ballista_shuffle_storage_type(self, storage_type: &str) -> Self { + if self.options().extensions.get::().is_some() { + self.set_str(BALLISTA_SHUFFLE_STORAGE_TYPE, storage_type) + } else { + self.with_option_extension(BallistaConfig::default()) + .set_str(BALLISTA_SHUFFLE_STORAGE_TYPE, storage_type) + } + } + + fn ballista_shuffle_storage_url(&self) -> Option { + self.options() + .extensions + .get::() + .and_then(|c| c.shuffle_storage_url()) + } + + fn with_ballista_shuffle_storage_url(self, url: &str) -> Self { + if self.options().extensions.get::().is_some() { + self.set_str(BALLISTA_SHUFFLE_STORAGE_URL, url) + } else { + self.with_option_extension(BallistaConfig::default()) + .set_str(BALLISTA_SHUFFLE_STORAGE_URL, url) + } + } + fn ballista_shuffle_format(&self) -> ShuffleFormat { self.options() .extensions @@ -934,7 +983,7 @@ mod test { }; use crate::{ - config::BALLISTA_JOB_NAME, + config::{BALLISTA_JOB_NAME, BallistaConfig}, extension::{SessionConfigExt, SessionConfigHelperExt, SessionStateExt}, }; @@ -1019,17 +1068,21 @@ mod test { } #[test] - fn test_is_final_stage_serialization() { - use crate::config::BALLISTA_IS_FINAL_STAGE; - - // Test that is_final_stage is included in key-value pairs + fn test_is_final_stage_internal_setting() { + // Test that is_final_stage is properly stored in BallistaConfig let config = SessionConfig::new_with_ballista().with_ballista_is_final_stage(true); - let pairs = config.to_key_value_pairs(); - let is_final_pair = pairs.iter().find(|p| p.key == BALLISTA_IS_FINAL_STAGE); - assert!(is_final_pair.is_some()); - assert_eq!(is_final_pair.unwrap().value, Some("true".to_string())); + // Verify via the getter + assert!(config.ballista_is_final_stage()); + + // Verify the internal BallistaConfig has the setting + let ballista_config = config + .options() + .extensions + .get::() + .expect("BallistaConfig should exist"); + assert!(ballista_config.is_final_stage()); } #[test] diff --git a/ballista/core/src/lib.rs b/ballista/core/src/lib.rs index 58e6d80a16..a9aa9be1a4 100644 --- a/ballista/core/src/lib.rs +++ b/ballista/core/src/lib.rs @@ -45,7 +45,6 @@ pub mod event_loop; pub mod execution_plans; /// Extension traits and utilities for DataFusion integration. pub mod extension; -#[cfg(feature = "build-binary")] /// Object store configuration and utilities for distributed file access. pub mod object_store; /// Query planning utilities for distributed execution. @@ -56,6 +55,8 @@ pub mod registry; pub mod remote_catalog; /// Serialization and deserialization for Ballista messages and plans. pub mod serde; +/// Shuffle storage abstraction for local and object store backends. +pub mod shuffle_storage; /// General utility functions for Ballista operations. pub mod utils; diff --git a/ballista/core/src/shuffle_storage.rs b/ballista/core/src/shuffle_storage.rs new file mode 100644 index 0000000000..dfc917e95d --- /dev/null +++ b/ballista/core/src/shuffle_storage.rs @@ -0,0 +1,767 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Shuffle storage abstraction for storing shuffle data on local disk or object stores. +//! +//! This module provides a unified interface for reading and writing shuffle data +//! to different storage backends including local filesystem, Amazon S3, and Azure Blob Storage. + +use async_trait::async_trait; +use bytes::Bytes; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::ipc::CompressionType; +use datafusion::arrow::ipc::reader::StreamReader; +use datafusion::arrow::ipc::writer::IpcWriteOptions; +use datafusion::arrow::ipc::writer::StreamWriter; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::physical_plan::metrics; +use futures::StreamExt; +use log::{debug, error}; +use object_store::aws::AmazonS3Builder; +use object_store::azure::MicrosoftAzureBuilder; +use object_store::path::Path as ObjectPath; +use object_store::{ObjectStore, PutPayload}; +use std::fmt::{Debug, Display}; +use std::fs::File; +use std::io::{BufReader, Cursor}; +use std::path::PathBuf; +use std::sync::Arc; +use url::Url; + +use crate::error::{BallistaError, Result}; +use crate::serde::scheduler::PartitionStats; + +/// Defines the type of storage to use for shuffle data. +#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub enum ShuffleStorageType { + /// Store shuffle data on local disk (default behavior). + #[default] + Local, + /// Store shuffle data in Amazon S3. + S3, + /// Store shuffle data in Azure Blob Storage (ABFS). + Azure, +} + +impl Display for ShuffleStorageType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ShuffleStorageType::Local => write!(f, "local"), + ShuffleStorageType::S3 => write!(f, "s3"), + ShuffleStorageType::Azure => write!(f, "azure"), + } + } +} + +impl std::str::FromStr for ShuffleStorageType { + type Err = String; + + fn from_str(s: &str) -> std::result::Result { + match s.to_lowercase().as_str() { + "local" | "disk" => Ok(ShuffleStorageType::Local), + "s3" | "aws" => Ok(ShuffleStorageType::S3), + "azure" | "abfs" | "adls" => Ok(ShuffleStorageType::Azure), + _ => Err(format!( + "Unknown shuffle storage type: '{}'. Valid options are: local, s3, azure", + s + )), + } + } +} + +/// Configuration for S3 shuffle storage. +#[derive(Clone, Debug, Default)] +pub struct S3ShuffleConfig { + /// S3 bucket name for shuffle data. + pub bucket: Option, + /// AWS region. + pub region: Option, + /// S3 endpoint URL (for MinIO or custom S3-compatible storage). + pub endpoint: Option, + /// AWS access key ID. + pub access_key_id: Option, + /// AWS secret access key. + pub secret_access_key: Option, + /// Allow HTTP connections (default is HTTPS only). + pub allow_http: bool, +} + +/// Configuration for Azure Blob Storage shuffle storage. +#[derive(Clone, Debug, Default)] +pub struct AzureShuffleConfig { + /// Azure storage account name. + pub account: Option, + /// Azure storage container name. + pub container: Option, + /// Azure storage access key. + pub access_key: Option, + /// Azure SAS token (alternative to access key). + pub sas_token: Option, +} + +/// Configuration for shuffle storage. +#[derive(Clone, Debug, Default)] +pub struct ShuffleStorageConfig { + /// The type of storage to use. + pub storage_type: ShuffleStorageType, + /// Base URL/path for shuffle data storage. + /// For local: file path (e.g., /tmp/ballista) + /// For S3: s3://bucket/prefix + /// For Azure: abfs://container@account.dfs.core.windows.net/prefix + pub base_url: Option, + /// S3-specific configuration. + pub s3_config: S3ShuffleConfig, + /// Azure-specific configuration. + pub azure_config: AzureShuffleConfig, +} + +impl ShuffleStorageConfig { + /// Creates a new local storage configuration with the given work directory. + pub fn new_local(work_dir: &str) -> Self { + Self { + storage_type: ShuffleStorageType::Local, + base_url: Some(work_dir.to_string()), + ..Default::default() + } + } + + /// Creates a new S3 storage configuration. + pub fn new_s3(bucket: &str, prefix: Option<&str>, region: Option<&str>) -> Self { + let base_url = match prefix { + Some(p) => format!("s3://{}/{}", bucket, p), + None => format!("s3://{}", bucket), + }; + Self { + storage_type: ShuffleStorageType::S3, + base_url: Some(base_url), + s3_config: S3ShuffleConfig { + bucket: Some(bucket.to_string()), + region: region.map(|s| s.to_string()), + ..Default::default() + }, + ..Default::default() + } + } + + /// Creates a new Azure Blob Storage configuration. + pub fn new_azure(account: &str, container: &str, prefix: Option<&str>) -> Self { + let base_url = match prefix { + Some(p) => format!( + "abfs://{}@{}.dfs.core.windows.net/{}", + container, account, p + ), + None => format!("abfs://{}@{}.dfs.core.windows.net", container, account), + }; + Self { + storage_type: ShuffleStorageType::Azure, + base_url: Some(base_url), + azure_config: AzureShuffleConfig { + account: Some(account.to_string()), + container: Some(container.to_string()), + ..Default::default() + }, + ..Default::default() + } + } +} + +/// Trait for shuffle storage operations. +#[async_trait] +#[allow(clippy::too_many_arguments)] +pub trait ShuffleStorage: Send + Sync + Debug { + /// Write a record batch to storage and return the path where it was written. + async fn write_shuffle_data( + &self, + job_id: &str, + stage_id: usize, + partition_id: usize, + input_partition: usize, + batches: Vec, + schema: SchemaRef, + write_metric: &metrics::Time, + ) -> Result<(String, PartitionStats)>; + + /// Read shuffle data from storage. + async fn read_shuffle_data(&self, path: &str) -> Result>; + + /// Delete shuffle data for a job. + async fn delete_job_data(&self, job_id: &str) -> Result<()>; + + /// Get the base path/URL for this storage. + fn base_path(&self) -> &str; + + /// Check if a path is accessible by this storage backend. + fn can_handle(&self, path: &str) -> bool; +} + +/// Local filesystem shuffle storage implementation. +#[derive(Debug)] +pub struct LocalShuffleStorage { + work_dir: String, +} + +impl LocalShuffleStorage { + /// Creates a new local shuffle storage with the given work directory. + pub fn new(work_dir: &str) -> Self { + Self { + work_dir: work_dir.to_string(), + } + } +} + +#[async_trait] +impl ShuffleStorage for LocalShuffleStorage { + async fn write_shuffle_data( + &self, + job_id: &str, + stage_id: usize, + partition_id: usize, + input_partition: usize, + batches: Vec, + schema: SchemaRef, + write_metric: &metrics::Time, + ) -> Result<(String, PartitionStats)> { + let mut path = PathBuf::from(&self.work_dir); + path.push(job_id); + path.push(format!("{}", stage_id)); + path.push(format!("{}", partition_id)); + std::fs::create_dir_all(&path)?; + + let filename = if input_partition == partition_id { + "data.arrow".to_string() + } else { + format!("data-{}.arrow", input_partition) + }; + path.push(&filename); + + let path_str = path.to_str().unwrap().to_string(); + debug!("Writing shuffle data to local path: {}", path_str); + + let timer = write_metric.timer(); + let file = File::create(&path).map_err(|e| { + error!("Failed to create shuffle file at {}: {:?}", path_str, e); + BallistaError::IoError(e) + })?; + + let options = IpcWriteOptions::default() + .try_with_compression(Some(CompressionType::LZ4_FRAME))?; + + let mut writer = + StreamWriter::try_new_with_options(file, schema.as_ref(), options)?; + + let mut num_rows = 0; + let mut num_batches = 0; + let mut num_bytes = 0; + + for batch in batches { + num_batches += 1; + num_rows += batch.num_rows(); + num_bytes += batch.get_array_memory_size(); + writer.write(&batch)?; + } + + writer.finish()?; + timer.done(); + + let stats = PartitionStats::new( + Some(num_rows as u64), + Some(num_batches), + Some(num_bytes as u64), + ); + + Ok((path_str, stats)) + } + + async fn read_shuffle_data(&self, path: &str) -> Result> { + let file = File::open(path).map_err(|e| { + BallistaError::General(format!( + "Failed to open shuffle file at {}: {:?}", + path, e + )) + })?; + let reader = BufReader::new(file); + let stream_reader = StreamReader::try_new(reader, None)?; + + let mut batches = Vec::new(); + for batch_result in stream_reader { + batches.push(batch_result?); + } + + Ok(batches) + } + + async fn delete_job_data(&self, job_id: &str) -> Result<()> { + let mut path = PathBuf::from(&self.work_dir); + path.push(job_id); + if path.exists() { + std::fs::remove_dir_all(&path)?; + } + Ok(()) + } + + fn base_path(&self) -> &str { + &self.work_dir + } + + fn can_handle(&self, path: &str) -> bool { + // Local storage can handle paths that don't start with a URL scheme + !path.starts_with("s3://") + && !path.starts_with("abfs://") + && !path.starts_with("az://") + } +} + +/// Object store based shuffle storage implementation (for S3 and Azure). +#[derive(Debug)] +pub struct ObjectStoreShuffleStorage { + store: Arc, + base_url: String, + storage_type: ShuffleStorageType, +} + +impl ObjectStoreShuffleStorage { + /// Creates a new S3 shuffle storage. + pub fn new_s3(config: &ShuffleStorageConfig) -> Result { + let s3_config = &config.s3_config; + let bucket = s3_config.bucket.as_ref().ok_or_else(|| { + BallistaError::General("S3 bucket not configured".to_string()) + })?; + + let mut builder = AmazonS3Builder::from_env().with_bucket_name(bucket); + + if let Some(region) = &s3_config.region { + builder = builder.with_region(region); + } + + if let Some(endpoint) = &s3_config.endpoint { + builder = builder.with_endpoint(endpoint); + } + + if let (Some(access_key), Some(secret_key)) = + (&s3_config.access_key_id, &s3_config.secret_access_key) + { + builder = builder + .with_access_key_id(access_key) + .with_secret_access_key(secret_key); + } + + if s3_config.allow_http { + builder = builder.with_allow_http(true); + } + + let store = builder.build().map_err(|e| { + BallistaError::General(format!("Failed to create S3 object store: {:?}", e)) + })?; + + let base_url = config + .base_url + .clone() + .unwrap_or_else(|| format!("s3://{}", bucket)); + + Ok(Self { + store: Arc::new(store), + base_url, + storage_type: ShuffleStorageType::S3, + }) + } + + /// Creates a new Azure Blob Storage shuffle storage. + pub fn new_azure(config: &ShuffleStorageConfig) -> Result { + let azure_config = &config.azure_config; + let account = azure_config.account.as_ref().ok_or_else(|| { + BallistaError::General("Azure storage account not configured".to_string()) + })?; + let container = azure_config.container.as_ref().ok_or_else(|| { + BallistaError::General("Azure storage container not configured".to_string()) + })?; + + let mut builder = MicrosoftAzureBuilder::new() + .with_account(account) + .with_container_name(container); + + if let Some(access_key) = &azure_config.access_key { + builder = builder.with_access_key(access_key); + } + + if let Some(sas_token) = &azure_config.sas_token { + // Parse SAS token into key-value pairs + // SAS token format: ?sv=2021-06-08&ss=bf&srt=sco&... + let query_pairs: Vec<(String, String)> = sas_token + .trim_start_matches('?') + .split('&') + .filter_map(|pair| { + let mut parts = pair.splitn(2, '='); + match (parts.next(), parts.next()) { + (Some(key), Some(value)) => { + Some((key.to_string(), value.to_string())) + } + _ => None, + } + }) + .collect(); + builder = builder.with_sas_authorization(query_pairs); + } + + let store = builder.build().map_err(|e| { + BallistaError::General(format!( + "Failed to create Azure object store: {:?}", + e + )) + })?; + + let base_url = config.base_url.clone().unwrap_or_else(|| { + format!("abfs://{}@{}.dfs.core.windows.net", container, account) + }); + + Ok(Self { + store: Arc::new(store), + base_url, + storage_type: ShuffleStorageType::Azure, + }) + } + + /// Creates object store storage from configuration. + pub fn from_config(config: &ShuffleStorageConfig) -> Result { + match config.storage_type { + ShuffleStorageType::S3 => Self::new_s3(config), + ShuffleStorageType::Azure => Self::new_azure(config), + ShuffleStorageType::Local => Err(BallistaError::General( + "Use LocalShuffleStorage for local storage".to_string(), + )), + } + } + + fn make_path( + &self, + job_id: &str, + stage_id: usize, + partition_id: usize, + input_partition: usize, + ) -> String { + let filename = if input_partition == partition_id { + "data.arrow".to_string() + } else { + format!("data-{}.arrow", input_partition) + }; + format!("{}/{}/{}/{}", job_id, stage_id, partition_id, filename) + } +} + +#[async_trait] +impl ShuffleStorage for ObjectStoreShuffleStorage { + async fn write_shuffle_data( + &self, + job_id: &str, + stage_id: usize, + partition_id: usize, + input_partition: usize, + batches: Vec, + schema: SchemaRef, + write_metric: &metrics::Time, + ) -> Result<(String, PartitionStats)> { + let relative_path = + self.make_path(job_id, stage_id, partition_id, input_partition); + let full_url = format!("{}/{}", self.base_url, relative_path); + + debug!("Writing shuffle data to object store: {}", full_url); + + let timer = write_metric.timer(); + + // Write batches to an in-memory buffer first + let mut buffer = Vec::new(); + let options = IpcWriteOptions::default() + .try_with_compression(Some(CompressionType::LZ4_FRAME))?; + + let (_total_rows, _total_batches) = { + let mut writer = StreamWriter::try_new_with_options( + Cursor::new(&mut buffer), + schema.as_ref(), + options, + )?; + + let mut num_rows = 0; + let mut num_batches = 0; + + for batch in &batches { + num_rows += batch.num_rows(); + num_batches += 1; + writer.write(batch)?; + } + + writer.finish()?; + (num_rows, num_batches) + }; + + let num_bytes = buffer.len(); + + // Upload to object store + let object_path = ObjectPath::from(relative_path); + let payload = PutPayload::from(Bytes::from(buffer)); + + self.store.put(&object_path, payload).await.map_err(|e| { + BallistaError::General(format!( + "Failed to upload shuffle data to {}: {:?}", + full_url, e + )) + })?; + + timer.done(); + + let stats = PartitionStats::new( + Some(batches.iter().map(|b| b.num_rows() as u64).sum()), + Some(batches.len() as u64), + Some(num_bytes as u64), + ); + + Ok((full_url, stats)) + } + + async fn read_shuffle_data(&self, path: &str) -> Result> { + // Extract the object path from the full URL + let object_path = self.extract_object_path(path)?; + + debug!("Reading shuffle data from object store: {}", path); + + let get_result = self.store.get(&object_path).await.map_err(|e| { + BallistaError::General(format!( + "Failed to read shuffle data from {}: {:?}", + path, e + )) + })?; + + let bytes = get_result.bytes().await.map_err(|e| { + BallistaError::General(format!("Failed to read bytes from {}: {:?}", path, e)) + })?; + + let cursor = Cursor::new(bytes.to_vec()); + let stream_reader = StreamReader::try_new(cursor, None)?; + + let mut batches = Vec::new(); + for batch_result in stream_reader { + batches.push(batch_result?); + } + + Ok(batches) + } + + async fn delete_job_data(&self, job_id: &str) -> Result<()> { + let prefix = ObjectPath::from(job_id.to_string()); + + // List all objects with the job_id prefix + let mut list_stream = self.store.list(Some(&prefix)); + let mut objects_to_delete = Vec::new(); + + while let Some(result) = list_stream.next().await { + match result { + Ok(meta) => objects_to_delete.push(meta.location), + Err(e) => { + return Err(BallistaError::General(format!( + "Failed to list objects for job {}: {:?}", + job_id, e + ))); + } + } + } + + // Delete all objects + for path in objects_to_delete { + self.store.delete(&path).await.map_err(|e| { + BallistaError::General(format!( + "Failed to delete object {:?}: {:?}", + path, e + )) + })?; + } + + Ok(()) + } + + fn base_path(&self) -> &str { + &self.base_url + } + + fn can_handle(&self, path: &str) -> bool { + match self.storage_type { + ShuffleStorageType::S3 => path.starts_with("s3://"), + ShuffleStorageType::Azure => { + path.starts_with("abfs://") || path.starts_with("az://") + } + ShuffleStorageType::Local => false, + } + } +} + +impl ObjectStoreShuffleStorage { + fn extract_object_path(&self, path: &str) -> Result { + // Parse the URL and extract the path component + if let Ok(url) = Url::parse(path) { + let path_str = url.path().trim_start_matches('/'); + Ok(ObjectPath::from(path_str)) + } else { + // If it's not a valid URL, assume it's already a relative path + Ok(ObjectPath::from(path)) + } + } +} + +/// Factory for creating shuffle storage instances. +pub struct ShuffleStorageFactory; + +impl ShuffleStorageFactory { + /// Creates a shuffle storage instance based on the configuration. + pub fn create(config: &ShuffleStorageConfig) -> Result> { + match config.storage_type { + ShuffleStorageType::Local => { + let work_dir = config.base_url.as_ref().ok_or_else(|| { + BallistaError::General("Work directory not configured".to_string()) + })?; + Ok(Arc::new(LocalShuffleStorage::new(work_dir))) + } + ShuffleStorageType::S3 => { + Ok(Arc::new(ObjectStoreShuffleStorage::new_s3(config)?)) + } + ShuffleStorageType::Azure => { + Ok(Arc::new(ObjectStoreShuffleStorage::new_azure(config)?)) + } + } + } + + /// Creates a local shuffle storage with the given work directory. + pub fn create_local(work_dir: &str) -> Arc { + Arc::new(LocalShuffleStorage::new(work_dir)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::arrow::array::Int32Array; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; + use tempfile::TempDir; + + fn create_test_batch() -> (RecordBatch, SchemaRef) { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + (batch, schema) + } + + #[tokio::test] + async fn test_local_shuffle_storage_write_read() { + let temp_dir = TempDir::new().unwrap(); + let storage = LocalShuffleStorage::new(temp_dir.path().to_str().unwrap()); + + let (batch, schema) = create_test_batch(); + let metrics = ExecutionPlanMetricsSet::new(); + let time_metric = + metrics::MetricBuilder::new(&metrics).subset_time("write_time", 0); + + let (path, stats) = storage + .write_shuffle_data( + "test_job", + 1, + 0, + 0, + vec![batch.clone()], + schema, + &time_metric, + ) + .await + .unwrap(); + + assert!(path.contains("test_job")); + assert_eq!(stats.num_rows, Some(3)); + assert_eq!(stats.num_batches, Some(1)); + + let read_batches = storage.read_shuffle_data(&path).await.unwrap(); + assert_eq!(read_batches.len(), 1); + assert_eq!(read_batches[0].num_rows(), 3); + } + + #[tokio::test] + async fn test_local_shuffle_storage_delete() { + let temp_dir = TempDir::new().unwrap(); + let storage = LocalShuffleStorage::new(temp_dir.path().to_str().unwrap()); + + let (batch, schema) = create_test_batch(); + let metrics = ExecutionPlanMetricsSet::new(); + let time_metric = + metrics::MetricBuilder::new(&metrics).subset_time("write_time", 0); + + let (path, _) = storage + .write_shuffle_data("test_job", 1, 0, 0, vec![batch], schema, &time_metric) + .await + .unwrap(); + + assert!(std::path::Path::new(&path).exists()); + + storage.delete_job_data("test_job").await.unwrap(); + assert!(!std::path::Path::new(&path).exists()); + } + + #[test] + fn test_shuffle_storage_type_parse() { + assert_eq!( + "local".parse::().unwrap(), + ShuffleStorageType::Local + ); + assert_eq!( + "s3".parse::().unwrap(), + ShuffleStorageType::S3 + ); + assert_eq!( + "azure".parse::().unwrap(), + ShuffleStorageType::Azure + ); + assert_eq!( + "abfs".parse::().unwrap(), + ShuffleStorageType::Azure + ); + } + + #[test] + fn test_storage_config_new_local() { + let config = ShuffleStorageConfig::new_local("/tmp/ballista"); + assert_eq!(config.storage_type, ShuffleStorageType::Local); + assert_eq!(config.base_url, Some("/tmp/ballista".to_string())); + } + + #[test] + fn test_storage_config_new_s3() { + let config = + ShuffleStorageConfig::new_s3("my-bucket", Some("shuffle"), Some("us-east-1")); + assert_eq!(config.storage_type, ShuffleStorageType::S3); + assert_eq!(config.base_url, Some("s3://my-bucket/shuffle".to_string())); + assert_eq!(config.s3_config.bucket, Some("my-bucket".to_string())); + assert_eq!(config.s3_config.region, Some("us-east-1".to_string())); + } + + #[test] + fn test_storage_config_new_azure() { + let config = + ShuffleStorageConfig::new_azure("myaccount", "mycontainer", Some("shuffle")); + assert_eq!(config.storage_type, ShuffleStorageType::Azure); + assert_eq!( + config.base_url, + Some("abfs://mycontainer@myaccount.dfs.core.windows.net/shuffle".to_string()) + ); + } +} diff --git a/ballista/executor/Cargo.toml b/ballista/executor/Cargo.toml index f0fb6043a7..750bed7f61 100644 --- a/ballista/executor/Cargo.toml +++ b/ballista/executor/Cargo.toml @@ -33,7 +33,7 @@ path = "src/bin/main.rs" required-features = ["build-binary"] [features] -build-binary = ["clap", "tracing-subscriber", "tracing-appender", "tracing", "ballista-core/build-binary"] +build-binary = ["clap", "tracing-subscriber", "tracing-appender", "tracing"] default = ["build-binary", "mimalloc"] vortex = ["ballista-core/vortex", "vortex-array", "vortex-ipc"] diff --git a/ballista/scheduler/Cargo.toml b/ballista/scheduler/Cargo.toml index 883b211d13..9245e1e1b6 100644 --- a/ballista/scheduler/Cargo.toml +++ b/ballista/scheduler/Cargo.toml @@ -33,7 +33,7 @@ path = "src/bin/main.rs" required-features = ["build-binary"] [features] -build-binary = ["clap", "tracing-subscriber", "tracing-appender", "tracing", "ballista-core/build-binary"] +build-binary = ["clap", "tracing-subscriber", "tracing-appender", "tracing"] default = ["build-binary", "substrait"] # job info can cache stage plans, in some cases where # task plans can be re-computed, cache behavior may need to be disabled.