1212import socket
1313import time
1414from dataclasses import dataclass
15- from typing import Any , AsyncContextManager , AsyncIterator , Awaitable , Mapping , Protocol , TypeAlias , cast
15+ from typing import Any , AsyncContextManager , AsyncIterator , Awaitable , Callable , Mapping , Protocol , TypeAlias , cast
1616from urllib .parse import ParseResult , urlparse , urlunparse
1717
1818import aiohttp
1919from aiohttp import hdrs
2020from aiohttp .client_ws import DEFAULT_WS_CLIENT_TIMEOUT
21+ from aiohttp .http_websocket import WS_KEY , WebSocketReader , WebSocketWriter
2122from multidict import CIMultiDict
2223
2324from app .core .clients .http import get_http_client
2425from app .core .config .settings import get_settings
25- from app .core .errors import OpenAIErrorEnvelope , ResponseFailedEvent , openai_error , response_failed_event
26+ from app .core .errors import (
27+ OpenAIErrorDetail ,
28+ OpenAIErrorEnvelope ,
29+ ResponseFailedEvent ,
30+ openai_error ,
31+ response_failed_event ,
32+ )
2633from app .core .openai .model_registry import get_model_registry
27- from app .core .openai .models import CompactResponsePayload
34+ from app .core .openai .models import CompactResponsePayload , OpenAIError
2835from app .core .openai .parsing import (
2936 parse_compact_response_payload ,
3037 parse_error_payload ,
@@ -408,7 +415,7 @@ def _error_payload_from_websocket_handshake_error(exc: aiohttp.WSServerHandshake
408415 if extracted is not None :
409416 error = parse_error_payload (extracted )
410417 if error is not None :
411- return {"error" : error . model_dump ( exclude_none = True )}
418+ return {"error" : _openai_error_detail ( error )}
412419
413420 code = _infer_websocket_handshake_error_code (exc .status , message )
414421 if code == "invalid_api_key" :
@@ -694,13 +701,32 @@ async def _error_payload_from_response(resp: ErrorResponse) -> OpenAIErrorEnvelo
694701 if isinstance (data , dict ):
695702 error = parse_error_payload (data )
696703 if error :
697- return {"error" : error . model_dump ( exclude_none = True )}
704+ return {"error" : _openai_error_detail ( error )}
698705 message = _extract_upstream_message (data )
699706 if message :
700707 return openai_error ("upstream_error" , message )
701708 return openai_error ("upstream_error" , fallback_message )
702709
703710
711+ def _openai_error_detail (error : OpenAIError ) -> OpenAIErrorDetail :
712+ detail : OpenAIErrorDetail = {}
713+ if error .message is not None :
714+ detail ["message" ] = error .message
715+ if error .type is not None :
716+ detail ["type" ] = error .type
717+ if error .code is not None :
718+ detail ["code" ] = error .code
719+ if error .param is not None :
720+ detail ["param" ] = error .param
721+ if error .plan_type is not None :
722+ detail ["plan_type" ] = error .plan_type
723+ if error .resets_at is not None :
724+ detail ["resets_at" ] = error .resets_at
725+ if error .resets_in_seconds is not None :
726+ detail ["resets_in_seconds" ] = error .resets_in_seconds
727+ return detail
728+
729+
704730def _extract_upstream_message (data : Mapping [str , object ]) -> str | None :
705731 for key in ("message" , "detail" , "error" ):
706732 value = data .get (key )
@@ -856,8 +882,8 @@ async def _open_upstream_websocket(
856882 connect_timeout_seconds : float ,
857883 max_msg_size : int ,
858884) -> tuple [AsyncContextManager [aiohttp .ClientWebSocketResponse ], aiohttp .ClientWebSocketResponse ]:
859- request = getattr (session , "request" , None )
860- if not callable (request ):
885+ request_obj = getattr (session , "request" , None )
886+ if not callable (request_obj ):
861887 websocket_cm = session .ws_connect (
862888 url ,
863889 headers = headers ,
@@ -868,6 +894,7 @@ async def _open_upstream_websocket(
868894 )
869895 websocket = await asyncio .wait_for (websocket_cm .__aenter__ (), timeout = connect_timeout_seconds )
870896 return websocket_cm , websocket
897+ request = cast (Callable [..., Awaitable [aiohttp .ClientResponse ]], request_obj )
871898
872899 request_headers = CIMultiDict (headers )
873900 request_headers .setdefault (hdrs .UPGRADE , "websocket" )
@@ -910,7 +937,7 @@ async def _raise_handshake_error(message: str) -> None:
910937 await _raise_handshake_error ("Invalid connection header" )
911938
912939 response_key = resp .headers .get (hdrs .SEC_WEBSOCKET_ACCEPT , "" )
913- expected_key = base64 .b64encode (hashlib .sha1 (sec_key .encode () + aiohttp . client . WS_KEY ).digest ()).decode ()
940+ expected_key = base64 .b64encode (hashlib .sha1 (sec_key .encode () + WS_KEY ).digest ()).decode ()
914941 if response_key != expected_key :
915942 await _raise_handshake_error ("Invalid challenge response" )
916943
@@ -922,9 +949,10 @@ async def _raise_handshake_error(message: str) -> None:
922949
923950 transport = conn .transport
924951 assert transport is not None
925- reader = aiohttp .client .WebSocketDataQueue (conn_proto , 2 ** 16 , loop = session ._loop )
926- conn_proto .set_parser (aiohttp .client .WebSocketReader (reader , max_msg_size ), reader )
927- writer = aiohttp .client .WebSocketWriter (conn_proto , transport , use_mask = True , compress = 0 , notakeover = False )
952+ web_socket_data_queue = cast (Callable [..., Any ], getattr (aiohttp .client_ws , "WebSocketDataQueue" ))
953+ reader = web_socket_data_queue (conn_proto , 2 ** 16 , loop = session ._loop )
954+ conn_proto .set_parser (WebSocketReader (reader , max_msg_size ), reader )
955+ writer = WebSocketWriter (conn_proto , transport , use_mask = True , compress = 0 , notakeover = False )
928956 except BaseException :
929957 resp .close ()
930958 raise
0 commit comments