Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Potential data race #1338

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pulsar/consumer_multitopic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,20 @@
package pulsar

import (
"context"
"errors"
"fmt"
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert"

"github.com/apache/pulsar-client-go/pulsar/internal"
pb "github.com/apache/pulsar-client-go/pulsar/internal/pulsar_proto"
"github.com/apache/pulsar-client-go/pulsaradmin"
"github.com/apache/pulsar-client-go/pulsaradmin/pkg/admin/config"
"github.com/apache/pulsar-client-go/pulsaradmin/pkg/utils"
"github.com/stretchr/testify/assert"
)

func TestMultiTopicConsumerReceive(t *testing.T) {
Expand Down Expand Up @@ -330,7 +332,7 @@ func (dummyConnection) SendRequestNoWait(_ *pb.BaseCommand) error {
return nil
}

func (dummyConnection) WriteData(_ internal.Buffer) {
func (dummyConnection) WriteData(_ context.Context, _ internal.Buffer) {
}

func (dummyConnection) RegisterListener(_ uint64, _ internal.ConnectionListener) error {
Expand Down
48 changes: 32 additions & 16 deletions pulsar/internal/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package internal

import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
Expand Down Expand Up @@ -78,7 +79,7 @@ type ConnectionListener interface {
type Connection interface {
SendRequest(requestID uint64, req *pb.BaseCommand, callback func(*pb.BaseCommand, error))
SendRequestNoWait(req *pb.BaseCommand) error
WriteData(data Buffer)
WriteData(ctx context.Context, data Buffer)
RegisterListener(id uint64, listener ConnectionListener) error
UnregisterListener(id uint64)
AddConsumeHandler(id uint64, handler ConsumerHandler) error
Expand Down Expand Up @@ -129,6 +130,11 @@ type request struct {
callback func(command *pb.BaseCommand, err error)
}

type dataRequest struct {
ctx context.Context
data Buffer
}

type connection struct {
started int32
connectionTimeout time.Duration
Expand Down Expand Up @@ -157,7 +163,7 @@ type connection struct {
incomingRequestsCh chan *request
closeCh chan struct{}
readyCh chan struct{}
writeRequestsCh chan Buffer
writeRequestsCh chan *dataRequest

pendingLock sync.Mutex
pendingReqs map[uint64]*request
Expand Down Expand Up @@ -209,7 +215,7 @@ func newConnection(opts connectionOptions) *connection {
// partition produces writing on a single connection. In general it's
// good to keep this above the number of partition producers assigned
// to a single connection.
writeRequestsCh: make(chan Buffer, 256),
writeRequestsCh: make(chan *dataRequest, 256),
listeners: make(map[uint64]ConnectionListener),
consumerHandlers: make(map[uint64]ConsumerHandler),
metrics: opts.metrics,
Expand Down Expand Up @@ -421,11 +427,11 @@ func (c *connection) run() {
return // TODO: this never gonna be happen
}
c.internalSendRequest(req)
case data := <-c.writeRequestsCh:
if data == nil {
case req := <-c.writeRequestsCh:
if req == nil {
return
}
c.internalWriteData(data)
c.internalWriteData(req.ctx, req.data)

case <-pingSendTicker.C:
c.sendPing()
Expand All @@ -450,22 +456,26 @@ func (c *connection) runPingCheck(pingCheckTicker *time.Ticker) {
}
}

func (c *connection) WriteData(data Buffer) {
func (c *connection) WriteData(ctx context.Context, data Buffer) {
select {
case c.writeRequestsCh <- data:
case c.writeRequestsCh <- &dataRequest{ctx: ctx, data: data}:
// Channel is not full
return

case <-ctx.Done():
c.log.Debug("Write data context cancelled")
return
default:
// Channel full, fallback to probe if connection is closed
}

for {
select {
case c.writeRequestsCh <- data:
case c.writeRequestsCh <- &dataRequest{ctx: ctx, data: data}:
// Successfully wrote on the channel
return

case <-ctx.Done():
c.log.Debug("Write data context cancelled")
return
case <-time.After(100 * time.Millisecond):
// The channel is either:
// 1. blocked, in which case we need to wait until we have space
Expand All @@ -481,11 +491,17 @@ func (c *connection) WriteData(data Buffer) {

}

func (c *connection) internalWriteData(data Buffer) {
func (c *connection) internalWriteData(ctx context.Context, data Buffer) {
c.log.Debug("Write data: ", data.ReadableBytes())
if _, err := c.cnx.Write(data.ReadableSlice()); err != nil {
c.log.WithError(err).Warn("Failed to write on connection")
c.Close()

select {
case <-ctx.Done():
return
default:
if _, err := c.cnx.Write(data.ReadableSlice()); err != nil {
c.log.WithError(err).Warn("Failed to write on connection")
c.Close()
}
}
}

Expand All @@ -510,7 +526,7 @@ func (c *connection) writeCommand(cmd *pb.BaseCommand) {
}

c.writeBuffer.WrittenBytes(cmdSize)
c.internalWriteData(c.writeBuffer)
c.internalWriteData(context.Background(), c.writeBuffer)
}

func (c *connection) receivedCommand(cmd *pb.BaseCommand, headersAndPayload Buffer) {
Expand Down
19 changes: 14 additions & 5 deletions pulsar/producer_partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ func (p *partitionProducer) grabCnx(assignedBrokerURL string) error {
pi.sentAt = time.Now()
pi.Unlock()
p.pendingQueue.Put(pi)
p._getConn().WriteData(pi.buffer)
p._getConn().WriteData(pi.ctx, pi.buffer)

if pi == lastViewItem {
break
Expand Down Expand Up @@ -837,6 +837,8 @@ func (p *partitionProducer) internalSingleSend(

type pendingItem struct {
sync.Mutex
ctx context.Context
cancel context.CancelFunc
buffer internal.Buffer
sequenceID uint64
createdAt time.Time
Expand Down Expand Up @@ -895,14 +897,17 @@ func (p *partitionProducer) writeData(buffer internal.Buffer, sequenceID uint64,
return
default:
now := time.Now()
ctx, cancel := context.WithCancel(context.Background())
p.pendingQueue.Put(&pendingItem{
ctx: ctx,
cancel: cancel,
createdAt: now,
sentAt: now,
buffer: buffer,
sequenceID: sequenceID,
sendRequests: callbacks,
})
p._getConn().WriteData(buffer)
p._getConn().WriteData(ctx, buffer)
}
}

Expand Down Expand Up @@ -1579,14 +1584,14 @@ type sendRequest struct {
uuid string
chunkRecorder *chunkRecorder

/// resource management
// resource management

memLimit internal.MemoryLimitController
reservedMem int64
semaphore internal.Semaphore
reservedSemaphore int

/// convey settable state
// convey settable state

sendAsBatch bool
transaction *transaction
Expand Down Expand Up @@ -1659,7 +1664,7 @@ func (sr *sendRequest) done(msgID MessageID, err error) {
}

func (p *partitionProducer) blockIfQueueFull() bool {
//DisableBlockIfQueueFull == false means enable block
// DisableBlockIfQueueFull == false means enable block
return !p.options.DisableBlockIfQueueFull
}

Expand Down Expand Up @@ -1741,6 +1746,10 @@ func (i *pendingItem) done(err error) {
if i.flushCallback != nil {
i.flushCallback(err)
}

if i.cancel != nil {
i.cancel()
}
}

// _setConn sets the internal connection field of this partition producer atomically.
Expand Down
Loading