Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 24 additions & 43 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/llm-d-incubation/llm-d-async/pkg/async"
"github.com/llm-d-incubation/llm-d-async/pkg/async/api"
"github.com/llm-d-incubation/llm-d-async/pkg/async/inference/flowcontrol"
"github.com/llm-d-incubation/llm-d-async/pkg/config"
"github.com/llm-d-incubation/llm-d-async/pkg/metrics"
"github.com/llm-d-incubation/llm-d-async/pkg/pubsub"
"github.com/llm-d-incubation/llm-d-async/pkg/redis"
Expand All @@ -25,86 +26,66 @@ import (

func main() {

var loggerVerbosity int

var metricsPort int
var metricsEndpointAuth bool

var concurrency int
var requestMergePolicy string
var messageQueueImpl string

flag.IntVar(&loggerVerbosity, "v", logging.DEFAULT, "number for the log level verbosity")

flag.IntVar(&metricsPort, "metrics-port", 9090, "The metrics port")
flag.BoolVar(&metricsEndpointAuth, "metrics-endpoint-auth", true, "Enables authentication and authorization of the metrics endpoint")

flag.IntVar(&concurrency, "concurrency", 8, "number of concurrent workers")

flag.StringVar(&requestMergePolicy, "request-merge-policy", "random-robin", "The request merge policy to use. Supported policies: random-robin")
flag.StringVar(&messageQueueImpl, "message-queue-impl", "redis-pubsub", "The message queue implementation to use. Supported implementations: redis-pubsub, redis-sortedset, gcp-pubsub, gcp-pubsub-gated")

var prometheusURL = flag.String("prometheus-url", "", "Prometheus server URL for metric-based gates (e.g., http://localhost:9090)")
var configFile string
flag.StringVar(&configFile, "config", "", "Path to the configuration file")

opts := zap.Options{
Development: true,
}

opts.BindFlags(flag.CommandLine)
flag.Parse()

logging.InitLogging(&opts, loggerVerbosity)
cfg, err := config.LoadConfig(configFile)
if err != nil {
fmt.Printf("failed to load config: %v\n", err)
os.Exit(1)
}

logging.InitLogging(&opts, cfg.LogLevel)
defer logging.Sync() // nolint:errcheck

setupLog := ctrl.Log.WithName("setup")
setupLog.Info("Logger initialized")

////////setupLog.Info("GIE build", "commit-sha", version.CommitSHA, "build-ref", version.BuildRef)

printAllFlags(setupLog)
// Create Gate Factory for per-queue gate instantiation
gateFactory := flowcontrol.NewGateFactory(*prometheusURL)
gateFactory := flowcontrol.NewGateFactory(cfg)

var policy api.RequestMergePolicy
switch requestMergePolicy {
switch cfg.RequestMergePolicy {
case "random-robin":
policy = async.NewRandomRobinPolicy()
default:
setupLog.Error(fmt.Errorf("unknown request merge policy: %s", requestMergePolicy), "Unknown request merge policy", "request-merge-policy",
requestMergePolicy)
setupLog.Error(fmt.Errorf("unknown request merge policy: %s", cfg.RequestMergePolicy), "Unknown request merge policy", "request-merge-policy",
cfg.RequestMergePolicy)
os.Exit(1)
}
var impl api.Flow
switch messageQueueImpl {
switch cfg.MessageQueueImpl {
case "redis-pubsub":
impl = redis.NewRedisMQFlow()
impl = redis.NewRedisMQFlow(cfg.Redis)
case "redis-sortedset":
impl = redis.NewRedisSortedSetFlow(redis.WithGateFactory(gateFactory))
impl = redis.NewRedisSortedSetFlow(cfg.RedisSortedSet, redis.WithGateFactory(gateFactory))
setupLog.Info("Using Redis sorted-set flow with per-queue gating")
case "gcp-pubsub":
impl = pubsub.NewGCPPubSubMQFlow()
impl = pubsub.NewGCPPubSubMQFlow(cfg.PubSub)
case "gcp-pubsub-gated":
impl = pubsub.NewGCPPubSubMQFlow(pubsub.WithGateFactory(gateFactory))
impl = pubsub.NewGCPPubSubMQFlow(cfg.PubSub, pubsub.WithGateFactory(gateFactory))
setupLog.Info("Using GCP PubSub flow with per-queue gating")
default:
setupLog.Error(fmt.Errorf("unknown message queue implementation: %s", messageQueueImpl), "Unknown message queue implementation",
"message-queue-impl", messageQueueImpl)
setupLog.Error(fmt.Errorf("unknown message queue implementation: %s", cfg.MessageQueueImpl), "Unknown message queue implementation",
"message-queue-impl", cfg.MessageQueueImpl)
os.Exit(1)
}

metrics.Register(metrics.GetAsyncProcessorCollectors(impl.Characteristics().SupportsMessageLatency)...)

ctx := ctrl.SetupSignalHandler()

// Register metrics handler.
// Metrics endpoint is enabled in 'config/default/kustomization.yaml'. The Metrics options configure the server.
// More info:
// - https://pkg.go.dev/sigs.k8s.io/controller-runtime@v0.19.1/pkg/metrics/server
// - https://book.kubebuilder.io/reference/metrics.html
metricsServerOptions := metricsserver.Options{
BindAddress: fmt.Sprintf(":%d", metricsPort),
BindAddress: fmt.Sprintf(":%d", cfg.MetricsPort),
FilterProvider: func() func(c *rest.Config, httpClient *http.Client) (metricsserver.Filter, error) {
if metricsEndpointAuth {
if cfg.MetricsEndpointAuth {
return filters.WithAuthenticationAndAuthorization
}

Expand All @@ -121,7 +102,7 @@ func main() {
inferenceClient := api.NewHTTPInferenceClient(httpClient)

requestChannel := policy.MergeRequestChannels(impl.RequestChannels()).Channel
for w := 1; w <= concurrency; w++ {
for w := 1; w <= cfg.Concurrency; w++ {

go api.Worker(ctx, impl.Characteristics(), inferenceClient, requestChannel, impl.RetryChannel(), impl.ResultChannel())
}
Expand Down
35 changes: 0 additions & 35 deletions pkg/async/inference/flowcontrol/binary_metric_dispatch_gate.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,11 @@ package flowcontrol

import (
"context"
"flag"

"github.com/prometheus/client_golang/api"
"sigs.k8s.io/controller-runtime/pkg/log"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)

var isGMP = flag.Bool("gate.pmetric.is-gmp", false, "Is this GMP (Google Managed Prometheus).")
var prometheusURL = flag.String("gate.prometheus.url", "", "Prometheus URL for non GMP metric")
var gmpProjectID = flag.String("gate.pmetric.gmp.project-id", "", "Project ID for Google Managed Prometheus")
var prometheusQueryModelName = flag.String("gate.prometheus.model-name", "", "metrics name to use for avg_queue_size")

// BinaryMetricDispatchGate implements DispatchGate using a MetricSource.
// It returns 0.0 (no capacity) if the metric value is non-zero,
// and 1.0 (full capacity) if the metric value is zero.
type BinaryMetricDispatchGate struct {
source MetricSource
}
Expand Down Expand Up @@ -64,28 +54,3 @@ func (g *BinaryMetricDispatchGate) Budget(ctx context.Context) float64 {
}
return 0.0
}

// AverageQueueSizeGate creates a BinaryMetricDispatchGate from command-line flags.
func AverageQueueSizeGate() *BinaryMetricDispatchGate {
expr := buildPromQL("inference_pool_average_queue_size",
map[string]string{"name": *prometheusQueryModelName})

var source MetricSource
if *isGMP {
var err error
source, err = NewGMPPromQLMetricSource(*gmpProjectID, expr)
if err != nil {
panic(err)
}
} else {
var err error
source, err = NewPromQLMetricSource(api.Config{
Address: *prometheusURL,
}, expr)
if err != nil {
panic(err)
}
}

return NewBinaryMetricDispatchGateWithSource(source)
}
38 changes: 19 additions & 19 deletions pkg/async/inference/flowcontrol/gate_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,27 @@ import (
"strconv"

asyncapi "github.com/llm-d-incubation/llm-d-async/pkg/async/api"
"github.com/llm-d-incubation/llm-d-async/pkg/config"
redisgate "github.com/llm-d-incubation/llm-d-async/pkg/redis"
promapi "github.com/prometheus/client_golang/api"
goredis "github.com/redis/go-redis/v9"
)

// GateFactory creates DispatchGate instances based on configuration.
type GateFactory struct {
prometheusURL string
redisClients map[string]*goredis.Client
cfg *config.Config
redisClients map[string]*goredis.Client
}

// NewGateFactory creates a new GateFactory with an optional Prometheus URL.
// If prometheusURL is empty, Prometheus gates will fail at creation time.
func NewGateFactory(prometheusURL string) *GateFactory {
// NewGateFactory creates a new GateFactory with the provided configuration.
func NewGateFactory(cfg *config.Config) *GateFactory {
return &GateFactory{
prometheusURL: prometheusURL,
redisClients: make(map[string]*goredis.Client),
cfg: cfg,
redisClients: make(map[string]*goredis.Client),
}
}

// CreateGate creates a DispatchGate based on the gate type and parameters.
// Supported gate types:
// - "constant": Always returns budget 1.0 (fully open)
// - "redis": Queries Redis for dispatch budget
// - "prometheus-saturation": Queries Prometheus for pool saturation metric
// Optional params: threshold (default 0.8), fallback (default 0.0)
//
// For unsupported or unknown gate types, returns ConstOpenGate as a safe default.
func (f *GateFactory) CreateGate(gateType string, params map[string]string) (asyncapi.DispatchGate, error) {
switch gateType {
case "constant":
Expand All @@ -71,13 +64,17 @@ func (f *GateFactory) CreateGate(gateType string, params map[string]string) (asy
return redisgate.NewRedisDispatchGate(client, budgetKey), nil

case "prometheus-saturation":
if f.prometheusURL == "" {
return nil, fmt.Errorf("prometheus-saturation gate type requires --prometheus-url flag to be set")
prometheusURL := f.cfg.PrometheusURL
if prometheusURL == "" {
prometheusURL = f.cfg.Gates.Prometheus.URL
}
if prometheusURL == "" {
return nil, fmt.Errorf("prometheus-saturation gate type requires prometheusURL to be set in config")
}

pool := params["pool"]

threshold := 0.8 // default threshold
threshold := f.cfg.Gates.Saturation.Threshold
if thresholdStr := params["threshold"]; thresholdStr != "" {
t, err := strconv.ParseFloat(thresholdStr, 64)
if err != nil {
Expand All @@ -86,7 +83,7 @@ func (f *GateFactory) CreateGate(gateType string, params map[string]string) (asy
threshold = t
}

fallback := 0.0 // default fallback saturation
fallback := f.cfg.Gates.Saturation.Fallback
if fallbackStr := params["fallback"]; fallbackStr != "" {
fb, err := strconv.ParseFloat(fallbackStr, 64)
if err != nil {
Expand All @@ -96,6 +93,9 @@ func (f *GateFactory) CreateGate(gateType string, params map[string]string) (asy
}

queryExpr := params["query"]
if queryExpr == "" {
queryExpr = f.cfg.Gates.Saturation.QueryExpr
}
if queryExpr == "" {
labels := map[string]string{}
if pool != "" {
Expand All @@ -104,7 +104,7 @@ func (f *GateFactory) CreateGate(gateType string, params map[string]string) (asy
queryExpr = buildPromQL("inference_extension_flow_control_pool_saturation", labels)
}

source, err := NewPromQLMetricSource(promapi.Config{Address: f.prometheusURL}, queryExpr)
source, err := NewPromQLMetricSource(promapi.Config{Address: prometheusURL}, queryExpr)
if err != nil {
return nil, fmt.Errorf("failed to create Prometheus metric source: %w", err)
}
Expand Down
37 changes: 24 additions & 13 deletions pkg/async/inference/flowcontrol/gate_factory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ import (
"context"
"testing"

"github.com/llm-d-incubation/llm-d-async/pkg/config"
"github.com/stretchr/testify/assert"
)

func TestGateFactory_CreateConstantGate(t *testing.T) {
factory := NewGateFactory("")
factory := NewGateFactory(config.DefaultConfig())
gate, err := factory.CreateGate("constant", nil)

assert.NoError(t, err)
Expand All @@ -34,7 +35,7 @@ func TestGateFactory_CreateConstantGate(t *testing.T) {
}

func TestGateFactory_UnknownGateType(t *testing.T) {
factory := NewGateFactory("")
factory := NewGateFactory(config.DefaultConfig())
gate, err := factory.CreateGate("unknown-type", nil)

assert.NoError(t, err)
Expand All @@ -45,7 +46,7 @@ func TestGateFactory_UnknownGateType(t *testing.T) {
}

func TestGateFactory_EmptyGateType(t *testing.T) {
factory := NewGateFactory("")
factory := NewGateFactory(config.DefaultConfig())
gate, err := factory.CreateGate("", nil)

assert.NoError(t, err)
Expand All @@ -55,22 +56,28 @@ func TestGateFactory_EmptyGateType(t *testing.T) {
}

func TestGateFactory_PrometheusGateWithoutURL(t *testing.T) {
factory := NewGateFactory("") // No Prometheus URL
cfg := config.DefaultConfig()
cfg.PrometheusURL = ""
factory := NewGateFactory(cfg) // No Prometheus URL
gate, err := factory.CreateGate("prometheus-saturation", map[string]string{})
assert.Error(t, err, "should return error when Prometheus URL is not set")
assert.Nil(t, gate)
assert.Contains(t, err.Error(), "prometheus-saturation gate type requires --prometheus-url flag to be set")
assert.Contains(t, err.Error(), "prometheus-saturation gate type requires prometheusURL to be set in config")
}

func TestGateFactory_PrometheusGateWithoutPoolParam(t *testing.T) {
factory := NewGateFactory("http://localhost:9090")
cfg := config.DefaultConfig()
cfg.PrometheusURL = "http://localhost:9090"
factory := NewGateFactory(cfg)
gate, err := factory.CreateGate("prometheus-saturation", map[string]string{})
assert.NoError(t, err, "should not error when pool parameter is missing")
assert.NotNil(t, gate)
}

func TestGateFactory_PrometheusGateWithInvalidThreshold(t *testing.T) {
factory := NewGateFactory("http://localhost:9090")
cfg := config.DefaultConfig()
cfg.PrometheusURL = "http://localhost:9090"
factory := NewGateFactory(cfg)
gate, err := factory.CreateGate("prometheus-saturation", map[string]string{
"threshold": "not-a-number",
})
Expand All @@ -80,7 +87,9 @@ func TestGateFactory_PrometheusGateWithInvalidThreshold(t *testing.T) {
}

func TestGateFactory_PrometheusGateWithInvalidFallback(t *testing.T) {
factory := NewGateFactory("http://localhost:9090")
cfg := config.DefaultConfig()
cfg.PrometheusURL = "http://localhost:9090"
factory := NewGateFactory(cfg)
gate, err := factory.CreateGate("prometheus-saturation", map[string]string{
"fallback": "not-a-number",
})
Expand All @@ -90,7 +99,9 @@ func TestGateFactory_PrometheusGateWithInvalidFallback(t *testing.T) {
}

func TestGateFactory_PrometheusGateWithThresholdAndFallback(t *testing.T) {
factory := NewGateFactory("http://localhost:9090")
cfg := config.DefaultConfig()
cfg.PrometheusURL = "http://localhost:9090"
factory := NewGateFactory(cfg)
gate, err := factory.CreateGate("prometheus-saturation", map[string]string{
"threshold": "0.7",
"fallback": "0.3",
Expand All @@ -100,22 +111,22 @@ func TestGateFactory_PrometheusGateWithThresholdAndFallback(t *testing.T) {
}

func TestGateFactory_RedisGateMissingAddress(t *testing.T) {
factory := NewGateFactory("")
factory := NewGateFactory(config.DefaultConfig())
gate, err := factory.CreateGate("redis", map[string]string{})
assert.Error(t, err, "should return error when address is missing")
assert.Nil(t, gate)
assert.Contains(t, err.Error(), "redis gate requires an 'address' in gate_params")
}

func TestGateFactory_RedisGateNilParams(t *testing.T) {
factory := NewGateFactory("")
factory := NewGateFactory(config.DefaultConfig())
gate, err := factory.CreateGate("redis", nil)
assert.Error(t, err, "should return error when params is nil")
assert.Nil(t, gate)
}

func TestGateFactory_RedisGateSharesClient(t *testing.T) {
factory := NewGateFactory("")
factory := NewGateFactory(config.DefaultConfig())
params := map[string]string{"address": "localhost:6379"}
gate1, err1 := factory.CreateGate("redis", params)
gate2, err2 := factory.CreateGate("redis", params)
Expand All @@ -128,7 +139,7 @@ func TestGateFactory_RedisGateSharesClient(t *testing.T) {
}

func TestGateFactory_RedisGateDifferentAddresses(t *testing.T) {
factory := NewGateFactory("")
factory := NewGateFactory(config.DefaultConfig())
gate1, err1 := factory.CreateGate("redis", map[string]string{"address": "host1:6379"})
gate2, err2 := factory.CreateGate("redis", map[string]string{"address": "host2:6379"})
assert.NoError(t, err1)
Expand Down
Loading
Loading