Skip to content

Commit ddf9402

Browse files
committed
removed flags into config file
Signed-off-by: Shimi Bandiel <shimib@google.com>
1 parent 20ec7d9 commit ddf9402

12 files changed

Lines changed: 379 additions & 342 deletions

File tree

cmd/main.go

Lines changed: 24 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/llm-d-incubation/llm-d-async/pkg/async"
1212
"github.com/llm-d-incubation/llm-d-async/pkg/async/api"
1313
"github.com/llm-d-incubation/llm-d-async/pkg/async/inference/flowcontrol"
14+
"github.com/llm-d-incubation/llm-d-async/pkg/config"
1415
"github.com/llm-d-incubation/llm-d-async/pkg/metrics"
1516
"github.com/llm-d-incubation/llm-d-async/pkg/pubsub"
1617
"github.com/llm-d-incubation/llm-d-async/pkg/redis"
@@ -25,86 +26,66 @@ import (
2526

2627
func main() {
2728

28-
var loggerVerbosity int
29-
30-
var metricsPort int
31-
var metricsEndpointAuth bool
32-
33-
var concurrency int
34-
var requestMergePolicy string
35-
var messageQueueImpl string
36-
37-
flag.IntVar(&loggerVerbosity, "v", logging.DEFAULT, "number for the log level verbosity")
38-
39-
flag.IntVar(&metricsPort, "metrics-port", 9090, "The metrics port")
40-
flag.BoolVar(&metricsEndpointAuth, "metrics-endpoint-auth", true, "Enables authentication and authorization of the metrics endpoint")
41-
42-
flag.IntVar(&concurrency, "concurrency", 8, "number of concurrent workers")
43-
44-
flag.StringVar(&requestMergePolicy, "request-merge-policy", "random-robin", "The request merge policy to use. Supported policies: random-robin")
45-
flag.StringVar(&messageQueueImpl, "message-queue-impl", "redis-pubsub", "The message queue implementation to use. Supported implementations: redis-pubsub, redis-sortedset, gcp-pubsub, gcp-pubsub-gated")
46-
47-
var prometheusURL = flag.String("prometheus-url", "", "Prometheus server URL for metric-based gates (e.g., http://localhost:9090)")
29+
var configFile string
30+
flag.StringVar(&configFile, "config", "", "Path to the configuration file")
4831

4932
opts := zap.Options{
5033
Development: true,
5134
}
52-
5335
opts.BindFlags(flag.CommandLine)
5436
flag.Parse()
5537

56-
logging.InitLogging(&opts, loggerVerbosity)
38+
cfg, err := config.LoadConfig(configFile)
39+
if err != nil {
40+
fmt.Printf("failed to load config: %v\n", err)
41+
os.Exit(1)
42+
}
43+
44+
logging.InitLogging(&opts, cfg.LogLevel)
5745
defer logging.Sync() // nolint:errcheck
5846

5947
setupLog := ctrl.Log.WithName("setup")
6048
setupLog.Info("Logger initialized")
6149

62-
////////setupLog.Info("GIE build", "commit-sha", version.CommitSHA, "build-ref", version.BuildRef)
63-
6450
printAllFlags(setupLog)
6551
// Create Gate Factory for per-queue gate instantiation
66-
gateFactory := flowcontrol.NewGateFactory(*prometheusURL)
52+
gateFactory := flowcontrol.NewGateFactory(cfg)
6753

6854
var policy api.RequestMergePolicy
69-
switch requestMergePolicy {
55+
switch cfg.RequestMergePolicy {
7056
case "random-robin":
7157
policy = async.NewRandomRobinPolicy()
7258
default:
73-
setupLog.Error(fmt.Errorf("unknown request merge policy: %s", requestMergePolicy), "Unknown request merge policy", "request-merge-policy",
74-
requestMergePolicy)
59+
setupLog.Error(fmt.Errorf("unknown request merge policy: %s", cfg.RequestMergePolicy), "Unknown request merge policy", "request-merge-policy",
60+
cfg.RequestMergePolicy)
7561
os.Exit(1)
7662
}
7763
var impl api.Flow
78-
switch messageQueueImpl {
64+
switch cfg.MessageQueueImpl {
7965
case "redis-pubsub":
80-
impl = redis.NewRedisMQFlow()
66+
impl = redis.NewRedisMQFlow(cfg.Redis)
8167
case "redis-sortedset":
82-
impl = redis.NewRedisSortedSetFlow(redis.WithGateFactory(gateFactory))
68+
impl = redis.NewRedisSortedSetFlow(cfg.RedisSortedSet, redis.WithGateFactory(gateFactory))
8369
setupLog.Info("Using Redis sorted-set flow with per-queue gating")
8470
case "gcp-pubsub":
85-
impl = pubsub.NewGCPPubSubMQFlow()
71+
impl = pubsub.NewGCPPubSubMQFlow(cfg.PubSub)
8672
case "gcp-pubsub-gated":
87-
impl = pubsub.NewGCPPubSubMQFlow(pubsub.WithGateFactory(gateFactory))
73+
impl = pubsub.NewGCPPubSubMQFlow(cfg.PubSub, pubsub.WithGateFactory(gateFactory))
8874
setupLog.Info("Using GCP PubSub flow with per-queue gating")
8975
default:
90-
setupLog.Error(fmt.Errorf("unknown message queue implementation: %s", messageQueueImpl), "Unknown message queue implementation",
91-
"message-queue-impl", messageQueueImpl)
76+
setupLog.Error(fmt.Errorf("unknown message queue implementation: %s", cfg.MessageQueueImpl), "Unknown message queue implementation",
77+
"message-queue-impl", cfg.MessageQueueImpl)
9278
os.Exit(1)
9379
}
9480

9581
metrics.Register(metrics.GetAsyncProcessorCollectors(impl.Characteristics().SupportsMessageLatency)...)
9682

9783
ctx := ctrl.SetupSignalHandler()
9884

99-
// Register metrics handler.
100-
// Metrics endpoint is enabled in 'config/default/kustomization.yaml'. The Metrics options configure the server.
101-
// More info:
102-
// - https://pkg.go.dev/sigs.k8s.io/controller-runtime@v0.19.1/pkg/metrics/server
103-
// - https://book.kubebuilder.io/reference/metrics.html
10485
metricsServerOptions := metricsserver.Options{
105-
BindAddress: fmt.Sprintf(":%d", metricsPort),
86+
BindAddress: fmt.Sprintf(":%d", cfg.MetricsPort),
10687
FilterProvider: func() func(c *rest.Config, httpClient *http.Client) (metricsserver.Filter, error) {
107-
if metricsEndpointAuth {
88+
if cfg.MetricsEndpointAuth {
10889
return filters.WithAuthenticationAndAuthorization
10990
}
11091

@@ -121,7 +102,7 @@ func main() {
121102
inferenceClient := api.NewHTTPInferenceClient(httpClient)
122103

123104
requestChannel := policy.MergeRequestChannels(impl.RequestChannels()).Channel
124-
for w := 1; w <= concurrency; w++ {
105+
for w := 1; w <= cfg.Concurrency; w++ {
125106

126107
go api.Worker(ctx, impl.Characteristics(), inferenceClient, requestChannel, impl.RetryChannel(), impl.ResultChannel())
127108
}

pkg/async/inference/flowcontrol/binary_metric_dispatch_gate.go

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,11 @@ package flowcontrol
1818

1919
import (
2020
"context"
21-
"flag"
2221

23-
"github.com/prometheus/client_golang/api"
2422
"sigs.k8s.io/controller-runtime/pkg/log"
2523
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
2624
)
2725

28-
var isGMP = flag.Bool("gate.pmetric.is-gmp", false, "Is this GMP (Google Managed Prometheus).")
29-
var prometheusURL = flag.String("gate.prometheus.url", "", "Prometheus URL for non GMP metric")
30-
var gmpProjectID = flag.String("gate.pmetric.gmp.project-id", "", "Project ID for Google Managed Prometheus")
31-
var prometheusQueryModelName = flag.String("gate.prometheus.model-name", "", "metrics name to use for avg_queue_size")
32-
33-
// BinaryMetricDispatchGate implements DispatchGate using a MetricSource.
34-
// It returns 0.0 (no capacity) if the metric value is non-zero,
35-
// and 1.0 (full capacity) if the metric value is zero.
3626
type BinaryMetricDispatchGate struct {
3727
source MetricSource
3828
}
@@ -64,28 +54,3 @@ func (g *BinaryMetricDispatchGate) Budget(ctx context.Context) float64 {
6454
}
6555
return 0.0
6656
}
67-
68-
// AverageQueueSizeGate creates a BinaryMetricDispatchGate from command-line flags.
69-
func AverageQueueSizeGate() *BinaryMetricDispatchGate {
70-
expr := buildPromQL("inference_pool_average_queue_size",
71-
map[string]string{"name": *prometheusQueryModelName})
72-
73-
var source MetricSource
74-
if *isGMP {
75-
var err error
76-
source, err = NewGMPPromQLMetricSource(*gmpProjectID, expr)
77-
if err != nil {
78-
panic(err)
79-
}
80-
} else {
81-
var err error
82-
source, err = NewPromQLMetricSource(api.Config{
83-
Address: *prometheusURL,
84-
}, expr)
85-
if err != nil {
86-
panic(err)
87-
}
88-
}
89-
90-
return NewBinaryMetricDispatchGateWithSource(source)
91-
}

pkg/async/inference/flowcontrol/gate_factory.go

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,34 +21,27 @@ import (
2121
"strconv"
2222

2323
asyncapi "github.com/llm-d-incubation/llm-d-async/pkg/async/api"
24+
"github.com/llm-d-incubation/llm-d-async/pkg/config"
2425
redisgate "github.com/llm-d-incubation/llm-d-async/pkg/redis"
2526
promapi "github.com/prometheus/client_golang/api"
2627
goredis "github.com/redis/go-redis/v9"
2728
)
2829

2930
// GateFactory creates DispatchGate instances based on configuration.
3031
type GateFactory struct {
31-
prometheusURL string
32-
redisClients map[string]*goredis.Client
32+
cfg *config.Config
33+
redisClients map[string]*goredis.Client
3334
}
3435

35-
// NewGateFactory creates a new GateFactory with an optional Prometheus URL.
36-
// If prometheusURL is empty, Prometheus gates will fail at creation time.
37-
func NewGateFactory(prometheusURL string) *GateFactory {
36+
// NewGateFactory creates a new GateFactory with the provided configuration.
37+
func NewGateFactory(cfg *config.Config) *GateFactory {
3838
return &GateFactory{
39-
prometheusURL: prometheusURL,
40-
redisClients: make(map[string]*goredis.Client),
39+
cfg: cfg,
40+
redisClients: make(map[string]*goredis.Client),
4141
}
4242
}
4343

4444
// CreateGate creates a DispatchGate based on the gate type and parameters.
45-
// Supported gate types:
46-
// - "constant": Always returns budget 1.0 (fully open)
47-
// - "redis": Queries Redis for dispatch budget
48-
// - "prometheus-saturation": Queries Prometheus for pool saturation metric
49-
// Optional params: threshold (default 0.8), fallback (default 0.0)
50-
//
51-
// For unsupported or unknown gate types, returns ConstOpenGate as a safe default.
5245
func (f *GateFactory) CreateGate(gateType string, params map[string]string) (asyncapi.DispatchGate, error) {
5346
switch gateType {
5447
case "constant":
@@ -71,13 +64,17 @@ func (f *GateFactory) CreateGate(gateType string, params map[string]string) (asy
7164
return redisgate.NewRedisDispatchGate(client, budgetKey), nil
7265

7366
case "prometheus-saturation":
74-
if f.prometheusURL == "" {
75-
return nil, fmt.Errorf("prometheus-saturation gate type requires --prometheus-url flag to be set")
67+
prometheusURL := f.cfg.PrometheusURL
68+
if prometheusURL == "" {
69+
prometheusURL = f.cfg.Gates.Prometheus.URL
70+
}
71+
if prometheusURL == "" {
72+
return nil, fmt.Errorf("prometheus-saturation gate type requires prometheusURL to be set in config")
7673
}
7774

7875
pool := params["pool"]
7976

80-
threshold := 0.8 // default threshold
77+
threshold := f.cfg.Gates.Saturation.Threshold
8178
if thresholdStr := params["threshold"]; thresholdStr != "" {
8279
t, err := strconv.ParseFloat(thresholdStr, 64)
8380
if err != nil {
@@ -86,7 +83,7 @@ func (f *GateFactory) CreateGate(gateType string, params map[string]string) (asy
8683
threshold = t
8784
}
8885

89-
fallback := 0.0 // default fallback saturation
86+
fallback := f.cfg.Gates.Saturation.Fallback
9087
if fallbackStr := params["fallback"]; fallbackStr != "" {
9188
fb, err := strconv.ParseFloat(fallbackStr, 64)
9289
if err != nil {
@@ -96,6 +93,9 @@ func (f *GateFactory) CreateGate(gateType string, params map[string]string) (asy
9693
}
9794

9895
queryExpr := params["query"]
96+
if queryExpr == "" {
97+
queryExpr = f.cfg.Gates.Saturation.QueryExpr
98+
}
9999
if queryExpr == "" {
100100
labels := map[string]string{}
101101
if pool != "" {
@@ -104,7 +104,7 @@ func (f *GateFactory) CreateGate(gateType string, params map[string]string) (asy
104104
queryExpr = buildPromQL("inference_extension_flow_control_pool_saturation", labels)
105105
}
106106

107-
source, err := NewPromQLMetricSource(promapi.Config{Address: f.prometheusURL}, queryExpr)
107+
source, err := NewPromQLMetricSource(promapi.Config{Address: prometheusURL}, queryExpr)
108108
if err != nil {
109109
return nil, fmt.Errorf("failed to create Prometheus metric source: %w", err)
110110
}

pkg/async/inference/flowcontrol/gate_factory_test.go

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@ import (
2020
"context"
2121
"testing"
2222

23+
"github.com/llm-d-incubation/llm-d-async/pkg/config"
2324
"github.com/stretchr/testify/assert"
2425
)
2526

2627
func TestGateFactory_CreateConstantGate(t *testing.T) {
27-
factory := NewGateFactory("")
28+
factory := NewGateFactory(config.DefaultConfig())
2829
gate, err := factory.CreateGate("constant", nil)
2930

3031
assert.NoError(t, err)
@@ -34,7 +35,7 @@ func TestGateFactory_CreateConstantGate(t *testing.T) {
3435
}
3536

3637
func TestGateFactory_UnknownGateType(t *testing.T) {
37-
factory := NewGateFactory("")
38+
factory := NewGateFactory(config.DefaultConfig())
3839
gate, err := factory.CreateGate("unknown-type", nil)
3940

4041
assert.NoError(t, err)
@@ -45,7 +46,7 @@ func TestGateFactory_UnknownGateType(t *testing.T) {
4546
}
4647

4748
func TestGateFactory_EmptyGateType(t *testing.T) {
48-
factory := NewGateFactory("")
49+
factory := NewGateFactory(config.DefaultConfig())
4950
gate, err := factory.CreateGate("", nil)
5051

5152
assert.NoError(t, err)
@@ -55,22 +56,28 @@ func TestGateFactory_EmptyGateType(t *testing.T) {
5556
}
5657

5758
func TestGateFactory_PrometheusGateWithoutURL(t *testing.T) {
58-
factory := NewGateFactory("") // No Prometheus URL
59+
cfg := config.DefaultConfig()
60+
cfg.PrometheusURL = ""
61+
factory := NewGateFactory(cfg) // No Prometheus URL
5962
gate, err := factory.CreateGate("prometheus-saturation", map[string]string{})
6063
assert.Error(t, err, "should return error when Prometheus URL is not set")
6164
assert.Nil(t, gate)
62-
assert.Contains(t, err.Error(), "prometheus-saturation gate type requires --prometheus-url flag to be set")
65+
assert.Contains(t, err.Error(), "prometheus-saturation gate type requires prometheusURL to be set in config")
6366
}
6467

6568
func TestGateFactory_PrometheusGateWithoutPoolParam(t *testing.T) {
66-
factory := NewGateFactory("http://localhost:9090")
69+
cfg := config.DefaultConfig()
70+
cfg.PrometheusURL = "http://localhost:9090"
71+
factory := NewGateFactory(cfg)
6772
gate, err := factory.CreateGate("prometheus-saturation", map[string]string{})
6873
assert.NoError(t, err, "should not error when pool parameter is missing")
6974
assert.NotNil(t, gate)
7075
}
7176

7277
func TestGateFactory_PrometheusGateWithInvalidThreshold(t *testing.T) {
73-
factory := NewGateFactory("http://localhost:9090")
78+
cfg := config.DefaultConfig()
79+
cfg.PrometheusURL = "http://localhost:9090"
80+
factory := NewGateFactory(cfg)
7481
gate, err := factory.CreateGate("prometheus-saturation", map[string]string{
7582
"threshold": "not-a-number",
7683
})
@@ -80,7 +87,9 @@ func TestGateFactory_PrometheusGateWithInvalidThreshold(t *testing.T) {
8087
}
8188

8289
func TestGateFactory_PrometheusGateWithInvalidFallback(t *testing.T) {
83-
factory := NewGateFactory("http://localhost:9090")
90+
cfg := config.DefaultConfig()
91+
cfg.PrometheusURL = "http://localhost:9090"
92+
factory := NewGateFactory(cfg)
8493
gate, err := factory.CreateGate("prometheus-saturation", map[string]string{
8594
"fallback": "not-a-number",
8695
})
@@ -90,7 +99,9 @@ func TestGateFactory_PrometheusGateWithInvalidFallback(t *testing.T) {
9099
}
91100

92101
func TestGateFactory_PrometheusGateWithThresholdAndFallback(t *testing.T) {
93-
factory := NewGateFactory("http://localhost:9090")
102+
cfg := config.DefaultConfig()
103+
cfg.PrometheusURL = "http://localhost:9090"
104+
factory := NewGateFactory(cfg)
94105
gate, err := factory.CreateGate("prometheus-saturation", map[string]string{
95106
"threshold": "0.7",
96107
"fallback": "0.3",
@@ -100,22 +111,22 @@ func TestGateFactory_PrometheusGateWithThresholdAndFallback(t *testing.T) {
100111
}
101112

102113
func TestGateFactory_RedisGateMissingAddress(t *testing.T) {
103-
factory := NewGateFactory("")
114+
factory := NewGateFactory(config.DefaultConfig())
104115
gate, err := factory.CreateGate("redis", map[string]string{})
105116
assert.Error(t, err, "should return error when address is missing")
106117
assert.Nil(t, gate)
107118
assert.Contains(t, err.Error(), "redis gate requires an 'address' in gate_params")
108119
}
109120

110121
func TestGateFactory_RedisGateNilParams(t *testing.T) {
111-
factory := NewGateFactory("")
122+
factory := NewGateFactory(config.DefaultConfig())
112123
gate, err := factory.CreateGate("redis", nil)
113124
assert.Error(t, err, "should return error when params is nil")
114125
assert.Nil(t, gate)
115126
}
116127

117128
func TestGateFactory_RedisGateSharesClient(t *testing.T) {
118-
factory := NewGateFactory("")
129+
factory := NewGateFactory(config.DefaultConfig())
119130
params := map[string]string{"address": "localhost:6379"}
120131
gate1, err1 := factory.CreateGate("redis", params)
121132
gate2, err2 := factory.CreateGate("redis", params)
@@ -128,7 +139,7 @@ func TestGateFactory_RedisGateSharesClient(t *testing.T) {
128139
}
129140

130141
func TestGateFactory_RedisGateDifferentAddresses(t *testing.T) {
131-
factory := NewGateFactory("")
142+
factory := NewGateFactory(config.DefaultConfig())
132143
gate1, err1 := factory.CreateGate("redis", map[string]string{"address": "host1:6379"})
133144
gate2, err2 := factory.CreateGate("redis", map[string]string{"address": "host2:6379"})
134145
assert.NoError(t, err1)

0 commit comments

Comments
 (0)