1010
1111from __future__ import annotations
1212
13+ import asyncio
1314import os
15+ import signal
1416import sys
1517import traceback
1618from argparse import ArgumentParser
1719from concurrent import futures
1820from dataclasses import dataclass
1921from typing import (
2022 Any ,
23+ AsyncIterator ,
2124 Iterable ,
22- Iterator ,
2325)
2426
25- import grpc
26- from grpc import ServicerContext , StatusCode
27+ from grpc import StatusCode , aio , local_server_credentials
28+
29+ from isolate .connections .grpc .definitions import PartialRunResult
2730
2831try :
2932 from isolate import __version__ as agent_version
@@ -48,12 +51,19 @@ def __init__(self, log_fd: int | None = None):
4851
4952 self ._run_cache : dict [str , Any ] = {}
5053 self ._log = sys .stdout if log_fd is None else os .fdopen (log_fd , "w" )
54+ self ._thread_pool = futures .ThreadPoolExecutor (max_workers = 1 )
55+
56+ def handle_termination (* args ):
57+ self .log ("Termination signal received, shutting down..." )
58+ signal .raise_signal (signal .SIGTERM )
5159
52- def Run (
60+ signal .signal (signal .SIGINT , handle_termination )
61+
62+ async def Run (
5363 self ,
5464 request : definitions .FunctionCall ,
55- context : ServicerContext ,
56- ) -> Iterator [ definitions . PartialRunResult ]:
65+ context : aio . ServicerContext ,
66+ ) -> AsyncIterator [ PartialRunResult ]:
5767 self .log (f"A connection has been established: { context .peer ()} !" )
5868 server_version = os .getenv ("ISOLATE_SERVER_VERSION" ) or "unknown"
5969 self .log (f"Isolate info: server { server_version } , agent { agent_version } " )
@@ -70,7 +80,7 @@ def Run(
7080 result ,
7181 was_it_raised ,
7282 stringized_tb ,
73- ) = self .execute_function (
83+ ) = await self .execute_function (
7484 request .setup_func ,
7585 "setup" ,
7686 )
@@ -87,15 +97,16 @@ def Run(
8797 )
8898 raise AbortException ("The setup function has thrown an error." )
8999 except AbortException as exc :
90- return self .abort_with_msg (context , exc .message )
100+ self .abort_with_msg (context , exc .message )
101+ return
91102 else :
92103 assert not was_it_raised
93104 self ._run_cache [cache_key ] = result
94105
95106 extra_args .append (self ._run_cache [cache_key ])
96107
97108 try :
98- result , was_it_raised , stringized_tb = self .execute_function (
109+ result , was_it_raised , stringized_tb = await self .execute_function (
99110 request .function ,
100111 "function" ,
101112 extra_args = extra_args ,
@@ -107,9 +118,10 @@ def Run(
107118 stringized_tb ,
108119 )
109120 except AbortException as exc :
110- return self .abort_with_msg (context , exc .message )
121+ self .abort_with_msg (context , exc .message )
122+ return
111123
112- def execute_function (
124+ async def execute_function (
113125 self ,
114126 function : definitions .SerializedObject ,
115127 function_kind : str ,
@@ -143,14 +155,34 @@ def execute_function(
143155 was_it_raised = False
144156 stringized_tb = None
145157 try :
146- result = function (* extra_args )
158+ # Newer fal SDK will mark async entrypoints with `_run_on_main_thread` so
159+ # we execute on the main loop and can await the coroutine they return.
160+ # Older fal SDK still call `asyncio.run(...)`.
161+ # To avoid error "asyncio.run() cannot be called from a running event loop"
162+ # and be backward compatible,
163+ # we offload those unflagged functions to a thread pool.
164+
165+ if getattr (function , "_run_on_main_thread" , False ):
166+ result = function (* extra_args )
167+ else :
168+ result = self ._thread_pool .submit (function , * extra_args ).result ()
169+
170+ if asyncio .iscoroutine (result ):
171+ result = await result
172+
147173 except BaseException as exc :
148174 result = exc
149175 was_it_raised = True
150176 num_frames = len (traceback .extract_stack ()[:- 5 ])
151177 stringized_tb = "" .join (traceback .format_exc (limit = - num_frames ))
152178
153- self .log (f"Completed the execution of the { function_kind } function." )
179+ if not was_it_raised :
180+ self .log (f"Completed the execution of the { function_kind } function." )
181+ else :
182+ self .log (
183+ f"Completed the execution of the { function_kind } function"
184+ f" with an error: { result } \n Traceback:\n { stringized_tb } "
185+ )
154186 return result , was_it_raised , stringized_tb
155187
156188 def send_object (
@@ -195,7 +227,7 @@ def log(self, message: str) -> None:
195227
196228 def abort_with_msg (
197229 self ,
198- context : ServicerContext ,
230+ context : aio . ServicerContext ,
199231 message : str ,
200232 * ,
201233 code : StatusCode = StatusCode .INVALID_ARGUMENT ,
@@ -205,23 +237,26 @@ def abort_with_msg(
205237 return None
206238
207239
208- def create_server (address : str ) -> grpc .Server :
240+ def create_server (address : str ) -> aio .Server :
209241 """Create a new (temporary) gRPC server listening on the given
210242 address."""
211- server = grpc .server (
212- futures .ThreadPoolExecutor (max_workers = 1 ),
213- maximum_concurrent_rpcs = 1 ,
243+ # Use asyncio server so requests can run in the main thread and intercept signals
244+ # There seems to be a weird bug with grpcio that makes subsequent requests fail with
245+ # concurrent rpc limit exceeded if we set maximum_current_rpcs to 1. Setting it to 2
246+ # fixes it, even though in practice, we only run one request at a time.
247+ server = aio .server (
248+ maximum_concurrent_rpcs = 2 ,
214249 options = get_default_options (),
215250 )
216251
217252 # Local server credentials allow us to ensure that the
218253 # connection is established by a local process.
219- server_credentials = grpc . local_server_credentials ()
254+ server_credentials = local_server_credentials ()
220255 server .add_secure_port (address , server_credentials )
221256 return server
222257
223258
224- def run_agent (address : str , log_fd : int | None = None ) -> int :
259+ async def run_agent (address : str , log_fd : int | None = None ) -> int :
225260 """Run the agent servicer on the given address."""
226261 server = create_server (address )
227262 servicer = AgentServicer (log_fd = log_fd )
@@ -231,19 +266,19 @@ def run_agent(address: str, log_fd: int | None = None) -> int:
231266 # not have any global side effects.
232267 definitions .register_agent (servicer , server )
233268
234- server .start ()
235- server .wait_for_termination ()
269+ await server .start ()
270+ await server .wait_for_termination ()
236271 return 0
237272
238273
239- def main () -> int :
274+ async def main () -> int :
240275 parser = ArgumentParser ()
241276 parser .add_argument ("address" , type = str )
242277 parser .add_argument ("--log-fd" , type = int )
243278
244279 options = parser .parse_args ()
245- return run_agent (options .address , log_fd = options .log_fd )
280+ return await run_agent (options .address , log_fd = options .log_fd )
246281
247282
248283if __name__ == "__main__" :
249- main ()
284+ asyncio . run ( main () )
0 commit comments