-
Notifications
You must be signed in to change notification settings - Fork 133
feat(lmcache): implement decode first flow on lmcache connector when cache_hit_threshold field is present #509
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
b228bb7
a436b50
a6ae771
58388eb
04b7ffd
ce74f50
1de6035
c0ac69e
7ce5e19
6430a02
4c15d95
cac084f
91c7a06
69d30b5
1ed1d89
515b385
4c8659e
878585b
722b58e
487d333
cb00b52
88739c6
9fbb2d1
7b18827
460b9c8
a722510
e2b3380
548e6c8
160fb72
9a08dfb
7508f10
bd114fa
5a6a4f6
6e6ff8f
ea60bf0
0cbc6f9
fd43c17
7030e38
f84046a
2f0e99e
ea3f1da
2e35d85
d71b10b
e34600f
efd638a
5e988e5
fbb12fc
903249f
f6024d1
5f042e7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ package proxy | |
|
|
||
| import ( | ||
| "encoding/json" | ||
| "fmt" | ||
| "io" | ||
| "net/http" | ||
| "strings" | ||
|
|
@@ -26,16 +27,14 @@ import ( | |
| func (s *Server) runLMCacheProtocol(w http.ResponseWriter, r *http.Request, prefillPodHostPort string) { | ||
| s.logger.Info("running LMCache protocol") | ||
|
|
||
| // Read and parse request body | ||
| defer r.Body.Close() //nolint:all | ||
| original, err := io.ReadAll(r.Body) | ||
| if err != nil { | ||
| w.WriteHeader(http.StatusBadRequest) // TODO: check FastAPI error code when failing to read body | ||
| w.Write([]byte(err.Error())) //nolint:all | ||
| w.WriteHeader(http.StatusBadRequest) | ||
kyanokashi marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| w.Write([]byte(err.Error())) //nolint:all | ||
| return | ||
| } | ||
|
|
||
| // Parse completion request | ||
kyanokashi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| var completionRequest map[string]any | ||
| if err := json.Unmarshal(original, &completionRequest); err != nil { | ||
| if err := errorJSONInvalid(err, w); err != nil { | ||
|
|
@@ -44,11 +43,117 @@ func (s *Server) runLMCacheProtocol(w http.ResponseWriter, r *http.Request, pref | |
| return | ||
| } | ||
|
|
||
| // Create prefiller request. Set max_tokens to 1. | ||
| if s.forwardDataParallel && s.dataParallelHandler(w, r) { | ||
|
||
| if err := s.prefill(w, r, prefillPodHostPort, completionRequest); err != nil { | ||
| s.logger.Error(err, "prefill failed") | ||
| } | ||
| return | ||
| } | ||
|
|
||
kyanokashi marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if _, hasCacheHitThreshold := completionRequest[requestFieldCacheHitThreshold]; hasCacheHitThreshold { | ||
| s.decodeFirst(w, r, original, completionRequest, prefillPodHostPort) | ||
| } else { | ||
| s.prefillThenDecode(w, r, original, completionRequest, prefillPodHostPort) | ||
| } | ||
| } | ||
|
|
||
| // prefillThenDecode implements the prefill-first flow: prefill then decode | ||
| func (s *Server) prefillThenDecode(w http.ResponseWriter, r *http.Request, original []byte, completionRequest map[string]any, prefillPodHostPort string) { | ||
| s.logger.V(4).Info("running prefill-then-decode flow") | ||
|
|
||
| if err := s.prefill(w, r, prefillPodHostPort, completionRequest); err != nil { | ||
| s.logger.Error(err, "prefill failed") | ||
| return | ||
| } | ||
|
|
||
kyanokashi marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| s.logger.V(4).Info("forwarding to decoder after prefill") | ||
| r.Body = io.NopCloser(strings.NewReader(string(original))) | ||
| s.decoderProxy.ServeHTTP(w, r) | ||
| } | ||
|
|
||
| // decodeFirst implements the decode-first flow with cache threshold checking | ||
| func (s *Server) decodeFirst(w http.ResponseWriter, r *http.Request, original []byte, completionRequest map[string]any, prefillPodHostPort string) { | ||
| s.logger.V(4).Info("running decode-first flow") | ||
|
|
||
| // Step 1: Try decode first | ||
| r.Body = io.NopCloser(strings.NewReader(string(original))) | ||
| needsPrefill, err := s.tryDecode(w, r) | ||
| if err != nil { | ||
| s.logger.Error(err, "decode attempt failed") | ||
| return | ||
| } | ||
|
|
||
| // If decode succeeded (cache hit was sufficient), we're done | ||
| if !needsPrefill { | ||
| s.logger.V(4).Info("decode succeeded without prefill") | ||
| return | ||
| } | ||
|
|
||
| // Step 2: Cache threshold not met, execute prefill | ||
kyanokashi marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| s.logger.V(4).Info("cache threshold not met, executing prefill", "prefillPod", prefillPodHostPort) | ||
| if err := s.prefill(w, r, prefillPodHostPort, completionRequest); err != nil { | ||
| s.logger.Error(err, "prefill failed") | ||
| return | ||
| } | ||
|
|
||
| // Step 3: Retry decode after prefill | ||
| s.logger.V(4).Info("retrying decode after prefill") | ||
| r.Body = io.NopCloser(strings.NewReader(string(original))) | ||
| s.decoderProxy.ServeHTTP(w, r) | ||
| } | ||
|
|
||
| // tryDecode attempts to decode and returns whether prefill is needed | ||
| func (s *Server) tryDecode(w http.ResponseWriter, r *http.Request) (bool, error) { | ||
| dw := &bufferedResponseWriter{} | ||
| s.decoderProxy.ServeHTTP(dw, r) | ||
|
|
||
| // Check for non-success status codes | ||
| if dw.statusCode < 200 || dw.statusCode >= 300 { | ||
elevran marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| w.WriteHeader(dw.statusCode) | ||
kyanokashi marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if dw.buffer.Len() > 0 { | ||
| w.Write([]byte(dw.buffer.String())) //nolint:all | ||
| } | ||
| return false, fmt.Errorf("decode request failed with status code: %d", dw.statusCode) | ||
| } | ||
|
|
||
| // Parse response to check finish_reason | ||
| var response map[string]any | ||
| if err := json.Unmarshal([]byte(dw.buffer.String()), &response); err != nil { | ||
kyanokashi marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| w.WriteHeader(dw.statusCode) | ||
| w.Write([]byte(dw.buffer.String())) //nolint:all | ||
| return false, err | ||
| } | ||
|
|
||
| // Check for cache_threshold finish reason | ||
| if choices, ok := response[responseFieldChoices].([]any); ok && len(choices) > 0 { | ||
| if choice, ok := choices[0].(map[string]any); ok { | ||
| if finishReason, ok := choice[responseFieldFinishReason].(string); ok { | ||
| if finishReason == finishReasonCacheThreshold { | ||
| s.logger.V(4).Info("decode rejected due to cache threshold") | ||
| return true, nil | ||
| } | ||
| } | ||
|
|
||
| } | ||
| } | ||
|
|
||
| // Decode succeeded, write response to client | ||
|
||
| for k, v := range dw.headers { | ||
| for _, val := range v { | ||
| w.Header().Add(k, val) | ||
| } | ||
| } | ||
kyanokashi marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| w.Write([]byte(dw.buffer.String())) //nolint:all | ||
kyanokashi marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| return false, nil | ||
| } | ||
|
|
||
| // prefill routes a request to a preill node | ||
kyanokashi marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| func (s *Server) prefill(w http.ResponseWriter, r *http.Request, prefillPodHostPort string, completionRequest map[string]any) error { | ||
| ctx := r.Context() | ||
| preq := r.Clone(ctx) | ||
|
|
||
| // Prepare prefill request | ||
| completionRequest[requestFieldMaxTokens] = 1 | ||
| completionRequest[requestFieldMaxCompletionTokens] = 1 | ||
|
|
||
kyanokashi marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
@@ -57,35 +162,33 @@ func (s *Server) runLMCacheProtocol(w http.ResponseWriter, r *http.Request, pref | |
| if err := errorJSONInvalid(err, w); err != nil { | ||
| s.logger.Error(err, "failed to send error response to client") | ||
kyanokashi marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
| return | ||
| return err | ||
| } | ||
| preq.Body = io.NopCloser(strings.NewReader(string(pbody))) | ||
| preq.ContentLength = int64(len(pbody)) | ||
|
|
||
| // Forward request to prefiller | ||
|
|
||
| prefillHandler, err := s.prefillerProxyHandler(prefillPodHostPort) | ||
| if err != nil { | ||
| if err := errorBadGateway(err, w); err != nil { | ||
| s.logger.Error(err, "failed to send error response to client") | ||
kyanokashi marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
| return | ||
| return err | ||
| } | ||
|
|
||
| // send prefill request | ||
| s.logger.V(4).Info("sending prefill request", "to", prefillPodHostPort) | ||
| pw := &bufferedResponseWriter{} | ||
| prefillHandler.ServeHTTP(pw, preq) | ||
|
|
||
| if pw.statusCode < 200 || pw.statusCode >= 300 { | ||
| s.logger.Error(err, "request failed", "code", pw.statusCode) | ||
| s.logger.Error(nil, "prefill request failed", "code", pw.statusCode) | ||
kyanokashi marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| w.WriteHeader(pw.statusCode) | ||
| return | ||
| if pw.buffer.Len() > 0 { | ||
| w.Write([]byte(pw.buffer.String())) //nolint:all | ||
| } | ||
| return err | ||
| } | ||
|
|
||
| // Forward original request to local decoder | ||
|
|
||
| r.Body = io.NopCloser(strings.NewReader(string(original))) | ||
| if !s.forwardDataParallel || !s.dataParallelHandler(w, r) { | ||
| s.logger.V(4).Info("sending request to decoder", "to", s.decoderURL.Host) | ||
| s.decoderProxy.ServeHTTP(w, r) | ||
| } | ||
| s.logger.V(4).Info("prefill completed successfully") | ||
| return nil | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.