Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 3 additions & 17 deletions cmd/pd-sidecar/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License.
package main

import (
"crypto/tls"
"flag"
"net/url"
"os"
Expand Down Expand Up @@ -111,28 +110,15 @@ func main() {
return
}

var cert *tls.Certificate
if *secureProxy {
var tempCert tls.Certificate
if *certPath != "" {
tempCert, err = tls.LoadX509KeyPair(*certPath+"/tls.crt", *certPath+"/tls.key")
} else {
tempCert, err = proxy.CreateSelfSignedTLSCertificate()
}
if err != nil {
logger.Error(err, "failed to create TLS certificate")
return
}
cert = &tempCert
}

config := proxy.Config{
Connector: *connector,
PrefillerUseTLS: *prefillerUseTLS,
PrefillerInsecureSkipVerify: *prefillerInsecureSkipVerify,
DecoderInsecureSkipVerify: *decoderInsecureSkipVerify,
DataParallelSize: *vLLMDataParallelSize,
EnablePrefillerSampling: *enablePrefillerSampling,
SecureServing: *secureProxy,
CertPath: *certPath,
}

// Create SSRF protection validator
Expand All @@ -144,7 +130,7 @@ func main() {

proxyServer := proxy.NewProxy(*port, targetURL, config)

if err := proxyServer.Start(ctx, cert, validator); err != nil {
if err := proxyServer.Start(ctx, validator); err != nil {
logger.Error(err, "failed to start proxy server")
}
}
2 changes: 1 addition & 1 deletion pkg/sidecar/proxy/connector_nixlv2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ var _ = Describe("NIXL Connector (v2)", func() {
defer GinkgoRecover()

validator := &AllowlistValidator{enabled: false}
err := testInfo.proxy.Start(testInfo.ctx, nil, validator)
err := testInfo.proxy.Start(testInfo.ctx, validator)
Expect(err).ToNot(HaveOccurred())

testInfo.stoppedCh <- struct{}{}
Expand Down
4 changes: 2 additions & 2 deletions pkg/sidecar/proxy/connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ var _ = Describe("Common Connector tests", func() {
defer GinkgoRecover()

validator := &AllowlistValidator{enabled: false}
err := testInfo.proxy.Start(testInfo.ctx, nil, validator)
err := testInfo.proxy.Start(testInfo.ctx, validator)
Expect(err).ToNot(HaveOccurred())

testInfo.stoppedCh <- struct{}{}
Expand Down Expand Up @@ -121,7 +121,7 @@ var _ = Describe("Common Connector tests", func() {
defer GinkgoRecover()

validator := &AllowlistValidator{enabled: false}
err := testInfo.proxy.Start(testInfo.ctx, nil, validator)
err := testInfo.proxy.Start(testInfo.ctx, validator)
Expect(err).ToNot(HaveOccurred())

testInfo.stoppedCh <- struct{}{}
Expand Down
5 changes: 2 additions & 3 deletions pkg/sidecar/proxy/data_parallel.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package proxy

import (
"context"
"crypto/tls"
"net"
"net/http"
"net/url"
Expand Down Expand Up @@ -35,7 +34,7 @@ func (s *Server) dataParallelHandler(w http.ResponseWriter, r *http.Request) boo
return false
}

func (s *Server) startDataParallel(ctx context.Context, cert *tls.Certificate, grp *errgroup.Group) error {
func (s *Server) startDataParallel(ctx context.Context, grp *errgroup.Group) error {
podIP := os.Getenv("POD_IP")
basePort, err := strconv.Atoi(s.port)
if err != nil {
Expand Down Expand Up @@ -79,7 +78,7 @@ func (s *Server) startDataParallel(ctx context.Context, cert *tls.Certificate, g
clone.handler = clone.createRoutes()
clone.setConnector()

return clone.startHTTP(ctx, cert)
return clone.startHTTP(ctx)
Copy link

Copilot AI Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Data-parallel clone servers call startHTTP(), which now derives TLS settings (and connector selection) from s.config, but Clone() does not copy the config field. As a result, cloned servers will run with zero-value config (e.g., SecureServing false / CertPath empty), so TLS cert reloading (and potentially connector behavior) won’t apply to the data-parallel proxy ports. Consider copying config (and any other required fields) in Clone(), or explicitly setting clone.config = s.config before calling clone.startHTTP().

Copilot uses AI. Check for mistakes.
})
}
return nil
Expand Down
2 changes: 1 addition & 1 deletion pkg/sidecar/proxy/data_parallel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ var _ = Describe("Data Parallel support", func() {
theProxy.allowlistValidator, err = NewAllowlistValidator(false, DefaultPoolGroup, "", "")
Expect(err).ToNot(HaveOccurred())

err = theProxy.startDataParallel(ctx, nil, grp)
err = theProxy.startDataParallel(ctx, grp)
Expect(err).ToNot(HaveOccurred())

Expect(theProxy.dataParallelProxies).To(HaveLen(testDataParallelSize))
Expand Down
11 changes: 8 additions & 3 deletions pkg/sidecar/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ type Config struct {
// EnablePrefillerSampling configures the proxy to randomly choose from the set
// of provided prefill hosts instead of always using the first one.
EnablePrefillerSampling bool

// CertPath is the path to TLS certificates for the sidecar server.
CertPath string
// SecureServing enables TLS for the sidecar server.
SecureServing bool
}

type protocolRunner func(http.ResponseWriter, *http.Request, string)
Expand Down Expand Up @@ -143,7 +148,7 @@ func NewProxy(port string, decodeURL *url.URL, config Config) *Server {
}

// Start the HTTP reverse proxy.
func (s *Server) Start(ctx context.Context, cert *tls.Certificate, allowlistValidator *AllowlistValidator) error {
func (s *Server) Start(ctx context.Context, allowlistValidator *AllowlistValidator) error {
s.logger = log.FromContext(ctx).WithName("proxy server on port " + s.port)

s.allowlistValidator = allowlistValidator
Expand All @@ -152,12 +157,12 @@ func (s *Server) Start(ctx context.Context, cert *tls.Certificate, allowlistVali
s.handler = s.createRoutes()

grp, ctx := errgroup.WithContext(ctx)
if err := s.startDataParallel(ctx, cert, grp); err != nil {
if err := s.startDataParallel(ctx, grp); err != nil {
return err
}

grp.Go(func() error {
return s.startHTTP(ctx, cert)
return s.startHTTP(ctx)
})

return grp.Wait()
Expand Down
37 changes: 33 additions & 4 deletions pkg/sidecar/proxy/proxy_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@ import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"net/http/httputil"
"net/url"
"syscall"
"time"

"sigs.k8s.io/gateway-api-inference-extension/pkg/common"
)

// startHTTP starts the HTTP reverse proxy.
func (s *Server) startHTTP(ctx context.Context, cert *tls.Certificate) error {
func (s *Server) startHTTP(ctx context.Context) error {
// Start SSRF protection validator
if err := s.allowlistValidator.Start(ctx); err != nil {
s.logger.Error(err, "Failed to start allowlist validator")
Expand All @@ -35,11 +38,36 @@ func (s *Server) startHTTP(ctx context.Context, cert *tls.Certificate) error {
MaxHeaderBytes: 1 << 20, // 1 MB for headers is sufficient
}

// Create TLS certificates
var cert *tls.Certificate
if s.config.SecureServing {
var tempCert tls.Certificate
if s.config.CertPath != "" {
tempCert, err = tls.LoadX509KeyPair(s.config.CertPath+"/tls.crt", s.config.CertPath+"/tls.key")
} else {
tempCert, err = CreateSelfSignedTLSCertificate()
}
if err != nil {
return fmt.Errorf("failed to create TLS certificate: %w", err)
}
Comment on lines +58 to +64
Copy link

Copilot AI Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message "failed to create TLS certificate" is used for both loading a keypair from CertPath and generating a self-signed cert. It would be more actionable to distinguish these cases (e.g., include the cert/key filenames when LoadX509KeyPair fails, and a separate message for self-signed generation failure).

Suggested change
tempCert, err = tls.LoadX509KeyPair(s.config.CertPath+"/tls.crt", s.config.CertPath+"/tls.key")
} else {
tempCert, err = CreateSelfSignedTLSCertificate()
}
if err != nil {
return fmt.Errorf("failed to create TLS certificate: %w", err)
}
certFile := s.config.CertPath + "/tls.crt"
keyFile := s.config.CertPath + "/tls.key"
tempCert, err = tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return fmt.Errorf("failed to load TLS key pair from cert %q and key %q: %w", certFile, keyFile, err)
}
} else {
tempCert, err = CreateSelfSignedTLSCertificate()
if err != nil {
return fmt.Errorf("failed to generate self-signed TLS certificate: %w", err)
}
}

Copilot uses AI. Check for mistakes.
cert = &tempCert
}

if cert != nil {
getCertificate := func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return cert, nil
}
if s.config.CertPath != "" {
reloader, err := common.NewCertReloader(ctx, "", cert)
if err != nil {
return fmt.Errorf("failed to start reloader: %w", err)
}
Comment on lines +72 to +76
Copy link

Copilot AI Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TLS certificate reloading via common.NewCertReloader is newly introduced but doesn’t appear to be covered by tests. Adding an integration/unit test that starts the server with CertPath pointing to a temp dir, rotates tls.crt/tls.key, and verifies a new TLS handshake presents the updated cert would help prevent regressions.

Copilot uses AI. Check for mistakes.
getCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return reloader.Get(), nil
}
}

server.TLSConfig = &tls.Config{
Certificates: []tls.Certificate{*cert},
MinVersion: tls.VersionTLS12,
MinVersion: tls.VersionTLS12,
CipherSuites: []uint16{
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
Expand All @@ -48,6 +76,7 @@ func (s *Server) startHTTP(ctx context.Context, cert *tls.Certificate) error {
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
},
GetCertificate: getCertificate,
}
s.logger.Info("server TLS configured")
}
Expand Down
16 changes: 6 additions & 10 deletions pkg/sidecar/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,6 @@ var _ = Describe("Reverse Proxy", func() {
func(path string, secureProxy bool) {

ctx := newTestContext()
var cert *tls.Certificate
if secureProxy {
tempCert, err := CreateSelfSignedTLSCertificate()
Expect(err).ToNot(HaveOccurred())
cert = &tempCert
}

ackHandlerFn := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(200)
Expand All @@ -70,7 +64,9 @@ var _ = Describe("Reverse Proxy", func() {
targetURL, err := url.Parse(decodeBackend.URL)
Expect(err).ToNot(HaveOccurred())

cfg := Config{}
cfg := Config{
SecureServing: secureProxy,
}
proxy := NewProxy("0", targetURL, cfg) // port 0 to automatically choose one that's available.

ctx, cancelFn := context.WithCancel(ctx)
Expand All @@ -80,7 +76,7 @@ var _ = Describe("Reverse Proxy", func() {
defer GinkgoRecover()

validator := &AllowlistValidator{enabled: false}
err := proxy.Start(ctx, cert, validator)
err := proxy.Start(ctx, validator)
Expect(err).ToNot(HaveOccurred())
stoppedCh <- struct{}{}
}()
Expand Down Expand Up @@ -180,7 +176,7 @@ var _ = Describe("Reverse Proxy", func() {
defer GinkgoRecover()

validator := &AllowlistValidator{enabled: false}
err := proxy.Start(ctx, nil, validator)
err := proxy.Start(ctx, validator)
Expect(err).ToNot(HaveOccurred())
stoppedCh <- struct{}{}
}()
Expand Down Expand Up @@ -254,7 +250,7 @@ var _ = Describe("Reverse Proxy", func() {
defer GinkgoRecover()

validator := &AllowlistValidator{enabled: false}
err := proxy.Start(ctx, nil, validator)
err := proxy.Start(ctx, validator)
Expect(err).ToNot(HaveOccurred())
stoppedCh <- struct{}{}
}()
Expand Down