diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 8464e25..58dca14 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -19,7 +19,7 @@ jobs:
redis-version: 7
- uses: namoshek/rabbitmq-github-action@v1
with:
- version: '3.8.9'
+ version: '4.1.1'
ports: '5672:5672'
- run: go test -count=1 -v ./...
lint:
diff --git a/README.md b/README.md
index 4eb15ab..a7f9fe6 100644
--- a/README.md
+++ b/README.md
@@ -103,12 +103,16 @@ $ celery --app myproject worker --queues important --loglevel=debug --without-he
Sending tasks from Python and receiving them on Go side.
```sh
-$ python producer.py --protocol=1
+$ python producer.py
$ go run ./consumer/
{"msg":"waiting for tasks..."}
received a=fizz b=bazz
```
+To send a task with Celery Protocol version 1, run *producer.py* with the *--protocol=1* command-line argument:
+```sh
+$ python producer.py --protocol=1
+```
@@ -214,6 +218,10 @@ $ go run ./consumer/
received a=fizz b=bazz
```
+To send a task with Celery Protocol version 1, run *producer.py* with the *--protocol=1* command-line argument:
+```sh
+$ python producer.py --protocol=1
+```
## Testing
diff --git a/celery_test.go b/celery_test.go
index adf987e..187fd1b 100644
--- a/celery_test.go
+++ b/celery_test.go
@@ -13,6 +13,7 @@ import (
"github.com/marselester/gopher-celery/goredis"
"github.com/marselester/gopher-celery/protocol"
+ "github.com/marselester/gopher-celery/rabbitmq"
)
func TestExecuteTaskPanic(t *testing.T) {
@@ -243,6 +244,57 @@ func TestGoredisProduceAndConsume100times(t *testing.T) {
}
}
+func TestRabbitmqProduceAndConsume100times(t *testing.T) {
+ app := NewApp(
+ WithBroker(rabbitmq.NewBroker(rabbitmq.WithAmqpUri("amqp://guest:guest@localhost:5672/"))),
+ WithLogger(log.NewJSONLogger(os.Stderr)),
+ )
+
+ queue := "rabbitmq_broker_test"
+
+ // Create the queue, if it doesn't exist.
+ app.conf.broker.Observe([]string{queue})
+
+ for i := 0; i < 100; i++ {
+ err := app.Delay(
+ "myproject.apps.myapp.tasks.mytask",
+ queue,
+ 2,
+ 3,
+ )
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ // The test finishes either when ctx times out or all the tasks finish.
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ t.Cleanup(cancel)
+
+ var sum int32
+ app.Register(
+ "myproject.apps.myapp.tasks.mytask",
+ queue,
+ func(ctx context.Context, p *TaskParam) error {
+ p.NameArgs("a", "b")
+ atomic.AddInt32(
+ &sum,
+ int32(p.MustInt("a")+p.MustInt("b")),
+ )
+ return nil
+ },
+ )
+ if err := app.Run(ctx); err != nil {
+ t.Error(err)
+ }
+
+ var want int32 = 500
+ if want != sum {
+ t.Errorf("expected sum %d got %d", want, sum)
+ }
+
+}
+
func TestConsumeSequentially(t *testing.T) {
app := NewApp(
WithLogger(log.NewJSONLogger(os.Stderr)),
diff --git a/rabbitmq/broker.go b/rabbitmq/broker.go
index 12a0bb0..a101e47 100644
--- a/rabbitmq/broker.go
+++ b/rabbitmq/broker.go
@@ -3,7 +3,6 @@
package rabbitmq
import (
- "context"
"encoding/base64"
"encoding/json"
"fmt"
@@ -33,7 +32,7 @@ type Broker struct {
queues []string
conn *amqp.Connection
channel *amqp.Channel
- ctx context.Context
+ delivery map[string]<-chan amqp.Delivery
}
// WithAmqpUri sets the AMQP connection URI to RabbitMQ.
@@ -67,30 +66,27 @@ func NewBroker(options ...BrokerOption) *Broker {
amqpUri: DefaultAmqpUri,
receiveTimeout: DefaultReceiveTimeout * time.Second,
rawMode: false,
- ctx: context.Background(),
+ delivery: make(map[string]<-chan amqp.Delivery),
}
for _, opt := range options {
opt(&br)
}
if br.conn == nil {
- br.channel = nil
conn, err := amqp.Dial(br.amqpUri)
- br.conn = conn
if err != nil {
log.Panicf("Failed to connect to RabbitMQ: %s", err)
return nil
}
+ br.conn = conn
}
- if br.channel == nil {
- channel, err := br.conn.Channel()
- br.channel = channel
- if err != nil {
- log.Panicf("Failed to open a channel: %s", err)
- return nil
- }
+ channel, err := br.conn.Channel()
+ if err != nil {
+ log.Panicf("Failed to open a channel: %s", err)
+ return nil
}
+ br.channel = channel
return &br
}
@@ -135,8 +131,7 @@ func (br *Broker) Send(m []byte, q string) error {
replyTo = properties_in["reply_to"].(string)
}
- err := br.channel.PublishWithContext(
- br.ctx,
+ err := br.channel.Publish(
"", // exchange
q, // routing key
false, // mandatory
@@ -150,6 +145,7 @@ func (br *Broker) Send(m []byte, q string) error {
ReplyTo: replyTo,
Body: body,
})
+
return err
}
@@ -158,16 +154,44 @@ func (br *Broker) Send(m []byte, q string) error {
func (br *Broker) Observe(queues []string) {
br.queues = queues
for _, queue := range queues {
- _, err := br.channel.QueueDeclare(
- queue, // name
- true, // durable
- false, // delete when unused
- false, // exclusive
- false, // no-wait
- nil, // arguments
+ durable := true
+ autoDelete := false
+ exclusive := false
+ noWait := false
+
+ // Check whether the queue exists.
+ _, err := br.channel.QueueDeclarePassive(
+ queue,
+ durable,
+ autoDelete,
+ exclusive,
+ noWait,
+ nil,
)
+
+ // If the queue doesn't exist, attempt to create it.
if err != nil {
- log.Panicf("Failed to declare a queue: %s", err)
+ // QueueDeclarePassive() will close the channel if the queue does not exist, so we have to create a new channel when this happens.
+ if br.channel.IsClosed() {
+ channel, err := br.conn.Channel()
+ if err != nil {
+ log.Panicf("Failed to open a channel: %s", err)
+ }
+ br.channel = channel
+ }
+
+ _, err := br.channel.QueueDeclare(
+ queue,
+ durable,
+ autoDelete,
+ exclusive,
+ noWait,
+ nil,
+ )
+
+ if err != nil {
+ log.Panicf("Failed to declare a queue: %s", err)
+ }
}
}
}
@@ -175,70 +199,69 @@ func (br *Broker) Observe(queues []string) {
// Receive fetches a Celery task message from a tail of one of the queues in RabbitMQ.
// After a timeout it returns nil, nil.
func (br *Broker) Receive() ([]byte, error) {
+ queue := br.queues[0]
+ // Put the Celery queue name to the end of the slice for fair processing.
+ broker.Move2back(br.queues, queue)
- const retryIntervalMs = 100
+ var err error
+
+ delivery, delivery_exists := br.delivery[queue]
+ if !delivery_exists {
+ delivery, err = br.channel.Consume(
+ queue, // queue
+ "", // consumer
+ true, // autoAck
+ false, // exclusive
+ false, // noLocal (ignored)
+ false, // noWait
+ nil, // args
+ )
- try_receive := func() (msg amqp.Delivery, ok bool, err error) {
- queue := br.queues[0]
- // Put the Celery queue name to the end of the slice for fair processing.
- broker.Move2back(br.queues, queue)
- my_msg, my_ok, my_err := br.channel.Get(queue, true)
- if my_err != nil {
- log.Printf("Failed to g a message: %s", my_err)
+ if err != nil {
+ return nil, err
}
- return my_msg, my_ok, my_err
- }
- startTime := time.Now()
- timeoutTime := startTime.Add(br.receiveTimeout)
- msg, ok, err := try_receive()
- if err != nil {
- return nil, nil
+ br.delivery[queue] = delivery
}
- for !ok {
- if time.Now().After(timeoutTime) {
- return nil, nil
+
+ select {
+ case msg := <-delivery:
+ if br.rawMode {
+ return msg.Body, nil
}
- time.Sleep(retryIntervalMs * time.Millisecond)
- msg, ok, err = try_receive()
- if err != nil {
- return nil, nil
+ // Marshal msg from RabbitMQ Celery format to internal Celery format.
+
+ properties := make(map[string]interface{})
+ properties["correlation_id"] = msg.CorrelationId
+ properties["reply_to"] = msg.ReplyTo
+ properties["delivery_mode"] = msg.DeliveryMode
+ properties["delivery_info"] = map[string]interface{}{
+ "exchange": msg.Exchange,
+ "routing_key": msg.RoutingKey,
}
- }
+ properties["priority"] = msg.Priority
+ properties["body_encoding"] = "base64"
+ properties["delivery_tag"] = msg.DeliveryTag
- if br.rawMode {
- return msg.Body, nil
- }
+ imsg := make(map[string]interface{})
+ imsg["body"] = msg.Body
+ imsg["content-encoding"] = msg.ContentEncoding
+ imsg["content-type"] = msg.ContentType
+ imsg["headers"] = msg.Headers
+ imsg["properties"] = properties
- // Marshal msg from RabbitMQ Celery format to internal Celery format.
-
- properties := make(map[string]interface{})
- properties["correlation_id"] = msg.CorrelationId
- properties["reply_to"] = msg.ReplyTo
- properties["delivery_mode"] = msg.DeliveryMode
- delivery_info := make(map[string]interface{})
- properties["delivery_info"] = delivery_info
- delivery_info["exchange"] = msg.Exchange
- delivery_info["routing_key"] = msg.RoutingKey
- properties["priority"] = msg.Priority
- properties["body_encoding"] = "base64"
- properties["delivery_tag"] = msg.DeliveryTag
-
- imsg := make(map[string]interface{})
- imsg["body"] = msg.Body
- imsg["content-encoding"] = msg.ContentEncoding
- imsg["content-type"] = msg.ContentType
- imsg["headers"] = msg.Headers
- imsg["properties"] = properties
-
- var result []byte
- result, err = json.Marshal(imsg)
- if err != nil {
- err_str := fmt.Errorf("%w", err)
- log.Printf("json encode: %s", err_str)
+ var result []byte
+ result, err := json.Marshal(imsg)
+ if err != nil {
+ err_str := fmt.Errorf("%w", err)
+ log.Printf("json encode: %s", err_str)
+ return nil, err
+ }
+ return result, nil
+
+ case <-time.After(br.receiveTimeout):
+ // Receive timeout
return nil, nil
}
-
- return result, nil
}