Skip to content

Fix producers reconnection deadlock #394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 23 additions & 19 deletions pkg/stream/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -471,37 +471,41 @@ func (c *Client) closeHartBeat() {
}

func (c *Client) Close() error {

c.closeHartBeat()
for _, p := range c.coordinator.Producers() {
err := c.coordinator.RemoveProducerById(p.(*Producer).id, Event{
c.coordinator.Producers().Range(func(_, p any) bool {
producer := p.(*Producer)
err := c.coordinator.RemoveProducerById(producer.id, Event{
Command: CommandClose,
StreamName: p.(*Producer).GetStreamName(),
Name: p.(*Producer).GetName(),
StreamName: producer.GetStreamName(),
Name: producer.GetName(),
Reason: SocketClosed,
Err: nil,
})

if err != nil {
logs.LogWarn("error removing producer: %s", err)
}
}

for _, cs := range c.coordinator.GetConsumers() {
if cs != nil {
err := c.coordinator.RemoveConsumerById(cs.(*Consumer).ID, Event{
Command: CommandClose,
StreamName: cs.(*Consumer).GetStreamName(),
Name: cs.(*Consumer).GetName(),
Reason: SocketClosed,
Err: nil,
})
return true
})

if err != nil {
logs.LogWarn("error removing consumer: %s", err)
}
c.coordinator.Consumers().Range(func(_, cs any) bool {
consumer := cs.(*Consumer)
err := c.coordinator.RemoveConsumerById(consumer.ID, Event{
Command: CommandClose,
StreamName: consumer.GetStreamName(),
Name: consumer.GetName(),
Reason: SocketClosed,
Err: nil,
})

if err != nil {
logs.LogWarn("error removing consumer: %s", err)
}
}

return true
})

if c.getSocket().isOpen() {

res := c.coordinator.NewResponse(CommandClose)
Expand Down
98 changes: 52 additions & 46 deletions pkg/stream/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import (

type Coordinator struct {
counter int
producers map[interface{}]interface{}
consumers map[interface{}]interface{}
producers *sync.Map
consumers *sync.Map
responses map[interface{}]interface{}
nextItemProducer uint8
nextItemConsumer uint8
Expand Down Expand Up @@ -43,8 +43,8 @@ type Response struct {

func NewCoordinator() *Coordinator {
return &Coordinator{mutex: &sync.Mutex{},
producers: make(map[interface{}]interface{}),
consumers: make(map[interface{}]interface{}),
producers: &sync.Map{},
consumers: &sync.Map{},
responses: make(map[interface{}]interface{})}
}

Expand Down Expand Up @@ -77,7 +77,7 @@ func (coordinator *Coordinator) NewProducer(
confirmMutex: &sync.Mutex{},
onClose: cleanUp,
}
coordinator.producers[lastId] = producer
coordinator.producers.Store(lastId, producer)
return producer, err
}

Expand All @@ -89,11 +89,8 @@ func (coordinator *Coordinator) RemoveConsumerById(id interface{}, reason Event)
return consumer.close(reason)

}
func (coordinator *Coordinator) GetConsumers() map[interface{}]interface{} {
coordinator.mutex.Lock()
defer coordinator.mutex.Unlock()
func (coordinator *Coordinator) Consumers() *sync.Map {
return coordinator.consumers

}

func (coordinator *Coordinator) RemoveProducerById(id uint8, reason Event) error {
Expand All @@ -117,7 +114,7 @@ func (coordinator *Coordinator) RemoveResponseById(id interface{}) error {
}

func (coordinator *Coordinator) ProducersCount() int {
return coordinator.count(coordinator.producers)
return coordinator.countSyncMap(coordinator.producers)
}

// response
Expand Down Expand Up @@ -198,28 +195,25 @@ func (coordinator *Coordinator) NewConsumer(messagesHandler MessagesHandler,
onClose: cleanUp,
}

coordinator.consumers[lastId] = item
coordinator.consumers.Store(lastId, item)

return item
}

func (coordinator *Coordinator) GetConsumerById(id interface{}) (*Consumer, error) {
v, err := coordinator.getById(id, coordinator.consumers)
if err != nil {
return nil, err
if consumer, exists := coordinator.consumers.Load(id); exists {
return consumer.(*Consumer), nil
}
return v.(*Consumer), err

return nil, errors.New("item #{id} not found ")
}

func (coordinator *Coordinator) ExtractConsumerById(id interface{}) (*Consumer, error) {
coordinator.mutex.Lock()
defer coordinator.mutex.Unlock()
if coordinator.consumers[id] == nil {
return nil, errors.New("item #{id} not found ")
if consumer, exists := coordinator.consumers.LoadAndDelete(id); exists {
return consumer.(*Consumer), nil
}
consumer := coordinator.consumers[id].(*Consumer)
coordinator.consumers[id] = nil
delete(coordinator.consumers, id)
return consumer, nil

return nil, errors.New("item #{id} not found ")
}

func (coordinator *Coordinator) GetResponseById(id uint32) (*Response, error) {
Expand All @@ -231,31 +225,26 @@ func (coordinator *Coordinator) GetResponseById(id uint32) (*Response, error) {
}

func (coordinator *Coordinator) ConsumersCount() int {
return coordinator.count(coordinator.consumers)
return coordinator.countSyncMap(coordinator.consumers)
}

func (coordinator *Coordinator) GetProducerById(id interface{}) (*Producer, error) {
v, err := coordinator.getById(id, coordinator.producers)
if err != nil {
return nil, err
if producer, exists := coordinator.producers.Load(id); exists {
return producer.(*Producer), nil
}
return v.(*Producer), err

return nil, errors.New("item #{id} not found ")
}

func (coordinator *Coordinator) ExtractProducerById(id interface{}) (*Producer, error) {
coordinator.mutex.Lock()
defer coordinator.mutex.Unlock()
if coordinator.producers[id] == nil {
return nil, errors.New("item #{id} not found ")
if producer, exists := coordinator.producers.LoadAndDelete(id); exists {
return producer.(*Producer), nil
}
producer := coordinator.producers[id].(*Producer)
coordinator.producers[id] = nil
delete(coordinator.producers, id)
return producer, nil

return nil, errors.New("item #{id} not found ")
}

// general functions

func (coordinator *Coordinator) getById(id interface{}, refmap map[interface{}]interface{}) (interface{}, error) {
coordinator.mutex.Lock()
defer coordinator.mutex.Unlock()
Expand All @@ -276,11 +265,16 @@ func (coordinator *Coordinator) removeById(id interface{}, refmap map[interface{
return nil
}

func (coordinator *Coordinator) count(refmap map[interface{}]interface{}) int {
coordinator.mutex.Lock()
defer coordinator.mutex.Unlock()
return len(refmap)
func (coordinator *Coordinator) countSyncMap(refmap *sync.Map) int {
count := 0
refmap.Range(func(_, _ interface{}) bool {
count++
return true
})

return count
}

func (coordinator *Coordinator) getNextProducerItem() (uint8, error) {
if coordinator.nextItemProducer >= ^uint8(0) {
return coordinator.reuseFreeId(coordinator.producers)
Expand All @@ -299,11 +293,11 @@ func (coordinator *Coordinator) getNextConsumerItem() (uint8, error) {
return res, nil
}

func (coordinator *Coordinator) reuseFreeId(refMap map[interface{}]interface{}) (byte, error) {
func (coordinator *Coordinator) reuseFreeId(refMap *sync.Map) (byte, error) {
maxValue := int(^uint8(0))
var result byte
for i := 0; i < maxValue; i++ {
if refMap[byte(i)] == nil {
if _, exists := refMap.Load(byte(i)); !exists {
return byte(i), nil
}
result++
Expand All @@ -314,8 +308,20 @@ func (coordinator *Coordinator) reuseFreeId(refMap map[interface{}]interface{})
return result, nil
}

func (coordinator *Coordinator) Producers() map[interface{}]interface{} {
coordinator.mutex.Lock()
defer coordinator.mutex.Unlock()
func (coordinator *Coordinator) Producers() *sync.Map {
return coordinator.producers
}

func (coordinator *Coordinator) Close() {
coordinator.producers.Range(func(_, producer interface{}) bool {
_ = producer.(*Producer).Close()

return true
})

coordinator.consumers.Range(func(_, consumer interface{}) bool {
_ = consumer.(*Consumer).Close()

return true
})
}
39 changes: 17 additions & 22 deletions pkg/stream/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -510,42 +510,43 @@ func (cc *environmentCoordinator) maybeCleanClients() {
}

func (c *Client) maybeCleanProducers(streamName string) {
c.mutex.Lock()
for pidx, producer := range c.coordinator.Producers() {
if producer.(*Producer).GetStreamName() == streamName {
c.coordinator.Producers().Range(func(pidx, p any) bool {
producer := p.(*Producer)
if producer.GetStreamName() == streamName {
err := c.coordinator.RemoveProducerById(pidx.(uint8), Event{
Command: CommandMetadataUpdate,
StreamName: streamName,
Name: producer.(*Producer).GetName(),
Name: producer.GetName(),
Reason: MetaDataUpdate,
Err: nil,
})
if err != nil {
return
return false
}
}
}
c.mutex.Unlock()

return true
})
}

func (c *Client) maybeCleanConsumers(streamName string) {
c.mutex.Lock()
for pidx, consumer := range c.coordinator.consumers {
if consumer.(*Consumer).options.streamName == streamName {
c.coordinator.Consumers().Range(func(pidx, cs any) bool {
consumer := cs.(*Consumer)
if consumer.options.streamName == streamName {
err := c.coordinator.RemoveConsumerById(pidx.(uint8), Event{
Command: CommandMetadataUpdate,
StreamName: streamName,
Name: consumer.(*Consumer).GetName(),
Name: consumer.GetName(),
Reason: MetaDataUpdate,
Err: nil,
})
if err != nil {
return
return false
}
}
}
c.mutex.Unlock()

return true
})
}

func (cc *environmentCoordinator) newProducer(leader *Broker, tcpParameters *TCPParameters, saslConfiguration *SaslConfiguration, streamName string, options *ProducerOptions, rpcTimeout time.Duration, cleanUp func()) (*Producer, error) {
Expand Down Expand Up @@ -643,15 +644,9 @@ func (cc *environmentCoordinator) newConsumer(connectionName string, leader *Bro
}

func (cc *environmentCoordinator) Close() error {

cc.clientsPerContext.Range(func(key, value any) bool {
client := value.(*Client)
for i := range client.coordinator.producers {
_ = client.coordinator.producers[i].(*Producer).Close()
}
for i := range client.coordinator.consumers {
_ = client.coordinator.consumers[i].(*Consumer).Close()
}
value.(*Client).coordinator.Close()

return true
})

Expand Down
57 changes: 57 additions & 0 deletions pkg/stream/super_stream_producer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -476,4 +476,61 @@ var _ = Describe("Super Stream Producer", Label("super-stream-producer"), func()
Expect(env.Close()).NotTo(HaveOccurred())
})

It("should reconnect to the same partition after a close event", func() {
const partitionsCount = 3
env, err := NewEnvironment(nil)
Expect(err).NotTo(HaveOccurred())

var superStream = fmt.Sprintf("reconnect-test-super-stream-%d", time.Now().Unix())
Expect(env.DeclareSuperStream(superStream, NewPartitionsOptions(partitionsCount))).NotTo(HaveOccurred())

superProducer, err := newSuperStreamProducer(env, superStream, &SuperStreamProducerOptions{
RoutingStrategy: NewHashRoutingStrategy(func(msg message.StreamMessage) string {
return msg.GetApplicationProperties()["routingKey"].(string)
}),
})
Expect(err).To(BeNil())
Expect(superProducer).NotTo(BeNil())
Expect(superProducer.init()).NotTo(HaveOccurred())
producers := superProducer.getProducers()
Expect(producers).To(HaveLen(partitionsCount))
partitionToClose := producers[0].GetStreamName()

// Declare synchronization helpers and listeners
partitionCloseEvent := make(chan bool)

// Listen for the partition close event and try to reconnect
go func(ch <-chan PPartitionClose) {
for event := range ch {
err := event.Context.ConnectPartition(event.Partition)
Expect(err).To(BeNil())

partitionCloseEvent <- true

break

}
}(superProducer.NotifyPartitionClose(1))

// Imitates metadataUpdateFrameHandler - it can happen when stream members are changed.
go func() {
client, ok := env.producers.getCoordinators()["localhost:5552"].clientsPerContext.Load(1)
Expect(ok).To(BeTrue())
client.(*Client).maybeCleanProducers(partitionToClose)
}()

// Wait for the partition close event
Eventually(partitionCloseEvent).WithTimeout(5 * time.Second).WithPolling(100 * time.Millisecond).Should(Receive())

// Verify that the partition was successfully reconnected
Expect(superProducer.getProducers()).To(HaveLen(partitionsCount))
reconnectedProducer := superProducer.getProducer(partitionToClose)
Expect(reconnectedProducer).NotTo(BeNil())

// Clean up
Expect(superProducer.Close()).NotTo(HaveOccurred())
Expect(env.DeleteSuperStream(superStream)).NotTo(HaveOccurred())
Expect(env.Close()).NotTo(HaveOccurred())
})

})