Skip to content

Commit f9a548c

Browse files
authored
fix: make retry dequeue atomic with Lua script (#108)
Signed-off-by: Raymond Zhao <zhaoeryi@gmail.com>
1 parent aae7fe8 commit f9a548c

2 files changed

Lines changed: 225 additions & 26 deletions

File tree

pkg/redis/redisimpl.go

Lines changed: 82 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@ import (
66
"flag"
77
"fmt"
88
"os"
9-
10-
"strconv"
119
"time"
1210

1311
"github.com/llm-d-incubation/llm-d-async/pkg/async/api"
@@ -41,6 +39,31 @@ var (
4139
queuesConfigFile = flag.String("redis.queues-config-file", "", "Queues Configuration file. Mutually exclusive with redis.igw-base-url, redis.request-queue-name, redis.request-path-url and redis.inference-objective flags. See documentation about syntax")
4240
)
4341

42+
const retryPopBatchSize = 100
43+
44+
// popDueRetryMessagesScript atomically fetches due retry entries (score <= now) and removes them.
45+
var popDueRetryMessagesScript = redis.NewScript(`
46+
local key = KEYS[1]
47+
local now = tonumber(ARGV[1])
48+
local limit = tonumber(ARGV[2])
49+
50+
local items = redis.call("ZRANGEBYSCORE", key, "-inf", now, "LIMIT", 0, limit)
51+
if #items > 0 then
52+
-- Chunk ZREM arguments to avoid Lua unpack stack limits if
53+
-- limit is increased significantly in the future.
54+
local chunk_size = 1000
55+
for i = 1, #items, chunk_size do
56+
local last = math.min(i + chunk_size - 1, #items)
57+
local chunk = {}
58+
for j = i, last do
59+
chunk[#chunk + 1] = items[j]
60+
end
61+
redis.call("ZREM", key, unpack(chunk))
62+
end
63+
end
64+
return items
65+
`)
66+
4467
type QueueConfig struct {
4568
QueueName string `json:"queue_name"`
4669
InferenceObjective string `json:"inference_objective"`
@@ -243,6 +266,7 @@ func addMsgToRetryWorker(ctx context.Context, rdb *redis.Client, retryChannel ch
243266

244267
// Every second polls the sorted set and publishes the messages that need to be retried into the request queue
245268
func (r *RedisMQFlow) retryWorker(ctx context.Context, rdb *redis.Client) {
269+
logger := log.FromContext(ctx)
246270
// create a map of queuename to channel based on requestchannels
247271
msgChannels := make(map[string]chan api.RequestMessage)
248272
for _, channelData := range r.requestChannels {
@@ -255,38 +279,70 @@ func (r *RedisMQFlow) retryWorker(ctx context.Context, rdb *redis.Client) {
255279
return
256280

257281
default:
258-
currentTimeSec := float64(time.Now().Unix())
259-
260-
results, err := rdb.ZRangeArgs(ctx, redis.ZRangeArgs{
261-
Key: *retryQueueName,
262-
Start: "0",
263-
Stop: strconv.FormatFloat(currentTimeSec, 'f', -1, 64),
264-
ByScore: true,
265-
}).Result()
266-
if err != nil {
267-
panic(err)
268-
}
269-
for _, msg := range results {
270-
var message api.RequestMessage
271-
err := json.Unmarshal([]byte(msg), &message)
272-
if err != nil {
273-
fmt.Println(err)
282+
// Keep one fixed cutoff for this drain cycle so we only process
283+
// messages due at cycle start, avoiding an ever-expanding window.
284+
currentTimeSec := time.Now().Unix()
274285

275-
}
276-
err = rdb.ZRem(ctx, *retryQueueName, msg).Err()
286+
for {
287+
results, err := popDueRetryMessages(ctx, rdb, *retryQueueName, currentTimeSec, retryPopBatchSize)
277288
if err != nil {
278-
fmt.Println(err)
279-
289+
logger.V(logutil.DEFAULT).Error(err, "Failed to atomically pop due retry messages")
290+
break
291+
}
292+
if len(results) == 0 {
293+
break
280294
}
281-
queueName := message.Metadata[QUEUE_NAME_KEY]
282295

283-
// TODO: We probably want to write here back to the request queue/channel in Redis. Adding the msg to the
284-
// golang channel directly is not that wise as this might be blocking.
285-
msgChannels[queueName] <- message
296+
for _, msg := range results {
297+
var message api.RequestMessage
298+
err := json.Unmarshal([]byte(msg), &message)
299+
if err != nil {
300+
logger.V(logutil.DEFAULT).Error(err, "Failed to unmarshal retry message")
301+
continue
302+
}
303+
queueName := message.Metadata[QUEUE_NAME_KEY]
304+
msgChannel, ok := msgChannels[queueName]
305+
if !ok {
306+
logger.V(logutil.DEFAULT).Info("Unknown retry queue, dropping message", "queueName", queueName, "messageId", message.Id)
307+
continue
308+
}
309+
310+
// TODO: We probably want to write here back to the request queue/channel in Redis. Adding the msg to the
311+
// golang channel directly is not that wise as this might be blocking.
312+
select {
313+
case msgChannel <- message:
314+
case <-ctx.Done():
315+
return
316+
}
317+
}
286318
}
287319
time.Sleep(time.Second)
288320
}
289321
}
290322

291323
}
292324

325+
// popDueRetryMessages atomically pops up to limit retry messages whose score is <= nowUnixSec.
326+
// It returns the raw message payloads removed from the sorted set.
327+
func popDueRetryMessages(ctx context.Context, rdb *redis.Client, key string, nowUnixSec int64, limit int) ([]string, error) {
328+
raw, err := popDueRetryMessagesScript.Run(ctx, rdb, []string{key}, nowUnixSec, limit).Result()
329+
if err != nil {
330+
return nil, err
331+
}
332+
333+
entries, ok := raw.([]interface{})
334+
if !ok {
335+
return nil, fmt.Errorf("unexpected script result type: %T", raw)
336+
}
337+
338+
messages := make([]string, 0, len(entries))
339+
for _, entry := range entries {
340+
msg, ok := entry.(string)
341+
if !ok {
342+
return nil, fmt.Errorf("unexpected script entry type: %T", entry)
343+
}
344+
messages = append(messages, msg)
345+
}
346+
347+
return messages, nil
348+
}

pkg/redis/redisimpl_test.go

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,3 +240,146 @@ func TestPubsubResultWorker_ConcurrentProducers(t *testing.T) {
240240
}
241241
}
242242
}
243+
244+
func TestPopDueRetryMessages_PopsDueAndRemovesFromSortedSet(t *testing.T) {
245+
s := miniredis.RunT(t)
246+
defer s.Close()
247+
rdb := redis.NewClient(&redis.Options{Addr: s.Addr()})
248+
defer rdb.Close() // nolint:errcheck
249+
250+
ctx := context.Background()
251+
queue := "retry-pop-test"
252+
now := time.Now().Unix()
253+
254+
due := api.RequestMessage{
255+
Id: "due",
256+
Metadata: map[string]string{QUEUE_NAME_KEY: "request-queue"},
257+
}
258+
future := api.RequestMessage{
259+
Id: "future",
260+
Metadata: map[string]string{QUEUE_NAME_KEY: "request-queue"},
261+
}
262+
263+
dueBytes, err := json.Marshal(due)
264+
if err != nil {
265+
t.Fatalf("marshal due message: %v", err)
266+
}
267+
futureBytes, err := json.Marshal(future)
268+
if err != nil {
269+
t.Fatalf("marshal future message: %v", err)
270+
}
271+
272+
if err := rdb.ZAdd(ctx, queue,
273+
redis.Z{Score: float64(now - 1), Member: string(dueBytes)},
274+
redis.Z{Score: float64(now + 60), Member: string(futureBytes)},
275+
).Err(); err != nil {
276+
t.Fatalf("seed retry sorted set: %v", err)
277+
}
278+
279+
items, err := popDueRetryMessages(ctx, rdb, queue, now, 10)
280+
if err != nil {
281+
t.Fatalf("pop due messages: %v", err)
282+
}
283+
if len(items) != 1 {
284+
t.Fatalf("expected exactly one popped message, got %d", len(items))
285+
}
286+
287+
var popped api.RequestMessage
288+
if err := json.Unmarshal([]byte(items[0]), &popped); err != nil {
289+
t.Fatalf("unmarshal popped message: %v", err)
290+
}
291+
if popped.Id != "due" {
292+
t.Fatalf("expected popped message id 'due', got %q", popped.Id)
293+
}
294+
295+
remaining, err := rdb.ZCard(ctx, queue).Result()
296+
if err != nil {
297+
t.Fatalf("read remaining queue size: %v", err)
298+
}
299+
if remaining != 1 {
300+
t.Fatalf("expected one remaining future message, got %d", remaining)
301+
}
302+
}
303+
304+
func TestPopDueRetryMessages_ConcurrentCallers_NoDuplicatePops(t *testing.T) {
305+
s := miniredis.RunT(t)
306+
defer s.Close()
307+
rdb := redis.NewClient(&redis.Options{Addr: s.Addr()})
308+
defer rdb.Close() // nolint:errcheck
309+
310+
ctx := context.Background()
311+
queue := "retry-pop-concurrent-test"
312+
now := time.Now().Unix()
313+
totalMessages := 40
314+
315+
for i := 0; i < totalMessages; i++ {
316+
msg := api.RequestMessage{
317+
Id: "msg-" + strconv.Itoa(i),
318+
Metadata: map[string]string{QUEUE_NAME_KEY: "request-queue"},
319+
}
320+
msgBytes, err := json.Marshal(msg)
321+
if err != nil {
322+
t.Fatalf("marshal seed message %d: %v", i, err)
323+
}
324+
if err := rdb.ZAdd(ctx, queue, redis.Z{
325+
Score: float64(now),
326+
Member: string(msgBytes),
327+
}).Err(); err != nil {
328+
t.Fatalf("seed retry queue: %v", err)
329+
}
330+
}
331+
332+
var (
333+
wg sync.WaitGroup
334+
mu sync.Mutex
335+
seenID = make(map[string]int, totalMessages)
336+
)
337+
338+
workerCount := 4
339+
for i := 0; i < workerCount; i++ {
340+
wg.Add(1)
341+
go func() {
342+
defer wg.Done()
343+
for {
344+
items, err := popDueRetryMessages(ctx, rdb, queue, now, 3)
345+
if err != nil {
346+
t.Errorf("pop due retry messages: %v", err)
347+
return
348+
}
349+
if len(items) == 0 {
350+
return
351+
}
352+
353+
for _, raw := range items {
354+
var msg api.RequestMessage
355+
if err := json.Unmarshal([]byte(raw), &msg); err != nil {
356+
t.Errorf("unmarshal popped message: %v", err)
357+
return
358+
}
359+
mu.Lock()
360+
seenID[msg.Id]++
361+
mu.Unlock()
362+
}
363+
}
364+
}()
365+
}
366+
wg.Wait()
367+
368+
if len(seenID) != totalMessages {
369+
t.Fatalf("expected %d unique popped messages, got %d", totalMessages, len(seenID))
370+
}
371+
372+
for id, count := range seenID {
373+
if count != 1 {
374+
t.Fatalf("message %s popped %d times, expected exactly once", id, count)
375+
}
376+
}
377+
378+
remaining, err := rdb.ZCard(ctx, queue).Result()
379+
if err != nil {
380+
t.Fatalf("read remaining queue size: %v", err)
381+
}
382+
if remaining != 0 {
383+
t.Fatalf("expected queue to be empty after concurrent pops, got %d", remaining)
384+
}
385+
}

0 commit comments

Comments
 (0)