Skip to content

Commit b806520

Browse files
committed
parallelize sending first token and decode request prep
Signed-off-by: RishabhSaini <rishabhsaini01@gmail.com>
1 parent 2b1a630 commit b806520

File tree

1 file changed

+79
-36
lines changed

1 file changed

+79
-36
lines changed

pkg/sidecar/proxy/connector_nixlv2.go

Lines changed: 79 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -177,51 +177,63 @@ func (s *Server) runNIXLProtocolV2(w http.ResponseWriter, r *http.Request, prefi
177177
s.logger.V(5).Info("received prefiller response", requestFieldKVTransferParams, pKVTransferParams)
178178

179179
// 4. Send first token to client if streaming
180-
var firstTokenSentTime time.Time
181-
decodeResponseWriter := w
180+
// Channel to signal when first token has been sent and receive the timestamp
181+
// Also carries any error that occurred and whether headers were sent
182+
type firstTokenResult struct {
183+
sentTime time.Time
184+
err error
185+
headersSent bool
186+
}
187+
firstTokenSent := make(chan firstTokenResult, 1)
188+
182189
if clientWantsStreaming {
183-
// Convert non-streaming response to SSE format and send first chunk
184-
// Remove kv_transfer_params before sending to user
190+
// Remove kv_transfer_params before sending to user (do this before goroutine to avoid race)
185191
delete(prefillerResponse, requestFieldKVTransferParams)
186192

187-
streamChunk, err := json.Marshal(prefillerResponse)
188-
if err != nil {
189-
if err := errorJSONInvalid(err, w); err != nil {
193+
// Send first token in goroutine to allow parallel execution with decode prep
194+
go func() {
195+
streamChunk, err := json.Marshal(prefillerResponse)
196+
if err != nil {
190197
s.logger.Error(err, "failed to marshal streaming chunk")
198+
firstTokenSent <- firstTokenResult{err: err, headersSent: false}
199+
return
191200
}
192-
return
193-
}
194201

195-
w.Header().Set("Content-Type", "text/event-stream")
196-
w.Header().Set("Cache-Control", "no-cache")
197-
w.Header().Set("Connection", "keep-alive")
198-
w.WriteHeader(http.StatusOK)
202+
w.Header().Set("Content-Type", "text/event-stream")
203+
w.Header().Set("Cache-Control", "no-cache")
204+
w.Header().Set("Connection", "keep-alive")
205+
w.WriteHeader(http.StatusOK)
199206

200-
_, err = w.Write([]byte("data: "))
201-
if err != nil {
202-
s.logger.Error(err, "failed to write SSE prefix")
203-
return
204-
}
205-
_, err = w.Write(streamChunk)
206-
if err != nil {
207-
s.logger.Error(err, "failed to write prefill chunk to client")
208-
return
209-
}
210-
_, err = w.Write([]byte("\n\n"))
211-
if err != nil {
212-
s.logger.Error(err, "failed to write SSE suffix")
213-
return
214-
}
215-
216-
if flusher, ok := w.(http.Flusher); ok {
217-
flusher.Flush()
218-
}
207+
_, err = w.Write([]byte("data: "))
208+
if err != nil {
209+
s.logger.Error(err, "failed to write SSE prefix")
210+
firstTokenSent <- firstTokenResult{err: err, headersSent: true}
211+
return
212+
}
213+
_, err = w.Write(streamChunk)
214+
if err != nil {
215+
s.logger.Error(err, "failed to write prefill chunk to client")
216+
firstTokenSent <- firstTokenResult{err: err, headersSent: true}
217+
return
218+
}
219+
_, err = w.Write([]byte("\n\n"))
220+
if err != nil {
221+
s.logger.Error(err, "failed to write SSE suffix")
222+
firstTokenSent <- firstTokenResult{err: err, headersSent: true}
223+
return
224+
}
219225

220-
firstTokenSentTime = time.Now()
221-
s.logger.V(4).Info("sent first token to client immediately after prefill")
226+
if flusher, ok := w.(http.Flusher); ok {
227+
flusher.Flush()
228+
}
222229

223-
// Wrap writer to prevent decode from re-writing headers
224-
decodeResponseWriter = &headersSentWriter{ResponseWriter: w}
230+
sentTime := time.Now()
231+
s.logger.V(4).Info("sent first token to client immediately after prefill")
232+
firstTokenSent <- firstTokenResult{sentTime: sentTime, headersSent: true}
233+
}()
234+
} else {
235+
// For non-streaming, signal immediately that we don't need to wait
236+
firstTokenSent <- firstTokenResult{}
225237
}
226238

227239
// Decode Stage
@@ -284,6 +296,14 @@ func (s *Server) runNIXLProtocolV2(w http.ResponseWriter, r *http.Request, prefi
284296

285297
dbody, err := json.Marshal(completionRequest)
286298
if err != nil {
299+
// Wait for goroutine to complete before returning
300+
result := <-firstTokenSent
301+
if result.headersSent {
302+
// Headers already sent, can't send error response
303+
s.logger.Error(err, "failed to marshal decode request, but response already started")
304+
return
305+
}
306+
// Safe to send error response - headers not sent yet
287307
if err := errorJSONInvalid(err, w); err != nil {
288308
s.logger.Error(err, "failed to send error response to client")
289309
}
@@ -293,6 +313,29 @@ func (s *Server) runNIXLProtocolV2(w http.ResponseWriter, r *http.Request, prefi
293313
dreq.ContentLength = int64(len(dbody))
294314

295315
// 6. Forward to local decoder.
316+
// Wait for first token to be sent before forwarding to decoder
317+
result := <-firstTokenSent
318+
if result.err != nil {
319+
// First token send failed
320+
if result.headersSent {
321+
// Partial response already sent to client, can't recover
322+
s.logger.Error(result.err, "failed to send first token after headers sent, aborting request")
323+
} else {
324+
// Headers not sent yet, try to send error response
325+
s.logger.Error(result.err, "failed to send first token before headers sent")
326+
if err := errorJSONInvalid(result.err, w); err != nil {
327+
s.logger.Error(err, "failed to send error response to client")
328+
}
329+
}
330+
return
331+
}
332+
firstTokenSentTime := result.sentTime
333+
334+
// Use wrapper to prevent duplicate WriteHeader for streaming
335+
var decodeResponseWriter http.ResponseWriter = w
336+
if clientWantsStreaming {
337+
decodeResponseWriter = &headersSentWriter{ResponseWriter: w}
338+
}
296339

297340
s.logger.V(5).Info("sending request to decoder", "body", string(dbody))
298341
dataParallelUsed := s.forwardDataParallel && s.dataParallelHandler(decodeResponseWriter, dreq)

0 commit comments

Comments
 (0)