diff --git a/Cargo.lock b/Cargo.lock index 3c7478b079..4920368d7c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1001,6 +1001,7 @@ dependencies = [ "aws-credential-types", "chrono", "clap 4.5.54", + "dashmap", "datafusion", "datafusion-proto", "datafusion-proto-common", @@ -1059,6 +1060,7 @@ dependencies = [ "async-trait", "backoff", "ballista-core", + "bytes", "clap 4.5.54", "dashmap", "datafusion", diff --git a/ballista/core/Cargo.toml b/ballista/core/Cargo.toml index 8402c3613b..65fc6724dc 100644 --- a/ballista/core/Cargo.toml +++ b/ballista/core/Cargo.toml @@ -49,6 +49,7 @@ aws-config = { version = "1.6.0", optional = true } aws-credential-types = { version = "1.2.0", optional = true } chrono = { version = "0.4", default-features = false } clap = { workspace = true, optional = true } +dashmap = "6" datafusion = { workspace = true } datafusion-proto = { workspace = true } datafusion-proto-common = { workspace = true } diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs index 7510724a41..314ad05cb2 100644 --- a/ballista/core/src/config.rs +++ b/ballista/core/src/config.rs @@ -44,6 +44,11 @@ pub const BALLISTA_SHUFFLE_READER_FORCE_REMOTE_READ: &str = /// Configuration key to prefer Flight protocol for remote shuffle reads. pub const BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT: &str = "ballista.shuffle.remote_read_prefer_flight"; +/// Configuration key for shuffle storage mode (disk or memory). +pub const BALLISTA_SHUFFLE_MEMORY_MODE: &str = "ballista.shuffle.memory_mode"; +/// Configuration key indicating if this is the final output stage. +/// When true, shuffle data is always written to disk regardless of memory_mode setting. +pub const BALLISTA_IS_FINAL_STAGE: &str = "ballista.shuffle.is_final_stage"; /// Shuffle format configuration: "arrow_ipc" or "vortex" pub const BALLISTA_SHUFFLE_FORMAT: &str = "ballista.shuffle.format"; @@ -88,6 +93,14 @@ static CONFIG_ENTRIES: LazyLock> = LazyLock::new(|| "Forces the shuffle reader to use flight reader instead of block reader for remote read. Block reader usually has better performance and resource utilization".to_string(), DataType::Boolean, Some((false).to_string())), + ConfigEntry::new(BALLISTA_SHUFFLE_MEMORY_MODE.to_string(), + "When enabled, shuffle data is kept in memory on executors instead of being written to disk. This can improve performance for workloads with sufficient memory.".to_string(), + DataType::Boolean, + Some((false).to_string())), + ConfigEntry::new(BALLISTA_IS_FINAL_STAGE.to_string(), + "When true, indicates this is the final output stage. Final stages always write to disk regardless of memory_mode setting to ensure proper cleanup.".to_string(), + DataType::Boolean, + Some((false).to_string())), ConfigEntry::new(BALLISTA_GRPC_CLIENT_CONNECT_TIMEOUT_SECONDS.to_string(), "Connection timeout for gRPC client in seconds".to_string(), DataType::UInt64, @@ -307,6 +320,21 @@ impl BallistaConfig { self.get_bool_setting(BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT) } + /// Returns whether in-memory shuffle mode is enabled. + /// + /// When enabled, shuffle data is kept in memory on executors instead of + /// being written to disk. This can improve performance for workloads + /// with sufficient memory. + pub fn shuffle_memory_mode(&self) -> bool { + self.get_bool_setting(BALLISTA_SHUFFLE_MEMORY_MODE) + } + + /// Returns whether this is the final output stage. + /// Final stages always write to disk regardless of memory_mode setting. + pub fn is_final_stage(&self) -> bool { + self.get_bool_setting(BALLISTA_IS_FINAL_STAGE) + } + /// Returns the configured shuffle format (ArrowIpc or Vortex) /// /// Note: Vortex format requires the 'vortex' feature to be enabled. @@ -477,4 +505,46 @@ mod tests { assert_eq!(16777216, config.default_grpc_client_max_message_size()); Ok(()) } + + #[test] + fn test_is_final_stage_default() { + let config = BallistaConfig::default(); + // Default should be false + assert!(!config.is_final_stage()); + } + + #[test] + fn test_shuffle_memory_mode_default() { + let config = BallistaConfig::default(); + // Default should be false (disk-based shuffles) + assert!(!config.shuffle_memory_mode()); + } + + #[test] + fn test_shuffle_format_default() { + let config = BallistaConfig::default(); + // Default should be ArrowIpc + assert_eq!(config.shuffle_format(), ShuffleFormat::ArrowIpc); + } + + #[test] + fn test_shuffle_format_parsing() { + assert_eq!( + "arrow_ipc".parse::().unwrap(), + ShuffleFormat::ArrowIpc + ); + assert_eq!( + "arrow-ipc".parse::().unwrap(), + ShuffleFormat::ArrowIpc + ); + assert_eq!( + "ipc".parse::().unwrap(), + ShuffleFormat::ArrowIpc + ); + assert_eq!( + "vortex".parse::().unwrap(), + ShuffleFormat::Vortex + ); + assert!("invalid".parse::().is_err()); + } } diff --git a/ballista/core/src/execution_plans/mod.rs b/ballista/core/src/execution_plans/mod.rs index b431c0a1f6..ad33d65743 100644 --- a/ballista/core/src/execution_plans/mod.rs +++ b/ballista/core/src/execution_plans/mod.rs @@ -19,6 +19,7 @@ //! several Ballista executors. mod distributed_query; +mod shuffle_manager; mod shuffle_reader; mod shuffle_writer; mod unresolved_shuffle; @@ -27,6 +28,10 @@ mod unresolved_shuffle; pub mod vortex_shuffle; pub use distributed_query::DistributedQueryExec; +pub use shuffle_manager::{ + InMemoryShuffleManager, ShufflePartitionData, ShufflePartitionKey, + global_shuffle_manager, +}; pub use shuffle_reader::ShuffleReaderExec; pub use shuffle_writer::ShuffleWriterExec; pub use unresolved_shuffle::UnresolvedShuffleExec; diff --git a/ballista/core/src/execution_plans/shuffle_manager.rs b/ballista/core/src/execution_plans/shuffle_manager.rs new file mode 100644 index 0000000000..fdadaff7f5 --- /dev/null +++ b/ballista/core/src/execution_plans/shuffle_manager.rs @@ -0,0 +1,488 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! In-memory shuffle manager for storing shuffle data in executor memory. +//! +//! This module provides a thread-safe in-memory store for shuffle data that +//! can be used as an alternative to disk-based shuffle storage. When enabled, +//! shuffle writers store data in memory and shuffle readers fetch it directly +//! from memory instead of reading from disk. + +use dashmap::DashMap; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::record_batch::RecordBatch; +use std::sync::Arc; + +use crate::config::ShuffleFormat; +use crate::error::{BallistaError, Result}; + +/// Key for identifying a shuffle partition in the in-memory store. +/// Format: "{job_id}/{stage_id}/{partition_id}" or "{job_id}/{stage_id}/{output_partition}/{input_partition}" +pub type ShufflePartitionKey = String; + +/// In-memory representation of shuffle data. +/// Supports both Arrow RecordBatch and Vortex formats. +#[derive(Debug, Clone)] +pub enum InMemoryShuffleData { + /// Arrow RecordBatch format (default) + Arrow(Vec), + /// Vortex columnar format (requires 'vortex' feature) + #[cfg(feature = "vortex")] + Vortex(Vec), +} + +/// Data stored for a single shuffle partition. +#[derive(Debug, Clone)] +pub struct ShufflePartitionData { + /// The schema of the record batches + pub schema: SchemaRef, + /// The data for this partition (Arrow or Vortex format) + pub data: InMemoryShuffleData, + /// Total number of rows across all batches + pub num_rows: u64, + /// Total number of batches + pub num_batches: u64, + /// Approximate size in bytes (based on array memory size) + pub num_bytes: u64, + /// The format of the data + pub format: ShuffleFormat, +} + +impl ShufflePartitionData { + /// Creates a new ShufflePartitionData from a schema and Arrow batches. + pub fn new(schema: SchemaRef, batches: Vec) -> Self { + let num_rows: u64 = batches.iter().map(|b| b.num_rows() as u64).sum(); + let num_batches = batches.len() as u64; + let num_bytes: u64 = batches + .iter() + .map(|b| b.get_array_memory_size() as u64) + .sum(); + + Self { + schema, + data: InMemoryShuffleData::Arrow(batches), + num_rows, + num_batches, + num_bytes, + format: ShuffleFormat::ArrowIpc, + } + } + + /// Creates a new ShufflePartitionData from a schema and Vortex arrays. + #[cfg(feature = "vortex")] + pub fn new_vortex( + schema: SchemaRef, + arrays: Vec, + num_rows: u64, + num_bytes: u64, + ) -> Self { + let num_batches = arrays.len() as u64; + + Self { + schema, + data: InMemoryShuffleData::Vortex(arrays), + num_rows, + num_batches, + num_bytes, + format: ShuffleFormat::Vortex, + } + } + + /// Returns the batches if stored in Arrow format, otherwise converts from Vortex. + pub fn to_batches(&self) -> Result> { + match &self.data { + InMemoryShuffleData::Arrow(batches) => Ok(batches.clone()), + #[cfg(feature = "vortex")] + InMemoryShuffleData::Vortex(arrays) => { + use vortex_array::arrow::IntoArrowArray; + arrays + .iter() + .map(|array| { + let arrow_array = + array.clone().into_arrow_preferred().map_err(|e| { + BallistaError::General(format!( + "Failed to convert Vortex array to Arrow: {e}" + )) + })?; + let struct_array = arrow_array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + BallistaError::General( + "Expected StructArray from Vortex conversion" + .to_string(), + ) + })?; + Ok(RecordBatch::from(struct_array)) + }) + .collect() + } + } + } +} + +/// Thread-safe in-memory storage for shuffle partition data. +/// +/// This manager stores shuffle data in memory, keyed by a string that +/// uniquely identifies the shuffle partition (job_id/stage_id/partition_id). +/// It is designed to be shared across all tasks in an executor. +#[derive(Debug, Default)] +pub struct InMemoryShuffleManager { + /// Map from partition key to partition data + partitions: DashMap, +} + +impl InMemoryShuffleManager { + /// Creates a new empty in-memory shuffle manager. + pub fn new() -> Self { + Self { + partitions: DashMap::new(), + } + } + + /// Stores shuffle partition data in memory. + /// + /// # Arguments + /// * `key` - Unique identifier for the partition (e.g., "job_id/stage_id/partition_id") + /// * `data` - The partition data to store + pub fn store_partition(&self, key: ShufflePartitionKey, data: ShufflePartitionData) { + log::debug!( + "Storing shuffle partition in memory: {} ({} batches, {} rows, {} bytes)", + key, + data.num_batches, + data.num_rows, + data.num_bytes + ); + self.partitions.insert(key, data); + } + + /// Retrieves shuffle partition data from memory. + /// + /// # Arguments + /// * `key` - Unique identifier for the partition + /// + /// # Returns + /// * `Ok(ShufflePartitionData)` if the partition exists + /// * `Err(BallistaError)` if the partition is not found + pub fn get_partition(&self, key: &str) -> Result { + self.partitions + .get(key) + .map(|entry| entry.value().clone()) + .ok_or_else(|| { + BallistaError::General(format!( + "Shuffle partition not found in memory: {key}" + )) + }) + } + + /// Checks if a shuffle partition exists in memory. + pub fn contains_partition(&self, key: &str) -> bool { + self.partitions.contains_key(key) + } + + /// Removes a shuffle partition from memory. + /// + /// # Returns + /// The removed partition data if it existed + pub fn remove_partition(&self, key: &str) -> Option { + self.partitions.remove(key).map(|(_, v)| v) + } + + /// Removes all partitions for a given job. + /// + /// # Arguments + /// * `job_id` - The job identifier + pub fn remove_job_partitions(&self, job_id: &str) { + let prefix = format!("{job_id}/"); + self.partitions.retain(|k, _| !k.starts_with(&prefix)); + log::debug!("Removed all shuffle partitions for job: {job_id}"); + } + + /// Removes all partitions for a given stage within a job. + /// + /// This is called when a stage's output has been fully consumed by the next stage, + /// allowing the memory to be reclaimed immediately rather than waiting for job completion. + /// + /// # Arguments + /// * `job_id` - The job identifier + /// * `stage_id` - The stage identifier + /// + /// # Returns + /// The number of partitions that were removed + pub fn remove_stage_partitions(&self, job_id: &str, stage_id: usize) -> usize { + let prefix = format!("{job_id}/{stage_id}/"); + let initial_count = self.partitions.len(); + self.partitions.retain(|k, _| !k.starts_with(&prefix)); + let removed = initial_count - self.partitions.len(); + log::debug!( + "Removed {} shuffle partitions for stage: {}/{}", + removed, + job_id, + stage_id + ); + removed + } + + /// Returns the total number of partitions stored in memory. + pub fn partition_count(&self) -> usize { + self.partitions.len() + } + + /// Returns the approximate total memory usage in bytes. + pub fn total_memory_usage(&self) -> u64 { + self.partitions + .iter() + .map(|entry| entry.value().num_bytes) + .sum() + } + + /// Generates the partition key for a simple partition (no repartitioning). + pub fn partition_key(job_id: &str, stage_id: usize, partition_id: usize) -> String { + format!("{job_id}/{stage_id}/{partition_id}/data") + } + + /// Generates the partition key for a hash-partitioned output. + pub fn hash_partition_key( + job_id: &str, + stage_id: usize, + output_partition: usize, + input_partition: usize, + ) -> String { + format!("{job_id}/{stage_id}/{output_partition}/data-{input_partition}") + } + + /// Clears all stored partitions. + pub fn clear(&self) { + self.partitions.clear(); + log::debug!("Cleared all shuffle partitions from memory"); + } +} + +/// Global in-memory shuffle manager instance. +/// +/// This is a singleton that can be accessed from anywhere in the executor. +/// It is initialized lazily on first access. +static GLOBAL_SHUFFLE_MANAGER: std::sync::LazyLock> = + std::sync::LazyLock::new(|| Arc::new(InMemoryShuffleManager::new())); + +/// Returns the global in-memory shuffle manager instance. +pub fn global_shuffle_manager() -> Arc { + GLOBAL_SHUFFLE_MANAGER.clone() +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::arrow::array::{Int32Array, StringArray}; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + + fn create_test_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + + RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + ], + ) + .unwrap() + } + + #[test] + fn test_store_and_retrieve() { + let manager = InMemoryShuffleManager::new(); + let batch = create_test_batch(); + let schema = batch.schema(); + let data = ShufflePartitionData::new(schema.clone(), vec![batch]); + + let key = InMemoryShuffleManager::partition_key("job1", 1, 0); + manager.store_partition(key.clone(), data); + + assert!(manager.contains_partition(&key)); + + let retrieved = manager.get_partition(&key).unwrap(); + assert_eq!(retrieved.num_rows, 3); + assert_eq!(retrieved.num_batches, 1); + let batches = retrieved.to_batches().unwrap(); + assert_eq!(batches.len(), 1); + } + + #[test] + fn test_remove_job_partitions() { + let manager = InMemoryShuffleManager::new(); + let batch = create_test_batch(); + let schema = batch.schema(); + + // Store partitions for two jobs + for job in ["job1", "job2"] { + for stage in 0..2 { + for partition in 0..3 { + let key = + InMemoryShuffleManager::partition_key(job, stage, partition); + let data = + ShufflePartitionData::new(schema.clone(), vec![batch.clone()]); + manager.store_partition(key, data); + } + } + } + + assert_eq!(manager.partition_count(), 12); + + manager.remove_job_partitions("job1"); + assert_eq!(manager.partition_count(), 6); + + // Verify job2 partitions still exist + let key = InMemoryShuffleManager::partition_key("job2", 0, 0); + assert!(manager.contains_partition(&key)); + } + + #[test] + fn test_hash_partition_key() { + let key = InMemoryShuffleManager::hash_partition_key("job1", 1, 2, 3); + assert_eq!(key, "job1/1/2/data-3"); + } + + #[test] + fn test_remove_stage_partitions() { + let manager = InMemoryShuffleManager::new(); + let batch = create_test_batch(); + let schema = batch.schema(); + + // Store partitions for multiple stages in the same job + for stage in 0..3 { + for partition in 0..4 { + let key = InMemoryShuffleManager::partition_key("job1", stage, partition); + let data = ShufflePartitionData::new(schema.clone(), vec![batch.clone()]); + manager.store_partition(key, data); + } + } + + assert_eq!(manager.partition_count(), 12); + + // Remove stage 1 partitions + let removed = manager.remove_stage_partitions("job1", 1); + assert_eq!(removed, 4); + assert_eq!(manager.partition_count(), 8); + + // Verify stage 0 and 2 partitions still exist + let key0 = InMemoryShuffleManager::partition_key("job1", 0, 0); + let key2 = InMemoryShuffleManager::partition_key("job1", 2, 0); + assert!(manager.contains_partition(&key0)); + assert!(manager.contains_partition(&key2)); + + // Verify stage 1 partitions are gone + let key1 = InMemoryShuffleManager::partition_key("job1", 1, 0); + assert!(!manager.contains_partition(&key1)); + } + + #[test] + fn test_remove_stage_partitions_different_jobs() { + let manager = InMemoryShuffleManager::new(); + let batch = create_test_batch(); + let schema = batch.schema(); + + // Store partitions for stage 1 in two different jobs + for job in ["job1", "job2"] { + for partition in 0..3 { + let key = InMemoryShuffleManager::partition_key(job, 1, partition); + let data = ShufflePartitionData::new(schema.clone(), vec![batch.clone()]); + manager.store_partition(key, data); + } + } + + assert_eq!(manager.partition_count(), 6); + + // Remove stage 1 from job1 only + let removed = manager.remove_stage_partitions("job1", 1); + assert_eq!(removed, 3); + assert_eq!(manager.partition_count(), 3); + + // Verify job2 stage 1 partitions still exist + let key = InMemoryShuffleManager::partition_key("job2", 1, 0); + assert!(manager.contains_partition(&key)); + } + + #[test] + fn test_remove_partition_returns_data() { + let manager = InMemoryShuffleManager::new(); + let batch = create_test_batch(); + let schema = batch.schema(); + let data = ShufflePartitionData::new(schema.clone(), vec![batch]); + + let key = InMemoryShuffleManager::partition_key("job1", 1, 0); + manager.store_partition(key.clone(), data); + + assert!(manager.contains_partition(&key)); + + // Remove should return the data + let removed = manager.remove_partition(&key); + assert!(removed.is_some()); + let removed_data = removed.unwrap(); + assert_eq!(removed_data.num_rows, 3); + assert_eq!(removed_data.num_batches, 1); + + // Partition should no longer exist + assert!(!manager.contains_partition(&key)); + + // Second remove should return None + let removed_again = manager.remove_partition(&key); + assert!(removed_again.is_none()); + } + + #[test] + fn test_total_memory_usage() { + let manager = InMemoryShuffleManager::new(); + let batch = create_test_batch(); + let schema = batch.schema(); + + // Store multiple partitions + for i in 0..3 { + let key = InMemoryShuffleManager::partition_key("job1", 1, i); + let data = ShufflePartitionData::new(schema.clone(), vec![batch.clone()]); + manager.store_partition(key, data); + } + + // Memory usage should be > 0 + let usage = manager.total_memory_usage(); + assert!(usage > 0); + + // Remove partitions and verify usage decreases + manager.remove_job_partitions("job1"); + assert_eq!(manager.total_memory_usage(), 0); + } + + #[test] + fn test_clear() { + let manager = InMemoryShuffleManager::new(); + let batch = create_test_batch(); + let schema = batch.schema(); + + for i in 0..5 { + let key = InMemoryShuffleManager::partition_key("job1", 1, i); + let data = ShufflePartitionData::new(schema.clone(), vec![batch.clone()]); + manager.store_partition(key, data); + } + + assert_eq!(manager.partition_count(), 5); + manager.clear(); + assert_eq!(manager.partition_count(), 0); + } +} diff --git a/ballista/core/src/execution_plans/shuffle_reader.rs b/ballista/core/src/execution_plans/shuffle_reader.rs index aac8e17c46..cd88fcf317 100644 --- a/ballista/core/src/execution_plans/shuffle_reader.rs +++ b/ballista/core/src/execution_plans/shuffle_reader.rs @@ -29,6 +29,7 @@ use std::sync::Arc; use std::task::{Context, Poll}; use crate::client::BallistaClient; +use crate::execution_plans::shuffle_manager::global_shuffle_manager; use crate::extension::{ BallistaConfigGrpcEndpoint, SessionConfigExt, ShuffleReadMetricsCallback, }; @@ -379,6 +380,7 @@ impl Stream for AbortableReceiverStream { /// Local partitions are read directly from local Arrow IPC files, /// while remote partitions are fetched using the Arrow Flight client. /// If `force_remote_read` is true, all partitions are treated as remote. +#[allow(dead_code)] fn local_remote_read_split( partition_locations: Vec, force_remote_read: bool, @@ -392,6 +394,34 @@ fn local_remote_read_split( } } +/// Splits partition locations into memory, local disk, and remote categories. +/// Returns (memory_locations, local_locations, remote_locations) +fn split_partition_locations( + partition_locations: Vec, + force_remote_read: bool, +) -> ( + Vec, + Vec, + Vec, +) { + let mut memory_locations = Vec::new(); + let mut local_locations = Vec::new(); + let mut remote_locations = Vec::new(); + + for loc in partition_locations { + if check_is_memory_location(&loc) { + // Memory locations are always read locally + memory_locations.push(loc); + } else if !force_remote_read && check_is_local_location(&loc) { + local_locations.push(loc); + } else { + remote_locations.push(loc); + } + } + + (memory_locations, local_locations, remote_locations) +} + #[allow(clippy::too_many_arguments)] fn send_fetch_partitions( partition_locations: Vec, @@ -407,15 +437,29 @@ fn send_fetch_partitions( let semaphore = Arc::new(Semaphore::new(max_request_num)); let mut spawned_tasks: Vec> = vec![]; - let (local_locations, remote_locations): (Vec<_>, Vec<_>) = - local_remote_read_split(partition_locations, force_remote_read); + let (memory_locations, local_locations, remote_locations) = + split_partition_locations(partition_locations, force_remote_read); debug!( - "local shuffle file counts:{}, remote shuffle file count:{}.", + "memory shuffle partition count: {}, local shuffle file counts: {}, remote shuffle file count: {}.", + memory_locations.len(), local_locations.len(), remote_locations.len() ); + // Read memory partitions first (fastest path) + let response_sender_m = response_sender.clone(); + spawned_tasks.push(SpawnedTask::spawn(async move { + for p in memory_locations { + let r = PartitionReaderEnum::Memory + .fetch_partition(&p, max_message_size, flight_transport, None, false) + .await; + if let Err(e) = response_sender_m.send(r).await { + error!("Fail to send response event to the channel due to {e}"); + } + } + })); + // keep local shuffle files reading in serial order for memory control. let response_sender_c = response_sender.clone(); let customize_endpoint_c = customize_endpoint.clone(); @@ -510,6 +554,11 @@ fn check_is_local_location(location: &PartitionLocation) -> bool { std::path::Path::new(location.path.as_str()).exists() } +/// Check if the partition location is stored in memory +fn check_is_memory_location(location: &PartitionLocation) -> bool { + location.path.starts_with("memory://") +} + /// Partition reader Trait, different partition reader can have #[async_trait] trait PartitionReader: Send + Sync + Clone { @@ -527,6 +576,7 @@ trait PartitionReader: Send + Sync + Clone { #[derive(Clone)] enum PartitionReaderEnum { Local, + Memory, FlightRemote, #[allow(dead_code)] ObjectStoreRemote, @@ -555,6 +605,7 @@ impl PartitionReader for PartitionReaderEnum { .await } PartitionReaderEnum::Local => fetch_partition_local(location).await, + PartitionReaderEnum::Memory => fetch_partition_memory(location).await, PartitionReaderEnum::ObjectStoreRemote => { fetch_partition_object_store(location).await } @@ -694,6 +745,94 @@ async fn fetch_partition_object_store( )) } +/// Fetch partition data from in-memory shuffle storage. +/// +/// After successfully fetching the data, the partition is removed from memory +/// to allow for immediate memory reclamation. This is safe because each shuffle +/// partition is typically read only once by the consuming stage. +async fn fetch_partition_memory( + location: &PartitionLocation, +) -> result::Result { + let path = &location.path; + let metadata = &location.executor_meta; + let partition_id = &location.partition_id; + + // Extract the key from the "memory://{key}" path format + let key = path.strip_prefix("memory://").ok_or_else(|| { + BallistaError::General(format!("Invalid in-memory partition path format: {path}")) + })?; + + let shuffle_manager = global_shuffle_manager(); + + // Remove and retrieve the partition data in one atomic operation + // This ensures the memory is reclaimed as soon as the data is read + let data = shuffle_manager + .remove_partition(key) + .ok_or_else(|| { + // If remove fails, try a regular get (for retry scenarios) + shuffle_manager.get_partition(key).map_err(|e| { + BallistaError::FetchFailed( + metadata.id.clone(), + partition_id.stage_id, + partition_id.partition_id, + e.to_string(), + ) + }) + }) + .or_else(|result| result)?; + + debug!( + "Fetched and removed partition {} from memory: {} batches, {} rows", + key, data.num_batches, data.num_rows + ); + + let batches = data.to_batches().map_err(|e| { + BallistaError::FetchFailed( + metadata.id.clone(), + partition_id.stage_id, + partition_id.partition_id, + format!("Failed to convert in-memory partition to batches: {e}"), + ) + })?; + + Ok(Box::pin(InMemoryShuffleStream::new(data.schema, batches))) +} + +/// Stream that reads from in-memory shuffle data +struct InMemoryShuffleStream { + schema: SchemaRef, + batches: std::vec::IntoIter, +} + +impl InMemoryShuffleStream { + pub fn new(schema: SchemaRef, batches: Vec) -> Self { + Self { + schema, + batches: batches.into_iter(), + } + } +} + +impl Stream for InMemoryShuffleStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll> { + match self.batches.next() { + Some(batch) => Poll::Ready(Some(Ok(batch))), + None => Poll::Ready(None), + } + } +} + +impl RecordBatchStream for InMemoryShuffleStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/ballista/core/src/execution_plans/shuffle_writer.rs b/ballista/core/src/execution_plans/shuffle_writer.rs index ec281da65e..526b2fe049 100644 --- a/ballista/core/src/execution_plans/shuffle_writer.rs +++ b/ballista/core/src/execution_plans/shuffle_writer.rs @@ -36,6 +36,9 @@ use std::sync::Arc; use std::time::Instant; use crate::config::ShuffleFormat; +use crate::execution_plans::shuffle_manager::{ + InMemoryShuffleManager, ShufflePartitionData, global_shuffle_manager, +}; use crate::extension::SessionConfigExt; use crate::utils; @@ -198,6 +201,16 @@ pub struct WriteTracker { pub path: PathBuf, } +/// Tracker for in-memory shuffle writes. +/// Collects record batches in memory instead of writing to disk. +pub struct InMemoryWriteTracker { + pub num_batches: usize, + pub num_rows: usize, + pub num_bytes: usize, + pub batches: Vec, + pub key: String, +} + #[derive(Debug, Clone)] struct ShuffleWriteMetrics { /// Time spend writing batches to shuffle files @@ -293,6 +306,16 @@ impl ShuffleWriterExec { let write_metrics = ShuffleWriteMetrics::new(input_partition, &self.metrics); let output_partitioning = self.shuffle_output_partitioning.clone(); let plan = self.plan.clone(); + let job_id = self.job_id.clone(); + let stage_id = self.stage_id; + + // Check if memory mode is enabled and this is not the final stage + // Final stages always write to disk to ensure proper cleanup via existing mechanisms + let memory_mode = context.session_config().ballista_shuffle_memory_mode(); + let is_final_stage = context.session_config().ballista_is_final_stage(); + + // Use memory mode only for intermediate stages, not for the final output stage + let use_memory = memory_mode && !is_final_stage; // Get shuffle format from session config let shuffle_format = context.session_config().ballista_shuffle_format(); @@ -302,139 +325,381 @@ impl ShuffleWriterExec { let now = Instant::now(); let mut stream = plan.execute(input_partition, context)?; - match output_partitioning { - None => { - let timer = write_metrics.write_time.timer(); - path.push(format!("{input_partition}")); - std::fs::create_dir_all(&path)?; - path.push(format!("data.{file_ext}")); - let path = path.to_str().unwrap(); - debug!("Writing results to {path} (format: {shuffle_format})"); - - // stream results to disk using configured format - let stats = utils::write_stream_to_disk_with_format( - &mut stream, - path, - &write_metrics.write_time, - shuffle_format, - ) - .await - .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; - - write_metrics - .input_rows - .add(stats.num_rows.unwrap_or(0) as usize); - write_metrics - .output_rows - .add(stats.num_rows.unwrap_or(0) as usize); - timer.done(); - - info!( - "Executed partition {} in {} seconds. Statistics: {}", - input_partition, - now.elapsed().as_secs(), - stats - ); - - Ok(vec![ShuffleWritePartition { - partition_id: input_partition as u64, - path: path.to_owned(), - num_batches: stats.num_batches.unwrap_or(0), - num_rows: stats.num_rows.unwrap_or(0), - num_bytes: stats.num_bytes.unwrap_or(0), - }]) + if use_memory { + // Use in-memory shuffle storage with configurable format + Self::execute_shuffle_write_memory( + &job_id, + stage_id, + input_partition, + &mut stream, + output_partitioning, + write_metrics, + now, + shuffle_format, + ) + .await + } else { + // Use disk-based shuffle storage with configurable format + // This is used for: + // 1. When memory_mode is disabled + // 2. For final stages (even if memory_mode is enabled) + Self::execute_shuffle_write_disk( + path, + input_partition, + &mut stream, + output_partitioning, + write_metrics, + now, + shuffle_format, + file_ext, + ) + .await + } + } + } + + /// Executes shuffle write to disk (original behavior). + #[allow(clippy::too_many_arguments)] + async fn execute_shuffle_write_disk( + mut path: PathBuf, + input_partition: usize, + stream: &mut std::pin::Pin< + Box, + >, + output_partitioning: Option, + write_metrics: ShuffleWriteMetrics, + now: Instant, + shuffle_format: ShuffleFormat, + file_ext: &str, + ) -> Result> { + match output_partitioning { + None => { + let timer = write_metrics.write_time.timer(); + path.push(format!("{input_partition}")); + std::fs::create_dir_all(&path)?; + path.push(format!("data.{file_ext}")); + let path = path.to_str().unwrap(); + debug!("Writing results to {path} (format: {shuffle_format})"); + + // stream results to disk using configured format + let stats = utils::write_stream_to_disk_with_format( + stream, + path, + &write_metrics.write_time, + shuffle_format, + ) + .await + .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; + + write_metrics + .input_rows + .add(stats.num_rows.unwrap_or(0) as usize); + write_metrics + .output_rows + .add(stats.num_rows.unwrap_or(0) as usize); + timer.done(); + + info!( + "Executed partition {} in {} seconds. Statistics: {}", + input_partition, + now.elapsed().as_secs(), + stats + ); + + Ok(vec![ShuffleWritePartition { + partition_id: input_partition as u64, + path: path.to_owned(), + num_batches: stats.num_batches.unwrap_or(0), + num_rows: stats.num_rows.unwrap_or(0), + num_bytes: stats.num_bytes.unwrap_or(0), + }]) + } + + Some(Partitioning::Hash(exprs, num_output_partitions)) => { + // we won't necessary produce output for every possible partition, so we + // create writers on demand + let mut writers: Vec> = vec![]; + for _ in 0..num_output_partitions { + writers.push(None); } - Some(Partitioning::Hash(exprs, num_output_partitions)) => { - // we won't necessary produce output for every possible partition, so we - // create writers on demand - let mut writers: Vec> = vec![]; - for _ in 0..num_output_partitions { - writers.push(None); - } + let mut partitioner = BatchPartitioner::try_new( + Partitioning::Hash(exprs, num_output_partitions), + write_metrics.repart_time.clone(), + )?; + + let schema = stream.schema(); + + while let Some(result) = stream.next().await { + let input_batch = result?; - let mut partitioner = BatchPartitioner::try_new( - Partitioning::Hash(exprs, num_output_partitions), - write_metrics.repart_time.clone(), + write_metrics.input_rows.add(input_batch.num_rows()); + + partitioner.partition( + input_batch, + |output_partition, output_batch| { + // partition func in datafusion make sure not write empty output_batch. + let timer = write_metrics.write_time.timer(); + match &mut writers[output_partition] { + Some(w) => { + w.num_batches += 1; + w.num_rows += output_batch.num_rows(); + w.writer.write(&output_batch)?; + } + None => { + let mut file_path = path.clone(); + file_path.push(format!("{output_partition}")); + std::fs::create_dir_all(&file_path)?; + + file_path.push(format!( + "data-{input_partition}.{file_ext}" + )); + debug!("Writing results to {file_path:?} (format: {shuffle_format})"); + + let mut writer = ShuffleFileWriter::try_new( + file_path.clone(), + schema.clone(), + shuffle_format, + )?; + + writer.write(&output_batch)?; + writers[output_partition] = Some(WriteTracker { + num_batches: 1, + num_rows: output_batch.num_rows(), + writer, + path: file_path, + }); + } + } + write_metrics.output_rows.add(output_batch.num_rows()); + timer.done(); + Ok(()) + }, )?; + } - let schema = stream.schema(); - - while let Some(result) = stream.next().await { - let input_batch = result?; - - write_metrics.input_rows.add(input_batch.num_rows()); - - partitioner.partition( - input_batch, - |output_partition, output_batch| { - // partition func in datafusion make sure not write empty output_batch. - let timer = write_metrics.write_time.timer(); - match &mut writers[output_partition] { - Some(w) => { - w.num_batches += 1; - w.num_rows += output_batch.num_rows(); - w.writer.write(&output_batch)?; - } - None => { - let mut file_path = path.clone(); - file_path.push(format!("{output_partition}")); - std::fs::create_dir_all(&file_path)?; - - file_path.push(format!( - "data-{input_partition}.{file_ext}" - )); - debug!("Writing results to {file_path:?} (format: {shuffle_format})"); - - let mut writer = ShuffleFileWriter::try_new( - file_path.clone(), - schema.clone(), - shuffle_format, - )?; - - writer.write(&output_batch)?; - writers[output_partition] = Some(WriteTracker { + let mut part_locs = vec![]; + + for (i, w) in writers.into_iter().enumerate() { + if let Some(w) = w { + let num_bytes = fs::metadata(&w.path)?.len(); + w.writer.finish()?; + debug!( + "Finished writing shuffle partition {} at {:?}. Batches: {}. Rows: {}. Bytes: {}.", + i, w.path, w.num_batches, w.num_rows, num_bytes + ); + + part_locs.push(ShuffleWritePartition { + partition_id: i as u64, + path: w.path.to_string_lossy().to_string(), + num_batches: w.num_batches as u64, + num_rows: w.num_rows as u64, + num_bytes, + }); + } + } + Ok(part_locs) + } + + _ => Err(DataFusionError::Execution( + "Invalid shuffle partitioning scheme".to_owned(), + )), + } + } + + /// Executes shuffle write to in-memory storage. + #[allow(clippy::too_many_arguments)] + async fn execute_shuffle_write_memory( + job_id: &str, + stage_id: usize, + input_partition: usize, + stream: &mut std::pin::Pin< + Box, + >, + output_partitioning: Option, + write_metrics: ShuffleWriteMetrics, + now: Instant, + shuffle_format: ShuffleFormat, + ) -> Result> { + let shuffle_manager = global_shuffle_manager(); + let schema = stream.schema(); + + match output_partitioning { + None => { + let timer = write_metrics.write_time.timer(); + + // Collect all batches into memory + let mut batches = Vec::new(); + let mut num_rows = 0usize; + let mut num_bytes = 0usize; + + while let Some(result) = stream.next().await { + let batch = result?; + num_rows += batch.num_rows(); + num_bytes += batch.get_array_memory_size(); + write_metrics.input_rows.add(batch.num_rows()); + write_metrics.output_rows.add(batch.num_rows()); + batches.push(batch); + } + + let num_batches = batches.len(); + let key = InMemoryShuffleManager::partition_key( + job_id, + stage_id, + input_partition, + ); + + // Store in the global shuffle manager using the configured format + let data = + Self::create_partition_data(schema.clone(), batches, shuffle_format)?; + shuffle_manager.store_partition(key.clone(), data); + + timer.done(); + + info!( + "Executed partition {} to memory ({shuffle_format}) in {} seconds. Batches: {}, Rows: {}, Bytes: {}", + input_partition, + now.elapsed().as_secs(), + num_batches, + num_rows, + num_bytes + ); + + // Use special "memory://" prefix to indicate in-memory storage + Ok(vec![ShuffleWritePartition { + partition_id: input_partition as u64, + path: format!("memory://{key}"), + num_batches: num_batches as u64, + num_rows: num_rows as u64, + num_bytes: num_bytes as u64, + }]) + } + + Some(Partitioning::Hash(exprs, num_output_partitions)) => { + // We collect batches per output partition in memory + let mut mem_writers: Vec> = vec![]; + for _ in 0..num_output_partitions { + mem_writers.push(None); + } + + let mut partitioner = BatchPartitioner::try_new( + Partitioning::Hash(exprs, num_output_partitions), + write_metrics.repart_time.clone(), + )?; + + while let Some(result) = stream.next().await { + let input_batch = result?; + write_metrics.input_rows.add(input_batch.num_rows()); + + partitioner.partition( + input_batch, + |output_partition, output_batch| { + let timer = write_metrics.write_time.timer(); + let batch_bytes = output_batch.get_array_memory_size(); + let batch_rows = output_batch.num_rows(); + + match &mut mem_writers[output_partition] { + Some(w) => { + w.num_batches += 1; + w.num_rows += batch_rows; + w.num_bytes += batch_bytes; + w.batches.push(output_batch); + } + None => { + let key = InMemoryShuffleManager::hash_partition_key( + job_id, + stage_id, + output_partition, + input_partition, + ); + mem_writers[output_partition] = + Some(InMemoryWriteTracker { num_batches: 1, - num_rows: output_batch.num_rows(), - writer, - path: file_path, + num_rows: batch_rows, + num_bytes: batch_bytes, + batches: vec![output_batch], + key, }); - } } - write_metrics.output_rows.add(output_batch.num_rows()); - timer.done(); - Ok(()) - }, + } + write_metrics.output_rows.add(batch_rows); + timer.done(); + Ok(()) + }, + )?; + } + + let mut part_locs = vec![]; + + for (i, w) in mem_writers.into_iter().enumerate() { + if let Some(w) = w { + debug!( + "Finished writing shuffle partition {} to memory ({shuffle_format}). Batches: {}. Rows: {}. Bytes: {}.", + i, w.num_batches, w.num_rows, w.num_bytes + ); + + // Store in the global shuffle manager using the configured format + let data = Self::create_partition_data( + schema.clone(), + w.batches, + shuffle_format, )?; + shuffle_manager.store_partition(w.key.clone(), data); + + part_locs.push(ShuffleWritePartition { + partition_id: i as u64, + path: format!("memory://{}", w.key), + num_batches: w.num_batches as u64, + num_rows: w.num_rows as u64, + num_bytes: w.num_bytes as u64, + }); } + } + Ok(part_locs) + } - let mut part_locs = vec![]; - - for (i, w) in writers.into_iter().enumerate() { - if let Some(w) = w { - let num_bytes = fs::metadata(&w.path)?.len(); - w.writer.finish()?; - debug!( - "Finished writing shuffle partition {} at {:?}. Batches: {}. Rows: {}. Bytes: {}.", - i, w.path, w.num_batches, w.num_rows, num_bytes - ); - - part_locs.push(ShuffleWritePartition { - partition_id: i as u64, - path: w.path.to_string_lossy().to_string(), - num_batches: w.num_batches as u64, - num_rows: w.num_rows as u64, - num_bytes, - }); - } - } - Ok(part_locs) + _ => Err(DataFusionError::Execution( + "Invalid shuffle partitioning scheme".to_owned(), + )), + } + } + + /// Creates partition data in the specified format (Arrow or Vortex). + fn create_partition_data( + schema: SchemaRef, + batches: Vec, + format: ShuffleFormat, + ) -> Result { + match format { + ShuffleFormat::ArrowIpc => Ok(ShufflePartitionData::new(schema, batches)), + #[cfg(feature = "vortex")] + ShuffleFormat::Vortex => { + use vortex_array::ArrayRef; + use vortex_array::arrow::FromArrowArray; + + let mut arrays = Vec::with_capacity(batches.len()); + let mut total_rows = 0u64; + let mut total_bytes = 0u64; + + for batch in batches { + total_rows += batch.num_rows() as u64; + // Convert Arrow RecordBatch to Vortex Array + let vortex_array = ArrayRef::from_arrow(&batch, false); + total_bytes += vortex_array.nbytes(); + arrays.push(vortex_array); } - _ => Err(DataFusionError::Execution( - "Invalid shuffle partitioning scheme".to_owned(), - )), + Ok(ShufflePartitionData::new_vortex( + schema, + arrays, + total_rows, + total_bytes, + )) } + #[cfg(not(feature = "vortex"))] + ShuffleFormat::Vortex => Err(DataFusionError::NotImplemented( + "Vortex format requires the 'vortex' feature to be enabled".to_string(), + )), } } } diff --git a/ballista/core/src/extension.rs b/ballista/core/src/extension.rs index 05b33a7de6..8cb4a4ebbd 100644 --- a/ballista/core/src/extension.rs +++ b/ballista/core/src/extension.rs @@ -16,7 +16,8 @@ // under the License. use crate::config::{ - BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE, BALLISTA_JOB_NAME, BALLISTA_SHUFFLE_FORMAT, + BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE, BALLISTA_IS_FINAL_STAGE, BALLISTA_JOB_NAME, + BALLISTA_SHUFFLE_FORMAT, BALLISTA_SHUFFLE_MEMORY_MODE, BALLISTA_SHUFFLE_READER_FORCE_REMOTE_READ, BALLISTA_SHUFFLE_READER_MAX_REQUESTS, BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT, BALLISTA_STANDALONE_PARALLELISM, BallistaConfig, ShuffleFormat, @@ -155,6 +156,19 @@ pub trait SessionConfigExt { prefer_flight: bool, ) -> Self; + /// Returns whether in-memory shuffle mode is enabled. + fn ballista_shuffle_memory_mode(&self) -> bool; + + /// Sets whether to use in-memory shuffle mode. + fn with_ballista_shuffle_memory_mode(self, memory_mode: bool) -> Self; + + /// Returns whether this is the final output stage. + /// Final stages always write to disk regardless of memory_mode setting. + fn ballista_is_final_stage(&self) -> bool; + + /// Sets whether this is the final output stage. + fn with_ballista_is_final_stage(self, is_final: bool) -> Self; + /// Set user defined metadata keys in Ballista gRPC requests fn with_ballista_grpc_metadata(self, metadata: HashMap) -> Self; @@ -464,6 +478,40 @@ impl SessionConfigExt for SessionConfig { } } + fn ballista_shuffle_memory_mode(&self) -> bool { + self.options() + .extensions + .get::() + .map(|c| c.shuffle_memory_mode()) + .unwrap_or_else(|| BallistaConfig::default().shuffle_memory_mode()) + } + + fn with_ballista_shuffle_memory_mode(self, memory_mode: bool) -> Self { + if self.options().extensions.get::().is_some() { + self.set_bool(BALLISTA_SHUFFLE_MEMORY_MODE, memory_mode) + } else { + self.with_option_extension(BallistaConfig::default()) + .set_bool(BALLISTA_SHUFFLE_MEMORY_MODE, memory_mode) + } + } + + fn ballista_is_final_stage(&self) -> bool { + self.options() + .extensions + .get::() + .map(|c| c.is_final_stage()) + .unwrap_or_else(|| BallistaConfig::default().is_final_stage()) + } + + fn with_ballista_is_final_stage(self, is_final: bool) -> Self { + if self.options().extensions.get::().is_some() { + self.set_bool(BALLISTA_IS_FINAL_STAGE, is_final) + } else { + self.with_option_extension(BallistaConfig::default()) + .set_bool(BALLISTA_IS_FINAL_STAGE, is_final) + } + } + fn with_ballista_grpc_metadata(self, metadata: HashMap) -> Self { let extension = BallistaGrpcMetadataInterceptor::new(metadata); self.with_extension(Arc::new(extension)) @@ -922,4 +970,79 @@ mod test { .any(|p| p.key == "datafusion.catalog.information_schema") ) } + + #[test] + fn test_is_final_stage_config() { + // Default should be false + let config = SessionConfig::new_with_ballista(); + assert!(!config.ballista_is_final_stage()); + + // Set to true + let config = config.with_ballista_is_final_stage(true); + assert!(config.ballista_is_final_stage()); + + // Set back to false + let config = config.with_ballista_is_final_stage(false); + assert!(!config.ballista_is_final_stage()); + } + + #[test] + fn test_shuffle_memory_mode_config() { + // Default should be false (disk-based) + let config = SessionConfig::new_with_ballista(); + assert!(!config.ballista_shuffle_memory_mode()); + + // Enable memory mode + let config = config.with_ballista_shuffle_memory_mode(true); + assert!(config.ballista_shuffle_memory_mode()); + + // Disable memory mode + let config = config.with_ballista_shuffle_memory_mode(false); + assert!(!config.ballista_shuffle_memory_mode()); + } + + #[test] + fn test_shuffle_format_config() { + use crate::config::ShuffleFormat; + + // Default should be ArrowIpc + let config = SessionConfig::new_with_ballista(); + assert_eq!(config.ballista_shuffle_format(), ShuffleFormat::ArrowIpc); + + // Set to Vortex + let config = config.with_ballista_shuffle_format(ShuffleFormat::Vortex); + assert_eq!(config.ballista_shuffle_format(), ShuffleFormat::Vortex); + + // Set back to ArrowIpc + let config = config.with_ballista_shuffle_format(ShuffleFormat::ArrowIpc); + assert_eq!(config.ballista_shuffle_format(), ShuffleFormat::ArrowIpc); + } + + #[test] + fn test_is_final_stage_serialization() { + use crate::config::BALLISTA_IS_FINAL_STAGE; + + // Test that is_final_stage is included in key-value pairs + let config = + SessionConfig::new_with_ballista().with_ballista_is_final_stage(true); + let pairs = config.to_key_value_pairs(); + + let is_final_pair = pairs.iter().find(|p| p.key == BALLISTA_IS_FINAL_STAGE); + assert!(is_final_pair.is_some()); + assert_eq!(is_final_pair.unwrap().value, Some("true".to_string())); + } + + #[test] + fn test_config_without_ballista_extension() { + // Test that methods work even without explicit ballista extension + let config = SessionConfig::new(); + + // Should return defaults + assert!(!config.ballista_is_final_stage()); + assert!(!config.ballista_shuffle_memory_mode()); + + // Should be able to set values (which adds the extension) + let config = config.with_ballista_is_final_stage(true); + assert!(config.ballista_is_final_stage()); + } } diff --git a/ballista/executor/Cargo.toml b/ballista/executor/Cargo.toml index 41ab99b904..f0fb6043a7 100644 --- a/ballista/executor/Cargo.toml +++ b/ballista/executor/Cargo.toml @@ -45,6 +45,7 @@ async-trait = { workspace = true } backoff = { workspace = true } ballista-core = { path = "../core", version = "51.0.0" } +bytes = "1" clap = { workspace = true, optional = true } dashmap = { workspace = true } diff --git a/ballista/executor/src/executor_server.rs b/ballista/executor/src/executor_server.rs index a792e89be2..364c094335 100644 --- a/ballista/executor/src/executor_server.rs +++ b/ballista/executor/src/executor_server.rs @@ -38,6 +38,7 @@ pub type EndpointOverrideFn = Arc Result + Send + Sync>; use ballista_core::error::BallistaError; +use ballista_core::execution_plans::global_shuffle_manager; use ballista_core::serde::BallistaCodec; use ballista_core::serde::protobuf::{ CancelTasksParams, CancelTasksResult, ExecutorMetric, ExecutorStatus, @@ -774,6 +775,11 @@ impl ExecutorGrpc ) -> Result, Status> { let job_id = request.into_inner().job_id; + // Clean up in-memory shuffle partitions for this job + let shuffle_manager = global_shuffle_manager(); + shuffle_manager.remove_job_partitions(&job_id); + + // Clean up disk-based shuffle data remove_job_dir(&self.executor.work_dir, &job_id) .await .map_err(|e| Status::invalid_argument(e.to_string()))?; diff --git a/ballista/executor/src/flight_service.rs b/ballista/executor/src/flight_service.rs index 73f9128e67..ede4299c50 100644 --- a/ballista/executor/src/flight_service.rs +++ b/ballista/executor/src/flight_service.rs @@ -28,6 +28,7 @@ use tokio_util::io::ReaderStream; use arrow_flight::encode::FlightDataEncoderBuilder; use arrow_flight::error::FlightError; use ballista_core::error::BallistaError; +use ballista_core::execution_plans::global_shuffle_manager; use ballista_core::serde::decode_protobuf; use ballista_core::serde::scheduler::Action as BallistaAction; use datafusion::arrow::ipc::CompressionType; @@ -98,6 +99,55 @@ impl FlightService for BallistaFlightService { match &action { BallistaAction::FetchPartition { path, .. } => { + // Check if this is an in-memory partition + if let Some(key) = path.strip_prefix("memory://") { + // Fetch from in-memory shuffle manager + let shuffle_manager = global_shuffle_manager(); + let data = shuffle_manager.get_partition(key).map_err(|e| { + Status::not_found(format!( + "In-memory partition not found: {key}: {e}" + )) + })?; + + debug!( + "FetchPartition serving in-memory partition: {} ({} batches, {} rows, format: {:?})", + key, data.num_batches, data.num_rows, data.format + ); + + let (tx, rx) = channel(2); + let schema = data.schema.clone(); + + // Convert to batches (handles both Arrow and Vortex formats) + let batches = data.to_batches().map_err(|e| { + Status::internal(format!( + "Failed to convert in-memory partition to batches: {e}" + )) + })?; + + // Stream the batches from memory + task::spawn(async move { + for batch in batches { + if tx.send(Ok(batch)).await.is_err() { + break; + } + } + }); + + let write_options: IpcWriteOptions = IpcWriteOptions::default() + .try_with_compression(Some(CompressionType::LZ4_FRAME)) + .map_err(|e| from_arrow_err(&e))?; + let flight_data_stream = FlightDataEncoderBuilder::new() + .with_schema(schema) + .with_options(write_options) + .build(ReceiverStream::new(rx)) + .map_err(|err| Status::from_error(Box::new(err))); + + return Ok(Response::new( + Box::pin(flight_data_stream) as Self::DoGetStream + )); + } + + // Handle disk-based partition // Detect shuffle format based on file extension let is_vortex = Path::new(path) .extension() @@ -107,12 +157,6 @@ impl FlightService for BallistaFlightService { let format = if is_vortex { "vortex" } else { "arrow-ipc" }; debug!("FetchPartition reading {path} (format: {format})"); - // Detect shuffle format based on file extension - let is_vortex = Path::new(path) - .extension() - .map(|ext| ext == "vortex") - .unwrap_or(false); - let (schema, rx) = if is_vortex { #[cfg(feature = "vortex")] { @@ -222,6 +266,64 @@ impl FlightService for BallistaFlightService { match &action { BallistaAction::FetchPartition { path, .. } => { debug!("FetchPartition reading {path}"); + + // Check if this is an in-memory partition + // For in-memory partitions, we need to serialize to IPC format first + if let Some(key) = path.strip_prefix("memory://") { + let shuffle_manager = global_shuffle_manager(); + let data = + shuffle_manager.get_partition(key).map_err(|e| { + Status::not_found(format!( + "In-memory partition not found: {key}: {e}" + )) + })?; + + debug!( + "FetchPartition serving in-memory partition via block transfer: {} ({} batches, format: {:?})", + key, data.num_batches, data.format + ); + + // Convert to batches (handles both Arrow and Vortex formats) + let batches = data.to_batches().map_err(|e| { + Status::internal(format!( + "Failed to convert in-memory partition to batches: {e}" + )) + })?; + + // Serialize batches to IPC format in memory + let mut buffer = Vec::new(); + { + use datafusion::arrow::ipc::writer::StreamWriter; + let mut writer = StreamWriter::try_new_with_options( + &mut buffer, + &data.schema, + IpcWriteOptions::default() + .try_with_compression(Some( + CompressionType::LZ4_FRAME, + )) + .map_err(|e| from_arrow_err(&e))?, + ) + .map_err(|e| from_arrow_err(&e))?; + + for batch in &batches { + writer + .write(batch) + .map_err(|e| from_arrow_err(&e))?; + } + writer.finish().map_err(|e| from_arrow_err(&e))?; + } + + let bytes = bytes::Bytes::from(buffer); + let result_stream = futures::stream::once(async move { + Ok(arrow_flight::Result { body: bytes }) + }); + + return Ok(Response::new( + Box::pin(result_stream) as Self::DoActionStream + )); + } + + // Handle disk-based partition let file = tokio::fs::File::open(&path).await.map_err(|e| { Status::internal(format!("Failed to open file: {e}")) })?; diff --git a/ballista/scheduler/src/api/mod.rs b/ballista/scheduler/src/api/mod.rs index 5dd1aee2e4..42b7b89e1f 100644 --- a/ballista/scheduler/src/api/mod.rs +++ b/ballista/scheduler/src/api/mod.rs @@ -19,6 +19,7 @@ use datafusion_proto::logical_plan::AsLogicalPlan; use datafusion_proto::physical_plan::AsExecutionPlan; use std::sync::Arc; +/// Creates the API routes for the scheduler REST API. pub fn get_routes< T: AsLogicalPlan + Clone + Send + Sync + 'static, U: AsExecutionPlan + Send + Sync + 'static, diff --git a/ballista/scheduler/src/metrics/mod.rs b/ballista/scheduler/src/metrics/mod.rs index 934dda1d95..984a9990d2 100644 --- a/ballista/scheduler/src/metrics/mod.rs +++ b/ballista/scheduler/src/metrics/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +/// Prometheus metrics collector implementation. #[cfg(feature = "prometheus")] pub mod prometheus; diff --git a/ballista/scheduler/src/metrics/prometheus.rs b/ballista/scheduler/src/metrics/prometheus.rs index 4cf68088e0..2628155f3a 100644 --- a/ballista/scheduler/src/metrics/prometheus.rs +++ b/ballista/scheduler/src/metrics/prometheus.rs @@ -49,6 +49,7 @@ pub struct PrometheusMetricsCollector { } impl PrometheusMetricsCollector { + /// Creates a new PrometheusMetricsCollector with the given registry. pub fn new(registry: &Registry) -> Result { let execution_time = register_histogram_with_registry!( "job_exec_time_seconds", @@ -126,6 +127,7 @@ impl PrometheusMetricsCollector { }) } + /// Returns the current global PrometheusMetricsCollector instance. pub fn current() -> Result> { COLLECTOR .get_or_try_init(|| { @@ -133,7 +135,7 @@ impl PrometheusMetricsCollector { Ok(Arc::new(collector) as Arc) }) - .map(|arc| arc.clone()) + .cloned() } } @@ -206,6 +208,22 @@ impl SchedulerMetricsCollector for PrometheusMetricsCollector { } fn record_task_retry(&self, _job_id: &str, _stage_id: usize) {} + // Shuffle affinity - not tracked by default Prometheus collector + fn record_task_shuffle_affinity_hit( + &self, + _job_id: &str, + _stage_id: usize, + _executor_id: &str, + ) { + } + fn record_task_shuffle_affinity_miss( + &self, + _job_id: &str, + _stage_id: usize, + _executor_id: &str, + ) { + } + // Executor management - not tracked by default Prometheus collector fn set_active_executor_count(&self, _count: usize) {} fn record_executor_registered(&self, _executor_id: &str) {} diff --git a/ballista/scheduler/src/scheduler_server/mod.rs b/ballista/scheduler/src/scheduler_server/mod.rs index a6d8ac18c3..1bffda332f 100644 --- a/ballista/scheduler/src/scheduler_server/mod.rs +++ b/ballista/scheduler/src/scheduler_server/mod.rs @@ -46,6 +46,7 @@ use crate::state::task_manager::TaskLauncher; // include the generated protobuf source as a submodule #[cfg(feature = "keda-scaler")] #[allow(clippy::all)] +#[allow(missing_docs)] pub mod externalscaler { include!(concat!(env!("OUT_DIR"), "/externalscaler.rs")); } diff --git a/ballista/scheduler/src/state/execution_graph.rs b/ballista/scheduler/src/state/execution_graph.rs index f9526beb2f..6f0862911d 100644 --- a/ballista/scheduler/src/state/execution_graph.rs +++ b/ballista/scheduler/src/state/execution_graph.rs @@ -29,6 +29,7 @@ use log::{debug, error, info, warn}; use ballista_core::error::{BallistaError, Result}; use ballista_core::execution_plans::{ShuffleWriterExec, UnresolvedShuffleExec}; +use ballista_core::extension::SessionConfigExt; use ballista_core::serde::protobuf::failed_task::FailedReason; use ballista_core::serde::protobuf::job_status::Status; use ballista_core::serde::protobuf::{FailedJob, ShuffleWritePartition, job_status}; @@ -1128,6 +1129,20 @@ impl ExecutionGraph { // Set the task info to Running for new task stage.task_infos[partition_id] = Some(task_info); + // Check if this is the final stage (no output links means this is the output stage) + let is_final_stage = stage.output_links.is_empty(); + + // Create session config with the is_final_stage flag set + let task_session_config = if is_final_stage { + Arc::new( + (*self.session_config) + .clone() + .with_ballista_is_final_stage(true), + ) + } else { + self.session_config.clone() + }; + Ok(TaskDescription { session_id, partition, @@ -1135,7 +1150,7 @@ impl ExecutionGraph { task_id, task_attempt, plan: stage.plan.clone(), - session_config: self.session_config.clone(), + session_config: task_session_config, schedulable_time_millis: stage.stage_running_time, }) } else { @@ -2903,6 +2918,65 @@ mod test { // todo!() // } + /// Test that is_final_stage flag is correctly set on tasks from the final output stage + #[tokio::test] + async fn test_is_final_stage_flag() -> Result<()> { + use ballista_core::extension::SessionConfigExt; + + // Create a simple two-stage aggregation plan + // Stage 1: partial aggregation (has output_links to stage 2) + // Stage 2: final aggregation (no output_links - this is the final stage) + let mut agg_graph = test_aggregation_plan(4).await; + + let executor = mock_executor("executor-id1".to_string()); + + // Collect all tasks and their is_final_stage flags + let mut stages_and_flags: Vec<(usize, bool)> = Vec::new(); + + while let Some(task) = agg_graph.pop_next_task(&executor.id)? { + let stage_id = task.partition.stage_id; + let is_final = task.session_config.ballista_is_final_stage(); + stages_and_flags.push((stage_id, is_final)); + + // Complete the task to move to next stage + let task_status = mock_completed_task(task, &executor.id); + agg_graph.update_task_status(&executor, vec![task_status], 1, 1)?; + } + + // Verify we got tasks from multiple stages + assert!( + !stages_and_flags.is_empty(), + "Should have at least one task" + ); + + // Get unique stage IDs + let unique_stages: HashSet = + stages_and_flags.iter().map(|(s, _)| *s).collect(); + + // Find the final stage (highest stage number for aggregation plan) + let final_stage_id = *unique_stages.iter().max().unwrap(); + + // Verify: tasks from non-final stages should have is_final_stage=false + // tasks from final stage should have is_final_stage=true + for (stage_id, is_final) in &stages_and_flags { + if *stage_id == final_stage_id { + assert!( + *is_final, + "Final stage {} should have is_final_stage=true", + stage_id + ); + } else { + assert!( + !*is_final, + "Non-final stage {} should have is_final_stage=false", + stage_id + ); + } + } + + Ok(()) + } + fn drain_tasks(graph: &mut ExecutionGraph) -> Result<()> { let executor = mock_executor("executor-id1".to_string()); while let Some(task) = graph.pop_next_task(&executor.id)? {