Skip to content

Commit 4aba1cb

Browse files
committed
Add worker pool-level gates
Introduces a pool-level admission control gating mechanism. This allows worker pools to block/park incoming requests in-memory when capacity is reached, preventing expensive broker nack-and-retry cycles. Signed-off-by: Jacob Murry <jacobmurry@google.com>
1 parent ab5d128 commit 4aba1cb

7 files changed

Lines changed: 466 additions & 15 deletions

File tree

cmd/main.go

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,19 @@ func main() {
287287

288288
dispatch := policy.MergeRequestChannels(impl.RequestChannels(), poolsMap)
289289

290+
poolGates := make(map[string]pipeline.Gate)
291+
for poolID, pool := range poolsMap {
292+
if pool.GateType != "" {
293+
gate, err := gateFactory.CreateGate(pool.GateType, pool.GateParams)
294+
if err != nil {
295+
setupLog.Error(err, "Failed to create pool gate", "poolID", poolID, "gateType", pool.GateType)
296+
os.Exit(1)
297+
}
298+
poolGates[poolID] = gate
299+
setupLog.Info("Created pool gate", "poolID", poolID, "gateType", pool.GateType, "gateParams", pool.GateParams)
300+
}
301+
}
302+
290303
var wg sync.WaitGroup
291304
for poolID, mergedChan := range dispatch.Channels {
292305
pool, ok := poolsMap[poolID]
@@ -295,14 +308,15 @@ func main() {
295308
os.Exit(1)
296309
}
297310
workersCount := pool.Workers
311+
poolGate := poolGates[poolID]
298312

299-
setupLog.Info("Spawning workers for pool", "poolID", poolID, "workers", workersCount)
313+
setupLog.Info("Spawning workers for pool", "poolID", poolID, "workers", workersCount, "hasGate", poolGate != nil)
300314
for w := 1; w <= workersCount; w++ {
301315
wg.Add(1)
302-
go func(mergedChan chan pipeline.EmbelishedRequestMessage) {
316+
go func(mergedChan chan pipeline.EmbelishedRequestMessage, poolGate pipeline.Gate) {
303317
defer wg.Done()
304-
asyncworker.Worker(signalCtx, drainCtx, impl.Characteristics(), inferenceClient, mergedChan, impl.RetryChannel(), impl.ResultChannel(), requestTimeout, transforms)
305-
}(mergedChan)
318+
asyncworker.WorkerWithGate(signalCtx, drainCtx, impl.Characteristics(), inferenceClient, mergedChan, impl.RetryChannel(), impl.ResultChannel(), requestTimeout, transforms, poolGate)
319+
}(mergedChan, poolGate)
306320
}
307321
}
308322

pipeline/worker_pool.go

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,44 @@ import (
44
"encoding/json"
55
"fmt"
66
"os"
7+
"strconv"
78
)
89

10+
// StringMap is a map[string]string that tolerates non-string JSON values
11+
// by converting them to their string representation during unmarshaling.
12+
type StringMap map[string]string
13+
14+
func (m *StringMap) UnmarshalJSON(data []byte) error {
15+
var raw map[string]interface{}
16+
if err := json.Unmarshal(data, &raw); err != nil {
17+
return err
18+
}
19+
result := make(map[string]string, len(raw))
20+
for k, v := range raw {
21+
switch val := v.(type) {
22+
case string:
23+
result[k] = val
24+
case float64:
25+
result[k] = strconv.FormatFloat(val, 'f', -1, 64)
26+
case bool:
27+
result[k] = strconv.FormatBool(val)
28+
case nil:
29+
result[k] = ""
30+
default:
31+
return fmt.Errorf("gate_params key %q: unsupported value type %T (only strings, numbers, and booleans are allowed)", k, v)
32+
}
33+
}
34+
*m = result
35+
return nil
36+
}
37+
938
// WorkerPoolConfig defines the configuration for a worker pool,
1039
// specifying the concurrency limit (number of workers) and its ID.
1140
type WorkerPoolConfig struct {
12-
ID string `json:"id"`
13-
Workers int `json:"workers"`
41+
ID string `json:"id"`
42+
Workers int `json:"workers"`
43+
GateType string `json:"gate_type,omitempty"`
44+
GateParams StringMap `json:"gate_params,omitempty"`
1445
}
1546

1647
// LoadWorkerPools loads and validates worker pool configurations from a JSON file.

pkg/async/inference/flowcontrol/gate_factory.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,15 @@ func (f *GateFactory) CreateGate(gateType string, params map[string]string) (pip
346346
if limit <= 0 {
347347
return nil, fmt.Errorf("local-max-concurrency limit must be greater than 0, got %d", limit)
348348
}
349-
return NewLocalConcurrencyGate(limit), nil
349+
gate := NewLocalConcurrencyGate(limit)
350+
gatingMode := params["gating_mode"]
351+
if gatingMode != "" {
352+
if gatingMode != string(GatingModeBlocking) && gatingMode != string(GatingModeClassifying) {
353+
return nil, fmt.Errorf("local-max-concurrency gating_mode must be either 'blocking' or 'classifying', got %q", gatingMode)
354+
}
355+
gate.WithGatingMode(GatingMode(gatingMode))
356+
}
357+
return gate, nil
350358

351359
default:
352360
// Unknown gate types default to open gate

pkg/async/inference/flowcontrol/local_concurrency_gate.go

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,42 @@ import (
1010

1111
var _ pipeline.Gate = (*LocalConcurrencyGate)(nil)
1212

13+
type GatingMode string
14+
15+
const (
16+
GatingModeBlocking GatingMode = "blocking"
17+
GatingModeClassifying GatingMode = "classifying"
18+
)
19+
1320
// LocalConcurrencyGate limits the number of concurrent in-flight requests
1421
// processed from a single queue locally.
1522
type LocalConcurrencyGate struct {
16-
mu sync.Mutex
17-
limit int
18-
inFlight int
23+
mu sync.Mutex
24+
limit int
25+
inFlight int
26+
gatingMode GatingMode
27+
sem chan struct{}
1928
}
2029

2130
// NewLocalConcurrencyGate creates a new LocalConcurrencyGate with the specified limit.
2231
func NewLocalConcurrencyGate(limit int) *LocalConcurrencyGate {
2332
return &LocalConcurrencyGate{
24-
limit: limit,
33+
limit: limit,
34+
gatingMode: GatingModeClassifying,
2535
}
2636
}
2737

38+
// WithGatingMode configures the gating mode (blocking or classifying).
39+
func (g *LocalConcurrencyGate) WithGatingMode(mode GatingMode) *LocalConcurrencyGate {
40+
g.mu.Lock()
41+
defer g.mu.Unlock()
42+
g.gatingMode = mode
43+
if mode == GatingModeBlocking && g.limit > 0 {
44+
g.sem = make(chan struct{}, g.limit)
45+
}
46+
return g
47+
}
48+
2849
// Budget implements pipeline.Gate.
2950
// Returns the fraction of available capacity in [0.0, 1.0].
3051
func (g *LocalConcurrencyGate) Budget(ctx context.Context) float64 {
@@ -44,17 +65,43 @@ func (g *LocalConcurrencyGate) Budget(ctx context.Context) float64 {
4465
// Returns VerdictContinue if request fits in budget, VerdictRefuse with redeliver otherwise.
4566
func (g *LocalConcurrencyGate) Apply(ctx context.Context, msg *api.InternalRequest) (pipeline.Verdict, error) {
4667
g.mu.Lock()
47-
defer g.mu.Unlock()
4868

4969
if g.limit <= 0 {
70+
g.mu.Unlock()
5071
return pipeline.Refuse(), nil
5172
}
5273

74+
if g.gatingMode == GatingModeBlocking {
75+
sem := g.sem
76+
g.mu.Unlock()
77+
78+
select {
79+
case sem <- struct{}{}:
80+
g.mu.Lock()
81+
g.inFlight++
82+
g.mu.Unlock()
83+
84+
msg.AttachRelease(func() {
85+
<-sem
86+
g.mu.Lock()
87+
g.inFlight--
88+
g.mu.Unlock()
89+
})
90+
return pipeline.Continue(), nil
91+
case <-ctx.Done():
92+
return pipeline.Verdict{}, ctx.Err()
93+
}
94+
}
95+
96+
// Classifying/non-blocking mode
5397
if g.inFlight >= g.limit {
98+
g.mu.Unlock()
5499
return pipeline.Refuse(), nil
55100
}
56101

57102
g.inFlight++
103+
g.mu.Unlock()
104+
58105
msg.AttachRelease(func() {
59106
g.mu.Lock()
60107
g.inFlight--

pkg/async/inference/flowcontrol/local_concurrency_gate_test.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"sync"
66
"testing"
7+
"time"
78

89
"github.com/llm-d-incubation/llm-d-async/api"
910
"github.com/llm-d-incubation/llm-d-async/pipeline"
@@ -131,3 +132,93 @@ func TestLocalConcurrencyGate_Concurrency(t *testing.T) {
131132
// Budget should be fully recovered to 1.0
132133
assert.Equal(t, 1.0, gate.Budget(ctx))
133134
}
135+
136+
func TestLocalConcurrencyGate_BlockingMode(t *testing.T) {
137+
gate := NewLocalConcurrencyGate(2).WithGatingMode(GatingModeBlocking)
138+
ctx := context.Background()
139+
140+
// 1. Initial Budget is 1.0
141+
assert.Equal(t, 1.0, gate.Budget(ctx))
142+
143+
// 2. Admit 2 requests
144+
r1 := api.NewInternalRequest(api.InternalRouting{}, &api.RequestMessage{})
145+
v1, err := gate.Apply(ctx, r1)
146+
require.NoError(t, err)
147+
assert.Equal(t, pipeline.ActionContinue, v1.Action)
148+
149+
r2 := api.NewInternalRequest(api.InternalRouting{}, &api.RequestMessage{})
150+
v2, err := gate.Apply(ctx, r2)
151+
require.NoError(t, err)
152+
assert.Equal(t, pipeline.ActionContinue, v2.Action)
153+
154+
// Budget should now be 0.0
155+
assert.Equal(t, 0.0, gate.Budget(ctx))
156+
157+
// 3. Attempt to admit 3rd request in a separate goroutine (should block)
158+
r3 := api.NewInternalRequest(api.InternalRouting{}, &api.RequestMessage{})
159+
blockedCh := make(chan struct{})
160+
resultCh := make(chan pipeline.Verdict, 1)
161+
errCh := make(chan error, 1)
162+
163+
go func() {
164+
close(blockedCh)
165+
v, e := gate.Apply(ctx, r3)
166+
resultCh <- v
167+
errCh <- e
168+
}()
169+
170+
<-blockedCh
171+
// Sleep briefly to ensure the goroutine is indeed parked waiting on semaphore
172+
time.Sleep(50 * time.Millisecond)
173+
174+
select {
175+
case <-resultCh:
176+
t.Fatal("Apply should have blocked, but returned result immediately")
177+
default:
178+
// Passed: goroutine is blocked
179+
}
180+
181+
// 4. Release request 1
182+
r1.Release()
183+
184+
// 5. Goroutine should unblock and the request should be admitted
185+
select {
186+
case verdict := <-resultCh:
187+
require.NoError(t, <-errCh)
188+
assert.Equal(t, pipeline.ActionContinue, verdict.Action)
189+
case <-time.After(1 * time.Second):
190+
t.Fatal("timed out waiting for 3rd request to unblock")
191+
}
192+
193+
// Budget should be back to 0.0
194+
assert.Equal(t, 0.0, gate.Budget(ctx))
195+
196+
// Clean up
197+
r2.Release()
198+
r3.Release()
199+
assert.Equal(t, 1.0, gate.Budget(ctx))
200+
}
201+
202+
func TestLocalConcurrencyGate_BlockingModeCancel(t *testing.T) {
203+
gate := NewLocalConcurrencyGate(1).WithGatingMode(GatingModeBlocking)
204+
ctx := context.Background()
205+
206+
// Admit 1 request to exhaust capacity
207+
r1 := api.NewInternalRequest(api.InternalRouting{}, &api.RequestMessage{})
208+
v1, err := gate.Apply(ctx, r1)
209+
require.NoError(t, err)
210+
assert.Equal(t, pipeline.ActionContinue, v1.Action)
211+
212+
// Try to admit 2nd request with a cancelled context
213+
r2 := api.NewInternalRequest(api.InternalRouting{}, &api.RequestMessage{})
214+
cancelCtx, cancel := context.WithCancel(ctx)
215+
cancel() // cancel immediately
216+
217+
_, err = gate.Apply(cancelCtx, r2)
218+
assert.ErrorIs(t, err, context.Canceled)
219+
220+
// Clean up
221+
r1.Release()
222+
assert.Equal(t, 1.0, gate.Budget(ctx))
223+
}
224+

pkg/asyncworker/worker.go

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ const (
3030

3131
func Worker(consumeCtx, requestCtx context.Context, characteristics pipeline.Characteristics, client asyncapi.InferenceClient, requestChannel chan pipeline.EmbelishedRequestMessage,
3232
retryChannel chan pipeline.RetryMessage, resultChannel chan asyncapi.ResultMessage, requestTimeout time.Duration, transforms *transform.Chain) {
33+
WorkerWithGate(consumeCtx, requestCtx, characteristics, client, requestChannel, retryChannel, resultChannel, requestTimeout, transforms, nil)
34+
}
35+
36+
func WorkerWithGate(consumeCtx, requestCtx context.Context, characteristics pipeline.Characteristics, client asyncapi.InferenceClient, requestChannel chan pipeline.EmbelishedRequestMessage,
37+
retryChannel chan pipeline.RetryMessage, resultChannel chan asyncapi.ResultMessage, requestTimeout time.Duration, transforms *transform.Chain, poolGate pipeline.Gate) {
3338

3439
logger := log.FromContext(requestCtx)
3540
for {
@@ -75,9 +80,6 @@ func Worker(consumeCtx, requestCtx context.Context, characteristics pipeline.Cha
7580
}
7681

7782
processMessage := func() {
78-
metrics.IncInflight(queueID, queueName, msg.WorkerPoolID)
79-
defer metrics.DecInflight(queueID, queueName, msg.WorkerPoolID)
80-
8183
if msg.RetryCount == 0 {
8284
metrics.RecordAsyncReq(queueID, queueName, msg.WorkerPoolID)
8385
}
@@ -87,6 +89,75 @@ func Worker(consumeCtx, requestCtx context.Context, characteristics pipeline.Cha
8789
return
8890
}
8991

92+
snapshot := len(msg.Releases())
93+
defer msg.RollbackReleases(snapshot)
94+
95+
if poolGate != nil {
96+
reqDeadline := time.Now().Add(requestTimeout)
97+
if dline := msg.PublicRequest.ReqDeadline(); dline > 0 {
98+
if msgDeadline := time.Unix(dline, 0); msgDeadline.Before(reqDeadline) {
99+
reqDeadline = msgDeadline
100+
}
101+
}
102+
gateCtx, cancelGate := context.WithDeadline(requestCtx, reqDeadline)
103+
defer cancelGate()
104+
105+
var verdict pipeline.Verdict
106+
var err error
107+
for {
108+
verdict, err = poolGate.Apply(gateCtx, msg.InternalRequest)
109+
if err != nil {
110+
if errors.Is(err, context.DeadlineExceeded) || gateCtx.Err() != nil {
111+
metrics.RecordExceededDeadlineReq(queueID, queueName, msg.WorkerPoolID)
112+
select {
113+
case resultChannel <- CreateDeadlineExceededResultMessage(msg.PublicRequest, msg.InternalRouting):
114+
case <-requestCtx.Done():
115+
}
116+
return
117+
}
118+
select {
119+
case resultChannel <- CreateErrorResultMessage(msg.PublicRequest, msg.InternalRouting, fmt.Sprintf("Pool gating error: %s", err.Error())):
120+
case <-requestCtx.Done():
121+
}
122+
return
123+
}
124+
125+
if verdict.Action == pipeline.ActionContinue {
126+
break
127+
}
128+
129+
if verdict.Action == pipeline.ActionDrop {
130+
var resultMsg asyncapi.ResultMessage
131+
if verdict.Result != nil {
132+
resultMsg = *verdict.Result
133+
} else {
134+
resultMsg = CreateErrorResultMessage(msg.PublicRequest, msg.InternalRouting, "Pool gating dropped request")
135+
}
136+
select {
137+
case resultChannel <- resultMsg:
138+
case <-requestCtx.Done():
139+
}
140+
return
141+
}
142+
143+
// ActionRefuse: park/wait and retry
144+
select {
145+
case <-gateCtx.Done():
146+
metrics.RecordExceededDeadlineReq(queueID, queueName, msg.WorkerPoolID)
147+
select {
148+
case resultChannel <- CreateDeadlineExceededResultMessage(msg.PublicRequest, msg.InternalRouting):
149+
case <-requestCtx.Done():
150+
}
151+
return
152+
case <-time.After(50 * time.Millisecond):
153+
// poll again
154+
}
155+
}
156+
}
157+
158+
metrics.IncInflight(queueID, queueName, msg.WorkerPoolID)
159+
defer metrics.DecInflight(queueID, queueName, msg.WorkerPoolID)
160+
90161
sendInferenceRequest := func() {
91162
reqCtx := requestCtx
92163
if md := msg.PublicRequest.ReqMetadata(); len(md) > 0 {

0 commit comments

Comments
 (0)