|
| 1 | +import threading |
| 2 | +from concurrent.futures import ThreadPoolExecutor |
| 3 | +from collections.abc import Iterable |
| 4 | + |
| 5 | +from google.protobuf import empty_pb2 as _empty_pb2 |
| 6 | +from pynumaflow.shared.server import exit_on_error |
| 7 | + |
| 8 | +from pynumaflow._constants import NUM_THREADS_DEFAULT, STREAM_EOF, _LOGGER |
| 9 | +from pynumaflow.mapper._dtypes import MapSyncCallable, Datum, MapError |
| 10 | +from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc |
| 11 | +from pynumaflow.shared.synciter import SyncIterator |
| 12 | +from pynumaflow.types import NumaflowServicerContext |
| 13 | + |
| 14 | + |
| 15 | +class SyncMapServicer(map_pb2_grpc.MapServicer): |
| 16 | + """ |
| 17 | + This class is used to create a new grpc Map Servicer instance. |
| 18 | + It implements the SyncMapServicer interface from the proto map.proto file. |
| 19 | + Provides the functionality for the required rpc methods. |
| 20 | + """ |
| 21 | + |
| 22 | + def __init__(self, handler: MapSyncCallable, multiproc: bool = False): |
| 23 | + self.__map_handler: MapSyncCallable = handler |
| 24 | + # This indicates whether the grpc server attached is multiproc or not |
| 25 | + self.multiproc = multiproc |
| 26 | + # create a thread pool for executing UDF code |
| 27 | + self.executor = ThreadPoolExecutor(max_workers=NUM_THREADS_DEFAULT) |
| 28 | + |
| 29 | + def MapFn( |
| 30 | + self, |
| 31 | + request_iterator: Iterable[map_pb2.MapRequest], |
| 32 | + context: NumaflowServicerContext, |
| 33 | + ) -> Iterable[map_pb2.MapResponse]: |
| 34 | + """ |
| 35 | + Applies a function to each datum element. |
| 36 | + The pascal case function name comes from the proto map_pb2_grpc.py file. |
| 37 | + """ |
| 38 | + try: |
| 39 | + # The first message to be received should be a valid handshake |
| 40 | + req = next(request_iterator) |
| 41 | + # check if it is a valid handshake req |
| 42 | + if not (req.handshake and req.handshake.sot): |
| 43 | + raise MapError("MapFn: expected handshake as the first message") |
| 44 | + yield map_pb2.MapResponse(handshake=map_pb2.Handshake(sot=True)) |
| 45 | + |
| 46 | + # result queue to stream messages from the user code back to the client |
| 47 | + result_queue = SyncIterator() |
| 48 | + |
| 49 | + # Reader thread to keep reading from the request iterator and schedule |
| 50 | + # execution for each of them |
| 51 | + reader_thread = threading.Thread( |
| 52 | + target=self._process_requests, args=(context, request_iterator, result_queue) |
| 53 | + ) |
| 54 | + reader_thread.start() |
| 55 | + # Read the result queue and keep forwarding them upstream |
| 56 | + for res in result_queue.read_iterator(): |
| 57 | + # if error handler accordingly |
| 58 | + if isinstance(res, BaseException): |
| 59 | + # Terminate the current server process due to exception |
| 60 | + exit_on_error(context, repr(res), parent=self.multiproc) |
| 61 | + return |
| 62 | + # return the result |
| 63 | + yield res |
| 64 | + |
| 65 | + # wait for the threads to clean-up |
| 66 | + reader_thread.join() |
| 67 | + self.executor.shutdown(cancel_futures=True) |
| 68 | + |
| 69 | + except BaseException as err: |
| 70 | + _LOGGER.critical("UDFError, re-raising the error", exc_info=True) |
| 71 | + # Terminate the current server process due to exception |
| 72 | + exit_on_error(context, repr(err), parent=self.multiproc) |
| 73 | + return |
| 74 | + |
| 75 | + def _process_requests( |
| 76 | + self, |
| 77 | + context: NumaflowServicerContext, |
| 78 | + request_iterator: Iterable[map_pb2.MapRequest], |
| 79 | + result_queue: SyncIterator, |
| 80 | + ): |
| 81 | + try: |
| 82 | + # read through all incoming requests and submit to the |
| 83 | + # threadpool for invocation |
| 84 | + for request in request_iterator: |
| 85 | + _ = self.executor.submit(self._invoke_map, context, request, result_queue) |
| 86 | + # wait for all tasks to finish after all requests exhausted |
| 87 | + self.executor.shutdown(wait=True) |
| 88 | + # Indicate to the result queue that no more messages left to process |
| 89 | + result_queue.put(STREAM_EOF) |
| 90 | + except BaseException as e: |
| 91 | + _LOGGER.critical("MapFn Error, re-raising the error", exc_info=True) |
| 92 | + result_queue.put(e) |
| 93 | + |
| 94 | + def _invoke_map( |
| 95 | + self, |
| 96 | + context: NumaflowServicerContext, |
| 97 | + request: map_pb2.MapRequest, |
| 98 | + result_queue: SyncIterator, |
| 99 | + ): |
| 100 | + try: |
| 101 | + d = Datum( |
| 102 | + keys=list(request.request.keys), |
| 103 | + value=request.request.value, |
| 104 | + event_time=request.request.event_time.ToDatetime(), |
| 105 | + watermark=request.request.watermark.ToDatetime(), |
| 106 | + headers=dict(request.request.headers), |
| 107 | + ) |
| 108 | + |
| 109 | + responses = self.__map_handler(list(request.request.keys), d) |
| 110 | + results = [] |
| 111 | + for resp in responses: |
| 112 | + results.append( |
| 113 | + map_pb2.MapResponse.Result( |
| 114 | + keys=list(resp.keys), |
| 115 | + value=resp.value, |
| 116 | + tags=resp.tags, |
| 117 | + ) |
| 118 | + ) |
| 119 | + result_queue.put(map_pb2.MapResponse(results=results, id=request.id)) |
| 120 | + |
| 121 | + except BaseException as e: |
| 122 | + _LOGGER.critical("MapFn handler error", exc_info=True) |
| 123 | + result_queue.put(e) |
| 124 | + return |
| 125 | + |
| 126 | + def IsReady( |
| 127 | + self, request: _empty_pb2.Empty, context: NumaflowServicerContext |
| 128 | + ) -> map_pb2.ReadyResponse: |
| 129 | + """ |
| 130 | + IsReady is the heartbeat endpoint for gRPC. |
| 131 | + The pascal case function name comes from the proto map_pb2_grpc.py file. |
| 132 | + """ |
| 133 | + return map_pb2.ReadyResponse(ready=True) |
0 commit comments