5
5
from google .protobuf import empty_pb2 as _empty_pb2
6
6
7
7
from pynumaflow .batchmapper import Datum
8
- from pynumaflow .batchmapper ._dtypes import BatchMapCallable
9
- from pynumaflow .proto .batchmapper import batchmap_pb2 , batchmap_pb2_grpc
8
+ from pynumaflow .batchmapper ._dtypes import BatchMapCallable , BatchMapError
9
+ from pynumaflow .proto .mapper import map_pb2 , map_pb2_grpc
10
10
from pynumaflow .shared .asynciter import NonBlockingIterator
11
11
from pynumaflow .shared .server import exit_on_error
12
12
from pynumaflow .types import NumaflowServicerContext
13
13
from pynumaflow ._constants import _LOGGER , STREAM_EOF
14
14
15
15
16
- async def datum_generator (
17
- request_iterator : AsyncIterable [batchmap_pb2 .BatchMapRequest ],
18
- ) -> AsyncIterable [Datum ]:
19
- """
20
- This function is used to create an async generator
21
- from the gRPC request iterator.
22
- It yields a Datum instance for each request received which is then
23
- forwarded to the UDF.
24
- """
25
- async for d in request_iterator :
26
- request = Datum (
27
- keys = d .keys ,
28
- value = d .value ,
29
- event_time = d .event_time .ToDatetime (),
30
- watermark = d .watermark .ToDatetime (),
31
- headers = dict (d .headers ),
32
- id = d .id ,
33
- )
34
- yield request
35
-
36
-
37
- class AsyncBatchMapServicer (batchmap_pb2_grpc .BatchMapServicer ):
16
+ class AsyncBatchMapServicer (map_pb2_grpc .MapServicer ):
38
17
"""
39
18
This class is used to create a new grpc Batch Map Servicer instance.
40
- It implements the BatchMapServicer interface from the proto
41
- batchmap_pb2_grpc .py file.
19
+ It implements the MapServicer interface from the proto
20
+ map_pb2_grpc .py file.
42
21
Provides the functionality for the required rpc methods.
43
22
"""
44
23
@@ -49,41 +28,74 @@ def __init__(
49
28
self .background_tasks = set ()
50
29
self .__batch_map_handler : BatchMapCallable = handler
51
30
52
- async def BatchMapFn (
31
+ async def MapFn (
53
32
self ,
54
- request_iterator : AsyncIterable [batchmap_pb2 . BatchMapRequest ],
33
+ request_iterator : AsyncIterable [map_pb2 . MapRequest ],
55
34
context : NumaflowServicerContext ,
56
- ) -> batchmap_pb2 . BatchMapResponse :
35
+ ) -> AsyncIterable [ map_pb2 . MapResponse ] :
57
36
"""
58
- Applies a batch map function to a BatchMapRequest stream in a batching mode.
59
- The pascal case function name comes from the proto batchmap_pb2_grpc .py file.
37
+ Applies a batch map function to a MapRequest stream in a batching mode.
38
+ The pascal case function name comes from the proto map_pb2_grpc .py file.
60
39
"""
61
- # Create an async iterator from the request iterator
62
- datum_iterator = datum_generator (request_iterator = request_iterator )
63
-
64
40
try :
65
- # invoke the UDF call for batch map
66
- responses , request_counter = await self . invoke_batch_map ( datum_iterator )
67
-
68
- # If the number of responses received does not align with the request batch size,
69
- # we will not be able to process the data correctly.
70
- # This should be marked as an error and raised to the user.
71
- if len ( responses ) != request_counter :
72
- err_msg = "batchMapFn: mismatch between length of batch requests and responses"
73
- raise Exception ( err_msg )
74
-
75
- # iterate over the responses received and covert to the required proto format
76
- for batch_response in responses :
77
- single_req_resp = []
78
- for msg in batch_response . messages :
79
- single_req_resp . append (
80
- batchmap_pb2 . BatchMapResponse . Result (
81
- keys = msg . keys , value = msg . value , tags = msg . tags
82
- )
41
+ # The first message to be received should be a valid handshake
42
+ req = await request_iterator . __anext__ ( )
43
+ # check if it is a valid handshake req
44
+ if not ( req . handshake and req . handshake . sot ):
45
+ raise BatchMapError ( "BatchMapFn: expected handshake as the first message" )
46
+ yield map_pb2 . MapResponse ( handshake = map_pb2 . Handshake ( sot = True ))
47
+
48
+ # cur_task is used to track the task (coroutine) processing
49
+ # the current batch of messages.
50
+ cur_task = None
51
+ # iterate of the incoming messages ot the sink
52
+ async for d in request_iterator :
53
+ # if we do not have any active task currently processing the batch
54
+ # we need to create one and call the User function for processing the same.
55
+ if cur_task is None :
56
+ req_queue = NonBlockingIterator ()
57
+ cur_task = asyncio . create_task (
58
+ self . __batch_map_handler ( req_queue . read_iterator () )
83
59
)
84
-
85
- # send the response for a given ID back to the stream
86
- yield batchmap_pb2 .BatchMapResponse (id = batch_response .id , results = single_req_resp )
60
+ self .background_tasks .add (cur_task )
61
+ cur_task .add_done_callback (self .background_tasks .discard )
62
+ # when we have end of transmission message, we need to stop the processing the
63
+ # current batch and wait for the next batch of messages.
64
+ # We will also wait for the current task to finish processing the current batch.
65
+ # We mark the current task as None to indicate that we are
66
+ # ready to process the next batch.
67
+ if d .status and d .status .eot :
68
+ await req_queue .put (STREAM_EOF )
69
+ await cur_task
70
+ ret = cur_task .result ()
71
+
72
+ # iterate over the responses received and covert to the required proto format
73
+ for batch_response in ret :
74
+ single_req_resp = []
75
+ for msg in batch_response .messages :
76
+ single_req_resp .append (
77
+ map_pb2 .MapResponse .Result (
78
+ keys = msg .keys , value = msg .value , tags = msg .tags
79
+ )
80
+ )
81
+ # send the response for a given ID back to the stream
82
+ yield map_pb2 .MapResponse (id = batch_response .id , results = single_req_resp )
83
+
84
+ # send EOT after each finishing Batch responses
85
+ yield map_pb2 .MapResponse (status = map_pb2 .TransmissionStatus (eot = True ))
86
+ cur_task = None
87
+ continue
88
+
89
+ # if we have a valid message, we will add it to the request queue for processing.
90
+ datum = Datum (
91
+ keys = list (d .request .keys ),
92
+ value = d .request .value ,
93
+ event_time = d .request .event_time .ToDatetime (),
94
+ watermark = d .request .watermark .ToDatetime (),
95
+ headers = dict (d .request .headers ),
96
+ id = d .id ,
97
+ )
98
+ await req_queue .put (datum )
87
99
88
100
except BaseException as err :
89
101
_LOGGER .critical ("UDFError, re-raising the error" , exc_info = True )
@@ -93,42 +105,11 @@ async def BatchMapFn(
93
105
exit_on_error (context , repr (err ))
94
106
return
95
107
96
- async def invoke_batch_map (self , datum_iterator : AsyncIterable [Datum ]):
97
- """
98
- # iterate over the incoming requests, and keep sending to the user code
99
- # once all messages have been sent, we wait for the responses
100
- """
101
- # create a message queue to send to the user code
102
- niter = NonBlockingIterator ()
103
- riter = niter .read_iterator ()
104
- # create a task for invoking the UDF handler
105
- task = asyncio .create_task (self .__batch_map_handler (riter ))
106
- # Save a reference to the result of this function, to avoid a
107
- # task disappearing mid-execution.
108
- self .background_tasks .add (task )
109
- task .add_done_callback (lambda t : self .background_tasks .remove (t ))
110
-
111
- req_count = 0
112
- # start streaming the messages to the UDF code, and increment the request counter
113
- async for datum in datum_iterator :
114
- await niter .put (datum )
115
- req_count += 1
116
-
117
- # once all messages have been exhausted, send an EOF to indicate end of messages
118
- # to the UDF
119
- await niter .put (STREAM_EOF )
120
-
121
- # wait for all the responses
122
- await task
123
-
124
- # return the result from the UDF, along with the request_counter
125
- return task .result (), req_count
126
-
127
108
async def IsReady (
128
109
self , request : _empty_pb2 .Empty , context : NumaflowServicerContext
129
- ) -> batchmap_pb2 .ReadyResponse :
110
+ ) -> map_pb2 .ReadyResponse :
130
111
"""
131
112
IsReady is the heartbeat endpoint for gRPC.
132
113
The pascal case function name comes from the proto batchmap_pb2_grpc.py file.
133
114
"""
134
- return batchmap_pb2 .ReadyResponse (ready = True )
115
+ return map_pb2 .ReadyResponse (ready = True )
0 commit comments