Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion src/isolate/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,6 @@ def servicer(self) -> IsolateServicer:
return self._servicer


@dataclass
class SingleTaskInterceptor(ServerBoundInterceptor):
"""Sets server to terminate after the first Submit/Run task."""

Expand Down Expand Up @@ -686,6 +685,25 @@ def _stop(*args):
return wrap_server_method_handler(wrapper, handler)


class ControllerAuthInterceptor(ServerBoundInterceptor):
def __init__(self, controller_auth_key: str) -> None:
super().__init__()
self.controller_auth_key = controller_auth_key
self._terminator = grpc.unary_unary_rpc_method_handler(
lambda request, context: context.abort(
grpc.StatusCode.UNAUTHENTICATED, "Unauthorized"
)
)

def intercept_service(self, continuation, handler_call_details):
metadata = dict(handler_call_details.invocation_metadata)
controller_token = metadata.get("controller-token")
if controller_token != self.controller_auth_key:
return self._terminator

return continuation(handler_call_details)


def main(argv: list[str] | None = None) -> None:
parser = ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
Expand All @@ -710,6 +728,15 @@ def main(argv: list[str] | None = None) -> None:
if options.single_use:
interceptors.append(SingleTaskInterceptor())

if controller_auth_key := os.getenv("ISOLATE_CONTROLLER_AUTH_KEY"):
# Set an interceptor to only accept requests with the correct auth key
interceptors.append(ControllerAuthInterceptor(controller_auth_key))
else:
print(
"[WARN] ISOLATE_CONTROLLER_AUTH_KEY is not set, all requests will be "
"accepted without authentication."
)

server = grpc.server(
futures.ThreadPoolExecutor(max_workers=options.num_workers),
options=get_default_options(),
Expand Down
41 changes: 41 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from isolate.server.interface import from_grpc, to_serialized_object
from isolate.server.server import (
BridgeManager,
ControllerAuthInterceptor,
IsolateServicer,
ServerBoundInterceptor,
SingleTaskInterceptor,
Expand Down Expand Up @@ -602,6 +603,46 @@ def test_health_check(health_stub: health.HealthStub) -> None:
assert resp.status == health.HealthCheckResponse.SERVING


@pytest.mark.parametrize(
"interceptors",
[[ControllerAuthInterceptor(controller_auth_key="test-secret")]],
)
def test_controller_auth_rejects_without_token(
stub: definitions.IsolateStub,
) -> None:
with pytest.raises(grpc.RpcError) as exc_info:
stub.List(definitions.ListRequest())
assert exc_info.value.code() == grpc.StatusCode.UNAUTHENTICATED


@pytest.mark.parametrize(
"interceptors",
[[ControllerAuthInterceptor(controller_auth_key="test-secret")]],
)
def test_controller_auth_rejects_wrong_token(
stub: definitions.IsolateStub,
) -> None:
with pytest.raises(grpc.RpcError) as exc_info:
stub.List(
definitions.ListRequest(),
metadata=[("controller-token", "wrong")],
)
assert exc_info.value.code() == grpc.StatusCode.UNAUTHENTICATED


@pytest.mark.parametrize(
"interceptors",
[[ControllerAuthInterceptor(controller_auth_key="test-secret")]],
)
def test_controller_auth_accepts_correct_token(
stub: definitions.IsolateStub,
) -> None:
stub.List(
definitions.ListRequest(),
metadata=[("controller-token", "test-secret")],
)


def check_machine():
import os

Expand Down