Skip to content

Commit 0f43cf8

Browse files
committed
expose ArrowBatchNode to DataFrame API
1 parent 770aa41 commit 0f43cf8

File tree

2 files changed

+54
-11
lines changed

2 files changed

+54
-11
lines changed

smallpond/dataframe.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections import OrderedDict
66
from concurrent.futures import ThreadPoolExecutor
77
from datetime import datetime
8-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
8+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Iterator
99

1010
import pandas as pd
1111
import pyarrow as arrow
@@ -578,6 +578,7 @@ def map_batches(
578578
func: Callable[[arrow.Table], arrow.Table],
579579
*,
580580
batch_size: int = 122880,
581+
streaming: bool = False,
581582
**kwargs,
582583
) -> DataFrame:
583584
"""
@@ -590,18 +591,35 @@ def map_batches(
590591
It should take a `arrow.Table` as input and returns a `arrow.Table`.
591592
batch_size, optional
592593
The number of rows in each batch. Defaults to 122880.
594+
streaming, optional
595+
If true, the function takes an iterator of `arrow.Table` as input and yields a streaming of `arrow.Table` as output.
596+
i.e. func: Callable[[Iterator[arrow.Table]], Iterator[arrow.Table]]
597+
Defaults to false.
593598
"""
594599

595-
def process_func(_runtime_ctx, tables: List[arrow.Table]) -> arrow.Table:
596-
return func(tables[0])
600+
if streaming:
601+
def process_func(_runtime_ctx, readers: List[arrow.RecordBatchReader]) -> Iterator[arrow.Table]:
602+
tables = map(lambda batch: arrow.Table.from_batches([batch]), readers[0])
603+
return func(tables)
597604

598-
plan = ArrowBatchNode(
599-
self.session._ctx,
600-
(self.plan,),
601-
process_func=process_func,
602-
streaming_batch_size=batch_size,
603-
**kwargs,
604-
)
605+
plan = ArrowStreamNode(
606+
self.session._ctx,
607+
(self.plan,),
608+
process_func=process_func,
609+
streaming_batch_size=batch_size,
610+
**kwargs,
611+
)
612+
else:
613+
def process_func(_runtime_ctx, tables: List[arrow.Table]) -> arrow.Table:
614+
return func(tables[0])
615+
616+
plan = ArrowBatchNode(
617+
self.session._ctx,
618+
(self.plan,),
619+
process_func=process_func,
620+
streaming_batch_size=batch_size,
621+
**kwargs,
622+
)
605623
return DataFrame(self.session, plan, recompute=self.need_recompute)
606624

607625
def limit(self, limit: int) -> DataFrame:

tests/test_dataframe.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List
1+
from typing import Iterator, List
22

33
import pandas as pd
44
import pyarrow as pa
@@ -84,6 +84,31 @@ def test_map_batches(sp: Session):
8484
assert df.take_all() == [{"num_rows": 350}, {"num_rows": 350}, {"num_rows": 300}]
8585

8686

87+
def test_map_batches_streaming(sp: Session):
88+
df = sp.read_parquet("tests/data/mock_urls/*.parquet")
89+
90+
def batched2(tables: Iterator[pa.Table]) -> Iterator[pa.Table]:
91+
# same as itertools.pairwise
92+
num_rows = 0
93+
count = 0
94+
for batch in tables:
95+
num_rows += batch.num_rows
96+
count += 1
97+
if count == 2:
98+
yield pa.table({"num_rows": [num_rows]})
99+
num_rows = 0
100+
count = 0
101+
if count > 0:
102+
yield pa.table({"num_rows": [num_rows]})
103+
104+
df = df.map_batches(
105+
batched2,
106+
batch_size=350,
107+
streaming=True,
108+
)
109+
assert df.take_all() == [{"num_rows": 700}, {"num_rows": 300}]
110+
111+
87112
def test_filter(sp: Session):
88113
df = sp.from_arrow(pa.table({"a": [1, 2, 3], "b": [4, 5, 6]}))
89114
df1 = df.filter("a > 1")

0 commit comments

Comments
 (0)