Skip to content

Commit 9819269

Browse files
committed
enable chunked decode option in the routing proxy
Signed-off-by: andreyod <andreyo@il.ibm.com>
1 parent 3c69796 commit 9819269

File tree

5 files changed

+216
-17
lines changed

5 files changed

+216
-17
lines changed

cmd/pd-sidecar/main.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ func main() {
6060
enablePrefillerSampling := flag.Bool("enable-prefiller-sampling", func() bool { b, _ := strconv.ParseBool(os.Getenv("ENABLE_PREFILLER_SAMPLING")); return b }(), "if true, the target prefill instance will be selected randomly from among the provided prefill host values")
6161
poolGroup := flag.String("pool-group", proxy.DefaultPoolGroup, "group of the InferencePool this Endpoint Picker is associated with.")
6262

63+
enableChunkedDecode := flag.Bool("enable-chunked-decode", false, "enable chunked decode output. Defaults to false.")
64+
decodeChunkSize := flag.Int("decode-chunk-size", 512, "the decode output chunk size in token. Only when enable-chunked-decode is true.") //TODO: maybe use KV cache block-size instead of tokens
65+
6366
opts := zap.Options{}
6467
opts.BindFlags(flag.CommandLine) // optional to allow zap logging control via CLI
6568
flag.Parse()
@@ -133,6 +136,8 @@ func main() {
133136
DecoderInsecureSkipVerify: *decoderInsecureSkipVerify,
134137
DataParallelSize: *vLLMDataParallelSize,
135138
EnablePrefillerSampling: *enablePrefillerSampling,
139+
EnableChunkedDecode: *enableChunkedDecode,
140+
DecodeChunkSize: *decodeChunkSize,
136141
}
137142

138143
// Create SSRF protection validator

pkg/sidecar/proxy/chat_completions.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ func (s *Server) chatCompletionsHandler(w http.ResponseWriter, r *http.Request)
5858
s.logger.V(4).Info("skip disaggregated prefill")
5959

6060
if !s.forwardDataParallel || !s.dataParallelHandler(w, r) {
61+
if s.config.EnableChunkedDecode {
62+
s.sendChunkedDecodeRequest(w, r, nil) // currently supported for vLLM only
63+
return
64+
}
6165
s.decoderProxy.ServeHTTP(w, r)
6266
}
6367
return

pkg/sidecar/proxy/connector_nixlv2.go

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ func (s *Server) runNIXLProtocolV2(w http.ResponseWriter, r *http.Request, prefi
3939

4040
// Parse completion request
4141
var completionRequest map[string]any
42-
if err := json.Unmarshal(original, &completionRequest); err != nil {
42+
if err := json.Unmarshal(original, &completionRequest); err != nil { // TODO: use openai-go?
4343
if err := errorJSONInvalid(err, w); err != nil {
4444
s.logger.Error(err, "failed to send error response to client")
4545
}
@@ -155,21 +155,6 @@ func (s *Server) runNIXLProtocolV2(w http.ResponseWriter, r *http.Request, prefi
155155
}
156156
completionRequest[requestFieldKVTransferParams] = pKVTransferParams
157157

158-
dbody, err := json.Marshal(completionRequest)
159-
if err != nil {
160-
if err := errorJSONInvalid(err, w); err != nil {
161-
s.logger.Error(err, "failed to send error response to client")
162-
}
163-
return
164-
}
165-
dreq.Body = io.NopCloser(strings.NewReader(string(dbody)))
166-
dreq.ContentLength = int64(len(dbody))
167-
168158
// 2. Forward to local decoder.
169-
170-
s.logger.V(5).Info("sending request to decoder", "body", string(dbody))
171-
if !s.forwardDataParallel || !s.dataParallelHandler(w, dreq) {
172-
s.logger.V(4).Info("sending request to decoder", "to", s.decoderURL.Host)
173-
s.decoderProxy.ServeHTTP(w, dreq)
174-
}
159+
s.sendDecodeRequest(w, dreq, completionRequest)
175160
}

pkg/sidecar/proxy/decode.go

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
/*
2+
Copyright 2026 The llm-d Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package proxy
18+
19+
import (
20+
"encoding/json"
21+
"io"
22+
"net/http"
23+
"net/http/httptest"
24+
"strings"
25+
)
26+
27+
func (s *Server) sendDecodeRequest(w http.ResponseWriter, r *http.Request, completionRequest map[string]any) {
28+
29+
if s.forwardDataParallel && s.dataParallelHandler(w, r) {
30+
return
31+
}
32+
if !s.config.EnableChunkedDecode {
33+
dreq, err := setRequestBody(r, completionRequest)
34+
if err != nil {
35+
if err := errorJSONInvalid(err, w); err != nil {
36+
s.logger.Error(err, "failed to send error response to client")
37+
}
38+
return
39+
}
40+
41+
s.logger.V(4).Info("sending request to decoder", "to", s.decoderURL.Host)
42+
s.decoderProxy.ServeHTTP(w, dreq)
43+
44+
return
45+
}
46+
s.sendChunkedDecodeRequest(w, r, completionRequest)
47+
}
48+
49+
type Message struct {
50+
Role string `json:"role"`
51+
Content string `json:"content"`
52+
}
53+
54+
// Partial response struct TODO: use openai-go?
55+
type ChatCompletionResponse struct {
56+
Choices []struct {
57+
Message Message `json:"message"`
58+
FinishReason string `json:"finish_reason"`
59+
} `json:"choices"`
60+
}
61+
62+
func (s *Server) sendChunkedDecodeRequest(w http.ResponseWriter, r *http.Request, completionRequest map[string]any) {
63+
64+
if completionRequest == nil {
65+
var err error
66+
completionRequest, err = parseCompletionRequest(r)
67+
if err != nil {
68+
w.WriteHeader(http.StatusBadRequest)
69+
w.Write([]byte(err.Error())) //nolint:all
70+
return
71+
}
72+
}
73+
74+
//TODO: validate if we should run chunked decode for this request
75+
// based on the request parameters,
76+
// e.g., continue_final_message, add_generation_prompt, max tokens, etc.
77+
78+
completionRequest[requestFieldMaxCompletionTokens] = s.config.DecodeChunkSize
79+
80+
s.logger.V(4).Info("sending chunked decode request", "chunk size", s.config.DecodeChunkSize)
81+
82+
messagesAny, ok := completionRequest["messages"]
83+
if !ok {
84+
s.logger.Error(nil, "chunked decode: missing 'messages' field in decode request")
85+
return
86+
}
87+
messages, ok := messagesAny.([]any)
88+
if !ok {
89+
s.logger.Error(nil, "chunked decode: invalid 'messages' field in decode request")
90+
return
91+
}
92+
93+
respBody := []byte{}
94+
respStausCode := 0
95+
var responseMessageContent string
96+
97+
for {
98+
dreq, err := setRequestBody(r, completionRequest)
99+
if err != nil {
100+
if err := errorJSONInvalid(err, w); err != nil {
101+
s.logger.Error(err, "failed to send error response to client")
102+
}
103+
return
104+
}
105+
106+
rec := httptest.NewRecorder()
107+
s.decoderProxy.ServeHTTP(rec, dreq)
108+
resp := rec.Result()
109+
defer resp.Body.Close()
110+
111+
respBody, _ = io.ReadAll(resp.Body) // TODO: handle error
112+
var parsed ChatCompletionResponse
113+
if err := json.Unmarshal(respBody, &parsed); err != nil {
114+
s.logger.Error(err, "failed to decode response")
115+
return
116+
}
117+
respStausCode = resp.StatusCode
118+
119+
if len(parsed.Choices) == 0 {
120+
s.logger.Error(nil, "no choices in decoder response")
121+
return
122+
}
123+
124+
choice := parsed.Choices[0]
125+
chunk := choice.Message.Content
126+
finishReason := choice.FinishReason
127+
128+
s.logger.V(4).Info("decoder response chunk", "chunk", chunk, "finish_reason", finishReason)
129+
// Append chunk to build full response content
130+
responseMessageContent += chunk
131+
132+
// Prepare for next iteration
133+
134+
// Append assistant message to continue generation
135+
messages = append(messages, Message{
136+
Role: "assistant",
137+
Content: chunk,
138+
})
139+
completionRequest["messages"] = messages
140+
141+
// Do not pull KV cache next time
142+
delete(completionRequest, requestFieldKVTransferParams)
143+
144+
completionRequest["continue_final_message"] = true
145+
completionRequest["add_generation_prompt"] = false
146+
147+
s.logger.V(4).Info("chunked decode combined output", "output", responseMessageContent)
148+
149+
// Stop unless the model was cut off due to token limit
150+
if finishReason != "length" {
151+
break
152+
}
153+
}
154+
155+
// add the combined message to the final response
156+
var finalResponse ChatCompletionResponse
157+
if err := json.Unmarshal(respBody, &finalResponse); err != nil {
158+
s.logger.Error(err, "failed to decode final decoder response")
159+
return
160+
}
161+
if len(finalResponse.Choices) == 0 {
162+
s.logger.Error(nil, "no choices in final decoder response")
163+
return
164+
}
165+
finalResponse.Choices[0].Message.Content = responseMessageContent
166+
167+
respBody, err := json.Marshal(finalResponse)
168+
if err != nil {
169+
s.logger.Error(err, "failed to marshal final decoder response")
170+
return
171+
}
172+
// Write response back to original writer
173+
w.Header().Set("Content-Type", "application/json")
174+
w.WriteHeader(respStausCode)
175+
w.Write(respBody) //nolint:errcheck
176+
}
177+
178+
func setRequestBody(r *http.Request, completionRequest map[string]any) (*http.Request, error) {
179+
dbody, err := json.Marshal(completionRequest)
180+
if err != nil {
181+
return nil, err
182+
}
183+
r.Body = io.NopCloser(strings.NewReader(string(dbody)))
184+
r.ContentLength = int64(len(dbody))
185+
return r, nil
186+
}
187+
188+
func parseCompletionRequest(r *http.Request) (map[string]any, error) {
189+
defer r.Body.Close() //nolint:all
190+
original, err := io.ReadAll(r.Body)
191+
if err != nil {
192+
return nil, err
193+
}
194+
var completionRequest map[string]any
195+
if err := json.Unmarshal(original, &completionRequest); err != nil { // TODO: use openai-go?
196+
return nil, err
197+
}
198+
return completionRequest, nil
199+
}

pkg/sidecar/proxy/proxy.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,12 @@ type Config struct {
8787
// EnablePrefillerSampling configures the proxy to randomly choose from the set
8888
// of provided prefill hosts instead of always using the first one.
8989
EnablePrefillerSampling bool
90+
91+
// EnableChunkedDecode if set to true, will enable chunked decode requests.
92+
EnableChunkedDecode bool
93+
94+
// DecodeChunkSize is the size of the decode chunk when chunked decode is enabled.
95+
DecodeChunkSize int
9096
}
9197

9298
type protocolRunner func(http.ResponseWriter, *http.Request, string)

0 commit comments

Comments
 (0)