From 1934d55f0f619971d83f43fbc56865ce9221ca92 Mon Sep 17 00:00:00 2001 From: Matteo Merli Date: Fri, 13 Sep 2024 17:28:34 -0700 Subject: [PATCH] Add term check on the internal streams, to avoid wrong closures of leader controller (#513) During a leader election, an old leader can force back the new leader into a follower mode. This does not affect the correctness (eg: the old leader's requests are then ignored because the term is invalid), though it causes the shard to be unavailable until a new leader election is triggered. The simplified sequence is like: 1. `oxia-1` is the old leader in term: 4 2. Coordinator initiates a new leader election and fences all the nodes 3. `oxia-0` and `oxia-2` are fenced correctly, while oxia-1 is ignored, and will be retried in background 4. `oxia-2` is elected leader (term: 5) 5. `oxia-1` is still running the leader controller (because it never got successfully fenced), and it keeps trying to replicated to `oxia-0` and `oxia-2` . The replicate requests keep failing (because the term has already changed to 5), though the first request has the effect of putting oxia-2 back in follower mode. We are now in a state where coordinator is happy thinking `oxia-2` as the leader, though `oxia-2` went back to follower mode. A new leader election is not triggered because `oxia-2` looks healthy overall (the health check is done at the pod level, not at shard level). --- common/constants.go | 1 + maelstrom/replication_rpc_provider.go | 4 +-- server/follower_cursor.go | 8 +++--- server/internal_rpc_server.go | 28 +++++++++++++++++--- server/mock_test.go | 4 +-- server/rpc_provider.go | 6 +++-- server/shards_director.go | 10 ++++--- server/shards_director_test.go | 38 +++++++++++++++++++++++++++ server/standalone.go | 4 +-- 9 files changed, 85 insertions(+), 18 deletions(-) diff --git a/common/constants.go b/common/constants.go index d910a103..cb8848eb 100644 --- a/common/constants.go +++ b/common/constants.go @@ -17,6 +17,7 @@ package common import "time" const ( + MetadataTerm = "term" MetadataNamespace = "namespace" MetadataShardId = "shard-id" DefaultNamespace = "default" diff --git a/maelstrom/replication_rpc_provider.go b/maelstrom/replication_rpc_provider.go index c1dacdec..4916f8c7 100644 --- a/maelstrom/replication_rpc_provider.go +++ b/maelstrom/replication_rpc_provider.go @@ -45,7 +45,7 @@ func (r *maelstromReplicationRpcProvider) Close() error { return nil } -func (r *maelstromReplicationRpcProvider) GetReplicateStream(ctx context.Context, follower string, namespace string, shard int64) ( +func (r *maelstromReplicationRpcProvider) GetReplicateStream(ctx context.Context, follower string, namespace string, shard int64, term int64) ( proto.OxiaLogReplication_ReplicateClient, error) { s := &maelstromReplicateClient{ ctx: ctx, @@ -83,7 +83,7 @@ func (r *maelstromReplicationRpcProvider) Truncate(follower string, req *proto.T return res.(*proto.TruncateResponse), nil } -func (r *maelstromReplicationRpcProvider) SendSnapshot(ctx context.Context, follower string, namespace string, shard int64) (proto.OxiaLogReplication_SendSnapshotClient, error) { +func (r *maelstromReplicationRpcProvider) SendSnapshot(ctx context.Context, follower string, namespace string, shard int64, term int64) (proto.OxiaLogReplication_SendSnapshotClient, error) { panic("not implemented") } diff --git a/server/follower_cursor.go b/server/follower_cursor.go index 9e1306d6..b6b970fd 100644 --- a/server/follower_cursor.go +++ b/server/follower_cursor.go @@ -40,8 +40,8 @@ import ( // This is a provider for the ReplicateStream Grpc handler // It's used to allow passing in a mocked version of the Grpc service. type ReplicateStreamProvider interface { - GetReplicateStream(ctx context.Context, follower string, namespace string, shard int64) (proto.OxiaLogReplication_ReplicateClient, error) - SendSnapshot(ctx context.Context, follower string, namespace string, shard int64) (proto.OxiaLogReplication_SendSnapshotClient, error) + GetReplicateStream(ctx context.Context, follower string, namespace string, shard int64, term int64) (proto.OxiaLogReplication_ReplicateClient, error) + SendSnapshot(ctx context.Context, follower string, namespace string, shard int64, term int64) (proto.OxiaLogReplication_SendSnapshotClient, error) } // FollowerCursor @@ -251,7 +251,7 @@ func (fc *followerCursor) sendSnapshot() error { ctx, cancel := context.WithCancel(fc.ctx) defer cancel() - stream, err := fc.replicateStreamProvider.SendSnapshot(ctx, fc.follower, fc.namespace, fc.shardId) + stream, err := fc.replicateStreamProvider.SendSnapshot(ctx, fc.follower, fc.namespace, fc.shardId, fc.term) if err != nil { return err } @@ -367,7 +367,7 @@ func (fc *followerCursor) streamEntries() error { fc.Lock() var err error - if fc.stream, err = fc.replicateStreamProvider.GetReplicateStream(ctx, fc.follower, fc.namespace, fc.shardId); err != nil { + if fc.stream, err = fc.replicateStreamProvider.GetReplicateStream(ctx, fc.follower, fc.namespace, fc.shardId, fc.term); err != nil { fc.Unlock() return err } diff --git a/server/internal_rpc_server.go b/server/internal_rpc_server.go index 0aceb53a..6a2f57d1 100644 --- a/server/internal_rpc_server.go +++ b/server/internal_rpc_server.go @@ -217,7 +217,7 @@ func (s *internalRpcServer) Truncate(c context.Context, req *proto.TruncateReque log.Info("Received Truncate request") - follower, err := s.shardsDirector.GetOrCreateFollower(req.Namespace, req.Shard) + follower, err := s.shardsDirector.GetOrCreateFollower(req.Namespace, req.Shard, req.Term) if err != nil { log.Warn( "Truncate failed: could not get follower controller", @@ -252,6 +252,11 @@ func (s *internalRpcServer) Replicate(srv proto.OxiaLogReplication_ReplicateServ return err } + term, err := readTerm(md) + if err != nil { + return err + } + log := s.log.With( slog.Int64("shard", shardId), slog.String("namespace", namespace), @@ -260,7 +265,7 @@ func (s *internalRpcServer) Replicate(srv proto.OxiaLogReplication_ReplicateServ log.Info("Received Replicate request") - follower, err := s.shardsDirector.GetOrCreateFollower(namespace, shardId) + follower, err := s.shardsDirector.GetOrCreateFollower(namespace, shardId, term) if err != nil { log.Warn( "Replicate failed: could not get follower controller", @@ -297,6 +302,11 @@ func (s *internalRpcServer) SendSnapshot(srv proto.OxiaLogReplication_SendSnapsh return err } + term, err := readTerm(md) + if err != nil { + return err + } + s.log.Info( "Received SendSnapshot request", slog.Int64("shard", shardId), @@ -304,7 +314,7 @@ func (s *internalRpcServer) SendSnapshot(srv proto.OxiaLogReplication_SendSnapsh slog.String("peer", common.GetPeer(srv.Context())), ) - follower, err := s.shardsDirector.GetOrCreateFollower(namespace, shardId) + follower, err := s.shardsDirector.GetOrCreateFollower(namespace, shardId, term) if err != nil { s.log.Warn( "SendSnapshot failed: could not get follower controller", @@ -374,3 +384,15 @@ func ReadHeaderInt64(md metadata.MD, key string) (v int64, err error) { _, err = fmt.Sscan(s, &r) return r, err } + +func readTerm(md metadata.MD) (v int64, err error) { + arr := md.Get(common.MetadataTerm) + if len(arr) == 0 { + // There was no term in the metadata for the stream. + // In order to retain compatibility in a rollout scenario, skip + // the term check + return -1, nil + } + + return ReadHeaderInt64(md, common.MetadataTerm) +} diff --git a/server/mock_test.go b/server/mock_test.go index ebfb4b31..31618f25 100644 --- a/server/mock_test.go +++ b/server/mock_test.go @@ -98,11 +98,11 @@ func (m *mockRpcClient) CloseSend() error { return nil } -func (m *mockRpcClient) GetReplicateStream(ctx context.Context, follower string, namespace string, shard int64) (proto.OxiaLogReplication_ReplicateClient, error) { +func (m *mockRpcClient) GetReplicateStream(ctx context.Context, follower string, namespace string, shard int64, term int64) (proto.OxiaLogReplication_ReplicateClient, error) { return m, nil } -func (m *mockRpcClient) SendSnapshot(ctx context.Context, follower string, namespace string, shard int64) (proto.OxiaLogReplication_SendSnapshotClient, error) { +func (m *mockRpcClient) SendSnapshot(ctx context.Context, follower string, namespace string, shard int64, term int64) (proto.OxiaLogReplication_SendSnapshotClient, error) { return m.sendSnapshotStream, nil } diff --git a/server/rpc_provider.go b/server/rpc_provider.go index dddbb296..388b3812 100644 --- a/server/rpc_provider.go +++ b/server/rpc_provider.go @@ -46,7 +46,7 @@ func NewReplicationRpcProvider(tlsConf *tls.Config) ReplicationRpcProvider { } } -func (r *replicationRpcProvider) GetReplicateStream(ctx context.Context, follower string, namespace string, shard int64) ( +func (r *replicationRpcProvider) GetReplicateStream(ctx context.Context, follower string, namespace string, shard int64, term int64) ( proto.OxiaLogReplication_ReplicateClient, error) { rpc, err := r.pool.GetReplicationRpc(follower) if err != nil { @@ -55,12 +55,13 @@ func (r *replicationRpcProvider) GetReplicateStream(ctx context.Context, followe ctx = metadata.AppendToOutgoingContext(ctx, common.MetadataNamespace, namespace) ctx = metadata.AppendToOutgoingContext(ctx, common.MetadataShardId, fmt.Sprintf("%d", shard)) + ctx = metadata.AppendToOutgoingContext(ctx, common.MetadataTerm, fmt.Sprintf("%d", term)) stream, err := rpc.Replicate(ctx) return stream, err } -func (r *replicationRpcProvider) SendSnapshot(ctx context.Context, follower string, namespace string, shard int64) ( +func (r *replicationRpcProvider) SendSnapshot(ctx context.Context, follower string, namespace string, shard int64, term int64) ( proto.OxiaLogReplication_SendSnapshotClient, error) { rpc, err := r.pool.GetReplicationRpc(follower) if err != nil { @@ -69,6 +70,7 @@ func (r *replicationRpcProvider) SendSnapshot(ctx context.Context, follower stri ctx = metadata.AppendToOutgoingContext(ctx, common.MetadataNamespace, namespace) ctx = metadata.AppendToOutgoingContext(ctx, common.MetadataShardId, fmt.Sprintf("%d", shard)) + ctx = metadata.AppendToOutgoingContext(ctx, common.MetadataTerm, fmt.Sprintf("%d", term)) stream, err := rpc.SendSnapshot(ctx) return stream, err diff --git a/server/shards_director.go b/server/shards_director.go index b8fa1e44..d2e2b620 100644 --- a/server/shards_director.go +++ b/server/shards_director.go @@ -36,7 +36,7 @@ type ShardsDirector interface { GetFollower(shardId int64) (FollowerController, error) GetOrCreateLeader(namespace string, shardId int64) (LeaderController, error) - GetOrCreateFollower(namespace string, shardId int64) (FollowerController, error) + GetOrCreateFollower(namespace string, shardId int64, term int64) (FollowerController, error) DeleteShard(req *proto.DeleteShardRequest) (*proto.DeleteShardResponse, error) } @@ -155,7 +155,7 @@ func (s *shardsDirector) GetOrCreateLeader(namespace string, shardId int64) (Lea return lc, nil } -func (s *shardsDirector) GetOrCreateFollower(namespace string, shardId int64) (FollowerController, error) { +func (s *shardsDirector) GetOrCreateFollower(namespace string, shardId int64, term int64) (FollowerController, error) { s.Lock() defer s.Unlock() @@ -168,8 +168,12 @@ func (s *shardsDirector) GetOrCreateFollower(namespace string, shardId int64) (F return follower, nil } else if leader, ok := s.leaders[shardId]; ok { // There is an existing leader controller - // Let's close it before creating the follower controller + if term >= 0 && term != leader.Term() { + // We should not close the existing leader because of a late request + return nil, common.ErrorInvalidTerm + } + // If we are in the right term, let's close the leader and reopen as a follower controller if err := leader.Close(); err != nil { return nil, err } diff --git a/server/shards_director_test.go b/server/shards_director_test.go index 8d731e62..c0cf2dc0 100644 --- a/server/shards_director_test.go +++ b/server/shards_director_test.go @@ -62,3 +62,41 @@ func TestShardsDirector_DeleteShardLeader(t *testing.T) { assert.NoError(t, lc.Close()) assert.NoError(t, walFactory.Close()) } + +func TestShardsDirector_GetOrCreateFollower(t *testing.T) { + var shard int64 = 1 + + kvFactory, _ := kv.NewPebbleKVFactory(testKVOptions) + walFactory := newTestWalFactory(t) + + sd := NewShardsDirector(Config{}, walFactory, kvFactory, newMockRpcClient()) + + lc, _ := sd.GetOrCreateLeader(common.DefaultNamespace, shard) + _, _ = lc.NewTerm(&proto.NewTermRequest{Shard: shard, Term: 2}) + _, _ = lc.BecomeLeader(context.Background(), &proto.BecomeLeaderRequest{ + Shard: shard, + Term: 2, + ReplicationFactor: 1, + FollowerMaps: nil, + }) + + assert.Equal(t, proto.ServingStatus_LEADER, lc.Status()) + + assert.EqualValues(t, 2, lc.Term()) + + // Should fail to get closed if the term is wrong + fc, err := sd.GetOrCreateFollower(common.DefaultNamespace, shard, 1) + assert.ErrorIs(t, common.ErrorInvalidTerm, err) + assert.Nil(t, fc) + assert.Equal(t, proto.ServingStatus_LEADER, lc.Status()) + + // Will get closed if term is correct + fc, err = sd.GetOrCreateFollower(common.DefaultNamespace, shard, 2) + assert.NoError(t, err) + + assert.Equal(t, proto.ServingStatus_NOT_MEMBER, lc.Status()) + + assert.NoError(t, fc.Close()) + assert.NoError(t, lc.Close()) + assert.NoError(t, walFactory.Close()) +} diff --git a/server/standalone.go b/server/standalone.go index ccb9fed8..6f32d719 100644 --- a/server/standalone.go +++ b/server/standalone.go @@ -162,11 +162,11 @@ func (noOpReplicationRpcProvider) Close() error { return nil } -func (noOpReplicationRpcProvider) GetReplicateStream(context.Context, string, string, int64) (proto.OxiaLogReplication_ReplicateClient, error) { +func (noOpReplicationRpcProvider) GetReplicateStream(context.Context, string, string, int64, int64) (proto.OxiaLogReplication_ReplicateClient, error) { panic("not implemented") } -func (noOpReplicationRpcProvider) SendSnapshot(context.Context, string, string, int64) (proto.OxiaLogReplication_SendSnapshotClient, error) { +func (noOpReplicationRpcProvider) SendSnapshot(context.Context, string, string, int64, int64) (proto.OxiaLogReplication_SendSnapshotClient, error) { panic("not implemented") }