diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8464e25..58dca14 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,7 +19,7 @@ jobs: redis-version: 7 - uses: namoshek/rabbitmq-github-action@v1 with: - version: '3.8.9' + version: '4.1.1' ports: '5672:5672' - run: go test -count=1 -v ./... lint: diff --git a/README.md b/README.md index 4eb15ab..a7f9fe6 100644 --- a/README.md +++ b/README.md @@ -103,12 +103,16 @@ $ celery --app myproject worker --queues important --loglevel=debug --without-he Sending tasks from Python and receiving them on Go side. ```sh -$ python producer.py --protocol=1 +$ python producer.py $ go run ./consumer/ {"msg":"waiting for tasks..."} received a=fizz b=bazz ``` +To send a task with Celery Protocol version 1, run *producer.py* with the *--protocol=1* command-line argument: +```sh +$ python producer.py --protocol=1 +```
@@ -214,6 +218,10 @@ $ go run ./consumer/ received a=fizz b=bazz ``` +To send a task with Celery Protocol version 1, run *producer.py* with the *--protocol=1* command-line argument: +```sh +$ python producer.py --protocol=1 +```
## Testing diff --git a/celery_test.go b/celery_test.go index adf987e..187fd1b 100644 --- a/celery_test.go +++ b/celery_test.go @@ -13,6 +13,7 @@ import ( "github.com/marselester/gopher-celery/goredis" "github.com/marselester/gopher-celery/protocol" + "github.com/marselester/gopher-celery/rabbitmq" ) func TestExecuteTaskPanic(t *testing.T) { @@ -243,6 +244,57 @@ func TestGoredisProduceAndConsume100times(t *testing.T) { } } +func TestRabbitmqProduceAndConsume100times(t *testing.T) { + app := NewApp( + WithBroker(rabbitmq.NewBroker(rabbitmq.WithAmqpUri("amqp://guest:guest@localhost:5672/"))), + WithLogger(log.NewJSONLogger(os.Stderr)), + ) + + queue := "rabbitmq_broker_test" + + // Create the queue, if it doesn't exist. + app.conf.broker.Observe([]string{queue}) + + for i := 0; i < 100; i++ { + err := app.Delay( + "myproject.apps.myapp.tasks.mytask", + queue, + 2, + 3, + ) + if err != nil { + t.Fatal(err) + } + } + + // The test finishes either when ctx times out or all the tasks finish. + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + t.Cleanup(cancel) + + var sum int32 + app.Register( + "myproject.apps.myapp.tasks.mytask", + queue, + func(ctx context.Context, p *TaskParam) error { + p.NameArgs("a", "b") + atomic.AddInt32( + &sum, + int32(p.MustInt("a")+p.MustInt("b")), + ) + return nil + }, + ) + if err := app.Run(ctx); err != nil { + t.Error(err) + } + + var want int32 = 500 + if want != sum { + t.Errorf("expected sum %d got %d", want, sum) + } + +} + func TestConsumeSequentially(t *testing.T) { app := NewApp( WithLogger(log.NewJSONLogger(os.Stderr)), diff --git a/rabbitmq/broker.go b/rabbitmq/broker.go index 12a0bb0..a101e47 100644 --- a/rabbitmq/broker.go +++ b/rabbitmq/broker.go @@ -3,7 +3,6 @@ package rabbitmq import ( - "context" "encoding/base64" "encoding/json" "fmt" @@ -33,7 +32,7 @@ type Broker struct { queues []string conn *amqp.Connection channel *amqp.Channel - ctx context.Context + delivery map[string]<-chan amqp.Delivery } // WithAmqpUri sets the AMQP connection URI to RabbitMQ. @@ -67,30 +66,27 @@ func NewBroker(options ...BrokerOption) *Broker { amqpUri: DefaultAmqpUri, receiveTimeout: DefaultReceiveTimeout * time.Second, rawMode: false, - ctx: context.Background(), + delivery: make(map[string]<-chan amqp.Delivery), } for _, opt := range options { opt(&br) } if br.conn == nil { - br.channel = nil conn, err := amqp.Dial(br.amqpUri) - br.conn = conn if err != nil { log.Panicf("Failed to connect to RabbitMQ: %s", err) return nil } + br.conn = conn } - if br.channel == nil { - channel, err := br.conn.Channel() - br.channel = channel - if err != nil { - log.Panicf("Failed to open a channel: %s", err) - return nil - } + channel, err := br.conn.Channel() + if err != nil { + log.Panicf("Failed to open a channel: %s", err) + return nil } + br.channel = channel return &br } @@ -135,8 +131,7 @@ func (br *Broker) Send(m []byte, q string) error { replyTo = properties_in["reply_to"].(string) } - err := br.channel.PublishWithContext( - br.ctx, + err := br.channel.Publish( "", // exchange q, // routing key false, // mandatory @@ -150,6 +145,7 @@ func (br *Broker) Send(m []byte, q string) error { ReplyTo: replyTo, Body: body, }) + return err } @@ -158,16 +154,44 @@ func (br *Broker) Send(m []byte, q string) error { func (br *Broker) Observe(queues []string) { br.queues = queues for _, queue := range queues { - _, err := br.channel.QueueDeclare( - queue, // name - true, // durable - false, // delete when unused - false, // exclusive - false, // no-wait - nil, // arguments + durable := true + autoDelete := false + exclusive := false + noWait := false + + // Check whether the queue exists. + _, err := br.channel.QueueDeclarePassive( + queue, + durable, + autoDelete, + exclusive, + noWait, + nil, ) + + // If the queue doesn't exist, attempt to create it. if err != nil { - log.Panicf("Failed to declare a queue: %s", err) + // QueueDeclarePassive() will close the channel if the queue does not exist, so we have to create a new channel when this happens. + if br.channel.IsClosed() { + channel, err := br.conn.Channel() + if err != nil { + log.Panicf("Failed to open a channel: %s", err) + } + br.channel = channel + } + + _, err := br.channel.QueueDeclare( + queue, + durable, + autoDelete, + exclusive, + noWait, + nil, + ) + + if err != nil { + log.Panicf("Failed to declare a queue: %s", err) + } } } } @@ -175,70 +199,69 @@ func (br *Broker) Observe(queues []string) { // Receive fetches a Celery task message from a tail of one of the queues in RabbitMQ. // After a timeout it returns nil, nil. func (br *Broker) Receive() ([]byte, error) { + queue := br.queues[0] + // Put the Celery queue name to the end of the slice for fair processing. + broker.Move2back(br.queues, queue) - const retryIntervalMs = 100 + var err error + + delivery, delivery_exists := br.delivery[queue] + if !delivery_exists { + delivery, err = br.channel.Consume( + queue, // queue + "", // consumer + true, // autoAck + false, // exclusive + false, // noLocal (ignored) + false, // noWait + nil, // args + ) - try_receive := func() (msg amqp.Delivery, ok bool, err error) { - queue := br.queues[0] - // Put the Celery queue name to the end of the slice for fair processing. - broker.Move2back(br.queues, queue) - my_msg, my_ok, my_err := br.channel.Get(queue, true) - if my_err != nil { - log.Printf("Failed to g a message: %s", my_err) + if err != nil { + return nil, err } - return my_msg, my_ok, my_err - } - startTime := time.Now() - timeoutTime := startTime.Add(br.receiveTimeout) - msg, ok, err := try_receive() - if err != nil { - return nil, nil + br.delivery[queue] = delivery } - for !ok { - if time.Now().After(timeoutTime) { - return nil, nil + + select { + case msg := <-delivery: + if br.rawMode { + return msg.Body, nil } - time.Sleep(retryIntervalMs * time.Millisecond) - msg, ok, err = try_receive() - if err != nil { - return nil, nil + // Marshal msg from RabbitMQ Celery format to internal Celery format. + + properties := make(map[string]interface{}) + properties["correlation_id"] = msg.CorrelationId + properties["reply_to"] = msg.ReplyTo + properties["delivery_mode"] = msg.DeliveryMode + properties["delivery_info"] = map[string]interface{}{ + "exchange": msg.Exchange, + "routing_key": msg.RoutingKey, } - } + properties["priority"] = msg.Priority + properties["body_encoding"] = "base64" + properties["delivery_tag"] = msg.DeliveryTag - if br.rawMode { - return msg.Body, nil - } + imsg := make(map[string]interface{}) + imsg["body"] = msg.Body + imsg["content-encoding"] = msg.ContentEncoding + imsg["content-type"] = msg.ContentType + imsg["headers"] = msg.Headers + imsg["properties"] = properties - // Marshal msg from RabbitMQ Celery format to internal Celery format. - - properties := make(map[string]interface{}) - properties["correlation_id"] = msg.CorrelationId - properties["reply_to"] = msg.ReplyTo - properties["delivery_mode"] = msg.DeliveryMode - delivery_info := make(map[string]interface{}) - properties["delivery_info"] = delivery_info - delivery_info["exchange"] = msg.Exchange - delivery_info["routing_key"] = msg.RoutingKey - properties["priority"] = msg.Priority - properties["body_encoding"] = "base64" - properties["delivery_tag"] = msg.DeliveryTag - - imsg := make(map[string]interface{}) - imsg["body"] = msg.Body - imsg["content-encoding"] = msg.ContentEncoding - imsg["content-type"] = msg.ContentType - imsg["headers"] = msg.Headers - imsg["properties"] = properties - - var result []byte - result, err = json.Marshal(imsg) - if err != nil { - err_str := fmt.Errorf("%w", err) - log.Printf("json encode: %s", err_str) + var result []byte + result, err := json.Marshal(imsg) + if err != nil { + err_str := fmt.Errorf("%w", err) + log.Printf("json encode: %s", err_str) + return nil, err + } + return result, nil + + case <-time.After(br.receiveTimeout): + // Receive timeout return nil, nil } - - return result, nil }