diff --git a/subdomains/safe_map.go b/subdomains/safe_map.go new file mode 100644 index 0000000..908074b --- /dev/null +++ b/subdomains/safe_map.go @@ -0,0 +1,71 @@ +package subdomains + +import "sync" + +func InitSafeMap[T any]() SafeMap[T] { + return SafeMap[T]{mp: map[string]T{}} +} + +func NewSafeMap[T any]() *SafeMap[T] { + return &SafeMap[T]{mp: map[string]T{}} +} + +type SafeMap[T any] struct { + mu sync.RWMutex + mp map[string]T +} + +func (a *SafeMap[T]) Size() int { + a.mu.RLock() + size := len(a.mp) + a.mu.RUnlock() + return size +} + +func (a *SafeMap[T]) GetOk(key string) (T, bool) { + a.mu.RLock() + val, ok := a.mp[key] + a.mu.RUnlock() + return val, ok +} + +func (a *SafeMap[T]) Get(key string) T { + a.mu.RLock() + val := a.mp[key] + a.mu.RUnlock() + return val +} + +func (a *SafeMap[T]) Set(key string, info T) { + a.mu.Lock() + a.mp[key] = info + a.mu.Unlock() +} + +func (a *SafeMap[T]) Remove(keys ...string) { + a.mu.Lock() + for _, key := range keys { + delete(a.mp, key) + } + a.mu.Unlock() +} + +func (a *SafeMap[T]) Range(f func(key string, val T) bool) { + a.mu.RLock() + for key, val := range a.mp { + if !f(key, val) { + break + } + } + a.mu.RUnlock() +} + +func (a *SafeMap[T]) ClearEmpty(f func(key string, val T) bool) { + a.mu.Lock() + for key, val := range a.mp { + if f(key, val) { + delete(a.mp, key) + } + } + a.mu.Unlock() +} diff --git a/subdomains/subdomains.go b/subdomains/subdomains.go index cd631b3..c21ccc3 100644 --- a/subdomains/subdomains.go +++ b/subdomains/subdomains.go @@ -1,8 +1,10 @@ package subdomains import ( + "slices" "sort" "strings" + "sync/atomic" "github.com/admpub/log" @@ -17,8 +19,8 @@ var Default = New() func New() *Subdomains { s := &Subdomains{ - Hosts: map[string][]string{}, - Alias: map[string]*Info{}, + Hosts: InitSafeMap[*[]string](), + Alias: InitSafeMap[*Info](), Default: ``, Protocol: `http`, } @@ -75,9 +77,10 @@ func (info *Info) RelativeURLByName(s *Subdomains, name string, args ...interfac type Dispatcher func(r engine.Request, w engine.Response) (*echo.Echo, bool) type Subdomains struct { - Hosts map[string][]string //{host:name} - Alias map[string]*Info + Hosts SafeMap[*[]string] //{host:name} + Alias SafeMap[*Info] Prefixes []string + hostsNum atomic.Int32 Default string //default name Protocol string //http/https Boot string @@ -109,13 +112,34 @@ func (s *Subdomains) Add(name string, e *echo.Echo) *Subdomains { hosts = append(hosts, ``) } } + var hasRemoved bool + var addedHosts int + var appendsHosts []string for _, host := range hosts { - if _, ok := s.Hosts[host]; !ok { - s.Hosts[host] = []string{name} - } else if !com.InSlice(name, s.Hosts[host]) { - s.Hosts[host] = append(s.Hosts[host], name) + if aliases, ok := s.Hosts.GetOk(host); !ok { + s.Hosts.Set(host, &[]string{name}) + addedHosts++ + } else if !com.InSlice(name, *aliases) { + *aliases = append(*aliases, name) + appendsHosts = append(appendsHosts, host) } } + s.Hosts.Range(func(host string, aliases *[]string) bool { + index := slices.Index(*aliases, name) + if index > -1 && !com.InSlice(host, hosts) { + *aliases = slices.Delete(*aliases, index, index+1) + hasRemoved = true + } + return true + }) + if hasRemoved { + s.Hosts.ClearEmpty(func(_ string, val *[]string) bool { + return len(*val) == 0 + }) + } + if addedHosts > 0 || hasRemoved { + s.hostsNum.Store(int32(s.Hosts.Size())) + } info := &Info{ Protocol: `http`, Name: name, @@ -135,7 +159,14 @@ func (s *Subdomains) Add(name string, e *echo.Echo) *Subdomains { } } } - s.Alias[name] = info + s.Alias.Set(name, info) + if len(e.Prefix()) > 0 { + for _, host := range appendsHosts { + if aliases, ok := s.Hosts.GetOk(host); ok { + s.sort(*aliases) + } + } + } return s } @@ -144,16 +175,17 @@ func (s *Subdomains) Get(args ...string) *Info { if len(args) > 0 { name = args[0] } - if e, ok := s.Alias[name]; ok { + if e, ok := s.Alias.GetOk(name); ok { return e } return nil } func (s *Subdomains) SetDebug(on bool) *Subdomains { - for _, info := range s.Alias { + s.Alias.Range(func(key string, info *Info) bool { info.SetDebug(on) - } + return true + }) return s } @@ -214,46 +246,48 @@ func (s *Subdomains) RelativeURLByName(name string, params ...interface{}) strin func (s *Subdomains) sort(names []string) []string { sort.Slice(names, func(i, j int) bool { - return len(s.Alias[names[i]].Prefix()) > len(s.Alias[names[j]].Prefix()) + return len(s.Alias.Get(names[i]).Prefix()) > len(s.Alias.Get(names[j]).Prefix()) }) return names } -func (s *Subdomains) sortHosts() { - for k := range s.Hosts { - s.Hosts[k] = s.sort(s.Hosts[k]) - } +func (s *Subdomains) SortHosts() { + s.Hosts.Range(func(key string, val *[]string) bool { + s.sort(*val) + return true + }) } func (s *Subdomains) FindByDomain(host string, upath string) (*echo.Echo, bool) { var ( - names []string + names *[]string exists bool ) - if len(s.Hosts) == 1 && len(s.Hosts[``]) > 0 { - names = s.Hosts[``] - exists = true - } else { - names, exists = s.Hosts[host] + if s.hostsNum.Load() == 1 { + names = s.Hosts.Get(``) + exists = names != nil && len(*names) > 0 + } + if !exists { + names, exists = s.Hosts.GetOk(host) if !exists { if p := strings.LastIndexByte(host, ':'); p > -1 { - names, exists = s.Hosts[host[0:p]] + names, exists = s.Hosts.GetOk(host[0:p]) if !exists { - names, exists = s.Hosts[``] + names, exists = s.Hosts.GetOk(``) } } } } var info *Info - if exists { - for _, name := range names { - info, exists = s.Alias[name] + if exists && names != nil { + for _, name := range *names { + info, exists = s.Alias.GetOk(name) if exists && (upath == info.Prefix() || strings.HasPrefix(upath, info.Prefix()+`/`)) { return info.Echo, exists } } } - info, exists = s.Alias[s.Default] + info, exists = s.Alias.GetOk(s.Default) if exists { return info.Echo, exists } @@ -277,20 +311,22 @@ func (s *Subdomains) Ready() *Info { if s.dispatcher == nil { s.dispatcher = s.DefaultDispatcher } - s.sortHosts() + s.hostsNum.Store(int32(s.Hosts.Size())) + s.SortHosts() e := s.Get(s.Boot) if e == nil { - for _, info := range s.Alias { + s.Alias.Range(func(key string, info *Info) bool { e = info - break - } + return false + }) } - for _, info := range s.Alias { + s.Alias.Range(func(key string, info *Info) bool { if e == info { - continue + return true } info.Commit() - } + return true + }) return e } diff --git a/subdomains/subdomains_test.go b/subdomains/subdomains_test.go index a142d50..7ea8d35 100644 --- a/subdomains/subdomains_test.go +++ b/subdomains/subdomains_test.go @@ -4,6 +4,7 @@ import ( "net/http" "testing" + "github.com/admpub/log" "github.com/stretchr/testify/assert" "github.com/webx-top/echo" "github.com/webx-top/echo/engine" @@ -14,7 +15,9 @@ func request(method, path string, h engine.Handler, reqRewrite ...func(*http.Req rec := test.Request(method, path, h, reqRewrite...) return rec.Code, rec.Body.String() } + func TestSortHosts(t *testing.T) { + defer log.Close() a := New() e := echo.New() e.Get(`/`, func(c echo.Context) error { @@ -32,7 +35,8 @@ func TestSortHosts(t *testing.T) { a.Add(`backend`, e2) a.Ready().Commit() - assert.Equal(t, []string{`backend`, `frontend`}, a.Hosts[``]) + assert.Equal(t, []string{`backend`, `frontend`}, *a.Hosts.Get(``)) + assert.Equal(t, int32(1), a.hostsNum.Load()) c, b := request(echo.GET, "/", a) assert.Equal(t, http.StatusOK, c) @@ -49,4 +53,25 @@ func TestSortHosts(t *testing.T) { c, b = request(echo.GET, "/adminindex", a) assert.Equal(t, http.StatusNotFound, c) assert.Equal(t, http.StatusText(http.StatusNotFound), b) + + e3 := echo.New() + e3.Get(``, func(c echo.Context) error { + return c.String(`backend`) + }) + e3.Get(`/index`, func(c echo.Context) error { + return c.String(`backend-index`) + }) + a.Add(`backend@github.com,coscms.com`, e3) + assert.Equal(t, []string{`frontend`}, *a.Hosts.Get(``)) + assert.Equal(t, []string{`backend`}, *a.Hosts.Get(`github.com`)) + assert.Equal(t, []string{`backend`}, *a.Hosts.Get(`coscms.com`)) + assert.Equal(t, int32(3), a.hostsNum.Load()) + + e4 := echo.New() + e4.SetPrefix(`/portal`) + e4.Get(`/`, func(c echo.Context) error { + return c.String(`portal`) + }) + a.Add(`portal`, e4) + assert.Equal(t, []string{`portal`, `frontend`}, *a.Hosts.Get(``)) }