Skip to content

Commit 56f5089

Browse files
authored
[FEAT] Enable Actor Pool UDFs by default (Eventual-Inc#3488)
Todo: - [x] fix tests - [ ] add docs (future PR) - [ ] add threaded concurrency (future PR)
1 parent 8de0101 commit 56f5089

File tree

36 files changed

+810
-1561
lines changed

36 files changed

+810
-1561
lines changed

daft/daft/__init__.pyi

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ from daft.io.scan import ScanOperator
99
from daft.plan_scheduler.physical_plan_scheduler import PartitionT
1010
from daft.runners.partitioning import PartitionCacheEntry
1111
from daft.sql.sql_connection import SQLConnection
12-
from daft.udf import InitArgsType, PartialStatefulUDF, PartialStatelessUDF
12+
from daft.udf import BoundUDFArgs, InitArgsType, UninitializedUdf
1313

1414
if TYPE_CHECKING:
1515
import pyarrow as pa
@@ -1123,29 +1123,20 @@ def interval_lit(
11231123
) -> PyExpr: ...
11241124
def decimal_lit(sign: bool, digits: tuple[int, ...], exp: int) -> PyExpr: ...
11251125
def series_lit(item: PySeries) -> PyExpr: ...
1126-
def stateless_udf(
1126+
def udf(
11271127
name: str,
1128-
partial_stateless_udf: PartialStatelessUDF,
1128+
inner: UninitializedUdf,
1129+
bound_args: BoundUDFArgs,
11291130
expressions: list[PyExpr],
11301131
return_dtype: PyDataType,
1131-
resource_request: ResourceRequest | None,
1132-
batch_size: int | None,
1133-
) -> PyExpr: ...
1134-
def stateful_udf(
1135-
name: str,
1136-
partial_stateful_udf: PartialStatefulUDF,
1137-
expressions: list[PyExpr],
1138-
return_dtype: PyDataType,
1139-
resource_request: ResourceRequest | None,
11401132
init_args: InitArgsType,
1133+
resource_request: ResourceRequest | None,
11411134
batch_size: int | None,
11421135
concurrency: int | None,
11431136
) -> PyExpr: ...
11441137
def check_column_name_validity(name: str, schema: PySchema): ...
1145-
def extract_partial_stateful_udf_py(
1146-
expression: PyExpr,
1147-
) -> dict[str, tuple[PartialStatefulUDF, InitArgsType]]: ...
1148-
def bind_stateful_udfs(expression: PyExpr, initialized_funcs: dict[str, Callable]) -> PyExpr: ...
1138+
def initialize_udfs(expression: PyExpr) -> PyExpr: ...
1139+
def get_udf_names(expression: PyExpr) -> list[str]: ...
11491140
def resolve_expr(expr: PyExpr, schema: PySchema) -> tuple[PyExpr, PyField]: ...
11501141
def hash(expr: PyExpr, seed: Any | None = None) -> PyExpr: ...
11511142
def cosine_distance(expr: PyExpr, other: PyExpr) -> PyExpr: ...
@@ -1885,12 +1876,9 @@ class PyDaftPlanningConfig:
18851876
def with_config_values(
18861877
self,
18871878
default_io_config: IOConfig | None = None,
1888-
enable_actor_pool_projections: bool | None = None,
18891879
) -> PyDaftPlanningConfig: ...
18901880
@property
18911881
def default_io_config(self) -> IOConfig: ...
1892-
@property
1893-
def enable_actor_pool_projections(self) -> bool: ...
18941882

18951883
def build_type() -> str: ...
18961884
def version() -> str: ...

daft/execution/actor_pool_udf.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
import multiprocessing as mp
5+
from typing import TYPE_CHECKING
6+
7+
from daft.expressions import Expression, ExpressionsProjection
8+
from daft.table import MicroPartition
9+
10+
if TYPE_CHECKING:
11+
from multiprocessing.connection import Connection
12+
13+
from daft.daft import PyExpr, PyMicroPartition
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
def actor_event_loop(uninitialized_projection: ExpressionsProjection, conn: Connection) -> None:
19+
"""
20+
Event loop that runs in a actor process and receives MicroPartitions to evaluate with an initialized UDF projection.
21+
22+
Terminates once it receives None.
23+
"""
24+
initialized_projection = ExpressionsProjection([e._initialize_udfs() for e in uninitialized_projection])
25+
26+
while True:
27+
input: MicroPartition | None = conn.recv()
28+
if input is None:
29+
break
30+
31+
output = input.eval_expression_list(initialized_projection)
32+
conn.send(output)
33+
34+
35+
class ActorHandle:
36+
"""Handle class for initializing, interacting with, and tearing down a single local actor process."""
37+
38+
def __init__(self, projection: list[PyExpr]) -> None:
39+
self.handle_conn, actor_conn = mp.Pipe()
40+
41+
expr_projection = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in projection])
42+
self.actor_process = mp.Process(target=actor_event_loop, args=(expr_projection, actor_conn))
43+
self.actor_process.start()
44+
45+
def eval_input(self, input: PyMicroPartition) -> PyMicroPartition:
46+
self.handle_conn.send(MicroPartition._from_pymicropartition(input))
47+
output: MicroPartition = self.handle_conn.recv()
48+
return output._micropartition
49+
50+
def teardown(self) -> None:
51+
self.handle_conn.send(None)
52+
self.handle_conn.close()
53+
self.actor_process.join()

daft/execution/execution_step.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata])
553553

554554

555555
@dataclass(frozen=True)
556-
class StatefulUDFProject(SingleOutputInstruction):
556+
class ActorPoolProject(SingleOutputInstruction):
557557
projection: ExpressionsProjection
558558

559559
def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
@@ -564,7 +564,7 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata])
564564
PartialPartitionMetadata(
565565
num_rows=None, # UDFs can potentially change cardinality
566566
size_bytes=None,
567-
boundaries=None, # TODO: figure out if the stateful UDF projection changes boundaries
567+
boundaries=None, # TODO: figure out if the actor pool UDF projection changes boundaries
568568
)
569569
]
570570

daft/execution/physical_plan.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def actor_pool_context(
229229
name: Name of the actor pool for debugging/observability
230230
resource_request: Requested amount of resources for each actor
231231
num_actors: Number of actors to spin up
232-
projection: Projection to be run on the incoming data (contains Stateful UDFs as well as other stateless expressions such as aliases)
232+
projection: Projection to be run on the incoming data (contains actor pool UDFs as well as other stateless expressions such as aliases)
233233
"""
234234
...
235235

@@ -243,12 +243,10 @@ def actor_pool_project(
243243
) -> InProgressPhysicalPlan[PartitionT]:
244244
stage_id = next(stage_id_counter)
245245

246-
from daft.daft import extract_partial_stateful_udf_py
246+
from daft.daft import get_udf_names
247247

248-
stateful_udf_names = "-".join(
249-
name for expr in projection for name in extract_partial_stateful_udf_py(expr._expr).keys()
250-
)
251-
actor_pool_name = f"{stateful_udf_names}-stage={stage_id}"
248+
udf_names = "-".join(name for expr in projection for name in get_udf_names(expr._expr))
249+
actor_pool_name = f"{udf_names}-stage={stage_id}"
252250

253251
# Keep track of materializations of the children tasks
254252
child_materializations: deque[SingleOutputPartitionTask[PartitionT]] = deque()
@@ -285,7 +283,7 @@ def actor_pool_project(
285283
actor_pool_id=actor_pool_id,
286284
)
287285
.add_instruction(
288-
instruction=execution_step.StatefulUDFProject(projection),
286+
instruction=execution_step.ActorPoolProject(projection),
289287
resource_request=task_resource_request,
290288
)
291289
.finalize_partition_task_single_output(

daft/execution/stateful_actor.py

Lines changed: 0 additions & 79 deletions
This file was deleted.

daft/expressions/expressions.py

Lines changed: 16 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import daft.daft as native
2121
from daft import context
22-
from daft.daft import CountMode, ImageFormat, ImageMode, ResourceRequest, bind_stateful_udfs
22+
from daft.daft import CountMode, ImageFormat, ImageMode, ResourceRequest, initialize_udfs
2323
from daft.daft import PyExpr as _PyExpr
2424
from daft.daft import col as _col
2525
from daft.daft import date_lit as _date_lit
@@ -28,13 +28,12 @@
2828
from daft.daft import list_sort as _list_sort
2929
from daft.daft import lit as _lit
3030
from daft.daft import series_lit as _series_lit
31-
from daft.daft import stateful_udf as _stateful_udf
32-
from daft.daft import stateless_udf as _stateless_udf
3331
from daft.daft import time_lit as _time_lit
3432
from daft.daft import timestamp_lit as _timestamp_lit
3533
from daft.daft import to_struct as _to_struct
3634
from daft.daft import tokenize_decode as _tokenize_decode
3735
from daft.daft import tokenize_encode as _tokenize_encode
36+
from daft.daft import udf as _udf
3837
from daft.daft import url_download as _url_download
3938
from daft.daft import utf8_count_matches as _utf8_count_matches
4039
from daft.datatype import DataType, TimeUnit
@@ -45,7 +44,7 @@
4544

4645
if TYPE_CHECKING:
4746
from daft.io import IOConfig
48-
from daft.udf import PartialStatefulUDF, PartialStatelessUDF
47+
from daft.udf import BoundUDFArgs, InitArgsType, UninitializedUdf
4948
# This allows Sphinx to correctly work against our "namespaced" accessor functions by overriding @property to
5049
# return a class instance of the namespace instead of a property object.
5150
elif os.getenv("DAFT_SPHINX_BUILD") == "1":
@@ -260,39 +259,26 @@ def _to_expression(obj: object) -> Expression:
260259
return lit(obj)
261260

262261
@staticmethod
263-
def stateless_udf(
262+
def udf(
264263
name: builtins.str,
265-
partial: PartialStatelessUDF,
264+
inner: UninitializedUdf,
265+
bound_args: BoundUDFArgs,
266266
expressions: builtins.list[Expression],
267267
return_dtype: DataType,
268+
init_args: InitArgsType,
268269
resource_request: ResourceRequest | None,
269270
batch_size: int | None,
270-
) -> Expression:
271-
return Expression._from_pyexpr(
272-
_stateless_udf(
273-
name, partial, [e._expr for e in expressions], return_dtype._dtype, resource_request, batch_size
274-
)
275-
)
276-
277-
@staticmethod
278-
def stateful_udf(
279-
name: builtins.str,
280-
partial: PartialStatefulUDF,
281-
expressions: builtins.list[Expression],
282-
return_dtype: DataType,
283-
resource_request: ResourceRequest | None,
284-
init_args: tuple[tuple[Any, ...], dict[builtins.str, Any]] | None,
285-
batch_size: int | None,
286271
concurrency: int | None,
287272
) -> Expression:
288273
return Expression._from_pyexpr(
289-
_stateful_udf(
274+
_udf(
290275
name,
291-
partial,
276+
inner,
277+
bound_args,
292278
[e._expr for e in expressions],
293279
return_dtype._dtype,
294-
resource_request,
295280
init_args,
281+
resource_request,
296282
batch_size,
297283
concurrency,
298284
)
@@ -1018,7 +1004,7 @@ def apply(self, func: Callable, return_dtype: DataType) -> Expression:
10181004
Returns:
10191005
Expression: New expression after having run the function on the expression
10201006
"""
1021-
from daft.udf import CommonUDFArgs, StatelessUDF
1007+
from daft.udf import UDF
10221008

10231009
def batch_func(self_series):
10241010
return [func(x) for x in self_series.to_pylist()]
@@ -1028,14 +1014,10 @@ def batch_func(self_series):
10281014
name = name + "."
10291015
name = name + getattr(func, "__qualname__") # type: ignore[call-overload]
10301016

1031-
return StatelessUDF(
1017+
return UDF(
1018+
inner=batch_func,
10321019
name=name,
1033-
func=batch_func,
10341020
return_dtype=return_dtype,
1035-
common_args=CommonUDFArgs(
1036-
resource_request=None,
1037-
batch_size=None,
1038-
),
10391021
)(self)
10401022

10411023
def is_null(self) -> Expression:
@@ -1263,8 +1245,8 @@ def __reduce__(self) -> tuple:
12631245
def _input_mapping(self) -> builtins.str | None:
12641246
return self._expr._input_mapping()
12651247

1266-
def _bind_stateful_udfs(self, initialized_funcs: dict[builtins.str, Callable]) -> Expression:
1267-
return Expression._from_pyexpr(bind_stateful_udfs(self._expr, initialized_funcs))
1248+
def _initialize_udfs(self) -> Expression:
1249+
return Expression._from_pyexpr(initialize_udfs(self._expr))
12681250

12691251

12701252
SomeExpressionNamespace = TypeVar("SomeExpressionNamespace", bound="ExpressionNamespace")

0 commit comments

Comments
 (0)