diff --git a/pulsar/consumer_multitopic_test.go b/pulsar/consumer_multitopic_test.go index cd236ecc2..30ae5ccd1 100644 --- a/pulsar/consumer_multitopic_test.go +++ b/pulsar/consumer_multitopic_test.go @@ -18,18 +18,20 @@ package pulsar import ( + "context" "errors" "fmt" "strings" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/apache/pulsar-client-go/pulsar/internal" pb "github.com/apache/pulsar-client-go/pulsar/internal/pulsar_proto" "github.com/apache/pulsar-client-go/pulsaradmin" "github.com/apache/pulsar-client-go/pulsaradmin/pkg/admin/config" "github.com/apache/pulsar-client-go/pulsaradmin/pkg/utils" - "github.com/stretchr/testify/assert" ) func TestMultiTopicConsumerReceive(t *testing.T) { @@ -330,7 +332,7 @@ func (dummyConnection) SendRequestNoWait(_ *pb.BaseCommand) error { return nil } -func (dummyConnection) WriteData(_ internal.Buffer) { +func (dummyConnection) WriteData(_ context.Context, _ internal.Buffer) { } func (dummyConnection) RegisterListener(_ uint64, _ internal.ConnectionListener) error { diff --git a/pulsar/internal/connection.go b/pulsar/internal/connection.go index 57fc72419..2ad1acb5f 100644 --- a/pulsar/internal/connection.go +++ b/pulsar/internal/connection.go @@ -18,6 +18,7 @@ package internal import ( + "context" "crypto/tls" "crypto/x509" "errors" @@ -78,7 +79,7 @@ type ConnectionListener interface { type Connection interface { SendRequest(requestID uint64, req *pb.BaseCommand, callback func(*pb.BaseCommand, error)) SendRequestNoWait(req *pb.BaseCommand) error - WriteData(data Buffer) + WriteData(ctx context.Context, data Buffer) RegisterListener(id uint64, listener ConnectionListener) error UnregisterListener(id uint64) AddConsumeHandler(id uint64, handler ConsumerHandler) error @@ -129,6 +130,11 @@ type request struct { callback func(command *pb.BaseCommand, err error) } +type dataRequest struct { + ctx context.Context + data Buffer +} + type connection struct { started int32 connectionTimeout time.Duration @@ -157,7 +163,7 @@ type connection struct { incomingRequestsCh chan *request closeCh chan struct{} readyCh chan struct{} - writeRequestsCh chan Buffer + writeRequestsCh chan *dataRequest pendingLock sync.Mutex pendingReqs map[uint64]*request @@ -209,7 +215,7 @@ func newConnection(opts connectionOptions) *connection { // partition produces writing on a single connection. In general it's // good to keep this above the number of partition producers assigned // to a single connection. - writeRequestsCh: make(chan Buffer, 256), + writeRequestsCh: make(chan *dataRequest, 256), listeners: make(map[uint64]ConnectionListener), consumerHandlers: make(map[uint64]ConsumerHandler), metrics: opts.metrics, @@ -421,11 +427,11 @@ func (c *connection) run() { return // TODO: this never gonna be happen } c.internalSendRequest(req) - case data := <-c.writeRequestsCh: - if data == nil { + case req := <-c.writeRequestsCh: + if req == nil { return } - c.internalWriteData(data) + c.internalWriteData(req.ctx, req.data) case <-pingSendTicker.C: c.sendPing() @@ -450,22 +456,26 @@ func (c *connection) runPingCheck(pingCheckTicker *time.Ticker) { } } -func (c *connection) WriteData(data Buffer) { +func (c *connection) WriteData(ctx context.Context, data Buffer) { select { - case c.writeRequestsCh <- data: + case c.writeRequestsCh <- &dataRequest{ctx: ctx, data: data}: // Channel is not full return - + case <-ctx.Done(): + c.log.Debug("Write data context cancelled") + return default: // Channel full, fallback to probe if connection is closed } for { select { - case c.writeRequestsCh <- data: + case c.writeRequestsCh <- &dataRequest{ctx: ctx, data: data}: // Successfully wrote on the channel return - + case <-ctx.Done(): + c.log.Debug("Write data context cancelled") + return case <-time.After(100 * time.Millisecond): // The channel is either: // 1. blocked, in which case we need to wait until we have space @@ -481,11 +491,17 @@ func (c *connection) WriteData(data Buffer) { } -func (c *connection) internalWriteData(data Buffer) { +func (c *connection) internalWriteData(ctx context.Context, data Buffer) { c.log.Debug("Write data: ", data.ReadableBytes()) - if _, err := c.cnx.Write(data.ReadableSlice()); err != nil { - c.log.WithError(err).Warn("Failed to write on connection") - c.Close() + + select { + case <-ctx.Done(): + return + default: + if _, err := c.cnx.Write(data.ReadableSlice()); err != nil { + c.log.WithError(err).Warn("Failed to write on connection") + c.Close() + } } } @@ -510,7 +526,7 @@ func (c *connection) writeCommand(cmd *pb.BaseCommand) { } c.writeBuffer.WrittenBytes(cmdSize) - c.internalWriteData(c.writeBuffer) + c.internalWriteData(context.Background(), c.writeBuffer) } func (c *connection) receivedCommand(cmd *pb.BaseCommand, headersAndPayload Buffer) { diff --git a/pulsar/producer_partition.go b/pulsar/producer_partition.go index 448f780cf..f6523124c 100755 --- a/pulsar/producer_partition.go +++ b/pulsar/producer_partition.go @@ -394,7 +394,7 @@ func (p *partitionProducer) grabCnx(assignedBrokerURL string) error { pi.sentAt = time.Now() pi.Unlock() p.pendingQueue.Put(pi) - p._getConn().WriteData(pi.buffer) + p._getConn().WriteData(pi.ctx, pi.buffer) if pi == lastViewItem { break @@ -837,6 +837,8 @@ func (p *partitionProducer) internalSingleSend( type pendingItem struct { sync.Mutex + ctx context.Context + cancel context.CancelFunc buffer internal.Buffer sequenceID uint64 createdAt time.Time @@ -895,14 +897,17 @@ func (p *partitionProducer) writeData(buffer internal.Buffer, sequenceID uint64, return default: now := time.Now() + ctx, cancel := context.WithCancel(context.Background()) p.pendingQueue.Put(&pendingItem{ + ctx: ctx, + cancel: cancel, createdAt: now, sentAt: now, buffer: buffer, sequenceID: sequenceID, sendRequests: callbacks, }) - p._getConn().WriteData(buffer) + p._getConn().WriteData(ctx, buffer) } } @@ -1579,14 +1584,14 @@ type sendRequest struct { uuid string chunkRecorder *chunkRecorder - /// resource management + // resource management memLimit internal.MemoryLimitController reservedMem int64 semaphore internal.Semaphore reservedSemaphore int - /// convey settable state + // convey settable state sendAsBatch bool transaction *transaction @@ -1659,7 +1664,7 @@ func (sr *sendRequest) done(msgID MessageID, err error) { } func (p *partitionProducer) blockIfQueueFull() bool { - //DisableBlockIfQueueFull == false means enable block + // DisableBlockIfQueueFull == false means enable block return !p.options.DisableBlockIfQueueFull } @@ -1741,6 +1746,10 @@ func (i *pendingItem) done(err error) { if i.flushCallback != nil { i.flushCallback(err) } + + if i.cancel != nil { + i.cancel() + } } // _setConn sets the internal connection field of this partition producer atomically.