Skip to content

Commit

Permalink
Fetch multiple messages (#163)
Browse files Browse the repository at this point in the history
* add MaxNumberOfMessages parameter

* add savedMessages cache

* fix concurrency bug

* add backoff retry loop and comments

* explain why receiveMessageCalled

* fix failing test

* dry a bit the test

* remove locking mutex

* apply pr suggestion

---------

Co-authored-by: Guillem <guillemus@proton.me>
  • Loading branch information
alarbada and Guillem authored Sep 17, 2024
1 parent 151eeff commit 55f6971
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 37 deletions.
19 changes: 10 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@ Queue.

The configuration passed to `Configure` can contain the following fields:

| name | description | required | default value |
| ----------------------- | ------------------------------------------------------------------------------------------------------------ | -------- | ------------- |
| `aws.accessKeyId` | AWS Access Key ID | yes | |
| `aws.secretAccessKey` | AWS Secret Access Key | yes | |
| `aws.region` | AWS SQS Region | yes | |
| `aws.queue` | AWS SQS Queue Name | yes | |
| `aws.visibilityTimeout` | The duration (in seconds) that the received messages are hidden from subsequent reads after being retrieved. | no | 0 |
| `aws.waitTimeSeconds` | the duration (in seconds) for which the call waits for a message to arrive in the queue before returning. | no | 10 |
| `aws.url` | URL for AWS (internal use only) | no | |
| name | description | required | default value |
| ------------------------- | ------------------------------------------------------------------------------------------------------------ | -------- | ------------- |
| `aws.accessKeyId` | AWS Access Key ID | yes | |
| `aws.secretAccessKey` | AWS Secret Access Key | yes | |
| `aws.region` | AWS SQS Region | yes | |
| `aws.queue` | AWS SQS Queue Name | yes | |
| `aws.visibilityTimeout` | The duration (in seconds) that the received messages are hidden from subsequent reads after being retrieved. | no | 0 |
| `aws.waitTimeSeconds` | The duration (in seconds) for which the call waits for a message to arrive in the queue before returning. | no | 10 |
| `aws.maxNumberOfMessages` | The maximum number of messages to fetch from SQS in a single batch. | no | 1 |
| `aws.url` | URL for AWS (internal use only) | no | |

## Destination

Expand Down
3 changes: 3 additions & 0 deletions source/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,7 @@ type Config struct {
// WaitTimeSeconds is the duration (in seconds) for which the call waits for
// a message to arrive in the queue before returning.
WaitTimeSeconds int32 `json:"aws.waitTimeSeconds" default:"10"`

// MaxNumberOfMessages is the maximum number of messages to fetch from SQS in a single batch.
MaxNumberOfMessages int32 `json:"aws.maxNumberOfMessages" default:"1" validate:"gt=0,lt=11"`
}
57 changes: 44 additions & 13 deletions source/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,31 @@ import (

type Source struct {
sdk.UnimplementedSource
config Config
svc *sqs.Client
queueURL string
config Config
svc *sqs.Client
queueURL string
savedMessages []types.Message

// receiveMessageCalled will be called each time the `ReceiveMessage` method
// from the SQS client is called. This is useful in tests only; in non-test
// environments, calling this should have no side effects.
receiveMessageCalled func()

// httpClient allows us to cleanup left over http connections. Useful to not
// leak goroutines when tearing down the connector
httpClient *http.Client
}

// newSource initializes a source without any middlewares. Useful for integration test setup.
func newSource() *Source {
return &Source{
httpClient: &http.Client{},
}
}

func NewSource() sdk.Source {
return sdk.SourceWithMiddleware(
&Source{
httpClient: &http.Client{},
},
newSource(),
sdk.DefaultSourceMiddleware(
// disable schema extraction by default, because the source produces raw data
sdk.SourceWithSchemaExtractionConfig{
Expand Down Expand Up @@ -107,30 +118,50 @@ func (s *Source) Open(ctx context.Context, sdkPos opencdc.Position) (err error)
return nil
}

func (s *Source) Read(ctx context.Context) (rec opencdc.Record, err error) {
func (s *Source) receiveMessage(ctx context.Context) (msg types.Message, err error) {
if len(s.savedMessages) >= 1 {
first := s.savedMessages[0]
s.savedMessages = s.savedMessages[1:]
return first, nil
}

receiveMessage := &sqs.ReceiveMessageInput{
MessageAttributeNames: []string{
string(types.QueueAttributeNameAll),
},
QueueUrl: &s.queueURL,
MaxNumberOfMessages: 1,
MaxNumberOfMessages: s.config.MaxNumberOfMessages,
VisibilityTimeout: s.config.VisibilityTimeout,
WaitTimeSeconds: s.config.WaitTimeSeconds,
}

// grab a message from queue
sqsMessages, err := s.svc.ReceiveMessage(ctx, receiveMessage)
if err != nil {
return rec, fmt.Errorf("error retrieving amazon sqs messages: %w", err)
return msg, fmt.Errorf("error retrieving amazon sqs messages: %w", err)
}
if s.receiveMessageCalled != nil {
s.receiveMessageCalled()
}

// if there are no messages in queue, backoff
if len(sqsMessages.Messages) == 0 {
sdk.Logger(ctx).Warn().Msg("got 0 messages from queue")
return rec, sdk.ErrBackoffRetry
return msg, sdk.ErrBackoffRetry
}

msg = sqsMessages.Messages[0]
if len(sqsMessages.Messages) == 1 {
return msg, nil
}

message := sqsMessages.Messages[0]
s.savedMessages = sqsMessages.Messages[1:]
return msg, nil
}

func (s *Source) Read(ctx context.Context) (rec opencdc.Record, err error) {
message, err := s.receiveMessage(ctx)
if err != nil {
return rec, err
}

mt := opencdc.Metadata{}
for key, value := range message.MessageAttributes {
Expand Down
89 changes: 81 additions & 8 deletions source/source_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,33 @@
package source

import (
"context"
"fmt"
"sort"
"strconv"
"strings"
"testing"

"github.com/aws/aws-sdk-go-v2/service/sqs"
"github.com/conduitio-labs/conduit-connector-sqs/common"
testutils "github.com/conduitio-labs/conduit-connector-sqs/test"
"github.com/conduitio/conduit-commons/opencdc"
sdk "github.com/conduitio/conduit-connector-sdk"
"github.com/matryer/is"
)

func sendMessage(ctx context.Context, is *is.I, client *sqs.Client, queueURL, msg *string) {
is.Helper()
_, err := client.SendMessage(
ctx,
&sqs.SendMessageInput{
MessageBody: msg,
QueueUrl: queueURL,
},
)
is.NoErr(err)
}

func TestSource_SuccessfulMessageReceive(t *testing.T) {
is := is.New(t)
ctx := testutils.TestContext(t)
Expand All @@ -38,14 +55,7 @@ func TestSource_SuccessfulMessageReceive(t *testing.T) {
defer cleanSource()

messageBody := "Test message body"
_, err := testClient.SendMessage(
ctx,
&sqs.SendMessageInput{
MessageBody: &messageBody,
QueueUrl: testQueue.URL,
},
)
is.NoErr(err)
sendMessage(ctx, is, testClient, testQueue.URL, &messageBody)

record, err := source.Read(ctx)
is.NoErr(err)
Expand Down Expand Up @@ -101,3 +111,66 @@ func TestSource_OpenWithPosition(t *testing.T) {
))
}
}

func TestMultipleMessageFetch(t *testing.T) {
is := is.New(t)
ctx := testutils.TestContext(t)

testClient, cleanTestClient := testutils.NewSQSClient(ctx, is)
defer cleanTestClient()

testQueue := testutils.CreateTestQueue(ctx, t, is, testClient)

totalMessages := 20
maxNumberOfMessages := 5

expectedMessages := make([]string, totalMessages)
for i := range totalMessages {
msg := fmt.Sprintf("message %d", i)
sendMessage(ctx, is, testClient, testQueue.URL, &msg)
expectedMessages[i] = msg
}

source := newSource()
var receiveMessageCalls int

source.receiveMessageCalled = func() {
receiveMessageCalls++
}

cfg := testutils.SourceConfig(testQueue.Name)
cfg[ConfigAwsVisibilityTimeout] = "10"
cfg[ConfigAwsMaxNumberOfMessages] = fmt.Sprint(maxNumberOfMessages)

is.NoErr(source.Configure(ctx, cfg))
is.NoErr(source.Open(ctx, nil))
defer func() { is.NoErr(source.Teardown(ctx)) }()

recs := make([]opencdc.Record, totalMessages)
for i := range totalMessages {
rec, err := source.Read(ctx)
is.NoErr(err)
is.NoErr(source.Ack(ctx, rec.Position))
recs[i] = rec
}

// records might come unsorted
sort.Slice(recs, func(i, j int) bool {
prevInt, _ := strconv.Atoi(string(recs[i].Payload.After.Bytes())[len("message "):])
nextInt, _ := strconv.Atoi(string(recs[j].Payload.After.Bytes())[len("message "):])
return prevInt < nextInt
})

// assert record contents
for i := range recs {
expected := expectedMessages[i]
actual := string(recs[i].Payload.After.Bytes())

is.Equal(expected, actual)
}

is.Equal(
totalMessages/maxNumberOfMessages,
receiveMessageCalls,
) // expected receive calls != actual receive calls made
}
24 changes: 17 additions & 7 deletions source/source_paramgen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 55f6971

Please sign in to comment.