@@ -597,7 +597,6 @@ def servicer(self) -> IsolateServicer:
597597 return self ._servicer
598598
599599
600- @dataclass
601600class SingleTaskInterceptor (ServerBoundInterceptor ):
602601 """Sets server to terminate after the first Submit/Run task."""
603602
@@ -686,6 +685,25 @@ def _stop(*args):
686685 return wrap_server_method_handler (wrapper , handler )
687686
688687
688+ class ControllerAuthInterceptor (ServerBoundInterceptor ):
689+ def __init__ (self , controller_auth_key : str ) -> None :
690+ super ().__init__ ()
691+ self .controller_auth_key = controller_auth_key
692+ self ._terminator = grpc .unary_unary_rpc_method_handler (
693+ lambda request , context : context .abort (
694+ grpc .StatusCode .UNAUTHENTICATED , "Unauthorized"
695+ )
696+ )
697+
698+ def intercept_service (self , continuation , handler_call_details ):
699+ metadata = dict (handler_call_details .invocation_metadata )
700+ controller_token = metadata .get ("controller-token" )
701+ if controller_token != self .controller_auth_key :
702+ return self ._terminator
703+
704+ return continuation (handler_call_details )
705+
706+
689707def main (argv : list [str ] | None = None ) -> None :
690708 parser = ArgumentParser ()
691709 parser .add_argument ("--host" , default = "0.0.0.0" )
@@ -710,6 +728,15 @@ def main(argv: list[str] | None = None) -> None:
710728 if options .single_use :
711729 interceptors .append (SingleTaskInterceptor ())
712730
731+ if controller_auth_key := os .getenv ("ISOLATE_CONTROLLER_AUTH_KEY" ):
732+ # Set an interceptor to only accept requests with the correct auth key
733+ interceptors .append (ControllerAuthInterceptor (controller_auth_key ))
734+ else :
735+ print (
736+ "[WARN] ISOLATE_CONTROLLER_AUTH_KEY is not set, all requests will be "
737+ "accepted without authentication."
738+ )
739+
713740 server = grpc .server (
714741 futures .ThreadPoolExecutor (max_workers = options .num_workers ),
715742 options = get_default_options (),
0 commit comments