Skip to content

Commit 9850255

Browse files
committed
🤖 fix: limit response model mapping to API output
What: - capture upstream response data before response model rewriting - keep proxy request and upstream attempt response models unchanged - update backend e2e assertions for unmapped recorded response models Why: - response model aliases should only affect the final client-facing API response - database records and operational logs should preserve upstream model identity Tests: - go test -count=1 ./internal/executor ./internal/handler ./tests/e2e -run 'Test.*Provider|Test.*ModelMapping|Test.*ResponseModelMapping|TestGenericProxyResponseModelMappingE2E' - go test -count=1 ./tests/e2e - go test -count=1 ./... - pnpm -C web typecheck - git diff --check
1 parent 49225c6 commit 9850255

5 files changed

Lines changed: 28 additions & 21 deletions

File tree

‎internal/executor/executor.go‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ func (e *Executor) processAdapterEvents(eventChan domain.AdapterEventChan, attem
341341
}
342342
case domain.EventResponseModel:
343343
if event.ResponseModel != "" {
344-
attempt.ResponseModel = MapResponseModel(event.ResponseModel, responseModelMapping)
344+
attempt.ResponseModel = event.ResponseModel
345345
}
346346
case domain.EventFirstToken:
347347
if event.FirstTokenTime > 0 {
@@ -423,7 +423,7 @@ func (e *Executor) processAdapterEventsRealtime(
423423
}
424424
case domain.EventResponseModel:
425425
if ev.ResponseModel != "" {
426-
attempt.ResponseModel = MapResponseModel(ev.ResponseModel, responseModelMapping)
426+
attempt.ResponseModel = ev.ResponseModel
427427
dirty = true
428428
}
429429
case domain.EventFirstToken:

‎internal/executor/middleware_dispatch.go‎

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,24 +156,24 @@ func (e *Executor) dispatch(c *flow.Ctx) {
156156
eventDone := make(chan struct{})
157157
go e.processAdapterEventsRealtime(eventChan, attemptRecord, eventDone, clearDetail, responseModelMapping)
158158

159+
originalWriter := c.Writer
159160
var responseWriter http.ResponseWriter
160161
var convertingWriter *ConvertingResponseWriter
161162
var responseModelWriter *ResponseModelMappingWriter
162-
responseCapture := NewResponseCapture(c.Writer)
163-
finalWriter := http.ResponseWriter(responseCapture)
163+
finalWriter := http.ResponseWriter(originalWriter)
164164
if len(responseModelMapping) > 0 {
165-
responseModelWriter = NewResponseModelMappingWriter(responseCapture, responseModelMapping, state.isStream)
165+
responseModelWriter = NewResponseModelMappingWriter(originalWriter, responseModelMapping, state.isStream)
166166
finalWriter = responseModelWriter
167167
}
168+
responseCapture := NewResponseCapture(finalWriter)
168169
if needsConversion {
169170
convertingWriter = NewConvertingResponseWriter(
170-
finalWriter, e.converter, originalClientType, currentClientType, state.isStream, state.originalRequestBody)
171+
responseCapture, e.converter, originalClientType, currentClientType, state.isStream, state.originalRequestBody)
171172
responseWriter = convertingWriter
172173
} else {
173-
responseWriter = finalWriter
174+
responseWriter = responseCapture
174175
}
175176

176-
originalWriter := c.Writer
177177
c.Writer = responseWriter
178178
err := matchedRoute.ProviderAdapter.Execute(c, matchedRoute.Provider)
179179
c.Writer = originalWriter

‎internal/executor/response_capture.go‎

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ func NewResponseCapture(w http.ResponseWriter) *ResponseCapture {
2323
}
2424
}
2525

26+
// SetResponseWriter changes the downstream writer while preserving captured data.
27+
func (rc *ResponseCapture) SetResponseWriter(w http.ResponseWriter) {
28+
rc.ResponseWriter = w
29+
}
30+
2631
// WriteHeader captures the status code and forwards to underlying writer
2732
func (rc *ResponseCapture) WriteHeader(code int) {
2833
rc.statusCode = code

‎internal/handler/provider_proxy.go‎

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,14 +140,15 @@ func (h *ProviderProxyHandler) directDispatch(provider *domain.Provider) flow.Ha
140140
c.Set(flow.KeyProxyRequest, proxyReq)
141141

142142
responseModelMapping := executor.ResponseModelMappingForProvider(provider)
143-
responseCapture := executor.NewResponseCapture(c.Writer)
143+
originalWriter := c.Writer
144144
var responseModelWriter *executor.ResponseModelMappingWriter
145-
responseWriter := http.ResponseWriter(responseCapture)
145+
finalWriter := http.ResponseWriter(originalWriter)
146146
if len(responseModelMapping) > 0 {
147-
responseModelWriter = executor.NewResponseModelMappingWriter(responseCapture, responseModelMapping, isStream)
148-
responseWriter = responseModelWriter
147+
responseModelWriter = executor.NewResponseModelMappingWriter(originalWriter, responseModelMapping, isStream)
148+
finalWriter = responseModelWriter
149149
}
150-
originalWriter := c.Writer
150+
responseCapture := executor.NewResponseCapture(finalWriter)
151+
responseWriter := http.ResponseWriter(responseCapture)
151152
c.Writer = responseWriter
152153
err = adapter.Execute(c, provider)
153154
c.Writer = originalWriter
@@ -164,8 +165,6 @@ func (h *ProviderProxyHandler) directDispatch(provider *domain.Provider) flow.Ha
164165
proxyReq.ResponseModel = mappedModel
165166
if responseModel := extractProviderProxyResponseModel(responseCapture.Body()); responseModel != "" {
166167
proxyReq.ResponseModel = responseModel
167-
} else if len(responseModelMapping) > 0 {
168-
proxyReq.ResponseModel = executor.MapResponseModel(mappedModel, responseModelMapping)
169168
}
170169
if !clearDetail {
171170
proxyReq.ResponseInfo = &domain.ResponseInfo{
@@ -184,6 +183,9 @@ func (h *ProviderProxyHandler) directDispatch(provider *domain.Provider) flow.Ha
184183

185184
proxyReq.Status = "FAILED"
186185
proxyReq.Error = err.Error()
186+
if responseModelWriter != nil {
187+
responseCapture.SetResponseWriter(originalWriter)
188+
}
187189
if proxyErr, ok := err.(*domain.ProxyError); ok {
188190
if isStream {
189191
writeStreamError(responseCapture, proxyErr)

‎tests/e2e/provider_proxy_test.go‎

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,8 @@ func TestProviderResponseModelMapping(t *testing.T) {
196196
if len(requests.Items) == 0 {
197197
t.Fatalf("expected at least one request for provider %d", providerID)
198198
}
199-
if got := requests.Items[0]["responseModel"]; got != "client-visible-model" {
200-
t.Fatalf("expected recorded response model client-visible-model, got %v", got)
199+
if got := requests.Items[0]["responseModel"]; got != "upstream-visible-sonnet" {
200+
t.Fatalf("expected recorded response model upstream-visible-sonnet, got %v", got)
201201
}
202202
}
203203

@@ -236,8 +236,8 @@ func TestGenericProxyResponseModelMappingE2E(t *testing.T) {
236236
if len(requests.Items) == 0 {
237237
t.Fatalf("expected at least one request for provider %d", providerID)
238238
}
239-
if got := requests.Items[0]["responseModel"]; got != "client-visible-model" {
240-
t.Fatalf("expected recorded response model client-visible-model, got %v", got)
239+
if got := requests.Items[0]["responseModel"]; got != "upstream-visible-sonnet" {
240+
t.Fatalf("expected recorded response model upstream-visible-sonnet, got %v", got)
241241
}
242242
}
243243

@@ -290,7 +290,7 @@ func TestProviderResponseModelMappingStreamsSSE(t *testing.T) {
290290
if len(requests.Items) == 0 {
291291
t.Fatalf("expected at least one request for provider %d", providerID)
292292
}
293-
if got := requests.Items[0]["responseModel"]; got != "client-visible-model" {
294-
t.Fatalf("expected recorded response model client-visible-model, got %v", got)
293+
if got := requests.Items[0]["responseModel"]; got != "upstream-visible-sonnet" {
294+
t.Fatalf("expected recorded response model upstream-visible-sonnet, got %v", got)
295295
}
296296
}

0 commit comments

Comments
 (0)