Skip to content

Commit e9f3a3c

Browse files
🤖 feat: add provider response model mapping
1 parent c595e85 commit e9f3a3c

4 files changed

Lines changed: 369 additions & 4 deletions

File tree

‎internal/executor/middleware_dispatch.go‎

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"github.com/awsl-project/maxx/internal/converter"
1111
"github.com/awsl-project/maxx/internal/cooldown"
1212
"github.com/awsl-project/maxx/internal/domain"
13+
"github.com/awsl-project/maxx/internal/executor/responsemodifier"
1314
"github.com/awsl-project/maxx/internal/flow"
1415
"github.com/awsl-project/maxx/internal/pricing"
1516
"github.com/awsl-project/maxx/internal/usage"
@@ -158,24 +159,40 @@ func (e *Executor) dispatch(c *flow.Ctx) {
158159
var responseWriter http.ResponseWriter
159160
var convertingWriter *ConvertingResponseWriter
160161
responseCapture := NewResponseCapture(c.Writer)
162+
modifierWriter := responsemodifier.NewResponseModifierWriter(
163+
responseCapture,
164+
matchedRoute.Provider,
165+
originalClientType,
166+
state.requestModel,
167+
mappedModel,
168+
state.isStream,
169+
)
170+
if modifierWriter != nil {
171+
responseWriter = modifierWriter
172+
} else {
173+
responseWriter = responseCapture
174+
}
161175
if needsConversion {
162176
convertingWriter = NewConvertingResponseWriter(
163-
responseCapture, e.converter, originalClientType, currentClientType, state.isStream, state.originalRequestBody)
177+
responseWriter, e.converter, originalClientType, currentClientType, state.isStream, state.originalRequestBody)
164178
responseWriter = convertingWriter
165-
} else {
166-
responseWriter = responseCapture
167179
}
168180

169181
originalWriter := c.Writer
170182
c.Writer = responseWriter
171183
err := matchedRoute.ProviderAdapter.Execute(c, matchedRoute.Provider)
172184
c.Writer = originalWriter
173185

174-
if needsConversion && convertingWriter != nil && !state.isStream {
186+
if needsConversion && convertingWriter != nil && !state.isStream && (err == nil || modifierWriter == nil) {
175187
if finalizeErr := convertingWriter.Finalize(); finalizeErr != nil {
176188
log.Printf("[Executor] Response conversion finalize failed: %v", finalizeErr)
177189
}
178190
}
191+
if err == nil && modifierWriter != nil {
192+
if finalizeErr := modifierWriter.Finalize(); finalizeErr != nil {
193+
log.Printf("[Executor] Response modifier finalize failed: %v", finalizeErr)
194+
}
195+
}
179196

180197
eventChan.Close()
181198
<-eventDone
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
package responsemodifier
2+
3+
import (
4+
"bytes"
5+
"encoding/json"
6+
"regexp"
7+
"strings"
8+
9+
"github.com/awsl-project/maxx/internal/domain"
10+
)
11+
12+
var claudeSSEDataLineRE = regexp.MustCompile(`(?m)^(\s*data:\s*)(.*?)(\r?\n?)$`)
13+
14+
type claudeResponseModifier struct {
15+
requestModel string
16+
mappedModel string
17+
}
18+
19+
func newClaudeResponseModifier(config *domain.ProviderConfig, clientType domain.ClientType, requestModel string, mappedModel string) *claudeResponseModifier {
20+
if clientType != domain.ClientTypeClaude {
21+
return nil
22+
}
23+
if config == nil || config.Claude == nil {
24+
return nil
25+
}
26+
requestModel = strings.TrimSpace(requestModel)
27+
mappedModel = strings.TrimSpace(mappedModel)
28+
if requestModel == "" || mappedModel == "" || requestModel == mappedModel {
29+
return nil
30+
}
31+
return &claudeResponseModifier{requestModel: requestModel, mappedModel: mappedModel}
32+
}
33+
34+
func (m *claudeResponseModifier) modifyResponse(body []byte) []byte {
35+
return m.rewriteJSON(body)
36+
}
37+
38+
func (m *claudeResponseModifier) modifyStreamEvent(body []byte) []byte {
39+
return claudeSSEDataLineRE.ReplaceAllFunc(body, func(line []byte) []byte {
40+
parts := claudeSSEDataLineRE.FindSubmatch(line)
41+
if len(parts) != 4 || bytes.Equal(bytes.TrimSpace(parts[2]), []byte("[DONE]")) {
42+
return line
43+
}
44+
payload := m.rewriteJSON(parts[2])
45+
if bytes.Equal(payload, parts[2]) {
46+
return line
47+
}
48+
out := make([]byte, 0, len(parts[1])+len(payload)+len(parts[3]))
49+
out = append(out, parts[1]...)
50+
out = append(out, payload...)
51+
out = append(out, parts[3]...)
52+
return out
53+
})
54+
}
55+
56+
func (m *claudeResponseModifier) rewriteJSON(body []byte) []byte {
57+
if len(bytes.TrimSpace(body)) == 0 {
58+
return body
59+
}
60+
var object map[string]any
61+
if err := json.Unmarshal(body, &object); err != nil {
62+
return body
63+
}
64+
changed := m.rewriteModel(object, "model")
65+
if message, ok := object["message"].(map[string]any); ok {
66+
changed = m.rewriteModel(message, "model") || changed
67+
}
68+
if !changed {
69+
return body
70+
}
71+
var buf bytes.Buffer
72+
encoder := json.NewEncoder(&buf)
73+
encoder.SetEscapeHTML(false)
74+
if err := encoder.Encode(object); err != nil {
75+
return body
76+
}
77+
return bytes.TrimSuffix(buf.Bytes(), []byte("\n"))
78+
}
79+
80+
func (m *claudeResponseModifier) rewriteModel(object map[string]any, key string) bool {
81+
value, ok := object[key].(string)
82+
if !ok {
83+
return false
84+
}
85+
mapped := m.mapModel(value)
86+
if mapped == value {
87+
return false
88+
}
89+
object[key] = mapped
90+
return true
91+
}
92+
93+
func (m *claudeResponseModifier) mapModel(model string) string {
94+
model = strings.TrimSpace(model)
95+
if model != m.mappedModel {
96+
return model
97+
}
98+
return m.requestModel
99+
}
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
package responsemodifier
2+
3+
import (
4+
"bytes"
5+
"net/http"
6+
"strconv"
7+
8+
"github.com/awsl-project/maxx/internal/domain"
9+
)
10+
11+
type responseModifier interface {
12+
modifyResponse(body []byte) []byte
13+
modifyStreamEvent(event []byte) []byte
14+
}
15+
16+
// ResponseModifierWriter buffers a response and applies provider-specific response modifications before sending it.
17+
type ResponseModifierWriter struct {
18+
underlying http.ResponseWriter
19+
modifier responseModifier
20+
isStream bool
21+
statusCode int
22+
buffer bytes.Buffer
23+
headersSent bool
24+
}
25+
26+
func NewResponseModifierWriter(
27+
w http.ResponseWriter,
28+
provider *domain.Provider,
29+
clientType domain.ClientType,
30+
requestModel string,
31+
mappedModel string,
32+
isStream bool,
33+
) *ResponseModifierWriter {
34+
modifier := newResponseModifier(provider, clientType, requestModel, mappedModel)
35+
if modifier == nil {
36+
return nil
37+
}
38+
return &ResponseModifierWriter{underlying: w, modifier: modifier, isStream: isStream, statusCode: http.StatusOK}
39+
}
40+
41+
func newResponseModifier(provider *domain.Provider, clientType domain.ClientType, requestModel string, mappedModel string) responseModifier {
42+
if provider == nil {
43+
return nil
44+
}
45+
switch provider.Type {
46+
case "claude":
47+
modifier := newClaudeResponseModifier(provider.Config, clientType, requestModel, mappedModel)
48+
if modifier == nil {
49+
return nil
50+
}
51+
return modifier
52+
default:
53+
return nil
54+
}
55+
}
56+
57+
func (w *ResponseModifierWriter) Header() http.Header {
58+
return w.underlying.Header()
59+
}
60+
61+
func (w *ResponseModifierWriter) WriteHeader(code int) {
62+
w.statusCode = code
63+
if w.isStream {
64+
w.writeHeaderIfNeeded()
65+
}
66+
}
67+
68+
func (w *ResponseModifierWriter) Write(b []byte) (int, error) {
69+
if !w.isStream {
70+
_, err := w.buffer.Write(b)
71+
return len(b), err
72+
}
73+
if _, err := w.buffer.Write(b); err != nil {
74+
return 0, err
75+
}
76+
if err := w.flushCompleteStreamEvents(false); err != nil {
77+
return 0, err
78+
}
79+
return len(b), nil
80+
}
81+
82+
func (w *ResponseModifierWriter) Flush() {
83+
if w.isStream {
84+
_ = w.flushCompleteStreamEvents(false)
85+
}
86+
if f, ok := w.underlying.(http.Flusher); ok {
87+
f.Flush()
88+
}
89+
}
90+
91+
func (w *ResponseModifierWriter) Finalize() error {
92+
if w.isStream {
93+
return w.flushCompleteStreamEvents(true)
94+
}
95+
body := w.modifier.modifyResponse(w.buffer.Bytes())
96+
if w.underlying.Header().Get("Content-Length") != "" {
97+
w.underlying.Header().Set("Content-Length", strconv.Itoa(len(body)))
98+
}
99+
w.writeHeaderIfNeeded()
100+
_, err := w.underlying.Write(body)
101+
return err
102+
}
103+
104+
func (w *ResponseModifierWriter) flushCompleteStreamEvents(final bool) error {
105+
for {
106+
eventLen := completeSSEEventLen(w.buffer.Bytes())
107+
if eventLen == 0 {
108+
break
109+
}
110+
if err := w.writeStreamEvent(w.buffer.Next(eventLen)); err != nil {
111+
return err
112+
}
113+
}
114+
if final && w.buffer.Len() > 0 {
115+
if err := w.writeStreamEvent(w.buffer.Next(w.buffer.Len())); err != nil {
116+
return err
117+
}
118+
}
119+
if final {
120+
w.writeHeaderIfNeeded()
121+
}
122+
return nil
123+
}
124+
125+
func (w *ResponseModifierWriter) writeStreamEvent(event []byte) error {
126+
body := w.modifier.modifyStreamEvent(event)
127+
w.writeHeaderIfNeeded()
128+
_, err := w.underlying.Write(body)
129+
if err != nil {
130+
return err
131+
}
132+
if f, ok := w.underlying.(http.Flusher); ok {
133+
f.Flush()
134+
}
135+
return nil
136+
}
137+
138+
func (w *ResponseModifierWriter) writeHeaderIfNeeded() {
139+
if w.headersSent {
140+
return
141+
}
142+
w.underlying.WriteHeader(w.statusCode)
143+
w.headersSent = true
144+
}
145+
146+
func completeSSEEventLen(data []byte) int {
147+
lf := bytes.Index(data, []byte("\n\n"))
148+
crlf := bytes.Index(data, []byte("\r\n\r\n"))
149+
if lf == -1 && crlf == -1 {
150+
return 0
151+
}
152+
if lf == -1 || (crlf != -1 && crlf < lf) {
153+
return crlf + len("\r\n\r\n")
154+
}
155+
return lf + len("\n\n")
156+
}

0 commit comments

Comments
 (0)