Skip to content

Commit 4788627

Browse files
authored
feat: bi-directional streaming map (numaproj#197)
Signed-off-by: Sidhant Kohli <[email protected]>
1 parent a1d3eb5 commit 4788627

18 files changed

+691
-402
lines changed

pynumaflow/mapper/_dtypes.py

+4
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,7 @@ def handler(self, keys: list[str], datum: Datum) -> Messages:
204204
# MapAsyncCallable is a callable which can be used as a handler for the Asynchronous Map UDF
205205
MapAsyncHandlerCallable = Callable[[list[str], Datum], Awaitable[Messages]]
206206
MapAsyncCallable = Union[Mapper, MapAsyncHandlerCallable]
207+
208+
209+
class MapError(Exception):
210+
"""To Raise an error while executing a Map call"""
File renamed without changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import asyncio
2+
from collections.abc import AsyncIterable
3+
4+
from google.protobuf import empty_pb2 as _empty_pb2
5+
from pynumaflow.shared.asynciter import NonBlockingIterator
6+
7+
from pynumaflow._constants import _LOGGER, STREAM_EOF
8+
from pynumaflow.mapper._dtypes import MapAsyncCallable, Datum, MapError
9+
from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc
10+
from pynumaflow.shared.server import exit_on_error, handle_async_error
11+
from pynumaflow.types import NumaflowServicerContext
12+
13+
14+
class AsyncMapServicer(map_pb2_grpc.MapServicer):
15+
"""
16+
This class is used to create a new grpc Async Map Servicer instance.
17+
It implements the SyncMapServicer interface from the proto map.proto file.
18+
Provides the functionality for the required rpc methods.
19+
"""
20+
21+
def __init__(
22+
self,
23+
handler: MapAsyncCallable,
24+
):
25+
self.background_tasks = set()
26+
self.__map_handler: MapAsyncCallable = handler
27+
28+
async def MapFn(
29+
self,
30+
request_iterator: AsyncIterable[map_pb2.MapRequest],
31+
context: NumaflowServicerContext,
32+
) -> AsyncIterable[map_pb2.MapResponse]:
33+
"""
34+
Applies a function to each datum element.
35+
The pascal case function name comes from the proto map_pb2_grpc.py file.
36+
"""
37+
# proto repeated field(keys) is of type google._upb._message.RepeatedScalarContainer
38+
# we need to explicitly convert it to list
39+
try:
40+
# The first message to be received should be a valid handshake
41+
req = await request_iterator.__anext__()
42+
# check if it is a valid handshake req
43+
if not (req.handshake and req.handshake.sot):
44+
raise MapError("MapFn: expected handshake as the first message")
45+
yield map_pb2.MapResponse(handshake=map_pb2.Handshake(sot=True))
46+
47+
global_result_queue = NonBlockingIterator()
48+
49+
# reader task to process the input task and invoke the required tasks
50+
producer = asyncio.create_task(
51+
self._process_inputs(request_iterator, global_result_queue)
52+
)
53+
54+
# keep reading on result queue and send messages back
55+
consumer = global_result_queue.read_iterator()
56+
async for msg in consumer:
57+
# If the message is an exception, we raise the exception
58+
if isinstance(msg, BaseException):
59+
await handle_async_error(context, msg)
60+
return
61+
# Send window response back to the client
62+
else:
63+
yield msg
64+
# wait for the producer task to complete
65+
await producer
66+
except BaseException as e:
67+
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
68+
exit_on_error(context, repr(e))
69+
return
70+
71+
async def _process_inputs(
72+
self,
73+
request_iterator: AsyncIterable[map_pb2.MapRequest],
74+
result_queue: NonBlockingIterator,
75+
):
76+
"""
77+
Utility function for processing incoming MapRequests
78+
"""
79+
try:
80+
# for each incoming request, create a background task to execute the
81+
# UDF code
82+
async for req in request_iterator:
83+
msg_task = asyncio.create_task(self._invoke_map(req, result_queue))
84+
# save a reference to a set to store active tasks
85+
self.background_tasks.add(msg_task)
86+
msg_task.add_done_callback(self.background_tasks.discard)
87+
88+
# wait for all tasks to complete
89+
for task in self.background_tasks:
90+
await task
91+
92+
# send an EOF to result queue to indicate that all tasks have completed
93+
await result_queue.put(STREAM_EOF)
94+
95+
except BaseException as e:
96+
await result_queue.put(e)
97+
return
98+
99+
async def _invoke_map(self, req: map_pb2.MapRequest, result_queue: NonBlockingIterator):
100+
"""
101+
Invokes the user defined function.
102+
"""
103+
try:
104+
datum = Datum(
105+
keys=list(req.request.keys),
106+
value=req.request.value,
107+
event_time=req.request.event_time.ToDatetime(),
108+
watermark=req.request.watermark.ToDatetime(),
109+
headers=dict(req.request.headers),
110+
)
111+
msgs = await self.__map_handler(list(req.request.keys), datum)
112+
datums = []
113+
for msg in msgs:
114+
datums.append(
115+
map_pb2.MapResponse.Result(keys=msg.keys, value=msg.value, tags=msg.tags)
116+
)
117+
await result_queue.put(map_pb2.MapResponse(results=datums, id=req.id))
118+
except BaseException as err:
119+
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
120+
await result_queue.put(err)
121+
122+
async def IsReady(
123+
self, request: _empty_pb2.Empty, context: NumaflowServicerContext
124+
) -> map_pb2.ReadyResponse:
125+
"""
126+
IsReady is the heartbeat endpoint for gRPC.
127+
The pascal case function name comes from the proto map_pb2_grpc.py file.
128+
"""
129+
return map_pb2.ReadyResponse(ready=True)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import threading
2+
from concurrent.futures import ThreadPoolExecutor
3+
from collections.abc import Iterable
4+
5+
from google.protobuf import empty_pb2 as _empty_pb2
6+
from pynumaflow.shared.server import exit_on_error
7+
8+
from pynumaflow._constants import NUM_THREADS_DEFAULT, STREAM_EOF, _LOGGER
9+
from pynumaflow.mapper._dtypes import MapSyncCallable, Datum, MapError
10+
from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc
11+
from pynumaflow.shared.synciter import SyncIterator
12+
from pynumaflow.types import NumaflowServicerContext
13+
14+
15+
class SyncMapServicer(map_pb2_grpc.MapServicer):
16+
"""
17+
This class is used to create a new grpc Map Servicer instance.
18+
It implements the SyncMapServicer interface from the proto map.proto file.
19+
Provides the functionality for the required rpc methods.
20+
"""
21+
22+
def __init__(self, handler: MapSyncCallable, multiproc: bool = False):
23+
self.__map_handler: MapSyncCallable = handler
24+
# This indicates whether the grpc server attached is multiproc or not
25+
self.multiproc = multiproc
26+
# create a thread pool for executing UDF code
27+
self.executor = ThreadPoolExecutor(max_workers=NUM_THREADS_DEFAULT)
28+
29+
def MapFn(
30+
self,
31+
request_iterator: Iterable[map_pb2.MapRequest],
32+
context: NumaflowServicerContext,
33+
) -> Iterable[map_pb2.MapResponse]:
34+
"""
35+
Applies a function to each datum element.
36+
The pascal case function name comes from the proto map_pb2_grpc.py file.
37+
"""
38+
try:
39+
# The first message to be received should be a valid handshake
40+
req = next(request_iterator)
41+
# check if it is a valid handshake req
42+
if not (req.handshake and req.handshake.sot):
43+
raise MapError("MapFn: expected handshake as the first message")
44+
yield map_pb2.MapResponse(handshake=map_pb2.Handshake(sot=True))
45+
46+
# result queue to stream messages from the user code back to the client
47+
result_queue = SyncIterator()
48+
49+
# Reader thread to keep reading from the request iterator and schedule
50+
# execution for each of them
51+
reader_thread = threading.Thread(
52+
target=self._process_requests, args=(context, request_iterator, result_queue)
53+
)
54+
reader_thread.start()
55+
# Read the result queue and keep forwarding them upstream
56+
for res in result_queue.read_iterator():
57+
# if error handler accordingly
58+
if isinstance(res, BaseException):
59+
# Terminate the current server process due to exception
60+
exit_on_error(context, repr(res), parent=self.multiproc)
61+
return
62+
# return the result
63+
yield res
64+
65+
# wait for the threads to clean-up
66+
reader_thread.join()
67+
self.executor.shutdown(cancel_futures=True)
68+
69+
except BaseException as err:
70+
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
71+
# Terminate the current server process due to exception
72+
exit_on_error(context, repr(err), parent=self.multiproc)
73+
return
74+
75+
def _process_requests(
76+
self,
77+
context: NumaflowServicerContext,
78+
request_iterator: Iterable[map_pb2.MapRequest],
79+
result_queue: SyncIterator,
80+
):
81+
try:
82+
# read through all incoming requests and submit to the
83+
# threadpool for invocation
84+
for request in request_iterator:
85+
_ = self.executor.submit(self._invoke_map, context, request, result_queue)
86+
# wait for all tasks to finish after all requests exhausted
87+
self.executor.shutdown(wait=True)
88+
# Indicate to the result queue that no more messages left to process
89+
result_queue.put(STREAM_EOF)
90+
except BaseException as e:
91+
_LOGGER.critical("MapFn Error, re-raising the error", exc_info=True)
92+
result_queue.put(e)
93+
94+
def _invoke_map(
95+
self,
96+
context: NumaflowServicerContext,
97+
request: map_pb2.MapRequest,
98+
result_queue: SyncIterator,
99+
):
100+
try:
101+
d = Datum(
102+
keys=list(request.request.keys),
103+
value=request.request.value,
104+
event_time=request.request.event_time.ToDatetime(),
105+
watermark=request.request.watermark.ToDatetime(),
106+
headers=dict(request.request.headers),
107+
)
108+
109+
responses = self.__map_handler(list(request.request.keys), d)
110+
results = []
111+
for resp in responses:
112+
results.append(
113+
map_pb2.MapResponse.Result(
114+
keys=list(resp.keys),
115+
value=resp.value,
116+
tags=resp.tags,
117+
)
118+
)
119+
result_queue.put(map_pb2.MapResponse(results=results, id=request.id))
120+
121+
except BaseException as e:
122+
_LOGGER.critical("MapFn handler error", exc_info=True)
123+
result_queue.put(e)
124+
return
125+
126+
def IsReady(
127+
self, request: _empty_pb2.Empty, context: NumaflowServicerContext
128+
) -> map_pb2.ReadyResponse:
129+
"""
130+
IsReady is the heartbeat endpoint for gRPC.
131+
The pascal case function name comes from the proto map_pb2_grpc.py file.
132+
"""
133+
return map_pb2.ReadyResponse(ready=True)

pynumaflow/mapper/async_server.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
ContainerType,
1717
)
1818
from pynumaflow.mapper._dtypes import MapAsyncCallable
19-
from pynumaflow.mapper.servicer.async_servicer import AsyncMapServicer
19+
from pynumaflow.mapper._servicer._async_servicer import AsyncMapServicer
2020
from pynumaflow.proto.mapper import map_pb2_grpc
2121
from pynumaflow.shared.server import (
2222
NumaflowServer,

pynumaflow/mapper/multiproc_server.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
ContainerType,
1919
)
2020
from pynumaflow.mapper._dtypes import MapSyncCallable
21-
from pynumaflow.mapper.servicer.sync_servicer import SyncMapServicer
21+
from pynumaflow.mapper._servicer._sync_servicer import SyncMapServicer
2222
from pynumaflow.shared.server import (
2323
NumaflowServer,
2424
start_multiproc_server,

pynumaflow/mapper/servicer/async_servicer.py

-75
This file was deleted.

0 commit comments

Comments
 (0)