Skip to content

Commit 949671e

Browse files
authored
client: HTTP timeout must be larger than the consume timeout (#233)
1 parent 03f4aec commit 949671e

File tree

3 files changed

+207
-6
lines changed

3 files changed

+207
-6
lines changed

client/client.go

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,45 @@ func (c *LmstfyClient) ConfigRetry(retryCount int, backOffMillisecond int) {
9999
c.backOff = backOffMillisecond
100100
}
101101

102+
// validateConsumeTimeout validates that the consume timeout is less than HTTP client timeout
103+
func (c *LmstfyClient) validateConsumeTimeout(timeoutSecond uint32) *APIError {
104+
if timeoutSecond == 0 {
105+
return nil
106+
}
107+
108+
httpTimeout := c.getHTTPTimeout()
109+
if httpTimeout <= time.Duration(timeoutSecond)*time.Second {
110+
return &APIError{
111+
Type: RequestErr,
112+
Reason: fmt.Sprintf("consume timeout (%d seconds) must be less than HTTP client timeout (%d seconds)",
113+
timeoutSecond, int(httpTimeout.Seconds())),
114+
}
115+
}
116+
return nil
117+
}
118+
119+
// getHTTPTimeout returns the HTTP client timeout duration
120+
func (c *LmstfyClient) getHTTPTimeout() time.Duration {
121+
if c.httpCli == nil {
122+
return maxReadTimeout * time.Second
123+
}
124+
125+
// Check the client timeout
126+
if c.httpCli.Timeout > 0 {
127+
return c.httpCli.Timeout
128+
}
129+
130+
// Try to get the timeout from the transport and client
131+
if transport, ok := c.httpCli.Transport.(*http.Transport); ok {
132+
if transport.ResponseHeaderTimeout > 0 {
133+
return transport.ResponseHeaderTimeout
134+
}
135+
}
136+
137+
// Default to maxReadTimeout
138+
return maxReadTimeout * time.Second
139+
}
140+
102141
func (c *LmstfyClient) getReq(method, relativePath string, query url.Values, body []byte) (req *http.Request, err error) {
103142
targetUrl := url.URL{
104143
Scheme: c.scheme,
@@ -242,7 +281,9 @@ RETRY:
242281
// - ttlSecond is the time-to-live of the job. If it's zero, job won't expire; if it's positive, the value is the TTL.
243282
// - tries is the maximum times the job can be fetched.
244283
// - delaySecond is the duration before the job is released for consuming. When it's zero, no delay is applied.
245-
func (c *LmstfyClient) BatchPublish(queue string, jobs []interface{}, ttlSecond uint32, tries uint16, delaySecond uint32) (jobIDs []string, e error) {
284+
func (c *LmstfyClient) BatchPublish(queue string, jobs []interface{}, ttlSecond uint32,
285+
tries uint16, delaySecond uint32,
286+
) (jobIDs []string, e error) {
246287
query := url.Values{}
247288
query.Add("ttl", strconv.FormatUint(uint64(ttlSecond), 10))
248289
query.Add("tries", strconv.FormatUint(uint64(tries), 10))
@@ -363,6 +404,11 @@ func (c *LmstfyClient) consume(queue string, ttrSecond, timeoutSecond uint32, fr
363404
Reason: fmt.Sprintf("timeout should be < %d", maxReadTimeout),
364405
}
365406
}
407+
408+
// Check if HTTP client timeout is larger than consume timeout
409+
if err := c.validateConsumeTimeout(timeoutSecond); err != nil {
410+
return nil, err
411+
}
366412
query := url.Values{}
367413
query.Add("ttr", strconv.FormatUint(uint64(ttrSecond), 10))
368414
query.Add("timeout", strconv.FormatUint(uint64(timeoutSecond), 10))
@@ -436,7 +482,9 @@ func (c *LmstfyClient) BatchConsume(queues []string, count, ttrSecond, timeoutSe
436482
// these job will be released for consuming again if the `(tries - 1) > 0`.
437483
// - count is the job count of this consume. If it's zero or over 100, this method will return an error.
438484
// If it's positive, this method would return some jobs, and it's count is between 0 and count.
439-
func (c *LmstfyClient) BatchConsumeWithFreezeTries(queues []string, count, ttrSecond, timeoutSecond uint32) (jobs []*Job, e error) {
485+
func (c *LmstfyClient) BatchConsumeWithFreezeTries(queues []string,
486+
count, ttrSecond, timeoutSecond uint32,
487+
) (jobs []*Job, e error) {
440488
return c.batchConsume(queues, count, ttrSecond, timeoutSecond, true)
441489
}
442490

@@ -473,6 +521,11 @@ func (c *LmstfyClient) batchConsume(queues []string, count, ttrSecond, timeoutSe
473521
}
474522
}
475523

524+
// Check if HTTP client timeout is larger than consume timeout
525+
if err := c.validateConsumeTimeout(timeoutSecond); err != nil {
526+
return nil, err
527+
}
528+
476529
query := url.Values{}
477530
query.Add("ttr", strconv.FormatUint(uint64(ttrSecond), 10))
478531
query.Add("count", strconv.FormatUint(uint64(count), 10))
@@ -569,6 +622,11 @@ func (c *LmstfyClient) consumeFromQueues(ttrSecond, timeoutSecond uint32, freeze
569622
Reason: fmt.Sprintf("timeout must be < %d when fetch from multiple queues", maxReadTimeout),
570623
}
571624
}
625+
626+
// Check if HTTP client timeout is larger than consume timeout
627+
if err := c.validateConsumeTimeout(timeoutSecond); err != nil {
628+
return nil, err
629+
}
572630
query := url.Values{}
573631
query.Add("ttr", strconv.FormatUint(uint64(ttrSecond), 10))
574632
query.Add("timeout", strconv.FormatUint(uint64(timeoutSecond), 10))

client/client_test.go

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package client
33
import (
44
"bytes"
55
"encoding/json"
6+
"errors"
7+
"net/http"
68
"testing"
79
"time"
810

@@ -307,3 +309,138 @@ func TestLmstfyClient_DeleteDeadLetter(t *testing.T) {
307309
t.Fatal("delete deadletter failed")
308310
}
309311
}
312+
313+
func TestLmstfyClient_ValidateConsumeTimeout(t *testing.T) {
314+
// Test with default HTTP client (600 seconds timeout)
315+
cli := NewLmstfyClient(Host, Port, Namespace, Token)
316+
317+
// Test valid timeout (should return nil)
318+
err := cli.validateConsumeTimeout(300)
319+
if err != nil {
320+
t.Fatalf("Expected nil for valid timeout, got: %v", err)
321+
}
322+
323+
// Test zero timeout (should return nil)
324+
err = cli.validateConsumeTimeout(0)
325+
if err != nil {
326+
t.Fatalf("Expected nil for zero timeout, got: %v", err)
327+
}
328+
329+
// Test invalid timeout (should return error)
330+
err = cli.validateConsumeTimeout(600)
331+
if err == nil {
332+
t.Fatal("Expected error for timeout >= HTTP timeout")
333+
}
334+
expectedMsg := "consume timeout (600 seconds) must be less than HTTP client timeout (600 seconds)"
335+
var apiErr *APIError
336+
if !errors.As(err, &apiErr) || apiErr.Reason != expectedMsg {
337+
t.Fatalf("Expected error message '%s', got '%v'", expectedMsg, err)
338+
}
339+
}
340+
341+
func TestLmstfyClient_ConsumeWithTimeoutValidation(t *testing.T) {
342+
// Test with custom HTTP client with short timeout
343+
shortHTTPClient := &http.Client{
344+
Timeout: 5 * time.Second,
345+
}
346+
cli := NewLmstfyWithClient(shortHTTPClient, Host, Port, Namespace, Token)
347+
348+
// Test consume with valid timeout
349+
job, err := cli.Consume("test-timeout-validation", 10, 3)
350+
if err != nil {
351+
t.Fatalf("Consume should succeed with valid timeout: %v", err)
352+
}
353+
if job != nil {
354+
cli.Ack("test-timeout-validation", job.ID)
355+
}
356+
357+
// Test consume with invalid timeout (should fail)
358+
_, err = cli.Consume("test-timeout-validation", 10, 6)
359+
if err == nil {
360+
t.Fatal("Consume should fail with timeout >= HTTP timeout")
361+
}
362+
expectedMsg := "consume timeout (6 seconds) must be less than HTTP client timeout (5 seconds)"
363+
var apiErr *APIError
364+
if !errors.As(err, &apiErr) || apiErr.Reason != expectedMsg {
365+
t.Fatalf("Expected error message '%s', got '%v'", expectedMsg, err)
366+
}
367+
}
368+
369+
func TestLmstfyClient_BatchConsumeWithTimeoutValidation(t *testing.T) {
370+
// Test with custom HTTP client with short timeout
371+
shortHTTPClient := &http.Client{
372+
Timeout: 5 * time.Second,
373+
}
374+
cli := NewLmstfyWithClient(shortHTTPClient, Host, Port, Namespace, Token)
375+
376+
// Test batch consume with valid timeout
377+
queues := []string{"test-batch-timeout-validation"}
378+
jobs, err := cli.BatchConsume(queues, 3, 10, 3)
379+
if err != nil {
380+
t.Fatalf("BatchConsume should succeed with valid timeout: %v", err)
381+
}
382+
383+
// Ack any jobs that were returned
384+
for _, job := range jobs {
385+
cli.Ack("test-batch-timeout-validation", job.ID)
386+
}
387+
388+
// Test batch consume with invalid timeout (should fail)
389+
_, err = cli.BatchConsume(queues, 3, 10, 6)
390+
if err == nil {
391+
t.Fatal("BatchConsume should fail with timeout >= HTTP timeout")
392+
}
393+
expectedMsg := "consume timeout (6 seconds) must be less than HTTP client timeout (5 seconds)"
394+
var apiErr *APIError
395+
if !errors.As(err, &apiErr) || apiErr.Reason != expectedMsg {
396+
t.Fatalf("Expected error message '%s', got '%v'", expectedMsg, err)
397+
}
398+
}
399+
400+
func TestLmstfyClient_ConsumeFromQueuesWithTimeoutValidation(t *testing.T) {
401+
// Test with custom HTTP client with short timeout
402+
shortHTTPClient := &http.Client{
403+
Timeout: 5 * time.Second,
404+
}
405+
cli := NewLmstfyWithClient(shortHTTPClient, Host, Port, Namespace, Token)
406+
407+
// Test consume from queues with valid timeout
408+
job, err := cli.ConsumeFromQueues(10, 3, "test-multi-queue-timeout-validation")
409+
if err != nil {
410+
t.Fatalf("ConsumeFromQueues should succeed with valid timeout: %v", err)
411+
}
412+
if job != nil {
413+
cli.Ack("test-multi-queue-timeout-validation", job.ID)
414+
}
415+
416+
// Test consume from queues with invalid timeout (should fail)
417+
_, err = cli.ConsumeFromQueues(10, 6, "test-multi-queue-timeout-validation")
418+
if err == nil {
419+
t.Fatal("ConsumeFromQueues should fail with timeout >= HTTP timeout")
420+
}
421+
expectedMsg := "consume timeout (6 seconds) must be less than HTTP client timeout (5 seconds)"
422+
var apiErr *APIError
423+
if !errors.As(err, &apiErr) || apiErr.Reason != expectedMsg {
424+
t.Fatalf("Expected error message '%s', got '%v'", expectedMsg, err)
425+
}
426+
}
427+
428+
func TestLmstfyClient_GetHTTPTimeout(t *testing.T) {
429+
// Test with default HTTP client
430+
cli := NewLmstfyClient(Host, Port, Namespace, Token)
431+
timeout := cli.getHTTPTimeout()
432+
if timeout != 600*time.Second {
433+
t.Fatalf("Expected 600 seconds timeout for default client, got %v", timeout)
434+
}
435+
436+
// Test with custom HTTP client
437+
customTimeout := 30 * time.Second
438+
customClient := &http.Client{
439+
Timeout: customTimeout,
440+
}
441+
cli2 := NewLmstfyWithClient(customClient, Host, Port, Namespace, Token)
442+
timeout = cli2.getHTTPTimeout()
443+
if timeout != customTimeout {
444+
t.Fatalf("Expected %v timeout for custom client, got %v", customTimeout, timeout)
445+
}
446+
}

engine/redis/queue.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,9 @@ func (q *Queue) Destroy() (count int64, err error) {
115115
poolPrefix := PoolJobKeyPrefix(q.name.Namespace, q.name.Queue)
116116
var batchSize int64 = 100
117117
for {
118-
val, err := q.redis.Conn.EvalSha(dummyCtx, q.destroySHA, []string{q.Name(), poolPrefix}, batchSize).Result()
118+
val, err := q.redis.Conn.EvalSha(dummyCtx, q.destroySHA, []string{
119+
q.Name(), poolPrefix,
120+
}, batchSize).Result()
119121
if err != nil {
120122
if isLuaScriptGone(err) {
121123
if err := PreloadDeadLetterLuaScript(q.redis); err != nil {
@@ -174,7 +176,9 @@ func popMultiQueues(redis *RedisInstance, queueNames []string) (string, string,
174176
}
175177

176178
// Poll from multiple queues using blocking method; OR pop a job from one queue using non-blocking method
177-
func PollQueues(redis *RedisInstance, timer *Timer, queueNames []QueueName, timeoutSecond, ttrSecond uint32) (queueName *QueueName, jobID string, retries uint16, err error) {
179+
func PollQueues(redis *RedisInstance, timer *Timer, queueNames []QueueName,
180+
timeoutSecond, ttrSecond uint32,
181+
) (queueName *QueueName, jobID string, retries uint16, err error) {
178182
defer func() {
179183
if jobID != "" {
180184
metrics.queuePopJobs.WithLabelValues(redis.Name).Inc()
@@ -236,7 +240,8 @@ func PollQueues(redis *RedisInstance, timer *Timer, queueNames []QueueName, time
236240
}
237241

238242
// Pack (tries, jobID) into lua struct pack of format "HHHc0", in lua this can be done:
239-
// ```local data = struct.pack("HHc0", tries, #job_id, job_id)```
243+
//
244+
// ```local data = struct.pack("HHc0", tries, #job_id, job_id)```
240245
func structPack(tries uint16, jobID string) (data string) {
241246
buf := make([]byte, 2+2+len(jobID))
242247
binary.LittleEndian.PutUint16(buf[0:], tries)
@@ -246,7 +251,8 @@ func structPack(tries uint16, jobID string) (data string) {
246251
}
247252

248253
// Unpack the "HHc0" lua struct format, in lua this can be done:
249-
// ```local tries, job_id = struct.unpack("HHc0", data)```
254+
//
255+
// ```local tries, job_id = struct.unpack("HHc0", data)```
250256
func structUnpack(data string) (tries uint16, jobID string, err error) {
251257
buf := []byte(data)
252258
h1 := binary.LittleEndian.Uint16(buf[0:])

0 commit comments

Comments
 (0)