5
5
from collections import OrderedDict
6
6
from concurrent .futures import ThreadPoolExecutor
7
7
from datetime import datetime
8
- from typing import Any , Callable , Dict , List , Optional , Tuple , Union
8
+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union , Iterator
9
9
10
10
import pandas as pd
11
11
import pyarrow as arrow
@@ -578,6 +578,7 @@ def map_batches(
578
578
func : Callable [[arrow .Table ], arrow .Table ],
579
579
* ,
580
580
batch_size : int = 122880 ,
581
+ streaming : bool = False ,
581
582
** kwargs ,
582
583
) -> DataFrame :
583
584
"""
@@ -590,18 +591,35 @@ def map_batches(
590
591
It should take a `arrow.Table` as input and returns a `arrow.Table`.
591
592
batch_size, optional
592
593
The number of rows in each batch. Defaults to 122880.
594
+ streaming, optional
595
+ If true, the function takes an iterator of `arrow.Table` as input and yields a streaming of `arrow.Table` as output.
596
+ i.e. func: Callable[[Iterator[arrow.Table]], Iterator[arrow.Table]]
597
+ Defaults to false.
593
598
"""
594
599
595
- def process_func (_runtime_ctx , tables : List [arrow .Table ]) -> arrow .Table :
596
- return func (tables [0 ])
600
+ if streaming :
601
+ def process_func (_runtime_ctx , readers : List [arrow .RecordBatchReader ]) -> Iterator [arrow .Table ]:
602
+ tables = map (lambda batch : arrow .Table .from_batches ([batch ]), readers [0 ])
603
+ return func (tables )
597
604
598
- plan = ArrowBatchNode (
599
- self .session ._ctx ,
600
- (self .plan ,),
601
- process_func = process_func ,
602
- streaming_batch_size = batch_size ,
603
- ** kwargs ,
604
- )
605
+ plan = ArrowStreamNode (
606
+ self .session ._ctx ,
607
+ (self .plan ,),
608
+ process_func = process_func ,
609
+ streaming_batch_size = batch_size ,
610
+ ** kwargs ,
611
+ )
612
+ else :
613
+ def process_func (_runtime_ctx , tables : List [arrow .Table ]) -> arrow .Table :
614
+ return func (tables [0 ])
615
+
616
+ plan = ArrowBatchNode (
617
+ self .session ._ctx ,
618
+ (self .plan ,),
619
+ process_func = process_func ,
620
+ streaming_batch_size = batch_size ,
621
+ ** kwargs ,
622
+ )
605
623
return DataFrame (self .session , plan , recompute = self .need_recompute )
606
624
607
625
def limit (self , limit : int ) -> DataFrame :
0 commit comments