@@ -26,16 +26,14 @@ import (
2626func (s * Server ) runLMCacheProtocol (w http.ResponseWriter , r * http.Request , prefillPodHostPort string ) {
2727 s .logger .Info ("running LMCache protocol" )
2828
29- // Read and parse request body
3029 defer r .Body .Close () //nolint:all
3130 original , err := io .ReadAll (r .Body )
3231 if err != nil {
33- w .WriteHeader (http .StatusBadRequest ) // TODO: check FastAPI error code when failing to read body
34- w .Write ([]byte (err .Error ())) //nolint:all
32+ w .WriteHeader (http .StatusBadRequest )
33+ w .Write ([]byte (err .Error ())) //nolint:all
3534 return
3635 }
3736
38- // Parse completion request
3937 var completionRequest map [string ]any
4038 if err := json .Unmarshal (original , & completionRequest ); err != nil {
4139 if err := errorJSONInvalid (err , w ); err != nil {
@@ -44,11 +42,120 @@ func (s *Server) runLMCacheProtocol(w http.ResponseWriter, r *http.Request, pref
4442 return
4543 }
4644
47- // Create prefiller request. Set max_tokens to 1.
45+ if s .forwardDataParallel && s .dataParallelHandler (w , r ) {
46+ if err := s .prefill (w , r , prefillPodHostPort , completionRequest ); err != nil {
47+ s .logger .Error (err , "prefill failed" )
48+ }
49+ return
50+ }
51+
52+ if _ , hasCacheHitThreshold := completionRequest [requestFieldCacheHitThreshold ]; hasCacheHitThreshold {
53+ s .decodeFirst (w , r , original , completionRequest , prefillPodHostPort )
54+ } else {
55+ s .prefillThenDecode (w , r , original , completionRequest , prefillPodHostPort )
56+ }
57+ }
58+
59+ // prefillThenDecode implements the prefill-first flow: prefill then decode
60+ func (s * Server ) prefillThenDecode (w http.ResponseWriter , r * http.Request , original []byte , completionRequest map [string ]any , prefillPodHostPort string ) {
61+ s .logger .V (4 ).Info ("running prefill-then-decode flow" )
62+
63+ if err := s .prefill (w , r , prefillPodHostPort , completionRequest ); err != nil {
64+ s .logger .Error (err , "prefill failed" )
65+ return
66+ }
67+
68+ s .logger .V (4 ).Info ("forwarding to decoder after prefill" )
69+ r .Body = io .NopCloser (strings .NewReader (string (original )))
70+ s .decoderProxy .ServeHTTP (w , r )
71+ }
72+
73+ // decodeFirst implements the decode-first flow with cache threshold checking
74+ func (s * Server ) decodeFirst (w http.ResponseWriter , r * http.Request , original []byte , completionRequest map [string ]any , prefillPodHostPort string ) {
75+ s .logger .V (4 ).Info ("running decode-first flow" )
76+
77+ // Step 1: Try decode first
78+ r .Body = io .NopCloser (strings .NewReader (string (original )))
79+ needsPrefill , err := s .tryDecode (w , r )
80+ if err != nil {
81+ s .logger .Error (err , "decode attempt failed" )
82+ return
83+ }
84+
85+ // If decode succeeded (cache hit was sufficient), we're done
86+ if ! needsPrefill {
87+ s .logger .V (4 ).Info ("decode succeeded without prefill" )
88+ return
89+ }
90+
91+ // Step 2: Cache threshold not met, execute prefill
92+ s .logger .V (4 ).Info ("cache threshold not met, executing prefill" , "prefillPod" , prefillPodHostPort )
93+ if err := s .prefill (w , r , prefillPodHostPort , completionRequest ); err != nil {
94+ s .logger .Error (err , "prefill failed" )
95+ return
96+ }
4897
98+ // Step 3: Retry decode after prefill
99+ s .logger .V (4 ).Info ("retrying decode after prefill" )
100+ r .Body = io .NopCloser (strings .NewReader (string (original )))
101+ s .decoderProxy .ServeHTTP (w , r )
102+ }
103+
104+ // tryDecode attempts to decode and returns whether prefill is needed
105+ func (s * Server ) tryDecode (w http.ResponseWriter , r * http.Request ) (needsPrefill bool , err error ) {
106+ dw := & bufferedResponseWriter {}
107+ s .decoderProxy .ServeHTTP (dw , r )
108+
109+ // Check for non-success status codes
110+ if dw .statusCode < 200 || dw .statusCode >= 300 {
111+ s .logger .Error (nil , "decode request failed" , "code" , dw .statusCode )
112+ w .WriteHeader (dw .statusCode )
113+ if dw .buffer .Len () > 0 {
114+ w .Write ([]byte (dw .buffer .String ())) //nolint:all
115+ }
116+ return false , nil
117+ }
118+
119+ // Parse response to check finish_reason
120+ var response map [string ]any
121+ if err := json .Unmarshal ([]byte (dw .buffer .String ()), & response ); err != nil {
122+ s .logger .Error (err , "failed to unmarshal decoder response" )
123+ // Forward response as-is if we can't parse it
124+ w .WriteHeader (dw .statusCode )
125+ w .Write ([]byte (dw .buffer .String ())) //nolint:all
126+ return false , nil
127+ }
128+
129+ // Check for cache_threshold finish reason
130+ if choices , ok := response [responseFieldChoices ].([]any ); ok && len (choices ) > 0 {
131+ if choice , ok := choices [0 ].(map [string ]any ); ok {
132+ if finishReason , ok := choice [responseFieldFinishReason ].(string ); ok {
133+ if finishReason == finishReasonCacheThreshold {
134+ s .logger .V (4 ).Info ("decode rejected due to cache threshold" )
135+ return true , nil
136+ }
137+ }
138+ }
139+ }
140+
141+ // Decode succeeded, write response to client
142+ for k , v := range dw .headers {
143+ for _ , val := range v {
144+ w .Header ().Add (k , val )
145+ }
146+ }
147+ w .WriteHeader (dw .statusCode )
148+ w .Write ([]byte (dw .buffer .String ())) //nolint:all
149+
150+ return false , nil
151+ }
152+
153+ // prefill routes a request to a preill node
154+ func (s * Server ) prefill (w http.ResponseWriter , r * http.Request , prefillPodHostPort string , completionRequest map [string ]any ) error {
49155 ctx := r .Context ()
50156 preq := r .Clone (ctx )
51157
158+ // Prepare prefill request
52159 completionRequest [requestFieldMaxTokens ] = 1
53160 completionRequest [requestFieldMaxCompletionTokens ] = 1
54161
@@ -57,34 +164,33 @@ func (s *Server) runLMCacheProtocol(w http.ResponseWriter, r *http.Request, pref
57164 if err := errorJSONInvalid (err , w ); err != nil {
58165 s .logger .Error (err , "failed to send error response to client" )
59166 }
60- return
167+ return err
61168 }
62169 preq .Body = io .NopCloser (strings .NewReader (string (pbody )))
63170 preq .ContentLength = int64 (len (pbody ))
64171
65- // Forward request to prefiller
66-
67172 prefillHandler , err := s .prefillerProxyHandler (prefillPodHostPort )
68173 if err != nil {
69174 if err := errorBadGateway (err , w ); err != nil {
70175 s .logger .Error (err , "failed to send error response to client" )
71176 }
72- return
177+ return err
73178 }
179+
180+ // send prefill request
74181 s .logger .V (4 ).Info ("sending prefill request" , "to" , prefillPodHostPort )
75182 pw := & bufferedResponseWriter {}
76183 prefillHandler .ServeHTTP (pw , preq )
77184
78185 if pw .statusCode < 200 || pw .statusCode >= 300 {
79- s .logger .Error (err , "request failed" , "code" , pw .statusCode )
186+ s .logger .Error (nil , "prefill request failed" , "code" , pw .statusCode )
80187 w .WriteHeader (pw .statusCode )
81- return
188+ if pw .buffer .Len () > 0 {
189+ w .Write ([]byte (pw .buffer .String ())) //nolint:all
190+ }
191+ return err
82192 }
83193
84- // Forward original request to local decoder
85-
86- r .Body = io .NopCloser (strings .NewReader (string (original )))
87- if s .forwardDataParallel && ! s .dataParallelHandler (w , r ) {
88- s .decoderProxy .ServeHTTP (w , r )
89- }
194+ s .logger .V (4 ).Info ("prefill completed successfully" )
195+ return nil
90196}
0 commit comments