Skip to content

Commit 2723d97

Browse files
authored
fix: bridge backend HTTP responses through websocket sessions (#236)
* fix: bridge backend HTTP responses through websocket sessions * fix: defer bridge turn-state alias registration
1 parent 3236119 commit 2723d97

10 files changed

Lines changed: 1660 additions & 243 deletions

File tree

app/modules/proxy/api.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,13 @@ async def responses(
125125
api_key: ApiKeyData | None = Security(validate_proxy_api_key),
126126
) -> Response:
127127
return await _stream_responses(
128-
request, payload, context, api_key, codex_session_affinity=True, openai_cache_affinity=True
128+
request,
129+
payload,
130+
context,
131+
api_key,
132+
codex_session_affinity=True,
133+
openai_cache_affinity=True,
134+
prefer_http_bridge=True,
129135
)
130136

131137

@@ -437,6 +443,15 @@ async def _stream_responses(
437443
)
438444

439445
rate_limit_headers = await context.service.rate_limit_headers()
446+
bridge_active = prefer_http_bridge and proxy_service_module.get_settings().http_responses_session_bridge_enabled
447+
downstream_turn_state = (
448+
proxy_service_module.ensure_http_downstream_turn_state(request.headers) if bridge_active else None
449+
)
450+
turn_state_headers = (
451+
proxy_service_module.build_downstream_turn_state_response_headers(downstream_turn_state)
452+
if downstream_turn_state is not None
453+
else {}
454+
)
440455
payload.stream = True
441456
if prefer_http_bridge:
442457
stream = context.service.stream_http_responses(
@@ -448,6 +463,7 @@ async def _stream_responses(
448463
api_key=api_key,
449464
api_key_reservation=reservation,
450465
suppress_text_done_events=suppress_text_done_events,
466+
downstream_turn_state=downstream_turn_state,
451467
)
452468
else:
453469
stream = context.service.stream_responses(
@@ -470,11 +486,16 @@ async def _stream_responses(
470486
)
471487
except ProxyResponseError as exc:
472488
await _release_reservation(reservation)
473-
return _logged_error_json_response(request, exc.status_code, exc.payload, headers=rate_limit_headers)
489+
return _logged_error_json_response(
490+
request,
491+
exc.status_code,
492+
exc.payload,
493+
headers=rate_limit_headers,
494+
)
474495
return StreamingResponse(
475496
_prepend_first(first, stream),
476497
media_type="text/event-stream",
477-
headers={"Cache-Control": "no-cache", **rate_limit_headers},
498+
headers={"Cache-Control": "no-cache", **turn_state_headers, **rate_limit_headers},
478499
)
479500

480501

@@ -498,6 +519,15 @@ async def _collect_responses(
498519
)
499520

500521
rate_limit_headers = await context.service.rate_limit_headers()
522+
bridge_active = prefer_http_bridge and proxy_service_module.get_settings().http_responses_session_bridge_enabled
523+
downstream_turn_state = (
524+
proxy_service_module.ensure_http_downstream_turn_state(request.headers) if bridge_active else None
525+
)
526+
turn_state_headers = (
527+
proxy_service_module.build_downstream_turn_state_response_headers(downstream_turn_state)
528+
if downstream_turn_state is not None
529+
else {}
530+
)
501531
payload.stream = True
502532
if prefer_http_bridge:
503533
stream = context.service.stream_http_responses(
@@ -509,6 +539,7 @@ async def _collect_responses(
509539
api_key=api_key,
510540
api_key_reservation=reservation,
511541
suppress_text_done_events=suppress_text_done_events,
542+
downstream_turn_state=downstream_turn_state,
512543
)
513544
else:
514545
stream = context.service.stream_responses(
@@ -540,18 +571,18 @@ async def _collect_responses(
540571
request,
541572
status_code,
542573
error_payload.model_dump(mode="json", exclude_none=True),
543-
headers=rate_limit_headers,
574+
headers={**turn_state_headers, **rate_limit_headers},
544575
)
545576
return JSONResponse(
546577
content=response_payload.model_dump(mode="json", exclude_none=True),
547-
headers=rate_limit_headers,
578+
headers={**turn_state_headers, **rate_limit_headers},
548579
)
549580
status_code = _status_for_error(response_payload.error)
550581
return _logged_error_json_response(
551582
request,
552583
status_code,
553584
response_payload.model_dump(mode="json", exclude_none=True),
554-
headers=rate_limit_headers,
585+
headers={**turn_state_headers, **rate_limit_headers},
555586
)
556587

557588

app/modules/proxy/service.py

Lines changed: 130 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import time
88
from collections import deque
99
from collections.abc import Sequence
10-
from dataclasses import dataclass
10+
from dataclasses import dataclass, field
1111
from hashlib import sha256
1212
from typing import AsyncIterator, Mapping, NoReturn
1313
from uuid import uuid4
@@ -143,6 +143,7 @@ def __init__(self, repo_factory: ProxyRepoFactory) -> None:
143143
self._encryptor = TokenEncryptor()
144144
self._load_balancer = LoadBalancer(repo_factory)
145145
self._http_bridge_sessions: dict[_HTTPBridgeSessionKey, _HTTPBridgeSession] = {}
146+
self._http_bridge_turn_state_index: dict[tuple[str, str | None], _HTTPBridgeSessionKey] = {}
146147
self._http_bridge_lock = anyio.Lock()
147148

148149
def stream_responses(
@@ -183,6 +184,7 @@ def stream_http_responses(
183184
api_key: ApiKeyData | None = None,
184185
api_key_reservation: ApiKeyUsageReservationData | None = None,
185186
suppress_text_done_events: bool = False,
187+
downstream_turn_state: str | None = None,
186188
) -> AsyncIterator[str]:
187189
_maybe_log_proxy_request_payload("stream_http", payload, headers)
188190
filtered = filter_inbound_headers(headers)
@@ -195,6 +197,7 @@ def stream_http_responses(
195197
api_key=api_key,
196198
api_key_reservation=api_key_reservation,
197199
suppress_text_done_events=suppress_text_done_events,
200+
downstream_turn_state=downstream_turn_state,
198201
)
199202

200203
async def _stream_http_bridge_or_retry(
@@ -208,6 +211,7 @@ async def _stream_http_bridge_or_retry(
208211
api_key: ApiKeyData | None,
209212
api_key_reservation: ApiKeyUsageReservationData | None,
210213
suppress_text_done_events: bool,
214+
downstream_turn_state: str | None = None,
211215
) -> AsyncIterator[str]:
212216
settings = await get_settings_cache().get()
213217
if not _http_responses_session_bridge_enabled(settings):
@@ -238,6 +242,7 @@ async def _stream_http_bridge_or_retry(
238242
codex_idle_ttl_seconds=getattr(settings, "http_responses_session_bridge_codex_idle_ttl_seconds", 900.0),
239243
max_sessions=getattr(settings, "http_responses_session_bridge_max_sessions", 256),
240244
queue_limit=getattr(settings, "http_responses_session_bridge_queue_limit", 8),
245+
downstream_turn_state=downstream_turn_state,
241246
):
242247
yield line
243248

@@ -256,6 +261,7 @@ async def _stream_via_http_bridge(
256261
codex_idle_ttl_seconds: float,
257262
max_sessions: int,
258263
queue_limit: int,
264+
downstream_turn_state: str | None = None,
259265
) -> AsyncIterator[str]:
260266
del propagate_http_errors, suppress_text_done_events
261267
request_id = ensure_request_id()
@@ -306,7 +312,6 @@ async def _stream_via_http_bridge(
306312
),
307313
max_sessions=max_sessions,
308314
)
309-
310315
request_state, text_data = self._prepare_http_bridge_request(
311316
payload,
312317
api_key=api_key,
@@ -320,6 +325,8 @@ async def _stream_via_http_bridge(
320325
text_data=text_data,
321326
queue_limit=queue_limit,
322327
)
328+
if downstream_turn_state is not None:
329+
await self._register_http_bridge_turn_state(session, downstream_turn_state)
323330

324331
try:
325332
event_queue = request_state.event_queue
@@ -1398,8 +1405,8 @@ async def _get_or_create_http_bridge_session(
13981405
idle_ttl_seconds: float,
13991406
max_sessions: int,
14001407
) -> "_HTTPBridgeSession":
1401-
del api_key
14021408
settings = get_settings()
1409+
api_key_id = api_key.id if api_key is not None else None
14031410
effective_idle_ttl_seconds = _effective_http_bridge_idle_ttl_seconds(
14041411
affinity=affinity,
14051412
idle_ttl_seconds=idle_ttl_seconds,
@@ -1409,36 +1416,64 @@ async def _get_or_create_http_bridge_session(
14091416
900.0,
14101417
),
14111418
)
1412-
owner_instance = _http_bridge_owner_instance(key, settings)
1413-
current_instance, ring = _normalized_http_bridge_instance_ring(settings)
1414-
if (
1415-
key.affinity_kind != "request"
1416-
and owner_instance is not None
1417-
and len(ring) > 1
1418-
and owner_instance != current_instance
1419-
):
1420-
_log_http_bridge_event(
1421-
"owner_mismatch",
1422-
key,
1423-
account_id=None,
1424-
model=request_model,
1425-
detail=f"expected_instance={owner_instance}, current_instance={current_instance}",
1426-
)
1427-
raise ProxyResponseError(
1428-
409,
1429-
openai_error(
1430-
"bridge_instance_mismatch",
1431-
(
1432-
"HTTP responses session bridge request reached the wrong instance "
1433-
f"(expected {owner_instance}, got {current_instance})"
1434-
),
1435-
error_type="server_error",
1436-
),
1437-
)
14381419
async with self._http_bridge_lock:
1420+
incoming_turn_state = _sticky_key_from_turn_state_header(headers)
1421+
if incoming_turn_state is not None:
1422+
alias_key = self._http_bridge_turn_state_index.get(
1423+
_http_bridge_turn_state_alias_key(incoming_turn_state, api_key_id)
1424+
)
1425+
if alias_key is not None:
1426+
key = alias_key
1427+
elif incoming_turn_state.startswith("http_turn_"):
1428+
raise ProxyResponseError(
1429+
409,
1430+
openai_error(
1431+
"bridge_instance_mismatch",
1432+
"HTTP bridge turn-state reached an instance that does not own the live session",
1433+
error_type="server_error",
1434+
),
1435+
)
1436+
owner_instance = _http_bridge_owner_instance(key, settings)
1437+
current_instance, ring = _normalized_http_bridge_instance_ring(settings)
1438+
if (
1439+
key.affinity_kind != "request"
1440+
and owner_instance is not None
1441+
and len(ring) > 1
1442+
and owner_instance != current_instance
1443+
):
1444+
_log_http_bridge_event(
1445+
"owner_mismatch",
1446+
key,
1447+
account_id=None,
1448+
model=request_model,
1449+
detail=f"expected_instance={owner_instance}, current_instance={current_instance}",
1450+
)
1451+
raise ProxyResponseError(
1452+
409,
1453+
openai_error(
1454+
"bridge_instance_mismatch",
1455+
(
1456+
"HTTP responses session bridge request reached the wrong instance "
1457+
f"(expected {owner_instance}, got {current_instance})"
1458+
),
1459+
error_type="server_error",
1460+
),
1461+
)
14391462
await self._prune_http_bridge_sessions_locked()
14401463
existing = self._http_bridge_sessions.get(key)
14411464
if existing is not None and not existing.closed and existing.account.status == AccountStatus.ACTIVE:
1465+
if (
1466+
incoming_turn_state is not None
1467+
and self._http_bridge_turn_state_index.get(
1468+
_http_bridge_turn_state_alias_key(incoming_turn_state, api_key_id)
1469+
)
1470+
== key
1471+
):
1472+
self._promote_http_bridge_session_to_codex_affinity(
1473+
existing,
1474+
turn_state=incoming_turn_state,
1475+
settings=settings,
1476+
)
14421477
existing.request_model = request_model
14431478
existing.last_used_at = time.monotonic()
14441479
_log_http_bridge_event(
@@ -1538,6 +1573,11 @@ async def _prune_http_bridge_sessions_locked(self) -> None:
15381573

15391574
async def _close_http_bridge_session(self, session: "_HTTPBridgeSession") -> None:
15401575
session.closed = True
1576+
for alias in session.downstream_turn_state_aliases:
1577+
self._http_bridge_turn_state_index.pop(
1578+
_http_bridge_turn_state_alias_key(alias, session.key.api_key_id),
1579+
None,
1580+
)
15411581
if session.upstream_reader is not None:
15421582
session.upstream_reader.cancel()
15431583
try:
@@ -1555,6 +1595,34 @@ async def _close_http_bridge_session(self, session: "_HTTPBridgeSession") -> Non
15551595
model=session.request_model,
15561596
)
15571597

1598+
async def _register_http_bridge_turn_state(self, session: "_HTTPBridgeSession", turn_state: str) -> None:
1599+
async with self._http_bridge_lock:
1600+
if session.closed:
1601+
return
1602+
session.downstream_turn_state_aliases.add(turn_state)
1603+
if session.downstream_turn_state is None:
1604+
session.downstream_turn_state = turn_state
1605+
self._http_bridge_turn_state_index[
1606+
_http_bridge_turn_state_alias_key(turn_state, session.key.api_key_id)
1607+
] = session.key
1608+
1609+
def _promote_http_bridge_session_to_codex_affinity(
1610+
self,
1611+
session: "_HTTPBridgeSession",
1612+
*,
1613+
turn_state: str,
1614+
settings: object,
1615+
) -> None:
1616+
session.affinity = _AffinityPolicy(key=turn_state, kind=StickySessionKind.CODEX_SESSION)
1617+
session.codex_session = True
1618+
session.downstream_turn_state = turn_state
1619+
session.downstream_turn_state_aliases.add(turn_state)
1620+
session.idle_ttl_seconds = max(
1621+
session.idle_ttl_seconds,
1622+
float(getattr(settings, "http_responses_session_bridge_codex_idle_ttl_seconds", 900.0)),
1623+
)
1624+
session.headers = _headers_with_turn_state(session.headers, turn_state)
1625+
15581626
async def _create_http_bridge_session(
15591627
self,
15601628
key: "_HTTPBridgeSessionKey",
@@ -1621,6 +1689,7 @@ async def _create_http_bridge_session(
16211689
codex_session=affinity.kind == StickySessionKind.CODEX_SESSION,
16221690
prewarm_lock=anyio.Lock(),
16231691
upstream_turn_state=_upstream_turn_state_from_socket(upstream),
1692+
downstream_turn_state=None,
16241693
)
16251694
session.upstream_reader = asyncio.create_task(self._relay_http_bridge_upstream_messages(session))
16261695
return session
@@ -2008,7 +2077,10 @@ async def _reconnect_http_bridge_session(
20082077
),
20092078
)
20102079
account = await self._ensure_fresh_with_budget(account, timeout_seconds=_remaining_budget_seconds(deadline))
2011-
connect_headers = _headers_with_turn_state(session.headers, session.upstream_turn_state)
2080+
connect_headers = _headers_with_turn_state(
2081+
session.headers,
2082+
_preferred_http_bridge_reconnect_turn_state(session),
2083+
)
20122084
upstream = await self._open_upstream_websocket_with_budget(
20132085
account,
20142086
connect_headers,
@@ -3983,6 +4055,8 @@ class _HTTPBridgeSession:
39834055
prewarmed: bool = False
39844056
prewarm_lock: anyio.Lock | None = None
39854057
upstream_turn_state: str | None = None
4058+
downstream_turn_state: str | None = None
4059+
downstream_turn_state_aliases: set[str] = field(default_factory=set)
39864060
upstream_reader: asyncio.Task[None] | None = None
39874061
closed: bool = False
39884062

@@ -4507,10 +4581,21 @@ def ensure_downstream_turn_state(headers: Mapping[str, str]) -> str:
45074581
return f"turn_{uuid4().hex}"
45084582

45094583

4584+
def ensure_http_downstream_turn_state(headers: Mapping[str, str]) -> str:
4585+
existing = _sticky_key_from_turn_state_header(headers)
4586+
if existing is not None:
4587+
return existing
4588+
return f"http_turn_{uuid4().hex}"
4589+
4590+
45104591
def build_downstream_turn_state_accept_headers(turn_state: str) -> list[tuple[bytes, bytes]]:
45114592
return [(b"x-codex-turn-state", turn_state.encode("utf-8"))]
45124593

45134594

4595+
def build_downstream_turn_state_response_headers(turn_state: str) -> dict[str, str]:
4596+
return {"x-codex-turn-state": turn_state}
4597+
4598+
45144599
def _upstream_turn_state_from_socket(upstream: UpstreamResponsesWebSocket | None) -> str | None:
45154600
if upstream is None:
45164601
return None
@@ -4531,6 +4616,21 @@ def _headers_with_turn_state(headers: Mapping[str, str], turn_state: str | None)
45314616
return forwarded
45324617

45334618

4619+
def _preferred_http_bridge_reconnect_turn_state(session: "_HTTPBridgeSession") -> str | None:
4620+
if (
4621+
session.codex_session
4622+
and session.downstream_turn_state is not None
4623+
and session.affinity.kind == StickySessionKind.CODEX_SESSION
4624+
and session.affinity.key == session.downstream_turn_state
4625+
):
4626+
return session.downstream_turn_state
4627+
return session.upstream_turn_state
4628+
4629+
4630+
def _http_bridge_turn_state_alias_key(turn_state: str, api_key_id: str | None) -> tuple[str, str | None]:
4631+
return (turn_state, api_key_id)
4632+
4633+
45344634
def _resolve_prompt_cache_key(
45354635
payload: ResponsesRequest | ResponsesCompactRequest,
45364636
*,

0 commit comments

Comments
 (0)