Skip to content

Commit

Permalink
fix: seek race
Browse files Browse the repository at this point in the history
  • Loading branch information
nodece committed Sep 11, 2024
1 parent 98dc8d4 commit 4dec08d
Showing 1 changed file with 82 additions and 97 deletions.
179 changes: 82 additions & 97 deletions pulsar/consumer_partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"math"
"strings"
"sync"
"sync/atomic"
"time"

"google.golang.org/protobuf/proto"
Expand Down Expand Up @@ -152,7 +153,7 @@ type partitionConsumer struct {

// the size of the queue channel for buffering messages
maxQueueSize int32
queueCh chan []*message
queueCh chan *message
startMessageID atomicMessageID
lastDequeuedMsg *trackingMessageID

Expand Down Expand Up @@ -182,6 +183,8 @@ type partitionConsumer struct {
lastMessageInBroker *trackingMessageID

redirectedClusterURI string

seekCh atomic.Pointer[chan struct{}]
}

func (pc *partitionConsumer) ActiveConsumerChanged(isActive bool) {
Expand Down Expand Up @@ -328,7 +331,7 @@ func newPartitionConsumer(parent Consumer, client *client, options *partitionCon
partitionIdx: int32(options.partitionIdx),
eventsCh: make(chan interface{}, 10),
maxQueueSize: int32(options.receiverQueueSize),
queueCh: make(chan []*message, options.receiverQueueSize),
queueCh: make(chan *message, options.receiverQueueSize),
startMessageID: atomicMessageID{msgID: options.startMessageID},
connectedCh: make(chan struct{}),
messageCh: messageCh,
Expand Down Expand Up @@ -847,6 +850,8 @@ func (pc *partitionConsumer) Close() {
<-req.doneCh
}

var errSeekInProgress = errors.New("seek operation is already in progress")

func (pc *partitionConsumer) Seek(msgID MessageID) error {
if state := pc.getConsumerState(); state == consumerClosed || state == consumerClosing {
pc.log.WithField("state", state).Error("Failed to seek by closing or closed consumer")
Expand All @@ -861,6 +866,9 @@ func (pc *partitionConsumer) Seek(msgID MessageID) error {
req := &seekRequest{
doneCh: make(chan struct{}),
}
if !pc.seekCh.CompareAndSwap(nil, &req.doneCh) {
return errSeekInProgress
}
if cmid, ok := msgID.(*chunkMessageID); ok {
req.msgID = cmid.firstChunkID
} else {
Expand All @@ -877,7 +885,6 @@ func (pc *partitionConsumer) Seek(msgID MessageID) error {
}

func (pc *partitionConsumer) internalSeek(seek *seekRequest) {
defer close(seek.doneCh)
seek.err = pc.requestSeek(seek.msgID)
}
func (pc *partitionConsumer) requestSeek(msgID *messageID) error {
Expand Down Expand Up @@ -926,6 +933,9 @@ func (pc *partitionConsumer) SeekByTime(time time.Time) error {
doneCh: make(chan struct{}),
publishTime: time,
}
if !pc.seekCh.CompareAndSwap(nil, &req.doneCh) {
return errSeekInProgress
}
pc.ackGroupingTracker.flushAndClean()
pc.eventsCh <- req

Expand All @@ -935,8 +945,6 @@ func (pc *partitionConsumer) SeekByTime(time time.Time) error {
}

func (pc *partitionConsumer) internalSeekByTime(seek *seekByTimeRequest) {
defer close(seek.doneCh)

state := pc.getConsumerState()
if state == consumerClosing || state == consumerClosed {
pc.log.WithField("state", pc.state).Error("Failed seekByTime by consumer is closing or has closed")
Expand Down Expand Up @@ -1051,37 +1059,33 @@ func (pc *partitionConsumer) MessageReceived(response *pb.CommandMessage, header
return fmt.Errorf("discarding message on decryption error :%v", err)
case crypto.ConsumerCryptoFailureActionConsume:
pc.log.Warnf("consuming encrypted message due to error in decryption :%v", err)
messages := []*message{
{
publishTime: timeFromUnixTimestampMillis(msgMeta.GetPublishTime()),
eventTime: timeFromUnixTimestampMillis(msgMeta.GetEventTime()),
key: msgMeta.GetPartitionKey(),
producerName: msgMeta.GetProducerName(),
properties: internal.ConvertToStringMap(msgMeta.GetProperties()),
topic: pc.topic,
msgID: newMessageID(
int64(pbMsgID.GetLedgerId()),
int64(pbMsgID.GetEntryId()),
pbMsgID.GetBatchIndex(),
pc.partitionIdx,
pbMsgID.GetBatchSize(),
),
payLoad: headersAndPayload.ReadableSlice(),
schema: pc.options.schema,
replicationClusters: msgMeta.GetReplicateTo(),
replicatedFrom: msgMeta.GetReplicatedFrom(),
redeliveryCount: response.GetRedeliveryCount(),
encryptionContext: createEncryptionContext(msgMeta),
orderingKey: string(msgMeta.OrderingKey),
},
}

if pc.options.autoReceiverQueueSize {
pc.incomingMessages.Inc()
pc.markScaleIfNeed()
}

pc.queueCh <- messages
pc.queueCh <- &message{
publishTime: timeFromUnixTimestampMillis(msgMeta.GetPublishTime()),
eventTime: timeFromUnixTimestampMillis(msgMeta.GetEventTime()),
key: msgMeta.GetPartitionKey(),
producerName: msgMeta.GetProducerName(),
properties: internal.ConvertToStringMap(msgMeta.GetProperties()),
topic: pc.topic,
msgID: newMessageID(
int64(pbMsgID.GetLedgerId()),
int64(pbMsgID.GetEntryId()),
pbMsgID.GetBatchIndex(),
pc.partitionIdx,
pbMsgID.GetBatchSize(),
),
payLoad: headersAndPayload.ReadableSlice(),
schema: pc.options.schema,
replicationClusters: msgMeta.GetReplicateTo(),
replicatedFrom: msgMeta.GetReplicatedFrom(),
redeliveryCount: response.GetRedeliveryCount(),
encryptionContext: createEncryptionContext(msgMeta),
orderingKey: string(msgMeta.OrderingKey),
}
return nil
}
}
Expand Down Expand Up @@ -1255,7 +1259,7 @@ func (pc *partitionConsumer) MessageReceived(response *pb.CommandMessage, header
Message: msg,
})

messages = append(messages, msg)
pc.queueCh <- msg
bytesReceived += msg.size()
}

Expand All @@ -1269,8 +1273,6 @@ func (pc *partitionConsumer) MessageReceived(response *pb.CommandMessage, header
pc.availablePermits.add(skippedMessages)
}

// send messages to the dispatcher
pc.queueCh <- messages
return nil
}

Expand Down Expand Up @@ -1426,36 +1428,7 @@ func (pc *partitionConsumer) dispatcher() {
defer func() {
pc.log.Debug("exiting dispatch loop")
}()
var messages []*message
for {
var queueCh chan []*message
var messageCh chan ConsumerMessage
var nextMessage ConsumerMessage
var nextMessageSize int

// are there more messages to send?
if len(messages) > 0 {
nextMessage = ConsumerMessage{
Consumer: pc.parentConsumer,
Message: messages[0],
}
nextMessageSize = messages[0].size()

if pc.dlq.shouldSendToDlq(&nextMessage) {
// pass the message to the DLQ router
pc.metrics.DlqCounter.Inc()
messageCh = pc.dlq.Chan()
} else {
// pass the message to application channel
messageCh = pc.messageCh
}

pc.metrics.PrefetchedMessages.Dec()
pc.metrics.PrefetchedBytes.Sub(float64(len(messages[0].payLoad)))
} else {
queueCh = pc.queueCh
}

select {
case <-pc.closeCh:
return
Expand All @@ -1466,8 +1439,6 @@ func (pc *partitionConsumer) dispatcher() {
}
pc.log.Debug("dispatcher received connection event")

messages = nil

// reset available permits
pc.availablePermits.reset()

Expand All @@ -1484,28 +1455,33 @@ func (pc *partitionConsumer) dispatcher() {
pc.log.WithError(err).Error("unable to send initial permits to broker")
}

case msgs, ok := <-queueCh:
case msg, ok := <-pc.queueCh:
if !ok {
return
}
// we only read messages here after the consumer has processed all messages
// in the previous batch
messages = msgs

// if the messageCh is nil or the messageCh is full this will not be selected
case messageCh <- nextMessage:
// allow this message to be garbage collected
messages[0] = nil
messages = messages[1:]
consumerMessage := ConsumerMessage{
Consumer: pc.parentConsumer,
Message: msg,
}
if pc.dlq.shouldSendToDlq(&consumerMessage) {
// pass the message to the DLQ router
pc.metrics.DlqCounter.Inc()
pc.dlq.Chan() <- consumerMessage
} else {
// pass the message to application channel
pc.messageCh <- consumerMessage
}

pc.availablePermits.inc()

if pc.options.autoReceiverQueueSize {
pc.incomingMessages.Dec()
pc.client.memLimit.ReleaseMemory(int64(nextMessageSize))
pc.client.memLimit.ReleaseMemory(int64(msg.size()))
pc.expectMoreIncomingMessages()
}

pc.metrics.PrefetchedMessages.Dec()
pc.metrics.PrefetchedBytes.Sub(float64(len(msg.payLoad)))
case clearQueueCb := <-pc.clearQueueCh:
// drain the message queue on any new connection by sending a
// special nil message to the channel so we know when to stop dropping messages
Expand All @@ -1519,15 +1495,12 @@ func (pc *partitionConsumer) dispatcher() {
if m == nil {
break
} else if nextMessageInQueue == nil {
nextMessageInQueue = toTrackingMessageID(m[0].msgID)
nextMessageInQueue = toTrackingMessageID(m.msgID)
}
if pc.options.autoReceiverQueueSize {
pc.incomingMessages.Sub(int32(len(m)))
pc.incomingMessages.Sub(int32(1))
}
}

messages = nil

clearQueueCb(nextMessageInQueue)
}
}
Expand Down Expand Up @@ -1587,26 +1560,24 @@ type seekByTimeRequest struct {

func (pc *partitionConsumer) runEventsLoop() {
defer func() {
load := pc.seekCh.Load()
if load != nil {
*load <- struct{}{}
}
pc.log.Debug("exiting events loop")
}()
pc.log.Debug("get into runEventsLoop")

go func() {
for {
select {
case <-pc.closeCh:
pc.log.Info("close consumer, exit reconnect")
return
case connectionClosed := <-pc.connectClosedCh:
pc.log.Debug("runEventsLoop will reconnect")
pc.reconnectToBroker(connectionClosed)
}
}
}()

for {
for i := range pc.eventsCh {
switch v := i.(type) {
select {
case <-pc.closeCh:
pc.log.Info("close consumer, exit reconnect")
return
case connectionClosed := <-pc.connectClosedCh:
pc.log.Debug("runEventsLoop will reconnect")
pc.reconnectToBroker(connectionClosed)
case event := <-pc.eventsCh:
switch v := event.(type) {
case *ackRequest:
pc.internalAck(v)
case *ackWithTxnRequest:
Expand Down Expand Up @@ -1684,6 +1655,16 @@ func (pc *partitionConsumer) internalClose(req *closeRequest) {
}

func (pc *partitionConsumer) reconnectToBroker(connectionClosed *connectionClosed) {
cleanupSeekChFn := func() {
seekCh := pc.seekCh.Swap(nil)
if seekCh != nil {
*seekCh <- struct{}{}
}
}
defer func() {
cleanupSeekChFn()
}()

var maxRetry int

if pc.options.maxReconnectToBroker == nil {
Expand Down Expand Up @@ -1811,7 +1792,11 @@ func (pc *partitionConsumer) grabConn(assignedBrokerURL string) error {
KeySharedMeta: keySharedMeta,
}

pc.startMessageID.set(pc.clearReceiverQueue())
queue := pc.clearReceiverQueue()
if queue != nil {
pc.log.Info("StartMessageId " + pc.startMessageID.get().String())
}
pc.startMessageID.set(queue)
if pc.options.subscriptionMode != Durable {
// For regular subscriptions the broker will determine the restarting point
cmdSubscribe.StartMessageId = convertToMessageIDData(pc.startMessageID.get())
Expand Down

0 comments on commit 4dec08d

Please sign in to comment.