Skip to content

Commit 1576cc0

Browse files
committed
initial E/PD extension of the sidecar
Signed-off-by: roytman <roytman@il.ibm.com>
1 parent 1519a28 commit 1576cc0

File tree

12 files changed

+670
-87
lines changed

12 files changed

+670
-87
lines changed

cmd/pd-sidecar/main.go

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,30 @@ import (
3333
)
3434

3535
var (
36-
// supportedConnectors defines all valid P/D connector types
37-
supportedConnectors = []string{
38-
proxy.ConnectorNIXLV2,
39-
proxy.ConnectorSharedStorage,
40-
proxy.ConnectorSGLang,
36+
// supportedKVConnectors defines all valid KV (Prefiller-Decoder) connector types
37+
supportedKVConnectors = []string{
38+
proxy.KVConnectorNIXLV2,
39+
proxy.KVConnectorSharedStorage,
40+
proxy.KVConnectorSGLang,
41+
}
42+
43+
// supportedECConnectors defines all valid EC (Encoder-Prefiller) connector types
44+
supportedECConnectors = []string{
45+
proxy.ECExampleConnector,
4146
}
4247
)
4348

4449
func main() {
4550
port := flag.String("port", "8000", "the port the sidecar is listening on")
4651
vLLMPort := flag.String("vllm-port", "8001", "the port vLLM is listening on")
4752
vLLMDataParallelSize := flag.Int("data-parallel-size", 1, "the vLLM DATA-PARALLEL-SIZE value")
48-
connector := flag.String("connector", proxy.ConnectorNIXLV2, "the P/D connector being used. Supported: "+strings.Join(supportedConnectors, ", "))
53+
kvConnector := flag.String("kv-connector", proxy.KVConnectorNIXLV2, "the KV connector between Prefiller and Decoder. Supported: "+strings.Join(supportedKVConnectors, ", "))
54+
ecConnector := flag.String("ec-connector", proxy.ECExampleConnector, "the EC connector between Encoder and Prefiller (optional, for EPD mode). Supported: "+strings.Join(supportedECConnectors, ", "))
4955
prefillerUseTLS := flag.Bool("prefiller-use-tls", false, "whether to use TLS when sending requests to prefillers")
56+
encoderUseTLS := flag.Bool("encoder-use-tls", false, "whether to use TLS when sending requests to encoders")
5057
decoderUseTLS := flag.Bool("decoder-use-tls", false, "whether to use TLS when sending requests to the decoder")
5158
prefillerInsecureSkipVerify := flag.Bool("prefiller-tls-insecure-skip-verify", false, "configures the proxy to skip TLS verification for requests to prefiller")
59+
encoderInsecureSkipVerify := flag.Bool("encoder-tls-insecure-skip-verify", false, "configures the proxy to skip TLS verification for requests to encoder")
5260
decoderInsecureSkipVerify := flag.Bool("decoder-tls-insecure-skip-verify", false, "configures the proxy to skip TLS verification for requests to decoder")
5361
secureProxy := flag.Bool("secure-proxy", true, "Enables secure proxy. Defaults to true.")
5462
certPath := flag.String(
@@ -87,19 +95,37 @@ func main() {
8795

8896
logger.Info("Proxy starting", "Built on", version.BuildRef, "From Git SHA", version.CommitSHA)
8997

90-
// Validate connector
98+
// Validate KV connector (Prefiller-Decoder)
9199
isValidConnector := false
92-
for _, validConnector := range supportedConnectors {
93-
if *connector == validConnector {
100+
for _, validConnector := range supportedKVConnectors {
101+
if *kvConnector == validConnector {
94102
isValidConnector = true
95103
break
96104
}
97105
}
98106
if !isValidConnector {
99-
logger.Info("Error: --connector must be one of: " + strings.Join(supportedConnectors, ", "))
107+
logger.Info("Error: --kv-connector must be one of: " + strings.Join(supportedKVConnectors, ", "))
100108
return
101109
}
102-
logger.Info("p/d connector validated", "connector", connector)
110+
logger.Info("KV connector (prefiller-decoder) validated", "kvConnector", kvConnector)
111+
112+
// Validate EC connector (Encoder-Prefiller) if specified
113+
if *ecConnector != "" {
114+
isValidEncoderConnector := false
115+
for _, validConnector := range supportedECConnectors {
116+
if *ecConnector == validConnector {
117+
isValidEncoderConnector = true
118+
break
119+
}
120+
}
121+
if !isValidEncoderConnector {
122+
logger.Info("Error: --ec-connector must be one of: " + strings.Join(supportedECConnectors, ", "))
123+
return
124+
}
125+
logger.Info("EC connector (encoder-prefiller) validated", "ecConnector", ecConnector)
126+
} else {
127+
logger.Info("EC connector (encoder-prefiller) not specified, encoder stage will be skipped")
128+
}
103129

104130
// Determine namespace and pool name for SSRF protection
105131
if *enableSSRFProtection {
@@ -142,9 +168,12 @@ func main() {
142168
}
143169

144170
config := proxy.Config{
145-
Connector: *connector,
171+
KVConnector: *kvConnector,
172+
ECConnector: *ecConnector,
146173
PrefillerUseTLS: *prefillerUseTLS,
174+
EncoderUseTLS: *encoderUseTLS,
147175
PrefillerInsecureSkipVerify: *prefillerInsecureSkipVerify,
176+
EncoderInsecureSkipVerify: *encoderInsecureSkipVerify,
148177
DecoderInsecureSkipVerify: *decoderInsecureSkipVerify,
149178
DataParallelSize: *vLLMDataParallelSize,
150179
EnablePrefillerSampling: *enablePrefillerSampling,

pkg/common/common.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ const (
1212

1313
// DataParallelPodHeader is the header name used to indicate the worker <ip:port> for Data Parallel
1414
DataParallelPodHeader = "x-data-parallel-host-port"
15+
16+
// EncoderHostsPortsHeader is the header name used to indicate Encoder workers <ip:port> list
17+
EncoderHostsPortsHeader = "x-encoder-hosts-ports"
1518
)
1619

1720
// StripScheme removes the scheme from an endpoint URL, returning host:port.

pkg/sidecar/proxy/chat_completions.go

Lines changed: 78 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ func (s *Server) chatCompletionsHandler(w http.ResponseWriter, r *http.Request)
6060
requestPath = r.URL.Path
6161
}
6262
span.SetAttributes(
63-
attribute.String("llm_d.pd_proxy.connector", s.config.Connector),
63+
attribute.String("llm_d.pd_proxy.kv_connector", s.config.KVConnector),
64+
attribute.String("llm_d.pd_proxy.ec_connector", s.config.ECConnector),
6465
attribute.String("llm_d.pd_proxy.request_path", requestPath),
6566
)
6667

@@ -86,41 +87,96 @@ func (s *Server) chatCompletionsHandler(w http.ResponseWriter, r *http.Request)
8687
}
8788
}
8889

90+
// Set span attributes for prefill
8991
if len(prefillHostPort) == 0 {
9092
s.logger.V(4).Info("skip disaggregated prefill")
9193
span.SetAttributes(
9294
attribute.Bool("llm_d.pd_proxy.disaggregation_used", false),
9395
attribute.String("llm_d.pd_proxy.reason", "no_prefill_header"),
9496
)
97+
} else {
98+
span.SetAttributes(
99+
attribute.Bool("llm_d.pd_proxy.disaggregation_used", true),
100+
attribute.String("llm_d.pd_proxy.prefill_target", prefillHostPort),
101+
attribute.Int("llm_d.pd_proxy.prefill_candidates", numHosts),
102+
)
103+
}
95104

96-
if !s.forwardDataParallel || !s.dataParallelHandler(w, r) {
97-
s.decoderProxy.ServeHTTP(w, r)
105+
// SSRF Protection: Check if the prefill target is allowed (if provided)
106+
if len(prefillHostPort) > 0 {
107+
if !s.allowlistValidator.IsAllowed(prefillHostPort) {
108+
s.logger.Error(nil, "SSRF protection: prefill target not in allowlist",
109+
"target", prefillHostPort,
110+
"clientIP", r.RemoteAddr,
111+
"userAgent", r.Header.Get("User-Agent"),
112+
"requestPath", r.URL.Path)
113+
span.SetAttributes(
114+
attribute.String("llm_d.pd_proxy.error", "ssrf_protection_denied"),
115+
attribute.String("llm_d.pd_proxy.denied_target", prefillHostPort),
116+
)
117+
span.SetStatus(codes.Error, "SSRF protection: prefill target not in allowlist")
118+
http.Error(w, "Forbidden: prefill target not allowed by SSRF protection", http.StatusForbidden)
119+
return
98120
}
99-
return
121+
s.logger.V(4).Info("SSRF protection: prefill target allowed", "target", prefillHostPort)
100122
}
101123

102-
span.SetAttributes(
103-
attribute.Bool("llm_d.pd_proxy.disaggregation_used", true),
104-
attribute.String("llm_d.pd_proxy.prefill_target", prefillHostPort),
105-
attribute.Int("llm_d.pd_proxy.prefill_candidates", numHosts),
106-
)
124+
// Check if encoder headers are present to determine if we should use EPD protocol
125+
encoderHostPorts := r.Header.Values(common.EncoderHostsPortsHeader)
126+
if len(encoderHostPorts) == 1 {
127+
encoderHostPorts = strings.Split(encoderHostPorts[0], ",")
128+
}
107129

108-
// SSRF Protection: Check if the prefill target is allowed
109-
if !s.allowlistValidator.IsAllowed(prefillHostPort) {
110-
s.logger.Error(nil, "SSRF protection: prefill target not in allowlist",
111-
"target", prefillHostPort,
112-
"clientIP", r.RemoteAddr,
113-
"userAgent", r.Header.Get("User-Agent"),
114-
"requestPath", r.URL.Path)
130+
// SSRF Protection: Filter encoder targets to only allowed hosts
131+
var allowedEncoders []string
132+
if len(encoderHostPorts) > 0 {
133+
allowedEncoders = make([]string, 0, len(encoderHostPorts))
134+
for _, encoderHost := range encoderHostPorts {
135+
encoderHost = strings.TrimSpace(encoderHost)
136+
if s.allowlistValidator.IsAllowed(encoderHost) {
137+
allowedEncoders = append(allowedEncoders, encoderHost)
138+
s.logger.V(4).Info("SSRF protection: encoder target allowed", "target", encoderHost)
139+
} else {
140+
s.logger.Info("SSRF protection: encoder target not in allowlist, removing from list",
141+
"target", encoderHost,
142+
"clientIP", r.RemoteAddr,
143+
"userAgent", r.Header.Get("User-Agent"),
144+
"requestPath", r.URL.Path)
145+
}
146+
}
147+
}
148+
149+
// Determine which protocol to use
150+
if len(allowedEncoders) > 0 && s.runEPDConnectorProtocol != nil {
151+
// Use EPD protocol (Encoder-Prefiller-Decoder or Encoder-Decoder)
152+
s.logger.V(4).Info("encoder headers detected, using EPD protocol",
153+
"encoderCount", len(allowedEncoders),
154+
"encoderCandidates", len(encoderHostPorts),
155+
"hasPrefiller", len(prefillHostPort) > 0)
115156
span.SetAttributes(
116-
attribute.String("llm_d.pd_proxy.error", "ssrf_protection_denied"),
117-
attribute.String("llm_d.pd_proxy.denied_target", prefillHostPort),
157+
attribute.Bool("llm_d.epd_proxy.encode_disaggregation_used", true),
158+
attribute.Int("llm_d.epd_proxy.encoder_count", len(allowedEncoders)),
159+
attribute.Int("llm_d.epd_proxy.encoder_candidates", len(encoderHostPorts)),
118160
)
119-
span.SetStatus(codes.Error, "SSRF protection: prefill target not in allowlist")
120-
http.Error(w, "Forbidden: prefill target not allowed by SSRF protection", http.StatusForbidden)
161+
s.runEPDConnectorProtocol(w, r, prefillHostPort, allowedEncoders)
121162
return
122163
}
123164

124-
s.logger.V(4).Info("SSRF protection: prefill target allowed", "target", prefillHostPort)
125-
s.runConnectorProtocol(w, r, prefillHostPort)
165+
// If all encoders were filtered out, log and fall through
166+
if len(encoderHostPorts) > 0 && len(allowedEncoders) == 0 {
167+
s.logger.Info("SSRF protection: all encoder targets filtered out, falling back to P/D or decoder-only")
168+
}
169+
170+
// Use P/D protocol or decoder-only
171+
if len(prefillHostPort) > 0 {
172+
s.logger.V(4).Info("using P/D protocol")
173+
s.runPDConnectorProtocol(w, r, prefillHostPort)
174+
} else {
175+
s.logger.V(4).Info("no prefiller or encoder, using decoder only")
176+
if !s.forwardDataParallel || !s.dataParallelHandler(w, r) {
177+
s.decoderProxy.ServeHTTP(w, r)
178+
}
179+
}
126180
}
181+
182+
// Made with Bob

pkg/sidecar/proxy/chat_completions_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ func TestServer_chatCompletionsHandler(t *testing.T) {
119119
s.prefillSamplerFn = func(n int) int { return i % n }
120120
// verify the hostPort value
121121
var hostPort string
122-
s.runConnectorProtocol = func(_ http.ResponseWriter, _ *http.Request, selectedHostPort string) { hostPort = selectedHostPort }
122+
s.runPDConnectorProtocol = func(_ http.ResponseWriter, _ *http.Request, selectedHostPort string) { hostPort = selectedHostPort }
123123
var passthrough bool
124124
s.decoderProxy = http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
125125
passthrough = true

0 commit comments

Comments
 (0)