Skip to content

Commit cf682b5

Browse files
authored
Fix /chat/completion response in echo mode (#362)
* Fix /chat/completion response in echo mode + update tests accordingly Signed-off-by: Maya Barnea <mayab@il.ibm.com> * Fix echo mode for grpc + lint Signed-off-by: Maya Barnea <mayab@il.ibm.com> * Fix echo in grpc Signed-off-by: Maya Barnea <mayab@il.ibm.com> --------- Signed-off-by: Maya Barnea <mayab@il.ibm.com>
1 parent d890326 commit cf682b5

File tree

8 files changed

+80
-2
lines changed

8 files changed

+80
-2
lines changed

pkg/dataset/dataset.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ type EchoDataset struct{}
5757
// if max-tokens is defined in the request and response's length is >= it value, finish reason is set to LENGTH,
5858
// otherwise finish reason is STOP
5959
func (ed *EchoDataset) GetResponseTokens(req openaiserverapi.Request) (*openaiserverapi.Tokenized, string, error) {
60-
tokens := req.TokenizedPrompt()
60+
tokens := req.TokenizedEchoResponse()
6161
maxTokens := req.GetMaxCompletionTokens()
62-
return tokens, common.FinishReason(maxTokens, len(tokens.Tokens)), nil
62+
return tokens, common.FinishReason(maxTokens, tokens.Length()), nil
6363
}
6464

6565
func (ed *EchoDataset) Close() error {

pkg/dataset/dataset_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ var _ = Describe("Echo Dataset", Ordered, func() {
187187
Prompt: theText,
188188
}
189189
req.SetTokenizedPrompt(&openaiserverapi.Tokenized{Tokens: tokens, Strings: strTokens})
190+
req.SetTokenizedEchoResponse(&openaiserverapi.Tokenized{Tokens: tokens, Strings: strTokens})
190191
tokens, finishReason, err := dataset.GetResponseTokens(req)
191192
Expect(err).ShouldNot(HaveOccurred())
192193
Expect(tokens.Strings).Should(Equal(strTokens))
@@ -199,6 +200,7 @@ var _ = Describe("Echo Dataset", Ordered, func() {
199200
MaxTokens: &maxTokens,
200201
}
201202
req.SetTokenizedPrompt(&openaiserverapi.Tokenized{Tokens: tokens, Strings: strTokens})
203+
req.SetTokenizedEchoResponse(&openaiserverapi.Tokenized{Tokens: tokens, Strings: strTokens})
202204

203205
tokens, finishReason, err := dataset.GetResponseTokens(req)
204206
Expect(err).ShouldNot(HaveOccurred())
@@ -212,12 +214,34 @@ var _ = Describe("Echo Dataset", Ordered, func() {
212214
MaxTokens: &maxTokens,
213215
}
214216
req.SetTokenizedPrompt(&openaiserverapi.Tokenized{Tokens: tokens, Strings: strTokens})
217+
req.SetTokenizedEchoResponse(&openaiserverapi.Tokenized{Tokens: tokens, Strings: strTokens})
215218

216219
tokens, finishReason, err := dataset.GetResponseTokens(req)
217220
Expect(err).ShouldNot(HaveOccurred())
218221
Expect(tokens.Strings).Should(Equal(strTokens))
219222
Expect(finishReason).Should(Equal(common.LengthFinishReason))
220223
})
224+
It("should return the last message in chat completion", func() {
225+
req := &openaiserverapi.ChatCompletionRequest{
226+
Messages: []openaiserverapi.Message{
227+
{Role: openaiserverapi.RoleUser, Content: openaiserverapi.Content{Raw: "user message1"}},
228+
{Role: openaiserverapi.RoleAssistant, Content: openaiserverapi.Content{Raw: "assistant message1"}},
229+
{Role: openaiserverapi.RoleUser, Content: openaiserverapi.Content{Raw: testPrompt}},
230+
},
231+
}
232+
promptTokens, promptStrTokens, err := tokenizer.Encode(req.GetFullPrompt(), "")
233+
Expect(err).ShouldNot(HaveOccurred())
234+
respTokens, resptStrTokens, err := tokenizer.Encode(testPrompt, "")
235+
Expect(err).ShouldNot(HaveOccurred())
236+
237+
req.SetTokenizedPrompt(&openaiserverapi.Tokenized{Tokens: promptTokens, Strings: promptStrTokens})
238+
req.SetTokenizedEchoResponse(&openaiserverapi.Tokenized{Tokens: respTokens, Strings: resptStrTokens})
239+
240+
tokens, _, err := dataset.GetResponseTokens(req)
241+
Expect(err).ShouldNot(HaveOccurred())
242+
243+
Expect(tokens.Strings).Should(Equal(resptStrTokens))
244+
})
221245
})
222246

223247
DescribeTable("should work correctly in echo mode",
@@ -235,6 +259,7 @@ var _ = Describe("Echo Dataset", Ordered, func() {
235259
req = &textReq
236260
}
237261
req.SetTokenizedPrompt(&openaiserverapi.Tokenized{Tokens: promptTokens, Strings: promptStrTokens})
262+
req.SetTokenizedEchoResponse(&openaiserverapi.Tokenized{Tokens: promptTokens, Strings: promptStrTokens})
238263

239264
tokens, finishReason, err := dataset.GetResponseTokens(req)
240265
Expect(err).NotTo(HaveOccurred())

pkg/llm-d-inference-sim/chat_completion.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,14 @@ func (c *chatCompletionRequest) createResponseContext(reqCtx requestContext, dis
7777
}
7878
}
7979

80+
func (c *chatCompletionReqCtx) getEchoTokens() ([]uint32, []string, error) {
81+
lastMsg := ""
82+
if len(c.req.Messages) > 0 {
83+
lastMsg = c.req.Messages[len(c.req.Messages)-1].Content.Raw
84+
}
85+
return c.sim.tokenizer.Encode(lastMsg, "")
86+
}
87+
8088
var _ request = (*chatCompletionRequest)(nil)
8189

8290
// Implementation of requestContext for /chat/completions requests

pkg/llm-d-inference-sim/generation.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ func (g *generationReqCtx) request() request {
7272
return g.req
7373
}
7474

75+
func (g *generationReqCtx) getEchoTokens() ([]uint32, []string, error) {
76+
tokenisedResponse := g.req.TokenizedEchoResponse()
77+
if tokenisedResponse != nil {
78+
return tokenisedResponse.Tokens, tokenisedResponse.Strings, nil
79+
}
80+
return g.sim.tokenizer.Encode(g.req.Prompt, "")
81+
}
82+
7583
func (g *generationReqCtx) kvCacheOnRequestStart() (hitRate float64, oaiServerError *openaiserverapi.Error) {
7684
if g.sim.config.EnableKVCache {
7785
var err error

pkg/llm-d-inference-sim/grpc.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ func (s *VllmSimulator) pbRequestToRequest(in *pb.GenerateRequest) *generationRe
144144
prompt := &openaiserverapi.Tokenized{}
145145
prompt.Tokens = in.GetTokenized().InputIds
146146
req.SetTokenizedPrompt(prompt)
147+
req.SetTokenizedEchoResponse(prompt)
147148
} else {
148149
req.Prompt = in.GetText()
149150
}

pkg/llm-d-inference-sim/request.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ type requestContext interface {
4848
createToolCalls() ([]openaiserverapi.ToolCall, int, string, error)
4949
handleRequest() (responseContext, *openaiserverapi.Error)
5050
responseChannel() chan *responseInfo
51+
getEchoTokens() ([]uint32, []string, error)
5152
}
5253

5354
type baseRequestContext struct {
@@ -92,6 +93,21 @@ func (b *baseRequestContext) tokenize() *openaiserverapi.Error {
9293
Tokens: tokens,
9394
Strings: textTokens,
9495
})
96+
97+
if b.sim.config.Mode == common.ModeEcho {
98+
tokens, textTokens, err = b.getEchoTokens()
99+
if err != nil {
100+
b.sim.logger.Error(err, "failed to tokenize echo mode response")
101+
serverErr := openaiserverapi.NewError("Failed to tokenize echo mode response, "+err.Error(), fasthttp.StatusInternalServerError, nil)
102+
return &serverErr
103+
}
104+
105+
req.SetTokenizedEchoResponse(&openaiserverapi.Tokenized{
106+
Tokens: tokens,
107+
Strings: textTokens,
108+
})
109+
}
110+
95111
return nil
96112
}
97113

pkg/llm-d-inference-sim/text_completion.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ func (t *textCompletionReqCtx) createToolCalls() ([]openaiserverapi.ToolCall, in
101101
return nil, 0, "", nil
102102
}
103103

104+
func (t *textCompletionReqCtx) getEchoTokens() ([]uint32, []string, error) {
105+
return t.sim.tokenizer.Encode(t.req.Prompt, "")
106+
}
107+
104108
var _ requestContext = (*textCompletionReqCtx)(nil)
105109

106110
// Implementation of responseContext for /completions requests

pkg/openai-server-api/request.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ type Request interface {
8080
TokenizedPrompt() *Tokenized
8181
// SetTokenizedPrompt sets the tokenized prompt
8282
SetTokenizedPrompt(tokenized *Tokenized)
83+
// TokenizedEchoResponse returns the tokenized response in echo mode
84+
TokenizedEchoResponse() *Tokenized
85+
// SetTokenizedEchoResponse sets the tokenized response in echo mode
86+
SetTokenizedEchoResponse(tokenized *Tokenized)
8387
// CacheThresholdFinishReason returns cacheThresholdFinishReason, when true,
8488
// forces a cache_threshold finish reason
8589
CacheThresholdFinishReason() bool
@@ -112,6 +116,8 @@ type baseCompletionRequest struct {
112116
cacheThresholdFinishReason bool
113117
// tokenizedPrompt is the tokenized prompt
114118
tokenizedPrompt *Tokenized
119+
// tokenizedEchoResponse is the tokenized response in echo mode, exists only in echo mode
120+
tokenizedEchoResponse *Tokenized
115121
}
116122

117123
type KVTransferParams struct {
@@ -245,6 +251,16 @@ func (b *baseCompletionRequest) SetTokenizedPrompt(tokenized *Tokenized) {
245251
b.tokenizedPrompt = tokenized
246252
}
247253

254+
// TokenizedEchoResponse returns the tokenized response in echo mode
255+
func (b *baseCompletionRequest) TokenizedEchoResponse() *Tokenized {
256+
return b.tokenizedEchoResponse
257+
}
258+
259+
// SetTokenizedEchoResponse sets the tokenized response in echo mode
260+
func (b *baseCompletionRequest) SetTokenizedEchoResponse(tokenized *Tokenized) {
261+
b.tokenizedEchoResponse = tokenized
262+
}
263+
248264
// ChatCompletionRequest defines structure of /chat/completion request
249265
type ChatCompletionRequest struct {
250266
baseCompletionRequest

0 commit comments

Comments
 (0)