Skip to content

Commit e2a22d1

Browse files
authored
[ai-proxy] vertex image edits & variations (#3536)
1 parent e9aecb6 commit e2a22d1

8 files changed

Lines changed: 830 additions & 28 deletions

File tree

plugins/wasm-go/extensions/ai-proxy/main.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,9 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
225225
}
226226
}
227227

228-
if contentType, _ := proxywasm.GetHttpRequestHeader(util.HeaderContentType); contentType != "" && !strings.Contains(contentType, util.MimeTypeApplicationJson) {
228+
if contentType, _ := proxywasm.GetHttpRequestHeader(util.HeaderContentType); contentType != "" && !isSupportedRequestContentType(apiName, contentType) {
229229
ctx.DontReadRequestBody()
230-
log.Debugf("[onHttpRequestHeader] unsupported content type: %s, will not process the request body", contentType)
230+
log.Debugf("[onHttpRequestHeader] unsupported content type for api %s: %s, will not process the request body", apiName, contentType)
231231
}
232232

233233
if apiName == "" {
@@ -306,6 +306,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
306306
if err == nil {
307307
return action
308308
}
309+
log.Errorf("[onHttpRequestBody] failed to process request body, apiName=%s, err=%v", apiName, err)
309310
_ = util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err))
310311
}
311312
return types.ActionContinue
@@ -594,3 +595,14 @@ func getApiName(path string) provider.ApiName {
594595

595596
return ""
596597
}
598+
599+
func isSupportedRequestContentType(apiName provider.ApiName, contentType string) bool {
600+
if strings.Contains(contentType, util.MimeTypeApplicationJson) {
601+
return true
602+
}
603+
contentType = strings.ToLower(contentType)
604+
if strings.HasPrefix(contentType, "multipart/form-data") {
605+
return apiName == provider.ApiNameImageEdit || apiName == provider.ApiNameImageVariation
606+
}
607+
return false
608+
}

plugins/wasm-go/extensions/ai-proxy/main_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,54 @@ func Test_getApiName(t *testing.T) {
6363
}
6464
}
6565

66+
func Test_isSupportedRequestContentType(t *testing.T) {
67+
tests := []struct {
68+
name string
69+
apiName provider.ApiName
70+
contentType string
71+
want bool
72+
}{
73+
{
74+
name: "json chat completion",
75+
apiName: provider.ApiNameChatCompletion,
76+
contentType: "application/json",
77+
want: true,
78+
},
79+
{
80+
name: "multipart image edit",
81+
apiName: provider.ApiNameImageEdit,
82+
contentType: "multipart/form-data; boundary=----boundary",
83+
want: true,
84+
},
85+
{
86+
name: "multipart image variation",
87+
apiName: provider.ApiNameImageVariation,
88+
contentType: "multipart/form-data; boundary=----boundary",
89+
want: true,
90+
},
91+
{
92+
name: "multipart chat completion",
93+
apiName: provider.ApiNameChatCompletion,
94+
contentType: "multipart/form-data; boundary=----boundary",
95+
want: false,
96+
},
97+
{
98+
name: "text plain image edit",
99+
apiName: provider.ApiNameImageEdit,
100+
contentType: "text/plain",
101+
want: false,
102+
},
103+
}
104+
for _, tt := range tests {
105+
t.Run(tt.name, func(t *testing.T) {
106+
got := isSupportedRequestContentType(tt.apiName, tt.contentType)
107+
if got != tt.want {
108+
t.Errorf("isSupportedRequestContentType(%v, %q) = %v, want %v", tt.apiName, tt.contentType, got, tt.want)
109+
}
110+
})
111+
}
112+
}
113+
66114
func TestAi360(t *testing.T) {
67115
test.RunAi360ParseConfigTests(t)
68116
test.RunAi360OnHttpRequestHeadersTests(t)
@@ -137,6 +185,8 @@ func TestVertex(t *testing.T) {
137185
test.RunVertexExpressModeOnStreamingResponseBodyTests(t)
138186
test.RunVertexExpressModeImageGenerationRequestBodyTests(t)
139187
test.RunVertexExpressModeImageGenerationResponseBodyTests(t)
188+
test.RunVertexExpressModeImageEditVariationRequestBodyTests(t)
189+
test.RunVertexExpressModeImageEditVariationResponseBodyTests(t)
140190
// Vertex Raw 模式测试
141191
test.RunVertexRawModeOnHttpRequestHeadersTests(t)
142192
test.RunVertexRawModeOnHttpRequestBodyTests(t)

plugins/wasm-go/extensions/ai-proxy/provider/model.go

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package provider
22

33
import (
4+
"encoding/json"
45
"fmt"
56
"strings"
67

@@ -461,6 +462,122 @@ type imageGenerationRequest struct {
461462
Size string `json:"size,omitempty"`
462463
}
463464

465+
type imageInputURL struct {
466+
URL string `json:"url,omitempty"`
467+
ImageURL *chatMessageContentImageUrl `json:"image_url,omitempty"`
468+
}
469+
470+
func (i *imageInputURL) UnmarshalJSON(data []byte) error {
471+
// Support a plain string payload, e.g. "data:image/png;base64,..."
472+
var rawURL string
473+
if err := json.Unmarshal(data, &rawURL); err == nil {
474+
i.URL = rawURL
475+
i.ImageURL = nil
476+
return nil
477+
}
478+
479+
type alias imageInputURL
480+
var value alias
481+
if err := json.Unmarshal(data, &value); err != nil {
482+
return err
483+
}
484+
*i = imageInputURL(value)
485+
return nil
486+
}
487+
488+
func (i *imageInputURL) GetURL() string {
489+
if i == nil {
490+
return ""
491+
}
492+
if i.ImageURL != nil && i.ImageURL.Url != "" {
493+
return i.ImageURL.Url
494+
}
495+
return i.URL
496+
}
497+
498+
type imageEditRequest struct {
499+
Model string `json:"model"`
500+
Prompt string `json:"prompt"`
501+
Image *imageInputURL `json:"image,omitempty"`
502+
Images []imageInputURL `json:"images,omitempty"`
503+
ImageURL *imageInputURL `json:"image_url,omitempty"`
504+
Mask *imageInputURL `json:"mask,omitempty"`
505+
MaskURL *imageInputURL `json:"mask_url,omitempty"`
506+
Background string `json:"background,omitempty"`
507+
Moderation string `json:"moderation,omitempty"`
508+
OutputCompression int `json:"output_compression,omitempty"`
509+
OutputFormat string `json:"output_format,omitempty"`
510+
Quality string `json:"quality,omitempty"`
511+
ResponseFormat string `json:"response_format,omitempty"`
512+
Style string `json:"style,omitempty"`
513+
N int `json:"n,omitempty"`
514+
Size string `json:"size,omitempty"`
515+
}
516+
517+
func (r *imageEditRequest) GetImageURLs() []string {
518+
urls := make([]string, 0, len(r.Images)+2)
519+
for _, image := range r.Images {
520+
if url := image.GetURL(); url != "" {
521+
urls = append(urls, url)
522+
}
523+
}
524+
if r.Image != nil {
525+
if url := r.Image.GetURL(); url != "" {
526+
urls = append(urls, url)
527+
}
528+
}
529+
if r.ImageURL != nil {
530+
if url := r.ImageURL.GetURL(); url != "" {
531+
urls = append(urls, url)
532+
}
533+
}
534+
return urls
535+
}
536+
537+
func (r *imageEditRequest) HasMask() bool {
538+
if r.Mask != nil && r.Mask.GetURL() != "" {
539+
return true
540+
}
541+
return r.MaskURL != nil && r.MaskURL.GetURL() != ""
542+
}
543+
544+
type imageVariationRequest struct {
545+
Model string `json:"model"`
546+
Prompt string `json:"prompt,omitempty"`
547+
Image *imageInputURL `json:"image,omitempty"`
548+
Images []imageInputURL `json:"images,omitempty"`
549+
ImageURL *imageInputURL `json:"image_url,omitempty"`
550+
Background string `json:"background,omitempty"`
551+
Moderation string `json:"moderation,omitempty"`
552+
OutputCompression int `json:"output_compression,omitempty"`
553+
OutputFormat string `json:"output_format,omitempty"`
554+
Quality string `json:"quality,omitempty"`
555+
ResponseFormat string `json:"response_format,omitempty"`
556+
Style string `json:"style,omitempty"`
557+
N int `json:"n,omitempty"`
558+
Size string `json:"size,omitempty"`
559+
}
560+
561+
func (r *imageVariationRequest) GetImageURLs() []string {
562+
urls := make([]string, 0, len(r.Images)+2)
563+
for _, image := range r.Images {
564+
if url := image.GetURL(); url != "" {
565+
urls = append(urls, url)
566+
}
567+
}
568+
if r.Image != nil {
569+
if url := r.Image.GetURL(); url != "" {
570+
urls = append(urls, url)
571+
}
572+
}
573+
if r.ImageURL != nil {
574+
if url := r.ImageURL.GetURL(); url != "" {
575+
urls = append(urls, url)
576+
}
577+
}
578+
return urls
579+
}
580+
464581
type imageGenerationData struct {
465582
URL string `json:"url,omitempty"`
466583
B64 string `json:"b64_json,omitempty"`
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
package provider
2+
3+
import (
4+
"bytes"
5+
"encoding/base64"
6+
"fmt"
7+
"io"
8+
"mime"
9+
"mime/multipart"
10+
"net/http"
11+
"strconv"
12+
"strings"
13+
)
14+
15+
type multipartImageRequest struct {
16+
Model string
17+
Prompt string
18+
Size string
19+
OutputFormat string
20+
N int
21+
ImageURLs []string
22+
HasMask bool
23+
}
24+
25+
func isMultipartFormData(contentType string) bool {
26+
mediaType, _, err := mime.ParseMediaType(contentType)
27+
if err != nil {
28+
return false
29+
}
30+
return strings.EqualFold(mediaType, "multipart/form-data")
31+
}
32+
33+
func parseMultipartImageRequest(body []byte, contentType string) (*multipartImageRequest, error) {
34+
_, params, err := mime.ParseMediaType(contentType)
35+
if err != nil {
36+
return nil, fmt.Errorf("unable to parse content-type: %v", err)
37+
}
38+
boundary := params["boundary"]
39+
if boundary == "" {
40+
return nil, fmt.Errorf("missing multipart boundary")
41+
}
42+
43+
req := &multipartImageRequest{
44+
ImageURLs: make([]string, 0),
45+
}
46+
reader := multipart.NewReader(bytes.NewReader(body), boundary)
47+
for {
48+
part, err := reader.NextPart()
49+
if err == io.EOF {
50+
break
51+
}
52+
if err != nil {
53+
return nil, fmt.Errorf("unable to read multipart part: %v", err)
54+
}
55+
fieldName := part.FormName()
56+
if fieldName == "" {
57+
_ = part.Close()
58+
continue
59+
}
60+
partContentType := strings.TrimSpace(part.Header.Get("Content-Type"))
61+
62+
partData, err := io.ReadAll(part)
63+
_ = part.Close()
64+
if err != nil {
65+
return nil, fmt.Errorf("unable to read multipart field %s: %v", fieldName, err)
66+
}
67+
68+
value := strings.TrimSpace(string(partData))
69+
switch fieldName {
70+
case "model":
71+
req.Model = value
72+
continue
73+
case "prompt":
74+
req.Prompt = value
75+
continue
76+
case "size":
77+
req.Size = value
78+
continue
79+
case "output_format":
80+
req.OutputFormat = value
81+
continue
82+
case "n":
83+
if value != "" {
84+
if parsed, err := strconv.Atoi(value); err == nil {
85+
req.N = parsed
86+
}
87+
}
88+
continue
89+
}
90+
91+
if isMultipartImageField(fieldName) {
92+
if isMultipartImageURLValue(value) {
93+
req.ImageURLs = append(req.ImageURLs, value)
94+
continue
95+
}
96+
if len(partData) == 0 {
97+
continue
98+
}
99+
imageURL := buildMultipartDataURL(partContentType, partData)
100+
req.ImageURLs = append(req.ImageURLs, imageURL)
101+
continue
102+
}
103+
if isMultipartMaskField(fieldName) {
104+
if len(partData) > 0 || value != "" {
105+
req.HasMask = true
106+
}
107+
continue
108+
}
109+
}
110+
111+
return req, nil
112+
}
113+
114+
func isMultipartImageField(fieldName string) bool {
115+
return fieldName == "image" || fieldName == "image[]" || strings.HasPrefix(fieldName, "image[")
116+
}
117+
118+
func isMultipartMaskField(fieldName string) bool {
119+
return fieldName == "mask" || fieldName == "mask[]" || strings.HasPrefix(fieldName, "mask[")
120+
}
121+
122+
func isMultipartImageURLValue(value string) bool {
123+
if value == "" {
124+
return false
125+
}
126+
loweredValue := strings.ToLower(value)
127+
return strings.HasPrefix(loweredValue, "data:") || strings.HasPrefix(loweredValue, "http://") || strings.HasPrefix(loweredValue, "https://")
128+
}
129+
130+
func buildMultipartDataURL(contentType string, data []byte) string {
131+
mimeType := strings.TrimSpace(contentType)
132+
if mimeType == "" || strings.EqualFold(mimeType, "application/octet-stream") {
133+
mimeType = http.DetectContentType(data)
134+
}
135+
mimeType = normalizeMultipartMimeType(mimeType)
136+
if mimeType == "" {
137+
mimeType = "application/octet-stream"
138+
}
139+
encoded := base64.StdEncoding.EncodeToString(data)
140+
return fmt.Sprintf("data:%s;base64,%s", mimeType, encoded)
141+
}
142+
143+
func normalizeMultipartMimeType(contentType string) string {
144+
contentType = strings.TrimSpace(contentType)
145+
if contentType == "" {
146+
return ""
147+
}
148+
mediaType, _, err := mime.ParseMediaType(contentType)
149+
if err == nil && mediaType != "" {
150+
return strings.TrimSpace(mediaType)
151+
}
152+
if idx := strings.Index(contentType, ";"); idx > 0 {
153+
return strings.TrimSpace(contentType[:idx])
154+
}
155+
return contentType
156+
}

0 commit comments

Comments
 (0)