Skip to content

Commit 2a6e437

Browse files
committed
feat: implement decode first flow on lmcache connector
- if cache_hit_threshold field is present in completion request, then we perform a decode first flow Signed-off-by: kyano <kyanokashi2@gmail.com>
1 parent e1480d7 commit 2a6e437

File tree

2 files changed

+129
-17
lines changed

2 files changed

+129
-17
lines changed

pkg/sidecar/proxy/connector_lmcache.go

Lines changed: 123 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,14 @@ import (
2626
func (s *Server) runLMCacheProtocol(w http.ResponseWriter, r *http.Request, prefillPodHostPort string) {
2727
s.logger.Info("running LMCache protocol")
2828

29-
// Read and parse request body
3029
defer r.Body.Close() //nolint:all
3130
original, err := io.ReadAll(r.Body)
3231
if err != nil {
33-
w.WriteHeader(http.StatusBadRequest) // TODO: check FastAPI error code when failing to read body
34-
w.Write([]byte(err.Error())) //nolint:all
32+
w.WriteHeader(http.StatusBadRequest)
33+
w.Write([]byte(err.Error())) //nolint:all
3534
return
3635
}
3736

38-
// Parse completion request
3937
var completionRequest map[string]any
4038
if err := json.Unmarshal(original, &completionRequest); err != nil {
4139
if err := errorJSONInvalid(err, w); err != nil {
@@ -44,11 +42,120 @@ func (s *Server) runLMCacheProtocol(w http.ResponseWriter, r *http.Request, pref
4442
return
4543
}
4644

47-
// Create prefiller request. Set max_tokens to 1.
45+
if s.forwardDataParallel && s.dataParallelHandler(w, r) {
46+
if err := s.prefill(w, r, prefillPodHostPort, completionRequest); err != nil {
47+
s.logger.Error(err, "prefill failed")
48+
}
49+
return
50+
}
51+
52+
if _, hasCacheHitThreshold := completionRequest[requestFieldCacheHitThreshold]; hasCacheHitThreshold {
53+
s.decodeFirst(w, r, original, completionRequest, prefillPodHostPort)
54+
} else {
55+
s.prefillThenDecode(w, r, original, completionRequest, prefillPodHostPort)
56+
}
57+
}
58+
59+
// prefillThenDecode implements the prefill-first flow: prefill then decode
60+
func (s *Server) prefillThenDecode(w http.ResponseWriter, r *http.Request, original []byte, completionRequest map[string]any, prefillPodHostPort string) {
61+
s.logger.V(4).Info("running prefill-then-decode flow")
62+
63+
if err := s.prefill(w, r, prefillPodHostPort, completionRequest); err != nil {
64+
s.logger.Error(err, "prefill failed")
65+
return
66+
}
67+
68+
s.logger.V(4).Info("forwarding to decoder after prefill")
69+
r.Body = io.NopCloser(strings.NewReader(string(original)))
70+
s.decoderProxy.ServeHTTP(w, r)
71+
}
72+
73+
// decodeFirst implements the decode-first flow with cache threshold checking
74+
func (s *Server) decodeFirst(w http.ResponseWriter, r *http.Request, original []byte, completionRequest map[string]any, prefillPodHostPort string) {
75+
s.logger.V(4).Info("running decode-first flow")
76+
77+
// Step 1: Try decode first
78+
r.Body = io.NopCloser(strings.NewReader(string(original)))
79+
needsPrefill, err := s.tryDecode(w, r)
80+
if err != nil {
81+
s.logger.Error(err, "decode attempt failed")
82+
return
83+
}
84+
85+
// If decode succeeded (cache hit was sufficient), we're done
86+
if !needsPrefill {
87+
s.logger.V(4).Info("decode succeeded without prefill")
88+
return
89+
}
90+
91+
// Step 2: Cache threshold not met, execute prefill
92+
s.logger.V(4).Info("cache threshold not met, executing prefill", "prefillPod", prefillPodHostPort)
93+
if err := s.prefill(w, r, prefillPodHostPort, completionRequest); err != nil {
94+
s.logger.Error(err, "prefill failed")
95+
return
96+
}
4897

98+
// Step 3: Retry decode after prefill
99+
s.logger.V(4).Info("retrying decode after prefill")
100+
r.Body = io.NopCloser(strings.NewReader(string(original)))
101+
s.decoderProxy.ServeHTTP(w, r)
102+
}
103+
104+
// tryDecode attempts to decode and returns whether prefill is needed
105+
func (s *Server) tryDecode(w http.ResponseWriter, r *http.Request) (needsPrefill bool, err error) {
106+
dw := &bufferedResponseWriter{}
107+
s.decoderProxy.ServeHTTP(dw, r)
108+
109+
// Check for non-success status codes
110+
if dw.statusCode < 200 || dw.statusCode >= 300 {
111+
s.logger.Error(nil, "decode request failed", "code", dw.statusCode)
112+
w.WriteHeader(dw.statusCode)
113+
if dw.buffer.Len() > 0 {
114+
w.Write([]byte(dw.buffer.String())) //nolint:all
115+
}
116+
return false, nil
117+
}
118+
119+
// Parse response to check finish_reason
120+
var response map[string]any
121+
if err := json.Unmarshal([]byte(dw.buffer.String()), &response); err != nil {
122+
s.logger.Error(err, "failed to unmarshal decoder response")
123+
// Forward response as-is if we can't parse it
124+
w.WriteHeader(dw.statusCode)
125+
w.Write([]byte(dw.buffer.String())) //nolint:all
126+
return false, nil
127+
}
128+
129+
// Check for cache_threshold finish reason
130+
if choices, ok := response[responseFieldChoices].([]any); ok && len(choices) > 0 {
131+
if choice, ok := choices[0].(map[string]any); ok {
132+
if finishReason, ok := choice[responseFieldFinishReason].(string); ok {
133+
if finishReason == finishReasonCacheThreshold {
134+
s.logger.V(4).Info("decode rejected due to cache threshold")
135+
return true, nil
136+
}
137+
}
138+
}
139+
}
140+
141+
// Decode succeeded, write response to client
142+
for k, v := range dw.headers {
143+
for _, val := range v {
144+
w.Header().Add(k, val)
145+
}
146+
}
147+
w.WriteHeader(dw.statusCode)
148+
w.Write([]byte(dw.buffer.String())) //nolint:all
149+
150+
return false, nil
151+
}
152+
153+
// prefill routes a request to a preill node
154+
func (s *Server) prefill(w http.ResponseWriter, r *http.Request, prefillPodHostPort string, completionRequest map[string]any) error {
49155
ctx := r.Context()
50156
preq := r.Clone(ctx)
51157

158+
// Prepare prefill request
52159
completionRequest[requestFieldMaxTokens] = 1
53160
completionRequest[requestFieldMaxCompletionTokens] = 1
54161

@@ -57,34 +164,33 @@ func (s *Server) runLMCacheProtocol(w http.ResponseWriter, r *http.Request, pref
57164
if err := errorJSONInvalid(err, w); err != nil {
58165
s.logger.Error(err, "failed to send error response to client")
59166
}
60-
return
167+
return err
61168
}
62169
preq.Body = io.NopCloser(strings.NewReader(string(pbody)))
63170
preq.ContentLength = int64(len(pbody))
64171

65-
// Forward request to prefiller
66-
67172
prefillHandler, err := s.prefillerProxyHandler(prefillPodHostPort)
68173
if err != nil {
69174
if err := errorBadGateway(err, w); err != nil {
70175
s.logger.Error(err, "failed to send error response to client")
71176
}
72-
return
177+
return err
73178
}
179+
180+
// send prefill request
74181
s.logger.V(4).Info("sending prefill request", "to", prefillPodHostPort)
75182
pw := &bufferedResponseWriter{}
76183
prefillHandler.ServeHTTP(pw, preq)
77184

78185
if pw.statusCode < 200 || pw.statusCode >= 300 {
79-
s.logger.Error(err, "request failed", "code", pw.statusCode)
186+
s.logger.Error(nil, "prefill request failed", "code", pw.statusCode)
80187
w.WriteHeader(pw.statusCode)
81-
return
188+
if pw.buffer.Len() > 0 {
189+
w.Write([]byte(pw.buffer.String())) //nolint:all
190+
}
191+
return err
82192
}
83193

84-
// Forward original request to local decoder
85-
86-
r.Body = io.NopCloser(strings.NewReader(string(original)))
87-
if s.forwardDataParallel && !s.dataParallelHandler(w, r) {
88-
s.decoderProxy.ServeHTTP(w, r)
89-
}
194+
s.logger.V(4).Info("prefill completed successfully")
195+
return nil
90196
}

pkg/sidecar/proxy/proxy.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ const (
4545
requestFieldRemotePort = "remote_port"
4646
requestFieldStream = "stream"
4747
requestFieldStreamOptions = "stream_options"
48+
requestFieldCacheHitThreshold = "cache_hit_threshold"
49+
50+
responseFieldChoices = "choices"
51+
responseFieldFinishReason = "finish_reason"
52+
53+
finishReasonCacheThreshold = "cache_threshold"
4854

4955
// ConnectorNIXLV2 enables the P/D NIXL v2 protocol
5056
ConnectorNIXLV2 = "nixlv2"

0 commit comments

Comments
 (0)