Skip to content

Commit 7c1493a

Browse files
committed
Allow sidecar server to reload TLS certificates
Enables TLS certificates to be rotated without restarting sidecar and vLLM deployments. Signed-off-by: Pierangelo Di Pilato <pierdipi@redhat.com>
1 parent 3ed3d5c commit 7c1493a

File tree

8 files changed

+56
-41
lines changed

8 files changed

+56
-41
lines changed

cmd/pd-sidecar/main.go

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ limitations under the License.
1616
package main
1717

1818
import (
19-
"crypto/tls"
2019
"flag"
2120
"net/url"
2221
"os"
@@ -111,28 +110,15 @@ func main() {
111110
return
112111
}
113112

114-
var cert *tls.Certificate
115-
if *secureProxy {
116-
var tempCert tls.Certificate
117-
if *certPath != "" {
118-
tempCert, err = tls.LoadX509KeyPair(*certPath+"/tls.crt", *certPath+"/tls.key")
119-
} else {
120-
tempCert, err = proxy.CreateSelfSignedTLSCertificate()
121-
}
122-
if err != nil {
123-
logger.Error(err, "failed to create TLS certificate")
124-
return
125-
}
126-
cert = &tempCert
127-
}
128-
129113
config := proxy.Config{
130114
Connector: *connector,
131115
PrefillerUseTLS: *prefillerUseTLS,
132116
PrefillerInsecureSkipVerify: *prefillerInsecureSkipVerify,
133117
DecoderInsecureSkipVerify: *decoderInsecureSkipVerify,
134118
DataParallelSize: *vLLMDataParallelSize,
135119
EnablePrefillerSampling: *enablePrefillerSampling,
120+
SecureServing: *secureProxy,
121+
CertPath: *certPath,
136122
}
137123

138124
// Create SSRF protection validator
@@ -144,7 +130,7 @@ func main() {
144130

145131
proxyServer := proxy.NewProxy(*port, targetURL, config)
146132

147-
if err := proxyServer.Start(ctx, cert, validator); err != nil {
133+
if err := proxyServer.Start(ctx, validator); err != nil {
148134
logger.Error(err, "failed to start proxy server")
149135
}
150136
}

pkg/sidecar/proxy/connector_nixlv2_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ var _ = Describe("NIXL Connector (v2)", func() {
4141
defer GinkgoRecover()
4242

4343
validator := &AllowlistValidator{enabled: false}
44-
err := testInfo.proxy.Start(testInfo.ctx, nil, validator)
44+
err := testInfo.proxy.Start(testInfo.ctx, validator)
4545
Expect(err).ToNot(HaveOccurred())
4646

4747
testInfo.stoppedCh <- struct{}{}

pkg/sidecar/proxy/connector_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ var _ = Describe("Common Connector tests", func() {
5959
defer GinkgoRecover()
6060

6161
validator := &AllowlistValidator{enabled: false}
62-
err := testInfo.proxy.Start(testInfo.ctx, nil, validator)
62+
err := testInfo.proxy.Start(testInfo.ctx, validator)
6363
Expect(err).ToNot(HaveOccurred())
6464

6565
testInfo.stoppedCh <- struct{}{}
@@ -121,7 +121,7 @@ var _ = Describe("Common Connector tests", func() {
121121
defer GinkgoRecover()
122122

123123
validator := &AllowlistValidator{enabled: false}
124-
err := testInfo.proxy.Start(testInfo.ctx, nil, validator)
124+
err := testInfo.proxy.Start(testInfo.ctx, validator)
125125
Expect(err).ToNot(HaveOccurred())
126126

127127
testInfo.stoppedCh <- struct{}{}

pkg/sidecar/proxy/data_parallel.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package proxy
22

33
import (
44
"context"
5-
"crypto/tls"
65
"net"
76
"net/http"
87
"net/url"
@@ -35,7 +34,7 @@ func (s *Server) dataParallelHandler(w http.ResponseWriter, r *http.Request) boo
3534
return false
3635
}
3736

38-
func (s *Server) startDataParallel(ctx context.Context, cert *tls.Certificate, grp *errgroup.Group) error {
37+
func (s *Server) startDataParallel(ctx context.Context, grp *errgroup.Group) error {
3938
podIP := os.Getenv("POD_IP")
4039
basePort, err := strconv.Atoi(s.port)
4140
if err != nil {
@@ -79,7 +78,7 @@ func (s *Server) startDataParallel(ctx context.Context, cert *tls.Certificate, g
7978
clone.handler = clone.createRoutes()
8079
clone.setConnector()
8180

82-
return clone.startHTTP(ctx, cert)
81+
return clone.startHTTP(ctx)
8382
})
8483
}
8584
return nil

pkg/sidecar/proxy/data_parallel_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ var _ = Describe("Data Parallel support", func() {
6262
theProxy.allowlistValidator, err = NewAllowlistValidator(false, DefaultPoolGroup, "", "")
6363
Expect(err).ToNot(HaveOccurred())
6464

65-
err = theProxy.startDataParallel(ctx, nil, grp)
65+
err = theProxy.startDataParallel(ctx, grp)
6666
Expect(err).ToNot(HaveOccurred())
6767

6868
Expect(theProxy.dataParallelProxies).To(HaveLen(testDataParallelSize))

pkg/sidecar/proxy/proxy.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ type Config struct {
9393
// EnablePrefillerSampling configures the proxy to randomly choose from the set
9494
// of provided prefill hosts instead of always using the first one.
9595
EnablePrefillerSampling bool
96+
97+
// CertPath is the path to TLS certificates for the sidecar server.
98+
CertPath string
99+
// SecureServing enables TLS for the sidecar server.
100+
SecureServing bool
96101
}
97102

98103
type protocolRunner func(http.ResponseWriter, *http.Request, string)
@@ -143,7 +148,7 @@ func NewProxy(port string, decodeURL *url.URL, config Config) *Server {
143148
}
144149

145150
// Start the HTTP reverse proxy.
146-
func (s *Server) Start(ctx context.Context, cert *tls.Certificate, allowlistValidator *AllowlistValidator) error {
151+
func (s *Server) Start(ctx context.Context, allowlistValidator *AllowlistValidator) error {
147152
s.logger = log.FromContext(ctx).WithName("proxy server on port " + s.port)
148153

149154
s.allowlistValidator = allowlistValidator
@@ -152,12 +157,12 @@ func (s *Server) Start(ctx context.Context, cert *tls.Certificate, allowlistVali
152157
s.handler = s.createRoutes()
153158

154159
grp, ctx := errgroup.WithContext(ctx)
155-
if err := s.startDataParallel(ctx, cert, grp); err != nil {
160+
if err := s.startDataParallel(ctx, grp); err != nil {
156161
return err
157162
}
158163

159164
grp.Go(func() error {
160-
return s.startHTTP(ctx, cert)
165+
return s.startHTTP(ctx)
161166
})
162167

163168
return grp.Wait()

pkg/sidecar/proxy/proxy_helpers.go

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,19 @@ import (
44
"context"
55
"crypto/tls"
66
"errors"
7+
"fmt"
78
"net"
89
"net/http"
910
"net/http/httputil"
1011
"net/url"
1112
"syscall"
1213
"time"
14+
15+
"sigs.k8s.io/gateway-api-inference-extension/pkg/common"
1316
)
1417

1518
// startHTTP starts the HTTP reverse proxy.
16-
func (s *Server) startHTTP(ctx context.Context, cert *tls.Certificate) error {
19+
func (s *Server) startHTTP(ctx context.Context) error {
1720
// Start SSRF protection validator
1821
if err := s.allowlistValidator.Start(ctx); err != nil {
1922
s.logger.Error(err, "Failed to start allowlist validator")
@@ -35,11 +38,36 @@ func (s *Server) startHTTP(ctx context.Context, cert *tls.Certificate) error {
3538
MaxHeaderBytes: 1 << 20, // 1 MB for headers is sufficient
3639
}
3740

38-
// Create TLS certificates
41+
var cert *tls.Certificate
42+
if s.config.SecureServing {
43+
var tempCert tls.Certificate
44+
if s.config.CertPath != "" {
45+
tempCert, err = tls.LoadX509KeyPair(s.config.CertPath+"/tls.crt", s.config.CertPath+"/tls.key")
46+
} else {
47+
tempCert, err = CreateSelfSignedTLSCertificate()
48+
}
49+
if err != nil {
50+
return fmt.Errorf("failed to create TLS certificate: %w", err)
51+
}
52+
cert = &tempCert
53+
}
54+
3955
if cert != nil {
56+
getCertificate := func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
57+
return cert, nil
58+
}
59+
if s.config.CertPath != "" {
60+
reloader, err := common.NewCertReloader(ctx, "", cert)
61+
if err != nil {
62+
return fmt.Errorf("failed to start reloader: %w", err)
63+
}
64+
getCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
65+
return reloader.Get(), nil
66+
}
67+
}
68+
4069
server.TLSConfig = &tls.Config{
41-
Certificates: []tls.Certificate{*cert},
42-
MinVersion: tls.VersionTLS12,
70+
MinVersion: tls.VersionTLS12,
4371
CipherSuites: []uint16{
4472
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
4573
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
@@ -48,6 +76,7 @@ func (s *Server) startHTTP(ctx context.Context, cert *tls.Certificate) error {
4876
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
4977
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
5078
},
79+
GetCertificate: getCertificate,
5180
}
5281
s.logger.Info("server TLS configured")
5382
}

pkg/sidecar/proxy/proxy_test.go

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,6 @@ var _ = Describe("Reverse Proxy", func() {
5353
func(path string, secureProxy bool) {
5454

5555
ctx := newTestContext()
56-
var cert *tls.Certificate
57-
if secureProxy {
58-
tempCert, err := CreateSelfSignedTLSCertificate()
59-
Expect(err).ToNot(HaveOccurred())
60-
cert = &tempCert
61-
}
6256

6357
ackHandlerFn := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
6458
w.WriteHeader(200)
@@ -70,7 +64,9 @@ var _ = Describe("Reverse Proxy", func() {
7064
targetURL, err := url.Parse(decodeBackend.URL)
7165
Expect(err).ToNot(HaveOccurred())
7266

73-
cfg := Config{}
67+
cfg := Config{
68+
SecureServing: secureProxy,
69+
}
7470
proxy := NewProxy("0", targetURL, cfg) // port 0 to automatically choose one that's available.
7571

7672
ctx, cancelFn := context.WithCancel(ctx)
@@ -80,7 +76,7 @@ var _ = Describe("Reverse Proxy", func() {
8076
defer GinkgoRecover()
8177

8278
validator := &AllowlistValidator{enabled: false}
83-
err := proxy.Start(ctx, cert, validator)
79+
err := proxy.Start(ctx, validator)
8480
Expect(err).ToNot(HaveOccurred())
8581
stoppedCh <- struct{}{}
8682
}()
@@ -180,7 +176,7 @@ var _ = Describe("Reverse Proxy", func() {
180176
defer GinkgoRecover()
181177

182178
validator := &AllowlistValidator{enabled: false}
183-
err := proxy.Start(ctx, nil, validator)
179+
err := proxy.Start(ctx, validator)
184180
Expect(err).ToNot(HaveOccurred())
185181
stoppedCh <- struct{}{}
186182
}()
@@ -254,7 +250,7 @@ var _ = Describe("Reverse Proxy", func() {
254250
defer GinkgoRecover()
255251

256252
validator := &AllowlistValidator{enabled: false}
257-
err := proxy.Start(ctx, nil, validator)
253+
err := proxy.Start(ctx, validator)
258254
Expect(err).ToNot(HaveOccurred())
259255
stoppedCh <- struct{}{}
260256
}()

0 commit comments

Comments
 (0)