Skip to content

Commit c16d601

Browse files
authored
[iris] Add optional auth mode for gradual adoption (#3937)
Add optional field to AuthConfig proto. When set, authentication is attempted but not required -- requests with valid tokens get the authenticated identity while unauthenticated requests fall through as anonymous/admin. This enables gradual rollout of JWT auth on clusters that currently rely on SSH tunnels. The dashboard /auth/config endpoint reports the optional flag so the frontend can adapt its login UI accordingly. Fixes #3936
1 parent ab588aa commit c16d601

15 files changed

Lines changed: 269 additions & 35 deletions

File tree

lib/fray/src/fray/v2/iris_backend.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,7 @@ def submit(self, request: JobRequest, adopt_existing: bool = True) -> IrisJobHan
508508
replicas = request.replicas or 1
509509
coscheduling = resolve_coscheduling(request.resources.device, replicas)
510510

511+
policy = cluster_pb2.EXISTING_JOB_POLICY_KEEP if adopt_existing else cluster_pb2.EXISTING_JOB_POLICY_UNSPECIFIED
511512
try:
512513
job = self._iris.submit(
513514
entrypoint=iris_entrypoint,
@@ -519,12 +520,10 @@ def submit(self, request: JobRequest, adopt_existing: bool = True) -> IrisJobHan
519520
replicas=replicas,
520521
max_retries_failure=request.max_retries_failure,
521522
max_retries_preemption=request.max_retries_preemption,
523+
existing_job_policy=policy,
522524
)
523525
except IrisJobAlreadyExists as e:
524-
if adopt_existing:
525-
logger.info("Job %s already exists, adopting existing job", request.name)
526-
return IrisJobHandle(e.job)
527-
raise FrayJobAlreadyExists(request.name, handle=IrisJobHandle(e.job)) from e
526+
raise FrayJobAlreadyExists(request.name) from e
528527
return IrisJobHandle(job)
529528

530529
def host_actor(

lib/iris/dashboard/src/App.vue

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,19 +75,21 @@ onMounted(async () => {
7575
window.addEventListener('iris-auth-required', onAuthRequired)
7676
7777
let hasSession = false
78+
let authOptional = false
7879
try {
7980
const resp = await fetch('/auth/config')
8081
if (resp.ok) {
8182
const config = await resp.json()
8283
authEnabled.value = config.auth_enabled ?? false
8384
hasSession = config.has_session ?? false
85+
authOptional = config.optional ?? false
8486
providerKind.value = config.provider_kind === 'kubernetes' ? 'kubernetes' : 'worker'
8587
}
8688
} catch {
8789
// Auth config endpoint unavailable — assume no auth
8890
}
8991
90-
if (authEnabled.value && !hasSession && route.path !== '/login') {
92+
if (authEnabled.value && !authOptional && !hasSession && route.path !== '/login') {
9193
router.push('/login')
9294
}
9395
})

lib/iris/examples/local-auth-gcp.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ auth:
1616
project_id: hai-gcp-models # GCP project ID — users must have access to log in
1717
admin_users:
1818
- russell.power@gmail.com # Replace with actual admin email
19+
optional: true
1920

2021
scale_groups:
2122
cpu:

lib/iris/src/iris/client/client.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,7 @@ def __init__(self, job_id: JobName, status: cluster_pb2.JobStatus):
132132
class JobAlreadyExists(Exception):
133133
"""Raised when a job with the same name is already running."""
134134

135-
def __init__(self, job: "Job", message: str):
136-
self.job = job
135+
def __init__(self, message: str):
137136
super().__init__(message)
138137

139138

@@ -704,7 +703,7 @@ def submit(
704703
)
705704

706705
try:
707-
self._cluster_client.submit_job(
706+
canonical_id = self._cluster_client.submit_job(
708707
job_id=job_id,
709708
entrypoint=entrypoint,
710709
resources=resources_proto,
@@ -723,10 +722,10 @@ def submit(
723722
)
724723
except ConnectError as e:
725724
if e.code == Code.ALREADY_EXISTS:
726-
raise JobAlreadyExists(Job(self, job_id), str(e)) from e
725+
raise JobAlreadyExists(str(e)) from e
727726
raise
728727

729-
return Job(self, job_id)
728+
return Job(self, canonical_id)
730729

731730
def status(self, job_id: JobName) -> cluster_pb2.JobStatus:
732731
"""Get job status.

lib/iris/src/iris/cluster/client/protocol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def submit_job(
5555
reservation: cluster_pb2.ReservationConfig | None = None,
5656
preemption_policy: cluster_pb2.JobPreemptionPolicy = cluster_pb2.JOB_PREEMPTION_POLICY_UNSPECIFIED,
5757
existing_job_policy: cluster_pb2.ExistingJobPolicy = cluster_pb2.EXISTING_JOB_POLICY_UNSPECIFIED,
58-
) -> None: ...
58+
) -> JobName: ...
5959

6060
def get_job_status(self, job_id: JobName) -> cluster_pb2.JobStatus: ...
6161

lib/iris/src/iris/cluster/client/remote_client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def submit_job(
7373
reservation: cluster_pb2.ReservationConfig | None = None,
7474
preemption_policy: cluster_pb2.JobPreemptionPolicy = cluster_pb2.JOB_PREEMPTION_POLICY_UNSPECIFIED,
7575
existing_job_policy: cluster_pb2.ExistingJobPolicy = cluster_pb2.EXISTING_JOB_POLICY_UNSPECIFIED,
76-
) -> None:
76+
) -> JobName:
7777
if replicas < 1:
7878
raise ValueError(f"replicas must be >= 1, got {replicas}")
7979
replicas = adjust_tpu_replicas(resources.device if resources.HasField("device") else None, replicas)
@@ -112,9 +112,10 @@ def submit_job(
112112
request.reservation.CopyFrom(reservation)
113113

114114
def _call():
115-
self._client.launch_job(request)
115+
return self._client.launch_job(request)
116116

117-
call_with_retry(f"launch_job({job_id})", _call)
117+
response = call_with_retry(f"launch_job({job_id})", _call)
118+
return JobName.from_wire(response.job_id)
118119

119120
def get_job_status(self, job_id: JobName) -> cluster_pb2.JobStatus:
120121
def _call():

lib/iris/src/iris/cluster/controller/auth.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ class ControllerAuth:
259259
login_verifier: TokenVerifier | None = None
260260
gcp_project_id: str | None = None
261261
jwt_manager: JwtTokenManager | None = None
262+
optional: bool = False
262263

263264

264265
def create_controller_auth(
@@ -328,14 +329,22 @@ def create_controller_auth(
328329
static_tokens = dict(auth_config.static.tokens)
329330
login_verifier = StaticTokenVerifier(static_tokens)
330331

331-
logger.info("Auth enabled: provider=%s, db=%s, jwt=%s", provider, "yes" if db else "no", "yes" if jwt_mgr else "no")
332+
optional = auth_config.optional
333+
logger.info(
334+
"Auth enabled: provider=%s, db=%s, jwt=%s, optional=%s",
335+
provider,
336+
"yes" if db else "no",
337+
"yes" if jwt_mgr else "no",
338+
optional,
339+
)
332340
return ControllerAuth(
333341
verifier=verifier,
334342
provider=provider,
335343
worker_token=worker_token,
336344
login_verifier=login_verifier,
337345
gcp_project_id=gcp_project_id,
338346
jwt_manager=jwt_mgr,
347+
optional=optional,
339348
)
340349

341350

lib/iris/src/iris/cluster/controller/controller.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,7 @@ def __init__(
760760
port=config.port,
761761
auth_verifier=config.auth_verifier,
762762
auth_provider=config.auth_provider,
763+
auth_optional=config.auth.optional if config.auth else False,
763764
)
764765

765766
# Ingest process logs into the LogStore so they are available via FetchLogs.

lib/iris/src/iris/cluster/controller/dashboard.py

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
from iris.cluster.controller.service import ControllerServiceImpl
4141
from iris.cluster.dashboard_common import html_shell, static_files_mount
42-
from iris.rpc.auth import SESSION_COOKIE, AuthInterceptor, NullAuthInterceptor, TokenVerifier
42+
from iris.rpc.auth import SESSION_COOKIE, NullAuthInterceptor, TokenVerifier, extract_bearer_token, resolve_auth
4343
from iris.rpc.cluster_connect import ControllerServiceWSGIApplication
4444
from iris.rpc.interceptors import RequestTimingInterceptor
4545

@@ -86,11 +86,15 @@ class _RouteAuthMiddleware:
8686
@public / @requires_auth annotation. Routes without an annotation are
8787
denied (default-deny). RPC Mount routes and static file mounts are
8888
skipped (they have their own auth).
89+
90+
Uses resolve_auth() — the same policy function as the gRPC interceptor —
91+
so HTTP and gRPC layers agree on allow/deny for every token state.
8992
"""
9093

91-
def __init__(self, app: Starlette, verifier: TokenVerifier):
94+
def __init__(self, app: Starlette, verifier: TokenVerifier, optional: bool = False):
9295
self._app = app
9396
self._verifier = verifier
97+
self._optional = optional
9498
self._router = app.router
9599

96100
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
@@ -134,15 +138,13 @@ def _resolve_policy(self, scope: Scope) -> str:
134138

135139
async def _check_auth(self, scope: Scope, receive: Receive, send: Send) -> None:
136140
token = _extract_token_from_scope(scope)
137-
if token is None:
138-
response = JSONResponse({"error": "authentication required"}, status_code=401)
139-
return await response(scope, receive, send)
140141
try:
141-
identity = self._verifier.verify(token)
142+
identity = resolve_auth(token, self._verifier, self._optional)
142143
except ValueError:
143-
response = JSONResponse({"error": "invalid session"}, status_code=401)
144+
response = JSONResponse({"error": "authentication required"}, status_code=401)
144145
return await response(scope, receive, send)
145-
scope["auth_identity"] = identity
146+
if identity is not None:
147+
scope["auth_identity"] = identity
146148
return await self._app(scope, receive, send)
147149

148150

@@ -168,16 +170,49 @@ def _check_csrf(request: Request) -> bool:
168170
return origin == expected_origin
169171

170172

171-
class _SelectiveAuthInterceptor:
172-
"""Auth interceptor that skips authentication for specific RPC methods."""
173+
class _DashboardAuthInterceptor:
174+
"""RPC auth interceptor that uses resolve_auth() — same policy as HTTP middleware.
173175
174-
def __init__(self, verifier: TokenVerifier):
175-
self._inner = AuthInterceptor(verifier)
176+
Login and GetAuthInfo RPCs are always unauthenticated. All other RPCs go
177+
through resolve_auth(token, verifier, optional) which:
178+
- token present + valid → authenticated identity
179+
- token present + invalid → rejected
180+
- no token + optional → anonymous/admin fallback via NullAuthInterceptor
181+
- no token + required → rejected
182+
"""
183+
184+
def __init__(self, verifier: TokenVerifier, optional: bool = False):
185+
self._verifier = verifier
186+
self._optional = optional
187+
self._null = NullAuthInterceptor(verifier=verifier)
176188

177189
def intercept_unary_sync(self, call_next, request, ctx):
190+
from iris.rpc.auth import _verified_identity
191+
178192
if ctx.method().name in _UNAUTHENTICATED_RPCS:
179193
return call_next(request, ctx)
180-
return self._inner.intercept_unary_sync(call_next, request, ctx)
194+
195+
token = extract_bearer_token(ctx.request_headers())
196+
try:
197+
identity = resolve_auth(token, self._verifier, self._optional)
198+
except ValueError as exc:
199+
from connectrpc.code import Code
200+
from connectrpc.errors import ConnectError
201+
202+
if token is None:
203+
raise ConnectError(Code.UNAUTHENTICATED, str(exc)) from exc
204+
logger.warning("Authentication failed: %s", exc)
205+
raise ConnectError(Code.UNAUTHENTICATED, "Authentication failed") from exc
206+
207+
if identity is None:
208+
# Optional mode, no token — anonymous fallback.
209+
return self._null.intercept_unary_sync(call_next, request, ctx)
210+
211+
reset_token = _verified_identity.set(identity)
212+
try:
213+
return call_next(request, ctx)
214+
finally:
215+
_verified_identity.reset(reset_token)
181216

182217

183218
class ControllerDashboard:
@@ -196,12 +231,14 @@ def __init__(
196231
port: int = 8080,
197232
auth_verifier: TokenVerifier | None = None,
198233
auth_provider: str | None = None,
234+
auth_optional: bool = False,
199235
):
200236
self._service = service
201237
self._host = host
202238
self._port = port
203239
self._auth_verifier = auth_verifier
204240
self._auth_provider = auth_provider
241+
self._auth_optional = auth_optional
205242
self._app = self._create_app()
206243

207244
@property
@@ -214,9 +251,11 @@ def app(self) -> ASGIApp:
214251

215252
def _create_app(self) -> ASGIApp:
216253
interceptors = [RequestTimingInterceptor(include_traceback=bool(os.environ.get("IRIS_DEBUG")))]
217-
if self._auth_provider is not None:
218-
interceptors.insert(0, _SelectiveAuthInterceptor(self._auth_verifier))
254+
if self._auth_provider is not None and self._auth_verifier is not None:
255+
interceptors.insert(0, _DashboardAuthInterceptor(self._auth_verifier, optional=self._auth_optional))
219256
else:
257+
# Null-auth mode: no provider configured. Verify worker tokens
258+
# when present but treat everything as anonymous/admin.
220259
interceptors.insert(0, NullAuthInterceptor(verifier=self._auth_verifier))
221260
rpc_wsgi_app = ControllerServiceWSGIApplication(service=self._service, interceptors=interceptors)
222261
rpc_app = WSGIMiddleware(rpc_wsgi_app)
@@ -236,7 +275,7 @@ def _create_app(self) -> ASGIApp:
236275
]
237276
app: Starlette | _RouteAuthMiddleware = Starlette(routes=routes)
238277
if self._auth_verifier is not None and self._auth_provider is not None:
239-
app = _RouteAuthMiddleware(app, self._auth_verifier)
278+
app = _RouteAuthMiddleware(app, self._auth_verifier, optional=self._auth_optional)
240279
return app
241280

242281
@public
@@ -283,6 +322,7 @@ def _auth_config(self, request: Request) -> JSONResponse:
283322
"provider": self._auth_provider,
284323
"has_session": has_session,
285324
"provider_kind": provider_kind,
325+
"optional": self._auth_optional,
286326
}
287327
)
288328

lib/iris/src/iris/rpc/auth.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,23 @@ def verify(self, token: str) -> VerifiedIdentity:
217217
raise ValueError(f"All verifiers failed: {'; '.join(errors)}")
218218

219219

220+
def resolve_auth(
221+
token: str | None,
222+
verifier: TokenVerifier,
223+
optional: bool,
224+
) -> VerifiedIdentity | None:
225+
"""Shared auth policy for gRPC interceptors and HTTP middleware.
226+
227+
Returns VerifiedIdentity on success, None for anonymous passthrough.
228+
Raises ValueError on rejected tokens (invalid token, or missing when required).
229+
"""
230+
if token is None:
231+
if optional:
232+
return None
233+
raise ValueError("Missing authentication")
234+
return verifier.verify(token)
235+
236+
220237
class AuthInterceptor:
221238
"""Server-side Connect RPC interceptor that enforces bearer token auth.
222239

0 commit comments

Comments
 (0)