From c6f09e74586625da8522d8689d846b34e2662adb Mon Sep 17 00:00:00 2001 From: Baodi Shi Date: Wed, 26 Feb 2025 18:43:13 +0800 Subject: [PATCH] Fix wrong result of hasNext after seeking by id or time --- pulsar/consumer_impl.go | 1 + pulsar/consumer_partition.go | 97 ++++++++++++++++------ pulsar/consumer_test.go | 5 +- pulsar/impl_message.go | 14 ++++ pulsar/reader_impl.go | 22 ++--- pulsar/reader_test.go | 156 +++++++++++++++++++++++++++++++++++ 6 files changed, 254 insertions(+), 41 deletions(-) diff --git a/pulsar/consumer_impl.go b/pulsar/consumer_impl.go index b4903516ea..8fdd336aca 100644 --- a/pulsar/consumer_impl.go +++ b/pulsar/consumer_impl.go @@ -88,6 +88,7 @@ func newConsumer(client *client, options ConsumerOptions) (Consumer, error) { if options.EnableZeroQueueConsumer { options.ReceiverQueueSize = 0 + options.StartMessageIDInclusive = true } if options.Interceptors == nil { diff --git a/pulsar/consumer_partition.go b/pulsar/consumer_partition.go index 24ffa401f9..8bf28b3f0e 100644 --- a/pulsar/consumer_partition.go +++ b/pulsar/consumer_partition.go @@ -190,9 +190,13 @@ type partitionConsumer struct { backoffPolicyFunc func() backoff.Policy dispatcherSeekingControlCh chan struct{} - isSeeking atomic.Bool - ctx context.Context - cancelFunc context.CancelFunc + // handle to the dispatcher goroutine + isSeeking atomic.Bool + // After executing seekByTime, the client is unaware of the startMessageId. + // Use this flag to compare markDeletePosition with BrokerLastMessageId when checking hasMoreMessages. + hasSoughtByTime atomic.Bool + ctx context.Context + cancelFunc context.CancelFunc } // pauseDispatchMessage used to discard the message in the dispatcher goroutine. @@ -429,11 +433,12 @@ func newPartitionConsumer(parent Consumer, client *client, options *partitionCon startingMessageID := pc.startMessageID.get() if pc.options.startMessageIDInclusive && startingMessageID != nil && startingMessageID.equal(latestMessageID) { - msgID, err := pc.requestGetLastMessageID() + msgIDResp, err := pc.requestGetLastMessageID() if err != nil { pc.Close() return nil, err } + msgID := convertToMessageID(msgIDResp.GetLastMessageId()) if msgID.entryID != noMessageEntry { pc.startMessageID.set(msgID) @@ -616,18 +621,27 @@ func (pc *partitionConsumer) internalUnsubscribe(unsub *unsubscribeRequest) { } func (pc *partitionConsumer) getLastMessageID() (*trackingMessageID, error) { + res, err := pc.getLastMessageIDAndMarkDeletePosition() + if err != nil { + return nil, err + } + return res.msgID, err +} + +func (pc *partitionConsumer) getLastMessageIDAndMarkDeletePosition() (*getLastMsgResult, error) { if state := pc.getConsumerState(); state == consumerClosed || state == consumerClosing { pc.log.WithField("state", state).Error("Failed to getLastMessageID for the closing or closed consumer") return nil, errors.New("failed to getLastMessageID for the closing or closed consumer") } bo := pc.backoffPolicyFunc() - request := func() (*trackingMessageID, error) { + request := func() (*getLastMsgResult, error) { req := &getLastMsgIDRequest{doneCh: make(chan struct{})} pc.eventsCh <- req // wait for the request to complete <-req.doneCh - return req.msgID, req.err + res := &getLastMsgResult{req.msgID, req.markDeletePosition} + return res, req.err } ctx, cancel := context.WithTimeout(context.Background(), pc.client.operationTimeout) @@ -647,10 +661,16 @@ func (pc *partitionConsumer) getLastMessageID() (*trackingMessageID, error) { func (pc *partitionConsumer) internalGetLastMessageID(req *getLastMsgIDRequest) { defer close(req.doneCh) - req.msgID, req.err = pc.requestGetLastMessageID() + rsp, err := pc.requestGetLastMessageID() + if err != nil { + req.err = err + return + } + req.msgID = convertToMessageID(rsp.GetLastMessageId()) + req.markDeletePosition = convertToMessageID(rsp.GetConsumerMarkDeletePosition()) } -func (pc *partitionConsumer) requestGetLastMessageID() (*trackingMessageID, error) { +func (pc *partitionConsumer) requestGetLastMessageID() (*pb.CommandGetLastMessageIdResponse, error) { if state := pc.getConsumerState(); state == consumerClosed || state == consumerClosing { pc.log.WithField("state", state).Error("Failed to getLastMessageID closing or closed consumer") return nil, errors.New("failed to getLastMessageID closing or closed consumer") @@ -667,8 +687,7 @@ func (pc *partitionConsumer) requestGetLastMessageID() (*trackingMessageID, erro pc.log.WithError(err).Error("Failed to get last message id") return nil, err } - id := res.Response.GetLastMessageIdResponse.GetLastMessageId() - return convertToMessageID(id), nil + return res.Response.GetLastMessageIdResponse, nil } func (pc *partitionConsumer) sendIndividualAck(msgID MessageID) *ackRequest { @@ -997,7 +1016,15 @@ func (pc *partitionConsumer) requestSeek(msgID *messageID) error { if err := pc.requestSeekWithoutClear(msgID); err != nil { return err } - pc.clearReceiverQueue() + // When the seek operation is successful, it indicates: + // 1. The broker has reset the cursor and sent a request to close the consumer on the client side. + // Since this method is in the same goroutine as the reconnectToBroker, + // we can safely clear the messages in the queue (at this point, it won't contain messages after the seek). + // 2. The startMessageID is reset to ensure accurate judgment when calling hasNext next time. + // Since the messages in the queue are cleared here reconnection won't reset startMessageId. + pc.lastDequeuedMsg = nil + pc.startMessageID.set(toTrackingMessageID(msgID)) + pc.clearQueueAndGetNextMessage() return nil } @@ -1069,7 +1096,9 @@ func (pc *partitionConsumer) internalSeekByTime(seek *seekByTimeRequest) { seek.err = err return } - pc.clearReceiverQueue() + pc.lastDequeuedMsg = nil + pc.hasSoughtByTime.Store(true) + pc.clearQueueAndGetNextMessage() } func (pc *partitionConsumer) internalAck(req *ackRequest) { @@ -1451,10 +1480,6 @@ func (pc *partitionConsumer) messageShouldBeDiscarded(msgID *trackingMessageID) if pc.startMessageID.get() == nil { return false } - // if we start at latest message, we should never discard - if pc.options.startMessageID != nil && pc.options.startMessageID.equal(latestMessageID) { - return false - } if pc.options.startMessageIDInclusive { return pc.startMessageID.get().greater(msgID.messageID) @@ -1709,9 +1734,15 @@ type redeliveryRequest struct { } type getLastMsgIDRequest struct { - doneCh chan struct{} - msgID *trackingMessageID - err error + doneCh chan struct{} + msgID *trackingMessageID + markDeletePosition *trackingMessageID + err error +} + +type getLastMsgResult struct { + msgID *trackingMessageID + markDeletePosition *trackingMessageID } type seekRequest struct { @@ -2195,6 +2226,24 @@ func (pc *partitionConsumer) discardCorruptedMessage(msgID *pb.MessageIdData, } func (pc *partitionConsumer) hasNext() bool { + + // If a seek by time has been performed, then the `startMessageId` becomes irrelevant. + // We need to compare `markDeletePosition` and `lastMessageId`, + // and then reset `startMessageID` to `markDeletePosition`. + if pc.hasSoughtByTime.CompareAndSwap(true, false) { + res, err := pc.getLastMessageIDAndMarkDeletePosition() + if err != nil { + pc.log.WithError(err).Error("Failed to get last message id") + return false + } + pc.lastMessageInBroker = res.msgID + pc.startMessageID.set(res.markDeletePosition) + // We only care about comparing ledger ids and entry ids as mark delete position + // doesn't have other ids such as batch index + compareResult := pc.lastMessageInBroker.messageID.compareLedgerAndEntryID(pc.startMessageID.get().messageID) + return compareResult > 0 || (pc.options.startMessageIDInclusive && compareResult == 0) + } + if pc.lastMessageInBroker != nil && pc.hasMoreMessages() { return true } @@ -2256,12 +2305,14 @@ func convertToMessageID(id *pb.MessageIdData) *trackingMessageID { msgID := &trackingMessageID{ messageID: &messageID{ - ledgerID: int64(*id.LedgerId), - entryID: int64(*id.EntryId), + ledgerID: int64(id.GetLedgerId()), + entryID: int64(id.GetEntryId()), + batchIdx: -1, + batchSize: id.GetBatchSize(), }, } - if id.BatchIndex != nil { - msgID.batchIdx = *id.BatchIndex + if id.GetBatchSize() > 1 { + msgID.batchIdx = id.GetBatchIndex() } return msgID diff --git a/pulsar/consumer_test.go b/pulsar/consumer_test.go index 33e4d057ee..8e5c2b84e2 100644 --- a/pulsar/consumer_test.go +++ b/pulsar/consumer_test.go @@ -1262,8 +1262,9 @@ func TestConsumerSeek(t *testing.T) { defer producer.Close() consumer, err := client.Subscribe(ConsumerOptions{ - Topic: topicName, - SubscriptionName: "sub-1", + Topic: topicName, + SubscriptionName: "sub-1", + StartMessageIDInclusive: true, }) assert.Nil(t, err) defer consumer.Close() diff --git a/pulsar/impl_message.go b/pulsar/impl_message.go index 0acd782b80..566e5082f8 100644 --- a/pulsar/impl_message.go +++ b/pulsar/impl_message.go @@ -18,6 +18,7 @@ package pulsar import ( + "cmp" "errors" "fmt" "math" @@ -147,6 +148,13 @@ func (id *messageID) equal(other *messageID) bool { id.batchIdx == other.batchIdx } +func (id *messageID) compareLedgerAndEntryID(other *messageID) int { + if result := cmp.Compare(id.ledgerID, other.ledgerID); result != 0 { + return result + } + return cmp.Compare(id.entryID, other.entryID) +} + func (id *messageID) greaterEqual(other *messageID) bool { return id.equal(other) || id.greater(other) } @@ -204,6 +212,9 @@ func deserializeMessageID(data []byte) (MessageID, error) { } func newMessageID(ledgerID int64, entryID int64, batchIdx int32, partitionIdx int32, batchSize int32) MessageID { + if batchSize <= 1 { + batchIdx = -1 + } return &messageID{ ledgerID: ledgerID, entryID: entryID, @@ -225,6 +236,9 @@ func fromMessageID(msgID MessageID) *messageID { func newTrackingMessageID(ledgerID int64, entryID int64, batchIdx int32, partitionIdx int32, batchSize int32, tracker *ackTracker) *trackingMessageID { + if batchSize <= 1 { + batchIdx = -1 + } return &trackingMessageID{ messageID: &messageID{ ledgerID: ledgerID, diff --git a/pulsar/reader_impl.go b/pulsar/reader_impl.go index f76255e2e8..55b05037f7 100644 --- a/pulsar/reader_impl.go +++ b/pulsar/reader_impl.go @@ -196,19 +196,6 @@ func (r *reader) Close() { r.metrics.ReadersClosed.Inc() } -func (r *reader) messageID(msgID MessageID) *trackingMessageID { - mid := toTrackingMessageID(msgID) - - partition := int(mid.partitionIdx) - // did we receive a valid partition index? - if partition < 0 { - r.log.Warnf("invalid partition index %d expected", partition) - return nil - } - - return mid -} - func (r *reader) Seek(msgID MessageID) error { r.Lock() defer r.Unlock() @@ -218,9 +205,12 @@ func (r *reader) Seek(msgID MessageID) error { return fmt.Errorf("invalid message id type %T", msgID) } - mid := r.messageID(msgID) - if mid == nil { - return nil + mid := toTrackingMessageID(msgID) + + partition := int(mid.partitionIdx) + if partition < 0 { + r.log.Warnf("invalid partition index %d expected", partition) + return fmt.Errorf("seek msgId must include partitoinIndex") } return r.c.Seek(mid) diff --git a/pulsar/reader_test.go b/pulsar/reader_test.go index 836535704b..61871e6297 100644 --- a/pulsar/reader_test.go +++ b/pulsar/reader_test.go @@ -1070,3 +1070,159 @@ func TestReaderNextReturnsOnClosedConsumer(t *testing.T) { assert.ErrorAs(t, err, &e) assert.Equal(t, ConsumerClosed, e.Result()) } + +func testReaderSeekByIDWithHasNext(t *testing.T, startMessageID MessageID, startMessageIDInclusive bool) { + client, err := NewClient(ClientOptions{ + URL: lookupURL, + }) + + assert.Nil(t, err) + defer client.Close() + + topic := newTopicName() + ctx := context.Background() + + // create producer + producer, err := client.CreateProducer(ProducerOptions{ + Topic: topic, + DisableBatching: true, + }) + assert.Nil(t, err) + defer producer.Close() + + // send 100 messages + var lastMsgID MessageID + for i := 0; i < 10; i++ { + lastMsgID, err = producer.Send(ctx, &ProducerMessage{ + Payload: []byte(fmt.Sprintf("hello-%d", i)), + }) + fmt.Println(lastMsgID.String()) + assert.NoError(t, err) + assert.NotNil(t, lastMsgID) + } + + reader, err := client.CreateReader(ReaderOptions{ + Topic: topic, + StartMessageID: startMessageID, + StartMessageIDInclusive: startMessageIDInclusive, + }) + assert.Nil(t, err) + defer reader.Close() + + // Seek to last message ID + err = reader.Seek(lastMsgID) + assert.NoError(t, err) + + if startMessageIDInclusive { + assert.True(t, reader.HasNext()) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + msg, err := reader.Next(ctx) + assert.NoError(t, err) + assert.NotNil(t, msg) + assert.True(t, messageIDCompare(lastMsgID, msg.ID()) == 0) + cancel() + } else { + assert.False(t, reader.HasNext()) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + msg, err := reader.Next(ctx) + assert.Error(t, err) + assert.Nil(t, msg) + cancel() + } + +} + +func TestReaderWithSeekByID(t *testing.T) { + testReaderSeekByIDWithHasNext(t, EarliestMessageID(), false) + testReaderSeekByIDWithHasNext(t, EarliestMessageID(), true) + testReaderSeekByIDWithHasNext(t, LatestMessageID(), false) + testReaderSeekByIDWithHasNext(t, LatestMessageID(), true) +} + +func testReaderSeekByTimeWithHasNext(t *testing.T, startMessageID MessageID) { + client, err := NewClient(ClientOptions{ + URL: lookupURL, + }) + + assert.Nil(t, err) + defer client.Close() + + topic := newTopicName() + ctx := context.Background() + + // create producer + producer, err := client.CreateProducer(ProducerOptions{ + Topic: topic, + DisableBatching: true, + }) + assert.Nil(t, err) + defer producer.Close() + + // 1. send 10 messages + var lastMsgID MessageID + for i := 0; i < 10; i++ { + lastMsgID, err = producer.Send(ctx, &ProducerMessage{ + Payload: []byte(fmt.Sprintf("hello-%d", i)), + }) + fmt.Println(lastMsgID.String()) + assert.NoError(t, err) + + assert.NotNil(t, lastMsgID) + } + + // 2. create reader + reader, err := client.CreateReader(ReaderOptions{ + Topic: topic, + StartMessageID: startMessageID, + StartMessageIDInclusive: false, + }) + assert.Nil(t, err) + defer reader.Close() + + // 3. Seek time to now + reader.SeekByTime(time.Now()) + + // 4. Should not receive msg + { + assert.False(t, reader.HasNext()) + timeoutCtx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + msg, err := reader.Next(timeoutCtx) + assert.Error(t, err) + assert.Nil(t, msg) + cancel() + } + + // 5. send more 10 messages + for i := 0; i < 10; i++ { + lastMsgID, err = producer.Send(ctx, &ProducerMessage{ + Payload: []byte(fmt.Sprintf("hello2-%d", i)), + }) + fmt.Println(lastMsgID.String()) + assert.NoError(t, err) + assert.NotNil(t, lastMsgID) + } + + // 6. Assert these messages are received + for i := 0; i < 10; i++ { + assert.True(t, reader.HasNext()) + msg, err := reader.Next(context.Background()) + assert.NoError(t, err) + assert.Equal(t, fmt.Sprintf("hello2-%d", i), string(msg.Payload())) + } + + // assert not more msg + { + assert.False(t, reader.HasNext()) + timeoutCtx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + msg, err := reader.Next(timeoutCtx) + assert.Error(t, err) + assert.Nil(t, msg) + cancel() + } +} +func TestReaderWithSeekByTime(t *testing.T) { + testReaderSeekByTimeWithHasNext(t, EarliestMessageID()) + testReaderSeekByTimeWithHasNext(t, EarliestMessageID()) + testReaderSeekByTimeWithHasNext(t, LatestMessageID()) + testReaderSeekByTimeWithHasNext(t, LatestMessageID()) +}