Skip to content

Commit

Permalink
Support acknowledging a list of message IDs (#1301)
Browse files Browse the repository at this point in the history
  • Loading branch information
BewareMyPower authored Nov 6, 2024
1 parent 875f6ba commit 35076ac
Show file tree
Hide file tree
Showing 11 changed files with 486 additions and 28 deletions.
60 changes: 36 additions & 24 deletions pulsar/ack_grouping_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func newAckGroupingTracker(options *AckGroupingOptions,
maxNumAcks: int(options.MaxSize),
ackCumulative: ackCumulative,
ackList: ackList,
pendingAcks: make(map[[2]uint64]*bitset.BitSet),
pendingAcks: make(map[position]*bitset.BitSet),
lastCumulativeAck: EarliestMessageID(),
}

Expand Down Expand Up @@ -110,6 +110,15 @@ func (i *immediateAckGroupingTracker) flushAndClean() {
func (i *immediateAckGroupingTracker) close() {
}

type position struct {
ledgerID uint64
entryID uint64
}

func newPosition(msgID MessageID) position {
return position{ledgerID: uint64(msgID.LedgerID()), entryID: uint64(msgID.EntryID())}
}

type timedAckGroupingTracker struct {
sync.RWMutex

Expand All @@ -124,7 +133,7 @@ type timedAckGroupingTracker struct {
// in the batch whose batch size is 3 are not acknowledged.
// After the 1st message (i.e. batch index is 0) is acknowledged, the bits will become "011".
// Value is nil if the entry represents a single message.
pendingAcks map[[2]uint64]*bitset.BitSet
pendingAcks map[position]*bitset.BitSet

lastCumulativeAck MessageID
cumulativeAckRequired int32
Expand All @@ -138,35 +147,36 @@ func (t *timedAckGroupingTracker) add(id MessageID) {
}
}

func (t *timedAckGroupingTracker) tryAddIndividual(id MessageID) map[[2]uint64]*bitset.BitSet {
t.Lock()
defer t.Unlock()
key := [2]uint64{uint64(id.LedgerID()), uint64(id.EntryID())}

func addMsgIDToPendingAcks(pendingAcks map[position]*bitset.BitSet, id MessageID) {
key := newPosition(id)
batchIdx := id.BatchIdx()
batchSize := id.BatchSize()

if batchIdx >= 0 && batchSize > 0 {
bs, found := t.pendingAcks[key]
bs, found := pendingAcks[key]
if !found {
if batchSize > 1 {
bs = bitset.New(uint(batchSize))
for i := uint(0); i < uint(batchSize); i++ {
bs.Set(i)
}
bs = bitset.New(uint(batchSize))
for i := uint(0); i < uint(batchSize); i++ {
bs.Set(i)
}
t.pendingAcks[key] = bs
pendingAcks[key] = bs
}
if bs != nil {
bs.Clear(uint(batchIdx))
}
} else {
t.pendingAcks[key] = nil
pendingAcks[key] = nil
}
}

func (t *timedAckGroupingTracker) tryAddIndividual(id MessageID) map[position]*bitset.BitSet {
t.Lock()
defer t.Unlock()

addMsgIDToPendingAcks(t.pendingAcks, id)
if len(t.pendingAcks) >= t.maxNumAcks {
pendingAcks := t.pendingAcks
t.pendingAcks = make(map[[2]uint64]*bitset.BitSet)
t.pendingAcks = make(map[position]*bitset.BitSet)
return pendingAcks
}
return nil
Expand Down Expand Up @@ -195,7 +205,7 @@ func (t *timedAckGroupingTracker) isDuplicate(id MessageID) bool {
if messageIDCompare(t.lastCumulativeAck, id) >= 0 {
return true
}
key := [2]uint64{uint64(id.LedgerID()), uint64(id.EntryID())}
key := newPosition(id)
if bs, found := t.pendingAcks[key]; found {
if bs == nil {
return true
Expand Down Expand Up @@ -232,11 +242,11 @@ func (t *timedAckGroupingTracker) flushAndClean() {
}
}

func (t *timedAckGroupingTracker) clearPendingAcks() map[[2]uint64]*bitset.BitSet {
func (t *timedAckGroupingTracker) clearPendingAcks() map[position]*bitset.BitSet {
t.Lock()
defer t.Unlock()
pendingAcks := t.pendingAcks
t.pendingAcks = make(map[[2]uint64]*bitset.BitSet)
t.pendingAcks = make(map[position]*bitset.BitSet)
return pendingAcks
}

Expand All @@ -250,12 +260,10 @@ func (t *timedAckGroupingTracker) close() {
}
}

func (t *timedAckGroupingTracker) flushIndividual(pendingAcks map[[2]uint64]*bitset.BitSet) {
func toMsgIDDataList(pendingAcks map[position]*bitset.BitSet) []*pb.MessageIdData {
msgIDs := make([]*pb.MessageIdData, 0, len(pendingAcks))
for k, v := range pendingAcks {
ledgerID := k[0]
entryID := k[1]
msgID := &pb.MessageIdData{LedgerId: &ledgerID, EntryId: &entryID}
msgID := &pb.MessageIdData{LedgerId: &k.ledgerID, EntryId: &k.entryID}
if v != nil && !v.None() {
bytes := v.Bytes()
msgID.AckSet = make([]int64, len(bytes))
Expand All @@ -265,5 +273,9 @@ func (t *timedAckGroupingTracker) flushIndividual(pendingAcks map[[2]uint64]*bit
}
msgIDs = append(msgIDs, msgID)
}
t.ackList(msgIDs)
return msgIDs
}

func (t *timedAckGroupingTracker) flushIndividual(pendingAcks map[position]*bitset.BitSet) {
t.ackList(toMsgIDDataList(pendingAcks))
}
31 changes: 31 additions & 0 deletions pulsar/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package pulsar

import (
"context"
"fmt"
"strings"
"time"

"github.com/apache/pulsar-client-go/pulsar/backoff"
Expand Down Expand Up @@ -266,6 +268,23 @@ type ConsumerOptions struct {
startMessageID *trackingMessageID
}

// This error is returned when `AckIDList` failed and `AckWithResponse` is true.
// It only contains the valid message IDs that failed to be acknowledged in the `AckIDList` call.
// For those invalid message IDs, users should ignore them and not acknowledge them again.
type AckError map[MessageID]error

func (e AckError) Error() string {
builder := strings.Builder{}
errorMap := make(map[string][]MessageID)
for id, err := range e {
errorMap[err.Error()] = append(errorMap[err.Error()], id)
}
for err, msgIDs := range errorMap {
builder.WriteString(fmt.Sprintf("error: %s, failed message IDs: %v\n", err, msgIDs))
}
return builder.String()
}

// Consumer is an interface that abstracts behavior of Pulsar's consumer
type Consumer interface {
// Subscription get a subscription for the consumer
Expand Down Expand Up @@ -305,8 +324,20 @@ type Consumer interface {
Ack(Message) error

// AckID the consumption of a single message, identified by its MessageID
// When `EnableBatchIndexAcknowledgment` is false, if a message ID represents a message in the batch,
// it will not be actually acknowledged by broker until all messages in that batch are acknowledged via
// `AckID` or `AckIDList`.
AckID(MessageID) error

// AckIDList the consumption of a list of messages, identified by their MessageIDs
//
// This method should be used when `AckWithResponse` is true. Otherwise, it will be equivalent with calling
// `AckID` on each message ID in the list.
//
// When `AckWithResponse` is true, the returned error could be an `AckError` which contains the failed message ID
// and the corresponding error.
AckIDList([]MessageID) error

// AckWithTxn the consumption of a single message with a transaction
AckWithTxn(Message, Transaction) error

Expand Down
10 changes: 10 additions & 0 deletions pulsar/consumer_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ const defaultNackRedeliveryDelay = 1 * time.Minute
type acker interface {
// AckID does not handle errors returned by the Broker side, so no need to wait for doneCh to finish.
AckID(id MessageID) error
AckIDList(msgIDs []MessageID) error
AckIDWithResponse(id MessageID) error
AckIDWithTxn(msgID MessageID, txn Transaction) error
AckIDCumulative(msgID MessageID) error
Expand Down Expand Up @@ -559,6 +560,15 @@ func (c *consumer) AckID(msgID MessageID) error {
return c.consumers[msgID.PartitionIdx()].AckID(msgID)
}

func (c *consumer) AckIDList(msgIDs []MessageID) error {
return ackIDListFromMultiTopics(c.log, msgIDs, func(msgID MessageID) (acker, error) {
if err := c.checkMsgIDPartition(msgID); err != nil {
return nil, err
}
return c.consumers[msgID.PartitionIdx()], nil
})
}

// AckCumulative the reception of all the messages in the stream up to (and including)
// the provided message, identified by its MessageID
func (c *consumer) AckCumulative(msg Message) error {
Expand Down
43 changes: 43 additions & 0 deletions pulsar/consumer_multitopic.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,49 @@ func (c *multiTopicConsumer) AckID(msgID MessageID) error {
return mid.consumer.AckID(msgID)
}

func (c *multiTopicConsumer) AckIDList(msgIDs []MessageID) error {
return ackIDListFromMultiTopics(c.log, msgIDs, func(msgID MessageID) (acker, error) {
if !checkMessageIDType(msgID) {
return nil, fmt.Errorf("invalid message id type %T", msgID)
}
if mid := toTrackingMessageID(msgID); mid != nil && mid.consumer != nil {
return mid.consumer, nil
}
return nil, errors.New("consumer is nil")
})
}

func ackIDListFromMultiTopics(log log.Logger, msgIDs []MessageID, findConsumer func(MessageID) (acker, error)) error {
consumerToMsgIDs := make(map[acker][]MessageID)
for _, msgID := range msgIDs {
if consumer, err := findConsumer(msgID); err == nil {
consumerToMsgIDs[consumer] = append(consumerToMsgIDs[consumer], msgID)
} else {
log.Warnf("Can not find consumer for %v", msgID)
}
}

ackError := AckError{}
for consumer, ids := range consumerToMsgIDs {
if err := consumer.AckIDList(ids); err != nil {
if topicAckError := err.(AckError); topicAckError != nil {
for id, err := range topicAckError {
ackError[id] = err
}
} else {
// It should not reach here
for _, id := range ids {
ackError[id] = err
}
}
}
}
if len(ackError) == 0 {
return nil
}
return ackError
}

// AckWithTxn the consumption of a single message with a transaction
func (c *multiTopicConsumer) AckWithTxn(msg Message, txn Transaction) error {
msgID := msg.ID()
Expand Down
99 changes: 99 additions & 0 deletions pulsar/consumer_multitopic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"strings"
"testing"
"time"

"github.com/apache/pulsar-client-go/pulsaradmin"
"github.com/apache/pulsar-client-go/pulsaradmin/pkg/admin/config"
Expand Down Expand Up @@ -218,3 +219,101 @@ func TestMultiTopicGetLastMessageIDs(t *testing.T) {
}

}

func TestMultiTopicAckIDList(t *testing.T) {
for _, params := range []bool{true, false} {
t.Run(fmt.Sprintf("TestMultiTopicConsumerAckIDList%v", params), func(t *testing.T) {
runMultiTopicAckIDList(t, params)
})
}
}

func runMultiTopicAckIDList(t *testing.T, regex bool) {
topicPrefix := fmt.Sprintf("multiTopicAckIDList%v", time.Now().UnixNano())
topic1 := "persistent://public/default/" + topicPrefix + "1"
topic2 := "persistent://public/default/" + topicPrefix + "2"

client, err := NewClient(ClientOptions{URL: "pulsar://localhost:6650"})
assert.Nil(t, err)
defer client.Close()

if regex {
admin, err := pulsaradmin.NewClient(&config.Config{})
assert.Nil(t, err)
for _, topic := range []string{topic1, topic2} {
topicName, err := utils.GetTopicName(topic)
assert.Nil(t, err)
admin.Topics().Create(*topicName, 0)
}
}

createConsumer := func() Consumer {
options := ConsumerOptions{
SubscriptionName: "sub",
Type: Shared,
AckWithResponse: true,
}
if regex {
options.TopicsPattern = topicPrefix + ".*"
} else {
options.Topics = []string{topic1, topic2}
}
consumer, err := client.Subscribe(options)
assert.Nil(t, err)
return consumer
}
consumer := createConsumer()

sendMessages(t, client, topic1, 0, 3, false)
sendMessages(t, client, topic2, 0, 2, false)

receiveMessageMap := func(consumer Consumer, numMessages int) map[string][]Message {
msgs := receiveMessages(t, consumer, numMessages)
topicToMsgs := make(map[string][]Message)
for _, msg := range msgs {
topicToMsgs[msg.Topic()] = append(topicToMsgs[msg.Topic()], msg)
}
return topicToMsgs
}

topicToMsgs := receiveMessageMap(consumer, 5)
assert.Equal(t, 3, len(topicToMsgs[topic1]))
for i := 0; i < 3; i++ {
assert.Equal(t, fmt.Sprintf("msg-%d", i), string(topicToMsgs[topic1][i].Payload()))
}
assert.Equal(t, 2, len(topicToMsgs[topic2]))
for i := 0; i < 2; i++ {
assert.Equal(t, fmt.Sprintf("msg-%d", i), string(topicToMsgs[topic2][i].Payload()))
}

assert.Nil(t, consumer.AckIDList([]MessageID{
topicToMsgs[topic1][0].ID(),
topicToMsgs[topic1][2].ID(),
topicToMsgs[topic2][1].ID(),
}))

consumer.Close()
consumer = createConsumer()
topicToMsgs = receiveMessageMap(consumer, 2)
assert.Equal(t, 1, len(topicToMsgs[topic1]))
assert.Equal(t, "msg-1", string(topicToMsgs[topic1][0].Payload()))
assert.Equal(t, 1, len(topicToMsgs[topic2]))
assert.Equal(t, "msg-0", string(topicToMsgs[topic2][0].Payload()))
consumer.Close()

msgID0 := topicToMsgs[topic1][0].ID()
err = consumer.AckIDList([]MessageID{msgID0})
assert.NotNil(t, err)
t.Logf("AckIDList error: %v", err)

msgID1 := topicToMsgs[topic2][0].ID()
if ackError, ok := consumer.AckIDList([]MessageID{msgID0, msgID1}).(AckError); ok {
assert.Equal(t, 2, len(ackError))
assert.Contains(t, ackError, msgID0)
assert.Equal(t, "consumer state is closed", ackError[msgID0].Error())
assert.Contains(t, ackError, msgID1)
assert.Equal(t, "consumer state is closed", ackError[msgID1].Error())
} else {
assert.Fail(t, "AckIDList should return AckError")
}
}
Loading

0 comments on commit 35076ac

Please sign in to comment.