Skip to content

Commit 75bb455

Browse files
committed
fix: stabilize codex websocket auth failover
1 parent f644ee7 commit 75bb455

4 files changed

Lines changed: 209 additions & 5 deletions

File tree

internal/runtime/executor/codex_websockets_executor.go

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,7 +1284,27 @@ func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth *
12841284
sess.connMu.Lock()
12851285
conn := sess.conn
12861286
readerConn := sess.readerConn
1287-
sess.connMu.Unlock()
1287+
sessionAuthID := strings.TrimSpace(sess.authID)
1288+
sessionWSURL := strings.TrimSpace(sess.wsURL)
1289+
sessionID := sess.sessionID
1290+
if conn != nil && (sessionAuthID != strings.TrimSpace(authID) || sessionWSURL != strings.TrimSpace(wsURL)) {
1291+
oldConn := conn
1292+
oldAuthID := sess.authID
1293+
oldWSURL := sess.wsURL
1294+
sess.conn = nil
1295+
if sess.readerConn == oldConn {
1296+
sess.readerConn = nil
1297+
}
1298+
conn = nil
1299+
readerConn = nil
1300+
sess.connMu.Unlock()
1301+
logCodexWebsocketDisconnected(sessionID, oldAuthID, oldWSURL, "auth_or_endpoint_changed", nil)
1302+
if errClose := oldConn.Close(); errClose != nil {
1303+
log.Errorf("codex websockets executor: close websocket error: %v", errClose)
1304+
}
1305+
} else {
1306+
sess.connMu.Unlock()
1307+
}
12881308
if conn != nil {
12891309
if readerConn != conn {
12901310
sess.connMu.Lock()
@@ -1404,12 +1424,30 @@ func (e *CodexWebsocketsExecutor) invalidateUpstreamConn(sess *codexWebsocketSes
14041424
sess.connMu.Unlock()
14051425

14061426
logCodexWebsocketDisconnected(sessionID, authID, wsURL, reason, err)
1407-
sess.notifyUpstreamDisconnect(err)
1427+
if shouldNotifyCodexUpstreamDisconnect(err) {
1428+
sess.notifyUpstreamDisconnect(err)
1429+
}
14081430
if errClose := conn.Close(); errClose != nil {
14091431
log.Errorf("codex websockets executor: close websocket error: %v", errClose)
14101432
}
14111433
}
14121434

1435+
func shouldNotifyCodexUpstreamDisconnect(err error) bool {
1436+
if err == nil {
1437+
return true
1438+
}
1439+
statusProvider, ok := err.(interface{ StatusCode() int })
1440+
if !ok || statusProvider == nil {
1441+
return true
1442+
}
1443+
switch statusProvider.StatusCode() {
1444+
case http.StatusUnauthorized, http.StatusPaymentRequired, http.StatusForbidden, http.StatusTooManyRequests:
1445+
return false
1446+
default:
1447+
return true
1448+
}
1449+
}
1450+
14131451
func (e *CodexWebsocketsExecutor) CloseExecutionSession(sessionID string) {
14141452
sessionID = strings.TrimSpace(sessionID)
14151453
if e == nil {

internal/runtime/executor/codex_websockets_executor_test.go

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package executor
33
import (
44
"bytes"
55
"context"
6+
"errors"
67
"net/http"
78
"net/http/httptest"
89
"testing"
@@ -91,6 +92,97 @@ func TestCodexWebsocketsExecutePreservesPreviousResponseIDUpstream(t *testing.T)
9192
}
9293
}
9394

95+
func TestCodexWebsocketSessionRedialsWhenAuthChanges(t *testing.T) {
96+
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
97+
authHeaders := make(chan string, 2)
98+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
99+
conn, err := upgrader.Upgrade(w, r, nil)
100+
if err != nil {
101+
t.Errorf("upgrade websocket: %v", err)
102+
return
103+
}
104+
defer func() { _ = conn.Close() }()
105+
106+
for {
107+
msgType, _, errRead := conn.ReadMessage()
108+
if errRead != nil {
109+
return
110+
}
111+
if msgType != websocket.TextMessage {
112+
continue
113+
}
114+
authHeaders <- r.Header.Get("Authorization")
115+
completed := []byte(`{"type":"response.completed","response":{"id":"resp-1","output":[],"usage":{"input_tokens":0,"output_tokens":0,"total_tokens":0}}}`)
116+
if errWrite := conn.WriteMessage(websocket.TextMessage, completed); errWrite != nil {
117+
return
118+
}
119+
}
120+
}))
121+
defer server.Close()
122+
123+
exec := NewCodexWebsocketsExecutor(&config.Config{})
124+
sessionID := "session-redial-auth"
125+
defer exec.CloseExecutionSession(sessionID)
126+
127+
req := cliproxyexecutor.Request{
128+
Model: "gpt-5-codex",
129+
Payload: []byte(`{"model":"gpt-5-codex","input":[]}`),
130+
}
131+
opts := cliproxyexecutor.Options{
132+
SourceFormat: sdktranslator.FromString("openai-response"),
133+
Metadata: map[string]any{
134+
cliproxyexecutor.ExecutionSessionMetadataKey: sessionID,
135+
},
136+
}
137+
auths := []struct {
138+
auth *cliproxyauth.Auth
139+
want string
140+
}{
141+
{
142+
auth: &cliproxyauth.Auth{ID: "auth-a", Attributes: map[string]string{"api_key": "sk-a", "base_url": server.URL}},
143+
want: "Bearer sk-a",
144+
},
145+
{
146+
auth: &cliproxyauth.Auth{ID: "auth-b", Attributes: map[string]string{"api_key": "sk-b", "base_url": server.URL}},
147+
want: "Bearer sk-b",
148+
},
149+
}
150+
151+
for i := range auths {
152+
streamResult, err := exec.ExecuteStream(context.Background(), auths[i].auth, req, opts)
153+
if err != nil {
154+
t.Fatalf("ExecuteStream(%s) error = %v", auths[i].auth.ID, err)
155+
}
156+
for chunk := range streamResult.Chunks {
157+
if chunk.Err != nil {
158+
t.Fatalf("stream chunk error for %s: %v", auths[i].auth.ID, chunk.Err)
159+
}
160+
}
161+
162+
select {
163+
case got := <-authHeaders:
164+
if got != auths[i].want {
165+
t.Fatalf("request %d Authorization = %q, want %q", i+1, got, auths[i].want)
166+
}
167+
case <-time.After(5 * time.Second):
168+
t.Fatalf("timed out waiting for upstream auth header %d", i+1)
169+
}
170+
}
171+
}
172+
173+
func TestShouldNotifyCodexUpstreamDisconnectSkipsRecoverableStatus(t *testing.T) {
174+
err, ok := parseCodexWebsocketError([]byte(`{"type":"error","status":429,"error":{"message":"quota exhausted"}}`))
175+
if !ok {
176+
t.Fatalf("expected websocket status error")
177+
}
178+
if shouldNotifyCodexUpstreamDisconnect(err) {
179+
t.Fatalf("recoverable websocket status should not force downstream disconnect")
180+
}
181+
if !shouldNotifyCodexUpstreamDisconnect(errors.New("network reset")) {
182+
t.Fatalf("network disconnect should still close downstream")
183+
}
184+
}
185+
94186
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
95187
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "", nil)
96188

sdk/api/handlers/openai/openai_responses_websocket.go

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
210210
}
211211

212212
requestJSON = repairResponsesWebsocketToolCalls(downstreamSessionKey, requestJSON)
213+
requestJSON = dedupeResponsesWebsocketRequestToolItems(requestJSON)
213214
updatedLastRequest = bytes.Clone(requestJSON)
214215
previousLastRequest := bytes.Clone(lastRequest)
215216
previousLastResponseOutput := bytes.Clone(lastResponseOutput)
@@ -296,22 +297,50 @@ func normalizeResponsesWebsocketRequestWithMode(rawJSON []byte, lastRequest []by
296297
}
297298
rawJSON = normalizedJSON
298299
requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String())
300+
var normalized []byte
301+
var updatedLastRequest []byte
299302
switch requestType {
300303
case wsRequestTypeCreate:
301304
// log.Infof("responses websocket: response.create request")
302305
if len(lastRequest) == 0 {
303-
return normalizeResponseCreateRequest(rawJSON)
306+
normalized, updatedLastRequest, errMsg = normalizeResponseCreateRequest(rawJSON)
307+
} else {
308+
normalized, updatedLastRequest, errMsg = normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID, allowCompactionReplayBypass)
304309
}
305-
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID, allowCompactionReplayBypass)
306310
case wsRequestTypeAppend:
307311
// log.Infof("responses websocket: response.append request")
308-
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID, allowCompactionReplayBypass)
312+
normalized, updatedLastRequest, errMsg = normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID, allowCompactionReplayBypass)
309313
default:
310314
return nil, lastRequest, &interfaces.ErrorMessage{
311315
StatusCode: http.StatusBadRequest,
312316
Error: fmt.Errorf("unsupported websocket request type: %s", requestType),
313317
}
314318
}
319+
if errMsg != nil {
320+
return nil, updatedLastRequest, errMsg
321+
}
322+
normalized = dedupeResponsesWebsocketRequestToolItems(normalized)
323+
return normalized, bytes.Clone(normalized), nil
324+
}
325+
326+
func dedupeResponsesWebsocketRequestToolItems(rawJSON []byte) []byte {
327+
if len(rawJSON) == 0 {
328+
return rawJSON
329+
}
330+
deduped := cliproxyexecutor.DedupeToolOutputs(rawJSON)
331+
input := gjson.GetBytes(deduped, "input")
332+
if !input.Exists() || !input.IsArray() {
333+
return deduped
334+
}
335+
dedupedInput, errDedupe := dedupeFunctionCallsByCallID(input.Raw)
336+
if errDedupe != nil || dedupedInput == input.Raw {
337+
return deduped
338+
}
339+
updated, errSet := sjson.SetRawBytes(deduped, "input", []byte(dedupedInput))
340+
if errSet != nil {
341+
return deduped
342+
}
343+
return updated
315344
}
316345

317346
func normalizeResponseCreateRequest(rawJSON []byte) ([]byte, []byte, *interfaces.ErrorMessage) {

sdk/api/handlers/openai/openai_responses_websocket_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1421,6 +1421,51 @@ func TestNormalizeResponsesWebsocketRequestDropsDuplicateFunctionCallsByCallID(t
14211421
}
14221422
}
14231423

1424+
func TestNormalizeResponsesWebsocketRequestDropsDuplicateToolOutputsByCallID(t *testing.T) {
1425+
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"function_call","id":"fc-1","call_id":"call-1","name":"tool"},{"type":"function_call_output","id":"tool-out-old","call_id":"call-1","output":"old"}]}`)
1426+
lastResponseOutput := []byte(`[]`)
1427+
raw := []byte(`{"type":"response.create","input":[{"type":"function_call_output","id":"tool-out-new","call_id":"call-1","output":"new"},{"type":"message","id":"msg-2"}]}`)
1428+
1429+
normalized, _, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
1430+
if errMsg != nil {
1431+
t.Fatalf("unexpected error: %v", errMsg.Error)
1432+
}
1433+
1434+
items := gjson.GetBytes(normalized, "input").Array()
1435+
if len(items) != 3 {
1436+
t.Fatalf("merged input len = %d, want 3: %s", len(items), normalized)
1437+
}
1438+
if items[0].Get("id").String() != "fc-1" ||
1439+
items[1].Get("id").String() != "tool-out-new" ||
1440+
items[2].Get("id").String() != "msg-2" {
1441+
t.Fatalf("unexpected deduped input order: %s", normalized)
1442+
}
1443+
if got := items[1].Get("output").String(); got != "new" {
1444+
t.Fatalf("tool output = %q, want new", got)
1445+
}
1446+
}
1447+
1448+
func TestNormalizeResponsesWebsocketRequestWithPreviousResponseIDDedupesIncrementalToolItems(t *testing.T) {
1449+
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`)
1450+
raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call","id":"fc-1","call_id":"call-1","name":"tool"},{"type":"function_call","id":"fc-dup","call_id":"call-1","name":"tool"},{"type":"function_call_output","id":"tool-out-old","call_id":"call-1","output":"old"},{"type":"function_call_output","id":"tool-out-new","call_id":"call-1","output":"new"}]}`)
1451+
1452+
normalized, _, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, []byte(`[]`), true, true)
1453+
if errMsg != nil {
1454+
t.Fatalf("unexpected error: %v", errMsg.Error)
1455+
}
1456+
if got := gjson.GetBytes(normalized, "previous_response_id").String(); got != "resp-1" {
1457+
t.Fatalf("previous_response_id = %s, want resp-1", got)
1458+
}
1459+
1460+
items := gjson.GetBytes(normalized, "input").Array()
1461+
if len(items) != 2 {
1462+
t.Fatalf("incremental input len = %d, want 2: %s", len(items), normalized)
1463+
}
1464+
if items[0].Get("id").String() != "fc-1" || items[1].Get("id").String() != "tool-out-new" {
1465+
t.Fatalf("unexpected deduped incremental input: %s", normalized)
1466+
}
1467+
}
1468+
14241469
func TestNormalizeResponsesWebsocketRequestTreatsCustomToolTranscriptReplacementAsReset(t *testing.T) {
14251470
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"},{"type":"custom_tool_call","id":"ctc-1","call_id":"call-1","name":"apply_patch"},{"type":"custom_tool_call_output","id":"tool-out-1","call_id":"call-1"},{"type":"message","id":"assistant-1","role":"assistant"}]}`)
14261471
lastResponseOutput := []byte(`[

0 commit comments

Comments
 (0)