Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 108 additions & 4 deletions pkg/plugins/gateway/algorithms/prefix_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package routingalgorithms

import (
"encoding/json"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the project adopts github.com/bytedance/sonic over encoding/json to optimize the performance, let's move to sonic if the scenario is supported.

"errors"
"fmt"
"math"
Expand Down Expand Up @@ -455,6 +456,64 @@ func (p *prefixCacheRouter) Cleanup() error {
return nil
}

// buildTokenizeInputFromChatRequest converts ChatCompletionRequest to TokenizeInput
// preserving multimodal content and vLLM-specific parameters
func buildTokenizeInputFromChatRequest(chatReq *types.ChatCompletionRequest) (*tokenizer.TokenizeInput, error) {
if len(chatReq.Messages) == 0 {
return nil, fmt.Errorf("no messages in chat completion request")
}

// Convert OpenAI messages to tokenizer messages, preserving content structure
messages := make([]tokenizer.ChatMessage, len(chatReq.Messages))
for i, msg := range chatReq.Messages {
role := msg.GetRole()
if role == nil {
return nil, fmt.Errorf("message at index %d has no role", i)
}

// Marshal the entire message to JSON then extract content field
msgJSON, err := json.Marshal(msg)
if err != nil {
return nil, fmt.Errorf("failed to marshal message at index %d: %w", i, err)
}

// Parse to extract content field as RawMessage (preserves structure)
var msgMap map[string]json.RawMessage
if err := json.Unmarshal(msgJSON, &msgMap); err != nil {
return nil, fmt.Errorf("failed to parse message at index %d: %w", i, err)
}

messages[i] = tokenizer.ChatMessage{
Role: *role,
Content: msgMap["content"], // Preserve as RawMessage
}
}

// Extract vLLM-specific parameters with defaults matching vLLM behavior
addSpecialTokens := false
if chatReq.AddSpecialTokens != nil {
addSpecialTokens = *chatReq.AddSpecialTokens
}

addGenerationPrompt := true
if chatReq.AddGenerationPrompt != nil {
addGenerationPrompt = *chatReq.AddGenerationPrompt
}

returnTokenStrings := false
if chatReq.ReturnTokenStrings != nil {
returnTokenStrings = *chatReq.ReturnTokenStrings
}

return &tokenizer.TokenizeInput{
Type: tokenizer.ChatInput,
Messages: messages,
AddSpecialTokens: addSpecialTokens,
AddGenerationPrompt: addGenerationPrompt,
ReturnTokenStrings: returnTokenStrings,
}, nil
}

// Route handles KV sync routing with clean implementation
func (k *kvSyncPrefixCacheRouter) Route(ctx *types.RoutingContext, readyPodList types.PodList) (string, error) {
// Start timing for latency metric if metrics are enabled
Expand Down Expand Up @@ -487,10 +546,55 @@ func (k *kvSyncPrefixCacheRouter) Route(ctx *types.RoutingContext, readyPodList
return "", fmt.Errorf("TokenizerPool not initialized for KV sync router")
}

// Tokenize the input
tokens, err := tokenizerToUse.TokenizeInputText(ctx.Message)
if err != nil {
return "", err
// Tokenize the input based on endpoint type
var tokens []byte
var err error

if ctx.ReqPath == "/v1/chat/completions" {
// For chat completions, try to use chat template tokenization
if extTokenizer, ok := tokenizerToUse.(tokenizer.ExtendedTokenizer); ok {
// Parse request body as ChatCompletionRequest
var chatReq types.ChatCompletionRequest
if parseErr := json.Unmarshal(ctx.ReqBody, &chatReq); parseErr == nil && len(chatReq.Messages) > 0 {
// Build TokenizeInput directly from request
input, buildErr := buildTokenizeInputFromChatRequest(&chatReq)
if buildErr == nil {
result, tokenizeErr := extTokenizer.TokenizeWithOptions(ctx.Context, *input)
if tokenizeErr == nil {
tokens = tokenizer.IntToByteArray(result.Tokens)
klog.V(4).InfoS("tokenized using chat template",
"request_id", ctx.RequestID,
"message_count", len(input.Messages),
"token_count", len(result.Tokens),
"add_generation_prompt", input.AddGenerationPrompt,
"add_special_tokens", input.AddSpecialTokens)
} else {
klog.V(4).InfoS("chat tokenization failed, falling back to text",
"request_id", ctx.RequestID,
"error", tokenizeErr)
}
} else {
klog.V(4).InfoS("failed to build tokenize input, falling back to text",
"request_id", ctx.RequestID,
"error", buildErr)
}
} else {
klog.V(4).InfoS("failed to parse chat request, falling back to text",
"request_id", ctx.RequestID,
"error", parseErr)
}
} else {
klog.V(4).InfoS("tokenizer does not support ExtendedTokenizer, using text tokenization",
"request_id", ctx.RequestID)
}
}

// Fallback to text tokenization if chat tokenization wasn't used or failed
if tokens == nil {
tokens, err = tokenizerToUse.TokenizeInputText(ctx.Message)
if err != nil {
return "", err
}
}

readyPods := readyPodList.All()
Expand Down
45 changes: 45 additions & 0 deletions pkg/types/request.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
Copyright 2025 The Aibrix Team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package types

import (
"github.com/openai/openai-go"
)

// ChatCompletionRequest extends OpenAI's ChatCompletionNewParams with vLLM-specific parameters
// Reference: https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/chat_completion/protocol.py
type ChatCompletionRequest struct {
openai.ChatCompletionNewParams

// vLLM-specific parameters

// AddGenerationPrompt controls whether to add the generation prompt to the chat template.
// Default: true (vLLM default)
// Reference: protocol.py line 217-224
AddGenerationPrompt *bool `json:"add_generation_prompt,omitempty"`

// AddSpecialTokens controls whether to add special tokens (e.g. BOS) on top of
// what is added by the chat template.
// Default: false (vLLM default) - chat template handles special tokens
// Reference: protocol.py line 235-244
AddSpecialTokens *bool `json:"add_special_tokens,omitempty"`

// ReturnTokenStrings controls whether to return token strings in tokenization results.
// This is used for debugging and verification purposes.
// Default: false
ReturnTokenStrings *bool `json:"return_token_strs,omitempty"`
}
8 changes: 4 additions & 4 deletions pkg/utils/tokenizer/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ type Tokenizer interface {
TokenizeInputText(string) ([]byte, error)
}

// extendedTokenizer represents an extended tokenizer interface with advanced features
// ExtendedTokenizer represents an extended tokenizer interface with advanced features
// Advanced features include:
// - Context-aware tokenization with cancellation support
// - Tokenization with custom options (special tokens, generation prompts)
// - Detokenization support for converting tokens back to text
//
// TODO: Consider simplifying the interface hierarchy by removing this intermediate layer
// if only remoteTokenizer needs these advanced features
type extendedTokenizer interface {
type ExtendedTokenizer interface {
Tokenizer // Embed existing interface for backward compatibility

// TokenizeWithOptions performs tokenization with advanced options
Expand All @@ -41,9 +41,9 @@ type extendedTokenizer interface {
Detokenize(ctx context.Context, tokens []int) (string, error)
}

// remoteTokenizer interface extends extendedTokenizer with remote-specific methods
// remoteTokenizer interface extends ExtendedTokenizer with remote-specific methods
type remoteTokenizer interface {
extendedTokenizer
ExtendedTokenizer
GetEndpoint() string
IsHealthy(ctx context.Context) bool
Close() error
Expand Down
135 changes: 135 additions & 0 deletions pkg/utils/tokenizer/serialization_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
Copyright 2025 The Aibrix Team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package tokenizer

import (
"encoding/json"
"testing"

"github.com/bytedance/sonic"
)

func TestSonicSerializationWithRawMessage(t *testing.T) {
tests := []struct {
name string
message ChatMessage
}{
{
name: "string content",
message: ChatMessage{
Role: "user",
Content: json.RawMessage(`"Hello world"`),
},
},
{
name: "multimodal array content",
message: ChatMessage{
Role: "user",
Content: json.RawMessage(`[
{"type":"text","text":"What's in this image?"},
{"type":"image_url","image_url":{"url":"https://example.com/img.jpg"}}
]`),
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test sonic.Marshal
data, err := sonic.Marshal(tt.message)
if err != nil {
t.Fatalf("sonic.Marshal failed: %v", err)
}

// Verify it can be unmarshaled back
var result ChatMessage
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("json.Unmarshal failed: %v", err)
}

if result.Role != tt.message.Role {
t.Errorf("Role mismatch: got %v, want %v", result.Role, tt.message.Role)
}

// Content should be preserved as-is
if string(result.Content) != string(tt.message.Content) {
t.Errorf("Content mismatch:\ngot: %s\nwant: %s", string(result.Content), string(tt.message.Content))
}
})
}
}

func TestVLLMRequestSerialization(t *testing.T) {
// Simulate what happens in vLLM adapter
messages := []ChatMessage{
{
Role: "user",
Content: json.RawMessage(`[{"type":"text","text":"test"},{"type":"image_url","image_url":{"url":"http://example.com/img.jpg"}}]`),
},
}

addSpecialTokens := false
addGenerationPrompt := true
returnTokenStrs := false

request := struct {
Messages []ChatMessage `json:"messages"`
AddSpecialTokens *bool `json:"add_special_tokens,omitempty"`
AddGenerationPrompt *bool `json:"add_generation_prompt,omitempty"`
ReturnTokenStrs *bool `json:"return_token_strs,omitempty"`
}{
Messages: messages,
AddSpecialTokens: &addSpecialTokens,
AddGenerationPrompt: &addGenerationPrompt,
ReturnTokenStrs: &returnTokenStrs,
}

// This is what the HTTP client does
jsonData, err := sonic.Marshal(request)
if err != nil {
t.Fatalf("Failed to marshal: %v", err)
}

t.Logf("Serialized request:\n%s", string(jsonData))

// Verify the output is valid JSON and content is preserved
var result map[string]interface{}
if err := json.Unmarshal(jsonData, &result); err != nil {
t.Fatalf("Failed to unmarshal result: %v", err)
}

// Check messages array
messagesArray, ok := result["messages"].([]interface{})
if !ok {
t.Fatal("messages should be an array")
}

if len(messagesArray) != 1 {
t.Fatalf("Expected 1 message, got %d", len(messagesArray))
}

// Check first message has content as array
firstMsg := messagesArray[0].(map[string]interface{})
content, ok := firstMsg["content"].([]interface{})
if !ok {
t.Fatal("content should be an array for multimodal message")
}

if len(content) != 2 {
t.Fatalf("Expected 2 content parts, got %d", len(content))
}
}
6 changes: 4 additions & 2 deletions pkg/utils/tokenizer/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package tokenizer

import (
"encoding/json"
"time"
)

Expand Down Expand Up @@ -49,9 +50,10 @@ type TokenizeResult struct {
}

// ChatMessage represents a single message in a chat conversation
// Content can be either a string (simple text) or a structured array (multimodal)
type ChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
Role string `json:"role"`
Content json.RawMessage `json:"content"` // Can be string or array for multimodal
}

// RemoteTokenizerConfig represents configuration for a remote tokenizer
Expand Down
Loading
Loading