Skip to content

Commit cf6c299

Browse files
refactor(rust): Dispatch new-streaming IO plugin source to updated multiscan (#22009)
1 parent e4ead38 commit cf6c299

File tree

3 files changed

+263
-101
lines changed

3 files changed

+263
-101
lines changed

Diff for: crates/polars-stream/src/execute.rs

+9
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,15 @@ pub struct StreamingExecutionState {
1818
pub in_memory_exec_state: ExecutionState,
1919
}
2020

21+
impl Default for StreamingExecutionState {
22+
fn default() -> Self {
23+
Self {
24+
num_pipelines: POOL.current_num_threads(),
25+
in_memory_exec_state: ExecutionState::default(),
26+
}
27+
}
28+
}
29+
2130
/// Finds all runnable pipeline blockers in the graph, that is, nodes which:
2231
/// - Only have blocked output ports.
2332
/// - Have at least one ready input port connected to a ready output port.

Diff for: crates/polars-stream/src/nodes/io_sources/batch.rs

+203-92
Original file line numberDiff line numberDiff line change
@@ -1,118 +1,229 @@
1+
//! Reads batches from a `dyn Fn`
2+
3+
use async_trait::async_trait;
14
use polars_core::frame::DataFrame;
25
use polars_core::schema::SchemaRef;
36
use polars_error::{PolarsResult, polars_err};
7+
use polars_utils::IdxSize;
48
use polars_utils::pl_str::PlSmallStr;
5-
use polars_utils::{IdxSize, format_pl_smallstr};
6-
use tokio::sync::oneshot;
79

8-
use super::{
9-
JoinHandle, Morsel, MorselSeq, SourceNode, SourceOutput, StreamingExecutionState, TaskPriority,
10+
use crate::async_executor::{JoinHandle, TaskPriority, spawn};
11+
use crate::execute::StreamingExecutionState;
12+
use crate::morsel::{Morsel, MorselSeq, SourceToken};
13+
use crate::nodes::io_sources::multi_file_reader::reader_interface::output::{
14+
FileReaderOutputRecv, FileReaderOutputSend,
15+
};
16+
use crate::nodes::io_sources::multi_file_reader::reader_interface::{
17+
BeginReadArgs, FileReader, FileReaderCallbacks,
1018
};
11-
use crate::async_executor::spawn;
12-
use crate::async_primitives::connector::Receiver;
13-
use crate::async_primitives::wait_group::WaitGroup;
14-
use crate::morsel::SourceToken;
1519

16-
type GetBatchFn =
20+
pub mod builder {
21+
22+
use std::sync::{Arc, Mutex};
23+
24+
use polars_utils::pl_str::PlSmallStr;
25+
26+
use super::BatchFnReader;
27+
use crate::nodes::io_sources::multi_file_reader::reader_interface::FileReader;
28+
use crate::nodes::io_sources::multi_file_reader::reader_interface::builder::FileReaderBuilder;
29+
use crate::nodes::io_sources::multi_file_reader::reader_interface::capabilities::ReaderCapabilities;
30+
31+
pub struct BatchFnReaderBuilder {
32+
pub name: PlSmallStr,
33+
pub reader: Mutex<Option<BatchFnReader>>,
34+
}
35+
36+
impl FileReaderBuilder for BatchFnReaderBuilder {
37+
fn reader_name(&self) -> &str {
38+
&self.name
39+
}
40+
41+
fn reader_capabilities(&self) -> ReaderCapabilities {
42+
ReaderCapabilities::empty()
43+
}
44+
45+
fn build_file_reader(
46+
&self,
47+
_source: polars_plan::prelude::ScanSource,
48+
_cloud_options: Option<Arc<polars_io::cloud::CloudOptions>>,
49+
scan_source_idx: usize,
50+
) -> Box<dyn FileReader> {
51+
assert_eq!(scan_source_idx, 0);
52+
53+
Box::new(
54+
self.reader
55+
.try_lock()
56+
.unwrap()
57+
.take()
58+
.expect("BatchFnReaderBuilder called more than once"),
59+
) as Box<dyn FileReader>
60+
}
61+
}
62+
63+
impl std::fmt::Debug for BatchFnReaderBuilder {
64+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65+
f.write_str("BatchFnReaderBuilder: name: ")?;
66+
f.write_str(&self.name)?;
67+
68+
Ok(())
69+
}
70+
}
71+
}
72+
73+
pub type GetBatchFn =
1774
Box<dyn Fn(&StreamingExecutionState) -> PolarsResult<Option<DataFrame>> + Send + Sync>;
1875

19-
pub struct BatchSourceNode {
20-
pub name: PlSmallStr,
21-
pub output_schema: SchemaRef,
22-
pub get_batch_fn: Option<GetBatchFn>,
76+
/// Wraps `GetBatchFn` to support peeking.
77+
pub struct GetBatchState {
78+
func: GetBatchFn,
79+
peek: Option<DataFrame>,
2380
}
2481

25-
impl BatchSourceNode {
26-
pub fn new(name: &str, output_schema: SchemaRef, get_batch_fn: Option<GetBatchFn>) -> Self {
27-
let name = format_pl_smallstr!("batch_source[{name}]");
28-
Self {
29-
name,
30-
output_schema,
31-
get_batch_fn,
82+
impl GetBatchState {
83+
pub fn peek(&mut self, state: &StreamingExecutionState) -> PolarsResult<Option<&DataFrame>> {
84+
if self.peek.is_none() {
85+
self.peek = (self.func)(state)?;
86+
}
87+
88+
Ok(self.peek.as_ref())
89+
}
90+
91+
pub fn next(&mut self, state: &StreamingExecutionState) -> PolarsResult<Option<DataFrame>> {
92+
if let Some(df) = self.peek.take() {
93+
Ok(Some(df))
94+
} else {
95+
(self.func)(state)
3296
}
3397
}
3498
}
3599

36-
impl SourceNode for BatchSourceNode {
37-
fn name(&self) -> &str {
38-
self.name.as_str()
100+
impl From<GetBatchFn> for GetBatchState {
101+
fn from(func: GetBatchFn) -> Self {
102+
Self { func, peek: None }
39103
}
104+
}
105+
106+
pub struct BatchFnReader {
107+
pub name: PlSmallStr,
108+
pub output_schema: Option<SchemaRef>,
109+
pub get_batch_state: Option<GetBatchState>,
110+
pub verbose: bool,
111+
}
40112

41-
fn is_source_output_parallel(&self, _is_receiver_serial: bool) -> bool {
42-
false
113+
#[async_trait]
114+
impl FileReader for BatchFnReader {
115+
async fn initialize(&mut self) -> PolarsResult<()> {
116+
Ok(())
43117
}
44118

45-
fn spawn_source(
119+
fn begin_read(
46120
&mut self,
47-
mut output_recv: Receiver<SourceOutput>,
48-
state: &StreamingExecutionState,
49-
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
50-
unrestricted_row_count: Option<oneshot::Sender<polars_utils::IdxSize>>,
51-
) {
52-
// We only spawn this once, so this is all fine.
53-
let output_schema = self.output_schema.clone();
54-
let get_batch_fn = self.get_batch_fn.take().unwrap();
55-
let state = state.clone();
56-
join_handles.push(spawn(TaskPriority::Low, async move {
57-
let mut seq = MorselSeq::default();
58-
let mut n_rows_seen = 0;
59-
60-
'phase_loop: while let Ok(phase_output) = output_recv.recv().await {
61-
let mut sender = phase_output.port.serial();
62-
let source_token = SourceToken::new();
63-
let wait_group = WaitGroup::default();
64-
65-
loop {
66-
let df = (get_batch_fn)(&state)?;
67-
let Some(df) = df else {
68-
if let Some(unrestricted_row_count) = unrestricted_row_count {
69-
if unrestricted_row_count.send(n_rows_seen).is_err() {
70-
return Ok(());
71-
}
72-
}
73-
74-
if n_rows_seen == 0 {
75-
let morsel = Morsel::new(
76-
DataFrame::empty_with_schema(output_schema.as_ref()),
77-
seq,
78-
source_token.clone(),
79-
);
80-
if sender.send(morsel).await.is_err() {
81-
return Ok(());
82-
}
83-
}
84-
85-
break 'phase_loop;
86-
};
87-
88-
let num_rows = IdxSize::try_from(df.height()).map_err(|_| {
89-
polars_err!(bigidx, ctx = "batch source", size = df.height())
90-
})?;
91-
n_rows_seen = n_rows_seen.checked_add(num_rows).ok_or_else(|| {
92-
polars_err!(
93-
bigidx,
94-
ctx = "batch source",
95-
size = n_rows_seen as usize + num_rows as usize
96-
)
97-
})?;
98-
99-
let mut morsel = Morsel::new(df, seq, source_token.clone());
100-
morsel.set_consume_token(wait_group.token());
101-
seq = seq.successor();
102-
103-
if sender.send(morsel).await.is_err() {
104-
return Ok(());
105-
}
106-
107-
wait_group.wait().await;
108-
if source_token.stop_requested() {
109-
phase_output.outcome.stop();
110-
continue 'phase_loop;
111-
}
121+
args: BeginReadArgs,
122+
) -> PolarsResult<(FileReaderOutputRecv, JoinHandle<PolarsResult<()>>)> {
123+
let BeginReadArgs {
124+
projected_schema: _,
125+
row_index: None,
126+
pre_slice: None,
127+
predicate: None,
128+
cast_columns_policy: _,
129+
num_pipelines: _,
130+
callbacks:
131+
FileReaderCallbacks {
132+
file_schema_tx,
133+
n_rows_in_file_tx,
134+
row_position_on_end_tx,
135+
},
136+
} = args
137+
else {
138+
panic!("unsupported args: {:?}", &args)
139+
};
140+
141+
// Must send this first before we `take()` the GetBatchState.
142+
if let Some(mut file_schema_tx) = file_schema_tx {
143+
_ = file_schema_tx.try_send(self._file_schema()?);
144+
}
145+
146+
let mut get_batch_state = self
147+
.get_batch_state
148+
.take()
149+
// If this is ever needed we can buffer
150+
.expect("unimplemented: BatchFnReader called more than once");
151+
152+
// FIXME: Propagate this from BeginReadArgs.
153+
let exec_state = StreamingExecutionState::default();
154+
155+
let verbose = self.verbose;
156+
157+
if verbose {
158+
eprintln!("[BatchFnReader]: name: {}", self.name);
159+
}
160+
161+
let (mut morsel_sender, morsel_rx) = FileReaderOutputSend::new_serial();
162+
163+
let handle = spawn(TaskPriority::Low, async move {
164+
let mut seq: u64 = 0;
165+
// Note: We don't use this (it is handled by the bridge). But morsels require a source token.
166+
let source_token = SourceToken::new();
167+
168+
let mut n_rows_seen: usize = 0;
169+
170+
while let Some(df) = get_batch_state.next(&exec_state)? {
171+
n_rows_seen = n_rows_seen.saturating_add(df.height());
172+
173+
if morsel_sender
174+
.send_morsel(Morsel::new(df, MorselSeq::new(seq), source_token.clone()))
175+
.await
176+
.is_err()
177+
{
178+
break;
179+
};
180+
seq = seq.saturating_add(1);
181+
}
182+
183+
if let Some(mut row_position_on_end_tx) = row_position_on_end_tx {
184+
let n_rows_seen = IdxSize::try_from(n_rows_seen)
185+
.map_err(|_| polars_err!(bigidx, ctx = "batch reader", size = n_rows_seen))?;
186+
187+
_ = row_position_on_end_tx.try_send(n_rows_seen)
188+
}
189+
190+
if let Some(mut n_rows_in_file_tx) = n_rows_in_file_tx {
191+
if verbose {
192+
eprintln!("[BatchFnReader]: read to end for full row count");
193+
}
194+
195+
while let Some(df) = get_batch_state.next(&exec_state)? {
196+
n_rows_seen = n_rows_seen.saturating_add(df.height());
112197
}
198+
199+
let n_rows_seen = IdxSize::try_from(n_rows_seen)
200+
.map_err(|_| polars_err!(bigidx, ctx = "batch reader", size = n_rows_seen))?;
201+
202+
_ = n_rows_in_file_tx.try_send(n_rows_seen)
113203
}
114204

115205
Ok(())
116-
}));
206+
});
207+
208+
Ok((morsel_rx, handle))
209+
}
210+
}
211+
212+
impl BatchFnReader {
213+
pub fn _file_schema(&mut self) -> PolarsResult<SchemaRef> {
214+
if self.output_schema.is_none() {
215+
let exec_state = StreamingExecutionState::default();
216+
217+
let schema =
218+
if let Some(df) = self.get_batch_state.as_mut().unwrap().peek(&exec_state)? {
219+
df.schema().clone()
220+
} else {
221+
SchemaRef::default()
222+
};
223+
224+
self.output_schema = Some(schema);
225+
}
226+
227+
Ok(self.output_schema.clone().unwrap())
117228
}
118229
}

0 commit comments

Comments
 (0)