diff --git a/.travis.yml b/.travis.yml index 549ebd8..c8c0064 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,4 +11,4 @@ before_install: - docker pull deangiberson/aws-dynamodb-local - docker pull dlsniper/kinesalite install: make get -script: make travis-integration +script: make docker-integration diff --git a/Makefile b/Makefile index 1dd0f91..50eac2b 100644 --- a/Makefile +++ b/Makefile @@ -10,9 +10,7 @@ integration: get @go test -timeout 30s -tags=integration docker-integration: - @docker-compose run --rm gokini make integration - -travis-integration: @docker-compose up -d @sleep 10 - @go test -timeout 30s -tags=integration + @docker-compose run gokini make integration + @docker-compose down diff --git a/checkpointer.go b/checkpointer.go index afd12cf..507701d 100644 --- a/checkpointer.go +++ b/checkpointer.go @@ -7,10 +7,10 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/client" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" - "github.com/matryer/try" log "github.com/sirupsen/logrus" ) @@ -35,10 +35,11 @@ var ErrSequenceIDNotFound = errors.New("SequenceIDNotFoundForShard") // DynamoCheckpoint implements the Checkpoint interface using DynamoDB as a backend type DynamoCheckpoint struct { - TableName string - LeaseDuration int - svc dynamodbiface.DynamoDBAPI - Retries int + TableName string + LeaseDuration int + Retries int + svc dynamodbiface.DynamoDBAPI + skipTableCheck bool } // Init initialises the DynamoDB Checkpoint @@ -46,6 +47,7 @@ func (checkpointer *DynamoCheckpoint) Init() error { log.Debug("Creating DynamoDB session") session, err := session.NewSessionWithOptions( session.Options{ + Config: aws.Config{Retryer: client.DefaultRetryer{NumMaxRetries: checkpointer.Retries}}, SharedConfigState: session.SharedConfigEnable, }, ) @@ -54,7 +56,8 @@ func (checkpointer *DynamoCheckpoint) Init() error { } if endpoint := os.Getenv("DYNAMODB_ENDPOINT"); endpoint != "" { - session.Config.Endpoint = aws.String(endpoint) + log.Infof("Using dynamodb endpoint from environment %s", endpoint) + session.Config.Endpoint = &endpoint } checkpointer.svc = dynamodb.New(session) @@ -63,7 +66,7 @@ func (checkpointer *DynamoCheckpoint) Init() error { checkpointer.LeaseDuration = defaultLeaseDuration } - if !checkpointer.doesTableExist() { + if !checkpointer.skipTableCheck && !checkpointer.doesTableExist() { return checkpointer.createTable() } return nil @@ -84,6 +87,14 @@ func (checkpointer *DynamoCheckpoint) GetLease(shard *shardStatus, newAssignTo s var expressionAttributeValues map[string]*dynamodb.AttributeValue if !leaseTimeoutOk || !assignedToOk { conditionalExpression = "attribute_not_exists(AssignedTo)" + if shard.Checkpoint != "" { + conditionalExpression = conditionalExpression + " AND SequenceID = :id" + expressionAttributeValues = map[string]*dynamodb.AttributeValue{ + ":id": { + S: &shard.Checkpoint, + }, + } + } } else { assignedTo := *assignedVar.S leaseTimeout := *leaseVar.S @@ -108,6 +119,12 @@ func (checkpointer *DynamoCheckpoint) GetLease(shard *shardStatus, newAssignTo s S: &leaseTimeout, }, } + if shard.Checkpoint != "" { + conditionalExpression = conditionalExpression + " AND SequenceID = :sid" + expressionAttributeValues[":sid"] = &dynamodb.AttributeValue{ + S: &shard.Checkpoint, + } + } } marshalledCheckpoint := map[string]*dynamodb.AttributeValue{ @@ -122,6 +139,10 @@ func (checkpointer *DynamoCheckpoint) GetLease(shard *shardStatus, newAssignTo s }, } + if shard.Checkpoint != "" { + marshalledCheckpoint["SequenceID"] = &dynamodb.AttributeValue{S: &shard.Checkpoint} + } + if shard.Checkpoint != "" { marshalledCheckpoint["Checkpoint"] = &dynamodb.AttributeValue{ S: &shard.Checkpoint, @@ -229,51 +250,26 @@ func (checkpointer *DynamoCheckpoint) saveItem(item map[string]*dynamodb.Attribu func (checkpointer *DynamoCheckpoint) conditionalUpdate(conditionExpression string, expressionAttributeValues map[string]*dynamodb.AttributeValue, item map[string]*dynamodb.AttributeValue) error { return checkpointer.putItem(&dynamodb.PutItemInput{ - ConditionExpression: aws.String(conditionExpression), - TableName: aws.String(checkpointer.TableName), - Item: item, + ConditionExpression: aws.String(conditionExpression), + TableName: aws.String(checkpointer.TableName), + Item: item, ExpressionAttributeValues: expressionAttributeValues, }) } func (checkpointer *DynamoCheckpoint) putItem(input *dynamodb.PutItemInput) error { - return try.Do(func(attempt int) (bool, error) { - _, err := checkpointer.svc.PutItem(input) - if awsErr, ok := err.(awserr.Error); ok { - if awsErr.Code() == dynamodb.ErrCodeProvisionedThroughputExceededException || - awsErr.Code() == dynamodb.ErrCodeInternalServerError && - attempt < checkpointer.Retries { - // Backoff time as recommended by https://docs.aws.amazon.com/general/latest/gr/api-retries.html - time.Sleep(time.Duration(2^attempt*100) * time.Millisecond) - return true, err - } - } - return false, err - }) + _, err := checkpointer.svc.PutItem(input) + return err } func (checkpointer *DynamoCheckpoint) getItem(shardID string) (map[string]*dynamodb.AttributeValue, error) { - var item *dynamodb.GetItemOutput - err := try.Do(func(attempt int) (bool, error) { - var err error - item, err = checkpointer.svc.GetItem(&dynamodb.GetItemInput{ - TableName: aws.String(checkpointer.TableName), - Key: map[string]*dynamodb.AttributeValue{ - "ShardID": { - S: aws.String(shardID), - }, + item, err := checkpointer.svc.GetItem(&dynamodb.GetItemInput{ + TableName: aws.String(checkpointer.TableName), + Key: map[string]*dynamodb.AttributeValue{ + "ShardID": { + S: aws.String(shardID), }, - }) - if awsErr, ok := err.(awserr.Error); ok { - if awsErr.Code() == dynamodb.ErrCodeProvisionedThroughputExceededException || - awsErr.Code() == dynamodb.ErrCodeInternalServerError && - attempt < checkpointer.Retries { - // Backoff time as recommended by https://docs.aws.amazon.com/general/latest/gr/api-retries.html - time.Sleep(time.Duration(2^attempt*100) * time.Millisecond) - return true, err - } - } - return false, err + }, }) return item.Item, err } diff --git a/checkpointer_integration_test.go b/checkpointer_integration_test.go new file mode 100644 index 0000000..7f2a29d --- /dev/null +++ b/checkpointer_integration_test.go @@ -0,0 +1,83 @@ +//+build integration + +package gokini + +import ( + "sync" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/dynamodb" +) + +func TestRaceCondGetLeaseTimeout(t *testing.T) { + checkpoint := &DynamoCheckpoint{ + TableName: "TableName", + } + checkpoint.Init() + marshalledCheckpoint := map[string]*dynamodb.AttributeValue{ + "ShardID": { + S: aws.String("0001"), + }, + "AssignedTo": { + S: aws.String("abcd-efgh"), + }, + "LeaseTimeout": { + S: aws.String(time.Now().AddDate(0, -1, 0).UTC().Format(time.RFC3339)), + }, + "SequenceID": { + S: aws.String("deadbeef"), + }, + } + input := &dynamodb.PutItemInput{ + TableName: aws.String("TableName"), + Item: marshalledCheckpoint, + } + _, err := checkpoint.svc.PutItem(input) + if err != nil { + t.Fatalf("Error writing to dynamo %s", err) + } + shard := &shardStatus{ + ID: "0001", + Checkpoint: "TestRaceCondGetLeaseTimeout", + mux: &sync.Mutex{}, + } + err = checkpoint.GetLease(shard, "ijkl-mnop") + + if err == nil || err.Error() != ErrLeaseNotAquired { + t.Error("Got a lease when checkpoints didn't match. Potentially we stomped on the checkpoint") + } +} +func TestRaceCondGetLeaseNoAssignee(t *testing.T) { + checkpoint := &DynamoCheckpoint{ + TableName: "TableName", + } + checkpoint.Init() + marshalledCheckpoint := map[string]*dynamodb.AttributeValue{ + "ShardID": { + S: aws.String("0001"), + }, + "SequenceID": { + S: aws.String("deadbeef"), + }, + } + input := &dynamodb.PutItemInput{ + TableName: aws.String("TableName"), + Item: marshalledCheckpoint, + } + _, err := checkpoint.svc.PutItem(input) + if err != nil { + t.Fatalf("Error writing to dynamo %s", err) + } + shard := &shardStatus{ + ID: "0001", + Checkpoint: "TestRaceCondGetLeaseNoAssignee", + mux: &sync.Mutex{}, + } + err = checkpoint.GetLease(shard, "ijkl-mnop") + + if err == nil || err.Error() != ErrLeaseNotAquired { + t.Error("Got a lease when checkpoints didn't match. Potentially we stomped on the checkpoint") + } +} diff --git a/checkpointer_test.go b/checkpointer_test.go index 0b84416..f5aa946 100644 --- a/checkpointer_test.go +++ b/checkpointer_test.go @@ -55,8 +55,8 @@ func TestDoesTableExist(t *testing.T) { func TestGetLeaseNotAquired(t *testing.T) { svc := &mockDynamoDB{tableExist: true} checkpoint := &DynamoCheckpoint{ - TableName: "TableName", - svc: svc, + TableName: "TableName", + skipTableCheck: true, } checkpoint.Init() checkpoint.svc = svc @@ -82,8 +82,10 @@ func TestGetLeaseNotAquired(t *testing.T) { func TestGetLeaseAquired(t *testing.T) { svc := &mockDynamoDB{tableExist: true} checkpoint := &DynamoCheckpoint{ - TableName: "TableName", + TableName: "TableName", + skipTableCheck: true, } + checkpoint.svc = svc checkpoint.Init() checkpoint.svc = svc marshalledCheckpoint := map[string]*dynamodb.AttributeValue{ @@ -96,28 +98,39 @@ func TestGetLeaseAquired(t *testing.T) { "LeaseTimeout": { S: aws.String(time.Now().AddDate(0, -1, 0).UTC().Format(time.RFC3339)), }, + "SequenceID": { + S: aws.String("deadbeef"), + }, } input := &dynamodb.PutItemInput{ TableName: aws.String("TableName"), Item: marshalledCheckpoint, } checkpoint.svc.PutItem(input) - err := checkpoint.GetLease(&shardStatus{ + shard := &shardStatus{ ID: "0001", - Checkpoint: "", + Checkpoint: "deadbeef", mux: &sync.Mutex{}, - }, "ijkl-mnop") + } + err := checkpoint.GetLease(shard, "ijkl-mnop") if err != nil { t.Errorf("Lease not aquired after timeout %s", err) } + + id, ok := svc.item["SequenceID"] + if !ok { + t.Error("Expected SequenceID to be set by GetLease") + } else if *id.S != "deadbeef" { + t.Errorf("Expected SequenceID to be deadbeef. Got '%s'", *id.S) + } } func TestGetLeaseRenewed(t *testing.T) { svc := &mockDynamoDB{tableExist: true} checkpoint := &DynamoCheckpoint{ - TableName: "TableName", - svc: svc, + TableName: "TableName", + skipTableCheck: true, } checkpoint.Init() checkpoint.svc = svc diff --git a/consumer.go b/consumer.go index ad70fb5..f687231 100644 --- a/consumer.go +++ b/consumer.go @@ -10,6 +10,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/client" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/kinesis" "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" @@ -55,6 +56,7 @@ type KinesisConsumer struct { EmptyRecordBackoffMs int LeaseDuration int Monitoring MonitoringConfiguration + Retries *int svc kinesisiface.KinesisAPI checkpointer Checkpointer stop *chan struct{} @@ -65,6 +67,8 @@ type KinesisConsumer struct { mService monitoringService } +var defaultRetries = 5 + // StartConsumer starts the RecordConsumer, calls Init and starts sending records to ProcessRecords func (kc *KinesisConsumer) StartConsumer() error { // Set Defaults @@ -81,9 +85,15 @@ func (kc *KinesisConsumer) StartConsumer() error { kc.mService = kc.Monitoring.service if kc.svc == nil && kc.checkpointer == nil { + retries := defaultRetries + if kc.Retries != nil { + retries = *kc.Retries + } + log.Debugf("Creating Kinesis Session") session, err := session.NewSessionWithOptions( session.Options{ + Config: aws.Config{Retryer: client.DefaultRetryer{NumMaxRetries: retries}}, SharedConfigState: session.SharedConfigEnable, }, ) @@ -96,7 +106,7 @@ func (kc *KinesisConsumer) StartConsumer() error { kc.svc = kinesis.New(session) kc.checkpointer = &DynamoCheckpoint{ TableName: kc.TableName, - Retries: 5, + Retries: retries, LeaseDuration: kc.LeaseDuration, } } diff --git a/docker-compose.yml b/docker-compose.yml index f7d3a1f..a35e188 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -7,7 +7,7 @@ services: expose: - 4567 dynamodb: - image: deangiberson/aws-dynamodb-local + image: amazon/dynamodb-local ports: - 8000:8000 expose: