@@ -48,6 +48,11 @@ pub const BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT: &str =
4848pub const BALLISTA_SHUFFLE_STORAGE_TYPE : & str = "ballista.shuffle.storage_type" ;
4949/// Configuration key for shuffle storage base URL/path.
5050pub const BALLISTA_SHUFFLE_STORAGE_URL : & str = "ballista.shuffle.storage_url" ;
51+ /// Configuration key for shuffle storage mode (disk or memory).
52+ pub const BALLISTA_SHUFFLE_MEMORY_MODE : & str = "ballista.shuffle.memory_mode" ;
53+ /// Configuration key indicating if this is the final output stage.
54+ /// When true, shuffle data is always written to disk regardless of memory_mode setting.
55+ pub const BALLISTA_IS_FINAL_STAGE : & str = "ballista.shuffle.is_final_stage" ;
5156/// Shuffle format configuration: "arrow_ipc" or "vortex"
5257pub const BALLISTA_SHUFFLE_FORMAT : & str = "ballista.shuffle.format" ;
5358
@@ -100,6 +105,14 @@ static CONFIG_ENTRIES: LazyLock<HashMap<String, ConfigEntry>> = LazyLock::new(||
100105 "Base URL/path for shuffle storage. For local: file path; For S3: s3://bucket/prefix; For Azure: abfs://container@account.dfs.core.windows.net/prefix" . to_string( ) ,
101106 DataType :: Utf8 ,
102107 None ) ,
108+ ConfigEntry :: new( BALLISTA_SHUFFLE_MEMORY_MODE . to_string( ) ,
109+ "When enabled, shuffle data is kept in memory on executors instead of being written to disk. This can improve performance for workloads with sufficient memory." . to_string( ) ,
110+ DataType :: Boolean ,
111+ Some ( ( false ) . to_string( ) ) ) ,
112+ ConfigEntry :: new( BALLISTA_IS_FINAL_STAGE . to_string( ) ,
113+ "When true, indicates this is the final output stage. Final stages always write to disk regardless of memory_mode setting to ensure proper cleanup." . to_string( ) ,
114+ DataType :: Boolean ,
115+ Some ( ( false ) . to_string( ) ) ) ,
103116 ConfigEntry :: new( BALLISTA_GRPC_CLIENT_CONNECT_TIMEOUT_SECONDS . to_string( ) ,
104117 "Connection timeout for gRPC client in seconds" . to_string( ) ,
105118 DataType :: UInt64 ,
@@ -329,6 +342,21 @@ impl BallistaConfig {
329342 self . settings . get ( BALLISTA_SHUFFLE_STORAGE_URL ) . cloned ( )
330343 }
331344
345+ /// Returns whether in-memory shuffle mode is enabled.
346+ ///
347+ /// When enabled, shuffle data is kept in memory on executors instead of
348+ /// being written to disk. This can improve performance for workloads
349+ /// with sufficient memory.
350+ pub fn shuffle_memory_mode ( & self ) -> bool {
351+ self . get_bool_setting ( BALLISTA_SHUFFLE_MEMORY_MODE )
352+ }
353+
354+ /// Returns whether this is the final output stage.
355+ /// Final stages always write to disk regardless of memory_mode setting.
356+ pub fn is_final_stage ( & self ) -> bool {
357+ self . get_bool_setting ( BALLISTA_IS_FINAL_STAGE )
358+ }
359+
332360 /// Returns the configured shuffle format (ArrowIpc or Vortex)
333361 ///
334362 /// Note: Vortex format requires the 'vortex' feature to be enabled.
@@ -499,4 +527,46 @@ mod tests {
499527 assert_eq ! ( 16777216 , config. default_grpc_client_max_message_size( ) ) ;
500528 Ok ( ( ) )
501529 }
530+
531+ #[ test]
532+ fn test_is_final_stage_default ( ) {
533+ let config = BallistaConfig :: default ( ) ;
534+ // Default should be false
535+ assert ! ( !config. is_final_stage( ) ) ;
536+ }
537+
538+ #[ test]
539+ fn test_shuffle_memory_mode_default ( ) {
540+ let config = BallistaConfig :: default ( ) ;
541+ // Default should be false (disk-based shuffles)
542+ assert ! ( !config. shuffle_memory_mode( ) ) ;
543+ }
544+
545+ #[ test]
546+ fn test_shuffle_format_default ( ) {
547+ let config = BallistaConfig :: default ( ) ;
548+ // Default should be ArrowIpc
549+ assert_eq ! ( config. shuffle_format( ) , ShuffleFormat :: ArrowIpc ) ;
550+ }
551+
552+ #[ test]
553+ fn test_shuffle_format_parsing ( ) {
554+ assert_eq ! (
555+ "arrow_ipc" . parse:: <ShuffleFormat >( ) . unwrap( ) ,
556+ ShuffleFormat :: ArrowIpc
557+ ) ;
558+ assert_eq ! (
559+ "arrow-ipc" . parse:: <ShuffleFormat >( ) . unwrap( ) ,
560+ ShuffleFormat :: ArrowIpc
561+ ) ;
562+ assert_eq ! (
563+ "ipc" . parse:: <ShuffleFormat >( ) . unwrap( ) ,
564+ ShuffleFormat :: ArrowIpc
565+ ) ;
566+ assert_eq ! (
567+ "vortex" . parse:: <ShuffleFormat >( ) . unwrap( ) ,
568+ ShuffleFormat :: Vortex
569+ ) ;
570+ assert ! ( "invalid" . parse:: <ShuffleFormat >( ) . is_err( ) ) ;
571+ }
502572}
0 commit comments