Skip to content

Commit

Permalink
Merge pull request #4 from patrobinson/race-condition-get-lease
Browse files Browse the repository at this point in the history
Expose race conditions in GetLease
  • Loading branch information
Patrick Robinson authored Feb 17, 2019
2 parents fb5ee54 + 6efb6de commit 1a357e1
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 58 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ before_install:
- docker pull deangiberson/aws-dynamodb-local
- docker pull dlsniper/kinesalite
install: make get
script: make travis-integration
script: make docker-integration
6 changes: 2 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ integration: get
@go test -timeout 30s -tags=integration

docker-integration:
@docker-compose run --rm gokini make integration

travis-integration:
@docker-compose up -d
@sleep 10
@go test -timeout 30s -tags=integration
@docker-compose run gokini make integration
@docker-compose down
82 changes: 39 additions & 43 deletions checkpointer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface"
"github.com/matryer/try"
log "github.com/sirupsen/logrus"
)

Expand All @@ -35,17 +35,19 @@ var ErrSequenceIDNotFound = errors.New("SequenceIDNotFoundForShard")

// DynamoCheckpoint implements the Checkpoint interface using DynamoDB as a backend
type DynamoCheckpoint struct {
TableName string
LeaseDuration int
svc dynamodbiface.DynamoDBAPI
Retries int
TableName string
LeaseDuration int
Retries int
svc dynamodbiface.DynamoDBAPI
skipTableCheck bool
}

// Init initialises the DynamoDB Checkpoint
func (checkpointer *DynamoCheckpoint) Init() error {
log.Debug("Creating DynamoDB session")
session, err := session.NewSessionWithOptions(
session.Options{
Config: aws.Config{Retryer: client.DefaultRetryer{NumMaxRetries: checkpointer.Retries}},
SharedConfigState: session.SharedConfigEnable,
},
)
Expand All @@ -54,7 +56,8 @@ func (checkpointer *DynamoCheckpoint) Init() error {
}

if endpoint := os.Getenv("DYNAMODB_ENDPOINT"); endpoint != "" {
session.Config.Endpoint = aws.String(endpoint)
log.Infof("Using dynamodb endpoint from environment %s", endpoint)
session.Config.Endpoint = &endpoint
}

checkpointer.svc = dynamodb.New(session)
Expand All @@ -63,7 +66,7 @@ func (checkpointer *DynamoCheckpoint) Init() error {
checkpointer.LeaseDuration = defaultLeaseDuration
}

if !checkpointer.doesTableExist() {
if !checkpointer.skipTableCheck && !checkpointer.doesTableExist() {
return checkpointer.createTable()
}
return nil
Expand All @@ -84,6 +87,14 @@ func (checkpointer *DynamoCheckpoint) GetLease(shard *shardStatus, newAssignTo s
var expressionAttributeValues map[string]*dynamodb.AttributeValue
if !leaseTimeoutOk || !assignedToOk {
conditionalExpression = "attribute_not_exists(AssignedTo)"
if shard.Checkpoint != "" {
conditionalExpression = conditionalExpression + " AND SequenceID = :id"
expressionAttributeValues = map[string]*dynamodb.AttributeValue{
":id": {
S: &shard.Checkpoint,
},
}
}
} else {
assignedTo := *assignedVar.S
leaseTimeout := *leaseVar.S
Expand All @@ -108,6 +119,12 @@ func (checkpointer *DynamoCheckpoint) GetLease(shard *shardStatus, newAssignTo s
S: &leaseTimeout,
},
}
if shard.Checkpoint != "" {
conditionalExpression = conditionalExpression + " AND SequenceID = :sid"
expressionAttributeValues[":sid"] = &dynamodb.AttributeValue{
S: &shard.Checkpoint,
}
}
}

marshalledCheckpoint := map[string]*dynamodb.AttributeValue{
Expand All @@ -122,6 +139,10 @@ func (checkpointer *DynamoCheckpoint) GetLease(shard *shardStatus, newAssignTo s
},
}

if shard.Checkpoint != "" {
marshalledCheckpoint["SequenceID"] = &dynamodb.AttributeValue{S: &shard.Checkpoint}
}

if shard.Checkpoint != "" {
marshalledCheckpoint["Checkpoint"] = &dynamodb.AttributeValue{
S: &shard.Checkpoint,
Expand Down Expand Up @@ -229,51 +250,26 @@ func (checkpointer *DynamoCheckpoint) saveItem(item map[string]*dynamodb.Attribu

func (checkpointer *DynamoCheckpoint) conditionalUpdate(conditionExpression string, expressionAttributeValues map[string]*dynamodb.AttributeValue, item map[string]*dynamodb.AttributeValue) error {
return checkpointer.putItem(&dynamodb.PutItemInput{
ConditionExpression: aws.String(conditionExpression),
TableName: aws.String(checkpointer.TableName),
Item: item,
ConditionExpression: aws.String(conditionExpression),
TableName: aws.String(checkpointer.TableName),
Item: item,
ExpressionAttributeValues: expressionAttributeValues,
})
}

func (checkpointer *DynamoCheckpoint) putItem(input *dynamodb.PutItemInput) error {
return try.Do(func(attempt int) (bool, error) {
_, err := checkpointer.svc.PutItem(input)
if awsErr, ok := err.(awserr.Error); ok {
if awsErr.Code() == dynamodb.ErrCodeProvisionedThroughputExceededException ||
awsErr.Code() == dynamodb.ErrCodeInternalServerError &&
attempt < checkpointer.Retries {
// Backoff time as recommended by https://docs.aws.amazon.com/general/latest/gr/api-retries.html
time.Sleep(time.Duration(2^attempt*100) * time.Millisecond)
return true, err
}
}
return false, err
})
_, err := checkpointer.svc.PutItem(input)
return err
}

func (checkpointer *DynamoCheckpoint) getItem(shardID string) (map[string]*dynamodb.AttributeValue, error) {
var item *dynamodb.GetItemOutput
err := try.Do(func(attempt int) (bool, error) {
var err error
item, err = checkpointer.svc.GetItem(&dynamodb.GetItemInput{
TableName: aws.String(checkpointer.TableName),
Key: map[string]*dynamodb.AttributeValue{
"ShardID": {
S: aws.String(shardID),
},
item, err := checkpointer.svc.GetItem(&dynamodb.GetItemInput{
TableName: aws.String(checkpointer.TableName),
Key: map[string]*dynamodb.AttributeValue{
"ShardID": {
S: aws.String(shardID),
},
})
if awsErr, ok := err.(awserr.Error); ok {
if awsErr.Code() == dynamodb.ErrCodeProvisionedThroughputExceededException ||
awsErr.Code() == dynamodb.ErrCodeInternalServerError &&
attempt < checkpointer.Retries {
// Backoff time as recommended by https://docs.aws.amazon.com/general/latest/gr/api-retries.html
time.Sleep(time.Duration(2^attempt*100) * time.Millisecond)
return true, err
}
}
return false, err
},
})
return item.Item, err
}
83 changes: 83 additions & 0 deletions checkpointer_integration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
//+build integration

package gokini

import (
"sync"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/dynamodb"
)

func TestRaceCondGetLeaseTimeout(t *testing.T) {
checkpoint := &DynamoCheckpoint{
TableName: "TableName",
}
checkpoint.Init()
marshalledCheckpoint := map[string]*dynamodb.AttributeValue{
"ShardID": {
S: aws.String("0001"),
},
"AssignedTo": {
S: aws.String("abcd-efgh"),
},
"LeaseTimeout": {
S: aws.String(time.Now().AddDate(0, -1, 0).UTC().Format(time.RFC3339)),
},
"SequenceID": {
S: aws.String("deadbeef"),
},
}
input := &dynamodb.PutItemInput{
TableName: aws.String("TableName"),
Item: marshalledCheckpoint,
}
_, err := checkpoint.svc.PutItem(input)
if err != nil {
t.Fatalf("Error writing to dynamo %s", err)
}
shard := &shardStatus{
ID: "0001",
Checkpoint: "TestRaceCondGetLeaseTimeout",
mux: &sync.Mutex{},
}
err = checkpoint.GetLease(shard, "ijkl-mnop")

if err == nil || err.Error() != ErrLeaseNotAquired {
t.Error("Got a lease when checkpoints didn't match. Potentially we stomped on the checkpoint")
}
}
func TestRaceCondGetLeaseNoAssignee(t *testing.T) {
checkpoint := &DynamoCheckpoint{
TableName: "TableName",
}
checkpoint.Init()
marshalledCheckpoint := map[string]*dynamodb.AttributeValue{
"ShardID": {
S: aws.String("0001"),
},
"SequenceID": {
S: aws.String("deadbeef"),
},
}
input := &dynamodb.PutItemInput{
TableName: aws.String("TableName"),
Item: marshalledCheckpoint,
}
_, err := checkpoint.svc.PutItem(input)
if err != nil {
t.Fatalf("Error writing to dynamo %s", err)
}
shard := &shardStatus{
ID: "0001",
Checkpoint: "TestRaceCondGetLeaseNoAssignee",
mux: &sync.Mutex{},
}
err = checkpoint.GetLease(shard, "ijkl-mnop")

if err == nil || err.Error() != ErrLeaseNotAquired {
t.Error("Got a lease when checkpoints didn't match. Potentially we stomped on the checkpoint")
}
}
29 changes: 21 additions & 8 deletions checkpointer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ func TestDoesTableExist(t *testing.T) {
func TestGetLeaseNotAquired(t *testing.T) {
svc := &mockDynamoDB{tableExist: true}
checkpoint := &DynamoCheckpoint{
TableName: "TableName",
svc: svc,
TableName: "TableName",
skipTableCheck: true,
}
checkpoint.Init()
checkpoint.svc = svc
Expand All @@ -82,8 +82,10 @@ func TestGetLeaseNotAquired(t *testing.T) {
func TestGetLeaseAquired(t *testing.T) {
svc := &mockDynamoDB{tableExist: true}
checkpoint := &DynamoCheckpoint{
TableName: "TableName",
TableName: "TableName",
skipTableCheck: true,
}
checkpoint.svc = svc
checkpoint.Init()
checkpoint.svc = svc
marshalledCheckpoint := map[string]*dynamodb.AttributeValue{
Expand All @@ -96,28 +98,39 @@ func TestGetLeaseAquired(t *testing.T) {
"LeaseTimeout": {
S: aws.String(time.Now().AddDate(0, -1, 0).UTC().Format(time.RFC3339)),
},
"SequenceID": {
S: aws.String("deadbeef"),
},
}
input := &dynamodb.PutItemInput{
TableName: aws.String("TableName"),
Item: marshalledCheckpoint,
}
checkpoint.svc.PutItem(input)
err := checkpoint.GetLease(&shardStatus{
shard := &shardStatus{
ID: "0001",
Checkpoint: "",
Checkpoint: "deadbeef",
mux: &sync.Mutex{},
}, "ijkl-mnop")
}
err := checkpoint.GetLease(shard, "ijkl-mnop")

if err != nil {
t.Errorf("Lease not aquired after timeout %s", err)
}

id, ok := svc.item["SequenceID"]
if !ok {
t.Error("Expected SequenceID to be set by GetLease")
} else if *id.S != "deadbeef" {
t.Errorf("Expected SequenceID to be deadbeef. Got '%s'", *id.S)
}
}

func TestGetLeaseRenewed(t *testing.T) {
svc := &mockDynamoDB{tableExist: true}
checkpoint := &DynamoCheckpoint{
TableName: "TableName",
svc: svc,
TableName: "TableName",
skipTableCheck: true,
}
checkpoint.Init()
checkpoint.svc = svc
Expand Down
12 changes: 11 additions & 1 deletion consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kinesis"
"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
Expand Down Expand Up @@ -55,6 +56,7 @@ type KinesisConsumer struct {
EmptyRecordBackoffMs int
LeaseDuration int
Monitoring MonitoringConfiguration
Retries *int
svc kinesisiface.KinesisAPI
checkpointer Checkpointer
stop *chan struct{}
Expand All @@ -65,6 +67,8 @@ type KinesisConsumer struct {
mService monitoringService
}

var defaultRetries = 5

// StartConsumer starts the RecordConsumer, calls Init and starts sending records to ProcessRecords
func (kc *KinesisConsumer) StartConsumer() error {
// Set Defaults
Expand All @@ -81,9 +85,15 @@ func (kc *KinesisConsumer) StartConsumer() error {
kc.mService = kc.Monitoring.service

if kc.svc == nil && kc.checkpointer == nil {
retries := defaultRetries
if kc.Retries != nil {
retries = *kc.Retries
}

log.Debugf("Creating Kinesis Session")
session, err := session.NewSessionWithOptions(
session.Options{
Config: aws.Config{Retryer: client.DefaultRetryer{NumMaxRetries: retries}},
SharedConfigState: session.SharedConfigEnable,
},
)
Expand All @@ -96,7 +106,7 @@ func (kc *KinesisConsumer) StartConsumer() error {
kc.svc = kinesis.New(session)
kc.checkpointer = &DynamoCheckpoint{
TableName: kc.TableName,
Retries: 5,
Retries: retries,
LeaseDuration: kc.LeaseDuration,
}
}
Expand Down
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ services:
expose:
- 4567
dynamodb:
image: deangiberson/aws-dynamodb-local
image: amazon/dynamodb-local
ports:
- 8000:8000
expose:
Expand Down

0 comments on commit 1a357e1

Please sign in to comment.