Skip to content

Commit f2f7bf6

Browse files
kohlisidvigith
andauthored
feat: update batchmap and mapstream to use Map proto (numaproj#200)
Signed-off-by: Sidhant Kohli <[email protected]> Co-authored-by: Vigith Maurice <[email protected]>
1 parent 688132a commit f2f7bf6

22 files changed

+252
-803
lines changed

pynumaflow/batchmapper/_dtypes.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from datetime import datetime
55
from typing import TypeVar, Callable, Union, Optional
66
from collections.abc import AsyncIterable
7-
from collections.abc import Awaitable
87

98
from pynumaflow._constants import DROP
109

@@ -222,5 +221,9 @@ async def handler(self, datums: AsyncIterable[Datum]) -> BatchResponses:
222221
pass
223222

224223

225-
BatchMapAsyncCallable = Callable[[AsyncIterable[Datum]], Awaitable[BatchResponses]]
224+
BatchMapAsyncCallable = Callable[[AsyncIterable[Datum]], BatchResponses]
226225
BatchMapCallable = Union[BatchMapper, BatchMapAsyncCallable]
226+
227+
228+
class BatchMapError(Exception):
229+
"""To Raise an error while executing a BatchMap call"""

pynumaflow/batchmapper/async_server.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
MINIMUM_NUMAFLOW_VERSION,
1919
ContainerType,
2020
)
21-
from pynumaflow.proto.batchmapper import batchmap_pb2_grpc
21+
from pynumaflow.proto.mapper import map_pb2_grpc
2222
from pynumaflow.shared.server import NumaflowServer, start_async_server
2323

2424

@@ -103,7 +103,7 @@ async def aexec(self):
103103
# Create a new async server instance and add the servicer to it
104104
server = grpc.aio.server(options=self._server_options)
105105
server.add_insecure_port(self.sock_path)
106-
batchmap_pb2_grpc.add_BatchMapServicer_to_server(
106+
map_pb2_grpc.add_MapServicer_to_server(
107107
self.servicer,
108108
server,
109109
)

pynumaflow/batchmapper/servicer/async_servicer.py

+69-88
Original file line numberDiff line numberDiff line change
@@ -5,40 +5,19 @@
55
from google.protobuf import empty_pb2 as _empty_pb2
66

77
from pynumaflow.batchmapper import Datum
8-
from pynumaflow.batchmapper._dtypes import BatchMapCallable
9-
from pynumaflow.proto.batchmapper import batchmap_pb2, batchmap_pb2_grpc
8+
from pynumaflow.batchmapper._dtypes import BatchMapCallable, BatchMapError
9+
from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc
1010
from pynumaflow.shared.asynciter import NonBlockingIterator
1111
from pynumaflow.shared.server import exit_on_error
1212
from pynumaflow.types import NumaflowServicerContext
1313
from pynumaflow._constants import _LOGGER, STREAM_EOF
1414

1515

16-
async def datum_generator(
17-
request_iterator: AsyncIterable[batchmap_pb2.BatchMapRequest],
18-
) -> AsyncIterable[Datum]:
19-
"""
20-
This function is used to create an async generator
21-
from the gRPC request iterator.
22-
It yields a Datum instance for each request received which is then
23-
forwarded to the UDF.
24-
"""
25-
async for d in request_iterator:
26-
request = Datum(
27-
keys=d.keys,
28-
value=d.value,
29-
event_time=d.event_time.ToDatetime(),
30-
watermark=d.watermark.ToDatetime(),
31-
headers=dict(d.headers),
32-
id=d.id,
33-
)
34-
yield request
35-
36-
37-
class AsyncBatchMapServicer(batchmap_pb2_grpc.BatchMapServicer):
16+
class AsyncBatchMapServicer(map_pb2_grpc.MapServicer):
3817
"""
3918
This class is used to create a new grpc Batch Map Servicer instance.
40-
It implements the BatchMapServicer interface from the proto
41-
batchmap_pb2_grpc.py file.
19+
It implements the MapServicer interface from the proto
20+
map_pb2_grpc.py file.
4221
Provides the functionality for the required rpc methods.
4322
"""
4423

@@ -49,41 +28,74 @@ def __init__(
4928
self.background_tasks = set()
5029
self.__batch_map_handler: BatchMapCallable = handler
5130

52-
async def BatchMapFn(
31+
async def MapFn(
5332
self,
54-
request_iterator: AsyncIterable[batchmap_pb2.BatchMapRequest],
33+
request_iterator: AsyncIterable[map_pb2.MapRequest],
5534
context: NumaflowServicerContext,
56-
) -> batchmap_pb2.BatchMapResponse:
35+
) -> AsyncIterable[map_pb2.MapResponse]:
5736
"""
58-
Applies a batch map function to a BatchMapRequest stream in a batching mode.
59-
The pascal case function name comes from the proto batchmap_pb2_grpc.py file.
37+
Applies a batch map function to a MapRequest stream in a batching mode.
38+
The pascal case function name comes from the proto map_pb2_grpc.py file.
6039
"""
61-
# Create an async iterator from the request iterator
62-
datum_iterator = datum_generator(request_iterator=request_iterator)
63-
6440
try:
65-
# invoke the UDF call for batch map
66-
responses, request_counter = await self.invoke_batch_map(datum_iterator)
67-
68-
# If the number of responses received does not align with the request batch size,
69-
# we will not be able to process the data correctly.
70-
# This should be marked as an error and raised to the user.
71-
if len(responses) != request_counter:
72-
err_msg = "batchMapFn: mismatch between length of batch requests and responses"
73-
raise Exception(err_msg)
74-
75-
# iterate over the responses received and covert to the required proto format
76-
for batch_response in responses:
77-
single_req_resp = []
78-
for msg in batch_response.messages:
79-
single_req_resp.append(
80-
batchmap_pb2.BatchMapResponse.Result(
81-
keys=msg.keys, value=msg.value, tags=msg.tags
82-
)
41+
# The first message to be received should be a valid handshake
42+
req = await request_iterator.__anext__()
43+
# check if it is a valid handshake req
44+
if not (req.handshake and req.handshake.sot):
45+
raise BatchMapError("BatchMapFn: expected handshake as the first message")
46+
yield map_pb2.MapResponse(handshake=map_pb2.Handshake(sot=True))
47+
48+
# cur_task is used to track the task (coroutine) processing
49+
# the current batch of messages.
50+
cur_task = None
51+
# iterate of the incoming messages ot the sink
52+
async for d in request_iterator:
53+
# if we do not have any active task currently processing the batch
54+
# we need to create one and call the User function for processing the same.
55+
if cur_task is None:
56+
req_queue = NonBlockingIterator()
57+
cur_task = asyncio.create_task(
58+
self.__batch_map_handler(req_queue.read_iterator())
8359
)
84-
85-
# send the response for a given ID back to the stream
86-
yield batchmap_pb2.BatchMapResponse(id=batch_response.id, results=single_req_resp)
60+
self.background_tasks.add(cur_task)
61+
cur_task.add_done_callback(self.background_tasks.discard)
62+
# when we have end of transmission message, we need to stop the processing the
63+
# current batch and wait for the next batch of messages.
64+
# We will also wait for the current task to finish processing the current batch.
65+
# We mark the current task as None to indicate that we are
66+
# ready to process the next batch.
67+
if d.status and d.status.eot:
68+
await req_queue.put(STREAM_EOF)
69+
await cur_task
70+
ret = cur_task.result()
71+
72+
# iterate over the responses received and covert to the required proto format
73+
for batch_response in ret:
74+
single_req_resp = []
75+
for msg in batch_response.messages:
76+
single_req_resp.append(
77+
map_pb2.MapResponse.Result(
78+
keys=msg.keys, value=msg.value, tags=msg.tags
79+
)
80+
)
81+
# send the response for a given ID back to the stream
82+
yield map_pb2.MapResponse(id=batch_response.id, results=single_req_resp)
83+
84+
# send EOT after each finishing Batch responses
85+
yield map_pb2.MapResponse(status=map_pb2.TransmissionStatus(eot=True))
86+
cur_task = None
87+
continue
88+
89+
# if we have a valid message, we will add it to the request queue for processing.
90+
datum = Datum(
91+
keys=list(d.request.keys),
92+
value=d.request.value,
93+
event_time=d.request.event_time.ToDatetime(),
94+
watermark=d.request.watermark.ToDatetime(),
95+
headers=dict(d.request.headers),
96+
id=d.id,
97+
)
98+
await req_queue.put(datum)
8799

88100
except BaseException as err:
89101
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
@@ -93,42 +105,11 @@ async def BatchMapFn(
93105
exit_on_error(context, repr(err))
94106
return
95107

96-
async def invoke_batch_map(self, datum_iterator: AsyncIterable[Datum]):
97-
"""
98-
# iterate over the incoming requests, and keep sending to the user code
99-
# once all messages have been sent, we wait for the responses
100-
"""
101-
# create a message queue to send to the user code
102-
niter = NonBlockingIterator()
103-
riter = niter.read_iterator()
104-
# create a task for invoking the UDF handler
105-
task = asyncio.create_task(self.__batch_map_handler(riter))
106-
# Save a reference to the result of this function, to avoid a
107-
# task disappearing mid-execution.
108-
self.background_tasks.add(task)
109-
task.add_done_callback(lambda t: self.background_tasks.remove(t))
110-
111-
req_count = 0
112-
# start streaming the messages to the UDF code, and increment the request counter
113-
async for datum in datum_iterator:
114-
await niter.put(datum)
115-
req_count += 1
116-
117-
# once all messages have been exhausted, send an EOF to indicate end of messages
118-
# to the UDF
119-
await niter.put(STREAM_EOF)
120-
121-
# wait for all the responses
122-
await task
123-
124-
# return the result from the UDF, along with the request_counter
125-
return task.result(), req_count
126-
127108
async def IsReady(
128109
self, request: _empty_pb2.Empty, context: NumaflowServicerContext
129-
) -> batchmap_pb2.ReadyResponse:
110+
) -> map_pb2.ReadyResponse:
130111
"""
131112
IsReady is the heartbeat endpoint for gRPC.
132113
The pascal case function name comes from the proto batchmap_pb2_grpc.py file.
133114
"""
134-
return batchmap_pb2.ReadyResponse(ready=True)
115+
return map_pb2.ReadyResponse(ready=True)

pynumaflow/mapstreamer/_dtypes.py

+4
Original file line numberDiff line numberDiff line change
@@ -201,3 +201,7 @@ async def handler(self, keys: list[str], datum: Datum) -> AsyncIterable[Message]
201201

202202
MapStreamAsyncCallable = Callable[[list[str], Datum], AsyncIterable[Message]]
203203
MapStreamCallable = Union[MapStreamer, MapStreamAsyncCallable]
204+
205+
206+
class MapStreamError(Exception):
207+
"""To Raise an error while executing a MapStream call"""

pynumaflow/mapstreamer/async_server.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
ContainerType,
1010
)
1111
from pynumaflow.mapstreamer.servicer.async_servicer import AsyncMapStreamServicer
12-
from pynumaflow.proto.mapstreamer import mapstream_pb2_grpc
12+
from pynumaflow.proto.mapper import map_pb2_grpc
1313

1414
from pynumaflow._constants import (
1515
MAP_STREAM_SOCK_PATH,
@@ -122,7 +122,7 @@ async def aexec(self):
122122
# Create a new async server instance and add the servicer to it
123123
server = grpc.aio.server(options=self._server_options)
124124
server.add_insecure_port(self.sock_path)
125-
mapstream_pb2_grpc.add_MapStreamServicer_to_server(
125+
map_pb2_grpc.add_MapServicer_to_server(
126126
self.servicer,
127127
server,
128128
)

pynumaflow/mapstreamer/servicer/async_servicer.py

+37-31
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,18 @@
33
from google.protobuf import empty_pb2 as _empty_pb2
44

55
from pynumaflow.mapstreamer import Datum
6-
from pynumaflow.mapstreamer._dtypes import MapStreamCallable
7-
from pynumaflow.proto.mapstreamer import mapstream_pb2_grpc, mapstream_pb2
6+
from pynumaflow.mapstreamer._dtypes import MapStreamCallable, MapStreamError
7+
from pynumaflow.proto.mapper import map_pb2_grpc, map_pb2
88
from pynumaflow.shared.server import exit_on_error
99
from pynumaflow.types import NumaflowServicerContext
1010
from pynumaflow._constants import _LOGGER
1111

1212

13-
class AsyncMapStreamServicer(mapstream_pb2_grpc.MapStreamServicer):
13+
class AsyncMapStreamServicer(map_pb2_grpc.MapServicer):
1414
"""
1515
This class is used to create a new grpc Map Stream Servicer instance.
1616
It implements the SyncMapServicer interface from the proto
17-
mapstream_pb2_grpc.py file.
17+
map_pb2_grpc.py file.
1818
Provides the functionality for the required rpc methods.
1919
"""
2020

@@ -24,52 +24,58 @@ def __init__(
2424
):
2525
self.__map_stream_handler: MapStreamCallable = handler
2626

27-
async def MapStreamFn(
27+
async def MapFn(
2828
self,
29-
request: mapstream_pb2.MapStreamRequest,
29+
request_iterator: AsyncIterable[map_pb2.MapRequest],
3030
context: NumaflowServicerContext,
31-
) -> AsyncIterable[mapstream_pb2.MapStreamResponse]:
31+
) -> AsyncIterable[map_pb2.MapResponse]:
3232
"""
3333
Applies a map function to a datum stream in streaming mode.
34-
The pascal case function name comes from the proto mapstream_pb2_grpc.py file.
34+
The pascal case function name comes from the proto map_pb2_grpc.py file.
3535
"""
36-
3736
try:
38-
async for res in self.__invoke_map_stream(
39-
list(request.keys),
40-
Datum(
41-
keys=list(request.keys),
42-
value=request.value,
43-
event_time=request.event_time.ToDatetime(),
44-
watermark=request.watermark.ToDatetime(),
45-
headers=dict(request.headers),
46-
),
47-
context,
48-
):
49-
yield mapstream_pb2.MapStreamResponse(result=res)
37+
# The first message to be received should be a valid handshake
38+
req = await request_iterator.__anext__()
39+
# check if it is a valid handshake req
40+
if not (req.handshake and req.handshake.sot):
41+
raise MapStreamError("MapStreamFn: expected handshake as the first message")
42+
yield map_pb2.MapResponse(handshake=map_pb2.Handshake(sot=True))
43+
44+
# read for each input request
45+
async for req in request_iterator:
46+
# yield messages as received from the UDF
47+
async for res in self.__invoke_map_stream(
48+
list(req.request.keys),
49+
Datum(
50+
keys=list(req.request.keys),
51+
value=req.request.value,
52+
event_time=req.request.event_time.ToDatetime(),
53+
watermark=req.request.watermark.ToDatetime(),
54+
headers=dict(req.request.headers),
55+
),
56+
):
57+
yield map_pb2.MapResponse(results=[res], id=req.id)
58+
# send EOT to indicate end of transmission for a given message
59+
yield map_pb2.MapResponse(status=map_pb2.TransmissionStatus(eot=True), id=req.id)
5060
except BaseException as err:
5161
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
5262
exit_on_error(context, repr(err))
5363
return
5464

55-
async def __invoke_map_stream(
56-
self, keys: list[str], req: Datum, context: NumaflowServicerContext
57-
):
65+
async def __invoke_map_stream(self, keys: list[str], req: Datum):
5866
try:
67+
# Invoke the user handler for map stream
5968
async for msg in self.__map_stream_handler(keys, req):
60-
yield mapstream_pb2.MapStreamResponse.Result(
61-
keys=msg.keys, value=msg.value, tags=msg.tags
62-
)
69+
yield map_pb2.MapResponse.Result(keys=msg.keys, value=msg.value, tags=msg.tags)
6370
except BaseException as err:
6471
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
65-
exit_on_error(context, repr(err))
6672
raise err
6773

6874
async def IsReady(
6975
self, request: _empty_pb2.Empty, context: NumaflowServicerContext
70-
) -> mapstream_pb2.ReadyResponse:
76+
) -> map_pb2.ReadyResponse:
7177
"""
7278
IsReady is the heartbeat endpoint for gRPC.
73-
The pascal case function name comes from the proto mapstream_pb2_grpc.py file.
79+
The pascal case function name comes from the proto map_pb2_grpc.py file.
7480
"""
75-
return mapstream_pb2.ReadyResponse(ready=True)
81+
return map_pb2.ReadyResponse(ready=True)

pynumaflow/proto/batchmapper/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)