Skip to content

Commit 04e8c78

Browse files
refactor(rust): Add streaming IO sink components (#25594)
1 parent 4359355 commit 04e8c78

File tree

21 files changed

+861
-4
lines changed

21 files changed

+861
-4
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
use std::sync::Arc;
2+
3+
use futures::{StreamExt as _, TryStreamExt as _};
4+
use polars_core::frame::DataFrame;
5+
use polars_core::prelude::sort::arg_sort;
6+
use polars_core::prelude::{Column, IdxArr, SortMultipleOptions};
7+
use polars_error::PolarsResult;
8+
use polars_expr::state::ExecutionState;
9+
10+
use crate::async_executor::TaskPriority;
11+
use crate::async_primitives::opt_spawned_future::parallelize_first_to_local;
12+
use crate::expression::StreamExpr;
13+
14+
#[derive(Clone)]
15+
pub struct ArgSortBy {
16+
pub by: Arc<[StreamExpr]>,
17+
pub sort_options: SortMultipleOptions,
18+
pub in_memory_exec_state: Arc<ExecutionState>,
19+
}
20+
21+
impl ArgSortBy {
22+
pub async fn arg_sort_by_par(self, df: &Arc<DataFrame>) -> PolarsResult<IdxArr> {
23+
let ArgSortBy {
24+
by,
25+
sort_options,
26+
in_memory_exec_state,
27+
} = self;
28+
29+
let sort_by_cols: Vec<Column> = futures::stream::iter(parallelize_first_to_local(
30+
TaskPriority::Low,
31+
(0..by.len()).map(|i| {
32+
let df = Arc::clone(df);
33+
let by = Arc::clone(&by);
34+
let in_memory_exec_state = Arc::clone(&in_memory_exec_state);
35+
36+
async move {
37+
by[i]
38+
.evaluate(&df, in_memory_exec_state.as_ref())
39+
.await
40+
.map(|c| c.rechunk())
41+
}
42+
}),
43+
))
44+
.then(|x| x)
45+
.try_collect()
46+
.await?;
47+
48+
Ok(arg_sort(&sort_by_cols, sort_options)?
49+
.downcast_as_array()
50+
.clone())
51+
}
52+
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
use std::any::Any;
2+
use std::panic::AssertUnwindSafe;
3+
4+
use futures::FutureExt;
5+
use polars_error::{PolarsError, PolarsResult};
6+
7+
/// Utility to capture errors and propagate them to an associated [`ErrorHandle`].
8+
#[derive(Clone)]
9+
pub struct ErrorCapture {
10+
tx: tokio::sync::mpsc::Sender<ErrorMessage>,
11+
}
12+
13+
impl ErrorCapture {
14+
pub fn new() -> (Self, ErrorHandle) {
15+
let (tx, rx) = tokio::sync::mpsc::channel(1);
16+
(Self { tx }, ErrorHandle { rx })
17+
}
18+
19+
/// Wraps a future such that its error result is sent to the associated [`ErrorHandle`].
20+
pub async fn wrap_future<F, O>(self, fut: F)
21+
where
22+
F: Future<Output = PolarsResult<O>>,
23+
{
24+
let err: Result<(), tokio::sync::mpsc::error::TrySendError<ErrorMessage>> =
25+
match AssertUnwindSafe(fut).catch_unwind().await {
26+
Ok(Ok(_)) => return,
27+
Ok(Err(err)) => self.tx.try_send(ErrorMessage::Error(err)),
28+
Err(panic) => self.tx.try_send(ErrorMessage::Panic(panic)),
29+
};
30+
drop(err);
31+
}
32+
}
33+
34+
enum ErrorMessage {
35+
Error(PolarsError),
36+
Panic(Box<dyn Any + Send + 'static>),
37+
}
38+
39+
/// Handle to await the completion of multiple tasks. Propagates error results
40+
/// and resumes unwinds when joined.
41+
pub struct ErrorHandle {
42+
rx: tokio::sync::mpsc::Receiver<ErrorMessage>,
43+
}
44+
45+
impl ErrorHandle {
46+
pub fn has_errored(&self) -> bool {
47+
!self.rx.is_empty()
48+
}
49+
50+
/// Block until either an error is received, or all [`ErrorCapture`]s associated with this
51+
/// handle are dropped (i.e. successful completion of all wrapped futures).
52+
///
53+
/// # Panics
54+
/// If a panic is received, this will resume unwinding.
55+
pub async fn join(self) -> PolarsResult<()> {
56+
let ErrorHandle { mut rx } = self;
57+
58+
match rx.recv().await {
59+
None => Ok(()),
60+
Some(ErrorMessage::Error(e)) => Err(e),
61+
Some(ErrorMessage::Panic(panic)) => std::panic::resume_unwind(panic),
62+
}
63+
}
64+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
use std::sync::Arc;
2+
3+
#[derive(Clone)]
4+
pub enum ExcludeKeysProjection {
5+
/// Project these indices
6+
Indices(Arc<[usize]>),
7+
/// Project this many columns from the left.
8+
Width(usize),
9+
}
10+
11+
impl ExcludeKeysProjection {
12+
pub fn iter_indices(&self) -> impl ExactSizeIterator<Item = usize> {
13+
let (indices, end) = match self {
14+
Self::Indices(indices) => (indices.as_ref(), indices.len()),
15+
Self::Width(width) => (&[][..], *width),
16+
};
17+
18+
(0..end).map(|i| if indices.is_empty() { i } else { indices[i] })
19+
}
20+
21+
pub fn len(&self) -> usize {
22+
match self {
23+
Self::Indices(indices) => indices.len(),
24+
Self::Width(len) => *len,
25+
}
26+
}
27+
}
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
use std::fmt::Write;
2+
use std::sync::Arc;
3+
4+
use polars_core::prelude::{Column, DataType};
5+
use polars_error::PolarsResult;
6+
use polars_io::cloud::CloudOptions;
7+
use polars_io::pl_async;
8+
use polars_io::utils::HIVE_VALUE_ENCODE_CHARSET;
9+
use polars_io::utils::file::Writeable;
10+
use polars_plan::dsl::sink2::{FileProviderReturn, FileProviderType};
11+
use polars_plan::prelude::sink2::FileProviderArgs;
12+
use polars_utils::plpath::PlPath;
13+
14+
pub struct FileProvider {
15+
pub base_path: PlPath,
16+
pub cloud_options: Option<Arc<CloudOptions>>,
17+
pub provider_type: FileProviderType,
18+
}
19+
20+
impl FileProvider {
21+
pub async fn open_file(&self, args: FileProviderArgs) -> PolarsResult<Writeable> {
22+
let provided_path: String = match &self.provider_type {
23+
FileProviderType::Hive { extension } => {
24+
let FileProviderArgs {
25+
index_in_partition,
26+
partition_keys,
27+
} = args;
28+
29+
let mut partition_parts = String::new();
30+
31+
let partition_keys: &[Column] = partition_keys.get_columns();
32+
33+
write!(
34+
&mut partition_parts,
35+
"{}",
36+
HivePathFormatter::new(partition_keys)
37+
)
38+
.unwrap();
39+
40+
assert!(index_in_partition <= 0xffff_ffff);
41+
42+
write!(&mut partition_parts, "{index_in_partition:08x}.{extension}").unwrap();
43+
44+
partition_parts
45+
},
46+
47+
FileProviderType::Function(f) => {
48+
let f = f.clone();
49+
50+
let out = pl_async::get_runtime()
51+
.spawn_blocking(move || f.get_file(args))
52+
.await
53+
.unwrap()?;
54+
55+
match out {
56+
FileProviderReturn::Path(p) => p,
57+
FileProviderReturn::Writeable(v) => return Ok(v),
58+
}
59+
},
60+
61+
FileProviderType::Legacy(_) => unreachable!(),
62+
};
63+
64+
let path = self.base_path.as_ref().join(&provided_path);
65+
let path = path.as_ref();
66+
67+
if let Some(path) = path.as_local_path().and_then(|p| p.parent()) {
68+
// Ignore errors from directory creation - the `Writeable::try_new()` below will raise
69+
// appropriate errors.
70+
let _ = tokio::fs::DirBuilder::new()
71+
.recursive(true)
72+
.create(path)
73+
.await;
74+
}
75+
76+
Writeable::try_new(path, self.cloud_options.as_deref())
77+
}
78+
}
79+
80+
/// # Panics
81+
/// The `Display` impl of this will panic if a column has non-unit length.
82+
pub struct HivePathFormatter<'a> {
83+
keys: &'a [Column],
84+
}
85+
86+
impl<'a> HivePathFormatter<'a> {
87+
pub fn new(keys: &'a [Column]) -> Self {
88+
Self { keys }
89+
}
90+
}
91+
92+
impl std::fmt::Display for HivePathFormatter<'_> {
93+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94+
for column in self.keys {
95+
assert_eq!(column.len(), 1);
96+
let column = column.cast(&DataType::String).unwrap();
97+
98+
let key = column.name();
99+
let value = percent_encoding::percent_encode(
100+
column
101+
.str()
102+
.unwrap()
103+
.get(0)
104+
.unwrap_or("__HIVE_DEFAULT_PARTITION__")
105+
.as_bytes(),
106+
HIVE_VALUE_ENCODE_CHARSET,
107+
);
108+
109+
write!(f, "{key}={value}/")?
110+
}
111+
112+
Ok(())
113+
}
114+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
use polars_error::PolarsResult;
2+
3+
use crate::async_executor;
4+
use crate::async_primitives::connector;
5+
use crate::nodes::io_sinks2::components::sink_morsel::SinkMorsel;
6+
use crate::nodes::io_sinks2::components::size::RowCountAndSize;
7+
8+
pub type FileSinkPermit = tokio::sync::OwnedSemaphorePermit;
9+
10+
pub struct FileSinkTaskData {
11+
pub morsel_tx: connector::Sender<SinkMorsel>,
12+
pub start_position: RowCountAndSize,
13+
pub task_handle: async_executor::JoinHandle<PolarsResult<FileSinkPermit>>,
14+
}
15+
16+
impl FileSinkTaskData {
17+
/// Signals to the writer to close, and returns its task handle.
18+
pub fn close(self) -> async_executor::JoinHandle<PolarsResult<FileSinkPermit>> {
19+
self.task_handle
20+
}
21+
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
use std::sync::Arc;
2+
3+
use polars_core::prelude::Column;
4+
use polars_core::schema::Schema;
5+
use polars_utils::marked_usize::MarkedUsize;
6+
7+
/// Applies a `with_columns()` operation with pre-computed indices.
8+
#[derive(Clone)]
9+
pub struct HStackColumns {
10+
gather_indices: Arc<[MarkedUsize]>,
11+
}
12+
13+
impl HStackColumns {
14+
/// Note:
15+
/// * Dtypes of the schemas are unused.
16+
pub fn new(output_schema: &Schema, schema_left: &Schema, schema_right: &Schema) -> Self {
17+
assert!(schema_left.len() <= MarkedUsize::UNMARKED_MAX);
18+
assert!(schema_right.len() <= MarkedUsize::UNMARKED_MAX);
19+
20+
let gather_indices: Arc<[MarkedUsize]> = output_schema
21+
.iter_names()
22+
.map(|name| {
23+
if let Some((idx, ..)) = schema_right.get_full(name) {
24+
MarkedUsize::new(idx, true)
25+
} else {
26+
MarkedUsize::new(schema_left.get_full(name).unwrap().0, false)
27+
}
28+
})
29+
.collect();
30+
31+
Self { gather_indices }
32+
}
33+
34+
#[expect(unused)]
35+
pub fn output_width(&self) -> usize {
36+
self.gather_indices.len()
37+
}
38+
39+
/// Broadcasts unit-length columns from the RHS.
40+
pub fn hstack_columns_broadcast(
41+
&self,
42+
height: usize,
43+
cols_left: &[Column],
44+
cols_right: &[Column],
45+
) -> Vec<Column> {
46+
self.gather_indices
47+
.iter()
48+
.copied()
49+
.map(|mi| {
50+
let i = mi.idx();
51+
52+
if mi.marked() {
53+
let c = &cols_right[i];
54+
55+
if c.len() != height {
56+
assert_eq!(c.len(), 1);
57+
c.new_from_index(0, height)
58+
} else {
59+
c.clone()
60+
}
61+
} else {
62+
cols_left[i].clone()
63+
}
64+
})
65+
.collect()
66+
}
67+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
pub mod arg_sort;
2+
pub mod error_capture;
3+
pub mod exclude_keys_projection;
4+
pub mod file_provider;
5+
pub mod file_sink;
6+
pub mod hstack_columns;
7+
pub mod sink_morsel;
8+
pub mod size;
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
use polars_core::frame::DataFrame;
2+
3+
pub type SinkMorselPermit = tokio::sync::OwnedSemaphorePermit;
4+
5+
/// In-flight morsel in the IO sink. Holds a permit against a semaphore that restricts
6+
/// the total number of sink morsels in memory.
7+
pub struct SinkMorsel {
8+
df: DataFrame,
9+
/// Should only be dropped once the data associated with this morsel has been dropped from memory.
10+
permit: SinkMorselPermit,
11+
}
12+
13+
impl SinkMorsel {
14+
pub fn new(df: DataFrame, permit: SinkMorselPermit) -> Self {
15+
Self { df, permit }
16+
}
17+
18+
pub fn into_inner(self) -> (DataFrame, SinkMorselPermit) {
19+
(self.df, self.permit)
20+
}
21+
22+
pub fn df(&self) -> &DataFrame {
23+
&self.df
24+
}
25+
26+
pub fn df_mut(&mut self) -> &mut DataFrame {
27+
&mut self.df
28+
}
29+
}

0 commit comments

Comments
 (0)