Skip to content

Commit

Permalink
Add term check on the internal streams, to avoid wrong closures of le…
Browse files Browse the repository at this point in the history
…ader controller (#513)

During a leader election, an old leader can force back the new leader
into a follower mode.

This does not affect the correctness (eg: the old leader's requests are
then ignored because the term is invalid), though it causes the shard to
be unavailable until a new leader election is triggered.

The simplified sequence is like:
 1. `oxia-1` is the old leader in term: 4
 2. Coordinator initiates a new leader election and fences all the nodes
3. `oxia-0` and `oxia-2` are fenced correctly, while oxia-1 is ignored,
and will be retried in background
 4. `oxia-2` is elected leader (term: 5)
5. `oxia-1` is still running the leader controller (because it never got
successfully fenced), and it keeps trying to replicated to `oxia-0` and
`oxia-2` .


The replicate requests keep failing (because the term has already
changed to 5), though the first request has the effect of putting oxia-2
back in follower mode.

We are now in a state where coordinator is happy thinking `oxia-2` as
the leader, though `oxia-2` went back to follower mode. A new leader
election is not triggered because `oxia-2` looks healthy overall (the
health check is done at the pod level, not at shard level).
  • Loading branch information
merlimat authored Sep 14, 2024
1 parent 2d1aa2a commit 1934d55
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 18 deletions.
1 change: 1 addition & 0 deletions common/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package common
import "time"

const (
MetadataTerm = "term"
MetadataNamespace = "namespace"
MetadataShardId = "shard-id"
DefaultNamespace = "default"
Expand Down
4 changes: 2 additions & 2 deletions maelstrom/replication_rpc_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (r *maelstromReplicationRpcProvider) Close() error {
return nil
}

func (r *maelstromReplicationRpcProvider) GetReplicateStream(ctx context.Context, follower string, namespace string, shard int64) (
func (r *maelstromReplicationRpcProvider) GetReplicateStream(ctx context.Context, follower string, namespace string, shard int64, term int64) (
proto.OxiaLogReplication_ReplicateClient, error) {
s := &maelstromReplicateClient{
ctx: ctx,
Expand Down Expand Up @@ -83,7 +83,7 @@ func (r *maelstromReplicationRpcProvider) Truncate(follower string, req *proto.T
return res.(*proto.TruncateResponse), nil
}

func (r *maelstromReplicationRpcProvider) SendSnapshot(ctx context.Context, follower string, namespace string, shard int64) (proto.OxiaLogReplication_SendSnapshotClient, error) {
func (r *maelstromReplicationRpcProvider) SendSnapshot(ctx context.Context, follower string, namespace string, shard int64, term int64) (proto.OxiaLogReplication_SendSnapshotClient, error) {
panic("not implemented")
}

Expand Down
8 changes: 4 additions & 4 deletions server/follower_cursor.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ import (
// This is a provider for the ReplicateStream Grpc handler
// It's used to allow passing in a mocked version of the Grpc service.
type ReplicateStreamProvider interface {
GetReplicateStream(ctx context.Context, follower string, namespace string, shard int64) (proto.OxiaLogReplication_ReplicateClient, error)
SendSnapshot(ctx context.Context, follower string, namespace string, shard int64) (proto.OxiaLogReplication_SendSnapshotClient, error)
GetReplicateStream(ctx context.Context, follower string, namespace string, shard int64, term int64) (proto.OxiaLogReplication_ReplicateClient, error)
SendSnapshot(ctx context.Context, follower string, namespace string, shard int64, term int64) (proto.OxiaLogReplication_SendSnapshotClient, error)
}

// FollowerCursor
Expand Down Expand Up @@ -251,7 +251,7 @@ func (fc *followerCursor) sendSnapshot() error {
ctx, cancel := context.WithCancel(fc.ctx)
defer cancel()

stream, err := fc.replicateStreamProvider.SendSnapshot(ctx, fc.follower, fc.namespace, fc.shardId)
stream, err := fc.replicateStreamProvider.SendSnapshot(ctx, fc.follower, fc.namespace, fc.shardId, fc.term)
if err != nil {
return err
}
Expand Down Expand Up @@ -367,7 +367,7 @@ func (fc *followerCursor) streamEntries() error {

fc.Lock()
var err error
if fc.stream, err = fc.replicateStreamProvider.GetReplicateStream(ctx, fc.follower, fc.namespace, fc.shardId); err != nil {
if fc.stream, err = fc.replicateStreamProvider.GetReplicateStream(ctx, fc.follower, fc.namespace, fc.shardId, fc.term); err != nil {
fc.Unlock()
return err
}
Expand Down
28 changes: 25 additions & 3 deletions server/internal_rpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ func (s *internalRpcServer) Truncate(c context.Context, req *proto.TruncateReque

log.Info("Received Truncate request")

follower, err := s.shardsDirector.GetOrCreateFollower(req.Namespace, req.Shard)
follower, err := s.shardsDirector.GetOrCreateFollower(req.Namespace, req.Shard, req.Term)
if err != nil {
log.Warn(
"Truncate failed: could not get follower controller",
Expand Down Expand Up @@ -252,6 +252,11 @@ func (s *internalRpcServer) Replicate(srv proto.OxiaLogReplication_ReplicateServ
return err
}

term, err := readTerm(md)
if err != nil {
return err
}

log := s.log.With(
slog.Int64("shard", shardId),
slog.String("namespace", namespace),
Expand All @@ -260,7 +265,7 @@ func (s *internalRpcServer) Replicate(srv proto.OxiaLogReplication_ReplicateServ

log.Info("Received Replicate request")

follower, err := s.shardsDirector.GetOrCreateFollower(namespace, shardId)
follower, err := s.shardsDirector.GetOrCreateFollower(namespace, shardId, term)
if err != nil {
log.Warn(
"Replicate failed: could not get follower controller",
Expand Down Expand Up @@ -297,14 +302,19 @@ func (s *internalRpcServer) SendSnapshot(srv proto.OxiaLogReplication_SendSnapsh
return err
}

term, err := readTerm(md)
if err != nil {
return err
}

s.log.Info(
"Received SendSnapshot request",
slog.Int64("shard", shardId),
slog.String("namespace", namespace),
slog.String("peer", common.GetPeer(srv.Context())),
)

follower, err := s.shardsDirector.GetOrCreateFollower(namespace, shardId)
follower, err := s.shardsDirector.GetOrCreateFollower(namespace, shardId, term)
if err != nil {
s.log.Warn(
"SendSnapshot failed: could not get follower controller",
Expand Down Expand Up @@ -374,3 +384,15 @@ func ReadHeaderInt64(md metadata.MD, key string) (v int64, err error) {
_, err = fmt.Sscan(s, &r)
return r, err
}

func readTerm(md metadata.MD) (v int64, err error) {
arr := md.Get(common.MetadataTerm)
if len(arr) == 0 {
// There was no term in the metadata for the stream.
// In order to retain compatibility in a rollout scenario, skip
// the term check
return -1, nil
}

return ReadHeaderInt64(md, common.MetadataTerm)
}
4 changes: 2 additions & 2 deletions server/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,11 @@ func (m *mockRpcClient) CloseSend() error {
return nil
}

func (m *mockRpcClient) GetReplicateStream(ctx context.Context, follower string, namespace string, shard int64) (proto.OxiaLogReplication_ReplicateClient, error) {
func (m *mockRpcClient) GetReplicateStream(ctx context.Context, follower string, namespace string, shard int64, term int64) (proto.OxiaLogReplication_ReplicateClient, error) {
return m, nil
}

func (m *mockRpcClient) SendSnapshot(ctx context.Context, follower string, namespace string, shard int64) (proto.OxiaLogReplication_SendSnapshotClient, error) {
func (m *mockRpcClient) SendSnapshot(ctx context.Context, follower string, namespace string, shard int64, term int64) (proto.OxiaLogReplication_SendSnapshotClient, error) {
return m.sendSnapshotStream, nil
}

Expand Down
6 changes: 4 additions & 2 deletions server/rpc_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func NewReplicationRpcProvider(tlsConf *tls.Config) ReplicationRpcProvider {
}
}

func (r *replicationRpcProvider) GetReplicateStream(ctx context.Context, follower string, namespace string, shard int64) (
func (r *replicationRpcProvider) GetReplicateStream(ctx context.Context, follower string, namespace string, shard int64, term int64) (
proto.OxiaLogReplication_ReplicateClient, error) {
rpc, err := r.pool.GetReplicationRpc(follower)
if err != nil {
Expand All @@ -55,12 +55,13 @@ func (r *replicationRpcProvider) GetReplicateStream(ctx context.Context, followe

ctx = metadata.AppendToOutgoingContext(ctx, common.MetadataNamespace, namespace)
ctx = metadata.AppendToOutgoingContext(ctx, common.MetadataShardId, fmt.Sprintf("%d", shard))
ctx = metadata.AppendToOutgoingContext(ctx, common.MetadataTerm, fmt.Sprintf("%d", term))

stream, err := rpc.Replicate(ctx)
return stream, err
}

func (r *replicationRpcProvider) SendSnapshot(ctx context.Context, follower string, namespace string, shard int64) (
func (r *replicationRpcProvider) SendSnapshot(ctx context.Context, follower string, namespace string, shard int64, term int64) (
proto.OxiaLogReplication_SendSnapshotClient, error) {
rpc, err := r.pool.GetReplicationRpc(follower)
if err != nil {
Expand All @@ -69,6 +70,7 @@ func (r *replicationRpcProvider) SendSnapshot(ctx context.Context, follower stri

ctx = metadata.AppendToOutgoingContext(ctx, common.MetadataNamespace, namespace)
ctx = metadata.AppendToOutgoingContext(ctx, common.MetadataShardId, fmt.Sprintf("%d", shard))
ctx = metadata.AppendToOutgoingContext(ctx, common.MetadataTerm, fmt.Sprintf("%d", term))

stream, err := rpc.SendSnapshot(ctx)
return stream, err
Expand Down
10 changes: 7 additions & 3 deletions server/shards_director.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type ShardsDirector interface {
GetFollower(shardId int64) (FollowerController, error)

GetOrCreateLeader(namespace string, shardId int64) (LeaderController, error)
GetOrCreateFollower(namespace string, shardId int64) (FollowerController, error)
GetOrCreateFollower(namespace string, shardId int64, term int64) (FollowerController, error)

DeleteShard(req *proto.DeleteShardRequest) (*proto.DeleteShardResponse, error)
}
Expand Down Expand Up @@ -155,7 +155,7 @@ func (s *shardsDirector) GetOrCreateLeader(namespace string, shardId int64) (Lea
return lc, nil
}

func (s *shardsDirector) GetOrCreateFollower(namespace string, shardId int64) (FollowerController, error) {
func (s *shardsDirector) GetOrCreateFollower(namespace string, shardId int64, term int64) (FollowerController, error) {
s.Lock()
defer s.Unlock()

Expand All @@ -168,8 +168,12 @@ func (s *shardsDirector) GetOrCreateFollower(namespace string, shardId int64) (F
return follower, nil
} else if leader, ok := s.leaders[shardId]; ok {
// There is an existing leader controller
// Let's close it before creating the follower controller
if term >= 0 && term != leader.Term() {
// We should not close the existing leader because of a late request
return nil, common.ErrorInvalidTerm
}

// If we are in the right term, let's close the leader and reopen as a follower controller
if err := leader.Close(); err != nil {
return nil, err
}
Expand Down
38 changes: 38 additions & 0 deletions server/shards_director_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,41 @@ func TestShardsDirector_DeleteShardLeader(t *testing.T) {
assert.NoError(t, lc.Close())
assert.NoError(t, walFactory.Close())
}

func TestShardsDirector_GetOrCreateFollower(t *testing.T) {
var shard int64 = 1

kvFactory, _ := kv.NewPebbleKVFactory(testKVOptions)
walFactory := newTestWalFactory(t)

sd := NewShardsDirector(Config{}, walFactory, kvFactory, newMockRpcClient())

lc, _ := sd.GetOrCreateLeader(common.DefaultNamespace, shard)
_, _ = lc.NewTerm(&proto.NewTermRequest{Shard: shard, Term: 2})
_, _ = lc.BecomeLeader(context.Background(), &proto.BecomeLeaderRequest{
Shard: shard,
Term: 2,
ReplicationFactor: 1,
FollowerMaps: nil,
})

assert.Equal(t, proto.ServingStatus_LEADER, lc.Status())

assert.EqualValues(t, 2, lc.Term())

// Should fail to get closed if the term is wrong
fc, err := sd.GetOrCreateFollower(common.DefaultNamespace, shard, 1)
assert.ErrorIs(t, common.ErrorInvalidTerm, err)
assert.Nil(t, fc)
assert.Equal(t, proto.ServingStatus_LEADER, lc.Status())

// Will get closed if term is correct
fc, err = sd.GetOrCreateFollower(common.DefaultNamespace, shard, 2)
assert.NoError(t, err)

assert.Equal(t, proto.ServingStatus_NOT_MEMBER, lc.Status())

assert.NoError(t, fc.Close())
assert.NoError(t, lc.Close())
assert.NoError(t, walFactory.Close())
}
4 changes: 2 additions & 2 deletions server/standalone.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,11 @@ func (noOpReplicationRpcProvider) Close() error {
return nil
}

func (noOpReplicationRpcProvider) GetReplicateStream(context.Context, string, string, int64) (proto.OxiaLogReplication_ReplicateClient, error) {
func (noOpReplicationRpcProvider) GetReplicateStream(context.Context, string, string, int64, int64) (proto.OxiaLogReplication_ReplicateClient, error) {
panic("not implemented")
}

func (noOpReplicationRpcProvider) SendSnapshot(context.Context, string, string, int64) (proto.OxiaLogReplication_SendSnapshotClient, error) {
func (noOpReplicationRpcProvider) SendSnapshot(context.Context, string, string, int64, int64) (proto.OxiaLogReplication_SendSnapshotClient, error) {
panic("not implemented")
}

Expand Down

0 comments on commit 1934d55

Please sign in to comment.