Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
redis-version: 7
- uses: namoshek/rabbitmq-github-action@v1
with:
version: '3.8.9'
Copy link
Contributor Author

@roncemer roncemer Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept getting "connection reset by peer" when running the unit tests in the CI pipeline. Upgrading to the latest rabbitmq docker image seems to have improved the situation. If a test pipeline fails, re-running it usually works. Not happy about this, but I don't have a quick solution for it.

version: '4.1.1'
ports: '5672:5672'
- run: go test -count=1 -v ./...
lint:
Expand Down
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,16 @@ $ celery --app myproject worker --queues important --loglevel=debug --without-he
<summary>Sending tasks from Python and receiving them on Go side.</summary>

```sh
$ python producer.py --protocol=1
$ python producer.py
$ go run ./consumer/
{"msg":"waiting for tasks..."}
received a=fizz b=bazz
```

To send a task with Celery Protocol version 1, run *producer.py* with the *--protocol=1* command-line argument:
```sh
$ python producer.py --protocol=1
```
</details>

<details>
Expand Down Expand Up @@ -214,6 +218,10 @@ $ go run ./consumer/
received a=fizz b=bazz
```

To send a task with Celery Protocol version 1, run *producer.py* with the *--protocol=1* command-line argument:
```sh
$ python producer.py --protocol=1
```
</details>

## Testing
Expand Down
52 changes: 52 additions & 0 deletions celery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"github.com/marselester/gopher-celery/goredis"
"github.com/marselester/gopher-celery/protocol"
"github.com/marselester/gopher-celery/rabbitmq"
)

func TestExecuteTaskPanic(t *testing.T) {
Expand Down Expand Up @@ -243,6 +244,57 @@ func TestGoredisProduceAndConsume100times(t *testing.T) {
}
}

func TestRabbitmqProduceAndConsume100times(t *testing.T) {
app := NewApp(
WithBroker(rabbitmq.NewBroker(rabbitmq.WithAmqpUri("amqp://guest:guest@localhost:5672/"))),
WithLogger(log.NewJSONLogger(os.Stderr)),
)

queue := "rabbitmq_broker_test"

// Create the queue, if it doesn't exist.
app.conf.broker.Observe([]string{queue})

for i := 0; i < 100; i++ {
err := app.Delay(
"myproject.apps.myapp.tasks.mytask",
queue,
2,
3,
)
if err != nil {
t.Fatal(err)
}
}

// The test finishes either when ctx times out or all the tasks finish.
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
t.Cleanup(cancel)

var sum int32
app.Register(
"myproject.apps.myapp.tasks.mytask",
queue,
func(ctx context.Context, p *TaskParam) error {
p.NameArgs("a", "b")
atomic.AddInt32(
&sum,
int32(p.MustInt("a")+p.MustInt("b")),
)
return nil
},
)
if err := app.Run(ctx); err != nil {
t.Error(err)
}

var want int32 = 500
if want != sum {
t.Errorf("expected sum %d got %d", want, sum)
}

}

func TestConsumeSequentially(t *testing.T) {
app := NewApp(
WithLogger(log.NewJSONLogger(os.Stderr)),
Expand Down
175 changes: 99 additions & 76 deletions rabbitmq/broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
package rabbitmq

import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -33,7 +32,7 @@ type Broker struct {
queues []string
conn *amqp.Connection
channel *amqp.Channel
ctx context.Context
delivery map[string]<-chan amqp.Delivery
}

// WithAmqpUri sets the AMQP connection URI to RabbitMQ.
Expand Down Expand Up @@ -67,30 +66,27 @@ func NewBroker(options ...BrokerOption) *Broker {
amqpUri: DefaultAmqpUri,
receiveTimeout: DefaultReceiveTimeout * time.Second,
rawMode: false,
ctx: context.Background(),
delivery: make(map[string]<-chan amqp.Delivery),
}
for _, opt := range options {
opt(&br)
}

if br.conn == nil {
br.channel = nil
conn, err := amqp.Dial(br.amqpUri)
br.conn = conn
if err != nil {
log.Panicf("Failed to connect to RabbitMQ: %s", err)
return nil
}
br.conn = conn
}

if br.channel == nil {
channel, err := br.conn.Channel()
br.channel = channel
if err != nil {
log.Panicf("Failed to open a channel: %s", err)
return nil
}
channel, err := br.conn.Channel()
if err != nil {
log.Panicf("Failed to open a channel: %s", err)
return nil
}
br.channel = channel

return &br
}
Expand Down Expand Up @@ -135,8 +131,7 @@ func (br *Broker) Send(m []byte, q string) error {
replyTo = properties_in["reply_to"].(string)
}

err := br.channel.PublishWithContext(
br.ctx,
err := br.channel.Publish(
"", // exchange
q, // routing key
false, // mandatory
Expand All @@ -150,6 +145,7 @@ func (br *Broker) Send(m []byte, q string) error {
ReplyTo: replyTo,
Body: body,
})

return err
}

Expand All @@ -158,87 +154,114 @@ func (br *Broker) Send(m []byte, q string) error {
func (br *Broker) Observe(queues []string) {
br.queues = queues
for _, queue := range queues {
_, err := br.channel.QueueDeclare(
queue, // name
true, // durable
false, // delete when unused
false, // exclusive
false, // no-wait
nil, // arguments
durable := true
autoDelete := false
exclusive := false
noWait := false

// Check whether the queue exists.
_, err := br.channel.QueueDeclarePassive(
queue,
durable,
autoDelete,
exclusive,
noWait,
nil,
)

// If the queue doesn't exist, attempt to create it.
if err != nil {
log.Panicf("Failed to declare a queue: %s", err)
// QueueDeclarePassive() will close the channel if the queue does not exist, so we have to create a new channel when this happens.
if br.channel.IsClosed() {
channel, err := br.conn.Channel()
if err != nil {
log.Panicf("Failed to open a channel: %s", err)
}
br.channel = channel
}

_, err := br.channel.QueueDeclare(
queue,
durable,
autoDelete,
exclusive,
noWait,
nil,
)

if err != nil {
log.Panicf("Failed to declare a queue: %s", err)
}
}
}
}

// Receive fetches a Celery task message from a tail of one of the queues in RabbitMQ.
// After a timeout it returns nil, nil.
func (br *Broker) Receive() ([]byte, error) {
queue := br.queues[0]
// Put the Celery queue name to the end of the slice for fair processing.
broker.Move2back(br.queues, queue)

const retryIntervalMs = 100
var err error

delivery, delivery_exists := br.delivery[queue]
if !delivery_exists {
delivery, err = br.channel.Consume(
queue, // queue
"", // consumer
true, // autoAck
false, // exclusive
false, // noLocal (ignored)
false, // noWait
nil, // args
)

try_receive := func() (msg amqp.Delivery, ok bool, err error) {
queue := br.queues[0]
// Put the Celery queue name to the end of the slice for fair processing.
broker.Move2back(br.queues, queue)
my_msg, my_ok, my_err := br.channel.Get(queue, true)
if my_err != nil {
log.Printf("Failed to g a message: %s", my_err)
if err != nil {
return nil, err
}
return my_msg, my_ok, my_err
}

startTime := time.Now()
timeoutTime := startTime.Add(br.receiveTimeout)
msg, ok, err := try_receive()
if err != nil {
return nil, nil
br.delivery[queue] = delivery
}
for !ok {
if time.Now().After(timeoutTime) {
return nil, nil

select {
case msg := <-delivery:
if br.rawMode {
return msg.Body, nil
}
time.Sleep(retryIntervalMs * time.Millisecond)

msg, ok, err = try_receive()
if err != nil {
return nil, nil
// Marshal msg from RabbitMQ Celery format to internal Celery format.

properties := make(map[string]interface{})
properties["correlation_id"] = msg.CorrelationId
properties["reply_to"] = msg.ReplyTo
properties["delivery_mode"] = msg.DeliveryMode
properties["delivery_info"] = map[string]interface{}{
"exchange": msg.Exchange,
"routing_key": msg.RoutingKey,
}
}
properties["priority"] = msg.Priority
properties["body_encoding"] = "base64"
properties["delivery_tag"] = msg.DeliveryTag

if br.rawMode {
return msg.Body, nil
}
imsg := make(map[string]interface{})
imsg["body"] = msg.Body
imsg["content-encoding"] = msg.ContentEncoding
imsg["content-type"] = msg.ContentType
imsg["headers"] = msg.Headers
imsg["properties"] = properties

// Marshal msg from RabbitMQ Celery format to internal Celery format.

properties := make(map[string]interface{})
properties["correlation_id"] = msg.CorrelationId
properties["reply_to"] = msg.ReplyTo
properties["delivery_mode"] = msg.DeliveryMode
delivery_info := make(map[string]interface{})
properties["delivery_info"] = delivery_info
delivery_info["exchange"] = msg.Exchange
delivery_info["routing_key"] = msg.RoutingKey
properties["priority"] = msg.Priority
properties["body_encoding"] = "base64"
properties["delivery_tag"] = msg.DeliveryTag

imsg := make(map[string]interface{})
imsg["body"] = msg.Body
imsg["content-encoding"] = msg.ContentEncoding
imsg["content-type"] = msg.ContentType
imsg["headers"] = msg.Headers
imsg["properties"] = properties

var result []byte
result, err = json.Marshal(imsg)
if err != nil {
err_str := fmt.Errorf("%w", err)
log.Printf("json encode: %s", err_str)
var result []byte
result, err := json.Marshal(imsg)
if err != nil {
err_str := fmt.Errorf("%w", err)
log.Printf("json encode: %s", err_str)
return nil, err
}
return result, nil

case <-time.After(br.receiveTimeout):
// Receive timeout
return nil, nil
}

return result, nil
}