Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 78 additions & 147 deletions README.md

Large diffs are not rendered by default.

16 changes: 11 additions & 5 deletions api/internal_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,17 @@ const (
// values here. All internal pipeline code reads routing exclusively from this
// struct rather than reaching back into the typed request.
type InternalRouting struct {
RetryCount int `json:"retry_count,omitempty"`
RequestQueueName string `json:"request_queue_name,omitempty"`
ResultQueueName string `json:"result_queue_name,omitempty"`
TransportCorrelationID string `json:"transport_correlation_id,omitempty"`
Classification QuotaClassification `json:"classification,omitempty"`
RetryCount int `json:"retry_count,omitempty"`
RequestQueueName string `json:"request_queue_name,omitempty"`
ResultQueueName string `json:"result_queue_name,omitempty"`
TransportCorrelationID string `json:"transport_correlation_id,omitempty"`
// Labels is the framework's per-message label set. Seeded by the
// Flow at pull time from the originating channel's effective
// static label set (auto-injected pool ID + pool.Labels +
// subscription.Labels, merged at startup). Gates and the merge
// policy read and mutate this map in place. Producer-controlled
// per-message correlation data rides on body.Metadata, not Labels.
Labels map[string]string `json:"labels,omitempty"`
}

// InternalRequest is the internal envelope: routing data plus a concrete Request.
Expand Down
12 changes: 0 additions & 12 deletions charts/async-processor/templates/_helpers.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,6 @@ Create the name of the service account to use
{{- default (include "async-processor.fullname" .) .Values.serviceAccount.name }}
{{- end }}

{{/*
Render gate params as JSON with all values as strings.
The gate params parser expects map[string]string, so numeric values must be quoted.
*/}}
{{- define "async-processor.gateParamsJson" -}}
{{- $out := dict -}}
{{- range $k, $v := .Values.ap.redis.gateParams -}}
{{- $_ := set $out $k ($v | toString) -}}
{{- end -}}
{{- $out | toJson -}}
{{- end }}

{{/*
Resolve the Redis secret name.
If redis.url is set, the chart creates a Secret named <fullname>-redis.
Expand Down
4 changes: 0 additions & 4 deletions charts/async-processor/templates/ap-deployments.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,6 @@ spec:
- "{{ .Values.ap.redis.pollIntervalMs | default 1000 }}"
- --redis.ss.batch-size
- "{{ .Values.ap.redis.batchSize | default 10 }}"
{{- if .Values.ap.redis.gateType }}
- --redis.ss.gate-type={{ .Values.ap.redis.gateType }}
- --redis.ss.gate-params={{ include "async-processor.gateParamsJson" . }}
{{- end }}
{{- else }}
- --message-queue-impl=redis-pubsub
- --redis.igw-base-url={{ .Values.ap.igwBaseURL }}
Expand Down
7 changes: 3 additions & 4 deletions charts/async-processor/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,12 @@ ap:
resultQueueName: "result-list"
pollIntervalMs: 1000
batchSize: 10
gateType: ""
gateParams: {}
# Note: per-pool / per-subscription gate chains are configured via
# the topics-config-file (see README), not chart values. Chart-level
# support for the new gate types is not yet wired.
# PodMonitor for model server metrics relabeling.
# Creates a PodMonitor that scrapes vLLM metrics and relabels
# the inference_pool pod label into scraped metrics.
# Required when using prometheus-budget or prometheus-saturation gates
# without llm-d's flow control plugin enabled.
modelServerMonitor:
enabled: false
selector:
Expand Down
104 changes: 81 additions & 23 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"crypto/tls"
"crypto/x509"
"encoding/json"
"flag"
"fmt"
"net/http"
Expand All @@ -12,13 +13,15 @@ import (
"github.com/go-logr/logr"
"github.com/llm-d-incubation/llm-d-async/internal/logging"
"github.com/llm-d-incubation/llm-d-async/pipeline"
"github.com/llm-d-incubation/llm-d-async/pkg/async"
_ "github.com/llm-d-incubation/llm-d-async/pkg/async" // register built-in merge policies
_ "github.com/llm-d-incubation/llm-d-async/pkg/async/inference/mergepolicy/tierpriority" // register tier-priority merge policy
"github.com/llm-d-incubation/llm-d-async/pkg/async/inference/flowcontrol"
"github.com/llm-d-incubation/llm-d-async/pkg/asyncworker"
"github.com/llm-d-incubation/llm-d-async/pkg/metrics"
"github.com/llm-d-incubation/llm-d-async/pkg/pubsub"
"github.com/llm-d-incubation/llm-d-async/pkg/redis"
"github.com/llm-d-incubation/llm-d-async/pkg/version"
goredis "github.com/redis/go-redis/v9"
"k8s.io/client-go/rest"
"sigs.k8s.io/controller-runtime/pkg/log/zap"

Expand Down Expand Up @@ -53,9 +56,11 @@ func main() {
flag.IntVar(&concurrency, "concurrency", 8, "number of concurrent workers")
flag.DurationVar(&requestTimeout, "request-timeout", 5*time.Minute, "timeout for individual inference requests")

flag.StringVar(&requestMergePolicy, "request-merge-policy", "random-robin", "The request merge policy to use. Supported policies: random-robin")
flag.StringVar(&requestMergePolicy, "request-merge-policy", "random-robin", "The request merge policy to use. Supported policies: random-robin, tier-priority")
flag.StringVar(&messageQueueImpl, "message-queue-impl", "redis-pubsub", "The message queue implementation to use. Supported implementations: redis-pubsub, redis-sortedset, gcp-pubsub, gcp-pubsub-gated")

var mergePolicyConfigJSON = flag.String("merge-policy-config", "{}", "JSON-encoded free-form config map passed to the selected merge policy's factory (see policy package docs for recognized keys).")

flag.StringVar(&tlsCACert, "tls-ca-cert", "", "Path to CA certificate file (PEM) for verifying the inference gateway")
flag.StringVar(&tlsCert, "tls-cert", "", "Path to client certificate file (PEM) for mTLS")
flag.StringVar(&tlsKey, "tls-key", "", "Path to client key file (PEM) for mTLS")
Expand All @@ -81,23 +86,32 @@ func main() {
setupLog.Info("Async Processor starting", "version", version.Version, "commit", version.Commit, "buildDate", version.BuildDate)

printAllFlags(setupLog)
// Create Gate Factory for per-queue gate instantiation
gateFactory := flowcontrol.NewGateFactoryWithCacheTTL(*prometheusURL, *prometheusCacheTTL)
defer func() {
if err := gateFactory.Close(); err != nil {
setupLog.Error(err, "Failed to close gate factory")
}
}()

var policy pipeline.RequestMergePolicy
switch requestMergePolicy {
case "random-robin":
policy = async.NewRandomRobinPolicy()
default:
setupLog.Error(fmt.Errorf("unknown request merge policy: %s", requestMergePolicy), "Unknown request merge policy", "request-merge-policy",
requestMergePolicy)
os.Exit(1)
// Set up the process-lifetime context before the gate factory so
// background goroutines owned by gates (e.g. PromQL refresh loops
// for tier-priority-admission) are driven by the signal-handler ctx.
ctx := ctrl.SetupSignalHandler()

// Shared redis client for gates that need it (e.g.
// reservation-classifier, tier-priority-admission redis-counter).
// Best-effort: if --redis.url is not set, those gates will fail at
// construction. Other gate types and Flow impls that don't need
// redis are unaffected.
var gateRedisClient *goredis.Client
if redisOpts, err := redis.RedisOptions(); err == nil {
gateRedisClient = goredis.NewClient(redisOpts)
} else {
setupLog.Info("Redis client for gates not available; redis-backed gates will not initialize", "reason", err.Error())
}

gateFactoryOpts := []flowcontrol.GateFactoryOption{
flowcontrol.WithBackgroundContext(ctx),
}
if gateRedisClient != nil {
gateFactoryOpts = append(gateFactoryOpts, flowcontrol.WithRedisClient(gateRedisClient))
}
gateFactory := flowcontrol.NewGateFactoryWithCacheTTL(*prometheusURL, *prometheusCacheTTL, gateFactoryOpts...)

var impl pipeline.Flow
switch messageQueueImpl {
case "redis-pubsub":
Expand Down Expand Up @@ -126,9 +140,17 @@ func main() {
os.Exit(1)
}

metrics.Register(metrics.GetAsyncProcessorCollectors(impl.Characteristics().SupportsMessageLatency)...)
// Build the merge policy via the registry after the Flow is constructed.
policy, err := pipeline.NewMergePolicy(requestMergePolicy, pipeline.MergePolicyDeps{
GateFactory: gateFactory,
Config: parseStringMap(*mergePolicyConfigJSON),
})
if err != nil {
setupLog.Error(err, "Failed to construct merge policy", "request-merge-policy", requestMergePolicy)
os.Exit(1)
}

ctx := ctrl.SetupSignalHandler()
metrics.Register(metrics.GetAsyncProcessorCollectors(impl.Characteristics().SupportsMessageLatency)...)

// Register metrics handler.
// Metrics endpoint is enabled in 'config/default/kustomization.yaml'. The Metrics options configure the server.
Expand Down Expand Up @@ -170,11 +192,33 @@ func main() {
inferenceHTTPClient := &http.Client{Transport: inferenceTransport}
inferenceClient := asyncworker.NewHTTPInferenceClient(inferenceHTTPClient)

requestChannel := policy.MergeRequestChannels(impl.RequestChannels()).Channel
for w := 1; w <= concurrency; w++ {

go asyncworker.Worker(ctx, impl.Characteristics(), inferenceClient, requestChannel, impl.RetryChannel(), impl.ResultChannel(), requestTimeout)
// Per-pool dispatch: the merge policy fans subscriptions into one
// channel per inference pool. Each pool gets its own dedicated
// worker pool so backpressure on one pool's downstream endpoint
// stays local — a saturated pool's workers and prefetch stall
// without affecting any other pool's throughput.
dispatch := policy.MergeRequestChannels(impl.RequestChannels())
poolByID := map[string]pipeline.Pool{}
for _, p := range impl.Pools() {
poolByID[p.ID] = p
}
totalWorkers := 0
for poolID, ch := range dispatch.Channels {
pool := poolByID[poolID]
workers := pool.Workers
if workers <= 0 {
workers = concurrency
}
totalWorkers += workers
setupLog.Info("Spawning per-pool worker pool",
"poolID", poolID,
"gatewayURL", pool.GatewayURL,
"workers", workers)
for w := 0; w < workers; w++ {
go asyncworker.Worker(ctx, impl.Characteristics(), inferenceClient, ch, impl.RetryChannel(), impl.ResultChannel(), requestTimeout)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worker goroutines fire-and-forget: Workers are spawned with bare go asyncworker.Worker(...) and nothing waits for them to finish — the old WaitGroup was removed. After <-ctx.Done() returns (line 224), main() exits immediately, so in-flight requests are abandoned mid-dispatch. This can lose results or leave partial state in Redis.

Add a sync.WaitGroup (or similar) so the shutdown path waits for all workers to drain before exiting.

}
}
setupLog.Info("Per-pool worker pools started", "pools", len(dispatch.Channels), "totalWorkers", totalWorkers)

impl.Start(ctx)
<-ctx.Done()
Expand Down Expand Up @@ -221,6 +265,20 @@ func buildTLSConfig(caCertPath, certPath, keyPath string, insecureSkipVerify boo
return tlsConfig, nil
}

// parseStringMap parses a JSON object string into a map[string]string.
// Empty / "{}" yields an empty map; parse failures yield nil. Callers
// reading individual keys are not affected by missing entries.
func parseStringMap(s string) map[string]string {
if s == "" || s == "{}" {
return map[string]string{}
}
out := map[string]string{}
if err := json.Unmarshal([]byte(s), &out); err != nil {
return nil
}
return out
}

func printAllFlags(setupLog logr.Logger) {
flags := make(map[string]any)
flag.VisitAll(func(f *flag.Flag) {
Expand Down
76 changes: 0 additions & 76 deletions docs/dispatch-budget.md

This file was deleted.

Loading