Skip to content

Commit e793964

Browse files
committed
initial E/PD extension of the sidecar
Signed-off-by: roytman <roytman@il.ibm.com>
1 parent a0c8d17 commit e793964

File tree

11 files changed

+646
-80
lines changed

11 files changed

+646
-80
lines changed

cmd/pd-sidecar/main.go

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,22 +32,30 @@ import (
3232
)
3333

3434
var (
35-
// supportedConnectors defines all valid P/D connector types
36-
supportedConnectors = []string{
37-
proxy.ConnectorNIXLV2,
38-
proxy.ConnectorSharedStorage,
39-
proxy.ConnectorSGLang,
35+
// supportedKVConnectors defines all valid KV (Prefiller-Decoder) connector types
36+
supportedKVConnectors = []string{
37+
proxy.KVConnectorNIXLV2,
38+
proxy.KVConnectorSharedStorage,
39+
proxy.KVConnectorSGLang,
40+
}
41+
42+
// supportedECConnectors defines all valid EC (Encoder-Prefiller) connector types
43+
supportedECConnectors = []string{
44+
proxy.ECExampleConnector,
4045
}
4146
)
4247

4348
func main() {
4449
port := flag.String("port", "8000", "the port the sidecar is listening on")
4550
vLLMPort := flag.String("vllm-port", "8001", "the port vLLM is listening on")
4651
vLLMDataParallelSize := flag.Int("data-parallel-size", 1, "the vLLM DATA-PARALLEL-SIZE value")
47-
connector := flag.String("connector", proxy.ConnectorNIXLV2, "the P/D connector being used. Supported: "+strings.Join(supportedConnectors, ", "))
52+
kvConnector := flag.String("kv-connector", proxy.KVConnectorNIXLV2, "the KV connector between Prefiller and Decoder. Supported: "+strings.Join(supportedKVConnectors, ", "))
53+
ecConnector := flag.String("ec-connector", proxy.ECExampleConnector, "the EC connector between Encoder and Prefiller (optional, for EPD mode). Supported: "+strings.Join(supportedECConnectors, ", "))
4854
prefillerUseTLS := flag.Bool("prefiller-use-tls", false, "whether to use TLS when sending requests to prefillers")
55+
encoderUseTLS := flag.Bool("encoder-use-tls", false, "whether to use TLS when sending requests to encoders")
4956
decoderUseTLS := flag.Bool("decoder-use-tls", false, "whether to use TLS when sending requests to the decoder")
5057
prefillerInsecureSkipVerify := flag.Bool("prefiller-tls-insecure-skip-verify", false, "configures the proxy to skip TLS verification for requests to prefiller")
58+
encoderInsecureSkipVerify := flag.Bool("encoder-tls-insecure-skip-verify", false, "configures the proxy to skip TLS verification for requests to encoder")
5159
decoderInsecureSkipVerify := flag.Bool("decoder-tls-insecure-skip-verify", false, "configures the proxy to skip TLS verification for requests to decoder")
5260
secureProxy := flag.Bool("secure-proxy", true, "Enables secure proxy. Defaults to true.")
5361
certPath := flag.String(
@@ -72,19 +80,37 @@ func main() {
7280

7381
logger.Info("Proxy starting", "Built on", version.BuildRef, "From Git SHA", version.CommitSHA)
7482

75-
// Validate connector
83+
// Validate KV connector (Prefiller-Decoder)
7684
isValidConnector := false
77-
for _, validConnector := range supportedConnectors {
78-
if *connector == validConnector {
85+
for _, validConnector := range supportedKVConnectors {
86+
if *kvConnector == validConnector {
7987
isValidConnector = true
8088
break
8189
}
8290
}
8391
if !isValidConnector {
84-
logger.Info("Error: --connector must be one of: " + strings.Join(supportedConnectors, ", "))
92+
logger.Info("Error: --kv-connector must be one of: " + strings.Join(supportedKVConnectors, ", "))
8593
return
8694
}
87-
logger.Info("p/d connector validated", "connector", connector)
95+
logger.Info("KV connector (prefiller-decoder) validated", "kvConnector", kvConnector)
96+
97+
// Validate EC connector (Encoder-Prefiller) if specified
98+
if *ecConnector != "" {
99+
isValidEncoderConnector := false
100+
for _, validConnector := range supportedECConnectors {
101+
if *ecConnector == validConnector {
102+
isValidEncoderConnector = true
103+
break
104+
}
105+
}
106+
if !isValidEncoderConnector {
107+
logger.Info("Error: --ec-connector must be one of: " + strings.Join(supportedECConnectors, ", "))
108+
return
109+
}
110+
logger.Info("EC connector (encoder-prefiller) validated", "ecConnector", ecConnector)
111+
} else {
112+
logger.Info("EC connector (encoder-prefiller) not specified, encoder stage will be skipped")
113+
}
88114

89115
// Determine namespace and pool name for SSRF protection
90116
if *enableSSRFProtection {
@@ -127,9 +153,12 @@ func main() {
127153
}
128154

129155
config := proxy.Config{
130-
Connector: *connector,
156+
KVConnector: *kvConnector,
157+
ECConnector: *ecConnector,
131158
PrefillerUseTLS: *prefillerUseTLS,
159+
EncoderUseTLS: *encoderUseTLS,
132160
PrefillerInsecureSkipVerify: *prefillerInsecureSkipVerify,
161+
EncoderInsecureSkipVerify: *encoderInsecureSkipVerify,
133162
DecoderInsecureSkipVerify: *decoderInsecureSkipVerify,
134163
DataParallelSize: *vLLMDataParallelSize,
135164
EnablePrefillerSampling: *enablePrefillerSampling,

pkg/sidecar/proxy/chat_completions.go

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,26 +54,68 @@ func (s *Server) chatCompletionsHandler(w http.ResponseWriter, r *http.Request)
5454
}
5555
}
5656

57-
if len(prefillHostPort) == 0 {
58-
s.logger.V(4).Info("skip disaggregated prefill")
57+
// SSRF Protection: Check if the prefill target is allowed (if provided)
58+
if len(prefillHostPort) > 0 {
59+
if !s.allowlistValidator.IsAllowed(prefillHostPort) {
60+
s.logger.Error(nil, "SSRF protection: prefill target not in allowlist",
61+
"target", prefillHostPort,
62+
"clientIP", r.RemoteAddr,
63+
"userAgent", r.Header.Get("User-Agent"),
64+
"requestPath", r.URL.Path)
65+
http.Error(w, "Forbidden: prefill target not allowed by SSRF protection", http.StatusForbidden)
66+
return
67+
}
68+
s.logger.V(4).Info("SSRF protection: prefill target allowed", "target", prefillHostPort)
69+
}
5970

60-
if !s.forwardDataParallel || !s.dataParallelHandler(w, r) {
61-
s.decoderProxy.ServeHTTP(w, r)
71+
// Check if encoder headers are present to determine if we should use EPD protocol
72+
encoderHostPorts := r.Header.Values(common.EncoderHostsPortsHeader)
73+
if len(encoderHostPorts) == 1 {
74+
encoderHostPorts = strings.Split(encoderHostPorts[0], ",")
75+
}
76+
77+
// SSRF Protection: Filter encoder targets to only allowed hosts
78+
var allowedEncoders []string
79+
if len(encoderHostPorts) > 0 {
80+
allowedEncoders = make([]string, 0, len(encoderHostPorts))
81+
for _, encoderHost := range encoderHostPorts {
82+
encoderHost = strings.TrimSpace(encoderHost)
83+
if s.allowlistValidator.IsAllowed(encoderHost) {
84+
allowedEncoders = append(allowedEncoders, encoderHost)
85+
s.logger.V(4).Info("SSRF protection: encoder target allowed", "target", encoderHost)
86+
} else {
87+
s.logger.Info("SSRF protection: encoder target not in allowlist, removing from list",
88+
"target", encoderHost,
89+
"clientIP", r.RemoteAddr,
90+
"userAgent", r.Header.Get("User-Agent"),
91+
"requestPath", r.URL.Path)
92+
}
6293
}
63-
return
6494
}
6595

66-
// SSRF Protection: Check if the prefill target is allowed
67-
if !s.allowlistValidator.IsAllowed(prefillHostPort) {
68-
s.logger.Error(nil, "SSRF protection: prefill target not in allowlist",
69-
"target", prefillHostPort,
70-
"clientIP", r.RemoteAddr,
71-
"userAgent", r.Header.Get("User-Agent"),
72-
"requestPath", r.URL.Path)
73-
http.Error(w, "Forbidden: prefill target not allowed by SSRF protection", http.StatusForbidden)
96+
// Determine which protocol to use
97+
if len(allowedEncoders) > 0 && s.runEPDConnectorProtocol != nil {
98+
// Use EPD protocol (Encoder-Prefiller-Decoder or Encoder-Decoder)
99+
s.logger.V(4).Info("encoder headers detected, using EPD protocol",
100+
"encoderCount", len(allowedEncoders),
101+
"hasPrefiller", len(prefillHostPort) > 0)
102+
s.runEPDConnectorProtocol(w, r, prefillHostPort, allowedEncoders)
74103
return
75104
}
76105

77-
s.logger.V(4).Info("SSRF protection: prefill target allowed", "target", prefillHostPort)
78-
s.runConnectorProtocol(w, r, prefillHostPort)
106+
// If all encoders were filtered out, log and fall through
107+
if len(encoderHostPorts) > 0 && len(allowedEncoders) == 0 {
108+
s.logger.Info("SSRF protection: all encoder targets filtered out, falling back to P/D or decoder-only")
109+
}
110+
111+
// Use P/D protocol or decoder-only
112+
if len(prefillHostPort) > 0 {
113+
s.logger.V(4).Info("using P/D protocol")
114+
s.runPDConnectorProtocol(w, r, prefillHostPort)
115+
} else {
116+
s.logger.V(4).Info("no prefiller or encoder, using decoder only")
117+
if !s.forwardDataParallel || !s.dataParallelHandler(w, r) {
118+
s.decoderProxy.ServeHTTP(w, r)
119+
}
120+
}
79121
}

pkg/sidecar/proxy/chat_completions_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ func TestServer_chatCompletionsHandler(t *testing.T) {
119119
s.prefillSamplerFn = func(n int) int { return i % n }
120120
// verify the hostPort value
121121
var hostPort string
122-
s.runConnectorProtocol = func(_ http.ResponseWriter, _ *http.Request, selectedHostPort string) { hostPort = selectedHostPort }
122+
s.runPDConnectorProtocol = func(_ http.ResponseWriter, _ *http.Request, selectedHostPort string) { hostPort = selectedHostPort }
123123
var passthrough bool
124124
s.decoderProxy = http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
125125
passthrough = true

0 commit comments

Comments
 (0)