Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ datafusion-cli = "51.0.0"
datafusion-proto = "51.0.0"
datafusion-proto-common = "51.0.0"
datafusion-substrait = "51.0.0"
object_store = "0.12"
object_store = { version = "0.12", features = ["aws", "azure"] }
bytes = "1.5"
prost = "0.14"
prost-types = "0.14"
rstest = { version = "0.26" }
Expand Down
9 changes: 4 additions & 5 deletions ballista/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ exclude = ["*.proto"]
rustc-args = ["--cfg", "docsrs"]

[features]
build-binary = ["aws-config", "aws-credential-types", "clap", "object_store"]
docsrs = []
# Used for testing ONLY: causes all values to hash to the same value (test for collisions)
force_hash_collisions = ["datafusion/force_hash_collisions"]
Expand All @@ -45,10 +44,10 @@ vortex = ["vortex-array", "vortex-buffer", "vortex-dtype", "vortex-error", "vort
[dependencies]
arrow-flight = { workspace = true }
async-trait = { workspace = true }
aws-config = { version = "1.6.0", optional = true }
aws-credential-types = { version = "1.2.0", optional = true }
aws-config = { version = "1.6.0" }
aws-credential-types = { version = "1.2.0" }
bytes = { workspace = true }
chrono = { version = "0.4", default-features = false }
clap = { workspace = true, optional = true }
dashmap = "6"
datafusion = { workspace = true }
datafusion-proto = { workspace = true }
Expand All @@ -57,7 +56,7 @@ futures = { workspace = true }
itertools = "0.14"
log = { workspace = true }
md-5 = { version = "^0.10.0" }
object_store = { workspace = true, features = ["aws", "http"], optional = true }
object_store = { workspace = true, features = ["aws", "azure", "http"] }
parking_lot = { workspace = true }
prost = { workspace = true }
prost-types = { workspace = true }
Expand Down
70 changes: 58 additions & 12 deletions ballista/core/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,14 @@ pub const BALLISTA_SHUFFLE_READER_FORCE_REMOTE_READ: &str =
/// Configuration key to prefer Flight protocol for remote shuffle reads.
pub const BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT: &str =
"ballista.shuffle.remote_read_prefer_flight";
/// Configuration key for shuffle storage type (local, s3, azure).
pub const BALLISTA_SHUFFLE_STORAGE_TYPE: &str = "ballista.shuffle.storage_type";
/// Configuration key for shuffle storage base URL/path.
pub const BALLISTA_SHUFFLE_STORAGE_URL: &str = "ballista.shuffle.storage_url";
/// Configuration key for shuffle storage mode (disk or memory).
pub const BALLISTA_SHUFFLE_MEMORY_MODE: &str = "ballista.shuffle.memory_mode";
/// Configuration key indicating if this is the final output stage.
/// Internal configuration key indicating if this is the final output stage.
/// This is set by the scheduler based on stage topology, NOT user-configurable.
/// When true, shuffle data is always written to disk regardless of memory_mode setting.
pub const BALLISTA_IS_FINAL_STAGE: &str = "ballista.shuffle.is_final_stage";
/// Shuffle format configuration: "arrow_ipc" or "vortex"
Expand Down Expand Up @@ -93,14 +98,21 @@ static CONFIG_ENTRIES: LazyLock<HashMap<String, ConfigEntry>> = LazyLock::new(||
"Forces the shuffle reader to use flight reader instead of block reader for remote read. Block reader usually has better performance and resource utilization".to_string(),
DataType::Boolean,
Some((false).to_string())),
ConfigEntry::new(BALLISTA_SHUFFLE_STORAGE_TYPE.to_string(),
"Storage type for shuffle data: 'local' (default), 's3', or 'azure'".to_string(),
DataType::Utf8,
Some("local".to_string())),
ConfigEntry::new(BALLISTA_SHUFFLE_STORAGE_URL.to_string(),
"Base URL/path for shuffle storage. For local: file path; For S3: s3://bucket/prefix; For Azure: abfs://container@account.dfs.core.windows.net/prefix".to_string(),
DataType::Utf8,
None),
ConfigEntry::new(BALLISTA_SHUFFLE_MEMORY_MODE.to_string(),
"When enabled, shuffle data is kept in memory on executors instead of being written to disk. This can improve performance for workloads with sufficient memory.".to_string(),
DataType::Boolean,
Some((false).to_string())),
ConfigEntry::new(BALLISTA_IS_FINAL_STAGE.to_string(),
"When true, indicates this is the final output stage. Final stages always write to disk regardless of memory_mode setting to ensure proper cleanup.".to_string(),
DataType::Boolean,
Some((false).to_string())),
// Note: BALLISTA_IS_FINAL_STAGE is intentionally NOT in CONFIG_ENTRIES.
// It's an internal flag set by the scheduler based on stage topology,
// not a user-configurable setting.
ConfigEntry::new(BALLISTA_GRPC_CLIENT_CONNECT_TIMEOUT_SECONDS.to_string(),
"Connection timeout for gRPC client in seconds".to_string(),
DataType::UInt64,
Expand Down Expand Up @@ -320,6 +332,16 @@ impl BallistaConfig {
self.get_bool_setting(BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT)
}

/// Returns the shuffle storage type (local, s3, azure).
pub fn shuffle_storage_type(&self) -> String {
self.get_string_setting(BALLISTA_SHUFFLE_STORAGE_TYPE)
}

/// Returns the shuffle storage base URL/path if configured.
pub fn shuffle_storage_url(&self) -> Option<String> {
self.settings.get(BALLISTA_SHUFFLE_STORAGE_URL).cloned()
}

/// Returns whether in-memory shuffle mode is enabled.
///
/// When enabled, shuffle data is kept in memory on executors instead of
Expand All @@ -330,9 +352,21 @@ impl BallistaConfig {
}

/// Returns whether this is the final output stage.
/// This is an internal flag set by the scheduler, not user-configurable.
/// Final stages always write to disk regardless of memory_mode setting.
pub fn is_final_stage(&self) -> bool {
self.get_bool_setting(BALLISTA_IS_FINAL_STAGE)
self.settings
.get(BALLISTA_IS_FINAL_STAGE)
.and_then(|v| v.parse::<bool>().ok())
.unwrap_or(false)
}

/// Sets the internal is_final_stage flag.
/// This should only be called by the scheduler when creating task configurations.
pub fn with_is_final_stage(mut self, is_final: bool) -> Self {
self.settings
.insert(BALLISTA_IS_FINAL_STAGE.to_string(), is_final.to_string());
self
}

/// Returns the configured shuffle format (ArrowIpc or Vortex)
Expand Down Expand Up @@ -434,7 +468,6 @@ impl datafusion::config::ConfigExtension for BallistaConfig {
/// Ballista supports both push-based and pull-based task scheduling.
/// It is recommended that you try both to determine which is the best for your use case.
#[derive(Clone, Copy, Debug, serde::Deserialize, Default)]
#[cfg_attr(feature = "build-binary", derive(clap::ValueEnum))]
pub enum TaskSchedulingPolicy {
/// Pull-based scheduling works in a similar way to Apache Spark
#[default]
Expand All @@ -451,18 +484,23 @@ impl Display for TaskSchedulingPolicy {
}
}

#[cfg(feature = "build-binary")]
impl std::str::FromStr for TaskSchedulingPolicy {
type Err = String;

fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
clap::ValueEnum::from_str(s, true)
match s.to_lowercase().as_str() {
"pull-staged" | "pullstaged" => Ok(TaskSchedulingPolicy::PullStaged),
"push-staged" | "pushstaged" => Ok(TaskSchedulingPolicy::PushStaged),
_ => Err(format!(
"Invalid scheduling policy '{}'. Valid options: 'pull-staged', 'push-staged'",
s
)),
}
}
}

/// Configures the log file rotation policy.
#[derive(Clone, Copy, Debug, serde::Deserialize, Default)]
#[cfg_attr(feature = "build-binary", derive(clap::ValueEnum))]
pub enum LogRotationPolicy {
/// Rotate log files every minute.
Minutely,
Expand All @@ -486,12 +524,20 @@ impl Display for LogRotationPolicy {
}
}

#[cfg(feature = "build-binary")]
impl std::str::FromStr for LogRotationPolicy {
type Err = String;

fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
clap::ValueEnum::from_str(s, true)
match s.to_lowercase().as_str() {
"minutely" => Ok(LogRotationPolicy::Minutely),
"hourly" => Ok(LogRotationPolicy::Hourly),
"daily" => Ok(LogRotationPolicy::Daily),
"never" => Ok(LogRotationPolicy::Never),
_ => Err(format!(
"Invalid rotation policy '{}'. Valid options: 'minutely', 'hourly', 'daily', 'never'",
s
)),
}
}
}

Expand Down
146 changes: 137 additions & 9 deletions ballista/core/src/execution_plans/shuffle_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,17 @@ use std::any::Any;
use std::collections::HashMap;
use std::fmt::Debug;
use std::fs::File;
use std::io::BufReader;
use std::io::{BufReader, Cursor};
use std::pin::Pin;
use std::result;
use std::sync::Arc;
use std::task::{Context, Poll};

use object_store::ObjectStore;
use object_store::aws::AmazonS3Builder;
use object_store::azure::MicrosoftAzureBuilder;
use url::Url;

use crate::client::BallistaClient;
use crate::execution_plans::shuffle_manager::global_shuffle_manager;
use crate::extension::{
Expand Down Expand Up @@ -737,14 +742,6 @@ fn fetch_partition_local_vortex(
Ok(Box::pin(stream))
}

async fn fetch_partition_object_store(
_location: &PartitionLocation,
) -> result::Result<SendableRecordBatchStream, BallistaError> {
Err(BallistaError::NotImplemented(
"Should not use ObjectStorePartitionReader".to_string(),
))
}

/// Fetch partition data from in-memory shuffle storage.
///
/// After successfully fetching the data, the partition is removed from memory
Expand Down Expand Up @@ -833,6 +830,137 @@ impl RecordBatchStream for InMemoryShuffleStream {
}
}

/// Check if the location is an object store path (S3 or Azure).
#[allow(dead_code)]
fn check_is_object_store_location(location: &PartitionLocation) -> bool {
let path = location.path.as_str();
path.starts_with("s3://") || path.starts_with("abfs://") || path.starts_with("az://")
}

async fn fetch_partition_object_store(
location: &PartitionLocation,
) -> result::Result<SendableRecordBatchStream, BallistaError> {
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;

let path = &location.path;
let metadata = &location.executor_meta;
let partition_id = &location.partition_id;

debug!("Fetching shuffle partition from object store: {}", path);

let batches = fetch_partition_object_store_inner(path)
.await
.map_err(|e| {
// return BallistaError::FetchFailed may let scheduler retry this task.
BallistaError::FetchFailed(
metadata.id.clone(),
partition_id.stage_id,
partition_id.partition_id,
e.to_string(),
)
})?;

if batches.is_empty() {
return Err(BallistaError::General(format!(
"No batches found in shuffle partition at {}",
path
)));
}

let schema = batches[0].schema();
let stream = futures::stream::iter(batches.into_iter().map(Ok));
Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
}

async fn fetch_partition_object_store_inner(
path: &str,
) -> result::Result<Vec<RecordBatch>, BallistaError> {
use object_store::path::Path as ObjectPath;

let url = Url::parse(path).map_err(|e| {
BallistaError::General(format!(
"Failed to parse object store URL '{}': {:?}",
path, e
))
})?;

let scheme = url.scheme();
let store: Arc<dyn ObjectStore> = match scheme {
"s3" => {
let bucket = url.host_str().ok_or_else(|| {
BallistaError::General(format!("No bucket in S3 URL: {}", path))
})?;
let builder = AmazonS3Builder::from_env().with_bucket_name(bucket);
Arc::new(builder.build().map_err(|e| {
BallistaError::General(format!("Failed to create S3 client: {:?}", e))
})?)
}
"abfs" | "az" => {
// Parse Azure URL: abfs://container@account.dfs.core.windows.net/path
let host = url.host_str().ok_or_else(|| {
BallistaError::General(format!("No host in Azure URL: {}", path))
})?;

// Extract container from username portion
let container = url.username();
if container.is_empty() {
return Err(BallistaError::General(format!(
"No container in Azure URL. Expected format: abfs://container@account.dfs.core.windows.net/path. Got: {}",
path
)));
}

// Extract account from host (account.dfs.core.windows.net)
let account = host.split('.').next().ok_or_else(|| {
BallistaError::General(format!("No account in Azure URL: {}", path))
})?;

let builder = MicrosoftAzureBuilder::from_env()
.with_account(account)
.with_container_name(container);
Arc::new(builder.build().map_err(|e| {
BallistaError::General(format!("Failed to create Azure client: {:?}", e))
})?)
}
_ => {
return Err(BallistaError::General(format!(
"Unsupported object store scheme: {}. Supported: s3, abfs, az",
scheme
)));
}
};

// Extract the object path from the URL
let object_path = ObjectPath::from(url.path().trim_start_matches('/'));

debug!("Reading object from path: {:?}", object_path);

let get_result = store.get(&object_path).await.map_err(|e| {
BallistaError::General(format!("Failed to read object from {}: {:?}", path, e))
})?;

let bytes = get_result.bytes().await.map_err(|e| {
BallistaError::General(format!("Failed to read bytes from {}: {:?}", path, e))
})?;

let cursor = Cursor::new(bytes.to_vec());
let stream_reader = StreamReader::try_new(cursor, None).map_err(|e| {
BallistaError::General(format!(
"Failed to create Arrow stream reader for {}: {:?}",
path, e
))
})?;

let mut batches = Vec::new();
for batch_result in stream_reader {
batches.push(batch_result.map_err(|e| {
BallistaError::General(format!("Failed to read batch from {}: {:?}", path, e))
})?);
}

Ok(batches)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading