Skip to content

Commit e3b82e4

Browse files
authored
feat: if controller auth key is set, use it for auth of all requests (#201)
* feat: if controller auth key is set, use it for auth of all requests * remove print * type * warn
1 parent cea7892 commit e3b82e4

File tree

2 files changed

+69
-1
lines changed

2 files changed

+69
-1
lines changed

src/isolate/server/server.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,6 @@ def servicer(self) -> IsolateServicer:
597597
return self._servicer
598598

599599

600-
@dataclass
601600
class SingleTaskInterceptor(ServerBoundInterceptor):
602601
"""Sets server to terminate after the first Submit/Run task."""
603602

@@ -686,6 +685,25 @@ def _stop(*args):
686685
return wrap_server_method_handler(wrapper, handler)
687686

688687

688+
class ControllerAuthInterceptor(ServerBoundInterceptor):
689+
def __init__(self, controller_auth_key: str) -> None:
690+
super().__init__()
691+
self.controller_auth_key = controller_auth_key
692+
self._terminator = grpc.unary_unary_rpc_method_handler(
693+
lambda request, context: context.abort(
694+
grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"
695+
)
696+
)
697+
698+
def intercept_service(self, continuation, handler_call_details):
699+
metadata = dict(handler_call_details.invocation_metadata)
700+
controller_token = metadata.get("controller-token")
701+
if controller_token != self.controller_auth_key:
702+
return self._terminator
703+
704+
return continuation(handler_call_details)
705+
706+
689707
def main(argv: list[str] | None = None) -> None:
690708
parser = ArgumentParser()
691709
parser.add_argument("--host", default="0.0.0.0")
@@ -710,6 +728,15 @@ def main(argv: list[str] | None = None) -> None:
710728
if options.single_use:
711729
interceptors.append(SingleTaskInterceptor())
712730

731+
if controller_auth_key := os.getenv("ISOLATE_CONTROLLER_AUTH_KEY"):
732+
# Set an interceptor to only accept requests with the correct auth key
733+
interceptors.append(ControllerAuthInterceptor(controller_auth_key))
734+
else:
735+
print(
736+
"[WARN] ISOLATE_CONTROLLER_AUTH_KEY is not set, all requests will be "
737+
"accepted without authentication."
738+
)
739+
713740
server = grpc.server(
714741
futures.ThreadPoolExecutor(max_workers=options.num_workers),
715742
options=get_default_options(),

tests/test_server.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from isolate.server.interface import from_grpc, to_serialized_object
2020
from isolate.server.server import (
2121
BridgeManager,
22+
ControllerAuthInterceptor,
2223
IsolateServicer,
2324
ServerBoundInterceptor,
2425
SingleTaskInterceptor,
@@ -602,6 +603,46 @@ def test_health_check(health_stub: health.HealthStub) -> None:
602603
assert resp.status == health.HealthCheckResponse.SERVING
603604

604605

606+
@pytest.mark.parametrize(
607+
"interceptors",
608+
[[ControllerAuthInterceptor(controller_auth_key="test-secret")]],
609+
)
610+
def test_controller_auth_rejects_without_token(
611+
stub: definitions.IsolateStub,
612+
) -> None:
613+
with pytest.raises(grpc.RpcError) as exc_info:
614+
stub.List(definitions.ListRequest())
615+
assert exc_info.value.code() == grpc.StatusCode.UNAUTHENTICATED
616+
617+
618+
@pytest.mark.parametrize(
619+
"interceptors",
620+
[[ControllerAuthInterceptor(controller_auth_key="test-secret")]],
621+
)
622+
def test_controller_auth_rejects_wrong_token(
623+
stub: definitions.IsolateStub,
624+
) -> None:
625+
with pytest.raises(grpc.RpcError) as exc_info:
626+
stub.List(
627+
definitions.ListRequest(),
628+
metadata=[("controller-token", "wrong")],
629+
)
630+
assert exc_info.value.code() == grpc.StatusCode.UNAUTHENTICATED
631+
632+
633+
@pytest.mark.parametrize(
634+
"interceptors",
635+
[[ControllerAuthInterceptor(controller_auth_key="test-secret")]],
636+
)
637+
def test_controller_auth_accepts_correct_token(
638+
stub: definitions.IsolateStub,
639+
) -> None:
640+
stub.List(
641+
definitions.ListRequest(),
642+
metadata=[("controller-token", "test-secret")],
643+
)
644+
645+
605646
def check_machine():
606647
import os
607648

0 commit comments

Comments
 (0)