|
| 1 | +//! Reads batches from a `dyn Fn` |
| 2 | +
|
| 3 | +use async_trait::async_trait; |
1 | 4 | use polars_core::frame::DataFrame;
|
2 | 5 | use polars_core::schema::SchemaRef;
|
3 | 6 | use polars_error::{PolarsResult, polars_err};
|
| 7 | +use polars_utils::IdxSize; |
4 | 8 | use polars_utils::pl_str::PlSmallStr;
|
5 |
| -use polars_utils::{IdxSize, format_pl_smallstr}; |
6 |
| -use tokio::sync::oneshot; |
7 | 9 |
|
8 |
| -use super::{ |
9 |
| - JoinHandle, Morsel, MorselSeq, SourceNode, SourceOutput, StreamingExecutionState, TaskPriority, |
| 10 | +use crate::async_executor::{JoinHandle, TaskPriority, spawn}; |
| 11 | +use crate::execute::StreamingExecutionState; |
| 12 | +use crate::morsel::{Morsel, MorselSeq, SourceToken}; |
| 13 | +use crate::nodes::io_sources::multi_file_reader::reader_interface::output::{ |
| 14 | + FileReaderOutputRecv, FileReaderOutputSend, |
| 15 | +}; |
| 16 | +use crate::nodes::io_sources::multi_file_reader::reader_interface::{ |
| 17 | + BeginReadArgs, FileReader, FileReaderCallbacks, |
10 | 18 | };
|
11 |
| -use crate::async_executor::spawn; |
12 |
| -use crate::async_primitives::connector::Receiver; |
13 |
| -use crate::async_primitives::wait_group::WaitGroup; |
14 |
| -use crate::morsel::SourceToken; |
15 | 19 |
|
16 |
| -type GetBatchFn = |
| 20 | +pub mod builder { |
| 21 | + |
| 22 | + use std::sync::{Arc, Mutex}; |
| 23 | + |
| 24 | + use polars_utils::pl_str::PlSmallStr; |
| 25 | + |
| 26 | + use super::BatchFnReader; |
| 27 | + use crate::nodes::io_sources::multi_file_reader::reader_interface::FileReader; |
| 28 | + use crate::nodes::io_sources::multi_file_reader::reader_interface::builder::FileReaderBuilder; |
| 29 | + use crate::nodes::io_sources::multi_file_reader::reader_interface::capabilities::ReaderCapabilities; |
| 30 | + |
| 31 | + pub struct BatchFnReaderBuilder { |
| 32 | + pub name: PlSmallStr, |
| 33 | + pub reader: Mutex<Option<BatchFnReader>>, |
| 34 | + } |
| 35 | + |
| 36 | + impl FileReaderBuilder for BatchFnReaderBuilder { |
| 37 | + fn reader_name(&self) -> &str { |
| 38 | + &self.name |
| 39 | + } |
| 40 | + |
| 41 | + fn reader_capabilities(&self) -> ReaderCapabilities { |
| 42 | + ReaderCapabilities::empty() |
| 43 | + } |
| 44 | + |
| 45 | + fn build_file_reader( |
| 46 | + &self, |
| 47 | + _source: polars_plan::prelude::ScanSource, |
| 48 | + _cloud_options: Option<Arc<polars_io::cloud::CloudOptions>>, |
| 49 | + scan_source_idx: usize, |
| 50 | + ) -> Box<dyn FileReader> { |
| 51 | + assert_eq!(scan_source_idx, 0); |
| 52 | + |
| 53 | + Box::new( |
| 54 | + self.reader |
| 55 | + .try_lock() |
| 56 | + .unwrap() |
| 57 | + .take() |
| 58 | + .expect("BatchFnReaderBuilder called more than once"), |
| 59 | + ) as Box<dyn FileReader> |
| 60 | + } |
| 61 | + } |
| 62 | + |
| 63 | + impl std::fmt::Debug for BatchFnReaderBuilder { |
| 64 | + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| 65 | + f.write_str("BatchFnReaderBuilder: name: ")?; |
| 66 | + f.write_str(&self.name)?; |
| 67 | + |
| 68 | + Ok(()) |
| 69 | + } |
| 70 | + } |
| 71 | +} |
| 72 | + |
| 73 | +pub type GetBatchFn = |
17 | 74 | Box<dyn Fn(&StreamingExecutionState) -> PolarsResult<Option<DataFrame>> + Send + Sync>;
|
18 | 75 |
|
19 |
| -pub struct BatchSourceNode { |
20 |
| - pub name: PlSmallStr, |
21 |
| - pub output_schema: SchemaRef, |
22 |
| - pub get_batch_fn: Option<GetBatchFn>, |
| 76 | +/// Wraps `GetBatchFn` to support peeking. |
| 77 | +pub struct GetBatchState { |
| 78 | + func: GetBatchFn, |
| 79 | + peek: Option<DataFrame>, |
23 | 80 | }
|
24 | 81 |
|
25 |
| -impl BatchSourceNode { |
26 |
| - pub fn new(name: &str, output_schema: SchemaRef, get_batch_fn: Option<GetBatchFn>) -> Self { |
27 |
| - let name = format_pl_smallstr!("batch_source[{name}]"); |
28 |
| - Self { |
29 |
| - name, |
30 |
| - output_schema, |
31 |
| - get_batch_fn, |
| 82 | +impl GetBatchState { |
| 83 | + pub fn peek(&mut self, state: &StreamingExecutionState) -> PolarsResult<Option<&DataFrame>> { |
| 84 | + if self.peek.is_none() { |
| 85 | + self.peek = (self.func)(state)?; |
| 86 | + } |
| 87 | + |
| 88 | + Ok(self.peek.as_ref()) |
| 89 | + } |
| 90 | + |
| 91 | + pub fn next(&mut self, state: &StreamingExecutionState) -> PolarsResult<Option<DataFrame>> { |
| 92 | + if let Some(df) = self.peek.take() { |
| 93 | + Ok(Some(df)) |
| 94 | + } else { |
| 95 | + (self.func)(state) |
32 | 96 | }
|
33 | 97 | }
|
34 | 98 | }
|
35 | 99 |
|
36 |
| -impl SourceNode for BatchSourceNode { |
37 |
| - fn name(&self) -> &str { |
38 |
| - self.name.as_str() |
| 100 | +impl From<GetBatchFn> for GetBatchState { |
| 101 | + fn from(func: GetBatchFn) -> Self { |
| 102 | + Self { func, peek: None } |
39 | 103 | }
|
| 104 | +} |
| 105 | + |
| 106 | +pub struct BatchFnReader { |
| 107 | + pub name: PlSmallStr, |
| 108 | + pub output_schema: Option<SchemaRef>, |
| 109 | + pub get_batch_state: Option<GetBatchState>, |
| 110 | + pub verbose: bool, |
| 111 | +} |
40 | 112 |
|
41 |
| - fn is_source_output_parallel(&self, _is_receiver_serial: bool) -> bool { |
42 |
| - false |
| 113 | +#[async_trait] |
| 114 | +impl FileReader for BatchFnReader { |
| 115 | + async fn initialize(&mut self) -> PolarsResult<()> { |
| 116 | + Ok(()) |
43 | 117 | }
|
44 | 118 |
|
45 |
| - fn spawn_source( |
| 119 | + fn begin_read( |
46 | 120 | &mut self,
|
47 |
| - mut output_recv: Receiver<SourceOutput>, |
48 |
| - state: &StreamingExecutionState, |
49 |
| - join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>, |
50 |
| - unrestricted_row_count: Option<oneshot::Sender<polars_utils::IdxSize>>, |
51 |
| - ) { |
52 |
| - // We only spawn this once, so this is all fine. |
53 |
| - let output_schema = self.output_schema.clone(); |
54 |
| - let get_batch_fn = self.get_batch_fn.take().unwrap(); |
55 |
| - let state = state.clone(); |
56 |
| - join_handles.push(spawn(TaskPriority::Low, async move { |
57 |
| - let mut seq = MorselSeq::default(); |
58 |
| - let mut n_rows_seen = 0; |
59 |
| - |
60 |
| - 'phase_loop: while let Ok(phase_output) = output_recv.recv().await { |
61 |
| - let mut sender = phase_output.port.serial(); |
62 |
| - let source_token = SourceToken::new(); |
63 |
| - let wait_group = WaitGroup::default(); |
64 |
| - |
65 |
| - loop { |
66 |
| - let df = (get_batch_fn)(&state)?; |
67 |
| - let Some(df) = df else { |
68 |
| - if let Some(unrestricted_row_count) = unrestricted_row_count { |
69 |
| - if unrestricted_row_count.send(n_rows_seen).is_err() { |
70 |
| - return Ok(()); |
71 |
| - } |
72 |
| - } |
73 |
| - |
74 |
| - if n_rows_seen == 0 { |
75 |
| - let morsel = Morsel::new( |
76 |
| - DataFrame::empty_with_schema(output_schema.as_ref()), |
77 |
| - seq, |
78 |
| - source_token.clone(), |
79 |
| - ); |
80 |
| - if sender.send(morsel).await.is_err() { |
81 |
| - return Ok(()); |
82 |
| - } |
83 |
| - } |
84 |
| - |
85 |
| - break 'phase_loop; |
86 |
| - }; |
87 |
| - |
88 |
| - let num_rows = IdxSize::try_from(df.height()).map_err(|_| { |
89 |
| - polars_err!(bigidx, ctx = "batch source", size = df.height()) |
90 |
| - })?; |
91 |
| - n_rows_seen = n_rows_seen.checked_add(num_rows).ok_or_else(|| { |
92 |
| - polars_err!( |
93 |
| - bigidx, |
94 |
| - ctx = "batch source", |
95 |
| - size = n_rows_seen as usize + num_rows as usize |
96 |
| - ) |
97 |
| - })?; |
98 |
| - |
99 |
| - let mut morsel = Morsel::new(df, seq, source_token.clone()); |
100 |
| - morsel.set_consume_token(wait_group.token()); |
101 |
| - seq = seq.successor(); |
102 |
| - |
103 |
| - if sender.send(morsel).await.is_err() { |
104 |
| - return Ok(()); |
105 |
| - } |
106 |
| - |
107 |
| - wait_group.wait().await; |
108 |
| - if source_token.stop_requested() { |
109 |
| - phase_output.outcome.stop(); |
110 |
| - continue 'phase_loop; |
111 |
| - } |
| 121 | + args: BeginReadArgs, |
| 122 | + ) -> PolarsResult<(FileReaderOutputRecv, JoinHandle<PolarsResult<()>>)> { |
| 123 | + let BeginReadArgs { |
| 124 | + projected_schema: _, |
| 125 | + row_index: None, |
| 126 | + pre_slice: None, |
| 127 | + predicate: None, |
| 128 | + cast_columns_policy: _, |
| 129 | + num_pipelines: _, |
| 130 | + callbacks: |
| 131 | + FileReaderCallbacks { |
| 132 | + file_schema_tx, |
| 133 | + n_rows_in_file_tx, |
| 134 | + row_position_on_end_tx, |
| 135 | + }, |
| 136 | + } = args |
| 137 | + else { |
| 138 | + panic!("unsupported args: {:?}", &args) |
| 139 | + }; |
| 140 | + |
| 141 | + // Must send this first before we `take()` the GetBatchState. |
| 142 | + if let Some(mut file_schema_tx) = file_schema_tx { |
| 143 | + _ = file_schema_tx.try_send(self._file_schema()?); |
| 144 | + } |
| 145 | + |
| 146 | + let mut get_batch_state = self |
| 147 | + .get_batch_state |
| 148 | + .take() |
| 149 | + // If this is ever needed we can buffer |
| 150 | + .expect("unimplemented: BatchFnReader called more than once"); |
| 151 | + |
| 152 | + // FIXME: Propagate this from BeginReadArgs. |
| 153 | + let exec_state = StreamingExecutionState::default(); |
| 154 | + |
| 155 | + let verbose = self.verbose; |
| 156 | + |
| 157 | + if verbose { |
| 158 | + eprintln!("[BatchFnReader]: name: {}", self.name); |
| 159 | + } |
| 160 | + |
| 161 | + let (mut morsel_sender, morsel_rx) = FileReaderOutputSend::new_serial(); |
| 162 | + |
| 163 | + let handle = spawn(TaskPriority::Low, async move { |
| 164 | + let mut seq: u64 = 0; |
| 165 | + // Note: We don't use this (it is handled by the bridge). But morsels require a source token. |
| 166 | + let source_token = SourceToken::new(); |
| 167 | + |
| 168 | + let mut n_rows_seen: usize = 0; |
| 169 | + |
| 170 | + while let Some(df) = get_batch_state.next(&exec_state)? { |
| 171 | + n_rows_seen = n_rows_seen.saturating_add(df.height()); |
| 172 | + |
| 173 | + if morsel_sender |
| 174 | + .send_morsel(Morsel::new(df, MorselSeq::new(seq), source_token.clone())) |
| 175 | + .await |
| 176 | + .is_err() |
| 177 | + { |
| 178 | + break; |
| 179 | + }; |
| 180 | + seq = seq.saturating_add(1); |
| 181 | + } |
| 182 | + |
| 183 | + if let Some(mut row_position_on_end_tx) = row_position_on_end_tx { |
| 184 | + let n_rows_seen = IdxSize::try_from(n_rows_seen) |
| 185 | + .map_err(|_| polars_err!(bigidx, ctx = "batch reader", size = n_rows_seen))?; |
| 186 | + |
| 187 | + _ = row_position_on_end_tx.try_send(n_rows_seen) |
| 188 | + } |
| 189 | + |
| 190 | + if let Some(mut n_rows_in_file_tx) = n_rows_in_file_tx { |
| 191 | + if verbose { |
| 192 | + eprintln!("[BatchFnReader]: read to end for full row count"); |
| 193 | + } |
| 194 | + |
| 195 | + while let Some(df) = get_batch_state.next(&exec_state)? { |
| 196 | + n_rows_seen = n_rows_seen.saturating_add(df.height()); |
112 | 197 | }
|
| 198 | + |
| 199 | + let n_rows_seen = IdxSize::try_from(n_rows_seen) |
| 200 | + .map_err(|_| polars_err!(bigidx, ctx = "batch reader", size = n_rows_seen))?; |
| 201 | + |
| 202 | + _ = n_rows_in_file_tx.try_send(n_rows_seen) |
113 | 203 | }
|
114 | 204 |
|
115 | 205 | Ok(())
|
116 |
| - })); |
| 206 | + }); |
| 207 | + |
| 208 | + Ok((morsel_rx, handle)) |
| 209 | + } |
| 210 | +} |
| 211 | + |
| 212 | +impl BatchFnReader { |
| 213 | + pub fn _file_schema(&mut self) -> PolarsResult<SchemaRef> { |
| 214 | + if self.output_schema.is_none() { |
| 215 | + let exec_state = StreamingExecutionState::default(); |
| 216 | + |
| 217 | + let schema = |
| 218 | + if let Some(df) = self.get_batch_state.as_mut().unwrap().peek(&exec_state)? { |
| 219 | + df.schema().clone() |
| 220 | + } else { |
| 221 | + SchemaRef::default() |
| 222 | + }; |
| 223 | + |
| 224 | + self.output_schema = Some(schema); |
| 225 | + } |
| 226 | + |
| 227 | + Ok(self.output_schema.clone().unwrap()) |
117 | 228 | }
|
118 | 229 | }
|
0 commit comments