Skip to content

Commit 6f84d17

Browse files
authored
support multi modal inputs (#1617)
1 parent e4fe22d commit 6f84d17

File tree

3 files changed

+137
-25
lines changed

3 files changed

+137
-25
lines changed

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,8 @@ func TestPrefixPluginChatCompletions(t *testing.T) {
217217
Body: &types.LLMRequestBody{
218218
ChatCompletions: &types.ChatCompletionsRequest{
219219
Messages: []types.Message{
220-
{Role: "user", Content: "hello world"},
221-
{Role: "assistant", Content: "hi there"},
220+
{Role: "user", Content: types.Content{Raw: "hello world"}},
221+
{Role: "assistant", Content: types.Content{Raw: "hi there"}},
222222
},
223223
},
224224
},
@@ -252,8 +252,8 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
252252
Body: &types.LLMRequestBody{
253253
ChatCompletions: &types.ChatCompletionsRequest{
254254
Messages: []types.Message{
255-
{Role: "system", Content: "You are a helpful assistant"},
256-
{Role: "user", Content: "Hello, how are you?"},
255+
{Role: "system", Content: types.Content{Raw: "You are a helpful assistant"}},
256+
{Role: "user", Content: types.Content{Raw: "Hello, how are you?"}},
257257
},
258258
},
259259
},
@@ -285,10 +285,10 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
285285
Body: &types.LLMRequestBody{
286286
ChatCompletions: &types.ChatCompletionsRequest{
287287
Messages: []types.Message{
288-
{Role: "system", Content: "You are a helpful assistant"},
289-
{Role: "user", Content: "Hello, how are you?"},
290-
{Role: "assistant", Content: "I'm doing well, thank you! How can I help you today?"},
291-
{Role: "user", Content: "Can you explain how prefix caching works?"},
288+
{Role: "system", Content: types.Content{Raw: "You are a helpful assistant"}},
289+
{Role: "user", Content: types.Content{Raw: "Hello, how are you?"}},
290+
{Role: "assistant", Content: types.Content{Raw: "I'm doing well, thank you! How can I help you today?"}},
291+
{Role: "user", Content: types.Content{Raw: "Can you explain how prefix caching works?"}},
292292
},
293293
},
294294
},
@@ -318,12 +318,12 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
318318
Body: &types.LLMRequestBody{
319319
ChatCompletions: &types.ChatCompletionsRequest{
320320
Messages: []types.Message{
321-
{Role: "system", Content: "You are a helpful assistant"},
322-
{Role: "user", Content: "Hello, how are you?"},
323-
{Role: "assistant", Content: "I'm doing well, thank you! How can I help you today?"},
324-
{Role: "user", Content: "Can you explain how prefix caching works?"},
325-
{Role: "assistant", Content: "Prefix caching is a technique where..."},
326-
{Role: "user", Content: "That's very helpful, thank you!"},
321+
{Role: "system", Content: types.Content{Raw: "You are a helpful assistant"}},
322+
{Role: "user", Content: types.Content{Raw: "Hello, how are you?"}},
323+
{Role: "assistant", Content: types.Content{Raw: "I'm doing well, thank you! How can I help you today?"}},
324+
{Role: "user", Content: types.Content{Raw: "Can you explain how prefix caching works?"}},
325+
{Role: "assistant", Content: types.Content{Raw: "Prefix caching is a technique where..."}},
326+
{Role: "user", Content: types.Content{Raw: "That's very helpful, thank you!"}},
327327
},
328328
},
329329
},
@@ -437,15 +437,15 @@ func BenchmarkPrefixPluginChatCompletionsStress(b *testing.B) {
437437
b.Run(fmt.Sprintf("messages_%d_length_%d", scenario.messageCount, scenario.messageLength), func(b *testing.B) {
438438
// Generate messages for this scenario
439439
messages := make([]types.Message, scenario.messageCount)
440-
messages[0] = types.Message{Role: "system", Content: "You are a helpful assistant."}
440+
messages[0] = types.Message{Role: "system", Content: types.Content{Raw: "You are a helpful assistant."}}
441441

442442
for i := 1; i < scenario.messageCount; i++ {
443443
role := "user"
444444
if i%2 == 0 {
445445
role = "assistant"
446446
}
447447
content := randomPrompt(scenario.messageLength)
448-
messages[i] = types.Message{Role: role, Content: content}
448+
messages[i] = types.Message{Role: role, Content: types.Content{Raw: content}}
449449
}
450450

451451
pod := &types.PodMetrics{

pkg/epp/scheduling/types/types.go

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@ limitations under the License.
1717
package types
1818

1919
import (
20+
"encoding/json"
21+
"errors"
2022
"fmt"
23+
"strings"
2124

2225
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
2326
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
@@ -113,16 +116,75 @@ func (r *ChatCompletionsRequest) String() string {
113116

114117
messagesLen := 0
115118
for _, msg := range r.Messages {
116-
messagesLen += len(msg.Content)
119+
messagesLen += len(msg.Content.PlainText())
117120
}
118-
119121
return fmt.Sprintf("{MessagesLength: %d}", messagesLen)
120122
}
121123

122124
// Message represents a single message in a chat-completions request.
123125
type Message struct {
124-
Role string
125-
Content string // TODO: support multi-modal content
126+
// Role is the message Role, optional values are 'user', 'assistant', ...
127+
Role string `json:"role,omitempty"`
128+
// Content defines text of this message
129+
Content Content `json:"content,omitempty"`
130+
}
131+
132+
type Content struct {
133+
Raw string
134+
Structured []ContentBlock
135+
}
136+
137+
type ContentBlock struct {
138+
Type string `json:"type"`
139+
Text string `json:"text,omitempty"`
140+
ImageURL ImageBlock `json:"image_url,omitempty"`
141+
}
142+
143+
type ImageBlock struct {
144+
Url string `json:"url,omitempty"`
145+
}
146+
147+
// UnmarshalJSON allow use both format
148+
func (mc *Content) UnmarshalJSON(data []byte) error {
149+
// Raw format
150+
var str string
151+
if err := json.Unmarshal(data, &str); err == nil {
152+
mc.Raw = str
153+
return nil
154+
}
155+
156+
// Block format
157+
var blocks []ContentBlock
158+
if err := json.Unmarshal(data, &blocks); err == nil {
159+
mc.Structured = blocks
160+
return nil
161+
}
162+
163+
return errors.New("content format not supported")
164+
}
165+
166+
func (mc Content) MarshalJSON() ([]byte, error) {
167+
if mc.Raw != "" {
168+
return json.Marshal(mc.Raw)
169+
}
170+
if mc.Structured != nil {
171+
return json.Marshal(mc.Structured)
172+
}
173+
return json.Marshal("")
174+
}
175+
176+
func (mc Content) PlainText() string {
177+
if mc.Raw != "" {
178+
return mc.Raw
179+
}
180+
var sb strings.Builder
181+
for _, block := range mc.Structured {
182+
if block.Type == "text" {
183+
sb.WriteString(block.Text)
184+
sb.WriteString(" ")
185+
}
186+
}
187+
return sb.String()
126188
}
127189

128190
type Pod interface {

pkg/epp/util/request/body_test.go

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,58 @@ func TestExtractRequestData(t *testing.T) {
5858
want: &types.LLMRequestBody{
5959
ChatCompletions: &types.ChatCompletionsRequest{
6060
Messages: []types.Message{
61-
{Role: "system", Content: "this is a system message"},
62-
{Role: "user", Content: "hello"},
61+
{Role: "system", Content: types.Content{Raw: "this is a system message"}},
62+
{Role: "user", Content: types.Content{Raw: "hello"}},
63+
},
64+
},
65+
},
66+
},
67+
{
68+
name: "chat completions request body with multi-modal content",
69+
body: map[string]any{
70+
"model": "test",
71+
"messages": []any{
72+
map[string]any{
73+
"role": "system",
74+
"content": []map[string]any{
75+
{
76+
"type": "text",
77+
"text": "Describe this image in one sentence.",
78+
},
79+
},
80+
},
81+
map[string]any{
82+
"role": "user",
83+
"content": []map[string]any{
84+
{
85+
"type": "image_url",
86+
"image_url": map[string]any{
87+
"url": "https://example.com/images/dui.jpg.",
88+
},
89+
},
90+
},
91+
},
92+
},
93+
},
94+
want: &types.LLMRequestBody{
95+
ChatCompletions: &types.ChatCompletionsRequest{
96+
Messages: []types.Message{
97+
{Role: "system", Content: types.Content{
98+
Structured: []types.ContentBlock{
99+
{
100+
Text: "Describe this image in one sentence.",
101+
Type: "text",
102+
},
103+
},
104+
}},
105+
{Role: "user", Content: types.Content{
106+
Structured: []types.ContentBlock{
107+
{
108+
Type: "image_url",
109+
ImageURL: types.ImageBlock{Url: "https://example.com/images/dui.jpg."},
110+
},
111+
},
112+
}},
63113
},
64114
},
65115
},
@@ -81,7 +131,7 @@ func TestExtractRequestData(t *testing.T) {
81131
},
82132
want: &types.LLMRequestBody{
83133
ChatCompletions: &types.ChatCompletionsRequest{
84-
Messages: []types.Message{{Role: "user", Content: "hello"}},
134+
Messages: []types.Message{{Role: "user", Content: types.Content{Raw: "hello"}}},
85135
Tools: []any{map[string]any{"type": "function"}},
86136
Documents: []any{map[string]any{"content": "doc"}},
87137
ChatTemplate: "custom template",
@@ -256,8 +306,8 @@ func TestExtractRequestData(t *testing.T) {
256306
want: &types.LLMRequestBody{
257307
ChatCompletions: &types.ChatCompletionsRequest{
258308
Messages: []types.Message{
259-
{Role: "system", Content: "this is a system message"},
260-
{Role: "user", Content: "hello"},
309+
{Role: "system", Content: types.Content{Raw: "this is a system message"}},
310+
{Role: "user", Content: types.Content{Raw: "hello"}},
261311
},
262312
CacheSalt: "Z3V2bmV3aGxza3ZubGFoZ3Zud3V3ZWZ2bmd0b3V2bnZmc2xpZ3RoZ2x2aQ==",
263313
},

0 commit comments

Comments
 (0)