diff --git a/pkg/sidecar/proxy/connector_sglang.go b/pkg/sidecar/proxy/connector_sglang.go index b02fb7231..1e8251380 100644 --- a/pkg/sidecar/proxy/connector_sglang.go +++ b/pkg/sidecar/proxy/connector_sglang.go @@ -18,6 +18,7 @@ package proxy import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -77,8 +78,10 @@ func (s *Server) runSGLangProtocol(w http.ResponseWriter, r *http.Request, prefi func (s *Server) sendSGLangConcurrentRequests(w http.ResponseWriter, r *http.Request, body []byte, prefillHost string) { // Create separate requests for prefill and decode - prefillReq := cloneWithJSONBody(r, body) - decodeReq := cloneWithJSONBody(r, body) + // Use context.WithoutCancel for prefillReq to prevent it from being aborted + // if the main HTTP handler (which serves decodeReq) finishes first. + prefillReq := cloneWithJSONBody(context.WithoutCancel(r.Context()), r, body) + decodeReq := cloneWithJSONBody(r.Context(), r, body) prefillHandler, err := s.prefillerProxyHandler(prefillHost) if err != nil { @@ -90,6 +93,11 @@ func (s *Server) sendSGLangConcurrentRequests(w http.ResponseWriter, r *http.Req // Send prefill request asynchronously go func() { + defer func() { + if rec := recover(); rec != nil && rec != http.ErrAbortHandler { + s.logger.Error(fmt.Errorf("panic: %v", rec), "panic in prefill request") + } + }() pw := &bufferedResponseWriter{} prefillHandler.ServeHTTP(pw, prefillReq) s.logger.V(5).Info("prefill request completed", "status", pw.statusCode) @@ -99,8 +107,8 @@ func (s *Server) sendSGLangConcurrentRequests(w http.ResponseWriter, r *http.Req s.decoderProxy.ServeHTTP(w, decodeReq) } -func cloneWithJSONBody(r *http.Request, body []byte) *http.Request { - req := r.Clone(r.Context()) +func cloneWithJSONBody(ctx context.Context, r *http.Request, body []byte) *http.Request { + req := r.Clone(ctx) req.Body = io.NopCloser(bytes.NewReader(body)) req.ContentLength = int64(len(body)) return req diff --git a/pkg/sidecar/proxy/connector_sglang_test.go b/pkg/sidecar/proxy/connector_sglang_test.go new file mode 100644 index 000000000..9b6a38cdb --- /dev/null +++ b/pkg/sidecar/proxy/connector_sglang_test.go @@ -0,0 +1,177 @@ +/* +Copyright 2025 The llm-d Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package proxy + +import ( + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "time" + + "github.com/llm-d/llm-d-inference-scheduler/pkg/common" + . "github.com/onsi/ginkgo/v2" // nolint:revive + . "github.com/onsi/gomega" // nolint:revive +) + +var _ = Describe("SGLang Connector", func() { + + var testInfo *sidecarTestInfo + + BeforeEach(func() { + // Mock testing setup using the SGLang connector mode + testInfo = sidecarConnectionTestSetup(ConnectorSGLang) + }) + + It("should successfully send concurrent requests to prefill and decode with bootstrap info", func() { + By("starting the proxy") + go func() { + defer GinkgoRecover() + + validator := &AllowlistValidator{enabled: false} + err := testInfo.proxy.Start(testInfo.ctx, nil, validator) + Expect(err).ToNot(HaveOccurred()) + + testInfo.stoppedCh <- struct{}{} + }() + + // Wait for proxy to start + time.Sleep(1 * time.Second) + Expect(testInfo.proxy.addr).ToNot(BeNil()) + proxyBaseAddr := "http://" + testInfo.proxy.addr.String() + + By("sending a /v1/chat/completions request with prefill header") + body := `{ + "model": "Qwen/Qwen2-0.5B", + "messages": [ + {"role": "user", "content": "Hello"} + ], + "max_tokens": 50 + }` + + req, err := http.NewRequest(http.MethodPost, proxyBaseAddr+ChatCompletionsPath, strings.NewReader(body)) + Expect(err).ToNot(HaveOccurred()) + + prefillHostPort := testInfo.prefillBackend.URL[len("http://"):] + req.Header.Add(common.PrefillPodHeader, prefillHostPort) + + rp, err := http.DefaultClient.Do(req) + Expect(err).ToNot(HaveOccurred()) + + if rp.StatusCode != 200 { + bp, _ := io.ReadAll(rp.Body) //nolint:all + Fail(string(bp)) + } + + // Because SGLang connector sends requests concurrently (prefill in goroutine), + // we sleep a tiny bit to ensure the prefill handler has time to finish processing. + time.Sleep(100 * time.Millisecond) + + // Validate prefill request + Expect(testInfo.prefillHandler.RequestCount.Load()).To(BeNumerically("==", 1)) + Expect(testInfo.prefillHandler.CompletionRequests).To(HaveLen(1)) + prq1 := testInfo.prefillHandler.CompletionRequests[0] + + // Validate decode request + Expect(testInfo.decodeHandler.RequestCount.Load()).To(BeNumerically("==", 1)) + Expect(testInfo.decodeHandler.CompletionRequests).To(HaveLen(1)) + drq1 := testInfo.decodeHandler.CompletionRequests[0] + + // Bootstrap validations for prefill + Expect(prq1).To(HaveKey(requestFieldBootstrapHost)) + Expect(prq1).To(HaveKey(requestFieldBootstrapPort)) + Expect(prq1).To(HaveKey(requestFieldBootstrapRoom)) + + expectedHost := strings.Split(prefillHostPort, ":")[0] + Expect(prq1[requestFieldBootstrapHost]).To(Equal(expectedHost)) + Expect(prq1[requestFieldBootstrapPort]).To(Equal(float64(sglangBootstrapPort))) + Expect(prq1[requestFieldBootstrapRoom]).ToNot(BeNil()) + + // Bootstrap validations for decode + Expect(drq1).To(HaveKey(requestFieldBootstrapHost)) + Expect(drq1).To(HaveKey(requestFieldBootstrapPort)) + Expect(drq1).To(HaveKey(requestFieldBootstrapRoom)) + + Expect(drq1[requestFieldBootstrapHost]).To(Equal(expectedHost)) + Expect(drq1[requestFieldBootstrapPort]).To(Equal(float64(sglangBootstrapPort))) + Expect(drq1[requestFieldBootstrapRoom]).To(Equal(prq1[requestFieldBootstrapRoom])) // Room ID must match + + testInfo.cancelFn() + <-testInfo.stoppedCh + }) + + It("should not panic when prefill response is slower than decode response", func() { + // Stop previously injected servers + testInfo.decodeBackend.Close() + testInfo.prefillBackend.Close() + + var prefillFinished bool + + slowPrefill := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + testInfo.prefillHandler.ServeHTTP(w, r) + time.Sleep(300 * time.Millisecond) // Simulated load delay on KV Cache + prefillFinished = true + }) + testInfo.prefillBackend = httptest.NewServer(slowPrefill) + + fastDecode := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + testInfo.decodeHandler.ServeHTTP(w, r) + }) + testInfo.decodeBackend = httptest.NewServer(fastDecode) + testInfo.decodeURL, _ = url.Parse(testInfo.decodeBackend.URL) + + // Re-initialize proxy to fetch the new mock addresses + cfg := Config{ + Connector: ConnectorSGLang, + } + testInfo.proxy = NewProxy("0", testInfo.decodeURL, cfg) + + go func() { + defer GinkgoRecover() + validator := &AllowlistValidator{enabled: false} + err := testInfo.proxy.Start(testInfo.ctx, nil, validator) + Expect(err).ToNot(HaveOccurred()) + testInfo.stoppedCh <- struct{}{} + }() + + time.Sleep(1 * time.Second) + proxyBaseAddr := "http://" + testInfo.proxy.addr.String() + + body := `{"model": "Qwen", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 50}` + req, err := http.NewRequest(http.MethodPost, proxyBaseAddr+ChatCompletionsPath, strings.NewReader(body)) + Expect(err).ToNot(HaveOccurred()) + + prefillHostPort := testInfo.prefillBackend.URL[len("http://"):] + req.Header.Add(common.PrefillPodHeader, prefillHostPort) + + // Submit request. This will complete as soon as fastDecode completes. + rp, err := http.DefaultClient.Do(req) + Expect(err).ToNot(HaveOccurred()) + Expect(rp.StatusCode).To(Equal(200)) + + // The original panicking goroutine takes 300ms total. Give it time to attempt finishing up! + time.Sleep(500 * time.Millisecond) + + Expect(prefillFinished).To(BeTrue()) + Expect(testInfo.prefillHandler.RequestCount.Load()).To(BeNumerically("==", 1)) + Expect(testInfo.decodeHandler.RequestCount.Load()).To(BeNumerically("==", 1)) + + testInfo.cancelFn() + <-testInfo.stoppedCh + }) +})