Skip to content

Commit 98654d0

Browse files
committed
adds ability to configure sidecar proxy through YAML
1 parent 4cd7046 commit 98654d0

File tree

3 files changed

+691
-8
lines changed

3 files changed

+691
-8
lines changed

cmd/pd-sidecar/main.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ func main() {
3333
opts := proxy.NewOptions()
3434

3535
// Add options flags (including logging flags)
36-
opts.AddFlags(pflag.CommandLine)
36+
opts.FlagSet = pflag.CommandLine
37+
opts.AddFlags(opts.FlagSet)
3738
pflag.Parse()
3839

3940
logger := zap.New(zap.UseFlagOptions(&opts.LoggingOptions))

pkg/sidecar/proxy/options.go

Lines changed: 294 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626

2727
"github.com/spf13/pflag"
2828
"sigs.k8s.io/controller-runtime/pkg/log/zap"
29+
"sigs.k8s.io/yaml"
2930
)
3031

3132
// Options holds all configuration options for the pd-sidecar proxy.
@@ -48,6 +49,8 @@ type Options struct {
4849
InsecureSkipVerifyForPrefiller bool // InsecureSkipVerifyForPrefiller configures the proxy to skip TLS verification for requests to prefiller (set from TLSInsecureSkipVerify)
4950
InsecureSkipVerifyForEncoder bool // InsecureSkipVerifyForEncoder configures the proxy to skip TLS verification for requests to encoder (set from TLSInsecureSkipVerify)
5051
InsecureSkipVerifyForDecoder bool // InsecureSkipVerifyForDecoder configures the proxy to skip TLS verification for requests to decoder (set from TLSInsecureSkipVerify)
52+
Configuration string // Configuration is sidecar configuration in YAML provided as inline specification. Example `--configuration={port: 8085, vllm-port: 8203}`
53+
ConfigurationFile string // ConfigurationFile is path to file which contains sidecar configuration in YAML. Example `--configuration-file=/etc/config/sidecar-config.yaml`
5154

5255
// Deprecated flag fields (kept for backward compatibility)
5356
PrefillerUseTLS bool // Deprecated: Use EnableTLS instead. PrefillerUseTLS indicates whether to use TLS when sending requests to prefillers
@@ -65,13 +68,19 @@ type Options struct {
6568
EnablePrefillerSampling bool // EnablePrefillerSampling enables random selection of prefill instances
6669
PoolGroup string // PoolGroup is the group of the InferencePool this Endpoint Picker is associated with
6770
LoggingOptions zap.Options // LoggingOptions holds the zap logging configuration
71+
FlagSet *pflag.FlagSet
6872
}
6973

74+
type configurationMap map[string]any
75+
7076
const (
7177
// TLS stages
72-
prefillStage = "prefiller"
73-
decodeStage = "decoder"
74-
encodeStage = "encoder"
78+
prefillStage = "prefiller"
79+
decodeStage = "decoder"
80+
encodeStage = "encoder"
81+
defaultPort = "8000"
82+
defaultvLLMPort = "8001"
83+
defaultDataParallelSize = 1
7584
)
7685

7786
var (
@@ -118,9 +127,9 @@ func NewOptions() *Options {
118127
}
119128

120129
return &Options{
121-
Port: "8000",
122-
VLLMPort: "8001",
123-
DataParallelSize: 1,
130+
Port: defaultPort,
131+
VLLMPort: defaultvLLMPort,
132+
DataParallelSize: defaultDataParallelSize,
124133
KVConnector: "",
125134
ECConnector: "",
126135
Connector: KVConnectorNIXLV2,
@@ -154,6 +163,8 @@ func (opts *Options) AddFlags(fs *pflag.FlagSet) {
154163

155164
fs.StringSliceVar(&opts.EnableTLS, "enable-tls", opts.EnableTLS, "stages to enable TLS for. Supported: "+supportedTLSStageNamesStr+". Can be specified multiple times or as comma-separated values.")
156165
fs.StringSliceVar(&opts.TLSInsecureSkipVerify, "tls-insecure-skip-verify", opts.TLSInsecureSkipVerify, "stages to skip TLS verification for. Supported: "+supportedTLSStageNamesStr+". Can be specified multiple times or as comma-separated values.")
166+
fs.StringVar(&opts.Configuration, "configuration", "", "Sidecar configuration in YAML provided as inline specification. Example `--configuration={port: 8085, vllm-port: 8203}`")
167+
fs.StringVar(&opts.ConfigurationFile, "configuration-file", "", "Path to file which contains sidecar configuration in YAML. Example `--configuration-file=/etc/config/sidecar-config.yaml`")
157168

158169
// Deprecated flags - kept for backward compatibility
159170
fs.StringVar(&opts.Connector, "connector", opts.Connector, "Deprecated: use --kv-connector instead. The P/D connector being used. Supported: "+supportedKVConnectorNamesStr)
@@ -192,9 +203,13 @@ func validateStages(stages []string, supportedStages map[string]struct{}, flagNa
192203
}
193204

194205
// Complete performs post-processing of parsed command-line arguments.
195-
// This handles migration from deprecated boolean flags to new StringSlice flags,
206+
// This handles migration from deprecated boolean flags to new StringSlice flags, extracts YAML configuration,
196207
// parses the InferencePool field, sets configuration fields from flag fields, and computes the target URL.
197208
func (opts *Options) Complete() error {
209+
if err := opts.extractYAMLConfiguration(opts.Configuration, opts.ConfigurationFile); err != nil {
210+
return err
211+
}
212+
198213
// Migrate deprecated Connector flag to KVConnector
199214
if opts.Connector != "" && opts.KVConnector == "" {
200215
opts.KVConnector = opts.Connector
@@ -311,3 +326,275 @@ func (opts *Options) Validate() error {
311326

312327
return nil
313328
}
329+
330+
// extractYAMLConfiguration extracts sidecar configuration (if provided)
331+
// from `--configuration` and `--configuration-file` parameters
332+
func (opts *Options) extractYAMLConfiguration(configuration string, configurationFile string) error {
333+
var configurationMap1, configurationMap2 configurationMap
334+
var err error
335+
if configuration != "" {
336+
configurationMap1, err = YAMLConfigurationFromInlineSpecification(configuration)
337+
if err != nil {
338+
return err
339+
}
340+
}
341+
if configurationFile != "" {
342+
configurationMap2, err = YAMLConfigurationFromFile(configurationFile)
343+
if err != nil {
344+
return err
345+
}
346+
}
347+
switch {
348+
case configurationMap1 != nil && configurationMap2 != nil:
349+
opts.updateSidecarConfiguration(mergeYAMLConfigurations(configurationMap2, configurationMap1))
350+
case configurationMap1 != nil && configurationMap2 == nil:
351+
opts.updateSidecarConfiguration(configurationMap1)
352+
case configurationMap1 == nil && configurationMap2 != nil:
353+
opts.updateSidecarConfiguration(configurationMap2)
354+
default:
355+
break
356+
}
357+
return nil
358+
}
359+
360+
// isDefault checks if flag contains default or parsed value
361+
func (o *Options) isDefault(parameter string) bool {
362+
flag := o.FlagSet.Lookup(parameter)
363+
return flag == nil || !flag.Changed
364+
}
365+
366+
// YAMLConfigurationFromInlineSpecification extracts YAML configuration provided as inline specification
367+
// "--configuration={port: 8085, vllm-port: 8203}"
368+
func YAMLConfigurationFromInlineSpecification(config string) (map[string]any, error) {
369+
var temp map[string]any
370+
if err := yaml.Unmarshal([]byte(config), &temp); err != nil {
371+
return nil, errors.New("Failed to unmarshal sidecar configuration")
372+
}
373+
return temp, nil
374+
375+
}
376+
377+
// YAMLConfigurationFromFile extracts YAML configuration from file path
378+
// "--configuration-file=/etc/config/sidecar-config.yaml"
379+
func YAMLConfigurationFromFile(configFile string) (map[string]any, error) {
380+
var temp map[string]any
381+
rawFile, err := os.ReadFile(configFile)
382+
if err != nil {
383+
return nil, errors.New("Failed to read sidecar configuration")
384+
}
385+
if err := yaml.Unmarshal(rawFile, &temp); err != nil {
386+
return nil, errors.New("Failed to unmarshal sidecar configuration")
387+
388+
}
389+
return temp, nil
390+
}
391+
392+
// mergeYAMLConfigurations merges following:
393+
// 1. YAML configuration from file path `--configuration-file`
394+
// 2. YAML configuration provided as inline specification `--configuration“,
395+
// and gives higher priority to configuration provided in inline specification `--configuration`
396+
func mergeYAMLConfigurations(fileYAML, parameterYAML map[string]any) map[string]any {
397+
for parameterKey, parameterValue := range parameterYAML {
398+
if fileYAMLValue, ok := fileYAML[parameterKey]; ok {
399+
fileYAMLMap, fileYAMLOk := fileYAMLValue.(map[string]any)
400+
parameterYAMLMap, parameterYAMLOk := parameterValue.(map[string]any)
401+
if fileYAMLOk && parameterYAMLOk {
402+
fileYAML[parameterKey] = mergeYAMLConfigurations(fileYAMLMap, parameterYAMLMap)
403+
continue
404+
}
405+
}
406+
fileYAML[parameterKey] = parameterValue
407+
}
408+
return fileYAML
409+
}
410+
411+
// updateSidecarConfiguration updates value from YAML only when:
412+
// 1. YAML configuration contains non-zero value
413+
// 2. sidecar configuration contains value not explicitely set by flag
414+
// i.e. gives higher priority to configuration provided individually through flags (e.g. `--port`, `--vllm-port`) over configuration provided through YAML
415+
func (opts *Options) updateSidecarConfiguration(configurationMap configurationMap) error {
416+
if configurationMap["port"] != nil {
417+
if v, ok := configurationMap["port"].(float64); ok {
418+
if opts.isDefault("port") {
419+
opts.Port = strconv.Itoa(int(v))
420+
}
421+
} else {
422+
return errors.New("Type assertion failed for port: " + fmt.Sprintf("%v", configurationMap["port"]))
423+
}
424+
}
425+
if configurationMap["vllm-port"] != nil {
426+
if v, ok := configurationMap["vllm-port"].(float64); ok {
427+
if opts.isDefault("vllm-port") {
428+
opts.VLLMPort = strconv.Itoa(int(v))
429+
}
430+
} else {
431+
return errors.New("Type assertion failed for vllm-port: " + fmt.Sprintf("%v", configurationMap["vllm-port"]))
432+
}
433+
}
434+
if configurationMap["data-parallel-size"] != nil {
435+
if v, ok := configurationMap["data-parallel-size"].(float64); ok {
436+
if opts.isDefault("data-parallel-size") {
437+
opts.DataParallelSize = int(v)
438+
}
439+
} else {
440+
return errors.New("Type assertion failed for data-parallel-size: " + fmt.Sprintf("%v", configurationMap["data-parallel-size"]))
441+
}
442+
}
443+
if configurationMap["connector"] != nil {
444+
if v, ok := configurationMap["connector"].(string); ok {
445+
if opts.isDefault("connector") {
446+
opts.Connector = v
447+
}
448+
} else {
449+
return errors.New("Type assertion failed for connector: " + fmt.Sprintf("%v", configurationMap["connector"]))
450+
}
451+
}
452+
if configurationMap["kv-connector"] != nil {
453+
if v, ok := configurationMap["kv-connector"].(string); ok {
454+
if opts.isDefault("kv-connector") {
455+
opts.KVConnector = v
456+
}
457+
} else {
458+
return errors.New("Type assertion failed for kv-connector: " + fmt.Sprintf("%v", configurationMap["kv-connector"]))
459+
}
460+
}
461+
if configurationMap["ec-connector"] != nil {
462+
if v, ok := configurationMap["ec-connector"].(string); ok {
463+
if opts.isDefault("ec-connector") {
464+
opts.ECConnector = v
465+
}
466+
} else {
467+
return errors.New("Type assertion failed for ec-connector: " + fmt.Sprintf("%v", configurationMap["ec-connector"]))
468+
}
469+
}
470+
if configurationMap["enable-tls"] != nil {
471+
switch v := configurationMap["enable-tls"].(type) {
472+
case string:
473+
opts.EnableTLS = append(opts.EnableTLS, strings.Split(v, ",")...)
474+
case []any:
475+
for _, val := range v {
476+
opts.EnableTLS = append(opts.EnableTLS, fmt.Sprintf("%v", val))
477+
}
478+
default:
479+
return errors.New("Type assertion failed for enable-tls: " + fmt.Sprintf("%v", configurationMap["enable-tls"]))
480+
}
481+
}
482+
if configurationMap["prefiller-use-tls"] != nil {
483+
if v, ok := configurationMap["prefiller-use-tls"].(bool); ok {
484+
if opts.isDefault("prefiller-use-tls") {
485+
opts.PrefillerUseTLS = v
486+
}
487+
} else {
488+
return errors.New("Type assertion failed for prefiller-use-tls: " + fmt.Sprintf("%v", configurationMap["prefiller-use-tls"]))
489+
}
490+
}
491+
if configurationMap["decoder-use-tls"] != nil {
492+
if v, ok := configurationMap["decoder-use-tls"].(bool); ok {
493+
if opts.isDefault("decoder-use-tls") {
494+
opts.DecoderUseTLS = v
495+
}
496+
} else {
497+
return errors.New("Type assertion failed for decoder-use-tls: " + fmt.Sprintf("%v", configurationMap["decoder-use-tls"]))
498+
}
499+
}
500+
if configurationMap["tls-insecure-skip-verify"] != nil {
501+
if v, ok := configurationMap["tls-insecure-skip-verify"].(bool); ok {
502+
if opts.isDefault("tls-insecure-skip-verify") {
503+
opts.PrefillerInsecureSkipVerify = v
504+
}
505+
} else {
506+
return errors.New("Type assertion failed for tls-insecure-skip-verify: " + fmt.Sprintf("%v", configurationMap["tls-insecure-skip-verify"]))
507+
}
508+
}
509+
if configurationMap["prefiller-tls-insecure-skip-verify"] != nil {
510+
if v, ok := configurationMap["prefiller-tls-insecure-skip-verify"].(bool); ok {
511+
if opts.isDefault("prefiller-tls-insecure-skip-verify") {
512+
opts.PrefillerInsecureSkipVerify = v
513+
}
514+
} else {
515+
return errors.New("Type assertion failed for prefiller-tls-insecure-skip-verify: " + fmt.Sprintf("%v", configurationMap["prefiller-tls-insecure-skip-verify"]))
516+
}
517+
}
518+
if configurationMap["decoder-tls-insecure-skip-verify"] != nil {
519+
if v, ok := configurationMap["decoder-tls-insecure-skip-verify"].(bool); ok {
520+
if opts.isDefault("decoder-tls-insecure-skip-verify") {
521+
opts.DecoderInsecureSkipVerify = v
522+
}
523+
} else {
524+
return errors.New("Type assertion failed for decoder-tls-insecure-skip-verify: " + fmt.Sprintf("%v", configurationMap["decoder-tls-insecure-skip-verify"]))
525+
}
526+
}
527+
if configurationMap["secure-proxy"] != nil {
528+
if v, ok := configurationMap["secure-proxy"].(bool); ok {
529+
if opts.isDefault("secure-proxy") {
530+
opts.SecureProxy = v
531+
}
532+
} else {
533+
return errors.New("Type assertion failed for secure-proxy: " + fmt.Sprintf("%v", configurationMap["secure-proxy"]))
534+
}
535+
}
536+
if configurationMap["cert-path"] != nil {
537+
if v, ok := configurationMap["cert-path"].(string); ok {
538+
if opts.isDefault("cert-path") {
539+
opts.CertPath = v
540+
}
541+
} else {
542+
return errors.New("Type assertion failed for cert-path: " + fmt.Sprintf("%v", configurationMap["cert-path"]))
543+
}
544+
}
545+
if configurationMap["enable-ssrf-protection"] != nil {
546+
if v, ok := configurationMap["enable-ssrf-protection"].(bool); ok {
547+
if opts.isDefault("enable-ssrf-protection") {
548+
opts.EnableSSRFProtection = v
549+
}
550+
} else {
551+
return errors.New("Type assertion failed for enable-ssrf-protection: " + fmt.Sprintf("%v", configurationMap["enable-ssrf-protection"]))
552+
}
553+
}
554+
if configurationMap["inference-pool"] != nil {
555+
if v, ok := configurationMap["inference-pool"].(string); ok {
556+
if opts.isDefault("inference-pool") {
557+
opts.InferencePool = v
558+
}
559+
} else {
560+
return errors.New("Type assertion failed for inference-pool: " + fmt.Sprintf("%v", configurationMap["inference-pool"]))
561+
}
562+
}
563+
if configurationMap["inference-pool-namespace"] != nil {
564+
if v, ok := configurationMap["inference-pool-namespace"].(string); ok {
565+
if opts.isDefault("inference-pool-namespace") {
566+
opts.InferencePoolNamespace = v
567+
}
568+
} else {
569+
return errors.New("Type assertion failed for inference-pool-namespace: " + fmt.Sprintf("%v", configurationMap["inference-pool-namespace"]))
570+
}
571+
}
572+
if configurationMap["inference-pool-name"] != nil {
573+
if v, ok := configurationMap["inference-pool-name"].(string); ok {
574+
if opts.isDefault("inference-pool-name") {
575+
opts.InferencePoolName = v
576+
}
577+
} else {
578+
return errors.New("Type assertion failed for inference-pool-name: " + fmt.Sprintf("%v", configurationMap["inference-pool-name"]))
579+
}
580+
}
581+
if configurationMap["enable-prefiller-sampling"] != nil {
582+
if v, ok := configurationMap["enable-prefiller-sampling"].(bool); ok {
583+
if opts.isDefault("enable-prefiller-sampling") {
584+
opts.EnablePrefillerSampling = v
585+
}
586+
} else {
587+
return errors.New("Type assertion failed for enable-prefiller-sampling: " + fmt.Sprintf("%v", configurationMap["enable-prefiller-sampling"]))
588+
}
589+
}
590+
if configurationMap["pool-group"] != nil {
591+
if v, ok := configurationMap["pool-group"].(string); ok {
592+
if opts.isDefault("pool-group") {
593+
opts.PoolGroup = v
594+
}
595+
} else {
596+
return errors.New("Type assertion failed for pool-group: " + fmt.Sprintf("%v", configurationMap["pool-group"]))
597+
}
598+
}
599+
return nil
600+
}

0 commit comments

Comments
 (0)