diff --git a/coordinator/impl/coordinator.go b/coordinator/impl/coordinator.go index 97367a2f..91251059 100644 --- a/coordinator/impl/coordinator.go +++ b/coordinator/impl/coordinator.go @@ -54,6 +54,8 @@ type Coordinator interface { NodeAvailabilityListener + FindServerAddressByInternalAddress(internalAddress string) (*model.ServerAddress, bool) + ClusterStatus() model.ClusterStatus } @@ -64,6 +66,8 @@ type coordinator struct { MetadataProvider clusterConfigProvider func() (model.ClusterConfig, error) model.ClusterConfig + serverIndexes sync.Map + clusterConfigChangeCh chan any shardControllers map[int64]ShardController @@ -100,6 +104,7 @@ func NewCoordinator(metadataProvider MetadataProvider, shardControllers: make(map[int64]ShardController), nodeControllers: make(map[string]NodeController), drainingNodes: make(map[string]NodeController), + serverIndexes: sync.Map{}, rpc: rpc, log: slog.With( slog.String("component", "coordinator"), @@ -117,6 +122,7 @@ func NewCoordinator(metadataProvider MetadataProvider, for _, sa := range c.ClusterConfig.Servers { c.nodeControllers[sa.Internal] = NewNodeController(sa, c, c, c.rpc) + c.serverIndexes.Store(sa.Internal, sa) } if c.clusterStatus == nil { @@ -446,6 +452,10 @@ func (c *coordinator) handleClusterConfigUpdated() error { slog.Any("metadataVersion", c.metadataVersion), ) + for _, sc := range c.shardControllers { + sc.SyncServerAddress() + } + c.checkClusterNodeChanges(newClusterConfig) clusterStatus, shardsToAdd, shardsToDelete := applyClusterChanges(&newClusterConfig, c.clusterStatus) @@ -512,6 +522,17 @@ func (c *coordinator) rebalanceCluster() error { return nil } +func (c *coordinator) FindServerAddressByInternalAddress(internalAddress string) (*model.ServerAddress, bool) { + if info, exist := c.serverIndexes.Load(internalAddress); exist { + address, ok := info.(model.ServerAddress) + if !ok { + panic("unexpected cast") + } + return &address, true + } + return nil, false +} + func (*coordinator) findServerByInternalAddress(newClusterConfig model.ClusterConfig, server string) *model.ServerAddress { for _, s := range newClusterConfig.Servers { if server == s.Internal { @@ -525,6 +546,8 @@ func (*coordinator) findServerByInternalAddress(newClusterConfig model.ClusterCo func (c *coordinator) checkClusterNodeChanges(newClusterConfig model.ClusterConfig) { // Check for nodes to add for _, sa := range newClusterConfig.Servers { + c.serverIndexes.Store(sa.Internal, sa) + if _, ok := c.nodeControllers[sa.Internal]; ok { continue } @@ -548,6 +571,7 @@ func (c *coordinator) checkClusterNodeChanges(newClusterConfig model.ClusterConf } c.log.Info("Detected a removed node", slog.Any("addr", ia)) + c.serverIndexes.Delete(ia) // Moved the node delete(c.nodeControllers, ia) nc.SetStatus(Draining) diff --git a/coordinator/impl/coordinator_e2e_test.go b/coordinator/impl/coordinator_e2e_test.go index b62f6ecc..6043b8b9 100644 --- a/coordinator/impl/coordinator_e2e_test.go +++ b/coordinator/impl/coordinator_e2e_test.go @@ -19,6 +19,7 @@ import ( "fmt" "log/slog" "math" + "strings" "sync" "testing" "time" @@ -769,3 +770,74 @@ func checkServerLists(t *testing.T, expected, actual []model.ServerAddress) { assert.True(t, ok) } } + +func TestCoordinator_RefreshServerInfo(t *testing.T) { + s1, sa1 := newServer(t) + s2, sa2 := newServer(t) + s3, sa3 := newServer(t) + + metadataProvider := NewMetadataProviderMemory() + clusterConfig := model.ClusterConfig{ + Namespaces: []model.NamespaceConfig{{ + Name: "my-ns-1", + ReplicationFactor: 3, + InitialShardCount: 1, + }}, + Servers: []model.ServerAddress{sa1, sa2, sa3}, + } + configChangesCh := make(chan any) + c, err := NewCoordinator(metadataProvider, func() (model.ClusterConfig, error) { + return clusterConfig, nil + }, configChangesCh, + NewRpcProvider(common.NewClientPool(nil, nil))) + assert.NoError(t, err) + + // wait for all shards to be ready + assert.Eventually(t, func() bool { + for _, ns := range c.ClusterStatus().Namespaces { + for _, shard := range ns.Shards { + if shard.Status != model.ShardStatusSteadyState { + return false + } + } + } + return true + }, 10*time.Second, 10*time.Millisecond) + + // change the localhost to 127.0.0.1 + clusterServer := make([]model.ServerAddress, 0) + for _, sv := range clusterConfig.Servers { + clusterServer = append(clusterServer, model.ServerAddress{ + Public: strings.ReplaceAll(sv.Public, "localhost", "127.0.0.1"), + Internal: sv.Internal, + }) + } + + clusterConfig.Servers = clusterServer + configChangesCh <- nil + + assert.Eventually(t, func() bool { + for _, ns := range c.ClusterStatus().Namespaces { + for _, shard := range ns.Shards { + if shard.Status != model.ShardStatusSteadyState { + return false + } + for _, sv := range shard.Ensemble { + if !strings.HasPrefix(sv.Public, "127.0.0.1") { + return false + } + } + } + } + return true + }, 10*time.Second, 10*time.Millisecond) + + err = s1.Close() + assert.NoError(t, err) + err = s2.Close() + assert.NoError(t, err) + err = s3.Close() + assert.NoError(t, err) + err = c.Close() + assert.NoError(t, err) +} diff --git a/coordinator/impl/shard_controller.go b/coordinator/impl/shard_controller.go index cb760060..36ebf421 100644 --- a/coordinator/impl/shard_controller.go +++ b/coordinator/impl/shard_controller.go @@ -20,6 +20,7 @@ import ( "io" "log/slog" "math/rand" + "reflect" "sync" "time" @@ -64,6 +65,8 @@ type ShardController interface { HandleNodeFailure(failedNode model.ServerAddress) + SyncServerAddress() + SwapNode(from model.ServerAddress, to model.ServerAddress) error DeleteShard() @@ -77,10 +80,11 @@ type shardController struct { shard int64 namespaceConfig *model.NamespaceConfig shardMetadata model.ShardMetadata - shardMetadataMutex sync.Mutex + shardMetadataMutex sync.RWMutex rpc RpcProvider coordinator Coordinator + electionOp chan any deleteOp chan any nodeFailureOp chan model.ServerAddress swapNodeOp chan swapNodeRequest @@ -110,6 +114,7 @@ func NewShardController(namespace string, shard int64, namespaceConfig *model.Na shardMetadata: shardMetadata, rpc: rpc, coordinator: coordinator, + electionOp: make(chan any, chanBufferSize), deleteOp: make(chan any, chanBufferSize), nodeFailureOp: make(chan model.ServerAddress, chanBufferSize), swapNodeOp: make(chan swapNodeRequest, chanBufferSize), @@ -169,6 +174,8 @@ func (s *shardController) run() { if !s.verifyCurrentEnsemble() { s.electLeaderWithRetries() + } else { + s.SyncServerAddress() } } @@ -193,6 +200,9 @@ func (s *shardController) run() { case a := <-s.newTermAndAddFollowerOp: s.internalNewTermAndAddFollower(a.ctx, a.node, a.res) + + case <-s.electionOp: + s.electLeaderWithRetries() } } } @@ -209,7 +219,7 @@ func (s *shardController) handleNodeFailure(failedNode model.ServerAddress) { ) if s.shardMetadata.Leader != nil && - *s.shardMetadata.Leader == failedNode { + s.shardMetadata.Leader.Internal == failedNode.Internal { s.log.Info( "Detected failure on shard leader", slog.Any("leader", failedNode), @@ -295,6 +305,8 @@ func (s *shardController) electLeader() error { s.shardMetadata.Status = model.ShardStatusElection s.shardMetadata.Leader = nil s.shardMetadata.Term++ + // it's a safe point to update the service info + s.shardMetadata.Ensemble = s.getRefreshedEnsemble() s.shardMetadataMutex.Unlock() s.log.Info( @@ -369,6 +381,25 @@ func (s *shardController) electLeader() error { return nil } +func (s *shardController) getRefreshedEnsemble() []model.ServerAddress { + currentEnsemble := s.shardMetadata.Ensemble + refreshedEnsembleServiceAddress := make([]model.ServerAddress, len(currentEnsemble)) + for idx, candidate := range currentEnsemble { + if refreshedAddress, exist := s.coordinator.FindServerAddressByInternalAddress(candidate.Internal); exist { + refreshedEnsembleServiceAddress[idx] = *refreshedAddress + continue + } + refreshedEnsembleServiceAddress[idx] = candidate + } + if s.log.Enabled(s.ctx, slog.LevelDebug) { + if !reflect.DeepEqual(currentEnsemble, refreshedEnsembleServiceAddress) { + s.log.Info("refresh the shard ensemble server address", slog.Any("current-ensemble", currentEnsemble), + slog.Any("new-ensemble", refreshedEnsembleServiceAddress)) + } + } + return refreshedEnsembleServiceAddress +} + func (s *shardController) deletingRemovedNodes() error { for _, ds := range s.shardMetadata.RemovedNodes { if _, err := s.rpc.DeleteShard(s.ctx, ds, &proto.DeleteShardRequest{ @@ -866,9 +897,29 @@ func (s *shardController) waitForFollowersToCatchUp(ctx context.Context, leader return nil } +func (s *shardController) SyncServerAddress() { + s.shardMetadataMutex.RLock() + exist := false + for _, candidate := range s.shardMetadata.Ensemble { + if newInfo, ok := s.coordinator.FindServerAddressByInternalAddress(candidate.Internal); ok { + if *newInfo != candidate { + exist = true + break + } + } + } + if !exist { + s.shardMetadataMutex.RUnlock() + return + } + s.shardMetadataMutex.RUnlock() + s.log.Info("server address changed, start a new leader election") + s.electionOp <- nil +} + func listContains(list []model.ServerAddress, sa model.ServerAddress) bool { for _, item := range list { - if item.Public == sa.Public && item.Internal == sa.Internal { + if item.Internal == sa.Internal { return true } } @@ -887,7 +938,7 @@ func mergeLists[T any](lists ...[]T) []T { func replaceInList(list []model.ServerAddress, oldServerAddress, newServerAddress model.ServerAddress) []model.ServerAddress { var res []model.ServerAddress for _, item := range list { - if item.Public != oldServerAddress.Public && item.Internal != oldServerAddress.Internal { + if item.Internal != oldServerAddress.Internal { res = append(res, item) } } diff --git a/coordinator/impl/shard_controller_test.go b/coordinator/impl/shard_controller_test.go index d798cd7d..eeef2850 100644 --- a/coordinator/impl/shard_controller_test.go +++ b/coordinator/impl/shard_controller_test.go @@ -365,6 +365,10 @@ func (m *mockCoordinator) WaitForNextUpdate(ctx context.Context, currentValue *p panic("not implemented") } +func (m *mockCoordinator) FindServerAddressByInternalAddress(_ string) (*model.ServerAddress, bool) { + return nil, false +} + func (m *mockCoordinator) InitiateLeaderElection(namespace string, shard int64, metadata model.ShardMetadata) error { m.Lock() defer m.Unlock()