diff --git a/bindings/kafka/kafka.go b/bindings/kafka/kafka.go index aa58758240..7440d0674d 100644 --- a/bindings/kafka/kafka.go +++ b/bindings/kafka/kafka.go @@ -100,29 +100,26 @@ func (b *Binding) Read(ctx context.Context, handler bindings.Handler) error { return nil } - handlerConfig := kafka.SubscriptionHandlerConfig{ - IsBulkSubscribe: false, - Handler: adaptHandler(handler), - } - for _, t := range b.topics { - b.kafka.AddTopicHandler(t, handlerConfig) - } + ctx, cancel := context.WithCancel(ctx) + b.wg.Add(1) go func() { - defer b.wg.Done() - // Wait for context cancelation or closure. select { case <-ctx.Done(): case <-b.closeCh: } - - // Remove the topic handlers. - for _, t := range b.topics { - b.kafka.RemoveTopicHandler(t) - } + cancel() + b.wg.Done() }() - return b.kafka.Subscribe(ctx) + handlerConfig := kafka.SubscriptionHandlerConfig{ + IsBulkSubscribe: false, + Handler: adaptHandler(handler), + } + + b.kafka.Subscribe(ctx, handlerConfig, b.topics...) + + return nil } func adaptHandler(handler bindings.Handler) kafka.EventHandler { diff --git a/common/component/kafka/consumer.go b/common/component/kafka/consumer.go index 21d2c65b9d..a05e611707 100644 --- a/common/component/kafka/consumer.go +++ b/common/component/kafka/consumer.go @@ -14,12 +14,10 @@ limitations under the License. package kafka import ( - "context" "errors" "fmt" "strconv" "sync" - "sync/atomic" "time" "github.com/IBM/sarama" @@ -29,15 +27,8 @@ import ( ) type consumer struct { - k *Kafka - ready chan bool - running chan struct{} - stopped atomic.Bool - once sync.Once - mutex sync.Mutex - skipConsume bool - consumeCtx context.Context - consumeCancel context.CancelFunc + k *Kafka + mutex sync.Mutex } func (consumer *consumer) ConsumeClaim(session sarama.ConsumerGroupSession, claim sarama.ConsumerGroupClaim) error { @@ -233,27 +224,9 @@ func (consumer *consumer) Cleanup(sarama.ConsumerGroupSession) error { } func (consumer *consumer) Setup(sarama.ConsumerGroupSession) error { - consumer.once.Do(func() { - close(consumer.ready) - }) - return nil } -// AddTopicHandler adds a handler and configuration for a topic -func (k *Kafka) AddTopicHandler(topic string, handlerConfig SubscriptionHandlerConfig) { - k.subscribeLock.Lock() - k.subscribeTopics[topic] = handlerConfig - k.subscribeLock.Unlock() -} - -// RemoveTopicHandler removes a topic handler -func (k *Kafka) RemoveTopicHandler(topic string) { - k.subscribeLock.Lock() - delete(k.subscribeTopics, topic) - k.subscribeLock.Unlock() -} - // checkBulkSubscribe checks if a bulk handler and config are correctly registered for provided topic func (k *Kafka) checkBulkSubscribe(topic string) bool { if bulkHandlerConfig, ok := k.subscribeTopics[topic]; ok && @@ -275,120 +248,3 @@ func (k *Kafka) GetTopicHandlerConfig(topic string) (SubscriptionHandlerConfig, return SubscriptionHandlerConfig{}, fmt.Errorf("any handler for messages of topic %s not found", topic) } - -// Subscribe to topic in the Kafka cluster, in a background goroutine -func (k *Kafka) Subscribe(ctx context.Context) error { - if k.consumerGroup == "" { - return errors.New("kafka: consumerGroup must be set to subscribe") - } - - k.subscribeLock.Lock() - defer k.subscribeLock.Unlock() - - topics := k.subscribeTopics.TopicList() - if len(topics) == 0 { - // Nothing to subscribe to - return nil - } - k.consumer.skipConsume = true - - ctxCreateFn := func() { - consumeCtx, cancel := context.WithCancel(context.Background()) - - k.consumer.consumeCtx = consumeCtx - k.consumer.consumeCancel = cancel - - k.consumer.skipConsume = false - } - - if k.cg == nil { - cg, err := sarama.NewConsumerGroup(k.brokers, k.consumerGroup, k.config) - if err != nil { - return err - } - - k.cg = cg - - ready := make(chan bool) - k.consumer = consumer{ - k: k, - ready: ready, - running: make(chan struct{}), - } - - ctxCreateFn() - - go func() { - k.logger.Debugf("Subscribed and listening to topics: %s", topics) - - for { - // If the context was cancelled, as is the case when handling SIGINT and SIGTERM below, then this pops - // us out of the consume loop - if ctx.Err() != nil { - k.logger.Info("Consume context cancelled") - break - } - - k.logger.Debugf("Starting loop to consume.") - - if k.consumer.skipConsume { - continue - } - - topics = k.subscribeTopics.TopicList() - - // Consume the requested topics - bo := backoff.WithContext(backoff.NewConstantBackOff(k.consumeRetryInterval), ctx) - innerErr := retry.NotifyRecover(func() error { - if ctxErr := ctx.Err(); ctxErr != nil { - return backoff.Permanent(ctxErr) - } - return k.cg.Consume(k.consumer.consumeCtx, topics, &(k.consumer)) - }, bo, func(err error, t time.Duration) { - k.logger.Errorf("Error consuming %v. Retrying...: %v", topics, err) - }, func() { - k.logger.Infof("Recovered consuming %v", topics) - }) - if innerErr != nil && !errors.Is(innerErr, context.Canceled) { - k.logger.Errorf("Permanent error consuming %v: %v", topics, innerErr) - } - } - - k.logger.Debugf("Closing ConsumerGroup for topics: %v", topics) - err := k.cg.Close() - if err != nil { - k.logger.Errorf("Error closing consumer group: %v", err) - } - - // Ensure running channel is only closed once. - if k.consumer.stopped.CompareAndSwap(false, true) { - close(k.consumer.running) - } - }() - - <-ready - } else { - // The consumer group is already created and consuming topics. This means a new subscription is being added - k.consumer.consumeCancel() - ctxCreateFn() - } - - return nil -} - -// Close down consumer group resources, refresh once. -func (k *Kafka) closeSubscriptionResources() { - if k.cg != nil { - err := k.cg.Close() - if err != nil { - k.logger.Errorf("Error closing consumer group: %v", err) - } - - k.consumer.once.Do(func() { - // Wait for shutdown to be complete - <-k.consumer.running - close(k.consumer.ready) - k.consumer.once = sync.Once{} - }) - } -} diff --git a/common/component/kafka/kafka.go b/common/component/kafka/kafka.go index d9cb1d16c9..876365e489 100644 --- a/common/component/kafka/kafka.go +++ b/common/component/kafka/kafka.go @@ -20,6 +20,7 @@ import ( "fmt" "strings" "sync" + "sync/atomic" "time" "github.com/IBM/sarama" @@ -34,19 +35,24 @@ import ( // Kafka allows reading/writing to a Kafka consumer group. type Kafka struct { - producer sarama.SyncProducer - consumerGroup string - brokers []string - logger logger.Logger - authType string - saslUsername string - saslPassword string - initialOffset int64 + producer sarama.SyncProducer + consumerGroup string + brokers []string + logger logger.Logger + authType string + saslUsername string + saslPassword string + initialOffset int64 + config *sarama.Config + cg sarama.ConsumerGroup - consumer consumer - config *sarama.Config subscribeTopics TopicHandlerConfig subscribeLock sync.Mutex + consumerCancel context.CancelFunc + consumerWG sync.WaitGroup + closeCh chan struct{} + closed atomic.Bool + wg sync.WaitGroup // schema registry settings srClient srclient.ISchemaRegistryClient @@ -106,7 +112,7 @@ func NewKafka(logger logger.Logger) *Kafka { return &Kafka{ logger: logger, subscribeTopics: make(TopicHandlerConfig), - subscribeLock: sync.Mutex{}, + closeCh: make(chan struct{}), } } @@ -184,11 +190,11 @@ func (k *Kafka) Init(ctx context.Context, metadata map[string]string) error { // Default retry configuration is used if no // backOff properties are set. - if err := retry.DecodeConfigWithPrefix( + if rerr := retry.DecodeConfigWithPrefix( &k.backOffConfig, metadata, - "backOff"); err != nil { - return err + "backOff"); rerr != nil { + return rerr } k.consumeRetryEnabled = meta.ConsumeRetryEnabled k.consumeRetryInterval = meta.ConsumeRetryInterval @@ -207,22 +213,41 @@ func (k *Kafka) Init(ctx context.Context, metadata map[string]string) error { } k.logger.Debug("Kafka message bus initialization complete") + k.cg, err = sarama.NewConsumerGroup(k.brokers, k.consumerGroup, k.config) + if err != nil { + return err + } + return nil } -func (k *Kafka) Close() (err error) { - k.closeSubscriptionResources() +func (k *Kafka) Close() error { + defer k.wg.Wait() + defer k.consumerWG.Wait() - if k.producer != nil { - err = k.producer.Close() - k.producer = nil - } + errs := make([]error, 2) + if k.closed.CompareAndSwap(false, true) { + close(k.closeCh) + + if k.producer != nil { + errs[0] = k.producer.Close() + k.producer = nil + } + + if k.internalContext != nil { + k.internalContextCancel() + } + + k.subscribeLock.Lock() + if k.consumerCancel != nil { + k.consumerCancel() + } + k.subscribeLock.Unlock() - if k.internalContext != nil { - k.internalContextCancel() + errs[1] = k.cg.Close() } - return err + return errors.Join(errs...) } func getSchemaSubject(topic string) string { diff --git a/common/component/kafka/mocks/consumergroup.go b/common/component/kafka/mocks/consumergroup.go new file mode 100644 index 0000000000..5106767e95 --- /dev/null +++ b/common/component/kafka/mocks/consumergroup.go @@ -0,0 +1,115 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package mocks + +import ( + "context" + + "github.com/IBM/sarama" +) + +type FakeConsumerGroup struct { + consumerFn func(context.Context, []string, sarama.ConsumerGroupHandler) error + errorsFn func() <-chan error + closeFn func() error + pauseFn func(map[string][]int32) + resumeFn func(map[string][]int32) + pauseAllFn func() + resumeAllFn func() +} + +func NewConsumerGroup() *FakeConsumerGroup { + return &FakeConsumerGroup{ + consumerFn: func(context.Context, []string, sarama.ConsumerGroupHandler) error { + return nil + }, + errorsFn: func() <-chan error { + return nil + }, + closeFn: func() error { + return nil + }, + pauseFn: func(map[string][]int32) { + }, + resumeFn: func(map[string][]int32) { + }, + pauseAllFn: func() { + }, + resumeAllFn: func() { + }, + } +} + +func (f *FakeConsumerGroup) WithConsumeFn(fn func(context.Context, []string, sarama.ConsumerGroupHandler) error) *FakeConsumerGroup { + f.consumerFn = fn + return f +} + +func (f *FakeConsumerGroup) WithErrorsFn(fn func() <-chan error) *FakeConsumerGroup { + f.errorsFn = fn + return f +} + +func (f *FakeConsumerGroup) WithCloseFn(fn func() error) *FakeConsumerGroup { + f.closeFn = fn + return f +} + +func (f *FakeConsumerGroup) WithPauseFn(fn func(map[string][]int32)) *FakeConsumerGroup { + f.pauseFn = fn + return f +} + +func (f *FakeConsumerGroup) WithResumeFn(fn func(map[string][]int32)) *FakeConsumerGroup { + f.resumeFn = fn + return f +} + +func (f *FakeConsumerGroup) WithPauseAllFn(fn func()) *FakeConsumerGroup { + f.pauseAllFn = fn + return f +} + +func (f *FakeConsumerGroup) WithResumeAllFn(fn func()) *FakeConsumerGroup { + f.resumeAllFn = fn + return f +} + +func (f *FakeConsumerGroup) Consume(ctx context.Context, topics []string, handler sarama.ConsumerGroupHandler) error { + return f.consumerFn(ctx, topics, handler) +} + +func (f *FakeConsumerGroup) Errors() <-chan error { + return f.errorsFn() +} + +func (f *FakeConsumerGroup) Close() error { + return f.closeFn() +} + +func (f *FakeConsumerGroup) Pause(partitions map[string][]int32) { + f.pauseFn(partitions) +} + +func (f *FakeConsumerGroup) Resume(partitions map[string][]int32) { + f.resumeFn(partitions) +} + +func (f *FakeConsumerGroup) PauseAll() { + f.pauseAllFn() +} + +func (f *FakeConsumerGroup) ResumeAll() { + f.resumeAllFn() +} diff --git a/common/component/kafka/mocks/mock_ISchemaRegistryClient.go b/common/component/kafka/mocks/mock_ISchemaRegistryClient.go index b52653090b..0e7d3bb72c 100644 --- a/common/component/kafka/mocks/mock_ISchemaRegistryClient.go +++ b/common/component/kafka/mocks/mock_ISchemaRegistryClient.go @@ -2,7 +2,7 @@ // Source: /Users/patrick.assuied/go/pkg/mod/github.com/riferrei/srclient@v0.6.0/schemaRegistryClient.go // Package mock_srclient is a generated GoMock package. -package mock_srclient +package mocks import ( reflect "reflect" diff --git a/common/component/kafka/subscriber.go b/common/component/kafka/subscriber.go new file mode 100644 index 0000000000..95bdd5a232 --- /dev/null +++ b/common/component/kafka/subscriber.go @@ -0,0 +1,103 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package kafka + +import ( + "context" + "errors" + "time" +) + +// Subscribe adds a handler and configuration for a topic, and subscribes. +// Unsubscribes to the topic on context cancel. +func (k *Kafka) Subscribe(ctx context.Context, handlerConfig SubscriptionHandlerConfig, topics ...string) { + k.subscribeLock.Lock() + defer k.subscribeLock.Unlock() + for _, topic := range topics { + k.subscribeTopics[topic] = handlerConfig + } + + k.logger.Debugf("Subscribing to topic: %v", topics) + + k.reloadConsumerGroup() + + k.wg.Add(1) + go func() { + defer k.wg.Done() + select { + case <-ctx.Done(): + case <-k.closeCh: + } + + k.subscribeLock.Lock() + defer k.subscribeLock.Unlock() + + k.logger.Debugf("Unsubscribing to topic: %v", topics) + + for _, topic := range topics { + delete(k.subscribeTopics, topic) + } + + k.reloadConsumerGroup() + }() +} + +// reloadConsumerGroup reloads the consumer group with the new topics. +func (k *Kafka) reloadConsumerGroup() { + if k.consumerCancel != nil { + k.consumerCancel() + k.consumerCancel = nil + k.consumerWG.Wait() + } + + if len(k.subscribeTopics) == 0 || k.closed.Load() { + return + } + + topics := k.subscribeTopics.TopicList() + + k.logger.Debugf("Subscribed and listening to topics: %s", topics) + + consumer := &consumer{k: k} + + ctx, cancel := context.WithCancel(context.Background()) + k.consumerCancel = cancel + + k.consumerWG.Add(1) + go func() { + defer k.consumerWG.Done() + k.consume(ctx, topics, consumer) + k.logger.Debugf("Closing ConsumerGroup for topics: %v", topics) + }() +} + +func (k *Kafka) consume(ctx context.Context, topics []string, consumer *consumer) { + for { + err := k.cg.Consume(ctx, topics, consumer) + if errors.Is(err, context.Canceled) { + return + } + if err != nil { + k.logger.Errorf("Error consuming %v. Retrying...: %v", topics, err) + } + + select { + case <-k.closeCh: + return + case <-ctx.Done(): + return + case <-time.After(k.consumeRetryInterval): + } + } +} diff --git a/common/component/kafka/subscriber_test.go b/common/component/kafka/subscriber_test.go new file mode 100644 index 0000000000..f54948e341 --- /dev/null +++ b/common/component/kafka/subscriber_test.go @@ -0,0 +1,522 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package kafka + +import ( + "context" + "errors" + "strconv" + "sync/atomic" + "testing" + "time" + + "github.com/IBM/sarama" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/dapr/components-contrib/common/component/kafka/mocks" + "github.com/dapr/kit/logger" +) + +func Test_reloadConsumerGroup(t *testing.T) { + t.Run("if reload called with no topics and not closed, expect return and cancel called", func(t *testing.T) { + var consumeCalled atomic.Bool + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + cg := mocks.NewConsumerGroup().WithConsumeFn(func(context.Context, []string, sarama.ConsumerGroupHandler) error { + consumeCalled.Store(true) + return nil + }) + + k := &Kafka{ + logger: logger.NewLogger("test"), + cg: cg, + subscribeTopics: nil, + closeCh: make(chan struct{}), + consumerCancel: cancel, + } + + k.reloadConsumerGroup() + + require.Error(t, ctx.Err()) + assert.False(t, consumeCalled.Load()) + }) + + t.Run("if reload called with topics but is closed, expect return and cancel called", func(t *testing.T) { + var consumeCalled atomic.Bool + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + cg := mocks.NewConsumerGroup().WithConsumeFn(func(context.Context, []string, sarama.ConsumerGroupHandler) error { + consumeCalled.Store(true) + return nil + }) + k := &Kafka{ + logger: logger.NewLogger("test"), + cg: cg, + consumerCancel: cancel, + closeCh: make(chan struct{}), + subscribeTopics: TopicHandlerConfig{"foo": SubscriptionHandlerConfig{}}, + } + + k.closed.Store(true) + + k.reloadConsumerGroup() + + require.Error(t, ctx.Err()) + assert.False(t, consumeCalled.Load()) + }) + + t.Run("if reload called with topics, expect Consume to be called. If cancelled return", func(t *testing.T) { + var consumeCalled atomic.Bool + var consumeCancel atomic.Bool + cg := mocks.NewConsumerGroup().WithConsumeFn(func(ctx context.Context, _ []string, _ sarama.ConsumerGroupHandler) error { + consumeCalled.Store(true) + <-ctx.Done() + consumeCancel.Store(true) + return nil + }) + k := &Kafka{ + logger: logger.NewLogger("test"), + cg: cg, + consumerCancel: nil, + closeCh: make(chan struct{}), + subscribeTopics: TopicHandlerConfig{"foo": SubscriptionHandlerConfig{}}, + } + + k.reloadConsumerGroup() + + assert.Eventually(t, consumeCalled.Load, time.Second, time.Millisecond) + assert.False(t, consumeCancel.Load()) + assert.NotNil(t, k.consumerCancel) + + k.consumerCancel() + k.consumerWG.Wait() + }) + + t.Run("Consume retries if returns non-context cancel error", func(t *testing.T) { + var consumeCalled atomic.Int64 + cg := mocks.NewConsumerGroup().WithConsumeFn(func(ctx context.Context, _ []string, _ sarama.ConsumerGroupHandler) error { + consumeCalled.Add(1) + return errors.New("some error") + }) + k := &Kafka{ + logger: logger.NewLogger("test"), + cg: cg, + consumerCancel: nil, + closeCh: make(chan struct{}), + subscribeTopics: TopicHandlerConfig{"foo": SubscriptionHandlerConfig{}}, + consumeRetryInterval: time.Millisecond, + } + + k.reloadConsumerGroup() + + assert.Eventually(t, func() bool { + return consumeCalled.Load() > 10 + }, time.Second, time.Millisecond) + + assert.NotNil(t, k.consumerCancel) + + called := consumeCalled.Load() + k.consumerCancel() + k.consumerWG.Wait() + assert.InDelta(t, called, consumeCalled.Load(), 1) + }) + + t.Run("Consume return immediately if returns a context cancelled error", func(t *testing.T) { + var consumeCalled atomic.Int64 + cg := mocks.NewConsumerGroup().WithConsumeFn(func(ctx context.Context, _ []string, _ sarama.ConsumerGroupHandler) error { + consumeCalled.Add(1) + if consumeCalled.Load() == 5 { + return context.Canceled + } + return errors.New("some error") + }) + k := &Kafka{ + logger: logger.NewLogger("test"), + cg: cg, + consumerCancel: nil, + closeCh: make(chan struct{}), + subscribeTopics: map[string]SubscriptionHandlerConfig{"foo": {}}, + consumeRetryInterval: time.Millisecond, + } + + k.reloadConsumerGroup() + + assert.Eventually(t, func() bool { + return consumeCalled.Load() == 5 + }, time.Second, time.Millisecond) + + k.consumerWG.Wait() + assert.Equal(t, int64(5), consumeCalled.Load()) + }) + + t.Run("Calling reloadConsumerGroup causes context to be cancelled and Consume called again (close by closed)", func(t *testing.T) { + var consumeCalled atomic.Int64 + var cancelCalled atomic.Int64 + cg := mocks.NewConsumerGroup().WithConsumeFn(func(ctx context.Context, _ []string, _ sarama.ConsumerGroupHandler) error { + consumeCalled.Add(1) + <-ctx.Done() + cancelCalled.Add(1) + return nil + }) + k := &Kafka{ + logger: logger.NewLogger("test"), + cg: cg, + consumerCancel: nil, + closeCh: make(chan struct{}), + subscribeTopics: map[string]SubscriptionHandlerConfig{"foo": {}}, + consumeRetryInterval: time.Millisecond, + } + + k.reloadConsumerGroup() + assert.Eventually(t, func() bool { + return consumeCalled.Load() == 1 + }, time.Second, time.Millisecond) + assert.Equal(t, int64(0), cancelCalled.Load()) + + k.reloadConsumerGroup() + assert.Eventually(t, func() bool { + return consumeCalled.Load() == 2 + }, time.Second, time.Millisecond) + assert.Equal(t, int64(1), cancelCalled.Load()) + + k.closed.Store(true) + k.reloadConsumerGroup() + assert.Equal(t, int64(2), cancelCalled.Load()) + assert.Equal(t, int64(2), consumeCalled.Load()) + }) + + t.Run("Calling reloadConsumerGroup causes context to be cancelled and Consume called again (close by no subscriptions)", func(t *testing.T) { + var consumeCalled atomic.Int64 + var cancelCalled atomic.Int64 + cg := mocks.NewConsumerGroup().WithConsumeFn(func(ctx context.Context, _ []string, _ sarama.ConsumerGroupHandler) error { + consumeCalled.Add(1) + <-ctx.Done() + cancelCalled.Add(1) + return nil + }) + k := &Kafka{ + logger: logger.NewLogger("test"), + cg: cg, + consumerCancel: nil, + closeCh: make(chan struct{}), + subscribeTopics: map[string]SubscriptionHandlerConfig{"foo": {}}, + consumeRetryInterval: time.Millisecond, + } + + k.reloadConsumerGroup() + assert.Eventually(t, func() bool { + return consumeCalled.Load() == 1 + }, time.Second, time.Millisecond) + assert.Equal(t, int64(0), cancelCalled.Load()) + + k.reloadConsumerGroup() + assert.Eventually(t, func() bool { + return consumeCalled.Load() == 2 + }, time.Second, time.Millisecond) + assert.Equal(t, int64(1), cancelCalled.Load()) + + k.subscribeTopics = nil + k.reloadConsumerGroup() + assert.Equal(t, int64(2), cancelCalled.Load()) + assert.Equal(t, int64(2), consumeCalled.Load()) + }) +} + +func Test_Subscribe(t *testing.T) { + t.Run("Calling subscribe with no topics should not consume", func(t *testing.T) { + var consumeCalled atomic.Int64 + var cancelCalled atomic.Int64 + cg := mocks.NewConsumerGroup().WithConsumeFn(func(ctx context.Context, _ []string, _ sarama.ConsumerGroupHandler) error { + consumeCalled.Add(1) + <-ctx.Done() + cancelCalled.Add(1) + return nil + }) + k := &Kafka{ + logger: logger.NewLogger("test"), + cg: cg, + consumerCancel: nil, + closeCh: make(chan struct{}), + consumeRetryInterval: time.Millisecond, + subscribeTopics: make(TopicHandlerConfig), + } + + k.Subscribe(context.Background(), SubscriptionHandlerConfig{}) + + assert.Nil(t, k.consumerCancel) + assert.Equal(t, int64(0), consumeCalled.Load()) + assert.Equal(t, int64(0), cancelCalled.Load()) + }) + + t.Run("Calling subscribe when closed should not consume", func(t *testing.T) { + var consumeCalled atomic.Int64 + var cancelCalled atomic.Int64 + cg := mocks.NewConsumerGroup().WithConsumeFn(func(ctx context.Context, _ []string, _ sarama.ConsumerGroupHandler) error { + consumeCalled.Add(1) + <-ctx.Done() + cancelCalled.Add(1) + return nil + }) + k := &Kafka{ + logger: logger.NewLogger("test"), + cg: cg, + consumerCancel: nil, + closeCh: make(chan struct{}), + consumeRetryInterval: time.Millisecond, + subscribeTopics: make(TopicHandlerConfig), + } + + k.closed.Store(true) + + k.Subscribe(context.Background(), SubscriptionHandlerConfig{}, "abc") + + assert.Nil(t, k.consumerCancel) + assert.Equal(t, int64(0), consumeCalled.Load()) + assert.Equal(t, int64(0), cancelCalled.Load()) + }) + + t.Run("Subscribe should subscribe to a topic until context is cancelled", func(t *testing.T) { + var consumeCalled atomic.Int64 + var cancelCalled atomic.Int64 + var consumeTopics atomic.Value + cg := mocks.NewConsumerGroup().WithConsumeFn(func(ctx context.Context, topics []string, _ sarama.ConsumerGroupHandler) error { + consumeTopics.Store(topics) + consumeCalled.Add(1) + <-ctx.Done() + cancelCalled.Add(1) + return nil + }) + k := &Kafka{ + logger: logger.NewLogger("test"), + cg: cg, + consumerCancel: nil, + closeCh: make(chan struct{}), + consumeRetryInterval: time.Millisecond, + subscribeTopics: make(TopicHandlerConfig), + } + + ctx, cancel := context.WithCancel(context.Background()) + k.Subscribe(ctx, SubscriptionHandlerConfig{}, "abc") + + assert.Eventually(t, func() bool { + return consumeCalled.Load() == 1 + }, time.Second, time.Millisecond) + assert.Equal(t, int64(0), cancelCalled.Load()) + + cancel() + + assert.Eventually(t, func() bool { + return cancelCalled.Load() == 1 + }, time.Second, time.Millisecond) + assert.Equal(t, int64(1), consumeCalled.Load()) + + assert.Equal(t, []string{"abc"}, consumeTopics.Load()) + }) + + t.Run("Calling subscribe multiple times with new topics should re-consume will full topics list", func(t *testing.T) { + var consumeCalled atomic.Int64 + var cancelCalled atomic.Int64 + var consumeTopics atomic.Value + cg := mocks.NewConsumerGroup().WithConsumeFn(func(ctx context.Context, topics []string, _ sarama.ConsumerGroupHandler) error { + consumeTopics.Store(topics) + consumeCalled.Add(1) + <-ctx.Done() + cancelCalled.Add(1) + return nil + }) + k := &Kafka{ + logger: logger.NewLogger("test"), + cg: cg, + consumerCancel: nil, + closeCh: make(chan struct{}), + consumeRetryInterval: time.Millisecond, + subscribeTopics: make(TopicHandlerConfig), + } + + ctx, cancel := context.WithCancel(context.Background()) + k.Subscribe(ctx, SubscriptionHandlerConfig{}, "abc") + + assert.Eventually(t, func() bool { + return consumeCalled.Load() == 1 + }, time.Second, time.Millisecond) + assert.Equal(t, int64(0), cancelCalled.Load()) + assert.Equal(t, []string{"abc"}, consumeTopics.Load()) + assert.Equal(t, TopicHandlerConfig{"abc": SubscriptionHandlerConfig{}}, k.subscribeTopics) + + k.Subscribe(ctx, SubscriptionHandlerConfig{}, "def") + assert.Equal(t, TopicHandlerConfig{ + "abc": SubscriptionHandlerConfig{}, + "def": SubscriptionHandlerConfig{}, + }, k.subscribeTopics) + + assert.Eventually(t, func() bool { + return consumeCalled.Load() == 2 + }, time.Second, time.Millisecond) + assert.Equal(t, int64(1), cancelCalled.Load()) + assert.ElementsMatch(t, []string{"abc", "def"}, consumeTopics.Load()) + + cancel() + assert.Eventually(t, func() bool { + return consumeCalled.Load() == 3 + }, time.Second, time.Millisecond) + assert.Equal(t, int64(3), cancelCalled.Load()) + + k.Subscribe(ctx, SubscriptionHandlerConfig{}) + assert.Nil(t, k.consumerCancel) + assert.Empty(t, k.subscribeTopics) + }) + + t.Run("Consume return immediately if returns a context cancelled error", func(t *testing.T) { + var consumeCalled atomic.Int64 + cg := mocks.NewConsumerGroup().WithConsumeFn(func(ctx context.Context, _ []string, _ sarama.ConsumerGroupHandler) error { + consumeCalled.Add(1) + if consumeCalled.Load() == 5 { + return context.Canceled + } + return errors.New("some error") + }) + k := &Kafka{ + logger: logger.NewLogger("test"), + cg: cg, + consumerCancel: nil, + closeCh: make(chan struct{}), + subscribeTopics: make(TopicHandlerConfig), + consumeRetryInterval: time.Millisecond, + } + + k.Subscribe(context.Background(), SubscriptionHandlerConfig{}, "foo") + assert.Equal(t, TopicHandlerConfig{"foo": SubscriptionHandlerConfig{}}, k.subscribeTopics) + assert.Eventually(t, func() bool { + return consumeCalled.Load() == 5 + }, time.Second, time.Millisecond) + k.consumerWG.Wait() + assert.Equal(t, int64(5), consumeCalled.Load()) + assert.Equal(t, TopicHandlerConfig{"foo": SubscriptionHandlerConfig{}}, k.subscribeTopics) + }) + + t.Run("Consume dynamically changes topics which are being consumed", func(t *testing.T) { + var consumeTopics atomic.Value + var consumeCalled atomic.Int64 + var cancelCalled atomic.Int64 + cg := mocks.NewConsumerGroup().WithConsumeFn(func(ctx context.Context, topics []string, _ sarama.ConsumerGroupHandler) error { + consumeTopics.Store(topics) + consumeCalled.Add(1) + <-ctx.Done() + cancelCalled.Add(1) + return nil + }) + k := &Kafka{ + logger: logger.NewLogger("test"), + cg: cg, + consumerCancel: nil, + closeCh: make(chan struct{}), + subscribeTopics: make(TopicHandlerConfig), + consumeRetryInterval: time.Millisecond, + } + + ctx1, cancel1 := context.WithCancel(context.Background()) + k.Subscribe(ctx1, SubscriptionHandlerConfig{}, "abc") + assert.Eventually(t, func() bool { + return consumeCalled.Load() == 1 + }, time.Second, time.Millisecond) + assert.ElementsMatch(t, []string{"abc"}, consumeTopics.Load()) + assert.Equal(t, int64(0), cancelCalled.Load()) + + ctx2, cancel2 := context.WithCancel(context.Background()) + k.Subscribe(ctx2, SubscriptionHandlerConfig{}, "def") + assert.Eventually(t, func() bool { + return consumeCalled.Load() == 2 + }, time.Second, time.Millisecond) + assert.ElementsMatch(t, []string{"abc", "def"}, consumeTopics.Load()) + assert.Equal(t, int64(1), cancelCalled.Load()) + + ctx3, cancel3 := context.WithCancel(context.Background()) + k.Subscribe(ctx3, SubscriptionHandlerConfig{}, "123") + assert.Eventually(t, func() bool { + return consumeCalled.Load() == 3 + }, time.Second, time.Millisecond) + assert.ElementsMatch(t, []string{"abc", "def", "123"}, consumeTopics.Load()) + assert.Equal(t, int64(2), cancelCalled.Load()) + + cancel2() + assert.Eventually(t, func() bool { + return consumeCalled.Load() == 4 + }, time.Second, time.Millisecond) + assert.ElementsMatch(t, []string{"abc", "123"}, consumeTopics.Load()) + assert.Equal(t, int64(3), cancelCalled.Load()) + + ctx2, cancel2 = context.WithCancel(context.Background()) + k.Subscribe(ctx2, SubscriptionHandlerConfig{}, "456") + assert.Eventually(t, func() bool { + return consumeCalled.Load() == 5 + }, time.Second, time.Millisecond) + assert.ElementsMatch(t, []string{"abc", "123", "456"}, consumeTopics.Load()) + assert.Equal(t, int64(4), cancelCalled.Load()) + + cancel1() + cancel3() + + assert.Eventually(t, func() bool { + return consumeCalled.Load() == 7 + }, time.Second, time.Millisecond) + assert.ElementsMatch(t, []string{"456"}, consumeTopics.Load()) + assert.Equal(t, int64(6), cancelCalled.Load()) + + cancel2() + assert.Eventually(t, func() bool { + return cancelCalled.Load() == 7 + }, time.Second, time.Millisecond) + assert.Empty(t, k.subscribeTopics) + assert.Equal(t, int64(7), consumeCalled.Load()) + }) + + t.Run("Can call Subscribe concurrently", func(t *testing.T) { + var cancelCalled atomic.Int64 + var consumeCalled atomic.Int64 + cg := mocks.NewConsumerGroup().WithConsumeFn(func(ctx context.Context, topics []string, _ sarama.ConsumerGroupHandler) error { + consumeCalled.Add(1) + <-ctx.Done() + cancelCalled.Add(1) + return nil + }) + k := &Kafka{ + logger: logger.NewLogger("test"), + cg: cg, + consumerCancel: nil, + closeCh: make(chan struct{}), + subscribeTopics: make(TopicHandlerConfig), + consumeRetryInterval: time.Millisecond, + } + + ctx, cancel := context.WithCancel(context.Background()) + for i := 0; i < 100; i++ { + go func(i int) { + k.Subscribe(ctx, SubscriptionHandlerConfig{}, strconv.Itoa(i)) + }(i) + } + + assert.Eventually(t, func() bool { + return consumeCalled.Load() == 100 + }, time.Second, time.Millisecond) + assert.Equal(t, int64(99), cancelCalled.Load()) + cancel() + assert.Eventually(t, func() bool { + return cancelCalled.Load() == 199 + }, time.Second, time.Millisecond) + assert.Equal(t, int64(199), consumeCalled.Load()) + }) +} diff --git a/pubsub/kafka/kafka.go b/pubsub/kafka/kafka.go index 6a81184a73..f47bf06dd9 100644 --- a/pubsub/kafka/kafka.go +++ b/pubsub/kafka/kafka.go @@ -54,7 +54,9 @@ func (p *PubSub) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, han Handler: adaptHandler(handler), ValueSchemaType: valueSchemaType, } - return p.subscribeUtil(ctx, req, handlerConfig) + + p.subscribeUtil(ctx, req, handlerConfig) + return nil } func (p *PubSub) BulkSubscribe(ctx context.Context, req pubsub.SubscribeRequest, @@ -78,36 +80,24 @@ func (p *PubSub) BulkSubscribe(ctx context.Context, req pubsub.SubscribeRequest, BulkHandler: adaptBulkHandler(handler), ValueSchemaType: valueSchemaType, } - return p.subscribeUtil(ctx, req, handlerConfig) + p.subscribeUtil(ctx, req, handlerConfig) + return nil } -func (p *PubSub) subscribeUtil(ctx context.Context, req pubsub.SubscribeRequest, handlerConfig kafka.SubscriptionHandlerConfig) error { - p.kafka.AddTopicHandler(req.Topic, handlerConfig) +func (p *PubSub) subscribeUtil(ctx context.Context, req pubsub.SubscribeRequest, handlerConfig kafka.SubscriptionHandlerConfig) { + ctx, cancel := context.WithCancel(ctx) p.wg.Add(1) go func() { - defer p.wg.Done() - // Wait for context cancelation select { case <-ctx.Done(): case <-p.closeCh: } - - // Remove the topic handler before restarting the subscriber - p.kafka.RemoveTopicHandler(req.Topic) - - // If the component's context has been canceled, do not re-subscribe - if ctx.Err() != nil { - return - } - - err := p.kafka.Subscribe(ctx) - if err != nil { - p.logger.Errorf("kafka pubsub: error re-subscribing: %v", err) - } + cancel() + p.wg.Done() }() - return p.kafka.Subscribe(ctx) + p.kafka.Subscribe(ctx, handlerConfig, req.Topic) } // NewKafka returns a new kafka pubsub instance. diff --git a/tests/certification/go.mod b/tests/certification/go.mod index 43ae80133a..bc875115ac 100644 --- a/tests/certification/go.mod +++ b/tests/certification/go.mod @@ -19,8 +19,8 @@ require ( github.com/cenkalti/backoff/v4 v4.2.1 github.com/cloudwego/kitex v0.5.0 github.com/cloudwego/kitex-examples v0.1.1 - github.com/dapr/components-contrib v1.13.0-rc.3 - github.com/dapr/dapr v1.13.0-rc.2 + github.com/dapr/components-contrib v1.13.0-rc.6 + github.com/dapr/dapr v1.13.0-rc.7 github.com/dapr/go-sdk v1.6.1-0.20231102031149-87bbb8cd690a github.com/dapr/kit v0.13.0 github.com/eclipse/paho.mqtt.golang v1.4.3 diff --git a/tests/certification/go.sum b/tests/certification/go.sum index b601744a6d..06a0361a71 100644 --- a/tests/certification/go.sum +++ b/tests/certification/go.sum @@ -388,8 +388,8 @@ github.com/cyphar/filepath-securejoin v0.2.4 h1:Ugdm7cg7i6ZK6x3xDF1oEu1nfkyfH53E github.com/cyphar/filepath-securejoin v0.2.4/go.mod h1:aPGpWjXOXUn2NCNjFvBE6aRxGGx79pTxQpKOJNYHHl4= github.com/danieljoos/wincred v1.1.2 h1:QLdCxFs1/Yl4zduvBdcHB8goaYk9RARS2SgLLRuAyr0= github.com/danieljoos/wincred v1.1.2/go.mod h1:GijpziifJoIBfYh+S7BbkdUTU4LfM+QnGqR5Vl2tAx0= -github.com/dapr/dapr v1.13.0-rc.2 h1:Y5tQ07KB856aSWXxVjb/Lob4AT8Gy/hJxZtwODI21CI= -github.com/dapr/dapr v1.13.0-rc.2/go.mod h1:QvxJ5htwv17PeRfFMGkHznEVRkpnt35re7TpF4CsCc8= +github.com/dapr/dapr v1.13.0-rc.7 h1:Z3r+eCPlWK6reJcfNuSL5Gu2+V81qyOIBgvY6EV8ZP4= +github.com/dapr/dapr v1.13.0-rc.7/go.mod h1:NHMC48qz9yEwIRDT1apo3GO+2SVoz5Ae7ejtg2B48RM= github.com/dapr/go-sdk v1.6.1-0.20231102031149-87bbb8cd690a h1:Sapb/wyFdMRDxn6PYYNh/P3WW3WcOIrpRSUdW+LT3bE= github.com/dapr/go-sdk v1.6.1-0.20231102031149-87bbb8cd690a/go.mod h1:DtFOk+AKGMho/vDTECVVX7WgovkGw64X30nyaEmRGXw= github.com/dapr/kit v0.13.0 h1:4S+5QqDCreva+MBONtIgxeg6B2b1W89bB8F5lqKgTa0=