diff --git a/consumer.go b/consumer.go index 763b0e25..80206dc7 100644 --- a/consumer.go +++ b/consumer.go @@ -78,6 +78,7 @@ type Consumer struct { store Store scanInterval time.Duration maxRecords int64 + retentionPeriod int64 } // ScanFunc is the type of the function called for each message read @@ -104,6 +105,23 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error { shardc = make(chan *kinesis.Shard, 1) ) + // Discover the retention period on the stream, which is used later in ScanShard() + // for failure mode handling. + summary, err := c.client.DescribeStreamSummaryWithContext( + aws.Context(ctx), + &kinesis.DescribeStreamSummaryInput{ + StreamName: aws.String(c.streamName), + }, + ) + if err != nil { + if err.(awserr.Error).Code() != "AccessDeniedException" { + return err + } + c.logger.Log("[CONSUMER] IAM entity lacks kinesis:DescribeStreamSummary permissions, skipping extra sanity checks") + } else { + c.retentionPeriod = *summary.StreamDescriptionSummary.RetentionPeriodHours * 3600 * 1000 + } + go func() { c.group.Start(ctx, shardc) <-ctx.Done() @@ -146,7 +164,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e } // get shard iterator - shardIterator, err := c.getShardIterator(ctx, c.streamName, shardID, lastSeqNum) + shardIterator, err := c.getShardIterator(ctx, c.streamName, shardID, lastSeqNum, false) if err != nil { return fmt.Errorf("get shard iterator error: %v", err) } @@ -155,10 +173,19 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e defer func() { c.logger.Log("[CONSUMER] stop scan:", shardID) }() + + // True if we should detect expired sequence number and fetch iterator for oldest + // record if necessary. + checkExpiredSeq := true + // User requested polling interval scanTicker := time.NewTicker(c.scanInterval) defer scanTicker.Stop() + // Limit before throttling based on the published Kinesis API limit of 5 TPS per shard + fastTicker := time.NewTicker(200 * time.Millisecond) + defer fastTicker.Stop() for { + nextTicker := scanTicker resp, err := c.client.GetRecords(&kinesis.GetRecordsInput{ Limit: aws.Int64(c.maxRecords), ShardIterator: shardIterator, @@ -174,7 +201,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e } } - shardIterator, err = c.getShardIterator(ctx, c.streamName, shardID, lastSeqNum) + shardIterator, err = c.getShardIterator(ctx, c.streamName, shardID, lastSeqNum, false) if err != nil { return fmt.Errorf("get shard iterator error: %v", err) } @@ -207,13 +234,28 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e } shardIterator = resp.NextShardIterator + if len(resp.Records) == 0 && c.retentionPeriod > 0 && *resp.MillisBehindLatest >= c.retentionPeriod { + // No records were returned and we are behind at least the retention + // period of the stream which means the last sequence number refers to + // expired data. If we haven't done so already, fetch an iterator for the + // shard's trim horizon. + if checkExpiredSeq { + shardIterator, err = c.getShardIterator(ctx, c.streamName, shardID, "", true) + if err != nil { + return fmt.Errorf("get shard iterator error: %v", err) + } + // Only fetch the trim horizon once. + checkExpiredSeq = false + } + nextTicker = fastTicker + } } // Wait for next scan select { case <-ctx.Done(): return nil - case <-scanTicker.C: + case <-nextTicker.C: continue } } @@ -228,7 +270,7 @@ func isShardClosed(nextShardIterator, currentShardIterator *string) bool { return nextShardIterator == nil || currentShardIterator == nextShardIterator } -func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, seqNum string) (*string, error) { +func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, seqNum string, oldest bool) (*string, error) { params := &kinesis.GetShardIteratorInput{ ShardId: aws.String(shardID), StreamName: aws.String(streamName), @@ -237,6 +279,8 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se if seqNum != "" { params.ShardIteratorType = aws.String(kinesis.ShardIteratorTypeAfterSequenceNumber) params.StartingSequenceNumber = aws.String(seqNum) + } else if oldest { + params.ShardIteratorType = aws.String(kinesis.ShardIteratorTypeTrimHorizon) } else if c.initialTimestamp != nil { params.ShardIteratorType = aws.String(kinesis.ShardIteratorTypeAtTimestamp) params.Timestamp = c.initialTimestamp